diff --git a/python/python-psi-api/src/com/jetbrains/python/psi/PyCallable.java b/python/python-psi-api/src/com/jetbrains/python/psi/PyCallable.java index 35e748f9a832..d243bf1e590e 100644 --- a/python/python-psi-api/src/com/jetbrains/python/psi/PyCallable.java +++ b/python/python-psi-api/src/com/jetbrains/python/psi/PyCallable.java @@ -59,12 +59,27 @@ public interface PyCallable extends PyAstCallable, PyTypedElement, PyQualifiedNa @Nullable PyType getCallType(@NotNull TypeEvalContext context, @NotNull PyCallSiteExpression callSite); + /** + * Please use getCallType with four arguments instead + */ + @Nullable + @Deprecated(forRemoval = true) + default PyType getCallType(@Nullable PyExpression receiver, + @NotNull Map parameters, + @NotNull TypeEvalContext context) { + return getCallType(receiver, null, parameters, context); + } + + + /** * Returns the type of the call to the callable where the call site is specified by the optional receiver and the arguments to parameters * mapping. */ @Nullable - PyType getCallType(@Nullable PyExpression receiver, @NotNull Map parameters, + PyType getCallType(@Nullable PyExpression receiver, + @Nullable PyCallSiteExpression pyCallSiteExpression, + @NotNull Map parameters, @NotNull TypeEvalContext context); /** diff --git a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/controlflow/PyTypeAssertionEvaluator.java b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/controlflow/PyTypeAssertionEvaluator.java index 11498f38c190..292742a41355 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/controlflow/PyTypeAssertionEvaluator.java +++ b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/controlflow/PyTypeAssertionEvaluator.java @@ -138,34 +138,31 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor { public static Ref createAssertionType(@Nullable PyType initial, @Nullable PyType suggested, boolean positive, - boolean transformToDefinition, boolean isStrict, - @NotNull TypeEvalContext context, - @Nullable PyExpression typeElement) { - final PyType transformedType = transformTypeFromAssertion(suggested, transformToDefinition, context, typeElement); + @NotNull TypeEvalContext context) { // non-strict type guard if (!isStrict) return Ref.create((positive) ? suggested : initial); if (positive) { if (!(initial instanceof PyUnionType) && !(initial instanceof PyStructuralType) && !PyTypeChecker.isUnknown(initial, context) && - PyTypeChecker.match(transformedType, initial, context)) { + PyTypeChecker.match(suggested, initial, context)) { return Ref.create(initial); } if (initial instanceof PyUnionType unionType) { if (!unionType.isWeak()) { - var matched = unionType.getMembers().stream().filter((member) -> match(member, transformedType, context)).toList(); + var matched = unionType.getMembers().stream().filter((member) -> match(member, suggested, context)).toList(); if (!matched.isEmpty()) { return Ref.create(PyUnionType.union(matched)); } } } - return Ref.create(transformedType); + return Ref.create(suggested); } else if (initial instanceof PyUnionType) { - return Ref.create(((PyUnionType)initial).exclude(transformedType, context)); + return Ref.create(((PyUnionType)initial).exclude(suggested, context)); } - else if (match(initial, transformedType, context)) { + else if (match(initial, suggested, context)) { return null; } return Ref.create(initial); @@ -180,6 +177,12 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor { @Nullable private static PyType transformTypeFromAssertion(@Nullable PyType type, boolean transformToDefinition, @NotNull TypeEvalContext context, @Nullable PyExpression typeElement) { + /* + * We need to distinguish: + * if isinstance(x, (int, str)): + * And: + * if isinstance(x, (1, "")): + */ if (type instanceof PyTupleType tupleType) { final List members = new ArrayList<>(); final int count = tupleType.getElementCount(); @@ -227,12 +230,10 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor { @Override public Ref getType(TypeEvalContext context, @Nullable PsiElement anchor) { return createAssertionType(context.getType(target), - suggestedType.apply(context), + transformTypeFromAssertion(suggestedType.apply(context), transformToDefinition, context, typeElement), positive, - transformToDefinition, /*isStrict*/ true, - context, - typeElement); + context); } }; diff --git a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java index 6f57b787451d..61380b52a26e 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java +++ b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java @@ -19,7 +19,6 @@ import com.jetbrains.python.PyCustomType; import com.jetbrains.python.PyNames; import com.jetbrains.python.PyTokenTypes; import com.jetbrains.python.ast.PyAstFunction; -import com.jetbrains.python.ast.impl.PyPsiUtilsCore; import com.jetbrains.python.ast.impl.PyUtilCore; import com.jetbrains.python.codeInsight.controlflow.ControlFlowCache; import com.jetbrains.python.codeInsight.controlflow.ScopeOwner; @@ -347,10 +346,6 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext< @Override public Ref getReturnType(@NotNull PyCallable callable, @NotNull Context context) { if (callable instanceof PyFunction function) { - - if (getTypeGuardKind(function, context.myContext) != TypeGuardKind.None) { - return Ref.create(PyBuiltinCache.getInstance(callable).getBoolType()); - } final PyExpression returnTypeAnnotation = getReturnTypeAnnotation(function, context.myContext); if (returnTypeAnnotation != null) { final Ref typeRef = getType(returnTypeAnnotation, context); @@ -403,25 +398,25 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext< } if (callSite instanceof PyCallExpression callExpression) { - var typeGuardKind = getTypeGuardKind(function, context.myContext); - if (typeGuardKind != TypeGuardKind.None) { - var arguments = callSite.getArguments(function); - if (!arguments.isEmpty() && arguments.get(0) instanceof PyReferenceExpression refExpr) { - var qname = PyPsiUtilsCore.asQualifiedName(refExpr); - if (qname != null) { - var narrowedType = getTypeFromTypeGuardLikeType(function, context.myContext); - if (narrowedType != null) { - return Ref.create(PyNarrowedType.Companion.create(callSite, - qname.toString(), - narrowedType, - callExpression, - false, - TypeGuardKind.TypeIs.equals(typeGuardKind))); - } - } - } - return Ref.create(PyBuiltinCache.getInstance(function).getBoolType()); - } + //var typeGuardKind = getTypeGuardKind(function, context.myContext); + //if (typeGuardKind != TypeGuardKind.None) { + // var arguments = callSite.getArguments(function); + // if (!arguments.isEmpty() && arguments.get(0) instanceof PyReferenceExpression refExpr) { + // var qname = PyPsiUtilsCore.asQualifiedName(refExpr); + // if (qname != null) { + // var narrowedType = getTypeFromTypeGuardLikeType(function, context.myContext); + // if (narrowedType != null) { + // return Ref.create(PyNarrowedType.Companion.create(callSite, + // qname.toString(), + // narrowedType, + // callExpression, + // false, + // TypeGuardKind.TypeIs.equals(typeGuardKind))); + // } + // } + // } + // return Ref.create(PyBuiltinCache.getInstance(function).getBoolType()); + //} } if (callSite instanceof PyCallExpression) { @@ -945,6 +940,10 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext< if (selfType != null) { return selfType; } + final Ref narrowedType = getNarrowedType(resolved, context.getTypeContext()); + if (narrowedType != null) { + return narrowedType; + } return null; } finally { @@ -954,6 +953,18 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext< } } + private static Ref getNarrowedType(@NotNull PsiElement resolved, @NotNull TypeEvalContext context) { + if (resolved instanceof PyExpression expression) { + Collection names = resolveToQualifiedNames(expression, context); + var isTypeIs = names.contains(TYPE_IS) || names.contains(TYPE_IS_EXT); + var isTypeGuard = names.contains(TYPE_GUARD) || names.contains(TYPE_GUARD_EXT); + if (isTypeIs || isTypeGuard) { + return Ref.create(PyNarrowedType.Companion.create(expression, isTypeIs)); + } + } + return null; + } + private static Ref getSelfType(@NotNull PsiElement resolved, @NotNull PyExpression typeHint, @NotNull Context context) { if (resolved instanceof PyQualifiedNameOwner && (SELF.equals(((PyQualifiedNameOwner)resolved).getQualifiedName()) || @@ -1952,9 +1963,12 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext< return returnType; } + /** + * Bound narrowed types shouldn't leak out of its scope, since it is bound to a particular call site. + */ @Nullable public static PyType removeNarrowedTypeIfNeeded(@Nullable PyType type) { - if (type instanceof PyNarrowedType pyNarrowedType) { + if (type instanceof PyNarrowedType pyNarrowedType && pyNarrowedType.isBound()) { return PyBuiltinCache.getInstance(pyNarrowedType.getOriginal()).getBoolType(); } else { diff --git a/python/python-psi-impl/src/com/jetbrains/python/documentation/PyTypeModelBuilder.java b/python/python-psi-impl/src/com/jetbrains/python/documentation/PyTypeModelBuilder.java index 1f8eb6936377..ddfc075da3fe 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/documentation/PyTypeModelBuilder.java +++ b/python/python-psi-impl/src/com/jetbrains/python/documentation/PyTypeModelBuilder.java @@ -93,11 +93,16 @@ public class PyTypeModelBuilder { private final TypeModel collectionType; private final List elementTypes; private final boolean useTypingAlias; + private final Boolean isTypeIs; - private CollectionOf(TypeModel collectionType, List elementTypes, boolean useTypingAlias) { + private CollectionOf(TypeModel collectionType, + List elementTypes, + boolean useTypingAlias, + @Nullable Boolean isTypeIs) { this.collectionType = collectionType; this.elementTypes = elementTypes; this.useTypingAlias = useTypingAlias; + this.isTypeIs = isTypeIs; } @Override @@ -305,7 +310,10 @@ public class PyTypeModelBuilder { if (!elementModels.isEmpty()) { final TypeModel collectionType = build(new PyClassTypeImpl(asCollection.getPyClass(), asCollection.isDefinition()), false); boolean useTypingAlias = PyiUtil.getOriginalLanguageLevel(asCollection.getPyClass()).isOlderThan(LanguageLevel.PYTHON39); - result = new CollectionOf(collectionType, elementModels, useTypingAlias); + result = new CollectionOf(collectionType, + elementModels, + useTypingAlias, + asCollection instanceof PyNarrowedType pyNarrowedType ? pyNarrowedType.getTypeIs() : null); } } else if (type instanceof PyUnionType unionType && allowUnions) { @@ -649,7 +657,13 @@ public class PyTypeModelBuilder { protected void typingGenericFormat(CollectionOf collectionOf) { final boolean prevSwitchBuiltinToTyping = switchBuiltinToTyping; switchBuiltinToTyping = collectionOf.useTypingAlias; - collectionOf.collectionType.accept(this); + if (collectionOf.isTypeIs == null) { + collectionOf.collectionType.accept(this); + } else if (collectionOf.isTypeIs) { + add(styled("TypeIs", PyHighlighter.PY_CLASS_DEFINITION)); + } else { + add(styled("TypeGuard", PyHighlighter.PY_CLASS_DEFINITION)); + } switchBuiltinToTyping = prevSwitchBuiltinToTyping; if (!collectionOf.elementTypes.isEmpty()) { diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyClassImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyClassImpl.java index 8c2d6f28121a..ee1ace6acb23 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyClassImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyClassImpl.java @@ -914,7 +914,7 @@ public class PyClassImpl extends PyBaseElementImpl implements PyCla if (!(callable instanceof StubBasedPsiElement) && !context.maySwitchToAST(callable)) { return null; } - return callable.getCallType(receiver, buildArgumentsToParametersMap(receiver, callable, context), context); + return callable.getCallType(receiver, null, buildArgumentsToParametersMap(receiver, callable, context), context); } return null; } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java index 88e2d56132b7..2c8766eb7757 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java @@ -187,7 +187,7 @@ public class PyFunctionImpl extends PyBaseElementImpl implements } allMappedParameters.putAll(mappedExplicitParameters); - return getCallType(receiver, allMappedParameters, context); + return getCallType(receiver, callSite, allMappedParameters, context); } private static @Nullable PyType derefType(@NotNull Ref typeRef, @NotNull PyTypeProvider typeProvider) { @@ -200,13 +200,15 @@ public class PyFunctionImpl extends PyBaseElementImpl implements @Override public @Nullable PyType getCallType(@Nullable PyExpression receiver, + @Nullable PyCallSiteExpression callSiteExpression, @NotNull Map parameters, @NotNull TypeEvalContext context) { - return analyzeCallType(PyUtil.getReturnTypeToAnalyzeAsCallType(this, context), receiver, parameters, context); + return analyzeCallType(PyUtil.getReturnTypeToAnalyzeAsCallType(this, context), receiver, callSiteExpression, parameters, context); } private @Nullable PyType analyzeCallType(@Nullable PyType type, @Nullable PyExpression receiver, + @Nullable PyCallSiteExpression callSiteExpression, @NotNull Map parameters, @NotNull TypeEvalContext context) { if (PyTypeChecker.hasGenerics(type, context)) { @@ -233,7 +235,7 @@ public class PyFunctionImpl extends PyBaseElementImpl implements if (type != null && isDynamicallyEvaluated(parameters.values(), context)) { type = PyUnionType.createWeakType(type); } - return type; + return PyNarrowedType.Companion.bindIfNeeded(type, callSiteExpression); } @Override diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyLambdaExpressionImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyLambdaExpressionImpl.java index 7bc86f4f01f1..3c1a1616d6d5 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyLambdaExpressionImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyLambdaExpressionImpl.java @@ -68,6 +68,7 @@ public class PyLambdaExpressionImpl extends PyElementImpl implements PyLambdaExp @Nullable @Override public PyType getCallType(@Nullable PyExpression receiver, + @Nullable PyCallSiteExpression pyCallSiteExpression, @NotNull Map parameters, @NotNull TypeEvalContext context) { return context.getReturnType(this); diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyReferenceExpressionImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyReferenceExpressionImpl.java index ca83e2dfb738..f095cbe4e247 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyReferenceExpressionImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyReferenceExpressionImpl.java @@ -488,19 +488,18 @@ public class PyReferenceExpressionImpl extends PyElementImpl implements PyRefere return readWriteInstruction.getType(context, anchor); } if (instr instanceof ConditionalInstruction conditionalInstruction) { - if (context.getType((PyTypedElement)conditionalInstruction.getCondition()) instanceof PyNarrowedType narrowedType) { - PyExpression[] arguments = narrowedType.getOriginal().getArguments(); - if (arguments.length > 0) { - var firstArgument = arguments[0]; + if (context.getType((PyTypedElement)conditionalInstruction.getCondition()) instanceof PyNarrowedType narrowedType + && narrowedType.isBound()) { + var arguments = narrowedType.getOriginal().getArguments(null); + if (!arguments.isEmpty()) { + var firstArgument = arguments.get(0); if (firstArgument instanceof PyReferenceExpression) { return PyTypeAssertionEvaluator.createAssertionType( context.getType(firstArgument), narrowedType.getNarrowedType(), conditionalInstruction.getResult() ^ narrowedType.getNegated(), - false, narrowedType.getTypeIs(), - context, - null); + context); } } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyTargetExpressionImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyTargetExpressionImpl.java index ea8206e34218..6e30cc80bf2e 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyTargetExpressionImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyTargetExpressionImpl.java @@ -435,7 +435,7 @@ public class PyTargetExpressionImpl extends PyBaseElementImpl actualParameters, @NotNull Collection allParameters, @Nullable PyExpression receiver, + @NotNull PyCallSiteExpression callsite, @NotNull TypeEvalContext context) { final var substitutions = PyTypeChecker.unifyGenericCall(receiver, actualParameters, context); final var substitutionsWithUnresolvedReturnGenerics = PyTypeChecker.getSubstitutionsWithUnresolvedReturnGenerics(allParameters, type, substitutions, context); - return PyTypeChecker.substitute(type, substitutionsWithUnresolvedReturnGenerics, context); + PyType typeAfterSubstitution = PyTypeChecker.substitute(type, substitutionsWithUnresolvedReturnGenerics, context); + return PyNarrowedType.Companion.bindIfNeeded(typeAfterSubstitution, callsite); } @Nullable diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyCollectionTypeImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyCollectionTypeImpl.java index 02f5c8967c01..ac9652e05d8a 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyCollectionTypeImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyCollectionTypeImpl.java @@ -30,7 +30,7 @@ import java.util.List; public class PyCollectionTypeImpl extends PyClassTypeImpl implements PyCollectionType { - @NotNull private final List myElementTypes; + @NotNull protected final List myElementTypes; public PyCollectionTypeImpl(@NotNull PyClass source, boolean isDefinition, @NotNull List elementTypes) { super(source, isDefinition); diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyNarrowedType.kt b/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyNarrowedType.kt index 7efe798cbbb2..ab3fe987ee88 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyNarrowedType.kt +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyNarrowedType.kt @@ -1,8 +1,10 @@ package com.jetbrains.python.psi.types -import com.jetbrains.python.psi.PyCallExpression +import com.jetbrains.python.ast.impl.PyPsiUtilsCore +import com.jetbrains.python.psi.PyCallSiteExpression import com.jetbrains.python.psi.PyClass import com.jetbrains.python.psi.PyElement +import com.jetbrains.python.psi.PyReferenceExpression import com.jetbrains.python.psi.impl.PyBuiltinCache import org.jetbrains.annotations.ApiStatus @@ -12,22 +14,57 @@ import org.jetbrains.annotations.ApiStatus @ApiStatus.Internal class PyNarrowedType private constructor( pyClass: PyClass, - val qname: String, - val narrowedType: PyType, - val original: PyCallExpression, + val qname: String?, + val original: PyCallSiteExpression?, val negated: Boolean, - val typeIs: Boolean) - : PyClassTypeImpl(pyClass, false) { + val typeIs: Boolean, + typeVars: List, +) + : PyCollectionTypeImpl(pyClass, false, typeVars) { fun negate(): PyNarrowedType { - return PyNarrowedType(pyClass, qname, narrowedType, original, !negated, typeIs) + return PyNarrowedType(pyClass, qname, original, !negated, typeIs, myElementTypes) } + fun bind(callExpression: PyCallSiteExpression, name: String): PyNarrowedType { + return PyNarrowedType(pyClass, name, callExpression, negated, typeIs, myElementTypes) + } + + fun substitute(type: List): PyNarrowedType { + return PyNarrowedType(pyClass, qname, original, negated, typeIs, type) + } + + /** + * A type is considered bound if it has an associated original call site expression, + * indicating that it is valid only within the call-side scope. + */ + fun isBound(): Boolean = original != null + + val narrowedType: PyType + get() = requireNotNull(iteratedItemType) + companion object { - fun create(anchor: PyElement, name: String, narrowedType: PyType, original: PyCallExpression, negated: Boolean = false, typeIs: Boolean): PyNarrowedType? { + + private val myTypeVar = object : PyGenericType(toString(), null) {} + + fun create(anchor: PyElement, typeIs: Boolean): PyNarrowedType? { val pyClass = PyBuiltinCache.getInstance(anchor).getClass("bool") if (pyClass == null) return null - return PyNarrowedType(pyClass, name, narrowedType, original, negated, typeIs) + return PyNarrowedType(pyClass, null, null, false, typeIs, listOf(myTypeVar)) + } + + fun bindIfNeeded(type: PyType?, callSiteExpression: PyCallSiteExpression?): PyType? { + if (type is PyNarrowedType && callSiteExpression != null) { + val arguments = callSiteExpression.getArguments(null) + val pyReferenceExpression = arguments.firstOrNull() + if (pyReferenceExpression is PyReferenceExpression) { + val qname = PyPsiUtilsCore.asQualifiedName(pyReferenceExpression) + if (qname != null) { + return type.bind(callSiteExpression, qname.toString()) + } + } + } + return type } } } \ No newline at end of file diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyTypeChecker.java b/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyTypeChecker.java index edb02c7fd7fc..c8375ec90a12 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyTypeChecker.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyTypeChecker.java @@ -1158,6 +1158,10 @@ public final class PyTypeChecker { ); return PyTypedDictType.Companion.createFromKeysToValueTypes(typedDictType.myClass, substitutedTDFields, false); } + else if (type instanceof PyNarrowedType pyNarrowedType) { + return pyNarrowedType.substitute(ContainerUtil.flatMap(pyNarrowedType.getElementTypes(), + t -> substituteExpand(t, substitutions, context, substituting))); + } else if (type instanceof final PyCollectionTypeImpl collection) { return new PyCollectionTypeImpl(collection.getPyClass(), collection.isDefinition(), ContainerUtil.flatMap(collection.getElementTypes(), diff --git a/python/python-psi-impl/src/com/jetbrains/python/refactoring/PyDefUseUtil.java b/python/python-psi-impl/src/com/jetbrains/python/refactoring/PyDefUseUtil.java index 8cea27ff8f18..83cb0a635ee2 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/refactoring/PyDefUseUtil.java +++ b/python/python-psi-impl/src/com/jetbrains/python/refactoring/PyDefUseUtil.java @@ -105,7 +105,7 @@ public final class PyDefUseUtil { var newContext = (MAX_CONTROL_FLOW_SIZE > instructions.length) ? TypeEvalContext.codeAnalysis(context.getOrigin().getProject(), context.getOrigin()) : TypeEvalContext.codeInsightFallback(context.getOrigin().getProject()); - if (newContext.getType(typedElement) instanceof PyNarrowedType narrowedType) { + if (newContext.getType(typedElement) instanceof PyNarrowedType narrowedType && narrowedType.isBound()) { if (narrowedType.getQname().equals(varName)) { pendingTypeGuard.put(narrowedType.getOriginal(), conditionalInstruction); } diff --git a/python/testSrc/com/jetbrains/python/Py3TypeTest.java b/python/testSrc/com/jetbrains/python/Py3TypeTest.java index 835c64ba1cab..a7ea9d5469c6 100644 --- a/python/testSrc/com/jetbrains/python/Py3TypeTest.java +++ b/python/testSrc/com/jetbrains/python/Py3TypeTest.java @@ -1924,8 +1924,8 @@ public class Py3TypeTest extends PyTestCase { - public void testTypeGuardBool() { - doTest("bool", + public void testTypeGuardPresentation() { + doTest("TypeGuard[list[str]]", """ from typing import List from typing import TypeGuard @@ -1940,6 +1940,88 @@ public class Py3TypeTest extends PyTestCase { """); } + public void testTypeIsPresentation() { + doTest("TypeIs[list[str]]", + """ + from typing import List + from typing_extensions import TypeIs + + + def is_str_list(val: List[object]) -> TypeIs[List[str]]: + return all(isinstance(x, str) for x in val) + + + def func1(val: List[object]): + expr = is_str_list(val) + """); + } + + public void testTypeGuardIsErasedOnReturn() { + doTest("bool", + """ + from typing import List + from typing_extensions import TypeIs + + def is_str_list(val: List[object]) -> TypeIs[List[str]]: + return all(isinstance(x, str) for x in val) + + def func1(val: List[object]): + return is_str_list(val) + + expr = func1([]) + """); + } + + public void testTypeAliasesWithTypeIs() { + doTest("list[str]", """ + from typing import List + from typing_extensions import TypeIs + + MyTypeIs = TypeIs[List[str]] + + def is_str_list(val: List[object]) -> MyTypeIs: + return all(isinstance(x, str) for x in val) + + def func1(val: List[object]): + if is_str_list(val): + expr = val + """); + } + + public void testTypeAliasWithGenericTypeIs() { + doTest("list[str]", """ + from typing import List + from typing_extensions import TypeIs + + type MyTypeIs[T] = TypeIs[T] + + def is_str_list(val: List[object]) -> MyTypeIs[List[str]]: + return all(isinstance(x, str) for x in val) + + def func1(val: List[object]): + if is_str_list(val): + expr = val + """); + } + + public void testTypeIsWithGenerics() { + doTest("tuple[str, str]", """ + from typing_extensions import TypeIs + from typing import TypeVar + + T = TypeVar("T") + + def is_two_element_tuple(val: tuple[T, ...]) -> TypeIs[tuple[T, T]]: + return len(val) == 2 + + + def func7(names: tuple[str, ...]): + if is_two_element_tuple(names): + expr = names + """); + } + + public void testTypeGuardListInStringLiteral() { doTest("list[str]", """ @@ -2044,7 +2126,6 @@ public class Py3TypeTest extends PyTestCase { name: str age: int - def is_person(val: dict) -> TypeGuard[Person]: try: return isinstance(val["name"], str) and isinstance(val["age"], int) @@ -2059,6 +2140,20 @@ public class Py3TypeTest extends PyTestCase { print("Not a person!")"""); } + public void testTypeIsInCallable() { + doTest("str", """ + from typing import Callable + from typing import assert_type + from typing_extensions import TypeIs + + def takes_narrower(x: int | str, narrower: Callable[[object], TypeIs[int]]): + if narrower(x): + pass + else: + expr = x + """); + } + public void testTypeGuardDoubleCheckNegation() { doTest("Person", """ @@ -2106,11 +2201,10 @@ public class Py3TypeTest extends PyTestCase { """ from typing import List from typing_extensions import TypeIs - - + def is_str_list(val: List[object]) -> TypeIs[List[str]]: return all(isinstance(x, str) for x in val) - + def func1(val: List[int] | List[str]): if not is_str_list(val): expr = val @@ -2119,6 +2213,35 @@ public class Py3TypeTest extends PyTestCase { """); } + public void testHandleGenericReturnType() { + doTest("list[str]", """ + from typing import List + + def create_list_of_type[T](item: T, count: int) -> List[T]: + return [item] * count + + expr = create_list_of_type("foo", 3) + """); + } + + public void testHandleGenericWithAliasesReturnType() { + doTest("int | None", """ + from typing import Dict, TypeAlias, TypeVar + + V = TypeVar("V") + + StringDict = Dict[str, V] + + def create_dict_of_type[T](item: T,) -> StringDict[T]: + return {"foo": item} + + + dict = create_dict_of_type(23) + + expr = dict.get("foo") + """); + } + public void testNoReturn() { doTest("Bar", """ diff --git a/python/testSrc/com/jetbrains/python/PyTypeTest.java b/python/testSrc/com/jetbrains/python/PyTypeTest.java index c60814e19f82..863a56b6daf3 100644 --- a/python/testSrc/com/jetbrains/python/PyTypeTest.java +++ b/python/testSrc/com/jetbrains/python/PyTypeTest.java @@ -304,6 +304,15 @@ public class PyTypeTest extends PyTestCase { expr = x"""); } + public void testIsInstance2() { + doTest("str", + """ + x = "" + if isinstance(x, (1, "")): + expr = x + """); + } + // PY-4383 public void testAssertIsInstance() { doTest("int",