Migrate CompatibilityPrintCallQuickFix, PyTransformConditionalExpressionIntention and PyYieldFromIntention to ModCommand

PY-65297

GitOrigin-RevId: f77789cecd5d8d83e33242b7e800b63b00dafe1e
This commit is contained in:
Georgii Ustinov
2024-01-11 12:54:28 +02:00
committed by intellij-monorepo-bot
parent 72a859a808
commit f8990a192d
3 changed files with 71 additions and 70 deletions

View File

@@ -1,14 +1,14 @@
// Copyright 2000-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.codeInsight.intentions;
import com.intellij.openapi.editor.Editor;
import com.intellij.openapi.project.Project;
import com.intellij.modcommand.*;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiFile;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.IncorrectOperationException;
import com.jetbrains.python.PyPsiBundle;
import com.jetbrains.python.psi.*;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
/**
* User: catherine
@@ -24,38 +24,23 @@ import org.jetbrains.annotations.NotNull;
* else:
* x = b
*/
public final class PyTransformConditionalExpressionIntention extends PyBaseIntentionAction {
public final class PyTransformConditionalExpressionIntention extends PsiUpdateModCommandAction<PsiElement> {
PyTransformConditionalExpressionIntention() {
super(PsiElement.class);
}
@Override
@NotNull
public String getFamilyName() {
return PyPsiBundle.message("INTN.transform.into.if.else.statement");
}
@Override
@NotNull
public String getText() {
return PyPsiBundle.message("INTN.transform.into.if.else.statement");
}
@Override
public boolean isAvailable(@NotNull Project project, Editor editor, PsiFile file) {
if (!(file instanceof PyFile)) {
return false;
}
PyAssignmentStatement expression =
PsiTreeUtil.getParentOfType(file.findElementAt(editor.getCaretModel().getOffset()), PyAssignmentStatement.class);
if (expression != null && expression.getAssignedValue() instanceof PyConditionalExpression) {
return true;
}
return false;
}
@Override
public void doInvoke(@NotNull Project project, Editor editor, PsiFile file) throws IncorrectOperationException {
final PyAssignmentStatement assignmentStatement =
PsiTreeUtil.getParentOfType(file.findElementAt(editor.getCaretModel().getOffset()), PyAssignmentStatement.class);
assert assignmentStatement != null;
protected void invoke(@NotNull ActionContext context, @NotNull PsiElement element, @NotNull ModPsiUpdater updater) {
PyAssignmentStatement assignmentStatement =
PsiTreeUtil.getParentOfType(element, PyAssignmentStatement.class);
final PyExpression assignedValue =
assignmentStatement.getAssignedValue();
if (assignedValue instanceof PyConditionalExpression expression) {
@@ -66,14 +51,29 @@ public final class PyTransformConditionalExpressionIntention extends PyBaseInten
final PyExpression leftHandSideExpression = assignmentStatement.getLeftHandSideExpression();
if (leftHandSideExpression != null) {
final String targetText = leftHandSideExpression.getText();
final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project);
final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(context.project());
final String text = "if " + condition.getText() + ":\n\t" + targetText + " = " + truePartText
+ "\nelse:\n\t" + targetText + " = " + falsePart.getText();
final PyIfStatement ifStatement = elementGenerator.createFromText(LanguageLevel.forElement(expression), PyIfStatement.class, text);
+ "\nelse:\n\t" + targetText + " = " + falsePart.getText();
final PyIfStatement ifStatement =
elementGenerator.createFromText(LanguageLevel.forElement(expression), PyIfStatement.class, text);
assignmentStatement.replace(ifStatement);
}
}
}
}
@Override
protected @Nullable Presentation getPresentation(@NotNull ActionContext context, @NotNull PsiElement element) {
PsiFile file = element.getContainingFile();
if (!(file instanceof PyFile)) {
return null;
}
PyAssignmentStatement expression =
PsiTreeUtil.getParentOfType(element, PyAssignmentStatement.class);
if (expression != null && expression.getAssignedValue() instanceof PyConditionalExpression) {
return super.getPresentation(context, element);
}
return null;
}
}

View File

@@ -15,57 +15,55 @@
*/
package com.jetbrains.python.codeInsight.intentions;
import com.intellij.openapi.editor.Editor;
import com.intellij.openapi.project.Project;
import com.intellij.modcommand.ActionContext;
import com.intellij.modcommand.ModPsiUpdater;
import com.intellij.modcommand.Presentation;
import com.intellij.modcommand.PsiUpdateModCommandAction;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiFile;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.IncorrectOperationException;
import com.jetbrains.python.PyPsiBundle;
import com.jetbrains.python.psi.*;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
public final class PyYieldFromIntention extends PyBaseIntentionAction {
public final class PyYieldFromIntention extends PsiUpdateModCommandAction<PsiElement> {
PyYieldFromIntention() {
super(PsiElement.class);
}
@NotNull
@Override
public String getFamilyName() {
return PyPsiBundle.message("INTN.yield.from");
}
@NotNull
@Override
public String getText() {
return PyPsiBundle.message("INTN.yield.from");
}
@Override
public boolean isAvailable(@NotNull Project project, Editor editor, PsiFile file) {
if (!LanguageLevel.forElement(file).isPython2()) {
final PyForStatement forLoop = findForStatementAtCaret(editor, file);
protected @Nullable Presentation getPresentation(@NotNull ActionContext context, @NotNull PsiElement element) {
if (!LanguageLevel.forElement(element.getContainingFile()).isPython2()) {
final PyForStatement forLoop = PsiTreeUtil.getParentOfType(element, PyForStatement.class);
if (forLoop != null) {
final PyTargetExpression forTarget = findSingleForLoopTarget(forLoop);
final PyReferenceExpression yieldValue = findSingleYieldValue(forLoop);
if (forTarget != null && yieldValue != null) {
final String targetName = forTarget.getName();
if (targetName != null && targetName.equals(yieldValue.getName())) {
return true;
return super.getPresentation(context, element);
}
}
}
}
return false;
return null;
}
@Override
public void doInvoke(@NotNull Project project, Editor editor, PsiFile file) throws IncorrectOperationException {
final PyForStatement forLoop = findForStatementAtCaret(editor, file);
protected void invoke(@NotNull ActionContext context, @NotNull PsiElement element, @NotNull ModPsiUpdater updater) {
final PyForStatement forLoop = PsiTreeUtil.getParentOfType(element, PyForStatement.class);
if (forLoop != null) {
final PyExpression source = forLoop.getForPart().getSource();
if (source != null) {
final PyElementGenerator generator = PyElementGenerator.getInstance(project);
final PyElementGenerator generator = PyElementGenerator.getInstance(context.project());
final String text = "yield from foo";
final PyExpressionStatement exprStmt = generator.createFromText(LanguageLevel.forElement(file), PyExpressionStatement.class, text);
final PyExpressionStatement exprStmt = generator.createFromText(LanguageLevel.forElement(element.getContainingFile()), PyExpressionStatement.class, text);
final PyExpression expr = exprStmt.getExpression();
if (expr instanceof PyYieldExpression) {
final PyExpression yieldValue = ((PyYieldExpression)expr).getExpression();
@@ -78,12 +76,6 @@ public final class PyYieldFromIntention extends PyBaseIntentionAction {
}
}
@Nullable
private static PyForStatement findForStatementAtCaret(@NotNull Editor editor, @NotNull PsiFile file) {
final PsiElement elementAtCaret = file.findElementAt(editor.getCaretModel().getOffset());
return PsiTreeUtil.getParentOfType(elementAtCaret, PyForStatement.class);
}
@Nullable
private static PyTargetExpression findSingleForLoopTarget(@NotNull PyForStatement forLoop) {
final PyForPart forPart = forLoop.getForPart();

View File

@@ -1,23 +1,27 @@
// Copyright 2000-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.inspections.quickfix;
import com.intellij.codeInspection.LocalQuickFix;
import com.intellij.codeInspection.ProblemDescriptor;
import com.intellij.modcommand.*;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.util.text.StringUtil;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiFile;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.PyPsiBundle;
import com.jetbrains.python.codeInsight.imports.AddImportHelper;
import com.jetbrains.python.psi.*;
import org.jetbrains.annotations.NotNull;
import java.util.List;
/**
* User: catherine
*
* QuickFix to replace statement that has no effect with function call
*/
public class CompatibilityPrintCallQuickFix implements LocalQuickFix {
public class CompatibilityPrintCallQuickFix extends ModCommandBatchQuickFix {
@Override
@NotNull
public String getFamilyName() {
@@ -25,15 +29,27 @@ public class CompatibilityPrintCallQuickFix implements LocalQuickFix {
}
@Override
public void applyFix(@NotNull Project project, @NotNull ProblemDescriptor descriptor) {
PsiElement expression = descriptor.getPsiElement();
public @NotNull ModCommand perform(@NotNull Project project, @NotNull List<ProblemDescriptor> descriptors) {
PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project);
replacePrint(expression, elementGenerator);
return ModCommand.psiUpdate(ActionContext.from(descriptors.get(0)), updater -> {
List<PsiElement> elements = ContainerUtil.map(descriptors, d -> updater.getWritable(d.getStartElement()));
PsiFile file = elements.get(0).getContainingFile();
for (PsiElement element : elements) {
if (element.isValid()) {
replace(element, elementGenerator);
}
}
AddImportHelper.addOrUpdateFromImportStatement(file,
"__future__",
"print_function",
null,
AddImportHelper.ImportPriority.FUTURE,
null);
});
}
private static void replacePrint(PsiElement expression, PyElementGenerator elementGenerator) {
private static void replace(PsiElement expression, PyElementGenerator elementGenerator) {
final StringBuilder stringBuilder = new StringBuilder("print(");
final PyFile file = (PyFile)expression.getContainingFile();
final PyExpression[] target = PsiTreeUtil.getChildrenOfType(expression, PyExpression.class);
if (target != null) {
stringBuilder.append(StringUtil.join(target, o -> o.getText(), ", "));
@@ -41,12 +57,5 @@ public class CompatibilityPrintCallQuickFix implements LocalQuickFix {
stringBuilder.append(")");
expression.replace(elementGenerator.createFromText(LanguageLevel.forElement(expression), PyElement.class,
stringBuilder.toString()));
AddImportHelper.addOrUpdateFromImportStatement(file,
"__future__",
"print_function",
null,
AddImportHelper.ImportPriority.FUTURE,
null);
}
}