From 0dfd1f65e459a7bc8e6114977d3851ec1a262de5 Mon Sep 17 00:00:00 2001 From: "Aleksandr.Govenko" Date: Tue, 1 Jul 2025 18:21:26 +0200 Subject: [PATCH] [python] Refactor PyCapturePatternImpl. Introduce PyCaptureContext (cherry picked from commit 2e3fbf4c7d79e6031c7c087e5c7e7e49046587fd) IJ-MR-168826 GitOrigin-RevId: b87eda39543460451311fc875d6ae3722d671db0 --- .../jetbrains/python/psi/PyCaseClause.java | 2 +- .../jetbrains/python/psi/PyClassPattern.java | 7 +- .../python/psi/PyMappingPattern.java | 2 +- .../src/com/jetbrains/python/psi/PyPattern.kt | 72 +++++-- .../python/psi/PySequencePattern.java | 2 +- .../python/psi/PySingleStarPattern.java | 6 - .../python/psi/impl/PyCapturePatternImpl.java | 204 +----------------- .../python/psi/impl/PyCaseClauseImpl.kt | 39 +++- .../python/psi/impl/PyClassPatternImpl.java | 72 ++++++- .../python/psi/impl/PyMappingPatternImpl.java | 36 +++- .../psi/impl/PySequencePatternImpl.java | 57 ++++- .../psi/impl/PySingleStarPatternImpl.java | 25 +-- .../psi/impl/PyWildcardPatternImpl.java | 3 +- 13 files changed, 260 insertions(+), 267 deletions(-) diff --git a/python/python-psi-api/src/com/jetbrains/python/psi/PyCaseClause.java b/python/python-psi-api/src/com/jetbrains/python/psi/PyCaseClause.java index 036170f28dee..2be5be9eef60 100644 --- a/python/python-psi-api/src/com/jetbrains/python/psi/PyCaseClause.java +++ b/python/python-psi-api/src/com/jetbrains/python/psi/PyCaseClause.java @@ -4,7 +4,7 @@ package com.jetbrains.python.psi; import com.jetbrains.python.ast.PyAstCaseClause; import org.jetbrains.annotations.Nullable; -public interface PyCaseClause extends PyAstCaseClause, PyStatementPart { +public interface PyCaseClause extends PyAstCaseClause, PyStatementPart, PyCaptureContext { @Override default @Nullable PyPattern getPattern() { return (PyPattern)PyAstCaseClause.super.getPattern(); diff --git a/python/python-psi-api/src/com/jetbrains/python/psi/PyClassPattern.java b/python/python-psi-api/src/com/jetbrains/python/psi/PyClassPattern.java index cae0782e34ab..5f14c641b42d 100644 --- a/python/python-psi-api/src/com/jetbrains/python/psi/PyClassPattern.java +++ b/python/python-psi-api/src/com/jetbrains/python/psi/PyClassPattern.java @@ -4,7 +4,12 @@ package com.jetbrains.python.psi; import com.jetbrains.python.ast.PyAstClassPattern; import org.jetbrains.annotations.NotNull; -public interface PyClassPattern extends PyAstClassPattern, PyPattern { +import java.util.Set; + +public interface PyClassPattern extends PyAstClassPattern, PyPattern, PyCaptureContext { + Set SPECIAL_BUILTINS = Set.of( + "bool", "bytearray", "bytes", "dict", "float", "frozenset", "int", "list", "set", "str", "tuple"); + @Override default @NotNull PyReferenceExpression getClassNameReference() { return (PyReferenceExpression)PyAstClassPattern.super.getClassNameReference(); diff --git a/python/python-psi-api/src/com/jetbrains/python/psi/PyMappingPattern.java b/python/python-psi-api/src/com/jetbrains/python/psi/PyMappingPattern.java index 31caf342dc1b..ec07888ec5e9 100644 --- a/python/python-psi-api/src/com/jetbrains/python/psi/PyMappingPattern.java +++ b/python/python-psi-api/src/com/jetbrains/python/psi/PyMappingPattern.java @@ -3,5 +3,5 @@ package com.jetbrains.python.psi; import com.jetbrains.python.ast.PyAstMappingPattern; -public interface PyMappingPattern extends PyAstMappingPattern, PyPattern { +public interface PyMappingPattern extends PyAstMappingPattern, PyPattern, PyCaptureContext { } diff --git a/python/python-psi-api/src/com/jetbrains/python/psi/PyPattern.kt b/python/python-psi-api/src/com/jetbrains/python/psi/PyPattern.kt index 639f01963b0a..a5a1d4a8f328 100644 --- a/python/python-psi-api/src/com/jetbrains/python/psi/PyPattern.kt +++ b/python/python-psi-api/src/com/jetbrains/python/psi/PyPattern.kt @@ -1,43 +1,77 @@ // Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. -package com.jetbrains.python.psi; +package com.jetbrains.python.psi -import com.jetbrains.python.ast.PyAstPattern; -import com.jetbrains.python.psi.types.PyType; -import com.jetbrains.python.psi.types.TypeEvalContext; -import org.jetbrains.annotations.ApiStatus; -import org.jetbrains.annotations.NotNull; -import org.jetbrains.annotations.Nullable; +import com.intellij.psi.util.findParentOfType +import com.jetbrains.python.ast.PyAstPattern +import com.jetbrains.python.psi.types.PyType +import com.jetbrains.python.psi.types.TypeEvalContext +import org.jetbrains.annotations.ApiStatus -public interface PyPattern extends PyAstPattern, PyTypedElement { +interface PyPattern : PyAstPattern, PyTypedElement { /** * Returns the type that would be captured by this pattern when matching. - *

+ * + * * Unlike other PyTypedElements where getType returns their own type, pattern's getType * returns the type that would result from a successful match. For example: * - *

{@code
+   * ```python
    * class Plant: pass
    * class Animal: pass
    * class Dog(Animal): pass
    *
    * x: Dog | Plant
    * match x:
-   *     case Animal():
-   *         # getType returns Dog here, even though the pattern is Animal()
-   * }
+ * case Animal(): + * # getType returns Dog here, even though the pattern is Animal() + * ``` * - * @see PyCapturePatternImpl#getCaptureType(PyPattern, TypeEvalContext) + * @see PyCapturePatternImpl.getCaptureType */ - @Override - @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key); + override fun getType(context: TypeEvalContext, key: TypeEvalContext.Key): PyType? /** - * Decides if the set of values described by a pattern is suitable + * Decides if the set of values described by a pattern is suitable * to be subtracted (excluded) from a subject type on the negative edge, * or if this pattern is too specific. */ @ApiStatus.Experimental - default boolean canExcludePatternType(@NotNull TypeEvalContext context) { - return true; + fun canExcludePatternType(context: TypeEvalContext): Boolean { + return true + } +} + +interface PyCaptureContext : PyElement { + fun getCaptureTypeForChild(pattern: PyPattern, context: TypeEvalContext): PyType? + + companion object { + /** + * Determines what type this pattern would have if it was a capture pattern (like a bare name or _). + * + * In pattern matching, a capture pattern takes on the type of the entire matched expression, + * regardless of any specific pattern constraints. + * + * For example: + * ```python + * x: int | str + * match x: + * case a: # This is a capture pattern + * # Here 'a' has type int | str + * case str(): # This is a class pattern + * # Capture type: int | str (same as what 'case a:' would get) + * # Regular getType: str + * + * y: int + * match y: + * case str() as a: + * # Capture type: int (same as what 'case a:' would get) + * # Regular getType: intersect(int, str) (just 'str' for now) + * ``` + * @see PyPattern#getType(TypeEvalContext, TypeEvalContext.Key) + */ + @JvmStatic + fun getCaptureType(pattern: PyPattern, context: TypeEvalContext): PyType? { + return pattern.findParentOfType()?.getCaptureTypeForChild(pattern, context) + } } } diff --git a/python/python-psi-api/src/com/jetbrains/python/psi/PySequencePattern.java b/python/python-psi-api/src/com/jetbrains/python/psi/PySequencePattern.java index 9ac9df06e6aa..3bf1c26e6416 100644 --- a/python/python-psi-api/src/com/jetbrains/python/psi/PySequencePattern.java +++ b/python/python-psi-api/src/com/jetbrains/python/psi/PySequencePattern.java @@ -8,7 +8,7 @@ import java.util.List; import static com.jetbrains.python.ast.PyAstElementKt.findChildrenByClass; -public interface PySequencePattern extends PyAstSequencePattern, PyPattern { +public interface PySequencePattern extends PyAstSequencePattern, PyPattern, PyCaptureContext { default @NotNull List<@NotNull PyPattern> getElements() { return List.of(findChildrenByClass(this, PyPattern.class)); } diff --git a/python/python-psi-api/src/com/jetbrains/python/psi/PySingleStarPattern.java b/python/python-psi-api/src/com/jetbrains/python/psi/PySingleStarPattern.java index 71b5d6026515..1ab04e9f529d 100644 --- a/python/python-psi-api/src/com/jetbrains/python/psi/PySingleStarPattern.java +++ b/python/python-psi-api/src/com/jetbrains/python/psi/PySingleStarPattern.java @@ -2,12 +2,8 @@ package com.jetbrains.python.psi; import com.jetbrains.python.ast.PyAstSingleStarPattern; -import com.jetbrains.python.psi.types.PyType; -import com.jetbrains.python.psi.types.TypeEvalContext; import org.jetbrains.annotations.NotNull; -import org.jetbrains.annotations.Nullable; -import java.util.List; import java.util.Objects; import static com.jetbrains.python.ast.PyAstElementKt.findChildByClass; @@ -17,6 +13,4 @@ public interface PySingleStarPattern extends PyAstSingleStarPattern, PyPattern { default PyPattern getPattern() { return Objects.requireNonNull(findChildByClass(this, PyPattern.class)); } - - @NotNull List<@Nullable PyType> getCapturedTypesFromSequenceType(@Nullable PyType sequenceType, @NotNull TypeEvalContext context); } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyCapturePatternImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyCapturePatternImpl.java index e575a496d37d..22d74e158d9e 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyCapturePatternImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyCapturePatternImpl.java @@ -1,25 +1,14 @@ package com.jetbrains.python.psi.impl; import com.intellij.lang.ASTNode; -import com.intellij.openapi.util.Ref; -import com.intellij.psi.PsiElement; -import com.intellij.psi.util.PsiTreeUtil; -import com.intellij.util.containers.ContainerUtil; -import com.jetbrains.python.codeInsight.stdlib.PyDataclassTypeProvider; -import com.jetbrains.python.psi.*; -import com.jetbrains.python.psi.resolve.PyResolveContext; -import com.jetbrains.python.psi.resolve.RatedResolveResult; -import com.jetbrains.python.psi.types.*; +import com.jetbrains.python.psi.PyCaptureContext; +import com.jetbrains.python.psi.PyCapturePattern; +import com.jetbrains.python.psi.PyElementVisitor; +import com.jetbrains.python.psi.types.PyType; +import com.jetbrains.python.psi.types.TypeEvalContext; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; -import java.util.List; -import java.util.Set; - -import static com.jetbrains.python.codeInsight.controlflow.PyTypeAssertionEvaluator.createAssertionType; -import static com.jetbrains.python.psi.PyUtil.as; -import static com.jetbrains.python.psi.impl.PySequencePatternImpl.wrapInListType; - public class PyCapturePatternImpl extends PyElementImpl implements PyCapturePattern { public PyCapturePatternImpl(ASTNode astNode) { super(astNode); @@ -32,187 +21,6 @@ public class PyCapturePatternImpl extends PyElementImpl implements PyCapturePatt @Override public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) { - return getCaptureType(this, context); - } - - static Set SPECIAL_BUILTINS = Set.of( - "bool", "bytearray", "bytes", "dict", "float", "frozenset", "int", "list", "set", "str", "tuple"); - - /** - * Determines what type this pattern would have if it was a capture pattern (like a bare name or _). - *

- * In pattern matching, a capture pattern takes on the type of the entire matched expression, - * regardless of any specific pattern constraints. - *

- * For example: - *

{@code
-   * x: int | str
-   * match x:
-   *     case a:         # This is a capture pattern
-   *         # Here 'a' has type int | str
-   *     case str():     # This is a class pattern
-   *         # Capture type: int | str (same as what 'case a:' would get)
-   *         # Regular getType: str
-   *
-   * y: int
-   * match y:
-   *     case str() as a:
-   *         # Capture type: int (same as what 'case a:' would get)
-   *         # Regular getType: intersect(int, str) (just 'str' for now)
-   * }
- * @see PyPattern#getType(TypeEvalContext, TypeEvalContext.Key) - */ - static @Nullable PyType getCaptureType(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) { - final PyElement parentPattern = PsiTreeUtil.getParentOfType( - pattern, // Capture corresponds to: - PyCaseClause.class, // - Subject of a match statement - PySingleStarPattern.class, // - Type of parent sequence pattern - PyDoubleStarPattern.class, // - Type of parent mapping pattern - PyKeyValuePattern.class, // - Value type of parent mapping pattern - PySequencePattern.class, // - Iterated item type of sequence - PyClassPattern.class, // - Attribute type in the corresponding class - PyKeywordPattern.class // - Attribute type in the corresponding class - ); - - if (parentPattern instanceof PyCaseClause caseClause) { - final PyMatchStatement matchStatement = as(caseClause.getParent(), PyMatchStatement.class); - if (matchStatement == null) return null; - - final PyExpression subject = matchStatement.getSubject(); - if (subject == null) return null; - - PyType subjectType = context.getType(subject); - for (PyCaseClause cs : matchStatement.getCaseClauses()) { - if (cs == caseClause) break; - if (cs.getPattern() == null) continue; - if (cs.getGuardCondition() != null && !PyEvaluator.evaluateAsBoolean(cs.getGuardCondition(), false)) continue; - if (cs.getPattern().canExcludePatternType(context)) { - subjectType = Ref.deref(createAssertionType(subjectType, context.getType(cs.getPattern()), false, context)); - } - } - - return subjectType; - } - if (parentPattern instanceof PySingleStarPattern starPattern) { - final PySequencePattern sequenceParent = as(starPattern.getParent(), PySequencePattern.class); - if (sequenceParent == null) return null; - final PyType sequenceType = PySequencePatternImpl.getSequenceCaptureType(sequenceParent, context); - final PyType iteratedType = PyTypeUtil.toStream(sequenceType) - .flatMap(it -> starPattern.getCapturedTypesFromSequenceType(it, context).stream()).collect(PyTypeUtil.toUnion()); - return wrapInListType(iteratedType, pattern); - } - if (parentPattern instanceof PyDoubleStarPattern) { - final PyMappingPattern mappingParent = as(parentPattern.getParent(), PyMappingPattern.class); - if (mappingParent == null) return null; - var mappingType = PyTypeUtil.convertToType(context.getType(mappingParent), "typing.Mapping", pattern, context); - if (mappingType instanceof PyCollectionType collectionType) { - final PyClass dict = PyBuiltinCache.getInstance(pattern).getClass("dict"); - return dict != null ? new PyCollectionTypeImpl(dict, false, collectionType.getElementTypes()) : null; - } - return null; - } - if (parentPattern instanceof PyKeyValuePattern keyValuePattern) { - final PyMappingPattern mappingParent = as(keyValuePattern.getParent(), PyMappingPattern.class); - if (mappingParent == null) return null; - - return PyTypeUtil.toStream(getCaptureType(mappingParent, context)).map(type -> { - if (type instanceof PyTypedDictType typedDictType) { - if (context.getType(keyValuePattern.getKeyPattern()) instanceof PyLiteralType l && - l.getExpression() instanceof PyStringLiteralExpression str) { - return typedDictType.getElementType(str.getStringValue()); - } - } - - PyType mappingType = PyTypeUtil.convertToType(type, "typing.Mapping", pattern, context); - if (mappingType == null) { - return PyNeverType.NEVER; - } - else if (mappingType instanceof PyCollectionType collectionType) { - return collectionType.getElementTypes().get(1); - } - return null; - }).collect(PyTypeUtil.toUnion()); - } - if (parentPattern instanceof PySequencePattern sequencePattern) { - final PyType sequenceType = PySequencePatternImpl.getSequenceCaptureType(sequencePattern, context); - if (sequenceType == null) return null; - - // This is done to skip group- and as-patterns - final var sequenceMember = PsiTreeUtil.findFirstParent(pattern, el -> el.getParent() == sequencePattern); - final List elements = sequencePattern.getElements(); - final int idx = elements.indexOf(sequenceMember); - final int starIdx = ContainerUtil.indexOf(elements, it2 -> it2 instanceof PySingleStarPattern); - - return PyTypeUtil.toStream(sequenceType).map(it -> { - if (it instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) { - if (starIdx == -1 || idx < starIdx) { - return tupleType.getElementType(idx); - } - else { - final int starSpan = tupleType.getElementCount() - elements.size(); - return tupleType.getElementType(idx + starSpan); - } - } - var upcast = PyTypeUtil.convertToType(it, "typing.Sequence", pattern, context); - if (upcast instanceof PyCollectionType collectionType) { - return collectionType.getIteratedItemType(); - } - return null; - }).collect(PyTypeUtil.toUnion()); - } - if (parentPattern instanceof PyClassPattern classPattern) { - final List arguments = classPattern.getArgumentList().getPatterns(); - int index = arguments.indexOf(pattern); - if (index < 0) return null; - - // capture type can be a union like: list[int] | list[str] - return PyTypeUtil.toStream(context.getType(classPattern)).map(type -> { - if (type instanceof PyClassType classType) { - if (SPECIAL_BUILTINS.contains(classType.getClassQName())) { - if (index == 0) { - return classType; - } - return null; - } - - List matchArgs = getMatchArgs(classType, context); - if (matchArgs == null || matchArgs.size() > arguments.size()) return null; - - final PyTypedElement instanceAttribute = as(resolveTypeMember(classType, matchArgs.get(index), context), PyTypedElement.class); - if (instanceAttribute == null) return null; - - return PyTypeChecker.substitute(context.getType(instanceAttribute), PyTypeChecker.unifyReceiver(classType, context), context); - } - return null; - }).collect(PyTypeUtil.toUnion()); - } - if (parentPattern instanceof PyKeywordPattern keywordPattern) { - final PyClassPattern classPattern = PsiTreeUtil.getParentOfType(keywordPattern, PyClassPattern.class); - if (classPattern == null) return null; - - if (context.getType(classPattern) instanceof PyClassType classType) { - final PyTypedElement instanceAttribute = as(resolveTypeMember(classType, keywordPattern.getKeyword(), context), PyTypedElement.class); - if (instanceAttribute == null) return null; - return context.getType(instanceAttribute); - } - return null; - } - return null; - } - - static @Nullable List<@NotNull String> getMatchArgs(@NotNull PyClassType type, @NotNull TypeEvalContext context) { - final PyClass cls = type.getPyClass(); - List matchArgs = cls.getOwnMatchArgs(); - if (matchArgs == null) { - matchArgs = PyDataclassTypeProvider.Companion.getGeneratedMatchArgs(cls, context); - } - return matchArgs; - } - - @Nullable - static PsiElement resolveTypeMember(@NotNull PyType type, @NotNull String name, @NotNull TypeEvalContext context) { - final PyResolveContext resolveContext = PyResolveContext.defaultContext(context); - final List results = type.resolveMember(name, null, AccessDirection.READ, resolveContext); - return !ContainerUtil.isEmpty(results) ? results.get(0).getElement() : null; + return PyCaptureContext.getCaptureType(this, context); } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyCaseClauseImpl.kt b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyCaseClauseImpl.kt index 2c04771ec584..d9ff07375cae 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyCaseClauseImpl.kt +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyCaseClauseImpl.kt @@ -1,16 +1,35 @@ -package com.jetbrains.python.psi.impl; +package com.jetbrains.python.psi.impl -import com.intellij.lang.ASTNode; -import com.jetbrains.python.psi.PyCaseClause; -import com.jetbrains.python.psi.PyElementVisitor; +import com.intellij.lang.ASTNode +import com.intellij.openapi.util.Ref +import com.jetbrains.python.codeInsight.controlflow.PyTypeAssertionEvaluator +import com.jetbrains.python.psi.PyCaseClause +import com.jetbrains.python.psi.PyElementVisitor +import com.jetbrains.python.psi.PyMatchStatement +import com.jetbrains.python.psi.PyPattern +import com.jetbrains.python.psi.types.PyType +import com.jetbrains.python.psi.types.TypeEvalContext -public class PyCaseClauseImpl extends PyElementImpl implements PyCaseClause { - public PyCaseClauseImpl(ASTNode astNode) { - super(astNode); +class PyCaseClauseImpl(astNode: ASTNode?) : PyElementImpl(astNode), PyCaseClause { + override fun acceptPyVisitor(pyVisitor: PyElementVisitor) { + pyVisitor.visitPyCaseClause(this) } - @Override - protected void acceptPyVisitor(PyElementVisitor pyVisitor) { - pyVisitor.visitPyCaseClause(this); + override fun getCaptureTypeForChild(pattern: PyPattern, context: TypeEvalContext): PyType? { + val matchStatement = getParent() as? PyMatchStatement ?: return null + val subject = matchStatement.subject ?: return null + + var subjectType = context.getType(subject) + for (cs in matchStatement.caseClauses) { + if (cs === this) break + if (cs.pattern == null) continue + if (cs.guardCondition != null && !PyEvaluator.evaluateAsBoolean(cs.guardCondition, false)) continue + if (cs.pattern!!.canExcludePatternType(context)) { + subjectType = Ref.deref( + PyTypeAssertionEvaluator.createAssertionType(subjectType, context.getType(cs.pattern!!), false, context)) + } + } + + return subjectType } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyClassPatternImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyClassPatternImpl.java index d7e5cdf720af..1c3498194261 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyClassPatternImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyClassPatternImpl.java @@ -2,8 +2,14 @@ package com.jetbrains.python.psi.impl; import com.intellij.lang.ASTNode; import com.intellij.openapi.util.Ref; +import com.intellij.psi.PsiElement; +import com.intellij.psi.util.PsiTreeUtil; +import com.intellij.util.containers.ContainerUtil; import com.jetbrains.python.codeInsight.controlflow.PyTypeAssertionEvaluator; +import com.jetbrains.python.codeInsight.stdlib.PyDataclassTypeProvider; import com.jetbrains.python.psi.*; +import com.jetbrains.python.psi.resolve.PyResolveContext; +import com.jetbrains.python.psi.resolve.RatedResolveResult; import com.jetbrains.python.psi.types.*; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; @@ -11,8 +17,8 @@ import org.jetbrains.annotations.Nullable; import java.util.List; import static com.intellij.util.containers.ContainerUtil.getFirstItem; +import static com.jetbrains.python.psi.PyUtil.as; import static com.jetbrains.python.psi.PyUtil.multiResolveTopPriority; -import static com.jetbrains.python.psi.impl.PyCapturePatternImpl.*; public class PyClassPatternImpl extends PyElementImpl implements PyClassPattern { public PyClassPatternImpl(ASTNode astNode) { @@ -29,7 +35,7 @@ public class PyClassPatternImpl extends PyElementImpl implements PyClassPattern final PyType type = context.getType(getClassNameReference()); if (type instanceof PyClassType classType) { final PyType instanceType = classType.toInstance(); - final PyType captureType = PyCapturePatternImpl.getCaptureType(this, context); + final PyType captureType = PyCaptureContext.getCaptureType(this, context); return Ref.deref(PyTypeAssertionEvaluator.createAssertionType(captureType, instanceType, true, context)); } return null; @@ -75,13 +81,53 @@ public class PyClassPatternImpl extends PyElementImpl implements PyClassPattern } } } - final PyType captureType = getCaptureType(this, context); + final PyType captureType = PyCaptureContext.getCaptureType(this, context); final PyType patternType = context.getType(this); return PyTypeUtil.toStream(captureType).anyMatch(it -> PyTypeChecker.match(patternType, it, context)); } - + + @Override + public @Nullable PyType getCaptureTypeForChild(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) { + pattern = as(PsiTreeUtil.findFirstParent(pattern, el -> this.getArgumentList() == el.getParent()), PyPattern.class); + if (pattern == null) return null; + + if (pattern instanceof PyKeywordPattern keywordPattern) { + if (context.getType(this) instanceof PyClassType classType) { + final PyTypedElement instanceAttribute = as(resolveTypeMember(classType, keywordPattern.getKeyword(), context), PyTypedElement.class); + if (instanceAttribute == null) return null; + return context.getType(instanceAttribute); + } + return null; + } + + final List arguments = getArgumentList().getPatterns(); + int index = arguments.indexOf(pattern); + if (index < 0) return null; + + // capture type can be a union like: list[int] | list[str] + return PyTypeUtil.toStream(context.getType(this)).map(type -> { + if (type instanceof PyClassType classType) { + if (SPECIAL_BUILTINS.contains(classType.getClassQName())) { + if (index == 0) { + return classType; + } + return null; + } + + List matchArgs = getMatchArgs(classType, context); + if (matchArgs == null || matchArgs.size() > arguments.size()) return null; + + final PyTypedElement instanceAttribute = as(resolveTypeMember(classType, matchArgs.get(index), context), PyTypedElement.class); + if (instanceAttribute == null) return null; + + return PyTypeChecker.substitute(context.getType(instanceAttribute), PyTypeChecker.unifyReceiver(classType, context), context); + } + return null; + }).collect(PyTypeUtil.toUnion()); + } + static boolean canExcludeArgumentPatternType(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) { - final var captureType = getCaptureType(pattern, context); + final var captureType = PyCaptureContext.getCaptureType(pattern, context); final var patternType = context.getType(pattern); // For class pattern arguments, we need to ensure that the argument pattern covers its capture type fully if (Ref.deref(PyTypeAssertionEvaluator.createAssertionType(captureType, patternType, false, context)) instanceof PyNeverType) { @@ -90,4 +136,20 @@ public class PyClassPatternImpl extends PyElementImpl implements PyClassPattern } return false; } + + private static @Nullable List<@NotNull String> getMatchArgs(@NotNull PyClassType type, @NotNull TypeEvalContext context) { + final PyClass cls = type.getPyClass(); + List matchArgs = cls.getOwnMatchArgs(); + if (matchArgs == null) { + matchArgs = PyDataclassTypeProvider.Companion.getGeneratedMatchArgs(cls, context); + } + return matchArgs; + } + + @Nullable + static PsiElement resolveTypeMember(@NotNull PyType type, @NotNull String name, @NotNull TypeEvalContext context) { + final PyResolveContext resolveContext = PyResolveContext.defaultContext(context); + final List results = type.resolveMember(name, null, AccessDirection.READ, resolveContext); + return !ContainerUtil.isEmpty(results) ? results.get(0).getElement() : null; + } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyMappingPatternImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyMappingPatternImpl.java index e08964c5a2f3..8931fb775448 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyMappingPatternImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyMappingPatternImpl.java @@ -3,6 +3,7 @@ package com.jetbrains.python.psi.impl; import com.intellij.lang.ASTNode; import com.intellij.psi.PsiElement; import com.intellij.psi.PsiListLikeElement; +import com.intellij.psi.util.PsiTreeUtil; import com.jetbrains.python.psi.*; import com.jetbrains.python.psi.types.*; import org.jetbrains.annotations.NotNull; @@ -45,7 +46,7 @@ public class PyMappingPatternImpl extends PyElementImpl implements PyMappingPatt PyType patternMappingType = wrapInMappingType(PyUnionType.union(keyTypes), PyUnionType.union(valueTypes), this); - PyType captureTypes = PyCapturePatternImpl.getCaptureType(this, context); + PyType captureTypes = PyCaptureContext.getCaptureType(this, context); PyType filteredType = PyTypeUtil.toStream(captureTypes).filter(captureType -> { var mappingType = PyTypeUtil.convertToType(captureType, "typing.Mapping", this, context); if (mappingType == null) return false; @@ -56,6 +57,39 @@ public class PyMappingPatternImpl extends PyElementImpl implements PyMappingPatt return filteredType == null ? patternMappingType : filteredType; } + @Override + public @Nullable PyType getCaptureTypeForChild(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) { + final var sequenceMember = PsiTreeUtil.findFirstParent(pattern, el -> this == el.getParent()); + if (sequenceMember instanceof PyDoubleStarPattern) { + var mappingType = PyTypeUtil.convertToType(context.getType(this), "typing.Mapping", pattern, context); + if (mappingType instanceof PyCollectionType collectionType) { + final PyClass dict = PyBuiltinCache.getInstance(pattern).getClass("dict"); + return dict != null ? new PyCollectionTypeImpl(dict, false, collectionType.getElementTypes()) : null; + } + return null; + } + else if (sequenceMember instanceof PyKeyValuePattern keyValuePattern) { + return PyTypeUtil.toStream(PyCaptureContext.getCaptureType(this, context)).map(type -> { + if (type instanceof PyTypedDictType typedDictType) { + if (context.getType(keyValuePattern.getKeyPattern()) instanceof PyLiteralType l && + l.getExpression() instanceof PyStringLiteralExpression str) { + return typedDictType.getElementType(str.getStringValue()); + } + } + + PyType mappingType = PyTypeUtil.convertToType(type, "typing.Mapping", pattern, context); + if (mappingType == null) { + return PyNeverType.NEVER; + } + else if (mappingType instanceof PyCollectionType collectionType) { + return collectionType.getElementTypes().get(1); + } + return null; + }).collect(PyTypeUtil.toUnion()); + } + return null; + } + private static @Nullable PyType wrapInMappingType(@Nullable PyType keyType, @Nullable PyType valueType, @NotNull PsiElement resolveAnchor) { keyType = PyLiteralType.upcastLiteralToClass(keyType); valueType = PyLiteralType.upcastLiteralToClass(valueType); diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PySequencePatternImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PySequencePatternImpl.java index 728e0c070df1..dc2df9735b26 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PySequencePatternImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PySequencePatternImpl.java @@ -3,6 +3,7 @@ package com.jetbrains.python.psi.impl; import com.intellij.lang.ASTNode; import com.intellij.psi.PsiElement; import com.intellij.psi.PsiListLikeElement; +import com.intellij.psi.util.PsiTreeUtil; import com.intellij.util.ArrayUtil; import com.intellij.util.containers.ContainerUtil; import com.jetbrains.python.psi.*; @@ -39,7 +40,7 @@ public class PySequencePatternImpl extends PyElementImpl implements PySequencePa final ArrayList types = new ArrayList<>(); for (PyPattern pattern : getElements()) { if (pattern instanceof PySingleStarPattern starPattern) { - types.addAll(starPattern.getCapturedTypesFromSequenceType(sequenceCaptureType, context)); + types.addAll(getCapturedTypesFromSequenceType(starPattern, sequenceCaptureType, context)); } else { types.add(context.getType(pattern)); @@ -74,6 +75,40 @@ public class PySequencePatternImpl extends PyElementImpl implements PySequencePa return true; } + @Override + public @Nullable PyType getCaptureTypeForChild(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) { + final PyType sequenceType = getSequenceCaptureType(this, context); + if (sequenceType == null) return null; + + // This is done to skip group- and as-patterns + final var sequenceMember = PsiTreeUtil.findFirstParent(pattern, el -> this == el.getParent()); + if (sequenceMember instanceof PySingleStarPattern starPattern) { + final PyType iteratedType = PyTypeUtil.toStream(sequenceType) + .flatMap(it -> getCapturedTypesFromSequenceType(starPattern, it, context).stream()).collect(PyTypeUtil.toUnion()); + return wrapInListType(iteratedType, pattern); + } + final List elements = getElements(); + final int idx = elements.indexOf(sequenceMember); + final int starIdx = ContainerUtil.indexOf(elements, it2 -> it2 instanceof PySingleStarPattern); + + return PyTypeUtil.toStream(sequenceType).map(it -> { + if (it instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) { + if (starIdx == -1 || idx < starIdx) { + return tupleType.getElementType(idx); + } + else { + final int starSpan = tupleType.getElementCount() - elements.size(); + return tupleType.getElementType(idx + starSpan); + } + } + var upcast = PyTypeUtil.convertToType(it, "typing.Sequence", pattern, context); + if (upcast instanceof PyCollectionType collectionType) { + return collectionType.getIteratedItemType(); + } + return null; + }).collect(PyTypeUtil.toUnion()); + } + static @Nullable PyType wrapInListType(@Nullable PyType elementType, @NotNull PsiElement resolveAnchor) { final PyClass list = PyBuiltinCache.getInstance(resolveAnchor).getClass("list"); return list != null ? new PyCollectionTypeImpl(list, false, Collections.singletonList(upcastLiteralToClass(elementType))) : null; @@ -86,12 +121,12 @@ public class PySequencePatternImpl extends PyElementImpl implements PySequencePa } /** - * Similar to {@link PyCapturePatternImpl#getCaptureType(PyPattern, TypeEvalContext)}, + * Similar to {@link PyCaptureContext#getCaptureType(PyPattern, TypeEvalContext)}, * but only chooses types that would match to typing.Sequence, and have correct length */ @Nullable static PyType getSequenceCaptureType(@NotNull PySequencePattern pattern, @NotNull TypeEvalContext context) { - final PyType captureTypes = PyCapturePatternImpl.getCaptureType(pattern, context); + final PyType captureTypes = PyCaptureContext.getCaptureType(pattern, context); final boolean hasStar = ContainerUtil.exists(pattern.getElements(), it -> it instanceof PySingleStarPattern); List types = new ArrayList<>(); @@ -131,4 +166,20 @@ public class PySequencePatternImpl extends PyElementImpl implements PySequencePa } return PyUnionType.union(types); } + + private static @NotNull List<@Nullable PyType> getCapturedTypesFromSequenceType(@NotNull PySingleStarPattern starPattern, + @Nullable PyType sequenceType, + @NotNull TypeEvalContext context) { + if (starPattern.getParent() instanceof PySequencePattern sequenceParent) { + final int idx = sequenceParent.getElements().indexOf(starPattern); + if (sequenceType instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) { + return tupleType.getElementTypes().subList(idx, idx + tupleType.getElementCount() - sequenceParent.getElements().size() + 1); + } + var upcast = PyTypeUtil.convertToType(sequenceType, "typing.Sequence", starPattern, context); + if (upcast instanceof PyCollectionType collectionType) { + return Collections.singletonList(collectionType.getIteratedItemType()); + } + } + return List.of(); + } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PySingleStarPatternImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PySingleStarPatternImpl.java index 9ee3745357d9..7156f7fe8281 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PySingleStarPatternImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PySingleStarPatternImpl.java @@ -1,14 +1,14 @@ package com.jetbrains.python.psi.impl; import com.intellij.lang.ASTNode; -import com.jetbrains.python.psi.*; -import com.jetbrains.python.psi.types.*; +import com.jetbrains.python.psi.PyElementVisitor; +import com.jetbrains.python.psi.PyInstantTypeProvider; +import com.jetbrains.python.psi.PySingleStarPattern; +import com.jetbrains.python.psi.types.PyType; +import com.jetbrains.python.psi.types.TypeEvalContext; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; -import java.util.Collections; -import java.util.List; - public class PySingleStarPatternImpl extends PyElementImpl implements PySingleStarPattern, PyInstantTypeProvider { public PySingleStarPatternImpl(ASTNode astNode) { super(astNode); @@ -23,19 +23,4 @@ public class PySingleStarPatternImpl extends PyElementImpl implements PySingleSt public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) { return null; } - - @Override - public @NotNull List<@Nullable PyType> getCapturedTypesFromSequenceType(@Nullable PyType sequenceType, @NotNull TypeEvalContext context) { - if (getParent() instanceof PySequencePattern sequenceParent) { - final int idx = sequenceParent.getElements().indexOf(this); - if (sequenceType instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) { - return tupleType.getElementTypes().subList(idx, idx + tupleType.getElementCount() - sequenceParent.getElements().size() + 1); - } - var upcast = PyTypeUtil.convertToType(sequenceType, "typing.Sequence", this, context); - if (upcast instanceof PyCollectionType collectionType) { - return Collections.singletonList(collectionType.getIteratedItemType()); - } - } - return List.of(); - } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyWildcardPatternImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyWildcardPatternImpl.java index b5037ea8b5a1..2e65e800b4f3 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyWildcardPatternImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyWildcardPatternImpl.java @@ -1,6 +1,7 @@ package com.jetbrains.python.psi.impl; import com.intellij.lang.ASTNode; +import com.jetbrains.python.psi.PyCaptureContext; import com.jetbrains.python.psi.PyElementVisitor; import com.jetbrains.python.psi.PyWildcardPattern; import com.jetbrains.python.psi.types.PyType; @@ -20,6 +21,6 @@ public class PyWildcardPatternImpl extends PyElementImpl implements PyWildcardPa @Override public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) { - return PyCapturePatternImpl.getCaptureType(this, context); + return PyCaptureContext.getCaptureType(this, context); } }