[ml-completion] bump default java model to 0.3.1; drop SortingRestriction; bump experimental java model to 0.3.2;

GitOrigin-RevId: 50fc8e418d359d1720d05174d7feb80617a1a874
This commit is contained in:
Vadim Lomshakov
2020-09-22 14:31:53 +03:00
committed by intellij-monorepo-bot
parent da89a26ce4
commit ed2710d5e9
8 changed files with 112 additions and 71 deletions

View File

@@ -59,13 +59,13 @@
<orderEntry type="module" module-name="intellij.platform.diff.impl" />
<orderEntry type="module-library">
<library name="completion-ranking-java" type="repository">
<properties include-transitive-deps="false" maven-id="org.jetbrains.intellij.deps.completion:completion-ranking-java:0.2.1" />
<properties include-transitive-deps="false" maven-id="org.jetbrains.intellij.deps.completion:completion-ranking-java:0.3.1" />
<CLASSES>
<root url="jar://$MAVEN_REPOSITORY$/org/jetbrains/intellij/deps/completion/completion-ranking-java/0.2.1/completion-ranking-java-0.2.1.jar!/" />
<root url="jar://$MAVEN_REPOSITORY$/org/jetbrains/intellij/deps/completion/completion-ranking-java/0.3.1/completion-ranking-java-0.3.1.jar!/" />
</CLASSES>
<JAVADOC />
<SOURCES>
<root url="jar://$MAVEN_REPOSITORY$/org/jetbrains/intellij/deps/completion/completion-ranking-java/0.2.1/completion-ranking-java-0.2.1-sources.jar!/" />
<root url="jar://$MAVEN_REPOSITORY$/org/jetbrains/intellij/deps/completion/completion-ranking-java/0.3.1/completion-ranking-java-0.3.1-sources.jar!/" />
</SOURCES>
</library>
</orderEntry>

View File

@@ -197,6 +197,8 @@ class CommunityLibraryLicenses {
url: "https://github.com/raphw/byte-buddy", licenseUrl: "http://www.apache.org/licenses/LICENSE-2.0"),
new LibraryLicense(name: "caffeine", libraryName: "caffeine", license: "Apache 2.0",
url: "https://github.com/ben-manes/caffeine", licenseUrl: "https://github.com/ben-manes/caffeine/blob/master/LICENSE"),
new LibraryLicense(name: "CatBoost Model Applier", libraryName: "ai.catboost:catboost-prediction:0.24", license: "Apache 2.0",
url: "https://github.com/catboost/catboost", licenseUrl: "https://github.com/catboost/catboost/blob/master/LICENSE"),
new LibraryLicense(name: "CGLib", libraryName: "CGLIB", license: "Apache", url: "http://cglib.sourceforge.net/",
licenseUrl: "http://www.apache.org/foundation/licence-FAQ.html"),
new LibraryLicense(name: "classworlds", libraryName: "Maven", transitiveDependency: true, version: "1.1", license: "codehaus",

View File

@@ -1,7 +1,7 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.intellij.internal.ml
class ResourcesModelMetadataReader(private val metadataHolder: Class<*>, private val featuresDirectory: String): ModelMetadataReader {
open class ResourcesModelMetadataReader(protected val metadataHolder: Class<*>, private val featuresDirectory: String): ModelMetadataReader {
override fun binaryFeatures(): String = resourceContent("binary.json")
override fun floatFeatures(): String = resourceContent("float.json")

View File

@@ -12,13 +12,13 @@
<orderEntry type="module" module-name="intellij.completionMlRanking" />
<orderEntry type="module-library">
<library name="completion-ranking-java-exp" type="repository">
<properties include-transitive-deps="false" maven-id="org.jetbrains.intellij.deps.completion:completion-ranking-java:0.3.1" />
<properties include-transitive-deps="false" maven-id="org.jetbrains.intellij.deps.completion:completion-ranking-java:0.3.2" />
<CLASSES>
<root url="jar://$MAVEN_REPOSITORY$/org/jetbrains/intellij/deps/completion/completion-ranking-java/0.3.1/completion-ranking-java-0.3.1.jar!/" />
<root url="jar://$MAVEN_REPOSITORY$/org/jetbrains/intellij/deps/completion/completion-ranking-java/0.3.2/completion-ranking-java-0.3.2.jar!/" />
</CLASSES>
<JAVADOC />
<SOURCES>
<root url="jar://$MAVEN_REPOSITORY$/org/jetbrains/intellij/deps/completion/completion-ranking-java/0.3.1/completion-ranking-java-0.3.1-sources.jar!/" />
<root url="jar://$MAVEN_REPOSITORY$/org/jetbrains/intellij/deps/completion/completion-ranking-java/0.3.2/completion-ranking-java-0.3.2-sources.jar!/" />
</SOURCES>
</library>
</orderEntry>
@@ -67,5 +67,19 @@
</SOURCES>
</library>
</orderEntry>
<orderEntry type="module-library">
<library name="ai.catboost:catboost-prediction:0.24" type="repository">
<properties maven-id="ai.catboost:catboost-prediction:0.24" />
<CLASSES>
<root url="jar://$MAVEN_REPOSITORY$/ai/catboost/catboost-prediction/0.24/catboost-prediction-0.24.jar!/" />
<root url="jar://$MAVEN_REPOSITORY$/javax/validation/validation-api/1.1.0.Final/validation-api-1.1.0.Final.jar!/" />
<root url="jar://$MAVEN_REPOSITORY$/com/google/code/findbugs/jsr305/3.0.2/jsr305-3.0.2.jar!/" />
<root url="jar://$MAVEN_REPOSITORY$/org/slf4j/slf4j-api/1.7.25/slf4j-api-1.7.25.jar!/" />
<root url="jar://$MAVEN_REPOSITORY$/ai/catboost/catboost-common/0.24/catboost-common-0.24.jar!/" />
</CLASSES>
<JAVADOC />
<SOURCES />
</library>
</orderEntry>
</component>
</module>

View File

@@ -1,21 +1,11 @@
package com.jetbrains.completion.ml.ranker
import com.intellij.completion.ml.ranker.ExperimentModelProvider
import com.intellij.internal.ml.DecisionFunction
import com.intellij.internal.ml.ModelMetadata
import com.intellij.internal.ml.completion.CompletionRankingModelBase
import com.intellij.internal.ml.completion.JarCompletionModelProvider
import com.intellij.lang.Language
import com.jetbrains.completion.ranker.model.java.MLGlassBox
import com.jetbrains.completion.ml.ranker.cb.JarCatBoostCompletionModelProvider
class ExperimentJavaMLRankingProvider: JarCompletionModelProvider(
CompletionRankingModelsBundle.message("ml.completion.experiment.model.java"), "java_features"), ExperimentModelProvider {
override fun createModel(metadata: ModelMetadata): DecisionFunction {
return object : CompletionRankingModelBase(metadata) {
override fun predict(features: DoubleArray?): Double = MLGlassBox.makePredict(features)
}
}
class ExperimentJavaMLRankingProvider: JarCatBoostCompletionModelProvider(
CompletionRankingModelsBundle.message("ml.completion.experiment.model.java"), "java_features", "java_model"), ExperimentModelProvider {
override fun isLanguageSupported(language: Language): Boolean = language.id.compareTo("Java", ignoreCase = true) == 0

View File

@@ -0,0 +1,18 @@
package com.jetbrains.completion.ml.ranker.cb
import ai.catboost.CatBoostModel
import com.intellij.internal.ml.InconsistentMetadataException
import com.intellij.internal.ml.ResourcesModelMetadataReader
class CatBoostResourcesModelMetadataReader(metadataHolder: Class<*>,
featuresDirectory: String,
private val modelDirectory: String) : ResourcesModelMetadataReader(metadataHolder, featuresDirectory) {
fun loadModel(): CatBoostModel {
val resource = "$modelDirectory/model.cbm"
val fileStream = metadataHolder.classLoader.getResourceAsStream(resource)
?: throw InconsistentMetadataException(
"Metadata file not found: $resource. Resources holder: ${metadataHolder.name}")
return CatBoostModel.loadModel(fileStream)
}
}

View File

@@ -0,0 +1,66 @@
package com.jetbrains.completion.ml.ranker.cb
import com.intellij.internal.ml.DecisionFunction
import com.intellij.internal.ml.FeaturesInfo
import com.intellij.internal.ml.InconsistentMetadataException
import com.intellij.internal.ml.completion.CompletionRankingModelBase
import com.intellij.internal.ml.completion.RankingModelProvider
import com.intellij.openapi.diagnostic.logger
import org.jetbrains.annotations.Nls
import org.jetbrains.annotations.NonNls
import org.jetbrains.annotations.TestOnly
abstract class JarCatBoostCompletionModelProvider(@Nls(capitalization = Nls.Capitalization.Title) private val displayName: String,
@NonNls private val resourceDirectory: String,
@NonNls private val modelDirectory: String) : RankingModelProvider {
private val lazyModel: DecisionFunction by lazy {
val metadataReader = CatBoostResourcesModelMetadataReader(this::class.java, resourceDirectory, modelDirectory)
val metadata = FeaturesInfo.buildInfo(metadataReader)
val model = metadataReader.loadModel()
return@lazy object : CompletionRankingModelBase(metadata) {
override fun predict(features: DoubleArray): Double {
val floatArray = FloatArray(features.size)
for (i in features.indices) {
floatArray[i] = features[i].toFloat()
}
try {
return model.predict(floatArray, emptyArray<String>()).get(0, 0)
} catch (t: Throwable) {
LOG.error(t)
return 0.0
}
}
}
}
override fun getModel(): DecisionFunction = lazyModel
override fun getDisplayNameInSettings(): String = displayName
@TestOnly
fun assertModelMetadataConsistent() {
try {
val decisionFunction = model
decisionFunction.version()
val unknownRequiredFeatures = decisionFunction.getUnknownFeatures(decisionFunction.requiredFeatures)
assert(unknownRequiredFeatures.isEmpty()) { "All required features should be known, but $unknownRequiredFeatures unknown" }
val featuresOrder = decisionFunction.featuresOrder
val unknownUsedFeatures = decisionFunction.getUnknownFeatures(featuresOrder.map { it.featureName }.distinct())
assert(unknownUsedFeatures.isEmpty()) { "All used features should be known, but $unknownUsedFeatures unknown" }
val features = DoubleArray(featuresOrder.size)
decisionFunction.predict(features)
}
catch (e: InconsistentMetadataException) {
throw AssertionError("Model metadata inconsistent", e)
}
}
companion object {
private val LOG = logger<JarCatBoostCompletionModelProvider>()
}
}

View File

@@ -13,8 +13,6 @@ import com.intellij.completion.ml.util.RelevanceUtil
import com.intellij.completion.ml.common.PrefixMatchingUtil
import com.intellij.completion.ml.performance.CompletionPerformanceTracker
import com.intellij.completion.ml.settings.CompletionMLRankingSettings
import com.intellij.lang.Language
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.util.Pair
import com.intellij.openapi.util.registry.Registry
import com.intellij.openapi.diagnostic.logger
@@ -36,7 +34,6 @@ class MLSorter : CompletionFinalSorter() {
}
private val cachedScore: MutableMap<LookupElement, ItemRankInfo> = IdentityHashMap()
private lateinit var sortingRestrictions: SortingRestriction
private val reorderOnlyTopItems: Boolean = Registry.`is`("completion.ml.reorder.only.top.items", true)
override fun getRelevanceObjects(items: MutableIterable<LookupElement>): Map<LookupElement, List<Pair<String, Any>>> {
@@ -81,10 +78,6 @@ class MLSorter : CompletionFinalSorter() {
val queryLength = lookup.queryLength()
val prefix = lookup.prefix()
if (!this::sortingRestrictions.isInitialized) {
sortingRestrictions = SortingRestriction.forLanguage(lookupStorage.language, lookupStorage.model?.version() ?: "")
}
val element2score = mutableMapOf<LookupElement, Double?>()
val elements = items.toList()
@@ -157,9 +150,7 @@ class MLSorter : CompletionFinalSorter() {
val score = tracker.measure {
val position = positionsBefore.getValue(element)
val elementFeatures = features.withElementFeatures(relevance, additional)
val score = calculateElementScore(rankingModel, element, position, elementFeatures, queryLength)
sortingRestrictions.itemScored(elementFeatures)
return@measure score
return@measure calculateElementScore(rankingModel, element, position, elementFeatures, queryLength)
}
element2score[element] = score
@@ -175,7 +166,7 @@ class MLSorter : CompletionFinalSorter() {
positionsBefore: Map<LookupElement, Int>,
lookupStorage: MutableLookupStorage,
lookup: LookupImpl): Iterable<LookupElement> {
val mlScoresUsed = element2score.values.none { it == null } && sortingRestrictions.shouldSort()
val mlScoresUsed = element2score.values.none { it == null }
if (LOG.isDebugEnabled) {
LOG.debug("ML sorting in completion used=$mlScoresUsed for language=${lookupStorage.language.id}")
}
@@ -297,46 +288,6 @@ class MLSorter : CompletionFinalSorter() {
}
}
}
interface SortingRestriction {
companion object {
fun forLanguage(language: Language, version: String): SortingRestriction {
if (language.id.equals("Java", ignoreCase = true)
&& version.equals("0.2.1", ignoreCase = true)
&& !ApplicationManager.getApplication().isUnitTestMode) {
return SortOnlyWithRecommendersScore()
}
return SortAll()
}
}
fun itemScored(features: RankingFeatures)
fun shouldSort(): Boolean
}
private class SortAll : SortingRestriction {
override fun shouldSort(): Boolean = true
override fun itemScored(features: RankingFeatures) {}
}
private class SortOnlyWithRecommendersScore : SortingRestriction {
companion object {
private val REC_FEATURES_NAMES: List<String> = listOf("ml_rec-instances_probability", "ml_rec-statics2_probability")
}
private var recommendersScoreFound: Boolean = false
override fun itemScored(features: RankingFeatures) {
if (!recommendersScoreFound) {
recommendersScoreFound = REC_FEATURES_NAMES.any { features.hasFeature(it) }
}
}
override fun shouldSort(): Boolean {
return recommendersScoreFound
}
}
}
private data class ItemRankInfo(val positionBefore: Int, val mlRank: Double?, val prefixLength: Int)