KTIJ-30925 [ssr] Add typed parameter replacement handling for SSR in Kotlin

Analogously to Java SSR override `handleSubstitution` method to form correct replacement of typed parameters. Added regression test.

GitOrigin-RevId: 8401f6f82d14d886c0fd60e20c0ce989387f24ab
This commit is contained in:
Aleksandr.Govenko
2024-08-09 13:29:36 +02:00
committed by intellij-monorepo-bot
parent 86e037e396
commit 1ea6bb957b
5 changed files with 263 additions and 7 deletions

View File

@@ -20,6 +20,10 @@ import com.intellij.structuralsearch.impl.matcher.compiler.GlobalCompilingVisito
import com.intellij.structuralsearch.impl.matcher.predicates.MatchPredicate
import com.intellij.structuralsearch.impl.matcher.predicates.NotPredicate
import com.intellij.structuralsearch.plugin.replace.ReplaceOptions
import com.intellij.structuralsearch.plugin.replace.ReplacementInfo
import com.intellij.structuralsearch.plugin.replace.impl.ParameterInfo
import com.intellij.structuralsearch.plugin.replace.impl.ReplacementBuilder
import com.intellij.structuralsearch.plugin.replace.impl.Replacer
import com.intellij.structuralsearch.plugin.ui.Configuration
import com.intellij.structuralsearch.plugin.ui.UIUtil
import com.intellij.util.SmartList
@@ -34,11 +38,13 @@ import org.jetbrains.kotlin.idea.structuralsearch.predicates.KotlinExprTypePredi
import org.jetbrains.kotlin.idea.structuralsearch.predicates.KotlinMatchCallSemantics
import org.jetbrains.kotlin.idea.structuralsearch.visitor.KotlinCompilingVisitor
import org.jetbrains.kotlin.idea.structuralsearch.visitor.KotlinMatchingVisitor
import org.jetbrains.kotlin.idea.structuralsearch.visitor.KotlinRecursiveElementVisitor
import org.jetbrains.kotlin.idea.structuralsearch.visitor.KotlinRecursiveElementWalkingVisitor
import org.jetbrains.kotlin.kdoc.psi.impl.KDocTag
import org.jetbrains.kotlin.lexer.KtTokens
import org.jetbrains.kotlin.psi.*
import org.jetbrains.kotlin.utils.KotlinExceptionWithAttachments
import kotlin.math.min
class KotlinStructuralSearchProfile : StructuralSearchProfile() {
override fun isMatchNode(element: PsiElement?): Boolean = element !is PsiWhiteSpace
@@ -368,6 +374,71 @@ class KotlinStructuralSearchProfile : StructuralSearchProfile() {
override fun getPatternContexts(): MutableList<PatternContext> = PATTERN_CONTEXTS
override fun compileReplacementTypedVariable(name: String): String {
return TYPED_VAR_PREFIX + name
}
override fun isReplacementTypedVariable(name: String): Boolean {
return name.substring(0, min(TYPED_VAR_PREFIX.length, name.length)) == TYPED_VAR_PREFIX
}
override fun stripReplacementTypedVariableDecorations(name: String): String {
return name.removePrefix(TYPED_VAR_PREFIX)
}
override fun provideAdditionalReplaceOptions(node: PsiElement, options: ReplaceOptions, builder: ReplacementBuilder) {
val profile = this
node.accept(object : KotlinRecursiveElementVisitor() {
override fun visitParameter(parameter: KtParameter) {
val name = parameter.nameIdentifier
val type = parameter.typeReference ?: return
val nameInfo = builder.findParameterization(name) ?: return
nameInfo.isArgumentContext = false
val infos = mutableMapOf(nameInfo.name to nameInfo)
nameInfo.putUserData(PARAMETER_CONTEXT, infos)
nameInfo.element = parameter
if (profile.isReplacementTypedVariable(type.text)) {
val typeInfo = builder.findParameterization(type) ?: return
typeInfo.isArgumentContext = false
typeInfo.putUserData(PARAMETER_CONTEXT, mapOf(typeInfo.name to typeInfo))
infos[typeInfo.name] = typeInfo
}
val dflt = parameter.defaultValue ?: return
if (profile.isReplacementTypedVariable(dflt.text)) {
val dfltInfo = builder.findParameterization(dflt) ?: return
dfltInfo.isArgumentContext = false
dfltInfo.putUserData(PARAMETER_CONTEXT, mapOf(dfltInfo.name to dfltInfo))
infos[dfltInfo.name] = dfltInfo
}
}
})
}
override fun handleSubstitution(info: ParameterInfo, match: MatchResult, result: StringBuilder, replacementInfo: ReplacementInfo) {
if (info.name == match.name) {
val typeInfos = info.getUserData(PARAMETER_CONTEXT)
if (typeInfos == null) {
return super.handleSubstitution(info, match, result, replacementInfo)
}
if (info.element !is KtParameter) {
return
}
val parameterStart = info.startIndex
val length = info.element.getTextLength() - typeInfos.keys.sumOf { key: String -> key.length + TYPED_VAR_PREFIX.length }
val parameterEnd = parameterStart + length
val template = result.substring(parameterStart, parameterEnd)
val replacementString = handleParameter(info, replacementInfo, -parameterStart, template)
result.delete(parameterStart, parameterEnd)
Replacer.insertSubstitution(result, 0, info, replacementString)
}
}
companion object {
const val TYPED_VAR_PREFIX: String = "_____"
val DEFAULT_CONTEXT: PatternContext = PatternContext("default", KotlinBundle.lazyMessage("context.default"))
@@ -384,5 +455,44 @@ class KotlinStructuralSearchProfile : StructuralSearchProfile() {
}
return result
}
private val PARAMETER_CONTEXT: Key<Map<String, ParameterInfo>> = Key("PARAMETER_CONTEXT")
private fun appendParameter(parameterInfo: ParameterInfo, matchResult: MatchResult, offset: Int, out: StringBuilder) {
val infos = checkNotNull(parameterInfo.getUserData(PARAMETER_CONTEXT))
val matches: MutableList<MatchResult> = SmartList(matchResult.children)
matches.add(matchResult)
matches.sortWith(Comparator.comparingInt { result: MatchResult -> result.match.textOffset }.reversed())
for (match in matches) {
val typeInfo = infos[match.name]
if (typeInfo != null) out.insert(typeInfo.startIndex + offset, match.matchImage)
}
}
private fun handleParameter(info: ParameterInfo, replacementInfo: ReplacementInfo, offset: Int, template: String): String {
val matchResult = checkNotNull(replacementInfo.getNamedMatchResult(info.name))
val result = StringBuilder()
if (matchResult.isMultipleMatch) {
var previous: PsiElement? = null
for (child in matchResult.children) {
val match = child.match.parent
if (previous != null) addSeparatorText(previous, match, result)
appendParameter(info, child, offset + result.length, result.append(template))
previous = match
}
} else {
result.append(template)
appendParameter(info, matchResult, offset, result)
}
return result.toString()
}
private fun addSeparatorText(left: PsiElement, right: PsiElement, out: StringBuilder) {
var e = left.nextSibling
while (e != null && e !== right) {
out.append(e.text)
e = e.nextSibling
}
}
}
}

View File

@@ -306,6 +306,10 @@ class KotlinMatchingVisitor(private val myMatchingVisitor: GlobalMatchingVisitor
override fun visitTypeReference(typeReference: KtTypeReference) {
val other = getTreeElementDepar<KtTypeReference>() ?: return
myMatchingVisitor.result = myMatchingVisitor.matchSons(typeReference, other)
val handler = getHandler(typeReference)
if (myMatchingVisitor.result && handler is SubstitutionHandler) {
handler.reset()
}
}
override fun visitQualifiedExpression(expression: KtQualifiedExpression) {
@@ -524,6 +528,7 @@ class KotlinMatchingVisitor(private val myMatchingVisitor: GlobalMatchingVisitor
&& parameter.nameIdentifier != null
&& other.nameIdentifier == null
) other else other.nameIdentifier
myMatchingVisitor.getMatchContext().pushResult()
myMatchingVisitor.result = myMatchingVisitor.match(parameter.typeReference, other.typeReference)
&& myMatchingVisitor.match(parameter.defaultValue, other.defaultValue)
&& (parameter.isVarArg == other.isVarArg || getHandler(parameter) is SubstitutionHandler)
@@ -534,7 +539,8 @@ class KotlinMatchingVisitor(private val myMatchingVisitor: GlobalMatchingVisitor
parameter.nameIdentifier?.let { nameIdentifier ->
val handler = getHandler(nameIdentifier)
if (myMatchingVisitor.result && handler is SubstitutionHandler) {
handler.handle(other.nameIdentifier, myMatchingVisitor.matchContext)
myMatchingVisitor.scopeMatch(parameter.nameIdentifier,
myMatchingVisitor.matchContext.pattern.isTypedVar(parameter.nameIdentifier), otherNameIdentifier)
}
}
}

View File

@@ -5,11 +5,7 @@ import com.intellij.lang.Language
import com.intellij.openapi.fileTypes.LanguageFileType
import com.intellij.openapi.project.Project
import com.intellij.openapi.util.Key
import com.intellij.psi.PsiCodeFragment
import com.intellij.psi.PsiComment
import com.intellij.psi.PsiElement
import com.intellij.psi.PsiErrorElement
import com.intellij.psi.PsiWhiteSpace
import com.intellij.psi.*
import com.intellij.psi.util.elementType
import com.intellij.structuralsearch.*
import com.intellij.structuralsearch.impl.matcher.CompiledPattern
@@ -19,6 +15,10 @@ import com.intellij.structuralsearch.impl.matcher.compiler.GlobalCompilingVisito
import com.intellij.structuralsearch.impl.matcher.predicates.MatchPredicate
import com.intellij.structuralsearch.impl.matcher.predicates.NotPredicate
import com.intellij.structuralsearch.plugin.replace.ReplaceOptions
import com.intellij.structuralsearch.plugin.replace.ReplacementInfo
import com.intellij.structuralsearch.plugin.replace.impl.ParameterInfo
import com.intellij.structuralsearch.plugin.replace.impl.ReplacementBuilder
import com.intellij.structuralsearch.plugin.replace.impl.Replacer
import com.intellij.structuralsearch.plugin.ui.Configuration
import com.intellij.structuralsearch.plugin.ui.UIUtil
import com.intellij.util.SmartList
@@ -32,12 +32,14 @@ import org.jetbrains.kotlin.idea.k2.codeinsight.structuralsearch.predicates.Kotl
import org.jetbrains.kotlin.idea.k2.codeinsight.structuralsearch.predicates.KotlinMatchCallSemantics
import org.jetbrains.kotlin.idea.k2.codeinsight.structuralsearch.visitor.KotlinCompilingVisitor
import org.jetbrains.kotlin.idea.k2.codeinsight.structuralsearch.visitor.KotlinMatchingVisitor
import org.jetbrains.kotlin.idea.k2.codeinsight.structuralsearch.visitor.KotlinRecursiveElementVisitor
import org.jetbrains.kotlin.idea.k2.codeinsight.structuralsearch.visitor.KotlinRecursiveElementWalkingVisitor
import org.jetbrains.kotlin.idea.liveTemplates.KotlinTemplateContextType
import org.jetbrains.kotlin.kdoc.psi.impl.KDocTag
import org.jetbrains.kotlin.lexer.KtTokens
import org.jetbrains.kotlin.psi.*
import org.jetbrains.kotlin.utils.KotlinExceptionWithAttachments
import kotlin.math.min
class KotlinStructuralSearchProfile : StructuralSearchProfile() {
override fun isMatchNode(element: PsiElement?): Boolean = element !is PsiWhiteSpace
@@ -377,6 +379,71 @@ class KotlinStructuralSearchProfile : StructuralSearchProfile() {
return result
}
override fun compileReplacementTypedVariable(name: String): String {
return TYPED_VAR_PREFIX + name
}
override fun isReplacementTypedVariable(name: String): Boolean {
return name.substring(0, min(TYPED_VAR_PREFIX.length, name.length)) == TYPED_VAR_PREFIX
}
override fun stripReplacementTypedVariableDecorations(name: String): String {
return name.removePrefix(TYPED_VAR_PREFIX)
}
override fun provideAdditionalReplaceOptions(node: PsiElement, options: ReplaceOptions, builder: ReplacementBuilder) {
val profile = this
node.accept(object : KotlinRecursiveElementVisitor() {
override fun visitParameter(parameter: KtParameter) {
val name = parameter.nameIdentifier
val type = parameter.typeReference ?: return
val nameInfo = builder.findParameterization(name) ?: return
nameInfo.isArgumentContext = false
val infos = mutableMapOf(nameInfo.name to nameInfo)
nameInfo.putUserData(PARAMETER_CONTEXT, infos)
nameInfo.element = parameter
if (profile.isReplacementTypedVariable(type.text)) {
val typeInfo = builder.findParameterization(type) ?: return
typeInfo.isArgumentContext = false
typeInfo.putUserData(PARAMETER_CONTEXT, mapOf(typeInfo.name to typeInfo))
infos[typeInfo.name] = typeInfo
}
val dflt = parameter.defaultValue ?: return
if (profile.isReplacementTypedVariable(dflt.text)) {
val dfltInfo = builder.findParameterization(dflt) ?: return
dfltInfo.isArgumentContext = false
dfltInfo.putUserData(PARAMETER_CONTEXT, mapOf(dfltInfo.name to dfltInfo))
infos[dfltInfo.name] = dfltInfo
}
}
})
}
override fun handleSubstitution(info: ParameterInfo, match: MatchResult, result: StringBuilder, replacementInfo: ReplacementInfo) {
if (info.name == match.name) {
val typeInfos = info.getUserData(PARAMETER_CONTEXT)
if (typeInfos == null) {
return super.handleSubstitution(info, match, result, replacementInfo)
}
if (info.element !is KtParameter) {
return
}
val parameterStart = info.startIndex
val length = info.element.getTextLength() - typeInfos.keys.sumOf { key: String -> key.length + TYPED_VAR_PREFIX.length }
val parameterEnd = parameterStart + length
val template = result.substring(parameterStart, parameterEnd)
val replacementString = handleParameter(info, replacementInfo, -parameterStart, template)
result.delete(parameterStart, parameterEnd)
Replacer.insertSubstitution(result, 0, info, replacementString)
}
}
companion object {
const val TYPED_VAR_PREFIX: String = "_____"
@@ -387,5 +454,54 @@ class KotlinStructuralSearchProfile : StructuralSearchProfile() {
private val PATTERN_CONTEXTS: MutableList<PatternContext> = mutableListOf(DEFAULT_CONTEXT, PROPERTY_CONTEXT)
private val PATTERN_ERROR: Key<String> = Key("patternError")
fun getNonWhitespaceChildren(fragment: PsiElement): List<PsiElement> {
var element = fragment.firstChild
val result: MutableList<PsiElement> = SmartList()
while (element != null) {
if (element !is PsiWhiteSpace) result.add(element)
element = element.nextSibling
}
return result
}
private val PARAMETER_CONTEXT: Key<Map<String, ParameterInfo>> = Key("PARAMETER_CONTEXT")
private fun appendParameter(parameterInfo: ParameterInfo, matchResult: MatchResult, offset: Int, out: StringBuilder) {
val infos = checkNotNull(parameterInfo.getUserData(PARAMETER_CONTEXT))
val matches: MutableList<MatchResult> = SmartList(matchResult.children)
matches.add(matchResult)
matches.sortWith(Comparator.comparingInt { result: MatchResult -> result.match.textOffset }.reversed())
for (match in matches) {
val typeInfo = infos[match.name]
if (typeInfo != null) out.insert(typeInfo.startIndex + offset, match.matchImage)
}
}
private fun handleParameter(info: ParameterInfo, replacementInfo: ReplacementInfo, offset: Int, template: String): String {
val matchResult = checkNotNull(replacementInfo.getNamedMatchResult(info.name))
val result = StringBuilder()
if (matchResult.isMultipleMatch) {
var previous: PsiElement? = null
for (child in matchResult.children) {
val match = child.match.parent
if (previous != null) addSeparatorText(previous, match, result)
appendParameter(info, child, offset + result.length, result.append(template))
previous = match
}
} else {
result.append(template)
appendParameter(info, matchResult, offset, result)
}
return result.toString()
}
private fun addSeparatorText(left: PsiElement, right: PsiElement, out: StringBuilder) {
var e = left.nextSibling
while (e != null && e !== right) {
out.append(e.text)
e = e.nextSibling
}
}
}
}

View File

@@ -287,6 +287,10 @@ class KotlinMatchingVisitor(private val myMatchingVisitor: GlobalMatchingVisitor
override fun visitTypeReference(typeReference: KtTypeReference) {
val other = getTreeElementDepar<KtTypeReference>() ?: return
myMatchingVisitor.result = myMatchingVisitor.matchSons(typeReference, other)
val handler = getHandler(typeReference)
if (myMatchingVisitor.result && handler is SubstitutionHandler) {
handler.reset()
}
}
override fun visitQualifiedExpression(expression: KtQualifiedExpression) {
@@ -516,6 +520,7 @@ class KotlinMatchingVisitor(private val myMatchingVisitor: GlobalMatchingVisitor
&& parameter.nameIdentifier != null
&& other.nameIdentifier == null
) other else other.nameIdentifier
myMatchingVisitor.matchContext.pushResult()
myMatchingVisitor.result = myMatchingVisitor.match(parameter.typeReference, other.typeReference)
&& myMatchingVisitor.match(parameter.defaultValue, other.defaultValue)
&& (parameter.isVarArg == other.isVarArg || getHandler(parameter) is SubstitutionHandler)
@@ -526,7 +531,8 @@ class KotlinMatchingVisitor(private val myMatchingVisitor: GlobalMatchingVisitor
parameter.nameIdentifier?.let { nameIdentifier ->
val handler = getHandler(nameIdentifier)
if (myMatchingVisitor.result && handler is SubstitutionHandler) {
handler.handle(other.nameIdentifier, myMatchingVisitor.matchContext)
myMatchingVisitor.scopeMatch(parameter.nameIdentifier,
myMatchingVisitor.matchContext.pattern.isTypedVar(parameter.nameIdentifier), otherNameIdentifier)
}
}
}

View File

@@ -59,6 +59,15 @@ class KotlinSSRFunctionReplaceTest : KotlinStructuralReplaceTest() {
)
}
fun testFunctionMultipleTypedParameterFormatCopy() {
doTest(
searchPattern = "fun '_ID('_PARAM* : '_TYPE)",
replacePattern = "fun '_ID('_PARAM : '_TYPE)",
match = "public fun foo(bar : Int = 0, baz : Boolean = true) {}",
result = "public fun foo(bar : Int = 0, baz : Boolean = true) {}"
)
}
fun testFunctionDefaultParameterFormatCopy() {
doTest(
searchPattern = "fun '_ID('_PARAM : '_TYPE = '_INIT)",
@@ -68,6 +77,15 @@ class KotlinSSRFunctionReplaceTest : KotlinStructuralReplaceTest() {
)
}
fun testFunctionMultipleDefaultParameterFormatCopy() {
doTest(
searchPattern = "fun '_ID('_PARAM* : '_TYPE = '_INIT)",
replacePattern = "fun '_ID('_PARAM : '_TYPE = '_INIT)",
match = "public fun foo(bar : Int = 0, baz : Boolean = true) {}",
result = "public fun foo(bar : Int = 0, baz : Boolean = true) {}"
)
}
fun testFunctionMultiParamCountFilter() {
doTest(
searchPattern = "fun '_ID('_PARAM*)",