mirror of
https://gitflic.ru/project/openide/openide.git
synced 2026-01-04 17:20:55 +07:00
[ml-local-models] make models extendable from different languages and independent from completion
GitOrigin-RevId: 312c422165285bbe1202deec04e259daca67584a
This commit is contained in:
committed by
intellij-monorepo-bot
parent
b4ddea5d82
commit
08debd44b6
2
.idea/modules.xml
generated
2
.idea/modules.xml
generated
@@ -479,7 +479,6 @@
|
||||
<module fileurl="file://$PROJECT_DIR$/plugins/commander/intellij.commander.iml" filepath="$PROJECT_DIR$/plugins/commander/intellij.commander.iml" />
|
||||
<module fileurl="file://$PROJECT_DIR$/plugins/completion-ml-ranking/intellij.completionMlRanking.iml" filepath="$PROJECT_DIR$/plugins/completion-ml-ranking/intellij.completionMlRanking.iml" />
|
||||
<module fileurl="file://$PROJECT_DIR$/plugins/completion-ml-ranking/intellij.completionMlRanking.tests.iml" filepath="$PROJECT_DIR$/plugins/completion-ml-ranking/intellij.completionMlRanking.tests.iml" />
|
||||
<module fileurl="file://$PROJECT_DIR$/plugins/completion-ml-ranking-local/intellij.completionMlRankingLocal.iml" filepath="$PROJECT_DIR$/plugins/completion-ml-ranking-local/intellij.completionMlRankingLocal.iml" />
|
||||
<module fileurl="file://$PROJECT_DIR$/plugins/completion-ml-ranking-models/intellij.completionMlRankingModels.iml" filepath="$PROJECT_DIR$/plugins/completion-ml-ranking-models/intellij.completionMlRankingModels.iml" />
|
||||
<module fileurl="file://$PROJECT_DIR$/plugins/completion-ml-ranking-models/intellij.completionMlRankingModels.tests.iml" filepath="$PROJECT_DIR$/plugins/completion-ml-ranking-models/intellij.completionMlRankingModels.tests.iml" />
|
||||
<module fileurl="file://$PROJECT_DIR$/plugins/configuration-script/intellij.configurationScript.iml" filepath="$PROJECT_DIR$/plugins/configuration-script/intellij.configurationScript.iml" />
|
||||
@@ -636,6 +635,7 @@
|
||||
<module fileurl="file://$PROJECT_DIR$/plugins/maven/maven3-server-impl/intellij.maven.server.m3.impl.iml" filepath="$PROJECT_DIR$/plugins/maven/maven3-server-impl/intellij.maven.server.m3.impl.iml" />
|
||||
<module fileurl="file://$PROJECT_DIR$/plugins/maven/maven30-server-impl/intellij.maven.server.m30.impl.iml" filepath="$PROJECT_DIR$/plugins/maven/maven30-server-impl/intellij.maven.server.m30.impl.iml" />
|
||||
<module fileurl="file://$PROJECT_DIR$/plugins/maven/maven36-server-impl/intellij.maven.server.m36.impl.iml" filepath="$PROJECT_DIR$/plugins/maven/maven36-server-impl/intellij.maven.server.m36.impl.iml" />
|
||||
<module fileurl="file://$PROJECT_DIR$/plugins/ml-local-models/intellij.mlLocalModels.iml" filepath="$PROJECT_DIR$/plugins/ml-local-models/intellij.mlLocalModels.iml" />
|
||||
<module fileurl="file://$PROJECT_DIR$/platform/built-in-server/client/node-rpc-client/intellij.nodeRpcClient.iml" filepath="$PROJECT_DIR$/platform/built-in-server/client/node-rpc-client/intellij.nodeRpcClient.iml" />
|
||||
<module fileurl="file://$PROJECT_DIR$/plugins/package-search/intellij.packageSearch.iml" filepath="$PROJECT_DIR$/plugins/package-search/intellij.packageSearch.iml" />
|
||||
<module fileurl="file://$PROJECT_DIR$/plugins/package-search/pkgs-tests/intellij.packageSearch.tests.iml" filepath="$PROJECT_DIR$/plugins/package-search/pkgs-tests/intellij.packageSearch.tests.iml" />
|
||||
|
||||
@@ -153,6 +153,6 @@
|
||||
<orderEntry type="module" module-name="intellij.java.featuresTrainer" scope="RUNTIME" />
|
||||
<orderEntry type="module" module-name="intellij.idea.community.build.tasks" scope="TEST" />
|
||||
<orderEntry type="module" module-name="intellij.junit.v5.rt.tests" scope="TEST" />
|
||||
<orderEntry type="module" module-name="intellij.completionMlRankingLocal" scope="RUNTIME" />
|
||||
<orderEntry type="module" module-name="intellij.mlLocalModels" scope="RUNTIME" />
|
||||
</component>
|
||||
</module>
|
||||
@@ -84,6 +84,7 @@
|
||||
<orderEntry type="module" module-name="intellij.platform.core.ui" />
|
||||
<orderEntry type="module" module-name="intellij.platform.codeStyle.impl" />
|
||||
<orderEntry type="module" module-name="intellij.platform.ide.util.io" />
|
||||
<orderEntry type="module" module-name="intellij.mlLocalModels" scope="PROVIDED" />
|
||||
</component>
|
||||
<component name="copyright">
|
||||
<Base>
|
||||
|
||||
@@ -17,6 +17,7 @@
|
||||
<module value="com.intellij.modules.java"/>
|
||||
|
||||
<depends optional="true" config-file="images-integration.xml">com.intellij.platform.images</depends>
|
||||
<depends optional="true" config-file="java-ml-local-models.xml">com.intellij.ml.local.models</depends>
|
||||
|
||||
<xi:include href="/idea/JavaActions.xml" xpointer="xpointer(/idea-plugin/*)"/>
|
||||
|
||||
|
||||
10
java/java-impl/src/META-INF/java-ml-local-models.xml
Normal file
10
java/java-impl/src/META-INF/java-ml-local-models.xml
Normal file
@@ -0,0 +1,10 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<idea-plugin>
|
||||
<extensions defaultExtensionNs="com.intellij">
|
||||
<ml.local.models.factory language="JAVA" implementationClass="com.intellij.codeInsight.completion.ml.local.JavaMethodsFrequencyModelFactory"/>
|
||||
<ml.local.models.factory language="JAVA" implementationClass="com.intellij.codeInsight.completion.ml.local.JavaClassesFrequencyModelFactory"/>
|
||||
|
||||
<completion.ml.contextFeatures language="JAVA" implementationClass="com.intellij.codeInsight.completion.ml.local.JavaFrequencyContextFeatureProvider"/>
|
||||
<completion.ml.elementFeatures language="JAVA" implementationClass="com.intellij.codeInsight.completion.ml.local.JavaFrequencyElementFeatureProvider"/>
|
||||
</extensions>
|
||||
</idea-plugin>
|
||||
@@ -1,28 +1,13 @@
|
||||
package com.intellij.completion.ml.local.models.frequency
|
||||
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
package com.intellij.codeInsight.completion.ml.local
|
||||
|
||||
import com.intellij.completion.ml.local.models.api.LocalModel
|
||||
import com.intellij.completion.ml.local.models.storage.ClassesFrequencyStorage
|
||||
import com.intellij.completion.ml.local.util.LocalModelsUtil
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.ml.local.models.frequency.classes.ClassesFrequencyModelFactory
|
||||
import com.intellij.ml.local.models.frequency.classes.ClassesUsagesTracker
|
||||
import com.intellij.psi.*
|
||||
import com.intellij.psi.util.PsiTypesUtil
|
||||
|
||||
class ClassesFrequencyLocalModel private constructor(private val storage: ClassesFrequencyStorage) : LocalModel {
|
||||
companion object {
|
||||
fun create(project: Project): ClassesFrequencyLocalModel {
|
||||
val storagesPath = LocalModelsUtil.storagePath(project)
|
||||
val storage = ClassesFrequencyStorage.getStorage(storagesPath)
|
||||
return ClassesFrequencyLocalModel(storage)
|
||||
}
|
||||
}
|
||||
|
||||
fun totalClassesCount(): Int = storage.totalClasses
|
||||
|
||||
fun totalClassesUsages(): Int = storage.totalClassesUsages
|
||||
|
||||
fun getClass(className: String): Int? = storage.get(className)
|
||||
|
||||
override fun fileVisitor(): PsiElementVisitor = object : JavaRecursiveElementWalkingVisitor() {
|
||||
class JavaClassesFrequencyModelFactory : ClassesFrequencyModelFactory() {
|
||||
override fun fileVisitor(usagesTracker: ClassesUsagesTracker): PsiElementVisitor = object : JavaRecursiveElementWalkingVisitor() {
|
||||
|
||||
override fun visitNewExpression(expression: PsiNewExpression) {
|
||||
val cls = expression.classReference?.resolve()
|
||||
@@ -70,8 +55,8 @@ class ClassesFrequencyLocalModel private constructor(private val storage: Classe
|
||||
}
|
||||
|
||||
private fun addClassUsage(cls: PsiClass) {
|
||||
LocalModelsUtil.getClassName(cls)?.let {
|
||||
storage.addClassUsage(it)
|
||||
JavaLocalModelsUtil.getClassName(cls)?.let {
|
||||
usagesTracker.classUsed(it)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -79,14 +64,4 @@ class ClassesFrequencyLocalModel private constructor(private val storage: Classe
|
||||
override fun visitImportStatement(statement: PsiImportStatement) = Unit
|
||||
override fun visitImportStaticStatement(statement: PsiImportStaticStatement) = Unit
|
||||
}
|
||||
|
||||
override fun onStarted() {
|
||||
storage.setValid(false)
|
||||
}
|
||||
|
||||
override fun onFinished() {
|
||||
storage.setValid(true)
|
||||
}
|
||||
|
||||
override fun readyToUse(): Boolean = storage.isValid() && !storage.isEmpty()
|
||||
}
|
||||
@@ -1,18 +1,19 @@
|
||||
package com.intellij.completion.ml.local.features
|
||||
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
package com.intellij.codeInsight.completion.ml.local
|
||||
|
||||
import com.intellij.codeInsight.completion.CompletionParameters
|
||||
import com.intellij.codeInsight.completion.ml.CompletionEnvironment
|
||||
import com.intellij.codeInsight.completion.ml.ContextFeatureProvider
|
||||
import com.intellij.codeInsight.completion.ml.MLFeatureValue
|
||||
import com.intellij.completion.ml.local.models.LocalModelsManager
|
||||
import com.intellij.completion.ml.local.models.frequency.ClassesFrequencyLocalModel
|
||||
import com.intellij.completion.ml.local.models.frequency.MethodsFrequencyLocalModel
|
||||
import com.intellij.completion.ml.local.models.frequency.MethodsFrequencies
|
||||
import com.intellij.completion.ml.local.util.LocalModelsUtil
|
||||
import com.intellij.lang.java.JavaLanguage
|
||||
import com.intellij.ml.local.models.LocalModelsManager
|
||||
import com.intellij.ml.local.models.frequency.classes.ClassesFrequencyLocalModel
|
||||
import com.intellij.ml.local.models.frequency.methods.MethodsFrequencyLocalModel
|
||||
import com.intellij.ml.local.models.frequency.methods.MethodsFrequencies
|
||||
import com.intellij.openapi.util.Key
|
||||
import com.intellij.psi.*
|
||||
|
||||
class FrequencyContextFeaturesProvider : ContextFeatureProvider {
|
||||
class JavaFrequencyContextFeatureProvider : ContextFeatureProvider {
|
||||
companion object {
|
||||
val RECEIVER_CLASS_NAME_KEY: Key<String> = Key.create("ml.completion.local.models.receiver.class.name")
|
||||
val RECEIVER_CLASS_FREQUENCIES_KEY: Key<MethodsFrequencies> = Key.create("ml.completion.local.models.receiver.class.frequencies")
|
||||
@@ -24,10 +25,11 @@ class FrequencyContextFeaturesProvider : ContextFeatureProvider {
|
||||
override fun calculateFeatures(environment: CompletionEnvironment): MutableMap<String, MLFeatureValue> {
|
||||
val features = mutableMapOf<String, MLFeatureValue>()
|
||||
val project = environment.parameters.position.project
|
||||
val methodsModel = LocalModelsManager.getInstance(project).getModel<MethodsFrequencyLocalModel>()
|
||||
val modelsManager = LocalModelsManager.getInstance(project)
|
||||
val methodsModel = modelsManager.getModel<MethodsFrequencyLocalModel>(JavaLanguage.INSTANCE)
|
||||
if (methodsModel != null && methodsModel.readyToUse()) {
|
||||
getReceiverClass(environment.parameters)?.let { cls ->
|
||||
LocalModelsUtil.getClassName(cls)?.let {
|
||||
JavaLocalModelsUtil.getClassName(cls)?.let {
|
||||
environment.putUserData(RECEIVER_CLASS_NAME_KEY, it)
|
||||
methodsModel.getMethodsByClass(it)?.let { frequencies ->
|
||||
environment.putUserData(RECEIVER_CLASS_FREQUENCIES_KEY, frequencies)
|
||||
@@ -37,7 +39,7 @@ class FrequencyContextFeaturesProvider : ContextFeatureProvider {
|
||||
features["total_methods"] = MLFeatureValue.numerical(methodsModel.totalMethodsCount())
|
||||
features["total_methods_usages"] = MLFeatureValue.numerical(methodsModel.totalMethodsUsages())
|
||||
}
|
||||
val classesModel = LocalModelsManager.getInstance(project).getModel<ClassesFrequencyLocalModel>()
|
||||
val classesModel = modelsManager.getModel<ClassesFrequencyLocalModel>(JavaLanguage.INSTANCE)
|
||||
if (classesModel != null && classesModel.readyToUse()) {
|
||||
features["total_classes"] = MLFeatureValue.numerical(classesModel.totalClassesCount())
|
||||
features["total_classes_usages"] = MLFeatureValue.numerical(classesModel.totalClassesUsages())
|
||||
@@ -0,0 +1,52 @@
|
||||
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
package com.intellij.codeInsight.completion.ml.local
|
||||
|
||||
import com.intellij.codeInsight.completion.CompletionLocation
|
||||
import com.intellij.codeInsight.completion.ml.ContextFeatures
|
||||
import com.intellij.codeInsight.completion.ml.ElementFeatureProvider
|
||||
import com.intellij.codeInsight.completion.ml.MLFeatureValue
|
||||
import com.intellij.codeInsight.lookup.LookupElement
|
||||
import com.intellij.lang.java.JavaLanguage
|
||||
import com.intellij.ml.local.models.LocalModelsManager
|
||||
import com.intellij.ml.local.models.frequency.classes.ClassesFrequencyLocalModel
|
||||
import com.intellij.psi.PsiClass
|
||||
import com.intellij.psi.PsiMethod
|
||||
|
||||
class JavaFrequencyElementFeatureProvider : ElementFeatureProvider {
|
||||
override fun getName(): String = "local"
|
||||
|
||||
override fun calculateFeatures(element: LookupElement,
|
||||
location: CompletionLocation,
|
||||
contextFeatures: ContextFeatures): MutableMap<String, MLFeatureValue> {
|
||||
val features = mutableMapOf<String, MLFeatureValue>()
|
||||
val psi = element.psiElement
|
||||
val receiverClassName = contextFeatures.getUserData(JavaFrequencyContextFeatureProvider.RECEIVER_CLASS_NAME_KEY)
|
||||
val classFrequencies = contextFeatures.getUserData(JavaFrequencyContextFeatureProvider.RECEIVER_CLASS_FREQUENCIES_KEY)
|
||||
if (psi is PsiMethod && receiverClassName != null && classFrequencies != null) {
|
||||
psi.containingClass?.let { cls ->
|
||||
JavaLocalModelsUtil.getClassName(cls)?.let { className ->
|
||||
if (receiverClassName == className) {
|
||||
JavaLocalModelsUtil.getMethodName(psi)?.let { methodName ->
|
||||
val frequency = classFrequencies.getMethodFrequency(methodName)
|
||||
if (frequency > 0) {
|
||||
val totalUsages = classFrequencies.getTotalFrequency()
|
||||
features["absolute_method_frequency"] = MLFeatureValue.numerical(frequency)
|
||||
features["relative_method_frequency"] = MLFeatureValue.numerical(frequency.toDouble() / totalUsages)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
val classesModel = LocalModelsManager.getInstance(location.project).getModel<ClassesFrequencyLocalModel>(JavaLanguage.INSTANCE)
|
||||
if (psi is PsiClass && classesModel != null && classesModel.readyToUse()) {
|
||||
JavaLocalModelsUtil.getClassName(psi)?.let { className ->
|
||||
classesModel.getClass(className)?.let {
|
||||
features["absolute_class_frequency"] = MLFeatureValue.numerical(it)
|
||||
features["relative_class_frequency"] = MLFeatureValue.numerical(it.toDouble() / classesModel.totalClassesUsages())
|
||||
}
|
||||
}
|
||||
}
|
||||
return features
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
package com.intellij.codeInsight.completion.ml.local
|
||||
|
||||
import com.intellij.psi.PsiClass
|
||||
import com.intellij.psi.PsiMethod
|
||||
|
||||
internal object JavaLocalModelsUtil {
|
||||
|
||||
fun getMethodName(method: PsiMethod): String? = method.presentation?.presentableText
|
||||
|
||||
fun getClassName(cls: PsiClass): String? = cls.qualifiedName
|
||||
}
|
||||
@@ -0,0 +1,28 @@
|
||||
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
package com.intellij.codeInsight.completion.ml.local
|
||||
|
||||
import com.intellij.ml.local.models.frequency.methods.MethodsFrequencyModelFactory
|
||||
import com.intellij.ml.local.models.frequency.methods.MethodsUsagesTracker
|
||||
import com.intellij.psi.*
|
||||
|
||||
class JavaMethodsFrequencyModelFactory : MethodsFrequencyModelFactory() {
|
||||
override fun fileVisitor(usagesTracker: MethodsUsagesTracker): PsiElementVisitor = object : JavaRecursiveElementWalkingVisitor() {
|
||||
|
||||
override fun visitMethodCallExpression(expression: PsiMethodCallExpression) {
|
||||
expression.resolveMethod()?.let { method ->
|
||||
JavaLocalModelsUtil.getMethodName(method)?.let { methodName ->
|
||||
method.containingClass?.let { cls ->
|
||||
JavaLocalModelsUtil.getClassName(cls)?.let { clsName ->
|
||||
usagesTracker.methodUsed(clsName, methodName)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
super.visitMethodCallExpression(expression)
|
||||
}
|
||||
|
||||
override fun visitReferenceElement(reference: PsiJavaCodeReferenceElement) = Unit
|
||||
override fun visitImportStatement(statement: PsiImportStatement) = Unit
|
||||
override fun visitImportStaticStatement(statement: PsiImportStaticStatement) = Unit
|
||||
}
|
||||
}
|
||||
@@ -1,26 +0,0 @@
|
||||
<idea-plugin>
|
||||
<id>com.intellij.completion.ml.ranking.local</id>
|
||||
<name>Machine Learning Code Completion Local Models</name>
|
||||
<vendor>JetBrains</vendor>
|
||||
<category>Other Tools</category>
|
||||
|
||||
<description><![CDATA[
|
||||
<p>The plugin contains logic for training local models for code completion based on machine learning.</p>
|
||||
]]></description>
|
||||
|
||||
<actions>
|
||||
<action id="TrainLocalModelsAction" class="com.intellij.completion.ml.local.actions.TrainLocalModelsAction"/>
|
||||
</actions>
|
||||
|
||||
<resource-bundle>messages.CompletionRankingLocalBundle</resource-bundle>
|
||||
|
||||
<extensions defaultExtensionNs="com.intellij">
|
||||
<projectService serviceImplementation="com.intellij.completion.ml.local.models.LocalModelsManager"/>
|
||||
|
||||
<completion.ml.contextFeatures language="JAVA" implementationClass="com.intellij.completion.ml.local.features.FrequencyContextFeaturesProvider"/>
|
||||
<completion.ml.elementFeatures language="JAVA" implementationClass="com.intellij.completion.ml.local.features.FrequencyElementFeatureProvider"/>
|
||||
</extensions>
|
||||
|
||||
<depends>com.intellij.modules.java</depends>
|
||||
<depends>com.intellij.completion.ml.ranking</depends>
|
||||
</idea-plugin>
|
||||
@@ -1,3 +0,0 @@
|
||||
ml.completion.local.models.training.title=Training local ML completion models
|
||||
ml.completion.local.models.training.action=Train Local ML Completion Models
|
||||
ml.completion.local.models.training.files.processing=Processing source code files
|
||||
@@ -1,21 +0,0 @@
|
||||
package com.intellij.completion.ml.local;
|
||||
|
||||
import com.intellij.AbstractBundle;
|
||||
import com.intellij.DynamicBundle;
|
||||
import org.jetbrains.annotations.Nls;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.PropertyKey;
|
||||
|
||||
public class CompletionRankingLocalBundle extends DynamicBundle {
|
||||
private static final String COMPLETION_RANKING_LOCAL_BUNDLE = "messages.CompletionRankingLocalBundle";
|
||||
|
||||
public static @Nls String message(@NotNull @PropertyKey(resourceBundle = COMPLETION_RANKING_LOCAL_BUNDLE) String key, Object @NotNull ... params) {
|
||||
return ourInstance.getMessage(key, params);
|
||||
}
|
||||
|
||||
private static final AbstractBundle ourInstance = new CompletionRankingLocalBundle();
|
||||
|
||||
protected CompletionRankingLocalBundle() {
|
||||
super(COMPLETION_RANKING_LOCAL_BUNDLE);
|
||||
}
|
||||
}
|
||||
@@ -1,18 +0,0 @@
|
||||
package com.intellij.completion.ml.local.actions
|
||||
|
||||
import com.intellij.completion.ml.local.CompletionRankingLocalBundle
|
||||
import com.intellij.completion.ml.local.models.LocalModelsTraining
|
||||
import com.intellij.openapi.actionSystem.AnAction
|
||||
import com.intellij.openapi.actionSystem.AnActionEvent
|
||||
|
||||
class TrainLocalModelsAction : AnAction(CompletionRankingLocalBundle.message("ml.completion.local.models.training.action")) {
|
||||
override fun actionPerformed(e: AnActionEvent) {
|
||||
val project = e.project ?: return
|
||||
if (LocalModelsTraining.isTraining()) {
|
||||
//TODO: Show message that model is training right now
|
||||
return
|
||||
}
|
||||
LocalModelsTraining.train(project)
|
||||
//TODO: Show message that model trained successfully
|
||||
}
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
package com.intellij.completion.ml.local.features
|
||||
|
||||
import com.intellij.codeInsight.completion.CompletionLocation
|
||||
import com.intellij.codeInsight.completion.ml.ContextFeatures
|
||||
import com.intellij.codeInsight.completion.ml.ElementFeatureProvider
|
||||
import com.intellij.codeInsight.completion.ml.MLFeatureValue
|
||||
import com.intellij.codeInsight.lookup.LookupElement
|
||||
import com.intellij.completion.ml.local.models.LocalModelsManager
|
||||
import com.intellij.completion.ml.local.models.frequency.ClassesFrequencyLocalModel
|
||||
import com.intellij.completion.ml.local.util.LocalModelsUtil
|
||||
import com.intellij.psi.PsiClass
|
||||
import com.intellij.psi.PsiMethod
|
||||
|
||||
class FrequencyElementFeatureProvider : ElementFeatureProvider {
|
||||
override fun getName(): String = "local"
|
||||
|
||||
override fun calculateFeatures(element: LookupElement,
|
||||
location: CompletionLocation,
|
||||
contextFeatures: ContextFeatures): MutableMap<String, MLFeatureValue> {
|
||||
val features = mutableMapOf<String, MLFeatureValue>()
|
||||
val psi = element.psiElement
|
||||
val receiverClassName = contextFeatures.getUserData(FrequencyContextFeaturesProvider.RECEIVER_CLASS_NAME_KEY)
|
||||
val classFrequencies = contextFeatures.getUserData(FrequencyContextFeaturesProvider.RECEIVER_CLASS_FREQUENCIES_KEY)
|
||||
if (psi is PsiMethod && receiverClassName != null && classFrequencies != null) {
|
||||
LocalModelsUtil.getClassName(psi.containingClass)?.let { className ->
|
||||
if (receiverClassName == className) {
|
||||
LocalModelsUtil.getMethodName(psi)?.let { methodName ->
|
||||
val frequency = classFrequencies.getMethodFrequency(methodName)
|
||||
if (frequency > 0) {
|
||||
val totalUsages = classFrequencies.getTotalFrequency()
|
||||
features["absolute_method_frequency"] = MLFeatureValue.numerical(frequency)
|
||||
features["relative_method_frequency"] = MLFeatureValue.numerical(frequency.toDouble() / totalUsages)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
val classesModel = LocalModelsManager.getInstance(location.project).getModel<ClassesFrequencyLocalModel>()
|
||||
if (psi is PsiClass && classesModel != null && classesModel.readyToUse()) {
|
||||
LocalModelsUtil.getClassName(psi)?.let { className ->
|
||||
classesModel.getClass(className)?.let {
|
||||
features["absolute_class_frequency"] = MLFeatureValue.numerical(it)
|
||||
features["relative_class_frequency"] = MLFeatureValue.numerical(it.toDouble() / classesModel.totalClassesUsages())
|
||||
}
|
||||
}
|
||||
}
|
||||
return features
|
||||
}
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
package com.intellij.completion.ml.local.models
|
||||
|
||||
import com.intellij.completion.ml.local.models.api.LocalModel
|
||||
import com.intellij.completion.ml.local.models.frequency.ClassesFrequencyLocalModel
|
||||
import com.intellij.completion.ml.local.models.frequency.MethodsFrequencyLocalModel
|
||||
import com.intellij.openapi.project.Project
|
||||
|
||||
class LocalModelsManager private constructor(private val project: Project) {
|
||||
companion object {
|
||||
fun getInstance(project: Project): LocalModelsManager = project.getService(LocalModelsManager::class.java)
|
||||
}
|
||||
private val models = mutableMapOf<String, LocalModel>()
|
||||
|
||||
fun getModels(): List<LocalModel> = listOf(
|
||||
models.getOrPut("methods_frequency") { MethodsFrequencyLocalModel.create(project) },
|
||||
models.getOrPut("classes_frequency") { ClassesFrequencyLocalModel.create(project) }
|
||||
)
|
||||
|
||||
inline fun <reified T : LocalModel> getModel(): T? = getModels().filterIsInstance<T>().firstOrNull()
|
||||
}
|
||||
@@ -1,51 +0,0 @@
|
||||
package com.intellij.completion.ml.local.models.frequency
|
||||
|
||||
import com.intellij.completion.ml.local.models.api.LocalModel
|
||||
import com.intellij.completion.ml.local.models.storage.MethodsFrequencyStorage
|
||||
import com.intellij.completion.ml.local.util.LocalModelsUtil
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.psi.*
|
||||
|
||||
class MethodsFrequencyLocalModel private constructor(private val storage: MethodsFrequencyStorage) : LocalModel {
|
||||
companion object {
|
||||
fun create(project: Project): MethodsFrequencyLocalModel {
|
||||
val storagesPath = LocalModelsUtil.storagePath(project)
|
||||
val methodsFrequencyStorage = MethodsFrequencyStorage.getStorage(storagesPath)
|
||||
return MethodsFrequencyLocalModel(methodsFrequencyStorage)
|
||||
}
|
||||
}
|
||||
|
||||
fun totalMethodsCount(): Int = storage.totalMethods
|
||||
|
||||
fun totalMethodsUsages(): Int = storage.totalMethodsUsages
|
||||
|
||||
fun getMethodsByClass(className: String): MethodsFrequencies? = storage.get(className)
|
||||
|
||||
override fun fileVisitor(): PsiElementVisitor = object : JavaRecursiveElementWalkingVisitor() {
|
||||
|
||||
override fun visitMethodCallExpression(expression: PsiMethodCallExpression) {
|
||||
expression.resolveMethod()?.let { method ->
|
||||
LocalModelsUtil.getMethodName(method)?.let { methodName ->
|
||||
LocalModelsUtil.getClassName(method.containingClass)?.let { clsName ->
|
||||
storage.addMethodUsage(clsName, methodName)
|
||||
}
|
||||
}
|
||||
}
|
||||
super.visitMethodCallExpression(expression)
|
||||
}
|
||||
|
||||
override fun visitReferenceElement(reference: PsiJavaCodeReferenceElement) = Unit
|
||||
override fun visitImportStatement(statement: PsiImportStatement) = Unit
|
||||
override fun visitImportStaticStatement(statement: PsiImportStaticStatement) = Unit
|
||||
}
|
||||
|
||||
override fun onStarted() {
|
||||
storage.setValid(false)
|
||||
}
|
||||
|
||||
override fun onFinished() {
|
||||
storage.setValid(true)
|
||||
}
|
||||
|
||||
override fun readyToUse(): Boolean = storage.isValid() && !storage.isEmpty()
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
package com.intellij.completion.ml.local.util
|
||||
|
||||
import com.intellij.openapi.application.PathManager
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.psi.PsiClass
|
||||
import com.intellij.psi.PsiMethod
|
||||
import java.nio.file.Path
|
||||
|
||||
object LocalModelsUtil {
|
||||
fun storagePath(project: Project): Path = PathManager.getIndexRoot().toPath().resolve("completion.ml.local").resolve(project.locationHash)
|
||||
|
||||
fun getMethodName(method: PsiMethod): String? = method.presentation?.presentableText
|
||||
|
||||
fun getClassName(cls: PsiClass?): String? = cls?.qualifiedName
|
||||
}
|
||||
@@ -11,7 +11,6 @@
|
||||
<orderEntry type="library" name="kotlin-stdlib-jdk8" level="project" />
|
||||
<orderEntry type="library" name="gson" level="project" />
|
||||
<orderEntry type="module" module-name="intellij.platform.ide.impl" />
|
||||
<orderEntry type="module" module-name="intellij.java.impl" />
|
||||
<orderEntry type="module" module-name="intellij.completionMlRanking" />
|
||||
<orderEntry type="module" module-name="intellij.platform.indexing" />
|
||||
</component>
|
||||
</module>
|
||||
27
plugins/ml-local-models/resources/META-INF/plugin.xml
Normal file
27
plugins/ml-local-models/resources/META-INF/plugin.xml
Normal file
@@ -0,0 +1,27 @@
|
||||
<idea-plugin>
|
||||
<id>com.intellij.ml.local.models</id>
|
||||
<name>Machine Learning Local Models</name>
|
||||
<vendor>JetBrains</vendor>
|
||||
<category>Other Tools</category>
|
||||
|
||||
<description><![CDATA[
|
||||
<p>The plugin contains logic for training local models based on machine learning.</p>
|
||||
]]></description>
|
||||
|
||||
<actions>
|
||||
<action id="TrainLocalModelsAction" class="com.intellij.ml.local.actions.TrainLocalModelsAction"/>
|
||||
</actions>
|
||||
|
||||
<resource-bundle>messages.MlLocalModelsBundle</resource-bundle>
|
||||
|
||||
<extensionPoints>
|
||||
<extensionPoint name="factory" beanClass="com.intellij.lang.LanguageExtensionPoint" dynamic="true">
|
||||
<with attribute="implementationClass" implements="com.intellij.ml.local.models.api.LocalModelFactory"/>
|
||||
</extensionPoint>
|
||||
</extensionPoints>
|
||||
|
||||
<extensions defaultExtensionNs="com.intellij">
|
||||
<projectService serviceImplementation="com.intellij.ml.local.models.LocalModelsManager"/>
|
||||
<postStartupActivity implementation="com.intellij.ml.local.models.LocalModelsStartupActivity"/>
|
||||
</extensions>
|
||||
</idea-plugin>
|
||||
@@ -0,0 +1,3 @@
|
||||
ml.local.models.training.title=Training local ML models
|
||||
ml.local.models.training.action=Train Local ML Models
|
||||
ml.local.models.training.files.processing=Processing source code files
|
||||
@@ -0,0 +1,21 @@
|
||||
package com.intellij.ml.local;
|
||||
|
||||
import com.intellij.AbstractBundle;
|
||||
import com.intellij.DynamicBundle;
|
||||
import org.jetbrains.annotations.Nls;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.PropertyKey;
|
||||
|
||||
public class MlLocalModelsBundle extends DynamicBundle {
|
||||
private static final String ML_LOCAL_MODELS_BUNDLE = "messages.MlLocalModelsBundle";
|
||||
|
||||
public static @Nls String message(@NotNull @PropertyKey(resourceBundle = ML_LOCAL_MODELS_BUNDLE) String key, Object @NotNull ... params) {
|
||||
return ourInstance.getMessage(key, params);
|
||||
}
|
||||
|
||||
private static final AbstractBundle ourInstance = new MlLocalModelsBundle();
|
||||
|
||||
protected MlLocalModelsBundle() {
|
||||
super(ML_LOCAL_MODELS_BUNDLE);
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package com.intellij.ml.local.actions
|
||||
|
||||
import com.intellij.ml.local.MlLocalModelsBundle
|
||||
import com.intellij.ml.local.models.LocalModelsTraining
|
||||
import com.intellij.lang.Language
|
||||
import com.intellij.openapi.actionSystem.AnAction
|
||||
import com.intellij.openapi.actionSystem.AnActionEvent
|
||||
|
||||
class TrainLocalModelsAction : AnAction(MlLocalModelsBundle.message("ml.local.models.training.action")) {
|
||||
override fun actionPerformed(e: AnActionEvent) {
|
||||
val project = e.project ?: return
|
||||
if (LocalModelsTraining.isTraining()) {
|
||||
//TODO: Show message that model is training right now
|
||||
return
|
||||
}
|
||||
//TODO: Dialog for different languages
|
||||
LocalModelsTraining.train(project, Language.findLanguageByID("JAVA")!!)
|
||||
//TODO: Show message that model trained successfully
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package com.intellij.ml.local.models
|
||||
|
||||
import com.intellij.ml.local.models.api.LocalModel
|
||||
import com.intellij.lang.Language
|
||||
import com.intellij.openapi.project.Project
|
||||
|
||||
class LocalModelsManager private constructor(private val project: Project) {
|
||||
companion object {
|
||||
fun getInstance(project: Project): LocalModelsManager = project.getService(LocalModelsManager::class.java)
|
||||
}
|
||||
private val models = mutableMapOf<String, MutableList<LocalModel>>()
|
||||
|
||||
fun getModels(language: Language): List<LocalModel> = models.getOrDefault(language.id, emptyList())
|
||||
|
||||
fun registerModel(language: Language, model: LocalModel) {
|
||||
models.getOrPut(language.id, { mutableListOf() }).add(model)
|
||||
}
|
||||
|
||||
inline fun <reified T : LocalModel> getModel(language: Language): T? = getModels(language).filterIsInstance<T>().firstOrNull()
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
package com.intellij.ml.local.models
|
||||
|
||||
import com.intellij.lang.Language
|
||||
import com.intellij.ml.local.models.api.LocalModelFactory
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.openapi.startup.StartupActivity
|
||||
|
||||
class LocalModelsStartupActivity : StartupActivity {
|
||||
override fun runActivity(project: Project) {
|
||||
val modelsManager = LocalModelsManager.getInstance(project)
|
||||
for (language in Language.getRegisteredLanguages()) {
|
||||
val factories = LocalModelFactory.forLanguage(language)
|
||||
for (factory in factories) {
|
||||
factory.modelBuilder(project, language).build()?.let { model ->
|
||||
modelsManager.registerModel(language, model)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,12 @@
|
||||
package com.intellij.completion.ml.local.models
|
||||
package com.intellij.ml.local.models
|
||||
|
||||
import com.intellij.completion.ml.local.CompletionRankingLocalBundle
|
||||
import com.intellij.completion.ml.local.models.api.LocalModel
|
||||
import com.intellij.ide.highlighter.JavaFileType
|
||||
import com.intellij.lang.Language
|
||||
import com.intellij.ml.local.MlLocalModelsBundle
|
||||
import com.intellij.ml.local.models.api.LocalModelBuilder
|
||||
import com.intellij.ml.local.models.api.LocalModelFactory
|
||||
import com.intellij.openapi.application.ApplicationManager
|
||||
import com.intellij.openapi.application.runReadAction
|
||||
import com.intellij.openapi.fileTypes.FileType
|
||||
import com.intellij.openapi.progress.ProgressIndicator
|
||||
import com.intellij.openapi.progress.ProgressManager
|
||||
import com.intellij.openapi.progress.Task
|
||||
@@ -16,6 +18,7 @@ import com.intellij.psi.PsiManager
|
||||
import com.intellij.psi.search.FileTypeIndex
|
||||
import com.intellij.psi.search.GlobalSearchScope
|
||||
import com.intellij.util.indexing.FileBasedIndex
|
||||
import java.lang.IllegalArgumentException
|
||||
import java.util.concurrent.Executors
|
||||
import java.util.concurrent.TimeUnit
|
||||
import java.util.concurrent.atomic.AtomicInteger
|
||||
@@ -26,16 +29,23 @@ object LocalModelsTraining {
|
||||
|
||||
fun isTraining(): Boolean = isTraining
|
||||
|
||||
fun train(project: Project) = ApplicationManager.getApplication().executeOnPooledThread {
|
||||
fun train(project: Project, language: Language) = ApplicationManager.getApplication().executeOnPooledThread {
|
||||
val fileType = language.associatedFileType ?: throw IllegalArgumentException("Unsupported language")
|
||||
isTraining = true
|
||||
val modelsManager = LocalModelsManager.getInstance(project)
|
||||
val task = object : Task.Backgroundable(project, CompletionRankingLocalBundle.message("ml.completion.local.models.training.title"), true) {
|
||||
val task = object : Task.Backgroundable(project, MlLocalModelsBundle.message("ml.local.models.training.title"), true) {
|
||||
override fun run(indicator: ProgressIndicator) {
|
||||
val files = getFiles(project)
|
||||
val models = modelsManager.getModels()
|
||||
models.forEach { it.onStarted() }
|
||||
processFiles(files, models, project, indicator)
|
||||
models.forEach { it.onFinished() }
|
||||
val files = getFiles(project, fileType)
|
||||
val factories = LocalModelFactory.forLanguage(language)
|
||||
val builders = factories.map { it.modelBuilder(project, language) }
|
||||
builders.forEach { it.onStarted() }
|
||||
processFiles(files, builders, project, indicator)
|
||||
builders.forEach { it.onFinished() }
|
||||
builders.forEach {
|
||||
it.build()?.let {
|
||||
modelsManager.registerModel(language, it)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun onFinished() {
|
||||
@@ -45,12 +55,12 @@ object LocalModelsTraining {
|
||||
ProgressManager.getInstance().runProcessWithProgressAsynchronously(task, BackgroundableProcessIndicator(task))
|
||||
}
|
||||
|
||||
private fun processFiles(files: List<VirtualFile>, models: Iterable<LocalModel>, project: Project, indicator: ProgressIndicator) {
|
||||
private fun processFiles(files: List<VirtualFile>, modelBuilders: List<LocalModelBuilder>, project: Project, indicator: ProgressIndicator) {
|
||||
val dumbService = DumbService.getInstance(project)
|
||||
val psiManager = PsiManager.getInstance(project)
|
||||
val executorService = Executors.newFixedThreadPool((Runtime.getRuntime().availableProcessors() - 1).coerceAtLeast(1))
|
||||
indicator.isIndeterminate = false
|
||||
indicator.text = CompletionRankingLocalBundle.message("ml.completion.local.models.training.files.processing")
|
||||
indicator.text = MlLocalModelsBundle.message("ml.local.models.training.files.processing")
|
||||
indicator.fraction = 0.0
|
||||
val processed = AtomicInteger(0)
|
||||
|
||||
@@ -58,8 +68,8 @@ object LocalModelsTraining {
|
||||
executorService.submit {
|
||||
dumbService.runReadActionInSmartMode {
|
||||
psiManager.findFile(file)?.let { psiFile ->
|
||||
for (model in models) {
|
||||
psiFile.accept(model.fileVisitor())
|
||||
for (model in modelBuilders.map { it.fileVisitor() }) {
|
||||
psiFile.accept(model)
|
||||
}
|
||||
}
|
||||
indicator.fraction = processed.incrementAndGet().toDouble() / files.size
|
||||
@@ -81,9 +91,9 @@ object LocalModelsTraining {
|
||||
}
|
||||
}
|
||||
|
||||
fun getFiles(project: Project): List<VirtualFile> = runReadAction {
|
||||
fun getFiles(project: Project, fileType: FileType): List<VirtualFile> = runReadAction {
|
||||
FileBasedIndex.getInstance().getContainingFiles(FileTypeIndex.NAME,
|
||||
JavaFileType.INSTANCE,
|
||||
fileType,
|
||||
GlobalSearchScope.projectScope(project)).toList()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
package com.intellij.ml.local.models.api
|
||||
|
||||
interface LocalModel {
|
||||
val id: String
|
||||
fun readyToUse(): Boolean
|
||||
}
|
||||
@@ -1,10 +1,10 @@
|
||||
package com.intellij.completion.ml.local.models.api
|
||||
package com.intellij.ml.local.models.api
|
||||
|
||||
import com.intellij.psi.PsiElementVisitor
|
||||
|
||||
interface LocalModel {
|
||||
fun fileVisitor(): PsiElementVisitor
|
||||
interface LocalModelBuilder {
|
||||
fun onStarted()
|
||||
fun onFinished()
|
||||
fun readyToUse(): Boolean
|
||||
fun fileVisitor(): PsiElementVisitor
|
||||
fun build(): LocalModel?
|
||||
}
|
||||
@@ -0,0 +1,17 @@
|
||||
package com.intellij.ml.local.models.api
|
||||
|
||||
import com.intellij.lang.Language
|
||||
import com.intellij.lang.LanguageExtension
|
||||
import com.intellij.openapi.project.Project
|
||||
|
||||
interface LocalModelFactory {
|
||||
companion object {
|
||||
private val EP_NAME = LanguageExtension<LocalModelFactory>("com.intellij.ml.local.models.factory")
|
||||
|
||||
fun forLanguage(language: Language): List<LocalModelFactory> {
|
||||
return EP_NAME.allForLanguage(language)
|
||||
}
|
||||
}
|
||||
|
||||
fun modelBuilder(project: Project, language: Language): LocalModelBuilder
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.intellij.completion.ml.local.models.storage
|
||||
package com.intellij.ml.local.models.api
|
||||
|
||||
interface LocalModelStorage {
|
||||
fun version(): Int
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.intellij.ml.local.models.frequency
|
||||
|
||||
import com.intellij.ml.local.models.api.LocalModelBuilder
|
||||
import com.intellij.lang.Language
|
||||
import com.intellij.ml.local.models.api.LocalModelFactory
|
||||
import com.intellij.ml.local.models.frequency.classes.ClassesFrequencyModelFactory
|
||||
import com.intellij.ml.local.models.frequency.methods.MethodsFrequencyModelFactory
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.psi.PsiElementVisitor
|
||||
|
||||
abstract class FrequencyModelFactory<UsagesTracker> : LocalModelFactory {
|
||||
|
||||
protected abstract fun fileVisitor(usagesTracker: UsagesTracker): PsiElementVisitor
|
||||
|
||||
abstract override fun modelBuilder(project: Project, language: Language): LocalModelBuilder
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.intellij.ml.local.models.frequency.classes
|
||||
|
||||
import com.intellij.ml.local.models.api.LocalModel
|
||||
|
||||
class ClassesFrequencyLocalModel(private val storage: ClassesFrequencyStorage) : LocalModel {
|
||||
|
||||
override val id: String = "classes_frequency"
|
||||
|
||||
override fun readyToUse(): Boolean = storage.isValid() && !storage.isEmpty()
|
||||
|
||||
fun totalClassesCount(): Int = storage.totalClasses
|
||||
|
||||
fun totalClassesUsages(): Int = storage.totalClassesUsages
|
||||
|
||||
fun getClass(className: String): Int? = storage.get(className)
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package com.intellij.ml.local.models.frequency.classes
|
||||
|
||||
import com.intellij.ml.local.models.api.LocalModelBuilder
|
||||
import com.intellij.ml.local.models.frequency.FrequencyModelFactory
|
||||
import com.intellij.ml.local.util.StorageUtil
|
||||
import com.intellij.lang.Language
|
||||
import com.intellij.lang.LanguageExtension
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.psi.PsiElementVisitor
|
||||
|
||||
abstract class ClassesFrequencyModelFactory : FrequencyModelFactory<ClassesUsagesTracker>() {
|
||||
|
||||
abstract override fun fileVisitor(usagesTracker: ClassesUsagesTracker): PsiElementVisitor
|
||||
|
||||
override fun modelBuilder(project: Project, language: Language): LocalModelBuilder {
|
||||
val storagesPath = StorageUtil.storagePath(project, language)
|
||||
val storage = ClassesFrequencyStorage.getStorage(storagesPath)
|
||||
|
||||
return object : LocalModelBuilder {
|
||||
|
||||
override fun onStarted() {
|
||||
storage.setValid(false)
|
||||
}
|
||||
|
||||
override fun onFinished() {
|
||||
storage.setValid(true)
|
||||
}
|
||||
|
||||
override fun fileVisitor(): PsiElementVisitor = fileVisitor(ClassesUsagesTracker(storage))
|
||||
|
||||
override fun build(): ClassesFrequencyLocalModel? {
|
||||
if (!storage.isValid() || storage.isEmpty()) {
|
||||
return null
|
||||
}
|
||||
return ClassesFrequencyLocalModel(storage)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,7 +1,9 @@
|
||||
package com.intellij.completion.ml.local.models.storage
|
||||
package com.intellij.ml.local.models.frequency.classes
|
||||
|
||||
import com.intellij.completion.ml.local.models.storage.StorageUtil.clear
|
||||
import com.intellij.completion.ml.local.models.storage.StorageUtil.isEmpty
|
||||
import com.intellij.ml.local.models.api.LocalModelStorage
|
||||
import com.intellij.ml.local.util.StorageUtil
|
||||
import com.intellij.ml.local.util.StorageUtil.clear
|
||||
import com.intellij.ml.local.util.StorageUtil.isEmpty
|
||||
import com.intellij.util.Processor
|
||||
import com.intellij.util.io.EnumeratorStringDescriptor
|
||||
import com.intellij.util.io.IntInlineKeyDescriptor
|
||||
@@ -0,0 +1,7 @@
|
||||
package com.intellij.ml.local.models.frequency.classes
|
||||
|
||||
class ClassesUsagesTracker(private val storage: ClassesFrequencyStorage) {
|
||||
fun classUsed(className: String) {
|
||||
storage.addClassUsage(className)
|
||||
}
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
package com.intellij.completion.ml.local.models.frequency
|
||||
package com.intellij.ml.local.models.frequency.methods
|
||||
|
||||
class MethodsFrequencies(private var totalFrequency: Int = 0,
|
||||
private val methods: MutableMap<Int, Int> = mutableMapOf()) {
|
||||
@@ -0,0 +1,16 @@
|
||||
package com.intellij.ml.local.models.frequency.methods
|
||||
|
||||
import com.intellij.ml.local.models.api.LocalModel
|
||||
|
||||
class MethodsFrequencyLocalModel internal constructor(private val storage: MethodsFrequencyStorage) : LocalModel {
|
||||
|
||||
override val id: String = "methods_frequency"
|
||||
|
||||
override fun readyToUse(): Boolean = storage.isValid() && !storage.isEmpty()
|
||||
|
||||
fun totalMethodsCount(): Int = storage.totalMethods
|
||||
|
||||
fun totalMethodsUsages(): Int = storage.totalMethodsUsages
|
||||
|
||||
fun getMethodsByClass(className: String): MethodsFrequencies? = storage.get(className)
|
||||
}
|
||||
@@ -0,0 +1,39 @@
|
||||
package com.intellij.ml.local.models.frequency.methods
|
||||
|
||||
import com.intellij.ml.local.models.api.LocalModel
|
||||
import com.intellij.ml.local.models.api.LocalModelBuilder
|
||||
import com.intellij.ml.local.models.frequency.FrequencyModelFactory
|
||||
import com.intellij.ml.local.util.StorageUtil
|
||||
import com.intellij.lang.Language
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.psi.PsiElementVisitor
|
||||
|
||||
abstract class MethodsFrequencyModelFactory : FrequencyModelFactory<MethodsUsagesTracker>() {
|
||||
|
||||
abstract override fun fileVisitor(usagesTracker: MethodsUsagesTracker): PsiElementVisitor
|
||||
|
||||
override fun modelBuilder(project: Project, language: Language): LocalModelBuilder {
|
||||
val storagesPath = StorageUtil.storagePath(project, language)
|
||||
val storage = MethodsFrequencyStorage.getStorage(storagesPath)
|
||||
|
||||
return object : LocalModelBuilder {
|
||||
|
||||
override fun onStarted() {
|
||||
storage.setValid(false)
|
||||
}
|
||||
|
||||
override fun onFinished() {
|
||||
storage.setValid(true)
|
||||
}
|
||||
|
||||
override fun fileVisitor(): PsiElementVisitor = fileVisitor(MethodsUsagesTracker(storage))
|
||||
|
||||
override fun build(): LocalModel? {
|
||||
if (!storage.isValid() || storage.isEmpty()) {
|
||||
return null
|
||||
}
|
||||
return MethodsFrequencyLocalModel(storage)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,8 +1,9 @@
|
||||
package com.intellij.completion.ml.local.models.storage
|
||||
package com.intellij.ml.local.models.frequency.methods
|
||||
|
||||
import com.intellij.completion.ml.local.models.frequency.MethodsFrequencies
|
||||
import com.intellij.completion.ml.local.models.storage.StorageUtil.clear
|
||||
import com.intellij.completion.ml.local.models.storage.StorageUtil.isEmpty
|
||||
import com.intellij.ml.local.models.api.LocalModelStorage
|
||||
import com.intellij.ml.local.util.StorageUtil
|
||||
import com.intellij.ml.local.util.StorageUtil.clear
|
||||
import com.intellij.ml.local.util.StorageUtil.isEmpty
|
||||
import com.intellij.util.Processor
|
||||
import com.intellij.util.io.DataExternalizer
|
||||
import com.intellij.util.io.EnumeratorStringDescriptor
|
||||
@@ -0,0 +1,7 @@
|
||||
package com.intellij.ml.local.models.frequency.methods
|
||||
|
||||
class MethodsUsagesTracker(private val storage: MethodsFrequencyStorage) {
|
||||
fun methodUsed(className: String, methodName: String) {
|
||||
storage.addMethodUsage(className, methodName)
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,9 @@
|
||||
package com.intellij.completion.ml.local.models.storage
|
||||
package com.intellij.ml.local.util
|
||||
|
||||
import com.google.gson.Gson
|
||||
import com.intellij.lang.Language
|
||||
import com.intellij.openapi.application.PathManager
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.util.io.*
|
||||
import java.nio.file.Path
|
||||
import kotlin.io.path.writeText
|
||||
@@ -17,6 +20,11 @@ object StorageUtil {
|
||||
return GSON.fromJson(infoFile.readText(), StorageInfo::class.java)
|
||||
}
|
||||
|
||||
fun storagePath(project: Project, language: Language): Path = PathManager.getIndexRoot().toPath()
|
||||
.resolve("ml.local.models")
|
||||
.resolve(project.locationHash)
|
||||
.resolve(language.id)
|
||||
|
||||
fun saveInfo(version: Int, isValid: Boolean, storageDirectory: Path) {
|
||||
val infoFile = storageDirectory.resolve(STORAGE_INFO_FILE)
|
||||
infoFile.writeText(GSON.toJson(StorageInfo(version, isValid)))
|
||||
Reference in New Issue
Block a user