[pycharm] PY-38294 Add Sparse tensors support

GitOrigin-RevId: 0b9b046ae1da37b5fd34be800fbdd6f241f7514d
This commit is contained in:
ekaterina.itsenko
2024-06-11 04:03:23 +02:00
committed by intellij-monorepo-bot
parent fc7c02e181
commit 1693fd94b9
9 changed files with 31 additions and 2 deletions

View File

@@ -13,6 +13,7 @@ enum class NotebookOutputKeyType {
NUMPY_ARRAY, NUMPY_ARRAY,
EAGER_TENSOR, EAGER_TENSOR,
RESOURCE_VARIABLE, RESOURCE_VARIABLE,
SPARSE_TENSOR,
TORCH_TENSOR, TORCH_TENSOR,
PANDAS_DATA_FRAME, PANDAS_DATA_FRAME,
PANDAS_SERIES, PANDAS_SERIES,

View File

@@ -75,6 +75,7 @@ def __get_table_provider(output):
elif type_qualified_name in ['numpy.ndarray', elif type_qualified_name in ['numpy.ndarray',
'tensorflow.python.framework.ops.EagerTensor', 'tensorflow.python.framework.ops.EagerTensor',
'tensorflow.python.ops.resource_variable_ops.ResourceVariable', 'tensorflow.python.ops.resource_variable_ops.ResourceVariable',
'tensorflow.python.framework.sparse_tensor.SparseTensor',
'torch.Tensor', 'torch.Tensor',
'builtins.dict']: 'builtins.dict']:
import _pydevd_bundle.tables.pydevd_numpy as table_provider import _pydevd_bundle.tables.pydevd_numpy as table_provider

View File

@@ -412,6 +412,11 @@ def tensor_to_thrift_struct(tensor, name, roffset, coffset, rows, cols, format):
return array_to_thrift_struct(tensor.numpy(), name, roffset, coffset, rows, cols, format) return array_to_thrift_struct(tensor.numpy(), name, roffset, coffset, rows, cols, format)
def sparse_tensor_to_thrift_struct(tensor, name, roffset, coffset, rows, cols, format):
import tensorflow as tf
return tensor_to_thrift_struct(tf.sparse.to_dense(tf.sparse.reorder(tensor)), name, roffset, coffset, rows, cols, format)
def array_to_meta_thrift_struct(array, name, format): def array_to_meta_thrift_struct(array, name, format):
type = array.dtype.kind type = array.dtype.kind
slice = name slice = name
@@ -609,6 +614,7 @@ TYPE_TO_THRIFT_STRUCT_CONVERTERS = {
"ndarray": array_to_thrift_struct, "ndarray": array_to_thrift_struct,
"EagerTensor": tensor_to_thrift_struct, "EagerTensor": tensor_to_thrift_struct,
"ResourceVariable": tensor_to_thrift_struct, "ResourceVariable": tensor_to_thrift_struct,
"SparseTensor": sparse_tensor_to_thrift_struct,
"Tensor": tensor_to_thrift_struct, "Tensor": tensor_to_thrift_struct,
"DataFrame": dataframe_to_thrift_struct, "DataFrame": dataframe_to_thrift_struct,
"Series": dataframe_to_thrift_struct, "Series": dataframe_to_thrift_struct,

View File

@@ -586,6 +586,11 @@ def tensor_to_xml(tensor, name, roffset, coffset, rows, cols, format):
return array_to_xml(tensor.numpy(), name, roffset, coffset, rows, cols, format) return array_to_xml(tensor.numpy(), name, roffset, coffset, rows, cols, format)
def sparse_tensor_to_xml(tensor, name, roffset, coffset, rows, cols, format):
import tensorflow as tf
return tensor_to_xml(tf.sparse.to_dense(tf.sparse.reorder(tensor)), name, roffset, coffset, rows, cols, format)
class ExceedingArrayDimensionsException(Exception): class ExceedingArrayDimensionsException(Exception):
pass pass
@@ -811,6 +816,7 @@ TYPE_TO_XML_CONVERTERS = {
"GeoSeries": dataframe_to_xml, "GeoSeries": dataframe_to_xml,
"EagerTensor": tensor_to_xml, "EagerTensor": tensor_to_xml,
"ResourceVariable": tensor_to_xml, "ResourceVariable": tensor_to_xml,
"SparseTensor": sparse_tensor_to_xml,
"Tensor": tensor_to_xml "Tensor": tensor_to_xml
} }

View File

@@ -220,7 +220,12 @@ def _create_table(command, start_index=None, end_index=None):
np_array = command['data'] np_array = command['data']
sort_keys = command['sort_keys'] sort_keys = command['sort_keys']
else: else:
np_array = command try:
import tensorflow as tf
if isinstance(command, tf.SparseTensor):
command = tf.sparse.to_dense(tf.sparse.reorder(command))
finally:
np_array = command
if is_pd: if is_pd:
sorting_arr = _sort_df(pd.DataFrame(np_array), sort_keys) sorting_arr = _sort_df(pd.DataFrame(np_array), sort_keys)

View File

@@ -37,6 +37,7 @@ public class PyDebugValue extends XNamedValue {
"ndarray", ARRAY, "ndarray", ARRAY,
"EagerTensor", ARRAY, "EagerTensor", ARRAY,
"ResourceVariable", ARRAY, "ResourceVariable", ARRAY,
"SparseTensor", ARRAY,
"Tensor", ARRAY, "Tensor", ARRAY,
DATA_FRAME, DATA_FRAME, DATA_FRAME, DATA_FRAME,
SERIES, SERIES, SERIES, SERIES,

View File

@@ -27,6 +27,10 @@ public class ArrayViewStrategy extends DataViewStrategy {
return new ArrayViewStrategy("ResourceVariable"); return new ArrayViewStrategy("ResourceVariable");
} }
public static @NotNull ArrayViewStrategy createInstanceForSparseTensor() {
return new ArrayViewStrategy("SparseTensor");
}
public static @NotNull ArrayViewStrategy createInstanceForTensor() { public static @NotNull ArrayViewStrategy createInstanceForTensor() {
return new ArrayViewStrategy("Tensor"); return new ArrayViewStrategy("Tensor");
} }

View File

@@ -21,6 +21,7 @@ public abstract class DataViewStrategy {
ArrayViewStrategy.createInstanceForNumpyArray(), ArrayViewStrategy.createInstanceForNumpyArray(),
ArrayViewStrategy.createInstanceForEagerTensor(), ArrayViewStrategy.createInstanceForEagerTensor(),
ArrayViewStrategy.createInstanceForResourceVariable(), ArrayViewStrategy.createInstanceForResourceVariable(),
ArrayViewStrategy.createInstanceForSparseTensor(),
ArrayViewStrategy.createInstanceForTensor(), ArrayViewStrategy.createInstanceForTensor(),
DataFrameViewStrategy.createInstanceForDataFrame(), DataFrameViewStrategy.createInstanceForDataFrame(),
DataFrameViewStrategy.createInstanceForGeoDataFrame(), DataFrameViewStrategy.createInstanceForGeoDataFrame(),

View File

@@ -64,7 +64,11 @@ public class PyViewNumericContainerAction extends XDebuggerTreeActionBase {
} }
String nodeType = debugValue.getType(); String nodeType = debugValue.getType();
if ("ndarray".equals(nodeType) || "EagerTensor".equals(nodeType) || "ResourceVariable".equals(nodeType) || "Tensor".equals(nodeType)) { if ("ndarray".equals(nodeType) ||
"EagerTensor".equals(nodeType) ||
"ResourceVariable".equals(nodeType) ||
"SparseTensor".equals(nodeType) ||
"Tensor".equals(nodeType)) {
e.getPresentation().setText(PyBundle.message("debugger.numeric.view.as.array")); e.getPresentation().setText(PyBundle.message("debugger.numeric.view.as.array"));
e.getPresentation().setVisible(true); e.getPresentation().setVisible(true);
} }