Update overriding method with overloads (PY-30287)

* Select implementation as super method, not an overload
* Copy all its overloads
* Add imports for decorators (e.g. typing.overload)
This commit is contained in:
Semyon Proshev
2018-06-27 14:50:06 +03:00
parent 98582a7785
commit 6eeb3b25be
7 changed files with 150 additions and 18 deletions

View File

@@ -27,6 +27,7 @@ import com.jetbrains.python.psi.impl.PyFunctionBuilder;
import com.jetbrains.python.psi.impl.PyPsiUtils;
import com.jetbrains.python.psi.resolve.PyResolveContext;
import com.jetbrains.python.psi.types.*;
import com.jetbrains.python.pyi.PyiUtil;
import com.jetbrains.python.refactoring.classes.PyClassRefactoringUtil;
import one.util.streamex.StreamEx;
import org.jetbrains.annotations.NotNull;
@@ -157,12 +158,7 @@ public class PyOverrideImplementUtil {
PyFunction element = null;
for (PyMethodMember newMember : Lists.reverse(newMembers)) {
final PyFunction baseFunction = (PyFunction)newMember.getPsiElement();
final PyFunctionBuilder builder = buildOverriddenFunction(pyClass, baseFunction, implement);
final PyFunction function = builder.addFunctionAfter(statementList, anchor);
addImports(baseFunction, function);
element = CodeInsightUtilCore.forcePsiPostprocessAndRestoreElement(function);
element = writeMember(pyClass, (PyFunction)newMember.getPsiElement(), anchor, implement);
}
PyPsiUtils.removeRedundantPass(statementList);
@@ -175,6 +171,29 @@ public class PyOverrideImplementUtil {
}
}
@Nullable
private static PyFunction writeMember(@NotNull PyClass cls,
@NotNull PyFunction baseFunction,
@Nullable PsiElement anchor,
boolean implement) {
final PyStatementList statementList = cls.getStatementList();
final TypeEvalContext context = TypeEvalContext.userInitiated(cls.getProject(), cls.getContainingFile());
final PyFunction function = buildOverriddenFunction(cls, baseFunction, implement).addFunctionAfter(statementList, anchor);
addImports(baseFunction, function, context);
PyiUtil
.getOverloads(baseFunction, context)
.forEach(
baseOverload -> {
final PyFunction overload = (PyFunction)statementList.addBefore(baseOverload, function);
addImports(baseOverload, overload, context);
}
);
return CodeInsightUtilCore.forcePsiPostprocessAndRestoreElement(function);
}
private static PyFunctionBuilder buildOverriddenFunction(PyClass pyClass,
PyFunction baseFunction,
boolean implement) {
@@ -352,8 +371,10 @@ public class PyOverrideImplementUtil {
if (type != null) {
for (PyFunction function : PyTypeUtil.getMembersOfType(type, PyFunction.class, false, context)) {
final String name = function.getName();
if (name != null && !functions.containsKey(name)) {
functions.put(name, function);
if (name != null) {
if (!functions.containsKey(name) || PyiUtil.isOverload(functions.get(name), context) && !PyiUtil.isOverload(function, context)) {
functions.put(name, function);
}
}
}
}
@@ -362,22 +383,19 @@ public class PyOverrideImplementUtil {
}
/**
* Adds imports for type hints in overridden function (PY-18553).
* Adds imports for type hints and decorators in overridden function.
*
* @param baseFunction base function used to resolve types
* @param function overridden function
*/
private static void addImports(@NotNull PyFunction baseFunction, @NotNull PyFunction function) {
final TypeEvalContext typeEvalContext = TypeEvalContext.userInitiated(baseFunction.getProject(), baseFunction.getContainingFile());
private static void addImports(@NotNull PyFunction baseFunction, @NotNull PyFunction function, @NotNull TypeEvalContext context) {
final UnresolvedExpressionVisitor unresolvedExpressionVisitor = new UnresolvedExpressionVisitor();
final List<PyAnnotation> annotations = getAnnotations(function, typeEvalContext);
annotations.forEach(annotation -> unresolvedExpressionVisitor.visitPyElement(annotation));
final List<PyReferenceExpression> unresolved = unresolvedExpressionVisitor.getUnresolved();
getAnnotations(function, context).forEach(annotation -> annotation.accept(unresolvedExpressionVisitor));
getDecorators(function).forEach(decorator -> decorator.accept(unresolvedExpressionVisitor));
final ResolveExpressionVisitor resolveExpressionVisitor = new ResolveExpressionVisitor(unresolved);
final List<PyAnnotation> baseAnnotations = getAnnotations(baseFunction, typeEvalContext);
baseAnnotations.forEach(annotation -> resolveExpressionVisitor.visitPyElement(annotation));
final ResolveExpressionVisitor resolveExpressionVisitor = new ResolveExpressionVisitor(unresolvedExpressionVisitor.getUnresolved());
getAnnotations(baseFunction, context).forEach(annotation -> annotation.accept(resolveExpressionVisitor));
getDecorators(baseFunction).forEach(decorator -> decorator.accept(resolveExpressionVisitor));
}
/**
@@ -387,6 +405,7 @@ public class PyOverrideImplementUtil {
* @param typeEvalContext
* @return
*/
@NotNull
private static List<PyAnnotation> getAnnotations(@NotNull PyFunction function, @NotNull TypeEvalContext typeEvalContext) {
return StreamEx.of(function.getParameters(typeEvalContext))
.map(PyCallableParameter::getParameter)
@@ -398,6 +417,12 @@ public class PyOverrideImplementUtil {
.toList();
}
@NotNull
private static List<PyDecorator> getDecorators(@NotNull PyFunction function) {
final PyDecoratorList decoratorList = function.getDecoratorList();
return decoratorList == null ? Collections.emptyList() : Arrays.asList(decoratorList.getDecorators());
}
/**
* Collects unresolved {@link PyReferenceExpression} objects.
*/

View File

@@ -0,0 +1,4 @@
from methodWithOverloadsInAnotherFile_parent import Foo
class B(Foo):
<caret>pass

View File

@@ -0,0 +1,14 @@
from typing import overload
from methodWithOverloadsInAnotherFile_parent import Foo
class B(Foo):
@overload
def fun(self, s:str) -> str: pass
@overload
def fun(self, i:int) -> int: pass
def fun(self, x):
super().fun(x)

View File

@@ -0,0 +1,11 @@
from typing import overload
class Foo(object):
@overload
def fun(self, s:str) -> str: pass
@overload
def fun(self, i:int) -> int: pass
def fun(self, x):
pass

View File

@@ -0,0 +1,14 @@
from typing import overload
class Foo:
@overload
def fun(self, s:str) -> str: pass
@overload
def fun(self, i:int) -> int: pass
def fun(self, x):
pass
class B(Foo):
<caret>pass

View File

@@ -0,0 +1,22 @@
from typing import overload
class Foo:
@overload
def fun(self, s:str) -> str: pass
@overload
def fun(self, i:int) -> int: pass
def fun(self, x):
pass
class B(Foo):
@overload
def fun(self, s:str) -> str: pass
@overload
def fun(self, i:int) -> int: pass
def fun(self, x):
super().fun(x)

View File

@@ -4,6 +4,7 @@
package com.jetbrains.python;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiFile;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.codeInsight.override.PyMethodMember;
@@ -228,4 +229,45 @@ public class PyOverrideTest extends PyTestCase {
public void testAsyncMethod() {
runWithLanguageLevel(LanguageLevel.PYTHON36, () -> doTest());
}
// PY-30287
public void testMethodWithOverloadsInTheSameFile() {
runWithLanguageLevel(
LanguageLevel.PYTHON35,
() -> {
myFixture.configureByFile("override/" + getTestName(true) + ".py");
PyOverrideImplementUtil.overrideMethods(
myFixture.getEditor(),
getTopLevelClass(1),
Collections.singletonList(new PyMethodMember(getTopLevelClass(0).getMethods()[2])),
false
);
myFixture.checkResultByFile("override/" + getTestName(true) + "_after.py", true);
}
);
}
// PY-30287
public void testMethodWithOverloadsInAnotherFile() {
runWithLanguageLevel(
LanguageLevel.PYTHON35,
() -> {
final PsiFile[] files = myFixture.configureByFiles(
"override/" + getTestName(true) + ".py",
"override/" + getTestName(true) + "_parent.py"
);
PyOverrideImplementUtil.overrideMethods(
myFixture.getEditor(),
getTopLevelClass(0),
Collections.singletonList(new PyMethodMember(((PyFile)files[1]).getTopLevelClasses().get(0).getMethods()[2])),
false
);
myFixture.checkResultByFile("override/" + getTestName(true) + "_after.py", true);
}
);
}
}