mirror of
https://gitflic.ru/project/openide/openide.git
synced 2026-03-22 15:19:59 +07:00
[evaluation] LME-94, LME-112: initial implementation of API Recall for generation-in-chat evaluation
GitOrigin-RevId: 748bbac2646f9671c86a752bda7f915f411d2530
This commit is contained in:
committed by
intellij-monorepo-bot
parent
7bc3c12aab
commit
5b0e72c848
@@ -223,6 +223,9 @@ function addCommonFeatures(sessionDiv, popup, lookup) {
|
||||
addRelevanceModelBlock(popup, lookup, "filter")
|
||||
addAiaDiagnosticsBlock("Response", "aia_response", popup, lookup)
|
||||
addAiaDiagnosticsBlock("Context", "aia_context", popup, lookup)
|
||||
addAiaDiagnosticsBlock("Code snippets from response", "extracted_code_snippets", popup, lookup)
|
||||
addAiaDiagnosticsBlock("Internal api calls from original code snippet", "ground_truth_internal_api_calls", popup, lookup)
|
||||
addAiaDiagnosticsBlock("Extracted api calls from generated code snippet", "predicted_api_calls", popup, lookup)
|
||||
addDiagnosticsBlock("RAW SUGGESTIONS", "raw_proposals", popup, lookup)
|
||||
addDiagnosticsBlock("RAW FILTERED", "raw_filtered", popup, lookup)
|
||||
addDiagnosticsBlock("ANALYZED SUGGESTIONS", "analyzed_proposals", popup, lookup)
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.intellij.cce.metric
|
||||
|
||||
import com.intellij.lang.Language
|
||||
import com.intellij.openapi.extensions.ExtensionPointName
|
||||
import com.intellij.openapi.project.Project
|
||||
|
||||
interface ApiCallExtractor {
|
||||
|
||||
companion object {
|
||||
val EP_NAME: ExtensionPointName<ApiCallExtractor> = ExtensionPointName.create("com.intellij.cce.apiCallExtractor")
|
||||
fun getForLanguage(language: Language): ApiCallExtractor? = EP_NAME.findFirstSafe { it.language == language }
|
||||
}
|
||||
|
||||
val language: Language
|
||||
suspend fun extractForGeneratedCode(code: String, project: Project): List<String>
|
||||
}
|
||||
@@ -0,0 +1,41 @@
|
||||
package com.intellij.cce.metric
|
||||
|
||||
import com.intellij.cce.core.Session
|
||||
import com.intellij.cce.metric.util.Sample
|
||||
|
||||
class ApiRecall : Metric {
|
||||
private val sample = Sample()
|
||||
override val name: String = "API Recall"
|
||||
override val description: String = "The fraction of correctly guessed project-defined API calls"
|
||||
override val showByDefault: Boolean = true
|
||||
override val valueType = MetricValueType.DOUBLE
|
||||
override val value: Double
|
||||
get() = sample.mean()
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
override fun evaluate(sessions: List<Session>): Number {
|
||||
val fileSample = Sample()
|
||||
sessions
|
||||
.flatMap { it.lookups }
|
||||
.forEach {
|
||||
val predictedApiCalls = it.additionalInfo["predicted_api_calls"] as List<String>
|
||||
val groundTruthApiCalls = it.additionalInfo["ground_truth_internal_api_calls"] as List<String>
|
||||
val apiRecall = calculateApiRecallForLookupSnippets(predictedApiCalls, groundTruthApiCalls)
|
||||
fileSample.add(apiRecall)
|
||||
sample.add(apiRecall)
|
||||
}
|
||||
return fileSample.mean()
|
||||
}
|
||||
|
||||
private fun calculateApiRecallForLookupSnippets(
|
||||
predictedApiCalls: List<String>,
|
||||
groundTruthApiCalls: List<String>,
|
||||
): Double {
|
||||
if (groundTruthApiCalls.isEmpty()) return 1.0
|
||||
|
||||
val uniqueGroundTruthApiCalls = groundTruthApiCalls.toSet()
|
||||
val uniquePredictedApiCalls = predictedApiCalls.toSet()
|
||||
val intersection = uniquePredictedApiCalls.intersect(uniqueGroundTruthApiCalls)
|
||||
return intersection.size.toDouble() / uniqueGroundTruthApiCalls.size.toDouble()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,30 @@
|
||||
{
|
||||
"language": "Java",
|
||||
"projectPath": "../../evaluation_projects/calculator-java",
|
||||
"outputDir": "ml-eval-chat-code-generation-java-output",
|
||||
"strategy": {
|
||||
"behaviour": "REMOVE_METHOD_BODY_ASK_TO_GENERATE",
|
||||
"collectContextOnly": false,
|
||||
"minimumInternalApiCallsToConsider": 1
|
||||
},
|
||||
"actions": {
|
||||
"evaluationRoots": [
|
||||
"."
|
||||
]
|
||||
},
|
||||
"interpret": {
|
||||
"saveLogs": true,
|
||||
"sessionProbability": 1.0,
|
||||
"sessionSeed": null,
|
||||
"sessionsLimit": 100,
|
||||
"filesLimit": 1,
|
||||
"order": "LINEAR",
|
||||
"trainTestSplit": 70
|
||||
},
|
||||
"reports": {
|
||||
"evaluationTitle": "Default",
|
||||
"defaultMetrics": null,
|
||||
"sessionsFilters": [],
|
||||
"comparisonFilters": []
|
||||
}
|
||||
}
|
||||
@@ -8,8 +8,10 @@
|
||||
<completionEvaluationVisitor implementation="com.intellij.cce.visitor.JavaTestGenerationVisitor"/>
|
||||
<completionEvaluationVisitor implementation="com.intellij.cce.visitor.JavaDocGenerationVisitor"/>
|
||||
<completionEvaluationVisitor implementation="com.intellij.cce.visitor.JavaCodeGenerationVisitor"/>
|
||||
<completionEvaluationVisitor implementation="com.intellij.cce.visitor.JavaCodeGenerationInChatVisitor"/>
|
||||
<completionEvaluationVisitor implementation="com.intellij.cce.visitor.JavaCompletionContextEvaluationVisitor"/>
|
||||
<completionEvaluationVisitor implementation="com.intellij.cce.visitor.JavaFunctionCallingVisitor"/>
|
||||
<completionEvaluationVisitor implementation="com.intellij.cce.visitor.JavaSelfIdentificationVisitor"/>
|
||||
<apiCallExtractor implementation="com.intellij.cce.visitor.JavaApiCallExtractor"/>
|
||||
</extensions>
|
||||
</idea-plugin>
|
||||
@@ -0,0 +1,54 @@
|
||||
package com.intellij.cce.visitor
|
||||
|
||||
import com.intellij.cce.metric.ApiCallExtractor
|
||||
import com.intellij.ide.actions.QualifiedNameProvider
|
||||
import com.intellij.lang.Language
|
||||
import com.intellij.openapi.application.smartReadActionBlocking
|
||||
import com.intellij.openapi.application.writeAction
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.psi.*
|
||||
|
||||
fun extractCallExpressions(
|
||||
psiElement: PsiElement,
|
||||
filter: (PsiCallExpression) -> Boolean = { true },
|
||||
): List<PsiCallExpression> {
|
||||
val result: MutableList<PsiCallExpression> = mutableListOf()
|
||||
val visitor = object : JavaRecursiveElementVisitor() {
|
||||
override fun visitCallExpression(callExpression: PsiCallExpression) {
|
||||
if (!filter(callExpression)) return
|
||||
result.add(callExpression)
|
||||
super.visitCallExpression(callExpression)
|
||||
}
|
||||
}
|
||||
psiElement.accept(visitor)
|
||||
return result
|
||||
}
|
||||
|
||||
fun PsiElement.getQualifiedName(): String? {
|
||||
QualifiedNameProvider.EP_NAME.extensionList.forEach { provider ->
|
||||
val qualifiedName = provider.getQualifiedName(this)
|
||||
if (qualifiedName != null) return qualifiedName
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
|
||||
class JavaApiCallExtractor : ApiCallExtractor {
|
||||
override val language = Language.findLanguageByID("JAVA")!!
|
||||
|
||||
override suspend fun extractForGeneratedCode(code: String, project: Project): List<String> {
|
||||
val psiFile = writeAction { createPsiFile(code, project) }
|
||||
return smartReadActionBlocking(project) {
|
||||
val callExpressions = extractCallExpressions(psiFile)
|
||||
callExpressions.mapNotNull { it.resolveMethod()?.getQualifiedName() }
|
||||
}
|
||||
}
|
||||
|
||||
private fun createPsiFile(code: String, project: Project): PsiFile {
|
||||
return PsiFileFactory.getInstance(project).createFileFromText(
|
||||
"dummy.java",
|
||||
language,
|
||||
code
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,59 @@
|
||||
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.cce.visitor
|
||||
|
||||
import com.intellij.cce.core.*
|
||||
import com.intellij.cce.visitor.exceptions.PsiConverterException
|
||||
import com.intellij.openapi.application.smartReadActionBlocking
|
||||
import com.intellij.openapi.progress.runBlockingCancellable
|
||||
import com.intellij.openapi.roots.ProjectFileIndex
|
||||
import com.intellij.psi.JavaRecursiveElementVisitor
|
||||
import com.intellij.psi.PsiCallExpression
|
||||
import com.intellij.psi.PsiJavaFile
|
||||
import com.intellij.psi.PsiMethod
|
||||
import com.intellij.psi.util.startOffset
|
||||
|
||||
class JavaCodeGenerationInChatVisitor : EvaluationVisitor, JavaRecursiveElementVisitor() {
|
||||
private var codeFragment: CodeFragment? = null
|
||||
|
||||
override val language: Language = Language.JAVA
|
||||
override val feature: String = "chat-code-generation"
|
||||
|
||||
override fun getFile(): CodeFragment = codeFragment
|
||||
?: throw PsiConverterException("Invoke 'accept' with visitor on PSI first")
|
||||
|
||||
override fun visitJavaFile(file: PsiJavaFile) {
|
||||
codeFragment = CodeFragment(file.textOffset, file.textLength)
|
||||
super.visitJavaFile(file)
|
||||
}
|
||||
|
||||
override fun visitMethod(method: PsiMethod) {
|
||||
val internalApiCalls = runBlockingCancellable { extractInternalApiCalls(method) }
|
||||
codeFragment?.addChild(
|
||||
CodeToken(
|
||||
method.text,
|
||||
method.startOffset,
|
||||
SimpleTokenProperties.create(tokenType = TypeProperty.METHOD, SymbolLocation.PROJECT) {
|
||||
put("apiCalls", internalApiCalls.joinToString("\n"))
|
||||
})
|
||||
)
|
||||
}
|
||||
|
||||
private suspend fun extractInternalApiCalls(method: PsiMethod): List<String> {
|
||||
return smartReadActionBlocking(method.project) {
|
||||
val callExpressions = extractCallExpressions(method)
|
||||
callExpressions.mapNotNull { it.tryGetCorrespondingInternalApi() }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private fun PsiCallExpression.tryGetCorrespondingInternalApi(): String? {
|
||||
val resolvedElement = this.resolveMethod() ?: return null
|
||||
val containingFile = resolvedElement.containingFile?.virtualFile ?: return null
|
||||
val projectFileIndex = ProjectFileIndex.getInstance(this.project)
|
||||
if (projectFileIndex.isInContent(containingFile)) {
|
||||
return resolvedElement.getQualifiedName()
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
}
|
||||
@@ -39,6 +39,9 @@
|
||||
<extensionPoint qualifiedName="com.intellij.cce.evaluableFeature"
|
||||
interface="com.intellij.cce.evaluable.EvaluableFeature"
|
||||
dynamic="true"/>
|
||||
<extensionPoint qualifiedName="com.intellij.cce.apiCallExtractor"
|
||||
interface="com.intellij.cce.metric.ApiCallExtractor"
|
||||
dynamic="true"/>
|
||||
</extensionPoints>
|
||||
|
||||
<extensions defaultExtensionNs="com.intellij.cce">
|
||||
|
||||
Reference in New Issue
Block a user