mirror of
https://gitflic.ru/project/openide/openide.git
synced 2025-12-15 02:59:33 +07:00
Update overriding method with overloads (PY-30287)
* Select implementation as super method, not an overload * Copy all its overloads * Add imports for decorators (e.g. typing.overload)
This commit is contained in:
@@ -27,6 +27,7 @@ import com.jetbrains.python.psi.impl.PyFunctionBuilder;
|
||||
import com.jetbrains.python.psi.impl.PyPsiUtils;
|
||||
import com.jetbrains.python.psi.resolve.PyResolveContext;
|
||||
import com.jetbrains.python.psi.types.*;
|
||||
import com.jetbrains.python.pyi.PyiUtil;
|
||||
import com.jetbrains.python.refactoring.classes.PyClassRefactoringUtil;
|
||||
import one.util.streamex.StreamEx;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
@@ -157,12 +158,7 @@ public class PyOverrideImplementUtil {
|
||||
|
||||
PyFunction element = null;
|
||||
for (PyMethodMember newMember : Lists.reverse(newMembers)) {
|
||||
final PyFunction baseFunction = (PyFunction)newMember.getPsiElement();
|
||||
final PyFunctionBuilder builder = buildOverriddenFunction(pyClass, baseFunction, implement);
|
||||
final PyFunction function = builder.addFunctionAfter(statementList, anchor);
|
||||
|
||||
addImports(baseFunction, function);
|
||||
element = CodeInsightUtilCore.forcePsiPostprocessAndRestoreElement(function);
|
||||
element = writeMember(pyClass, (PyFunction)newMember.getPsiElement(), anchor, implement);
|
||||
}
|
||||
|
||||
PyPsiUtils.removeRedundantPass(statementList);
|
||||
@@ -175,6 +171,29 @@ public class PyOverrideImplementUtil {
|
||||
}
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private static PyFunction writeMember(@NotNull PyClass cls,
|
||||
@NotNull PyFunction baseFunction,
|
||||
@Nullable PsiElement anchor,
|
||||
boolean implement) {
|
||||
final PyStatementList statementList = cls.getStatementList();
|
||||
final TypeEvalContext context = TypeEvalContext.userInitiated(cls.getProject(), cls.getContainingFile());
|
||||
|
||||
final PyFunction function = buildOverriddenFunction(cls, baseFunction, implement).addFunctionAfter(statementList, anchor);
|
||||
addImports(baseFunction, function, context);
|
||||
|
||||
PyiUtil
|
||||
.getOverloads(baseFunction, context)
|
||||
.forEach(
|
||||
baseOverload -> {
|
||||
final PyFunction overload = (PyFunction)statementList.addBefore(baseOverload, function);
|
||||
addImports(baseOverload, overload, context);
|
||||
}
|
||||
);
|
||||
|
||||
return CodeInsightUtilCore.forcePsiPostprocessAndRestoreElement(function);
|
||||
}
|
||||
|
||||
private static PyFunctionBuilder buildOverriddenFunction(PyClass pyClass,
|
||||
PyFunction baseFunction,
|
||||
boolean implement) {
|
||||
@@ -352,8 +371,10 @@ public class PyOverrideImplementUtil {
|
||||
if (type != null) {
|
||||
for (PyFunction function : PyTypeUtil.getMembersOfType(type, PyFunction.class, false, context)) {
|
||||
final String name = function.getName();
|
||||
if (name != null && !functions.containsKey(name)) {
|
||||
functions.put(name, function);
|
||||
if (name != null) {
|
||||
if (!functions.containsKey(name) || PyiUtil.isOverload(functions.get(name), context) && !PyiUtil.isOverload(function, context)) {
|
||||
functions.put(name, function);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -362,22 +383,19 @@ public class PyOverrideImplementUtil {
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds imports for type hints in overridden function (PY-18553).
|
||||
* Adds imports for type hints and decorators in overridden function.
|
||||
*
|
||||
* @param baseFunction base function used to resolve types
|
||||
* @param function overridden function
|
||||
*/
|
||||
private static void addImports(@NotNull PyFunction baseFunction, @NotNull PyFunction function) {
|
||||
final TypeEvalContext typeEvalContext = TypeEvalContext.userInitiated(baseFunction.getProject(), baseFunction.getContainingFile());
|
||||
|
||||
private static void addImports(@NotNull PyFunction baseFunction, @NotNull PyFunction function, @NotNull TypeEvalContext context) {
|
||||
final UnresolvedExpressionVisitor unresolvedExpressionVisitor = new UnresolvedExpressionVisitor();
|
||||
final List<PyAnnotation> annotations = getAnnotations(function, typeEvalContext);
|
||||
annotations.forEach(annotation -> unresolvedExpressionVisitor.visitPyElement(annotation));
|
||||
final List<PyReferenceExpression> unresolved = unresolvedExpressionVisitor.getUnresolved();
|
||||
getAnnotations(function, context).forEach(annotation -> annotation.accept(unresolvedExpressionVisitor));
|
||||
getDecorators(function).forEach(decorator -> decorator.accept(unresolvedExpressionVisitor));
|
||||
|
||||
final ResolveExpressionVisitor resolveExpressionVisitor = new ResolveExpressionVisitor(unresolved);
|
||||
final List<PyAnnotation> baseAnnotations = getAnnotations(baseFunction, typeEvalContext);
|
||||
baseAnnotations.forEach(annotation -> resolveExpressionVisitor.visitPyElement(annotation));
|
||||
final ResolveExpressionVisitor resolveExpressionVisitor = new ResolveExpressionVisitor(unresolvedExpressionVisitor.getUnresolved());
|
||||
getAnnotations(baseFunction, context).forEach(annotation -> annotation.accept(resolveExpressionVisitor));
|
||||
getDecorators(baseFunction).forEach(decorator -> decorator.accept(resolveExpressionVisitor));
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -387,6 +405,7 @@ public class PyOverrideImplementUtil {
|
||||
* @param typeEvalContext
|
||||
* @return
|
||||
*/
|
||||
@NotNull
|
||||
private static List<PyAnnotation> getAnnotations(@NotNull PyFunction function, @NotNull TypeEvalContext typeEvalContext) {
|
||||
return StreamEx.of(function.getParameters(typeEvalContext))
|
||||
.map(PyCallableParameter::getParameter)
|
||||
@@ -398,6 +417,12 @@ public class PyOverrideImplementUtil {
|
||||
.toList();
|
||||
}
|
||||
|
||||
@NotNull
|
||||
private static List<PyDecorator> getDecorators(@NotNull PyFunction function) {
|
||||
final PyDecoratorList decoratorList = function.getDecoratorList();
|
||||
return decoratorList == null ? Collections.emptyList() : Arrays.asList(decoratorList.getDecorators());
|
||||
}
|
||||
|
||||
/**
|
||||
* Collects unresolved {@link PyReferenceExpression} objects.
|
||||
*/
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
from methodWithOverloadsInAnotherFile_parent import Foo
|
||||
|
||||
class B(Foo):
|
||||
<caret>pass
|
||||
@@ -0,0 +1,14 @@
|
||||
from typing import overload
|
||||
|
||||
from methodWithOverloadsInAnotherFile_parent import Foo
|
||||
|
||||
class B(Foo):
|
||||
@overload
|
||||
def fun(self, s:str) -> str: pass
|
||||
|
||||
@overload
|
||||
def fun(self, i:int) -> int: pass
|
||||
|
||||
def fun(self, x):
|
||||
super().fun(x)
|
||||
|
||||
@@ -0,0 +1,11 @@
|
||||
from typing import overload
|
||||
|
||||
class Foo(object):
|
||||
@overload
|
||||
def fun(self, s:str) -> str: pass
|
||||
|
||||
@overload
|
||||
def fun(self, i:int) -> int: pass
|
||||
|
||||
def fun(self, x):
|
||||
pass
|
||||
14
python/testData/override/methodWithOverloadsInTheSameFile.py
Normal file
14
python/testData/override/methodWithOverloadsInTheSameFile.py
Normal file
@@ -0,0 +1,14 @@
|
||||
from typing import overload
|
||||
|
||||
class Foo:
|
||||
@overload
|
||||
def fun(self, s:str) -> str: pass
|
||||
|
||||
@overload
|
||||
def fun(self, i:int) -> int: pass
|
||||
|
||||
def fun(self, x):
|
||||
pass
|
||||
|
||||
class B(Foo):
|
||||
<caret>pass
|
||||
@@ -0,0 +1,22 @@
|
||||
from typing import overload
|
||||
|
||||
class Foo:
|
||||
@overload
|
||||
def fun(self, s:str) -> str: pass
|
||||
|
||||
@overload
|
||||
def fun(self, i:int) -> int: pass
|
||||
|
||||
def fun(self, x):
|
||||
pass
|
||||
|
||||
class B(Foo):
|
||||
@overload
|
||||
def fun(self, s:str) -> str: pass
|
||||
|
||||
@overload
|
||||
def fun(self, i:int) -> int: pass
|
||||
|
||||
def fun(self, x):
|
||||
super().fun(x)
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
package com.jetbrains.python;
|
||||
|
||||
import com.intellij.psi.PsiElement;
|
||||
import com.intellij.psi.PsiFile;
|
||||
import com.intellij.psi.util.PsiTreeUtil;
|
||||
import com.intellij.util.containers.ContainerUtil;
|
||||
import com.jetbrains.python.codeInsight.override.PyMethodMember;
|
||||
@@ -228,4 +229,45 @@ public class PyOverrideTest extends PyTestCase {
|
||||
public void testAsyncMethod() {
|
||||
runWithLanguageLevel(LanguageLevel.PYTHON36, () -> doTest());
|
||||
}
|
||||
|
||||
// PY-30287
|
||||
public void testMethodWithOverloadsInTheSameFile() {
|
||||
runWithLanguageLevel(
|
||||
LanguageLevel.PYTHON35,
|
||||
() -> {
|
||||
myFixture.configureByFile("override/" + getTestName(true) + ".py");
|
||||
|
||||
PyOverrideImplementUtil.overrideMethods(
|
||||
myFixture.getEditor(),
|
||||
getTopLevelClass(1),
|
||||
Collections.singletonList(new PyMethodMember(getTopLevelClass(0).getMethods()[2])),
|
||||
false
|
||||
);
|
||||
|
||||
myFixture.checkResultByFile("override/" + getTestName(true) + "_after.py", true);
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// PY-30287
|
||||
public void testMethodWithOverloadsInAnotherFile() {
|
||||
runWithLanguageLevel(
|
||||
LanguageLevel.PYTHON35,
|
||||
() -> {
|
||||
final PsiFile[] files = myFixture.configureByFiles(
|
||||
"override/" + getTestName(true) + ".py",
|
||||
"override/" + getTestName(true) + "_parent.py"
|
||||
);
|
||||
|
||||
PyOverrideImplementUtil.overrideMethods(
|
||||
myFixture.getEditor(),
|
||||
getTopLevelClass(0),
|
||||
Collections.singletonList(new PyMethodMember(((PyFile)files[1]).getTopLevelClasses().get(0).getMethods()[2])),
|
||||
false
|
||||
);
|
||||
|
||||
myFixture.checkResultByFile("override/" + getTestName(true) + "_after.py", true);
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user