From 6eeb3b25bede59649f11814a9133d0192b69561e Mon Sep 17 00:00:00 2001 From: Semyon Proshev Date: Wed, 27 Jun 2018 14:50:06 +0300 Subject: [PATCH] Update overriding method with overloads (PY-30287) * Select implementation as super method, not an overload * Copy all its overloads * Add imports for decorators (e.g. typing.overload) --- .../override/PyOverrideImplementUtil.java | 61 +++++++++++++------ .../methodWithOverloadsInAnotherFile.py | 4 ++ .../methodWithOverloadsInAnotherFile_after.py | 14 +++++ ...methodWithOverloadsInAnotherFile_parent.py | 11 ++++ .../methodWithOverloadsInTheSameFile.py | 14 +++++ .../methodWithOverloadsInTheSameFile_after.py | 22 +++++++ .../com/jetbrains/python/PyOverrideTest.java | 42 +++++++++++++ 7 files changed, 150 insertions(+), 18 deletions(-) create mode 100644 python/testData/override/methodWithOverloadsInAnotherFile.py create mode 100644 python/testData/override/methodWithOverloadsInAnotherFile_after.py create mode 100644 python/testData/override/methodWithOverloadsInAnotherFile_parent.py create mode 100644 python/testData/override/methodWithOverloadsInTheSameFile.py create mode 100644 python/testData/override/methodWithOverloadsInTheSameFile_after.py diff --git a/python/src/com/jetbrains/python/codeInsight/override/PyOverrideImplementUtil.java b/python/src/com/jetbrains/python/codeInsight/override/PyOverrideImplementUtil.java index ac5aebcc6745..587a220fd831 100644 --- a/python/src/com/jetbrains/python/codeInsight/override/PyOverrideImplementUtil.java +++ b/python/src/com/jetbrains/python/codeInsight/override/PyOverrideImplementUtil.java @@ -27,6 +27,7 @@ import com.jetbrains.python.psi.impl.PyFunctionBuilder; import com.jetbrains.python.psi.impl.PyPsiUtils; import com.jetbrains.python.psi.resolve.PyResolveContext; import com.jetbrains.python.psi.types.*; +import com.jetbrains.python.pyi.PyiUtil; import com.jetbrains.python.refactoring.classes.PyClassRefactoringUtil; import one.util.streamex.StreamEx; import org.jetbrains.annotations.NotNull; @@ -157,12 +158,7 @@ public class PyOverrideImplementUtil { PyFunction element = null; for (PyMethodMember newMember : Lists.reverse(newMembers)) { - final PyFunction baseFunction = (PyFunction)newMember.getPsiElement(); - final PyFunctionBuilder builder = buildOverriddenFunction(pyClass, baseFunction, implement); - final PyFunction function = builder.addFunctionAfter(statementList, anchor); - - addImports(baseFunction, function); - element = CodeInsightUtilCore.forcePsiPostprocessAndRestoreElement(function); + element = writeMember(pyClass, (PyFunction)newMember.getPsiElement(), anchor, implement); } PyPsiUtils.removeRedundantPass(statementList); @@ -175,6 +171,29 @@ public class PyOverrideImplementUtil { } } + @Nullable + private static PyFunction writeMember(@NotNull PyClass cls, + @NotNull PyFunction baseFunction, + @Nullable PsiElement anchor, + boolean implement) { + final PyStatementList statementList = cls.getStatementList(); + final TypeEvalContext context = TypeEvalContext.userInitiated(cls.getProject(), cls.getContainingFile()); + + final PyFunction function = buildOverriddenFunction(cls, baseFunction, implement).addFunctionAfter(statementList, anchor); + addImports(baseFunction, function, context); + + PyiUtil + .getOverloads(baseFunction, context) + .forEach( + baseOverload -> { + final PyFunction overload = (PyFunction)statementList.addBefore(baseOverload, function); + addImports(baseOverload, overload, context); + } + ); + + return CodeInsightUtilCore.forcePsiPostprocessAndRestoreElement(function); + } + private static PyFunctionBuilder buildOverriddenFunction(PyClass pyClass, PyFunction baseFunction, boolean implement) { @@ -352,8 +371,10 @@ public class PyOverrideImplementUtil { if (type != null) { for (PyFunction function : PyTypeUtil.getMembersOfType(type, PyFunction.class, false, context)) { final String name = function.getName(); - if (name != null && !functions.containsKey(name)) { - functions.put(name, function); + if (name != null) { + if (!functions.containsKey(name) || PyiUtil.isOverload(functions.get(name), context) && !PyiUtil.isOverload(function, context)) { + functions.put(name, function); + } } } } @@ -362,22 +383,19 @@ public class PyOverrideImplementUtil { } /** - * Adds imports for type hints in overridden function (PY-18553). + * Adds imports for type hints and decorators in overridden function. * * @param baseFunction base function used to resolve types * @param function overridden function */ - private static void addImports(@NotNull PyFunction baseFunction, @NotNull PyFunction function) { - final TypeEvalContext typeEvalContext = TypeEvalContext.userInitiated(baseFunction.getProject(), baseFunction.getContainingFile()); - + private static void addImports(@NotNull PyFunction baseFunction, @NotNull PyFunction function, @NotNull TypeEvalContext context) { final UnresolvedExpressionVisitor unresolvedExpressionVisitor = new UnresolvedExpressionVisitor(); - final List annotations = getAnnotations(function, typeEvalContext); - annotations.forEach(annotation -> unresolvedExpressionVisitor.visitPyElement(annotation)); - final List unresolved = unresolvedExpressionVisitor.getUnresolved(); + getAnnotations(function, context).forEach(annotation -> annotation.accept(unresolvedExpressionVisitor)); + getDecorators(function).forEach(decorator -> decorator.accept(unresolvedExpressionVisitor)); - final ResolveExpressionVisitor resolveExpressionVisitor = new ResolveExpressionVisitor(unresolved); - final List baseAnnotations = getAnnotations(baseFunction, typeEvalContext); - baseAnnotations.forEach(annotation -> resolveExpressionVisitor.visitPyElement(annotation)); + final ResolveExpressionVisitor resolveExpressionVisitor = new ResolveExpressionVisitor(unresolvedExpressionVisitor.getUnresolved()); + getAnnotations(baseFunction, context).forEach(annotation -> annotation.accept(resolveExpressionVisitor)); + getDecorators(baseFunction).forEach(decorator -> decorator.accept(resolveExpressionVisitor)); } /** @@ -387,6 +405,7 @@ public class PyOverrideImplementUtil { * @param typeEvalContext * @return */ + @NotNull private static List getAnnotations(@NotNull PyFunction function, @NotNull TypeEvalContext typeEvalContext) { return StreamEx.of(function.getParameters(typeEvalContext)) .map(PyCallableParameter::getParameter) @@ -398,6 +417,12 @@ public class PyOverrideImplementUtil { .toList(); } + @NotNull + private static List getDecorators(@NotNull PyFunction function) { + final PyDecoratorList decoratorList = function.getDecoratorList(); + return decoratorList == null ? Collections.emptyList() : Arrays.asList(decoratorList.getDecorators()); + } + /** * Collects unresolved {@link PyReferenceExpression} objects. */ diff --git a/python/testData/override/methodWithOverloadsInAnotherFile.py b/python/testData/override/methodWithOverloadsInAnotherFile.py new file mode 100644 index 000000000000..e3520773baa7 --- /dev/null +++ b/python/testData/override/methodWithOverloadsInAnotherFile.py @@ -0,0 +1,4 @@ +from methodWithOverloadsInAnotherFile_parent import Foo + +class B(Foo): + pass \ No newline at end of file diff --git a/python/testData/override/methodWithOverloadsInAnotherFile_after.py b/python/testData/override/methodWithOverloadsInAnotherFile_after.py new file mode 100644 index 000000000000..477bc7e76e4f --- /dev/null +++ b/python/testData/override/methodWithOverloadsInAnotherFile_after.py @@ -0,0 +1,14 @@ +from typing import overload + +from methodWithOverloadsInAnotherFile_parent import Foo + +class B(Foo): + @overload + def fun(self, s:str) -> str: pass + + @overload + def fun(self, i:int) -> int: pass + + def fun(self, x): + super().fun(x) + diff --git a/python/testData/override/methodWithOverloadsInAnotherFile_parent.py b/python/testData/override/methodWithOverloadsInAnotherFile_parent.py new file mode 100644 index 000000000000..b40f1d5bcb2d --- /dev/null +++ b/python/testData/override/methodWithOverloadsInAnotherFile_parent.py @@ -0,0 +1,11 @@ +from typing import overload + +class Foo(object): + @overload + def fun(self, s:str) -> str: pass + + @overload + def fun(self, i:int) -> int: pass + + def fun(self, x): + pass \ No newline at end of file diff --git a/python/testData/override/methodWithOverloadsInTheSameFile.py b/python/testData/override/methodWithOverloadsInTheSameFile.py new file mode 100644 index 000000000000..70f8f8ba8877 --- /dev/null +++ b/python/testData/override/methodWithOverloadsInTheSameFile.py @@ -0,0 +1,14 @@ +from typing import overload + +class Foo: + @overload + def fun(self, s:str) -> str: pass + + @overload + def fun(self, i:int) -> int: pass + + def fun(self, x): + pass + +class B(Foo): + pass \ No newline at end of file diff --git a/python/testData/override/methodWithOverloadsInTheSameFile_after.py b/python/testData/override/methodWithOverloadsInTheSameFile_after.py new file mode 100644 index 000000000000..45d7ae80fab8 --- /dev/null +++ b/python/testData/override/methodWithOverloadsInTheSameFile_after.py @@ -0,0 +1,22 @@ +from typing import overload + +class Foo: + @overload + def fun(self, s:str) -> str: pass + + @overload + def fun(self, i:int) -> int: pass + + def fun(self, x): + pass + +class B(Foo): + @overload + def fun(self, s:str) -> str: pass + + @overload + def fun(self, i:int) -> int: pass + + def fun(self, x): + super().fun(x) + diff --git a/python/testSrc/com/jetbrains/python/PyOverrideTest.java b/python/testSrc/com/jetbrains/python/PyOverrideTest.java index e8ae6540b8aa..f030892d5e2d 100644 --- a/python/testSrc/com/jetbrains/python/PyOverrideTest.java +++ b/python/testSrc/com/jetbrains/python/PyOverrideTest.java @@ -4,6 +4,7 @@ package com.jetbrains.python; import com.intellij.psi.PsiElement; +import com.intellij.psi.PsiFile; import com.intellij.psi.util.PsiTreeUtil; import com.intellij.util.containers.ContainerUtil; import com.jetbrains.python.codeInsight.override.PyMethodMember; @@ -228,4 +229,45 @@ public class PyOverrideTest extends PyTestCase { public void testAsyncMethod() { runWithLanguageLevel(LanguageLevel.PYTHON36, () -> doTest()); } + + // PY-30287 + public void testMethodWithOverloadsInTheSameFile() { + runWithLanguageLevel( + LanguageLevel.PYTHON35, + () -> { + myFixture.configureByFile("override/" + getTestName(true) + ".py"); + + PyOverrideImplementUtil.overrideMethods( + myFixture.getEditor(), + getTopLevelClass(1), + Collections.singletonList(new PyMethodMember(getTopLevelClass(0).getMethods()[2])), + false + ); + + myFixture.checkResultByFile("override/" + getTestName(true) + "_after.py", true); + } + ); + } + + // PY-30287 + public void testMethodWithOverloadsInAnotherFile() { + runWithLanguageLevel( + LanguageLevel.PYTHON35, + () -> { + final PsiFile[] files = myFixture.configureByFiles( + "override/" + getTestName(true) + ".py", + "override/" + getTestName(true) + "_parent.py" + ); + + PyOverrideImplementUtil.overrideMethods( + myFixture.getEditor(), + getTopLevelClass(0), + Collections.singletonList(new PyMethodMember(((PyFile)files[1]).getTopLevelClasses().get(0).getMethods()[2])), + false + ); + + myFixture.checkResultByFile("override/" + getTestName(true) + "_after.py", true); + } + ); + } }