PY-29489 Support unpacking to targets in square brackets

GitOrigin-RevId: 470b0bd74bf3dc2ed5c8b6902df264624b50628d
This commit is contained in:
Mikhail Golubev
2024-04-12 15:30:49 +03:00
committed by intellij-monorepo-bot
parent 8c0b817248
commit b70b23ca27
3 changed files with 58 additions and 30 deletions

View File

@@ -4,7 +4,10 @@ package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.intellij.navigation.ItemPresentation;
import com.intellij.openapi.util.Ref;
import com.intellij.psi.*;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiPolyVariantReference;
import com.intellij.psi.PsiReference;
import com.intellij.psi.ResolveResult;
import com.intellij.psi.search.GlobalSearchScope;
import com.intellij.psi.search.LocalSearchScope;
import com.intellij.psi.search.SearchScope;
@@ -14,13 +17,9 @@ import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.psi.util.QualifiedName;
import com.intellij.ui.IconManager;
import com.intellij.util.IncorrectOperationException;
import com.intellij.util.ObjectUtils;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.PyStubElementTypes;
import com.jetbrains.python.ast.*;
import com.jetbrains.python.ast.impl.PyPsiUtilsCore;
import com.jetbrains.python.ast.impl.PyUtilCore;
import com.jetbrains.python.codeInsight.controlflow.ControlFlowCache;
import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
import com.jetbrains.python.codeInsight.dataflow.scope.Scope;
@@ -134,19 +133,17 @@ public class PyTargetExpressionImpl extends PyBaseElementImpl<PyTargetExpression
return context.getType(assignedValue);
}
}
if (parent instanceof PyTupleExpression) {
PsiElement nextParent = parent.getParent();
while (nextParent instanceof PyParenthesizedExpression || nextParent instanceof PyTupleExpression) {
nextParent = nextParent.getParent();
}
if (parent instanceof PyTupleExpression || parent instanceof PyListLiteralExpression) {
PsiElement nextParent =
PsiTreeUtil.skipParentsOfType(parent, PyParenthesizedExpression.class, PyTupleExpression.class, PyListLiteralExpression.class);
if (nextParent instanceof PyAssignmentStatement assignment) {
final PyExpression value = assignment.getAssignedValue();
final PyExpression lhs = assignment.getLeftHandSideExpression();
final PyTupleExpression targetTuple = PsiTreeUtil.findChildOfType(lhs, PyTupleExpression.class, false);
if (value != null && targetTuple != null) {
final PySequenceExpression targetTupleOrList = PsiTreeUtil.findChildOfType(lhs, PySequenceExpression.class, false);
if (value != null && (targetTupleOrList instanceof PyTupleExpression || targetTupleOrList instanceof PyListLiteralExpression)) {
final PyType assignedType = PyUnionType.toNonWeakType(context.getType(value));
if (assignedType != null) {
final PyType t = PyTypeChecker.getTargetTypeFromTupleAssignment(this, targetTuple, assignedType, context);
final PyType t = PyTypeChecker.getTargetTypeFromTupleAssignment(this, targetTupleOrList, assignedType, context);
if (t != null) {
return t;
}
@@ -302,8 +299,8 @@ public class PyTargetExpressionImpl extends PyBaseElementImpl<PyTargetExpression
final PyType sourceType = context.getType(source);
final PyType type = getIterationType(sourceType, source, this, context);
target = PyPsiUtils.flattenParens(target);
if (type instanceof PyTupleType tupleType && target instanceof PyTupleExpression tupleExpression) {
return PyTypeChecker.getTargetTypeFromTupleAssignment(this, tupleExpression, tupleType);
if (type instanceof PyTupleType tupleType && (target instanceof PyTupleExpression || target instanceof PyListLiteralExpression)) {
return PyTypeChecker.getTargetTypeFromTupleAssignment(this, (PySequenceExpression)target, tupleType);
}
if (target == this && type != null) {
return type;

View File

@@ -17,6 +17,7 @@ import com.jetbrains.python.codeInsight.typing.PyProtocolsKt;
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.impl.PyBuiltinCache;
import com.jetbrains.python.psi.impl.PyPsiUtils;
import com.jetbrains.python.psi.impl.PyTypeProvider;
import com.jetbrains.python.psi.resolve.PyResolveContext;
import com.jetbrains.python.psi.resolve.RatedResolveResult;
@@ -1447,16 +1448,16 @@ public final class PyTypeChecker {
@Nullable
public static PyType getTargetTypeFromTupleAssignment(@NotNull PyTargetExpression target,
@NotNull PyTupleExpression parentTuple,
@NotNull PySequenceExpression parentTupleOrList,
@NotNull PyType assignedType,
@NotNull TypeEvalContext context) {
if (assignedType instanceof PyTupleType) {
return getTargetTypeFromTupleAssignment(target, parentTuple, (PyTupleType)assignedType);
return getTargetTypeFromTupleAssignment(target, parentTupleOrList, (PyTupleType)assignedType);
}
else if (assignedType instanceof PyClassLikeType classLikeType) {
PyNamedTupleType namedTupleType = ContainerUtil.findInstance(classLikeType.getAncestorTypes(context), PyNamedTupleType.class);
if (namedTupleType != null) {
return getTargetTypeFromTupleAssignment(target, parentTuple, namedTupleType);
return getTargetTypeFromTupleAssignment(target, parentTupleOrList, namedTupleType);
}
else if (assignedType instanceof PyCollectionType generic) {
return generic.getIteratedItemType();
@@ -1466,24 +1467,22 @@ public final class PyTypeChecker {
}
@Nullable
public static PyType getTargetTypeFromTupleAssignment(@NotNull PyTargetExpression target, @NotNull PyTupleExpression parentTuple,
public static PyType getTargetTypeFromTupleAssignment(@NotNull PyTargetExpression target,
@NotNull PySequenceExpression parentTupleOrList,
@NotNull PyTupleType assignedTupleType) {
final int count = assignedTupleType.getElementCount();
final PyExpression[] elements = parentTuple.getElements();
final PyExpression[] elements = parentTupleOrList.getElements();
if (elements.length == count || assignedTupleType.isHomogeneous()) {
final int index = ArrayUtil.indexOf(elements, target);
if (index >= 0) {
return assignedTupleType.getElementType(index);
}
for (int i = 0; i < count; i++) {
PyExpression element = elements[i];
while (element instanceof PyParenthesizedExpression) {
element = ((PyParenthesizedExpression)element).getContainedExpression();
}
if (element instanceof PyTupleExpression) {
PyExpression element = PyPsiUtils.flattenParens(elements[i]);
if (element instanceof PyTupleExpression || element instanceof PyListLiteralExpression) {
final PyType elementType = assignedTupleType.getElementType(i);
if (elementType instanceof PyTupleType) {
final PyType result = getTargetTypeFromTupleAssignment(target, (PyTupleExpression)element, (PyTupleType)elementType);
if (elementType instanceof PyTupleType nestedAssignedTupleType) {
final PyType result = getTargetTypeFromTupleAssignment(target, (PySequenceExpression)element, nestedAssignedTupleType);
if (result != null) {
return result;
}

View File

@@ -3127,7 +3127,7 @@ public class PyTypeTest extends PyTestCase {
}
// PY-29489
public void testListLiteralUnpacking() {
public void testGenericIterableUnpackingNoBrackets() {
doTest("int",
"""
_, expr, _ = [1, 2, 3]
@@ -3135,10 +3135,42 @@ public class PyTypeTest extends PyTestCase {
}
// PY-29489
public void testGenericIterableUnpacking() {
public void testGenericIterableUnpackingParentheses() {
doTest("int",
"""
_, expr = map(int, ["1", "2"])
(_, expr, _) = [1, 2, 3]
""");
}
// PY-29489
public void testGenericIterableUnpackingSquareBrackets() {
doTest("int",
"""
[_, expr] = [1, 2, 3]
""");
}
public void testUnpackingToNestedTargetsInSquareBracketsInAssignments() {
doTest("int",
"""
[_, [[expr], _]] = "foo", ((42,), "bar")
""");
}
public void testUnpackingToNestedTargetsInSquareBracketsInForLoops() {
doTest("str",
"""
xs = [(1, ("foo",))]
for [_, [expr]] in xs:
pass
""");
}
public void testUnpackingToNestedTargetsInSquareBracketsInComprehensions() {
doTest("str",
"""
xs = [(1, ("foo",))]
ys = [expr for [_, [expr]] in xs]
""");
}