[python] Disable strict narrowing, add registry flag to enable it

(cherry picked from commit ba403db011803b4aef1ea2b1582374559b70c32a)

GitOrigin-RevId: 07a2f2fc25f28abfb1f7bca1a75ab86a20e16bc4
This commit is contained in:
Aleksandr.Govenko
2025-08-20 12:49:59 +03:00
committed by intellij-monorepo-bot
parent 9095762730
commit 5e5a46bb51
5 changed files with 24 additions and 15 deletions

View File

@@ -495,6 +495,8 @@
description="It enables the use of a library-level cache for PSI elements from packages."/>
<registryKey key="python.statement.lists.incremental.reparse" defaultValue="false"
description="Enables incremental reparse for statement lists"/>
<registryKey key="python.strict.type.narrow" defaultValue="false"
description="Allows narrowing types exhaustively to Never"/>
</extensions>
<extensionPoints>

View File

@@ -2,6 +2,7 @@
package com.jetbrains.python.codeInsight.controlflow;
import com.intellij.openapi.util.Ref;
import com.intellij.openapi.util.registry.Registry;
import com.intellij.psi.PsiElement;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.containers.ContainerUtil;
@@ -185,7 +186,7 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
final PyExpression subject = matchStatement.getSubject();
if (subject == null) return;
// allowAnyExpr is here because we need negative edges with Never even when subject is not reference expression
pushAssertion(subject, true, true, context -> {
pushAssertion(subject, true, true, true, context -> {
PyType subjectType = context.getType(subject);
for (PyCaseClause cs : matchStatement.getCaseClauses()) {
if (cs.getPattern() == null) continue;
@@ -194,7 +195,7 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
subjectType = PyNeverType.NEVER;
break;
}
subjectType = Ref.deref(createAssertionType(subjectType, context.getType(cs.getPattern()), false, context));
subjectType = Ref.deref(createAssertionType(subjectType, context.getType(cs.getPattern()), false, true, context));
}
return subjectType;
@@ -205,6 +206,7 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
public static @Nullable Ref<PyType> createAssertionType(@Nullable PyType initial,
@Nullable PyType suggested,
boolean positive,
boolean forceStrictNarrow,
@NotNull TypeEvalContext context) {
if (suggested == null) return null;
if (positive) {
@@ -222,10 +224,10 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
}
else {
if (initial instanceof PyUnionType unionType) {
return Ref.create(excludeFromUnion(unionType, suggested, context));
return Ref.create(excludeFromUnion(unionType, suggested, context, forceStrictNarrow));
}
if (match(suggested, initial, context)) {
return Ref.create(PyNeverType.NEVER);
return (forceStrictNarrow || isStrictNarrowingAllowed()) ? Ref.create(PyNeverType.NEVER) : null;
}
Ref<@Nullable PyType> diff = trySubtract(initial, suggested, context);
return diff != null ? diff : Ref.create(initial);
@@ -234,7 +236,8 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
private static @Nullable PyType excludeFromUnion(@NotNull PyUnionType unionType,
@Nullable PyType type,
@NotNull TypeEvalContext context) {
@NotNull TypeEvalContext context,
boolean forceStrictNarrow) {
final List<PyType> members = new ArrayList<>();
for (PyType m : unionType.getMembers()) {
Ref<@Nullable PyType> diff = trySubtract(m, type, context);
@@ -245,12 +248,16 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
members.add(m);
}
}
if (members.isEmpty()) {
if ((forceStrictNarrow || isStrictNarrowingAllowed()) && members.isEmpty()) {
return PyNeverType.NEVER;
}
return PyUnionType.union(members);
}
public static boolean isStrictNarrowingAllowed() {
return Registry.is("python.strict.type.narrow");
}
private static @Nullable Ref<@Nullable PyType> trySubtract(@Nullable PyType type1,
@Nullable PyType type2,
@NotNull TypeEvalContext context) {
@@ -337,26 +344,26 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
}
private void pushAssertion(@Nullable PyExpression expr, boolean positive, @NotNull Function<TypeEvalContext, PyType> suggestedType) {
pushAssertion(expr, positive, false, suggestedType);
pushAssertion(expr, positive, false, isStrictNarrowingAllowed(), suggestedType);
}
private void pushAssertion(@Nullable PyExpression expr, boolean positive, boolean allowAnyExpr, @NotNull Function<TypeEvalContext, PyType> suggestedType) {
private void pushAssertion(@Nullable PyExpression expr, boolean positive, boolean allowAnyExpr, boolean forceStrictNarrow, @NotNull Function<TypeEvalContext, PyType> suggestedType) {
expr = PyPsiUtils.flattenParens(expr);
if (expr instanceof PySequenceExpression seqExpr) {
PyExpression[] elements = seqExpr.getElements();
for (int i = 0; i < elements.length; i++) {
pushAssertion(elements[i], positive, allowAnyExpr, getIteratedType(suggestedType, i));
pushAssertion(elements[i], positive, allowAnyExpr, forceStrictNarrow, getIteratedType(suggestedType, i));
}
}
else if (expr instanceof PyAssignmentExpression walrus) {
pushAssertion(walrus.getTarget(), positive, allowAnyExpr, suggestedType);
pushAssertion(walrus.getTarget(), positive, allowAnyExpr, forceStrictNarrow, suggestedType);
}
else if (expr != null) {
final var target = expr;
final InstructionTypeCallback typeCallback = new InstructionTypeCallback() {
@Override
public Ref<PyType> getType(TypeEvalContext context) {
return createAssertionType(context.getType(target), suggestedType.apply(context), positive, context);
return createAssertionType(context.getType(target), suggestedType.apply(context), positive, forceStrictNarrow, context);
}
};

View File

@@ -26,7 +26,7 @@ class PyCaseClauseImpl(astNode: ASTNode?) : PyElementImpl(astNode), PyCaseClause
if (cs.guardCondition != null && !PyEvaluator.evaluateAsBoolean(cs.guardCondition, false)) continue
if (cs.pattern!!.canExcludePatternType(context)) {
subjectType = Ref.deref(
PyTypeAssertionEvaluator.createAssertionType(subjectType, context.getType(cs.pattern!!), false, context))
PyTypeAssertionEvaluator.createAssertionType(subjectType, context.getType(cs.pattern!!), false, true, context))
}
}

View File

@@ -35,7 +35,7 @@ public class PyClassPatternImpl extends PyElementImpl implements PyClassPattern
if (type instanceof PyClassType classType) {
final PyType instanceType = classType.toInstance();
final PyType captureType = PyCaptureContext.getCaptureType(this, context);
return Ref.deref(PyTypeAssertionEvaluator.createAssertionType(captureType, instanceType, true, context));
return Ref.deref(PyTypeAssertionEvaluator.createAssertionType(captureType, instanceType, true, true, context));
}
return null;
}
@@ -127,7 +127,7 @@ public class PyClassPatternImpl extends PyElementImpl implements PyClassPattern
final var captureType = PyCaptureContext.getCaptureType(pattern, context);
final var patternType = context.getType(pattern);
// For class pattern arguments, we need to ensure that the argument pattern covers its capture type fully
if (Ref.deref(PyTypeAssertionEvaluator.createAssertionType(captureType, patternType, false, context)) instanceof PyNeverType) {
if (Ref.deref(PyTypeAssertionEvaluator.createAssertionType(captureType, patternType, false, true, context)) instanceof PyNeverType) {
// in case the argument pattern is also class pattern with arguments
return pattern.canExcludePatternType(context);
}

View File

@@ -492,7 +492,7 @@ public class PyReferenceExpressionImpl extends PyElementImpl implements PyRefere
@Nullable PyType initial = context.getType(firstArgument);
boolean positive = conditionalInstruction.getResult() ^ narrowedType.getNegated();
if (narrowedType.getTypeIs()) {
return PyTypeAssertionEvaluator.createAssertionType(initial, type, positive, context);
return PyTypeAssertionEvaluator.createAssertionType(initial, type, positive, false, context);
}
return Ref.create((positive) ? type : initial);
}