StreamToLoop: merge sorted().toArray() and sorted().collect(Collectors.toList()) into single step

This commit is contained in:
Tagir Valeev
2016-12-27 17:41:32 +07:00
parent 4305f611dc
commit d707262f9d
5 changed files with 129 additions and 46 deletions

View File

@@ -350,6 +350,16 @@ abstract class Operation {
myComparator = comparator;
}
@Override
Operation combineWithNext(Operation next) {
if (next instanceof TerminalOperation.ToArrayTerminalOperation ||
(next instanceof TerminalOperation.ToCollectionTerminalOperation
&& ((TerminalOperation.ToCollectionTerminalOperation)next).isList())) {
return new TerminalOperation.SortedTerminalOperation((TerminalOperation.AccumulatedOperation)next, myComparator);
}
return null;
}
@Override
String wrap(StreamVariable inVar, StreamVariable outVar, String code, StreamToLoopReplacementContext context) {
String list = context.registerVarName(Arrays.asList("toSort", "listToSort"));

View File

@@ -20,6 +20,7 @@ import com.intellij.codeInspection.streamToLoop.StreamToLoopInspection.StreamToL
import com.intellij.codeInspection.util.OptionalUtil;
import com.intellij.openapi.project.Project;
import com.intellij.psi.*;
import com.intellij.psi.util.InheritanceUtil;
import com.intellij.psi.util.PsiUtil;
import com.intellij.psi.util.TypeConversionUtil;
import com.siyeh.ig.psiutils.BoolUtils;
@@ -72,10 +73,10 @@ abstract class TerminalOperation extends Operation {
return null;
}
if(name.equals("count") && args.length == 0) {
return new AccumulatedTerminalOperation("count", "long", "0", "{acc}++;");
return new TemplateBasedOperation("count", "long", "0", "{acc}++;");
}
if(name.equals("sum") && args.length == 0) {
return AccumulatedTerminalOperation.summing(resultType);
return TemplateBasedOperation.summing(resultType);
}
if(name.equals("average") && args.length == 0) {
if(elementType.equals(PsiType.DOUBLE)) {
@@ -86,7 +87,7 @@ abstract class TerminalOperation extends Operation {
}
}
if(name.equals("summaryStatistics") && args.length == 0) {
return AccumulatedTerminalOperation.summarizing(resultType);
return TemplateBasedOperation.summarizing(resultType);
}
if((name.equals("findFirst") || name.equals("findAny")) && args.length == 0) {
PsiType optionalElementType = OptionalUtil.getOptionalElementType(resultType);
@@ -203,21 +204,21 @@ abstract class TerminalOperation extends Operation {
return null;
case "counting":
if (collectorArgs.length != 0) return null;
return new AccumulatedTerminalOperation("count", "long", "0", "{acc}++;");
return new TemplateBasedOperation("count", "long", "0", "{acc}++;");
case "summingInt":
case "summingLong":
case "summingDouble": {
if (collectorArgs.length != 1) return null;
fn = FunctionHelper.create(collectorArgs[0], 1);
PsiPrimitiveType type = PsiPrimitiveType.getUnboxedType(resultType);
return fn == null || type == null ? null : new InlineMappingTerminalOperation(fn, AccumulatedTerminalOperation.summing(type));
return fn == null || type == null ? null : new InlineMappingTerminalOperation(fn, TemplateBasedOperation.summing(type));
}
case "summarizingInt":
case "summarizingLong":
case "summarizingDouble": {
if (collectorArgs.length != 1) return null;
fn = FunctionHelper.create(collectorArgs[0], 1);
return fn == null ? null : new InlineMappingTerminalOperation(fn, AccumulatedTerminalOperation.summarizing(resultType));
return fn == null ? null : new InlineMappingTerminalOperation(fn, TemplateBasedOperation.summarizing(resultType));
}
case "averagingInt":
case "averagingLong":
@@ -269,16 +270,16 @@ abstract class TerminalOperation extends Operation {
case "joining":
switch (collectorArgs.length) {
case 0:
return new AccumulatedTerminalOperation("sb", CommonClassNames.JAVA_LANG_STRING_BUILDER,
return new TemplateBasedOperation("sb", CommonClassNames.JAVA_LANG_STRING_BUILDER,
"new " + CommonClassNames.JAVA_LANG_STRING_BUILDER + "()",
"{acc}.append({item});",
"{acc}.toString()");
"{acc}.append({item});",
"{acc}.toString()");
case 1:
case 3:
String initializer =
"new java.util.StringJoiner(" + StreamEx.of(collectorArgs).map(PsiElement::getText).joining(",") + ")";
return new AccumulatedTerminalOperation("joiner", "java.util.StringJoiner", initializer,
"{acc}.add({item});", "{acc}.toString()");
return new TemplateBasedOperation("joiner", "java.util.StringJoiner", initializer,
"{acc}.add({item});", "{acc}.toString()");
}
return null;
}
@@ -321,6 +322,16 @@ abstract class TerminalOperation extends Operation {
return substitutor == origSubstitutor ? type : JavaPsiFacade.getElementFactory(project).createType(aClass, substitutor);
}
abstract static class AccumulatedOperation extends TerminalOperation {
abstract String initAccumulator(StreamVariable inVar, StreamToLoopReplacementContext context);
abstract String getAccumulatorUpdater(StreamVariable inVar, String acc);
String generate(StreamVariable inVar, StreamToLoopReplacementContext context) {
String acc = initAccumulator(inVar, context);
return getAccumulatorUpdater(inVar, acc);
}
}
static class ReduceTerminalOperation extends TerminalOperation {
private PsiExpression myIdentity;
private String myType;
@@ -457,7 +468,7 @@ abstract class TerminalOperation extends Operation {
}
}
static class ToArrayTerminalOperation extends TerminalOperation {
static class ToArrayTerminalOperation extends AccumulatedOperation {
private final String myType;
private final FunctionHelper mySupplier;
@@ -467,7 +478,7 @@ abstract class TerminalOperation extends Operation {
}
@Override
String generate(StreamVariable inVar, StreamToLoopReplacementContext context) {
String initAccumulator(StreamVariable inVar, StreamToLoopReplacementContext context) {
String list = context.declareResult("list", CommonClassNames.JAVA_UTIL_LIST + "<" + myType + ">",
"new " + CommonClassNames.JAVA_UTIL_ARRAY_LIST + "<>()", ResultKind.UNKNOWN);
String toArrayArg = "";
@@ -476,6 +487,11 @@ abstract class TerminalOperation extends Operation {
toArrayArg = mySupplier.getText();
}
context.setFinisher(list + ".toArray(" + toArrayArg + ")");
return list;
}
@Override
String getAccumulatorUpdater(StreamVariable inVar, String list) {
return list+".add("+inVar+");\n";
}
}
@@ -550,12 +566,12 @@ abstract class TerminalOperation extends Operation {
default void suggestNames(StreamVariable inVar, StreamVariable outVar) {}
default void registerReusedElements(Consumer<PsiElement> consumer) {}
String getSupplier();
String getAccumulator(String acc, String item);
String getAccumulatorUpdater(StreamVariable inVar, String acc);
default PsiType correctReturnType(PsiType type) {return type;}
}
abstract static class CollectorBasedTerminalOperation extends TerminalOperation implements CollectorOperation {
abstract static class CollectorBasedTerminalOperation extends AccumulatedOperation implements CollectorOperation {
final String myType;
final Function<StreamToLoopReplacementContext, String> myAccNameSupplier;
final FunctionHelper mySupplier;
@@ -568,11 +584,10 @@ abstract class TerminalOperation extends Operation {
}
@Override
String generate(StreamVariable inVar, StreamToLoopReplacementContext context) {
String initAccumulator(StreamVariable inVar, StreamToLoopReplacementContext context) {
transform(context, inVar.getName());
PsiType resultType = correctReturnType(context.createType(myType));
String acc = context.declareResult(myAccNameSupplier.apply(context), resultType.getCanonicalText(), getSupplier(), ResultKind.FINAL);
return getAccumulator(acc, inVar.getName());
return context.declareResult(myAccNameSupplier.apply(context), resultType.getCanonicalText(), getSupplier(), ResultKind.FINAL);
}
@Override
@@ -596,7 +611,7 @@ abstract class TerminalOperation extends Operation {
}
}
static class AccumulatedTerminalOperation extends TerminalOperation implements CollectorOperation {
static class TemplateBasedOperation extends AccumulatedOperation implements CollectorOperation {
private String myAccName;
private String myAccType;
private String myAccInitializer;
@@ -612,7 +627,7 @@ abstract class TerminalOperation extends Operation {
* @param finisherTemplate template to final result. May contain {@code {acc}} - reference to accumulator variable.
* By default it's {@code "{acc}"}
*/
AccumulatedTerminalOperation(String accName, String accType, String accInitializer, String updateTemplate, String finisherTemplate) {
TemplateBasedOperation(String accName, String accType, String accInitializer, String updateTemplate, String finisherTemplate) {
myAccName = accName;
myAccType = accType;
myAccInitializer = accInitializer;
@@ -620,17 +635,17 @@ abstract class TerminalOperation extends Operation {
myFinisherTemplate = finisherTemplate;
}
AccumulatedTerminalOperation(String accName, String accType, String accInitializer, String updateTemplate) {
TemplateBasedOperation(String accName, String accType, String accInitializer, String updateTemplate) {
this(accName, accType, accInitializer, updateTemplate, "{acc}");
}
@Override
public String generate(StreamVariable inVar, StreamToLoopReplacementContext context) {
String initAccumulator(StreamVariable inVar, StreamToLoopReplacementContext context) {
ResultKind kind = myFinisherTemplate.equals("{acc}") ?
TypeConversionUtil.isPrimitive(myAccType) ? ResultKind.NON_FINAL : ResultKind.FINAL : ResultKind.UNKNOWN;
String varName = context.declareResult(myAccName, myAccType, myAccInitializer, kind);
context.setFinisher(myFinisherTemplate.replace("{acc}", varName));
return myUpdateTemplate.replace("{item}", inVar.getName()).replace("{acc}", varName);
return varName;
}
@Override
@@ -644,30 +659,33 @@ abstract class TerminalOperation extends Operation {
}
@Override
public String getAccumulator(String acc, String item) {
return myUpdateTemplate.replace("{acc}", acc).replace("{item}", item);
public String getAccumulatorUpdater(StreamVariable inVar, String acc) {
return myUpdateTemplate.replace("{acc}", acc).replace("{item}", inVar.getName());
}
@NotNull
static AccumulatedTerminalOperation summing(PsiType type) {
return new AccumulatedTerminalOperation("sum", type.getCanonicalText(), "0", "{acc}+={item};");
static TemplateBasedOperation summing(PsiType type) {
return new TemplateBasedOperation("sum", type.getCanonicalText(), "0", "{acc}+={item};");
}
@NotNull
static AccumulatedTerminalOperation summarizing(@NotNull PsiType resultType) {
return new AccumulatedTerminalOperation("stat", resultType.getCanonicalText(), "new " + resultType.getCanonicalText() + "()",
"{acc}.accept({item});");
static TemplateBasedOperation summarizing(@NotNull PsiType resultType) {
return new TemplateBasedOperation("stat", resultType.getCanonicalText(), "new " + resultType.getCanonicalText() + "()",
"{acc}.accept({item});");
}
}
static class ToCollectionTerminalOperation extends CollectorBasedTerminalOperation {
private final boolean myList;
public ToCollectionTerminalOperation(PsiType resultType, FunctionHelper fn, String desiredName) {
super(resultType.getCanonicalText(), context -> fn.suggestFinalOutputNames(context, desiredName, "collection").get(0), fn);
myList = InheritanceUtil.isInheritor(resultType, CommonClassNames.JAVA_UTIL_LIST);
}
@Override
public String getAccumulator(String acc, String item) {
return acc+".add("+item+");\n";
public String getAccumulatorUpdater(StreamVariable inVar, String acc) {
return acc + ".add(" + inVar + ");\n";
}
@Override
@@ -675,6 +693,10 @@ abstract class TerminalOperation extends Operation {
return correctTypeParameters(type, CommonClassNames.JAVA_UTIL_COLLECTION, Collections.emptyMap());
}
public boolean isList() {
return myList;
}
@NotNull
private static ToCollectionTerminalOperation toList(@NotNull PsiType resultType) {
return new ToCollectionTerminalOperation(resultType, FunctionHelper.newObjectSupplier(resultType, CommonClassNames.JAVA_UTIL_ARRAY_LIST), "list");
@@ -782,7 +804,7 @@ abstract class TerminalOperation extends Operation {
}
@Override
public String getAccumulator(String map, String item) {
public String getAccumulatorUpdater(StreamVariable inVar, String map) {
if(myMerger == null) {
return "if("+map+".put("+myKeyExtractor.getText()+","+myValueExtractor.getText()+")!=null) {\n"+
"throw new java.lang.IllegalStateException(\"Duplicate key\");\n}\n";
@@ -850,9 +872,9 @@ abstract class TerminalOperation extends Operation {
}
@Override
public String getAccumulator(String map, String item) {
public String getAccumulatorUpdater(StreamVariable inVar, String map) {
String acc = map+".computeIfAbsent("+myKeyExtractor.getText()+","+myKeyVar+"->"+myCollector.getSupplier()+")";
return myCollector.getAccumulator(acc, item);
return myCollector.getAccumulatorUpdater(inVar, acc);
}
}
@@ -889,7 +911,7 @@ abstract class TerminalOperation extends Operation {
myCollector.transform(context, inVar.getName());
context.addBeforeStep(map + ".put(false, " + myCollector.getSupplier() + ");");
context.addBeforeStep(map + ".put(true, " + myCollector.getSupplier() + ");");
return myCollector.getAccumulator(map + ".get(" + myPredicate.getText() + ")", inVar.getName());
return myCollector.getAccumulatorUpdater(inVar, map + ".get(" + myPredicate.getText() + ")");
}
}
@@ -959,9 +981,9 @@ abstract class TerminalOperation extends Operation {
}
@Override
public String getAccumulator(String acc, String item) {
public String getAccumulatorUpdater(StreamVariable inVar, String acc) {
return myVariable.getDeclaration() + "=" + myMapper.getText() + ";\n" +
myDownstreamCollector.getAccumulator(acc, myVariable.getName());
myDownstreamCollector.getAccumulatorUpdater(myVariable, acc);
}
}
@@ -984,8 +1006,8 @@ abstract class TerminalOperation extends Operation {
}
@Override
public String getAccumulator(String acc, String item) {
return myDownstreamCollector.getAccumulator(acc, myMapper.getText());
public String getAccumulatorUpdater(StreamVariable inVar, String acc) {
return myDownstreamCollector.getAccumulatorUpdater(new StreamVariable(myMapper.getResultType(), myMapper.getText()), acc);
}
}
@@ -1012,4 +1034,34 @@ abstract class TerminalOperation extends Operation {
return myFn.getText()+";\n";
}
}
static class SortedTerminalOperation extends TerminalOperation {
private final AccumulatedOperation myOrigin;
@Nullable private final PsiExpression myComparator;
SortedTerminalOperation(AccumulatedOperation origin, @Nullable PsiExpression comparator) {
myOrigin = origin;
myComparator = comparator;
}
@Override
public void registerReusedElements(Consumer<PsiElement> consumer) {
myOrigin.registerReusedElements(consumer);
if(myComparator != null) {
consumer.accept(myComparator);
}
}
@Override
public void suggestNames(StreamVariable inVar, StreamVariable outVar) {
myOrigin.suggestNames(inVar, outVar);
}
@Override
String generate(StreamVariable inVar, StreamToLoopReplacementContext context) {
String acc = myOrigin.initAccumulator(inVar, context);
context.addAfterStep(acc + ".sort(" + (myComparator == null ? "null" : myComparator.getText()) + ");\n");
return myOrigin.getAccumulatorUpdater(inVar, acc);
}
}
}

View File

@@ -5,15 +5,11 @@ import java.util.stream.*;
public class Main {
public List<String> testSorted(List<String> list) {
List<String> toSort = new ArrayList<>();
for (String s : list) {
toSort.add(s);
}
toSort.sort(String.CASE_INSENSITIVE_ORDER);
List<String> result = new ArrayList<>();
for (String s : toSort) {
for (String s : list) {
result.add(s);
}
result.sort(String.CASE_INSENSITIVE_ORDER);
return result;
}
}

View File

@@ -0,0 +1,15 @@
// "Replace Stream API chain with loop" "true"
import java.util.*;
import java.util.stream.*;
public class Main {
public List<String> testSorted(List<String> list) {
List<String> result = new ArrayList<>();
for (String s : list) {
result.add(s);
}
result.sort(null);
return result.toArray(new String[0]);
}
}

View File

@@ -0,0 +1,10 @@
// "Replace Stream API chain with loop" "true"
import java.util.*;
import java.util.stream.*;
public class Main {
public List<String> testSorted(List<String> list) {
return list.stream().sorted().<caret>toArray(String[]::new);
}
}