From 8f40ea4365b5d1aa72389dc6cfb95c23b4baf99b Mon Sep 17 00:00:00 2001 From: Evgeny Abramov Date: Fri, 5 Jul 2024 17:34:28 +0300 Subject: [PATCH] (IJPL-156622) Do not fail on concurrent read and write access to embedding index structures I also perform less intermediate copying when working with structures of embedding indices. GitOrigin-RevId: cf86043330ec34e631c3ba16accd2a92633166b2 --- .../DiskSynchronizedEmbeddingSearchIndex.kt | 91 +++++++++++-------- .../search/indices/EmbeddingSearchIndex.kt | 38 +++++--- .../indices/InMemoryEmbeddingSearchIndex.kt | 46 ++++++---- .../FileBasedEmbeddingStoragesManager.kt | 5 +- .../search/utils/SuspendingReadWriteLock.kt | 68 ++++++-------- 5 files changed, 136 insertions(+), 112 deletions(-) diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/search/indices/DiskSynchronizedEmbeddingSearchIndex.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/search/indices/DiskSynchronizedEmbeddingSearchIndex.kt index da10a146c468..2a767fc8a002 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/search/indices/DiskSynchronizedEmbeddingSearchIndex.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/search/indices/DiskSynchronizedEmbeddingSearchIndex.kt @@ -3,6 +3,7 @@ package com.intellij.platform.ml.embeddings.search.indices import ai.grazie.emb.FloatTextEmbedding import com.intellij.concurrency.ConcurrentCollectionFactory +import com.intellij.openapi.application.ApplicationManager import com.intellij.openapi.diagnostic.Logger import com.intellij.openapi.diagnostic.debug import com.intellij.platform.ml.embeddings.search.indices.EntitySourceType.DEFAULT @@ -21,8 +22,8 @@ import java.nio.file.Path * Instead, they change only the corresponding sections in the file. */ open class DiskSynchronizedEmbeddingSearchIndex(val root: Path, override var limit: Int? = null) : EmbeddingSearchIndex { - private var indexToId: MutableMap = CollectionFactory.createSmallMemoryFootprintMap() - private var idToEntry: MutableMap = CollectionFactory.createSmallMemoryFootprintMap() + private val indexToId: MutableMap = CollectionFactory.createSmallMemoryFootprintMap() + private val idToEntry: MutableMap = CollectionFactory.createSmallMemoryFootprintMap() private val uncheckedIds: MutableSet = ConcurrentCollectionFactory.createConcurrentSet() var changed: Boolean = false @@ -44,7 +45,7 @@ open class DiskSynchronizedEmbeddingSearchIndex(val root: Path, override var lim private data class IndexEntry( var index: Int, var count: Int, - val embedding: FloatTextEmbedding + val embedding: FloatTextEmbedding, ) override suspend fun getSize() = lock.read { idToEntry.size } @@ -79,43 +80,66 @@ open class DiskSynchronizedEmbeddingSearchIndex(val root: Path, override var lim uncheckedIds.clear() } - override suspend fun addEntries(values: Iterable>, shouldCount: Boolean) = - lock.write { - for ((id, embedding) in values) { - ensureActive() - uncheckedIds.remove(id) - val entry = idToEntry.getOrPut(id) { - changed = true - if (limit != null && idToEntry.size >= limit!!) return@write - val index = idToEntry.size - indexToId[index] = id - IndexEntry(index = index, count = 0, embedding = embedding) - } - if (shouldCount || entry.count == 0) { - entry.count += 1 - } + override suspend fun addEntries( + values: Iterable>, + shouldCount: Boolean, + ) = lock.write { + for ((id, embedding) in values) { + ensureActive() + uncheckedIds.remove(id) + if (limit != null && idToEntry.size >= limit!!) break + val entry = idToEntry.getOrPut(id) { + changed = true + val index = idToEntry.size + indexToId[index] = id + IndexEntry(index = index, count = 0, embedding = embedding) + } + if (shouldCount || entry.count == 0) { + entry.count += 1 } } + } - override suspend fun saveToDisk() = lock.read { save() } + override suspend fun saveToDisk() = lock.read { + val ids = idToEntry.asSequence().sortedBy { it.value.index }.map { it.key }.toList() + val embeddings = ids.map { idToEntry[it]!!.embedding } + fileManager.saveIndex(ids, embeddings) + } override suspend fun loadFromDisk() = lock.write { + indexToId.clear() + idToEntry.clear() val (ids, embeddings) = fileManager.loadIndex() ?: return@write - val idToIndex = ids.withIndex().associate { it.value to it.index } - val idToEmbedding = (ids zip embeddings).toMap() - indexToId = CollectionFactory.createSmallMemoryFootprintMap(ids.withIndex().associate { it.index to it.value }) - idToEntry = CollectionFactory.createSmallMemoryFootprintMap( - ids.associateWith { IndexEntry(index = idToIndex[it]!!, count = 0, embedding = idToEmbedding[it]!!) } - ) + for ((index, id) in ids.withIndex()) { + val embedding = embeddings[index] + indexToId[index] = id + idToEntry[id] = IndexEntry(index = index, count = 0, embedding = embedding) + } } override suspend fun offload() = lock.write { - indexToId = CollectionFactory.createSmallMemoryFootprintMap() - idToEntry = CollectionFactory.createSmallMemoryFootprintMap() + indexToId.clear() + idToEntry.clear() } - override suspend fun findClosest(searchEmbedding: FloatTextEmbedding, topK: Int, similarityThreshold: Double?): List = lock.read { - return@read idToEntry.mapValues { it.value.embedding }.findClosest(searchEmbedding, topK, similarityThreshold) + override suspend fun findClosest( + searchEmbedding: FloatTextEmbedding, + topK: Int, + similarityThreshold: Double?, + ): List = lock.read { + for (i in 0..SEARCH_ATTEMPTS) { + try { + return@read idToEntry.asSequence() + .map { (id, indexEntry) -> id to indexEntry.embedding } + .findClosest(searchEmbedding, topK, similarityThreshold) + } + catch (e: Exception) { + ensureActive() + if (ApplicationManager.getApplication().isInternal) throw e + continue + } + } + emptyList() } override suspend fun streamFindClose(searchEmbedding: FloatTextEmbedding, similarityThreshold: Double?): Flow { @@ -140,12 +164,6 @@ open class DiskSynchronizedEmbeddingSearchIndex(val root: Path, override var lim limit == null || idToEntry.size < limit!! } - private suspend fun save() { - val ids = idToEntry.toList().sortedBy { it.second.index }.map { it.first } - val embeddings = ids.map { idToEntry[it]!!.embedding } - fileManager.saveIndex(ids = ids, embeddings = embeddings) - } - suspend fun deleteEntry(id: EntityId, syncToDisk: Boolean) = lock.write { delete(id = id, shouldSaveIds = syncToDisk) } @@ -213,7 +231,7 @@ open class DiskSynchronizedEmbeddingSearchIndex(val root: Path, override var lim } private suspend fun saveIds() { - fileManager.saveIds(idToEntry.toList().sortedBy { it.second.index }.map { it.first }) + fileManager.saveIds(idToEntry.asSequence().sortedBy { it.value.index }.map { it.key }.toList()) } override suspend fun clearBySourceType(sourceType: EntitySourceType) { @@ -224,6 +242,7 @@ open class DiskSynchronizedEmbeddingSearchIndex(val root: Path, override var lim } companion object { + private const val SEARCH_ATTEMPTS = 5 private val logger = Logger.getInstance(DiskSynchronizedEmbeddingSearchIndex::class.java) } } \ No newline at end of file diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/search/indices/EmbeddingSearchIndex.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/search/indices/EmbeddingSearchIndex.kt index bf36c6edbd4d..d9a96772e9d9 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/search/indices/EmbeddingSearchIndex.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/search/indices/EmbeddingSearchIndex.kt @@ -3,7 +3,10 @@ package com.intellij.platform.ml.embeddings.search.indices import ai.grazie.emb.FloatTextEmbedding import com.intellij.platform.ml.embeddings.search.utils.ScoredText +import kotlinx.coroutines.coroutineScope +import kotlinx.coroutines.ensureActive import kotlinx.coroutines.flow.Flow +import java.util.PriorityQueue interface EmbeddingSearchIndex { var limit: Int? @@ -33,20 +36,27 @@ interface EmbeddingSearchIndex { suspend fun checkCanAddEntry(): Boolean } -internal fun Map.findClosest(searchEmbedding: FloatTextEmbedding, - topK: Int, similarityThreshold: Double?): List { - return asSequence() - .map { it.key to searchEmbedding.times(it.value) } - .filter { (_, similarity) -> if (similarityThreshold != null) similarity > similarityThreshold else true } - .sortedByDescending { (_, similarity) -> similarity } - .take(topK) - .map { (id, similarity) -> ScoredText(id.id, similarity.toDouble()) } - .toList() +internal suspend fun Sequence>.findClosest( + searchEmbedding: FloatTextEmbedding, + topK: Int, similarityThreshold: Double?, +): List = coroutineScope { + val closest = PriorityQueue(topK + 1, compareBy { it.similarity }) + + map { (id, embedding) -> ScoredText(id.id, searchEmbedding.times(embedding).toDouble()) } + .filter { similarityThreshold == null || it.similarity > similarityThreshold } + .forEach { + ensureActive() + closest.add(it) + if (closest.size > topK) closest.poll() + } + + closest.sortedByDescending { it.similarity } } -internal fun Sequence>.streamFindClose(queryEmbedding: FloatTextEmbedding, - similarityThreshold: Double?): Sequence { - return map { (id, embedding) -> id to queryEmbedding.times(embedding) } - .filter { similarityThreshold == null || it.second > similarityThreshold } - .map { (id, similarity) -> ScoredText(id.id, similarity.toDouble()) } +internal fun Sequence>.streamFindClose( + queryEmbedding: FloatTextEmbedding, + similarityThreshold: Double?, +): Sequence { + return map { (id, embedding) -> ScoredText(id.id, queryEmbedding.times(embedding).toDouble()) } + .filter { similarityThreshold == null || it.similarity > similarityThreshold } } \ No newline at end of file diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/search/indices/InMemoryEmbeddingSearchIndex.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/search/indices/InMemoryEmbeddingSearchIndex.kt index 4d64b86efcab..6eff8b273e2c 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/search/indices/InMemoryEmbeddingSearchIndex.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/search/indices/InMemoryEmbeddingSearchIndex.kt @@ -16,7 +16,7 @@ import java.nio.file.Path * Can be persisted to disk. */ class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null) : EmbeddingSearchIndex { - private var idToEmbedding: MutableMap = CollectionFactory.createSmallMemoryFootprintMap() + private val idToEmbedding: MutableMap = CollectionFactory.createSmallMemoryFootprintMap() private val uncheckedIds: MutableSet = ConcurrentCollectionFactory.createConcurrentSet() private val lock = SuspendingReadWriteLock() @@ -27,7 +27,9 @@ class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null) override suspend fun setLimit(value: Int?) = lock.write { // Shrink index if necessary: if (value != null && value < idToEmbedding.size) { - idToEmbedding = idToEmbedding.toList().take(value).toMap().toMutableMap() + val remaining = idToEmbedding.asSequence().take(value).map { it.toPair() }.toList() + idToEmbedding.clear() + idToEmbedding.putAll(remaining) } limit = value } @@ -55,29 +57,37 @@ class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null) uncheckedIds.clear() } - override suspend fun addEntries(values: Iterable>, shouldCount: Boolean) = - lock.write { - if (limit != null) { - val list = values.toList() - list.forEach { uncheckedIds.remove(it.first) } - idToEmbedding.putAll(list.take(minOf(limit!! - idToEmbedding.size, list.size))) - } - else { - idToEmbedding.putAll(values) - } + override suspend fun addEntries( + values: Iterable>, + shouldCount: Boolean, + ) = lock.write { + if (limit != null) { + val list = values.toList() + list.forEach { uncheckedIds.remove(it.first) } + idToEmbedding.putAll(list.take(minOf(limit!! - idToEmbedding.size, list.size))) } + else { + idToEmbedding.putAll(values) + } + } override suspend fun saveToDisk() = lock.read { save() } override suspend fun loadFromDisk() = lock.write { val (ids, embeddings) = fileManager.loadIndex() ?: return@write - idToEmbedding = (ids zip embeddings).toMap().toMutableMap() + idToEmbedding.clear() + idToEmbedding.putAll(ids zip embeddings) } - override suspend fun offload() = idToEmbedding.clear() + override suspend fun offload() = lock.write { idToEmbedding.clear() } - override suspend fun findClosest(searchEmbedding: FloatTextEmbedding, topK: Int, similarityThreshold: Double?): List = lock.read { - idToEmbedding.findClosest(searchEmbedding, topK, similarityThreshold) + override suspend fun findClosest( + searchEmbedding: FloatTextEmbedding, + topK: Int, similarityThreshold: Double?, + ): List = lock.read { + idToEmbedding.asSequence() + .map { (id, embedding) -> id to embedding } + .findClosest(searchEmbedding, topK, similarityThreshold) } override suspend fun streamFindClose(searchEmbedding: FloatTextEmbedding, similarityThreshold: Double?): Flow { @@ -103,7 +113,7 @@ class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null) } private suspend fun save() { - val (ids, embeddings) = idToEmbedding.toList().unzip() - fileManager.saveIndex(ids = ids, embeddings = embeddings) + val (ids, embeddings) = idToEmbedding.asSequence().map { (id, embedding) -> id to embedding }.unzip() + fileManager.saveIndex(ids, embeddings) } } \ No newline at end of file diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/search/services/FileBasedEmbeddingStoragesManager.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/search/services/FileBasedEmbeddingStoragesManager.kt index 26b99a87880d..7959be83ce27 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/search/services/FileBasedEmbeddingStoragesManager.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/search/services/FileBasedEmbeddingStoragesManager.kt @@ -147,11 +147,8 @@ class FileBasedEmbeddingStoragesManager(private val project: Project, private va indexFiles(scanFiles().toList().sortedByDescending { it.name.length }) EmbeddingSearchLogger.indexingFinished(project, forActions = false, TimeoutUtil.getDurationMillis(projectIndexingStartTime)) } - catch (e: CancellationException) { - logger.debug { "Full project embedding indexing was cancelled" } - throw e - } finally { + ensureActive() if (isFirstIndexing) { onFirstIndexingFinish() isFirstIndexing = false diff --git a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/search/utils/SuspendingReadWriteLock.kt b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/search/utils/SuspendingReadWriteLock.kt index 3ad4a0b26880..3782e7c02933 100644 --- a/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/search/utils/SuspendingReadWriteLock.kt +++ b/platform/ml-embeddings/src/com/intellij/platform/ml/embeddings/search/utils/SuspendingReadWriteLock.kt @@ -2,54 +2,42 @@ package com.intellij.platform.ml.embeddings.search.utils import kotlinx.coroutines.CoroutineScope +import kotlinx.coroutines.NonCancellable import kotlinx.coroutines.coroutineScope -import kotlinx.coroutines.sync.Semaphore -import kotlinx.coroutines.sync.withPermit +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import kotlinx.coroutines.withContext class SuspendingReadWriteLock { - private val readLightswitch = Lightswitch() - private val roomEmpty = Semaphore(1) - - suspend fun read(action: suspend CoroutineScope.() -> T): T = coroutineScope { - readLightswitch.lock(roomEmpty) - return@coroutineScope try { - action() - } - finally { - readLightswitch.unlock(roomEmpty) - } - } - - suspend fun write(action: suspend CoroutineScope.() -> T): T = coroutineScope { - roomEmpty.acquire() - return@coroutineScope try { - action() - } - finally { - roomEmpty.release() - } - } -} - -private class Lightswitch { + private val roomEmpty = Mutex() private var counter = 0 - private val mutex = Semaphore(1) + private val counterMutex = Mutex() - suspend fun lock(semaphore: Semaphore) { - mutex.withPermit { - counter += 1 - if (counter == 1) { - semaphore.acquire() + suspend fun read( + action: suspend CoroutineScope.() -> T, + ): T = coroutineScope { + counterMutex.withLock { + if (++counter == 1) { + roomEmpty.lock() + } + } + try { + action() + } + finally { + withContext(NonCancellable) { + counterMutex.withLock { + if (--counter == 0) { + roomEmpty.unlock() + } + } } } } - suspend fun unlock(semaphore: Semaphore) { - mutex.withPermit { - counter -= 1 - if (counter == 0) { - semaphore.release() - } - } + suspend fun write( + action: suspend CoroutineScope.() -> T, + ): T = coroutineScope { + roomEmpty.withLock { action() } } } \ No newline at end of file