[python] Refactor PyCapturePatternImpl. Introduce PyCaptureContext

(cherry picked from commit 2e3fbf4c7d79e6031c7c087e5c7e7e49046587fd)

IJ-MR-168826

GitOrigin-RevId: b87eda39543460451311fc875d6ae3722d671db0
This commit is contained in:
Aleksandr.Govenko
2025-07-01 18:21:26 +02:00
committed by intellij-monorepo-bot
parent 8cc52b8cf3
commit 0dfd1f65e4
13 changed files with 260 additions and 267 deletions

View File

@@ -4,7 +4,7 @@ package com.jetbrains.python.psi;
import com.jetbrains.python.ast.PyAstCaseClause;
import org.jetbrains.annotations.Nullable;
public interface PyCaseClause extends PyAstCaseClause, PyStatementPart {
public interface PyCaseClause extends PyAstCaseClause, PyStatementPart, PyCaptureContext {
@Override
default @Nullable PyPattern getPattern() {
return (PyPattern)PyAstCaseClause.super.getPattern();

View File

@@ -4,7 +4,12 @@ package com.jetbrains.python.psi;
import com.jetbrains.python.ast.PyAstClassPattern;
import org.jetbrains.annotations.NotNull;
public interface PyClassPattern extends PyAstClassPattern, PyPattern {
import java.util.Set;
public interface PyClassPattern extends PyAstClassPattern, PyPattern, PyCaptureContext {
Set<String> SPECIAL_BUILTINS = Set.of(
"bool", "bytearray", "bytes", "dict", "float", "frozenset", "int", "list", "set", "str", "tuple");
@Override
default @NotNull PyReferenceExpression getClassNameReference() {
return (PyReferenceExpression)PyAstClassPattern.super.getClassNameReference();

View File

@@ -3,5 +3,5 @@ package com.jetbrains.python.psi;
import com.jetbrains.python.ast.PyAstMappingPattern;
public interface PyMappingPattern extends PyAstMappingPattern, PyPattern {
public interface PyMappingPattern extends PyAstMappingPattern, PyPattern, PyCaptureContext {
}

View File

@@ -1,43 +1,77 @@
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.psi;
package com.jetbrains.python.psi
import com.jetbrains.python.ast.PyAstPattern;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.ApiStatus;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import com.intellij.psi.util.findParentOfType
import com.jetbrains.python.ast.PyAstPattern
import com.jetbrains.python.psi.types.PyType
import com.jetbrains.python.psi.types.TypeEvalContext
import org.jetbrains.annotations.ApiStatus
public interface PyPattern extends PyAstPattern, PyTypedElement {
interface PyPattern : PyAstPattern, PyTypedElement {
/**
* Returns the type that would be captured by this pattern when matching.
* <p>
*
*
* Unlike other PyTypedElements where getType returns their own type, pattern's getType
* returns the type that would result from a successful match. For example:
*
* <pre>{@code
* ```python
* class Plant: pass
* class Animal: pass
* class Dog(Animal): pass
*
* x: Dog | Plant
* match x:
* case Animal():
* # getType returns Dog here, even though the pattern is Animal()
* }</pre>
* case Animal():
* # getType returns Dog here, even though the pattern is Animal()
* ```
*
* @see PyCapturePatternImpl#getCaptureType(PyPattern, TypeEvalContext)
* @see PyCapturePatternImpl.getCaptureType
*/
@Override
@Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key);
override fun getType(context: TypeEvalContext, key: TypeEvalContext.Key): PyType?
/**
* Decides if the set of values described by a pattern is suitable
* Decides if the set of values described by a pattern is suitable
* to be subtracted (excluded) from a subject type on the negative edge,
* or if this pattern is too specific.
*/
@ApiStatus.Experimental
default boolean canExcludePatternType(@NotNull TypeEvalContext context) {
return true;
fun canExcludePatternType(context: TypeEvalContext): Boolean {
return true
}
}
interface PyCaptureContext : PyElement {
fun getCaptureTypeForChild(pattern: PyPattern, context: TypeEvalContext): PyType?
companion object {
/**
* Determines what type this pattern would have if it was a capture pattern (like a bare name or _).
*
* In pattern matching, a capture pattern takes on the type of the entire matched expression,
* regardless of any specific pattern constraints.
*
* For example:
* ```python
* x: int | str
* match x:
* case a: # This is a capture pattern
* # Here 'a' has type int | str
* case str(): # This is a class pattern
* # Capture type: int | str (same as what 'case a:' would get)
* # Regular getType: str
*
* y: int
* match y:
* case str() as a:
* # Capture type: int (same as what 'case a:' would get)
* # Regular getType: intersect(int, str) (just 'str' for now)
* ```
* @see PyPattern#getType(TypeEvalContext, TypeEvalContext.Key)
*/
@JvmStatic
fun getCaptureType(pattern: PyPattern, context: TypeEvalContext): PyType? {
return pattern.findParentOfType<PyCaptureContext>()?.getCaptureTypeForChild(pattern, context)
}
}
}

View File

@@ -8,7 +8,7 @@ import java.util.List;
import static com.jetbrains.python.ast.PyAstElementKt.findChildrenByClass;
public interface PySequencePattern extends PyAstSequencePattern, PyPattern {
public interface PySequencePattern extends PyAstSequencePattern, PyPattern, PyCaptureContext {
default @NotNull List<@NotNull PyPattern> getElements() {
return List.of(findChildrenByClass(this, PyPattern.class));
}

View File

@@ -2,12 +2,8 @@
package com.jetbrains.python.psi;
import com.jetbrains.python.ast.PyAstSingleStarPattern;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.List;
import java.util.Objects;
import static com.jetbrains.python.ast.PyAstElementKt.findChildByClass;
@@ -17,6 +13,4 @@ public interface PySingleStarPattern extends PyAstSingleStarPattern, PyPattern {
default PyPattern getPattern() {
return Objects.requireNonNull(findChildByClass(this, PyPattern.class));
}
@NotNull List<@Nullable PyType> getCapturedTypesFromSequenceType(@Nullable PyType sequenceType, @NotNull TypeEvalContext context);
}

View File

@@ -1,25 +1,14 @@
package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.intellij.openapi.util.Ref;
import com.intellij.psi.PsiElement;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.codeInsight.stdlib.PyDataclassTypeProvider;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.resolve.PyResolveContext;
import com.jetbrains.python.psi.resolve.RatedResolveResult;
import com.jetbrains.python.psi.types.*;
import com.jetbrains.python.psi.PyCaptureContext;
import com.jetbrains.python.psi.PyCapturePattern;
import com.jetbrains.python.psi.PyElementVisitor;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.List;
import java.util.Set;
import static com.jetbrains.python.codeInsight.controlflow.PyTypeAssertionEvaluator.createAssertionType;
import static com.jetbrains.python.psi.PyUtil.as;
import static com.jetbrains.python.psi.impl.PySequencePatternImpl.wrapInListType;
public class PyCapturePatternImpl extends PyElementImpl implements PyCapturePattern {
public PyCapturePatternImpl(ASTNode astNode) {
super(astNode);
@@ -32,187 +21,6 @@ public class PyCapturePatternImpl extends PyElementImpl implements PyCapturePatt
@Override
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
return getCaptureType(this, context);
}
static Set<String> SPECIAL_BUILTINS = Set.of(
"bool", "bytearray", "bytes", "dict", "float", "frozenset", "int", "list", "set", "str", "tuple");
/**
* Determines what type this pattern would have if it was a capture pattern (like a bare name or _).
* <p>
* In pattern matching, a capture pattern takes on the type of the entire matched expression,
* regardless of any specific pattern constraints.
* <p>
* For example:
* <pre>{@code
* x: int | str
* match x:
* case a: # This is a capture pattern
* # Here 'a' has type int | str
* case str(): # This is a class pattern
* # Capture type: int | str (same as what 'case a:' would get)
* # Regular getType: str
*
* y: int
* match y:
* case str() as a:
* # Capture type: int (same as what 'case a:' would get)
* # Regular getType: intersect(int, str) (just 'str' for now)
* }</pre>
* @see PyPattern#getType(TypeEvalContext, TypeEvalContext.Key)
*/
static @Nullable PyType getCaptureType(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) {
final PyElement parentPattern = PsiTreeUtil.getParentOfType(
pattern, // Capture corresponds to:
PyCaseClause.class, // - Subject of a match statement
PySingleStarPattern.class, // - Type of parent sequence pattern
PyDoubleStarPattern.class, // - Type of parent mapping pattern
PyKeyValuePattern.class, // - Value type of parent mapping pattern
PySequencePattern.class, // - Iterated item type of sequence
PyClassPattern.class, // - Attribute type in the corresponding class
PyKeywordPattern.class // - Attribute type in the corresponding class
);
if (parentPattern instanceof PyCaseClause caseClause) {
final PyMatchStatement matchStatement = as(caseClause.getParent(), PyMatchStatement.class);
if (matchStatement == null) return null;
final PyExpression subject = matchStatement.getSubject();
if (subject == null) return null;
PyType subjectType = context.getType(subject);
for (PyCaseClause cs : matchStatement.getCaseClauses()) {
if (cs == caseClause) break;
if (cs.getPattern() == null) continue;
if (cs.getGuardCondition() != null && !PyEvaluator.evaluateAsBoolean(cs.getGuardCondition(), false)) continue;
if (cs.getPattern().canExcludePatternType(context)) {
subjectType = Ref.deref(createAssertionType(subjectType, context.getType(cs.getPattern()), false, context));
}
}
return subjectType;
}
if (parentPattern instanceof PySingleStarPattern starPattern) {
final PySequencePattern sequenceParent = as(starPattern.getParent(), PySequencePattern.class);
if (sequenceParent == null) return null;
final PyType sequenceType = PySequencePatternImpl.getSequenceCaptureType(sequenceParent, context);
final PyType iteratedType = PyTypeUtil.toStream(sequenceType)
.flatMap(it -> starPattern.getCapturedTypesFromSequenceType(it, context).stream()).collect(PyTypeUtil.toUnion());
return wrapInListType(iteratedType, pattern);
}
if (parentPattern instanceof PyDoubleStarPattern) {
final PyMappingPattern mappingParent = as(parentPattern.getParent(), PyMappingPattern.class);
if (mappingParent == null) return null;
var mappingType = PyTypeUtil.convertToType(context.getType(mappingParent), "typing.Mapping", pattern, context);
if (mappingType instanceof PyCollectionType collectionType) {
final PyClass dict = PyBuiltinCache.getInstance(pattern).getClass("dict");
return dict != null ? new PyCollectionTypeImpl(dict, false, collectionType.getElementTypes()) : null;
}
return null;
}
if (parentPattern instanceof PyKeyValuePattern keyValuePattern) {
final PyMappingPattern mappingParent = as(keyValuePattern.getParent(), PyMappingPattern.class);
if (mappingParent == null) return null;
return PyTypeUtil.toStream(getCaptureType(mappingParent, context)).map(type -> {
if (type instanceof PyTypedDictType typedDictType) {
if (context.getType(keyValuePattern.getKeyPattern()) instanceof PyLiteralType l &&
l.getExpression() instanceof PyStringLiteralExpression str) {
return typedDictType.getElementType(str.getStringValue());
}
}
PyType mappingType = PyTypeUtil.convertToType(type, "typing.Mapping", pattern, context);
if (mappingType == null) {
return PyNeverType.NEVER;
}
else if (mappingType instanceof PyCollectionType collectionType) {
return collectionType.getElementTypes().get(1);
}
return null;
}).collect(PyTypeUtil.toUnion());
}
if (parentPattern instanceof PySequencePattern sequencePattern) {
final PyType sequenceType = PySequencePatternImpl.getSequenceCaptureType(sequencePattern, context);
if (sequenceType == null) return null;
// This is done to skip group- and as-patterns
final var sequenceMember = PsiTreeUtil.findFirstParent(pattern, el -> el.getParent() == sequencePattern);
final List<PyPattern> elements = sequencePattern.getElements();
final int idx = elements.indexOf(sequenceMember);
final int starIdx = ContainerUtil.indexOf(elements, it2 -> it2 instanceof PySingleStarPattern);
return PyTypeUtil.toStream(sequenceType).map(it -> {
if (it instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) {
if (starIdx == -1 || idx < starIdx) {
return tupleType.getElementType(idx);
}
else {
final int starSpan = tupleType.getElementCount() - elements.size();
return tupleType.getElementType(idx + starSpan);
}
}
var upcast = PyTypeUtil.convertToType(it, "typing.Sequence", pattern, context);
if (upcast instanceof PyCollectionType collectionType) {
return collectionType.getIteratedItemType();
}
return null;
}).collect(PyTypeUtil.toUnion());
}
if (parentPattern instanceof PyClassPattern classPattern) {
final List<PyPattern> arguments = classPattern.getArgumentList().getPatterns();
int index = arguments.indexOf(pattern);
if (index < 0) return null;
// capture type can be a union like: list[int] | list[str]
return PyTypeUtil.toStream(context.getType(classPattern)).map(type -> {
if (type instanceof PyClassType classType) {
if (SPECIAL_BUILTINS.contains(classType.getClassQName())) {
if (index == 0) {
return classType;
}
return null;
}
List<String> matchArgs = getMatchArgs(classType, context);
if (matchArgs == null || matchArgs.size() > arguments.size()) return null;
final PyTypedElement instanceAttribute = as(resolveTypeMember(classType, matchArgs.get(index), context), PyTypedElement.class);
if (instanceAttribute == null) return null;
return PyTypeChecker.substitute(context.getType(instanceAttribute), PyTypeChecker.unifyReceiver(classType, context), context);
}
return null;
}).collect(PyTypeUtil.toUnion());
}
if (parentPattern instanceof PyKeywordPattern keywordPattern) {
final PyClassPattern classPattern = PsiTreeUtil.getParentOfType(keywordPattern, PyClassPattern.class);
if (classPattern == null) return null;
if (context.getType(classPattern) instanceof PyClassType classType) {
final PyTypedElement instanceAttribute = as(resolveTypeMember(classType, keywordPattern.getKeyword(), context), PyTypedElement.class);
if (instanceAttribute == null) return null;
return context.getType(instanceAttribute);
}
return null;
}
return null;
}
static @Nullable List<@NotNull String> getMatchArgs(@NotNull PyClassType type, @NotNull TypeEvalContext context) {
final PyClass cls = type.getPyClass();
List<String> matchArgs = cls.getOwnMatchArgs();
if (matchArgs == null) {
matchArgs = PyDataclassTypeProvider.Companion.getGeneratedMatchArgs(cls, context);
}
return matchArgs;
}
@Nullable
static PsiElement resolveTypeMember(@NotNull PyType type, @NotNull String name, @NotNull TypeEvalContext context) {
final PyResolveContext resolveContext = PyResolveContext.defaultContext(context);
final List<? extends RatedResolveResult> results = type.resolveMember(name, null, AccessDirection.READ, resolveContext);
return !ContainerUtil.isEmpty(results) ? results.get(0).getElement() : null;
return PyCaptureContext.getCaptureType(this, context);
}
}

View File

@@ -1,16 +1,35 @@
package com.jetbrains.python.psi.impl;
package com.jetbrains.python.psi.impl
import com.intellij.lang.ASTNode;
import com.jetbrains.python.psi.PyCaseClause;
import com.jetbrains.python.psi.PyElementVisitor;
import com.intellij.lang.ASTNode
import com.intellij.openapi.util.Ref
import com.jetbrains.python.codeInsight.controlflow.PyTypeAssertionEvaluator
import com.jetbrains.python.psi.PyCaseClause
import com.jetbrains.python.psi.PyElementVisitor
import com.jetbrains.python.psi.PyMatchStatement
import com.jetbrains.python.psi.PyPattern
import com.jetbrains.python.psi.types.PyType
import com.jetbrains.python.psi.types.TypeEvalContext
public class PyCaseClauseImpl extends PyElementImpl implements PyCaseClause {
public PyCaseClauseImpl(ASTNode astNode) {
super(astNode);
class PyCaseClauseImpl(astNode: ASTNode?) : PyElementImpl(astNode), PyCaseClause {
override fun acceptPyVisitor(pyVisitor: PyElementVisitor) {
pyVisitor.visitPyCaseClause(this)
}
@Override
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
pyVisitor.visitPyCaseClause(this);
override fun getCaptureTypeForChild(pattern: PyPattern, context: TypeEvalContext): PyType? {
val matchStatement = getParent() as? PyMatchStatement ?: return null
val subject = matchStatement.subject ?: return null
var subjectType = context.getType(subject)
for (cs in matchStatement.caseClauses) {
if (cs === this) break
if (cs.pattern == null) continue
if (cs.guardCondition != null && !PyEvaluator.evaluateAsBoolean(cs.guardCondition, false)) continue
if (cs.pattern!!.canExcludePatternType(context)) {
subjectType = Ref.deref(
PyTypeAssertionEvaluator.createAssertionType(subjectType, context.getType(cs.pattern!!), false, context))
}
}
return subjectType
}
}

View File

@@ -2,8 +2,14 @@ package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.intellij.openapi.util.Ref;
import com.intellij.psi.PsiElement;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.codeInsight.controlflow.PyTypeAssertionEvaluator;
import com.jetbrains.python.codeInsight.stdlib.PyDataclassTypeProvider;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.resolve.PyResolveContext;
import com.jetbrains.python.psi.resolve.RatedResolveResult;
import com.jetbrains.python.psi.types.*;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
@@ -11,8 +17,8 @@ import org.jetbrains.annotations.Nullable;
import java.util.List;
import static com.intellij.util.containers.ContainerUtil.getFirstItem;
import static com.jetbrains.python.psi.PyUtil.as;
import static com.jetbrains.python.psi.PyUtil.multiResolveTopPriority;
import static com.jetbrains.python.psi.impl.PyCapturePatternImpl.*;
public class PyClassPatternImpl extends PyElementImpl implements PyClassPattern {
public PyClassPatternImpl(ASTNode astNode) {
@@ -29,7 +35,7 @@ public class PyClassPatternImpl extends PyElementImpl implements PyClassPattern
final PyType type = context.getType(getClassNameReference());
if (type instanceof PyClassType classType) {
final PyType instanceType = classType.toInstance();
final PyType captureType = PyCapturePatternImpl.getCaptureType(this, context);
final PyType captureType = PyCaptureContext.getCaptureType(this, context);
return Ref.deref(PyTypeAssertionEvaluator.createAssertionType(captureType, instanceType, true, context));
}
return null;
@@ -75,13 +81,53 @@ public class PyClassPatternImpl extends PyElementImpl implements PyClassPattern
}
}
}
final PyType captureType = getCaptureType(this, context);
final PyType captureType = PyCaptureContext.getCaptureType(this, context);
final PyType patternType = context.getType(this);
return PyTypeUtil.toStream(captureType).anyMatch(it -> PyTypeChecker.match(patternType, it, context));
}
@Override
public @Nullable PyType getCaptureTypeForChild(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) {
pattern = as(PsiTreeUtil.findFirstParent(pattern, el -> this.getArgumentList() == el.getParent()), PyPattern.class);
if (pattern == null) return null;
if (pattern instanceof PyKeywordPattern keywordPattern) {
if (context.getType(this) instanceof PyClassType classType) {
final PyTypedElement instanceAttribute = as(resolveTypeMember(classType, keywordPattern.getKeyword(), context), PyTypedElement.class);
if (instanceAttribute == null) return null;
return context.getType(instanceAttribute);
}
return null;
}
final List<PyPattern> arguments = getArgumentList().getPatterns();
int index = arguments.indexOf(pattern);
if (index < 0) return null;
// capture type can be a union like: list[int] | list[str]
return PyTypeUtil.toStream(context.getType(this)).map(type -> {
if (type instanceof PyClassType classType) {
if (SPECIAL_BUILTINS.contains(classType.getClassQName())) {
if (index == 0) {
return classType;
}
return null;
}
List<String> matchArgs = getMatchArgs(classType, context);
if (matchArgs == null || matchArgs.size() > arguments.size()) return null;
final PyTypedElement instanceAttribute = as(resolveTypeMember(classType, matchArgs.get(index), context), PyTypedElement.class);
if (instanceAttribute == null) return null;
return PyTypeChecker.substitute(context.getType(instanceAttribute), PyTypeChecker.unifyReceiver(classType, context), context);
}
return null;
}).collect(PyTypeUtil.toUnion());
}
static boolean canExcludeArgumentPatternType(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) {
final var captureType = getCaptureType(pattern, context);
final var captureType = PyCaptureContext.getCaptureType(pattern, context);
final var patternType = context.getType(pattern);
// For class pattern arguments, we need to ensure that the argument pattern covers its capture type fully
if (Ref.deref(PyTypeAssertionEvaluator.createAssertionType(captureType, patternType, false, context)) instanceof PyNeverType) {
@@ -90,4 +136,20 @@ public class PyClassPatternImpl extends PyElementImpl implements PyClassPattern
}
return false;
}
private static @Nullable List<@NotNull String> getMatchArgs(@NotNull PyClassType type, @NotNull TypeEvalContext context) {
final PyClass cls = type.getPyClass();
List<String> matchArgs = cls.getOwnMatchArgs();
if (matchArgs == null) {
matchArgs = PyDataclassTypeProvider.Companion.getGeneratedMatchArgs(cls, context);
}
return matchArgs;
}
@Nullable
static PsiElement resolveTypeMember(@NotNull PyType type, @NotNull String name, @NotNull TypeEvalContext context) {
final PyResolveContext resolveContext = PyResolveContext.defaultContext(context);
final List<? extends RatedResolveResult> results = type.resolveMember(name, null, AccessDirection.READ, resolveContext);
return !ContainerUtil.isEmpty(results) ? results.get(0).getElement() : null;
}
}

View File

@@ -3,6 +3,7 @@ package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiListLikeElement;
import com.intellij.psi.util.PsiTreeUtil;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.types.*;
import org.jetbrains.annotations.NotNull;
@@ -45,7 +46,7 @@ public class PyMappingPatternImpl extends PyElementImpl implements PyMappingPatt
PyType patternMappingType = wrapInMappingType(PyUnionType.union(keyTypes), PyUnionType.union(valueTypes), this);
PyType captureTypes = PyCapturePatternImpl.getCaptureType(this, context);
PyType captureTypes = PyCaptureContext.getCaptureType(this, context);
PyType filteredType = PyTypeUtil.toStream(captureTypes).filter(captureType -> {
var mappingType = PyTypeUtil.convertToType(captureType, "typing.Mapping", this, context);
if (mappingType == null) return false;
@@ -56,6 +57,39 @@ public class PyMappingPatternImpl extends PyElementImpl implements PyMappingPatt
return filteredType == null ? patternMappingType : filteredType;
}
@Override
public @Nullable PyType getCaptureTypeForChild(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) {
final var sequenceMember = PsiTreeUtil.findFirstParent(pattern, el -> this == el.getParent());
if (sequenceMember instanceof PyDoubleStarPattern) {
var mappingType = PyTypeUtil.convertToType(context.getType(this), "typing.Mapping", pattern, context);
if (mappingType instanceof PyCollectionType collectionType) {
final PyClass dict = PyBuiltinCache.getInstance(pattern).getClass("dict");
return dict != null ? new PyCollectionTypeImpl(dict, false, collectionType.getElementTypes()) : null;
}
return null;
}
else if (sequenceMember instanceof PyKeyValuePattern keyValuePattern) {
return PyTypeUtil.toStream(PyCaptureContext.getCaptureType(this, context)).map(type -> {
if (type instanceof PyTypedDictType typedDictType) {
if (context.getType(keyValuePattern.getKeyPattern()) instanceof PyLiteralType l &&
l.getExpression() instanceof PyStringLiteralExpression str) {
return typedDictType.getElementType(str.getStringValue());
}
}
PyType mappingType = PyTypeUtil.convertToType(type, "typing.Mapping", pattern, context);
if (mappingType == null) {
return PyNeverType.NEVER;
}
else if (mappingType instanceof PyCollectionType collectionType) {
return collectionType.getElementTypes().get(1);
}
return null;
}).collect(PyTypeUtil.toUnion());
}
return null;
}
private static @Nullable PyType wrapInMappingType(@Nullable PyType keyType, @Nullable PyType valueType, @NotNull PsiElement resolveAnchor) {
keyType = PyLiteralType.upcastLiteralToClass(keyType);
valueType = PyLiteralType.upcastLiteralToClass(valueType);

View File

@@ -3,6 +3,7 @@ package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiListLikeElement;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.ArrayUtil;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.psi.*;
@@ -39,7 +40,7 @@ public class PySequencePatternImpl extends PyElementImpl implements PySequencePa
final ArrayList<PyType> types = new ArrayList<>();
for (PyPattern pattern : getElements()) {
if (pattern instanceof PySingleStarPattern starPattern) {
types.addAll(starPattern.getCapturedTypesFromSequenceType(sequenceCaptureType, context));
types.addAll(getCapturedTypesFromSequenceType(starPattern, sequenceCaptureType, context));
}
else {
types.add(context.getType(pattern));
@@ -74,6 +75,40 @@ public class PySequencePatternImpl extends PyElementImpl implements PySequencePa
return true;
}
@Override
public @Nullable PyType getCaptureTypeForChild(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) {
final PyType sequenceType = getSequenceCaptureType(this, context);
if (sequenceType == null) return null;
// This is done to skip group- and as-patterns
final var sequenceMember = PsiTreeUtil.findFirstParent(pattern, el -> this == el.getParent());
if (sequenceMember instanceof PySingleStarPattern starPattern) {
final PyType iteratedType = PyTypeUtil.toStream(sequenceType)
.flatMap(it -> getCapturedTypesFromSequenceType(starPattern, it, context).stream()).collect(PyTypeUtil.toUnion());
return wrapInListType(iteratedType, pattern);
}
final List<PyPattern> elements = getElements();
final int idx = elements.indexOf(sequenceMember);
final int starIdx = ContainerUtil.indexOf(elements, it2 -> it2 instanceof PySingleStarPattern);
return PyTypeUtil.toStream(sequenceType).map(it -> {
if (it instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) {
if (starIdx == -1 || idx < starIdx) {
return tupleType.getElementType(idx);
}
else {
final int starSpan = tupleType.getElementCount() - elements.size();
return tupleType.getElementType(idx + starSpan);
}
}
var upcast = PyTypeUtil.convertToType(it, "typing.Sequence", pattern, context);
if (upcast instanceof PyCollectionType collectionType) {
return collectionType.getIteratedItemType();
}
return null;
}).collect(PyTypeUtil.toUnion());
}
static @Nullable PyType wrapInListType(@Nullable PyType elementType, @NotNull PsiElement resolveAnchor) {
final PyClass list = PyBuiltinCache.getInstance(resolveAnchor).getClass("list");
return list != null ? new PyCollectionTypeImpl(list, false, Collections.singletonList(upcastLiteralToClass(elementType))) : null;
@@ -86,12 +121,12 @@ public class PySequencePatternImpl extends PyElementImpl implements PySequencePa
}
/**
* Similar to {@link PyCapturePatternImpl#getCaptureType(PyPattern, TypeEvalContext)},
* Similar to {@link PyCaptureContext#getCaptureType(PyPattern, TypeEvalContext)},
* but only chooses types that would match to typing.Sequence, and have correct length
*/
@Nullable
static PyType getSequenceCaptureType(@NotNull PySequencePattern pattern, @NotNull TypeEvalContext context) {
final PyType captureTypes = PyCapturePatternImpl.getCaptureType(pattern, context);
final PyType captureTypes = PyCaptureContext.getCaptureType(pattern, context);
final boolean hasStar = ContainerUtil.exists(pattern.getElements(), it -> it instanceof PySingleStarPattern);
List<PyType> types = new ArrayList<>();
@@ -131,4 +166,20 @@ public class PySequencePatternImpl extends PyElementImpl implements PySequencePa
}
return PyUnionType.union(types);
}
private static @NotNull List<@Nullable PyType> getCapturedTypesFromSequenceType(@NotNull PySingleStarPattern starPattern,
@Nullable PyType sequenceType,
@NotNull TypeEvalContext context) {
if (starPattern.getParent() instanceof PySequencePattern sequenceParent) {
final int idx = sequenceParent.getElements().indexOf(starPattern);
if (sequenceType instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) {
return tupleType.getElementTypes().subList(idx, idx + tupleType.getElementCount() - sequenceParent.getElements().size() + 1);
}
var upcast = PyTypeUtil.convertToType(sequenceType, "typing.Sequence", starPattern, context);
if (upcast instanceof PyCollectionType collectionType) {
return Collections.singletonList(collectionType.getIteratedItemType());
}
}
return List.of();
}
}

View File

@@ -1,14 +1,14 @@
package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.types.*;
import com.jetbrains.python.psi.PyElementVisitor;
import com.jetbrains.python.psi.PyInstantTypeProvider;
import com.jetbrains.python.psi.PySingleStarPattern;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.Collections;
import java.util.List;
public class PySingleStarPatternImpl extends PyElementImpl implements PySingleStarPattern, PyInstantTypeProvider {
public PySingleStarPatternImpl(ASTNode astNode) {
super(astNode);
@@ -23,19 +23,4 @@ public class PySingleStarPatternImpl extends PyElementImpl implements PySingleSt
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
return null;
}
@Override
public @NotNull List<@Nullable PyType> getCapturedTypesFromSequenceType(@Nullable PyType sequenceType, @NotNull TypeEvalContext context) {
if (getParent() instanceof PySequencePattern sequenceParent) {
final int idx = sequenceParent.getElements().indexOf(this);
if (sequenceType instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) {
return tupleType.getElementTypes().subList(idx, idx + tupleType.getElementCount() - sequenceParent.getElements().size() + 1);
}
var upcast = PyTypeUtil.convertToType(sequenceType, "typing.Sequence", this, context);
if (upcast instanceof PyCollectionType collectionType) {
return Collections.singletonList(collectionType.getIteratedItemType());
}
}
return List.of();
}
}

View File

@@ -1,6 +1,7 @@
package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.jetbrains.python.psi.PyCaptureContext;
import com.jetbrains.python.psi.PyElementVisitor;
import com.jetbrains.python.psi.PyWildcardPattern;
import com.jetbrains.python.psi.types.PyType;
@@ -20,6 +21,6 @@ public class PyWildcardPatternImpl extends PyElementImpl implements PyWildcardPa
@Override
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
return PyCapturePatternImpl.getCaptureType(this, context);
return PyCaptureContext.getCaptureType(this, context);
}
}