diff --git a/platform/ml-api/intellij.platform.ml.iml b/platform/ml-api/intellij.platform.ml.iml index 22b2ad5e983a..c0615a6ea52d 100644 --- a/platform/ml-api/intellij.platform.ml.iml +++ b/platform/ml-api/intellij.platform.ml.iml @@ -3,6 +3,7 @@ + diff --git a/platform/ml-api/resources/META-INF/ml-api.xml b/platform/ml-api/resources/META-INF/ml-api.xml new file mode 100644 index 000000000000..164f46cfcfb2 --- /dev/null +++ b/platform/ml-api/resources/META-INF/ml-api.xml @@ -0,0 +1,2 @@ + + diff --git a/platform/ml-api/src/com/intellij/platform/ml/Environment.kt b/platform/ml-api/src/com/intellij/platform/ml/Environment.kt new file mode 100644 index 000000000000..453f2b67dd6d --- /dev/null +++ b/platform/ml-api/src/com/intellij/platform/ml/Environment.kt @@ -0,0 +1,69 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml + +import org.jetbrains.annotations.ApiStatus + +/** + * Represents an environment that is being assembled to be described by [TierDescriptor]s, + * to acquire a new ML Model, or for another reason. + */ +@ApiStatus.Internal +interface Environment { + /** + * The set of tiers, that the environment contains. + */ + val tiers: Set> + + /** + * @return an instance, that corresponds to the given [tier]. + * @throws IllegalArgumentException if the [tier] is not present + */ + fun getInstance(tier: Tier): T + + /** + * The set of tier instances that are present in the environment. + */ + val tierInstances: Set> + get() { + return tiers.map { this.getTierInstance(it) }.toSet() + } + + /** + * @return a tier instance wrapped into [TierInstance] class. + * @throws IllegalArgumentException if the tier is not present. + */ + fun getTierInstance(tier: Tier) = TierInstance(tier, this[tier]) + + /** + * @return if the tier is present in the environment + */ + operator fun contains(tier: Tier<*>): Boolean = tier in this.tiers + + companion object { + /** + * @return An environment that contains all tiers of all given environments. + * @throws IllegalArgumentException if there is a tier that is present in more than two + * environments. + */ + fun joined(environments: Iterable): Environment = TierInstanceStorage.joined(environments) + + fun empty(): Environment = joined(emptySet()) + + /** + * Builds an environment that contains all the given tiers. + * @throws IllegalArgumentException if there is more than one instance of a particular tier. + */ + fun of(entries: Iterable>): Environment { + val storage = TierInstanceStorage() + fun putToStorage(tierInstance: TierInstance) { + storage[tierInstance.tier] = tierInstance.instance + } + entries.forEach { putToStorage(it) } + return storage + } + + fun of(vararg entries: TierInstance<*>): Environment = of(entries.toList()) + } +} + +operator fun Environment.get(tier: Tier): T = getInstance(tier) diff --git a/platform/ml-api/src/com/intellij/platform/ml/EnvironmentExtender.kt b/platform/ml-api/src/com/intellij/platform/ml/EnvironmentExtender.kt new file mode 100644 index 000000000000..d9f5f6c8b007 --- /dev/null +++ b/platform/ml-api/src/com/intellij/platform/ml/EnvironmentExtender.kt @@ -0,0 +1,53 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml + +import com.intellij.openapi.extensions.ExtensionPointName +import org.jetbrains.annotations.ApiStatus + +/** + * Provides additional tiers on the top of the main ones, making an "extended" + * environment. + * + * If there are some other tiers that an [EnvironmentExtender] needs to build or access the [extendingTier], + * then it could "order" them by defining the [requiredTiers]. + * + * The extender will be called if and only if all the requirements are satisfied. + * + * Additional tiers will be later described by [TierDescriptor]. This description will be used for the model's + * inference and then logged. However, they cannot be analyzed via [com.intellij.platform.ml.impl.session.analysis.SessionAnalyser]. + * Because they do not make a part of the ML Task, and they could be absent in the session. + * + * An "extended environment" (i.e., the one that contains main as well as additional tiers) is built + * each time when a [TierRequester] performs the desired action. For example, an [EnvironmentExtender.extend] or + * [com.intellij.platform.ml.impl.model.MLModel.Provider.provideModel] is called. + * + * As each extender has some requirements, as it also produces some other tier, we must first determine + * the order in which the available extenders will run, or resolve it. + * To lean more about that, address [com.intellij.platform.ml.impl.environment.ExtendedEnvironment]'s documentation. + * + * Additional tiers, + */ +@ApiStatus.Internal +interface EnvironmentExtender : TierRequester { + /** + * The tier that the extender will be providing. + */ + val extendingTier: Tier + + /** + * Provides an instance of the [extendingTier] based on the [environment]. + * + * @param environment includes tiers requested in [requiredTiers] + */ + fun extend(environment: Environment): T? + + companion object { + val EP_NAME: ExtensionPointName> = ExtensionPointName("com.intellij.platform.ml.environmentExtender") + + fun EnvironmentExtender.extendTierInstance(environment: Environment): TierInstance? { + return extend(environment)?.let { + this.extendingTier with it + } + } + } +} diff --git a/platform/ml-api/src/com/intellij/platform/ml/Feature.kt b/platform/ml-api/src/com/intellij/platform/ml/Feature.kt new file mode 100644 index 000000000000..2e9d721b22a0 --- /dev/null +++ b/platform/ml-api/src/com/intellij/platform/ml/Feature.kt @@ -0,0 +1,90 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml + +import org.jetbrains.annotations.ApiStatus + +/** + * An instantiated tier's feature. It could be either + * description's feature (that was provided by a [TierDescriptor]), or an + * analysis feature (that was provided by a [com.intellij.platform.ml.impl.session.analysis.SessionAnalyser]). + * + * If you need another type of feature, and it is not supported yet, contact the ML API developers. + */ +@ApiStatus.Internal +sealed class Feature { + /** + * A statically defined declaration of this feature, that includes all the information, except the value. + */ + abstract val declaration: FeatureDeclaration<*> + + /** + * A computed nullable value + */ + abstract val value: Any? + + override fun equals(other: Any?): kotlin.Boolean { + if (other !is Feature) return false + return other.declaration == this.declaration && other.value == this.value + } + + sealed class TypedFeature( + val name: String, + override val value: T, + ) : Feature() { + abstract val valueType: FeatureValueType + + override val declaration: FeatureDeclaration<*> + get() = FeatureDeclaration(name, valueType) + } + + override fun hashCode(): kotlin.Int { + return this.declaration.hashCode().xor(this.value.hashCode()) + } + + override fun toString(): String { + return "Feature{declaration=$declaration, value=$value}" + } + + class Enum>(name: String, value: T) : TypedFeature(name, value) { + override val valueType = FeatureValueType.Enum(value.javaClass) + } + + class Int(name: String, value: kotlin.Int) : TypedFeature(name, value) { + override val valueType = FeatureValueType.Int + } + + class Boolean(name: String, value: kotlin.Boolean) : TypedFeature(name, value) { + override val valueType = FeatureValueType.Boolean + } + + class Float(name: String, value: kotlin.Float) : TypedFeature(name, value) { + override val valueType = FeatureValueType.Float + } + + class Double(name: String, value: kotlin.Double) : TypedFeature(name, value) { + override val valueType = FeatureValueType.Double + } + + class Long(name: String, value: kotlin.Long) : TypedFeature(name, value) { + override val valueType = FeatureValueType.Long + } + + class Class(name: String, value: java.lang.Class<*>) : TypedFeature>(name, value) { + override val valueType = FeatureValueType.Class + } + + class Nullable(name: String, value: T?, val baseType: FeatureValueType) + : TypedFeature(name, value) { + override val valueType = FeatureValueType.Nullable(baseType) + } + + class Categorical(name: String, value: String, possibleValues: Set) + : TypedFeature(name, value) { + override val valueType = FeatureValueType.Categorical(possibleValues) + } + + class Version(name: String, value: com.intellij.openapi.util.Version) + : TypedFeature(name, value) { + override val valueType = FeatureValueType.Version + } +} diff --git a/platform/ml-api/src/com/intellij/platform/ml/FeatureDeclaration.kt b/platform/ml-api/src/com/intellij/platform/ml/FeatureDeclaration.kt new file mode 100644 index 000000000000..1b4f78c76c81 --- /dev/null +++ b/platform/ml-api/src/com/intellij/platform/ml/FeatureDeclaration.kt @@ -0,0 +1,62 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml + +import org.jetbrains.annotations.ApiStatus + +/** + * Represents declaration of a Tier's feature. + * + * All tiers that end up in the ML Model's inference must be declared statically + * (i.e. via Extension Points), as FUS logs' validator must be aware of every + * feature that will be logged. + * + * @param name The feature's name that is unique for this tier. It may not contain special symbols. + * @param type The feature's type. + * @param T The type of the value, with which the [type] can be instantiated ([FeatureValueType.instantiate]). + */ +@ApiStatus.Internal +data class FeatureDeclaration( + val name: String, + val type: FeatureValueType +) { + init { + require(name.all { it.isLetterOrDigit() || it == '_' }) { + "Invalid feature name '$name': it shall not contain special symbols" + } + } + + /** + * Shortcut for the feature's instantiation + */ + infix fun with(value: T): Feature { + return type.instantiate(name, value) + } + + /** + * Signifies that the feature can be instantiated with null values. + */ + fun nullable(): FeatureDeclaration { + require(type !is FeatureValueType.Nullable<*>) { "Repeated declaration as 'nullable'" } + return FeatureDeclaration(name, FeatureValueType.Nullable(type)) + } + + companion object { + inline fun > enum(name: String) = FeatureDeclaration(name, FeatureValueType.Enum(T::class.java)) + + fun int(name: String) = FeatureDeclaration(name, FeatureValueType.Int) + + fun double(name: String) = FeatureDeclaration(name, FeatureValueType.Double) + + fun float(name: String) = FeatureDeclaration(name, FeatureValueType.Float) + + fun long(name: String) = FeatureDeclaration(name, FeatureValueType.Long) + + fun aClass(name: String) = FeatureDeclaration(name, FeatureValueType.Class) + + fun boolean(name: String) = FeatureDeclaration(name, FeatureValueType.Boolean) + + fun categorical(name: String, possibleValues: Set) = FeatureDeclaration(name, FeatureValueType.Categorical(possibleValues)) + + fun version(name: String) = FeatureDeclaration(name, FeatureValueType.Version) + } +} diff --git a/platform/ml-api/src/com/intellij/platform/ml/FeatureFilter.kt b/platform/ml-api/src/com/intellij/platform/ml/FeatureFilter.kt new file mode 100644 index 000000000000..c8faa0a3f218 --- /dev/null +++ b/platform/ml-api/src/com/intellij/platform/ml/FeatureFilter.kt @@ -0,0 +1,18 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml + +import org.jetbrains.annotations.ApiStatus + +/** + * A filter that indicates whether a feature matches a feature set or not. + * A feature filter functions within a particular tier. + */ +@ApiStatus.Internal +fun interface FeatureFilter { + fun accept(featureDeclaration: FeatureDeclaration<*>): Boolean + + companion object { + val REJECT_ALL = FeatureFilter { false } + val ACCEPT_ALL = FeatureFilter { true } + } +} diff --git a/platform/ml-api/src/com/intellij/platform/ml/FeatureValueType.kt b/platform/ml-api/src/com/intellij/platform/ml/FeatureValueType.kt new file mode 100644 index 000000000000..d8e5760914e9 --- /dev/null +++ b/platform/ml-api/src/com/intellij/platform/ml/FeatureValueType.kt @@ -0,0 +1,83 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml + +import org.jetbrains.annotations.ApiStatus + +/** + * A type of [Feature]. + * + * If you need to use another type of feature to use in your ML model or in analysis, + * consider contacting the ML API developers. + */ +@ApiStatus.Internal +sealed class FeatureValueType { + abstract fun instantiate(name: String, value: T): Feature + + data class Nullable(val baseType: FeatureValueType) : FeatureValueType() { + override fun instantiate(name: String, value: T?): Feature { + return Feature.Nullable(name, value, baseType) + } + } + + data class Enum>(val enumClass: java.lang.Class) : FeatureValueType() { + override fun instantiate(name: String, value: T): Feature { + return Feature.Enum(name, value) + } + } + + data class Categorical(val possibleValues: Set) : FeatureValueType() { + override fun instantiate(name: String, value: String): Feature { + require(value in possibleValues) { + val caseNonMatchingValue = possibleValues.find { it.equals(name, ignoreCase = true) } + "Feature $name cannot be assigned to value $value," + + "all possible values are $possibleValues. " + + "Possible match (but case does not match): $caseNonMatchingValue" + } + return Feature.Categorical(name, value, possibleValues) + } + } + + object Int : FeatureValueType() { + override fun instantiate(name: String, value: kotlin.Int): Feature { + return Feature.Int(name, value) + } + } + + object Double : FeatureValueType() { + override fun instantiate(name: String, value: kotlin.Double): Feature { + return Feature.Double(name, value) + } + } + + object Float : FeatureValueType() { + override fun instantiate(name: String, value: kotlin.Float): Feature { + return Feature.Float(name, value) + } + } + + object Long : FeatureValueType() { + override fun instantiate(name: String, value: kotlin.Long): Feature { + return Feature.Long(name, value) + } + } + + object Class : FeatureValueType>() { + override fun instantiate(name: String, value: java.lang.Class<*>): Feature { + return Feature.Class(name, value) + } + } + + object Boolean : FeatureValueType() { + override fun instantiate(name: String, value: kotlin.Boolean): Feature { + return Feature.Boolean(name, value) + } + } + + object Version : FeatureValueType() { + override fun instantiate(name: String, value: com.intellij.openapi.util.Version): Feature { + return Feature.Version(name, value) + } + } + + override fun toString(): String = this.javaClass.simpleName +} diff --git a/platform/ml-api/src/com/intellij/platform/ml/MutableEnvironment.kt b/platform/ml-api/src/com/intellij/platform/ml/MutableEnvironment.kt new file mode 100644 index 000000000000..40f1a9db0d49 --- /dev/null +++ b/platform/ml-api/src/com/intellij/platform/ml/MutableEnvironment.kt @@ -0,0 +1,22 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml + +import org.jetbrains.annotations.ApiStatus + +/** + * A mutable version of the [Environment], that can be extended. + */ +@ApiStatus.Internal +interface MutableEnvironment : Environment { + /** + * Adds new tier instance to the existing ones. + * @throws IllegalArgumentException if there is already an instance of such [tier] in the environment. + */ + fun putTierInstance(tier: Tier, instance: T) +} + +operator fun MutableEnvironment.set(tier: Tier, instance: T) = putTierInstance(tier, instance) + +fun MutableEnvironment.putTierInstance(tierInstance: TierInstance) { + this[tierInstance.tier] = tierInstance.instance +} diff --git a/platform/ml-api/src/com/intellij/platform/ml/ObsoleteTierDescriptor.kt b/platform/ml-api/src/com/intellij/platform/ml/ObsoleteTierDescriptor.kt new file mode 100644 index 000000000000..53a4d00794aa --- /dev/null +++ b/platform/ml-api/src/com/intellij/platform/ml/ObsoleteTierDescriptor.kt @@ -0,0 +1,38 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml + +import org.jetbrains.annotations.ApiStatus + +/** + * A [TierDescriptor] that is not fully aware of the features it describes with. + * It is used to make a smooth transition from the old forms of features providers to the new API. + * + * It will be removed eventually. + */ +@Deprecated(message = "ObsoleteTierDescriptors' features are not logged, as they are missing a feature declaration", + replaceWith = ReplaceWith("TierDescriptor", "com.intellij.platform.ml"), + level = DeprecationLevel.WARNING) +@ApiStatus.Internal +interface ObsoleteTierDescriptor : TierDescriptor { + /** + * The case when a [TierDescriptor] is an [ObsoleteTierDescriptor] is handled in the API + * individually each time. And in those times, we cannot rely on the [descriptionDeclaration], + * because it may not be correct. + */ + override val descriptionDeclaration: Set> + get() = throw IllegalAccessError("Obsolete descriptor does not provide a description declaration") + + override fun couldBeUseful(usefulFeaturesFilter: FeatureFilter): Boolean { + return true + } + + /** + * The declaration that is already known. + * If there is a feature that is computed but not declared, then they can be logged, + * so you can add them to the declaration. + * + * Turn on the ml.description.logMissing registry key to log the missing features. + */ + val partialDescriptionDeclaration: Set> + get() = emptySet() +} diff --git a/platform/ml-api/src/com/intellij/platform/ml/ScopeEnvironment.kt b/platform/ml-api/src/com/intellij/platform/ml/ScopeEnvironment.kt new file mode 100644 index 000000000000..b695c06ffd28 --- /dev/null +++ b/platform/ml-api/src/com/intellij/platform/ml/ScopeEnvironment.kt @@ -0,0 +1,40 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml + +import org.jetbrains.annotations.ApiStatus + +/** + * An environment, that restricts access to the [baseEnvironment] with the [scope]. + * It is created before passing an [Environment] to a [TierDescriptor], so it will + * not be accessing non-declared tiers. + */ +@ApiStatus.Internal +class ScopeEnvironment private constructor( + private val baseEnvironment: Environment, + private val scope: Set> +) : Environment { + override val tiers: Set> = scope + + override fun getInstance(tier: Tier): T { + require(tier in scope) { "$tier was not supposed to be used, allowed scope: ${scope}" } + return baseEnvironment.getInstance(tier) + } + + companion object { + fun Environment.restrictedBy(scope: Set>) = ScopeEnvironment(this, scope.intersect(this.tiers)) + + fun Environment.narrowedTo(scope: Set>): ScopeEnvironment { + require(scope.all { it in this }) + return ScopeEnvironment(this, scope) + } + + fun Environment.accessibleSafelyByOrNull(requester: TierRequester): ScopeEnvironment? { + if (requester.requiredTiers.any { it !in this }) { + return null + } + return this.accessibleSafelyBy(requester) + } + + fun Environment.accessibleSafelyBy(requester: TierRequester): ScopeEnvironment = this.narrowedTo(requester.requiredTiers) + } +} diff --git a/platform/ml-api/src/com/intellij/platform/ml/Session.kt b/platform/ml-api/src/com/intellij/platform/ml/Session.kt new file mode 100644 index 000000000000..c1bfe063b94d --- /dev/null +++ b/platform/ml-api/src/com/intellij/platform/ml/Session.kt @@ -0,0 +1,100 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml + +import com.intellij.platform.ml.Session.StartOutcome.Failure +import org.jetbrains.annotations.ApiStatus + +/** + * A period of making predictions with an ML model. + * + * The session's depth corresponds to the number of levels in a [com.intellij.platform.ml.impl.MLTask]. + * + * This is a builder of the tree-like session structure. + * To learn how sessions work, please address this interface's implementations. + */ +@ApiStatus.Internal +sealed interface Session

{ + /** + * An outcome of an ML session's start. + * + * It is considered to be a non-exceptional situation when it is not possible to start an ML session. + * If so, then a [Failure] is returned and could be handled by the user. + */ + sealed interface StartOutcome

{ + /** + * A ready-to-use ml session, in case the start was successful + */ + val session: Session

? + + fun requireSuccess(): Session

= when (this) { + is Failure -> throw this.asThrowable() + is Success -> this.session + } + + /** + * Indicates that nothing went wrong, and the start was successful + */ + class Success

(override val session: Session

) : StartOutcome

+ + /** + * Indicates that there was some issue during the start. + * The problem could be identified more precisely by looking at + * the [Failure]'s class. + */ + open class Failure

: StartOutcome

{ + override val session: Session

? = null + + open val failureDetails: String + get() = "Unable to start ml session, failure: $this" + + open fun asThrowable(): Throwable { + return Exception(failureDetails) + } + } + + class UncaughtException

(val exception: Throwable) : Failure

() { + override fun asThrowable(): Throwable { + return Exception("An unexpected exception", exception) + } + } + } +} + +/** + * A session, that holds other sessions. + */ +interface NestableMLSession

: Session

{ + /** + * Start another nested session within this one, that will inherit + * this session's features. + * + * @param levelMainEnvironment The main session's environment that contains main ML task's tiers. + * + * @return Either [NestableMLSession] or [SinglePrediction], depending on whether the + * last ML task's level has been reached. + */ + fun createNestedSession(levelMainEnvironment: Environment): Session

+ + /** + * Declare that no more nested sessions will be created from this moment on. + * It must be called. + */ + fun onLastNestedSessionCreated() +} + +/** + * A session, that is dedicated to create one prediction at most. + */ +interface SinglePrediction

: Session

{ + /** + * Call ML model's inference and produce the prediction. + * On one object, exactly one function must be called during its lifetime: this or [cancelPrediction] + */ + fun predict(): P + + /** + * Declare that model's inference will not be called. + * It must be called in case it's decided that a prediction is not needed. + */ + fun cancelPrediction() +} diff --git a/platform/ml-api/src/com/intellij/platform/ml/Tier.kt b/platform/ml-api/src/com/intellij/platform/ml/Tier.kt new file mode 100644 index 000000000000..43641b5a77e4 --- /dev/null +++ b/platform/ml-api/src/com/intellij/platform/ml/Tier.kt @@ -0,0 +1,76 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml + +import org.jetbrains.annotations.ApiStatus + +/** + * A category of your application's objects that could be used to run the ML API. + * + * Tiers' main purpose is to be sources for features production, to pass more information to an ML model and to have the most precise + * prediction. + * Tiers are described per [TierDescriptor]. + * There are two categories of tiers when it comes to description: + * + * - Main tiers: + * the ones that are declared in an [com.intellij.platform.ml.impl.MLTask] and provided at the problem's application point, + * when [com.intellij.platform.ml.NestableMLSession.createNestedSession] is called. + * They exist in each session. + * + * - Additional tiers: the ones that are declared in [com.intellij.platform.ml.impl.approach.LogDrivenModelInference]. + * They are provided by various [EnvironmentExtender]s occasionally when an [com.intellij.platform.ml.impl.environment.ExtendedEnvironment] + * is crated, and they could be absent, as it was not able to satisfy extenders' requirements, or a tier instance was simply not + * present this time. + * So we can't rely on their existence. + * + * On the top of that, tiers could serve as helper objects for other user-defined objects, as [TierDescriptor]s and [EnvironmentExtender]s. + * The listed interfaces are [TierRequester]s, which means that they require other tiers for proper functioning. + * You could create additional [EnvironmentExtender]s to create new tiers, or to define other ways to instantiate existing tiers, but + * from other sources. + */ +@ApiStatus.Internal +abstract class Tier { + /** + * A unique name of a tier (among other tiers in your application). + * Class name is used by default. + */ + open val name: String + get() = this.javaClass.simpleName + + override fun toString(): String = name + + infix fun with(instance: T) = TierInstance(this, instance) +} + +/** + * A helper class to type-safely handle pairs of [tier] and the corresponding [instance]. + */ +data class TierInstance(val tier: Tier, val instance: T) + +typealias PerTier = Map, T> + +typealias PerTierInstance = Map, T> + +fun Iterable>.joinByUniqueTier(): PerTier { + val joinedPerTier = mutableMapOf, T>() + + this.forEach { perTierMapping -> + perTierMapping.forEach { (tier, value) -> + require(tier !in joinedPerTier) + joinedPerTier[tier] = value + } + } + + return joinedPerTier +} + +fun , CO : MutableCollection> Iterable>.mergePerTier(createCollection: () -> CO): PerTier { + val joinedPerTier = mutableMapOf, CO>() + for (perTierMapping in this) { + for ((tier, anotherCollection) in perTierMapping) { + val existingCollection = joinedPerTier[tier] ?: emptyList() + require(anotherCollection.all { it !in existingCollection }) + joinedPerTier.computeIfAbsent(tier) { createCollection() }.addAll(anotherCollection) + } + } + return joinedPerTier +} diff --git a/platform/ml-api/src/com/intellij/platform/ml/TierDescriptor.kt b/platform/ml-api/src/com/intellij/platform/ml/TierDescriptor.kt new file mode 100644 index 000000000000..2c28a2113989 --- /dev/null +++ b/platform/ml-api/src/com/intellij/platform/ml/TierDescriptor.kt @@ -0,0 +1,81 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml + +import com.intellij.openapi.extensions.ExtensionPointName +import org.jetbrains.annotations.ApiStatus + +/** + * Provides features for a particular [tier]. + * + * It is also a [TierRequester], which implies that [additionallyRequiredTiers] could be defined, in + * case additional objects are required to describe the [tier]'s instance. + * If so, an attempt will be made to create an [com.intellij.platform.ml.impl.environment.ExtendedEnvironment], + * and the extended environment that contains the described [tier] and [additionallyRequiredTiers] + * will be passed to the [describe] function. + * + * @see ObsoleteTierDescriptor if you are looking for an opportunity to smoothly transfer your old feature providers to this API. + */ +@ApiStatus.Internal +interface TierDescriptor : TierRequester { + /** + * The tier, that this descriptor is describing (giving features to). + */ + val tier: Tier<*> + + /** + * All features that could ever be used in the declaration. + * + * _Important_: Do not change features' names. + * Names serve as features identifiers, and they are used to pass to ML models. + */ + val descriptionDeclaration: Set> + + /** + * Computes [tier]'s features from with the features from [descriptionDeclaration]. + * + * If [additionallyRequiredTiers] could not be satisfied, then this descriptor is not called at all. + * Be aware that _each_ feature given in [descriptionDeclaration] must be calculated here. + * If it is possible that a feature is not computable withing particular circumstances, then + * you could declare your feature as nullable: [FeatureDeclaration.nullable]. + * + * @param environment Contains [tier] and [additionallyRequiredTiers]. + * @param usefulFeaturesFilter Accepts features, that could make any difference to compute this time. + * A feature is considered to be useful if an ML model is aware of the feature, + * or it is explicitly said that "ML model is not aware of this feature, but it must be logged" + * (@see [com.intellij.platform.ml.impl.approach.LogDrivenModelInference]). + * + * @throws IllegalArgumentException If a feature is missing: it is accepted by [usefulFeaturesFilter] + * and it is declared, but not present in the result. + * @throws IllegalArgumentException If a redundant feature was computed, that was not declared. + */ + fun describe(environment: Environment, usefulFeaturesFilter: FeatureFilter): Set + + /** + * Declares a requirement and ensures, that [describe] will be called if and only if + * the [environment] will contain both [tier] and [additionallyRequiredTiers]. + */ + override val requiredTiers: Set> + get() = additionallyRequiredTiers + tier + + /** + * Tiers that are required additionally to perform description. + * + * Such tiers' instances could be created via existing or additionally + * created [EnvironmentExtender]s. + */ + val additionallyRequiredTiers: Set> + get() = emptySet() + + /** + * Tells if the descriptor could generate any useful features at all. + * @param usefulFeaturesFilter Accepts features, that make sense to calculate at the time of this invocation. + * If the filter does not accept a feature, it means that its computation will not make any difference. + */ + fun couldBeUseful(usefulFeaturesFilter: FeatureFilter): Boolean { + return descriptionDeclaration.any { usefulFeaturesFilter.accept(it) } + } + + companion object { + val EP_NAME: ExtensionPointName = ExtensionPointName.create("com.intellij.platform.ml.descriptor") + } +} diff --git a/platform/ml-api/src/com/intellij/platform/ml/TierInstanceStorage.kt b/platform/ml-api/src/com/intellij/platform/ml/TierInstanceStorage.kt new file mode 100644 index 000000000000..a3c023a4e93f --- /dev/null +++ b/platform/ml-api/src/com/intellij/platform/ml/TierInstanceStorage.kt @@ -0,0 +1,65 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml + +import org.jetbrains.annotations.ApiStatus + +/** + * An object that contains tier instances, and it could be extended. + */ +@ApiStatus.Internal +class TierInstanceStorage : MutableEnvironment { + private val instances: MutableMap, Any> = mutableMapOf() + + override val tiers: Set> + get() = instances.keys + + override val tierInstances: Set> + get() = instances + .map { (tier, tierInstance) -> tier withUnsafe tierInstance } + .toSet() + + private infix fun Tier.withUnsafe(value: P): TierInstance { + @Suppress("UNCHECKED_CAST") + return this.with(value as T) + } + + override fun getInstance(tier: Tier): T { + val tierInstance = instances[tier] + @Suppress("UNCHECKED_CAST") + return requireNotNull(tierInstance) as T + } + + override fun putTierInstance(tier: Tier, instance: T) { + require(tier !in this) { + "Tier $tier is already registered in the storage. Old value: '${instances[tier]}', new value: '$instance'" + } + instances[tier] = instance + } + + companion object { + fun copyOf(environment: Environment): TierInstanceStorage { + val storage = TierInstanceStorage() + fun putInstance(tier: Tier) { + storage[tier] = environment[tier] + } + environment.tiers.forEach { putInstance(it) } + return storage + } + + fun joined(environments: Iterable): Environment { + val commonStorage = TierInstanceStorage() + + fun putCapturingType(tier: Tier, environment: Environment) { + commonStorage[tier] = environment[tier] + } + + environments.forEach { environment -> + environment.tiers.forEach { tier -> + putCapturingType(tier, environment) + } + } + + return commonStorage + } + } +} \ No newline at end of file diff --git a/platform/ml-api/src/com/intellij/platform/ml/TierRequester.kt b/platform/ml-api/src/com/intellij/platform/ml/TierRequester.kt new file mode 100644 index 000000000000..192d06fef6a3 --- /dev/null +++ b/platform/ml-api/src/com/intellij/platform/ml/TierRequester.kt @@ -0,0 +1,26 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml + +import org.jetbrains.annotations.ApiStatus + +/** + * An interface that represents an object, that requires some additional tiers + * for proper functioning. + * + * If the requirements could not be satisfied, then the main object's functionality will not be run. + * Please address the interface's inheritors, such as [TierDescriptor], [EnvironmentExtender], + * and [com.intellij.platform.ml.impl.model.MLModel.Provider] for more details. + */ +@ApiStatus.Internal +interface TierRequester { + /** + * The tiers, that are required to use the object. + */ + val requiredTiers: Set> + + companion object { + fun Iterable.fulfilledBy(environment: Environment): List { + return this.filter { it.requiredTiers.all { requiredTier -> requiredTier in environment } } + } + } +} diff --git a/platform/ml-impl/intellij.platform.ml.impl.iml b/platform/ml-impl/intellij.platform.ml.impl.iml index e593e614c870..9de0fd35792a 100644 --- a/platform/ml-impl/intellij.platform.ml.impl.iml +++ b/platform/ml-impl/intellij.platform.ml.impl.iml @@ -28,6 +28,7 @@ + @@ -44,5 +45,7 @@ + + \ No newline at end of file diff --git a/platform/ml-impl/resources/META-INF/ml.xml b/platform/ml-impl/resources/META-INF/ml.xml index 0bd706b54600..48e562829999 100644 --- a/platform/ml-impl/resources/META-INF/ml.xml +++ b/platform/ml-impl/resources/META-INF/ml.xml @@ -6,6 +6,9 @@ + + @@ -17,9 +20,14 @@ + interface="com.intellij.platform.ml.impl.turboComplete.SmartPipelineRunner" + dynamic="true"/> + - + diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/DescriptionComputer.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/DescriptionComputer.kt new file mode 100644 index 000000000000..93214fb332a2 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/DescriptionComputer.kt @@ -0,0 +1,46 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl + +import com.intellij.platform.ml.* +import com.intellij.platform.ml.ScopeEnvironment.Companion.accessibleSafelyBy +import org.jetbrains.annotations.ApiStatus + +/** + * Computes tiers descriptions, calling [TierDescriptor.describe], + * or caching the descriptions. + */ +@ApiStatus.Internal +interface DescriptionComputer { + /** + * + * @param tier The tier that all is being described. + * @param descriptors All relevant descriptors, that can be called within [environment] + * (it is guaranteed that they all describe [tier] and their requirements are fulfilled by [environment]). + * @param environment An environment, that contains all tiers required to run any of the [descriptors]. + * @param usefulFeaturesFilter Accepts features, that are meaningful to compute. If the filter does not + * accept a feature, its computation will not make any difference later. + */ + fun computeDescription( + tier: Tier<*>, + descriptors: List, + environment: Environment, + usefulFeaturesFilter: FeatureFilter, + ): Map> +} + +/** + * Does not cache descriptors, primitively computes all descriptors all over again each time. + */ +@ApiStatus.Internal +class StateFreeDescriptionComputer : DescriptionComputer { + override fun computeDescription(tier: Tier<*>, + descriptors: List, + environment: Environment, + usefulFeaturesFilter: FeatureFilter): Map> { + return descriptors.associateWith { descriptor -> + require(descriptor.tier == tier) + require(descriptor.requiredTiers.all { it in environment.tiers }) + descriptor.describe(environment.accessibleSafelyBy(descriptor), usefulFeaturesFilter) + } + } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/FeaturesSelector.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/FeaturesSelector.kt new file mode 100644 index 000000000000..3c2b8ccddc43 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/FeaturesSelector.kt @@ -0,0 +1,91 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl + +import com.intellij.platform.ml.FeatureDeclaration +import com.intellij.platform.ml.PerTier +import org.jetbrains.annotations.ApiStatus + +/** + * Determines a set of features, and can tell if a set of features is "complete" or not. + * The Meaning of the "completeness" depends on the selector's usage. + * + * Selectors are used to determine if there are enough features to run an ML model, + * and to not compute redundant features, the ones that will not be used by the ML model, + * and are not desired to be logged. + * + * Two types of instances have the power to select features: + * - [com.intellij.platform.ml.impl.model.MLModel] to tell which features are known. + * - [com.intellij.platform.ml.impl.approach.LogDrivenModelInference] to tell which features are not known by the ML model, + * but they are still must be computed and then logged. + */ +@ApiStatus.Internal +interface FeatureSelector { + /** + * @param availableFeatures A set of features that the selection will be built from. + * + * @return A set of features that are 'selected' and a completeness marker + * [Selection.Incomplete] is returned if there are some more features that are 'selected', + * but they are not present among [availableFeatures]. + * [Selection.Complete] is returned otherwise. + */ + fun select(availableFeatures: Set>): Selection + + /** + * @param featureDeclaration A single feature, that needs to be selected. + * + * @return If the feature belongs to the determined set of features. + */ + fun select(featureDeclaration: FeatureDeclaration<*>): Boolean { + return select(setOf(featureDeclaration)).selectedFeatures.isNotEmpty() + } + + sealed class Selection(val selectedFeatures: Set>) { + class Complete(selectedFeatures: Set>) : Selection(selectedFeatures) + + open class Incomplete(selectedFeatures: Set>) : Selection(selectedFeatures) { + open val details: String = "Incomplete selection, only these were selected: $selectedFeatures" + } + + companion object { + val NOTHING = Complete(emptySet()) + } + } + + companion object { + val NOTHING = object : FeatureSelector { + override fun select(availableFeatures: Set>): Selection = Selection.NOTHING + } + + val EVERYTHING = object : FeatureSelector { + override fun select(availableFeatures: Set>): Selection = Selection.Complete(availableFeatures) + } + + infix fun FeatureSelector.or(other: FeatureSelector): FeatureSelector { + return object : FeatureSelector { + override fun select(availableFeatures: Set>): Selection { + val thisSelection = this@or.select(availableFeatures) + val otherSelection = other.select(availableFeatures) + val joinedSelection = (thisSelection.selectedFeatures + otherSelection.selectedFeatures).toSet() + return if (thisSelection is Selection.Incomplete || otherSelection is Selection.Incomplete) { + val incompleteSelection = if (thisSelection is Selection.Incomplete) + thisSelection + else + otherSelection as Selection.Incomplete + + object : Selection.Incomplete(joinedSelection) { + override val details: String + get() = incompleteSelection.details + } + } + else + Selection.Complete(joinedSelection) + } + } + } + + infix fun PerTier.or(other: PerTier): PerTier { + require(this.keys == other.keys) + return keys.associateWith { this.getValue(it) or other.getValue(it) } + } + } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/MLModelUsageSession.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/MLModelUsageSession.kt new file mode 100644 index 000000000000..4dcdee573afc --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/MLModelUsageSession.kt @@ -0,0 +1,69 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl + +import com.intellij.platform.ml.* +import com.intellij.platform.ml.impl.model.MLModel +import com.intellij.platform.ml.impl.session.DescribedLevel +import com.intellij.platform.ml.impl.session.NestableStructureCollector +import com.intellij.platform.ml.impl.session.PredictionCollector +import com.intellij.platform.ml.impl.session.SessionTree +import org.jetbrains.annotations.ApiStatus + +/** + * A [SinglePrediction] performed by an ML model. + * The session's structure is collected by [collector], after the prediction is done or canceled. + */ +@ApiStatus.Internal +class MLModelPrediction, M : MLModel

, P : Any>( + private val mlModel: M, + private val collector: PredictionCollector, +) : SinglePrediction

{ + override fun predict(): P { + val prediction = mlModel.predict(collector.usableDescription) + collector.submitPrediction(prediction) + return prediction + } + + override fun cancelPrediction() { + collector.submitPrediction(null) + } +} + +/** + * A [NestableMLSession] of utilizing [mlModel]. + * The session's structure is collected by [collector] after [onLastNestedSessionCreated], + * and all nested sessions' structures are collected. + */ +@ApiStatus.Internal +class MLModelPredictionBranching, M : MLModel

, P : Any>( + private val mlModel: M, + private val collector: NestableStructureCollector +) : NestableMLSession

{ + override fun createNestedSession(levelMainEnvironment: Environment): Session

{ + val nestedLevelScheme = collector.levelPositioning.lowerTiers.first() + verifyTiersInMain(nestedLevelScheme.main, levelMainEnvironment.tiers) + val levelAdditionalTiers = nestedLevelScheme.additional + + return if (collector.levelPositioning.lowerTiers.size == 1) { + val nestedCollector = collector.nestPrediction(levelMainEnvironment, levelAdditionalTiers) + MLModelPrediction(mlModel, nestedCollector) + } + else { + assert(collector.levelPositioning.lowerTiers.size > 1) + val nestedCollector = collector.nestBranch(levelMainEnvironment, levelAdditionalTiers) + MLModelPredictionBranching(mlModel, nestedCollector) + } + } + + override fun onLastNestedSessionCreated() { + collector.onLastNestedCollectorCreated() + } +} + +private fun verifyTiersInMain(expected: Set>, actual: Set<*>) { + require(expected == actual) { + "Tier set in the main environment is not like it was declared. " + + "Declared $expected, " + + "but given $actual" + } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/MLTask.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/MLTask.kt new file mode 100644 index 000000000000..ea25b50ad652 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/MLTask.kt @@ -0,0 +1,129 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl + +import com.intellij.openapi.extensions.ExtensionPointName +import com.intellij.platform.ml.* +import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform +import com.intellij.platform.ml.impl.apiPlatform.ReplaceableIJPlatform +import com.intellij.platform.ml.impl.session.AdditionalTierScheme +import com.intellij.platform.ml.impl.session.Level +import com.intellij.platform.ml.impl.session.MainTierScheme +import org.jetbrains.annotations.ApiStatus + + +/** + * Is a declaration of a place in code, where a classical machine learning approach + * is desired to be applied. + * + * The proper way to create new tasks is to create a static object-inheritor of this class + * (see inheritors of this class to find all implemented tasks in your project). + * + * @param name The unique name of an ML Task + * @param levels The main tiers of the task that will be provided within the task's application place + * @param predictionClass The class of an object, that will serve as "prediction" + * @param T The type of prediction + */ +@ApiStatus.Internal +abstract class MLTask protected constructor( + val name: String, + val levels: List>>, + val predictionClass: Class +) + +/** + * A method of approaching an ML task. + * Usually, it is inferencing an ML model and collecting logs. + * + * Each [MLTaskApproach] is initialized once by the corresponding [MLTaskApproachInitializer], + * then the [apiPlatform] is fixed. + * + * @see [com.intellij.platform.ml.impl.approach.LogDrivenModelInference] for currently used approach. + */ +@ApiStatus.Internal +interface MLTaskApproach

{ + /** + * The task this approach is solving. + * Each approach is dedicated to one and only task, and it is aware of it. + */ + val task: MLTask

+ + /** + * The platform, this approach is called within, that was provided by [MLTaskApproachInitializer] + */ + val apiPlatform: MLApiPlatform + + /** + * A static declaration of the features, used in the approach. + */ + val approachDeclaration: Declaration + + /** + * Acquire the ML model and start the session. + * + * @return [Session.StartOutcome.Failure] if something went wrong during the start, [Session.StartOutcome.Success] + * which contains the started session otherwise. + */ + fun startSession(permanentSessionEnvironment: Environment): Session.StartOutcome

+ + data class Declaration( + val sessionFeatures: Map>>, + val levelsScheme: List + ) + + companion object { + fun

findMlApproach(task: MLTask

, apiPlatform: MLApiPlatform = ReplaceableIJPlatform): MLTaskApproach

{ + return apiPlatform.accessApproachFor(task) + } + + fun

startMLSession(task: MLTask

, + permanentSessionEnvironment: Environment, + apiPlatform: MLApiPlatform = ReplaceableIJPlatform): Session.StartOutcome

{ + val approach = findMlApproach(task, apiPlatform) + return approach.startSession(permanentSessionEnvironment) + } + + fun

MLTask

.startMLSession(permanentSessionEnvironment: Environment): Session.StartOutcome

{ + return startMLSession(this, permanentSessionEnvironment) + } + + fun

MLTask

.startMLSession(permanentTierInstances: Iterable>): Session.StartOutcome

{ + return this.startMLSession(Environment.of(permanentTierInstances)) + } + + fun

MLTask

.startMLSession(vararg permanentTierInstances: TierInstance<*>): Session.StartOutcome

{ + return this.startMLSession(Environment.of(*permanentTierInstances)) + } + } +} + +/** + * Initializes an [MLTaskApproach] + */ +@ApiStatus.Internal +interface MLTaskApproachInitializer

{ + /** + * The task, that the created [MLTaskApproach] is dedicated to solve. + */ + val task: MLTask

+ + /** + * Initializes the approach. + * It is called only once during the application's runtime. + * So it is crucial that this function will accept the [MLApiPlatform] you want it to. + * + * To access the API in order to build event validator statically, + * FUS uses the actual [com.intellij.platform.ml.impl.apiPlatform.IJPlatform], which could be problematic if you + * want to test FUS logs. + * So make sure that you will replace it with your test platform in + * time via [com.intellij.platform.ml.impl.apiPlatform.ReplaceableIJPlatform.replacingWith]. + */ + fun initializeApproachWithin(apiPlatform: MLApiPlatform): MLTaskApproach

+ + companion object { + val EP_NAME = ExtensionPointName>("com.intellij.platform.ml.impl.approach") + } +} + +typealias LevelScheme = Level, PerTier> + +typealias LevelTiers = Level>, Set>> diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/apiPlatform/CodeLikePrinter.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/apiPlatform/CodeLikePrinter.kt new file mode 100644 index 000000000000..34ca84d2781c --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/apiPlatform/CodeLikePrinter.kt @@ -0,0 +1,40 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.apiPlatform + +import com.intellij.platform.ml.FeatureDeclaration +import com.intellij.platform.ml.FeatureValueType +import org.jetbrains.annotations.ApiStatus + +/** + * Prints code-like features representation, when they have to be logged for an API user's + * convenience (so they can just copy-paste the logged features into their code). + */ +@ApiStatus.Internal +class CodeLikePrinter { + private val > FeatureValueType.Enum.codeLikeType: String + get() = this.enumClass.name + + private fun FeatureValueType.makeCodeLikeString(name: String): String = when (this) { + FeatureValueType.Boolean -> "FeatureDeclaration.boolean(\"$name\")" + FeatureValueType.Class -> "FeatureDeclaration.aClass(\"$name\")" + FeatureValueType.Double -> "FeatureDeclaration.double(\"$name\")" + is FeatureValueType.Enum<*> -> "FeatureDeclaration.enum<${this.codeLikeType}>(\"$name\")" + FeatureValueType.Float -> "FeatureDeclaration.float(\"${name}\")" + FeatureValueType.Int -> "FeatureDeclaration.int(\"${name}\")" + FeatureValueType.Long -> "FeatureDeclaration.long(\"${name}\")" + is FeatureValueType.Nullable<*> -> "${this.baseType.makeCodeLikeString(name)}.nullable()" + is FeatureValueType.Categorical -> { + val possibleValuesSerialized = possibleValues.joinToString(", ") { "\"$it\"" } + "FeatureDeclaration.categorical(\"$name\", setOf(${possibleValuesSerialized}))" + } + FeatureValueType.Version -> "FeatureDeclaration.version(\"${name}\")" + } + + fun printCodeLikeString(featureDeclaration: FeatureDeclaration): String { + return featureDeclaration.type.makeCodeLikeString(featureDeclaration.name) + } + + fun printCodeLikeString(featureDeclarations: Collection>): String { + return featureDeclarations.joinToString(", ") { printCodeLikeString(it) } + } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/apiPlatform/IJPlatform.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/apiPlatform/IJPlatform.kt new file mode 100644 index 000000000000..2405379f1b5e --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/apiPlatform/IJPlatform.kt @@ -0,0 +1,186 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.apiPlatform + +import com.intellij.openapi.diagnostic.thisLogger +import com.intellij.openapi.util.registry.Registry +import com.intellij.platform.ml.EnvironmentExtender +import com.intellij.platform.ml.Feature +import com.intellij.platform.ml.ObsoleteTierDescriptor +import com.intellij.platform.ml.TierDescriptor +import com.intellij.platform.ml.impl.MLTaskApproachInitializer +import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform.ExtensionController +import com.intellij.platform.ml.impl.apiPlatform.ReplaceableIJPlatform.replacingWith +import com.intellij.platform.ml.impl.logger.MLEvent +import com.intellij.platform.ml.impl.monitoring.MLApiStartupListener +import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener +import com.intellij.util.application +import com.intellij.util.messages.Topic +import org.jetbrains.annotations.ApiStatus +import org.jetbrains.annotations.NonNls + +@ApiStatus.Internal +fun interface MessagingProvider { + fun provide(collector: (T) -> Unit) + + companion object { + inline fun > createTopic(displayName: @NonNls String): Topic

{ + return Topic.create(displayName, P::class.java) + } + + fun > collect(topic: Topic

): List { + val collected = mutableListOf() + application.messageBus.syncPublisher(topic).provide { collected.add(it) } + return collected + } + } +} + +fun interface MLTaskListenerProvider : MessagingProvider { + companion object { + val TOPIC = MessagingProvider.createTopic("ml.task") + } +} + +fun interface MLEventProvider : MessagingProvider { + companion object { + val TOPIC = MessagingProvider.createTopic("ml.event") + } +} + +fun interface MLApiStartupListenerProvider : MessagingProvider { + companion object { + val TOPIC = MessagingProvider.createTopic("ml.startup") + } +} + +/** + * A representation of the "real-life" [MLApiPlatform], whose content is + * the content of the corresponding Extension Points. + * It is used at the API's entry point, unless it is not replaced by another. + * + * It shouldn't be used due to low testability. + * Use [ReplaceableIJPlatform] instead. + */ +@ApiStatus.Internal +private data object IJPlatform : MLApiPlatform() { + override val tierDescriptors: List + get() = TierDescriptor.EP_NAME.extensionList + + override val environmentExtenders: List> + get() = EnvironmentExtender.EP_NAME.extensionList + + override val taskApproaches: List> + get() = MLTaskApproachInitializer.EP_NAME.extensionList + + override val taskListeners: List + get() = MessagingProvider.collect(MLTaskListenerProvider.TOPIC) + + override val events: List + get() = MessagingProvider.collect(MLEventProvider.TOPIC) + + override val startupListeners: List + get() = MessagingProvider.collect(MLApiStartupListenerProvider.TOPIC) + + override fun addStartupListener(listener: MLApiStartupListener): ExtensionController { + val connection = application.messageBus.connect() + connection.subscribe(MLApiStartupListenerProvider.TOPIC, MLApiStartupListenerProvider { collector -> collector(listener) }) + return ExtensionController { connection.disconnect() } + } + + override fun addTaskListener(taskListener: MLTaskGroupListener): ExtensionController { + val connection = application.messageBus.connect() + connection.subscribe(MLTaskListenerProvider.TOPIC, MLTaskListenerProvider { it(taskListener) }) + return ExtensionController { connection.disconnect() } + } + + override fun addEvent(event: MLEvent): ExtensionController { + val connection = application.messageBus.connect() + connection.subscribe(MLEventProvider.TOPIC, MLEventProvider { it(event) }) + return ExtensionController { connection.disconnect() } + } + + override fun manageNonDeclaredFeatures(descriptor: ObsoleteTierDescriptor, nonDeclaredFeatures: Set) { + if (!Registry.`is`("ml.description.logMissing")) return + val printer = CodeLikePrinter() + val codeLikeMissingDeclaration = printer.printCodeLikeString(nonDeclaredFeatures.map { it.declaration }) + thisLogger().info("${descriptor::class.java} is missing declaration: setOf($codeLikeMissingDeclaration)") + } +} + +/** + * Also a "real-life" [MLApiPlatform], but it can be replaced with another one any time. + * + * We always want to test [com.intellij.platform.ml.impl.MLTaskApproach]es. + * But after they are initialized by [com.intellij.platform.ml.impl.MLTaskApproachInitializer], + * the passed [MLApiPlatform] could already spread all the way within the API. + * But the user-defined instances of the api could be overridden for testing sake. + * + * To replace all [TierDescriptor], [EnvironmentExtender] and [MLTaskApproachInitializer] + * to test your code, you may call [replacingWith] and pass the desired environment, + * that contains all the objects you need for your test. + */ +@ApiStatus.Internal +object ReplaceableIJPlatform : MLApiPlatform() { + private var replacement: MLApiPlatform? = null + + private val platform: MLApiPlatform + get() = replacement ?: IJPlatform + + + override val tierDescriptors: List + get() = platform.tierDescriptors + + override val environmentExtenders: List> + get() = platform.environmentExtenders + + override val taskApproaches: List> + get() = platform.taskApproaches + + + override val taskListeners: List + get() = platform.taskListeners + + override val events: List + get() = platform.events + + override val startupListeners: List + get() = platform.startupListeners + + + override fun addStartupListener(listener: MLApiStartupListener): ExtensionController { + return extend(listener) { platform -> platform.addStartupListener(listener) } + } + + override fun addTaskListener(taskListener: MLTaskGroupListener): ExtensionController { + return extend(taskListener) { platform -> platform.addTaskListener(taskListener) } + } + + override fun addEvent(event: MLEvent): ExtensionController { + return extend(event) { platform -> platform.addEvent(event) } + } + + override fun manageNonDeclaredFeatures(descriptor: ObsoleteTierDescriptor, nonDeclaredFeatures: Set) = + platform.manageNonDeclaredFeatures(descriptor, nonDeclaredFeatures) + + private fun extend(obj: T, method: (MLApiPlatform) -> ExtensionController): ExtensionController { + val initialPlatform = platform + method(initialPlatform) + return ExtensionController { + require(initialPlatform == platform) { + "$obj should be removed within the same platform it was added in." + + "It was added in $initialPlatform, but removed from $platform" + } + } + } + + fun replacingWith(apiPlatform: MLApiPlatform, action: () -> T): T { + val oldApiPlatform = replacement + return try { + replacement = apiPlatform + action() + } + finally { + replacement = oldApiPlatform + } + } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/apiPlatform/MLApiPlatform.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/apiPlatform/MLApiPlatform.kt new file mode 100644 index 000000000000..c411bc92ccdc --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/apiPlatform/MLApiPlatform.kt @@ -0,0 +1,264 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.apiPlatform + +import com.intellij.platform.ml.* +import com.intellij.platform.ml.impl.MLTask +import com.intellij.platform.ml.impl.MLTaskApproach +import com.intellij.platform.ml.impl.MLTaskApproachInitializer +import com.intellij.platform.ml.impl.logger.MLEvent +import com.intellij.platform.ml.impl.logger.MLEventsLogger +import com.intellij.platform.ml.impl.monitoring.* +import com.intellij.platform.ml.impl.monitoring.MLApproachListener.Companion.asJoinedListener +import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.Companion.onAttemptedToStartSession +import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.Companion.targetedApproaches +import org.jetbrains.annotations.ApiStatus + +/** + * Represents an environment, that provides extendable parts of the ML API. + * + * Each entity inside the API could access the platform, it is running within, + * as everything happens after [com.intellij.platform.ml.impl.MLTaskApproachInitializer.initializeApproachWithin], + * where the platform is acknowledged. + * + * All usages of the ij platform functionality (extension points, registry keys, etc.) shall be + * accessed via this class. + */ +@ApiStatus.Internal +abstract class MLApiPlatform { + private val finishedInitialization = lazy { MLApiPlatformInitializationProcess() } + private var initializationStage: InitializationStage = InitializationStage.NotStarted + + /** + * Each [MLTaskApproach] is initialized only once during the application's lifetime. + * This function keeps track of all approaches that were initialized already, and initializes + * them when they are first needed. + */ + fun

accessApproachFor(task: MLTask

): MLTaskApproach

{ + return finishedInitialization.value.getApproachFor(task) + } + + + /** + * The extendable static state of the platform that must be fixed to create FUS group's validator. + * + * These values must be static, as they define the FUS scheme + */ + val staticState: StaticState + get() = StaticState(tierDescriptors, environmentExtenders, taskApproaches) + + /** + * The descriptors that are available in the platform. + * This value is interchangeable during the application runtime, + * see [staticState]. + */ + abstract val tierDescriptors: List + + /** + * The complete list of environment extenders, available in the platform. + * This value is interchangeable during the application runtime, + * see [staticState]. + */ + abstract val environmentExtenders: List> + + /** + * The complete list of the approaches for ML tasks, available in the platform. + * This value is interchangeable during the application runtime, + * see [staticState]. + */ + abstract val taskApproaches: List> + + + /** + * All the objects, that are listening execution of ML tasks. + * The collection is mutable, so new listeners could be added via [addTaskListener]. + * + * This value is mutable, new listeners could be added anytime. + */ + abstract val taskListeners: List + + /** + * Adds another provider for ML tasks' execution process monitoring dynamically. + * The event could be removed via the corresponding [ExtensionController.removeExtension] call. + * See [taskListeners]. + */ + abstract fun addTaskListener(taskListener: MLTaskGroupListener): ExtensionController + + /** + * ML events that will be written to FUS logs. + * As FUS is initialized only once, on the application's startup, they all must be registered + * before that via [addMLEventBeforeFusInitialized]. + * + * This value could be mutable, however, only during a short period of time: after the application's startup, + * and before FUS logs initialization. + */ + abstract val events: List + + /** + * Adds another ML event dynamically. + * The event could be removed via the corresponding [ExtensionController.removeExtension] call. + */ + fun addMLEventBeforeFusInitialized(event: MLEvent): ExtensionController { + when (val stage = initializationStage) { + is InitializationStage.Failed -> + throw Exception("Initialization of ML Api Platform has failed, events could not be added.", stage.asException) + is InitializationStage.PotentiallySuccessful -> { + require(stage.order <= InitializationStage.InitializingApproaches.order) { + "FUS group initialization has already been started, not allowed to register more ML Events" + } + return addEvent(event) + } + } + } + + /** + * The complete list of the listeners that are listening to the process of an MLApiPlatform's initialization. + * The initialization is performed in [finishedInitialization]'s init block. + * + * This value could be mutable, so + * additional listeners could be added via [addStartupListener]. + * + * If a listener was added after a certain initialization stage, + * only callbacks of those stages will be triggered later that have not happened yet. + */ + abstract val startupListeners: List + + /** + * Adds another startup listener. + * The listener could be removed via the corresponding [ExtensionController.removeExtension] call. + */ + abstract fun addStartupListener(listener: MLApiStartupListener): ExtensionController + + + /** + * Declares how the computed but non-declared features will be handled. + */ + abstract fun manageNonDeclaredFeatures(descriptor: ObsoleteTierDescriptor, nonDeclaredFeatures: Set) + + + internal abstract fun addEvent(event: MLEvent): ExtensionController + + fun interface ExtensionController { + fun removeExtension() + } + + data class StaticState( + val tierDescriptors: List, + val environmentExtenders: List>, + val taskApproaches: List>, + ) + + private sealed class InitializationStage(val callListener: (MLApiStartupProcessListener) -> Unit) { + sealed class PotentiallySuccessful(val order: Int, callListener: (MLApiStartupProcessListener) -> Unit) : InitializationStage(callListener) + + class Failed(lastStage: InitializationStage, nextStage: InitializationStage, exception: Throwable, callListener: (MLApiStartupProcessListener) -> Unit) : InitializationStage(callListener) { + val asException = Exception("Failed to proceed from the initialization stage $lastStage to $nextStage", exception) + } + + data object NotStarted : PotentiallySuccessful(0, {}) + data object InitializingApproaches : PotentiallySuccessful(1, { it.onStartedInitializingApproaches() }) + data class InitializingFUS(val initializedApproaches: Collection>) : PotentiallySuccessful(2, { + it.onStartedInitializingFus(initializedApproaches) + }) + + data object Finished : PotentiallySuccessful(3, { it.onFinished() }) + } + + private inner class MLApiPlatformInitializationProcess { + val approachPerTask: Map, MLTaskApproach<*>> + private val completeInitializersList: List> = taskApproaches.toMutableList() + + init { + require(initializationStage == InitializationStage.NotStarted) { "ML API Platform's initialization should not be run twice" } + + fun currentStartupListeners(): List = startupListeners.map { it.onBeforeStarted(this@MLApiPlatform) } + + fun proceedToNextStage(nextStage: InitializationStage.PotentiallySuccessful, action: () -> T): T { + return try { + action().also { + currentStartupListeners().forEach { nextStage.callListener(it) } + initializationStage = nextStage + } + } + catch (ex: Throwable) { + val failure = InitializationStage.Failed(initializationStage, nextStage, ex) { it.onFailed(ex) } + initializationStage = failure + throw failure.asException + } + } + + proceedToNextStage(InitializationStage.InitializingApproaches) {} + + val initializedApproachPerTask = mutableListOf>() + + approachPerTask = proceedToNextStage(InitializationStage.InitializingFUS(initializedApproachPerTask)) { + completeInitializersList.validate() + + fun initializeApproach(approachInitializer: MLTaskApproachInitializer) { + initializedApproachPerTask.add(InitializerAndApproach( + approachInitializer, + approachInitializer.initializeApproachWithin(this@MLApiPlatform) + )) + } + completeInitializersList.forEach { initializeApproach(it) } + initializedApproachPerTask.associate { it.initializer.task to it.approach } + } + + proceedToNextStage(InitializationStage.Finished) { + MLEventsLogger.Manager.ensureInitialized(okIfInitializing = true, this@MLApiPlatform) + } + } + + fun

getApproachFor(task: MLTask

): MLTaskApproach

{ + val taskApproach = requireNotNull(approachPerTask[task]) { + val mainMessage = "No approach for task $task was found" + val lateRegistrationMessage = getLateApproachRegistrationAssumption(task) + if (lateRegistrationMessage != null) "$mainMessage. $lateRegistrationMessage" else mainMessage + } + + @Suppress("UNCHECKED_CAST") + return taskApproach as MLTaskApproach

+ } + + private fun getLateApproachRegistrationAssumption(task: MLTask<*>): String? { + val currentInitializersList = taskApproaches.toMutableList() + if (completeInitializersList == currentInitializersList) return null + val taskApproaches = currentInitializersList.filter { it.task == task } + if (taskApproaches.isEmpty()) return null + require(taskApproaches.size == 1) { "More than one approach for task $task: $taskApproaches" } + return "Approach ${taskApproaches.first()} for task ${task.name} was registered after the ML API Platform was initialized" + } + + private fun List>.validate() { + val duplicateInitializerPerTask = this.groupBy { it.task }.filter { it.value.size > 1 } + require(duplicateInitializerPerTask.isEmpty()) { + "Found more than one approach for the following tasks: ${duplicateInitializerPerTask}" + } + } + } + + companion object { + fun MLApiPlatform.getDescriptorsOfTiers(tiers: Set>): PerTier> { + val descriptorsPerTier = tierDescriptors.groupBy { it.tier } + return tiers.associateWith { descriptorsPerTier[it] ?: emptyList() } + } + + fun MLApiPlatform.ensureApproachesInitialized() { + when (val stage = initializationStage) { + is InitializationStage.Failed -> throw Exception("Unable to ensure that approaches are initialized", stage.asException) + InitializationStage.NotStarted -> finishedInitialization.value + InitializationStage.InitializingApproaches -> throw Exception("Recursion detected while initializing approaches") + is InitializationStage.InitializingFUS -> return + InitializationStage.Finished -> return + } + } + + fun MLApiPlatform.getJoinedListenerForTask(taskApproach: MLTaskApproach

, + permanentSessionEnvironment: Environment): MLApproachListener { + val relevantGroupListeners = taskListeners.filter { taskApproach.javaClass in it.targetedApproaches } + val approachListeners = relevantGroupListeners.mapNotNull { + it.onAttemptedToStartSession(taskApproach, permanentSessionEnvironment) + } + return approachListeners.asJoinedListener() + } + } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/approach/AnalysisMethod.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/approach/AnalysisMethod.kt new file mode 100644 index 000000000000..e156dc93a663 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/approach/AnalysisMethod.kt @@ -0,0 +1,31 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.approach + +import com.intellij.platform.ml.FeatureDeclaration +import com.intellij.platform.ml.PerTier +import com.intellij.platform.ml.impl.model.MLModel +import com.intellij.platform.ml.impl.session.AnalysedRootContainer +import com.intellij.platform.ml.impl.session.DescribedRootContainer +import org.jetbrains.annotations.ApiStatus +import java.util.concurrent.CompletableFuture + +/** + * Represents the method, that is utilized for the [LogDrivenModelInference]'s analysis. + */ +@ApiStatus.Internal +interface AnalysisMethod, P : Any> { + /** + * Static declaration of the features, that are used in the session tree's analysis. + */ + val structureAnalysisDeclaration: PerTier>> + + /** + * Static declaration of the session's entities, that are not tiers. + */ + val sessionAnalysisDeclaration: Map>> + + /** + * Perform the completed session's analysis. + */ + fun analyseTree(treeRoot: DescribedRootContainer): CompletableFuture> +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/approach/GroupedAnalysis.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/approach/GroupedAnalysis.kt new file mode 100644 index 000000000000..708827a6b63c --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/approach/GroupedAnalysis.kt @@ -0,0 +1,62 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.approach + +import com.intellij.platform.ml.Feature +import com.intellij.platform.ml.FeatureDeclaration +import com.intellij.platform.ml.impl.model.MLModel +import com.intellij.platform.ml.impl.session.DescribedRootContainer +import com.intellij.platform.ml.impl.session.analysis.* +import org.jetbrains.annotations.ApiStatus +import java.util.concurrent.CompletableFuture + +/** + * The session's assembled analysis declaration. + */ +@ApiStatus.Internal +data class GroupedAnalysisDeclaration( + val structureAnalysis: StructureAnalysisDeclaration, + val mlModelAnalysis: Set> +) + +/** + * The session's assembled analysis itself. + */ +@ApiStatus.Internal +data class GroupedAnalysis, P : Any>( + val structureAnalysis: StructureAnalysis, + val mlModelAnalysis: Set +) + +/** + * Analyzes both structure and ML model. + */ +@ApiStatus.Internal +class JoinedGroupedSessionAnalyser, P : Any>( + private val structureAnalysers: Collection>, + private val mlModelAnalysers: Collection>, +) : SessionAnalyser, M, P> { + override val analysisDeclaration = GroupedAnalysisDeclaration( + structureAnalysis = SessionStructureAnalysisJoiner().joinDeclarations(structureAnalysers.map { it.analysisDeclaration }), + mlModelAnalysis = MLModelAnalysisJoiner().joinDeclarations(mlModelAnalysers.map { it.analysisDeclaration }) + ) + + override fun analyse(sessionTreeRoot: DescribedRootContainer): CompletableFuture> { + val joinedStructureAnalyser = JoinedSessionAnalyser( + structureAnalysers, SessionStructureAnalysisJoiner() + ) + val joinedMLModelAnalyser = JoinedSessionAnalyser( + mlModelAnalysers, MLModelAnalysisJoiner() + ) + val structureAnalysis = joinedStructureAnalyser.analyse(sessionTreeRoot) + val mlModelAnalysis = joinedMLModelAnalyser.analyse(sessionTreeRoot) + + val futureGroupAnalysis = CompletableFuture.allOf(structureAnalysis, mlModelAnalysis) + val completeGroupAnalysis = CompletableFuture>() + + futureGroupAnalysis.thenRun { + completeGroupAnalysis.complete(GroupedAnalysis(structureAnalysis.get(), mlModelAnalysis.get())) + } + + return completeGroupAnalysis + } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/approach/LanguageSpecific.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/approach/LanguageSpecific.kt new file mode 100644 index 000000000000..329d9d0a11b7 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/approach/LanguageSpecific.kt @@ -0,0 +1,36 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.approach + +import com.intellij.lang.Language +import com.intellij.platform.ml.Feature +import com.intellij.platform.ml.FeatureDeclaration +import com.intellij.platform.ml.impl.model.MLModel +import com.intellij.platform.ml.impl.session.DescribedRootContainer +import com.intellij.platform.ml.impl.session.analysis.MLModelAnalyser +import org.jetbrains.annotations.ApiStatus +import java.util.concurrent.CompletableFuture + +/** + * Something, that is dedicated for one language only. + */ +@ApiStatus.Internal +interface LanguageSpecific { + val languageId: String +} + +/** + * The analyzer, that adds information about ML model's language to logs. + */ +@ApiStatus.Internal +class ModelLanguageAnalyser : MLModelAnalyser + where M : MLModel

, + M : LanguageSpecific { + + private val LANGUAGE_ID = FeatureDeclaration.categorical("language_id", Language.getRegisteredLanguages().map { it.id }.toSet()) + + override val analysisDeclaration = setOf(LANGUAGE_ID) + + override fun analyse(sessionTreeRoot: DescribedRootContainer): CompletableFuture> = CompletableFuture.completedFuture( + setOf(LANGUAGE_ID with sessionTreeRoot.root.languageId) + ) +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/approach/LogDrivenModelInference.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/approach/LogDrivenModelInference.kt new file mode 100644 index 000000000000..68d72236808f --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/approach/LogDrivenModelInference.kt @@ -0,0 +1,231 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.approach + +import com.intellij.openapi.diagnostic.thisLogger +import com.intellij.platform.ml.* +import com.intellij.platform.ml.ScopeEnvironment.Companion.accessibleSafelyByOrNull +import com.intellij.platform.ml.impl.* +import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform +import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform.Companion.getDescriptorsOfTiers +import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform.Companion.getJoinedListenerForTask +import com.intellij.platform.ml.impl.environment.ExtendedEnvironment +import com.intellij.platform.ml.impl.model.MLModel +import com.intellij.platform.ml.impl.monitoring.MLApproachListener +import com.intellij.platform.ml.impl.monitoring.MLSessionListener +import com.intellij.platform.ml.impl.session.* +import org.jetbrains.annotations.ApiStatus + +/** + * The main way to apply classical machine learning approaches: run the ML model, collect the logs, retrain the model, repeat. + * + * @param task The task that is solved by this approach. + * @param apiPlatform The platform, that the approach will be running within. + */ +@ApiStatus.Internal +abstract class LogDrivenModelInference, P : Any>( + override val task: MLTask

, + override val apiPlatform: MLApiPlatform +) : MLTaskApproach

{ + /** + * The method that is used to analyze sessions. + * + * [StructureAndModelAnalysis] is currently used analysis method + * that is dedicated to analyze session's tree-like structure, + * and the ML model. + */ + abstract val analysisMethod: AnalysisMethod + + /** + * Provides an ML model to use during session's lifetime. + */ + abstract val mlModelProvider: MLModel.Provider + + /** + * Declares features, that are not used by the ML model, but must be computed anyway, + * so they make it to logs. + * + * A feature cannot be simultaneously declared as "not used description" and as used by the [mlModelProvider]'s + * provided model. + * If a feature is not declared as "not used but still computed" or as "used by the model", then it will be computed. + * + * It must contain explicitly declared selectors for each tier used in [task], as well as in [additionallyDescribedTiers]. + */ + abstract val notUsedDescription: PerTier + + /** + * Performs description's computation. + * Could perform caching mechanisms to avoid recomputing features every time. + */ + abstract val descriptionComputer: DescriptionComputer + + /** + * Tiers that do not make a part of te [task], but they could be described and passed to the ML model. + * + * The size of this list must correspond to the number of levels in the solved [task]. + */ + abstract val additionallyDescribedTiers: List>> + + private val levels: List by lazy { + (task.levels zip additionallyDescribedTiers).map { Level(it.first, it.second) } + } + + private val approachValidation: Unit by lazy { validateApproach() } + + override fun startSession(permanentSessionEnvironment: Environment): Session.StartOutcome

{ + return startSessionMonitoring(permanentSessionEnvironment) + } + + private fun startSessionMonitoring(permanentSessionEnvironment: Environment): Session.StartOutcome

{ + val approachListener = apiPlatform.getJoinedListenerForTask(this, permanentSessionEnvironment) + try { + return acquireModelAndStartSession(permanentSessionEnvironment, approachListener) + } + catch (e: Throwable) { + approachListener.onFailedToStartSessionWithException(e) + return Session.StartOutcome.UncaughtException(e) + } + } + + private fun acquireModelAndStartSession(permanentSessionEnvironment: Environment, + approachListener: MLApproachListener): Session.StartOutcome

{ + approachValidation + + val extendedPermanentSessionEnvironment = ExtendedEnvironment( + apiPlatform.environmentExtenders, + permanentSessionEnvironment, + mlModelProvider.requiredTiers + ) + + val mlModel: M = run { + val mlModelProviderEnvironment = extendedPermanentSessionEnvironment.accessibleSafelyByOrNull(mlModelProvider) + if (mlModelProviderEnvironment == null) { + val failure = InsufficientEnvironmentForModelProviderOutcome

(mlModelProvider.requiredTiers, + extendedPermanentSessionEnvironment.tiers) + approachListener.onFailedToStartSession(failure) + return failure + } + val nullableMlModel = mlModelProvider.provideModel(levels, mlModelProviderEnvironment) + if (nullableMlModel == null) { + val failure = ModelNotAcquiredOutcome

() + approachListener.onFailedToStartSession(failure) + return failure + } + nullableMlModel + } + + var sessionListener: MLSessionListener? = null + + val analyseThenLogStructure = SessionTreeHandler, M, P> { treeRoot -> + sessionListener?.onSessionDescriptionFinished(treeRoot) + analysisMethod.analyseTree(treeRoot).thenApplyAsync { analysedSession -> + sessionListener?.onSessionAnalysisFinished(analysedSession) + }.exceptionally { + thisLogger().error(it) + } + } + + val session = if (levels.size == 1) { + val collector = SolitaryLeafCollector( + apiPlatform, levels.first(), descriptionComputer, notUsedDescription, + permanentSessionEnvironment, levels.first().additional, mlModel + ) + collector.handleCollectedTree(analyseThenLogStructure) + MLModelPrediction(mlModel, collector) + } + else { + val collector = RootCollector( + apiPlatform, levels, descriptionComputer, notUsedDescription, + permanentSessionEnvironment, levels.first().additional, mlModel + ) + collector.handleCollectedTree(analyseThenLogStructure) + MLModelPredictionBranching(mlModel, collector) + } + + sessionListener = approachListener.onStartedSession(session) + + return Session.StartOutcome.Success(session) + } + + override val approachDeclaration: MLTaskApproach.Declaration + get() { + approachValidation + + return MLTaskApproach.Declaration( + sessionFeatures = analysisMethod.sessionAnalysisDeclaration, + levelsScheme = levels.map { levelTiers -> + Level( + buildMainTiersScheme(levelTiers.main, apiPlatform), + buildAdditionalTiersScheme(levelTiers.additional, apiPlatform), + ) + } + ) + } + + private fun validateApproach() { + require(task.levels.size == additionallyDescribedTiers.size) { + "Task $task has ${task.levels.size} levels, when 'additionallyDescribedTiers' has ${additionallyDescribedTiers.size}" + } + + require(levels.isNotEmpty()) { "Task must declare at least one level" } + + val maybeDuplicatedTaskTiers = levels.flatMap { it.main + it.additional } + val taskTiers = maybeDuplicatedTaskTiers.toSet() + + require(maybeDuplicatedTaskTiers.size == taskTiers.size) { + "There are duplicated tiers in the declaration: ${maybeDuplicatedTaskTiers - taskTiers}" + } + require(notUsedDescription.keys == taskTiers) { + "Selectors for those and only those tiers must be represented in the 'notUsedDescription' that are present in the task. " + + "Missing: ${taskTiers - notUsedDescription.keys}, " + + "Redundant: ${notUsedDescription.keys - taskTiers}" + } + } + + private fun buildTierDescriptionDeclaration(tierDescriptors: Collection): Set> { + return tierDescriptors.flatMap { + if (it is ObsoleteTierDescriptor) it.partialDescriptionDeclaration else it.descriptionDeclaration + }.toSet() + } + + private fun buildMainTiersScheme(tiers: Set>, apiEnvironment: MLApiPlatform): PerTier { + val tiersDescriptors = apiEnvironment.getDescriptorsOfTiers(tiers) + + return tiers.associateWith { tier -> + val tierDescriptors = tiersDescriptors.getValue(tier) + val descriptionDeclaration = buildTierDescriptionDeclaration(tierDescriptors) + val analysisDeclaration = analysisMethod.structureAnalysisDeclaration[tier] ?: emptySet() + MainTierScheme(descriptionDeclaration, analysisDeclaration) + } + } + + private fun buildAdditionalTiersScheme(tiers: Set>, apiEnvironment: MLApiPlatform): PerTier { + val tiersDescriptors = apiEnvironment.getDescriptorsOfTiers(tiers) + + return tiers.associateWith { tier -> + val tierDescriptors = tiersDescriptors.getValue(tier) + val descriptionDeclaration = buildTierDescriptionDeclaration(tierDescriptors) + AdditionalTierScheme(descriptionDeclaration) + } + } +} + +/** + * An exception that indicates that for some reason, it was not possible to provide an ML model + * when calling [com.intellij.platform.ml.impl.model.MLModel.Provider.provideModel]. + * The session's start is considered as failed. + */ +@ApiStatus.Internal +class ModelNotAcquiredOutcome

: Session.StartOutcome.Failure

() { + override val failureDetails: String = "ML Model was not provided" +} + +/** + * There were not enough tiers to satisfy [MLModel.Provider]'s requirements, so it could not provide the model. + */ +@ApiStatus.Internal +class InsufficientEnvironmentForModelProviderOutcome

( + expectedTiers: Set>, + existingTiers: Set> +) : Session.StartOutcome.Failure

() { + override val failureDetails: String = "ML Model could not be provided: environment is not sufficient. Missing: ${expectedTiers - existingTiers}" +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/approach/StructureAndModelAnalysis.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/approach/StructureAndModelAnalysis.kt new file mode 100644 index 000000000000..14cd0c5eece6 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/approach/StructureAndModelAnalysis.kt @@ -0,0 +1,81 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.approach + +import com.intellij.platform.ml.FeatureDeclaration +import com.intellij.platform.ml.PerTierInstance +import com.intellij.platform.ml.impl.model.MLModel +import com.intellij.platform.ml.impl.session.* +import com.intellij.platform.ml.impl.session.analysis.MLModelAnalyser +import com.intellij.platform.ml.impl.session.analysis.StructureAnalyser +import com.intellij.platform.ml.impl.session.analysis.StructureAnalysisDeclaration +import org.jetbrains.annotations.ApiStatus +import java.util.concurrent.CompletableFuture + +/** + * An analysis method, that asynchronously analyzes session structure after + * it finished, and the ML model that has been used during the session. + * + * @param structureAnalysers Analyzing session's structure - main tier instances. + * @param mlModelAnalysers Analyzing the ML model which was producing predictions during the session. + * @param sessionAnalysisKeyModel Key that will be used in logs, to write ML model's features into. + */ +@ApiStatus.Internal +class StructureAndModelAnalysis, P : Any>( + structureAnalysers: Collection>, + mlModelAnalysers: Collection>, + private val sessionAnalysisKeyModel: String = DEFAULT_SESSION_KEY_ML_MODEL +) : AnalysisMethod { + private val groupedAnalyser = JoinedGroupedSessionAnalyser(structureAnalysers, mlModelAnalysers) + + override val structureAnalysisDeclaration: StructureAnalysisDeclaration + get() = groupedAnalyser.analysisDeclaration.structureAnalysis + + override val sessionAnalysisDeclaration: Map>> = mapOf( + sessionAnalysisKeyModel to groupedAnalyser.analysisDeclaration.mlModelAnalysis + ) + + override fun analyseTree(treeRoot: DescribedRootContainer): CompletableFuture> { + return groupedAnalyser.analyse(treeRoot).thenApply { + buildAnalysedSessionTree(treeRoot, it) as AnalysedRootContainer

+ } + } + + private fun buildAnalysedSessionTree(tree: DescribedSessionTree, analysis: GroupedAnalysis): AnalysedSessionTree

{ + val treeAnalysisPerInstance: PerTierInstance = tree.level.main.entries + .associate { (tierInstance, data) -> + tierInstance to AnalysedTierData(data.description, + analysis.structureAnalysis[tree]?.get(tierInstance.tier) ?: emptySet()) + } + + val analysedLevel = Level( + main = treeAnalysisPerInstance, + additional = tree.level.additional + ) + + return when (tree) { + is SessionTree.Branching -> { + SessionTree.Branching(analysedLevel, + tree.children.map { buildAnalysedSessionTree(it, analysis) }) + } + is SessionTree.Leaf -> { + SessionTree.Leaf(analysedLevel, tree.prediction) + } + is SessionTree.ComplexRoot -> { + SessionTree.ComplexRoot(mapOf(sessionAnalysisKeyModel to analysis.mlModelAnalysis), + analysedLevel, + tree.children.map { buildAnalysedSessionTree(it, analysis) } + ) + } + is SessionTree.SolitaryLeaf -> { + SessionTree.SolitaryLeaf(mapOf(sessionAnalysisKeyModel to analysis.mlModelAnalysis), + analysedLevel, + tree.prediction + ) + } + } + } + + companion object { + const val DEFAULT_SESSION_KEY_ML_MODEL: String = "ml_model" + } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/approach/Versioned.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/approach/Versioned.kt new file mode 100644 index 000000000000..3be44a757368 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/approach/Versioned.kt @@ -0,0 +1,37 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.approach + +import com.intellij.openapi.util.Version +import com.intellij.platform.ml.Feature +import com.intellij.platform.ml.FeatureDeclaration +import com.intellij.platform.ml.impl.model.MLModel +import com.intellij.platform.ml.impl.session.DescribedRootContainer +import com.intellij.platform.ml.impl.session.analysis.MLModelAnalyser +import org.jetbrains.annotations.ApiStatus +import java.util.concurrent.CompletableFuture + +/** + * Something, that has versions. + */ +@ApiStatus.Internal +interface Versioned { + val version: Version? +} + +/** + * Adds model's version to the ML logs. + */ +@ApiStatus.Internal +class ModelVersionAnalyser : MLModelAnalyser + where M : MLModel

, + M : Versioned { + companion object { + val VERSION = FeatureDeclaration.version("version").nullable() + } + + override val analysisDeclaration = setOf(VERSION) + + override fun analyse(sessionTreeRoot: DescribedRootContainer): CompletableFuture> = CompletableFuture.completedFuture( + setOf(VERSION with sessionTreeRoot.root.version) + ) +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/environment/EnvironmentResolver.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/environment/EnvironmentResolver.kt new file mode 100644 index 000000000000..ffaee813b55e --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/environment/EnvironmentResolver.kt @@ -0,0 +1,35 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.environment + +import com.intellij.platform.ml.EnvironmentExtender +import com.intellij.platform.ml.Tier +import org.jetbrains.annotations.ApiStatus + +/** + * There was circle within available extenders' set's requirements. + * Meaning that to create some tier, we must eventually run the extender itself. + */ +@ApiStatus.Internal +class CircularRequirementException(val extensionPath: List>) : IllegalArgumentException() { + override val message: String = "A circular resolve path found among EnvironmentExtenders: ${serializePath()}" + + private fun serializePath(): String { + val extensions = extensionPath.map { extender -> "[$extender] -> ${extender.extendingTier.name}" } + return extensions.joinToString(" - ") + } +} + +/** + * An algorithm for resolving order of [EnvironmentExtender]s' execution. + */ +@ApiStatus.Internal +interface EnvironmentResolver { + /** + * @return The order, which guarantees that for each extender, all the requirements will be fulfilled by the previously runned + * extenders. + * But it still could happen that an [EnvironmentExtender] will not return the tier it extends, then some subsequent + * extenders' requirements could not be satisfied. + * @throws CircularRequirementException + */ + fun resolve(extenderPerTier: Map, EnvironmentExtender<*>>): List> +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/environment/ExtendedEnvironment.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/environment/ExtendedEnvironment.kt new file mode 100644 index 000000000000..2f74d5f15166 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/environment/ExtendedEnvironment.kt @@ -0,0 +1,139 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.environment + +import com.intellij.platform.ml.* +import com.intellij.platform.ml.EnvironmentExtender.Companion.extendTierInstance +import com.intellij.platform.ml.ScopeEnvironment.Companion.accessibleSafelyByOrNull +import org.jetbrains.annotations.ApiStatus + +/** + * An environment that is built in to fulfill a [TierRequester]'s requirements. + * + * When built, it accepts all available extenders, and it is trying to resolve their + * order to extend particular tiers. + * + * If there are some main tiers passed to build the extended environment, they will + * not be overridden. + * + * Among the available extenders, passed to the constructor, there could be some that + * extend the same tier. + * This could signify that an instance of the same tier can be mined from + * different sets of objects (requirements). + * But if there is more than one extender, that could potentially be run that will + * extend a tier of the same instance, then [IllegalArgumentException] will be thrown + * telling, that there is an ambiguity. + * + * When we have determined which extenders could potentially be run, we try to determine + * the order with topological sort: [com.intellij.platform.ml.impl.environment.TopologicalSortingResolver]. + * If there is a circle in the requirements, it will throw [com.intellij.platform.ml.impl.environment.CircularRequirementException]. + */ +@ApiStatus.Internal +class ExtendedEnvironment : Environment { + private val storage: Environment + + /** + * @param environmentExtenders Extenders, that will be used to extend the [tiersToExtend]. + * @param mainEnvironment Tiers that are already determined and shall not be replaced. + * @param tiersToExtend Tiers that should be put to the extended environment via [environmentExtenders]. + * It could not be guaranteed that all desired tiers will be extended. + * + * @return An environment that contains all tiers from [mainEnvironment] plus + * some tiers from [tiersToExtend], if it was possible to extend them. + */ + constructor(environmentExtenders: List>, + mainEnvironment: Environment, + tiersToExtend: Set>) { + require(tiersToExtend.all { it !in mainEnvironment }) + val nonOverridingExtenders = environmentExtenders.filter { it.extendingTier !in mainEnvironment } + storage = buildExtendedEnvironment( + tiersToExtend + mainEnvironment.tiers, + nonOverridingExtenders + mainEnvironment.separateIntoExtenders() + ) + } + + /** + * @param environmentExtenders Extenders that will be utilized to build the extended environment. + * @param mainEnvironment An already existing environment, instances from which shall not be overridden. + * + * @return An environment that contains all tiers from [mainEnvironment] plus + * all tiers that it was possible to acquire via [environmentExtenders]. + */ + constructor(environmentExtenders: List>, + mainEnvironment: Environment) { + val nonOverridingExtenders = environmentExtenders.filter { it.extendingTier !in mainEnvironment } + storage = buildExtendedEnvironment( + nonOverridingExtenders.map { it.extendingTier }.toSet() + mainEnvironment.tiers, + nonOverridingExtenders + mainEnvironment.separateIntoExtenders() + ) + } + + override val tiers: Set> + get() = storage.tiers + + override fun getInstance(tier: Tier): T { + return storage.getInstance(tier) + } + + companion object { + private val ENVIRONMENT_RESOLVER = TopologicalSortingResolver() + + /** + * Creates an [Environment] that contents tiers from [tiers], that were successfully extended by [extenders] + */ + private fun buildExtendedEnvironment(tiers: Set>, + extenders: List>): Environment { + val validatedExtendersPerTier = validateExtenders(tiers, extenders) + val extensionOrder = ENVIRONMENT_RESOLVER.resolve(validatedExtendersPerTier) + val storage = TierInstanceStorage() + + extensionOrder.map { + val safelyAccessibleEnvironment = storage.accessibleSafelyByOrNull(it) ?: return@map + it.extendTierInstance(safelyAccessibleEnvironment)?.let { extendedTierInstance -> + storage.putTierInstance(extendedTierInstance) + } + } + return storage + } + + private fun validateExtenders(tiers: Set>, extenders: List>): Map, EnvironmentExtender<*>> { + val extendableTiers: Set> = extenders.map { it.extendingTier }.toSet() + + val runnableExtenders = extenders + .filter { desiredExtender -> + desiredExtender.requiredTiers.all { requirementForDesiredExtender -> requirementForDesiredExtender in extendableTiers } + } + + val ambiguouslyExtendableTiers: MutableList, List>>> = mutableListOf() + val extendersPerTier: Map, EnvironmentExtender<*>> = runnableExtenders + .groupBy { it.extendingTier } + .mapNotNull { (tier, tierExtenders) -> + if (tierExtenders.size > 1) { + ambiguouslyExtendableTiers.add(tier to tierExtenders) + null + } + else + tierExtenders.first() + } + .associateBy { it.extendingTier } + .filterKeys { it in tiers } + + require(ambiguouslyExtendableTiers.isEmpty()) { "Some tiers could be extended ambiguously: $ambiguouslyExtendableTiers" } + + return extendersPerTier + } + } +} + +private fun Environment.separateIntoExtenders(): List> { + class ContainingExtender(private val tier: Tier) : EnvironmentExtender { + override val extendingTier: Tier = tier + + override fun extend(environment: Environment): T { + return this@separateIntoExtenders[tier] + } + + override val requiredTiers: Set> = emptySet() + } + + return this.tiers.map { tier -> ContainingExtender(tier) } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/environment/TopologicalSortingResolver.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/environment/TopologicalSortingResolver.kt new file mode 100644 index 000000000000..857fdd553727 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/environment/TopologicalSortingResolver.kt @@ -0,0 +1,49 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.environment + +import com.intellij.platform.ml.EnvironmentExtender +import com.intellij.platform.ml.Tier +import org.jetbrains.annotations.ApiStatus + +private typealias Node = EnvironmentExtender<*> + +/** + * Resolves order using topological sort + */ +@ApiStatus.Internal +class TopologicalSortingResolver : EnvironmentResolver { + + override fun resolve(extenderPerTier: Map, EnvironmentExtender<*>>): List> { + val graph: Map> = extenderPerTier.values + .associateWith { desiredExtender -> + desiredExtender.requiredTiers.map { requirementForDesiredExtender -> extenderPerTier.getValue(requirementForDesiredExtender) } + } + + val reverseTopologicalOrder: MutableList = mutableListOf() + val resolveStatus: MutableMap = mutableMapOf() + + fun Node.resolve(path: List) { + when (resolveStatus[this]) { + ResolveState.STARTED -> throw CircularRequirementException(path + this) + ResolveState.RESOLVED -> return + null -> { + resolveStatus[this] = ResolveState.STARTED + for (nextNode in graph.getValue(this)) { + nextNode.resolve(path + this) + } + resolveStatus[this] = ResolveState.RESOLVED + reverseTopologicalOrder.add(this) + } + } + } + + graph.keys.forEach { it.resolve(emptyList()) } + + return reverseTopologicalOrder + } + + private enum class ResolveState { + STARTED, + RESOLVED + } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/logger/FusSessionEventBuilder.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/logger/FusSessionEventBuilder.kt new file mode 100644 index 000000000000..f2a13c183247 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/logger/FusSessionEventBuilder.kt @@ -0,0 +1,49 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.logger + +import com.intellij.internal.statistic.eventLog.events.EventPair +import com.intellij.internal.statistic.eventLog.events.ObjectDescription +import com.intellij.internal.statistic.eventLog.events.ObjectEventData +import com.intellij.platform.ml.impl.MLTaskApproach +import com.intellij.platform.ml.impl.session.AnalysedRootContainer +import com.intellij.platform.ml.impl.session.AnalysedSessionTree +import org.jetbrains.annotations.ApiStatus + +/** + * Represents FUS fields of a session's subtree. + */ +@ApiStatus.Internal +abstract class SessionFields

: ObjectDescription() { + fun buildObjectEventData(sessionStructure: AnalysedSessionTree

) = ObjectEventData(buildEventPairs(sessionStructure)) + + abstract fun buildEventPairs(sessionStructure: AnalysedSessionTree

): List> +} + +/** + * Represents a logging scheme for the FUS event. + * + * @param P The type of the ML task's prediction + */ +@ApiStatus.Internal +interface FusSessionEventBuilder

{ + /** + * Configuration of a [FusSessionEventBuilder], that builds it when accepts approach's declaration. + */ + interface FusScheme

{ + fun createEventBuilder(approachDeclaration: MLTaskApproach.Declaration): FusSessionEventBuilder

+ } + + /** + * Builds declaration of all features, that will be logged for the session tiers' description and analysis. + * It is required because FUS logs validators are built 'statically'. + */ + fun buildFusDeclaration(): SessionFields

+ + /** + * Builds a concrete FUS record, that contains fields that were built by [buildFusDeclaration]. + * + * @param sessionStructure A session tree that been already analyzed and is ready to be logged. + * @param sessionFields Session fields that were built by [buildFusDeclaration] earlier. + */ + fun buildRecord(sessionStructure: AnalysedRootContainer

, sessionFields: SessionFields

): Array> +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/logger/InplaceFeaturesScheme.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/logger/InplaceFeaturesScheme.kt new file mode 100644 index 000000000000..24c34cc656a7 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/logger/InplaceFeaturesScheme.kt @@ -0,0 +1,390 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.logger + +import com.intellij.internal.statistic.eventLog.FeatureUsageData +import com.intellij.internal.statistic.eventLog.events.* +import com.intellij.platform.ml.* +import com.intellij.platform.ml.impl.* +import com.intellij.platform.ml.impl.session.* +import org.jetbrains.annotations.ApiStatus + +/** + * The currently used FUS logging scheme. + * Inplace means that the features are logged beside the tiers' instances and + * are not compressed in any way. + * + * An example of a FUS record could be found in the test resources: + * [testResources/ml_logs.js](community/platform/ml-impl/testResources/ml_logs.js) + */ +@ApiStatus.Internal +class InplaceFeaturesScheme

internal constructor( + private val predictionValidationRule: List, + private val predictionTransform: (P?) -> String, + private val approachDeclaration: MLTaskApproach.Declaration +) : FusSessionEventBuilder

{ + class FusScheme

( + private val predictionValidationRule: List, + private val predictionTransform: (P?) -> String + ) : FusSessionEventBuilder.FusScheme

{ + override fun createEventBuilder(approachDeclaration: MLTaskApproach.Declaration): FusSessionEventBuilder

= InplaceFeaturesScheme( + predictionValidationRule, + predictionTransform, + approachDeclaration + ) + + companion object { + val DOUBLE: FusScheme = FusScheme(listOf("{regexp#float}")) { it.toString() } + } + } + + override fun buildFusDeclaration(): SessionFields

{ + require(approachDeclaration.levelsScheme.isNotEmpty()) + return if (approachDeclaration.levelsScheme.size == 1) + PredictionSessionFields(approachDeclaration.levelsScheme.first(), predictionValidationRule, predictionTransform, + approachDeclaration.sessionFeatures) + else + NestableSessionFields(approachDeclaration.levelsScheme.first(), approachDeclaration.levelsScheme.drop(1), predictionValidationRule, + predictionTransform, approachDeclaration.sessionFeatures) + } + + override fun buildRecord(sessionStructure: AnalysedRootContainer

, sessionFields: SessionFields

): Array> { + return sessionFields.buildEventPairs(sessionStructure).toTypedArray() + } +} + +private class PredictionField( + override val name: String, + override val validationRule: List, + val transform: (T?) -> String +) : PrimitiveEventField() { + override fun addData(fuData: FeatureUsageData, value: T?) { + fuData.addData(name, transform(value)) + } +} + +private data class StringField(override val name: String, private val possibleValues: Set) : PrimitiveEventField() { + override fun addData(fuData: FeatureUsageData, value: String) { + fuData.addData(name, value) + } + + override val validationRule = listOf( + "{enum:${possibleValues.joinToString("|")}}" + ) +} + +private data class VersionField(override val name: String) : PrimitiveEventField() { + override val validationRule: List + get() = listOf("{regexp#version}") + + override fun addData(fuData: FeatureUsageData, value: String?) { + fuData.addVersionByString(value) + } +} + +private fun FeatureDeclaration<*>.toEventField(): EventField<*> { + return when (val valueType = type) { + is FeatureValueType.Enum<*> -> EnumEventField(name, valueType.enumClass, Enum<*>::name) + is FeatureValueType.Int -> IntEventField(name) + is FeatureValueType.Long -> LongEventField(name) + is FeatureValueType.Class -> ClassEventField(name) + is FeatureValueType.Boolean -> BooleanEventField(name) + is FeatureValueType.Double -> DoubleEventField(name) + is FeatureValueType.Float -> FloatEventField(name) + is FeatureValueType.Nullable -> FeatureDeclaration(name, valueType.baseType).toEventField() + is FeatureValueType.Categorical -> StringField(name, valueType.possibleValues) + is FeatureValueType.Version -> VersionField(name) + } +} + +private fun > Feature.Enum.toEventPair(): EventPair<*> { + return EnumEventField(declaration.name, valueType.enumClass, Enum<*>::name) with value +} + +private fun Feature.Nullable.toEventPair(): EventPair<*>? { + return value?.let { + baseType.instantiate(this.declaration.name, it).toEventPair() + } +} + +private fun Feature.toEventPair(): EventPair<*>? { + return when (this) { + is Feature.TypedFeature<*> -> typedToEventPair() + } +} + +private fun Feature.TypedFeature.typedToEventPair(): EventPair<*>? { + return when (this) { + is Feature.Boolean -> BooleanEventField(declaration.name) with this.value + is Feature.Categorical -> StringField(declaration.name, this.valueType.possibleValues) with this.value + is Feature.Class -> ClassEventField(declaration.name) with this.value + is Feature.Double -> DoubleEventField(declaration.name) with this.value + is Feature.Enum<*> -> toEventPair() + is Feature.Float -> FloatEventField(declaration.name) with this.value + is Feature.Int -> IntEventField(declaration.name) with this.value + is Feature.Long -> LongEventField(declaration.name) with this.value + is Feature.Nullable<*> -> toEventPair() + is Feature.Version -> VersionField(declaration.name) with this.value.toString() + } +} + +private class FeatureSet(featuresDeclarations: Set>) : ObjectDescription() { + init { + for (featureDeclaration in featuresDeclarations) { + field(featureDeclaration.toEventField()) + } + } + + fun toObjectEventData(features: Set) = ObjectEventData(features.mapNotNull { it.toEventPair() }) +} + +private fun Set.toObjectEventData() = FeatureSet(this.map { it.declaration }.toSet()).toObjectEventData(this) + +private data class TierDescriptionFields( + val used: FeatureSet, + val notUsed: FeatureSet, +) : ObjectDescription() { + private val fieldUsed = ObjectEventField("used", used) + private val fieldNotUsed = ObjectEventField("not_used", notUsed) + private val fieldAmountUsedNonDeclaredFeatures = IntEventField("n_used_non_declared") + private val fieldAmountNotUsedNonDeclaredFeatures = IntEventField("n_not_used_non_declared") + + init { + field(fieldUsed) + field(fieldNotUsed) + field(fieldAmountUsedNonDeclaredFeatures) + field(fieldAmountNotUsedNonDeclaredFeatures) + } + + fun buildEventPairs(descriptionPartition: DescriptionPartition): List> { + val result = mutableListOf>( + fieldUsed with descriptionPartition.declared.used.toObjectEventData(), + fieldNotUsed with descriptionPartition.declared.notUsed.toObjectEventData(), + ) + descriptionPartition.nonDeclared.used.let { + if (it.isNotEmpty()) result += fieldAmountUsedNonDeclaredFeatures with it.size + } + descriptionPartition.nonDeclared.notUsed.let { + if (it.isNotEmpty()) result += fieldAmountUsedNonDeclaredFeatures with it.size + } + return result + } + + fun buildObjectEventData(descriptionPartition: DescriptionPartition) = ObjectEventData( + buildEventPairs(descriptionPartition) + ) +} + +private data class AdditionalTierFields(val description: TierDescriptionFields) : ObjectDescription() { + private val fieldInstanceId = IntEventField("id") + private val fieldDescription = ObjectEventField("description", description) + + constructor(descriptionFeatures: Set>) + : this(TierDescriptionFields(used = FeatureSet(descriptionFeatures), + notUsed = FeatureSet(descriptionFeatures))) + + init { + field(fieldInstanceId) + field(fieldDescription) + } + + fun buildObjectEventData(tierInstance: TierInstance<*>, + descriptionPartition: DescriptionPartition) = ObjectEventData( + fieldInstanceId with tierInstance.instance.hashCode(), + fieldDescription with this.description.buildObjectEventData(descriptionPartition), + ) +} + +private data class MainTierFields( + val description: TierDescriptionFields, + val analysis: FeatureSet, +) : ObjectDescription() { + private val fieldInstanceId = IntEventField("id") + private val fieldDescription = ObjectEventField("description", description) + private val fieldAnalysis = ObjectEventField("analysis", analysis) + + constructor(descriptionFeatures: Set>, analysisFeatures: Set>) + : this(TierDescriptionFields(used = FeatureSet(descriptionFeatures), + notUsed = FeatureSet(descriptionFeatures)), + FeatureSet(analysisFeatures)) + + init { + field(fieldInstanceId) + field(fieldDescription) + field(fieldAnalysis) + } + + fun buildObjectEventData(tierInstance: TierInstance<*>, + descriptionPartition: DescriptionPartition, + analysis: Set) = ObjectEventData( + fieldInstanceId with tierInstance.instance.hashCode(), + fieldDescription with this.description.buildObjectEventData(descriptionPartition), + fieldAnalysis with this.analysis.toObjectEventData(analysis) + ) +} + +private data class SessionAnalysisFields

( + val featuresPerKey: Map>> +) : SessionFields

() { + val fieldsPerKey: Map = featuresPerKey.entries.associate { (key, keyFeatures) -> + key to ObjectEventField(key, FeatureSet(keyFeatures)) + } + + init { + fieldsPerKey.values.forEach { field(it) } + } + + override fun buildEventPairs(sessionStructure: AnalysedSessionTree

): List> { + require(sessionStructure is SessionTree.RootContainer) + return sessionStructure.root.entries.map { (key, keyFeatures) -> + val keyFeaturesDeclaration = requireNotNull(featuresPerKey[key]) { + "Key $key was not declared as session features key, declared keys: ${featuresPerKey.keys}" + } + val objectEventField = fieldsPerKey.getValue(key) + val keyFeatureSet = FeatureSet(keyFeaturesDeclaration) + objectEventField with keyFeatureSet.toObjectEventData(keyFeatures) + } + } +} + +private class MainTierSet

(mainTierScheme: PerTier) : SessionFields

() { + val tiersDeclarations: PerTier = mainTierScheme.entries.associate { (tier, tierScheme) -> + tier to MainTierFields(tierScheme.description, tierScheme.analysis) + } + val fieldPerTier: PerTier = tiersDeclarations.entries.associate { (tier, tierFields) -> + tier to ObjectEventField(tier.name, tierFields) + } + + init { + fieldPerTier.values.forEach { field(it) } + } + + override fun buildEventPairs(sessionStructure: AnalysedSessionTree

): List> { + val level = sessionStructure.level.main + return level.entries.map { (tierInstance, data) -> + val tierField = requireNotNull(fieldPerTier[tierInstance.tier]) { + "Tier ${tierInstance.tier} is now allowed here: only ${fieldPerTier.keys} are registered" + } + val tierDeclaration = tiersDeclarations.getValue(tierInstance.tier) + tierField with tierDeclaration.buildObjectEventData(tierInstance, data.description, data.analysis) + } + } +} + +private class AdditionalTierSet

(additionalTierScheme: PerTier) : SessionFields

() { + val tiersDeclarations: PerTier = additionalTierScheme.entries.associate { (tier, tierScheme) -> + tier to AdditionalTierFields(tierScheme.description) + } + val fieldPerTier: PerTier = tiersDeclarations.entries.associate { (tier, tierFields) -> + tier to ObjectEventField(tier.name, tierFields) + } + + init { + fieldPerTier.values.forEach { field(it) } + } + + override fun buildEventPairs(sessionStructure: AnalysedSessionTree

): List> { + val level = sessionStructure.level.additional + return level.entries.map { (tierInstance, data) -> + val tierField = requireNotNull(fieldPerTier[tierInstance.tier]) { + "Tier ${tierInstance.tier} is now allowed here: only ${fieldPerTier.keys} are registered" + } + val tierDeclaration = tiersDeclarations.getValue(tierInstance.tier) + tierField with tierDeclaration.buildObjectEventData(tierInstance, data.description) + } + } +} + +private data class PredictionSessionFields

( + val declarationMainTierSet: MainTierSet

, + val declarationAdditionalTierSet: AdditionalTierSet

, + val predictionValidationRule: List, + val predictionTransform: (P?) -> String, + val sessionAnalysisFields: SessionAnalysisFields

? +) : SessionFields

() { + private val fieldMainInstances = ObjectEventField("main", declarationMainTierSet) + private val fieldAdditionalInstances = ObjectEventField("additional", declarationAdditionalTierSet) + private val fieldPrediction = PredictionField("prediction", predictionValidationRule, predictionTransform) + private val fieldSessionAnalysis = sessionAnalysisFields?.let { ObjectEventField("session", sessionAnalysisFields) } + + constructor(levelScheme: LevelScheme, + predictionValidationRule: List, + predictionTransform: (P?) -> String, + sessionAnalysisFields: Map>>?) + : this(MainTierSet(levelScheme.main), + AdditionalTierSet(levelScheme.additional), + predictionValidationRule, + predictionTransform, + sessionAnalysisFields?.let { SessionAnalysisFields(it) }) + + init { + field(fieldMainInstances) + field(fieldAdditionalInstances) + field(fieldPrediction) + fieldSessionAnalysis?.let { field(it) } + } + + override fun buildEventPairs(sessionStructure: AnalysedSessionTree

): List> { + require(sessionStructure is SessionTree.Leaf<*, *, P>) + val eventPairs = mutableListOf>( + fieldMainInstances with declarationMainTierSet.buildObjectEventData(sessionStructure), + fieldAdditionalInstances with declarationAdditionalTierSet.buildObjectEventData(sessionStructure), + fieldPrediction with sessionStructure.prediction, + ) + fieldSessionAnalysis?.let { + eventPairs += fieldSessionAnalysis with sessionAnalysisFields!!.buildObjectEventData(sessionStructure) + } + return eventPairs + } +} + +private data class NestableSessionFields

( + val declarationMainTierSet: MainTierSet

, + val declarationAdditionalTierSet: AdditionalTierSet

, + val declarationNestedSession: SessionFields

, + val sessionAnalysisFields: SessionAnalysisFields

? +) : SessionFields

() { + private val fieldMainInstances = ObjectEventField("main", declarationMainTierSet) + private val fieldAdditionalInstances = ObjectEventField("additional", declarationAdditionalTierSet) + private val fieldNestedSessions = ObjectListEventField("nested", declarationNestedSession) + private val fieldSessionAnalysis = sessionAnalysisFields?.let { ObjectEventField("session", sessionAnalysisFields) } + + constructor(levelScheme: LevelScheme, + deeperLevelsSchemes: List, + predictionValidationRule: List, + predictionTransform: (P?) -> String, + sessionAnalysisFields: Map>>?) + : this(MainTierSet(levelScheme.main), + AdditionalTierSet(levelScheme.additional), + if (deeperLevelsSchemes.size == 1) + PredictionSessionFields(deeperLevelsSchemes.first(), predictionValidationRule, predictionTransform, null) + else { + require(deeperLevelsSchemes.size > 1) + NestableSessionFields(deeperLevelsSchemes.first(), deeperLevelsSchemes.drop(1), predictionValidationRule, + predictionTransform, null) + }, + sessionAnalysisFields?.let { SessionAnalysisFields(it) } + ) + + init { + field(fieldMainInstances) + field(fieldAdditionalInstances) + field(fieldNestedSessions) + fieldSessionAnalysis?.let { field(it) } + } + + override fun buildEventPairs(sessionStructure: AnalysedSessionTree

): List> { + require(sessionStructure is SessionTree.ChildrenContainer) + val children = sessionStructure.children + val eventPairs = mutableListOf>( + fieldMainInstances with declarationMainTierSet.buildObjectEventData(sessionStructure), + fieldAdditionalInstances with declarationAdditionalTierSet.buildObjectEventData(sessionStructure), + fieldNestedSessions with children.map { nestedSession -> declarationNestedSession.buildObjectEventData(nestedSession) } + ) + + fieldSessionAnalysis?.let { + eventPairs += fieldSessionAnalysis with sessionAnalysisFields!!.buildObjectEventData(sessionStructure) + } + + return eventPairs + } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/logger/MLApiPlatformLogger.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/logger/MLApiPlatformLogger.kt new file mode 100644 index 000000000000..9e99ff03e259 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/logger/MLApiPlatformLogger.kt @@ -0,0 +1,36 @@ +// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.logger + +import com.intellij.internal.statistic.eventLog.events.BooleanEventField +import com.intellij.internal.statistic.eventLog.events.ClassEventField +import com.intellij.internal.statistic.eventLog.events.EventField +import com.intellij.internal.statistic.eventLog.events.EventPair +import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform +import com.intellij.platform.ml.impl.monitoring.MLApiStartupListener +import com.intellij.platform.ml.impl.monitoring.MLApiStartupProcessListener +import org.jetbrains.annotations.ApiStatus + +@ApiStatus.Internal +class MLApiPlatformStartupLogger : EventIdRecordingMLEvent(), MLApiStartupListener { + companion object { + val SUCCESS = BooleanEventField("success") + val EXCEPTION = ClassEventField("exception") + } + + override val eventName: String = "startup" + + override val declaration: Array> = arrayOf(SUCCESS) + + override fun onBeforeStarted(apiPlatform: MLApiPlatform): MLApiStartupProcessListener { + val eventId = getEventId(apiPlatform) + return object : MLApiStartupProcessListener { + override fun onFinished() = eventId.log(SUCCESS with true) + + override fun onFailed(exception: Throwable?) { + val fields = mutableListOf>(SUCCESS with false) + exception?.let { fields += EXCEPTION with it.javaClass } + eventId.log(*fields.toTypedArray()) + } + } + } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/logger/MLEventsLogger.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/logger/MLEventsLogger.kt new file mode 100644 index 000000000000..24fdf32b1f38 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/logger/MLEventsLogger.kt @@ -0,0 +1,121 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.logger + +import com.intellij.internal.statistic.eventLog.EventLogGroup +import com.intellij.internal.statistic.eventLog.events.EventField +import com.intellij.internal.statistic.eventLog.events.VarargEventId +import com.intellij.internal.statistic.service.fus.collectors.CounterUsagesCollector +import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform +import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform.Companion.ensureApproachesInitialized +import com.intellij.platform.ml.impl.apiPlatform.ReplaceableIJPlatform +import org.jetbrains.annotations.ApiStatus + +@ApiStatus.Internal +interface MLEvent { + val eventName: String + + val declaration: Array> + + fun onEventGroupInitialized(eventId: VarargEventId) +} + +@ApiStatus.Internal +abstract class EventIdRecordingMLEvent : MLEvent { + private var providedEventId: VarargEventId? = null + + protected fun getEventId(apiPlatform: MLApiPlatform): VarargEventId { + MLEventsLogger.Manager.ensureInitialized(okIfInitializing = false, apiPlatform) + return requireNotNull(providedEventId) + } + + final override fun onEventGroupInitialized(eventId: VarargEventId) { + providedEventId = eventId + } +} + +/** + * It logs ML sessions with tiers' descriptions and the analysis. + * Each session is logged after it is finished, and all analyzers [com.intellij.platform.ml.impl.session.analysis.SessionAnalyser] + * have yielded the analysis. + * + * As the FUS logs' validators are initialized once during the application's start, + * before giving the [getGroup] for the FUS to use, we must first walk through all declared + * [com.intellij.platform.ml.impl.MLTaskApproach] in the platform and register them in the FUS scheme + * (in case they require FUS logging). + */ +@ApiStatus.Internal +class MLEventsLogger : CounterUsagesCollector() { + override fun getGroup(): EventLogGroup = Initializer.GROUP.also { Manager.ensureInitialized(okIfInitializing = false) } + + object Manager { + private val defaultPlatform = ReplaceableIJPlatform + + internal fun ensureInitialized(okIfInitializing: Boolean, apiPlatform: MLApiPlatform = defaultPlatform) { + when (val state = Initializer.state) { + is State.FailedToInitialize -> throw Exception("ML Event Log already has failed to initialize", state.exception) + State.Initializing -> if (okIfInitializing) return else throw IllegalStateException("Initialization recursion") + State.NonInitialized -> Initializer.initializeGroup(apiPlatform) + is State.Initialized -> { + val currentApiPlatformState = apiPlatform.staticState + require(currentApiPlatformState == state.context.staticState) { + """ + FUS ML Logger was initialized from ${state.context.apiPlatform} with presumably immutable state ${state.context.staticState}, + but it is used from ${apiPlatform} with state ${currentApiPlatformState}, which differs from the initial state. + Hence, something that was expected to be logged will not be. + """.trimIndent() + } + } + } + } + } + + internal data class InitializationContext( + val apiPlatform: MLApiPlatform, + val staticState: MLApiPlatform.StaticState + ) + + internal sealed interface State { + data object NonInitialized : State + + data object Initializing : State + + class FailedToInitialize(val exception: Throwable) : State + + class Initialized(val context: InitializationContext) : State + } + + internal object Initializer { + val GROUP = EventLogGroup("ml", 1) + + var state: State = State.NonInitialized + + fun initializeGroup(apiPlatform: MLApiPlatform) { + require(state == State.NonInitialized) + state = State.Initializing + try { + apiPlatform.ensureApproachesInitialized() + val apiPlatformState = apiPlatform.staticState + + apiPlatform.addDefaultEventsAndListeners() + + apiPlatform.events.forEach { event -> + val eventId = GROUP.registerVarargEvent(event.eventName, *event.declaration) + event.onEventGroupInitialized(eventId) + } + state = State.Initialized(InitializationContext(apiPlatform, apiPlatformState)) + } + catch (e: Throwable) { + state = State.FailedToInitialize(e) + } + (state as? State.FailedToInitialize)?.let { + throw Exception("Failed to initialize FUS ML Logger", it.exception) + } + } + + private fun MLApiPlatform.addDefaultEventsAndListeners() { + val startupLogger = MLApiPlatformStartupLogger() + addEvent(startupLogger) + addStartupListener(startupLogger) + } + } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/logger/MLSessionFailedLogger.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/logger/MLSessionFailedLogger.kt new file mode 100644 index 000000000000..ad4aae17eb05 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/logger/MLSessionFailedLogger.kt @@ -0,0 +1,98 @@ +// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.logger + +import com.intellij.internal.statistic.eventLog.events.EventField +import com.intellij.internal.statistic.eventLog.events.ObjectEventData +import com.intellij.internal.statistic.eventLog.events.ObjectEventField +import com.intellij.platform.ml.Session +import com.intellij.platform.ml.impl.MLTaskApproach +import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform +import com.intellij.platform.ml.impl.monitoring.* +import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.ApproachListeners.Companion.monitoredBy +import com.intellij.platform.ml.impl.session.analysis.ShallowSessionAnalyser +import com.intellij.platform.ml.impl.session.analysis.ShallowSessionAnalyser.Companion.declarationObjectDescription +import org.jetbrains.annotations.ApiStatus + +/** + * Logs to FUS information about an ML session, that has failed to start. + * + * A way to start logging, is to create a [FailedSessionLoggerRegister] and add it via [com.intellij.platform.ml.impl.apiPlatform.ReplaceableIJPlatform.addStartupListener] + * + * @param taskApproach The approach that is monitored + * @param exceptionalAnalysers Analyzers, that are triggered, when [MLTaskApproach.startSession] fails with an unhandled exception + * @param normalFailureAnalysers Analyzers, that aer triggered, when [MLTaskApproach.startSession] returns a [Session.StartOutcome.Failure] + */ +@ApiStatus.Internal +class MLSessionFailedLogger( + private val taskApproach: MLTaskApproach

, + private val exceptionalAnalysers: Collection>, + private val normalFailureAnalysers: Collection>>, + private val apiPlatform: MLApiPlatform, +) : EventIdRecordingMLEvent(), MLTaskGroupListener { + override val eventName: String = "${taskApproach.task.name}.failed" + + private val fields: Map = ( + exceptionalAnalysers.map { ObjectEventField(it.name, it.declarationObjectDescription) } + + normalFailureAnalysers.map { ObjectEventField(it.name, it.declarationObjectDescription) }) + .associateBy { it.name } + + override val declaration: Array> = fields.values.toTypedArray() + + override val approachListeners: Collection> + get() { + val eventId = getEventId(apiPlatform) + return listOf( + taskApproach.javaClass monitoredBy MLApproachInitializationListener { permanentSessionEnvironment -> + object : MLApproachListener { + override fun onFailedToStartSessionWithException(exception: Throwable) { + val analysis = exceptionalAnalysers.map { + val analyserField = fields.getValue(it.name) + analyserField with ObjectEventData(it.analyse(permanentSessionEnvironment, exception)) + } + eventId.log(*analysis.toTypedArray()) + } + + override fun onFailedToStartSession(failure: Session.StartOutcome.Failure

) { + val analysis = normalFailureAnalysers.map { + val analyserField = fields.getValue(it.name) + analyserField with ObjectEventData(it.analyse(permanentSessionEnvironment, failure)) + } + eventId.log(*analysis.toTypedArray()) + } + + override fun onStartedSession(session: Session

) = null + } + } + ) + } +} + +/** + * Registers failed sessions' logging as a separate FUS event. + * + * See [MLSessionFailedLogger] + */ +@ApiStatus.Internal +open class FailedSessionLoggerRegister( + private val targetApproachClass: Class>, + private val exceptionalAnalysers: Collection>, + private val normalFailureAnalysers: Collection>>, +) : MLApiStartupListener { + override fun onBeforeStarted(apiPlatform: MLApiPlatform): MLApiStartupProcessListener { + return object : MLApiStartupProcessListener { + override fun onStartedInitializingFus(initializedApproaches: Collection>) { + @Suppress("UNCHECKED_CAST") + val targetInitializedApproach: MLTaskApproach

= requireNotNull( + initializedApproaches.find { it.approach.javaClass == targetApproachClass }) { + """ + Could not create logger for failed sessions for $targetApproachClass in platform $apiPlatform, because the corresponding approach was not initialized. + Initialized approaches: $initializedApproaches + """.trimIndent() + }.approach as MLTaskApproach

+ val finishedEventLogger = MLSessionFailedLogger(targetInitializedApproach, exceptionalAnalysers, normalFailureAnalysers, apiPlatform) + apiPlatform.addMLEventBeforeFusInitialized(finishedEventLogger) + apiPlatform.addTaskListener(finishedEventLogger) + } + } + } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/logger/MLSessionFinishedLogger.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/logger/MLSessionFinishedLogger.kt new file mode 100644 index 000000000000..7f2d64ae31a8 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/logger/MLSessionFinishedLogger.kt @@ -0,0 +1,89 @@ +// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.logger + +import com.intellij.internal.statistic.eventLog.events.EventField +import com.intellij.platform.ml.Environment +import com.intellij.platform.ml.Session +import com.intellij.platform.ml.impl.MLTaskApproach +import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform +import com.intellij.platform.ml.impl.monitoring.* +import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.ApproachListeners.Companion.monitoredBy +import com.intellij.platform.ml.impl.session.AnalysedRootContainer +import com.intellij.platform.ml.impl.session.DescribedRootContainer +import org.jetbrains.annotations.ApiStatus + +/** + * Logs to FUS information about a finished ML session, tier instances' descriptions, analysis. + * + * A way to start logging, is to create a [FinishedSessionLoggerRegister] and add it with [com.intellij.platform.ml.impl.apiPlatform.ReplaceableIJPlatform.addStartupListener]. + * If you are using a regression model, see [InplaceFeaturesScheme.FusScheme.Companion.DOUBLE]. + * + * @param approach The approach whose sessions are monitored + * @param configuration The particular scheme that is used to serialize sessions + */ +@ApiStatus.Internal +class MLSessionFinishedLogger( + approach: MLTaskApproach

, + configuration: FusSessionEventBuilder.FusScheme

, + private val apiPlatform: MLApiPlatform, +) : EventIdRecordingMLEvent(), MLTaskGroupListener { + private val loggingScheme: FusSessionEventBuilder

= configuration.createEventBuilder(approach.approachDeclaration) + private val fusDeclaration: SessionFields

= loggingScheme.buildFusDeclaration() + override val declaration: Array> = fusDeclaration.getFields() + + override val eventName: String = "${approach.task.name}.finished" + + override val approachListeners: Collection> = listOf( + approach.javaClass monitoredBy InitializationLogger() + ) + + inner class InitializationLogger : MLApproachInitializationListener { + override fun onAttemptedToStartSession(permanentSessionEnvironment: Environment): MLApproachListener = ApproachLogger() + } + + inner class ApproachLogger : MLApproachListener { + override fun onFailedToStartSessionWithException(exception: Throwable) {} + + override fun onFailedToStartSession(failure: Session.StartOutcome.Failure

) {} + + override fun onStartedSession(session: Session

): MLSessionListener = SessionLogger() + } + + inner class SessionLogger : MLSessionListener { + override fun onSessionDescriptionFinished(sessionTree: DescribedRootContainer) {} + + override fun onSessionAnalysisFinished(sessionTree: AnalysedRootContainer

) { + val eventId = getEventId(apiPlatform) + val fusEventData = loggingScheme.buildRecord(sessionTree, fusDeclaration) + eventId.log(*fusEventData) + } + } +} + +/** + * Registers successful ML sessions' logging as a separate FUS event. + * + * See [MLSessionFinishedLogger] + */ +@ApiStatus.Internal +open class FinishedSessionLoggerRegister( + private val targetApproachClass: Class>, + private val fusScheme: FusSessionEventBuilder.FusScheme

, +) : MLApiStartupListener { + override fun onBeforeStarted(apiPlatform: MLApiPlatform): MLApiStartupProcessListener { + return object : MLApiStartupProcessListener { + override fun onStartedInitializingFus(initializedApproaches: Collection>) { + @Suppress("UNCHECKED_CAST") + val targetInitializedApproach: MLTaskApproach

= requireNotNull(initializedApproaches.find { it.approach.javaClass == targetApproachClass }) { + """ + Could not create logger for finished sessions of $targetApproachClass in platform $apiPlatform, because the corresponding approach was not initialized. + Initialized approaches: $initializedApproaches + """.trimIndent() + }.approach as MLTaskApproach

+ val finishedEventLogger = MLSessionFinishedLogger(targetInitializedApproach, fusScheme, apiPlatform) + apiPlatform.addMLEventBeforeFusInitialized(finishedEventLogger) + apiPlatform.addTaskListener(finishedEventLogger) + } + } + } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/model/MLModel.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/model/MLModel.kt new file mode 100644 index 000000000000..89d45ae4d412 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/model/MLModel.kt @@ -0,0 +1,61 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.model + +import com.intellij.platform.ml.Environment +import com.intellij.platform.ml.Feature +import com.intellij.platform.ml.PerTier +import com.intellij.platform.ml.TierRequester +import com.intellij.platform.ml.impl.FeatureSelector +import com.intellij.platform.ml.impl.LevelTiers +import org.jetbrains.annotations.ApiStatus + +/** + * Performs a prediction based on the given features. + * + * @param P The prediction's type. + */ +@ApiStatus.Internal +interface MLModel

{ + /** + * Provides a model, that will be used during an ML session. + * + * It extends [TierRequester] interface which implies, that you can request + * additional tiers that will help you acquire the right model. + */ + interface Provider, P : Any> : TierRequester { + /** + * Provides an ML model from the task's environment, and the additional tiers + * declared via [requiredTiers]. + * If [requiredTiers] could not be fulfilled, then the session will not be started + * and [com.intellij.platform.ml.impl.approach.InsufficientEnvironmentForModelProviderOutcome] + * will be returned as the start's outcome. + * + * @param sessionTiers Contains the main as well as additional tiers that will be used during the session. + * @param environment Contains the all-embracing "permanent" tiers of an ML session - the ones that sit on the first position + * of an [com.intellij.platform.ml.impl.MLTask]'s declaration. + * Plus additional tiers that were requested in [requiredTiers]. + * extendedPermanentSessionEnvironment + */ + fun provideModel(sessionTiers: List, environment: Environment): M? + } + + /** + * Declares a set of features, that are known and can be used by the ML model. + * For each tier, it returns a selection. + * Selection will be 'complete' if and only if the model could be executed with it. + * + * The set of tiers it is aware of could not include some additional tiers, declared + * in [com.intellij.platform.ml.impl.approach.LogDrivenModelInference]'s 'not used description'. + * Because old models could be not aware of the newly established tiers. + */ + val knownFeatures: PerTier + + /** + * Performs the prediction. + * + * @param features Features computed for this model to take into account. + * It is guaranteed that for each tier, this set of features is a 'complete' + * selection of features. + */ + fun predict(features: PerTier>): P +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/model/RegressionModel.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/model/RegressionModel.kt new file mode 100644 index 000000000000..411fe86b133a --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/model/RegressionModel.kt @@ -0,0 +1,170 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.model + +import com.intellij.internal.ml.DecisionFunction +import com.intellij.platform.ml.Feature +import com.intellij.platform.ml.FeatureDeclaration +import com.intellij.platform.ml.PerTier +import com.intellij.platform.ml.Tier +import com.intellij.platform.ml.impl.FeatureSelector +import com.intellij.platform.ml.impl.LevelTiers +import org.jetbrains.annotations.ApiStatus + +/** + * A wrapper for using legacy [DecisionFunction] as the new API's [MLModel]. + */ +@ApiStatus.Internal +open class RegressionModel private constructor( + private val decisionFunction: DecisionFunction, + private val featuresTiers: Set>, + availableTiers: Set>, + private val featureSerialization: FeatureNameSerialization +) : MLModel { + constructor(decisionFunction: DecisionFunction, + featureSerialization: FeatureNameSerialization, + sessionTiers: List) : this( + decisionFunction = decisionFunction, + featuresTiers = decisionFunction.featuresOrder.map { + featureSerialization.deserialize(it.featureName, sessionTiers.flatten().associateBy { it.name }).first + }.toSet(), + availableTiers = sessionTiers.flatten(), + featureSerialization = featureSerialization + ) + + override val knownFeatures: PerTier = createFeatureSelectors( + DecisionFunctionWrapper(decisionFunction, availableTiers, featureSerialization), + featuresTiers + ) + + override fun predict(features: PerTier>): Double { + val array = DoubleArray(decisionFunction.featuresOrder.size) + val featurePerSerializedName = features + .flatMap { (tier, tierFeatures) -> tierFeatures.map { tier to it } } + .associate { (tier, feature) -> featureSerialization.serialize(tier, feature.declaration.name) to feature } + + require(features.keys == featuresTiers) { + "Given features tiers are ${features.keys}, but this model needs ${featuresTiers}" + } + + for (featureI in decisionFunction.featuresOrder.indices) { + val featureMapper = decisionFunction.featuresOrder[featureI] + val featureSerializedName = featureMapper.featureName + val featureValue = featureMapper.asArrayValue(featurePerSerializedName[featureSerializedName]?.value) + array[featureI] = featureValue + } + + return decisionFunction.predict(array) + } + + interface FeatureNameSerialization { + fun serialize(tier: Tier<*>, featureName: String): String + + fun deserialize(serializedFeatureName: String, availableTiersPerName: Map>): Pair, String> + } + + class DefaultSerialization : FeatureNameSerialization { + private val SERIALIZED_FEATURE_SEPARATOR = '/' + + override fun serialize(tier: Tier<*>, featureName: String): String { + return "${tier.name}${SERIALIZED_FEATURE_SEPARATOR}${featureName}" + } + + override fun deserialize(serializedFeatureName: String, availableTiersPerName: Map>): Pair, String> { + val indexOfLastSeparator = serializedFeatureName.indexOfLast { it == SERIALIZED_FEATURE_SEPARATOR } + require(indexOfLastSeparator >= 0) { "Feature name '$serializedFeatureName' does not contain tier's name" } + val featureTierName = serializedFeatureName.slice(0 until indexOfLastSeparator) + val featureName = serializedFeatureName.slice(indexOfLastSeparator until serializedFeatureName.length) + val featureTier = requireNotNull( + availableTiersPerName[featureTierName]) { "Serialized feature '$serializedFeatureName' has tier $featureTierName, but all available tiers are ${availableTiersPerName.keys}" } + return featureTier to featureName + } + } + + class SelectionMissingFeatures( + selectedFeatures: Set>, + missingFeatures: Set + ) : FeatureSelector.Selection.Incomplete(selectedFeatures) { + override val details: String = "Regression model requires more features to run. " + + "Missing: $missingFeatures, " + + "Has: $selectedFeatures" + } + + private class DecisionFunctionWrapper( + private val decisionFunction: DecisionFunction, + private val availableTiers: Set>, + private val featureNameSerialization: FeatureNameSerialization + ) { + private val availableTiersPerName: Map> = availableTiers.associateBy { it.name } + + fun getKnownFeatures(): PerTier> { + val knownFeaturesSerializedNames = decisionFunction.featuresOrder.map { it.featureName }.toSet() + return knownFeaturesSerializedNames + .map { featureNameSerialization.deserialize(it, availableTiersPerName) } + .groupBy({ it.first }, { it.second }) + .mapValues { it.value.toSet() } + } + + fun getRequiredFeaturesPerTier(): PerTier> { + val availableTiersPerName = availableTiers.associateBy { it.name } + val requiredFeaturesSerializedNames = decisionFunction.requiredFeatures.filterNotNull().toSet() + return requiredFeaturesSerializedNames + .map { serializedFeatureName -> featureNameSerialization.deserialize(serializedFeatureName, availableTiersPerName) } + .groupBy({ it.first }, { it.second }) + .mapValues { it.value.toSet() } + } + + fun getUnknownFeatures(tier: Tier<*>, featuresNames: Set): Set { + val featureNamePerSerializedName = featuresNames + .associateBy { featureNameSerialization.serialize(tier, it) } + + val unknownFeaturesSerializedNames = decisionFunction.getUnknownFeatures(featureNamePerSerializedName.keys).filterNotNull() + + return unknownFeaturesSerializedNames.map { + requireNotNull(featureNamePerSerializedName[it]) { "Decision function returned an unknown feature that was not given: '$it'" } + }.toSet() + } + } + + companion object { + private fun createFeatureSelectors(decisionFunction: DecisionFunctionWrapper, + featuresTiers: Set>): PerTier { + val requiredFeaturesPerTier = decisionFunction.getRequiredFeaturesPerTier() + + fun createFeatureSelector(tier: Tier<*>) = object : FeatureSelector { + init { + val knownFeatures = decisionFunction.getKnownFeatures() + knownFeatures.forEach { (tier, tierFeatures) -> + val nonConsistentlyKnownFeatures = decisionFunction.getUnknownFeatures(tier, tierFeatures) + require(nonConsistentlyKnownFeatures.isEmpty()) { + "These features are known and unknown at the same time: $nonConsistentlyKnownFeatures" + } + } + } + + override fun select(availableFeatures: Set>): FeatureSelector.Selection { + val availableFeaturesPerName = availableFeatures.associateBy { it.name } + val availableFeaturesNames = availableFeatures.map { it.name }.toSet() + val unknownFeaturesNames = decisionFunction.getUnknownFeatures(tier, availableFeaturesNames) + val knownAvailableFeaturesNames = availableFeaturesNames - unknownFeaturesNames + val knownAvailableFeatures = knownAvailableFeaturesNames.map { availableFeaturesPerName.getValue(it) }.toSet() + val requiredFeaturesNames = requiredFeaturesPerTier[tier] ?: emptySet() + + return if (availableFeaturesNames.containsAll(requiredFeaturesNames)) + FeatureSelector.Selection.Complete(knownAvailableFeatures) + else + SelectionMissingFeatures(knownAvailableFeatures, requiredFeaturesNames - availableFeaturesNames) + } + + override fun select(featureDeclaration: FeatureDeclaration<*>): Boolean { + return decisionFunction.getUnknownFeatures(tier, setOf(featureDeclaration.name)).isEmpty() + } + } + + return featuresTiers.associateWith { createFeatureSelector(it) } + } + } +} + +private fun List.flatten(): Set> { + return this.flatMap { it.main + it.additional }.toSet() +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/monitoring/MLApiStartupProcessListener.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/monitoring/MLApiStartupProcessListener.kt new file mode 100644 index 000000000000..eef8c04755a4 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/monitoring/MLApiStartupProcessListener.kt @@ -0,0 +1,29 @@ +// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.monitoring + +import com.intellij.platform.ml.impl.MLTaskApproach +import com.intellij.platform.ml.impl.MLTaskApproachInitializer +import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform +import org.jetbrains.annotations.ApiStatus + +@ApiStatus.Internal +fun interface MLApiStartupListener { + fun onBeforeStarted(apiPlatform: MLApiPlatform): MLApiStartupProcessListener +} + +@ApiStatus.Internal +data class InitializerAndApproach( + val initializer: MLTaskApproachInitializer, + val approach: MLTaskApproach +) + +@ApiStatus.Internal +interface MLApiStartupProcessListener { + fun onStartedInitializingApproaches() {} + + fun onStartedInitializingFus(initializedApproaches: Collection>) {} + + fun onFinished() {} + + fun onFailed(exception: Throwable?) {} +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/monitoring/MLApproachListener.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/monitoring/MLApproachListener.kt new file mode 100644 index 000000000000..5020d357a91a --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/monitoring/MLApproachListener.kt @@ -0,0 +1,166 @@ +// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.monitoring + +import com.intellij.platform.ml.Environment +import com.intellij.platform.ml.Session +import com.intellij.platform.ml.impl.MLTaskApproach +import com.intellij.platform.ml.impl.monitoring.MLApproachInitializationListener.Companion.asJoinedListener +import com.intellij.platform.ml.impl.monitoring.MLApproachListener.Companion.asJoinedListener +import com.intellij.platform.ml.impl.monitoring.MLSessionListener.Companion.asJoinedListener +import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.ApproachListeners.Companion.monitoredBy +import com.intellij.platform.ml.impl.session.AnalysedRootContainer +import com.intellij.platform.ml.impl.session.DescribedRootContainer +import org.jetbrains.annotations.ApiStatus + +/** + * Provides listeners for a set of [com.intellij.platform.ml.impl.MLTaskApproach] + * + * Only [com.intellij.platform.ml.impl.approach.LogDrivenModelInference] and the subclasses, that are + * calling [com.intellij.platform.ml.impl.approach.LogDrivenModelInference.startSession] are monitored. + */ +@ApiStatus.Internal +interface MLTaskGroupListener { + /** + * For every approach, the [MLTaskGroupListener] is interested in this value provides a collection of + * [MLApproachInitializationListener] + * + * The comfortable way to create this accordance would be by using + * [com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.ApproachListeners.Companion.monitoredBy] infix function. + */ + val approachListeners: Collection> + + /** + * A type-safe pair of approach's class and a set of listeners + * + * A proper way to create it is to use [monitoredBy] + */ + data class ApproachListeners internal constructor( + val taskApproach: Class>, + val approachListener: Collection> + ) { + companion object { + infix fun Class>.monitoredBy(approachListener: MLApproachInitializationListener) = ApproachListeners( + this, listOf(approachListener)) + + infix fun Class>.monitoredBy(approachListeners: Collection>) = ApproachListeners( + this, approachListeners) + } + } + + companion object { + internal val MLTaskGroupListener.targetedApproaches: Set>> + get() = approachListeners.map { it.taskApproach }.toSet() + + internal fun

MLTaskGroupListener.onAttemptedToStartSession(taskApproach: MLTaskApproach

, + permanentSessionEnvironment: Environment): MLApproachListener? { + @Suppress("UNCHECKED_CAST") + val approachListeners: List> = approachListeners + .filter { it.taskApproach == taskApproach.javaClass } + .flatMap { it.approachListener } as List> + return approachListeners.asJoinedListener().onAttemptedToStartSession(permanentSessionEnvironment) + } + } +} + +/** + * Listens to the attempt of starting new [Session] of the [MLTaskApproach], that this listener was put + * into correspondence to via [com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.ApproachListeners.Companion.monitoredBy] + */ +@ApiStatus.Internal +fun interface MLApproachInitializationListener { + /** + * Called each time, when [com.intellij.platform.ml.impl.approach.LogDrivenModelInference.startSession] is invoked + * + * @return A listener, that will be monitoring how successful the start was. If it is not needed, null is returned. + */ + fun onAttemptedToStartSession(permanentSessionEnvironment: Environment): MLApproachListener? + + companion object { + fun Collection>.asJoinedListener(): MLApproachInitializationListener = + MLApproachInitializationListener { permanentSessionEnvironment -> + val approachListeners = this@asJoinedListener.mapNotNull { it.onAttemptedToStartSession(permanentSessionEnvironment) } + if (approachListeners.isEmpty()) null else approachListeners.asJoinedListener() + } + } +} + +/** + * Listens to the process of starting new [Session] of [com.intellij.platform.ml.impl.approach.LogDrivenModelInference]. + */ +@ApiStatus.Internal +interface MLApproachListener { + /** + * Called if the session was not started, + * on exceptionally rare occasions, + * when the [com.intellij.platform.ml.impl.approach.LogDrivenModelInference.startSession] failed with an exception + */ + fun onFailedToStartSessionWithException(exception: Throwable) + + /** + * Called if the session was not started, + * but the failure is 'ordinary'. + */ + fun onFailedToStartSession(failure: Session.StartOutcome.Failure

) + + /** + * Called when a new [com.intellij.platform.ml.impl.approach.LogDrivenModelInference]'s session was started successfully. + * + * @return A listener for tracking the session's progress, null if the session will not be tracked. + */ + fun onStartedSession(session: Session

): MLSessionListener? + + companion object { + fun Collection>.asJoinedListener(): MLApproachListener { + val approachListeners = this@asJoinedListener + + return object : MLApproachListener { + override fun onFailedToStartSessionWithException(exception: Throwable) = + approachListeners.forEach { it.onFailedToStartSessionWithException(exception) } + + override fun onFailedToStartSession(failure: Session.StartOutcome.Failure

) = approachListeners.forEach { + it.onFailedToStartSession(failure) + } + + override fun onStartedSession(session: Session

): MLSessionListener? { + val listeners = approachListeners.mapNotNull { it.onStartedSession(session) } + return if (listeners.isEmpty()) null else listeners.asJoinedListener() + } + } + } + } +} + +/** + * Listens to session events of a [com.intellij.platform.ml.impl.approach.LogDrivenModelInference] + */ +@ApiStatus.Internal +interface MLSessionListener { + /** + * All tier instances were established (the tree will not be growing further), + * described, and predictions in the [sessionTree] were finished. + */ + fun onSessionDescriptionFinished(sessionTree: DescribedRootContainer) + + /** + * Called only after [onSessionDescriptionFinished] + * + * All tree nodes were analyzed. + */ + fun onSessionAnalysisFinished(sessionTree: AnalysedRootContainer

) + + companion object { + fun Collection>.asJoinedListener(): MLSessionListener { + val sessionListeners = this@asJoinedListener + + return object : MLSessionListener { + override fun onSessionDescriptionFinished(sessionTree: DescribedRootContainer) = sessionListeners.forEach { + it.onSessionDescriptionFinished(sessionTree) + } + + override fun onSessionAnalysisFinished(sessionTree: AnalysedRootContainer

) = sessionListeners.forEach { + it.onSessionAnalysisFinished(sessionTree) + } + } + } + } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/session/LevelDescriptor.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/session/LevelDescriptor.kt new file mode 100644 index 000000000000..b377fda3a39f --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/session/LevelDescriptor.kt @@ -0,0 +1,182 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.session + +import com.intellij.platform.ml.* +import com.intellij.platform.ml.ScopeEnvironment.Companion.restrictedBy +import com.intellij.platform.ml.TierRequester.Companion.fulfilledBy +import com.intellij.platform.ml.impl.DescriptionComputer +import com.intellij.platform.ml.impl.FeatureSelector +import com.intellij.platform.ml.impl.FeatureSelector.Companion.or +import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform +import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform.Companion.getDescriptorsOfTiers +import com.intellij.platform.ml.impl.environment.ExtendedEnvironment +import org.jetbrains.annotations.ApiStatus + +@ApiStatus.Internal +data class LevelDescriptor( + val apiPlatform: MLApiPlatform, + val descriptionComputer: DescriptionComputer, + val usedFeaturesSelectors: PerTier, + val notUsedFeaturesSelectors: PerTier, +) { + fun describe( + upperLevels: List, + nextLevelMainEnvironment: Environment, + nextLevelAdditionalTiers: Set> + ): DescribedLevel { + val mainEnvironment = Environment.joined(listOf( + Environment.of(upperLevels.flatMap { it.main.keys }), + nextLevelMainEnvironment + )) + + val availableEnvironment = ExtendedEnvironment(apiPlatform.environmentExtenders, mainEnvironment) + val extendedEnvironment = availableEnvironment.restrictedBy(nextLevelMainEnvironment.tiers + nextLevelAdditionalTiers) + + val runnableDescriptorsPerTier = apiPlatform.getDescriptorsOfTiers(extendedEnvironment.tiers) + .mapValues { (_, descriptors) -> descriptors.fulfilledBy(availableEnvironment) } + + val extendedEnvironmentDescription = runnableDescriptorsPerTier + .mapValues { (tier, tierDescriptors) -> + describeTier(tier, tierDescriptors, availableEnvironment) + } + + val nextLevel = DescribedLevel( + main = nextLevelMainEnvironment.tierInstances.associateWith { mainTierInstance -> + DescribedTierData(extendedEnvironmentDescription.getValue(mainTierInstance.tier)) + }, + additional = extendedEnvironment.restrictedBy(nextLevelAdditionalTiers).tierInstances.associateWith { mainTierInstance -> + DescribedTierData(extendedEnvironmentDescription.getValue(mainTierInstance.tier)) + }, + ) + + return nextLevel + } + + private fun createFilterOfFeaturesToCompute(tier: Tier<*>, tierDescriptors: List): FeatureFilter { + if (tierDescriptors.any { it is ObsoleteTierDescriptor }) { + return FeatureFilter.ACCEPT_ALL + } + + val tierFeaturesSelector = (usedFeaturesSelectors[tier] ?: FeatureSelector.NOTHING) or notUsedFeaturesSelectors.getValue(tier) + val tierComputableFeatures = tierDescriptors.flatMap { it.descriptionDeclaration }.toSet() + val tierToComputeSelection = tierFeaturesSelector.select(tierComputableFeatures) + + if (tierToComputeSelection is FeatureSelector.Selection.Incomplete) { + throw IncompleteDescriptionException(tier, tierToComputeSelection.selectedFeatures, tierToComputeSelection.details) + } + + return FeatureFilter { it in tierToComputeSelection.selectedFeatures } + } + + // Filter assumes that the feature is either used or not used by the model + private fun createFilterOfUsedFeatures(tier: Tier<*>): FeatureFilter { + val usedFeaturesSelector = usedFeaturesSelectors[tier] ?: FeatureSelector.NOTHING + val notUsedFeaturesSelector = notUsedFeaturesSelectors.getValue(tier) + return FeatureFilter { + val featureIsUsed = usedFeaturesSelector.select(it) + val featureIsNotUsed = notUsedFeaturesSelector.select(it) + assert(featureIsUsed || featureIsNotUsed) { + "Feature $it of $tier must not have been computed. It is not used by the ML model or marked as not used" + } + require(featureIsUsed xor featureIsNotUsed) { + "${it} of $tier is used by the ML model, but marked as not used at the same time" + } + featureIsUsed + } + } + + private fun Set.splitByUsage(usableFeaturesFilter: FeatureFilter): Usage> { + val usedFeatures = this.filter { usableFeaturesFilter.accept(it.declaration) } + val notUsedFeatures = this.filter { !usableFeaturesFilter.accept(it.declaration) } + return Usage(usedFeatures.toSet(), notUsedFeatures.toSet()) + } + + private fun makeDescriptionPartition(descriptor: TierDescriptor, + computedDescription: Set, + usedFeaturesFilter: FeatureFilter): DescriptionPartition { + + val computedDescriptionDeclaration = computedDescription.map { it.declaration }.toSet() + + if (descriptor is ObsoleteTierDescriptor) { + val nonDeclaredDescription = computedDescription.filter { it.declaration !in descriptor.partialDescriptionDeclaration }.toSet() + apiPlatform.manageNonDeclaredFeatures(descriptor, nonDeclaredDescription) + } + else { + val notDeclaredFeaturesDeclarations = computedDescriptionDeclaration - descriptor.descriptionDeclaration + require(notDeclaredFeaturesDeclarations.isEmpty()) { + """ + $descriptor described environment with some features that were not declared: + ${notDeclaredFeaturesDeclarations.map { it.name }} + computed declaration: ${computedDescriptionDeclaration} + declared declaration: ${descriptor.descriptionDeclaration} + """.trimIndent() + } + } + + val maybePartialDescriptionDeclaration = if (descriptor is ObsoleteTierDescriptor) + descriptor.partialDescriptionDeclaration + else + descriptor.descriptionDeclaration + + val notComputedDescriptionDeclaration = maybePartialDescriptionDeclaration - computedDescriptionDeclaration + for (notComputedFeatureDeclaration in notComputedDescriptionDeclaration) { + require(!usedFeaturesFilter.accept(notComputedFeatureDeclaration)) { + "Feature ${notComputedFeatureDeclaration} was expected to be computed by $descriptor, " + + "because was declared and accepted by the feature filter. Computed declaration: $computedDescriptionDeclaration" + } + } + + val declaredFeatures = mutableSetOf() + val nonDeclaredFeatures = mutableSetOf() + + if (descriptor is ObsoleteTierDescriptor) { + computedDescription.forEach { + if (it.declaration in descriptor.partialDescriptionDeclaration) + declaredFeatures += it + else + nonDeclaredFeatures += it + } + } + else { + declaredFeatures.addAll(computedDescription) + } + + return Declaredness(declaredFeatures.splitByUsage(usedFeaturesFilter), nonDeclaredFeatures.splitByUsage(usedFeaturesFilter)) + } + + private fun describeTier(tier: Tier<*>, tierDescriptors: List, environment: Environment): DescriptionPartition { + val toComputeFilter = createFilterOfFeaturesToCompute(tier, tierDescriptors) + val usefulTierDescriptors = tierDescriptors.filter { it.couldBeUseful(toComputeFilter) } + + val description: Map> = descriptionComputer.computeDescription( + tier, + usefulTierDescriptors, + environment, + toComputeFilter + ) + + val usedByModelFilter = createFilterOfUsedFeatures(tier) + + val descriptionPartition = description.entries + .map { (tierDescriptor, computedDescription) -> + makeDescriptionPartition(tierDescriptor, computedDescription, usedByModelFilter) + } + .reduceOrNull { first, second -> + Declaredness( + declared = Usage(used = first.declared.used + second.declared.used, + notUsed = first.declared.notUsed + second.declared.notUsed), + nonDeclared = Usage(used = first.nonDeclared.used + second.nonDeclared.used, + notUsed = first.nonDeclared.notUsed + second.nonDeclared.notUsed) + ) + } ?: Declaredness(Usage(emptySet(), emptySet()), Usage(emptySet(), emptySet())) + + return descriptionPartition + } +} + +@ApiStatus.Internal +class IncompleteDescriptionException(tier: Tier<*>, + selectedFeatures: Set>, + missingFeaturesDetails: String) : Exception() { + override val message = "Computable description of tier $tier is not sufficient: $missingFeaturesDetails. Computed features: $selectedFeatures" +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/session/SessionStructureCollector.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/session/SessionStructureCollector.kt new file mode 100644 index 000000000000..d89b778d38a2 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/session/SessionStructureCollector.kt @@ -0,0 +1,231 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.session + +import com.intellij.platform.ml.* +import com.intellij.platform.ml.impl.DescriptionComputer +import com.intellij.platform.ml.impl.FeatureSelector +import com.intellij.platform.ml.impl.LevelTiers +import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform +import com.intellij.platform.ml.impl.model.MLModel +import org.jetbrains.annotations.ApiStatus +import java.util.concurrent.CompletableFuture + +@ApiStatus.Internal +class RootCollector, P : Any>( + apiPlatform: MLApiPlatform, + levelsTiers: List, + descriptionComputer: DescriptionComputer, + notUsedFeaturesSelectors: PerTier, + levelMainEnvironment: Environment, + levelAdditionalTiers: Set>, + private val mlModel: M +) : NestableStructureCollector, M, P>() { + override val levelDescriptor = LevelDescriptor(apiPlatform, descriptionComputer, mlModel.knownFeatures, notUsedFeaturesSelectors) + override val levelPositioning = LevelPositioning.superior(levelsTiers, levelDescriptor, levelMainEnvironment, levelAdditionalTiers) + + init { + validateSuperiorCollector(levelsTiers, levelMainEnvironment, levelAdditionalTiers, mlModel, notUsedFeaturesSelectors) + } + + override fun createTree(thisLevel: DescribedLevel, + collectedNestedStructureTrees: List>): SessionTree.ComplexRoot { + return SessionTree.ComplexRoot(mlModel, thisLevel, collectedNestedStructureTrees) + } +} + +@ApiStatus.Internal +class SolitaryLeafCollector, P : Any>( + apiPlatform: MLApiPlatform, + levelScheme: LevelTiers, + descriptionComputer: DescriptionComputer, + notUsedFeaturesSelectors: PerTier, + levelMainEnvironment: Environment, + levelAdditionalTiers: Set>, + private val mlModel: M +) : PredictionCollector, M, P>() { + override val levelDescriptor = LevelDescriptor(apiPlatform, descriptionComputer, mlModel.knownFeatures, notUsedFeaturesSelectors) + override val levelPositioning = LevelPositioning.superior(listOf(levelScheme), levelDescriptor, levelMainEnvironment, + levelAdditionalTiers) + + init { + validateSuperiorCollector(listOf(levelScheme), levelMainEnvironment, levelAdditionalTiers, mlModel, notUsedFeaturesSelectors) + } + + override fun createTree(thisLevel: DescribedLevel, prediction: P?): SessionTree.SolitaryLeaf { + return SessionTree.SolitaryLeaf(mlModel, levelPositioning.thisLevel, prediction) + } +} + +@ApiStatus.Internal +class BranchingCollector, P : Any>( + override val levelDescriptor: LevelDescriptor, + override val levelPositioning: LevelPositioning +) : NestableStructureCollector, M, P>() { + override fun createTree(thisLevel: DescribedLevel, + collectedNestedStructureTrees: List>): SessionTree.Branching { + return SessionTree.Branching(thisLevel, collectedNestedStructureTrees) + } +} + +@ApiStatus.Internal +abstract class PredictionCollector, M : MLModel

, P : Any> : StructureCollector() { + private var predictionSubmitted = false + private var submittedPrediction: P? = null + + abstract fun createTree(thisLevel: DescribedLevel, prediction: P?): T + + val usableDescription: PerTier> + get() = levelPositioning.levels.extractDescriptionForModel() + + fun submitPrediction(prediction: P?) { + require(!predictionSubmitted) + submittedPrediction = prediction + predictionSubmitted = true + submitTreeToHandlers(createTree(levelPositioning.thisLevel, submittedPrediction)) + } + + private fun PerTierInstance.extractDescriptionForModel(): PerTier> { + return this.entries.associate { (tierInstance, data) -> + tierInstance.tier to data.description.declared.used + data.description.nonDeclared.used + } + } + + private fun DescribedLevel.extractDescriptionForModel(): PerTier> { + val mainDescription = this.main.extractDescriptionForModel() + val additionalDescription = this.additional.extractDescriptionForModel() + return listOf(mainDescription + additionalDescription).joinByUniqueTier() + } + + private fun Iterable.extractDescriptionForModel(): PerTier> { + return this.map { it.extractDescriptionForModel() }.joinByUniqueTier() + } +} + +@ApiStatus.Internal +class LeafCollector, P : Any>( + override val levelDescriptor: LevelDescriptor, + override val levelPositioning: LevelPositioning +) : PredictionCollector, M, P>() { + + override fun createTree(thisLevel: DescribedLevel, prediction: P?): SessionTree.Leaf { + return SessionTree.Leaf(levelPositioning.thisLevel, prediction) + } +} + +private fun , P : Any> validateSuperiorCollector(levelsTiers: List, + levelMainEnvironment: Environment, + levelAdditionalTiers: Set>, + mlModel: M, + notUsedDescriptionSelectors: PerTier) { + val allTiers = (levelsTiers.flatMap { it.main + it.additional } + levelMainEnvironment.tiers + levelAdditionalTiers).toSet() + val usedFeaturesSelectors = mlModel.knownFeatures + + require(allTiers.containsAll(usedFeaturesSelectors.keys)) { + "ML Model uses tiers that are not main or additional: ${usedFeaturesSelectors.keys - allTiers}" + } + require(notUsedDescriptionSelectors.keys == allTiers) { + """ + Not used description's tiers must be same as the task tiersAll tiers + Missing: ${allTiers - notUsedDescriptionSelectors.keys} + Redundant: ${notUsedDescriptionSelectors.keys - allTiers} + """.trimIndent() + } +} + +@ApiStatus.Internal +data class LevelPositioning( + val upperLevels: List, + val lowerTiers: List, + val thisLevel: DescribedLevel, +) { + val levels: List = upperLevels + thisLevel + + fun nestNextLevel( + levelDescriptor: LevelDescriptor, + nextLevelMainEnvironment: Environment, + nextLevelAdditionalTiers: Set> + ): LevelPositioning { + return LevelPositioning( + upperLevels = upperLevels + thisLevel, + lowerTiers = lowerTiers.drop(1), + thisLevel = levelDescriptor.describe(upperLevels, nextLevelMainEnvironment, nextLevelAdditionalTiers) + ) + } + + companion object { + fun superior(levelsTiers: List, + levelDescriptor: LevelDescriptor, + levelMainEnvironment: Environment, + levelAdditionalTiers: Set>): LevelPositioning { + return LevelPositioning(emptyList(), levelsTiers.drop(1), + levelDescriptor.describe(emptyList(), levelMainEnvironment, levelAdditionalTiers)) + } + } +} + +@ApiStatus.Internal +sealed class StructureCollector, M : MLModel

, P : Any> { + protected abstract val levelDescriptor: LevelDescriptor + abstract val levelPositioning: LevelPositioning + + private val sessionTreeHandlers: MutableList> = mutableListOf() + + fun handleCollectedTree(handler: SessionTreeHandler) { + sessionTreeHandlers.add(handler) + } + + protected fun submitTreeToHandlers(sessionTree: T) { + sessionTreeHandlers.forEach { it.handleTree(sessionTree) } + } +} + +@ApiStatus.Internal +abstract class NestableStructureCollector, M : MLModel

, P : Any> : StructureCollector() { + private val nestedSessionsStructures: MutableList>> = mutableListOf() + private var nestingFinished = false + + fun nestBranch(levelMainEnvironment: Environment, levelAdditionalTiers: Set>): BranchingCollector { + verifyNestedLevelEnvironment(levelMainEnvironment, levelAdditionalTiers) + return BranchingCollector(levelDescriptor, + levelPositioning.nestNextLevel(levelDescriptor, levelMainEnvironment, levelAdditionalTiers)) + .also { it.trackCollectedStructure() } + } + + fun nestPrediction(levelMainEnvironment: Environment, levelAdditionalTiers: Set>): LeafCollector { + verifyNestedLevelEnvironment(levelMainEnvironment, levelAdditionalTiers) + return LeafCollector(levelDescriptor, levelPositioning.nestNextLevel(levelDescriptor, levelMainEnvironment, levelAdditionalTiers)) + .also { it.trackCollectedStructure() } + } + + fun onLastNestedCollectorCreated() { + require(!nestingFinished) + nestingFinished = true + maybeSubmitStructure() + } + + private fun > StructureCollector.trackCollectedStructure() { + val collectedNestedTreeContainer = CompletableFuture>() + nestedSessionsStructures += collectedNestedTreeContainer + this.handleCollectedTree { + collectedNestedTreeContainer.complete(it) + } + } + + private fun maybeSubmitStructure() { + val collectedNestedStructureTrees: List> = nestedSessionsStructures + .filter { it.isDone } + .map { it.get() } + + if (nestingFinished && collectedNestedStructureTrees.size == nestedSessionsStructures.size) { + val describedSessionTree = createTree(levelPositioning.thisLevel, collectedNestedStructureTrees) + submitTreeToHandlers(describedSessionTree) + } + } + + protected abstract fun createTree(thisLevel: DescribedLevel, collectedNestedStructureTrees: List>): T + + private fun verifyNestedLevelEnvironment(levelMainEnvironment: Environment, levelAdditionalTiers: Set>) { + require(levelPositioning.lowerTiers.first().main == levelMainEnvironment.tiers) + require(levelPositioning.lowerTiers.first().additional == levelAdditionalTiers) + } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/session/SessionTree.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/session/SessionTree.kt new file mode 100644 index 000000000000..6e74b751e200 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/session/SessionTree.kt @@ -0,0 +1,275 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.session + +import com.intellij.platform.ml.Environment +import com.intellij.platform.ml.Feature +import com.intellij.platform.ml.FeatureDeclaration +import com.intellij.platform.ml.PerTierInstance +import org.jetbrains.annotations.ApiStatus + +/** + * A partition of [Feature]s that indicates that they could be either used or not used by an ML model. + */ +@ApiStatus.Internal +data class Usage( + val used: T, + val notUsed: T, +) + +/** + * A partition of [Feature]s that indicates that they could be either declared (by [com.intellij.platform.ml.TierDescriptor]) or not. + * + * Used to process [com.intellij.platform.ml.ObsoleteTierDescriptor]'s partial descriptions. + * It shall be removed when all obsolete descriptors will be transferred to the new API. + */ +@ApiStatus.Internal +data class Declaredness( + val declared: T, + val nonDeclared: T +) + +/** + * There are two characteristics of a feature as for now: whether it is declared (statically, in a tier descriptor), + * and whether it is used by the ML model. + * So, this container creates a partition for each category of features. + */ +typealias DescriptionPartition = Declaredness>> + +/** + * A main tier has a description (that is used to run the ML model), and it also could contain analysis features + */ +@ApiStatus.Internal +data class MainTierScheme( + val description: Set>, + val analysis: Set> +) + +/** + * An additional tier is provided occasionally, and it has only description + */ +@ApiStatus.Internal +data class AdditionalTierScheme( + val description: Set> +) + +/** + * All the data, that a tier instance that has been described has + */ +@ApiStatus.Internal +data class DescribedTierData( + val description: DescriptionPartition, +) + +/** + * All the data, that a tier instance that has been described and analyzed has + */ +@ApiStatus.Internal +data class AnalysedTierData( + val description: DescriptionPartition, + val analysis: Set +) + +/** + * A template for a container that contains data of main tier instances, as well as additional instances. + * A Level is a collection of tiers that were declared on the same depth of an [com.intellij.platform.ml.impl.MLTask]'s declaration, + * plus additional tiers, declared by [com.intellij.platform.ml.impl.approach.LogDrivenModelInference.additionallyDescribedTiers] + * on the corresponding level. + */ +@ApiStatus.Internal +data class Level(val main: M, val additional: A) + +typealias DescribedLevel = Level, PerTierInstance> + +typealias AnalysedLevel = Level, PerTierInstance> + +/** + * Tree-like ml session's structure. + * + * All trees leaves have the same depths. + * The depth corresponds to the number of levels in an [com.intellij.platform.ml.impl.MLTask]. + * And the tree's structure is built by calling [com.intellij.platform.ml.NestableMLSession.createNestedSession]. + * + * @param RootT Type of the data, that is stored in the root node. + * @param LevelT Type of the data, that is stored in each tree's node. + * @param PredictionT Type of the session's prediction. + */ +@ApiStatus.Internal +sealed interface SessionTree { + /** + * Data, that is stored in each tree's node. + */ + val level: LevelT + + /** + * Accepts the [visitor], calling the corresponding interface's function. + */ + fun accept(visitor: Visitor): T + + /** + * Something that contains tree's root data. + * There are two such classes: a [SolitaryLeaf], and [ComplexRoot]. + */ + sealed interface RootContainer : SessionTree { + /** + * Data, that is stored only in the tree's root. + * ML model, for example. + */ + val root: RootT + } + + /** + * Something that has nested nodes. + * It could be either [ComplexRoot], or [Branching]. + * The number of children corresponds to number of calls of + * [com.intellij.platform.ml.NestableMLSession.createNestedSession]. + */ + sealed interface ChildrenContainer : SessionTree { + /** + * All nested trees, that were built by calling + * [com.intellij.platform.ml.NestableMLSession.createNestedSession] on this level. + */ + val children: List> + } + + /** + * Something that contains session's prediction. + * It is produced by [com.intellij.platform.ml.SinglePrediction], and the prediction could be either + * produced or canceled, hence the [prediction] is nullable. + */ + sealed interface PredictionContainer : SessionTree { + val prediction: PredictionT? + } + + /** + * Corresponds to an ML task session's structure, that had only one level, and could not have been nested. + * Hence, it contains root data and a prediction simultaneously. + */ + data class SolitaryLeaf( + override val root: RootT, + override val level: LevelT, + override val prediction: PredictionT? + ) : RootContainer, PredictionContainer { + override fun accept(visitor: Visitor): T { + return visitor.acceptSolitaryLeaf(this) + } + } + + /** + * Corresponds to an ML task session's structure, that had more than one level. + */ + data class ComplexRoot( + override val root: RootT, + override val level: LevelT, + override val children: List> + ) : RootContainer, ChildrenContainer { + override fun accept(visitor: Visitor): T = visitor.acceptRoot(this) + } + + /** + * Corresponds to a node in an ML task session's structure, that had more than one level. + */ + data class Branching( + override val level: LevelT, + override val children: List> + ) : SessionTree, ChildrenContainer { + override fun accept(visitor: Visitor): T = visitor.acceptBranching(this) + } + + /** + * Corresponds to a leaf in an ML task session's structure, that ad more than one level. + */ + data class Leaf( + override val level: LevelT, + override val prediction: PredictionT? + ) : SessionTree, PredictionContainer { + override fun accept(visitor: Visitor): T = visitor.acceptLeaf(this) + } + + /** + * Visits a [SessionTree]'s node. + * + * @param RootT Data's type, stored in the tree's root. + * @param LevelT Data's type, stored in each node. + * @param PredictionT Prediction's type in the ML session. + * @param T Data type that is returned by the visitor. + */ + interface Visitor { + fun acceptBranching(branching: Branching): T + + fun acceptLeaf(leaf: Leaf): T + + fun acceptRoot(root: ComplexRoot): T + + fun acceptSolitaryLeaf(solitaryLeaf: SolitaryLeaf): T + } + + /** + * Visits all tree's nodes on [levelIndex] depth. + */ + abstract class LevelVisitor private constructor( + private val levelIndex: Int, + private val thisVisitorLevel: Int, + ) : Visitor { + constructor(levelIndex: Int) : this(levelIndex, 0) + + private inner class DeeperLevelVisitor : LevelVisitor(levelIndex, thisVisitorLevel + 1) { + override fun visitLevel(level: DescribedLevel, levelRoot: SessionTree) { + this@LevelVisitor.visitLevel(level, levelRoot) + } + } + + private fun maybeVisitLevel(level: DescribedLevel, + levelRoot: SessionTree): Boolean = + if (levelIndex == thisVisitorLevel) { + visitLevel(level, levelRoot) + true + } + else false + + final override fun acceptBranching(branching: Branching) { + if (maybeVisitLevel(branching.level, branching)) return + for (child in branching.children) { + child.accept(DeeperLevelVisitor()) + } + } + + final override fun acceptLeaf(leaf: Leaf) { + require(maybeVisitLevel(leaf.level, leaf)) { + "The deepest level in the session tree is $thisVisitorLevel, given level $levelIndex does not exist" + } + } + + final override fun acceptRoot(root: ComplexRoot) { + if (maybeVisitLevel(root.level, root)) return + for (child in root.children) { + child.accept(DeeperLevelVisitor()) + } + } + + final override fun acceptSolitaryLeaf(solitaryLeaf: SolitaryLeaf) { + require(maybeVisitLevel(solitaryLeaf.level, solitaryLeaf)) { + "The only level in the session tree is $thisVisitorLevel, given level $levelIndex does not exist" + } + } + + abstract fun visitLevel(level: DescribedLevel, levelRoot: SessionTree) + } +} + +typealias DescribedSessionTree = SessionTree + +typealias DescribedChildrenContainer = SessionTree.ChildrenContainer + +typealias DescribedRootContainer = SessionTree.RootContainer + +typealias SessionAnalysis = Map> + +typealias AnalysedSessionTree

= SessionTree + +typealias AnalysedRootContainer

= SessionTree.RootContainer + +val DescribedSessionTree.environment: Environment + get() = Environment.of(this.level.main.keys) + +val DescribedLevel.environment: Environment + get() = Environment.of(this.main.keys) diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/session/SessionTreeHandler.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/session/SessionTreeHandler.kt new file mode 100644 index 000000000000..239c4321b094 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/session/SessionTreeHandler.kt @@ -0,0 +1,9 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.session + +import org.jetbrains.annotations.ApiStatus + +@ApiStatus.Internal +fun interface SessionTreeHandler, R, P> { + fun handleTree(sessionTree: T) +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/session/analysis/MLModelAnalysis.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/session/analysis/MLModelAnalysis.kt new file mode 100644 index 000000000000..966b17f69203 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/session/analysis/MLModelAnalysis.kt @@ -0,0 +1,25 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.session.analysis + +import com.intellij.platform.ml.Feature +import com.intellij.platform.ml.FeatureDeclaration +import com.intellij.platform.ml.impl.model.MLModel + +/** + * An analyzer which is dedicated to give features to the session's model. + * + * [SessionAnalyser.analysisDeclaration] returns a set of [FeatureDeclaration] - + * the features that the model will be described with. + * [SessionAnalyser.analyse] in this case returns a set of features - the model's analysis. + */ +typealias MLModelAnalyser = SessionAnalyser>, Set, M, P> + +internal class MLModelAnalysisJoiner, P : Any> : AnalysisJoiner>, Set, M, P> { + override fun joinDeclarations(declarations: Iterable>>): Set> { + return declarations.flatten().toSet() + } + + override fun joinAnalysis(analysis: Iterable>): Set { + return analysis.flatten().toSet() + } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/session/analysis/SessionAnalyser.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/session/analysis/SessionAnalyser.kt new file mode 100644 index 000000000000..65037098144a --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/session/analysis/SessionAnalyser.kt @@ -0,0 +1,67 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.session.analysis + +import com.intellij.platform.ml.impl.model.MLModel +import com.intellij.platform.ml.impl.session.DescribedRootContainer +import org.jetbrains.annotations.ApiStatus +import java.util.concurrent.CompletableFuture + +/** + * An interface for classes, that are analyzing the ML session after it was finished. + * + * Analysis is indefinitely long process (which is emphasized by the fact [analyse] returns a [CompletableFuture]). + * + * @param D Type of the analysis' declaration + * @param A Type of the analysis itself + * @param M Type of the model that has been utilized during the ML session + * @param P Type of the session's prediction + */ +@ApiStatus.Internal +interface SessionAnalyser, P : Any> { + /** + * Contains all features' declarations - [com.intellij.platform.ml.FeatureDeclaration], + * that are then used in the analysis as [analyse]'s return value. + */ + val analysisDeclaration: D + + /** + * Performs session tree's analysis. The analysis is performed asynchronously, + * so the function is not required to return the final result, but only a [CompletableFuture] + * that will be fulfilled when the analysis is finished. + */ + fun analyse(sessionTreeRoot: DescribedRootContainer): CompletableFuture +} + +/** + * Gathers analyses from different [SessionAnalyser]s to one. + */ +internal class JoinedSessionAnalyser, P : Any>( + private val baseAnalysers: Collection>, + private val joiner: AnalysisJoiner +) : SessionAnalyser { + override fun analyse(sessionTreeRoot: DescribedRootContainer): CompletableFuture { + val scatteredAnalysis = mutableListOf>() + for (sessionAnalyser in baseAnalysers) { + scatteredAnalysis.add(sessionAnalyser.analyse(sessionTreeRoot)) + } + + val joinedAnalysis = CompletableFuture() + val eachScatteredAnalysisFinished = CompletableFuture.allOf(*scatteredAnalysis.toTypedArray()) + + eachScatteredAnalysisFinished.thenRun { + val completeAnalysis = scatteredAnalysis.mapNotNull { it.get() } + joinedAnalysis.complete(joiner.joinAnalysis(completeAnalysis)) + } + + return joinedAnalysis + } + + override val analysisDeclaration: D + get() = joiner.joinDeclarations(baseAnalysers.map { it.analysisDeclaration }) +} + +internal interface AnalysisJoiner, P : Any> { + fun joinDeclarations(declarations: Iterable): D + + fun joinAnalysis(analysis: Iterable): A +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/session/analysis/SessionStructureAnalyser.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/session/analysis/SessionStructureAnalyser.kt new file mode 100644 index 000000000000..5b2617384d75 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/session/analysis/SessionStructureAnalyser.kt @@ -0,0 +1,40 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.session.analysis + +import com.intellij.platform.ml.Feature +import com.intellij.platform.ml.FeatureDeclaration +import com.intellij.platform.ml.PerTier +import com.intellij.platform.ml.impl.model.MLModel +import com.intellij.platform.ml.impl.session.DescribedSessionTree +import com.intellij.platform.ml.mergePerTier + +typealias StructureAnalysis = Map, PerTier>> + +typealias StructureAnalysisDeclaration = PerTier>> + +/** + * An analyzer that gives analytical features to the session tree's nodes. + * + * [SessionAnalyser.analysisDeclaration] returns a set of [FeatureDeclaration] per each analyzed tier. + * [SessionAnalyser.analyse] returns a mapping from each analyzed tree's node to sets of features. + * + * For example, if you want to give some analysis to the session tree's root, then you will return + * mapOf(sessionTreeRoot to setOf(...)). + * + * @see com.intellij.platform.ml.impl.session.SessionTree.Visitor to learn how you could walk the tree's nodes + * @see com.intellij.platform.ml.impl.session.SessionTree.LevelVisitor to see learn how you could session's levels. + */ +typealias StructureAnalyser = SessionAnalyser, M, P> + +internal class SessionStructureAnalysisJoiner, P : Any> : AnalysisJoiner, M, P> { + override fun joinAnalysis(analysis: Iterable>): StructureAnalysis { + return analysis + .flatMap { it.entries } + .groupBy({ it.key }, { it.value }) + .mapValues { it.value.mergePerTier { mutableSetOf() } } + } + + override fun joinDeclarations(declarations: Iterable): StructureAnalysisDeclaration { + return declarations.mergePerTier { mutableSetOf() } + } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/session/analysis/ShallowSessionAnalyser.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/session/analysis/ShallowSessionAnalyser.kt new file mode 100644 index 000000000000..021b94c2b91c --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/session/analysis/ShallowSessionAnalyser.kt @@ -0,0 +1,43 @@ +// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.session.analysis + +import com.intellij.internal.statistic.eventLog.events.EventField +import com.intellij.internal.statistic.eventLog.events.EventPair +import com.intellij.internal.statistic.eventLog.events.ObjectDescription +import com.intellij.platform.ml.Environment +import org.jetbrains.annotations.ApiStatus + +/** + * Analyzes not the ML session's tree, but some generic information. + * Could be used to analyze started, as well as failed to start sessions. + */ +@ApiStatus.Internal +interface ShallowSessionAnalyser { + /** + * Name of the analyzer. + */ + val name: String + + /** + * A complete static declaration of the fields, that will be written during analysis. + */ + val declaration: List> + + /** + * Analyze some generic information about an ML session. + * + * @param permanentSessionEnvironment The environment that is available during the whole ML session + * @param data Some additional data, an insight about the place where the analyzer was called. + * Could be a [Throwable], or a reason why it was not possible to start the session. + */ + fun analyse(permanentSessionEnvironment: Environment, data: D): List> + + companion object { + val ShallowSessionAnalyser.declarationObjectDescription: ObjectDescription + get() = object : ObjectDescription() { + init { + declaration.forEach { field(it) } + } + } + } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/session/sessionUsageUtil.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/session/sessionUsageUtil.kt new file mode 100644 index 000000000000..cca8e648116e --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/session/sessionUsageUtil.kt @@ -0,0 +1,77 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.session + +import com.intellij.platform.ml.* + +/** + * A wrapper for convenient usage of a [NestableMLSession]. + * + * This function assumes that this session is nestable. + */ +fun

Session

.withNestedSessions(useCreator: (NestableSessionWrapper

) -> Unit) { + val nestableMLSession = requireNotNull(this as? NestableMLSession

) + + val creator = object : NestableSessionWrapper

{ + override fun nestConsidering(levelEnvironment: Environment): Session

{ + return nestableMLSession.createNestedSession(levelEnvironment) + } + } + + try { + return useCreator(creator) + } + finally { + nestableMLSession.onLastNestedSessionCreated() + } +} + +/** + * A wrapper for convenient usage of a [SinglePrediction]. + * + * This function assumes that this session is [NestableMLSession], and the nested + * sessions' types are [SinglePrediction]. + */ +fun Session

.withPredictions(useModelWrapper: (ModelWrapper

) -> T): T { + val nestableMLSession = requireNotNull(this as? NestableMLSession

) + val predictor = object : ModelWrapper

{ + override fun predictConsidering(predictionEnvironment: Environment): P { + val predictionSession = nestableMLSession.createNestedSession(predictionEnvironment) + require(predictionSession is SinglePrediction

) + return predictionSession.predict() + } + + override fun consider(predictionEnvironment: Environment) { + val predictionSession = nestableMLSession.createNestedSession(predictionEnvironment) + require(predictionSession is SinglePrediction

) + predictionSession.cancelPrediction() + } + } + try { + return useModelWrapper(predictor) + } + finally { + nestableMLSession.onLastNestedSessionCreated() + } +} + +interface NestableSessionWrapper

{ + fun nestConsidering(levelEnvironment: Environment): Session

+ + companion object { + fun

NestableSessionWrapper

.nestConsidering(vararg levelTierInstances: TierInstance<*>): Session

{ + return this.nestConsidering(Environment.of(*levelTierInstances)) + } + } +} + +interface ModelWrapper

{ + fun predictConsidering(predictionEnvironment: Environment): P + + fun consider(predictionEnvironment: Environment) + + companion object { + fun

ModelWrapper

.predictConsidering(vararg predictionTierInstances: TierInstance<*>): P { + return this.predictConsidering(Environment.of(*predictionTierInstances)) + } + } +} \ No newline at end of file diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/turboComplete/CompletionKind.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/turboComplete/CompletionKind.kt index 9f3cdb1cf9fe..9f493aab08c2 100644 --- a/platform/ml-impl/src/com/intellij/platform/ml/impl/turboComplete/CompletionKind.kt +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/turboComplete/CompletionKind.kt @@ -1,6 +1,8 @@ // Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. package com.intellij.platform.ml.impl.turboComplete +import com.intellij.platform.ml.Tier + /** * Data class representing a kind of [SuggestionGenerator]'s suggestions. * @@ -8,4 +10,6 @@ package com.intellij.platform.ml.impl.turboComplete * The completion kind's name is defined statically, it should be unique among * the corresponding [KindVariety]. */ -data class CompletionKind(val name: Enum<*>, val variety: KindVariety) \ No newline at end of file +data class CompletionKind(val name: Enum<*>, val variety: KindVariety) + +object TierCompletionKind : Tier() diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/turboComplete/SuggestionGeneratorExecutorProvider.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/turboComplete/SuggestionGeneratorExecutorProvider.kt index 771eef37615b..92045f715aaf 100644 --- a/platform/ml-impl/src/com/intellij/platform/ml/impl/turboComplete/SuggestionGeneratorExecutorProvider.kt +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/turboComplete/SuggestionGeneratorExecutorProvider.kt @@ -16,7 +16,7 @@ interface SuggestionGeneratorExecutorProvider { ): SuggestionGeneratorExecutor companion object { - private val EP_NAME: ExtensionPointName = + val EP_NAME: ExtensionPointName = ExtensionPointName("com.intellij.turboComplete.suggestionGeneratorExecutorProvider") fun hasAnyToCall(parameters: CompletionParameters): Boolean { diff --git a/platform/ml-impl/test/com/intellij/platform/ml/impl/Demo.kt b/platform/ml-impl/test/com/intellij/platform/ml/impl/Demo.kt new file mode 100644 index 000000000000..db9035617e4c --- /dev/null +++ b/platform/ml-impl/test/com/intellij/platform/ml/impl/Demo.kt @@ -0,0 +1,483 @@ +// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl + +import com.intellij.internal.statistic.FUCollectorTestCase +import com.intellij.internal.statistic.eventLog.events.ClassEventField +import com.intellij.internal.statistic.eventLog.events.EventField +import com.intellij.internal.statistic.eventLog.events.EventPair +import com.intellij.lang.Language +import com.intellij.openapi.fileTypes.PlainTextLanguage +import com.intellij.openapi.util.Version +import com.intellij.platform.ml.* +import com.intellij.platform.ml.impl.MLTaskApproach.Companion.startMLSession +import com.intellij.platform.ml.impl.apiPlatform.CodeLikePrinter +import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform +import com.intellij.platform.ml.impl.apiPlatform.ReplaceableIJPlatform +import com.intellij.platform.ml.impl.approach.* +import com.intellij.platform.ml.impl.logger.FailedSessionLoggerRegister +import com.intellij.platform.ml.impl.logger.FinishedSessionLoggerRegister +import com.intellij.platform.ml.impl.logger.InplaceFeaturesScheme +import com.intellij.platform.ml.impl.logger.MLEvent +import com.intellij.platform.ml.impl.model.MLModel +import com.intellij.platform.ml.impl.monitoring.* +import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.ApproachListeners.Companion.monitoredBy +import com.intellij.platform.ml.impl.session.* +import com.intellij.platform.ml.impl.session.analysis.MLModelAnalyser +import com.intellij.platform.ml.impl.session.analysis.ShallowSessionAnalyser +import com.intellij.platform.ml.impl.session.analysis.StructureAnalyser +import com.intellij.platform.ml.impl.session.analysis.StructureAnalysis +import com.intellij.testFramework.fixtures.BasePlatformTestCase +import com.jetbrains.fus.reporting.model.lion3.LogEvent +import java.io.BufferedWriter +import java.io.FileWriter +import java.net.URI +import java.nio.file.Path +import java.util.concurrent.CompletableFuture +import java.util.concurrent.TimeUnit +import java.util.function.Consumer +import kotlin.io.path.div +import kotlin.random.Random + +enum class CompletionType { + SMART, + BASIC +} + +data class CompletionSession( + val language: Language, + val callOrder: Int, + val completionType: CompletionType? +) + +data class LookupImpl( + val foo: Boolean, + val index: Int +) + +data class LookupItem( + val lookupString: String, + val decoration: Map +) + +data class GitRepository( + val user: String, + val projectUri: URI, + val commits: List +) + +object TierCompletionSession : Tier() +object TierLookup : Tier() +object TierItem : Tier() + +object TierGit : Tier() + +class CompletionSessionFeatures1 : TierDescriptor { + companion object { + val CALL_ORDER = FeatureDeclaration.int("call_order") + val LANGUAGE_ID = FeatureDeclaration.categorical("language_id", Language.getRegisteredLanguages().map { it.id }.toSet()) + val GIT_USER = FeatureDeclaration.boolean("git_user_is_Glebanister") + val COMPLETION_TYPE = FeatureDeclaration.enum("completion_type").nullable() + } + + override val tier: Tier<*> = TierCompletionSession + + override val additionallyRequiredTiers: Set> = setOf(TierGit) + + override val descriptionDeclaration: Set> = setOf( + CALL_ORDER, LANGUAGE_ID, GIT_USER, COMPLETION_TYPE + ) + + override fun describe(environment: Environment, usefulFeaturesFilter: FeatureFilter): Set { + val completionSession = environment[TierCompletionSession] + val gitRepository = environment[TierGit] + return setOf( + CALL_ORDER with completionSession.callOrder, + LANGUAGE_ID with completionSession.language.id, + GIT_USER with (gitRepository.user == "Glebanister"), + COMPLETION_TYPE with completionSession.completionType + ) + } +} + +class ItemFeatures1 : TierDescriptor { + companion object { + val DECORATIONS = FeatureDeclaration.int("decorations") + val LENGTH = FeatureDeclaration.int("length") + } + + override val tier: Tier<*> = TierItem + + override val descriptionDeclaration: Set> = setOf( + DECORATIONS, LENGTH + ) + + override fun describe(environment: Environment, usefulFeaturesFilter: FeatureFilter): Set { + val item = environment[TierItem] + return setOf( + DECORATIONS with item.decoration.size, + LENGTH with item.lookupString.length + ) + } +} + +class GitFeatures1 : TierDescriptor { + companion object { + val N_COMMITS = FeatureDeclaration.int("n_commits") + val HAS_USER = FeatureDeclaration.boolean("has_user") + } + + override val tier: Tier<*> = TierGit + + override val descriptionDeclaration: Set> = setOf( + N_COMMITS, HAS_USER + ) + + override fun describe(environment: Environment, usefulFeaturesFilter: FeatureFilter) = setOf( + N_COMMITS with environment[TierGit].commits.size, + HAS_USER with environment[TierGit].user.isNotEmpty() + ) +} + +class GitInformant : EnvironmentExtender { + override val extendingTier: Tier = TierGit + + override val requiredTiers: Set> = setOf() + + override fun extend(environment: Environment): GitRepository { + return GitRepository( + user = "Glebanister", + projectUri = URI.create("ssh://git@git.jetbrains.team/ij/intellij.git"), + commits = listOf( + "0e47200fa3bf029d7244745eacbf9d495de818c1", + "638fbc7840b85d6e34ef2320ca5b2c9ec2c4b23c", + "5a74104a0901b8faa4cc1f76736347cd33917041" + ) + ) + } +} + +class SomeStructureAnalyser> : StructureAnalyser { + companion object { + val SESSION_IS_GOOD = FeatureDeclaration.boolean("very_good_session") + val LOOKUP_INDEX = FeatureDeclaration.int("lookup_index") + } + + override fun analyse(sessionTreeRoot: DescribedRootContainer): CompletableFuture> { + val analysis = mutableMapOf, PerTier>>() + sessionTreeRoot.accept(LookupAnalyser(analysis)) + analysis[sessionTreeRoot] = mapOf( + TierCompletionSession to setOf(SESSION_IS_GOOD with true) + ) + + return CompletableFuture.supplyAsync { + // Pretend that analysis is taking some long time + TimeUnit.SECONDS.sleep(2) + analysis + } + } + + private class LookupAnalyser>( + private val analysis: MutableMap, PerTier>> + ) : SessionTree.LevelVisitor(levelIndex = 1) { + override fun visitLevel(level: DescribedLevel, levelRoot: DescribedSessionTree) { + val lookup = level.environment[TierLookup] + analysis[levelRoot] = mapOf(TierLookup to setOf(LOOKUP_INDEX with lookup.index)) + } + } + + override val analysisDeclaration: PerTier>> + get() = mapOf( + TierCompletionSession to setOf(SESSION_IS_GOOD), + TierLookup to setOf(LOOKUP_INDEX) + ) +} + +class RandomModelSeedAnalyser : MLModelAnalyser { + companion object { + val SEED = FeatureDeclaration.int("random_seed") + } + + override val analysisDeclaration: Set> = setOf( + SEED + ) + + override fun analyse(sessionTreeRoot: DescribedRootContainer): CompletableFuture> { + return CompletableFuture.supplyAsync { + // Pretend that analysis is taking some long time + TimeUnit.SECONDS.sleep(1) + setOf(SEED with sessionTreeRoot.root.seed) + } + } +} + +class RandomModel(val seed: Int) : MLModel, Versioned, LanguageSpecific { + private val generator = Random(seed) + + class Provider : MLModel.Provider { + override val requiredTiers: Set> = emptySet() + + override fun provideModel(sessionTiers: List, environment: Environment): RandomModel? { + return if (Random.nextBoolean()) { + if (Random.nextBoolean()) RandomModel(1) else throw IllegalStateException() + } + else null + } + } + + override val knownFeatures: PerTier = mapOf( + TierCompletionSession to FeatureSelector.EVERYTHING, + TierLookup to FeatureSelector.EVERYTHING, + TierGit to FeatureSelector.EVERYTHING, + TierItem to FeatureSelector.EVERYTHING, + ) + + override fun predict(features: PerTier>): Double { + return generator.nextDouble() + } + + override val languageId: String = PlainTextLanguage.INSTANCE.id + + override val version: Version = Version(0, 0, 1) +} + +class SomeListener(private val name: String) : MLTaskGroupListener { + override val approachListeners = listOf( + MockTaskApproach::class.java monitoredBy InitializationListener() + ) + + private fun log(message: String) = println("[Listener $name says] $message") + + inner class InitializationListener : MLApproachInitializationListener { + override fun onAttemptedToStartSession(permanentSessionEnvironment: Environment): MLApproachListener { + log("attempted to initialize session") + return ApproachListener() + } + } + + inner class ApproachListener : MLApproachListener { + override fun onFailedToStartSessionWithException(exception: Throwable) { + log("failed to start session with exception: $exception") + } + + override fun onFailedToStartSession(failure: Session.StartOutcome.Failure) { + log("failed to start session with outcome: $failure") + } + + override fun onStartedSession(session: Session): MLSessionListener { + log("session was started successfully: $session") + return SessionListener() + } + } + + inner class SessionListener : MLSessionListener { + override fun onSessionDescriptionFinished(sessionTree: DescribedRootContainer) { + log("session successfully described: $sessionTree") + } + + override fun onSessionAnalysisFinished(sessionTree: AnalysedRootContainer) { + log("session successfully analyzed: $sessionTree") + } + } +} + +object ExceptionLogger : ShallowSessionAnalyser { + private val THROWABLE_CLASS = ClassEventField("throwable_class") + + override val name: String = "exception" + + override val declaration: List> = listOf(THROWABLE_CLASS) + + override fun analyse(permanentSessionEnvironment: Environment, data: Throwable): List> { + return listOf(THROWABLE_CLASS with data.javaClass) + } +} + +object FailureLogger : ShallowSessionAnalyser> { + private val REASON = ClassEventField("reason") + + override val name: String = "normal_failure" + + override val declaration: List> = listOf(REASON) + + override fun analyse(permanentSessionEnvironment: Environment, + data: Session.StartOutcome.Failure): List> { + return listOf(REASON with data.javaClass) + } +} + +object ThisTestApiPlatform : TestApiPlatform() { + override val tierDescriptors = listOf( + CompletionSessionFeatures1(), + ItemFeatures1(), + GitFeatures1(), + ) + + override val environmentExtenders = listOf( + GitInformant(), + ) + + override val taskApproaches = listOf( + MockTaskApproach.Initializer() + ) + + + override val initialStartupListeners: List = listOf( + FinishedSessionLoggerRegister( + MockTaskApproach::class.java, + InplaceFeaturesScheme.FusScheme.DOUBLE + ), + FailedSessionLoggerRegister( + MockTaskApproach::class.java, + exceptionalAnalysers = listOf(ExceptionLogger), + normalFailureAnalysers = listOf(FailureLogger) + ) + ) + + override val initialTaskListeners: List = listOf( + SomeListener("Nika"), + SomeListener("Alex"), + ) + + override val initialEvents: List = listOf() + + + override fun manageNonDeclaredFeatures(descriptor: ObsoleteTierDescriptor, nonDeclaredFeatures: Set) { + val printer = CodeLikePrinter() + println("$descriptor is missing the following declaration: ${printer.printCodeLikeString(nonDeclaredFeatures.map { it.declaration })}") + } +} + +object MockTask : MLTask( + name = "mock", + predictionClass = Double::class.java, + levels = listOf( + setOf(TierCompletionSession), + setOf(TierLookup), + setOf(TierItem) + ) +) + +class MockTaskApproach( + apiPlatform: MLApiPlatform, + task: MLTask +) : LogDrivenModelInference(task, apiPlatform) { + + override val additionallyDescribedTiers: List>> = listOf( + setOf(TierGit), + setOf(), + setOf(), + ) + + override val analysisMethod: AnalysisMethod = StructureAndModelAnalysis( + structureAnalysers = listOf(SomeStructureAnalyser()), + mlModelAnalysers = listOf( + RandomModelSeedAnalyser(), + ModelVersionAnalyser(), + ModelLanguageAnalyser() + ) + ) + + override val mlModelProvider = RandomModel.Provider() + + override val notUsedDescription: PerTier = mapOf( + TierCompletionSession to FeatureSelector.NOTHING, + TierLookup to FeatureSelector.NOTHING, + TierItem to FeatureSelector.NOTHING, + TierGit to FeatureSelector.NOTHING + ) + + override val descriptionComputer: DescriptionComputer = StateFreeDescriptionComputer() + + class Initializer : MLTaskApproachInitializer { + override val task: MLTask = MockTask + override fun initializeApproachWithin(apiPlatform: MLApiPlatform) = MockTaskApproach(apiPlatform, task) + } +} + +class TestTask : BasePlatformTestCase() { + fun `test demo ml task`() { + // After the session is finished, it will be logged to community/platform/ml-impl/testResources/ml_logs.js + + val logs: MutableList>> = mutableListOf() + val collectLogs = Consumer { fusLog: LogEvent -> + logs.add(fusLog.event.id to fusLog.event.data) + } + + // TODO: Handle this as well (when it has been initialized before we had the control flow) + //MLEventLogger.Manager.ensureNotInitialized() + + ReplaceableIJPlatform.replacingWith(ThisTestApiPlatform) { + FUCollectorTestCase.listenForEvents("FUS", this.testRootDisposable, collectLogs) { + + repeat(3) { sessionIndex -> + + println("Demo session #$sessionIndex has started") + + val startOutcome = startMLSession(MockTask, Environment.of( + TierCompletionSession with CompletionSession( + language = PlainTextLanguage.INSTANCE, + callOrder = 1, + completionType = CompletionType.SMART + ) + )) + + val completionSession = startOutcome.session ?: return@repeat + + completionSession.withNestedSessions { lookupSessionCreator -> + + lookupSessionCreator.nestConsidering(Environment.of(TierLookup with LookupImpl(true, 1))) + .withPredictions { + it.predictConsidering(Environment.of(TierItem with LookupItem("hello", emptyMap()))) + it.predictConsidering(Environment.of(TierItem with LookupItem("world", emptyMap()))) + } + + lookupSessionCreator.nestConsidering(Environment.of(TierLookup with LookupImpl(false, 2))) + .withPredictions { + it.predictConsidering(Environment.of(TierItem with LookupItem("hello!!!", mapOf("bold" to true)))) + it.predictConsidering(Environment.of(TierItem with LookupItem("AAAAA!!", mapOf("strikethrough" to true)))) + it.consider(Environment.of(TierItem with LookupItem("AAAAAAAAAAAAAAAAA", mapOf("cursive" to true)))) + } + } + + // Wait until the analysis will be over + // TODO: Think if it is possible to track it via API + + Thread.sleep(3 * 1000) + + println("Demo session #$sessionIndex has finished") + } + } + } + + convertListToJsonAndWriteToFile(logs) + } +} + +fun mapToJson(map: Map): String { + val entrySet = map.entries.joinToString(", ") { + "\"${it.key}\": ${valueToJson(it.value)}" + } + return "{$entrySet}" +} + +fun listToJson(list: List): String = + list.joinToString(", ") { valueToJson(it) } + +@Suppress("UNCHECKED_CAST") +fun valueToJson(value: Any): String = + when (value) { + is String -> "\"$value\"" + is Map<*, *> -> mapToJson(value as Map) + is List<*> -> "[${listToJson(value as List)}]" + else -> value.toString() + } + +fun convertListToJsonAndWriteToFile(list: MutableList>>) { + // Prepare content to write to file + val logs = list.joinToString(",\n") { mapToJson(mapOf("eventId" to it.first, "data" to it.second)) } + val content = "let logs = [\n$logs\n];" + + // Write content to file + val filePath = Path.of(".") / "testResources" / "ml_logs.js" + BufferedWriter(FileWriter(filePath.toFile())).use { it.write(content) } +} diff --git a/platform/ml-impl/test/com/intellij/platform/ml/impl/TestApiPlatform.kt b/platform/ml-impl/test/com/intellij/platform/ml/impl/TestApiPlatform.kt new file mode 100644 index 000000000000..f7d1030b5e2e --- /dev/null +++ b/platform/ml-impl/test/com/intellij/platform/ml/impl/TestApiPlatform.kt @@ -0,0 +1,46 @@ +// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl + +import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform +import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform.ExtensionController +import com.intellij.platform.ml.impl.logger.MLEvent +import com.intellij.platform.ml.impl.monitoring.MLApiStartupListener +import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener + +abstract class TestApiPlatform : MLApiPlatform() { + private val dynamicTaskListeners: MutableList = mutableListOf() + private val dynamicStartupListeners: MutableList = mutableListOf() + private val dynamicEvents: MutableList = mutableListOf() + + abstract val initialTaskListeners: List + + abstract val initialStartupListeners: List + + abstract val initialEvents: List + + final override val events: List + get() = initialEvents + dynamicEvents + + final override val taskListeners: List + get() = initialTaskListeners + dynamicTaskListeners + + final override val startupListeners: List + get() = initialStartupListeners + dynamicStartupListeners + + override fun addTaskListener(taskListener: MLTaskGroupListener): ExtensionController { + return extend(taskListener, dynamicTaskListeners) + } + + override fun addEvent(event: MLEvent): ExtensionController { + return extend(event, dynamicEvents) + } + + override fun addStartupListener(listener: MLApiStartupListener): ExtensionController { + return extend(listener, dynamicStartupListeners) + } + + private fun extend(obj: T, collection: MutableCollection): ExtensionController { + collection.add(obj) + return ExtensionController { collection.remove(obj) } + } +} diff --git a/platform/ml-impl/testResources/ml_logs.js b/platform/ml-impl/testResources/ml_logs.js new file mode 100644 index 000000000000..eec904b996c8 --- /dev/null +++ b/platform/ml-impl/testResources/ml_logs.js @@ -0,0 +1,6 @@ +let logs = [ +{"eventId": "startup", "data": {"success": true}}, +{"eventId": "mock.failed", "data": {"normal_failure": {"reason": "com.intellij.platform.ml.impl.approach.ModelNotAcquiredOutcome"}}}, +{"eventId": "mock.finished", "data": {"session": {"ml_model": {"random_seed": 1, "language_id": "TEXT", "version": "0.0"}}, "additional": {"TierGit": {"description": {"not_used": {}, "used": {"has_user": true, "n_commits": 3}}, "id": -401817549}}, "main": {"TierCompletionSession": {"description": {"not_used": {}, "used": {"language_id": "TEXT", "git_user_is_Glebanister": true, "completion_type": "SMART", "call_order": 1}}, "id": 1444967957, "analysis": {"very_good_session": true}}}, "nested": [{"additional": {}, "main": {"TierLookup": {"description": {"not_used": {}, "used": {}}, "id": 38162, "analysis": {"lookup_index": 1}}}, "nested": [{"additional": {}, "prediction": "0.1397272444266996", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 5, "decorations": 0}}, "id": -1220935314, "analysis": {}}}}, {"additional": {}, "prediction": "0.8772473577824259", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 5, "decorations": 0}}, "id": -782084434, "analysis": {}}}}]}, {"additional": {}, "main": {"TierLookup": {"description": {"not_used": {}, "used": {}}, "id": 38349, "analysis": {"lookup_index": 2}}}, "nested": [{"additional": {}, "prediction": "0.8739363143786573", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 8, "decorations": 1}}, "id": 1198136891, "analysis": {}}}}, {"additional": {}, "prediction": "0.9611620534809023", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 7, "decorations": 1}}, "id": 122089051, "analysis": {}}}}, {"additional": {}, "prediction": "null", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 17, "decorations": 1}}, "id": 1444769257, "analysis": {}}}}]}]}}, +{"eventId": "mock.finished", "data": {"session": {"ml_model": {"random_seed": 1, "language_id": "TEXT", "version": "0.0"}}, "additional": {"TierGit": {"description": {"not_used": {}, "used": {"has_user": true, "n_commits": 3}}, "id": -401817549}}, "main": {"TierCompletionSession": {"description": {"not_used": {}, "used": {"language_id": "TEXT", "git_user_is_Glebanister": true, "completion_type": "SMART", "call_order": 1}}, "id": 1444967957, "analysis": {"very_good_session": true}}}, "nested": [{"additional": {}, "main": {"TierLookup": {"description": {"not_used": {}, "used": {}}, "id": 38162, "analysis": {"lookup_index": 1}}}, "nested": [{"additional": {}, "prediction": "0.1397272444266996", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 5, "decorations": 0}}, "id": -1220935314, "analysis": {}}}}, {"additional": {}, "prediction": "0.8772473577824259", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 5, "decorations": 0}}, "id": -782084434, "analysis": {}}}}]}, {"additional": {}, "main": {"TierLookup": {"description": {"not_used": {}, "used": {}}, "id": 38349, "analysis": {"lookup_index": 2}}}, "nested": [{"additional": {}, "prediction": "0.8739363143786573", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 8, "decorations": 1}}, "id": 1198136891, "analysis": {}}}}, {"additional": {}, "prediction": "0.9611620534809023", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 7, "decorations": 1}}, "id": 122089051, "analysis": {}}}}, {"additional": {}, "prediction": "null", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 17, "decorations": 1}}, "id": 1444769257, "analysis": {}}}}]}]}} +]; \ No newline at end of file diff --git a/platform/ml-impl/testResources/mockModel/local_model.zip b/platform/ml-impl/testResources/mockModel/local_model.zip new file mode 100644 index 000000000000..0853b104be01 Binary files /dev/null and b/platform/ml-impl/testResources/mockModel/local_model.zip differ diff --git a/platform/platform-resources/src/META-INF/PlatformExtensionPoints.xml b/platform/platform-resources/src/META-INF/PlatformExtensionPoints.xml index 8de9309dd0eb..279a6dcdc70b 100644 --- a/platform/platform-resources/src/META-INF/PlatformExtensionPoints.xml +++ b/platform/platform-resources/src/META-INF/PlatformExtensionPoints.xml @@ -520,6 +520,16 @@ + + + +