[java-intentions] IDEA-327634 CreateSealedClassMissingSwitchBranchesFix support generics

- try to calculate generics for generic sealed classes using a selector type

GitOrigin-RevId: 80fafc8f359a8841814f0e13f87f67d4e81039cb
This commit is contained in:
Mikhail Pyltsin
2023-08-11 13:23:23 +02:00
committed by intellij-monorepo-bot
parent d60be7afdc
commit 584d755d71
6 changed files with 156 additions and 9 deletions

View File

@@ -24,7 +24,7 @@ public final class CreateEnumMissingSwitchBranchesFix extends CreateMissingSwitc
}
@Override
protected @NotNull List<String> getAllNames(@NotNull PsiClass aClass) {
protected @NotNull List<String> getAllNames(@NotNull PsiClass aClass, @NotNull PsiSwitchBlock switchBlock) {
return StreamEx.of(aClass.getAllFields()).select(PsiEnumConstant.class).map(PsiField::getName).toList();
}

View File

@@ -41,7 +41,7 @@ public final class CreateMissingDeconstructionRecordClassBranchesFix extends Cre
}
@Override
protected @NotNull List<String> getAllNames(@NotNull PsiClass aClass) {
protected @NotNull List<String> getAllNames(@NotNull PsiClass aClass, @NotNull PsiSwitchBlock switchBlock) {
return allNames;
}

View File

@@ -34,10 +34,15 @@ public abstract class CreateMissingSwitchBranchesFix extends BaseSwitchFix {
final PsiClass psiClass = switchType.resolve();
if (psiClass == null) return;
List<PsiSwitchLabelStatementBase> addedLabels = CreateSwitchBranchesUtil
.createMissingBranches(switchBlock, getAllNames(psiClass), myNames, getCaseExtractor());
.createMissingBranches(switchBlock, getAllNames(psiClass, switchBlock), getNames(switchBlock), getCaseExtractor());
CreateSwitchBranchesUtil.createTemplate(switchBlock, addedLabels, updater);
}
abstract protected @NotNull List<String> getAllNames(@NotNull PsiClass aClass);
@NotNull
protected Set<String> getNames(@NotNull PsiSwitchBlock switchBlock) {
return myNames;
}
abstract protected @NotNull List<String> getAllNames(@NotNull PsiClass aClass, @NotNull PsiSwitchBlock switchBlock);
abstract protected @NotNull Function<PsiSwitchLabelStatementBase, List<String>> getCaseExtractor();
}

View File

@@ -1,15 +1,18 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.siyeh.ig.fixes;
import com.intellij.openapi.project.Project;
import com.intellij.psi.*;
import com.intellij.psi.search.GlobalSearchScope;
import com.intellij.psi.util.InheritanceUtil;
import com.intellij.psi.util.PsiUtil;
import com.intellij.psi.util.TypeConversionUtil;
import com.intellij.util.containers.ContainerUtil;
import com.siyeh.InspectionGadgetsBundle;
import org.jetbrains.annotations.Nls;
import org.jetbrains.annotations.NotNull;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.*;
import java.util.function.Function;
public class CreateSealedClassMissingSwitchBranchesFix extends CreateMissingSwitchBranchesFix {
@@ -27,8 +30,93 @@ public class CreateSealedClassMissingSwitchBranchesFix extends CreateMissingSwit
}
@Override
protected @NotNull List<String> getAllNames(@NotNull PsiClass ignored) {
return myAllNames;
protected @NotNull List<String> getAllNames(@NotNull PsiClass ignored, @NotNull PsiSwitchBlock switchBlock) {
Map<String, String> mapToConvert = getConversionNewTypeWithGeneric(switchBlock);
List<String> result = new ArrayList<>();
for (String name : myAllNames) {
if (mapToConvert.containsKey(name)) {
result.add(mapToConvert.get(name));
}
else {
result.add(name);
}
}
return result;
}
@NotNull
private Map<String, String> getConversionNewTypeWithGeneric(@NotNull PsiSwitchBlock switchBlock) {
HashMap<String, String> mapToConvert = new HashMap<>();
PsiExpression expression = switchBlock.getExpression();
if (expression == null) {
return Map.of();
}
PsiType expressionType = expression.getType();
if (!(expressionType instanceof PsiClassType expressionClassType)) {
return Map.of();
}
PsiClassType.ClassResolveResult classResolveResult = expressionClassType.resolveGenerics();
PsiClass expressionClass = classResolveResult.getElement();
PsiSubstitutor superSubstitutor = classResolveResult.getSubstitutor();
if (expressionClass == null) {
return Map.of();
}
for (String myName : myNames) {
Project project = switchBlock.getProject();
PsiClass[] classes = JavaPsiFacade.getInstance(project).findClasses(myName, GlobalSearchScope.projectScope(project));
if (classes.length != 1) {
continue;
}
PsiClass classToAdd = classes[0];
if (!classToAdd.hasTypeParameters()) {
continue;
}
if (!InheritanceUtil.isInheritorOrSelf(classToAdd, expressionClass, true)) {
return Map.of();
}
PsiSubstitutor inheritorSubstitutor = TypeConversionUtil.getSuperClassSubstitutor(expressionClass, classToAdd, PsiSubstitutor.EMPTY);
PsiSubstitutor targetSubstitutor = PsiSubstitutor.EMPTY;
for (Map.Entry<PsiTypeParameter, PsiType> entry : inheritorSubstitutor.getSubstitutionMap().entrySet()) {
PsiType value = entry.getValue();
PsiClass psiClass = PsiUtil.resolveClassInClassTypeOnly(value);
if (!(psiClass instanceof PsiTypeParameter derivedTypeParameter)) {
continue;
}
PsiType substituted = superSubstitutor.substitute(entry.getKey());
targetSubstitutor = targetSubstitutor.put(derivedTypeParameter, substituted);
}
for (PsiTypeParameter parameter : classToAdd.getTypeParameters()) {
if (targetSubstitutor.getSubstitutionMap().containsKey(parameter)) {
continue;
}
targetSubstitutor = targetSubstitutor.put(parameter, PsiWildcardType.createUnbounded(parameter.getManager()));
}
PsiClassType classTypeToAdd = PsiElementFactory.getInstance(project).createType(classToAdd, targetSubstitutor);
if (TypeConversionUtil.isAssignable(expressionClassType, classTypeToAdd)) {
mapToConvert.put(myName, classTypeToAdd.getCanonicalText());
}
else {
return Map.of();
}
}
return mapToConvert;
}
@Override
protected @NotNull Set<String> getNames(@NotNull PsiSwitchBlock switchBlock) {
Map<String, String> mapToConvert = getConversionNewTypeWithGeneric(switchBlock);
Set<String> result = new LinkedHashSet<>();
for (String name : myNames) {
if (mapToConvert.containsKey(name)) {
result.add(mapToConvert.get(name));
}
else {
result.add(name);
}
}
return result;
}
@Override

View File

@@ -0,0 +1,28 @@
// "Create missing branches: 'Test.Bar', and 'Test.Foo'" "true-preview"
import java.util.List;
class Test {
public static void main(String[] args) {
List<Example<String, Integer>> examples = List.of();
for (Example<String, Integer> example : examples) {
String res = switch (example) {
case Bar<Integer> v -> null;
case Foo<String, Integer, ?> v -> null;
};
}
}
interface AB<A, B> {
}
sealed interface Example<A, B> extends AB<A, B> permits Foo, Bar {
}
record Foo<A, B, C>(A a, C c) implements Example<A, B> {
}
static final class Bar<B> implements Example<String, B> {
}
}

View File

@@ -0,0 +1,26 @@
// "Create missing branches: 'Test.Bar', and 'Test.Foo'" "true-preview"
import java.util.List;
class Test {
public static void main(String[] args) {
List<Example<String, Integer>> examples = List.of();
for (Example<String, Integer> example : examples) {
String res = switch (example<caret>) {
};
}
}
interface AB<A, B> {
}
sealed interface Example<A, B> extends AB<A, B> permits Foo, Bar {
}
record Foo<A, B, C>(A a, C c) implements Example<A, B> {
}
static final class Bar<B> implements Example<String, B> {
}
}