IDEA-CR-48018: inline function refactoring for python (PY-21287)

- ability to inline single/all invocations and keep/remove the declaration
- handled name conflicts of local and imported vars/functions
- cases when this refactoring is not applicable can be found in PyInlineFunctionHandler

(cherry picked from commit 40e298ba2b833a4408ee774628b84fe422468b91)

GitOrigin-RevId: dd3fab59800163b4f300b410cd28a62c1ee1b6f3
This commit is contained in:
Aleksei Kniazev
2019-06-04 18:53:58 +03:00
committed by intellij-monorepo-bot
parent ce764be06a
commit c6e4181228
74 changed files with 1370 additions and 9 deletions

View File

@@ -445,6 +445,7 @@
<lang.refactoringSupport.classMembersRefactoringSupport language="Python"
implementationClass="com.jetbrains.python.refactoring.classes.PyMembersRefactoringSupport"/>
<inlineActionHandler implementation="com.jetbrains.python.refactoring.inline.PyInlineLocalHandler"/>
<inlineActionHandler implementation="com.jetbrains.python.refactoring.inline.PyInlineFunctionHandler"/>
<codeInsight.gotoSuper language="Python" implementationClass="com.jetbrains.python.codeInsight.PyGotoSuperHandler"/>
<gotoDeclarationHandler implementation="com.jetbrains.python.codeInsight.PyBreakContinueGotoProvider" order="FIRST"/>
<gotoDeclarationHandler implementation="com.jetbrains.python.psi.impl.PyGotoDeclarationHandler"/>

View File

@@ -645,6 +645,27 @@ refactoring.push.down.error.cannot.perform.refactoring.not.inside.class=Cannot p
# inline
refactoring.inline.local.multiassignment=Definition is in multi-assign
# inline function
refactoring.inline.function.title=Inline Function
refactoring.inline.this.only=Inline this invocation only and keep the declaration
refactoring.inline.all.keep.declaration=Inline all invocations and keep the declaration
refactoring.inline.all.remove.declaration=Inline all invocations and remove the declaration
refactoring.inline.function.is.decorator=Function {0} is used as a decorator and cannot be inlined. Function definition will not be removed
refactoring.inline.function.is.reference=Function {0} is used as a reference and cannot be inlined. Function definition will not be removed
refactoring.inline.function.uses.unpacking=Function {0} uses argument unpacking and cannot be inlined. Function definition will not be removed
refactoring.inline.function.generator=Cannot inline generators
refactoring.inline.function.async=Cannot inline async functions
refactoring.inline.function.constructor=Cannot inline constructor calls
refactoring.inline.function.builtin=Cannot inline builtin functions
refactoring.inline.function.decorator=Cannot inline functions with decorators
refactoring.inline.function.self.referrent=Cannot inline functions that reference themselves
refactoring.inline.function.star=Cannot inline functions with * arguments
refactoring.inline.function.overridden=Cannot inline overridden functions
refactoring.inline.function.global=Cannot inline functions with global variables
refactoring.inline.function.nonlocal=Cannot inline functions with nonlocals variables
refactoring.inline.function.nested=Cannot inline functions with another function declaration
refactoring.inline.function.interrupts.flow=Cannot inline functions that interrupt control flow
# extract method
refactoring.extract.method=Extract method
refactoring.extract.method.error.interrupted.execution.flow=Cannot perform refactoring when execution flow is interrupted

View File

@@ -36,10 +36,13 @@ public interface Scope {
@Nullable
ScopeVariable getDeclaredVariable(@NotNull PsiElement anchorElement, @NotNull String name) throws DFALimitExceededException;
boolean hasGlobals();
boolean isGlobal(String name);
boolean hasNonLocals();
boolean isNonlocal(String name);
boolean hasNestedScopes();
boolean containsDeclaration(String name);
@NotNull

View File

@@ -74,6 +74,15 @@ public class ScopeImpl implements Scope {
}
}
@Override
public boolean hasGlobals() {
if (myGlobals == null || myNestedScopes == null) {
collectDeclarations();
}
if (!myGlobals.isEmpty()) return true;
return myNestedScopes.stream().anyMatch(scope -> scope.hasGlobals());
}
@Override
public boolean isGlobal(final String name) {
if (myGlobals == null || myNestedScopes == null) {
@@ -90,6 +99,15 @@ public class ScopeImpl implements Scope {
return false;
}
@Override
public boolean hasNonLocals() {
if (myNonlocals == null || myNestedScopes == null) {
collectDeclarations();
}
if (!myNonlocals.isEmpty()) return true;
return myNestedScopes.stream().anyMatch(scope -> scope.hasNonLocals());
}
@Override
public boolean isNonlocal(final String name) {
if (myNonlocals == null || myNestedScopes == null) {
@@ -98,6 +116,14 @@ public class ScopeImpl implements Scope {
return myNonlocals.contains(name);
}
@Override
public boolean hasNestedScopes() {
if (myNestedScopes == null) {
collectDeclarations();
}
return !myNestedScopes.isEmpty();
}
private boolean isAugAssignment(final String name) {
if (myAugAssignments == null || myNestedScopes == null) {
collectDeclarations();

View File

@@ -20,6 +20,7 @@ import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.*;
import java.util.function.BiPredicate;
public class PyRefactoringUtil {
private PyRefactoringUtil() {
@@ -326,7 +327,7 @@ public class PyRefactoringUtil {
*/
@NotNull
public static String selectUniqueNameFromType(@NotNull String typeName, @NotNull PsiElement scopeAnchor) {
return selectUniqueName(typeName, true, scopeAnchor);
return selectUniqueName(typeName, true, scopeAnchor, PyRefactoringUtil::isValidNewName);
}
/**
@@ -339,11 +340,16 @@ public class PyRefactoringUtil {
*/
@NotNull
public static String selectUniqueName(@NotNull String templateName, @NotNull PsiElement scopeAnchor) {
return selectUniqueName(templateName, false, scopeAnchor);
return selectUniqueName(templateName, false, scopeAnchor, PyRefactoringUtil::isValidNewName);
}
@NotNull
private static String selectUniqueName(@NotNull String templateName, boolean templateIsType, @NotNull PsiElement scopeAnchor) {
public static String selectUniqueName(@NotNull String templateName, @NotNull PsiElement scopeAnchor, @NotNull BiPredicate<String, PsiElement> isValid) {
return selectUniqueName(templateName, false, scopeAnchor, isValid);
}
@NotNull
private static String selectUniqueName(@NotNull String templateName, boolean templateIsType, @NotNull PsiElement scopeAnchor, @NotNull BiPredicate<String, PsiElement> isValid) {
final Collection<String> suggestions;
if (templateIsType) {
suggestions = NameSuggesterUtil.generateNamesByType(templateName);
@@ -352,14 +358,14 @@ public class PyRefactoringUtil {
suggestions = NameSuggesterUtil.generateNames(templateName);
}
for (String name : suggestions) {
if (isValidNewName(name, scopeAnchor)) {
if (isValid.test(name, scopeAnchor)) {
return name;
}
}
final String shortestName = ContainerUtil.getFirstItem(suggestions);
//noinspection ConstantConditions
return appendNumberUntilValid(shortestName, scopeAnchor);
return appendNumberUntilValid(shortestName, scopeAnchor, isValid);
}
/**
@@ -367,13 +373,14 @@ public class PyRefactoringUtil {
*
* @param name initial name
* @param scopeAnchor PSI element used to determine correct scope
* @param predicate used to test if suggested name is valid
* @return unique name in the scope probably with number suffix appended
*/
@NotNull
public static String appendNumberUntilValid(@NotNull String name, @NotNull PsiElement scopeAnchor) {
public static String appendNumberUntilValid(@NotNull String name, @NotNull PsiElement scopeAnchor, @NotNull BiPredicate<String, PsiElement> predicate) {
int counter = 1;
String candidate = name;
while (!isValidNewName(candidate, scopeAnchor)) {
while (!predicate.test(candidate, scopeAnchor)) {
candidate = name + counter;
counter++;
}

View File

@@ -0,0 +1,45 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.refactoring.inline
import com.intellij.openapi.editor.Editor
import com.intellij.openapi.project.Project
import com.intellij.psi.PsiReference
import com.intellij.psi.util.PsiTreeUtil
import com.intellij.refactoring.inline.InlineOptionsDialog
import com.jetbrains.python.PyBundle
import com.jetbrains.python.psi.PyFunction
import com.jetbrains.python.psi.PyImportStatementBase
/**
* @author Aleksei.Kniazev
*/
class PyInlineFunctionDialog(project: Project,
private val myEditor: Editor,
private val myFunction: PyFunction,
private val myReference: PsiReference?) : InlineOptionsDialog(project, true, myFunction) {
private val isMethod = myFunction.asMethod() != null
init {
myInvokedOnReference = myReference != null
title = if (isMethod) "Inline method ${myFunction.name}" else "Inline function ${myFunction.name}"
init()
}
override fun doAction() {
invokeRefactoring(PyInlineFunctionProcessor(myProject, myEditor, myFunction, myReference, isInlineThisOnly, !isKeepTheDeclaration))
}
override fun getNameLabelText(): String = "The number of occurrences: ${getNumberOfOccurrences(myFunction)}"
override fun getBorderTitle(): String = "Inline"
override fun getInlineAllText(): String = PyBundle.message("refactoring.inline.all.remove.declaration")
override fun getKeepTheDeclarationText(): String = PyBundle.message("refactoring.inline.all.keep.declaration")
override fun getInlineThisText(): String = PyBundle.message("refactoring.inline.this.only")
override fun getHelpId(): String = PyInlineFunctionHandler.REFACTORING_ID
override fun allowInlineAll(): Boolean = true
override fun isInlineThis(): Boolean = true
override fun ignoreOccurrence(reference: PsiReference): Boolean {
return PsiTreeUtil.getParentOfType(reference.element, PyImportStatementBase::class.java) == null
}
}

View File

@@ -0,0 +1,127 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.refactoring.inline
import com.intellij.codeInsight.TargetElementUtil
import com.intellij.lang.Language
import com.intellij.lang.refactoring.InlineActionHandler
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.editor.Editor
import com.intellij.openapi.project.Project
import com.intellij.psi.PsiElement
import com.intellij.psi.SyntaxTraverser
import com.intellij.psi.util.PsiTreeUtil
import com.intellij.refactoring.util.CommonRefactoringUtil
import com.jetbrains.python.PyBundle
import com.jetbrains.python.PyNames
import com.jetbrains.python.PythonLanguage
import com.jetbrains.python.codeInsight.controlflow.ControlFlowCache
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.impl.PyBuiltinCache
import com.jetbrains.python.psi.search.PySuperMethodsSearch
import com.jetbrains.python.psi.types.TypeEvalContext
import com.jetbrains.python.pyi.PyiFile
/**
* @author Aleksei.Kniazev
*/
class PyInlineFunctionHandler : InlineActionHandler() {
override fun isEnabledForLanguage(l: Language?) = l is PythonLanguage
override fun canInlineElement(element: PsiElement?) = element is PyFunction && element.containingFile !is PyiFile
override fun inlineElement(project: Project?, editor: Editor?, element: PsiElement?) {
if (project == null || editor == null || element !is PyFunction) return
val functionScope = ControlFlowCache.getScope(element)
val error = when {
element.isAsync -> "refactoring.inline.function.async"
element.isGenerator -> "refactoring.inline.function.generator"
PyNames.INIT == element.name -> "refactoring.inline.function.constructor"
PyBuiltinCache.getInstance(element).isBuiltin(element) -> "refactoring.inline.function.builtin"
hasDecorators(element) -> "refactoring.inline.function.decorator"
hasReferencesToSelf(element) -> "refactoring.inline.function.self.referrent"
hasStarArgs(element) -> "refactoring.inline.function.star"
isOverride(element, project) -> "refactoring.inline.function.overridden"
functionScope.hasGlobals() -> "refactoring.inline.function.global"
functionScope.hasNonLocals() -> "refactoring.inline.function.nonlocal"
functionScope.hasNestedScopes() -> "refactoring.inline.function.nested"
hasNonExhaustiveIfs(element) -> "refactoring.inline.function.interrupts.flow"
else -> null
}
if (error != null) {
CommonRefactoringUtil.showErrorHint(project, editor, PyBundle.message(error), PyBundle.message("refactoring.inline.function.title"), REFACTORING_ID)
return
}
if (!ApplicationManager.getApplication().isUnitTestMode){
PyInlineFunctionDialog(project, editor, element, TargetElementUtil.findReference(editor)).show()
}
}
private fun hasNonExhaustiveIfs(function: PyFunction): Boolean {
val returns = mutableListOf<PyReturnStatement>()
function.accept(object : PyRecursiveElementVisitor() {
override fun visitPyReturnStatement(node: PyReturnStatement) {
returns.add(node)
}
})
if (returns.isEmpty()) return false
val cache = mutableSetOf<PyIfStatement>()
return returns.asSequence()
.map { PsiTreeUtil.getParentOfType(it, PyIfStatement::class.java) }
.distinct()
.filterNotNull()
.any { checkInterruptsControlFlow(it, cache) }
}
private fun checkInterruptsControlFlow(ifStatement: PyIfStatement, cache: MutableSet<PyIfStatement>): Boolean {
if (ifStatement in cache) return false
cache.add(ifStatement)
val elsePart = ifStatement.elsePart
if (elsePart == null) return true
if (checkLastStatement(ifStatement.ifPart.statementList, cache)) return true
if (checkLastStatement(elsePart.statementList, cache)) return true
ifStatement.elifParts.forEach { if (checkLastStatement(it.statementList, cache)) return true }
val parentIfStatement = PsiTreeUtil.getParentOfType(ifStatement, PyIfStatement::class.java)
if (parentIfStatement != null && checkInterruptsControlFlow(parentIfStatement, cache)) return true
return false
}
private fun checkLastStatement(statementList: PyStatementList, cache: MutableSet<PyIfStatement>): Boolean {
val statements = statementList.statements
if (statements.isEmpty()) return true
when(val last = statements.last()) {
is PyIfStatement -> if (checkInterruptsControlFlow(last, cache)) return true
!is PyReturnStatement -> return true
}
return false
}
private fun hasDecorators(function: PyFunction): Boolean = function.decoratorList?.decorators?.isNotEmpty() == true
private fun isOverride(function: PyFunction, project: Project): Boolean {
return function.containingClass != null
&& PySuperMethodsSearch.search(function, TypeEvalContext.codeAnalysis(project, function.containingFile)).any()
}
private fun hasStarArgs(function: PyFunction): Boolean {
return function.parameterList.parameters.asSequence()
.filterIsInstance<PyNamedParameter>()
.any { it.isPositionalContainer || it.isKeywordContainer }
}
private fun hasReferencesToSelf(function: PyFunction): Boolean = SyntaxTraverser.psiTraverser(function.statementList)
.any { it is PyReferenceExpression && it.reference.isReferenceTo(function) }
companion object {
@JvmStatic
fun getInstance(): PyInlineFunctionHandler {
return InlineActionHandler.EP_NAME.findExtensionOrFail(PyInlineFunctionHandler::class.java)
}
@JvmStatic
val REFACTORING_ID = "refactoring.inlineMethod"
}
}

View File

@@ -0,0 +1,311 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.refactoring.inline
import com.intellij.history.LocalHistory
import com.intellij.openapi.editor.Editor
import com.intellij.openapi.project.Project
import com.intellij.openapi.util.Ref
import com.intellij.psi.PsiElement
import com.intellij.psi.PsiReference
import com.intellij.psi.SyntaxTraverser
import com.intellij.psi.codeStyle.CodeStyleManager
import com.intellij.psi.search.searches.ReferencesSearch
import com.intellij.psi.util.PsiTreeUtil
import com.intellij.refactoring.BaseRefactoringProcessor
import com.intellij.refactoring.util.CommonRefactoringUtil
import com.intellij.usageView.UsageInfo
import com.intellij.usageView.UsageViewDescriptor
import com.intellij.util.containers.MultiMap
import com.jetbrains.python.PyBundle
import com.jetbrains.python.codeInsight.controlflow.ControlFlowCache
import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil
import com.jetbrains.python.codeInsight.imports.PyImportOptimizer
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.impl.PyBuiltinCache
import com.jetbrains.python.psi.impl.PyCallExpressionHelper
import com.jetbrains.python.psi.impl.PyPsiUtils
import com.jetbrains.python.psi.resolve.PyResolveContext
import com.jetbrains.python.psi.resolve.PyResolveUtil
import com.jetbrains.python.psi.types.TypeEvalContext
import com.jetbrains.python.pyi.PyiUtil
import com.jetbrains.python.refactoring.PyRefactoringUtil
import com.jetbrains.python.refactoring.classes.PyClassRefactoringUtil
import org.jetbrains.annotations.PropertyKey
/**
* @author Aleksei.Kniazev
*/
class PyInlineFunctionProcessor(project: Project,
private val myEditor: Editor,
private val myFunction: PyFunction,
private val myReference: PsiReference?,
private val myInlineThis: Boolean,
removeDeclaration: Boolean) : BaseRefactoringProcessor(project) {
private val myFunctionClass = myFunction.containingClass
private val myGenerator = PyElementGenerator.getInstance(myProject)
private var myRemoveDeclaration = !myInlineThis && removeDeclaration
override fun preprocessUsages(refUsages: Ref<Array<UsageInfo>>): Boolean {
if (refUsages.isNull) return false
val conflicts = MultiMap.create<PsiElement, String>()
val usagesAndImports = refUsages.get()
val (imports, usages) = usagesAndImports.partition { PsiTreeUtil.getParentOfType(it.element, PyImportStatementBase::class.java) != null }
val filteredUsages = usages.filter { usage ->
val element = usage.element!!
if (element.parent is PyDecorator) {
if (!handleUsageError(element, "refactoring.inline.function.is.decorator", conflicts)) return false
return@filter false
}
else if (element.parent !is PyCallExpression) {
if (!handleUsageError(element, "refactoring.inline.function.is.reference", conflicts)) return false
return@filter false
}
else {
val callExpression = element.parent as PyCallExpression
if (callExpression.arguments.any { it is PyStarArgument}) {
if (!handleUsageError(element, "refactoring.inline.function.uses.unpacking", conflicts)) return false
return@filter false
}
}
return@filter true
}
val conflictLocations = conflicts.keySet().map { it.containingFile }
val filteredImports = imports.filter { it.file !in conflictLocations }
val filtered = filteredUsages + filteredImports
refUsages.set(filtered.toTypedArray())
return showConflicts(conflicts, filtered.toTypedArray())
}
private fun handleUsageError(element: PsiElement, @PropertyKey(resourceBundle = "com.jetbrains.python.PyBundle") error: String, conflicts: MultiMap<PsiElement, String>): Boolean {
val errorText = PyBundle.message(error, myFunction.name)
if (myInlineThis) {
// shortcut for inlining single reference: show error hint instead of modal dialog
CommonRefactoringUtil.showErrorHint(myProject, myEditor, errorText, PyBundle.message("refactoring.inline.function.title"), PyInlineFunctionHandler.REFACTORING_ID)
prepareSuccessful()
return false
}
conflicts.putValue(element, errorText)
myRemoveDeclaration = false
return true
}
override fun findUsages(): Array<UsageInfo> {
if (myInlineThis) {
val element = myReference!!.element as PyReferenceExpression
val import = PyResolveUtil.resolveLocally(ScopeUtil.getScopeOwner(element)!!, element.name!!).firstOrNull { it is PyImportElement }
return if (import != null) arrayOf(UsageInfo(element), UsageInfo(import)) else arrayOf(UsageInfo(element))
}
return ReferencesSearch.search(myFunction, myRefactoringScope).asSequence()
.distinct()
.map(PsiReference::getElement)
.map(::UsageInfo)
.toList()
.toTypedArray()
}
override fun performRefactoring(usages: Array<out UsageInfo>) {
val action = LocalHistory.getInstance().startAction(commandName)
try {
doRefactor(usages)
}
finally {
action.finish()
}
}
private fun doRefactor(usages: Array<out UsageInfo>) {
val (unsortedRefs, imports) = usages.partition { PsiTreeUtil.getParentOfType(it.element, PyImportStatementBase::class.java) == null }
val references = unsortedRefs.sortedByDescending { usage ->
SyntaxTraverser.psiApi().parents(usage.element).asSequence().filter { it is PyCallExpression }.count()
}
val functionScope = ControlFlowCache.getScope(myFunction)
PyClassRefactoringUtil.rememberNamedReferences(myFunction)
references.forEach { usage ->
val reference = usage.element as PyReferenceExpression
val languageLevel = LanguageLevel.forElement(reference)
val refScopeOwner = ScopeUtil.getScopeOwner(reference) ?: error("Unable to find scope owner for ${reference.name}")
val declarations = mutableListOf<PyAssignmentStatement>()
val generatedNames = mutableSetOf<String>()
val callSite = PsiTreeUtil.getParentOfType(reference, PyCallExpression::class.java) ?: error("Unable to find call expression for ${reference.name}")
val containingStatement = PsiTreeUtil.getParentOfType(callSite, PyStatement::class.java) ?: error("Unable to find statement for ${reference.name}")
val replacementFunction = myFunction.statementList.copy() as PyStatementList
val namesInOuterScope = PyRefactoringUtil.collectUsedNames(refScopeOwner)
val argumentReplacements = mutableMapOf<PyReferenceExpression, PyExpression>()
val nameClashes = MultiMap.create<String, PyExpression>()
val returnStatements = mutableListOf<PyReturnStatement>()
val mappedArguments = prepareArguments(callSite, declarations, generatedNames, reference, languageLevel)
val builtinCache = PyBuiltinCache.getInstance(reference)
replacementFunction.accept(object : PyRecursiveElementVisitor() {
override fun visitPyReferenceExpression(node: PyReferenceExpression) {
if (node.qualifier == null) {
val name = node.name
if (name in mappedArguments) {
argumentReplacements[node] = mappedArguments[name]!!
}
else if (name in namesInOuterScope && !builtinCache.isBuiltin(node.reference.resolve())) {
nameClashes.putValue(name!!, node)
}
}
super.visitPyReferenceExpression(node)
}
override fun visitPyReturnStatement(node: PyReturnStatement) {
returnStatements.add(node)
super.visitPyReturnStatement(node)
}
override fun visitPyTargetExpression(node: PyTargetExpression) {
if (node.qualifier == null) {
val name = node.name
if (name in namesInOuterScope && name !in mappedArguments && functionScope.containsDeclaration(name)) {
nameClashes.putValue(name!!, node)
}
}
super.visitPyTargetExpression(node)
}
})
// Replacing
argumentReplacements.forEach { (old, new) -> old.replace(new) }
nameClashes.entrySet().forEach { (name, elements) ->
val generated = generateUniqueAssignment(languageLevel, name, generatedNames, reference)
elements.forEach {
when (it) {
is PyTargetExpression -> it.replace(generated.targets[0])
is PyReferenceExpression -> it.replace(generated.assignedValue!!)
}
}
}
if (returnStatements.size == 1 && returnStatements[0].expression !is PyTupleExpression) {
// replace single return with expression itself
val statement = returnStatements[0]
callSite.replace(statement.expression!!)
statement.delete()
}
else if (returnStatements.isNotEmpty()) {
val newReturn = generateUniqueAssignment(languageLevel, "result", generatedNames, reference)
returnStatements.forEach {
val copy = newReturn.copy() as PyAssignmentStatement
copy.assignedValue!!.replace(it.expression!!)
it.replace(copy)
}
callSite.replace(newReturn.assignedValue!!)
}
CodeStyleManager.getInstance(myProject).reformat(replacementFunction, true)
declarations.forEach { containingStatement.parent.addBefore(it, containingStatement) }
if (replacementFunction.firstChild != null) {
val statements = replacementFunction.statements
statements.asSequence()
.map { containingStatement.parent.addBefore(it, containingStatement) }
.forEach { PyClassRefactoringUtil.restoreNamedReferences(it) }
}
if (returnStatements.isEmpty()) {
if (callSite.parent is PyExpressionStatement) {
containingStatement.delete()
}
else {
callSite.replace(myGenerator.createExpressionFromText(languageLevel, "None"))
}
}
}
imports.forEach { PyClassRefactoringUtil.optimizeImports(it.element!!.containingFile!!) }
if (myRemoveDeclaration) {
val stubFunction = PyiUtil.getPythonStub(myFunction)
if (stubFunction != null && stubFunction.isWritable) {
stubFunction.delete()
}
myFunction.delete()
}
}
private fun prepareArguments(callSite: PyCallExpression, declarations: MutableList<PyAssignmentStatement>, generatedNames: MutableSet<String>,
reference: PyReferenceExpression, languageLevel: LanguageLevel): Map<String, PyExpression> {
val context = PyResolveContext.noImplicits().withTypeEvalContext(TypeEvalContext.userInitiated(myProject, reference.containingFile))
val mapping = PyCallExpressionHelper.mapArguments(callSite, context).firstOrNull() ?: error("Can't map arguments for ${reference.name}")
val mappedParams = mapping.mappedParameters
val self = mapping.implicitParameters.firstOrNull()?.let { first ->
val implicitName = first.name!!
val selfReplacement = reference.qualifier?.let { qualifier ->
myFunctionClass?.let {
when {
qualifier is PyReferenceExpression && !qualifier.isQualified -> qualifier
else -> {
val qualifierDeclaration = generateUniqueAssignment(languageLevel, myFunctionClass.name!!, generatedNames, reference)
val newRef = qualifierDeclaration.assignedValue!!.copy() as PyExpression
qualifierDeclaration.assignedValue!!.replace(qualifier)
declarations.add(qualifierDeclaration)
newRef
}
}
}
}
mapOf(implicitName to selfReplacement!!)
} ?: emptyMap()
val passedArguments = mappedParams.asSequence()
.map { (arg, param) ->
val argValue = if (arg is PyKeywordArgument) arg.valueExpression!! else arg
tryExtractDeclaration(param.name!!, argValue, declarations, generatedNames, reference, languageLevel)
}
.toMap()
val defaultValues = myFunction.parameterList.parameters.asSequence()
.filter { it.name !in passedArguments }
.filter { it.hasDefaultValue() }
.map { tryExtractDeclaration(it.name!!, it.defaultValue!!, declarations, generatedNames, reference, languageLevel) }
.toMap()
return self + passedArguments + defaultValues
}
private fun tryExtractDeclaration(paramName: String, arg: PyExpression, declarations: MutableList<PyAssignmentStatement>, generatedNames: MutableSet<String>,
reference: PyReferenceExpression, languageLevel: LanguageLevel): Pair<String, PyExpression> {
if (arg !is PyReferenceExpression && arg !is PyLiteralExpression) {
val statement = generateUniqueAssignment(languageLevel, paramName, generatedNames, reference)
statement.assignedValue!!.replace(arg)
declarations.add(statement)
return paramName to statement.targets[0]
}
return paramName to arg
}
private fun generateUniqueAssignment(level: LanguageLevel, name: String, previouslyGeneratedNames: MutableSet<String>, scopeAnchor: PsiElement): PyAssignmentStatement {
val uniqueName = PyRefactoringUtil.selectUniqueName(name, scopeAnchor) { newName, anchor ->
PyRefactoringUtil.isValidNewName(newName, anchor) && newName !in previouslyGeneratedNames
}
previouslyGeneratedNames.add(uniqueName)
return myGenerator.createFromText(level, PyAssignmentStatement::class.java, "$uniqueName = $uniqueName")
}
override fun getCommandName() = "Inlining ${myFunction.name}"
override fun getRefactoringId() = PyInlineFunctionHandler.REFACTORING_ID
override fun createUsageViewDescriptor(usages: Array<out UsageInfo>) = object : UsageViewDescriptor {
override fun getElements(): Array<PsiElement> = arrayOf(myFunction)
override fun getProcessedElementsHeader(): String = "Function to inline "
override fun getCodeReferencesText(usagesCount: Int, filesCount: Int): String = "Invocations to be inlined in $filesCount files"
override fun getCommentReferencesText(usagesCount: Int, filesCount: Int): String = ""
}
}

View File

@@ -51,7 +51,7 @@ public class PyInlineLocalHandler extends InlineActionHandler {
private static final String REFACTORING_NAME = RefactoringBundle.message("inline.variable.title");
private static final Pair<PyStatement, Boolean> EMPTY_DEF_RESULT = Pair.create(null, false);
private static final String HELP_ID = "python.reference.inline";
private static final String HELP_ID = "refactoring.inlineVariable";
public static PyInlineLocalHandler getInstance() {
return EP_NAME.findExtensionOrFail(PyInlineLocalHandler.class);

View File

@@ -229,7 +229,7 @@ public class PyMakeMethodTopLevelProcessor extends PyBaseMakeFunctionTopLevelPro
final PsiElement anchor = ContainerUtil.getFirstItem(reads);
//noinspection ConstantConditions
if (!PyRefactoringUtil.isValidNewName(name, anchor)) {
final String indexedName = PyRefactoringUtil.appendNumberUntilValid(name, anchor);
final String indexedName = PyRefactoringUtil.appendNumberUntilValid(name, anchor, PyRefactoringUtil::isValidNewName);
myAttributeToParameterName.put(name, indexedName);
}
else {

View File

@@ -0,0 +1,19 @@
def foo(x, y, z):
print(x)
print(y)
print(z)
return x
def id(x):
return x
def bar():
x = id(1)
y = id(2)
z = id(3)
print(x)
print(y)
print(z)
res = x

View File

@@ -0,0 +1,13 @@
def foo(x, y, z):
print(x)
print(y)
print(z)
return x
def id(x):
return x
def bar():
res = fo<caret>o(id(1), id(2), z = id(3))

View File

@@ -0,0 +1,6 @@
async def foo():
pass
async def bar():
await f<caret>oo()

View File

@@ -0,0 +1 @@
inp<caret>ut()

View File

@@ -0,0 +1,16 @@
class A:
def __init__(self, a, b):
self.a = a
self.b = b
def doStuff(self):
print(self.a)
print(self.b)
return self.a + self.b
@classmethod
def foo(cls):
my_a = A(1, 2)
print(my_a.a)
print(my_a.b)
res = my_a.a + my_a.b

View File

@@ -0,0 +1,14 @@
class A:
def __init__(self, a, b):
self.a = a
self.b = b
def doStuff(self):
print(self.a)
print(self.b)
return self.a + self.b
@classmethod
def foo(cls):
my_a = A(1, 2)
res = my_a.doSt<caret>uff()

View File

@@ -0,0 +1,16 @@
class A:
def __init__(self, a, b):
self.a = a
self.b = b
def doStuff(self):
print(self.a)
print(self.b)
return self.a + self.b
@staticmethod
def foo():
my_a = A(1, 2)
print(my_a.a)
print(my_a.b)
res = my_a.a + my_a.b

View File

@@ -0,0 +1,14 @@
class A:
def __init__(self, a, b):
self.a = a
self.b = b
def doStuff(self):
print(self.a)
print(self.b)
return self.a + self.b
@staticmethod
def foo():
my_a = A(1, 2)
res = my_a.doSt<caret>uff()

View File

@@ -0,0 +1,15 @@
class MyClass:
def __init__(self, attr):
self.attr = attr
def __add__(self, other):
return MyClass(self.attr + other.attr)
def method(self):
print(self.attr)
print(self.attr)
my_class = (MyClass(1) + MyClass(2))
print(my_class.attr)
print(my_class.attr)

View File

@@ -0,0 +1,13 @@
class MyClass:
def __init__(self, attr):
self.attr = attr
def __add__(self, other):
return MyClass(self.attr + other.attr)
def method(self):
print(self.attr)
print(self.attr)
(MyClass(1) + MyClass(2)).meth<caret>od()

View File

@@ -0,0 +1,6 @@
class MyClass():
def __init__(self):
pass
MyCla<caret>ss()

View File

@@ -0,0 +1,10 @@
def foo(x):
return x
@foo
def bar():
pass
ba<caret>r()

View File

@@ -0,0 +1,8 @@
def foo(arg1=1, arg2=2, arg3=3, arg4=4):
res = arg1 + arg2 + arg3 + arg4
print(res)
def bar():
res = 10 + 20 + 3 + 4
print(res)

View File

@@ -0,0 +1,7 @@
def foo(arg1=1, arg2=2, arg3=3, arg4=4):
res = arg1 + arg2 + arg3 + arg4
print(res)
def bar():
f<caret>oo(10, 20)

View File

@@ -0,0 +1,10 @@
from source import do_more_work
from source import do_work
from source import do_substantially_more_work as do_nothing
def bar():
do_work()
do_more_work(42)
do_nothing()
res = 42

View File

@@ -0,0 +1,7 @@
from source import foo
from source import do_work
from source import do_substantially_more_work as do_nothing
def bar():
res = fo<caret>o(42)

View File

@@ -0,0 +1,19 @@
def do_work():
pass
def do_more_work(x):
pass
def do_substantially_more_work():
pass
def foo(arg):
do_work()
do_more_work(arg)
do_substantially_more_work()
return arg

View File

@@ -0,0 +1,8 @@
import source
def bar():
source.do_work()
source.do_more_work(42)
source.do_substantially_more_work()
res = 42

View File

@@ -0,0 +1,5 @@
from source import foo
def bar():
res = fo<caret>o(42)

View File

@@ -0,0 +1,19 @@
def do_work():
pass
def do_more_work(x):
pass
def do_substantially_more_work():
pass
def foo(arg):
do_work()
do_more_work(arg)
do_substantially_more_work()
return arg

View File

@@ -0,0 +1,5 @@
def foo():
yield 1
f<caret>oo()

View File

@@ -0,0 +1,8 @@
from source import do_work, do_more_work, do_substantially_more_work
def bar():
do_work()
do_more_work(42)
do_substantially_more_work()
res = 42

View File

@@ -0,0 +1,5 @@
from source import foo
def bar():
res = fo<caret>o(42)

View File

@@ -0,0 +1,19 @@
def do_work():
pass
def do_more_work(x):
pass
def do_substantially_more_work():
pass
def foo(arg):
do_work()
do_more_work(arg)
do_substantially_more_work()
return arg

View File

@@ -0,0 +1,36 @@
def foo(arg):
local = 1
if arg:
another = 2
else:
another = 3
return local
def bar():
x = 1
local = 1
if x:
another = 2
else:
another = 3
res = local
def baz():
y = 2
local = 1
if y:
another = 2
else:
another = 3
res = local
z = 1
local = 1
if z:
another = 2
else:
another = 3
res = local

View File

@@ -0,0 +1,21 @@
def fo<caret>o(arg):
local = 1
if arg:
another = 2
else:
another = 3
return local
def bar():
x = 1
res = foo(x)
def baz():
y = 2
res = foo(y)
z = 1
res = foo(z)

View File

@@ -0,0 +1,5 @@
print(1)
print(2)
z = 1 + 2
print(z)
res = z

View File

@@ -0,0 +1,3 @@
from src import foo as bar
res = ba<caret>r(1, 2)

View File

@@ -0,0 +1,6 @@
def foo(x, y):
print(x)
print(y)
z = x + y
print(z)
return z

View File

@@ -0,0 +1,8 @@
def foo(arg):
if not arg:
return None
print("working")
return arg
f<caret>oo(1)

View File

@@ -0,0 +1,8 @@
def foo(*, a, b):
print(a)
print(b)
def bar():
print(1)
print(2)

View File

@@ -0,0 +1,7 @@
def foo(*, a, b):
print(a)
print(b)
def bar():
fo<caret>o(a = 1, b = 2)

View File

@@ -0,0 +1,27 @@
class MyClass:
def do_stuff(self, x, y):
print(x)
print(y)
self.do_something_else()
if x:
print(x)
elif y:
print(y)
else:
print("nothing")
result = x, y
return result
def for_inline(self, a, b):
self.do_something_else()
if a:
print(a)
elif b:
print(b)
else:
print("nothing")
return a, b
def do_something_else(self):
pass

View File

@@ -0,0 +1,19 @@
class MyClass:
def do_stuff(self, x, y):
print(x)
print(y)
return self.for_in<caret>line(x, y)
def for_inline(self, a, b):
self.do_something_else()
if a:
print(a)
elif b:
print(b)
else:
print("nothing")
return a, b
def do_something_else(self):
pass

View File

@@ -0,0 +1,33 @@
class MyClass:
def do_stuff(self, x, y):
print(x)
print(y)
return self.for_inline(x, y)
def for_inline(self, a, b):
self.do_something_else()
if a:
print(a)
elif b:
print(b)
else:
print("nothing")
return a, b
def do_something_else(self):
pass
x = 1
y = 2
cls = MyClass()
cls.do_something_else()
if x:
print(x)
elif y:
print(y)
else:
print("nothing")
result = x, y
res = result

View File

@@ -0,0 +1,25 @@
class MyClass:
def do_stuff(self, x, y):
print(x)
print(y)
return self.for_inline(x, y)
def for_inline(self, a, b):
self.do_something_else()
if a:
print(a)
elif b:
print(b)
else:
print("nothing")
return a, b
def do_something_else(self):
pass
x = 1
y = 2
cls = MyClass()
res = cls.for_in<caret>line(x, y)

View File

@@ -0,0 +1,26 @@
def foo(x, y, z):
if x:
return x + 2
elif y:
return y + 2
else:
if z:
return z + 2
else:
return 2
def bar():
a = 1
b = 2
c = 3
if a:
result = a + 2
elif b:
result = b + 2
else:
if c:
result = c + 2
else:
result = 2
res = result

View File

@@ -0,0 +1,17 @@
def foo(x, y, z):
if x:
return x + 2
elif y:
return y + 2
else:
if z:
return z + 2
else:
return 2
def bar():
a = 1
b = 2
c = 3
res = fo<caret>o(a, b, c)

View File

@@ -0,0 +1,19 @@
def foo(arg):
local = 1
if arg:
another = 2
else:
another = 3
return local
def bar():
x = 1
local = 2
another = 3
local1 = 1
if x:
another1 = 2
else:
another1 = 3
res = local1

View File

@@ -0,0 +1,14 @@
def foo(arg):
local = 1
if arg:
another = 2
else:
another = 3
return local
def bar():
x = 1
local = 2
another = 3
res = fo<caret>o(x)

View File

@@ -0,0 +1,8 @@
def foo():
def bar():
pass
bar()
fo<caret>o()

View File

@@ -0,0 +1,3 @@
print("fun called")
print("fun called")
res = 1

View File

@@ -0,0 +1,6 @@
def fun(x):
print("fun called")
return x
res = fu<caret>n(fun(1))

View File

@@ -0,0 +1,21 @@
def foo(x, y):
s = x + y
if s > 10:
print("s>10")
elif s > 5:
print("s>5")
else:
print("less")
print("over")
def bar():
s = 1 + 2
if s > 10:
print("s>10")
elif s > 5:
print("s>5")
else:
print("less")
print("over")
res = None

View File

@@ -0,0 +1,13 @@
def foo(x, y):
s = x + y
if s > 10:
print("s>10")
elif s > 5:
print("s>5")
else:
print("less")
print("over")
def bar():
res = f<caret>oo(1, 2)

View File

@@ -0,0 +1,20 @@
def foo(x, y):
s = x + y
if s > 10:
print("s>10")
elif s > 5:
print("s>5")
else:
print("less")
print("over")
def bar():
s = 1 + 2
if s > 10:
print("s>10")
elif s > 5:
print("s>5")
else:
print("less")
print("over")

View File

@@ -0,0 +1,13 @@
def foo(x, y):
s = x + y
if s > 10:
print("s>10")
elif s > 5:
print("s>5")
else:
print("less")
print("over")
def bar():
f<caret>oo(1, 2)

View File

@@ -0,0 +1,12 @@
class A():
def method(self):
print(1)
class B(A):
def method(self):
print(2)
b = B()
b.meth<caret>od()

View File

@@ -0,0 +1,8 @@
def foo(a, b, /):
print(a)
print(b)
def bar():
print(1)
print(2)

View File

@@ -0,0 +1,7 @@
def foo(a, b, /):
print(a)
print(b)
def bar():
fo<caret>o(1, 2)

View File

@@ -0,0 +1,5 @@
def foo():
foo()
fo<caret>o()

View File

@@ -0,0 +1,3 @@
print(1)
print(2)
1 + 2

View File

@@ -0,0 +1,7 @@
def foo(x, y):
print(x)
print(y)
return x + y
f<caret>oo(1, 2)

View File

@@ -0,0 +1,2 @@
def foo(x: int, y: int) -> int:
...

View File

@@ -0,0 +1,8 @@
def bar():
x = 1
local = 1
if x:
another = 2
else:
another = 3
res = local

View File

@@ -0,0 +1,12 @@
def foo(arg):
local = 1
if arg:
another = 2
else:
another = 3
return local
def bar():
x = 1
res = fo<caret>o(x)

View File

@@ -0,0 +1,7 @@
def foo(arg):
local = arg + arg
return local
local = 42 + 42
x = local

View File

@@ -0,0 +1,6 @@
def foo(arg):
local = arg + arg
return local
x = fo<caret>o(42)

View File

@@ -0,0 +1,4 @@
def foo(*args, **kwargs):
pass
fo<caret>o(1, 2, 3, x = 4)

View File

@@ -0,0 +1,7 @@
def foo(arg):
return arg
@fo<caret>o
def bar():
pass

View File

@@ -0,0 +1,9 @@
def call_with_one(fun):
fun(1)
def foo(x):
print(x)
call_with_one(fo<caret>o)

View File

@@ -0,0 +1,6 @@
def foo(a, b, c, d):
print(a, b, c, d)
arg = (1, 2, 3, 4)
fo<caret>o(*arg)

View File

@@ -0,0 +1,96 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.refactoring
import com.intellij.codeInsight.TargetElementUtil
import com.intellij.refactoring.util.CommonRefactoringUtil
import com.jetbrains.python.codeInsight.PyCodeInsightSettings
import com.jetbrains.python.fixtures.PyTestCase
import com.jetbrains.python.psi.LanguageLevel
import com.jetbrains.python.psi.PyElement
import com.jetbrains.python.psi.PyFunction
import com.jetbrains.python.pyi.PyiFile
import com.jetbrains.python.pyi.PyiUtil
import com.jetbrains.python.refactoring.inline.PyInlineFunctionHandler
import com.jetbrains.python.refactoring.inline.PyInlineFunctionProcessor
import junit.framework.TestCase
/**
* @author Aleksei.Kniazev
*/
class PyInlineFunctionTest : PyTestCase() {
override fun getTestDataPath(): String = super.getTestDataPath() + "/refactoring/inlineFunction"
private fun doTest(inlineThis: Boolean = true, remove: Boolean = false) {
val testName = getTestName(true)
myFixture.copyDirectoryToProject(testName, "")
myFixture.configureByFile("main.py")
var element = TargetElementUtil.findTargetElement(myFixture.editor, TargetElementUtil.getInstance().referenceSearchFlags)
if (element!!.containingFile is PyiFile) element = PyiUtil.getOriginalElement(element as PyElement)
val reference = TargetElementUtil.findReference(myFixture.editor)
TestCase.assertTrue(element is PyFunction)
PyInlineFunctionProcessor(myFixture.project, myFixture. editor, element as PyFunction, reference, inlineThis, remove).run()
myFixture.checkResultByFile("$testName/main.after.py")
}
private fun doTestError(expectedError: String, isReferenceError: Boolean = false) {
myFixture.configureByFile("${getTestName(true)}.py")
val element = TargetElementUtil.findTargetElement(myFixture.editor, TargetElementUtil.getInstance().referenceSearchFlags)
try {
if (isReferenceError) {
val reference = TargetElementUtil.findReference(myFixture.editor)
PyInlineFunctionProcessor(myFixture.project, myFixture. editor, element as PyFunction, reference, myInlineThis = true, removeDeclaration = false).run()
}
else {
PyInlineFunctionHandler.getInstance().inlineElement(myFixture.project, myFixture.editor, element)
}
TestCase.fail("Expected error: $expectedError, but got none")
}
catch (e: CommonRefactoringUtil.RefactoringErrorHintException) {
TestCase.assertEquals(expectedError, e.message)
}
}
fun testSimple() = doTest()
fun testNameClash() = doTest()
fun testArgumentExtraction() = doTest()
fun testMultipleReturns() = doTest()
fun testImporting() = doTest()
//fun testExistingImports() = doTest()
fun testMethodInsideClass() = doTest()
fun testMethodOutsideClass() = doTest()
fun testNoReturnsAsExpressionStatement() = doTest()
fun testNoReturnsAsCallExpression() = doTest()
fun testInlineAll() = doTest(inlineThis = false)
fun testRemoving() = doTest(inlineThis = false, remove = true)
fun testDefaultValues() = doTest()
fun testPositionalOnlyArgs() = doTest()
fun testKeywordOnlyArgs() = doTest()
fun testNestedCalls() = doTest(inlineThis = false, remove = true)
fun testCallFromStaticMethod() = doTest()
fun testCallFromClassMethod() = doTest()
fun testComplexQualifier() = doTest()
//fun testInlineImportedAs() = doTest(inlineThis = false)
fun testRemoveFunctionWithStub() {
doTest(inlineThis = false, remove = true)
val testName = getTestName(true)
myFixture.checkResultByFile("main.pyi", "$testName/main.after.pyi",true)
}
fun testGenerator() = doTestError("Cannot inline generators")
fun testAsyncFunction() {
runWithLanguageLevel(LanguageLevel.PYTHON37) {
doTestError("Cannot inline async functions")
}
}
fun testConstructor() = doTestError("Cannot inline constructor calls")
fun testBuiltin() = doTestError("Cannot inline builtin functions")
fun testDecorator() = doTestError("Cannot inline functions with decorators")
fun testRecursive() = doTestError("Cannot inline functions that reference themselves")
fun testStar() = doTestError("Cannot inline functions with * arguments")
fun testOverridden() = doTestError("Cannot inline overridden functions")
fun testNested() = doTestError("Cannot inline functions with another function declaration")
fun testInterruptedFlow() = doTestError("Cannot inline functions that interrupt control flow")
fun testUsedAsDecorator() = doTestError("Function foo is used as a decorator and cannot be inlined. Function definition will not be removed", isReferenceError = true)
fun testUsedAsReference() = doTestError("Function foo is used as a reference and cannot be inlined. Function definition will not be removed", isReferenceError = true)
fun testUsesArgumentUnpacking() = doTestError("Function foo uses argument unpacking and cannot be inlined. Function definition will not be removed", isReferenceError = true)
}