mirror of
https://gitflic.ru/project/openide/openide.git
synced 2026-04-20 05:21:29 +07:00
[pycharm] PY-71726 add initial support for TypeIs
GitOrigin-RevId: ae3d5d7c88450bf2851fa1dabbf264f7206d347a
This commit is contained in:
committed by
intellij-monorepo-bot
parent
86421462bb
commit
6a08992395
@@ -130,15 +130,21 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
|
||||
super.visitPyBinaryExpression(node);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param isStrict is false means that a type guard makes the assertion
|
||||
*/
|
||||
@Nullable
|
||||
@ApiStatus.Internal
|
||||
public static Ref<PyType> createAssertionType(@Nullable PyType initial,
|
||||
@Nullable PyType suggested,
|
||||
boolean positive,
|
||||
boolean transformToDefinition,
|
||||
@NotNull TypeEvalContext context,
|
||||
@Nullable PyExpression typeElement) {
|
||||
@Nullable PyType suggested,
|
||||
boolean positive,
|
||||
boolean transformToDefinition,
|
||||
boolean isStrict,
|
||||
@NotNull TypeEvalContext context,
|
||||
@Nullable PyExpression typeElement) {
|
||||
final PyType transformedType = transformTypeFromAssertion(suggested, transformToDefinition, context, typeElement);
|
||||
// non-strict type guard
|
||||
if (!isStrict) return Ref.create((positive) ? suggested : initial);
|
||||
if (positive) {
|
||||
if (!(initial instanceof PyUnionType) &&
|
||||
!(initial instanceof PyStructuralType) &&
|
||||
@@ -146,19 +152,31 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
|
||||
PyTypeChecker.match(transformedType, initial, context)) {
|
||||
return Ref.create(initial);
|
||||
}
|
||||
if (initial instanceof PyUnionType unionType) {
|
||||
if (!unionType.isWeak()) {
|
||||
var matched = unionType.getMembers().stream().filter((member) -> match(member, transformedType, context)).toList();
|
||||
if (!matched.isEmpty()) {
|
||||
return Ref.create(PyUnionType.union(matched));
|
||||
}
|
||||
}
|
||||
}
|
||||
return Ref.create(transformedType);
|
||||
}
|
||||
else if (initial instanceof PyUnionType) {
|
||||
return Ref.create(((PyUnionType)initial).exclude(transformedType, context));
|
||||
}
|
||||
else if (!(initial instanceof PyStructuralType) &&
|
||||
!PyTypeChecker.isUnknown(initial, context) &&
|
||||
PyTypeChecker.match(transformedType, initial, context)) {
|
||||
else if (match(initial, transformedType, context)) {
|
||||
return null;
|
||||
}
|
||||
return Ref.create(initial);
|
||||
}
|
||||
|
||||
private static boolean match(@Nullable PyType initial, PyType transformedType, @NotNull TypeEvalContext context) {
|
||||
return !(initial instanceof PyStructuralType) &&
|
||||
!PyTypeChecker.isUnknown(initial, context) &&
|
||||
PyTypeChecker.match(transformedType, initial, context);
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private static PyType transformTypeFromAssertion(@Nullable PyType type, boolean transformToDefinition, @NotNull TypeEvalContext context,
|
||||
@Nullable PyExpression typeElement) {
|
||||
@@ -208,7 +226,12 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
|
||||
final InstructionTypeCallback typeCallback = new InstructionTypeCallback() {
|
||||
@Override
|
||||
public Ref<PyType> getType(TypeEvalContext context, @Nullable PsiElement anchor) {
|
||||
return createAssertionType(context.getType(target), suggestedType.apply(context), positive, transformToDefinition, context,
|
||||
return createAssertionType(context.getType(target),
|
||||
suggestedType.apply(context),
|
||||
positive,
|
||||
transformToDefinition,
|
||||
/*isStrict*/ true,
|
||||
context,
|
||||
typeElement);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -64,6 +64,8 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
|
||||
public static final String TYPED_DICT_EXT = "typing_extensions.TypedDict";
|
||||
public static final String TYPE_GUARD = "typing.TypeGuard";
|
||||
public static final String TYPE_GUARD_EXT = "typing_extensions.TypeGuard";
|
||||
public static final String TYPE_IS = "typing.TypeIs";
|
||||
public static final String TYPE_IS_EXT = "typing_extensions.TypeIs";
|
||||
public static final String GENERIC = "typing.Generic";
|
||||
public static final String PROTOCOL = "typing.Protocol";
|
||||
public static final String PROTOCOL_EXT = "typing_extensions.Protocol";
|
||||
@@ -341,7 +343,7 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
|
||||
public Ref<PyType> getReturnType(@NotNull PyCallable callable, @NotNull Context context) {
|
||||
if (callable instanceof PyFunction function) {
|
||||
|
||||
if (isTypeGuard(function, context.myContext)) {
|
||||
if (getTypeGuardKind(function, context.myContext) != TypeGuardKind.None) {
|
||||
return Ref.create(PyBuiltinCache.getInstance(callable).getBoolType());
|
||||
}
|
||||
final PyExpression returnTypeAnnotation = getReturnTypeAnnotation(function, context.myContext);
|
||||
@@ -395,18 +397,26 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
|
||||
.orElse(null);
|
||||
}
|
||||
|
||||
if (callSite instanceof PyCallExpression callExpression && isTypeGuard(function, context.myContext)) {
|
||||
var arguments = callSite.getArguments(function);
|
||||
if (!arguments.isEmpty() && arguments.get(0) instanceof PyReferenceExpression refExpr) {
|
||||
var qname = PyPsiUtilsCore.asQualifiedName(refExpr);
|
||||
if (qname != null) {
|
||||
var narrowedType = getTypeFromTypeGuardLikeType(function, context.myContext);
|
||||
if (narrowedType != null) {
|
||||
return Ref.create(PyNarrowedType.Companion.create(callSite, qname.toString(), narrowedType, callExpression, false));
|
||||
if (callSite instanceof PyCallExpression callExpression) {
|
||||
var typeGuardKind = getTypeGuardKind(function, context.myContext);
|
||||
if (typeGuardKind != TypeGuardKind.None) {
|
||||
var arguments = callSite.getArguments(function);
|
||||
if (!arguments.isEmpty() && arguments.get(0) instanceof PyReferenceExpression refExpr) {
|
||||
var qname = PyPsiUtilsCore.asQualifiedName(refExpr);
|
||||
if (qname != null) {
|
||||
var narrowedType = getTypeFromTypeGuardLikeType(function, context.myContext);
|
||||
if (narrowedType != null) {
|
||||
return Ref.create(PyNarrowedType.Companion.create(callSite,
|
||||
qname.toString(),
|
||||
narrowedType,
|
||||
callExpression,
|
||||
false,
|
||||
TypeGuardKind.TypeIs.equals(typeGuardKind)));
|
||||
}
|
||||
}
|
||||
}
|
||||
return Ref.create(PyBuiltinCache.getInstance(function).getBoolType());
|
||||
}
|
||||
return Ref.create(PyBuiltinCache.getInstance(function).getBoolType());
|
||||
}
|
||||
|
||||
if (callSite instanceof PyCallExpression) {
|
||||
@@ -1178,30 +1188,37 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
|
||||
private static <T extends PyTypeCommentOwner & PyAnnotationOwner> boolean typeHintedWithName(@NotNull T owner,
|
||||
@NotNull TypeEvalContext context,
|
||||
String... names) {
|
||||
return ContainerUtil.exists(names, resolveTypeHintsToQualifiedNames(owner, context)::contains);
|
||||
}
|
||||
|
||||
private static <T extends PyTypeCommentOwner & PyAnnotationOwner> Collection<String> resolveTypeHintsToQualifiedNames(
|
||||
@NotNull T owner,
|
||||
@NotNull TypeEvalContext context
|
||||
) {
|
||||
var annotation = getAnnotationValue(owner, context);
|
||||
if (annotation instanceof PyStringLiteralExpression stringLiteralExpression) {
|
||||
final var annotationText = stringLiteralExpression.getStringValue();
|
||||
annotation = toExpression(annotationText, owner);
|
||||
if (annotation == null) return false;
|
||||
if (annotation == null) return Collections.emptyList();
|
||||
}
|
||||
|
||||
if (annotation instanceof PySubscriptionExpression) {
|
||||
return resolvesToQualifiedNames(((PySubscriptionExpression)annotation).getOperand(), context, names);
|
||||
if (annotation instanceof PySubscriptionExpression pySubscriptionExpression) {
|
||||
return resolveToQualifiedNames(pySubscriptionExpression.getOperand(), context);
|
||||
}
|
||||
else if (annotation instanceof PyReferenceExpression) {
|
||||
return resolvesToQualifiedNames(annotation, context, names);
|
||||
return resolveToQualifiedNames(annotation, context);
|
||||
}
|
||||
|
||||
final String typeCommentValue = owner.getTypeCommentAnnotation();
|
||||
final PyExpression typeComment = typeCommentValue == null ? null : toExpression(typeCommentValue, owner);
|
||||
if (typeComment instanceof PySubscriptionExpression) {
|
||||
return resolvesToQualifiedNames(((PySubscriptionExpression)typeComment).getOperand(), context, names);
|
||||
if (typeComment instanceof PySubscriptionExpression pySubscriptionExpression) {
|
||||
return resolveToQualifiedNames(pySubscriptionExpression.getOperand(), context);
|
||||
}
|
||||
else if (typeComment instanceof PyReferenceExpression) {
|
||||
return resolvesToQualifiedNames(typeComment, context, names);
|
||||
return resolveToQualifiedNames(typeComment, context);
|
||||
}
|
||||
|
||||
return false;
|
||||
return Collections.emptyList();
|
||||
}
|
||||
|
||||
public static boolean isFinal(@NotNull PyDecoratable decoratable, @NotNull TypeEvalContext context) {
|
||||
@@ -1224,11 +1241,16 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
|
||||
typeHintedWithName(function, context, NO_RETURN, NO_RETURN_EXT, NEVER, NEVER_EXT));
|
||||
}
|
||||
|
||||
public static boolean isTypeGuard(@NotNull PyFunction function, @NotNull TypeEvalContext context) {
|
||||
return PyUtil.getParameterizedCachedValue(function, context, p ->
|
||||
typeHintedWithName(function, context, TYPE_GUARD, TYPE_GUARD_EXT));
|
||||
public static TypeGuardKind getTypeGuardKind(@NotNull PyFunction function, @NotNull TypeEvalContext context) {
|
||||
return PyUtil.getParameterizedCachedValue(function, context, p -> {
|
||||
var typeHints = resolveTypeHintsToQualifiedNames(function, context);
|
||||
if (typeHints.contains(TYPE_GUARD) || typeHints.contains(TYPE_GUARD_EXT)) return TypeGuardKind.TypeGuard;
|
||||
if (typeHints.contains(TYPE_IS) || typeHints.contains(TYPE_IS_EXT)) return TypeGuardKind.TypeIs;
|
||||
return TypeGuardKind.None;
|
||||
});
|
||||
}
|
||||
|
||||
|
||||
@Nullable
|
||||
public static PyType getTypeFromTypeGuardLikeType(@NotNull PyFunction function, @NotNull TypeEvalContext context) {
|
||||
var returnType = getReturnTypeAnnotation(function, context);
|
||||
@@ -2123,4 +2145,10 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
|
||||
return Objects.hash(myContext);
|
||||
}
|
||||
}
|
||||
|
||||
public enum TypeGuardKind {
|
||||
TypeGuard,
|
||||
TypeIs,
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
@@ -498,6 +498,7 @@ public class PyReferenceExpressionImpl extends PyElementImpl implements PyRefere
|
||||
narrowedType.getNarrowedType(),
|
||||
conditionalInstruction.getResult() ^ narrowedType.getNegated(),
|
||||
false,
|
||||
narrowedType.getTypeIs(),
|
||||
context,
|
||||
null);
|
||||
}
|
||||
|
||||
@@ -4,28 +4,30 @@ import com.jetbrains.python.psi.PyCallExpression
|
||||
import com.jetbrains.python.psi.PyClass
|
||||
import com.jetbrains.python.psi.PyElement
|
||||
import com.jetbrains.python.psi.impl.PyBuiltinCache
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* Class is used for representing TypeGuard and TypeIs behavior
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
class PyNarrowedType private constructor(
|
||||
pyClass: PyClass,
|
||||
val qname: String,
|
||||
val narrowedType: PyType,
|
||||
val original: PyCallExpression,
|
||||
// used
|
||||
val negated: Boolean)
|
||||
val negated: Boolean,
|
||||
val typeIs: Boolean)
|
||||
: PyClassTypeImpl(pyClass, false) {
|
||||
|
||||
fun negate(): PyNarrowedType {
|
||||
return PyNarrowedType(pyClass, qname, narrowedType, original, !negated)
|
||||
return PyNarrowedType(pyClass, qname, narrowedType, original, !negated, typeIs)
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun create(anchor: PyElement, name: String, narrowedType: PyType, original: PyCallExpression, negated: Boolean = false): PyNarrowedType? {
|
||||
fun create(anchor: PyElement, name: String, narrowedType: PyType, original: PyCallExpression, negated: Boolean = false, typeIs: Boolean): PyNarrowedType? {
|
||||
val pyClass = PyBuiltinCache.getInstance(anchor).getClass("bool")
|
||||
if (pyClass == null) return null
|
||||
return PyNarrowedType(pyClass, name, narrowedType, original, negated)
|
||||
return PyNarrowedType(pyClass, name, narrowedType, original, negated, typeIs)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -2083,6 +2083,42 @@ public class Py3TypeTest extends PyTestCase {
|
||||
""");
|
||||
}
|
||||
|
||||
public void testFailedTypeGuardCheckDoesntAffectOriginalType() {
|
||||
doTest("list[int] | list[str]",
|
||||
"""
|
||||
from typing import List
|
||||
from typing import TypeGuard
|
||||
|
||||
def is_str_list(val: List[object]) -> TypeGuard[List[str]]:
|
||||
return all(isinstance(x, str) for x in val)
|
||||
|
||||
|
||||
def func1(val: List[int] | List[str]):
|
||||
if not is_str_list(val):
|
||||
expr = val
|
||||
else:
|
||||
pass
|
||||
""");
|
||||
}
|
||||
|
||||
public void testFailedTypeIsCheckDoesAffectOriginalType() {
|
||||
doTest("list[int]",
|
||||
"""
|
||||
from typing import List
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
|
||||
def is_str_list(val: List[object]) -> TypeIs[List[str]]:
|
||||
return all(isinstance(x, str) for x in val)
|
||||
|
||||
def func1(val: List[int] | List[str]):
|
||||
if not is_str_list(val):
|
||||
expr = val
|
||||
else:
|
||||
pass
|
||||
""");
|
||||
}
|
||||
|
||||
public void testNoReturn() {
|
||||
doTest("Bar",
|
||||
"""
|
||||
@@ -2103,6 +2139,59 @@ public class Py3TypeTest extends PyTestCase {
|
||||
""");
|
||||
}
|
||||
|
||||
public void testTypeIs1() {
|
||||
doTest("str", """
|
||||
from typing import Any, Callable, Literal, Mapping, Sequence, TypeVar, Union
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
|
||||
def is_str1(val: Union[str, int]) -> TypeIs[str]:
|
||||
return isinstance(val, str)
|
||||
|
||||
|
||||
def func1(val: Union[str, int]):
|
||||
if is_str1(val):
|
||||
expr = val
|
||||
else:
|
||||
pass
|
||||
""");
|
||||
}
|
||||
|
||||
public void testTypeIs2() {
|
||||
doTest("int", """
|
||||
from typing import Any, Callable, Literal, Mapping, Sequence, TypeVar, Union
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
|
||||
def is_str1(val: Union[str, int]) -> TypeIs[str]:
|
||||
return isinstance(val, str)
|
||||
|
||||
|
||||
def func1(val: Union[str, int]):
|
||||
if is_str1(val):
|
||||
pass
|
||||
else:
|
||||
expr = val
|
||||
""");
|
||||
}
|
||||
|
||||
public void testTypeIs3() {
|
||||
doTest("list[str] | list[int]", """
|
||||
from typing import Any, Callable, Literal, Mapping, Sequence, TypeVar, Union
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
def is_list(val: object) -> TypeIs[list[Any]]:
|
||||
return isinstance(val, list)
|
||||
|
||||
|
||||
def func3(val: dict[str, str] | list[str] | list[int] | Sequence[int]):
|
||||
if is_list(val):
|
||||
expr = val
|
||||
else:
|
||||
pass
|
||||
""");
|
||||
}
|
||||
|
||||
// PY-61137
|
||||
public void testLiteralStringIsNotInferredWithoutExplicitAnnotation() {
|
||||
doTest("list[str]",
|
||||
|
||||
Reference in New Issue
Block a user