From 747dd44b263f36824fbe85240fa2f673ca0091bb Mon Sep 17 00:00:00 2001 From: Gleb Marin Date: Fri, 16 Aug 2024 11:25:26 +0000 Subject: [PATCH] [ml tools] Add dependency on ml api library [ml tools] Fix missing internal api status [ml tools] Fix minor issues [ml tools] Fix merge issues [ml tools] Add licenses [ml tools] Add dependency on ml api library [ml tools] Add dependency on ml api library Merge-request: IJ-MR-141881 Merged-by: Gleb Marin GitOrigin-RevId: cffb6cb5170b5f105adfe589085e43b9cb21bb31 --- .idea/libraries/jetbrains_mlapi_extension.xml | 20 ++ .idea/libraries/jetbrains_mlapi_usage.xml | 20 ++ .../build/CommunityLibraryLicenses.kt | 2 + platform/ml-api/intellij.platform.ml.iml | 1 + .../ml-impl/intellij.platform.ml.impl.iml | 1 + .../platform/ml/impl/tools/IJPlatform.kt | 90 ++++++ .../platform/ml/impl/tools/extensionPoints.kt | 7 + .../platform/ml/impl/tools/listeners.kt | 31 +++ .../platform/ml/impl/tools/logging.kt | 21 ++ .../tools/logs/IntelliJFusEventRegister.kt | 257 ++++++++++++++++++ .../ml/impl/tools/logs/LanguageSpecific.kt | 32 +++ .../platform/ml/impl/tools/logs/Versioned.kt | 34 +++ .../ml/impl/tools/logs/eventFields.kt | 20 ++ .../ml/impl/tools/model/RegressionModel.kt | 142 ++++++++++ .../src/META-INF/PlatformExtensionPoints.xml | 3 + 15 files changed, 681 insertions(+) create mode 100644 .idea/libraries/jetbrains_mlapi_extension.xml create mode 100644 .idea/libraries/jetbrains_mlapi_usage.xml create mode 100644 platform/ml-impl/src/com/intellij/platform/ml/impl/tools/IJPlatform.kt create mode 100644 platform/ml-impl/src/com/intellij/platform/ml/impl/tools/extensionPoints.kt create mode 100644 platform/ml-impl/src/com/intellij/platform/ml/impl/tools/listeners.kt create mode 100644 platform/ml-impl/src/com/intellij/platform/ml/impl/tools/logging.kt create mode 100644 platform/ml-impl/src/com/intellij/platform/ml/impl/tools/logs/IntelliJFusEventRegister.kt create mode 100644 platform/ml-impl/src/com/intellij/platform/ml/impl/tools/logs/LanguageSpecific.kt create mode 100644 platform/ml-impl/src/com/intellij/platform/ml/impl/tools/logs/Versioned.kt create mode 100644 platform/ml-impl/src/com/intellij/platform/ml/impl/tools/logs/eventFields.kt create mode 100644 platform/ml-impl/src/com/intellij/platform/ml/impl/tools/model/RegressionModel.kt diff --git a/.idea/libraries/jetbrains_mlapi_extension.xml b/.idea/libraries/jetbrains_mlapi_extension.xml new file mode 100644 index 000000000000..a8a4e90d0dfa --- /dev/null +++ b/.idea/libraries/jetbrains_mlapi_extension.xml @@ -0,0 +1,20 @@ + + + + + + 39e57d44df6db53d7336ce8b5ab3e6486b3f0beef16e7c792cbe7a0cf3e4d50b + + + + + + + + + + + + + + \ No newline at end of file diff --git a/.idea/libraries/jetbrains_mlapi_usage.xml b/.idea/libraries/jetbrains_mlapi_usage.xml new file mode 100644 index 000000000000..f469b3d5af21 --- /dev/null +++ b/.idea/libraries/jetbrains_mlapi_usage.xml @@ -0,0 +1,20 @@ + + + + + + 714206436b4e8d619d322cba1db73c8e0908ccaf292fa3278a2c85b64512fe8a + + + + + + + + + + + + + + \ No newline at end of file diff --git a/platform/build-scripts/src/org/jetbrains/intellij/build/CommunityLibraryLicenses.kt b/platform/build-scripts/src/org/jetbrains/intellij/build/CommunityLibraryLicenses.kt index 526a82e72401..15501f13de72 100644 --- a/platform/build-scripts/src/org/jetbrains/intellij/build/CommunityLibraryLicenses.kt +++ b/platform/build-scripts/src/org/jetbrains/intellij/build/CommunityLibraryLicenses.kt @@ -1204,6 +1204,8 @@ object CommunityLibraryLicenses { jetbrainsLibrary("jetbrains.fleet.rpc"), jetbrainsLibrary("jetbrains.fleet.rpc.server"), jetbrainsLibrary("jetbrains.intellij.deps.rwmutex.idea"), + jetbrainsLibrary("jetbrains.mlapi.extension"), + jetbrainsLibrary("jetbrains.mlapi.usage"), jetbrainsLibrary("jshell-frontend"), jetbrainsLibrary("jvm-native-trusted-roots"), jetbrainsLibrary("kotlin-gradle-plugin-idea"), diff --git a/platform/ml-api/intellij.platform.ml.iml b/platform/ml-api/intellij.platform.ml.iml index e27bd7b7467d..4bf4ebf195b9 100644 --- a/platform/ml-api/intellij.platform.ml.iml +++ b/platform/ml-api/intellij.platform.ml.iml @@ -11,5 +11,6 @@ + \ No newline at end of file diff --git a/platform/ml-impl/intellij.platform.ml.impl.iml b/platform/ml-impl/intellij.platform.ml.impl.iml index 9ede63710bce..3b031b7cabd1 100644 --- a/platform/ml-impl/intellij.platform.ml.impl.iml +++ b/platform/ml-impl/intellij.platform.ml.impl.iml @@ -48,5 +48,6 @@ + \ No newline at end of file diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/IJPlatform.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/IJPlatform.kt new file mode 100644 index 000000000000..bed57622bbb1 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/IJPlatform.kt @@ -0,0 +1,90 @@ +// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.tools + +import com.intellij.openapi.components.Service +import com.intellij.openapi.diagnostic.Logger +import com.intellij.openapi.diagnostic.debug +import com.intellij.util.application +import com.intellij.util.messages.Topic +import com.jetbrains.ml.platform.MLApiPlatform.ExtensionController +import org.jetbrains.annotations.ApiStatus + + +@ApiStatus.Internal +@Service(Service.Level.APP) +class IJPlatform : com.jetbrains.ml.platform.MLApiPlatform( + featureProviders = EP_NAME_FEATURE_PROVIDER.extensionList, + mlUnitProviders = emptyList(), +) { + override val taskListeners: Map>> + get() = KeyedMessagingProvider.collect(MLTaskListenerTyped.TOPIC) + + override fun addTaskListener(taskId: String, taskListener: com.jetbrains.ml.monitoring.MLTaskListenerTyped<*, *>): ExtensionController { + val connection = application.messageBus.connect() + + fun , P : Any> capturingType(taskListenerTyped: com.jetbrains.ml.monitoring.MLTaskListenerTyped) { + connection.subscribe(MLTaskListenerTyped.TOPIC, object : MessageBusMLTaskListenerProvider { + override fun provide(collector: (com.jetbrains.ml.monitoring.MLTaskListenerTyped, String) -> Unit) = collector(taskListenerTyped, taskId) + }) + } + + capturingType(taskListener) + + return ExtensionController { connection.disconnect() } + } + + override val loggingListeners: Map> + get() = KeyedMessagingProvider.collect(MLTaskLoggingListener.TOPIC) + + override fun addLoggingListener(taskId: String, loggingListener: com.jetbrains.ml.monitoring.MLTaskLoggingListener): ExtensionController { + val connection = application.messageBus.connect() + connection.subscribe(MLTaskLoggingListener.TOPIC, object : MessageBusMLTaskLoggingListenerProvider { + override fun provide(collector: (com.jetbrains.ml.monitoring.MLTaskLoggingListener, String) -> Unit) = collector(loggingListener, taskId) + }) + + return ExtensionController { connection.disconnect() } + } + + override val systemLoggerBuilder: com.jetbrains.ml.platform.SystemLoggerBuilder = object : com.jetbrains.ml.platform.SystemLoggerBuilder { + override fun build(clazz: Class<*>): com.jetbrains.ml.platform.SystemLogger { + return IJSystemLogger(Logger.getInstance(clazz)) + } + + override fun build(name: String): com.jetbrains.ml.platform.SystemLogger { + return IJSystemLogger(Logger.getInstance(name)) + } + + private inner class IJSystemLogger(private val baseLogger: Logger) : com.jetbrains.ml.platform.SystemLogger { + override fun info(data: () -> String) = baseLogger.info(data()) + + override fun warn(data: () -> String) = baseLogger.warn(data()) + + override fun debug(data: () -> String) = baseLogger.debug { data() } + + override fun error(e: Throwable) = baseLogger.error(e) + } + } +} + + +@ApiStatus.Internal +sealed interface KeyedMessagingProvider { + fun provide(collector: (T, String) -> Unit) + + companion object { + fun

, T> collect(topic: Topic

): Map> { + val collected = mutableMapOf>() + application.messageBus.syncPublisher(topic).provide { it, keyOfIt -> + @Suppress("UNCHECKED_CAST") + collected.getOrPut(keyOfIt) { mutableListOf() }.add(it as T) + } + return collected + } + } +} + +@ApiStatus.Internal +interface MessageBusMLTaskListenerProvider, P : Any> : KeyedMessagingProvider> + +@ApiStatus.Internal +interface MessageBusMLTaskLoggingListenerProvider : KeyedMessagingProvider diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/extensionPoints.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/extensionPoints.kt new file mode 100644 index 000000000000..7c1e48f8daf9 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/extensionPoints.kt @@ -0,0 +1,7 @@ +// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.tools + +import com.intellij.openapi.extensions.ExtensionPointName +import com.jetbrains.ml.FeatureProvider + +internal val EP_NAME_FEATURE_PROVIDER: ExtensionPointName = ExtensionPointName.create("com.intellij.platform.ml.featureProvider") diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/listeners.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/listeners.kt new file mode 100644 index 000000000000..4f771b9f57ad --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/listeners.kt @@ -0,0 +1,31 @@ +// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.tools + +import com.intellij.util.messages.Topic +import org.jetbrains.annotations.ApiStatus + +@ApiStatus.Internal +interface MLTaskListenerTyped, P : Any> : com.jetbrains.ml.monitoring.MLTaskListenerTyped, MessageBusMLTaskListenerProvider { + val task: com.jetbrains.ml.MLTask + + override fun provide(collector: (com.jetbrains.ml.monitoring.MLTaskListenerTyped, String) -> Unit) { + collector(this, task.id) + } + + companion object { + val TOPIC: Topic> = Topic.create>("ml.task", MessageBusMLTaskListenerProvider::class.java) + } +} + +@ApiStatus.Internal +interface MLTaskLoggingListener : com.jetbrains.ml.monitoring.MLTaskLoggingListener, MessageBusMLTaskLoggingListenerProvider { + val task: com.jetbrains.ml.MLTask<*, *> + + override fun provide(collector: (com.jetbrains.ml.monitoring.MLTaskLoggingListener, String) -> Unit) { + collector(this, task.id) + } + + companion object { + val TOPIC = Topic.create("ml.logging", MessageBusMLTaskLoggingListenerProvider::class.java) + } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/logging.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/logging.kt new file mode 100644 index 000000000000..41d0b8359105 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/logging.kt @@ -0,0 +1,21 @@ +// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.tools + +import com.intellij.internal.statistic.eventLog.EventLogGroup +import com.intellij.openapi.Disposable +import com.intellij.openapi.observable.util.whenDisposed +import com.intellij.platform.ml.impl.tools.logs.IntelliJFusEventRegister +import com.jetbrains.ml.MLTask +import com.jetbrains.ml.model.MLModel +import org.jetbrains.annotations.ApiStatus + +@ApiStatus.Internal +fun , P : Any> EventLogGroup.registerMLTaskLogging( + task: MLTask, + parentDisposable: Disposable, + eventPrefix: String = task.id, +) { + val componentRegister = IntelliJFusEventRegister(this) + val listenerController = componentRegister.registerLogging(task, eventPrefix) + parentDisposable.whenDisposed { listenerController.remove() } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/logs/IntelliJFusEventRegister.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/logs/IntelliJFusEventRegister.kt new file mode 100644 index 000000000000..9bd0a5955117 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/logs/IntelliJFusEventRegister.kt @@ -0,0 +1,257 @@ +// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.tools.logs + +import com.intellij.internal.statistic.eventLog.FeatureUsageData +import com.intellij.internal.statistic.eventLog.events.PrimitiveEventField +import com.intellij.internal.statistic.eventLog.events.VarargEventId +import com.intellij.lang.Language +import com.intellij.openapi.util.Version +import com.intellij.platform.ml.impl.tools.logs.ConverterObjectDescription.Companion.asIJObjectDescription +import com.intellij.platform.ml.impl.tools.logs.ConverterOfEnum.Companion.toIJConverter +import com.intellij.platform.ml.impl.tools.logs.IJEventPairConverter.Companion.typedBuild +import com.jetbrains.ml.logs.schema.EventField +import com.jetbrains.ml.logs.schema.EventPair +import org.jetbrains.annotations.ApiStatus +import com.intellij.internal.statistic.eventLog.EventLogGroup as IJEventLogGroup +import com.intellij.internal.statistic.eventLog.events.BooleanEventField as IJBooleanEventField +import com.intellij.internal.statistic.eventLog.events.ClassEventField as IJClassEventField +import com.intellij.internal.statistic.eventLog.events.DoubleEventField as IJDoubleEventField +import com.intellij.internal.statistic.eventLog.events.EnumEventField as IJEnumEventField +import com.intellij.internal.statistic.eventLog.events.EventField as IJEventField +import com.intellij.internal.statistic.eventLog.events.EventFields as IJEventFields +import com.intellij.internal.statistic.eventLog.events.EventPair as IJEventPair +import com.intellij.internal.statistic.eventLog.events.FloatEventField as IJFloatEventField1 +import com.intellij.internal.statistic.eventLog.events.IntEventField as IJIntEventField +import com.intellij.internal.statistic.eventLog.events.LongEventField as IJLongEventField +import com.intellij.internal.statistic.eventLog.events.ObjectDescription as IJObjectDescription +import com.intellij.internal.statistic.eventLog.events.ObjectEventData as IJObjectEventData +import com.intellij.internal.statistic.eventLog.events.ObjectEventField as IJObjectEventField +import com.intellij.internal.statistic.eventLog.events.ObjectListEventField as IJObjectListEventField +import com.intellij.internal.statistic.eventLog.events.StringEventField as IJStringEventField +import com.jetbrains.ml.logs.schema.BooleanEventField as MLBooleanEventField +import com.jetbrains.ml.logs.schema.ClassEventField as MLClassEventField +import com.jetbrains.ml.logs.schema.DoubleEventField as MLDoubleEventField +import com.jetbrains.ml.logs.schema.EnumEventField as MLEnumEventField +import com.jetbrains.ml.logs.schema.EventField as MLEventField +import com.jetbrains.ml.logs.schema.EventPair as MLEventPair +import com.jetbrains.ml.logs.schema.FloatEventField as MLFloatEventField +import com.jetbrains.ml.logs.schema.IntEventField as MLIntEventField +import com.jetbrains.ml.logs.schema.LongEventField as MLLongEventField +import com.jetbrains.ml.logs.schema.ObjectDescription as MLObjectDescription +import com.jetbrains.ml.logs.schema.ObjectEventData as MLObjectEventData +import com.jetbrains.ml.logs.schema.ObjectEventField as MLObjectEventField +import com.jetbrains.ml.logs.schema.ObjectListEventField as MLObjectListEventField +import com.jetbrains.ml.logs.schema.StringEventField as MLStringEventField + + +@ApiStatus.Internal +class IntelliJFusEventRegister(private val baseEventGroup: IJEventLogGroup) : com.jetbrains.ml.logs.FusEventRegister { + private class Logger( + private val varargEventId: VarargEventId, + private val objectDescription: ConverterObjectDescription + ) : com.jetbrains.ml.logs.FusEventLogger { + override fun log(eventPairs: List>) { + val ijEventPairs = objectDescription.buildEventPairs(eventPairs) + varargEventId.log(*ijEventPairs.toTypedArray()) + } + } + + override fun registerEvent(name: String, eventFields: List>): com.jetbrains.ml.logs.FusEventLogger { + val objectDescription = ConverterObjectDescription(MLObjectDescription(eventFields)) + val varargEventId = baseEventGroup.registerVarargEvent(name, null, *objectDescription.getFields()) + return Logger(varargEventId, objectDescription) + } +} + +@Suppress("UNCHECKED_CAST") +private fun createConverter(mlEventField: MLEventField): IJEventPairConverter = when (mlEventField) { + is MLObjectEventField -> ConverterOfObject( + mlEventField.name, + mlEventField.lazyDescription, + mlEventField.objectDescription + ) as IJEventPairConverter + is MLBooleanEventField -> ConverterOfPrimitiveType(mlEventField) { n, d -> IJBooleanEventField(n, d) } as IJEventPairConverter + is MLIntEventField -> ConverterOfPrimitiveType(mlEventField) { n, d -> IJIntEventField(n, d) } as IJEventPairConverter + is MLLongEventField -> ConverterOfPrimitiveType(mlEventField) { n, d -> IJLongEventField(n, d) } as IJEventPairConverter + is MLFloatEventField -> ConverterOfPrimitiveType(mlEventField) { n, d -> IJFloatEventField1(n, d) } as IJEventPairConverter + is MLEnumEventField<*> -> mlEventField.toIJConverter() as IJEventPairConverter + is MLClassEventField -> ConverterOfClass(mlEventField) as IJEventPairConverter + is MLObjectListEventField -> ConvertObjectList(mlEventField) as IJEventPairConverter + is MLDoubleEventField -> ConverterOfPrimitiveType(mlEventField) { n, d -> IJDoubleEventField(n, d) } as IJEventPairConverter + is VersionEventField -> ConverterOfVersion(mlEventField) as IJEventPairConverter + is LanguageEventField -> ConverterOfLanguage(mlEventField) as IJEventPairConverter + is MLStringEventField -> ConverterOfString(mlEventField) as IJEventPairConverter + + is IJSpecificEventField<*> -> { + when (mlEventField) { + is IJCustomEventField -> ConverterOfCustom(mlEventField) + is LanguageEventField -> ConverterOfLanguage(mlEventField) as IJEventPairConverter + is VersionEventField -> ConverterOfVersion(mlEventField) as IJEventPairConverter + } + } + + else -> throw IllegalArgumentException( + """ + Conversion of ${mlEventField.javaClass.simpleName} is not possible. + If you want to create your own field, you must add an inheritor of + ${IJCustomEventField::class.qualifiedName} + """.trimIndent() + ) +} + +private class ConverterOfCustom(mlEventField: IJCustomEventField) : IJEventPairConverter { + override val ijEventField: IJEventField = mlEventField.baseIJEventField + + override fun buildEventPair(mlEventPair: EventPair): IJEventPair { + return ijEventField with mlEventPair.data + } +} + +private class ConverterOfString(mlEventField: MLStringEventField) : IJEventPairConverter { + override val ijEventField: IJEventField = IJStringEventField.ValidatedByAllowedValues( + mlEventField.name, + allowedValues = mlEventField.possibleValues, + description = mlEventField.lazyDescription() + ) + + override fun buildEventPair(mlEventPair: EventPair): IJEventPair { + return ijEventField with mlEventPair.data + } +} + +private class ConverterOfLanguage(mlEventField: LanguageEventField) : IJEventPairConverter { + override val ijEventField: IJEventField = IJEventFields.Language(mlEventField.name, mlEventField.lazyDescription()) + + override fun buildEventPair(mlEventPair: MLEventPair): IJEventPair { + return ijEventField with mlEventPair.data + } +} + +private class ConverterOfVersion(mlEventField: com.intellij.platform.ml.impl.tools.logs.VersionEventField) : IJEventPairConverter { + private class VersionEventField(override val name: String, override val description: String?) : PrimitiveEventField() { + override val validationRule: List + get() = listOf("{regexp#version}") + + override fun addData(fuData: FeatureUsageData, value: Version?) { + fuData.addVersion(value) + } + } + + override val ijEventField: IJEventField = VersionEventField(mlEventField.name, mlEventField.lazyDescription()) + + override fun buildEventPair(mlEventPair: MLEventPair): IJEventPair { + return ijEventField with mlEventPair.data + } +} + +private class ConvertObjectList(mlEventField: MLObjectListEventField) : + IJEventPairConverter, List> { + private val innerObjectConverter = ConverterOfObject(mlEventField.name, mlEventField.lazyDescription, mlEventField.internalObjectDescription) + + // FIXME: description is not passed + override val ijEventField: IJEventField> = IJObjectListEventField( + mlEventField.name, + innerObjectConverter.ijObjectDescription + ) + + override fun buildEventPair(mlEventPair: MLEventPair>): IJEventPair> { + return ijEventField with mlEventPair.data.map { innerObjectFieldsValues -> + innerObjectConverter.buildObjectEventData(innerObjectFieldsValues) + } + } +} + +private class ConverterOfEnum>(mlEnumField: MLEnumEventField) : IJEventPairConverter { + override val ijEventField: IJEventField = IJEnumEventField(mlEnumField.name, mlEnumField.enumClass, mlEnumField.transform) + + override fun buildEventPair(mlEventPair: MLEventPair): IJEventPair { + return ijEventField with mlEventPair.data + } + + companion object { + fun > MLEnumEventField.toIJConverter(): ConverterOfEnum { + return ConverterOfEnum(this) + } + } +} + +private interface IJEventPairConverter { + val ijEventField: IJEventField + + fun buildEventPair(mlEventPair: MLEventPair): IJEventPair + + companion object { + fun IJEventPairConverter.typedBuild(mlEventPair: MLEventPair<*>): IJEventPair { + @Suppress("UNCHECKED_CAST") + return buildEventPair(mlEventPair as MLEventPair) + } + } +} + +private class ConverterObjectDescription(mlObjectDescription: MLObjectDescription) : IJObjectDescription() { + private val toIJConverters: Map, IJEventPairConverter<*, *>> = mlObjectDescription.getFields().associateWith { mlField -> + val converter = createConverter(mlField) + field(converter.ijEventField) + converter + } + + fun buildEventPairs(mlEventPairs: List>): List> { + return mlEventPairs.map { mlEventPair -> + require(mlEventPair.field in toIJConverters) { + """ + Field ${mlEventPair.field} (name: ${mlEventPair.field.name}) was not found among + the registered ones: ${toIJConverters.keys.map { it.name }} + """.trimIndent() + } + val converter = requireNotNull(toIJConverters[mlEventPair.field]) + converter.typedBuild(mlEventPair) + } + } + + fun buildObjectEventData(mlObject: MLObjectEventData): IJObjectEventData { + return IJObjectEventData(buildEventPairs(mlObject.values)) + } + + companion object { + fun MLObjectDescription.asIJObjectDescription() = ConverterObjectDescription(this) + } +} + +private class ConverterOfObject( + name: String, + lazyDescription: () -> String, + mlObjectDescription: MLObjectDescription, +) : IJEventPairConverter { + val ijObjectDescription = mlObjectDescription.asIJObjectDescription() + + override val ijEventField: IJEventField = IJObjectEventField(name, lazyDescription(), ijObjectDescription) + + fun buildObjectEventData(mlObject: MLObjectEventData): IJObjectEventData { + return ijObjectDescription.buildObjectEventData(mlObject) + } + + override fun buildEventPair(mlEventPair: MLEventPair): IJEventPair { + return ijEventField with buildObjectEventData(mlEventPair.data) + } +} + +private class ConverterOfPrimitiveType( + mlEventField: MLEventField, + createIJField: (String, String?) -> IJEventField +) : IJEventPairConverter { + override val ijEventField: IJEventField = createIJField(mlEventField.name, mlEventField.lazyDescription()) + + override fun buildEventPair(mlEventPair: MLEventPair): IJEventPair { + return ijEventField with mlEventPair.data + } +} + +private class ConverterOfClass( + mlEventField: MLClassEventField, +) : IJEventPairConverter, Class<*>?> { + override val ijEventField: IJEventField?> = IJClassEventField(mlEventField.name, mlEventField.lazyDescription()) + + override fun buildEventPair(mlEventPair: MLEventPair>): IJEventPair?> { + return ijEventField with mlEventPair.data + } +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/logs/LanguageSpecific.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/logs/LanguageSpecific.kt new file mode 100644 index 000000000000..40238772e170 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/logs/LanguageSpecific.kt @@ -0,0 +1,32 @@ +// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.tools.logs + +import com.intellij.lang.Language +import com.jetbrains.ml.logs.schema.EventField +import org.jetbrains.annotations.ApiStatus + +/** + * Something, that is dedicated for one language only. + */ +@ApiStatus.Internal +interface LanguageSpecific { + val language: Language +} + +/** + * The analyzer, that adds information about ML model's language to logs. + */ +@ApiStatus.Internal +class ModelLanguageAnalyser : com.jetbrains.ml.analysis.MLTaskAnalyserTyped + where M : com.jetbrains.ml.model.MLModel

, + M : LanguageSpecific { + private val LANGUAGE = LanguageEventField("model_language") { "The programming language the ML model is trained for" } + + override fun startMLSessionAnalysis(sessionInfo: com.jetbrains.ml.session.MLSessionInfo) = object : com.jetbrains.ml.analysis.MLSessionAnalyserTyped { + override suspend fun analyseSession(tree: com.jetbrains.ml.tree.MLTree.ATopNode?) = buildList { + tree?.mlModel?.let { add(LANGUAGE with it.language) } + } + } + + override val declaration: List> = listOf(LANGUAGE) +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/logs/Versioned.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/logs/Versioned.kt new file mode 100644 index 000000000000..703a68042908 --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/logs/Versioned.kt @@ -0,0 +1,34 @@ +// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.tools.logs + +import com.intellij.openapi.util.Version +import com.jetbrains.ml.logs.schema.EventField +import org.jetbrains.annotations.ApiStatus + +/** + * Something, that has versions. + */ +@ApiStatus.Internal +interface Versioned { + val version: Version? +} + +/** + * Adds model's version to the ML logs. + */ +@ApiStatus.Internal +class ModelVersionAnalyser : com.jetbrains.ml.analysis.MLTaskAnalyserTyped + where M : com.jetbrains.ml.model.MLModel

, + M : Versioned { + companion object { + private val VERSION = VersionEventField("model_version") { "Version of the ML model" } + } + + override fun startMLSessionAnalysis(sessionInfo: com.jetbrains.ml.session.MLSessionInfo) = object : com.jetbrains.ml.analysis.MLSessionAnalyserTyped { + override suspend fun analyseSession(tree: com.jetbrains.ml.tree.MLTree.ATopNode?) = buildList { + tree?.mlModel?.version?.let { add(VERSION with it) } + } + } + + override val declaration: List> = listOf(VERSION) +} diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/logs/eventFields.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/logs/eventFields.kt new file mode 100644 index 000000000000..15021f87bfdd --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/logs/eventFields.kt @@ -0,0 +1,20 @@ +// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.tools.logs + +import com.intellij.lang.Language +import com.intellij.openapi.util.Version +import com.jetbrains.ml.logs.schema.CustomRuleEventField +import org.jetbrains.annotations.ApiStatus +import com.intellij.internal.statistic.eventLog.events.EventField as IJEventField + +@ApiStatus.Internal +sealed class IJSpecificEventField(name: String, lazyDescription: () -> String) : CustomRuleEventField(name, lazyDescription) + +@ApiStatus.Internal +class VersionEventField(name: String, lazyDescription: () -> String) : IJSpecificEventField(name, lazyDescription) + +@ApiStatus.Internal +class LanguageEventField(name: String, lazyDescription: () -> String) : IJSpecificEventField(name, lazyDescription) + +@ApiStatus.Internal +open class IJCustomEventField(val baseIJEventField: IJEventField) : IJSpecificEventField(baseIJEventField.name, { baseIJEventField.description ?: "" }) diff --git a/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/model/RegressionModel.kt b/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/model/RegressionModel.kt new file mode 100644 index 000000000000..8ed44312757e --- /dev/null +++ b/platform/ml-impl/src/com/intellij/platform/ml/impl/tools/model/RegressionModel.kt @@ -0,0 +1,142 @@ +// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. +package com.intellij.platform.ml.impl.tools.model + +import com.intellij.internal.ml.DecisionFunction +import com.jetbrains.ml.* +import com.jetbrains.ml.model.MLModel +import org.jetbrains.annotations.ApiStatus + + +/** + * A wrapper for using legacy [DecisionFunction] ML API`s [MLModel]. + */ +@ApiStatus.Internal +open class RegressionModel private constructor( + private val decisionFunction: DecisionFunction, + private val modelUnits: Set>, + globalUnits: Set>, + private val featureSerialization: FeatureNameSerialization, +) : MLModel { + constructor( + decisionFunction: DecisionFunction, + featureSerialization: FeatureNameSerialization, + globalUnits: Set>, + ) : this( + decisionFunction = decisionFunction, + modelUnits = decisionFunction.featuresOrder.map { featureMapper -> + featureSerialization.deserialize(featureMapper.featureName, globalUnits.associateBy { it.name }).first + }.toSet(), + globalUnits = globalUnits, + featureSerialization = featureSerialization + ) + + override val knownFeatures: Map, FeatureFilter> = createKnownFeatures( + DecisionFunctionWrapper(decisionFunction, globalUnits, featureSerialization), + modelUnits + ) + + override fun predict(taskUnits: List, contexts: List, features: Map, List>): Double { + val array = DoubleArray(decisionFunction.featuresOrder.size) + val featurePerSerializedName: Map = features + .flatMap { (unit, unitFeatures) -> unitFeatures.map { unit to it } } + .associate { (unit, feature) -> featureSerialization.serialize(unit, feature.declaration.name) to feature } + + require(features.keys == modelUnits) { + "Given features units are ${features.keys}, but this model needs ${modelUnits}" + } + + for (featureI in decisionFunction.featuresOrder.indices) { + val featureMapper = decisionFunction.featuresOrder[featureI] + val featureSerializedName = featureMapper.featureName + val featureValue = featureMapper.asArrayValue(featurePerSerializedName[featureSerializedName]?.value) + array[featureI] = featureValue + } + + return decisionFunction.predict(array) + } + + interface FeatureNameSerialization { + fun serialize(mlUnit: MLUnit<*>, featureName: String): String + + fun deserialize(serializedFeatureName: String, unitsPerName: Map>): Pair, String> + } + + object DefaultFeatureSerialization : FeatureNameSerialization { + private const val SERIALIZED_FEATURE_SEPARATOR = '/' + + override fun serialize(mlUnit: MLUnit<*>, featureName: String): String { + return mlUnit.name + SERIALIZED_FEATURE_SEPARATOR + featureName + } + + override fun deserialize(serializedFeatureName: String, unitsPerName: Map>): Pair, String> { + val indexOfLastSeparator = serializedFeatureName.indexOfLast { it == SERIALIZED_FEATURE_SEPARATOR } + require(indexOfLastSeparator >= 0) { "Feature name '$serializedFeatureName' does not contain ML unit's name" } + val unitName = serializedFeatureName.slice(0 until indexOfLastSeparator) + val featureName = serializedFeatureName.slice(indexOfLastSeparator until serializedFeatureName.length) + val featureUnit = requireNotNull(unitsPerName[unitName]) { + """ + Serialized feature '$serializedFeatureName' has tier $unitName, + but all available tiers are ${unitsPerName.keys} + """.trimIndent() + } + return featureUnit to featureName + } + } + + private class DecisionFunctionWrapper( + private val decisionFunction: DecisionFunction, + globalUnits: Set>, + private val featureNameSerialization: FeatureNameSerialization, + ) { + private val availableUnitsPerName: Map> = globalUnits.associateBy { it.name } + + val knownFeatures: Map, Set> = run { + val knownFeaturesSerializedNames = decisionFunction.featuresOrder.map { it.featureName }.toSet() + knownFeaturesSerializedNames + .map { featureNameSerialization.deserialize(it, availableUnitsPerName) } + .groupBy({ it.first }, { it.second }) + .mapValues { it.value.toSet() } + } + + fun getUnknownFeatures(unit: MLUnit<*>, featuresNames: Set): Set { + val knownFeatures = knownFeatures[unit] ?: return featuresNames + return featuresNames.filterNot { it in knownFeatures }.toSet() + } + } + + companion object { + private fun createKnownFeatures( + decisionFunction: DecisionFunctionWrapper, + modelUnits: Set>, + ): Map, FeatureFilter> { + fun createFeatureSelector(unit: MLUnit<*>) = object : FeatureFilter { + init { + val knownFeatures = decisionFunction.knownFeatures + knownFeatures.forEach { (unit, features) -> + val nonConsistentlyKnownFeatures = decisionFunction.getUnknownFeatures(unit, features) + require(nonConsistentlyKnownFeatures.isEmpty()) { + "These features are known and unknown at the same time: $nonConsistentlyKnownFeatures" + } + } + } + + override fun accept(featureDeclarations: Set>): Set> { + val availableFeaturesPerName = featureDeclarations.associateBy { it.name } + val availableFeaturesNames = featureDeclarations.map { it.name }.toSet() + val unknownFeaturesNames = decisionFunction.getUnknownFeatures(unit, availableFeaturesNames) + val knownAvailableFeaturesNames = availableFeaturesNames - unknownFeaturesNames + val knownAvailableFeatures = knownAvailableFeaturesNames.map { availableFeaturesPerName.getValue(it) }.toSet() + + return knownAvailableFeatures + } + + override fun accept(featureDeclaration: FeatureDeclaration<*>): Boolean { + val unknown = decisionFunction.getUnknownFeatures(unit, setOf(featureDeclaration.name)) + return unknown.isEmpty() + } + } + + return modelUnits.associateWith { createFeatureSelector(it) } + } + } +} diff --git a/platform/platform-resources/src/META-INF/PlatformExtensionPoints.xml b/platform/platform-resources/src/META-INF/PlatformExtensionPoints.xml index 29fc5ea1eb9b..73d4da7a610d 100644 --- a/platform/platform-resources/src/META-INF/PlatformExtensionPoints.xml +++ b/platform/platform-resources/src/META-INF/PlatformExtensionPoints.xml @@ -541,6 +541,9 @@ +