[java-intentions] IDEA-313226. Suggest the correct type for switch

GitOrigin-RevId: bd99d2b31049b94542a1745f48a33d0ee787e83d
This commit is contained in:
Mikhail Pyltsin
2023-02-17 18:32:12 +01:00
committed by intellij-monorepo-bot
parent bf8ce55c6b
commit 938ed3204b
12 changed files with 317 additions and 2 deletions

View File

@@ -554,7 +554,14 @@ public final class ExpectedTypesProvider {
}
public void processSwitchBlock(@NotNull PsiSwitchBlock statement) {
myResult.add(createInfoImpl(PsiTypes.longType(), PsiTypes.intType()));
if (statement.getExpression() == this.myExpr) {
List<ExpectedTypeInfo> collectedTypes = collectFromLabels(statement);
if (!collectedTypes.isEmpty()) {
myResult.addAll(collectedTypes);
return;
}
}
myResult.add(createInfoImpl(PsiTypes.intType(), PsiTypes.intType()));
LanguageLevel level = PsiUtil.getLanguageLevel(statement);
if (level.isAtLeast(LanguageLevel.JDK_1_5)) {
PsiClassType enumType = TypeUtils.getType(CommonClassNames.JAVA_LANG_ENUM, statement);
@@ -567,6 +574,148 @@ public final class ExpectedTypesProvider {
}
}
@NotNull
private static List<ExpectedTypeInfo> collectFromLabels(@NotNull PsiSwitchBlock statement) {
List<PsiType> labeledExpressionTypes = new ArrayList<>();
List<PsiType> labeledPatternsTypes = new ArrayList<>();
List<ExpectedTypeInfo> result = new ArrayList<>();
boolean mustBeReference = false;
PsiCodeBlock body = statement.getBody();
if (body == null) {
return result;
}
for (PsiStatement psiStatement : body.getStatements()) {
if (psiStatement instanceof PsiSwitchLabelStatementBase labelStatement) {
if (labelStatement.isDefaultCase()) {
continue;
}
PsiCaseLabelElementList labelElementList = labelStatement.getCaseLabelElementList();
if (labelElementList == null) {
continue;
}
for (PsiCaseLabelElement caseLabelElement : labelElementList.getElements()) {
if (caseLabelElement instanceof PsiExpression expression) {
PsiType type = expression.getType();
if (type == null) {
continue;
}
if (type == PsiTypes.nullType()) {
mustBeReference = true;
continue;
}
labeledExpressionTypes.add(type);
}
if (caseLabelElement instanceof PsiPattern pattern) {
PsiType type = JavaPsiPatternUtil.getPatternType(pattern);
if (type == null) {
continue;
}
labeledPatternsTypes.add(type);
}
}
}
}
result.addAll(processedExpressionTypes(labeledExpressionTypes, mustBeReference, statement));
result.addAll(processedPatternTypes(labeledPatternsTypes));
return result;
}
private static List<ExpectedTypeInfo> processedPatternTypes(@NotNull List<PsiType> expectedTypes) {
List<ExpectedTypeInfo> result = new ArrayList<>();
Set<PsiType> processedTypes = new HashSet<>();
for (PsiType type : expectedTypes) {
PsiClass currentClass = PsiUtil.resolveClassInClassTypeOnly(type);
if (currentClass == null) {
continue;
}
PsiClassType[] types = currentClass.getSuperTypes();
if (processedTypes.isEmpty()) {
Collections.addAll(processedTypes, types);
processedTypes.add(type);
}
else {
List<PsiType> combined = new ArrayList<>();
Collections.addAll(combined, types);
combined.add(type);
processedTypes.retainAll(combined);
}
}
for (PsiType type : processedTypes) {
result.add(createInfo(type, ExpectedTypeInfo.TYPE_OR_SUPERTYPE, type, TailType.NONE));
}
return result;
}
@NotNull
private static List<ExpectedTypeInfo> processedExpressionTypes(@NotNull List<PsiType> expectedTypes,
boolean mustBeReference,
@NotNull PsiSwitchBlock context) {
List<ExpectedTypeInfo> result = new ArrayList<>();
Set<PsiType> processedTypes = new HashSet<>();
for (PsiType expectedType : expectedTypes) {
if (expectedType == null) {
continue;
}
if (expectedType instanceof PsiClassType classType) {
PsiClass resolved = classType.resolve();
if (resolved != null && resolved.isEnum()) {
processedTypes.add(expectedType);
result.add(createInfoImpl(expectedType, expectedType));
continue;
}
}
if (expectedType.equalsToText(CommonClassNames.JAVA_LANG_STRING) ||
TypeConversionUtil.isPrimitiveAndNotNullOrWrapper(expectedType)) {
if (expectedType instanceof PsiPrimitiveType primitiveType) {
if (primitiveType.equals(PsiTypes.longType()) ||
primitiveType.equals(PsiTypes.doubleType()) ||
primitiveType.equals(PsiTypes.floatType())) {
return List.of(); //unexpected types, let's suggest default
}
if (mustBeReference) {
expectedType = primitiveType.getBoxedType(context);
if (expectedType == null) {
continue;
}
} else {
addWithWrapper(processedTypes, PsiTypes.intType(), result, context);
}
}
addWithWrapper(processedTypes, expectedType, result, context);
continue;
}
return List.of(); //something unexpected, let's suggest default
}
return result;
}
private static void addWithWrapper(@NotNull Set<PsiType> processedTypes,
@NotNull PsiType expectedType,
@NotNull List<ExpectedTypeInfo> result,
@Nullable PsiElement context) {
addIfNotExist(processedTypes, expectedType, result, createInfoImpl(expectedType, expectedType));
if (context!=null && expectedType instanceof PsiPrimitiveType primitiveType) {
PsiClassType type = primitiveType.getBoxedType(context);
if (type != null) {
addIfNotExist(processedTypes, type, result, createInfoImpl(type, type));
}
}
}
private static void addIfNotExist(@NotNull Set<PsiType> processedTypes,
@NotNull PsiType expectedType,
@NotNull List<ExpectedTypeInfo> result,
@NotNull ExpectedTypeInfo expectedTypeInfo) {
if (!processedTypes.contains(expectedType)) {
processedTypes.add(expectedType);
result.add(expectedTypeInfo);
}
}
@Override
public void visitSynchronizedStatement(@NotNull PsiSynchronizedStatement statement) {
PsiElementFactory factory = JavaPsiFacade.getElementFactory(statement.getProject());

View File

@@ -0,0 +1,18 @@
// "Create local variable 'x2'" "true-preview"
import static A.Month.APRIL;
class A {
public void foo() {
Month x2;
var x = switch (x2)
{
case APRIL ->
{
yield "bar";
}
default -> "foo";
};
}
enum Month{APRIL, MAY};
}

View File

@@ -0,0 +1,15 @@
// "Create local variable 'x2'" "true-preview"
class A {
public void foo() {
int x2;
var x = switch (x2)
{
case 1 ->
{
yield "bar";
}
default -> "foo";
};
}
}

View File

@@ -0,0 +1,16 @@
// "Create local variable 'x2'" "true-preview"
class A {
public void foo() {
Integer x2;
var x = switch (x2)
{
case 1 ->
{
yield "bar";
}
case null -> "null";
default -> "foo";
};
}
}

View File

@@ -0,0 +1,21 @@
// "Create local variable 'x2'" "true-preview"
class A {
String testPattern() {
BaseInterface x2;
return switch(x2){
case BaseInterface.Record1 record1 -> "1";
case BaseInterface.Record2 record1 -> "1";
default -> "2";
};
}
sealed interface BaseInterface permits BaseInterface.Record1, BaseInterface.Record2 {
sealed class Record1() implements BaseInterface {
}
record Record2() implements BaseInterface {
}
}
}

View File

@@ -0,0 +1,15 @@
// "Create local variable 'x2'" "true-preview"
class A {
public void foo() {
String x2;
var x = switch (x2)
{
case "bar" ->
{
yield "bar";
}
default -> "foo";
};
}
}

View File

@@ -0,0 +1,17 @@
// "Create local variable 'x2'" "true-preview"
import static A.Month.APRIL;
class A {
public void foo() {
var x = switch (x2<caret>)
{
case APRIL ->
{
yield "bar";
}
default -> "foo";
};
}
enum Month{APRIL, MAY};
}

View File

@@ -0,0 +1,14 @@
// "Create local variable 'x2'" "true-preview"
class A {
public void foo() {
var x = switch (x2<caret>)
{
case 1 ->
{
yield "bar";
}
default -> "foo";
};
}
}

View File

@@ -0,0 +1,15 @@
// "Create local variable 'x2'" "true-preview"
class A {
public void foo() {
var x = switch (x2<caret>)
{
case 1 ->
{
yield "bar";
}
case null -> "null";
default -> "foo";
};
}
}

View File

@@ -0,0 +1,20 @@
// "Create local variable 'x2'" "true-preview"
class A {
String testPattern() {
return switch(x2<caret>){
case BaseInterface.Record1 record1 -> "1";
case BaseInterface.Record2 record1 -> "1";
default -> "2";
};
}
sealed interface BaseInterface permits BaseInterface.Record1, BaseInterface.Record2 {
sealed class Record1() implements BaseInterface {
}
record Record2() implements BaseInterface {
}
}
}

View File

@@ -0,0 +1,14 @@
// "Create local variable 'x2'" "true-preview"
class A {
public void foo() {
var x = switch (x2<caret>)
{
case "bar" ->
{
yield "bar";
}
default -> "foo";
};
}
}

View File

@@ -2,6 +2,7 @@
package com.intellij.codeInsight.daemon.impl.quickfix;
import com.intellij.codeInsight.daemon.quickFix.LightQuickFixParameterizedTestCase;
import com.intellij.pom.java.LanguageLevel;
import com.intellij.psi.codeStyle.JavaCodeStyleSettings;
public class CreateLocalFromUsageTest extends LightQuickFixParameterizedTestCase {
@@ -13,7 +14,7 @@ public class CreateLocalFromUsageTest extends LightQuickFixParameterizedTestCase
@Override
protected void setUp() throws Exception {
super.setUp();
setLanguageLevel(LanguageLevel.JDK_20_PREVIEW);
JavaCodeStyleSettings.getInstance(getProject()).GENERATE_FINAL_LOCALS = getTestName(true).contains("final");
}
}