mirror of
https://gitflic.ru/project/openide/openide.git
synced 2026-01-08 15:09:39 +07:00
[ml-completion] bump default java model to 0.3.1; drop SortingRestriction; bump experimental java model to 0.3.2;
GitOrigin-RevId: 50fc8e418d359d1720d05174d7feb80617a1a874
This commit is contained in:
committed by
intellij-monorepo-bot
parent
da89a26ce4
commit
ed2710d5e9
@@ -59,13 +59,13 @@
|
||||
<orderEntry type="module" module-name="intellij.platform.diff.impl" />
|
||||
<orderEntry type="module-library">
|
||||
<library name="completion-ranking-java" type="repository">
|
||||
<properties include-transitive-deps="false" maven-id="org.jetbrains.intellij.deps.completion:completion-ranking-java:0.2.1" />
|
||||
<properties include-transitive-deps="false" maven-id="org.jetbrains.intellij.deps.completion:completion-ranking-java:0.3.1" />
|
||||
<CLASSES>
|
||||
<root url="jar://$MAVEN_REPOSITORY$/org/jetbrains/intellij/deps/completion/completion-ranking-java/0.2.1/completion-ranking-java-0.2.1.jar!/" />
|
||||
<root url="jar://$MAVEN_REPOSITORY$/org/jetbrains/intellij/deps/completion/completion-ranking-java/0.3.1/completion-ranking-java-0.3.1.jar!/" />
|
||||
</CLASSES>
|
||||
<JAVADOC />
|
||||
<SOURCES>
|
||||
<root url="jar://$MAVEN_REPOSITORY$/org/jetbrains/intellij/deps/completion/completion-ranking-java/0.2.1/completion-ranking-java-0.2.1-sources.jar!/" />
|
||||
<root url="jar://$MAVEN_REPOSITORY$/org/jetbrains/intellij/deps/completion/completion-ranking-java/0.3.1/completion-ranking-java-0.3.1-sources.jar!/" />
|
||||
</SOURCES>
|
||||
</library>
|
||||
</orderEntry>
|
||||
|
||||
@@ -197,6 +197,8 @@ class CommunityLibraryLicenses {
|
||||
url: "https://github.com/raphw/byte-buddy", licenseUrl: "http://www.apache.org/licenses/LICENSE-2.0"),
|
||||
new LibraryLicense(name: "caffeine", libraryName: "caffeine", license: "Apache 2.0",
|
||||
url: "https://github.com/ben-manes/caffeine", licenseUrl: "https://github.com/ben-manes/caffeine/blob/master/LICENSE"),
|
||||
new LibraryLicense(name: "CatBoost Model Applier", libraryName: "ai.catboost:catboost-prediction:0.24", license: "Apache 2.0",
|
||||
url: "https://github.com/catboost/catboost", licenseUrl: "https://github.com/catboost/catboost/blob/master/LICENSE"),
|
||||
new LibraryLicense(name: "CGLib", libraryName: "CGLIB", license: "Apache", url: "http://cglib.sourceforge.net/",
|
||||
licenseUrl: "http://www.apache.org/foundation/licence-FAQ.html"),
|
||||
new LibraryLicense(name: "classworlds", libraryName: "Maven", transitiveDependency: true, version: "1.1", license: "codehaus",
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
package com.intellij.internal.ml
|
||||
|
||||
class ResourcesModelMetadataReader(private val metadataHolder: Class<*>, private val featuresDirectory: String): ModelMetadataReader {
|
||||
open class ResourcesModelMetadataReader(protected val metadataHolder: Class<*>, private val featuresDirectory: String): ModelMetadataReader {
|
||||
|
||||
override fun binaryFeatures(): String = resourceContent("binary.json")
|
||||
override fun floatFeatures(): String = resourceContent("float.json")
|
||||
|
||||
@@ -12,13 +12,13 @@
|
||||
<orderEntry type="module" module-name="intellij.completionMlRanking" />
|
||||
<orderEntry type="module-library">
|
||||
<library name="completion-ranking-java-exp" type="repository">
|
||||
<properties include-transitive-deps="false" maven-id="org.jetbrains.intellij.deps.completion:completion-ranking-java:0.3.1" />
|
||||
<properties include-transitive-deps="false" maven-id="org.jetbrains.intellij.deps.completion:completion-ranking-java:0.3.2" />
|
||||
<CLASSES>
|
||||
<root url="jar://$MAVEN_REPOSITORY$/org/jetbrains/intellij/deps/completion/completion-ranking-java/0.3.1/completion-ranking-java-0.3.1.jar!/" />
|
||||
<root url="jar://$MAVEN_REPOSITORY$/org/jetbrains/intellij/deps/completion/completion-ranking-java/0.3.2/completion-ranking-java-0.3.2.jar!/" />
|
||||
</CLASSES>
|
||||
<JAVADOC />
|
||||
<SOURCES>
|
||||
<root url="jar://$MAVEN_REPOSITORY$/org/jetbrains/intellij/deps/completion/completion-ranking-java/0.3.1/completion-ranking-java-0.3.1-sources.jar!/" />
|
||||
<root url="jar://$MAVEN_REPOSITORY$/org/jetbrains/intellij/deps/completion/completion-ranking-java/0.3.2/completion-ranking-java-0.3.2-sources.jar!/" />
|
||||
</SOURCES>
|
||||
</library>
|
||||
</orderEntry>
|
||||
@@ -67,5 +67,19 @@
|
||||
</SOURCES>
|
||||
</library>
|
||||
</orderEntry>
|
||||
<orderEntry type="module-library">
|
||||
<library name="ai.catboost:catboost-prediction:0.24" type="repository">
|
||||
<properties maven-id="ai.catboost:catboost-prediction:0.24" />
|
||||
<CLASSES>
|
||||
<root url="jar://$MAVEN_REPOSITORY$/ai/catboost/catboost-prediction/0.24/catboost-prediction-0.24.jar!/" />
|
||||
<root url="jar://$MAVEN_REPOSITORY$/javax/validation/validation-api/1.1.0.Final/validation-api-1.1.0.Final.jar!/" />
|
||||
<root url="jar://$MAVEN_REPOSITORY$/com/google/code/findbugs/jsr305/3.0.2/jsr305-3.0.2.jar!/" />
|
||||
<root url="jar://$MAVEN_REPOSITORY$/org/slf4j/slf4j-api/1.7.25/slf4j-api-1.7.25.jar!/" />
|
||||
<root url="jar://$MAVEN_REPOSITORY$/ai/catboost/catboost-common/0.24/catboost-common-0.24.jar!/" />
|
||||
</CLASSES>
|
||||
<JAVADOC />
|
||||
<SOURCES />
|
||||
</library>
|
||||
</orderEntry>
|
||||
</component>
|
||||
</module>
|
||||
@@ -1,21 +1,11 @@
|
||||
package com.jetbrains.completion.ml.ranker
|
||||
|
||||
import com.intellij.completion.ml.ranker.ExperimentModelProvider
|
||||
import com.intellij.internal.ml.DecisionFunction
|
||||
import com.intellij.internal.ml.ModelMetadata
|
||||
import com.intellij.internal.ml.completion.CompletionRankingModelBase
|
||||
import com.intellij.internal.ml.completion.JarCompletionModelProvider
|
||||
import com.intellij.lang.Language
|
||||
import com.jetbrains.completion.ranker.model.java.MLGlassBox
|
||||
import com.jetbrains.completion.ml.ranker.cb.JarCatBoostCompletionModelProvider
|
||||
|
||||
class ExperimentJavaMLRankingProvider: JarCompletionModelProvider(
|
||||
CompletionRankingModelsBundle.message("ml.completion.experiment.model.java"), "java_features"), ExperimentModelProvider {
|
||||
|
||||
override fun createModel(metadata: ModelMetadata): DecisionFunction {
|
||||
return object : CompletionRankingModelBase(metadata) {
|
||||
override fun predict(features: DoubleArray?): Double = MLGlassBox.makePredict(features)
|
||||
}
|
||||
}
|
||||
class ExperimentJavaMLRankingProvider: JarCatBoostCompletionModelProvider(
|
||||
CompletionRankingModelsBundle.message("ml.completion.experiment.model.java"), "java_features", "java_model"), ExperimentModelProvider {
|
||||
|
||||
override fun isLanguageSupported(language: Language): Boolean = language.id.compareTo("Java", ignoreCase = true) == 0
|
||||
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
package com.jetbrains.completion.ml.ranker.cb
|
||||
|
||||
import ai.catboost.CatBoostModel
|
||||
import com.intellij.internal.ml.InconsistentMetadataException
|
||||
import com.intellij.internal.ml.ResourcesModelMetadataReader
|
||||
|
||||
class CatBoostResourcesModelMetadataReader(metadataHolder: Class<*>,
|
||||
featuresDirectory: String,
|
||||
private val modelDirectory: String) : ResourcesModelMetadataReader(metadataHolder, featuresDirectory) {
|
||||
|
||||
fun loadModel(): CatBoostModel {
|
||||
val resource = "$modelDirectory/model.cbm"
|
||||
val fileStream = metadataHolder.classLoader.getResourceAsStream(resource)
|
||||
?: throw InconsistentMetadataException(
|
||||
"Metadata file not found: $resource. Resources holder: ${metadataHolder.name}")
|
||||
return CatBoostModel.loadModel(fileStream)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,66 @@
|
||||
package com.jetbrains.completion.ml.ranker.cb
|
||||
|
||||
import com.intellij.internal.ml.DecisionFunction
|
||||
import com.intellij.internal.ml.FeaturesInfo
|
||||
import com.intellij.internal.ml.InconsistentMetadataException
|
||||
import com.intellij.internal.ml.completion.CompletionRankingModelBase
|
||||
import com.intellij.internal.ml.completion.RankingModelProvider
|
||||
import com.intellij.openapi.diagnostic.logger
|
||||
import org.jetbrains.annotations.Nls
|
||||
import org.jetbrains.annotations.NonNls
|
||||
import org.jetbrains.annotations.TestOnly
|
||||
|
||||
abstract class JarCatBoostCompletionModelProvider(@Nls(capitalization = Nls.Capitalization.Title) private val displayName: String,
|
||||
@NonNls private val resourceDirectory: String,
|
||||
@NonNls private val modelDirectory: String) : RankingModelProvider {
|
||||
private val lazyModel: DecisionFunction by lazy {
|
||||
val metadataReader = CatBoostResourcesModelMetadataReader(this::class.java, resourceDirectory, modelDirectory)
|
||||
val metadata = FeaturesInfo.buildInfo(metadataReader)
|
||||
val model = metadataReader.loadModel()
|
||||
return@lazy object : CompletionRankingModelBase(metadata) {
|
||||
override fun predict(features: DoubleArray): Double {
|
||||
|
||||
val floatArray = FloatArray(features.size)
|
||||
for (i in features.indices) {
|
||||
floatArray[i] = features[i].toFloat()
|
||||
}
|
||||
|
||||
try {
|
||||
return model.predict(floatArray, emptyArray<String>()).get(0, 0)
|
||||
} catch (t: Throwable) {
|
||||
LOG.error(t)
|
||||
return 0.0
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun getModel(): DecisionFunction = lazyModel
|
||||
|
||||
override fun getDisplayNameInSettings(): String = displayName
|
||||
|
||||
@TestOnly
|
||||
fun assertModelMetadataConsistent() {
|
||||
try {
|
||||
val decisionFunction = model
|
||||
decisionFunction.version()
|
||||
|
||||
val unknownRequiredFeatures = decisionFunction.getUnknownFeatures(decisionFunction.requiredFeatures)
|
||||
assert(unknownRequiredFeatures.isEmpty()) { "All required features should be known, but $unknownRequiredFeatures unknown" }
|
||||
|
||||
val featuresOrder = decisionFunction.featuresOrder
|
||||
val unknownUsedFeatures = decisionFunction.getUnknownFeatures(featuresOrder.map { it.featureName }.distinct())
|
||||
assert(unknownUsedFeatures.isEmpty()) { "All used features should be known, but $unknownUsedFeatures unknown" }
|
||||
|
||||
val features = DoubleArray(featuresOrder.size)
|
||||
decisionFunction.predict(features)
|
||||
}
|
||||
catch (e: InconsistentMetadataException) {
|
||||
throw AssertionError("Model metadata inconsistent", e)
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
private val LOG = logger<JarCatBoostCompletionModelProvider>()
|
||||
}
|
||||
}
|
||||
@@ -13,8 +13,6 @@ import com.intellij.completion.ml.util.RelevanceUtil
|
||||
import com.intellij.completion.ml.common.PrefixMatchingUtil
|
||||
import com.intellij.completion.ml.performance.CompletionPerformanceTracker
|
||||
import com.intellij.completion.ml.settings.CompletionMLRankingSettings
|
||||
import com.intellij.lang.Language
|
||||
import com.intellij.openapi.application.ApplicationManager
|
||||
import com.intellij.openapi.util.Pair
|
||||
import com.intellij.openapi.util.registry.Registry
|
||||
import com.intellij.openapi.diagnostic.logger
|
||||
@@ -36,7 +34,6 @@ class MLSorter : CompletionFinalSorter() {
|
||||
}
|
||||
|
||||
private val cachedScore: MutableMap<LookupElement, ItemRankInfo> = IdentityHashMap()
|
||||
private lateinit var sortingRestrictions: SortingRestriction
|
||||
private val reorderOnlyTopItems: Boolean = Registry.`is`("completion.ml.reorder.only.top.items", true)
|
||||
|
||||
override fun getRelevanceObjects(items: MutableIterable<LookupElement>): Map<LookupElement, List<Pair<String, Any>>> {
|
||||
@@ -81,10 +78,6 @@ class MLSorter : CompletionFinalSorter() {
|
||||
val queryLength = lookup.queryLength()
|
||||
val prefix = lookup.prefix()
|
||||
|
||||
if (!this::sortingRestrictions.isInitialized) {
|
||||
sortingRestrictions = SortingRestriction.forLanguage(lookupStorage.language, lookupStorage.model?.version() ?: "")
|
||||
}
|
||||
|
||||
val element2score = mutableMapOf<LookupElement, Double?>()
|
||||
val elements = items.toList()
|
||||
|
||||
@@ -157,9 +150,7 @@ class MLSorter : CompletionFinalSorter() {
|
||||
val score = tracker.measure {
|
||||
val position = positionsBefore.getValue(element)
|
||||
val elementFeatures = features.withElementFeatures(relevance, additional)
|
||||
val score = calculateElementScore(rankingModel, element, position, elementFeatures, queryLength)
|
||||
sortingRestrictions.itemScored(elementFeatures)
|
||||
return@measure score
|
||||
return@measure calculateElementScore(rankingModel, element, position, elementFeatures, queryLength)
|
||||
}
|
||||
element2score[element] = score
|
||||
|
||||
@@ -175,7 +166,7 @@ class MLSorter : CompletionFinalSorter() {
|
||||
positionsBefore: Map<LookupElement, Int>,
|
||||
lookupStorage: MutableLookupStorage,
|
||||
lookup: LookupImpl): Iterable<LookupElement> {
|
||||
val mlScoresUsed = element2score.values.none { it == null } && sortingRestrictions.shouldSort()
|
||||
val mlScoresUsed = element2score.values.none { it == null }
|
||||
if (LOG.isDebugEnabled) {
|
||||
LOG.debug("ML sorting in completion used=$mlScoresUsed for language=${lookupStorage.language.id}")
|
||||
}
|
||||
@@ -297,46 +288,6 @@ class MLSorter : CompletionFinalSorter() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
interface SortingRestriction {
|
||||
companion object {
|
||||
fun forLanguage(language: Language, version: String): SortingRestriction {
|
||||
if (language.id.equals("Java", ignoreCase = true)
|
||||
&& version.equals("0.2.1", ignoreCase = true)
|
||||
&& !ApplicationManager.getApplication().isUnitTestMode) {
|
||||
return SortOnlyWithRecommendersScore()
|
||||
}
|
||||
return SortAll()
|
||||
}
|
||||
}
|
||||
|
||||
fun itemScored(features: RankingFeatures)
|
||||
|
||||
fun shouldSort(): Boolean
|
||||
}
|
||||
|
||||
private class SortAll : SortingRestriction {
|
||||
override fun shouldSort(): Boolean = true
|
||||
override fun itemScored(features: RankingFeatures) {}
|
||||
}
|
||||
|
||||
private class SortOnlyWithRecommendersScore : SortingRestriction {
|
||||
companion object {
|
||||
private val REC_FEATURES_NAMES: List<String> = listOf("ml_rec-instances_probability", "ml_rec-statics2_probability")
|
||||
}
|
||||
|
||||
private var recommendersScoreFound: Boolean = false
|
||||
|
||||
override fun itemScored(features: RankingFeatures) {
|
||||
if (!recommendersScoreFound) {
|
||||
recommendersScoreFound = REC_FEATURES_NAMES.any { features.hasFeature(it) }
|
||||
}
|
||||
}
|
||||
|
||||
override fun shouldSort(): Boolean {
|
||||
return recommendersScoreFound
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private data class ItemRankInfo(val positionBefore: Int, val mlRank: Double?, val prefixLength: Int)
|
||||
|
||||
Reference in New Issue
Block a user