diff --git a/python/src/com/jetbrains/python/refactoring/classes/PyClassRefactoringUtil.java b/python/src/com/jetbrains/python/refactoring/classes/PyClassRefactoringUtil.java index 6104a4dede28..a963a570fa49 100644 --- a/python/src/com/jetbrains/python/refactoring/classes/PyClassRefactoringUtil.java +++ b/python/src/com/jetbrains/python/refactoring/classes/PyClassRefactoringUtil.java @@ -15,6 +15,7 @@ import com.intellij.util.ArrayUtil; import com.intellij.util.IncorrectOperationException; import com.intellij.util.containers.ContainerUtil; import com.jetbrains.python.PyNames; +import com.jetbrains.python.codeInsight.imports.AddImportHelper; import com.jetbrains.python.codeInsight.imports.PyImportOptimizer; import com.jetbrains.python.psi.*; import com.jetbrains.python.psi.impl.PyImportedModule; @@ -24,9 +25,7 @@ import org.jetbrains.annotations.NonNls; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; -import java.util.ArrayList; -import java.util.Collection; -import java.util.List; +import java.util.*; /** * @author Dennis.Ushakov @@ -37,6 +36,7 @@ public final class PyClassRefactoringUtil { private static final Key ENCODED_USE_FROM_IMPORT = Key.create("PyEncodedUseFromImport"); private static final Key ENCODED_IMPORT_AS = Key.create("PyEncodedImportAs"); private static final Key> INJECTION_REFERENCES = Key.create("PyInjectionReferences"); + private static final Key> ENCODED_FROM_FUTURE_IMPORTS = Key.create("PyFromFutureImports"); private PyClassRefactoringUtil() { @@ -185,6 +185,16 @@ public final class PyClassRefactoringUtil { public static void restoreNamedReferences(@NotNull final PsiElement newElement, @Nullable final PsiElement oldElement, final PsiElement @NotNull [] otherMovedElements) { + Set fromFutureImports = newElement.getCopyableUserData(ENCODED_FROM_FUTURE_IMPORTS); + newElement.putCopyableUserData(ENCODED_FROM_FUTURE_IMPORTS, null); + PsiFile destFile = newElement.getContainingFile(); + if (fromFutureImports != null & destFile != null) { + for (FutureFeature futureFeature: fromFutureImports) { + AddImportHelper.addOrUpdateFromImportStatement(destFile, PyNames.FUTURE_MODULE, futureFeature.toString(), null, + AddImportHelper.ImportPriority.FUTURE, null); + } + } + newElement.acceptChildren(new PyRecursiveElementVisitor() { @Override public void visitPyReferenceExpression(@NotNull PyReferenceExpression node) { @@ -279,6 +289,12 @@ public final class PyClassRefactoringUtil { * @param namesToSkip if reference inside of element has one of this names, it will not be saved. */ public static void rememberNamedReferences(@NotNull final PsiElement element, final String @NotNull ... namesToSkip) { + PsiFile containingFile = element.getContainingFile(); + if (containingFile instanceof PyFile) { + Set fromFutureImports = collectFromFutureImports((PyFile)containingFile); + element.putCopyableUserData(ENCODED_FROM_FUTURE_IMPORTS, fromFutureImports); + } + element.accept(new PyRecursiveElementVisitor() { @Override public void visitPyReferenceExpression(@NotNull PyReferenceExpression node) { @@ -339,6 +355,16 @@ public final class PyClassRefactoringUtil { host.putCopyableUserData(INJECTION_REFERENCES, rememberedReferences); } + private static @NotNull Set collectFromFutureImports(@NotNull PyFile file) { + EnumSet result = EnumSet.noneOf(FutureFeature.class); + for (FutureFeature feature: FutureFeature.values()) { + if (file.hasImportFromFuture(feature)) { + result.add(feature); + } + } + return result; + } + private static void rememberReference(@NotNull PyReferenceExpression node, @NotNull PsiElement element) { // We will remember reference in deepest node (except for references to PyImportedModules, as we need references to modules, not to // their packages) diff --git a/python/testData/refactoring/extractsuperclass/fromFutureImports/dest_module.after.py b/python/testData/refactoring/extractsuperclass/fromFutureImports/dest_module.after.py new file mode 100644 index 000000000000..586879083698 --- /dev/null +++ b/python/testData/refactoring/extractsuperclass/fromFutureImports/dest_module.after.py @@ -0,0 +1,6 @@ +from __future__ import absolute_import, annotations + + +class NewParent(object): + def foo(self): + pass \ No newline at end of file diff --git a/python/testData/refactoring/extractsuperclass/fromFutureImports/dest_module.py b/python/testData/refactoring/extractsuperclass/fromFutureImports/dest_module.py new file mode 100644 index 000000000000..3be6bd298581 --- /dev/null +++ b/python/testData/refactoring/extractsuperclass/fromFutureImports/dest_module.py @@ -0,0 +1 @@ +from __future__ import absolute_import \ No newline at end of file diff --git a/python/testData/refactoring/extractsuperclass/fromFutureImports/shared_module.py b/python/testData/refactoring/extractsuperclass/fromFutureImports/shared_module.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/python/testData/refactoring/extractsuperclass/fromFutureImports/source_module.after.py b/python/testData/refactoring/extractsuperclass/fromFutureImports/source_module.after.py new file mode 100644 index 000000000000..cfd6a984ddf5 --- /dev/null +++ b/python/testData/refactoring/extractsuperclass/fromFutureImports/source_module.after.py @@ -0,0 +1,8 @@ +from __future__ import annotations +from __future__ import absolute_import + +from dest_module import NewParent + + +class MyClass(NewParent): + pass \ No newline at end of file diff --git a/python/testData/refactoring/extractsuperclass/fromFutureImports/source_module.py b/python/testData/refactoring/extractsuperclass/fromFutureImports/source_module.py new file mode 100644 index 000000000000..122ee33aff94 --- /dev/null +++ b/python/testData/refactoring/extractsuperclass/fromFutureImports/source_module.py @@ -0,0 +1,6 @@ +from __future__ import annotations +from __future__ import absolute_import + +class MyClass(object): + def foo(self): + pass \ No newline at end of file diff --git a/python/testData/refactoring/move/existingFromFutureImportsNotDuplicated/after/src/a.py b/python/testData/refactoring/move/existingFromFutureImportsNotDuplicated/after/src/a.py new file mode 100644 index 000000000000..62778e5d0b2e --- /dev/null +++ b/python/testData/refactoring/move/existingFromFutureImportsNotDuplicated/after/src/a.py @@ -0,0 +1,5 @@ +from __future__ import annotations +from __future__ import unicode_literals +from __future__ import division + + diff --git a/python/testData/refactoring/move/existingFromFutureImportsNotDuplicated/after/src/b.py b/python/testData/refactoring/move/existingFromFutureImportsNotDuplicated/after/src/b.py new file mode 100644 index 000000000000..ba72212f69d3 --- /dev/null +++ b/python/testData/refactoring/move/existingFromFutureImportsNotDuplicated/after/src/b.py @@ -0,0 +1,6 @@ +from __future__ import unicode_literals, division, annotations +from __future__ import print_function + + +class C: + pass \ No newline at end of file diff --git a/python/testData/refactoring/move/existingFromFutureImportsNotDuplicated/before/src/a.py b/python/testData/refactoring/move/existingFromFutureImportsNotDuplicated/before/src/a.py new file mode 100644 index 000000000000..8f10d63f56eb --- /dev/null +++ b/python/testData/refactoring/move/existingFromFutureImportsNotDuplicated/before/src/a.py @@ -0,0 +1,7 @@ +from __future__ import annotations +from __future__ import unicode_literals +from __future__ import division + + +class C: + pass diff --git a/python/testData/refactoring/move/existingFromFutureImportsNotDuplicated/before/src/b.py b/python/testData/refactoring/move/existingFromFutureImportsNotDuplicated/before/src/b.py new file mode 100644 index 000000000000..fc49d10aa119 --- /dev/null +++ b/python/testData/refactoring/move/existingFromFutureImportsNotDuplicated/before/src/b.py @@ -0,0 +1,2 @@ +from __future__ import unicode_literals +from __future__ import print_function diff --git a/python/testData/refactoring/move/fromFutureImports/after/src/a.py b/python/testData/refactoring/move/fromFutureImports/after/src/a.py new file mode 100644 index 000000000000..1aa452689c5b --- /dev/null +++ b/python/testData/refactoring/move/fromFutureImports/after/src/a.py @@ -0,0 +1,4 @@ +from __future__ import annotations +from __future__ import unicode_literals + + diff --git a/python/testData/refactoring/move/fromFutureImports/after/src/b.py b/python/testData/refactoring/move/fromFutureImports/after/src/b.py new file mode 100644 index 000000000000..9c0229bb0b7f --- /dev/null +++ b/python/testData/refactoring/move/fromFutureImports/after/src/b.py @@ -0,0 +1,5 @@ +from __future__ import unicode_literals, annotations + + +class C: + pass \ No newline at end of file diff --git a/python/testData/refactoring/move/fromFutureImports/before/src/a.py b/python/testData/refactoring/move/fromFutureImports/before/src/a.py new file mode 100644 index 000000000000..b3b31a4bcbd0 --- /dev/null +++ b/python/testData/refactoring/move/fromFutureImports/before/src/a.py @@ -0,0 +1,6 @@ +from __future__ import annotations +from __future__ import unicode_literals + + +class C: + pass diff --git a/python/testSrc/com/jetbrains/python/refactoring/PyMoveTest.java b/python/testSrc/com/jetbrains/python/refactoring/PyMoveTest.java index 3949412feb92..8a2cfea2c3fd 100644 --- a/python/testSrc/com/jetbrains/python/refactoring/PyMoveTest.java +++ b/python/testSrc/com/jetbrains/python/refactoring/PyMoveTest.java @@ -438,6 +438,16 @@ public class PyMoveTest extends PyTestCase { doMoveSymbolTest("func", "dst.py"); } + // PY-16221 + public void testFromFutureImports() { + doMoveSymbolTest("C", "b.py"); + } + + // PY-16221 + public void testExistingFromFutureImportsNotDuplicated() { + doMoveSymbolTest("C", "b.py"); + } + // PY-23831 public void testWithImportedForwardReferencesInTypeHints() { doMoveSymbolTest("test", "dst.py"); diff --git a/python/testSrc/com/jetbrains/python/refactoring/classes/extractSuperclass/PyExtractSuperclassTest.java b/python/testSrc/com/jetbrains/python/refactoring/classes/extractSuperclass/PyExtractSuperclassTest.java index 8b6b80c3a4ec..2e38af926912 100644 --- a/python/testSrc/com/jetbrains/python/refactoring/classes/extractSuperclass/PyExtractSuperclassTest.java +++ b/python/testSrc/com/jetbrains/python/refactoring/classes/extractSuperclass/PyExtractSuperclassTest.java @@ -249,4 +249,9 @@ public final class PyExtractSuperclassTest extends PyClassRefactoringTest { public void testNoClassCastExceptionInCopiedFunctionWithClassInitAndMethodCall() { doSimpleTest("Baz", "Bar", null, true, false, ".baz"); } + + // PY-16221 + public void testFromFutureImports() { + multiFileTestHelper(".foo", false); + } } \ No newline at end of file