[ml] ML-2246 Create a prototype of the common ML API

Fix recursion issue while initializing FUS

Fix project structure

Add ml api initialization logger

Add Api.Internal annotations

Rename a field for better understanding

Add some missing documentation

fixup! Add the API's code documentation
fixup! Add the API's code documentation
fixup! Add the API's code documentation
fixup! Add the API's code documentation
Add analysis of the not started sessions

Add API startup listener & move logger out of approach initialization

Implement API listeners

Add the API's code documentation

Refactor level storage format

Remove redundant changes

Update demo

Fix non logging nested sessions issue

Secure fus logger's initialization

Access the API only via MLApiPlatform

Add model's type & add model analysers

Add double quotes to the code-like categorical feature representation

Allow logging strings from a particular set

In logs, split non-declared features by usage

Acknowledge non declared features in logs

Add tier to logs' tier id

Add environment resolve logic

Update performance task declaration

Enable testing FUS logs

Split MLTask and MLTaskApproach

Add informative failure messages

Replace feature set requirements with feature selectors

Make feature a class

Extend demo a little bit

Remove unchecked cast suppressing

Add type-safe environment initialisation

Add type-safe feature declaration

Add example lookup analysis

Think better about beautiful names

Add item tier && attempt to beautify nested sessions creation

Update demo

Attempt to create another ml api

Merge-request: IJ-MR-125165
Merged-by: Gleb Marin <Gleb.Marin@jetbrains.com>

GitOrigin-RevId: 2aeed3030ab9a5f43e51783fd95c36cda99a18ab
This commit is contained in:
Gleb.Marin
2024-02-02 13:11:00 +00:00
committed by intellij-monorepo-bot
parent d3682b6383
commit ef4328d797
60 changed files with 5074 additions and 4 deletions

View File

@@ -3,6 +3,7 @@
<component name="NewModuleRootManager" inherit-compiler-output="true">
<exclude-output />
<content url="file://$MODULE_DIR$">
<sourceFolder url="file://$MODULE_DIR$/resources" type="java-resource" />
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
</content>
<orderEntry type="inheritedJdk" />

View File

@@ -0,0 +1,2 @@
<idea-plugin>
</idea-plugin>

View File

@@ -0,0 +1,69 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml
import org.jetbrains.annotations.ApiStatus
/**
* Represents an environment that is being assembled to be described by [TierDescriptor]s,
* to acquire a new ML Model, or for another reason.
*/
@ApiStatus.Internal
interface Environment {
/**
* The set of tiers, that the environment contains.
*/
val tiers: Set<Tier<*>>
/**
* @return an instance, that corresponds to the given [tier].
* @throws IllegalArgumentException if the [tier] is not present
*/
fun <T : Any> getInstance(tier: Tier<T>): T
/**
* The set of tier instances that are present in the environment.
*/
val tierInstances: Set<TierInstance<*>>
get() {
return tiers.map { this.getTierInstance(it) }.toSet()
}
/**
* @return a tier instance wrapped into [TierInstance] class.
* @throws IllegalArgumentException if the tier is not present.
*/
fun <T : Any> getTierInstance(tier: Tier<T>) = TierInstance(tier, this[tier])
/**
* @return if the tier is present in the environment
*/
operator fun contains(tier: Tier<*>): Boolean = tier in this.tiers
companion object {
/**
* @return An environment that contains all tiers of all given environments.
* @throws IllegalArgumentException if there is a tier that is present in more than two
* environments.
*/
fun joined(environments: Iterable<Environment>): Environment = TierInstanceStorage.joined(environments)
fun empty(): Environment = joined(emptySet())
/**
* Builds an environment that contains all the given tiers.
* @throws IllegalArgumentException if there is more than one instance of a particular tier.
*/
fun of(entries: Iterable<TierInstance<*>>): Environment {
val storage = TierInstanceStorage()
fun <T : Any> putToStorage(tierInstance: TierInstance<T>) {
storage[tierInstance.tier] = tierInstance.instance
}
entries.forEach { putToStorage(it) }
return storage
}
fun of(vararg entries: TierInstance<*>): Environment = of(entries.toList())
}
}
operator fun <T : Any> Environment.get(tier: Tier<T>): T = getInstance(tier)

View File

@@ -0,0 +1,53 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml
import com.intellij.openapi.extensions.ExtensionPointName
import org.jetbrains.annotations.ApiStatus
/**
* Provides additional tiers on the top of the main ones, making an "extended"
* environment.
*
* If there are some other tiers that an [EnvironmentExtender] needs to build or access the [extendingTier],
* then it could "order" them by defining the [requiredTiers].
*
* The extender will be called if and only if all the requirements are satisfied.
*
* Additional tiers will be later described by [TierDescriptor]. This description will be used for the model's
* inference and then logged. However, they cannot be analyzed via [com.intellij.platform.ml.impl.session.analysis.SessionAnalyser].
* Because they do not make a part of the ML Task, and they could be absent in the session.
*
* An "extended environment" (i.e., the one that contains main as well as additional tiers) is built
* each time when a [TierRequester] performs the desired action. For example, an [EnvironmentExtender.extend] or
* [com.intellij.platform.ml.impl.model.MLModel.Provider.provideModel] is called.
*
* As each extender has some requirements, as it also produces some other tier, we must first determine
* the order in which the available extenders will run, or resolve it.
* To lean more about that, address [com.intellij.platform.ml.impl.environment.ExtendedEnvironment]'s documentation.
*
* Additional tiers,
*/
@ApiStatus.Internal
interface EnvironmentExtender<T : Any> : TierRequester {
/**
* The tier that the extender will be providing.
*/
val extendingTier: Tier<T>
/**
* Provides an instance of the [extendingTier] based on the [environment].
*
* @param environment includes tiers requested in [requiredTiers]
*/
fun extend(environment: Environment): T?
companion object {
val EP_NAME: ExtensionPointName<EnvironmentExtender<*>> = ExtensionPointName("com.intellij.platform.ml.environmentExtender")
fun <T : Any> EnvironmentExtender<T>.extendTierInstance(environment: Environment): TierInstance<T>? {
return extend(environment)?.let {
this.extendingTier with it
}
}
}
}

View File

@@ -0,0 +1,90 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml
import org.jetbrains.annotations.ApiStatus
/**
* An instantiated tier's feature. It could be either
* description's feature (that was provided by a [TierDescriptor]), or an
* analysis feature (that was provided by a [com.intellij.platform.ml.impl.session.analysis.SessionAnalyser]).
*
* If you need another type of feature, and it is not supported yet, contact the ML API developers.
*/
@ApiStatus.Internal
sealed class Feature {
/**
* A statically defined declaration of this feature, that includes all the information, except the value.
*/
abstract val declaration: FeatureDeclaration<*>
/**
* A computed nullable value
*/
abstract val value: Any?
override fun equals(other: Any?): kotlin.Boolean {
if (other !is Feature) return false
return other.declaration == this.declaration && other.value == this.value
}
sealed class TypedFeature<T>(
val name: String,
override val value: T,
) : Feature() {
abstract val valueType: FeatureValueType<T>
override val declaration: FeatureDeclaration<*>
get() = FeatureDeclaration(name, valueType)
}
override fun hashCode(): kotlin.Int {
return this.declaration.hashCode().xor(this.value.hashCode())
}
override fun toString(): String {
return "Feature{declaration=$declaration, value=$value}"
}
class Enum<T : kotlin.Enum<*>>(name: String, value: T) : TypedFeature<T>(name, value) {
override val valueType = FeatureValueType.Enum(value.javaClass)
}
class Int(name: String, value: kotlin.Int) : TypedFeature<kotlin.Int>(name, value) {
override val valueType = FeatureValueType.Int
}
class Boolean(name: String, value: kotlin.Boolean) : TypedFeature<kotlin.Boolean>(name, value) {
override val valueType = FeatureValueType.Boolean
}
class Float(name: String, value: kotlin.Float) : TypedFeature<kotlin.Float>(name, value) {
override val valueType = FeatureValueType.Float
}
class Double(name: String, value: kotlin.Double) : TypedFeature<kotlin.Double>(name, value) {
override val valueType = FeatureValueType.Double
}
class Long(name: String, value: kotlin.Long) : TypedFeature<kotlin.Long>(name, value) {
override val valueType = FeatureValueType.Long
}
class Class(name: String, value: java.lang.Class<*>) : TypedFeature<java.lang.Class<*>>(name, value) {
override val valueType = FeatureValueType.Class
}
class Nullable<T>(name: String, value: T?, val baseType: FeatureValueType<T>)
: TypedFeature<T?>(name, value) {
override val valueType = FeatureValueType.Nullable(baseType)
}
class Categorical(name: String, value: String, possibleValues: Set<String>)
: TypedFeature<String>(name, value) {
override val valueType = FeatureValueType.Categorical(possibleValues)
}
class Version(name: String, value: com.intellij.openapi.util.Version)
: TypedFeature<com.intellij.openapi.util.Version>(name, value) {
override val valueType = FeatureValueType.Version
}
}

View File

@@ -0,0 +1,62 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml
import org.jetbrains.annotations.ApiStatus
/**
* Represents declaration of a Tier's feature.
*
* All tiers that end up in the ML Model's inference must be declared statically
* (i.e. via Extension Points), as FUS logs' validator must be aware of every
* feature that will be logged.
*
* @param name The feature's name that is unique for this tier. It may not contain special symbols.
* @param type The feature's type.
* @param T The type of the value, with which the [type] can be instantiated ([FeatureValueType.instantiate]).
*/
@ApiStatus.Internal
data class FeatureDeclaration<T>(
val name: String,
val type: FeatureValueType<T>
) {
init {
require(name.all { it.isLetterOrDigit() || it == '_' }) {
"Invalid feature name '$name': it shall not contain special symbols"
}
}
/**
* Shortcut for the feature's instantiation
*/
infix fun with(value: T): Feature {
return type.instantiate(name, value)
}
/**
* Signifies that the feature can be instantiated with null values.
*/
fun nullable(): FeatureDeclaration<T?> {
require(type !is FeatureValueType.Nullable<*>) { "Repeated declaration as 'nullable'" }
return FeatureDeclaration(name, FeatureValueType.Nullable(type))
}
companion object {
inline fun <reified T : Enum<*>> enum(name: String) = FeatureDeclaration(name, FeatureValueType.Enum(T::class.java))
fun int(name: String) = FeatureDeclaration(name, FeatureValueType.Int)
fun double(name: String) = FeatureDeclaration(name, FeatureValueType.Double)
fun float(name: String) = FeatureDeclaration(name, FeatureValueType.Float)
fun long(name: String) = FeatureDeclaration(name, FeatureValueType.Long)
fun aClass(name: String) = FeatureDeclaration(name, FeatureValueType.Class)
fun boolean(name: String) = FeatureDeclaration(name, FeatureValueType.Boolean)
fun categorical(name: String, possibleValues: Set<String>) = FeatureDeclaration(name, FeatureValueType.Categorical(possibleValues))
fun version(name: String) = FeatureDeclaration(name, FeatureValueType.Version)
}
}

View File

@@ -0,0 +1,18 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml
import org.jetbrains.annotations.ApiStatus
/**
* A filter that indicates whether a feature matches a feature set or not.
* A feature filter functions within a particular tier.
*/
@ApiStatus.Internal
fun interface FeatureFilter {
fun accept(featureDeclaration: FeatureDeclaration<*>): Boolean
companion object {
val REJECT_ALL = FeatureFilter { false }
val ACCEPT_ALL = FeatureFilter { true }
}
}

View File

@@ -0,0 +1,83 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml
import org.jetbrains.annotations.ApiStatus
/**
* A type of [Feature].
*
* If you need to use another type of feature to use in your ML model or in analysis,
* consider contacting the ML API developers.
*/
@ApiStatus.Internal
sealed class FeatureValueType<T> {
abstract fun instantiate(name: String, value: T): Feature
data class Nullable<T>(val baseType: FeatureValueType<T>) : FeatureValueType<T?>() {
override fun instantiate(name: String, value: T?): Feature {
return Feature.Nullable(name, value, baseType)
}
}
data class Enum<T : kotlin.Enum<*>>(val enumClass: java.lang.Class<T>) : FeatureValueType<T>() {
override fun instantiate(name: String, value: T): Feature {
return Feature.Enum(name, value)
}
}
data class Categorical(val possibleValues: Set<String>) : FeatureValueType<String>() {
override fun instantiate(name: String, value: String): Feature {
require(value in possibleValues) {
val caseNonMatchingValue = possibleValues.find { it.equals(name, ignoreCase = true) }
"Feature $name cannot be assigned to value $value," +
"all possible values are $possibleValues. " +
"Possible match (but case does not match): $caseNonMatchingValue"
}
return Feature.Categorical(name, value, possibleValues)
}
}
object Int : FeatureValueType<kotlin.Int>() {
override fun instantiate(name: String, value: kotlin.Int): Feature {
return Feature.Int(name, value)
}
}
object Double : FeatureValueType<kotlin.Double>() {
override fun instantiate(name: String, value: kotlin.Double): Feature {
return Feature.Double(name, value)
}
}
object Float : FeatureValueType<kotlin.Float>() {
override fun instantiate(name: String, value: kotlin.Float): Feature {
return Feature.Float(name, value)
}
}
object Long : FeatureValueType<kotlin.Long>() {
override fun instantiate(name: String, value: kotlin.Long): Feature {
return Feature.Long(name, value)
}
}
object Class : FeatureValueType<java.lang.Class<*>>() {
override fun instantiate(name: String, value: java.lang.Class<*>): Feature {
return Feature.Class(name, value)
}
}
object Boolean : FeatureValueType<kotlin.Boolean>() {
override fun instantiate(name: String, value: kotlin.Boolean): Feature {
return Feature.Boolean(name, value)
}
}
object Version : FeatureValueType<com.intellij.openapi.util.Version>() {
override fun instantiate(name: String, value: com.intellij.openapi.util.Version): Feature {
return Feature.Version(name, value)
}
}
override fun toString(): String = this.javaClass.simpleName
}

View File

@@ -0,0 +1,22 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml
import org.jetbrains.annotations.ApiStatus
/**
* A mutable version of the [Environment], that can be extended.
*/
@ApiStatus.Internal
interface MutableEnvironment : Environment {
/**
* Adds new tier instance to the existing ones.
* @throws IllegalArgumentException if there is already an instance of such [tier] in the environment.
*/
fun <T : Any> putTierInstance(tier: Tier<T>, instance: T)
}
operator fun <T : Any> MutableEnvironment.set(tier: Tier<T>, instance: T) = putTierInstance(tier, instance)
fun <T : Any> MutableEnvironment.putTierInstance(tierInstance: TierInstance<T>) {
this[tierInstance.tier] = tierInstance.instance
}

View File

@@ -0,0 +1,38 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml
import org.jetbrains.annotations.ApiStatus
/**
* A [TierDescriptor] that is not fully aware of the features it describes with.
* It is used to make a smooth transition from the old forms of features providers to the new API.
*
* It will be removed eventually.
*/
@Deprecated(message = "ObsoleteTierDescriptors' features are not logged, as they are missing a feature declaration",
replaceWith = ReplaceWith("TierDescriptor", "com.intellij.platform.ml"),
level = DeprecationLevel.WARNING)
@ApiStatus.Internal
interface ObsoleteTierDescriptor : TierDescriptor {
/**
* The case when a [TierDescriptor] is an [ObsoleteTierDescriptor] is handled in the API
* individually each time. And in those times, we cannot rely on the [descriptionDeclaration],
* because it may not be correct.
*/
override val descriptionDeclaration: Set<FeatureDeclaration<*>>
get() = throw IllegalAccessError("Obsolete descriptor does not provide a description declaration")
override fun couldBeUseful(usefulFeaturesFilter: FeatureFilter): Boolean {
return true
}
/**
* The declaration that is already known.
* If there is a feature that is computed but not declared, then they can be logged,
* so you can add them to the declaration.
*
* Turn on the ml.description.logMissing registry key to log the missing features.
*/
val partialDescriptionDeclaration: Set<FeatureDeclaration<*>>
get() = emptySet()
}

View File

@@ -0,0 +1,40 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml
import org.jetbrains.annotations.ApiStatus
/**
* An environment, that restricts access to the [baseEnvironment] with the [scope].
* It is created before passing an [Environment] to a [TierDescriptor], so it will
* not be accessing non-declared tiers.
*/
@ApiStatus.Internal
class ScopeEnvironment private constructor(
private val baseEnvironment: Environment,
private val scope: Set<Tier<*>>
) : Environment {
override val tiers: Set<Tier<*>> = scope
override fun <T : Any> getInstance(tier: Tier<T>): T {
require(tier in scope) { "$tier was not supposed to be used, allowed scope: ${scope}" }
return baseEnvironment.getInstance(tier)
}
companion object {
fun Environment.restrictedBy(scope: Set<Tier<*>>) = ScopeEnvironment(this, scope.intersect(this.tiers))
fun Environment.narrowedTo(scope: Set<Tier<*>>): ScopeEnvironment {
require(scope.all { it in this })
return ScopeEnvironment(this, scope)
}
fun Environment.accessibleSafelyByOrNull(requester: TierRequester): ScopeEnvironment? {
if (requester.requiredTiers.any { it !in this }) {
return null
}
return this.accessibleSafelyBy(requester)
}
fun Environment.accessibleSafelyBy(requester: TierRequester): ScopeEnvironment = this.narrowedTo(requester.requiredTiers)
}
}

View File

@@ -0,0 +1,100 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml
import com.intellij.platform.ml.Session.StartOutcome.Failure
import org.jetbrains.annotations.ApiStatus
/**
* A period of making predictions with an ML model.
*
* The session's depth corresponds to the number of levels in a [com.intellij.platform.ml.impl.MLTask].
*
* This is a builder of the tree-like session structure.
* To learn how sessions work, please address this interface's implementations.
*/
@ApiStatus.Internal
sealed interface Session<P : Any> {
/**
* An outcome of an ML session's start.
*
* It is considered to be a non-exceptional situation when it is not possible to start an ML session.
* If so, then a [Failure] is returned and could be handled by the user.
*/
sealed interface StartOutcome<P : Any> {
/**
* A ready-to-use ml session, in case the start was successful
*/
val session: Session<P>?
fun requireSuccess(): Session<P> = when (this) {
is Failure -> throw this.asThrowable()
is Success -> this.session
}
/**
* Indicates that nothing went wrong, and the start was successful
*/
class Success<P : Any>(override val session: Session<P>) : StartOutcome<P>
/**
* Indicates that there was some issue during the start.
* The problem could be identified more precisely by looking at
* the [Failure]'s class.
*/
open class Failure<P : Any> : StartOutcome<P> {
override val session: Session<P>? = null
open val failureDetails: String
get() = "Unable to start ml session, failure: $this"
open fun asThrowable(): Throwable {
return Exception(failureDetails)
}
}
class UncaughtException<P : Any>(val exception: Throwable) : Failure<P>() {
override fun asThrowable(): Throwable {
return Exception("An unexpected exception", exception)
}
}
}
}
/**
* A session, that holds other sessions.
*/
interface NestableMLSession<P : Any> : Session<P> {
/**
* Start another nested session within this one, that will inherit
* this session's features.
*
* @param levelMainEnvironment The main session's environment that contains main ML task's tiers.
*
* @return Either [NestableMLSession] or [SinglePrediction], depending on whether the
* last ML task's level has been reached.
*/
fun createNestedSession(levelMainEnvironment: Environment): Session<P>
/**
* Declare that no more nested sessions will be created from this moment on.
* It must be called.
*/
fun onLastNestedSessionCreated()
}
/**
* A session, that is dedicated to create one prediction at most.
*/
interface SinglePrediction<P : Any> : Session<P> {
/**
* Call ML model's inference and produce the prediction.
* On one object, exactly one function must be called during its lifetime: this or [cancelPrediction]
*/
fun predict(): P
/**
* Declare that model's inference will not be called.
* It must be called in case it's decided that a prediction is not needed.
*/
fun cancelPrediction()
}

View File

@@ -0,0 +1,76 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml
import org.jetbrains.annotations.ApiStatus
/**
* A category of your application's objects that could be used to run the ML API.
*
* Tiers' main purpose is to be sources for features production, to pass more information to an ML model and to have the most precise
* prediction.
* Tiers are described per [TierDescriptor].
* There are two categories of tiers when it comes to description:
*
* - Main tiers:
* the ones that are declared in an [com.intellij.platform.ml.impl.MLTask] and provided at the problem's application point,
* when [com.intellij.platform.ml.NestableMLSession.createNestedSession] is called.
* They exist in each session.
*
* - Additional tiers: the ones that are declared in [com.intellij.platform.ml.impl.approach.LogDrivenModelInference].
* They are provided by various [EnvironmentExtender]s occasionally when an [com.intellij.platform.ml.impl.environment.ExtendedEnvironment]
* is crated, and they could be absent, as it was not able to satisfy extenders' requirements, or a tier instance was simply not
* present this time.
* So we can't rely on their existence.
*
* On the top of that, tiers could serve as helper objects for other user-defined objects, as [TierDescriptor]s and [EnvironmentExtender]s.
* The listed interfaces are [TierRequester]s, which means that they require other tiers for proper functioning.
* You could create additional [EnvironmentExtender]s to create new tiers, or to define other ways to instantiate existing tiers, but
* from other sources.
*/
@ApiStatus.Internal
abstract class Tier<T : Any> {
/**
* A unique name of a tier (among other tiers in your application).
* Class name is used by default.
*/
open val name: String
get() = this.javaClass.simpleName
override fun toString(): String = name
infix fun with(instance: T) = TierInstance(this, instance)
}
/**
* A helper class to type-safely handle pairs of [tier] and the corresponding [instance].
*/
data class TierInstance<T : Any>(val tier: Tier<T>, val instance: T)
typealias PerTier<T> = Map<Tier<*>, T>
typealias PerTierInstance<T> = Map<TierInstance<*>, T>
fun <T> Iterable<PerTier<T>>.joinByUniqueTier(): PerTier<T> {
val joinedPerTier = mutableMapOf<Tier<*>, T>()
this.forEach { perTierMapping ->
perTierMapping.forEach { (tier, value) ->
require(tier !in joinedPerTier)
joinedPerTier[tier] = value
}
}
return joinedPerTier
}
fun <T, CI : Iterable<T>, CO : MutableCollection<T>> Iterable<PerTier<CI>>.mergePerTier(createCollection: () -> CO): PerTier<CO> {
val joinedPerTier = mutableMapOf<Tier<*>, CO>()
for (perTierMapping in this) {
for ((tier, anotherCollection) in perTierMapping) {
val existingCollection = joinedPerTier[tier] ?: emptyList()
require(anotherCollection.all { it !in existingCollection })
joinedPerTier.computeIfAbsent(tier) { createCollection() }.addAll(anotherCollection)
}
}
return joinedPerTier
}

View File

@@ -0,0 +1,81 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml
import com.intellij.openapi.extensions.ExtensionPointName
import org.jetbrains.annotations.ApiStatus
/**
* Provides features for a particular [tier].
*
* It is also a [TierRequester], which implies that [additionallyRequiredTiers] could be defined, in
* case additional objects are required to describe the [tier]'s instance.
* If so, an attempt will be made to create an [com.intellij.platform.ml.impl.environment.ExtendedEnvironment],
* and the extended environment that contains the described [tier] and [additionallyRequiredTiers]
* will be passed to the [describe] function.
*
* @see ObsoleteTierDescriptor if you are looking for an opportunity to smoothly transfer your old feature providers to this API.
*/
@ApiStatus.Internal
interface TierDescriptor : TierRequester {
/**
* The tier, that this descriptor is describing (giving features to).
*/
val tier: Tier<*>
/**
* All features that could ever be used in the declaration.
*
* _Important_: Do not change features' names.
* Names serve as features identifiers, and they are used to pass to ML models.
*/
val descriptionDeclaration: Set<FeatureDeclaration<*>>
/**
* Computes [tier]'s features from with the features from [descriptionDeclaration].
*
* If [additionallyRequiredTiers] could not be satisfied, then this descriptor is not called at all.
* Be aware that _each_ feature given in [descriptionDeclaration] must be calculated here.
* If it is possible that a feature is not computable withing particular circumstances, then
* you could declare your feature as nullable: [FeatureDeclaration.nullable].
*
* @param environment Contains [tier] and [additionallyRequiredTiers].
* @param usefulFeaturesFilter Accepts features, that could make any difference to compute this time.
* A feature is considered to be useful if an ML model is aware of the feature,
* or it is explicitly said that "ML model is not aware of this feature, but it must be logged"
* (@see [com.intellij.platform.ml.impl.approach.LogDrivenModelInference]).
*
* @throws IllegalArgumentException If a feature is missing: it is accepted by [usefulFeaturesFilter]
* and it is declared, but not present in the result.
* @throws IllegalArgumentException If a redundant feature was computed, that was not declared.
*/
fun describe(environment: Environment, usefulFeaturesFilter: FeatureFilter): Set<Feature>
/**
* Declares a requirement and ensures, that [describe] will be called if and only if
* the [environment] will contain both [tier] and [additionallyRequiredTiers].
*/
override val requiredTiers: Set<Tier<*>>
get() = additionallyRequiredTiers + tier
/**
* Tiers that are required additionally to perform description.
*
* Such tiers' instances could be created via existing or additionally
* created [EnvironmentExtender]s.
*/
val additionallyRequiredTiers: Set<Tier<*>>
get() = emptySet()
/**
* Tells if the descriptor could generate any useful features at all.
* @param usefulFeaturesFilter Accepts features, that make sense to calculate at the time of this invocation.
* If the filter does not accept a feature, it means that its computation will not make any difference.
*/
fun couldBeUseful(usefulFeaturesFilter: FeatureFilter): Boolean {
return descriptionDeclaration.any { usefulFeaturesFilter.accept(it) }
}
companion object {
val EP_NAME: ExtensionPointName<TierDescriptor> = ExtensionPointName.create("com.intellij.platform.ml.descriptor")
}
}

View File

@@ -0,0 +1,65 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml
import org.jetbrains.annotations.ApiStatus
/**
* An object that contains tier instances, and it could be extended.
*/
@ApiStatus.Internal
class TierInstanceStorage : MutableEnvironment {
private val instances: MutableMap<Tier<*>, Any> = mutableMapOf()
override val tiers: Set<Tier<*>>
get() = instances.keys
override val tierInstances: Set<TierInstance<*>>
get() = instances
.map { (tier, tierInstance) -> tier withUnsafe tierInstance }
.toSet()
private infix fun <T : Any, P : Any> Tier<T>.withUnsafe(value: P): TierInstance<T> {
@Suppress("UNCHECKED_CAST")
return this.with(value as T)
}
override fun <T : Any> getInstance(tier: Tier<T>): T {
val tierInstance = instances[tier]
@Suppress("UNCHECKED_CAST")
return requireNotNull(tierInstance) as T
}
override fun <T : Any> putTierInstance(tier: Tier<T>, instance: T) {
require(tier !in this) {
"Tier $tier is already registered in the storage. Old value: '${instances[tier]}', new value: '$instance'"
}
instances[tier] = instance
}
companion object {
fun copyOf(environment: Environment): TierInstanceStorage {
val storage = TierInstanceStorage()
fun <T : Any> putInstance(tier: Tier<T>) {
storage[tier] = environment[tier]
}
environment.tiers.forEach { putInstance(it) }
return storage
}
fun joined(environments: Iterable<Environment>): Environment {
val commonStorage = TierInstanceStorage()
fun <T : Any> putCapturingType(tier: Tier<T>, environment: Environment) {
commonStorage[tier] = environment[tier]
}
environments.forEach { environment ->
environment.tiers.forEach { tier ->
putCapturingType(tier, environment)
}
}
return commonStorage
}
}
}

View File

@@ -0,0 +1,26 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml
import org.jetbrains.annotations.ApiStatus
/**
* An interface that represents an object, that requires some additional tiers
* for proper functioning.
*
* If the requirements could not be satisfied, then the main object's functionality will not be run.
* Please address the interface's inheritors, such as [TierDescriptor], [EnvironmentExtender],
* and [com.intellij.platform.ml.impl.model.MLModel.Provider] for more details.
*/
@ApiStatus.Internal
interface TierRequester {
/**
* The tiers, that are required to use the object.
*/
val requiredTiers: Set<Tier<*>>
companion object {
fun <T : TierRequester> Iterable<T>.fulfilledBy(environment: Environment): List<T> {
return this.filter { it.requiredTiers.all { requiredTier -> requiredTier in environment } }
}
}
}

View File

@@ -28,6 +28,7 @@
<sourceFolder url="file://$MODULE_DIR$/src" isTestSource="false" />
<sourceFolder url="file://$MODULE_DIR$/test" isTestSource="true" />
<sourceFolder url="file://$MODULE_DIR$/resources" type="java-resource" />
<sourceFolder url="file://$MODULE_DIR$/testResources" type="java-test-resource" />
</content>
<orderEntry type="inheritedJdk" />
<orderEntry type="sourceFolder" forTests="false" />
@@ -44,5 +45,7 @@
<orderEntry type="module" module-name="intellij.platform.util.text.matching" />
<orderEntry type="module" module-name="intellij.platform.lang.impl" />
<orderEntry type="library" name="kotlinx-serialization-core" level="project" />
<orderEntry type="module" module-name="intellij.platform.ml" />
<orderEntry type="module" module-name="intellij.platform.statistics" />
</component>
</module>

View File

@@ -6,6 +6,9 @@
<projectService serviceInterface="com.intellij.internal.statistic.local.LanguageUsageStatisticsProvider"
serviceImplementation="com.intellij.internal.statistic.local.LanguageUsageStatisticsProviderImpl"/>
<statistics.counterUsagesCollector implementationClass="com.intellij.platform.ml.impl.logger.MLEventsLogger"/>
<registryKey defaultValue="false" description="Log features that were computed but missing in an ObsoleteTierDescriptor's declaration"
key="ml.description.logMissing"/>
</extensions>
<projectListeners>
@@ -17,9 +20,14 @@
<extensionPoints>
<extensionPoint qualifiedName="com.intellij.platform.ml.impl.turboComplete.smartPipelineRunner"
interface="com.intellij.platform.ml.impl.turboComplete.SmartPipelineRunner" dynamic="true"/>
interface="com.intellij.platform.ml.impl.turboComplete.SmartPipelineRunner"
dynamic="true"/>
<extensionPoint qualifiedName="com.intellij.platform.ml.impl.approach"
interface="com.intellij.platform.ml.impl.MLTaskApproachInitializer"
dynamic="true"/>
<extensionPoint name="mlCompletionCorrectnessSupporter" beanClass="com.intellij.platform.ml.impl.correctness.MLCompletionCorrectnessSupporterEP" dynamic="true">
<extensionPoint name="mlCompletionCorrectnessSupporter"
beanClass="com.intellij.platform.ml.impl.correctness.MLCompletionCorrectnessSupporterEP" dynamic="true">
<with attribute="implementationClass" implements="com.intellij.platform.ml.impl.correctness.MLCompletionCorrectnessSupporter"/>
</extensionPoint>
</extensionPoints>

View File

@@ -0,0 +1,46 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl
import com.intellij.platform.ml.*
import com.intellij.platform.ml.ScopeEnvironment.Companion.accessibleSafelyBy
import org.jetbrains.annotations.ApiStatus
/**
* Computes tiers descriptions, calling [TierDescriptor.describe],
* or caching the descriptions.
*/
@ApiStatus.Internal
interface DescriptionComputer {
/**
*
* @param tier The tier that all is being described.
* @param descriptors All relevant descriptors, that can be called within [environment]
* (it is guaranteed that they all describe [tier] and their requirements are fulfilled by [environment]).
* @param environment An environment, that contains all tiers required to run any of the [descriptors].
* @param usefulFeaturesFilter Accepts features, that are meaningful to compute. If the filter does not
* accept a feature, its computation will not make any difference later.
*/
fun computeDescription(
tier: Tier<*>,
descriptors: List<TierDescriptor>,
environment: Environment,
usefulFeaturesFilter: FeatureFilter,
): Map<TierDescriptor, Set<Feature>>
}
/**
* Does not cache descriptors, primitively computes all descriptors all over again each time.
*/
@ApiStatus.Internal
class StateFreeDescriptionComputer : DescriptionComputer {
override fun computeDescription(tier: Tier<*>,
descriptors: List<TierDescriptor>,
environment: Environment,
usefulFeaturesFilter: FeatureFilter): Map<TierDescriptor, Set<Feature>> {
return descriptors.associateWith { descriptor ->
require(descriptor.tier == tier)
require(descriptor.requiredTiers.all { it in environment.tiers })
descriptor.describe(environment.accessibleSafelyBy(descriptor), usefulFeaturesFilter)
}
}
}

View File

@@ -0,0 +1,91 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl
import com.intellij.platform.ml.FeatureDeclaration
import com.intellij.platform.ml.PerTier
import org.jetbrains.annotations.ApiStatus
/**
* Determines a set of features, and can tell if a set of features is "complete" or not.
* The Meaning of the "completeness" depends on the selector's usage.
*
* Selectors are used to determine if there are enough features to run an ML model,
* and to not compute redundant features, the ones that will not be used by the ML model,
* and are not desired to be logged.
*
* Two types of instances have the power to select features:
* - [com.intellij.platform.ml.impl.model.MLModel] to tell which features are known.
* - [com.intellij.platform.ml.impl.approach.LogDrivenModelInference] to tell which features are not known by the ML model,
* but they are still must be computed and then logged.
*/
@ApiStatus.Internal
interface FeatureSelector {
/**
* @param availableFeatures A set of features that the selection will be built from.
*
* @return A set of features that are 'selected' and a completeness marker
* [Selection.Incomplete] is returned if there are some more features that are 'selected',
* but they are not present among [availableFeatures].
* [Selection.Complete] is returned otherwise.
*/
fun select(availableFeatures: Set<FeatureDeclaration<*>>): Selection
/**
* @param featureDeclaration A single feature, that needs to be selected.
*
* @return If the feature belongs to the determined set of features.
*/
fun select(featureDeclaration: FeatureDeclaration<*>): Boolean {
return select(setOf(featureDeclaration)).selectedFeatures.isNotEmpty()
}
sealed class Selection(val selectedFeatures: Set<FeatureDeclaration<*>>) {
class Complete(selectedFeatures: Set<FeatureDeclaration<*>>) : Selection(selectedFeatures)
open class Incomplete(selectedFeatures: Set<FeatureDeclaration<*>>) : Selection(selectedFeatures) {
open val details: String = "Incomplete selection, only these were selected: $selectedFeatures"
}
companion object {
val NOTHING = Complete(emptySet())
}
}
companion object {
val NOTHING = object : FeatureSelector {
override fun select(availableFeatures: Set<FeatureDeclaration<*>>): Selection = Selection.NOTHING
}
val EVERYTHING = object : FeatureSelector {
override fun select(availableFeatures: Set<FeatureDeclaration<*>>): Selection = Selection.Complete(availableFeatures)
}
infix fun FeatureSelector.or(other: FeatureSelector): FeatureSelector {
return object : FeatureSelector {
override fun select(availableFeatures: Set<FeatureDeclaration<*>>): Selection {
val thisSelection = this@or.select(availableFeatures)
val otherSelection = other.select(availableFeatures)
val joinedSelection = (thisSelection.selectedFeatures + otherSelection.selectedFeatures).toSet()
return if (thisSelection is Selection.Incomplete || otherSelection is Selection.Incomplete) {
val incompleteSelection = if (thisSelection is Selection.Incomplete)
thisSelection
else
otherSelection as Selection.Incomplete
object : Selection.Incomplete(joinedSelection) {
override val details: String
get() = incompleteSelection.details
}
}
else
Selection.Complete(joinedSelection)
}
}
}
infix fun PerTier<FeatureSelector>.or(other: PerTier<FeatureSelector>): PerTier<FeatureSelector> {
require(this.keys == other.keys)
return keys.associateWith { this.getValue(it) or other.getValue(it) }
}
}
}

View File

@@ -0,0 +1,69 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl
import com.intellij.platform.ml.*
import com.intellij.platform.ml.impl.model.MLModel
import com.intellij.platform.ml.impl.session.DescribedLevel
import com.intellij.platform.ml.impl.session.NestableStructureCollector
import com.intellij.platform.ml.impl.session.PredictionCollector
import com.intellij.platform.ml.impl.session.SessionTree
import org.jetbrains.annotations.ApiStatus
/**
* A [SinglePrediction] performed by an ML model.
* The session's structure is collected by [collector], after the prediction is done or canceled.
*/
@ApiStatus.Internal
class MLModelPrediction<T : SessionTree.PredictionContainer<M, DescribedLevel, P>, M : MLModel<P>, P : Any>(
private val mlModel: M,
private val collector: PredictionCollector<T, M, P>,
) : SinglePrediction<P> {
override fun predict(): P {
val prediction = mlModel.predict(collector.usableDescription)
collector.submitPrediction(prediction)
return prediction
}
override fun cancelPrediction() {
collector.submitPrediction(null)
}
}
/**
* A [NestableMLSession] of utilizing [mlModel].
* The session's structure is collected by [collector] after [onLastNestedSessionCreated],
* and all nested sessions' structures are collected.
*/
@ApiStatus.Internal
class MLModelPredictionBranching<T : SessionTree.ChildrenContainer<M, DescribedLevel, P>, M : MLModel<P>, P : Any>(
private val mlModel: M,
private val collector: NestableStructureCollector<T, M, P>
) : NestableMLSession<P> {
override fun createNestedSession(levelMainEnvironment: Environment): Session<P> {
val nestedLevelScheme = collector.levelPositioning.lowerTiers.first()
verifyTiersInMain(nestedLevelScheme.main, levelMainEnvironment.tiers)
val levelAdditionalTiers = nestedLevelScheme.additional
return if (collector.levelPositioning.lowerTiers.size == 1) {
val nestedCollector = collector.nestPrediction(levelMainEnvironment, levelAdditionalTiers)
MLModelPrediction(mlModel, nestedCollector)
}
else {
assert(collector.levelPositioning.lowerTiers.size > 1)
val nestedCollector = collector.nestBranch(levelMainEnvironment, levelAdditionalTiers)
MLModelPredictionBranching(mlModel, nestedCollector)
}
}
override fun onLastNestedSessionCreated() {
collector.onLastNestedCollectorCreated()
}
}
private fun verifyTiersInMain(expected: Set<Tier<*>>, actual: Set<*>) {
require(expected == actual) {
"Tier set in the main environment is not like it was declared. " +
"Declared $expected, " +
"but given $actual"
}
}

View File

@@ -0,0 +1,129 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl
import com.intellij.openapi.extensions.ExtensionPointName
import com.intellij.platform.ml.*
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform
import com.intellij.platform.ml.impl.apiPlatform.ReplaceableIJPlatform
import com.intellij.platform.ml.impl.session.AdditionalTierScheme
import com.intellij.platform.ml.impl.session.Level
import com.intellij.platform.ml.impl.session.MainTierScheme
import org.jetbrains.annotations.ApiStatus
/**
* Is a declaration of a place in code, where a classical machine learning approach
* is desired to be applied.
*
* The proper way to create new tasks is to create a static object-inheritor of this class
* (see inheritors of this class to find all implemented tasks in your project).
*
* @param name The unique name of an ML Task
* @param levels The main tiers of the task that will be provided within the task's application place
* @param predictionClass The class of an object, that will serve as "prediction"
* @param T The type of prediction
*/
@ApiStatus.Internal
abstract class MLTask<T : Any> protected constructor(
val name: String,
val levels: List<Set<Tier<*>>>,
val predictionClass: Class<T>
)
/**
* A method of approaching an ML task.
* Usually, it is inferencing an ML model and collecting logs.
*
* Each [MLTaskApproach] is initialized once by the corresponding [MLTaskApproachInitializer],
* then the [apiPlatform] is fixed.
*
* @see [com.intellij.platform.ml.impl.approach.LogDrivenModelInference] for currently used approach.
*/
@ApiStatus.Internal
interface MLTaskApproach<P : Any> {
/**
* The task this approach is solving.
* Each approach is dedicated to one and only task, and it is aware of it.
*/
val task: MLTask<P>
/**
* The platform, this approach is called within, that was provided by [MLTaskApproachInitializer]
*/
val apiPlatform: MLApiPlatform
/**
* A static declaration of the features, used in the approach.
*/
val approachDeclaration: Declaration
/**
* Acquire the ML model and start the session.
*
* @return [Session.StartOutcome.Failure] if something went wrong during the start, [Session.StartOutcome.Success]
* which contains the started session otherwise.
*/
fun startSession(permanentSessionEnvironment: Environment): Session.StartOutcome<P>
data class Declaration(
val sessionFeatures: Map<String, Set<FeatureDeclaration<*>>>,
val levelsScheme: List<LevelScheme>
)
companion object {
fun <P : Any> findMlApproach(task: MLTask<P>, apiPlatform: MLApiPlatform = ReplaceableIJPlatform): MLTaskApproach<P> {
return apiPlatform.accessApproachFor(task)
}
fun <P : Any> startMLSession(task: MLTask<P>,
permanentSessionEnvironment: Environment,
apiPlatform: MLApiPlatform = ReplaceableIJPlatform): Session.StartOutcome<P> {
val approach = findMlApproach(task, apiPlatform)
return approach.startSession(permanentSessionEnvironment)
}
fun <P : Any> MLTask<P>.startMLSession(permanentSessionEnvironment: Environment): Session.StartOutcome<P> {
return startMLSession(this, permanentSessionEnvironment)
}
fun <P : Any> MLTask<P>.startMLSession(permanentTierInstances: Iterable<TierInstance<*>>): Session.StartOutcome<P> {
return this.startMLSession(Environment.of(permanentTierInstances))
}
fun <P : Any> MLTask<P>.startMLSession(vararg permanentTierInstances: TierInstance<*>): Session.StartOutcome<P> {
return this.startMLSession(Environment.of(*permanentTierInstances))
}
}
}
/**
* Initializes an [MLTaskApproach]
*/
@ApiStatus.Internal
interface MLTaskApproachInitializer<P : Any> {
/**
* The task, that the created [MLTaskApproach] is dedicated to solve.
*/
val task: MLTask<P>
/**
* Initializes the approach.
* It is called only once during the application's runtime.
* So it is crucial that this function will accept the [MLApiPlatform] you want it to.
*
* To access the API in order to build event validator statically,
* FUS uses the actual [com.intellij.platform.ml.impl.apiPlatform.IJPlatform], which could be problematic if you
* want to test FUS logs.
* So make sure that you will replace it with your test platform in
* time via [com.intellij.platform.ml.impl.apiPlatform.ReplaceableIJPlatform.replacingWith].
*/
fun initializeApproachWithin(apiPlatform: MLApiPlatform): MLTaskApproach<P>
companion object {
val EP_NAME = ExtensionPointName<MLTaskApproachInitializer<*>>("com.intellij.platform.ml.impl.approach")
}
}
typealias LevelScheme = Level<PerTier<MainTierScheme>, PerTier<AdditionalTierScheme>>
typealias LevelTiers = Level<Set<Tier<*>>, Set<Tier<*>>>

View File

@@ -0,0 +1,40 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.apiPlatform
import com.intellij.platform.ml.FeatureDeclaration
import com.intellij.platform.ml.FeatureValueType
import org.jetbrains.annotations.ApiStatus
/**
* Prints code-like features representation, when they have to be logged for an API user's
* convenience (so they can just copy-paste the logged features into their code).
*/
@ApiStatus.Internal
class CodeLikePrinter {
private val <T : Enum<*>> FeatureValueType.Enum<T>.codeLikeType: String
get() = this.enumClass.name
private fun <T> FeatureValueType<T>.makeCodeLikeString(name: String): String = when (this) {
FeatureValueType.Boolean -> "FeatureDeclaration.boolean(\"$name\")"
FeatureValueType.Class -> "FeatureDeclaration.aClass(\"$name\")"
FeatureValueType.Double -> "FeatureDeclaration.double(\"$name\")"
is FeatureValueType.Enum<*> -> "FeatureDeclaration.enum<${this.codeLikeType}>(\"$name\")"
FeatureValueType.Float -> "FeatureDeclaration.float(\"${name}\")"
FeatureValueType.Int -> "FeatureDeclaration.int(\"${name}\")"
FeatureValueType.Long -> "FeatureDeclaration.long(\"${name}\")"
is FeatureValueType.Nullable<*> -> "${this.baseType.makeCodeLikeString(name)}.nullable()"
is FeatureValueType.Categorical -> {
val possibleValuesSerialized = possibleValues.joinToString(", ") { "\"$it\"" }
"FeatureDeclaration.categorical(\"$name\", setOf(${possibleValuesSerialized}))"
}
FeatureValueType.Version -> "FeatureDeclaration.version(\"${name}\")"
}
fun <T> printCodeLikeString(featureDeclaration: FeatureDeclaration<T>): String {
return featureDeclaration.type.makeCodeLikeString(featureDeclaration.name)
}
fun printCodeLikeString(featureDeclarations: Collection<FeatureDeclaration<*>>): String {
return featureDeclarations.joinToString(", ") { printCodeLikeString(it) }
}
}

View File

@@ -0,0 +1,186 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.apiPlatform
import com.intellij.openapi.diagnostic.thisLogger
import com.intellij.openapi.util.registry.Registry
import com.intellij.platform.ml.EnvironmentExtender
import com.intellij.platform.ml.Feature
import com.intellij.platform.ml.ObsoleteTierDescriptor
import com.intellij.platform.ml.TierDescriptor
import com.intellij.platform.ml.impl.MLTaskApproachInitializer
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform.ExtensionController
import com.intellij.platform.ml.impl.apiPlatform.ReplaceableIJPlatform.replacingWith
import com.intellij.platform.ml.impl.logger.MLEvent
import com.intellij.platform.ml.impl.monitoring.MLApiStartupListener
import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener
import com.intellij.util.application
import com.intellij.util.messages.Topic
import org.jetbrains.annotations.ApiStatus
import org.jetbrains.annotations.NonNls
@ApiStatus.Internal
fun interface MessagingProvider<T> {
fun provide(collector: (T) -> Unit)
companion object {
inline fun <T, reified P : MessagingProvider<T>> createTopic(displayName: @NonNls String): Topic<P> {
return Topic.create(displayName, P::class.java)
}
fun <T, P : MessagingProvider<T>> collect(topic: Topic<P>): List<T> {
val collected = mutableListOf<T>()
application.messageBus.syncPublisher(topic).provide { collected.add(it) }
return collected
}
}
}
fun interface MLTaskListenerProvider : MessagingProvider<MLTaskGroupListener> {
companion object {
val TOPIC = MessagingProvider.createTopic<MLTaskGroupListener, MLTaskListenerProvider>("ml.task")
}
}
fun interface MLEventProvider : MessagingProvider<MLEvent> {
companion object {
val TOPIC = MessagingProvider.createTopic<MLEvent, MLEventProvider>("ml.event")
}
}
fun interface MLApiStartupListenerProvider : MessagingProvider<MLApiStartupListener> {
companion object {
val TOPIC = MessagingProvider.createTopic<MLApiStartupListener, MLApiStartupListenerProvider>("ml.startup")
}
}
/**
* A representation of the "real-life" [MLApiPlatform], whose content is
* the content of the corresponding Extension Points.
* It is used at the API's entry point, unless it is not replaced by another.
*
* It shouldn't be used due to low testability.
* Use [ReplaceableIJPlatform] instead.
*/
@ApiStatus.Internal
private data object IJPlatform : MLApiPlatform() {
override val tierDescriptors: List<TierDescriptor>
get() = TierDescriptor.EP_NAME.extensionList
override val environmentExtenders: List<EnvironmentExtender<*>>
get() = EnvironmentExtender.EP_NAME.extensionList
override val taskApproaches: List<MLTaskApproachInitializer<*>>
get() = MLTaskApproachInitializer.EP_NAME.extensionList
override val taskListeners: List<MLTaskGroupListener>
get() = MessagingProvider.collect(MLTaskListenerProvider.TOPIC)
override val events: List<MLEvent>
get() = MessagingProvider.collect(MLEventProvider.TOPIC)
override val startupListeners: List<MLApiStartupListener>
get() = MessagingProvider.collect(MLApiStartupListenerProvider.TOPIC)
override fun addStartupListener(listener: MLApiStartupListener): ExtensionController {
val connection = application.messageBus.connect()
connection.subscribe(MLApiStartupListenerProvider.TOPIC, MLApiStartupListenerProvider { collector -> collector(listener) })
return ExtensionController { connection.disconnect() }
}
override fun addTaskListener(taskListener: MLTaskGroupListener): ExtensionController {
val connection = application.messageBus.connect()
connection.subscribe(MLTaskListenerProvider.TOPIC, MLTaskListenerProvider { it(taskListener) })
return ExtensionController { connection.disconnect() }
}
override fun addEvent(event: MLEvent): ExtensionController {
val connection = application.messageBus.connect()
connection.subscribe(MLEventProvider.TOPIC, MLEventProvider { it(event) })
return ExtensionController { connection.disconnect() }
}
override fun manageNonDeclaredFeatures(descriptor: ObsoleteTierDescriptor, nonDeclaredFeatures: Set<Feature>) {
if (!Registry.`is`("ml.description.logMissing")) return
val printer = CodeLikePrinter()
val codeLikeMissingDeclaration = printer.printCodeLikeString(nonDeclaredFeatures.map { it.declaration })
thisLogger().info("${descriptor::class.java} is missing declaration: setOf($codeLikeMissingDeclaration)")
}
}
/**
* Also a "real-life" [MLApiPlatform], but it can be replaced with another one any time.
*
* We always want to test [com.intellij.platform.ml.impl.MLTaskApproach]es.
* But after they are initialized by [com.intellij.platform.ml.impl.MLTaskApproachInitializer],
* the passed [MLApiPlatform] could already spread all the way within the API.
* But the user-defined instances of the api could be overridden for testing sake.
*
* To replace all [TierDescriptor], [EnvironmentExtender] and [MLTaskApproachInitializer]
* to test your code, you may call [replacingWith] and pass the desired environment,
* that contains all the objects you need for your test.
*/
@ApiStatus.Internal
object ReplaceableIJPlatform : MLApiPlatform() {
private var replacement: MLApiPlatform? = null
private val platform: MLApiPlatform
get() = replacement ?: IJPlatform
override val tierDescriptors: List<TierDescriptor>
get() = platform.tierDescriptors
override val environmentExtenders: List<EnvironmentExtender<*>>
get() = platform.environmentExtenders
override val taskApproaches: List<MLTaskApproachInitializer<*>>
get() = platform.taskApproaches
override val taskListeners: List<MLTaskGroupListener>
get() = platform.taskListeners
override val events: List<MLEvent>
get() = platform.events
override val startupListeners: List<MLApiStartupListener>
get() = platform.startupListeners
override fun addStartupListener(listener: MLApiStartupListener): ExtensionController {
return extend(listener) { platform -> platform.addStartupListener(listener) }
}
override fun addTaskListener(taskListener: MLTaskGroupListener): ExtensionController {
return extend(taskListener) { platform -> platform.addTaskListener(taskListener) }
}
override fun addEvent(event: MLEvent): ExtensionController {
return extend(event) { platform -> platform.addEvent(event) }
}
override fun manageNonDeclaredFeatures(descriptor: ObsoleteTierDescriptor, nonDeclaredFeatures: Set<Feature>) =
platform.manageNonDeclaredFeatures(descriptor, nonDeclaredFeatures)
private fun <T> extend(obj: T, method: (MLApiPlatform) -> ExtensionController): ExtensionController {
val initialPlatform = platform
method(initialPlatform)
return ExtensionController {
require(initialPlatform == platform) {
"$obj should be removed within the same platform it was added in." +
"It was added in $initialPlatform, but removed from $platform"
}
}
}
fun <T> replacingWith(apiPlatform: MLApiPlatform, action: () -> T): T {
val oldApiPlatform = replacement
return try {
replacement = apiPlatform
action()
}
finally {
replacement = oldApiPlatform
}
}
}

View File

@@ -0,0 +1,264 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.apiPlatform
import com.intellij.platform.ml.*
import com.intellij.platform.ml.impl.MLTask
import com.intellij.platform.ml.impl.MLTaskApproach
import com.intellij.platform.ml.impl.MLTaskApproachInitializer
import com.intellij.platform.ml.impl.logger.MLEvent
import com.intellij.platform.ml.impl.logger.MLEventsLogger
import com.intellij.platform.ml.impl.monitoring.*
import com.intellij.platform.ml.impl.monitoring.MLApproachListener.Companion.asJoinedListener
import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.Companion.onAttemptedToStartSession
import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.Companion.targetedApproaches
import org.jetbrains.annotations.ApiStatus
/**
* Represents an environment, that provides extendable parts of the ML API.
*
* Each entity inside the API could access the platform, it is running within,
* as everything happens after [com.intellij.platform.ml.impl.MLTaskApproachInitializer.initializeApproachWithin],
* where the platform is acknowledged.
*
* All usages of the ij platform functionality (extension points, registry keys, etc.) shall be
* accessed via this class.
*/
@ApiStatus.Internal
abstract class MLApiPlatform {
private val finishedInitialization = lazy { MLApiPlatformInitializationProcess() }
private var initializationStage: InitializationStage = InitializationStage.NotStarted
/**
* Each [MLTaskApproach] is initialized only once during the application's lifetime.
* This function keeps track of all approaches that were initialized already, and initializes
* them when they are first needed.
*/
fun <P : Any> accessApproachFor(task: MLTask<P>): MLTaskApproach<P> {
return finishedInitialization.value.getApproachFor(task)
}
/**
* The extendable static state of the platform that must be fixed to create FUS group's validator.
*
* These values must be static, as they define the FUS scheme
*/
val staticState: StaticState
get() = StaticState(tierDescriptors, environmentExtenders, taskApproaches)
/**
* The descriptors that are available in the platform.
* This value is interchangeable during the application runtime,
* see [staticState].
*/
abstract val tierDescriptors: List<TierDescriptor>
/**
* The complete list of environment extenders, available in the platform.
* This value is interchangeable during the application runtime,
* see [staticState].
*/
abstract val environmentExtenders: List<EnvironmentExtender<*>>
/**
* The complete list of the approaches for ML tasks, available in the platform.
* This value is interchangeable during the application runtime,
* see [staticState].
*/
abstract val taskApproaches: List<MLTaskApproachInitializer<*>>
/**
* All the objects, that are listening execution of ML tasks.
* The collection is mutable, so new listeners could be added via [addTaskListener].
*
* This value is mutable, new listeners could be added anytime.
*/
abstract val taskListeners: List<MLTaskGroupListener>
/**
* Adds another provider for ML tasks' execution process monitoring dynamically.
* The event could be removed via the corresponding [ExtensionController.removeExtension] call.
* See [taskListeners].
*/
abstract fun addTaskListener(taskListener: MLTaskGroupListener): ExtensionController
/**
* ML events that will be written to FUS logs.
* As FUS is initialized only once, on the application's startup, they all must be registered
* before that via [addMLEventBeforeFusInitialized].
*
* This value could be mutable, however, only during a short period of time: after the application's startup,
* and before FUS logs initialization.
*/
abstract val events: List<MLEvent>
/**
* Adds another ML event dynamically.
* The event could be removed via the corresponding [ExtensionController.removeExtension] call.
*/
fun addMLEventBeforeFusInitialized(event: MLEvent): ExtensionController {
when (val stage = initializationStage) {
is InitializationStage.Failed ->
throw Exception("Initialization of ML Api Platform has failed, events could not be added.", stage.asException)
is InitializationStage.PotentiallySuccessful -> {
require(stage.order <= InitializationStage.InitializingApproaches.order) {
"FUS group initialization has already been started, not allowed to register more ML Events"
}
return addEvent(event)
}
}
}
/**
* The complete list of the listeners that are listening to the process of an MLApiPlatform's initialization.
* The initialization is performed in [finishedInitialization]'s init block.
*
* This value could be mutable, so
* additional listeners could be added via [addStartupListener].
*
* If a listener was added after a certain initialization stage,
* only callbacks of those stages will be triggered later that have not happened yet.
*/
abstract val startupListeners: List<MLApiStartupListener>
/**
* Adds another startup listener.
* The listener could be removed via the corresponding [ExtensionController.removeExtension] call.
*/
abstract fun addStartupListener(listener: MLApiStartupListener): ExtensionController
/**
* Declares how the computed but non-declared features will be handled.
*/
abstract fun manageNonDeclaredFeatures(descriptor: ObsoleteTierDescriptor, nonDeclaredFeatures: Set<Feature>)
internal abstract fun addEvent(event: MLEvent): ExtensionController
fun interface ExtensionController {
fun removeExtension()
}
data class StaticState(
val tierDescriptors: List<TierDescriptor>,
val environmentExtenders: List<EnvironmentExtender<*>>,
val taskApproaches: List<MLTaskApproachInitializer<*>>,
)
private sealed class InitializationStage(val callListener: (MLApiStartupProcessListener) -> Unit) {
sealed class PotentiallySuccessful(val order: Int, callListener: (MLApiStartupProcessListener) -> Unit) : InitializationStage(callListener)
class Failed(lastStage: InitializationStage, nextStage: InitializationStage, exception: Throwable, callListener: (MLApiStartupProcessListener) -> Unit) : InitializationStage(callListener) {
val asException = Exception("Failed to proceed from the initialization stage $lastStage to $nextStage", exception)
}
data object NotStarted : PotentiallySuccessful(0, {})
data object InitializingApproaches : PotentiallySuccessful(1, { it.onStartedInitializingApproaches() })
data class InitializingFUS(val initializedApproaches: Collection<InitializerAndApproach<*>>) : PotentiallySuccessful(2, {
it.onStartedInitializingFus(initializedApproaches)
})
data object Finished : PotentiallySuccessful(3, { it.onFinished() })
}
private inner class MLApiPlatformInitializationProcess {
val approachPerTask: Map<MLTask<*>, MLTaskApproach<*>>
private val completeInitializersList: List<MLTaskApproachInitializer<*>> = taskApproaches.toMutableList()
init {
require(initializationStage == InitializationStage.NotStarted) { "ML API Platform's initialization should not be run twice" }
fun currentStartupListeners(): List<MLApiStartupProcessListener> = startupListeners.map { it.onBeforeStarted(this@MLApiPlatform) }
fun <T> proceedToNextStage(nextStage: InitializationStage.PotentiallySuccessful, action: () -> T): T {
return try {
action().also {
currentStartupListeners().forEach { nextStage.callListener(it) }
initializationStage = nextStage
}
}
catch (ex: Throwable) {
val failure = InitializationStage.Failed(initializationStage, nextStage, ex) { it.onFailed(ex) }
initializationStage = failure
throw failure.asException
}
}
proceedToNextStage(InitializationStage.InitializingApproaches) {}
val initializedApproachPerTask = mutableListOf<InitializerAndApproach<*>>()
approachPerTask = proceedToNextStage(InitializationStage.InitializingFUS(initializedApproachPerTask)) {
completeInitializersList.validate()
fun <T : Any> initializeApproach(approachInitializer: MLTaskApproachInitializer<T>) {
initializedApproachPerTask.add(InitializerAndApproach(
approachInitializer,
approachInitializer.initializeApproachWithin(this@MLApiPlatform)
))
}
completeInitializersList.forEach { initializeApproach(it) }
initializedApproachPerTask.associate { it.initializer.task to it.approach }
}
proceedToNextStage(InitializationStage.Finished) {
MLEventsLogger.Manager.ensureInitialized(okIfInitializing = true, this@MLApiPlatform)
}
}
fun <P : Any> getApproachFor(task: MLTask<P>): MLTaskApproach<P> {
val taskApproach = requireNotNull(approachPerTask[task]) {
val mainMessage = "No approach for task $task was found"
val lateRegistrationMessage = getLateApproachRegistrationAssumption(task)
if (lateRegistrationMessage != null) "$mainMessage. $lateRegistrationMessage" else mainMessage
}
@Suppress("UNCHECKED_CAST")
return taskApproach as MLTaskApproach<P>
}
private fun getLateApproachRegistrationAssumption(task: MLTask<*>): String? {
val currentInitializersList = taskApproaches.toMutableList()
if (completeInitializersList == currentInitializersList) return null
val taskApproaches = currentInitializersList.filter { it.task == task }
if (taskApproaches.isEmpty()) return null
require(taskApproaches.size == 1) { "More than one approach for task $task: $taskApproaches" }
return "Approach ${taskApproaches.first()} for task ${task.name} was registered after the ML API Platform was initialized"
}
private fun List<MLTaskApproachInitializer<*>>.validate() {
val duplicateInitializerPerTask = this.groupBy { it.task }.filter { it.value.size > 1 }
require(duplicateInitializerPerTask.isEmpty()) {
"Found more than one approach for the following tasks: ${duplicateInitializerPerTask}"
}
}
}
companion object {
fun MLApiPlatform.getDescriptorsOfTiers(tiers: Set<Tier<*>>): PerTier<List<TierDescriptor>> {
val descriptorsPerTier = tierDescriptors.groupBy { it.tier }
return tiers.associateWith { descriptorsPerTier[it] ?: emptyList() }
}
fun MLApiPlatform.ensureApproachesInitialized() {
when (val stage = initializationStage) {
is InitializationStage.Failed -> throw Exception("Unable to ensure that approaches are initialized", stage.asException)
InitializationStage.NotStarted -> finishedInitialization.value
InitializationStage.InitializingApproaches -> throw Exception("Recursion detected while initializing approaches")
is InitializationStage.InitializingFUS -> return
InitializationStage.Finished -> return
}
}
fun <R, P : Any> MLApiPlatform.getJoinedListenerForTask(taskApproach: MLTaskApproach<P>,
permanentSessionEnvironment: Environment): MLApproachListener<R, P> {
val relevantGroupListeners = taskListeners.filter { taskApproach.javaClass in it.targetedApproaches }
val approachListeners = relevantGroupListeners.mapNotNull {
it.onAttemptedToStartSession<P, R>(taskApproach, permanentSessionEnvironment)
}
return approachListeners.asJoinedListener()
}
}
}

View File

@@ -0,0 +1,31 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.approach
import com.intellij.platform.ml.FeatureDeclaration
import com.intellij.platform.ml.PerTier
import com.intellij.platform.ml.impl.model.MLModel
import com.intellij.platform.ml.impl.session.AnalysedRootContainer
import com.intellij.platform.ml.impl.session.DescribedRootContainer
import org.jetbrains.annotations.ApiStatus
import java.util.concurrent.CompletableFuture
/**
* Represents the method, that is utilized for the [LogDrivenModelInference]'s analysis.
*/
@ApiStatus.Internal
interface AnalysisMethod<M : MLModel<P>, P : Any> {
/**
* Static declaration of the features, that are used in the session tree's analysis.
*/
val structureAnalysisDeclaration: PerTier<Set<FeatureDeclaration<*>>>
/**
* Static declaration of the session's entities, that are not tiers.
*/
val sessionAnalysisDeclaration: Map<String, Set<FeatureDeclaration<*>>>
/**
* Perform the completed session's analysis.
*/
fun analyseTree(treeRoot: DescribedRootContainer<M, P>): CompletableFuture<AnalysedRootContainer<P>>
}

View File

@@ -0,0 +1,62 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.approach
import com.intellij.platform.ml.Feature
import com.intellij.platform.ml.FeatureDeclaration
import com.intellij.platform.ml.impl.model.MLModel
import com.intellij.platform.ml.impl.session.DescribedRootContainer
import com.intellij.platform.ml.impl.session.analysis.*
import org.jetbrains.annotations.ApiStatus
import java.util.concurrent.CompletableFuture
/**
* The session's assembled analysis declaration.
*/
@ApiStatus.Internal
data class GroupedAnalysisDeclaration(
val structureAnalysis: StructureAnalysisDeclaration,
val mlModelAnalysis: Set<FeatureDeclaration<*>>
)
/**
* The session's assembled analysis itself.
*/
@ApiStatus.Internal
data class GroupedAnalysis<M : MLModel<P>, P : Any>(
val structureAnalysis: StructureAnalysis<M, P>,
val mlModelAnalysis: Set<Feature>
)
/**
* Analyzes both structure and ML model.
*/
@ApiStatus.Internal
class JoinedGroupedSessionAnalyser<M : MLModel<P>, P : Any>(
private val structureAnalysers: Collection<StructureAnalyser<M, P>>,
private val mlModelAnalysers: Collection<MLModelAnalyser<M, P>>,
) : SessionAnalyser<GroupedAnalysisDeclaration, GroupedAnalysis<M, P>, M, P> {
override val analysisDeclaration = GroupedAnalysisDeclaration(
structureAnalysis = SessionStructureAnalysisJoiner<M, P>().joinDeclarations(structureAnalysers.map { it.analysisDeclaration }),
mlModelAnalysis = MLModelAnalysisJoiner<M, P>().joinDeclarations(mlModelAnalysers.map { it.analysisDeclaration })
)
override fun analyse(sessionTreeRoot: DescribedRootContainer<M, P>): CompletableFuture<GroupedAnalysis<M, P>> {
val joinedStructureAnalyser = JoinedSessionAnalyser(
structureAnalysers, SessionStructureAnalysisJoiner()
)
val joinedMLModelAnalyser = JoinedSessionAnalyser(
mlModelAnalysers, MLModelAnalysisJoiner()
)
val structureAnalysis = joinedStructureAnalyser.analyse(sessionTreeRoot)
val mlModelAnalysis = joinedMLModelAnalyser.analyse(sessionTreeRoot)
val futureGroupAnalysis = CompletableFuture.allOf(structureAnalysis, mlModelAnalysis)
val completeGroupAnalysis = CompletableFuture<GroupedAnalysis<M, P>>()
futureGroupAnalysis.thenRun {
completeGroupAnalysis.complete(GroupedAnalysis(structureAnalysis.get(), mlModelAnalysis.get()))
}
return completeGroupAnalysis
}
}

View File

@@ -0,0 +1,36 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.approach
import com.intellij.lang.Language
import com.intellij.platform.ml.Feature
import com.intellij.platform.ml.FeatureDeclaration
import com.intellij.platform.ml.impl.model.MLModel
import com.intellij.platform.ml.impl.session.DescribedRootContainer
import com.intellij.platform.ml.impl.session.analysis.MLModelAnalyser
import org.jetbrains.annotations.ApiStatus
import java.util.concurrent.CompletableFuture
/**
* Something, that is dedicated for one language only.
*/
@ApiStatus.Internal
interface LanguageSpecific {
val languageId: String
}
/**
* The analyzer, that adds information about ML model's language to logs.
*/
@ApiStatus.Internal
class ModelLanguageAnalyser<M, P : Any> : MLModelAnalyser<M, P>
where M : MLModel<P>,
M : LanguageSpecific {
private val LANGUAGE_ID = FeatureDeclaration.categorical("language_id", Language.getRegisteredLanguages().map { it.id }.toSet())
override val analysisDeclaration = setOf(LANGUAGE_ID)
override fun analyse(sessionTreeRoot: DescribedRootContainer<M, P>): CompletableFuture<Set<Feature>> = CompletableFuture.completedFuture(
setOf(LANGUAGE_ID with sessionTreeRoot.root.languageId)
)
}

View File

@@ -0,0 +1,231 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.approach
import com.intellij.openapi.diagnostic.thisLogger
import com.intellij.platform.ml.*
import com.intellij.platform.ml.ScopeEnvironment.Companion.accessibleSafelyByOrNull
import com.intellij.platform.ml.impl.*
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform.Companion.getDescriptorsOfTiers
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform.Companion.getJoinedListenerForTask
import com.intellij.platform.ml.impl.environment.ExtendedEnvironment
import com.intellij.platform.ml.impl.model.MLModel
import com.intellij.platform.ml.impl.monitoring.MLApproachListener
import com.intellij.platform.ml.impl.monitoring.MLSessionListener
import com.intellij.platform.ml.impl.session.*
import org.jetbrains.annotations.ApiStatus
/**
* The main way to apply classical machine learning approaches: run the ML model, collect the logs, retrain the model, repeat.
*
* @param task The task that is solved by this approach.
* @param apiPlatform The platform, that the approach will be running within.
*/
@ApiStatus.Internal
abstract class LogDrivenModelInference<M : MLModel<P>, P : Any>(
override val task: MLTask<P>,
override val apiPlatform: MLApiPlatform
) : MLTaskApproach<P> {
/**
* The method that is used to analyze sessions.
*
* [StructureAndModelAnalysis] is currently used analysis method
* that is dedicated to analyze session's tree-like structure,
* and the ML model.
*/
abstract val analysisMethod: AnalysisMethod<M, P>
/**
* Provides an ML model to use during session's lifetime.
*/
abstract val mlModelProvider: MLModel.Provider<M, P>
/**
* Declares features, that are not used by the ML model, but must be computed anyway,
* so they make it to logs.
*
* A feature cannot be simultaneously declared as "not used description" and as used by the [mlModelProvider]'s
* provided model.
* If a feature is not declared as "not used but still computed" or as "used by the model", then it will be computed.
*
* It must contain explicitly declared selectors for each tier used in [task], as well as in [additionallyDescribedTiers].
*/
abstract val notUsedDescription: PerTier<FeatureSelector>
/**
* Performs description's computation.
* Could perform caching mechanisms to avoid recomputing features every time.
*/
abstract val descriptionComputer: DescriptionComputer
/**
* Tiers that do not make a part of te [task], but they could be described and passed to the ML model.
*
* The size of this list must correspond to the number of levels in the solved [task].
*/
abstract val additionallyDescribedTiers: List<Set<Tier<*>>>
private val levels: List<LevelTiers> by lazy {
(task.levels zip additionallyDescribedTiers).map { Level(it.first, it.second) }
}
private val approachValidation: Unit by lazy { validateApproach() }
override fun startSession(permanentSessionEnvironment: Environment): Session.StartOutcome<P> {
return startSessionMonitoring(permanentSessionEnvironment)
}
private fun startSessionMonitoring(permanentSessionEnvironment: Environment): Session.StartOutcome<P> {
val approachListener = apiPlatform.getJoinedListenerForTask<M, P>(this, permanentSessionEnvironment)
try {
return acquireModelAndStartSession(permanentSessionEnvironment, approachListener)
}
catch (e: Throwable) {
approachListener.onFailedToStartSessionWithException(e)
return Session.StartOutcome.UncaughtException(e)
}
}
private fun acquireModelAndStartSession(permanentSessionEnvironment: Environment,
approachListener: MLApproachListener<M, P>): Session.StartOutcome<P> {
approachValidation
val extendedPermanentSessionEnvironment = ExtendedEnvironment(
apiPlatform.environmentExtenders,
permanentSessionEnvironment,
mlModelProvider.requiredTiers
)
val mlModel: M = run {
val mlModelProviderEnvironment = extendedPermanentSessionEnvironment.accessibleSafelyByOrNull(mlModelProvider)
if (mlModelProviderEnvironment == null) {
val failure = InsufficientEnvironmentForModelProviderOutcome<P>(mlModelProvider.requiredTiers,
extendedPermanentSessionEnvironment.tiers)
approachListener.onFailedToStartSession(failure)
return failure
}
val nullableMlModel = mlModelProvider.provideModel(levels, mlModelProviderEnvironment)
if (nullableMlModel == null) {
val failure = ModelNotAcquiredOutcome<P>()
approachListener.onFailedToStartSession(failure)
return failure
}
nullableMlModel
}
var sessionListener: MLSessionListener<M, P>? = null
val analyseThenLogStructure = SessionTreeHandler<DescribedRootContainer<M, P>, M, P> { treeRoot ->
sessionListener?.onSessionDescriptionFinished(treeRoot)
analysisMethod.analyseTree(treeRoot).thenApplyAsync { analysedSession ->
sessionListener?.onSessionAnalysisFinished(analysedSession)
}.exceptionally {
thisLogger().error(it)
}
}
val session = if (levels.size == 1) {
val collector = SolitaryLeafCollector(
apiPlatform, levels.first(), descriptionComputer, notUsedDescription,
permanentSessionEnvironment, levels.first().additional, mlModel
)
collector.handleCollectedTree(analyseThenLogStructure)
MLModelPrediction(mlModel, collector)
}
else {
val collector = RootCollector(
apiPlatform, levels, descriptionComputer, notUsedDescription,
permanentSessionEnvironment, levels.first().additional, mlModel
)
collector.handleCollectedTree(analyseThenLogStructure)
MLModelPredictionBranching(mlModel, collector)
}
sessionListener = approachListener.onStartedSession(session)
return Session.StartOutcome.Success(session)
}
override val approachDeclaration: MLTaskApproach.Declaration
get() {
approachValidation
return MLTaskApproach.Declaration(
sessionFeatures = analysisMethod.sessionAnalysisDeclaration,
levelsScheme = levels.map { levelTiers ->
Level(
buildMainTiersScheme(levelTiers.main, apiPlatform),
buildAdditionalTiersScheme(levelTiers.additional, apiPlatform),
)
}
)
}
private fun validateApproach() {
require(task.levels.size == additionallyDescribedTiers.size) {
"Task $task has ${task.levels.size} levels, when 'additionallyDescribedTiers' has ${additionallyDescribedTiers.size}"
}
require(levels.isNotEmpty()) { "Task must declare at least one level" }
val maybeDuplicatedTaskTiers = levels.flatMap { it.main + it.additional }
val taskTiers = maybeDuplicatedTaskTiers.toSet()
require(maybeDuplicatedTaskTiers.size == taskTiers.size) {
"There are duplicated tiers in the declaration: ${maybeDuplicatedTaskTiers - taskTiers}"
}
require(notUsedDescription.keys == taskTiers) {
"Selectors for those and only those tiers must be represented in the 'notUsedDescription' that are present in the task. " +
"Missing: ${taskTiers - notUsedDescription.keys}, " +
"Redundant: ${notUsedDescription.keys - taskTiers}"
}
}
private fun buildTierDescriptionDeclaration(tierDescriptors: Collection<TierDescriptor>): Set<FeatureDeclaration<*>> {
return tierDescriptors.flatMap {
if (it is ObsoleteTierDescriptor) it.partialDescriptionDeclaration else it.descriptionDeclaration
}.toSet()
}
private fun buildMainTiersScheme(tiers: Set<Tier<*>>, apiEnvironment: MLApiPlatform): PerTier<MainTierScheme> {
val tiersDescriptors = apiEnvironment.getDescriptorsOfTiers(tiers)
return tiers.associateWith { tier ->
val tierDescriptors = tiersDescriptors.getValue(tier)
val descriptionDeclaration = buildTierDescriptionDeclaration(tierDescriptors)
val analysisDeclaration = analysisMethod.structureAnalysisDeclaration[tier] ?: emptySet()
MainTierScheme(descriptionDeclaration, analysisDeclaration)
}
}
private fun buildAdditionalTiersScheme(tiers: Set<Tier<*>>, apiEnvironment: MLApiPlatform): PerTier<AdditionalTierScheme> {
val tiersDescriptors = apiEnvironment.getDescriptorsOfTiers(tiers)
return tiers.associateWith { tier ->
val tierDescriptors = tiersDescriptors.getValue(tier)
val descriptionDeclaration = buildTierDescriptionDeclaration(tierDescriptors)
AdditionalTierScheme(descriptionDeclaration)
}
}
}
/**
* An exception that indicates that for some reason, it was not possible to provide an ML model
* when calling [com.intellij.platform.ml.impl.model.MLModel.Provider.provideModel].
* The session's start is considered as failed.
*/
@ApiStatus.Internal
class ModelNotAcquiredOutcome<P : Any> : Session.StartOutcome.Failure<P>() {
override val failureDetails: String = "ML Model was not provided"
}
/**
* There were not enough tiers to satisfy [MLModel.Provider]'s requirements, so it could not provide the model.
*/
@ApiStatus.Internal
class InsufficientEnvironmentForModelProviderOutcome<P : Any>(
expectedTiers: Set<Tier<*>>,
existingTiers: Set<Tier<*>>
) : Session.StartOutcome.Failure<P>() {
override val failureDetails: String = "ML Model could not be provided: environment is not sufficient. Missing: ${expectedTiers - existingTiers}"
}

View File

@@ -0,0 +1,81 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.approach
import com.intellij.platform.ml.FeatureDeclaration
import com.intellij.platform.ml.PerTierInstance
import com.intellij.platform.ml.impl.model.MLModel
import com.intellij.platform.ml.impl.session.*
import com.intellij.platform.ml.impl.session.analysis.MLModelAnalyser
import com.intellij.platform.ml.impl.session.analysis.StructureAnalyser
import com.intellij.platform.ml.impl.session.analysis.StructureAnalysisDeclaration
import org.jetbrains.annotations.ApiStatus
import java.util.concurrent.CompletableFuture
/**
* An analysis method, that asynchronously analyzes session structure after
* it finished, and the ML model that has been used during the session.
*
* @param structureAnalysers Analyzing session's structure - main tier instances.
* @param mlModelAnalysers Analyzing the ML model which was producing predictions during the session.
* @param sessionAnalysisKeyModel Key that will be used in logs, to write ML model's features into.
*/
@ApiStatus.Internal
class StructureAndModelAnalysis<M : MLModel<P>, P : Any>(
structureAnalysers: Collection<StructureAnalyser<M, P>>,
mlModelAnalysers: Collection<MLModelAnalyser<M, P>>,
private val sessionAnalysisKeyModel: String = DEFAULT_SESSION_KEY_ML_MODEL
) : AnalysisMethod<M, P> {
private val groupedAnalyser = JoinedGroupedSessionAnalyser(structureAnalysers, mlModelAnalysers)
override val structureAnalysisDeclaration: StructureAnalysisDeclaration
get() = groupedAnalyser.analysisDeclaration.structureAnalysis
override val sessionAnalysisDeclaration: Map<String, Set<FeatureDeclaration<*>>> = mapOf(
sessionAnalysisKeyModel to groupedAnalyser.analysisDeclaration.mlModelAnalysis
)
override fun analyseTree(treeRoot: DescribedRootContainer<M, P>): CompletableFuture<AnalysedRootContainer<P>> {
return groupedAnalyser.analyse(treeRoot).thenApply {
buildAnalysedSessionTree(treeRoot, it) as AnalysedRootContainer<P>
}
}
private fun buildAnalysedSessionTree(tree: DescribedSessionTree<M, P>, analysis: GroupedAnalysis<M, P>): AnalysedSessionTree<P> {
val treeAnalysisPerInstance: PerTierInstance<AnalysedTierData> = tree.level.main.entries
.associate { (tierInstance, data) ->
tierInstance to AnalysedTierData(data.description,
analysis.structureAnalysis[tree]?.get(tierInstance.tier) ?: emptySet())
}
val analysedLevel = Level(
main = treeAnalysisPerInstance,
additional = tree.level.additional
)
return when (tree) {
is SessionTree.Branching<M, DescribedLevel, P> -> {
SessionTree.Branching(analysedLevel,
tree.children.map { buildAnalysedSessionTree(it, analysis) })
}
is SessionTree.Leaf<M, DescribedLevel, P> -> {
SessionTree.Leaf(analysedLevel, tree.prediction)
}
is SessionTree.ComplexRoot -> {
SessionTree.ComplexRoot(mapOf(sessionAnalysisKeyModel to analysis.mlModelAnalysis),
analysedLevel,
tree.children.map { buildAnalysedSessionTree(it, analysis) }
)
}
is SessionTree.SolitaryLeaf -> {
SessionTree.SolitaryLeaf(mapOf(sessionAnalysisKeyModel to analysis.mlModelAnalysis),
analysedLevel,
tree.prediction
)
}
}
}
companion object {
const val DEFAULT_SESSION_KEY_ML_MODEL: String = "ml_model"
}
}

View File

@@ -0,0 +1,37 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.approach
import com.intellij.openapi.util.Version
import com.intellij.platform.ml.Feature
import com.intellij.platform.ml.FeatureDeclaration
import com.intellij.platform.ml.impl.model.MLModel
import com.intellij.platform.ml.impl.session.DescribedRootContainer
import com.intellij.platform.ml.impl.session.analysis.MLModelAnalyser
import org.jetbrains.annotations.ApiStatus
import java.util.concurrent.CompletableFuture
/**
* Something, that has versions.
*/
@ApiStatus.Internal
interface Versioned {
val version: Version?
}
/**
* Adds model's version to the ML logs.
*/
@ApiStatus.Internal
class ModelVersionAnalyser<M, P : Any> : MLModelAnalyser<M, P>
where M : MLModel<P>,
M : Versioned {
companion object {
val VERSION = FeatureDeclaration.version("version").nullable()
}
override val analysisDeclaration = setOf(VERSION)
override fun analyse(sessionTreeRoot: DescribedRootContainer<M, P>): CompletableFuture<Set<Feature>> = CompletableFuture.completedFuture(
setOf(VERSION with sessionTreeRoot.root.version)
)
}

View File

@@ -0,0 +1,35 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.environment
import com.intellij.platform.ml.EnvironmentExtender
import com.intellij.platform.ml.Tier
import org.jetbrains.annotations.ApiStatus
/**
* There was circle within available extenders' set's requirements.
* Meaning that to create some tier, we must eventually run the extender itself.
*/
@ApiStatus.Internal
class CircularRequirementException(val extensionPath: List<EnvironmentExtender<*>>) : IllegalArgumentException() {
override val message: String = "A circular resolve path found among EnvironmentExtenders: ${serializePath()}"
private fun serializePath(): String {
val extensions = extensionPath.map { extender -> "[$extender] -> ${extender.extendingTier.name}" }
return extensions.joinToString(" - ")
}
}
/**
* An algorithm for resolving order of [EnvironmentExtender]s' execution.
*/
@ApiStatus.Internal
interface EnvironmentResolver {
/**
* @return The order, which guarantees that for each extender, all the requirements will be fulfilled by the previously runned
* extenders.
* But it still could happen that an [EnvironmentExtender] will not return the tier it extends, then some subsequent
* extenders' requirements could not be satisfied.
* @throws CircularRequirementException
*/
fun resolve(extenderPerTier: Map<Tier<*>, EnvironmentExtender<*>>): List<EnvironmentExtender<*>>
}

View File

@@ -0,0 +1,139 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.environment
import com.intellij.platform.ml.*
import com.intellij.platform.ml.EnvironmentExtender.Companion.extendTierInstance
import com.intellij.platform.ml.ScopeEnvironment.Companion.accessibleSafelyByOrNull
import org.jetbrains.annotations.ApiStatus
/**
* An environment that is built in to fulfill a [TierRequester]'s requirements.
*
* When built, it accepts all available extenders, and it is trying to resolve their
* order to extend particular tiers.
*
* If there are some main tiers passed to build the extended environment, they will
* not be overridden.
*
* Among the available extenders, passed to the constructor, there could be some that
* extend the same tier.
* This could signify that an instance of the same tier can be mined from
* different sets of objects (requirements).
* But if there is more than one extender, that could potentially be run that will
* extend a tier of the same instance, then [IllegalArgumentException] will be thrown
* telling, that there is an ambiguity.
*
* When we have determined which extenders could potentially be run, we try to determine
* the order with topological sort: [com.intellij.platform.ml.impl.environment.TopologicalSortingResolver].
* If there is a circle in the requirements, it will throw [com.intellij.platform.ml.impl.environment.CircularRequirementException].
*/
@ApiStatus.Internal
class ExtendedEnvironment : Environment {
private val storage: Environment
/**
* @param environmentExtenders Extenders, that will be used to extend the [tiersToExtend].
* @param mainEnvironment Tiers that are already determined and shall not be replaced.
* @param tiersToExtend Tiers that should be put to the extended environment via [environmentExtenders].
* It could not be guaranteed that all desired tiers will be extended.
*
* @return An environment that contains all tiers from [mainEnvironment] plus
* some tiers from [tiersToExtend], if it was possible to extend them.
*/
constructor(environmentExtenders: List<EnvironmentExtender<*>>,
mainEnvironment: Environment,
tiersToExtend: Set<Tier<*>>) {
require(tiersToExtend.all { it !in mainEnvironment })
val nonOverridingExtenders = environmentExtenders.filter { it.extendingTier !in mainEnvironment }
storage = buildExtendedEnvironment(
tiersToExtend + mainEnvironment.tiers,
nonOverridingExtenders + mainEnvironment.separateIntoExtenders()
)
}
/**
* @param environmentExtenders Extenders that will be utilized to build the extended environment.
* @param mainEnvironment An already existing environment, instances from which shall not be overridden.
*
* @return An environment that contains all tiers from [mainEnvironment] plus
* all tiers that it was possible to acquire via [environmentExtenders].
*/
constructor(environmentExtenders: List<EnvironmentExtender<*>>,
mainEnvironment: Environment) {
val nonOverridingExtenders = environmentExtenders.filter { it.extendingTier !in mainEnvironment }
storage = buildExtendedEnvironment(
nonOverridingExtenders.map { it.extendingTier }.toSet() + mainEnvironment.tiers,
nonOverridingExtenders + mainEnvironment.separateIntoExtenders()
)
}
override val tiers: Set<Tier<*>>
get() = storage.tiers
override fun <T : Any> getInstance(tier: Tier<T>): T {
return storage.getInstance(tier)
}
companion object {
private val ENVIRONMENT_RESOLVER = TopologicalSortingResolver()
/**
* Creates an [Environment] that contents tiers from [tiers], that were successfully extended by [extenders]
*/
private fun buildExtendedEnvironment(tiers: Set<Tier<*>>,
extenders: List<EnvironmentExtender<*>>): Environment {
val validatedExtendersPerTier = validateExtenders(tiers, extenders)
val extensionOrder = ENVIRONMENT_RESOLVER.resolve(validatedExtendersPerTier)
val storage = TierInstanceStorage()
extensionOrder.map {
val safelyAccessibleEnvironment = storage.accessibleSafelyByOrNull(it) ?: return@map
it.extendTierInstance(safelyAccessibleEnvironment)?.let { extendedTierInstance ->
storage.putTierInstance(extendedTierInstance)
}
}
return storage
}
private fun validateExtenders(tiers: Set<Tier<*>>, extenders: List<EnvironmentExtender<*>>): Map<Tier<*>, EnvironmentExtender<*>> {
val extendableTiers: Set<Tier<*>> = extenders.map { it.extendingTier }.toSet()
val runnableExtenders = extenders
.filter { desiredExtender ->
desiredExtender.requiredTiers.all { requirementForDesiredExtender -> requirementForDesiredExtender in extendableTiers }
}
val ambiguouslyExtendableTiers: MutableList<Pair<Tier<*>, List<EnvironmentExtender<*>>>> = mutableListOf()
val extendersPerTier: Map<Tier<*>, EnvironmentExtender<*>> = runnableExtenders
.groupBy { it.extendingTier }
.mapNotNull { (tier, tierExtenders) ->
if (tierExtenders.size > 1) {
ambiguouslyExtendableTiers.add(tier to tierExtenders)
null
}
else
tierExtenders.first()
}
.associateBy { it.extendingTier }
.filterKeys { it in tiers }
require(ambiguouslyExtendableTiers.isEmpty()) { "Some tiers could be extended ambiguously: $ambiguouslyExtendableTiers" }
return extendersPerTier
}
}
}
private fun Environment.separateIntoExtenders(): List<EnvironmentExtender<*>> {
class ContainingExtender<T : Any>(private val tier: Tier<T>) : EnvironmentExtender<T> {
override val extendingTier: Tier<T> = tier
override fun extend(environment: Environment): T {
return this@separateIntoExtenders[tier]
}
override val requiredTiers: Set<Tier<*>> = emptySet()
}
return this.tiers.map { tier -> ContainingExtender(tier) }
}

View File

@@ -0,0 +1,49 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.environment
import com.intellij.platform.ml.EnvironmentExtender
import com.intellij.platform.ml.Tier
import org.jetbrains.annotations.ApiStatus
private typealias Node = EnvironmentExtender<*>
/**
* Resolves order using topological sort
*/
@ApiStatus.Internal
class TopologicalSortingResolver : EnvironmentResolver {
override fun resolve(extenderPerTier: Map<Tier<*>, EnvironmentExtender<*>>): List<EnvironmentExtender<*>> {
val graph: Map<Node, List<Node>> = extenderPerTier.values
.associateWith { desiredExtender ->
desiredExtender.requiredTiers.map { requirementForDesiredExtender -> extenderPerTier.getValue(requirementForDesiredExtender) }
}
val reverseTopologicalOrder: MutableList<Node> = mutableListOf()
val resolveStatus: MutableMap<Node, ResolveState> = mutableMapOf()
fun Node.resolve(path: List<Node>) {
when (resolveStatus[this]) {
ResolveState.STARTED -> throw CircularRequirementException(path + this)
ResolveState.RESOLVED -> return
null -> {
resolveStatus[this] = ResolveState.STARTED
for (nextNode in graph.getValue(this)) {
nextNode.resolve(path + this)
}
resolveStatus[this] = ResolveState.RESOLVED
reverseTopologicalOrder.add(this)
}
}
}
graph.keys.forEach { it.resolve(emptyList()) }
return reverseTopologicalOrder
}
private enum class ResolveState {
STARTED,
RESOLVED
}
}

View File

@@ -0,0 +1,49 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.logger
import com.intellij.internal.statistic.eventLog.events.EventPair
import com.intellij.internal.statistic.eventLog.events.ObjectDescription
import com.intellij.internal.statistic.eventLog.events.ObjectEventData
import com.intellij.platform.ml.impl.MLTaskApproach
import com.intellij.platform.ml.impl.session.AnalysedRootContainer
import com.intellij.platform.ml.impl.session.AnalysedSessionTree
import org.jetbrains.annotations.ApiStatus
/**
* Represents FUS fields of a session's subtree.
*/
@ApiStatus.Internal
abstract class SessionFields<P : Any> : ObjectDescription() {
fun buildObjectEventData(sessionStructure: AnalysedSessionTree<P>) = ObjectEventData(buildEventPairs(sessionStructure))
abstract fun buildEventPairs(sessionStructure: AnalysedSessionTree<P>): List<EventPair<*>>
}
/**
* Represents a logging scheme for the FUS event.
*
* @param P The type of the ML task's prediction
*/
@ApiStatus.Internal
interface FusSessionEventBuilder<P : Any> {
/**
* Configuration of a [FusSessionEventBuilder], that builds it when accepts approach's declaration.
*/
interface FusScheme<P : Any> {
fun createEventBuilder(approachDeclaration: MLTaskApproach.Declaration): FusSessionEventBuilder<P>
}
/**
* Builds declaration of all features, that will be logged for the session tiers' description and analysis.
* It is required because FUS logs validators are built 'statically'.
*/
fun buildFusDeclaration(): SessionFields<P>
/**
* Builds a concrete FUS record, that contains fields that were built by [buildFusDeclaration].
*
* @param sessionStructure A session tree that been already analyzed and is ready to be logged.
* @param sessionFields Session fields that were built by [buildFusDeclaration] earlier.
*/
fun buildRecord(sessionStructure: AnalysedRootContainer<P>, sessionFields: SessionFields<P>): Array<EventPair<*>>
}

View File

@@ -0,0 +1,390 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.logger
import com.intellij.internal.statistic.eventLog.FeatureUsageData
import com.intellij.internal.statistic.eventLog.events.*
import com.intellij.platform.ml.*
import com.intellij.platform.ml.impl.*
import com.intellij.platform.ml.impl.session.*
import org.jetbrains.annotations.ApiStatus
/**
* The currently used FUS logging scheme.
* Inplace means that the features are logged beside the tiers' instances and
* are not compressed in any way.
*
* An example of a FUS record could be found in the test resources:
* [testResources/ml_logs.js](community/platform/ml-impl/testResources/ml_logs.js)
*/
@ApiStatus.Internal
class InplaceFeaturesScheme<P : Any> internal constructor(
private val predictionValidationRule: List<String>,
private val predictionTransform: (P?) -> String,
private val approachDeclaration: MLTaskApproach.Declaration
) : FusSessionEventBuilder<P> {
class FusScheme<P : Any>(
private val predictionValidationRule: List<String>,
private val predictionTransform: (P?) -> String
) : FusSessionEventBuilder.FusScheme<P> {
override fun createEventBuilder(approachDeclaration: MLTaskApproach.Declaration): FusSessionEventBuilder<P> = InplaceFeaturesScheme(
predictionValidationRule,
predictionTransform,
approachDeclaration
)
companion object {
val DOUBLE: FusScheme<Double> = FusScheme(listOf("{regexp#float}")) { it.toString() }
}
}
override fun buildFusDeclaration(): SessionFields<P> {
require(approachDeclaration.levelsScheme.isNotEmpty())
return if (approachDeclaration.levelsScheme.size == 1)
PredictionSessionFields(approachDeclaration.levelsScheme.first(), predictionValidationRule, predictionTransform,
approachDeclaration.sessionFeatures)
else
NestableSessionFields(approachDeclaration.levelsScheme.first(), approachDeclaration.levelsScheme.drop(1), predictionValidationRule,
predictionTransform, approachDeclaration.sessionFeatures)
}
override fun buildRecord(sessionStructure: AnalysedRootContainer<P>, sessionFields: SessionFields<P>): Array<EventPair<*>> {
return sessionFields.buildEventPairs(sessionStructure).toTypedArray()
}
}
private class PredictionField<T : Any>(
override val name: String,
override val validationRule: List<String>,
val transform: (T?) -> String
) : PrimitiveEventField<T?>() {
override fun addData(fuData: FeatureUsageData, value: T?) {
fuData.addData(name, transform(value))
}
}
private data class StringField(override val name: String, private val possibleValues: Set<String>) : PrimitiveEventField<String>() {
override fun addData(fuData: FeatureUsageData, value: String) {
fuData.addData(name, value)
}
override val validationRule = listOf(
"{enum:${possibleValues.joinToString("|")}}"
)
}
private data class VersionField(override val name: String) : PrimitiveEventField<String?>() {
override val validationRule: List<String>
get() = listOf("{regexp#version}")
override fun addData(fuData: FeatureUsageData, value: String?) {
fuData.addVersionByString(value)
}
}
private fun FeatureDeclaration<*>.toEventField(): EventField<*> {
return when (val valueType = type) {
is FeatureValueType.Enum<*> -> EnumEventField(name, valueType.enumClass, Enum<*>::name)
is FeatureValueType.Int -> IntEventField(name)
is FeatureValueType.Long -> LongEventField(name)
is FeatureValueType.Class -> ClassEventField(name)
is FeatureValueType.Boolean -> BooleanEventField(name)
is FeatureValueType.Double -> DoubleEventField(name)
is FeatureValueType.Float -> FloatEventField(name)
is FeatureValueType.Nullable -> FeatureDeclaration(name, valueType.baseType).toEventField()
is FeatureValueType.Categorical -> StringField(name, valueType.possibleValues)
is FeatureValueType.Version -> VersionField(name)
}
}
private fun <T : Enum<*>> Feature.Enum<T>.toEventPair(): EventPair<*> {
return EnumEventField(declaration.name, valueType.enumClass, Enum<*>::name) with value
}
private fun <T> Feature.Nullable<T>.toEventPair(): EventPair<*>? {
return value?.let {
baseType.instantiate(this.declaration.name, it).toEventPair()
}
}
private fun Feature.toEventPair(): EventPair<*>? {
return when (this) {
is Feature.TypedFeature<*> -> typedToEventPair()
}
}
private fun <T> Feature.TypedFeature<T>.typedToEventPair(): EventPair<*>? {
return when (this) {
is Feature.Boolean -> BooleanEventField(declaration.name) with this.value
is Feature.Categorical -> StringField(declaration.name, this.valueType.possibleValues) with this.value
is Feature.Class -> ClassEventField(declaration.name) with this.value
is Feature.Double -> DoubleEventField(declaration.name) with this.value
is Feature.Enum<*> -> toEventPair()
is Feature.Float -> FloatEventField(declaration.name) with this.value
is Feature.Int -> IntEventField(declaration.name) with this.value
is Feature.Long -> LongEventField(declaration.name) with this.value
is Feature.Nullable<*> -> toEventPair()
is Feature.Version -> VersionField(declaration.name) with this.value.toString()
}
}
private class FeatureSet(featuresDeclarations: Set<FeatureDeclaration<*>>) : ObjectDescription() {
init {
for (featureDeclaration in featuresDeclarations) {
field(featureDeclaration.toEventField())
}
}
fun toObjectEventData(features: Set<Feature>) = ObjectEventData(features.mapNotNull { it.toEventPair() })
}
private fun Set<Feature>.toObjectEventData() = FeatureSet(this.map { it.declaration }.toSet()).toObjectEventData(this)
private data class TierDescriptionFields(
val used: FeatureSet,
val notUsed: FeatureSet,
) : ObjectDescription() {
private val fieldUsed = ObjectEventField("used", used)
private val fieldNotUsed = ObjectEventField("not_used", notUsed)
private val fieldAmountUsedNonDeclaredFeatures = IntEventField("n_used_non_declared")
private val fieldAmountNotUsedNonDeclaredFeatures = IntEventField("n_not_used_non_declared")
init {
field(fieldUsed)
field(fieldNotUsed)
field(fieldAmountUsedNonDeclaredFeatures)
field(fieldAmountNotUsedNonDeclaredFeatures)
}
fun buildEventPairs(descriptionPartition: DescriptionPartition): List<EventPair<*>> {
val result = mutableListOf<EventPair<*>>(
fieldUsed with descriptionPartition.declared.used.toObjectEventData(),
fieldNotUsed with descriptionPartition.declared.notUsed.toObjectEventData(),
)
descriptionPartition.nonDeclared.used.let {
if (it.isNotEmpty()) result += fieldAmountUsedNonDeclaredFeatures with it.size
}
descriptionPartition.nonDeclared.notUsed.let {
if (it.isNotEmpty()) result += fieldAmountUsedNonDeclaredFeatures with it.size
}
return result
}
fun buildObjectEventData(descriptionPartition: DescriptionPartition) = ObjectEventData(
buildEventPairs(descriptionPartition)
)
}
private data class AdditionalTierFields(val description: TierDescriptionFields) : ObjectDescription() {
private val fieldInstanceId = IntEventField("id")
private val fieldDescription = ObjectEventField("description", description)
constructor(descriptionFeatures: Set<FeatureDeclaration<*>>)
: this(TierDescriptionFields(used = FeatureSet(descriptionFeatures),
notUsed = FeatureSet(descriptionFeatures)))
init {
field(fieldInstanceId)
field(fieldDescription)
}
fun buildObjectEventData(tierInstance: TierInstance<*>,
descriptionPartition: DescriptionPartition) = ObjectEventData(
fieldInstanceId with tierInstance.instance.hashCode(),
fieldDescription with this.description.buildObjectEventData(descriptionPartition),
)
}
private data class MainTierFields(
val description: TierDescriptionFields,
val analysis: FeatureSet,
) : ObjectDescription() {
private val fieldInstanceId = IntEventField("id")
private val fieldDescription = ObjectEventField("description", description)
private val fieldAnalysis = ObjectEventField("analysis", analysis)
constructor(descriptionFeatures: Set<FeatureDeclaration<*>>, analysisFeatures: Set<FeatureDeclaration<*>>)
: this(TierDescriptionFields(used = FeatureSet(descriptionFeatures),
notUsed = FeatureSet(descriptionFeatures)),
FeatureSet(analysisFeatures))
init {
field(fieldInstanceId)
field(fieldDescription)
field(fieldAnalysis)
}
fun buildObjectEventData(tierInstance: TierInstance<*>,
descriptionPartition: DescriptionPartition,
analysis: Set<Feature>) = ObjectEventData(
fieldInstanceId with tierInstance.instance.hashCode(),
fieldDescription with this.description.buildObjectEventData(descriptionPartition),
fieldAnalysis with this.analysis.toObjectEventData(analysis)
)
}
private data class SessionAnalysisFields<P : Any>(
val featuresPerKey: Map<String, Set<FeatureDeclaration<*>>>
) : SessionFields<P>() {
val fieldsPerKey: Map<String, ObjectEventField> = featuresPerKey.entries.associate { (key, keyFeatures) ->
key to ObjectEventField(key, FeatureSet(keyFeatures))
}
init {
fieldsPerKey.values.forEach { field(it) }
}
override fun buildEventPairs(sessionStructure: AnalysedSessionTree<P>): List<EventPair<*>> {
require(sessionStructure is SessionTree.RootContainer<SessionAnalysis, AnalysedLevel, P>)
return sessionStructure.root.entries.map { (key, keyFeatures) ->
val keyFeaturesDeclaration = requireNotNull(featuresPerKey[key]) {
"Key $key was not declared as session features key, declared keys: ${featuresPerKey.keys}"
}
val objectEventField = fieldsPerKey.getValue(key)
val keyFeatureSet = FeatureSet(keyFeaturesDeclaration)
objectEventField with keyFeatureSet.toObjectEventData(keyFeatures)
}
}
}
private class MainTierSet<P : Any>(mainTierScheme: PerTier<MainTierScheme>) : SessionFields<P>() {
val tiersDeclarations: PerTier<MainTierFields> = mainTierScheme.entries.associate { (tier, tierScheme) ->
tier to MainTierFields(tierScheme.description, tierScheme.analysis)
}
val fieldPerTier: PerTier<ObjectEventField> = tiersDeclarations.entries.associate { (tier, tierFields) ->
tier to ObjectEventField(tier.name, tierFields)
}
init {
fieldPerTier.values.forEach { field(it) }
}
override fun buildEventPairs(sessionStructure: AnalysedSessionTree<P>): List<EventPair<*>> {
val level = sessionStructure.level.main
return level.entries.map { (tierInstance, data) ->
val tierField = requireNotNull(fieldPerTier[tierInstance.tier]) {
"Tier ${tierInstance.tier} is now allowed here: only ${fieldPerTier.keys} are registered"
}
val tierDeclaration = tiersDeclarations.getValue(tierInstance.tier)
tierField with tierDeclaration.buildObjectEventData(tierInstance, data.description, data.analysis)
}
}
}
private class AdditionalTierSet<P : Any>(additionalTierScheme: PerTier<AdditionalTierScheme>) : SessionFields<P>() {
val tiersDeclarations: PerTier<AdditionalTierFields> = additionalTierScheme.entries.associate { (tier, tierScheme) ->
tier to AdditionalTierFields(tierScheme.description)
}
val fieldPerTier: PerTier<ObjectEventField> = tiersDeclarations.entries.associate { (tier, tierFields) ->
tier to ObjectEventField(tier.name, tierFields)
}
init {
fieldPerTier.values.forEach { field(it) }
}
override fun buildEventPairs(sessionStructure: AnalysedSessionTree<P>): List<EventPair<*>> {
val level = sessionStructure.level.additional
return level.entries.map { (tierInstance, data) ->
val tierField = requireNotNull(fieldPerTier[tierInstance.tier]) {
"Tier ${tierInstance.tier} is now allowed here: only ${fieldPerTier.keys} are registered"
}
val tierDeclaration = tiersDeclarations.getValue(tierInstance.tier)
tierField with tierDeclaration.buildObjectEventData(tierInstance, data.description)
}
}
}
private data class PredictionSessionFields<P : Any>(
val declarationMainTierSet: MainTierSet<P>,
val declarationAdditionalTierSet: AdditionalTierSet<P>,
val predictionValidationRule: List<String>,
val predictionTransform: (P?) -> String,
val sessionAnalysisFields: SessionAnalysisFields<P>?
) : SessionFields<P>() {
private val fieldMainInstances = ObjectEventField("main", declarationMainTierSet)
private val fieldAdditionalInstances = ObjectEventField("additional", declarationAdditionalTierSet)
private val fieldPrediction = PredictionField("prediction", predictionValidationRule, predictionTransform)
private val fieldSessionAnalysis = sessionAnalysisFields?.let { ObjectEventField("session", sessionAnalysisFields) }
constructor(levelScheme: LevelScheme,
predictionValidationRule: List<String>,
predictionTransform: (P?) -> String,
sessionAnalysisFields: Map<String, Set<FeatureDeclaration<*>>>?)
: this(MainTierSet(levelScheme.main),
AdditionalTierSet(levelScheme.additional),
predictionValidationRule,
predictionTransform,
sessionAnalysisFields?.let { SessionAnalysisFields(it) })
init {
field(fieldMainInstances)
field(fieldAdditionalInstances)
field(fieldPrediction)
fieldSessionAnalysis?.let { field(it) }
}
override fun buildEventPairs(sessionStructure: AnalysedSessionTree<P>): List<EventPair<*>> {
require(sessionStructure is SessionTree.Leaf<*, *, P>)
val eventPairs = mutableListOf<EventPair<*>>(
fieldMainInstances with declarationMainTierSet.buildObjectEventData(sessionStructure),
fieldAdditionalInstances with declarationAdditionalTierSet.buildObjectEventData(sessionStructure),
fieldPrediction with sessionStructure.prediction,
)
fieldSessionAnalysis?.let {
eventPairs += fieldSessionAnalysis with sessionAnalysisFields!!.buildObjectEventData(sessionStructure)
}
return eventPairs
}
}
private data class NestableSessionFields<P : Any>(
val declarationMainTierSet: MainTierSet<P>,
val declarationAdditionalTierSet: AdditionalTierSet<P>,
val declarationNestedSession: SessionFields<P>,
val sessionAnalysisFields: SessionAnalysisFields<P>?
) : SessionFields<P>() {
private val fieldMainInstances = ObjectEventField("main", declarationMainTierSet)
private val fieldAdditionalInstances = ObjectEventField("additional", declarationAdditionalTierSet)
private val fieldNestedSessions = ObjectListEventField("nested", declarationNestedSession)
private val fieldSessionAnalysis = sessionAnalysisFields?.let { ObjectEventField("session", sessionAnalysisFields) }
constructor(levelScheme: LevelScheme,
deeperLevelsSchemes: List<LevelScheme>,
predictionValidationRule: List<String>,
predictionTransform: (P?) -> String,
sessionAnalysisFields: Map<String, Set<FeatureDeclaration<*>>>?)
: this(MainTierSet(levelScheme.main),
AdditionalTierSet(levelScheme.additional),
if (deeperLevelsSchemes.size == 1)
PredictionSessionFields(deeperLevelsSchemes.first(), predictionValidationRule, predictionTransform, null)
else {
require(deeperLevelsSchemes.size > 1)
NestableSessionFields(deeperLevelsSchemes.first(), deeperLevelsSchemes.drop(1), predictionValidationRule,
predictionTransform, null)
},
sessionAnalysisFields?.let { SessionAnalysisFields(it) }
)
init {
field(fieldMainInstances)
field(fieldAdditionalInstances)
field(fieldNestedSessions)
fieldSessionAnalysis?.let { field(it) }
}
override fun buildEventPairs(sessionStructure: AnalysedSessionTree<P>): List<EventPair<*>> {
require(sessionStructure is SessionTree.ChildrenContainer)
val children = sessionStructure.children
val eventPairs = mutableListOf<EventPair<*>>(
fieldMainInstances with declarationMainTierSet.buildObjectEventData(sessionStructure),
fieldAdditionalInstances with declarationAdditionalTierSet.buildObjectEventData(sessionStructure),
fieldNestedSessions with children.map { nestedSession -> declarationNestedSession.buildObjectEventData(nestedSession) }
)
fieldSessionAnalysis?.let {
eventPairs += fieldSessionAnalysis with sessionAnalysisFields!!.buildObjectEventData(sessionStructure)
}
return eventPairs
}
}

View File

@@ -0,0 +1,36 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.logger
import com.intellij.internal.statistic.eventLog.events.BooleanEventField
import com.intellij.internal.statistic.eventLog.events.ClassEventField
import com.intellij.internal.statistic.eventLog.events.EventField
import com.intellij.internal.statistic.eventLog.events.EventPair
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform
import com.intellij.platform.ml.impl.monitoring.MLApiStartupListener
import com.intellij.platform.ml.impl.monitoring.MLApiStartupProcessListener
import org.jetbrains.annotations.ApiStatus
@ApiStatus.Internal
class MLApiPlatformStartupLogger : EventIdRecordingMLEvent(), MLApiStartupListener {
companion object {
val SUCCESS = BooleanEventField("success")
val EXCEPTION = ClassEventField("exception")
}
override val eventName: String = "startup"
override val declaration: Array<EventField<*>> = arrayOf(SUCCESS)
override fun onBeforeStarted(apiPlatform: MLApiPlatform): MLApiStartupProcessListener {
val eventId = getEventId(apiPlatform)
return object : MLApiStartupProcessListener {
override fun onFinished() = eventId.log(SUCCESS with true)
override fun onFailed(exception: Throwable?) {
val fields = mutableListOf<EventPair<*>>(SUCCESS with false)
exception?.let { fields += EXCEPTION with it.javaClass }
eventId.log(*fields.toTypedArray())
}
}
}
}

View File

@@ -0,0 +1,121 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.logger
import com.intellij.internal.statistic.eventLog.EventLogGroup
import com.intellij.internal.statistic.eventLog.events.EventField
import com.intellij.internal.statistic.eventLog.events.VarargEventId
import com.intellij.internal.statistic.service.fus.collectors.CounterUsagesCollector
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform.Companion.ensureApproachesInitialized
import com.intellij.platform.ml.impl.apiPlatform.ReplaceableIJPlatform
import org.jetbrains.annotations.ApiStatus
@ApiStatus.Internal
interface MLEvent {
val eventName: String
val declaration: Array<EventField<*>>
fun onEventGroupInitialized(eventId: VarargEventId)
}
@ApiStatus.Internal
abstract class EventIdRecordingMLEvent : MLEvent {
private var providedEventId: VarargEventId? = null
protected fun getEventId(apiPlatform: MLApiPlatform): VarargEventId {
MLEventsLogger.Manager.ensureInitialized(okIfInitializing = false, apiPlatform)
return requireNotNull(providedEventId)
}
final override fun onEventGroupInitialized(eventId: VarargEventId) {
providedEventId = eventId
}
}
/**
* It logs ML sessions with tiers' descriptions and the analysis.
* Each session is logged after it is finished, and all analyzers [com.intellij.platform.ml.impl.session.analysis.SessionAnalyser]
* have yielded the analysis.
*
* As the FUS logs' validators are initialized once during the application's start,
* before giving the [getGroup] for the FUS to use, we must first walk through all declared
* [com.intellij.platform.ml.impl.MLTaskApproach] in the platform and register them in the FUS scheme
* (in case they require FUS logging).
*/
@ApiStatus.Internal
class MLEventsLogger : CounterUsagesCollector() {
override fun getGroup(): EventLogGroup = Initializer.GROUP.also { Manager.ensureInitialized(okIfInitializing = false) }
object Manager {
private val defaultPlatform = ReplaceableIJPlatform
internal fun ensureInitialized(okIfInitializing: Boolean, apiPlatform: MLApiPlatform = defaultPlatform) {
when (val state = Initializer.state) {
is State.FailedToInitialize -> throw Exception("ML Event Log already has failed to initialize", state.exception)
State.Initializing -> if (okIfInitializing) return else throw IllegalStateException("Initialization recursion")
State.NonInitialized -> Initializer.initializeGroup(apiPlatform)
is State.Initialized -> {
val currentApiPlatformState = apiPlatform.staticState
require(currentApiPlatformState == state.context.staticState) {
"""
FUS ML Logger was initialized from ${state.context.apiPlatform} with presumably immutable state ${state.context.staticState},
but it is used from ${apiPlatform} with state ${currentApiPlatformState}, which differs from the initial state.
Hence, something that was expected to be logged will not be.
""".trimIndent()
}
}
}
}
}
internal data class InitializationContext(
val apiPlatform: MLApiPlatform,
val staticState: MLApiPlatform.StaticState
)
internal sealed interface State {
data object NonInitialized : State
data object Initializing : State
class FailedToInitialize(val exception: Throwable) : State
class Initialized(val context: InitializationContext) : State
}
internal object Initializer {
val GROUP = EventLogGroup("ml", 1)
var state: State = State.NonInitialized
fun initializeGroup(apiPlatform: MLApiPlatform) {
require(state == State.NonInitialized)
state = State.Initializing
try {
apiPlatform.ensureApproachesInitialized()
val apiPlatformState = apiPlatform.staticState
apiPlatform.addDefaultEventsAndListeners()
apiPlatform.events.forEach { event ->
val eventId = GROUP.registerVarargEvent(event.eventName, *event.declaration)
event.onEventGroupInitialized(eventId)
}
state = State.Initialized(InitializationContext(apiPlatform, apiPlatformState))
}
catch (e: Throwable) {
state = State.FailedToInitialize(e)
}
(state as? State.FailedToInitialize)?.let {
throw Exception("Failed to initialize FUS ML Logger", it.exception)
}
}
private fun MLApiPlatform.addDefaultEventsAndListeners() {
val startupLogger = MLApiPlatformStartupLogger()
addEvent(startupLogger)
addStartupListener(startupLogger)
}
}
}

View File

@@ -0,0 +1,98 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.logger
import com.intellij.internal.statistic.eventLog.events.EventField
import com.intellij.internal.statistic.eventLog.events.ObjectEventData
import com.intellij.internal.statistic.eventLog.events.ObjectEventField
import com.intellij.platform.ml.Session
import com.intellij.platform.ml.impl.MLTaskApproach
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform
import com.intellij.platform.ml.impl.monitoring.*
import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.ApproachListeners.Companion.monitoredBy
import com.intellij.platform.ml.impl.session.analysis.ShallowSessionAnalyser
import com.intellij.platform.ml.impl.session.analysis.ShallowSessionAnalyser.Companion.declarationObjectDescription
import org.jetbrains.annotations.ApiStatus
/**
* Logs to FUS information about an ML session, that has failed to start.
*
* A way to start logging, is to create a [FailedSessionLoggerRegister] and add it via [com.intellij.platform.ml.impl.apiPlatform.ReplaceableIJPlatform.addStartupListener]
*
* @param taskApproach The approach that is monitored
* @param exceptionalAnalysers Analyzers, that are triggered, when [MLTaskApproach.startSession] fails with an unhandled exception
* @param normalFailureAnalysers Analyzers, that aer triggered, when [MLTaskApproach.startSession] returns a [Session.StartOutcome.Failure]
*/
@ApiStatus.Internal
class MLSessionFailedLogger<M, P : Any>(
private val taskApproach: MLTaskApproach<P>,
private val exceptionalAnalysers: Collection<ShallowSessionAnalyser<Throwable>>,
private val normalFailureAnalysers: Collection<ShallowSessionAnalyser<Session.StartOutcome.Failure<P>>>,
private val apiPlatform: MLApiPlatform,
) : EventIdRecordingMLEvent(), MLTaskGroupListener {
override val eventName: String = "${taskApproach.task.name}.failed"
private val fields: Map<String, ObjectEventField> = (
exceptionalAnalysers.map { ObjectEventField(it.name, it.declarationObjectDescription) } +
normalFailureAnalysers.map { ObjectEventField(it.name, it.declarationObjectDescription) })
.associateBy { it.name }
override val declaration: Array<EventField<*>> = fields.values.toTypedArray()
override val approachListeners: Collection<MLTaskGroupListener.ApproachListeners<*, *>>
get() {
val eventId = getEventId(apiPlatform)
return listOf(
taskApproach.javaClass monitoredBy MLApproachInitializationListener { permanentSessionEnvironment ->
object : MLApproachListener<M, P> {
override fun onFailedToStartSessionWithException(exception: Throwable) {
val analysis = exceptionalAnalysers.map {
val analyserField = fields.getValue(it.name)
analyserField with ObjectEventData(it.analyse(permanentSessionEnvironment, exception))
}
eventId.log(*analysis.toTypedArray())
}
override fun onFailedToStartSession(failure: Session.StartOutcome.Failure<P>) {
val analysis = normalFailureAnalysers.map {
val analyserField = fields.getValue(it.name)
analyserField with ObjectEventData(it.analyse(permanentSessionEnvironment, failure))
}
eventId.log(*analysis.toTypedArray())
}
override fun onStartedSession(session: Session<P>) = null
}
}
)
}
}
/**
* Registers failed sessions' logging as a separate FUS event.
*
* See [MLSessionFailedLogger]
*/
@ApiStatus.Internal
open class FailedSessionLoggerRegister<M, P : Any>(
private val targetApproachClass: Class<out MLTaskApproach<P>>,
private val exceptionalAnalysers: Collection<ShallowSessionAnalyser<Throwable>>,
private val normalFailureAnalysers: Collection<ShallowSessionAnalyser<Session.StartOutcome.Failure<P>>>,
) : MLApiStartupListener {
override fun onBeforeStarted(apiPlatform: MLApiPlatform): MLApiStartupProcessListener {
return object : MLApiStartupProcessListener {
override fun onStartedInitializingFus(initializedApproaches: Collection<InitializerAndApproach<*>>) {
@Suppress("UNCHECKED_CAST")
val targetInitializedApproach: MLTaskApproach<P> = requireNotNull(
initializedApproaches.find { it.approach.javaClass == targetApproachClass }) {
"""
Could not create logger for failed sessions for $targetApproachClass in platform $apiPlatform, because the corresponding approach was not initialized.
Initialized approaches: $initializedApproaches
""".trimIndent()
}.approach as MLTaskApproach<P>
val finishedEventLogger = MLSessionFailedLogger<M, P>(targetInitializedApproach, exceptionalAnalysers, normalFailureAnalysers, apiPlatform)
apiPlatform.addMLEventBeforeFusInitialized(finishedEventLogger)
apiPlatform.addTaskListener(finishedEventLogger)
}
}
}
}

View File

@@ -0,0 +1,89 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.logger
import com.intellij.internal.statistic.eventLog.events.EventField
import com.intellij.platform.ml.Environment
import com.intellij.platform.ml.Session
import com.intellij.platform.ml.impl.MLTaskApproach
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform
import com.intellij.platform.ml.impl.monitoring.*
import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.ApproachListeners.Companion.monitoredBy
import com.intellij.platform.ml.impl.session.AnalysedRootContainer
import com.intellij.platform.ml.impl.session.DescribedRootContainer
import org.jetbrains.annotations.ApiStatus
/**
* Logs to FUS information about a finished ML session, tier instances' descriptions, analysis.
*
* A way to start logging, is to create a [FinishedSessionLoggerRegister] and add it with [com.intellij.platform.ml.impl.apiPlatform.ReplaceableIJPlatform.addStartupListener].
* If you are using a regression model, see [InplaceFeaturesScheme.FusScheme.Companion.DOUBLE].
*
* @param approach The approach whose sessions are monitored
* @param configuration The particular scheme that is used to serialize sessions
*/
@ApiStatus.Internal
class MLSessionFinishedLogger<R, P : Any>(
approach: MLTaskApproach<P>,
configuration: FusSessionEventBuilder.FusScheme<P>,
private val apiPlatform: MLApiPlatform,
) : EventIdRecordingMLEvent(), MLTaskGroupListener {
private val loggingScheme: FusSessionEventBuilder<P> = configuration.createEventBuilder(approach.approachDeclaration)
private val fusDeclaration: SessionFields<P> = loggingScheme.buildFusDeclaration()
override val declaration: Array<EventField<*>> = fusDeclaration.getFields()
override val eventName: String = "${approach.task.name}.finished"
override val approachListeners: Collection<MLTaskGroupListener.ApproachListeners<*, *>> = listOf(
approach.javaClass monitoredBy InitializationLogger()
)
inner class InitializationLogger : MLApproachInitializationListener<R, P> {
override fun onAttemptedToStartSession(permanentSessionEnvironment: Environment): MLApproachListener<R, P> = ApproachLogger()
}
inner class ApproachLogger : MLApproachListener<R, P> {
override fun onFailedToStartSessionWithException(exception: Throwable) {}
override fun onFailedToStartSession(failure: Session.StartOutcome.Failure<P>) {}
override fun onStartedSession(session: Session<P>): MLSessionListener<R, P> = SessionLogger()
}
inner class SessionLogger : MLSessionListener<R, P> {
override fun onSessionDescriptionFinished(sessionTree: DescribedRootContainer<R, P>) {}
override fun onSessionAnalysisFinished(sessionTree: AnalysedRootContainer<P>) {
val eventId = getEventId(apiPlatform)
val fusEventData = loggingScheme.buildRecord(sessionTree, fusDeclaration)
eventId.log(*fusEventData)
}
}
}
/**
* Registers successful ML sessions' logging as a separate FUS event.
*
* See [MLSessionFinishedLogger]
*/
@ApiStatus.Internal
open class FinishedSessionLoggerRegister<M, P : Any>(
private val targetApproachClass: Class<out MLTaskApproach<P>>,
private val fusScheme: FusSessionEventBuilder.FusScheme<P>,
) : MLApiStartupListener {
override fun onBeforeStarted(apiPlatform: MLApiPlatform): MLApiStartupProcessListener {
return object : MLApiStartupProcessListener {
override fun onStartedInitializingFus(initializedApproaches: Collection<InitializerAndApproach<*>>) {
@Suppress("UNCHECKED_CAST")
val targetInitializedApproach: MLTaskApproach<P> = requireNotNull(initializedApproaches.find { it.approach.javaClass == targetApproachClass }) {
"""
Could not create logger for finished sessions of $targetApproachClass in platform $apiPlatform, because the corresponding approach was not initialized.
Initialized approaches: $initializedApproaches
""".trimIndent()
}.approach as MLTaskApproach<P>
val finishedEventLogger = MLSessionFinishedLogger<M, P>(targetInitializedApproach, fusScheme, apiPlatform)
apiPlatform.addMLEventBeforeFusInitialized(finishedEventLogger)
apiPlatform.addTaskListener(finishedEventLogger)
}
}
}
}

View File

@@ -0,0 +1,61 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.model
import com.intellij.platform.ml.Environment
import com.intellij.platform.ml.Feature
import com.intellij.platform.ml.PerTier
import com.intellij.platform.ml.TierRequester
import com.intellij.platform.ml.impl.FeatureSelector
import com.intellij.platform.ml.impl.LevelTiers
import org.jetbrains.annotations.ApiStatus
/**
* Performs a prediction based on the given features.
*
* @param P The prediction's type.
*/
@ApiStatus.Internal
interface MLModel<P : Any> {
/**
* Provides a model, that will be used during an ML session.
*
* It extends [TierRequester] interface which implies, that you can request
* additional tiers that will help you acquire the right model.
*/
interface Provider<M : MLModel<P>, P : Any> : TierRequester {
/**
* Provides an ML model from the task's environment, and the additional tiers
* declared via [requiredTiers].
* If [requiredTiers] could not be fulfilled, then the session will not be started
* and [com.intellij.platform.ml.impl.approach.InsufficientEnvironmentForModelProviderOutcome]
* will be returned as the start's outcome.
*
* @param sessionTiers Contains the main as well as additional tiers that will be used during the session.
* @param environment Contains the all-embracing "permanent" tiers of an ML session - the ones that sit on the first position
* of an [com.intellij.platform.ml.impl.MLTask]'s declaration.
* Plus additional tiers that were requested in [requiredTiers].
* extendedPermanentSessionEnvironment
*/
fun provideModel(sessionTiers: List<LevelTiers>, environment: Environment): M?
}
/**
* Declares a set of features, that are known and can be used by the ML model.
* For each tier, it returns a selection.
* Selection will be 'complete' if and only if the model could be executed with it.
*
* The set of tiers it is aware of could not include some additional tiers, declared
* in [com.intellij.platform.ml.impl.approach.LogDrivenModelInference]'s 'not used description'.
* Because old models could be not aware of the newly established tiers.
*/
val knownFeatures: PerTier<FeatureSelector>
/**
* Performs the prediction.
*
* @param features Features computed for this model to take into account.
* It is guaranteed that for each tier, this set of features is a 'complete'
* selection of features.
*/
fun predict(features: PerTier<Set<Feature>>): P
}

View File

@@ -0,0 +1,170 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.model
import com.intellij.internal.ml.DecisionFunction
import com.intellij.platform.ml.Feature
import com.intellij.platform.ml.FeatureDeclaration
import com.intellij.platform.ml.PerTier
import com.intellij.platform.ml.Tier
import com.intellij.platform.ml.impl.FeatureSelector
import com.intellij.platform.ml.impl.LevelTiers
import org.jetbrains.annotations.ApiStatus
/**
* A wrapper for using legacy [DecisionFunction] as the new API's [MLModel].
*/
@ApiStatus.Internal
open class RegressionModel private constructor(
private val decisionFunction: DecisionFunction,
private val featuresTiers: Set<Tier<*>>,
availableTiers: Set<Tier<*>>,
private val featureSerialization: FeatureNameSerialization
) : MLModel<Double> {
constructor(decisionFunction: DecisionFunction,
featureSerialization: FeatureNameSerialization,
sessionTiers: List<LevelTiers>) : this(
decisionFunction = decisionFunction,
featuresTiers = decisionFunction.featuresOrder.map {
featureSerialization.deserialize(it.featureName, sessionTiers.flatten().associateBy { it.name }).first
}.toSet(),
availableTiers = sessionTiers.flatten(),
featureSerialization = featureSerialization
)
override val knownFeatures: PerTier<FeatureSelector> = createFeatureSelectors(
DecisionFunctionWrapper(decisionFunction, availableTiers, featureSerialization),
featuresTiers
)
override fun predict(features: PerTier<Set<Feature>>): Double {
val array = DoubleArray(decisionFunction.featuresOrder.size)
val featurePerSerializedName = features
.flatMap { (tier, tierFeatures) -> tierFeatures.map { tier to it } }
.associate { (tier, feature) -> featureSerialization.serialize(tier, feature.declaration.name) to feature }
require(features.keys == featuresTiers) {
"Given features tiers are ${features.keys}, but this model needs ${featuresTiers}"
}
for (featureI in decisionFunction.featuresOrder.indices) {
val featureMapper = decisionFunction.featuresOrder[featureI]
val featureSerializedName = featureMapper.featureName
val featureValue = featureMapper.asArrayValue(featurePerSerializedName[featureSerializedName]?.value)
array[featureI] = featureValue
}
return decisionFunction.predict(array)
}
interface FeatureNameSerialization {
fun serialize(tier: Tier<*>, featureName: String): String
fun deserialize(serializedFeatureName: String, availableTiersPerName: Map<String, Tier<*>>): Pair<Tier<*>, String>
}
class DefaultSerialization : FeatureNameSerialization {
private val SERIALIZED_FEATURE_SEPARATOR = '/'
override fun serialize(tier: Tier<*>, featureName: String): String {
return "${tier.name}${SERIALIZED_FEATURE_SEPARATOR}${featureName}"
}
override fun deserialize(serializedFeatureName: String, availableTiersPerName: Map<String, Tier<*>>): Pair<Tier<*>, String> {
val indexOfLastSeparator = serializedFeatureName.indexOfLast { it == SERIALIZED_FEATURE_SEPARATOR }
require(indexOfLastSeparator >= 0) { "Feature name '$serializedFeatureName' does not contain tier's name" }
val featureTierName = serializedFeatureName.slice(0 until indexOfLastSeparator)
val featureName = serializedFeatureName.slice(indexOfLastSeparator until serializedFeatureName.length)
val featureTier = requireNotNull(
availableTiersPerName[featureTierName]) { "Serialized feature '$serializedFeatureName' has tier $featureTierName, but all available tiers are ${availableTiersPerName.keys}" }
return featureTier to featureName
}
}
class SelectionMissingFeatures(
selectedFeatures: Set<FeatureDeclaration<*>>,
missingFeatures: Set<String>
) : FeatureSelector.Selection.Incomplete(selectedFeatures) {
override val details: String = "Regression model requires more features to run. " +
"Missing: $missingFeatures, " +
"Has: $selectedFeatures"
}
private class DecisionFunctionWrapper(
private val decisionFunction: DecisionFunction,
private val availableTiers: Set<Tier<*>>,
private val featureNameSerialization: FeatureNameSerialization
) {
private val availableTiersPerName: Map<String, Tier<*>> = availableTiers.associateBy { it.name }
fun getKnownFeatures(): PerTier<Set<String>> {
val knownFeaturesSerializedNames = decisionFunction.featuresOrder.map { it.featureName }.toSet()
return knownFeaturesSerializedNames
.map { featureNameSerialization.deserialize(it, availableTiersPerName) }
.groupBy({ it.first }, { it.second })
.mapValues { it.value.toSet() }
}
fun getRequiredFeaturesPerTier(): PerTier<Set<String>> {
val availableTiersPerName = availableTiers.associateBy { it.name }
val requiredFeaturesSerializedNames = decisionFunction.requiredFeatures.filterNotNull().toSet()
return requiredFeaturesSerializedNames
.map { serializedFeatureName -> featureNameSerialization.deserialize(serializedFeatureName, availableTiersPerName) }
.groupBy({ it.first }, { it.second })
.mapValues { it.value.toSet() }
}
fun getUnknownFeatures(tier: Tier<*>, featuresNames: Set<String>): Set<String> {
val featureNamePerSerializedName = featuresNames
.associateBy { featureNameSerialization.serialize(tier, it) }
val unknownFeaturesSerializedNames = decisionFunction.getUnknownFeatures(featureNamePerSerializedName.keys).filterNotNull()
return unknownFeaturesSerializedNames.map {
requireNotNull(featureNamePerSerializedName[it]) { "Decision function returned an unknown feature that was not given: '$it'" }
}.toSet()
}
}
companion object {
private fun createFeatureSelectors(decisionFunction: DecisionFunctionWrapper,
featuresTiers: Set<Tier<*>>): PerTier<FeatureSelector> {
val requiredFeaturesPerTier = decisionFunction.getRequiredFeaturesPerTier()
fun createFeatureSelector(tier: Tier<*>) = object : FeatureSelector {
init {
val knownFeatures = decisionFunction.getKnownFeatures()
knownFeatures.forEach { (tier, tierFeatures) ->
val nonConsistentlyKnownFeatures = decisionFunction.getUnknownFeatures(tier, tierFeatures)
require(nonConsistentlyKnownFeatures.isEmpty()) {
"These features are known and unknown at the same time: $nonConsistentlyKnownFeatures"
}
}
}
override fun select(availableFeatures: Set<FeatureDeclaration<*>>): FeatureSelector.Selection {
val availableFeaturesPerName = availableFeatures.associateBy { it.name }
val availableFeaturesNames = availableFeatures.map { it.name }.toSet()
val unknownFeaturesNames = decisionFunction.getUnknownFeatures(tier, availableFeaturesNames)
val knownAvailableFeaturesNames = availableFeaturesNames - unknownFeaturesNames
val knownAvailableFeatures = knownAvailableFeaturesNames.map { availableFeaturesPerName.getValue(it) }.toSet()
val requiredFeaturesNames = requiredFeaturesPerTier[tier] ?: emptySet()
return if (availableFeaturesNames.containsAll(requiredFeaturesNames))
FeatureSelector.Selection.Complete(knownAvailableFeatures)
else
SelectionMissingFeatures(knownAvailableFeatures, requiredFeaturesNames - availableFeaturesNames)
}
override fun select(featureDeclaration: FeatureDeclaration<*>): Boolean {
return decisionFunction.getUnknownFeatures(tier, setOf(featureDeclaration.name)).isEmpty()
}
}
return featuresTiers.associateWith { createFeatureSelector(it) }
}
}
}
private fun List<LevelTiers>.flatten(): Set<Tier<*>> {
return this.flatMap { it.main + it.additional }.toSet()
}

View File

@@ -0,0 +1,29 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.monitoring
import com.intellij.platform.ml.impl.MLTaskApproach
import com.intellij.platform.ml.impl.MLTaskApproachInitializer
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform
import org.jetbrains.annotations.ApiStatus
@ApiStatus.Internal
fun interface MLApiStartupListener {
fun onBeforeStarted(apiPlatform: MLApiPlatform): MLApiStartupProcessListener
}
@ApiStatus.Internal
data class InitializerAndApproach<T : Any>(
val initializer: MLTaskApproachInitializer<T>,
val approach: MLTaskApproach<T>
)
@ApiStatus.Internal
interface MLApiStartupProcessListener {
fun onStartedInitializingApproaches() {}
fun onStartedInitializingFus(initializedApproaches: Collection<InitializerAndApproach<*>>) {}
fun onFinished() {}
fun onFailed(exception: Throwable?) {}
}

View File

@@ -0,0 +1,166 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.monitoring
import com.intellij.platform.ml.Environment
import com.intellij.platform.ml.Session
import com.intellij.platform.ml.impl.MLTaskApproach
import com.intellij.platform.ml.impl.monitoring.MLApproachInitializationListener.Companion.asJoinedListener
import com.intellij.platform.ml.impl.monitoring.MLApproachListener.Companion.asJoinedListener
import com.intellij.platform.ml.impl.monitoring.MLSessionListener.Companion.asJoinedListener
import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.ApproachListeners.Companion.monitoredBy
import com.intellij.platform.ml.impl.session.AnalysedRootContainer
import com.intellij.platform.ml.impl.session.DescribedRootContainer
import org.jetbrains.annotations.ApiStatus
/**
* Provides listeners for a set of [com.intellij.platform.ml.impl.MLTaskApproach]
*
* Only [com.intellij.platform.ml.impl.approach.LogDrivenModelInference] and the subclasses, that are
* calling [com.intellij.platform.ml.impl.approach.LogDrivenModelInference.startSession] are monitored.
*/
@ApiStatus.Internal
interface MLTaskGroupListener {
/**
* For every approach, the [MLTaskGroupListener] is interested in this value provides a collection of
* [MLApproachInitializationListener]
*
* The comfortable way to create this accordance would be by using
* [com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.ApproachListeners.Companion.monitoredBy] infix function.
*/
val approachListeners: Collection<ApproachListeners<*, *>>
/**
* A type-safe pair of approach's class and a set of listeners
*
* A proper way to create it is to use [monitoredBy]
*/
data class ApproachListeners<R, P : Any> internal constructor(
val taskApproach: Class<out MLTaskApproach<P>>,
val approachListener: Collection<MLApproachInitializationListener<R, P>>
) {
companion object {
infix fun <R, P : Any> Class<out MLTaskApproach<P>>.monitoredBy(approachListener: MLApproachInitializationListener<R, P>) = ApproachListeners(
this, listOf(approachListener))
infix fun <R, P : Any> Class<out MLTaskApproach<P>>.monitoredBy(approachListeners: Collection<MLApproachInitializationListener<R, P>>) = ApproachListeners(
this, approachListeners)
}
}
companion object {
internal val MLTaskGroupListener.targetedApproaches: Set<Class<out MLTaskApproach<*>>>
get() = approachListeners.map { it.taskApproach }.toSet()
internal fun <P : Any, R> MLTaskGroupListener.onAttemptedToStartSession(taskApproach: MLTaskApproach<P>,
permanentSessionEnvironment: Environment): MLApproachListener<R, P>? {
@Suppress("UNCHECKED_CAST")
val approachListeners: List<MLApproachInitializationListener<R, P>> = approachListeners
.filter { it.taskApproach == taskApproach.javaClass }
.flatMap { it.approachListener } as List<MLApproachInitializationListener<R, P>>
return approachListeners.asJoinedListener().onAttemptedToStartSession(permanentSessionEnvironment)
}
}
}
/**
* Listens to the attempt of starting new [Session] of the [MLTaskApproach], that this listener was put
* into correspondence to via [com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.ApproachListeners.Companion.monitoredBy]
*/
@ApiStatus.Internal
fun interface MLApproachInitializationListener<R, P : Any> {
/**
* Called each time, when [com.intellij.platform.ml.impl.approach.LogDrivenModelInference.startSession] is invoked
*
* @return A listener, that will be monitoring how successful the start was. If it is not needed, null is returned.
*/
fun onAttemptedToStartSession(permanentSessionEnvironment: Environment): MLApproachListener<R, P>?
companion object {
fun <R, P : Any> Collection<MLApproachInitializationListener<R, P>>.asJoinedListener(): MLApproachInitializationListener<R, P> =
MLApproachInitializationListener { permanentSessionEnvironment ->
val approachListeners = this@asJoinedListener.mapNotNull { it.onAttemptedToStartSession(permanentSessionEnvironment) }
if (approachListeners.isEmpty()) null else approachListeners.asJoinedListener()
}
}
}
/**
* Listens to the process of starting new [Session] of [com.intellij.platform.ml.impl.approach.LogDrivenModelInference].
*/
@ApiStatus.Internal
interface MLApproachListener<R, P : Any> {
/**
* Called if the session was not started,
* on exceptionally rare occasions,
* when the [com.intellij.platform.ml.impl.approach.LogDrivenModelInference.startSession] failed with an exception
*/
fun onFailedToStartSessionWithException(exception: Throwable)
/**
* Called if the session was not started,
* but the failure is 'ordinary'.
*/
fun onFailedToStartSession(failure: Session.StartOutcome.Failure<P>)
/**
* Called when a new [com.intellij.platform.ml.impl.approach.LogDrivenModelInference]'s session was started successfully.
*
* @return A listener for tracking the session's progress, null if the session will not be tracked.
*/
fun onStartedSession(session: Session<P>): MLSessionListener<R, P>?
companion object {
fun <R, P : Any> Collection<MLApproachListener<R, P>>.asJoinedListener(): MLApproachListener<R, P> {
val approachListeners = this@asJoinedListener
return object : MLApproachListener<R, P> {
override fun onFailedToStartSessionWithException(exception: Throwable) =
approachListeners.forEach { it.onFailedToStartSessionWithException(exception) }
override fun onFailedToStartSession(failure: Session.StartOutcome.Failure<P>) = approachListeners.forEach {
it.onFailedToStartSession(failure)
}
override fun onStartedSession(session: Session<P>): MLSessionListener<R, P>? {
val listeners = approachListeners.mapNotNull { it.onStartedSession(session) }
return if (listeners.isEmpty()) null else listeners.asJoinedListener()
}
}
}
}
}
/**
* Listens to session events of a [com.intellij.platform.ml.impl.approach.LogDrivenModelInference]
*/
@ApiStatus.Internal
interface MLSessionListener<R, P : Any> {
/**
* All tier instances were established (the tree will not be growing further),
* described, and predictions in the [sessionTree] were finished.
*/
fun onSessionDescriptionFinished(sessionTree: DescribedRootContainer<R, P>)
/**
* Called only after [onSessionDescriptionFinished]
*
* All tree nodes were analyzed.
*/
fun onSessionAnalysisFinished(sessionTree: AnalysedRootContainer<P>)
companion object {
fun <R, P : Any> Collection<MLSessionListener<R, P>>.asJoinedListener(): MLSessionListener<R, P> {
val sessionListeners = this@asJoinedListener
return object : MLSessionListener<R, P> {
override fun onSessionDescriptionFinished(sessionTree: DescribedRootContainer<R, P>) = sessionListeners.forEach {
it.onSessionDescriptionFinished(sessionTree)
}
override fun onSessionAnalysisFinished(sessionTree: AnalysedRootContainer<P>) = sessionListeners.forEach {
it.onSessionAnalysisFinished(sessionTree)
}
}
}
}
}

View File

@@ -0,0 +1,182 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.session
import com.intellij.platform.ml.*
import com.intellij.platform.ml.ScopeEnvironment.Companion.restrictedBy
import com.intellij.platform.ml.TierRequester.Companion.fulfilledBy
import com.intellij.platform.ml.impl.DescriptionComputer
import com.intellij.platform.ml.impl.FeatureSelector
import com.intellij.platform.ml.impl.FeatureSelector.Companion.or
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform.Companion.getDescriptorsOfTiers
import com.intellij.platform.ml.impl.environment.ExtendedEnvironment
import org.jetbrains.annotations.ApiStatus
@ApiStatus.Internal
data class LevelDescriptor(
val apiPlatform: MLApiPlatform,
val descriptionComputer: DescriptionComputer,
val usedFeaturesSelectors: PerTier<FeatureSelector>,
val notUsedFeaturesSelectors: PerTier<FeatureSelector>,
) {
fun describe(
upperLevels: List<DescribedLevel>,
nextLevelMainEnvironment: Environment,
nextLevelAdditionalTiers: Set<Tier<*>>
): DescribedLevel {
val mainEnvironment = Environment.joined(listOf(
Environment.of(upperLevels.flatMap { it.main.keys }),
nextLevelMainEnvironment
))
val availableEnvironment = ExtendedEnvironment(apiPlatform.environmentExtenders, mainEnvironment)
val extendedEnvironment = availableEnvironment.restrictedBy(nextLevelMainEnvironment.tiers + nextLevelAdditionalTiers)
val runnableDescriptorsPerTier = apiPlatform.getDescriptorsOfTiers(extendedEnvironment.tiers)
.mapValues { (_, descriptors) -> descriptors.fulfilledBy(availableEnvironment) }
val extendedEnvironmentDescription = runnableDescriptorsPerTier
.mapValues { (tier, tierDescriptors) ->
describeTier(tier, tierDescriptors, availableEnvironment)
}
val nextLevel = DescribedLevel(
main = nextLevelMainEnvironment.tierInstances.associateWith { mainTierInstance ->
DescribedTierData(extendedEnvironmentDescription.getValue(mainTierInstance.tier))
},
additional = extendedEnvironment.restrictedBy(nextLevelAdditionalTiers).tierInstances.associateWith { mainTierInstance ->
DescribedTierData(extendedEnvironmentDescription.getValue(mainTierInstance.tier))
},
)
return nextLevel
}
private fun createFilterOfFeaturesToCompute(tier: Tier<*>, tierDescriptors: List<TierDescriptor>): FeatureFilter {
if (tierDescriptors.any { it is ObsoleteTierDescriptor }) {
return FeatureFilter.ACCEPT_ALL
}
val tierFeaturesSelector = (usedFeaturesSelectors[tier] ?: FeatureSelector.NOTHING) or notUsedFeaturesSelectors.getValue(tier)
val tierComputableFeatures = tierDescriptors.flatMap { it.descriptionDeclaration }.toSet()
val tierToComputeSelection = tierFeaturesSelector.select(tierComputableFeatures)
if (tierToComputeSelection is FeatureSelector.Selection.Incomplete) {
throw IncompleteDescriptionException(tier, tierToComputeSelection.selectedFeatures, tierToComputeSelection.details)
}
return FeatureFilter { it in tierToComputeSelection.selectedFeatures }
}
// Filter assumes that the feature is either used or not used by the model
private fun createFilterOfUsedFeatures(tier: Tier<*>): FeatureFilter {
val usedFeaturesSelector = usedFeaturesSelectors[tier] ?: FeatureSelector.NOTHING
val notUsedFeaturesSelector = notUsedFeaturesSelectors.getValue(tier)
return FeatureFilter {
val featureIsUsed = usedFeaturesSelector.select(it)
val featureIsNotUsed = notUsedFeaturesSelector.select(it)
assert(featureIsUsed || featureIsNotUsed) {
"Feature $it of $tier must not have been computed. It is not used by the ML model or marked as not used"
}
require(featureIsUsed xor featureIsNotUsed) {
"${it} of $tier is used by the ML model, but marked as not used at the same time"
}
featureIsUsed
}
}
private fun Set<Feature>.splitByUsage(usableFeaturesFilter: FeatureFilter): Usage<Set<Feature>> {
val usedFeatures = this.filter { usableFeaturesFilter.accept(it.declaration) }
val notUsedFeatures = this.filter { !usableFeaturesFilter.accept(it.declaration) }
return Usage(usedFeatures.toSet(), notUsedFeatures.toSet())
}
private fun makeDescriptionPartition(descriptor: TierDescriptor,
computedDescription: Set<Feature>,
usedFeaturesFilter: FeatureFilter): DescriptionPartition {
val computedDescriptionDeclaration = computedDescription.map { it.declaration }.toSet()
if (descriptor is ObsoleteTierDescriptor) {
val nonDeclaredDescription = computedDescription.filter { it.declaration !in descriptor.partialDescriptionDeclaration }.toSet()
apiPlatform.manageNonDeclaredFeatures(descriptor, nonDeclaredDescription)
}
else {
val notDeclaredFeaturesDeclarations = computedDescriptionDeclaration - descriptor.descriptionDeclaration
require(notDeclaredFeaturesDeclarations.isEmpty()) {
"""
$descriptor described environment with some features that were not declared:
${notDeclaredFeaturesDeclarations.map { it.name }}
computed declaration: ${computedDescriptionDeclaration}
declared declaration: ${descriptor.descriptionDeclaration}
""".trimIndent()
}
}
val maybePartialDescriptionDeclaration = if (descriptor is ObsoleteTierDescriptor)
descriptor.partialDescriptionDeclaration
else
descriptor.descriptionDeclaration
val notComputedDescriptionDeclaration = maybePartialDescriptionDeclaration - computedDescriptionDeclaration
for (notComputedFeatureDeclaration in notComputedDescriptionDeclaration) {
require(!usedFeaturesFilter.accept(notComputedFeatureDeclaration)) {
"Feature ${notComputedFeatureDeclaration} was expected to be computed by $descriptor, " +
"because was declared and accepted by the feature filter. Computed declaration: $computedDescriptionDeclaration"
}
}
val declaredFeatures = mutableSetOf<Feature>()
val nonDeclaredFeatures = mutableSetOf<Feature>()
if (descriptor is ObsoleteTierDescriptor) {
computedDescription.forEach {
if (it.declaration in descriptor.partialDescriptionDeclaration)
declaredFeatures += it
else
nonDeclaredFeatures += it
}
}
else {
declaredFeatures.addAll(computedDescription)
}
return Declaredness(declaredFeatures.splitByUsage(usedFeaturesFilter), nonDeclaredFeatures.splitByUsage(usedFeaturesFilter))
}
private fun describeTier(tier: Tier<*>, tierDescriptors: List<TierDescriptor>, environment: Environment): DescriptionPartition {
val toComputeFilter = createFilterOfFeaturesToCompute(tier, tierDescriptors)
val usefulTierDescriptors = tierDescriptors.filter { it.couldBeUseful(toComputeFilter) }
val description: Map<TierDescriptor, Set<Feature>> = descriptionComputer.computeDescription(
tier,
usefulTierDescriptors,
environment,
toComputeFilter
)
val usedByModelFilter = createFilterOfUsedFeatures(tier)
val descriptionPartition = description.entries
.map { (tierDescriptor, computedDescription) ->
makeDescriptionPartition(tierDescriptor, computedDescription, usedByModelFilter)
}
.reduceOrNull { first, second ->
Declaredness(
declared = Usage(used = first.declared.used + second.declared.used,
notUsed = first.declared.notUsed + second.declared.notUsed),
nonDeclared = Usage(used = first.nonDeclared.used + second.nonDeclared.used,
notUsed = first.nonDeclared.notUsed + second.nonDeclared.notUsed)
)
} ?: Declaredness(Usage(emptySet(), emptySet()), Usage(emptySet(), emptySet()))
return descriptionPartition
}
}
@ApiStatus.Internal
class IncompleteDescriptionException(tier: Tier<*>,
selectedFeatures: Set<FeatureDeclaration<*>>,
missingFeaturesDetails: String) : Exception() {
override val message = "Computable description of tier $tier is not sufficient: $missingFeaturesDetails. Computed features: $selectedFeatures"
}

View File

@@ -0,0 +1,231 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.session
import com.intellij.platform.ml.*
import com.intellij.platform.ml.impl.DescriptionComputer
import com.intellij.platform.ml.impl.FeatureSelector
import com.intellij.platform.ml.impl.LevelTiers
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform
import com.intellij.platform.ml.impl.model.MLModel
import org.jetbrains.annotations.ApiStatus
import java.util.concurrent.CompletableFuture
@ApiStatus.Internal
class RootCollector<M : MLModel<P>, P : Any>(
apiPlatform: MLApiPlatform,
levelsTiers: List<LevelTiers>,
descriptionComputer: DescriptionComputer,
notUsedFeaturesSelectors: PerTier<FeatureSelector>,
levelMainEnvironment: Environment,
levelAdditionalTiers: Set<Tier<*>>,
private val mlModel: M
) : NestableStructureCollector<SessionTree.ComplexRoot<M, DescribedLevel, P>, M, P>() {
override val levelDescriptor = LevelDescriptor(apiPlatform, descriptionComputer, mlModel.knownFeatures, notUsedFeaturesSelectors)
override val levelPositioning = LevelPositioning.superior(levelsTiers, levelDescriptor, levelMainEnvironment, levelAdditionalTiers)
init {
validateSuperiorCollector(levelsTiers, levelMainEnvironment, levelAdditionalTiers, mlModel, notUsedFeaturesSelectors)
}
override fun createTree(thisLevel: DescribedLevel,
collectedNestedStructureTrees: List<DescribedSessionTree<M, P>>): SessionTree.ComplexRoot<M, DescribedLevel, P> {
return SessionTree.ComplexRoot(mlModel, thisLevel, collectedNestedStructureTrees)
}
}
@ApiStatus.Internal
class SolitaryLeafCollector<M : MLModel<P>, P : Any>(
apiPlatform: MLApiPlatform,
levelScheme: LevelTiers,
descriptionComputer: DescriptionComputer,
notUsedFeaturesSelectors: PerTier<FeatureSelector>,
levelMainEnvironment: Environment,
levelAdditionalTiers: Set<Tier<*>>,
private val mlModel: M
) : PredictionCollector<SessionTree.SolitaryLeaf<M, DescribedLevel, P>, M, P>() {
override val levelDescriptor = LevelDescriptor(apiPlatform, descriptionComputer, mlModel.knownFeatures, notUsedFeaturesSelectors)
override val levelPositioning = LevelPositioning.superior(listOf(levelScheme), levelDescriptor, levelMainEnvironment,
levelAdditionalTiers)
init {
validateSuperiorCollector(listOf(levelScheme), levelMainEnvironment, levelAdditionalTiers, mlModel, notUsedFeaturesSelectors)
}
override fun createTree(thisLevel: DescribedLevel, prediction: P?): SessionTree.SolitaryLeaf<M, DescribedLevel, P> {
return SessionTree.SolitaryLeaf(mlModel, levelPositioning.thisLevel, prediction)
}
}
@ApiStatus.Internal
class BranchingCollector<M : MLModel<P>, P : Any>(
override val levelDescriptor: LevelDescriptor,
override val levelPositioning: LevelPositioning
) : NestableStructureCollector<SessionTree.Branching<M, DescribedLevel, P>, M, P>() {
override fun createTree(thisLevel: DescribedLevel,
collectedNestedStructureTrees: List<DescribedSessionTree<M, P>>): SessionTree.Branching<M, DescribedLevel, P> {
return SessionTree.Branching(thisLevel, collectedNestedStructureTrees)
}
}
@ApiStatus.Internal
abstract class PredictionCollector<T : SessionTree.PredictionContainer<M, DescribedLevel, P>, M : MLModel<P>, P : Any> : StructureCollector<T, M, P>() {
private var predictionSubmitted = false
private var submittedPrediction: P? = null
abstract fun createTree(thisLevel: DescribedLevel, prediction: P?): T
val usableDescription: PerTier<Set<Feature>>
get() = levelPositioning.levels.extractDescriptionForModel()
fun submitPrediction(prediction: P?) {
require(!predictionSubmitted)
submittedPrediction = prediction
predictionSubmitted = true
submitTreeToHandlers(createTree(levelPositioning.thisLevel, submittedPrediction))
}
private fun PerTierInstance<DescribedTierData>.extractDescriptionForModel(): PerTier<Set<Feature>> {
return this.entries.associate { (tierInstance, data) ->
tierInstance.tier to data.description.declared.used + data.description.nonDeclared.used
}
}
private fun DescribedLevel.extractDescriptionForModel(): PerTier<Set<Feature>> {
val mainDescription = this.main.extractDescriptionForModel()
val additionalDescription = this.additional.extractDescriptionForModel()
return listOf(mainDescription + additionalDescription).joinByUniqueTier()
}
private fun Iterable<DescribedLevel>.extractDescriptionForModel(): PerTier<Set<Feature>> {
return this.map { it.extractDescriptionForModel() }.joinByUniqueTier()
}
}
@ApiStatus.Internal
class LeafCollector<M : MLModel<P>, P : Any>(
override val levelDescriptor: LevelDescriptor,
override val levelPositioning: LevelPositioning
) : PredictionCollector<SessionTree.Leaf<M, DescribedLevel, P>, M, P>() {
override fun createTree(thisLevel: DescribedLevel, prediction: P?): SessionTree.Leaf<M, DescribedLevel, P> {
return SessionTree.Leaf(levelPositioning.thisLevel, prediction)
}
}
private fun <M : MLModel<P>, P : Any> validateSuperiorCollector(levelsTiers: List<LevelTiers>,
levelMainEnvironment: Environment,
levelAdditionalTiers: Set<Tier<*>>,
mlModel: M,
notUsedDescriptionSelectors: PerTier<FeatureSelector>) {
val allTiers = (levelsTiers.flatMap { it.main + it.additional } + levelMainEnvironment.tiers + levelAdditionalTiers).toSet()
val usedFeaturesSelectors = mlModel.knownFeatures
require(allTiers.containsAll(usedFeaturesSelectors.keys)) {
"ML Model uses tiers that are not main or additional: ${usedFeaturesSelectors.keys - allTiers}"
}
require(notUsedDescriptionSelectors.keys == allTiers) {
"""
Not used description's tiers must be same as the task tiersAll tiers
Missing: ${allTiers - notUsedDescriptionSelectors.keys}
Redundant: ${notUsedDescriptionSelectors.keys - allTiers}
""".trimIndent()
}
}
@ApiStatus.Internal
data class LevelPositioning(
val upperLevels: List<DescribedLevel>,
val lowerTiers: List<LevelTiers>,
val thisLevel: DescribedLevel,
) {
val levels: List<DescribedLevel> = upperLevels + thisLevel
fun nestNextLevel(
levelDescriptor: LevelDescriptor,
nextLevelMainEnvironment: Environment,
nextLevelAdditionalTiers: Set<Tier<*>>
): LevelPositioning {
return LevelPositioning(
upperLevels = upperLevels + thisLevel,
lowerTiers = lowerTiers.drop(1),
thisLevel = levelDescriptor.describe(upperLevels, nextLevelMainEnvironment, nextLevelAdditionalTiers)
)
}
companion object {
fun superior(levelsTiers: List<LevelTiers>,
levelDescriptor: LevelDescriptor,
levelMainEnvironment: Environment,
levelAdditionalTiers: Set<Tier<*>>): LevelPositioning {
return LevelPositioning(emptyList(), levelsTiers.drop(1),
levelDescriptor.describe(emptyList(), levelMainEnvironment, levelAdditionalTiers))
}
}
}
@ApiStatus.Internal
sealed class StructureCollector<T : DescribedSessionTree<M, P>, M : MLModel<P>, P : Any> {
protected abstract val levelDescriptor: LevelDescriptor
abstract val levelPositioning: LevelPositioning
private val sessionTreeHandlers: MutableList<SessionTreeHandler<in T, M, P>> = mutableListOf()
fun handleCollectedTree(handler: SessionTreeHandler<in T, M, P>) {
sessionTreeHandlers.add(handler)
}
protected fun submitTreeToHandlers(sessionTree: T) {
sessionTreeHandlers.forEach { it.handleTree(sessionTree) }
}
}
@ApiStatus.Internal
abstract class NestableStructureCollector<T : DescribedChildrenContainer<M, P>, M : MLModel<P>, P : Any> : StructureCollector<T, M, P>() {
private val nestedSessionsStructures: MutableList<CompletableFuture<DescribedSessionTree<M, P>>> = mutableListOf()
private var nestingFinished = false
fun nestBranch(levelMainEnvironment: Environment, levelAdditionalTiers: Set<Tier<*>>): BranchingCollector<M, P> {
verifyNestedLevelEnvironment(levelMainEnvironment, levelAdditionalTiers)
return BranchingCollector<M, P>(levelDescriptor,
levelPositioning.nestNextLevel(levelDescriptor, levelMainEnvironment, levelAdditionalTiers))
.also { it.trackCollectedStructure() }
}
fun nestPrediction(levelMainEnvironment: Environment, levelAdditionalTiers: Set<Tier<*>>): LeafCollector<M, P> {
verifyNestedLevelEnvironment(levelMainEnvironment, levelAdditionalTiers)
return LeafCollector<M, P>(levelDescriptor, levelPositioning.nestNextLevel(levelDescriptor, levelMainEnvironment, levelAdditionalTiers))
.also { it.trackCollectedStructure() }
}
fun onLastNestedCollectorCreated() {
require(!nestingFinished)
nestingFinished = true
maybeSubmitStructure()
}
private fun <K : DescribedSessionTree<M, P>> StructureCollector<K, M, P>.trackCollectedStructure() {
val collectedNestedTreeContainer = CompletableFuture<DescribedSessionTree<M, P>>()
nestedSessionsStructures += collectedNestedTreeContainer
this.handleCollectedTree {
collectedNestedTreeContainer.complete(it)
}
}
private fun maybeSubmitStructure() {
val collectedNestedStructureTrees: List<SessionTree<M, DescribedLevel, P>> = nestedSessionsStructures
.filter { it.isDone }
.map { it.get() }
if (nestingFinished && collectedNestedStructureTrees.size == nestedSessionsStructures.size) {
val describedSessionTree = createTree(levelPositioning.thisLevel, collectedNestedStructureTrees)
submitTreeToHandlers(describedSessionTree)
}
}
protected abstract fun createTree(thisLevel: DescribedLevel, collectedNestedStructureTrees: List<DescribedSessionTree<M, P>>): T
private fun verifyNestedLevelEnvironment(levelMainEnvironment: Environment, levelAdditionalTiers: Set<Tier<*>>) {
require(levelPositioning.lowerTiers.first().main == levelMainEnvironment.tiers)
require(levelPositioning.lowerTiers.first().additional == levelAdditionalTiers)
}
}

View File

@@ -0,0 +1,275 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.session
import com.intellij.platform.ml.Environment
import com.intellij.platform.ml.Feature
import com.intellij.platform.ml.FeatureDeclaration
import com.intellij.platform.ml.PerTierInstance
import org.jetbrains.annotations.ApiStatus
/**
* A partition of [Feature]s that indicates that they could be either used or not used by an ML model.
*/
@ApiStatus.Internal
data class Usage<T>(
val used: T,
val notUsed: T,
)
/**
* A partition of [Feature]s that indicates that they could be either declared (by [com.intellij.platform.ml.TierDescriptor]) or not.
*
* Used to process [com.intellij.platform.ml.ObsoleteTierDescriptor]'s partial descriptions.
* It shall be removed when all obsolete descriptors will be transferred to the new API.
*/
@ApiStatus.Internal
data class Declaredness<T>(
val declared: T,
val nonDeclared: T
)
/**
* There are two characteristics of a feature as for now: whether it is declared (statically, in a tier descriptor),
* and whether it is used by the ML model.
* So, this container creates a partition for each category of features.
*/
typealias DescriptionPartition = Declaredness<Usage<Set<Feature>>>
/**
* A main tier has a description (that is used to run the ML model), and it also could contain analysis features
*/
@ApiStatus.Internal
data class MainTierScheme(
val description: Set<FeatureDeclaration<*>>,
val analysis: Set<FeatureDeclaration<*>>
)
/**
* An additional tier is provided occasionally, and it has only description
*/
@ApiStatus.Internal
data class AdditionalTierScheme(
val description: Set<FeatureDeclaration<*>>
)
/**
* All the data, that a tier instance that has been described has
*/
@ApiStatus.Internal
data class DescribedTierData(
val description: DescriptionPartition,
)
/**
* All the data, that a tier instance that has been described and analyzed has
*/
@ApiStatus.Internal
data class AnalysedTierData(
val description: DescriptionPartition,
val analysis: Set<Feature>
)
/**
* A template for a container that contains data of main tier instances, as well as additional instances.
* A Level is a collection of tiers that were declared on the same depth of an [com.intellij.platform.ml.impl.MLTask]'s declaration,
* plus additional tiers, declared by [com.intellij.platform.ml.impl.approach.LogDrivenModelInference.additionallyDescribedTiers]
* on the corresponding level.
*/
@ApiStatus.Internal
data class Level<M, A>(val main: M, val additional: A)
typealias DescribedLevel = Level<PerTierInstance<DescribedTierData>, PerTierInstance<DescribedTierData>>
typealias AnalysedLevel = Level<PerTierInstance<AnalysedTierData>, PerTierInstance<DescribedTierData>>
/**
* Tree-like ml session's structure.
*
* All trees leaves have the same depths.
* The depth corresponds to the number of levels in an [com.intellij.platform.ml.impl.MLTask].
* And the tree's structure is built by calling [com.intellij.platform.ml.NestableMLSession.createNestedSession].
*
* @param RootT Type of the data, that is stored in the root node.
* @param LevelT Type of the data, that is stored in each tree's node.
* @param PredictionT Type of the session's prediction.
*/
@ApiStatus.Internal
sealed interface SessionTree<RootT, LevelT, PredictionT> {
/**
* Data, that is stored in each tree's node.
*/
val level: LevelT
/**
* Accepts the [visitor], calling the corresponding interface's function.
*/
fun <T> accept(visitor: Visitor<RootT, LevelT, PredictionT, T>): T
/**
* Something that contains tree's root data.
* There are two such classes: a [SolitaryLeaf], and [ComplexRoot].
*/
sealed interface RootContainer<RootT, LevelT, PredictionT> : SessionTree<RootT, LevelT, PredictionT> {
/**
* Data, that is stored only in the tree's root.
* ML model, for example.
*/
val root: RootT
}
/**
* Something that has nested nodes.
* It could be either [ComplexRoot], or [Branching].
* The number of children corresponds to number of calls of
* [com.intellij.platform.ml.NestableMLSession.createNestedSession].
*/
sealed interface ChildrenContainer<RootT, LevelT, PredictionT> : SessionTree<RootT, LevelT, PredictionT> {
/**
* All nested trees, that were built by calling
* [com.intellij.platform.ml.NestableMLSession.createNestedSession] on this level.
*/
val children: List<SessionTree<RootT, LevelT, PredictionT>>
}
/**
* Something that contains session's prediction.
* It is produced by [com.intellij.platform.ml.SinglePrediction], and the prediction could be either
* produced or canceled, hence the [prediction] is nullable.
*/
sealed interface PredictionContainer<RootT, LevelT, PredictionT> : SessionTree<RootT, LevelT, PredictionT> {
val prediction: PredictionT?
}
/**
* Corresponds to an ML task session's structure, that had only one level, and could not have been nested.
* Hence, it contains root data and a prediction simultaneously.
*/
data class SolitaryLeaf<RootT, LevelT, PredictionT>(
override val root: RootT,
override val level: LevelT,
override val prediction: PredictionT?
) : RootContainer<RootT, LevelT, PredictionT>, PredictionContainer<RootT, LevelT, PredictionT> {
override fun <T> accept(visitor: Visitor<RootT, LevelT, PredictionT, T>): T {
return visitor.acceptSolitaryLeaf(this)
}
}
/**
* Corresponds to an ML task session's structure, that had more than one level.
*/
data class ComplexRoot<RootT, LevelT, PredictionT>(
override val root: RootT,
override val level: LevelT,
override val children: List<SessionTree<RootT, LevelT, PredictionT>>
) : RootContainer<RootT, LevelT, PredictionT>, ChildrenContainer<RootT, LevelT, PredictionT> {
override fun <T> accept(visitor: Visitor<RootT, LevelT, PredictionT, T>): T = visitor.acceptRoot(this)
}
/**
* Corresponds to a node in an ML task session's structure, that had more than one level.
*/
data class Branching<RootT, LevelT, PredictionT>(
override val level: LevelT,
override val children: List<SessionTree<RootT, LevelT, PredictionT>>
) : SessionTree<RootT, LevelT, PredictionT>, ChildrenContainer<RootT, LevelT, PredictionT> {
override fun <T> accept(visitor: Visitor<RootT, LevelT, PredictionT, T>): T = visitor.acceptBranching(this)
}
/**
* Corresponds to a leaf in an ML task session's structure, that ad more than one level.
*/
data class Leaf<RootT, LevelT, PredictionT>(
override val level: LevelT,
override val prediction: PredictionT?
) : SessionTree<RootT, LevelT, PredictionT>, PredictionContainer<RootT, LevelT, PredictionT> {
override fun <T> accept(visitor: Visitor<RootT, LevelT, PredictionT, T>): T = visitor.acceptLeaf(this)
}
/**
* Visits a [SessionTree]'s node.
*
* @param RootT Data's type, stored in the tree's root.
* @param LevelT Data's type, stored in each node.
* @param PredictionT Prediction's type in the ML session.
* @param T Data type that is returned by the visitor.
*/
interface Visitor<RootT, LevelT, PredictionT, out T> {
fun acceptBranching(branching: Branching<RootT, LevelT, PredictionT>): T
fun acceptLeaf(leaf: Leaf<RootT, LevelT, PredictionT>): T
fun acceptRoot(root: ComplexRoot<RootT, LevelT, PredictionT>): T
fun acceptSolitaryLeaf(solitaryLeaf: SolitaryLeaf<RootT, LevelT, PredictionT>): T
}
/**
* Visits all tree's nodes on [levelIndex] depth.
*/
abstract class LevelVisitor<RootT, PredictionT : Any> private constructor(
private val levelIndex: Int,
private val thisVisitorLevel: Int,
) : Visitor<RootT, DescribedLevel, PredictionT, Unit> {
constructor(levelIndex: Int) : this(levelIndex, 0)
private inner class DeeperLevelVisitor : LevelVisitor<RootT, PredictionT>(levelIndex, thisVisitorLevel + 1) {
override fun visitLevel(level: DescribedLevel, levelRoot: SessionTree<RootT, DescribedLevel, PredictionT>) {
this@LevelVisitor.visitLevel(level, levelRoot)
}
}
private fun maybeVisitLevel(level: DescribedLevel,
levelRoot: SessionTree<RootT, DescribedLevel, PredictionT>): Boolean =
if (levelIndex == thisVisitorLevel) {
visitLevel(level, levelRoot)
true
}
else false
final override fun acceptBranching(branching: Branching<RootT, DescribedLevel, PredictionT>) {
if (maybeVisitLevel(branching.level, branching)) return
for (child in branching.children) {
child.accept(DeeperLevelVisitor())
}
}
final override fun acceptLeaf(leaf: Leaf<RootT, DescribedLevel, PredictionT>) {
require(maybeVisitLevel(leaf.level, leaf)) {
"The deepest level in the session tree is $thisVisitorLevel, given level $levelIndex does not exist"
}
}
final override fun acceptRoot(root: ComplexRoot<RootT, DescribedLevel, PredictionT>) {
if (maybeVisitLevel(root.level, root)) return
for (child in root.children) {
child.accept(DeeperLevelVisitor())
}
}
final override fun acceptSolitaryLeaf(solitaryLeaf: SolitaryLeaf<RootT, DescribedLevel, PredictionT>) {
require(maybeVisitLevel(solitaryLeaf.level, solitaryLeaf)) {
"The only level in the session tree is $thisVisitorLevel, given level $levelIndex does not exist"
}
}
abstract fun visitLevel(level: DescribedLevel, levelRoot: SessionTree<RootT, DescribedLevel, PredictionT>)
}
}
typealias DescribedSessionTree<R, P> = SessionTree<R, DescribedLevel, P>
typealias DescribedChildrenContainer<R, P> = SessionTree.ChildrenContainer<R, DescribedLevel, P>
typealias DescribedRootContainer<R, P> = SessionTree.RootContainer<R, DescribedLevel, P>
typealias SessionAnalysis = Map<String, Set<Feature>>
typealias AnalysedSessionTree<P> = SessionTree<SessionAnalysis, AnalysedLevel, P>
typealias AnalysedRootContainer<P> = SessionTree.RootContainer<SessionAnalysis, AnalysedLevel, P>
val <R, P> DescribedSessionTree<R, P>.environment: Environment
get() = Environment.of(this.level.main.keys)
val DescribedLevel.environment: Environment
get() = Environment.of(this.main.keys)

View File

@@ -0,0 +1,9 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.session
import org.jetbrains.annotations.ApiStatus
@ApiStatus.Internal
fun interface SessionTreeHandler<T : DescribedSessionTree<R, P>, R, P> {
fun handleTree(sessionTree: T)
}

View File

@@ -0,0 +1,25 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.session.analysis
import com.intellij.platform.ml.Feature
import com.intellij.platform.ml.FeatureDeclaration
import com.intellij.platform.ml.impl.model.MLModel
/**
* An analyzer which is dedicated to give features to the session's model.
*
* [SessionAnalyser.analysisDeclaration] returns a set of [FeatureDeclaration] -
* the features that the model will be described with.
* [SessionAnalyser.analyse] in this case returns a set of features - the model's analysis.
*/
typealias MLModelAnalyser<M, P> = SessionAnalyser<Set<FeatureDeclaration<*>>, Set<Feature>, M, P>
internal class MLModelAnalysisJoiner<M : MLModel<P>, P : Any> : AnalysisJoiner<Set<FeatureDeclaration<*>>, Set<Feature>, M, P> {
override fun joinDeclarations(declarations: Iterable<Set<FeatureDeclaration<*>>>): Set<FeatureDeclaration<*>> {
return declarations.flatten().toSet()
}
override fun joinAnalysis(analysis: Iterable<Set<Feature>>): Set<Feature> {
return analysis.flatten().toSet()
}
}

View File

@@ -0,0 +1,67 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.session.analysis
import com.intellij.platform.ml.impl.model.MLModel
import com.intellij.platform.ml.impl.session.DescribedRootContainer
import org.jetbrains.annotations.ApiStatus
import java.util.concurrent.CompletableFuture
/**
* An interface for classes, that are analyzing the ML session after it was finished.
*
* Analysis is indefinitely long process (which is emphasized by the fact [analyse] returns a [CompletableFuture]).
*
* @param D Type of the analysis' declaration
* @param A Type of the analysis itself
* @param M Type of the model that has been utilized during the ML session
* @param P Type of the session's prediction
*/
@ApiStatus.Internal
interface SessionAnalyser<D, A, M : MLModel<P>, P : Any> {
/**
* Contains all features' declarations - [com.intellij.platform.ml.FeatureDeclaration],
* that are then used in the analysis as [analyse]'s return value.
*/
val analysisDeclaration: D
/**
* Performs session tree's analysis. The analysis is performed asynchronously,
* so the function is not required to return the final result, but only a [CompletableFuture]
* that will be fulfilled when the analysis is finished.
*/
fun analyse(sessionTreeRoot: DescribedRootContainer<M, P>): CompletableFuture<A>
}
/**
* Gathers analyses from different [SessionAnalyser]s to one.
*/
internal class JoinedSessionAnalyser<D, A, M : MLModel<P>, P : Any>(
private val baseAnalysers: Collection<SessionAnalyser<D, A, M, P>>,
private val joiner: AnalysisJoiner<D, A, M, P>
) : SessionAnalyser<D, A, M, P> {
override fun analyse(sessionTreeRoot: DescribedRootContainer<M, P>): CompletableFuture<A> {
val scatteredAnalysis = mutableListOf<CompletableFuture<A>>()
for (sessionAnalyser in baseAnalysers) {
scatteredAnalysis.add(sessionAnalyser.analyse(sessionTreeRoot))
}
val joinedAnalysis = CompletableFuture<A>()
val eachScatteredAnalysisFinished = CompletableFuture.allOf(*scatteredAnalysis.toTypedArray())
eachScatteredAnalysisFinished.thenRun {
val completeAnalysis = scatteredAnalysis.mapNotNull { it.get() }
joinedAnalysis.complete(joiner.joinAnalysis(completeAnalysis))
}
return joinedAnalysis
}
override val analysisDeclaration: D
get() = joiner.joinDeclarations(baseAnalysers.map { it.analysisDeclaration })
}
internal interface AnalysisJoiner<D, A, M : MLModel<P>, P : Any> {
fun joinDeclarations(declarations: Iterable<D>): D
fun joinAnalysis(analysis: Iterable<A>): A
}

View File

@@ -0,0 +1,40 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.session.analysis
import com.intellij.platform.ml.Feature
import com.intellij.platform.ml.FeatureDeclaration
import com.intellij.platform.ml.PerTier
import com.intellij.platform.ml.impl.model.MLModel
import com.intellij.platform.ml.impl.session.DescribedSessionTree
import com.intellij.platform.ml.mergePerTier
typealias StructureAnalysis<M, P> = Map<DescribedSessionTree<M, P>, PerTier<Set<Feature>>>
typealias StructureAnalysisDeclaration = PerTier<Set<FeatureDeclaration<*>>>
/**
* An analyzer that gives analytical features to the session tree's nodes.
*
* [SessionAnalyser.analysisDeclaration] returns a set of [FeatureDeclaration] per each analyzed tier.
* [SessionAnalyser.analyse] returns a mapping from each analyzed tree's node to sets of features.
*
* For example, if you want to give some analysis to the session tree's root, then you will return
* mapOf(sessionTreeRoot to setOf(...)).
*
* @see com.intellij.platform.ml.impl.session.SessionTree.Visitor to learn how you could walk the tree's nodes
* @see com.intellij.platform.ml.impl.session.SessionTree.LevelVisitor to see learn how you could session's levels.
*/
typealias StructureAnalyser<M, P> = SessionAnalyser<StructureAnalysisDeclaration, StructureAnalysis<M, P>, M, P>
internal class SessionStructureAnalysisJoiner<M : MLModel<P>, P : Any> : AnalysisJoiner<StructureAnalysisDeclaration, StructureAnalysis<M, P>, M, P> {
override fun joinAnalysis(analysis: Iterable<StructureAnalysis<M, P>>): StructureAnalysis<M, P> {
return analysis
.flatMap { it.entries }
.groupBy({ it.key }, { it.value })
.mapValues { it.value.mergePerTier { mutableSetOf() } }
}
override fun joinDeclarations(declarations: Iterable<StructureAnalysisDeclaration>): StructureAnalysisDeclaration {
return declarations.mergePerTier { mutableSetOf() }
}
}

View File

@@ -0,0 +1,43 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.session.analysis
import com.intellij.internal.statistic.eventLog.events.EventField
import com.intellij.internal.statistic.eventLog.events.EventPair
import com.intellij.internal.statistic.eventLog.events.ObjectDescription
import com.intellij.platform.ml.Environment
import org.jetbrains.annotations.ApiStatus
/**
* Analyzes not the ML session's tree, but some generic information.
* Could be used to analyze started, as well as failed to start sessions.
*/
@ApiStatus.Internal
interface ShallowSessionAnalyser<D> {
/**
* Name of the analyzer.
*/
val name: String
/**
* A complete static declaration of the fields, that will be written during analysis.
*/
val declaration: List<EventField<*>>
/**
* Analyze some generic information about an ML session.
*
* @param permanentSessionEnvironment The environment that is available during the whole ML session
* @param data Some additional data, an insight about the place where the analyzer was called.
* Could be a [Throwable], or a reason why it was not possible to start the session.
*/
fun analyse(permanentSessionEnvironment: Environment, data: D): List<EventPair<*>>
companion object {
val <F> ShallowSessionAnalyser<F>.declarationObjectDescription: ObjectDescription
get() = object : ObjectDescription() {
init {
declaration.forEach { field(it) }
}
}
}
}

View File

@@ -0,0 +1,77 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.session
import com.intellij.platform.ml.*
/**
* A wrapper for convenient usage of a [NestableMLSession].
*
* This function assumes that this session is nestable.
*/
fun <P : Any> Session<P>.withNestedSessions(useCreator: (NestableSessionWrapper<P>) -> Unit) {
val nestableMLSession = requireNotNull(this as? NestableMLSession<P>)
val creator = object : NestableSessionWrapper<P> {
override fun nestConsidering(levelEnvironment: Environment): Session<P> {
return nestableMLSession.createNestedSession(levelEnvironment)
}
}
try {
return useCreator(creator)
}
finally {
nestableMLSession.onLastNestedSessionCreated()
}
}
/**
* A wrapper for convenient usage of a [SinglePrediction].
*
* This function assumes that this session is [NestableMLSession], and the nested
* sessions' types are [SinglePrediction].
*/
fun <T, P : Any> Session<P>.withPredictions(useModelWrapper: (ModelWrapper<P>) -> T): T {
val nestableMLSession = requireNotNull(this as? NestableMLSession<P>)
val predictor = object : ModelWrapper<P> {
override fun predictConsidering(predictionEnvironment: Environment): P {
val predictionSession = nestableMLSession.createNestedSession(predictionEnvironment)
require(predictionSession is SinglePrediction<P>)
return predictionSession.predict()
}
override fun consider(predictionEnvironment: Environment) {
val predictionSession = nestableMLSession.createNestedSession(predictionEnvironment)
require(predictionSession is SinglePrediction<P>)
predictionSession.cancelPrediction()
}
}
try {
return useModelWrapper(predictor)
}
finally {
nestableMLSession.onLastNestedSessionCreated()
}
}
interface NestableSessionWrapper<P : Any> {
fun nestConsidering(levelEnvironment: Environment): Session<P>
companion object {
fun <P : Any> NestableSessionWrapper<P>.nestConsidering(vararg levelTierInstances: TierInstance<*>): Session<P> {
return this.nestConsidering(Environment.of(*levelTierInstances))
}
}
}
interface ModelWrapper<P : Any> {
fun predictConsidering(predictionEnvironment: Environment): P
fun consider(predictionEnvironment: Environment)
companion object {
fun <P : Any> ModelWrapper<P>.predictConsidering(vararg predictionTierInstances: TierInstance<*>): P {
return this.predictConsidering(Environment.of(*predictionTierInstances))
}
}
}

View File

@@ -1,6 +1,8 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.turboComplete
import com.intellij.platform.ml.Tier
/**
* Data class representing a kind of [SuggestionGenerator]'s suggestions.
*
@@ -8,4 +10,6 @@ package com.intellij.platform.ml.impl.turboComplete
* The completion kind's name is defined statically, it should be unique among
* the corresponding [KindVariety].
*/
data class CompletionKind(val name: Enum<*>, val variety: KindVariety)
data class CompletionKind(val name: Enum<*>, val variety: KindVariety)
object TierCompletionKind : Tier<CompletionKind>()

View File

@@ -16,7 +16,7 @@ interface SuggestionGeneratorExecutorProvider {
): SuggestionGeneratorExecutor
companion object {
private val EP_NAME: ExtensionPointName<SuggestionGeneratorExecutorProvider> =
val EP_NAME: ExtensionPointName<SuggestionGeneratorExecutorProvider> =
ExtensionPointName("com.intellij.turboComplete.suggestionGeneratorExecutorProvider")
fun hasAnyToCall(parameters: CompletionParameters): Boolean {

View File

@@ -0,0 +1,483 @@
// Copyright 2000-2023 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl
import com.intellij.internal.statistic.FUCollectorTestCase
import com.intellij.internal.statistic.eventLog.events.ClassEventField
import com.intellij.internal.statistic.eventLog.events.EventField
import com.intellij.internal.statistic.eventLog.events.EventPair
import com.intellij.lang.Language
import com.intellij.openapi.fileTypes.PlainTextLanguage
import com.intellij.openapi.util.Version
import com.intellij.platform.ml.*
import com.intellij.platform.ml.impl.MLTaskApproach.Companion.startMLSession
import com.intellij.platform.ml.impl.apiPlatform.CodeLikePrinter
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform
import com.intellij.platform.ml.impl.apiPlatform.ReplaceableIJPlatform
import com.intellij.platform.ml.impl.approach.*
import com.intellij.platform.ml.impl.logger.FailedSessionLoggerRegister
import com.intellij.platform.ml.impl.logger.FinishedSessionLoggerRegister
import com.intellij.platform.ml.impl.logger.InplaceFeaturesScheme
import com.intellij.platform.ml.impl.logger.MLEvent
import com.intellij.platform.ml.impl.model.MLModel
import com.intellij.platform.ml.impl.monitoring.*
import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener.ApproachListeners.Companion.monitoredBy
import com.intellij.platform.ml.impl.session.*
import com.intellij.platform.ml.impl.session.analysis.MLModelAnalyser
import com.intellij.platform.ml.impl.session.analysis.ShallowSessionAnalyser
import com.intellij.platform.ml.impl.session.analysis.StructureAnalyser
import com.intellij.platform.ml.impl.session.analysis.StructureAnalysis
import com.intellij.testFramework.fixtures.BasePlatformTestCase
import com.jetbrains.fus.reporting.model.lion3.LogEvent
import java.io.BufferedWriter
import java.io.FileWriter
import java.net.URI
import java.nio.file.Path
import java.util.concurrent.CompletableFuture
import java.util.concurrent.TimeUnit
import java.util.function.Consumer
import kotlin.io.path.div
import kotlin.random.Random
enum class CompletionType {
SMART,
BASIC
}
data class CompletionSession(
val language: Language,
val callOrder: Int,
val completionType: CompletionType?
)
data class LookupImpl(
val foo: Boolean,
val index: Int
)
data class LookupItem(
val lookupString: String,
val decoration: Map<Any, Any>
)
data class GitRepository(
val user: String,
val projectUri: URI,
val commits: List<String>
)
object TierCompletionSession : Tier<CompletionSession>()
object TierLookup : Tier<LookupImpl>()
object TierItem : Tier<LookupItem>()
object TierGit : Tier<GitRepository>()
class CompletionSessionFeatures1 : TierDescriptor {
companion object {
val CALL_ORDER = FeatureDeclaration.int("call_order")
val LANGUAGE_ID = FeatureDeclaration.categorical("language_id", Language.getRegisteredLanguages().map { it.id }.toSet())
val GIT_USER = FeatureDeclaration.boolean("git_user_is_Glebanister")
val COMPLETION_TYPE = FeatureDeclaration.enum<CompletionType>("completion_type").nullable()
}
override val tier: Tier<*> = TierCompletionSession
override val additionallyRequiredTiers: Set<Tier<*>> = setOf(TierGit)
override val descriptionDeclaration: Set<FeatureDeclaration<*>> = setOf(
CALL_ORDER, LANGUAGE_ID, GIT_USER, COMPLETION_TYPE
)
override fun describe(environment: Environment, usefulFeaturesFilter: FeatureFilter): Set<Feature> {
val completionSession = environment[TierCompletionSession]
val gitRepository = environment[TierGit]
return setOf(
CALL_ORDER with completionSession.callOrder,
LANGUAGE_ID with completionSession.language.id,
GIT_USER with (gitRepository.user == "Glebanister"),
COMPLETION_TYPE with completionSession.completionType
)
}
}
class ItemFeatures1 : TierDescriptor {
companion object {
val DECORATIONS = FeatureDeclaration.int("decorations")
val LENGTH = FeatureDeclaration.int("length")
}
override val tier: Tier<*> = TierItem
override val descriptionDeclaration: Set<FeatureDeclaration<*>> = setOf(
DECORATIONS, LENGTH
)
override fun describe(environment: Environment, usefulFeaturesFilter: FeatureFilter): Set<Feature> {
val item = environment[TierItem]
return setOf(
DECORATIONS with item.decoration.size,
LENGTH with item.lookupString.length
)
}
}
class GitFeatures1 : TierDescriptor {
companion object {
val N_COMMITS = FeatureDeclaration.int("n_commits")
val HAS_USER = FeatureDeclaration.boolean("has_user")
}
override val tier: Tier<*> = TierGit
override val descriptionDeclaration: Set<FeatureDeclaration<*>> = setOf(
N_COMMITS, HAS_USER
)
override fun describe(environment: Environment, usefulFeaturesFilter: FeatureFilter) = setOf(
N_COMMITS with environment[TierGit].commits.size,
HAS_USER with environment[TierGit].user.isNotEmpty()
)
}
class GitInformant : EnvironmentExtender<GitRepository> {
override val extendingTier: Tier<GitRepository> = TierGit
override val requiredTiers: Set<Tier<*>> = setOf()
override fun extend(environment: Environment): GitRepository {
return GitRepository(
user = "Glebanister",
projectUri = URI.create("ssh://git@git.jetbrains.team/ij/intellij.git"),
commits = listOf(
"0e47200fa3bf029d7244745eacbf9d495de818c1",
"638fbc7840b85d6e34ef2320ca5b2c9ec2c4b23c",
"5a74104a0901b8faa4cc1f76736347cd33917041"
)
)
}
}
class SomeStructureAnalyser<M : MLModel<Double>> : StructureAnalyser<M, Double> {
companion object {
val SESSION_IS_GOOD = FeatureDeclaration.boolean("very_good_session")
val LOOKUP_INDEX = FeatureDeclaration.int("lookup_index")
}
override fun analyse(sessionTreeRoot: DescribedRootContainer<M, Double>): CompletableFuture<StructureAnalysis<M, Double>> {
val analysis = mutableMapOf<DescribedSessionTree<M, Double>, PerTier<Set<Feature>>>()
sessionTreeRoot.accept(LookupAnalyser(analysis))
analysis[sessionTreeRoot] = mapOf(
TierCompletionSession to setOf(SESSION_IS_GOOD with true)
)
return CompletableFuture.supplyAsync {
// Pretend that analysis is taking some long time
TimeUnit.SECONDS.sleep(2)
analysis
}
}
private class LookupAnalyser<M : MLModel<Double>>(
private val analysis: MutableMap<DescribedSessionTree<M, Double>, PerTier<Set<Feature>>>
) : SessionTree.LevelVisitor<M, Double>(levelIndex = 1) {
override fun visitLevel(level: DescribedLevel, levelRoot: DescribedSessionTree<M, Double>) {
val lookup = level.environment[TierLookup]
analysis[levelRoot] = mapOf(TierLookup to setOf(LOOKUP_INDEX with lookup.index))
}
}
override val analysisDeclaration: PerTier<Set<FeatureDeclaration<*>>>
get() = mapOf(
TierCompletionSession to setOf(SESSION_IS_GOOD),
TierLookup to setOf(LOOKUP_INDEX)
)
}
class RandomModelSeedAnalyser : MLModelAnalyser<RandomModel, Double> {
companion object {
val SEED = FeatureDeclaration.int("random_seed")
}
override val analysisDeclaration: Set<FeatureDeclaration<*>> = setOf(
SEED
)
override fun analyse(sessionTreeRoot: DescribedRootContainer<RandomModel, Double>): CompletableFuture<Set<Feature>> {
return CompletableFuture.supplyAsync {
// Pretend that analysis is taking some long time
TimeUnit.SECONDS.sleep(1)
setOf(SEED with sessionTreeRoot.root.seed)
}
}
}
class RandomModel(val seed: Int) : MLModel<Double>, Versioned, LanguageSpecific {
private val generator = Random(seed)
class Provider : MLModel.Provider<RandomModel, Double> {
override val requiredTiers: Set<Tier<*>> = emptySet()
override fun provideModel(sessionTiers: List<LevelTiers>, environment: Environment): RandomModel? {
return if (Random.nextBoolean()) {
if (Random.nextBoolean()) RandomModel(1) else throw IllegalStateException()
}
else null
}
}
override val knownFeatures: PerTier<FeatureSelector> = mapOf(
TierCompletionSession to FeatureSelector.EVERYTHING,
TierLookup to FeatureSelector.EVERYTHING,
TierGit to FeatureSelector.EVERYTHING,
TierItem to FeatureSelector.EVERYTHING,
)
override fun predict(features: PerTier<Set<Feature>>): Double {
return generator.nextDouble()
}
override val languageId: String = PlainTextLanguage.INSTANCE.id
override val version: Version = Version(0, 0, 1)
}
class SomeListener(private val name: String) : MLTaskGroupListener {
override val approachListeners = listOf(
MockTaskApproach::class.java monitoredBy InitializationListener()
)
private fun log(message: String) = println("[Listener $name says] $message")
inner class InitializationListener : MLApproachInitializationListener<RandomModel, Double> {
override fun onAttemptedToStartSession(permanentSessionEnvironment: Environment): MLApproachListener<RandomModel, Double> {
log("attempted to initialize session")
return ApproachListener()
}
}
inner class ApproachListener : MLApproachListener<RandomModel, Double> {
override fun onFailedToStartSessionWithException(exception: Throwable) {
log("failed to start session with exception: $exception")
}
override fun onFailedToStartSession(failure: Session.StartOutcome.Failure<Double>) {
log("failed to start session with outcome: $failure")
}
override fun onStartedSession(session: Session<Double>): MLSessionListener<RandomModel, Double> {
log("session was started successfully: $session")
return SessionListener()
}
}
inner class SessionListener : MLSessionListener<RandomModel, Double> {
override fun onSessionDescriptionFinished(sessionTree: DescribedRootContainer<RandomModel, Double>) {
log("session successfully described: $sessionTree")
}
override fun onSessionAnalysisFinished(sessionTree: AnalysedRootContainer<Double>) {
log("session successfully analyzed: $sessionTree")
}
}
}
object ExceptionLogger : ShallowSessionAnalyser<Throwable> {
private val THROWABLE_CLASS = ClassEventField("throwable_class")
override val name: String = "exception"
override val declaration: List<EventField<*>> = listOf(THROWABLE_CLASS)
override fun analyse(permanentSessionEnvironment: Environment, data: Throwable): List<EventPair<*>> {
return listOf(THROWABLE_CLASS with data.javaClass)
}
}
object FailureLogger : ShallowSessionAnalyser<Session.StartOutcome.Failure<Double>> {
private val REASON = ClassEventField("reason")
override val name: String = "normal_failure"
override val declaration: List<EventField<*>> = listOf(REASON)
override fun analyse(permanentSessionEnvironment: Environment,
data: Session.StartOutcome.Failure<Double>): List<EventPair<*>> {
return listOf(REASON with data.javaClass)
}
}
object ThisTestApiPlatform : TestApiPlatform() {
override val tierDescriptors = listOf(
CompletionSessionFeatures1(),
ItemFeatures1(),
GitFeatures1(),
)
override val environmentExtenders = listOf(
GitInformant(),
)
override val taskApproaches = listOf(
MockTaskApproach.Initializer()
)
override val initialStartupListeners: List<MLApiStartupListener> = listOf(
FinishedSessionLoggerRegister<RandomModel, Double>(
MockTaskApproach::class.java,
InplaceFeaturesScheme.FusScheme.DOUBLE
),
FailedSessionLoggerRegister<RandomModel, Double>(
MockTaskApproach::class.java,
exceptionalAnalysers = listOf(ExceptionLogger),
normalFailureAnalysers = listOf(FailureLogger)
)
)
override val initialTaskListeners: List<MLTaskGroupListener> = listOf(
SomeListener("Nika"),
SomeListener("Alex"),
)
override val initialEvents: List<MLEvent> = listOf()
override fun manageNonDeclaredFeatures(descriptor: ObsoleteTierDescriptor, nonDeclaredFeatures: Set<Feature>) {
val printer = CodeLikePrinter()
println("$descriptor is missing the following declaration: ${printer.printCodeLikeString(nonDeclaredFeatures.map { it.declaration })}")
}
}
object MockTask : MLTask<Double>(
name = "mock",
predictionClass = Double::class.java,
levels = listOf(
setOf(TierCompletionSession),
setOf(TierLookup),
setOf(TierItem)
)
)
class MockTaskApproach(
apiPlatform: MLApiPlatform,
task: MLTask<Double>
) : LogDrivenModelInference<RandomModel, Double>(task, apiPlatform) {
override val additionallyDescribedTiers: List<Set<Tier<*>>> = listOf(
setOf(TierGit),
setOf(),
setOf(),
)
override val analysisMethod: AnalysisMethod<RandomModel, Double> = StructureAndModelAnalysis(
structureAnalysers = listOf(SomeStructureAnalyser()),
mlModelAnalysers = listOf(
RandomModelSeedAnalyser(),
ModelVersionAnalyser(),
ModelLanguageAnalyser()
)
)
override val mlModelProvider = RandomModel.Provider()
override val notUsedDescription: PerTier<FeatureSelector> = mapOf(
TierCompletionSession to FeatureSelector.NOTHING,
TierLookup to FeatureSelector.NOTHING,
TierItem to FeatureSelector.NOTHING,
TierGit to FeatureSelector.NOTHING
)
override val descriptionComputer: DescriptionComputer = StateFreeDescriptionComputer()
class Initializer : MLTaskApproachInitializer<Double> {
override val task: MLTask<Double> = MockTask
override fun initializeApproachWithin(apiPlatform: MLApiPlatform) = MockTaskApproach(apiPlatform, task)
}
}
class TestTask : BasePlatformTestCase() {
fun `test demo ml task`() {
// After the session is finished, it will be logged to community/platform/ml-impl/testResources/ml_logs.js
val logs: MutableList<Pair<String, Map<String, Any>>> = mutableListOf()
val collectLogs = Consumer { fusLog: LogEvent ->
logs.add(fusLog.event.id to fusLog.event.data)
}
// TODO: Handle this as well (when it has been initialized before we had the control flow)
//MLEventLogger.Manager.ensureNotInitialized()
ReplaceableIJPlatform.replacingWith(ThisTestApiPlatform) {
FUCollectorTestCase.listenForEvents("FUS", this.testRootDisposable, collectLogs) {
repeat(3) { sessionIndex ->
println("Demo session #$sessionIndex has started")
val startOutcome = startMLSession(MockTask, Environment.of(
TierCompletionSession with CompletionSession(
language = PlainTextLanguage.INSTANCE,
callOrder = 1,
completionType = CompletionType.SMART
)
))
val completionSession = startOutcome.session ?: return@repeat
completionSession.withNestedSessions { lookupSessionCreator ->
lookupSessionCreator.nestConsidering(Environment.of(TierLookup with LookupImpl(true, 1)))
.withPredictions {
it.predictConsidering(Environment.of(TierItem with LookupItem("hello", emptyMap())))
it.predictConsidering(Environment.of(TierItem with LookupItem("world", emptyMap())))
}
lookupSessionCreator.nestConsidering(Environment.of(TierLookup with LookupImpl(false, 2)))
.withPredictions {
it.predictConsidering(Environment.of(TierItem with LookupItem("hello!!!", mapOf("bold" to true))))
it.predictConsidering(Environment.of(TierItem with LookupItem("AAAAA!!", mapOf("strikethrough" to true))))
it.consider(Environment.of(TierItem with LookupItem("AAAAAAAAAAAAAAAAA", mapOf("cursive" to true))))
}
}
// Wait until the analysis will be over
// TODO: Think if it is possible to track it via API
Thread.sleep(3 * 1000)
println("Demo session #$sessionIndex has finished")
}
}
}
convertListToJsonAndWriteToFile(logs)
}
}
fun mapToJson(map: Map<String, Any>): String {
val entrySet = map.entries.joinToString(", ") {
"\"${it.key}\": ${valueToJson(it.value)}"
}
return "{$entrySet}"
}
fun listToJson(list: List<Any>): String =
list.joinToString(", ") { valueToJson(it) }
@Suppress("UNCHECKED_CAST")
fun valueToJson(value: Any): String =
when (value) {
is String -> "\"$value\""
is Map<*, *> -> mapToJson(value as Map<String, Any>)
is List<*> -> "[${listToJson(value as List<Any>)}]"
else -> value.toString()
}
fun convertListToJsonAndWriteToFile(list: MutableList<Pair<String, Map<String, Any>>>) {
// Prepare content to write to file
val logs = list.joinToString(",\n") { mapToJson(mapOf("eventId" to it.first, "data" to it.second)) }
val content = "let logs = [\n$logs\n];"
// Write content to file
val filePath = Path.of(".") / "testResources" / "ml_logs.js"
BufferedWriter(FileWriter(filePath.toFile())).use { it.write(content) }
}

View File

@@ -0,0 +1,46 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform
import com.intellij.platform.ml.impl.apiPlatform.MLApiPlatform.ExtensionController
import com.intellij.platform.ml.impl.logger.MLEvent
import com.intellij.platform.ml.impl.monitoring.MLApiStartupListener
import com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener
abstract class TestApiPlatform : MLApiPlatform() {
private val dynamicTaskListeners: MutableList<MLTaskGroupListener> = mutableListOf()
private val dynamicStartupListeners: MutableList<MLApiStartupListener> = mutableListOf()
private val dynamicEvents: MutableList<MLEvent> = mutableListOf()
abstract val initialTaskListeners: List<MLTaskGroupListener>
abstract val initialStartupListeners: List<MLApiStartupListener>
abstract val initialEvents: List<MLEvent>
final override val events: List<MLEvent>
get() = initialEvents + dynamicEvents
final override val taskListeners: List<MLTaskGroupListener>
get() = initialTaskListeners + dynamicTaskListeners
final override val startupListeners: List<MLApiStartupListener>
get() = initialStartupListeners + dynamicStartupListeners
override fun addTaskListener(taskListener: MLTaskGroupListener): ExtensionController {
return extend(taskListener, dynamicTaskListeners)
}
override fun addEvent(event: MLEvent): ExtensionController {
return extend(event, dynamicEvents)
}
override fun addStartupListener(listener: MLApiStartupListener): ExtensionController {
return extend(listener, dynamicStartupListeners)
}
private fun <T> extend(obj: T, collection: MutableCollection<T>): ExtensionController {
collection.add(obj)
return ExtensionController { collection.remove(obj) }
}
}

View File

@@ -0,0 +1,6 @@
let logs = [
{"eventId": "startup", "data": {"success": true}},
{"eventId": "mock.failed", "data": {"normal_failure": {"reason": "com.intellij.platform.ml.impl.approach.ModelNotAcquiredOutcome"}}},
{"eventId": "mock.finished", "data": {"session": {"ml_model": {"random_seed": 1, "language_id": "TEXT", "version": "0.0"}}, "additional": {"TierGit": {"description": {"not_used": {}, "used": {"has_user": true, "n_commits": 3}}, "id": -401817549}}, "main": {"TierCompletionSession": {"description": {"not_used": {}, "used": {"language_id": "TEXT", "git_user_is_Glebanister": true, "completion_type": "SMART", "call_order": 1}}, "id": 1444967957, "analysis": {"very_good_session": true}}}, "nested": [{"additional": {}, "main": {"TierLookup": {"description": {"not_used": {}, "used": {}}, "id": 38162, "analysis": {"lookup_index": 1}}}, "nested": [{"additional": {}, "prediction": "0.1397272444266996", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 5, "decorations": 0}}, "id": -1220935314, "analysis": {}}}}, {"additional": {}, "prediction": "0.8772473577824259", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 5, "decorations": 0}}, "id": -782084434, "analysis": {}}}}]}, {"additional": {}, "main": {"TierLookup": {"description": {"not_used": {}, "used": {}}, "id": 38349, "analysis": {"lookup_index": 2}}}, "nested": [{"additional": {}, "prediction": "0.8739363143786573", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 8, "decorations": 1}}, "id": 1198136891, "analysis": {}}}}, {"additional": {}, "prediction": "0.9611620534809023", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 7, "decorations": 1}}, "id": 122089051, "analysis": {}}}}, {"additional": {}, "prediction": "null", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 17, "decorations": 1}}, "id": 1444769257, "analysis": {}}}}]}]}},
{"eventId": "mock.finished", "data": {"session": {"ml_model": {"random_seed": 1, "language_id": "TEXT", "version": "0.0"}}, "additional": {"TierGit": {"description": {"not_used": {}, "used": {"has_user": true, "n_commits": 3}}, "id": -401817549}}, "main": {"TierCompletionSession": {"description": {"not_used": {}, "used": {"language_id": "TEXT", "git_user_is_Glebanister": true, "completion_type": "SMART", "call_order": 1}}, "id": 1444967957, "analysis": {"very_good_session": true}}}, "nested": [{"additional": {}, "main": {"TierLookup": {"description": {"not_used": {}, "used": {}}, "id": 38162, "analysis": {"lookup_index": 1}}}, "nested": [{"additional": {}, "prediction": "0.1397272444266996", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 5, "decorations": 0}}, "id": -1220935314, "analysis": {}}}}, {"additional": {}, "prediction": "0.8772473577824259", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 5, "decorations": 0}}, "id": -782084434, "analysis": {}}}}]}, {"additional": {}, "main": {"TierLookup": {"description": {"not_used": {}, "used": {}}, "id": 38349, "analysis": {"lookup_index": 2}}}, "nested": [{"additional": {}, "prediction": "0.8739363143786573", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 8, "decorations": 1}}, "id": 1198136891, "analysis": {}}}}, {"additional": {}, "prediction": "0.9611620534809023", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 7, "decorations": 1}}, "id": 122089051, "analysis": {}}}}, {"additional": {}, "prediction": "null", "main": {"TierItem": {"description": {"not_used": {}, "used": {"length": 17, "decorations": 1}}, "id": 1444769257, "analysis": {}}}}]}]}}
];

View File

@@ -520,6 +520,16 @@
<with attribute="implementationClass" implements="com.intellij.internal.ml.MLFeatureProvider"/>
</extensionPoint>
<extensionPoint qualifiedName="com.intellij.platform.ml.environmentExtender"
interface="com.intellij.platform.ml.EnvironmentExtender"
dynamic="true"/>
<extensionPoint qualifiedName="com.intellij.platform.ml.descriptor"
interface="com.intellij.platform.ml.TierDescriptor"
dynamic="true"/>
<extensionPoint qualifiedName="com.intellij.platform.ml.taskListener"
interface="com.intellij.platform.ml.impl.monitoring.MLTaskGroupListener"
dynamic="true"/>
<extensionPoint name="defender.config" interface="com.intellij.diagnostic.WindowsDefenderChecker$Extension" dynamic="true" />
<extensionPoint name="authorizationProvider" interface="com.intellij.ide.impl.AuthorizationProvider" dynamic="true" />