[ml tools] Add dependency on ml api library

[ml tools] Fix missing internal api status

[ml tools] Fix minor issues

[ml tools] Fix merge issues

[ml tools] Add licenses

[ml tools] Add dependency on ml api library

[ml tools] Add dependency on ml api library


Merge-request: IJ-MR-141881
Merged-by: Gleb Marin <Gleb.Marin@jetbrains.com>

GitOrigin-RevId: cffb6cb5170b5f105adfe589085e43b9cb21bb31
This commit is contained in:
Gleb Marin
2024-08-16 11:25:26 +00:00
committed by intellij-monorepo-bot
parent 06b8c97ffa
commit 747dd44b26
15 changed files with 681 additions and 0 deletions

View File

@@ -0,0 +1,20 @@
<component name="libraryTable">
<library name="jetbrains.mlapi.extension" type="repository">
<properties include-transitive-deps="false" maven-id="com.jetbrains.mlapi:extension:32">
<verification>
<artifact url="file://$MAVEN_REPOSITORY$/com/jetbrains/mlapi/extension/32/extension-32.jar">
<sha256sum>39e57d44df6db53d7336ce8b5ab3e6486b3f0beef16e7c792cbe7a0cf3e4d50b</sha256sum>
</artifact>
</verification>
</properties>
<CLASSES>
<root url="jar://$MAVEN_REPOSITORY$/com/jetbrains/mlapi/extension/32/extension-32.jar!/" />
</CLASSES>
<JAVADOC>
<root url="jar://$MAVEN_REPOSITORY$/com/jetbrains/mlapi/extension/32/extension-32-javadoc.jar!/" />
</JAVADOC>
<SOURCES>
<root url="jar://$MAVEN_REPOSITORY$/com/jetbrains/mlapi/extension/32/extension-32-sources.jar!/" />
</SOURCES>
</library>
</component>

View File

@@ -0,0 +1,20 @@
<component name="libraryTable">
<library name="jetbrains.mlapi.usage" type="repository">
<properties include-transitive-deps="false" maven-id="com.jetbrains.mlapi:usage:32">
<verification>
<artifact url="file://$MAVEN_REPOSITORY$/com/jetbrains/mlapi/usage/32/usage-32.jar">
<sha256sum>714206436b4e8d619d322cba1db73c8e0908ccaf292fa3278a2c85b64512fe8a</sha256sum>
</artifact>
</verification>
</properties>
<CLASSES>
<root url="jar://$MAVEN_REPOSITORY$/com/jetbrains/mlapi/usage/32/usage-32.jar!/" />
</CLASSES>
<JAVADOC>
<root url="jar://$MAVEN_REPOSITORY$/com/jetbrains/mlapi/usage/32/usage-32-javadoc.jar!/" />
</JAVADOC>
<SOURCES>
<root url="jar://$MAVEN_REPOSITORY$/com/jetbrains/mlapi/usage/32/usage-32-sources.jar!/" />
</SOURCES>
</library>
</component>

View File

@@ -1204,6 +1204,8 @@ object CommunityLibraryLicenses {
jetbrainsLibrary("jetbrains.fleet.rpc"),
jetbrainsLibrary("jetbrains.fleet.rpc.server"),
jetbrainsLibrary("jetbrains.intellij.deps.rwmutex.idea"),
jetbrainsLibrary("jetbrains.mlapi.extension"),
jetbrainsLibrary("jetbrains.mlapi.usage"),
jetbrainsLibrary("jshell-frontend"),
jetbrainsLibrary("jvm-native-trusted-roots"),
jetbrainsLibrary("kotlin-gradle-plugin-idea"),

View File

@@ -11,5 +11,6 @@
<orderEntry type="library" name="kotlin-stdlib" level="project" />
<orderEntry type="library" name="kotlinx-coroutines-core" level="project" />
<orderEntry type="module" module-name="intellij.platform.core" />
<orderEntry type="library" exported="" name="jetbrains.mlapi.extension" level="project" />
</component>
</module>

View File

@@ -48,5 +48,6 @@
<orderEntry type="module" module-name="intellij.platform.ml" />
<orderEntry type="module" module-name="intellij.platform.statistics" />
<orderEntry type="library" name="kotlin-reflect" level="project" />
<orderEntry type="library" exported="" name="jetbrains.mlapi.usage" level="project" />
</component>
</module>

View File

@@ -0,0 +1,90 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.tools
import com.intellij.openapi.components.Service
import com.intellij.openapi.diagnostic.Logger
import com.intellij.openapi.diagnostic.debug
import com.intellij.util.application
import com.intellij.util.messages.Topic
import com.jetbrains.ml.platform.MLApiPlatform.ExtensionController
import org.jetbrains.annotations.ApiStatus
@ApiStatus.Internal
@Service(Service.Level.APP)
class IJPlatform : com.jetbrains.ml.platform.MLApiPlatform(
featureProviders = EP_NAME_FEATURE_PROVIDER.extensionList,
mlUnitProviders = emptyList(),
) {
override val taskListeners: Map<String, List<com.jetbrains.ml.monitoring.MLTaskListenerTyped<*, *>>>
get() = KeyedMessagingProvider.collect(MLTaskListenerTyped.TOPIC)
override fun addTaskListener(taskId: String, taskListener: com.jetbrains.ml.monitoring.MLTaskListenerTyped<*, *>): ExtensionController {
val connection = application.messageBus.connect()
fun <M : com.jetbrains.ml.model.MLModel<P>, P : Any> capturingType(taskListenerTyped: com.jetbrains.ml.monitoring.MLTaskListenerTyped<M, P>) {
connection.subscribe(MLTaskListenerTyped.TOPIC, object : MessageBusMLTaskListenerProvider<M, P> {
override fun provide(collector: (com.jetbrains.ml.monitoring.MLTaskListenerTyped<M, P>, String) -> Unit) = collector(taskListenerTyped, taskId)
})
}
capturingType(taskListener)
return ExtensionController { connection.disconnect() }
}
override val loggingListeners: Map<String, List<com.jetbrains.ml.monitoring.MLTaskLoggingListener>>
get() = KeyedMessagingProvider.collect(MLTaskLoggingListener.TOPIC)
override fun addLoggingListener(taskId: String, loggingListener: com.jetbrains.ml.monitoring.MLTaskLoggingListener): ExtensionController {
val connection = application.messageBus.connect()
connection.subscribe(MLTaskLoggingListener.TOPIC, object : MessageBusMLTaskLoggingListenerProvider {
override fun provide(collector: (com.jetbrains.ml.monitoring.MLTaskLoggingListener, String) -> Unit) = collector(loggingListener, taskId)
})
return ExtensionController { connection.disconnect() }
}
override val systemLoggerBuilder: com.jetbrains.ml.platform.SystemLoggerBuilder = object : com.jetbrains.ml.platform.SystemLoggerBuilder {
override fun build(clazz: Class<*>): com.jetbrains.ml.platform.SystemLogger {
return IJSystemLogger(Logger.getInstance(clazz))
}
override fun build(name: String): com.jetbrains.ml.platform.SystemLogger {
return IJSystemLogger(Logger.getInstance(name))
}
private inner class IJSystemLogger(private val baseLogger: Logger) : com.jetbrains.ml.platform.SystemLogger {
override fun info(data: () -> String) = baseLogger.info(data())
override fun warn(data: () -> String) = baseLogger.warn(data())
override fun debug(data: () -> String) = baseLogger.debug { data() }
override fun error(e: Throwable) = baseLogger.error(e)
}
}
}
@ApiStatus.Internal
sealed interface KeyedMessagingProvider<T> {
fun provide(collector: (T, String) -> Unit)
companion object {
fun <P : KeyedMessagingProvider<*>, T> collect(topic: Topic<P>): Map<String, List<T>> {
val collected = mutableMapOf<String, MutableList<T>>()
application.messageBus.syncPublisher(topic).provide { it, keyOfIt ->
@Suppress("UNCHECKED_CAST")
collected.getOrPut(keyOfIt) { mutableListOf() }.add(it as T)
}
return collected
}
}
}
@ApiStatus.Internal
interface MessageBusMLTaskListenerProvider<M : com.jetbrains.ml.model.MLModel<P>, P : Any> : KeyedMessagingProvider<com.jetbrains.ml.monitoring.MLTaskListenerTyped<M, P>>
@ApiStatus.Internal
interface MessageBusMLTaskLoggingListenerProvider : KeyedMessagingProvider<com.jetbrains.ml.monitoring.MLTaskLoggingListener>

View File

@@ -0,0 +1,7 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.tools
import com.intellij.openapi.extensions.ExtensionPointName
import com.jetbrains.ml.FeatureProvider
internal val EP_NAME_FEATURE_PROVIDER: ExtensionPointName<FeatureProvider> = ExtensionPointName.create("com.intellij.platform.ml.featureProvider")

View File

@@ -0,0 +1,31 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.tools
import com.intellij.util.messages.Topic
import org.jetbrains.annotations.ApiStatus
@ApiStatus.Internal
interface MLTaskListenerTyped<M : com.jetbrains.ml.model.MLModel<P>, P : Any> : com.jetbrains.ml.monitoring.MLTaskListenerTyped<M, P>, MessageBusMLTaskListenerProvider<M, P> {
val task: com.jetbrains.ml.MLTask<M, P>
override fun provide(collector: (com.jetbrains.ml.monitoring.MLTaskListenerTyped<M, P>, String) -> Unit) {
collector(this, task.id)
}
companion object {
val TOPIC: Topic<MessageBusMLTaskListenerProvider<*, *>> = Topic.create<MessageBusMLTaskListenerProvider<*, *>>("ml.task", MessageBusMLTaskListenerProvider::class.java)
}
}
@ApiStatus.Internal
interface MLTaskLoggingListener : com.jetbrains.ml.monitoring.MLTaskLoggingListener, MessageBusMLTaskLoggingListenerProvider {
val task: com.jetbrains.ml.MLTask<*, *>
override fun provide(collector: (com.jetbrains.ml.monitoring.MLTaskLoggingListener, String) -> Unit) {
collector(this, task.id)
}
companion object {
val TOPIC = Topic.create<MessageBusMLTaskLoggingListenerProvider>("ml.logging", MessageBusMLTaskLoggingListenerProvider::class.java)
}
}

View File

@@ -0,0 +1,21 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.tools
import com.intellij.internal.statistic.eventLog.EventLogGroup
import com.intellij.openapi.Disposable
import com.intellij.openapi.observable.util.whenDisposed
import com.intellij.platform.ml.impl.tools.logs.IntelliJFusEventRegister
import com.jetbrains.ml.MLTask
import com.jetbrains.ml.model.MLModel
import org.jetbrains.annotations.ApiStatus
@ApiStatus.Internal
fun <M : MLModel<P>, P : Any> EventLogGroup.registerMLTaskLogging(
task: MLTask<M, P>,
parentDisposable: Disposable,
eventPrefix: String = task.id,
) {
val componentRegister = IntelliJFusEventRegister(this)
val listenerController = componentRegister.registerLogging(task, eventPrefix)
parentDisposable.whenDisposed { listenerController.remove() }
}

View File

@@ -0,0 +1,257 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.tools.logs
import com.intellij.internal.statistic.eventLog.FeatureUsageData
import com.intellij.internal.statistic.eventLog.events.PrimitiveEventField
import com.intellij.internal.statistic.eventLog.events.VarargEventId
import com.intellij.lang.Language
import com.intellij.openapi.util.Version
import com.intellij.platform.ml.impl.tools.logs.ConverterObjectDescription.Companion.asIJObjectDescription
import com.intellij.platform.ml.impl.tools.logs.ConverterOfEnum.Companion.toIJConverter
import com.intellij.platform.ml.impl.tools.logs.IJEventPairConverter.Companion.typedBuild
import com.jetbrains.ml.logs.schema.EventField
import com.jetbrains.ml.logs.schema.EventPair
import org.jetbrains.annotations.ApiStatus
import com.intellij.internal.statistic.eventLog.EventLogGroup as IJEventLogGroup
import com.intellij.internal.statistic.eventLog.events.BooleanEventField as IJBooleanEventField
import com.intellij.internal.statistic.eventLog.events.ClassEventField as IJClassEventField
import com.intellij.internal.statistic.eventLog.events.DoubleEventField as IJDoubleEventField
import com.intellij.internal.statistic.eventLog.events.EnumEventField as IJEnumEventField
import com.intellij.internal.statistic.eventLog.events.EventField as IJEventField
import com.intellij.internal.statistic.eventLog.events.EventFields as IJEventFields
import com.intellij.internal.statistic.eventLog.events.EventPair as IJEventPair
import com.intellij.internal.statistic.eventLog.events.FloatEventField as IJFloatEventField1
import com.intellij.internal.statistic.eventLog.events.IntEventField as IJIntEventField
import com.intellij.internal.statistic.eventLog.events.LongEventField as IJLongEventField
import com.intellij.internal.statistic.eventLog.events.ObjectDescription as IJObjectDescription
import com.intellij.internal.statistic.eventLog.events.ObjectEventData as IJObjectEventData
import com.intellij.internal.statistic.eventLog.events.ObjectEventField as IJObjectEventField
import com.intellij.internal.statistic.eventLog.events.ObjectListEventField as IJObjectListEventField
import com.intellij.internal.statistic.eventLog.events.StringEventField as IJStringEventField
import com.jetbrains.ml.logs.schema.BooleanEventField as MLBooleanEventField
import com.jetbrains.ml.logs.schema.ClassEventField as MLClassEventField
import com.jetbrains.ml.logs.schema.DoubleEventField as MLDoubleEventField
import com.jetbrains.ml.logs.schema.EnumEventField as MLEnumEventField
import com.jetbrains.ml.logs.schema.EventField as MLEventField
import com.jetbrains.ml.logs.schema.EventPair as MLEventPair
import com.jetbrains.ml.logs.schema.FloatEventField as MLFloatEventField
import com.jetbrains.ml.logs.schema.IntEventField as MLIntEventField
import com.jetbrains.ml.logs.schema.LongEventField as MLLongEventField
import com.jetbrains.ml.logs.schema.ObjectDescription as MLObjectDescription
import com.jetbrains.ml.logs.schema.ObjectEventData as MLObjectEventData
import com.jetbrains.ml.logs.schema.ObjectEventField as MLObjectEventField
import com.jetbrains.ml.logs.schema.ObjectListEventField as MLObjectListEventField
import com.jetbrains.ml.logs.schema.StringEventField as MLStringEventField
@ApiStatus.Internal
class IntelliJFusEventRegister(private val baseEventGroup: IJEventLogGroup) : com.jetbrains.ml.logs.FusEventRegister {
private class Logger(
private val varargEventId: VarargEventId,
private val objectDescription: ConverterObjectDescription
) : com.jetbrains.ml.logs.FusEventLogger {
override fun log(eventPairs: List<MLEventPair<*>>) {
val ijEventPairs = objectDescription.buildEventPairs(eventPairs)
varargEventId.log(*ijEventPairs.toTypedArray())
}
}
override fun registerEvent(name: String, eventFields: List<EventField<*>>): com.jetbrains.ml.logs.FusEventLogger {
val objectDescription = ConverterObjectDescription(MLObjectDescription(eventFields))
val varargEventId = baseEventGroup.registerVarargEvent(name, null, *objectDescription.getFields())
return Logger(varargEventId, objectDescription)
}
}
@Suppress("UNCHECKED_CAST")
private fun <L> createConverter(mlEventField: MLEventField<L>): IJEventPairConverter<L, *> = when (mlEventField) {
is MLObjectEventField -> ConverterOfObject(
mlEventField.name,
mlEventField.lazyDescription,
mlEventField.objectDescription
) as IJEventPairConverter<L, *>
is MLBooleanEventField -> ConverterOfPrimitiveType(mlEventField) { n, d -> IJBooleanEventField(n, d) } as IJEventPairConverter<L, *>
is MLIntEventField -> ConverterOfPrimitiveType(mlEventField) { n, d -> IJIntEventField(n, d) } as IJEventPairConverter<L, *>
is MLLongEventField -> ConverterOfPrimitiveType(mlEventField) { n, d -> IJLongEventField(n, d) } as IJEventPairConverter<L, *>
is MLFloatEventField -> ConverterOfPrimitiveType(mlEventField) { n, d -> IJFloatEventField1(n, d) } as IJEventPairConverter<L, *>
is MLEnumEventField<*> -> mlEventField.toIJConverter() as IJEventPairConverter<L, *>
is MLClassEventField -> ConverterOfClass(mlEventField) as IJEventPairConverter<L, *>
is MLObjectListEventField -> ConvertObjectList(mlEventField) as IJEventPairConverter<L, *>
is MLDoubleEventField -> ConverterOfPrimitiveType(mlEventField) { n, d -> IJDoubleEventField(n, d) } as IJEventPairConverter<L, *>
is VersionEventField -> ConverterOfVersion(mlEventField) as IJEventPairConverter<L, *>
is LanguageEventField -> ConverterOfLanguage(mlEventField) as IJEventPairConverter<L, *>
is MLStringEventField -> ConverterOfString(mlEventField) as IJEventPairConverter<L, *>
is IJSpecificEventField<*> -> {
when (mlEventField) {
is IJCustomEventField -> ConverterOfCustom(mlEventField)
is LanguageEventField -> ConverterOfLanguage(mlEventField) as IJEventPairConverter<L, *>
is VersionEventField -> ConverterOfVersion(mlEventField) as IJEventPairConverter<L, *>
}
}
else -> throw IllegalArgumentException(
"""
Conversion of ${mlEventField.javaClass.simpleName} is not possible.
If you want to create your own field, you must add an inheritor of
${IJCustomEventField::class.qualifiedName}
""".trimIndent()
)
}
private class ConverterOfCustom<T>(mlEventField: IJCustomEventField<T>) : IJEventPairConverter<T, T> {
override val ijEventField: IJEventField<T> = mlEventField.baseIJEventField
override fun buildEventPair(mlEventPair: EventPair<T>): IJEventPair<T> {
return ijEventField with mlEventPair.data
}
}
private class ConverterOfString(mlEventField: MLStringEventField) : IJEventPairConverter<String, String?> {
override val ijEventField: IJEventField<String?> = IJStringEventField.ValidatedByAllowedValues(
mlEventField.name,
allowedValues = mlEventField.possibleValues,
description = mlEventField.lazyDescription()
)
override fun buildEventPair(mlEventPair: EventPair<String>): IJEventPair<String?> {
return ijEventField with mlEventPair.data
}
}
private class ConverterOfLanguage(mlEventField: LanguageEventField) : IJEventPairConverter<Language, Language?> {
override val ijEventField: IJEventField<Language?> = IJEventFields.Language(mlEventField.name, mlEventField.lazyDescription())
override fun buildEventPair(mlEventPair: MLEventPair<Language>): IJEventPair<Language?> {
return ijEventField with mlEventPair.data
}
}
private class ConverterOfVersion(mlEventField: com.intellij.platform.ml.impl.tools.logs.VersionEventField) : IJEventPairConverter<Version, Version?> {
private class VersionEventField(override val name: String, override val description: String?) : PrimitiveEventField<Version?>() {
override val validationRule: List<String>
get() = listOf("{regexp#version}")
override fun addData(fuData: FeatureUsageData, value: Version?) {
fuData.addVersion(value)
}
}
override val ijEventField: IJEventField<Version?> = VersionEventField(mlEventField.name, mlEventField.lazyDescription())
override fun buildEventPair(mlEventPair: MLEventPair<Version>): IJEventPair<Version?> {
return ijEventField with mlEventPair.data
}
}
private class ConvertObjectList(mlEventField: MLObjectListEventField) :
IJEventPairConverter<List<MLObjectEventData>, List<IJObjectEventData>> {
private val innerObjectConverter = ConverterOfObject(mlEventField.name, mlEventField.lazyDescription, mlEventField.internalObjectDescription)
// FIXME: description is not passed
override val ijEventField: IJEventField<List<IJObjectEventData>> = IJObjectListEventField(
mlEventField.name,
innerObjectConverter.ijObjectDescription
)
override fun buildEventPair(mlEventPair: MLEventPair<List<MLObjectEventData>>): IJEventPair<List<IJObjectEventData>> {
return ijEventField with mlEventPair.data.map { innerObjectFieldsValues ->
innerObjectConverter.buildObjectEventData(innerObjectFieldsValues)
}
}
}
private class ConverterOfEnum<T : Enum<*>>(mlEnumField: MLEnumEventField<T>) : IJEventPairConverter<T, T> {
override val ijEventField: IJEventField<T> = IJEnumEventField(mlEnumField.name, mlEnumField.enumClass, mlEnumField.transform)
override fun buildEventPair(mlEventPair: MLEventPair<T>): IJEventPair<T> {
return ijEventField with mlEventPair.data
}
companion object {
fun <T : Enum<*>> MLEnumEventField<T>.toIJConverter(): ConverterOfEnum<T> {
return ConverterOfEnum(this)
}
}
}
private interface IJEventPairConverter<M, I> {
val ijEventField: IJEventField<I>
fun buildEventPair(mlEventPair: MLEventPair<M>): IJEventPair<I>
companion object {
fun <M, I> IJEventPairConverter<M, I>.typedBuild(mlEventPair: MLEventPair<*>): IJEventPair<I> {
@Suppress("UNCHECKED_CAST")
return buildEventPair(mlEventPair as MLEventPair<M>)
}
}
}
private class ConverterObjectDescription(mlObjectDescription: MLObjectDescription) : IJObjectDescription() {
private val toIJConverters: Map<MLEventField<*>, IJEventPairConverter<*, *>> = mlObjectDescription.getFields().associateWith { mlField ->
val converter = createConverter(mlField)
field(converter.ijEventField)
converter
}
fun buildEventPairs(mlEventPairs: List<MLEventPair<*>>): List<IJEventPair<*>> {
return mlEventPairs.map { mlEventPair ->
require(mlEventPair.field in toIJConverters) {
"""
Field ${mlEventPair.field} (name: ${mlEventPair.field.name}) was not found among
the registered ones: ${toIJConverters.keys.map { it.name }}
""".trimIndent()
}
val converter = requireNotNull(toIJConverters[mlEventPair.field])
converter.typedBuild(mlEventPair)
}
}
fun buildObjectEventData(mlObject: MLObjectEventData): IJObjectEventData {
return IJObjectEventData(buildEventPairs(mlObject.values))
}
companion object {
fun MLObjectDescription.asIJObjectDescription() = ConverterObjectDescription(this)
}
}
private class ConverterOfObject(
name: String,
lazyDescription: () -> String,
mlObjectDescription: MLObjectDescription,
) : IJEventPairConverter<MLObjectEventData, IJObjectEventData> {
val ijObjectDescription = mlObjectDescription.asIJObjectDescription()
override val ijEventField: IJEventField<IJObjectEventData> = IJObjectEventField(name, lazyDescription(), ijObjectDescription)
fun buildObjectEventData(mlObject: MLObjectEventData): IJObjectEventData {
return ijObjectDescription.buildObjectEventData(mlObject)
}
override fun buildEventPair(mlEventPair: MLEventPair<MLObjectEventData>): IJEventPair<IJObjectEventData> {
return ijEventField with buildObjectEventData(mlEventPair.data)
}
}
private class ConverterOfPrimitiveType<T>(
mlEventField: MLEventField<T>,
createIJField: (String, String?) -> IJEventField<T>
) : IJEventPairConverter<T, T> {
override val ijEventField: IJEventField<T> = createIJField(mlEventField.name, mlEventField.lazyDescription())
override fun buildEventPair(mlEventPair: MLEventPair<T>): IJEventPair<T> {
return ijEventField with mlEventPair.data
}
}
private class ConverterOfClass(
mlEventField: MLClassEventField,
) : IJEventPairConverter<Class<*>, Class<*>?> {
override val ijEventField: IJEventField<Class<*>?> = IJClassEventField(mlEventField.name, mlEventField.lazyDescription())
override fun buildEventPair(mlEventPair: MLEventPair<Class<*>>): IJEventPair<Class<*>?> {
return ijEventField with mlEventPair.data
}
}

View File

@@ -0,0 +1,32 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.tools.logs
import com.intellij.lang.Language
import com.jetbrains.ml.logs.schema.EventField
import org.jetbrains.annotations.ApiStatus
/**
* Something, that is dedicated for one language only.
*/
@ApiStatus.Internal
interface LanguageSpecific {
val language: Language
}
/**
* The analyzer, that adds information about ML model's language to logs.
*/
@ApiStatus.Internal
class ModelLanguageAnalyser<M, P : Any> : com.jetbrains.ml.analysis.MLTaskAnalyserTyped<M, P>
where M : com.jetbrains.ml.model.MLModel<P>,
M : LanguageSpecific {
private val LANGUAGE = LanguageEventField("model_language") { "The programming language the ML model is trained for" }
override fun startMLSessionAnalysis(sessionInfo: com.jetbrains.ml.session.MLSessionInfo<M, P>) = object : com.jetbrains.ml.analysis.MLSessionAnalyserTyped<M, P> {
override suspend fun analyseSession(tree: com.jetbrains.ml.tree.MLTree.ATopNode<M, P>?) = buildList {
tree?.mlModel?.let { add(LANGUAGE with it.language) }
}
}
override val declaration: List<EventField<*>> = listOf(LANGUAGE)
}

View File

@@ -0,0 +1,34 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.tools.logs
import com.intellij.openapi.util.Version
import com.jetbrains.ml.logs.schema.EventField
import org.jetbrains.annotations.ApiStatus
/**
* Something, that has versions.
*/
@ApiStatus.Internal
interface Versioned {
val version: Version?
}
/**
* Adds model's version to the ML logs.
*/
@ApiStatus.Internal
class ModelVersionAnalyser<M, P : Any> : com.jetbrains.ml.analysis.MLTaskAnalyserTyped<M, P>
where M : com.jetbrains.ml.model.MLModel<P>,
M : Versioned {
companion object {
private val VERSION = VersionEventField("model_version") { "Version of the ML model" }
}
override fun startMLSessionAnalysis(sessionInfo: com.jetbrains.ml.session.MLSessionInfo<M, P>) = object : com.jetbrains.ml.analysis.MLSessionAnalyserTyped<M, P> {
override suspend fun analyseSession(tree: com.jetbrains.ml.tree.MLTree.ATopNode<M, P>?) = buildList {
tree?.mlModel?.version?.let { add(VERSION with it) }
}
}
override val declaration: List<EventField<*>> = listOf(VERSION)
}

View File

@@ -0,0 +1,20 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.tools.logs
import com.intellij.lang.Language
import com.intellij.openapi.util.Version
import com.jetbrains.ml.logs.schema.CustomRuleEventField
import org.jetbrains.annotations.ApiStatus
import com.intellij.internal.statistic.eventLog.events.EventField as IJEventField
@ApiStatus.Internal
sealed class IJSpecificEventField<T>(name: String, lazyDescription: () -> String) : CustomRuleEventField<T>(name, lazyDescription)
@ApiStatus.Internal
class VersionEventField(name: String, lazyDescription: () -> String) : IJSpecificEventField<Version>(name, lazyDescription)
@ApiStatus.Internal
class LanguageEventField(name: String, lazyDescription: () -> String) : IJSpecificEventField<Language>(name, lazyDescription)
@ApiStatus.Internal
open class IJCustomEventField<T>(val baseIJEventField: IJEventField<T>) : IJSpecificEventField<T>(baseIJEventField.name, { baseIJEventField.description ?: "" })

View File

@@ -0,0 +1,142 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.impl.tools.model
import com.intellij.internal.ml.DecisionFunction
import com.jetbrains.ml.*
import com.jetbrains.ml.model.MLModel
import org.jetbrains.annotations.ApiStatus
/**
* A wrapper for using legacy [DecisionFunction] ML API`s [MLModel].
*/
@ApiStatus.Internal
open class RegressionModel private constructor(
private val decisionFunction: DecisionFunction,
private val modelUnits: Set<MLUnit<*>>,
globalUnits: Set<MLUnit<*>>,
private val featureSerialization: FeatureNameSerialization,
) : MLModel<Double> {
constructor(
decisionFunction: DecisionFunction,
featureSerialization: FeatureNameSerialization,
globalUnits: Set<MLUnit<*>>,
) : this(
decisionFunction = decisionFunction,
modelUnits = decisionFunction.featuresOrder.map { featureMapper ->
featureSerialization.deserialize(featureMapper.featureName, globalUnits.associateBy { it.name }).first
}.toSet(),
globalUnits = globalUnits,
featureSerialization = featureSerialization
)
override val knownFeatures: Map<MLUnit<*>, FeatureFilter> = createKnownFeatures(
DecisionFunctionWrapper(decisionFunction, globalUnits, featureSerialization),
modelUnits
)
override fun predict(taskUnits: List<MLUnitsMap>, contexts: List<Any?>, features: Map<MLUnit<*>, List<Feature>>): Double {
val array = DoubleArray(decisionFunction.featuresOrder.size)
val featurePerSerializedName: Map<String, Feature> = features
.flatMap { (unit, unitFeatures) -> unitFeatures.map { unit to it } }
.associate { (unit, feature) -> featureSerialization.serialize(unit, feature.declaration.name) to feature }
require(features.keys == modelUnits) {
"Given features units are ${features.keys}, but this model needs ${modelUnits}"
}
for (featureI in decisionFunction.featuresOrder.indices) {
val featureMapper = decisionFunction.featuresOrder[featureI]
val featureSerializedName = featureMapper.featureName
val featureValue = featureMapper.asArrayValue(featurePerSerializedName[featureSerializedName]?.value)
array[featureI] = featureValue
}
return decisionFunction.predict(array)
}
interface FeatureNameSerialization {
fun serialize(mlUnit: MLUnit<*>, featureName: String): String
fun deserialize(serializedFeatureName: String, unitsPerName: Map<String, MLUnit<*>>): Pair<MLUnit<*>, String>
}
object DefaultFeatureSerialization : FeatureNameSerialization {
private const val SERIALIZED_FEATURE_SEPARATOR = '/'
override fun serialize(mlUnit: MLUnit<*>, featureName: String): String {
return mlUnit.name + SERIALIZED_FEATURE_SEPARATOR + featureName
}
override fun deserialize(serializedFeatureName: String, unitsPerName: Map<String, MLUnit<*>>): Pair<MLUnit<*>, String> {
val indexOfLastSeparator = serializedFeatureName.indexOfLast { it == SERIALIZED_FEATURE_SEPARATOR }
require(indexOfLastSeparator >= 0) { "Feature name '$serializedFeatureName' does not contain ML unit's name" }
val unitName = serializedFeatureName.slice(0 until indexOfLastSeparator)
val featureName = serializedFeatureName.slice(indexOfLastSeparator until serializedFeatureName.length)
val featureUnit = requireNotNull(unitsPerName[unitName]) {
"""
Serialized feature '$serializedFeatureName' has tier $unitName,
but all available tiers are ${unitsPerName.keys}
""".trimIndent()
}
return featureUnit to featureName
}
}
private class DecisionFunctionWrapper(
private val decisionFunction: DecisionFunction,
globalUnits: Set<MLUnit<*>>,
private val featureNameSerialization: FeatureNameSerialization,
) {
private val availableUnitsPerName: Map<String, MLUnit<*>> = globalUnits.associateBy { it.name }
val knownFeatures: Map<MLUnit<*>, Set<String>> = run {
val knownFeaturesSerializedNames = decisionFunction.featuresOrder.map { it.featureName }.toSet()
knownFeaturesSerializedNames
.map { featureNameSerialization.deserialize(it, availableUnitsPerName) }
.groupBy({ it.first }, { it.second })
.mapValues { it.value.toSet() }
}
fun getUnknownFeatures(unit: MLUnit<*>, featuresNames: Set<String>): Set<String> {
val knownFeatures = knownFeatures[unit] ?: return featuresNames
return featuresNames.filterNot { it in knownFeatures }.toSet()
}
}
companion object {
private fun createKnownFeatures(
decisionFunction: DecisionFunctionWrapper,
modelUnits: Set<MLUnit<*>>,
): Map<MLUnit<*>, FeatureFilter> {
fun createFeatureSelector(unit: MLUnit<*>) = object : FeatureFilter {
init {
val knownFeatures = decisionFunction.knownFeatures
knownFeatures.forEach { (unit, features) ->
val nonConsistentlyKnownFeatures = decisionFunction.getUnknownFeatures(unit, features)
require(nonConsistentlyKnownFeatures.isEmpty()) {
"These features are known and unknown at the same time: $nonConsistentlyKnownFeatures"
}
}
}
override fun accept(featureDeclarations: Set<FeatureDeclaration<*>>): Set<FeatureDeclaration<*>> {
val availableFeaturesPerName = featureDeclarations.associateBy { it.name }
val availableFeaturesNames = featureDeclarations.map { it.name }.toSet()
val unknownFeaturesNames = decisionFunction.getUnknownFeatures(unit, availableFeaturesNames)
val knownAvailableFeaturesNames = availableFeaturesNames - unknownFeaturesNames
val knownAvailableFeatures = knownAvailableFeaturesNames.map { availableFeaturesPerName.getValue(it) }.toSet()
return knownAvailableFeatures
}
override fun accept(featureDeclaration: FeatureDeclaration<*>): Boolean {
val unknown = decisionFunction.getUnknownFeatures(unit, setOf(featureDeclaration.name))
return unknown.isEmpty()
}
}
return modelUnits.associateWith { createFeatureSelector(it) }
}
}
}

View File

@@ -541,6 +541,9 @@
<extensionPoint qualifiedName="com.intellij.platform.ml.taskListener"
interface="com.intellij.platform.ml.monitoring.MLTaskGroupListener"
dynamic="true"/>
<extensionPoint qualifiedName="com.intellij.platform.ml.featureProvider"
interface="com.jetbrains.ml.FeatureProvider"
dynamic="true"/>
<extensionPoint name="defender.config" interface="com.intellij.diagnostic.WindowsDefenderChecker$Extension" dynamic="true" />
<extensionPoint name="authorizationProvider" interface="com.intellij.ide.impl.AuthorizationProvider" dynamic="true" />