mirror of
https://gitflic.ru/project/openide/openide.git
synced 2025-12-15 02:59:33 +07:00
ML in SE: trigger embedding indexing lazily only if any client uses embedding search
GitOrigin-RevId: f41373ec2a819837073f255fa063e24f03d0096e
This commit is contained in:
committed by
intellij-monorepo-bot
parent
292926ec7b
commit
9a03794d45
@@ -10,7 +10,9 @@ import com.intellij.openapi.components.Service
|
||||
import com.intellij.openapi.components.service
|
||||
import com.intellij.openapi.components.serviceAsync
|
||||
import com.intellij.openapi.diagnostic.logger
|
||||
import com.intellij.openapi.progress.blockingContext
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.openapi.project.ProjectManager
|
||||
import com.intellij.platform.ide.progress.withBackgroundProgress
|
||||
import com.intellij.platform.ml.embeddings.EmbeddingsBundle
|
||||
import com.intellij.platform.ml.embeddings.search.indices.InMemoryEmbeddingSearchIndex
|
||||
@@ -21,6 +23,7 @@ import com.intellij.platform.ml.embeddings.utils.generateEmbedding
|
||||
import com.intellij.util.concurrency.annotations.RequiresBackgroundThread
|
||||
import kotlinx.coroutines.*
|
||||
import java.io.File
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
import java.util.concurrent.atomic.AtomicReference
|
||||
|
||||
/**
|
||||
@@ -37,28 +40,37 @@ class ActionEmbeddingsStorage(private val cs: CoroutineScope) : EmbeddingsStorag
|
||||
.resolve(INDEX_DIR).toPath()
|
||||
)
|
||||
|
||||
private val isIndexingTriggered = AtomicBoolean(false)
|
||||
|
||||
private val indexSetupJob = AtomicReference<Job>(null)
|
||||
|
||||
private val setupTitle
|
||||
get() = EmbeddingsBundle.getMessage("ml.embeddings.indices.actions.generation.label")
|
||||
|
||||
fun prepareForSearch(project: Project) = SemanticSearchCoroutineScope.getScope(project).launch {
|
||||
fun prepareForSearch(project: Project? = null) = cs.launch {
|
||||
val reportProject = project ?: blockingContext { ProjectManager.getInstance().openProjects.firstOrNull() }
|
||||
isIndexingTriggered.compareAndSet(false, true)
|
||||
if (!ApplicationManager.getApplication().isUnitTestMode) {
|
||||
// In unit tests you have to manually download artifacts when needed
|
||||
serviceAsync<LocalArtifactsManager>().downloadArtifactsIfNecessary(project, retryIfCanceled = false)
|
||||
serviceAsync<LocalArtifactsManager>().downloadArtifactsIfNecessary(reportProject, retryIfCanceled = false)
|
||||
}
|
||||
index.loadFromDisk()
|
||||
generateEmbeddingsIfNecessary(project)
|
||||
generateEmbeddingsIfNecessary(reportProject)
|
||||
}
|
||||
|
||||
fun tryStopGeneratingEmbeddings() = indexSetupJob.getAndSet(null)?.cancel()
|
||||
|
||||
/* Thread-safe job for updating embeddings. Consequent call stops the previous execution */
|
||||
@RequiresBackgroundThread
|
||||
suspend fun generateEmbeddingsIfNecessary(project: Project) = coroutineScope {
|
||||
suspend fun generateEmbeddingsIfNecessary(project: Project?) = coroutineScope {
|
||||
val backgroundable = ActionEmbeddingsStorageSetup(index, indexSetupJob)
|
||||
try {
|
||||
withBackgroundProgress(project, setupTitle) {
|
||||
if (project != null) {
|
||||
withBackgroundProgress(project, setupTitle) {
|
||||
backgroundable.run()
|
||||
}
|
||||
}
|
||||
else {
|
||||
backgroundable.run()
|
||||
}
|
||||
}
|
||||
@@ -74,6 +86,7 @@ class ActionEmbeddingsStorage(private val cs: CoroutineScope) : EmbeddingsStorag
|
||||
@RequiresBackgroundThread
|
||||
override suspend fun searchNeighbours(text: String, topK: Int, similarityThreshold: Double?): List<ScoredText> {
|
||||
if (index.size == 0) return emptyList()
|
||||
triggerIndexing() // trigger indexing on first search usage
|
||||
val embedding = generateEmbedding(text) ?: return emptyList()
|
||||
return index.findClosest(searchEmbedding = embedding, topK = topK, similarityThreshold = similarityThreshold)
|
||||
}
|
||||
@@ -81,10 +94,17 @@ class ActionEmbeddingsStorage(private val cs: CoroutineScope) : EmbeddingsStorag
|
||||
@RequiresBackgroundThread
|
||||
suspend fun streamSearchNeighbours(text: String, similarityThreshold: Double? = null): Sequence<ScoredText> {
|
||||
if (index.size == 0) return emptySequence()
|
||||
triggerIndexing() // trigger indexing on first search usage
|
||||
val embedding = generateEmbedding(text) ?: return emptySequence()
|
||||
return index.streamFindClose(embedding, similarityThreshold)
|
||||
}
|
||||
|
||||
private fun triggerIndexing() {
|
||||
if (isIndexingTriggered.compareAndSet(false, true)) {
|
||||
prepareForSearch()
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
private const val INDEX_DIR = "actions"
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ abstract class DiskSynchronizedEmbeddingsStorage<T : IndexableEntity>(val projec
|
||||
@RequiresBackgroundThread
|
||||
override suspend fun searchNeighbours(text: String, topK: Int, similarityThreshold: Double?): List<ScoredText> {
|
||||
if (index.size == 0) return emptyList()
|
||||
FileBasedEmbeddingStoragesManager.getInstance(project).triggerIndexing()
|
||||
val embedding = generateEmbedding(text) ?: return emptyList()
|
||||
return index.findClosest(embedding, topK, similarityThreshold)
|
||||
}
|
||||
@@ -23,6 +24,7 @@ abstract class DiskSynchronizedEmbeddingsStorage<T : IndexableEntity>(val projec
|
||||
@RequiresBackgroundThread
|
||||
suspend fun streamSearchNeighbours(text: String, similarityThreshold: Double? = null): Sequence<ScoredText> {
|
||||
if (index.size == 0) return emptySequence()
|
||||
FileBasedEmbeddingStoragesManager.getInstance(project).triggerIndexing()
|
||||
val embedding = generateEmbedding(text) ?: return emptySequence()
|
||||
return index.streamFindClose(embedding, similarityThreshold)
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import com.intellij.openapi.roots.ProjectRootManager
|
||||
import com.intellij.openapi.util.registry.Registry
|
||||
import com.intellij.openapi.vfs.VfsUtilCore
|
||||
import com.intellij.openapi.vfs.VirtualFile
|
||||
import com.intellij.openapi.vfs.VirtualFileManager
|
||||
import com.intellij.openapi.vfs.isFile
|
||||
import com.intellij.platform.diagnostic.telemetry.helpers.useWithScope
|
||||
import com.intellij.platform.ide.progress.withBackgroundProgress
|
||||
@@ -25,11 +26,13 @@ import com.intellij.platform.util.progress.reportProgress
|
||||
import com.intellij.psi.PsiManager
|
||||
import kotlinx.coroutines.*
|
||||
import kotlinx.coroutines.flow.*
|
||||
import java.util.concurrent.atomic.AtomicBoolean
|
||||
|
||||
@Service(Service.Level.PROJECT)
|
||||
class FileBasedEmbeddingStoragesManager(private val project: Project, private val cs: CoroutineScope) {
|
||||
private val indexingScope = cs.namedChildScope("Embedding indexing scope")
|
||||
private var isFirstIndexing = true
|
||||
private val isIndexingTriggered = AtomicBoolean(false)
|
||||
|
||||
private val filesLimit: Int?
|
||||
get() {
|
||||
@@ -40,6 +43,7 @@ class FileBasedEmbeddingStoragesManager(private val project: Project, private va
|
||||
}
|
||||
|
||||
fun prepareForSearch() = cs.launch {
|
||||
if (isIndexingTriggered.compareAndSet(false, true)) addFileListener()
|
||||
indexingScope.coroutineContext.cancelChildren()
|
||||
withContext(indexingScope.coroutineContext) {
|
||||
if (!ApplicationManager.getApplication().isUnitTestMode) {
|
||||
@@ -50,6 +54,20 @@ class FileBasedEmbeddingStoragesManager(private val project: Project, private va
|
||||
}
|
||||
}
|
||||
|
||||
fun triggerIndexing() {
|
||||
if (isIndexingTriggered.compareAndSet(false, true)) {
|
||||
addFileListener()
|
||||
prepareForSearch()
|
||||
}
|
||||
}
|
||||
|
||||
private fun addFileListener() {
|
||||
VirtualFileManager.getInstance().addAsyncFileListener(
|
||||
SemanticSearchFileChangeListener.getInstance(project),
|
||||
SemanticSearchCoroutineScope.getInstance(project)
|
||||
)
|
||||
}
|
||||
|
||||
private suspend fun loadRequirements() {
|
||||
withContext(Dispatchers.IO) {
|
||||
if (!ApplicationManager.getApplication().isUnitTestMode) {
|
||||
|
||||
@@ -3,7 +3,6 @@ package com.intellij.searchEverywhereMl.semantics
|
||||
import com.intellij.openapi.components.serviceAsync
|
||||
import com.intellij.openapi.project.Project
|
||||
import com.intellij.openapi.startup.ProjectActivity
|
||||
import com.intellij.openapi.vfs.VirtualFileManager
|
||||
import com.intellij.platform.ml.embeddings.search.services.*
|
||||
import com.intellij.searchEverywhereMl.semantics.settings.SearchEverywhereSemanticSettings
|
||||
|
||||
@@ -15,10 +14,6 @@ private class SemanticSearchInitializer : ProjectActivity {
|
||||
override suspend fun execute(project: Project) {
|
||||
val searchEverywhereSemanticSettings = serviceAsync<SearchEverywhereSemanticSettings>()
|
||||
|
||||
if (searchEverywhereSemanticSettings.enabledInActionsTab) {
|
||||
ActionEmbeddingsStorage.getInstance().prepareForSearch(project)
|
||||
}
|
||||
|
||||
EmbeddingIndexSettingsImpl.getInstance(project).registerClientSettings(
|
||||
object : EmbeddingIndexSettings {
|
||||
override val shouldIndexFiles: Boolean
|
||||
@@ -29,12 +24,5 @@ private class SemanticSearchInitializer : ProjectActivity {
|
||||
get() = searchEverywhereSemanticSettings.enabledInSymbolsTab
|
||||
}
|
||||
)
|
||||
|
||||
FileBasedEmbeddingStoragesManager.getInstance(project).prepareForSearch()
|
||||
|
||||
VirtualFileManager.getInstance().addAsyncFileListener(
|
||||
SemanticSearchFileChangeListener.getInstance(project),
|
||||
SemanticSearchCoroutineScope.getInstance(project)
|
||||
)
|
||||
}
|
||||
}
|
||||
@@ -19,6 +19,7 @@ import com.intellij.testFramework.utils.editor.saveToDisk
|
||||
import com.intellij.util.TimeoutUtil
|
||||
import org.jetbrains.kotlin.psi.KtClass
|
||||
import kotlinx.coroutines.test.runTest
|
||||
import kotlin.time.Duration.Companion.seconds
|
||||
|
||||
class SemanticClassSearchTest : SemanticSearchBaseTestCase() {
|
||||
private val storage
|
||||
@@ -46,8 +47,12 @@ class SemanticClassSearchTest : SemanticSearchBaseTestCase() {
|
||||
assertEquals(1, storage.index.size)
|
||||
}
|
||||
|
||||
fun `test search everywhere contributor`() = runTest {
|
||||
fun `test search everywhere contributor`() = runTest(
|
||||
timeout = 45.seconds // increased timeout because of a bug in class index
|
||||
) {
|
||||
setupTest("java/IndexProjectAction.java", "kotlin/ProjectIndexingTask.kt", "java/ScoresFileManager.java")
|
||||
assertEquals(3, storage.index.size)
|
||||
|
||||
val searchEverywhereUI = SearchEverywhereUI(project, listOf(SemanticClassSearchEverywhereContributor(createEvent())),
|
||||
{ _ -> null }, null)
|
||||
val elements = PlatformTestUtil.waitForFuture(searchEverywhereUI.findElementsForPattern("index project job"))
|
||||
|
||||
@@ -20,6 +20,7 @@ import com.intellij.testFramework.utils.vfs.deleteRecursively
|
||||
import com.intellij.util.TimeoutUtil
|
||||
import kotlinx.coroutines.test.runTest
|
||||
import org.jetbrains.kotlin.psi.KtFunction
|
||||
import kotlin.time.Duration.Companion.seconds
|
||||
|
||||
|
||||
class SemanticSymbolSearchTest : SemanticSearchBaseTestCase() {
|
||||
@@ -48,7 +49,9 @@ class SemanticSymbolSearchTest : SemanticSearchBaseTestCase() {
|
||||
assertEquals(1, storage.index.size)
|
||||
}
|
||||
|
||||
fun `test search everywhere contributor`() = runTest {
|
||||
fun `test search everywhere contributor`() = runTest(
|
||||
timeout = 45.seconds // increased timeout because of a bug in symbol index
|
||||
) {
|
||||
setupTest("java/ProjectIndexingTask.java", "kotlin/ScoresFileManager.kt")
|
||||
val searchEverywhereUI = SearchEverywhereUI(project, listOf(SemanticSymbolSearchEverywhereContributor(createEvent())),
|
||||
{ _ -> null }, null)
|
||||
|
||||
Reference in New Issue
Block a user