mirror of
https://gitflic.ru/project/openide/openide.git
synced 2025-12-15 02:59:33 +07:00
[ml] ML-2246 Create a prototype of the common ML API
Fix recursion issue while initializing FUS Fix project structure Add ml api initialization logger Add Api.Internal annotations Rename a field for better understanding Add some missing documentation fixup! Add the API's code documentation fixup! Add the API's code documentation fixup! Add the API's code documentation fixup! Add the API's code documentation Add analysis of the not started sessions Add API startup listener & move logger out of approach initialization Implement API listeners Add the API's code documentation Refactor level storage format Remove redundant changes Update demo Fix non logging nested sessions issue Secure fus logger's initialization Access the API only via MLApiPlatform Add model's type & add model analysers Add double quotes to the code-like categorical feature representation Allow logging strings from a particular set In logs, split non-declared features by usage Acknowledge non declared features in logs Add tier to logs' tier id Add environment resolve logic Update performance task declaration Enable testing FUS logs Split MLTask and MLTaskApproach Add informative failure messages Replace feature set requirements with feature selectors Make feature a class Extend demo a little bit Remove unchecked cast suppressing Add type-safe environment initialisation Add type-safe feature declaration Add example lookup analysis Think better about beautiful names Add item tier && attempt to beautify nested sessions creation Update demo Attempt to create another ml api Merge-request: IJ-MR-125165 Merged-by: Gleb Marin <Gleb.Marin@jetbrains.com> GitOrigin-RevId: 2aeed3030ab9a5f43e51783fd95c36cda99a18ab
This commit is contained in:
committed by
intellij-monorepo-bot
parent
d3682b6383
commit
ef4328d797
@@ -3,6 +3,7 @@
|
||||
<component name="NewModuleRootManager" inherit-compiler-output="true">
|
||||
<exclude-output />
|
||||
<content url="file://$MODULE_DIR$">
|
||||
<sourceFolder url="file://$MODULE_DIR$/resources" type="java-resource" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
|
||||
</content>
|
||||
<orderEntry type="inheritedJdk" />
|
||||
|
||||
2
platform/ml-api/resources/META-INF/ml-api.xml
Normal file
2
platform/ml-api/resources/META-INF/ml-api.xml
Normal file
@@ -0,0 +1,2 @@
|
||||
<idea-plugin>
|
||||
</idea-plugin>
|
||||
69
platform/ml-api/src/com/intellij/platform/ml/Environment.kt
Normal file
69
platform/ml-api/src/com/intellij/platform/ml/Environment.kt
Normal file
@@ -0,0 +1,69 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml
|
||||
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* Represents an environment that is being assembled to be described by [TierDescriptor]s,
|
||||
* to acquire a new ML Model, or for another reason.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
interface Environment {
|
||||
/**
|
||||
* The set of tiers, that the environment contains.
|
||||
*/
|
||||
val tiers: Set<Tier<*>>
|
||||
|
||||
/**
|
||||
* @return an instance, that corresponds to the given [tier].
|
||||
* @throws IllegalArgumentException if the [tier] is not present
|
||||
*/
|
||||
fun <T : Any> getInstance(tier: Tier<T>): T
|
||||
|
||||
/**
|
||||
* The set of tier instances that are present in the environment.
|
||||
*/
|
||||
val tierInstances: Set<TierInstance<*>>
|
||||
get() {
|
||||
return tiers.map { this.getTierInstance(it) }.toSet()
|
||||
}
|
||||
|
||||
/**
|
||||
* @return a tier instance wrapped into [TierInstance] class.
|
||||
* @throws IllegalArgumentException if the tier is not present.
|
||||
*/
|
||||
fun <T : Any> getTierInstance(tier: Tier<T>) = TierInstance(tier, this[tier])
|
||||
|
||||
/**
|
||||
* @return if the tier is present in the environment
|
||||
*/
|
||||
operator fun contains(tier: Tier<*>): Boolean = tier in this.tiers
|
||||
|
||||
companion object {
|
||||
/**
|
||||
* @return An environment that contains all tiers of all given environments.
|
||||
* @throws IllegalArgumentException if there is a tier that is present in more than two
|
||||
* environments.
|
||||
*/
|
||||
fun joined(environments: Iterable<Environment>): Environment = TierInstanceStorage.joined(environments)
|
||||
|
||||
fun empty(): Environment = joined(emptySet())
|
||||
|
||||
/**
|
||||
* Builds an environment that contains all the given tiers.
|
||||
* @throws IllegalArgumentException if there is more than one instance of a particular tier.
|
||||
*/
|
||||
fun of(entries: Iterable<TierInstance<*>>): Environment {
|
||||
val storage = TierInstanceStorage()
|
||||
fun <T : Any> putToStorage(tierInstance: TierInstance<T>) {
|
||||
storage[tierInstance.tier] = tierInstance.instance
|
||||
}
|
||||
entries.forEach { putToStorage(it) }
|
||||
return storage
|
||||
}
|
||||
|
||||
fun of(vararg entries: TierInstance<*>): Environment = of(entries.toList())
|
||||
}
|
||||
}
|
||||
|
||||
operator fun <T : Any> Environment.get(tier: Tier<T>): T = getInstance(tier)
|
||||
@@ -0,0 +1,53 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml
|
||||
|
||||
import com.intellij.openapi.extensions.ExtensionPointName
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* Provides additional tiers on the top of the main ones, making an "extended"
|
||||
* environment.
|
||||
*
|
||||
* If there are some other tiers that an [EnvironmentExtender] needs to build or access the [extendingTier],
|
||||
* then it could "order" them by defining the [requiredTiers].
|
||||
*
|
||||
* The extender will be called if and only if all the requirements are satisfied.
|
||||
*
|
||||
* Additional tiers will be later described by [TierDescriptor]. This description will be used for the model's
|
||||
* inference and then logged. However, they cannot be analyzed via [com.intellij.platform.ml.impl.session.analysis.SessionAnalyser].
|
||||
* Because they do not make a part of the ML Task, and they could be absent in the session.
|
||||
*
|
||||
* An "extended environment" (i.e., the one that contains main as well as additional tiers) is built
|
||||
* each time when a [TierRequester] performs the desired action. For example, an [EnvironmentExtender.extend] or
|
||||
* [com.intellij.platform.ml.impl.model.MLModel.Provider.provideModel] is called.
|
||||
*
|
||||
* As each extender has some requirements, as it also produces some other tier, we must first determine
|
||||
* the order in which the available extenders will run, or resolve it.
|
||||
* To lean more about that, address [com.intellij.platform.ml.impl.environment.ExtendedEnvironment]'s documentation.
|
||||
*
|
||||
* Additional tiers,
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
interface EnvironmentExtender<T : Any> : TierRequester {
|
||||
/**
|
||||
* The tier that the extender will be providing.
|
||||
*/
|
||||
val extendingTier: Tier<T>
|
||||
|
||||
/**
|
||||
* Provides an instance of the [extendingTier] based on the [environment].
|
||||
*
|
||||
* @param environment includes tiers requested in [requiredTiers]
|
||||
*/
|
||||
fun extend(environment: Environment): T?
|
||||
|
||||
companion object {
|
||||
val EP_NAME: ExtensionPointName<EnvironmentExtender<*>> = ExtensionPointName("com.intellij.platform.ml.environmentExtender")
|
||||
|
||||
fun <T : Any> EnvironmentExtender<T>.extendTierInstance(environment: Environment): TierInstance<T>? {
|
||||
return extend(environment)?.let {
|
||||
this.extendingTier with it
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
90
platform/ml-api/src/com/intellij/platform/ml/Feature.kt
Normal file
90
platform/ml-api/src/com/intellij/platform/ml/Feature.kt
Normal file
@@ -0,0 +1,90 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml
|
||||
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* An instantiated tier's feature. It could be either
|
||||
* description's feature (that was provided by a [TierDescriptor]), or an
|
||||
* analysis feature (that was provided by a [com.intellij.platform.ml.impl.session.analysis.SessionAnalyser]).
|
||||
*
|
||||
* If you need another type of feature, and it is not supported yet, contact the ML API developers.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
sealed class Feature {
|
||||
/**
|
||||
* A statically defined declaration of this feature, that includes all the information, except the value.
|
||||
*/
|
||||
abstract val declaration: FeatureDeclaration<*>
|
||||
|
||||
/**
|
||||
* A computed nullable value
|
||||
*/
|
||||
abstract val value: Any?
|
||||
|
||||
override fun equals(other: Any?): kotlin.Boolean {
|
||||
if (other !is Feature) return false
|
||||
return other.declaration == this.declaration && other.value == this.value
|
||||
}
|
||||
|
||||
sealed class TypedFeature<T>(
|
||||
val name: String,
|
||||
override val value: T,
|
||||
) : Feature() {
|
||||
abstract val valueType: FeatureValueType<T>
|
||||
|
||||
override val declaration: FeatureDeclaration<*>
|
||||
get() = FeatureDeclaration(name, valueType)
|
||||
}
|
||||
|
||||
override fun hashCode(): kotlin.Int {
|
||||
return this.declaration.hashCode().xor(this.value.hashCode())
|
||||
}
|
||||
|
||||
override fun toString(): String {
|
||||
return "Feature{declaration=$declaration, value=$value}"
|
||||
}
|
||||
|
||||
class Enum<T : kotlin.Enum<*>>(name: String, value: T) : TypedFeature<T>(name, value) {
|
||||
override val valueType = FeatureValueType.Enum(value.javaClass)
|
||||
}
|
||||
|
||||
class Int(name: String, value: kotlin.Int) : TypedFeature<kotlin.Int>(name, value) {
|
||||
override val valueType = FeatureValueType.Int
|
||||
}
|
||||
|
||||
class Boolean(name: String, value: kotlin.Boolean) : TypedFeature<kotlin.Boolean>(name, value) {
|
||||
override val valueType = FeatureValueType.Boolean
|
||||
}
|
||||
|
||||
class Float(name: String, value: kotlin.Float) : TypedFeature<kotlin.Float>(name, value) {
|
||||
override val valueType = FeatureValueType.Float
|
||||
}
|
||||
|
||||
class Double(name: String, value: kotlin.Double) : TypedFeature<kotlin.Double>(name, value) {
|
||||
override val valueType = FeatureValueType.Double
|
||||
}
|
||||
|
||||
class Long(name: String, value: kotlin.Long) : TypedFeature<kotlin.Long>(name, value) {
|
||||
override val valueType = FeatureValueType.Long
|
||||
}
|
||||
|
||||
class Class(name: String, value: java.lang.Class<*>) : TypedFeature<java.lang.Class<*>>(name, value) {
|
||||
override val valueType = FeatureValueType.Class
|
||||
}
|
||||
|
||||
class Nullable<T>(name: String, value: T?, val baseType: FeatureValueType<T>)
|
||||
: TypedFeature<T?>(name, value) {
|
||||
override val valueType = FeatureValueType.Nullable(baseType)
|
||||
}
|
||||
|
||||
class Categorical(name: String, value: String, possibleValues: Set<String>)
|
||||
: TypedFeature<String>(name, value) {
|
||||
override val valueType = FeatureValueType.Categorical(possibleValues)
|
||||
}
|
||||
|
||||
class Version(name: String, value: com.intellij.openapi.util.Version)
|
||||
: TypedFeature<com.intellij.openapi.util.Version>(name, value) {
|
||||
override val valueType = FeatureValueType.Version
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml
|
||||
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* Represents declaration of a Tier's feature.
|
||||
*
|
||||
* All tiers that end up in the ML Model's inference must be declared statically
|
||||
* (i.e. via Extension Points), as FUS logs' validator must be aware of every
|
||||
* feature that will be logged.
|
||||
*
|
||||
* @param name The feature's name that is unique for this tier. It may not contain special symbols.
|
||||
* @param type The feature's type.
|
||||
* @param T The type of the value, with which the [type] can be instantiated ([FeatureValueType.instantiate]).
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
data class FeatureDeclaration<T>(
|
||||
val name: String,
|
||||
val type: FeatureValueType<T>
|
||||
) {
|
||||
init {
|
||||
require(name.all { it.isLetterOrDigit() || it == '_' }) {
|
||||
"Invalid feature name '$name': it shall not contain special symbols"
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Shortcut for the feature's instantiation
|
||||
*/
|
||||
infix fun with(value: T): Feature {
|
||||
return type.instantiate(name, value)
|
||||
}
|
||||
|
||||
/**
|
||||
* Signifies that the feature can be instantiated with null values.
|
||||
*/
|
||||
fun nullable(): FeatureDeclaration<T?> {
|
||||
require(type !is FeatureValueType.Nullable<*>) { "Repeated declaration as 'nullable'" }
|
||||
return FeatureDeclaration(name, FeatureValueType.Nullable(type))
|
||||
}
|
||||
|
||||
companion object {
|
||||
inline fun <reified T : Enum<*>> enum(name: String) = FeatureDeclaration(name, FeatureValueType.Enum(T::class.java))
|
||||
|
||||
fun int(name: String) = FeatureDeclaration(name, FeatureValueType.Int)
|
||||
|
||||
fun double(name: String) = FeatureDeclaration(name, FeatureValueType.Double)
|
||||
|
||||
fun float(name: String) = FeatureDeclaration(name, FeatureValueType.Float)
|
||||
|
||||
fun long(name: String) = FeatureDeclaration(name, FeatureValueType.Long)
|
||||
|
||||
fun aClass(name: String) = FeatureDeclaration(name, FeatureValueType.Class)
|
||||
|
||||
fun boolean(name: String) = FeatureDeclaration(name, FeatureValueType.Boolean)
|
||||
|
||||
fun categorical(name: String, possibleValues: Set<String>) = FeatureDeclaration(name, FeatureValueType.Categorical(possibleValues))
|
||||
|
||||
fun version(name: String) = FeatureDeclaration(name, FeatureValueType.Version)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,18 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml
|
||||
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* A filter that indicates whether a feature matches a feature set or not.
|
||||
* A feature filter functions within a particular tier.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
fun interface FeatureFilter {
|
||||
fun accept(featureDeclaration: FeatureDeclaration<*>): Boolean
|
||||
|
||||
companion object {
|
||||
val REJECT_ALL = FeatureFilter { false }
|
||||
val ACCEPT_ALL = FeatureFilter { true }
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml
|
||||
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* A type of [Feature].
|
||||
*
|
||||
* If you need to use another type of feature to use in your ML model or in analysis,
|
||||
* consider contacting the ML API developers.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
sealed class FeatureValueType<T> {
|
||||
abstract fun instantiate(name: String, value: T): Feature
|
||||
|
||||
data class Nullable<T>(val baseType: FeatureValueType<T>) : FeatureValueType<T?>() {
|
||||
override fun instantiate(name: String, value: T?): Feature {
|
||||
return Feature.Nullable(name, value, baseType)
|
||||
}
|
||||
}
|
||||
|
||||
data class Enum<T : kotlin.Enum<*>>(val enumClass: java.lang.Class<T>) : FeatureValueType<T>() {
|
||||
override fun instantiate(name: String, value: T): Feature {
|
||||
return Feature.Enum(name, value)
|
||||
}
|
||||
}
|
||||
|
||||
data class Categorical(val possibleValues: Set<String>) : FeatureValueType<String>() {
|
||||
override fun instantiate(name: String, value: String): Feature {
|
||||
require(value in possibleValues) {
|
||||
val caseNonMatchingValue = possibleValues.find { it.equals(name, ignoreCase = true) }
|
||||
"Feature $name cannot be assigned to value $value," +
|
||||
"all possible values are $possibleValues. " +
|
||||
"Possible match (but case does not match): $caseNonMatchingValue"
|
||||
}
|
||||
return Feature.Categorical(name, value, possibleValues)
|
||||
}
|
||||
}
|
||||
|
||||
object Int : FeatureValueType<kotlin.Int>() {
|
||||
override fun instantiate(name: String, value: kotlin.Int): Feature {
|
||||
return Feature.Int(name, value)
|
||||
}
|
||||
}
|
||||
|
||||
object Double : FeatureValueType<kotlin.Double>() {
|
||||
override fun instantiate(name: String, value: kotlin.Double): Feature {
|
||||
return Feature.Double(name, value)
|
||||
}
|
||||
}
|
||||
|
||||
object Float : FeatureValueType<kotlin.Float>() {
|
||||
override fun instantiate(name: String, value: kotlin.Float): Feature {
|
||||
return Feature.Float(name, value)
|
||||
}
|
||||
}
|
||||
|
||||
object Long : FeatureValueType<kotlin.Long>() {
|
||||
override fun instantiate(name: String, value: kotlin.Long): Feature {
|
||||
return Feature.Long(name, value)
|
||||
}
|
||||
}
|
||||
|
||||
object Class : FeatureValueType<java.lang.Class<*>>() {
|
||||
override fun instantiate(name: String, value: java.lang.Class<*>): Feature {
|
||||
return Feature.Class(name, value)
|
||||
}
|
||||
}
|
||||
|
||||
object Boolean : FeatureValueType<kotlin.Boolean>() {
|
||||
override fun instantiate(name: String, value: kotlin.Boolean): Feature {
|
||||
return Feature.Boolean(name, value)
|
||||
}
|
||||
}
|
||||
|
||||
object Version : FeatureValueType<com.intellij.openapi.util.Version>() {
|
||||
override fun instantiate(name: String, value: com.intellij.openapi.util.Version): Feature {
|
||||
return Feature.Version(name, value)
|
||||
}
|
||||
}
|
||||
|
||||
override fun toString(): String = this.javaClass.simpleName
|
||||
}
|
||||
@@ -0,0 +1,22 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml
|
||||
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* A mutable version of the [Environment], that can be extended.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
interface MutableEnvironment : Environment {
|
||||
/**
|
||||
* Adds new tier instance to the existing ones.
|
||||
* @throws IllegalArgumentException if there is already an instance of such [tier] in the environment.
|
||||
*/
|
||||
fun <T : Any> putTierInstance(tier: Tier<T>, instance: T)
|
||||
}
|
||||
|
||||
operator fun <T : Any> MutableEnvironment.set(tier: Tier<T>, instance: T) = putTierInstance(tier, instance)
|
||||
|
||||
fun <T : Any> MutableEnvironment.putTierInstance(tierInstance: TierInstance<T>) {
|
||||
this[tierInstance.tier] = tierInstance.instance
|
||||
}
|
||||
@@ -0,0 +1,38 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml
|
||||
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* A [TierDescriptor] that is not fully aware of the features it describes with.
|
||||
* It is used to make a smooth transition from the old forms of features providers to the new API.
|
||||
*
|
||||
* It will be removed eventually.
|
||||
*/
|
||||
@Deprecated(message = "ObsoleteTierDescriptors' features are not logged, as they are missing a feature declaration",
|
||||
replaceWith = ReplaceWith("TierDescriptor", "com.intellij.platform.ml"),
|
||||
level = DeprecationLevel.WARNING)
|
||||
@ApiStatus.Internal
|
||||
interface ObsoleteTierDescriptor : TierDescriptor {
|
||||
/**
|
||||
* The case when a [TierDescriptor] is an [ObsoleteTierDescriptor] is handled in the API
|
||||
* individually each time. And in those times, we cannot rely on the [descriptionDeclaration],
|
||||
* because it may not be correct.
|
||||
*/
|
||||
override val descriptionDeclaration: Set<FeatureDeclaration<*>>
|
||||
get() = throw IllegalAccessError("Obsolete descriptor does not provide a description declaration")
|
||||
|
||||
override fun couldBeUseful(usefulFeaturesFilter: FeatureFilter): Boolean {
|
||||
return true
|
||||
}
|
||||
|
||||
/**
|
||||
* The declaration that is already known.
|
||||
* If there is a feature that is computed but not declared, then they can be logged,
|
||||
* so you can add them to the declaration.
|
||||
*
|
||||
* Turn on the ml.description.logMissing registry key to log the missing features.
|
||||
*/
|
||||
val partialDescriptionDeclaration: Set<FeatureDeclaration<*>>
|
||||
get() = emptySet()
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml
|
||||
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* An environment, that restricts access to the [baseEnvironment] with the [scope].
|
||||
* It is created before passing an [Environment] to a [TierDescriptor], so it will
|
||||
* not be accessing non-declared tiers.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
class ScopeEnvironment private constructor(
|
||||
private val baseEnvironment: Environment,
|
||||
private val scope: Set<Tier<*>>
|
||||
) : Environment {
|
||||
override val tiers: Set<Tier<*>> = scope
|
||||
|
||||
override fun <T : Any> getInstance(tier: Tier<T>): T {
|
||||
require(tier in scope) { "$tier was not supposed to be used, allowed scope: ${scope}" }
|
||||
return baseEnvironment.getInstance(tier)
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun Environment.restrictedBy(scope: Set<Tier<*>>) = ScopeEnvironment(this, scope.intersect(this.tiers))
|
||||
|
||||
fun Environment.narrowedTo(scope: Set<Tier<*>>): ScopeEnvironment {
|
||||
require(scope.all { it in this })
|
||||
return ScopeEnvironment(this, scope)
|
||||
}
|
||||
|
||||
fun Environment.accessibleSafelyByOrNull(requester: TierRequester): ScopeEnvironment? {
|
||||
if (requester.requiredTiers.any { it !in this }) {
|
||||
return null
|
||||
}
|
||||
return this.accessibleSafelyBy(requester)
|
||||
}
|
||||
|
||||
fun Environment.accessibleSafelyBy(requester: TierRequester): ScopeEnvironment = this.narrowedTo(requester.requiredTiers)
|
||||
}
|
||||
}
|
||||
100
platform/ml-api/src/com/intellij/platform/ml/Session.kt
Normal file
100
platform/ml-api/src/com/intellij/platform/ml/Session.kt
Normal file
@@ -0,0 +1,100 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml
|
||||
|
||||
import com.intellij.platform.ml.Session.StartOutcome.Failure
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* A period of making predictions with an ML model.
|
||||
*
|
||||
* The session's depth corresponds to the number of levels in a [com.intellij.platform.ml.impl.MLTask].
|
||||
*
|
||||
* This is a builder of the tree-like session structure.
|
||||
* To learn how sessions work, please address this interface's implementations.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
sealed interface Session<P : Any> {
|
||||
/**
|
||||
* An outcome of an ML session's start.
|
||||
*
|
||||
* It is considered to be a non-exceptional situation when it is not possible to start an ML session.
|
||||
* If so, then a [Failure] is returned and could be handled by the user.
|
||||
*/
|
||||
sealed interface StartOutcome<P : Any> {
|
||||
/**
|
||||
* A ready-to-use ml session, in case the start was successful
|
||||
*/
|
||||
val session: Session<P>?
|
||||
|
||||
fun requireSuccess(): Session<P> = when (this) {
|
||||
is Failure -> throw this.asThrowable()
|
||||
is Success -> this.session
|
||||
}
|
||||
|
||||
/**
|
||||
* Indicates that nothing went wrong, and the start was successful
|
||||
*/
|
||||
class Success<P : Any>(override val session: Session<P>) : StartOutcome<P>
|
||||
|
||||
/**
|
||||
* Indicates that there was some issue during the start.
|
||||
* The problem could be identified more precisely by looking at
|
||||
* the [Failure]'s class.
|
||||
*/
|
||||
open class Failure<P : Any> : StartOutcome<P> {
|
||||
override val session: Session<P>? = null
|
||||
|
||||
open val failureDetails: String
|
||||
get() = "Unable to start ml session, failure: $this"
|
||||
|
||||
open fun asThrowable(): Throwable {
|
||||
return Exception(failureDetails)
|
||||
}
|
||||
}
|
||||
|
||||
class UncaughtException<P : Any>(val exception: Throwable) : Failure<P>() {
|
||||
override fun asThrowable(): Throwable {
|
||||
return Exception("An unexpected exception", exception)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A session, that holds other sessions.
|
||||
*/
|
||||
interface NestableMLSession<P : Any> : Session<P> {
|
||||
/**
|
||||
* Start another nested session within this one, that will inherit
|
||||
* this session's features.
|
||||
*
|
||||
* @param levelMainEnvironment The main session's environment that contains main ML task's tiers.
|
||||
*
|
||||
* @return Either [NestableMLSession] or [SinglePrediction], depending on whether the
|
||||
* last ML task's level has been reached.
|
||||
*/
|
||||
fun createNestedSession(levelMainEnvironment: Environment): Session<P>
|
||||
|
||||
/**
|
||||
* Declare that no more nested sessions will be created from this moment on.
|
||||
* It must be called.
|
||||
*/
|
||||
fun onLastNestedSessionCreated()
|
||||
}
|
||||
|
||||
/**
|
||||
* A session, that is dedicated to create one prediction at most.
|
||||
*/
|
||||
interface SinglePrediction<P : Any> : Session<P> {
|
||||
/**
|
||||
* Call ML model's inference and produce the prediction.
|
||||
* On one object, exactly one function must be called during its lifetime: this or [cancelPrediction]
|
||||
*/
|
||||
fun predict(): P
|
||||
|
||||
/**
|
||||
* Declare that model's inference will not be called.
|
||||
* It must be called in case it's decided that a prediction is not needed.
|
||||
*/
|
||||
fun cancelPrediction()
|
||||
}
|
||||
76
platform/ml-api/src/com/intellij/platform/ml/Tier.kt
Normal file
76
platform/ml-api/src/com/intellij/platform/ml/Tier.kt
Normal file
@@ -0,0 +1,76 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml
|
||||
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* A category of your application's objects that could be used to run the ML API.
|
||||
*
|
||||
* Tiers' main purpose is to be sources for features production, to pass more information to an ML model and to have the most precise
|
||||
* prediction.
|
||||
* Tiers are described per [TierDescriptor].
|
||||
* There are two categories of tiers when it comes to description:
|
||||
*
|
||||
* - Main tiers:
|
||||
* the ones that are declared in an [com.intellij.platform.ml.impl.MLTask] and provided at the problem's application point,
|
||||
* when [com.intellij.platform.ml.NestableMLSession.createNestedSession] is called.
|
||||
* They exist in each session.
|
||||
*
|
||||
* - Additional tiers: the ones that are declared in [com.intellij.platform.ml.impl.approach.LogDrivenModelInference].
|
||||
* They are provided by various [EnvironmentExtender]s occasionally when an [com.intellij.platform.ml.impl.environment.ExtendedEnvironment]
|
||||
* is crated, and they could be absent, as it was not able to satisfy extenders' requirements, or a tier instance was simply not
|
||||
* present this time.
|
||||
* So we can't rely on their existence.
|
||||
*
|
||||
* On the top of that, tiers could serve as helper objects for other user-defined objects, as [TierDescriptor]s and [EnvironmentExtender]s.
|
||||
* The listed interfaces are [TierRequester]s, which means that they require other tiers for proper functioning.
|
||||
* You could create additional [EnvironmentExtender]s to create new tiers, or to define other ways to instantiate existing tiers, but
|
||||
* from other sources.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
abstract class Tier<T : Any> {
|
||||
/**
|
||||
* A unique name of a tier (among other tiers in your application).
|
||||
* Class name is used by default.
|
||||
*/
|
||||
open val name: String
|
||||
get() = this.javaClass.simpleName
|
||||
|
||||
override fun toString(): String = name
|
||||
|
||||
infix fun with(instance: T) = TierInstance(this, instance)
|
||||
}
|
||||
|
||||
/**
|
||||
* A helper class to type-safely handle pairs of [tier] and the corresponding [instance].
|
||||
*/
|
||||
data class TierInstance<T : Any>(val tier: Tier<T>, val instance: T)
|
||||
|
||||
typealias PerTier<T> = Map<Tier<*>, T>
|
||||
|
||||
typealias PerTierInstance<T> = Map<TierInstance<*>, T>
|
||||
|
||||
fun <T> Iterable<PerTier<T>>.joinByUniqueTier(): PerTier<T> {
|
||||
val joinedPerTier = mutableMapOf<Tier<*>, T>()
|
||||
|
||||
this.forEach { perTierMapping ->
|
||||
perTierMapping.forEach { (tier, value) ->
|
||||
require(tier !in joinedPerTier)
|
||||
joinedPerTier[tier] = value
|
||||
}
|
||||
}
|
||||
|
||||
return joinedPerTier
|
||||
}
|
||||
|
||||
fun <T, CI : Iterable<T>, CO : MutableCollection<T>> Iterable<PerTier<CI>>.mergePerTier(createCollection: () -> CO): PerTier<CO> {
|
||||
val joinedPerTier = mutableMapOf<Tier<*>, CO>()
|
||||
for (perTierMapping in this) {
|
||||
for ((tier, anotherCollection) in perTierMapping) {
|
||||
val existingCollection = joinedPerTier[tier] ?: emptyList()
|
||||
require(anotherCollection.all { it !in existingCollection })
|
||||
joinedPerTier.computeIfAbsent(tier) { createCollection() }.addAll(anotherCollection)
|
||||
}
|
||||
}
|
||||
return joinedPerTier
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml
|
||||
|
||||
import com.intellij.openapi.extensions.ExtensionPointName
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* Provides features for a particular [tier].
|
||||
*
|
||||
* It is also a [TierRequester], which implies that [additionallyRequiredTiers] could be defined, in
|
||||
* case additional objects are required to describe the [tier]'s instance.
|
||||
* If so, an attempt will be made to create an [com.intellij.platform.ml.impl.environment.ExtendedEnvironment],
|
||||
* and the extended environment that contains the described [tier] and [additionallyRequiredTiers]
|
||||
* will be passed to the [describe] function.
|
||||
*
|
||||
* @see ObsoleteTierDescriptor if you are looking for an opportunity to smoothly transfer your old feature providers to this API.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
interface TierDescriptor : TierRequester {
|
||||
/**
|
||||
* The tier, that this descriptor is describing (giving features to).
|
||||
*/
|
||||
val tier: Tier<*>
|
||||
|
||||
/**
|
||||
* All features that could ever be used in the declaration.
|
||||
*
|
||||
* _Important_: Do not change features' names.
|
||||
* Names serve as features identifiers, and they are used to pass to ML models.
|
||||
*/
|
||||
val descriptionDeclaration: Set<FeatureDeclaration<*>>
|
||||
|
||||
/**
|
||||
* Computes [tier]'s features from with the features from [descriptionDeclaration].
|
||||
*
|
||||
* If [additionallyRequiredTiers] could not be satisfied, then this descriptor is not called at all.
|
||||
* Be aware that _each_ feature given in [descriptionDeclaration] must be calculated here.
|
||||
* If it is possible that a feature is not computable withing particular circumstances, then
|
||||
* you could declare your feature as nullable: [FeatureDeclaration.nullable].
|
||||
*
|
||||
* @param environment Contains [tier] and [additionallyRequiredTiers].
|
||||
* @param usefulFeaturesFilter Accepts features, that could make any difference to compute this time.
|
||||
* A feature is considered to be useful if an ML model is aware of the feature,
|
||||
* or it is explicitly said that "ML model is not aware of this feature, but it must be logged"
|
||||
* (@see [com.intellij.platform.ml.impl.approach.LogDrivenModelInference]).
|
||||
*
|
||||
* @throws IllegalArgumentException If a feature is missing: it is accepted by [usefulFeaturesFilter]
|
||||
* and it is declared, but not present in the result.
|
||||
* @throws IllegalArgumentException If a redundant feature was computed, that was not declared.
|
||||
*/
|
||||
fun describe(environment: Environment, usefulFeaturesFilter: FeatureFilter): Set<Feature>
|
||||
|
||||
/**
|
||||
* Declares a requirement and ensures, that [describe] will be called if and only if
|
||||
* the [environment] will contain both [tier] and [additionallyRequiredTiers].
|
||||
*/
|
||||
override val requiredTiers: Set<Tier<*>>
|
||||
get() = additionallyRequiredTiers + tier
|
||||
|
||||
/**
|
||||
* Tiers that are required additionally to perform description.
|
||||
*
|
||||
* Such tiers' instances could be created via existing or additionally
|
||||
* created [EnvironmentExtender]s.
|
||||
*/
|
||||
val additionallyRequiredTiers: Set<Tier<*>>
|
||||
get() = emptySet()
|
||||
|
||||
/**
|
||||
* Tells if the descriptor could generate any useful features at all.
|
||||
* @param usefulFeaturesFilter Accepts features, that make sense to calculate at the time of this invocation.
|
||||
* If the filter does not accept a feature, it means that its computation will not make any difference.
|
||||
*/
|
||||
fun couldBeUseful(usefulFeaturesFilter: FeatureFilter): Boolean {
|
||||
return descriptionDeclaration.any { usefulFeaturesFilter.accept(it) }
|
||||
}
|
||||
|
||||
companion object {
|
||||
val EP_NAME: ExtensionPointName<TierDescriptor> = ExtensionPointName.create("com.intellij.platform.ml.descriptor")
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,65 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml
|
||||
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* An object that contains tier instances, and it could be extended.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
class TierInstanceStorage : MutableEnvironment {
|
||||
private val instances: MutableMap<Tier<*>, Any> = mutableMapOf()
|
||||
|
||||
override val tiers: Set<Tier<*>>
|
||||
get() = instances.keys
|
||||
|
||||
override val tierInstances: Set<TierInstance<*>>
|
||||
get() = instances
|
||||
.map { (tier, tierInstance) -> tier withUnsafe tierInstance }
|
||||
.toSet()
|
||||
|
||||
private infix fun <T : Any, P : Any> Tier<T>.withUnsafe(value: P): TierInstance<T> {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
return this.with(value as T)
|
||||
}
|
||||
|
||||
override fun <T : Any> getInstance(tier: Tier<T>): T {
|
||||
val tierInstance = instances[tier]
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
return requireNotNull(tierInstance) as T
|
||||
}
|
||||
|
||||
override fun <T : Any> putTierInstance(tier: Tier<T>, instance: T) {
|
||||
require(tier !in this) {
|
||||
"Tier $tier is already registered in the storage. Old value: '${instances[tier]}', new value: '$instance'"
|
||||
}
|
||||
instances[tier] = instance
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun copyOf(environment: Environment): TierInstanceStorage {
|
||||
val storage = TierInstanceStorage()
|
||||
fun <T : Any> putInstance(tier: Tier<T>) {
|
||||
storage[tier] = environment[tier]
|
||||
}
|
||||
environment.tiers.forEach { putInstance(it) }
|
||||
return storage
|
||||
}
|
||||
|
||||
fun joined(environments: Iterable<Environment>): Environment {
|
||||
val commonStorage = TierInstanceStorage()
|
||||
|
||||
fun <T : Any> putCapturingType(tier: Tier<T>, environment: Environment) {
|
||||
commonStorage[tier] = environment[tier]
|
||||
}
|
||||
|
||||
environments.forEach { environment ->
|
||||
environment.tiers.forEach { tier ->
|
||||
putCapturingType(tier, environment)
|
||||
}
|
||||
}
|
||||
|
||||
return commonStorage
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,26 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml
|
||||
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* An interface that represents an object, that requires some additional tiers
|
||||
* for proper functioning.
|
||||
*
|
||||
* If the requirements could not be satisfied, then the main object's functionality will not be run.
|
||||
* Please address the interface's inheritors, such as [TierDescriptor], [EnvironmentExtender],
|
||||
* and [com.intellij.platform.ml.impl.model.MLModel.Provider] for more details.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
interface TierRequester {
|
||||
/**
|
||||
* The tiers, that are required to use the object.
|
||||
*/
|
||||
val requiredTiers: Set<Tier<*>>
|
||||
|
||||
companion object {
|
||||
fun <T : TierRequester> Iterable<T>.fulfilledBy(environment: Environment): List<T> {
|
||||
return this.filter { it.requiredTiers.all { requiredTier -> requiredTier in environment } }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -28,6 +28,7 @@
|
||||
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/test" isTestSource="true" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/resources" type="java-resource" />
|
||||
<sourceFolder url="file://$MODULE_DIR$/testResources" type="java-test-resource" />
|
||||
</content>
|
||||
<orderEntry type="inheritedJdk" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
@@ -44,5 +45,7 @@
|
||||
<orderEntry type="module" module-name="intellij.platform.util.text.matching" />
|
||||
<orderEntry type="module" module-name="intellij.platform.lang.impl" />
|
||||
<orderEntry type="library" name="kotlinx-serialization-core" level="project" />
|
||||
<orderEntry type="module" module-name="intellij.platform.ml" />
|
||||
<orderEntry type="module" module-name="intellij.platform.statistics" />
|
||||
</component>
|
||||
</module>
|
||||
@@ -6,6 +6,9 @@
|
||||
|
||||
<projectService serviceInterface="com.intellij.internal.statistic.local.LanguageUsageStatisticsProvider"
|
||||
serviceImplementation="com.intellij.internal.statistic.local.LanguageUsageStatisticsProviderImpl"/>
|
||||
<statistics.counterUsagesCollector implementationClass="com.intellij.platform.ml.impl.logger.MLEventsLogger"/>
|
||||
<registryKey defaultValue="false" description="Log features that were computed but missing in an ObsoleteTierDescriptor's declaration"
|
||||
key="ml.description.logMissing"/>
|
||||
</extensions>
|
||||
|
||||
<projectListeners>
|
||||
@@ -17,9 +20,14 @@
|
||||
|
||||
<extensionPoints>
|
||||
<extensionPoint qualifiedName="com.intellij.platform.ml.impl.turboComplete.smartPipelineRunner"
|
||||
interface="com.intellij.platform.ml.impl.turboComplete.SmartPipelineRunner" dynamic="true"/>
|
||||
interface="com.intellij.platform.ml.impl.turboComplete.SmartPipelineRunner"
|
||||
dynamic="true"/>
|
||||
<extensionPoint qualifiedName="com.intellij.platform.ml.impl.approach"
|
||||
interface="com.intellij.platform.ml.impl.MLTaskApproachInitializer"
|
||||
dynamic="true"/>
|
||||
|
||||
<extensionPoint name="mlCompletionCorrectnessSupporter" beanClass="com.intellij.platform.ml.impl.correctness.MLCompletionCorrectnessSupporterEP" dynamic="true">
|
||||
<extensionPoint name="mlCompletionCorrectnessSupporter"
|
||||
beanClass="com.intellij.platform.ml.impl.correctness.MLCompletionCorrectnessSupporterEP" dynamic="true">
|
||||
<with attribute="implementationClass" implements="com.intellij.platform.ml.impl.correctness.MLCompletionCorrectnessSupporter"/>
|
||||
</extensionPoint>
|
||||
</extensionPoints>
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl
|
||||
|
||||
import com.intellij.platform.ml.*
|
||||
import com.intellij.platform.ml.ScopeEnvironment.Companion.accessibleSafelyBy
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* Computes tiers descriptions, calling [TierDescriptor.describe],
|
||||
* or caching the descriptions.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
interface DescriptionComputer {
|
||||
/**
|
||||
*
|
||||
* @param tier The tier that all is being described.
|
||||
* @param descriptors All relevant descriptors, that can be called within [environment]
|
||||
* (it is guaranteed that they all describe [tier] and their requirements are fulfilled by [environment]).
|
||||
* @param environment An environment, that contains all tiers required to run any of the [descriptors].
|
||||
* @param usefulFeaturesFilter Accepts features, that are meaningful to compute. If the filter does not
|
||||
* accept a feature, its computation will not make any difference later.
|
||||
*/
|
||||
fun computeDescription(
|
||||
tier: Tier<*>,
|
||||
descriptors: List<TierDescriptor>,
|
||||
environment: Environment,
|
||||
usefulFeaturesFilter: FeatureFilter,
|
||||
): Map<TierDescriptor, Set<Feature>>
|
||||
}
|
||||
|
||||
/**
|
||||
* Does not cache descriptors, primitively computes all descriptors all over again each time.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
class StateFreeDescriptionComputer : DescriptionComputer {
|
||||
override fun computeDescription(tier: Tier<*>,
|
||||
descriptors: List<TierDescriptor>,
|
||||
environment: Environment,
|
||||
usefulFeaturesFilter: FeatureFilter): Map<TierDescriptor, Set<Feature>> {
|
||||
return descriptors.associateWith { descriptor ->
|
||||
require(descriptor.tier == tier)
|
||||
require(descriptor.requiredTiers.all { it in environment.tiers })
|
||||
descriptor.describe(environment.accessibleSafelyBy(descriptor), usefulFeaturesFilter)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,91 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl
|
||||
|
||||
import com.intellij.platform.ml.FeatureDeclaration
|
||||
import com.intellij.platform.ml.PerTier
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* Determines a set of features, and can tell if a set of features is "complete" or not.
|
||||
* The Meaning of the "completeness" depends on the selector's usage.
|
||||
*
|
||||
* Selectors are used to determine if there are enough features to run an ML model,
|
||||
* and to not compute redundant features, the ones that will not be used by the ML model,
|
||||
* and are not desired to be logged.
|
||||
*
|
||||
* Two types of instances have the power to select features:
|
||||
* - [com.intellij.platform.ml.impl.model.MLModel] to tell which features are known.
|
||||
* - [com.intellij.platform.ml.impl.approach.LogDrivenModelInference] to tell which features are not known by the ML model,
|
||||
* but they are still must be computed and then logged.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
interface FeatureSelector {
|
||||
/**
|
||||
* @param availableFeatures A set of features that the selection will be built from.
|
||||
*
|
||||
* @return A set of features that are 'selected' and a completeness marker
|
||||
* [Selection.Incomplete] is returned if there are some more features that are 'selected',
|
||||
* but they are not present among [availableFeatures].
|
||||
* [Selection.Complete] is returned otherwise.
|
||||
*/
|
||||
fun select(availableFeatures: Set<FeatureDeclaration<*>>): Selection
|
||||
|
||||
/**
|
||||
* @param featureDeclaration A single feature, that needs to be selected.
|
||||
*
|
||||
* @return If the feature belongs to the determined set of features.
|
||||
*/
|
||||
fun select(featureDeclaration: FeatureDeclaration<*>): Boolean {
|
||||
return select(setOf(featureDeclaration)).selectedFeatures.isNotEmpty()
|
||||
}
|
||||
|
||||
sealed class Selection(val selectedFeatures: Set<FeatureDeclaration<*>>) {
|
||||
class Complete(selectedFeatures: Set<FeatureDeclaration<*>>) : Selection(selectedFeatures)
|
||||
|
||||
open class Incomplete(selectedFeatures: Set<FeatureDeclaration<*>>) : Selection(selectedFeatures) {
|
||||
open val details: String = "Incomplete selection, only these were selected: $selectedFeatures"
|
||||
}
|
||||
|
||||
companion object {
|
||||
val NOTHING = Complete(emptySet())
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
val NOTHING = object : FeatureSelector {
|
||||
override fun select(availableFeatures: Set<FeatureDeclaration<*>>): Selection = Selection.NOTHING
|
||||
}
|
||||
|
||||
val EVERYTHING = object : FeatureSelector {
|
||||
override fun select(availableFeatures: Set<FeatureDeclaration<*>>): Selection = Selection.Complete(availableFeatures)
|
||||
}
|
||||
|
||||
infix fun FeatureSelector.or(other: FeatureSelector): FeatureSelector {
|
||||
return object : FeatureSelector {
|
||||
override fun select(availableFeatures: Set<FeatureDeclaration<*>>): Selection {
|
||||
val thisSelection = this@or.select(availableFeatures)
|
||||
val otherSelection = other.select(availableFeatures)
|
||||
val joinedSelection = (thisSelection.selectedFeatures + otherSelection.selectedFeatures).toSet()
|
||||
return if (thisSelection is Selection.Incomplete || otherSelection is Selection.Incomplete) {
|
||||
val incompleteSelection = if (thisSelection is Selection.Incomplete)
|
||||
thisSelection
|
||||
else
|
||||
otherSelection as Selection.Incomplete
|
||||
|
||||
object : Selection.Incomplete(joinedSelection) {
|
||||
override val details: String
|
||||
get() = incompleteSelection.details
|
||||
}
|
||||
}
|
||||
else
|
||||
Selection.Complete(joinedSelection)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
infix fun PerTier<FeatureSelector>.or(other: PerTier<FeatureSelector>): PerTier<FeatureSelector> {
|
||||
require(this.keys == other.keys)
|
||||
return keys.associateWith { this.getValue(it) or other.getValue(it) }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl
|
||||
|
||||
import com.intellij.platform.ml.*
|
||||
import com.intellij.platform.ml.impl.model.MLModel
|
||||
import com.intellij.platform.ml.impl.session.DescribedLevel
|
||||
import com.intellij.platform.ml.impl.session.NestableStructureCollector
|
||||
import com.intellij.platform.ml.impl.session.PredictionCollector
|
||||
import com.intellij.platform.ml.impl.session.SessionTree
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* A [SinglePrediction] performed by an ML model.
|
||||
* The session's structure is collected by [collector], after the prediction is done or canceled.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
class MLModelPrediction<T : SessionTree.PredictionContainer<M, DescribedLevel, P>, M : MLModel<P>, P : Any>(
|
||||
private val mlModel: M,
|
||||
private val collector: PredictionCollector<T, M, P>,
|
||||
) : SinglePrediction<P> {
|
||||
override fun predict(): P {
|
||||
val prediction = mlModel.predict(collector.usableDescription)
|
||||
collector.submitPrediction(prediction)
|
||||
return prediction
|
||||
}
|
||||
|
||||
override fun cancelPrediction() {
|
||||
collector.submitPrediction(null)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A [NestableMLSession] of utilizing [mlModel].
|
||||
* The session's structure is collected by [collector] after [onLastNestedSessionCreated],
|
||||
* and all nested sessions' structures are collected.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
class MLModelPredictionBranching<T : SessionTree.ChildrenContainer<M, DescribedLevel, P>, M : MLModel<P>, P : Any>(
|
||||
private val mlModel: M,
|
||||
private val collector: NestableStructureCollector<T, M, P>
|
||||
) : NestableMLSession<P> {
|
||||
override fun createNestedSession(levelMainEnvironment: Environment): Session<P> {
|
||||
val nestedLevelScheme = collector.levelPositioning.lowerTiers.first()
|
||||
verifyTiersInMain(nestedLevelScheme.main, levelMainEnvironment.tiers)
|
||||
val levelAdditionalTiers = nestedLevelScheme.additional
|
||||
|
||||
return if (collector.levelPositioning.lowerTiers.size == 1) {
|
||||
val nestedCollector = collector.nestPrediction(levelMainEnvironment, levelAdditionalTiers)
|
||||
MLModelPrediction(mlModel, nestedCollector)
|
||||
}
|
||||
else {
|
||||
assert(collector.levelPositioning.lowerTiers.size > 1)
|
||||
val nestedCollector = collector.nestBranch(levelMainEnvironment, levelAdditionalTiers)
|
||||
MLModelPredictionBranching(mlModel, nestedCollector)
|
||||
}
|
||||
}
|
||||
|
||||
override fun onLastNestedSessionCreated() {
|
||||
collector.onLastNestedCollectorCreated()
|
||||
}
|
||||
}
|
||||
|
||||
private fun verifyTiersInMain(expected: Set<Tier<*>>, actual: Set<*>) {
|
||||
require(expected == actual) {
|
||||
"Tier set in the main environment is not like it was declared. " +
|
||||
"Declared $expected, " +
|
||||
"but given $actual"
|
||||
}
|
||||
}
|
||||
129
platform/ml-impl/src/com/intellij/platform/ml/impl/MLTask.kt
Normal file
129
platform/ml-impl/src/com/intellij/platform/ml/impl/MLTask.kt
Normal file
@@ -0,0 +1,129 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl
|
||||
|
||||
import com.intellij.openapi.extensions.ExtensionPointName
|
||||
import com.intellij.platform.ml.*
|
||||
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform
|
||||
import com.intellij.platform.ml.impl.apiPlatform.ReplaceableIJPlatform
|
||||
import com.intellij.platform.ml.impl.session.AdditionalTierScheme
|
||||
import com.intellij.platform.ml.impl.session.Level
|
||||
import com.intellij.platform.ml.impl.session.MainTierScheme
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
|
||||
/**
|
||||
* Is a declaration of a place in code, where a classical machine learning approach
|
||||
* is desired to be applied.
|
||||
*
|
||||
* The proper way to create new tasks is to create a static object-inheritor of this class
|
||||
* (see inheritors of this class to find all implemented tasks in your project).
|
||||
*
|
||||
* @param name The unique name of an ML Task
|
||||
* @param levels The main tiers of the task that will be provided within the task's application place
|
||||
* @param predictionClass The class of an object, that will serve as "prediction"
|
||||
* @param T The type of prediction
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
abstract class MLTask<T : Any> protected constructor(
|
||||
val name: String,
|
||||
val levels: List<Set<Tier<*>>>,
|
||||
val predictionClass: Class<T>
|
||||
)
|
||||
|
||||
/**
|
||||
* A method of approaching an ML task.
|
||||
* Usually, it is inferencing an ML model and collecting logs.
|
||||
*
|
||||
* Each [MLTaskApproach] is initialized once by the corresponding [MLTaskApproachInitializer],
|
||||
* then the [apiPlatform] is fixed.
|
||||
*
|
||||
* @see [com.intellij.platform.ml.impl.approach.LogDrivenModelInference] for currently used approach.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
interface MLTaskApproach<P : Any> {
|
||||
/**
|
||||
* The task this approach is solving.
|
||||
* Each approach is dedicated to one and only task, and it is aware of it.
|
||||
*/
|
||||
val task: MLTask<P>
|
||||
|
||||
/**
|
||||
* The platform, this approach is called within, that was provided by [MLTaskApproachInitializer]
|
||||
*/
|
||||
val apiPlatform: MLApiPlatform
|
||||
|
||||
/**
|
||||
* A static declaration of the features, used in the approach.
|
||||
*/
|
||||
val approachDeclaration: Declaration
|
||||
|
||||
/**
|
||||
* Acquire the ML model and start the session.
|
||||
*
|
||||
* @return [Session.StartOutcome.Failure] if something went wrong during the start, [Session.StartOutcome.Success]
|
||||
* which contains the started session otherwise.
|
||||
*/
|
||||
fun startSession(permanentSessionEnvironment: Environment): Session.StartOutcome<P>
|
||||
|
||||
data class Declaration(
|
||||
val sessionFeatures: Map<String, Set<FeatureDeclaration<*>>>,
|
||||
val levelsScheme: List<LevelScheme>
|
||||
)
|
||||
|
||||
companion object {
|
||||
fun <P : Any> findMlApproach(task: MLTask<P>, apiPlatform: MLApiPlatform = ReplaceableIJPlatform): MLTaskApproach<P> {
|
||||
return apiPlatform.accessApproachFor(task)
|
||||
}
|
||||
|
||||
fun <P : Any> startMLSession(task: MLTask<P>,
|
||||
permanentSessionEnvironment: Environment,
|
||||
apiPlatform: MLApiPlatform = ReplaceableIJPlatform): Session.StartOutcome<P> {
|
||||
val approach = findMlApproach(task, apiPlatform)
|
||||
return approach.startSession(permanentSessionEnvironment)
|
||||
}
|
||||
|
||||
fun <P : Any> MLTask<P>.startMLSession(permanentSessionEnvironment: Environment): Session.StartOutcome<P> {
|
||||
return startMLSession(this, permanentSessionEnvironment)
|
||||
}
|
||||
|
||||
fun <P : Any> MLTask<P>.startMLSession(permanentTierInstances: Iterable<TierInstance<*>>): Session.StartOutcome<P> {
|
||||
return this.startMLSession(Environment.of(permanentTierInstances))
|
||||
}
|
||||
|
||||
fun <P : Any> MLTask<P>.startMLSession(vararg permanentTierInstances: TierInstance<*>): Session.StartOutcome<P> {
|
||||
return this.startMLSession(Environment.of(*permanentTierInstances))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Initializes an [MLTaskApproach]
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
interface MLTaskApproachInitializer<P : Any> {
|
||||
/**
|
||||
* The task, that the created [MLTaskApproach] is dedicated to solve.
|
||||
*/
|
||||
val task: MLTask<P>
|
||||
|
||||
/**
|
||||
* Initializes the approach.
|
||||
* It is called only once during the application's runtime.
|
||||
* So it is crucial that this function will accept the [MLApiPlatform] you want it to.
|
||||
*
|
||||
* To access the API in order to build event validator statically,
|
||||
* FUS uses the actual [com.intellij.platform.ml.impl.apiPlatform.IJPlatform], which could be problematic if you
|
||||
* want to test FUS logs.
|
||||
* So make sure that you will replace it with your test platform in
|
||||
* time via [com.intellij.platform.ml.impl.apiPlatform.ReplaceableIJPlatform.replacingWith].
|
||||
*/
|
||||
fun initializeApproachWithin(apiPlatform: MLApiPlatform): MLTaskApproach<P>
|
||||
|
||||
companion object {
|
||||
val EP_NAME = ExtensionPointName<MLTaskApproachInitializer<*>>("com.intellij.platform.ml.impl.approach")
|
||||
}
|
||||
}
|
||||
|
||||
typealias LevelScheme = Level<PerTier<MainTierScheme>, PerTier<AdditionalTierScheme>>
|
||||
|
||||
typealias LevelTiers = Level<Set<Tier<*>>, Set<Tier<*>>>
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.apiPlatform
|
||||
|
||||
import com.intellij.platform.ml.FeatureDeclaration
|
||||
import com.intellij.platform.ml.FeatureValueType
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* Prints code-like features representation, when they have to be logged for an API user's
|
||||
* convenience (so they can just copy-paste the logged features into their code).
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
class CodeLikePrinter {
|
||||
private val <T : Enum<*>> FeatureValueType.Enum<T>.codeLikeType: String
|
||||
get() = this.enumClass.name
|
||||
|
||||
private fun <T> FeatureValueType<T>.makeCodeLikeString(name: String): String = when (this) {
|
||||
FeatureValueType.Boolean -> "FeatureDeclaration.boolean(\"$name\")"
|
||||
FeatureValueType.Class -> "FeatureDeclaration.aClass(\"$name\")"
|
||||
FeatureValueType.Double -> "FeatureDeclaration.double(\"$name\")"
|
||||
is FeatureValueType.Enum<*> -> "FeatureDeclaration.enum<${this.codeLikeType}>(\"$name\")"
|
||||
FeatureValueType.Float -> "FeatureDeclaration.float(\"${name}\")"
|
||||
FeatureValueType.Int -> "FeatureDeclaration.int(\"${name}\")"
|
||||
FeatureValueType.Long -> "FeatureDeclaration.long(\"${name}\")"
|
||||
is FeatureValueType.Nullable<*> -> "${this.baseType.makeCodeLikeString(name)}.nullable()"
|
||||
is FeatureValueType.Categorical -> {
|
||||
val possibleValuesSerialized = possibleValues.joinToString(", ") { "\"$it\"" }
|
||||
"FeatureDeclaration.categorical(\"$name\", setOf(${possibleValuesSerialized}))"
|
||||
}
|
||||
FeatureValueType.Version -> "FeatureDeclaration.version(\"${name}\")"
|
||||
}
|
||||
|
||||
fun <T> printCodeLikeString(featureDeclaration: FeatureDeclaration<T>): String {
|
||||
return featureDeclaration.type.makeCodeLikeString(featureDeclaration.name)
|
||||
}
|
||||
|
||||
fun printCodeLikeString(featureDeclarations: Collection<FeatureDeclaration<*>>): String {
|
||||
return featureDeclarations.joinToString(", ") { printCodeLikeString(it) }
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,186 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.apiPlatform
|
||||
|
||||
import com.intellij.openapi.diagnostic.thisLogger
|
||||
import com.intellij.openapi.util.registry.Registry
|
||||
import com.intellij.platform.ml.EnvironmentExtender
|
||||
import com.intellij.platform.ml.Feature
|
||||
import com.intellij.platform.ml.ObsoleteTierDescriptor
|
||||
import com.intellij.platform.ml.TierDescriptor
|
||||
import com.intellij.platform.ml.impl.MLTaskApproachInitializer
|
||||
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform.ExtensionController
|
||||
import com.intellij.platform.ml.impl.apiPlatform.ReplaceableIJPlatform.replacingWith
|
||||
import com.intellij.platform.ml.impl.logger.MLEvent
|
||||
import com.intellij.platform.ml.impl.monitoring.MLApiStartupListener
|
||||
import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener
|
||||
import com.intellij.util.application
|
||||
import com.intellij.util.messages.Topic
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
import org.jetbrains.annotations.NonNls
|
||||
|
||||
@ApiStatus.Internal
|
||||
fun interface MessagingProvider<T> {
|
||||
fun provide(collector: (T) -> Unit)
|
||||
|
||||
companion object {
|
||||
inline fun <T, reified P : MessagingProvider<T>> createTopic(displayName: @NonNls String): Topic<P> {
|
||||
return Topic.create(displayName, P::class.java)
|
||||
}
|
||||
|
||||
fun <T, P : MessagingProvider<T>> collect(topic: Topic<P>): List<T> {
|
||||
val collected = mutableListOf<T>()
|
||||
application.messageBus.syncPublisher(topic).provide { collected.add(it) }
|
||||
return collected
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun interface MLTaskListenerProvider : MessagingProvider<MLTaskGroupListener> {
|
||||
companion object {
|
||||
val TOPIC = MessagingProvider.createTopic<MLTaskGroupListener, MLTaskListenerProvider>("ml.task")
|
||||
}
|
||||
}
|
||||
|
||||
fun interface MLEventProvider : MessagingProvider<MLEvent> {
|
||||
companion object {
|
||||
val TOPIC = MessagingProvider.createTopic<MLEvent, MLEventProvider>("ml.event")
|
||||
}
|
||||
}
|
||||
|
||||
fun interface MLApiStartupListenerProvider : MessagingProvider<MLApiStartupListener> {
|
||||
companion object {
|
||||
val TOPIC = MessagingProvider.createTopic<MLApiStartupListener, MLApiStartupListenerProvider>("ml.startup")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A representation of the "real-life" [MLApiPlatform], whose content is
|
||||
* the content of the corresponding Extension Points.
|
||||
* It is used at the API's entry point, unless it is not replaced by another.
|
||||
*
|
||||
* It shouldn't be used due to low testability.
|
||||
* Use [ReplaceableIJPlatform] instead.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
private data object IJPlatform : MLApiPlatform() {
|
||||
override val tierDescriptors: List<TierDescriptor>
|
||||
get() = TierDescriptor.EP_NAME.extensionList
|
||||
|
||||
override val environmentExtenders: List<EnvironmentExtender<*>>
|
||||
get() = EnvironmentExtender.EP_NAME.extensionList
|
||||
|
||||
override val taskApproaches: List<MLTaskApproachInitializer<*>>
|
||||
get() = MLTaskApproachInitializer.EP_NAME.extensionList
|
||||
|
||||
override val taskListeners: List<MLTaskGroupListener>
|
||||
get() = MessagingProvider.collect(MLTaskListenerProvider.TOPIC)
|
||||
|
||||
override val events: List<MLEvent>
|
||||
get() = MessagingProvider.collect(MLEventProvider.TOPIC)
|
||||
|
||||
override val startupListeners: List<MLApiStartupListener>
|
||||
get() = MessagingProvider.collect(MLApiStartupListenerProvider.TOPIC)
|
||||
|
||||
override fun addStartupListener(listener: MLApiStartupListener): ExtensionController {
|
||||
val connection = application.messageBus.connect()
|
||||
connection.subscribe(MLApiStartupListenerProvider.TOPIC, MLApiStartupListenerProvider { collector -> collector(listener) })
|
||||
return ExtensionController { connection.disconnect() }
|
||||
}
|
||||
|
||||
override fun addTaskListener(taskListener: MLTaskGroupListener): ExtensionController {
|
||||
val connection = application.messageBus.connect()
|
||||
connection.subscribe(MLTaskListenerProvider.TOPIC, MLTaskListenerProvider { it(taskListener) })
|
||||
return ExtensionController { connection.disconnect() }
|
||||
}
|
||||
|
||||
override fun addEvent(event: MLEvent): ExtensionController {
|
||||
val connection = application.messageBus.connect()
|
||||
connection.subscribe(MLEventProvider.TOPIC, MLEventProvider { it(event) })
|
||||
return ExtensionController { connection.disconnect() }
|
||||
}
|
||||
|
||||
override fun manageNonDeclaredFeatures(descriptor: ObsoleteTierDescriptor, nonDeclaredFeatures: Set<Feature>) {
|
||||
if (!Registry.`is`("ml.description.logMissing")) return
|
||||
val printer = CodeLikePrinter()
|
||||
val codeLikeMissingDeclaration = printer.printCodeLikeString(nonDeclaredFeatures.map { it.declaration })
|
||||
thisLogger().info("${descriptor::class.java} is missing declaration: setOf($codeLikeMissingDeclaration)")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Also a "real-life" [MLApiPlatform], but it can be replaced with another one any time.
|
||||
*
|
||||
* We always want to test [com.intellij.platform.ml.impl.MLTaskApproach]es.
|
||||
* But after they are initialized by [com.intellij.platform.ml.impl.MLTaskApproachInitializer],
|
||||
* the passed [MLApiPlatform] could already spread all the way within the API.
|
||||
* But the user-defined instances of the api could be overridden for testing sake.
|
||||
*
|
||||
* To replace all [TierDescriptor], [EnvironmentExtender] and [MLTaskApproachInitializer]
|
||||
* to test your code, you may call [replacingWith] and pass the desired environment,
|
||||
* that contains all the objects you need for your test.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
object ReplaceableIJPlatform : MLApiPlatform() {
|
||||
private var replacement: MLApiPlatform? = null
|
||||
|
||||
private val platform: MLApiPlatform
|
||||
get() = replacement ?: IJPlatform
|
||||
|
||||
|
||||
override val tierDescriptors: List<TierDescriptor>
|
||||
get() = platform.tierDescriptors
|
||||
|
||||
override val environmentExtenders: List<EnvironmentExtender<*>>
|
||||
get() = platform.environmentExtenders
|
||||
|
||||
override val taskApproaches: List<MLTaskApproachInitializer<*>>
|
||||
get() = platform.taskApproaches
|
||||
|
||||
|
||||
override val taskListeners: List<MLTaskGroupListener>
|
||||
get() = platform.taskListeners
|
||||
|
||||
override val events: List<MLEvent>
|
||||
get() = platform.events
|
||||
|
||||
override val startupListeners: List<MLApiStartupListener>
|
||||
get() = platform.startupListeners
|
||||
|
||||
|
||||
override fun addStartupListener(listener: MLApiStartupListener): ExtensionController {
|
||||
return extend(listener) { platform -> platform.addStartupListener(listener) }
|
||||
}
|
||||
|
||||
override fun addTaskListener(taskListener: MLTaskGroupListener): ExtensionController {
|
||||
return extend(taskListener) { platform -> platform.addTaskListener(taskListener) }
|
||||
}
|
||||
|
||||
override fun addEvent(event: MLEvent): ExtensionController {
|
||||
return extend(event) { platform -> platform.addEvent(event) }
|
||||
}
|
||||
|
||||
override fun manageNonDeclaredFeatures(descriptor: ObsoleteTierDescriptor, nonDeclaredFeatures: Set<Feature>) =
|
||||
platform.manageNonDeclaredFeatures(descriptor, nonDeclaredFeatures)
|
||||
|
||||
private fun <T> extend(obj: T, method: (MLApiPlatform) -> ExtensionController): ExtensionController {
|
||||
val initialPlatform = platform
|
||||
method(initialPlatform)
|
||||
return ExtensionController {
|
||||
require(initialPlatform == platform) {
|
||||
"$obj should be removed within the same platform it was added in." +
|
||||
"It was added in $initialPlatform, but removed from $platform"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fun <T> replacingWith(apiPlatform: MLApiPlatform, action: () -> T): T {
|
||||
val oldApiPlatform = replacement
|
||||
return try {
|
||||
replacement = apiPlatform
|
||||
action()
|
||||
}
|
||||
finally {
|
||||
replacement = oldApiPlatform
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,264 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.apiPlatform
|
||||
|
||||
import com.intellij.platform.ml.*
|
||||
import com.intellij.platform.ml.impl.MLTask
|
||||
import com.intellij.platform.ml.impl.MLTaskApproach
|
||||
import com.intellij.platform.ml.impl.MLTaskApproachInitializer
|
||||
import com.intellij.platform.ml.impl.logger.MLEvent
|
||||
import com.intellij.platform.ml.impl.logger.MLEventsLogger
|
||||
import com.intellij.platform.ml.impl.monitoring.*
|
||||
import com.intellij.platform.ml.impl.monitoring.MLApproachListener.Companion.asJoinedListener
|
||||
import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.Companion.onAttemptedToStartSession
|
||||
import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.Companion.targetedApproaches
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* Represents an environment, that provides extendable parts of the ML API.
|
||||
*
|
||||
* Each entity inside the API could access the platform, it is running within,
|
||||
* as everything happens after [com.intellij.platform.ml.impl.MLTaskApproachInitializer.initializeApproachWithin],
|
||||
* where the platform is acknowledged.
|
||||
*
|
||||
* All usages of the ij platform functionality (extension points, registry keys, etc.) shall be
|
||||
* accessed via this class.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
abstract class MLApiPlatform {
|
||||
private val finishedInitialization = lazy { MLApiPlatformInitializationProcess() }
|
||||
private var initializationStage: InitializationStage = InitializationStage.NotStarted
|
||||
|
||||
/**
|
||||
* Each [MLTaskApproach] is initialized only once during the application's lifetime.
|
||||
* This function keeps track of all approaches that were initialized already, and initializes
|
||||
* them when they are first needed.
|
||||
*/
|
||||
fun <P : Any> accessApproachFor(task: MLTask<P>): MLTaskApproach<P> {
|
||||
return finishedInitialization.value.getApproachFor(task)
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* The extendable static state of the platform that must be fixed to create FUS group's validator.
|
||||
*
|
||||
* These values must be static, as they define the FUS scheme
|
||||
*/
|
||||
val staticState: StaticState
|
||||
get() = StaticState(tierDescriptors, environmentExtenders, taskApproaches)
|
||||
|
||||
/**
|
||||
* The descriptors that are available in the platform.
|
||||
* This value is interchangeable during the application runtime,
|
||||
* see [staticState].
|
||||
*/
|
||||
abstract val tierDescriptors: List<TierDescriptor>
|
||||
|
||||
/**
|
||||
* The complete list of environment extenders, available in the platform.
|
||||
* This value is interchangeable during the application runtime,
|
||||
* see [staticState].
|
||||
*/
|
||||
abstract val environmentExtenders: List<EnvironmentExtender<*>>
|
||||
|
||||
/**
|
||||
* The complete list of the approaches for ML tasks, available in the platform.
|
||||
* This value is interchangeable during the application runtime,
|
||||
* see [staticState].
|
||||
*/
|
||||
abstract val taskApproaches: List<MLTaskApproachInitializer<*>>
|
||||
|
||||
|
||||
/**
|
||||
* All the objects, that are listening execution of ML tasks.
|
||||
* The collection is mutable, so new listeners could be added via [addTaskListener].
|
||||
*
|
||||
* This value is mutable, new listeners could be added anytime.
|
||||
*/
|
||||
abstract val taskListeners: List<MLTaskGroupListener>
|
||||
|
||||
/**
|
||||
* Adds another provider for ML tasks' execution process monitoring dynamically.
|
||||
* The event could be removed via the corresponding [ExtensionController.removeExtension] call.
|
||||
* See [taskListeners].
|
||||
*/
|
||||
abstract fun addTaskListener(taskListener: MLTaskGroupListener): ExtensionController
|
||||
|
||||
/**
|
||||
* ML events that will be written to FUS logs.
|
||||
* As FUS is initialized only once, on the application's startup, they all must be registered
|
||||
* before that via [addMLEventBeforeFusInitialized].
|
||||
*
|
||||
* This value could be mutable, however, only during a short period of time: after the application's startup,
|
||||
* and before FUS logs initialization.
|
||||
*/
|
||||
abstract val events: List<MLEvent>
|
||||
|
||||
/**
|
||||
* Adds another ML event dynamically.
|
||||
* The event could be removed via the corresponding [ExtensionController.removeExtension] call.
|
||||
*/
|
||||
fun addMLEventBeforeFusInitialized(event: MLEvent): ExtensionController {
|
||||
when (val stage = initializationStage) {
|
||||
is InitializationStage.Failed ->
|
||||
throw Exception("Initialization of ML Api Platform has failed, events could not be added.", stage.asException)
|
||||
is InitializationStage.PotentiallySuccessful -> {
|
||||
require(stage.order <= InitializationStage.InitializingApproaches.order) {
|
||||
"FUS group initialization has already been started, not allowed to register more ML Events"
|
||||
}
|
||||
return addEvent(event)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* The complete list of the listeners that are listening to the process of an MLApiPlatform's initialization.
|
||||
* The initialization is performed in [finishedInitialization]'s init block.
|
||||
*
|
||||
* This value could be mutable, so
|
||||
* additional listeners could be added via [addStartupListener].
|
||||
*
|
||||
* If a listener was added after a certain initialization stage,
|
||||
* only callbacks of those stages will be triggered later that have not happened yet.
|
||||
*/
|
||||
abstract val startupListeners: List<MLApiStartupListener>
|
||||
|
||||
/**
|
||||
* Adds another startup listener.
|
||||
* The listener could be removed via the corresponding [ExtensionController.removeExtension] call.
|
||||
*/
|
||||
abstract fun addStartupListener(listener: MLApiStartupListener): ExtensionController
|
||||
|
||||
|
||||
/**
|
||||
* Declares how the computed but non-declared features will be handled.
|
||||
*/
|
||||
abstract fun manageNonDeclaredFeatures(descriptor: ObsoleteTierDescriptor, nonDeclaredFeatures: Set<Feature>)
|
||||
|
||||
|
||||
internal abstract fun addEvent(event: MLEvent): ExtensionController
|
||||
|
||||
fun interface ExtensionController {
|
||||
fun removeExtension()
|
||||
}
|
||||
|
||||
data class StaticState(
|
||||
val tierDescriptors: List<TierDescriptor>,
|
||||
val environmentExtenders: List<EnvironmentExtender<*>>,
|
||||
val taskApproaches: List<MLTaskApproachInitializer<*>>,
|
||||
)
|
||||
|
||||
private sealed class InitializationStage(val callListener: (MLApiStartupProcessListener) -> Unit) {
|
||||
sealed class PotentiallySuccessful(val order: Int, callListener: (MLApiStartupProcessListener) -> Unit) : InitializationStage(callListener)
|
||||
|
||||
class Failed(lastStage: InitializationStage, nextStage: InitializationStage, exception: Throwable, callListener: (MLApiStartupProcessListener) -> Unit) : InitializationStage(callListener) {
|
||||
val asException = Exception("Failed to proceed from the initialization stage $lastStage to $nextStage", exception)
|
||||
}
|
||||
|
||||
data object NotStarted : PotentiallySuccessful(0, {})
|
||||
data object InitializingApproaches : PotentiallySuccessful(1, { it.onStartedInitializingApproaches() })
|
||||
data class InitializingFUS(val initializedApproaches: Collection<InitializerAndApproach<*>>) : PotentiallySuccessful(2, {
|
||||
it.onStartedInitializingFus(initializedApproaches)
|
||||
})
|
||||
|
||||
data object Finished : PotentiallySuccessful(3, { it.onFinished() })
|
||||
}
|
||||
|
||||
private inner class MLApiPlatformInitializationProcess {
|
||||
val approachPerTask: Map<MLTask<*>, MLTaskApproach<*>>
|
||||
private val completeInitializersList: List<MLTaskApproachInitializer<*>> = taskApproaches.toMutableList()
|
||||
|
||||
init {
|
||||
require(initializationStage == InitializationStage.NotStarted) { "ML API Platform's initialization should not be run twice" }
|
||||
|
||||
fun currentStartupListeners(): List<MLApiStartupProcessListener> = startupListeners.map { it.onBeforeStarted(this@MLApiPlatform) }
|
||||
|
||||
fun <T> proceedToNextStage(nextStage: InitializationStage.PotentiallySuccessful, action: () -> T): T {
|
||||
return try {
|
||||
action().also {
|
||||
currentStartupListeners().forEach { nextStage.callListener(it) }
|
||||
initializationStage = nextStage
|
||||
}
|
||||
}
|
||||
catch (ex: Throwable) {
|
||||
val failure = InitializationStage.Failed(initializationStage, nextStage, ex) { it.onFailed(ex) }
|
||||
initializationStage = failure
|
||||
throw failure.asException
|
||||
}
|
||||
}
|
||||
|
||||
proceedToNextStage(InitializationStage.InitializingApproaches) {}
|
||||
|
||||
val initializedApproachPerTask = mutableListOf<InitializerAndApproach<*>>()
|
||||
|
||||
approachPerTask = proceedToNextStage(InitializationStage.InitializingFUS(initializedApproachPerTask)) {
|
||||
completeInitializersList.validate()
|
||||
|
||||
fun <T : Any> initializeApproach(approachInitializer: MLTaskApproachInitializer<T>) {
|
||||
initializedApproachPerTask.add(InitializerAndApproach(
|
||||
approachInitializer,
|
||||
approachInitializer.initializeApproachWithin(this@MLApiPlatform)
|
||||
))
|
||||
}
|
||||
completeInitializersList.forEach { initializeApproach(it) }
|
||||
initializedApproachPerTask.associate { it.initializer.task to it.approach }
|
||||
}
|
||||
|
||||
proceedToNextStage(InitializationStage.Finished) {
|
||||
MLEventsLogger.Manager.ensureInitialized(okIfInitializing = true, this@MLApiPlatform)
|
||||
}
|
||||
}
|
||||
|
||||
fun <P : Any> getApproachFor(task: MLTask<P>): MLTaskApproach<P> {
|
||||
val taskApproach = requireNotNull(approachPerTask[task]) {
|
||||
val mainMessage = "No approach for task $task was found"
|
||||
val lateRegistrationMessage = getLateApproachRegistrationAssumption(task)
|
||||
if (lateRegistrationMessage != null) "$mainMessage. $lateRegistrationMessage" else mainMessage
|
||||
}
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
return taskApproach as MLTaskApproach<P>
|
||||
}
|
||||
|
||||
private fun getLateApproachRegistrationAssumption(task: MLTask<*>): String? {
|
||||
val currentInitializersList = taskApproaches.toMutableList()
|
||||
if (completeInitializersList == currentInitializersList) return null
|
||||
val taskApproaches = currentInitializersList.filter { it.task == task }
|
||||
if (taskApproaches.isEmpty()) return null
|
||||
require(taskApproaches.size == 1) { "More than one approach for task $task: $taskApproaches" }
|
||||
return "Approach ${taskApproaches.first()} for task ${task.name} was registered after the ML API Platform was initialized"
|
||||
}
|
||||
|
||||
private fun List<MLTaskApproachInitializer<*>>.validate() {
|
||||
val duplicateInitializerPerTask = this.groupBy { it.task }.filter { it.value.size > 1 }
|
||||
require(duplicateInitializerPerTask.isEmpty()) {
|
||||
"Found more than one approach for the following tasks: ${duplicateInitializerPerTask}"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun MLApiPlatform.getDescriptorsOfTiers(tiers: Set<Tier<*>>): PerTier<List<TierDescriptor>> {
|
||||
val descriptorsPerTier = tierDescriptors.groupBy { it.tier }
|
||||
return tiers.associateWith { descriptorsPerTier[it] ?: emptyList() }
|
||||
}
|
||||
|
||||
fun MLApiPlatform.ensureApproachesInitialized() {
|
||||
when (val stage = initializationStage) {
|
||||
is InitializationStage.Failed -> throw Exception("Unable to ensure that approaches are initialized", stage.asException)
|
||||
InitializationStage.NotStarted -> finishedInitialization.value
|
||||
InitializationStage.InitializingApproaches -> throw Exception("Recursion detected while initializing approaches")
|
||||
is InitializationStage.InitializingFUS -> return
|
||||
InitializationStage.Finished -> return
|
||||
}
|
||||
}
|
||||
|
||||
fun <R, P : Any> MLApiPlatform.getJoinedListenerForTask(taskApproach: MLTaskApproach<P>,
|
||||
permanentSessionEnvironment: Environment): MLApproachListener<R, P> {
|
||||
val relevantGroupListeners = taskListeners.filter { taskApproach.javaClass in it.targetedApproaches }
|
||||
val approachListeners = relevantGroupListeners.mapNotNull {
|
||||
it.onAttemptedToStartSession<P, R>(taskApproach, permanentSessionEnvironment)
|
||||
}
|
||||
return approachListeners.asJoinedListener()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,31 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.approach
|
||||
|
||||
import com.intellij.platform.ml.FeatureDeclaration
|
||||
import com.intellij.platform.ml.PerTier
|
||||
import com.intellij.platform.ml.impl.model.MLModel
|
||||
import com.intellij.platform.ml.impl.session.AnalysedRootContainer
|
||||
import com.intellij.platform.ml.impl.session.DescribedRootContainer
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
import java.util.concurrent.CompletableFuture
|
||||
|
||||
/**
|
||||
* Represents the method, that is utilized for the [LogDrivenModelInference]'s analysis.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
interface AnalysisMethod<M : MLModel<P>, P : Any> {
|
||||
/**
|
||||
* Static declaration of the features, that are used in the session tree's analysis.
|
||||
*/
|
||||
val structureAnalysisDeclaration: PerTier<Set<FeatureDeclaration<*>>>
|
||||
|
||||
/**
|
||||
* Static declaration of the session's entities, that are not tiers.
|
||||
*/
|
||||
val sessionAnalysisDeclaration: Map<String, Set<FeatureDeclaration<*>>>
|
||||
|
||||
/**
|
||||
* Perform the completed session's analysis.
|
||||
*/
|
||||
fun analyseTree(treeRoot: DescribedRootContainer<M, P>): CompletableFuture<AnalysedRootContainer<P>>
|
||||
}
|
||||
@@ -0,0 +1,62 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.approach
|
||||
|
||||
import com.intellij.platform.ml.Feature
|
||||
import com.intellij.platform.ml.FeatureDeclaration
|
||||
import com.intellij.platform.ml.impl.model.MLModel
|
||||
import com.intellij.platform.ml.impl.session.DescribedRootContainer
|
||||
import com.intellij.platform.ml.impl.session.analysis.*
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
import java.util.concurrent.CompletableFuture
|
||||
|
||||
/**
|
||||
* The session's assembled analysis declaration.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
data class GroupedAnalysisDeclaration(
|
||||
val structureAnalysis: StructureAnalysisDeclaration,
|
||||
val mlModelAnalysis: Set<FeatureDeclaration<*>>
|
||||
)
|
||||
|
||||
/**
|
||||
* The session's assembled analysis itself.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
data class GroupedAnalysis<M : MLModel<P>, P : Any>(
|
||||
val structureAnalysis: StructureAnalysis<M, P>,
|
||||
val mlModelAnalysis: Set<Feature>
|
||||
)
|
||||
|
||||
/**
|
||||
* Analyzes both structure and ML model.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
class JoinedGroupedSessionAnalyser<M : MLModel<P>, P : Any>(
|
||||
private val structureAnalysers: Collection<StructureAnalyser<M, P>>,
|
||||
private val mlModelAnalysers: Collection<MLModelAnalyser<M, P>>,
|
||||
) : SessionAnalyser<GroupedAnalysisDeclaration, GroupedAnalysis<M, P>, M, P> {
|
||||
override val analysisDeclaration = GroupedAnalysisDeclaration(
|
||||
structureAnalysis = SessionStructureAnalysisJoiner<M, P>().joinDeclarations(structureAnalysers.map { it.analysisDeclaration }),
|
||||
mlModelAnalysis = MLModelAnalysisJoiner<M, P>().joinDeclarations(mlModelAnalysers.map { it.analysisDeclaration })
|
||||
)
|
||||
|
||||
override fun analyse(sessionTreeRoot: DescribedRootContainer<M, P>): CompletableFuture<GroupedAnalysis<M, P>> {
|
||||
val joinedStructureAnalyser = JoinedSessionAnalyser(
|
||||
structureAnalysers, SessionStructureAnalysisJoiner()
|
||||
)
|
||||
val joinedMLModelAnalyser = JoinedSessionAnalyser(
|
||||
mlModelAnalysers, MLModelAnalysisJoiner()
|
||||
)
|
||||
val structureAnalysis = joinedStructureAnalyser.analyse(sessionTreeRoot)
|
||||
val mlModelAnalysis = joinedMLModelAnalyser.analyse(sessionTreeRoot)
|
||||
|
||||
val futureGroupAnalysis = CompletableFuture.allOf(structureAnalysis, mlModelAnalysis)
|
||||
val completeGroupAnalysis = CompletableFuture<GroupedAnalysis<M, P>>()
|
||||
|
||||
futureGroupAnalysis.thenRun {
|
||||
completeGroupAnalysis.complete(GroupedAnalysis(structureAnalysis.get(), mlModelAnalysis.get()))
|
||||
}
|
||||
|
||||
return completeGroupAnalysis
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.approach
|
||||
|
||||
import com.intellij.lang.Language
|
||||
import com.intellij.platform.ml.Feature
|
||||
import com.intellij.platform.ml.FeatureDeclaration
|
||||
import com.intellij.platform.ml.impl.model.MLModel
|
||||
import com.intellij.platform.ml.impl.session.DescribedRootContainer
|
||||
import com.intellij.platform.ml.impl.session.analysis.MLModelAnalyser
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
import java.util.concurrent.CompletableFuture
|
||||
|
||||
/**
|
||||
* Something, that is dedicated for one language only.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
interface LanguageSpecific {
|
||||
val languageId: String
|
||||
}
|
||||
|
||||
/**
|
||||
* The analyzer, that adds information about ML model's language to logs.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
class ModelLanguageAnalyser<M, P : Any> : MLModelAnalyser<M, P>
|
||||
where M : MLModel<P>,
|
||||
M : LanguageSpecific {
|
||||
|
||||
private val LANGUAGE_ID = FeatureDeclaration.categorical("language_id", Language.getRegisteredLanguages().map { it.id }.toSet())
|
||||
|
||||
override val analysisDeclaration = setOf(LANGUAGE_ID)
|
||||
|
||||
override fun analyse(sessionTreeRoot: DescribedRootContainer<M, P>): CompletableFuture<Set<Feature>> = CompletableFuture.completedFuture(
|
||||
setOf(LANGUAGE_ID with sessionTreeRoot.root.languageId)
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,231 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.approach
|
||||
|
||||
import com.intellij.openapi.diagnostic.thisLogger
|
||||
import com.intellij.platform.ml.*
|
||||
import com.intellij.platform.ml.ScopeEnvironment.Companion.accessibleSafelyByOrNull
|
||||
import com.intellij.platform.ml.impl.*
|
||||
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform
|
||||
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform.Companion.getDescriptorsOfTiers
|
||||
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform.Companion.getJoinedListenerForTask
|
||||
import com.intellij.platform.ml.impl.environment.ExtendedEnvironment
|
||||
import com.intellij.platform.ml.impl.model.MLModel
|
||||
import com.intellij.platform.ml.impl.monitoring.MLApproachListener
|
||||
import com.intellij.platform.ml.impl.monitoring.MLSessionListener
|
||||
import com.intellij.platform.ml.impl.session.*
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* The main way to apply classical machine learning approaches: run the ML model, collect the logs, retrain the model, repeat.
|
||||
*
|
||||
* @param task The task that is solved by this approach.
|
||||
* @param apiPlatform The platform, that the approach will be running within.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
abstract class LogDrivenModelInference<M : MLModel<P>, P : Any>(
|
||||
override val task: MLTask<P>,
|
||||
override val apiPlatform: MLApiPlatform
|
||||
) : MLTaskApproach<P> {
|
||||
/**
|
||||
* The method that is used to analyze sessions.
|
||||
*
|
||||
* [StructureAndModelAnalysis] is currently used analysis method
|
||||
* that is dedicated to analyze session's tree-like structure,
|
||||
* and the ML model.
|
||||
*/
|
||||
abstract val analysisMethod: AnalysisMethod<M, P>
|
||||
|
||||
/**
|
||||
* Provides an ML model to use during session's lifetime.
|
||||
*/
|
||||
abstract val mlModelProvider: MLModel.Provider<M, P>
|
||||
|
||||
/**
|
||||
* Declares features, that are not used by the ML model, but must be computed anyway,
|
||||
* so they make it to logs.
|
||||
*
|
||||
* A feature cannot be simultaneously declared as "not used description" and as used by the [mlModelProvider]'s
|
||||
* provided model.
|
||||
* If a feature is not declared as "not used but still computed" or as "used by the model", then it will be computed.
|
||||
*
|
||||
* It must contain explicitly declared selectors for each tier used in [task], as well as in [additionallyDescribedTiers].
|
||||
*/
|
||||
abstract val notUsedDescription: PerTier<FeatureSelector>
|
||||
|
||||
/**
|
||||
* Performs description's computation.
|
||||
* Could perform caching mechanisms to avoid recomputing features every time.
|
||||
*/
|
||||
abstract val descriptionComputer: DescriptionComputer
|
||||
|
||||
/**
|
||||
* Tiers that do not make a part of te [task], but they could be described and passed to the ML model.
|
||||
*
|
||||
* The size of this list must correspond to the number of levels in the solved [task].
|
||||
*/
|
||||
abstract val additionallyDescribedTiers: List<Set<Tier<*>>>
|
||||
|
||||
private val levels: List<LevelTiers> by lazy {
|
||||
(task.levels zip additionallyDescribedTiers).map { Level(it.first, it.second) }
|
||||
}
|
||||
|
||||
private val approachValidation: Unit by lazy { validateApproach() }
|
||||
|
||||
override fun startSession(permanentSessionEnvironment: Environment): Session.StartOutcome<P> {
|
||||
return startSessionMonitoring(permanentSessionEnvironment)
|
||||
}
|
||||
|
||||
private fun startSessionMonitoring(permanentSessionEnvironment: Environment): Session.StartOutcome<P> {
|
||||
val approachListener = apiPlatform.getJoinedListenerForTask<M, P>(this, permanentSessionEnvironment)
|
||||
try {
|
||||
return acquireModelAndStartSession(permanentSessionEnvironment, approachListener)
|
||||
}
|
||||
catch (e: Throwable) {
|
||||
approachListener.onFailedToStartSessionWithException(e)
|
||||
return Session.StartOutcome.UncaughtException(e)
|
||||
}
|
||||
}
|
||||
|
||||
private fun acquireModelAndStartSession(permanentSessionEnvironment: Environment,
|
||||
approachListener: MLApproachListener<M, P>): Session.StartOutcome<P> {
|
||||
approachValidation
|
||||
|
||||
val extendedPermanentSessionEnvironment = ExtendedEnvironment(
|
||||
apiPlatform.environmentExtenders,
|
||||
permanentSessionEnvironment,
|
||||
mlModelProvider.requiredTiers
|
||||
)
|
||||
|
||||
val mlModel: M = run {
|
||||
val mlModelProviderEnvironment = extendedPermanentSessionEnvironment.accessibleSafelyByOrNull(mlModelProvider)
|
||||
if (mlModelProviderEnvironment == null) {
|
||||
val failure = InsufficientEnvironmentForModelProviderOutcome<P>(mlModelProvider.requiredTiers,
|
||||
extendedPermanentSessionEnvironment.tiers)
|
||||
approachListener.onFailedToStartSession(failure)
|
||||
return failure
|
||||
}
|
||||
val nullableMlModel = mlModelProvider.provideModel(levels, mlModelProviderEnvironment)
|
||||
if (nullableMlModel == null) {
|
||||
val failure = ModelNotAcquiredOutcome<P>()
|
||||
approachListener.onFailedToStartSession(failure)
|
||||
return failure
|
||||
}
|
||||
nullableMlModel
|
||||
}
|
||||
|
||||
var sessionListener: MLSessionListener<M, P>? = null
|
||||
|
||||
val analyseThenLogStructure = SessionTreeHandler<DescribedRootContainer<M, P>, M, P> { treeRoot ->
|
||||
sessionListener?.onSessionDescriptionFinished(treeRoot)
|
||||
analysisMethod.analyseTree(treeRoot).thenApplyAsync { analysedSession ->
|
||||
sessionListener?.onSessionAnalysisFinished(analysedSession)
|
||||
}.exceptionally {
|
||||
thisLogger().error(it)
|
||||
}
|
||||
}
|
||||
|
||||
val session = if (levels.size == 1) {
|
||||
val collector = SolitaryLeafCollector(
|
||||
apiPlatform, levels.first(), descriptionComputer, notUsedDescription,
|
||||
permanentSessionEnvironment, levels.first().additional, mlModel
|
||||
)
|
||||
collector.handleCollectedTree(analyseThenLogStructure)
|
||||
MLModelPrediction(mlModel, collector)
|
||||
}
|
||||
else {
|
||||
val collector = RootCollector(
|
||||
apiPlatform, levels, descriptionComputer, notUsedDescription,
|
||||
permanentSessionEnvironment, levels.first().additional, mlModel
|
||||
)
|
||||
collector.handleCollectedTree(analyseThenLogStructure)
|
||||
MLModelPredictionBranching(mlModel, collector)
|
||||
}
|
||||
|
||||
sessionListener = approachListener.onStartedSession(session)
|
||||
|
||||
return Session.StartOutcome.Success(session)
|
||||
}
|
||||
|
||||
override val approachDeclaration: MLTaskApproach.Declaration
|
||||
get() {
|
||||
approachValidation
|
||||
|
||||
return MLTaskApproach.Declaration(
|
||||
sessionFeatures = analysisMethod.sessionAnalysisDeclaration,
|
||||
levelsScheme = levels.map { levelTiers ->
|
||||
Level(
|
||||
buildMainTiersScheme(levelTiers.main, apiPlatform),
|
||||
buildAdditionalTiersScheme(levelTiers.additional, apiPlatform),
|
||||
)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
private fun validateApproach() {
|
||||
require(task.levels.size == additionallyDescribedTiers.size) {
|
||||
"Task $task has ${task.levels.size} levels, when 'additionallyDescribedTiers' has ${additionallyDescribedTiers.size}"
|
||||
}
|
||||
|
||||
require(levels.isNotEmpty()) { "Task must declare at least one level" }
|
||||
|
||||
val maybeDuplicatedTaskTiers = levels.flatMap { it.main + it.additional }
|
||||
val taskTiers = maybeDuplicatedTaskTiers.toSet()
|
||||
|
||||
require(maybeDuplicatedTaskTiers.size == taskTiers.size) {
|
||||
"There are duplicated tiers in the declaration: ${maybeDuplicatedTaskTiers - taskTiers}"
|
||||
}
|
||||
require(notUsedDescription.keys == taskTiers) {
|
||||
"Selectors for those and only those tiers must be represented in the 'notUsedDescription' that are present in the task. " +
|
||||
"Missing: ${taskTiers - notUsedDescription.keys}, " +
|
||||
"Redundant: ${notUsedDescription.keys - taskTiers}"
|
||||
}
|
||||
}
|
||||
|
||||
private fun buildTierDescriptionDeclaration(tierDescriptors: Collection<TierDescriptor>): Set<FeatureDeclaration<*>> {
|
||||
return tierDescriptors.flatMap {
|
||||
if (it is ObsoleteTierDescriptor) it.partialDescriptionDeclaration else it.descriptionDeclaration
|
||||
}.toSet()
|
||||
}
|
||||
|
||||
private fun buildMainTiersScheme(tiers: Set<Tier<*>>, apiEnvironment: MLApiPlatform): PerTier<MainTierScheme> {
|
||||
val tiersDescriptors = apiEnvironment.getDescriptorsOfTiers(tiers)
|
||||
|
||||
return tiers.associateWith { tier ->
|
||||
val tierDescriptors = tiersDescriptors.getValue(tier)
|
||||
val descriptionDeclaration = buildTierDescriptionDeclaration(tierDescriptors)
|
||||
val analysisDeclaration = analysisMethod.structureAnalysisDeclaration[tier] ?: emptySet()
|
||||
MainTierScheme(descriptionDeclaration, analysisDeclaration)
|
||||
}
|
||||
}
|
||||
|
||||
private fun buildAdditionalTiersScheme(tiers: Set<Tier<*>>, apiEnvironment: MLApiPlatform): PerTier<AdditionalTierScheme> {
|
||||
val tiersDescriptors = apiEnvironment.getDescriptorsOfTiers(tiers)
|
||||
|
||||
return tiers.associateWith { tier ->
|
||||
val tierDescriptors = tiersDescriptors.getValue(tier)
|
||||
val descriptionDeclaration = buildTierDescriptionDeclaration(tierDescriptors)
|
||||
AdditionalTierScheme(descriptionDeclaration)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* An exception that indicates that for some reason, it was not possible to provide an ML model
|
||||
* when calling [com.intellij.platform.ml.impl.model.MLModel.Provider.provideModel].
|
||||
* The session's start is considered as failed.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
class ModelNotAcquiredOutcome<P : Any> : Session.StartOutcome.Failure<P>() {
|
||||
override val failureDetails: String = "ML Model was not provided"
|
||||
}
|
||||
|
||||
/**
|
||||
* There were not enough tiers to satisfy [MLModel.Provider]'s requirements, so it could not provide the model.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
class InsufficientEnvironmentForModelProviderOutcome<P : Any>(
|
||||
expectedTiers: Set<Tier<*>>,
|
||||
existingTiers: Set<Tier<*>>
|
||||
) : Session.StartOutcome.Failure<P>() {
|
||||
override val failureDetails: String = "ML Model could not be provided: environment is not sufficient. Missing: ${expectedTiers - existingTiers}"
|
||||
}
|
||||
@@ -0,0 +1,81 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.approach
|
||||
|
||||
import com.intellij.platform.ml.FeatureDeclaration
|
||||
import com.intellij.platform.ml.PerTierInstance
|
||||
import com.intellij.platform.ml.impl.model.MLModel
|
||||
import com.intellij.platform.ml.impl.session.*
|
||||
import com.intellij.platform.ml.impl.session.analysis.MLModelAnalyser
|
||||
import com.intellij.platform.ml.impl.session.analysis.StructureAnalyser
|
||||
import com.intellij.platform.ml.impl.session.analysis.StructureAnalysisDeclaration
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
import java.util.concurrent.CompletableFuture
|
||||
|
||||
/**
|
||||
* An analysis method, that asynchronously analyzes session structure after
|
||||
* it finished, and the ML model that has been used during the session.
|
||||
*
|
||||
* @param structureAnalysers Analyzing session's structure - main tier instances.
|
||||
* @param mlModelAnalysers Analyzing the ML model which was producing predictions during the session.
|
||||
* @param sessionAnalysisKeyModel Key that will be used in logs, to write ML model's features into.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
class StructureAndModelAnalysis<M : MLModel<P>, P : Any>(
|
||||
structureAnalysers: Collection<StructureAnalyser<M, P>>,
|
||||
mlModelAnalysers: Collection<MLModelAnalyser<M, P>>,
|
||||
private val sessionAnalysisKeyModel: String = DEFAULT_SESSION_KEY_ML_MODEL
|
||||
) : AnalysisMethod<M, P> {
|
||||
private val groupedAnalyser = JoinedGroupedSessionAnalyser(structureAnalysers, mlModelAnalysers)
|
||||
|
||||
override val structureAnalysisDeclaration: StructureAnalysisDeclaration
|
||||
get() = groupedAnalyser.analysisDeclaration.structureAnalysis
|
||||
|
||||
override val sessionAnalysisDeclaration: Map<String, Set<FeatureDeclaration<*>>> = mapOf(
|
||||
sessionAnalysisKeyModel to groupedAnalyser.analysisDeclaration.mlModelAnalysis
|
||||
)
|
||||
|
||||
override fun analyseTree(treeRoot: DescribedRootContainer<M, P>): CompletableFuture<AnalysedRootContainer<P>> {
|
||||
return groupedAnalyser.analyse(treeRoot).thenApply {
|
||||
buildAnalysedSessionTree(treeRoot, it) as AnalysedRootContainer<P>
|
||||
}
|
||||
}
|
||||
|
||||
private fun buildAnalysedSessionTree(tree: DescribedSessionTree<M, P>, analysis: GroupedAnalysis<M, P>): AnalysedSessionTree<P> {
|
||||
val treeAnalysisPerInstance: PerTierInstance<AnalysedTierData> = tree.level.main.entries
|
||||
.associate { (tierInstance, data) ->
|
||||
tierInstance to AnalysedTierData(data.description,
|
||||
analysis.structureAnalysis[tree]?.get(tierInstance.tier) ?: emptySet())
|
||||
}
|
||||
|
||||
val analysedLevel = Level(
|
||||
main = treeAnalysisPerInstance,
|
||||
additional = tree.level.additional
|
||||
)
|
||||
|
||||
return when (tree) {
|
||||
is SessionTree.Branching<M, DescribedLevel, P> -> {
|
||||
SessionTree.Branching(analysedLevel,
|
||||
tree.children.map { buildAnalysedSessionTree(it, analysis) })
|
||||
}
|
||||
is SessionTree.Leaf<M, DescribedLevel, P> -> {
|
||||
SessionTree.Leaf(analysedLevel, tree.prediction)
|
||||
}
|
||||
is SessionTree.ComplexRoot -> {
|
||||
SessionTree.ComplexRoot(mapOf(sessionAnalysisKeyModel to analysis.mlModelAnalysis),
|
||||
analysedLevel,
|
||||
tree.children.map { buildAnalysedSessionTree(it, analysis) }
|
||||
)
|
||||
}
|
||||
is SessionTree.SolitaryLeaf -> {
|
||||
SessionTree.SolitaryLeaf(mapOf(sessionAnalysisKeyModel to analysis.mlModelAnalysis),
|
||||
analysedLevel,
|
||||
tree.prediction
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
const val DEFAULT_SESSION_KEY_ML_MODEL: String = "ml_model"
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,37 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.approach
|
||||
|
||||
import com.intellij.openapi.util.Version
|
||||
import com.intellij.platform.ml.Feature
|
||||
import com.intellij.platform.ml.FeatureDeclaration
|
||||
import com.intellij.platform.ml.impl.model.MLModel
|
||||
import com.intellij.platform.ml.impl.session.DescribedRootContainer
|
||||
import com.intellij.platform.ml.impl.session.analysis.MLModelAnalyser
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
import java.util.concurrent.CompletableFuture
|
||||
|
||||
/**
|
||||
* Something, that has versions.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
interface Versioned {
|
||||
val version: Version?
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds model's version to the ML logs.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
class ModelVersionAnalyser<M, P : Any> : MLModelAnalyser<M, P>
|
||||
where M : MLModel<P>,
|
||||
M : Versioned {
|
||||
companion object {
|
||||
val VERSION = FeatureDeclaration.version("version").nullable()
|
||||
}
|
||||
|
||||
override val analysisDeclaration = setOf(VERSION)
|
||||
|
||||
override fun analyse(sessionTreeRoot: DescribedRootContainer<M, P>): CompletableFuture<Set<Feature>> = CompletableFuture.completedFuture(
|
||||
setOf(VERSION with sessionTreeRoot.root.version)
|
||||
)
|
||||
}
|
||||
@@ -0,0 +1,35 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.environment
|
||||
|
||||
import com.intellij.platform.ml.EnvironmentExtender
|
||||
import com.intellij.platform.ml.Tier
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* There was circle within available extenders' set's requirements.
|
||||
* Meaning that to create some tier, we must eventually run the extender itself.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
class CircularRequirementException(val extensionPath: List<EnvironmentExtender<*>>) : IllegalArgumentException() {
|
||||
override val message: String = "A circular resolve path found among EnvironmentExtenders: ${serializePath()}"
|
||||
|
||||
private fun serializePath(): String {
|
||||
val extensions = extensionPath.map { extender -> "[$extender] -> ${extender.extendingTier.name}" }
|
||||
return extensions.joinToString(" - ")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* An algorithm for resolving order of [EnvironmentExtender]s' execution.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
interface EnvironmentResolver {
|
||||
/**
|
||||
* @return The order, which guarantees that for each extender, all the requirements will be fulfilled by the previously runned
|
||||
* extenders.
|
||||
* But it still could happen that an [EnvironmentExtender] will not return the tier it extends, then some subsequent
|
||||
* extenders' requirements could not be satisfied.
|
||||
* @throws CircularRequirementException
|
||||
*/
|
||||
fun resolve(extenderPerTier: Map<Tier<*>, EnvironmentExtender<*>>): List<EnvironmentExtender<*>>
|
||||
}
|
||||
@@ -0,0 +1,139 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.environment
|
||||
|
||||
import com.intellij.platform.ml.*
|
||||
import com.intellij.platform.ml.EnvironmentExtender.Companion.extendTierInstance
|
||||
import com.intellij.platform.ml.ScopeEnvironment.Companion.accessibleSafelyByOrNull
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* An environment that is built in to fulfill a [TierRequester]'s requirements.
|
||||
*
|
||||
* When built, it accepts all available extenders, and it is trying to resolve their
|
||||
* order to extend particular tiers.
|
||||
*
|
||||
* If there are some main tiers passed to build the extended environment, they will
|
||||
* not be overridden.
|
||||
*
|
||||
* Among the available extenders, passed to the constructor, there could be some that
|
||||
* extend the same tier.
|
||||
* This could signify that an instance of the same tier can be mined from
|
||||
* different sets of objects (requirements).
|
||||
* But if there is more than one extender, that could potentially be run that will
|
||||
* extend a tier of the same instance, then [IllegalArgumentException] will be thrown
|
||||
* telling, that there is an ambiguity.
|
||||
*
|
||||
* When we have determined which extenders could potentially be run, we try to determine
|
||||
* the order with topological sort: [com.intellij.platform.ml.impl.environment.TopologicalSortingResolver].
|
||||
* If there is a circle in the requirements, it will throw [com.intellij.platform.ml.impl.environment.CircularRequirementException].
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
class ExtendedEnvironment : Environment {
|
||||
private val storage: Environment
|
||||
|
||||
/**
|
||||
* @param environmentExtenders Extenders, that will be used to extend the [tiersToExtend].
|
||||
* @param mainEnvironment Tiers that are already determined and shall not be replaced.
|
||||
* @param tiersToExtend Tiers that should be put to the extended environment via [environmentExtenders].
|
||||
* It could not be guaranteed that all desired tiers will be extended.
|
||||
*
|
||||
* @return An environment that contains all tiers from [mainEnvironment] plus
|
||||
* some tiers from [tiersToExtend], if it was possible to extend them.
|
||||
*/
|
||||
constructor(environmentExtenders: List<EnvironmentExtender<*>>,
|
||||
mainEnvironment: Environment,
|
||||
tiersToExtend: Set<Tier<*>>) {
|
||||
require(tiersToExtend.all { it !in mainEnvironment })
|
||||
val nonOverridingExtenders = environmentExtenders.filter { it.extendingTier !in mainEnvironment }
|
||||
storage = buildExtendedEnvironment(
|
||||
tiersToExtend + mainEnvironment.tiers,
|
||||
nonOverridingExtenders + mainEnvironment.separateIntoExtenders()
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* @param environmentExtenders Extenders that will be utilized to build the extended environment.
|
||||
* @param mainEnvironment An already existing environment, instances from which shall not be overridden.
|
||||
*
|
||||
* @return An environment that contains all tiers from [mainEnvironment] plus
|
||||
* all tiers that it was possible to acquire via [environmentExtenders].
|
||||
*/
|
||||
constructor(environmentExtenders: List<EnvironmentExtender<*>>,
|
||||
mainEnvironment: Environment) {
|
||||
val nonOverridingExtenders = environmentExtenders.filter { it.extendingTier !in mainEnvironment }
|
||||
storage = buildExtendedEnvironment(
|
||||
nonOverridingExtenders.map { it.extendingTier }.toSet() + mainEnvironment.tiers,
|
||||
nonOverridingExtenders + mainEnvironment.separateIntoExtenders()
|
||||
)
|
||||
}
|
||||
|
||||
override val tiers: Set<Tier<*>>
|
||||
get() = storage.tiers
|
||||
|
||||
override fun <T : Any> getInstance(tier: Tier<T>): T {
|
||||
return storage.getInstance(tier)
|
||||
}
|
||||
|
||||
companion object {
|
||||
private val ENVIRONMENT_RESOLVER = TopologicalSortingResolver()
|
||||
|
||||
/**
|
||||
* Creates an [Environment] that contents tiers from [tiers], that were successfully extended by [extenders]
|
||||
*/
|
||||
private fun buildExtendedEnvironment(tiers: Set<Tier<*>>,
|
||||
extenders: List<EnvironmentExtender<*>>): Environment {
|
||||
val validatedExtendersPerTier = validateExtenders(tiers, extenders)
|
||||
val extensionOrder = ENVIRONMENT_RESOLVER.resolve(validatedExtendersPerTier)
|
||||
val storage = TierInstanceStorage()
|
||||
|
||||
extensionOrder.map {
|
||||
val safelyAccessibleEnvironment = storage.accessibleSafelyByOrNull(it) ?: return@map
|
||||
it.extendTierInstance(safelyAccessibleEnvironment)?.let { extendedTierInstance ->
|
||||
storage.putTierInstance(extendedTierInstance)
|
||||
}
|
||||
}
|
||||
return storage
|
||||
}
|
||||
|
||||
private fun validateExtenders(tiers: Set<Tier<*>>, extenders: List<EnvironmentExtender<*>>): Map<Tier<*>, EnvironmentExtender<*>> {
|
||||
val extendableTiers: Set<Tier<*>> = extenders.map { it.extendingTier }.toSet()
|
||||
|
||||
val runnableExtenders = extenders
|
||||
.filter { desiredExtender ->
|
||||
desiredExtender.requiredTiers.all { requirementForDesiredExtender -> requirementForDesiredExtender in extendableTiers }
|
||||
}
|
||||
|
||||
val ambiguouslyExtendableTiers: MutableList<Pair<Tier<*>, List<EnvironmentExtender<*>>>> = mutableListOf()
|
||||
val extendersPerTier: Map<Tier<*>, EnvironmentExtender<*>> = runnableExtenders
|
||||
.groupBy { it.extendingTier }
|
||||
.mapNotNull { (tier, tierExtenders) ->
|
||||
if (tierExtenders.size > 1) {
|
||||
ambiguouslyExtendableTiers.add(tier to tierExtenders)
|
||||
null
|
||||
}
|
||||
else
|
||||
tierExtenders.first()
|
||||
}
|
||||
.associateBy { it.extendingTier }
|
||||
.filterKeys { it in tiers }
|
||||
|
||||
require(ambiguouslyExtendableTiers.isEmpty()) { "Some tiers could be extended ambiguously: $ambiguouslyExtendableTiers" }
|
||||
|
||||
return extendersPerTier
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun Environment.separateIntoExtenders(): List<EnvironmentExtender<*>> {
|
||||
class ContainingExtender<T : Any>(private val tier: Tier<T>) : EnvironmentExtender<T> {
|
||||
override val extendingTier: Tier<T> = tier
|
||||
|
||||
override fun extend(environment: Environment): T {
|
||||
return this@separateIntoExtenders[tier]
|
||||
}
|
||||
|
||||
override val requiredTiers: Set<Tier<*>> = emptySet()
|
||||
}
|
||||
|
||||
return this.tiers.map { tier -> ContainingExtender(tier) }
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.environment
|
||||
|
||||
import com.intellij.platform.ml.EnvironmentExtender
|
||||
import com.intellij.platform.ml.Tier
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
private typealias Node = EnvironmentExtender<*>
|
||||
|
||||
/**
|
||||
* Resolves order using topological sort
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
class TopologicalSortingResolver : EnvironmentResolver {
|
||||
|
||||
override fun resolve(extenderPerTier: Map<Tier<*>, EnvironmentExtender<*>>): List<EnvironmentExtender<*>> {
|
||||
val graph: Map<Node, List<Node>> = extenderPerTier.values
|
||||
.associateWith { desiredExtender ->
|
||||
desiredExtender.requiredTiers.map { requirementForDesiredExtender -> extenderPerTier.getValue(requirementForDesiredExtender) }
|
||||
}
|
||||
|
||||
val reverseTopologicalOrder: MutableList<Node> = mutableListOf()
|
||||
val resolveStatus: MutableMap<Node, ResolveState> = mutableMapOf()
|
||||
|
||||
fun Node.resolve(path: List<Node>) {
|
||||
when (resolveStatus[this]) {
|
||||
ResolveState.STARTED -> throw CircularRequirementException(path + this)
|
||||
ResolveState.RESOLVED -> return
|
||||
null -> {
|
||||
resolveStatus[this] = ResolveState.STARTED
|
||||
for (nextNode in graph.getValue(this)) {
|
||||
nextNode.resolve(path + this)
|
||||
}
|
||||
resolveStatus[this] = ResolveState.RESOLVED
|
||||
reverseTopologicalOrder.add(this)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
graph.keys.forEach { it.resolve(emptyList()) }
|
||||
|
||||
return reverseTopologicalOrder
|
||||
}
|
||||
|
||||
private enum class ResolveState {
|
||||
STARTED,
|
||||
RESOLVED
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,49 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.logger
|
||||
|
||||
import com.intellij.internal.statistic.eventLog.events.EventPair
|
||||
import com.intellij.internal.statistic.eventLog.events.ObjectDescription
|
||||
import com.intellij.internal.statistic.eventLog.events.ObjectEventData
|
||||
import com.intellij.platform.ml.impl.MLTaskApproach
|
||||
import com.intellij.platform.ml.impl.session.AnalysedRootContainer
|
||||
import com.intellij.platform.ml.impl.session.AnalysedSessionTree
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* Represents FUS fields of a session's subtree.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
abstract class SessionFields<P : Any> : ObjectDescription() {
|
||||
fun buildObjectEventData(sessionStructure: AnalysedSessionTree<P>) = ObjectEventData(buildEventPairs(sessionStructure))
|
||||
|
||||
abstract fun buildEventPairs(sessionStructure: AnalysedSessionTree<P>): List<EventPair<*>>
|
||||
}
|
||||
|
||||
/**
|
||||
* Represents a logging scheme for the FUS event.
|
||||
*
|
||||
* @param P The type of the ML task's prediction
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
interface FusSessionEventBuilder<P : Any> {
|
||||
/**
|
||||
* Configuration of a [FusSessionEventBuilder], that builds it when accepts approach's declaration.
|
||||
*/
|
||||
interface FusScheme<P : Any> {
|
||||
fun createEventBuilder(approachDeclaration: MLTaskApproach.Declaration): FusSessionEventBuilder<P>
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds declaration of all features, that will be logged for the session tiers' description and analysis.
|
||||
* It is required because FUS logs validators are built 'statically'.
|
||||
*/
|
||||
fun buildFusDeclaration(): SessionFields<P>
|
||||
|
||||
/**
|
||||
* Builds a concrete FUS record, that contains fields that were built by [buildFusDeclaration].
|
||||
*
|
||||
* @param sessionStructure A session tree that been already analyzed and is ready to be logged.
|
||||
* @param sessionFields Session fields that were built by [buildFusDeclaration] earlier.
|
||||
*/
|
||||
fun buildRecord(sessionStructure: AnalysedRootContainer<P>, sessionFields: SessionFields<P>): Array<EventPair<*>>
|
||||
}
|
||||
@@ -0,0 +1,390 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.logger
|
||||
|
||||
import com.intellij.internal.statistic.eventLog.FeatureUsageData
|
||||
import com.intellij.internal.statistic.eventLog.events.*
|
||||
import com.intellij.platform.ml.*
|
||||
import com.intellij.platform.ml.impl.*
|
||||
import com.intellij.platform.ml.impl.session.*
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* The currently used FUS logging scheme.
|
||||
* Inplace means that the features are logged beside the tiers' instances and
|
||||
* are not compressed in any way.
|
||||
*
|
||||
* An example of a FUS record could be found in the test resources:
|
||||
* [testResources/ml_logs.js](community/platform/ml-impl/testResources/ml_logs.js)
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
class InplaceFeaturesScheme<P : Any> internal constructor(
|
||||
private val predictionValidationRule: List<String>,
|
||||
private val predictionTransform: (P?) -> String,
|
||||
private val approachDeclaration: MLTaskApproach.Declaration
|
||||
) : FusSessionEventBuilder<P> {
|
||||
class FusScheme<P : Any>(
|
||||
private val predictionValidationRule: List<String>,
|
||||
private val predictionTransform: (P?) -> String
|
||||
) : FusSessionEventBuilder.FusScheme<P> {
|
||||
override fun createEventBuilder(approachDeclaration: MLTaskApproach.Declaration): FusSessionEventBuilder<P> = InplaceFeaturesScheme(
|
||||
predictionValidationRule,
|
||||
predictionTransform,
|
||||
approachDeclaration
|
||||
)
|
||||
|
||||
companion object {
|
||||
val DOUBLE: FusScheme<Double> = FusScheme(listOf("{regexp#float}")) { it.toString() }
|
||||
}
|
||||
}
|
||||
|
||||
override fun buildFusDeclaration(): SessionFields<P> {
|
||||
require(approachDeclaration.levelsScheme.isNotEmpty())
|
||||
return if (approachDeclaration.levelsScheme.size == 1)
|
||||
PredictionSessionFields(approachDeclaration.levelsScheme.first(), predictionValidationRule, predictionTransform,
|
||||
approachDeclaration.sessionFeatures)
|
||||
else
|
||||
NestableSessionFields(approachDeclaration.levelsScheme.first(), approachDeclaration.levelsScheme.drop(1), predictionValidationRule,
|
||||
predictionTransform, approachDeclaration.sessionFeatures)
|
||||
}
|
||||
|
||||
override fun buildRecord(sessionStructure: AnalysedRootContainer<P>, sessionFields: SessionFields<P>): Array<EventPair<*>> {
|
||||
return sessionFields.buildEventPairs(sessionStructure).toTypedArray()
|
||||
}
|
||||
}
|
||||
|
||||
private class PredictionField<T : Any>(
|
||||
override val name: String,
|
||||
override val validationRule: List<String>,
|
||||
val transform: (T?) -> String
|
||||
) : PrimitiveEventField<T?>() {
|
||||
override fun addData(fuData: FeatureUsageData, value: T?) {
|
||||
fuData.addData(name, transform(value))
|
||||
}
|
||||
}
|
||||
|
||||
private data class StringField(override val name: String, private val possibleValues: Set<String>) : PrimitiveEventField<String>() {
|
||||
override fun addData(fuData: FeatureUsageData, value: String) {
|
||||
fuData.addData(name, value)
|
||||
}
|
||||
|
||||
override val validationRule = listOf(
|
||||
"{enum:${possibleValues.joinToString("|")}}"
|
||||
)
|
||||
}
|
||||
|
||||
private data class VersionField(override val name: String) : PrimitiveEventField<String?>() {
|
||||
override val validationRule: List<String>
|
||||
get() = listOf("{regexp#version}")
|
||||
|
||||
override fun addData(fuData: FeatureUsageData, value: String?) {
|
||||
fuData.addVersionByString(value)
|
||||
}
|
||||
}
|
||||
|
||||
private fun FeatureDeclaration<*>.toEventField(): EventField<*> {
|
||||
return when (val valueType = type) {
|
||||
is FeatureValueType.Enum<*> -> EnumEventField(name, valueType.enumClass, Enum<*>::name)
|
||||
is FeatureValueType.Int -> IntEventField(name)
|
||||
is FeatureValueType.Long -> LongEventField(name)
|
||||
is FeatureValueType.Class -> ClassEventField(name)
|
||||
is FeatureValueType.Boolean -> BooleanEventField(name)
|
||||
is FeatureValueType.Double -> DoubleEventField(name)
|
||||
is FeatureValueType.Float -> FloatEventField(name)
|
||||
is FeatureValueType.Nullable -> FeatureDeclaration(name, valueType.baseType).toEventField()
|
||||
is FeatureValueType.Categorical -> StringField(name, valueType.possibleValues)
|
||||
is FeatureValueType.Version -> VersionField(name)
|
||||
}
|
||||
}
|
||||
|
||||
private fun <T : Enum<*>> Feature.Enum<T>.toEventPair(): EventPair<*> {
|
||||
return EnumEventField(declaration.name, valueType.enumClass, Enum<*>::name) with value
|
||||
}
|
||||
|
||||
private fun <T> Feature.Nullable<T>.toEventPair(): EventPair<*>? {
|
||||
return value?.let {
|
||||
baseType.instantiate(this.declaration.name, it).toEventPair()
|
||||
}
|
||||
}
|
||||
|
||||
private fun Feature.toEventPair(): EventPair<*>? {
|
||||
return when (this) {
|
||||
is Feature.TypedFeature<*> -> typedToEventPair()
|
||||
}
|
||||
}
|
||||
|
||||
private fun <T> Feature.TypedFeature<T>.typedToEventPair(): EventPair<*>? {
|
||||
return when (this) {
|
||||
is Feature.Boolean -> BooleanEventField(declaration.name) with this.value
|
||||
is Feature.Categorical -> StringField(declaration.name, this.valueType.possibleValues) with this.value
|
||||
is Feature.Class -> ClassEventField(declaration.name) with this.value
|
||||
is Feature.Double -> DoubleEventField(declaration.name) with this.value
|
||||
is Feature.Enum<*> -> toEventPair()
|
||||
is Feature.Float -> FloatEventField(declaration.name) with this.value
|
||||
is Feature.Int -> IntEventField(declaration.name) with this.value
|
||||
is Feature.Long -> LongEventField(declaration.name) with this.value
|
||||
is Feature.Nullable<*> -> toEventPair()
|
||||
is Feature.Version -> VersionField(declaration.name) with this.value.toString()
|
||||
}
|
||||
}
|
||||
|
||||
private class FeatureSet(featuresDeclarations: Set<FeatureDeclaration<*>>) : ObjectDescription() {
|
||||
init {
|
||||
for (featureDeclaration in featuresDeclarations) {
|
||||
field(featureDeclaration.toEventField())
|
||||
}
|
||||
}
|
||||
|
||||
fun toObjectEventData(features: Set<Feature>) = ObjectEventData(features.mapNotNull { it.toEventPair() })
|
||||
}
|
||||
|
||||
private fun Set<Feature>.toObjectEventData() = FeatureSet(this.map { it.declaration }.toSet()).toObjectEventData(this)
|
||||
|
||||
private data class TierDescriptionFields(
|
||||
val used: FeatureSet,
|
||||
val notUsed: FeatureSet,
|
||||
) : ObjectDescription() {
|
||||
private val fieldUsed = ObjectEventField("used", used)
|
||||
private val fieldNotUsed = ObjectEventField("not_used", notUsed)
|
||||
private val fieldAmountUsedNonDeclaredFeatures = IntEventField("n_used_non_declared")
|
||||
private val fieldAmountNotUsedNonDeclaredFeatures = IntEventField("n_not_used_non_declared")
|
||||
|
||||
init {
|
||||
field(fieldUsed)
|
||||
field(fieldNotUsed)
|
||||
field(fieldAmountUsedNonDeclaredFeatures)
|
||||
field(fieldAmountNotUsedNonDeclaredFeatures)
|
||||
}
|
||||
|
||||
fun buildEventPairs(descriptionPartition: DescriptionPartition): List<EventPair<*>> {
|
||||
val result = mutableListOf<EventPair<*>>(
|
||||
fieldUsed with descriptionPartition.declared.used.toObjectEventData(),
|
||||
fieldNotUsed with descriptionPartition.declared.notUsed.toObjectEventData(),
|
||||
)
|
||||
descriptionPartition.nonDeclared.used.let {
|
||||
if (it.isNotEmpty()) result += fieldAmountUsedNonDeclaredFeatures with it.size
|
||||
}
|
||||
descriptionPartition.nonDeclared.notUsed.let {
|
||||
if (it.isNotEmpty()) result += fieldAmountUsedNonDeclaredFeatures with it.size
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
fun buildObjectEventData(descriptionPartition: DescriptionPartition) = ObjectEventData(
|
||||
buildEventPairs(descriptionPartition)
|
||||
)
|
||||
}
|
||||
|
||||
private data class AdditionalTierFields(val description: TierDescriptionFields) : ObjectDescription() {
|
||||
private val fieldInstanceId = IntEventField("id")
|
||||
private val fieldDescription = ObjectEventField("description", description)
|
||||
|
||||
constructor(descriptionFeatures: Set<FeatureDeclaration<*>>)
|
||||
: this(TierDescriptionFields(used = FeatureSet(descriptionFeatures),
|
||||
notUsed = FeatureSet(descriptionFeatures)))
|
||||
|
||||
init {
|
||||
field(fieldInstanceId)
|
||||
field(fieldDescription)
|
||||
}
|
||||
|
||||
fun buildObjectEventData(tierInstance: TierInstance<*>,
|
||||
descriptionPartition: DescriptionPartition) = ObjectEventData(
|
||||
fieldInstanceId with tierInstance.instance.hashCode(),
|
||||
fieldDescription with this.description.buildObjectEventData(descriptionPartition),
|
||||
)
|
||||
}
|
||||
|
||||
private data class MainTierFields(
|
||||
val description: TierDescriptionFields,
|
||||
val analysis: FeatureSet,
|
||||
) : ObjectDescription() {
|
||||
private val fieldInstanceId = IntEventField("id")
|
||||
private val fieldDescription = ObjectEventField("description", description)
|
||||
private val fieldAnalysis = ObjectEventField("analysis", analysis)
|
||||
|
||||
constructor(descriptionFeatures: Set<FeatureDeclaration<*>>, analysisFeatures: Set<FeatureDeclaration<*>>)
|
||||
: this(TierDescriptionFields(used = FeatureSet(descriptionFeatures),
|
||||
notUsed = FeatureSet(descriptionFeatures)),
|
||||
FeatureSet(analysisFeatures))
|
||||
|
||||
init {
|
||||
field(fieldInstanceId)
|
||||
field(fieldDescription)
|
||||
field(fieldAnalysis)
|
||||
}
|
||||
|
||||
fun buildObjectEventData(tierInstance: TierInstance<*>,
|
||||
descriptionPartition: DescriptionPartition,
|
||||
analysis: Set<Feature>) = ObjectEventData(
|
||||
fieldInstanceId with tierInstance.instance.hashCode(),
|
||||
fieldDescription with this.description.buildObjectEventData(descriptionPartition),
|
||||
fieldAnalysis with this.analysis.toObjectEventData(analysis)
|
||||
)
|
||||
}
|
||||
|
||||
private data class SessionAnalysisFields<P : Any>(
|
||||
val featuresPerKey: Map<String, Set<FeatureDeclaration<*>>>
|
||||
) : SessionFields<P>() {
|
||||
val fieldsPerKey: Map<String, ObjectEventField> = featuresPerKey.entries.associate { (key, keyFeatures) ->
|
||||
key to ObjectEventField(key, FeatureSet(keyFeatures))
|
||||
}
|
||||
|
||||
init {
|
||||
fieldsPerKey.values.forEach { field(it) }
|
||||
}
|
||||
|
||||
override fun buildEventPairs(sessionStructure: AnalysedSessionTree<P>): List<EventPair<*>> {
|
||||
require(sessionStructure is SessionTree.RootContainer<SessionAnalysis, AnalysedLevel, P>)
|
||||
return sessionStructure.root.entries.map { (key, keyFeatures) ->
|
||||
val keyFeaturesDeclaration = requireNotNull(featuresPerKey[key]) {
|
||||
"Key $key was not declared as session features key, declared keys: ${featuresPerKey.keys}"
|
||||
}
|
||||
val objectEventField = fieldsPerKey.getValue(key)
|
||||
val keyFeatureSet = FeatureSet(keyFeaturesDeclaration)
|
||||
objectEventField with keyFeatureSet.toObjectEventData(keyFeatures)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class MainTierSet<P : Any>(mainTierScheme: PerTier<MainTierScheme>) : SessionFields<P>() {
|
||||
val tiersDeclarations: PerTier<MainTierFields> = mainTierScheme.entries.associate { (tier, tierScheme) ->
|
||||
tier to MainTierFields(tierScheme.description, tierScheme.analysis)
|
||||
}
|
||||
val fieldPerTier: PerTier<ObjectEventField> = tiersDeclarations.entries.associate { (tier, tierFields) ->
|
||||
tier to ObjectEventField(tier.name, tierFields)
|
||||
}
|
||||
|
||||
init {
|
||||
fieldPerTier.values.forEach { field(it) }
|
||||
}
|
||||
|
||||
override fun buildEventPairs(sessionStructure: AnalysedSessionTree<P>): List<EventPair<*>> {
|
||||
val level = sessionStructure.level.main
|
||||
return level.entries.map { (tierInstance, data) ->
|
||||
val tierField = requireNotNull(fieldPerTier[tierInstance.tier]) {
|
||||
"Tier ${tierInstance.tier} is now allowed here: only ${fieldPerTier.keys} are registered"
|
||||
}
|
||||
val tierDeclaration = tiersDeclarations.getValue(tierInstance.tier)
|
||||
tierField with tierDeclaration.buildObjectEventData(tierInstance, data.description, data.analysis)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class AdditionalTierSet<P : Any>(additionalTierScheme: PerTier<AdditionalTierScheme>) : SessionFields<P>() {
|
||||
val tiersDeclarations: PerTier<AdditionalTierFields> = additionalTierScheme.entries.associate { (tier, tierScheme) ->
|
||||
tier to AdditionalTierFields(tierScheme.description)
|
||||
}
|
||||
val fieldPerTier: PerTier<ObjectEventField> = tiersDeclarations.entries.associate { (tier, tierFields) ->
|
||||
tier to ObjectEventField(tier.name, tierFields)
|
||||
}
|
||||
|
||||
init {
|
||||
fieldPerTier.values.forEach { field(it) }
|
||||
}
|
||||
|
||||
override fun buildEventPairs(sessionStructure: AnalysedSessionTree<P>): List<EventPair<*>> {
|
||||
val level = sessionStructure.level.additional
|
||||
return level.entries.map { (tierInstance, data) ->
|
||||
val tierField = requireNotNull(fieldPerTier[tierInstance.tier]) {
|
||||
"Tier ${tierInstance.tier} is now allowed here: only ${fieldPerTier.keys} are registered"
|
||||
}
|
||||
val tierDeclaration = tiersDeclarations.getValue(tierInstance.tier)
|
||||
tierField with tierDeclaration.buildObjectEventData(tierInstance, data.description)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private data class PredictionSessionFields<P : Any>(
|
||||
val declarationMainTierSet: MainTierSet<P>,
|
||||
val declarationAdditionalTierSet: AdditionalTierSet<P>,
|
||||
val predictionValidationRule: List<String>,
|
||||
val predictionTransform: (P?) -> String,
|
||||
val sessionAnalysisFields: SessionAnalysisFields<P>?
|
||||
) : SessionFields<P>() {
|
||||
private val fieldMainInstances = ObjectEventField("main", declarationMainTierSet)
|
||||
private val fieldAdditionalInstances = ObjectEventField("additional", declarationAdditionalTierSet)
|
||||
private val fieldPrediction = PredictionField("prediction", predictionValidationRule, predictionTransform)
|
||||
private val fieldSessionAnalysis = sessionAnalysisFields?.let { ObjectEventField("session", sessionAnalysisFields) }
|
||||
|
||||
constructor(levelScheme: LevelScheme,
|
||||
predictionValidationRule: List<String>,
|
||||
predictionTransform: (P?) -> String,
|
||||
sessionAnalysisFields: Map<String, Set<FeatureDeclaration<*>>>?)
|
||||
: this(MainTierSet(levelScheme.main),
|
||||
AdditionalTierSet(levelScheme.additional),
|
||||
predictionValidationRule,
|
||||
predictionTransform,
|
||||
sessionAnalysisFields?.let { SessionAnalysisFields(it) })
|
||||
|
||||
init {
|
||||
field(fieldMainInstances)
|
||||
field(fieldAdditionalInstances)
|
||||
field(fieldPrediction)
|
||||
fieldSessionAnalysis?.let { field(it) }
|
||||
}
|
||||
|
||||
override fun buildEventPairs(sessionStructure: AnalysedSessionTree<P>): List<EventPair<*>> {
|
||||
require(sessionStructure is SessionTree.Leaf<*, *, P>)
|
||||
val eventPairs = mutableListOf<EventPair<*>>(
|
||||
fieldMainInstances with declarationMainTierSet.buildObjectEventData(sessionStructure),
|
||||
fieldAdditionalInstances with declarationAdditionalTierSet.buildObjectEventData(sessionStructure),
|
||||
fieldPrediction with sessionStructure.prediction,
|
||||
)
|
||||
fieldSessionAnalysis?.let {
|
||||
eventPairs += fieldSessionAnalysis with sessionAnalysisFields!!.buildObjectEventData(sessionStructure)
|
||||
}
|
||||
return eventPairs
|
||||
}
|
||||
}
|
||||
|
||||
private data class NestableSessionFields<P : Any>(
|
||||
val declarationMainTierSet: MainTierSet<P>,
|
||||
val declarationAdditionalTierSet: AdditionalTierSet<P>,
|
||||
val declarationNestedSession: SessionFields<P>,
|
||||
val sessionAnalysisFields: SessionAnalysisFields<P>?
|
||||
) : SessionFields<P>() {
|
||||
private val fieldMainInstances = ObjectEventField("main", declarationMainTierSet)
|
||||
private val fieldAdditionalInstances = ObjectEventField("additional", declarationAdditionalTierSet)
|
||||
private val fieldNestedSessions = ObjectListEventField("nested", declarationNestedSession)
|
||||
private val fieldSessionAnalysis = sessionAnalysisFields?.let { ObjectEventField("session", sessionAnalysisFields) }
|
||||
|
||||
constructor(levelScheme: LevelScheme,
|
||||
deeperLevelsSchemes: List<LevelScheme>,
|
||||
predictionValidationRule: List<String>,
|
||||
predictionTransform: (P?) -> String,
|
||||
sessionAnalysisFields: Map<String, Set<FeatureDeclaration<*>>>?)
|
||||
: this(MainTierSet(levelScheme.main),
|
||||
AdditionalTierSet(levelScheme.additional),
|
||||
if (deeperLevelsSchemes.size == 1)
|
||||
PredictionSessionFields(deeperLevelsSchemes.first(), predictionValidationRule, predictionTransform, null)
|
||||
else {
|
||||
require(deeperLevelsSchemes.size > 1)
|
||||
NestableSessionFields(deeperLevelsSchemes.first(), deeperLevelsSchemes.drop(1), predictionValidationRule,
|
||||
predictionTransform, null)
|
||||
},
|
||||
sessionAnalysisFields?.let { SessionAnalysisFields(it) }
|
||||
)
|
||||
|
||||
init {
|
||||
field(fieldMainInstances)
|
||||
field(fieldAdditionalInstances)
|
||||
field(fieldNestedSessions)
|
||||
fieldSessionAnalysis?.let { field(it) }
|
||||
}
|
||||
|
||||
override fun buildEventPairs(sessionStructure: AnalysedSessionTree<P>): List<EventPair<*>> {
|
||||
require(sessionStructure is SessionTree.ChildrenContainer)
|
||||
val children = sessionStructure.children
|
||||
val eventPairs = mutableListOf<EventPair<*>>(
|
||||
fieldMainInstances with declarationMainTierSet.buildObjectEventData(sessionStructure),
|
||||
fieldAdditionalInstances with declarationAdditionalTierSet.buildObjectEventData(sessionStructure),
|
||||
fieldNestedSessions with children.map { nestedSession -> declarationNestedSession.buildObjectEventData(nestedSession) }
|
||||
)
|
||||
|
||||
fieldSessionAnalysis?.let {
|
||||
eventPairs += fieldSessionAnalysis with sessionAnalysisFields!!.buildObjectEventData(sessionStructure)
|
||||
}
|
||||
|
||||
return eventPairs
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,36 @@
|
||||
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.logger
|
||||
|
||||
import com.intellij.internal.statistic.eventLog.events.BooleanEventField
|
||||
import com.intellij.internal.statistic.eventLog.events.ClassEventField
|
||||
import com.intellij.internal.statistic.eventLog.events.EventField
|
||||
import com.intellij.internal.statistic.eventLog.events.EventPair
|
||||
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform
|
||||
import com.intellij.platform.ml.impl.monitoring.MLApiStartupListener
|
||||
import com.intellij.platform.ml.impl.monitoring.MLApiStartupProcessListener
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
@ApiStatus.Internal
|
||||
class MLApiPlatformStartupLogger : EventIdRecordingMLEvent(), MLApiStartupListener {
|
||||
companion object {
|
||||
val SUCCESS = BooleanEventField("success")
|
||||
val EXCEPTION = ClassEventField("exception")
|
||||
}
|
||||
|
||||
override val eventName: String = "startup"
|
||||
|
||||
override val declaration: Array<EventField<*>> = arrayOf(SUCCESS)
|
||||
|
||||
override fun onBeforeStarted(apiPlatform: MLApiPlatform): MLApiStartupProcessListener {
|
||||
val eventId = getEventId(apiPlatform)
|
||||
return object : MLApiStartupProcessListener {
|
||||
override fun onFinished() = eventId.log(SUCCESS with true)
|
||||
|
||||
override fun onFailed(exception: Throwable?) {
|
||||
val fields = mutableListOf<EventPair<*>>(SUCCESS with false)
|
||||
exception?.let { fields += EXCEPTION with it.javaClass }
|
||||
eventId.log(*fields.toTypedArray())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,121 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.logger
|
||||
|
||||
import com.intellij.internal.statistic.eventLog.EventLogGroup
|
||||
import com.intellij.internal.statistic.eventLog.events.EventField
|
||||
import com.intellij.internal.statistic.eventLog.events.VarargEventId
|
||||
import com.intellij.internal.statistic.service.fus.collectors.CounterUsagesCollector
|
||||
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform
|
||||
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform.Companion.ensureApproachesInitialized
|
||||
import com.intellij.platform.ml.impl.apiPlatform.ReplaceableIJPlatform
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
@ApiStatus.Internal
|
||||
interface MLEvent {
|
||||
val eventName: String
|
||||
|
||||
val declaration: Array<EventField<*>>
|
||||
|
||||
fun onEventGroupInitialized(eventId: VarargEventId)
|
||||
}
|
||||
|
||||
@ApiStatus.Internal
|
||||
abstract class EventIdRecordingMLEvent : MLEvent {
|
||||
private var providedEventId: VarargEventId? = null
|
||||
|
||||
protected fun getEventId(apiPlatform: MLApiPlatform): VarargEventId {
|
||||
MLEventsLogger.Manager.ensureInitialized(okIfInitializing = false, apiPlatform)
|
||||
return requireNotNull(providedEventId)
|
||||
}
|
||||
|
||||
final override fun onEventGroupInitialized(eventId: VarargEventId) {
|
||||
providedEventId = eventId
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* It logs ML sessions with tiers' descriptions and the analysis.
|
||||
* Each session is logged after it is finished, and all analyzers [com.intellij.platform.ml.impl.session.analysis.SessionAnalyser]
|
||||
* have yielded the analysis.
|
||||
*
|
||||
* As the FUS logs' validators are initialized once during the application's start,
|
||||
* before giving the [getGroup] for the FUS to use, we must first walk through all declared
|
||||
* [com.intellij.platform.ml.impl.MLTaskApproach] in the platform and register them in the FUS scheme
|
||||
* (in case they require FUS logging).
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
class MLEventsLogger : CounterUsagesCollector() {
|
||||
override fun getGroup(): EventLogGroup = Initializer.GROUP.also { Manager.ensureInitialized(okIfInitializing = false) }
|
||||
|
||||
object Manager {
|
||||
private val defaultPlatform = ReplaceableIJPlatform
|
||||
|
||||
internal fun ensureInitialized(okIfInitializing: Boolean, apiPlatform: MLApiPlatform = defaultPlatform) {
|
||||
when (val state = Initializer.state) {
|
||||
is State.FailedToInitialize -> throw Exception("ML Event Log already has failed to initialize", state.exception)
|
||||
State.Initializing -> if (okIfInitializing) return else throw IllegalStateException("Initialization recursion")
|
||||
State.NonInitialized -> Initializer.initializeGroup(apiPlatform)
|
||||
is State.Initialized -> {
|
||||
val currentApiPlatformState = apiPlatform.staticState
|
||||
require(currentApiPlatformState == state.context.staticState) {
|
||||
"""
|
||||
FUS ML Logger was initialized from ${state.context.apiPlatform} with presumably immutable state ${state.context.staticState},
|
||||
but it is used from ${apiPlatform} with state ${currentApiPlatformState}, which differs from the initial state.
|
||||
Hence, something that was expected to be logged will not be.
|
||||
""".trimIndent()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
internal data class InitializationContext(
|
||||
val apiPlatform: MLApiPlatform,
|
||||
val staticState: MLApiPlatform.StaticState
|
||||
)
|
||||
|
||||
internal sealed interface State {
|
||||
data object NonInitialized : State
|
||||
|
||||
data object Initializing : State
|
||||
|
||||
class FailedToInitialize(val exception: Throwable) : State
|
||||
|
||||
class Initialized(val context: InitializationContext) : State
|
||||
}
|
||||
|
||||
internal object Initializer {
|
||||
val GROUP = EventLogGroup("ml", 1)
|
||||
|
||||
var state: State = State.NonInitialized
|
||||
|
||||
fun initializeGroup(apiPlatform: MLApiPlatform) {
|
||||
require(state == State.NonInitialized)
|
||||
state = State.Initializing
|
||||
try {
|
||||
apiPlatform.ensureApproachesInitialized()
|
||||
val apiPlatformState = apiPlatform.staticState
|
||||
|
||||
apiPlatform.addDefaultEventsAndListeners()
|
||||
|
||||
apiPlatform.events.forEach { event ->
|
||||
val eventId = GROUP.registerVarargEvent(event.eventName, *event.declaration)
|
||||
event.onEventGroupInitialized(eventId)
|
||||
}
|
||||
state = State.Initialized(InitializationContext(apiPlatform, apiPlatformState))
|
||||
}
|
||||
catch (e: Throwable) {
|
||||
state = State.FailedToInitialize(e)
|
||||
}
|
||||
(state as? State.FailedToInitialize)?.let {
|
||||
throw Exception("Failed to initialize FUS ML Logger", it.exception)
|
||||
}
|
||||
}
|
||||
|
||||
private fun MLApiPlatform.addDefaultEventsAndListeners() {
|
||||
val startupLogger = MLApiPlatformStartupLogger()
|
||||
addEvent(startupLogger)
|
||||
addStartupListener(startupLogger)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,98 @@
|
||||
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.logger
|
||||
|
||||
import com.intellij.internal.statistic.eventLog.events.EventField
|
||||
import com.intellij.internal.statistic.eventLog.events.ObjectEventData
|
||||
import com.intellij.internal.statistic.eventLog.events.ObjectEventField
|
||||
import com.intellij.platform.ml.Session
|
||||
import com.intellij.platform.ml.impl.MLTaskApproach
|
||||
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform
|
||||
import com.intellij.platform.ml.impl.monitoring.*
|
||||
import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.ApproachListeners.Companion.monitoredBy
|
||||
import com.intellij.platform.ml.impl.session.analysis.ShallowSessionAnalyser
|
||||
import com.intellij.platform.ml.impl.session.analysis.ShallowSessionAnalyser.Companion.declarationObjectDescription
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* Logs to FUS information about an ML session, that has failed to start.
|
||||
*
|
||||
* A way to start logging, is to create a [FailedSessionLoggerRegister] and add it via [com.intellij.platform.ml.impl.apiPlatform.ReplaceableIJPlatform.addStartupListener]
|
||||
*
|
||||
* @param taskApproach The approach that is monitored
|
||||
* @param exceptionalAnalysers Analyzers, that are triggered, when [MLTaskApproach.startSession] fails with an unhandled exception
|
||||
* @param normalFailureAnalysers Analyzers, that aer triggered, when [MLTaskApproach.startSession] returns a [Session.StartOutcome.Failure]
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
class MLSessionFailedLogger<M, P : Any>(
|
||||
private val taskApproach: MLTaskApproach<P>,
|
||||
private val exceptionalAnalysers: Collection<ShallowSessionAnalyser<Throwable>>,
|
||||
private val normalFailureAnalysers: Collection<ShallowSessionAnalyser<Session.StartOutcome.Failure<P>>>,
|
||||
private val apiPlatform: MLApiPlatform,
|
||||
) : EventIdRecordingMLEvent(), MLTaskGroupListener {
|
||||
override val eventName: String = "${taskApproach.task.name}.failed"
|
||||
|
||||
private val fields: Map<String, ObjectEventField> = (
|
||||
exceptionalAnalysers.map { ObjectEventField(it.name, it.declarationObjectDescription) } +
|
||||
normalFailureAnalysers.map { ObjectEventField(it.name, it.declarationObjectDescription) })
|
||||
.associateBy { it.name }
|
||||
|
||||
override val declaration: Array<EventField<*>> = fields.values.toTypedArray()
|
||||
|
||||
override val approachListeners: Collection<MLTaskGroupListener.ApproachListeners<*, *>>
|
||||
get() {
|
||||
val eventId = getEventId(apiPlatform)
|
||||
return listOf(
|
||||
taskApproach.javaClass monitoredBy MLApproachInitializationListener { permanentSessionEnvironment ->
|
||||
object : MLApproachListener<M, P> {
|
||||
override fun onFailedToStartSessionWithException(exception: Throwable) {
|
||||
val analysis = exceptionalAnalysers.map {
|
||||
val analyserField = fields.getValue(it.name)
|
||||
analyserField with ObjectEventData(it.analyse(permanentSessionEnvironment, exception))
|
||||
}
|
||||
eventId.log(*analysis.toTypedArray())
|
||||
}
|
||||
|
||||
override fun onFailedToStartSession(failure: Session.StartOutcome.Failure<P>) {
|
||||
val analysis = normalFailureAnalysers.map {
|
||||
val analyserField = fields.getValue(it.name)
|
||||
analyserField with ObjectEventData(it.analyse(permanentSessionEnvironment, failure))
|
||||
}
|
||||
eventId.log(*analysis.toTypedArray())
|
||||
}
|
||||
|
||||
override fun onStartedSession(session: Session<P>) = null
|
||||
}
|
||||
}
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers failed sessions' logging as a separate FUS event.
|
||||
*
|
||||
* See [MLSessionFailedLogger]
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
open class FailedSessionLoggerRegister<M, P : Any>(
|
||||
private val targetApproachClass: Class<out MLTaskApproach<P>>,
|
||||
private val exceptionalAnalysers: Collection<ShallowSessionAnalyser<Throwable>>,
|
||||
private val normalFailureAnalysers: Collection<ShallowSessionAnalyser<Session.StartOutcome.Failure<P>>>,
|
||||
) : MLApiStartupListener {
|
||||
override fun onBeforeStarted(apiPlatform: MLApiPlatform): MLApiStartupProcessListener {
|
||||
return object : MLApiStartupProcessListener {
|
||||
override fun onStartedInitializingFus(initializedApproaches: Collection<InitializerAndApproach<*>>) {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
val targetInitializedApproach: MLTaskApproach<P> = requireNotNull(
|
||||
initializedApproaches.find { it.approach.javaClass == targetApproachClass }) {
|
||||
"""
|
||||
Could not create logger for failed sessions for $targetApproachClass in platform $apiPlatform, because the corresponding approach was not initialized.
|
||||
Initialized approaches: $initializedApproaches
|
||||
""".trimIndent()
|
||||
}.approach as MLTaskApproach<P>
|
||||
val finishedEventLogger = MLSessionFailedLogger<M, P>(targetInitializedApproach, exceptionalAnalysers, normalFailureAnalysers, apiPlatform)
|
||||
apiPlatform.addMLEventBeforeFusInitialized(finishedEventLogger)
|
||||
apiPlatform.addTaskListener(finishedEventLogger)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,89 @@
|
||||
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.logger
|
||||
|
||||
import com.intellij.internal.statistic.eventLog.events.EventField
|
||||
import com.intellij.platform.ml.Environment
|
||||
import com.intellij.platform.ml.Session
|
||||
import com.intellij.platform.ml.impl.MLTaskApproach
|
||||
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform
|
||||
import com.intellij.platform.ml.impl.monitoring.*
|
||||
import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.ApproachListeners.Companion.monitoredBy
|
||||
import com.intellij.platform.ml.impl.session.AnalysedRootContainer
|
||||
import com.intellij.platform.ml.impl.session.DescribedRootContainer
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* Logs to FUS information about a finished ML session, tier instances' descriptions, analysis.
|
||||
*
|
||||
* A way to start logging, is to create a [FinishedSessionLoggerRegister] and add it with [com.intellij.platform.ml.impl.apiPlatform.ReplaceableIJPlatform.addStartupListener].
|
||||
* If you are using a regression model, see [InplaceFeaturesScheme.FusScheme.Companion.DOUBLE].
|
||||
*
|
||||
* @param approach The approach whose sessions are monitored
|
||||
* @param configuration The particular scheme that is used to serialize sessions
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
class MLSessionFinishedLogger<R, P : Any>(
|
||||
approach: MLTaskApproach<P>,
|
||||
configuration: FusSessionEventBuilder.FusScheme<P>,
|
||||
private val apiPlatform: MLApiPlatform,
|
||||
) : EventIdRecordingMLEvent(), MLTaskGroupListener {
|
||||
private val loggingScheme: FusSessionEventBuilder<P> = configuration.createEventBuilder(approach.approachDeclaration)
|
||||
private val fusDeclaration: SessionFields<P> = loggingScheme.buildFusDeclaration()
|
||||
override val declaration: Array<EventField<*>> = fusDeclaration.getFields()
|
||||
|
||||
override val eventName: String = "${approach.task.name}.finished"
|
||||
|
||||
override val approachListeners: Collection<MLTaskGroupListener.ApproachListeners<*, *>> = listOf(
|
||||
approach.javaClass monitoredBy InitializationLogger()
|
||||
)
|
||||
|
||||
inner class InitializationLogger : MLApproachInitializationListener<R, P> {
|
||||
override fun onAttemptedToStartSession(permanentSessionEnvironment: Environment): MLApproachListener<R, P> = ApproachLogger()
|
||||
}
|
||||
|
||||
inner class ApproachLogger : MLApproachListener<R, P> {
|
||||
override fun onFailedToStartSessionWithException(exception: Throwable) {}
|
||||
|
||||
override fun onFailedToStartSession(failure: Session.StartOutcome.Failure<P>) {}
|
||||
|
||||
override fun onStartedSession(session: Session<P>): MLSessionListener<R, P> = SessionLogger()
|
||||
}
|
||||
|
||||
inner class SessionLogger : MLSessionListener<R, P> {
|
||||
override fun onSessionDescriptionFinished(sessionTree: DescribedRootContainer<R, P>) {}
|
||||
|
||||
override fun onSessionAnalysisFinished(sessionTree: AnalysedRootContainer<P>) {
|
||||
val eventId = getEventId(apiPlatform)
|
||||
val fusEventData = loggingScheme.buildRecord(sessionTree, fusDeclaration)
|
||||
eventId.log(*fusEventData)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Registers successful ML sessions' logging as a separate FUS event.
|
||||
*
|
||||
* See [MLSessionFinishedLogger]
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
open class FinishedSessionLoggerRegister<M, P : Any>(
|
||||
private val targetApproachClass: Class<out MLTaskApproach<P>>,
|
||||
private val fusScheme: FusSessionEventBuilder.FusScheme<P>,
|
||||
) : MLApiStartupListener {
|
||||
override fun onBeforeStarted(apiPlatform: MLApiPlatform): MLApiStartupProcessListener {
|
||||
return object : MLApiStartupProcessListener {
|
||||
override fun onStartedInitializingFus(initializedApproaches: Collection<InitializerAndApproach<*>>) {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
val targetInitializedApproach: MLTaskApproach<P> = requireNotNull(initializedApproaches.find { it.approach.javaClass == targetApproachClass }) {
|
||||
"""
|
||||
Could not create logger for finished sessions of $targetApproachClass in platform $apiPlatform, because the corresponding approach was not initialized.
|
||||
Initialized approaches: $initializedApproaches
|
||||
""".trimIndent()
|
||||
}.approach as MLTaskApproach<P>
|
||||
val finishedEventLogger = MLSessionFinishedLogger<M, P>(targetInitializedApproach, fusScheme, apiPlatform)
|
||||
apiPlatform.addMLEventBeforeFusInitialized(finishedEventLogger)
|
||||
apiPlatform.addTaskListener(finishedEventLogger)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,61 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.model
|
||||
|
||||
import com.intellij.platform.ml.Environment
|
||||
import com.intellij.platform.ml.Feature
|
||||
import com.intellij.platform.ml.PerTier
|
||||
import com.intellij.platform.ml.TierRequester
|
||||
import com.intellij.platform.ml.impl.FeatureSelector
|
||||
import com.intellij.platform.ml.impl.LevelTiers
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* Performs a prediction based on the given features.
|
||||
*
|
||||
* @param P The prediction's type.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
interface MLModel<P : Any> {
|
||||
/**
|
||||
* Provides a model, that will be used during an ML session.
|
||||
*
|
||||
* It extends [TierRequester] interface which implies, that you can request
|
||||
* additional tiers that will help you acquire the right model.
|
||||
*/
|
||||
interface Provider<M : MLModel<P>, P : Any> : TierRequester {
|
||||
/**
|
||||
* Provides an ML model from the task's environment, and the additional tiers
|
||||
* declared via [requiredTiers].
|
||||
* If [requiredTiers] could not be fulfilled, then the session will not be started
|
||||
* and [com.intellij.platform.ml.impl.approach.InsufficientEnvironmentForModelProviderOutcome]
|
||||
* will be returned as the start's outcome.
|
||||
*
|
||||
* @param sessionTiers Contains the main as well as additional tiers that will be used during the session.
|
||||
* @param environment Contains the all-embracing "permanent" tiers of an ML session - the ones that sit on the first position
|
||||
* of an [com.intellij.platform.ml.impl.MLTask]'s declaration.
|
||||
* Plus additional tiers that were requested in [requiredTiers].
|
||||
* extendedPermanentSessionEnvironment
|
||||
*/
|
||||
fun provideModel(sessionTiers: List<LevelTiers>, environment: Environment): M?
|
||||
}
|
||||
|
||||
/**
|
||||
* Declares a set of features, that are known and can be used by the ML model.
|
||||
* For each tier, it returns a selection.
|
||||
* Selection will be 'complete' if and only if the model could be executed with it.
|
||||
*
|
||||
* The set of tiers it is aware of could not include some additional tiers, declared
|
||||
* in [com.intellij.platform.ml.impl.approach.LogDrivenModelInference]'s 'not used description'.
|
||||
* Because old models could be not aware of the newly established tiers.
|
||||
*/
|
||||
val knownFeatures: PerTier<FeatureSelector>
|
||||
|
||||
/**
|
||||
* Performs the prediction.
|
||||
*
|
||||
* @param features Features computed for this model to take into account.
|
||||
* It is guaranteed that for each tier, this set of features is a 'complete'
|
||||
* selection of features.
|
||||
*/
|
||||
fun predict(features: PerTier<Set<Feature>>): P
|
||||
}
|
||||
@@ -0,0 +1,170 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.model
|
||||
|
||||
import com.intellij.internal.ml.DecisionFunction
|
||||
import com.intellij.platform.ml.Feature
|
||||
import com.intellij.platform.ml.FeatureDeclaration
|
||||
import com.intellij.platform.ml.PerTier
|
||||
import com.intellij.platform.ml.Tier
|
||||
import com.intellij.platform.ml.impl.FeatureSelector
|
||||
import com.intellij.platform.ml.impl.LevelTiers
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* A wrapper for using legacy [DecisionFunction] as the new API's [MLModel].
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
open class RegressionModel private constructor(
|
||||
private val decisionFunction: DecisionFunction,
|
||||
private val featuresTiers: Set<Tier<*>>,
|
||||
availableTiers: Set<Tier<*>>,
|
||||
private val featureSerialization: FeatureNameSerialization
|
||||
) : MLModel<Double> {
|
||||
constructor(decisionFunction: DecisionFunction,
|
||||
featureSerialization: FeatureNameSerialization,
|
||||
sessionTiers: List<LevelTiers>) : this(
|
||||
decisionFunction = decisionFunction,
|
||||
featuresTiers = decisionFunction.featuresOrder.map {
|
||||
featureSerialization.deserialize(it.featureName, sessionTiers.flatten().associateBy { it.name }).first
|
||||
}.toSet(),
|
||||
availableTiers = sessionTiers.flatten(),
|
||||
featureSerialization = featureSerialization
|
||||
)
|
||||
|
||||
override val knownFeatures: PerTier<FeatureSelector> = createFeatureSelectors(
|
||||
DecisionFunctionWrapper(decisionFunction, availableTiers, featureSerialization),
|
||||
featuresTiers
|
||||
)
|
||||
|
||||
override fun predict(features: PerTier<Set<Feature>>): Double {
|
||||
val array = DoubleArray(decisionFunction.featuresOrder.size)
|
||||
val featurePerSerializedName = features
|
||||
.flatMap { (tier, tierFeatures) -> tierFeatures.map { tier to it } }
|
||||
.associate { (tier, feature) -> featureSerialization.serialize(tier, feature.declaration.name) to feature }
|
||||
|
||||
require(features.keys == featuresTiers) {
|
||||
"Given features tiers are ${features.keys}, but this model needs ${featuresTiers}"
|
||||
}
|
||||
|
||||
for (featureI in decisionFunction.featuresOrder.indices) {
|
||||
val featureMapper = decisionFunction.featuresOrder[featureI]
|
||||
val featureSerializedName = featureMapper.featureName
|
||||
val featureValue = featureMapper.asArrayValue(featurePerSerializedName[featureSerializedName]?.value)
|
||||
array[featureI] = featureValue
|
||||
}
|
||||
|
||||
return decisionFunction.predict(array)
|
||||
}
|
||||
|
||||
interface FeatureNameSerialization {
|
||||
fun serialize(tier: Tier<*>, featureName: String): String
|
||||
|
||||
fun deserialize(serializedFeatureName: String, availableTiersPerName: Map<String, Tier<*>>): Pair<Tier<*>, String>
|
||||
}
|
||||
|
||||
class DefaultSerialization : FeatureNameSerialization {
|
||||
private val SERIALIZED_FEATURE_SEPARATOR = '/'
|
||||
|
||||
override fun serialize(tier: Tier<*>, featureName: String): String {
|
||||
return "${tier.name}${SERIALIZED_FEATURE_SEPARATOR}${featureName}"
|
||||
}
|
||||
|
||||
override fun deserialize(serializedFeatureName: String, availableTiersPerName: Map<String, Tier<*>>): Pair<Tier<*>, String> {
|
||||
val indexOfLastSeparator = serializedFeatureName.indexOfLast { it == SERIALIZED_FEATURE_SEPARATOR }
|
||||
require(indexOfLastSeparator >= 0) { "Feature name '$serializedFeatureName' does not contain tier's name" }
|
||||
val featureTierName = serializedFeatureName.slice(0 until indexOfLastSeparator)
|
||||
val featureName = serializedFeatureName.slice(indexOfLastSeparator until serializedFeatureName.length)
|
||||
val featureTier = requireNotNull(
|
||||
availableTiersPerName[featureTierName]) { "Serialized feature '$serializedFeatureName' has tier $featureTierName, but all available tiers are ${availableTiersPerName.keys}" }
|
||||
return featureTier to featureName
|
||||
}
|
||||
}
|
||||
|
||||
class SelectionMissingFeatures(
|
||||
selectedFeatures: Set<FeatureDeclaration<*>>,
|
||||
missingFeatures: Set<String>
|
||||
) : FeatureSelector.Selection.Incomplete(selectedFeatures) {
|
||||
override val details: String = "Regression model requires more features to run. " +
|
||||
"Missing: $missingFeatures, " +
|
||||
"Has: $selectedFeatures"
|
||||
}
|
||||
|
||||
private class DecisionFunctionWrapper(
|
||||
private val decisionFunction: DecisionFunction,
|
||||
private val availableTiers: Set<Tier<*>>,
|
||||
private val featureNameSerialization: FeatureNameSerialization
|
||||
) {
|
||||
private val availableTiersPerName: Map<String, Tier<*>> = availableTiers.associateBy { it.name }
|
||||
|
||||
fun getKnownFeatures(): PerTier<Set<String>> {
|
||||
val knownFeaturesSerializedNames = decisionFunction.featuresOrder.map { it.featureName }.toSet()
|
||||
return knownFeaturesSerializedNames
|
||||
.map { featureNameSerialization.deserialize(it, availableTiersPerName) }
|
||||
.groupBy({ it.first }, { it.second })
|
||||
.mapValues { it.value.toSet() }
|
||||
}
|
||||
|
||||
fun getRequiredFeaturesPerTier(): PerTier<Set<String>> {
|
||||
val availableTiersPerName = availableTiers.associateBy { it.name }
|
||||
val requiredFeaturesSerializedNames = decisionFunction.requiredFeatures.filterNotNull().toSet()
|
||||
return requiredFeaturesSerializedNames
|
||||
.map { serializedFeatureName -> featureNameSerialization.deserialize(serializedFeatureName, availableTiersPerName) }
|
||||
.groupBy({ it.first }, { it.second })
|
||||
.mapValues { it.value.toSet() }
|
||||
}
|
||||
|
||||
fun getUnknownFeatures(tier: Tier<*>, featuresNames: Set<String>): Set<String> {
|
||||
val featureNamePerSerializedName = featuresNames
|
||||
.associateBy { featureNameSerialization.serialize(tier, it) }
|
||||
|
||||
val unknownFeaturesSerializedNames = decisionFunction.getUnknownFeatures(featureNamePerSerializedName.keys).filterNotNull()
|
||||
|
||||
return unknownFeaturesSerializedNames.map {
|
||||
requireNotNull(featureNamePerSerializedName[it]) { "Decision function returned an unknown feature that was not given: '$it'" }
|
||||
}.toSet()
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
private fun createFeatureSelectors(decisionFunction: DecisionFunctionWrapper,
|
||||
featuresTiers: Set<Tier<*>>): PerTier<FeatureSelector> {
|
||||
val requiredFeaturesPerTier = decisionFunction.getRequiredFeaturesPerTier()
|
||||
|
||||
fun createFeatureSelector(tier: Tier<*>) = object : FeatureSelector {
|
||||
init {
|
||||
val knownFeatures = decisionFunction.getKnownFeatures()
|
||||
knownFeatures.forEach { (tier, tierFeatures) ->
|
||||
val nonConsistentlyKnownFeatures = decisionFunction.getUnknownFeatures(tier, tierFeatures)
|
||||
require(nonConsistentlyKnownFeatures.isEmpty()) {
|
||||
"These features are known and unknown at the same time: $nonConsistentlyKnownFeatures"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override fun select(availableFeatures: Set<FeatureDeclaration<*>>): FeatureSelector.Selection {
|
||||
val availableFeaturesPerName = availableFeatures.associateBy { it.name }
|
||||
val availableFeaturesNames = availableFeatures.map { it.name }.toSet()
|
||||
val unknownFeaturesNames = decisionFunction.getUnknownFeatures(tier, availableFeaturesNames)
|
||||
val knownAvailableFeaturesNames = availableFeaturesNames - unknownFeaturesNames
|
||||
val knownAvailableFeatures = knownAvailableFeaturesNames.map { availableFeaturesPerName.getValue(it) }.toSet()
|
||||
val requiredFeaturesNames = requiredFeaturesPerTier[tier] ?: emptySet()
|
||||
|
||||
return if (availableFeaturesNames.containsAll(requiredFeaturesNames))
|
||||
FeatureSelector.Selection.Complete(knownAvailableFeatures)
|
||||
else
|
||||
SelectionMissingFeatures(knownAvailableFeatures, requiredFeaturesNames - availableFeaturesNames)
|
||||
}
|
||||
|
||||
override fun select(featureDeclaration: FeatureDeclaration<*>): Boolean {
|
||||
return decisionFunction.getUnknownFeatures(tier, setOf(featureDeclaration.name)).isEmpty()
|
||||
}
|
||||
}
|
||||
|
||||
return featuresTiers.associateWith { createFeatureSelector(it) }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun List<LevelTiers>.flatten(): Set<Tier<*>> {
|
||||
return this.flatMap { it.main + it.additional }.toSet()
|
||||
}
|
||||
@@ -0,0 +1,29 @@
|
||||
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.monitoring
|
||||
|
||||
import com.intellij.platform.ml.impl.MLTaskApproach
|
||||
import com.intellij.platform.ml.impl.MLTaskApproachInitializer
|
||||
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
@ApiStatus.Internal
|
||||
fun interface MLApiStartupListener {
|
||||
fun onBeforeStarted(apiPlatform: MLApiPlatform): MLApiStartupProcessListener
|
||||
}
|
||||
|
||||
@ApiStatus.Internal
|
||||
data class InitializerAndApproach<T : Any>(
|
||||
val initializer: MLTaskApproachInitializer<T>,
|
||||
val approach: MLTaskApproach<T>
|
||||
)
|
||||
|
||||
@ApiStatus.Internal
|
||||
interface MLApiStartupProcessListener {
|
||||
fun onStartedInitializingApproaches() {}
|
||||
|
||||
fun onStartedInitializingFus(initializedApproaches: Collection<InitializerAndApproach<*>>) {}
|
||||
|
||||
fun onFinished() {}
|
||||
|
||||
fun onFailed(exception: Throwable?) {}
|
||||
}
|
||||
@@ -0,0 +1,166 @@
|
||||
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.monitoring
|
||||
|
||||
import com.intellij.platform.ml.Environment
|
||||
import com.intellij.platform.ml.Session
|
||||
import com.intellij.platform.ml.impl.MLTaskApproach
|
||||
import com.intellij.platform.ml.impl.monitoring.MLApproachInitializationListener.Companion.asJoinedListener
|
||||
import com.intellij.platform.ml.impl.monitoring.MLApproachListener.Companion.asJoinedListener
|
||||
import com.intellij.platform.ml.impl.monitoring.MLSessionListener.Companion.asJoinedListener
|
||||
import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.ApproachListeners.Companion.monitoredBy
|
||||
import com.intellij.platform.ml.impl.session.AnalysedRootContainer
|
||||
import com.intellij.platform.ml.impl.session.DescribedRootContainer
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* Provides listeners for a set of [com.intellij.platform.ml.impl.MLTaskApproach]
|
||||
*
|
||||
* Only [com.intellij.platform.ml.impl.approach.LogDrivenModelInference] and the subclasses, that are
|
||||
* calling [com.intellij.platform.ml.impl.approach.LogDrivenModelInference.startSession] are monitored.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
interface MLTaskGroupListener {
|
||||
/**
|
||||
* For every approach, the [MLTaskGroupListener] is interested in this value provides a collection of
|
||||
* [MLApproachInitializationListener]
|
||||
*
|
||||
* The comfortable way to create this accordance would be by using
|
||||
* [com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.ApproachListeners.Companion.monitoredBy] infix function.
|
||||
*/
|
||||
val approachListeners: Collection<ApproachListeners<*, *>>
|
||||
|
||||
/**
|
||||
* A type-safe pair of approach's class and a set of listeners
|
||||
*
|
||||
* A proper way to create it is to use [monitoredBy]
|
||||
*/
|
||||
data class ApproachListeners<R, P : Any> internal constructor(
|
||||
val taskApproach: Class<out MLTaskApproach<P>>,
|
||||
val approachListener: Collection<MLApproachInitializationListener<R, P>>
|
||||
) {
|
||||
companion object {
|
||||
infix fun <R, P : Any> Class<out MLTaskApproach<P>>.monitoredBy(approachListener: MLApproachInitializationListener<R, P>) = ApproachListeners(
|
||||
this, listOf(approachListener))
|
||||
|
||||
infix fun <R, P : Any> Class<out MLTaskApproach<P>>.monitoredBy(approachListeners: Collection<MLApproachInitializationListener<R, P>>) = ApproachListeners(
|
||||
this, approachListeners)
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
internal val MLTaskGroupListener.targetedApproaches: Set<Class<out MLTaskApproach<*>>>
|
||||
get() = approachListeners.map { it.taskApproach }.toSet()
|
||||
|
||||
internal fun <P : Any, R> MLTaskGroupListener.onAttemptedToStartSession(taskApproach: MLTaskApproach<P>,
|
||||
permanentSessionEnvironment: Environment): MLApproachListener<R, P>? {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
val approachListeners: List<MLApproachInitializationListener<R, P>> = approachListeners
|
||||
.filter { it.taskApproach == taskApproach.javaClass }
|
||||
.flatMap { it.approachListener } as List<MLApproachInitializationListener<R, P>>
|
||||
return approachListeners.asJoinedListener().onAttemptedToStartSession(permanentSessionEnvironment)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Listens to the attempt of starting new [Session] of the [MLTaskApproach], that this listener was put
|
||||
* into correspondence to via [com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.ApproachListeners.Companion.monitoredBy]
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
fun interface MLApproachInitializationListener<R, P : Any> {
|
||||
/**
|
||||
* Called each time, when [com.intellij.platform.ml.impl.approach.LogDrivenModelInference.startSession] is invoked
|
||||
*
|
||||
* @return A listener, that will be monitoring how successful the start was. If it is not needed, null is returned.
|
||||
*/
|
||||
fun onAttemptedToStartSession(permanentSessionEnvironment: Environment): MLApproachListener<R, P>?
|
||||
|
||||
companion object {
|
||||
fun <R, P : Any> Collection<MLApproachInitializationListener<R, P>>.asJoinedListener(): MLApproachInitializationListener<R, P> =
|
||||
MLApproachInitializationListener { permanentSessionEnvironment ->
|
||||
val approachListeners = this@asJoinedListener.mapNotNull { it.onAttemptedToStartSession(permanentSessionEnvironment) }
|
||||
if (approachListeners.isEmpty()) null else approachListeners.asJoinedListener()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Listens to the process of starting new [Session] of [com.intellij.platform.ml.impl.approach.LogDrivenModelInference].
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
interface MLApproachListener<R, P : Any> {
|
||||
/**
|
||||
* Called if the session was not started,
|
||||
* on exceptionally rare occasions,
|
||||
* when the [com.intellij.platform.ml.impl.approach.LogDrivenModelInference.startSession] failed with an exception
|
||||
*/
|
||||
fun onFailedToStartSessionWithException(exception: Throwable)
|
||||
|
||||
/**
|
||||
* Called if the session was not started,
|
||||
* but the failure is 'ordinary'.
|
||||
*/
|
||||
fun onFailedToStartSession(failure: Session.StartOutcome.Failure<P>)
|
||||
|
||||
/**
|
||||
* Called when a new [com.intellij.platform.ml.impl.approach.LogDrivenModelInference]'s session was started successfully.
|
||||
*
|
||||
* @return A listener for tracking the session's progress, null if the session will not be tracked.
|
||||
*/
|
||||
fun onStartedSession(session: Session<P>): MLSessionListener<R, P>?
|
||||
|
||||
companion object {
|
||||
fun <R, P : Any> Collection<MLApproachListener<R, P>>.asJoinedListener(): MLApproachListener<R, P> {
|
||||
val approachListeners = this@asJoinedListener
|
||||
|
||||
return object : MLApproachListener<R, P> {
|
||||
override fun onFailedToStartSessionWithException(exception: Throwable) =
|
||||
approachListeners.forEach { it.onFailedToStartSessionWithException(exception) }
|
||||
|
||||
override fun onFailedToStartSession(failure: Session.StartOutcome.Failure<P>) = approachListeners.forEach {
|
||||
it.onFailedToStartSession(failure)
|
||||
}
|
||||
|
||||
override fun onStartedSession(session: Session<P>): MLSessionListener<R, P>? {
|
||||
val listeners = approachListeners.mapNotNull { it.onStartedSession(session) }
|
||||
return if (listeners.isEmpty()) null else listeners.asJoinedListener()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Listens to session events of a [com.intellij.platform.ml.impl.approach.LogDrivenModelInference]
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
interface MLSessionListener<R, P : Any> {
|
||||
/**
|
||||
* All tier instances were established (the tree will not be growing further),
|
||||
* described, and predictions in the [sessionTree] were finished.
|
||||
*/
|
||||
fun onSessionDescriptionFinished(sessionTree: DescribedRootContainer<R, P>)
|
||||
|
||||
/**
|
||||
* Called only after [onSessionDescriptionFinished]
|
||||
*
|
||||
* All tree nodes were analyzed.
|
||||
*/
|
||||
fun onSessionAnalysisFinished(sessionTree: AnalysedRootContainer<P>)
|
||||
|
||||
companion object {
|
||||
fun <R, P : Any> Collection<MLSessionListener<R, P>>.asJoinedListener(): MLSessionListener<R, P> {
|
||||
val sessionListeners = this@asJoinedListener
|
||||
|
||||
return object : MLSessionListener<R, P> {
|
||||
override fun onSessionDescriptionFinished(sessionTree: DescribedRootContainer<R, P>) = sessionListeners.forEach {
|
||||
it.onSessionDescriptionFinished(sessionTree)
|
||||
}
|
||||
|
||||
override fun onSessionAnalysisFinished(sessionTree: AnalysedRootContainer<P>) = sessionListeners.forEach {
|
||||
it.onSessionAnalysisFinished(sessionTree)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,182 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.session
|
||||
|
||||
import com.intellij.platform.ml.*
|
||||
import com.intellij.platform.ml.ScopeEnvironment.Companion.restrictedBy
|
||||
import com.intellij.platform.ml.TierRequester.Companion.fulfilledBy
|
||||
import com.intellij.platform.ml.impl.DescriptionComputer
|
||||
import com.intellij.platform.ml.impl.FeatureSelector
|
||||
import com.intellij.platform.ml.impl.FeatureSelector.Companion.or
|
||||
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform
|
||||
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform.Companion.getDescriptorsOfTiers
|
||||
import com.intellij.platform.ml.impl.environment.ExtendedEnvironment
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
@ApiStatus.Internal
|
||||
data class LevelDescriptor(
|
||||
val apiPlatform: MLApiPlatform,
|
||||
val descriptionComputer: DescriptionComputer,
|
||||
val usedFeaturesSelectors: PerTier<FeatureSelector>,
|
||||
val notUsedFeaturesSelectors: PerTier<FeatureSelector>,
|
||||
) {
|
||||
fun describe(
|
||||
upperLevels: List<DescribedLevel>,
|
||||
nextLevelMainEnvironment: Environment,
|
||||
nextLevelAdditionalTiers: Set<Tier<*>>
|
||||
): DescribedLevel {
|
||||
val mainEnvironment = Environment.joined(listOf(
|
||||
Environment.of(upperLevels.flatMap { it.main.keys }),
|
||||
nextLevelMainEnvironment
|
||||
))
|
||||
|
||||
val availableEnvironment = ExtendedEnvironment(apiPlatform.environmentExtenders, mainEnvironment)
|
||||
val extendedEnvironment = availableEnvironment.restrictedBy(nextLevelMainEnvironment.tiers + nextLevelAdditionalTiers)
|
||||
|
||||
val runnableDescriptorsPerTier = apiPlatform.getDescriptorsOfTiers(extendedEnvironment.tiers)
|
||||
.mapValues { (_, descriptors) -> descriptors.fulfilledBy(availableEnvironment) }
|
||||
|
||||
val extendedEnvironmentDescription = runnableDescriptorsPerTier
|
||||
.mapValues { (tier, tierDescriptors) ->
|
||||
describeTier(tier, tierDescriptors, availableEnvironment)
|
||||
}
|
||||
|
||||
val nextLevel = DescribedLevel(
|
||||
main = nextLevelMainEnvironment.tierInstances.associateWith { mainTierInstance ->
|
||||
DescribedTierData(extendedEnvironmentDescription.getValue(mainTierInstance.tier))
|
||||
},
|
||||
additional = extendedEnvironment.restrictedBy(nextLevelAdditionalTiers).tierInstances.associateWith { mainTierInstance ->
|
||||
DescribedTierData(extendedEnvironmentDescription.getValue(mainTierInstance.tier))
|
||||
},
|
||||
)
|
||||
|
||||
return nextLevel
|
||||
}
|
||||
|
||||
private fun createFilterOfFeaturesToCompute(tier: Tier<*>, tierDescriptors: List<TierDescriptor>): FeatureFilter {
|
||||
if (tierDescriptors.any { it is ObsoleteTierDescriptor }) {
|
||||
return FeatureFilter.ACCEPT_ALL
|
||||
}
|
||||
|
||||
val tierFeaturesSelector = (usedFeaturesSelectors[tier] ?: FeatureSelector.NOTHING) or notUsedFeaturesSelectors.getValue(tier)
|
||||
val tierComputableFeatures = tierDescriptors.flatMap { it.descriptionDeclaration }.toSet()
|
||||
val tierToComputeSelection = tierFeaturesSelector.select(tierComputableFeatures)
|
||||
|
||||
if (tierToComputeSelection is FeatureSelector.Selection.Incomplete) {
|
||||
throw IncompleteDescriptionException(tier, tierToComputeSelection.selectedFeatures, tierToComputeSelection.details)
|
||||
}
|
||||
|
||||
return FeatureFilter { it in tierToComputeSelection.selectedFeatures }
|
||||
}
|
||||
|
||||
// Filter assumes that the feature is either used or not used by the model
|
||||
private fun createFilterOfUsedFeatures(tier: Tier<*>): FeatureFilter {
|
||||
val usedFeaturesSelector = usedFeaturesSelectors[tier] ?: FeatureSelector.NOTHING
|
||||
val notUsedFeaturesSelector = notUsedFeaturesSelectors.getValue(tier)
|
||||
return FeatureFilter {
|
||||
val featureIsUsed = usedFeaturesSelector.select(it)
|
||||
val featureIsNotUsed = notUsedFeaturesSelector.select(it)
|
||||
assert(featureIsUsed || featureIsNotUsed) {
|
||||
"Feature $it of $tier must not have been computed. It is not used by the ML model or marked as not used"
|
||||
}
|
||||
require(featureIsUsed xor featureIsNotUsed) {
|
||||
"${it} of $tier is used by the ML model, but marked as not used at the same time"
|
||||
}
|
||||
featureIsUsed
|
||||
}
|
||||
}
|
||||
|
||||
private fun Set<Feature>.splitByUsage(usableFeaturesFilter: FeatureFilter): Usage<Set<Feature>> {
|
||||
val usedFeatures = this.filter { usableFeaturesFilter.accept(it.declaration) }
|
||||
val notUsedFeatures = this.filter { !usableFeaturesFilter.accept(it.declaration) }
|
||||
return Usage(usedFeatures.toSet(), notUsedFeatures.toSet())
|
||||
}
|
||||
|
||||
private fun makeDescriptionPartition(descriptor: TierDescriptor,
|
||||
computedDescription: Set<Feature>,
|
||||
usedFeaturesFilter: FeatureFilter): DescriptionPartition {
|
||||
|
||||
val computedDescriptionDeclaration = computedDescription.map { it.declaration }.toSet()
|
||||
|
||||
if (descriptor is ObsoleteTierDescriptor) {
|
||||
val nonDeclaredDescription = computedDescription.filter { it.declaration !in descriptor.partialDescriptionDeclaration }.toSet()
|
||||
apiPlatform.manageNonDeclaredFeatures(descriptor, nonDeclaredDescription)
|
||||
}
|
||||
else {
|
||||
val notDeclaredFeaturesDeclarations = computedDescriptionDeclaration - descriptor.descriptionDeclaration
|
||||
require(notDeclaredFeaturesDeclarations.isEmpty()) {
|
||||
"""
|
||||
$descriptor described environment with some features that were not declared:
|
||||
${notDeclaredFeaturesDeclarations.map { it.name }}
|
||||
computed declaration: ${computedDescriptionDeclaration}
|
||||
declared declaration: ${descriptor.descriptionDeclaration}
|
||||
""".trimIndent()
|
||||
}
|
||||
}
|
||||
|
||||
val maybePartialDescriptionDeclaration = if (descriptor is ObsoleteTierDescriptor)
|
||||
descriptor.partialDescriptionDeclaration
|
||||
else
|
||||
descriptor.descriptionDeclaration
|
||||
|
||||
val notComputedDescriptionDeclaration = maybePartialDescriptionDeclaration - computedDescriptionDeclaration
|
||||
for (notComputedFeatureDeclaration in notComputedDescriptionDeclaration) {
|
||||
require(!usedFeaturesFilter.accept(notComputedFeatureDeclaration)) {
|
||||
"Feature ${notComputedFeatureDeclaration} was expected to be computed by $descriptor, " +
|
||||
"because was declared and accepted by the feature filter. Computed declaration: $computedDescriptionDeclaration"
|
||||
}
|
||||
}
|
||||
|
||||
val declaredFeatures = mutableSetOf<Feature>()
|
||||
val nonDeclaredFeatures = mutableSetOf<Feature>()
|
||||
|
||||
if (descriptor is ObsoleteTierDescriptor) {
|
||||
computedDescription.forEach {
|
||||
if (it.declaration in descriptor.partialDescriptionDeclaration)
|
||||
declaredFeatures += it
|
||||
else
|
||||
nonDeclaredFeatures += it
|
||||
}
|
||||
}
|
||||
else {
|
||||
declaredFeatures.addAll(computedDescription)
|
||||
}
|
||||
|
||||
return Declaredness(declaredFeatures.splitByUsage(usedFeaturesFilter), nonDeclaredFeatures.splitByUsage(usedFeaturesFilter))
|
||||
}
|
||||
|
||||
private fun describeTier(tier: Tier<*>, tierDescriptors: List<TierDescriptor>, environment: Environment): DescriptionPartition {
|
||||
val toComputeFilter = createFilterOfFeaturesToCompute(tier, tierDescriptors)
|
||||
val usefulTierDescriptors = tierDescriptors.filter { it.couldBeUseful(toComputeFilter) }
|
||||
|
||||
val description: Map<TierDescriptor, Set<Feature>> = descriptionComputer.computeDescription(
|
||||
tier,
|
||||
usefulTierDescriptors,
|
||||
environment,
|
||||
toComputeFilter
|
||||
)
|
||||
|
||||
val usedByModelFilter = createFilterOfUsedFeatures(tier)
|
||||
|
||||
val descriptionPartition = description.entries
|
||||
.map { (tierDescriptor, computedDescription) ->
|
||||
makeDescriptionPartition(tierDescriptor, computedDescription, usedByModelFilter)
|
||||
}
|
||||
.reduceOrNull { first, second ->
|
||||
Declaredness(
|
||||
declared = Usage(used = first.declared.used + second.declared.used,
|
||||
notUsed = first.declared.notUsed + second.declared.notUsed),
|
||||
nonDeclared = Usage(used = first.nonDeclared.used + second.nonDeclared.used,
|
||||
notUsed = first.nonDeclared.notUsed + second.nonDeclared.notUsed)
|
||||
)
|
||||
} ?: Declaredness(Usage(emptySet(), emptySet()), Usage(emptySet(), emptySet()))
|
||||
|
||||
return descriptionPartition
|
||||
}
|
||||
}
|
||||
|
||||
@ApiStatus.Internal
|
||||
class IncompleteDescriptionException(tier: Tier<*>,
|
||||
selectedFeatures: Set<FeatureDeclaration<*>>,
|
||||
missingFeaturesDetails: String) : Exception() {
|
||||
override val message = "Computable description of tier $tier is not sufficient: $missingFeaturesDetails. Computed features: $selectedFeatures"
|
||||
}
|
||||
@@ -0,0 +1,231 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.session
|
||||
|
||||
import com.intellij.platform.ml.*
|
||||
import com.intellij.platform.ml.impl.DescriptionComputer
|
||||
import com.intellij.platform.ml.impl.FeatureSelector
|
||||
import com.intellij.platform.ml.impl.LevelTiers
|
||||
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform
|
||||
import com.intellij.platform.ml.impl.model.MLModel
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
import java.util.concurrent.CompletableFuture
|
||||
|
||||
@ApiStatus.Internal
|
||||
class RootCollector<M : MLModel<P>, P : Any>(
|
||||
apiPlatform: MLApiPlatform,
|
||||
levelsTiers: List<LevelTiers>,
|
||||
descriptionComputer: DescriptionComputer,
|
||||
notUsedFeaturesSelectors: PerTier<FeatureSelector>,
|
||||
levelMainEnvironment: Environment,
|
||||
levelAdditionalTiers: Set<Tier<*>>,
|
||||
private val mlModel: M
|
||||
) : NestableStructureCollector<SessionTree.ComplexRoot<M, DescribedLevel, P>, M, P>() {
|
||||
override val levelDescriptor = LevelDescriptor(apiPlatform, descriptionComputer, mlModel.knownFeatures, notUsedFeaturesSelectors)
|
||||
override val levelPositioning = LevelPositioning.superior(levelsTiers, levelDescriptor, levelMainEnvironment, levelAdditionalTiers)
|
||||
|
||||
init {
|
||||
validateSuperiorCollector(levelsTiers, levelMainEnvironment, levelAdditionalTiers, mlModel, notUsedFeaturesSelectors)
|
||||
}
|
||||
|
||||
override fun createTree(thisLevel: DescribedLevel,
|
||||
collectedNestedStructureTrees: List<DescribedSessionTree<M, P>>): SessionTree.ComplexRoot<M, DescribedLevel, P> {
|
||||
return SessionTree.ComplexRoot(mlModel, thisLevel, collectedNestedStructureTrees)
|
||||
}
|
||||
}
|
||||
|
||||
@ApiStatus.Internal
|
||||
class SolitaryLeafCollector<M : MLModel<P>, P : Any>(
|
||||
apiPlatform: MLApiPlatform,
|
||||
levelScheme: LevelTiers,
|
||||
descriptionComputer: DescriptionComputer,
|
||||
notUsedFeaturesSelectors: PerTier<FeatureSelector>,
|
||||
levelMainEnvironment: Environment,
|
||||
levelAdditionalTiers: Set<Tier<*>>,
|
||||
private val mlModel: M
|
||||
) : PredictionCollector<SessionTree.SolitaryLeaf<M, DescribedLevel, P>, M, P>() {
|
||||
override val levelDescriptor = LevelDescriptor(apiPlatform, descriptionComputer, mlModel.knownFeatures, notUsedFeaturesSelectors)
|
||||
override val levelPositioning = LevelPositioning.superior(listOf(levelScheme), levelDescriptor, levelMainEnvironment,
|
||||
levelAdditionalTiers)
|
||||
|
||||
init {
|
||||
validateSuperiorCollector(listOf(levelScheme), levelMainEnvironment, levelAdditionalTiers, mlModel, notUsedFeaturesSelectors)
|
||||
}
|
||||
|
||||
override fun createTree(thisLevel: DescribedLevel, prediction: P?): SessionTree.SolitaryLeaf<M, DescribedLevel, P> {
|
||||
return SessionTree.SolitaryLeaf(mlModel, levelPositioning.thisLevel, prediction)
|
||||
}
|
||||
}
|
||||
|
||||
@ApiStatus.Internal
|
||||
class BranchingCollector<M : MLModel<P>, P : Any>(
|
||||
override val levelDescriptor: LevelDescriptor,
|
||||
override val levelPositioning: LevelPositioning
|
||||
) : NestableStructureCollector<SessionTree.Branching<M, DescribedLevel, P>, M, P>() {
|
||||
override fun createTree(thisLevel: DescribedLevel,
|
||||
collectedNestedStructureTrees: List<DescribedSessionTree<M, P>>): SessionTree.Branching<M, DescribedLevel, P> {
|
||||
return SessionTree.Branching(thisLevel, collectedNestedStructureTrees)
|
||||
}
|
||||
}
|
||||
|
||||
@ApiStatus.Internal
|
||||
abstract class PredictionCollector<T : SessionTree.PredictionContainer<M, DescribedLevel, P>, M : MLModel<P>, P : Any> : StructureCollector<T, M, P>() {
|
||||
private var predictionSubmitted = false
|
||||
private var submittedPrediction: P? = null
|
||||
|
||||
abstract fun createTree(thisLevel: DescribedLevel, prediction: P?): T
|
||||
|
||||
val usableDescription: PerTier<Set<Feature>>
|
||||
get() = levelPositioning.levels.extractDescriptionForModel()
|
||||
|
||||
fun submitPrediction(prediction: P?) {
|
||||
require(!predictionSubmitted)
|
||||
submittedPrediction = prediction
|
||||
predictionSubmitted = true
|
||||
submitTreeToHandlers(createTree(levelPositioning.thisLevel, submittedPrediction))
|
||||
}
|
||||
|
||||
private fun PerTierInstance<DescribedTierData>.extractDescriptionForModel(): PerTier<Set<Feature>> {
|
||||
return this.entries.associate { (tierInstance, data) ->
|
||||
tierInstance.tier to data.description.declared.used + data.description.nonDeclared.used
|
||||
}
|
||||
}
|
||||
|
||||
private fun DescribedLevel.extractDescriptionForModel(): PerTier<Set<Feature>> {
|
||||
val mainDescription = this.main.extractDescriptionForModel()
|
||||
val additionalDescription = this.additional.extractDescriptionForModel()
|
||||
return listOf(mainDescription + additionalDescription).joinByUniqueTier()
|
||||
}
|
||||
|
||||
private fun Iterable<DescribedLevel>.extractDescriptionForModel(): PerTier<Set<Feature>> {
|
||||
return this.map { it.extractDescriptionForModel() }.joinByUniqueTier()
|
||||
}
|
||||
}
|
||||
|
||||
@ApiStatus.Internal
|
||||
class LeafCollector<M : MLModel<P>, P : Any>(
|
||||
override val levelDescriptor: LevelDescriptor,
|
||||
override val levelPositioning: LevelPositioning
|
||||
) : PredictionCollector<SessionTree.Leaf<M, DescribedLevel, P>, M, P>() {
|
||||
|
||||
override fun createTree(thisLevel: DescribedLevel, prediction: P?): SessionTree.Leaf<M, DescribedLevel, P> {
|
||||
return SessionTree.Leaf(levelPositioning.thisLevel, prediction)
|
||||
}
|
||||
}
|
||||
|
||||
private fun <M : MLModel<P>, P : Any> validateSuperiorCollector(levelsTiers: List<LevelTiers>,
|
||||
levelMainEnvironment: Environment,
|
||||
levelAdditionalTiers: Set<Tier<*>>,
|
||||
mlModel: M,
|
||||
notUsedDescriptionSelectors: PerTier<FeatureSelector>) {
|
||||
val allTiers = (levelsTiers.flatMap { it.main + it.additional } + levelMainEnvironment.tiers + levelAdditionalTiers).toSet()
|
||||
val usedFeaturesSelectors = mlModel.knownFeatures
|
||||
|
||||
require(allTiers.containsAll(usedFeaturesSelectors.keys)) {
|
||||
"ML Model uses tiers that are not main or additional: ${usedFeaturesSelectors.keys - allTiers}"
|
||||
}
|
||||
require(notUsedDescriptionSelectors.keys == allTiers) {
|
||||
"""
|
||||
Not used description's tiers must be same as the task tiersAll tiers
|
||||
Missing: ${allTiers - notUsedDescriptionSelectors.keys}
|
||||
Redundant: ${notUsedDescriptionSelectors.keys - allTiers}
|
||||
""".trimIndent()
|
||||
}
|
||||
}
|
||||
|
||||
@ApiStatus.Internal
|
||||
data class LevelPositioning(
|
||||
val upperLevels: List<DescribedLevel>,
|
||||
val lowerTiers: List<LevelTiers>,
|
||||
val thisLevel: DescribedLevel,
|
||||
) {
|
||||
val levels: List<DescribedLevel> = upperLevels + thisLevel
|
||||
|
||||
fun nestNextLevel(
|
||||
levelDescriptor: LevelDescriptor,
|
||||
nextLevelMainEnvironment: Environment,
|
||||
nextLevelAdditionalTiers: Set<Tier<*>>
|
||||
): LevelPositioning {
|
||||
return LevelPositioning(
|
||||
upperLevels = upperLevels + thisLevel,
|
||||
lowerTiers = lowerTiers.drop(1),
|
||||
thisLevel = levelDescriptor.describe(upperLevels, nextLevelMainEnvironment, nextLevelAdditionalTiers)
|
||||
)
|
||||
}
|
||||
|
||||
companion object {
|
||||
fun superior(levelsTiers: List<LevelTiers>,
|
||||
levelDescriptor: LevelDescriptor,
|
||||
levelMainEnvironment: Environment,
|
||||
levelAdditionalTiers: Set<Tier<*>>): LevelPositioning {
|
||||
return LevelPositioning(emptyList(), levelsTiers.drop(1),
|
||||
levelDescriptor.describe(emptyList(), levelMainEnvironment, levelAdditionalTiers))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ApiStatus.Internal
|
||||
sealed class StructureCollector<T : DescribedSessionTree<M, P>, M : MLModel<P>, P : Any> {
|
||||
protected abstract val levelDescriptor: LevelDescriptor
|
||||
abstract val levelPositioning: LevelPositioning
|
||||
|
||||
private val sessionTreeHandlers: MutableList<SessionTreeHandler<in T, M, P>> = mutableListOf()
|
||||
|
||||
fun handleCollectedTree(handler: SessionTreeHandler<in T, M, P>) {
|
||||
sessionTreeHandlers.add(handler)
|
||||
}
|
||||
|
||||
protected fun submitTreeToHandlers(sessionTree: T) {
|
||||
sessionTreeHandlers.forEach { it.handleTree(sessionTree) }
|
||||
}
|
||||
}
|
||||
|
||||
@ApiStatus.Internal
|
||||
abstract class NestableStructureCollector<T : DescribedChildrenContainer<M, P>, M : MLModel<P>, P : Any> : StructureCollector<T, M, P>() {
|
||||
private val nestedSessionsStructures: MutableList<CompletableFuture<DescribedSessionTree<M, P>>> = mutableListOf()
|
||||
private var nestingFinished = false
|
||||
|
||||
fun nestBranch(levelMainEnvironment: Environment, levelAdditionalTiers: Set<Tier<*>>): BranchingCollector<M, P> {
|
||||
verifyNestedLevelEnvironment(levelMainEnvironment, levelAdditionalTiers)
|
||||
return BranchingCollector<M, P>(levelDescriptor,
|
||||
levelPositioning.nestNextLevel(levelDescriptor, levelMainEnvironment, levelAdditionalTiers))
|
||||
.also { it.trackCollectedStructure() }
|
||||
}
|
||||
|
||||
fun nestPrediction(levelMainEnvironment: Environment, levelAdditionalTiers: Set<Tier<*>>): LeafCollector<M, P> {
|
||||
verifyNestedLevelEnvironment(levelMainEnvironment, levelAdditionalTiers)
|
||||
return LeafCollector<M, P>(levelDescriptor, levelPositioning.nestNextLevel(levelDescriptor, levelMainEnvironment, levelAdditionalTiers))
|
||||
.also { it.trackCollectedStructure() }
|
||||
}
|
||||
|
||||
fun onLastNestedCollectorCreated() {
|
||||
require(!nestingFinished)
|
||||
nestingFinished = true
|
||||
maybeSubmitStructure()
|
||||
}
|
||||
|
||||
private fun <K : DescribedSessionTree<M, P>> StructureCollector<K, M, P>.trackCollectedStructure() {
|
||||
val collectedNestedTreeContainer = CompletableFuture<DescribedSessionTree<M, P>>()
|
||||
nestedSessionsStructures += collectedNestedTreeContainer
|
||||
this.handleCollectedTree {
|
||||
collectedNestedTreeContainer.complete(it)
|
||||
}
|
||||
}
|
||||
|
||||
private fun maybeSubmitStructure() {
|
||||
val collectedNestedStructureTrees: List<SessionTree<M, DescribedLevel, P>> = nestedSessionsStructures
|
||||
.filter { it.isDone }
|
||||
.map { it.get() }
|
||||
|
||||
if (nestingFinished && collectedNestedStructureTrees.size == nestedSessionsStructures.size) {
|
||||
val describedSessionTree = createTree(levelPositioning.thisLevel, collectedNestedStructureTrees)
|
||||
submitTreeToHandlers(describedSessionTree)
|
||||
}
|
||||
}
|
||||
|
||||
protected abstract fun createTree(thisLevel: DescribedLevel, collectedNestedStructureTrees: List<DescribedSessionTree<M, P>>): T
|
||||
|
||||
private fun verifyNestedLevelEnvironment(levelMainEnvironment: Environment, levelAdditionalTiers: Set<Tier<*>>) {
|
||||
require(levelPositioning.lowerTiers.first().main == levelMainEnvironment.tiers)
|
||||
require(levelPositioning.lowerTiers.first().additional == levelAdditionalTiers)
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,275 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.session
|
||||
|
||||
import com.intellij.platform.ml.Environment
|
||||
import com.intellij.platform.ml.Feature
|
||||
import com.intellij.platform.ml.FeatureDeclaration
|
||||
import com.intellij.platform.ml.PerTierInstance
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* A partition of [Feature]s that indicates that they could be either used or not used by an ML model.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
data class Usage<T>(
|
||||
val used: T,
|
||||
val notUsed: T,
|
||||
)
|
||||
|
||||
/**
|
||||
* A partition of [Feature]s that indicates that they could be either declared (by [com.intellij.platform.ml.TierDescriptor]) or not.
|
||||
*
|
||||
* Used to process [com.intellij.platform.ml.ObsoleteTierDescriptor]'s partial descriptions.
|
||||
* It shall be removed when all obsolete descriptors will be transferred to the new API.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
data class Declaredness<T>(
|
||||
val declared: T,
|
||||
val nonDeclared: T
|
||||
)
|
||||
|
||||
/**
|
||||
* There are two characteristics of a feature as for now: whether it is declared (statically, in a tier descriptor),
|
||||
* and whether it is used by the ML model.
|
||||
* So, this container creates a partition for each category of features.
|
||||
*/
|
||||
typealias DescriptionPartition = Declaredness<Usage<Set<Feature>>>
|
||||
|
||||
/**
|
||||
* A main tier has a description (that is used to run the ML model), and it also could contain analysis features
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
data class MainTierScheme(
|
||||
val description: Set<FeatureDeclaration<*>>,
|
||||
val analysis: Set<FeatureDeclaration<*>>
|
||||
)
|
||||
|
||||
/**
|
||||
* An additional tier is provided occasionally, and it has only description
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
data class AdditionalTierScheme(
|
||||
val description: Set<FeatureDeclaration<*>>
|
||||
)
|
||||
|
||||
/**
|
||||
* All the data, that a tier instance that has been described has
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
data class DescribedTierData(
|
||||
val description: DescriptionPartition,
|
||||
)
|
||||
|
||||
/**
|
||||
* All the data, that a tier instance that has been described and analyzed has
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
data class AnalysedTierData(
|
||||
val description: DescriptionPartition,
|
||||
val analysis: Set<Feature>
|
||||
)
|
||||
|
||||
/**
|
||||
* A template for a container that contains data of main tier instances, as well as additional instances.
|
||||
* A Level is a collection of tiers that were declared on the same depth of an [com.intellij.platform.ml.impl.MLTask]'s declaration,
|
||||
* plus additional tiers, declared by [com.intellij.platform.ml.impl.approach.LogDrivenModelInference.additionallyDescribedTiers]
|
||||
* on the corresponding level.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
data class Level<M, A>(val main: M, val additional: A)
|
||||
|
||||
typealias DescribedLevel = Level<PerTierInstance<DescribedTierData>, PerTierInstance<DescribedTierData>>
|
||||
|
||||
typealias AnalysedLevel = Level<PerTierInstance<AnalysedTierData>, PerTierInstance<DescribedTierData>>
|
||||
|
||||
/**
|
||||
* Tree-like ml session's structure.
|
||||
*
|
||||
* All trees leaves have the same depths.
|
||||
* The depth corresponds to the number of levels in an [com.intellij.platform.ml.impl.MLTask].
|
||||
* And the tree's structure is built by calling [com.intellij.platform.ml.NestableMLSession.createNestedSession].
|
||||
*
|
||||
* @param RootT Type of the data, that is stored in the root node.
|
||||
* @param LevelT Type of the data, that is stored in each tree's node.
|
||||
* @param PredictionT Type of the session's prediction.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
sealed interface SessionTree<RootT, LevelT, PredictionT> {
|
||||
/**
|
||||
* Data, that is stored in each tree's node.
|
||||
*/
|
||||
val level: LevelT
|
||||
|
||||
/**
|
||||
* Accepts the [visitor], calling the corresponding interface's function.
|
||||
*/
|
||||
fun <T> accept(visitor: Visitor<RootT, LevelT, PredictionT, T>): T
|
||||
|
||||
/**
|
||||
* Something that contains tree's root data.
|
||||
* There are two such classes: a [SolitaryLeaf], and [ComplexRoot].
|
||||
*/
|
||||
sealed interface RootContainer<RootT, LevelT, PredictionT> : SessionTree<RootT, LevelT, PredictionT> {
|
||||
/**
|
||||
* Data, that is stored only in the tree's root.
|
||||
* ML model, for example.
|
||||
*/
|
||||
val root: RootT
|
||||
}
|
||||
|
||||
/**
|
||||
* Something that has nested nodes.
|
||||
* It could be either [ComplexRoot], or [Branching].
|
||||
* The number of children corresponds to number of calls of
|
||||
* [com.intellij.platform.ml.NestableMLSession.createNestedSession].
|
||||
*/
|
||||
sealed interface ChildrenContainer<RootT, LevelT, PredictionT> : SessionTree<RootT, LevelT, PredictionT> {
|
||||
/**
|
||||
* All nested trees, that were built by calling
|
||||
* [com.intellij.platform.ml.NestableMLSession.createNestedSession] on this level.
|
||||
*/
|
||||
val children: List<SessionTree<RootT, LevelT, PredictionT>>
|
||||
}
|
||||
|
||||
/**
|
||||
* Something that contains session's prediction.
|
||||
* It is produced by [com.intellij.platform.ml.SinglePrediction], and the prediction could be either
|
||||
* produced or canceled, hence the [prediction] is nullable.
|
||||
*/
|
||||
sealed interface PredictionContainer<RootT, LevelT, PredictionT> : SessionTree<RootT, LevelT, PredictionT> {
|
||||
val prediction: PredictionT?
|
||||
}
|
||||
|
||||
/**
|
||||
* Corresponds to an ML task session's structure, that had only one level, and could not have been nested.
|
||||
* Hence, it contains root data and a prediction simultaneously.
|
||||
*/
|
||||
data class SolitaryLeaf<RootT, LevelT, PredictionT>(
|
||||
override val root: RootT,
|
||||
override val level: LevelT,
|
||||
override val prediction: PredictionT?
|
||||
) : RootContainer<RootT, LevelT, PredictionT>, PredictionContainer<RootT, LevelT, PredictionT> {
|
||||
override fun <T> accept(visitor: Visitor<RootT, LevelT, PredictionT, T>): T {
|
||||
return visitor.acceptSolitaryLeaf(this)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Corresponds to an ML task session's structure, that had more than one level.
|
||||
*/
|
||||
data class ComplexRoot<RootT, LevelT, PredictionT>(
|
||||
override val root: RootT,
|
||||
override val level: LevelT,
|
||||
override val children: List<SessionTree<RootT, LevelT, PredictionT>>
|
||||
) : RootContainer<RootT, LevelT, PredictionT>, ChildrenContainer<RootT, LevelT, PredictionT> {
|
||||
override fun <T> accept(visitor: Visitor<RootT, LevelT, PredictionT, T>): T = visitor.acceptRoot(this)
|
||||
}
|
||||
|
||||
/**
|
||||
* Corresponds to a node in an ML task session's structure, that had more than one level.
|
||||
*/
|
||||
data class Branching<RootT, LevelT, PredictionT>(
|
||||
override val level: LevelT,
|
||||
override val children: List<SessionTree<RootT, LevelT, PredictionT>>
|
||||
) : SessionTree<RootT, LevelT, PredictionT>, ChildrenContainer<RootT, LevelT, PredictionT> {
|
||||
override fun <T> accept(visitor: Visitor<RootT, LevelT, PredictionT, T>): T = visitor.acceptBranching(this)
|
||||
}
|
||||
|
||||
/**
|
||||
* Corresponds to a leaf in an ML task session's structure, that ad more than one level.
|
||||
*/
|
||||
data class Leaf<RootT, LevelT, PredictionT>(
|
||||
override val level: LevelT,
|
||||
override val prediction: PredictionT?
|
||||
) : SessionTree<RootT, LevelT, PredictionT>, PredictionContainer<RootT, LevelT, PredictionT> {
|
||||
override fun <T> accept(visitor: Visitor<RootT, LevelT, PredictionT, T>): T = visitor.acceptLeaf(this)
|
||||
}
|
||||
|
||||
/**
|
||||
* Visits a [SessionTree]'s node.
|
||||
*
|
||||
* @param RootT Data's type, stored in the tree's root.
|
||||
* @param LevelT Data's type, stored in each node.
|
||||
* @param PredictionT Prediction's type in the ML session.
|
||||
* @param T Data type that is returned by the visitor.
|
||||
*/
|
||||
interface Visitor<RootT, LevelT, PredictionT, out T> {
|
||||
fun acceptBranching(branching: Branching<RootT, LevelT, PredictionT>): T
|
||||
|
||||
fun acceptLeaf(leaf: Leaf<RootT, LevelT, PredictionT>): T
|
||||
|
||||
fun acceptRoot(root: ComplexRoot<RootT, LevelT, PredictionT>): T
|
||||
|
||||
fun acceptSolitaryLeaf(solitaryLeaf: SolitaryLeaf<RootT, LevelT, PredictionT>): T
|
||||
}
|
||||
|
||||
/**
|
||||
* Visits all tree's nodes on [levelIndex] depth.
|
||||
*/
|
||||
abstract class LevelVisitor<RootT, PredictionT : Any> private constructor(
|
||||
private val levelIndex: Int,
|
||||
private val thisVisitorLevel: Int,
|
||||
) : Visitor<RootT, DescribedLevel, PredictionT, Unit> {
|
||||
constructor(levelIndex: Int) : this(levelIndex, 0)
|
||||
|
||||
private inner class DeeperLevelVisitor : LevelVisitor<RootT, PredictionT>(levelIndex, thisVisitorLevel + 1) {
|
||||
override fun visitLevel(level: DescribedLevel, levelRoot: SessionTree<RootT, DescribedLevel, PredictionT>) {
|
||||
this@LevelVisitor.visitLevel(level, levelRoot)
|
||||
}
|
||||
}
|
||||
|
||||
private fun maybeVisitLevel(level: DescribedLevel,
|
||||
levelRoot: SessionTree<RootT, DescribedLevel, PredictionT>): Boolean =
|
||||
if (levelIndex == thisVisitorLevel) {
|
||||
visitLevel(level, levelRoot)
|
||||
true
|
||||
}
|
||||
else false
|
||||
|
||||
final override fun acceptBranching(branching: Branching<RootT, DescribedLevel, PredictionT>) {
|
||||
if (maybeVisitLevel(branching.level, branching)) return
|
||||
for (child in branching.children) {
|
||||
child.accept(DeeperLevelVisitor())
|
||||
}
|
||||
}
|
||||
|
||||
final override fun acceptLeaf(leaf: Leaf<RootT, DescribedLevel, PredictionT>) {
|
||||
require(maybeVisitLevel(leaf.level, leaf)) {
|
||||
"The deepest level in the session tree is $thisVisitorLevel, given level $levelIndex does not exist"
|
||||
}
|
||||
}
|
||||
|
||||
final override fun acceptRoot(root: ComplexRoot<RootT, DescribedLevel, PredictionT>) {
|
||||
if (maybeVisitLevel(root.level, root)) return
|
||||
for (child in root.children) {
|
||||
child.accept(DeeperLevelVisitor())
|
||||
}
|
||||
}
|
||||
|
||||
final override fun acceptSolitaryLeaf(solitaryLeaf: SolitaryLeaf<RootT, DescribedLevel, PredictionT>) {
|
||||
require(maybeVisitLevel(solitaryLeaf.level, solitaryLeaf)) {
|
||||
"The only level in the session tree is $thisVisitorLevel, given level $levelIndex does not exist"
|
||||
}
|
||||
}
|
||||
|
||||
abstract fun visitLevel(level: DescribedLevel, levelRoot: SessionTree<RootT, DescribedLevel, PredictionT>)
|
||||
}
|
||||
}
|
||||
|
||||
typealias DescribedSessionTree<R, P> = SessionTree<R, DescribedLevel, P>
|
||||
|
||||
typealias DescribedChildrenContainer<R, P> = SessionTree.ChildrenContainer<R, DescribedLevel, P>
|
||||
|
||||
typealias DescribedRootContainer<R, P> = SessionTree.RootContainer<R, DescribedLevel, P>
|
||||
|
||||
typealias SessionAnalysis = Map<String, Set<Feature>>
|
||||
|
||||
typealias AnalysedSessionTree<P> = SessionTree<SessionAnalysis, AnalysedLevel, P>
|
||||
|
||||
typealias AnalysedRootContainer<P> = SessionTree.RootContainer<SessionAnalysis, AnalysedLevel, P>
|
||||
|
||||
val <R, P> DescribedSessionTree<R, P>.environment: Environment
|
||||
get() = Environment.of(this.level.main.keys)
|
||||
|
||||
val DescribedLevel.environment: Environment
|
||||
get() = Environment.of(this.main.keys)
|
||||
@@ -0,0 +1,9 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.session
|
||||
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
@ApiStatus.Internal
|
||||
fun interface SessionTreeHandler<T : DescribedSessionTree<R, P>, R, P> {
|
||||
fun handleTree(sessionTree: T)
|
||||
}
|
||||
@@ -0,0 +1,25 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.session.analysis
|
||||
|
||||
import com.intellij.platform.ml.Feature
|
||||
import com.intellij.platform.ml.FeatureDeclaration
|
||||
import com.intellij.platform.ml.impl.model.MLModel
|
||||
|
||||
/**
|
||||
* An analyzer which is dedicated to give features to the session's model.
|
||||
*
|
||||
* [SessionAnalyser.analysisDeclaration] returns a set of [FeatureDeclaration] -
|
||||
* the features that the model will be described with.
|
||||
* [SessionAnalyser.analyse] in this case returns a set of features - the model's analysis.
|
||||
*/
|
||||
typealias MLModelAnalyser<M, P> = SessionAnalyser<Set<FeatureDeclaration<*>>, Set<Feature>, M, P>
|
||||
|
||||
internal class MLModelAnalysisJoiner<M : MLModel<P>, P : Any> : AnalysisJoiner<Set<FeatureDeclaration<*>>, Set<Feature>, M, P> {
|
||||
override fun joinDeclarations(declarations: Iterable<Set<FeatureDeclaration<*>>>): Set<FeatureDeclaration<*>> {
|
||||
return declarations.flatten().toSet()
|
||||
}
|
||||
|
||||
override fun joinAnalysis(analysis: Iterable<Set<Feature>>): Set<Feature> {
|
||||
return analysis.flatten().toSet()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.session.analysis
|
||||
|
||||
import com.intellij.platform.ml.impl.model.MLModel
|
||||
import com.intellij.platform.ml.impl.session.DescribedRootContainer
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
import java.util.concurrent.CompletableFuture
|
||||
|
||||
/**
|
||||
* An interface for classes, that are analyzing the ML session after it was finished.
|
||||
*
|
||||
* Analysis is indefinitely long process (which is emphasized by the fact [analyse] returns a [CompletableFuture]).
|
||||
*
|
||||
* @param D Type of the analysis' declaration
|
||||
* @param A Type of the analysis itself
|
||||
* @param M Type of the model that has been utilized during the ML session
|
||||
* @param P Type of the session's prediction
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
interface SessionAnalyser<D, A, M : MLModel<P>, P : Any> {
|
||||
/**
|
||||
* Contains all features' declarations - [com.intellij.platform.ml.FeatureDeclaration],
|
||||
* that are then used in the analysis as [analyse]'s return value.
|
||||
*/
|
||||
val analysisDeclaration: D
|
||||
|
||||
/**
|
||||
* Performs session tree's analysis. The analysis is performed asynchronously,
|
||||
* so the function is not required to return the final result, but only a [CompletableFuture]
|
||||
* that will be fulfilled when the analysis is finished.
|
||||
*/
|
||||
fun analyse(sessionTreeRoot: DescribedRootContainer<M, P>): CompletableFuture<A>
|
||||
}
|
||||
|
||||
/**
|
||||
* Gathers analyses from different [SessionAnalyser]s to one.
|
||||
*/
|
||||
internal class JoinedSessionAnalyser<D, A, M : MLModel<P>, P : Any>(
|
||||
private val baseAnalysers: Collection<SessionAnalyser<D, A, M, P>>,
|
||||
private val joiner: AnalysisJoiner<D, A, M, P>
|
||||
) : SessionAnalyser<D, A, M, P> {
|
||||
override fun analyse(sessionTreeRoot: DescribedRootContainer<M, P>): CompletableFuture<A> {
|
||||
val scatteredAnalysis = mutableListOf<CompletableFuture<A>>()
|
||||
for (sessionAnalyser in baseAnalysers) {
|
||||
scatteredAnalysis.add(sessionAnalyser.analyse(sessionTreeRoot))
|
||||
}
|
||||
|
||||
val joinedAnalysis = CompletableFuture<A>()
|
||||
val eachScatteredAnalysisFinished = CompletableFuture.allOf(*scatteredAnalysis.toTypedArray())
|
||||
|
||||
eachScatteredAnalysisFinished.thenRun {
|
||||
val completeAnalysis = scatteredAnalysis.mapNotNull { it.get() }
|
||||
joinedAnalysis.complete(joiner.joinAnalysis(completeAnalysis))
|
||||
}
|
||||
|
||||
return joinedAnalysis
|
||||
}
|
||||
|
||||
override val analysisDeclaration: D
|
||||
get() = joiner.joinDeclarations(baseAnalysers.map { it.analysisDeclaration })
|
||||
}
|
||||
|
||||
internal interface AnalysisJoiner<D, A, M : MLModel<P>, P : Any> {
|
||||
fun joinDeclarations(declarations: Iterable<D>): D
|
||||
|
||||
fun joinAnalysis(analysis: Iterable<A>): A
|
||||
}
|
||||
@@ -0,0 +1,40 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.session.analysis
|
||||
|
||||
import com.intellij.platform.ml.Feature
|
||||
import com.intellij.platform.ml.FeatureDeclaration
|
||||
import com.intellij.platform.ml.PerTier
|
||||
import com.intellij.platform.ml.impl.model.MLModel
|
||||
import com.intellij.platform.ml.impl.session.DescribedSessionTree
|
||||
import com.intellij.platform.ml.mergePerTier
|
||||
|
||||
typealias StructureAnalysis<M, P> = Map<DescribedSessionTree<M, P>, PerTier<Set<Feature>>>
|
||||
|
||||
typealias StructureAnalysisDeclaration = PerTier<Set<FeatureDeclaration<*>>>
|
||||
|
||||
/**
|
||||
* An analyzer that gives analytical features to the session tree's nodes.
|
||||
*
|
||||
* [SessionAnalyser.analysisDeclaration] returns a set of [FeatureDeclaration] per each analyzed tier.
|
||||
* [SessionAnalyser.analyse] returns a mapping from each analyzed tree's node to sets of features.
|
||||
*
|
||||
* For example, if you want to give some analysis to the session tree's root, then you will return
|
||||
* mapOf(sessionTreeRoot to setOf(...)).
|
||||
*
|
||||
* @see com.intellij.platform.ml.impl.session.SessionTree.Visitor to learn how you could walk the tree's nodes
|
||||
* @see com.intellij.platform.ml.impl.session.SessionTree.LevelVisitor to see learn how you could session's levels.
|
||||
*/
|
||||
typealias StructureAnalyser<M, P> = SessionAnalyser<StructureAnalysisDeclaration, StructureAnalysis<M, P>, M, P>
|
||||
|
||||
internal class SessionStructureAnalysisJoiner<M : MLModel<P>, P : Any> : AnalysisJoiner<StructureAnalysisDeclaration, StructureAnalysis<M, P>, M, P> {
|
||||
override fun joinAnalysis(analysis: Iterable<StructureAnalysis<M, P>>): StructureAnalysis<M, P> {
|
||||
return analysis
|
||||
.flatMap { it.entries }
|
||||
.groupBy({ it.key }, { it.value })
|
||||
.mapValues { it.value.mergePerTier { mutableSetOf() } }
|
||||
}
|
||||
|
||||
override fun joinDeclarations(declarations: Iterable<StructureAnalysisDeclaration>): StructureAnalysisDeclaration {
|
||||
return declarations.mergePerTier { mutableSetOf() }
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,43 @@
|
||||
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.session.analysis
|
||||
|
||||
import com.intellij.internal.statistic.eventLog.events.EventField
|
||||
import com.intellij.internal.statistic.eventLog.events.EventPair
|
||||
import com.intellij.internal.statistic.eventLog.events.ObjectDescription
|
||||
import com.intellij.platform.ml.Environment
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
/**
|
||||
* Analyzes not the ML session's tree, but some generic information.
|
||||
* Could be used to analyze started, as well as failed to start sessions.
|
||||
*/
|
||||
@ApiStatus.Internal
|
||||
interface ShallowSessionAnalyser<D> {
|
||||
/**
|
||||
* Name of the analyzer.
|
||||
*/
|
||||
val name: String
|
||||
|
||||
/**
|
||||
* A complete static declaration of the fields, that will be written during analysis.
|
||||
*/
|
||||
val declaration: List<EventField<*>>
|
||||
|
||||
/**
|
||||
* Analyze some generic information about an ML session.
|
||||
*
|
||||
* @param permanentSessionEnvironment The environment that is available during the whole ML session
|
||||
* @param data Some additional data, an insight about the place where the analyzer was called.
|
||||
* Could be a [Throwable], or a reason why it was not possible to start the session.
|
||||
*/
|
||||
fun analyse(permanentSessionEnvironment: Environment, data: D): List<EventPair<*>>
|
||||
|
||||
companion object {
|
||||
val <F> ShallowSessionAnalyser<F>.declarationObjectDescription: ObjectDescription
|
||||
get() = object : ObjectDescription() {
|
||||
init {
|
||||
declaration.forEach { field(it) }
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,77 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.session
|
||||
|
||||
import com.intellij.platform.ml.*
|
||||
|
||||
/**
|
||||
* A wrapper for convenient usage of a [NestableMLSession].
|
||||
*
|
||||
* This function assumes that this session is nestable.
|
||||
*/
|
||||
fun <P : Any> Session<P>.withNestedSessions(useCreator: (NestableSessionWrapper<P>) -> Unit) {
|
||||
val nestableMLSession = requireNotNull(this as? NestableMLSession<P>)
|
||||
|
||||
val creator = object : NestableSessionWrapper<P> {
|
||||
override fun nestConsidering(levelEnvironment: Environment): Session<P> {
|
||||
return nestableMLSession.createNestedSession(levelEnvironment)
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
return useCreator(creator)
|
||||
}
|
||||
finally {
|
||||
nestableMLSession.onLastNestedSessionCreated()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A wrapper for convenient usage of a [SinglePrediction].
|
||||
*
|
||||
* This function assumes that this session is [NestableMLSession], and the nested
|
||||
* sessions' types are [SinglePrediction].
|
||||
*/
|
||||
fun <T, P : Any> Session<P>.withPredictions(useModelWrapper: (ModelWrapper<P>) -> T): T {
|
||||
val nestableMLSession = requireNotNull(this as? NestableMLSession<P>)
|
||||
val predictor = object : ModelWrapper<P> {
|
||||
override fun predictConsidering(predictionEnvironment: Environment): P {
|
||||
val predictionSession = nestableMLSession.createNestedSession(predictionEnvironment)
|
||||
require(predictionSession is SinglePrediction<P>)
|
||||
return predictionSession.predict()
|
||||
}
|
||||
|
||||
override fun consider(predictionEnvironment: Environment) {
|
||||
val predictionSession = nestableMLSession.createNestedSession(predictionEnvironment)
|
||||
require(predictionSession is SinglePrediction<P>)
|
||||
predictionSession.cancelPrediction()
|
||||
}
|
||||
}
|
||||
try {
|
||||
return useModelWrapper(predictor)
|
||||
}
|
||||
finally {
|
||||
nestableMLSession.onLastNestedSessionCreated()
|
||||
}
|
||||
}
|
||||
|
||||
interface NestableSessionWrapper<P : Any> {
|
||||
fun nestConsidering(levelEnvironment: Environment): Session<P>
|
||||
|
||||
companion object {
|
||||
fun <P : Any> NestableSessionWrapper<P>.nestConsidering(vararg levelTierInstances: TierInstance<*>): Session<P> {
|
||||
return this.nestConsidering(Environment.of(*levelTierInstances))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
interface ModelWrapper<P : Any> {
|
||||
fun predictConsidering(predictionEnvironment: Environment): P
|
||||
|
||||
fun consider(predictionEnvironment: Environment)
|
||||
|
||||
companion object {
|
||||
fun <P : Any> ModelWrapper<P>.predictConsidering(vararg predictionTierInstances: TierInstance<*>): P {
|
||||
return this.predictConsidering(Environment.of(*predictionTierInstances))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl.turboComplete
|
||||
|
||||
import com.intellij.platform.ml.Tier
|
||||
|
||||
/**
|
||||
* Data class representing a kind of [SuggestionGenerator]'s suggestions.
|
||||
*
|
||||
@@ -8,4 +10,6 @@ package com.intellij.platform.ml.impl.turboComplete
|
||||
* The completion kind's name is defined statically, it should be unique among
|
||||
* the corresponding [KindVariety].
|
||||
*/
|
||||
data class CompletionKind(val name: Enum<*>, val variety: KindVariety)
|
||||
data class CompletionKind(val name: Enum<*>, val variety: KindVariety)
|
||||
|
||||
object TierCompletionKind : Tier<CompletionKind>()
|
||||
|
||||
@@ -16,7 +16,7 @@ interface SuggestionGeneratorExecutorProvider {
|
||||
): SuggestionGeneratorExecutor
|
||||
|
||||
companion object {
|
||||
private val EP_NAME: ExtensionPointName<SuggestionGeneratorExecutorProvider> =
|
||||
val EP_NAME: ExtensionPointName<SuggestionGeneratorExecutorProvider> =
|
||||
ExtensionPointName("com.intellij.turboComplete.suggestionGeneratorExecutorProvider")
|
||||
|
||||
fun hasAnyToCall(parameters: CompletionParameters): Boolean {
|
||||
|
||||
483
platform/ml-impl/test/com/intellij/platform/ml/impl/Demo.kt
Normal file
483
platform/ml-impl/test/com/intellij/platform/ml/impl/Demo.kt
Normal file
@@ -0,0 +1,483 @@
|
||||
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl
|
||||
|
||||
import com.intellij.internal.statistic.FUCollectorTestCase
|
||||
import com.intellij.internal.statistic.eventLog.events.ClassEventField
|
||||
import com.intellij.internal.statistic.eventLog.events.EventField
|
||||
import com.intellij.internal.statistic.eventLog.events.EventPair
|
||||
import com.intellij.lang.Language
|
||||
import com.intellij.openapi.fileTypes.PlainTextLanguage
|
||||
import com.intellij.openapi.util.Version
|
||||
import com.intellij.platform.ml.*
|
||||
import com.intellij.platform.ml.impl.MLTaskApproach.Companion.startMLSession
|
||||
import com.intellij.platform.ml.impl.apiPlatform.CodeLikePrinter
|
||||
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform
|
||||
import com.intellij.platform.ml.impl.apiPlatform.ReplaceableIJPlatform
|
||||
import com.intellij.platform.ml.impl.approach.*
|
||||
import com.intellij.platform.ml.impl.logger.FailedSessionLoggerRegister
|
||||
import com.intellij.platform.ml.impl.logger.FinishedSessionLoggerRegister
|
||||
import com.intellij.platform.ml.impl.logger.InplaceFeaturesScheme
|
||||
import com.intellij.platform.ml.impl.logger.MLEvent
|
||||
import com.intellij.platform.ml.impl.model.MLModel
|
||||
import com.intellij.platform.ml.impl.monitoring.*
|
||||
import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.ApproachListeners.Companion.monitoredBy
|
||||
import com.intellij.platform.ml.impl.session.*
|
||||
import com.intellij.platform.ml.impl.session.analysis.MLModelAnalyser
|
||||
import com.intellij.platform.ml.impl.session.analysis.ShallowSessionAnalyser
|
||||
import com.intellij.platform.ml.impl.session.analysis.StructureAnalyser
|
||||
import com.intellij.platform.ml.impl.session.analysis.StructureAnalysis
|
||||
import com.intellij.testFramework.fixtures.BasePlatformTestCase
|
||||
import com.jetbrains.fus.reporting.model.lion3.LogEvent
|
||||
import java.io.BufferedWriter
|
||||
import java.io.FileWriter
|
||||
import java.net.URI
|
||||
import java.nio.file.Path
|
||||
import java.util.concurrent.CompletableFuture
|
||||
import java.util.concurrent.TimeUnit
|
||||
import java.util.function.Consumer
|
||||
import kotlin.io.path.div
|
||||
import kotlin.random.Random
|
||||
|
||||
enum class CompletionType {
|
||||
SMART,
|
||||
BASIC
|
||||
}
|
||||
|
||||
data class CompletionSession(
|
||||
val language: Language,
|
||||
val callOrder: Int,
|
||||
val completionType: CompletionType?
|
||||
)
|
||||
|
||||
data class LookupImpl(
|
||||
val foo: Boolean,
|
||||
val index: Int
|
||||
)
|
||||
|
||||
data class LookupItem(
|
||||
val lookupString: String,
|
||||
val decoration: Map<Any, Any>
|
||||
)
|
||||
|
||||
data class GitRepository(
|
||||
val user: String,
|
||||
val projectUri: URI,
|
||||
val commits: List<String>
|
||||
)
|
||||
|
||||
object TierCompletionSession : Tier<CompletionSession>()
|
||||
object TierLookup : Tier<LookupImpl>()
|
||||
object TierItem : Tier<LookupItem>()
|
||||
|
||||
object TierGit : Tier<GitRepository>()
|
||||
|
||||
class CompletionSessionFeatures1 : TierDescriptor {
|
||||
companion object {
|
||||
val CALL_ORDER = FeatureDeclaration.int("call_order")
|
||||
val LANGUAGE_ID = FeatureDeclaration.categorical("language_id", Language.getRegisteredLanguages().map { it.id }.toSet())
|
||||
val GIT_USER = FeatureDeclaration.boolean("git_user_is_Glebanister")
|
||||
val COMPLETION_TYPE = FeatureDeclaration.enum<CompletionType>("completion_type").nullable()
|
||||
}
|
||||
|
||||
override val tier: Tier<*> = TierCompletionSession
|
||||
|
||||
override val additionallyRequiredTiers: Set<Tier<*>> = setOf(TierGit)
|
||||
|
||||
override val descriptionDeclaration: Set<FeatureDeclaration<*>> = setOf(
|
||||
CALL_ORDER, LANGUAGE_ID, GIT_USER, COMPLETION_TYPE
|
||||
)
|
||||
|
||||
override fun describe(environment: Environment, usefulFeaturesFilter: FeatureFilter): Set<Feature> {
|
||||
val completionSession = environment[TierCompletionSession]
|
||||
val gitRepository = environment[TierGit]
|
||||
return setOf(
|
||||
CALL_ORDER with completionSession.callOrder,
|
||||
LANGUAGE_ID with completionSession.language.id,
|
||||
GIT_USER with (gitRepository.user == "Glebanister"),
|
||||
COMPLETION_TYPE with completionSession.completionType
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
class ItemFeatures1 : TierDescriptor {
|
||||
companion object {
|
||||
val DECORATIONS = FeatureDeclaration.int("decorations")
|
||||
val LENGTH = FeatureDeclaration.int("length")
|
||||
}
|
||||
|
||||
override val tier: Tier<*> = TierItem
|
||||
|
||||
override val descriptionDeclaration: Set<FeatureDeclaration<*>> = setOf(
|
||||
DECORATIONS, LENGTH
|
||||
)
|
||||
|
||||
override fun describe(environment: Environment, usefulFeaturesFilter: FeatureFilter): Set<Feature> {
|
||||
val item = environment[TierItem]
|
||||
return setOf(
|
||||
DECORATIONS with item.decoration.size,
|
||||
LENGTH with item.lookupString.length
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
class GitFeatures1 : TierDescriptor {
|
||||
companion object {
|
||||
val N_COMMITS = FeatureDeclaration.int("n_commits")
|
||||
val HAS_USER = FeatureDeclaration.boolean("has_user")
|
||||
}
|
||||
|
||||
override val tier: Tier<*> = TierGit
|
||||
|
||||
override val descriptionDeclaration: Set<FeatureDeclaration<*>> = setOf(
|
||||
N_COMMITS, HAS_USER
|
||||
)
|
||||
|
||||
override fun describe(environment: Environment, usefulFeaturesFilter: FeatureFilter) = setOf(
|
||||
N_COMMITS with environment[TierGit].commits.size,
|
||||
HAS_USER with environment[TierGit].user.isNotEmpty()
|
||||
)
|
||||
}
|
||||
|
||||
class GitInformant : EnvironmentExtender<GitRepository> {
|
||||
override val extendingTier: Tier<GitRepository> = TierGit
|
||||
|
||||
override val requiredTiers: Set<Tier<*>> = setOf()
|
||||
|
||||
override fun extend(environment: Environment): GitRepository {
|
||||
return GitRepository(
|
||||
user = "Glebanister",
|
||||
projectUri = URI.create("ssh://git@git.jetbrains.team/ij/intellij.git"),
|
||||
commits = listOf(
|
||||
"0e47200fa3bf029d7244745eacbf9d495de818c1",
|
||||
"638fbc7840b85d6e34ef2320ca5b2c9ec2c4b23c",
|
||||
"5a74104a0901b8faa4cc1f76736347cd33917041"
|
||||
)
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
class SomeStructureAnalyser<M : MLModel<Double>> : StructureAnalyser<M, Double> {
|
||||
companion object {
|
||||
val SESSION_IS_GOOD = FeatureDeclaration.boolean("very_good_session")
|
||||
val LOOKUP_INDEX = FeatureDeclaration.int("lookup_index")
|
||||
}
|
||||
|
||||
override fun analyse(sessionTreeRoot: DescribedRootContainer<M, Double>): CompletableFuture<StructureAnalysis<M, Double>> {
|
||||
val analysis = mutableMapOf<DescribedSessionTree<M, Double>, PerTier<Set<Feature>>>()
|
||||
sessionTreeRoot.accept(LookupAnalyser(analysis))
|
||||
analysis[sessionTreeRoot] = mapOf(
|
||||
TierCompletionSession to setOf(SESSION_IS_GOOD with true)
|
||||
)
|
||||
|
||||
return CompletableFuture.supplyAsync {
|
||||
// Pretend that analysis is taking some long time
|
||||
TimeUnit.SECONDS.sleep(2)
|
||||
analysis
|
||||
}
|
||||
}
|
||||
|
||||
private class LookupAnalyser<M : MLModel<Double>>(
|
||||
private val analysis: MutableMap<DescribedSessionTree<M, Double>, PerTier<Set<Feature>>>
|
||||
) : SessionTree.LevelVisitor<M, Double>(levelIndex = 1) {
|
||||
override fun visitLevel(level: DescribedLevel, levelRoot: DescribedSessionTree<M, Double>) {
|
||||
val lookup = level.environment[TierLookup]
|
||||
analysis[levelRoot] = mapOf(TierLookup to setOf(LOOKUP_INDEX with lookup.index))
|
||||
}
|
||||
}
|
||||
|
||||
override val analysisDeclaration: PerTier<Set<FeatureDeclaration<*>>>
|
||||
get() = mapOf(
|
||||
TierCompletionSession to setOf(SESSION_IS_GOOD),
|
||||
TierLookup to setOf(LOOKUP_INDEX)
|
||||
)
|
||||
}
|
||||
|
||||
class RandomModelSeedAnalyser : MLModelAnalyser<RandomModel, Double> {
|
||||
companion object {
|
||||
val SEED = FeatureDeclaration.int("random_seed")
|
||||
}
|
||||
|
||||
override val analysisDeclaration: Set<FeatureDeclaration<*>> = setOf(
|
||||
SEED
|
||||
)
|
||||
|
||||
override fun analyse(sessionTreeRoot: DescribedRootContainer<RandomModel, Double>): CompletableFuture<Set<Feature>> {
|
||||
return CompletableFuture.supplyAsync {
|
||||
// Pretend that analysis is taking some long time
|
||||
TimeUnit.SECONDS.sleep(1)
|
||||
setOf(SEED with sessionTreeRoot.root.seed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
class RandomModel(val seed: Int) : MLModel<Double>, Versioned, LanguageSpecific {
|
||||
private val generator = Random(seed)
|
||||
|
||||
class Provider : MLModel.Provider<RandomModel, Double> {
|
||||
override val requiredTiers: Set<Tier<*>> = emptySet()
|
||||
|
||||
override fun provideModel(sessionTiers: List<LevelTiers>, environment: Environment): RandomModel? {
|
||||
return if (Random.nextBoolean()) {
|
||||
if (Random.nextBoolean()) RandomModel(1) else throw IllegalStateException()
|
||||
}
|
||||
else null
|
||||
}
|
||||
}
|
||||
|
||||
override val knownFeatures: PerTier<FeatureSelector> = mapOf(
|
||||
TierCompletionSession to FeatureSelector.EVERYTHING,
|
||||
TierLookup to FeatureSelector.EVERYTHING,
|
||||
TierGit to FeatureSelector.EVERYTHING,
|
||||
TierItem to FeatureSelector.EVERYTHING,
|
||||
)
|
||||
|
||||
override fun predict(features: PerTier<Set<Feature>>): Double {
|
||||
return generator.nextDouble()
|
||||
}
|
||||
|
||||
override val languageId: String = PlainTextLanguage.INSTANCE.id
|
||||
|
||||
override val version: Version = Version(0, 0, 1)
|
||||
}
|
||||
|
||||
class SomeListener(private val name: String) : MLTaskGroupListener {
|
||||
override val approachListeners = listOf(
|
||||
MockTaskApproach::class.java monitoredBy InitializationListener()
|
||||
)
|
||||
|
||||
private fun log(message: String) = println("[Listener $name says] $message")
|
||||
|
||||
inner class InitializationListener : MLApproachInitializationListener<RandomModel, Double> {
|
||||
override fun onAttemptedToStartSession(permanentSessionEnvironment: Environment): MLApproachListener<RandomModel, Double> {
|
||||
log("attempted to initialize session")
|
||||
return ApproachListener()
|
||||
}
|
||||
}
|
||||
|
||||
inner class ApproachListener : MLApproachListener<RandomModel, Double> {
|
||||
override fun onFailedToStartSessionWithException(exception: Throwable) {
|
||||
log("failed to start session with exception: $exception")
|
||||
}
|
||||
|
||||
override fun onFailedToStartSession(failure: Session.StartOutcome.Failure<Double>) {
|
||||
log("failed to start session with outcome: $failure")
|
||||
}
|
||||
|
||||
override fun onStartedSession(session: Session<Double>): MLSessionListener<RandomModel, Double> {
|
||||
log("session was started successfully: $session")
|
||||
return SessionListener()
|
||||
}
|
||||
}
|
||||
|
||||
inner class SessionListener : MLSessionListener<RandomModel, Double> {
|
||||
override fun onSessionDescriptionFinished(sessionTree: DescribedRootContainer<RandomModel, Double>) {
|
||||
log("session successfully described: $sessionTree")
|
||||
}
|
||||
|
||||
override fun onSessionAnalysisFinished(sessionTree: AnalysedRootContainer<Double>) {
|
||||
log("session successfully analyzed: $sessionTree")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
object ExceptionLogger : ShallowSessionAnalyser<Throwable> {
|
||||
private val THROWABLE_CLASS = ClassEventField("throwable_class")
|
||||
|
||||
override val name: String = "exception"
|
||||
|
||||
override val declaration: List<EventField<*>> = listOf(THROWABLE_CLASS)
|
||||
|
||||
override fun analyse(permanentSessionEnvironment: Environment, data: Throwable): List<EventPair<*>> {
|
||||
return listOf(THROWABLE_CLASS with data.javaClass)
|
||||
}
|
||||
}
|
||||
|
||||
object FailureLogger : ShallowSessionAnalyser<Session.StartOutcome.Failure<Double>> {
|
||||
private val REASON = ClassEventField("reason")
|
||||
|
||||
override val name: String = "normal_failure"
|
||||
|
||||
override val declaration: List<EventField<*>> = listOf(REASON)
|
||||
|
||||
override fun analyse(permanentSessionEnvironment: Environment,
|
||||
data: Session.StartOutcome.Failure<Double>): List<EventPair<*>> {
|
||||
return listOf(REASON with data.javaClass)
|
||||
}
|
||||
}
|
||||
|
||||
object ThisTestApiPlatform : TestApiPlatform() {
|
||||
override val tierDescriptors = listOf(
|
||||
CompletionSessionFeatures1(),
|
||||
ItemFeatures1(),
|
||||
GitFeatures1(),
|
||||
)
|
||||
|
||||
override val environmentExtenders = listOf(
|
||||
GitInformant(),
|
||||
)
|
||||
|
||||
override val taskApproaches = listOf(
|
||||
MockTaskApproach.Initializer()
|
||||
)
|
||||
|
||||
|
||||
override val initialStartupListeners: List<MLApiStartupListener> = listOf(
|
||||
FinishedSessionLoggerRegister<RandomModel, Double>(
|
||||
MockTaskApproach::class.java,
|
||||
InplaceFeaturesScheme.FusScheme.DOUBLE
|
||||
),
|
||||
FailedSessionLoggerRegister<RandomModel, Double>(
|
||||
MockTaskApproach::class.java,
|
||||
exceptionalAnalysers = listOf(ExceptionLogger),
|
||||
normalFailureAnalysers = listOf(FailureLogger)
|
||||
)
|
||||
)
|
||||
|
||||
override val initialTaskListeners: List<MLTaskGroupListener> = listOf(
|
||||
SomeListener("Nika"),
|
||||
SomeListener("Alex"),
|
||||
)
|
||||
|
||||
override val initialEvents: List<MLEvent> = listOf()
|
||||
|
||||
|
||||
override fun manageNonDeclaredFeatures(descriptor: ObsoleteTierDescriptor, nonDeclaredFeatures: Set<Feature>) {
|
||||
val printer = CodeLikePrinter()
|
||||
println("$descriptor is missing the following declaration: ${printer.printCodeLikeString(nonDeclaredFeatures.map { it.declaration })}")
|
||||
}
|
||||
}
|
||||
|
||||
object MockTask : MLTask<Double>(
|
||||
name = "mock",
|
||||
predictionClass = Double::class.java,
|
||||
levels = listOf(
|
||||
setOf(TierCompletionSession),
|
||||
setOf(TierLookup),
|
||||
setOf(TierItem)
|
||||
)
|
||||
)
|
||||
|
||||
class MockTaskApproach(
|
||||
apiPlatform: MLApiPlatform,
|
||||
task: MLTask<Double>
|
||||
) : LogDrivenModelInference<RandomModel, Double>(task, apiPlatform) {
|
||||
|
||||
override val additionallyDescribedTiers: List<Set<Tier<*>>> = listOf(
|
||||
setOf(TierGit),
|
||||
setOf(),
|
||||
setOf(),
|
||||
)
|
||||
|
||||
override val analysisMethod: AnalysisMethod<RandomModel, Double> = StructureAndModelAnalysis(
|
||||
structureAnalysers = listOf(SomeStructureAnalyser()),
|
||||
mlModelAnalysers = listOf(
|
||||
RandomModelSeedAnalyser(),
|
||||
ModelVersionAnalyser(),
|
||||
ModelLanguageAnalyser()
|
||||
)
|
||||
)
|
||||
|
||||
override val mlModelProvider = RandomModel.Provider()
|
||||
|
||||
override val notUsedDescription: PerTier<FeatureSelector> = mapOf(
|
||||
TierCompletionSession to FeatureSelector.NOTHING,
|
||||
TierLookup to FeatureSelector.NOTHING,
|
||||
TierItem to FeatureSelector.NOTHING,
|
||||
TierGit to FeatureSelector.NOTHING
|
||||
)
|
||||
|
||||
override val descriptionComputer: DescriptionComputer = StateFreeDescriptionComputer()
|
||||
|
||||
class Initializer : MLTaskApproachInitializer<Double> {
|
||||
override val task: MLTask<Double> = MockTask
|
||||
override fun initializeApproachWithin(apiPlatform: MLApiPlatform) = MockTaskApproach(apiPlatform, task)
|
||||
}
|
||||
}
|
||||
|
||||
class TestTask : BasePlatformTestCase() {
|
||||
fun `test demo ml task`() {
|
||||
// After the session is finished, it will be logged to community/platform/ml-impl/testResources/ml_logs.js
|
||||
|
||||
val logs: MutableList<Pair<String, Map<String, Any>>> = mutableListOf()
|
||||
val collectLogs = Consumer { fusLog: LogEvent ->
|
||||
logs.add(fusLog.event.id to fusLog.event.data)
|
||||
}
|
||||
|
||||
// TODO: Handle this as well (when it has been initialized before we had the control flow)
|
||||
//MLEventLogger.Manager.ensureNotInitialized()
|
||||
|
||||
ReplaceableIJPlatform.replacingWith(ThisTestApiPlatform) {
|
||||
FUCollectorTestCase.listenForEvents("FUS", this.testRootDisposable, collectLogs) {
|
||||
|
||||
repeat(3) { sessionIndex ->
|
||||
|
||||
println("Demo session #$sessionIndex has started")
|
||||
|
||||
val startOutcome = startMLSession(MockTask, Environment.of(
|
||||
TierCompletionSession with CompletionSession(
|
||||
language = PlainTextLanguage.INSTANCE,
|
||||
callOrder = 1,
|
||||
completionType = CompletionType.SMART
|
||||
)
|
||||
))
|
||||
|
||||
val completionSession = startOutcome.session ?: return@repeat
|
||||
|
||||
completionSession.withNestedSessions { lookupSessionCreator ->
|
||||
|
||||
lookupSessionCreator.nestConsidering(Environment.of(TierLookup with LookupImpl(true, 1)))
|
||||
.withPredictions {
|
||||
it.predictConsidering(Environment.of(TierItem with LookupItem("hello", emptyMap())))
|
||||
it.predictConsidering(Environment.of(TierItem with LookupItem("world", emptyMap())))
|
||||
}
|
||||
|
||||
lookupSessionCreator.nestConsidering(Environment.of(TierLookup with LookupImpl(false, 2)))
|
||||
.withPredictions {
|
||||
it.predictConsidering(Environment.of(TierItem with LookupItem("hello!!!", mapOf("bold" to true))))
|
||||
it.predictConsidering(Environment.of(TierItem with LookupItem("AAAAA!!", mapOf("strikethrough" to true))))
|
||||
it.consider(Environment.of(TierItem with LookupItem("AAAAAAAAAAAAAAAAA", mapOf("cursive" to true))))
|
||||
}
|
||||
}
|
||||
|
||||
// Wait until the analysis will be over
|
||||
// TODO: Think if it is possible to track it via API
|
||||
|
||||
Thread.sleep(3 * 1000)
|
||||
|
||||
println("Demo session #$sessionIndex has finished")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
convertListToJsonAndWriteToFile(logs)
|
||||
}
|
||||
}
|
||||
|
||||
fun mapToJson(map: Map<String, Any>): String {
|
||||
val entrySet = map.entries.joinToString(", ") {
|
||||
"\"${it.key}\": ${valueToJson(it.value)}"
|
||||
}
|
||||
return "{$entrySet}"
|
||||
}
|
||||
|
||||
fun listToJson(list: List<Any>): String =
|
||||
list.joinToString(", ") { valueToJson(it) }
|
||||
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
fun valueToJson(value: Any): String =
|
||||
when (value) {
|
||||
is String -> "\"$value\""
|
||||
is Map<*, *> -> mapToJson(value as Map<String, Any>)
|
||||
is List<*> -> "[${listToJson(value as List<Any>)}]"
|
||||
else -> value.toString()
|
||||
}
|
||||
|
||||
fun convertListToJsonAndWriteToFile(list: MutableList<Pair<String, Map<String, Any>>>) {
|
||||
// Prepare content to write to file
|
||||
val logs = list.joinToString(",\n") { mapToJson(mapOf("eventId" to it.first, "data" to it.second)) }
|
||||
val content = "let logs = [\n$logs\n];"
|
||||
|
||||
// Write content to file
|
||||
val filePath = Path.of(".") / "testResources" / "ml_logs.js"
|
||||
BufferedWriter(FileWriter(filePath.toFile())).use { it.write(content) }
|
||||
}
|
||||
@@ -0,0 +1,46 @@
|
||||
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.impl
|
||||
|
||||
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform
|
||||
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform.ExtensionController
|
||||
import com.intellij.platform.ml.impl.logger.MLEvent
|
||||
import com.intellij.platform.ml.impl.monitoring.MLApiStartupListener
|
||||
import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener
|
||||
|
||||
abstract class TestApiPlatform : MLApiPlatform() {
|
||||
private val dynamicTaskListeners: MutableList<MLTaskGroupListener> = mutableListOf()
|
||||
private val dynamicStartupListeners: MutableList<MLApiStartupListener> = mutableListOf()
|
||||
private val dynamicEvents: MutableList<MLEvent> = mutableListOf()
|
||||
|
||||
abstract val initialTaskListeners: List<MLTaskGroupListener>
|
||||
|
||||
abstract val initialStartupListeners: List<MLApiStartupListener>
|
||||
|
||||
abstract val initialEvents: List<MLEvent>
|
||||
|
||||
final override val events: List<MLEvent>
|
||||
get() = initialEvents + dynamicEvents
|
||||
|
||||
final override val taskListeners: List<MLTaskGroupListener>
|
||||
get() = initialTaskListeners + dynamicTaskListeners
|
||||
|
||||
final override val startupListeners: List<MLApiStartupListener>
|
||||
get() = initialStartupListeners + dynamicStartupListeners
|
||||
|
||||
override fun addTaskListener(taskListener: MLTaskGroupListener): ExtensionController {
|
||||
return extend(taskListener, dynamicTaskListeners)
|
||||
}
|
||||
|
||||
override fun addEvent(event: MLEvent): ExtensionController {
|
||||
return extend(event, dynamicEvents)
|
||||
}
|
||||
|
||||
override fun addStartupListener(listener: MLApiStartupListener): ExtensionController {
|
||||
return extend(listener, dynamicStartupListeners)
|
||||
}
|
||||
|
||||
private fun <T> extend(obj: T, collection: MutableCollection<T>): ExtensionController {
|
||||
collection.add(obj)
|
||||
return ExtensionController { collection.remove(obj) }
|
||||
}
|
||||
}
|
||||
6
platform/ml-impl/testResources/ml_logs.js
Normal file
6
platform/ml-impl/testResources/ml_logs.js
Normal file
@@ -0,0 +1,6 @@
|
||||
let logs = [
|
||||
{"eventId": "startup", "data": {"success": true}},
|
||||
{"eventId": "mock.failed", "data": {"normal_failure": {"reason": "com.intellij.platform.ml.impl.approach.ModelNotAcquiredOutcome"}}},
|
||||
{"eventId": "mock.finished", "data": {"session": {"ml_model": {"random_seed": 1, "language_id": "TEXT", "version": "0.0"}}, "additional": {"TierGit": {"description": {"not_used": {}, "used": {"has_user": true, "n_commits": 3}}, "id": -401817549}}, "main": {"TierCompletionSession": {"description": {"not_used": {}, "used": {"language_id": "TEXT", "git_user_is_Glebanister": true, "completion_type": "SMART", "call_order": 1}}, "id": 1444967957, "analysis": {"very_good_session": true}}}, "nested": [{"additional": {}, "main": {"TierLookup": {"description": {"not_used": {}, "used": {}}, "id": 38162, "analysis": {"lookup_index": 1}}}, "nested": [{"additional": {}, "prediction": "0.1397272444266996", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 5, "decorations": 0}}, "id": -1220935314, "analysis": {}}}}, {"additional": {}, "prediction": "0.8772473577824259", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 5, "decorations": 0}}, "id": -782084434, "analysis": {}}}}]}, {"additional": {}, "main": {"TierLookup": {"description": {"not_used": {}, "used": {}}, "id": 38349, "analysis": {"lookup_index": 2}}}, "nested": [{"additional": {}, "prediction": "0.8739363143786573", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 8, "decorations": 1}}, "id": 1198136891, "analysis": {}}}}, {"additional": {}, "prediction": "0.9611620534809023", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 7, "decorations": 1}}, "id": 122089051, "analysis": {}}}}, {"additional": {}, "prediction": "null", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 17, "decorations": 1}}, "id": 1444769257, "analysis": {}}}}]}]}},
|
||||
{"eventId": "mock.finished", "data": {"session": {"ml_model": {"random_seed": 1, "language_id": "TEXT", "version": "0.0"}}, "additional": {"TierGit": {"description": {"not_used": {}, "used": {"has_user": true, "n_commits": 3}}, "id": -401817549}}, "main": {"TierCompletionSession": {"description": {"not_used": {}, "used": {"language_id": "TEXT", "git_user_is_Glebanister": true, "completion_type": "SMART", "call_order": 1}}, "id": 1444967957, "analysis": {"very_good_session": true}}}, "nested": [{"additional": {}, "main": {"TierLookup": {"description": {"not_used": {}, "used": {}}, "id": 38162, "analysis": {"lookup_index": 1}}}, "nested": [{"additional": {}, "prediction": "0.1397272444266996", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 5, "decorations": 0}}, "id": -1220935314, "analysis": {}}}}, {"additional": {}, "prediction": "0.8772473577824259", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 5, "decorations": 0}}, "id": -782084434, "analysis": {}}}}]}, {"additional": {}, "main": {"TierLookup": {"description": {"not_used": {}, "used": {}}, "id": 38349, "analysis": {"lookup_index": 2}}}, "nested": [{"additional": {}, "prediction": "0.8739363143786573", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 8, "decorations": 1}}, "id": 1198136891, "analysis": {}}}}, {"additional": {}, "prediction": "0.9611620534809023", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 7, "decorations": 1}}, "id": 122089051, "analysis": {}}}}, {"additional": {}, "prediction": "null", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 17, "decorations": 1}}, "id": 1444769257, "analysis": {}}}}]}]}}
|
||||
];
|
||||
BIN
platform/ml-impl/testResources/mockModel/local_model.zip
Normal file
BIN
platform/ml-impl/testResources/mockModel/local_model.zip
Normal file
Binary file not shown.
@@ -520,6 +520,16 @@
|
||||
<with attribute="implementationClass" implements="com.intellij.internal.ml.MLFeatureProvider"/>
|
||||
</extensionPoint>
|
||||
|
||||
<extensionPoint qualifiedName="com.intellij.platform.ml.environmentExtender"
|
||||
interface="com.intellij.platform.ml.EnvironmentExtender"
|
||||
dynamic="true"/>
|
||||
<extensionPoint qualifiedName="com.intellij.platform.ml.descriptor"
|
||||
interface="com.intellij.platform.ml.TierDescriptor"
|
||||
dynamic="true"/>
|
||||
<extensionPoint qualifiedName="com.intellij.platform.ml.taskListener"
|
||||
interface="com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener"
|
||||
dynamic="true"/>
|
||||
|
||||
<extensionPoint name="defender.config" interface="com.intellij.diagnostic.WindowsDefenderChecker$Extension" dynamic="true" />
|
||||
<extensionPoint name="authorizationProvider" interface="com.intellij.ide.impl.AuthorizationProvider" dynamic="true" />
|
||||
|
||||
|
||||
Reference in New Issue
Block a user