PY-75760 Refactor the implementation of PEP 696 support

1. At first, get rid of the explicit mapping of generics to default types (remove all these not-good-looking methods which were added earlier, such as `PyTypeChecker.trySubstituteByDefaultsOnl`y and `PyTypeChecker.getSubstitutionsWithDefaults`) and their usages. All the related logic now will be handled in `PyTypeParameterMapping`, as we wanted it to be.

2. Do some changes in `PyTypeChecker` to be able to correctly parameterize class via constructor call, and also take defaults into account in `PyTypeChecker.getSubstitutionsWithUnresolvedReturnGenerics` for methods

3. Get rid of the explicit calls of `PyTypingTypeProvider.tryParameterizeClassWithDefaults` in `PyCallExpressionHelper`, `PyReferenceExpressionImpl`, rename this method to `parameterizeClassDefaultAware` and call it directly in `PyTypingTypeProvider.getReferenceType`

4. Add a new flag to `PyTypeParameterMapping` to be able to correctly match type parameters (see `PyTypeChecker.matchTypeParameters`)

GitOrigin-RevId: 5dd90ee3bdf8319b36f1945ce22a33a8edf6bc93
This commit is contained in:
Daniil Kalinin
2024-09-20 15:51:17 +02:00
committed by intellij-monorepo-bot
parent 15f26f8d9f
commit 23cee33c35
7 changed files with 60 additions and 125 deletions

View File

@@ -583,6 +583,12 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
}
}
}
if (referenceTarget instanceof PyClass pyClass && anchor instanceof PyExpression) {
PyCollectionType parameterizedType = parameterizeClassDefaultAware(pyClass, List.of(), context);
if (parameterizedType != null) {
return Ref.create(parameterizedType.toClass());
}
}
return null;
}
@@ -655,7 +661,7 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
List<PyType> superTypeArguments = superClassType instanceof PyCollectionType parameterized ?
parameterized.getElementTypes() : Collections.emptyList();
PyTypeParameterMapping mapping =
PyTypeParameterMapping.mapByShape(superTypeParameters, superTypeArguments, Option.MAP_UNMATCHED_EXPECTED_TYPES_TO_ANY);
PyTypeParameterMapping.mapByShape(superTypeParameters, superTypeArguments, Option.MAP_UNMATCHED_EXPECTED_TYPES_TO_ANY, Option.USE_DEFAULTS);
if (mapping != null) {
for (Couple<PyType> pair : mapping.getMappedTypes()) {
PyType expectedType = pair.getFirst();
@@ -1863,32 +1869,19 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
}
@Nullable
public static PyType tryParameterizeClassWithDefaults(@NotNull PyClassType classType,
@NotNull PyExpression anchor,
boolean toInstance,
@NotNull TypeEvalContext context) {
return staticWithCustomContext(context, customContext -> parameterizeClassWithDefaults(classType, anchor, toInstance, customContext));
}
@Nullable
private static PyType parameterizeClassWithDefaults(@NotNull PyClassType classType,
@NotNull PyExpression anchor,
boolean toInstance,
@NotNull Context context) {
PyClass pyClass = classType.getPyClass();
private static PyCollectionType parameterizeClassDefaultAware(@NotNull PyClass pyClass,
@NotNull List<PyType> actualTypeParams,
@NotNull Context context) {
if (isGeneric(pyClass, context.getTypeContext())) {
PyCollectionType genericDefinitionType =
doPreventingRecursion(pyClass, false, () -> PyTypeChecker.findGenericDefinitionType(pyClass, context.getTypeContext()));
if (genericDefinitionType != null && ContainerUtil.exists(genericDefinitionType.getElementTypes(),
t -> t instanceof PyTypeParameterType typeParameterType &&
typeParameterType.getDefaultType() != null)) {
List<PyType> indexTypes = anchor instanceof PySubscriptionExpression subscriptionExpression
? getIndexTypes(subscriptionExpression, context)
: List.of();
PyType parameterizedType = PyTypeChecker.parameterizeType(genericDefinitionType, indexTypes, context.myContext);
PyType parameterizedType = PyTypeChecker.parameterizeType(genericDefinitionType, actualTypeParams, context.myContext);
if (parameterizedType instanceof PyCollectionType collectionType) {
return toInstance ? collectionType.toInstance() : collectionType.toClass();
return collectionType;
}
}
}
@@ -1985,9 +1978,9 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
}
if (operandType != null) {
if (operandType instanceof PyClassType classType) {
PyType parameterizedType = parameterizeClassWithDefaults(classType, subscriptionExpr, true, context);
if (parameterizedType instanceof PyCollectionType) {
return parameterizedType;
PyCollectionType parameterizedType = parameterizeClassDefaultAware(classType.getPyClass(), indexTypes, context);
if (parameterizedType != null) {
return parameterizedType.toInstance();
}
}
return PyTypeChecker.parameterizeType(operandType, indexTypes, context.getTypeContext());

View File

@@ -650,13 +650,6 @@ public final class PyCallExpressionHelper {
return new PyCollectionTypeImpl(receiverClass, false, elementTypes);
}
if (initOrNewCallType instanceof PyClassType classType) {
PyType implicitlyParameterized = PyTypingTypeProvider.tryParameterizeClassWithDefaults(classType, callSite, true, context);
if (implicitlyParameterized instanceof PyCollectionType collectionType) {
return collectionType.toInstance();
}
}
return new PyClassTypeImpl(receiverClass, false);
}

View File

@@ -223,8 +223,7 @@ public class PyFunctionImpl extends PyBaseElementImpl<PyFunctionStub> implements
}
final var substitutionsWithUnresolvedReturnGenerics =
PyTypeChecker.getSubstitutionsWithUnresolvedReturnGenerics(getParameters(context), type, substitutions, context);
final var substitutionsWithDefaults = PyTypeChecker.getSubstitutionsWithDefaults(substitutionsWithUnresolvedReturnGenerics);
type = PyTypeChecker.substitute(type, substitutionsWithDefaults, context);
type = PyTypeChecker.substitute(type, substitutionsWithUnresolvedReturnGenerics, context);
}
else {
type = null;

View File

@@ -18,7 +18,6 @@ import com.jetbrains.python.codeInsight.controlflow.PyTypeAssertionEvaluator;
import com.jetbrains.python.codeInsight.controlflow.ReadWriteInstruction;
import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.impl.references.PyImportReference;
import com.jetbrains.python.psi.impl.references.PyQualifiedReference;
@@ -368,8 +367,7 @@ public class PyReferenceExpressionImpl extends PyElementImpl implements PyRefere
}
final var substitutions = PyTypeChecker.unifyGenericCall(qualifier, Collections.emptyMap(), context);
if (substitutions != null) {
var substitutionsWithDefaults = PyTypeChecker.getSubstitutionsWithDefaults(substitutions);
final PyType substituted = PyTypeChecker.substitute(type, substitutionsWithDefaults, context);
final PyType substituted = PyTypeChecker.substitute(type, substitutions, context);
if (substituted != null) {
return substituted;
}
@@ -378,13 +376,6 @@ public class PyReferenceExpressionImpl extends PyElementImpl implements PyRefere
}
}
if (type instanceof PyClassType classType && !(type instanceof PyCollectionType)) {
PyType parameterizedType = PyTypingTypeProvider.tryParameterizeClassWithDefaults(classType, anchor, false, context);
if (parameterizedType instanceof PyCollectionType collectionType) {
return collectionType;
}
}
return type;
}

View File

@@ -16,6 +16,7 @@
package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.intellij.openapi.util.Ref;
import com.intellij.psi.PsiPolyVariantReference;
import com.intellij.psi.PsiReference;
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
@@ -59,10 +60,10 @@ public class PySubscriptionExpressionImpl extends PyElementImpl implements PySub
.map(typedDictType::getElementType)
.orElse(null);
}
if (type instanceof PyClassType classType) {
PyType parameterizedType = PyTypingTypeProvider.tryParameterizeClassWithDefaults(classType, this, false, context);
if (type instanceof PyClassType) {
PyType parameterizedType = Ref.deref(PyTypingTypeProvider.getType(this, context));
if (parameterizedType instanceof PyCollectionType collectionType) {
return collectionType;
return collectionType.toClass();
}
}
return PyCallExpressionHelper.getCallType(this, context, key);

View File

@@ -237,7 +237,12 @@ public final class PyTypeChecker {
return false;
}
final PyType substitution = context.mySubstitutions.typeVars.get(expected);
PyType substituted = context.mySubstitutions.typeVars.get(expected);
if (expected.getDefaultType() != null && expected.getDefaultType().equals(substituted)) {
// Skip default substitution
substituted = null;
}
final PyType substitution = substituted;
PyType bound = expected.getBound();
// Promote int in Type[TypeVar('T', int)] to Type[int] before checking that bounds match
if (expected.isDefinition()) {
@@ -865,7 +870,7 @@ public final class PyTypeChecker {
@NotNull List<PyType> actualTypeParameters,
@NotNull MatchContext context) {
PyTypeParameterMapping mapping =
PyTypeParameterMapping.mapByShape(expectedTypeParameters, actualTypeParameters);
PyTypeParameterMapping.mapByShape(expectedTypeParameters, actualTypeParameters, Option.USE_DEFAULTS);
if (mapping == null) {
return false;
}
@@ -950,7 +955,7 @@ public final class PyTypeChecker {
boolean isAlreadyBound = existingSubstitutions.typeVars.containsKey(returnTypeParam) ||
existingSubstitutions.typeVars.containsKey(invert(returnTypeParam));
if (canGetBoundFromArguments && !isAlreadyBound) {
existingSubstitutions.typeVars.put(returnTypeParam, null);
existingSubstitutions.typeVars.put(returnTypeParam, returnTypeParam.getDefaultType());
}
}
for (PyParamSpecType paramSpecType : typeParamsFromReturnType.paramSpecs) {
@@ -1391,41 +1396,6 @@ public final class PyTypeChecker {
return match(expectedType, actualType, context, substitutions);
}
@NotNull
public static GenericSubstitutions getSubstitutionsWithDefaults(@NotNull GenericSubstitutions substitutions) {
Map<PyTypeVarType, PyType> typeVersWithDefaults = new HashMap<>();
Map<PyParamSpecType, PyParamSpecType> paramSpecsWithDefaults = new HashMap<>();
Map<PyTypeVarTupleType, PyVariadicType> typeVarTuplesWithDefaults = new HashMap<>();
substitutions.getTypeVars().forEach((key, value) -> {
if (key.getDefaultType() != null && value == null) {
typeVersWithDefaults.put(key, key.getDefaultType());
}
else if (value instanceof PyTypeVarType typeVarType
&& typeVarType.getDefaultType() != null
&& !typeVarType.getDefaultType().equals(key)) {
typeVersWithDefaults.put(key, typeVarType.getDefaultType());
}
});
substitutions.getParamSpecs().forEach((key, value) -> {
if (key.getDefaultType() instanceof PyParamSpecType defaultType
&& (value == null || key.equals(value))) {
paramSpecsWithDefaults.put(key, defaultType);
}
});
substitutions.getTypeVarTuples().forEach((key, value) -> {
if (key.getDefaultType() instanceof PyVariadicType variadicType
&& (value == null
|| (value instanceof PyUnpackedTupleType unpackedTupleType && unpackedTupleType.getElementTypes().isEmpty())
|| key.equals(value))) {
typeVarTuplesWithDefaults.put(key, variadicType);
}
});
substitutions.typeVars.putAll(typeVersWithDefaults);
substitutions.paramSpecs.putAll(paramSpecsWithDefaults);
substitutions.typeVarTuples.putAll(typeVarTuplesWithDefaults);
return substitutions;
}
private static boolean matchContainer(@Nullable PyCallableParameter container, @NotNull List<? extends PyExpression> arguments,
@NotNull GenericSubstitutions substitutions, @NotNull TypeEvalContext context) {
if (container == null) {
@@ -1685,24 +1655,6 @@ public final class PyTypeChecker {
return substitutions;
}
@ApiStatus.Internal
@Nullable
public static PyType trySubstituteByDefaultsOnly(@NotNull PyType targetType,
List<PyTypeParameterType> typeParameterTypes,
boolean unmappedToAny,
@NotNull TypeEvalContext context) {
if (!typeParameterTypes.isEmpty()) {
List<Couple<PyType>> typeParametersToSubstitutions =
ContainerUtil.map(typeParameterTypes, typeParameterType
-> new Couple<>(typeParameterType, unmappedToAny ? null : typeParameterType));
var substitutions = fillSubstitutionsWithTypeParameters(new GenericSubstitutions(), typeParametersToSubstitutions);
var subsWithDefaults = getSubstitutionsWithDefaults(substitutions);
return substitute(targetType, subsWithDefaults, context);
}
return null;
}
@ApiStatus.Internal
public static class Generics {
@NotNull

View File

@@ -205,36 +205,15 @@ public final class PyTypeParameterMapping {
}
sizeMismatch = false;
}
else if (optionSet.contains(Option.MAP_UNMATCHED_EXPECTED_TYPES_TO_ANY)) {
if (optionSet.contains(Option.USE_DEFAULTS)
&& onlyLeftExpectedType instanceof PyTypeParameterType typeParameterType
&& typeParameterType.getDefaultType() != null) {
centerMappedTypes.add(Couple.of(onlyLeftExpectedType, typeParameterType.getDefaultType()));
}
else {
centerMappedTypes.add(Couple.of(onlyLeftExpectedType, null));
}
sizeMismatch = false;
}
else {
sizeMismatch = true;
sizeMismatch = handleSizeMismatch(onlyLeftExpectedType, centerMappedTypes, optionSet);
}
}
else if (optionSet.contains(Option.MAP_UNMATCHED_EXPECTED_TYPES_TO_ANY)) {
for (PyType unmatchedType : expectedTypesDeque.toList()) {
if (optionSet.contains(Option.USE_DEFAULTS) &&
unmatchedType instanceof PyTypeParameterType typeParameterType
&& typeParameterType.getDefaultType() != null) {
centerMappedTypes.add(Couple.of(unmatchedType, typeParameterType.getDefaultType()));
}
else {
centerMappedTypes.add(Couple.of(unmatchedType, null));
}
}
sizeMismatch = false;
}
else {
sizeMismatch = true;
for (PyType unmatchedType : expectedTypesDeque.toList()) {
sizeMismatch = handleSizeMismatch(unmatchedType, centerMappedTypes, optionSet);
}
}
if (sizeMismatch) {
return null;
@@ -246,6 +225,33 @@ public final class PyTypeParameterMapping {
return new PyTypeParameterMapping(resultMapping);
}
private static boolean handleSizeMismatch(@Nullable PyType unmatchedType,
@NotNull List<Couple<PyType>> centerMappedTypes,
@NotNull EnumSet<Option> optionSet) {
boolean sizeMismatch;
if (optionSet.contains(Option.USE_DEFAULTS)) {
if (unmatchedType instanceof PyTypeParameterType typeParameterType && typeParameterType.getDefaultType() != null) {
centerMappedTypes.add(Couple.of(unmatchedType, typeParameterType.getDefaultType()));
sizeMismatch = false;
}
else if (optionSet.contains(Option.MAP_UNMATCHED_EXPECTED_TYPES_TO_ANY)) {
sizeMismatch = false;
centerMappedTypes.add(Couple.of(unmatchedType, null));
}
else {
sizeMismatch = true;
}
}
else if (optionSet.contains(Option.MAP_UNMATCHED_EXPECTED_TYPES_TO_ANY) && !optionSet.contains(Option.USE_DEFAULTS)) {
centerMappedTypes.add(Couple.of(unmatchedType, null));
sizeMismatch = false;
}
else {
sizeMismatch = true;
}
return sizeMismatch;
}
private static @NotNull List<PyType> flattenUnpackedTupleTypes(List<? extends PyType> types) {
return ContainerUtil.flatMap(types, type -> {
if (type instanceof PyUnpackedTupleType unpackedTupleType && !unpackedTupleType.isUnbound()) {