PY-42664 Code deduplication (PyCollectionTypeUtil)

GitOrigin-RevId: cde6db9b023b33fc5955ed16ae465e00bb6a5160
This commit is contained in:
Petr
2024-04-03 13:57:45 +02:00
committed by intellij-monorepo-bot
parent 541d4fad47
commit c58f577ee8

View File

@@ -3,6 +3,7 @@
*/
package com.jetbrains.python.psi.types
import com.intellij.openapi.util.Ref
import com.intellij.psi.PsiElement
import com.intellij.psi.util.PsiTreeUtil
import com.intellij.util.ArrayUtil
@@ -336,7 +337,7 @@ object PyCollectionTypeUtil {
return if (isModificationExist) valueTypes else null
}
private fun getRightValue(node: PySubscriptionExpression): PyExpression? {
private fun getRightValue(node: PySubscriptionExpression, element: PsiElement, typeEvalContext: TypeEvalContext): Ref<PyExpression?>? {
var parent = node.parent
var tupleParent: PyTupleExpression? = null
@@ -358,6 +359,12 @@ object PyCollectionTypeUtil {
}
}
val referenceOwner = node.operand as? PyReferenceOwner ?: return null
val resolveContext = PyResolveContext.defaultContext(typeEvalContext)
if (!referenceOwner.getReference(resolveContext).isReferenceTo(element)) {
return null
}
var rightValue = assignment.assignedValue
if (tupleParent != null) {
val rightElements = PyUtil.flattenedParensAndLists(rightValue)
@@ -366,7 +373,7 @@ object PyCollectionTypeUtil {
rightValue = rightElements[indexInAssignment]
}
}
return rightValue
return Ref(rightValue)
}
return null
}
@@ -374,48 +381,17 @@ object PyCollectionTypeUtil {
private fun getTypeByModifications(node: PySubscriptionExpression,
element: PsiElement,
typeEvalContext: TypeEvalContext): Pair<List<PyType?>, List<PyType?>>? {
var parent = node.parent
val keyTypes = ArrayList<PyType?>()
val valueTypes = ArrayList<PyType?>()
var isModificationExist = false
var tupleParent: PyTupleExpression? = null
if (parent is PyTupleExpression) {
tupleParent = parent
parent = tupleParent.parent
fun getType(element: PyTypedElement?): List<PyType?> {
return if (element != null)
listOf(typeEvalContext.getType(element))
else
emptyList()
}
if (parent is PyAssignmentStatement) {
val assignment = parent
val leftExpression = assignment.leftHandSideExpression
if (tupleParent == null) {
if (leftExpression !== node) return null
}
else {
if (leftExpression !== tupleParent || !ArrayUtil.contains(node, *tupleParent.elements)) {
return null
}
}
val resolveContext = PyResolveContext.defaultContext(typeEvalContext)
val referenceOwner = node.operand as? PyReferenceOwner ?: return null
val reference = referenceOwner.getReference(resolveContext)
isModificationExist = if (reference.isReferenceTo(element)) true else return null
val indexExpression = node.indexExpression
if (indexExpression != null) {
keyTypes.add(typeEvalContext.getType(indexExpression))
}
val rightValue = getRightValue(node)
if (rightValue != null) {
val rightValueType = typeEvalContext.getType(rightValue)
valueTypes.add(rightValueType)
}
}
return if (isModificationExist) Pair(keyTypes, valueTypes) else null
val rightValue = getRightValue(node, element, typeEvalContext) ?: return null
val keyTypes = getType(node.indexExpression)
val valueTypes = getType(rightValue.get())
return Pair(keyTypes, valueTypes)
}
private abstract class PyCollectionTypeVisitor(protected val myElement: PyTargetExpression,
@@ -576,9 +552,7 @@ object PyCollectionTypeUtil {
else null
override val elementTypes: List<PyType?>
get() = if (isModificationExist && hasAllStrKeys)
PyTypedDictType.createFromKeysToValueTypes(myElement, strKeysToValueTypes)?.elementTypes ?: emptyList()
else emptyList()
get() = typedDictType?.elementTypes ?: emptyList()
init {
modificationMethods = initMethods()
@@ -629,12 +603,7 @@ object PyCollectionTypeUtil {
}
override fun visitPySubscriptionExpression(node: PySubscriptionExpression) {
val rightValue = getRightValue(node) ?: return
val subscriptionTarget = node.operand as? PyReferenceOwner ?: return
val resolveContext = PyResolveContext.defaultContext(myTypeEvalContext)
if (!subscriptionTarget.getReference(resolveContext).isReferenceTo(myElement)) {
return
}
val rightValue = getRightValue(node, myElement, myTypeEvalContext)?.get() ?: return
isModificationExist = true
val indexExpression = node.indexExpression