(IJPL-156622) Do not fail on concurrent read and write access to embedding index structures

I also perform less intermediate copying when working with structures of embedding indices.

GitOrigin-RevId: cf86043330ec34e631c3ba16accd2a92633166b2
This commit is contained in:
Evgeny Abramov
2024-07-05 17:34:28 +03:00
committed by intellij-monorepo-bot
parent 1ce4687f96
commit 8f40ea4365
5 changed files with 136 additions and 112 deletions

View File

@@ -3,6 +3,7 @@ package com.intellij.platform.ml.embeddings.search.indices
import ai.grazie.emb.FloatTextEmbedding
import com.intellij.concurrency.ConcurrentCollectionFactory
import com.intellij.openapi.application.ApplicationManager
import com.intellij.openapi.diagnostic.Logger
import com.intellij.openapi.diagnostic.debug
import com.intellij.platform.ml.embeddings.search.indices.EntitySourceType.DEFAULT
@@ -21,8 +22,8 @@ import java.nio.file.Path
* Instead, they change only the corresponding sections in the file.
*/
open class DiskSynchronizedEmbeddingSearchIndex(val root: Path, override var limit: Int? = null) : EmbeddingSearchIndex {
private var indexToId: MutableMap<Int, EntityId> = CollectionFactory.createSmallMemoryFootprintMap()
private var idToEntry: MutableMap<EntityId, IndexEntry> = CollectionFactory.createSmallMemoryFootprintMap()
private val indexToId: MutableMap<Int, EntityId> = CollectionFactory.createSmallMemoryFootprintMap()
private val idToEntry: MutableMap<EntityId, IndexEntry> = CollectionFactory.createSmallMemoryFootprintMap()
private val uncheckedIds: MutableSet<EntityId> = ConcurrentCollectionFactory.createConcurrentSet()
var changed: Boolean = false
@@ -44,7 +45,7 @@ open class DiskSynchronizedEmbeddingSearchIndex(val root: Path, override var lim
private data class IndexEntry(
var index: Int,
var count: Int,
val embedding: FloatTextEmbedding
val embedding: FloatTextEmbedding,
)
override suspend fun getSize() = lock.read { idToEntry.size }
@@ -79,43 +80,66 @@ open class DiskSynchronizedEmbeddingSearchIndex(val root: Path, override var lim
uncheckedIds.clear()
}
override suspend fun addEntries(values: Iterable<Pair<EntityId, FloatTextEmbedding>>, shouldCount: Boolean) =
lock.write {
for ((id, embedding) in values) {
ensureActive()
uncheckedIds.remove(id)
val entry = idToEntry.getOrPut(id) {
changed = true
if (limit != null && idToEntry.size >= limit!!) return@write
val index = idToEntry.size
indexToId[index] = id
IndexEntry(index = index, count = 0, embedding = embedding)
}
if (shouldCount || entry.count == 0) {
entry.count += 1
}
override suspend fun addEntries(
values: Iterable<Pair<EntityId, FloatTextEmbedding>>,
shouldCount: Boolean,
) = lock.write {
for ((id, embedding) in values) {
ensureActive()
uncheckedIds.remove(id)
if (limit != null && idToEntry.size >= limit!!) break
val entry = idToEntry.getOrPut(id) {
changed = true
val index = idToEntry.size
indexToId[index] = id
IndexEntry(index = index, count = 0, embedding = embedding)
}
if (shouldCount || entry.count == 0) {
entry.count += 1
}
}
}
override suspend fun saveToDisk() = lock.read { save() }
override suspend fun saveToDisk() = lock.read {
val ids = idToEntry.asSequence().sortedBy { it.value.index }.map { it.key }.toList()
val embeddings = ids.map { idToEntry[it]!!.embedding }
fileManager.saveIndex(ids, embeddings)
}
override suspend fun loadFromDisk() = lock.write {
indexToId.clear()
idToEntry.clear()
val (ids, embeddings) = fileManager.loadIndex() ?: return@write
val idToIndex = ids.withIndex().associate { it.value to it.index }
val idToEmbedding = (ids zip embeddings).toMap()
indexToId = CollectionFactory.createSmallMemoryFootprintMap(ids.withIndex().associate { it.index to it.value })
idToEntry = CollectionFactory.createSmallMemoryFootprintMap(
ids.associateWith { IndexEntry(index = idToIndex[it]!!, count = 0, embedding = idToEmbedding[it]!!) }
)
for ((index, id) in ids.withIndex()) {
val embedding = embeddings[index]
indexToId[index] = id
idToEntry[id] = IndexEntry(index = index, count = 0, embedding = embedding)
}
}
override suspend fun offload() = lock.write {
indexToId = CollectionFactory.createSmallMemoryFootprintMap()
idToEntry = CollectionFactory.createSmallMemoryFootprintMap()
indexToId.clear()
idToEntry.clear()
}
override suspend fun findClosest(searchEmbedding: FloatTextEmbedding, topK: Int, similarityThreshold: Double?): List<ScoredText> = lock.read {
return@read idToEntry.mapValues { it.value.embedding }.findClosest(searchEmbedding, topK, similarityThreshold)
override suspend fun findClosest(
searchEmbedding: FloatTextEmbedding,
topK: Int,
similarityThreshold: Double?,
): List<ScoredText> = lock.read {
for (i in 0..SEARCH_ATTEMPTS) {
try {
return@read idToEntry.asSequence()
.map { (id, indexEntry) -> id to indexEntry.embedding }
.findClosest(searchEmbedding, topK, similarityThreshold)
}
catch (e: Exception) {
ensureActive()
if (ApplicationManager.getApplication().isInternal) throw e
continue
}
}
emptyList()
}
override suspend fun streamFindClose(searchEmbedding: FloatTextEmbedding, similarityThreshold: Double?): Flow<ScoredText> {
@@ -140,12 +164,6 @@ open class DiskSynchronizedEmbeddingSearchIndex(val root: Path, override var lim
limit == null || idToEntry.size < limit!!
}
private suspend fun save() {
val ids = idToEntry.toList().sortedBy { it.second.index }.map { it.first }
val embeddings = ids.map { idToEntry[it]!!.embedding }
fileManager.saveIndex(ids = ids, embeddings = embeddings)
}
suspend fun deleteEntry(id: EntityId, syncToDisk: Boolean) = lock.write {
delete(id = id, shouldSaveIds = syncToDisk)
}
@@ -213,7 +231,7 @@ open class DiskSynchronizedEmbeddingSearchIndex(val root: Path, override var lim
}
private suspend fun saveIds() {
fileManager.saveIds(idToEntry.toList().sortedBy { it.second.index }.map { it.first })
fileManager.saveIds(idToEntry.asSequence().sortedBy { it.value.index }.map { it.key }.toList())
}
override suspend fun clearBySourceType(sourceType: EntitySourceType) {
@@ -224,6 +242,7 @@ open class DiskSynchronizedEmbeddingSearchIndex(val root: Path, override var lim
}
companion object {
private const val SEARCH_ATTEMPTS = 5
private val logger = Logger.getInstance(DiskSynchronizedEmbeddingSearchIndex::class.java)
}
}

View File

@@ -3,7 +3,10 @@ package com.intellij.platform.ml.embeddings.search.indices
import ai.grazie.emb.FloatTextEmbedding
import com.intellij.platform.ml.embeddings.search.utils.ScoredText
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.ensureActive
import kotlinx.coroutines.flow.Flow
import java.util.PriorityQueue
interface EmbeddingSearchIndex {
var limit: Int?
@@ -33,20 +36,27 @@ interface EmbeddingSearchIndex {
suspend fun checkCanAddEntry(): Boolean
}
internal fun Map<EntityId, FloatTextEmbedding>.findClosest(searchEmbedding: FloatTextEmbedding,
topK: Int, similarityThreshold: Double?): List<ScoredText> {
return asSequence()
.map { it.key to searchEmbedding.times(it.value) }
.filter { (_, similarity) -> if (similarityThreshold != null) similarity > similarityThreshold else true }
.sortedByDescending { (_, similarity) -> similarity }
.take(topK)
.map { (id, similarity) -> ScoredText(id.id, similarity.toDouble()) }
.toList()
internal suspend fun Sequence<Pair<EntityId, FloatTextEmbedding>>.findClosest(
searchEmbedding: FloatTextEmbedding,
topK: Int, similarityThreshold: Double?,
): List<ScoredText> = coroutineScope {
val closest = PriorityQueue<ScoredText>(topK + 1, compareBy { it.similarity })
map { (id, embedding) -> ScoredText(id.id, searchEmbedding.times(embedding).toDouble()) }
.filter { similarityThreshold == null || it.similarity > similarityThreshold }
.forEach {
ensureActive()
closest.add(it)
if (closest.size > topK) closest.poll()
}
closest.sortedByDescending { it.similarity }
}
internal fun Sequence<Pair<EntityId, FloatTextEmbedding>>.streamFindClose(queryEmbedding: FloatTextEmbedding,
similarityThreshold: Double?): Sequence<ScoredText> {
return map { (id, embedding) -> id to queryEmbedding.times(embedding) }
.filter { similarityThreshold == null || it.second > similarityThreshold }
.map { (id, similarity) -> ScoredText(id.id, similarity.toDouble()) }
internal fun Sequence<Pair<EntityId, FloatTextEmbedding>>.streamFindClose(
queryEmbedding: FloatTextEmbedding,
similarityThreshold: Double?,
): Sequence<ScoredText> {
return map { (id, embedding) -> ScoredText(id.id, queryEmbedding.times(embedding).toDouble()) }
.filter { similarityThreshold == null || it.similarity > similarityThreshold }
}

View File

@@ -16,7 +16,7 @@ import java.nio.file.Path
* Can be persisted to disk.
*/
class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null) : EmbeddingSearchIndex {
private var idToEmbedding: MutableMap<EntityId, FloatTextEmbedding> = CollectionFactory.createSmallMemoryFootprintMap()
private val idToEmbedding: MutableMap<EntityId, FloatTextEmbedding> = CollectionFactory.createSmallMemoryFootprintMap()
private val uncheckedIds: MutableSet<EntityId> = ConcurrentCollectionFactory.createConcurrentSet()
private val lock = SuspendingReadWriteLock()
@@ -27,7 +27,9 @@ class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null)
override suspend fun setLimit(value: Int?) = lock.write {
// Shrink index if necessary:
if (value != null && value < idToEmbedding.size) {
idToEmbedding = idToEmbedding.toList().take(value).toMap().toMutableMap()
val remaining = idToEmbedding.asSequence().take(value).map { it.toPair() }.toList()
idToEmbedding.clear()
idToEmbedding.putAll(remaining)
}
limit = value
}
@@ -55,29 +57,37 @@ class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null)
uncheckedIds.clear()
}
override suspend fun addEntries(values: Iterable<Pair<EntityId, FloatTextEmbedding>>, shouldCount: Boolean) =
lock.write {
if (limit != null) {
val list = values.toList()
list.forEach { uncheckedIds.remove(it.first) }
idToEmbedding.putAll(list.take(minOf(limit!! - idToEmbedding.size, list.size)))
}
else {
idToEmbedding.putAll(values)
}
override suspend fun addEntries(
values: Iterable<Pair<EntityId, FloatTextEmbedding>>,
shouldCount: Boolean,
) = lock.write {
if (limit != null) {
val list = values.toList()
list.forEach { uncheckedIds.remove(it.first) }
idToEmbedding.putAll(list.take(minOf(limit!! - idToEmbedding.size, list.size)))
}
else {
idToEmbedding.putAll(values)
}
}
override suspend fun saveToDisk() = lock.read { save() }
override suspend fun loadFromDisk() = lock.write {
val (ids, embeddings) = fileManager.loadIndex() ?: return@write
idToEmbedding = (ids zip embeddings).toMap().toMutableMap()
idToEmbedding.clear()
idToEmbedding.putAll(ids zip embeddings)
}
override suspend fun offload() = idToEmbedding.clear()
override suspend fun offload() = lock.write { idToEmbedding.clear() }
override suspend fun findClosest(searchEmbedding: FloatTextEmbedding, topK: Int, similarityThreshold: Double?): List<ScoredText> = lock.read {
idToEmbedding.findClosest(searchEmbedding, topK, similarityThreshold)
override suspend fun findClosest(
searchEmbedding: FloatTextEmbedding,
topK: Int, similarityThreshold: Double?,
): List<ScoredText> = lock.read {
idToEmbedding.asSequence()
.map { (id, embedding) -> id to embedding }
.findClosest(searchEmbedding, topK, similarityThreshold)
}
override suspend fun streamFindClose(searchEmbedding: FloatTextEmbedding, similarityThreshold: Double?): Flow<ScoredText> {
@@ -103,7 +113,7 @@ class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null)
}
private suspend fun save() {
val (ids, embeddings) = idToEmbedding.toList().unzip()
fileManager.saveIndex(ids = ids, embeddings = embeddings)
val (ids, embeddings) = idToEmbedding.asSequence().map { (id, embedding) -> id to embedding }.unzip()
fileManager.saveIndex(ids, embeddings)
}
}

View File

@@ -147,11 +147,8 @@ class FileBasedEmbeddingStoragesManager(private val project: Project, private va
indexFiles(scanFiles().toList().sortedByDescending { it.name.length })
EmbeddingSearchLogger.indexingFinished(project, forActions = false, TimeoutUtil.getDurationMillis(projectIndexingStartTime))
}
catch (e: CancellationException) {
logger.debug { "Full project embedding indexing was cancelled" }
throw e
}
finally {
ensureActive()
if (isFirstIndexing) {
onFirstIndexingFinish()
isFirstIndexing = false

View File

@@ -2,54 +2,42 @@
package com.intellij.platform.ml.embeddings.search.utils
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.NonCancellable
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.sync.Semaphore
import kotlinx.coroutines.sync.withPermit
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import kotlinx.coroutines.withContext
class SuspendingReadWriteLock {
private val readLightswitch = Lightswitch()
private val roomEmpty = Semaphore(1)
suspend fun <T> read(action: suspend CoroutineScope.() -> T): T = coroutineScope {
readLightswitch.lock(roomEmpty)
return@coroutineScope try {
action()
}
finally {
readLightswitch.unlock(roomEmpty)
}
}
suspend fun <T> write(action: suspend CoroutineScope.() -> T): T = coroutineScope {
roomEmpty.acquire()
return@coroutineScope try {
action()
}
finally {
roomEmpty.release()
}
}
}
private class Lightswitch {
private val roomEmpty = Mutex()
private var counter = 0
private val mutex = Semaphore(1)
private val counterMutex = Mutex()
suspend fun lock(semaphore: Semaphore) {
mutex.withPermit {
counter += 1
if (counter == 1) {
semaphore.acquire()
suspend fun <T> read(
action: suspend CoroutineScope.() -> T,
): T = coroutineScope {
counterMutex.withLock {
if (++counter == 1) {
roomEmpty.lock()
}
}
try {
action()
}
finally {
withContext(NonCancellable) {
counterMutex.withLock {
if (--counter == 0) {
roomEmpty.unlock()
}
}
}
}
}
suspend fun unlock(semaphore: Semaphore) {
mutex.withPermit {
counter -= 1
if (counter == 0) {
semaphore.release()
}
}
suspend fun <T> write(
action: suspend CoroutineScope.() -> T,
): T = coroutineScope {
roomEmpty.withLock { action() }
}
}