mirror of
https://gitflic.ru/project/openide/openide.git
synced 2025-12-13 15:52:01 +07:00
[pycharm] PY-80607 Tables(Jupyter, DataView): support requires_grad case
(cherry picked from commit 8f8be0cc80f3473e6fabd2f9cf9034a2103cdbfd) GitOrigin-RevId: cc10b1639b4be14201e2a7a89655146e4e4bbead
This commit is contained in:
committed by
intellij-monorepo-bot
parent
ca9ae7baa5
commit
8828a301b0
@@ -420,17 +420,26 @@ def array_to_thrift_struct(array, name, roffset, coffset, rows, cols, format):
|
||||
return array_chunk
|
||||
|
||||
|
||||
def tensor_to_thrift_struct(tensor, name, roffset, coffset, rows, cols, format):
|
||||
def tf_to_thrift_struct(tensor, name, roffset, coffset, rows, cols, format):
|
||||
try:
|
||||
return array_to_thrift_struct(tensor.numpy(), name, roffset, coffset, rows, cols, format)
|
||||
except TypeError:
|
||||
return array_to_thrift_struct(tensor.to_dense().numpy(), name, roffset, coffset, rows, cols, format)
|
||||
|
||||
|
||||
def sparse_tensor_to_thrift_struct(tensor, name, roffset, coffset, rows, cols, format):
|
||||
def torch_to_thrift_struct(tensor, name, roffset, coffset, rows, cols, format):
|
||||
try:
|
||||
if tensor.requires_grad:
|
||||
tensor = tensor.detach()
|
||||
return array_to_thrift_struct(tensor.numpy(), name, roffset, coffset, rows, cols, format)
|
||||
except TypeError:
|
||||
return array_to_thrift_struct(tensor.to_dense().numpy(), name, roffset, coffset, rows, cols, format)
|
||||
|
||||
|
||||
def tf_sparse_to_thrift_struct(tensor, name, roffset, coffset, rows, cols, format):
|
||||
try:
|
||||
import tensorflow as tf
|
||||
return tensor_to_thrift_struct(tf.sparse.to_dense(tf.sparse.reorder(tensor)), name, roffset, coffset, rows, cols, format)
|
||||
return tf_to_thrift_struct(tf.sparse.to_dense(tf.sparse.reorder(tensor)), name, roffset, coffset, rows, cols, format)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -646,10 +655,10 @@ def header_data_to_thrift_struct(rows, cols, dtypes, col_bounds, col_to_format,
|
||||
TYPE_TO_THRIFT_STRUCT_CONVERTERS = {
|
||||
"ndarray": array_to_thrift_struct,
|
||||
"recarray": array_to_thrift_struct,
|
||||
"EagerTensor": tensor_to_thrift_struct,
|
||||
"ResourceVariable": tensor_to_thrift_struct,
|
||||
"SparseTensor": sparse_tensor_to_thrift_struct,
|
||||
"Tensor": tensor_to_thrift_struct,
|
||||
"EagerTensor": tf_to_thrift_struct,
|
||||
"ResourceVariable": tf_to_thrift_struct,
|
||||
"SparseTensor": tf_sparse_to_thrift_struct,
|
||||
"Tensor": torch_to_thrift_struct,
|
||||
"DataFrame": dataframe_to_thrift_struct,
|
||||
"Series": dataframe_to_thrift_struct,
|
||||
"Dataset": dataset_to_thrift_struct,
|
||||
|
||||
@@ -589,17 +589,26 @@ def array_to_xml(array, name, roffset, coffset, rows, cols, format):
|
||||
return xml
|
||||
|
||||
|
||||
def tensor_to_xml(tensor, name, roffset, coffset, rows, cols, format):
|
||||
def tf_to_xml(tensor, name, roffset, coffset, rows, cols, format):
|
||||
try:
|
||||
return array_to_xml(tensor.numpy(), name, roffset, coffset, rows, cols, format)
|
||||
except TypeError:
|
||||
return array_to_xml(tensor.to_dense().numpy(), name, roffset, coffset, rows, cols, format)
|
||||
|
||||
|
||||
def sparse_tensor_to_xml(tensor, name, roffset, coffset, rows, cols, format):
|
||||
def torch_to_xml(tensor, name, roffset, coffset, rows, cols, format):
|
||||
try:
|
||||
if tensor.requires_grad:
|
||||
tensor = tensor.detach()
|
||||
return array_to_xml(tensor.numpy(), name, roffset, coffset, rows, cols, format)
|
||||
except TypeError:
|
||||
return array_to_xml(tensor.to_dense().numpy(), name, roffset, coffset, rows, cols, format)
|
||||
|
||||
|
||||
def tf_sparse_to_xml(tensor, name, roffset, coffset, rows, cols, format):
|
||||
try:
|
||||
import tensorflow as tf
|
||||
return tensor_to_xml(tf.sparse.to_dense(tf.sparse.reorder(tensor)), name, roffset, coffset, rows, cols, format)
|
||||
return tf_to_xml(tf.sparse.to_dense(tf.sparse.reorder(tensor)), name, roffset, coffset, rows, cols, format)
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
@@ -861,10 +870,10 @@ TYPE_TO_XML_CONVERTERS = {
|
||||
"Series": dataframe_to_xml,
|
||||
"GeoDataFrame": dataframe_to_xml,
|
||||
"GeoSeries": dataframe_to_xml,
|
||||
"EagerTensor": tensor_to_xml,
|
||||
"ResourceVariable": tensor_to_xml,
|
||||
"SparseTensor": sparse_tensor_to_xml,
|
||||
"Tensor": tensor_to_xml,
|
||||
"EagerTensor": tf_to_xml,
|
||||
"ResourceVariable": tf_to_xml,
|
||||
"SparseTensor": tf_sparse_to_xml,
|
||||
"Tensor": torch_to_xml,
|
||||
"Dataset": dataset_to_xml
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +31,14 @@ def create_image(arr):
|
||||
arr_to_convert = tf.sparse.to_dense(tf.sparse.reorder(arr_to_convert))
|
||||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
import torch
|
||||
if isinstance(arr_to_convert, torch.Tensor):
|
||||
if arr_to_convert.requires_grad:
|
||||
arr_to_convert = arr_to_convert.detach()
|
||||
arr_to_convert = arr_to_convert.to_dense()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
arr_to_convert = arr_to_convert.numpy()
|
||||
arr_to_convert = np.where(arr_to_convert == None, 0, arr_to_convert)
|
||||
|
||||
@@ -287,6 +287,8 @@ def __create_table(command, start_index=None, end_index=None, format=None):
|
||||
try:
|
||||
import torch
|
||||
if isinstance(np_array, torch.Tensor):
|
||||
if np_array.requires_grad:
|
||||
np_array = np_array.detach()
|
||||
np_array = np_array.to_dense()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user