PY-76399 Support __set__ descriptor

Add API for inferring the expected type of `__set__` from `value` parameter
Add corresponding logic to PyTypeCheckerInspection to check if assigned value matches the expected descriptor type
Add tests on it

(cherry picked from commit b14ab7b2e40e225b508875a778ceae8986cbb291)

GitOrigin-RevId: 2b15b2b4527a95e5912897ba256dcc73d71c3dcd
This commit is contained in:
Daniil Kalinin
2024-10-02 17:50:33 +02:00
committed by intellij-monorepo-bot
parent 09a1e4b0e1
commit 4ad6f08f45
6 changed files with 208 additions and 2 deletions

View File

@@ -73,6 +73,7 @@ public final @NonNls class PyNames {
public static final String GETATTR = "__getattr__";
public static final String GETATTRIBUTE = "__getattribute__";
public static final String GET = "__get__";
public static final String DUNDER_SET = "__set__";
public static final String __CLASS__ = "__class__";
public static final String DUNDER_METACLASS = "__metaclass__";
public static final @NlsSafe String METACLASS = "metaclass";

View File

@@ -1066,6 +1066,7 @@ INSP.type.checker.expected.types.prefix=Possible type(s):
INSP.type.checker.unexpected.argument.from.paramspec=Unexpected argument (from ParamSpec ''{0}'')
INSP.type.checker.unfilled.parameter.for.paramspec=Parameter ''{0}'' unfilled (from ParamSpec ''{1}'')
INSP.type.checker.unfilled.vararg=Parameter ''{0}'' unfilled, expected ''{1}''
INSP.type.checker.assigned.value.do.not.match.expected.type.from.dunder.set=Assigned type ''{0}'' do not match expected type ''{1}'' of value from '__set__' descriptor of class ''{2}''
# PyTypedDictInspection
INSP.NAME.typed.dict=Invalid TypedDict definition and usages

View File

@@ -5,6 +5,7 @@ import com.intellij.codeInspection.LocalInspectionToolSession;
import com.intellij.codeInspection.ProblemsHolder;
import com.intellij.openapi.diagnostic.Logger;
import com.intellij.openapi.util.Key;
import com.intellij.openapi.util.Pair;
import com.intellij.openapi.util.Ref;
import com.intellij.openapi.util.text.StringUtil;
import com.intellij.psi.PsiElementVisitor;
@@ -18,6 +19,8 @@ import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
import com.jetbrains.python.documentation.PythonDocumentationProvider;
import com.jetbrains.python.inspections.quickfix.PyMakeFunctionReturnTypeQuickFix;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.resolve.PyResolveContext;
import com.jetbrains.python.psi.resolve.RatedResolveResult;
import com.jetbrains.python.psi.types.*;
import one.util.streamex.StreamEx;
import org.jetbrains.annotations.NotNull;
@@ -144,6 +147,60 @@ public class PyTypeCheckerInspection extends PyInspection {
String actualName = PythonDocumentationProvider.getTypeName(actual, myTypeEvalContext);
registerProblem(value, PyPsiBundle.message("INSP.type.checker.expected.type.got.type.instead", expectedName, actualName));
}
matchAssignedAndExpectedDunderSetDescriptorValue(node, value);
}
private void matchAssignedAndExpectedDunderSetDescriptorValue(@NotNull PyTargetExpression targetExpression,
@NotNull PyExpression assignedValue) {
final ScopeOwner scopeOwner = ScopeUtil.getScopeOwner(targetExpression);
if (scopeOwner == null) return;
Pair<PyType, String> infoFromDunderSet =
getExpectedTypeFromDunderSet(targetExpression, scopeOwner, myTypeEvalContext);
if (infoFromDunderSet == null) return;
final PyType expectedTypeFromDunderSet = infoFromDunderSet.first;
final PyType actual = tryPromotingType(assignedValue, expectedTypeFromDunderSet);
String actualName = PythonDocumentationProvider.getTypeName(actual, myTypeEvalContext);
if (expectedTypeFromDunderSet != null && !PyTypeChecker.match(expectedTypeFromDunderSet, actual, myTypeEvalContext)) {
String expectedName =
PythonDocumentationProvider.getVerboseTypeName(expectedTypeFromDunderSet, myTypeEvalContext);
String className = infoFromDunderSet.second;
if (className != null) {
registerProblem(assignedValue,
PyPsiBundle.message("INSP.type.checker.assigned.value.do.not.match.expected.type.from.dunder.set", actualName,
expectedName, className));
}
}
}
private Pair<PyType, String> getExpectedTypeFromDunderSet(@NotNull PyTargetExpression targetExpression,
@NotNull ScopeOwner scopeOwner,
@NotNull TypeEvalContext context) {
PyExpression referenceExpressionFromTarget = PyUtil.createExpressionFromFragment(targetExpression.getText(), scopeOwner);
if (referenceExpressionFromTarget == null) return null;
PyType referenceType = myTypeEvalContext.getType(referenceExpressionFromTarget);
if (referenceType instanceof PyClassType classType) {
Ref<PyType> expectedTypeRefFromSet =
PyDescriptorTypeUtil.getExpectedValueTypeForDunderSet(targetExpression, classType, context);
if (expectedTypeRefFromSet != null) {
final PyResolveContext resolveContext = PyResolveContext.noProperties(context);
final List<? extends RatedResolveResult> members =
classType.resolveMember(PyNames.DUNDER_SET, targetExpression, AccessDirection.READ,
resolveContext);
if (members == null || members.isEmpty() || !(members.get(0).getElement() instanceof PyFunction dunderSetFunc)) return null;
PyClass classContainingDunderSet = dunderSetFunc.getContainingClass();
if (classContainingDunderSet == null || classContainingDunderSet.getName() == null) return null;
return Pair.create(Ref.deref(expectedTypeRefFromSet), classContainingDunderSet.getName());
}
}
return null;
}
private boolean reportTypedDictProblems(@NotNull PyType expected, @NotNull PyTypedDictType actual, @NotNull PyExpression value) {

View File

@@ -5,6 +5,7 @@ import com.jetbrains.python.PyNames;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.resolve.PyResolveContext;
import com.jetbrains.python.psi.resolve.RatedResolveResult;
import one.util.streamex.StreamEx;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
@@ -30,6 +31,19 @@ public final class PyDescriptorTypeUtil {
return getTypeFromSyntheticDunderGetCall(expression, typeFromTargets, context);
}
@Nullable
public static Ref<PyType> getExpectedValueTypeForDunderSet(@NotNull PyQualifiedExpression expression, @Nullable PyType typeFromTargets, @NotNull TypeEvalContext context) {
final PyClassLikeType targetType = as(typeFromTargets, PyClassLikeType.class);
if (targetType == null || targetType.isDefinition()) return null;
final PyResolveContext resolveContext = PyResolveContext.noProperties(context);
final List<? extends RatedResolveResult> members = targetType.resolveMember(PyNames.DUNDER_SET, expression, AccessDirection.READ,
resolveContext);
if (members == null || members.isEmpty()) return null;
return getExpectedTypeFromDunderSet(expression, typeFromTargets, context);
}
@Nullable
private static Ref<PyType> getTypeFromSyntheticDunderGetCall(@NotNull PyQualifiedExpression expression, @NotNull PyType typeFromTargets, @NotNull TypeEvalContext context) {
PyExpression qualifier = expression.getQualifier();
@@ -54,4 +68,45 @@ public final class PyDescriptorTypeUtil {
return null;
}
@Nullable
private static Ref<PyType> getExpectedTypeFromDunderSet(@NotNull PyQualifiedExpression expression,
@NotNull PyType typeFromTargets,
@NotNull TypeEvalContext context) {
PyExpression qualifier = expression.getQualifier();
PyType objectArgumentType = PyNoneType.INSTANCE;
if (qualifier != null && typeFromTargets instanceof PyCallableType) {
PyType qualifierType = context.getType(qualifier);
if (qualifierType instanceof PyClassType classType && !classType.isDefinition()) {
objectArgumentType = qualifierType;
}
}
List<PyFunction> functions =
PySyntheticCallHelper.resolveFunctionsByArgumentTypes(PyNames.DUNDER_SET, List.of(objectArgumentType), typeFromTargets, context);
if (functions.isEmpty()) return null;
PyType expectedSetValueType = StreamEx.of(functions)
.nonNull()
.map(function -> getExpectedDunderSetValueType(function, typeFromTargets, context))
.collect(PyTypeUtil.toUnion());
return Ref.create(expectedSetValueType);
}
@Nullable
private static PyType getExpectedDunderSetValueType(@NotNull PyFunction function, @NotNull PyType receiverType, @NotNull TypeEvalContext context) {
List<PyCallableParameter> parameters = function.getParameters(context);
if (parameters.size() != 3) return null;
// Parameter names may differ, but 'value' parameter should always be the third one
PyCallableParameter valueParameter = parameters.get(2);
if (valueParameter != null) {
PyType type = valueParameter.getArgumentType(context);
if (type != null && receiverType instanceof PyClassType) {
PyTypeChecker.GenericSubstitutions subs = PyTypeChecker.unifyReceiver(receiverType, context);
return PyTypeChecker.substitute(type, subs, context);
}
}
return null;
}
}

View File

@@ -76,7 +76,7 @@ public final class PySyntheticCallHelper {
}
private static @NotNull List<PyFunction> resolveFunctionsByArgumentTypes(@NotNull String functionName,
public static @NotNull List<PyFunction> resolveFunctionsByArgumentTypes(@NotNull String functionName,
@NotNull List<PyType> argumentTypes,
@Nullable PyType receiverType,
@NotNull TypeEvalContext context) {
@@ -154,7 +154,7 @@ public final class PySyntheticCallHelper {
Map<Ref<PyType>, PyCallableParameter> mappedParams = new HashMap<>();
for (int i = 0; i < explicitParameters.size(); i++) {
mappedParams.put(Ref.create(arguments.get(i)), explicitParameters.get(i));
mappedParams.put(Ref.create(i < arguments.size() ? arguments.get(i) : null), explicitParameters.get(i));
}
return new SyntheticCallArgumentsMapping(functionType, implicitParameters, mappedParams, unmappedArguments);

View File

@@ -2170,6 +2170,98 @@ def foo(param: str | int) -> TypeGuard[str]:
""");
}
// PY-76399
public void testAssignedValueMatchesWithDunderSetSimpleCase() {
doTestByText("""
class MyDescriptor:
def __set__(self, obj, value: str) -> None:
...
class Test:
member: MyDescriptor
t = Test()
t.member = "str"
t.member = <warning descr="Assigned type 'int' do not match expected type 'str' of value from __set__ descriptor of class 'MyDescriptor'">123</warning>
t.member = <warning descr="Assigned type 'Type[list]' do not match expected type 'str' of value from __set__ descriptor of class 'MyDescriptor'">list</warning>
""");
}
// PY-76399
public void testAssignedValueMatchesWithGenericDunderSetSimpleCase() {
doTestByText("""
class MyDescriptor[T]:
def __set__(self, obj, value: T) -> None:
...
class Test:
member: MyDescriptor[str]
t = Test()
t.member = "str"
t.member = <warning descr="Assigned type 'int' do not match expected type 'str' of value from __set__ descriptor of class 'MyDescriptor'">123</warning>
t.member = <warning descr="Assigned type 'Type[list]' do not match expected type 'str' of value from __set__ descriptor of class 'MyDescriptor'">list</warning>
""");
}
// PY-76399
public void testAssignedValueMatchesWithDunderSetWithOverloads() {
doTestByText("""
from typing import overload
class MyDescriptor:
@overload
def __set__(self, obj: "Test", value: str) -> None:
...
@overload
def __set__(self, obj: "Prod", value: "LocalizedString") -> None:
...
def __set__(self, obj, value) -> None:
...
class Test:
member: MyDescriptor
class Prod:
member: MyDescriptor
class LocalizedString:
def __init__(self, value: str):
...
t = Test()
t.member = "abc"
t.member = <warning descr="Assigned type 'int' do not match expected type 'str' of value from __set__ descriptor of class 'MyDescriptor'">42</warning>
p = Prod()
p.member = <warning descr="Assigned type 'str' do not match expected type 'LocalizedString' of value from __set__ descriptor of class 'MyDescriptor'">"abc"</warning>
p.member = <warning descr="Assigned type 'int' do not match expected type 'LocalizedString' of value from __set__ descriptor of class 'MyDescriptor'">42</warning>
p.member = LocalizedString("abc")
""");
}
// PY-76399
public void testAssignedValueMatchesWithDunderSetWithLiteralValue() {
doTestByText("""
from typing import Literal
class MyDescriptor:
def __set__(self, obj, value: Literal[42]) -> None:
...
class Test:
member: MyDescriptor
t = Test()
t.member = 42
t.member = <warning descr="Assigned type 'Literal[43]' do not match expected type 'Literal[42]' of value from __set__ descriptor of class 'MyDescriptor'">43</warning>
t.member = <warning descr="Assigned type 'Literal[\\"42\\"]' do not match expected type 'Literal[42]' of value from __set__ descriptor of class 'MyDescriptor'">"42"</warning>
""");
}
// PY-23067
public void testFunctoolsWrapsMultiFile() {
doMultiFileTest();