mirror of
https://gitflic.ru/project/openide/openide.git
synced 2025-12-14 18:05:27 +07:00
[python] Refactor PyCapturePatternImpl. Introduce PyCaptureContext
(cherry picked from commit 2e3fbf4c7d79e6031c7c087e5c7e7e49046587fd) IJ-MR-168826 GitOrigin-RevId: b87eda39543460451311fc875d6ae3722d671db0
This commit is contained in:
committed by
intellij-monorepo-bot
parent
8cc52b8cf3
commit
0dfd1f65e4
@@ -4,7 +4,7 @@ package com.jetbrains.python.psi;
|
||||
import com.jetbrains.python.ast.PyAstCaseClause;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
public interface PyCaseClause extends PyAstCaseClause, PyStatementPart {
|
||||
public interface PyCaseClause extends PyAstCaseClause, PyStatementPart, PyCaptureContext {
|
||||
@Override
|
||||
default @Nullable PyPattern getPattern() {
|
||||
return (PyPattern)PyAstCaseClause.super.getPattern();
|
||||
|
||||
@@ -4,7 +4,12 @@ package com.jetbrains.python.psi;
|
||||
import com.jetbrains.python.ast.PyAstClassPattern;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
|
||||
public interface PyClassPattern extends PyAstClassPattern, PyPattern {
|
||||
import java.util.Set;
|
||||
|
||||
public interface PyClassPattern extends PyAstClassPattern, PyPattern, PyCaptureContext {
|
||||
Set<String> SPECIAL_BUILTINS = Set.of(
|
||||
"bool", "bytearray", "bytes", "dict", "float", "frozenset", "int", "list", "set", "str", "tuple");
|
||||
|
||||
@Override
|
||||
default @NotNull PyReferenceExpression getClassNameReference() {
|
||||
return (PyReferenceExpression)PyAstClassPattern.super.getClassNameReference();
|
||||
|
||||
@@ -3,5 +3,5 @@ package com.jetbrains.python.psi;
|
||||
|
||||
import com.jetbrains.python.ast.PyAstMappingPattern;
|
||||
|
||||
public interface PyMappingPattern extends PyAstMappingPattern, PyPattern {
|
||||
public interface PyMappingPattern extends PyAstMappingPattern, PyPattern, PyCaptureContext {
|
||||
}
|
||||
|
||||
@@ -1,43 +1,77 @@
|
||||
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
package com.jetbrains.python.psi;
|
||||
package com.jetbrains.python.psi
|
||||
|
||||
import com.jetbrains.python.ast.PyAstPattern;
|
||||
import com.jetbrains.python.psi.types.PyType;
|
||||
import com.jetbrains.python.psi.types.TypeEvalContext;
|
||||
import org.jetbrains.annotations.ApiStatus;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
import com.intellij.psi.util.findParentOfType
|
||||
import com.jetbrains.python.ast.PyAstPattern
|
||||
import com.jetbrains.python.psi.types.PyType
|
||||
import com.jetbrains.python.psi.types.TypeEvalContext
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
public interface PyPattern extends PyAstPattern, PyTypedElement {
|
||||
interface PyPattern : PyAstPattern, PyTypedElement {
|
||||
/**
|
||||
* Returns the type that would be captured by this pattern when matching.
|
||||
* <p>
|
||||
*
|
||||
*
|
||||
* Unlike other PyTypedElements where getType returns their own type, pattern's getType
|
||||
* returns the type that would result from a successful match. For example:
|
||||
*
|
||||
* <pre>{@code
|
||||
* ```python
|
||||
* class Plant: pass
|
||||
* class Animal: pass
|
||||
* class Dog(Animal): pass
|
||||
*
|
||||
* x: Dog | Plant
|
||||
* match x:
|
||||
* case Animal():
|
||||
* # getType returns Dog here, even though the pattern is Animal()
|
||||
* }</pre>
|
||||
* case Animal():
|
||||
* # getType returns Dog here, even though the pattern is Animal()
|
||||
* ```
|
||||
*
|
||||
* @see PyCapturePatternImpl#getCaptureType(PyPattern, TypeEvalContext)
|
||||
* @see PyCapturePatternImpl.getCaptureType
|
||||
*/
|
||||
@Override
|
||||
@Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key);
|
||||
override fun getType(context: TypeEvalContext, key: TypeEvalContext.Key): PyType?
|
||||
|
||||
/**
|
||||
* Decides if the set of values described by a pattern is suitable
|
||||
* Decides if the set of values described by a pattern is suitable
|
||||
* to be subtracted (excluded) from a subject type on the negative edge,
|
||||
* or if this pattern is too specific.
|
||||
*/
|
||||
@ApiStatus.Experimental
|
||||
default boolean canExcludePatternType(@NotNull TypeEvalContext context) {
|
||||
return true;
|
||||
fun canExcludePatternType(context: TypeEvalContext): Boolean {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
interface PyCaptureContext : PyElement {
|
||||
fun getCaptureTypeForChild(pattern: PyPattern, context: TypeEvalContext): PyType?
|
||||
|
||||
companion object {
|
||||
/**
|
||||
* Determines what type this pattern would have if it was a capture pattern (like a bare name or _).
|
||||
*
|
||||
* In pattern matching, a capture pattern takes on the type of the entire matched expression,
|
||||
* regardless of any specific pattern constraints.
|
||||
*
|
||||
* For example:
|
||||
* ```python
|
||||
* x: int | str
|
||||
* match x:
|
||||
* case a: # This is a capture pattern
|
||||
* # Here 'a' has type int | str
|
||||
* case str(): # This is a class pattern
|
||||
* # Capture type: int | str (same as what 'case a:' would get)
|
||||
* # Regular getType: str
|
||||
*
|
||||
* y: int
|
||||
* match y:
|
||||
* case str() as a:
|
||||
* # Capture type: int (same as what 'case a:' would get)
|
||||
* # Regular getType: intersect(int, str) (just 'str' for now)
|
||||
* ```
|
||||
* @see PyPattern#getType(TypeEvalContext, TypeEvalContext.Key)
|
||||
*/
|
||||
@JvmStatic
|
||||
fun getCaptureType(pattern: PyPattern, context: TypeEvalContext): PyType? {
|
||||
return pattern.findParentOfType<PyCaptureContext>()?.getCaptureTypeForChild(pattern, context)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,7 +8,7 @@ import java.util.List;
|
||||
|
||||
import static com.jetbrains.python.ast.PyAstElementKt.findChildrenByClass;
|
||||
|
||||
public interface PySequencePattern extends PyAstSequencePattern, PyPattern {
|
||||
public interface PySequencePattern extends PyAstSequencePattern, PyPattern, PyCaptureContext {
|
||||
default @NotNull List<@NotNull PyPattern> getElements() {
|
||||
return List.of(findChildrenByClass(this, PyPattern.class));
|
||||
}
|
||||
|
||||
@@ -2,12 +2,8 @@
|
||||
package com.jetbrains.python.psi;
|
||||
|
||||
import com.jetbrains.python.ast.PyAstSingleStarPattern;
|
||||
import com.jetbrains.python.psi.types.PyType;
|
||||
import com.jetbrains.python.psi.types.TypeEvalContext;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static com.jetbrains.python.ast.PyAstElementKt.findChildByClass;
|
||||
@@ -17,6 +13,4 @@ public interface PySingleStarPattern extends PyAstSingleStarPattern, PyPattern {
|
||||
default PyPattern getPattern() {
|
||||
return Objects.requireNonNull(findChildByClass(this, PyPattern.class));
|
||||
}
|
||||
|
||||
@NotNull List<@Nullable PyType> getCapturedTypesFromSequenceType(@Nullable PyType sequenceType, @NotNull TypeEvalContext context);
|
||||
}
|
||||
|
||||
@@ -1,25 +1,14 @@
|
||||
package com.jetbrains.python.psi.impl;
|
||||
|
||||
import com.intellij.lang.ASTNode;
|
||||
import com.intellij.openapi.util.Ref;
|
||||
import com.intellij.psi.PsiElement;
|
||||
import com.intellij.psi.util.PsiTreeUtil;
|
||||
import com.intellij.util.containers.ContainerUtil;
|
||||
import com.jetbrains.python.codeInsight.stdlib.PyDataclassTypeProvider;
|
||||
import com.jetbrains.python.psi.*;
|
||||
import com.jetbrains.python.psi.resolve.PyResolveContext;
|
||||
import com.jetbrains.python.psi.resolve.RatedResolveResult;
|
||||
import com.jetbrains.python.psi.types.*;
|
||||
import com.jetbrains.python.psi.PyCaptureContext;
|
||||
import com.jetbrains.python.psi.PyCapturePattern;
|
||||
import com.jetbrains.python.psi.PyElementVisitor;
|
||||
import com.jetbrains.python.psi.types.PyType;
|
||||
import com.jetbrains.python.psi.types.TypeEvalContext;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import static com.jetbrains.python.codeInsight.controlflow.PyTypeAssertionEvaluator.createAssertionType;
|
||||
import static com.jetbrains.python.psi.PyUtil.as;
|
||||
import static com.jetbrains.python.psi.impl.PySequencePatternImpl.wrapInListType;
|
||||
|
||||
public class PyCapturePatternImpl extends PyElementImpl implements PyCapturePattern {
|
||||
public PyCapturePatternImpl(ASTNode astNode) {
|
||||
super(astNode);
|
||||
@@ -32,187 +21,6 @@ public class PyCapturePatternImpl extends PyElementImpl implements PyCapturePatt
|
||||
|
||||
@Override
|
||||
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
|
||||
return getCaptureType(this, context);
|
||||
}
|
||||
|
||||
static Set<String> SPECIAL_BUILTINS = Set.of(
|
||||
"bool", "bytearray", "bytes", "dict", "float", "frozenset", "int", "list", "set", "str", "tuple");
|
||||
|
||||
/**
|
||||
* Determines what type this pattern would have if it was a capture pattern (like a bare name or _).
|
||||
* <p>
|
||||
* In pattern matching, a capture pattern takes on the type of the entire matched expression,
|
||||
* regardless of any specific pattern constraints.
|
||||
* <p>
|
||||
* For example:
|
||||
* <pre>{@code
|
||||
* x: int | str
|
||||
* match x:
|
||||
* case a: # This is a capture pattern
|
||||
* # Here 'a' has type int | str
|
||||
* case str(): # This is a class pattern
|
||||
* # Capture type: int | str (same as what 'case a:' would get)
|
||||
* # Regular getType: str
|
||||
*
|
||||
* y: int
|
||||
* match y:
|
||||
* case str() as a:
|
||||
* # Capture type: int (same as what 'case a:' would get)
|
||||
* # Regular getType: intersect(int, str) (just 'str' for now)
|
||||
* }</pre>
|
||||
* @see PyPattern#getType(TypeEvalContext, TypeEvalContext.Key)
|
||||
*/
|
||||
static @Nullable PyType getCaptureType(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) {
|
||||
final PyElement parentPattern = PsiTreeUtil.getParentOfType(
|
||||
pattern, // Capture corresponds to:
|
||||
PyCaseClause.class, // - Subject of a match statement
|
||||
PySingleStarPattern.class, // - Type of parent sequence pattern
|
||||
PyDoubleStarPattern.class, // - Type of parent mapping pattern
|
||||
PyKeyValuePattern.class, // - Value type of parent mapping pattern
|
||||
PySequencePattern.class, // - Iterated item type of sequence
|
||||
PyClassPattern.class, // - Attribute type in the corresponding class
|
||||
PyKeywordPattern.class // - Attribute type in the corresponding class
|
||||
);
|
||||
|
||||
if (parentPattern instanceof PyCaseClause caseClause) {
|
||||
final PyMatchStatement matchStatement = as(caseClause.getParent(), PyMatchStatement.class);
|
||||
if (matchStatement == null) return null;
|
||||
|
||||
final PyExpression subject = matchStatement.getSubject();
|
||||
if (subject == null) return null;
|
||||
|
||||
PyType subjectType = context.getType(subject);
|
||||
for (PyCaseClause cs : matchStatement.getCaseClauses()) {
|
||||
if (cs == caseClause) break;
|
||||
if (cs.getPattern() == null) continue;
|
||||
if (cs.getGuardCondition() != null && !PyEvaluator.evaluateAsBoolean(cs.getGuardCondition(), false)) continue;
|
||||
if (cs.getPattern().canExcludePatternType(context)) {
|
||||
subjectType = Ref.deref(createAssertionType(subjectType, context.getType(cs.getPattern()), false, context));
|
||||
}
|
||||
}
|
||||
|
||||
return subjectType;
|
||||
}
|
||||
if (parentPattern instanceof PySingleStarPattern starPattern) {
|
||||
final PySequencePattern sequenceParent = as(starPattern.getParent(), PySequencePattern.class);
|
||||
if (sequenceParent == null) return null;
|
||||
final PyType sequenceType = PySequencePatternImpl.getSequenceCaptureType(sequenceParent, context);
|
||||
final PyType iteratedType = PyTypeUtil.toStream(sequenceType)
|
||||
.flatMap(it -> starPattern.getCapturedTypesFromSequenceType(it, context).stream()).collect(PyTypeUtil.toUnion());
|
||||
return wrapInListType(iteratedType, pattern);
|
||||
}
|
||||
if (parentPattern instanceof PyDoubleStarPattern) {
|
||||
final PyMappingPattern mappingParent = as(parentPattern.getParent(), PyMappingPattern.class);
|
||||
if (mappingParent == null) return null;
|
||||
var mappingType = PyTypeUtil.convertToType(context.getType(mappingParent), "typing.Mapping", pattern, context);
|
||||
if (mappingType instanceof PyCollectionType collectionType) {
|
||||
final PyClass dict = PyBuiltinCache.getInstance(pattern).getClass("dict");
|
||||
return dict != null ? new PyCollectionTypeImpl(dict, false, collectionType.getElementTypes()) : null;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
if (parentPattern instanceof PyKeyValuePattern keyValuePattern) {
|
||||
final PyMappingPattern mappingParent = as(keyValuePattern.getParent(), PyMappingPattern.class);
|
||||
if (mappingParent == null) return null;
|
||||
|
||||
return PyTypeUtil.toStream(getCaptureType(mappingParent, context)).map(type -> {
|
||||
if (type instanceof PyTypedDictType typedDictType) {
|
||||
if (context.getType(keyValuePattern.getKeyPattern()) instanceof PyLiteralType l &&
|
||||
l.getExpression() instanceof PyStringLiteralExpression str) {
|
||||
return typedDictType.getElementType(str.getStringValue());
|
||||
}
|
||||
}
|
||||
|
||||
PyType mappingType = PyTypeUtil.convertToType(type, "typing.Mapping", pattern, context);
|
||||
if (mappingType == null) {
|
||||
return PyNeverType.NEVER;
|
||||
}
|
||||
else if (mappingType instanceof PyCollectionType collectionType) {
|
||||
return collectionType.getElementTypes().get(1);
|
||||
}
|
||||
return null;
|
||||
}).collect(PyTypeUtil.toUnion());
|
||||
}
|
||||
if (parentPattern instanceof PySequencePattern sequencePattern) {
|
||||
final PyType sequenceType = PySequencePatternImpl.getSequenceCaptureType(sequencePattern, context);
|
||||
if (sequenceType == null) return null;
|
||||
|
||||
// This is done to skip group- and as-patterns
|
||||
final var sequenceMember = PsiTreeUtil.findFirstParent(pattern, el -> el.getParent() == sequencePattern);
|
||||
final List<PyPattern> elements = sequencePattern.getElements();
|
||||
final int idx = elements.indexOf(sequenceMember);
|
||||
final int starIdx = ContainerUtil.indexOf(elements, it2 -> it2 instanceof PySingleStarPattern);
|
||||
|
||||
return PyTypeUtil.toStream(sequenceType).map(it -> {
|
||||
if (it instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) {
|
||||
if (starIdx == -1 || idx < starIdx) {
|
||||
return tupleType.getElementType(idx);
|
||||
}
|
||||
else {
|
||||
final int starSpan = tupleType.getElementCount() - elements.size();
|
||||
return tupleType.getElementType(idx + starSpan);
|
||||
}
|
||||
}
|
||||
var upcast = PyTypeUtil.convertToType(it, "typing.Sequence", pattern, context);
|
||||
if (upcast instanceof PyCollectionType collectionType) {
|
||||
return collectionType.getIteratedItemType();
|
||||
}
|
||||
return null;
|
||||
}).collect(PyTypeUtil.toUnion());
|
||||
}
|
||||
if (parentPattern instanceof PyClassPattern classPattern) {
|
||||
final List<PyPattern> arguments = classPattern.getArgumentList().getPatterns();
|
||||
int index = arguments.indexOf(pattern);
|
||||
if (index < 0) return null;
|
||||
|
||||
// capture type can be a union like: list[int] | list[str]
|
||||
return PyTypeUtil.toStream(context.getType(classPattern)).map(type -> {
|
||||
if (type instanceof PyClassType classType) {
|
||||
if (SPECIAL_BUILTINS.contains(classType.getClassQName())) {
|
||||
if (index == 0) {
|
||||
return classType;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
List<String> matchArgs = getMatchArgs(classType, context);
|
||||
if (matchArgs == null || matchArgs.size() > arguments.size()) return null;
|
||||
|
||||
final PyTypedElement instanceAttribute = as(resolveTypeMember(classType, matchArgs.get(index), context), PyTypedElement.class);
|
||||
if (instanceAttribute == null) return null;
|
||||
|
||||
return PyTypeChecker.substitute(context.getType(instanceAttribute), PyTypeChecker.unifyReceiver(classType, context), context);
|
||||
}
|
||||
return null;
|
||||
}).collect(PyTypeUtil.toUnion());
|
||||
}
|
||||
if (parentPattern instanceof PyKeywordPattern keywordPattern) {
|
||||
final PyClassPattern classPattern = PsiTreeUtil.getParentOfType(keywordPattern, PyClassPattern.class);
|
||||
if (classPattern == null) return null;
|
||||
|
||||
if (context.getType(classPattern) instanceof PyClassType classType) {
|
||||
final PyTypedElement instanceAttribute = as(resolveTypeMember(classType, keywordPattern.getKeyword(), context), PyTypedElement.class);
|
||||
if (instanceAttribute == null) return null;
|
||||
return context.getType(instanceAttribute);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
static @Nullable List<@NotNull String> getMatchArgs(@NotNull PyClassType type, @NotNull TypeEvalContext context) {
|
||||
final PyClass cls = type.getPyClass();
|
||||
List<String> matchArgs = cls.getOwnMatchArgs();
|
||||
if (matchArgs == null) {
|
||||
matchArgs = PyDataclassTypeProvider.Companion.getGeneratedMatchArgs(cls, context);
|
||||
}
|
||||
return matchArgs;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
static PsiElement resolveTypeMember(@NotNull PyType type, @NotNull String name, @NotNull TypeEvalContext context) {
|
||||
final PyResolveContext resolveContext = PyResolveContext.defaultContext(context);
|
||||
final List<? extends RatedResolveResult> results = type.resolveMember(name, null, AccessDirection.READ, resolveContext);
|
||||
return !ContainerUtil.isEmpty(results) ? results.get(0).getElement() : null;
|
||||
return PyCaptureContext.getCaptureType(this, context);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,16 +1,35 @@
|
||||
package com.jetbrains.python.psi.impl;
|
||||
package com.jetbrains.python.psi.impl
|
||||
|
||||
import com.intellij.lang.ASTNode;
|
||||
import com.jetbrains.python.psi.PyCaseClause;
|
||||
import com.jetbrains.python.psi.PyElementVisitor;
|
||||
import com.intellij.lang.ASTNode
|
||||
import com.intellij.openapi.util.Ref
|
||||
import com.jetbrains.python.codeInsight.controlflow.PyTypeAssertionEvaluator
|
||||
import com.jetbrains.python.psi.PyCaseClause
|
||||
import com.jetbrains.python.psi.PyElementVisitor
|
||||
import com.jetbrains.python.psi.PyMatchStatement
|
||||
import com.jetbrains.python.psi.PyPattern
|
||||
import com.jetbrains.python.psi.types.PyType
|
||||
import com.jetbrains.python.psi.types.TypeEvalContext
|
||||
|
||||
public class PyCaseClauseImpl extends PyElementImpl implements PyCaseClause {
|
||||
public PyCaseClauseImpl(ASTNode astNode) {
|
||||
super(astNode);
|
||||
class PyCaseClauseImpl(astNode: ASTNode?) : PyElementImpl(astNode), PyCaseClause {
|
||||
override fun acceptPyVisitor(pyVisitor: PyElementVisitor) {
|
||||
pyVisitor.visitPyCaseClause(this)
|
||||
}
|
||||
|
||||
@Override
|
||||
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
|
||||
pyVisitor.visitPyCaseClause(this);
|
||||
override fun getCaptureTypeForChild(pattern: PyPattern, context: TypeEvalContext): PyType? {
|
||||
val matchStatement = getParent() as? PyMatchStatement ?: return null
|
||||
val subject = matchStatement.subject ?: return null
|
||||
|
||||
var subjectType = context.getType(subject)
|
||||
for (cs in matchStatement.caseClauses) {
|
||||
if (cs === this) break
|
||||
if (cs.pattern == null) continue
|
||||
if (cs.guardCondition != null && !PyEvaluator.evaluateAsBoolean(cs.guardCondition, false)) continue
|
||||
if (cs.pattern!!.canExcludePatternType(context)) {
|
||||
subjectType = Ref.deref(
|
||||
PyTypeAssertionEvaluator.createAssertionType(subjectType, context.getType(cs.pattern!!), false, context))
|
||||
}
|
||||
}
|
||||
|
||||
return subjectType
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,8 +2,14 @@ package com.jetbrains.python.psi.impl;
|
||||
|
||||
import com.intellij.lang.ASTNode;
|
||||
import com.intellij.openapi.util.Ref;
|
||||
import com.intellij.psi.PsiElement;
|
||||
import com.intellij.psi.util.PsiTreeUtil;
|
||||
import com.intellij.util.containers.ContainerUtil;
|
||||
import com.jetbrains.python.codeInsight.controlflow.PyTypeAssertionEvaluator;
|
||||
import com.jetbrains.python.codeInsight.stdlib.PyDataclassTypeProvider;
|
||||
import com.jetbrains.python.psi.*;
|
||||
import com.jetbrains.python.psi.resolve.PyResolveContext;
|
||||
import com.jetbrains.python.psi.resolve.RatedResolveResult;
|
||||
import com.jetbrains.python.psi.types.*;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
@@ -11,8 +17,8 @@ import org.jetbrains.annotations.Nullable;
|
||||
import java.util.List;
|
||||
|
||||
import static com.intellij.util.containers.ContainerUtil.getFirstItem;
|
||||
import static com.jetbrains.python.psi.PyUtil.as;
|
||||
import static com.jetbrains.python.psi.PyUtil.multiResolveTopPriority;
|
||||
import static com.jetbrains.python.psi.impl.PyCapturePatternImpl.*;
|
||||
|
||||
public class PyClassPatternImpl extends PyElementImpl implements PyClassPattern {
|
||||
public PyClassPatternImpl(ASTNode astNode) {
|
||||
@@ -29,7 +35,7 @@ public class PyClassPatternImpl extends PyElementImpl implements PyClassPattern
|
||||
final PyType type = context.getType(getClassNameReference());
|
||||
if (type instanceof PyClassType classType) {
|
||||
final PyType instanceType = classType.toInstance();
|
||||
final PyType captureType = PyCapturePatternImpl.getCaptureType(this, context);
|
||||
final PyType captureType = PyCaptureContext.getCaptureType(this, context);
|
||||
return Ref.deref(PyTypeAssertionEvaluator.createAssertionType(captureType, instanceType, true, context));
|
||||
}
|
||||
return null;
|
||||
@@ -75,13 +81,53 @@ public class PyClassPatternImpl extends PyElementImpl implements PyClassPattern
|
||||
}
|
||||
}
|
||||
}
|
||||
final PyType captureType = getCaptureType(this, context);
|
||||
final PyType captureType = PyCaptureContext.getCaptureType(this, context);
|
||||
final PyType patternType = context.getType(this);
|
||||
return PyTypeUtil.toStream(captureType).anyMatch(it -> PyTypeChecker.match(patternType, it, context));
|
||||
}
|
||||
|
||||
|
||||
@Override
|
||||
public @Nullable PyType getCaptureTypeForChild(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) {
|
||||
pattern = as(PsiTreeUtil.findFirstParent(pattern, el -> this.getArgumentList() == el.getParent()), PyPattern.class);
|
||||
if (pattern == null) return null;
|
||||
|
||||
if (pattern instanceof PyKeywordPattern keywordPattern) {
|
||||
if (context.getType(this) instanceof PyClassType classType) {
|
||||
final PyTypedElement instanceAttribute = as(resolveTypeMember(classType, keywordPattern.getKeyword(), context), PyTypedElement.class);
|
||||
if (instanceAttribute == null) return null;
|
||||
return context.getType(instanceAttribute);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
final List<PyPattern> arguments = getArgumentList().getPatterns();
|
||||
int index = arguments.indexOf(pattern);
|
||||
if (index < 0) return null;
|
||||
|
||||
// capture type can be a union like: list[int] | list[str]
|
||||
return PyTypeUtil.toStream(context.getType(this)).map(type -> {
|
||||
if (type instanceof PyClassType classType) {
|
||||
if (SPECIAL_BUILTINS.contains(classType.getClassQName())) {
|
||||
if (index == 0) {
|
||||
return classType;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
List<String> matchArgs = getMatchArgs(classType, context);
|
||||
if (matchArgs == null || matchArgs.size() > arguments.size()) return null;
|
||||
|
||||
final PyTypedElement instanceAttribute = as(resolveTypeMember(classType, matchArgs.get(index), context), PyTypedElement.class);
|
||||
if (instanceAttribute == null) return null;
|
||||
|
||||
return PyTypeChecker.substitute(context.getType(instanceAttribute), PyTypeChecker.unifyReceiver(classType, context), context);
|
||||
}
|
||||
return null;
|
||||
}).collect(PyTypeUtil.toUnion());
|
||||
}
|
||||
|
||||
static boolean canExcludeArgumentPatternType(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) {
|
||||
final var captureType = getCaptureType(pattern, context);
|
||||
final var captureType = PyCaptureContext.getCaptureType(pattern, context);
|
||||
final var patternType = context.getType(pattern);
|
||||
// For class pattern arguments, we need to ensure that the argument pattern covers its capture type fully
|
||||
if (Ref.deref(PyTypeAssertionEvaluator.createAssertionType(captureType, patternType, false, context)) instanceof PyNeverType) {
|
||||
@@ -90,4 +136,20 @@ public class PyClassPatternImpl extends PyElementImpl implements PyClassPattern
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
private static @Nullable List<@NotNull String> getMatchArgs(@NotNull PyClassType type, @NotNull TypeEvalContext context) {
|
||||
final PyClass cls = type.getPyClass();
|
||||
List<String> matchArgs = cls.getOwnMatchArgs();
|
||||
if (matchArgs == null) {
|
||||
matchArgs = PyDataclassTypeProvider.Companion.getGeneratedMatchArgs(cls, context);
|
||||
}
|
||||
return matchArgs;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
static PsiElement resolveTypeMember(@NotNull PyType type, @NotNull String name, @NotNull TypeEvalContext context) {
|
||||
final PyResolveContext resolveContext = PyResolveContext.defaultContext(context);
|
||||
final List<? extends RatedResolveResult> results = type.resolveMember(name, null, AccessDirection.READ, resolveContext);
|
||||
return !ContainerUtil.isEmpty(results) ? results.get(0).getElement() : null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.jetbrains.python.psi.impl;
|
||||
import com.intellij.lang.ASTNode;
|
||||
import com.intellij.psi.PsiElement;
|
||||
import com.intellij.psi.PsiListLikeElement;
|
||||
import com.intellij.psi.util.PsiTreeUtil;
|
||||
import com.jetbrains.python.psi.*;
|
||||
import com.jetbrains.python.psi.types.*;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
@@ -45,7 +46,7 @@ public class PyMappingPatternImpl extends PyElementImpl implements PyMappingPatt
|
||||
|
||||
PyType patternMappingType = wrapInMappingType(PyUnionType.union(keyTypes), PyUnionType.union(valueTypes), this);
|
||||
|
||||
PyType captureTypes = PyCapturePatternImpl.getCaptureType(this, context);
|
||||
PyType captureTypes = PyCaptureContext.getCaptureType(this, context);
|
||||
PyType filteredType = PyTypeUtil.toStream(captureTypes).filter(captureType -> {
|
||||
var mappingType = PyTypeUtil.convertToType(captureType, "typing.Mapping", this, context);
|
||||
if (mappingType == null) return false;
|
||||
@@ -56,6 +57,39 @@ public class PyMappingPatternImpl extends PyElementImpl implements PyMappingPatt
|
||||
return filteredType == null ? patternMappingType : filteredType;
|
||||
}
|
||||
|
||||
@Override
|
||||
public @Nullable PyType getCaptureTypeForChild(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) {
|
||||
final var sequenceMember = PsiTreeUtil.findFirstParent(pattern, el -> this == el.getParent());
|
||||
if (sequenceMember instanceof PyDoubleStarPattern) {
|
||||
var mappingType = PyTypeUtil.convertToType(context.getType(this), "typing.Mapping", pattern, context);
|
||||
if (mappingType instanceof PyCollectionType collectionType) {
|
||||
final PyClass dict = PyBuiltinCache.getInstance(pattern).getClass("dict");
|
||||
return dict != null ? new PyCollectionTypeImpl(dict, false, collectionType.getElementTypes()) : null;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
else if (sequenceMember instanceof PyKeyValuePattern keyValuePattern) {
|
||||
return PyTypeUtil.toStream(PyCaptureContext.getCaptureType(this, context)).map(type -> {
|
||||
if (type instanceof PyTypedDictType typedDictType) {
|
||||
if (context.getType(keyValuePattern.getKeyPattern()) instanceof PyLiteralType l &&
|
||||
l.getExpression() instanceof PyStringLiteralExpression str) {
|
||||
return typedDictType.getElementType(str.getStringValue());
|
||||
}
|
||||
}
|
||||
|
||||
PyType mappingType = PyTypeUtil.convertToType(type, "typing.Mapping", pattern, context);
|
||||
if (mappingType == null) {
|
||||
return PyNeverType.NEVER;
|
||||
}
|
||||
else if (mappingType instanceof PyCollectionType collectionType) {
|
||||
return collectionType.getElementTypes().get(1);
|
||||
}
|
||||
return null;
|
||||
}).collect(PyTypeUtil.toUnion());
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
private static @Nullable PyType wrapInMappingType(@Nullable PyType keyType, @Nullable PyType valueType, @NotNull PsiElement resolveAnchor) {
|
||||
keyType = PyLiteralType.upcastLiteralToClass(keyType);
|
||||
valueType = PyLiteralType.upcastLiteralToClass(valueType);
|
||||
|
||||
@@ -3,6 +3,7 @@ package com.jetbrains.python.psi.impl;
|
||||
import com.intellij.lang.ASTNode;
|
||||
import com.intellij.psi.PsiElement;
|
||||
import com.intellij.psi.PsiListLikeElement;
|
||||
import com.intellij.psi.util.PsiTreeUtil;
|
||||
import com.intellij.util.ArrayUtil;
|
||||
import com.intellij.util.containers.ContainerUtil;
|
||||
import com.jetbrains.python.psi.*;
|
||||
@@ -39,7 +40,7 @@ public class PySequencePatternImpl extends PyElementImpl implements PySequencePa
|
||||
final ArrayList<PyType> types = new ArrayList<>();
|
||||
for (PyPattern pattern : getElements()) {
|
||||
if (pattern instanceof PySingleStarPattern starPattern) {
|
||||
types.addAll(starPattern.getCapturedTypesFromSequenceType(sequenceCaptureType, context));
|
||||
types.addAll(getCapturedTypesFromSequenceType(starPattern, sequenceCaptureType, context));
|
||||
}
|
||||
else {
|
||||
types.add(context.getType(pattern));
|
||||
@@ -74,6 +75,40 @@ public class PySequencePatternImpl extends PyElementImpl implements PySequencePa
|
||||
return true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public @Nullable PyType getCaptureTypeForChild(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) {
|
||||
final PyType sequenceType = getSequenceCaptureType(this, context);
|
||||
if (sequenceType == null) return null;
|
||||
|
||||
// This is done to skip group- and as-patterns
|
||||
final var sequenceMember = PsiTreeUtil.findFirstParent(pattern, el -> this == el.getParent());
|
||||
if (sequenceMember instanceof PySingleStarPattern starPattern) {
|
||||
final PyType iteratedType = PyTypeUtil.toStream(sequenceType)
|
||||
.flatMap(it -> getCapturedTypesFromSequenceType(starPattern, it, context).stream()).collect(PyTypeUtil.toUnion());
|
||||
return wrapInListType(iteratedType, pattern);
|
||||
}
|
||||
final List<PyPattern> elements = getElements();
|
||||
final int idx = elements.indexOf(sequenceMember);
|
||||
final int starIdx = ContainerUtil.indexOf(elements, it2 -> it2 instanceof PySingleStarPattern);
|
||||
|
||||
return PyTypeUtil.toStream(sequenceType).map(it -> {
|
||||
if (it instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) {
|
||||
if (starIdx == -1 || idx < starIdx) {
|
||||
return tupleType.getElementType(idx);
|
||||
}
|
||||
else {
|
||||
final int starSpan = tupleType.getElementCount() - elements.size();
|
||||
return tupleType.getElementType(idx + starSpan);
|
||||
}
|
||||
}
|
||||
var upcast = PyTypeUtil.convertToType(it, "typing.Sequence", pattern, context);
|
||||
if (upcast instanceof PyCollectionType collectionType) {
|
||||
return collectionType.getIteratedItemType();
|
||||
}
|
||||
return null;
|
||||
}).collect(PyTypeUtil.toUnion());
|
||||
}
|
||||
|
||||
static @Nullable PyType wrapInListType(@Nullable PyType elementType, @NotNull PsiElement resolveAnchor) {
|
||||
final PyClass list = PyBuiltinCache.getInstance(resolveAnchor).getClass("list");
|
||||
return list != null ? new PyCollectionTypeImpl(list, false, Collections.singletonList(upcastLiteralToClass(elementType))) : null;
|
||||
@@ -86,12 +121,12 @@ public class PySequencePatternImpl extends PyElementImpl implements PySequencePa
|
||||
}
|
||||
|
||||
/**
|
||||
* Similar to {@link PyCapturePatternImpl#getCaptureType(PyPattern, TypeEvalContext)},
|
||||
* Similar to {@link PyCaptureContext#getCaptureType(PyPattern, TypeEvalContext)},
|
||||
* but only chooses types that would match to typing.Sequence, and have correct length
|
||||
*/
|
||||
@Nullable
|
||||
static PyType getSequenceCaptureType(@NotNull PySequencePattern pattern, @NotNull TypeEvalContext context) {
|
||||
final PyType captureTypes = PyCapturePatternImpl.getCaptureType(pattern, context);
|
||||
final PyType captureTypes = PyCaptureContext.getCaptureType(pattern, context);
|
||||
final boolean hasStar = ContainerUtil.exists(pattern.getElements(), it -> it instanceof PySingleStarPattern);
|
||||
|
||||
List<PyType> types = new ArrayList<>();
|
||||
@@ -131,4 +166,20 @@ public class PySequencePatternImpl extends PyElementImpl implements PySequencePa
|
||||
}
|
||||
return PyUnionType.union(types);
|
||||
}
|
||||
|
||||
private static @NotNull List<@Nullable PyType> getCapturedTypesFromSequenceType(@NotNull PySingleStarPattern starPattern,
|
||||
@Nullable PyType sequenceType,
|
||||
@NotNull TypeEvalContext context) {
|
||||
if (starPattern.getParent() instanceof PySequencePattern sequenceParent) {
|
||||
final int idx = sequenceParent.getElements().indexOf(starPattern);
|
||||
if (sequenceType instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) {
|
||||
return tupleType.getElementTypes().subList(idx, idx + tupleType.getElementCount() - sequenceParent.getElements().size() + 1);
|
||||
}
|
||||
var upcast = PyTypeUtil.convertToType(sequenceType, "typing.Sequence", starPattern, context);
|
||||
if (upcast instanceof PyCollectionType collectionType) {
|
||||
return Collections.singletonList(collectionType.getIteratedItemType());
|
||||
}
|
||||
}
|
||||
return List.of();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
package com.jetbrains.python.psi.impl;
|
||||
|
||||
import com.intellij.lang.ASTNode;
|
||||
import com.jetbrains.python.psi.*;
|
||||
import com.jetbrains.python.psi.types.*;
|
||||
import com.jetbrains.python.psi.PyElementVisitor;
|
||||
import com.jetbrains.python.psi.PyInstantTypeProvider;
|
||||
import com.jetbrains.python.psi.PySingleStarPattern;
|
||||
import com.jetbrains.python.psi.types.PyType;
|
||||
import com.jetbrains.python.psi.types.TypeEvalContext;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class PySingleStarPatternImpl extends PyElementImpl implements PySingleStarPattern, PyInstantTypeProvider {
|
||||
public PySingleStarPatternImpl(ASTNode astNode) {
|
||||
super(astNode);
|
||||
@@ -23,19 +23,4 @@ public class PySingleStarPatternImpl extends PyElementImpl implements PySingleSt
|
||||
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public @NotNull List<@Nullable PyType> getCapturedTypesFromSequenceType(@Nullable PyType sequenceType, @NotNull TypeEvalContext context) {
|
||||
if (getParent() instanceof PySequencePattern sequenceParent) {
|
||||
final int idx = sequenceParent.getElements().indexOf(this);
|
||||
if (sequenceType instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) {
|
||||
return tupleType.getElementTypes().subList(idx, idx + tupleType.getElementCount() - sequenceParent.getElements().size() + 1);
|
||||
}
|
||||
var upcast = PyTypeUtil.convertToType(sequenceType, "typing.Sequence", this, context);
|
||||
if (upcast instanceof PyCollectionType collectionType) {
|
||||
return Collections.singletonList(collectionType.getIteratedItemType());
|
||||
}
|
||||
}
|
||||
return List.of();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package com.jetbrains.python.psi.impl;
|
||||
|
||||
import com.intellij.lang.ASTNode;
|
||||
import com.jetbrains.python.psi.PyCaptureContext;
|
||||
import com.jetbrains.python.psi.PyElementVisitor;
|
||||
import com.jetbrains.python.psi.PyWildcardPattern;
|
||||
import com.jetbrains.python.psi.types.PyType;
|
||||
@@ -20,6 +21,6 @@ public class PyWildcardPatternImpl extends PyElementImpl implements PyWildcardPa
|
||||
|
||||
@Override
|
||||
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
|
||||
return PyCapturePatternImpl.getCaptureType(this, context);
|
||||
return PyCaptureContext.getCaptureType(this, context);
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user