mirror of
https://gitflic.ru/project/openide/openide.git
synced 2025-12-16 22:51:17 +07:00
[python] Convert PySequencePattern and PyMappingPattern to kotlin
(cherry picked from commit 074ed9f865556d561237fc894d202d76995ab562) IJ-MR-168826 GitOrigin-RevId: 6776375e6566521744ccc1f6c54254b887a1574b
This commit is contained in:
committed by
intellij-monorepo-bot
parent
21e8b573a7
commit
05203527a4
@@ -1,7 +0,0 @@
|
|||||||
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
|
||||||
package com.jetbrains.python.psi;
|
|
||||||
|
|
||||||
import com.jetbrains.python.ast.PyAstMappingPattern;
|
|
||||||
|
|
||||||
public interface PyMappingPattern extends PyAstMappingPattern, PyPattern, PyCaptureContext {
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||||
|
package com.jetbrains.python.psi
|
||||||
|
|
||||||
|
import com.jetbrains.python.ast.PyAstMappingPattern
|
||||||
|
|
||||||
|
interface PyMappingPattern : PyAstMappingPattern, PyPattern, PyCaptureContext
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
|
||||||
package com.jetbrains.python.psi;
|
|
||||||
|
|
||||||
import com.jetbrains.python.ast.PyAstSequencePattern;
|
|
||||||
import org.jetbrains.annotations.NotNull;
|
|
||||||
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static com.jetbrains.python.ast.PyAstElementKt.findChildrenByClass;
|
|
||||||
|
|
||||||
public interface PySequencePattern extends PyAstSequencePattern, PyPattern, PyCaptureContext {
|
|
||||||
default @NotNull List<@NotNull PyPattern> getElements() {
|
|
||||||
return List.of(findChildrenByClass(this, PyPattern.class));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,10 @@
|
|||||||
|
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||||
|
package com.jetbrains.python.psi
|
||||||
|
|
||||||
|
import com.jetbrains.python.ast.PyAstSequencePattern
|
||||||
|
import com.jetbrains.python.ast.findChildrenByClass
|
||||||
|
|
||||||
|
interface PySequencePattern : PyAstSequencePattern, PyPattern, PyCaptureContext {
|
||||||
|
val elements: List<PyPattern>
|
||||||
|
get() = findChildrenByClass(PyPattern::class.java).toList()
|
||||||
|
}
|
||||||
@@ -1,99 +0,0 @@
|
|||||||
package com.jetbrains.python.psi.impl;
|
|
||||||
|
|
||||||
import com.intellij.lang.ASTNode;
|
|
||||||
import com.intellij.psi.PsiElement;
|
|
||||||
import com.intellij.psi.PsiListLikeElement;
|
|
||||||
import com.intellij.psi.util.PsiTreeUtil;
|
|
||||||
import com.jetbrains.python.psi.*;
|
|
||||||
import com.jetbrains.python.psi.types.*;
|
|
||||||
import org.jetbrains.annotations.NotNull;
|
|
||||||
import org.jetbrains.annotations.Nullable;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
public class PyMappingPatternImpl extends PyElementImpl implements PyMappingPattern, PsiListLikeElement {
|
|
||||||
public PyMappingPatternImpl(ASTNode astNode) {
|
|
||||||
super(astNode);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
|
|
||||||
pyVisitor.visitPyMappingPattern(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public @NotNull List<? extends PyKeyValuePattern> getComponents() {
|
|
||||||
return Arrays.asList(findChildrenByClass(PyKeyValuePattern.class));
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean canExcludePatternType(@NotNull TypeEvalContext context) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
|
|
||||||
ArrayList<PyType> keyTypes = new ArrayList<>();
|
|
||||||
ArrayList<PyType> valueTypes = new ArrayList<>();
|
|
||||||
for (PyKeyValuePattern it : getComponents()) {
|
|
||||||
keyTypes.add(context.getType(it.getKeyPattern()));
|
|
||||||
if (it.getValuePattern() != null) {
|
|
||||||
valueTypes.add(context.getType(it.getValuePattern()));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
PyType patternMappingType = wrapInMappingType(PyUnionType.union(keyTypes), PyUnionType.union(valueTypes), this);
|
|
||||||
|
|
||||||
PyType captureTypes = PyCaptureContext.getCaptureType(this, context);
|
|
||||||
PyType filteredType = PyTypeUtil.toStream(captureTypes).filter(captureType -> {
|
|
||||||
var mappingType = PyTypeUtil.convertToType(captureType, "typing.Mapping", this, context);
|
|
||||||
if (mappingType == null) return false;
|
|
||||||
if (!PyTypeChecker.match(mappingType, patternMappingType, context)) return false;
|
|
||||||
return true;
|
|
||||||
}).collect(PyTypeUtil.toUnion());
|
|
||||||
|
|
||||||
return filteredType == null ? patternMappingType : filteredType;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public @Nullable PyType getCaptureTypeForChild(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) {
|
|
||||||
final var sequenceMember = PsiTreeUtil.findFirstParent(pattern, el -> this == el.getParent());
|
|
||||||
if (sequenceMember instanceof PyDoubleStarPattern) {
|
|
||||||
var mappingType = PyTypeUtil.convertToType(context.getType(this), "typing.Mapping", pattern, context);
|
|
||||||
if (mappingType instanceof PyCollectionType collectionType) {
|
|
||||||
final PyClass dict = PyBuiltinCache.getInstance(pattern).getClass("dict");
|
|
||||||
return dict != null ? new PyCollectionTypeImpl(dict, false, collectionType.getElementTypes()) : null;
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
else if (sequenceMember instanceof PyKeyValuePattern keyValuePattern) {
|
|
||||||
return PyTypeUtil.toStream(PyCaptureContext.getCaptureType(this, context)).map(type -> {
|
|
||||||
if (type instanceof PyTypedDictType typedDictType) {
|
|
||||||
if (context.getType(keyValuePattern.getKeyPattern()) instanceof PyLiteralType l &&
|
|
||||||
l.getExpression() instanceof PyStringLiteralExpression str) {
|
|
||||||
return typedDictType.getElementType(str.getStringValue());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
PyType mappingType = PyTypeUtil.convertToType(type, "typing.Mapping", pattern, context);
|
|
||||||
if (mappingType == null) {
|
|
||||||
return PyNeverType.NEVER;
|
|
||||||
}
|
|
||||||
else if (mappingType instanceof PyCollectionType collectionType) {
|
|
||||||
return collectionType.getElementTypes().get(1);
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}).collect(PyTypeUtil.toUnion());
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
|
|
||||||
private static @Nullable PyType wrapInMappingType(@Nullable PyType keyType, @Nullable PyType valueType, @NotNull PsiElement resolveAnchor) {
|
|
||||||
keyType = PyLiteralType.upcastLiteralToClass(keyType);
|
|
||||||
valueType = PyLiteralType.upcastLiteralToClass(valueType);
|
|
||||||
final PyClass sequence = PyPsiFacade.getInstance(resolveAnchor.getProject()).createClassByQName("typing.Mapping", resolveAnchor);
|
|
||||||
return sequence != null ? new PyCollectionTypeImpl(sequence, false, Arrays.asList(keyType, valueType)) : null;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,86 @@
|
|||||||
|
package com.jetbrains.python.psi.impl
|
||||||
|
|
||||||
|
import com.intellij.lang.ASTNode
|
||||||
|
import com.intellij.psi.PsiListLikeElement
|
||||||
|
import com.intellij.psi.util.findParentInFile
|
||||||
|
import com.jetbrains.python.psi.*
|
||||||
|
import com.jetbrains.python.psi.PyCaptureContext.Companion.getCaptureType
|
||||||
|
import com.jetbrains.python.psi.impl.PyBuiltinCache.Companion.getInstance
|
||||||
|
import com.jetbrains.python.psi.types.*
|
||||||
|
import com.jetbrains.python.psi.types.PyLiteralType.Companion.upcastLiteralToClass
|
||||||
|
|
||||||
|
class PyMappingPatternImpl(astNode: ASTNode?) : PyElementImpl(astNode), PyMappingPattern, PsiListLikeElement {
|
||||||
|
override fun acceptPyVisitor(pyVisitor: PyElementVisitor) {
|
||||||
|
pyVisitor.visitPyMappingPattern(this)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun getComponents(): List<PyKeyValuePattern> = findChildrenByClass(PyKeyValuePattern::class.java).toList()
|
||||||
|
|
||||||
|
override fun canExcludePatternType(context: TypeEvalContext): Boolean = false
|
||||||
|
|
||||||
|
override fun getType(context: TypeEvalContext, key: TypeEvalContext.Key): PyType? {
|
||||||
|
val keyTypes = mutableListOf<PyType?>()
|
||||||
|
val valueTypes = mutableListOf<PyType?>()
|
||||||
|
for (it in components) {
|
||||||
|
keyTypes.add(context.getType(it.keyPattern))
|
||||||
|
if (it.valuePattern != null) {
|
||||||
|
valueTypes.add(context.getType(it.valuePattern!!))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val patternMappingType = wrapInMappingType(PyUnionType.union(keyTypes), PyUnionType.union(valueTypes))
|
||||||
|
|
||||||
|
val filteredType = getCaptureType(this, context).toList().filter { captureType: PyType? ->
|
||||||
|
val mappingType = PyTypeUtil.convertToType(captureType, "typing.Mapping", this, context) ?: return@filter false
|
||||||
|
PyTypeChecker.match(mappingType, patternMappingType, context)
|
||||||
|
}.let {
|
||||||
|
PyUnionType.union(it)
|
||||||
|
}
|
||||||
|
|
||||||
|
return filteredType ?: patternMappingType
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun getCaptureTypeForChild(pattern: PyPattern, context: TypeEvalContext): PyType? {
|
||||||
|
val sequenceMember = pattern.findParentInFile(withSelf = true) { this === it.parent }
|
||||||
|
if (sequenceMember is PyDoubleStarPattern) {
|
||||||
|
val mappingType = PyTypeUtil.convertToType(context.getType(this), "typing.Mapping", pattern, context)
|
||||||
|
if (mappingType is PyCollectionType) {
|
||||||
|
val dict = getInstance(pattern).getClass("dict") ?: return null
|
||||||
|
return PyCollectionTypeImpl(dict, false, mappingType.getElementTypes())
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sequenceMember !is PyKeyValuePattern) return null
|
||||||
|
|
||||||
|
return getCaptureType(this, context).toList()
|
||||||
|
.map { possibleMapping -> possibleMapping.getValueType(sequenceMember, context) }
|
||||||
|
.let { PyUnionType.union(it) }
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun wrapInMappingType(keyType: PyType?, valueType: PyType?): PyType? {
|
||||||
|
val sequence = PyPsiFacade.getInstance(getProject()).createClassByQName("typing.Mapping", this) ?: return null
|
||||||
|
return PyCollectionTypeImpl(sequence, false, listOf(keyType, valueType).map { upcastLiteralToClass(it) })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun PyType?.getValueType(sequenceMember: PyKeyValuePattern, context: TypeEvalContext): PyType? {
|
||||||
|
if (this is PyTypedDictType) {
|
||||||
|
val key = sequenceMember.getKeyString(context)
|
||||||
|
if (key != null) return this.getElementType(key)
|
||||||
|
}
|
||||||
|
val mappingType = PyTypeUtil.convertToType(this, "typing.Mapping", sequenceMember, context)
|
||||||
|
?: return PyNeverType.NEVER
|
||||||
|
if (mappingType is PyCollectionType) {
|
||||||
|
return mappingType.elementTypes[1]
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun PyKeyValuePattern.getKeyString(context: TypeEvalContext): String? {
|
||||||
|
val keyType = context.getType(keyPattern)
|
||||||
|
if (keyType is PyLiteralType && keyType.expression is PyStringLiteralExpression) {
|
||||||
|
return keyType.expression.getStringValue()
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}
|
||||||
@@ -1,185 +0,0 @@
|
|||||||
package com.jetbrains.python.psi.impl;
|
|
||||||
|
|
||||||
import com.intellij.lang.ASTNode;
|
|
||||||
import com.intellij.psi.PsiElement;
|
|
||||||
import com.intellij.psi.PsiListLikeElement;
|
|
||||||
import com.intellij.psi.util.PsiTreeUtil;
|
|
||||||
import com.intellij.util.ArrayUtil;
|
|
||||||
import com.intellij.util.containers.ContainerUtil;
|
|
||||||
import com.jetbrains.python.psi.*;
|
|
||||||
import com.jetbrains.python.psi.types.*;
|
|
||||||
import org.jetbrains.annotations.NotNull;
|
|
||||||
import org.jetbrains.annotations.Nullable;
|
|
||||||
|
|
||||||
import java.util.ArrayList;
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.List;
|
|
||||||
|
|
||||||
import static com.jetbrains.python.psi.impl.PyClassPatternImpl.canExcludeArgumentPatternType;
|
|
||||||
import static com.jetbrains.python.psi.types.PyLiteralType.upcastLiteralToClass;
|
|
||||||
|
|
||||||
public class PySequencePatternImpl extends PyElementImpl implements PySequencePattern, PsiListLikeElement {
|
|
||||||
public PySequencePatternImpl(ASTNode astNode) {
|
|
||||||
super(astNode);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
|
|
||||||
pyVisitor.visitPySequencePattern(this);
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public @NotNull List<? extends PyPattern> getComponents() {
|
|
||||||
return getElements();
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
|
|
||||||
final PyType sequenceCaptureType = getSequenceCaptureType(this, context);
|
|
||||||
boolean isHomogeneous = !(sequenceCaptureType instanceof PyTupleType tupleType) || tupleType.isHomogeneous();
|
|
||||||
final ArrayList<PyType> types = new ArrayList<>();
|
|
||||||
for (PyPattern pattern : getElements()) {
|
|
||||||
if (pattern instanceof PySingleStarPattern starPattern) {
|
|
||||||
types.addAll(getCapturedTypesFromSequenceType(starPattern, sequenceCaptureType, context));
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
types.add(context.getType(pattern));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
PyType expectedType = isHomogeneous ? wrapInSequenceType(PyUnionType.union(types), this) : PyTupleType.create(this, types);
|
|
||||||
if (sequenceCaptureType == null) return expectedType;
|
|
||||||
return PyTypeUtil.toStream(sequenceCaptureType)
|
|
||||||
.map(it -> {
|
|
||||||
if (PyTypeChecker.match(expectedType, it, context)) {
|
|
||||||
return it;
|
|
||||||
}
|
|
||||||
return expectedType;
|
|
||||||
})
|
|
||||||
.collect(PyTypeUtil.toUnion());
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public boolean canExcludePatternType(@NotNull TypeEvalContext context) {
|
|
||||||
for (var p : getElements()) {
|
|
||||||
if (!canExcludeArgumentPatternType(p, context)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for (var type : PyTypeUtil.toStream(getSequenceCaptureType(this, context))) {
|
|
||||||
if (!(type instanceof PyTupleType tupleType) || tupleType.isHomogeneous()) {
|
|
||||||
// Not clear how to do this accurately, so for now matching
|
|
||||||
// types like list[Something] will not exclude type in any case
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public @Nullable PyType getCaptureTypeForChild(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) {
|
|
||||||
final PyType sequenceType = getSequenceCaptureType(this, context);
|
|
||||||
if (sequenceType == null) return null;
|
|
||||||
|
|
||||||
// This is done to skip group- and as-patterns
|
|
||||||
final var sequenceMember = PsiTreeUtil.findFirstParent(pattern, el -> this == el.getParent());
|
|
||||||
if (sequenceMember instanceof PySingleStarPattern starPattern) {
|
|
||||||
final PyType iteratedType = PyTypeUtil.toStream(sequenceType)
|
|
||||||
.flatMap(it -> getCapturedTypesFromSequenceType(starPattern, it, context).stream()).collect(PyTypeUtil.toUnion());
|
|
||||||
return wrapInListType(iteratedType, pattern);
|
|
||||||
}
|
|
||||||
final List<PyPattern> elements = getElements();
|
|
||||||
final int idx = elements.indexOf(sequenceMember);
|
|
||||||
final int starIdx = ContainerUtil.indexOf(elements, it2 -> it2 instanceof PySingleStarPattern);
|
|
||||||
|
|
||||||
return PyTypeUtil.toStream(sequenceType).map(it -> {
|
|
||||||
if (it instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) {
|
|
||||||
if (starIdx == -1 || idx < starIdx) {
|
|
||||||
return tupleType.getElementType(idx);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
final int starSpan = tupleType.getElementCount() - elements.size();
|
|
||||||
return tupleType.getElementType(idx + starSpan);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var upcast = PyTypeUtil.convertToType(it, "typing.Sequence", pattern, context);
|
|
||||||
if (upcast instanceof PyCollectionType collectionType) {
|
|
||||||
return collectionType.getIteratedItemType();
|
|
||||||
}
|
|
||||||
return null;
|
|
||||||
}).collect(PyTypeUtil.toUnion());
|
|
||||||
}
|
|
||||||
|
|
||||||
static @Nullable PyType wrapInListType(@Nullable PyType elementType, @NotNull PsiElement resolveAnchor) {
|
|
||||||
final PyClass list = PyBuiltinCache.getInstance(resolveAnchor).getClass("list");
|
|
||||||
return list != null ? new PyCollectionTypeImpl(list, false, Collections.singletonList(upcastLiteralToClass(elementType))) : null;
|
|
||||||
}
|
|
||||||
|
|
||||||
@Nullable
|
|
||||||
public static PyType wrapInSequenceType(@Nullable PyType elementType, @NotNull PsiElement resolveAnchor) {
|
|
||||||
final PyClass sequence = PyPsiFacade.getInstance(resolveAnchor.getProject()).createClassByQName("typing.Sequence", resolveAnchor);
|
|
||||||
return sequence != null ? new PyCollectionTypeImpl(sequence, false, Collections.singletonList(upcastLiteralToClass(elementType))) : null;
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Similar to {@link PyCaptureContext#getCaptureType(PyPattern, TypeEvalContext)},
|
|
||||||
* but only chooses types that would match to typing.Sequence, and have correct length
|
|
||||||
*/
|
|
||||||
@Nullable
|
|
||||||
static PyType getSequenceCaptureType(@NotNull PySequencePattern pattern, @NotNull TypeEvalContext context) {
|
|
||||||
final PyType captureTypes = PyCaptureContext.getCaptureType(pattern, context);
|
|
||||||
final boolean hasStar = ContainerUtil.exists(pattern.getElements(), it -> it instanceof PySingleStarPattern);
|
|
||||||
|
|
||||||
List<PyType> types = new ArrayList<>();
|
|
||||||
for (PyType captureType : PyTypeUtil.toStream(captureTypes)) {
|
|
||||||
if (captureType instanceof PyClassType classType &&
|
|
||||||
ArrayUtil.contains(classType.getClassQName(), "str", "bytes", "bytearray")) continue;
|
|
||||||
|
|
||||||
PyType sequenceType = PyTypeUtil.convertToType(captureType, "typing.Sequence", pattern, context);
|
|
||||||
if (sequenceType == null) continue;
|
|
||||||
|
|
||||||
if (captureType instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) {
|
|
||||||
final List<PyPattern> elements = pattern.getElements();
|
|
||||||
List<PyType> tupleElementTypes = tupleType.getElementTypes();
|
|
||||||
int unpackedTupleIndex = ContainerUtil.indexOf(tupleElementTypes, it -> it instanceof PyUnpackedTupleType);
|
|
||||||
if (unpackedTupleIndex != -1) {
|
|
||||||
PyUnpackedTupleType unpackedTupleType = (PyUnpackedTupleType)tupleElementTypes.get(unpackedTupleIndex);
|
|
||||||
assert unpackedTupleType.isUnbound();
|
|
||||||
int variadicElementsCount = elements.size() - tupleElementTypes.size() + 1;
|
|
||||||
if (variadicElementsCount >= 0) {
|
|
||||||
List<PyType> adjustedTupleElementTypes = new ArrayList<>(elements.size());
|
|
||||||
adjustedTupleElementTypes.addAll(tupleElementTypes.subList(0, unpackedTupleIndex));
|
|
||||||
for (int i = 0; i < variadicElementsCount; i++) {
|
|
||||||
adjustedTupleElementTypes.add(unpackedTupleType.getElementTypes().get(0));
|
|
||||||
}
|
|
||||||
adjustedTupleElementTypes.addAll(tupleElementTypes.subList(unpackedTupleIndex + 1, tupleElementTypes.size()));
|
|
||||||
types.add(new PyTupleType(tupleType.getPyClass(), adjustedTupleElementTypes, false));
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if (hasStar && elements.size() <= tupleType.getElementCount() || elements.size() == tupleType.getElementCount()) {
|
|
||||||
types.add(captureType);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
types.add(captureType);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return PyUnionType.union(types);
|
|
||||||
}
|
|
||||||
|
|
||||||
private static @NotNull List<@Nullable PyType> getCapturedTypesFromSequenceType(@NotNull PySingleStarPattern starPattern,
|
|
||||||
@Nullable PyType sequenceType,
|
|
||||||
@NotNull TypeEvalContext context) {
|
|
||||||
if (starPattern.getParent() instanceof PySequencePattern sequenceParent) {
|
|
||||||
final int idx = sequenceParent.getElements().indexOf(starPattern);
|
|
||||||
if (sequenceType instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) {
|
|
||||||
return tupleType.getElementTypes().subList(idx, idx + tupleType.getElementCount() - sequenceParent.getElements().size() + 1);
|
|
||||||
}
|
|
||||||
var upcast = PyTypeUtil.convertToType(sequenceType, "typing.Sequence", starPattern, context);
|
|
||||||
if (upcast instanceof PyCollectionType collectionType) {
|
|
||||||
return Collections.singletonList(collectionType.getIteratedItemType());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return List.of();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -0,0 +1,163 @@
|
|||||||
|
package com.jetbrains.python.psi.impl
|
||||||
|
|
||||||
|
import com.intellij.lang.ASTNode
|
||||||
|
import com.intellij.psi.PsiListLikeElement
|
||||||
|
import com.intellij.psi.util.findParentInFile
|
||||||
|
import com.jetbrains.python.psi.*
|
||||||
|
import com.jetbrains.python.psi.types.*
|
||||||
|
import com.jetbrains.python.psi.types.PyLiteralType.Companion.upcastLiteralToClass
|
||||||
|
import kotlin.contracts.ExperimentalContracts
|
||||||
|
import kotlin.contracts.contract
|
||||||
|
|
||||||
|
class PySequencePatternImpl(astNode: ASTNode?) : PyElementImpl(astNode), PySequencePattern, PsiListLikeElement {
|
||||||
|
override fun acceptPyVisitor(pyVisitor: PyElementVisitor) {
|
||||||
|
pyVisitor.visitPySequencePattern(this)
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun getComponents(): List<PyPattern> = elements
|
||||||
|
|
||||||
|
override fun getType(context: TypeEvalContext, key: TypeEvalContext.Key): PyType? {
|
||||||
|
val sequenceCaptureType = this.getSequenceCaptureType(context)
|
||||||
|
val types = elements.flatMap { pattern ->
|
||||||
|
when (pattern) {
|
||||||
|
is PySingleStarPattern -> pattern.getCapturedTypesFromSequenceType(sequenceCaptureType, context)
|
||||||
|
else -> listOf(context.getType(pattern))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
val expectedType = when {
|
||||||
|
sequenceCaptureType.isHeterogeneousTuple() -> PyTupleType.create(this, types)
|
||||||
|
else -> wrapInSequenceType(PyUnionType.union(types))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (sequenceCaptureType == null) return expectedType
|
||||||
|
return sequenceCaptureType.toList()
|
||||||
|
.map { if (PyTypeChecker.match(expectedType, it, context)) it else expectedType }
|
||||||
|
.let { PyUnionType.union(it) }
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun canExcludePatternType(context: TypeEvalContext): Boolean {
|
||||||
|
val allElementsCoverCapture = elements.all { PyClassPatternImpl.canExcludeArgumentPatternType(it, context) }
|
||||||
|
val allCapturesOfThisAreHeteroTuples = this.getSequenceCaptureType(context).toList().all { it.isHeterogeneousTuple() }
|
||||||
|
return allElementsCoverCapture && allCapturesOfThisAreHeteroTuples
|
||||||
|
}
|
||||||
|
|
||||||
|
override fun getCaptureTypeForChild(pattern: PyPattern, context: TypeEvalContext): PyType? {
|
||||||
|
val sequenceType = this.getSequenceCaptureType(context) ?: return null
|
||||||
|
|
||||||
|
// This is done to skip group- and as-patterns
|
||||||
|
val sequenceMember = pattern.findParentInFile(withSelf = true) { el -> this === el.parent }
|
||||||
|
if (sequenceMember is PySingleStarPattern) {
|
||||||
|
return sequenceType.toList()
|
||||||
|
.flatMap { sequenceMember.getCapturedTypesFromSequenceType(it, context) }
|
||||||
|
.let { PyUnionType.union(it) }
|
||||||
|
.let { wrapInListType(it) }
|
||||||
|
}
|
||||||
|
|
||||||
|
val idx = elements.indexOf(sequenceMember)
|
||||||
|
|
||||||
|
return sequenceType.toList()
|
||||||
|
.map { getElementTypeSkippingStar(it, idx, context) }
|
||||||
|
.let { PyUnionType.union(it) }
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun getElementTypeSkippingStar(sequence: PyType?, idx: Int, context: TypeEvalContext): PyType? {
|
||||||
|
if (sequence.isHeterogeneousTuple()) {
|
||||||
|
val starIdx = elements.indexOfFirst { it is PySingleStarPattern }
|
||||||
|
if (starIdx == -1 || idx < starIdx) {
|
||||||
|
return sequence.getElementType(idx)
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
val starSpan = sequence.elementCount - this.elements.size
|
||||||
|
return sequence.getElementType(idx + starSpan)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
val upcast = PyTypeUtil.convertToType(sequence, "typing.Sequence", this, context)
|
||||||
|
return (upcast as? PyCollectionType)?.iteratedItemType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Similar to [PyCaptureContext.getCaptureType],
|
||||||
|
* but only chooses types that would match to typing.Sequence, and have correct length
|
||||||
|
*/
|
||||||
|
private fun PySequencePattern.getSequenceCaptureType(context: TypeEvalContext): PyType? {
|
||||||
|
val captureTypes: PyType? = PyCaptureContext.getCaptureType(this, context)
|
||||||
|
|
||||||
|
val potentialMatchingTypes = captureTypes.toList()
|
||||||
|
.filter { it !is PyClassType || it.classQName !in listOf("str", "bytes", "bytearray") }
|
||||||
|
.filter { PyTypeUtil.convertToType(it, "typing.Sequence", this, context) != null }
|
||||||
|
|
||||||
|
val hasStar = elements.any { it is PySingleStarPattern }
|
||||||
|
val types = potentialMatchingTypes.mapNotNull {
|
||||||
|
if (it.isHeterogeneousTuple()) it.takeIfSizeMatches(elements.size, hasStar) else it
|
||||||
|
}
|
||||||
|
return PyUnionType.union(types)
|
||||||
|
}
|
||||||
|
|
||||||
|
fun wrapInListType(elementType: PyType?): PyType? {
|
||||||
|
val list = PyBuiltinCache.getInstance(this).getClass("list") ?: return null
|
||||||
|
return PyCollectionTypeImpl(list, false, listOf(upcastLiteralToClass(elementType)))
|
||||||
|
}
|
||||||
|
|
||||||
|
fun wrapInSequenceType(elementType: PyType?): PyType? {
|
||||||
|
val sequence = PyPsiFacade.getInstance(getProject()).createClassByQName("typing.Sequence", this) ?: return null
|
||||||
|
return PyCollectionTypeImpl(sequence, false, listOf(upcastLiteralToClass(elementType)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun PySingleStarPattern.getCapturedTypesFromSequenceType(sequenceType: PyType?, context: TypeEvalContext): List<PyType?> {
|
||||||
|
if (sequenceType.isHeterogeneousTuple()) {
|
||||||
|
val sequenceParent = this.parent as? PySequencePattern ?: return listOf()
|
||||||
|
val idx = sequenceParent.elements.indexOf(this)
|
||||||
|
return sequenceType.elementTypes.subList(idx, idx + sequenceType.elementCount - sequenceParent.elements.size + 1)
|
||||||
|
}
|
||||||
|
val upcast = PyTypeUtil.convertToType(sequenceType, "typing.Sequence", this, context)
|
||||||
|
if (upcast is PyCollectionType) {
|
||||||
|
return listOf(upcast.getIteratedItemType())
|
||||||
|
}
|
||||||
|
return listOf()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Use it like PyTypeUtil#toStream
|
||||||
|
internal fun PyType?.toList(): List<PyType?> = if (this is PyUnionType) members.toList() else listOf(this)
|
||||||
|
|
||||||
|
@OptIn(ExperimentalContracts::class)
|
||||||
|
private fun PyType?.isHeterogeneousTuple(): Boolean {
|
||||||
|
contract { returns(true) implies (this@isHeterogeneousTuple is PyTupleType) }
|
||||||
|
return this is PyTupleType && !isHomogeneous
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun PyTupleType.takeIfSizeMatches(desiredSize: Int, hasStar: Boolean): PyTupleType? {
|
||||||
|
if (this.elementTypes.any { it is PyUnpackedTupleType }) {
|
||||||
|
val variadicElementsCount: Int = desiredSize - this.elementTypes.size + 1
|
||||||
|
if (variadicElementsCount >= 0) {
|
||||||
|
return this.expandVariadics(variadicElementsCount)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
if (hasStar && desiredSize <= this.elementCount || desiredSize == this.elementCount) {
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
private fun PyTupleType.expandVariadics(variadicElementCount: Int): PyTupleType {
|
||||||
|
require(!this.isHomogeneous) { "Supplied tuple must not be homogeneous: $this" }
|
||||||
|
require(variadicElementCount >= 0) { "Supplied variadic element count must not be negative: $variadicElementCount" }
|
||||||
|
|
||||||
|
val unpackedTupleIndex = elementTypes.indexOfFirst { it is PyUnpackedTupleType }
|
||||||
|
val unpackedTupleType = elementTypes[unpackedTupleIndex] as PyUnpackedTupleType
|
||||||
|
assert(unpackedTupleType.isUnbound)
|
||||||
|
|
||||||
|
val adjustedTupleElementTypes = buildList {
|
||||||
|
addAll(elementTypes.subList(0, unpackedTupleIndex))
|
||||||
|
repeat(variadicElementCount) {
|
||||||
|
add(unpackedTupleType.elementTypes[0])
|
||||||
|
}
|
||||||
|
addAll(elementTypes.subList(unpackedTupleIndex + 1, elementTypes.size))
|
||||||
|
}
|
||||||
|
return PyTupleType(this.pyClass, adjustedTupleElementTypes, false)
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user