[pycharm] PY-80607 Tables(Jupyter, DataView): support requires_grad case

(cherry picked from commit 8f8be0cc80f3473e6fabd2f9cf9034a2103cdbfd)

GitOrigin-RevId: cc10b1639b4be14201e2a7a89655146e4e4bbead
This commit is contained in:
ekaterina.itsenko
2025-07-04 14:49:21 +02:00
committed by intellij-monorepo-bot
parent ca9ae7baa5
commit 8828a301b0
4 changed files with 42 additions and 14 deletions

View File

@@ -420,17 +420,26 @@ def array_to_thrift_struct(array, name, roffset, coffset, rows, cols, format):
return array_chunk
def tensor_to_thrift_struct(tensor, name, roffset, coffset, rows, cols, format):
def tf_to_thrift_struct(tensor, name, roffset, coffset, rows, cols, format):
try:
return array_to_thrift_struct(tensor.numpy(), name, roffset, coffset, rows, cols, format)
except TypeError:
return array_to_thrift_struct(tensor.to_dense().numpy(), name, roffset, coffset, rows, cols, format)
def sparse_tensor_to_thrift_struct(tensor, name, roffset, coffset, rows, cols, format):
def torch_to_thrift_struct(tensor, name, roffset, coffset, rows, cols, format):
try:
if tensor.requires_grad:
tensor = tensor.detach()
return array_to_thrift_struct(tensor.numpy(), name, roffset, coffset, rows, cols, format)
except TypeError:
return array_to_thrift_struct(tensor.to_dense().numpy(), name, roffset, coffset, rows, cols, format)
def tf_sparse_to_thrift_struct(tensor, name, roffset, coffset, rows, cols, format):
try:
import tensorflow as tf
return tensor_to_thrift_struct(tf.sparse.to_dense(tf.sparse.reorder(tensor)), name, roffset, coffset, rows, cols, format)
return tf_to_thrift_struct(tf.sparse.to_dense(tf.sparse.reorder(tensor)), name, roffset, coffset, rows, cols, format)
except ImportError:
pass
@@ -646,10 +655,10 @@ def header_data_to_thrift_struct(rows, cols, dtypes, col_bounds, col_to_format,
TYPE_TO_THRIFT_STRUCT_CONVERTERS = {
"ndarray": array_to_thrift_struct,
"recarray": array_to_thrift_struct,
"EagerTensor": tensor_to_thrift_struct,
"ResourceVariable": tensor_to_thrift_struct,
"SparseTensor": sparse_tensor_to_thrift_struct,
"Tensor": tensor_to_thrift_struct,
"EagerTensor": tf_to_thrift_struct,
"ResourceVariable": tf_to_thrift_struct,
"SparseTensor": tf_sparse_to_thrift_struct,
"Tensor": torch_to_thrift_struct,
"DataFrame": dataframe_to_thrift_struct,
"Series": dataframe_to_thrift_struct,
"Dataset": dataset_to_thrift_struct,

View File

@@ -589,17 +589,26 @@ def array_to_xml(array, name, roffset, coffset, rows, cols, format):
return xml
def tensor_to_xml(tensor, name, roffset, coffset, rows, cols, format):
def tf_to_xml(tensor, name, roffset, coffset, rows, cols, format):
try:
return array_to_xml(tensor.numpy(), name, roffset, coffset, rows, cols, format)
except TypeError:
return array_to_xml(tensor.to_dense().numpy(), name, roffset, coffset, rows, cols, format)
def sparse_tensor_to_xml(tensor, name, roffset, coffset, rows, cols, format):
def torch_to_xml(tensor, name, roffset, coffset, rows, cols, format):
try:
if tensor.requires_grad:
tensor = tensor.detach()
return array_to_xml(tensor.numpy(), name, roffset, coffset, rows, cols, format)
except TypeError:
return array_to_xml(tensor.to_dense().numpy(), name, roffset, coffset, rows, cols, format)
def tf_sparse_to_xml(tensor, name, roffset, coffset, rows, cols, format):
try:
import tensorflow as tf
return tensor_to_xml(tf.sparse.to_dense(tf.sparse.reorder(tensor)), name, roffset, coffset, rows, cols, format)
return tf_to_xml(tf.sparse.to_dense(tf.sparse.reorder(tensor)), name, roffset, coffset, rows, cols, format)
except ImportError:
pass
@@ -861,10 +870,10 @@ TYPE_TO_XML_CONVERTERS = {
"Series": dataframe_to_xml,
"GeoDataFrame": dataframe_to_xml,
"GeoSeries": dataframe_to_xml,
"EagerTensor": tensor_to_xml,
"ResourceVariable": tensor_to_xml,
"SparseTensor": sparse_tensor_to_xml,
"Tensor": tensor_to_xml,
"EagerTensor": tf_to_xml,
"ResourceVariable": tf_to_xml,
"SparseTensor": tf_sparse_to_xml,
"Tensor": torch_to_xml,
"Dataset": dataset_to_xml
}

View File

@@ -31,6 +31,14 @@ def create_image(arr):
arr_to_convert = tf.sparse.to_dense(tf.sparse.reorder(arr_to_convert))
except ImportError:
pass
try:
import torch
if isinstance(arr_to_convert, torch.Tensor):
if arr_to_convert.requires_grad:
arr_to_convert = arr_to_convert.detach()
arr_to_convert = arr_to_convert.to_dense()
except Exception:
pass
arr_to_convert = arr_to_convert.numpy()
arr_to_convert = np.where(arr_to_convert == None, 0, arr_to_convert)

View File

@@ -287,6 +287,8 @@ def __create_table(command, start_index=None, end_index=None, format=None):
try:
import torch
if isinstance(np_array, torch.Tensor):
if np_array.requires_grad:
np_array = np_array.detach()
np_array = np_array.to_dense()
except Exception:
pass