PY-48011 Pattern Matching: Type inference

Merge-request: IJ-MR-154823
Merged-by: Aleksandr Govenko <aleksandr.govenko@jetbrains.com>

GitOrigin-RevId: 42cb07bee63f34127c85574fc9c09e6043bc7591
This commit is contained in:
Aleksandr.Govenko
2025-03-07 22:56:00 +00:00
committed by intellij-monorepo-bot
parent 6860d4accc
commit 0073c7a8bb
51 changed files with 1301 additions and 377 deletions

View File

@@ -8,6 +8,10 @@ interface PyAstAsPattern : PyAstPattern {
return requireNotNull(findChildByClass(PyAstPattern::class.java)) { "${this}: pattern cannot be null" }
}
fun getTarget(): PyAstTargetExpression {
return requireNotNull(findChildByClass(PyAstTargetExpression::class.java)) { "${this}: target cannot be null" }
}
override fun isIrrefutable(): Boolean {
return getPattern().isIrrefutable
}

View File

@@ -2,6 +2,11 @@
package com.jetbrains.python.ast;
import org.jetbrains.annotations.ApiStatus;
import org.jetbrains.annotations.NotNull;
import java.util.Objects;
import static com.jetbrains.python.ast.PyAstElementKt.findChildByClass;
@ApiStatus.Experimental
public interface PyAstValuePattern extends PyAstPattern {
@@ -10,6 +15,11 @@ public interface PyAstValuePattern extends PyAstPattern {
return false;
}
@NotNull
default PyAstReferenceExpression getValue() {
return Objects.requireNonNull(findChildByClass(this, PyAstReferenceExpression.class));
}
@Override
default void acceptPyVisitor(PyAstElementVisitor pyVisitor) {
pyVisitor.visitPyValuePattern(this);

View File

@@ -186,6 +186,7 @@ public final @NonNls class PyNames {
public static final String ROUND = "__round__";
public static final String CLASS_GETITEM = "__class_getitem__";
public static final String PREPARE = "__prepare__";
public static final String MATCH_ARGS = "__match_args__";
public static final String NAME = "__name__";
public static final String ENTER = "__enter__";

View File

@@ -3,5 +3,5 @@ package com.jetbrains.python.psi;
import com.jetbrains.python.ast.PyAstPattern;
public interface PyPattern extends PyAstPattern, PyElement {
public interface PyPattern extends PyAstPattern, PyTypedElement {
}

View File

@@ -2,6 +2,14 @@
package com.jetbrains.python.psi;
import com.jetbrains.python.ast.PyAstSequencePattern;
import org.jetbrains.annotations.NotNull;
import java.util.List;
import static com.jetbrains.python.ast.PyAstElementKt.findChildrenByClass;
public interface PySequencePattern extends PyAstSequencePattern, PyPattern {
default @NotNull List<@NotNull PyPattern> getElements() {
return List.of(findChildrenByClass(this, PyPattern.class));
}
}

View File

@@ -2,6 +2,21 @@
package com.jetbrains.python.psi;
import com.jetbrains.python.ast.PyAstSingleStarPattern;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.List;
import java.util.Objects;
import static com.jetbrains.python.ast.PyAstElementKt.findChildByClass;
public interface PySingleStarPattern extends PyAstSingleStarPattern, PyPattern {
@NotNull
default PyPattern getPattern() {
return Objects.requireNonNull(findChildByClass(this, PyPattern.class));
}
@NotNull List<@Nullable PyType> getCapturedTypesFromSequenceType(@Nullable PyType sequenceType, @NotNull TypeEvalContext context);
}

View File

@@ -2,6 +2,11 @@
package com.jetbrains.python.psi;
import com.jetbrains.python.ast.PyAstValuePattern;
import org.jetbrains.annotations.NotNull;
public interface PyValuePattern extends PyAstValuePattern, PyPattern {
@Override
default @NotNull PyReferenceExpression getValue() {
return (PyReferenceExpression)PyAstValuePattern.super.getValue();
}
}

View File

@@ -27,6 +27,7 @@ import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiNamedElement;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.psi.util.QualifiedName;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.PyTokenTypes;
import com.jetbrains.python.psi.*;
@@ -330,7 +331,134 @@ public class PyControlFlowBuilder extends PyRecursiveElementVisitor {
@Override
public void visitPyMatchStatement(@NotNull PyMatchStatement matchStatement) {
new PyMatchStatementControlFlowBuilder(myBuilder, this).build(matchStatement);
myBuilder.startNode(matchStatement);
PyExpression subject = matchStatement.getSubject();
if (subject != null) {
subject.accept(this);
}
for (PyCaseClause caseClause : matchStatement.getCaseClauses()) {
visitPyCaseClause(caseClause);
}
myBuilder.addNodeAndCheckPending(new TransparentInstructionImpl(myBuilder, matchStatement, ""));
if (!myBuilder.prevInstruction.allPred().isEmpty()) {
addTypeAssertionNodes(matchStatement, false);
}
myBuilder.addPendingEdge(matchStatement, myBuilder.prevInstruction);
myBuilder.prevInstruction = null;
}
@Override
public void visitPyCaseClause(@NotNull PyCaseClause clause) {
PyPattern pattern = clause.getPattern();
if (pattern != null) {
pattern.accept(this);
addTypeAssertionNodes(pattern, true);
}
TransparentInstruction trueNode = addTransparentInstruction();
TransparentInstruction falseNode = addTransparentInstruction();
PyExpression guard = clause.getGuardCondition();
if (guard != null) {
visitCondition(guard, trueNode, falseNode);
addTypeAssertionNodes(guard, true);
}
else {
myBuilder.addEdge(myBuilder.prevInstruction, trueNode);
}
myBuilder.addPendingEdge(clause, falseNode);
myBuilder.prevInstruction = trueNode;
clause.getStatementList().accept(this);
if (clause.getParent() instanceof PyMatchStatement matchStatement) {
myBuilder.addPendingEdge(matchStatement, myBuilder.prevInstruction);
myBuilder.updatePendingElementScope(clause.getStatementList(), matchStatement);
}
myBuilder.prevInstruction = null;
}
@Override
public void visitWildcardPattern(@NotNull PyWildcardPattern node) {
myBuilder.startNode(node);
}
@Override
public void visitPyPattern(@NotNull PyPattern node) {
boolean isRefutable = !node.isIrrefutable();
if (isRefutable) {
myBuilder.addNodeAndCheckPending(new RefutablePatternInstruction(myBuilder, node, false));
myBuilder.addPendingEdge(node.getParent(), myBuilder.prevInstruction);
}
node.acceptChildren(this);
myBuilder.updatePendingElementScope(node, node.getParent());
if (isRefutable) {
myBuilder.addNode(new RefutablePatternInstruction(myBuilder, node, true));
}
}
@Override
public void visitPyOrPattern(@NotNull PyOrPattern node) {
myBuilder.addNodeAndCheckPending(new RefutablePatternInstruction(myBuilder, node, false));
TransparentInstruction onSuccess = new TransparentInstructionImpl(myBuilder, node, "onSuccess");
List<PyPattern> alternatives = node.getAlternatives();
PyPattern lastAlternative = ContainerUtil.getLastItem(alternatives);
for (PyPattern alternative : alternatives) {
alternative.accept(this);
if (alternative != lastAlternative) {
// Allow next alternative to handle the fail edge of this alternative
myBuilder.updatePendingElementScope(node, alternative);
}
myBuilder.addEdge(myBuilder.prevInstruction, onSuccess);
myBuilder.prevInstruction = null;
}
myBuilder.addNode(onSuccess);
myBuilder.addNode(new RefutablePatternInstruction(myBuilder, node, true));
myBuilder.updatePendingElementScope(node, node.getParent());
}
@Override
public void visitPyClassPattern(@NotNull PyClassPattern node) {
myBuilder.addNodeAndCheckPending(new RefutablePatternInstruction(myBuilder, node, false));
node.getClassNameReference().accept(this);
myBuilder.addPendingEdge(node.getParent(), myBuilder.prevInstruction);
node.getArgumentList().acceptChildren(this);
myBuilder.updatePendingElementScope(node, node.getParent());
myBuilder.addNode(new RefutablePatternInstruction(myBuilder, node, true));
}
@Override
public void visitPyValuePattern(@NotNull PyValuePattern node) {
myBuilder.addNodeAndCheckPending(new RefutablePatternInstruction(myBuilder, node, false));
node.getValue().accept(this);
myBuilder.addPendingEdge(node.getParent(), myBuilder.prevInstruction);
myBuilder.addNode(new RefutablePatternInstruction(myBuilder, node, true));
}
@Override
public void visitPyAsPattern(@NotNull PyAsPattern node) {
// AsPattern cannot fail by itself - it fails only if its child fails.
// So no need to create additional fail edge
myBuilder.startNode(node);
node.acceptChildren(this);
myBuilder.updatePendingElementScope(node, node.getParent());
}
@Override
public void visitPyGroupPattern(@NotNull PyGroupPattern node) {
// GroupPattern cannot fail by itself - it fails only if its child fails.
// So no need to create additional fail edge
// Also no need to dedicated node for GroupPattern itself
node.acceptChildren(this);
myBuilder.updatePendingElementScope(node, node.getParent());
}
@Override
@@ -955,7 +1083,7 @@ public class PyControlFlowBuilder extends PyRecursiveElementVisitor {
|| element instanceof PyStatementList);
}
private void addTypeAssertionNodes(@NotNull PyExpression condition, boolean positive) {
private void addTypeAssertionNodes(@NotNull PyElement condition, boolean positive) {
final PyTypeAssertionEvaluator evaluator = new PyTypeAssertionEvaluator(positive);
condition.accept(evaluator);
for (PyTypeAssertionEvaluator.Assertion def : evaluator.getDefinitions()) {

View File

@@ -16,6 +16,7 @@ class PyControlFlowProvider : ControlFlowProvider {
override fun getAdditionalInfo(instruction: Instruction): String? {
return when (instruction) {
is ReadWriteInstruction -> "${instruction.access} ${instruction.name}"
is RefutablePatternInstruction -> instruction.elementPresentation
else -> null
}
}

View File

@@ -1,165 +0,0 @@
package com.jetbrains.python.codeInsight.controlflow;
import com.intellij.codeInsight.controlflow.ConditionalInstruction;
import com.intellij.codeInsight.controlflow.ControlFlowBuilder;
import com.intellij.codeInsight.controlflow.Instruction;
import com.intellij.psi.PsiElement;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.PyTokenTypes;
import com.jetbrains.python.psi.*;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.List;
import java.util.function.BiFunction;
public final class PyMatchStatementControlFlowBuilder {
private final ControlFlowBuilder myBuilder;
private final PyElementVisitor myBaseVisitor;
public PyMatchStatementControlFlowBuilder(@NotNull ControlFlowBuilder builder, @NotNull PyElementVisitor baseVisitor) {
myBuilder = builder;
myBaseVisitor = baseVisitor;
}
public void build(@NotNull PyMatchStatement matchStatement) {
myBuilder.startNode(matchStatement);
PyExpression subject = matchStatement.getSubject();
if (subject != null) {
subject.accept(myBaseVisitor);
}
for (PyCaseClause caseClause : matchStatement.getCaseClauses()) {
processCaseClause(caseClause);
}
}
private void processCaseClause(@NotNull PyCaseClause clause) {
PyPattern pattern = clause.getPattern();
if (pattern != null) {
processPattern(pattern);
retargetOutgoingPatternEdges(pattern, (oldScope, instr) -> {
return instr.isMatched() ? pattern : clause;
});
}
PyStatementList statementList = clause.getStatementList();
PyExpression guard = clause.getGuardCondition();
if (guard != null) {
guard.accept(myBaseVisitor);
// Retarget failure edges coming from inner OR and AND expressions
retargetOutgoingEdges(guard, (pendingScope, instr) -> {
if (instr instanceof ConditionalInstruction && !((ConditionalInstruction)instr).getResult()) {
return clause;
}
return pendingScope;
});
// Top-level OR and AND expressions should have had their own outgoing failure edges
if (!isConjunctionOrDisjunction(guard)) {
myBuilder.addPendingEdge(clause, myBuilder.prevInstruction);
}
myBuilder.startConditionalNode(statementList, guard, true);
}
statementList.accept(myBaseVisitor);
PyMatchStatement matchStatement = PsiTreeUtil.getParentOfType(clause, PyMatchStatement.class);
assert matchStatement != null;
retargetOutgoingEdges(statementList, (pendingScope, instruction) -> matchStatement);
myBuilder.addPendingEdge(matchStatement, myBuilder.prevInstruction);
myBuilder.prevInstruction = null;
}
private void processPattern(@NotNull PyPattern pattern) {
boolean isRefutable = !pattern.isIrrefutable();
if (isRefutable) {
RefutablePatternInstruction instruction = new RefutablePatternInstruction(myBuilder, pattern, false);
myBuilder.addNodeAndCheckPending(instruction);
myBuilder.addPendingEdge(pattern, instruction);
}
if (pattern instanceof PyWildcardPattern) {
myBuilder.startNode(pattern);
}
else if (pattern instanceof PyOrPattern) {
List<PyPattern> alternatives = ((PyOrPattern)pattern).getAlternatives();
PyPattern lastAlternative = ContainerUtil.getLastItem(alternatives);
for (PyPattern alternative : alternatives) {
processPattern(alternative);
if (alternative != lastAlternative) {
myBuilder.addPendingEdge(alternative, myBuilder.prevInstruction);
myBuilder.prevInstruction = null;
}
retargetOutgoingEdges(alternative, (pendingScope, instruction) -> {
if (instruction instanceof RefutablePatternInstruction && !((RefutablePatternInstruction)instruction).isMatched()) {
// Pattern has failed, jump to the next alternative if any
return alternative;
}
// Pattern succeeded, jump out of OR-pattern. It can be either a refutable pattern or a capture/wildcard node.
else {
return pattern;
}
});
}
}
else {
pattern.acceptChildren(new PyElementVisitor() {
@Override
public void visitPyReferenceExpression(@NotNull PyReferenceExpression node) {
myBaseVisitor.visitPyReferenceExpression(node);
}
@Override
public void visitPyTargetExpression(@NotNull PyTargetExpression node) {
myBaseVisitor.visitPyTargetExpression(node);
}
@Override
public void visitPyPattern(@NotNull PyPattern node) {
processPattern(node);
retargetOutgoingPatternEdges(pattern, (oldScope, instr) -> {
// Mismatch in a non-OR pattern means mismatch of the containing pattern as well
return instr.isMatched() ? oldScope : pattern;
});
}
@Override
public void visitPyPatternArgumentList(@NotNull PyPatternArgumentList node) {
node.acceptChildren(this);
}
});
}
if (isRefutable) {
myBuilder.addNode(new RefutablePatternInstruction(myBuilder, pattern, true));
}
}
private void retargetOutgoingEdges(@NotNull PsiElement scopeAncestor,
@NotNull BiFunction<PsiElement, Instruction, PsiElement> newScopeProvider) {
myBuilder.processPending((oldScope, instruction) -> {
if (oldScope != null && PsiTreeUtil.isAncestor(scopeAncestor, oldScope, false)) {
myBuilder.addPendingEdge(newScopeProvider.apply(oldScope, instruction), instruction);
}
else {
myBuilder.addPendingEdge(oldScope, instruction);
}
});
}
private void retargetOutgoingPatternEdges(@NotNull PsiElement scopeAncestor,
@NotNull BiFunction<PsiElement, RefutablePatternInstruction, PsiElement> newScopeProvider) {
retargetOutgoingEdges(scopeAncestor, (pendingScope, instruction) -> {
if (instruction instanceof RefutablePatternInstruction) {
return newScopeProvider.apply(pendingScope, (RefutablePatternInstruction)instruction);
}
return pendingScope;
});
}
private static boolean isConjunctionOrDisjunction(@Nullable PyExpression node) {
if (node instanceof PyBinaryExpression) {
final var operator = ((PyBinaryExpression)node).getOperator();
return operator == PyTokenTypes.AND_KEYWORD || operator == PyTokenTypes.OR_KEYWORD;
}
return false;
}
}

View File

@@ -3,6 +3,7 @@ package com.jetbrains.python.codeInsight.controlflow;
import com.intellij.openapi.util.Ref;
import com.intellij.psi.PsiElement;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.containers.ContainerUtil;
import com.intellij.util.containers.Stack;
import com.jetbrains.python.PyNames;
@@ -22,6 +23,8 @@ import java.util.ArrayList;
import java.util.List;
import java.util.function.Function;
import static com.jetbrains.python.psi.PyUtil.as;
public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
private final Stack<Assertion> myStack = new Stack<>();
private boolean myPositive;
@@ -41,14 +44,15 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
if (args.length == 2 && args[0] instanceof PyReferenceExpression target) {
final PyExpression typeElement = args[1];
pushAssertion(target, myPositive, false, context -> context.getType(typeElement), typeElement);
pushAssertion(target, myPositive, context ->
transformTypeFromAssertion(context.getType(typeElement), false, context, typeElement));
}
}
else if (node.isCalleeText(PyNames.CALLABLE_BUILTIN)) {
final PyExpression[] args = node.getArguments();
if (args.length == 1 && args[0] instanceof PyReferenceExpression target) {
pushAssertion(target, myPositive, false, context -> PyTypingTypeProvider.createTypingCallableType(node), null);
pushAssertion(target, myPositive, context -> PyTypingTypeProvider.createTypingCallableType(node));
}
}
else if (node.isCalleeText(PyNames.ISSUBCLASS)) {
@@ -56,7 +60,8 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
if (args.length == 2 && args[0] instanceof PyReferenceExpression target) {
final PyExpression typeElement = args[1];
pushAssertion(target, myPositive, true, context -> context.getType(typeElement), typeElement);
pushAssertion(target, myPositive, context ->
transformTypeFromAssertion(context.getType(typeElement), true, context, typeElement));
}
}
}
@@ -66,7 +71,7 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
if (myPositive && (isIfReferenceStatement(node) || isIfReferenceConditionalStatement(node) || isIfNotReferenceStatement(node))) {
// we could not suggest `None` because it could be a reference to an empty collection
// so we could push only non-`None` assertions
pushAssertion(node, !myPositive, false, context -> PyNoneType.INSTANCE, null);
pushAssertion(node, !myPositive, context -> PyNoneType.INSTANCE);
return;
}
@@ -105,26 +110,26 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
if (PyLiteralType.isNone(lhs)) {
if (rhs instanceof PyReferenceExpression referenceExpr) {
pushAssertion(referenceExpr, myPositive, false, context -> PyNoneType.INSTANCE, null);
pushAssertion(referenceExpr, myPositive, context -> PyNoneType.INSTANCE);
}
return;
}
if (PyLiteralType.isNone(rhs)) {
if (lhs instanceof PyReferenceExpression referenceExpr) {
pushAssertion(referenceExpr, myPositive, false, context -> PyNoneType.INSTANCE, null);
pushAssertion(referenceExpr, myPositive, context -> PyNoneType.INSTANCE);
}
return;
}
if (lhs instanceof PyReferenceExpression referenceExpr) {
pushAssertion(referenceExpr, myPositive, false, context -> getLiteralType(rhs, context), null);
pushAssertion(referenceExpr, myPositive, context -> getLiteralType(rhs, context));
}
}
private void processIn(@NotNull PyExpression lhs, @NotNull PyExpression rhs) {
if (lhs instanceof PyReferenceExpression referenceExpr && rhs instanceof PyTupleExpression tupleExpr) {
pushAssertion(referenceExpr, myPositive, false, context -> {
pushAssertion(referenceExpr, myPositive, (TypeEvalContext context) -> {
PyExpression[] elements = tupleExpr.getElements();
List<PyType> types = new ArrayList<>(elements.length);
for (PyExpression element : elements) {
@@ -135,7 +140,7 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
types.add(type);
}
return PyUnionType.union(types);
}, null);
});
}
}
@@ -160,6 +165,41 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
}
}
@Override
public void visitPyPattern(@NotNull PyPattern node) {
final PsiElement parent = PsiTreeUtil.skipParentsOfType(node, PyCaseClause.class);
if (parent instanceof PyMatchStatement matchStatement) {
final PyExpression subject = PyPsiUtils.flattenParens(matchStatement.getSubject());
if (subject instanceof PyReferenceExpression target) {
pushAssertion(target, myPositive, context -> context.getType(node));
}
}
}
/**
* Negative type assertion for when all cases fail
*/
@Override
public void visitPyMatchStatement(@NotNull PyMatchStatement matchStatement) {
assert !myPositive; // for match statement as a whole, only negative assertion can be made
final PyExpression subject = matchStatement.getSubject();
if (subject == null) return;
if (subject instanceof PyReferenceExpression target) {
pushAssertion(target, true, context -> {
PyType subjectType = context.getType(subject);
for (PyCaseClause cs : matchStatement.getCaseClauses()) {
if (cs.getPattern() == null) continue;
if (cs.getGuardCondition() != null) continue;
subjectType = Ref.deref(createAssertionType(subjectType, context.getType(cs.getPattern()), false, context));
}
return subjectType;
});
}
}
@ApiStatus.Internal
public static @Nullable Ref<PyType> createAssertionType(@Nullable PyType initial,
@Nullable PyType suggested,
@@ -233,6 +273,9 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
PyTypeChecker.match(expected, actual, context);
}
/**
* @param transformToDefinition if true the result type will be Type[T], not T itself.
*/
private static @Nullable PyType transformTypeFromAssertion(@Nullable PyType type, boolean transformToDefinition, @NotNull TypeEvalContext context,
@Nullable PyExpression typeElement) {
/*
@@ -245,8 +288,7 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
final List<PyType> members = new ArrayList<>();
final int count = tupleType.getElementCount();
final PyTupleExpression tupleExpression = PyUtil
.as(PyPsiUtils.flattenParens(PyUtil.as(typeElement, PyParenthesizedExpression.class)), PyTupleExpression.class);
final PyTupleExpression tupleExpression = as(PyPsiUtils.flattenParens(typeElement), PyTupleExpression.class);
if (tupleExpression != null && tupleExpression.getElements().length == count) {
final PyExpression[] elements = tupleExpression.getElements();
for (int i = 0; i < count; i++) {
@@ -276,21 +318,13 @@ public class PyTypeAssertionEvaluator extends PyRecursiveElementVisitor {
return type;
}
/**
* @param transformToDefinition if true the result type will be Type[T], not T itself.
*/
private void pushAssertion(@NotNull PyReferenceExpression target,
boolean positive,
boolean transformToDefinition,
@NotNull Function<TypeEvalContext, PyType> suggestedType,
@Nullable PyExpression typeElement) {
@NotNull Function<TypeEvalContext, PyType> suggestedType) {
final InstructionTypeCallback typeCallback = new InstructionTypeCallback() {
@Override
public Ref<PyType> getType(TypeEvalContext context) {
return createAssertionType(context.getType(target),
transformTypeFromAssertion(suggestedType.apply(context), transformToDefinition, context, typeElement),
positive,
context);
return createAssertionType(context.getType(target), suggestedType.apply(context), positive, context);
}
};

View File

@@ -128,10 +128,8 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
PARAM_SPEC, PARAM_SPEC_EXT,
TYPE_VAR_TUPLE, TYPE_VAR_TUPLE_EXT
);
public static final String CONTEXT_MANAGER = "contextlib.AbstractContextManager";
public static final String CONTEXT_MANAGER = "contextlib.AbstractContextManager";
public static final String ASYNC_CONTEXT_MANAGER = "contextlib.AbstractAsyncContextManager";
public static final Set<String> TYPE_DICT_QUALIFIERS = Set.of(REQUIRED, REQUIRED_EXT, NOT_REQUIRED, NOT_REQUIRED_EXT, READONLY, READONLY_EXT);
public static final String UNPACK = "typing.Unpack";

View File

@@ -3,6 +3,10 @@ package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.jetbrains.python.psi.PyAsPattern;
import com.jetbrains.python.psi.PyElementVisitor;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
public class PyAsPatternImpl extends PyElementImpl implements PyAsPattern {
public PyAsPatternImpl(ASTNode astNode) {
@@ -13,4 +17,9 @@ public class PyAsPatternImpl extends PyElementImpl implements PyAsPattern {
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
pyVisitor.visitPyAsPattern(this);
}
@Override
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
return context.getType(getPattern());
}
}

View File

@@ -1,8 +1,24 @@
package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.jetbrains.python.psi.PyCapturePattern;
import com.jetbrains.python.psi.PyElementVisitor;
import com.intellij.openapi.util.Ref;
import com.intellij.psi.PsiElement;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.codeInsight.controlflow.PyTypeAssertionEvaluator;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.resolve.PyResolveContext;
import com.jetbrains.python.psi.resolve.RatedResolveResult;
import com.jetbrains.python.psi.types.*;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.List;
import java.util.Set;
import static com.jetbrains.python.psi.PyUtil.as;
import static com.jetbrains.python.psi.impl.PySequencePatternImpl.wrapInListType;
public class PyCapturePatternImpl extends PyElementImpl implements PyCapturePattern {
public PyCapturePatternImpl(ASTNode astNode) {
@@ -13,4 +29,158 @@ public class PyCapturePatternImpl extends PyElementImpl implements PyCapturePatt
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
pyVisitor.visitPyCapturePattern(this);
}
@Override
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
return getCaptureType(this, context);
}
static Set<String> SPECIAL_BUILTINS = Set.of(
"bool", "bytearray", "bytes", "dict", "float", "frozenset", "int", "list", "set", "str", "tuple");
/**
* Determines the type of a given pattern assuming it is a capture pattern (even when it is actually not),
* and looking up (to parents or subject expression of the match statement).
*/
static @Nullable PyType getCaptureType(@NotNull PyPattern pattern, @NotNull TypeEvalContext context) {
final PyElement parentPattern = PsiTreeUtil.getParentOfType(
pattern, // Capture corresponds to:
PyCaseClause.class, // - Subject of a match statement
PySingleStarPattern.class, // - Type of parent sequence pattern
PyDoubleStarPattern.class, // - Type of parent mapping pattern
PyKeyValuePattern.class, // - Any
PySequencePattern.class, // - Iterated item type of sequence
PyClassPattern.class, // - Attribute type in the corresponding class
PyKeywordPattern.class // - Attribute type in the corresponding class
);
if (parentPattern instanceof PyCaseClause caseClause) {
final PyMatchStatement matchStatement = as(caseClause.getParent(), PyMatchStatement.class);
if (matchStatement == null) return null;
final PyExpression subject = matchStatement.getSubject();
if (subject == null) return null;
PyType subjectType = context.getType(subject);
for (PyCaseClause cs : matchStatement.getCaseClauses()) {
if (cs == caseClause) break;
if (cs.getPattern() == null) continue;
if (cs.getGuardCondition() != null) continue;
subjectType = Ref.deref(
PyTypeAssertionEvaluator.createAssertionType(subjectType, context.getType(cs.getPattern()), false, context));
}
return subjectType;
}
if (parentPattern instanceof PySingleStarPattern starPattern) {
final PySequencePattern sequenceParent = as(starPattern.getParent(), PySequencePattern.class);
if (sequenceParent == null) return null;
final PyType sequenceType = PySequencePatternImpl.getSequenceCaptureType(sequenceParent, context);
final PyType iteratedType = PyTypeUtil.toStream(sequenceType)
.flatMap(it -> starPattern.getCapturedTypesFromSequenceType(it, context).stream()).collect(PyTypeUtil.toUnion());
return wrapInListType(iteratedType, pattern);
}
if (parentPattern instanceof PyDoubleStarPattern) {
final PyMappingPattern mappingParent = as(parentPattern.getParent(), PyMappingPattern.class);
if (mappingParent == null) return null;
var parentType = context.getType(mappingParent);
if (parentType instanceof PyCollectionType collectionType) {
final PyClass dict = PyBuiltinCache.getInstance(pattern).getClass("dict");
return dict != null ? new PyCollectionTypeImpl(dict, false, collectionType.getElementTypes()) : null;
}
return null;
}
if (parentPattern instanceof PyKeyValuePattern keyValuePattern) {
final PyMappingPattern mappingParent = as(keyValuePattern.getParent(), PyMappingPattern.class);
if (mappingParent == null) return null;
var dictType = getCaptureType(mappingParent, context);
if (dictType == null) return null;
if (dictType instanceof PyTypedDictType typedDictType) {
if (context.getType(keyValuePattern.getKeyPattern()) instanceof PyLiteralType l && l.getExpression() instanceof PyStringLiteralExpression str) {
return typedDictType.getElementType(str.getStringValue());
}
}
var mappingType = PyTypeUtil.convertToType(dictType, "typing.Mapping", pattern, context);
if (mappingType instanceof PyCollectionType collectionType) {
return collectionType.getElementTypes().get(1);
}
return null;
}
if (parentPattern instanceof PySequencePattern sequencePattern) {
final PyType sequenceType = PySequencePatternImpl.getSequenceCaptureType(sequencePattern, context);
if (sequenceType == null) return null;
return PyTypeUtil.toStream(sequenceType).map(it -> {
if (it instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) {
// This is done to skip group- and as-patterns
final var sequenceMember = PsiTreeUtil.findFirstParent(pattern, el -> el.getParent() == sequencePattern);
final List<PyPattern> elements = sequencePattern.getElements();
final int idx = elements.indexOf(sequenceMember);
final int starIdx = ContainerUtil.indexOf(elements, it2 -> it2 instanceof PySingleStarPattern);
if (starIdx == -1 || idx < starIdx) {
return tupleType.getElementType(idx);
}
else {
final int starSpan = tupleType.getElementCount() - elements.size();
return tupleType.getElementType(idx + starSpan);
}
}
var upcast = PyTypeUtil.convertToType(it, "typing.Sequence", pattern, context);
if (upcast instanceof PyCollectionType collectionType) {
return collectionType.getIteratedItemType();
}
return null;
}).collect(PyTypeUtil.toUnion());
}
if (parentPattern instanceof PyClassPattern classPattern) {
if (context.getType(classPattern) instanceof PyClassType classType) {
final List<PyPattern> arguments = classPattern.getArgumentList().getPatterns();
int index = arguments.indexOf(pattern);
if (index < 0) return null;
if (SPECIAL_BUILTINS.contains(classType.getClassQName())) {
if (index == 0) {
return context.getType(classPattern);
}
return null;
}
final PyTargetExpression matchArgs = as(resolveTypeMember(classType, PyNames.MATCH_ARGS, context), PyTargetExpression.class);
if (matchArgs == null) return null;
var matchArgsValue = PyPsiUtils.flattenParens(matchArgs.findAssignedValue());
if (matchArgsValue instanceof PySequenceExpression sequenceExpression) {
if (sequenceExpression.getElements().length <= index) return null;
final String attributeName = PyEvaluator.evaluate(sequenceExpression.getElements()[index], String.class);
if (attributeName == null) return null;
final PyExpression instanceAttribute = as(resolveTypeMember(classType, attributeName, context), PyExpression.class);
if (instanceAttribute == null) return null;
return context.getType(instanceAttribute);
}
}
return null;
}
if (parentPattern instanceof PyKeywordPattern keywordPattern) {
final PyClassPattern classPattern = PsiTreeUtil.getParentOfType(keywordPattern, PyClassPattern.class);
if (classPattern == null) return null;
if (context.getType(classPattern) instanceof PyClassType classType) {
final PyExpression instanceAttribute = as(resolveTypeMember(classType, keywordPattern.getKeyword(), context), PyExpression.class);
if (instanceAttribute == null) return null;
return context.getType(instanceAttribute);
}
return null;
}
return null;
}
@Nullable
private static PsiElement resolveTypeMember(@NotNull PyType type, @NotNull String name, @NotNull TypeEvalContext context) {
final PyResolveContext resolveContext = PyResolveContext.defaultContext(context);
final List<? extends RatedResolveResult> results = type.resolveMember(name, null, AccessDirection.READ, resolveContext);
return !ContainerUtil.isEmpty(results) ? results.get(0).getElement() : null;
}
}

View File

@@ -3,6 +3,12 @@ package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.jetbrains.python.psi.PyClassPattern;
import com.jetbrains.python.psi.PyElementVisitor;
import com.jetbrains.python.psi.types.PyClassType;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.PyTypeChecker;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
public class PyClassPatternImpl extends PyElementImpl implements PyClassPattern {
public PyClassPatternImpl(ASTNode astNode) {
@@ -13,4 +19,17 @@ public class PyClassPatternImpl extends PyElementImpl implements PyClassPattern
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
pyVisitor.visitPyClassPattern(this);
}
@Override
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
final PyType type = context.getType(getClassNameReference());
if (type instanceof PyClassType classType) {
final PyType instanceType = classType.toInstance();
final PyType captureType = PyCapturePatternImpl.getCaptureType(this, context);
if (PyTypeChecker.match(captureType, instanceType, context)) {
return instanceType;
}
}
return null;
}
}

View File

@@ -3,6 +3,10 @@ package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.jetbrains.python.psi.PyDoubleStarPattern;
import com.jetbrains.python.psi.PyElementVisitor;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
public class PyDoubleStarPatternImpl extends PyElementImpl implements PyDoubleStarPattern {
public PyDoubleStarPatternImpl(ASTNode astNode) {
@@ -13,4 +17,9 @@ public class PyDoubleStarPatternImpl extends PyElementImpl implements PyDoubleSt
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
pyVisitor.visitPyDoubleStarPattern(this);
}
@Override
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
return null;
}
}

View File

@@ -3,6 +3,10 @@ package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.jetbrains.python.psi.PyElementVisitor;
import com.jetbrains.python.psi.PyGroupPattern;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
public class PyGroupPatternImpl extends PyElementImpl implements PyGroupPattern {
public PyGroupPatternImpl(ASTNode astNode) {
@@ -13,4 +17,9 @@ public class PyGroupPatternImpl extends PyElementImpl implements PyGroupPattern
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
pyVisitor.visitPyGroupPattern(this);
}
@Override
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
return context.getType(getPattern());
}
}

View File

@@ -3,6 +3,13 @@ package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.jetbrains.python.psi.PyElementVisitor;
import com.jetbrains.python.psi.PyKeyValuePattern;
import com.jetbrains.python.psi.PyPattern;
import com.jetbrains.python.psi.types.PyTupleType;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
import java.util.Arrays;
public class PyKeyValuePatternImpl extends PyElementImpl implements PyKeyValuePattern {
public PyKeyValuePatternImpl(ASTNode astNode) {
@@ -13,4 +20,15 @@ public class PyKeyValuePatternImpl extends PyElementImpl implements PyKeyValuePa
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
pyVisitor.visitPyKeyValuePattern(this);
}
@Override
public PyType getType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) {
final PyType keyType = context.getType(getKeyPattern());
final PyPattern value = getValuePattern();
PyType valueType = null;
if (value != null) {
valueType = context.getType(value);
}
return PyTupleType.create(this, Arrays.asList(keyType, valueType));
}
}

View File

@@ -4,7 +4,12 @@ import com.intellij.lang.ASTNode;
import com.intellij.psi.PsiReference;
import com.jetbrains.python.psi.PyElementVisitor;
import com.jetbrains.python.psi.PyKeywordPattern;
import com.jetbrains.python.psi.PyPattern;
import com.jetbrains.python.psi.impl.references.PyKeywordPatternReference;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
public class PyKeywordPatternImpl extends PyElementImpl implements PyKeywordPattern {
public PyKeywordPatternImpl(ASTNode astNode) {
@@ -20,4 +25,10 @@ public class PyKeywordPatternImpl extends PyElementImpl implements PyKeywordPatt
public PsiReference getReference() {
return new PyKeywordPatternReference(this);
}
@Override
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
final PyPattern valuePattern = getValuePattern();
return valuePattern != null ? context.getType(valuePattern) : null;
}
}

View File

@@ -3,6 +3,11 @@ package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.jetbrains.python.psi.PyElementVisitor;
import com.jetbrains.python.psi.PyLiteralPattern;
import com.jetbrains.python.psi.types.PyLiteralType;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
public class PyLiteralPatternImpl extends PyElementImpl implements PyLiteralPattern {
public PyLiteralPatternImpl(ASTNode astNode) {
@@ -13,4 +18,13 @@ public class PyLiteralPatternImpl extends PyElementImpl implements PyLiteralPatt
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
pyVisitor.visitPyLiteralPattern(this);
}
@Override
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
PyType literalType = PyLiteralType.Companion.fromLiteralParameter(getExpression(), context);
if (literalType != null) {
return literalType;
}
return context.getType(getExpression());
}
}

View File

@@ -3,11 +3,12 @@ package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiListLikeElement;
import com.jetbrains.python.psi.PyElementVisitor;
import com.jetbrains.python.psi.PyKeyValuePattern;
import com.jetbrains.python.psi.PyMappingPattern;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.types.*;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@@ -22,7 +23,30 @@ public class PyMappingPatternImpl extends PyElementImpl implements PyMappingPatt
}
@Override
public @NotNull List<? extends PsiElement> getComponents() {
public @NotNull List<? extends PyKeyValuePattern> getComponents() {
return Arrays.asList(findChildrenByClass(PyKeyValuePattern.class));
}
@Override
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
ArrayList<PyType> keyTypes = new ArrayList<>();
ArrayList<PyType> valueTypes = new ArrayList<>();
for (PyKeyValuePattern it : getComponents()) {
PyType type = context.getType(it);
if (type instanceof PyTupleType tupleType) {
keyTypes.add(tupleType.getElementType(0));
valueTypes.add(tupleType.getElementType(1));
}
}
//keyTypes.add(null);
//valueTypes.add(null);
return wrapInMappingType(PyUnionType.union(keyTypes), PyUnionType.union(valueTypes), this);
}
private static @Nullable PyType wrapInMappingType(@Nullable PyType keyType, @Nullable PyType valueType, @NotNull PsiElement resolveAnchor) {
keyType = PyLiteralType.upcastLiteralToClass(keyType);
valueType = PyLiteralType.upcastLiteralToClass(valueType);
final PyClass sequence = PyPsiFacade.getInstance(resolveAnchor.getProject()).createClassByQName("typing.Mapping", resolveAnchor);
return sequence != null ? new PyCollectionTypeImpl(sequence, false, Arrays.asList(keyType, valueType)) : null;
}
}

View File

@@ -1,8 +1,14 @@
package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.psi.PyElementVisitor;
import com.jetbrains.python.psi.PyOrPattern;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.PyUnionType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
public class PyOrPatternImpl extends PyElementImpl implements PyOrPattern {
public PyOrPatternImpl(ASTNode astNode) {
@@ -13,4 +19,9 @@ public class PyOrPatternImpl extends PyElementImpl implements PyOrPattern {
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
pyVisitor.visitPyOrPattern(this);
}
@Override
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
return PyUnionType.union(ContainerUtil.map(getAlternatives(), it -> context.getType(it)));
}
}

View File

@@ -3,13 +3,16 @@ package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiListLikeElement;
import com.jetbrains.python.psi.PyElementVisitor;
import com.jetbrains.python.psi.PyPattern;
import com.jetbrains.python.psi.PySequencePattern;
import com.intellij.util.ArrayUtil;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.types.*;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.Arrays;
import java.util.List;
import java.util.*;
import static com.jetbrains.python.psi.types.PyLiteralType.upcastLiteralToClass;
public class PySequencePatternImpl extends PyElementImpl implements PySequencePattern, PsiListLikeElement {
public PySequencePatternImpl(ASTNode astNode) {
@@ -22,7 +25,74 @@ public class PySequencePatternImpl extends PyElementImpl implements PySequencePa
}
@Override
public @NotNull List<? extends PsiElement> getComponents() {
return Arrays.asList(findChildrenByClass(PyPattern.class));
public @NotNull List<? extends PyPattern> getComponents() {
return getElements();
}
@Override
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
final PyType sequenceCaptureType = getSequenceCaptureType(this, context);
boolean isHomogeneous = !(sequenceCaptureType instanceof PyTupleType tupleType) || tupleType.isHomogeneous();
final ArrayList<PyType> types = new ArrayList<>();
for (PyPattern pattern : getElements()) {
if (pattern instanceof PySingleStarPattern starPattern) {
types.addAll(starPattern.getCapturedTypesFromSequenceType(sequenceCaptureType, context));
}
else {
types.add(context.getType(pattern));
}
}
PyType expectedType = isHomogeneous ? wrapInSequenceType(PyUnionType.union(types), this) : PyTupleType.create(this, types);
PyType captureType = PyCapturePatternImpl.getCaptureType(this, context);
if (captureType == null) return expectedType;
return PyTypeUtil.toStream(captureType)
.map(it -> {
if (PyTypeChecker.match(expectedType, it, context)) {
return it;
}
return expectedType;
})
.collect(PyTypeUtil.toUnion());
}
static @Nullable PyType wrapInListType(@Nullable PyType elementType, @NotNull PsiElement resolveAnchor) {
final PyClass list = PyBuiltinCache.getInstance(resolveAnchor).getClass("list");
return list != null ? new PyCollectionTypeImpl(list, false, Collections.singletonList(upcastLiteralToClass(elementType))) : null;
}
@Nullable
public static PyType wrapInSequenceType(@Nullable PyType elementType, @NotNull PsiElement resolveAnchor) {
final PyClass sequence = PyPsiFacade.getInstance(resolveAnchor.getProject()).createClassByQName("typing.Sequence", resolveAnchor);
return sequence != null ? new PyCollectionTypeImpl(sequence, false, Collections.singletonList(upcastLiteralToClass(elementType))) : null;
}
/**
* Similar to {@link PyCapturePatternImpl#getCaptureType(PyPattern, TypeEvalContext)},
* but only chooses types that would match to typing.Sequence, and have correct length
*/
@Nullable
static PyType getSequenceCaptureType(@NotNull PySequencePattern pattern, @NotNull TypeEvalContext context) {
final PyType captureTypes = PyCapturePatternImpl.getCaptureType(pattern, context);
final boolean hasStar = ContainerUtil.exists(pattern.getElements(), it -> it instanceof PySingleStarPattern);
List<PyType> types = new ArrayList<>();
for (PyType captureType : PyTypeUtil.toStream(captureTypes)) {
if (captureType instanceof PyClassType classType &&
ArrayUtil.contains(classType.getClassQName(), "str", "bytes", "bytearray")) continue;
PyType sequenceType = PyTypeUtil.convertToType(captureType, "typing.Sequence", pattern, context);
if (sequenceType == null) continue;
if (captureType instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) {
final List<PyPattern> elements = pattern.getElements();
if (hasStar && elements.size() <= tupleType.getElementCount() || elements.size() == tupleType.getElementCount()) {
types.add(captureType);
}
}
else {
types.add(captureType);
}
}
return PyUnionType.union(types);
}
}

View File

@@ -1,8 +1,13 @@
package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.jetbrains.python.psi.PyElementVisitor;
import com.jetbrains.python.psi.PySingleStarPattern;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.types.*;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.Collections;
import java.util.List;
public class PySingleStarPatternImpl extends PyElementImpl implements PySingleStarPattern {
public PySingleStarPatternImpl(ASTNode astNode) {
@@ -13,4 +18,24 @@ public class PySingleStarPatternImpl extends PyElementImpl implements PySingleSt
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
pyVisitor.visitPySingleStarPattern(this);
}
@Override
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
return null;
}
@Override
public @NotNull List<@Nullable PyType> getCapturedTypesFromSequenceType(@Nullable PyType sequenceType, @NotNull TypeEvalContext context) {
if (getParent() instanceof PySequencePattern sequenceParent) {
final int idx = sequenceParent.getElements().indexOf(this);
if (sequenceType instanceof PyTupleType tupleType && !tupleType.isHomogeneous()) {
return tupleType.getElementTypes().subList(idx, idx + tupleType.getElementCount() - sequenceParent.getElements().size() + 1);
}
var upcast = PyTypeUtil.convertToType(sequenceType, "typing.Sequence", this, context);
if (upcast instanceof PyCollectionType collectionType) {
return Collections.singletonList(collectionType.getIteratedItemType());
}
}
return List.of();
}
}

View File

@@ -154,8 +154,11 @@ public class PyTargetExpressionImpl extends PyBaseElementImpl<PyTargetExpression
}
}
}
if (parent instanceof PyWithItem) {
return getWithItemVariableType((PyWithItem)parent, context);
if (parent instanceof PyWithItem withItem) {
return getWithItemVariableType(withItem, context);
}
if (parent instanceof PyPattern pattern) {
return context.getType(pattern);
}
if (parent instanceof PyAssignmentExpression) {
final PyExpression assignedValue = ((PyAssignmentExpression)parent).getAssignedValue();

View File

@@ -3,6 +3,10 @@ package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.jetbrains.python.psi.PyElementVisitor;
import com.jetbrains.python.psi.PyValuePattern;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
public class PyValuePatternImpl extends PyElementImpl implements PyValuePattern {
public PyValuePatternImpl(ASTNode astNode) {
@@ -13,4 +17,9 @@ public class PyValuePatternImpl extends PyElementImpl implements PyValuePattern
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
pyVisitor.visitPyValuePattern(this);
}
@Override
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
return context.getType(getValue());
}
}

View File

@@ -3,6 +3,10 @@ package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.jetbrains.python.psi.PyElementVisitor;
import com.jetbrains.python.psi.PyWildcardPattern;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
public class PyWildcardPatternImpl extends PyElementImpl implements PyWildcardPattern {
public PyWildcardPatternImpl(ASTNode astNode) {
@@ -13,4 +17,9 @@ public class PyWildcardPatternImpl extends PyElementImpl implements PyWildcardPa
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
pyVisitor.visitWildcardPattern(this);
}
@Override
public @Nullable PyType getType(@NotNull TypeEvalContext context, TypeEvalContext.@NotNull Key key) {
return PyCapturePatternImpl.getCaptureType(this, context);
}
}

View File

@@ -26,7 +26,7 @@ object PyCollectionTypeUtil {
private fun getListOrSetIteratedValueType(sequence: PySequenceExpression, context: TypeEvalContext): PyType? {
val elements = sequence.elements
val analyzedElementsType = PyUnionType.union(
elements.take(MAX_ANALYZED_ELEMENTS_OF_LITERALS).map { replaceLiteralWithItsClass(context.getType(it)) }
elements.take(MAX_ANALYZED_ELEMENTS_OF_LITERALS).map { PyLiteralType.upcastLiteralToClass(context.getType(it)) }
)
return if (elements.size > MAX_ANALYZED_ELEMENTS_OF_LITERALS) {
PyUnionType.createWeakType(analyzedElementsType)
@@ -66,7 +66,7 @@ object PyCollectionTypeUtil {
if (keyExpression !is PyStringLiteralExpression) {
return null
}
strKeysToValueTypes[keyExpression.stringValue] = Pair(element.value, replaceLiteralWithItsClass(valueType))
strKeysToValueTypes[keyExpression.stringValue] = Pair(element.value, PyLiteralType.upcastLiteralToClass(valueType))
}
return strKeysToValueTypes
@@ -82,8 +82,8 @@ object PyCollectionTypeUtil {
.forEach {
val type = context.getType(it)
val (keyType, valueType) = getKeyValueType(type)
keyTypes.add(replaceLiteralWithItsClass(keyType))
valueTypes.add(replaceLiteralWithItsClass(valueType))
keyTypes.add(PyLiteralType.upcastLiteralToClass(keyType))
valueTypes.add(PyLiteralType.upcastLiteralToClass(valueType))
}
if (elements.size > MAX_ANALYZED_ELEMENTS_OF_LITERALS) {
@@ -107,13 +107,4 @@ object PyCollectionTypeUtil {
}
return null to null
}
private fun replaceLiteralWithItsClass(type: PyType?): PyType? {
return when (type) {
is PyUnionType -> type.map(::replaceLiteralWithItsClass)
is PyLiteralStringType -> PyClassTypeImpl(type.cls, false)
is PyLiteralType -> PyClassTypeImpl(type.pyClass, false)
else -> type
}
}
}

View File

@@ -10,6 +10,7 @@ import com.jetbrains.python.psi.*
import com.jetbrains.python.psi.impl.PyEvaluator
import com.jetbrains.python.psi.resolve.PyResolveContext
import org.jetbrains.annotations.ApiStatus
import org.jetbrains.annotations.ApiStatus.Internal
/**
@@ -56,6 +57,17 @@ class PyLiteralType private constructor(cls: PyClass, val expression: PyExpressi
return PyLiteralType(enumClass, expression)
}
@Internal
@JvmStatic
fun upcastLiteralToClass(type: PyType?): PyType? {
return when (type) {
is PyUnionType -> type.map(::upcastLiteralToClass)
is PyLiteralStringType -> PyClassTypeImpl(type.cls, false)
is PyLiteralType -> PyClassTypeImpl(type.pyClass, false)
else -> type
}
}
private fun promoteToType(
expectedType: PyType?,
expression: PyExpression,

View File

@@ -1,16 +1,18 @@
0(1) element: null
1(2) element: PyWhileStatement
2(3,4) READ ACCESS: x
3(13) element: null. Condition: x:false
3(15) element: null. Condition: x:false
4(5) element: null. Condition: x:true
5(6) element: PyStatementList
6(7) element: PyMatchStatement
7(8) READ ACCESS: x
8(9,11) refutable pattern: 42
8(9,12) refutable pattern: 42
9(10) matched pattern: 42
10(13) element: PyBreakStatement
11(12) element: PyExpressionStatement
12(1) READ ACCESS: y
10(11) ASSERTTYPE ACCESS: x
11(15) element: PyBreakStatement
12(13) ASSERTTYPE ACCESS: x
13(14) element: PyExpressionStatement
14(15) READ ACCESS: z
15() element: null
14(1) READ ACCESS: y
15(16) element: PyExpressionStatement
16(17) READ ACCESS: z
17() element: null

View File

@@ -1,16 +1,18 @@
0(1) element: null
1(2) element: PyWhileStatement
2(3,4) READ ACCESS: x
3(13) element: null. Condition: x:false
3(15) element: null. Condition: x:false
4(5) element: null. Condition: x:true
5(6) element: PyStatementList
6(7) element: PyMatchStatement
7(8) READ ACCESS: x
8(9,11) refutable pattern: 42
8(9,12) refutable pattern: 42
9(10) matched pattern: 42
10(1) element: PyContinueStatement
11(12) element: PyExpressionStatement
12(1) READ ACCESS: y
10(11) ASSERTTYPE ACCESS: x
11(1) element: PyContinueStatement
12(13) ASSERTTYPE ACCESS: x
13(14) element: PyExpressionStatement
14(15) READ ACCESS: z
15() element: null
14(1) READ ACCESS: y
15(16) element: PyExpressionStatement
16(17) READ ACCESS: z
17() element: null

View File

@@ -2,9 +2,11 @@
1(2) WRITE ACCESS: x
2(3) element: PyMatchStatement
3(4) READ ACCESS: x
4(5,7) refutable pattern: 42
4(5,8) refutable pattern: 42
5(6) matched pattern: 42
6(9) element: PyReturnStatement
7(8) element: PyExpressionStatement
8(9) READ ACCESS: y
9() element: null
6(7) ASSERTTYPE ACCESS: x
7(11) element: PyReturnStatement
8(9) ASSERTTYPE ACCESS: x
9(10) element: PyExpressionStatement
10(11) READ ACCESS: y
11() element: null

View File

@@ -1,19 +1,18 @@
0(1) element: null
1(2) element: PyMatchStatement
2(3,16) refutable pattern: [42] | foo.bar as x
3(4,16) refutable pattern: [42] | foo.bar
2(3) element: PyAsPattern
3(4) refutable pattern: [42] | foo.bar
4(5,8) refutable pattern: [42]
5(6,8) refutable pattern: 42
6(7) matched pattern: 42
7(12) matched pattern: [42]
8(9,16) refutable pattern: foo.bar
9(10) READ ACCESS: foo
7(11) matched pattern: [42]
8(9) refutable pattern: foo.bar
9(10,15) READ ACCESS: foo
10(11) matched pattern: foo.bar
11(12) matched pattern: [42] | foo.bar
12(13) WRITE ACCESS: x
13(14) matched pattern: [42] | foo.bar as x
14(15) element: PyExpressionStatement
15(16) READ ACCESS: y
16(17) element: PyExpressionStatement
17(18) READ ACCESS: z
18() element: null
13(14) element: PyExpressionStatement
14(15) READ ACCESS: y
15(16) element: PyExpressionStatement
16(17) READ ACCESS: z
17() element: null

View File

@@ -1,12 +1,12 @@
0(1) element: null
1(2) element: PyMatchStatement
2(3,14) refutable pattern: Class(1, attr=foo.bar)
3(4) READ ACCESS: Class
2(3) refutable pattern: Class(1, attr=foo.bar)
3(4,14) READ ACCESS: Class
4(5,14) refutable pattern: 1
5(6) matched pattern: 1
6(7,14) refutable pattern: attr=foo.bar
7(8,14) refutable pattern: foo.bar
8(9) READ ACCESS: foo
7(8) refutable pattern: foo.bar
8(9,14) READ ACCESS: foo
9(10) matched pattern: foo.bar
10(11) matched pattern: attr=foo.bar
11(12) matched pattern: Class(1, attr=foo.bar)

View File

@@ -3,14 +3,13 @@
2(3) WRITE ACCESS: x
3(4) element: PyBinaryExpression
4(5,6) READ ACCESS: x
5(10) element: null. Condition: x > 0:false
5(12) element: null. Condition: x > 0:false
6(7) element: null. Condition: x > 0:true
7(8,9) READ ACCESS: x
8(10) element: null. Condition: x < 10:false
8(12) element: null. Condition: x < 10:false
9(10) element: null. Condition: x < 10:true
10(11) element: PyStatementList. Condition: x > 0 and x < 10:true
11(12) element: PyExpressionStatement
12(13) READ ACCESS: y
13(14) element: PyExpressionStatement
14(15) READ ACCESS: z
15() element: null
10(11) element: PyExpressionStatement
11(12) READ ACCESS: y
12(13) element: PyExpressionStatement
13(14) READ ACCESS: z
14() element: null

View File

@@ -3,18 +3,17 @@
2(3) WRITE ACCESS: x
3(4) element: PyBinaryExpression
4(5,6) READ ACCESS: x
5(14) element: null. Condition: x % 4 == 0:false
5(16) element: null. Condition: x % 4 == 0:false
6(7) element: null. Condition: x % 4 == 0:true
7(8) element: PyBinaryExpression
8(9,10) READ ACCESS: x
9(11) element: null. Condition: x % 400 == 0:false
10(14) element: null. Condition: x % 400 == 0:true
11(12,13) READ ACCESS: x
12(14) element: null. Condition: x % 100 != 0:false
12(16) element: null. Condition: x % 100 != 0:false
13(14) element: null. Condition: x % 100 != 0:true
14(15) element: PyStatementList. Condition: x % 4 == 0 and (x % 400 == 0 or x % 100 != 0):true
15(16) element: PyExpressionStatement
16(17) READ ACCESS: y
17(18) element: PyExpressionStatement
18(19) READ ACCESS: z
19() element: null
14(15) element: PyExpressionStatement
15(16) READ ACCESS: y
16(17) element: PyExpressionStatement
17(18) READ ACCESS: z
18() element: null

View File

@@ -6,11 +6,10 @@
5(7) element: null. Condition: x > 0:false
6(10) element: null. Condition: x > 0:true
7(8,9) READ ACCESS: x
8(10) element: null. Condition: x < 0:false
8(12) element: null. Condition: x < 0:false
9(10) element: null. Condition: x < 0:true
10(11) element: PyStatementList. Condition: x > 0 or x < 0:true
11(12) element: PyExpressionStatement
12(13) READ ACCESS: y
13(14) element: PyExpressionStatement
14(15) READ ACCESS: z
15() element: null
10(11) element: PyExpressionStatement
11(12) READ ACCESS: y
12(13) element: PyExpressionStatement
13(14) READ ACCESS: z
14() element: null

View File

@@ -1,6 +1,6 @@
0(1) element: null
1(2) element: PyMatchStatement
2(3,18) refutable pattern: [x1, x2, x3]
2(3,19) refutable pattern: [x1, x2, x3]
3(4) WRITE ACCESS: x1
4(5) WRITE ACCESS: x2
5(6) WRITE ACCESS: x3
@@ -8,14 +8,15 @@
7(8) element: PyBinaryExpression
8(9,10) READ ACCESS: x1
9(11) element: null. Condition: x1:false
10(14) element: null. Condition: x1:true
11(12,13) READ ACCESS: x2
12(14) element: null. Condition: x2:false
13(14) element: null. Condition: x2:true
14(15,18) READ ACCESS: x3
15(16) element: PyStatementList. Condition: (x1 or x2) > x3:true
16(17) element: PyExpressionStatement
17(18) READ ACCESS: y
18(19) element: PyExpressionStatement
19(20) READ ACCESS: z
20() element: null
10(17) element: null. Condition: x1:true
11(12,13,14) READ ACCESS: x2
12(19) element: null. Condition: x2:false
13(17) element: null. Condition: x2:true
14(15,16) READ ACCESS: x3
15(19) element: null. Condition: (x1 or x2) > x3:false
16(17) element: null. Condition: (x1 or x2) > x3:true
17(18) element: PyExpressionStatement
18(19) READ ACCESS: y
19(20) element: PyExpressionStatement
20(21) READ ACCESS: z
21() element: null

View File

@@ -1,11 +1,13 @@
0(1) element: null
1(2) element: PyMatchStatement
2(6) WRITE ACCESS: x
3(4,8) refutable pattern: [x]
4(5) WRITE ACCESS: x
5(6) matched pattern: [x]
6(7) element: PyExpressionStatement
7(8) READ ACCESS: y
2(3) refutable pattern: x | [x]
3(7) WRITE ACCESS: x
4(5,10) refutable pattern: [x]
5(6) WRITE ACCESS: x
6(7) matched pattern: [x]
7(8) matched pattern: x | [x]
8(9) element: PyExpressionStatement
9(10) READ ACCESS: z
10() element: null
9(10) READ ACCESS: y
10(11) element: PyExpressionStatement
11(12) READ ACCESS: z
12() element: null

View File

@@ -1,11 +1,13 @@
0(1) element: null
1(2) element: PyMatchStatement
2(3,5) refutable pattern: [x]
3(4) WRITE ACCESS: x
4(6) matched pattern: [x]
5(6) WRITE ACCESS: x
6(7) element: PyExpressionStatement
7(8) READ ACCESS: y
2(3) refutable pattern: [x] | x
3(4,6) refutable pattern: [x]
4(5) WRITE ACCESS: x
5(7) matched pattern: [x]
6(7) WRITE ACCESS: x
7(8) matched pattern: [x] | x
8(9) element: PyExpressionStatement
9(10) READ ACCESS: z
10() element: null
9(10) READ ACCESS: y
10(11) element: PyExpressionStatement
11(12) READ ACCESS: z
12() element: null

View File

@@ -10,8 +10,8 @@
9(10,19) refutable pattern: 'bar': foo.bar
10(11,19) refutable pattern: 'bar'
11(12) matched pattern: 'bar'
12(13,19) refutable pattern: foo.bar
13(14) READ ACCESS: foo
12(13) refutable pattern: foo.bar
13(14,19) READ ACCESS: foo
14(15) matched pattern: foo.bar
15(16) matched pattern: 'bar': foo.bar
16(17) matched pattern: {'foo': 1, 'bar': foo.bar}

View File

@@ -1,19 +1,17 @@
0(1) element: null
1(2) element: PyMatchStatement
2(3,16) refutable pattern: 1 | (2 | 3)
2(3) refutable pattern: 1 | (2 | 3)
3(4,5) refutable pattern: 1
4(14) matched pattern: 1
5(6,16) refutable pattern: (2 | 3)
6(7,16) refutable pattern: 2 | 3
7(8,9) refutable pattern: 2
8(14) matched pattern: 2
9(10,16) refutable pattern: 3
10(11) matched pattern: 3
11(12) matched pattern: 2 | 3
12(13) matched pattern: (2 | 3)
13(14) matched pattern: 1 | (2 | 3)
4(11) matched pattern: 1
5(6) refutable pattern: 2 | 3
6(7,8) refutable pattern: 2
7(10) matched pattern: 2
8(9,14) refutable pattern: 3
9(10) matched pattern: 3
10(11) matched pattern: 2 | 3
11(12) matched pattern: 1 | (2 | 3)
12(13) element: PyExpressionStatement
13(14) READ ACCESS: x
14(15) element: PyExpressionStatement
15(16) READ ACCESS: x
16(17) element: PyExpressionStatement
17(18) READ ACCESS: y
18() element: null
15(16) READ ACCESS: y
16() element: null

View File

@@ -1,20 +1,17 @@
0(1) element: null
1(2) element: PyMatchStatement
2(3,17) refutable pattern: [x] | (foo.bar as x)
2(3) refutable pattern: [x] | (foo.bar as x)
3(4,6) refutable pattern: [x]
4(5) WRITE ACCESS: x
5(15) matched pattern: [x]
6(7,17) refutable pattern: (foo.bar as x)
7(8,17) refutable pattern: foo.bar as x
8(9,17) refutable pattern: foo.bar
9(10) READ ACCESS: foo
10(11) matched pattern: foo.bar
11(12) WRITE ACCESS: x
12(13) matched pattern: foo.bar as x
13(14) matched pattern: (foo.bar as x)
14(15) matched pattern: [x] | (foo.bar as x)
15(16) element: PyExpressionStatement
16(17) READ ACCESS: y
17(18) element: PyExpressionStatement
18(19) READ ACCESS: z
19() element: null
5(11) matched pattern: [x]
6(7) element: PyAsPattern
7(8) refutable pattern: foo.bar
8(9,14) READ ACCESS: foo
9(10) matched pattern: foo.bar
10(11) WRITE ACCESS: x
11(12) matched pattern: [x] | (foo.bar as x)
12(13) element: PyExpressionStatement
13(14) READ ACCESS: y
14(15) element: PyExpressionStatement
15(16) READ ACCESS: z
16() element: null

View File

@@ -1,8 +1,8 @@
0(1) element: null
1(2) element: PyMatchStatement
2(3,10) refutable pattern: [] | 42
2(3) refutable pattern: [] | 42
3(4,5) refutable pattern: []
4(8) matched pattern: []
4(7) matched pattern: []
5(6,10) refutable pattern: 42
6(7) matched pattern: 42
7(8) matched pattern: [] | 42

View File

@@ -1,10 +1,12 @@
0(1) element: null
1(2) element: PyMatchStatement
2(5) element: PyWildcardPattern
3(4,7) refutable pattern: 42
4(5) matched pattern: 42
5(6) element: PyExpressionStatement
6(7) READ ACCESS: y
2(3) refutable pattern: _ | 42
3(6) element: PyWildcardPattern
4(5,9) refutable pattern: 42
5(6) matched pattern: 42
6(7) matched pattern: _ | 42
7(8) element: PyExpressionStatement
8(9) READ ACCESS: z
9() element: null
8(9) READ ACCESS: y
9(10) element: PyExpressionStatement
10(11) READ ACCESS: z
11() element: null

View File

@@ -1,18 +1,17 @@
0(1) element: null
1(2) element: PyMatchStatement
2(3,15) refutable pattern: [x]
2(3,14) refutable pattern: [x]
3(4) WRITE ACCESS: x
4(5) matched pattern: [x]
5(6) element: PyBinaryExpression
6(7,8) READ ACCESS: x
7(12) element: null. Condition: x > 0:false
7(14) element: null. Condition: x > 0:false
8(9) element: null. Condition: x > 0:true
9(10,11) READ ACCESS: x
10(12) element: null. Condition: x % 2 == 0:false
10(14) element: null. Condition: x % 2 == 0:false
11(12) element: null. Condition: x % 2 == 0:true
12(13) element: PyStatementList. Condition: x > 0 and x % 2 == 0:true
13(14) element: PyExpressionStatement
14(15) READ ACCESS: y
15(16) element: PyExpressionStatement
16(17) READ ACCESS: z
17() element: null
12(13) element: PyExpressionStatement
13(14) READ ACCESS: y
14(15) element: PyExpressionStatement
15(16) READ ACCESS: z
16() element: null

View File

@@ -3,8 +3,8 @@
2(3,11) refutable pattern: [1, foo.bar]
3(4,11) refutable pattern: 1
4(5) matched pattern: 1
5(6,11) refutable pattern: foo.bar
6(7) READ ACCESS: foo
5(6) refutable pattern: foo.bar
6(7,11) READ ACCESS: foo
7(8) matched pattern: foo.bar
8(9) matched pattern: [1, foo.bar]
9(10) element: PyExpressionStatement

View File

@@ -1,12 +1,14 @@
0(1) element: null
1(2) element: PyMatchStatement
2(3,9) refutable pattern: [1 | x]
3(4,5) refutable pattern: 1
4(7) matched pattern: 1
5(6) WRITE ACCESS: x
6(7) matched pattern: [1 | x]
7(8) element: PyExpressionStatement
8(9) READ ACCESS: y
2(3,11) refutable pattern: [1 | x]
3(4) refutable pattern: 1 | x
4(5,6) refutable pattern: 1
5(7) matched pattern: 1
6(7) WRITE ACCESS: x
7(8) matched pattern: 1 | x
8(9) matched pattern: [1 | x]
9(10) element: PyExpressionStatement
10(11) READ ACCESS: z
11() element: null
10(11) READ ACCESS: y
11(12) element: PyExpressionStatement
12(13) READ ACCESS: z
13() element: null

View File

@@ -1,10 +1,11 @@
0(1) element: null
1(2) element: PyMatchStatement
2(3) WRITE ACCESS: x
3(4,7) READ ACCESS: x
4(5) element: PyStatementList. Condition: x > 0:true
5(6) element: PyExpressionStatement
6(7) READ ACCESS: y
7(8) element: PyExpressionStatement
8(9) READ ACCESS: z
9() element: null
3(4,5) READ ACCESS: x
4(8) element: null. Condition: x > 0:false
5(6) element: null. Condition: x > 0:true
6(7) element: PyExpressionStatement
7(8) READ ACCESS: y
8(9) element: PyExpressionStatement
9(10) READ ACCESS: z
10() element: null

View File

@@ -0,0 +1,466 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.jetbrains.python;
import com.jetbrains.python.fixtures.PyInspectionTestCase;
import com.jetbrains.python.inspections.PyAssertTypeInspection;
import com.jetbrains.python.inspections.PyInspection;
import org.jetbrains.annotations.NotNull;
public class PyPatternTypeTest extends PyInspectionTestCase {
@Override
protected @NotNull Class<? extends PyInspection> getInspectionClass() {
return PyAssertTypeInspection.class;
}
public void testMatchCapturePatternType() {
doTestByText("""
from typing import assert_type
class A: ...
m: A
match m:
case a:
assert_type(a, A)
""");
}
public void testMatchLiteralPatternNarrows() {
doTestByText("""
from typing import assert_type, Literal
m: object
match m:
case 1:
assert_type(m, Literal[1])
""");
}
public void testMatchValuePatternNarrows() {
doTestByText("""
from typing import assert_type
class B:
b: int
m: object
match m:
case B.b:
assert_type(m, int)
""");
}
public void testMatchValuePatternAlreadyNarrower() {
doTestByText("""
from typing import assert_type
class B:
b: int
m: bool
match m:
case B.b:
assert_type(m, bool)
""");
}
public void testMatchSequencePatternCaptures() {
doTestByText("""
from typing import assert_type
m: list[int]
match m:
case [a]:
assert_type(a, int)
""");
}
public void testMatchSequencePatternCapturesStarred() {
doTestByText("""
from typing import assert_type
from typing import Sequence
m: Sequence[int]
match m:
case [a, *b]:
assert_type(a, int)
assert_type(b, list[int])
""");
}
public void testMatchSequencePatternNarrowsInner() {
doTestByText("""
from typing import assert_type
from typing import Sequence
m: Sequence[object]
match m:
case [1, True]:
assert_type(m, Sequence[int | bool])
""");
}
public void testMatchSequencePatternNarrowsOuter() {
doTestByText("""
from typing import assert_type
from typing import Sequence
m: object
match m:
case [1, True]:
assert_type(m, Sequence[int | bool])
""");
}
public void testMatchSequencePatternAlreadyNarrowerInner() {
doTestByText("""
from typing import assert_type
from typing import Sequence
m: Sequence[bool]
match m:
case [1, True]:
assert_type(m, Sequence[bool])
""");
}
public void testMatchSequencePatternAlreadyNarrowerOuter() {
doTestByText("""
from typing import assert_type
from typing import Sequence
m: Sequence[object]
match m:
case [1, True]:
assert_type(m, Sequence[int | bool])
""");
}
public void testMatchSequencePatternAlreadyNarrowerBoth() {
doTestByText("""
from typing import assert_type
from typing import Sequence
m: Sequence[bool]
match m:
case [1, True]:
assert_type(m, Sequence[bool])
""");
}
public void testMatchNestedSequencePatternNarrowsInner() {
doTestByText("""
from typing import assert_type
from typing import Sequence
m: Sequence[Sequence[object]]
match m:
case [[1], [True]]:
assert_type(m, Sequence[Sequence[int] | Sequence[bool]])
""");
}
public void testMatchNestedSequencePatternNarrowsOuter() {
doTestByText("""
from typing import assert_type
from typing import Sequence
m: object
match m:
case [[1], [True]]:
assert_type(m, Sequence[Sequence[int] | Sequence[bool]])
""");
}
public void testMatchSequencePatternMatches() {
doTestByText("""
from typing import assert_type
import array, collections
from typing import Sequence, Iterable
m1: object
m2: Sequence[int]
m3: array.array[int]
m4: collections.deque[int]
m5: list[int]
m6: memoryview
m7: range
m8: tuple[int]
m9: str
m10: bytes
m11: bytearray
match m1:
case [a]:
assert_type(a, Any)
match m2:
case [b]:
assert_type(b, int)
match m3:
case [c]:
assert_type(c, int)
match m4:
case [d]:
assert_type(d, int)
match m5:
case [e]:
assert_type(e, int)
match m6:
case [f]:
assert_type(f, int)
match m7:
case [g]:
assert_type(g, int)
match m8:
case [h]:
assert_type(h, int)
match m9:
case [i]:
assert_type(i, Any)
match m10:
case [j]:
assert_type(j, Any)
match m11:
case [k]:
assert_type(k, Any)
""");
}
public void testMatchSequencePatternCapturesTuple() {
doTestByText("""
from typing import assert_type
m: tuple[int, str, bool]
match m:
case [a, b, c]:
assert_type(a, int)
assert_type(b, str)
assert_type(c, bool)
assert_type(m, tuple[int, str, bool])
""");
}
public void testMatchSequencePatternTupleNarrows() {
doTestByText("""
from typing import assert_type, Literal
m: tuple[object, object]
match m:
case [1, "str"]:
assert_type(m, tuple[Literal[1], Literal['str']])
""");
}
public void testMatchSequencePatternTupleStarred() {
doTestByText("""
from typing import assert_type, Literal
m: tuple[int, str, bool]
match m:
case [a, *b, c]:
assert_type(a, int)
assert_type(b, list[str])
assert_type(c, bool)
assert_type(m, tuple[int, str, bool])
""");
}
public void testMatchSequencePatternTupleStarredUnion() {
doTestByText("""
from typing import assert_type
m: tuple[int, str, float, bool]
match m:
case [a, *b, c]:
assert_type(a, int)
assert_type(b, list[str | float])
assert_type(c, bool)
assert_type(m, tuple[int, str, float, bool])
""");
}
public void testMatchSequenceUnionSkip() {
doTestByText("""
from typing import assert_type
from typing import List, Union
m: Union[List[List[str]], str]
match m:
case [list(['str'])]:
assert_type(m, list[list[str]])
""");
}
public void testMatchMappingPatternCaptures() {
doTestByText("""
from typing import Dict, assert_type
class B:
b: str
m: Dict[str, int]
match m:
case {"key": v}:
assert_type(v, int)
match m:
case {B.b: v2}:
assert_type(v2, int)
""");
}
public void testMatchMappingPatternCapturesTypedDict() {
doTestByText("""
from typing import TypedDict, Literal, assert_type
class A(TypedDict):
a: str
b: int
class K:
k: Literal['a']
m: A
match m:
case {"a": v}:
assert_type(v, str)
case {"b": v2}:
assert_type(v2, int)
case {"a": v3, "b": v4}:
assert_type(v3, str)
assert_type(v4, int)
case {K.k: v5}:
assert_type(v5, str)
case {"o": v6}:
assert_type(v6, Any)
""");
}
public void testMatchMappingPatternCaptureRest() {
doTestByText("""
from typing import Mapping, assert_type
m: object
match m:
case {'k': 1, **r}:
assert_type(r, dict[str, int])
n: Mapping[str, int]
match n:
case {'k': 1, **r}:
assert_type(r, dict[str, int])
""");
}
public void testMatchClassPatternCapturePositional() {
doTestByText("""
from typing import assert_type
class A:
__match_args__ = ("a", "b")
a: str
b: int
m: A
match m:
case A(i, j):
assert_type(i, str)
assert_type(j, int)
""");
}
public void testMatchClassPatternCaptureKeyword() {
doTestByText("""
from typing import assert_type
class A:
a: str
b: int
m: A
match m:
case A(a=i, b=j):
assert_type(i, str)
assert_type(j, int)
""");
}
public void testMatchClassPatternCaptureSelf() {
doTestByText("""
from typing import assert_type
m: object
match m:
case bool(a):
assert_type(a, bool)
case bytearray(b):
assert_type(b, bytearray)
case bytes(c):
assert_type(c, bytes)
case dict(d):
assert_type(d, dict)
case float(e):
assert_type(e, float)
case frozenset(f):
assert_type(f, frozenset)
case int(g):
assert_type(g, int)
case list(h):
assert_type(h, list)
case set(i):
assert_type(i, set)
case str(j):
assert_type(j, str)
case tuple(k):
assert_type(k, tuple)
"""
);
}
public void testMatchClassPatternNarrowSelfCapture() {
doTestByText("""
from typing import assert_type
m: object
match m:
case bool():
assert_type(m, bool)
case bytearray():
assert_type(m, bytearray)
case bytes():
assert_type(m, bytes)
case dict():
assert_type(m, dict)
case float():
assert_type(m, float)
case frozenset():
assert_type(m, frozenset)
case int():
assert_type(m, int)
case list():
assert_type(m, list)
case set():
assert_type(m, set)
case str():
assert_type(m, str)
case tuple():
assert_type(m, tuple)"""
);
}
}

View File

@@ -1,4 +1,4 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
// Copyright 2000-2025 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.jetbrains.python;
import com.jetbrains.python.documentation.PythonDocumentationProvider;