mirror of
https://gitflic.ru/project/openide/openide.git
synced 2026-01-08 15:09:39 +07:00
PY-76906 Support type narrowing for comparisons with literal types
GitOrigin-RevId: 8094877f9b02b57f005391dd1b95befba621b604
This commit is contained in:
committed by
intellij-monorepo-bot
parent
4a13f8ad86
commit
04cac6392b
@@ -3,6 +3,7 @@ package com.jetbrains.python.codeInsight.controlflow;
|
||||
|
||||
import com.intellij.openapi.util.Ref;
|
||||
import com.intellij.psi.PsiElement;
|
||||
import com.intellij.util.ObjectUtils;
|
||||
import com.intellij.util.containers.ContainerUtil;
|
||||
import com.intellij.util.containers.Stack;
|
||||
import com.jetbrains.python.PyNames;
|
||||
@@ -74,62 +75,64 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
|
||||
|
||||
@Override
|
||||
public void visitPyBinaryExpression(@NotNull PyBinaryExpression node) {
|
||||
if (node.isOperator(PyNames.AND) || node.isOperator(PyNames.OR)) return;
|
||||
final PyExpression lhs = PyPsiUtils.flattenParens(node.getLeftExpression());
|
||||
final PyExpression rhs = PyPsiUtils.flattenParens(node.getRightExpression());
|
||||
if (lhs == null || rhs == null) return;
|
||||
|
||||
final PyExpression lhs = node.getLeftExpression();
|
||||
final PyExpression rhs = node.getRightExpression();
|
||||
|
||||
boolean isOperator = node.isOperator(PyNames.IS);
|
||||
boolean isNotOperator = node.isOperator("isnot");
|
||||
if (lhs instanceof PyReferenceExpression && rhs instanceof PyReferenceExpression ||
|
||||
lhs instanceof PyReferenceExpression && rhs instanceof PyNoneLiteralExpression ||
|
||||
lhs instanceof PyNoneLiteralExpression && rhs instanceof PyReferenceExpression) {
|
||||
final boolean leftIsNone = lhs instanceof PyNoneLiteralExpression || PyNames.NONE.equals(lhs.getName());
|
||||
final boolean rightIsNone = rhs instanceof PyNoneLiteralExpression || PyNames.NONE.equals(rhs.getName());
|
||||
|
||||
if (leftIsNone ^ rightIsNone) {
|
||||
final PyReferenceExpression target = (PyReferenceExpression)(rightIsNone ? lhs : rhs);
|
||||
|
||||
if (isOperator) {
|
||||
pushAssertion(target, myPositive, false, context -> PyNoneType.INSTANCE, null);
|
||||
return;
|
||||
}
|
||||
|
||||
if (isNotOperator) {
|
||||
pushAssertion(target, !myPositive, false, context -> PyNoneType.INSTANCE, null);
|
||||
return;
|
||||
}
|
||||
}
|
||||
PyElementType operator = node.getOperator();
|
||||
boolean isOrEqualsOperator = node.isOperator(PyNames.IS) || PyTokenTypes.EQEQ.equals(operator);
|
||||
if (isOrEqualsOperator || node.isOperator("isnot") || PyTokenTypes.NE.equals(operator) || PyTokenTypes.NE_OLD.equals(operator)) {
|
||||
setPositive(isOrEqualsOperator, () -> processIsOrEquals(lhs, rhs));
|
||||
}
|
||||
}
|
||||
|
||||
final Object leftValue = PyEvaluator.evaluateNoResolve(lhs, Object.class);
|
||||
final Object rightValue = PyEvaluator.evaluateNoResolve(rhs, Object.class);
|
||||
private void processIsOrEquals(@NotNull PyExpression lhs, @NotNull PyExpression rhs) {
|
||||
final Boolean leftValue = PyEvaluator.evaluateNoResolve(lhs, Boolean.class);
|
||||
final Boolean rightValue = PyEvaluator.evaluateNoResolve(rhs, Boolean.class);
|
||||
|
||||
if (leftValue instanceof Boolean && rightValue instanceof Boolean) {
|
||||
if (leftValue != null && rightValue != null) {
|
||||
return;
|
||||
}
|
||||
if (leftValue != null) {
|
||||
setPositive(leftValue, () -> rhs.accept(this));
|
||||
return;
|
||||
}
|
||||
if (rightValue != null) {
|
||||
setPositive(rightValue, () -> lhs.accept(this));
|
||||
return;
|
||||
}
|
||||
|
||||
if (isOperator && (leftValue == Boolean.FALSE || rightValue == Boolean.FALSE) ||
|
||||
isNotOperator && (leftValue == Boolean.TRUE || rightValue == Boolean.TRUE)) {
|
||||
myPositive = !myPositive;
|
||||
super.visitPyBinaryExpression(node);
|
||||
myPositive = !myPositive;
|
||||
return;
|
||||
if (PyLiteralType.isNone(lhs)) {
|
||||
pushAssertion(rhs, lhs);
|
||||
}
|
||||
|
||||
if (isOperator || isNotOperator) {
|
||||
if (lhs instanceof PyReferenceExpression target && rhs instanceof PyReferenceExpression) {
|
||||
pushAssertion(target, isOperator == myPositive, false, context -> {
|
||||
PyType rhsType = context.getType(rhs);
|
||||
boolean isEnumMember = rhsType instanceof PyLiteralType literalType &&
|
||||
PyStdlibTypeProvider.isCustomEnum(literalType.getPyClass(), context);
|
||||
return isEnumMember ? rhsType : null;
|
||||
}, null);
|
||||
return;
|
||||
}
|
||||
else {
|
||||
pushAssertion(lhs, rhs);
|
||||
}
|
||||
}
|
||||
|
||||
super.visitPyBinaryExpression(node);
|
||||
private void pushAssertion(@NotNull PyExpression target, @NotNull PyExpression bound) {
|
||||
if (target instanceof PyReferenceExpression targetRefExpr) {
|
||||
pushAssertion(targetRefExpr, myPositive, false, context -> {
|
||||
final PyType literalType = PyLiteralType.getLiteralType(bound, context);
|
||||
if (literalType != null) {
|
||||
return literalType;
|
||||
}
|
||||
return ObjectUtils.tryCast(context.getType(bound), PyLiteralType.class);
|
||||
}, null);
|
||||
}
|
||||
}
|
||||
|
||||
private void setPositive(boolean positive, @NotNull Runnable runnable) {
|
||||
boolean oldPositive = myPositive;
|
||||
if (!positive) {
|
||||
myPositive = !myPositive;
|
||||
}
|
||||
try {
|
||||
runnable.run();
|
||||
}
|
||||
finally {
|
||||
myPositive = oldPositive;
|
||||
}
|
||||
}
|
||||
|
||||
@ApiStatus.Internal
|
||||
|
||||
@@ -7,6 +7,7 @@ import com.jetbrains.python.PyTokenTypes
|
||||
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider
|
||||
import com.jetbrains.python.psi.*
|
||||
import com.jetbrains.python.psi.impl.PyEvaluator
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
|
||||
/**
|
||||
@@ -68,7 +69,7 @@ class PyLiteralType private constructor(cls: PyClass, val expression: PyExpressi
|
||||
}
|
||||
else -> {
|
||||
val type = if (inferLiteralTypes) {
|
||||
toLiteralType(value, context, false)
|
||||
getLiteralType(value, context)
|
||||
}
|
||||
else {
|
||||
null
|
||||
@@ -157,6 +158,10 @@ class PyLiteralType private constructor(cls: PyClass, val expression: PyExpressi
|
||||
PyEvaluator.evaluateNoResolve(actual.expression, Any::class.java)
|
||||
}
|
||||
|
||||
@ApiStatus.Internal
|
||||
@JvmStatic
|
||||
fun getLiteralType(expression: PyExpression, context: TypeEvalContext): PyType? = toLiteralType(expression, context, false)
|
||||
|
||||
/**
|
||||
* If [expected] type is `typing.Literal[...]`,
|
||||
* then tries to infer `typing.Literal[...]` for [expression],
|
||||
@@ -178,11 +183,17 @@ class PyLiteralType private constructor(cls: PyClass, val expression: PyExpressi
|
||||
type is PyCollectionType && type.elementTypes.any { containsLiteral(it) }
|
||||
}
|
||||
|
||||
@ApiStatus.Internal
|
||||
@JvmStatic
|
||||
fun isNone(expression: PyExpression): Boolean {
|
||||
return expression is PyNoneLiteralExpression && !expression.isEllipsis ||
|
||||
expression is PyReferenceExpression &&
|
||||
expression.name == PyNames.NONE &&
|
||||
LanguageLevel.forElement(expression).isPython2
|
||||
}
|
||||
|
||||
private fun toLiteralType(expression: PyExpression, context: TypeEvalContext, index: Boolean): PyType? {
|
||||
if (expression is PyNoneLiteralExpression && !expression.isEllipsis ||
|
||||
expression is PyReferenceExpression &&
|
||||
expression.name == PyNames.NONE &&
|
||||
LanguageLevel.forElement(expression).isPython2) return PyNoneType.INSTANCE
|
||||
if (isNone(expression)) return PyNoneType.INSTANCE
|
||||
|
||||
if (index && (expression is PyReferenceExpression || expression is PySubscriptionExpression)) {
|
||||
val subLiteralType = Ref.deref(PyTypingTypeProvider.getType(expression, context))
|
||||
|
||||
@@ -1,13 +1,15 @@
|
||||
0(1) element: null
|
||||
1(2) element: PyForStatement
|
||||
2(3,12) READ ACCESS: lines
|
||||
2(3,14) READ ACCESS: lines
|
||||
3(4) element: PyTargetExpression: line
|
||||
4(5) WRITE ACCESS: line
|
||||
5(6) element: PyIfStatement
|
||||
6(7,8) READ ACCESS: line
|
||||
7(3,12) element: null. Condition: line == '@':false
|
||||
8(9) element: null. Condition: line == '@':true
|
||||
9(10) element: PyStatementList
|
||||
10(11) element: PyAssignmentStatement
|
||||
11(3,12) WRITE ACCESS: head
|
||||
12() element: null
|
||||
6(7,9) READ ACCESS: line
|
||||
7(8) element: null. Condition: line == '@':false
|
||||
8(3,14) ASSERTTYPE ACCESS: line
|
||||
9(10) element: null. Condition: line == '@':true
|
||||
10(11) ASSERTTYPE ACCESS: line
|
||||
11(12) element: PyStatementList
|
||||
12(13) element: PyAssignmentStatement
|
||||
13(3,14) WRITE ACCESS: head
|
||||
14() element: null
|
||||
@@ -3,31 +3,35 @@
|
||||
2(3) WRITE ACCESS: var
|
||||
3(4) element: PyIfStatement
|
||||
4(5) READ ACCESS: a
|
||||
5(6,7) READ ACCESS: b
|
||||
6(11) element: null. Condition: a == b:false
|
||||
7(8) element: null. Condition: a == b:true
|
||||
8(9) element: PyStatementList
|
||||
9(10) element: PyAssignmentStatement
|
||||
10(30) WRITE ACCESS: var
|
||||
11(12) READ ACCESS: aa
|
||||
12(13,14) READ ACCESS: bb
|
||||
13(27) element: null. Condition: aa == bb:false
|
||||
14(15) element: null. Condition: aa == bb:true
|
||||
15(16) element: PyStatementList
|
||||
16(17) element: PyAssignmentStatement
|
||||
17(18) READ ACCESS: same_changet_expression
|
||||
18(19) WRITE ACCESS: bbb
|
||||
19(20) element: PyIfStatement
|
||||
20(21,22) READ ACCESS: bbb
|
||||
21(30) element: null. Condition: bbb:false
|
||||
22(23) element: null. Condition: bbb:true
|
||||
23(24) ASSERTTYPE ACCESS: bbb
|
||||
24(25) element: PyStatementList
|
||||
25(26) element: PyAssignmentStatement
|
||||
26(30) WRITE ACCESS: var
|
||||
27(28) element: PyStatementList
|
||||
28(29) element: PyAssignmentStatement
|
||||
29(30) WRITE ACCESS: var
|
||||
30(31) element: PyReturnStatement
|
||||
31(32) READ ACCESS: var
|
||||
32() element: null
|
||||
5(6,8) READ ACCESS: b
|
||||
6(7) element: null. Condition: a == b:false
|
||||
7(13) ASSERTTYPE ACCESS: a
|
||||
8(9) element: null. Condition: a == b:true
|
||||
9(10) ASSERTTYPE ACCESS: a
|
||||
10(11) element: PyStatementList
|
||||
11(12) element: PyAssignmentStatement
|
||||
12(34) WRITE ACCESS: var
|
||||
13(14) READ ACCESS: aa
|
||||
14(15,17) READ ACCESS: bb
|
||||
15(16) element: null. Condition: aa == bb:false
|
||||
16(31) ASSERTTYPE ACCESS: aa
|
||||
17(18) element: null. Condition: aa == bb:true
|
||||
18(19) ASSERTTYPE ACCESS: aa
|
||||
19(20) element: PyStatementList
|
||||
20(21) element: PyAssignmentStatement
|
||||
21(22) READ ACCESS: same_changet_expression
|
||||
22(23) WRITE ACCESS: bbb
|
||||
23(24) element: PyIfStatement
|
||||
24(25,26) READ ACCESS: bbb
|
||||
25(34) element: null. Condition: bbb:false
|
||||
26(27) element: null. Condition: bbb:true
|
||||
27(28) ASSERTTYPE ACCESS: bbb
|
||||
28(29) element: PyStatementList
|
||||
29(30) element: PyAssignmentStatement
|
||||
30(34) WRITE ACCESS: var
|
||||
31(32) element: PyStatementList
|
||||
32(33) element: PyAssignmentStatement
|
||||
33(34) WRITE ACCESS: var
|
||||
34(35) element: PyReturnStatement
|
||||
35(36) READ ACCESS: var
|
||||
36() element: null
|
||||
@@ -2,69 +2,77 @@
|
||||
1(2) element: PyAssignmentStatement
|
||||
2(3) WRITE ACCESS: a
|
||||
3(4) element: PyTryExceptStatement
|
||||
4(5,61) element: PyTryPart
|
||||
5(6,61) element: PyAssignmentStatement
|
||||
6(7,61) WRITE ACCESS: b
|
||||
7(8,61) element: PyForStatement
|
||||
8(9,61) element: PyTargetExpression: x
|
||||
9(10,61) WRITE ACCESS: x
|
||||
10(11,61) element: PyTryExceptStatement
|
||||
11(12,51) element: PyTryPart
|
||||
12(13,51) element: PyAssignmentStatement
|
||||
13(14,51) WRITE ACCESS: c
|
||||
14(15,51) element: PyTryExceptStatement
|
||||
15(16,43) element: PyTryPart
|
||||
16(17,43) element: PyAssignmentStatement
|
||||
17(18,43) WRITE ACCESS: d
|
||||
18(19,43) element: PyIfStatement
|
||||
19(20,21,43) READ ACCESS: x
|
||||
20(24) element: null. Condition: x == 0:false
|
||||
21(22) element: null. Condition: x == 0:true
|
||||
22(23) element: PyStatementList
|
||||
23(43,46) element: PyBreakStatement
|
||||
24(25,26,43) READ ACCESS: x
|
||||
25(29) element: null. Condition: x == 1:false
|
||||
26(27) element: null. Condition: x == 1:true
|
||||
27(28) element: PyStatementList
|
||||
28(7,43,46) element: PyContinueStatement
|
||||
29(30,31,43) READ ACCESS: x
|
||||
30(36) element: null. Condition: x == 2:false
|
||||
31(32) element: null. Condition: x == 2:true
|
||||
32(33) element: PyStatementList
|
||||
33(34,43) element: PyRaiseStatement
|
||||
34(35,43) READ ACCESS: Exception
|
||||
35(43) element: PyCallExpression: Exception
|
||||
36(37,38,43) READ ACCESS: x
|
||||
37(41) element: null. Condition: x == 3:false
|
||||
38(39) element: null. Condition: x == 3:true
|
||||
39(40) element: PyStatementList
|
||||
40(43) element: PyReturnStatement
|
||||
41(42,43) element: PyAssignmentStatement
|
||||
42(43,46) WRITE ACCESS: e
|
||||
43(44,51) element: PyFinallyPart
|
||||
44(45,51) element: PyAssignmentStatement
|
||||
45(51) WRITE ACCESS: f
|
||||
46(47,51) element: PyFinallyPart
|
||||
47(48,51) element: PyAssignmentStatement
|
||||
48(49,51,54) WRITE ACCESS: f
|
||||
4(5,69) element: PyTryPart
|
||||
5(6,69) element: PyAssignmentStatement
|
||||
6(7,69) WRITE ACCESS: b
|
||||
7(8,69) element: PyForStatement
|
||||
8(9,69) element: PyTargetExpression: x
|
||||
9(10,69) WRITE ACCESS: x
|
||||
10(11,69) element: PyTryExceptStatement
|
||||
11(12,59) element: PyTryPart
|
||||
12(13,59) element: PyAssignmentStatement
|
||||
13(14,59) WRITE ACCESS: c
|
||||
14(15,59) element: PyTryExceptStatement
|
||||
15(16,51) element: PyTryPart
|
||||
16(17,51) element: PyAssignmentStatement
|
||||
17(18,51) WRITE ACCESS: d
|
||||
18(19,51) element: PyIfStatement
|
||||
19(20,22,51) READ ACCESS: x
|
||||
20(21) element: null. Condition: x == 0:false
|
||||
21(51,26) ASSERTTYPE ACCESS: x
|
||||
22(23) element: null. Condition: x == 0:true
|
||||
23(51,24) ASSERTTYPE ACCESS: x
|
||||
24(25) element: PyStatementList
|
||||
25(51,54) element: PyBreakStatement
|
||||
26(27,29,51) READ ACCESS: x
|
||||
27(28) element: null. Condition: x == 1:false
|
||||
28(51,33) ASSERTTYPE ACCESS: x
|
||||
29(30) element: null. Condition: x == 1:true
|
||||
30(51,31) ASSERTTYPE ACCESS: x
|
||||
31(32) element: PyStatementList
|
||||
32(7,51,54) element: PyContinueStatement
|
||||
33(34,36,51) READ ACCESS: x
|
||||
34(35) element: null. Condition: x == 2:false
|
||||
35(51,42) ASSERTTYPE ACCESS: x
|
||||
36(37) element: null. Condition: x == 2:true
|
||||
37(51,38) ASSERTTYPE ACCESS: x
|
||||
38(39) element: PyStatementList
|
||||
39(40,51) element: PyRaiseStatement
|
||||
40(41,51) READ ACCESS: Exception
|
||||
41(51) element: PyCallExpression: Exception
|
||||
42(43,45,51) READ ACCESS: x
|
||||
43(44) element: null. Condition: x == 3:false
|
||||
44(51,49) ASSERTTYPE ACCESS: x
|
||||
45(46) element: null. Condition: x == 3:true
|
||||
46(51,47) ASSERTTYPE ACCESS: x
|
||||
47(48) element: PyStatementList
|
||||
48(51) element: PyReturnStatement
|
||||
49(50,51) element: PyAssignmentStatement
|
||||
50(51,54) WRITE ACCESS: g
|
||||
51(52,61) element: PyFinallyPart
|
||||
52(53,61) element: PyAssignmentStatement
|
||||
53(61) WRITE ACCESS: h
|
||||
54(55,61) element: PyFinallyPart
|
||||
55(56,61) element: PyAssignmentStatement
|
||||
56(57,59,61) WRITE ACCESS: h
|
||||
57(58,61) element: PyAssignmentStatement
|
||||
58(8,59,61) WRITE ACCESS: i
|
||||
59(60,61) element: PyAssignmentStatement
|
||||
60(61,64) WRITE ACCESS: j
|
||||
61(62) element: PyFinallyPart
|
||||
62(63) element: PyAssignmentStatement
|
||||
63(69) WRITE ACCESS: k
|
||||
64(65) element: PyFinallyPart
|
||||
65(66) element: PyAssignmentStatement
|
||||
66(67) WRITE ACCESS: k
|
||||
67(68) element: PyAssignmentStatement
|
||||
68(69) WRITE ACCESS: l
|
||||
69() element: null
|
||||
50(51,54) WRITE ACCESS: e
|
||||
51(52,59) element: PyFinallyPart
|
||||
52(53,59) element: PyAssignmentStatement
|
||||
53(59) WRITE ACCESS: f
|
||||
54(55,59) element: PyFinallyPart
|
||||
55(56,59) element: PyAssignmentStatement
|
||||
56(57,59,62) WRITE ACCESS: f
|
||||
57(58,59) element: PyAssignmentStatement
|
||||
58(59,62) WRITE ACCESS: g
|
||||
59(60,69) element: PyFinallyPart
|
||||
60(61,69) element: PyAssignmentStatement
|
||||
61(69) WRITE ACCESS: h
|
||||
62(63,69) element: PyFinallyPart
|
||||
63(64,69) element: PyAssignmentStatement
|
||||
64(65,67,69) WRITE ACCESS: h
|
||||
65(66,69) element: PyAssignmentStatement
|
||||
66(8,67,69) WRITE ACCESS: i
|
||||
67(68,69) element: PyAssignmentStatement
|
||||
68(69,72) WRITE ACCESS: j
|
||||
69(70) element: PyFinallyPart
|
||||
70(71) element: PyAssignmentStatement
|
||||
71(77) WRITE ACCESS: k
|
||||
72(73) element: PyFinallyPart
|
||||
73(74) element: PyAssignmentStatement
|
||||
74(75) WRITE ACCESS: k
|
||||
75(76) element: PyAssignmentStatement
|
||||
76(77) WRITE ACCESS: l
|
||||
77() element: null
|
||||
@@ -590,6 +590,31 @@ public class Py3TypeTest extends PyTestCase {
|
||||
""");
|
||||
}
|
||||
|
||||
public void testLiteralTypeNarrowing() {
|
||||
doTest("Literal[\"abba\"]",
|
||||
"""
|
||||
from typing import Literal
|
||||
def foo(v: str):
|
||||
if (v == "abba"):
|
||||
expr = v
|
||||
""");
|
||||
doTest("Literal[\"ab\"]",
|
||||
"""
|
||||
from typing import Literal
|
||||
def foo(v: Literal["abba", "ab"]):
|
||||
if (v != "abba"):
|
||||
expr = v
|
||||
""");
|
||||
doTest("Literal[\"abc\"]",
|
||||
"""
|
||||
from typing import Literal
|
||||
abc: Literal["abc"] = "abc"
|
||||
def foo(v: str):
|
||||
if (v == abc):
|
||||
expr = v
|
||||
""");
|
||||
}
|
||||
|
||||
// PY-21083
|
||||
public void testFloatFromhex() {
|
||||
doTest("float",
|
||||
|
||||
@@ -465,6 +465,16 @@ public class PyTypeTest extends PyTestCase {
|
||||
expr = a""");
|
||||
}
|
||||
|
||||
public void testIfNotEqOperator() {
|
||||
doTest("Literal[\"ab\"]",
|
||||
"""
|
||||
from typing import Literal
|
||||
def foo(v: Literal["abba", "ab"]):
|
||||
if (v <> "abba"):
|
||||
expr = v
|
||||
""");
|
||||
}
|
||||
|
||||
// PY-4279
|
||||
public void testFieldReassignment() {
|
||||
doTest("C1",
|
||||
|
||||
Reference in New Issue
Block a user