diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/external/artifacts/LocalArtifactsManager.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/external/artifacts/LocalArtifactsManager.kt index 50c806e71952..1a37ef16fc94 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/external/artifacts/LocalArtifactsManager.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/external/artifacts/LocalArtifactsManager.kt @@ -22,7 +22,6 @@ import com.intellij.platform.ml.embeddings.external.artifacts.LocalArtifactsMana import com.intellij.platform.ml.embeddings.external.artifacts.LocalArtifactsManager.Companion.getArchitectureId import com.intellij.platform.ml.embeddings.external.artifacts.LocalArtifactsManager.Companion.getOsId import com.intellij.platform.ml.embeddings.indexer.FileBasedEmbeddingIndexer.Companion.INDEXING_VERSION -import com.intellij.platform.ml.embeddings.jvm.models.CustomRootDataLoader import com.intellij.util.concurrency.annotations.RequiresBackgroundThread import com.intellij.util.download.DownloadableFileService import com.intellij.util.io.ZipUtil @@ -55,8 +54,6 @@ class LocalArtifactsManager { // TODO: the list may be changed depending on the project size private val availableModels = listOf(ModelArtifact.SmallModelArtifact) - fun getCustomRootDataLoader() = CustomRootDataLoader(modelsRoot) - @RequiresBackgroundThread suspend fun downloadArtifactsIfNecessary( project: Project? = null, @@ -90,9 +87,7 @@ class LocalArtifactsManager { // TODO: provide model id in arguments fun getModelArtifact(): ModelArtifact = availableModels.first() - fun getServerArtifact() = NativeServerArtifact - - fun checkArtifactsPresent(): Boolean = availableModels.all { it.checkPresent() } + fun getServerArtifact(): NativeServerArtifact = NativeServerArtifact private fun downloadArtifacts(artifacts: List) { try { @@ -136,9 +131,9 @@ class LocalArtifactsManager { } companion object { - const val SEMANTIC_SEARCH_RESOURCES_DIR_NAME = "semantic-search" // TODO: move to common constants + const val SEMANTIC_SEARCH_RESOURCES_DIR_NAME: String = "semantic-search" // TODO: move to common constants - const val INDICES_DIR_NAME = "indices" + const val INDICES_DIR_NAME: String = "indices" private const val MODELS_DIR_NAME = "models" private const val SERVER_DIR_NAME = "server" @@ -204,7 +199,7 @@ sealed class ModelArtifact( ) : DownloadableArtifact { data object SmallModelArtifact : ModelArtifact("small", "dan_100k_optimized.onnx", "bert-base-uncased.txt") - final override val archiveName = "$name.zip" + final override val archiveName: String = "$name.zip" override val downloadLink: String = listOf(CDN_LINK_BASE, MODEL_VERSION, archiveName).joinToString(separator = "/") override val destination: Path = LocalArtifactsManager.modelsRoot / name diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/external/client/NativeServerStartupArguments.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/external/client/NativeServerStartupArguments.kt index bb9bb05db442..ddec179d8a4f 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/external/client/NativeServerStartupArguments.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/external/client/NativeServerStartupArguments.kt @@ -9,7 +9,7 @@ enum class EmbeddingDistanceMetric(val label: String) { COSINE("cos"), SQUARED_EUCLIDEAN("l2sq"); - override fun toString() = label + override fun toString(): String = label } enum class EmbeddingQuantization(val label: String) { @@ -18,7 +18,7 @@ enum class EmbeddingQuantization(val label: String) { INT8("i8"), BINARY("b1x8"); - override fun toString() = label + override fun toString(): String = label } data class NativeServerStartupArguments( diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/external/client/listeners/ServerDiagnosticsListener.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/external/client/listeners/ServerDiagnosticsListener.kt index 17dee776eb8c..3a802ce331dd 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/external/client/listeners/ServerDiagnosticsListener.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/external/client/listeners/ServerDiagnosticsListener.kt @@ -62,6 +62,6 @@ class ServerDiagnosticsListener : ProcessListener { } companion object { - const val MAX_HISTORY_SIZE = 1_000 + const val MAX_HISTORY_SIZE: Int = 1_000 } } \ No newline at end of file diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/files/package-info.java b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/files/package-info.java deleted file mode 100644 index 1a6facfa7515..000000000000 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/files/package-info.java +++ /dev/null @@ -1,5 +0,0 @@ -// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. -@ApiStatus.Internal -package com.intellij.platform.ml.embeddings.files; - -import org.jetbrains.annotations.ApiStatus; \ No newline at end of file diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/indexer/entities/IndexableFile.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/indexer/entities/IndexableFile.kt index a45d8b32ed2b..7b7c66ffd443 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/indexer/entities/IndexableFile.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/indexer/entities/IndexableFile.kt @@ -8,5 +8,5 @@ import com.intellij.platform.ml.embeddings.utils.splitIdentifierIntoTokens class IndexableFile(override val id: EntityId) : IndexableEntity { constructor(file: VirtualFile) : this(EntityId(file.name)) - override val indexableRepresentation by lazy { splitIdentifierIntoTokens(FileUtilRt.getNameWithoutExtension(id.id)) } + override val indexableRepresentation: String by lazy { splitIdentifierIntoTokens(FileUtilRt.getNameWithoutExtension(id.id)) } } \ No newline at end of file diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/indexer/keys/IntegerStorageKeyProvider.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/indexer/keys/IntegerStorageKeyProvider.kt index bd597923b644..6180ec420cbe 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/indexer/keys/IntegerStorageKeyProvider.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/indexer/keys/IntegerStorageKeyProvider.kt @@ -53,8 +53,8 @@ class IntegerStorageKeyProvider : EmbeddingStorageKeyProvider, Disposable } companion object { - private const val ENUMERATOR_FOLDER = "enumerator" - private const val ENUMERATOR_FILE = "ids.enum" + private const val ENUMERATOR_FOLDER: String = "enumerator" + private const val ENUMERATOR_FILE: String = "ids.enum" fun getInstance(): IntegerStorageKeyProvider = service() } diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/indexer/storage/EmbeddingsStorageManagerWrapper.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/indexer/storage/EmbeddingsStorageManagerWrapper.kt index 4495c205b794..c0ea469de904 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/indexer/storage/EmbeddingsStorageManagerWrapper.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/indexer/storage/EmbeddingsStorageManagerWrapper.kt @@ -60,7 +60,7 @@ class EmbeddingsStorageManagerWrapper( return storageManager.getStorageStats(project, indexId) } - fun getBatchSize() = storageManager.getBatchSize() + fun getBatchSize(): Int = storageManager.getBatchSize() companion object { private const val INDEXABLE_REPRESENTATION_CHAR_LIMIT = 64 diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/artifacts/KInferenceLocalArtifactsManager.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/artifacts/KInferenceLocalArtifactsManager.kt index 96bbcb730f33..d5d2f4aa30dc 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/artifacts/KInferenceLocalArtifactsManager.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/artifacts/KInferenceLocalArtifactsManager.kt @@ -49,29 +49,31 @@ class KInferenceLocalArtifactsManager { root.toPath().listDirectoryEntries().filter { it.name != MODEL_ARTIFACTS_DIR }.forEach { it.delete(recursively = true) } } - fun getCustomRootDataLoader() = CustomRootDataLoader(modelArtifactsRoot.toPath()) + fun getCustomRootDataLoader(): CustomRootDataLoader = CustomRootDataLoader(modelArtifactsRoot.toPath()) @RequiresBackgroundThread suspend fun downloadArtifactsIfNecessary(project: Project? = null, - retryIfCanceled: Boolean = true) = withContext(downloadContext) { - if (!checkArtifactsPresent() && !ApplicationManager.getApplication().isUnitTestMode && (retryIfCanceled || !downloadCanceled)) { - logger.debug("Semantic search artifacts are not present, starting the download...") - if (project != null) { - withBackgroundProgress(project, ARTIFACTS_DOWNLOAD_TASK_NAME) { - try { - coroutineToIndicator { // platform code relies on the existence of indicator - downloadArtifacts() + retryIfCanceled: Boolean = true) { + withContext(downloadContext) { + if (!checkArtifactsPresent() && !ApplicationManager.getApplication().isUnitTestMode && (retryIfCanceled || !downloadCanceled)) { + logger.debug("Semantic search artifacts are not present, starting the download...") + if (project != null) { + withBackgroundProgress(project, ARTIFACTS_DOWNLOAD_TASK_NAME) { + try { + coroutineToIndicator { // platform code relies on the existence of indicator + downloadArtifacts() + } + } + catch (e: CancellationException) { + logger.debug("Artifacts downloading was canceled") + downloadCanceled = true + throw e } } - catch (e: CancellationException) { - logger.debug("Artifacts downloading was canceled") - downloadCanceled = true - throw e - } } - } - else { - downloadArtifacts() + else { + downloadArtifacts() + } } } } @@ -107,7 +109,7 @@ class KInferenceLocalArtifactsManager { } companion object { - const val SEMANTIC_SEARCH_RESOURCES_DIR = "semantic-search" + const val SEMANTIC_SEARCH_RESOURCES_DIR: String = "semantic-search" private val ARTIFACTS_DOWNLOAD_TASK_NAME get() = EmbeddingsBundle.getMessage("ml.embeddings.artifacts.download.name") diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/EmbeddingSearchIndex.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/EmbeddingSearchIndex.kt index 31d0e575aad7..b3efd4ef5fbb 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/EmbeddingSearchIndex.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/EmbeddingSearchIndex.kt @@ -17,7 +17,7 @@ interface EmbeddingSearchIndex { suspend fun contains(id: EntityId): Boolean suspend fun lookup(id: EntityId): FloatTextEmbedding? suspend fun clear() - suspend fun clearBySourceType(sourceType: EntitySourceType) = Unit + suspend fun clearBySourceType(sourceType: EntitySourceType) {} suspend fun remove(id: EntityId) suspend fun onIndexingStart() diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/EntitySourceType.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/EntitySourceType.kt index 73fe6ca2cdc3..ad2f9fd215ab 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/EntitySourceType.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/EntitySourceType.kt @@ -8,5 +8,5 @@ enum class EntitySourceType { EXTERNAL; @JsonValue - fun value() = ordinal + fun value(): Int = ordinal } \ No newline at end of file diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/InMemoryEmbeddingSearchIndex.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/InMemoryEmbeddingSearchIndex.kt index b30b0972017a..865c22209d13 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/InMemoryEmbeddingSearchIndex.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/InMemoryEmbeddingSearchIndex.kt @@ -22,16 +22,18 @@ class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null) private val fileManager = LocalEmbeddingIndexFileManager(root) - override suspend fun getSize() = lock.read { idToEmbedding.size } + override suspend fun getSize(): Int = lock.read { idToEmbedding.size } - override suspend fun setLimit(value: Int?) = lock.write { - // Shrink index if necessary: - if (value != null && value < idToEmbedding.size) { - val remaining = idToEmbedding.asSequence().take(value).map { it.toPair() }.toList() - idToEmbedding.clear() - idToEmbedding.putAll(remaining) + override suspend fun setLimit(value: Int?) { + lock.write { + // Shrink index if necessary: + if (value != null && value < idToEmbedding.size) { + val remaining = idToEmbedding.asSequence().take(value).map { it.toPair() }.toList() + idToEmbedding.clear() + idToEmbedding.putAll(remaining) + } + limit = value } - limit = value } override suspend fun contains(id: EntityId): Boolean = lock.read { @@ -40,9 +42,11 @@ class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null) override suspend fun lookup(id: EntityId): FloatTextEmbedding? = lock.read { idToEmbedding[id] } - override suspend fun clear() = lock.write { - idToEmbedding.clear() - uncheckedIds.clear() + override suspend fun clear() { + lock.write { + idToEmbedding.clear() + uncheckedIds.clear() + } } override suspend fun remove(id: EntityId) { @@ -58,34 +62,44 @@ class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null) } } - override suspend fun onIndexingFinish() = lock.write { - uncheckedIds.forEach { idToEmbedding.remove(it) } - uncheckedIds.clear() + override suspend fun onIndexingFinish() { + lock.write { + uncheckedIds.forEach { idToEmbedding.remove(it) } + uncheckedIds.clear() + } } override suspend fun addEntries( values: Iterable>, shouldCount: Boolean, - ) = lock.write { - if (limit != null) { - val list = values.toList() - list.forEach { uncheckedIds.remove(it.first) } - idToEmbedding.putAll(list.take(minOf(limit!! - idToEmbedding.size, list.size))) - } - else { - idToEmbedding.putAll(values) + ) { + lock.write { + if (limit != null) { + val list = values.toList() + list.forEach { uncheckedIds.remove(it.first) } + idToEmbedding.putAll(list.take(minOf(limit!! - idToEmbedding.size, list.size))) + } + else { + idToEmbedding.putAll(values) + } } } - override suspend fun saveToDisk() = lock.read { save() } - - override suspend fun loadFromDisk() = lock.write { - val (ids, embeddings) = fileManager.loadIndex() ?: return@write - idToEmbedding.clear() - idToEmbedding.putAll(ids zip embeddings) + override suspend fun saveToDisk() { + lock.read { save() } } - override suspend fun offload() = lock.write { idToEmbedding.clear() } + override suspend fun loadFromDisk() { + lock.write { + val (ids, embeddings) = fileManager.loadIndex() ?: return@write + idToEmbedding.clear() + idToEmbedding.putAll(ids zip embeddings) + } + } + + override suspend fun offload() { + lock.write { idToEmbedding.clear() } + } override suspend fun findClosest( searchEmbedding: FloatTextEmbedding, @@ -108,7 +122,7 @@ class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null) } } - override suspend fun estimateMemoryUsage() = fileManager.embeddingSizeInBytes.toLong() * getSize() + override suspend fun estimateMemoryUsage(): Long = fileManager.embeddingSizeInBytes.toLong() * getSize() override fun estimateLimitByMemory(memory: Long): Int { return (memory / fileManager.embeddingSizeInBytes).toInt() diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/IndexPersistedEventsCounter.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/IndexPersistedEventsCounter.kt index f4dd54b2921a..e91f2df861fe 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/IndexPersistedEventsCounter.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/IndexPersistedEventsCounter.kt @@ -10,7 +10,7 @@ import com.intellij.platform.ml.embeddings.indexer.IndexId */ interface IndexPersistedEventsCounter { companion object { - val EP_NAME = ProjectExtensionPointName("com.intellij.platform.ml.embeddings.indexPersistedEventsCounter") + val EP_NAME: ProjectExtensionPointName = ProjectExtensionPointName("com.intellij.platform.ml.embeddings.indexPersistedEventsCounter") } suspend fun sendPersistedCount(indexId: IndexId, project: Project) diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/LocalEmbeddingIndexFileManager.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/LocalEmbeddingIndexFileManager.kt index 19a7d8940b49..73eaa3bb1b14 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/LocalEmbeddingIndexFileManager.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/LocalEmbeddingIndexFileManager.kt @@ -9,7 +9,6 @@ import com.fasterxml.jackson.databind.SerializationFeature import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper import com.fasterxml.jackson.module.kotlin.readValue import com.intellij.platform.ml.embeddings.jvm.utils.SuspendingReadWriteLock -import com.intellij.util.io.outputStream import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.coroutineScope import kotlinx.coroutines.ensureActive @@ -20,6 +19,7 @@ import java.nio.file.Files import java.nio.file.Path import kotlin.io.path.exists import kotlin.io.path.inputStream +import kotlin.io.path.outputStream class LocalEmbeddingIndexFileManager(root: Path, private val dimensions: Int = DEFAULT_DIMENSIONS) { private val lock = SuspendingReadWriteLock() @@ -36,7 +36,7 @@ class LocalEmbeddingIndexFileManager(root: Path, private val dimensions: Int = D private val embeddingsPath get() = rootPath.resolve(EMBEDDINGS_FILENAME) - val embeddingSizeInBytes = dimensions * EMBEDDING_ELEMENT_SIZE + val embeddingSizeInBytes: Int = dimensions * EMBEDDING_ELEMENT_SIZE /** Provides reading access to the embedding vector at the specified index * without reading the whole file into memory @@ -55,12 +55,14 @@ class LocalEmbeddingIndexFileManager(root: Path, private val dimensions: Int = D /** Provides writing access to embedding vector at the specified index * without writing the other vectors */ - suspend fun set(index: Int, embedding: FloatTextEmbedding) = lock.write { - RandomAccessFile(embeddingsPath.toFile(), "rw").use { output -> - output.seek(getIndexOffset(index)) - val buffer = ByteBuffer.allocate(EMBEDDING_ELEMENT_SIZE) - embedding.values.forEach { - output.write(buffer.putFloat(0, it).array()) + suspend fun set(index: Int, embedding: FloatTextEmbedding) { + lock.write { + RandomAccessFile(embeddingsPath.toFile(), "rw").use { output -> + output.seek(getIndexOffset(index)) + val buffer = ByteBuffer.allocate(EMBEDDING_ELEMENT_SIZE) + embedding.values.forEach { + output.write(buffer.putFloat(0, it).array()) + } } } } @@ -69,23 +71,25 @@ class LocalEmbeddingIndexFileManager(root: Path, private val dimensions: Int = D * Removes the embedding vector at the specified index. * To do so, replaces this vector with the last vector in the file and shrinks the file size. */ - suspend fun removeAtIndex(index: Int) = lock.write { - RandomAccessFile(embeddingsPath.toFile(), "rw").use { file -> - if (file.length() < embeddingSizeInBytes) return@write - if (file.length() - embeddingSizeInBytes != getIndexOffset(index)) { - file.seek(file.length() - embeddingSizeInBytes) - val array = ByteArray(EMBEDDING_ELEMENT_SIZE) - val embedding = FloatTextEmbedding(FloatArray(dimensions) { - file.read(array) - ByteBuffer.wrap(array).getFloat() - }) - file.seek(getIndexOffset(index)) - val buffer = ByteBuffer.allocate(EMBEDDING_ELEMENT_SIZE) - embedding.values.forEach { - file.write(buffer.putFloat(0, it).array()) + suspend fun removeAtIndex(index: Int) { + lock.write { + RandomAccessFile(embeddingsPath.toFile(), "rw").use { file -> + if (file.length() < embeddingSizeInBytes) return@write + if (file.length() - embeddingSizeInBytes != getIndexOffset(index)) { + file.seek(file.length() - embeddingSizeInBytes) + val array = ByteArray(EMBEDDING_ELEMENT_SIZE) + val embedding = FloatTextEmbedding(FloatArray(dimensions) { + file.read(array) + ByteBuffer.wrap(array).getFloat() + }) + file.seek(getIndexOffset(index)) + val buffer = ByteBuffer.allocate(EMBEDDING_ELEMENT_SIZE) + embedding.values.forEach { + file.write(buffer.putFloat(0, it).array()) + } } + file.setLength(file.length() - embeddingSizeInBytes) } - file.setLength(file.length() - embeddingSizeInBytes) } } @@ -111,16 +115,18 @@ class LocalEmbeddingIndexFileManager(root: Path, private val dimensions: Int = D } LoadedIndex(ids, embeddings) } - catch (e: JsonProcessingException) { + catch (_: JsonProcessingException) { return@read null } } } - suspend fun saveIds(ids: List) = lock.write { - withNotEnoughSpaceCheck { - idsPath.outputStream().buffered().use { output -> - mapper.writer(prettyPrinter).writeValue(output, ids) + suspend fun saveIds(ids: List) { + lock.write { + withNotEnoughSpaceCheck { + idsPath.outputStream().buffered().use { output -> + mapper.writer(prettyPrinter).writeValue(output, ids) + } } } } @@ -169,8 +175,8 @@ class LocalEmbeddingIndexFileManager(root: Path, private val dimensions: Int = D } companion object { - const val DEFAULT_DIMENSIONS = 128 - const val EMBEDDING_ELEMENT_SIZE = 4 + const val DEFAULT_DIMENSIONS: Int = 128 + const val EMBEDDING_ELEMENT_SIZE: Int = 4 private const val IDS_FILENAME = "ids.json" private const val SOURCE_TYPES_FILENAME = "sourceTypes.json" diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/VanillaEmbeddingSearchIndex.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/VanillaEmbeddingSearchIndex.kt index 39ea7373d999..2111215cdddf 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/VanillaEmbeddingSearchIndex.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/indices/VanillaEmbeddingSearchIndex.kt @@ -31,15 +31,17 @@ open class VanillaEmbeddingSearchIndex(val root: Path, override var limit: Int? private val fileManager = LocalEmbeddingIndexFileManager(root) - override suspend fun setLimit(value: Int?) = lock.write { - if (value != null) { - // Shrink index if necessary: - while (idToEntry.size > value) { - delete(indexToId[idToEntry.size - 1]!!, all = true, shouldSaveIds = false) + override suspend fun setLimit(value: Int?) { + lock.write { + if (value != null) { + // Shrink index if necessary: + while (idToEntry.size > value) { + delete(indexToId[idToEntry.size - 1]!!, all = true, shouldSaveIds = false) + } + saveIds() } - saveIds() + limit = value } - limit = value } private data class IndexEntry( @@ -48,19 +50,19 @@ open class VanillaEmbeddingSearchIndex(val root: Path, override var limit: Int? val embedding: FloatTextEmbedding, ) - override suspend fun getSize() = lock.read { idToEntry.size } + override suspend fun getSize(): Int = lock.read { idToEntry.size } - override suspend fun contains(id: EntityId): Boolean = lock.read { - id in idToEntry - } + override suspend fun contains(id: EntityId): Boolean = lock.read { id in idToEntry } override suspend fun lookup(id: EntityId): FloatTextEmbedding? = lock.read { idToEntry[id]?.embedding } - override suspend fun clear() = lock.write { - indexToId.clear() - idToEntry.clear() - uncheckedIds.clear() - changed = false + override suspend fun clear() { + lock.write { + indexToId.clear() + idToEntry.clear() + uncheckedIds.clear() + changed = false + } } override suspend fun onIndexingStart() { @@ -71,57 +73,67 @@ open class VanillaEmbeddingSearchIndex(val root: Path, override var limit: Int? } } - override suspend fun onIndexingFinish() = lock.write { - if (uncheckedIds.size > 0) changed = true - logger.debug { "Deleted ${uncheckedIds.size} unchecked ids" } - uncheckedIds.forEach { - delete(it, all = true, shouldSaveIds = false) + override suspend fun onIndexingFinish() { + lock.write { + if (uncheckedIds.isNotEmpty()) changed = true + logger.debug { "Deleted ${uncheckedIds.size} unchecked ids" } + uncheckedIds.forEach { + delete(it, all = true, shouldSaveIds = false) + } + uncheckedIds.clear() } - uncheckedIds.clear() } override suspend fun addEntries( values: Iterable>, shouldCount: Boolean, - ) = lock.write { - for ((id, embedding) in values) { - checkCancelled() - uncheckedIds.remove(id) - if (limit != null && idToEntry.size >= limit!!) break - val entry = idToEntry.getOrPut(id) { - changed = true - val index = idToEntry.size + ) { + lock.write { + for ((id, embedding) in values) { + checkCancelled() + uncheckedIds.remove(id) + if (limit != null && idToEntry.size >= limit!!) break + val entry = idToEntry.getOrPut(id) { + changed = true + val index = idToEntry.size + indexToId[index] = id + IndexEntry(index = index, count = 0, embedding = embedding) + } + if (shouldCount || entry.count == 0) { + entry.count += 1 + } + } + } + } + + override suspend fun saveToDisk() { + lock.read { + val ids = idToEntry.asSequence().sortedBy { it.value.index }.map { it.key }.toList() + val embeddings = ids.map { idToEntry[it]!!.embedding } + fileManager.saveIndex(ids, embeddings) + } + } + + override suspend fun loadFromDisk() { + lock.write { + indexToId.clear() + idToEntry.clear() + val (ids, embeddings) = fileManager.loadIndex() ?: return@write + for ((index, id) in ids.withIndex()) { + val embedding = embeddings[index] indexToId[index] = id - IndexEntry(index = index, count = 0, embedding = embedding) - } - if (shouldCount || entry.count == 0) { - entry.count += 1 + idToEntry[id] = IndexEntry(index = index, count = 0, embedding = embedding) } } } - override suspend fun saveToDisk() = lock.read { - val ids = idToEntry.asSequence().sortedBy { it.value.index }.map { it.key }.toList() - val embeddings = ids.map { idToEntry[it]!!.embedding } - fileManager.saveIndex(ids, embeddings) - } - - override suspend fun loadFromDisk() = lock.write { - indexToId.clear() - idToEntry.clear() - val (ids, embeddings) = fileManager.loadIndex() ?: return@write - for ((index, id) in ids.withIndex()) { - val embedding = embeddings[index] - indexToId[index] = id - idToEntry[id] = IndexEntry(index = index, count = 0, embedding = embedding) + override suspend fun offload() { + lock.write { + indexToId.clear() + idToEntry.clear() } } - override suspend fun offload() = lock.write { - indexToId.clear() - idToEntry.clear() - } - override suspend fun findClosest( searchEmbedding: FloatTextEmbedding, topK: Int, @@ -154,7 +166,7 @@ open class VanillaEmbeddingSearchIndex(val root: Path, override var limit: Int? } } - override suspend fun estimateMemoryUsage() = fileManager.embeddingSizeInBytes.toLong() * getSize() + override suspend fun estimateMemoryUsage(): Long = fileManager.embeddingSizeInBytes.toLong() * getSize() override fun estimateLimitByMemory(memory: Long): Int { return (memory / fileManager.embeddingSizeInBytes).toInt() @@ -164,32 +176,32 @@ open class VanillaEmbeddingSearchIndex(val root: Path, override var limit: Int? limit == null || idToEntry.size < limit!! } - suspend fun deleteEntry(id: EntityId, syncToDisk: Boolean) = lock.write { - delete(id = id, shouldSaveIds = syncToDisk) - } - - suspend fun addEntry(id: EntityId, embedding: FloatTextEmbedding) = lock.write { - uncheckedIds.remove(id) - add(id = id, embedding = embedding) + suspend fun deleteEntry(id: EntityId, syncToDisk: Boolean) { + lock.write { + delete(id = id, shouldSaveIds = syncToDisk) + } } /* Optimization for consequent delete and add operations */ - suspend fun updateEntry(id: EntityId, newId: EntityId, embedding: FloatTextEmbedding) = lock.write { - if (id !in idToEntry) return@write - if (idToEntry[id]!!.count == 1 && newId !in idToEntry) { - val index = idToEntry[id]!!.index - fileManager.set(index, embedding) + @Suppress("unused") + suspend fun updateEntry(id: EntityId, newId: EntityId, embedding: FloatTextEmbedding) { + lock.write { + if (id !in idToEntry) return@write + if (idToEntry[id]!!.count == 1 && newId !in idToEntry) { + val index = idToEntry[id]!!.index + fileManager.set(index, embedding) - idToEntry.remove(id) - idToEntry[newId] = IndexEntry(index = index, count = 1, embedding = embedding) - indexToId[index] = newId + idToEntry.remove(id) + idToEntry[newId] = IndexEntry(index = index, count = 1, embedding = embedding) + indexToId[index] = newId - saveIds() - } - else { - // Do not apply optimization - delete(id) - add(id = newId, embedding = embedding) + saveIds() + } + else { + // Do not apply optimization + delete(id) + add(id = newId, embedding = embedding) + } } } diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/models/LocalEmbeddingNetwork.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/models/LocalEmbeddingNetwork.kt index e0530672dc8b..72fe1fd503de 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/models/LocalEmbeddingNetwork.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/models/LocalEmbeddingNetwork.kt @@ -30,7 +30,7 @@ class LocalEmbeddingNetwork( return LocalEmbeddingNetwork(KIEngine.loadModel(data), maxLen ?: DEFAULT_MAX_LEN) } - const val DEFAULT_MAX_LEN = 512 + const val DEFAULT_MAX_LEN: Int = 512 } @Suppress("Unused") // useful for older versions of kinference where operators for mean pooling are not supported diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/models/LocalEmbeddingServiceLoader.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/models/LocalEmbeddingServiceLoader.kt index 4463655fe756..74112d076bcf 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/models/LocalEmbeddingServiceLoader.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/models/LocalEmbeddingServiceLoader.kt @@ -16,7 +16,7 @@ class LocalEmbeddingServiceLoader { val network = loadNetwork(loader) val encoder = loadTextEncoder(loader) return LocalEmbeddingService(network, encoder) - } catch (e: EOFException) { + } catch (_: EOFException) { return null } } @@ -33,8 +33,8 @@ class LocalEmbeddingServiceLoader { } companion object { - const val MODEL_NAME = "dan-bert-tiny" - const val MODEL_FILENAME = "dan_optimized_fp16.onnx" + const val MODEL_NAME: String = "dan-bert-tiny" + const val MODEL_FILENAME: String = "dan_optimized_fp16.onnx" } } diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/wrappers/AbstractEmbeddingsStorageWrapper.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/wrappers/AbstractEmbeddingsStorageWrapper.kt index 7221e76f13d7..b49ddda5f72a 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/wrappers/AbstractEmbeddingsStorageWrapper.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/wrappers/AbstractEmbeddingsStorageWrapper.kt @@ -167,6 +167,6 @@ abstract class AbstractEmbeddingsStorageWrapper( private val OFFLOAD_TIMEOUT = 10.seconds private val logger = Logger.getInstance(AbstractEmbeddingsStorageWrapper::class.java) - const val OLD_API_DIR_NAME = "old-api" + const val OLD_API_DIR_NAME: String = "old-api" } } diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/wrappers/ActionEmbeddingsStorageWrapper.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/wrappers/ActionEmbeddingsStorageWrapper.kt index 2e8e4b3a6612..931962b9f1a0 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/wrappers/ActionEmbeddingsStorageWrapper.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/jvm/wrappers/ActionEmbeddingsStorageWrapper.kt @@ -30,7 +30,9 @@ class ActionEmbeddingsStorageWrapper : EmbeddingsStorageWrapper { / SEMANTIC_SEARCH_RESOURCES_DIR_NAME / OLD_API_DIR_NAME / INDICES_DIR_NAME / INDEX_DIR ) - override suspend fun addEntries(values: Iterable>) = index.addEntries(values) + override suspend fun addEntries(values: Iterable>) { + index.addEntries(values) + } override suspend fun removeEntries(keys: List) { for (key in keys) { @@ -38,7 +40,9 @@ class ActionEmbeddingsStorageWrapper : EmbeddingsStorageWrapper { } } - override suspend fun clear() = index.clear() + override suspend fun clear() { + index.clear() + } @RequiresBackgroundThread override suspend fun searchNeighbours(queryEmbedding: FloatTextEmbedding, topK: Int, similarityThreshold: Double?): List> { @@ -60,9 +64,13 @@ class ActionEmbeddingsStorageWrapper : EmbeddingsStorageWrapper { return index.streamFindClose(embedding, similarityThreshold) } - override suspend fun startIndexingSession() = index.onIndexingStart() + override suspend fun startIndexingSession() { + index.onIndexingStart() + } - override suspend fun finishIndexingSession() = index.onIndexingFinish() + override suspend fun finishIndexingSession() { + index.onIndexingFinish() + } override suspend fun getSize(): Int = index.getSize() diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/settings/EmbeddingIndexSettingsImpl.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/settings/EmbeddingIndexSettingsImpl.kt index 150e62241b18..a73ef2ca323b 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/settings/EmbeddingIndexSettingsImpl.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/settings/EmbeddingIndexSettingsImpl.kt @@ -24,14 +24,18 @@ class EmbeddingIndexSettingsImpl : EmbeddingIndexSettings { private val mutex = ReentrantReadWriteLock() private val clientSettings = mutableListOf() - fun registerClientSettings(settings: EmbeddingIndexSettings) = mutex.write { - if (settings in clientSettings) return@write - clientSettings.add(settings) + fun registerClientSettings(settings: EmbeddingIndexSettings) { + mutex.write { + if (settings in clientSettings) return@write + clientSettings.add(settings) + } } @Suppress("unused") - fun unregisterClientSettings(settings: EmbeddingIndexSettings) = mutex.write { - clientSettings.remove(settings) + fun unregisterClientSettings(settings: EmbeddingIndexSettings) { + mutex.write { + clientSettings.remove(settings) + } } companion object { diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/utils/SemanticSearchCoroutineScope.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/utils/SemanticSearchCoroutineScope.kt index 127dd2d54bb5..92926feab151 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/utils/SemanticSearchCoroutineScope.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/utils/SemanticSearchCoroutineScope.kt @@ -12,7 +12,7 @@ class SemanticSearchCoroutineScope(private val cs: CoroutineScope) : Disposable override fun dispose() {} companion object { - fun getScope(project: Project) = project.service().cs + fun getScope(project: Project): CoroutineScope = project.service().cs fun getInstance(project: Project): SemanticSearchCoroutineScope = project.service() } diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/utils/TracingUtils.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/utils/TracingUtils.kt index 47a23c513802..8aa21f8be1b0 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/utils/TracingUtils.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/utils/TracingUtils.kt @@ -1,9 +1,10 @@ // Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license. package com.intellij.platform.ml.embeddings.utils +import com.intellij.platform.diagnostic.telemetry.IJTracer import com.intellij.platform.diagnostic.telemetry.Scope import com.intellij.platform.diagnostic.telemetry.TelemetryManager import org.jetbrains.annotations.ApiStatus @ApiStatus.Internal -val SEMANTIC_SEARCH_TRACER = TelemetryManager.getInstance().getTracer(Scope("semanticSearch")) \ No newline at end of file +val SEMANTIC_SEARCH_TRACER: IJTracer = TelemetryManager.getInstance().getTracer(Scope("semanticSearch")) \ No newline at end of file diff --git a/plugins/search-everywhere-ml/semantics/src/com/intellij/searchEverywhereMl/semantics/utils/InvalidTokenNotificationManager.kt b/plugins/search-everywhere-ml/semantics/src/com/intellij/searchEverywhereMl/semantics/utils/InvalidTokenNotificationManager.kt index adb941b90aca..a017f1e6633c 100644 --- a/plugins/search-everywhere-ml/semantics/src/com/intellij/searchEverywhereMl/semantics/utils/InvalidTokenNotificationManager.kt +++ b/plugins/search-everywhere-ml/semantics/src/com/intellij/searchEverywhereMl/semantics/utils/InvalidTokenNotificationManager.kt @@ -44,7 +44,7 @@ class InvalidTokenNotificationManager { } companion object { - private const val NOTIFICATION_GROUP_ID = "Semantic search" + private const val NOTIFICATION_GROUP_ID: String = "Semantic search" fun getInstance(): InvalidTokenNotificationManager = service() }