[evaluation] LME-94, LME-112: initial implementation of API Recall for generation-in-chat evaluation

GitOrigin-RevId: 748bbac2646f9671c86a752bda7f915f411d2530
This commit is contained in:
Nikolai.Palchikov
2024-09-24 11:01:28 +02:00
committed by intellij-monorepo-bot
parent 7bc3c12aab
commit 5b0e72c848
9 changed files with 208 additions and 0 deletions

View File

@@ -223,6 +223,9 @@ function addCommonFeatures(sessionDiv, popup, lookup) {
addRelevanceModelBlock(popup, lookup, "filter")
addAiaDiagnosticsBlock("Response", "aia_response", popup, lookup)
addAiaDiagnosticsBlock("Context", "aia_context", popup, lookup)
addAiaDiagnosticsBlock("Code snippets from response", "extracted_code_snippets", popup, lookup)
addAiaDiagnosticsBlock("Internal api calls from original code snippet", "ground_truth_internal_api_calls", popup, lookup)
addAiaDiagnosticsBlock("Extracted api calls from generated code snippet", "predicted_api_calls", popup, lookup)
addDiagnosticsBlock("RAW SUGGESTIONS", "raw_proposals", popup, lookup)
addDiagnosticsBlock("RAW FILTERED", "raw_filtered", popup, lookup)
addDiagnosticsBlock("ANALYZED SUGGESTIONS", "analyzed_proposals", popup, lookup)

View File

@@ -0,0 +1,16 @@
package com.intellij.cce.metric
import com.intellij.lang.Language
import com.intellij.openapi.extensions.ExtensionPointName
import com.intellij.openapi.project.Project
interface ApiCallExtractor {
companion object {
val EP_NAME: ExtensionPointName<ApiCallExtractor> = ExtensionPointName.create("com.intellij.cce.apiCallExtractor")
fun getForLanguage(language: Language): ApiCallExtractor? = EP_NAME.findFirstSafe { it.language == language }
}
val language: Language
suspend fun extractForGeneratedCode(code: String, project: Project): List<String>
}

View File

@@ -0,0 +1,41 @@
package com.intellij.cce.metric
import com.intellij.cce.core.Session
import com.intellij.cce.metric.util.Sample
class ApiRecall : Metric {
private val sample = Sample()
override val name: String = "API Recall"
override val description: String = "The fraction of correctly guessed project-defined API calls"
override val showByDefault: Boolean = true
override val valueType = MetricValueType.DOUBLE
override val value: Double
get() = sample.mean()
@Suppress("UNCHECKED_CAST")
override fun evaluate(sessions: List<Session>): Number {
val fileSample = Sample()
sessions
.flatMap { it.lookups }
.forEach {
val predictedApiCalls = it.additionalInfo["predicted_api_calls"] as List<String>
val groundTruthApiCalls = it.additionalInfo["ground_truth_internal_api_calls"] as List<String>
val apiRecall = calculateApiRecallForLookupSnippets(predictedApiCalls, groundTruthApiCalls)
fileSample.add(apiRecall)
sample.add(apiRecall)
}
return fileSample.mean()
}
private fun calculateApiRecallForLookupSnippets(
predictedApiCalls: List<String>,
groundTruthApiCalls: List<String>,
): Double {
if (groundTruthApiCalls.isEmpty()) return 1.0
val uniqueGroundTruthApiCalls = groundTruthApiCalls.toSet()
val uniquePredictedApiCalls = predictedApiCalls.toSet()
val intersection = uniquePredictedApiCalls.intersect(uniqueGroundTruthApiCalls)
return intersection.size.toDouble() / uniqueGroundTruthApiCalls.size.toDouble()
}
}

View File

@@ -0,0 +1,30 @@
{
"language": "Java",
"projectPath": "../../evaluation_projects/calculator-java",
"outputDir": "ml-eval-chat-code-generation-java-output",
"strategy": {
"behaviour": "REMOVE_METHOD_BODY_ASK_TO_GENERATE",
"collectContextOnly": false,
"minimumInternalApiCallsToConsider": 1
},
"actions": {
"evaluationRoots": [
"."
]
},
"interpret": {
"saveLogs": true,
"sessionProbability": 1.0,
"sessionSeed": null,
"sessionsLimit": 100,
"filesLimit": 1,
"order": "LINEAR",
"trainTestSplit": 70
},
"reports": {
"evaluationTitle": "Default",
"defaultMetrics": null,
"sessionsFilters": [],
"comparisonFilters": []
}
}

View File

@@ -8,8 +8,10 @@
<completionEvaluationVisitor implementation="com.intellij.cce.visitor.JavaTestGenerationVisitor"/>
<completionEvaluationVisitor implementation="com.intellij.cce.visitor.JavaDocGenerationVisitor"/>
<completionEvaluationVisitor implementation="com.intellij.cce.visitor.JavaCodeGenerationVisitor"/>
<completionEvaluationVisitor implementation="com.intellij.cce.visitor.JavaCodeGenerationInChatVisitor"/>
<completionEvaluationVisitor implementation="com.intellij.cce.visitor.JavaCompletionContextEvaluationVisitor"/>
<completionEvaluationVisitor implementation="com.intellij.cce.visitor.JavaFunctionCallingVisitor"/>
<completionEvaluationVisitor implementation="com.intellij.cce.visitor.JavaSelfIdentificationVisitor"/>
<apiCallExtractor implementation="com.intellij.cce.visitor.JavaApiCallExtractor"/>
</extensions>
</idea-plugin>

View File

@@ -0,0 +1,54 @@
package com.intellij.cce.visitor
import com.intellij.cce.metric.ApiCallExtractor
import com.intellij.ide.actions.QualifiedNameProvider
import com.intellij.lang.Language
import com.intellij.openapi.application.smartReadActionBlocking
import com.intellij.openapi.application.writeAction
import com.intellij.openapi.project.Project
import com.intellij.psi.*
fun extractCallExpressions(
psiElement: PsiElement,
filter: (PsiCallExpression) -> Boolean = { true },
): List<PsiCallExpression> {
val result: MutableList<PsiCallExpression> = mutableListOf()
val visitor = object : JavaRecursiveElementVisitor() {
override fun visitCallExpression(callExpression: PsiCallExpression) {
if (!filter(callExpression)) return
result.add(callExpression)
super.visitCallExpression(callExpression)
}
}
psiElement.accept(visitor)
return result
}
fun PsiElement.getQualifiedName(): String? {
QualifiedNameProvider.EP_NAME.extensionList.forEach { provider ->
val qualifiedName = provider.getQualifiedName(this)
if (qualifiedName != null) return qualifiedName
}
return null
}
class JavaApiCallExtractor : ApiCallExtractor {
override val language = Language.findLanguageByID("JAVA")!!
override suspend fun extractForGeneratedCode(code: String, project: Project): List<String> {
val psiFile = writeAction { createPsiFile(code, project) }
return smartReadActionBlocking(project) {
val callExpressions = extractCallExpressions(psiFile)
callExpressions.mapNotNull { it.resolveMethod()?.getQualifiedName() }
}
}
private fun createPsiFile(code: String, project: Project): PsiFile {
return PsiFileFactory.getInstance(project).createFileFromText(
"dummy.java",
language,
code
)
}
}

View File

@@ -0,0 +1,59 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.cce.visitor
import com.intellij.cce.core.*
import com.intellij.cce.visitor.exceptions.PsiConverterException
import com.intellij.openapi.application.smartReadActionBlocking
import com.intellij.openapi.progress.runBlockingCancellable
import com.intellij.openapi.roots.ProjectFileIndex
import com.intellij.psi.JavaRecursiveElementVisitor
import com.intellij.psi.PsiCallExpression
import com.intellij.psi.PsiJavaFile
import com.intellij.psi.PsiMethod
import com.intellij.psi.util.startOffset
class JavaCodeGenerationInChatVisitor : EvaluationVisitor, JavaRecursiveElementVisitor() {
private var codeFragment: CodeFragment? = null
override val language: Language = Language.JAVA
override val feature: String = "chat-code-generation"
override fun getFile(): CodeFragment = codeFragment
?: throw PsiConverterException("Invoke 'accept' with visitor on PSI first")
override fun visitJavaFile(file: PsiJavaFile) {
codeFragment = CodeFragment(file.textOffset, file.textLength)
super.visitJavaFile(file)
}
override fun visitMethod(method: PsiMethod) {
val internalApiCalls = runBlockingCancellable { extractInternalApiCalls(method) }
codeFragment?.addChild(
CodeToken(
method.text,
method.startOffset,
SimpleTokenProperties.create(tokenType = TypeProperty.METHOD, SymbolLocation.PROJECT) {
put("apiCalls", internalApiCalls.joinToString("\n"))
})
)
}
private suspend fun extractInternalApiCalls(method: PsiMethod): List<String> {
return smartReadActionBlocking(method.project) {
val callExpressions = extractCallExpressions(method)
callExpressions.mapNotNull { it.tryGetCorrespondingInternalApi() }
}
}
private fun PsiCallExpression.tryGetCorrespondingInternalApi(): String? {
val resolvedElement = this.resolveMethod() ?: return null
val containingFile = resolvedElement.containingFile?.virtualFile ?: return null
val projectFileIndex = ProjectFileIndex.getInstance(this.project)
if (projectFileIndex.isInContent(containingFile)) {
return resolvedElement.getQualifiedName()
}
return null
}
}

View File

@@ -39,6 +39,9 @@
<extensionPoint qualifiedName="com.intellij.cce.evaluableFeature"
interface="com.intellij.cce.evaluable.EvaluableFeature"
dynamic="true"/>
<extensionPoint qualifiedName="com.intellij.cce.apiCallExtractor"
interface="com.intellij.cce.metric.ApiCallExtractor"
dynamic="true"/>
</extensionPoints>
<extensions defaultExtensionNs="com.intellij.cce">