[java-decompiler] IDEA-353995 improve support for switch expressions with patterns

GitOrigin-RevId: 98a19ccca5d898e3be713a5af584b9e81f0d295a
This commit is contained in:
Mikhail Pyltsin
2024-06-07 16:59:37 +02:00
committed by intellij-monorepo-bot
parent 97a9366c5f
commit db12d5db2d
11 changed files with 685 additions and 6 deletions

View File

@@ -446,7 +446,19 @@ public final class PatternHelper {
static boolean processAtLeastOneBlock(@NotNull VarTracker tracker, @NotNull Statement statement) {
if (statement instanceof BasicBlockStatement && statement.getExprents() != null) {
boolean found = false;
for (Exprent statementExprent : statement.getExprents()) {
List<Exprent> exprents = statement.getExprents();
List<StatEdge> edges = statement.getSuccessorEdges(StatEdge.EdgeType.DIRECT_ALL);
//should keep this variable, because it is used outside
if (edges.size() == 1 && edges.get(0).getType() == StatEdge.EdgeType.BREAK && exprents.size() > 1) {
Exprent exprent = exprents.get(exprents.size() - 1);
if (exprent instanceof AssignmentExprent assignmentExprent &&
assignmentExprent.getLeft() instanceof VarExprent preserveVarExprent &&
usedOutside(edges.get(0), assignmentExprent)) {
exprents = exprents.subList(0, exprents.size() - 1);
tracker.addPreserve(preserveVarExprent, statement);
}
}
for (Exprent statementExprent : exprents) {
if (collectRecordAssignment(tracker, statementExprent, statement)) {
found = true;
continue;
@@ -460,6 +472,18 @@ public final class PatternHelper {
return true;
}
private static boolean usedOutside(@NotNull StatEdge edge,
@NotNull AssignmentExprent assignmentExprent) {
Exprent left = assignmentExprent.getLeft();
if (!(left instanceof VarExprent varExprent)) return false;
Statement destination = edge.getDestination();
if (destination == null) return false;
List<Exprent> exprents = destination.getExprents();
if (exprents == null) return false;
return exprents.stream()
.anyMatch(exp -> exp.containsExprent(varExprent));
}
/**
* Collects record assignment in the given statement into tracker.
@@ -604,6 +628,10 @@ public final class PatternHelper {
return varTempAssignmentTracker;
}
void addPreserve(@NotNull VarExprent varExprent, @NotNull Statement statement) {
varTempAssignmentTracker.add(new TempVarAssignmentItem(varExprent, statement, false));
}
@Nullable
VarTracker copy() {
RecordVarExprent copy = root.copy();

View File

@@ -95,6 +95,7 @@ public final class SwitchHelper {
}
return false;
}
if (right instanceof FunctionExprent functionExprent &&
functionExprent.getFuncType() == FUNCTION_CAST &&
functionExprent.getLstOperands().size() == 2 &&
@@ -222,15 +223,21 @@ public final class SwitchHelper {
return isJavacEnumArray || isEclipseEnumArray;
}
record TempVarAssignmentItem(@NotNull VarExprent varExprent, @NotNull Statement statement) {
record TempVarAssignmentItem(@NotNull VarExprent varExprent,
@NotNull Statement statement,
boolean delete) {
TempVarAssignmentItem(@NotNull VarExprent varExprent, @NotNull Statement statement) {
this(varExprent, statement, true);
}
}
static void removeTempVariableDeclarations(@NotNull List<TempVarAssignmentItem> tempVarAssignments) {
if (tempVarAssignments.isEmpty()) return;
Set<Statement> visited = new HashSet<>();
Set<Statement> statements = tempVarAssignments.stream().map(a -> a.statement()).collect(Collectors.toSet());
Map<VarExprent, List<VarExprent>> vars = tempVarAssignments.stream().map(a -> a.varExprent()).collect(Collectors.groupingBy(t -> t));
Set<Statement> statements = tempVarAssignments.stream().filter(a -> a.delete).map(a -> a.statement()).collect(Collectors.toSet());
Map<VarExprent, List<VarExprent>> vars =
tempVarAssignments.stream().filter(a -> a.delete).map(a -> a.varExprent()).collect(Collectors.groupingBy(t -> t));
Set<VarExprent> preserve = tempVarAssignments.stream().filter(a->!a.delete).map(t->t.varExprent).collect(Collectors.toSet());
for (Statement statement : statements) {
Statement parent = statement;
while (parent != null) {
@@ -268,6 +275,9 @@ public final class SwitchHelper {
if (exprent.type != Exprent.EXPRENT_VAR) continue;
VarExprent varExprent = (VarExprent)exprent;
if (containVar(vars, varExprent) || (varExprent.isDefinition() && vars.containsKey(varExprent))) {
if (varExprent.isDefinition() && preserve.contains(varExprent)) {
continue;
}
toDelete.add(assignmentExprent == null ? varExprent : assignmentExprent);
}
}

View File

@@ -1443,14 +1443,81 @@ public final class SwitchPatternHelper {
}
extendCases(switchStatement, patternContainer);
addGuards(switchStatement, patternContainer);
Set<Statement> collectedPatterns =
patternContainer.patternsByStatement.values().stream().flatMap(t -> t.stream())
.map(t -> t.caseStatement)
.collect(Collectors.toSet());
if (upperDoStatement != null) {
upperDoStatement.getParent().replaceStatement(upperDoStatement, switchStatement);
Optional<Statement> lastStatementOpt = findNestedLastStatement(patternContainer);
//try to process fast exit
Statement containerStatement = patternContainer.patternsByStatement.keySet().iterator().next();
Optional<Statement> baseSwitchStatementOpt =
containerStatement.getStats().stream().filter(t -> t instanceof SwitchStatement).findAny();
if (lastStatementOpt.isPresent() &&
baseSwitchStatementOpt.isPresent() &&
new HashSet<>(switchStatement.getCaseStatements()).containsAll(collectedPatterns)) {
Statement lastStatement = lastStatementOpt.get();
Statement baseSwitchStatement = baseSwitchStatementOpt.get();
if (baseSwitchStatement.getFirst() != null &&
baseSwitchStatement.getFirst().getExprents() != null &&
lastStatement.getExprents() != null &&
switchStatement.getFirst() != null &&
switchStatement.getFirst().getExprents() != null) {
for (Exprent exprent : baseSwitchStatement.getFirst().getExprents()) {
if (exprent instanceof VarExprent varExprent &&
varExprent.isDefinition() &&
lastStatement.getExprents().stream().anyMatch(e -> e.containsExprent(varExprent))) {
switchStatement.getFirst().getExprents().add(varExprent);
break;
}
}
}
ArrayList<Statement> lst = new ArrayList<>();
lst.add(switchStatement);
lst.add(lastStatement);
List<StatEdge> edges = new ArrayList<>(lastStatement.getAllPredecessorEdges());
for (StatEdge edge : edges) {
lastStatement.removePredecessor(edge);
}
for (Statement statement : switchStatement.getCaseStatements()) {
lastStatement.addPredecessor(new StatEdge(EdgeType.BREAK, statement, lastStatement));
}
SequenceStatement statement = new SequenceStatement(lst);
List<StatEdge> lastSuccessors = new ArrayList<>(lastStatement.getAllSuccessorEdges());
for (StatEdge edge : lastSuccessors) {
statement.addSuccessor(edge);
lastStatement.removeSuccessor(edge);
}
upperDoStatement.getParent().replaceStatement(upperDoStatement, statement);
}
else {
upperDoStatement.getParent().replaceStatement(upperDoStatement, switchStatement);
}
normalizeCaseLabels(switchStatement, upperDoStatement);
}
normalizeLabels(switchStatement, tempVarAssignments);
deleteNullCases(switchStatement);
}
private static @NotNull Optional<Statement> findNestedLastStatement(@NotNull PatternContainer patternContainer) {
return Optional.of(patternContainer)
.map(c -> c.patternsByStatement)
.filter(patterns -> patterns.size() == 1)
.map(patterns -> patterns.keySet().iterator().next())
.map(p -> {
if (p.getStats() != null &&
p.getStats().size() >= 2 &&
p.getStats().get(p.getStats().size() - 2) instanceof SwitchStatement) {
return p.getStats().get(p.getStats().size() - 1);
}
else {
return null;
}
})
.filter(t -> t instanceof BasicBlockStatement basicBlockStatement &&
basicBlockStatement.getExprents() != null);
}
/**
* Deletes null cases from a SwitchStatement if these labels contain var exprent
* (it is impossible to have null and pattern named variable at the same time).

View File

@@ -242,6 +242,11 @@ public class SingleClassesTest {
@Test public void testInstanceofWithPattern() {
doTest("patterns/TestInstanceofWithPattern");
}
//it is not actual expressions, but convert expressions into statements
@Test public void testSwitchPatternWithExpression() {
doTest("patterns/TestSwitchPatternWithExpression");
}
@Test public void testInstanceofVarNotSupported() {
// the bytecode version of this test data doesn't support patterns in `instanceof`, so no modifications regarding that are applied
doTest("patterns/TestInstanceofPatternNotSupported");

View File

@@ -0,0 +1,432 @@
public class TestSwitchPatternWithExpression {
public static void main(String[] args) {
new A("4265111111");// 15
}// 16
private static String getX(I i) {
String var11;
switch (i) {
case A(String a):
var11 = a;
break;// 20
case B(String a):
var11 = a;
break;
}
return var11;// 19
}
private static String getX8(I i) {
switch (i) {// 26
case A(String a):
return a;// 28
case B(String a):
return a;// 31
}
}
private static String getX0(AA i) {
String var16;
switch (i) {
case AA(A(String a)):
var16 = a;
break;// 38
case AA(B(String a)):
var16 = a;
}
return var16;// 37
}
private static void getX11(AA i) {
String var17;
switch (i) {// 44
case AA(A(String a)):
var17 = a;
break;// 45
case AA(B(String a)):
var17 = a;
}
String aa = var17;// 46
System.out.println(aa + "1");// 48
return;
}
private static String getX4(I i) {
String var11;
switch (i) {// 52
case A(String a):
var11 = a;
break;// 53
case B(String a):
var11 = a;
break;
}
String string = var11;// 54
return string;// 56
}
private static void getX10(I i) {
String var11;
switch (i) {// 60
case A(String a):
var11 = a;
break;// 61
case B(String a):
var11 = a;
break;
}
String string = var11;// 62
System.out.println(string + "2");// 64
}
private static String getX5(I i) {
String var11;
switch (i) {// 68
case A(String a):
var11 = a + "1";// 69
break;
case B(String a):
var11 = a;
break;
}
String string = var11;// 70
return string;// 72
}
private static void getX9(I i) {
String var11;
switch (i) {// 76
case A(String a):
var11 = a + "1";// 77
break;
case B(String a):
var11 = a;
break;
}
String string = var11;// 78
System.out.println(string + "2");// 80
}
private static String getX3(I i) {
String var11;
switch (i) {
case A(String a):
System.out.println(a);// 86
var11 = a;
break;// 87
case B(String a):
System.out.println(a);// 90
var11 = a;
break;
}
return var11;// 84
}
private static String getX6(I i) {
String var11;
switch (i) {
case A(String a):
System.out.println(a);// 100
var11 = a;
break;// 101
case B(String a):
System.out.println(a);// 104
var11 = a + "1";// 105
break;
}
return var11;// 98
}
private static String getX7(I i) {
String var11;
switch (i) {
case A(String a):
System.out.println(a);// 113
System.out.println(a);// 114
System.out.println(a + "1");// 115
var11 = a;
break;// 116
case B(String a):
System.out.println(a);// 119
var11 = a + "1";// 120
break;
}
return var11;// 111
}
private static String getX2(I i) {
switch (i) {// 126
case A(String a):
return a;// 128
case B(String a):
return a;// 130
default:
throw new IllegalArgumentException();// 132
}
}
static record A(String a) implements I {
A(String a) {
this.a = a;
}
public String a() {
return this.a;// 5
}
}
static record B(String a) implements I {
B(String a) {
this.a = a;
}
public String a() {
return this.a;// 8
}
}
static record AA(I i) {
AA(I i) {
this.i = i;
}
public I i() {
return this.i;// 11
}
}
sealed interface I permits TestSwitchPatternWithExpression.A, TestSwitchPatternWithExpression.B {
}
}
class 'TestSwitchPatternWithExpression' {
method 'main ([Ljava/lang/String;)V' {
4 2
a 3
}
method 'getX (LTestSwitchPatternWithExpression$I;)Ljava/lang/String;' {
10 7
47 10
5d 16
}
method 'getX8 (LTestSwitchPatternWithExpression$I;)Ljava/lang/String;' {
10 20
47 22
5b 24
}
method 'getX0 (LTestSwitchPatternWithExpression$AA;)Ljava/lang/String;' {
10 30
70 33
8f 38
}
method 'getX11 (LTestSwitchPatternWithExpression$AA;)V' {
10 43
74 46
93 51
94 52
98 52
9d 52
a0 53
}
method 'getX4 (LTestSwitchPatternWithExpression$I;)Ljava/lang/String;' {
10 58
49 61
5f 67
61 68
}
method 'getX10 (LTestSwitchPatternWithExpression$I;)V' {
10 73
49 76
5f 82
60 83
64 83
69 83
6c 84
}
method 'getX5 (LTestSwitchPatternWithExpression$I;)Ljava/lang/String;' {
10 88
49 90
4e 91
64 97
66 98
}
method 'getX9 (LTestSwitchPatternWithExpression$I;)V' {
10 103
49 105
4e 106
64 112
65 113
69 113
6e 113
71 114
}
method 'getX3 (LTestSwitchPatternWithExpression$I;)Ljava/lang/String;' {
10 118
45 120
4a 120
4f 122
63 124
68 124
6d 129
}
method 'getX6 (LTestSwitchPatternWithExpression$I;)Ljava/lang/String;' {
10 134
45 136
4a 136
4f 138
63 140
68 140
6d 141
72 145
}
method 'getX7 (LTestSwitchPatternWithExpression$I;)Ljava/lang/String;' {
10 150
45 152
4a 152
4d 153
52 153
55 154
5a 154
5f 154
64 156
78 158
7d 158
82 159
87 163
}
method 'getX2 (LTestSwitchPatternWithExpression$I;)Ljava/lang/String;' {
10 167
3d 169
51 171
59 173
}
}
class 'TestSwitchPatternWithExpression$A' {
method '<init> (Ljava/lang/String;)V' {
6 179
9 180
}
method 'a ()Ljava/lang/String;' {
1 183
4 183
}
}
class 'TestSwitchPatternWithExpression$B' {
method '<init> (Ljava/lang/String;)V' {
6 189
9 190
}
method 'a ()Ljava/lang/String;' {
1 193
4 193
}
}
class 'TestSwitchPatternWithExpression$AA' {
method '<init> (LTestSwitchPatternWithExpression$I;)V' {
6 199
9 200
}
method 'i ()LTestSwitchPatternWithExpression$I;' {
1 203
4 203
}
}
Lines mapping:
5 <-> 184
8 <-> 194
11 <-> 204
15 <-> 3
16 <-> 4
19 <-> 17
20 <-> 11
26 <-> 21
28 <-> 23
31 <-> 25
37 <-> 39
38 <-> 34
44 <-> 44
45 <-> 47
46 <-> 52
48 <-> 53
52 <-> 59
53 <-> 62
54 <-> 68
56 <-> 69
60 <-> 74
61 <-> 77
62 <-> 83
64 <-> 84
68 <-> 89
69 <-> 91
70 <-> 98
72 <-> 99
76 <-> 104
77 <-> 106
78 <-> 113
80 <-> 114
84 <-> 130
86 <-> 121
87 <-> 123
90 <-> 125
98 <-> 146
100 <-> 137
101 <-> 139
104 <-> 141
105 <-> 142
111 <-> 164
113 <-> 153
114 <-> 154
115 <-> 155
116 <-> 157
119 <-> 159
120 <-> 160
126 <-> 168
128 <-> 170
130 <-> 172
132 <-> 174
Not mapped:
21
27
30
39
49
65
81
85
89
91
99
103
112
118
127
129

View File

@@ -0,0 +1,137 @@
package decompiler;
public class TestSwitchPatternWithExpression {
sealed interface I {
}
record A(String a) implements I {
}
record B(String a) implements I {
}
record AA(I i) {
}
public static void main(String[] args) {
I i = new A("4265111111");
}
private static String getX(I i) {
return switch (i) {
case A(var a) -> a;
case B(var a) -> a;
};
}
private static String getX8(I i) {
switch (i) {
case A(var a) -> {
return a;
}
case B(var a) -> {
return a;
}
}
}
private static String getX0(AA i) {
return switch (i) {
case AA(A(var a)) -> a;
case AA(B(var a)) -> a;
};
}
private static void getX11(AA i) {
String aa = switch (i) {
case AA(A(var a)) -> a;
case AA(B(var a)) -> a;
};
System.out.println(aa + "1");
}
private static String getX4(I i) {
String string = switch (i) {
case A(var a) -> a;
case B(var a) -> a;
};
return string;
}
private static void getX10(I i) {
String string = switch (i) {
case A(var a) -> a;
case B(var a) -> a;
};
System.out.println(string + "2");
}
private static String getX5(I i) {
String string = switch (i) {
case A(var a) -> a + "1";
case B(var a) -> a;
};
return string;
}
private static void getX9(I i) {
String string = switch (i) {
case A(var a) -> a + "1";
case B(var a) -> a;
};
System.out.println(string + "2");
}
private static String getX3(I i) {
return switch (i) {
case A(var a) -> {
System.out.println(a);
yield a;
}
case B(var a) -> {
System.out.println(a);
yield a;
}
};
}
private static String getX6(I i) {
return switch (i) {
case A(var a) -> {
System.out.println(a);
yield a;
}
case B(var a) -> {
System.out.println(a);
yield a + "1";
}
};
}
private static String getX7(I i) {
return switch (i) {
case A(var a) -> {
System.out.println(a);
System.out.println(a);
System.out.println(a + "1");
yield a;
}
case B(var a) -> {
System.out.println(a);
yield a + "1";
}
};
}
private static String getX2(I i) {
switch (i) {
case A(var a):
return a;
case B(var a):
return a;
default:
throw new IllegalArgumentException();
}
}
}