mirror of
https://gitflic.ru/project/openide/openide.git
synced 2025-12-16 22:51:17 +07:00
Update overriding method with overloads (PY-30287)
* Select implementation as super method, not an overload * Copy all its overloads * Add imports for decorators (e.g. typing.overload)
This commit is contained in:
@@ -27,6 +27,7 @@ import com.jetbrains.python.psi.impl.PyFunctionBuilder;
|
|||||||
import com.jetbrains.python.psi.impl.PyPsiUtils;
|
import com.jetbrains.python.psi.impl.PyPsiUtils;
|
||||||
import com.jetbrains.python.psi.resolve.PyResolveContext;
|
import com.jetbrains.python.psi.resolve.PyResolveContext;
|
||||||
import com.jetbrains.python.psi.types.*;
|
import com.jetbrains.python.psi.types.*;
|
||||||
|
import com.jetbrains.python.pyi.PyiUtil;
|
||||||
import com.jetbrains.python.refactoring.classes.PyClassRefactoringUtil;
|
import com.jetbrains.python.refactoring.classes.PyClassRefactoringUtil;
|
||||||
import one.util.streamex.StreamEx;
|
import one.util.streamex.StreamEx;
|
||||||
import org.jetbrains.annotations.NotNull;
|
import org.jetbrains.annotations.NotNull;
|
||||||
@@ -157,12 +158,7 @@ public class PyOverrideImplementUtil {
|
|||||||
|
|
||||||
PyFunction element = null;
|
PyFunction element = null;
|
||||||
for (PyMethodMember newMember : Lists.reverse(newMembers)) {
|
for (PyMethodMember newMember : Lists.reverse(newMembers)) {
|
||||||
final PyFunction baseFunction = (PyFunction)newMember.getPsiElement();
|
element = writeMember(pyClass, (PyFunction)newMember.getPsiElement(), anchor, implement);
|
||||||
final PyFunctionBuilder builder = buildOverriddenFunction(pyClass, baseFunction, implement);
|
|
||||||
final PyFunction function = builder.addFunctionAfter(statementList, anchor);
|
|
||||||
|
|
||||||
addImports(baseFunction, function);
|
|
||||||
element = CodeInsightUtilCore.forcePsiPostprocessAndRestoreElement(function);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
PyPsiUtils.removeRedundantPass(statementList);
|
PyPsiUtils.removeRedundantPass(statementList);
|
||||||
@@ -175,6 +171,29 @@ public class PyOverrideImplementUtil {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Nullable
|
||||||
|
private static PyFunction writeMember(@NotNull PyClass cls,
|
||||||
|
@NotNull PyFunction baseFunction,
|
||||||
|
@Nullable PsiElement anchor,
|
||||||
|
boolean implement) {
|
||||||
|
final PyStatementList statementList = cls.getStatementList();
|
||||||
|
final TypeEvalContext context = TypeEvalContext.userInitiated(cls.getProject(), cls.getContainingFile());
|
||||||
|
|
||||||
|
final PyFunction function = buildOverriddenFunction(cls, baseFunction, implement).addFunctionAfter(statementList, anchor);
|
||||||
|
addImports(baseFunction, function, context);
|
||||||
|
|
||||||
|
PyiUtil
|
||||||
|
.getOverloads(baseFunction, context)
|
||||||
|
.forEach(
|
||||||
|
baseOverload -> {
|
||||||
|
final PyFunction overload = (PyFunction)statementList.addBefore(baseOverload, function);
|
||||||
|
addImports(baseOverload, overload, context);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
return CodeInsightUtilCore.forcePsiPostprocessAndRestoreElement(function);
|
||||||
|
}
|
||||||
|
|
||||||
private static PyFunctionBuilder buildOverriddenFunction(PyClass pyClass,
|
private static PyFunctionBuilder buildOverriddenFunction(PyClass pyClass,
|
||||||
PyFunction baseFunction,
|
PyFunction baseFunction,
|
||||||
boolean implement) {
|
boolean implement) {
|
||||||
@@ -352,8 +371,10 @@ public class PyOverrideImplementUtil {
|
|||||||
if (type != null) {
|
if (type != null) {
|
||||||
for (PyFunction function : PyTypeUtil.getMembersOfType(type, PyFunction.class, false, context)) {
|
for (PyFunction function : PyTypeUtil.getMembersOfType(type, PyFunction.class, false, context)) {
|
||||||
final String name = function.getName();
|
final String name = function.getName();
|
||||||
if (name != null && !functions.containsKey(name)) {
|
if (name != null) {
|
||||||
functions.put(name, function);
|
if (!functions.containsKey(name) || PyiUtil.isOverload(functions.get(name), context) && !PyiUtil.isOverload(function, context)) {
|
||||||
|
functions.put(name, function);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -362,22 +383,19 @@ public class PyOverrideImplementUtil {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Adds imports for type hints in overridden function (PY-18553).
|
* Adds imports for type hints and decorators in overridden function.
|
||||||
*
|
*
|
||||||
* @param baseFunction base function used to resolve types
|
* @param baseFunction base function used to resolve types
|
||||||
* @param function overridden function
|
* @param function overridden function
|
||||||
*/
|
*/
|
||||||
private static void addImports(@NotNull PyFunction baseFunction, @NotNull PyFunction function) {
|
private static void addImports(@NotNull PyFunction baseFunction, @NotNull PyFunction function, @NotNull TypeEvalContext context) {
|
||||||
final TypeEvalContext typeEvalContext = TypeEvalContext.userInitiated(baseFunction.getProject(), baseFunction.getContainingFile());
|
|
||||||
|
|
||||||
final UnresolvedExpressionVisitor unresolvedExpressionVisitor = new UnresolvedExpressionVisitor();
|
final UnresolvedExpressionVisitor unresolvedExpressionVisitor = new UnresolvedExpressionVisitor();
|
||||||
final List<PyAnnotation> annotations = getAnnotations(function, typeEvalContext);
|
getAnnotations(function, context).forEach(annotation -> annotation.accept(unresolvedExpressionVisitor));
|
||||||
annotations.forEach(annotation -> unresolvedExpressionVisitor.visitPyElement(annotation));
|
getDecorators(function).forEach(decorator -> decorator.accept(unresolvedExpressionVisitor));
|
||||||
final List<PyReferenceExpression> unresolved = unresolvedExpressionVisitor.getUnresolved();
|
|
||||||
|
|
||||||
final ResolveExpressionVisitor resolveExpressionVisitor = new ResolveExpressionVisitor(unresolved);
|
final ResolveExpressionVisitor resolveExpressionVisitor = new ResolveExpressionVisitor(unresolvedExpressionVisitor.getUnresolved());
|
||||||
final List<PyAnnotation> baseAnnotations = getAnnotations(baseFunction, typeEvalContext);
|
getAnnotations(baseFunction, context).forEach(annotation -> annotation.accept(resolveExpressionVisitor));
|
||||||
baseAnnotations.forEach(annotation -> resolveExpressionVisitor.visitPyElement(annotation));
|
getDecorators(baseFunction).forEach(decorator -> decorator.accept(resolveExpressionVisitor));
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -387,6 +405,7 @@ public class PyOverrideImplementUtil {
|
|||||||
* @param typeEvalContext
|
* @param typeEvalContext
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
|
@NotNull
|
||||||
private static List<PyAnnotation> getAnnotations(@NotNull PyFunction function, @NotNull TypeEvalContext typeEvalContext) {
|
private static List<PyAnnotation> getAnnotations(@NotNull PyFunction function, @NotNull TypeEvalContext typeEvalContext) {
|
||||||
return StreamEx.of(function.getParameters(typeEvalContext))
|
return StreamEx.of(function.getParameters(typeEvalContext))
|
||||||
.map(PyCallableParameter::getParameter)
|
.map(PyCallableParameter::getParameter)
|
||||||
@@ -398,6 +417,12 @@ public class PyOverrideImplementUtil {
|
|||||||
.toList();
|
.toList();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@NotNull
|
||||||
|
private static List<PyDecorator> getDecorators(@NotNull PyFunction function) {
|
||||||
|
final PyDecoratorList decoratorList = function.getDecoratorList();
|
||||||
|
return decoratorList == null ? Collections.emptyList() : Arrays.asList(decoratorList.getDecorators());
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Collects unresolved {@link PyReferenceExpression} objects.
|
* Collects unresolved {@link PyReferenceExpression} objects.
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -0,0 +1,4 @@
|
|||||||
|
from methodWithOverloadsInAnotherFile_parent import Foo
|
||||||
|
|
||||||
|
class B(Foo):
|
||||||
|
<caret>pass
|
||||||
@@ -0,0 +1,14 @@
|
|||||||
|
from typing import overload
|
||||||
|
|
||||||
|
from methodWithOverloadsInAnotherFile_parent import Foo
|
||||||
|
|
||||||
|
class B(Foo):
|
||||||
|
@overload
|
||||||
|
def fun(self, s:str) -> str: pass
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def fun(self, i:int) -> int: pass
|
||||||
|
|
||||||
|
def fun(self, x):
|
||||||
|
super().fun(x)
|
||||||
|
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
from typing import overload
|
||||||
|
|
||||||
|
class Foo(object):
|
||||||
|
@overload
|
||||||
|
def fun(self, s:str) -> str: pass
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def fun(self, i:int) -> int: pass
|
||||||
|
|
||||||
|
def fun(self, x):
|
||||||
|
pass
|
||||||
14
python/testData/override/methodWithOverloadsInTheSameFile.py
Normal file
14
python/testData/override/methodWithOverloadsInTheSameFile.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
from typing import overload
|
||||||
|
|
||||||
|
class Foo:
|
||||||
|
@overload
|
||||||
|
def fun(self, s:str) -> str: pass
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def fun(self, i:int) -> int: pass
|
||||||
|
|
||||||
|
def fun(self, x):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class B(Foo):
|
||||||
|
<caret>pass
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
from typing import overload
|
||||||
|
|
||||||
|
class Foo:
|
||||||
|
@overload
|
||||||
|
def fun(self, s:str) -> str: pass
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def fun(self, i:int) -> int: pass
|
||||||
|
|
||||||
|
def fun(self, x):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class B(Foo):
|
||||||
|
@overload
|
||||||
|
def fun(self, s:str) -> str: pass
|
||||||
|
|
||||||
|
@overload
|
||||||
|
def fun(self, i:int) -> int: pass
|
||||||
|
|
||||||
|
def fun(self, x):
|
||||||
|
super().fun(x)
|
||||||
|
|
||||||
@@ -4,6 +4,7 @@
|
|||||||
package com.jetbrains.python;
|
package com.jetbrains.python;
|
||||||
|
|
||||||
import com.intellij.psi.PsiElement;
|
import com.intellij.psi.PsiElement;
|
||||||
|
import com.intellij.psi.PsiFile;
|
||||||
import com.intellij.psi.util.PsiTreeUtil;
|
import com.intellij.psi.util.PsiTreeUtil;
|
||||||
import com.intellij.util.containers.ContainerUtil;
|
import com.intellij.util.containers.ContainerUtil;
|
||||||
import com.jetbrains.python.codeInsight.override.PyMethodMember;
|
import com.jetbrains.python.codeInsight.override.PyMethodMember;
|
||||||
@@ -228,4 +229,45 @@ public class PyOverrideTest extends PyTestCase {
|
|||||||
public void testAsyncMethod() {
|
public void testAsyncMethod() {
|
||||||
runWithLanguageLevel(LanguageLevel.PYTHON36, () -> doTest());
|
runWithLanguageLevel(LanguageLevel.PYTHON36, () -> doTest());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PY-30287
|
||||||
|
public void testMethodWithOverloadsInTheSameFile() {
|
||||||
|
runWithLanguageLevel(
|
||||||
|
LanguageLevel.PYTHON35,
|
||||||
|
() -> {
|
||||||
|
myFixture.configureByFile("override/" + getTestName(true) + ".py");
|
||||||
|
|
||||||
|
PyOverrideImplementUtil.overrideMethods(
|
||||||
|
myFixture.getEditor(),
|
||||||
|
getTopLevelClass(1),
|
||||||
|
Collections.singletonList(new PyMethodMember(getTopLevelClass(0).getMethods()[2])),
|
||||||
|
false
|
||||||
|
);
|
||||||
|
|
||||||
|
myFixture.checkResultByFile("override/" + getTestName(true) + "_after.py", true);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// PY-30287
|
||||||
|
public void testMethodWithOverloadsInAnotherFile() {
|
||||||
|
runWithLanguageLevel(
|
||||||
|
LanguageLevel.PYTHON35,
|
||||||
|
() -> {
|
||||||
|
final PsiFile[] files = myFixture.configureByFiles(
|
||||||
|
"override/" + getTestName(true) + ".py",
|
||||||
|
"override/" + getTestName(true) + "_parent.py"
|
||||||
|
);
|
||||||
|
|
||||||
|
PyOverrideImplementUtil.overrideMethods(
|
||||||
|
myFixture.getEditor(),
|
||||||
|
getTopLevelClass(0),
|
||||||
|
Collections.singletonList(new PyMethodMember(((PyFile)files[1]).getTopLevelClasses().get(0).getMethods()[2])),
|
||||||
|
false
|
||||||
|
);
|
||||||
|
|
||||||
|
myFixture.checkResultByFile("override/" + getTestName(true) + "_after.py", true);
|
||||||
|
}
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user