diff --git a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java index 7e676eef81a1..54a317d58629 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java +++ b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java @@ -1826,6 +1826,12 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext< .withParameters(ContainerUtil .map(defaultArgumentTypes, argType -> PyCallableParameterImpl.nonPsi(argType)), context.getTypeContext()); } + if (expression instanceof PyReferenceExpression) { + PyType referenceType = Ref.deref(getType(expression, context)); + if (referenceType instanceof PyParamSpecType paramSpecType) { + return paramSpecType; + } + } return null; } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyTypeChecker.java b/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyTypeChecker.java index cda9c30683eb..381b139dd272 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyTypeChecker.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyTypeChecker.java @@ -1147,24 +1147,16 @@ public final class PyTypeChecker { substitution = invert(invertedSubstitution); } } - if (substitution instanceof PyTypeVarType) { - PyQualifiedNameOwner scope = typeVar.getScopeOwner(); - if (scope instanceof PyClass pyClass) { - PyType genericTypeFromClass = getGenericTypeForClass(context, pyClass); - if (genericTypeFromClass instanceof PyCollectionType) { - Generics generics = collectGenerics(genericTypeFromClass, context); - PyType finalSubstitution = substitution; + if (substitution instanceof PyTypeVarType typeVarSubstitution) { - PyTypeVarType sameScopeSubstitution = StreamEx.of(generics.typeVars) - .findFirst(typeVarType -> { - return typeVarType.getDeclarationElement() != null - && typeVarType.getDeclarationElement().equals(finalSubstitution.getDeclarationElement()); - }).orElse(null); + PyTypeVarType sameScopeSubstitution = StreamEx.of(substitutions.typeVars.keySet()) + .findFirst(typeVarType -> { + return typeVarType.getDeclarationElement() != null + && typeVarType.getDeclarationElement().equals(typeVarSubstitution.getDeclarationElement()); + }).orElse(null); - if (sameScopeSubstitution != null && sameScopeSubstitution.getDefaultType() != null) { - return substitute(sameScopeSubstitution, substitutions, context, substituting); - } - } + if (sameScopeSubstitution != null && typeVarSubstitution.getDefaultType() != null) { + return substitute(sameScopeSubstitution, substitutions, context, substituting); } } // TODO remove !typeVar.equals(substitution) part, it's necessary due to the logic in unifyReceiverWithParamSpecs @@ -1175,11 +1167,20 @@ public final class PyTypeChecker { } else if (type instanceof PyParamSpecType paramSpecType && paramSpecType.getParameters() == null) { if (!substitutions.paramSpecs.containsKey(paramSpecType)) { + PyParamSpecType sameScopeSubstitution = StreamEx.of(substitutions.paramSpecs.keySet()) + .findFirst(typeVarType -> { + return typeVarType.getDeclarationElement() != null + && typeVarType.getDeclarationElement().equals(paramSpecType.getDeclarationElement()); + }).orElse(null); + + if (sameScopeSubstitution != null) { + return substitute(sameScopeSubstitution, substitutions, context, substituting); + } return paramSpecType; } PyParamSpecType substitution = substitutions.paramSpecs.get(paramSpecType); if (substitution != null && !substitution.equals(paramSpecType) && hasGenerics(substitution, context)) { - return substitute(substitution, substitutions, context); + return substitute(substitution, substitutions, context, substituting); } // TODO For ParamSpecs, replace Any with (*args: Any, **kwargs: Any) as it's a logical "wildcard" for this kind of type parameter return substitution; diff --git a/python/testSrc/com/jetbrains/python/PyTypingTest.java b/python/testSrc/com/jetbrains/python/PyTypingTest.java index f231c9b409b3..81983744eb39 100644 --- a/python/testSrc/com/jetbrains/python/PyTypingTest.java +++ b/python/testSrc/com/jetbrains/python/PyTypingTest.java @@ -5644,6 +5644,60 @@ public class PyTypingTest extends PyTestCase { """); } + // PY-71002 + public void testParamSpecDefaultTypeRefersToAnotherParamSpecNewStyle() { + doTest("Clazz[[str], [str], [str]]", """ + class Clazz[**P1, **P2 = P1, **P3 = P2]: ... + expr = Clazz[[str]]() + """); + } + + // PY-71002 + public void testParamSpecDefaultTypeRefersToAnotherParamSpecOldStyle() { + doTest("Clazz[[str], [str], [str]]", """ + from typing import Generic, ParamSpec + P1 = ParamSpec("P1") + P2 = ParamSpec("P2", default=P1) + P3 = ParamSpec("P3", default=P2) + class Clazz(Generic[P1, P2, P3]): ... + expr = Clazz[[str]]() + """); + } + + // PY-71002 + public void testParamSpecDefaultTypeRefersToAnotherParamSpecOldStyleNoExplicit() { + doTest("Clazz[[str], [str], [bool, bool], [bool, bool]]", """ + from typing import Generic, ParamSpec + P1 = ParamSpec("P1", default=[str]) + P2 = ParamSpec("P2", default=P1) + P3 = ParamSpec("P3", default=[bool, bool]) + P4 = ParamSpec("P4", default=P3) + class Clazz(Generic[P1, P2, P3, P4]): ... + expr = Clazz() + """); + } + + // PY-71002 + public void testParamSpecWithDefaultInConstructor() { + doTest("(int, str, str) -> None | None", """ + from typing import Generic, ParamSpec, Callable + P = ParamSpec("P", default=[int, str, str]) + class ClassA(Generic[P]): + def __init__(self, x: Callable[P, None] = None) -> None: + self.x = x + ... + expr = ClassA().x + """); + } + + // PY-71002 + public void testParamSpecDefaultTypeRefersToAnotherParamSpecWithEllipsis() { + doTest("Clazz[Any, [float], [float]]", """ + class Clazz[**P1, **P2 = P1, **P3 = P2]: ... + expr = Clazz[..., [float]]() + """); + } + public void testDataclassTransformConstructorSignature() { doTestExpressionUnderCaret("(id: int, name: str) -> MyClass", """ from typing import dataclass_transform