[python] Fix PyStringConcatenationToFormatIntention for enclosing string concatenations

Noticed in IJ-CR-123777.

GitOrigin-RevId: f32832bca10e805fc609babf81382709d58fc480
This commit is contained in:
Mikhail Golubev
2024-02-02 22:36:02 +02:00
committed by intellij-monorepo-bot
parent c432ce7605
commit 454deac682
4 changed files with 25 additions and 19 deletions

View File

@@ -7,8 +7,6 @@ import com.intellij.modcommand.Presentation;
import com.intellij.modcommand.PsiUpdateModCommandAction;
import com.intellij.openapi.util.Pair;
import com.intellij.openapi.util.text.StringUtil;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.NotNullFunction;
import com.jetbrains.python.PyPsiBundle;
import com.jetbrains.python.PyTokenTypes;
import com.jetbrains.python.psi.*;
@@ -23,6 +21,7 @@ import org.jetbrains.annotations.Nullable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.function.Function;
public final class PyStringConcatenationToFormatIntention extends PsiUpdateModCommandAction<PyBinaryExpression> {
@@ -43,18 +42,16 @@ public final class PyStringConcatenationToFormatIntention extends PsiUpdateModCo
return null;
}
while (element.getParent() instanceof PyBinaryExpression pyBinaryExpression) {
element = pyBinaryExpression;
}
PyBinaryExpression topmostBinaryExpr = getTopmostBinaryExpression(element);
final Collection<PyElementType> operators = getOperators(element);
final Collection<PyElementType> operators = getOperators(topmostBinaryExpr);
for (PyElementType operator : operators) {
if (operator != PyTokenTypes.PLUS) {
return null;
}
}
final Collection<PyExpression> expressions = getSimpleExpressions(element);
final Collection<PyExpression> expressions = getSimpleExpressions(topmostBinaryExpr);
if (expressions.isEmpty()) {
return null;
}
@@ -84,16 +81,12 @@ public final class PyStringConcatenationToFormatIntention extends PsiUpdateModCo
@Override
protected void invoke(@NotNull ActionContext actionContext, @NotNull PyBinaryExpression element, @NotNull ModPsiUpdater updater) {
PyBinaryExpression binaryExpression = PsiTreeUtil.getTopmostParentOfType(element, PyBinaryExpression.class);
PyBinaryExpression topmostBinaryExpr = getTopmostBinaryExpression(element);
if (binaryExpression == null) {
binaryExpression = element;
}
final LanguageLevel languageLevel = LanguageLevel.forElement(binaryExpression);
final LanguageLevel languageLevel = LanguageLevel.forElement(topmostBinaryExpr);
final boolean useFormatMethod = languageLevel.isAtLeast(LanguageLevel.PYTHON27);
NotNullFunction<String, String> escaper = StringUtil.escaper(false, "\"'\\");
Function<String, String> escaper = StringUtil.escaper(false, "\"'\\");
StringBuilder stringLiteral = new StringBuilder();
List<String> parameters = new ArrayList<>();
Pair<String, String> quotes = Pair.create("\"", "\"");
@@ -101,9 +94,9 @@ public final class PyStringConcatenationToFormatIntention extends PsiUpdateModCo
final TypeEvalContext context = TypeEvalContext.userInitiated(actionContext.project(), actionContext.file());
int paramCount = 0;
boolean isUnicode = false;
final PyClassTypeImpl unicodeType = PyBuiltinCache.getInstance(binaryExpression).getObjectType("unicode");
final PyClassTypeImpl unicodeType = PyBuiltinCache.getInstance(topmostBinaryExpr).getObjectType("unicode");
for (PyExpression expression : getSimpleExpressions(binaryExpression)) {
for (PyExpression expression : getSimpleExpressions(topmostBinaryExpr)) {
if (expression instanceof PyStringLiteralExpression) {
final PyType type = context.getType(expression);
if (type != null && type.equals(unicodeType)) {
@@ -117,7 +110,7 @@ public final class PyStringConcatenationToFormatIntention extends PsiUpdateModCo
if (!useFormatMethod) {
value = value.replace("%", "%%");
}
stringLiteral.append(escaper.fun(value));
stringLiteral.append(escaper.apply(value));
}
else {
addParamToString(stringLiteral, paramCount, useFormatMethod);
@@ -148,15 +141,22 @@ public final class PyStringConcatenationToFormatIntention extends PsiUpdateModCo
final PyExpression expression = elementGenerator.createFromText(LanguageLevel.getDefault(),
PyExpressionStatement.class, stringLiteral.toString())
.getExpression();
binaryExpression.replace(expression);
topmostBinaryExpr.replace(expression);
}
else {
PyStringLiteralExpression stringLiteralExpression =
elementGenerator.createStringLiteralAlreadyEscaped(stringLiteral.toString());
binaryExpression.replace(stringLiteralExpression);
topmostBinaryExpr.replace(stringLiteralExpression);
}
}
private static @NotNull PyBinaryExpression getTopmostBinaryExpression(@NotNull PyBinaryExpression element) {
PyBinaryExpression result = element;
while (result.getParent() instanceof PyBinaryExpression binaryExpression) {
result = binaryExpression;
}
return result;
}
private static Collection<PyExpression> getSimpleExpressions(@NotNull PyBinaryExpression expression) {
List<PyExpression> res = new ArrayList<>();

View File

@@ -0,0 +1 @@
string = ascii("foo<caret>" + "bar") + "baz"

View File

@@ -79,4 +79,8 @@ public class PyStringConcatenationToFormatIntentionTest extends PyIntentionTestC
public void testPy3Unicode() {
doTest(PyPsiBundle.message("INTN.replace.plus.with.str.format"), LanguageLevel.PYTHON34);
}
public void testEnclosingConcatenationWithIntermediateCall() {
doTest(PyPsiBundle.message("INTN.replace.plus.with.str.format"), LanguageLevel.getLatest());
}
}