mirror of
https://gitflic.ru/project/openide/openide.git
synced 2025-12-16 22:51:17 +07:00
PY-48011 Pattern Matching: Type inference
Merge-request: IJ-MR-154823 Merged-by: Aleksandr Govenko <aleksandr.govenko@jetbrains.com> GitOrigin-RevId: 42cb07bee63f34127c85574fc9c09e6043bc7591
This commit is contained in:
committed by
intellij-monorepo-bot
parent
6860d4accc
commit
0073c7a8bb
@@ -8,6 +8,10 @@ interface PyAstAsPattern : PyAstPattern {
|
||||
return requireNotNull(findChildByClass(PyAstPattern::class.java)) { "${this}: pattern cannot be null" }
|
||||
}
|
||||
|
||||
fun getTarget(): PyAstTargetExpression {
|
||||
return requireNotNull(findChildByClass(PyAstTargetExpression::class.java)) { "${this}: target cannot be null" }
|
||||
}
|
||||
|
||||
override fun isIrrefutable(): Boolean {
|
||||
return getPattern().isIrrefutable
|
||||
}
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
package com.jetbrains.python.ast;
|
||||
|
||||
import org.jetbrains.annotations.ApiStatus;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
|
||||
import java.util.Objects;
|
||||
|
||||
import static com.jetbrains.python.ast.PyAstElementKt.findChildByClass;
|
||||
|
||||
@ApiStatus.Experimental
|
||||
public interface PyAstValuePattern extends PyAstPattern {
|
||||
@@ -10,6 +15,11 @@ public interface PyAstValuePattern extends PyAstPattern {
|
||||
return false;
|
||||
}
|
||||
|
||||
@NotNull
|
||||
default PyAstReferenceExpression getValue() {
|
||||
return Objects.requireNonNull(findChildByClass(this, PyAstReferenceExpression.class));
|
||||
}
|
||||
|
||||
@Override
|
||||
default void acceptPyVisitor(PyAstElementVisitor pyVisitor) {
|
||||
pyVisitor.visitPyValuePattern(this);
|
||||
|
||||
@@ -186,6 +186,7 @@ public final @NonNls class PyNames {
|
||||
public static final String ROUND = "__round__";
|
||||
public static final String CLASS_GETITEM = "__class_getitem__";
|
||||
public static final String PREPARE = "__prepare__";
|
||||
public static final String MATCH_ARGS = "__match_args__";
|
||||
|
||||
public static final String NAME = "__name__";
|
||||
public static final String ENTER = "__enter__";
|
||||
|
||||
@@ -3,5 +3,5 @@ package com.jetbrains.python.psi;
|
||||
|
||||
import com.jetbrains.python.ast.PyAstPattern;
|
||||
|
||||
public interface PyPattern extends PyAstPattern, PyElement {
|
||||
public interface PyPattern extends PyAstPattern, PyTypedElement {
|
||||
}
|
||||
|
||||
@@ -2,6 +2,14 @@
|
||||
package com.jetbrains.python.psi;
|
||||
|
||||
import com.jetbrains.python.ast.PyAstSequencePattern;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static com.jetbrains.python.ast.PyAstElementKt.findChildrenByClass;
|
||||
|
||||
public interface PySequencePattern extends PyAstSequencePattern, PyPattern {
|
||||
default @NotNull List<@NotNull PyPattern> getElements() {
|
||||
return List.of(findChildrenByClass(this, PyPattern.class));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,6 +2,21 @@
|
||||
package com.jetbrains.python.psi;
|
||||
|
||||
import com.jetbrains.python.ast.PyAstSingleStarPattern;
|
||||
import com.jetbrains.python.psi.types.PyType;
|
||||
import com.jetbrains.python.psi.types.TypeEvalContext;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
import static com.jetbrains.python.ast.PyAstElementKt.findChildByClass;
|
||||
|
||||
public interface PySingleStarPattern extends PyAstSingleStarPattern, PyPattern {
|
||||
@NotNull
|
||||
default PyPattern getPattern() {
|
||||
return Objects.requireNonNull(findChildByClass(this, PyPattern.class));
|
||||
}
|
||||
|
||||
@NotNull List<@Nullable PyType> getCapturedTypesFromSequenceType(@Nullable PyType sequenceType, @NotNull TypeEvalContext context);
|
||||
}
|
||||
|
||||
@@ -2,6 +2,11 @@
|
||||
package com.jetbrains.python.psi;
|
||||
|
||||
import com.jetbrains.python.ast.PyAstValuePattern;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
|
||||
public interface PyValuePattern extends PyAstValuePattern, PyPattern {
|
||||
@Override
|
||||
default @NotNull PyReferenceExpression getValue() {
|
||||
return (PyReferenceExpression)PyAstValuePattern.super.getValue();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ import com.intellij.psi.PsiElement;
|
||||
import com.intellij.psi.PsiNamedElement;
|
||||
import com.intellij.psi.util.PsiTreeUtil;
|
||||
import com.intellij.psi.util.QualifiedName;
|
||||
import com.intellij.util.containers.ContainerUtil;
|
||||
import com.jetbrains.python.PyNames;
|
||||
import com.jetbrains.python.PyTokenTypes;
|
||||
import com.jetbrains.python.psi.*;
|
||||
@@ -330,7 +331,134 @@ public class PyControlFlowBuilder extends PyRecursiveElementVisitor {
|
||||
|
||||
@Override
|
||||
public void visitPyMatchStatement(@NotNull PyMatchStatement matchStatement) {
|
||||
new PyMatchStatementControlFlowBuilder(myBuilder, this).build(matchStatement);
|
||||
myBuilder.startNode(matchStatement);
|
||||
PyExpression subject = matchStatement.getSubject();
|
||||
if (subject != null) {
|
||||
subject.accept(this);
|
||||
}
|
||||
for (PyCaseClause caseClause : matchStatement.getCaseClauses()) {
|
||||
visitPyCaseClause(caseClause);
|
||||
}
|
||||
myBuilder.addNodeAndCheckPending(new TransparentInstructionImpl(myBuilder, matchStatement, ""));
|
||||
if (!myBuilder.prevInstruction.allPred().isEmpty()) {
|
||||
addTypeAssertionNodes(matchStatement, false);
|
||||
}
|
||||
myBuilder.addPendingEdge(matchStatement, myBuilder.prevInstruction);
|
||||
myBuilder.prevInstruction = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visitPyCaseClause(@NotNull PyCaseClause clause) {
|
||||
PyPattern pattern = clause.getPattern();
|
||||
if (pattern != null) {
|
||||
pattern.accept(this);
|
||||
addTypeAssertionNodes(pattern, true);
|
||||
}
|
||||
|
||||
TransparentInstruction trueNode = addTransparentInstruction();
|
||||
TransparentInstruction falseNode = addTransparentInstruction();
|
||||
PyExpression guard = clause.getGuardCondition();
|
||||
if (guard != null) {
|
||||
visitCondition(guard, trueNode, falseNode);
|
||||
addTypeAssertionNodes(guard, true);
|
||||
}
|
||||
else {
|
||||
myBuilder.addEdge(myBuilder.prevInstruction, trueNode);
|
||||
}
|
||||
myBuilder.addPendingEdge(clause, falseNode);
|
||||
myBuilder.prevInstruction = trueNode;
|
||||
|
||||
clause.getStatementList().accept(this);
|
||||
|
||||
if (clause.getParent() instanceof PyMatchStatement matchStatement) {
|
||||
myBuilder.addPendingEdge(matchStatement, myBuilder.prevInstruction);
|
||||
myBuilder.updatePendingElementScope(clause.getStatementList(), matchStatement);
|
||||
}
|
||||
myBuilder.prevInstruction = null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visitWildcardPattern(@NotNull PyWildcardPattern node) {
|
||||
myBuilder.startNode(node);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visitPyPattern(@NotNull PyPattern node) {
|
||||
boolean isRefutable = !node.isIrrefutable();
|
||||
if (isRefutable) {
|
||||
myBuilder.addNodeAndCheckPending(new RefutablePatternInstruction(myBuilder, node, false));
|
||||
myBuilder.addPendingEdge(node.getParent(), myBuilder.prevInstruction);
|
||||
}
|
||||
|
||||
node.acceptChildren(this);
|
||||
myBuilder.updatePendingElementScope(node, node.getParent());
|
||||
|
||||
if (isRefutable) {
|
||||
myBuilder.addNode(new RefutablePatternInstruction(myBuilder, node, true));
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visitPyOrPattern(@NotNull PyOrPattern node) {
|
||||
myBuilder.addNodeAndCheckPending(new RefutablePatternInstruction(myBuilder, node, false));
|
||||
|
||||
TransparentInstruction onSuccess = new TransparentInstructionImpl(myBuilder, node, "onSuccess");
|
||||
List<PyPattern> alternatives = node.getAlternatives();
|
||||
PyPattern lastAlternative = ContainerUtil.getLastItem(alternatives);
|
||||
|
||||
for (PyPattern alternative : alternatives) {
|
||||
alternative.accept(this);
|
||||
if (alternative != lastAlternative) {
|
||||
// Allow next alternative to handle the fail edge of this alternative
|
||||
myBuilder.updatePendingElementScope(node, alternative);
|
||||
}
|
||||
myBuilder.addEdge(myBuilder.prevInstruction, onSuccess);
|
||||
myBuilder.prevInstruction = null;
|
||||
}
|
||||
myBuilder.addNode(onSuccess);
|
||||
myBuilder.addNode(new RefutablePatternInstruction(myBuilder, node, true));
|
||||
myBuilder.updatePendingElementScope(node, node.getParent());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visitPyClassPattern(@NotNull PyClassPattern node) {
|
||||
myBuilder.addNodeAndCheckPending(new RefutablePatternInstruction(myBuilder, node, false));
|
||||
|
||||
node.getClassNameReference().accept(this);
|
||||
myBuilder.addPendingEdge(node.getParent(), myBuilder.prevInstruction);
|
||||
|
||||
node.getArgumentList().acceptChildren(this);
|
||||
myBuilder.updatePendingElementScope(node, node.getParent());
|
||||
|
||||
myBuilder.addNode(new RefutablePatternInstruction(myBuilder, node, true));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visitPyValuePattern(@NotNull PyValuePattern node) {
|
||||
myBuilder.addNodeAndCheckPending(new RefutablePatternInstruction(myBuilder, node, false));
|
||||
|
||||
node.getValue().accept(this);
|
||||
myBuilder.addPendingEdge(node.getParent(), myBuilder.prevInstruction);
|
||||
|
||||
myBuilder.addNode(new RefutablePatternInstruction(myBuilder, node, true));
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visitPyAsPattern(@NotNull PyAsPattern node) {
|
||||
// AsPattern cannot fail by itself - it fails only if its child fails.
|
||||
// So no need to create additional fail edge
|
||||
myBuilder.startNode(node);
|
||||
node.acceptChildren(this);
|
||||
myBuilder.updatePendingElementScope(node, node.getParent());
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visitPyGroupPattern(@NotNull PyGroupPattern node) {
|
||||
// GroupPattern cannot fail by itself - it fails only if its child fails.
|
||||
// So no need to create additional fail edge
|
||||
// Also no need to dedicated node for GroupPattern itself
|
||||
node.acceptChildren(this);
|
||||
myBuilder.updatePendingElementScope(node, node.getParent());
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -955,7 +1083,7 @@ public class PyControlFlowBuilder extends PyRecursiveElementVisitor {
|
||||
|| element instanceof PyStatementList);
|
||||
}
|
||||
|
||||
private void addTypeAssertionNodes(@NotNull PyExpression condition, boolean positive) {
|
||||
private void addTypeAssertionNodes(@NotNull PyElement condition, boolean positive) {
|
||||
final PyTypeAssertionEvaluator evaluator = new PyTypeAssertionEvaluator(positive);
|
||||
condition.accept(evaluator);
|
||||
for (PyTypeAssertionEvaluator.Assertion def : evaluator.getDefinitions()) {
|
||||
|
||||
@@ -16,6 +16,7 @@ class PyControlFlowProvider : ControlFlowProvider {
|
||||
override fun getAdditionalInfo(instruction: Instruction): String? {
|
||||
return when (instruction) {
|
||||
is ReadWriteInstruction -> "${instruction.access} ${instruction.name}"
|
||||
is RefutablePatternInstruction -> instruction.elementPresentation
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,165 +0,0 @@
|
||||
package com.jetbrains.python.codeInsight.controlflow;
|
||||
|
||||
import com.intellij.codeInsight.controlflow.ConditionalInstruction;
|
||||
import com.intellij.codeInsight.controlflow.ControlFlowBuilder;
|
||||
import com.intellij.codeInsight.controlflow.Instruction;
|
||||
import com.intellij.psi.PsiElement;
|
||||
import com.intellij.psi.util.PsiTreeUtil;
|
||||
import com.intellij.util.containers.ContainerUtil;
|
||||
import com.jetbrains.python.PyTokenTypes;
|
||||
import com.jetbrains.python.psi.*;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.function.BiFunction;
|
||||
|
||||
public final class PyMatchStatementControlFlowBuilder {
|
||||
private final ControlFlowBuilder myBuilder;
|
||||
private final PyElementVisitor myBaseVisitor;
|
||||
|
||||
public PyMatchStatementControlFlowBuilder(@NotNull ControlFlowBuilder builder, @NotNull PyElementVisitor baseVisitor) {
|
||||
myBuilder = builder;
|
||||
myBaseVisitor = baseVisitor;
|
||||
}
|
||||
|
||||
public void build(@NotNull PyMatchStatement matchStatement) {
|
||||
myBuilder.startNode(matchStatement);
|
||||
PyExpression subject = matchStatement.getSubject();
|
||||
if (subject != null) {
|
||||
subject.accept(myBaseVisitor);
|
||||
}
|
||||
for (PyCaseClause caseClause : matchStatement.getCaseClauses()) {
|
||||
processCaseClause(caseClause);
|
||||
}
|
||||
}
|
||||
|
||||
private void processCaseClause(@NotNull PyCaseClause clause) {
|
||||
PyPattern pattern = clause.getPattern();
|
||||
if (pattern != null) {
|
||||
processPattern(pattern);
|
||||
retargetOutgoingPatternEdges(pattern, (oldScope, instr) -> {
|
||||
return instr.isMatched() ? pattern : clause;
|
||||
});
|
||||
}
|
||||
PyStatementList statementList = clause.getStatementList();
|
||||
PyExpression guard = clause.getGuardCondition();
|
||||
if (guard != null) {
|
||||
guard.accept(myBaseVisitor);
|
||||
// Retarget failure edges coming from inner OR and AND expressions
|
||||
retargetOutgoingEdges(guard, (pendingScope, instr) -> {
|
||||
if (instr instanceof ConditionalInstruction && !((ConditionalInstruction)instr).getResult()) {
|
||||
return clause;
|
||||
}
|
||||
return pendingScope;
|
||||
});
|
||||
// Top-level OR and AND expressions should have had their own outgoing failure edges
|
||||
if (!isConjunctionOrDisjunction(guard)) {
|
||||
myBuilder.addPendingEdge(clause, myBuilder.prevInstruction);
|
||||
}
|
||||
myBuilder.startConditionalNode(statementList, guard, true);
|
||||
}
|
||||
statementList.accept(myBaseVisitor);
|
||||
PyMatchStatement matchStatement = PsiTreeUtil.getParentOfType(clause, PyMatchStatement.class);
|
||||
assert matchStatement != null;
|
||||
retargetOutgoingEdges(statementList, (pendingScope, instruction) -> matchStatement);
|
||||
myBuilder.addPendingEdge(matchStatement, myBuilder.prevInstruction);
|
||||
myBuilder.prevInstruction = null;
|
||||
}
|
||||
|
||||
private void processPattern(@NotNull PyPattern pattern) {
|
||||
boolean isRefutable = !pattern.isIrrefutable();
|
||||
if (isRefutable) {
|
||||
RefutablePatternInstruction instruction = new RefutablePatternInstruction(myBuilder, pattern, false);
|
||||
myBuilder.addNodeAndCheckPending(instruction);
|
||||
myBuilder.addPendingEdge(pattern, instruction);
|
||||
}
|
||||
|
||||
if (pattern instanceof PyWildcardPattern) {
|
||||
myBuilder.startNode(pattern);
|
||||
}
|
||||
else if (pattern instanceof PyOrPattern) {
|
||||
List<PyPattern> alternatives = ((PyOrPattern)pattern).getAlternatives();
|
||||
PyPattern lastAlternative = ContainerUtil.getLastItem(alternatives);
|
||||
for (PyPattern alternative : alternatives) {
|
||||
processPattern(alternative);
|
||||
if (alternative != lastAlternative) {
|
||||
myBuilder.addPendingEdge(alternative, myBuilder.prevInstruction);
|
||||
myBuilder.prevInstruction = null;
|
||||
}
|
||||
retargetOutgoingEdges(alternative, (pendingScope, instruction) -> {
|
||||
if (instruction instanceof RefutablePatternInstruction && !((RefutablePatternInstruction)instruction).isMatched()) {
|
||||
// Pattern has failed, jump to the next alternative if any
|
||||
return alternative;
|
||||
}
|
||||
// Pattern succeeded, jump out of OR-pattern. It can be either a refutable pattern or a capture/wildcard node.
|
||||
else {
|
||||
return pattern;
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
else {
|
||||
pattern.acceptChildren(new PyElementVisitor() {
|
||||
@Override
|
||||
public void visitPyReferenceExpression(@NotNull PyReferenceExpression node) {
|
||||
myBaseVisitor.visitPyReferenceExpression(node);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visitPyTargetExpression(@NotNull PyTargetExpression node) {
|
||||
myBaseVisitor.visitPyTargetExpression(node);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visitPyPattern(@NotNull PyPattern node) {
|
||||
processPattern(node);
|
||||
retargetOutgoingPatternEdges(pattern, (oldScope, instr) -> {
|
||||
// Mismatch in a non-OR pattern means mismatch of the containing pattern as well
|
||||
return instr.isMatched() ? oldScope : pattern;
|
||||
});
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visitPyPatternArgumentList(@NotNull PyPatternArgumentList node) {
|
||||
node.acceptChildren(this);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
if (isRefutable) {
|
||||
myBuilder.addNode(new RefutablePatternInstruction(myBuilder, pattern, true));
|
||||
}
|
||||
}
|
||||
|
||||
private void retargetOutgoingEdges(@NotNull PsiElement scopeAncestor,
|
||||
@NotNull BiFunction<PsiElement, Instruction, PsiElement> newScopeProvider) {
|
||||
myBuilder.processPending((oldScope, instruction) -> {
|
||||
if (oldScope != null && PsiTreeUtil.isAncestor(scopeAncestor, oldScope, false)) {
|
||||
myBuilder.addPendingEdge(newScopeProvider.apply(oldScope, instruction), instruction);
|
||||
}
|
||||
else {
|
||||
myBuilder.addPendingEdge(oldScope, instruction);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
private void retargetOutgoingPatternEdges(@NotNull PsiElement scopeAncestor,
|
||||
@NotNull BiFunction<PsiElement, RefutablePatternInstruction, PsiElement> newScopeProvider) {
|
||||
retargetOutgoingEdges(scopeAncestor, (pendingScope, instruction) -> {
|
||||
if (instruction instanceof RefutablePatternInstruction) {
|
||||
return newScopeProvider.apply(pendingScope, (RefutablePatternInstruction)instruction);
|
||||
}
|
||||
return pendingScope;
|
||||
});
|
||||
}
|
||||
|
||||
private static boolean isConjunctionOrDisjunction(@Nullable PyExpression node) {
|
||||
if (node instanceof PyBinaryExpression) {
|
||||
final var operator = ((PyBinaryExpression)node).getOperator();
|
||||
return operator == PyTokenTypes.AND_KEYWORD || operator == PyTokenTypes.OR_KEYWORD;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
@@ -3,6 +3,7 @@ package com.jetbrains.python.codeInsight.controlflow;
|
||||
|
||||
import com.intellij.openapi.util.Ref;
|
||||
import com.intellij.psi.PsiElement;
|
||||
import com.intellij.psi.util.PsiTreeUtil;
|
||||
import com.intellij.util.containers.ContainerUtil;
|
||||
import com.intellij.util.containers.Stack;
|
||||
import com.jetbrains.python.PyNames;
|
||||
@@ -22,6 +23,8 @@ import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.function.Function;
|
||||
|
||||
import static com.jetbrains.python.psi.PyUtil.as;
|
||||
|
||||
public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
|
||||
private final Stack<Assertion> myStack = new Stack<>();
|
||||
private boolean myPositive;
|
||||
@@ -41,14 +44,15 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
|
||||
if (args.length == 2 && args[0] instanceof PyReferenceExpression target) {
|
||||
final PyExpression typeElement = args[1];
|
||||
|
||||
pushAssertion(target, myPositive, false, context -> context.getType(typeElement), typeElement);
|
||||
pushAssertion(target, myPositive, context ->
|
||||
transformTypeFromAssertion(context.getType(typeElement), false, context, typeElement));
|
||||
}
|
||||
}
|
||||
else if (node.isCalleeText(PyNames.CALLABLE_BUILTIN)) {
|
||||
final PyExpression[] args = node.getArguments();
|
||||
if (args.length == 1 && args[0] instanceof PyReferenceExpression target) {
|
||||
|
||||
pushAssertion(target, myPositive, false, context -> PyTypingTypeProvider.createTypingCallableType(node), null);
|
||||
pushAssertion(target, myPositive, context -> PyTypingTypeProvider.createTypingCallableType(node));
|
||||
}
|
||||
}
|
||||
else if (node.isCalleeText(PyNames.ISSUBCLASS)) {
|
||||
@@ -56,7 +60,8 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
|
||||
if (args.length == 2 && args[0] instanceof PyReferenceExpression target) {
|
||||
final PyExpression typeElement = args[1];
|
||||
|
||||
pushAssertion(target, myPositive, true, context -> context.getType(typeElement), typeElement);
|
||||
pushAssertion(target, myPositive, context ->
|
||||
transformTypeFromAssertion(context.getType(typeElement), true, context, typeElement));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -66,7 +71,7 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
|
||||
if (myPositive && (isIfReferenceStatement(node) || isIfReferenceConditionalStatement(node) || isIfNotReferenceStatement(node))) {
|
||||
// we could not suggest `None` because it could be a reference to an empty collection
|
||||
// so we could push only non-`None` assertions
|
||||
pushAssertion(node, !myPositive, false, context -> PyNoneType.INSTANCE, null);
|
||||
pushAssertion(node, !myPositive, context -> PyNoneType.INSTANCE);
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -105,26 +110,26 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
|
||||
|
||||
if (PyLiteralType.isNone(lhs)) {
|
||||
if (rhs instanceof PyReferenceExpression referenceExpr) {
|
||||
pushAssertion(referenceExpr, myPositive, false, context -> PyNoneType.INSTANCE, null);
|
||||
pushAssertion(referenceExpr, myPositive, context -> PyNoneType.INSTANCE);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (PyLiteralType.isNone(rhs)) {
|
||||
if (lhs instanceof PyReferenceExpression referenceExpr) {
|
||||
pushAssertion(referenceExpr, myPositive, false, context -> PyNoneType.INSTANCE, null);
|
||||
pushAssertion(referenceExpr, myPositive, context -> PyNoneType.INSTANCE);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (lhs instanceof PyReferenceExpression referenceExpr) {
|
||||
pushAssertion(referenceExpr, myPositive, false, context -> getLiteralType(rhs, context), null);
|
||||
pushAssertion(referenceExpr, myPositive, context -> getLiteralType(rhs, context));
|
||||
}
|
||||
}
|
||||
|
||||
private void processIn(@NotNull PyExpression lhs, @NotNull PyExpression rhs) {
|
||||
if (lhs instanceof PyReferenceExpression referenceExpr && rhs instanceof PyTupleExpression tupleExpr) {
|
||||
pushAssertion(referenceExpr, myPositive, false, context -> {
|
||||
pushAssertion(referenceExpr, myPositive, (TypeEvalContext context) -> {
|
||||
PyExpression[] elements = tupleExpr.getElements();
|
||||
List<PyType> types = new ArrayList<>(elements.length);
|
||||
for (PyExpression element : elements) {
|
||||
@@ -135,7 +140,7 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
|
||||
types.add(type);
|
||||
}
|
||||
return PyUnionType.union(types);
|
||||
}, null);
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@@ -160,6 +165,41 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visitPyPattern(@NotNull PyPattern node) {
|
||||
final PsiElement parent = PsiTreeUtil.skipParentsOfType(node, PyCaseClause.class);
|
||||
if (parent instanceof PyMatchStatement matchStatement) {
|
||||
final PyExpression subject = PyPsiUtils.flattenParens(matchStatement.getSubject());
|
||||
|
||||
if (subject instanceof PyReferenceExpression target) {
|
||||
pushAssertion(target, myPositive, context -> context.getType(node));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Negative type assertion for when all cases fail
|
||||
*/
|
||||
@Override
|
||||
public void visitPyMatchStatement(@NotNull PyMatchStatement matchStatement) {
|
||||
assert !myPositive; // for match statement as a whole, only negative assertion can be made
|
||||
final PyExpression subject = matchStatement.getSubject();
|
||||
if (subject == null) return;
|
||||
|
||||
if (subject instanceof PyReferenceExpression target) {
|
||||
pushAssertion(target, true, context -> {
|
||||
PyType subjectType = context.getType(subject);
|
||||
for (PyCaseClause cs : matchStatement.getCaseClauses()) {
|
||||
if (cs.getPattern() == null) continue;
|
||||
if (cs.getGuardCondition() != null) continue;
|
||||
subjectType = Ref.deref(createAssertionType(subjectType, context.getType(cs.getPattern()), false, context));
|
||||
}
|
||||
|
||||
return subjectType;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
@ApiStatus.Internal
|
||||
public static @Nullable Ref<PyType> createAssertionType(@Nullable PyType initial,
|
||||
@Nullable PyType suggested,
|
||||
@@ -233,6 +273,9 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
|
||||
PyTypeChecker.match(expected, actual, context);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param transformToDefinition if true the result type will be Type[T], not T itself.
|
||||
*/
|
||||
private static @Nullable PyType transformTypeFromAssertion(@Nullable PyType type, boolean transformToDefinition, @NotNull TypeEvalContext context,
|
||||
@Nullable PyExpression typeElement) {
|
||||
/*
|
||||
@@ -245,8 +288,7 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
|
||||
final List<PyType> members = new ArrayList<>();
|
||||
final int count = tupleType.getElementCount();
|
||||
|
||||
final PyTupleExpression tupleExpression = PyUtil
|
||||
.as(PyPsiUtils.flattenParens(PyUtil.as(typeElement, PyParenthesizedExpression.class)), PyTupleExpression.class);
|
||||
final PyTupleExpression tupleExpression = as(PyPsiUtils.flattenParens(typeElement), PyTupleExpression.class);
|
||||
if (tupleExpression != null && tupleExpression.getElements().length == count) {
|
||||
final PyExpression[] elements = tupleExpression.getElements();
|
||||
for (int i = 0; i < count; i++) {
|
||||
@@ -276,21 +318,13 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
|
||||
return type;
|
||||
}
|
||||
|
||||
/**
|
||||
* @param transformToDefinition if true the result type will be Type[T], not T itself.
|
||||
*/
|
||||
private void pushAssertion(@NotNull PyReferenceExpression target,
|
||||
boolean positive,
|
||||
boolean transformToDefinition,
|
||||
@NotNull Function<TypeEvalContext, PyType> suggestedType,
|
||||
@Nullable PyExpression typeElement) {
|
||||
@NotNull Function<TypeEvalContext, PyType> suggestedType) {
|
||||
final InstructionTypeCallback typeCallback = new InstructionTypeCallback() {
|
||||
@Override
|
||||
public Ref<PyType> getType(TypeEvalContext context) {
|
||||
return createAssertionType(context.getType(target),
|
||||
transformTypeFromAssertion(suggestedType.apply(context), transformToDefinition, context, typeElement),
|
||||
positive,
|
||||
context);
|
||||
return createAssertionType(context.getType(target), suggestedType.apply(context), positive, context);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -128,10 +128,8 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
|
||||
PARAM_SPEC, PARAM_SPEC_EXT,
|
||||
TYPE_VAR_TUPLE, TYPE_VAR_TUPLE_EXT
|
||||
);
|
||||
|
||||
public static final String CONTEXT_MANAGER = "contextlib.AbstractContextManager";
|
||||
public static final String CONTEXT_MANAGER = "contextlib.AbstractContextManager";
|
||||
public static final String ASYNC_CONTEXT_MANAGER = "contextlib.AbstractAsyncContextManager";
|
||||
|
||||
public static final Set<String> TYPE_DICT_QUALIFIERS = Set.of(REQUIRED, REQUIRED_EXT, NOT_REQUIRED, NOT_REQUIRED_EXT, READONLY, READONLY_EXT);
|
||||
|
||||
public static final String UNPACK = "typing.Unpack";
|
||||
|
||||
@@ -3,6 +3,10 @@ package com.jetbrains.python.psi.impl;
|
||||
import com.intellij.lang.ASTNode;
|
||||
import com.jetbrains.python.psi.PyAsPattern;
|
||||
import com.jetbrains.python.psi.PyElementVisitor;
|
||||
import com.jetbrains.python.psi.types.PyType;
|
||||
import com.jetbrains.python.psi.types.TypeEvalContext;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
public class PyAsPatternImpl extends PyElementImpl implements PyAsPattern {
|
||||
public PyAsPatternImpl(ASTNode astNode) {
|
||||
@@ -13,4 +17,9 @@ public class PyAsPatternImpl extends PyElementImpl implements PyAsPattern {
|
||||
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
|
||||
pyVisitor.visitPyAsPattern(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
|
||||
return context.getType(getPattern());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,24 @@
|
||||
package com.jetbrains.python.psi.impl;
|
||||
|
||||
import com.intellij.lang.ASTNode;
|
||||
import com.jetbrains.python.psi.PyCapturePattern;
|
||||
import com.jetbrains.python.psi.PyElementVisitor;
|
||||
import com.intellij.openapi.util.Ref;
|
||||
import com.intellij.psi.PsiElement;
|
||||
import com.intellij.psi.util.PsiTreeUtil;
|
||||
import com.intellij.util.containers.ContainerUtil;
|
||||
import com.jetbrains.python.PyNames;
|
||||
import com.jetbrains.python.codeInsight.controlflow.PyTypeAssertionEvaluator;
|
||||
import com.jetbrains.python.psi.*;
|
||||
import com.jetbrains.python.psi.resolve.PyResolveContext;
|
||||
import com.jetbrains.python.psi.resolve.RatedResolveResult;
|
||||
import com.jetbrains.python.psi.types.*;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import static com.jetbrains.python.psi.PyUtil.as;
|
||||
import static com.jetbrains.python.psi.impl.PySequencePatternImpl.wrapInListType;
|
||||
|
||||
public class PyCapturePatternImpl extends PyElementImpl implements PyCapturePattern {
|
||||
public PyCapturePatternImpl(ASTNode astNode) {
|
||||
@@ -13,4 +29,158 @@ public class PyCapturePatternImpl extends PyElementImpl implements PyCapturePatt
|
||||
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
|
||||
pyVisitor.visitPyCapturePattern(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
|
||||
return getCaptureType(this, context);
|
||||
}
|
||||
|
||||
static Set<String> SPECIAL_BUILTINS = Set.of(
|
||||
"bool", "bytearray", "bytes", "dict", "float", "frozenset", "int", "list", "set", "str", "tuple");
|
||||
|
||||
/**
|
||||
* Determines the type of a given pattern assuming it is a capture pattern (even when it is actually not),
|
||||
* and looking up (to parents or subject expression of the match statement).
|
||||
*/
|
||||
static @Nullable PyType getCaptureType(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) {
|
||||
final PyElement parentPattern = PsiTreeUtil.getParentOfType(
|
||||
pattern, // Capture corresponds to:
|
||||
PyCaseClause.class, // - Subject of a match statement
|
||||
PySingleStarPattern.class, // - Type of parent sequence pattern
|
||||
PyDoubleStarPattern.class, // - Type of parent mapping pattern
|
||||
PyKeyValuePattern.class, // - Any
|
||||
PySequencePattern.class, // - Iterated item type of sequence
|
||||
PyClassPattern.class, // - Attribute type in the corresponding class
|
||||
PyKeywordPattern.class // - Attribute type in the corresponding class
|
||||
);
|
||||
|
||||
if (parentPattern instanceof PyCaseClause caseClause) {
|
||||
final PyMatchStatement matchStatement = as(caseClause.getParent(), PyMatchStatement.class);
|
||||
if (matchStatement == null) return null;
|
||||
|
||||
final PyExpression subject = matchStatement.getSubject();
|
||||
if (subject == null) return null;
|
||||
|
||||
PyType subjectType = context.getType(subject);
|
||||
for (PyCaseClause cs : matchStatement.getCaseClauses()) {
|
||||
if (cs == caseClause) break;
|
||||
if (cs.getPattern() == null) continue;
|
||||
if (cs.getGuardCondition() != null) continue;
|
||||
subjectType = Ref.deref(
|
||||
PyTypeAssertionEvaluator.createAssertionType(subjectType, context.getType(cs.getPattern()), false, context));
|
||||
}
|
||||
|
||||
return subjectType;
|
||||
}
|
||||
if (parentPattern instanceof PySingleStarPattern starPattern) {
|
||||
final PySequencePattern sequenceParent = as(starPattern.getParent(), PySequencePattern.class);
|
||||
if (sequenceParent == null) return null;
|
||||
final PyType sequenceType = PySequencePatternImpl.getSequenceCaptureType(sequenceParent, context);
|
||||
final PyType iteratedType = PyTypeUtil.toStream(sequenceType)
|
||||
.flatMap(it -> starPattern.getCapturedTypesFromSequenceType(it, context).stream()).collect(PyTypeUtil.toUnion());
|
||||
return wrapInListType(iteratedType, pattern);
|
||||
}
|
||||
if (parentPattern instanceof PyDoubleStarPattern) {
|
||||
final PyMappingPattern mappingParent = as(parentPattern.getParent(), PyMappingPattern.class);
|
||||
if (mappingParent == null) return null;
|
||||
var parentType = context.getType(mappingParent);
|
||||
if (parentType instanceof PyCollectionType collectionType) {
|
||||
final PyClass dict = PyBuiltinCache.getInstance(pattern).getClass("dict");
|
||||
return dict != null ? new PyCollectionTypeImpl(dict, false, collectionType.getElementTypes()) : null;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
if (parentPattern instanceof PyKeyValuePattern keyValuePattern) {
|
||||
final PyMappingPattern mappingParent = as(keyValuePattern.getParent(), PyMappingPattern.class);
|
||||
if (mappingParent == null) return null;
|
||||
|
||||
var dictType = getCaptureType(mappingParent, context);
|
||||
if (dictType == null) return null;
|
||||
|
||||
if (dictType instanceof PyTypedDictType typedDictType) {
|
||||
if (context.getType(keyValuePattern.getKeyPattern()) instanceof PyLiteralType l && l.getExpression() instanceof PyStringLiteralExpression str) {
|
||||
return typedDictType.getElementType(str.getStringValue());
|
||||
}
|
||||
}
|
||||
var mappingType = PyTypeUtil.convertToType(dictType, "typing.Mapping", pattern, context);
|
||||
if (mappingType instanceof PyCollectionType collectionType) {
|
||||
return collectionType.getElementTypes().get(1);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
if (parentPattern instanceof PySequencePattern sequencePattern) {
|
||||
final PyType sequenceType = PySequencePatternImpl.getSequenceCaptureType(sequencePattern, context);
|
||||
if (sequenceType == null) return null;
|
||||
return PyTypeUtil.toStream(sequenceType).map(it -> {
|
||||
if (it instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) {
|
||||
// This is done to skip group- and as-patterns
|
||||
final var sequenceMember = PsiTreeUtil.findFirstParent(pattern, el -> el.getParent() == sequencePattern);
|
||||
final List<PyPattern> elements = sequencePattern.getElements();
|
||||
final int idx = elements.indexOf(sequenceMember);
|
||||
final int starIdx = ContainerUtil.indexOf(elements, it2 -> it2 instanceof PySingleStarPattern);
|
||||
if (starIdx == -1 || idx < starIdx) {
|
||||
return tupleType.getElementType(idx);
|
||||
}
|
||||
else {
|
||||
final int starSpan = tupleType.getElementCount() - elements.size();
|
||||
return tupleType.getElementType(idx + starSpan);
|
||||
}
|
||||
}
|
||||
var upcast = PyTypeUtil.convertToType(it, "typing.Sequence", pattern, context);
|
||||
if (upcast instanceof PyCollectionType collectionType) {
|
||||
return collectionType.getIteratedItemType();
|
||||
}
|
||||
return null;
|
||||
}).collect(PyTypeUtil.toUnion());
|
||||
}
|
||||
if (parentPattern instanceof PyClassPattern classPattern) {
|
||||
if (context.getType(classPattern) instanceof PyClassType classType) {
|
||||
final List<PyPattern> arguments = classPattern.getArgumentList().getPatterns();
|
||||
int index = arguments.indexOf(pattern);
|
||||
if (index < 0) return null;
|
||||
|
||||
if (SPECIAL_BUILTINS.contains(classType.getClassQName())) {
|
||||
if (index == 0) {
|
||||
return context.getType(classPattern);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
final PyTargetExpression matchArgs = as(resolveTypeMember(classType, PyNames.MATCH_ARGS, context), PyTargetExpression.class);
|
||||
if (matchArgs == null) return null;
|
||||
|
||||
var matchArgsValue = PyPsiUtils.flattenParens(matchArgs.findAssignedValue());
|
||||
if (matchArgsValue instanceof PySequenceExpression sequenceExpression) {
|
||||
if (sequenceExpression.getElements().length <= index) return null;
|
||||
|
||||
final String attributeName = PyEvaluator.evaluate(sequenceExpression.getElements()[index], String.class);
|
||||
if (attributeName == null) return null;
|
||||
|
||||
final PyExpression instanceAttribute = as(resolveTypeMember(classType, attributeName, context), PyExpression.class);
|
||||
if (instanceAttribute == null) return null;
|
||||
return context.getType(instanceAttribute);
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
if (parentPattern instanceof PyKeywordPattern keywordPattern) {
|
||||
final PyClassPattern classPattern = PsiTreeUtil.getParentOfType(keywordPattern, PyClassPattern.class);
|
||||
if (classPattern == null) return null;
|
||||
|
||||
if (context.getType(classPattern) instanceof PyClassType classType) {
|
||||
final PyExpression instanceAttribute = as(resolveTypeMember(classType, keywordPattern.getKeyword(), context), PyExpression.class);
|
||||
if (instanceAttribute == null) return null;
|
||||
return context.getType(instanceAttribute);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
private static PsiElement resolveTypeMember(@NotNull PyType type, @NotNull String name, @NotNull TypeEvalContext context) {
|
||||
final PyResolveContext resolveContext = PyResolveContext.defaultContext(context);
|
||||
final List<? extends RatedResolveResult> results = type.resolveMember(name, null, AccessDirection.READ, resolveContext);
|
||||
return !ContainerUtil.isEmpty(results) ? results.get(0).getElement() : null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,12 @@ package com.jetbrains.python.psi.impl;
|
||||
import com.intellij.lang.ASTNode;
|
||||
import com.jetbrains.python.psi.PyClassPattern;
|
||||
import com.jetbrains.python.psi.PyElementVisitor;
|
||||
import com.jetbrains.python.psi.types.PyClassType;
|
||||
import com.jetbrains.python.psi.types.PyType;
|
||||
import com.jetbrains.python.psi.types.PyTypeChecker;
|
||||
import com.jetbrains.python.psi.types.TypeEvalContext;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
public class PyClassPatternImpl extends PyElementImpl implements PyClassPattern {
|
||||
public PyClassPatternImpl(ASTNode astNode) {
|
||||
@@ -13,4 +19,17 @@ public class PyClassPatternImpl extends PyElementImpl implements PyClassPattern
|
||||
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
|
||||
pyVisitor.visitPyClassPattern(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
|
||||
final PyType type = context.getType(getClassNameReference());
|
||||
if (type instanceof PyClassType classType) {
|
||||
final PyType instanceType = classType.toInstance();
|
||||
final PyType captureType = PyCapturePatternImpl.getCaptureType(this, context);
|
||||
if (PyTypeChecker.match(captureType, instanceType, context)) {
|
||||
return instanceType;
|
||||
}
|
||||
}
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,10 @@ package com.jetbrains.python.psi.impl;
|
||||
import com.intellij.lang.ASTNode;
|
||||
import com.jetbrains.python.psi.PyDoubleStarPattern;
|
||||
import com.jetbrains.python.psi.PyElementVisitor;
|
||||
import com.jetbrains.python.psi.types.PyType;
|
||||
import com.jetbrains.python.psi.types.TypeEvalContext;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
public class PyDoubleStarPatternImpl extends PyElementImpl implements PyDoubleStarPattern {
|
||||
public PyDoubleStarPatternImpl(ASTNode astNode) {
|
||||
@@ -13,4 +17,9 @@ public class PyDoubleStarPatternImpl extends PyElementImpl implements PyDoubleSt
|
||||
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
|
||||
pyVisitor.visitPyDoubleStarPattern(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,10 @@ package com.jetbrains.python.psi.impl;
|
||||
import com.intellij.lang.ASTNode;
|
||||
import com.jetbrains.python.psi.PyElementVisitor;
|
||||
import com.jetbrains.python.psi.PyGroupPattern;
|
||||
import com.jetbrains.python.psi.types.PyType;
|
||||
import com.jetbrains.python.psi.types.TypeEvalContext;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
public class PyGroupPatternImpl extends PyElementImpl implements PyGroupPattern {
|
||||
public PyGroupPatternImpl(ASTNode astNode) {
|
||||
@@ -13,4 +17,9 @@ public class PyGroupPatternImpl extends PyElementImpl implements PyGroupPattern
|
||||
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
|
||||
pyVisitor.visitPyGroupPattern(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
|
||||
return context.getType(getPattern());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,13 @@ package com.jetbrains.python.psi.impl;
|
||||
import com.intellij.lang.ASTNode;
|
||||
import com.jetbrains.python.psi.PyElementVisitor;
|
||||
import com.jetbrains.python.psi.PyKeyValuePattern;
|
||||
import com.jetbrains.python.psi.PyPattern;
|
||||
import com.jetbrains.python.psi.types.PyTupleType;
|
||||
import com.jetbrains.python.psi.types.PyType;
|
||||
import com.jetbrains.python.psi.types.TypeEvalContext;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
|
||||
import java.util.Arrays;
|
||||
|
||||
public class PyKeyValuePatternImpl extends PyElementImpl implements PyKeyValuePattern {
|
||||
public PyKeyValuePatternImpl(ASTNode astNode) {
|
||||
@@ -13,4 +20,15 @@ public class PyKeyValuePatternImpl extends PyElementImpl implements PyKeyValuePa
|
||||
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
|
||||
pyVisitor.visitPyKeyValuePattern(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public PyType getType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) {
|
||||
final PyType keyType = context.getType(getKeyPattern());
|
||||
final PyPattern value = getValuePattern();
|
||||
PyType valueType = null;
|
||||
if (value != null) {
|
||||
valueType = context.getType(value);
|
||||
}
|
||||
return PyTupleType.create(this, Arrays.asList(keyType, valueType));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,12 @@ import com.intellij.lang.ASTNode;
|
||||
import com.intellij.psi.PsiReference;
|
||||
import com.jetbrains.python.psi.PyElementVisitor;
|
||||
import com.jetbrains.python.psi.PyKeywordPattern;
|
||||
import com.jetbrains.python.psi.PyPattern;
|
||||
import com.jetbrains.python.psi.impl.references.PyKeywordPatternReference;
|
||||
import com.jetbrains.python.psi.types.PyType;
|
||||
import com.jetbrains.python.psi.types.TypeEvalContext;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
public class PyKeywordPatternImpl extends PyElementImpl implements PyKeywordPattern {
|
||||
public PyKeywordPatternImpl(ASTNode astNode) {
|
||||
@@ -20,4 +25,10 @@ public class PyKeywordPatternImpl extends PyElementImpl implements PyKeywordPatt
|
||||
public PsiReference getReference() {
|
||||
return new PyKeywordPatternReference(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
|
||||
final PyPattern valuePattern = getValuePattern();
|
||||
return valuePattern != null ? context.getType(valuePattern) : null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,11 @@ package com.jetbrains.python.psi.impl;
|
||||
import com.intellij.lang.ASTNode;
|
||||
import com.jetbrains.python.psi.PyElementVisitor;
|
||||
import com.jetbrains.python.psi.PyLiteralPattern;
|
||||
import com.jetbrains.python.psi.types.PyLiteralType;
|
||||
import com.jetbrains.python.psi.types.PyType;
|
||||
import com.jetbrains.python.psi.types.TypeEvalContext;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
public class PyLiteralPatternImpl extends PyElementImpl implements PyLiteralPattern {
|
||||
public PyLiteralPatternImpl(ASTNode astNode) {
|
||||
@@ -13,4 +18,13 @@ public class PyLiteralPatternImpl extends PyElementImpl implements PyLiteralPatt
|
||||
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
|
||||
pyVisitor.visitPyLiteralPattern(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
|
||||
PyType literalType = PyLiteralType.Companion.fromLiteralParameter(getExpression(), context);
|
||||
if (literalType != null) {
|
||||
return literalType;
|
||||
}
|
||||
return context.getType(getExpression());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,11 +3,12 @@ package com.jetbrains.python.psi.impl;
|
||||
import com.intellij.lang.ASTNode;
|
||||
import com.intellij.psi.PsiElement;
|
||||
import com.intellij.psi.PsiListLikeElement;
|
||||
import com.jetbrains.python.psi.PyElementVisitor;
|
||||
import com.jetbrains.python.psi.PyKeyValuePattern;
|
||||
import com.jetbrains.python.psi.PyMappingPattern;
|
||||
import com.jetbrains.python.psi.*;
|
||||
import com.jetbrains.python.psi.types.*;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
@@ -22,7 +23,30 @@ public class PyMappingPatternImpl extends PyElementImpl implements PyMappingPatt
|
||||
}
|
||||
|
||||
@Override
|
||||
public @NotNull List<? extends PsiElement> getComponents() {
|
||||
public @NotNull List<? extends PyKeyValuePattern> getComponents() {
|
||||
return Arrays.asList(findChildrenByClass(PyKeyValuePattern.class));
|
||||
}
|
||||
|
||||
@Override
|
||||
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
|
||||
ArrayList<PyType> keyTypes = new ArrayList<>();
|
||||
ArrayList<PyType> valueTypes = new ArrayList<>();
|
||||
for (PyKeyValuePattern it : getComponents()) {
|
||||
PyType type = context.getType(it);
|
||||
if (type instanceof PyTupleType tupleType) {
|
||||
keyTypes.add(tupleType.getElementType(0));
|
||||
valueTypes.add(tupleType.getElementType(1));
|
||||
}
|
||||
}
|
||||
//keyTypes.add(null);
|
||||
//valueTypes.add(null);
|
||||
return wrapInMappingType(PyUnionType.union(keyTypes), PyUnionType.union(valueTypes), this);
|
||||
}
|
||||
|
||||
private static @Nullable PyType wrapInMappingType(@Nullable PyType keyType, @Nullable PyType valueType, @NotNull PsiElement resolveAnchor) {
|
||||
keyType = PyLiteralType.upcastLiteralToClass(keyType);
|
||||
valueType = PyLiteralType.upcastLiteralToClass(valueType);
|
||||
final PyClass sequence = PyPsiFacade.getInstance(resolveAnchor.getProject()).createClassByQName("typing.Mapping", resolveAnchor);
|
||||
return sequence != null ? new PyCollectionTypeImpl(sequence, false, Arrays.asList(keyType, valueType)) : null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
package com.jetbrains.python.psi.impl;
|
||||
|
||||
import com.intellij.lang.ASTNode;
|
||||
import com.intellij.util.containers.ContainerUtil;
|
||||
import com.jetbrains.python.psi.PyElementVisitor;
|
||||
import com.jetbrains.python.psi.PyOrPattern;
|
||||
import com.jetbrains.python.psi.types.PyType;
|
||||
import com.jetbrains.python.psi.types.PyUnionType;
|
||||
import com.jetbrains.python.psi.types.TypeEvalContext;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
public class PyOrPatternImpl extends PyElementImpl implements PyOrPattern {
|
||||
public PyOrPatternImpl(ASTNode astNode) {
|
||||
@@ -13,4 +19,9 @@ public class PyOrPatternImpl extends PyElementImpl implements PyOrPattern {
|
||||
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
|
||||
pyVisitor.visitPyOrPattern(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
|
||||
return PyUnionType.union(ContainerUtil.map(getAlternatives(), it -> context.getType(it)));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,13 +3,16 @@ package com.jetbrains.python.psi.impl;
|
||||
import com.intellij.lang.ASTNode;
|
||||
import com.intellij.psi.PsiElement;
|
||||
import com.intellij.psi.PsiListLikeElement;
|
||||
import com.jetbrains.python.psi.PyElementVisitor;
|
||||
import com.jetbrains.python.psi.PyPattern;
|
||||
import com.jetbrains.python.psi.PySequencePattern;
|
||||
import com.intellij.util.ArrayUtil;
|
||||
import com.intellij.util.containers.ContainerUtil;
|
||||
import com.jetbrains.python.psi.*;
|
||||
import com.jetbrains.python.psi.types.*;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.*;
|
||||
|
||||
import static com.jetbrains.python.psi.types.PyLiteralType.upcastLiteralToClass;
|
||||
|
||||
public class PySequencePatternImpl extends PyElementImpl implements PySequencePattern, PsiListLikeElement {
|
||||
public PySequencePatternImpl(ASTNode astNode) {
|
||||
@@ -22,7 +25,74 @@ public class PySequencePatternImpl extends PyElementImpl implements PySequencePa
|
||||
}
|
||||
|
||||
@Override
|
||||
public @NotNull List<? extends PsiElement> getComponents() {
|
||||
return Arrays.asList(findChildrenByClass(PyPattern.class));
|
||||
public @NotNull List<? extends PyPattern> getComponents() {
|
||||
return getElements();
|
||||
}
|
||||
|
||||
@Override
|
||||
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
|
||||
final PyType sequenceCaptureType = getSequenceCaptureType(this, context);
|
||||
boolean isHomogeneous = !(sequenceCaptureType instanceof PyTupleType tupleType) || tupleType.isHomogeneous();
|
||||
final ArrayList<PyType> types = new ArrayList<>();
|
||||
for (PyPattern pattern : getElements()) {
|
||||
if (pattern instanceof PySingleStarPattern starPattern) {
|
||||
types.addAll(starPattern.getCapturedTypesFromSequenceType(sequenceCaptureType, context));
|
||||
}
|
||||
else {
|
||||
types.add(context.getType(pattern));
|
||||
}
|
||||
}
|
||||
PyType expectedType = isHomogeneous ? wrapInSequenceType(PyUnionType.union(types), this) : PyTupleType.create(this, types);
|
||||
PyType captureType = PyCapturePatternImpl.getCaptureType(this, context);
|
||||
if (captureType == null) return expectedType;
|
||||
return PyTypeUtil.toStream(captureType)
|
||||
.map(it -> {
|
||||
if (PyTypeChecker.match(expectedType, it, context)) {
|
||||
return it;
|
||||
}
|
||||
return expectedType;
|
||||
})
|
||||
.collect(PyTypeUtil.toUnion());
|
||||
}
|
||||
|
||||
static @Nullable PyType wrapInListType(@Nullable PyType elementType, @NotNull PsiElement resolveAnchor) {
|
||||
final PyClass list = PyBuiltinCache.getInstance(resolveAnchor).getClass("list");
|
||||
return list != null ? new PyCollectionTypeImpl(list, false, Collections.singletonList(upcastLiteralToClass(elementType))) : null;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public static PyType wrapInSequenceType(@Nullable PyType elementType, @NotNull PsiElement resolveAnchor) {
|
||||
final PyClass sequence = PyPsiFacade.getInstance(resolveAnchor.getProject()).createClassByQName("typing.Sequence", resolveAnchor);
|
||||
return sequence != null ? new PyCollectionTypeImpl(sequence, false, Collections.singletonList(upcastLiteralToClass(elementType))) : null;
|
||||
}
|
||||
|
||||
/**
|
||||
* Similar to {@link PyCapturePatternImpl#getCaptureType(PyPattern, TypeEvalContext)},
|
||||
* but only chooses types that would match to typing.Sequence, and have correct length
|
||||
*/
|
||||
@Nullable
|
||||
static PyType getSequenceCaptureType(@NotNull PySequencePattern pattern, @NotNull TypeEvalContext context) {
|
||||
final PyType captureTypes = PyCapturePatternImpl.getCaptureType(pattern, context);
|
||||
final boolean hasStar = ContainerUtil.exists(pattern.getElements(), it -> it instanceof PySingleStarPattern);
|
||||
|
||||
List<PyType> types = new ArrayList<>();
|
||||
for (PyType captureType : PyTypeUtil.toStream(captureTypes)) {
|
||||
if (captureType instanceof PyClassType classType &&
|
||||
ArrayUtil.contains(classType.getClassQName(), "str", "bytes", "bytearray")) continue;
|
||||
|
||||
PyType sequenceType = PyTypeUtil.convertToType(captureType, "typing.Sequence", pattern, context);
|
||||
if (sequenceType == null) continue;
|
||||
|
||||
if (captureType instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) {
|
||||
final List<PyPattern> elements = pattern.getElements();
|
||||
if (hasStar && elements.size() <= tupleType.getElementCount() || elements.size() == tupleType.getElementCount()) {
|
||||
types.add(captureType);
|
||||
}
|
||||
}
|
||||
else {
|
||||
types.add(captureType);
|
||||
}
|
||||
}
|
||||
return PyUnionType.union(types);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,13 @@
|
||||
package com.jetbrains.python.psi.impl;
|
||||
|
||||
import com.intellij.lang.ASTNode;
|
||||
import com.jetbrains.python.psi.PyElementVisitor;
|
||||
import com.jetbrains.python.psi.PySingleStarPattern;
|
||||
import com.jetbrains.python.psi.*;
|
||||
import com.jetbrains.python.psi.types.*;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
public class PySingleStarPatternImpl extends PyElementImpl implements PySingleStarPattern {
|
||||
public PySingleStarPatternImpl(ASTNode astNode) {
|
||||
@@ -13,4 +18,24 @@ public class PySingleStarPatternImpl extends PyElementImpl implements PySingleSt
|
||||
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
|
||||
pyVisitor.visitPySingleStarPattern(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
|
||||
return null;
|
||||
}
|
||||
|
||||
@Override
|
||||
public @NotNull List<@Nullable PyType> getCapturedTypesFromSequenceType(@Nullable PyType sequenceType, @NotNull TypeEvalContext context) {
|
||||
if (getParent() instanceof PySequencePattern sequenceParent) {
|
||||
final int idx = sequenceParent.getElements().indexOf(this);
|
||||
if (sequenceType instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) {
|
||||
return tupleType.getElementTypes().subList(idx, idx + tupleType.getElementCount() - sequenceParent.getElements().size() + 1);
|
||||
}
|
||||
var upcast = PyTypeUtil.convertToType(sequenceType, "typing.Sequence", this, context);
|
||||
if (upcast instanceof PyCollectionType collectionType) {
|
||||
return Collections.singletonList(collectionType.getIteratedItemType());
|
||||
}
|
||||
}
|
||||
return List.of();
|
||||
}
|
||||
}
|
||||
|
||||
@@ -154,8 +154,11 @@ public class PyTargetExpressionImpl extends PyBaseElementImpl<PyTargetExpression
|
||||
}
|
||||
}
|
||||
}
|
||||
if (parent instanceof PyWithItem) {
|
||||
return getWithItemVariableType((PyWithItem)parent, context);
|
||||
if (parent instanceof PyWithItem withItem) {
|
||||
return getWithItemVariableType(withItem, context);
|
||||
}
|
||||
if (parent instanceof PyPattern pattern) {
|
||||
return context.getType(pattern);
|
||||
}
|
||||
if (parent instanceof PyAssignmentExpression) {
|
||||
final PyExpression assignedValue = ((PyAssignmentExpression)parent).getAssignedValue();
|
||||
|
||||
@@ -3,6 +3,10 @@ package com.jetbrains.python.psi.impl;
|
||||
import com.intellij.lang.ASTNode;
|
||||
import com.jetbrains.python.psi.PyElementVisitor;
|
||||
import com.jetbrains.python.psi.PyValuePattern;
|
||||
import com.jetbrains.python.psi.types.PyType;
|
||||
import com.jetbrains.python.psi.types.TypeEvalContext;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
public class PyValuePatternImpl extends PyElementImpl implements PyValuePattern {
|
||||
public PyValuePatternImpl(ASTNode astNode) {
|
||||
@@ -13,4 +17,9 @@ public class PyValuePatternImpl extends PyElementImpl implements PyValuePattern
|
||||
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
|
||||
pyVisitor.visitPyValuePattern(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
|
||||
return context.getType(getValue());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,10 @@ package com.jetbrains.python.psi.impl;
|
||||
import com.intellij.lang.ASTNode;
|
||||
import com.jetbrains.python.psi.PyElementVisitor;
|
||||
import com.jetbrains.python.psi.PyWildcardPattern;
|
||||
import com.jetbrains.python.psi.types.PyType;
|
||||
import com.jetbrains.python.psi.types.TypeEvalContext;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
public class PyWildcardPatternImpl extends PyElementImpl implements PyWildcardPattern {
|
||||
public PyWildcardPatternImpl(ASTNode astNode) {
|
||||
@@ -13,4 +17,9 @@ public class PyWildcardPatternImpl extends PyElementImpl implements PyWildcardPa
|
||||
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
|
||||
pyVisitor.visitWildcardPattern(this);
|
||||
}
|
||||
|
||||
@Override
|
||||
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
|
||||
return PyCapturePatternImpl.getCaptureType(this, context);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -26,7 +26,7 @@ object PyCollectionTypeUtil {
|
||||
private fun getListOrSetIteratedValueType(sequence: PySequenceExpression, context: TypeEvalContext): PyType? {
|
||||
val elements = sequence.elements
|
||||
val analyzedElementsType = PyUnionType.union(
|
||||
elements.take(MAX_ANALYZED_ELEMENTS_OF_LITERALS).map { replaceLiteralWithItsClass(context.getType(it)) }
|
||||
elements.take(MAX_ANALYZED_ELEMENTS_OF_LITERALS).map { PyLiteralType.upcastLiteralToClass(context.getType(it)) }
|
||||
)
|
||||
return if (elements.size > MAX_ANALYZED_ELEMENTS_OF_LITERALS) {
|
||||
PyUnionType.createWeakType(analyzedElementsType)
|
||||
@@ -66,7 +66,7 @@ object PyCollectionTypeUtil {
|
||||
if (keyExpression !is PyStringLiteralExpression) {
|
||||
return null
|
||||
}
|
||||
strKeysToValueTypes[keyExpression.stringValue] = Pair(element.value, replaceLiteralWithItsClass(valueType))
|
||||
strKeysToValueTypes[keyExpression.stringValue] = Pair(element.value, PyLiteralType.upcastLiteralToClass(valueType))
|
||||
}
|
||||
|
||||
return strKeysToValueTypes
|
||||
@@ -82,8 +82,8 @@ object PyCollectionTypeUtil {
|
||||
.forEach {
|
||||
val type = context.getType(it)
|
||||
val (keyType, valueType) = getKeyValueType(type)
|
||||
keyTypes.add(replaceLiteralWithItsClass(keyType))
|
||||
valueTypes.add(replaceLiteralWithItsClass(valueType))
|
||||
keyTypes.add(PyLiteralType.upcastLiteralToClass(keyType))
|
||||
valueTypes.add(PyLiteralType.upcastLiteralToClass(valueType))
|
||||
}
|
||||
|
||||
if (elements.size > MAX_ANALYZED_ELEMENTS_OF_LITERALS) {
|
||||
@@ -107,13 +107,4 @@ object PyCollectionTypeUtil {
|
||||
}
|
||||
return null to null
|
||||
}
|
||||
|
||||
private fun replaceLiteralWithItsClass(type: PyType?): PyType? {
|
||||
return when (type) {
|
||||
is PyUnionType -> type.map(::replaceLiteralWithItsClass)
|
||||
is PyLiteralStringType -> PyClassTypeImpl(type.cls, false)
|
||||
is PyLiteralType -> PyClassTypeImpl(type.pyClass, false)
|
||||
else -> type
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import com.jetbrains.python.psi.*
|
||||
import com.jetbrains.python.psi.impl.PyEvaluator
|
||||
import com.jetbrains.python.psi.resolve.PyResolveContext
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
import org.jetbrains.annotations.ApiStatus.Internal
|
||||
|
||||
|
||||
/**
|
||||
@@ -56,6 +57,17 @@ class PyLiteralType private constructor(cls: PyClass, val expression: PyExpressi
|
||||
return PyLiteralType(enumClass, expression)
|
||||
}
|
||||
|
||||
@Internal
|
||||
@JvmStatic
|
||||
fun upcastLiteralToClass(type: PyType?): PyType? {
|
||||
return when (type) {
|
||||
is PyUnionType -> type.map(::upcastLiteralToClass)
|
||||
is PyLiteralStringType -> PyClassTypeImpl(type.cls, false)
|
||||
is PyLiteralType -> PyClassTypeImpl(type.pyClass, false)
|
||||
else -> type
|
||||
}
|
||||
}
|
||||
|
||||
private fun promoteToType(
|
||||
expectedType: PyType?,
|
||||
expression: PyExpression,
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
0(1) element: null
|
||||
1(2) element: PyWhileStatement
|
||||
2(3,4) READ ACCESS: x
|
||||
3(13) element: null. Condition: x:false
|
||||
3(15) element: null. Condition: x:false
|
||||
4(5) element: null. Condition: x:true
|
||||
5(6) element: PyStatementList
|
||||
6(7) element: PyMatchStatement
|
||||
7(8) READ ACCESS: x
|
||||
8(9,11) refutable pattern: 42
|
||||
8(9,12) refutable pattern: 42
|
||||
9(10) matched pattern: 42
|
||||
10(13) element: PyBreakStatement
|
||||
11(12) element: PyExpressionStatement
|
||||
12(1) READ ACCESS: y
|
||||
10(11) ASSERTTYPE ACCESS: x
|
||||
11(15) element: PyBreakStatement
|
||||
12(13) ASSERTTYPE ACCESS: x
|
||||
13(14) element: PyExpressionStatement
|
||||
14(15) READ ACCESS: z
|
||||
15() element: null
|
||||
14(1) READ ACCESS: y
|
||||
15(16) element: PyExpressionStatement
|
||||
16(17) READ ACCESS: z
|
||||
17() element: null
|
||||
@@ -1,16 +1,18 @@
|
||||
0(1) element: null
|
||||
1(2) element: PyWhileStatement
|
||||
2(3,4) READ ACCESS: x
|
||||
3(13) element: null. Condition: x:false
|
||||
3(15) element: null. Condition: x:false
|
||||
4(5) element: null. Condition: x:true
|
||||
5(6) element: PyStatementList
|
||||
6(7) element: PyMatchStatement
|
||||
7(8) READ ACCESS: x
|
||||
8(9,11) refutable pattern: 42
|
||||
8(9,12) refutable pattern: 42
|
||||
9(10) matched pattern: 42
|
||||
10(1) element: PyContinueStatement
|
||||
11(12) element: PyExpressionStatement
|
||||
12(1) READ ACCESS: y
|
||||
10(11) ASSERTTYPE ACCESS: x
|
||||
11(1) element: PyContinueStatement
|
||||
12(13) ASSERTTYPE ACCESS: x
|
||||
13(14) element: PyExpressionStatement
|
||||
14(15) READ ACCESS: z
|
||||
15() element: null
|
||||
14(1) READ ACCESS: y
|
||||
15(16) element: PyExpressionStatement
|
||||
16(17) READ ACCESS: z
|
||||
17() element: null
|
||||
@@ -2,9 +2,11 @@
|
||||
1(2) WRITE ACCESS: x
|
||||
2(3) element: PyMatchStatement
|
||||
3(4) READ ACCESS: x
|
||||
4(5,7) refutable pattern: 42
|
||||
4(5,8) refutable pattern: 42
|
||||
5(6) matched pattern: 42
|
||||
6(9) element: PyReturnStatement
|
||||
7(8) element: PyExpressionStatement
|
||||
8(9) READ ACCESS: y
|
||||
9() element: null
|
||||
6(7) ASSERTTYPE ACCESS: x
|
||||
7(11) element: PyReturnStatement
|
||||
8(9) ASSERTTYPE ACCESS: x
|
||||
9(10) element: PyExpressionStatement
|
||||
10(11) READ ACCESS: y
|
||||
11() element: null
|
||||
@@ -1,19 +1,18 @@
|
||||
0(1) element: null
|
||||
1(2) element: PyMatchStatement
|
||||
2(3,16) refutable pattern: [42] | foo.bar as x
|
||||
3(4,16) refutable pattern: [42] | foo.bar
|
||||
2(3) element: PyAsPattern
|
||||
3(4) refutable pattern: [42] | foo.bar
|
||||
4(5,8) refutable pattern: [42]
|
||||
5(6,8) refutable pattern: 42
|
||||
6(7) matched pattern: 42
|
||||
7(12) matched pattern: [42]
|
||||
8(9,16) refutable pattern: foo.bar
|
||||
9(10) READ ACCESS: foo
|
||||
7(11) matched pattern: [42]
|
||||
8(9) refutable pattern: foo.bar
|
||||
9(10,15) READ ACCESS: foo
|
||||
10(11) matched pattern: foo.bar
|
||||
11(12) matched pattern: [42] | foo.bar
|
||||
12(13) WRITE ACCESS: x
|
||||
13(14) matched pattern: [42] | foo.bar as x
|
||||
14(15) element: PyExpressionStatement
|
||||
15(16) READ ACCESS: y
|
||||
16(17) element: PyExpressionStatement
|
||||
17(18) READ ACCESS: z
|
||||
18() element: null
|
||||
13(14) element: PyExpressionStatement
|
||||
14(15) READ ACCESS: y
|
||||
15(16) element: PyExpressionStatement
|
||||
16(17) READ ACCESS: z
|
||||
17() element: null
|
||||
@@ -1,12 +1,12 @@
|
||||
0(1) element: null
|
||||
1(2) element: PyMatchStatement
|
||||
2(3,14) refutable pattern: Class(1, attr=foo.bar)
|
||||
3(4) READ ACCESS: Class
|
||||
2(3) refutable pattern: Class(1, attr=foo.bar)
|
||||
3(4,14) READ ACCESS: Class
|
||||
4(5,14) refutable pattern: 1
|
||||
5(6) matched pattern: 1
|
||||
6(7,14) refutable pattern: attr=foo.bar
|
||||
7(8,14) refutable pattern: foo.bar
|
||||
8(9) READ ACCESS: foo
|
||||
7(8) refutable pattern: foo.bar
|
||||
8(9,14) READ ACCESS: foo
|
||||
9(10) matched pattern: foo.bar
|
||||
10(11) matched pattern: attr=foo.bar
|
||||
11(12) matched pattern: Class(1, attr=foo.bar)
|
||||
|
||||
@@ -3,14 +3,13 @@
|
||||
2(3) WRITE ACCESS: x
|
||||
3(4) element: PyBinaryExpression
|
||||
4(5,6) READ ACCESS: x
|
||||
5(10) element: null. Condition: x > 0:false
|
||||
5(12) element: null. Condition: x > 0:false
|
||||
6(7) element: null. Condition: x > 0:true
|
||||
7(8,9) READ ACCESS: x
|
||||
8(10) element: null. Condition: x < 10:false
|
||||
8(12) element: null. Condition: x < 10:false
|
||||
9(10) element: null. Condition: x < 10:true
|
||||
10(11) element: PyStatementList. Condition: x > 0 and x < 10:true
|
||||
11(12) element: PyExpressionStatement
|
||||
12(13) READ ACCESS: y
|
||||
13(14) element: PyExpressionStatement
|
||||
14(15) READ ACCESS: z
|
||||
15() element: null
|
||||
10(11) element: PyExpressionStatement
|
||||
11(12) READ ACCESS: y
|
||||
12(13) element: PyExpressionStatement
|
||||
13(14) READ ACCESS: z
|
||||
14() element: null
|
||||
@@ -3,18 +3,17 @@
|
||||
2(3) WRITE ACCESS: x
|
||||
3(4) element: PyBinaryExpression
|
||||
4(5,6) READ ACCESS: x
|
||||
5(14) element: null. Condition: x % 4 == 0:false
|
||||
5(16) element: null. Condition: x % 4 == 0:false
|
||||
6(7) element: null. Condition: x % 4 == 0:true
|
||||
7(8) element: PyBinaryExpression
|
||||
8(9,10) READ ACCESS: x
|
||||
9(11) element: null. Condition: x % 400 == 0:false
|
||||
10(14) element: null. Condition: x % 400 == 0:true
|
||||
11(12,13) READ ACCESS: x
|
||||
12(14) element: null. Condition: x % 100 != 0:false
|
||||
12(16) element: null. Condition: x % 100 != 0:false
|
||||
13(14) element: null. Condition: x % 100 != 0:true
|
||||
14(15) element: PyStatementList. Condition: x % 4 == 0 and (x % 400 == 0 or x % 100 != 0):true
|
||||
15(16) element: PyExpressionStatement
|
||||
16(17) READ ACCESS: y
|
||||
17(18) element: PyExpressionStatement
|
||||
18(19) READ ACCESS: z
|
||||
19() element: null
|
||||
14(15) element: PyExpressionStatement
|
||||
15(16) READ ACCESS: y
|
||||
16(17) element: PyExpressionStatement
|
||||
17(18) READ ACCESS: z
|
||||
18() element: null
|
||||
@@ -6,11 +6,10 @@
|
||||
5(7) element: null. Condition: x > 0:false
|
||||
6(10) element: null. Condition: x > 0:true
|
||||
7(8,9) READ ACCESS: x
|
||||
8(10) element: null. Condition: x < 0:false
|
||||
8(12) element: null. Condition: x < 0:false
|
||||
9(10) element: null. Condition: x < 0:true
|
||||
10(11) element: PyStatementList. Condition: x > 0 or x < 0:true
|
||||
11(12) element: PyExpressionStatement
|
||||
12(13) READ ACCESS: y
|
||||
13(14) element: PyExpressionStatement
|
||||
14(15) READ ACCESS: z
|
||||
15() element: null
|
||||
10(11) element: PyExpressionStatement
|
||||
11(12) READ ACCESS: y
|
||||
12(13) element: PyExpressionStatement
|
||||
13(14) READ ACCESS: z
|
||||
14() element: null
|
||||
@@ -1,6 +1,6 @@
|
||||
0(1) element: null
|
||||
1(2) element: PyMatchStatement
|
||||
2(3,18) refutable pattern: [x1, x2, x3]
|
||||
2(3,19) refutable pattern: [x1, x2, x3]
|
||||
3(4) WRITE ACCESS: x1
|
||||
4(5) WRITE ACCESS: x2
|
||||
5(6) WRITE ACCESS: x3
|
||||
@@ -8,14 +8,15 @@
|
||||
7(8) element: PyBinaryExpression
|
||||
8(9,10) READ ACCESS: x1
|
||||
9(11) element: null. Condition: x1:false
|
||||
10(14) element: null. Condition: x1:true
|
||||
11(12,13) READ ACCESS: x2
|
||||
12(14) element: null. Condition: x2:false
|
||||
13(14) element: null. Condition: x2:true
|
||||
14(15,18) READ ACCESS: x3
|
||||
15(16) element: PyStatementList. Condition: (x1 or x2) > x3:true
|
||||
16(17) element: PyExpressionStatement
|
||||
17(18) READ ACCESS: y
|
||||
18(19) element: PyExpressionStatement
|
||||
19(20) READ ACCESS: z
|
||||
20() element: null
|
||||
10(17) element: null. Condition: x1:true
|
||||
11(12,13,14) READ ACCESS: x2
|
||||
12(19) element: null. Condition: x2:false
|
||||
13(17) element: null. Condition: x2:true
|
||||
14(15,16) READ ACCESS: x3
|
||||
15(19) element: null. Condition: (x1 or x2) > x3:false
|
||||
16(17) element: null. Condition: (x1 or x2) > x3:true
|
||||
17(18) element: PyExpressionStatement
|
||||
18(19) READ ACCESS: y
|
||||
19(20) element: PyExpressionStatement
|
||||
20(21) READ ACCESS: z
|
||||
21() element: null
|
||||
@@ -1,11 +1,13 @@
|
||||
0(1) element: null
|
||||
1(2) element: PyMatchStatement
|
||||
2(6) WRITE ACCESS: x
|
||||
3(4,8) refutable pattern: [x]
|
||||
4(5) WRITE ACCESS: x
|
||||
5(6) matched pattern: [x]
|
||||
6(7) element: PyExpressionStatement
|
||||
7(8) READ ACCESS: y
|
||||
2(3) refutable pattern: x | [x]
|
||||
3(7) WRITE ACCESS: x
|
||||
4(5,10) refutable pattern: [x]
|
||||
5(6) WRITE ACCESS: x
|
||||
6(7) matched pattern: [x]
|
||||
7(8) matched pattern: x | [x]
|
||||
8(9) element: PyExpressionStatement
|
||||
9(10) READ ACCESS: z
|
||||
10() element: null
|
||||
9(10) READ ACCESS: y
|
||||
10(11) element: PyExpressionStatement
|
||||
11(12) READ ACCESS: z
|
||||
12() element: null
|
||||
@@ -1,11 +1,13 @@
|
||||
0(1) element: null
|
||||
1(2) element: PyMatchStatement
|
||||
2(3,5) refutable pattern: [x]
|
||||
3(4) WRITE ACCESS: x
|
||||
4(6) matched pattern: [x]
|
||||
5(6) WRITE ACCESS: x
|
||||
6(7) element: PyExpressionStatement
|
||||
7(8) READ ACCESS: y
|
||||
2(3) refutable pattern: [x] | x
|
||||
3(4,6) refutable pattern: [x]
|
||||
4(5) WRITE ACCESS: x
|
||||
5(7) matched pattern: [x]
|
||||
6(7) WRITE ACCESS: x
|
||||
7(8) matched pattern: [x] | x
|
||||
8(9) element: PyExpressionStatement
|
||||
9(10) READ ACCESS: z
|
||||
10() element: null
|
||||
9(10) READ ACCESS: y
|
||||
10(11) element: PyExpressionStatement
|
||||
11(12) READ ACCESS: z
|
||||
12() element: null
|
||||
@@ -10,8 +10,8 @@
|
||||
9(10,19) refutable pattern: 'bar': foo.bar
|
||||
10(11,19) refutable pattern: 'bar'
|
||||
11(12) matched pattern: 'bar'
|
||||
12(13,19) refutable pattern: foo.bar
|
||||
13(14) READ ACCESS: foo
|
||||
12(13) refutable pattern: foo.bar
|
||||
13(14,19) READ ACCESS: foo
|
||||
14(15) matched pattern: foo.bar
|
||||
15(16) matched pattern: 'bar': foo.bar
|
||||
16(17) matched pattern: {'foo': 1, 'bar': foo.bar}
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
0(1) element: null
|
||||
1(2) element: PyMatchStatement
|
||||
2(3,16) refutable pattern: 1 | (2 | 3)
|
||||
2(3) refutable pattern: 1 | (2 | 3)
|
||||
3(4,5) refutable pattern: 1
|
||||
4(14) matched pattern: 1
|
||||
5(6,16) refutable pattern: (2 | 3)
|
||||
6(7,16) refutable pattern: 2 | 3
|
||||
7(8,9) refutable pattern: 2
|
||||
8(14) matched pattern: 2
|
||||
9(10,16) refutable pattern: 3
|
||||
10(11) matched pattern: 3
|
||||
11(12) matched pattern: 2 | 3
|
||||
12(13) matched pattern: (2 | 3)
|
||||
13(14) matched pattern: 1 | (2 | 3)
|
||||
4(11) matched pattern: 1
|
||||
5(6) refutable pattern: 2 | 3
|
||||
6(7,8) refutable pattern: 2
|
||||
7(10) matched pattern: 2
|
||||
8(9,14) refutable pattern: 3
|
||||
9(10) matched pattern: 3
|
||||
10(11) matched pattern: 2 | 3
|
||||
11(12) matched pattern: 1 | (2 | 3)
|
||||
12(13) element: PyExpressionStatement
|
||||
13(14) READ ACCESS: x
|
||||
14(15) element: PyExpressionStatement
|
||||
15(16) READ ACCESS: x
|
||||
16(17) element: PyExpressionStatement
|
||||
17(18) READ ACCESS: y
|
||||
18() element: null
|
||||
15(16) READ ACCESS: y
|
||||
16() element: null
|
||||
@@ -1,20 +1,17 @@
|
||||
0(1) element: null
|
||||
1(2) element: PyMatchStatement
|
||||
2(3,17) refutable pattern: [x] | (foo.bar as x)
|
||||
2(3) refutable pattern: [x] | (foo.bar as x)
|
||||
3(4,6) refutable pattern: [x]
|
||||
4(5) WRITE ACCESS: x
|
||||
5(15) matched pattern: [x]
|
||||
6(7,17) refutable pattern: (foo.bar as x)
|
||||
7(8,17) refutable pattern: foo.bar as x
|
||||
8(9,17) refutable pattern: foo.bar
|
||||
9(10) READ ACCESS: foo
|
||||
10(11) matched pattern: foo.bar
|
||||
11(12) WRITE ACCESS: x
|
||||
12(13) matched pattern: foo.bar as x
|
||||
13(14) matched pattern: (foo.bar as x)
|
||||
14(15) matched pattern: [x] | (foo.bar as x)
|
||||
15(16) element: PyExpressionStatement
|
||||
16(17) READ ACCESS: y
|
||||
17(18) element: PyExpressionStatement
|
||||
18(19) READ ACCESS: z
|
||||
19() element: null
|
||||
5(11) matched pattern: [x]
|
||||
6(7) element: PyAsPattern
|
||||
7(8) refutable pattern: foo.bar
|
||||
8(9,14) READ ACCESS: foo
|
||||
9(10) matched pattern: foo.bar
|
||||
10(11) WRITE ACCESS: x
|
||||
11(12) matched pattern: [x] | (foo.bar as x)
|
||||
12(13) element: PyExpressionStatement
|
||||
13(14) READ ACCESS: y
|
||||
14(15) element: PyExpressionStatement
|
||||
15(16) READ ACCESS: z
|
||||
16() element: null
|
||||
@@ -1,8 +1,8 @@
|
||||
0(1) element: null
|
||||
1(2) element: PyMatchStatement
|
||||
2(3,10) refutable pattern: [] | 42
|
||||
2(3) refutable pattern: [] | 42
|
||||
3(4,5) refutable pattern: []
|
||||
4(8) matched pattern: []
|
||||
4(7) matched pattern: []
|
||||
5(6,10) refutable pattern: 42
|
||||
6(7) matched pattern: 42
|
||||
7(8) matched pattern: [] | 42
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
0(1) element: null
|
||||
1(2) element: PyMatchStatement
|
||||
2(5) element: PyWildcardPattern
|
||||
3(4,7) refutable pattern: 42
|
||||
4(5) matched pattern: 42
|
||||
5(6) element: PyExpressionStatement
|
||||
6(7) READ ACCESS: y
|
||||
2(3) refutable pattern: _ | 42
|
||||
3(6) element: PyWildcardPattern
|
||||
4(5,9) refutable pattern: 42
|
||||
5(6) matched pattern: 42
|
||||
6(7) matched pattern: _ | 42
|
||||
7(8) element: PyExpressionStatement
|
||||
8(9) READ ACCESS: z
|
||||
9() element: null
|
||||
8(9) READ ACCESS: y
|
||||
9(10) element: PyExpressionStatement
|
||||
10(11) READ ACCESS: z
|
||||
11() element: null
|
||||
@@ -1,18 +1,17 @@
|
||||
0(1) element: null
|
||||
1(2) element: PyMatchStatement
|
||||
2(3,15) refutable pattern: [x]
|
||||
2(3,14) refutable pattern: [x]
|
||||
3(4) WRITE ACCESS: x
|
||||
4(5) matched pattern: [x]
|
||||
5(6) element: PyBinaryExpression
|
||||
6(7,8) READ ACCESS: x
|
||||
7(12) element: null. Condition: x > 0:false
|
||||
7(14) element: null. Condition: x > 0:false
|
||||
8(9) element: null. Condition: x > 0:true
|
||||
9(10,11) READ ACCESS: x
|
||||
10(12) element: null. Condition: x % 2 == 0:false
|
||||
10(14) element: null. Condition: x % 2 == 0:false
|
||||
11(12) element: null. Condition: x % 2 == 0:true
|
||||
12(13) element: PyStatementList. Condition: x > 0 and x % 2 == 0:true
|
||||
13(14) element: PyExpressionStatement
|
||||
14(15) READ ACCESS: y
|
||||
15(16) element: PyExpressionStatement
|
||||
16(17) READ ACCESS: z
|
||||
17() element: null
|
||||
12(13) element: PyExpressionStatement
|
||||
13(14) READ ACCESS: y
|
||||
14(15) element: PyExpressionStatement
|
||||
15(16) READ ACCESS: z
|
||||
16() element: null
|
||||
@@ -3,8 +3,8 @@
|
||||
2(3,11) refutable pattern: [1, foo.bar]
|
||||
3(4,11) refutable pattern: 1
|
||||
4(5) matched pattern: 1
|
||||
5(6,11) refutable pattern: foo.bar
|
||||
6(7) READ ACCESS: foo
|
||||
5(6) refutable pattern: foo.bar
|
||||
6(7,11) READ ACCESS: foo
|
||||
7(8) matched pattern: foo.bar
|
||||
8(9) matched pattern: [1, foo.bar]
|
||||
9(10) element: PyExpressionStatement
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
0(1) element: null
|
||||
1(2) element: PyMatchStatement
|
||||
2(3,9) refutable pattern: [1 | x]
|
||||
3(4,5) refutable pattern: 1
|
||||
4(7) matched pattern: 1
|
||||
5(6) WRITE ACCESS: x
|
||||
6(7) matched pattern: [1 | x]
|
||||
7(8) element: PyExpressionStatement
|
||||
8(9) READ ACCESS: y
|
||||
2(3,11) refutable pattern: [1 | x]
|
||||
3(4) refutable pattern: 1 | x
|
||||
4(5,6) refutable pattern: 1
|
||||
5(7) matched pattern: 1
|
||||
6(7) WRITE ACCESS: x
|
||||
7(8) matched pattern: 1 | x
|
||||
8(9) matched pattern: [1 | x]
|
||||
9(10) element: PyExpressionStatement
|
||||
10(11) READ ACCESS: z
|
||||
11() element: null
|
||||
10(11) READ ACCESS: y
|
||||
11(12) element: PyExpressionStatement
|
||||
12(13) READ ACCESS: z
|
||||
13() element: null
|
||||
@@ -1,10 +1,11 @@
|
||||
0(1) element: null
|
||||
1(2) element: PyMatchStatement
|
||||
2(3) WRITE ACCESS: x
|
||||
3(4,7) READ ACCESS: x
|
||||
4(5) element: PyStatementList. Condition: x > 0:true
|
||||
5(6) element: PyExpressionStatement
|
||||
6(7) READ ACCESS: y
|
||||
7(8) element: PyExpressionStatement
|
||||
8(9) READ ACCESS: z
|
||||
9() element: null
|
||||
3(4,5) READ ACCESS: x
|
||||
4(8) element: null. Condition: x > 0:false
|
||||
5(6) element: null. Condition: x > 0:true
|
||||
6(7) element: PyExpressionStatement
|
||||
7(8) READ ACCESS: y
|
||||
8(9) element: PyExpressionStatement
|
||||
9(10) READ ACCESS: z
|
||||
10() element: null
|
||||
466
python/testSrc/com/jetbrains/python/PyPatternTypeTest.java
Normal file
466
python/testSrc/com/jetbrains/python/PyPatternTypeTest.java
Normal file
@@ -0,0 +1,466 @@
|
||||
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.jetbrains.python;
|
||||
|
||||
import com.jetbrains.python.fixtures.PyInspectionTestCase;
|
||||
import com.jetbrains.python.inspections.PyAssertTypeInspection;
|
||||
import com.jetbrains.python.inspections.PyInspection;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
|
||||
public class PyPatternTypeTest extends PyInspectionTestCase {
|
||||
|
||||
@Override
|
||||
protected @NotNull Class<? extends PyInspection> getInspectionClass() {
|
||||
return PyAssertTypeInspection.class;
|
||||
}
|
||||
|
||||
public void testMatchCapturePatternType() {
|
||||
doTestByText("""
|
||||
from typing import assert_type
|
||||
class A: ...
|
||||
m: A
|
||||
|
||||
match m:
|
||||
case a:
|
||||
assert_type(a, A)
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchLiteralPatternNarrows() {
|
||||
doTestByText("""
|
||||
from typing import assert_type, Literal
|
||||
m: object
|
||||
|
||||
match m:
|
||||
case 1:
|
||||
assert_type(m, Literal[1])
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchValuePatternNarrows() {
|
||||
doTestByText("""
|
||||
from typing import assert_type
|
||||
class B:
|
||||
b: int
|
||||
|
||||
m: object
|
||||
|
||||
match m:
|
||||
case B.b:
|
||||
assert_type(m, int)
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchValuePatternAlreadyNarrower() {
|
||||
doTestByText("""
|
||||
from typing import assert_type
|
||||
class B:
|
||||
b: int
|
||||
m: bool
|
||||
|
||||
match m:
|
||||
case B.b:
|
||||
assert_type(m, bool)
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchSequencePatternCaptures() {
|
||||
doTestByText("""
|
||||
from typing import assert_type
|
||||
m: list[int]
|
||||
|
||||
match m:
|
||||
case [a]:
|
||||
assert_type(a, int)
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchSequencePatternCapturesStarred() {
|
||||
doTestByText("""
|
||||
from typing import assert_type
|
||||
from typing import Sequence
|
||||
m: Sequence[int]
|
||||
|
||||
match m:
|
||||
case [a, *b]:
|
||||
assert_type(a, int)
|
||||
assert_type(b, list[int])
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchSequencePatternNarrowsInner() {
|
||||
doTestByText("""
|
||||
from typing import assert_type
|
||||
from typing import Sequence
|
||||
m: Sequence[object]
|
||||
|
||||
match m:
|
||||
case [1, True]:
|
||||
assert_type(m, Sequence[int | bool])
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchSequencePatternNarrowsOuter() {
|
||||
doTestByText("""
|
||||
from typing import assert_type
|
||||
from typing import Sequence
|
||||
m: object
|
||||
|
||||
match m:
|
||||
case [1, True]:
|
||||
assert_type(m, Sequence[int | bool])
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchSequencePatternAlreadyNarrowerInner() {
|
||||
doTestByText("""
|
||||
from typing import assert_type
|
||||
from typing import Sequence
|
||||
m: Sequence[bool]
|
||||
|
||||
match m:
|
||||
case [1, True]:
|
||||
assert_type(m, Sequence[bool])
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchSequencePatternAlreadyNarrowerOuter() {
|
||||
doTestByText("""
|
||||
from typing import assert_type
|
||||
from typing import Sequence
|
||||
m: Sequence[object]
|
||||
|
||||
match m:
|
||||
case [1, True]:
|
||||
assert_type(m, Sequence[int | bool])
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchSequencePatternAlreadyNarrowerBoth() {
|
||||
doTestByText("""
|
||||
from typing import assert_type
|
||||
from typing import Sequence
|
||||
m: Sequence[bool]
|
||||
|
||||
match m:
|
||||
case [1, True]:
|
||||
assert_type(m, Sequence[bool])
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchNestedSequencePatternNarrowsInner() {
|
||||
doTestByText("""
|
||||
from typing import assert_type
|
||||
from typing import Sequence
|
||||
m: Sequence[Sequence[object]]
|
||||
|
||||
match m:
|
||||
case [[1], [True]]:
|
||||
assert_type(m, Sequence[Sequence[int] | Sequence[bool]])
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchNestedSequencePatternNarrowsOuter() {
|
||||
doTestByText("""
|
||||
from typing import assert_type
|
||||
from typing import Sequence
|
||||
m: object
|
||||
|
||||
match m:
|
||||
case [[1], [True]]:
|
||||
assert_type(m, Sequence[Sequence[int] | Sequence[bool]])
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchSequencePatternMatches() {
|
||||
doTestByText("""
|
||||
from typing import assert_type
|
||||
import array, collections
|
||||
from typing import Sequence, Iterable
|
||||
|
||||
m1: object
|
||||
m2: Sequence[int]
|
||||
m3: array.array[int]
|
||||
m4: collections.deque[int]
|
||||
m5: list[int]
|
||||
m6: memoryview
|
||||
m7: range
|
||||
m8: tuple[int]
|
||||
|
||||
m9: str
|
||||
m10: bytes
|
||||
m11: bytearray
|
||||
|
||||
match m1:
|
||||
case [a]:
|
||||
assert_type(a, Any)
|
||||
|
||||
match m2:
|
||||
case [b]:
|
||||
assert_type(b, int)
|
||||
|
||||
match m3:
|
||||
case [c]:
|
||||
assert_type(c, int)
|
||||
|
||||
match m4:
|
||||
case [d]:
|
||||
assert_type(d, int)
|
||||
|
||||
match m5:
|
||||
case [e]:
|
||||
assert_type(e, int)
|
||||
|
||||
match m6:
|
||||
case [f]:
|
||||
assert_type(f, int)
|
||||
|
||||
match m7:
|
||||
case [g]:
|
||||
assert_type(g, int)
|
||||
|
||||
match m8:
|
||||
case [h]:
|
||||
assert_type(h, int)
|
||||
|
||||
match m9:
|
||||
case [i]:
|
||||
assert_type(i, Any)
|
||||
|
||||
match m10:
|
||||
case [j]:
|
||||
assert_type(j, Any)
|
||||
|
||||
match m11:
|
||||
case [k]:
|
||||
assert_type(k, Any)
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchSequencePatternCapturesTuple() {
|
||||
doTestByText("""
|
||||
from typing import assert_type
|
||||
m: tuple[int, str, bool]
|
||||
|
||||
match m:
|
||||
case [a, b, c]:
|
||||
assert_type(a, int)
|
||||
assert_type(b, str)
|
||||
assert_type(c, bool)
|
||||
assert_type(m, tuple[int, str, bool])
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchSequencePatternTupleNarrows() {
|
||||
doTestByText("""
|
||||
from typing import assert_type, Literal
|
||||
m: tuple[object, object]
|
||||
|
||||
match m:
|
||||
case [1, "str"]:
|
||||
assert_type(m, tuple[Literal[1], Literal['str']])
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchSequencePatternTupleStarred() {
|
||||
doTestByText("""
|
||||
from typing import assert_type, Literal
|
||||
m: tuple[int, str, bool]
|
||||
|
||||
match m:
|
||||
case [a, *b, c]:
|
||||
assert_type(a, int)
|
||||
assert_type(b, list[str])
|
||||
assert_type(c, bool)
|
||||
assert_type(m, tuple[int, str, bool])
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchSequencePatternTupleStarredUnion() {
|
||||
doTestByText("""
|
||||
from typing import assert_type
|
||||
m: tuple[int, str, float, bool]
|
||||
|
||||
match m:
|
||||
case [a, *b, c]:
|
||||
assert_type(a, int)
|
||||
assert_type(b, list[str | float])
|
||||
assert_type(c, bool)
|
||||
assert_type(m, tuple[int, str, float, bool])
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchSequenceUnionSkip() {
|
||||
doTestByText("""
|
||||
from typing import assert_type
|
||||
from typing import List, Union
|
||||
m: Union[List[List[str]], str]
|
||||
|
||||
match m:
|
||||
case [list(['str'])]:
|
||||
assert_type(m, list[list[str]])
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchMappingPatternCaptures() {
|
||||
doTestByText("""
|
||||
from typing import Dict, assert_type
|
||||
class B:
|
||||
b: str
|
||||
m: Dict[str, int]
|
||||
|
||||
match m:
|
||||
case {"key": v}:
|
||||
assert_type(v, int)
|
||||
|
||||
match m:
|
||||
case {B.b: v2}:
|
||||
assert_type(v2, int)
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchMappingPatternCapturesTypedDict() {
|
||||
doTestByText("""
|
||||
from typing import TypedDict, Literal, assert_type
|
||||
|
||||
class A(TypedDict):
|
||||
a: str
|
||||
b: int
|
||||
|
||||
class K:
|
||||
k: Literal['a']
|
||||
|
||||
m: A
|
||||
|
||||
match m:
|
||||
case {"a": v}:
|
||||
assert_type(v, str)
|
||||
case {"b": v2}:
|
||||
assert_type(v2, int)
|
||||
case {"a": v3, "b": v4}:
|
||||
assert_type(v3, str)
|
||||
assert_type(v4, int)
|
||||
case {K.k: v5}:
|
||||
assert_type(v5, str)
|
||||
case {"o": v6}:
|
||||
assert_type(v6, Any)
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchMappingPatternCaptureRest() {
|
||||
doTestByText("""
|
||||
from typing import Mapping, assert_type
|
||||
|
||||
m: object
|
||||
|
||||
match m:
|
||||
case {'k': 1, **r}:
|
||||
assert_type(r, dict[str, int])
|
||||
|
||||
n: Mapping[str, int]
|
||||
|
||||
match n:
|
||||
case {'k': 1, **r}:
|
||||
assert_type(r, dict[str, int])
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchClassPatternCapturePositional() {
|
||||
doTestByText("""
|
||||
from typing import assert_type
|
||||
|
||||
class A:
|
||||
__match_args__ = ("a", "b")
|
||||
a: str
|
||||
b: int
|
||||
|
||||
m: A
|
||||
|
||||
match m:
|
||||
case A(i, j):
|
||||
assert_type(i, str)
|
||||
assert_type(j, int)
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchClassPatternCaptureKeyword() {
|
||||
doTestByText("""
|
||||
from typing import assert_type
|
||||
|
||||
class A:
|
||||
a: str
|
||||
b: int
|
||||
|
||||
m: A
|
||||
|
||||
match m:
|
||||
case A(a=i, b=j):
|
||||
assert_type(i, str)
|
||||
assert_type(j, int)
|
||||
""");
|
||||
}
|
||||
|
||||
public void testMatchClassPatternCaptureSelf() {
|
||||
doTestByText("""
|
||||
from typing import assert_type
|
||||
|
||||
m: object
|
||||
|
||||
match m:
|
||||
case bool(a):
|
||||
assert_type(a, bool)
|
||||
case bytearray(b):
|
||||
assert_type(b, bytearray)
|
||||
case bytes(c):
|
||||
assert_type(c, bytes)
|
||||
case dict(d):
|
||||
assert_type(d, dict)
|
||||
case float(e):
|
||||
assert_type(e, float)
|
||||
case frozenset(f):
|
||||
assert_type(f, frozenset)
|
||||
case int(g):
|
||||
assert_type(g, int)
|
||||
case list(h):
|
||||
assert_type(h, list)
|
||||
case set(i):
|
||||
assert_type(i, set)
|
||||
case str(j):
|
||||
assert_type(j, str)
|
||||
case tuple(k):
|
||||
assert_type(k, tuple)
|
||||
"""
|
||||
);
|
||||
}
|
||||
|
||||
public void testMatchClassPatternNarrowSelfCapture() {
|
||||
doTestByText("""
|
||||
from typing import assert_type
|
||||
|
||||
m: object
|
||||
|
||||
match m:
|
||||
case bool():
|
||||
assert_type(m, bool)
|
||||
case bytearray():
|
||||
assert_type(m, bytearray)
|
||||
case bytes():
|
||||
assert_type(m, bytes)
|
||||
case dict():
|
||||
assert_type(m, dict)
|
||||
case float():
|
||||
assert_type(m, float)
|
||||
case frozenset():
|
||||
assert_type(m, frozenset)
|
||||
case int():
|
||||
assert_type(m, int)
|
||||
case list():
|
||||
assert_type(m, list)
|
||||
case set():
|
||||
assert_type(m, set)
|
||||
case str():
|
||||
assert_type(m, str)
|
||||
case tuple():
|
||||
assert_type(m, tuple)"""
|
||||
);
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
// Copyright 2000-2025 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.jetbrains.python;
|
||||
|
||||
import com.jetbrains.python.documentation.PythonDocumentationProvider;
|
||||
|
||||
Reference in New Issue
Block a user