mirror of
https://gitflic.ru/project/openide/openide.git
synced 2026-04-30 02:09:59 +07:00
PY-78008 Dataclass __post_init__ completion
GitOrigin-RevId: 88a2e5eb7d478be3221ad5ee69b3a95090972672
This commit is contained in:
committed by
intellij-monorepo-bot
parent
a8149cc0a5
commit
e5be315559
@@ -15,13 +15,13 @@ import com.jetbrains.python.codeInsight.PyDataclassNames.Attrs
|
||||
import com.jetbrains.python.codeInsight.PyDataclassNames.Dataclasses
|
||||
import com.jetbrains.python.codeInsight.PyDataclassParameters
|
||||
import com.jetbrains.python.codeInsight.parseDataclassParameters
|
||||
import com.jetbrains.python.codeInsight.stdlib.PyDataclassTypeProvider
|
||||
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider
|
||||
import com.jetbrains.python.extensions.afterDefInMethod
|
||||
import com.jetbrains.python.extensions.inParameterList
|
||||
import com.jetbrains.python.psi.PyParameter
|
||||
import com.jetbrains.python.psi.PyParameterList
|
||||
import com.jetbrains.python.psi.PySubscriptionExpression
|
||||
import com.jetbrains.python.psi.PyTargetExpression
|
||||
import com.jetbrains.python.psi.types.PyClassType
|
||||
|
||||
class PyDataclassCompletionContributor : CompletionContributor(), DumbAware {
|
||||
|
||||
@@ -44,22 +44,19 @@ class PyDataclassCompletionContributor : CompletionContributor(), DumbAware {
|
||||
if (dataclassParameters.type.asPredefinedType == PyDataclassParameters.PredefinedType.STD) {
|
||||
val postInitParameters = mutableListOf(PyNames.CANONICAL_SELF)
|
||||
|
||||
cls.processClassLevelDeclarations { element, _ ->
|
||||
if (element is PyTargetExpression && element.annotationValue != null) {
|
||||
val name = element.name
|
||||
val annotationValue = element.annotation?.value as? PySubscriptionExpression
|
||||
|
||||
if (name != null && annotationValue != null) {
|
||||
val type = typeEvalContext.getType(element)
|
||||
|
||||
if (type is PyClassType && type.classQName == Dataclasses.DATACLASSES_INITVAR) {
|
||||
val typeHint = annotationValue.indexExpression.let { if (it == null) "" else ": ${it.text}" }
|
||||
postInitParameters.add(name + typeHint)
|
||||
}
|
||||
PyDataclassTypeProvider.getInitVars(cls, dataclassParameters, typeEvalContext).orEmpty().forEach {
|
||||
val name = it.targetExpression.name
|
||||
val typeHint = PyTypingTypeProvider.getAnnotationValue(it.targetExpression, typeEvalContext)
|
||||
if (name != null && typeHint is PySubscriptionExpression) {
|
||||
val indexExpression = typeHint.indexExpression
|
||||
val parameterString = if (indexExpression != null) {
|
||||
"${name}: ${indexExpression.text}"
|
||||
}
|
||||
else {
|
||||
name
|
||||
}
|
||||
postInitParameters.add(parameterString)
|
||||
}
|
||||
|
||||
true
|
||||
}
|
||||
|
||||
addMethodToResult(result, cls, typeEvalContext,
|
||||
|
||||
@@ -18,6 +18,7 @@ import com.jetbrains.python.psi.impl.PyCallExpressionNavigator
|
||||
import com.jetbrains.python.psi.resolve.PyResolveContext
|
||||
import com.jetbrains.python.psi.types.*
|
||||
import one.util.streamex.StreamEx
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
class PyDataclassTypeProvider : PyTypeProviderBase() {
|
||||
|
||||
@@ -48,9 +49,10 @@ class PyDataclassTypeProvider : PyTypeProviderBase() {
|
||||
if (parameterIndex == -1) return null
|
||||
|
||||
val cls = func.containingClass ?: return null
|
||||
return getInitVarTypes(cls, context)
|
||||
val initVars = getInitVars(cls, parseStdDataclassParameters(cls, context), context) ?: return null
|
||||
return initVars
|
||||
.drop(parameterIndex)
|
||||
.map { Ref.create(it) }
|
||||
.map { Ref.create(it.type) }
|
||||
.firstOrNull()
|
||||
}
|
||||
|
||||
@@ -58,24 +60,35 @@ class PyDataclassTypeProvider : PyTypeProviderBase() {
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun getInitVarTypes(cls: PyClass, context: TypeEvalContext): Sequence<PyType?> {
|
||||
return if (parseStdDataclassParameters(cls, context)?.init == true) {
|
||||
cls.getAncestorClasses(context)
|
||||
.asReversed()
|
||||
.asSequence()
|
||||
.filter { parseDataclassParameters(it, context) != null }
|
||||
.plus(cls)
|
||||
.flatMap { it.classAttributes }
|
||||
.map { context.getType(it) }
|
||||
.filterIsInstance<PyCollectionType>()
|
||||
.filter { it.classQName == Dataclasses.DATACLASSES_INITVAR }
|
||||
.map { it.elementTypes.firstOrNull() }
|
||||
}
|
||||
else {
|
||||
emptySequence()
|
||||
@ApiStatus.Internal
|
||||
fun getInitVars(
|
||||
cls: PyClass,
|
||||
dataclassParams: PyDataclassParameters?,
|
||||
context: TypeEvalContext,
|
||||
): Sequence<InitVarInfo>? {
|
||||
if (dataclassParams == null || !dataclassParams.init) {
|
||||
return null
|
||||
}
|
||||
return cls.getAncestorClasses(context)
|
||||
.asReversed()
|
||||
.asSequence()
|
||||
.filter { parseDataclassParameters(it, context) != null }
|
||||
.plus(cls)
|
||||
.flatMap { it.classAttributes }
|
||||
.mapNotNull {
|
||||
val type = context.getType(it)
|
||||
if (type is PyCollectionType && type.classQName == Dataclasses.DATACLASSES_INITVAR) {
|
||||
InitVarInfo(it, type.elementTypes.singleOrNull())
|
||||
}
|
||||
else {
|
||||
null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ApiStatus.Internal
|
||||
class InitVarInfo(val targetExpression: PyTargetExpression, val type: PyType?)
|
||||
|
||||
private fun getDataclassesReplaceType(referenceExpression: PyReferenceExpression, context: TypeEvalContext): PyCallableType? {
|
||||
val call = PyCallExpressionNavigator.getPyCallExpressionByCallee(referenceExpression) ?: return null
|
||||
val callee = call.callee as? PyReferenceExpression ?: return null
|
||||
|
||||
@@ -0,0 +1,12 @@
|
||||
import dataclasses
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Base:
|
||||
a1: int = 0
|
||||
a2: dataclasses.InitVar[str]
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Derived(Base):
|
||||
a3 = ""
|
||||
a4: dataclasses.InitVar[bool]
|
||||
def __post_init__(self, a2: str, a4: bool):
|
||||
@@ -0,0 +1,12 @@
|
||||
import dataclasses
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Base:
|
||||
a1: int = 0
|
||||
a2: dataclasses.InitVar[str]
|
||||
|
||||
@dataclasses.dataclass
|
||||
class Derived(Base):
|
||||
a3 = ""
|
||||
a4: dataclasses.InitVar[bool]
|
||||
def __post_<caret>
|
||||
@@ -332,6 +332,11 @@ public class Py3CompletionTest extends PyTestCase {
|
||||
doMultiFileTest();
|
||||
}
|
||||
|
||||
// PY-78008
|
||||
public void testDataclassWithInheritedInitVarPostInit() {
|
||||
doMultiFileTest();
|
||||
}
|
||||
|
||||
// PY-27398
|
||||
public void testDataclassPostInitNoInit() {
|
||||
doMultiFileTest();
|
||||
|
||||
Reference in New Issue
Block a user