PY-26184 fix type hinting information for bound generics lost in descriptors

As far as `__get__` call when accessing the attribute is implicit, create a synthetic call considering the type of callsite (access via instance or via class) and use its type as a type of property typed with descriptor class.

GitOrigin-RevId: acc36ebd2d62acfe99a5202b2478356f7b7aea46
This commit is contained in:
Daniil Kalinin
2024-05-21 11:22:37 +02:00
committed by intellij-monorepo-bot
parent ea989a3e05
commit c71a02fa78
6 changed files with 300 additions and 20 deletions

View File

@@ -34,6 +34,7 @@ import java.util.function.Predicate;
import static com.jetbrains.python.psi.PyUtil.as;
import static com.jetbrains.python.psi.impl.PyCallExpressionHelper.getCalleeType;
import static com.jetbrains.python.psi.types.PyTypeUtil.notNullToRef;
/**
* Implements reference expression PSI.
@@ -203,7 +204,7 @@ public class PyReferenceExpressionImpl extends PyElementImpl implements PyRefere
if (qualified && typeFromTargets instanceof PyNoneType) {
return null;
}
final Ref<PyType> descriptorType = getDescriptorType(typeFromTargets, context);
final Ref<PyType> descriptorType = PyDescriptorTypeUtil.getDescriptorType(this, typeFromTargets, context);
if (descriptorType != null) {
return descriptorType.get();
}
@@ -225,23 +226,6 @@ public class PyReferenceExpressionImpl extends PyElementImpl implements PyRefere
return null;
}
@Nullable
private Ref<PyType> getDescriptorType(@Nullable PyType typeFromTargets, @NotNull TypeEvalContext context) {
if (!isQualified()) return null;
final PyClassLikeType targetType = as(typeFromTargets, PyClassLikeType.class);
if (targetType == null || targetType.isDefinition()) return null;
final PyResolveContext resolveContext = PyResolveContext.noProperties(context);
final List<? extends RatedResolveResult> members = targetType.resolveMember(PyNames.GET, this, AccessDirection.READ,
resolveContext);
if (members == null || members.isEmpty()) return null;
final PyType type = StreamEx.of(members)
.map(RatedResolveResult::getElement)
.select(PyCallable.class)
.map(context::getReturnType)
.collect(PyTypeUtil.toUnion());
return Ref.create(type);
}
@Nullable
private Ref<PyType> getQualifiedReferenceType(@NotNull TypeEvalContext context) {
if (!context.maySwitchToAST(this)) {

View File

@@ -95,8 +95,8 @@ public class PyTargetExpressionImpl extends PyBaseElementImpl<PyTargetExpression
return this;
}
@Override
public PyType getType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) {
@Nullable
private PyType getTargetExpressionType(@NotNull TypeEvalContext context) {
if (PyNames.ALL.equals(getName())) {
// no type for __all__, to avoid unresolved reference errors for expressions where a qualifier is a name
// imported via __all__
@@ -185,6 +185,19 @@ public class PyTargetExpressionImpl extends PyBaseElementImpl<PyTargetExpression
return null;
}
@Override
public PyType getType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) {
PyType type = getTargetExpressionType(context);
if (type != null) {
Ref<PyType> typeFromDescriptors = PyDescriptorTypeUtil.getDescriptorType(this, type, context);
if (typeFromDescriptors != null) {
return typeFromDescriptors.get();
}
return type;
}
return null;
}
private @Nullable PyType getTargetTypeFromIterableUnpacking(@NotNull PySequenceExpression topmostContainingTupleOrList,
@NotNull PyExpression assignedIterable,
@NotNull TypeEvalContext context) {

View File

@@ -0,0 +1,57 @@
package com.jetbrains.python.psi.types;
import com.intellij.openapi.util.Ref;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.resolve.PyResolveContext;
import com.jetbrains.python.psi.resolve.RatedResolveResult;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.List;
import static com.jetbrains.python.psi.PyUtil.as;
public final class PyDescriptorTypeUtil {
private PyDescriptorTypeUtil() { }
@Nullable
public static Ref<PyType> getDescriptorType(@NotNull PyQualifiedExpression expression, @Nullable PyType typeFromTargets, @NotNull TypeEvalContext context) {
if (!expression.isQualified()) return null;
final PyClassLikeType targetType = as(typeFromTargets, PyClassLikeType.class);
if (targetType == null || targetType.isDefinition()) return null;
final PyResolveContext resolveContext = PyResolveContext.noProperties(context);
final List<? extends RatedResolveResult> members = targetType.resolveMember(PyNames.GET, expression, AccessDirection.READ,
resolveContext);
if (members == null || members.isEmpty()) return null;
return getTypeFromSyntheticDunderGetCall(expression, typeFromTargets, context);
}
@Nullable
private static Ref<PyType> getTypeFromSyntheticDunderGetCall(@NotNull PyQualifiedExpression expression, @NotNull PyType typeFromTargets, @NotNull TypeEvalContext context) {
PyExpression qualifier = expression.getQualifier();
if (qualifier != null && typeFromTargets instanceof PyCallableType receiverType) {
PyType qualifierType = context.getType(qualifier);
if (qualifierType instanceof PyClassType classType) {
PyType instanceArgumentType;
PyType instanceTypeArgument;
if (classType.isDefinition()) {
instanceArgumentType = PyNoneType.INSTANCE;
instanceTypeArgument = classType;
}
else {
instanceArgumentType = classType;
instanceTypeArgument = PyNoneType.INSTANCE;
}
List<PyType> argumentTypes = List.of(instanceArgumentType, instanceTypeArgument);
PyType type = PySyntheticCallHelper.getCallTypeByFunctionName(PyNames.GET, receiverType, argumentTypes, context);
return Ref.create(type);
}
}
return null;
}
}

View File

@@ -0,0 +1,14 @@
from typing import Optional, Any, overload, Union
class MyDescriptor[T]:
@overload
def __get__(self, instance: None, owner: Any) -> T: # access via class
...
@overload
def __get__(self, instance: object, owner: Any) -> str: # access via instance
...
def __get__(self, instance: Optional[object], owner: Any) -> Union[str, T]:
...
class Test():
member: MyDescriptor[int]

View File

@@ -0,0 +1,14 @@
from typing import Optional, Any, overload, Union
class MyDescriptor[T]:
@overload
def __get__(self, instance: None, owner: Any) -> T: # access via class
...
@overload
def __get__(self, instance: object, owner: Any) -> str: # access via instance
...
def __get__(self, instance: Optional[object], owner: Any) -> Union[str, T]:
...
class Test():
member: MyDescriptor[int]

View File

@@ -2541,6 +2541,204 @@ public class Py3TypeTest extends PyTestCase {
""");
}
// PY-26184
public void testGenericTypeFromDescriptor() {
doTest("list", """
import typing
class MyDescriptor[T]:
def __init__(self, requested_type: typing.Type[T]):
self.requested_type = requested_type
def __get__(self, instance: typing.Any, owner: typing.Any) -> T:
raise Exception("Not implemented")
class Test:
member = MyDescriptor(list)
def foo(self):
test = self.member
expr = test
""");
}
// PY-26184
public void testGenericTypeFromDescriptorWithTypeAnnotationOnly() {
doTest("list", """
from typing import Type, Any
class MyDescriptor[T]:
def __get__(self, instance: typing.Any, owner: typing.Any) -> T:
raise Exception("Not implemented")
class Test:
member: MyDescriptor[list]
def foo(self):
test = self.member
expr = test
""");
}
// PY-26184
public void testGenericTypeFromDescriptorWithTypeAnnotationPriority() {
doTest("list", """
from typing import Type, Any
class MyDescriptor[T]:
def __init__(self, requested_type: Type[T]):
self.requested_type = requested_type
def __get__(self, instance: Any, owner: Any) -> T:
raise Exception("Not implemented")
class Test:
member: MyDescriptor[list] = MyDescriptor(str)
def foo(self):
test = self.member
expr = test
""");
}
// PY-26184
public void testGenericDescriptorAccessViaInstance() {
doTest("int", """
from typing import Optional, Any, overload, Union
class MyDescriptor[T]:
@overload
def __get__(self, instance: None, owner: Any) -> str: # access via class
...
@overload
def __get__(self, instance: object, owner: Any) -> T: # access via instance
...
def __get__(self, instance: Optional[object], owner: Any) -> Union[str, T]:
...
class Foo():
x = MyDescriptor[int]()
foo = Foo()
expr = foo.x
""");
}
// PY-26184
public void testGenericDescriptorAccessViaInstanceReturnsExplicitAny() {
doTest("Any", """
from typing import Optional, Any, overload, Union
class MyDescriptor[T]:
@overload
def __get__(self, instance: None, owner: Any) -> str: # access via class
...
@overload
def __get__(self, instance: object, owner: Any) -> Any: # access via instance
...
def __get__(self, instance: Optional[object], owner: Any) -> Union[str, T]:
...
class Foo():
x = MyDescriptor[int]()
foo = Foo()
expr = foo.x
""");
}
// PY-26184
public void testGenericDescriptorAccessViaClass() {
doTest("int", """
from typing import Optional, Any, overload, Union
class MyDescriptor[T]:
@overload
def __get__(self, instance: None, owner: Any) -> T: # access via class
...
@overload
def __get__(self, instance: object, owner: Any) -> str: # access via instance
...
def __get__(self, instance: Optional[object], owner: Any) -> Union[str, T]:
...
class Foo():
x = MyDescriptor[int]()
expr = Foo.x
""");
}
// PY-26184
public void testGenericDescriptorAccessViaClassReturnsExplicitAny() {
doTest("Any", """
from typing import Optional, Any, overload, Union
class MyDescriptor[T]:
@overload
def __get__(self, instance: None, owner: Any) -> Any: # access via class
...
@overload
def __get__(self, instance: object, owner: Any) -> str: # access via instance
...
def __get__(self, instance: Optional[object], owner: Any) -> Union[str, T]:
...
class Foo():
x = MyDescriptor[int]()
expr = Foo.x
""");
}
// PY-26184
public void testGenericDescriptorAccessViaClassReturnsNothing() {
doTest("None", """
from typing import Optional, Any, overload, Union
class MyDescriptor[T]:
@overload
def __get__(self, instance: None, owner: Any): # access via class
...
@overload
def __get__(self, instance: object, owner: Any) -> T: # access via instance
...
def __get__(self, instance: Optional[object], owner: Any) -> Union[str, T]:
...
class Foo():
x = MyDescriptor[int]()
expr = Foo.x
""");
}
// PY-26184
public void testGenericTypeFromParameterizedOnInheritanceDescriptorWithTypeAnnotationOnly() {
doTest("str", """
from typing import Any
class MyDescriptor[T]:
def __get__(self, instance: Any, owner: Any) -> T:
...
class StrDescriptor(MyDescriptor[str]):
pass
class Test:
member: StrDescriptor
def foo(self):
test = self.member
expr = test
""");
}
// PY-26184
public void testGenericTypeFromDescriptorDefinedWithTypeAnnotationInExternalFileAccessViaInstance() {
doMultiFileTest("str", """
from a import Test
test = Test()
expr = test.member
""");
}
private void doTest(final String expectedType, final String text) {
myFixture.configureByText(PythonFileType.INSTANCE, text);
final PyExpression expr = myFixture.findElementByText("expr", PyExpression.class);