Provide type for parameter in __post_init__ (PY-27398)

This commit is contained in:
Semyon Proshev
2017-12-20 16:05:53 +03:00
parent 4fc97d088f
commit dfb801aba9
4 changed files with 80 additions and 4 deletions

View File

@@ -3,11 +3,9 @@
*/
package com.jetbrains.python.codeInsight.stdlib
import com.intellij.openapi.util.Ref
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider
import com.jetbrains.python.psi.PyClass
import com.jetbrains.python.psi.PyReferenceExpression
import com.jetbrains.python.psi.PyTargetExpression
import com.jetbrains.python.psi.PyUtil
import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.impl.PyCallExpressionNavigator
import com.jetbrains.python.psi.resolve.PyResolveContext
import com.jetbrains.python.psi.types.*
@@ -18,6 +16,30 @@ class PyDataclassesTypeProvider : PyTypeProviderBase() {
return getDataclassTypeForCallee(referenceExpression, context)
}
override fun getParameterType(param: PyNamedParameter, func: PyFunction, context: TypeEvalContext): Ref<PyType>? {
if (!param.isPositionalContainer && !param.isKeywordContainer && param.annotationValue == null && func.name == "__post_init__") {
val cls = func.containingClass
if (cls != null && parseDataclassParameters(cls, context)?.init == true) {
var result: Ref<PyType>? = null
cls.processClassLevelDeclarations { element, _ ->
if (element is PyTargetExpression && element.name == param.name && element.annotationValue != null) {
result = Ref.create(getTypeForParameter(element, context))
false
}
else {
true
}
}
return result
}
}
return null
}
private fun getDataclassTypeForCallee(referenceExpression: PyReferenceExpression, context: TypeEvalContext): PyCallableType? {
if (PyCallExpressionNavigator.getPyCallExpressionByCallee(referenceExpression) == null) return null

View File

@@ -0,0 +1,11 @@
class _InitVarMeta(type):
def __getitem__(self, params):
return self
class InitVar(metaclass=_InitVarMeta):
pass
def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False,
hash=None, frozen=False):
pass

View File

@@ -0,0 +1,11 @@
class _InitVarMeta(type):
def __getitem__(self, params):
return self
class InitVar(metaclass=_InitVarMeta):
pass
def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False,
hash=None, frozen=False):
pass

View File

@@ -902,6 +902,38 @@ public class Py3TypeTest extends PyTestCase {
);
}
// PY-27398
public void testDataclassPostInitParameter() {
runWithLanguageLevel(
LanguageLevel.PYTHON37,
() -> doMultiFileTest("int",
"from dataclasses import dataclass, InitVar\n" +
"@dataclass\n" +
"class Foo:\n" +
" i: int\n" +
" j: int\n" +
" d: InitVar[int]\n" +
" def __post_init__(self, d):\n" +
" expr = d")
);
}
// PY-27398
public void testDataclassPostInitParameterNoInit() {
runWithLanguageLevel(
LanguageLevel.PYTHON37,
() -> doMultiFileTest("Any",
"from dataclasses import dataclass, InitVar\n" +
"@dataclass(init=False)\n" +
"class Foo:\n" +
" i: int\n" +
" j: int\n" +
" d: InitVar[int]\n" +
" def __post_init__(self, d):\n" +
" expr = d")
);
}
private void doTest(final String expectedType, final String text) {
myFixture.configureByText(PythonFileType.INSTANCE, text);
final PyExpression expr = myFixture.findElementByText("expr", PyExpression.class);