mirror of
https://gitflic.ru/project/openide/openide.git
synced 2026-03-22 15:19:59 +07:00
(IJPL-156622) Do not fail on concurrent read and write access to embedding index structures
I also perform less intermediate copying when working with structures of embedding indices. GitOrigin-RevId: cf86043330ec34e631c3ba16accd2a92633166b2
This commit is contained in:
committed by
intellij-monorepo-bot
parent
1ce4687f96
commit
8f40ea4365
@@ -3,6 +3,7 @@ package com.intellij.platform.ml.embeddings.search.indices
|
||||
|
||||
import ai.grazie.emb.FloatTextEmbedding
|
||||
import com.intellij.concurrency.ConcurrentCollectionFactory
|
||||
import com.intellij.openapi.application.ApplicationManager
|
||||
import com.intellij.openapi.diagnostic.Logger
|
||||
import com.intellij.openapi.diagnostic.debug
|
||||
import com.intellij.platform.ml.embeddings.search.indices.EntitySourceType.DEFAULT
|
||||
@@ -21,8 +22,8 @@ import java.nio.file.Path
|
||||
* Instead, they change only the corresponding sections in the file.
|
||||
*/
|
||||
open class DiskSynchronizedEmbeddingSearchIndex(val root: Path, override var limit: Int? = null) : EmbeddingSearchIndex {
|
||||
private var indexToId: MutableMap<Int, EntityId> = CollectionFactory.createSmallMemoryFootprintMap()
|
||||
private var idToEntry: MutableMap<EntityId, IndexEntry> = CollectionFactory.createSmallMemoryFootprintMap()
|
||||
private val indexToId: MutableMap<Int, EntityId> = CollectionFactory.createSmallMemoryFootprintMap()
|
||||
private val idToEntry: MutableMap<EntityId, IndexEntry> = CollectionFactory.createSmallMemoryFootprintMap()
|
||||
private val uncheckedIds: MutableSet<EntityId> = ConcurrentCollectionFactory.createConcurrentSet()
|
||||
var changed: Boolean = false
|
||||
|
||||
@@ -44,7 +45,7 @@ open class DiskSynchronizedEmbeddingSearchIndex(val root: Path, override var lim
|
||||
private data class IndexEntry(
|
||||
var index: Int,
|
||||
var count: Int,
|
||||
val embedding: FloatTextEmbedding
|
||||
val embedding: FloatTextEmbedding,
|
||||
)
|
||||
|
||||
override suspend fun getSize() = lock.read { idToEntry.size }
|
||||
@@ -79,43 +80,66 @@ open class DiskSynchronizedEmbeddingSearchIndex(val root: Path, override var lim
|
||||
uncheckedIds.clear()
|
||||
}
|
||||
|
||||
override suspend fun addEntries(values: Iterable<Pair<EntityId, FloatTextEmbedding>>, shouldCount: Boolean) =
|
||||
lock.write {
|
||||
for ((id, embedding) in values) {
|
||||
ensureActive()
|
||||
uncheckedIds.remove(id)
|
||||
val entry = idToEntry.getOrPut(id) {
|
||||
changed = true
|
||||
if (limit != null && idToEntry.size >= limit!!) return@write
|
||||
val index = idToEntry.size
|
||||
indexToId[index] = id
|
||||
IndexEntry(index = index, count = 0, embedding = embedding)
|
||||
}
|
||||
if (shouldCount || entry.count == 0) {
|
||||
entry.count += 1
|
||||
}
|
||||
override suspend fun addEntries(
|
||||
values: Iterable<Pair<EntityId, FloatTextEmbedding>>,
|
||||
shouldCount: Boolean,
|
||||
) = lock.write {
|
||||
for ((id, embedding) in values) {
|
||||
ensureActive()
|
||||
uncheckedIds.remove(id)
|
||||
if (limit != null && idToEntry.size >= limit!!) break
|
||||
val entry = idToEntry.getOrPut(id) {
|
||||
changed = true
|
||||
val index = idToEntry.size
|
||||
indexToId[index] = id
|
||||
IndexEntry(index = index, count = 0, embedding = embedding)
|
||||
}
|
||||
if (shouldCount || entry.count == 0) {
|
||||
entry.count += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun saveToDisk() = lock.read { save() }
|
||||
override suspend fun saveToDisk() = lock.read {
|
||||
val ids = idToEntry.asSequence().sortedBy { it.value.index }.map { it.key }.toList()
|
||||
val embeddings = ids.map { idToEntry[it]!!.embedding }
|
||||
fileManager.saveIndex(ids, embeddings)
|
||||
}
|
||||
|
||||
override suspend fun loadFromDisk() = lock.write {
|
||||
indexToId.clear()
|
||||
idToEntry.clear()
|
||||
val (ids, embeddings) = fileManager.loadIndex() ?: return@write
|
||||
val idToIndex = ids.withIndex().associate { it.value to it.index }
|
||||
val idToEmbedding = (ids zip embeddings).toMap()
|
||||
indexToId = CollectionFactory.createSmallMemoryFootprintMap(ids.withIndex().associate { it.index to it.value })
|
||||
idToEntry = CollectionFactory.createSmallMemoryFootprintMap(
|
||||
ids.associateWith { IndexEntry(index = idToIndex[it]!!, count = 0, embedding = idToEmbedding[it]!!) }
|
||||
)
|
||||
for ((index, id) in ids.withIndex()) {
|
||||
val embedding = embeddings[index]
|
||||
indexToId[index] = id
|
||||
idToEntry[id] = IndexEntry(index = index, count = 0, embedding = embedding)
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun offload() = lock.write {
|
||||
indexToId = CollectionFactory.createSmallMemoryFootprintMap()
|
||||
idToEntry = CollectionFactory.createSmallMemoryFootprintMap()
|
||||
indexToId.clear()
|
||||
idToEntry.clear()
|
||||
}
|
||||
|
||||
override suspend fun findClosest(searchEmbedding: FloatTextEmbedding, topK: Int, similarityThreshold: Double?): List<ScoredText> = lock.read {
|
||||
return@read idToEntry.mapValues { it.value.embedding }.findClosest(searchEmbedding, topK, similarityThreshold)
|
||||
override suspend fun findClosest(
|
||||
searchEmbedding: FloatTextEmbedding,
|
||||
topK: Int,
|
||||
similarityThreshold: Double?,
|
||||
): List<ScoredText> = lock.read {
|
||||
for (i in 0..SEARCH_ATTEMPTS) {
|
||||
try {
|
||||
return@read idToEntry.asSequence()
|
||||
.map { (id, indexEntry) -> id to indexEntry.embedding }
|
||||
.findClosest(searchEmbedding, topK, similarityThreshold)
|
||||
}
|
||||
catch (e: Exception) {
|
||||
ensureActive()
|
||||
if (ApplicationManager.getApplication().isInternal) throw e
|
||||
continue
|
||||
}
|
||||
}
|
||||
emptyList()
|
||||
}
|
||||
|
||||
override suspend fun streamFindClose(searchEmbedding: FloatTextEmbedding, similarityThreshold: Double?): Flow<ScoredText> {
|
||||
@@ -140,12 +164,6 @@ open class DiskSynchronizedEmbeddingSearchIndex(val root: Path, override var lim
|
||||
limit == null || idToEntry.size < limit!!
|
||||
}
|
||||
|
||||
private suspend fun save() {
|
||||
val ids = idToEntry.toList().sortedBy { it.second.index }.map { it.first }
|
||||
val embeddings = ids.map { idToEntry[it]!!.embedding }
|
||||
fileManager.saveIndex(ids = ids, embeddings = embeddings)
|
||||
}
|
||||
|
||||
suspend fun deleteEntry(id: EntityId, syncToDisk: Boolean) = lock.write {
|
||||
delete(id = id, shouldSaveIds = syncToDisk)
|
||||
}
|
||||
@@ -213,7 +231,7 @@ open class DiskSynchronizedEmbeddingSearchIndex(val root: Path, override var lim
|
||||
}
|
||||
|
||||
private suspend fun saveIds() {
|
||||
fileManager.saveIds(idToEntry.toList().sortedBy { it.second.index }.map { it.first })
|
||||
fileManager.saveIds(idToEntry.asSequence().sortedBy { it.value.index }.map { it.key }.toList())
|
||||
}
|
||||
|
||||
override suspend fun clearBySourceType(sourceType: EntitySourceType) {
|
||||
@@ -224,6 +242,7 @@ open class DiskSynchronizedEmbeddingSearchIndex(val root: Path, override var lim
|
||||
}
|
||||
|
||||
companion object {
|
||||
private const val SEARCH_ATTEMPTS = 5
|
||||
private val logger = Logger.getInstance(DiskSynchronizedEmbeddingSearchIndex::class.java)
|
||||
}
|
||||
}
|
||||
@@ -3,7 +3,10 @@ package com.intellij.platform.ml.embeddings.search.indices
|
||||
|
||||
import ai.grazie.emb.FloatTextEmbedding
|
||||
import com.intellij.platform.ml.embeddings.search.utils.ScoredText
|
||||
import kotlinx.coroutines.coroutineScope
|
||||
import kotlinx.coroutines.ensureActive
|
||||
import kotlinx.coroutines.flow.Flow
|
||||
import java.util.PriorityQueue
|
||||
|
||||
interface EmbeddingSearchIndex {
|
||||
var limit: Int?
|
||||
@@ -33,20 +36,27 @@ interface EmbeddingSearchIndex {
|
||||
suspend fun checkCanAddEntry(): Boolean
|
||||
}
|
||||
|
||||
internal fun Map<EntityId, FloatTextEmbedding>.findClosest(searchEmbedding: FloatTextEmbedding,
|
||||
topK: Int, similarityThreshold: Double?): List<ScoredText> {
|
||||
return asSequence()
|
||||
.map { it.key to searchEmbedding.times(it.value) }
|
||||
.filter { (_, similarity) -> if (similarityThreshold != null) similarity > similarityThreshold else true }
|
||||
.sortedByDescending { (_, similarity) -> similarity }
|
||||
.take(topK)
|
||||
.map { (id, similarity) -> ScoredText(id.id, similarity.toDouble()) }
|
||||
.toList()
|
||||
internal suspend fun Sequence<Pair<EntityId, FloatTextEmbedding>>.findClosest(
|
||||
searchEmbedding: FloatTextEmbedding,
|
||||
topK: Int, similarityThreshold: Double?,
|
||||
): List<ScoredText> = coroutineScope {
|
||||
val closest = PriorityQueue<ScoredText>(topK + 1, compareBy { it.similarity })
|
||||
|
||||
map { (id, embedding) -> ScoredText(id.id, searchEmbedding.times(embedding).toDouble()) }
|
||||
.filter { similarityThreshold == null || it.similarity > similarityThreshold }
|
||||
.forEach {
|
||||
ensureActive()
|
||||
closest.add(it)
|
||||
if (closest.size > topK) closest.poll()
|
||||
}
|
||||
|
||||
closest.sortedByDescending { it.similarity }
|
||||
}
|
||||
|
||||
internal fun Sequence<Pair<EntityId, FloatTextEmbedding>>.streamFindClose(queryEmbedding: FloatTextEmbedding,
|
||||
similarityThreshold: Double?): Sequence<ScoredText> {
|
||||
return map { (id, embedding) -> id to queryEmbedding.times(embedding) }
|
||||
.filter { similarityThreshold == null || it.second > similarityThreshold }
|
||||
.map { (id, similarity) -> ScoredText(id.id, similarity.toDouble()) }
|
||||
internal fun Sequence<Pair<EntityId, FloatTextEmbedding>>.streamFindClose(
|
||||
queryEmbedding: FloatTextEmbedding,
|
||||
similarityThreshold: Double?,
|
||||
): Sequence<ScoredText> {
|
||||
return map { (id, embedding) -> ScoredText(id.id, queryEmbedding.times(embedding).toDouble()) }
|
||||
.filter { similarityThreshold == null || it.similarity > similarityThreshold }
|
||||
}
|
||||
@@ -16,7 +16,7 @@ import java.nio.file.Path
|
||||
* Can be persisted to disk.
|
||||
*/
|
||||
class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null) : EmbeddingSearchIndex {
|
||||
private var idToEmbedding: MutableMap<EntityId, FloatTextEmbedding> = CollectionFactory.createSmallMemoryFootprintMap()
|
||||
private val idToEmbedding: MutableMap<EntityId, FloatTextEmbedding> = CollectionFactory.createSmallMemoryFootprintMap()
|
||||
private val uncheckedIds: MutableSet<EntityId> = ConcurrentCollectionFactory.createConcurrentSet()
|
||||
private val lock = SuspendingReadWriteLock()
|
||||
|
||||
@@ -27,7 +27,9 @@ class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null)
|
||||
override suspend fun setLimit(value: Int?) = lock.write {
|
||||
// Shrink index if necessary:
|
||||
if (value != null && value < idToEmbedding.size) {
|
||||
idToEmbedding = idToEmbedding.toList().take(value).toMap().toMutableMap()
|
||||
val remaining = idToEmbedding.asSequence().take(value).map { it.toPair() }.toList()
|
||||
idToEmbedding.clear()
|
||||
idToEmbedding.putAll(remaining)
|
||||
}
|
||||
limit = value
|
||||
}
|
||||
@@ -55,29 +57,37 @@ class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null)
|
||||
uncheckedIds.clear()
|
||||
}
|
||||
|
||||
override suspend fun addEntries(values: Iterable<Pair<EntityId, FloatTextEmbedding>>, shouldCount: Boolean) =
|
||||
lock.write {
|
||||
if (limit != null) {
|
||||
val list = values.toList()
|
||||
list.forEach { uncheckedIds.remove(it.first) }
|
||||
idToEmbedding.putAll(list.take(minOf(limit!! - idToEmbedding.size, list.size)))
|
||||
}
|
||||
else {
|
||||
idToEmbedding.putAll(values)
|
||||
}
|
||||
override suspend fun addEntries(
|
||||
values: Iterable<Pair<EntityId, FloatTextEmbedding>>,
|
||||
shouldCount: Boolean,
|
||||
) = lock.write {
|
||||
if (limit != null) {
|
||||
val list = values.toList()
|
||||
list.forEach { uncheckedIds.remove(it.first) }
|
||||
idToEmbedding.putAll(list.take(minOf(limit!! - idToEmbedding.size, list.size)))
|
||||
}
|
||||
else {
|
||||
idToEmbedding.putAll(values)
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun saveToDisk() = lock.read { save() }
|
||||
|
||||
override suspend fun loadFromDisk() = lock.write {
|
||||
val (ids, embeddings) = fileManager.loadIndex() ?: return@write
|
||||
idToEmbedding = (ids zip embeddings).toMap().toMutableMap()
|
||||
idToEmbedding.clear()
|
||||
idToEmbedding.putAll(ids zip embeddings)
|
||||
}
|
||||
|
||||
override suspend fun offload() = idToEmbedding.clear()
|
||||
override suspend fun offload() = lock.write { idToEmbedding.clear() }
|
||||
|
||||
override suspend fun findClosest(searchEmbedding: FloatTextEmbedding, topK: Int, similarityThreshold: Double?): List<ScoredText> = lock.read {
|
||||
idToEmbedding.findClosest(searchEmbedding, topK, similarityThreshold)
|
||||
override suspend fun findClosest(
|
||||
searchEmbedding: FloatTextEmbedding,
|
||||
topK: Int, similarityThreshold: Double?,
|
||||
): List<ScoredText> = lock.read {
|
||||
idToEmbedding.asSequence()
|
||||
.map { (id, embedding) -> id to embedding }
|
||||
.findClosest(searchEmbedding, topK, similarityThreshold)
|
||||
}
|
||||
|
||||
override suspend fun streamFindClose(searchEmbedding: FloatTextEmbedding, similarityThreshold: Double?): Flow<ScoredText> {
|
||||
@@ -103,7 +113,7 @@ class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null)
|
||||
}
|
||||
|
||||
private suspend fun save() {
|
||||
val (ids, embeddings) = idToEmbedding.toList().unzip()
|
||||
fileManager.saveIndex(ids = ids, embeddings = embeddings)
|
||||
val (ids, embeddings) = idToEmbedding.asSequence().map { (id, embedding) -> id to embedding }.unzip()
|
||||
fileManager.saveIndex(ids, embeddings)
|
||||
}
|
||||
}
|
||||
@@ -147,11 +147,8 @@ class FileBasedEmbeddingStoragesManager(private val project: Project, private va
|
||||
indexFiles(scanFiles().toList().sortedByDescending { it.name.length })
|
||||
EmbeddingSearchLogger.indexingFinished(project, forActions = false, TimeoutUtil.getDurationMillis(projectIndexingStartTime))
|
||||
}
|
||||
catch (e: CancellationException) {
|
||||
logger.debug { "Full project embedding indexing was cancelled" }
|
||||
throw e
|
||||
}
|
||||
finally {
|
||||
ensureActive()
|
||||
if (isFirstIndexing) {
|
||||
onFirstIndexingFinish()
|
||||
isFirstIndexing = false
|
||||
|
||||
@@ -2,54 +2,42 @@
|
||||
package com.intellij.platform.ml.embeddings.search.utils
|
||||
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.NonCancellable
|
||||
import kotlinx.coroutines.coroutineScope
|
||||
import kotlinx.coroutines.sync.Semaphore
|
||||
import kotlinx.coroutines.sync.withPermit
|
||||
import kotlinx.coroutines.sync.Mutex
|
||||
import kotlinx.coroutines.sync.withLock
|
||||
import kotlinx.coroutines.withContext
|
||||
|
||||
class SuspendingReadWriteLock {
|
||||
private val readLightswitch = Lightswitch()
|
||||
private val roomEmpty = Semaphore(1)
|
||||
|
||||
suspend fun <T> read(action: suspend CoroutineScope.() -> T): T = coroutineScope {
|
||||
readLightswitch.lock(roomEmpty)
|
||||
return@coroutineScope try {
|
||||
action()
|
||||
}
|
||||
finally {
|
||||
readLightswitch.unlock(roomEmpty)
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun <T> write(action: suspend CoroutineScope.() -> T): T = coroutineScope {
|
||||
roomEmpty.acquire()
|
||||
return@coroutineScope try {
|
||||
action()
|
||||
}
|
||||
finally {
|
||||
roomEmpty.release()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private class Lightswitch {
|
||||
private val roomEmpty = Mutex()
|
||||
private var counter = 0
|
||||
private val mutex = Semaphore(1)
|
||||
private val counterMutex = Mutex()
|
||||
|
||||
suspend fun lock(semaphore: Semaphore) {
|
||||
mutex.withPermit {
|
||||
counter += 1
|
||||
if (counter == 1) {
|
||||
semaphore.acquire()
|
||||
suspend fun <T> read(
|
||||
action: suspend CoroutineScope.() -> T,
|
||||
): T = coroutineScope {
|
||||
counterMutex.withLock {
|
||||
if (++counter == 1) {
|
||||
roomEmpty.lock()
|
||||
}
|
||||
}
|
||||
try {
|
||||
action()
|
||||
}
|
||||
finally {
|
||||
withContext(NonCancellable) {
|
||||
counterMutex.withLock {
|
||||
if (--counter == 0) {
|
||||
roomEmpty.unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun unlock(semaphore: Semaphore) {
|
||||
mutex.withPermit {
|
||||
counter -= 1
|
||||
if (counter == 0) {
|
||||
semaphore.release()
|
||||
}
|
||||
}
|
||||
suspend fun <T> write(
|
||||
action: suspend CoroutineScope.() -> T,
|
||||
): T = coroutineScope {
|
||||
roomEmpty.withLock { action() }
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user