PY-75760 - allow reference to another ParamSpec be default of ParamSpec type, simplify logic of generic substitution for TypeVars

GitOrigin-RevId: 9ca5d7f3529513c683424d2f4d6da75f40d58e4a
This commit is contained in:
Daniil Kalinin
2024-09-26 17:19:21 +02:00
committed by intellij-monorepo-bot
parent d29d55476e
commit 783bbde096
3 changed files with 78 additions and 17 deletions

View File

@@ -1826,6 +1826,12 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
.withParameters(ContainerUtil .withParameters(ContainerUtil
.map(defaultArgumentTypes, argType -> PyCallableParameterImpl.nonPsi(argType)), context.getTypeContext()); .map(defaultArgumentTypes, argType -> PyCallableParameterImpl.nonPsi(argType)), context.getTypeContext());
} }
if (expression instanceof PyReferenceExpression) {
PyType referenceType = Ref.deref(getType(expression, context));
if (referenceType instanceof PyParamSpecType paramSpecType) {
return paramSpecType;
}
}
return null; return null;
} }

View File

@@ -1147,26 +1147,18 @@ public final class PyTypeChecker {
substitution = invert(invertedSubstitution); substitution = invert(invertedSubstitution);
} }
} }
if (substitution instanceof PyTypeVarType) { if (substitution instanceof PyTypeVarType typeVarSubstitution) {
PyQualifiedNameOwner scope = typeVar.getScopeOwner();
if (scope instanceof PyClass pyClass) {
PyType genericTypeFromClass = getGenericTypeForClass(context, pyClass);
if (genericTypeFromClass instanceof PyCollectionType) {
Generics generics = collectGenerics(genericTypeFromClass, context);
PyType finalSubstitution = substitution;
PyTypeVarType sameScopeSubstitution = StreamEx.of(generics.typeVars) PyTypeVarType sameScopeSubstitution = StreamEx.of(substitutions.typeVars.keySet())
.findFirst(typeVarType -> { .findFirst(typeVarType -> {
return typeVarType.getDeclarationElement() != null return typeVarType.getDeclarationElement() != null
&& typeVarType.getDeclarationElement().equals(finalSubstitution.getDeclarationElement()); && typeVarType.getDeclarationElement().equals(typeVarSubstitution.getDeclarationElement());
}).orElse(null); }).orElse(null);
if (sameScopeSubstitution != null && sameScopeSubstitution.getDefaultType() != null) { if (sameScopeSubstitution != null && typeVarSubstitution.getDefaultType() != null) {
return substitute(sameScopeSubstitution, substitutions, context, substituting); return substitute(sameScopeSubstitution, substitutions, context, substituting);
} }
} }
}
}
// TODO remove !typeVar.equals(substitution) part, it's necessary due to the logic in unifyReceiverWithParamSpecs // TODO remove !typeVar.equals(substitution) part, it's necessary due to the logic in unifyReceiverWithParamSpecs
if (!typeVar.equals(substitution) && hasGenerics(substitution, context)) { if (!typeVar.equals(substitution) && hasGenerics(substitution, context)) {
return substitute(substitution, substitutions, context, substituting); return substitute(substitution, substitutions, context, substituting);
@@ -1175,11 +1167,20 @@ public final class PyTypeChecker {
} }
else if (type instanceof PyParamSpecType paramSpecType && paramSpecType.getParameters() == null) { else if (type instanceof PyParamSpecType paramSpecType && paramSpecType.getParameters() == null) {
if (!substitutions.paramSpecs.containsKey(paramSpecType)) { if (!substitutions.paramSpecs.containsKey(paramSpecType)) {
PyParamSpecType sameScopeSubstitution = StreamEx.of(substitutions.paramSpecs.keySet())
.findFirst(typeVarType -> {
return typeVarType.getDeclarationElement() != null
&& typeVarType.getDeclarationElement().equals(paramSpecType.getDeclarationElement());
}).orElse(null);
if (sameScopeSubstitution != null) {
return substitute(sameScopeSubstitution, substitutions, context, substituting);
}
return paramSpecType; return paramSpecType;
} }
PyParamSpecType substitution = substitutions.paramSpecs.get(paramSpecType); PyParamSpecType substitution = substitutions.paramSpecs.get(paramSpecType);
if (substitution != null && !substitution.equals(paramSpecType) && hasGenerics(substitution, context)) { if (substitution != null && !substitution.equals(paramSpecType) && hasGenerics(substitution, context)) {
return substitute(substitution, substitutions, context); return substitute(substitution, substitutions, context, substituting);
} }
// TODO For ParamSpecs, replace Any with (*args: Any, **kwargs: Any) as it's a logical "wildcard" for this kind of type parameter // TODO For ParamSpecs, replace Any with (*args: Any, **kwargs: Any) as it's a logical "wildcard" for this kind of type parameter
return substitution; return substitution;

View File

@@ -5644,6 +5644,60 @@ public class PyTypingTest extends PyTestCase {
"""); """);
} }
// PY-71002
public void testParamSpecDefaultTypeRefersToAnotherParamSpecNewStyle() {
doTest("Clazz[[str], [str], [str]]", """
class Clazz[**P1, **P2 = P1, **P3 = P2]: ...
expr = Clazz[[str]]()
""");
}
// PY-71002
public void testParamSpecDefaultTypeRefersToAnotherParamSpecOldStyle() {
doTest("Clazz[[str], [str], [str]]", """
from typing import Generic, ParamSpec
P1 = ParamSpec("P1")
P2 = ParamSpec("P2", default=P1)
P3 = ParamSpec("P3", default=P2)
class Clazz(Generic[P1, P2, P3]): ...
expr = Clazz[[str]]()
""");
}
// PY-71002
public void testParamSpecDefaultTypeRefersToAnotherParamSpecOldStyleNoExplicit() {
doTest("Clazz[[str], [str], [bool, bool], [bool, bool]]", """
from typing import Generic, ParamSpec
P1 = ParamSpec("P1", default=[str])
P2 = ParamSpec("P2", default=P1)
P3 = ParamSpec("P3", default=[bool, bool])
P4 = ParamSpec("P4", default=P3)
class Clazz(Generic[P1, P2, P3, P4]): ...
expr = Clazz()
""");
}
// PY-71002
public void testParamSpecWithDefaultInConstructor() {
doTest("(int, str, str) -> None | None", """
from typing import Generic, ParamSpec, Callable
P = ParamSpec("P", default=[int, str, str])
class ClassA(Generic[P]):
def __init__(self, x: Callable[P, None] = None) -> None:
self.x = x
...
expr = ClassA().x
""");
}
// PY-71002
public void testParamSpecDefaultTypeRefersToAnotherParamSpecWithEllipsis() {
doTest("Clazz[Any, [float], [float]]", """
class Clazz[**P1, **P2 = P1, **P3 = P2]: ...
expr = Clazz[..., [float]]()
""");
}
public void testDataclassTransformConstructorSignature() { public void testDataclassTransformConstructorSignature() {
doTestExpressionUnderCaret("(id: int, name: str) -> MyClass", """ doTestExpressionUnderCaret("(id: int, name: str) -> MyClass", """
from typing import dataclass_transform from typing import dataclass_transform