[pycharm] PY-41827 check for union types in PySimplifyBooleanCheckInspection

to avoid ambiguous conditions


(cherry picked from commit f11c08f8dfea453a6676173bae57f96233429970)

IJ-MR-159379

GitOrigin-RevId: b700b1e5e7c077da7b61661d33c2a3aa135d9e21
This commit is contained in:
Morgan Bartholomew
2025-04-02 18:00:14 +10:00
committed by intellij-monorepo-bot
parent df7a682006
commit fa4b4eec34
3 changed files with 86 additions and 6 deletions

View File

@@ -20,6 +20,7 @@ import com.intellij.codeInspection.LocalInspectionToolSession;
import com.intellij.codeInspection.ProblemsHolder;
import com.intellij.codeInspection.options.OptPane;
import com.intellij.psi.PsiElementVisitor;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.PyPsiBundle;
import com.jetbrains.python.PyTokenTypes;
import com.jetbrains.python.inspections.quickfix.SimplifyBooleanCheckQuickFix;
@@ -27,6 +28,8 @@ import com.jetbrains.python.psi.PyBinaryExpression;
import com.jetbrains.python.psi.PyConditionalStatementPart;
import com.jetbrains.python.psi.PyElementType;
import com.jetbrains.python.psi.PyExpression;
import com.jetbrains.python.psi.types.PyNoneType;
import com.jetbrains.python.psi.types.PyUnionType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
@@ -89,16 +92,53 @@ public final class PySimplifyBooleanCheckInspection extends PyInspection {
public void visitPyBinaryExpression(@NotNull PyBinaryExpression node) {
super.visitPyBinaryExpression(node);
final PyElementType operator = node.getOperator();
final var leftExpression = node.getLeftExpression();
final PyExpression rightExpression = node.getRightExpression();
if (rightExpression == null || rightExpression instanceof PyBinaryExpression ||
node.getLeftExpression() instanceof PyBinaryExpression) {
leftExpression instanceof PyBinaryExpression) {
return;
}
if (PyTokenTypes.EQUALITY_OPERATIONS.contains(operator)) {
if (operandsEqualTo(node, COMPARISON_LITERALS) ||
(!myIgnoreComparisonToZero && operandsEqualTo(node, Collections.singleton("0")))) {
registerProblem(node);
}
final var leftType = myTypeEvalContext.getType(leftExpression);
final var rightType = myTypeEvalContext.getType(rightExpression);
final var isIdentity = node.isOperator(PyNames.IS) || node.isOperator("isnot");
// if no type and `is`, then it's unsafe
if ((leftType == null || rightType == null) && isIdentity) {
return;
}
// because we are comparing to literal values, there will only ever be a union on one side
final var unionMembers = (leftType instanceof PyUnionType unionType)
? unionType.getMembers()
: (rightType instanceof PyUnionType unionType)
? unionType.getMembers()
: null;
final var isOptional = unionMembers != null && unionMembers.contains(PyNoneType.INSTANCE);
// if the union is `X | Y | None` or just `X | Y` then it is unsafe to simplify
if (isOptional && unionMembers.size() > 2 || !isOptional && unionMembers != null) {
return;
}
if (!isIdentity
&& !PyTokenTypes.EQUALITY_OPERATIONS.contains(operator)) {
return;
}
final var compareWithZero = !myIgnoreComparisonToZero && operandsEqualTo(node, Collections.singleton("0"));
boolean compareWithFalsey = operandsEqualTo(node, ImmutableList.of(PyNames.FALSE, "[]"));
// 'x is falsey' where `x` is `T | None`, then it is unsafe to simplify
// because the falsey value will evaluate to `False` which will be ambiguous with `None`
if (isOptional && (compareWithFalsey || compareWithZero)) {
return;
}
if (operandsEqualTo(node, COMPARISON_LITERALS) || compareWithZero) {
registerProblem(node);
}
}