diff --git a/python/python-psi-api/src/com/jetbrains/python/psi/PyMappingPattern.java b/python/python-psi-api/src/com/jetbrains/python/psi/PyMappingPattern.java deleted file mode 100644 index ec07888ec5e9..000000000000 --- a/python/python-psi-api/src/com/jetbrains/python/psi/PyMappingPattern.java +++ /dev/null @@ -1,7 +0,0 @@ -// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. -package com.jetbrains.python.psi; - -import com.jetbrains.python.ast.PyAstMappingPattern; - -public interface PyMappingPattern extends PyAstMappingPattern, PyPattern, PyCaptureContext { -} diff --git a/python/python-psi-api/src/com/jetbrains/python/psi/PyMappingPattern.kt b/python/python-psi-api/src/com/jetbrains/python/psi/PyMappingPattern.kt new file mode 100644 index 000000000000..4e8225f14703 --- /dev/null +++ b/python/python-psi-api/src/com/jetbrains/python/psi/PyMappingPattern.kt @@ -0,0 +1,6 @@ +// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. +package com.jetbrains.python.psi + +import com.jetbrains.python.ast.PyAstMappingPattern + +interface PyMappingPattern : PyAstMappingPattern, PyPattern, PyCaptureContext diff --git a/python/python-psi-api/src/com/jetbrains/python/psi/PySequencePattern.java b/python/python-psi-api/src/com/jetbrains/python/psi/PySequencePattern.java deleted file mode 100644 index 3bf1c26e6416..000000000000 --- a/python/python-psi-api/src/com/jetbrains/python/psi/PySequencePattern.java +++ /dev/null @@ -1,15 +0,0 @@ -// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. -package com.jetbrains.python.psi; - -import com.jetbrains.python.ast.PyAstSequencePattern; -import org.jetbrains.annotations.NotNull; - -import java.util.List; - -import static com.jetbrains.python.ast.PyAstElementKt.findChildrenByClass; - -public interface PySequencePattern extends PyAstSequencePattern, PyPattern, PyCaptureContext { - default @NotNull List<@NotNull PyPattern> getElements() { - return List.of(findChildrenByClass(this, PyPattern.class)); - } -} diff --git a/python/python-psi-api/src/com/jetbrains/python/psi/PySequencePattern.kt b/python/python-psi-api/src/com/jetbrains/python/psi/PySequencePattern.kt new file mode 100644 index 000000000000..4cff9dca4527 --- /dev/null +++ b/python/python-psi-api/src/com/jetbrains/python/psi/PySequencePattern.kt @@ -0,0 +1,10 @@ +// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. +package com.jetbrains.python.psi + +import com.jetbrains.python.ast.PyAstSequencePattern +import com.jetbrains.python.ast.findChildrenByClass + +interface PySequencePattern : PyAstSequencePattern, PyPattern, PyCaptureContext { + val elements: List + get() = findChildrenByClass(PyPattern::class.java).toList() +} diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyMappingPatternImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyMappingPatternImpl.java deleted file mode 100644 index 8931fb775448..000000000000 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyMappingPatternImpl.java +++ /dev/null @@ -1,99 +0,0 @@ -package com.jetbrains.python.psi.impl; - -import com.intellij.lang.ASTNode; -import com.intellij.psi.PsiElement; -import com.intellij.psi.PsiListLikeElement; -import com.intellij.psi.util.PsiTreeUtil; -import com.jetbrains.python.psi.*; -import com.jetbrains.python.psi.types.*; -import org.jetbrains.annotations.NotNull; -import org.jetbrains.annotations.Nullable; - -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; - -public class PyMappingPatternImpl extends PyElementImpl implements PyMappingPattern, PsiListLikeElement { - public PyMappingPatternImpl(ASTNode astNode) { - super(astNode); - } - - @Override - protected void acceptPyVisitor(PyElementVisitor pyVisitor) { - pyVisitor.visitPyMappingPattern(this); - } - - @Override - public @NotNull List getComponents() { - return Arrays.asList(findChildrenByClass(PyKeyValuePattern.class)); - } - - @Override - public boolean canExcludePatternType(@NotNull TypeEvalContext context) { - return false; - } - - @Override - public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) { - ArrayList keyTypes = new ArrayList<>(); - ArrayList valueTypes = new ArrayList<>(); - for (PyKeyValuePattern it : getComponents()) { - keyTypes.add(context.getType(it.getKeyPattern())); - if (it.getValuePattern() != null) { - valueTypes.add(context.getType(it.getValuePattern())); - } - } - - PyType patternMappingType = wrapInMappingType(PyUnionType.union(keyTypes), PyUnionType.union(valueTypes), this); - - PyType captureTypes = PyCaptureContext.getCaptureType(this, context); - PyType filteredType = PyTypeUtil.toStream(captureTypes).filter(captureType -> { - var mappingType = PyTypeUtil.convertToType(captureType, "typing.Mapping", this, context); - if (mappingType == null) return false; - if (!PyTypeChecker.match(mappingType, patternMappingType, context)) return false; - return true; - }).collect(PyTypeUtil.toUnion()); - - return filteredType == null ? patternMappingType : filteredType; - } - - @Override - public @Nullable PyType getCaptureTypeForChild(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) { - final var sequenceMember = PsiTreeUtil.findFirstParent(pattern, el -> this == el.getParent()); - if (sequenceMember instanceof PyDoubleStarPattern) { - var mappingType = PyTypeUtil.convertToType(context.getType(this), "typing.Mapping", pattern, context); - if (mappingType instanceof PyCollectionType collectionType) { - final PyClass dict = PyBuiltinCache.getInstance(pattern).getClass("dict"); - return dict != null ? new PyCollectionTypeImpl(dict, false, collectionType.getElementTypes()) : null; - } - return null; - } - else if (sequenceMember instanceof PyKeyValuePattern keyValuePattern) { - return PyTypeUtil.toStream(PyCaptureContext.getCaptureType(this, context)).map(type -> { - if (type instanceof PyTypedDictType typedDictType) { - if (context.getType(keyValuePattern.getKeyPattern()) instanceof PyLiteralType l && - l.getExpression() instanceof PyStringLiteralExpression str) { - return typedDictType.getElementType(str.getStringValue()); - } - } - - PyType mappingType = PyTypeUtil.convertToType(type, "typing.Mapping", pattern, context); - if (mappingType == null) { - return PyNeverType.NEVER; - } - else if (mappingType instanceof PyCollectionType collectionType) { - return collectionType.getElementTypes().get(1); - } - return null; - }).collect(PyTypeUtil.toUnion()); - } - return null; - } - - private static @Nullable PyType wrapInMappingType(@Nullable PyType keyType, @Nullable PyType valueType, @NotNull PsiElement resolveAnchor) { - keyType = PyLiteralType.upcastLiteralToClass(keyType); - valueType = PyLiteralType.upcastLiteralToClass(valueType); - final PyClass sequence = PyPsiFacade.getInstance(resolveAnchor.getProject()).createClassByQName("typing.Mapping", resolveAnchor); - return sequence != null ? new PyCollectionTypeImpl(sequence, false, Arrays.asList(keyType, valueType)) : null; - } -} diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyMappingPatternImpl.kt b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyMappingPatternImpl.kt new file mode 100644 index 000000000000..cf3df8513d07 --- /dev/null +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyMappingPatternImpl.kt @@ -0,0 +1,86 @@ +package com.jetbrains.python.psi.impl + +import com.intellij.lang.ASTNode +import com.intellij.psi.PsiListLikeElement +import com.intellij.psi.util.findParentInFile +import com.jetbrains.python.psi.* +import com.jetbrains.python.psi.PyCaptureContext.Companion.getCaptureType +import com.jetbrains.python.psi.impl.PyBuiltinCache.Companion.getInstance +import com.jetbrains.python.psi.types.* +import com.jetbrains.python.psi.types.PyLiteralType.Companion.upcastLiteralToClass + +class PyMappingPatternImpl(astNode: ASTNode?) : PyElementImpl(astNode), PyMappingPattern, PsiListLikeElement { + override fun acceptPyVisitor(pyVisitor: PyElementVisitor) { + pyVisitor.visitPyMappingPattern(this) + } + + override fun getComponents(): List = findChildrenByClass(PyKeyValuePattern::class.java).toList() + + override fun canExcludePatternType(context: TypeEvalContext): Boolean = false + + override fun getType(context: TypeEvalContext, key: TypeEvalContext.Key): PyType? { + val keyTypes = mutableListOf() + val valueTypes = mutableListOf() + for (it in components) { + keyTypes.add(context.getType(it.keyPattern)) + if (it.valuePattern != null) { + valueTypes.add(context.getType(it.valuePattern!!)) + } + } + + val patternMappingType = wrapInMappingType(PyUnionType.union(keyTypes), PyUnionType.union(valueTypes)) + + val filteredType = getCaptureType(this, context).toList().filter { captureType: PyType? -> + val mappingType = PyTypeUtil.convertToType(captureType, "typing.Mapping", this, context) ?: return@filter false + PyTypeChecker.match(mappingType, patternMappingType, context) + }.let { + PyUnionType.union(it) + } + + return filteredType ?: patternMappingType + } + + override fun getCaptureTypeForChild(pattern: PyPattern, context: TypeEvalContext): PyType? { + val sequenceMember = pattern.findParentInFile(withSelf = true) { this === it.parent } + if (sequenceMember is PyDoubleStarPattern) { + val mappingType = PyTypeUtil.convertToType(context.getType(this), "typing.Mapping", pattern, context) + if (mappingType is PyCollectionType) { + val dict = getInstance(pattern).getClass("dict") ?: return null + return PyCollectionTypeImpl(dict, false, mappingType.getElementTypes()) + } + return null + } + + if (sequenceMember !is PyKeyValuePattern) return null + + return getCaptureType(this, context).toList() + .map { possibleMapping -> possibleMapping.getValueType(sequenceMember, context) } + .let { PyUnionType.union(it) } + } + + private fun wrapInMappingType(keyType: PyType?, valueType: PyType?): PyType? { + val sequence = PyPsiFacade.getInstance(getProject()).createClassByQName("typing.Mapping", this) ?: return null + return PyCollectionTypeImpl(sequence, false, listOf(keyType, valueType).map { upcastLiteralToClass(it) }) + } +} + +private fun PyType?.getValueType(sequenceMember: PyKeyValuePattern, context: TypeEvalContext): PyType? { + if (this is PyTypedDictType) { + val key = sequenceMember.getKeyString(context) + if (key != null) return this.getElementType(key) + } + val mappingType = PyTypeUtil.convertToType(this, "typing.Mapping", sequenceMember, context) + ?: return PyNeverType.NEVER + if (mappingType is PyCollectionType) { + return mappingType.elementTypes[1] + } + return null +} + +private fun PyKeyValuePattern.getKeyString(context: TypeEvalContext): String? { + val keyType = context.getType(keyPattern) + if (keyType is PyLiteralType && keyType.expression is PyStringLiteralExpression) { + return keyType.expression.getStringValue() + } + return null +} diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PySequencePatternImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PySequencePatternImpl.java deleted file mode 100644 index dc2df9735b26..000000000000 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PySequencePatternImpl.java +++ /dev/null @@ -1,185 +0,0 @@ -package com.jetbrains.python.psi.impl; - -import com.intellij.lang.ASTNode; -import com.intellij.psi.PsiElement; -import com.intellij.psi.PsiListLikeElement; -import com.intellij.psi.util.PsiTreeUtil; -import com.intellij.util.ArrayUtil; -import com.intellij.util.containers.ContainerUtil; -import com.jetbrains.python.psi.*; -import com.jetbrains.python.psi.types.*; -import org.jetbrains.annotations.NotNull; -import org.jetbrains.annotations.Nullable; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; - -import static com.jetbrains.python.psi.impl.PyClassPatternImpl.canExcludeArgumentPatternType; -import static com.jetbrains.python.psi.types.PyLiteralType.upcastLiteralToClass; - -public class PySequencePatternImpl extends PyElementImpl implements PySequencePattern, PsiListLikeElement { - public PySequencePatternImpl(ASTNode astNode) { - super(astNode); - } - - @Override - protected void acceptPyVisitor(PyElementVisitor pyVisitor) { - pyVisitor.visitPySequencePattern(this); - } - - @Override - public @NotNull List getComponents() { - return getElements(); - } - - @Override - public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) { - final PyType sequenceCaptureType = getSequenceCaptureType(this, context); - boolean isHomogeneous = !(sequenceCaptureType instanceof PyTupleType tupleType) || tupleType.isHomogeneous(); - final ArrayList types = new ArrayList<>(); - for (PyPattern pattern : getElements()) { - if (pattern instanceof PySingleStarPattern starPattern) { - types.addAll(getCapturedTypesFromSequenceType(starPattern, sequenceCaptureType, context)); - } - else { - types.add(context.getType(pattern)); - } - } - PyType expectedType = isHomogeneous ? wrapInSequenceType(PyUnionType.union(types), this) : PyTupleType.create(this, types); - if (sequenceCaptureType == null) return expectedType; - return PyTypeUtil.toStream(sequenceCaptureType) - .map(it -> { - if (PyTypeChecker.match(expectedType, it, context)) { - return it; - } - return expectedType; - }) - .collect(PyTypeUtil.toUnion()); - } - - @Override - public boolean canExcludePatternType(@NotNull TypeEvalContext context) { - for (var p : getElements()) { - if (!canExcludeArgumentPatternType(p, context)) { - return false; - } - } - for (var type : PyTypeUtil.toStream(getSequenceCaptureType(this, context))) { - if (!(type instanceof PyTupleType tupleType) || tupleType.isHomogeneous()) { - // Not clear how to do this accurately, so for now matching - // types like list[Something] will not exclude type in any case - return false; - } - } - return true; - } - - @Override - public @Nullable PyType getCaptureTypeForChild(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) { - final PyType sequenceType = getSequenceCaptureType(this, context); - if (sequenceType == null) return null; - - // This is done to skip group- and as-patterns - final var sequenceMember = PsiTreeUtil.findFirstParent(pattern, el -> this == el.getParent()); - if (sequenceMember instanceof PySingleStarPattern starPattern) { - final PyType iteratedType = PyTypeUtil.toStream(sequenceType) - .flatMap(it -> getCapturedTypesFromSequenceType(starPattern, it, context).stream()).collect(PyTypeUtil.toUnion()); - return wrapInListType(iteratedType, pattern); - } - final List elements = getElements(); - final int idx = elements.indexOf(sequenceMember); - final int starIdx = ContainerUtil.indexOf(elements, it2 -> it2 instanceof PySingleStarPattern); - - return PyTypeUtil.toStream(sequenceType).map(it -> { - if (it instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) { - if (starIdx == -1 || idx < starIdx) { - return tupleType.getElementType(idx); - } - else { - final int starSpan = tupleType.getElementCount() - elements.size(); - return tupleType.getElementType(idx + starSpan); - } - } - var upcast = PyTypeUtil.convertToType(it, "typing.Sequence", pattern, context); - if (upcast instanceof PyCollectionType collectionType) { - return collectionType.getIteratedItemType(); - } - return null; - }).collect(PyTypeUtil.toUnion()); - } - - static @Nullable PyType wrapInListType(@Nullable PyType elementType, @NotNull PsiElement resolveAnchor) { - final PyClass list = PyBuiltinCache.getInstance(resolveAnchor).getClass("list"); - return list != null ? new PyCollectionTypeImpl(list, false, Collections.singletonList(upcastLiteralToClass(elementType))) : null; - } - - @Nullable - public static PyType wrapInSequenceType(@Nullable PyType elementType, @NotNull PsiElement resolveAnchor) { - final PyClass sequence = PyPsiFacade.getInstance(resolveAnchor.getProject()).createClassByQName("typing.Sequence", resolveAnchor); - return sequence != null ? new PyCollectionTypeImpl(sequence, false, Collections.singletonList(upcastLiteralToClass(elementType))) : null; - } - - /** - * Similar to {@link PyCaptureContext#getCaptureType(PyPattern, TypeEvalContext)}, - * but only chooses types that would match to typing.Sequence, and have correct length - */ - @Nullable - static PyType getSequenceCaptureType(@NotNull PySequencePattern pattern, @NotNull TypeEvalContext context) { - final PyType captureTypes = PyCaptureContext.getCaptureType(pattern, context); - final boolean hasStar = ContainerUtil.exists(pattern.getElements(), it -> it instanceof PySingleStarPattern); - - List types = new ArrayList<>(); - for (PyType captureType : PyTypeUtil.toStream(captureTypes)) { - if (captureType instanceof PyClassType classType && - ArrayUtil.contains(classType.getClassQName(), "str", "bytes", "bytearray")) continue; - - PyType sequenceType = PyTypeUtil.convertToType(captureType, "typing.Sequence", pattern, context); - if (sequenceType == null) continue; - - if (captureType instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) { - final List elements = pattern.getElements(); - List tupleElementTypes = tupleType.getElementTypes(); - int unpackedTupleIndex = ContainerUtil.indexOf(tupleElementTypes, it -> it instanceof PyUnpackedTupleType); - if (unpackedTupleIndex != -1) { - PyUnpackedTupleType unpackedTupleType = (PyUnpackedTupleType)tupleElementTypes.get(unpackedTupleIndex); - assert unpackedTupleType.isUnbound(); - int variadicElementsCount = elements.size() - tupleElementTypes.size() + 1; - if (variadicElementsCount >= 0) { - List adjustedTupleElementTypes = new ArrayList<>(elements.size()); - adjustedTupleElementTypes.addAll(tupleElementTypes.subList(0, unpackedTupleIndex)); - for (int i = 0; i < variadicElementsCount; i++) { - adjustedTupleElementTypes.add(unpackedTupleType.getElementTypes().get(0)); - } - adjustedTupleElementTypes.addAll(tupleElementTypes.subList(unpackedTupleIndex + 1, tupleElementTypes.size())); - types.add(new PyTupleType(tupleType.getPyClass(), adjustedTupleElementTypes, false)); - } - } else { - if (hasStar && elements.size() <= tupleType.getElementCount() || elements.size() == tupleType.getElementCount()) { - types.add(captureType); - } - } - } - else { - types.add(captureType); - } - } - return PyUnionType.union(types); - } - - private static @NotNull List<@Nullable PyType> getCapturedTypesFromSequenceType(@NotNull PySingleStarPattern starPattern, - @Nullable PyType sequenceType, - @NotNull TypeEvalContext context) { - if (starPattern.getParent() instanceof PySequencePattern sequenceParent) { - final int idx = sequenceParent.getElements().indexOf(starPattern); - if (sequenceType instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) { - return tupleType.getElementTypes().subList(idx, idx + tupleType.getElementCount() - sequenceParent.getElements().size() + 1); - } - var upcast = PyTypeUtil.convertToType(sequenceType, "typing.Sequence", starPattern, context); - if (upcast instanceof PyCollectionType collectionType) { - return Collections.singletonList(collectionType.getIteratedItemType()); - } - } - return List.of(); - } -} diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PySequencePatternImpl.kt b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PySequencePatternImpl.kt new file mode 100644 index 000000000000..55940c51f2a0 --- /dev/null +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PySequencePatternImpl.kt @@ -0,0 +1,163 @@ +package com.jetbrains.python.psi.impl + +import com.intellij.lang.ASTNode +import com.intellij.psi.PsiListLikeElement +import com.intellij.psi.util.findParentInFile +import com.jetbrains.python.psi.* +import com.jetbrains.python.psi.types.* +import com.jetbrains.python.psi.types.PyLiteralType.Companion.upcastLiteralToClass +import kotlin.contracts.ExperimentalContracts +import kotlin.contracts.contract + +class PySequencePatternImpl(astNode: ASTNode?) : PyElementImpl(astNode), PySequencePattern, PsiListLikeElement { + override fun acceptPyVisitor(pyVisitor: PyElementVisitor) { + pyVisitor.visitPySequencePattern(this) + } + + override fun getComponents(): List = elements + + override fun getType(context: TypeEvalContext, key: TypeEvalContext.Key): PyType? { + val sequenceCaptureType = this.getSequenceCaptureType(context) + val types = elements.flatMap { pattern -> + when (pattern) { + is PySingleStarPattern -> pattern.getCapturedTypesFromSequenceType(sequenceCaptureType, context) + else -> listOf(context.getType(pattern)) + } + } + + val expectedType = when { + sequenceCaptureType.isHeterogeneousTuple() -> PyTupleType.create(this, types) + else -> wrapInSequenceType(PyUnionType.union(types)) + } + + if (sequenceCaptureType == null) return expectedType + return sequenceCaptureType.toList() + .map { if (PyTypeChecker.match(expectedType, it, context)) it else expectedType } + .let { PyUnionType.union(it) } + } + + override fun canExcludePatternType(context: TypeEvalContext): Boolean { + val allElementsCoverCapture = elements.all { PyClassPatternImpl.canExcludeArgumentPatternType(it, context) } + val allCapturesOfThisAreHeteroTuples = this.getSequenceCaptureType(context).toList().all { it.isHeterogeneousTuple() } + return allElementsCoverCapture && allCapturesOfThisAreHeteroTuples + } + + override fun getCaptureTypeForChild(pattern: PyPattern, context: TypeEvalContext): PyType? { + val sequenceType = this.getSequenceCaptureType(context) ?: return null + + // This is done to skip group- and as-patterns + val sequenceMember = pattern.findParentInFile(withSelf = true) { el -> this === el.parent } + if (sequenceMember is PySingleStarPattern) { + return sequenceType.toList() + .flatMap { sequenceMember.getCapturedTypesFromSequenceType(it, context) } + .let { PyUnionType.union(it) } + .let { wrapInListType(it) } + } + + val idx = elements.indexOf(sequenceMember) + + return sequenceType.toList() + .map { getElementTypeSkippingStar(it, idx, context) } + .let { PyUnionType.union(it) } + } + + private fun getElementTypeSkippingStar(sequence: PyType?, idx: Int, context: TypeEvalContext): PyType? { + if (sequence.isHeterogeneousTuple()) { + val starIdx = elements.indexOfFirst { it is PySingleStarPattern } + if (starIdx == -1 || idx < starIdx) { + return sequence.getElementType(idx) + } + else { + val starSpan = sequence.elementCount - this.elements.size + return sequence.getElementType(idx + starSpan) + } + } + else { + val upcast = PyTypeUtil.convertToType(sequence, "typing.Sequence", this, context) + return (upcast as? PyCollectionType)?.iteratedItemType + } + } + + /** + * Similar to [PyCaptureContext.getCaptureType], + * but only chooses types that would match to typing.Sequence, and have correct length + */ + private fun PySequencePattern.getSequenceCaptureType(context: TypeEvalContext): PyType? { + val captureTypes: PyType? = PyCaptureContext.getCaptureType(this, context) + + val potentialMatchingTypes = captureTypes.toList() + .filter { it !is PyClassType || it.classQName !in listOf("str", "bytes", "bytearray") } + .filter { PyTypeUtil.convertToType(it, "typing.Sequence", this, context) != null } + + val hasStar = elements.any { it is PySingleStarPattern } + val types = potentialMatchingTypes.mapNotNull { + if (it.isHeterogeneousTuple()) it.takeIfSizeMatches(elements.size, hasStar) else it + } + return PyUnionType.union(types) + } + + fun wrapInListType(elementType: PyType?): PyType? { + val list = PyBuiltinCache.getInstance(this).getClass("list") ?: return null + return PyCollectionTypeImpl(list, false, listOf(upcastLiteralToClass(elementType))) + } + + fun wrapInSequenceType(elementType: PyType?): PyType? { + val sequence = PyPsiFacade.getInstance(getProject()).createClassByQName("typing.Sequence", this) ?: return null + return PyCollectionTypeImpl(sequence, false, listOf(upcastLiteralToClass(elementType))) + } +} + +private fun PySingleStarPattern.getCapturedTypesFromSequenceType(sequenceType: PyType?, context: TypeEvalContext): List { + if (sequenceType.isHeterogeneousTuple()) { + val sequenceParent = this.parent as? PySequencePattern ?: return listOf() + val idx = sequenceParent.elements.indexOf(this) + return sequenceType.elementTypes.subList(idx, idx + sequenceType.elementCount - sequenceParent.elements.size + 1) + } + val upcast = PyTypeUtil.convertToType(sequenceType, "typing.Sequence", this, context) + if (upcast is PyCollectionType) { + return listOf(upcast.getIteratedItemType()) + } + return listOf() +} + +// Use it like PyTypeUtil#toStream +internal fun PyType?.toList(): List = if (this is PyUnionType) members.toList() else listOf(this) + +@OptIn(ExperimentalContracts::class) +private fun PyType?.isHeterogeneousTuple(): Boolean { + contract { returns(true) implies (this@isHeterogeneousTuple is PyTupleType) } + return this is PyTupleType && !isHomogeneous +} + +private fun PyTupleType.takeIfSizeMatches(desiredSize: Int, hasStar: Boolean): PyTupleType? { + if (this.elementTypes.any { it is PyUnpackedTupleType }) { + val variadicElementsCount: Int = desiredSize - this.elementTypes.size + 1 + if (variadicElementsCount >= 0) { + return this.expandVariadics(variadicElementsCount) + } + } + else { + if (hasStar && desiredSize <= this.elementCount || desiredSize == this.elementCount) { + return this + } + } + return null +} + +private fun PyTupleType.expandVariadics(variadicElementCount: Int): PyTupleType { + require(!this.isHomogeneous) { "Supplied tuple must not be homogeneous: $this" } + require(variadicElementCount >= 0) { "Supplied variadic element count must not be negative: $variadicElementCount" } + + val unpackedTupleIndex = elementTypes.indexOfFirst { it is PyUnpackedTupleType } + val unpackedTupleType = elementTypes[unpackedTupleIndex] as PyUnpackedTupleType + assert(unpackedTupleType.isUnbound) + + val adjustedTupleElementTypes = buildList { + addAll(elementTypes.subList(0, unpackedTupleIndex)) + repeat(variadicElementCount) { + add(unpackedTupleType.elementTypes[0]) + } + addAll(elementTypes.subList(unpackedTupleIndex + 1, elementTypes.size)) + } + return PyTupleType(this.pyClass, adjustedTupleElementTypes, false) +} \ No newline at end of file