StreamToLoopInspection: fixed flatMap with method reference argument

This commit is contained in:
Tagir Valeev
2016-11-01 18:44:17 +07:00
parent 2fb544975c
commit 0ae5c0f1ae
10 changed files with 189 additions and 18 deletions

View File

@@ -49,6 +49,15 @@ abstract class FunctionHelper {
abstract PsiExpression getExpression();
/**
* Try to perform "light" transformation. Works only for single-argument SAM. The function helper decides by itself
* how to name the SAM argument and returns it. After this method invocation normal transform cannot be performed.
*
* @return SAM argument name or null if function helper refused to perform a transformation.
* @param type
*/
abstract String tryLightTransform(PsiType type);
/**
* Perform an adaptation of current function helper to the replacement context with given parameter names.
* Must be called exactly once prior using getExpression() or getText()
@@ -163,6 +172,30 @@ abstract class FunctionHelper {
myQualifierType = type == null ? null : type.getCanonicalText();
}
@Override
String tryLightTransform(PsiType type) {
if(myMethodRef.isConstructor()) return null;
PsiElement element = myMethodRef.resolve();
if(!(element instanceof PsiMethod)) return null;
PsiMethod method = (PsiMethod)element;
String var = "x";
PsiLambdaExpression lambda;
PsiClass aClass = method.getContainingClass();
if(aClass == null) return null;
if(method.getModifierList().hasExplicitModifier(PsiModifier.STATIC)) {
if(method.getParameterList().getParametersCount() != 1) return null;
lambda = (PsiLambdaExpression)JavaPsiFacade.getElementFactory(myMethodRef.getProject())
.createExpressionFromText("(" + type.getCanonicalText() + " " + var + ")->" +
aClass.getQualifiedName() + "." + method.getName() + "(" + var + ")", myMethodRef);
} else {
lambda =
(PsiLambdaExpression)JavaPsiFacade.getElementFactory(myMethodRef.getProject()).createExpressionFromText(
"(" + type.getCanonicalText() + " " + var + ")->" + var + "." + myMethodRef.getReferenceName() + "()", myMethodRef);
}
myExpression = (PsiExpression)lambda.getBody();
return var;
}
@Override
PsiExpression getExpression() {
LOG.assertTrue(myExpression != null);
@@ -240,6 +273,11 @@ abstract class FunctionHelper {
myName = methodName;
}
@Override
String tryLightTransform(PsiType type) {
return null;
}
@Override
PsiExpression getExpression() {
LOG.assertTrue(myExpression != null);
@@ -277,6 +315,12 @@ abstract class FunctionHelper {
myBody = body;
}
@Override
String tryLightTransform(PsiType type) {
LOG.assertTrue(myParameters.length == 1);
return myParameters[0];
}
PsiExpression getExpression() {
return myBody;
}

View File

@@ -18,6 +18,7 @@ package com.intellij.codeInspection.streamToLoop;
import com.intellij.codeInspection.streamToLoop.StreamToLoopInspection.StreamToLoopReplacementContext;
import com.intellij.psi.PsiExpression;
import com.intellij.psi.PsiMethodCallExpression;
import com.intellij.psi.PsiType;
import com.intellij.psi.util.PsiTypesUtil;
import one.util.streamex.StreamEx;
import org.jetbrains.annotations.Nullable;
@@ -51,7 +52,7 @@ abstract class Operation {
public void suggestNames(StreamVariable inVar, StreamVariable outVar) {}
@Nullable
static Operation createIntermediate(String name, PsiExpression[] args, StreamVariable outVar) {
static Operation createIntermediate(String name, PsiExpression[] args, StreamVariable outVar, PsiType inType) {
if(name.equals("distinct") && args.length == 0) {
return new DistinctOperation();
}
@@ -74,7 +75,7 @@ abstract class Operation {
}
if ((name.equals("flatMap") || name.equals("flatMapToInt") || name.equals("flatMapToLong") || name.equals("flatMapToDouble")) &&
args.length == 1) {
return FlatMapOperation.from(outVar, args[0]);
return FlatMapOperation.from(outVar, args[0], inType);
}
if ((name.equals("map") ||
name.equals("mapToInt") ||
@@ -172,11 +173,13 @@ abstract class Operation {
}
static class FlatMapOperation extends Operation {
private final @Nullable String myOriginalName;
private String myVarName;
private final FunctionHelper myFn;
private final List<StreamToLoopInspection.OperationRecord> myRecords;
private FlatMapOperation(@Nullable String name, List<StreamToLoopInspection.OperationRecord> records) {
myOriginalName = name;
private FlatMapOperation(String varName, FunctionHelper fn, List<StreamToLoopInspection.OperationRecord> records) {
myVarName = varName;
myFn = fn;
myRecords = records;
}
@@ -192,9 +195,7 @@ abstract class Operation {
@Override
public void suggestNames(StreamVariable inVar, StreamVariable outVar) {
if(myOriginalName != null) {
inVar.addBestNameCandidate(myOriginalName);
}
myFn.suggestVariableName(inVar, 0);
}
@Override
@@ -209,11 +210,10 @@ abstract class Operation {
@Override
String wrap(StreamVariable inVar, StreamVariable outVar, String code, StreamToLoopReplacementContext context) {
if(myOriginalName != null && !myOriginalName.equals(inVar.getName())) {
rename(myOriginalName, inVar.getName(), context);
if(!myVarName.equals(inVar.getName())) {
rename(myVarName, inVar.getName(), context);
}
StreamToLoopReplacementContext
innerContext = new StreamToLoopReplacementContext(context, myRecords);
StreamToLoopReplacementContext innerContext = new StreamToLoopReplacementContext(context, myRecords);
String replacement = code;
for(StreamToLoopInspection.OperationRecord or : StreamEx.ofReversed(myRecords)) {
replacement = or.myOperation.wrap(or.myInVar, or.myOutVar, replacement, innerContext);
@@ -222,15 +222,17 @@ abstract class Operation {
}
@Nullable
public static FlatMapOperation from(StreamVariable outVar, PsiExpression arg) {
public static FlatMapOperation from(StreamVariable outVar, PsiExpression arg, PsiType inType) {
FunctionHelper fn = FunctionHelper.create(arg, 1, true);
if(fn == null) return null;
String varName = fn.tryLightTransform(inType);
if(varName == null) return null;
PsiExpression body = fn.getExpression();
if(!(body instanceof PsiMethodCallExpression)) return null;
PsiMethodCallExpression terminalCall = (PsiMethodCallExpression)body;
List<StreamToLoopInspection.OperationRecord> records = StreamToLoopInspection.extractOperations(outVar, terminalCall);
if(records == null || StreamToLoopInspection.getTerminal(records) != null) return null;
return new FlatMapOperation(fn.getParameterName(0), records);
return new FlatMapOperation(varName, fn, records);
}
}

View File

@@ -131,10 +131,13 @@ public class StreamToLoopInspection extends BaseJavaBatchLocalInspectionTool {
PsiType callType = call.getType();
if(callType == null) return null;
if(InheritanceUtil.isInheritor(aClass, CommonClassNames.JAVA_UTIL_STREAM_BASE_STREAM)) {
Operation op = Operation.createIntermediate(name, args, outVar);
if (op != null) return op;
op = TerminalOperation.createTerminal(name, args, callType, className, call.getParent() instanceof PsiExpressionStatement);
if(op != null) return op;
PsiExpression qualifier = call.getMethodExpression().getQualifierExpression();
if(qualifier != null) {
Operation op = Operation.createIntermediate(name, args, outVar, getStreamElementType(qualifier.getType()));
if (op != null) return op;
op = TerminalOperation.createTerminal(name, args, callType, className, call.getParent() instanceof PsiExpressionStatement);
if (op != null) return op;
}
}
return SourceOperation.createSource(call);
}

View File

@@ -433,6 +433,11 @@ abstract class TerminalOperation extends Operation {
myFn = fn;
}
@Override
public void suggestNames(StreamVariable inVar, StreamVariable outVar) {
myFn.suggestVariableName(inVar, 0);
}
@Override
void registerUsedNames(Consumer<String> usedNameConsumer) {
myFn.registerUsedNames(usedNameConsumer);

View File

@@ -0,0 +1,23 @@
// "Replace Stream API chain with loop" "true"
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
public class Main {
private static List<String> test(List<List<String>> list) {
List<String> result = new ArrayList<>();
for (List<String> strings : list) {
for (String s : strings) {
result.add(s);
}
}
return result;
}
public static void main(String[] args) {
System.out.println(test(Arrays.asList(Arrays.asList("", "a", "abcd", "xyz"), Arrays.asList("x", "y"))));
}
}

View File

@@ -0,0 +1,23 @@
// "Replace Stream API chain with loop" "true"
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class Main {
private static List<String> test(List<String[]> list) {
List<String> result = new ArrayList<>();
for (String[] strings : list) {
for (String s : strings) {
result.add(s);
}
}
return result;
}
public static void main(String[] args) {
System.out.println(test(Arrays.asList(new String[] {"", "a", "abcd", "xyz"}, new String[] {"x", "y"})));
}
}

View File

@@ -0,0 +1,23 @@
// "Replace Stream API chain with loop" "true"
import java.util.ArrayList;
import java.util.List;
import java.util.stream.Collectors;
import static java.util.Arrays.asList;
public class Main {
private static List<List<String>> test(List<List<List<String>>> list) {
List<List<String>> result = new ArrayList<>();
for (List<List<String>> lists : list) {
for (List<String> strings : lists) {
result.add(strings);
}
}
return result;
}
public static void main(String[] args) {
System.out.println(test(asList(asList(asList("a", "d")), asList(asList("c"), asList("b")))));
}
}

View File

@@ -0,0 +1,16 @@
// "Replace Stream API chain with loop" "true"
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
public class Main {
private static List<String> test(List<List<String>> list) {
return list.stream().flatMap(Collection::stream).col<caret>lect(Collectors.toList());
}
public static void main(String[] args) {
System.out.println(test(Arrays.asList(Arrays.asList("", "a", "abcd", "xyz"), Arrays.asList("x", "y"))));
}
}

View File

@@ -0,0 +1,16 @@
// "Replace Stream API chain with loop" "true"
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.Stream;
public class Main {
private static List<String> test(List<String[]> list) {
return list.stream().flatMap(Stream::of).col<caret>lect(Collectors.toList());
}
public static void main(String[] args) {
System.out.println(test(Arrays.asList(new String[] {"", "a", "abcd", "xyz"}, new String[] {"x", "y"})));
}
}

View File

@@ -0,0 +1,16 @@
// "Replace Stream API chain with loop" "true"
import java.util.List;
import java.util.stream.Collectors;
import static java.util.Arrays.asList;
public class Main {
private static List<List<String>> test(List<List<List<String>>> list) {
return list.stream().flatMap(List::stream).c<caret>ollect(Collectors.toList());
}
public static void main(String[] args) {
System.out.println(test(asList(asList(asList("a", "d")), asList(asList("c"), asList("b")))));
}
}