PY-75537 Implement PyAstAssignmentStatement.getTargetsToValuesMapping() (PyFrontendElementTypesFacadeImpl.kt)

(cherry picked from commit 8c4926689e42f117275986c2e1246d7e04bde489)

GitOrigin-RevId: f83b02700ce17a182807ac8bfe1a0108e09d4741
This commit is contained in:
Petr
2024-10-15 18:26:09 +02:00
committed by intellij-monorepo-bot
parent 63ccfc1fb1
commit 001266a5b7
6 changed files with 89 additions and 75 deletions

View File

@@ -21,13 +21,21 @@ import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiErrorElement;
import com.intellij.psi.PsiNamedElement;
import com.intellij.psi.tree.TokenSet;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.IncorrectOperationException;
import com.intellij.util.SmartList;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.PyTokenTypes;
import com.jetbrains.python.ast.impl.PyPsiUtilsCore;
import com.jetbrains.python.ast.impl.PyUtilCore;
import com.jetbrains.python.psi.LanguageLevel;
import com.jetbrains.python.psi.PyAstElementGenerator;
import org.jetbrains.annotations.ApiStatus;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import static com.jetbrains.python.ast.PyAstElementKt.findChildByClass;
@@ -141,7 +149,66 @@ public interface PyAstAssignmentStatement extends PyAstStatement, PyAstNamedElem
* @return a list of [target, value] pairs; either part of a pair may be null, but not both.
*/
@NotNull
List<? extends Pair<? extends PyAstExpression, ? extends PyAstExpression>> getTargetsToValuesMapping();
default List<? extends Pair<? extends PyAstExpression, ? extends PyAstExpression>> getTargetsToValuesMapping() {
List<Pair<PyAstExpression, PyAstExpression>> ret = new SmartList<>();
if (!PsiTreeUtil.hasErrorElements(this)) { // no parse errors
PyAstExpression[] constituents = PsiTreeUtil.getChildrenOfType(this, PyAstExpression.class); // "a = b = c" -> [a, b, c]
if (constituents != null && constituents.length > 1) {
int lastIndex = constituents.length - 1;
PyAstExpression rhs = constituents[lastIndex];
for (int i = 0; i < lastIndex; i++) {
mapToValues(constituents[i], rhs, ret);
}
}
}
return ret;
}
private static void mapToValues(@Nullable PyAstExpression lhs,
@Nullable PyAstExpression rhs,
List<Pair<PyAstExpression, PyAstExpression>> map) {
// cast for convenience
PyAstSequenceExpression lhs_tuple = null;
PyAstExpression lhs_one = null;
if (PyPsiUtilsCore.flattenParens(lhs) instanceof PyAstTupleExpression<?> tupleExpr) lhs_tuple = tupleExpr;
else if (lhs != null) lhs_one = lhs;
PyAstSequenceExpression rhs_tuple = null;
PyAstExpression rhs_one = null;
if (PyPsiUtilsCore.flattenParens(rhs) instanceof PyAstTupleExpression<?> tupleExpr) rhs_tuple = tupleExpr;
else if (rhs != null) rhs_one = rhs;
//
if (lhs_one != null) { // single LHS, single RHS (direct mapping) or multiple RHS (packing)
map.add(Pair.create(lhs_one, rhs));
}
else if (lhs_tuple != null && rhs_one != null) { // multiple LHS, single RHS: unpacking
// PY-2648, PY-2649
PyAstElementGenerator elementGenerator = PyAstElementGenerator.getInstance(rhs_one.getProject());
final LanguageLevel languageLevel = LanguageLevel.forElement(lhs);
int counter = 0;
for (PyAstExpression tuple_elt : lhs_tuple.getElements()) {
try {
final PyAstExpression expression =
elementGenerator.createExpressionFromText(languageLevel, "(" + rhs_one.getText() + ")[" + counter + "]");
mapToValues(tuple_elt, expression, map);
}
catch (IncorrectOperationException e) {
// not parsed, no problem
}
++counter;
}
}
else if (lhs_tuple != null && rhs_tuple != null) { // multiple both sides: piecewise mapping
final List<PyAstExpression> lhsTupleElements = Arrays.asList(lhs_tuple.getElements());
final List<PyAstExpression> rhsTupleElements = Arrays.asList(rhs_tuple.getElements());
final int size = Math.max(lhsTupleElements.size(), rhsTupleElements.size());
for (int index = 0; index < size; index++) {
mapToValues(ContainerUtil.getOrElse(lhsTupleElements, index, null),
ContainerUtil.getOrElse(rhsTupleElements, index, null), map);
}
}
}
@Nullable
default PyAstExpression getLeftHandSideExpression() {

View File

@@ -8,8 +8,10 @@ import com.intellij.psi.PsiFile;
import com.intellij.psi.PsiFileFactory;
import com.intellij.psi.impl.PsiFileFactoryImpl;
import com.intellij.testFramework.LightVirtualFile;
import com.intellij.util.IncorrectOperationException;
import com.jetbrains.python.PythonFileType;
import com.jetbrains.python.PythonLanguage;
import com.jetbrains.python.ast.PyAstExpression;
import com.jetbrains.python.ast.PyAstExpressionStatement;
import org.jetbrains.annotations.ApiStatus;
import org.jetbrains.annotations.NotNull;
@@ -126,4 +128,15 @@ public class PyAstElementGenerator {
}
protected void specifyFileLanguageLevel(@NotNull VirtualFile virtualFile, @Nullable LanguageLevel langLevel) { }
@NotNull
public PyAstExpression createExpressionFromText(@NotNull LanguageLevel languageLevel, @NotNull String text)
throws IncorrectOperationException {
final PsiFile dummyFile = createDummyFile(languageLevel, text);
final PsiElement element = dummyFile.getFirstChild();
if (element instanceof PyAstExpressionStatement expressionStatement) {
return expressionStatement.getExpression();
}
throw new IncorrectOperationException("could not parse text as expression: " + text);
}
}