diff --git a/python/python-ast/src/com/jetbrains/python/ast/PyAstAsPattern.kt b/python/python-ast/src/com/jetbrains/python/ast/PyAstAsPattern.kt index 5c43b1e0ef9d..4dbe3beb542d 100644 --- a/python/python-ast/src/com/jetbrains/python/ast/PyAstAsPattern.kt +++ b/python/python-ast/src/com/jetbrains/python/ast/PyAstAsPattern.kt @@ -8,6 +8,10 @@ interface PyAstAsPattern : PyAstPattern { return requireNotNull(findChildByClass(PyAstPattern::class.java)) { "${this}: pattern cannot be null" } } + fun getTarget(): PyAstTargetExpression { + return requireNotNull(findChildByClass(PyAstTargetExpression::class.java)) { "${this}: target cannot be null" } + } + override fun isIrrefutable(): Boolean { return getPattern().isIrrefutable } diff --git a/python/python-ast/src/com/jetbrains/python/ast/PyAstValuePattern.java b/python/python-ast/src/com/jetbrains/python/ast/PyAstValuePattern.java index 3e43a9787329..9eb5562ae399 100644 --- a/python/python-ast/src/com/jetbrains/python/ast/PyAstValuePattern.java +++ b/python/python-ast/src/com/jetbrains/python/ast/PyAstValuePattern.java @@ -2,6 +2,11 @@ package com.jetbrains.python.ast; import org.jetbrains.annotations.ApiStatus; +import org.jetbrains.annotations.NotNull; + +import java.util.Objects; + +import static com.jetbrains.python.ast.PyAstElementKt.findChildByClass; @ApiStatus.Experimental public interface PyAstValuePattern extends PyAstPattern { @@ -10,6 +15,11 @@ public interface PyAstValuePattern extends PyAstPattern { return false; } + @NotNull + default PyAstReferenceExpression getValue() { + return Objects.requireNonNull(findChildByClass(this, PyAstReferenceExpression.class)); + } + @Override default void acceptPyVisitor(PyAstElementVisitor pyVisitor) { pyVisitor.visitPyValuePattern(this); diff --git a/python/python-parser/src/com/jetbrains/python/PyNames.java b/python/python-parser/src/com/jetbrains/python/PyNames.java index 9b841823caf9..d6a65cc78aa8 100644 --- a/python/python-parser/src/com/jetbrains/python/PyNames.java +++ b/python/python-parser/src/com/jetbrains/python/PyNames.java @@ -186,6 +186,7 @@ public final @NonNls class PyNames { public static final String ROUND = "__round__"; public static final String CLASS_GETITEM = "__class_getitem__"; public static final String PREPARE = "__prepare__"; + public static final String MATCH_ARGS = "__match_args__"; public static final String NAME = "__name__"; public static final String ENTER = "__enter__"; diff --git a/python/python-psi-api/src/com/jetbrains/python/psi/PyPattern.java b/python/python-psi-api/src/com/jetbrains/python/psi/PyPattern.java index edbaf975233a..77d22ed16c31 100644 --- a/python/python-psi-api/src/com/jetbrains/python/psi/PyPattern.java +++ b/python/python-psi-api/src/com/jetbrains/python/psi/PyPattern.java @@ -3,5 +3,5 @@ package com.jetbrains.python.psi; import com.jetbrains.python.ast.PyAstPattern; -public interface PyPattern extends PyAstPattern, PyElement { +public interface PyPattern extends PyAstPattern, PyTypedElement { } diff --git a/python/python-psi-api/src/com/jetbrains/python/psi/PySequencePattern.java b/python/python-psi-api/src/com/jetbrains/python/psi/PySequencePattern.java index 00c0cb79a8b6..9ac9df06e6aa 100644 --- a/python/python-psi-api/src/com/jetbrains/python/psi/PySequencePattern.java +++ b/python/python-psi-api/src/com/jetbrains/python/psi/PySequencePattern.java @@ -2,6 +2,14 @@ package com.jetbrains.python.psi; import com.jetbrains.python.ast.PyAstSequencePattern; +import org.jetbrains.annotations.NotNull; + +import java.util.List; + +import static com.jetbrains.python.ast.PyAstElementKt.findChildrenByClass; public interface PySequencePattern extends PyAstSequencePattern, PyPattern { + default @NotNull List<@NotNull PyPattern> getElements() { + return List.of(findChildrenByClass(this, PyPattern.class)); + } } diff --git a/python/python-psi-api/src/com/jetbrains/python/psi/PySingleStarPattern.java b/python/python-psi-api/src/com/jetbrains/python/psi/PySingleStarPattern.java index b19058fc3847..71b5d6026515 100644 --- a/python/python-psi-api/src/com/jetbrains/python/psi/PySingleStarPattern.java +++ b/python/python-psi-api/src/com/jetbrains/python/psi/PySingleStarPattern.java @@ -2,6 +2,21 @@ package com.jetbrains.python.psi; import com.jetbrains.python.ast.PyAstSingleStarPattern; +import com.jetbrains.python.psi.types.PyType; +import com.jetbrains.python.psi.types.TypeEvalContext; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; + +import java.util.List; +import java.util.Objects; + +import static com.jetbrains.python.ast.PyAstElementKt.findChildByClass; public interface PySingleStarPattern extends PyAstSingleStarPattern, PyPattern { + @NotNull + default PyPattern getPattern() { + return Objects.requireNonNull(findChildByClass(this, PyPattern.class)); + } + + @NotNull List<@Nullable PyType> getCapturedTypesFromSequenceType(@Nullable PyType sequenceType, @NotNull TypeEvalContext context); } diff --git a/python/python-psi-api/src/com/jetbrains/python/psi/PyValuePattern.java b/python/python-psi-api/src/com/jetbrains/python/psi/PyValuePattern.java index 607face32b73..443dde0bb0ce 100644 --- a/python/python-psi-api/src/com/jetbrains/python/psi/PyValuePattern.java +++ b/python/python-psi-api/src/com/jetbrains/python/psi/PyValuePattern.java @@ -2,6 +2,11 @@ package com.jetbrains.python.psi; import com.jetbrains.python.ast.PyAstValuePattern; +import org.jetbrains.annotations.NotNull; public interface PyValuePattern extends PyAstValuePattern, PyPattern { + @Override + default @NotNull PyReferenceExpression getValue() { + return (PyReferenceExpression)PyAstValuePattern.super.getValue(); + } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/controlflow/PyControlFlowBuilder.java b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/controlflow/PyControlFlowBuilder.java index 6047bb0563a5..691c14d67a8a 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/controlflow/PyControlFlowBuilder.java +++ b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/controlflow/PyControlFlowBuilder.java @@ -27,6 +27,7 @@ import com.intellij.psi.PsiElement; import com.intellij.psi.PsiNamedElement; import com.intellij.psi.util.PsiTreeUtil; import com.intellij.psi.util.QualifiedName; +import com.intellij.util.containers.ContainerUtil; import com.jetbrains.python.PyNames; import com.jetbrains.python.PyTokenTypes; import com.jetbrains.python.psi.*; @@ -330,7 +331,134 @@ public class PyControlFlowBuilder extends PyRecursiveElementVisitor { @Override public void visitPyMatchStatement(@NotNull PyMatchStatement matchStatement) { - new PyMatchStatementControlFlowBuilder(myBuilder, this).build(matchStatement); + myBuilder.startNode(matchStatement); + PyExpression subject = matchStatement.getSubject(); + if (subject != null) { + subject.accept(this); + } + for (PyCaseClause caseClause : matchStatement.getCaseClauses()) { + visitPyCaseClause(caseClause); + } + myBuilder.addNodeAndCheckPending(new TransparentInstructionImpl(myBuilder, matchStatement, "")); + if (!myBuilder.prevInstruction.allPred().isEmpty()) { + addTypeAssertionNodes(matchStatement, false); + } + myBuilder.addPendingEdge(matchStatement, myBuilder.prevInstruction); + myBuilder.prevInstruction = null; + } + + @Override + public void visitPyCaseClause(@NotNull PyCaseClause clause) { + PyPattern pattern = clause.getPattern(); + if (pattern != null) { + pattern.accept(this); + addTypeAssertionNodes(pattern, true); + } + + TransparentInstruction trueNode = addTransparentInstruction(); + TransparentInstruction falseNode = addTransparentInstruction(); + PyExpression guard = clause.getGuardCondition(); + if (guard != null) { + visitCondition(guard, trueNode, falseNode); + addTypeAssertionNodes(guard, true); + } + else { + myBuilder.addEdge(myBuilder.prevInstruction, trueNode); + } + myBuilder.addPendingEdge(clause, falseNode); + myBuilder.prevInstruction = trueNode; + + clause.getStatementList().accept(this); + + if (clause.getParent() instanceof PyMatchStatement matchStatement) { + myBuilder.addPendingEdge(matchStatement, myBuilder.prevInstruction); + myBuilder.updatePendingElementScope(clause.getStatementList(), matchStatement); + } + myBuilder.prevInstruction = null; + } + + @Override + public void visitWildcardPattern(@NotNull PyWildcardPattern node) { + myBuilder.startNode(node); + } + + @Override + public void visitPyPattern(@NotNull PyPattern node) { + boolean isRefutable = !node.isIrrefutable(); + if (isRefutable) { + myBuilder.addNodeAndCheckPending(new RefutablePatternInstruction(myBuilder, node, false)); + myBuilder.addPendingEdge(node.getParent(), myBuilder.prevInstruction); + } + + node.acceptChildren(this); + myBuilder.updatePendingElementScope(node, node.getParent()); + + if (isRefutable) { + myBuilder.addNode(new RefutablePatternInstruction(myBuilder, node, true)); + } + } + + @Override + public void visitPyOrPattern(@NotNull PyOrPattern node) { + myBuilder.addNodeAndCheckPending(new RefutablePatternInstruction(myBuilder, node, false)); + + TransparentInstruction onSuccess = new TransparentInstructionImpl(myBuilder, node, "onSuccess"); + List alternatives = node.getAlternatives(); + PyPattern lastAlternative = ContainerUtil.getLastItem(alternatives); + + for (PyPattern alternative : alternatives) { + alternative.accept(this); + if (alternative != lastAlternative) { + // Allow next alternative to handle the fail edge of this alternative + myBuilder.updatePendingElementScope(node, alternative); + } + myBuilder.addEdge(myBuilder.prevInstruction, onSuccess); + myBuilder.prevInstruction = null; + } + myBuilder.addNode(onSuccess); + myBuilder.addNode(new RefutablePatternInstruction(myBuilder, node, true)); + myBuilder.updatePendingElementScope(node, node.getParent()); + } + + @Override + public void visitPyClassPattern(@NotNull PyClassPattern node) { + myBuilder.addNodeAndCheckPending(new RefutablePatternInstruction(myBuilder, node, false)); + + node.getClassNameReference().accept(this); + myBuilder.addPendingEdge(node.getParent(), myBuilder.prevInstruction); + + node.getArgumentList().acceptChildren(this); + myBuilder.updatePendingElementScope(node, node.getParent()); + + myBuilder.addNode(new RefutablePatternInstruction(myBuilder, node, true)); + } + + @Override + public void visitPyValuePattern(@NotNull PyValuePattern node) { + myBuilder.addNodeAndCheckPending(new RefutablePatternInstruction(myBuilder, node, false)); + + node.getValue().accept(this); + myBuilder.addPendingEdge(node.getParent(), myBuilder.prevInstruction); + + myBuilder.addNode(new RefutablePatternInstruction(myBuilder, node, true)); + } + + @Override + public void visitPyAsPattern(@NotNull PyAsPattern node) { + // AsPattern cannot fail by itself - it fails only if its child fails. + // So no need to create additional fail edge + myBuilder.startNode(node); + node.acceptChildren(this); + myBuilder.updatePendingElementScope(node, node.getParent()); + } + + @Override + public void visitPyGroupPattern(@NotNull PyGroupPattern node) { + // GroupPattern cannot fail by itself - it fails only if its child fails. + // So no need to create additional fail edge + // Also no need to dedicated node for GroupPattern itself + node.acceptChildren(this); + myBuilder.updatePendingElementScope(node, node.getParent()); } @Override @@ -955,7 +1083,7 @@ public class PyControlFlowBuilder extends PyRecursiveElementVisitor { || element instanceof PyStatementList); } - private void addTypeAssertionNodes(@NotNull PyExpression condition, boolean positive) { + private void addTypeAssertionNodes(@NotNull PyElement condition, boolean positive) { final PyTypeAssertionEvaluator evaluator = new PyTypeAssertionEvaluator(positive); condition.accept(evaluator); for (PyTypeAssertionEvaluator.Assertion def : evaluator.getDefinitions()) { diff --git a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/controlflow/PyControlFlowProvider.kt b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/controlflow/PyControlFlowProvider.kt index 2d2e734314c6..f960721c155d 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/controlflow/PyControlFlowProvider.kt +++ b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/controlflow/PyControlFlowProvider.kt @@ -16,6 +16,7 @@ class PyControlFlowProvider : ControlFlowProvider { override fun getAdditionalInfo(instruction: Instruction): String? { return when (instruction) { is ReadWriteInstruction -> "${instruction.access} ${instruction.name}" + is RefutablePatternInstruction -> instruction.elementPresentation else -> null } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/controlflow/PyMatchStatementControlFlowBuilder.java b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/controlflow/PyMatchStatementControlFlowBuilder.java deleted file mode 100644 index 29b832df2f70..000000000000 --- a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/controlflow/PyMatchStatementControlFlowBuilder.java +++ /dev/null @@ -1,165 +0,0 @@ -package com.jetbrains.python.codeInsight.controlflow; - -import com.intellij.codeInsight.controlflow.ConditionalInstruction; -import com.intellij.codeInsight.controlflow.ControlFlowBuilder; -import com.intellij.codeInsight.controlflow.Instruction; -import com.intellij.psi.PsiElement; -import com.intellij.psi.util.PsiTreeUtil; -import com.intellij.util.containers.ContainerUtil; -import com.jetbrains.python.PyTokenTypes; -import com.jetbrains.python.psi.*; -import org.jetbrains.annotations.NotNull; -import org.jetbrains.annotations.Nullable; - -import java.util.List; -import java.util.function.BiFunction; - -public final class PyMatchStatementControlFlowBuilder { - private final ControlFlowBuilder myBuilder; - private final PyElementVisitor myBaseVisitor; - - public PyMatchStatementControlFlowBuilder(@NotNull ControlFlowBuilder builder, @NotNull PyElementVisitor baseVisitor) { - myBuilder = builder; - myBaseVisitor = baseVisitor; - } - - public void build(@NotNull PyMatchStatement matchStatement) { - myBuilder.startNode(matchStatement); - PyExpression subject = matchStatement.getSubject(); - if (subject != null) { - subject.accept(myBaseVisitor); - } - for (PyCaseClause caseClause : matchStatement.getCaseClauses()) { - processCaseClause(caseClause); - } - } - - private void processCaseClause(@NotNull PyCaseClause clause) { - PyPattern pattern = clause.getPattern(); - if (pattern != null) { - processPattern(pattern); - retargetOutgoingPatternEdges(pattern, (oldScope, instr) -> { - return instr.isMatched() ? pattern : clause; - }); - } - PyStatementList statementList = clause.getStatementList(); - PyExpression guard = clause.getGuardCondition(); - if (guard != null) { - guard.accept(myBaseVisitor); - // Retarget failure edges coming from inner OR and AND expressions - retargetOutgoingEdges(guard, (pendingScope, instr) -> { - if (instr instanceof ConditionalInstruction && !((ConditionalInstruction)instr).getResult()) { - return clause; - } - return pendingScope; - }); - // Top-level OR and AND expressions should have had their own outgoing failure edges - if (!isConjunctionOrDisjunction(guard)) { - myBuilder.addPendingEdge(clause, myBuilder.prevInstruction); - } - myBuilder.startConditionalNode(statementList, guard, true); - } - statementList.accept(myBaseVisitor); - PyMatchStatement matchStatement = PsiTreeUtil.getParentOfType(clause, PyMatchStatement.class); - assert matchStatement != null; - retargetOutgoingEdges(statementList, (pendingScope, instruction) -> matchStatement); - myBuilder.addPendingEdge(matchStatement, myBuilder.prevInstruction); - myBuilder.prevInstruction = null; - } - - private void processPattern(@NotNull PyPattern pattern) { - boolean isRefutable = !pattern.isIrrefutable(); - if (isRefutable) { - RefutablePatternInstruction instruction = new RefutablePatternInstruction(myBuilder, pattern, false); - myBuilder.addNodeAndCheckPending(instruction); - myBuilder.addPendingEdge(pattern, instruction); - } - - if (pattern instanceof PyWildcardPattern) { - myBuilder.startNode(pattern); - } - else if (pattern instanceof PyOrPattern) { - List alternatives = ((PyOrPattern)pattern).getAlternatives(); - PyPattern lastAlternative = ContainerUtil.getLastItem(alternatives); - for (PyPattern alternative : alternatives) { - processPattern(alternative); - if (alternative != lastAlternative) { - myBuilder.addPendingEdge(alternative, myBuilder.prevInstruction); - myBuilder.prevInstruction = null; - } - retargetOutgoingEdges(alternative, (pendingScope, instruction) -> { - if (instruction instanceof RefutablePatternInstruction && !((RefutablePatternInstruction)instruction).isMatched()) { - // Pattern has failed, jump to the next alternative if any - return alternative; - } - // Pattern succeeded, jump out of OR-pattern. It can be either a refutable pattern or a capture/wildcard node. - else { - return pattern; - } - }); - } - } - else { - pattern.acceptChildren(new PyElementVisitor() { - @Override - public void visitPyReferenceExpression(@NotNull PyReferenceExpression node) { - myBaseVisitor.visitPyReferenceExpression(node); - } - - @Override - public void visitPyTargetExpression(@NotNull PyTargetExpression node) { - myBaseVisitor.visitPyTargetExpression(node); - } - - @Override - public void visitPyPattern(@NotNull PyPattern node) { - processPattern(node); - retargetOutgoingPatternEdges(pattern, (oldScope, instr) -> { - // Mismatch in a non-OR pattern means mismatch of the containing pattern as well - return instr.isMatched() ? oldScope : pattern; - }); - } - - @Override - public void visitPyPatternArgumentList(@NotNull PyPatternArgumentList node) { - node.acceptChildren(this); - } - }); - } - - if (isRefutable) { - myBuilder.addNode(new RefutablePatternInstruction(myBuilder, pattern, true)); - } - } - - private void retargetOutgoingEdges(@NotNull PsiElement scopeAncestor, - @NotNull BiFunction newScopeProvider) { - myBuilder.processPending((oldScope, instruction) -> { - if (oldScope != null && PsiTreeUtil.isAncestor(scopeAncestor, oldScope, false)) { - myBuilder.addPendingEdge(newScopeProvider.apply(oldScope, instruction), instruction); - } - else { - myBuilder.addPendingEdge(oldScope, instruction); - } - }); - } - - private void retargetOutgoingPatternEdges(@NotNull PsiElement scopeAncestor, - @NotNull BiFunction newScopeProvider) { - retargetOutgoingEdges(scopeAncestor, (pendingScope, instruction) -> { - if (instruction instanceof RefutablePatternInstruction) { - return newScopeProvider.apply(pendingScope, (RefutablePatternInstruction)instruction); - } - return pendingScope; - }); - } - - private static boolean isConjunctionOrDisjunction(@Nullable PyExpression node) { - if (node instanceof PyBinaryExpression) { - final var operator = ((PyBinaryExpression)node).getOperator(); - return operator == PyTokenTypes.AND_KEYWORD || operator == PyTokenTypes.OR_KEYWORD; - } - - return false; - } -} diff --git a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/controlflow/PyTypeAssertionEvaluator.java b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/controlflow/PyTypeAssertionEvaluator.java index 87c8fb22ca95..1519f435bd96 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/controlflow/PyTypeAssertionEvaluator.java +++ b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/controlflow/PyTypeAssertionEvaluator.java @@ -3,6 +3,7 @@ package com.jetbrains.python.codeInsight.controlflow; import com.intellij.openapi.util.Ref; import com.intellij.psi.PsiElement; +import com.intellij.psi.util.PsiTreeUtil; import com.intellij.util.containers.ContainerUtil; import com.intellij.util.containers.Stack; import com.jetbrains.python.PyNames; @@ -22,6 +23,8 @@ import java.util.ArrayList; import java.util.List; import java.util.function.Function; +import static com.jetbrains.python.psi.PyUtil.as; + public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor { private final Stack myStack = new Stack<>(); private boolean myPositive; @@ -41,14 +44,15 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor { if (args.length == 2 && args[0] instanceof PyReferenceExpression target) { final PyExpression typeElement = args[1]; - pushAssertion(target, myPositive, false, context -> context.getType(typeElement), typeElement); + pushAssertion(target, myPositive, context -> + transformTypeFromAssertion(context.getType(typeElement), false, context, typeElement)); } } else if (node.isCalleeText(PyNames.CALLABLE_BUILTIN)) { final PyExpression[] args = node.getArguments(); if (args.length == 1 && args[0] instanceof PyReferenceExpression target) { - pushAssertion(target, myPositive, false, context -> PyTypingTypeProvider.createTypingCallableType(node), null); + pushAssertion(target, myPositive, context -> PyTypingTypeProvider.createTypingCallableType(node)); } } else if (node.isCalleeText(PyNames.ISSUBCLASS)) { @@ -56,7 +60,8 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor { if (args.length == 2 && args[0] instanceof PyReferenceExpression target) { final PyExpression typeElement = args[1]; - pushAssertion(target, myPositive, true, context -> context.getType(typeElement), typeElement); + pushAssertion(target, myPositive, context -> + transformTypeFromAssertion(context.getType(typeElement), true, context, typeElement)); } } } @@ -66,7 +71,7 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor { if (myPositive && (isIfReferenceStatement(node) || isIfReferenceConditionalStatement(node) || isIfNotReferenceStatement(node))) { // we could not suggest `None` because it could be a reference to an empty collection // so we could push only non-`None` assertions - pushAssertion(node, !myPositive, false, context -> PyNoneType.INSTANCE, null); + pushAssertion(node, !myPositive, context -> PyNoneType.INSTANCE); return; } @@ -105,26 +110,26 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor { if (PyLiteralType.isNone(lhs)) { if (rhs instanceof PyReferenceExpression referenceExpr) { - pushAssertion(referenceExpr, myPositive, false, context -> PyNoneType.INSTANCE, null); + pushAssertion(referenceExpr, myPositive, context -> PyNoneType.INSTANCE); } return; } if (PyLiteralType.isNone(rhs)) { if (lhs instanceof PyReferenceExpression referenceExpr) { - pushAssertion(referenceExpr, myPositive, false, context -> PyNoneType.INSTANCE, null); + pushAssertion(referenceExpr, myPositive, context -> PyNoneType.INSTANCE); } return; } if (lhs instanceof PyReferenceExpression referenceExpr) { - pushAssertion(referenceExpr, myPositive, false, context -> getLiteralType(rhs, context), null); + pushAssertion(referenceExpr, myPositive, context -> getLiteralType(rhs, context)); } } private void processIn(@NotNull PyExpression lhs, @NotNull PyExpression rhs) { if (lhs instanceof PyReferenceExpression referenceExpr && rhs instanceof PyTupleExpression tupleExpr) { - pushAssertion(referenceExpr, myPositive, false, context -> { + pushAssertion(referenceExpr, myPositive, (TypeEvalContext context) -> { PyExpression[] elements = tupleExpr.getElements(); List types = new ArrayList<>(elements.length); for (PyExpression element : elements) { @@ -135,7 +140,7 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor { types.add(type); } return PyUnionType.union(types); - }, null); + }); } } @@ -160,6 +165,41 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor { } } + @Override + public void visitPyPattern(@NotNull PyPattern node) { + final PsiElement parent = PsiTreeUtil.skipParentsOfType(node, PyCaseClause.class); + if (parent instanceof PyMatchStatement matchStatement) { + final PyExpression subject = PyPsiUtils.flattenParens(matchStatement.getSubject()); + + if (subject instanceof PyReferenceExpression target) { + pushAssertion(target, myPositive, context -> context.getType(node)); + } + } + } + + /** + * Negative type assertion for when all cases fail + */ + @Override + public void visitPyMatchStatement(@NotNull PyMatchStatement matchStatement) { + assert !myPositive; // for match statement as a whole, only negative assertion can be made + final PyExpression subject = matchStatement.getSubject(); + if (subject == null) return; + + if (subject instanceof PyReferenceExpression target) { + pushAssertion(target, true, context -> { + PyType subjectType = context.getType(subject); + for (PyCaseClause cs : matchStatement.getCaseClauses()) { + if (cs.getPattern() == null) continue; + if (cs.getGuardCondition() != null) continue; + subjectType = Ref.deref(createAssertionType(subjectType, context.getType(cs.getPattern()), false, context)); + } + + return subjectType; + }); + } + } + @ApiStatus.Internal public static @Nullable Ref createAssertionType(@Nullable PyType initial, @Nullable PyType suggested, @@ -233,6 +273,9 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor { PyTypeChecker.match(expected, actual, context); } + /** + * @param transformToDefinition if true the result type will be Type[T], not T itself. + */ private static @Nullable PyType transformTypeFromAssertion(@Nullable PyType type, boolean transformToDefinition, @NotNull TypeEvalContext context, @Nullable PyExpression typeElement) { /* @@ -245,8 +288,7 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor { final List members = new ArrayList<>(); final int count = tupleType.getElementCount(); - final PyTupleExpression tupleExpression = PyUtil - .as(PyPsiUtils.flattenParens(PyUtil.as(typeElement, PyParenthesizedExpression.class)), PyTupleExpression.class); + final PyTupleExpression tupleExpression = as(PyPsiUtils.flattenParens(typeElement), PyTupleExpression.class); if (tupleExpression != null && tupleExpression.getElements().length == count) { final PyExpression[] elements = tupleExpression.getElements(); for (int i = 0; i < count; i++) { @@ -276,21 +318,13 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor { return type; } - /** - * @param transformToDefinition if true the result type will be Type[T], not T itself. - */ private void pushAssertion(@NotNull PyReferenceExpression target, boolean positive, - boolean transformToDefinition, - @NotNull Function suggestedType, - @Nullable PyExpression typeElement) { + @NotNull Function suggestedType) { final InstructionTypeCallback typeCallback = new InstructionTypeCallback() { @Override public Ref getType(TypeEvalContext context) { - return createAssertionType(context.getType(target), - transformTypeFromAssertion(suggestedType.apply(context), transformToDefinition, context, typeElement), - positive, - context); + return createAssertionType(context.getType(target), suggestedType.apply(context), positive, context); } }; diff --git a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java index e69c2134e5bc..c8f1b979ffb4 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java +++ b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java @@ -128,10 +128,8 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext< PARAM_SPEC, PARAM_SPEC_EXT, TYPE_VAR_TUPLE, TYPE_VAR_TUPLE_EXT ); - - public static final String CONTEXT_MANAGER = "contextlib.AbstractContextManager"; +public static final String CONTEXT_MANAGER = "contextlib.AbstractContextManager"; public static final String ASYNC_CONTEXT_MANAGER = "contextlib.AbstractAsyncContextManager"; - public static final Set TYPE_DICT_QUALIFIERS = Set.of(REQUIRED, REQUIRED_EXT, NOT_REQUIRED, NOT_REQUIRED_EXT, READONLY, READONLY_EXT); public static final String UNPACK = "typing.Unpack"; diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyAsPatternImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyAsPatternImpl.java index c7217a21b7ac..56c92f0fa885 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyAsPatternImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyAsPatternImpl.java @@ -3,6 +3,10 @@ package com.jetbrains.python.psi.impl; import com.intellij.lang.ASTNode; import com.jetbrains.python.psi.PyAsPattern; import com.jetbrains.python.psi.PyElementVisitor; +import com.jetbrains.python.psi.types.PyType; +import com.jetbrains.python.psi.types.TypeEvalContext; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; public class PyAsPatternImpl extends PyElementImpl implements PyAsPattern { public PyAsPatternImpl(ASTNode astNode) { @@ -13,4 +17,9 @@ public class PyAsPatternImpl extends PyElementImpl implements PyAsPattern { protected void acceptPyVisitor(PyElementVisitor pyVisitor) { pyVisitor.visitPyAsPattern(this); } + + @Override + public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) { + return context.getType(getPattern()); + } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyCapturePatternImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyCapturePatternImpl.java index cac32fbdb9e8..b08ffe8a6f52 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyCapturePatternImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyCapturePatternImpl.java @@ -1,8 +1,24 @@ package com.jetbrains.python.psi.impl; import com.intellij.lang.ASTNode; -import com.jetbrains.python.psi.PyCapturePattern; -import com.jetbrains.python.psi.PyElementVisitor; +import com.intellij.openapi.util.Ref; +import com.intellij.psi.PsiElement; +import com.intellij.psi.util.PsiTreeUtil; +import com.intellij.util.containers.ContainerUtil; +import com.jetbrains.python.PyNames; +import com.jetbrains.python.codeInsight.controlflow.PyTypeAssertionEvaluator; +import com.jetbrains.python.psi.*; +import com.jetbrains.python.psi.resolve.PyResolveContext; +import com.jetbrains.python.psi.resolve.RatedResolveResult; +import com.jetbrains.python.psi.types.*; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; + +import java.util.List; +import java.util.Set; + +import static com.jetbrains.python.psi.PyUtil.as; +import static com.jetbrains.python.psi.impl.PySequencePatternImpl.wrapInListType; public class PyCapturePatternImpl extends PyElementImpl implements PyCapturePattern { public PyCapturePatternImpl(ASTNode astNode) { @@ -13,4 +29,158 @@ public class PyCapturePatternImpl extends PyElementImpl implements PyCapturePatt protected void acceptPyVisitor(PyElementVisitor pyVisitor) { pyVisitor.visitPyCapturePattern(this); } + + @Override + public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) { + return getCaptureType(this, context); + } + + static Set SPECIAL_BUILTINS = Set.of( + "bool", "bytearray", "bytes", "dict", "float", "frozenset", "int", "list", "set", "str", "tuple"); + + /** + * Determines the type of a given pattern assuming it is a capture pattern (even when it is actually not), + * and looking up (to parents or subject expression of the match statement). + */ + static @Nullable PyType getCaptureType(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) { + final PyElement parentPattern = PsiTreeUtil.getParentOfType( + pattern, // Capture corresponds to: + PyCaseClause.class, // - Subject of a match statement + PySingleStarPattern.class, // - Type of parent sequence pattern + PyDoubleStarPattern.class, // - Type of parent mapping pattern + PyKeyValuePattern.class, // - Any + PySequencePattern.class, // - Iterated item type of sequence + PyClassPattern.class, // - Attribute type in the corresponding class + PyKeywordPattern.class // - Attribute type in the corresponding class + ); + + if (parentPattern instanceof PyCaseClause caseClause) { + final PyMatchStatement matchStatement = as(caseClause.getParent(), PyMatchStatement.class); + if (matchStatement == null) return null; + + final PyExpression subject = matchStatement.getSubject(); + if (subject == null) return null; + + PyType subjectType = context.getType(subject); + for (PyCaseClause cs : matchStatement.getCaseClauses()) { + if (cs == caseClause) break; + if (cs.getPattern() == null) continue; + if (cs.getGuardCondition() != null) continue; + subjectType = Ref.deref( + PyTypeAssertionEvaluator.createAssertionType(subjectType, context.getType(cs.getPattern()), false, context)); + } + + return subjectType; + } + if (parentPattern instanceof PySingleStarPattern starPattern) { + final PySequencePattern sequenceParent = as(starPattern.getParent(), PySequencePattern.class); + if (sequenceParent == null) return null; + final PyType sequenceType = PySequencePatternImpl.getSequenceCaptureType(sequenceParent, context); + final PyType iteratedType = PyTypeUtil.toStream(sequenceType) + .flatMap(it -> starPattern.getCapturedTypesFromSequenceType(it, context).stream()).collect(PyTypeUtil.toUnion()); + return wrapInListType(iteratedType, pattern); + } + if (parentPattern instanceof PyDoubleStarPattern) { + final PyMappingPattern mappingParent = as(parentPattern.getParent(), PyMappingPattern.class); + if (mappingParent == null) return null; + var parentType = context.getType(mappingParent); + if (parentType instanceof PyCollectionType collectionType) { + final PyClass dict = PyBuiltinCache.getInstance(pattern).getClass("dict"); + return dict != null ? new PyCollectionTypeImpl(dict, false, collectionType.getElementTypes()) : null; + } + return null; + } + if (parentPattern instanceof PyKeyValuePattern keyValuePattern) { + final PyMappingPattern mappingParent = as(keyValuePattern.getParent(), PyMappingPattern.class); + if (mappingParent == null) return null; + + var dictType = getCaptureType(mappingParent, context); + if (dictType == null) return null; + + if (dictType instanceof PyTypedDictType typedDictType) { + if (context.getType(keyValuePattern.getKeyPattern()) instanceof PyLiteralType l && l.getExpression() instanceof PyStringLiteralExpression str) { + return typedDictType.getElementType(str.getStringValue()); + } + } + var mappingType = PyTypeUtil.convertToType(dictType, "typing.Mapping", pattern, context); + if (mappingType instanceof PyCollectionType collectionType) { + return collectionType.getElementTypes().get(1); + } + return null; + } + if (parentPattern instanceof PySequencePattern sequencePattern) { + final PyType sequenceType = PySequencePatternImpl.getSequenceCaptureType(sequencePattern, context); + if (sequenceType == null) return null; + return PyTypeUtil.toStream(sequenceType).map(it -> { + if (it instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) { + // This is done to skip group- and as-patterns + final var sequenceMember = PsiTreeUtil.findFirstParent(pattern, el -> el.getParent() == sequencePattern); + final List elements = sequencePattern.getElements(); + final int idx = elements.indexOf(sequenceMember); + final int starIdx = ContainerUtil.indexOf(elements, it2 -> it2 instanceof PySingleStarPattern); + if (starIdx == -1 || idx < starIdx) { + return tupleType.getElementType(idx); + } + else { + final int starSpan = tupleType.getElementCount() - elements.size(); + return tupleType.getElementType(idx + starSpan); + } + } + var upcast = PyTypeUtil.convertToType(it, "typing.Sequence", pattern, context); + if (upcast instanceof PyCollectionType collectionType) { + return collectionType.getIteratedItemType(); + } + return null; + }).collect(PyTypeUtil.toUnion()); + } + if (parentPattern instanceof PyClassPattern classPattern) { + if (context.getType(classPattern) instanceof PyClassType classType) { + final List arguments = classPattern.getArgumentList().getPatterns(); + int index = arguments.indexOf(pattern); + if (index < 0) return null; + + if (SPECIAL_BUILTINS.contains(classType.getClassQName())) { + if (index == 0) { + return context.getType(classPattern); + } + return null; + } + + final PyTargetExpression matchArgs = as(resolveTypeMember(classType, PyNames.MATCH_ARGS, context), PyTargetExpression.class); + if (matchArgs == null) return null; + + var matchArgsValue = PyPsiUtils.flattenParens(matchArgs.findAssignedValue()); + if (matchArgsValue instanceof PySequenceExpression sequenceExpression) { + if (sequenceExpression.getElements().length <= index) return null; + + final String attributeName = PyEvaluator.evaluate(sequenceExpression.getElements()[index], String.class); + if (attributeName == null) return null; + + final PyExpression instanceAttribute = as(resolveTypeMember(classType, attributeName, context), PyExpression.class); + if (instanceAttribute == null) return null; + return context.getType(instanceAttribute); + } + } + return null; + } + if (parentPattern instanceof PyKeywordPattern keywordPattern) { + final PyClassPattern classPattern = PsiTreeUtil.getParentOfType(keywordPattern, PyClassPattern.class); + if (classPattern == null) return null; + + if (context.getType(classPattern) instanceof PyClassType classType) { + final PyExpression instanceAttribute = as(resolveTypeMember(classType, keywordPattern.getKeyword(), context), PyExpression.class); + if (instanceAttribute == null) return null; + return context.getType(instanceAttribute); + } + return null; + } + return null; + } + + @Nullable + private static PsiElement resolveTypeMember(@NotNull PyType type, @NotNull String name, @NotNull TypeEvalContext context) { + final PyResolveContext resolveContext = PyResolveContext.defaultContext(context); + final List results = type.resolveMember(name, null, AccessDirection.READ, resolveContext); + return !ContainerUtil.isEmpty(results) ? results.get(0).getElement() : null; + } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyClassPatternImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyClassPatternImpl.java index 051e588b4aa8..8d776d4ce68b 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyClassPatternImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyClassPatternImpl.java @@ -3,6 +3,12 @@ package com.jetbrains.python.psi.impl; import com.intellij.lang.ASTNode; import com.jetbrains.python.psi.PyClassPattern; import com.jetbrains.python.psi.PyElementVisitor; +import com.jetbrains.python.psi.types.PyClassType; +import com.jetbrains.python.psi.types.PyType; +import com.jetbrains.python.psi.types.PyTypeChecker; +import com.jetbrains.python.psi.types.TypeEvalContext; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; public class PyClassPatternImpl extends PyElementImpl implements PyClassPattern { public PyClassPatternImpl(ASTNode astNode) { @@ -13,4 +19,17 @@ public class PyClassPatternImpl extends PyElementImpl implements PyClassPattern protected void acceptPyVisitor(PyElementVisitor pyVisitor) { pyVisitor.visitPyClassPattern(this); } + + @Override + public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) { + final PyType type = context.getType(getClassNameReference()); + if (type instanceof PyClassType classType) { + final PyType instanceType = classType.toInstance(); + final PyType captureType = PyCapturePatternImpl.getCaptureType(this, context); + if (PyTypeChecker.match(captureType, instanceType, context)) { + return instanceType; + } + } + return null; + } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyDoubleStarPatternImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyDoubleStarPatternImpl.java index 43ac7e7fda8b..f5c4dfd8e114 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyDoubleStarPatternImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyDoubleStarPatternImpl.java @@ -3,6 +3,10 @@ package com.jetbrains.python.psi.impl; import com.intellij.lang.ASTNode; import com.jetbrains.python.psi.PyDoubleStarPattern; import com.jetbrains.python.psi.PyElementVisitor; +import com.jetbrains.python.psi.types.PyType; +import com.jetbrains.python.psi.types.TypeEvalContext; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; public class PyDoubleStarPatternImpl extends PyElementImpl implements PyDoubleStarPattern { public PyDoubleStarPatternImpl(ASTNode astNode) { @@ -13,4 +17,9 @@ public class PyDoubleStarPatternImpl extends PyElementImpl implements PyDoubleSt protected void acceptPyVisitor(PyElementVisitor pyVisitor) { pyVisitor.visitPyDoubleStarPattern(this); } + + @Override + public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) { + return null; + } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyGroupPatternImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyGroupPatternImpl.java index 28d04e40dd8d..376731c8a9e8 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyGroupPatternImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyGroupPatternImpl.java @@ -3,6 +3,10 @@ package com.jetbrains.python.psi.impl; import com.intellij.lang.ASTNode; import com.jetbrains.python.psi.PyElementVisitor; import com.jetbrains.python.psi.PyGroupPattern; +import com.jetbrains.python.psi.types.PyType; +import com.jetbrains.python.psi.types.TypeEvalContext; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; public class PyGroupPatternImpl extends PyElementImpl implements PyGroupPattern { public PyGroupPatternImpl(ASTNode astNode) { @@ -13,4 +17,9 @@ public class PyGroupPatternImpl extends PyElementImpl implements PyGroupPattern protected void acceptPyVisitor(PyElementVisitor pyVisitor) { pyVisitor.visitPyGroupPattern(this); } + + @Override + public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) { + return context.getType(getPattern()); + } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyKeyValuePatternImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyKeyValuePatternImpl.java index aca9f67d98fb..5bc2c8eb59da 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyKeyValuePatternImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyKeyValuePatternImpl.java @@ -3,6 +3,13 @@ package com.jetbrains.python.psi.impl; import com.intellij.lang.ASTNode; import com.jetbrains.python.psi.PyElementVisitor; import com.jetbrains.python.psi.PyKeyValuePattern; +import com.jetbrains.python.psi.PyPattern; +import com.jetbrains.python.psi.types.PyTupleType; +import com.jetbrains.python.psi.types.PyType; +import com.jetbrains.python.psi.types.TypeEvalContext; +import org.jetbrains.annotations.NotNull; + +import java.util.Arrays; public class PyKeyValuePatternImpl extends PyElementImpl implements PyKeyValuePattern { public PyKeyValuePatternImpl(ASTNode astNode) { @@ -13,4 +20,15 @@ public class PyKeyValuePatternImpl extends PyElementImpl implements PyKeyValuePa protected void acceptPyVisitor(PyElementVisitor pyVisitor) { pyVisitor.visitPyKeyValuePattern(this); } + + @Override + public PyType getType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) { + final PyType keyType = context.getType(getKeyPattern()); + final PyPattern value = getValuePattern(); + PyType valueType = null; + if (value != null) { + valueType = context.getType(value); + } + return PyTupleType.create(this, Arrays.asList(keyType, valueType)); + } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyKeywordPatternImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyKeywordPatternImpl.java index 9ad04f34b841..135f431a77a7 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyKeywordPatternImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyKeywordPatternImpl.java @@ -4,7 +4,12 @@ import com.intellij.lang.ASTNode; import com.intellij.psi.PsiReference; import com.jetbrains.python.psi.PyElementVisitor; import com.jetbrains.python.psi.PyKeywordPattern; +import com.jetbrains.python.psi.PyPattern; import com.jetbrains.python.psi.impl.references.PyKeywordPatternReference; +import com.jetbrains.python.psi.types.PyType; +import com.jetbrains.python.psi.types.TypeEvalContext; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; public class PyKeywordPatternImpl extends PyElementImpl implements PyKeywordPattern { public PyKeywordPatternImpl(ASTNode astNode) { @@ -20,4 +25,10 @@ public class PyKeywordPatternImpl extends PyElementImpl implements PyKeywordPatt public PsiReference getReference() { return new PyKeywordPatternReference(this); } + + @Override + public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) { + final PyPattern valuePattern = getValuePattern(); + return valuePattern != null ? context.getType(valuePattern) : null; + } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyLiteralPatternImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyLiteralPatternImpl.java index de348a27b96e..9fbfa9cdf3c3 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyLiteralPatternImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyLiteralPatternImpl.java @@ -3,6 +3,11 @@ package com.jetbrains.python.psi.impl; import com.intellij.lang.ASTNode; import com.jetbrains.python.psi.PyElementVisitor; import com.jetbrains.python.psi.PyLiteralPattern; +import com.jetbrains.python.psi.types.PyLiteralType; +import com.jetbrains.python.psi.types.PyType; +import com.jetbrains.python.psi.types.TypeEvalContext; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; public class PyLiteralPatternImpl extends PyElementImpl implements PyLiteralPattern { public PyLiteralPatternImpl(ASTNode astNode) { @@ -13,4 +18,13 @@ public class PyLiteralPatternImpl extends PyElementImpl implements PyLiteralPatt protected void acceptPyVisitor(PyElementVisitor pyVisitor) { pyVisitor.visitPyLiteralPattern(this); } + + @Override + public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) { + PyType literalType = PyLiteralType.Companion.fromLiteralParameter(getExpression(), context); + if (literalType != null) { + return literalType; + } + return context.getType(getExpression()); + } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyMappingPatternImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyMappingPatternImpl.java index 6d9dabe8c8a0..87674ff2ce9d 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyMappingPatternImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyMappingPatternImpl.java @@ -3,11 +3,12 @@ package com.jetbrains.python.psi.impl; import com.intellij.lang.ASTNode; import com.intellij.psi.PsiElement; import com.intellij.psi.PsiListLikeElement; -import com.jetbrains.python.psi.PyElementVisitor; -import com.jetbrains.python.psi.PyKeyValuePattern; -import com.jetbrains.python.psi.PyMappingPattern; +import com.jetbrains.python.psi.*; +import com.jetbrains.python.psi.types.*; import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; +import java.util.ArrayList; import java.util.Arrays; import java.util.List; @@ -22,7 +23,30 @@ public class PyMappingPatternImpl extends PyElementImpl implements PyMappingPatt } @Override - public @NotNull List getComponents() { + public @NotNull List getComponents() { return Arrays.asList(findChildrenByClass(PyKeyValuePattern.class)); } + + @Override + public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) { + ArrayList keyTypes = new ArrayList<>(); + ArrayList valueTypes = new ArrayList<>(); + for (PyKeyValuePattern it : getComponents()) { + PyType type = context.getType(it); + if (type instanceof PyTupleType tupleType) { + keyTypes.add(tupleType.getElementType(0)); + valueTypes.add(tupleType.getElementType(1)); + } + } + //keyTypes.add(null); + //valueTypes.add(null); + return wrapInMappingType(PyUnionType.union(keyTypes), PyUnionType.union(valueTypes), this); + } + + private static @Nullable PyType wrapInMappingType(@Nullable PyType keyType, @Nullable PyType valueType, @NotNull PsiElement resolveAnchor) { + keyType = PyLiteralType.upcastLiteralToClass(keyType); + valueType = PyLiteralType.upcastLiteralToClass(valueType); + final PyClass sequence = PyPsiFacade.getInstance(resolveAnchor.getProject()).createClassByQName("typing.Mapping", resolveAnchor); + return sequence != null ? new PyCollectionTypeImpl(sequence, false, Arrays.asList(keyType, valueType)) : null; + } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyOrPatternImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyOrPatternImpl.java index 077452fd737e..ca5f3ba159aa 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyOrPatternImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyOrPatternImpl.java @@ -1,8 +1,14 @@ package com.jetbrains.python.psi.impl; import com.intellij.lang.ASTNode; +import com.intellij.util.containers.ContainerUtil; import com.jetbrains.python.psi.PyElementVisitor; import com.jetbrains.python.psi.PyOrPattern; +import com.jetbrains.python.psi.types.PyType; +import com.jetbrains.python.psi.types.PyUnionType; +import com.jetbrains.python.psi.types.TypeEvalContext; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; public class PyOrPatternImpl extends PyElementImpl implements PyOrPattern { public PyOrPatternImpl(ASTNode astNode) { @@ -13,4 +19,9 @@ public class PyOrPatternImpl extends PyElementImpl implements PyOrPattern { protected void acceptPyVisitor(PyElementVisitor pyVisitor) { pyVisitor.visitPyOrPattern(this); } + + @Override + public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) { + return PyUnionType.union(ContainerUtil.map(getAlternatives(), it -> context.getType(it))); + } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PySequencePatternImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PySequencePatternImpl.java index 5f3e3052bfeb..ff7e0719518c 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PySequencePatternImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PySequencePatternImpl.java @@ -3,13 +3,16 @@ package com.jetbrains.python.psi.impl; import com.intellij.lang.ASTNode; import com.intellij.psi.PsiElement; import com.intellij.psi.PsiListLikeElement; -import com.jetbrains.python.psi.PyElementVisitor; -import com.jetbrains.python.psi.PyPattern; -import com.jetbrains.python.psi.PySequencePattern; +import com.intellij.util.ArrayUtil; +import com.intellij.util.containers.ContainerUtil; +import com.jetbrains.python.psi.*; +import com.jetbrains.python.psi.types.*; import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; -import java.util.Arrays; -import java.util.List; +import java.util.*; + +import static com.jetbrains.python.psi.types.PyLiteralType.upcastLiteralToClass; public class PySequencePatternImpl extends PyElementImpl implements PySequencePattern, PsiListLikeElement { public PySequencePatternImpl(ASTNode astNode) { @@ -22,7 +25,74 @@ public class PySequencePatternImpl extends PyElementImpl implements PySequencePa } @Override - public @NotNull List getComponents() { - return Arrays.asList(findChildrenByClass(PyPattern.class)); + public @NotNull List getComponents() { + return getElements(); + } + + @Override + public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) { + final PyType sequenceCaptureType = getSequenceCaptureType(this, context); + boolean isHomogeneous = !(sequenceCaptureType instanceof PyTupleType tupleType) || tupleType.isHomogeneous(); + final ArrayList types = new ArrayList<>(); + for (PyPattern pattern : getElements()) { + if (pattern instanceof PySingleStarPattern starPattern) { + types.addAll(starPattern.getCapturedTypesFromSequenceType(sequenceCaptureType, context)); + } + else { + types.add(context.getType(pattern)); + } + } + PyType expectedType = isHomogeneous ? wrapInSequenceType(PyUnionType.union(types), this) : PyTupleType.create(this, types); + PyType captureType = PyCapturePatternImpl.getCaptureType(this, context); + if (captureType == null) return expectedType; + return PyTypeUtil.toStream(captureType) + .map(it -> { + if (PyTypeChecker.match(expectedType, it, context)) { + return it; + } + return expectedType; + }) + .collect(PyTypeUtil.toUnion()); + } + + static @Nullable PyType wrapInListType(@Nullable PyType elementType, @NotNull PsiElement resolveAnchor) { + final PyClass list = PyBuiltinCache.getInstance(resolveAnchor).getClass("list"); + return list != null ? new PyCollectionTypeImpl(list, false, Collections.singletonList(upcastLiteralToClass(elementType))) : null; + } + + @Nullable + public static PyType wrapInSequenceType(@Nullable PyType elementType, @NotNull PsiElement resolveAnchor) { + final PyClass sequence = PyPsiFacade.getInstance(resolveAnchor.getProject()).createClassByQName("typing.Sequence", resolveAnchor); + return sequence != null ? new PyCollectionTypeImpl(sequence, false, Collections.singletonList(upcastLiteralToClass(elementType))) : null; + } + + /** + * Similar to {@link PyCapturePatternImpl#getCaptureType(PyPattern, TypeEvalContext)}, + * but only chooses types that would match to typing.Sequence, and have correct length + */ + @Nullable + static PyType getSequenceCaptureType(@NotNull PySequencePattern pattern, @NotNull TypeEvalContext context) { + final PyType captureTypes = PyCapturePatternImpl.getCaptureType(pattern, context); + final boolean hasStar = ContainerUtil.exists(pattern.getElements(), it -> it instanceof PySingleStarPattern); + + List types = new ArrayList<>(); + for (PyType captureType : PyTypeUtil.toStream(captureTypes)) { + if (captureType instanceof PyClassType classType && + ArrayUtil.contains(classType.getClassQName(), "str", "bytes", "bytearray")) continue; + + PyType sequenceType = PyTypeUtil.convertToType(captureType, "typing.Sequence", pattern, context); + if (sequenceType == null) continue; + + if (captureType instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) { + final List elements = pattern.getElements(); + if (hasStar && elements.size() <= tupleType.getElementCount() || elements.size() == tupleType.getElementCount()) { + types.add(captureType); + } + } + else { + types.add(captureType); + } + } + return PyUnionType.union(types); } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PySingleStarPatternImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PySingleStarPatternImpl.java index 8d5a71b98b58..e50857893a49 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PySingleStarPatternImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PySingleStarPatternImpl.java @@ -1,8 +1,13 @@ package com.jetbrains.python.psi.impl; import com.intellij.lang.ASTNode; -import com.jetbrains.python.psi.PyElementVisitor; -import com.jetbrains.python.psi.PySingleStarPattern; +import com.jetbrains.python.psi.*; +import com.jetbrains.python.psi.types.*; +import org.jetbrains.annotations.NotNull; +import org.jetbrains.annotations.Nullable; + +import java.util.Collections; +import java.util.List; public class PySingleStarPatternImpl extends PyElementImpl implements PySingleStarPattern { public PySingleStarPatternImpl(ASTNode astNode) { @@ -13,4 +18,24 @@ public class PySingleStarPatternImpl extends PyElementImpl implements PySingleSt protected void acceptPyVisitor(PyElementVisitor pyVisitor) { pyVisitor.visitPySingleStarPattern(this); } + + @Override + public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) { + return null; + } + + @Override + public @NotNull List<@Nullable PyType> getCapturedTypesFromSequenceType(@Nullable PyType sequenceType, @NotNull TypeEvalContext context) { + if (getParent() instanceof PySequencePattern sequenceParent) { + final int idx = sequenceParent.getElements().indexOf(this); + if (sequenceType instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) { + return tupleType.getElementTypes().subList(idx, idx + tupleType.getElementCount() - sequenceParent.getElements().size() + 1); + } + var upcast = PyTypeUtil.convertToType(sequenceType, "typing.Sequence", this, context); + if (upcast instanceof PyCollectionType collectionType) { + return Collections.singletonList(collectionType.getIteratedItemType()); + } + } + return List.of(); + } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyTargetExpressionImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyTargetExpressionImpl.java index a5ef5fbfb5da..728ce5304a6d 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyTargetExpressionImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyTargetExpressionImpl.java @@ -154,8 +154,11 @@ public class PyTargetExpressionImpl extends PyBaseElementImpl MAX_ANALYZED_ELEMENTS_OF_LITERALS) { PyUnionType.createWeakType(analyzedElementsType) @@ -66,7 +66,7 @@ object PyCollectionTypeUtil { if (keyExpression !is PyStringLiteralExpression) { return null } - strKeysToValueTypes[keyExpression.stringValue] = Pair(element.value, replaceLiteralWithItsClass(valueType)) + strKeysToValueTypes[keyExpression.stringValue] = Pair(element.value, PyLiteralType.upcastLiteralToClass(valueType)) } return strKeysToValueTypes @@ -82,8 +82,8 @@ object PyCollectionTypeUtil { .forEach { val type = context.getType(it) val (keyType, valueType) = getKeyValueType(type) - keyTypes.add(replaceLiteralWithItsClass(keyType)) - valueTypes.add(replaceLiteralWithItsClass(valueType)) + keyTypes.add(PyLiteralType.upcastLiteralToClass(keyType)) + valueTypes.add(PyLiteralType.upcastLiteralToClass(valueType)) } if (elements.size > MAX_ANALYZED_ELEMENTS_OF_LITERALS) { @@ -107,13 +107,4 @@ object PyCollectionTypeUtil { } return null to null } - - private fun replaceLiteralWithItsClass(type: PyType?): PyType? { - return when (type) { - is PyUnionType -> type.map(::replaceLiteralWithItsClass) - is PyLiteralStringType -> PyClassTypeImpl(type.cls, false) - is PyLiteralType -> PyClassTypeImpl(type.pyClass, false) - else -> type - } - } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyLiteralType.kt b/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyLiteralType.kt index ccf6429b180d..57bfe3ad63d4 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyLiteralType.kt +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/types/PyLiteralType.kt @@ -10,6 +10,7 @@ import com.jetbrains.python.psi.* import com.jetbrains.python.psi.impl.PyEvaluator import com.jetbrains.python.psi.resolve.PyResolveContext import org.jetbrains.annotations.ApiStatus +import org.jetbrains.annotations.ApiStatus.Internal /** @@ -56,6 +57,17 @@ class PyLiteralType private constructor(cls: PyClass, val expression: PyExpressi return PyLiteralType(enumClass, expression) } + @Internal + @JvmStatic + fun upcastLiteralToClass(type: PyType?): PyType? { + return when (type) { + is PyUnionType -> type.map(::upcastLiteralToClass) + is PyLiteralStringType -> PyClassTypeImpl(type.cls, false) + is PyLiteralType -> PyClassTypeImpl(type.pyClass, false) + else -> type + } + } + private fun promoteToType( expectedType: PyType?, expression: PyExpression, diff --git a/python/testData/codeInsight/controlflow/MatchStatementClauseWithBreak.txt b/python/testData/codeInsight/controlflow/MatchStatementClauseWithBreak.txt index 9ad80a3a7c78..8cab4c6b6c22 100644 --- a/python/testData/codeInsight/controlflow/MatchStatementClauseWithBreak.txt +++ b/python/testData/codeInsight/controlflow/MatchStatementClauseWithBreak.txt @@ -1,16 +1,18 @@ 0(1) element: null 1(2) element: PyWhileStatement 2(3,4) READ ACCESS: x -3(13) element: null. Condition: x:false +3(15) element: null. Condition: x:false 4(5) element: null. Condition: x:true 5(6) element: PyStatementList 6(7) element: PyMatchStatement 7(8) READ ACCESS: x -8(9,11) refutable pattern: 42 +8(9,12) refutable pattern: 42 9(10) matched pattern: 42 -10(13) element: PyBreakStatement -11(12) element: PyExpressionStatement -12(1) READ ACCESS: y +10(11) ASSERTTYPE ACCESS: x +11(15) element: PyBreakStatement +12(13) ASSERTTYPE ACCESS: x 13(14) element: PyExpressionStatement -14(15) READ ACCESS: z -15() element: null \ No newline at end of file +14(1) READ ACCESS: y +15(16) element: PyExpressionStatement +16(17) READ ACCESS: z +17() element: null \ No newline at end of file diff --git a/python/testData/codeInsight/controlflow/MatchStatementClauseWithContinue.txt b/python/testData/codeInsight/controlflow/MatchStatementClauseWithContinue.txt index 1df3b52f8459..53ee1bf7a388 100644 --- a/python/testData/codeInsight/controlflow/MatchStatementClauseWithContinue.txt +++ b/python/testData/codeInsight/controlflow/MatchStatementClauseWithContinue.txt @@ -1,16 +1,18 @@ 0(1) element: null 1(2) element: PyWhileStatement 2(3,4) READ ACCESS: x -3(13) element: null. Condition: x:false +3(15) element: null. Condition: x:false 4(5) element: null. Condition: x:true 5(6) element: PyStatementList 6(7) element: PyMatchStatement 7(8) READ ACCESS: x -8(9,11) refutable pattern: 42 +8(9,12) refutable pattern: 42 9(10) matched pattern: 42 -10(1) element: PyContinueStatement -11(12) element: PyExpressionStatement -12(1) READ ACCESS: y +10(11) ASSERTTYPE ACCESS: x +11(1) element: PyContinueStatement +12(13) ASSERTTYPE ACCESS: x 13(14) element: PyExpressionStatement -14(15) READ ACCESS: z -15() element: null \ No newline at end of file +14(1) READ ACCESS: y +15(16) element: PyExpressionStatement +16(17) READ ACCESS: z +17() element: null \ No newline at end of file diff --git a/python/testData/codeInsight/controlflow/MatchStatementClauseWithReturn.txt b/python/testData/codeInsight/controlflow/MatchStatementClauseWithReturn.txt index 8d466859f916..7b9802699ae8 100644 --- a/python/testData/codeInsight/controlflow/MatchStatementClauseWithReturn.txt +++ b/python/testData/codeInsight/controlflow/MatchStatementClauseWithReturn.txt @@ -2,9 +2,11 @@ 1(2) WRITE ACCESS: x 2(3) element: PyMatchStatement 3(4) READ ACCESS: x -4(5,7) refutable pattern: 42 +4(5,8) refutable pattern: 42 5(6) matched pattern: 42 -6(9) element: PyReturnStatement -7(8) element: PyExpressionStatement -8(9) READ ACCESS: y -9() element: null \ No newline at end of file +6(7) ASSERTTYPE ACCESS: x +7(11) element: PyReturnStatement +8(9) ASSERTTYPE ACCESS: x +9(10) element: PyExpressionStatement +10(11) READ ACCESS: y +11() element: null \ No newline at end of file diff --git a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseAliasedRefutableOrPattern.txt b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseAliasedRefutableOrPattern.txt index 9eca5efbb529..9e2afe826d5e 100644 --- a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseAliasedRefutableOrPattern.txt +++ b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseAliasedRefutableOrPattern.txt @@ -1,19 +1,18 @@ 0(1) element: null 1(2) element: PyMatchStatement -2(3,16) refutable pattern: [42] | foo.bar as x -3(4,16) refutable pattern: [42] | foo.bar +2(3) element: PyAsPattern +3(4) refutable pattern: [42] | foo.bar 4(5,8) refutable pattern: [42] 5(6,8) refutable pattern: 42 6(7) matched pattern: 42 -7(12) matched pattern: [42] -8(9,16) refutable pattern: foo.bar -9(10) READ ACCESS: foo +7(11) matched pattern: [42] +8(9) refutable pattern: foo.bar +9(10,15) READ ACCESS: foo 10(11) matched pattern: foo.bar 11(12) matched pattern: [42] | foo.bar 12(13) WRITE ACCESS: x -13(14) matched pattern: [42] | foo.bar as x -14(15) element: PyExpressionStatement -15(16) READ ACCESS: y -16(17) element: PyExpressionStatement -17(18) READ ACCESS: z -18() element: null \ No newline at end of file +13(14) element: PyExpressionStatement +14(15) READ ACCESS: y +15(16) element: PyExpressionStatement +16(17) READ ACCESS: z +17() element: null \ No newline at end of file diff --git a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseClassPattern.txt b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseClassPattern.txt index 1de01dd81a9f..1ec12446231d 100644 --- a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseClassPattern.txt +++ b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseClassPattern.txt @@ -1,12 +1,12 @@ 0(1) element: null 1(2) element: PyMatchStatement -2(3,14) refutable pattern: Class(1, attr=foo.bar) -3(4) READ ACCESS: Class +2(3) refutable pattern: Class(1, attr=foo.bar) +3(4,14) READ ACCESS: Class 4(5,14) refutable pattern: 1 5(6) matched pattern: 1 6(7,14) refutable pattern: attr=foo.bar -7(8,14) refutable pattern: foo.bar -8(9) READ ACCESS: foo +7(8) refutable pattern: foo.bar +8(9,14) READ ACCESS: foo 9(10) matched pattern: foo.bar 10(11) matched pattern: attr=foo.bar 11(12) matched pattern: Class(1, attr=foo.bar) diff --git a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseConjunctionGuard.txt b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseConjunctionGuard.txt index 1c4c2f2de4fe..9496a095f518 100644 --- a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseConjunctionGuard.txt +++ b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseConjunctionGuard.txt @@ -3,14 +3,13 @@ 2(3) WRITE ACCESS: x 3(4) element: PyBinaryExpression 4(5,6) READ ACCESS: x -5(10) element: null. Condition: x > 0:false +5(12) element: null. Condition: x > 0:false 6(7) element: null. Condition: x > 0:true 7(8,9) READ ACCESS: x -8(10) element: null. Condition: x < 10:false +8(12) element: null. Condition: x < 10:false 9(10) element: null. Condition: x < 10:true -10(11) element: PyStatementList. Condition: x > 0 and x < 10:true -11(12) element: PyExpressionStatement -12(13) READ ACCESS: y -13(14) element: PyExpressionStatement -14(15) READ ACCESS: z -15() element: null \ No newline at end of file +10(11) element: PyExpressionStatement +11(12) READ ACCESS: y +12(13) element: PyExpressionStatement +13(14) READ ACCESS: z +14() element: null \ No newline at end of file diff --git a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseDisjunctionConjunctionGuard.txt b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseDisjunctionConjunctionGuard.txt index ecd8773a5ad4..60d6b001ba46 100644 --- a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseDisjunctionConjunctionGuard.txt +++ b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseDisjunctionConjunctionGuard.txt @@ -3,18 +3,17 @@ 2(3) WRITE ACCESS: x 3(4) element: PyBinaryExpression 4(5,6) READ ACCESS: x -5(14) element: null. Condition: x % 4 == 0:false +5(16) element: null. Condition: x % 4 == 0:false 6(7) element: null. Condition: x % 4 == 0:true 7(8) element: PyBinaryExpression 8(9,10) READ ACCESS: x 9(11) element: null. Condition: x % 400 == 0:false 10(14) element: null. Condition: x % 400 == 0:true 11(12,13) READ ACCESS: x -12(14) element: null. Condition: x % 100 != 0:false +12(16) element: null. Condition: x % 100 != 0:false 13(14) element: null. Condition: x % 100 != 0:true -14(15) element: PyStatementList. Condition: x % 4 == 0 and (x % 400 == 0 or x % 100 != 0):true -15(16) element: PyExpressionStatement -16(17) READ ACCESS: y -17(18) element: PyExpressionStatement -18(19) READ ACCESS: z -19() element: null \ No newline at end of file +14(15) element: PyExpressionStatement +15(16) READ ACCESS: y +16(17) element: PyExpressionStatement +17(18) READ ACCESS: z +18() element: null \ No newline at end of file diff --git a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseDisjunctionGuard.txt b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseDisjunctionGuard.txt index 570a4cc770e4..1732f263795c 100644 --- a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseDisjunctionGuard.txt +++ b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseDisjunctionGuard.txt @@ -6,11 +6,10 @@ 5(7) element: null. Condition: x > 0:false 6(10) element: null. Condition: x > 0:true 7(8,9) READ ACCESS: x -8(10) element: null. Condition: x < 0:false +8(12) element: null. Condition: x < 0:false 9(10) element: null. Condition: x < 0:true -10(11) element: PyStatementList. Condition: x > 0 or x < 0:true -11(12) element: PyExpressionStatement -12(13) READ ACCESS: y -13(14) element: PyExpressionStatement -14(15) READ ACCESS: z -15() element: null \ No newline at end of file +10(11) element: PyExpressionStatement +11(12) READ ACCESS: y +12(13) element: PyExpressionStatement +13(14) READ ACCESS: z +14() element: null \ No newline at end of file diff --git a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseGuardWithNonTopLevelDisjunction.txt b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseGuardWithNonTopLevelDisjunction.txt index 262c9fa2dec6..50c119ae4888 100644 --- a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseGuardWithNonTopLevelDisjunction.txt +++ b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseGuardWithNonTopLevelDisjunction.txt @@ -1,6 +1,6 @@ 0(1) element: null 1(2) element: PyMatchStatement -2(3,18) refutable pattern: [x1, x2, x3] +2(3,19) refutable pattern: [x1, x2, x3] 3(4) WRITE ACCESS: x1 4(5) WRITE ACCESS: x2 5(6) WRITE ACCESS: x3 @@ -8,14 +8,15 @@ 7(8) element: PyBinaryExpression 8(9,10) READ ACCESS: x1 9(11) element: null. Condition: x1:false -10(14) element: null. Condition: x1:true -11(12,13) READ ACCESS: x2 -12(14) element: null. Condition: x2:false -13(14) element: null. Condition: x2:true -14(15,18) READ ACCESS: x3 -15(16) element: PyStatementList. Condition: (x1 or x2) > x3:true -16(17) element: PyExpressionStatement -17(18) READ ACCESS: y -18(19) element: PyExpressionStatement -19(20) READ ACCESS: z -20() element: null \ No newline at end of file +10(17) element: null. Condition: x1:true +11(12,13,14) READ ACCESS: x2 +12(19) element: null. Condition: x2:false +13(17) element: null. Condition: x2:true +14(15,16) READ ACCESS: x3 +15(19) element: null. Condition: (x1 or x2) > x3:false +16(17) element: null. Condition: (x1 or x2) > x3:true +17(18) element: PyExpressionStatement +18(19) READ ACCESS: y +19(20) element: PyExpressionStatement +20(21) READ ACCESS: z +21() element: null \ No newline at end of file diff --git a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseIrrefutableOrPatternCaptureVariantFirst.txt b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseIrrefutableOrPatternCaptureVariantFirst.txt index fc25ce35960d..9d8462a214ee 100644 --- a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseIrrefutableOrPatternCaptureVariantFirst.txt +++ b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseIrrefutableOrPatternCaptureVariantFirst.txt @@ -1,11 +1,13 @@ 0(1) element: null 1(2) element: PyMatchStatement -2(6) WRITE ACCESS: x -3(4,8) refutable pattern: [x] -4(5) WRITE ACCESS: x -5(6) matched pattern: [x] -6(7) element: PyExpressionStatement -7(8) READ ACCESS: y +2(3) refutable pattern: x | [x] +3(7) WRITE ACCESS: x +4(5,10) refutable pattern: [x] +5(6) WRITE ACCESS: x +6(7) matched pattern: [x] +7(8) matched pattern: x | [x] 8(9) element: PyExpressionStatement -9(10) READ ACCESS: z -10() element: null \ No newline at end of file +9(10) READ ACCESS: y +10(11) element: PyExpressionStatement +11(12) READ ACCESS: z +12() element: null \ No newline at end of file diff --git a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseIrrefutableOrPatternCaptureVariantLast.txt b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseIrrefutableOrPatternCaptureVariantLast.txt index 5e0c46b1cd52..6a1ba7ad2083 100644 --- a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseIrrefutableOrPatternCaptureVariantLast.txt +++ b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseIrrefutableOrPatternCaptureVariantLast.txt @@ -1,11 +1,13 @@ 0(1) element: null 1(2) element: PyMatchStatement -2(3,5) refutable pattern: [x] -3(4) WRITE ACCESS: x -4(6) matched pattern: [x] -5(6) WRITE ACCESS: x -6(7) element: PyExpressionStatement -7(8) READ ACCESS: y +2(3) refutable pattern: [x] | x +3(4,6) refutable pattern: [x] +4(5) WRITE ACCESS: x +5(7) matched pattern: [x] +6(7) WRITE ACCESS: x +7(8) matched pattern: [x] | x 8(9) element: PyExpressionStatement -9(10) READ ACCESS: z -10() element: null \ No newline at end of file +9(10) READ ACCESS: y +10(11) element: PyExpressionStatement +11(12) READ ACCESS: z +12() element: null \ No newline at end of file diff --git a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseMappingPattern.txt b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseMappingPattern.txt index aafd1ac15ca1..e483666437a0 100644 --- a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseMappingPattern.txt +++ b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseMappingPattern.txt @@ -10,8 +10,8 @@ 9(10,19) refutable pattern: 'bar': foo.bar 10(11,19) refutable pattern: 'bar' 11(12) matched pattern: 'bar' -12(13,19) refutable pattern: foo.bar -13(14) READ ACCESS: foo +12(13) refutable pattern: foo.bar +13(14,19) READ ACCESS: foo 14(15) matched pattern: foo.bar 15(16) matched pattern: 'bar': foo.bar 16(17) matched pattern: {'foo': 1, 'bar': foo.bar} diff --git a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseNestedOrPatterns.txt b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseNestedOrPatterns.txt index 496ceef2c5b5..3c3f4ec34cf3 100644 --- a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseNestedOrPatterns.txt +++ b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseNestedOrPatterns.txt @@ -1,19 +1,17 @@ 0(1) element: null 1(2) element: PyMatchStatement -2(3,16) refutable pattern: 1 | (2 | 3) +2(3) refutable pattern: 1 | (2 | 3) 3(4,5) refutable pattern: 1 -4(14) matched pattern: 1 -5(6,16) refutable pattern: (2 | 3) -6(7,16) refutable pattern: 2 | 3 -7(8,9) refutable pattern: 2 -8(14) matched pattern: 2 -9(10,16) refutable pattern: 3 -10(11) matched pattern: 3 -11(12) matched pattern: 2 | 3 -12(13) matched pattern: (2 | 3) -13(14) matched pattern: 1 | (2 | 3) +4(11) matched pattern: 1 +5(6) refutable pattern: 2 | 3 +6(7,8) refutable pattern: 2 +7(10) matched pattern: 2 +8(9,14) refutable pattern: 3 +9(10) matched pattern: 3 +10(11) matched pattern: 2 | 3 +11(12) matched pattern: 1 | (2 | 3) +12(13) element: PyExpressionStatement +13(14) READ ACCESS: x 14(15) element: PyExpressionStatement -15(16) READ ACCESS: x -16(17) element: PyExpressionStatement -17(18) READ ACCESS: y -18() element: null \ No newline at end of file +15(16) READ ACCESS: y +16() element: null \ No newline at end of file diff --git a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseRefutableOrPattern.txt b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseRefutableOrPattern.txt index 2ede4ea6f2dd..df0d6bd5cb76 100644 --- a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseRefutableOrPattern.txt +++ b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseRefutableOrPattern.txt @@ -1,20 +1,17 @@ 0(1) element: null 1(2) element: PyMatchStatement -2(3,17) refutable pattern: [x] | (foo.bar as x) +2(3) refutable pattern: [x] | (foo.bar as x) 3(4,6) refutable pattern: [x] 4(5) WRITE ACCESS: x -5(15) matched pattern: [x] -6(7,17) refutable pattern: (foo.bar as x) -7(8,17) refutable pattern: foo.bar as x -8(9,17) refutable pattern: foo.bar -9(10) READ ACCESS: foo -10(11) matched pattern: foo.bar -11(12) WRITE ACCESS: x -12(13) matched pattern: foo.bar as x -13(14) matched pattern: (foo.bar as x) -14(15) matched pattern: [x] | (foo.bar as x) -15(16) element: PyExpressionStatement -16(17) READ ACCESS: y -17(18) element: PyExpressionStatement -18(19) READ ACCESS: z -19() element: null \ No newline at end of file +5(11) matched pattern: [x] +6(7) element: PyAsPattern +7(8) refutable pattern: foo.bar +8(9,14) READ ACCESS: foo +9(10) matched pattern: foo.bar +10(11) WRITE ACCESS: x +11(12) matched pattern: [x] | (foo.bar as x) +12(13) element: PyExpressionStatement +13(14) READ ACCESS: y +14(15) element: PyExpressionStatement +15(16) READ ACCESS: z +16() element: null \ No newline at end of file diff --git a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseRefutableOrPatternWithNonBindingVariants.txt b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseRefutableOrPatternWithNonBindingVariants.txt index 84472e8f21e5..a2a8c9a26b61 100644 --- a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseRefutableOrPatternWithNonBindingVariants.txt +++ b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseRefutableOrPatternWithNonBindingVariants.txt @@ -1,8 +1,8 @@ 0(1) element: null 1(2) element: PyMatchStatement -2(3,10) refutable pattern: [] | 42 +2(3) refutable pattern: [] | 42 3(4,5) refutable pattern: [] -4(8) matched pattern: [] +4(7) matched pattern: [] 5(6,10) refutable pattern: 42 6(7) matched pattern: 42 7(8) matched pattern: [] | 42 diff --git a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseRefutableOrPatternWithWildcard.txt b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseRefutableOrPatternWithWildcard.txt index 078ac997cb91..cb3a741dc8ea 100644 --- a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseRefutableOrPatternWithWildcard.txt +++ b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseRefutableOrPatternWithWildcard.txt @@ -1,10 +1,12 @@ 0(1) element: null 1(2) element: PyMatchStatement -2(5) element: PyWildcardPattern -3(4,7) refutable pattern: 42 -4(5) matched pattern: 42 -5(6) element: PyExpressionStatement -6(7) READ ACCESS: y +2(3) refutable pattern: _ | 42 +3(6) element: PyWildcardPattern +4(5,9) refutable pattern: 42 +5(6) matched pattern: 42 +6(7) matched pattern: _ | 42 7(8) element: PyExpressionStatement -8(9) READ ACCESS: z -9() element: null \ No newline at end of file +8(9) READ ACCESS: y +9(10) element: PyExpressionStatement +10(11) READ ACCESS: z +11() element: null \ No newline at end of file diff --git a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseRefutablePatternAndConjunctionGuard.txt b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseRefutablePatternAndConjunctionGuard.txt index 6681675831ab..5aba90a0d874 100644 --- a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseRefutablePatternAndConjunctionGuard.txt +++ b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseRefutablePatternAndConjunctionGuard.txt @@ -1,18 +1,17 @@ 0(1) element: null 1(2) element: PyMatchStatement -2(3,15) refutable pattern: [x] +2(3,14) refutable pattern: [x] 3(4) WRITE ACCESS: x 4(5) matched pattern: [x] 5(6) element: PyBinaryExpression 6(7,8) READ ACCESS: x -7(12) element: null. Condition: x > 0:false +7(14) element: null. Condition: x > 0:false 8(9) element: null. Condition: x > 0:true 9(10,11) READ ACCESS: x -10(12) element: null. Condition: x % 2 == 0:false +10(14) element: null. Condition: x % 2 == 0:false 11(12) element: null. Condition: x % 2 == 0:true -12(13) element: PyStatementList. Condition: x > 0 and x % 2 == 0:true -13(14) element: PyExpressionStatement -14(15) READ ACCESS: y -15(16) element: PyExpressionStatement -16(17) READ ACCESS: z -17() element: null \ No newline at end of file +12(13) element: PyExpressionStatement +13(14) READ ACCESS: y +14(15) element: PyExpressionStatement +15(16) READ ACCESS: z +16() element: null \ No newline at end of file diff --git a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseSequencePattern.txt b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseSequencePattern.txt index 05767136251f..181d77c8c0ad 100644 --- a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseSequencePattern.txt +++ b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseSequencePattern.txt @@ -3,8 +3,8 @@ 2(3,11) refutable pattern: [1, foo.bar] 3(4,11) refutable pattern: 1 4(5) matched pattern: 1 -5(6,11) refutable pattern: foo.bar -6(7) READ ACCESS: foo +5(6) refutable pattern: foo.bar +6(7,11) READ ACCESS: foo 7(8) matched pattern: foo.bar 8(9) matched pattern: [1, foo.bar] 9(10) element: PyExpressionStatement diff --git a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseSequencePatternWithSingleOrPattern.txt b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseSequencePatternWithSingleOrPattern.txt index a6ae38d4a8a9..aeba64e0c8da 100644 --- a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseSequencePatternWithSingleOrPattern.txt +++ b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseSequencePatternWithSingleOrPattern.txt @@ -1,12 +1,14 @@ 0(1) element: null 1(2) element: PyMatchStatement -2(3,9) refutable pattern: [1 | x] -3(4,5) refutable pattern: 1 -4(7) matched pattern: 1 -5(6) WRITE ACCESS: x -6(7) matched pattern: [1 | x] -7(8) element: PyExpressionStatement -8(9) READ ACCESS: y +2(3,11) refutable pattern: [1 | x] +3(4) refutable pattern: 1 | x +4(5,6) refutable pattern: 1 +5(7) matched pattern: 1 +6(7) WRITE ACCESS: x +7(8) matched pattern: 1 | x +8(9) matched pattern: [1 | x] 9(10) element: PyExpressionStatement -10(11) READ ACCESS: z -11() element: null \ No newline at end of file +10(11) READ ACCESS: y +11(12) element: PyExpressionStatement +12(13) READ ACCESS: z +13() element: null \ No newline at end of file diff --git a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseTrivialGuard.txt b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseTrivialGuard.txt index d3b7dbfd8944..c58fbe9566b5 100644 --- a/python/testData/codeInsight/controlflow/MatchStatementSingleClauseTrivialGuard.txt +++ b/python/testData/codeInsight/controlflow/MatchStatementSingleClauseTrivialGuard.txt @@ -1,10 +1,11 @@ 0(1) element: null 1(2) element: PyMatchStatement 2(3) WRITE ACCESS: x -3(4,7) READ ACCESS: x -4(5) element: PyStatementList. Condition: x > 0:true -5(6) element: PyExpressionStatement -6(7) READ ACCESS: y -7(8) element: PyExpressionStatement -8(9) READ ACCESS: z -9() element: null \ No newline at end of file +3(4,5) READ ACCESS: x +4(8) element: null. Condition: x > 0:false +5(6) element: null. Condition: x > 0:true +6(7) element: PyExpressionStatement +7(8) READ ACCESS: y +8(9) element: PyExpressionStatement +9(10) READ ACCESS: z +10() element: null \ No newline at end of file diff --git a/python/testSrc/com/jetbrains/python/PyPatternTypeTest.java b/python/testSrc/com/jetbrains/python/PyPatternTypeTest.java new file mode 100644 index 000000000000..9c4f82a5303b --- /dev/null +++ b/python/testSrc/com/jetbrains/python/PyPatternTypeTest.java @@ -0,0 +1,466 @@ +// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.jetbrains.python; + +import com.jetbrains.python.fixtures.PyInspectionTestCase; +import com.jetbrains.python.inspections.PyAssertTypeInspection; +import com.jetbrains.python.inspections.PyInspection; +import org.jetbrains.annotations.NotNull; + +public class PyPatternTypeTest extends PyInspectionTestCase { + + @Override + protected @NotNull Class getInspectionClass() { + return PyAssertTypeInspection.class; + } + + public void testMatchCapturePatternType() { + doTestByText(""" +from typing import assert_type +class A: ... +m: A + +match m: + case a: + assert_type(a, A) + """); + } + + public void testMatchLiteralPatternNarrows() { + doTestByText(""" +from typing import assert_type, Literal +m: object + +match m: + case 1: + assert_type(m, Literal[1]) + """); + } + + public void testMatchValuePatternNarrows() { + doTestByText(""" +from typing import assert_type +class B: + b: int + +m: object + +match m: + case B.b: + assert_type(m, int) + """); + } + + public void testMatchValuePatternAlreadyNarrower() { + doTestByText(""" +from typing import assert_type +class B: + b: int +m: bool + +match m: + case B.b: + assert_type(m, bool) + """); + } + + public void testMatchSequencePatternCaptures() { + doTestByText(""" +from typing import assert_type +m: list[int] + +match m: + case [a]: + assert_type(a, int) + """); + } + + public void testMatchSequencePatternCapturesStarred() { + doTestByText(""" +from typing import assert_type +from typing import Sequence +m: Sequence[int] + +match m: + case [a, *b]: + assert_type(a, int) + assert_type(b, list[int]) + """); + } + + public void testMatchSequencePatternNarrowsInner() { + doTestByText(""" +from typing import assert_type +from typing import Sequence +m: Sequence[object] + +match m: + case [1, True]: + assert_type(m, Sequence[int | bool]) + """); + } + + public void testMatchSequencePatternNarrowsOuter() { + doTestByText(""" +from typing import assert_type +from typing import Sequence +m: object + +match m: + case [1, True]: + assert_type(m, Sequence[int | bool]) + """); + } + + public void testMatchSequencePatternAlreadyNarrowerInner() { + doTestByText(""" +from typing import assert_type +from typing import Sequence +m: Sequence[bool] + +match m: + case [1, True]: + assert_type(m, Sequence[bool]) + """); + } + + public void testMatchSequencePatternAlreadyNarrowerOuter() { + doTestByText(""" +from typing import assert_type +from typing import Sequence +m: Sequence[object] + +match m: + case [1, True]: + assert_type(m, Sequence[int | bool]) + """); + } + + public void testMatchSequencePatternAlreadyNarrowerBoth() { + doTestByText(""" +from typing import assert_type +from typing import Sequence +m: Sequence[bool] + +match m: + case [1, True]: + assert_type(m, Sequence[bool]) + """); + } + + public void testMatchNestedSequencePatternNarrowsInner() { + doTestByText(""" +from typing import assert_type +from typing import Sequence +m: Sequence[Sequence[object]] + +match m: + case [[1], [True]]: + assert_type(m, Sequence[Sequence[int] | Sequence[bool]]) + """); + } + + public void testMatchNestedSequencePatternNarrowsOuter() { + doTestByText(""" +from typing import assert_type +from typing import Sequence +m: object + +match m: + case [[1], [True]]: + assert_type(m, Sequence[Sequence[int] | Sequence[bool]]) + """); + } + + public void testMatchSequencePatternMatches() { + doTestByText(""" +from typing import assert_type +import array, collections +from typing import Sequence, Iterable + +m1: object +m2: Sequence[int] +m3: array.array[int] +m4: collections.deque[int] +m5: list[int] +m6: memoryview +m7: range +m8: tuple[int] + +m9: str +m10: bytes +m11: bytearray + +match m1: + case [a]: + assert_type(a, Any) + +match m2: + case [b]: + assert_type(b, int) + +match m3: + case [c]: + assert_type(c, int) + +match m4: + case [d]: + assert_type(d, int) + +match m5: + case [e]: + assert_type(e, int) + +match m6: + case [f]: + assert_type(f, int) + +match m7: + case [g]: + assert_type(g, int) + +match m8: + case [h]: + assert_type(h, int) + +match m9: + case [i]: + assert_type(i, Any) + +match m10: + case [j]: + assert_type(j, Any) + +match m11: + case [k]: + assert_type(k, Any) + """); + } + + public void testMatchSequencePatternCapturesTuple() { + doTestByText(""" +from typing import assert_type +m: tuple[int, str, bool] + +match m: + case [a, b, c]: + assert_type(a, int) + assert_type(b, str) + assert_type(c, bool) + assert_type(m, tuple[int, str, bool]) + """); + } + + public void testMatchSequencePatternTupleNarrows() { + doTestByText(""" +from typing import assert_type, Literal +m: tuple[object, object] + +match m: + case [1, "str"]: + assert_type(m, tuple[Literal[1], Literal['str']]) + """); + } + + public void testMatchSequencePatternTupleStarred() { + doTestByText(""" +from typing import assert_type, Literal +m: tuple[int, str, bool] + +match m: + case [a, *b, c]: + assert_type(a, int) + assert_type(b, list[str]) + assert_type(c, bool) + assert_type(m, tuple[int, str, bool]) + """); + } + + public void testMatchSequencePatternTupleStarredUnion() { + doTestByText(""" +from typing import assert_type +m: tuple[int, str, float, bool] + +match m: + case [a, *b, c]: + assert_type(a, int) + assert_type(b, list[str | float]) + assert_type(c, bool) + assert_type(m, tuple[int, str, float, bool]) + """); + } + + public void testMatchSequenceUnionSkip() { + doTestByText(""" +from typing import assert_type +from typing import List, Union +m: Union[List[List[str]], str] + +match m: + case [list(['str'])]: + assert_type(m, list[list[str]]) + """); + } + + public void testMatchMappingPatternCaptures() { + doTestByText(""" +from typing import Dict, assert_type +class B: + b: str +m: Dict[str, int] + +match m: + case {"key": v}: + assert_type(v, int) + +match m: + case {B.b: v2}: + assert_type(v2, int) + """); + } + + public void testMatchMappingPatternCapturesTypedDict() { + doTestByText(""" +from typing import TypedDict, Literal, assert_type + +class A(TypedDict): + a: str + b: int + +class K: + k: Literal['a'] + +m: A + +match m: + case {"a": v}: + assert_type(v, str) + case {"b": v2}: + assert_type(v2, int) + case {"a": v3, "b": v4}: + assert_type(v3, str) + assert_type(v4, int) + case {K.k: v5}: + assert_type(v5, str) + case {"o": v6}: + assert_type(v6, Any) + """); + } + + public void testMatchMappingPatternCaptureRest() { + doTestByText(""" +from typing import Mapping, assert_type + +m: object + +match m: + case {'k': 1, **r}: + assert_type(r, dict[str, int]) + +n: Mapping[str, int] + +match n: + case {'k': 1, **r}: + assert_type(r, dict[str, int]) + """); + } + + public void testMatchClassPatternCapturePositional() { + doTestByText(""" +from typing import assert_type + +class A: + __match_args__ = ("a", "b") + a: str + b: int + +m: A + +match m: + case A(i, j): + assert_type(i, str) + assert_type(j, int) + """); + } + + public void testMatchClassPatternCaptureKeyword() { + doTestByText(""" +from typing import assert_type + +class A: + a: str + b: int + +m: A + +match m: + case A(a=i, b=j): + assert_type(i, str) + assert_type(j, int) + """); + } + + public void testMatchClassPatternCaptureSelf() { + doTestByText(""" +from typing import assert_type + +m: object + +match m: + case bool(a): + assert_type(a, bool) + case bytearray(b): + assert_type(b, bytearray) + case bytes(c): + assert_type(c, bytes) + case dict(d): + assert_type(d, dict) + case float(e): + assert_type(e, float) + case frozenset(f): + assert_type(f, frozenset) + case int(g): + assert_type(g, int) + case list(h): + assert_type(h, list) + case set(i): + assert_type(i, set) + case str(j): + assert_type(j, str) + case tuple(k): + assert_type(k, tuple) + """ + ); + } + + public void testMatchClassPatternNarrowSelfCapture() { + doTestByText(""" + from typing import assert_type + + m: object + + match m: + case bool(): + assert_type(m, bool) + case bytearray(): + assert_type(m, bytearray) + case bytes(): + assert_type(m, bytes) + case dict(): + assert_type(m, dict) + case float(): + assert_type(m, float) + case frozenset(): + assert_type(m, frozenset) + case int(): + assert_type(m, int) + case list(): + assert_type(m, list) + case set(): + assert_type(m, set) + case str(): + assert_type(m, str) + case tuple(): + assert_type(m, tuple)""" + ); + } +} diff --git a/python/testSrc/com/jetbrains/python/PyTypeConversionTest.java b/python/testSrc/com/jetbrains/python/PyTypeConversionTest.java index 9fc6c4dca003..ff7b76ff7dd7 100644 --- a/python/testSrc/com/jetbrains/python/PyTypeConversionTest.java +++ b/python/testSrc/com/jetbrains/python/PyTypeConversionTest.java @@ -1,4 +1,4 @@ -// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +// Copyright 2000-2025 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. package com.jetbrains.python; import com.jetbrains.python.documentation.PythonDocumentationProvider;