[pycharm] PY-71726 add initial support for TypeIs

GitOrigin-RevId: ae3d5d7c88450bf2851fa1dabbf264f7206d347a
This commit is contained in:
Vladimir Koshelev
2024-07-19 12:24:11 +02:00
committed by intellij-monorepo-bot
parent 86421462bb
commit 6a08992395
5 changed files with 178 additions and 35 deletions

View File

@@ -130,15 +130,21 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
super.visitPyBinaryExpression(node);
}
/**
* @param isStrict is false means that a type guard makes the assertion
*/
@Nullable
@ApiStatus.Internal
public static Ref<PyType> createAssertionType(@Nullable PyType initial,
@Nullable PyType suggested,
boolean positive,
boolean transformToDefinition,
@NotNull TypeEvalContext context,
@Nullable PyExpression typeElement) {
@Nullable PyType suggested,
boolean positive,
boolean transformToDefinition,
boolean isStrict,
@NotNull TypeEvalContext context,
@Nullable PyExpression typeElement) {
final PyType transformedType = transformTypeFromAssertion(suggested, transformToDefinition, context, typeElement);
// non-strict type guard
if (!isStrict) return Ref.create((positive) ? suggested : initial);
if (positive) {
if (!(initial instanceof PyUnionType) &&
!(initial instanceof PyStructuralType) &&
@@ -146,19 +152,31 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
PyTypeChecker.match(transformedType, initial, context)) {
return Ref.create(initial);
}
if (initial instanceof PyUnionType unionType) {
if (!unionType.isWeak()) {
var matched = unionType.getMembers().stream().filter((member) -> match(member, transformedType, context)).toList();
if (!matched.isEmpty()) {
return Ref.create(PyUnionType.union(matched));
}
}
}
return Ref.create(transformedType);
}
else if (initial instanceof PyUnionType) {
return Ref.create(((PyUnionType)initial).exclude(transformedType, context));
}
else if (!(initial instanceof PyStructuralType) &&
!PyTypeChecker.isUnknown(initial, context) &&
PyTypeChecker.match(transformedType, initial, context)) {
else if (match(initial, transformedType, context)) {
return null;
}
return Ref.create(initial);
}
private static boolean match(@Nullable PyType initial, PyType transformedType, @NotNull TypeEvalContext context) {
return !(initial instanceof PyStructuralType) &&
!PyTypeChecker.isUnknown(initial, context) &&
PyTypeChecker.match(transformedType, initial, context);
}
@Nullable
private static PyType transformTypeFromAssertion(@Nullable PyType type, boolean transformToDefinition, @NotNull TypeEvalContext context,
@Nullable PyExpression typeElement) {
@@ -208,7 +226,12 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
final InstructionTypeCallback typeCallback = new InstructionTypeCallback() {
@Override
public Ref<PyType> getType(TypeEvalContext context, @Nullable PsiElement anchor) {
return createAssertionType(context.getType(target), suggestedType.apply(context), positive, transformToDefinition, context,
return createAssertionType(context.getType(target),
suggestedType.apply(context),
positive,
transformToDefinition,
/*isStrict*/ true,
context,
typeElement);
}
};

View File

@@ -64,6 +64,8 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
public static final String TYPED_DICT_EXT = "typing_extensions.TypedDict";
public static final String TYPE_GUARD = "typing.TypeGuard";
public static final String TYPE_GUARD_EXT = "typing_extensions.TypeGuard";
public static final String TYPE_IS = "typing.TypeIs";
public static final String TYPE_IS_EXT = "typing_extensions.TypeIs";
public static final String GENERIC = "typing.Generic";
public static final String PROTOCOL = "typing.Protocol";
public static final String PROTOCOL_EXT = "typing_extensions.Protocol";
@@ -341,7 +343,7 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
public Ref<PyType> getReturnType(@NotNull PyCallable callable, @NotNull Context context) {
if (callable instanceof PyFunction function) {
if (isTypeGuard(function, context.myContext)) {
if (getTypeGuardKind(function, context.myContext) != TypeGuardKind.None) {
return Ref.create(PyBuiltinCache.getInstance(callable).getBoolType());
}
final PyExpression returnTypeAnnotation = getReturnTypeAnnotation(function, context.myContext);
@@ -395,18 +397,26 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
.orElse(null);
}
if (callSite instanceof PyCallExpression callExpression && isTypeGuard(function, context.myContext)) {
var arguments = callSite.getArguments(function);
if (!arguments.isEmpty() && arguments.get(0) instanceof PyReferenceExpression refExpr) {
var qname = PyPsiUtilsCore.asQualifiedName(refExpr);
if (qname != null) {
var narrowedType = getTypeFromTypeGuardLikeType(function, context.myContext);
if (narrowedType != null) {
return Ref.create(PyNarrowedType.Companion.create(callSite, qname.toString(), narrowedType, callExpression, false));
if (callSite instanceof PyCallExpression callExpression) {
var typeGuardKind = getTypeGuardKind(function, context.myContext);
if (typeGuardKind != TypeGuardKind.None) {
var arguments = callSite.getArguments(function);
if (!arguments.isEmpty() && arguments.get(0) instanceof PyReferenceExpression refExpr) {
var qname = PyPsiUtilsCore.asQualifiedName(refExpr);
if (qname != null) {
var narrowedType = getTypeFromTypeGuardLikeType(function, context.myContext);
if (narrowedType != null) {
return Ref.create(PyNarrowedType.Companion.create(callSite,
qname.toString(),
narrowedType,
callExpression,
false,
TypeGuardKind.TypeIs.equals(typeGuardKind)));
}
}
}
return Ref.create(PyBuiltinCache.getInstance(function).getBoolType());
}
return Ref.create(PyBuiltinCache.getInstance(function).getBoolType());
}
if (callSite instanceof PyCallExpression) {
@@ -1178,30 +1188,37 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
private static <T extends PyTypeCommentOwner & PyAnnotationOwner> boolean typeHintedWithName(@NotNull T owner,
@NotNull TypeEvalContext context,
String... names) {
return ContainerUtil.exists(names, resolveTypeHintsToQualifiedNames(owner, context)::contains);
}
private static <T extends PyTypeCommentOwner & PyAnnotationOwner> Collection<String> resolveTypeHintsToQualifiedNames(
@NotNull T owner,
@NotNull TypeEvalContext context
) {
var annotation = getAnnotationValue(owner, context);
if (annotation instanceof PyStringLiteralExpression stringLiteralExpression) {
final var annotationText = stringLiteralExpression.getStringValue();
annotation = toExpression(annotationText, owner);
if (annotation == null) return false;
if (annotation == null) return Collections.emptyList();
}
if (annotation instanceof PySubscriptionExpression) {
return resolvesToQualifiedNames(((PySubscriptionExpression)annotation).getOperand(), context, names);
if (annotation instanceof PySubscriptionExpression pySubscriptionExpression) {
return resolveToQualifiedNames(pySubscriptionExpression.getOperand(), context);
}
else if (annotation instanceof PyReferenceExpression) {
return resolvesToQualifiedNames(annotation, context, names);
return resolveToQualifiedNames(annotation, context);
}
final String typeCommentValue = owner.getTypeCommentAnnotation();
final PyExpression typeComment = typeCommentValue == null ? null : toExpression(typeCommentValue, owner);
if (typeComment instanceof PySubscriptionExpression) {
return resolvesToQualifiedNames(((PySubscriptionExpression)typeComment).getOperand(), context, names);
if (typeComment instanceof PySubscriptionExpression pySubscriptionExpression) {
return resolveToQualifiedNames(pySubscriptionExpression.getOperand(), context);
}
else if (typeComment instanceof PyReferenceExpression) {
return resolvesToQualifiedNames(typeComment, context, names);
return resolveToQualifiedNames(typeComment, context);
}
return false;
return Collections.emptyList();
}
public static boolean isFinal(@NotNull PyDecoratable decoratable, @NotNull TypeEvalContext context) {
@@ -1224,11 +1241,16 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
typeHintedWithName(function, context, NO_RETURN, NO_RETURN_EXT, NEVER, NEVER_EXT));
}
public static boolean isTypeGuard(@NotNull PyFunction function, @NotNull TypeEvalContext context) {
return PyUtil.getParameterizedCachedValue(function, context, p ->
typeHintedWithName(function, context, TYPE_GUARD, TYPE_GUARD_EXT));
public static TypeGuardKind getTypeGuardKind(@NotNull PyFunction function, @NotNull TypeEvalContext context) {
return PyUtil.getParameterizedCachedValue(function, context, p -> {
var typeHints = resolveTypeHintsToQualifiedNames(function, context);
if (typeHints.contains(TYPE_GUARD) || typeHints.contains(TYPE_GUARD_EXT)) return TypeGuardKind.TypeGuard;
if (typeHints.contains(TYPE_IS) || typeHints.contains(TYPE_IS_EXT)) return TypeGuardKind.TypeIs;
return TypeGuardKind.None;
});
}
@Nullable
public static PyType getTypeFromTypeGuardLikeType(@NotNull PyFunction function, @NotNull TypeEvalContext context) {
var returnType = getReturnTypeAnnotation(function, context);
@@ -2123,4 +2145,10 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
return Objects.hash(myContext);
}
}
public enum TypeGuardKind {
TypeGuard,
TypeIs,
None
}
}

View File

@@ -498,6 +498,7 @@ public class PyReferenceExpressionImpl extends PyElementImpl implements PyRefere
narrowedType.getNarrowedType(),
conditionalInstruction.getResult() ^ narrowedType.getNegated(),
false,
narrowedType.getTypeIs(),
context,
null);
}

View File

@@ -4,28 +4,30 @@ import com.jetbrains.python.psi.PyCallExpression
import com.jetbrains.python.psi.PyClass
import com.jetbrains.python.psi.PyElement
import com.jetbrains.python.psi.impl.PyBuiltinCache
import org.jetbrains.annotations.ApiStatus
/**
* Class is used for representing TypeGuard and TypeIs behavior
*/
@ApiStatus.Internal
class PyNarrowedType private constructor(
pyClass: PyClass,
val qname: String,
val narrowedType: PyType,
val original: PyCallExpression,
// used
val negated: Boolean)
val negated: Boolean,
val typeIs: Boolean)
: PyClassTypeImpl(pyClass, false) {
fun negate(): PyNarrowedType {
return PyNarrowedType(pyClass, qname, narrowedType, original, !negated)
return PyNarrowedType(pyClass, qname, narrowedType, original, !negated, typeIs)
}
companion object {
fun create(anchor: PyElement, name: String, narrowedType: PyType, original: PyCallExpression, negated: Boolean = false): PyNarrowedType? {
fun create(anchor: PyElement, name: String, narrowedType: PyType, original: PyCallExpression, negated: Boolean = false, typeIs: Boolean): PyNarrowedType? {
val pyClass = PyBuiltinCache.getInstance(anchor).getClass("bool")
if (pyClass == null) return null
return PyNarrowedType(pyClass, name, narrowedType, original, negated)
return PyNarrowedType(pyClass, name, narrowedType, original, negated, typeIs)
}
}
}

View File

@@ -2083,6 +2083,42 @@ public class Py3TypeTest extends PyTestCase {
""");
}
public void testFailedTypeGuardCheckDoesntAffectOriginalType() {
doTest("list[int] | list[str]",
"""
from typing import List
from typing import TypeGuard
def is_str_list(val: List[object]) -> TypeGuard[List[str]]:
return all(isinstance(x, str) for x in val)
def func1(val: List[int] | List[str]):
if not is_str_list(val):
expr = val
else:
pass
""");
}
public void testFailedTypeIsCheckDoesAffectOriginalType() {
doTest("list[int]",
"""
from typing import List
from typing_extensions import TypeIs
def is_str_list(val: List[object]) -> TypeIs[List[str]]:
return all(isinstance(x, str) for x in val)
def func1(val: List[int] | List[str]):
if not is_str_list(val):
expr = val
else:
pass
""");
}
public void testNoReturn() {
doTest("Bar",
"""
@@ -2103,6 +2139,59 @@ public class Py3TypeTest extends PyTestCase {
""");
}
public void testTypeIs1() {
doTest("str", """
from typing import Any, Callable, Literal, Mapping, Sequence, TypeVar, Union
from typing_extensions import TypeIs
def is_str1(val: Union[str, int]) -> TypeIs[str]:
return isinstance(val, str)
def func1(val: Union[str, int]):
if is_str1(val):
expr = val
else:
pass
""");
}
public void testTypeIs2() {
doTest("int", """
from typing import Any, Callable, Literal, Mapping, Sequence, TypeVar, Union
from typing_extensions import TypeIs
def is_str1(val: Union[str, int]) -> TypeIs[str]:
return isinstance(val, str)
def func1(val: Union[str, int]):
if is_str1(val):
pass
else:
expr = val
""");
}
public void testTypeIs3() {
doTest("list[str] | list[int]", """
from typing import Any, Callable, Literal, Mapping, Sequence, TypeVar, Union
from typing_extensions import TypeIs
def is_list(val: object) -> TypeIs[list[Any]]:
return isinstance(val, list)
def func3(val: dict[str, str] | list[str] | list[int] | Sequence[int]):
if is_list(val):
expr = val
else:
pass
""");
}
// PY-61137
public void testLiteralStringIsNotInferredWithoutExplicitAnnotation() {
doTest("list[str]",