diff --git a/notebooks/visualization/src/org/jetbrains/plugins/notebooks/visualization/outputs/statistic/NotebookOutputKeyType.kt b/notebooks/visualization/src/org/jetbrains/plugins/notebooks/visualization/outputs/statistic/NotebookOutputKeyType.kt index 02fc6a43c162..cbf635925958 100644 --- a/notebooks/visualization/src/org/jetbrains/plugins/notebooks/visualization/outputs/statistic/NotebookOutputKeyType.kt +++ b/notebooks/visualization/src/org/jetbrains/plugins/notebooks/visualization/outputs/statistic/NotebookOutputKeyType.kt @@ -11,6 +11,9 @@ enum class NotebookOutputKeyType { LETS_PLOT, MARKDOWN, NUMPY_ARRAY, + EAGER_TENSOR, + RESOURCE_VARIABLE, + TORCH_TENSOR, PANDAS_DATA_FRAME, PANDAS_SERIES, POLARS_DATA_FRAME, diff --git a/python/helpers/pydev/_pydevd_bundle/pydevd_tables.py b/python/helpers/pydev/_pydevd_bundle/pydevd_tables.py index 6151724579c7..72ad69c4b008 100644 --- a/python/helpers/pydev/_pydevd_bundle/pydevd_tables.py +++ b/python/helpers/pydev/_pydevd_bundle/pydevd_tables.py @@ -22,7 +22,8 @@ def is_error_on_eval(val): return is_exception_on_eval -def exec_table_command(init_command, command_type, start_index, end_index, f_globals, f_locals): +def exec_table_command(init_command, command_type, start_index, end_index, f_globals, + f_locals): # type: (str, str, [int, None], [int, None], dict, dict) -> (bool, str) table = pydevd_vars.eval_in_context(init_command, f_globals, f_locals) is_exception_on_eval = is_error_on_eval(table) @@ -69,7 +70,11 @@ def __get_table_provider(output): 'pandas.core.series.Series']: import _pydevd_bundle.tables.pydevd_pandas as table_provider # dict is needed for sort commands - elif type_qualified_name in ['numpy.ndarray', 'builtins.dict']: + elif type_qualified_name in ['numpy.ndarray', + 'tensorflow.python.framework.ops.EagerTensor', + 'tensorflow.python.ops.resource_variable_ops.ResourceVariable', + 'torch.Tensor', + 'builtins.dict']: import _pydevd_bundle.tables.pydevd_numpy as table_provider elif type_qualified_name.startswith('polars') and ( type_qualified_name.endswith('DataFrame') diff --git a/python/helpers/pydev/_pydevd_bundle/pydevd_vars.py b/python/helpers/pydev/_pydevd_bundle/pydevd_vars.py index f8b44fe8a7bc..694a73bd20dc 100644 --- a/python/helpers/pydev/_pydevd_bundle/pydevd_vars.py +++ b/python/helpers/pydev/_pydevd_bundle/pydevd_vars.py @@ -582,6 +582,9 @@ def array_to_xml(array, name, roffset, coffset, rows, cols, format): return xml +def tensorflow_to_xml(tensor, name, roffset, coffset, rows, cols, format): + return array_to_xml(tensor.numpy(), name, roffset, coffset, rows, cols, format) + class ExceedingArrayDimensionsException(Exception): pass @@ -804,7 +807,10 @@ TYPE_TO_XML_CONVERTERS = { "DataFrame": dataframe_to_xml, "Series": dataframe_to_xml, "GeoDataFrame": dataframe_to_xml, - "GeoSeries": dataframe_to_xml + "GeoSeries": dataframe_to_xml, + "EagerTensor": tensorflow_to_xml, + "ResourceVariable": tensorflow_to_xml, + "Tensor": tensorflow_to_xml } diff --git a/python/pydevSrc/src/com/jetbrains/python/debugger/PyDebugValue.java b/python/pydevSrc/src/com/jetbrains/python/debugger/PyDebugValue.java index 1edc70f307da..bf63fabbd34d 100644 --- a/python/pydevSrc/src/com/jetbrains/python/debugger/PyDebugValue.java +++ b/python/pydevSrc/src/com/jetbrains/python/debugger/PyDebugValue.java @@ -30,10 +30,14 @@ import static com.jetbrains.python.debugger.PyDebugValueGroupsKt.*; public class PyDebugValue extends XNamedValue { protected static final Logger LOG = Logger.getInstance(PyDebugValue.class); + private static final String ARRAY = "Array"; private static final String DATA_FRAME = "DataFrame"; private static final String SERIES = "Series"; private static final Map EVALUATOR_POSTFIXES = ImmutableMap.of( - "ndarray", "Array", + "ndarray", ARRAY, + "EagerTensor", ARRAY, + "ResourceVariable", ARRAY, + "Tensor", ARRAY, DATA_FRAME, DATA_FRAME, SERIES, SERIES, "GeoDataFrame", DATA_FRAME, diff --git a/python/src/com/jetbrains/python/debugger/array/ArrayViewStrategy.java b/python/src/com/jetbrains/python/debugger/array/ArrayViewStrategy.java index a9d09afcbcb6..3fed80814a39 100644 --- a/python/src/com/jetbrains/python/debugger/array/ArrayViewStrategy.java +++ b/python/src/com/jetbrains/python/debugger/array/ArrayViewStrategy.java @@ -13,7 +13,27 @@ import org.jetbrains.annotations.Nullable; import javax.swing.*; public class ArrayViewStrategy extends DataViewStrategy { - private static final String NDARRAY = "ndarray"; + private final String myTypeName; + + public static @NotNull ArrayViewStrategy createInstanceForNumpyArray() { + return new ArrayViewStrategy("ndarray"); + } + + public static @NotNull ArrayViewStrategy createInstanceForEagerTensor() { + return new ArrayViewStrategy("EagerTensor"); + } + + public static @NotNull ArrayViewStrategy createInstanceForResourceVariable() { + return new ArrayViewStrategy("ResourceVariable"); + } + + public static @NotNull ArrayViewStrategy createInstanceForTensor() { + return new ArrayViewStrategy("Tensor"); + } + + protected ArrayViewStrategy(final @NotNull String typeName) { + this.myTypeName = typeName; + } @Override public AsyncArrayTableModel createTableModel(int rowCount, @@ -61,6 +81,6 @@ public class ArrayViewStrategy extends DataViewStrategy { @Override public @NotNull String getTypeName() { - return NDARRAY; + return myTypeName; } } diff --git a/python/src/com/jetbrains/python/debugger/containerview/DataViewStrategy.java b/python/src/com/jetbrains/python/debugger/containerview/DataViewStrategy.java index 55a10ab718b8..c2db0ca32a0c 100644 --- a/python/src/com/jetbrains/python/debugger/containerview/DataViewStrategy.java +++ b/python/src/com/jetbrains/python/debugger/containerview/DataViewStrategy.java @@ -18,7 +18,10 @@ import java.util.Set; public abstract class DataViewStrategy { private static class StrategyHolder { private static final Set STRATEGIES = ImmutableSet.of( - new ArrayViewStrategy(), + ArrayViewStrategy.createInstanceForNumpyArray(), + ArrayViewStrategy.createInstanceForEagerTensor(), + ArrayViewStrategy.createInstanceForResourceVariable(), + ArrayViewStrategy.createInstanceForTensor(), DataFrameViewStrategy.createInstanceForDataFrame(), DataFrameViewStrategy.createInstanceForGeoDataFrame(), SeriesViewStrategy.createInstanceForSeries(), diff --git a/python/src/com/jetbrains/python/debugger/containerview/PyViewNumericContainerAction.java b/python/src/com/jetbrains/python/debugger/containerview/PyViewNumericContainerAction.java index 5b49727aedbd..ec702764d90d 100644 --- a/python/src/com/jetbrains/python/debugger/containerview/PyViewNumericContainerAction.java +++ b/python/src/com/jetbrains/python/debugger/containerview/PyViewNumericContainerAction.java @@ -64,7 +64,7 @@ public class PyViewNumericContainerAction extends XDebuggerTreeActionBase { } String nodeType = debugValue.getType(); - if ("ndarray".equals(nodeType)) { + if ("ndarray".equals(nodeType) || "EagerTensor".equals(nodeType) || "ResourceVariable".equals(nodeType) || "Tensor".equals(nodeType)) { e.getPresentation().setText(PyBundle.message("debugger.numeric.view.as.array")); e.getPresentation().setVisible(true); }