PY-51329 Fix problem with overloaded bitwise or operator chains

(cherry picked from commit f265b05500ef8a6b35e7fef090a971f4a5449453)

IJ-MR-17506

GitOrigin-RevId: 5cac0b7d071f753023b97c4e4152acbe388b8f08
This commit is contained in:
andrey.matveev
2021-10-28 18:37:35 +07:00
committed by intellij-monorepo-bot
parent 43b95d766e
commit 806fa72b80
6 changed files with 108 additions and 49 deletions

View File

@@ -765,21 +765,32 @@ public class PyTypingTypeProvider extends PyTypeProviderBase {
PyExpression right = expression.getRightExpression();
if (left == null || right == null) return null;
Ref<PyType> leftType = getType(left, context);
Ref<PyType> rightType = getType(right, context);
if (leftType == null && rightType == null) return null;
Ref<PyType> leftTypeRef = getType(left, context);
Ref<PyType> rightTypeRef = getType(right, context);
if (leftTypeRef == null || rightTypeRef == null) return null;
PyType union;
if (leftType != null && rightType != null) {
union = PyUnionType.union(leftType.get(), rightType.get());
}
else {
union = PyUnionType.createWeakType(Objects.requireNonNullElse(leftType, rightType).get());
}
PyType leftType = leftTypeRef.get();
if (leftType != null && typeHasOverloadedBitwiseOr(leftType, left, context)) return null;
PyType union = PyUnionType.union(leftType, rightTypeRef.get());
return union != null ? Ref.create(union) : null;
}
private static boolean typeHasOverloadedBitwiseOr(@NotNull PyType type, @NotNull PyExpression expression,
@NotNull Context context) {
if (type instanceof PyUnionType) return false;
PyType typeToClass = type instanceof PyClassLikeType ? ((PyClassLikeType)type).toClass() : type;
var resolved = typeToClass.resolveMember("__or__", expression, AccessDirection.READ,
PyResolveContext.defaultContext(context.getTypeContext()));
if (resolved == null || resolved.isEmpty()) return false;
return StreamEx.of(resolved)
.map(it -> it.getElement())
.nonNull()
.noneMatch(it -> PyBuiltinCache.getInstance(it).isBuiltin(it));
}
public static boolean isBitwiseOrUnionAvailable(@NotNull TypeEvalContext context) {
final PsiFile originFile = context.getOrigin();
return originFile == null || isBitwiseOrUnionAvailable(originFile);

View File

@@ -8,7 +8,6 @@ import com.intellij.codeInspection.util.InspectionMessage;
import com.intellij.lang.ASTNode;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.util.NlsSafe;
import com.intellij.openapi.util.Pair;
import com.intellij.openapi.util.Ref;
import com.intellij.openapi.util.TextRange;
import com.intellij.openapi.util.text.StringUtil;
@@ -25,14 +24,12 @@ import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.PyPsiBundle;
import com.jetbrains.python.PyTokenTypes;
import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
import com.jetbrains.python.codeInsight.imports.AddImportHelper;
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
import com.jetbrains.python.inspections.quickfix.*;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.impl.PyPsiUtils;
import com.jetbrains.python.psi.resolve.PyResolveContext;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.PyUnionType;
import com.jetbrains.python.psi.types.TypeEvalContext;
@@ -42,7 +39,6 @@ import org.jetbrains.annotations.Nullable;
import java.util.*;
import java.util.function.Predicate;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
@@ -712,16 +708,6 @@ public abstract class CompatibilityVisitor extends PyAnnotator {
final TypeEvalContext context = TypeEvalContext.codeAnalysis(node.getProject(), node.getContainingFile());
final List<PsiElement> resolvedVariants = PyUtil.multiResolveTopPriority(node.getReference(PyResolveContext.defaultContext(context)));
for (PsiElement resolved : resolvedVariants) {
if (resolved instanceof PyFunction) {
final PyClass containingClass = ((PyFunction)resolved).getContainingClass();
if (containingClass == null) return;
final String classQualifiedName = containingClass.getQualifiedName();
if (!PyNames.TYPE.equals(classQualifiedName) && !"types.UnionType".equals(classQualifiedName)) return;
}
}
// Consider only full expression not parts to have only one registered problem
if (PsiTreeUtil.getParentOfType(node, PyBinaryExpression.class, true, PyStatement.class) != null) return;

View File

@@ -0,0 +1,12 @@
class MyMeta(type):
def __or__(self, other):
return other
class Foo(metaclass=MyMeta):
...
class Bar(metaclass=MyMeta):
...
class Baz(metaclass=MyMeta):
...
print(Foo | Bar | Baz)

View File

@@ -16,6 +16,7 @@
package com.jetbrains.python;
import com.jetbrains.python.fixtures.PyTestCase;
import com.jetbrains.python.psi.LanguageLevel;
import org.jetbrains.annotations.NotNull;
public class Py3HighlightingTest extends PyTestCase {
@@ -88,6 +89,40 @@ public class Py3HighlightingTest extends PyTestCase {
doTest(true, false);
}
// PY-49774
public void testMatchStatementBefore310() {
doTestWithLanguageLevel(LanguageLevel.PYTHON39, true, true);
}
// PY-44974
public void testBitwiseOrUnionInOlderVersionsError() {
doTestWithLanguageLevel(LanguageLevel.PYTHON39, false, false);
}
// PY-44974
public void testBitwiseOrUnionInOlderVersionsErrorIsInstance() {
doTestWithLanguageLevel(LanguageLevel.PYTHON39, false, false);
}
// PY-49697
public void testNoErrorMetaClassOverloadBitwiseOrOperator() {
doTestWithLanguageLevel(LanguageLevel.PYTHON39, false, false);
}
// PY-49697
public void testNoErrorMetaClassOverloadBitwiseOrOperatorReturnTypesUnion() {
doTestWithLanguageLevel(LanguageLevel.PYTHON39, false, false);
}
// PY-51329
public void testNoErrorMetaClassOverloadBitwiseOrChain() {
doTestWithLanguageLevel(LanguageLevel.PYTHON39, false, false);
}
private void doTestWithLanguageLevel(LanguageLevel languageLevel, boolean checkWarnings, boolean checkInfos) {
runWithLanguageLevel(languageLevel, () -> doTest(checkWarnings, checkInfos));
}
private void doTest(boolean checkWarnings, boolean checkInfos) {
myFixture.testHighlighting(checkWarnings, checkInfos, false, TEST_PATH + getTestName(true) + PyNames.DOT_PY);
}

View File

@@ -1323,6 +1323,46 @@ public class Py3TypeTest extends PyTestCase {
"expr = transform(bar)");
}
// PY-51329
public void testBitwiseOrOperatorOverloadUnion() {
doTest("UnionType",
"class MyMeta(type):\n" +
" def __or__(self, other) -> Any:\n" +
" return other\n" +
"\n" +
"class Foo(metaclass=MyMeta):\n" +
" ...\n" +
"\n" +
"expr = Foo | None");
}
// PY-51329
public void testBitwiseOrOperatorOverloadUnionTypeAlias() {
doTest("Any",
"class MyMeta(type):\n" +
" def __or__(self, other) -> Any:\n" +
" return other\n" +
"\n" +
"class Foo(metaclass=MyMeta):\n" +
" ...\n" +
"\n" +
"Alias = Foo | None\n" +
"expr: Alias");
}
// PY-51329
public void testBitwiseOrOperatorOverloadUnionTypeAnnotation() {
doTest("Any",
"class MyMeta(type):\n" +
" def __or__(self, other) -> Any:\n" +
" return other\n" +
"\n" +
"class Foo(metaclass=MyMeta):\n" +
" ...\n" +
"\n" +
"expr: Foo | None");
}
/**
* @see #testRecursiveDictTopDown()
* @see PyTypeCheckerInspectionTest#testRecursiveDictAttribute()

View File

@@ -508,31 +508,6 @@ public class PythonHighlightingTest extends PyTestCase {
doTest(LanguageLevel.PYTHON310, false, true);
}
// PY-49774
public void testMatchStatementBefore310() {
doTest(LanguageLevel.PYTHON39, true, true);
}
// PY-44974
public void testBitwiseOrUnionInOlderVersionsError() {
doTest(LanguageLevel.PYTHON39, false, false);
}
// PY-44974
public void testBitwiseOrUnionInOlderVersionsErrorIsInstance() {
doTest(LanguageLevel.PYTHON39, false, false);
}
// PY-49697
public void testNoErrorMetaClassOverloadBitwiseOrOperator() {
doTest(LanguageLevel.PYTHON39, false, false);
}
// PY-49697
public void testNoErrorMetaClassOverloadBitwiseOrOperatorReturnTypesUnion() {
doTest(LanguageLevel.PYTHON39, false, false);
}
// PY-24653
public void testSelfHighlightingInInnerFunc() {
doTest(LanguageLevel.getLatest(), false, true);