mirror of
https://gitflic.ru/project/openide/openide.git
synced 2026-04-30 10:20:15 +07:00
IDEA-CR-48018: inline function refactoring for python (PY-21287)
- ability to inline single/all invocations and keep/remove the declaration - handled name conflicts of local and imported vars/functions - cases when this refactoring is not applicable can be found in PyInlineFunctionHandler (cherry picked from commit 40e298ba2b833a4408ee774628b84fe422468b91) GitOrigin-RevId: dd3fab59800163b4f300b410cd28a62c1ee1b6f3
This commit is contained in:
committed by
intellij-monorepo-bot
parent
ce764be06a
commit
c6e4181228
@@ -445,6 +445,7 @@
|
||||
<lang.refactoringSupport.classMembersRefactoringSupport language="Python"
|
||||
implementationClass="com.jetbrains.python.refactoring.classes.PyMembersRefactoringSupport"/>
|
||||
<inlineActionHandler implementation="com.jetbrains.python.refactoring.inline.PyInlineLocalHandler"/>
|
||||
<inlineActionHandler implementation="com.jetbrains.python.refactoring.inline.PyInlineFunctionHandler"/>
|
||||
<codeInsight.gotoSuper language="Python" implementationClass="com.jetbrains.python.codeInsight.PyGotoSuperHandler"/>
|
||||
<gotoDeclarationHandler implementation="com.jetbrains.python.codeInsight.PyBreakContinueGotoProvider" order="FIRST"/>
|
||||
<gotoDeclarationHandler implementation="com.jetbrains.python.psi.impl.PyGotoDeclarationHandler"/>
|
||||
|
||||
@@ -645,6 +645,27 @@ refactoring.push.down.error.cannot.perform.refactoring.not.inside.class=Cannot p
|
||||
# inline
|
||||
refactoring.inline.local.multiassignment=Definition is in multi-assign
|
||||
|
||||
# inline function
|
||||
refactoring.inline.function.title=Inline Function
|
||||
refactoring.inline.this.only=Inline this invocation only and keep the declaration
|
||||
refactoring.inline.all.keep.declaration=Inline all invocations and keep the declaration
|
||||
refactoring.inline.all.remove.declaration=Inline all invocations and remove the declaration
|
||||
refactoring.inline.function.is.decorator=Function {0} is used as a decorator and cannot be inlined. Function definition will not be removed
|
||||
refactoring.inline.function.is.reference=Function {0} is used as a reference and cannot be inlined. Function definition will not be removed
|
||||
refactoring.inline.function.uses.unpacking=Function {0} uses argument unpacking and cannot be inlined. Function definition will not be removed
|
||||
refactoring.inline.function.generator=Cannot inline generators
|
||||
refactoring.inline.function.async=Cannot inline async functions
|
||||
refactoring.inline.function.constructor=Cannot inline constructor calls
|
||||
refactoring.inline.function.builtin=Cannot inline builtin functions
|
||||
refactoring.inline.function.decorator=Cannot inline functions with decorators
|
||||
refactoring.inline.function.self.referrent=Cannot inline functions that reference themselves
|
||||
refactoring.inline.function.star=Cannot inline functions with * arguments
|
||||
refactoring.inline.function.overridden=Cannot inline overridden functions
|
||||
refactoring.inline.function.global=Cannot inline functions with global variables
|
||||
refactoring.inline.function.nonlocal=Cannot inline functions with nonlocals variables
|
||||
refactoring.inline.function.nested=Cannot inline functions with another function declaration
|
||||
refactoring.inline.function.interrupts.flow=Cannot inline functions that interrupt control flow
|
||||
|
||||
# extract method
|
||||
refactoring.extract.method=Extract method
|
||||
refactoring.extract.method.error.interrupted.execution.flow=Cannot perform refactoring when execution flow is interrupted
|
||||
|
||||
@@ -36,10 +36,13 @@ public interface Scope {
|
||||
@Nullable
|
||||
ScopeVariable getDeclaredVariable(@NotNull PsiElement anchorElement, @NotNull String name) throws DFALimitExceededException;
|
||||
|
||||
boolean hasGlobals();
|
||||
boolean isGlobal(String name);
|
||||
|
||||
boolean hasNonLocals();
|
||||
boolean isNonlocal(String name);
|
||||
|
||||
boolean hasNestedScopes();
|
||||
boolean containsDeclaration(String name);
|
||||
|
||||
@NotNull
|
||||
|
||||
@@ -74,6 +74,15 @@ public class ScopeImpl implements Scope {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasGlobals() {
|
||||
if (myGlobals == null || myNestedScopes == null) {
|
||||
collectDeclarations();
|
||||
}
|
||||
if (!myGlobals.isEmpty()) return true;
|
||||
return myNestedScopes.stream().anyMatch(scope -> scope.hasGlobals());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isGlobal(final String name) {
|
||||
if (myGlobals == null || myNestedScopes == null) {
|
||||
@@ -90,6 +99,15 @@ public class ScopeImpl implements Scope {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasNonLocals() {
|
||||
if (myNonlocals == null || myNestedScopes == null) {
|
||||
collectDeclarations();
|
||||
}
|
||||
if (!myNonlocals.isEmpty()) return true;
|
||||
return myNestedScopes.stream().anyMatch(scope -> scope.hasNonLocals());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean isNonlocal(final String name) {
|
||||
if (myNonlocals == null || myNestedScopes == null) {
|
||||
@@ -98,6 +116,14 @@ public class ScopeImpl implements Scope {
|
||||
return myNonlocals.contains(name);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasNestedScopes() {
|
||||
if (myNestedScopes == null) {
|
||||
collectDeclarations();
|
||||
}
|
||||
return !myNestedScopes.isEmpty();
|
||||
}
|
||||
|
||||
private boolean isAugAssignment(final String name) {
|
||||
if (myAugAssignments == null || myNestedScopes == null) {
|
||||
collectDeclarations();
|
||||
|
||||
@@ -20,6 +20,7 @@ import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
import java.util.*;
|
||||
import java.util.function.BiPredicate;
|
||||
|
||||
public class PyRefactoringUtil {
|
||||
private PyRefactoringUtil() {
|
||||
@@ -326,7 +327,7 @@ public class PyRefactoringUtil {
|
||||
*/
|
||||
@NotNull
|
||||
public static String selectUniqueNameFromType(@NotNull String typeName, @NotNull PsiElement scopeAnchor) {
|
||||
return selectUniqueName(typeName, true, scopeAnchor);
|
||||
return selectUniqueName(typeName, true, scopeAnchor, PyRefactoringUtil::isValidNewName);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -339,11 +340,16 @@ public class PyRefactoringUtil {
|
||||
*/
|
||||
@NotNull
|
||||
public static String selectUniqueName(@NotNull String templateName, @NotNull PsiElement scopeAnchor) {
|
||||
return selectUniqueName(templateName, false, scopeAnchor);
|
||||
return selectUniqueName(templateName, false, scopeAnchor, PyRefactoringUtil::isValidNewName);
|
||||
}
|
||||
|
||||
@NotNull
|
||||
private static String selectUniqueName(@NotNull String templateName, boolean templateIsType, @NotNull PsiElement scopeAnchor) {
|
||||
public static String selectUniqueName(@NotNull String templateName, @NotNull PsiElement scopeAnchor, @NotNull BiPredicate<String, PsiElement> isValid) {
|
||||
return selectUniqueName(templateName, false, scopeAnchor, isValid);
|
||||
}
|
||||
|
||||
@NotNull
|
||||
private static String selectUniqueName(@NotNull String templateName, boolean templateIsType, @NotNull PsiElement scopeAnchor, @NotNull BiPredicate<String, PsiElement> isValid) {
|
||||
final Collection<String> suggestions;
|
||||
if (templateIsType) {
|
||||
suggestions = NameSuggesterUtil.generateNamesByType(templateName);
|
||||
@@ -352,14 +358,14 @@ public class PyRefactoringUtil {
|
||||
suggestions = NameSuggesterUtil.generateNames(templateName);
|
||||
}
|
||||
for (String name : suggestions) {
|
||||
if (isValidNewName(name, scopeAnchor)) {
|
||||
if (isValid.test(name, scopeAnchor)) {
|
||||
return name;
|
||||
}
|
||||
}
|
||||
|
||||
final String shortestName = ContainerUtil.getFirstItem(suggestions);
|
||||
//noinspection ConstantConditions
|
||||
return appendNumberUntilValid(shortestName, scopeAnchor);
|
||||
return appendNumberUntilValid(shortestName, scopeAnchor, isValid);
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -367,13 +373,14 @@ public class PyRefactoringUtil {
|
||||
*
|
||||
* @param name initial name
|
||||
* @param scopeAnchor PSI element used to determine correct scope
|
||||
* @param predicate used to test if suggested name is valid
|
||||
* @return unique name in the scope probably with number suffix appended
|
||||
*/
|
||||
@NotNull
|
||||
public static String appendNumberUntilValid(@NotNull String name, @NotNull PsiElement scopeAnchor) {
|
||||
public static String appendNumberUntilValid(@NotNull String name, @NotNull PsiElement scopeAnchor, @NotNull BiPredicate<String, PsiElement> predicate) {
|
||||
int counter = 1;
|
||||
String candidate = name;
|
||||
while (!isValidNewName(candidate, scopeAnchor)) {
|
||||
while (!predicate.test(candidate, scopeAnchor)) {
|
||||
candidate = name + counter;
|
||||
counter++;
|
||||
}
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
package com.jetbrains.python.refactoring.inline
|
||||
|
||||
import com.intellij.openapi.editor.Editor
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.psi.PsiReference
|
||||
import com.intellij.psi.util.PsiTreeUtil
|
||||
import com.intellij.refactoring.inline.InlineOptionsDialog
|
||||
import com.jetbrains.python.PyBundle
|
||||
import com.jetbrains.python.psi.PyFunction
|
||||
import com.jetbrains.python.psi.PyImportStatementBase
|
||||
|
||||
/**
|
||||
* @author Aleksei.Kniazev
|
||||
*/
|
||||
class PyInlineFunctionDialog(project: Project,
|
||||
private val myEditor: Editor,
|
||||
private val myFunction: PyFunction,
|
||||
private val myReference: PsiReference?) : InlineOptionsDialog(project, true, myFunction) {
|
||||
private val isMethod = myFunction.asMethod() != null
|
||||
|
||||
init {
|
||||
myInvokedOnReference = myReference != null
|
||||
title = if (isMethod) "Inline method ${myFunction.name}" else "Inline function ${myFunction.name}"
|
||||
init()
|
||||
}
|
||||
|
||||
override fun doAction() {
|
||||
invokeRefactoring(PyInlineFunctionProcessor(myProject, myEditor, myFunction, myReference, isInlineThisOnly, !isKeepTheDeclaration))
|
||||
}
|
||||
|
||||
override fun getNameLabelText(): String = "The number of occurrences: ${getNumberOfOccurrences(myFunction)}"
|
||||
override fun getBorderTitle(): String = "Inline"
|
||||
override fun getInlineAllText(): String = PyBundle.message("refactoring.inline.all.remove.declaration")
|
||||
override fun getKeepTheDeclarationText(): String = PyBundle.message("refactoring.inline.all.keep.declaration")
|
||||
override fun getInlineThisText(): String = PyBundle.message("refactoring.inline.this.only")
|
||||
override fun getHelpId(): String = PyInlineFunctionHandler.REFACTORING_ID
|
||||
|
||||
override fun allowInlineAll(): Boolean = true
|
||||
override fun isInlineThis(): Boolean = true
|
||||
|
||||
override fun ignoreOccurrence(reference: PsiReference): Boolean {
|
||||
return PsiTreeUtil.getParentOfType(reference.element, PyImportStatementBase::class.java) == null
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,127 @@
|
||||
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
package com.jetbrains.python.refactoring.inline
|
||||
|
||||
import com.intellij.codeInsight.TargetElementUtil
|
||||
import com.intellij.lang.Language
|
||||
import com.intellij.lang.refactoring.InlineActionHandler
|
||||
import com.intellij.openapi.application.ApplicationManager
|
||||
import com.intellij.openapi.editor.Editor
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.psi.PsiElement
|
||||
import com.intellij.psi.SyntaxTraverser
|
||||
import com.intellij.psi.util.PsiTreeUtil
|
||||
import com.intellij.refactoring.util.CommonRefactoringUtil
|
||||
import com.jetbrains.python.PyBundle
|
||||
import com.jetbrains.python.PyNames
|
||||
import com.jetbrains.python.PythonLanguage
|
||||
import com.jetbrains.python.codeInsight.controlflow.ControlFlowCache
|
||||
import com.jetbrains.python.psi.*
|
||||
import com.jetbrains.python.psi.impl.PyBuiltinCache
|
||||
import com.jetbrains.python.psi.search.PySuperMethodsSearch
|
||||
import com.jetbrains.python.psi.types.TypeEvalContext
|
||||
import com.jetbrains.python.pyi.PyiFile
|
||||
|
||||
/**
|
||||
* @author Aleksei.Kniazev
|
||||
*/
|
||||
class PyInlineFunctionHandler : InlineActionHandler() {
|
||||
override fun isEnabledForLanguage(l: Language?) = l is PythonLanguage
|
||||
|
||||
override fun canInlineElement(element: PsiElement?) = element is PyFunction && element.containingFile !is PyiFile
|
||||
|
||||
override fun inlineElement(project: Project?, editor: Editor?, element: PsiElement?) {
|
||||
if (project == null || editor == null || element !is PyFunction) return
|
||||
val functionScope = ControlFlowCache.getScope(element)
|
||||
val error = when {
|
||||
element.isAsync -> "refactoring.inline.function.async"
|
||||
element.isGenerator -> "refactoring.inline.function.generator"
|
||||
PyNames.INIT == element.name -> "refactoring.inline.function.constructor"
|
||||
PyBuiltinCache.getInstance(element).isBuiltin(element) -> "refactoring.inline.function.builtin"
|
||||
hasDecorators(element) -> "refactoring.inline.function.decorator"
|
||||
hasReferencesToSelf(element) -> "refactoring.inline.function.self.referrent"
|
||||
hasStarArgs(element) -> "refactoring.inline.function.star"
|
||||
isOverride(element, project) -> "refactoring.inline.function.overridden"
|
||||
functionScope.hasGlobals() -> "refactoring.inline.function.global"
|
||||
functionScope.hasNonLocals() -> "refactoring.inline.function.nonlocal"
|
||||
functionScope.hasNestedScopes() -> "refactoring.inline.function.nested"
|
||||
hasNonExhaustiveIfs(element) -> "refactoring.inline.function.interrupts.flow"
|
||||
else -> null
|
||||
}
|
||||
if (error != null) {
|
||||
CommonRefactoringUtil.showErrorHint(project, editor, PyBundle.message(error), PyBundle.message("refactoring.inline.function.title"), REFACTORING_ID)
|
||||
return
|
||||
}
|
||||
if (!ApplicationManager.getApplication().isUnitTestMode){
|
||||
PyInlineFunctionDialog(project, editor, element, TargetElementUtil.findReference(editor)).show()
|
||||
}
|
||||
}
|
||||
|
||||
private fun hasNonExhaustiveIfs(function: PyFunction): Boolean {
|
||||
val returns = mutableListOf<PyReturnStatement>()
|
||||
|
||||
function.accept(object : PyRecursiveElementVisitor() {
|
||||
override fun visitPyReturnStatement(node: PyReturnStatement) {
|
||||
returns.add(node)
|
||||
}
|
||||
})
|
||||
|
||||
if (returns.isEmpty()) return false
|
||||
val cache = mutableSetOf<PyIfStatement>()
|
||||
return returns.asSequence()
|
||||
.map { PsiTreeUtil.getParentOfType(it, PyIfStatement::class.java) }
|
||||
.distinct()
|
||||
.filterNotNull()
|
||||
.any { checkInterruptsControlFlow(it, cache) }
|
||||
}
|
||||
|
||||
private fun checkInterruptsControlFlow(ifStatement: PyIfStatement, cache: MutableSet<PyIfStatement>): Boolean {
|
||||
if (ifStatement in cache) return false
|
||||
cache.add(ifStatement)
|
||||
val elsePart = ifStatement.elsePart
|
||||
if (elsePart == null) return true
|
||||
|
||||
if (checkLastStatement(ifStatement.ifPart.statementList, cache)) return true
|
||||
if (checkLastStatement(elsePart.statementList, cache)) return true
|
||||
ifStatement.elifParts.forEach { if (checkLastStatement(it.statementList, cache)) return true }
|
||||
|
||||
val parentIfStatement = PsiTreeUtil.getParentOfType(ifStatement, PyIfStatement::class.java)
|
||||
if (parentIfStatement != null && checkInterruptsControlFlow(parentIfStatement, cache)) return true
|
||||
return false
|
||||
}
|
||||
|
||||
private fun checkLastStatement(statementList: PyStatementList, cache: MutableSet<PyIfStatement>): Boolean {
|
||||
val statements = statementList.statements
|
||||
if (statements.isEmpty()) return true
|
||||
when(val last = statements.last()) {
|
||||
is PyIfStatement -> if (checkInterruptsControlFlow(last, cache)) return true
|
||||
!is PyReturnStatement -> return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
private fun hasDecorators(function: PyFunction): Boolean = function.decoratorList?.decorators?.isNotEmpty() == true
|
||||
|
||||
private fun isOverride(function: PyFunction, project: Project): Boolean {
|
||||
return function.containingClass != null
|
||||
&& PySuperMethodsSearch.search(function, TypeEvalContext.codeAnalysis(project, function.containingFile)).any()
|
||||
}
|
||||
|
||||
private fun hasStarArgs(function: PyFunction): Boolean {
|
||||
return function.parameterList.parameters.asSequence()
|
||||
.filterIsInstance<PyNamedParameter>()
|
||||
.any { it.isPositionalContainer || it.isKeywordContainer }
|
||||
}
|
||||
|
||||
private fun hasReferencesToSelf(function: PyFunction): Boolean = SyntaxTraverser.psiTraverser(function.statementList)
|
||||
.any { it is PyReferenceExpression && it.reference.isReferenceTo(function) }
|
||||
|
||||
companion object {
|
||||
@JvmStatic
|
||||
fun getInstance(): PyInlineFunctionHandler {
|
||||
return InlineActionHandler.EP_NAME.findExtensionOrFail(PyInlineFunctionHandler::class.java)
|
||||
}
|
||||
|
||||
@JvmStatic
|
||||
val REFACTORING_ID = "refactoring.inlineMethod"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,311 @@
|
||||
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
package com.jetbrains.python.refactoring.inline
|
||||
|
||||
import com.intellij.history.LocalHistory
|
||||
import com.intellij.openapi.editor.Editor
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.openapi.util.Ref
|
||||
import com.intellij.psi.PsiElement
|
||||
import com.intellij.psi.PsiReference
|
||||
import com.intellij.psi.SyntaxTraverser
|
||||
import com.intellij.psi.codeStyle.CodeStyleManager
|
||||
import com.intellij.psi.search.searches.ReferencesSearch
|
||||
import com.intellij.psi.util.PsiTreeUtil
|
||||
import com.intellij.refactoring.BaseRefactoringProcessor
|
||||
import com.intellij.refactoring.util.CommonRefactoringUtil
|
||||
import com.intellij.usageView.UsageInfo
|
||||
import com.intellij.usageView.UsageViewDescriptor
|
||||
import com.intellij.util.containers.MultiMap
|
||||
import com.jetbrains.python.PyBundle
|
||||
import com.jetbrains.python.codeInsight.controlflow.ControlFlowCache
|
||||
import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil
|
||||
import com.jetbrains.python.codeInsight.imports.PyImportOptimizer
|
||||
import com.jetbrains.python.psi.*
|
||||
import com.jetbrains.python.psi.impl.PyBuiltinCache
|
||||
import com.jetbrains.python.psi.impl.PyCallExpressionHelper
|
||||
import com.jetbrains.python.psi.impl.PyPsiUtils
|
||||
import com.jetbrains.python.psi.resolve.PyResolveContext
|
||||
import com.jetbrains.python.psi.resolve.PyResolveUtil
|
||||
import com.jetbrains.python.psi.types.TypeEvalContext
|
||||
import com.jetbrains.python.pyi.PyiUtil
|
||||
import com.jetbrains.python.refactoring.PyRefactoringUtil
|
||||
import com.jetbrains.python.refactoring.classes.PyClassRefactoringUtil
|
||||
import org.jetbrains.annotations.PropertyKey
|
||||
|
||||
/**
|
||||
* @author Aleksei.Kniazev
|
||||
*/
|
||||
class PyInlineFunctionProcessor(project: Project,
|
||||
private val myEditor: Editor,
|
||||
private val myFunction: PyFunction,
|
||||
private val myReference: PsiReference?,
|
||||
private val myInlineThis: Boolean,
|
||||
removeDeclaration: Boolean) : BaseRefactoringProcessor(project) {
|
||||
|
||||
private val myFunctionClass = myFunction.containingClass
|
||||
private val myGenerator = PyElementGenerator.getInstance(myProject)
|
||||
private var myRemoveDeclaration = !myInlineThis && removeDeclaration
|
||||
|
||||
override fun preprocessUsages(refUsages: Ref<Array<UsageInfo>>): Boolean {
|
||||
if (refUsages.isNull) return false
|
||||
val conflicts = MultiMap.create<PsiElement, String>()
|
||||
val usagesAndImports = refUsages.get()
|
||||
val (imports, usages) = usagesAndImports.partition { PsiTreeUtil.getParentOfType(it.element, PyImportStatementBase::class.java) != null }
|
||||
val filteredUsages = usages.filter { usage ->
|
||||
val element = usage.element!!
|
||||
if (element.parent is PyDecorator) {
|
||||
if (!handleUsageError(element, "refactoring.inline.function.is.decorator", conflicts)) return false
|
||||
return@filter false
|
||||
}
|
||||
else if (element.parent !is PyCallExpression) {
|
||||
if (!handleUsageError(element, "refactoring.inline.function.is.reference", conflicts)) return false
|
||||
return@filter false
|
||||
}
|
||||
else {
|
||||
val callExpression = element.parent as PyCallExpression
|
||||
if (callExpression.arguments.any { it is PyStarArgument}) {
|
||||
if (!handleUsageError(element, "refactoring.inline.function.uses.unpacking", conflicts)) return false
|
||||
return@filter false
|
||||
}
|
||||
}
|
||||
return@filter true
|
||||
}
|
||||
|
||||
val conflictLocations = conflicts.keySet().map { it.containingFile }
|
||||
val filteredImports = imports.filter { it.file !in conflictLocations }
|
||||
val filtered = filteredUsages + filteredImports
|
||||
refUsages.set(filtered.toTypedArray())
|
||||
return showConflicts(conflicts, filtered.toTypedArray())
|
||||
}
|
||||
|
||||
private fun handleUsageError(element: PsiElement, @PropertyKey(resourceBundle = "com.jetbrains.python.PyBundle") error: String, conflicts: MultiMap<PsiElement, String>): Boolean {
|
||||
val errorText = PyBundle.message(error, myFunction.name)
|
||||
if (myInlineThis) {
|
||||
// shortcut for inlining single reference: show error hint instead of modal dialog
|
||||
CommonRefactoringUtil.showErrorHint(myProject, myEditor, errorText, PyBundle.message("refactoring.inline.function.title"), PyInlineFunctionHandler.REFACTORING_ID)
|
||||
prepareSuccessful()
|
||||
return false
|
||||
}
|
||||
conflicts.putValue(element, errorText)
|
||||
myRemoveDeclaration = false
|
||||
return true
|
||||
}
|
||||
|
||||
override fun findUsages(): Array<UsageInfo> {
|
||||
if (myInlineThis) {
|
||||
val element = myReference!!.element as PyReferenceExpression
|
||||
val import = PyResolveUtil.resolveLocally(ScopeUtil.getScopeOwner(element)!!, element.name!!).firstOrNull { it is PyImportElement }
|
||||
return if (import != null) arrayOf(UsageInfo(element), UsageInfo(import)) else arrayOf(UsageInfo(element))
|
||||
}
|
||||
|
||||
return ReferencesSearch.search(myFunction, myRefactoringScope).asSequence()
|
||||
.distinct()
|
||||
.map(PsiReference::getElement)
|
||||
.map(::UsageInfo)
|
||||
.toList()
|
||||
.toTypedArray()
|
||||
}
|
||||
|
||||
override fun performRefactoring(usages: Array<out UsageInfo>) {
|
||||
val action = LocalHistory.getInstance().startAction(commandName)
|
||||
try {
|
||||
doRefactor(usages)
|
||||
}
|
||||
finally {
|
||||
action.finish()
|
||||
}
|
||||
}
|
||||
|
||||
private fun doRefactor(usages: Array<out UsageInfo>) {
|
||||
val (unsortedRefs, imports) = usages.partition { PsiTreeUtil.getParentOfType(it.element, PyImportStatementBase::class.java) == null }
|
||||
|
||||
val references = unsortedRefs.sortedByDescending { usage ->
|
||||
SyntaxTraverser.psiApi().parents(usage.element).asSequence().filter { it is PyCallExpression }.count()
|
||||
}
|
||||
|
||||
val functionScope = ControlFlowCache.getScope(myFunction)
|
||||
PyClassRefactoringUtil.rememberNamedReferences(myFunction)
|
||||
|
||||
references.forEach { usage ->
|
||||
val reference = usage.element as PyReferenceExpression
|
||||
val languageLevel = LanguageLevel.forElement(reference)
|
||||
val refScopeOwner = ScopeUtil.getScopeOwner(reference) ?: error("Unable to find scope owner for ${reference.name}")
|
||||
val declarations = mutableListOf<PyAssignmentStatement>()
|
||||
val generatedNames = mutableSetOf<String>()
|
||||
|
||||
|
||||
val callSite = PsiTreeUtil.getParentOfType(reference, PyCallExpression::class.java) ?: error("Unable to find call expression for ${reference.name}")
|
||||
val containingStatement = PsiTreeUtil.getParentOfType(callSite, PyStatement::class.java) ?: error("Unable to find statement for ${reference.name}")
|
||||
|
||||
val replacementFunction = myFunction.statementList.copy() as PyStatementList
|
||||
val namesInOuterScope = PyRefactoringUtil.collectUsedNames(refScopeOwner)
|
||||
|
||||
val argumentReplacements = mutableMapOf<PyReferenceExpression, PyExpression>()
|
||||
val nameClashes = MultiMap.create<String, PyExpression>()
|
||||
val returnStatements = mutableListOf<PyReturnStatement>()
|
||||
|
||||
val mappedArguments = prepareArguments(callSite, declarations, generatedNames, reference, languageLevel)
|
||||
|
||||
val builtinCache = PyBuiltinCache.getInstance(reference)
|
||||
replacementFunction.accept(object : PyRecursiveElementVisitor() {
|
||||
override fun visitPyReferenceExpression(node: PyReferenceExpression) {
|
||||
if (node.qualifier == null) {
|
||||
val name = node.name
|
||||
if (name in mappedArguments) {
|
||||
argumentReplacements[node] = mappedArguments[name]!!
|
||||
}
|
||||
else if (name in namesInOuterScope && !builtinCache.isBuiltin(node.reference.resolve())) {
|
||||
nameClashes.putValue(name!!, node)
|
||||
}
|
||||
}
|
||||
super.visitPyReferenceExpression(node)
|
||||
}
|
||||
|
||||
override fun visitPyReturnStatement(node: PyReturnStatement) {
|
||||
returnStatements.add(node)
|
||||
super.visitPyReturnStatement(node)
|
||||
}
|
||||
|
||||
override fun visitPyTargetExpression(node: PyTargetExpression) {
|
||||
if (node.qualifier == null) {
|
||||
val name = node.name
|
||||
if (name in namesInOuterScope && name !in mappedArguments && functionScope.containsDeclaration(name)) {
|
||||
nameClashes.putValue(name!!, node)
|
||||
}
|
||||
}
|
||||
super.visitPyTargetExpression(node)
|
||||
}
|
||||
})
|
||||
|
||||
// Replacing
|
||||
argumentReplacements.forEach { (old, new) -> old.replace(new) }
|
||||
nameClashes.entrySet().forEach { (name, elements) ->
|
||||
val generated = generateUniqueAssignment(languageLevel, name, generatedNames, reference)
|
||||
elements.forEach {
|
||||
when (it) {
|
||||
is PyTargetExpression -> it.replace(generated.targets[0])
|
||||
is PyReferenceExpression -> it.replace(generated.assignedValue!!)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if (returnStatements.size == 1 && returnStatements[0].expression !is PyTupleExpression) {
|
||||
// replace single return with expression itself
|
||||
val statement = returnStatements[0]
|
||||
callSite.replace(statement.expression!!)
|
||||
statement.delete()
|
||||
}
|
||||
else if (returnStatements.isNotEmpty()) {
|
||||
val newReturn = generateUniqueAssignment(languageLevel, "result", generatedNames, reference)
|
||||
returnStatements.forEach {
|
||||
val copy = newReturn.copy() as PyAssignmentStatement
|
||||
copy.assignedValue!!.replace(it.expression!!)
|
||||
it.replace(copy)
|
||||
}
|
||||
callSite.replace(newReturn.assignedValue!!)
|
||||
}
|
||||
|
||||
CodeStyleManager.getInstance(myProject).reformat(replacementFunction, true)
|
||||
|
||||
declarations.forEach { containingStatement.parent.addBefore(it, containingStatement) }
|
||||
if (replacementFunction.firstChild != null) {
|
||||
val statements = replacementFunction.statements
|
||||
statements.asSequence()
|
||||
.map { containingStatement.parent.addBefore(it, containingStatement) }
|
||||
.forEach { PyClassRefactoringUtil.restoreNamedReferences(it) }
|
||||
}
|
||||
|
||||
if (returnStatements.isEmpty()) {
|
||||
if (callSite.parent is PyExpressionStatement) {
|
||||
containingStatement.delete()
|
||||
}
|
||||
else {
|
||||
callSite.replace(myGenerator.createExpressionFromText(languageLevel, "None"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
imports.forEach { PyClassRefactoringUtil.optimizeImports(it.element!!.containingFile!!) }
|
||||
|
||||
if (myRemoveDeclaration) {
|
||||
val stubFunction = PyiUtil.getPythonStub(myFunction)
|
||||
if (stubFunction != null && stubFunction.isWritable) {
|
||||
stubFunction.delete()
|
||||
}
|
||||
myFunction.delete()
|
||||
}
|
||||
}
|
||||
|
||||
private fun prepareArguments(callSite: PyCallExpression, declarations: MutableList<PyAssignmentStatement>, generatedNames: MutableSet<String>,
|
||||
reference: PyReferenceExpression, languageLevel: LanguageLevel): Map<String, PyExpression> {
|
||||
val context = PyResolveContext.noImplicits().withTypeEvalContext(TypeEvalContext.userInitiated(myProject, reference.containingFile))
|
||||
val mapping = PyCallExpressionHelper.mapArguments(callSite, context).firstOrNull() ?: error("Can't map arguments for ${reference.name}")
|
||||
val mappedParams = mapping.mappedParameters
|
||||
|
||||
|
||||
val self = mapping.implicitParameters.firstOrNull()?.let { first ->
|
||||
val implicitName = first.name!!
|
||||
val selfReplacement = reference.qualifier?.let { qualifier ->
|
||||
myFunctionClass?.let {
|
||||
when {
|
||||
qualifier is PyReferenceExpression && !qualifier.isQualified -> qualifier
|
||||
else -> {
|
||||
val qualifierDeclaration = generateUniqueAssignment(languageLevel, myFunctionClass.name!!, generatedNames, reference)
|
||||
val newRef = qualifierDeclaration.assignedValue!!.copy() as PyExpression
|
||||
qualifierDeclaration.assignedValue!!.replace(qualifier)
|
||||
declarations.add(qualifierDeclaration)
|
||||
newRef
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
mapOf(implicitName to selfReplacement!!)
|
||||
} ?: emptyMap()
|
||||
|
||||
val passedArguments = mappedParams.asSequence()
|
||||
.map { (arg, param) ->
|
||||
val argValue = if (arg is PyKeywordArgument) arg.valueExpression!! else arg
|
||||
tryExtractDeclaration(param.name!!, argValue, declarations, generatedNames, reference, languageLevel)
|
||||
}
|
||||
.toMap()
|
||||
|
||||
val defaultValues = myFunction.parameterList.parameters.asSequence()
|
||||
.filter { it.name !in passedArguments }
|
||||
.filter { it.hasDefaultValue() }
|
||||
.map { tryExtractDeclaration(it.name!!, it.defaultValue!!, declarations, generatedNames, reference, languageLevel) }
|
||||
.toMap()
|
||||
|
||||
return self + passedArguments + defaultValues
|
||||
}
|
||||
|
||||
private fun tryExtractDeclaration(paramName: String, arg: PyExpression, declarations: MutableList<PyAssignmentStatement>, generatedNames: MutableSet<String>,
|
||||
reference: PyReferenceExpression, languageLevel: LanguageLevel): Pair<String, PyExpression> {
|
||||
if (arg !is PyReferenceExpression && arg !is PyLiteralExpression) {
|
||||
val statement = generateUniqueAssignment(languageLevel, paramName, generatedNames, reference)
|
||||
statement.assignedValue!!.replace(arg)
|
||||
declarations.add(statement)
|
||||
return paramName to statement.targets[0]
|
||||
}
|
||||
return paramName to arg
|
||||
|
||||
}
|
||||
|
||||
private fun generateUniqueAssignment(level: LanguageLevel, name: String, previouslyGeneratedNames: MutableSet<String>, scopeAnchor: PsiElement): PyAssignmentStatement {
|
||||
val uniqueName = PyRefactoringUtil.selectUniqueName(name, scopeAnchor) { newName, anchor ->
|
||||
PyRefactoringUtil.isValidNewName(newName, anchor) && newName !in previouslyGeneratedNames
|
||||
}
|
||||
previouslyGeneratedNames.add(uniqueName)
|
||||
return myGenerator.createFromText(level, PyAssignmentStatement::class.java, "$uniqueName = $uniqueName")
|
||||
}
|
||||
|
||||
override fun getCommandName() = "Inlining ${myFunction.name}"
|
||||
override fun getRefactoringId() = PyInlineFunctionHandler.REFACTORING_ID
|
||||
|
||||
override fun createUsageViewDescriptor(usages: Array<out UsageInfo>) = object : UsageViewDescriptor {
|
||||
override fun getElements(): Array<PsiElement> = arrayOf(myFunction)
|
||||
override fun getProcessedElementsHeader(): String = "Function to inline "
|
||||
override fun getCodeReferencesText(usagesCount: Int, filesCount: Int): String = "Invocations to be inlined in $filesCount files"
|
||||
override fun getCommentReferencesText(usagesCount: Int, filesCount: Int): String = ""
|
||||
}
|
||||
}
|
||||
@@ -51,7 +51,7 @@ public class PyInlineLocalHandler extends InlineActionHandler {
|
||||
|
||||
private static final String REFACTORING_NAME = RefactoringBundle.message("inline.variable.title");
|
||||
private static final Pair<PyStatement, Boolean> EMPTY_DEF_RESULT = Pair.create(null, false);
|
||||
private static final String HELP_ID = "python.reference.inline";
|
||||
private static final String HELP_ID = "refactoring.inlineVariable";
|
||||
|
||||
public static PyInlineLocalHandler getInstance() {
|
||||
return EP_NAME.findExtensionOrFail(PyInlineLocalHandler.class);
|
||||
|
||||
@@ -229,7 +229,7 @@ public class PyMakeMethodTopLevelProcessor extends PyBaseMakeFunctionTopLevelPro
|
||||
final PsiElement anchor = ContainerUtil.getFirstItem(reads);
|
||||
//noinspection ConstantConditions
|
||||
if (!PyRefactoringUtil.isValidNewName(name, anchor)) {
|
||||
final String indexedName = PyRefactoringUtil.appendNumberUntilValid(name, anchor);
|
||||
final String indexedName = PyRefactoringUtil.appendNumberUntilValid(name, anchor, PyRefactoringUtil::isValidNewName);
|
||||
myAttributeToParameterName.put(name, indexedName);
|
||||
}
|
||||
else {
|
||||
|
||||
@@ -0,0 +1,19 @@
|
||||
def foo(x, y, z):
|
||||
print(x)
|
||||
print(y)
|
||||
print(z)
|
||||
return x
|
||||
|
||||
|
||||
def id(x):
|
||||
return x
|
||||
|
||||
|
||||
def bar():
|
||||
x = id(1)
|
||||
y = id(2)
|
||||
z = id(3)
|
||||
print(x)
|
||||
print(y)
|
||||
print(z)
|
||||
res = x
|
||||
@@ -0,0 +1,13 @@
|
||||
def foo(x, y, z):
|
||||
print(x)
|
||||
print(y)
|
||||
print(z)
|
||||
return x
|
||||
|
||||
|
||||
def id(x):
|
||||
return x
|
||||
|
||||
|
||||
def bar():
|
||||
res = fo<caret>o(id(1), id(2), z = id(3))
|
||||
@@ -0,0 +1,6 @@
|
||||
async def foo():
|
||||
pass
|
||||
|
||||
|
||||
async def bar():
|
||||
await f<caret>oo()
|
||||
1
python/testData/refactoring/inlineFunction/builtin.py
Normal file
1
python/testData/refactoring/inlineFunction/builtin.py
Normal file
@@ -0,0 +1 @@
|
||||
inp<caret>ut()
|
||||
@@ -0,0 +1,16 @@
|
||||
class A:
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
def doStuff(self):
|
||||
print(self.a)
|
||||
print(self.b)
|
||||
return self.a + self.b
|
||||
|
||||
@classmethod
|
||||
def foo(cls):
|
||||
my_a = A(1, 2)
|
||||
print(my_a.a)
|
||||
print(my_a.b)
|
||||
res = my_a.a + my_a.b
|
||||
@@ -0,0 +1,14 @@
|
||||
class A:
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
def doStuff(self):
|
||||
print(self.a)
|
||||
print(self.b)
|
||||
return self.a + self.b
|
||||
|
||||
@classmethod
|
||||
def foo(cls):
|
||||
my_a = A(1, 2)
|
||||
res = my_a.doSt<caret>uff()
|
||||
@@ -0,0 +1,16 @@
|
||||
class A:
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
def doStuff(self):
|
||||
print(self.a)
|
||||
print(self.b)
|
||||
return self.a + self.b
|
||||
|
||||
@staticmethod
|
||||
def foo():
|
||||
my_a = A(1, 2)
|
||||
print(my_a.a)
|
||||
print(my_a.b)
|
||||
res = my_a.a + my_a.b
|
||||
@@ -0,0 +1,14 @@
|
||||
class A:
|
||||
def __init__(self, a, b):
|
||||
self.a = a
|
||||
self.b = b
|
||||
|
||||
def doStuff(self):
|
||||
print(self.a)
|
||||
print(self.b)
|
||||
return self.a + self.b
|
||||
|
||||
@staticmethod
|
||||
def foo():
|
||||
my_a = A(1, 2)
|
||||
res = my_a.doSt<caret>uff()
|
||||
@@ -0,0 +1,15 @@
|
||||
class MyClass:
|
||||
def __init__(self, attr):
|
||||
self.attr = attr
|
||||
|
||||
def __add__(self, other):
|
||||
return MyClass(self.attr + other.attr)
|
||||
|
||||
def method(self):
|
||||
print(self.attr)
|
||||
print(self.attr)
|
||||
|
||||
|
||||
my_class = (MyClass(1) + MyClass(2))
|
||||
print(my_class.attr)
|
||||
print(my_class.attr)
|
||||
@@ -0,0 +1,13 @@
|
||||
class MyClass:
|
||||
def __init__(self, attr):
|
||||
self.attr = attr
|
||||
|
||||
def __add__(self, other):
|
||||
return MyClass(self.attr + other.attr)
|
||||
|
||||
def method(self):
|
||||
print(self.attr)
|
||||
print(self.attr)
|
||||
|
||||
|
||||
(MyClass(1) + MyClass(2)).meth<caret>od()
|
||||
@@ -0,0 +1,6 @@
|
||||
class MyClass():
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
|
||||
MyCla<caret>ss()
|
||||
10
python/testData/refactoring/inlineFunction/decorator.py
Normal file
10
python/testData/refactoring/inlineFunction/decorator.py
Normal file
@@ -0,0 +1,10 @@
|
||||
def foo(x):
|
||||
return x
|
||||
|
||||
|
||||
@foo
|
||||
def bar():
|
||||
pass
|
||||
|
||||
|
||||
ba<caret>r()
|
||||
@@ -0,0 +1,8 @@
|
||||
def foo(arg1=1, arg2=2, arg3=3, arg4=4):
|
||||
res = arg1 + arg2 + arg3 + arg4
|
||||
print(res)
|
||||
|
||||
|
||||
def bar():
|
||||
res = 10 + 20 + 3 + 4
|
||||
print(res)
|
||||
@@ -0,0 +1,7 @@
|
||||
def foo(arg1=1, arg2=2, arg3=3, arg4=4):
|
||||
res = arg1 + arg2 + arg3 + arg4
|
||||
print(res)
|
||||
|
||||
|
||||
def bar():
|
||||
f<caret>oo(10, 20)
|
||||
@@ -0,0 +1,10 @@
|
||||
from source import do_more_work
|
||||
from source import do_work
|
||||
from source import do_substantially_more_work as do_nothing
|
||||
|
||||
|
||||
def bar():
|
||||
do_work()
|
||||
do_more_work(42)
|
||||
do_nothing()
|
||||
res = 42
|
||||
@@ -0,0 +1,7 @@
|
||||
from source import foo
|
||||
from source import do_work
|
||||
from source import do_substantially_more_work as do_nothing
|
||||
|
||||
|
||||
def bar():
|
||||
res = fo<caret>o(42)
|
||||
@@ -0,0 +1,19 @@
|
||||
|
||||
|
||||
def do_work():
|
||||
pass
|
||||
|
||||
|
||||
def do_more_work(x):
|
||||
pass
|
||||
|
||||
|
||||
def do_substantially_more_work():
|
||||
pass
|
||||
|
||||
|
||||
def foo(arg):
|
||||
do_work()
|
||||
do_more_work(arg)
|
||||
do_substantially_more_work()
|
||||
return arg
|
||||
@@ -0,0 +1,8 @@
|
||||
import source
|
||||
|
||||
|
||||
def bar():
|
||||
source.do_work()
|
||||
source.do_more_work(42)
|
||||
source.do_substantially_more_work()
|
||||
res = 42
|
||||
@@ -0,0 +1,5 @@
|
||||
from source import foo
|
||||
|
||||
|
||||
def bar():
|
||||
res = fo<caret>o(42)
|
||||
@@ -0,0 +1,19 @@
|
||||
|
||||
|
||||
def do_work():
|
||||
pass
|
||||
|
||||
|
||||
def do_more_work(x):
|
||||
pass
|
||||
|
||||
|
||||
def do_substantially_more_work():
|
||||
pass
|
||||
|
||||
|
||||
def foo(arg):
|
||||
do_work()
|
||||
do_more_work(arg)
|
||||
do_substantially_more_work()
|
||||
return arg
|
||||
5
python/testData/refactoring/inlineFunction/generator.py
Normal file
5
python/testData/refactoring/inlineFunction/generator.py
Normal file
@@ -0,0 +1,5 @@
|
||||
def foo():
|
||||
yield 1
|
||||
|
||||
|
||||
f<caret>oo()
|
||||
@@ -0,0 +1,8 @@
|
||||
from source import do_work, do_more_work, do_substantially_more_work
|
||||
|
||||
|
||||
def bar():
|
||||
do_work()
|
||||
do_more_work(42)
|
||||
do_substantially_more_work()
|
||||
res = 42
|
||||
@@ -0,0 +1,5 @@
|
||||
from source import foo
|
||||
|
||||
|
||||
def bar():
|
||||
res = fo<caret>o(42)
|
||||
@@ -0,0 +1,19 @@
|
||||
|
||||
|
||||
def do_work():
|
||||
pass
|
||||
|
||||
|
||||
def do_more_work(x):
|
||||
pass
|
||||
|
||||
|
||||
def do_substantially_more_work():
|
||||
pass
|
||||
|
||||
|
||||
def foo(arg):
|
||||
do_work()
|
||||
do_more_work(arg)
|
||||
do_substantially_more_work()
|
||||
return arg
|
||||
@@ -0,0 +1,36 @@
|
||||
def foo(arg):
|
||||
local = 1
|
||||
if arg:
|
||||
another = 2
|
||||
else:
|
||||
another = 3
|
||||
return local
|
||||
|
||||
|
||||
def bar():
|
||||
x = 1
|
||||
local = 1
|
||||
if x:
|
||||
another = 2
|
||||
else:
|
||||
another = 3
|
||||
res = local
|
||||
|
||||
|
||||
def baz():
|
||||
y = 2
|
||||
local = 1
|
||||
if y:
|
||||
another = 2
|
||||
else:
|
||||
another = 3
|
||||
res = local
|
||||
|
||||
|
||||
z = 1
|
||||
local = 1
|
||||
if z:
|
||||
another = 2
|
||||
else:
|
||||
another = 3
|
||||
res = local
|
||||
21
python/testData/refactoring/inlineFunction/inlineAll/main.py
Normal file
21
python/testData/refactoring/inlineFunction/inlineAll/main.py
Normal file
@@ -0,0 +1,21 @@
|
||||
def fo<caret>o(arg):
|
||||
local = 1
|
||||
if arg:
|
||||
another = 2
|
||||
else:
|
||||
another = 3
|
||||
return local
|
||||
|
||||
|
||||
def bar():
|
||||
x = 1
|
||||
res = foo(x)
|
||||
|
||||
|
||||
def baz():
|
||||
y = 2
|
||||
res = foo(y)
|
||||
|
||||
|
||||
z = 1
|
||||
res = foo(z)
|
||||
@@ -0,0 +1,5 @@
|
||||
print(1)
|
||||
print(2)
|
||||
z = 1 + 2
|
||||
print(z)
|
||||
res = z
|
||||
@@ -0,0 +1,3 @@
|
||||
from src import foo as bar
|
||||
|
||||
res = ba<caret>r(1, 2)
|
||||
@@ -0,0 +1,6 @@
|
||||
def foo(x, y):
|
||||
print(x)
|
||||
print(y)
|
||||
z = x + y
|
||||
print(z)
|
||||
return z
|
||||
@@ -0,0 +1,8 @@
|
||||
def foo(arg):
|
||||
if not arg:
|
||||
return None
|
||||
print("working")
|
||||
return arg
|
||||
|
||||
|
||||
f<caret>oo(1)
|
||||
@@ -0,0 +1,8 @@
|
||||
def foo(*, a, b):
|
||||
print(a)
|
||||
print(b)
|
||||
|
||||
|
||||
def bar():
|
||||
print(1)
|
||||
print(2)
|
||||
@@ -0,0 +1,7 @@
|
||||
def foo(*, a, b):
|
||||
print(a)
|
||||
print(b)
|
||||
|
||||
|
||||
def bar():
|
||||
fo<caret>o(a = 1, b = 2)
|
||||
@@ -0,0 +1,27 @@
|
||||
class MyClass:
|
||||
|
||||
def do_stuff(self, x, y):
|
||||
print(x)
|
||||
print(y)
|
||||
self.do_something_else()
|
||||
if x:
|
||||
print(x)
|
||||
elif y:
|
||||
print(y)
|
||||
else:
|
||||
print("nothing")
|
||||
result = x, y
|
||||
return result
|
||||
|
||||
def for_inline(self, a, b):
|
||||
self.do_something_else()
|
||||
if a:
|
||||
print(a)
|
||||
elif b:
|
||||
print(b)
|
||||
else:
|
||||
print("nothing")
|
||||
return a, b
|
||||
|
||||
def do_something_else(self):
|
||||
pass
|
||||
@@ -0,0 +1,19 @@
|
||||
class MyClass:
|
||||
|
||||
def do_stuff(self, x, y):
|
||||
print(x)
|
||||
print(y)
|
||||
return self.for_in<caret>line(x, y)
|
||||
|
||||
def for_inline(self, a, b):
|
||||
self.do_something_else()
|
||||
if a:
|
||||
print(a)
|
||||
elif b:
|
||||
print(b)
|
||||
else:
|
||||
print("nothing")
|
||||
return a, b
|
||||
|
||||
def do_something_else(self):
|
||||
pass
|
||||
@@ -0,0 +1,33 @@
|
||||
class MyClass:
|
||||
|
||||
def do_stuff(self, x, y):
|
||||
print(x)
|
||||
print(y)
|
||||
return self.for_inline(x, y)
|
||||
|
||||
def for_inline(self, a, b):
|
||||
self.do_something_else()
|
||||
if a:
|
||||
print(a)
|
||||
elif b:
|
||||
print(b)
|
||||
else:
|
||||
print("nothing")
|
||||
return a, b
|
||||
|
||||
def do_something_else(self):
|
||||
pass
|
||||
|
||||
|
||||
x = 1
|
||||
y = 2
|
||||
cls = MyClass()
|
||||
cls.do_something_else()
|
||||
if x:
|
||||
print(x)
|
||||
elif y:
|
||||
print(y)
|
||||
else:
|
||||
print("nothing")
|
||||
result = x, y
|
||||
res = result
|
||||
@@ -0,0 +1,25 @@
|
||||
class MyClass:
|
||||
|
||||
def do_stuff(self, x, y):
|
||||
print(x)
|
||||
print(y)
|
||||
return self.for_inline(x, y)
|
||||
|
||||
def for_inline(self, a, b):
|
||||
self.do_something_else()
|
||||
if a:
|
||||
print(a)
|
||||
elif b:
|
||||
print(b)
|
||||
else:
|
||||
print("nothing")
|
||||
return a, b
|
||||
|
||||
def do_something_else(self):
|
||||
pass
|
||||
|
||||
|
||||
x = 1
|
||||
y = 2
|
||||
cls = MyClass()
|
||||
res = cls.for_in<caret>line(x, y)
|
||||
@@ -0,0 +1,26 @@
|
||||
def foo(x, y, z):
|
||||
if x:
|
||||
return x + 2
|
||||
elif y:
|
||||
return y + 2
|
||||
else:
|
||||
if z:
|
||||
return z + 2
|
||||
else:
|
||||
return 2
|
||||
|
||||
|
||||
def bar():
|
||||
a = 1
|
||||
b = 2
|
||||
c = 3
|
||||
if a:
|
||||
result = a + 2
|
||||
elif b:
|
||||
result = b + 2
|
||||
else:
|
||||
if c:
|
||||
result = c + 2
|
||||
else:
|
||||
result = 2
|
||||
res = result
|
||||
@@ -0,0 +1,17 @@
|
||||
def foo(x, y, z):
|
||||
if x:
|
||||
return x + 2
|
||||
elif y:
|
||||
return y + 2
|
||||
else:
|
||||
if z:
|
||||
return z + 2
|
||||
else:
|
||||
return 2
|
||||
|
||||
|
||||
def bar():
|
||||
a = 1
|
||||
b = 2
|
||||
c = 3
|
||||
res = fo<caret>o(a, b, c)
|
||||
@@ -0,0 +1,19 @@
|
||||
def foo(arg):
|
||||
local = 1
|
||||
if arg:
|
||||
another = 2
|
||||
else:
|
||||
another = 3
|
||||
return local
|
||||
|
||||
|
||||
def bar():
|
||||
x = 1
|
||||
local = 2
|
||||
another = 3
|
||||
local1 = 1
|
||||
if x:
|
||||
another1 = 2
|
||||
else:
|
||||
another1 = 3
|
||||
res = local1
|
||||
14
python/testData/refactoring/inlineFunction/nameClash/main.py
Normal file
14
python/testData/refactoring/inlineFunction/nameClash/main.py
Normal file
@@ -0,0 +1,14 @@
|
||||
def foo(arg):
|
||||
local = 1
|
||||
if arg:
|
||||
another = 2
|
||||
else:
|
||||
another = 3
|
||||
return local
|
||||
|
||||
|
||||
def bar():
|
||||
x = 1
|
||||
local = 2
|
||||
another = 3
|
||||
res = fo<caret>o(x)
|
||||
8
python/testData/refactoring/inlineFunction/nested.py
Normal file
8
python/testData/refactoring/inlineFunction/nested.py
Normal file
@@ -0,0 +1,8 @@
|
||||
def foo():
|
||||
def bar():
|
||||
pass
|
||||
|
||||
bar()
|
||||
|
||||
|
||||
fo<caret>o()
|
||||
@@ -0,0 +1,3 @@
|
||||
print("fun called")
|
||||
print("fun called")
|
||||
res = 1
|
||||
@@ -0,0 +1,6 @@
|
||||
def fun(x):
|
||||
print("fun called")
|
||||
return x
|
||||
|
||||
|
||||
res = fu<caret>n(fun(1))
|
||||
@@ -0,0 +1,21 @@
|
||||
def foo(x, y):
|
||||
s = x + y
|
||||
if s > 10:
|
||||
print("s>10")
|
||||
elif s > 5:
|
||||
print("s>5")
|
||||
else:
|
||||
print("less")
|
||||
print("over")
|
||||
|
||||
|
||||
def bar():
|
||||
s = 1 + 2
|
||||
if s > 10:
|
||||
print("s>10")
|
||||
elif s > 5:
|
||||
print("s>5")
|
||||
else:
|
||||
print("less")
|
||||
print("over")
|
||||
res = None
|
||||
@@ -0,0 +1,13 @@
|
||||
def foo(x, y):
|
||||
s = x + y
|
||||
if s > 10:
|
||||
print("s>10")
|
||||
elif s > 5:
|
||||
print("s>5")
|
||||
else:
|
||||
print("less")
|
||||
print("over")
|
||||
|
||||
|
||||
def bar():
|
||||
res = f<caret>oo(1, 2)
|
||||
@@ -0,0 +1,20 @@
|
||||
def foo(x, y):
|
||||
s = x + y
|
||||
if s > 10:
|
||||
print("s>10")
|
||||
elif s > 5:
|
||||
print("s>5")
|
||||
else:
|
||||
print("less")
|
||||
print("over")
|
||||
|
||||
|
||||
def bar():
|
||||
s = 1 + 2
|
||||
if s > 10:
|
||||
print("s>10")
|
||||
elif s > 5:
|
||||
print("s>5")
|
||||
else:
|
||||
print("less")
|
||||
print("over")
|
||||
@@ -0,0 +1,13 @@
|
||||
def foo(x, y):
|
||||
s = x + y
|
||||
if s > 10:
|
||||
print("s>10")
|
||||
elif s > 5:
|
||||
print("s>5")
|
||||
else:
|
||||
print("less")
|
||||
print("over")
|
||||
|
||||
|
||||
def bar():
|
||||
f<caret>oo(1, 2)
|
||||
12
python/testData/refactoring/inlineFunction/overridden.py
Normal file
12
python/testData/refactoring/inlineFunction/overridden.py
Normal file
@@ -0,0 +1,12 @@
|
||||
class A():
|
||||
def method(self):
|
||||
print(1)
|
||||
|
||||
|
||||
class B(A):
|
||||
def method(self):
|
||||
print(2)
|
||||
|
||||
|
||||
b = B()
|
||||
b.meth<caret>od()
|
||||
@@ -0,0 +1,8 @@
|
||||
def foo(a, b, /):
|
||||
print(a)
|
||||
print(b)
|
||||
|
||||
|
||||
def bar():
|
||||
print(1)
|
||||
print(2)
|
||||
@@ -0,0 +1,7 @@
|
||||
def foo(a, b, /):
|
||||
print(a)
|
||||
print(b)
|
||||
|
||||
|
||||
def bar():
|
||||
fo<caret>o(1, 2)
|
||||
5
python/testData/refactoring/inlineFunction/recursive.py
Normal file
5
python/testData/refactoring/inlineFunction/recursive.py
Normal file
@@ -0,0 +1,5 @@
|
||||
def foo():
|
||||
foo()
|
||||
|
||||
|
||||
fo<caret>o()
|
||||
@@ -0,0 +1,3 @@
|
||||
print(1)
|
||||
print(2)
|
||||
1 + 2
|
||||
@@ -0,0 +1,7 @@
|
||||
def foo(x, y):
|
||||
print(x)
|
||||
print(y)
|
||||
return x + y
|
||||
|
||||
|
||||
f<caret>oo(1, 2)
|
||||
@@ -0,0 +1,2 @@
|
||||
def foo(x: int, y: int) -> int:
|
||||
...
|
||||
@@ -0,0 +1,8 @@
|
||||
def bar():
|
||||
x = 1
|
||||
local = 1
|
||||
if x:
|
||||
another = 2
|
||||
else:
|
||||
another = 3
|
||||
res = local
|
||||
12
python/testData/refactoring/inlineFunction/removing/main.py
Normal file
12
python/testData/refactoring/inlineFunction/removing/main.py
Normal file
@@ -0,0 +1,12 @@
|
||||
def foo(arg):
|
||||
local = 1
|
||||
if arg:
|
||||
another = 2
|
||||
else:
|
||||
another = 3
|
||||
return local
|
||||
|
||||
|
||||
def bar():
|
||||
x = 1
|
||||
res = fo<caret>o(x)
|
||||
@@ -0,0 +1,7 @@
|
||||
def foo(arg):
|
||||
local = arg + arg
|
||||
return local
|
||||
|
||||
|
||||
local = 42 + 42
|
||||
x = local
|
||||
@@ -0,0 +1,6 @@
|
||||
def foo(arg):
|
||||
local = arg + arg
|
||||
return local
|
||||
|
||||
|
||||
x = fo<caret>o(42)
|
||||
4
python/testData/refactoring/inlineFunction/star.py
Normal file
4
python/testData/refactoring/inlineFunction/star.py
Normal file
@@ -0,0 +1,4 @@
|
||||
def foo(*args, **kwargs):
|
||||
pass
|
||||
|
||||
fo<caret>o(1, 2, 3, x = 4)
|
||||
@@ -0,0 +1,7 @@
|
||||
def foo(arg):
|
||||
return arg
|
||||
|
||||
|
||||
@fo<caret>o
|
||||
def bar():
|
||||
pass
|
||||
@@ -0,0 +1,9 @@
|
||||
def call_with_one(fun):
|
||||
fun(1)
|
||||
|
||||
|
||||
def foo(x):
|
||||
print(x)
|
||||
|
||||
|
||||
call_with_one(fo<caret>o)
|
||||
@@ -0,0 +1,6 @@
|
||||
def foo(a, b, c, d):
|
||||
print(a, b, c, d)
|
||||
|
||||
|
||||
arg = (1, 2, 3, 4)
|
||||
fo<caret>o(*arg)
|
||||
@@ -0,0 +1,96 @@
|
||||
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
package com.jetbrains.python.refactoring
|
||||
|
||||
import com.intellij.codeInsight.TargetElementUtil
|
||||
import com.intellij.refactoring.util.CommonRefactoringUtil
|
||||
import com.jetbrains.python.codeInsight.PyCodeInsightSettings
|
||||
import com.jetbrains.python.fixtures.PyTestCase
|
||||
import com.jetbrains.python.psi.LanguageLevel
|
||||
import com.jetbrains.python.psi.PyElement
|
||||
import com.jetbrains.python.psi.PyFunction
|
||||
import com.jetbrains.python.pyi.PyiFile
|
||||
import com.jetbrains.python.pyi.PyiUtil
|
||||
import com.jetbrains.python.refactoring.inline.PyInlineFunctionHandler
|
||||
import com.jetbrains.python.refactoring.inline.PyInlineFunctionProcessor
|
||||
import junit.framework.TestCase
|
||||
|
||||
/**
|
||||
* @author Aleksei.Kniazev
|
||||
*/
|
||||
class PyInlineFunctionTest : PyTestCase() {
|
||||
|
||||
override fun getTestDataPath(): String = super.getTestDataPath() + "/refactoring/inlineFunction"
|
||||
|
||||
private fun doTest(inlineThis: Boolean = true, remove: Boolean = false) {
|
||||
val testName = getTestName(true)
|
||||
myFixture.copyDirectoryToProject(testName, "")
|
||||
myFixture.configureByFile("main.py")
|
||||
var element = TargetElementUtil.findTargetElement(myFixture.editor, TargetElementUtil.getInstance().referenceSearchFlags)
|
||||
if (element!!.containingFile is PyiFile) element = PyiUtil.getOriginalElement(element as PyElement)
|
||||
val reference = TargetElementUtil.findReference(myFixture.editor)
|
||||
TestCase.assertTrue(element is PyFunction)
|
||||
PyInlineFunctionProcessor(myFixture.project, myFixture. editor, element as PyFunction, reference, inlineThis, remove).run()
|
||||
myFixture.checkResultByFile("$testName/main.after.py")
|
||||
}
|
||||
|
||||
private fun doTestError(expectedError: String, isReferenceError: Boolean = false) {
|
||||
myFixture.configureByFile("${getTestName(true)}.py")
|
||||
val element = TargetElementUtil.findTargetElement(myFixture.editor, TargetElementUtil.getInstance().referenceSearchFlags)
|
||||
try {
|
||||
if (isReferenceError) {
|
||||
val reference = TargetElementUtil.findReference(myFixture.editor)
|
||||
PyInlineFunctionProcessor(myFixture.project, myFixture. editor, element as PyFunction, reference, myInlineThis = true, removeDeclaration = false).run()
|
||||
}
|
||||
else {
|
||||
PyInlineFunctionHandler.getInstance().inlineElement(myFixture.project, myFixture.editor, element)
|
||||
}
|
||||
TestCase.fail("Expected error: $expectedError, but got none")
|
||||
}
|
||||
catch (e: CommonRefactoringUtil.RefactoringErrorHintException) {
|
||||
TestCase.assertEquals(expectedError, e.message)
|
||||
}
|
||||
}
|
||||
|
||||
fun testSimple() = doTest()
|
||||
fun testNameClash() = doTest()
|
||||
fun testArgumentExtraction() = doTest()
|
||||
fun testMultipleReturns() = doTest()
|
||||
fun testImporting() = doTest()
|
||||
//fun testExistingImports() = doTest()
|
||||
fun testMethodInsideClass() = doTest()
|
||||
fun testMethodOutsideClass() = doTest()
|
||||
fun testNoReturnsAsExpressionStatement() = doTest()
|
||||
fun testNoReturnsAsCallExpression() = doTest()
|
||||
fun testInlineAll() = doTest(inlineThis = false)
|
||||
fun testRemoving() = doTest(inlineThis = false, remove = true)
|
||||
fun testDefaultValues() = doTest()
|
||||
fun testPositionalOnlyArgs() = doTest()
|
||||
fun testKeywordOnlyArgs() = doTest()
|
||||
fun testNestedCalls() = doTest(inlineThis = false, remove = true)
|
||||
fun testCallFromStaticMethod() = doTest()
|
||||
fun testCallFromClassMethod() = doTest()
|
||||
fun testComplexQualifier() = doTest()
|
||||
//fun testInlineImportedAs() = doTest(inlineThis = false)
|
||||
fun testRemoveFunctionWithStub() {
|
||||
doTest(inlineThis = false, remove = true)
|
||||
val testName = getTestName(true)
|
||||
myFixture.checkResultByFile("main.pyi", "$testName/main.after.pyi",true)
|
||||
}
|
||||
fun testGenerator() = doTestError("Cannot inline generators")
|
||||
fun testAsyncFunction() {
|
||||
runWithLanguageLevel(LanguageLevel.PYTHON37) {
|
||||
doTestError("Cannot inline async functions")
|
||||
}
|
||||
}
|
||||
fun testConstructor() = doTestError("Cannot inline constructor calls")
|
||||
fun testBuiltin() = doTestError("Cannot inline builtin functions")
|
||||
fun testDecorator() = doTestError("Cannot inline functions with decorators")
|
||||
fun testRecursive() = doTestError("Cannot inline functions that reference themselves")
|
||||
fun testStar() = doTestError("Cannot inline functions with * arguments")
|
||||
fun testOverridden() = doTestError("Cannot inline overridden functions")
|
||||
fun testNested() = doTestError("Cannot inline functions with another function declaration")
|
||||
fun testInterruptedFlow() = doTestError("Cannot inline functions that interrupt control flow")
|
||||
fun testUsedAsDecorator() = doTestError("Function foo is used as a decorator and cannot be inlined. Function definition will not be removed", isReferenceError = true)
|
||||
fun testUsedAsReference() = doTestError("Function foo is used as a reference and cannot be inlined. Function definition will not be removed", isReferenceError = true)
|
||||
fun testUsesArgumentUnpacking() = doTestError("Function foo uses argument unpacking and cannot be inlined. Function definition will not be removed", isReferenceError = true)
|
||||
}
|
||||
Reference in New Issue
Block a user