From 5b0e72c84862de27494aab83f07667bcb965c50d Mon Sep 17 00:00:00 2001 From: "Nikolai.Palchikov" Date: Tue, 24 Sep 2024 11:01:28 +0200 Subject: [PATCH] [evaluation] LME-94, LME-112: initial implementation of API Recall for generation-in-chat evaluation GitOrigin-RevId: 748bbac2646f9671c86a752bda7f915f411d2530 --- .../core/resources/script.js | 3 + .../intellij/cce/metric/ApiCallExtractor.kt | 16 +++++ .../src/com/intellij/cce/metric/ApiRecall.kt | 41 +++++++++++++ .../chatcodegeneration/java.json | 30 ++++++++++ .../META-INF/evaluationPlugin-java.xml | 2 + .../cce/visitor/JavaApiCallExtractor.kt | 54 +++++++++++++++++ .../JavaCodeGenerationInChatVisitor.kt | 59 +++++++++++++++++++ .../resources/META-INF/plugin.xml | 3 + .../cce/evaluable/chat/ChatFeature.kt | 0 9 files changed, 208 insertions(+) create mode 100644 plugins/evaluation-plugin/core/src/com/intellij/cce/metric/ApiCallExtractor.kt create mode 100644 plugins/evaluation-plugin/core/src/com/intellij/cce/metric/ApiRecall.kt create mode 100644 plugins/evaluation-plugin/languages/evaluationconfig/chatcodegeneration/java.json create mode 100644 plugins/evaluation-plugin/languages/java/src/com/intellij/cce/visitor/JavaApiCallExtractor.kt create mode 100644 plugins/evaluation-plugin/languages/java/src/com/intellij/cce/visitor/JavaCodeGenerationInChatVisitor.kt create mode 100644 plugins/evaluation-plugin/src/com/intellij/cce/evaluable/chat/ChatFeature.kt diff --git a/plugins/evaluation-plugin/core/resources/script.js b/plugins/evaluation-plugin/core/resources/script.js index 2e2cb3015082..2a69d0e645df 100644 --- a/plugins/evaluation-plugin/core/resources/script.js +++ b/plugins/evaluation-plugin/core/resources/script.js @@ -223,6 +223,9 @@ function addCommonFeatures(sessionDiv, popup, lookup) { addRelevanceModelBlock(popup, lookup, "filter") addAiaDiagnosticsBlock("Response", "aia_response", popup, lookup) addAiaDiagnosticsBlock("Context", "aia_context", popup, lookup) + addAiaDiagnosticsBlock("Code snippets from response", "extracted_code_snippets", popup, lookup) + addAiaDiagnosticsBlock("Internal api calls from original code snippet", "ground_truth_internal_api_calls", popup, lookup) + addAiaDiagnosticsBlock("Extracted api calls from generated code snippet", "predicted_api_calls", popup, lookup) addDiagnosticsBlock("RAW SUGGESTIONS", "raw_proposals", popup, lookup) addDiagnosticsBlock("RAW FILTERED", "raw_filtered", popup, lookup) addDiagnosticsBlock("ANALYZED SUGGESTIONS", "analyzed_proposals", popup, lookup) diff --git a/plugins/evaluation-plugin/core/src/com/intellij/cce/metric/ApiCallExtractor.kt b/plugins/evaluation-plugin/core/src/com/intellij/cce/metric/ApiCallExtractor.kt new file mode 100644 index 000000000000..1b625c7ad260 --- /dev/null +++ b/plugins/evaluation-plugin/core/src/com/intellij/cce/metric/ApiCallExtractor.kt @@ -0,0 +1,16 @@ +package com.intellij.cce.metric + +import com.intellij.lang.Language +import com.intellij.openapi.extensions.ExtensionPointName +import com.intellij.openapi.project.Project + +interface ApiCallExtractor { + + companion object { + val EP_NAME: ExtensionPointName = ExtensionPointName.create("com.intellij.cce.apiCallExtractor") + fun getForLanguage(language: Language): ApiCallExtractor? = EP_NAME.findFirstSafe { it.language == language } + } + + val language: Language + suspend fun extractForGeneratedCode(code: String, project: Project): List +} \ No newline at end of file diff --git a/plugins/evaluation-plugin/core/src/com/intellij/cce/metric/ApiRecall.kt b/plugins/evaluation-plugin/core/src/com/intellij/cce/metric/ApiRecall.kt new file mode 100644 index 000000000000..5e958ca95395 --- /dev/null +++ b/plugins/evaluation-plugin/core/src/com/intellij/cce/metric/ApiRecall.kt @@ -0,0 +1,41 @@ +package com.intellij.cce.metric + +import com.intellij.cce.core.Session +import com.intellij.cce.metric.util.Sample + +class ApiRecall : Metric { + private val sample = Sample() + override val name: String = "API Recall" + override val description: String = "The fraction of correctly guessed project-defined API calls" + override val showByDefault: Boolean = true + override val valueType = MetricValueType.DOUBLE + override val value: Double + get() = sample.mean() + + @Suppress("UNCHECKED_CAST") + override fun evaluate(sessions: List): Number { + val fileSample = Sample() + sessions + .flatMap { it.lookups } + .forEach { + val predictedApiCalls = it.additionalInfo["predicted_api_calls"] as List + val groundTruthApiCalls = it.additionalInfo["ground_truth_internal_api_calls"] as List + val apiRecall = calculateApiRecallForLookupSnippets(predictedApiCalls, groundTruthApiCalls) + fileSample.add(apiRecall) + sample.add(apiRecall) + } + return fileSample.mean() + } + + private fun calculateApiRecallForLookupSnippets( + predictedApiCalls: List, + groundTruthApiCalls: List, + ): Double { + if (groundTruthApiCalls.isEmpty()) return 1.0 + + val uniqueGroundTruthApiCalls = groundTruthApiCalls.toSet() + val uniquePredictedApiCalls = predictedApiCalls.toSet() + val intersection = uniquePredictedApiCalls.intersect(uniqueGroundTruthApiCalls) + return intersection.size.toDouble() / uniqueGroundTruthApiCalls.size.toDouble() + } +} \ No newline at end of file diff --git a/plugins/evaluation-plugin/languages/evaluationconfig/chatcodegeneration/java.json b/plugins/evaluation-plugin/languages/evaluationconfig/chatcodegeneration/java.json new file mode 100644 index 000000000000..de092d3453ed --- /dev/null +++ b/plugins/evaluation-plugin/languages/evaluationconfig/chatcodegeneration/java.json @@ -0,0 +1,30 @@ +{ + "language": "Java", + "projectPath": "../../evaluation_projects/calculator-java", + "outputDir": "ml-eval-chat-code-generation-java-output", + "strategy": { + "behaviour": "REMOVE_METHOD_BODY_ASK_TO_GENERATE", + "collectContextOnly": false, + "minimumInternalApiCallsToConsider": 1 + }, + "actions": { + "evaluationRoots": [ + "." + ] + }, + "interpret": { + "saveLogs": true, + "sessionProbability": 1.0, + "sessionSeed": null, + "sessionsLimit": 100, + "filesLimit": 1, + "order": "LINEAR", + "trainTestSplit": 70 + }, + "reports": { + "evaluationTitle": "Default", + "defaultMetrics": null, + "sessionsFilters": [], + "comparisonFilters": [] + } +} diff --git a/plugins/evaluation-plugin/languages/java/resources/META-INF/evaluationPlugin-java.xml b/plugins/evaluation-plugin/languages/java/resources/META-INF/evaluationPlugin-java.xml index f610827993c4..98eead2dc913 100644 --- a/plugins/evaluation-plugin/languages/java/resources/META-INF/evaluationPlugin-java.xml +++ b/plugins/evaluation-plugin/languages/java/resources/META-INF/evaluationPlugin-java.xml @@ -8,8 +8,10 @@ + + \ No newline at end of file diff --git a/plugins/evaluation-plugin/languages/java/src/com/intellij/cce/visitor/JavaApiCallExtractor.kt b/plugins/evaluation-plugin/languages/java/src/com/intellij/cce/visitor/JavaApiCallExtractor.kt new file mode 100644 index 000000000000..91de1a2dd1d9 --- /dev/null +++ b/plugins/evaluation-plugin/languages/java/src/com/intellij/cce/visitor/JavaApiCallExtractor.kt @@ -0,0 +1,54 @@ +package com.intellij.cce.visitor + +import com.intellij.cce.metric.ApiCallExtractor +import com.intellij.ide.actions.QualifiedNameProvider +import com.intellij.lang.Language +import com.intellij.openapi.application.smartReadActionBlocking +import com.intellij.openapi.application.writeAction +import com.intellij.openapi.project.Project +import com.intellij.psi.* + +fun extractCallExpressions( + psiElement: PsiElement, + filter: (PsiCallExpression) -> Boolean = { true }, +): List { + val result: MutableList = mutableListOf() + val visitor = object : JavaRecursiveElementVisitor() { + override fun visitCallExpression(callExpression: PsiCallExpression) { + if (!filter(callExpression)) return + result.add(callExpression) + super.visitCallExpression(callExpression) + } + } + psiElement.accept(visitor) + return result +} + +fun PsiElement.getQualifiedName(): String? { + QualifiedNameProvider.EP_NAME.extensionList.forEach { provider -> + val qualifiedName = provider.getQualifiedName(this) + if (qualifiedName != null) return qualifiedName + } + return null +} + + +class JavaApiCallExtractor : ApiCallExtractor { + override val language = Language.findLanguageByID("JAVA")!! + + override suspend fun extractForGeneratedCode(code: String, project: Project): List { + val psiFile = writeAction { createPsiFile(code, project) } + return smartReadActionBlocking(project) { + val callExpressions = extractCallExpressions(psiFile) + callExpressions.mapNotNull { it.resolveMethod()?.getQualifiedName() } + } + } + + private fun createPsiFile(code: String, project: Project): PsiFile { + return PsiFileFactory.getInstance(project).createFileFromText( + "dummy.java", + language, + code + ) + } +} \ No newline at end of file diff --git a/plugins/evaluation-plugin/languages/java/src/com/intellij/cce/visitor/JavaCodeGenerationInChatVisitor.kt b/plugins/evaluation-plugin/languages/java/src/com/intellij/cce/visitor/JavaCodeGenerationInChatVisitor.kt new file mode 100644 index 000000000000..1be2f40b5e35 --- /dev/null +++ b/plugins/evaluation-plugin/languages/java/src/com/intellij/cce/visitor/JavaCodeGenerationInChatVisitor.kt @@ -0,0 +1,59 @@ +// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.cce.visitor + +import com.intellij.cce.core.* +import com.intellij.cce.visitor.exceptions.PsiConverterException +import com.intellij.openapi.application.smartReadActionBlocking +import com.intellij.openapi.progress.runBlockingCancellable +import com.intellij.openapi.roots.ProjectFileIndex +import com.intellij.psi.JavaRecursiveElementVisitor +import com.intellij.psi.PsiCallExpression +import com.intellij.psi.PsiJavaFile +import com.intellij.psi.PsiMethod +import com.intellij.psi.util.startOffset + +class JavaCodeGenerationInChatVisitor : EvaluationVisitor, JavaRecursiveElementVisitor() { + private var codeFragment: CodeFragment? = null + + override val language: Language = Language.JAVA + override val feature: String = "chat-code-generation" + + override fun getFile(): CodeFragment = codeFragment + ?: throw PsiConverterException("Invoke 'accept' with visitor on PSI first") + + override fun visitJavaFile(file: PsiJavaFile) { + codeFragment = CodeFragment(file.textOffset, file.textLength) + super.visitJavaFile(file) + } + + override fun visitMethod(method: PsiMethod) { + val internalApiCalls = runBlockingCancellable { extractInternalApiCalls(method) } + codeFragment?.addChild( + CodeToken( + method.text, + method.startOffset, + SimpleTokenProperties.create(tokenType = TypeProperty.METHOD, SymbolLocation.PROJECT) { + put("apiCalls", internalApiCalls.joinToString("\n")) + }) + ) + } + + private suspend fun extractInternalApiCalls(method: PsiMethod): List { + return smartReadActionBlocking(method.project) { + val callExpressions = extractCallExpressions(method) + callExpressions.mapNotNull { it.tryGetCorrespondingInternalApi() } + } + } + + + private fun PsiCallExpression.tryGetCorrespondingInternalApi(): String? { + val resolvedElement = this.resolveMethod() ?: return null + val containingFile = resolvedElement.containingFile?.virtualFile ?: return null + val projectFileIndex = ProjectFileIndex.getInstance(this.project) + if (projectFileIndex.isInContent(containingFile)) { + return resolvedElement.getQualifiedName() + } + return null + } + +} \ No newline at end of file diff --git a/plugins/evaluation-plugin/resources/META-INF/plugin.xml b/plugins/evaluation-plugin/resources/META-INF/plugin.xml index 5ea8c3841e27..01409f10ac96 100644 --- a/plugins/evaluation-plugin/resources/META-INF/plugin.xml +++ b/plugins/evaluation-plugin/resources/META-INF/plugin.xml @@ -39,6 +39,9 @@ + diff --git a/plugins/evaluation-plugin/src/com/intellij/cce/evaluable/chat/ChatFeature.kt b/plugins/evaluation-plugin/src/com/intellij/cce/evaluable/chat/ChatFeature.kt new file mode 100644 index 000000000000..e69de29bb2d1