PY-35287 Extract method with typehints

- customize extract method refactoring for Python using prefix Py
- option to enable/disable type annotations
- persist value of checkbox.isSelected
- run all extract method tests using types
- add specific typed test
- adjust api-dump.txt since some members of AbstractExtractMethodDialog were raised from private to protected
- small adjustment in test expectation of the inferred type

GitOrigin-RevId: be6e70dcb61c451debb98c10c0a001234188cb7d
This commit is contained in:
Marcus Mews
2025-07-31 14:08:38 +00:00
committed by intellij-monorepo-bot
parent 589b225420
commit 3d25b26d0f
53 changed files with 871 additions and 177 deletions

View File

@@ -17,9 +17,12 @@ package com.jetbrains.python.codeInsight.codeFragment;
import com.intellij.codeInsight.codeFragment.CodeFragment;
import java.util.Map;
import java.util.Set;
public class PyCodeFragment extends CodeFragment {
private final Map<String, String> myInputTypes;
private final String myOutputType;
private final Set<String> myGlobalWrites;
private final Set<String> myNonlocalWrites;
private final boolean myYieldInside;
@@ -27,18 +30,30 @@ public class PyCodeFragment extends CodeFragment {
public PyCodeFragment(final Set<String> input,
final Set<String> output,
final Map<String, String> inputTypes,
final String outputType,
final Set<String> globalWrites,
final Set<String> nonlocalWrites,
final boolean returnInside,
final boolean yieldInside,
final boolean isAsync) {
super(input, output, returnInside);
myInputTypes = inputTypes;
myOutputType = outputType;
myGlobalWrites = globalWrites;
myNonlocalWrites = nonlocalWrites;
myYieldInside = yieldInside;
myAsync = isAsync;
}
public Map<String, String> getInputTypes() {
return myInputTypes;
}
public String getOutputType() {
return myOutputType;
}
public Set<String> getGlobalWrites() {
return myGlobalWrites;
}

View File

@@ -19,8 +19,10 @@ import com.jetbrains.python.codeInsight.controlflow.ReadWriteInstruction;
import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
import com.jetbrains.python.codeInsight.dataflow.scope.Scope;
import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
import com.jetbrains.python.documentation.PythonDocumentationProvider;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.impl.PyPsiUtils;
import com.jetbrains.python.psi.types.*;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
@@ -32,7 +34,10 @@ public final class PyCodeFragmentUtil {
public static @NotNull PyCodeFragment createCodeFragment(final @NotNull ScopeOwner owner,
final @NotNull PsiElement startInScope,
final @NotNull PsiElement endInScope) throws CannotCreateCodeFragmentException {
final @NotNull PsiElement endInScope,
final @Nullable PsiElement singleExpression)
throws CannotCreateCodeFragmentException {
final int start = startInScope.getTextOffset();
final int end = endInScope.getTextOffset() + endInScope.getTextLength();
final ControlFlow flow = ControlFlowCache.getControlFlow(owner);
@@ -50,31 +55,27 @@ public final class PyCodeFragmentUtil {
final Set<String> globalWrites = getGlobalWrites(subGraph, owner);
final Set<String> nonlocalWrites = getNonlocalWrites(subGraph, owner);
final TypeEvalContext context = TypeEvalContext.userInitiated(startInScope.getProject(), startInScope.getContainingFile());
final Set<String> inputNames = new HashSet<>();
final Map<String, String> inputTypeNames = new HashMap<>();
for (PsiElement element : filterElementsInScope(getInputElements(subGraph, graph), owner)) {
final String name = getName(element);
if (name != null) {
// Ignore "self" and "cls", they are generated automatically when extracting any method fragment
if (resolvesToBoundMethodParameter(element)) {
continue;
}
if (globalWrites.contains(name) || nonlocalWrites.contains(name)) {
continue;
}
inputNames.add(name);
// Ignore "self" and "cls", they are generated automatically when extracting any method fragment
if (resolvesToBoundMethodParameter(element)) {
continue;
}
addNameReturnType(globalWrites, nonlocalWrites, element, inputNames, inputTypeNames, null, context);
}
final Set<String> outputNames = new HashSet<>();
final List<PyType> outputTypes = new ArrayList<>();
for (PsiElement element : getOutputElements(subGraph, graph)) {
final String name = getName(element);
if (name != null) {
if (globalWrites.contains(name) || nonlocalWrites.contains(name)) {
continue;
}
outputNames.add(name);
}
addNameReturnType(globalWrites, nonlocalWrites, element, outputNames, null, outputTypes, context);
}
if (singleExpression != null) {
PyType returnType = getType(singleExpression, context);
outputTypes.add(returnType);
}
final String outputTypeName = getOutputTypeName(startInScope, outputTypes, context);
final boolean yieldsFound = subGraphAnalysis.yieldExpressions > 0;
if (yieldsFound && LanguageLevel.forElement(owner).isPython2()) {
@@ -82,7 +83,54 @@ public final class PyCodeFragmentUtil {
}
final boolean isAsync = owner instanceof PyFunction && ((PyFunction)owner).isAsync();
return new PyCodeFragment(inputNames, outputNames, globalWrites, nonlocalWrites, subGraphAnalysis.returns > 0, yieldsFound, isAsync);
return new PyCodeFragment(inputNames, outputNames, inputTypeNames, outputTypeName, globalWrites, nonlocalWrites,
subGraphAnalysis.returns > 0, yieldsFound, isAsync);
}
private static void addNameReturnType(@NotNull Set<String> globalWrites,
@NotNull Set<String> nonlocalWrites,
@NotNull PsiElement element,
@NotNull Set<String> varNames,
@Nullable Map<String, String> varTypeNames,
@Nullable List<PyType> outputTypes,
@NotNull TypeEvalContext context) {
String name = getName(element);
if (name == null || globalWrites.contains(name) || nonlocalWrites.contains(name) || varNames.contains(name)) {
return;
}
varNames.add(name);
PyType type = getType(element, context);
if (varTypeNames != null) {
String typeName = type == null ? null : PythonDocumentationProvider.getTypeHint(type, context);
varTypeNames.put(name, typeName);
}
if (outputTypes != null) {
outputTypes.add(type);
}
}
private static @Nullable PyType getType(@NotNull PsiElement element, @NotNull TypeEvalContext context) {
if (element instanceof PyTypedElement typedElement) {
PyType type = context.getType(typedElement);
if (type != null && !(type instanceof PyStructuralType) && !PyNoneTypeKt.isNoneType(type)) {
return type;
}
}
return null;
}
private static @Nullable String getOutputTypeName(@NotNull PsiElement startInScope,
@NotNull List<PyType> outputTypes,
@NotNull TypeEvalContext context) {
return switch (outputTypes.size()) {
case 0 -> null;
case 1 -> PythonDocumentationProvider.getTypeHint(outputTypes.get(0), context);
default -> {
PyType returnType = PyTupleType.create(startInScope, outputTypes);
yield PythonDocumentationProvider.getTypeHint(returnType, context);
}
};
}
private static boolean resolvesToBoundMethodParameter(@NotNull PsiElement element) {

View File

@@ -15,14 +15,17 @@
*/
package com.jetbrains.python.psi.impl;
import com.intellij.openapi.util.text.StringUtil;
import com.google.common.base.Strings;
import com.intellij.openapi.util.Pair;
import com.intellij.psi.PsiElement;
import com.intellij.util.ArrayUtil;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.documentation.docstrings.DocStringParser;
import com.jetbrains.python.documentation.docstrings.PyDocstringGenerator;
import com.jetbrains.python.psi.*;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.*;
@@ -30,11 +33,12 @@ import java.util.*;
public class PyFunctionBuilder {
private final String myName;
private final List<String> myParameters = new ArrayList<>();
private final Map<String, String> myParameterTypes = new HashMap<>();
private final List<String> myStatements = new ArrayList<>();
private final List<String> myDecorators = new ArrayList<>();
private final PsiElement mySettingAnchor;
private String myAnnotation = null;
private final @NotNull Map<String, String> myDecoratorValues = new HashMap<>();
private String myReturnType;
private boolean myAsync = false;
private PyDocstringGenerator myDocStringGenerator;
@@ -70,9 +74,9 @@ public class PyFunctionBuilder {
}
/**
* @param settingsAnchor any PSI element, presumably in the same file/module where generated function is going to be inserted.
* It's needed to detect configured docstring format and Python indentation size and, as result,
* generate properly formatted docstring.
* @param settingsAnchor any PSI element, presumably in the same file/module where the generated function is going to be inserted.
* It's necessary to detect configured docstring format and Python indentation size and, as a result,
* generate a properly formatted docstring.
*/
public PyFunctionBuilder(@NotNull String name, @NotNull PsiElement settingsAnchor) {
myName = name;
@@ -87,13 +91,17 @@ public class PyFunctionBuilder {
* @param name param name
* @param type param type
*/
public @NotNull PyFunctionBuilder parameterWithType(@NotNull String name, @NotNull String type) {
parameter(name);
public @NotNull PyFunctionBuilder parameterWithDocString(@NotNull String name, @NotNull String type) {
parameter(name, type);
myDocStringGenerator.withParamTypedByName(name, type);
return this;
}
public PyFunctionBuilder parameter(String baseName) {
public @NotNull PyFunctionBuilder parameter(@NotNull String baseName) {
return parameter(baseName, null);
}
public @NotNull PyFunctionBuilder parameter(@NotNull String baseName, @Nullable String type) {
String name = baseName;
int uniqueIndex = 0;
while (myParameters.contains(name)) {
@@ -101,38 +109,41 @@ public class PyFunctionBuilder {
name = baseName + uniqueIndex;
}
myParameters.add(name);
if (!Strings.isNullOrEmpty(type)) {
myParameterTypes.put(name, type);
}
return this;
}
public PyFunctionBuilder annotation(String text) {
myAnnotation = text;
public @NotNull PyFunctionBuilder returnType(String returnType) {
myReturnType = returnType;
return this;
}
public PyFunctionBuilder makeAsync() {
public @NotNull PyFunctionBuilder makeAsync() {
myAsync = true;
return this;
}
public PyFunctionBuilder statement(String text) {
public @NotNull PyFunctionBuilder statement(String text) {
myStatements.add(text);
return this;
}
public PyFunction addFunction(PsiElement target) {
public @NotNull PyFunction addFunction(PsiElement target) {
return (PyFunction)target.add(buildFunction());
}
public PyFunction addFunctionAfter(PsiElement target, PsiElement anchor) {
public @NotNull PyFunction addFunctionAfter(PsiElement target, PsiElement anchor) {
return (PyFunction)target.addAfter(buildFunction(), anchor);
}
public PyFunction buildFunction() {
public @NotNull PyFunction buildFunction() {
PyElementGenerator generator = PyElementGenerator.getInstance(mySettingAnchor.getProject());
return generator.createFromText(LanguageLevel.forElement(mySettingAnchor), PyFunction.class, buildText(generator));
}
private String buildText(PyElementGenerator generator) {
private @NotNull String buildText(PyElementGenerator generator) {
StringBuilder builder = new StringBuilder();
for (String decorator : myDecorators) {
final StringBuilder decoratorAppender = builder.append('@' + decorator);
@@ -146,18 +157,15 @@ public class PyFunctionBuilder {
if (myAsync) {
builder.append("async ");
}
builder.append("def ");
builder.append(myName).append("(");
builder.append(StringUtil.join(myParameters, ", "));
builder.append(")");
if (myAnnotation != null) {
builder.append(myAnnotation);
}
List<Pair<@NotNull String, @Nullable String>> parameters =
ContainerUtil.map(myParameters, paramName -> Pair.create(paramName, myParameterTypes.get(paramName)));
appendMethodSignature(builder, myName, parameters, myReturnType);
builder.append(":");
List<String> statements = myStatements.isEmpty() ? Collections.singletonList(PyNames.PASS) : myStatements;
final String indent = PyIndentUtil.getIndentFromSettings(mySettingAnchor.getContainingFile());
// There was original docstring or some parameters were added via parameterWithType()
// There was an original docstring or some parameters were added via parameterWithType()
if (!myDocStringGenerator.isNewMode() || myDocStringGenerator.hasParametersToAdd()) {
final String docstring = PyIndentUtil.changeIndent(myDocStringGenerator.buildDocString(), true, indent);
builder.append('\n').append(indent).append(docstring);
@@ -182,4 +190,30 @@ public class PyFunctionBuilder {
public void decorate(String decoratorName) {
myDecorators.add(decoratorName);
}
public static void appendMethodSignature(@NotNull StringBuilder builder, @NotNull String name,
@NotNull List<Pair<@NotNull String, @Nullable String>> parameters,
@Nullable String returnTypeName
) {
builder.append("def ");
builder.append(name);
builder.append("(");
for (int i = 0; i < parameters.size(); i++) {
Pair<@NotNull String, @Nullable String> parameter = parameters.get(i);
if (i > 0) {
builder.append(", ");
}
builder.append(parameter.first);
if (parameter.second != null) {
builder.append(": ");
builder.append(parameter.second);
}
}
builder.append(")");
if (returnTypeName != null) {
builder.append(" -> ");
builder.append(returnTypeName);
builder.append(" ");
}
}
}

View File

@@ -558,6 +558,10 @@ public class PyReferenceExpressionImpl extends PyElementImpl implements PyRefere
// null means empty set of possible types, Ref(null) means Any
final @Nullable Ref<PyType> combinedType = StreamEx.of(defs)
.map(instr -> {
if (instr.getElement() == anchor) {
// exclude recursive definition (example: type of 'i++' inside a loop)
return null;
}
if (instr instanceof ReadWriteInstruction readWriteInstruction) {
return readWriteInstruction.getType(context, anchor);
}

View File

@@ -1,6 +1,5 @@
package com.jetbrains.python.refactoring;
import com.intellij.codeInsight.codeFragment.CodeFragment;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.editor.Editor;
import com.intellij.openapi.fileTypes.FileType;
@@ -9,11 +8,13 @@ import com.intellij.openapi.util.NlsContexts;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiReference;
import com.intellij.refactoring.extractMethod.ExtractMethodDecorator;
import com.intellij.refactoring.extractMethod.ExtractMethodSettings;
import com.intellij.refactoring.extractMethod.ExtractMethodValidator;
import com.intellij.refactoring.util.AbstractVariableData;
import com.jetbrains.python.codeInsight.codeFragment.PyCodeFragment;
import com.jetbrains.python.psi.PyExpression;
import com.jetbrains.python.psi.PyFunction;
import com.jetbrains.python.refactoring.extractmethod.PyExtractMethodSettings;
import com.jetbrains.python.refactoring.extractmethod.PyExtractMethodUtil;
import com.jetbrains.python.refactoring.extractmethod.PyVariableData;
import com.jetbrains.python.refactoring.introduce.IntroduceOperation;
import com.jetbrains.python.refactoring.introduce.IntroduceValidator;
import org.jetbrains.annotations.ApiStatus;
@@ -46,29 +47,15 @@ public class PyRefactoringUiService {
callback.accept(operation);
}
public @Nullable <T> ExtractMethodSettings<T> showExtractMethodDialog(final Project project,
final String defaultName,
final CodeFragment fragment,
final T[] visibilityVariants,
final ExtractMethodValidator validator,
final ExtractMethodDecorator<T> decorator,
final FileType type, String helpId) {
return new ExtractMethodSettings<T>() {
@Override
public @NotNull String getMethodName() {
return defaultName;
}
@Override
public AbstractVariableData @NotNull [] getAbstractVariableData() {
return new AbstractVariableData[0];
}
@Override
public @Nullable T getVisibility() {
return null;
}
};
public @Nullable PyExtractMethodSettings showExtractMethodDialog(final Project project,
final String defaultName,
final PyCodeFragment fragment,
final Object[] visibilityVariants,
final ExtractMethodValidator validator,
final ExtractMethodDecorator<Object> decorator,
final FileType type, String helpId) {
return new PyExtractMethodSettings(defaultName, new PyVariableData[0], fragment.getOutputType(),
PyExtractMethodUtil.getAddTypeAnnotations(project));
}
public void showPyInlineFunctionDialog(@NotNull Project project,

View File

@@ -9,7 +9,6 @@ import com.intellij.openapi.util.Couple;
import com.intellij.psi.PsiComment;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiFile;
import com.intellij.psi.PsiWhiteSpace;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.refactoring.RefactoringActionHandler;
import com.intellij.refactoring.RefactoringBundle;
@@ -21,9 +20,6 @@ import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
import com.jetbrains.python.psi.PyCallExpression;
import com.jetbrains.python.psi.PyClass;
import com.jetbrains.python.psi.PyElement;
import com.jetbrains.python.psi.PyExpressionStatement;
import com.jetbrains.python.psi.PyKeywordArgument;
import com.jetbrains.python.psi.PyStatement;
import com.jetbrains.python.psi.impl.PyPsiUtils;
import com.jetbrains.python.refactoring.PyRefactoringUtil;
import org.jetbrains.annotations.NotNull;
@@ -102,7 +98,7 @@ public class PyExtractMethodHandler implements RefactoringActionHandler {
}
final PyCodeFragment fragment;
try {
fragment = PyCodeFragmentUtil.createCodeFragment(owner, element1, element2);
fragment = PyCodeFragmentUtil.createCodeFragment(owner, element1, element2, null);
}
catch (CannotCreateCodeFragmentException e) {
CommonRefactoringUtil.showErrorHint(project, editor, e.getMessage(),
@@ -121,7 +117,7 @@ public class PyExtractMethodHandler implements RefactoringActionHandler {
}
final PyCodeFragment fragment;
try {
fragment = PyCodeFragmentUtil.createCodeFragment(owner, element1, element2);
fragment = PyCodeFragmentUtil.createCodeFragment(owner, element1, element2, expression);
}
catch (CannotCreateCodeFragmentException e) {
CommonRefactoringUtil.showErrorHint(project, editor, e.getMessage(),

View File

@@ -0,0 +1,45 @@
package com.jetbrains.python.refactoring.extractmethod;
import com.intellij.refactoring.extractMethod.ExtractMethodSettings;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
public class PyExtractMethodSettings implements ExtractMethodSettings<Object> {
private final String myMethodName;
private final PyVariableData @NotNull [] myVariableData;
private final String myReturnTypeName;
private final boolean myUseTypeAnnotations;
public PyExtractMethodSettings(@NotNull String methodName,
PyVariableData @NotNull [] variableData,
String returnTypeName,
boolean useTypeAnnotations) {
myMethodName = methodName;
myVariableData = variableData;
myReturnTypeName = returnTypeName;
myUseTypeAnnotations = useTypeAnnotations;
}
@Override
public @NotNull String getMethodName() {
return myMethodName;
}
@Override
public PyVariableData @NotNull [] getAbstractVariableData() {
return myVariableData;
}
public String getReturnTypeName() {
return myReturnTypeName;
}
public boolean isUseTypeAnnotations() {
return myUseTypeAnnotations;
}
@Override
public @Nullable Object getVisibility() {
return null;
}
}

View File

@@ -2,7 +2,7 @@
package com.jetbrains.python.refactoring.extractmethod;
import com.intellij.codeInsight.CodeInsightUtilCore;
import com.intellij.codeInsight.codeFragment.CodeFragment;
import com.intellij.ide.util.PropertiesComponent;
import com.intellij.lang.LanguageNamesValidation;
import com.intellij.openapi.application.ApplicationManager;
import com.intellij.openapi.application.WriteAction;
@@ -22,7 +22,6 @@ import com.intellij.refactoring.listeners.RefactoringElementListenerComposite;
import com.intellij.refactoring.listeners.RefactoringEventData;
import com.intellij.refactoring.listeners.RefactoringEventListener;
import com.intellij.refactoring.rename.RenameUtil;
import com.intellij.refactoring.util.AbstractVariableData;
import com.intellij.refactoring.util.CommonRefactoringUtil;
import com.intellij.usageView.UsageInfo;
import com.intellij.util.ArrayUtilRt;
@@ -51,6 +50,8 @@ import java.util.*;
public final class PyExtractMethodUtil {
public static final String NAME = "extract.method.name";
private static final String ADD_TYPE_ANNOTATIONS_VALUE_KEY = "settings.extract.method.addTypeAnnotations";
private static final boolean ADD_TYPE_ANNOTATIONS_DEFAULT = true;
private PyExtractMethodUtil() {
}
@@ -91,15 +92,13 @@ public final class PyExtractMethodUtil {
return pointers;
}
final Pair<String, AbstractVariableData[]> data = getNameAndVariableData(project, fragment, statement1, isClassMethod, isStaticMethod);
if (data.first == null || data.second == null) {
final PyExtractMethodSettings methodSettings = getNameAndVariableData(project, fragment, statement1, isClassMethod, isStaticMethod);
if (methodSettings == null) {
return pointers;
}
PsiFile file = statement1.getContainingFile();
final String methodName = data.first;
final AbstractVariableData[] variableData = data.second;
final PyVariableData[] variableData = methodSettings.getAbstractVariableData();
final SimpleDuplicatesFinder finder = new SimpleDuplicatesFinder(statement1, statement2, fragment.getOutputVariables(), variableData);
@@ -132,14 +131,14 @@ public final class PyExtractMethodUtil {
}
// Generate method
final PyFunction generatedMethod = generateMethodFromElements(methodName, variableData, newMethodElements, flags, isAsync);
final PyFunction generatedMethod = generateMethodFromElements(methodSettings, newMethodElements, flags, isAsync);
final PyFunction insertedMethod = WriteAction.compute(() -> insertGeneratedMethod(statement1, generatedMethod));
// Process parameters
final PsiElement firstElement = elementsRange.get(0);
final boolean isMethod = PyPsiUtils.isMethodContext(firstElement);
WriteAction.run(() -> {
processParameters(project, insertedMethod, variableData, isMethod, isClassMethod, isStaticMethod);
processParameters(project, insertedMethod, methodSettings, isMethod, isClassMethod, isStaticMethod);
processGlobalWrites(insertedMethod, fragment);
processNonlocalWrites(insertedMethod, fragment);
});
@@ -160,7 +159,7 @@ public final class PyExtractMethodUtil {
if (isMethod) {
appendSelf(firstElement, builder, isStaticMethod);
}
builder.append(methodName).append("(");
builder.append(methodSettings.getMethodName()).append("(");
builder.append(createCallArgsString(variableData)).append(")");
final PyFunction function1 = generator.createFromText(languageLevel, PyFunction.class, builder.toString());
final PsiElement callElement = function1.getStatementList().getStatements()[0];
@@ -313,25 +312,23 @@ public final class PyExtractMethodUtil {
final boolean isClassMethod = flags != null && flags.isClassMethod();
final boolean isStaticMethod = flags != null && flags.isClassMethod();
final Pair<String, AbstractVariableData[]> data = getNameAndVariableData(project, fragment, expression, isClassMethod, isStaticMethod);
if (data.first == null || data.second == null) {
final PyExtractMethodSettings methodSettings = getNameAndVariableData(project, fragment, expression, isClassMethod, isStaticMethod);
if (methodSettings == null) {
return pointers;
}
final String methodName = data.first;
final AbstractVariableData[] variableData = data.second;
final PyVariableData[] variableData = methodSettings.getAbstractVariableData();
final SimpleDuplicatesFinder finder = new SimpleDuplicatesFinder(expression, expression, fragment.getOutputVariables(), variableData);
if (fragment.getOutputVariables().isEmpty()) {
CommandProcessor.getInstance().executeCommand(project, () -> {
// Generate method
final boolean isAsync = fragment.isAsync();
final PyFunction generatedMethod = generateMethodFromExpression(methodName, variableData, expression, flags, isAsync);
final PyFunction generatedMethod = generateMethodFromExpression(methodSettings, expression, flags, isAsync);
final PyFunction insertedMethod = WriteAction.compute(() -> insertGeneratedMethod(expression, generatedMethod));
// Process parameters
final boolean isMethod = PyPsiUtils.isMethodContext(expression);
WriteAction.run(() -> processParameters(project, insertedMethod, variableData, isMethod, isClassMethod, isStaticMethod));
WriteAction.run(() -> processParameters(project, insertedMethod, methodSettings, isMethod, isClassMethod, isStaticMethod));
// Generating call element
final StringBuilder builder = new StringBuilder();
@@ -351,11 +348,10 @@ public final class PyExtractMethodUtil {
if (isMethod) {
appendSelf(expression, builder, isStaticMethod);
}
builder.append(methodName);
builder.append(methodSettings.getMethodName());
builder.append("(").append(createCallArgsString(variableData)).append(")");
final PyElementGenerator generator = PyElementGenerator.getInstance(project);
final PyFunction function1 = generator.createFromText(LanguageLevel.forElement(expression), PyFunction.class,
builder.toString());
final PyFunction function1 = generator.createFromText(LanguageLevel.forElement(expression), PyFunction.class, builder.toString());
final PyElement generated = function1.getStatementList().getStatements()[0];
final PsiElement callElement;
if (generated instanceof PyReturnStatement) {
@@ -451,17 +447,17 @@ public final class PyExtractMethodUtil {
}
// Creates string for call
private static @NotNull String createCallArgsString(final AbstractVariableData @NotNull [] variableDatas) {
private static @NotNull String createCallArgsString(final PyVariableData @NotNull [] variableDatas) {
return StringUtil.join(ContainerUtil.mapNotNull(variableDatas, data -> data.isPassAsParameter() ? data.getOriginalName() : null), ",");
}
private static void processParameters(final @NotNull Project project,
final @NotNull PyFunction generatedMethod,
final AbstractVariableData @NotNull [] variableData,
final @NotNull PyExtractMethodSettings methodSettings,
final boolean isMethod,
final boolean isClassMethod,
final boolean isStaticMethod) {
final Map<String, String> map = createMap(variableData);
final Map<String, String> map = createMap(methodSettings.getAbstractVariableData());
// Rename parameters
for (PyParameter parameter : generatedMethod.getParameterList().getParameters()) {
final String name = parameter.getName();
@@ -487,18 +483,19 @@ public final class PyExtractMethodUtil {
else if (isMethod && !isStaticMethod) {
builder.parameter("self");
}
for (AbstractVariableData data : variableData) {
for (PyVariableData data : methodSettings.getAbstractVariableData()) {
if (data.isPassAsParameter()) {
builder.parameter(data.getName());
String typeName = methodSettings.isUseTypeAnnotations() ? data.getTypeName() : null;
builder.parameter(data.getName(), typeName);
}
}
final PyParameterList pyParameterList = builder.buildFunction().getParameterList();
generatedMethod.getParameterList().replace(pyParameterList);
}
private static @NotNull Map<String, String> createMap(final AbstractVariableData @NotNull [] variableData) {
private static @NotNull Map<String, String> createMap(final PyVariableData @NotNull [] variableData) {
final Map<String, String> map = new HashMap<>();
for (AbstractVariableData data : variableData) {
for (PyVariableData data : variableData) {
map.put(data.getOriginalName(), data.getName());
}
return map;
@@ -534,13 +531,12 @@ public final class PyExtractMethodUtil {
return (PyFunction)result;
}
private static @NotNull PyFunction generateMethodFromExpression(final @NotNull String methodName,
final AbstractVariableData @NotNull [] variableData,
private static @NotNull PyFunction generateMethodFromExpression(final @NotNull PyExtractMethodSettings methodSettings,
final @NotNull PsiElement expression,
final @Nullable PyUtil.MethodFlags flags, boolean isAsync) {
final PyFunctionBuilder builder = new PyFunctionBuilder(methodName, expression);
final PyFunctionBuilder builder = new PyFunctionBuilder(methodSettings.getMethodName(), expression);
addDecorators(builder, flags);
addFakeParameters(builder, variableData);
addParametersAndReturnType(builder, methodSettings);
if (isAsync) {
builder.makeAsync();
}
@@ -555,19 +551,18 @@ public final class PyExtractMethodUtil {
return builder.buildFunction();
}
private static @NotNull PyFunction generateMethodFromElements(final @NotNull String methodName,
final AbstractVariableData @NotNull [] variableData,
private static @NotNull PyFunction generateMethodFromElements(final @NotNull PyExtractMethodSettings methodSettings,
final @NotNull List<PsiElement> elementsRange,
@Nullable PyUtil.MethodFlags flags,
boolean isAsync) {
assert !elementsRange.isEmpty() : "Empty statements list was selected!";
final PyFunctionBuilder builder = new PyFunctionBuilder(methodName, elementsRange.get(0));
final PyFunctionBuilder builder = new PyFunctionBuilder(methodSettings.getMethodName(), elementsRange.get(0));
if (isAsync) {
builder.makeAsync();
}
addDecorators(builder, flags);
addFakeParameters(builder, variableData);
addParametersAndReturnType(builder, methodSettings);
final PyFunction method = builder.buildFunction();
final PyStatementList statementList = method.getStatementList();
for (PsiElement element : elementsRange) {
@@ -576,11 +571,11 @@ public final class PyExtractMethodUtil {
}
statementList.add(element);
}
// remove last instruction
final PsiElement child = statementList.getFirstChild();
if (child != null) {
child.delete();
}
// remove last instruction
PsiElement last = statementList;
while (last != null) {
last = last.getLastChild();
@@ -602,17 +597,21 @@ public final class PyExtractMethodUtil {
}
}
private static void addFakeParameters(@NotNull PyFunctionBuilder builder, AbstractVariableData @NotNull [] variableData) {
for (AbstractVariableData data : variableData) {
builder.parameter(data.getOriginalName());
private static void addParametersAndReturnType(@NotNull PyFunctionBuilder builder, PyExtractMethodSettings methodSettings) {
for (PyVariableData data : methodSettings.getAbstractVariableData()) {
String typeName = methodSettings.isUseTypeAnnotations() ? data.getTypeName() : null;
builder.parameter(data.getOriginalName(), typeName);
}
if (methodSettings.isUseTypeAnnotations()) {
builder.returnType(methodSettings.getReturnTypeName());
}
}
private static @NotNull Pair<String, AbstractVariableData[]> getNameAndVariableData(final @NotNull Project project,
final @NotNull CodeFragment fragment,
final @NotNull PsiElement element,
final boolean isClassMethod,
final boolean isStaticMethod) {
private static @Nullable PyExtractMethodSettings getNameAndVariableData(final @NotNull Project project,
final @NotNull PyCodeFragment fragment,
final @NotNull PsiElement element,
final boolean isClassMethod,
final boolean isStaticMethod) {
final ExtractMethodValidator validator = new PyExtractMethodValidator(element, project);
if (ApplicationManager.getApplication().isUnitTestMode()) {
String name = System.getProperty(NAME);
@@ -629,54 +628,49 @@ public final class PyExtractMethodUtil {
throw new CommonRefactoringUtil.RefactoringErrorHintException(error);
}
}
final List<AbstractVariableData> data = new ArrayList<>();
final List<PyVariableData> data = new ArrayList<>();
for (String in : fragment.getInputVariables()) {
final AbstractVariableData d = new AbstractVariableData();
final PyVariableData d = new PyVariableData();
d.name = in + "_new";
d.originalName = in;
d.passAsParameter = true;
d.typeName = fragment.getInputTypes().get(in);
data.add(d);
}
return Pair.create(name, data.toArray(new AbstractVariableData[0]));
return new PyExtractMethodSettings(name, data.toArray(new PyVariableData[0]), fragment.getOutputType(),
getAddTypeAnnotations(project));
}
final boolean isMethod = PyPsiUtils.isMethodContext(element);
final ExtractMethodDecorator<Object> decorator = new ExtractMethodDecorator<>() {
@Override
public @NotNull String createMethodSignature(@NotNull ExtractMethodSettings<Object> settings) {
final StringBuilder builder = new StringBuilder();
public @NotNull String createMethodSignature(@NotNull ExtractMethodSettings settings) {
PyExtractMethodSettings pySettings = (PyExtractMethodSettings)settings;
List<Pair<@NotNull String, @Nullable String>> parameters = new ArrayList<>();
if (isClassMethod) {
builder.append("cls");
parameters.add(Pair.create("cls", null));
}
else if (isMethod && !isStaticMethod) {
builder.append("self");
parameters.add(Pair.create("self", null));
}
for (AbstractVariableData variableData : settings.getAbstractVariableData()) {
for (PyVariableData variableData : pySettings.getAbstractVariableData()) {
if (variableData.passAsParameter) {
if (!builder.isEmpty()) {
builder.append(", ");
}
builder.append(variableData.name);
parameters.add(Pair.create(variableData.name, pySettings.isUseTypeAnnotations() ? variableData.typeName : null));
}
}
builder.insert(0, "(");
builder.insert(0, settings.getMethodName());
builder.insert(0, "def ");
builder.append(")");
final StringBuilder builder = new StringBuilder();
PyFunctionBuilder.appendMethodSignature(builder, pySettings.getMethodName(), parameters,
pySettings.isUseTypeAnnotations() ? pySettings.getReturnTypeName() : null);
return builder.toString();
}
};
ExtractMethodSettings<?> extractMethodSettings = PyRefactoringUiService.getInstance().showExtractMethodDialog(project, "method_name", fragment,
ArrayUtilRt.EMPTY_OBJECT_ARRAY, validator,
decorator, PythonFileType.INSTANCE,
"python.reference.extractMethod");
//return if don`t want to extract method
if (extractMethodSettings == null) {
return Pair.empty();
}
PyExtractMethodSettings extractMethodSettings = PyRefactoringUiService.getInstance()
.showExtractMethodDialog(project, "method_name", fragment, ArrayUtilRt.EMPTY_OBJECT_ARRAY, validator, decorator,
PythonFileType.INSTANCE, "python.reference.extractMethod");
return Pair.create(extractMethodSettings.getMethodName(), extractMethodSettings.getAbstractVariableData());
return extractMethodSettings;
}
public static @NotNull String getRefactoringId() {
@@ -727,4 +721,13 @@ public final class PyExtractMethodUtil {
return LanguageNamesValidation.isIdentifier(PythonLanguage.getInstance(), name, myProject);
}
}
public static void setAddTypeAnnotations(Project project, boolean value) {
PropertiesComponent.getInstance(project).setValue(ADD_TYPE_ANNOTATIONS_VALUE_KEY, value, ADD_TYPE_ANNOTATIONS_DEFAULT);
}
public static boolean getAddTypeAnnotations(Project project) {
boolean selected = PropertiesComponent.getInstance(project).getBoolean(ADD_TYPE_ANNOTATIONS_VALUE_KEY, ADD_TYPE_ANNOTATIONS_DEFAULT);
return selected;
}
}

View File

@@ -0,0 +1,13 @@
package com.jetbrains.python.refactoring.extractmethod;
import com.intellij.refactoring.util.AbstractVariableData;
import org.jetbrains.annotations.Nullable;
public class PyVariableData extends AbstractVariableData {
public @Nullable String typeName;
public @Nullable String getTypeName() {
return typeName;
}
}