diff --git a/python/python-psi-impl/resources/META-INF/PythonPsiImpl.xml b/python/python-psi-impl/resources/META-INF/PythonPsiImpl.xml index d509385ce941..c5e2c8534ba8 100644 --- a/python/python-psi-impl/resources/META-INF/PythonPsiImpl.xml +++ b/python/python-psi-impl/resources/META-INF/PythonPsiImpl.xml @@ -525,9 +525,13 @@ + - + + diff --git a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/decorator/PyFunctoolsWrapsDecoratedFunctionTypeProvider.kt b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/decorator/PyFunctoolsWrapsDecoratedFunctionTypeProvider.kt new file mode 100644 index 000000000000..975b4be38daa --- /dev/null +++ b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/decorator/PyFunctoolsWrapsDecoratedFunctionTypeProvider.kt @@ -0,0 +1,63 @@ +package com.jetbrains.python.codeInsight.decorator + +import com.intellij.openapi.util.Ref +import com.intellij.psi.PsiElement +import com.intellij.psi.util.QualifiedName +import com.intellij.util.containers.ContainerUtil +import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil +import com.jetbrains.python.psi.PyCallable +import com.jetbrains.python.psi.PyFunction +import com.jetbrains.python.psi.PyKnownDecoratorUtil +import com.jetbrains.python.psi.PyUtil +import com.jetbrains.python.psi.impl.StubAwareComputation +import com.jetbrains.python.psi.resolve.PyResolveContext +import com.jetbrains.python.psi.resolve.PyResolveUtil +import com.jetbrains.python.psi.stubs.PyFunctoolsWrapsDecoratorStub +import com.jetbrains.python.psi.types.PyType +import com.jetbrains.python.psi.types.PyTypeProviderBase +import com.jetbrains.python.psi.types.TypeEvalContext + +/** + * Infer type for reference of a function decorated with 'functools.wraps'. + * Has to be used before {@link com.jetbrains.python.codeInsight.decorator.PyDecoratedFunctionTypeProvider} + */ +class PyFunctoolsWrapsDecoratedFunctionTypeProvider : PyTypeProviderBase() { + override fun getReferenceType(referenceTarget: PsiElement, context: TypeEvalContext, anchor: PsiElement?): Ref? { + if (referenceTarget !is PyFunction) return null + val wrappedFunction = ContainerUtil.findInstance(resolveWrapped(referenceTarget, context), PyFunction::class.java) ?: return null + return Ref.create(context.getType(wrappedFunction)) + } + + override fun getCallableType(callable: PyCallable, context: TypeEvalContext): PyType? { + return Ref.deref(getReferenceType(callable, context, null)) + } + + private fun resolveWrapped(function: PyFunction, context: TypeEvalContext): List { + val decorator = function.decoratorList?.decorators?.find { + val qName = it.qualifiedName + qName != null && PyKnownDecoratorUtil.asKnownDecorators(qName).contains(PyKnownDecoratorUtil.KnownDecorator.FUNCTOOLS_WRAPS) + } ?: return emptyList() + return StubAwareComputation.on(decorator) + .withCustomStub { it.getCustomStub(PyFunctoolsWrapsDecoratorStub::class.java) } + .overStub { + if (it == null) return@overStub emptyList() + var scopeOwner = ScopeUtil.getScopeOwner(decorator) + val wrappedQName = QualifiedName.fromDottedString(it.wrapped) + val resolved = mutableListOf() + while (scopeOwner != null) { + resolved.addAll(PyResolveUtil.resolveQualifiedNameInScope(wrappedQName, scopeOwner, context)) + scopeOwner = ScopeUtil.getScopeOwner(scopeOwner) + } + resolved + } + .overAst { + val wrappedExpr = it.argumentList?.getValueExpressionForParam(PyKnownDecoratorUtil.FunctoolsWrapsParameters.WRAPPED) + if (wrappedExpr == null) + emptyList() + else + PyUtil.multiResolveTopPriority(wrappedExpr, PyResolveContext.defaultContext(context)) + } + .withStubBuilder(PyFunctoolsWrapsDecoratorStub::create) + .compute(context) + } +} \ No newline at end of file diff --git a/python/python-psi-impl/src/com/jetbrains/python/documentation/PyDocumentationBuilder.java b/python/python-psi-impl/src/com/jetbrains/python/documentation/PyDocumentationBuilder.java index 49fccda72d82..cb677916aef0 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/documentation/PyDocumentationBuilder.java +++ b/python/python-psi-impl/src/com/jetbrains/python/documentation/PyDocumentationBuilder.java @@ -18,16 +18,14 @@ import com.intellij.util.containers.FactoryMap; import com.jetbrains.python.*; import com.jetbrains.python.codeInsight.controlflow.ScopeOwner; import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil; +import com.jetbrains.python.codeInsight.decorator.PyFunctoolsWrapsDecoratedFunctionTypeProvider; import com.jetbrains.python.documentation.docstrings.DocStringUtil; import com.jetbrains.python.psi.*; import com.jetbrains.python.psi.impl.PyBuiltinCache; import com.jetbrains.python.psi.resolve.PyResolveContext; import com.jetbrains.python.psi.resolve.QualifiedNameFinder; import com.jetbrains.python.psi.resolve.QualifiedResolveResult; -import com.jetbrains.python.psi.types.PyCallableParameter; -import com.jetbrains.python.psi.types.PyClassType; -import com.jetbrains.python.psi.types.PyType; -import com.jetbrains.python.psi.types.TypeEvalContext; +import com.jetbrains.python.psi.types.*; import com.jetbrains.python.pyi.PyiUtil; import com.jetbrains.python.toolbox.Maybe; import one.util.streamex.StreamEx; @@ -595,6 +593,16 @@ public class PyDocumentationBuilder { return resolved; } } + // Return wrapped function for functools.wraps decorated function + if (myElement instanceof PyFunction function) { + PyType type = new PyFunctoolsWrapsDecoratedFunctionTypeProvider().getCallableType(function, myContext); + if (type instanceof PyCallableType callableType) { + PyCallable callable = callableType.getCallable(); + if (callable != null) { + return callable; + } + } + } return myElement; } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/PyFileElementType.java b/python/python-psi-impl/src/com/jetbrains/python/psi/PyFileElementType.java index a339b0e2977f..f394f56cb437 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/PyFileElementType.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/PyFileElementType.java @@ -60,7 +60,7 @@ public class PyFileElementType extends IStubFileElementType { @Override public int getStubVersion() { // Don't forget to update versions of indexes that use the updated stub-based elements - return 91; + return 92; } @Nullable diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/PyKnownDecoratorUtil.java b/python/python-psi-impl/src/com/jetbrains/python/psi/PyKnownDecoratorUtil.java index d0fe3c99bef1..d5d754cda56e 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/PyKnownDecoratorUtil.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/PyKnownDecoratorUtil.java @@ -5,6 +5,7 @@ import com.intellij.psi.PsiElement; import com.intellij.psi.PsiFile; import com.intellij.psi.util.QualifiedName; import com.intellij.util.containers.ContainerUtil; +import com.jetbrains.python.FunctionParameter; import com.jetbrains.python.PyNames; import com.jetbrains.python.codeInsight.controlflow.ScopeOwner; import com.jetbrains.python.psi.resolve.PyResolveContext; @@ -12,7 +13,9 @@ import com.jetbrains.python.psi.resolve.PyResolveUtil; import com.jetbrains.python.psi.types.TypeEvalContext; import com.jetbrains.python.pyi.PyiFile; import one.util.streamex.StreamEx; +import org.jetbrains.annotations.ApiStatus; import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; import java.util.*; @@ -178,12 +181,18 @@ public final class PyKnownDecoratorUtil { .toImmutableList(); } else { - // The method might have been called during building of PSI stub indexes. Thus, we can't leave this file's boundaries. - // TODO Use proper local resolve to imported names here - return Collections.unmodifiableList(BY_SHORT_NAME.getOrDefault(qualifiedName.getLastComponent(), Collections.emptyList())); + return asKnownDecorators(qualifiedName); } } + @ApiStatus.Internal + @NotNull + public static List asKnownDecorators(@NotNull QualifiedName qualifiedName) { + // The method might have been called during building of PSI stub indexes. Thus, we can't leave this file's boundaries. + // TODO Use proper local resolve to imported names here + return Collections.unmodifiableList(BY_SHORT_NAME.getOrDefault(qualifiedName.getLastComponent(), Collections.emptyList())); + } + /** * Check that given element has any non-standard (read "unreliable") decorators. * @@ -270,4 +279,26 @@ public final class PyKnownDecoratorUtil { ? decorators.isEmpty() : decoratorList.getDecorators().length == StreamEx.of(decorators).groupingBy(KnownDecorator::getShortName).size(); } + + public enum FunctoolsWrapsParameters implements FunctionParameter { + WRAPPED(0, "wrapped"); + + private final int myPosition; + private final String myName; + + FunctoolsWrapsParameters(int position, @NotNull String name) { + myPosition = position; + myName = name; + } + + @Override + public int getPosition() { + return myPosition; + } + + @Override + public @NotNull String getName() { + return myName; + } + } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/stubs/PyFunctoolsWrapsDecoratorStubType.kt b/python/python-psi-impl/src/com/jetbrains/python/psi/stubs/PyFunctoolsWrapsDecoratorStubType.kt new file mode 100644 index 000000000000..dedaae5946ae --- /dev/null +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/stubs/PyFunctoolsWrapsDecoratorStubType.kt @@ -0,0 +1,38 @@ +package com.jetbrains.python.psi.stubs + +import com.intellij.psi.stubs.StubInputStream +import com.intellij.psi.stubs.StubOutputStream +import com.jetbrains.python.psi.PyDecorator +import com.jetbrains.python.psi.PyKnownDecoratorUtil +import com.jetbrains.python.psi.PyReferenceExpression +import com.jetbrains.python.psi.impl.stubs.PyCustomDecoratorStub +import com.jetbrains.python.psi.impl.stubs.PyCustomDecoratorStubType + +class PyFunctoolsWrapsDecoratorStubType : PyCustomDecoratorStubType { + override fun createStub(psi: PyDecorator): PyFunctoolsWrapsDecoratorStub? { + return PyFunctoolsWrapsDecoratorStub.create(psi) + } + + override fun deserializeStub(stream: StubInputStream): PyFunctoolsWrapsDecoratorStub? { + val name = stream.readNameString() ?: return null + return PyFunctoolsWrapsDecoratorStub(name) + } +} + +class PyFunctoolsWrapsDecoratorStub(val wrapped: String) : PyCustomDecoratorStub { + override fun getTypeClass(): Class> = PyFunctoolsWrapsDecoratorStubType::class.java + + override fun serialize(stream: StubOutputStream) { + stream.writeName(wrapped) + } + + companion object { + fun create(psi: PyDecorator): PyFunctoolsWrapsDecoratorStub? { + val qName = psi.qualifiedName ?: return null + if (!PyKnownDecoratorUtil.asKnownDecorators(qName).contains(PyKnownDecoratorUtil.KnownDecorator.FUNCTOOLS_WRAPS)) return null + val wrappedExpr = psi.argumentList?.getValueExpressionForParam(PyKnownDecoratorUtil.FunctoolsWrapsParameters.WRAPPED) as? PyReferenceExpression + val wrappedExprQName = wrappedExpr?.asQualifiedName() ?: return null + return PyFunctoolsWrapsDecoratorStub(wrappedExprQName.toString()) + } + } +} \ No newline at end of file diff --git a/python/testData/inspections/PyArgumentListInspection/FunctoolsWrapsMultiFile/a.py b/python/testData/inspections/PyArgumentListInspection/FunctoolsWrapsMultiFile/a.py new file mode 100644 index 000000000000..640e0c546734 --- /dev/null +++ b/python/testData/inspections/PyArgumentListInspection/FunctoolsWrapsMultiFile/a.py @@ -0,0 +1,5 @@ +from m import Router + +r = Router() +r.route("", 13) +r.route("") \ No newline at end of file diff --git a/python/testData/inspections/PyArgumentListInspection/FunctoolsWrapsMultiFile/m.py b/python/testData/inspections/PyArgumentListInspection/FunctoolsWrapsMultiFile/m.py new file mode 100644 index 000000000000..3cf8a871d67f --- /dev/null +++ b/python/testData/inspections/PyArgumentListInspection/FunctoolsWrapsMultiFile/m.py @@ -0,0 +1,16 @@ +import functools + + +class MyClass: + def foo(self, s: str, i: int): + pass + +class Route: + @functools.wraps(MyClass.foo) + def __init__(self): + pass + +class Router: + @functools.wraps(wrapped=Route.__init__) + def route(self, s: str): + pass \ No newline at end of file diff --git a/python/testData/inspections/PyTypeCheckerInspection/FunctoolsWrapsMultiFile/a.py b/python/testData/inspections/PyTypeCheckerInspection/FunctoolsWrapsMultiFile/a.py new file mode 100644 index 000000000000..9774ebbc81a1 --- /dev/null +++ b/python/testData/inspections/PyTypeCheckerInspection/FunctoolsWrapsMultiFile/a.py @@ -0,0 +1,5 @@ +from m import Router + +router = Router() +router.route(-2) +router.route("") \ No newline at end of file diff --git a/python/testData/inspections/PyTypeCheckerInspection/FunctoolsWrapsMultiFile/m.py b/python/testData/inspections/PyTypeCheckerInspection/FunctoolsWrapsMultiFile/m.py new file mode 100644 index 000000000000..43bddd7768d0 --- /dev/null +++ b/python/testData/inspections/PyTypeCheckerInspection/FunctoolsWrapsMultiFile/m.py @@ -0,0 +1,18 @@ +import functools + + +class MyClass: + def foo(self, i: int): + pass + + +class Route: + @functools.wraps(MyClass.foo) + def __init__(self): + pass + + +class Router: + @functools.wraps(wrapped=Route.__init__) + def route(self, s: str): + pass \ No newline at end of file diff --git a/python/testData/paramInfo/FunctoolsWrap.py b/python/testData/paramInfo/FunctoolsWrap.py deleted file mode 100644 index fc25b8c9a0b3..000000000000 --- a/python/testData/paramInfo/FunctoolsWrap.py +++ /dev/null @@ -1,21 +0,0 @@ -from functools import wraps -import inspect - - -class Route: - def __init__(self, input_a: int, input_b: float): - ... - - -class Router: - def __init__(self): - self.routes = [] - - @wraps(Route.__init__) - def route(self, *args, **kwargs): - route = Route(*args, **kwargs) - self.routes.append(route) - - -r = Router() -r.route() diff --git a/python/testData/paramInfo/FunctoolsWraps.py b/python/testData/paramInfo/FunctoolsWraps.py new file mode 100644 index 000000000000..81b3e2d4123f --- /dev/null +++ b/python/testData/paramInfo/FunctoolsWraps.py @@ -0,0 +1,22 @@ +import functools + + +class MyClass: + def foo(self, s: str, b: bool): + pass + + +class Route: + @functools.wraps(MyClass.foo) + def __init__(self, a: int, b: float, c: object): + pass + + +class Router: + @functools.wraps(wrapped=Route.__init__) + def route(self, *args, **kwargs): + pass + + +r = Router() +r.route("", True) diff --git a/python/testData/quickdoc/FunctoolsWraps.html b/python/testData/quickdoc/FunctoolsWraps.html new file mode 100644 index 000000000000..95798eb4153c --- /dev/null +++ b/python/testData/quickdoc/FunctoolsWraps.html @@ -0,0 +1,3 @@ +
def foo(self,
+        s: str,
+        b: bool) -> None
Unittest placeholder
Params:

s – str

b – bool

Returns:None
\ No newline at end of file diff --git a/python/testData/quickdoc/FunctoolsWraps.py b/python/testData/quickdoc/FunctoolsWraps.py new file mode 100644 index 000000000000..d9c472faa29a --- /dev/null +++ b/python/testData/quickdoc/FunctoolsWraps.py @@ -0,0 +1,24 @@ +import functools + +class Cls: + def foo(self, s: str, b: bool): + """ + Doc text + :param s: str + :param b: bool + :return: None + """ + pass + +class Route: + @functools.wraps(Cls.foo) + def __init__(self): + pass + +class Router: + @functools.wraps(wrapped=Route.__init__) + def route(self, s: str): + pass + +r = Router() +r.route(13) \ No newline at end of file diff --git a/python/testSrc/com/jetbrains/python/Py3QuickDocTest.java b/python/testSrc/com/jetbrains/python/Py3QuickDocTest.java index f7d91e585bc3..8d6a8496a425 100644 --- a/python/testSrc/com/jetbrains/python/Py3QuickDocTest.java +++ b/python/testSrc/com/jetbrains/python/Py3QuickDocTest.java @@ -69,7 +69,7 @@ public class Py3QuickDocTest extends LightMarkedTestCase { assertNotNull(stringValue); PsiElement referenceElement = marks.get("").getParent(); // ident -> expr - final PyDocStringOwner docOwner = (PyDocStringOwner)((PyReferenceExpression)referenceElement).getReference().resolve(); + final PyDocStringOwner docOwner = (PyDocStringOwner)referenceElement.getReference().resolve(); assertNotNull(docOwner); assertEquals(docElement, docOwner.getDocStringExpression()); @@ -91,7 +91,7 @@ public class Py3QuickDocTest extends LightMarkedTestCase { private void checkHover() { Map marks = loadTest(); final PsiElement originalElement = marks.get(""); - final PsiElement docOwner = ((PyReferenceExpression)originalElement.getParent()).getReference().resolve(); + final PsiElement docOwner = originalElement.getParent().getReference().resolve(); checkByHTML(myProvider.getQuickNavigateInfo(docOwner, originalElement)); } @@ -165,18 +165,11 @@ public class Py3QuickDocTest extends LightMarkedTestCase { } public void testPropNewSetter() { - Map marks = loadTest(); - PsiElement referenceElement = marks.get(""); - final PyDocStringOwner docStringOwner = (PyDocStringOwner)referenceElement.getParent().getReference().resolve(); - checkByHTML(myProvider.generateDoc(docStringOwner, referenceElement)); + checkHTMLOnly(); } public void testPropNewDeleter() { - Map marks = loadTest(); - PsiElement referenceElement = marks.get(""); - final PyDocStringOwner docStringOwner = - (PyDocStringOwner)((PyReferenceExpression)(referenceElement.getParent())).getReference().resolve(); - checkByHTML(myProvider.generateDoc(docStringOwner, referenceElement)); + checkHTMLOnly(); } public void testPropOldGetter() { @@ -185,10 +178,7 @@ public class Py3QuickDocTest extends LightMarkedTestCase { public void testPropOldSetter() { - Map marks = loadTest(); - PsiElement referenceElement = marks.get(""); - final PyDocStringOwner docStringOwner = (PyDocStringOwner)referenceElement.getParent().getReference().resolve(); - checkByHTML(myProvider.generateDoc(docStringOwner, referenceElement)); + checkHTMLOnly(); } public void testPropOldDeleter() { @@ -871,6 +861,11 @@ public class Py3QuickDocTest extends LightMarkedTestCase { checkHTMLOnly(); } + // PY-23067 + public void testFunctoolsWraps() { + checkHTMLOnly(); + } + @Override protected String getTestDataPath() { return super.getTestDataPath() + "/quickdoc/"; diff --git a/python/testSrc/com/jetbrains/python/PyParameterInfoTest.java b/python/testSrc/com/jetbrains/python/PyParameterInfoTest.java index f050d7d7a320..f2d522a3001b 100644 --- a/python/testSrc/com/jetbrains/python/PyParameterInfoTest.java +++ b/python/testSrc/com/jetbrains/python/PyParameterInfoTest.java @@ -1258,6 +1258,14 @@ public class PyParameterInfoTest extends LightMarkedTestCase { feignCtrlP(marks.get("").getTextOffset()).check("a: int, *, name: str = ..., year: int", new String[]{"year: int"}); } + // PY-23067 + public void testFunctoolsWraps() { + final Map marks = loadTest(2); + + feignCtrlP(marks.get("").getTextOffset()).check("self: MyClass, s: str, b: bool", new String[]{"s: str, "}, new String[]{"self: MyClass, "}); + feignCtrlP(marks.get("").getTextOffset()).check("self: MyClass, s: str, b: bool", new String[]{"b: bool"}, new String[]{"self: MyClass, "}); + } + // PY-58497 public void testSimplePopupWithHintsOff() { Map marks = loadTest(5); diff --git a/python/testSrc/com/jetbrains/python/inspections/Py3ArgumentListInspectionTest.java b/python/testSrc/com/jetbrains/python/inspections/Py3ArgumentListInspectionTest.java index fbe17da2fae4..4b4e10bf6fee 100644 --- a/python/testSrc/com/jetbrains/python/inspections/Py3ArgumentListInspectionTest.java +++ b/python/testSrc/com/jetbrains/python/inspections/Py3ArgumentListInspectionTest.java @@ -311,4 +311,35 @@ public class Py3ArgumentListInspectionTest extends PyInspectionTestCase { Derived2(0, 0, b=0) """); } + + // PY-23067 + public void testFunctoolsWraps() { + doTestByText(""" + import functools + + class MyClass: + def foo(self, s: str, i: int): + pass + + class Route: + @functools.wraps(MyClass.foo) + def __init__(self): + pass + + class Router: + @functools.wraps(wrapped=Route.__init__) + def route(self, s: str): + pass + + r = Router() + r.route("", 13) + r.route("") + r.route("", 13, 1) + """); + } + + // PY-23067 + public void testFunctoolsWrapsMultiFile() { + doMultiFileTest(); + } } diff --git a/python/testSrc/com/jetbrains/python/inspections/Py3TypeCheckerInspectionTest.java b/python/testSrc/com/jetbrains/python/inspections/Py3TypeCheckerInspectionTest.java index cbef7e33b7d6..007fc1a1b447 100644 --- a/python/testSrc/com/jetbrains/python/inspections/Py3TypeCheckerInspectionTest.java +++ b/python/testSrc/com/jetbrains/python/inspections/Py3TypeCheckerInspectionTest.java @@ -2144,4 +2144,34 @@ def foo(param: str | int) -> TypeGuard[str]: PosArgsT = TypeVarTuple("PosArgsT") """); } + + // PY-23067 + public void testFunctoolsWraps() { + doTestByText(""" + import functools + + class MyClass: + def foo(self, i: int): + pass + + class Route: + @functools.wraps(MyClass.foo) + def __init__(self): + pass + + class Router: + @functools.wraps(wrapped=Route.__init__) + def route(self, s: str): + pass + + router = Router() + router.route(-2) + router.route("") + """); + } + + // PY-23067 + public void testFunctoolsWrapsMultiFile() { + doMultiFileTest(); + } }