PY-59548 Recognize generic class inheritance through aliases in stubbed files

Namely, evaluate expressions in the list of base classes as regular
type hints in PyTypingTypeProvider. It allowed us to properly consider type
aliases, as previously used PyResolveUtil.resolveQualifiedNameInScope reached
only up to the first type alias not following assignment chains further to
class definitions.
It simultaneously made collecting type substitutions for generic classes simpler
since we don't need to replicate the logic of parsing subscription expressions
in type hints specifically for base class expressions and can instead
retrieve type arguments corresponding to expressions in brackets directly from
PyCollectionType.getElementTypes.

Since we process both subscription expressions and plain qualified names in
the list of base classes uniformly now, evaluating them all as type hints,
PyClassElementType.getSubscriptedSuperClassesStubLike also became obsolete.

GitOrigin-RevId: 8a2a45c1be0259ee62f82595289842f043ae02b2
This commit is contained in:
Mikhail Golubev
2023-03-16 13:07:28 +02:00
committed by intellij-monorepo-bot
parent 869642e273
commit b1fd219835
4 changed files with 80 additions and 86 deletions

View File

@@ -38,7 +38,6 @@ import com.jetbrains.python.psi.impl.stubs.PyTypingAliasStubType;
import com.jetbrains.python.psi.resolve.PyResolveContext;
import com.jetbrains.python.psi.resolve.PyResolveUtil;
import com.jetbrains.python.psi.resolve.RatedResolveResult;
import com.jetbrains.python.psi.stubs.PyClassStub;
import com.jetbrains.python.psi.types.*;
import one.util.streamex.StreamEx;
import org.jetbrains.annotations.NotNull;
@@ -614,99 +613,51 @@ public class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<PyTypi
if (!isGeneric(cls, context.myContext)) {
return Collections.emptyMap();
}
final Map<PyType, PyType> results = new HashMap<>();
for (Map.Entry<PyClass, PySubscriptionExpression> e : getResolvedSuperClassesAndTypeParameters(cls, context.myContext).entrySet()) {
final PySubscriptionExpression subscriptionExpr = e.getValue();
final PyClass superClass = e.getKey();
final Map<PyType, PyType> superSubstitutions =
doPreventingRecursion(superClass, false, () -> getGenericSubstitutions(superClass, context));
Map<PyType, PyType> results = new HashMap<>();
for (PyClassType superClassType : evaluateSuperClassesAsTypeHints(cls, context.myContext)) {
Map<PyType, PyType> superSubstitutions =
doPreventingRecursion(superClassType.getPyClass(), false, () -> getGenericSubstitutions(superClassType.getPyClass(), context));
if (superSubstitutions != null) {
results.putAll(superSubstitutions);
}
final List<PyType> superGenerics = collectGenericTypes(superClass, context);
final List<PyExpression> indices = subscriptionExpr != null ? getSubscriptionIndices(subscriptionExpr) : Collections.emptyList();
for (int i = 0; i < superGenerics.size(); i++) {
final PyExpression expr = ContainerUtil.getOrElse(indices, i, null);
final PyType superGeneric = superGenerics.get(i);
final Ref<PyType> typeRef = expr != null ? getType(expr, context) : null;
final PyType actualType = typeRef != null ? typeRef.get() : null;
if (!superGeneric.equals(actualType)) {
results.put(superGeneric, actualType);
// TODO Share this logic with PyTypeChecker.collectTypeSubstitutions
List<PyType> superTypeParameters = collectGenericTypes(superClassType.getPyClass(), context);
List<PyType> superTypeArguments = superClassType instanceof PyCollectionType parameterized ?
parameterized.getElementTypes() : Collections.emptyList();
for (int i = 0; i < superTypeParameters.size(); i++) {
PyType superTypeParameter = superTypeParameters.get(i);
PyType superTypeArgument = ContainerUtil.getOrElse(superTypeArguments, i, null);
if (!superTypeParameter.equals(superTypeArgument)) {
results.put(superTypeParameter, superTypeArgument);
}
}
}
return results;
}
@NotNull
private static Map<PyClass, PySubscriptionExpression> getResolvedSuperClassesAndTypeParameters(@NotNull PyClass pyClass,
@NotNull TypeEvalContext context) {
final Map<PyClass, PySubscriptionExpression> results = new LinkedHashMap<>();
final PyClassStub classStub = pyClass.getStub();
if (context.maySwitchToAST(pyClass)) {
for (PyExpression e : pyClass.getSuperClassExpressions()) {
final PySubscriptionExpression subscriptionExpr = as(e, PySubscriptionExpression.class);
final PyExpression superExpr = subscriptionExpr != null ? subscriptionExpr.getOperand() : e;
final PyType superType = context.getType(superExpr);
final PyClassType superClassType = as(superType, PyClassType.class);
final PyClass superClass = superClassType != null ? superClassType.getPyClass() : null;
if (superClass != null) {
results.put(superClass, subscriptionExpr);
}
}
return results;
}
final Iterable<QualifiedName> allBaseClassesQNames;
final List<PySubscriptionExpression> subscriptedBaseClasses = PyClassElementType.getSubscriptedSuperClassesStubLike(pyClass);
for (PySubscriptionExpression subscrExpr : subscriptedBaseClasses) {
PsiFile containingFile = subscrExpr.getContainingFile();
private static @NotNull List<PyClassType> evaluateSuperClassesAsTypeHints(@NotNull PyClass pyClass, @NotNull TypeEvalContext context) {
List<PyClassType> results = new ArrayList<>();
for (PyExpression superClassExpression : PyClassElementType.getSuperClassExpressions(pyClass)) {
PsiFile containingFile = superClassExpression.getContainingFile();
if (containingFile instanceof PyExpressionCodeFragment) {
containingFile.putUserData(FRAGMENT_OWNER, pyClass);
}
}
final Map<QualifiedName, PySubscriptionExpression> baseClassQNameToExpr = new HashMap<>();
if (classStub == null) {
allBaseClassesQNames = PyClassElementType.getSuperClassQNames(pyClass).keySet();
}
else {
allBaseClassesQNames = classStub.getSuperClasses().keySet();
}
for (PySubscriptionExpression subscriptedBase : subscriptedBaseClasses) {
final PyExpression operand = subscriptedBase.getOperand();
if (operand instanceof PyReferenceExpression) {
final QualifiedName className = PyPsiUtils.asQualifiedName(operand);
baseClassQNameToExpr.put(className, subscriptedBase);
}
}
for (QualifiedName qName : allBaseClassesQNames) {
if (qName == null) continue;
final List<PsiElement> classes = PyResolveUtil.resolveQualifiedNameInScope(qName, (PyFile)pyClass.getContainingFile(), context);
// Better way to handle results of the multiresove
final PyClass firstFound = ContainerUtil.findInstance(classes, PyClass.class);
if (firstFound != null) {
results.put(firstFound, baseClassQNameToExpr.get(qName));
PyType type = Ref.deref(getType(superClassExpression, context));
if (type instanceof PyClassType classType) {
results.add(classType);
}
}
return results;
}
@NotNull
private static List<PyExpression> getSubscriptionIndices(@NotNull PySubscriptionExpression expr) {
final PyExpression indexExpr = expr.getIndexExpression();
final PyTupleExpression tupleExpr = as(indexExpr, PyTupleExpression.class);
return tupleExpr != null ? Arrays.asList(tupleExpr.getElements()) : Collections.singletonList(indexExpr);
}
@NotNull
private static List<PyType> collectGenericTypes(@NotNull PyClass cls, @NotNull Context context) {
if (!isGeneric(cls, context.getTypeContext())) {
return Collections.emptyList();
}
// See https://mypy.readthedocs.io/en/stable/generics.html#defining-sub-classes-of-generic-classes
List<PySubscriptionExpression> parameterizedSuperClassExpressions = PyClassElementType.getSubscriptedSuperClassesStubLike(cls);
List<PySubscriptionExpression> parameterizedSuperClassExpressions =
ContainerUtil.filterIsInstance(PyClassElementType.getSuperClassExpressions(cls), PySubscriptionExpression.class);
PySubscriptionExpression genericAsSuperClass = ContainerUtil.find(parameterizedSuperClassExpressions, s -> {
return resolveToQualifiedNames(s.getOperand(), context.myContext).contains(GENERIC);
});

View File

@@ -21,8 +21,6 @@ import org.jetbrains.annotations.Nullable;
import java.io.IOException;
import java.util.*;
import static com.jetbrains.python.psi.PyUtil.as;
public class PyClassElementType extends PyStubElementType<PyClassStub, PyClass>
implements PyCustomizableStubElementType<PyClass, PyCustomClassStub, PyCustomClassStubType<? extends PyCustomClassStub>> {
@@ -73,26 +71,19 @@ public class PyClassElementType extends PyStubElementType<PyClassStub, PyClass>
return result;
}
@NotNull
private static List<PySubscriptionExpression> getSubscriptedSuperClasses(@NotNull PyClass pyClass) {
return ContainerUtil.mapNotNull(pyClass.getSuperClassExpressions(), x -> as(x, PySubscriptionExpression.class));
}
/**
* If the class' stub is present, return subscription expressions in the base classes list, converting
* If the class' stub is present, return expressions in the base classes list, converting
* their saved text chunks into {@link PyExpressionCodeFragment} and extracting top-level expressions
* from them. Otherwise, get suitable expressions directly from AST, but process them in the same way as
* if they were going to be saved in the stub.
* from them. Otherwise, get superclass expressions directly from AST.
*/
@NotNull
public static List<PySubscriptionExpression> getSubscriptedSuperClassesStubLike(@NotNull PyClass pyClass) {
public static List<PyExpression> getSuperClassExpressions(@NotNull PyClass pyClass) {
final PyClassStub classStub = pyClass.getStub();
if (classStub == null) {
return getSubscriptedSuperClasses(pyClass);
return List.of(pyClass.getSuperClassExpressions());
}
return ContainerUtil.mapNotNull(classStub.getSuperClassesText(),
x -> as(PyUtil.createExpressionFromFragment(x, pyClass.getContainingFile()),
PySubscriptionExpression.class));
return ContainerUtil.mapNotNull(classStub.getSuperClassesText(),
x -> PyUtil.createExpressionFromFragment(x, pyClass.getContainingFile()));
}
@Nullable

View File

@@ -0,0 +1,11 @@
from typing import Generic, TypeVar
T = TypeVar('T')
class Super(Generic[T]):
pass
Alias = Super
class Sub(Alias[T]):
pass

View File

@@ -2734,6 +2734,47 @@ public class PyTypingTest extends PyTestCase {
return result;
}
// PY-59548
public void testGenericBaseClassSpecifiedThroughAlias() {
doTest("int",
"""
from typing import Generic, TypeVar
T = TypeVar('T')
class Super(Generic[T]):
pass
Alias = Super
class Sub(Alias[T]):
pass
def f(x: Super[T]) -> T:
pass
arg: Sub[int]
expr = f(arg)
""");
}
// PY-59548
public void testGenericBaseClassSpecifiedThroughAliasInImportedFile() {
doMultiFileStubAwareTest("int",
"""
from typing import TypeVar
from mod import Sub, Super
T = TypeVar('T')
def f(x: Super[T]) -> T:
pass
arg: Sub[int]
expr = f(arg)
""");
}
private void doTestNoInjectedText(@NotNull String text) {
myFixture.configureByText(PythonFileType.INSTANCE, text);
final InjectedLanguageManager languageManager = InjectedLanguageManager.getInstance(myFixture.getProject());