[ml-local-models] make models extendable from different languages and independent from completion

GitOrigin-RevId: 312c422165285bbe1202deec04e259daca67584a
This commit is contained in:
Alexey Kalina
2021-01-22 15:32:17 +03:00
committed by intellij-monorepo-bot
parent b4ddea5d82
commit 08debd44b6
41 changed files with 453 additions and 281 deletions

2
.idea/modules.xml generated
View File

@@ -479,7 +479,6 @@
<module fileurl="file://$PROJECT_DIR$/plugins/commander/intellij.commander.iml" filepath="$PROJECT_DIR$/plugins/commander/intellij.commander.iml" />
<module fileurl="file://$PROJECT_DIR$/plugins/completion-ml-ranking/intellij.completionMlRanking.iml" filepath="$PROJECT_DIR$/plugins/completion-ml-ranking/intellij.completionMlRanking.iml" />
<module fileurl="file://$PROJECT_DIR$/plugins/completion-ml-ranking/intellij.completionMlRanking.tests.iml" filepath="$PROJECT_DIR$/plugins/completion-ml-ranking/intellij.completionMlRanking.tests.iml" />
<module fileurl="file://$PROJECT_DIR$/plugins/completion-ml-ranking-local/intellij.completionMlRankingLocal.iml" filepath="$PROJECT_DIR$/plugins/completion-ml-ranking-local/intellij.completionMlRankingLocal.iml" />
<module fileurl="file://$PROJECT_DIR$/plugins/completion-ml-ranking-models/intellij.completionMlRankingModels.iml" filepath="$PROJECT_DIR$/plugins/completion-ml-ranking-models/intellij.completionMlRankingModels.iml" />
<module fileurl="file://$PROJECT_DIR$/plugins/completion-ml-ranking-models/intellij.completionMlRankingModels.tests.iml" filepath="$PROJECT_DIR$/plugins/completion-ml-ranking-models/intellij.completionMlRankingModels.tests.iml" />
<module fileurl="file://$PROJECT_DIR$/plugins/configuration-script/intellij.configurationScript.iml" filepath="$PROJECT_DIR$/plugins/configuration-script/intellij.configurationScript.iml" />
@@ -636,6 +635,7 @@
<module fileurl="file://$PROJECT_DIR$/plugins/maven/maven3-server-impl/intellij.maven.server.m3.impl.iml" filepath="$PROJECT_DIR$/plugins/maven/maven3-server-impl/intellij.maven.server.m3.impl.iml" />
<module fileurl="file://$PROJECT_DIR$/plugins/maven/maven30-server-impl/intellij.maven.server.m30.impl.iml" filepath="$PROJECT_DIR$/plugins/maven/maven30-server-impl/intellij.maven.server.m30.impl.iml" />
<module fileurl="file://$PROJECT_DIR$/plugins/maven/maven36-server-impl/intellij.maven.server.m36.impl.iml" filepath="$PROJECT_DIR$/plugins/maven/maven36-server-impl/intellij.maven.server.m36.impl.iml" />
<module fileurl="file://$PROJECT_DIR$/plugins/ml-local-models/intellij.mlLocalModels.iml" filepath="$PROJECT_DIR$/plugins/ml-local-models/intellij.mlLocalModels.iml" />
<module fileurl="file://$PROJECT_DIR$/platform/built-in-server/client/node-rpc-client/intellij.nodeRpcClient.iml" filepath="$PROJECT_DIR$/platform/built-in-server/client/node-rpc-client/intellij.nodeRpcClient.iml" />
<module fileurl="file://$PROJECT_DIR$/plugins/package-search/intellij.packageSearch.iml" filepath="$PROJECT_DIR$/plugins/package-search/intellij.packageSearch.iml" />
<module fileurl="file://$PROJECT_DIR$/plugins/package-search/pkgs-tests/intellij.packageSearch.tests.iml" filepath="$PROJECT_DIR$/plugins/package-search/pkgs-tests/intellij.packageSearch.tests.iml" />

View File

@@ -153,6 +153,6 @@
<orderEntry type="module" module-name="intellij.java.featuresTrainer" scope="RUNTIME" />
<orderEntry type="module" module-name="intellij.idea.community.build.tasks" scope="TEST" />
<orderEntry type="module" module-name="intellij.junit.v5.rt.tests" scope="TEST" />
<orderEntry type="module" module-name="intellij.completionMlRankingLocal" scope="RUNTIME" />
<orderEntry type="module" module-name="intellij.mlLocalModels" scope="RUNTIME" />
</component>
</module>

View File

@@ -84,6 +84,7 @@
<orderEntry type="module" module-name="intellij.platform.core.ui" />
<orderEntry type="module" module-name="intellij.platform.codeStyle.impl" />
<orderEntry type="module" module-name="intellij.platform.ide.util.io" />
<orderEntry type="module" module-name="intellij.mlLocalModels" scope="PROVIDED" />
</component>
<component name="copyright">
<Base>

View File

@@ -17,6 +17,7 @@
<module value="com.intellij.modules.java"/>
<depends optional="true" config-file="images-integration.xml">com.intellij.platform.images</depends>
<depends optional="true" config-file="java-ml-local-models.xml">com.intellij.ml.local.models</depends>
<xi:include href="/idea/JavaActions.xml" xpointer="xpointer(/idea-plugin/*)"/>

View File

@@ -0,0 +1,10 @@
<?xml version="1.0" encoding="utf-8"?>
<idea-plugin>
<extensions defaultExtensionNs="com.intellij">
<ml.local.models.factory language="JAVA" implementationClass="com.intellij.codeInsight.completion.ml.local.JavaMethodsFrequencyModelFactory"/>
<ml.local.models.factory language="JAVA" implementationClass="com.intellij.codeInsight.completion.ml.local.JavaClassesFrequencyModelFactory"/>
<completion.ml.contextFeatures language="JAVA" implementationClass="com.intellij.codeInsight.completion.ml.local.JavaFrequencyContextFeatureProvider"/>
<completion.ml.elementFeatures language="JAVA" implementationClass="com.intellij.codeInsight.completion.ml.local.JavaFrequencyElementFeatureProvider"/>
</extensions>
</idea-plugin>

View File

@@ -1,28 +1,13 @@
package com.intellij.completion.ml.local.models.frequency
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.intellij.codeInsight.completion.ml.local
import com.intellij.completion.ml.local.models.api.LocalModel
import com.intellij.completion.ml.local.models.storage.ClassesFrequencyStorage
import com.intellij.completion.ml.local.util.LocalModelsUtil
import com.intellij.openapi.project.Project
import com.intellij.ml.local.models.frequency.classes.ClassesFrequencyModelFactory
import com.intellij.ml.local.models.frequency.classes.ClassesUsagesTracker
import com.intellij.psi.*
import com.intellij.psi.util.PsiTypesUtil
class ClassesFrequencyLocalModel private constructor(private val storage: ClassesFrequencyStorage) : LocalModel {
companion object {
fun create(project: Project): ClassesFrequencyLocalModel {
val storagesPath = LocalModelsUtil.storagePath(project)
val storage = ClassesFrequencyStorage.getStorage(storagesPath)
return ClassesFrequencyLocalModel(storage)
}
}
fun totalClassesCount(): Int = storage.totalClasses
fun totalClassesUsages(): Int = storage.totalClassesUsages
fun getClass(className: String): Int? = storage.get(className)
override fun fileVisitor(): PsiElementVisitor = object : JavaRecursiveElementWalkingVisitor() {
class JavaClassesFrequencyModelFactory : ClassesFrequencyModelFactory() {
override fun fileVisitor(usagesTracker: ClassesUsagesTracker): PsiElementVisitor = object : JavaRecursiveElementWalkingVisitor() {
override fun visitNewExpression(expression: PsiNewExpression) {
val cls = expression.classReference?.resolve()
@@ -70,8 +55,8 @@ class ClassesFrequencyLocalModel private constructor(private val storage: Classe
}
private fun addClassUsage(cls: PsiClass) {
LocalModelsUtil.getClassName(cls)?.let {
storage.addClassUsage(it)
JavaLocalModelsUtil.getClassName(cls)?.let {
usagesTracker.classUsed(it)
}
}
@@ -79,14 +64,4 @@ class ClassesFrequencyLocalModel private constructor(private val storage: Classe
override fun visitImportStatement(statement: PsiImportStatement) = Unit
override fun visitImportStaticStatement(statement: PsiImportStaticStatement) = Unit
}
override fun onStarted() {
storage.setValid(false)
}
override fun onFinished() {
storage.setValid(true)
}
override fun readyToUse(): Boolean = storage.isValid() && !storage.isEmpty()
}

View File

@@ -1,18 +1,19 @@
package com.intellij.completion.ml.local.features
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.intellij.codeInsight.completion.ml.local
import com.intellij.codeInsight.completion.CompletionParameters
import com.intellij.codeInsight.completion.ml.CompletionEnvironment
import com.intellij.codeInsight.completion.ml.ContextFeatureProvider
import com.intellij.codeInsight.completion.ml.MLFeatureValue
import com.intellij.completion.ml.local.models.LocalModelsManager
import com.intellij.completion.ml.local.models.frequency.ClassesFrequencyLocalModel
import com.intellij.completion.ml.local.models.frequency.MethodsFrequencyLocalModel
import com.intellij.completion.ml.local.models.frequency.MethodsFrequencies
import com.intellij.completion.ml.local.util.LocalModelsUtil
import com.intellij.lang.java.JavaLanguage
import com.intellij.ml.local.models.LocalModelsManager
import com.intellij.ml.local.models.frequency.classes.ClassesFrequencyLocalModel
import com.intellij.ml.local.models.frequency.methods.MethodsFrequencyLocalModel
import com.intellij.ml.local.models.frequency.methods.MethodsFrequencies
import com.intellij.openapi.util.Key
import com.intellij.psi.*
class FrequencyContextFeaturesProvider : ContextFeatureProvider {
class JavaFrequencyContextFeatureProvider : ContextFeatureProvider {
companion object {
val RECEIVER_CLASS_NAME_KEY: Key<String> = Key.create("ml.completion.local.models.receiver.class.name")
val RECEIVER_CLASS_FREQUENCIES_KEY: Key<MethodsFrequencies> = Key.create("ml.completion.local.models.receiver.class.frequencies")
@@ -24,10 +25,11 @@ class FrequencyContextFeaturesProvider : ContextFeatureProvider {
override fun calculateFeatures(environment: CompletionEnvironment): MutableMap<String, MLFeatureValue> {
val features = mutableMapOf<String, MLFeatureValue>()
val project = environment.parameters.position.project
val methodsModel = LocalModelsManager.getInstance(project).getModel<MethodsFrequencyLocalModel>()
val modelsManager = LocalModelsManager.getInstance(project)
val methodsModel = modelsManager.getModel<MethodsFrequencyLocalModel>(JavaLanguage.INSTANCE)
if (methodsModel != null && methodsModel.readyToUse()) {
getReceiverClass(environment.parameters)?.let { cls ->
LocalModelsUtil.getClassName(cls)?.let {
JavaLocalModelsUtil.getClassName(cls)?.let {
environment.putUserData(RECEIVER_CLASS_NAME_KEY, it)
methodsModel.getMethodsByClass(it)?.let { frequencies ->
environment.putUserData(RECEIVER_CLASS_FREQUENCIES_KEY, frequencies)
@@ -37,7 +39,7 @@ class FrequencyContextFeaturesProvider : ContextFeatureProvider {
features["total_methods"] = MLFeatureValue.numerical(methodsModel.totalMethodsCount())
features["total_methods_usages"] = MLFeatureValue.numerical(methodsModel.totalMethodsUsages())
}
val classesModel = LocalModelsManager.getInstance(project).getModel<ClassesFrequencyLocalModel>()
val classesModel = modelsManager.getModel<ClassesFrequencyLocalModel>(JavaLanguage.INSTANCE)
if (classesModel != null && classesModel.readyToUse()) {
features["total_classes"] = MLFeatureValue.numerical(classesModel.totalClassesCount())
features["total_classes_usages"] = MLFeatureValue.numerical(classesModel.totalClassesUsages())

View File

@@ -0,0 +1,52 @@
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.intellij.codeInsight.completion.ml.local
import com.intellij.codeInsight.completion.CompletionLocation
import com.intellij.codeInsight.completion.ml.ContextFeatures
import com.intellij.codeInsight.completion.ml.ElementFeatureProvider
import com.intellij.codeInsight.completion.ml.MLFeatureValue
import com.intellij.codeInsight.lookup.LookupElement
import com.intellij.lang.java.JavaLanguage
import com.intellij.ml.local.models.LocalModelsManager
import com.intellij.ml.local.models.frequency.classes.ClassesFrequencyLocalModel
import com.intellij.psi.PsiClass
import com.intellij.psi.PsiMethod
class JavaFrequencyElementFeatureProvider : ElementFeatureProvider {
override fun getName(): String = "local"
override fun calculateFeatures(element: LookupElement,
location: CompletionLocation,
contextFeatures: ContextFeatures): MutableMap<String, MLFeatureValue> {
val features = mutableMapOf<String, MLFeatureValue>()
val psi = element.psiElement
val receiverClassName = contextFeatures.getUserData(JavaFrequencyContextFeatureProvider.RECEIVER_CLASS_NAME_KEY)
val classFrequencies = contextFeatures.getUserData(JavaFrequencyContextFeatureProvider.RECEIVER_CLASS_FREQUENCIES_KEY)
if (psi is PsiMethod && receiverClassName != null && classFrequencies != null) {
psi.containingClass?.let { cls ->
JavaLocalModelsUtil.getClassName(cls)?.let { className ->
if (receiverClassName == className) {
JavaLocalModelsUtil.getMethodName(psi)?.let { methodName ->
val frequency = classFrequencies.getMethodFrequency(methodName)
if (frequency > 0) {
val totalUsages = classFrequencies.getTotalFrequency()
features["absolute_method_frequency"] = MLFeatureValue.numerical(frequency)
features["relative_method_frequency"] = MLFeatureValue.numerical(frequency.toDouble() / totalUsages)
}
}
}
}
}
}
val classesModel = LocalModelsManager.getInstance(location.project).getModel<ClassesFrequencyLocalModel>(JavaLanguage.INSTANCE)
if (psi is PsiClass && classesModel != null && classesModel.readyToUse()) {
JavaLocalModelsUtil.getClassName(psi)?.let { className ->
classesModel.getClass(className)?.let {
features["absolute_class_frequency"] = MLFeatureValue.numerical(it)
features["relative_class_frequency"] = MLFeatureValue.numerical(it.toDouble() / classesModel.totalClassesUsages())
}
}
}
return features
}
}

View File

@@ -0,0 +1,12 @@
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.intellij.codeInsight.completion.ml.local
import com.intellij.psi.PsiClass
import com.intellij.psi.PsiMethod
internal object JavaLocalModelsUtil {
fun getMethodName(method: PsiMethod): String? = method.presentation?.presentableText
fun getClassName(cls: PsiClass): String? = cls.qualifiedName
}

View File

@@ -0,0 +1,28 @@
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.intellij.codeInsight.completion.ml.local
import com.intellij.ml.local.models.frequency.methods.MethodsFrequencyModelFactory
import com.intellij.ml.local.models.frequency.methods.MethodsUsagesTracker
import com.intellij.psi.*
class JavaMethodsFrequencyModelFactory : MethodsFrequencyModelFactory() {
override fun fileVisitor(usagesTracker: MethodsUsagesTracker): PsiElementVisitor = object : JavaRecursiveElementWalkingVisitor() {
override fun visitMethodCallExpression(expression: PsiMethodCallExpression) {
expression.resolveMethod()?.let { method ->
JavaLocalModelsUtil.getMethodName(method)?.let { methodName ->
method.containingClass?.let { cls ->
JavaLocalModelsUtil.getClassName(cls)?.let { clsName ->
usagesTracker.methodUsed(clsName, methodName)
}
}
}
}
super.visitMethodCallExpression(expression)
}
override fun visitReferenceElement(reference: PsiJavaCodeReferenceElement) = Unit
override fun visitImportStatement(statement: PsiImportStatement) = Unit
override fun visitImportStaticStatement(statement: PsiImportStaticStatement) = Unit
}
}

View File

@@ -1,26 +0,0 @@
<idea-plugin>
<id>com.intellij.completion.ml.ranking.local</id>
<name>Machine Learning Code Completion Local Models</name>
<vendor>JetBrains</vendor>
<category>Other Tools</category>
<description><![CDATA[
<p>The plugin contains logic for training local models for code completion based on machine learning.</p>
]]></description>
<actions>
<action id="TrainLocalModelsAction" class="com.intellij.completion.ml.local.actions.TrainLocalModelsAction"/>
</actions>
<resource-bundle>messages.CompletionRankingLocalBundle</resource-bundle>
<extensions defaultExtensionNs="com.intellij">
<projectService serviceImplementation="com.intellij.completion.ml.local.models.LocalModelsManager"/>
<completion.ml.contextFeatures language="JAVA" implementationClass="com.intellij.completion.ml.local.features.FrequencyContextFeaturesProvider"/>
<completion.ml.elementFeatures language="JAVA" implementationClass="com.intellij.completion.ml.local.features.FrequencyElementFeatureProvider"/>
</extensions>
<depends>com.intellij.modules.java</depends>
<depends>com.intellij.completion.ml.ranking</depends>
</idea-plugin>

View File

@@ -1,3 +0,0 @@
ml.completion.local.models.training.title=Training local ML completion models
ml.completion.local.models.training.action=Train Local ML Completion Models
ml.completion.local.models.training.files.processing=Processing source code files

View File

@@ -1,21 +0,0 @@
package com.intellij.completion.ml.local;
import com.intellij.AbstractBundle;
import com.intellij.DynamicBundle;
import org.jetbrains.annotations.Nls;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.PropertyKey;
public class CompletionRankingLocalBundle extends DynamicBundle {
private static final String COMPLETION_RANKING_LOCAL_BUNDLE = "messages.CompletionRankingLocalBundle";
public static @Nls String message(@NotNull @PropertyKey(resourceBundle = COMPLETION_RANKING_LOCAL_BUNDLE) String key, Object @NotNull ... params) {
return ourInstance.getMessage(key, params);
}
private static final AbstractBundle ourInstance = new CompletionRankingLocalBundle();
protected CompletionRankingLocalBundle() {
super(COMPLETION_RANKING_LOCAL_BUNDLE);
}
}

View File

@@ -1,18 +0,0 @@
package com.intellij.completion.ml.local.actions
import com.intellij.completion.ml.local.CompletionRankingLocalBundle
import com.intellij.completion.ml.local.models.LocalModelsTraining
import com.intellij.openapi.actionSystem.AnAction
import com.intellij.openapi.actionSystem.AnActionEvent
class TrainLocalModelsAction : AnAction(CompletionRankingLocalBundle.message("ml.completion.local.models.training.action")) {
override fun actionPerformed(e: AnActionEvent) {
val project = e.project ?: return
if (LocalModelsTraining.isTraining()) {
//TODO: Show message that model is training right now
return
}
LocalModelsTraining.train(project)
//TODO: Show message that model trained successfully
}
}

View File

@@ -1,49 +0,0 @@
package com.intellij.completion.ml.local.features
import com.intellij.codeInsight.completion.CompletionLocation
import com.intellij.codeInsight.completion.ml.ContextFeatures
import com.intellij.codeInsight.completion.ml.ElementFeatureProvider
import com.intellij.codeInsight.completion.ml.MLFeatureValue
import com.intellij.codeInsight.lookup.LookupElement
import com.intellij.completion.ml.local.models.LocalModelsManager
import com.intellij.completion.ml.local.models.frequency.ClassesFrequencyLocalModel
import com.intellij.completion.ml.local.util.LocalModelsUtil
import com.intellij.psi.PsiClass
import com.intellij.psi.PsiMethod
class FrequencyElementFeatureProvider : ElementFeatureProvider {
override fun getName(): String = "local"
override fun calculateFeatures(element: LookupElement,
location: CompletionLocation,
contextFeatures: ContextFeatures): MutableMap<String, MLFeatureValue> {
val features = mutableMapOf<String, MLFeatureValue>()
val psi = element.psiElement
val receiverClassName = contextFeatures.getUserData(FrequencyContextFeaturesProvider.RECEIVER_CLASS_NAME_KEY)
val classFrequencies = contextFeatures.getUserData(FrequencyContextFeaturesProvider.RECEIVER_CLASS_FREQUENCIES_KEY)
if (psi is PsiMethod && receiverClassName != null && classFrequencies != null) {
LocalModelsUtil.getClassName(psi.containingClass)?.let { className ->
if (receiverClassName == className) {
LocalModelsUtil.getMethodName(psi)?.let { methodName ->
val frequency = classFrequencies.getMethodFrequency(methodName)
if (frequency > 0) {
val totalUsages = classFrequencies.getTotalFrequency()
features["absolute_method_frequency"] = MLFeatureValue.numerical(frequency)
features["relative_method_frequency"] = MLFeatureValue.numerical(frequency.toDouble() / totalUsages)
}
}
}
}
}
val classesModel = LocalModelsManager.getInstance(location.project).getModel<ClassesFrequencyLocalModel>()
if (psi is PsiClass && classesModel != null && classesModel.readyToUse()) {
LocalModelsUtil.getClassName(psi)?.let { className ->
classesModel.getClass(className)?.let {
features["absolute_class_frequency"] = MLFeatureValue.numerical(it)
features["relative_class_frequency"] = MLFeatureValue.numerical(it.toDouble() / classesModel.totalClassesUsages())
}
}
}
return features
}
}

View File

@@ -1,20 +0,0 @@
package com.intellij.completion.ml.local.models
import com.intellij.completion.ml.local.models.api.LocalModel
import com.intellij.completion.ml.local.models.frequency.ClassesFrequencyLocalModel
import com.intellij.completion.ml.local.models.frequency.MethodsFrequencyLocalModel
import com.intellij.openapi.project.Project
class LocalModelsManager private constructor(private val project: Project) {
companion object {
fun getInstance(project: Project): LocalModelsManager = project.getService(LocalModelsManager::class.java)
}
private val models = mutableMapOf<String, LocalModel>()
fun getModels(): List<LocalModel> = listOf(
models.getOrPut("methods_frequency") { MethodsFrequencyLocalModel.create(project) },
models.getOrPut("classes_frequency") { ClassesFrequencyLocalModel.create(project) }
)
inline fun <reified T : LocalModel> getModel(): T? = getModels().filterIsInstance<T>().firstOrNull()
}

View File

@@ -1,51 +0,0 @@
package com.intellij.completion.ml.local.models.frequency
import com.intellij.completion.ml.local.models.api.LocalModel
import com.intellij.completion.ml.local.models.storage.MethodsFrequencyStorage
import com.intellij.completion.ml.local.util.LocalModelsUtil
import com.intellij.openapi.project.Project
import com.intellij.psi.*
class MethodsFrequencyLocalModel private constructor(private val storage: MethodsFrequencyStorage) : LocalModel {
companion object {
fun create(project: Project): MethodsFrequencyLocalModel {
val storagesPath = LocalModelsUtil.storagePath(project)
val methodsFrequencyStorage = MethodsFrequencyStorage.getStorage(storagesPath)
return MethodsFrequencyLocalModel(methodsFrequencyStorage)
}
}
fun totalMethodsCount(): Int = storage.totalMethods
fun totalMethodsUsages(): Int = storage.totalMethodsUsages
fun getMethodsByClass(className: String): MethodsFrequencies? = storage.get(className)
override fun fileVisitor(): PsiElementVisitor = object : JavaRecursiveElementWalkingVisitor() {
override fun visitMethodCallExpression(expression: PsiMethodCallExpression) {
expression.resolveMethod()?.let { method ->
LocalModelsUtil.getMethodName(method)?.let { methodName ->
LocalModelsUtil.getClassName(method.containingClass)?.let { clsName ->
storage.addMethodUsage(clsName, methodName)
}
}
}
super.visitMethodCallExpression(expression)
}
override fun visitReferenceElement(reference: PsiJavaCodeReferenceElement) = Unit
override fun visitImportStatement(statement: PsiImportStatement) = Unit
override fun visitImportStaticStatement(statement: PsiImportStaticStatement) = Unit
}
override fun onStarted() {
storage.setValid(false)
}
override fun onFinished() {
storage.setValid(true)
}
override fun readyToUse(): Boolean = storage.isValid() && !storage.isEmpty()
}

View File

@@ -1,15 +0,0 @@
package com.intellij.completion.ml.local.util
import com.intellij.openapi.application.PathManager
import com.intellij.openapi.project.Project
import com.intellij.psi.PsiClass
import com.intellij.psi.PsiMethod
import java.nio.file.Path
object LocalModelsUtil {
fun storagePath(project: Project): Path = PathManager.getIndexRoot().toPath().resolve("completion.ml.local").resolve(project.locationHash)
fun getMethodName(method: PsiMethod): String? = method.presentation?.presentableText
fun getClassName(cls: PsiClass?): String? = cls?.qualifiedName
}

View File

@@ -11,7 +11,6 @@
<orderEntry type="library" name="kotlin-stdlib-jdk8" level="project" />
<orderEntry type="library" name="gson" level="project" />
<orderEntry type="module" module-name="intellij.platform.ide.impl" />
<orderEntry type="module" module-name="intellij.java.impl" />
<orderEntry type="module" module-name="intellij.completionMlRanking" />
<orderEntry type="module" module-name="intellij.platform.indexing" />
</component>
</module>

View File

@@ -0,0 +1,27 @@
<idea-plugin>
<id>com.intellij.ml.local.models</id>
<name>Machine Learning Local Models</name>
<vendor>JetBrains</vendor>
<category>Other Tools</category>
<description><![CDATA[
<p>The plugin contains logic for training local models based on machine learning.</p>
]]></description>
<actions>
<action id="TrainLocalModelsAction" class="com.intellij.ml.local.actions.TrainLocalModelsAction"/>
</actions>
<resource-bundle>messages.MlLocalModelsBundle</resource-bundle>
<extensionPoints>
<extensionPoint name="factory" beanClass="com.intellij.lang.LanguageExtensionPoint" dynamic="true">
<with attribute="implementationClass" implements="com.intellij.ml.local.models.api.LocalModelFactory"/>
</extensionPoint>
</extensionPoints>
<extensions defaultExtensionNs="com.intellij">
<projectService serviceImplementation="com.intellij.ml.local.models.LocalModelsManager"/>
<postStartupActivity implementation="com.intellij.ml.local.models.LocalModelsStartupActivity"/>
</extensions>
</idea-plugin>

View File

@@ -0,0 +1,3 @@
ml.local.models.training.title=Training local ML models
ml.local.models.training.action=Train Local ML Models
ml.local.models.training.files.processing=Processing source code files

View File

@@ -0,0 +1,21 @@
package com.intellij.ml.local;
import com.intellij.AbstractBundle;
import com.intellij.DynamicBundle;
import org.jetbrains.annotations.Nls;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.PropertyKey;
public class MlLocalModelsBundle extends DynamicBundle {
private static final String ML_LOCAL_MODELS_BUNDLE = "messages.MlLocalModelsBundle";
public static @Nls String message(@NotNull @PropertyKey(resourceBundle = ML_LOCAL_MODELS_BUNDLE) String key, Object @NotNull ... params) {
return ourInstance.getMessage(key, params);
}
private static final AbstractBundle ourInstance = new MlLocalModelsBundle();
protected MlLocalModelsBundle() {
super(ML_LOCAL_MODELS_BUNDLE);
}
}

View File

@@ -0,0 +1,20 @@
package com.intellij.ml.local.actions
import com.intellij.ml.local.MlLocalModelsBundle
import com.intellij.ml.local.models.LocalModelsTraining
import com.intellij.lang.Language
import com.intellij.openapi.actionSystem.AnAction
import com.intellij.openapi.actionSystem.AnActionEvent
class TrainLocalModelsAction : AnAction(MlLocalModelsBundle.message("ml.local.models.training.action")) {
override fun actionPerformed(e: AnActionEvent) {
val project = e.project ?: return
if (LocalModelsTraining.isTraining()) {
//TODO: Show message that model is training right now
return
}
//TODO: Dialog for different languages
LocalModelsTraining.train(project, Language.findLanguageByID("JAVA")!!)
//TODO: Show message that model trained successfully
}
}

View File

@@ -0,0 +1,20 @@
package com.intellij.ml.local.models
import com.intellij.ml.local.models.api.LocalModel
import com.intellij.lang.Language
import com.intellij.openapi.project.Project
class LocalModelsManager private constructor(private val project: Project) {
companion object {
fun getInstance(project: Project): LocalModelsManager = project.getService(LocalModelsManager::class.java)
}
private val models = mutableMapOf<String, MutableList<LocalModel>>()
fun getModels(language: Language): List<LocalModel> = models.getOrDefault(language.id, emptyList())
fun registerModel(language: Language, model: LocalModel) {
models.getOrPut(language.id, { mutableListOf() }).add(model)
}
inline fun <reified T : LocalModel> getModel(language: Language): T? = getModels(language).filterIsInstance<T>().firstOrNull()
}

View File

@@ -0,0 +1,20 @@
package com.intellij.ml.local.models
import com.intellij.lang.Language
import com.intellij.ml.local.models.api.LocalModelFactory
import com.intellij.openapi.project.Project
import com.intellij.openapi.startup.StartupActivity
class LocalModelsStartupActivity : StartupActivity {
override fun runActivity(project: Project) {
val modelsManager = LocalModelsManager.getInstance(project)
for (language in Language.getRegisteredLanguages()) {
val factories = LocalModelFactory.forLanguage(language)
for (factory in factories) {
factory.modelBuilder(project, language).build()?.let { model ->
modelsManager.registerModel(language, model)
}
}
}
}
}

View File

@@ -1,10 +1,12 @@
package com.intellij.completion.ml.local.models
package com.intellij.ml.local.models
import com.intellij.completion.ml.local.CompletionRankingLocalBundle
import com.intellij.completion.ml.local.models.api.LocalModel
import com.intellij.ide.highlighter.JavaFileType
import com.intellij.lang.Language
import com.intellij.ml.local.MlLocalModelsBundle
import com.intellij.ml.local.models.api.LocalModelBuilder
import com.intellij.ml.local.models.api.LocalModelFactory
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.application.runReadAction
import com.intellij.openapi.fileTypes.FileType
import com.intellij.openapi.progress.ProgressIndicator
import com.intellij.openapi.progress.ProgressManager
import com.intellij.openapi.progress.Task
@@ -16,6 +18,7 @@ import com.intellij.psi.PsiManager
import com.intellij.psi.search.FileTypeIndex
import com.intellij.psi.search.GlobalSearchScope
import com.intellij.util.indexing.FileBasedIndex
import java.lang.IllegalArgumentException
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import java.util.concurrent.atomic.AtomicInteger
@@ -26,16 +29,23 @@ object LocalModelsTraining {
fun isTraining(): Boolean = isTraining
fun train(project: Project) = ApplicationManager.getApplication().executeOnPooledThread {
fun train(project: Project, language: Language) = ApplicationManager.getApplication().executeOnPooledThread {
val fileType = language.associatedFileType ?: throw IllegalArgumentException("Unsupported language")
isTraining = true
val modelsManager = LocalModelsManager.getInstance(project)
val task = object : Task.Backgroundable(project, CompletionRankingLocalBundle.message("ml.completion.local.models.training.title"), true) {
val task = object : Task.Backgroundable(project, MlLocalModelsBundle.message("ml.local.models.training.title"), true) {
override fun run(indicator: ProgressIndicator) {
val files = getFiles(project)
val models = modelsManager.getModels()
models.forEach { it.onStarted() }
processFiles(files, models, project, indicator)
models.forEach { it.onFinished() }
val files = getFiles(project, fileType)
val factories = LocalModelFactory.forLanguage(language)
val builders = factories.map { it.modelBuilder(project, language) }
builders.forEach { it.onStarted() }
processFiles(files, builders, project, indicator)
builders.forEach { it.onFinished() }
builders.forEach {
it.build()?.let {
modelsManager.registerModel(language, it)
}
}
}
override fun onFinished() {
@@ -45,12 +55,12 @@ object LocalModelsTraining {
ProgressManager.getInstance().runProcessWithProgressAsynchronously(task, BackgroundableProcessIndicator(task))
}
private fun processFiles(files: List<VirtualFile>, models: Iterable<LocalModel>, project: Project, indicator: ProgressIndicator) {
private fun processFiles(files: List<VirtualFile>, modelBuilders: List<LocalModelBuilder>, project: Project, indicator: ProgressIndicator) {
val dumbService = DumbService.getInstance(project)
val psiManager = PsiManager.getInstance(project)
val executorService = Executors.newFixedThreadPool((Runtime.getRuntime().availableProcessors() - 1).coerceAtLeast(1))
indicator.isIndeterminate = false
indicator.text = CompletionRankingLocalBundle.message("ml.completion.local.models.training.files.processing")
indicator.text = MlLocalModelsBundle.message("ml.local.models.training.files.processing")
indicator.fraction = 0.0
val processed = AtomicInteger(0)
@@ -58,8 +68,8 @@ object LocalModelsTraining {
executorService.submit {
dumbService.runReadActionInSmartMode {
psiManager.findFile(file)?.let { psiFile ->
for (model in models) {
psiFile.accept(model.fileVisitor())
for (model in modelBuilders.map { it.fileVisitor() }) {
psiFile.accept(model)
}
}
indicator.fraction = processed.incrementAndGet().toDouble() / files.size
@@ -81,9 +91,9 @@ object LocalModelsTraining {
}
}
fun getFiles(project: Project): List<VirtualFile> = runReadAction {
fun getFiles(project: Project, fileType: FileType): List<VirtualFile> = runReadAction {
FileBasedIndex.getInstance().getContainingFiles(FileTypeIndex.NAME,
JavaFileType.INSTANCE,
fileType,
GlobalSearchScope.projectScope(project)).toList()
}
}

View File

@@ -0,0 +1,6 @@
package com.intellij.ml.local.models.api
interface LocalModel {
val id: String
fun readyToUse(): Boolean
}

View File

@@ -1,10 +1,10 @@
package com.intellij.completion.ml.local.models.api
package com.intellij.ml.local.models.api
import com.intellij.psi.PsiElementVisitor
interface LocalModel {
fun fileVisitor(): PsiElementVisitor
interface LocalModelBuilder {
fun onStarted()
fun onFinished()
fun readyToUse(): Boolean
fun fileVisitor(): PsiElementVisitor
fun build(): LocalModel?
}

View File

@@ -0,0 +1,17 @@
package com.intellij.ml.local.models.api
import com.intellij.lang.Language
import com.intellij.lang.LanguageExtension
import com.intellij.openapi.project.Project
interface LocalModelFactory {
companion object {
private val EP_NAME = LanguageExtension<LocalModelFactory>("com.intellij.ml.local.models.factory")
fun forLanguage(language: Language): List<LocalModelFactory> {
return EP_NAME.allForLanguage(language)
}
}
fun modelBuilder(project: Project, language: Language): LocalModelBuilder
}

View File

@@ -1,4 +1,4 @@
package com.intellij.completion.ml.local.models.storage
package com.intellij.ml.local.models.api
interface LocalModelStorage {
fun version(): Int

View File

@@ -0,0 +1,16 @@
package com.intellij.ml.local.models.frequency
import com.intellij.ml.local.models.api.LocalModelBuilder
import com.intellij.lang.Language
import com.intellij.ml.local.models.api.LocalModelFactory
import com.intellij.ml.local.models.frequency.classes.ClassesFrequencyModelFactory
import com.intellij.ml.local.models.frequency.methods.MethodsFrequencyModelFactory
import com.intellij.openapi.project.Project
import com.intellij.psi.PsiElementVisitor
abstract class FrequencyModelFactory<UsagesTracker> : LocalModelFactory {
protected abstract fun fileVisitor(usagesTracker: UsagesTracker): PsiElementVisitor
abstract override fun modelBuilder(project: Project, language: Language): LocalModelBuilder
}

View File

@@ -0,0 +1,16 @@
package com.intellij.ml.local.models.frequency.classes
import com.intellij.ml.local.models.api.LocalModel
class ClassesFrequencyLocalModel(private val storage: ClassesFrequencyStorage) : LocalModel {
override val id: String = "classes_frequency"
override fun readyToUse(): Boolean = storage.isValid() && !storage.isEmpty()
fun totalClassesCount(): Int = storage.totalClasses
fun totalClassesUsages(): Int = storage.totalClassesUsages
fun getClass(className: String): Int? = storage.get(className)
}

View File

@@ -0,0 +1,39 @@
package com.intellij.ml.local.models.frequency.classes
import com.intellij.ml.local.models.api.LocalModelBuilder
import com.intellij.ml.local.models.frequency.FrequencyModelFactory
import com.intellij.ml.local.util.StorageUtil
import com.intellij.lang.Language
import com.intellij.lang.LanguageExtension
import com.intellij.openapi.project.Project
import com.intellij.psi.PsiElementVisitor
abstract class ClassesFrequencyModelFactory : FrequencyModelFactory<ClassesUsagesTracker>() {
abstract override fun fileVisitor(usagesTracker: ClassesUsagesTracker): PsiElementVisitor
override fun modelBuilder(project: Project, language: Language): LocalModelBuilder {
val storagesPath = StorageUtil.storagePath(project, language)
val storage = ClassesFrequencyStorage.getStorage(storagesPath)
return object : LocalModelBuilder {
override fun onStarted() {
storage.setValid(false)
}
override fun onFinished() {
storage.setValid(true)
}
override fun fileVisitor(): PsiElementVisitor = fileVisitor(ClassesUsagesTracker(storage))
override fun build(): ClassesFrequencyLocalModel? {
if (!storage.isValid() || storage.isEmpty()) {
return null
}
return ClassesFrequencyLocalModel(storage)
}
}
}
}

View File

@@ -1,7 +1,9 @@
package com.intellij.completion.ml.local.models.storage
package com.intellij.ml.local.models.frequency.classes
import com.intellij.completion.ml.local.models.storage.StorageUtil.clear
import com.intellij.completion.ml.local.models.storage.StorageUtil.isEmpty
import com.intellij.ml.local.models.api.LocalModelStorage
import com.intellij.ml.local.util.StorageUtil
import com.intellij.ml.local.util.StorageUtil.clear
import com.intellij.ml.local.util.StorageUtil.isEmpty
import com.intellij.util.Processor
import com.intellij.util.io.EnumeratorStringDescriptor
import com.intellij.util.io.IntInlineKeyDescriptor

View File

@@ -0,0 +1,7 @@
package com.intellij.ml.local.models.frequency.classes
class ClassesUsagesTracker(private val storage: ClassesFrequencyStorage) {
fun classUsed(className: String) {
storage.addClassUsage(className)
}
}

View File

@@ -1,4 +1,4 @@
package com.intellij.completion.ml.local.models.frequency
package com.intellij.ml.local.models.frequency.methods
class MethodsFrequencies(private var totalFrequency: Int = 0,
private val methods: MutableMap<Int, Int> = mutableMapOf()) {

View File

@@ -0,0 +1,16 @@
package com.intellij.ml.local.models.frequency.methods
import com.intellij.ml.local.models.api.LocalModel
class MethodsFrequencyLocalModel internal constructor(private val storage: MethodsFrequencyStorage) : LocalModel {
override val id: String = "methods_frequency"
override fun readyToUse(): Boolean = storage.isValid() && !storage.isEmpty()
fun totalMethodsCount(): Int = storage.totalMethods
fun totalMethodsUsages(): Int = storage.totalMethodsUsages
fun getMethodsByClass(className: String): MethodsFrequencies? = storage.get(className)
}

View File

@@ -0,0 +1,39 @@
package com.intellij.ml.local.models.frequency.methods
import com.intellij.ml.local.models.api.LocalModel
import com.intellij.ml.local.models.api.LocalModelBuilder
import com.intellij.ml.local.models.frequency.FrequencyModelFactory
import com.intellij.ml.local.util.StorageUtil
import com.intellij.lang.Language
import com.intellij.openapi.project.Project
import com.intellij.psi.PsiElementVisitor
abstract class MethodsFrequencyModelFactory : FrequencyModelFactory<MethodsUsagesTracker>() {
abstract override fun fileVisitor(usagesTracker: MethodsUsagesTracker): PsiElementVisitor
override fun modelBuilder(project: Project, language: Language): LocalModelBuilder {
val storagesPath = StorageUtil.storagePath(project, language)
val storage = MethodsFrequencyStorage.getStorage(storagesPath)
return object : LocalModelBuilder {
override fun onStarted() {
storage.setValid(false)
}
override fun onFinished() {
storage.setValid(true)
}
override fun fileVisitor(): PsiElementVisitor = fileVisitor(MethodsUsagesTracker(storage))
override fun build(): LocalModel? {
if (!storage.isValid() || storage.isEmpty()) {
return null
}
return MethodsFrequencyLocalModel(storage)
}
}
}
}

View File

@@ -1,8 +1,9 @@
package com.intellij.completion.ml.local.models.storage
package com.intellij.ml.local.models.frequency.methods
import com.intellij.completion.ml.local.models.frequency.MethodsFrequencies
import com.intellij.completion.ml.local.models.storage.StorageUtil.clear
import com.intellij.completion.ml.local.models.storage.StorageUtil.isEmpty
import com.intellij.ml.local.models.api.LocalModelStorage
import com.intellij.ml.local.util.StorageUtil
import com.intellij.ml.local.util.StorageUtil.clear
import com.intellij.ml.local.util.StorageUtil.isEmpty
import com.intellij.util.Processor
import com.intellij.util.io.DataExternalizer
import com.intellij.util.io.EnumeratorStringDescriptor

View File

@@ -0,0 +1,7 @@
package com.intellij.ml.local.models.frequency.methods
class MethodsUsagesTracker(private val storage: MethodsFrequencyStorage) {
fun methodUsed(className: String, methodName: String) {
storage.addMethodUsage(className, methodName)
}
}

View File

@@ -1,6 +1,9 @@
package com.intellij.completion.ml.local.models.storage
package com.intellij.ml.local.util
import com.google.gson.Gson
import com.intellij.lang.Language
import com.intellij.openapi.application.PathManager
import com.intellij.openapi.project.Project
import com.intellij.util.io.*
import java.nio.file.Path
import kotlin.io.path.writeText
@@ -17,6 +20,11 @@ object StorageUtil {
return GSON.fromJson(infoFile.readText(), StorageInfo::class.java)
}
fun storagePath(project: Project, language: Language): Path = PathManager.getIndexRoot().toPath()
.resolve("ml.local.models")
.resolve(project.locationHash)
.resolve(language.id)
fun saveInfo(version: Int, isValid: Boolean, storageDirectory: Path) {
val infoFile = storageDirectory.resolve(STORAGE_INFO_FILE)
infoFile.writeText(GSON.toJson(StorageInfo(version, isValid)))