mirror of
https://gitflic.ru/project/openide/openide.git
synced 2026-04-19 04:51:24 +07:00
PY-35287 Extract method with typehints
- customize extract method refactoring for Python using prefix Py - option to enable/disable type annotations - persist value of checkbox.isSelected - run all extract method tests using types - add specific typed test - adjust api-dump.txt since some members of AbstractExtractMethodDialog were raised from private to protected - small adjustment in test expectation of the inferred type GitOrigin-RevId: be6e70dcb61c451debb98c10c0a001234188cb7d
This commit is contained in:
committed by
intellij-monorepo-bot
parent
589b225420
commit
3d25b26d0f
@@ -17,9 +17,12 @@ package com.jetbrains.python.codeInsight.codeFragment;
|
||||
|
||||
import com.intellij.codeInsight.codeFragment.CodeFragment;
|
||||
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
|
||||
public class PyCodeFragment extends CodeFragment {
|
||||
private final Map<String, String> myInputTypes;
|
||||
private final String myOutputType;
|
||||
private final Set<String> myGlobalWrites;
|
||||
private final Set<String> myNonlocalWrites;
|
||||
private final boolean myYieldInside;
|
||||
@@ -27,18 +30,30 @@ public class PyCodeFragment extends CodeFragment {
|
||||
|
||||
public PyCodeFragment(final Set<String> input,
|
||||
final Set<String> output,
|
||||
final Map<String, String> inputTypes,
|
||||
final String outputType,
|
||||
final Set<String> globalWrites,
|
||||
final Set<String> nonlocalWrites,
|
||||
final boolean returnInside,
|
||||
final boolean yieldInside,
|
||||
final boolean isAsync) {
|
||||
super(input, output, returnInside);
|
||||
myInputTypes = inputTypes;
|
||||
myOutputType = outputType;
|
||||
myGlobalWrites = globalWrites;
|
||||
myNonlocalWrites = nonlocalWrites;
|
||||
myYieldInside = yieldInside;
|
||||
myAsync = isAsync;
|
||||
}
|
||||
|
||||
public Map<String, String> getInputTypes() {
|
||||
return myInputTypes;
|
||||
}
|
||||
|
||||
public String getOutputType() {
|
||||
return myOutputType;
|
||||
}
|
||||
|
||||
public Set<String> getGlobalWrites() {
|
||||
return myGlobalWrites;
|
||||
}
|
||||
|
||||
@@ -19,8 +19,10 @@ import com.jetbrains.python.codeInsight.controlflow.ReadWriteInstruction;
|
||||
import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
|
||||
import com.jetbrains.python.codeInsight.dataflow.scope.Scope;
|
||||
import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
|
||||
import com.jetbrains.python.documentation.PythonDocumentationProvider;
|
||||
import com.jetbrains.python.psi.*;
|
||||
import com.jetbrains.python.psi.impl.PyPsiUtils;
|
||||
import com.jetbrains.python.psi.types.*;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
@@ -32,7 +34,10 @@ public final class PyCodeFragmentUtil {
|
||||
|
||||
public static @NotNull PyCodeFragment createCodeFragment(final @NotNull ScopeOwner owner,
|
||||
final @NotNull PsiElement startInScope,
|
||||
final @NotNull PsiElement endInScope) throws CannotCreateCodeFragmentException {
|
||||
final @NotNull PsiElement endInScope,
|
||||
final @Nullable PsiElement singleExpression)
|
||||
throws CannotCreateCodeFragmentException {
|
||||
|
||||
final int start = startInScope.getTextOffset();
|
||||
final int end = endInScope.getTextOffset() + endInScope.getTextLength();
|
||||
final ControlFlow flow = ControlFlowCache.getControlFlow(owner);
|
||||
@@ -50,31 +55,27 @@ public final class PyCodeFragmentUtil {
|
||||
final Set<String> globalWrites = getGlobalWrites(subGraph, owner);
|
||||
final Set<String> nonlocalWrites = getNonlocalWrites(subGraph, owner);
|
||||
|
||||
final TypeEvalContext context = TypeEvalContext.userInitiated(startInScope.getProject(), startInScope.getContainingFile());
|
||||
final Set<String> inputNames = new HashSet<>();
|
||||
final Map<String, String> inputTypeNames = new HashMap<>();
|
||||
for (PsiElement element : filterElementsInScope(getInputElements(subGraph, graph), owner)) {
|
||||
final String name = getName(element);
|
||||
if (name != null) {
|
||||
// Ignore "self" and "cls", they are generated automatically when extracting any method fragment
|
||||
if (resolvesToBoundMethodParameter(element)) {
|
||||
continue;
|
||||
}
|
||||
if (globalWrites.contains(name) || nonlocalWrites.contains(name)) {
|
||||
continue;
|
||||
}
|
||||
inputNames.add(name);
|
||||
// Ignore "self" and "cls", they are generated automatically when extracting any method fragment
|
||||
if (resolvesToBoundMethodParameter(element)) {
|
||||
continue;
|
||||
}
|
||||
addNameReturnType(globalWrites, nonlocalWrites, element, inputNames, inputTypeNames, null, context);
|
||||
}
|
||||
|
||||
final Set<String> outputNames = new HashSet<>();
|
||||
final List<PyType> outputTypes = new ArrayList<>();
|
||||
for (PsiElement element : getOutputElements(subGraph, graph)) {
|
||||
final String name = getName(element);
|
||||
if (name != null) {
|
||||
if (globalWrites.contains(name) || nonlocalWrites.contains(name)) {
|
||||
continue;
|
||||
}
|
||||
outputNames.add(name);
|
||||
}
|
||||
addNameReturnType(globalWrites, nonlocalWrites, element, outputNames, null, outputTypes, context);
|
||||
}
|
||||
if (singleExpression != null) {
|
||||
PyType returnType = getType(singleExpression, context);
|
||||
outputTypes.add(returnType);
|
||||
}
|
||||
final String outputTypeName = getOutputTypeName(startInScope, outputTypes, context);
|
||||
|
||||
final boolean yieldsFound = subGraphAnalysis.yieldExpressions > 0;
|
||||
if (yieldsFound && LanguageLevel.forElement(owner).isPython2()) {
|
||||
@@ -82,7 +83,54 @@ public final class PyCodeFragmentUtil {
|
||||
}
|
||||
final boolean isAsync = owner instanceof PyFunction && ((PyFunction)owner).isAsync();
|
||||
|
||||
return new PyCodeFragment(inputNames, outputNames, globalWrites, nonlocalWrites, subGraphAnalysis.returns > 0, yieldsFound, isAsync);
|
||||
return new PyCodeFragment(inputNames, outputNames, inputTypeNames, outputTypeName, globalWrites, nonlocalWrites,
|
||||
subGraphAnalysis.returns > 0, yieldsFound, isAsync);
|
||||
}
|
||||
|
||||
private static void addNameReturnType(@NotNull Set<String> globalWrites,
|
||||
@NotNull Set<String> nonlocalWrites,
|
||||
@NotNull PsiElement element,
|
||||
@NotNull Set<String> varNames,
|
||||
@Nullable Map<String, String> varTypeNames,
|
||||
@Nullable List<PyType> outputTypes,
|
||||
@NotNull TypeEvalContext context) {
|
||||
String name = getName(element);
|
||||
if (name == null || globalWrites.contains(name) || nonlocalWrites.contains(name) || varNames.contains(name)) {
|
||||
return;
|
||||
}
|
||||
varNames.add(name);
|
||||
PyType type = getType(element, context);
|
||||
if (varTypeNames != null) {
|
||||
String typeName = type == null ? null : PythonDocumentationProvider.getTypeHint(type, context);
|
||||
varTypeNames.put(name, typeName);
|
||||
}
|
||||
if (outputTypes != null) {
|
||||
outputTypes.add(type);
|
||||
}
|
||||
}
|
||||
|
||||
private static @Nullable PyType getType(@NotNull PsiElement element, @NotNull TypeEvalContext context) {
|
||||
if (element instanceof PyTypedElement typedElement) {
|
||||
PyType type = context.getType(typedElement);
|
||||
if (type != null && !(type instanceof PyStructuralType) && !PyNoneTypeKt.isNoneType(type)) {
|
||||
return type;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static @Nullable String getOutputTypeName(@NotNull PsiElement startInScope,
|
||||
@NotNull List<PyType> outputTypes,
|
||||
@NotNull TypeEvalContext context) {
|
||||
|
||||
return switch (outputTypes.size()) {
|
||||
case 0 -> null;
|
||||
case 1 -> PythonDocumentationProvider.getTypeHint(outputTypes.get(0), context);
|
||||
default -> {
|
||||
PyType returnType = PyTupleType.create(startInScope, outputTypes);
|
||||
yield PythonDocumentationProvider.getTypeHint(returnType, context);
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
private static boolean resolvesToBoundMethodParameter(@NotNull PsiElement element) {
|
||||
|
||||
@@ -15,14 +15,17 @@
|
||||
*/
|
||||
package com.jetbrains.python.psi.impl;
|
||||
|
||||
import com.intellij.openapi.util.text.StringUtil;
|
||||
import com.google.common.base.Strings;
|
||||
import com.intellij.openapi.util.Pair;
|
||||
import com.intellij.psi.PsiElement;
|
||||
import com.intellij.util.ArrayUtil;
|
||||
import com.intellij.util.containers.ContainerUtil;
|
||||
import com.jetbrains.python.PyNames;
|
||||
import com.jetbrains.python.documentation.docstrings.DocStringParser;
|
||||
import com.jetbrains.python.documentation.docstrings.PyDocstringGenerator;
|
||||
import com.jetbrains.python.psi.*;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
import java.util.*;
|
||||
|
||||
@@ -30,11 +33,12 @@ import java.util.*;
|
||||
public class PyFunctionBuilder {
|
||||
private final String myName;
|
||||
private final List<String> myParameters = new ArrayList<>();
|
||||
private final Map<String, String> myParameterTypes = new HashMap<>();
|
||||
private final List<String> myStatements = new ArrayList<>();
|
||||
private final List<String> myDecorators = new ArrayList<>();
|
||||
private final PsiElement mySettingAnchor;
|
||||
private String myAnnotation = null;
|
||||
private final @NotNull Map<String, String> myDecoratorValues = new HashMap<>();
|
||||
private String myReturnType;
|
||||
private boolean myAsync = false;
|
||||
private PyDocstringGenerator myDocStringGenerator;
|
||||
|
||||
@@ -70,9 +74,9 @@ public class PyFunctionBuilder {
|
||||
}
|
||||
|
||||
/**
|
||||
* @param settingsAnchor any PSI element, presumably in the same file/module where generated function is going to be inserted.
|
||||
* It's needed to detect configured docstring format and Python indentation size and, as result,
|
||||
* generate properly formatted docstring.
|
||||
* @param settingsAnchor any PSI element, presumably in the same file/module where the generated function is going to be inserted.
|
||||
* It's necessary to detect configured docstring format and Python indentation size and, as a result,
|
||||
* generate a properly formatted docstring.
|
||||
*/
|
||||
public PyFunctionBuilder(@NotNull String name, @NotNull PsiElement settingsAnchor) {
|
||||
myName = name;
|
||||
@@ -87,13 +91,17 @@ public class PyFunctionBuilder {
|
||||
* @param name param name
|
||||
* @param type param type
|
||||
*/
|
||||
public @NotNull PyFunctionBuilder parameterWithType(@NotNull String name, @NotNull String type) {
|
||||
parameter(name);
|
||||
public @NotNull PyFunctionBuilder parameterWithDocString(@NotNull String name, @NotNull String type) {
|
||||
parameter(name, type);
|
||||
myDocStringGenerator.withParamTypedByName(name, type);
|
||||
return this;
|
||||
}
|
||||
|
||||
public PyFunctionBuilder parameter(String baseName) {
|
||||
public @NotNull PyFunctionBuilder parameter(@NotNull String baseName) {
|
||||
return parameter(baseName, null);
|
||||
}
|
||||
|
||||
public @NotNull PyFunctionBuilder parameter(@NotNull String baseName, @Nullable String type) {
|
||||
String name = baseName;
|
||||
int uniqueIndex = 0;
|
||||
while (myParameters.contains(name)) {
|
||||
@@ -101,38 +109,41 @@ public class PyFunctionBuilder {
|
||||
name = baseName + uniqueIndex;
|
||||
}
|
||||
myParameters.add(name);
|
||||
if (!Strings.isNullOrEmpty(type)) {
|
||||
myParameterTypes.put(name, type);
|
||||
}
|
||||
return this;
|
||||
}
|
||||
|
||||
public PyFunctionBuilder annotation(String text) {
|
||||
myAnnotation = text;
|
||||
public @NotNull PyFunctionBuilder returnType(String returnType) {
|
||||
myReturnType = returnType;
|
||||
return this;
|
||||
}
|
||||
|
||||
public PyFunctionBuilder makeAsync() {
|
||||
public @NotNull PyFunctionBuilder makeAsync() {
|
||||
myAsync = true;
|
||||
return this;
|
||||
}
|
||||
|
||||
public PyFunctionBuilder statement(String text) {
|
||||
public @NotNull PyFunctionBuilder statement(String text) {
|
||||
myStatements.add(text);
|
||||
return this;
|
||||
}
|
||||
|
||||
public PyFunction addFunction(PsiElement target) {
|
||||
public @NotNull PyFunction addFunction(PsiElement target) {
|
||||
return (PyFunction)target.add(buildFunction());
|
||||
}
|
||||
|
||||
public PyFunction addFunctionAfter(PsiElement target, PsiElement anchor) {
|
||||
public @NotNull PyFunction addFunctionAfter(PsiElement target, PsiElement anchor) {
|
||||
return (PyFunction)target.addAfter(buildFunction(), anchor);
|
||||
}
|
||||
|
||||
public PyFunction buildFunction() {
|
||||
public @NotNull PyFunction buildFunction() {
|
||||
PyElementGenerator generator = PyElementGenerator.getInstance(mySettingAnchor.getProject());
|
||||
return generator.createFromText(LanguageLevel.forElement(mySettingAnchor), PyFunction.class, buildText(generator));
|
||||
}
|
||||
|
||||
private String buildText(PyElementGenerator generator) {
|
||||
private @NotNull String buildText(PyElementGenerator generator) {
|
||||
StringBuilder builder = new StringBuilder();
|
||||
for (String decorator : myDecorators) {
|
||||
final StringBuilder decoratorAppender = builder.append('@' + decorator);
|
||||
@@ -146,18 +157,15 @@ public class PyFunctionBuilder {
|
||||
if (myAsync) {
|
||||
builder.append("async ");
|
||||
}
|
||||
builder.append("def ");
|
||||
builder.append(myName).append("(");
|
||||
builder.append(StringUtil.join(myParameters, ", "));
|
||||
builder.append(")");
|
||||
if (myAnnotation != null) {
|
||||
builder.append(myAnnotation);
|
||||
}
|
||||
List<Pair<@NotNull String, @Nullable String>> parameters =
|
||||
ContainerUtil.map(myParameters, paramName -> Pair.create(paramName, myParameterTypes.get(paramName)));
|
||||
|
||||
appendMethodSignature(builder, myName, parameters, myReturnType);
|
||||
builder.append(":");
|
||||
List<String> statements = myStatements.isEmpty() ? Collections.singletonList(PyNames.PASS) : myStatements;
|
||||
|
||||
final String indent = PyIndentUtil.getIndentFromSettings(mySettingAnchor.getContainingFile());
|
||||
// There was original docstring or some parameters were added via parameterWithType()
|
||||
// There was an original docstring or some parameters were added via parameterWithType()
|
||||
if (!myDocStringGenerator.isNewMode() || myDocStringGenerator.hasParametersToAdd()) {
|
||||
final String docstring = PyIndentUtil.changeIndent(myDocStringGenerator.buildDocString(), true, indent);
|
||||
builder.append('\n').append(indent).append(docstring);
|
||||
@@ -182,4 +190,30 @@ public class PyFunctionBuilder {
|
||||
public void decorate(String decoratorName) {
|
||||
myDecorators.add(decoratorName);
|
||||
}
|
||||
|
||||
public static void appendMethodSignature(@NotNull StringBuilder builder, @NotNull String name,
|
||||
@NotNull List<Pair<@NotNull String, @Nullable String>> parameters,
|
||||
@Nullable String returnTypeName
|
||||
) {
|
||||
builder.append("def ");
|
||||
builder.append(name);
|
||||
builder.append("(");
|
||||
for (int i = 0; i < parameters.size(); i++) {
|
||||
Pair<@NotNull String, @Nullable String> parameter = parameters.get(i);
|
||||
if (i > 0) {
|
||||
builder.append(", ");
|
||||
}
|
||||
builder.append(parameter.first);
|
||||
if (parameter.second != null) {
|
||||
builder.append(": ");
|
||||
builder.append(parameter.second);
|
||||
}
|
||||
}
|
||||
builder.append(")");
|
||||
if (returnTypeName != null) {
|
||||
builder.append(" -> ");
|
||||
builder.append(returnTypeName);
|
||||
builder.append(" ");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -558,6 +558,10 @@ public class PyReferenceExpressionImpl extends PyElementImpl implements PyRefere
|
||||
// null means empty set of possible types, Ref(null) means Any
|
||||
final @Nullable Ref<PyType> combinedType = StreamEx.of(defs)
|
||||
.map(instr -> {
|
||||
if (instr.getElement() == anchor) {
|
||||
// exclude recursive definition (example: type of 'i++' inside a loop)
|
||||
return null;
|
||||
}
|
||||
if (instr instanceof ReadWriteInstruction readWriteInstruction) {
|
||||
return readWriteInstruction.getType(context, anchor);
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
package com.jetbrains.python.refactoring;
|
||||
|
||||
import com.intellij.codeInsight.codeFragment.CodeFragment;
|
||||
import com.intellij.openapi.application.ApplicationManager;
|
||||
import com.intellij.openapi.editor.Editor;
|
||||
import com.intellij.openapi.fileTypes.FileType;
|
||||
@@ -9,11 +8,13 @@ import com.intellij.openapi.util.NlsContexts;
|
||||
import com.intellij.psi.PsiElement;
|
||||
import com.intellij.psi.PsiReference;
|
||||
import com.intellij.refactoring.extractMethod.ExtractMethodDecorator;
|
||||
import com.intellij.refactoring.extractMethod.ExtractMethodSettings;
|
||||
import com.intellij.refactoring.extractMethod.ExtractMethodValidator;
|
||||
import com.intellij.refactoring.util.AbstractVariableData;
|
||||
import com.jetbrains.python.codeInsight.codeFragment.PyCodeFragment;
|
||||
import com.jetbrains.python.psi.PyExpression;
|
||||
import com.jetbrains.python.psi.PyFunction;
|
||||
import com.jetbrains.python.refactoring.extractmethod.PyExtractMethodSettings;
|
||||
import com.jetbrains.python.refactoring.extractmethod.PyExtractMethodUtil;
|
||||
import com.jetbrains.python.refactoring.extractmethod.PyVariableData;
|
||||
import com.jetbrains.python.refactoring.introduce.IntroduceOperation;
|
||||
import com.jetbrains.python.refactoring.introduce.IntroduceValidator;
|
||||
import org.jetbrains.annotations.ApiStatus;
|
||||
@@ -46,29 +47,15 @@ public class PyRefactoringUiService {
|
||||
callback.accept(operation);
|
||||
}
|
||||
|
||||
public @Nullable <T> ExtractMethodSettings<T> showExtractMethodDialog(final Project project,
|
||||
final String defaultName,
|
||||
final CodeFragment fragment,
|
||||
final T[] visibilityVariants,
|
||||
final ExtractMethodValidator validator,
|
||||
final ExtractMethodDecorator<T> decorator,
|
||||
final FileType type, String helpId) {
|
||||
return new ExtractMethodSettings<T>() {
|
||||
@Override
|
||||
public @NotNull String getMethodName() {
|
||||
return defaultName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AbstractVariableData @NotNull [] getAbstractVariableData() {
|
||||
return new AbstractVariableData[0];
|
||||
}
|
||||
|
||||
@Override
|
||||
public @Nullable T getVisibility() {
|
||||
return null;
|
||||
}
|
||||
};
|
||||
public @Nullable PyExtractMethodSettings showExtractMethodDialog(final Project project,
|
||||
final String defaultName,
|
||||
final PyCodeFragment fragment,
|
||||
final Object[] visibilityVariants,
|
||||
final ExtractMethodValidator validator,
|
||||
final ExtractMethodDecorator<Object> decorator,
|
||||
final FileType type, String helpId) {
|
||||
return new PyExtractMethodSettings(defaultName, new PyVariableData[0], fragment.getOutputType(),
|
||||
PyExtractMethodUtil.getAddTypeAnnotations(project));
|
||||
}
|
||||
|
||||
public void showPyInlineFunctionDialog(@NotNull Project project,
|
||||
|
||||
@@ -9,7 +9,6 @@ import com.intellij.openapi.util.Couple;
|
||||
import com.intellij.psi.PsiComment;
|
||||
import com.intellij.psi.PsiElement;
|
||||
import com.intellij.psi.PsiFile;
|
||||
import com.intellij.psi.PsiWhiteSpace;
|
||||
import com.intellij.psi.util.PsiTreeUtil;
|
||||
import com.intellij.refactoring.RefactoringActionHandler;
|
||||
import com.intellij.refactoring.RefactoringBundle;
|
||||
@@ -21,9 +20,6 @@ import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
|
||||
import com.jetbrains.python.psi.PyCallExpression;
|
||||
import com.jetbrains.python.psi.PyClass;
|
||||
import com.jetbrains.python.psi.PyElement;
|
||||
import com.jetbrains.python.psi.PyExpressionStatement;
|
||||
import com.jetbrains.python.psi.PyKeywordArgument;
|
||||
import com.jetbrains.python.psi.PyStatement;
|
||||
import com.jetbrains.python.psi.impl.PyPsiUtils;
|
||||
import com.jetbrains.python.refactoring.PyRefactoringUtil;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
@@ -102,7 +98,7 @@ public class PyExtractMethodHandler implements RefactoringActionHandler {
|
||||
}
|
||||
final PyCodeFragment fragment;
|
||||
try {
|
||||
fragment = PyCodeFragmentUtil.createCodeFragment(owner, element1, element2);
|
||||
fragment = PyCodeFragmentUtil.createCodeFragment(owner, element1, element2, null);
|
||||
}
|
||||
catch (CannotCreateCodeFragmentException e) {
|
||||
CommonRefactoringUtil.showErrorHint(project, editor, e.getMessage(),
|
||||
@@ -121,7 +117,7 @@ public class PyExtractMethodHandler implements RefactoringActionHandler {
|
||||
}
|
||||
final PyCodeFragment fragment;
|
||||
try {
|
||||
fragment = PyCodeFragmentUtil.createCodeFragment(owner, element1, element2);
|
||||
fragment = PyCodeFragmentUtil.createCodeFragment(owner, element1, element2, expression);
|
||||
}
|
||||
catch (CannotCreateCodeFragmentException e) {
|
||||
CommonRefactoringUtil.showErrorHint(project, editor, e.getMessage(),
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
package com.jetbrains.python.refactoring.extractmethod;
|
||||
|
||||
import com.intellij.refactoring.extractMethod.ExtractMethodSettings;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
public class PyExtractMethodSettings implements ExtractMethodSettings<Object> {
|
||||
private final String myMethodName;
|
||||
private final PyVariableData @NotNull [] myVariableData;
|
||||
private final String myReturnTypeName;
|
||||
private final boolean myUseTypeAnnotations;
|
||||
|
||||
public PyExtractMethodSettings(@NotNull String methodName,
|
||||
PyVariableData @NotNull [] variableData,
|
||||
String returnTypeName,
|
||||
boolean useTypeAnnotations) {
|
||||
myMethodName = methodName;
|
||||
myVariableData = variableData;
|
||||
myReturnTypeName = returnTypeName;
|
||||
myUseTypeAnnotations = useTypeAnnotations;
|
||||
}
|
||||
|
||||
@Override
|
||||
public @NotNull String getMethodName() {
|
||||
return myMethodName;
|
||||
}
|
||||
|
||||
@Override
|
||||
public PyVariableData @NotNull [] getAbstractVariableData() {
|
||||
return myVariableData;
|
||||
}
|
||||
|
||||
public String getReturnTypeName() {
|
||||
return myReturnTypeName;
|
||||
}
|
||||
|
||||
public boolean isUseTypeAnnotations() {
|
||||
return myUseTypeAnnotations;
|
||||
}
|
||||
|
||||
@Override
|
||||
public @Nullable Object getVisibility() {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@
|
||||
package com.jetbrains.python.refactoring.extractmethod;
|
||||
|
||||
import com.intellij.codeInsight.CodeInsightUtilCore;
|
||||
import com.intellij.codeInsight.codeFragment.CodeFragment;
|
||||
import com.intellij.ide.util.PropertiesComponent;
|
||||
import com.intellij.lang.LanguageNamesValidation;
|
||||
import com.intellij.openapi.application.ApplicationManager;
|
||||
import com.intellij.openapi.application.WriteAction;
|
||||
@@ -22,7 +22,6 @@ import com.intellij.refactoring.listeners.RefactoringElementListenerComposite;
|
||||
import com.intellij.refactoring.listeners.RefactoringEventData;
|
||||
import com.intellij.refactoring.listeners.RefactoringEventListener;
|
||||
import com.intellij.refactoring.rename.RenameUtil;
|
||||
import com.intellij.refactoring.util.AbstractVariableData;
|
||||
import com.intellij.refactoring.util.CommonRefactoringUtil;
|
||||
import com.intellij.usageView.UsageInfo;
|
||||
import com.intellij.util.ArrayUtilRt;
|
||||
@@ -51,6 +50,8 @@ import java.util.*;
|
||||
|
||||
public final class PyExtractMethodUtil {
|
||||
public static final String NAME = "extract.method.name";
|
||||
private static final String ADD_TYPE_ANNOTATIONS_VALUE_KEY = "settings.extract.method.addTypeAnnotations";
|
||||
private static final boolean ADD_TYPE_ANNOTATIONS_DEFAULT = true;
|
||||
|
||||
private PyExtractMethodUtil() {
|
||||
}
|
||||
@@ -91,15 +92,13 @@ public final class PyExtractMethodUtil {
|
||||
return pointers;
|
||||
}
|
||||
|
||||
final Pair<String, AbstractVariableData[]> data = getNameAndVariableData(project, fragment, statement1, isClassMethod, isStaticMethod);
|
||||
if (data.first == null || data.second == null) {
|
||||
final PyExtractMethodSettings methodSettings = getNameAndVariableData(project, fragment, statement1, isClassMethod, isStaticMethod);
|
||||
if (methodSettings == null) {
|
||||
return pointers;
|
||||
}
|
||||
|
||||
PsiFile file = statement1.getContainingFile();
|
||||
|
||||
final String methodName = data.first;
|
||||
final AbstractVariableData[] variableData = data.second;
|
||||
final PyVariableData[] variableData = methodSettings.getAbstractVariableData();
|
||||
|
||||
final SimpleDuplicatesFinder finder = new SimpleDuplicatesFinder(statement1, statement2, fragment.getOutputVariables(), variableData);
|
||||
|
||||
@@ -132,14 +131,14 @@ public final class PyExtractMethodUtil {
|
||||
}
|
||||
|
||||
// Generate method
|
||||
final PyFunction generatedMethod = generateMethodFromElements(methodName, variableData, newMethodElements, flags, isAsync);
|
||||
final PyFunction generatedMethod = generateMethodFromElements(methodSettings, newMethodElements, flags, isAsync);
|
||||
final PyFunction insertedMethod = WriteAction.compute(() -> insertGeneratedMethod(statement1, generatedMethod));
|
||||
|
||||
// Process parameters
|
||||
final PsiElement firstElement = elementsRange.get(0);
|
||||
final boolean isMethod = PyPsiUtils.isMethodContext(firstElement);
|
||||
WriteAction.run(() -> {
|
||||
processParameters(project, insertedMethod, variableData, isMethod, isClassMethod, isStaticMethod);
|
||||
processParameters(project, insertedMethod, methodSettings, isMethod, isClassMethod, isStaticMethod);
|
||||
processGlobalWrites(insertedMethod, fragment);
|
||||
processNonlocalWrites(insertedMethod, fragment);
|
||||
});
|
||||
@@ -160,7 +159,7 @@ public final class PyExtractMethodUtil {
|
||||
if (isMethod) {
|
||||
appendSelf(firstElement, builder, isStaticMethod);
|
||||
}
|
||||
builder.append(methodName).append("(");
|
||||
builder.append(methodSettings.getMethodName()).append("(");
|
||||
builder.append(createCallArgsString(variableData)).append(")");
|
||||
final PyFunction function1 = generator.createFromText(languageLevel, PyFunction.class, builder.toString());
|
||||
final PsiElement callElement = function1.getStatementList().getStatements()[0];
|
||||
@@ -313,25 +312,23 @@ public final class PyExtractMethodUtil {
|
||||
final boolean isClassMethod = flags != null && flags.isClassMethod();
|
||||
final boolean isStaticMethod = flags != null && flags.isClassMethod();
|
||||
|
||||
final Pair<String, AbstractVariableData[]> data = getNameAndVariableData(project, fragment, expression, isClassMethod, isStaticMethod);
|
||||
if (data.first == null || data.second == null) {
|
||||
final PyExtractMethodSettings methodSettings = getNameAndVariableData(project, fragment, expression, isClassMethod, isStaticMethod);
|
||||
if (methodSettings == null) {
|
||||
return pointers;
|
||||
}
|
||||
|
||||
final String methodName = data.first;
|
||||
final AbstractVariableData[] variableData = data.second;
|
||||
|
||||
final PyVariableData[] variableData = methodSettings.getAbstractVariableData();
|
||||
final SimpleDuplicatesFinder finder = new SimpleDuplicatesFinder(expression, expression, fragment.getOutputVariables(), variableData);
|
||||
if (fragment.getOutputVariables().isEmpty()) {
|
||||
CommandProcessor.getInstance().executeCommand(project, () -> {
|
||||
// Generate method
|
||||
final boolean isAsync = fragment.isAsync();
|
||||
final PyFunction generatedMethod = generateMethodFromExpression(methodName, variableData, expression, flags, isAsync);
|
||||
final PyFunction generatedMethod = generateMethodFromExpression(methodSettings, expression, flags, isAsync);
|
||||
final PyFunction insertedMethod = WriteAction.compute(() -> insertGeneratedMethod(expression, generatedMethod));
|
||||
|
||||
// Process parameters
|
||||
final boolean isMethod = PyPsiUtils.isMethodContext(expression);
|
||||
WriteAction.run(() -> processParameters(project, insertedMethod, variableData, isMethod, isClassMethod, isStaticMethod));
|
||||
WriteAction.run(() -> processParameters(project, insertedMethod, methodSettings, isMethod, isClassMethod, isStaticMethod));
|
||||
|
||||
// Generating call element
|
||||
final StringBuilder builder = new StringBuilder();
|
||||
@@ -351,11 +348,10 @@ public final class PyExtractMethodUtil {
|
||||
if (isMethod) {
|
||||
appendSelf(expression, builder, isStaticMethod);
|
||||
}
|
||||
builder.append(methodName);
|
||||
builder.append(methodSettings.getMethodName());
|
||||
builder.append("(").append(createCallArgsString(variableData)).append(")");
|
||||
final PyElementGenerator generator = PyElementGenerator.getInstance(project);
|
||||
final PyFunction function1 = generator.createFromText(LanguageLevel.forElement(expression), PyFunction.class,
|
||||
builder.toString());
|
||||
final PyFunction function1 = generator.createFromText(LanguageLevel.forElement(expression), PyFunction.class, builder.toString());
|
||||
final PyElement generated = function1.getStatementList().getStatements()[0];
|
||||
final PsiElement callElement;
|
||||
if (generated instanceof PyReturnStatement) {
|
||||
@@ -451,17 +447,17 @@ public final class PyExtractMethodUtil {
|
||||
}
|
||||
|
||||
// Creates string for call
|
||||
private static @NotNull String createCallArgsString(final AbstractVariableData @NotNull [] variableDatas) {
|
||||
private static @NotNull String createCallArgsString(final PyVariableData @NotNull [] variableDatas) {
|
||||
return StringUtil.join(ContainerUtil.mapNotNull(variableDatas, data -> data.isPassAsParameter() ? data.getOriginalName() : null), ",");
|
||||
}
|
||||
|
||||
private static void processParameters(final @NotNull Project project,
|
||||
final @NotNull PyFunction generatedMethod,
|
||||
final AbstractVariableData @NotNull [] variableData,
|
||||
final @NotNull PyExtractMethodSettings methodSettings,
|
||||
final boolean isMethod,
|
||||
final boolean isClassMethod,
|
||||
final boolean isStaticMethod) {
|
||||
final Map<String, String> map = createMap(variableData);
|
||||
final Map<String, String> map = createMap(methodSettings.getAbstractVariableData());
|
||||
// Rename parameters
|
||||
for (PyParameter parameter : generatedMethod.getParameterList().getParameters()) {
|
||||
final String name = parameter.getName();
|
||||
@@ -487,18 +483,19 @@ public final class PyExtractMethodUtil {
|
||||
else if (isMethod && !isStaticMethod) {
|
||||
builder.parameter("self");
|
||||
}
|
||||
for (AbstractVariableData data : variableData) {
|
||||
for (PyVariableData data : methodSettings.getAbstractVariableData()) {
|
||||
if (data.isPassAsParameter()) {
|
||||
builder.parameter(data.getName());
|
||||
String typeName = methodSettings.isUseTypeAnnotations() ? data.getTypeName() : null;
|
||||
builder.parameter(data.getName(), typeName);
|
||||
}
|
||||
}
|
||||
final PyParameterList pyParameterList = builder.buildFunction().getParameterList();
|
||||
generatedMethod.getParameterList().replace(pyParameterList);
|
||||
}
|
||||
|
||||
private static @NotNull Map<String, String> createMap(final AbstractVariableData @NotNull [] variableData) {
|
||||
private static @NotNull Map<String, String> createMap(final PyVariableData @NotNull [] variableData) {
|
||||
final Map<String, String> map = new HashMap<>();
|
||||
for (AbstractVariableData data : variableData) {
|
||||
for (PyVariableData data : variableData) {
|
||||
map.put(data.getOriginalName(), data.getName());
|
||||
}
|
||||
return map;
|
||||
@@ -534,13 +531,12 @@ public final class PyExtractMethodUtil {
|
||||
return (PyFunction)result;
|
||||
}
|
||||
|
||||
private static @NotNull PyFunction generateMethodFromExpression(final @NotNull String methodName,
|
||||
final AbstractVariableData @NotNull [] variableData,
|
||||
private static @NotNull PyFunction generateMethodFromExpression(final @NotNull PyExtractMethodSettings methodSettings,
|
||||
final @NotNull PsiElement expression,
|
||||
final @Nullable PyUtil.MethodFlags flags, boolean isAsync) {
|
||||
final PyFunctionBuilder builder = new PyFunctionBuilder(methodName, expression);
|
||||
final PyFunctionBuilder builder = new PyFunctionBuilder(methodSettings.getMethodName(), expression);
|
||||
addDecorators(builder, flags);
|
||||
addFakeParameters(builder, variableData);
|
||||
addParametersAndReturnType(builder, methodSettings);
|
||||
if (isAsync) {
|
||||
builder.makeAsync();
|
||||
}
|
||||
@@ -555,19 +551,18 @@ public final class PyExtractMethodUtil {
|
||||
return builder.buildFunction();
|
||||
}
|
||||
|
||||
private static @NotNull PyFunction generateMethodFromElements(final @NotNull String methodName,
|
||||
final AbstractVariableData @NotNull [] variableData,
|
||||
private static @NotNull PyFunction generateMethodFromElements(final @NotNull PyExtractMethodSettings methodSettings,
|
||||
final @NotNull List<PsiElement> elementsRange,
|
||||
@Nullable PyUtil.MethodFlags flags,
|
||||
boolean isAsync) {
|
||||
assert !elementsRange.isEmpty() : "Empty statements list was selected!";
|
||||
|
||||
final PyFunctionBuilder builder = new PyFunctionBuilder(methodName, elementsRange.get(0));
|
||||
final PyFunctionBuilder builder = new PyFunctionBuilder(methodSettings.getMethodName(), elementsRange.get(0));
|
||||
if (isAsync) {
|
||||
builder.makeAsync();
|
||||
}
|
||||
addDecorators(builder, flags);
|
||||
addFakeParameters(builder, variableData);
|
||||
addParametersAndReturnType(builder, methodSettings);
|
||||
final PyFunction method = builder.buildFunction();
|
||||
final PyStatementList statementList = method.getStatementList();
|
||||
for (PsiElement element : elementsRange) {
|
||||
@@ -576,11 +571,11 @@ public final class PyExtractMethodUtil {
|
||||
}
|
||||
statementList.add(element);
|
||||
}
|
||||
// remove last instruction
|
||||
final PsiElement child = statementList.getFirstChild();
|
||||
if (child != null) {
|
||||
child.delete();
|
||||
}
|
||||
// remove last instruction
|
||||
PsiElement last = statementList;
|
||||
while (last != null) {
|
||||
last = last.getLastChild();
|
||||
@@ -602,17 +597,21 @@ public final class PyExtractMethodUtil {
|
||||
}
|
||||
}
|
||||
|
||||
private static void addFakeParameters(@NotNull PyFunctionBuilder builder, AbstractVariableData @NotNull [] variableData) {
|
||||
for (AbstractVariableData data : variableData) {
|
||||
builder.parameter(data.getOriginalName());
|
||||
private static void addParametersAndReturnType(@NotNull PyFunctionBuilder builder, PyExtractMethodSettings methodSettings) {
|
||||
for (PyVariableData data : methodSettings.getAbstractVariableData()) {
|
||||
String typeName = methodSettings.isUseTypeAnnotations() ? data.getTypeName() : null;
|
||||
builder.parameter(data.getOriginalName(), typeName);
|
||||
}
|
||||
if (methodSettings.isUseTypeAnnotations()) {
|
||||
builder.returnType(methodSettings.getReturnTypeName());
|
||||
}
|
||||
}
|
||||
|
||||
private static @NotNull Pair<String, AbstractVariableData[]> getNameAndVariableData(final @NotNull Project project,
|
||||
final @NotNull CodeFragment fragment,
|
||||
final @NotNull PsiElement element,
|
||||
final boolean isClassMethod,
|
||||
final boolean isStaticMethod) {
|
||||
private static @Nullable PyExtractMethodSettings getNameAndVariableData(final @NotNull Project project,
|
||||
final @NotNull PyCodeFragment fragment,
|
||||
final @NotNull PsiElement element,
|
||||
final boolean isClassMethod,
|
||||
final boolean isStaticMethod) {
|
||||
final ExtractMethodValidator validator = new PyExtractMethodValidator(element, project);
|
||||
if (ApplicationManager.getApplication().isUnitTestMode()) {
|
||||
String name = System.getProperty(NAME);
|
||||
@@ -629,54 +628,49 @@ public final class PyExtractMethodUtil {
|
||||
throw new CommonRefactoringUtil.RefactoringErrorHintException(error);
|
||||
}
|
||||
}
|
||||
final List<AbstractVariableData> data = new ArrayList<>();
|
||||
final List<PyVariableData> data = new ArrayList<>();
|
||||
for (String in : fragment.getInputVariables()) {
|
||||
final AbstractVariableData d = new AbstractVariableData();
|
||||
final PyVariableData d = new PyVariableData();
|
||||
d.name = in + "_new";
|
||||
d.originalName = in;
|
||||
d.passAsParameter = true;
|
||||
d.typeName = fragment.getInputTypes().get(in);
|
||||
data.add(d);
|
||||
}
|
||||
return Pair.create(name, data.toArray(new AbstractVariableData[0]));
|
||||
return new PyExtractMethodSettings(name, data.toArray(new PyVariableData[0]), fragment.getOutputType(),
|
||||
getAddTypeAnnotations(project));
|
||||
}
|
||||
|
||||
final boolean isMethod = PyPsiUtils.isMethodContext(element);
|
||||
final ExtractMethodDecorator<Object> decorator = new ExtractMethodDecorator<>() {
|
||||
@Override
|
||||
public @NotNull String createMethodSignature(@NotNull ExtractMethodSettings<Object> settings) {
|
||||
final StringBuilder builder = new StringBuilder();
|
||||
public @NotNull String createMethodSignature(@NotNull ExtractMethodSettings settings) {
|
||||
PyExtractMethodSettings pySettings = (PyExtractMethodSettings)settings;
|
||||
List<Pair<@NotNull String, @Nullable String>> parameters = new ArrayList<>();
|
||||
if (isClassMethod) {
|
||||
builder.append("cls");
|
||||
parameters.add(Pair.create("cls", null));
|
||||
}
|
||||
else if (isMethod && !isStaticMethod) {
|
||||
builder.append("self");
|
||||
parameters.add(Pair.create("self", null));
|
||||
}
|
||||
for (AbstractVariableData variableData : settings.getAbstractVariableData()) {
|
||||
for (PyVariableData variableData : pySettings.getAbstractVariableData()) {
|
||||
if (variableData.passAsParameter) {
|
||||
if (!builder.isEmpty()) {
|
||||
builder.append(", ");
|
||||
}
|
||||
builder.append(variableData.name);
|
||||
parameters.add(Pair.create(variableData.name, pySettings.isUseTypeAnnotations() ? variableData.typeName : null));
|
||||
}
|
||||
}
|
||||
builder.insert(0, "(");
|
||||
builder.insert(0, settings.getMethodName());
|
||||
builder.insert(0, "def ");
|
||||
builder.append(")");
|
||||
final StringBuilder builder = new StringBuilder();
|
||||
PyFunctionBuilder.appendMethodSignature(builder, pySettings.getMethodName(), parameters,
|
||||
pySettings.isUseTypeAnnotations() ? pySettings.getReturnTypeName() : null);
|
||||
|
||||
return builder.toString();
|
||||
}
|
||||
};
|
||||
|
||||
ExtractMethodSettings<?> extractMethodSettings = PyRefactoringUiService.getInstance().showExtractMethodDialog(project, "method_name", fragment,
|
||||
ArrayUtilRt.EMPTY_OBJECT_ARRAY, validator,
|
||||
decorator, PythonFileType.INSTANCE,
|
||||
"python.reference.extractMethod");
|
||||
//return if don`t want to extract method
|
||||
if (extractMethodSettings == null) {
|
||||
return Pair.empty();
|
||||
}
|
||||
PyExtractMethodSettings extractMethodSettings = PyRefactoringUiService.getInstance()
|
||||
.showExtractMethodDialog(project, "method_name", fragment, ArrayUtilRt.EMPTY_OBJECT_ARRAY, validator, decorator,
|
||||
PythonFileType.INSTANCE, "python.reference.extractMethod");
|
||||
|
||||
return Pair.create(extractMethodSettings.getMethodName(), extractMethodSettings.getAbstractVariableData());
|
||||
return extractMethodSettings;
|
||||
}
|
||||
|
||||
public static @NotNull String getRefactoringId() {
|
||||
@@ -727,4 +721,13 @@ public final class PyExtractMethodUtil {
|
||||
return LanguageNamesValidation.isIdentifier(PythonLanguage.getInstance(), name, myProject);
|
||||
}
|
||||
}
|
||||
|
||||
public static void setAddTypeAnnotations(Project project, boolean value) {
|
||||
PropertiesComponent.getInstance(project).setValue(ADD_TYPE_ANNOTATIONS_VALUE_KEY, value, ADD_TYPE_ANNOTATIONS_DEFAULT);
|
||||
}
|
||||
|
||||
public static boolean getAddTypeAnnotations(Project project) {
|
||||
boolean selected = PropertiesComponent.getInstance(project).getBoolean(ADD_TYPE_ANNOTATIONS_VALUE_KEY, ADD_TYPE_ANNOTATIONS_DEFAULT);
|
||||
return selected;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
package com.jetbrains.python.refactoring.extractmethod;
|
||||
|
||||
import com.intellij.refactoring.util.AbstractVariableData;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
public class PyVariableData extends AbstractVariableData {
|
||||
public @Nullable String typeName;
|
||||
|
||||
|
||||
public @Nullable String getTypeName() {
|
||||
return typeName;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user