Support type inference for cls inside namedtuple's class methods (PY-30870, PY-33140, PY-45473)

GitOrigin-RevId: a299d681d20230acd9443f0fa37c7fd64a51d76a
This commit is contained in:
Semyon Proshev
2020-11-17 22:04:46 +03:00
committed by intellij-monorepo-bot
parent b017f822c2
commit 6fcdc21997
4 changed files with 36 additions and 9 deletions

View File

@@ -3,6 +3,7 @@ package com.jetbrains.python.codeInsight.stdlib
import com.intellij.openapi.util.Ref
import com.intellij.psi.PsiElement
import com.intellij.psi.util.PsiTreeUtil
import com.intellij.util.ArrayUtil
import com.intellij.util.containers.mapSmartNotNull
import com.jetbrains.python.PyNames
@@ -70,11 +71,13 @@ class PyNamedTupleTypeProvider : PyTypeProviderBase() {
return when {
referenceTarget is PyFunction && anchor is PyCallExpression -> getNamedTupleFunctionType(referenceTarget, context, anchor)
referenceTarget is PyTargetExpression -> getNamedTupleTypeForTarget(referenceTarget, context)
referenceTarget is PyClass && anchor is PyCallExpression -> {
getNamedTupleTypeForNTInheritorAsCallee(referenceTarget, context) ?:
PyUnionType.union(
referenceTarget.multiFindInitOrNew(false, context).mapSmartNotNull { getNamedTupleFunctionType(it, context, anchor) }
)
referenceTarget is PyClass && anchor is PyCallExpression -> getNamedTupleTypeForClass(referenceTarget, context, anchor)
referenceTarget is PyParameter && anchor is PyCallExpression && referenceTarget.isSelf -> {
PsiTreeUtil.getParentOfType(referenceTarget, PyFunction::class.java)
?.takeIf { it.modifier == PyFunction.Modifier.CLASSMETHOD }
?.let { method ->
method.containingClass?.let { getNamedTupleTypeForClass(it, context, anchor) }
}
}
else -> null
}
@@ -158,6 +161,13 @@ class PyNamedTupleTypeProvider : PyTypeProviderBase() {
.compute(context)
}
private fun getNamedTupleTypeForClass(cls: PyClass, context: TypeEvalContext, call: PyCallExpression): PyType? {
return getNamedTupleTypeForNTInheritorAsCallee(cls, context)
?: PyUnionType.union(
cls.multiFindInitOrNew(false, context).mapSmartNotNull { getNamedTupleFunctionType(it, context, call) }
)
}
private fun getNamedTupleTypeForNTInheritorAsCallee(cls: PyClass, context: TypeEvalContext): PyType? {
if (cls.findInitOrNew(false, context) != null) return null

View File

@@ -8,5 +8,12 @@ class MyTup2(namedtuple("MyTup2", "bar baz")):
pass
class MyTup3(namedtuple("MyTup3", "bar baz")):
@classmethod
def factory(cls):
return cls(<arg3>)
MyTup1(<arg1>)
MyTup2(<arg2>)

View File

@@ -31,6 +31,15 @@ class MyTup8(typing.NamedTuple):
baz: str = ""
class MyTup9(typing.NamedTuple):
bar: int
baz: str
@classmethod
def factory(cls):
return cls(<arg8>)
MyTup2(<arg1>)
MyTup3(<arg2>)
MyTup4(<arg3>)

View File

@@ -629,9 +629,9 @@ public class PyParameterInfoTest extends LightMarkedTestCase {
);
}
// PY-22249
// PY-22249, PY-45473
public void testInitializingCollectionsNamedTuple() {
final Map<String, PsiElement> test = loadTest(2);
final Map<String, PsiElement> test = loadTest(3);
for (int offset : StreamEx.of(test.values()).map(PsiElement::getTextOffset)) {
final List<String> texts = Collections.singletonList("bar, baz");
@@ -641,13 +641,14 @@ public class PyParameterInfoTest extends LightMarkedTestCase {
}
}
// PY-33140
public void testInitializingTypingNamedTuple() {
runWithLanguageLevel(
LanguageLevel.PYTHON36,
() -> {
final Map<String, PsiElement> test = loadTest(7);
final Map<String, PsiElement> test = loadTest(8);
for (int offset : StreamEx.of(1, 2, 3, 4).map(number -> test.get("<arg" + number + ">").getTextOffset())) {
for (int offset : StreamEx.of(1, 2, 3, 4, 8).map(number -> test.get("<arg" + number + ">").getTextOffset())) {
final List<String> texts = Collections.singletonList("bar: int, baz: str");
final List<String[]> highlighted = Collections.singletonList(new String[]{"bar: int, "});