ML in SE: trigger embedding indexing lazily only if any client uses embedding search

GitOrigin-RevId: f41373ec2a819837073f255fa063e24f03d0096e
This commit is contained in:
Evgeny Abramov
2024-02-07 16:55:36 +02:00
committed by intellij-monorepo-bot
parent 292926ec7b
commit 9a03794d45
6 changed files with 55 additions and 19 deletions

View File

@@ -10,7 +10,9 @@ import com.intellij.openapi.components.Service
import com.intellij.openapi.components.service
import com.intellij.openapi.components.serviceAsync
import com.intellij.openapi.diagnostic.logger
import com.intellij.openapi.progress.blockingContext
import com.intellij.openapi.project.Project
import com.intellij.openapi.project.ProjectManager
import com.intellij.platform.ide.progress.withBackgroundProgress
import com.intellij.platform.ml.embeddings.EmbeddingsBundle
import com.intellij.platform.ml.embeddings.search.indices.InMemoryEmbeddingSearchIndex
@@ -21,6 +23,7 @@ import com.intellij.platform.ml.embeddings.utils.generateEmbedding
import com.intellij.util.concurrency.annotations.RequiresBackgroundThread
import kotlinx.coroutines.*
import java.io.File
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicReference
/**
@@ -37,28 +40,37 @@ class ActionEmbeddingsStorage(private val cs: CoroutineScope) : EmbeddingsStorag
.resolve(INDEX_DIR).toPath()
)
private val isIndexingTriggered = AtomicBoolean(false)
private val indexSetupJob = AtomicReference<Job>(null)
private val setupTitle
get() = EmbeddingsBundle.getMessage("ml.embeddings.indices.actions.generation.label")
fun prepareForSearch(project: Project) = SemanticSearchCoroutineScope.getScope(project).launch {
fun prepareForSearch(project: Project? = null) = cs.launch {
val reportProject = project ?: blockingContext { ProjectManager.getInstance().openProjects.firstOrNull() }
isIndexingTriggered.compareAndSet(false, true)
if (!ApplicationManager.getApplication().isUnitTestMode) {
// In unit tests you have to manually download artifacts when needed
serviceAsync<LocalArtifactsManager>().downloadArtifactsIfNecessary(project, retryIfCanceled = false)
serviceAsync<LocalArtifactsManager>().downloadArtifactsIfNecessary(reportProject, retryIfCanceled = false)
}
index.loadFromDisk()
generateEmbeddingsIfNecessary(project)
generateEmbeddingsIfNecessary(reportProject)
}
fun tryStopGeneratingEmbeddings() = indexSetupJob.getAndSet(null)?.cancel()
/* Thread-safe job for updating embeddings. Consequent call stops the previous execution */
@RequiresBackgroundThread
suspend fun generateEmbeddingsIfNecessary(project: Project) = coroutineScope {
suspend fun generateEmbeddingsIfNecessary(project: Project?) = coroutineScope {
val backgroundable = ActionEmbeddingsStorageSetup(index, indexSetupJob)
try {
withBackgroundProgress(project, setupTitle) {
if (project != null) {
withBackgroundProgress(project, setupTitle) {
backgroundable.run()
}
}
else {
backgroundable.run()
}
}
@@ -74,6 +86,7 @@ class ActionEmbeddingsStorage(private val cs: CoroutineScope) : EmbeddingsStorag
@RequiresBackgroundThread
override suspend fun searchNeighbours(text: String, topK: Int, similarityThreshold: Double?): List<ScoredText> {
if (index.size == 0) return emptyList()
triggerIndexing() // trigger indexing on first search usage
val embedding = generateEmbedding(text) ?: return emptyList()
return index.findClosest(searchEmbedding = embedding, topK = topK, similarityThreshold = similarityThreshold)
}
@@ -81,10 +94,17 @@ class ActionEmbeddingsStorage(private val cs: CoroutineScope) : EmbeddingsStorag
@RequiresBackgroundThread
suspend fun streamSearchNeighbours(text: String, similarityThreshold: Double? = null): Sequence<ScoredText> {
if (index.size == 0) return emptySequence()
triggerIndexing() // trigger indexing on first search usage
val embedding = generateEmbedding(text) ?: return emptySequence()
return index.streamFindClose(embedding, similarityThreshold)
}
private fun triggerIndexing() {
if (isIndexingTriggered.compareAndSet(false, true)) {
prepareForSearch()
}
}
companion object {
private const val INDEX_DIR = "actions"

View File

@@ -16,6 +16,7 @@ abstract class DiskSynchronizedEmbeddingsStorage<T : IndexableEntity>(val projec
@RequiresBackgroundThread
override suspend fun searchNeighbours(text: String, topK: Int, similarityThreshold: Double?): List<ScoredText> {
if (index.size == 0) return emptyList()
FileBasedEmbeddingStoragesManager.getInstance(project).triggerIndexing()
val embedding = generateEmbedding(text) ?: return emptyList()
return index.findClosest(embedding, topK, similarityThreshold)
}
@@ -23,6 +24,7 @@ abstract class DiskSynchronizedEmbeddingsStorage<T : IndexableEntity>(val projec
@RequiresBackgroundThread
suspend fun streamSearchNeighbours(text: String, similarityThreshold: Double? = null): Sequence<ScoredText> {
if (index.size == 0) return emptySequence()
FileBasedEmbeddingStoragesManager.getInstance(project).triggerIndexing()
val embedding = generateEmbedding(text) ?: return emptySequence()
return index.streamFindClose(embedding, similarityThreshold)
}

View File

@@ -13,6 +13,7 @@ import com.intellij.openapi.roots.ProjectRootManager
import com.intellij.openapi.util.registry.Registry
import com.intellij.openapi.vfs.VfsUtilCore
import com.intellij.openapi.vfs.VirtualFile
import com.intellij.openapi.vfs.VirtualFileManager
import com.intellij.openapi.vfs.isFile
import com.intellij.platform.diagnostic.telemetry.helpers.useWithScope
import com.intellij.platform.ide.progress.withBackgroundProgress
@@ -25,11 +26,13 @@ import com.intellij.platform.util.progress.reportProgress
import com.intellij.psi.PsiManager
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*
import java.util.concurrent.atomic.AtomicBoolean
@Service(Service.Level.PROJECT)
class FileBasedEmbeddingStoragesManager(private val project: Project, private val cs: CoroutineScope) {
private val indexingScope = cs.namedChildScope("Embedding indexing scope")
private var isFirstIndexing = true
private val isIndexingTriggered = AtomicBoolean(false)
private val filesLimit: Int?
get() {
@@ -40,6 +43,7 @@ class FileBasedEmbeddingStoragesManager(private val project: Project, private va
}
fun prepareForSearch() = cs.launch {
if (isIndexingTriggered.compareAndSet(false, true)) addFileListener()
indexingScope.coroutineContext.cancelChildren()
withContext(indexingScope.coroutineContext) {
if (!ApplicationManager.getApplication().isUnitTestMode) {
@@ -50,6 +54,20 @@ class FileBasedEmbeddingStoragesManager(private val project: Project, private va
}
}
fun triggerIndexing() {
if (isIndexingTriggered.compareAndSet(false, true)) {
addFileListener()
prepareForSearch()
}
}
private fun addFileListener() {
VirtualFileManager.getInstance().addAsyncFileListener(
SemanticSearchFileChangeListener.getInstance(project),
SemanticSearchCoroutineScope.getInstance(project)
)
}
private suspend fun loadRequirements() {
withContext(Dispatchers.IO) {
if (!ApplicationManager.getApplication().isUnitTestMode) {

View File

@@ -3,7 +3,6 @@ package com.intellij.searchEverywhereMl.semantics
import com.intellij.openapi.components.serviceAsync
import com.intellij.openapi.project.Project
import com.intellij.openapi.startup.ProjectActivity
import com.intellij.openapi.vfs.VirtualFileManager
import com.intellij.platform.ml.embeddings.search.services.*
import com.intellij.searchEverywhereMl.semantics.settings.SearchEverywhereSemanticSettings
@@ -15,10 +14,6 @@ private class SemanticSearchInitializer : ProjectActivity {
override suspend fun execute(project: Project) {
val searchEverywhereSemanticSettings = serviceAsync<SearchEverywhereSemanticSettings>()
if (searchEverywhereSemanticSettings.enabledInActionsTab) {
ActionEmbeddingsStorage.getInstance().prepareForSearch(project)
}
EmbeddingIndexSettingsImpl.getInstance(project).registerClientSettings(
object : EmbeddingIndexSettings {
override val shouldIndexFiles: Boolean
@@ -29,12 +24,5 @@ private class SemanticSearchInitializer : ProjectActivity {
get() = searchEverywhereSemanticSettings.enabledInSymbolsTab
}
)
FileBasedEmbeddingStoragesManager.getInstance(project).prepareForSearch()
VirtualFileManager.getInstance().addAsyncFileListener(
SemanticSearchFileChangeListener.getInstance(project),
SemanticSearchCoroutineScope.getInstance(project)
)
}
}

View File

@@ -19,6 +19,7 @@ import com.intellij.testFramework.utils.editor.saveToDisk
import com.intellij.util.TimeoutUtil
import org.jetbrains.kotlin.psi.KtClass
import kotlinx.coroutines.test.runTest
import kotlin.time.Duration.Companion.seconds
class SemanticClassSearchTest : SemanticSearchBaseTestCase() {
private val storage
@@ -46,8 +47,12 @@ class SemanticClassSearchTest : SemanticSearchBaseTestCase() {
assertEquals(1, storage.index.size)
}
fun `test search everywhere contributor`() = runTest {
fun `test search everywhere contributor`() = runTest(
timeout = 45.seconds // increased timeout because of a bug in class index
) {
setupTest("java/IndexProjectAction.java", "kotlin/ProjectIndexingTask.kt", "java/ScoresFileManager.java")
assertEquals(3, storage.index.size)
val searchEverywhereUI = SearchEverywhereUI(project, listOf(SemanticClassSearchEverywhereContributor(createEvent())),
{ _ -> null }, null)
val elements = PlatformTestUtil.waitForFuture(searchEverywhereUI.findElementsForPattern("index project job"))

View File

@@ -20,6 +20,7 @@ import com.intellij.testFramework.utils.vfs.deleteRecursively
import com.intellij.util.TimeoutUtil
import kotlinx.coroutines.test.runTest
import org.jetbrains.kotlin.psi.KtFunction
import kotlin.time.Duration.Companion.seconds
class SemanticSymbolSearchTest : SemanticSearchBaseTestCase() {
@@ -48,7 +49,9 @@ class SemanticSymbolSearchTest : SemanticSearchBaseTestCase() {
assertEquals(1, storage.index.size)
}
fun `test search everywhere contributor`() = runTest {
fun `test search everywhere contributor`() = runTest(
timeout = 45.seconds // increased timeout because of a bug in symbol index
) {
setupTest("java/ProjectIndexingTask.java", "kotlin/ScoresFileManager.kt")
val searchEverywhereUI = SearchEverywhereUI(project, listOf(SemanticSymbolSearchEverywhereContributor(createEvent())),
{ _ -> null }, null)