PY-23500 Impl considering dependencies for introduce constant fix

(cherry picked from commit fe2adaeabbf1862c2f51a93df14995264a251cca)

IJ-MR-5221

GitOrigin-RevId: 08d0db849d31cdf7684a1b7a68d68072cc0d3686
This commit is contained in:
andrey.matveev
2020-12-25 23:24:12 +07:00
committed by intellij-monorepo-bot
parent 96fd6d2148
commit aa5eb8dc43
25 changed files with 316 additions and 15 deletions

View File

@@ -62,6 +62,7 @@ refactoring.introduce.variable.scope.error=The name clashes with an existing var
# introduce constant
refactoring.introduce.constant.dialog.title=Extract Constant
refactoring.introduce.constant.scope.error=The name is already declared in the scope
refactoring.introduce.constant.cannot.extract.selected.expression=Selected expression cannot be extracted into a constant
# introduce parameter
refactoring.extract.parameter.dialog.title=Extract Parameter

View File

@@ -430,14 +430,11 @@ abstract public class IntroduceHandler implements RefactoringActionHandler {
private void performActionOnElement(IntroduceOperation operation) {
if (!checkEnabled(operation)) {
showCanNotIntroduceErrorHint(operation.getProject(), operation.getEditor());
return;
}
final PsiElement element = operation.getElement();
final PsiElement parent = element.getParent();
final PyExpression initializer = parent instanceof PyAssignmentStatement ?
((PyAssignmentStatement)parent).getAssignedValue() :
(PyExpression)element;
final PyExpression initializer = getInitializerForElement(element);
operation.setInitializer(initializer);
if (initializer != null) {
@@ -451,6 +448,8 @@ abstract public class IntroduceHandler implements RefactoringActionHandler {
performActionOnElementOccurrences(operation);
}
protected void showCanNotIntroduceErrorHint(@NotNull Project project, @NotNull Editor editor) {}
protected void performActionOnElementOccurrences(final IntroduceOperation operation) {
final Editor editor = operation.getEditor();
if (editor.getSettings().isVariableInplaceRenameEnabled()) {
@@ -473,6 +472,13 @@ abstract public class IntroduceHandler implements RefactoringActionHandler {
}
}
protected @Nullable PyExpression getInitializerForElement(@Nullable PsiElement element) {
if (element == null) return null;
final PsiElement parent = element.getParent();
return parent instanceof PyAssignmentStatement ? ((PyAssignmentStatement)parent).getAssignedValue() :
element instanceof PyExpression ? (PyExpression)element : null;
}
protected void performInplaceIntroduce(IntroduceOperation operation) {
final PsiElement statement = performRefactoring(operation);
if (statement instanceof PyAssignmentStatement) {

View File

@@ -15,22 +15,33 @@
*/
package com.jetbrains.python.refactoring.introduce.constant;
import com.intellij.openapi.editor.Editor;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.util.TextRange;
import com.intellij.openapi.util.text.StringUtil;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiFile;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.psi.util.PsiUtilCore;
import com.intellij.refactoring.RefactoringBundle;
import com.intellij.refactoring.util.CommonRefactoringUtil;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.PyBundle;
import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
import com.jetbrains.python.codeInsight.imports.AddImportHelper;
import com.jetbrains.python.psi.PyExpression;
import com.jetbrains.python.psi.PyFile;
import com.jetbrains.python.psi.PyParameterList;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.impl.PyPsiUtils;
import com.jetbrains.python.psi.resolve.PyResolveUtil;
import com.jetbrains.python.refactoring.PyReplaceExpressionUtil;
import com.jetbrains.python.refactoring.introduce.IntroduceHandler;
import com.jetbrains.python.refactoring.introduce.IntroduceOperation;
import one.util.streamex.StreamEx;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
/**
* @author Alexey.Ivanov
@@ -49,16 +60,26 @@ public class PyIntroduceConstantHandler extends IntroduceHandler {
}
@Override
protected PsiElement addDeclaration(@NotNull final PsiElement expression,
@NotNull final PsiElement declaration,
@NotNull final IntroduceOperation operation) {
final PsiElement anchor = expression.getContainingFile();
assert anchor instanceof PyFile;
return anchor.addBefore(declaration, AddImportHelper.getFileInsertPosition((PyFile)anchor));
protected PsiElement addDeclaration(@NotNull PsiElement expression,
@NotNull PsiElement declaration,
@NotNull IntroduceOperation operation) {
PsiElement containingFile = expression.getContainingFile();
assert containingFile instanceof PyFile;
PsiElement initialPosition = AddImportHelper.getFileInsertPosition((PyFile)containingFile);
List<PsiElement> sameFileRefs = collectReferencedDefinitionsInSameFile(operation);
PsiElement maxPosition = getLowermostTopLevelStatement(sameFileRefs);
if (maxPosition == null) {
return containingFile.addBefore(declaration, initialPosition);
}
assert PyUtil.isTopLevel(maxPosition);
return containingFile.addAfter(declaration, maxPosition);
}
@Override
protected Collection<String> generateSuggestedNames(@NotNull final PyExpression expression) {
protected Collection<String> generateSuggestedNames(@NotNull PyExpression expression) {
Collection<String> names = new HashSet<>();
for (String name : super.generateSuggestedNames(expression)) {
names.add(StringUtil.toUpperCase(name));
@@ -71,6 +92,65 @@ public class PyIntroduceConstantHandler extends IntroduceHandler {
return super.isValidIntroduceContext(element) || PsiTreeUtil.getParentOfType(element, PyParameterList.class) != null;
}
@Override
protected boolean checkEnabled(@NotNull IntroduceOperation operation) {
PsiElement selectionElement = getOriginalSelectionCoveringElement(operation.getElement());
PsiFile containingFile = selectionElement.getContainingFile();
if (!(containingFile instanceof PyFile)) return false;
Editor editor = operation.getEditor();
if (editor == null) return false;
List<PsiElement> sameFileRefs = collectReferencedDefinitionsInSameFile(operation);
if (!ContainerUtil.all(sameFileRefs, it -> PyUtil.isTopLevel(it))) return false;
PsiElement maxPosition = getLowermostTopLevelStatement(sameFileRefs);
if (maxPosition == null) return true;
return PsiUtilCore.compareElementsByPosition(maxPosition, selectionElement) <= 0 &&
!PsiTreeUtil.isAncestor(maxPosition, selectionElement, false);
}
private static @NotNull List<PsiElement> collectReferencedDefinitionsInSameFile(@NotNull IntroduceOperation operation) {
PsiElement selectionElement = getOriginalSelectionCoveringElement(operation.getElement());
TextRange textRange = getTextRangeForOperationElement(operation.getElement());
return StreamEx.of(PsiTreeUtil.collectElementsOfType(selectionElement, PyReferenceExpression.class))
.filter(it -> textRange.contains(it.getTextRange()))
.filter(ref -> !ref.isQualified())
.flatMap(expr -> PyResolveUtil.resolveLocally(expr).stream())
.filter(it -> it != null && it.getContainingFile() == operation.getFile())
.toList();
}
private static @NotNull TextRange getTextRangeForOperationElement(@NotNull PsiElement operationElement) {
var userData = operationElement.getUserData(PyReplaceExpressionUtil.SELECTION_BREAKS_AST_NODE);
if (userData == null || userData.first == null || userData.second == null) {
return operationElement.getTextRange();
}
else {
return userData.second.shiftRight(userData.first.getTextOffset());
}
}
private static @NotNull PsiElement getOriginalSelectionCoveringElement(@NotNull PsiElement operationElement) {
var userData = operationElement.getUserData(PyReplaceExpressionUtil.SELECTION_BREAKS_AST_NODE);
return userData == null ? operationElement : userData.first;
}
private static @Nullable PsiElement getLowermostTopLevelStatement(@NotNull List<PsiElement> elements) {
return StreamEx.of(elements)
.map(it -> PyPsiUtils.getParentRightBefore(it, it.getContainingFile()))
.select(PyStatement.class)
.max(PsiUtilCore::compareElementsByPosition)
.orElse(null);
}
@Override protected void showCanNotIntroduceErrorHint(@NotNull Project project, @NotNull Editor editor) {
String message =
RefactoringBundle.getCannotRefactorMessage(PyBundle.message("refactoring.introduce.constant.cannot.extract.selected.expression"));
CommonRefactoringUtil.showErrorHint(project, editor, message, myDialogTitle, getHelpId());
}
@Override
protected String getHelpId() {
return "python.reference.introduceConstant";

View File

@@ -0,0 +1,10 @@
from six import PY2
if PY2:
def ascii(obj):
...
a = ascii(42) + 'foo'
def func(p):
X = a

View File

@@ -0,0 +1,9 @@
from six import PY2
if PY2:
def ascii(obj):
...
def func(p):
X = <selection>ascii(42) + 'foo'</selection>

View File

@@ -0,0 +1,9 @@
from six import PY2
if PY2:
def ascii(obj):
...
def func(p):
X = <selection>ascii(p) + 'foo'</selection>

View File

@@ -0,0 +1,6 @@
SUFFIX = "foo"
def func():
from sys import version
print(<selection>version + SUFFIX</selection>)

View File

@@ -0,0 +1,8 @@
SUFFIX = "foo"
from sys import version
a = version + SUFFIX
def func():
print(a)

View File

@@ -0,0 +1,6 @@
SUFFIX = "foo"
from sys import version
def func():
print(<selection>version + SUFFIX</selection>)

View File

@@ -0,0 +1,12 @@
X = 42
N = 42
A = 11
M = 24
K = 21
T = 28
a = N + M + T + 1
O = 22
def f():
print(a)

View File

@@ -0,0 +1,11 @@
X = 42
N = 42
A = 11
M = 24
K = 21
T = 28
O = 22
def f():
print(<selection>N + M + T + 1</selection>)

View File

@@ -0,0 +1,6 @@
N = 42
def f(K):
for I in range(0, 10):
for j in range(0, 10):
print(<selection>N + K + I</selection>)

View File

@@ -0,0 +1,5 @@
N = 42
def f(K):
for i in range(0, 10):
print(<selection>N + K</selection>)

View File

@@ -0,0 +1,6 @@
N = 42
a = N + 1
def f():
print(a)

View File

@@ -0,0 +1,5 @@
N = 42
def f():
print(<selection>N + 1</selection>)

View File

@@ -0,0 +1,7 @@
if True:
X = 1
else:
X = 2
a = X + 1
print(a)

View File

@@ -0,0 +1,6 @@
if True:
X = 1
else:
X = 2
print(<selection>X + 1</selection>)

View File

@@ -0,0 +1,5 @@
def f(K):
for I in range(0, 10):
N = 42
for j in range(0, 10):
print(<selection>N + K + I</selection>)

View File

@@ -0,0 +1,6 @@
def foo():
N = 42
K = 24
def f():
print(<selection>N + 1</selection>)

View File

@@ -0,0 +1,2 @@
with 42 as N:
print(<selection>N + 1</selection>)

View File

@@ -0,0 +1,2 @@
def func(param):
return par<selection>am + 42</selection>

View File

@@ -0,0 +1,6 @@
SOME_GLOBAL = 42
a = 2 + SOME_GLOBAL<caret>
def func():
return 1 + a

View File

@@ -0,0 +1,5 @@
SOME_GLOBAL = 42
def func():
return 1 + <selection>2 + SOME_GLOBAL</selection>

View File

@@ -0,0 +1,2 @@
def func(p):
return 1 + <selection>2 + p</selection>

View File

@@ -1,6 +1,7 @@
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.refactoring;
import com.intellij.refactoring.util.CommonRefactoringUtil;
import com.intellij.testFramework.TestDataPath;
import com.jetbrains.python.psi.LanguageLevel;
import com.jetbrains.python.psi.PyExpression;
@@ -43,6 +44,85 @@ public class PyIntroduceConstantTest extends PyIntroduceTestCase {
doTest();
}
// PY-23500
public void testInsertAfterGlobalVariableOnWhichDepends() {
doTest();
}
// PY-23500
public void testInsertAfterAllGlobalVariablesOnWhichDepends() {
doTest();
}
// PY-23500
public void testInsertAfterWithStatementOnWhichDependsRefactoringError() {
doTestThrowsRefactoringErrorHintException();
}
// PY-23500
public void testInsertAfterLocalVariableOnWhichDependsRefactoringError() {
doTestThrowsRefactoringErrorHintException();
}
// PY-23500
public void testInsertAfterFunctionParameterOnWhichDependsRefactoringError() {
doTestThrowsRefactoringErrorHintException();
}
// PY-23500
public void testInsertAfterForIteratorOnWhichDependsRefactoringError() {
doTestThrowsRefactoringErrorHintException();
}
// PY-23500
public void testInsertAfterLocalVariableInForLoopOnWhichDependsRefactoringError() {
doTestThrowsRefactoringErrorHintException();
}
// PY-23500
public void testInsertAfterIfElse() {
doTest();
}
// PY-23500
public void testFromImportTopLevel() {
doTest();
}
// PY-23500
public void testFromImportInFunctionRefactoringError() {
doTestThrowsRefactoringErrorHintException();
}
// PY-23500
public void testExpressionWithFunctionCall() {
doTest();
}
// PY-23500
public void testExpressionWithParameterRefactoringError() {
doTestThrowsRefactoringErrorHintException();
}
// PY-23500
public void testSubexpressionWithParameterRefactoringError() {
doTestThrowsRefactoringErrorHintException();
}
// PY-23500
public void testSubexpressionWithGlobal() {
doTest();
}
// PY-23500
public void testSubexpressionNotFullWordRefactoringError() {
doTestThrowsRefactoringErrorHintException();
}
private void doTestThrowsRefactoringErrorHintException() {
assertThrows(CommonRefactoringUtil.RefactoringErrorHintException.class, () -> doTest());
}
@Override
protected String getTestDataPath() {
return super.getTestDataPath() + "/refactoring/introduceConstant";