PY-78008 Dataclass __post_init__ completion

GitOrigin-RevId: 88a2e5eb7d478be3221ad5ee69b3a95090972672
This commit is contained in:
Petr
2024-12-11 20:06:03 +01:00
committed by intellij-monorepo-bot
parent a8149cc0a5
commit e5be315559
5 changed files with 72 additions and 33 deletions

View File

@@ -15,13 +15,13 @@ import com.jetbrains.python.codeInsight.PyDataclassNames.Attrs
import com.jetbrains.python.codeInsight.PyDataclassNames.Dataclasses
import com.jetbrains.python.codeInsight.PyDataclassParameters
import com.jetbrains.python.codeInsight.parseDataclassParameters
import com.jetbrains.python.codeInsight.stdlib.PyDataclassTypeProvider
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider
import com.jetbrains.python.extensions.afterDefInMethod
import com.jetbrains.python.extensions.inParameterList
import com.jetbrains.python.psi.PyParameter
import com.jetbrains.python.psi.PyParameterList
import com.jetbrains.python.psi.PySubscriptionExpression
import com.jetbrains.python.psi.PyTargetExpression
import com.jetbrains.python.psi.types.PyClassType
class PyDataclassCompletionContributor : CompletionContributor(), DumbAware {
@@ -44,22 +44,19 @@ class PyDataclassCompletionContributor : CompletionContributor(), DumbAware {
if (dataclassParameters.type.asPredefinedType == PyDataclassParameters.PredefinedType.STD) {
val postInitParameters = mutableListOf(PyNames.CANONICAL_SELF)
cls.processClassLevelDeclarations { element, _ ->
if (element is PyTargetExpression && element.annotationValue != null) {
val name = element.name
val annotationValue = element.annotation?.value as? PySubscriptionExpression
if (name != null && annotationValue != null) {
val type = typeEvalContext.getType(element)
if (type is PyClassType && type.classQName == Dataclasses.DATACLASSES_INITVAR) {
val typeHint = annotationValue.indexExpression.let { if (it == null) "" else ": ${it.text}" }
postInitParameters.add(name + typeHint)
}
PyDataclassTypeProvider.getInitVars(cls, dataclassParameters, typeEvalContext).orEmpty().forEach {
val name = it.targetExpression.name
val typeHint = PyTypingTypeProvider.getAnnotationValue(it.targetExpression, typeEvalContext)
if (name != null && typeHint is PySubscriptionExpression) {
val indexExpression = typeHint.indexExpression
val parameterString = if (indexExpression != null) {
"${name}: ${indexExpression.text}"
}
else {
name
}
postInitParameters.add(parameterString)
}
true
}
addMethodToResult(result, cls, typeEvalContext,

View File

@@ -18,6 +18,7 @@ import com.jetbrains.python.psi.impl.PyCallExpressionNavigator
import com.jetbrains.python.psi.resolve.PyResolveContext
import com.jetbrains.python.psi.types.*
import one.util.streamex.StreamEx
import org.jetbrains.annotations.ApiStatus
class PyDataclassTypeProvider : PyTypeProviderBase() {
@@ -48,9 +49,10 @@ class PyDataclassTypeProvider : PyTypeProviderBase() {
if (parameterIndex == -1) return null
val cls = func.containingClass ?: return null
return getInitVarTypes(cls, context)
val initVars = getInitVars(cls, parseStdDataclassParameters(cls, context), context) ?: return null
return initVars
.drop(parameterIndex)
.map { Ref.create(it) }
.map { Ref.create(it.type) }
.firstOrNull()
}
@@ -58,24 +60,35 @@ class PyDataclassTypeProvider : PyTypeProviderBase() {
}
companion object {
fun getInitVarTypes(cls: PyClass, context: TypeEvalContext): Sequence<PyType?> {
return if (parseStdDataclassParameters(cls, context)?.init == true) {
cls.getAncestorClasses(context)
.asReversed()
.asSequence()
.filter { parseDataclassParameters(it, context) != null }
.plus(cls)
.flatMap { it.classAttributes }
.map { context.getType(it) }
.filterIsInstance<PyCollectionType>()
.filter { it.classQName == Dataclasses.DATACLASSES_INITVAR }
.map { it.elementTypes.firstOrNull() }
}
else {
emptySequence()
@ApiStatus.Internal
fun getInitVars(
cls: PyClass,
dataclassParams: PyDataclassParameters?,
context: TypeEvalContext,
): Sequence<InitVarInfo>? {
if (dataclassParams == null || !dataclassParams.init) {
return null
}
return cls.getAncestorClasses(context)
.asReversed()
.asSequence()
.filter { parseDataclassParameters(it, context) != null }
.plus(cls)
.flatMap { it.classAttributes }
.mapNotNull {
val type = context.getType(it)
if (type is PyCollectionType && type.classQName == Dataclasses.DATACLASSES_INITVAR) {
InitVarInfo(it, type.elementTypes.singleOrNull())
}
else {
null
}
}
}
@ApiStatus.Internal
class InitVarInfo(val targetExpression: PyTargetExpression, val type: PyType?)
private fun getDataclassesReplaceType(referenceExpression: PyReferenceExpression, context: TypeEvalContext): PyCallableType? {
val call = PyCallExpressionNavigator.getPyCallExpressionByCallee(referenceExpression) ?: return null
val callee = call.callee as? PyReferenceExpression ?: return null

View File

@@ -0,0 +1,12 @@
import dataclasses
@dataclasses.dataclass
class Base:
a1: int = 0
a2: dataclasses.InitVar[str]
@dataclasses.dataclass
class Derived(Base):
a3 = ""
a4: dataclasses.InitVar[bool]
def __post_init__(self, a2: str, a4: bool):

View File

@@ -0,0 +1,12 @@
import dataclasses
@dataclasses.dataclass
class Base:
a1: int = 0
a2: dataclasses.InitVar[str]
@dataclasses.dataclass
class Derived(Base):
a3 = ""
a4: dataclasses.InitVar[bool]
def __post_<caret>

View File

@@ -332,6 +332,11 @@ public class Py3CompletionTest extends PyTestCase {
doMultiFileTest();
}
// PY-78008
public void testDataclassWithInheritedInitVarPostInit() {
doMultiFileTest();
}
// PY-27398
public void testDataclassPostInitNoInit() {
doMultiFileTest();