diff --git a/python/pluginResources/messages/PyBundle.properties b/python/pluginResources/messages/PyBundle.properties index 8d67713df97e..47e476fbdf47 100644 --- a/python/pluginResources/messages/PyBundle.properties +++ b/python/pluginResources/messages/PyBundle.properties @@ -62,6 +62,7 @@ refactoring.introduce.variable.scope.error=The name clashes with an existing var # introduce constant refactoring.introduce.constant.dialog.title=Extract Constant refactoring.introduce.constant.scope.error=The name is already declared in the scope +refactoring.introduce.constant.cannot.extract.selected.expression=Selected expression cannot be extracted into a constant # introduce parameter refactoring.extract.parameter.dialog.title=Extract Parameter diff --git a/python/src/com/jetbrains/python/refactoring/introduce/IntroduceHandler.java b/python/src/com/jetbrains/python/refactoring/introduce/IntroduceHandler.java index 630697b88dee..be64555d4956 100644 --- a/python/src/com/jetbrains/python/refactoring/introduce/IntroduceHandler.java +++ b/python/src/com/jetbrains/python/refactoring/introduce/IntroduceHandler.java @@ -430,14 +430,11 @@ abstract public class IntroduceHandler implements RefactoringActionHandler { private void performActionOnElement(IntroduceOperation operation) { if (!checkEnabled(operation)) { + showCanNotIntroduceErrorHint(operation.getProject(), operation.getEditor()); return; } final PsiElement element = operation.getElement(); - - final PsiElement parent = element.getParent(); - final PyExpression initializer = parent instanceof PyAssignmentStatement ? - ((PyAssignmentStatement)parent).getAssignedValue() : - (PyExpression)element; + final PyExpression initializer = getInitializerForElement(element); operation.setInitializer(initializer); if (initializer != null) { @@ -451,6 +448,8 @@ abstract public class IntroduceHandler implements RefactoringActionHandler { performActionOnElementOccurrences(operation); } + protected void showCanNotIntroduceErrorHint(@NotNull Project project, @NotNull Editor editor) {} + protected void performActionOnElementOccurrences(final IntroduceOperation operation) { final Editor editor = operation.getEditor(); if (editor.getSettings().isVariableInplaceRenameEnabled()) { @@ -473,6 +472,13 @@ abstract public class IntroduceHandler implements RefactoringActionHandler { } } + protected @Nullable PyExpression getInitializerForElement(@Nullable PsiElement element) { + if (element == null) return null; + final PsiElement parent = element.getParent(); + return parent instanceof PyAssignmentStatement ? ((PyAssignmentStatement)parent).getAssignedValue() : + element instanceof PyExpression ? (PyExpression)element : null; + } + protected void performInplaceIntroduce(IntroduceOperation operation) { final PsiElement statement = performRefactoring(operation); if (statement instanceof PyAssignmentStatement) { diff --git a/python/src/com/jetbrains/python/refactoring/introduce/constant/PyIntroduceConstantHandler.java b/python/src/com/jetbrains/python/refactoring/introduce/constant/PyIntroduceConstantHandler.java index 43e747dba1cf..9434aa8567ef 100644 --- a/python/src/com/jetbrains/python/refactoring/introduce/constant/PyIntroduceConstantHandler.java +++ b/python/src/com/jetbrains/python/refactoring/introduce/constant/PyIntroduceConstantHandler.java @@ -15,22 +15,33 @@ */ package com.jetbrains.python.refactoring.introduce.constant; +import com.intellij.openapi.editor.Editor; +import com.intellij.openapi.project.Project; +import com.intellij.openapi.util.TextRange; import com.intellij.openapi.util.text.StringUtil; import com.intellij.psi.PsiElement; +import com.intellij.psi.PsiFile; import com.intellij.psi.util.PsiTreeUtil; +import com.intellij.psi.util.PsiUtilCore; +import com.intellij.refactoring.RefactoringBundle; +import com.intellij.refactoring.util.CommonRefactoringUtil; +import com.intellij.util.containers.ContainerUtil; import com.jetbrains.python.PyBundle; import com.jetbrains.python.codeInsight.controlflow.ScopeOwner; import com.jetbrains.python.codeInsight.imports.AddImportHelper; -import com.jetbrains.python.psi.PyExpression; -import com.jetbrains.python.psi.PyFile; -import com.jetbrains.python.psi.PyParameterList; +import com.jetbrains.python.psi.*; +import com.jetbrains.python.psi.impl.PyPsiUtils; +import com.jetbrains.python.psi.resolve.PyResolveUtil; import com.jetbrains.python.refactoring.PyReplaceExpressionUtil; import com.jetbrains.python.refactoring.introduce.IntroduceHandler; import com.jetbrains.python.refactoring.introduce.IntroduceOperation; +import one.util.streamex.StreamEx; import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; import java.util.Collection; import java.util.HashSet; +import java.util.List; /** * @author Alexey.Ivanov @@ -49,16 +60,26 @@ public class PyIntroduceConstantHandler extends IntroduceHandler { } @Override - protected PsiElement addDeclaration(@NotNull final PsiElement expression, - @NotNull final PsiElement declaration, - @NotNull final IntroduceOperation operation) { - final PsiElement anchor = expression.getContainingFile(); - assert anchor instanceof PyFile; - return anchor.addBefore(declaration, AddImportHelper.getFileInsertPosition((PyFile)anchor)); + protected PsiElement addDeclaration(@NotNull PsiElement expression, + @NotNull PsiElement declaration, + @NotNull IntroduceOperation operation) { + PsiElement containingFile = expression.getContainingFile(); + assert containingFile instanceof PyFile; + PsiElement initialPosition = AddImportHelper.getFileInsertPosition((PyFile)containingFile); + + List sameFileRefs = collectReferencedDefinitionsInSameFile(operation); + PsiElement maxPosition = getLowermostTopLevelStatement(sameFileRefs); + + if (maxPosition == null) { + return containingFile.addBefore(declaration, initialPosition); + } + + assert PyUtil.isTopLevel(maxPosition); + return containingFile.addAfter(declaration, maxPosition); } @Override - protected Collection generateSuggestedNames(@NotNull final PyExpression expression) { + protected Collection generateSuggestedNames(@NotNull PyExpression expression) { Collection names = new HashSet<>(); for (String name : super.generateSuggestedNames(expression)) { names.add(StringUtil.toUpperCase(name)); @@ -71,6 +92,65 @@ public class PyIntroduceConstantHandler extends IntroduceHandler { return super.isValidIntroduceContext(element) || PsiTreeUtil.getParentOfType(element, PyParameterList.class) != null; } + @Override + protected boolean checkEnabled(@NotNull IntroduceOperation operation) { + PsiElement selectionElement = getOriginalSelectionCoveringElement(operation.getElement()); + + PsiFile containingFile = selectionElement.getContainingFile(); + if (!(containingFile instanceof PyFile)) return false; + + Editor editor = operation.getEditor(); + if (editor == null) return false; + + List sameFileRefs = collectReferencedDefinitionsInSameFile(operation); + if (!ContainerUtil.all(sameFileRefs, it -> PyUtil.isTopLevel(it))) return false; + PsiElement maxPosition = getLowermostTopLevelStatement(sameFileRefs); + if (maxPosition == null) return true; + return PsiUtilCore.compareElementsByPosition(maxPosition, selectionElement) <= 0 && + !PsiTreeUtil.isAncestor(maxPosition, selectionElement, false); + } + + private static @NotNull List collectReferencedDefinitionsInSameFile(@NotNull IntroduceOperation operation) { + PsiElement selectionElement = getOriginalSelectionCoveringElement(operation.getElement()); + TextRange textRange = getTextRangeForOperationElement(operation.getElement()); + + return StreamEx.of(PsiTreeUtil.collectElementsOfType(selectionElement, PyReferenceExpression.class)) + .filter(it -> textRange.contains(it.getTextRange())) + .filter(ref -> !ref.isQualified()) + .flatMap(expr -> PyResolveUtil.resolveLocally(expr).stream()) + .filter(it -> it != null && it.getContainingFile() == operation.getFile()) + .toList(); + } + + private static @NotNull TextRange getTextRangeForOperationElement(@NotNull PsiElement operationElement) { + var userData = operationElement.getUserData(PyReplaceExpressionUtil.SELECTION_BREAKS_AST_NODE); + if (userData == null || userData.first == null || userData.second == null) { + return operationElement.getTextRange(); + } + else { + return userData.second.shiftRight(userData.first.getTextOffset()); + } + } + + private static @NotNull PsiElement getOriginalSelectionCoveringElement(@NotNull PsiElement operationElement) { + var userData = operationElement.getUserData(PyReplaceExpressionUtil.SELECTION_BREAKS_AST_NODE); + return userData == null ? operationElement : userData.first; + } + + private static @Nullable PsiElement getLowermostTopLevelStatement(@NotNull List elements) { + return StreamEx.of(elements) + .map(it -> PyPsiUtils.getParentRightBefore(it, it.getContainingFile())) + .select(PyStatement.class) + .max(PsiUtilCore::compareElementsByPosition) + .orElse(null); + } + + @Override protected void showCanNotIntroduceErrorHint(@NotNull Project project, @NotNull Editor editor) { + String message = + RefactoringBundle.getCannotRefactorMessage(PyBundle.message("refactoring.introduce.constant.cannot.extract.selected.expression")); + CommonRefactoringUtil.showErrorHint(project, editor, message, myDialogTitle, getHelpId()); + } + @Override protected String getHelpId() { return "python.reference.introduceConstant"; diff --git a/python/testData/refactoring/introduceConstant/expressionWithFunctionCall.after.py b/python/testData/refactoring/introduceConstant/expressionWithFunctionCall.after.py new file mode 100644 index 000000000000..a44e96978299 --- /dev/null +++ b/python/testData/refactoring/introduceConstant/expressionWithFunctionCall.after.py @@ -0,0 +1,10 @@ +from six import PY2 + +if PY2: + def ascii(obj): + ... +a = ascii(42) + 'foo' + + +def func(p): + X = a \ No newline at end of file diff --git a/python/testData/refactoring/introduceConstant/expressionWithFunctionCall.py b/python/testData/refactoring/introduceConstant/expressionWithFunctionCall.py new file mode 100644 index 000000000000..fdaa0089589c --- /dev/null +++ b/python/testData/refactoring/introduceConstant/expressionWithFunctionCall.py @@ -0,0 +1,9 @@ +from six import PY2 + +if PY2: + def ascii(obj): + ... + + +def func(p): + X = ascii(42) + 'foo' \ No newline at end of file diff --git a/python/testData/refactoring/introduceConstant/expressionWithParameterRefactoringError.py b/python/testData/refactoring/introduceConstant/expressionWithParameterRefactoringError.py new file mode 100644 index 000000000000..6219fee51875 --- /dev/null +++ b/python/testData/refactoring/introduceConstant/expressionWithParameterRefactoringError.py @@ -0,0 +1,9 @@ +from six import PY2 + +if PY2: + def ascii(obj): + ... + + +def func(p): + X = ascii(p) + 'foo' \ No newline at end of file diff --git a/python/testData/refactoring/introduceConstant/fromImportInFunctionRefactoringError.py b/python/testData/refactoring/introduceConstant/fromImportInFunctionRefactoringError.py new file mode 100644 index 000000000000..f9e020e991f6 --- /dev/null +++ b/python/testData/refactoring/introduceConstant/fromImportInFunctionRefactoringError.py @@ -0,0 +1,6 @@ +SUFFIX = "foo" + + +def func(): + from sys import version + print(version + SUFFIX) \ No newline at end of file diff --git a/python/testData/refactoring/introduceConstant/fromImportTopLevel.after.py b/python/testData/refactoring/introduceConstant/fromImportTopLevel.after.py new file mode 100644 index 000000000000..65bc0f057f29 --- /dev/null +++ b/python/testData/refactoring/introduceConstant/fromImportTopLevel.after.py @@ -0,0 +1,8 @@ +SUFFIX = "foo" +from sys import version + +a = version + SUFFIX + + +def func(): + print(a) \ No newline at end of file diff --git a/python/testData/refactoring/introduceConstant/fromImportTopLevel.py b/python/testData/refactoring/introduceConstant/fromImportTopLevel.py new file mode 100644 index 000000000000..a0c69d88f5fe --- /dev/null +++ b/python/testData/refactoring/introduceConstant/fromImportTopLevel.py @@ -0,0 +1,6 @@ +SUFFIX = "foo" +from sys import version + + +def func(): + print(version + SUFFIX) \ No newline at end of file diff --git a/python/testData/refactoring/introduceConstant/insertAfterAllGlobalVariablesOnWhichDepends.after.py b/python/testData/refactoring/introduceConstant/insertAfterAllGlobalVariablesOnWhichDepends.after.py new file mode 100644 index 000000000000..b5f8bfd03fc9 --- /dev/null +++ b/python/testData/refactoring/introduceConstant/insertAfterAllGlobalVariablesOnWhichDepends.after.py @@ -0,0 +1,12 @@ +X = 42 +N = 42 +A = 11 +M = 24 +K = 21 +T = 28 +a = N + M + T + 1 +O = 22 + + +def f(): + print(a) \ No newline at end of file diff --git a/python/testData/refactoring/introduceConstant/insertAfterAllGlobalVariablesOnWhichDepends.py b/python/testData/refactoring/introduceConstant/insertAfterAllGlobalVariablesOnWhichDepends.py new file mode 100644 index 000000000000..bfef077b9fee --- /dev/null +++ b/python/testData/refactoring/introduceConstant/insertAfterAllGlobalVariablesOnWhichDepends.py @@ -0,0 +1,11 @@ +X = 42 +N = 42 +A = 11 +M = 24 +K = 21 +T = 28 +O = 22 + + +def f(): + print(N + M + T + 1) \ No newline at end of file diff --git a/python/testData/refactoring/introduceConstant/insertAfterForIteratorOnWhichDependsRefactoringError.py b/python/testData/refactoring/introduceConstant/insertAfterForIteratorOnWhichDependsRefactoringError.py new file mode 100644 index 000000000000..19aff4897226 --- /dev/null +++ b/python/testData/refactoring/introduceConstant/insertAfterForIteratorOnWhichDependsRefactoringError.py @@ -0,0 +1,6 @@ +N = 42 + +def f(K): + for I in range(0, 10): + for j in range(0, 10): + print(N + K + I) \ No newline at end of file diff --git a/python/testData/refactoring/introduceConstant/insertAfterFunctionParameterOnWhichDependsRefactoringError.py b/python/testData/refactoring/introduceConstant/insertAfterFunctionParameterOnWhichDependsRefactoringError.py new file mode 100644 index 000000000000..97f9c50707a4 --- /dev/null +++ b/python/testData/refactoring/introduceConstant/insertAfterFunctionParameterOnWhichDependsRefactoringError.py @@ -0,0 +1,5 @@ +N = 42 + +def f(K): + for i in range(0, 10): + print(N + K) \ No newline at end of file diff --git a/python/testData/refactoring/introduceConstant/insertAfterGlobalVariableOnWhichDepends.after.py b/python/testData/refactoring/introduceConstant/insertAfterGlobalVariableOnWhichDepends.after.py new file mode 100644 index 000000000000..7cc2079877b0 --- /dev/null +++ b/python/testData/refactoring/introduceConstant/insertAfterGlobalVariableOnWhichDepends.after.py @@ -0,0 +1,6 @@ +N = 42 +a = N + 1 + + +def f(): + print(a) \ No newline at end of file diff --git a/python/testData/refactoring/introduceConstant/insertAfterGlobalVariableOnWhichDepends.py b/python/testData/refactoring/introduceConstant/insertAfterGlobalVariableOnWhichDepends.py new file mode 100644 index 000000000000..a6a6bb651262 --- /dev/null +++ b/python/testData/refactoring/introduceConstant/insertAfterGlobalVariableOnWhichDepends.py @@ -0,0 +1,5 @@ +N = 42 + + +def f(): + print(N + 1) \ No newline at end of file diff --git a/python/testData/refactoring/introduceConstant/insertAfterIfElse.after.py b/python/testData/refactoring/introduceConstant/insertAfterIfElse.after.py new file mode 100644 index 000000000000..6f107c16c349 --- /dev/null +++ b/python/testData/refactoring/introduceConstant/insertAfterIfElse.after.py @@ -0,0 +1,7 @@ +if True: + X = 1 +else: + X = 2 +a = X + 1 + +print(a) \ No newline at end of file diff --git a/python/testData/refactoring/introduceConstant/insertAfterIfElse.py b/python/testData/refactoring/introduceConstant/insertAfterIfElse.py new file mode 100644 index 000000000000..8a7b740ae183 --- /dev/null +++ b/python/testData/refactoring/introduceConstant/insertAfterIfElse.py @@ -0,0 +1,6 @@ +if True: + X = 1 +else: + X = 2 + +print(X + 1) \ No newline at end of file diff --git a/python/testData/refactoring/introduceConstant/insertAfterLocalVariableInForLoopOnWhichDependsRefactoringError.py b/python/testData/refactoring/introduceConstant/insertAfterLocalVariableInForLoopOnWhichDependsRefactoringError.py new file mode 100644 index 000000000000..9ff2e302967a --- /dev/null +++ b/python/testData/refactoring/introduceConstant/insertAfterLocalVariableInForLoopOnWhichDependsRefactoringError.py @@ -0,0 +1,5 @@ +def f(K): + for I in range(0, 10): + N = 42 + for j in range(0, 10): + print(N + K + I) \ No newline at end of file diff --git a/python/testData/refactoring/introduceConstant/insertAfterLocalVariableOnWhichDependsRefactoringError.py b/python/testData/refactoring/introduceConstant/insertAfterLocalVariableOnWhichDependsRefactoringError.py new file mode 100644 index 000000000000..cd45546679b8 --- /dev/null +++ b/python/testData/refactoring/introduceConstant/insertAfterLocalVariableOnWhichDependsRefactoringError.py @@ -0,0 +1,6 @@ +def foo(): + N = 42 + K = 24 + + def f(): + print(N + 1) \ No newline at end of file diff --git a/python/testData/refactoring/introduceConstant/insertAfterWithStatementOnWhichDependsRefactoringError.py b/python/testData/refactoring/introduceConstant/insertAfterWithStatementOnWhichDependsRefactoringError.py new file mode 100644 index 000000000000..857f663dd53e --- /dev/null +++ b/python/testData/refactoring/introduceConstant/insertAfterWithStatementOnWhichDependsRefactoringError.py @@ -0,0 +1,2 @@ +with 42 as N: + print(N + 1) \ No newline at end of file diff --git a/python/testData/refactoring/introduceConstant/subexpressionNotFullWordRefactoringError.py b/python/testData/refactoring/introduceConstant/subexpressionNotFullWordRefactoringError.py new file mode 100644 index 000000000000..74255e039fea --- /dev/null +++ b/python/testData/refactoring/introduceConstant/subexpressionNotFullWordRefactoringError.py @@ -0,0 +1,2 @@ +def func(param): + return param + 42 \ No newline at end of file diff --git a/python/testData/refactoring/introduceConstant/subexpressionWithGlobal.after.py b/python/testData/refactoring/introduceConstant/subexpressionWithGlobal.after.py new file mode 100644 index 000000000000..f5e3a6b394f3 --- /dev/null +++ b/python/testData/refactoring/introduceConstant/subexpressionWithGlobal.after.py @@ -0,0 +1,6 @@ +SOME_GLOBAL = 42 +a = 2 + SOME_GLOBAL + + +def func(): + return 1 + a \ No newline at end of file diff --git a/python/testData/refactoring/introduceConstant/subexpressionWithGlobal.py b/python/testData/refactoring/introduceConstant/subexpressionWithGlobal.py new file mode 100644 index 000000000000..6e7770d9f3c3 --- /dev/null +++ b/python/testData/refactoring/introduceConstant/subexpressionWithGlobal.py @@ -0,0 +1,5 @@ +SOME_GLOBAL = 42 + + +def func(): + return 1 + 2 + SOME_GLOBAL \ No newline at end of file diff --git a/python/testData/refactoring/introduceConstant/subexpressionWithParameterRefactoringError.py b/python/testData/refactoring/introduceConstant/subexpressionWithParameterRefactoringError.py new file mode 100644 index 000000000000..5b6322427c4b --- /dev/null +++ b/python/testData/refactoring/introduceConstant/subexpressionWithParameterRefactoringError.py @@ -0,0 +1,2 @@ +def func(p): + return 1 + 2 + p \ No newline at end of file diff --git a/python/testSrc/com/jetbrains/python/refactoring/PyIntroduceConstantTest.java b/python/testSrc/com/jetbrains/python/refactoring/PyIntroduceConstantTest.java index f3e765a36f45..040568378613 100644 --- a/python/testSrc/com/jetbrains/python/refactoring/PyIntroduceConstantTest.java +++ b/python/testSrc/com/jetbrains/python/refactoring/PyIntroduceConstantTest.java @@ -1,6 +1,7 @@ // Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. package com.jetbrains.python.refactoring; +import com.intellij.refactoring.util.CommonRefactoringUtil; import com.intellij.testFramework.TestDataPath; import com.jetbrains.python.psi.LanguageLevel; import com.jetbrains.python.psi.PyExpression; @@ -43,6 +44,85 @@ public class PyIntroduceConstantTest extends PyIntroduceTestCase { doTest(); } + // PY-23500 + public void testInsertAfterGlobalVariableOnWhichDepends() { + doTest(); + } + + // PY-23500 + public void testInsertAfterAllGlobalVariablesOnWhichDepends() { + doTest(); + } + + // PY-23500 + public void testInsertAfterWithStatementOnWhichDependsRefactoringError() { + doTestThrowsRefactoringErrorHintException(); + } + + // PY-23500 + public void testInsertAfterLocalVariableOnWhichDependsRefactoringError() { + doTestThrowsRefactoringErrorHintException(); + } + + // PY-23500 + public void testInsertAfterFunctionParameterOnWhichDependsRefactoringError() { + doTestThrowsRefactoringErrorHintException(); + } + + // PY-23500 + public void testInsertAfterForIteratorOnWhichDependsRefactoringError() { + doTestThrowsRefactoringErrorHintException(); + } + + // PY-23500 + public void testInsertAfterLocalVariableInForLoopOnWhichDependsRefactoringError() { + doTestThrowsRefactoringErrorHintException(); + } + + // PY-23500 + public void testInsertAfterIfElse() { + doTest(); + } + + // PY-23500 + public void testFromImportTopLevel() { + doTest(); + } + + // PY-23500 + public void testFromImportInFunctionRefactoringError() { + doTestThrowsRefactoringErrorHintException(); + } + + // PY-23500 + public void testExpressionWithFunctionCall() { + doTest(); + } + + // PY-23500 + public void testExpressionWithParameterRefactoringError() { + doTestThrowsRefactoringErrorHintException(); + } + + // PY-23500 + public void testSubexpressionWithParameterRefactoringError() { + doTestThrowsRefactoringErrorHintException(); + } + + // PY-23500 + public void testSubexpressionWithGlobal() { + doTest(); + } + + // PY-23500 + public void testSubexpressionNotFullWordRefactoringError() { + doTestThrowsRefactoringErrorHintException(); + } + + private void doTestThrowsRefactoringErrorHintException() { + assertThrows(CommonRefactoringUtil.RefactoringErrorHintException.class, () -> doTest()); + } + @Override protected String getTestDataPath() { return super.getTestDataPath() + "/refactoring/introduceConstant";