PY-75760 - allow reference to another ParamSpec be default of ParamSpec type, simplify logic of generic substitution for TypeVars

GitOrigin-RevId: 9ca5d7f3529513c683424d2f4d6da75f40d58e4a
This commit is contained in:
Daniil Kalinin
2024-09-26 17:19:21 +02:00
committed by intellij-monorepo-bot
parent d29d55476e
commit 783bbde096
3 changed files with 78 additions and 17 deletions

View File

@@ -1826,6 +1826,12 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
.withParameters(ContainerUtil
.map(defaultArgumentTypes, argType -> PyCallableParameterImpl.nonPsi(argType)), context.getTypeContext());
}
if (expression instanceof PyReferenceExpression) {
PyType referenceType = Ref.deref(getType(expression, context));
if (referenceType instanceof PyParamSpecType paramSpecType) {
return paramSpecType;
}
}
return null;
}

View File

@@ -1147,24 +1147,16 @@ public final class PyTypeChecker {
substitution = invert(invertedSubstitution);
}
}
if (substitution instanceof PyTypeVarType) {
PyQualifiedNameOwner scope = typeVar.getScopeOwner();
if (scope instanceof PyClass pyClass) {
PyType genericTypeFromClass = getGenericTypeForClass(context, pyClass);
if (genericTypeFromClass instanceof PyCollectionType) {
Generics generics = collectGenerics(genericTypeFromClass, context);
PyType finalSubstitution = substitution;
if (substitution instanceof PyTypeVarType typeVarSubstitution) {
PyTypeVarType sameScopeSubstitution = StreamEx.of(generics.typeVars)
.findFirst(typeVarType -> {
return typeVarType.getDeclarationElement() != null
&& typeVarType.getDeclarationElement().equals(finalSubstitution.getDeclarationElement());
}).orElse(null);
PyTypeVarType sameScopeSubstitution = StreamEx.of(substitutions.typeVars.keySet())
.findFirst(typeVarType -> {
return typeVarType.getDeclarationElement() != null
&& typeVarType.getDeclarationElement().equals(typeVarSubstitution.getDeclarationElement());
}).orElse(null);
if (sameScopeSubstitution != null && sameScopeSubstitution.getDefaultType() != null) {
return substitute(sameScopeSubstitution, substitutions, context, substituting);
}
}
if (sameScopeSubstitution != null && typeVarSubstitution.getDefaultType() != null) {
return substitute(sameScopeSubstitution, substitutions, context, substituting);
}
}
// TODO remove !typeVar.equals(substitution) part, it's necessary due to the logic in unifyReceiverWithParamSpecs
@@ -1175,11 +1167,20 @@ public final class PyTypeChecker {
}
else if (type instanceof PyParamSpecType paramSpecType && paramSpecType.getParameters() == null) {
if (!substitutions.paramSpecs.containsKey(paramSpecType)) {
PyParamSpecType sameScopeSubstitution = StreamEx.of(substitutions.paramSpecs.keySet())
.findFirst(typeVarType -> {
return typeVarType.getDeclarationElement() != null
&& typeVarType.getDeclarationElement().equals(paramSpecType.getDeclarationElement());
}).orElse(null);
if (sameScopeSubstitution != null) {
return substitute(sameScopeSubstitution, substitutions, context, substituting);
}
return paramSpecType;
}
PyParamSpecType substitution = substitutions.paramSpecs.get(paramSpecType);
if (substitution != null && !substitution.equals(paramSpecType) && hasGenerics(substitution, context)) {
return substitute(substitution, substitutions, context);
return substitute(substitution, substitutions, context, substituting);
}
// TODO For ParamSpecs, replace Any with (*args: Any, **kwargs: Any) as it's a logical "wildcard" for this kind of type parameter
return substitution;

View File

@@ -5644,6 +5644,60 @@ public class PyTypingTest extends PyTestCase {
""");
}
// PY-71002
public void testParamSpecDefaultTypeRefersToAnotherParamSpecNewStyle() {
doTest("Clazz[[str], [str], [str]]", """
class Clazz[**P1, **P2 = P1, **P3 = P2]: ...
expr = Clazz[[str]]()
""");
}
// PY-71002
public void testParamSpecDefaultTypeRefersToAnotherParamSpecOldStyle() {
doTest("Clazz[[str], [str], [str]]", """
from typing import Generic, ParamSpec
P1 = ParamSpec("P1")
P2 = ParamSpec("P2", default=P1)
P3 = ParamSpec("P3", default=P2)
class Clazz(Generic[P1, P2, P3]): ...
expr = Clazz[[str]]()
""");
}
// PY-71002
public void testParamSpecDefaultTypeRefersToAnotherParamSpecOldStyleNoExplicit() {
doTest("Clazz[[str], [str], [bool, bool], [bool, bool]]", """
from typing import Generic, ParamSpec
P1 = ParamSpec("P1", default=[str])
P2 = ParamSpec("P2", default=P1)
P3 = ParamSpec("P3", default=[bool, bool])
P4 = ParamSpec("P4", default=P3)
class Clazz(Generic[P1, P2, P3, P4]): ...
expr = Clazz()
""");
}
// PY-71002
public void testParamSpecWithDefaultInConstructor() {
doTest("(int, str, str) -> None | None", """
from typing import Generic, ParamSpec, Callable
P = ParamSpec("P", default=[int, str, str])
class ClassA(Generic[P]):
def __init__(self, x: Callable[P, None] = None) -> None:
self.x = x
...
expr = ClassA().x
""");
}
// PY-71002
public void testParamSpecDefaultTypeRefersToAnotherParamSpecWithEllipsis() {
doTest("Clazz[Any, [float], [float]]", """
class Clazz[**P1, **P2 = P1, **P3 = P2]: ...
expr = Clazz[..., [float]]()
""");
}
public void testDataclassTransformConstructorSignature() {
doTestExpressionUnderCaret("(id: int, name: str) -> MyClass", """
from typing import dataclass_transform