[java-inspection] Support multi-line single-return lambdas in StreamToLoopInspection

Part of IDEABKL-7718
Fixes IDEA-317735

GitOrigin-RevId: 120245c2b1f4abb464d52c43dd39078a83f4bbcd
This commit is contained in:
Tagir Valeev
2023-04-14 14:47:54 +02:00
committed by intellij-monorepo-bot
parent 6076d18d3f
commit 78c8e66901
10 changed files with 241 additions and 119 deletions

View File

@@ -150,16 +150,23 @@ public abstract class FunctionHelper {
PsiElement body = lambda.getBody();
PsiExpression lambdaExpression = LambdaUtil.extractSingleExpressionFromBody(body);
if (lambdaExpression == null) {
if (PsiTypes.voidType().equals(returnType) && body instanceof PsiCodeBlock) {
if (body instanceof PsiCodeBlock block) {
List<PsiReturnStatement> returns = getReturns(body);
if (!allowReturns && (!returns.isEmpty() || !ControlFlowUtils.codeBlockMayCompleteNormally((PsiCodeBlock)body))) return null;
// Return inside loop is not supported yet
for (PsiReturnStatement ret : returns) {
if (PsiTreeUtil.getParentOfType(ret, PsiLoopStatement.class, true, PsiLambdaExpression.class) != null) {
return null;
if (PsiTypes.voidType().equals(returnType)) {
if (!allowReturns && (!returns.isEmpty() || !ControlFlowUtils.codeBlockMayCompleteNormally(block))) return null;
// Return inside loop is not supported yet
for (PsiReturnStatement ret : returns) {
if (PsiTreeUtil.getParentOfType(ret, PsiLoopStatement.class, true, PsiLambdaExpression.class) != null) {
return null;
}
}
return new VoidBlockLambdaFunctionHelper(block, parameters);
} else if (returns.size() == 1 && ArrayUtil.getLastElement(block.getStatements()) == returns.get(0)) {
PsiExpression trivialCall = JavaPsiFacade.getElementFactory(lambda.getProject())
.createExpressionFromText("((" + type.getCanonicalText() + ")" + lambda.getText() + ")." +
interfaceMethod.getName() + "(" + String.join(",", parameters) + ")", null);
return new LambdaFunctionHelper(returnType, trivialCall, parameters);
}
return new VoidBlockLambdaFunctionHelper((PsiCodeBlock)body, parameters);
}
return null;
}
@@ -533,10 +540,10 @@ public abstract class FunctionHelper {
}
private static class LambdaFunctionHelper extends FunctionHelper {
String[] myParameters;
PsiElement myBody;
@NotNull String @NotNull [] myParameters;
@NotNull PsiElement myBody;
LambdaFunctionHelper(PsiType returnType, PsiElement body, String[] parameters) {
LambdaFunctionHelper(PsiType returnType, @NotNull PsiElement body, @NotNull String @NotNull [] parameters) {
super(returnType);
myParameters = parameters;
myBody = body;

View File

@@ -16,6 +16,7 @@ import com.intellij.psi.*;
import com.intellij.psi.codeStyle.JavaCodeStyleManager;
import com.intellij.psi.impl.source.PsiImmediateClassType;
import com.intellij.psi.util.InheritanceUtil;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.psi.util.PsiUtil;
import com.intellij.psi.util.RedundantCastUtil;
import com.siyeh.ig.callMatcher.CallMatcher;
@@ -290,24 +291,26 @@ public class StreamToLoopInspection extends AbstractBaseJavaLocalInspectionTool
for (OperationRecord or : StreamEx.ofReversed(operations)) {
replacement = or.myOperation.wrap(or.myInVar, or.myOutVar, replacement, context);
}
PsiElement firstAdded = null;
PsiElement previous = PsiTreeUtil.skipWhitespacesAndCommentsBackward(statement);
PsiElement next = PsiTreeUtil.skipWhitespacesAndCommentsForward(statement);
PsiElementFactory factory = JavaPsiFacade.getElementFactory(project);
for (PsiStatement addedStatement : ((PsiBlockStatement)factory.createStatementFromText("{" + replacement + "}", statement))
.getCodeBlock().getStatements()) {
PsiElement res = addStatement(project, statement, addedStatement);
if (firstAdded == null) {
firstAdded = res;
}
addStatement(statement, addedStatement);
}
PsiElement result = context.makeFinalReplacement();
if(result != null) {
result = normalize(project, result);
if (firstAdded == null) {
firstAdded = result;
}
if (result != null) {
normalize(result);
}
if (firstAdded != null) {
ct.insertCommentsBefore(firstAdded);
if (previous != null) {
PsiElement firstAdded = PsiTreeUtil.skipWhitespacesAndCommentsForward(previous);
if (firstAdded != null) {
ct.insertCommentsBefore(firstAdded);
JavaCodeStyleManager codeStyleManager = JavaCodeStyleManager.getInstance(project);
for (PsiElement e = firstAdded; e != null && e != next; e = e.getNextSibling()) {
codeStyleManager.shortenClassReferences(e);
}
}
}
}
catch (Exception ex) {
@@ -315,16 +318,15 @@ public class StreamToLoopInspection extends AbstractBaseJavaLocalInspectionTool
}
}
private static PsiElement addStatement(@NotNull Project project, PsiStatement statement, PsiStatement context) {
private static void addStatement(PsiStatement statement, PsiStatement context) {
PsiElement element = statement.getParent().addBefore(context, statement);
return normalize(project, element);
normalize(element);
}
private static PsiElement normalize(@NotNull Project project, PsiElement element) {
element = JavaCodeStyleManager.getInstance(project).shortenClassReferences(element);
private static void normalize(@NotNull PsiElement element) {
RemoveRedundantTypeArgumentsUtil.removeRedundantTypeArguments(element);
RedundantCastUtil.getRedundantCastsInside(element).forEach(RemoveRedundantCastUtil::removeCast);
return element;
TrivialFunctionalExpressionUsageInspection.simplifyAllLambdas(element);
}
private static StreamEx<OperationRecord> allOperations(List<OperationRecord> operations) {

View File

@@ -37,16 +37,17 @@ public class Main {
}
void sample3(List<String> people) {
List<String> list2 = people.stream().collect( // comment
Collectors.collectingAndThen(Collectors.<String, List<String>>toCollection(LinkedList::new),
list -> {
List<String> result = new ArrayList<>();
for (Iterator<String> it = Stream.concat(list.stream(), list.stream()).iterator(); it.hasNext(); ) {
String s = it.next();
result.add(s);
}
return result;
}));
// comment
List<String> strings = new LinkedList<>();
for (String person : people) {
strings.add(person);
}
List<String> result = new ArrayList<>();
for (Iterator<String> it = Stream.concat(strings.stream(), strings.stream()).iterator(); it.hasNext(); ) {
String s = it.next();
result.add(s);
}
List<String> list2 = result;
}
void sample4(List<String> people) {

View File

@@ -77,7 +77,7 @@ public class Main {
static Integer testReducing3() {
Integer totalLength = 0;
for (String s : Arrays.asList("a", "bb", "ccc")) {
for (String s : asList("a", "bb", "ccc")) {
Integer length = s.length();
totalLength = totalLength + length;
}

View File

@@ -11,12 +11,12 @@ import static java.util.Arrays.asList;
public class Main {
private static long testChain(List<? extends String> list) {
long count = 0L;
for (Object o : Arrays.asList(0, null, "1", list)) {
for (Object object : Arrays.asList(o)) {
for (Object o1 : Arrays.asList(object)) {
for (Object object1 : Arrays.asList(o1)) {
for (Object o2 : Arrays.asList(object1)) {
for (Object object2 : Arrays.asList(o2)) {
for (Object o : asList(0, null, "1", list)) {
for (Object object : asList(o)) {
for (Object o1 : asList(object)) {
for (Object object1 : asList(o1)) {
for (Object o2 : asList(object1)) {
for (Object object2 : asList(o2)) {
count++;
}
}

View File

@@ -0,0 +1,41 @@
// "Fix all 'Stream API call chain can be replaced with loop' problems in file" "true"
import java.util.*;
import java.util.stream.*;
class X {
record N(N parent, List<N> children, String whatever) {}
private static N reproducer(N parent, String frame) {
for (N child : parent.children) {
if (child.whatever.equals(frame)) {
return child;
}
}
N result = new N(parent, new ArrayList<>(), frame);
parent.children.add(result);
return result;
}
void testMap(List<String> list) {
List<List<String>> newList = new ArrayList<>();
for (String s : list) {
List<String> result = new ArrayList<>();
result.add(s);
result.add(s + s);
List<String> apply = result;
newList.add(apply);
}
}
void testFilter(List<String> list) {
List<String> newList = new ArrayList<>();
for (String string : list) {
boolean result = false;
if (string.isEmpty()) result = true;
if (result) {
newList.add(string);
}
}
}
}

View File

@@ -65,7 +65,7 @@ public class Main {
map.put(true, new HashMap<>());
for (String string : strings) {
String s = string/*trimming*/.trim();
if (map.get(s.length() /*too big!*/ > 2).put(((UnaryOperator<String>) /* cast is necessary here */ x -> x).apply(s), s.length()) != null) {
if (map.get(s.length() /*too big!*/ > 2).put(((UnaryOperator<String>) /* cast is necessary here */ x -> x = x).apply(s), s.length()) != null) {
throw new IllegalStateException("Duplicate key");
}
}

View File

@@ -0,0 +1,40 @@
// "Fix all 'Stream API call chain can be replaced with loop' problems in file" "true"
import java.util.*;
import java.util.stream.*;
class X {
record N(N parent, List<N> children, String whatever) {}
private static N reproducer(N parent, String frame) {
return parent.children.<caret>stream()
.filter(child -> child.whatever.equals(frame))
.findAny()
.orElseGet(() -> {
N result = new N(parent, new ArrayList<>(), frame);
parent.children.add(result);
return result;
});
}
void testMap(List<String> list) {
List<List<String>> newList = list.stream()
.map(item -> {
List<String> result = new ArrayList<>();
result.add(item);
result.add(item + item);
return result;
})
.collect(Collectors.toList());
}
void testFilter(List<String> list) {
List<String> newList = list.stream()
.filter(s -> {
boolean result = false;
if (s.isEmpty()) result = true;
return result;
})
.collect(Collectors.toList());
}
}

View File

@@ -34,7 +34,7 @@ public class Main {
public static void testToMapNameConflict(List<String> strings) {
System.out.println(strings.stream().map(x -> x/*trimming*/.trim()) // and collect
.collect(Collectors.partitioningBy(s -> s.length() /*too big!*/ > 2,
Collectors.toMap(s -> ((UnaryOperator<String>) /* cast is necessary here */ x -> x).apply(s),
Collectors.toMap(s -> ((UnaryOperator<String>) /* cast is necessary here */ x -> x = x).apply(s),
String::length))));
}

View File

@@ -1,4 +1,4 @@
// Copyright 2000-2022 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.codeInspection;
import com.intellij.codeInsight.BlockUtils;
@@ -19,6 +19,7 @@ import com.siyeh.ig.psiutils.*;
import one.util.streamex.StreamEx;
import org.jetbrains.annotations.Nls;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.ArrayList;
import java.util.Collections;
@@ -33,52 +34,25 @@ public class TrivialFunctionalExpressionUsageInspection extends AbstractBaseJava
@Override
public void visitMethodReferenceExpression(final @NotNull PsiMethodReferenceExpression expression) {
doCheckMethodCallOnFunctionalExpression(expression, call -> expression.resolve() != null);
Problem problem = doCheckMethodCallOnFunctionalExpression(expression, call -> expression.resolve() != null);
if (problem != null) {
problem.register(holder);
}
}
@Override
public void visitLambdaExpression(final @NotNull PsiLambdaExpression expression) {
final PsiElement body = expression.getBody();
if (body == null) return;
Predicate<PsiMethodCallExpression> checkBody = call -> {
final PsiElement callParent = call.getParent();
if (!(body instanceof PsiCodeBlock)) {
return callParent instanceof PsiStatement || callParent instanceof PsiLocalVariable || expression.isValueCompatible();
}
PsiStatement[] statements = ((PsiCodeBlock)body).getStatements();
if (statements.length == 1) {
return callParent instanceof PsiStatement
|| callParent instanceof PsiLocalVariable
|| statements[0] instanceof PsiReturnStatement && expression.isValueCompatible();
}
final PsiReturnStatement[] returnStatements = PsiUtil.findReturnStatements((PsiCodeBlock)body);
if (returnStatements.length > 1) {
return false;
}
if (returnStatements.length == 1) {
if (!(ArrayUtil.getLastElement(statements) instanceof PsiReturnStatement)) {
return false;
}
}
return CodeBlockSurrounder.canSurround(call);
};
Predicate<PsiMethodCallExpression> checkWrites = call ->
!ContainerUtil.exists(expression.getParameterList().getParameters(), parameter -> VariableAccessUtils.variableIsAssigned(parameter, body));
doCheckMethodCallOnFunctionalExpression(expression, checkBody.and(checkWrites));
Problem problem = doCheckLambda(expression);
if (problem != null) {
problem.register(holder);
}
}
@Override
public void visitAnonymousClass(final @NotNull PsiAnonymousClass aClass) {
if (AnonymousCanBeLambdaInspection.canBeConvertedToLambda(aClass, false, Collections.emptySet())) {
final PsiNewExpression newExpression = ObjectUtils.tryCast(aClass.getParent(), PsiNewExpression.class);
doCheckMethodCallOnFunctionalExpression(call -> {
Problem problem = doCheckMethodCallOnFunctionalExpression(call -> {
final PsiMethod method = aClass.getMethods()[0];
final PsiCodeBlock body = method.getBody();
final PsiReturnStatement[] returnStatements = PsiUtil.findReturnStatements(body);
@@ -89,45 +63,99 @@ public class TrivialFunctionalExpressionUsageInspection extends AbstractBaseJava
return callParent instanceof PsiStatement ||
callParent instanceof PsiLocalVariable;
}, newExpression, aClass.getBaseClassType(), new ReplaceAnonymousWithLambdaBodyFix());
}
}
private void doCheckMethodCallOnFunctionalExpression(PsiElement expression,
Predicate<? super PsiMethodCallExpression> elementContainerPredicate) {
final PsiTypeCastExpression parent =
ObjectUtils.tryCast(PsiUtil.skipParenthesizedExprUp(expression.getParent()), PsiTypeCastExpression.class);
if (parent != null) {
final PsiType interfaceType = parent.getType();
doCheckMethodCallOnFunctionalExpression(elementContainerPredicate, parent, interfaceType,
expression instanceof PsiLambdaExpression ? new ReplaceWithLambdaBodyFix()
: new ReplaceWithMethodReferenceFix());
}
}
private void doCheckMethodCallOnFunctionalExpression(Predicate<? super PsiMethodCallExpression> elementContainerPredicate,
PsiExpression qualifier,
PsiType interfaceType,
LocalQuickFix fix) {
final PsiMethodCallExpression call = ExpressionUtils.getCallForQualifier(qualifier);
if (call == null) return;
final PsiMethod method = call.resolveMethod();
final PsiElement referenceNameElement = call.getMethodExpression().getReferenceNameElement();
boolean suitableMethod = method != null &&
referenceNameElement != null &&
!method.isVarArgs() &&
call.getArgumentList().getExpressionCount() == method.getParameterList().getParametersCount() &&
elementContainerPredicate.test(call);
if (!suitableMethod) return;
final PsiMethod interfaceMethod = LambdaUtil.getFunctionalInterfaceMethod(interfaceType);
if (method == interfaceMethod || interfaceMethod != null && MethodSignatureUtil.isSuperMethod(interfaceMethod, method)) {
holder.registerProblem(referenceNameElement,
InspectionGadgetsBundle.message("inspection.trivial.functional.expression.usage.description"),
fix);
if (problem != null) {
problem.register(holder);
}
}
}
};
}
public static void simplifyAllLambdas(@NotNull PsiElement context) {
List<@NotNull Problem> problems = SyntaxTraverser.psiTraverser(context)
.filter(PsiLambdaExpression.class)
.filterMap(TrivialFunctionalExpressionUsageInspection::doCheckLambda)
.toList();
for (Problem problem : problems) {
if (!problem.place().isValid()) continue;
problem.fix().apply(problem.place());
}
}
@Nullable
private static Problem doCheckLambda(@NotNull PsiLambdaExpression expression) {
final PsiElement body = expression.getBody();
if (body == null) return null;
Predicate<PsiMethodCallExpression> checkBody = call -> {
final PsiElement callParent = call.getParent();
if (!(body instanceof PsiCodeBlock)) {
return callParent instanceof PsiStatement || callParent instanceof PsiLocalVariable || expression.isValueCompatible();
}
PsiStatement[] statements = ((PsiCodeBlock)body).getStatements();
if (statements.length == 1) {
return callParent instanceof PsiStatement
|| callParent instanceof PsiLocalVariable
|| statements[0] instanceof PsiReturnStatement && expression.isValueCompatible();
}
final PsiReturnStatement[] returnStatements = PsiUtil.findReturnStatements((PsiCodeBlock)body);
if (returnStatements.length > 1) {
return false;
}
if (returnStatements.length == 1) {
if (!(ArrayUtil.getLastElement(statements) instanceof PsiReturnStatement)) {
return false;
}
}
return CodeBlockSurrounder.canSurround(call);
};
Predicate<PsiMethodCallExpression> checkWrites = call ->
!ContainerUtil.exists(expression.getParameterList().getParameters(), parameter -> VariableAccessUtils.variableIsAssigned(parameter, body));
return doCheckMethodCallOnFunctionalExpression(expression, checkBody.and(checkWrites));
}
private static @Nullable Problem doCheckMethodCallOnFunctionalExpression(@NotNull PsiElement expression,
@NotNull Predicate<? super PsiMethodCallExpression> elementContainerPredicate) {
if (!(PsiUtil.skipParenthesizedExprUp(expression.getParent()) instanceof PsiTypeCastExpression parent)) return null;
final PsiType interfaceType = parent.getType();
return doCheckMethodCallOnFunctionalExpression(elementContainerPredicate, parent, interfaceType,
expression instanceof PsiLambdaExpression ? new ReplaceWithLambdaBodyFix()
: new ReplaceWithMethodReferenceFix());
}
private static @Nullable Problem doCheckMethodCallOnFunctionalExpression(@NotNull Predicate<? super PsiMethodCallExpression> elementContainerPredicate,
PsiExpression qualifier,
PsiType interfaceType,
@NotNull ReplaceFix fix) {
final PsiMethodCallExpression call = ExpressionUtils.getCallForQualifier(qualifier);
if (call == null) return null;
final PsiMethod method = call.resolveMethod();
final PsiElement referenceNameElement = call.getMethodExpression().getReferenceNameElement();
boolean suitableMethod = method != null &&
referenceNameElement != null &&
!method.isVarArgs() &&
call.getArgumentList().getExpressionCount() == method.getParameterList().getParametersCount() &&
elementContainerPredicate.test(call);
if (!suitableMethod) return null;
final PsiMethod interfaceMethod = LambdaUtil.getFunctionalInterfaceMethod(interfaceType);
if (method == interfaceMethod || interfaceMethod != null && MethodSignatureUtil.isSuperMethod(interfaceMethod, method)) {
return new Problem(referenceNameElement, fix);
}
return null;
}
private record Problem(@NotNull PsiElement place, @NotNull ReplaceFix fix) {
void register(@NotNull ProblemsHolder holder) {
holder.registerProblem(place, InspectionGadgetsBundle.message("inspection.trivial.functional.expression.usage.description"), fix);
}
}
private static void replaceWithLambdaBody(PsiLambdaExpression lambda) {
lambda = extractSideEffects(lambda);
@@ -342,7 +370,10 @@ public class TrivialFunctionalExpressionUsageInspection extends AbstractBaseJava
@Override
public void applyFix(@NotNull Project project, @NotNull ProblemDescriptor descriptor) {
final PsiElement psiElement = descriptor.getPsiElement();
apply(descriptor.getPsiElement());
}
void apply(@NotNull PsiElement psiElement) {
final PsiMethodCallExpression callExpression = PsiTreeUtil.getParentOfType(psiElement, PsiMethodCallExpression.class);
if (callExpression != null) {
fixExpression(callExpression, PsiUtil.skipParenthesizedExprDown(callExpression.getMethodExpression().getQualifierExpression()));