PY-76906 Support type narrowing for comparisons with literal types

GitOrigin-RevId: 8094877f9b02b57f005391dd1b95befba621b604
This commit is contained in:
Petr
2025-01-13 17:15:33 +01:00
committed by intellij-monorepo-bot
parent 4a13f8ad86
commit 04cac6392b
7 changed files with 215 additions and 152 deletions

View File

@@ -3,6 +3,7 @@ package com.jetbrains.python.codeInsight.controlflow;
import com.intellij.openapi.util.Ref;
import com.intellij.psi.PsiElement;
import com.intellij.util.ObjectUtils;
import com.intellij.util.containers.ContainerUtil;
import com.intellij.util.containers.Stack;
import com.jetbrains.python.PyNames;
@@ -74,62 +75,64 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
@Override
public void visitPyBinaryExpression(@NotNull PyBinaryExpression node) {
if (node.isOperator(PyNames.AND) || node.isOperator(PyNames.OR)) return;
final PyExpression lhs = PyPsiUtils.flattenParens(node.getLeftExpression());
final PyExpression rhs = PyPsiUtils.flattenParens(node.getRightExpression());
if (lhs == null || rhs == null) return;
final PyExpression lhs = node.getLeftExpression();
final PyExpression rhs = node.getRightExpression();
boolean isOperator = node.isOperator(PyNames.IS);
boolean isNotOperator = node.isOperator("isnot");
if (lhs instanceof PyReferenceExpression && rhs instanceof PyReferenceExpression ||
lhs instanceof PyReferenceExpression && rhs instanceof PyNoneLiteralExpression ||
lhs instanceof PyNoneLiteralExpression && rhs instanceof PyReferenceExpression) {
final boolean leftIsNone = lhs instanceof PyNoneLiteralExpression || PyNames.NONE.equals(lhs.getName());
final boolean rightIsNone = rhs instanceof PyNoneLiteralExpression || PyNames.NONE.equals(rhs.getName());
if (leftIsNone ^ rightIsNone) {
final PyReferenceExpression target = (PyReferenceExpression)(rightIsNone ? lhs : rhs);
if (isOperator) {
pushAssertion(target, myPositive, false, context -> PyNoneType.INSTANCE, null);
return;
}
if (isNotOperator) {
pushAssertion(target, !myPositive, false, context -> PyNoneType.INSTANCE, null);
return;
}
}
PyElementType operator = node.getOperator();
boolean isOrEqualsOperator = node.isOperator(PyNames.IS) || PyTokenTypes.EQEQ.equals(operator);
if (isOrEqualsOperator || node.isOperator("isnot") || PyTokenTypes.NE.equals(operator) || PyTokenTypes.NE_OLD.equals(operator)) {
setPositive(isOrEqualsOperator, () -> processIsOrEquals(lhs, rhs));
}
}
final Object leftValue = PyEvaluator.evaluateNoResolve(lhs, Object.class);
final Object rightValue = PyEvaluator.evaluateNoResolve(rhs, Object.class);
private void processIsOrEquals(@NotNull PyExpression lhs, @NotNull PyExpression rhs) {
final Boolean leftValue = PyEvaluator.evaluateNoResolve(lhs, Boolean.class);
final Boolean rightValue = PyEvaluator.evaluateNoResolve(rhs, Boolean.class);
if (leftValue instanceof Boolean && rightValue instanceof Boolean) {
if (leftValue != null && rightValue != null) {
return;
}
if (leftValue != null) {
setPositive(leftValue, () -> rhs.accept(this));
return;
}
if (rightValue != null) {
setPositive(rightValue, () -> lhs.accept(this));
return;
}
if (isOperator && (leftValue == Boolean.FALSE || rightValue == Boolean.FALSE) ||
isNotOperator && (leftValue == Boolean.TRUE || rightValue == Boolean.TRUE)) {
myPositive = !myPositive;
super.visitPyBinaryExpression(node);
myPositive = !myPositive;
return;
if (PyLiteralType.isNone(lhs)) {
pushAssertion(rhs, lhs);
}
if (isOperator || isNotOperator) {
if (lhs instanceof PyReferenceExpression target && rhs instanceof PyReferenceExpression) {
pushAssertion(target, isOperator == myPositive, false, context -> {
PyType rhsType = context.getType(rhs);
boolean isEnumMember = rhsType instanceof PyLiteralType literalType &&
PyStdlibTypeProvider.isCustomEnum(literalType.getPyClass(), context);
return isEnumMember ? rhsType : null;
}, null);
return;
}
else {
pushAssertion(lhs, rhs);
}
}
super.visitPyBinaryExpression(node);
private void pushAssertion(@NotNull PyExpression target, @NotNull PyExpression bound) {
if (target instanceof PyReferenceExpression targetRefExpr) {
pushAssertion(targetRefExpr, myPositive, false, context -> {
final PyType literalType = PyLiteralType.getLiteralType(bound, context);
if (literalType != null) {
return literalType;
}
return ObjectUtils.tryCast(context.getType(bound), PyLiteralType.class);
}, null);
}
}
private void setPositive(boolean positive, @NotNull Runnable runnable) {
boolean oldPositive = myPositive;
if (!positive) {
myPositive = !myPositive;
}
try {
runnable.run();
}
finally {
myPositive = oldPositive;
}
}
@ApiStatus.Internal

View File

@@ -7,6 +7,7 @@ import com.jetbrains.python.PyTokenTypes
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.impl.PyEvaluator
import org.jetbrains.annotations.ApiStatus
/**
@@ -68,7 +69,7 @@ class PyLiteralType private constructor(cls: PyClass, val expression: PyExpressi
}
else -> {
val type = if (inferLiteralTypes) {
toLiteralType(value, context, false)
getLiteralType(value, context)
}
else {
null
@@ -157,6 +158,10 @@ class PyLiteralType private constructor(cls: PyClass, val expression: PyExpressi
PyEvaluator.evaluateNoResolve(actual.expression, Any::class.java)
}
@ApiStatus.Internal
@JvmStatic
fun getLiteralType(expression: PyExpression, context: TypeEvalContext): PyType? = toLiteralType(expression, context, false)
/**
* If [expected] type is `typing.Literal[...]`,
* then tries to infer `typing.Literal[...]` for [expression],
@@ -178,11 +183,17 @@ class PyLiteralType private constructor(cls: PyClass, val expression: PyExpressi
type is PyCollectionType && type.elementTypes.any { containsLiteral(it) }
}
@ApiStatus.Internal
@JvmStatic
fun isNone(expression: PyExpression): Boolean {
return expression is PyNoneLiteralExpression && !expression.isEllipsis ||
expression is PyReferenceExpression &&
expression.name == PyNames.NONE &&
LanguageLevel.forElement(expression).isPython2
}
private fun toLiteralType(expression: PyExpression, context: TypeEvalContext, index: Boolean): PyType? {
if (expression is PyNoneLiteralExpression && !expression.isEllipsis ||
expression is PyReferenceExpression &&
expression.name == PyNames.NONE &&
LanguageLevel.forElement(expression).isPython2) return PyNoneType.INSTANCE
if (isNone(expression)) return PyNoneType.INSTANCE
if (index && (expression is PyReferenceExpression || expression is PySubscriptionExpression)) {
val subLiteralType = Ref.deref(PyTypingTypeProvider.getType(expression, context))

View File

@@ -1,13 +1,15 @@
0(1) element: null
1(2) element: PyForStatement
2(3,12) READ ACCESS: lines
2(3,14) READ ACCESS: lines
3(4) element: PyTargetExpression: line
4(5) WRITE ACCESS: line
5(6) element: PyIfStatement
6(7,8) READ ACCESS: line
7(3,12) element: null. Condition: line == '@':false
8(9) element: null. Condition: line == '@':true
9(10) element: PyStatementList
10(11) element: PyAssignmentStatement
11(3,12) WRITE ACCESS: head
12() element: null
6(7,9) READ ACCESS: line
7(8) element: null. Condition: line == '@':false
8(3,14) ASSERTTYPE ACCESS: line
9(10) element: null. Condition: line == '@':true
10(11) ASSERTTYPE ACCESS: line
11(12) element: PyStatementList
12(13) element: PyAssignmentStatement
13(3,14) WRITE ACCESS: head
14() element: null

View File

@@ -3,31 +3,35 @@
2(3) WRITE ACCESS: var
3(4) element: PyIfStatement
4(5) READ ACCESS: a
5(6,7) READ ACCESS: b
6(11) element: null. Condition: a == b:false
7(8) element: null. Condition: a == b:true
8(9) element: PyStatementList
9(10) element: PyAssignmentStatement
10(30) WRITE ACCESS: var
11(12) READ ACCESS: aa
12(13,14) READ ACCESS: bb
13(27) element: null. Condition: aa == bb:false
14(15) element: null. Condition: aa == bb:true
15(16) element: PyStatementList
16(17) element: PyAssignmentStatement
17(18) READ ACCESS: same_changet_expression
18(19) WRITE ACCESS: bbb
19(20) element: PyIfStatement
20(21,22) READ ACCESS: bbb
21(30) element: null. Condition: bbb:false
22(23) element: null. Condition: bbb:true
23(24) ASSERTTYPE ACCESS: bbb
24(25) element: PyStatementList
25(26) element: PyAssignmentStatement
26(30) WRITE ACCESS: var
27(28) element: PyStatementList
28(29) element: PyAssignmentStatement
29(30) WRITE ACCESS: var
30(31) element: PyReturnStatement
31(32) READ ACCESS: var
32() element: null
5(6,8) READ ACCESS: b
6(7) element: null. Condition: a == b:false
7(13) ASSERTTYPE ACCESS: a
8(9) element: null. Condition: a == b:true
9(10) ASSERTTYPE ACCESS: a
10(11) element: PyStatementList
11(12) element: PyAssignmentStatement
12(34) WRITE ACCESS: var
13(14) READ ACCESS: aa
14(15,17) READ ACCESS: bb
15(16) element: null. Condition: aa == bb:false
16(31) ASSERTTYPE ACCESS: aa
17(18) element: null. Condition: aa == bb:true
18(19) ASSERTTYPE ACCESS: aa
19(20) element: PyStatementList
20(21) element: PyAssignmentStatement
21(22) READ ACCESS: same_changet_expression
22(23) WRITE ACCESS: bbb
23(24) element: PyIfStatement
24(25,26) READ ACCESS: bbb
25(34) element: null. Condition: bbb:false
26(27) element: null. Condition: bbb:true
27(28) ASSERTTYPE ACCESS: bbb
28(29) element: PyStatementList
29(30) element: PyAssignmentStatement
30(34) WRITE ACCESS: var
31(32) element: PyStatementList
32(33) element: PyAssignmentStatement
33(34) WRITE ACCESS: var
34(35) element: PyReturnStatement
35(36) READ ACCESS: var
36() element: null

View File

@@ -2,69 +2,77 @@
1(2) element: PyAssignmentStatement
2(3) WRITE ACCESS: a
3(4) element: PyTryExceptStatement
4(5,61) element: PyTryPart
5(6,61) element: PyAssignmentStatement
6(7,61) WRITE ACCESS: b
7(8,61) element: PyForStatement
8(9,61) element: PyTargetExpression: x
9(10,61) WRITE ACCESS: x
10(11,61) element: PyTryExceptStatement
11(12,51) element: PyTryPart
12(13,51) element: PyAssignmentStatement
13(14,51) WRITE ACCESS: c
14(15,51) element: PyTryExceptStatement
15(16,43) element: PyTryPart
16(17,43) element: PyAssignmentStatement
17(18,43) WRITE ACCESS: d
18(19,43) element: PyIfStatement
19(20,21,43) READ ACCESS: x
20(24) element: null. Condition: x == 0:false
21(22) element: null. Condition: x == 0:true
22(23) element: PyStatementList
23(43,46) element: PyBreakStatement
24(25,26,43) READ ACCESS: x
25(29) element: null. Condition: x == 1:false
26(27) element: null. Condition: x == 1:true
27(28) element: PyStatementList
28(7,43,46) element: PyContinueStatement
29(30,31,43) READ ACCESS: x
30(36) element: null. Condition: x == 2:false
31(32) element: null. Condition: x == 2:true
32(33) element: PyStatementList
33(34,43) element: PyRaiseStatement
34(35,43) READ ACCESS: Exception
35(43) element: PyCallExpression: Exception
36(37,38,43) READ ACCESS: x
37(41) element: null. Condition: x == 3:false
38(39) element: null. Condition: x == 3:true
39(40) element: PyStatementList
40(43) element: PyReturnStatement
41(42,43) element: PyAssignmentStatement
42(43,46) WRITE ACCESS: e
43(44,51) element: PyFinallyPart
44(45,51) element: PyAssignmentStatement
45(51) WRITE ACCESS: f
46(47,51) element: PyFinallyPart
47(48,51) element: PyAssignmentStatement
48(49,51,54) WRITE ACCESS: f
4(5,69) element: PyTryPart
5(6,69) element: PyAssignmentStatement
6(7,69) WRITE ACCESS: b
7(8,69) element: PyForStatement
8(9,69) element: PyTargetExpression: x
9(10,69) WRITE ACCESS: x
10(11,69) element: PyTryExceptStatement
11(12,59) element: PyTryPart
12(13,59) element: PyAssignmentStatement
13(14,59) WRITE ACCESS: c
14(15,59) element: PyTryExceptStatement
15(16,51) element: PyTryPart
16(17,51) element: PyAssignmentStatement
17(18,51) WRITE ACCESS: d
18(19,51) element: PyIfStatement
19(20,22,51) READ ACCESS: x
20(21) element: null. Condition: x == 0:false
21(51,26) ASSERTTYPE ACCESS: x
22(23) element: null. Condition: x == 0:true
23(51,24) ASSERTTYPE ACCESS: x
24(25) element: PyStatementList
25(51,54) element: PyBreakStatement
26(27,29,51) READ ACCESS: x
27(28) element: null. Condition: x == 1:false
28(51,33) ASSERTTYPE ACCESS: x
29(30) element: null. Condition: x == 1:true
30(51,31) ASSERTTYPE ACCESS: x
31(32) element: PyStatementList
32(7,51,54) element: PyContinueStatement
33(34,36,51) READ ACCESS: x
34(35) element: null. Condition: x == 2:false
35(51,42) ASSERTTYPE ACCESS: x
36(37) element: null. Condition: x == 2:true
37(51,38) ASSERTTYPE ACCESS: x
38(39) element: PyStatementList
39(40,51) element: PyRaiseStatement
40(41,51) READ ACCESS: Exception
41(51) element: PyCallExpression: Exception
42(43,45,51) READ ACCESS: x
43(44) element: null. Condition: x == 3:false
44(51,49) ASSERTTYPE ACCESS: x
45(46) element: null. Condition: x == 3:true
46(51,47) ASSERTTYPE ACCESS: x
47(48) element: PyStatementList
48(51) element: PyReturnStatement
49(50,51) element: PyAssignmentStatement
50(51,54) WRITE ACCESS: g
51(52,61) element: PyFinallyPart
52(53,61) element: PyAssignmentStatement
53(61) WRITE ACCESS: h
54(55,61) element: PyFinallyPart
55(56,61) element: PyAssignmentStatement
56(57,59,61) WRITE ACCESS: h
57(58,61) element: PyAssignmentStatement
58(8,59,61) WRITE ACCESS: i
59(60,61) element: PyAssignmentStatement
60(61,64) WRITE ACCESS: j
61(62) element: PyFinallyPart
62(63) element: PyAssignmentStatement
63(69) WRITE ACCESS: k
64(65) element: PyFinallyPart
65(66) element: PyAssignmentStatement
66(67) WRITE ACCESS: k
67(68) element: PyAssignmentStatement
68(69) WRITE ACCESS: l
69() element: null
50(51,54) WRITE ACCESS: e
51(52,59) element: PyFinallyPart
52(53,59) element: PyAssignmentStatement
53(59) WRITE ACCESS: f
54(55,59) element: PyFinallyPart
55(56,59) element: PyAssignmentStatement
56(57,59,62) WRITE ACCESS: f
57(58,59) element: PyAssignmentStatement
58(59,62) WRITE ACCESS: g
59(60,69) element: PyFinallyPart
60(61,69) element: PyAssignmentStatement
61(69) WRITE ACCESS: h
62(63,69) element: PyFinallyPart
63(64,69) element: PyAssignmentStatement
64(65,67,69) WRITE ACCESS: h
65(66,69) element: PyAssignmentStatement
66(8,67,69) WRITE ACCESS: i
67(68,69) element: PyAssignmentStatement
68(69,72) WRITE ACCESS: j
69(70) element: PyFinallyPart
70(71) element: PyAssignmentStatement
71(77) WRITE ACCESS: k
72(73) element: PyFinallyPart
73(74) element: PyAssignmentStatement
74(75) WRITE ACCESS: k
75(76) element: PyAssignmentStatement
76(77) WRITE ACCESS: l
77() element: null

View File

@@ -590,6 +590,31 @@ public class Py3TypeTest extends PyTestCase {
""");
}
public void testLiteralTypeNarrowing() {
doTest("Literal[\"abba\"]",
"""
from typing import Literal
def foo(v: str):
if (v == "abba"):
expr = v
""");
doTest("Literal[\"ab\"]",
"""
from typing import Literal
def foo(v: Literal["abba", "ab"]):
if (v != "abba"):
expr = v
""");
doTest("Literal[\"abc\"]",
"""
from typing import Literal
abc: Literal["abc"] = "abc"
def foo(v: str):
if (v == abc):
expr = v
""");
}
// PY-21083
public void testFloatFromhex() {
doTest("float",

View File

@@ -465,6 +465,16 @@ public class PyTypeTest extends PyTestCase {
expr = a""");
}
public void testIfNotEqOperator() {
doTest("Literal[\"ab\"]",
"""
from typing import Literal
def foo(v: Literal["abba", "ab"]):
if (v <> "abba"):
expr = v
""");
}
// PY-4279
public void testFieldReassignment() {
doTest("C1",