diff --git a/python/python-parser/src/com/jetbrains/python/PyNames.java b/python/python-parser/src/com/jetbrains/python/PyNames.java index ebb901237871..9a1d2fd51a5a 100644 --- a/python/python-parser/src/com/jetbrains/python/PyNames.java +++ b/python/python-parser/src/com/jetbrains/python/PyNames.java @@ -73,6 +73,7 @@ public final @NonNls class PyNames { public static final String GETATTR = "__getattr__"; public static final String GETATTRIBUTE = "__getattribute__"; public static final String GET = "__get__"; + public static final String DUNDER_SET = "__set__"; public static final String __CLASS__ = "__class__"; public static final String DUNDER_METACLASS = "__metaclass__"; public static final @NlsSafe String METACLASS = "metaclass"; diff --git a/python/python-psi-impl/resources/messages/PyPsiBundle.properties b/python/python-psi-impl/resources/messages/PyPsiBundle.properties index 38f8fa532027..33c94f4845cd 100644 --- a/python/python-psi-impl/resources/messages/PyPsiBundle.properties +++ b/python/python-psi-impl/resources/messages/PyPsiBundle.properties @@ -1066,6 +1066,7 @@ INSP.type.checker.expected.types.prefix=Possible type(s): INSP.type.checker.unexpected.argument.from.paramspec=Unexpected argument (from ParamSpec ''{0}'') INSP.type.checker.unfilled.parameter.for.paramspec=Parameter ''{0}'' unfilled (from ParamSpec ''{1}'') INSP.type.checker.unfilled.vararg=Parameter ''{0}'' unfilled, expected ''{1}'' +INSP.type.checker.assigned.value.do.not.match.expected.type.from.dunder.set=Assigned type ''{0}'' do not match expected type ''{1}'' of value from '__set__' descriptor of class ''{2}'' # PyTypedDictInspection INSP.NAME.typed.dict=Invalid TypedDict definition and usages diff --git a/python/python-psi-impl/src/com/jetbrains/python/inspections/PyTypeCheckerInspection.java b/python/python-psi-impl/src/com/jetbrains/python/inspections/PyTypeCheckerInspection.java index 1e3391bcebc3..618a025b2602 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/inspections/PyTypeCheckerInspection.java +++ b/python/python-psi-impl/src/com/jetbrains/python/inspections/PyTypeCheckerInspection.java @@ -5,6 +5,7 @@ import com.intellij.codeInspection.LocalInspectionToolSession; import com.intellij.codeInspection.ProblemsHolder; import com.intellij.openapi.diagnostic.Logger; import com.intellij.openapi.util.Key; +import com.intellij.openapi.util.Pair; import com.intellij.openapi.util.Ref; import com.intellij.openapi.util.text.StringUtil; import com.intellij.psi.PsiElementVisitor; @@ -18,6 +19,8 @@ import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider; import com.jetbrains.python.documentation.PythonDocumentationProvider; import com.jetbrains.python.inspections.quickfix.PyMakeFunctionReturnTypeQuickFix; import com.jetbrains.python.psi.*; +import com.jetbrains.python.psi.resolve.PyResolveContext; +import com.jetbrains.python.psi.resolve.RatedResolveResult; import com.jetbrains.python.psi.types.*; import one.util.streamex.StreamEx; import org.jetbrains.annotations.NotNull; @@ -144,6 +147,60 @@ public class PyTypeCheckerInspection extends PyInspection { String actualName = PythonDocumentationProvider.getTypeName(actual, myTypeEvalContext); registerProblem(value, PyPsiBundle.message("INSP.type.checker.expected.type.got.type.instead", expectedName, actualName)); } + + matchAssignedAndExpectedDunderSetDescriptorValue(node, value); + } + + private void matchAssignedAndExpectedDunderSetDescriptorValue(@NotNull PyTargetExpression targetExpression, + @NotNull PyExpression assignedValue) { + final ScopeOwner scopeOwner = ScopeUtil.getScopeOwner(targetExpression); + if (scopeOwner == null) return; + + Pair infoFromDunderSet = + getExpectedTypeFromDunderSet(targetExpression, scopeOwner, myTypeEvalContext); + if (infoFromDunderSet == null) return; + + final PyType expectedTypeFromDunderSet = infoFromDunderSet.first; + final PyType actual = tryPromotingType(assignedValue, expectedTypeFromDunderSet); + + String actualName = PythonDocumentationProvider.getTypeName(actual, myTypeEvalContext); + if (expectedTypeFromDunderSet != null && !PyTypeChecker.match(expectedTypeFromDunderSet, actual, myTypeEvalContext)) { + String expectedName = + PythonDocumentationProvider.getVerboseTypeName(expectedTypeFromDunderSet, myTypeEvalContext); + String className = infoFromDunderSet.second; + + if (className != null) { + registerProblem(assignedValue, + PyPsiBundle.message("INSP.type.checker.assigned.value.do.not.match.expected.type.from.dunder.set", actualName, + expectedName, className)); + } + } + } + + private Pair getExpectedTypeFromDunderSet(@NotNull PyTargetExpression targetExpression, + @NotNull ScopeOwner scopeOwner, + @NotNull TypeEvalContext context) { + PyExpression referenceExpressionFromTarget = PyUtil.createExpressionFromFragment(targetExpression.getText(), scopeOwner); + if (referenceExpressionFromTarget == null) return null; + + PyType referenceType = myTypeEvalContext.getType(referenceExpressionFromTarget); + if (referenceType instanceof PyClassType classType) { + Ref expectedTypeRefFromSet = + PyDescriptorTypeUtil.getExpectedValueTypeForDunderSet(targetExpression, classType, context); + if (expectedTypeRefFromSet != null) { + final PyResolveContext resolveContext = PyResolveContext.noProperties(context); + final List members = + classType.resolveMember(PyNames.DUNDER_SET, targetExpression, AccessDirection.READ, + resolveContext); + if (members == null || members.isEmpty() || !(members.get(0).getElement() instanceof PyFunction dunderSetFunc)) return null; + PyClass classContainingDunderSet = dunderSetFunc.getContainingClass(); + + if (classContainingDunderSet == null || classContainingDunderSet.getName() == null) return null; + + return Pair.create(Ref.deref(expectedTypeRefFromSet), classContainingDunderSet.getName()); + } + } + return null; } private boolean reportTypedDictProblems(@NotNull PyType expected, @NotNull PyTypedDictType actual, @NotNull PyExpression value) { diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyDescriptorTypeUtil.java b/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyDescriptorTypeUtil.java index 4697f92ab5fa..64d0c6c2a91a 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyDescriptorTypeUtil.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyDescriptorTypeUtil.java @@ -5,6 +5,7 @@ import com.jetbrains.python.PyNames; import com.jetbrains.python.psi.*; import com.jetbrains.python.psi.resolve.PyResolveContext; import com.jetbrains.python.psi.resolve.RatedResolveResult; +import one.util.streamex.StreamEx; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; @@ -30,6 +31,19 @@ public final class PyDescriptorTypeUtil { return getTypeFromSyntheticDunderGetCall(expression, typeFromTargets, context); } + @Nullable + public static Ref getExpectedValueTypeForDunderSet(@NotNull PyQualifiedExpression expression, @Nullable PyType typeFromTargets, @NotNull TypeEvalContext context) { + final PyClassLikeType targetType = as(typeFromTargets, PyClassLikeType.class); + if (targetType == null || targetType.isDefinition()) return null; + + final PyResolveContext resolveContext = PyResolveContext.noProperties(context); + final List members = targetType.resolveMember(PyNames.DUNDER_SET, expression, AccessDirection.READ, + resolveContext); + if (members == null || members.isEmpty()) return null; + + return getExpectedTypeFromDunderSet(expression, typeFromTargets, context); + } + @Nullable private static Ref getTypeFromSyntheticDunderGetCall(@NotNull PyQualifiedExpression expression, @NotNull PyType typeFromTargets, @NotNull TypeEvalContext context) { PyExpression qualifier = expression.getQualifier(); @@ -54,4 +68,45 @@ public final class PyDescriptorTypeUtil { return null; } + @Nullable + private static Ref getExpectedTypeFromDunderSet(@NotNull PyQualifiedExpression expression, + @NotNull PyType typeFromTargets, + @NotNull TypeEvalContext context) { + PyExpression qualifier = expression.getQualifier(); + PyType objectArgumentType = PyNoneType.INSTANCE; + + if (qualifier != null && typeFromTargets instanceof PyCallableType) { + PyType qualifierType = context.getType(qualifier); + if (qualifierType instanceof PyClassType classType && !classType.isDefinition()) { + objectArgumentType = qualifierType; + } + } + + List functions = + PySyntheticCallHelper.resolveFunctionsByArgumentTypes(PyNames.DUNDER_SET, List.of(objectArgumentType), typeFromTargets, context); + if (functions.isEmpty()) return null; + + PyType expectedSetValueType = StreamEx.of(functions) + .nonNull() + .map(function -> getExpectedDunderSetValueType(function, typeFromTargets, context)) + .collect(PyTypeUtil.toUnion()); + return Ref.create(expectedSetValueType); + } + + @Nullable + private static PyType getExpectedDunderSetValueType(@NotNull PyFunction function, @NotNull PyType receiverType, @NotNull TypeEvalContext context) { + List parameters = function.getParameters(context); + if (parameters.size() != 3) return null; + // Parameter names may differ, but 'value' parameter should always be the third one + PyCallableParameter valueParameter = parameters.get(2); + if (valueParameter != null) { + PyType type = valueParameter.getArgumentType(context); + if (type != null && receiverType instanceof PyClassType) { + PyTypeChecker.GenericSubstitutions subs = PyTypeChecker.unifyReceiver(receiverType, context); + return PyTypeChecker.substitute(type, subs, context); + } + } + return null; + } + } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/types/PySyntheticCallHelper.java b/python/python-psi-impl/src/com/jetbrains/python/psi/types/PySyntheticCallHelper.java index d0904b88f3dd..dfb9f90f0429 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/types/PySyntheticCallHelper.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/types/PySyntheticCallHelper.java @@ -76,7 +76,7 @@ public final class PySyntheticCallHelper { } - private static @NotNull List resolveFunctionsByArgumentTypes(@NotNull String functionName, + public static @NotNull List resolveFunctionsByArgumentTypes(@NotNull String functionName, @NotNull List argumentTypes, @Nullable PyType receiverType, @NotNull TypeEvalContext context) { @@ -154,7 +154,7 @@ public final class PySyntheticCallHelper { Map, PyCallableParameter> mappedParams = new HashMap<>(); for (int i = 0; i < explicitParameters.size(); i++) { - mappedParams.put(Ref.create(arguments.get(i)), explicitParameters.get(i)); + mappedParams.put(Ref.create(i < arguments.size() ? arguments.get(i) : null), explicitParameters.get(i)); } return new SyntheticCallArgumentsMapping(functionType, implicitParameters, mappedParams, unmappedArguments); diff --git a/python/testSrc/com/jetbrains/python/inspections/Py3TypeCheckerInspectionTest.java b/python/testSrc/com/jetbrains/python/inspections/Py3TypeCheckerInspectionTest.java index 007fc1a1b447..493c6b4506ba 100644 --- a/python/testSrc/com/jetbrains/python/inspections/Py3TypeCheckerInspectionTest.java +++ b/python/testSrc/com/jetbrains/python/inspections/Py3TypeCheckerInspectionTest.java @@ -2170,6 +2170,98 @@ def foo(param: str | int) -> TypeGuard[str]: """); } + // PY-76399 + public void testAssignedValueMatchesWithDunderSetSimpleCase() { + doTestByText(""" + class MyDescriptor: + + def __set__(self, obj, value: str) -> None: + ... + + class Test: + member: MyDescriptor + + t = Test() + t.member = "str" + t.member = 123 + t.member = list + """); + } + + // PY-76399 + public void testAssignedValueMatchesWithGenericDunderSetSimpleCase() { + doTestByText(""" + class MyDescriptor[T]: + + def __set__(self, obj, value: T) -> None: + ... + + class Test: + member: MyDescriptor[str] + + t = Test() + t.member = "str" + t.member = 123 + t.member = list + """); + } + + // PY-76399 + public void testAssignedValueMatchesWithDunderSetWithOverloads() { + doTestByText(""" + from typing import overload + + class MyDescriptor: + + @overload + def __set__(self, obj: "Test", value: str) -> None: + ... + @overload + def __set__(self, obj: "Prod", value: "LocalizedString") -> None: + ... + def __set__(self, obj, value) -> None: + ... + + class Test: + member: MyDescriptor + + class Prod: + member: MyDescriptor + + class LocalizedString: + def __init__(self, value: str): + ... + + t = Test() + t.member = "abc" + t.member = 42 + p = Prod() + p.member = "abc" + p.member = 42 + p.member = LocalizedString("abc") + """); + } + + // PY-76399 + public void testAssignedValueMatchesWithDunderSetWithLiteralValue() { + doTestByText(""" + from typing import Literal + + + class MyDescriptor: + def __set__(self, obj, value: Literal[42]) -> None: + ... + + class Test: + member: MyDescriptor + + t = Test() + t.member = 42 + t.member = 43 + t.member = "42" + """); + } + // PY-23067 public void testFunctoolsWrapsMultiFile() { doMultiFileTest();