mirror of
https://gitflic.ru/project/openide/openide.git
synced 2025-12-15 02:59:33 +07:00
[python] Convert PySequencePattern and PyMappingPattern to kotlin
(cherry picked from commit 074ed9f865556d561237fc894d202d76995ab562) IJ-MR-168826 GitOrigin-RevId: 6776375e6566521744ccc1f6c54254b887a1574b
This commit is contained in:
committed by
intellij-monorepo-bot
parent
21e8b573a7
commit
05203527a4
@@ -1,7 +0,0 @@
|
||||
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
package com.jetbrains.python.psi;
|
||||
|
||||
import com.jetbrains.python.ast.PyAstMappingPattern;
|
||||
|
||||
public interface PyMappingPattern extends PyAstMappingPattern, PyPattern, PyCaptureContext {
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
package com.jetbrains.python.psi
|
||||
|
||||
import com.jetbrains.python.ast.PyAstMappingPattern
|
||||
|
||||
interface PyMappingPattern : PyAstMappingPattern, PyPattern, PyCaptureContext
|
||||
@@ -1,15 +0,0 @@
|
||||
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
package com.jetbrains.python.psi;
|
||||
|
||||
import com.jetbrains.python.ast.PyAstSequencePattern;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static com.jetbrains.python.ast.PyAstElementKt.findChildrenByClass;
|
||||
|
||||
public interface PySequencePattern extends PyAstSequencePattern, PyPattern, PyCaptureContext {
|
||||
default @NotNull List<@NotNull PyPattern> getElements() {
|
||||
return List.of(findChildrenByClass(this, PyPattern.class));
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,10 @@
|
||||
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
package com.jetbrains.python.psi
|
||||
|
||||
import com.jetbrains.python.ast.PyAstSequencePattern
|
||||
import com.jetbrains.python.ast.findChildrenByClass
|
||||
|
||||
interface PySequencePattern : PyAstSequencePattern, PyPattern, PyCaptureContext {
|
||||
val elements: List<PyPattern>
|
||||
get() = findChildrenByClass(PyPattern::class.java).toList()
|
||||
}
|
||||
@@ -1,99 +0,0 @@
|
||||
package com.jetbrains.python.psi.impl;
|
||||
|
||||
import com.intellij.lang.ASTNode;
|
||||
import com.intellij.psi.PsiElement;
|
||||
import com.intellij.psi.PsiListLikeElement;
|
||||
import com.intellij.psi.util.PsiTreeUtil;
|
||||
import com.jetbrains.python.psi.*;
|
||||
import com.jetbrains.python.psi.types.*;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
public class PyMappingPatternImpl extends PyElementImpl implements PyMappingPattern, PsiListLikeElement {
|
||||
public PyMappingPatternImpl(ASTNode astNode) {
|
||||
super(astNode);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
|
||||
pyVisitor.visitPyMappingPattern(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public @NotNull List<? extends PyKeyValuePattern> getComponents() {
|
||||
return Arrays.asList(findChildrenByClass(PyKeyValuePattern.class));
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean canExcludePatternType(@NotNull TypeEvalContext context) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@Override
|
||||
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
|
||||
ArrayList<PyType> keyTypes = new ArrayList<>();
|
||||
ArrayList<PyType> valueTypes = new ArrayList<>();
|
||||
for (PyKeyValuePattern it : getComponents()) {
|
||||
keyTypes.add(context.getType(it.getKeyPattern()));
|
||||
if (it.getValuePattern() != null) {
|
||||
valueTypes.add(context.getType(it.getValuePattern()));
|
||||
}
|
||||
}
|
||||
|
||||
PyType patternMappingType = wrapInMappingType(PyUnionType.union(keyTypes), PyUnionType.union(valueTypes), this);
|
||||
|
||||
PyType captureTypes = PyCaptureContext.getCaptureType(this, context);
|
||||
PyType filteredType = PyTypeUtil.toStream(captureTypes).filter(captureType -> {
|
||||
var mappingType = PyTypeUtil.convertToType(captureType, "typing.Mapping", this, context);
|
||||
if (mappingType == null) return false;
|
||||
if (!PyTypeChecker.match(mappingType, patternMappingType, context)) return false;
|
||||
return true;
|
||||
}).collect(PyTypeUtil.toUnion());
|
||||
|
||||
return filteredType == null ? patternMappingType : filteredType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public @Nullable PyType getCaptureTypeForChild(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) {
|
||||
final var sequenceMember = PsiTreeUtil.findFirstParent(pattern, el -> this == el.getParent());
|
||||
if (sequenceMember instanceof PyDoubleStarPattern) {
|
||||
var mappingType = PyTypeUtil.convertToType(context.getType(this), "typing.Mapping", pattern, context);
|
||||
if (mappingType instanceof PyCollectionType collectionType) {
|
||||
final PyClass dict = PyBuiltinCache.getInstance(pattern).getClass("dict");
|
||||
return dict != null ? new PyCollectionTypeImpl(dict, false, collectionType.getElementTypes()) : null;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
else if (sequenceMember instanceof PyKeyValuePattern keyValuePattern) {
|
||||
return PyTypeUtil.toStream(PyCaptureContext.getCaptureType(this, context)).map(type -> {
|
||||
if (type instanceof PyTypedDictType typedDictType) {
|
||||
if (context.getType(keyValuePattern.getKeyPattern()) instanceof PyLiteralType l &&
|
||||
l.getExpression() instanceof PyStringLiteralExpression str) {
|
||||
return typedDictType.getElementType(str.getStringValue());
|
||||
}
|
||||
}
|
||||
|
||||
PyType mappingType = PyTypeUtil.convertToType(type, "typing.Mapping", pattern, context);
|
||||
if (mappingType == null) {
|
||||
return PyNeverType.NEVER;
|
||||
}
|
||||
else if (mappingType instanceof PyCollectionType collectionType) {
|
||||
return collectionType.getElementTypes().get(1);
|
||||
}
|
||||
return null;
|
||||
}).collect(PyTypeUtil.toUnion());
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static @Nullable PyType wrapInMappingType(@Nullable PyType keyType, @Nullable PyType valueType, @NotNull PsiElement resolveAnchor) {
|
||||
keyType = PyLiteralType.upcastLiteralToClass(keyType);
|
||||
valueType = PyLiteralType.upcastLiteralToClass(valueType);
|
||||
final PyClass sequence = PyPsiFacade.getInstance(resolveAnchor.getProject()).createClassByQName("typing.Mapping", resolveAnchor);
|
||||
return sequence != null ? new PyCollectionTypeImpl(sequence, false, Arrays.asList(keyType, valueType)) : null;
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,86 @@
|
||||
package com.jetbrains.python.psi.impl
|
||||
|
||||
import com.intellij.lang.ASTNode
|
||||
import com.intellij.psi.PsiListLikeElement
|
||||
import com.intellij.psi.util.findParentInFile
|
||||
import com.jetbrains.python.psi.*
|
||||
import com.jetbrains.python.psi.PyCaptureContext.Companion.getCaptureType
|
||||
import com.jetbrains.python.psi.impl.PyBuiltinCache.Companion.getInstance
|
||||
import com.jetbrains.python.psi.types.*
|
||||
import com.jetbrains.python.psi.types.PyLiteralType.Companion.upcastLiteralToClass
|
||||
|
||||
class PyMappingPatternImpl(astNode: ASTNode?) : PyElementImpl(astNode), PyMappingPattern, PsiListLikeElement {
|
||||
override fun acceptPyVisitor(pyVisitor: PyElementVisitor) {
|
||||
pyVisitor.visitPyMappingPattern(this)
|
||||
}
|
||||
|
||||
override fun getComponents(): List<PyKeyValuePattern> = findChildrenByClass(PyKeyValuePattern::class.java).toList()
|
||||
|
||||
override fun canExcludePatternType(context: TypeEvalContext): Boolean = false
|
||||
|
||||
override fun getType(context: TypeEvalContext, key: TypeEvalContext.Key): PyType? {
|
||||
val keyTypes = mutableListOf<PyType?>()
|
||||
val valueTypes = mutableListOf<PyType?>()
|
||||
for (it in components) {
|
||||
keyTypes.add(context.getType(it.keyPattern))
|
||||
if (it.valuePattern != null) {
|
||||
valueTypes.add(context.getType(it.valuePattern!!))
|
||||
}
|
||||
}
|
||||
|
||||
val patternMappingType = wrapInMappingType(PyUnionType.union(keyTypes), PyUnionType.union(valueTypes))
|
||||
|
||||
val filteredType = getCaptureType(this, context).toList().filter { captureType: PyType? ->
|
||||
val mappingType = PyTypeUtil.convertToType(captureType, "typing.Mapping", this, context) ?: return@filter false
|
||||
PyTypeChecker.match(mappingType, patternMappingType, context)
|
||||
}.let {
|
||||
PyUnionType.union(it)
|
||||
}
|
||||
|
||||
return filteredType ?: patternMappingType
|
||||
}
|
||||
|
||||
override fun getCaptureTypeForChild(pattern: PyPattern, context: TypeEvalContext): PyType? {
|
||||
val sequenceMember = pattern.findParentInFile(withSelf = true) { this === it.parent }
|
||||
if (sequenceMember is PyDoubleStarPattern) {
|
||||
val mappingType = PyTypeUtil.convertToType(context.getType(this), "typing.Mapping", pattern, context)
|
||||
if (mappingType is PyCollectionType) {
|
||||
val dict = getInstance(pattern).getClass("dict") ?: return null
|
||||
return PyCollectionTypeImpl(dict, false, mappingType.getElementTypes())
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
if (sequenceMember !is PyKeyValuePattern) return null
|
||||
|
||||
return getCaptureType(this, context).toList()
|
||||
.map { possibleMapping -> possibleMapping.getValueType(sequenceMember, context) }
|
||||
.let { PyUnionType.union(it) }
|
||||
}
|
||||
|
||||
private fun wrapInMappingType(keyType: PyType?, valueType: PyType?): PyType? {
|
||||
val sequence = PyPsiFacade.getInstance(getProject()).createClassByQName("typing.Mapping", this) ?: return null
|
||||
return PyCollectionTypeImpl(sequence, false, listOf(keyType, valueType).map { upcastLiteralToClass(it) })
|
||||
}
|
||||
}
|
||||
|
||||
private fun PyType?.getValueType(sequenceMember: PyKeyValuePattern, context: TypeEvalContext): PyType? {
|
||||
if (this is PyTypedDictType) {
|
||||
val key = sequenceMember.getKeyString(context)
|
||||
if (key != null) return this.getElementType(key)
|
||||
}
|
||||
val mappingType = PyTypeUtil.convertToType(this, "typing.Mapping", sequenceMember, context)
|
||||
?: return PyNeverType.NEVER
|
||||
if (mappingType is PyCollectionType) {
|
||||
return mappingType.elementTypes[1]
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
private fun PyKeyValuePattern.getKeyString(context: TypeEvalContext): String? {
|
||||
val keyType = context.getType(keyPattern)
|
||||
if (keyType is PyLiteralType && keyType.expression is PyStringLiteralExpression) {
|
||||
return keyType.expression.getStringValue()
|
||||
}
|
||||
return null
|
||||
}
|
||||
@@ -1,185 +0,0 @@
|
||||
package com.jetbrains.python.psi.impl;
|
||||
|
||||
import com.intellij.lang.ASTNode;
|
||||
import com.intellij.psi.PsiElement;
|
||||
import com.intellij.psi.PsiListLikeElement;
|
||||
import com.intellij.psi.util.PsiTreeUtil;
|
||||
import com.intellij.util.ArrayUtil;
|
||||
import com.intellij.util.containers.ContainerUtil;
|
||||
import com.jetbrains.python.psi.*;
|
||||
import com.jetbrains.python.psi.types.*;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import static com.jetbrains.python.psi.impl.PyClassPatternImpl.canExcludeArgumentPatternType;
|
||||
import static com.jetbrains.python.psi.types.PyLiteralType.upcastLiteralToClass;
|
||||
|
||||
public class PySequencePatternImpl extends PyElementImpl implements PySequencePattern, PsiListLikeElement {
|
||||
public PySequencePatternImpl(ASTNode astNode) {
|
||||
super(astNode);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
|
||||
pyVisitor.visitPySequencePattern(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public @NotNull List<? extends PyPattern> getComponents() {
|
||||
return getElements();
|
||||
}
|
||||
|
||||
@Override
|
||||
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
|
||||
final PyType sequenceCaptureType = getSequenceCaptureType(this, context);
|
||||
boolean isHomogeneous = !(sequenceCaptureType instanceof PyTupleType tupleType) || tupleType.isHomogeneous();
|
||||
final ArrayList<PyType> types = new ArrayList<>();
|
||||
for (PyPattern pattern : getElements()) {
|
||||
if (pattern instanceof PySingleStarPattern starPattern) {
|
||||
types.addAll(getCapturedTypesFromSequenceType(starPattern, sequenceCaptureType, context));
|
||||
}
|
||||
else {
|
||||
types.add(context.getType(pattern));
|
||||
}
|
||||
}
|
||||
PyType expectedType = isHomogeneous ? wrapInSequenceType(PyUnionType.union(types), this) : PyTupleType.create(this, types);
|
||||
if (sequenceCaptureType == null) return expectedType;
|
||||
return PyTypeUtil.toStream(sequenceCaptureType)
|
||||
.map(it -> {
|
||||
if (PyTypeChecker.match(expectedType, it, context)) {
|
||||
return it;
|
||||
}
|
||||
return expectedType;
|
||||
})
|
||||
.collect(PyTypeUtil.toUnion());
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean canExcludePatternType(@NotNull TypeEvalContext context) {
|
||||
for (var p : getElements()) {
|
||||
if (!canExcludeArgumentPatternType(p, context)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
for (var type : PyTypeUtil.toStream(getSequenceCaptureType(this, context))) {
|
||||
if (!(type instanceof PyTupleType tupleType) || tupleType.isHomogeneous()) {
|
||||
// Not clear how to do this accurately, so for now matching
|
||||
// types like list[Something] will not exclude type in any case
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public @Nullable PyType getCaptureTypeForChild(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) {
|
||||
final PyType sequenceType = getSequenceCaptureType(this, context);
|
||||
if (sequenceType == null) return null;
|
||||
|
||||
// This is done to skip group- and as-patterns
|
||||
final var sequenceMember = PsiTreeUtil.findFirstParent(pattern, el -> this == el.getParent());
|
||||
if (sequenceMember instanceof PySingleStarPattern starPattern) {
|
||||
final PyType iteratedType = PyTypeUtil.toStream(sequenceType)
|
||||
.flatMap(it -> getCapturedTypesFromSequenceType(starPattern, it, context).stream()).collect(PyTypeUtil.toUnion());
|
||||
return wrapInListType(iteratedType, pattern);
|
||||
}
|
||||
final List<PyPattern> elements = getElements();
|
||||
final int idx = elements.indexOf(sequenceMember);
|
||||
final int starIdx = ContainerUtil.indexOf(elements, it2 -> it2 instanceof PySingleStarPattern);
|
||||
|
||||
return PyTypeUtil.toStream(sequenceType).map(it -> {
|
||||
if (it instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) {
|
||||
if (starIdx == -1 || idx < starIdx) {
|
||||
return tupleType.getElementType(idx);
|
||||
}
|
||||
else {
|
||||
final int starSpan = tupleType.getElementCount() - elements.size();
|
||||
return tupleType.getElementType(idx + starSpan);
|
||||
}
|
||||
}
|
||||
var upcast = PyTypeUtil.convertToType(it, "typing.Sequence", pattern, context);
|
||||
if (upcast instanceof PyCollectionType collectionType) {
|
||||
return collectionType.getIteratedItemType();
|
||||
}
|
||||
return null;
|
||||
}).collect(PyTypeUtil.toUnion());
|
||||
}
|
||||
|
||||
static @Nullable PyType wrapInListType(@Nullable PyType elementType, @NotNull PsiElement resolveAnchor) {
|
||||
final PyClass list = PyBuiltinCache.getInstance(resolveAnchor).getClass("list");
|
||||
return list != null ? new PyCollectionTypeImpl(list, false, Collections.singletonList(upcastLiteralToClass(elementType))) : null;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public static PyType wrapInSequenceType(@Nullable PyType elementType, @NotNull PsiElement resolveAnchor) {
|
||||
final PyClass sequence = PyPsiFacade.getInstance(resolveAnchor.getProject()).createClassByQName("typing.Sequence", resolveAnchor);
|
||||
return sequence != null ? new PyCollectionTypeImpl(sequence, false, Collections.singletonList(upcastLiteralToClass(elementType))) : null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Similar to {@link PyCaptureContext#getCaptureType(PyPattern, TypeEvalContext)},
|
||||
* but only chooses types that would match to typing.Sequence, and have correct length
|
||||
*/
|
||||
@Nullable
|
||||
static PyType getSequenceCaptureType(@NotNull PySequencePattern pattern, @NotNull TypeEvalContext context) {
|
||||
final PyType captureTypes = PyCaptureContext.getCaptureType(pattern, context);
|
||||
final boolean hasStar = ContainerUtil.exists(pattern.getElements(), it -> it instanceof PySingleStarPattern);
|
||||
|
||||
List<PyType> types = new ArrayList<>();
|
||||
for (PyType captureType : PyTypeUtil.toStream(captureTypes)) {
|
||||
if (captureType instanceof PyClassType classType &&
|
||||
ArrayUtil.contains(classType.getClassQName(), "str", "bytes", "bytearray")) continue;
|
||||
|
||||
PyType sequenceType = PyTypeUtil.convertToType(captureType, "typing.Sequence", pattern, context);
|
||||
if (sequenceType == null) continue;
|
||||
|
||||
if (captureType instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) {
|
||||
final List<PyPattern> elements = pattern.getElements();
|
||||
List<PyType> tupleElementTypes = tupleType.getElementTypes();
|
||||
int unpackedTupleIndex = ContainerUtil.indexOf(tupleElementTypes, it -> it instanceof PyUnpackedTupleType);
|
||||
if (unpackedTupleIndex != -1) {
|
||||
PyUnpackedTupleType unpackedTupleType = (PyUnpackedTupleType)tupleElementTypes.get(unpackedTupleIndex);
|
||||
assert unpackedTupleType.isUnbound();
|
||||
int variadicElementsCount = elements.size() - tupleElementTypes.size() + 1;
|
||||
if (variadicElementsCount >= 0) {
|
||||
List<PyType> adjustedTupleElementTypes = new ArrayList<>(elements.size());
|
||||
adjustedTupleElementTypes.addAll(tupleElementTypes.subList(0, unpackedTupleIndex));
|
||||
for (int i = 0; i < variadicElementsCount; i++) {
|
||||
adjustedTupleElementTypes.add(unpackedTupleType.getElementTypes().get(0));
|
||||
}
|
||||
adjustedTupleElementTypes.addAll(tupleElementTypes.subList(unpackedTupleIndex + 1, tupleElementTypes.size()));
|
||||
types.add(new PyTupleType(tupleType.getPyClass(), adjustedTupleElementTypes, false));
|
||||
}
|
||||
} else {
|
||||
if (hasStar && elements.size() <= tupleType.getElementCount() || elements.size() == tupleType.getElementCount()) {
|
||||
types.add(captureType);
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
types.add(captureType);
|
||||
}
|
||||
}
|
||||
return PyUnionType.union(types);
|
||||
}
|
||||
|
||||
private static @NotNull List<@Nullable PyType> getCapturedTypesFromSequenceType(@NotNull PySingleStarPattern starPattern,
|
||||
@Nullable PyType sequenceType,
|
||||
@NotNull TypeEvalContext context) {
|
||||
if (starPattern.getParent() instanceof PySequencePattern sequenceParent) {
|
||||
final int idx = sequenceParent.getElements().indexOf(starPattern);
|
||||
if (sequenceType instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) {
|
||||
return tupleType.getElementTypes().subList(idx, idx + tupleType.getElementCount() - sequenceParent.getElements().size() + 1);
|
||||
}
|
||||
var upcast = PyTypeUtil.convertToType(sequenceType, "typing.Sequence", starPattern, context);
|
||||
if (upcast instanceof PyCollectionType collectionType) {
|
||||
return Collections.singletonList(collectionType.getIteratedItemType());
|
||||
}
|
||||
}
|
||||
return List.of();
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,163 @@
|
||||
package com.jetbrains.python.psi.impl
|
||||
|
||||
import com.intellij.lang.ASTNode
|
||||
import com.intellij.psi.PsiListLikeElement
|
||||
import com.intellij.psi.util.findParentInFile
|
||||
import com.jetbrains.python.psi.*
|
||||
import com.jetbrains.python.psi.types.*
|
||||
import com.jetbrains.python.psi.types.PyLiteralType.Companion.upcastLiteralToClass
|
||||
import kotlin.contracts.ExperimentalContracts
|
||||
import kotlin.contracts.contract
|
||||
|
||||
class PySequencePatternImpl(astNode: ASTNode?) : PyElementImpl(astNode), PySequencePattern, PsiListLikeElement {
|
||||
override fun acceptPyVisitor(pyVisitor: PyElementVisitor) {
|
||||
pyVisitor.visitPySequencePattern(this)
|
||||
}
|
||||
|
||||
override fun getComponents(): List<PyPattern> = elements
|
||||
|
||||
override fun getType(context: TypeEvalContext, key: TypeEvalContext.Key): PyType? {
|
||||
val sequenceCaptureType = this.getSequenceCaptureType(context)
|
||||
val types = elements.flatMap { pattern ->
|
||||
when (pattern) {
|
||||
is PySingleStarPattern -> pattern.getCapturedTypesFromSequenceType(sequenceCaptureType, context)
|
||||
else -> listOf(context.getType(pattern))
|
||||
}
|
||||
}
|
||||
|
||||
val expectedType = when {
|
||||
sequenceCaptureType.isHeterogeneousTuple() -> PyTupleType.create(this, types)
|
||||
else -> wrapInSequenceType(PyUnionType.union(types))
|
||||
}
|
||||
|
||||
if (sequenceCaptureType == null) return expectedType
|
||||
return sequenceCaptureType.toList()
|
||||
.map { if (PyTypeChecker.match(expectedType, it, context)) it else expectedType }
|
||||
.let { PyUnionType.union(it) }
|
||||
}
|
||||
|
||||
override fun canExcludePatternType(context: TypeEvalContext): Boolean {
|
||||
val allElementsCoverCapture = elements.all { PyClassPatternImpl.canExcludeArgumentPatternType(it, context) }
|
||||
val allCapturesOfThisAreHeteroTuples = this.getSequenceCaptureType(context).toList().all { it.isHeterogeneousTuple() }
|
||||
return allElementsCoverCapture && allCapturesOfThisAreHeteroTuples
|
||||
}
|
||||
|
||||
override fun getCaptureTypeForChild(pattern: PyPattern, context: TypeEvalContext): PyType? {
|
||||
val sequenceType = this.getSequenceCaptureType(context) ?: return null
|
||||
|
||||
// This is done to skip group- and as-patterns
|
||||
val sequenceMember = pattern.findParentInFile(withSelf = true) { el -> this === el.parent }
|
||||
if (sequenceMember is PySingleStarPattern) {
|
||||
return sequenceType.toList()
|
||||
.flatMap { sequenceMember.getCapturedTypesFromSequenceType(it, context) }
|
||||
.let { PyUnionType.union(it) }
|
||||
.let { wrapInListType(it) }
|
||||
}
|
||||
|
||||
val idx = elements.indexOf(sequenceMember)
|
||||
|
||||
return sequenceType.toList()
|
||||
.map { getElementTypeSkippingStar(it, idx, context) }
|
||||
.let { PyUnionType.union(it) }
|
||||
}
|
||||
|
||||
private fun getElementTypeSkippingStar(sequence: PyType?, idx: Int, context: TypeEvalContext): PyType? {
|
||||
if (sequence.isHeterogeneousTuple()) {
|
||||
val starIdx = elements.indexOfFirst { it is PySingleStarPattern }
|
||||
if (starIdx == -1 || idx < starIdx) {
|
||||
return sequence.getElementType(idx)
|
||||
}
|
||||
else {
|
||||
val starSpan = sequence.elementCount - this.elements.size
|
||||
return sequence.getElementType(idx + starSpan)
|
||||
}
|
||||
}
|
||||
else {
|
||||
val upcast = PyTypeUtil.convertToType(sequence, "typing.Sequence", this, context)
|
||||
return (upcast as? PyCollectionType)?.iteratedItemType
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Similar to [PyCaptureContext.getCaptureType],
|
||||
* but only chooses types that would match to typing.Sequence, and have correct length
|
||||
*/
|
||||
private fun PySequencePattern.getSequenceCaptureType(context: TypeEvalContext): PyType? {
|
||||
val captureTypes: PyType? = PyCaptureContext.getCaptureType(this, context)
|
||||
|
||||
val potentialMatchingTypes = captureTypes.toList()
|
||||
.filter { it !is PyClassType || it.classQName !in listOf("str", "bytes", "bytearray") }
|
||||
.filter { PyTypeUtil.convertToType(it, "typing.Sequence", this, context) != null }
|
||||
|
||||
val hasStar = elements.any { it is PySingleStarPattern }
|
||||
val types = potentialMatchingTypes.mapNotNull {
|
||||
if (it.isHeterogeneousTuple()) it.takeIfSizeMatches(elements.size, hasStar) else it
|
||||
}
|
||||
return PyUnionType.union(types)
|
||||
}
|
||||
|
||||
fun wrapInListType(elementType: PyType?): PyType? {
|
||||
val list = PyBuiltinCache.getInstance(this).getClass("list") ?: return null
|
||||
return PyCollectionTypeImpl(list, false, listOf(upcastLiteralToClass(elementType)))
|
||||
}
|
||||
|
||||
fun wrapInSequenceType(elementType: PyType?): PyType? {
|
||||
val sequence = PyPsiFacade.getInstance(getProject()).createClassByQName("typing.Sequence", this) ?: return null
|
||||
return PyCollectionTypeImpl(sequence, false, listOf(upcastLiteralToClass(elementType)))
|
||||
}
|
||||
}
|
||||
|
||||
private fun PySingleStarPattern.getCapturedTypesFromSequenceType(sequenceType: PyType?, context: TypeEvalContext): List<PyType?> {
|
||||
if (sequenceType.isHeterogeneousTuple()) {
|
||||
val sequenceParent = this.parent as? PySequencePattern ?: return listOf()
|
||||
val idx = sequenceParent.elements.indexOf(this)
|
||||
return sequenceType.elementTypes.subList(idx, idx + sequenceType.elementCount - sequenceParent.elements.size + 1)
|
||||
}
|
||||
val upcast = PyTypeUtil.convertToType(sequenceType, "typing.Sequence", this, context)
|
||||
if (upcast is PyCollectionType) {
|
||||
return listOf(upcast.getIteratedItemType())
|
||||
}
|
||||
return listOf()
|
||||
}
|
||||
|
||||
// Use it like PyTypeUtil#toStream
|
||||
internal fun PyType?.toList(): List<PyType?> = if (this is PyUnionType) members.toList() else listOf(this)
|
||||
|
||||
@OptIn(ExperimentalContracts::class)
|
||||
private fun PyType?.isHeterogeneousTuple(): Boolean {
|
||||
contract { returns(true) implies (this@isHeterogeneousTuple is PyTupleType) }
|
||||
return this is PyTupleType && !isHomogeneous
|
||||
}
|
||||
|
||||
private fun PyTupleType.takeIfSizeMatches(desiredSize: Int, hasStar: Boolean): PyTupleType? {
|
||||
if (this.elementTypes.any { it is PyUnpackedTupleType }) {
|
||||
val variadicElementsCount: Int = desiredSize - this.elementTypes.size + 1
|
||||
if (variadicElementsCount >= 0) {
|
||||
return this.expandVariadics(variadicElementsCount)
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (hasStar && desiredSize <= this.elementCount || desiredSize == this.elementCount) {
|
||||
return this
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
private fun PyTupleType.expandVariadics(variadicElementCount: Int): PyTupleType {
|
||||
require(!this.isHomogeneous) { "Supplied tuple must not be homogeneous: $this" }
|
||||
require(variadicElementCount >= 0) { "Supplied variadic element count must not be negative: $variadicElementCount" }
|
||||
|
||||
val unpackedTupleIndex = elementTypes.indexOfFirst { it is PyUnpackedTupleType }
|
||||
val unpackedTupleType = elementTypes[unpackedTupleIndex] as PyUnpackedTupleType
|
||||
assert(unpackedTupleType.isUnbound)
|
||||
|
||||
val adjustedTupleElementTypes = buildList {
|
||||
addAll(elementTypes.subList(0, unpackedTupleIndex))
|
||||
repeat(variadicElementCount) {
|
||||
add(unpackedTupleType.elementTypes[0])
|
||||
}
|
||||
addAll(elementTypes.subList(unpackedTupleIndex + 1, elementTypes.size))
|
||||
}
|
||||
return PyTupleType(this.pyClass, adjustedTupleElementTypes, false)
|
||||
}
|
||||
Reference in New Issue
Block a user