[junit 5] use multiResolve instead of fastResolveFor for @MethodSource

GitOrigin-RevId: 5f36021929cf0b2017cd09aad869c88998eeed27
This commit is contained in:
Aleksey Dobrynin
2025-06-16 18:27:11 +02:00
committed by intellij-monorepo-bot
parent 4542127eeb
commit 913adcc0c5
3 changed files with 55 additions and 37 deletions

View File

@@ -16,6 +16,7 @@ import com.intellij.codeInspection.util.InspectionMessage
import com.intellij.execution.JUnitBundle
import com.intellij.execution.junit.*
import com.intellij.execution.junit.references.MethodSourceReference
import com.intellij.execution.junit.references.PsiMethodSourceResolveResult
import com.intellij.jvm.analysis.quickFix.CompositeModCommandQuickFix
import com.intellij.jvm.analysis.quickFix.createModifierQuickfixes
import com.intellij.lang.Language
@@ -562,12 +563,15 @@ private class JUnitMalformedSignatureVisitor(
annotationMemberValue.forEach { attributeValue ->
for (reference in attributeValue.references) {
if (reference is MethodSourceReference) {
val parametrizedMethod = reference.fastResolveFor(method)
if (parametrizedMethod !is PsiMethod) {
val sourceProviders = reference.multiResolve(false)
val sourceProvider = sourceProviders
.mapNotNull { it as? PsiMethodSourceResolveResult }
.firstNotNullOfOrNull { it.getSourceMethodForClass(method.javaPsi.containingClass ?: return) }
if (sourceProvider == null) {
return checkAbsentSourceProvider(containingClass, attributeValue, reference.value, method)
}
else {
val uSourceProvider = parametrizedMethod.toUElementOfType<UMethod>() ?: return
val uSourceProvider = sourceProvider.toUElementOfType<UMethod>() ?: return
return checkSourceProvider(uSourceProvider, containingClass, attributeValue, method)
}
}
@@ -576,6 +580,12 @@ private class JUnitMalformedSignatureVisitor(
}
}
private fun PsiMethodSourceResolveResult.getSourceMethodForClass(owner: PsiClass): PsiMethod? {
val method = element as? PsiMethod ?: return null
if (owners.isEmpty()) return method // direct link
return owner.findMethodBySignature(method, true)
}
private fun checkAbsentSourceProvider(
containingClass: PsiClass, attributeValue: PsiElement, sourceProviderName: String, method: UMethod
) {

View File

@@ -10,7 +10,6 @@ import com.intellij.psi.impl.source.resolve.ResolveCache
import com.intellij.psi.search.searches.AnnotatedElementsSearch
import com.intellij.psi.search.searches.ClassInheritorsSearch
import com.intellij.psi.util.ClassUtil
import com.intellij.psi.util.PsiUtil
import org.jetbrains.uast.*
abstract class BaseJunitAnnotationReference(
@@ -33,10 +32,13 @@ abstract class BaseJunitAnnotationReference(
}
override fun isReferenceTo(element: PsiElement): Boolean {
val literal = getElement().toUElement(UExpression::class.java) ?: return false
val scope = element.toUElement(UMethod::class.java)?.getParentOfType(UClass::class.java) ?: return false
return directLink(literal, scope) == element ||
fastResolveFor(literal, scope) == element
val results: Array<ResolveResult> = multiResolve(false)
for (result in results) {
if (element == result.getElement()) {
return true
}
}
return false
}
override fun resolve(): PsiElement? {
@@ -44,12 +46,6 @@ abstract class BaseJunitAnnotationReference(
return if (results.size == 1) results[0].element else null
}
private fun filteredMethod(clazzMethods: Array<PsiMethod>, uClass: UClass, uMethod: UMethod?): PsiMethod? {
return clazzMethods.firstOrNull { method ->
hasNoStaticProblem(method, uClass, uMethod)
} ?: if (clazzMethods.isEmpty()) null else clazzMethods.first()
}
override fun getVariants(): Array<Any> {
val myLiteral = element.toUElement(UExpression::class.java) ?: return emptyArray()
val topLevelClass = myLiteral.getParentOfType(UClass::class.java) ?: return emptyArray()
@@ -87,33 +83,32 @@ abstract class BaseJunitAnnotationReference(
val methodName = StringUtil.getShortName(string, '#')
if (methodName.isEmpty()) return null
val directClass = ClassUtil.findPsiClass(scope.javaPsi.manager, className, null, false, scope.javaPsi.resolveScope) ?: return null
val directUClass = directClass.toUElement(UClass::class.java) ?: return null
return filteredMethod(directClass.findMethodsByName(methodName, false), directUClass, literal.getParentOfType(UMethod::class.java))
return directClass.findMethodsByName(methodName, false).firstOrNull()
}
private fun fastResolveFor(literal: UExpression, scope: UClass): PsiElement? {
val methodName = literal.evaluate() as String? ?: return null
private fun fastResolveFor(literal: UExpression, scope: UClass): Set<PsiMethod> {
val methodName = literal.evaluate() as String? ?: return setOf()
val psiClazz = scope.javaPsi
val clazzMethods = psiClazz.findMethodsByName(methodName, true)
if (clazzMethods.isEmpty() && (scope.isInterface || PsiUtil.isAbstractClass(psiClazz))) {
val methods = ClassInheritorsSearch.search(psiClazz, psiClazz.resolveScope, false)
.findAll()
.flatMap { aClazz -> aClazz.findMethodsByName(methodName, false).toList() }
return filteredMethod(methods.toTypedArray(), scope, literal.getParentOfType(UMethod::class.java))
}
return filteredMethod(clazzMethods, scope, literal.getParentOfType(UMethod::class.java))
val methods = ClassInheritorsSearch.search(psiClazz, psiClazz.resolveScope, true)
.findAll()
.flatMap { aClazz -> aClazz.findMethodsByName(methodName, true).toList() }
.toMutableSet()
methods.addAll(clazzMethods)
return methods
}
/**
* @param testMethod test method marked with JUnit annotation
* @return the method referenced from the annotation
*/
fun fastResolveFor(testMethod: UMethod): PsiElement? {
val literal = element.toUElement(UExpression::class.java) ?: return null
val scope = literal.getParentOfType(UClass::class.java) ?: return null
private fun fastResolveFor(testMethod: UMethod): Set<PsiMethod> {
val literal = element.toUElement(UExpression::class.java) ?: return setOf()
val scope = literal.getParentOfType(UClass::class.java) ?: return setOf()
val directLink = directLink(literal, scope)
if (directLink != null) return directLink
var currentClass = testMethod.getParentOfType(UClass::class.java) ?: return null
if (directLink != null) return setOf(directLink)
val currentClass = testMethod.getParentOfType(UClass::class.java) ?: return setOf()
return fastResolveFor(literal, currentClass)
}
@@ -131,13 +126,14 @@ abstract class BaseJunitAnnotationReference(
val literal = ref.element.toUElement(UExpression::class.java) ?: return ResolveResult.EMPTY_ARRAY
val uClass = literal.getParentOfType(UClass::class.java) ?: return ResolveResult.EMPTY_ARRAY
val directLink = ref.directLink(literal, uClass)
if (directLink != null) return arrayOf(PsiElementResolveResult(directLink))
if (directLink != null) return arrayOf(PsiMethodSourceResolveResult(directLink, listOf()))
val method = literal.getParentOfType(UMethod::class.java)
if (method != null) { // direct annotation
val resolved = ref.fastResolveFor(method)
return if (resolved is PsiMethod) arrayOf(PsiElementResolveResult(resolved)) else ResolveResult.EMPTY_ARRAY
} else if (uClass.isAnnotationType) { // inherited annotation from another annotation
val owners = method.javaPsi.containingClass?.let { listOf(it) } ?: emptyList()
return ref.fastResolveFor(method).map { PsiMethodSourceResolveResult(it, owners) }.toTypedArray()
}
else if (uClass.isAnnotationType) { // inherited annotation from another annotation
val scope = uClass.sourcePsi?.resolveScope ?: ref.element.resolveScope
val process = ArrayDeque<PsiClass>()
val processed = mutableSetOf<PsiClass>()
@@ -158,9 +154,12 @@ abstract class BaseJunitAnnotationReference(
.mapNotNull { method -> method.toUElement(UMethod::class.java) }
.mapNotNull { method -> method.getParentOfType(UClass::class.java) }
.distinct() // process only classes
.mapNotNull { clazz -> ref.fastResolveFor(literal, clazz) }
.map { method -> PsiElementResolveResult(method) }.toTypedArray()
} else {
.map{ clazz -> clazz to ref.fastResolveFor(literal, clazz) }
.flatMap { (clazz, methods) -> methods.map { method -> method to clazz } }
.groupBy({ it.first }, { it.second })
.map { (method, classes) -> PsiMethodSourceResolveResult(method, classes) }.toTypedArray()
}
else {
return ResolveResult.EMPTY_ARRAY
}
}

View File

@@ -0,0 +1,9 @@
// Copyright 2000-2025 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.execution.junit.references
import com.intellij.psi.PsiClass
import com.intellij.psi.PsiElementResolveResult
import com.intellij.psi.PsiMethod
class PsiMethodSourceResolveResult(method: PsiMethod, val owners: List<PsiClass>): PsiElementResolveResult(method) {
}