IDEA-165063 Stream API migration: support simple limit conversions

This commit is contained in:
Tagir Valeev
2016-12-20 16:30:23 +07:00
parent b9b2151f24
commit a15916ae24
14 changed files with 355 additions and 38 deletions

View File

@@ -65,6 +65,7 @@ class ReplaceWithCollectFix extends MigrateToStreamFix {
@NotNull PsiLoopStatement loopStatement,
@NotNull PsiStatement body,
@NotNull TerminalBlock tb) {
tb = tb.tryPeelLimit(loopStatement);
PsiElementFactory factory = JavaPsiFacade.getElementFactory(project);
PsiMethodCallExpression call = tb.getSingleMethodCall();
if (call == null) return null;

View File

@@ -15,6 +15,7 @@
*/
package com.intellij.codeInspection.streamMigration;
import com.intellij.codeInspection.streamMigration.StreamApiMigrationInspection.LimitOp;
import com.intellij.openapi.project.Project;
import com.intellij.psi.*;
import org.jetbrains.annotations.NotNull;
@@ -35,12 +36,18 @@ class ReplaceWithCountFix extends MigrateToStreamFix {
@NotNull PsiLoopStatement loopStatement,
@NotNull PsiStatement body,
@NotNull StreamApiMigrationInspection.TerminalBlock tb) {
PsiExpression operand = StreamApiMigrationInspection.extractIncrementedLValue(tb.getSingleExpression(PsiExpression.class));
tb = tb.tryPeelLimit(loopStatement);
PsiExpression expression = tb.getSingleExpression(PsiExpression.class);
StreamApiMigrationInspection.Operation lastOperation = tb.getLastOperation();
if (expression == null && lastOperation instanceof LimitOp) {
expression = ((LimitOp)lastOperation).getCountExpression();
}
PsiExpression operand = StreamApiMigrationInspection.extractIncrementedLValue(expression);
if (!(operand instanceof PsiReferenceExpression)) return null;
PsiElement element = ((PsiReferenceExpression)operand).resolve();
if (!(element instanceof PsiLocalVariable)) return null;
PsiLocalVariable var = (PsiLocalVariable)element;
StringBuilder builder = generateStream(tb.getLastOperation()).append(".count()");
StringBuilder builder = generateStream(lastOperation).append(".count()");
return replaceWithNumericAddition(project, loopStatement, var, builder, PsiType.LONG);
}
}

View File

@@ -206,12 +206,12 @@ public class StreamApiMigrationInspection extends BaseJavaBatchLocalInspectionTo
}
@Nullable
private static PsiLocalVariable getIncrementedVariable(TerminalBlock tb, List<PsiVariable> variables) {
private static PsiLocalVariable getIncrementedVariable(PsiExpression expression, TerminalBlock tb, List<PsiVariable> variables) {
// have only one non-final variable
if(variables.size() != 1) return null;
// have single expression which is either ++x or x++ or x+=1 or x=x+1
PsiExpression operand = extractIncrementedLValue(tb.getSingleExpression(PsiExpression.class));
PsiExpression operand = extractIncrementedLValue(expression);
if(!(operand instanceof PsiReferenceExpression)) return null;
PsiElement element = ((PsiReferenceExpression)operand).resolve();
@@ -306,6 +306,25 @@ public class StreamApiMigrationInspection extends BaseJavaBatchLocalInspectionTo
return extractQualifierClass(tb, call) != null && !tb.dependsOn(qualifierExpression) && canCollect(loop, call);
}
private static boolean isCountOperation(PsiLoopStatement statement, List<PsiVariable> nonFinalVariables, TerminalBlock tb) {
PsiLocalVariable variable = getIncrementedVariable(tb.getSingleExpression(PsiExpression.class), tb, nonFinalVariables);
LimitOp limitOp = tb.getLastOperation(LimitOp.class);
if (limitOp == null) {
return variable != null;
}
PsiExpression counter = PsiUtil.skipParenthesizedExprDown(limitOp.getCountExpression());
if (tb.isEmpty()) {
// like "if(++count == limit) break"
if(!(counter instanceof PsiPrefixExpression)) return false;
variable = getIncrementedVariable(counter, tb, nonFinalVariables);
} else if (!ExpressionUtils.isReferenceTo(counter, variable)) {
return false;
}
return variable != null &&
ExpressionUtils.isZero(variable.getInitializer()) &&
getInitializerUsageStatus(variable, statement) != UNKNOWN;
}
private static boolean isCollectCall(TerminalBlock tb) {
PsiMethodCallExpression call = tb.getSingleMethodCall();
if (!isCallOf(call, CommonClassNames.JAVA_UTIL_COLLECTION, "add")) return false;
@@ -314,8 +333,20 @@ public class StreamApiMigrationInspection extends BaseJavaBatchLocalInspectionTo
if (tb.dependsOn(qualifierExpression)) return false;
if (extractQualifierClass(tb, call) != null) return true;
if (qualifierExpression instanceof PsiMethodCallExpression) {
LimitOp limitOp = tb.getLastOperation(LimitOp.class);
PsiClass qualifierClass = extractQualifierClass(tb, call);
if (qualifierClass != null) {
if (limitOp == null) return true;
// like "list.add(x); if(list.size() >= limit) break;"
PsiExpression count = limitOp.getCountExpression();
if(!(count instanceof PsiMethodCallExpression)) return false;
PsiMethodCallExpression sizeCall = (PsiMethodCallExpression)count;
PsiExpression sizeQualifier = sizeCall.getMethodExpression().getQualifierExpression();
return isCallOf(sizeCall, CommonClassNames.JAVA_UTIL_COLLECTION, "size") &&
EquivalenceChecker.getCanonicalPsiEquivalence().expressionsAreEquivalent(sizeQualifier, qualifierExpression) &&
InheritanceUtil.isInheritor(qualifierClass, CommonClassNames.JAVA_UTIL_LIST);
}
if (qualifierExpression instanceof PsiMethodCallExpression && limitOp == null) {
PsiMethodCallExpression qualifierCall = (PsiMethodCallExpression)qualifierExpression;
if (isCallOf(qualifierCall, CommonClassNames.JAVA_UTIL_MAP, "computeIfAbsent")) {
PsiExpression[] args = qualifierCall.getArgumentList().getExpressions();
@@ -614,39 +645,20 @@ public class StreamApiMigrationInspection extends BaseJavaBatchLocalInspectionTo
.remove(variable -> PsiTreeUtil.getParentOfType(variable, PsiLambdaExpression.class, PsiClass.class) != surrounder)
.remove(variable -> isVariableSuitableForStream(variable, statement, tb)).toList();
TerminalBlock tbWithLimit = tb.tryPeelLimit(statement);
if (isCountOperation(statement, nonFinalVariables, tbWithLimit)) {
registerProblem(statement, "count", new ReplaceWithCountFix());
} else if (nonFinalVariables.isEmpty() && isCollectCall(tbWithLimit)) {
handleCollect(statement, tbWithLimit);
return;
} else if (getAccumulatedVariable(tb, nonFinalVariables) != null) {
registerProblem(statement, "sum", new ReplaceWithSumFix());
}
if (exitPoints.isEmpty()) {
if(getIncrementedVariable(tb, nonFinalVariables) != null) {
registerProblem(statement, "count", new ReplaceWithCountFix());
}
if(getAccumulatedVariable(tb, nonFinalVariables) != null) {
registerProblem(statement, "sum", new ReplaceWithSumFix());
}
if(!nonFinalVariables.isEmpty()) {
return;
}
if (isCollectCall(tb)) {
boolean addAll = statement instanceof PsiForeachStatement && !tb.hasOperations() && isAddAllCall(tb);
String methodName;
if(addAll) {
methodName = "addAll";
} else {
PsiMethodCallExpression call = tb.getSingleMethodCall();
if(call != null && call.getMethodExpression().getQualifierExpression() instanceof PsiMethodCallExpression) {
call = (PsiMethodCallExpression)call.getMethodExpression().getQualifierExpression();
}
if(canCollect(statement, call)) {
if(extractToArrayExpression(statement, call) != null)
methodName = "toArray";
else
methodName = "collect";
} else {
if (!SUGGEST_FOREACH) return;
methodName = "forEach";
}
}
registerProblem(statement, methodName, new ReplaceWithCollectFix(methodName));
}
else if (isCollectMapCall(statement, tb) && (REPLACE_TRIVIAL_FOREACH || tb.hasOperations())) {
if (isCollectMapCall(statement, tb) && (REPLACE_TRIVIAL_FOREACH || tb.hasOperations())) {
registerProblem(statement, "collect", new ReplaceWithCollectFix("collect"));
}
// do not replace for(T e : arr) {} with Arrays.stream(arr).forEach(e -> {}) even if flag is set
@@ -667,7 +679,7 @@ public class StreamApiMigrationInspection extends BaseJavaBatchLocalInspectionTo
// Source and intermediate ops should not refer to non-final variables
if (tb.intermediateAndSourceExpressions()
.flatCollection(expr -> PsiTreeUtil.collectElementsOfType(expr, PsiReferenceExpression.class))
.map(PsiReferenceExpression::resolve).anyMatch(nonFinalVariables::contains)) {
.map(PsiReferenceExpression::resolve).select(PsiVariable.class).anyMatch(nonFinalVariables::contains)) {
return;
}
PsiStatement[] statements = tb.getStatements();
@@ -699,6 +711,29 @@ public class StreamApiMigrationInspection extends BaseJavaBatchLocalInspectionTo
}
}
private void handleCollect(PsiLoopStatement statement, TerminalBlock tb) {
boolean addAll = statement instanceof PsiForeachStatement && !tb.hasOperations() && isAddAllCall(tb);
String methodName;
if(addAll) {
methodName = "addAll";
} else {
PsiMethodCallExpression call = tb.getSingleMethodCall();
if(call != null && call.getMethodExpression().getQualifierExpression() instanceof PsiMethodCallExpression) {
call = (PsiMethodCallExpression)call.getMethodExpression().getQualifierExpression();
}
if(canCollect(statement, call)) {
if(extractToArrayExpression(statement, call) != null)
methodName = "toArray";
else
methodName = "collect";
} else {
if (!SUGGEST_FOREACH || tb.getLastOperation() instanceof LimitOp) return;
methodName = "forEach";
}
}
registerProblem(statement, methodName, new ReplaceWithCollectFix(methodName));
}
void handleSingleReturn(PsiLoopStatement statement, TerminalBlock tb) {
PsiReturnStatement returnStatement = (PsiReturnStatement)tb.getSingleStatement();
PsiExpression value = returnStatement.getReturnValue();
@@ -712,8 +747,8 @@ public class StreamApiMigrationInspection extends BaseJavaBatchLocalInspectionTo
}
else {
methodName = "noneMatch";
Operation lastOp = tb.getLastOperation();
if(lastOp instanceof FilterOp && (((FilterOp)lastOp).isNegated() ^ BoolUtils.isNegation(lastOp.getExpression()))) {
FilterOp lastFilter = tb.getLastOperation(FilterOp.class);
if(lastFilter != null && (lastFilter.isNegated() ^ BoolUtils.isNegation(lastFilter.getExpression()))) {
methodName = "allMatch";
}
}
@@ -1057,6 +1092,39 @@ public class StreamApiMigrationInspection extends BaseJavaBatchLocalInspectionTo
}
}
static class LimitOp extends Operation {
private final boolean myInclusive;
private final PsiExpression myCounter;
LimitOp(@Nullable Operation previousOp, PsiExpression counter, PsiExpression expression, PsiVariable variable, boolean inclusive) {
super(previousOp, expression, variable);
myInclusive = inclusive;
myCounter = counter;
}
@Override
String createReplacement() {
return ".limit(" + getLimitExpression() + ")";
}
PsiExpression getCountExpression() {
return myCounter;
}
private String getLimitExpression() {
if(!myInclusive) {
return myExpression.getText();
}
if (myExpression instanceof PsiLiteralExpression) {
Object value = ((PsiLiteralExpression)myExpression).getValue();
if (value instanceof Integer || value instanceof Long) {
return String.valueOf(((Number)value).longValue() + 1);
}
}
return ParenthesesUtils.getText(myExpression, ParenthesesUtils.ADDITIVE_PRECEDENCE) + "+1";
}
}
abstract static class StreamSource extends Operation {
protected StreamSource(PsiVariable variable, PsiExpression expression) {
super(null, expression, variable);
@@ -1453,11 +1521,72 @@ public class StreamApiMigrationInspection extends BaseJavaBatchLocalInspectionTo
return null;
}
/**
* Try to peel off the condition like if(count > limit) break; from the end of current terminal block.
*
* <p>It's not guaranteed that the peeled condition actually could be translated to the limit operation:
* additional checks will be necessary</p>
*
* @param loop a main loop for which the condition should be peeled off
* @return new terminal block with additional limit operation or self if peeling is failed.
*
*/
TerminalBlock tryPeelLimit(PsiLoopStatement loop) {
if(myStatements.length == 0) return this;
TerminalBlock tb = this;
PsiStatement[] statements = {};
if(myStatements.length > 1) {
statements = new PsiStatement[]{myStatements[0]};
tb = new TerminalBlock(myPreviousOp, myVariable, Arrays.copyOfRange(myStatements, 1, myStatements.length)).extractFilter();
}
if (tb == null || !ControlFlowUtils.statementBreaksLoop(tb.getSingleStatement(), loop)) return this;
FilterOp filter = tb.getLastOperation(FilterOp.class);
if(filter == null) return this;
PsiExpression condition = PsiUtil.skipParenthesizedExprDown(filter.getExpression());
if(!(condition instanceof PsiBinaryExpression)) return this;
PsiBinaryExpression binOp = (PsiBinaryExpression)condition;
if(!ComparisonUtils.isComparison(binOp)) return this;
String comparison = filter.isNegated() ? ComparisonUtils.getNegatedComparison(binOp.getOperationTokenType())
: binOp.getOperationSign().getText();
boolean inclusive = false, flipped = false;
switch (comparison) {
case "==":
case ">=":
break;
case ">":
inclusive = true;
break;
case "<":
inclusive = true;
flipped = true;
break;
case "<=":
flipped = true;
break;
default:
return this;
}
PsiExpression counter = flipped ? binOp.getROperand() : binOp.getLOperand();
if(counter == null || VariableAccessUtils.variableIsUsed(myVariable, counter)) return this;
PsiExpression limit = flipped ? binOp.getLOperand() : binOp.getROperand();
if(!ExpressionUtils.isSimpleExpression(limit) || VariableAccessUtils.variableIsUsed(myVariable, limit)) return this;
PsiType type = limit.getType();
if(!PsiType.INT.equals(type) && !PsiType.LONG.equals(type)) return this;
LimitOp limitOp = new LimitOp(filter.getPreviousOp(), counter, limit, myVariable, inclusive);
return new TerminalBlock(limitOp, myVariable, statements);
}
@NotNull
public Operation getLastOperation() {
return myPreviousOp;
}
@Nullable
public <T extends Operation> T getLastOperation(Class<T> clazz) {
return clazz.isInstance(myPreviousOp) ? clazz.cast(myPreviousOp) : null;
}
/**
* Extract all possible intermediate operations
* @return the terminal block with all possible terminal operations extracted (may return this if no operations could be extracted)

View File

@@ -0,0 +1,17 @@
// "Replace with collect" "true"
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
public class Main {
private List<String> test(String[] list, int limit) {
List<String> result;
List<String> other = new ArrayList<>();
System.out.println("hello");
result = Arrays.stream(list).filter(Objects::nonNull).limit(limit).map(s -> s + s).sorted().collect(Collectors.toList());
return result;
}
}

View File

@@ -0,0 +1,10 @@
// "Replace with count()" "true"
import java.util.Arrays;
public class Main {
public long test(String[] array) {
long longStrings = Arrays.stream(array).map(String::trim).filter(trimmed -> trimmed.length() > 10).limit(100).count();
return longStrings;
}
}

View File

@@ -0,0 +1,10 @@
// "Replace with count()" "true"
import java.util.Arrays;
public class Main {
public long test(String[] array, long limit) {
long longStrings = Arrays.stream(array).map(String::trim).filter(trimmed -> trimmed.length() > 10).limit(limit + 1).count();
return longStrings;
}
}

View File

@@ -0,0 +1,10 @@
// "Replace with count()" "true"
import java.util.Arrays;
public class Main {
public long test(String[] array) {
long longStrings = Arrays.stream(array).map(String::trim).filter(trimmed -> trimmed.length() > 10).limit(100).count();
return longStrings;
}
}

View File

@@ -0,0 +1,24 @@
// "Replace with collect" "true"
import java.util.ArrayList;
import java.util.List;
public class Main {
private List<String> test(String[] list, int limit) {
List<String> result = new ArrayList<>();
List<String> other = new ArrayList<>();
System.out.println("hello");
for(String s : li<caret>st) {
if (s == null) {
continue;
}
result.add(s+s);
if(result.size() != limit) {
continue;
}
break;
}
result.sort(null);
return result;
}
}

View File

@@ -0,0 +1,23 @@
// "Replace with collect" "false"
import java.util.*;
public class Main {
private Set<String> test(String[] list, int limit) {
Set<String> result = new HashSet<>();
List<String> other = new ArrayList<>();
System.out.println("hello");
for(String s : li<caret>st) {
if (s == null) {
continue;
}
result.add(s+s);
if(result.size() != limit) {
continue;
}
break;
}
result.sort(null);
return result;
}
}

View File

@@ -0,0 +1,24 @@
// "Replace with collect" "false"
import java.util.ArrayList;
import java.util.List;
public class Main {
private List<String> test(String[] list, int limit) {
List<String> result = new ArrayList<>();
List<String> other = new ArrayList<>();
System.out.println("hello");
for(String s : li<caret>st) {
if (s == null) {
continue;
}
result.add(s+s);
if(other.size() != limit) {
continue;
}
break;
}
result.sort(null);
return result;
}
}

View File

@@ -0,0 +1,15 @@
// "Replace with count()" "true"
public class Main {
public long test(String[] array) {
long longStrings = 0;
for(String str : a<caret>rray) {
String trimmed = str.trim();
if(trimmed.length() > 10) {
longStrings = longStrings + 1;
if(longStrings >= 100) break;
}
}
return longStrings;
}
}

View File

@@ -0,0 +1,17 @@
// "Replace with count()" "true"
public class Main {
public long test(String[] array, long limit) {
long longStrings = 0;
for(String str : a<caret>rray) {
String trimmed = str.trim();
if(trimmed.length() > 10) {
longStrings++;
if(longStrings > limit) {
break;
}
}
}
return longStrings;
}
}

View File

@@ -0,0 +1,15 @@
// "Replace with count()" "true"
public class Main {
public long test(String[] array) {
long longStrings = 0;
for(String str : a<caret>rray) {
String trimmed = str.trim();
if(trimmed.length() > 10) {
if(100 > ++longStrings) continue;
break;
}
}
return longStrings;
}
}

View File

@@ -0,0 +1,15 @@
// "Replace with count()" "false"
public class Main {
public long test(String[] array) {
long longStrings = 0;
for(String str : a<caret>rray) {
String trimmed = str.trim();
if(trimmed.length() > 10) {
longStrings = longStrings + 1;
if(longStrings >= trimmed.length()) break;
}
}
return longStrings;
}
}