PY-45381: Fix false-positive when overriden non-default fields follow defaults defined in parent dataclass

GitOrigin-RevId: ee08f0db368951ea8f02835f1af8c57fae34dabd
This commit is contained in:
Irina Fediaeva
2023-04-03 12:39:39 +03:00
committed by intellij-monorepo-bot
parent f5a541c541
commit 069ee247c0
3 changed files with 104 additions and 26 deletions

View File

@@ -13,6 +13,7 @@ import com.jetbrains.python.psi.LanguageLevel
import com.jetbrains.python.psi.PyClass
import com.jetbrains.python.psi.PyTargetExpression
import com.jetbrains.python.psi.types.TypeEvalContext
import one.util.streamex.StreamEx
import java.util.*
class PyNamedTupleInspection : PyInspection() {
@@ -25,21 +26,18 @@ class PyNamedTupleInspection : PyInspection() {
callback: (PsiElement, String, ProblemHighlightType) -> Unit,
fieldsFilter: (PyTargetExpression) -> Boolean = { true },
hasAssignedValue: (PyTargetExpression) -> Boolean = PyTargetExpression::hasAssignedValue) {
val fieldsProcessor = if (classFieldsFilter(cls)) processFields(
cls, fieldsFilter, hasAssignedValue)
val fieldsProcessor = if (classFieldsFilter(cls)) processFields(cls, fieldsFilter, hasAssignedValue)
else null
if (fieldsProcessor?.lastFieldWithoutDefaultValue == null && !checkInheritedOrder) return
val ancestors = cls.getAncestorClasses(context)
val ancestorsFields = ancestors.map {
val ancestorsParameters = ancestors.map {
if (!classFieldsFilter(it)) {
Ancestor.FILTERED
}
else {
val processor = processFields(it,
fieldsFilter,
hasAssignedValue)
val processor = processFields(it, fieldsFilter, hasAssignedValue)
if (processor.fieldsWithDefaultValue.isNotEmpty()) {
Ancestor.HAS_FIELD_WITH_DEFAULT_VALUE
}
@@ -54,7 +52,7 @@ class PyNamedTupleInspection : PyInspection() {
if (checkInheritedOrder) {
var seenAncestorHavingFieldWithDefaultValue: PyClass? = null
for (ancestorAndFields in ancestors.zip(ancestorsFields).asReversed()) {
for (ancestorAndFields in ancestors.zip(ancestorsParameters).asReversed()) {
if (ancestorAndFields.second == Ancestor.HAS_FIELD_WITH_DEFAULT_VALUE) seenAncestorHavingFieldWithDefaultValue = ancestorAndFields.first
else if (ancestorAndFields.second == Ancestor.HAS_NOT_FIELD_WITH_DEFAULT_VALUE && seenAncestorHavingFieldWithDefaultValue != null) {
callback(
@@ -69,24 +67,42 @@ class PyNamedTupleInspection : PyInspection() {
}
}
val lastFieldWithoutDefaultValue = fieldsProcessor?.lastFieldWithoutDefaultValue
if (lastFieldWithoutDefaultValue != null) {
if (ancestorsFields.contains(
Ancestor.HAS_FIELD_WITH_DEFAULT_VALUE)) {
cls.nameIdentifier?.let { name ->
val ancestorsNames = ancestors
.asSequence()
.zip(ancestorsFields.asSequence())
.filter { it.second == Ancestor.HAS_FIELD_WITH_DEFAULT_VALUE }
.joinToString { "'${it.first.name}'" }
if (fieldsProcessor == null) return
callback(name,
"Non-default argument(s) follows default argument(s) defined in $ancestorsNames",
ProblemHighlightType.GENERIC_ERROR)
}
val ancestorFieldNames = StreamEx.of(ancestors).toFlatList { it.classAttributes }.map { it.name }
val fieldsWithoutDefaultNotOverriden = fieldsProcessor.fieldsWithoutDefaultValue
.filterNot {
it.name in ancestorFieldNames
}
fieldsProcessor.fieldsWithDefaultValue.headSet(lastFieldWithoutDefaultValue).forEach {
if (fieldsWithoutDefaultNotOverriden.isNotEmpty() && ancestorsParameters.contains(Ancestor.HAS_FIELD_WITH_DEFAULT_VALUE)) {
cls.nameIdentifier?.let { name ->
val ancestorsNames = ancestors
.asSequence()
.zip(ancestorsParameters.asSequence())
.filter { it.second == Ancestor.HAS_FIELD_WITH_DEFAULT_VALUE }
.joinToString { "'${it.first.name}'" }
callback(name,
"Non-default argument(s) follows default argument(s) defined in $ancestorsNames",
ProblemHighlightType.GENERIC_ERROR)
}
}
val lastFieldWithoutDefault = fieldsProcessor.lastFieldWithoutDefaultValue
val lastFieldWithoutDefaultNotOverriden =
if (lastFieldWithoutDefault != null && lastFieldWithoutDefault in fieldsWithoutDefaultNotOverriden) {
lastFieldWithoutDefault
}
else {
fieldsProcessor
.fieldsWithoutDefaultValue
.descendingSet()
.firstOrNull { it in fieldsWithoutDefaultNotOverriden }
}
if (lastFieldWithoutDefaultNotOverriden != null) {
fieldsProcessor.fieldsWithDefaultValue.headSet(lastFieldWithoutDefaultNotOverriden).forEach {
callback(it,
"Fields with a default value must come after any fields without a default.",
ProblemHighlightType.GENERIC_ERROR)
@@ -97,10 +113,8 @@ class PyNamedTupleInspection : PyInspection() {
private fun processFields(cls: PyClass,
filter: (PyTargetExpression) -> Boolean,
hasAssignedValue: (PyTargetExpression) -> Boolean): LocalFieldsProcessor {
val fieldsProcessor = LocalFieldsProcessor(filter,
hasAssignedValue)
val fieldsProcessor = LocalFieldsProcessor(filter, hasAssignedValue)
cls.processClassLevelDeclarations(fieldsProcessor)
return fieldsProcessor
}
@@ -134,6 +148,7 @@ class PyNamedTupleInspection : PyInspection() {
val lastFieldWithoutDefaultValue: PyTargetExpression?
get() = lastFieldWithoutDefaultValueBox.result
val fieldsWithDefaultValue: TreeSet<PyTargetExpression>
val fieldsWithoutDefaultValue: TreeSet<PyTargetExpression>
private val lastFieldWithoutDefaultValueBox: MaxBy<PyTargetExpression>
@@ -141,13 +156,17 @@ class PyNamedTupleInspection : PyInspection() {
val offsetComparator = compareBy(PyTargetExpression::getTextOffset)
lastFieldWithoutDefaultValueBox = MaxBy(offsetComparator)
fieldsWithDefaultValue = TreeSet(offsetComparator)
fieldsWithoutDefaultValue = TreeSet(offsetComparator)
}
override fun execute(element: PsiElement, state: ResolveState): Boolean {
if (element is PyTargetExpression && filter(element)) {
when {
hasAssignedValue(element) -> fieldsWithDefaultValue.add(element)
else -> lastFieldWithoutDefaultValueBox.apply(element)
else -> {
fieldsWithoutDefaultValue.add(element)
lastFieldWithoutDefaultValueBox.apply(element)
}
}
}

View File

@@ -0,0 +1,54 @@
from dataclasses import dataclass, field
@dataclass
class B:
a: int
b: int = 0
@dataclass
class A1(B):
a: int
c: int = 1
b: int
@dataclass
class <error descr="Non-default argument(s) follows default argument(s) defined in 'B'">A2</error>(B):
a: int
c: int
b: int
@dataclass
class <error descr="Non-default argument(s) follows default argument(s) defined in 'B'">A3</error>(B):
a: int
b: int
c: int
@dataclass
class A4(B):
a: int
b: int
c: int = 1
@dataclass
class <error descr="Non-default argument(s) follows default argument(s) defined in 'B'">A5</error>(B):
a: int
<error descr="Fields with a default value must come after any fields without a default.">c</error>: int = 1
<error descr="Fields with a default value must come after any fields without a default.">d</error>: int = 1
b: int
e: int
@dataclass
class <error descr="Non-default argument(s) follows default argument(s) defined in 'B'">A6</error>(B):
a: int
<error descr="Fields with a default value must come after any fields without a default.">c</error>: int = 1
d: int
e: int
b: int = 1
@dataclass
class <error descr="Non-default argument(s) follows default argument(s) defined in 'B'">A7</error>(B):
a: int
b: int
<error descr="Fields with a default value must come after any fields without a default.">c</error>: int = 1
<error descr="Fields with a default value must come after any fields without a default.">d</error>: int = 1
e: int

View File

@@ -325,6 +325,11 @@ public class PyDataclassInspectionTest extends PyInspectionTestCase {
doTest();
}
// PY-49946
public void testFieldsOrderOverridden() {
doTest();
}
@Override
protected void doTest() {
myFixture.copyDirectoryToProject("packages/attr", "attr");