IDEA-CR-49176: fixes after review of function inline for python

- codestyle improvements
- reuse TypeEvalContext and PyResolveContext
- check PyDunderAllReference specifically to avoid unwanted side effects
- skip resolving PyTargetExpression
- replace references with import alias in the refactoring, instead of the utility method
- added single reference search for PyStringLiteralExpressionImpl
- cleanup obsolete tests

GitOrigin-RevId: 3e33c880a8fe116a5b98583059c83b7dd67e34a2
This commit is contained in:
Aleksei Kniazev
2019-07-03 11:59:15 +03:00
committed by intellij-monorepo-bot
parent e867306f59
commit 2d720aea16
19 changed files with 83 additions and 98 deletions

View File

@@ -28,6 +28,7 @@ import com.intellij.psi.impl.source.resolve.reference.ReferenceProvidersRegistry
import com.intellij.psi.tree.IElementType;
import com.intellij.psi.tree.TokenSet;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.ArrayUtil;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.PyElementTypes;
import com.jetbrains.python.PyTokenTypes;
@@ -204,6 +205,11 @@ public class PyStringLiteralExpressionImpl extends PyElementImpl implements PySt
return ReferenceProvidersRegistry.getReferencesFromProviders(this, Hints.NO_HINTS);
}
@Override
public PsiReference getReference() {
return ArrayUtil.getFirstElement(getReferences());
}
@Override
public ItemPresentation getPresentation() {
return new ItemPresentation() {

View File

@@ -57,7 +57,6 @@ public final class PyClassRefactoringUtil {
private static final Key<PsiNamedElement> ENCODED_IMPORT = Key.create("PyEncodedImport");
private static final Key<Boolean> ENCODED_USE_FROM_IMPORT = Key.create("PyEncodedUseFromImport");
private static final Key<String> ENCODED_IMPORT_AS = Key.create("PyEncodedImportAs");
private static final Key<PyReferenceExpression> REPLACEMENT_EXPRESSION = Key.create("PyReplacementExpression");
private PyClassRefactoringUtil() {
@@ -252,7 +251,6 @@ public final class PyClassRefactoringUtil {
PsiNamedElement target = sourceNode.getCopyableUserData(ENCODED_IMPORT);
final String asName = sourceNode.getCopyableUserData(ENCODED_IMPORT_AS);
final Boolean useFromImport = sourceNode.getCopyableUserData(ENCODED_USE_FROM_IMPORT);
final PyReferenceExpression replacement = sourceNode.getCopyableUserData(REPLACEMENT_EXPRESSION);
if (target instanceof PsiDirectory) {
target = (PsiNamedElement)PyUtil.getPackageElement((PsiDirectory)target, sourceNode);
}
@@ -272,9 +270,6 @@ public final class PyClassRefactoringUtil {
else {
insertImport(targetNode, target, asName, true);
}
if (replacement != null) {
sourceNode.replace(replacement);
}
}
finally {
sourceNode.putCopyableUserData(ENCODED_IMPORT, null);
@@ -434,20 +429,25 @@ public final class PyClassRefactoringUtil {
}
/**
* Forces the use of 'import as' when restoring references (i.e. if there are name clashes). Takes an optional replacement expression
* to insert in place of the node after import.
* Forces the use of 'import as' when restoring references (i.e. if there are name clashes)
* @param node with encoded import
* @param asName new alias for import
* @param replacement reference after import
*/
public static void forceAsName(@NotNull PyReferenceExpression node, @NotNull String asName, @Nullable PyReferenceExpression replacement) {
public static void forceAsName(@NotNull PyReferenceExpression node, @NotNull String asName) {
if (node.getCopyableUserData(ENCODED_IMPORT) == null) {
LOG.warn("As name is forced on the referenceExpression, that has no encoded import. Forcing it will likely be ignored.");
}
node.putCopyableUserData(ENCODED_IMPORT_AS, asName);
if (replacement != null) {
node.putCopyableUserData(REPLACEMENT_EXPRESSION, replacement);
}
}
public static void transferEncodedImports(@NotNull PyReferenceExpression source, @NotNull PyReferenceExpression target) {
target.putCopyableUserData(ENCODED_IMPORT, source.getCopyableUserData(ENCODED_IMPORT));
target.putCopyableUserData(ENCODED_IMPORT_AS, source.getCopyableUserData(ENCODED_IMPORT_AS));
target.putCopyableUserData(ENCODED_USE_FROM_IMPORT, source.getCopyableUserData(ENCODED_USE_FROM_IMPORT));
source.putCopyableUserData(ENCODED_IMPORT, null);
source.putCopyableUserData(ENCODED_IMPORT_AS, null);
source.putCopyableUserData(ENCODED_USE_FROM_IMPORT, null);
}
public static boolean hasEncodedTarget(@NotNull PyReferenceExpression node) {

View File

@@ -12,7 +12,6 @@ import com.intellij.refactoring.inline.InlineOptionsDialog
import com.jetbrains.python.PyBundle
import com.jetbrains.python.psi.PyFunction
import com.jetbrains.python.psi.PyImportStatementBase
import com.jetbrains.python.psi.PyReferenceExpression
import com.jetbrains.python.pyi.PyiUtil
/**
@@ -27,10 +26,7 @@ class PyInlineFunctionDialog(project: Project,
private val myNumberOfOccurrences: Int = getNumberOfOccurrences(myFunction)
init {
myInvokedOnReference = if (myReference != null) {
val expression = myReference.element as PyReferenceExpression
PsiTreeUtil.getParentOfType(expression, PyImportStatementBase::class.java) == null
} else false
myInvokedOnReference = PsiTreeUtil.getParentOfType(myReference?.element, PyImportStatementBase::class.java) == null
title = if (isMethod) "Inline method $myFunctionName" else "Inline function $myFunctionName"
init()
}
@@ -65,10 +61,8 @@ class PyInlineFunctionDialog(project: Project,
val originalNum = super.getNumberOfOccurrences(nameIdentifierOwner)
val stubOrImplementation = if (PyiUtil.isInsideStub(myFunction)) PyiUtil.getOriginalElement(myFunction) else PyiUtil.getPythonStub(myFunction)
if (originalNum != -1 && stubOrImplementation != null) {
val fromDeclaration = ReferencesSearch.search(stubOrImplementation, GlobalSearchScope.projectScope(myProject)).asSequence()
.filter(this::ignoreOccurrence)
.count()
return originalNum + fromDeclaration
val fromOtherLocation = super.getNumberOfOccurrences(stubOrImplementation as PsiNameIdentifierOwner)
if (fromOtherLocation != -1) return originalNum + fromOtherLocation
}
return originalNum
}

View File

@@ -24,7 +24,7 @@ import com.jetbrains.python.psi.types.TypeEvalContext
import com.jetbrains.python.pyi.PyiFile
import com.jetbrains.python.pyi.PyiUtil
import com.jetbrains.python.sdk.PySdkUtil
import com.jetbrains.python.sdk.pythonSdk
import com.jetbrains.python.sdk.PythonSdkType
/**
* @author Aleksei.Kniazev
@@ -49,7 +49,7 @@ class PyInlineFunctionHandler : InlineActionHandler() {
PyNames.INIT == element.name -> "refactoring.inline.function.constructor"
PyBuiltinCache.getInstance(element).isBuiltin(element) -> "refactoring.inline.function.builtin"
isSpecialMethod(element) -> "refactoring.inline.function.special.method"
isUnderSkeletonDir(element, project) -> "refactoring.inline.function.skeleton.only"
isUnderSkeletonDir(element) -> "refactoring.inline.function.skeleton.only"
hasDecorators(element) -> "refactoring.inline.function.decorator"
hasReferencesToSelf(element) -> "refactoring.inline.function.self.referrent"
hasStarArgs(element) -> "refactoring.inline.function.star"
@@ -139,8 +139,9 @@ class PyInlineFunctionHandler : InlineActionHandler() {
private fun hasReferencesToSelf(function: PyFunction): Boolean = SyntaxTraverser.psiTraverser(function.statementList)
.any { it is PyReferenceExpression && it.reference.isReferenceTo(function) }
private fun isUnderSkeletonDir(function: PyFunction, project: Project): Boolean {
val skeletonsDir = PySdkUtil.findSkeletonsDir(project.pythonSdk ?: return false) ?: return false
private fun isUnderSkeletonDir(function: PyFunction): Boolean {
val sdk = PythonSdkType.findPythonSdk(function.containingFile) ?: return false
val skeletonsDir = PySdkUtil.findSkeletonsDir(sdk) ?: return false
return VfsUtil.isAncestor(skeletonsDir, function.containingFile.virtualFile, true)
}

View File

@@ -15,6 +15,8 @@ import com.intellij.usageView.UsageInfo
import com.intellij.usageView.UsageViewDescriptor
import com.intellij.util.containers.MultiMap
import com.jetbrains.python.PyBundle
import com.jetbrains.python.PyNames
import com.jetbrains.python.codeInsight.PyDunderAllReference
import com.jetbrains.python.codeInsight.controlflow.ControlFlowCache
import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil
import com.jetbrains.python.psi.*
@@ -48,11 +50,8 @@ class PyInlineFunctionProcessor(project: Project,
val usagesAndImports = refUsages.get()
val (imports, usages) = usagesAndImports.partition { PsiTreeUtil.getParentOfType(it.element, PyImportStatementBase::class.java) != null }
val filteredUsages = usages.filter { usage ->
if (usage.reference is PyDunderAllReference) return@filter true
val element = usage.element!!
if (element is PyStringLiteralExpression) {
val file = element.containingFile as? PyFile
if (file?.dunderAll?.contains(element.stringValue) == true) return@filter true
}
if (element.parent is PyDecorator) {
if (!handleUsageError(element, "refactoring.inline.function.is.decorator", conflicts)) return false
return@filter false
@@ -123,12 +122,15 @@ class PyInlineFunctionProcessor(project: Project,
private fun doRefactor(usages: Array<out UsageInfo>) {
val (unsortedRefs, imports) = usages.partition { PsiTreeUtil.getParentOfType(it.element, PyImportStatementBase::class.java) == null }
val (callRefs, dunderAll) = unsortedRefs.partition { it.element is PyReferenceExpression }
val (callRefs, dunderAll) = unsortedRefs.partition { it.reference !is PyDunderAllReference }
val references = callRefs.sortedByDescending { usage ->
SyntaxTraverser.psiApi().parents(usage.element).asSequence().filter { it is PyCallExpression }.count()
}
val typeEvalContext = TypeEvalContext.userInitiated(myProject, null)
val resolveContext = PyResolveContext.noImplicits().withTypeEvalContext(typeEvalContext)
val selfUsed = myFunction.parameterList.parameters.firstOrNull()?.let { firstParam ->
if (!firstParam.isSelf) return@let false
return@let SyntaxTraverser.psiTraverser(myFunction.statementList).traverse()
@@ -167,7 +169,7 @@ class PyInlineFunctionProcessor(project: Project,
val importAsRefs = MultiMap.create<String, PyReferenceExpression>()
val returnStatements = mutableListOf<PyReturnStatement>()
val mappedArguments = prepareArguments(callSite, declarations, generatedNames, scopeAnchor, reference, languageLevel, selfUsed)
val mappedArguments = prepareArguments(callSite, declarations, generatedNames, scopeAnchor, reference, languageLevel, resolveContext, selfUsed)
myFunction.statementList.accept(object : PyRecursiveElementVisitor() {
override fun visitPyReferenceExpression(node: PyReferenceExpression) {
@@ -175,7 +177,7 @@ class PyInlineFunctionProcessor(project: Project,
val name = node.name!!
if (name in namesInOuterScope && name !in mappedArguments) {
val resolved = node.reference.resolve()
val target = (resolved as? PyFunction)?.containingClass ?: resolved
val target = if (resolved is PyFunction && resolved.containingClass != null && resolved.name == PyNames.INIT) resolved.containingClass else resolved
if (!builtinCache.isBuiltin(target) && target !in PyResolveUtil.resolveLocally(refScopeOwner, name)) {
if (PyClassRefactoringUtil.hasEncodedTarget(node)) importAsTargets.add(name)
else nameClashes.add(name)
@@ -189,11 +191,7 @@ class PyInlineFunctionProcessor(project: Project,
if (!node.isQualified) {
val name = node.name!!
if (name in namesInOuterScope && name !in mappedArguments && functionScope.containsDeclaration(name)) {
val resolved = node.reference.resolve()
val target = (resolved as? PyFunction)?.containingClass ?: resolved
if (!builtinCache.isBuiltin(target) && target !in PyResolveUtil.resolveLocally(refScopeOwner, name)) {
nameClashes.add(name)
}
nameClashes.add(name)
}
}
super.visitPyTargetExpression(node)
@@ -243,7 +241,11 @@ class PyInlineFunctionProcessor(project: Project,
importAsRefs.entrySet().forEach { (name, elements) ->
val newRef = generateUniqueAssignment(languageLevel, name, generatedNames, scopeAnchor).assignedValue as PyReferenceExpression
elements.forEach { PyClassRefactoringUtil.forceAsName(it, newRef.name!!, newRef) }
elements.forEach {
PyClassRefactoringUtil.transferEncodedImports(it, newRef)
PyClassRefactoringUtil.forceAsName(newRef, newRef.name!!)
it.replace(newRef)
}
}
if (returnStatements.size == 1 && returnStatements[0].expression !is PyTupleExpression) {
@@ -300,7 +302,7 @@ class PyInlineFunctionProcessor(project: Project,
if (stubFunction != null && stubFunction.isWritable) {
stubFunction.delete()
}
val typingOverloads = PyiUtil.getOverloads(myFunction, TypeEvalContext.userInitiated(myProject, file))
val typingOverloads = PyiUtil.getOverloads(myFunction, typeEvalContext)
if (typingOverloads.isNotEmpty()) {
typingOverloads.forEach { it.delete() }
}
@@ -311,8 +313,7 @@ class PyInlineFunctionProcessor(project: Project,
}
private fun prepareArguments(callSite: PyCallExpression, declarations: MutableList<PyAssignmentStatement>, generatedNames: MutableSet<String>, scopeAnchor: PsiElement,
reference: PyReferenceExpression, languageLevel: LanguageLevel, selfUsed: Boolean): Map<String, PyExpression> {
val context = PyResolveContext.noImplicits().withTypeEvalContext(TypeEvalContext.userInitiated(myProject, reference.containingFile))
reference: PyReferenceExpression, languageLevel: LanguageLevel, context: PyResolveContext, selfUsed: Boolean): Map<String, PyExpression> {
val mapping = PyCallExpressionHelper.mapArguments(callSite, context).firstOrNull() ?: error("Can't map arguments for ${reference.name}")
val mappedParams = mapping.mappedParameters
val firstImplicit = mapping.implicitParameters.firstOrNull()

View File

@@ -1,10 +0,0 @@
from source import do_more_work
from source import do_work
from source import do_substantially_more_work as do_nothing
def bar():
do_work()
do_more_work(42)
do_nothing()
res = 42

View File

@@ -1,7 +0,0 @@
from source import foo
from source import do_work
from source import do_substantially_more_work as do_nothing
def bar():
res = fo<caret>o(42)

View File

@@ -1,19 +0,0 @@
def do_work():
pass
def do_more_work(x):
pass
def do_substantially_more_work():
pass
def foo(arg):
do_work()
do_more_work(arg)
do_substantially_more_work()
return arg

View File

@@ -0,0 +1,3 @@
from sys import exit
ex<caret>it()

View File

@@ -0,0 +1,4 @@
def foo():
pass # should not be inserted after inline

View File

@@ -0,0 +1,5 @@
def foo():
pass # should not be inserted after inline
fo<caret>o()

View File

@@ -0,0 +1,4 @@
def func():
"""Only a docstring."""

View File

@@ -0,0 +1,5 @@
def func():
"""Only a docstring."""
fu<caret>nc()

View File

@@ -1,5 +0,0 @@
print(1)
print(2)
z = 1 + 2
print(z)
res = z

View File

@@ -1,3 +0,0 @@
from src import foo as bar
res = ba<caret>r(1, 2)

View File

@@ -1,6 +0,0 @@
def foo(x, y):
print(x)
print(y)
z = x + y
print(z)
return z

View File

@@ -0,0 +1,6 @@
def func():
"""Only a docstring."""
if True:
pass

View File

@@ -0,0 +1,6 @@
def func():
"""Only a docstring."""
if True:
fu<caret>nc()

View File

@@ -3,7 +3,6 @@ package com.jetbrains.python.refactoring
import com.intellij.codeInsight.TargetElementUtil
import com.intellij.refactoring.util.CommonRefactoringUtil
import com.jetbrains.python.codeInsight.PyCodeInsightSettings
import com.jetbrains.python.fixtures.PyTestCase
import com.jetbrains.python.psi.LanguageLevel
import com.jetbrains.python.psi.PyElement
@@ -12,7 +11,6 @@ import com.jetbrains.python.pyi.PyiFile
import com.jetbrains.python.pyi.PyiUtil
import com.jetbrains.python.refactoring.inline.PyInlineFunctionHandler
import com.jetbrains.python.refactoring.inline.PyInlineFunctionProcessor
import junit.framework.TestCase
/**
* @author Aleksei.Kniazev
@@ -28,7 +26,7 @@ class PyInlineFunctionTest : PyTestCase() {
var element = TargetElementUtil.findTargetElement(myFixture.editor, TargetElementUtil.getInstance().referenceSearchFlags)
if (element!!.containingFile is PyiFile) element = PyiUtil.getOriginalElement(element as PyElement)
val reference = TargetElementUtil.findReference(myFixture.editor)
TestCase.assertTrue(element is PyFunction)
assertTrue(element is PyFunction)
PyInlineFunctionProcessor(myFixture.project, myFixture. editor, element as PyFunction, reference, inlineThis, remove).run()
myFixture.checkResultByFile("$testName/main.after.py")
}
@@ -44,10 +42,10 @@ class PyInlineFunctionTest : PyTestCase() {
else {
PyInlineFunctionHandler.getInstance().inlineElement(myFixture.project, myFixture.editor, element)
}
TestCase.fail("Expected error: $expectedError, but got none")
fail("Expected error: $expectedError, but got none")
}
catch (e: CommonRefactoringUtil.RefactoringErrorHintException) {
TestCase.assertEquals(expectedError, e.message)
assertEquals(expectedError, e.message)
}
}
@@ -60,7 +58,6 @@ class PyInlineFunctionTest : PyTestCase() {
fun testMultipleReturns() = doTest()
fun testImporting() = doTest()
fun testImportAs() = doTest()
//fun testExistingImports() = doTest()
fun testMethodInsideClass() = doTest()
fun testMethodOutsideClass() = doTest()
fun testNoReturnsAsExpressionStatement() = doTest()
@@ -85,8 +82,10 @@ class PyInlineFunctionTest : PyTestCase() {
fun testKeepingComments() = doTest()
fun testInvocationOnImport() = doTest(inlineThis = false, remove = true)
fun testImportedLocally() = doTest(inlineThis = false, remove = true)
//fun testInlineImportedAs() = doTest(inlineThis = false)
fun testSelfUsageDetection() = doTest(inlineThis = false, remove = true)
fun testIgnoreSolePassStatement() = doTest()
fun testInlineDocstringOnlyFunction() = doTest()
fun testTurnDocstringOnlyFunctionIntoPass() = doTest()
fun testOptimizeImportsAtDeclarationSite() {
doTest(inlineThis = false, remove = true)
val testName = getTestName(true)
@@ -113,6 +112,7 @@ class PyInlineFunctionTest : PyTestCase() {
fun testOverridden() = doTestError("Cannot inline overridden methods")
fun testNested() = doTestError("Cannot inline functions with another function declaration")
fun testInterruptedFlow() = doTestError("Cannot inline functions that interrupt control flow")
fun testFunctionFromBinaryStub() = doTestError("Cannot inline function from binary module")
fun testUsedAsDecorator() = doTestError("Function foo is used as a decorator and cannot be inlined. Function definition will not be removed", isReferenceError = true)
fun testUsedAsReference() = doTestError("Function foo is used as a reference and cannot be inlined. Function definition will not be removed", isReferenceError = true)
fun testUsesArgumentUnpacking() = doTestError("Function foo uses argument unpacking and cannot be inlined. Function definition will not be removed", isReferenceError = true)