Update overriding method with overloads (PY-30287)

* Select implementation as super method, not an overload
* Copy all its overloads
* Add imports for decorators (e.g. typing.overload)
This commit is contained in:
Semyon Proshev
2018-06-27 14:50:06 +03:00
parent 98582a7785
commit 6eeb3b25be
7 changed files with 150 additions and 18 deletions

View File

@@ -27,6 +27,7 @@ import com.jetbrains.python.psi.impl.PyFunctionBuilder;
import com.jetbrains.python.psi.impl.PyPsiUtils; import com.jetbrains.python.psi.impl.PyPsiUtils;
import com.jetbrains.python.psi.resolve.PyResolveContext; import com.jetbrains.python.psi.resolve.PyResolveContext;
import com.jetbrains.python.psi.types.*; import com.jetbrains.python.psi.types.*;
import com.jetbrains.python.pyi.PyiUtil;
import com.jetbrains.python.refactoring.classes.PyClassRefactoringUtil; import com.jetbrains.python.refactoring.classes.PyClassRefactoringUtil;
import one.util.streamex.StreamEx; import one.util.streamex.StreamEx;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
@@ -157,12 +158,7 @@ public class PyOverrideImplementUtil {
PyFunction element = null; PyFunction element = null;
for (PyMethodMember newMember : Lists.reverse(newMembers)) { for (PyMethodMember newMember : Lists.reverse(newMembers)) {
final PyFunction baseFunction = (PyFunction)newMember.getPsiElement(); element = writeMember(pyClass, (PyFunction)newMember.getPsiElement(), anchor, implement);
final PyFunctionBuilder builder = buildOverriddenFunction(pyClass, baseFunction, implement);
final PyFunction function = builder.addFunctionAfter(statementList, anchor);
addImports(baseFunction, function);
element = CodeInsightUtilCore.forcePsiPostprocessAndRestoreElement(function);
} }
PyPsiUtils.removeRedundantPass(statementList); PyPsiUtils.removeRedundantPass(statementList);
@@ -175,6 +171,29 @@ public class PyOverrideImplementUtil {
} }
} }
@Nullable
private static PyFunction writeMember(@NotNull PyClass cls,
@NotNull PyFunction baseFunction,
@Nullable PsiElement anchor,
boolean implement) {
final PyStatementList statementList = cls.getStatementList();
final TypeEvalContext context = TypeEvalContext.userInitiated(cls.getProject(), cls.getContainingFile());
final PyFunction function = buildOverriddenFunction(cls, baseFunction, implement).addFunctionAfter(statementList, anchor);
addImports(baseFunction, function, context);
PyiUtil
.getOverloads(baseFunction, context)
.forEach(
baseOverload -> {
final PyFunction overload = (PyFunction)statementList.addBefore(baseOverload, function);
addImports(baseOverload, overload, context);
}
);
return CodeInsightUtilCore.forcePsiPostprocessAndRestoreElement(function);
}
private static PyFunctionBuilder buildOverriddenFunction(PyClass pyClass, private static PyFunctionBuilder buildOverriddenFunction(PyClass pyClass,
PyFunction baseFunction, PyFunction baseFunction,
boolean implement) { boolean implement) {
@@ -352,8 +371,10 @@ public class PyOverrideImplementUtil {
if (type != null) { if (type != null) {
for (PyFunction function : PyTypeUtil.getMembersOfType(type, PyFunction.class, false, context)) { for (PyFunction function : PyTypeUtil.getMembersOfType(type, PyFunction.class, false, context)) {
final String name = function.getName(); final String name = function.getName();
if (name != null && !functions.containsKey(name)) { if (name != null) {
functions.put(name, function); if (!functions.containsKey(name) || PyiUtil.isOverload(functions.get(name), context) && !PyiUtil.isOverload(function, context)) {
functions.put(name, function);
}
} }
} }
} }
@@ -362,22 +383,19 @@ public class PyOverrideImplementUtil {
} }
/** /**
* Adds imports for type hints in overridden function (PY-18553). * Adds imports for type hints and decorators in overridden function.
* *
* @param baseFunction base function used to resolve types * @param baseFunction base function used to resolve types
* @param function overridden function * @param function overridden function
*/ */
private static void addImports(@NotNull PyFunction baseFunction, @NotNull PyFunction function) { private static void addImports(@NotNull PyFunction baseFunction, @NotNull PyFunction function, @NotNull TypeEvalContext context) {
final TypeEvalContext typeEvalContext = TypeEvalContext.userInitiated(baseFunction.getProject(), baseFunction.getContainingFile());
final UnresolvedExpressionVisitor unresolvedExpressionVisitor = new UnresolvedExpressionVisitor(); final UnresolvedExpressionVisitor unresolvedExpressionVisitor = new UnresolvedExpressionVisitor();
final List<PyAnnotation> annotations = getAnnotations(function, typeEvalContext); getAnnotations(function, context).forEach(annotation -> annotation.accept(unresolvedExpressionVisitor));
annotations.forEach(annotation -> unresolvedExpressionVisitor.visitPyElement(annotation)); getDecorators(function).forEach(decorator -> decorator.accept(unresolvedExpressionVisitor));
final List<PyReferenceExpression> unresolved = unresolvedExpressionVisitor.getUnresolved();
final ResolveExpressionVisitor resolveExpressionVisitor = new ResolveExpressionVisitor(unresolved); final ResolveExpressionVisitor resolveExpressionVisitor = new ResolveExpressionVisitor(unresolvedExpressionVisitor.getUnresolved());
final List<PyAnnotation> baseAnnotations = getAnnotations(baseFunction, typeEvalContext); getAnnotations(baseFunction, context).forEach(annotation -> annotation.accept(resolveExpressionVisitor));
baseAnnotations.forEach(annotation -> resolveExpressionVisitor.visitPyElement(annotation)); getDecorators(baseFunction).forEach(decorator -> decorator.accept(resolveExpressionVisitor));
} }
/** /**
@@ -387,6 +405,7 @@ public class PyOverrideImplementUtil {
* @param typeEvalContext * @param typeEvalContext
* @return * @return
*/ */
@NotNull
private static List<PyAnnotation> getAnnotations(@NotNull PyFunction function, @NotNull TypeEvalContext typeEvalContext) { private static List<PyAnnotation> getAnnotations(@NotNull PyFunction function, @NotNull TypeEvalContext typeEvalContext) {
return StreamEx.of(function.getParameters(typeEvalContext)) return StreamEx.of(function.getParameters(typeEvalContext))
.map(PyCallableParameter::getParameter) .map(PyCallableParameter::getParameter)
@@ -398,6 +417,12 @@ public class PyOverrideImplementUtil {
.toList(); .toList();
} }
@NotNull
private static List<PyDecorator> getDecorators(@NotNull PyFunction function) {
final PyDecoratorList decoratorList = function.getDecoratorList();
return decoratorList == null ? Collections.emptyList() : Arrays.asList(decoratorList.getDecorators());
}
/** /**
* Collects unresolved {@link PyReferenceExpression} objects. * Collects unresolved {@link PyReferenceExpression} objects.
*/ */

View File

@@ -0,0 +1,4 @@
from methodWithOverloadsInAnotherFile_parent import Foo
class B(Foo):
<caret>pass

View File

@@ -0,0 +1,14 @@
from typing import overload
from methodWithOverloadsInAnotherFile_parent import Foo
class B(Foo):
@overload
def fun(self, s:str) -> str: pass
@overload
def fun(self, i:int) -> int: pass
def fun(self, x):
super().fun(x)

View File

@@ -0,0 +1,11 @@
from typing import overload
class Foo(object):
@overload
def fun(self, s:str) -> str: pass
@overload
def fun(self, i:int) -> int: pass
def fun(self, x):
pass

View File

@@ -0,0 +1,14 @@
from typing import overload
class Foo:
@overload
def fun(self, s:str) -> str: pass
@overload
def fun(self, i:int) -> int: pass
def fun(self, x):
pass
class B(Foo):
<caret>pass

View File

@@ -0,0 +1,22 @@
from typing import overload
class Foo:
@overload
def fun(self, s:str) -> str: pass
@overload
def fun(self, i:int) -> int: pass
def fun(self, x):
pass
class B(Foo):
@overload
def fun(self, s:str) -> str: pass
@overload
def fun(self, i:int) -> int: pass
def fun(self, x):
super().fun(x)

View File

@@ -4,6 +4,7 @@
package com.jetbrains.python; package com.jetbrains.python;
import com.intellij.psi.PsiElement; import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiFile;
import com.intellij.psi.util.PsiTreeUtil; import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.containers.ContainerUtil; import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.codeInsight.override.PyMethodMember; import com.jetbrains.python.codeInsight.override.PyMethodMember;
@@ -228,4 +229,45 @@ public class PyOverrideTest extends PyTestCase {
public void testAsyncMethod() { public void testAsyncMethod() {
runWithLanguageLevel(LanguageLevel.PYTHON36, () -> doTest()); runWithLanguageLevel(LanguageLevel.PYTHON36, () -> doTest());
} }
// PY-30287
public void testMethodWithOverloadsInTheSameFile() {
runWithLanguageLevel(
LanguageLevel.PYTHON35,
() -> {
myFixture.configureByFile("override/" + getTestName(true) + ".py");
PyOverrideImplementUtil.overrideMethods(
myFixture.getEditor(),
getTopLevelClass(1),
Collections.singletonList(new PyMethodMember(getTopLevelClass(0).getMethods()[2])),
false
);
myFixture.checkResultByFile("override/" + getTestName(true) + "_after.py", true);
}
);
}
// PY-30287
public void testMethodWithOverloadsInAnotherFile() {
runWithLanguageLevel(
LanguageLevel.PYTHON35,
() -> {
final PsiFile[] files = myFixture.configureByFiles(
"override/" + getTestName(true) + ".py",
"override/" + getTestName(true) + "_parent.py"
);
PyOverrideImplementUtil.overrideMethods(
myFixture.getEditor(),
getTopLevelClass(0),
Collections.singletonList(new PyMethodMember(((PyFile)files[1]).getTopLevelClasses().get(0).getMethods()[2])),
false
);
myFixture.checkResultByFile("override/" + getTestName(true) + "_after.py", true);
}
);
}
} }