PY-79816 Introduce PyType#getMemberTypes and use it to infer __hash__ type of dataclasses

`getMemberTypes` should be used for members which have no PSI which can be used to resolve to. For example, `__init__` method in dataclasses are sometimes not present in the source code. Yet its parameters are always useful for the code analysis. In this case, `getMemberTypes` should be used.

GitOrigin-RevId: 2455ed05099842fc50e1fa2a196c4952b6444795
This commit is contained in:
evgeny.bovykin
2025-07-02 08:33:56 +00:00
committed by intellij-monorepo-bot
parent 9094c7acdc
commit abe6184e8a
12 changed files with 267 additions and 52 deletions

View File

@@ -14,6 +14,7 @@ import com.jetbrains.python.codeInsight.PyDataclassNames.Attrs
import com.jetbrains.python.codeInsight.PyDataclassNames.Dataclasses
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.impl.PyBuiltinCache
import com.jetbrains.python.psi.impl.PyCallExpressionNavigator
import com.jetbrains.python.psi.resolve.PyResolveContext
import com.jetbrains.python.psi.types.*
@@ -76,6 +77,37 @@ class PyDataclassTypeProvider : PyTypeProviderBase() {
return null
}
override fun getMemberTypes(type: PyType, name: String, location: PyExpression?, direction: AccessDirection, context: PyResolveContext): List<PyTypedResolveResult>? {
if (type !is PyClassType) {
return null
}
if (PyNames.HASH == name) {
// See `unsafe_hash` section here https://docs.python.org/3/library/dataclasses.html
val dataclassParameters = parseDataclassParameters(type.pyClass, context.typeEvalContext)
if (dataclassParameters == null) return null
if (dataclassParameters.unsafeHash) {
return null
}
if (!dataclassParameters.eq) {
return null
}
if (dataclassParameters.frozen) {
return null
}
val resolvedMembers = type.resolveMember(name, location, direction, context, false)
if (resolvedMembers?.isNotEmpty() == true) {
return null
}
return listOf(PyTypedResolveResult(null, PyBuiltinCache.getInstance(type.pyClass).noneType))
}
return null
}
companion object {
@ApiStatus.Internal
fun getInitVars(
@@ -199,8 +231,8 @@ class PyDataclassTypeProvider : PyTypeProviderBase() {
fieldsInfo.forEachIndexed { index, (name, kwOnly, parameter) ->
// note: attributes are visited from inheritors to ancestors, in reversed order for every of them
if ((seenKeywordOnlyClass && (parameters.type == PyDataclassParameters.PredefinedType.ATTRS || kwOnly != false)
|| index < indexOfKeywordOnlyAttribute || kwOnly == true)
if ((seenKeywordOnlyClass && (parameters.type == PyDataclassParameters.PredefinedType.ATTRS || kwOnly != false)
|| index < indexOfKeywordOnlyAttribute || kwOnly == true)
&& name !in collected) {
keywordOnly += name
}

View File

@@ -4,18 +4,11 @@ package com.jetbrains.python.codeInsight.typing
import com.jetbrains.python.PyNames
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider.PROTOCOL
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider.PROTOCOL_EXT
import com.jetbrains.python.psi.AccessDirection
import com.jetbrains.python.psi.PyClass
import com.jetbrains.python.psi.PyPossibleClassMember
import com.jetbrains.python.psi.PyTypeParameter
import com.jetbrains.python.psi.PyTypedElement
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.impl.getImplicitlyInvokedMethodTypes
import com.jetbrains.python.psi.impl.resolveImplicitlyInvokedMethods
import com.jetbrains.python.psi.resolve.PyResolveContext
import com.jetbrains.python.psi.resolve.RatedResolveResult
import com.jetbrains.python.psi.types.PyClassLikeType
import com.jetbrains.python.psi.types.PyClassType
import com.jetbrains.python.psi.types.PyType
import com.jetbrains.python.psi.types.TypeEvalContext
import com.jetbrains.python.psi.types.*
fun isProtocol(classLikeType: PyClassLikeType, context: TypeEvalContext): Boolean = containsProtocol(classLikeType.getSuperClassTypes(context))
@@ -29,11 +22,11 @@ fun matchingProtocolDefinitions(expected: PyType?, actual: PyType?, context: Typ
isProtocol(expected, context) &&
isProtocol(actual, context)
typealias ProtocolAndSubclassElements = Pair<PyTypedElement, List<RatedResolveResult>?>
typealias ProtocolAndSubclassElements = Pair<PyTypedElement, List<PyTypedResolveResult>>
fun inspectProtocolSubclass(protocol: PyClassType, subclass: PyClassType, context: TypeEvalContext): List<ProtocolAndSubclassElements> {
val resolveContext = PyResolveContext.defaultContext(context)
val result = mutableListOf<Pair<PyTypedElement, List<RatedResolveResult>?>>()
val result = mutableListOf<Pair<PyTypedElement, List<PyTypedResolveResult>>>()
protocol.toInstance().visitMembers(
{ e ->
@@ -51,8 +44,33 @@ fun inspectProtocolSubclass(protocol: PyClassType, subclass: PyClassType, contex
val name = e.name ?: return@visitMembers true
when (name) {
PyNames.CLASS_GETITEM -> return@visitMembers true
PyNames.CALL -> result.add(Pair(e, subclass.resolveImplicitlyInvokedMethods(null, resolveContext)))
else -> result.add(Pair(e, subclass.resolveMember(name, null, AccessDirection.READ, resolveContext)))
PyNames.CALL -> {
val types = subclass.getImplicitlyInvokedMethodTypes(null, resolveContext)
if (types.isNotEmpty()) {
result.add(Pair(e, types))
}
else {
val fallbackTypes = subclass.resolveImplicitlyInvokedMethods(null, resolveContext)
.mapNotNull { it.element }
.filterIsInstance<PyTypedElement>()
.mapNotNull {
val type = resolveContext.typeEvalContext.getType(it)
if (type != null) {
it to type
}
else {
null
}
}
result.add(Pair(e, fallbackTypes.map { PyTypedResolveResult(it.first, it.second) }))
}
}
else -> {
val types = subclass.getMemberTypes(name, null, AccessDirection.READ, resolveContext)
if (types != null) {
result.add(Pair(e, types))
}
}
}
}

View File

@@ -15,11 +15,7 @@ import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.PyKnownDecorator.*
import com.jetbrains.python.psi.resolve.PyResolveContext
import com.jetbrains.python.psi.resolve.PyResolveUtil
import com.jetbrains.python.psi.resolve.RatedResolveResult
import com.jetbrains.python.psi.types.PyClassLikeType
import com.jetbrains.python.psi.types.PyClassType
import com.jetbrains.python.psi.types.PyTypeChecker
import com.jetbrains.python.psi.types.TypeEvalContext
import com.jetbrains.python.psi.types.*
class PyProtocolInspection : PyInspection() {
@@ -55,7 +51,7 @@ class PyProtocolInspection : PyInspection() {
.forEach { protocol ->
inspectProtocolSubclass(protocol, type, myTypeEvalContext).forEach {
val subclassElements = it.second
if (!subclassElements.isNullOrEmpty()) {
if (subclassElements.isNotEmpty()) {
checkMemberCompatibility(it.first, subclassElements, type, protocol)
}
}
@@ -127,21 +123,23 @@ class PyProtocolInspection : PyInspection() {
}
}
private fun checkMemberCompatibility(protocolElement: PyTypedElement,
subclassElements: List<RatedResolveResult>,
type: PyClassType,
protocol: PyClassType) {
private fun checkMemberCompatibility(
protocolElement: PyTypedElement,
subclassElements: List<PyTypedResolveResult>,
type: PyClassType,
protocol: PyClassType
) {
val expectedMemberType = myTypeEvalContext.getType(protocolElement)
subclassElements
.asSequence()
.map { it.element }
.filterIsInstance<PyTypedElement>()
.filter { it.containingFile == type.pyClass.containingFile }
.filterNot { PyTypeChecker.match(expectedMemberType, myTypeEvalContext.getType(it), myTypeEvalContext) }
.filter { it.element?.containingFile == type.pyClass.containingFile }
.filterNot { PyTypeChecker.match(expectedMemberType, it.type, myTypeEvalContext) }
.forEach {
val place = if (it is PsiNameIdentifierOwner) it.nameIdentifier else it
registerProblem(place, PyPsiBundle.message("INSP.protocol.element.type.incompatible.with.protocol", it.name, protocol.name))
val element = it.element
val place = if (element is PsiNameIdentifierOwner) element.nameIdentifier else element ?: return@forEach
val elementName = if (element is PsiNameIdentifierOwner) element.name else return@forEach
registerProblem(place, PyPsiBundle.message("INSP.protocol.element.type.incompatible.with.protocol", elementName, protocol.name))
}
}
}

View File

@@ -812,6 +812,14 @@ fun PyClassType.resolveImplicitlyInvokedMethods(
else resolveDunderCall(callSite, resolveContext)
}
fun PyClassType.getImplicitlyInvokedMethodTypes(
callSite: PyCallSiteExpression?,
resolveContext: PyResolveContext,
): List<PyTypedResolveResult> {
return if (isDefinition()) getConstructorTypes(callSite, resolveContext)
else getDunderCallType(callSite, resolveContext)
}
private fun PyClassType.changeToImplicitlyInvokedMethods(
implicitlyInvokedMethods: List<PsiElement>,
call: PyCallExpression,
@@ -858,6 +866,20 @@ private fun PyClassType.resolveConstructors(callSite: PyCallSiteExpression?, res
return initAndNew.preferInitOverNew().map { RatedResolveResult(PyReferenceImpl.getRate(it, context), it) }
}
private fun PyClassType.getConstructorTypes(callSite: PyCallSiteExpression?, resolveContext: PyResolveContext): List<PyTypedResolveResult> {
val initTypes = getMemberTypes(PyNames.INIT, callSite, AccessDirection.READ, resolveContext)
if (initTypes != null) {
return initTypes
}
val newTypes = getMemberTypes(PyNames.NEW, callSite, AccessDirection.READ, resolveContext)
if (newTypes != null) {
return newTypes
}
return emptyList()
}
private fun PyCallableType.isReturnTypeAnnotated(context: TypeEvalContext): Boolean {
val callable = this.callable
if (callable is PyFunction) {
@@ -901,6 +923,10 @@ private fun PyClassLikeType.resolveDunderCall(location: PyExpression?, resolveCo
return resolveMember(PyNames.CALL, location, AccessDirection.READ, resolveContext) ?: emptyList()
}
private fun PyClassLikeType.getDunderCallType(location: PyExpression?, resolveContext: PyResolveContext): List<PyTypedResolveResult> {
return getMemberTypes(PyNames.CALL, location, AccessDirection.READ, resolveContext) ?: emptyList()
}
fun analyzeArguments(
arguments: List<PyExpression>,
parameters: List<PyCallableParameter>,

View File

@@ -139,6 +139,11 @@ public class PyCallableTypeImpl implements PyCallableType {
return myImplicitOffset;
}
@Override
public @Nullable PyQualifiedNameOwner getDeclarationElement() {
return myCallable;
}
@Override
public boolean equals(Object o) {
if (this == o) return true;

View File

@@ -4,7 +4,6 @@ package com.jetbrains.python.psi.types;
import com.intellij.openapi.util.*;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiFile;
import com.intellij.psi.ResolveResult;
import com.intellij.util.ArrayUtil;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.PyNames;
@@ -544,18 +543,15 @@ public final class PyTypeChecker {
GenericSubstitutions substitutions = collectTypeSubstitutions(actual, matchContext.context);
MatchContext protocolContext = new MatchContext(matchContext.context, new GenericSubstitutions(), matchContext.reversedSubstitutions);
for (kotlin.Pair<PyTypedElement, List<RatedResolveResult>> pair : PyProtocolsKt.inspectProtocolSubclass(expected, actual, matchContext.context)) {
final List<RatedResolveResult> subclassElements = pair.getSecond();
if (ContainerUtil.isEmpty(subclassElements)) {
for (kotlin.Pair<PyTypedElement, List<PyTypedResolveResult>> pair : PyProtocolsKt.inspectProtocolSubclass(expected, actual, matchContext.context)) {
final List<PyType> subclassElementTypes = ContainerUtil.map(pair.getSecond(), member -> member.getType());
if (ContainerUtil.isEmpty(subclassElementTypes)) {
return false;
}
final PyType protocolElementType = dropSelfIfNeeded(expected, matchContext.context.getType(pair.getFirst()), matchContext.context);
final boolean elementResult = StreamEx
.of(subclassElements)
.map(ResolveResult::getElement)
.select(PyTypedElement.class)
.map(matchContext.context::getType)
.of(subclassElementTypes)
.map(type -> dropSelfIfNeeded(actual, type, matchContext.context))
.map(type -> substitute(type, substitutions, matchContext.context))
.anyMatch(