diff --git a/python/python-psi-api/src/com/jetbrains/python/PyNames.java b/python/python-psi-api/src/com/jetbrains/python/PyNames.java index eeb5b5cbe7cc..50dc8929fcbf 100644 --- a/python/python-psi-api/src/com/jetbrains/python/PyNames.java +++ b/python/python-psi-api/src/com/jetbrains/python/PyNames.java @@ -209,8 +209,6 @@ public class PyNames { public static final String COLLECTIONS_NAMEDTUPLE_PY2 = COLLECTIONS + "." + NAMEDTUPLE; public static final String COLLECTIONS_NAMEDTUPLE_PY3 = COLLECTIONS + "." + INIT + "." + NAMEDTUPLE; - public static final String TYPED_DICT = "TypedDict"; - public static final String FORMAT = "format"; public static final String ABSTRACTMETHOD = "abstractmethod"; diff --git a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypedDictTypeProvider.kt b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypedDictTypeProvider.kt index aaf80e366eec..c49f674af263 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypedDictTypeProvider.kt +++ b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypedDictTypeProvider.kt @@ -3,13 +3,11 @@ package com.jetbrains.python.codeInsight.typing import com.intellij.openapi.util.Ref import com.intellij.psi.PsiElement -import com.intellij.psi.util.QualifiedName import com.jetbrains.python.PyNames import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider.* import com.jetbrains.python.psi.* import com.jetbrains.python.psi.impl.PyBuiltinCache import com.jetbrains.python.psi.impl.PyCallExpressionNavigator -import com.jetbrains.python.psi.impl.stubs.PyClassElementType import com.jetbrains.python.psi.impl.stubs.PyTypedDictStubImpl import com.jetbrains.python.psi.resolve.PyResolveContext import com.jetbrains.python.psi.stubs.PyTypedDictStub @@ -31,34 +29,17 @@ class PyTypedDictTypeProvider : PyTypeProviderBase() { companion object { val nameIsTypedDict = { name: String? -> name == TYPED_DICT || name == TYPED_DICT_EXT } + fun isTypedDict(expression: PyExpression, context: TypeEvalContext): Boolean { + return resolveToQualifiedNames(expression, context).any(nameIsTypedDict) + } + fun isTypingTypedDictInheritor(cls: PyClass, context: TypeEvalContext): Boolean { val isTypingTD = { type: PyClassLikeType? -> type is PyTypedDictType || nameIsTypedDict(type?.classQName) } val ancestors = cls.getAncestorTypes(context) - if (ancestors.any(isTypingTD)) return true - - val hasTDAsSuperclass = hasTypedDictAsSuperclass(cls, context) - val hasTDAncestors = ancestors.filterIsInstance() - .any { hasTypedDictAsSuperclass(it.pyClass, context) } - return hasTDAsSuperclass || hasTDAncestors - } - - private fun hasTypedDictAsSuperclass(cls: PyClass, context: TypeEvalContext): Boolean { - when { - context.maySwitchToAST(cls) -> return cls.superClassExpressions.any { superClassExpr -> - resolveToQualifiedNames(superClassExpr, context).any(nameIsTypedDict) - } - cls.stub != null -> { - return containsTypedDictQName(cls.stub.superClasses) - } - else -> return containsTypedDictQName(PyClassElementType.getSuperClassQNames(cls)) - } - } - - private fun containsTypedDictQName(map: Map): Boolean { - return map.any { name -> name == QualifiedName.fromDottedString(TYPED_DICT) || name == QualifiedName.fromDottedString(TYPED_DICT_EXT) } + return ancestors.any(isTypingTD) } fun getTypedDictTypeForResolvedCallee(referenceTarget: PsiElement, context: TypeEvalContext): PyTypedDictType? { @@ -103,7 +84,7 @@ class PyTypedDictTypeProvider : PyTypeProviderBase() { } } - if (resolveToQualifiedNames(referenceExpression, context).contains(TYPED_DICT)) { + if (isTypedDict(referenceExpression, context)) { val parameters = mutableListOf() val builtinCache = PyBuiltinCache.getInstance(referenceExpression) @@ -126,7 +107,9 @@ class PyTypedDictTypeProvider : PyTypeProviderBase() { return getTypedDictTypeForTypingTDInheritorAsCallee(cls, context, false) } - private fun getTypedDictTypeForTypingTDInheritorAsCallee(cls: PyClass, context: TypeEvalContext, isInstance: Boolean): PyTypedDictType? { + private fun getTypedDictTypeForTypingTDInheritorAsCallee(cls: PyClass, + context: TypeEvalContext, + isInstance: Boolean): PyTypedDictType? { if (isTypingTypedDictInheritor(cls, context)) { val ancestors = cls.getAncestorTypes(context).filterIsInstance() val name = cls.name ?: return null diff --git a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java index 1a1b9172cb0b..0535014daf12 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java +++ b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java @@ -51,6 +51,7 @@ import org.jetbrains.annotations.Nullable; import java.util.*; import java.util.regex.Matcher; import java.util.regex.Pattern; +import java.util.stream.Collectors; import static com.intellij.openapi.util.RecursionManager.doPreventingRecursion; import static com.jetbrains.python.psi.PyKnownDecoratorUtil.KnownDecorator.TYPING_FINAL; @@ -446,6 +447,16 @@ public class PyTypingTypeProvider extends PyTypeProviderBase { return null; } + @Nullable + private static PyType getTypedDictTypeForTarget(@NotNull PyTargetExpression referenceTarget, @NotNull TypeEvalContext context) { + if (PyTypedDictTypeProvider.Companion.isTypedDict(referenceTarget, context)) { + return new PyCustomType(TYPED_DICT, null, false, true, + PyBuiltinCache.getInstance(referenceTarget).getDictType()); + } + + return null; + } + @Nullable private static PyType getNewTypeCreationForTarget(@NotNull PyTargetExpression referenceTarget, @NotNull TypeEvalContext context) { final PyTargetExpressionStub stub = referenceTarget.getStub(); @@ -516,6 +527,11 @@ public class PyTypingTypeProvider extends PyTypeProviderBase { return Ref.create(newType); } + final PyType typedDictType = getTypedDictTypeForTarget(target, context); + if (typedDictType != null) { + return Ref.create(typedDictType); + } + final Ref annotatedType = getTypeFromTargetExpressionAnnotation(target, context); if (annotatedType != null) { return annotatedType; @@ -1008,11 +1024,6 @@ public class PyTypingTypeProvider extends PyTypeProviderBase { return null; } - public static boolean isTypedDict(@NotNull PyExpression expression, @NotNull TypeEvalContext context) { - Collection qualifiedNames = resolveToQualifiedNames(expression, context); - return qualifiedNames.stream().anyMatch(name -> TYPED_DICT.equals(name) || TYPED_DICT_EXT.equals(name)); - } - public static boolean isFinal(@NotNull PyDecoratable decoratable, @NotNull TypeEvalContext context) { return ContainerUtil.exists(PyKnownDecoratorUtil.getKnownDecorators(decoratable, context), d -> d == TYPING_FINAL || d == TYPING_FINAL_EXT); diff --git a/python/src/com/jetbrains/python/inspections/PyTypedDictInspection.kt b/python/src/com/jetbrains/python/inspections/PyTypedDictInspection.kt index e74b4a2cc3d0..89d9a1ee46a6 100644 --- a/python/src/com/jetbrains/python/inspections/PyTypedDictInspection.kt +++ b/python/src/com/jetbrains/python/inspections/PyTypedDictInspection.kt @@ -54,10 +54,7 @@ class PyTypedDictInspection : PyInspection() { if (node.hasAssignedValue()) { val value = node.findAssignedValue() - if (value is PyCallExpression && value.callee != null && - PyTypingTypeProvider.resolveToQualifiedNames(value.callee!!, myTypeEvalContext).any { - PyTypedDictTypeProvider.nameIsTypedDict(it) - }) { + if (value is PyCallExpression && value.callee != null && PyTypedDictTypeProvider.isTypedDict(value.callee!!, myTypeEvalContext)) { if (value.arguments.isNotEmpty() && node.name != (value.arguments[0] as? PyStringLiteralExpression)?.stringValue) { registerProblem(value.arguments[0], "First argument has to match the variable name") } @@ -72,8 +69,9 @@ class PyTypedDictInspection : PyInspection() { val arguments = node.arguments for (argument in arguments) { val type = myTypeEvalContext.getType(argument) - if (argument !is PyKeywordArgument && !PyTypingTypeProvider.isTypedDict(argument, - myTypeEvalContext) && type !is PyTypedDictType) { + if (argument !is PyKeywordArgument + && type !is PyTypedDictType + && !PyTypedDictTypeProvider.isTypedDict(argument, myTypeEvalContext)) { registerProblem(argument, "TypedDict cannot inherit from a non-TypedDict base class") } if (argument is PyKeywordArgument && argument.keyword == "total" && !checkValidTotality(argument.valueExpression)) { @@ -83,9 +81,7 @@ class PyTypedDictInspection : PyInspection() { } else if (node.callExpression != null) { val callee = node.callExpression!!.callee - if (callee != null && PyTypingTypeProvider.resolveToQualifiedNames(callee, myTypeEvalContext).any { - PyTypedDictTypeProvider.nameIsTypedDict(it) - }) { + if (callee != null && PyTypedDictTypeProvider.isTypedDict(callee, myTypeEvalContext)) { val totality = node.getKeywordArgument("total")?.valueExpression if (!checkValidTotality(totality)) { registerProblem(totality, "Value of 'total' must be True or False")