IJPL-149042 Introduce EmbeddingEntitiesIndexer

GitOrigin-RevId: f4a304dd7172911c602d4d06f011901ea8e28d66
This commit is contained in:
Liudmila Kornilova
2024-09-27 16:23:35 +02:00
committed by intellij-monorepo-bot
parent cc7687562d
commit 4d6044e054
6 changed files with 325 additions and 250 deletions

View File

@@ -3,12 +3,10 @@ package com.intellij.platform.ml.embeddings.files
import com.intellij.concurrency.ConcurrentCollectionFactory
import com.intellij.openapi.application.readAction
import com.intellij.openapi.components.Service
import com.intellij.openapi.components.service
import com.intellij.openapi.project.Project
import com.intellij.openapi.vfs.AsyncFileListener
import com.intellij.openapi.vfs.VirtualFile
import com.intellij.openapi.vfs.newvfs.events.VFileEvent
import com.intellij.platform.ml.embeddings.indexer.FileBasedEmbeddingIndexer
import com.intellij.platform.ml.embeddings.settings.EmbeddingIndexSettingsImpl
import com.intellij.util.indexing.FileBasedIndex
import com.intellij.util.indexing.FileBasedIndexImpl
@@ -24,8 +22,7 @@ import java.util.concurrent.atomic.AtomicReference
import kotlin.time.Duration.Companion.milliseconds
@OptIn(FlowPreview::class)
@Service(Service.Level.APP)
class SemanticSearchFileChangeListener(cs: CoroutineScope) : AsyncFileListener {
class SemanticSearchFileChangeListener(cs: CoroutineScope, private val index: suspend (Project, List<VirtualFile>) -> Unit) : AsyncFileListener {
private val reindexRequest = MutableSharedFlow<Unit>(replay = 1, onBufferOverflow = BufferOverflow.DROP_OLDEST)
private val reindexQueue = AtomicReference(ConcurrentCollectionFactory.createConcurrentSet<VirtualFile>())
@@ -56,13 +53,9 @@ class SemanticSearchFileChangeListener(cs: CoroutineScope) : AsyncFileListener {
queue.flatMap { fileBasedIndex.getContainingProjects(it).map { project -> project to it } }
}.groupBy({ it.first }, { it.second })
for ((project, files) in projectToFiles) {
FileBasedEmbeddingIndexer.getInstance().indexFiles(project, files)
index(project, files)
}
// When we have a project:
// val files = IndexableFilesIndex.getInstance(project).run { queue.filter { shouldBeIndexed(it) } }
}
companion object {
fun getInstance(): SemanticSearchFileChangeListener = service()
}
}

View File

@@ -2,8 +2,6 @@
package com.intellij.platform.ml.embeddings.indexer
import com.intellij.openapi.Disposable
import com.intellij.openapi.application.readActionUndispatched
import com.intellij.openapi.application.smartReadAction
import com.intellij.openapi.components.Service
import com.intellij.openapi.components.service
import com.intellij.openapi.diagnostic.Logger
@@ -11,74 +9,43 @@ import com.intellij.openapi.diagnostic.debug
import com.intellij.openapi.progress.runBlockingMaybeCancellable
import com.intellij.openapi.project.Project
import com.intellij.openapi.project.waitForSmartMode
import com.intellij.openapi.roots.ProjectFileIndex
import com.intellij.openapi.util.Disposer
import com.intellij.openapi.util.registry.Registry
import com.intellij.openapi.vfs.VirtualFile
import com.intellij.openapi.vfs.VirtualFileManager
import com.intellij.openapi.vfs.VirtualFileWithId
import com.intellij.openapi.vfs.isFile
import com.intellij.platform.diagnostic.telemetry.helpers.useWithScope
import com.intellij.platform.ide.progress.withBackgroundProgress
import com.intellij.platform.ml.embeddings.EmbeddingsBundle
import com.intellij.platform.ml.embeddings.files.SemanticSearchFileChangeListener
import com.intellij.platform.ml.embeddings.indexer.configuration.EmbeddingsConfiguration
import com.intellij.platform.ml.embeddings.indexer.entities.*
import com.intellij.platform.ml.embeddings.indexer.storage.EmbeddingsStorageManagerWrapper
import com.intellij.platform.ml.embeddings.jvm.indices.EntityId
import com.intellij.platform.ml.embeddings.indexer.configuration.EmbeddingsConfiguration.Companion.getStorageManagerWrapper
import com.intellij.platform.ml.embeddings.indexer.entities.IndexableEntity
import com.intellij.platform.ml.embeddings.indexer.searcher.EmbeddingEntitiesIndexer
import com.intellij.platform.ml.embeddings.indexer.searcher.index.IndexBasedEmbeddingEntitiesIndexer
import com.intellij.platform.ml.embeddings.indexer.searcher.vfs.VFSBasedEmbeddingEntitiesIndexer
import com.intellij.platform.ml.embeddings.logging.EmbeddingSearchLogger
import com.intellij.platform.ml.embeddings.settings.EmbeddingIndexSettings
import com.intellij.platform.ml.embeddings.settings.EmbeddingIndexSettingsImpl
import com.intellij.platform.ml.embeddings.utils.SEMANTIC_SEARCH_TRACER
import com.intellij.platform.ml.embeddings.utils.SemanticSearchCoroutineScope
import com.intellij.platform.util.coroutines.childScope
import com.intellij.psi.PsiManager
import com.intellij.psi.search.GlobalSearchScope
import com.intellij.util.TimeoutUtil
import com.intellij.util.indexing.FileBasedIndex
import com.intellij.util.indexing.ID
import kotlinx.coroutines.*
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.channels.ReceiveChannel
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.channelFlow
import kotlinx.coroutines.flow.toList
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.launch
import kotlinx.coroutines.sync.Mutex
import kotlinx.coroutines.sync.withLock
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
internal const val TOTAL_THREAD_LIMIT_FOR_INDEXING = 8
@Service(Service.Level.APP)
class FileBasedEmbeddingIndexer(private val cs: CoroutineScope) : Disposable {
private val indexingScope = cs.childScope("Embedding indexing scope")
private val isFileListenerAdded = AtomicBoolean(false)
private val indexerScope = cs.childScope("Embedding indexer scope")
private val indexedProjects = mutableSetOf<Project>()
private val indexingJobs = mutableMapOf<Project, Job>()
private val jobsMutex = Mutex()
private val entitiesIndexer: EmbeddingEntitiesIndexer = if (Registry.`is`("intellij.platform.ml.embeddings.use.file.based.index")) IndexBasedEmbeddingEntitiesIndexer(indexerScope)
else VFSBasedEmbeddingEntitiesIndexer(indexerScope).also { searcher -> Disposer.register(this, searcher) }
private val storageManagerWrappers = buildMap {
for (indexId in FILE_BASED_INDICES) {
put(indexId, EmbeddingsConfiguration.getStorageManagerWrapper(indexId))
}
}
private val filesLimit: Int?
get() {
return if (Registry.`is`("intellij.platform.ml.embeddings.index.files.use.limit")) {
Registry.intValue("intellij.platform.ml.embeddings.index.files.limit")
}
else null
}
@OptIn(ExperimentalCoroutinesApi::class)
private val indexingContext = Dispatchers.Default.limitedParallelism(TOTAL_THREAD_LIMIT)
@OptIn(ExperimentalCoroutinesApi::class)
private val filesIterationContext = Dispatchers.Default.limitedParallelism(FILE_WORKER_COUNT)
fun prepareForSearch(project: Project): Job = cs.launch {
if (isFileListenerAdded.compareAndSet(false, true)) addFileListener()
Disposer.register(project) {
runBlockingMaybeCancellable {
jobsMutex.withLock {
@@ -103,7 +70,6 @@ class FileBasedEmbeddingIndexer(private val cs: CoroutineScope) : Disposable {
}
suspend fun triggerIndexing(project: Project) {
if (isFileListenerAdded.compareAndSet(false, true)) addFileListener()
var shouldIndex = false
jobsMutex.withLock {
if (project !in indexedProjects) {
@@ -116,10 +82,6 @@ class FileBasedEmbeddingIndexer(private val cs: CoroutineScope) : Disposable {
}
}
private fun addFileListener() {
VirtualFileManager.getInstance().addAsyncFileListener(SemanticSearchFileChangeListener.getInstance(), this)
}
private suspend fun indexProject(project: Project) {
project.waitForSmartMode()
logger.debug { "Started full project embedding indexing" }
@@ -127,7 +89,10 @@ class FileBasedEmbeddingIndexer(private val cs: CoroutineScope) : Disposable {
startIndexingSession(project)
try {
val projectIndexingStartTime = System.nanoTime()
indexFiles(project, scanFiles(project).toList().sortedByDescending { it.name.length })
val settings = EmbeddingIndexSettingsImpl.getInstance()
if (settings.shouldIndexAnythingFileBased) {
entitiesIndexer.index(project, settings)
}
EmbeddingSearchLogger.indexingFinished(project, forActions = false, TimeoutUtil.getDurationMillis(projectIndexingStartTime))
} finally {
finishIndexingSession(project)
@@ -136,179 +101,6 @@ class FileBasedEmbeddingIndexer(private val cs: CoroutineScope) : Disposable {
logger.debug { "Finished full project embedding indexing" }
}
private fun scanFiles(project: Project): Flow<VirtualFile> {
val scanLimit = filesLimit?.let { it * 2 } // do not scan all files if there is a limit
var filteredFiles = 0
return channelFlow {
SEMANTIC_SEARCH_TRACER.spanBuilder(SCANNING_SPAN_NAME).useWithScope {
withBackgroundProgress(project, EmbeddingsBundle.getMessage("ml.embeddings.indices.scanning.label")) {
ProjectFileIndex.getInstance(project).iterateContent { file ->
if (file.isFile && file.isValid && file.isInLocalFileSystem) {
launch { send(file) }
filteredFiles += 1
}
scanLimit == null || filteredFiles < scanLimit
}
}
}
}
}
suspend fun indexFiles(project: Project, files: List<VirtualFile>) {
val settings = EmbeddingIndexSettingsImpl.getInstance()
if (!settings.shouldIndexAnythingFileBased) return
withContext(indexingScope.coroutineContext) {
withContext(indexingContext) {
val filesChannel = Channel<IndexableEntity>(capacity = BUFFER_SIZE)
val classesChannel = Channel<IndexableEntity>(capacity = BUFFER_SIZE)
val symbolsChannel = Channel<IndexableEntity>(capacity = BUFFER_SIZE)
suspend fun sendEntities(indexId: IndexId, channel: ReceiveChannel<IndexableEntity>) {
val entities = ArrayList<IndexableEntity>(BATCH_SIZE)
var index = 0
val wrapper = getStorageManagerWrapper(indexId)
for (entity in channel) {
if (entities.size < BATCH_SIZE) entities.add(entity) else entities[index] = entity
++index
if (index == BATCH_SIZE) {
wrapper.addAbsent(project, entities)
index = 0
}
}
if (entities.isNotEmpty()) {
wrapper.addAbsent(project, entities)
}
}
launch { sendEntities(IndexId.FILES, filesChannel) }
launch { sendEntities(IndexId.CLASSES, classesChannel) }
launch { sendEntities(IndexId.SYMBOLS, symbolsChannel) }
if (Registry.`is`("intellij.platform.ml.embeddings.use.file.based.index")) {
launchFetchingEntities(settings.shouldIndexFiles, FILE_NAME_EMBEDDING_INDEX_NAME, filesChannel, project) { entityId -> IndexableFile(entityId) }
launchFetchingEntities(settings.shouldIndexClasses, CLASS_NAME_EMBEDDING_INDEX_NAME, classesChannel, project) { entityId -> IndexableClass(entityId) }
launchFetchingEntities(settings.shouldIndexSymbols, SYMBOL_NAME_EMBEDDING_INDEX_NAME, symbolsChannel, project) { entityId -> IndexableSymbol(entityId) }
}
else {
indexFilesInProject(project, files, settings, filesChannel, classesChannel, symbolsChannel)
}
}
}
}
private fun CoroutineScope.launchFetchingEntities(shouldIndex: Boolean,
index: ID<EmbeddingKey, String>,
channel: Channel<IndexableEntity>,
project: Project,
toIndexableEntity: (EntityId) -> IndexableEntity) {
if (!shouldIndex) {
channel.close()
return
}
launch {
fetchEntities(index, channel, project) { key, name ->
LongIndexableEntity(key.toLong(), toIndexableEntity(EntityId(name)))
}
channel.close()
}
}
private suspend fun indexFilesInProject(
project: Project,
files: List<VirtualFile>,
settings: EmbeddingIndexSettingsImpl,
filesChannel: Channel<IndexableEntity>,
classesChannel: Channel<IndexableEntity>,
symbolsChannel: Channel<IndexableEntity>,
) {
val psiManager = PsiManager.getInstance(project)
val processedFiles = AtomicInteger(0)
val total: Int = filesLimit?.let { minOf(files.size, it) } ?: files.size
logger.debug { "Effective embedding indexing files limit: $total" }
withContext(filesIterationContext) {
val limit = filesLimit
repeat(FILE_WORKER_COUNT) { worker ->
var index = worker
launch {
while (index < files.size) {
if (limit != null && processedFiles.get() >= limit) return@launch
val file = files[index]
if (file.isFile && file.isValid && file.isInLocalFileSystem) {
processFile(file, psiManager, settings, filesChannel, classesChannel, symbolsChannel)
processedFiles.incrementAndGet()
}
else {
logger.debug { "File is not valid: ${file.name}" }
}
index += FILE_WORKER_COUNT
}
}
}
}
filesChannel.close()
classesChannel.close()
symbolsChannel.close()
}
private suspend fun fetchEntities(indexId: ID<EmbeddingKey, String>,
channel: Channel<IndexableEntity>,
project: Project,
nameToEntity: (Long, String) -> LongIndexableEntity) {
val fileBasedIndex = FileBasedIndex.getInstance()
val scope = GlobalSearchScope.projectScope(project)
val keys = smartReadAction(project) { fileBasedIndex.getAllKeys(indexId, project) }
val chunkSize = Registry.intValue("intellij.platform.ml.embeddings.file.based.index.processing.chunk.size")
keys.asSequence().chunked(chunkSize).forEach { chunk ->
chunk.forEach { key ->
val fileIdsAndNames = smartReadAction(project) {
val result = mutableListOf<Pair<Int, String>>()
fileBasedIndex.processValues(indexId, key, null, { virtualFile, name ->
if (virtualFile is VirtualFileWithId) {
result.add(Pair(virtualFile.id, name))
}
true
}, scope)
result
}
for ((fileId, name) in fileIdsAndNames) {
channel.send(nameToEntity(key.toLong(fileId), name))
}
}
}
}
private suspend fun processFile(
file: VirtualFile,
psiManager: PsiManager,
settings: EmbeddingIndexSettings,
filesChannel: Channel<IndexableEntity>,
classesChannel: Channel<IndexableEntity>,
symbolsChannel: Channel<IndexableEntity>,
) = coroutineScope {
if (settings.shouldIndexFiles) {
launch {
filesChannel.send(IndexableFile(file))
}
}
if (settings.shouldIndexClasses || settings.shouldIndexSymbols) {
val psiFile = readActionUndispatched { psiManager.findFile(file) } ?: return@coroutineScope
if (settings.shouldIndexClasses) {
launch {
readActionUndispatched { ClassesProvider.extractClasses(psiFile) }.forEach { classesChannel.send(it) }
}
}
if (settings.shouldIndexSymbols) {
launch {
readActionUndispatched { SymbolsProvider.extractSymbols(psiFile) }.forEach { symbolsChannel.send(it) }
}
}
}
}
private suspend fun startIndexingSession(project: Project) {
for (indexId in FILE_BASED_INDICES) {
getStorageManagerWrapper(indexId).startIndexingSession(project)
@@ -321,27 +113,60 @@ class FileBasedEmbeddingIndexer(private val cs: CoroutineScope) : Disposable {
}
}
private fun getStorageManagerWrapper(indexId: IndexId): EmbeddingsStorageManagerWrapper<*> {
return storageManagerWrappers[indexId] ?: throw IllegalArgumentException("$indexId is not supported for file-based indexing")
}
companion object {
fun getInstance(): FileBasedEmbeddingIndexer = service()
private const val TOTAL_THREAD_LIMIT = 8
private const val FILE_WORKER_COUNT = 4
private const val BATCH_SIZE = 128
private const val BUFFER_SIZE = BATCH_SIZE * 8
internal const val INDEXING_VERSION = "0.0.1"
private val FILE_BASED_INDICES = arrayOf(IndexId.FILES, IndexId.CLASSES, IndexId.SYMBOLS)
private val logger = Logger.getInstance(FileBasedEmbeddingIndexer::class.java)
private const val SCANNING_SPAN_NAME = "embeddingFilesScanning"
private const val INDEXING_SPAN_NAME = "embeddingIndexing"
}
override fun dispose() {}
}
}
private const val BATCH_SIZE = 128
private const val BUFFER_SIZE = BATCH_SIZE * 8
suspend fun sendEntities(project: Project, indexId: IndexId, channel: ReceiveChannel<IndexableEntity>) {
val entities = ArrayList<IndexableEntity>(BATCH_SIZE)
var index = 0
val wrapper = getStorageManagerWrapper(indexId)
for (entity in channel) {
if (entities.size < BATCH_SIZE) entities.add(entity) else entities[index] = entity
++index
if (index == BATCH_SIZE) {
wrapper.addAbsent(project, entities)
index = 0
}
}
if (entities.isNotEmpty()) {
wrapper.addAbsent(project, entities)
}
}
internal suspend fun searchAndSendEntities(
project: Project,
settings: EmbeddingIndexSettings,
launchSearching: CoroutineScope.(Channel<IndexableEntity>?, Channel<IndexableEntity>?, Channel<IndexableEntity>?) -> Unit,
) = coroutineScope {
val filesChannel = if (settings.shouldIndexFiles) Channel<IndexableEntity>(capacity = BUFFER_SIZE) else null
val classesChannel = if (settings.shouldIndexClasses) Channel<IndexableEntity>(capacity = BUFFER_SIZE) else null
val symbolsChannel = if (settings.shouldIndexSymbols) Channel<IndexableEntity>(capacity = BUFFER_SIZE) else null
if (filesChannel != null) launch { sendEntities(project, IndexId.FILES, filesChannel) }
if (classesChannel != null) launch { sendEntities(project, IndexId.CLASSES, classesChannel) }
if (symbolsChannel != null) launch { sendEntities(project, IndexId.SYMBOLS, symbolsChannel) }
coroutineScope {
launchSearching(this, filesChannel, classesChannel, symbolsChannel)
}
// Here all producer coroutines launch from launchSearching finished,
// so we can close channels to make consumer coroutines finish
filesChannel?.close()
classesChannel?.close()
symbolsChannel?.close()
}

View File

@@ -2,7 +2,6 @@
package com.intellij.platform.ml.embeddings.indexer.configuration
import com.intellij.openapi.util.registry.Registry
import com.intellij.platform.ml.embeddings.indexer.IndexId
import com.intellij.platform.ml.embeddings.indexer.keys.EmbeddingStorageKeyProvider
import com.intellij.platform.ml.embeddings.indexer.keys.IndexLongKeyProvider
import com.intellij.platform.ml.embeddings.indexer.storage.NativeServerTextEmbeddingsStorageManager
@@ -18,7 +17,7 @@ class NativeServerFileBasedIndexEmbeddingsConfiguration: EmbeddingsConfiguration
}
override fun isEnabled(): Boolean {
return Registry.Companion.`is`("intellij.platform.ml.embeddings.use.native.server") &&
Registry.Companion.`is`("intellij.platform.ml.embeddings.use.file.based.index")
return Registry.`is`("intellij.platform.ml.embeddings.use.native.server") &&
Registry.`is`("intellij.platform.ml.embeddings.use.file.based.index")
}
}

View File

@@ -0,0 +1,9 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.embeddings.indexer.searcher
import com.intellij.openapi.project.Project
import com.intellij.platform.ml.embeddings.settings.EmbeddingIndexSettings
internal interface EmbeddingEntitiesIndexer {
suspend fun index(project: Project, settings: EmbeddingIndexSettings)
}

View File

@@ -0,0 +1,80 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.embeddings.indexer.searcher.index
import com.intellij.openapi.application.smartReadAction
import com.intellij.openapi.project.Project
import com.intellij.openapi.util.registry.Registry
import com.intellij.openapi.vfs.VirtualFileWithId
import com.intellij.platform.ml.embeddings.indexer.*
import com.intellij.platform.ml.embeddings.indexer.entities.*
import com.intellij.platform.ml.embeddings.indexer.searcher.EmbeddingEntitiesIndexer
import com.intellij.platform.ml.embeddings.jvm.indices.EntityId
import com.intellij.platform.ml.embeddings.settings.EmbeddingIndexSettings
import com.intellij.platform.util.coroutines.childScope
import com.intellij.psi.search.GlobalSearchScope
import com.intellij.util.indexing.FileBasedIndex
import com.intellij.util.indexing.ID
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Dispatchers
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.launch
internal class IndexBasedEmbeddingEntitiesIndexer(cs: CoroutineScope) : EmbeddingEntitiesIndexer {
@OptIn(ExperimentalCoroutinesApi::class)
private val indexingScope = cs.childScope("IndexBasedEmbeddingEntitiesIndexer indexing scope", Dispatchers.Default.limitedParallelism(TOTAL_THREAD_LIMIT_FOR_INDEXING))
override suspend fun index(project: Project, settings: EmbeddingIndexSettings) {
indexingScope.launch {
searchAndSendEntities(project, settings) { filesChannel, classesChannel, symbolsChannel ->
if (filesChannel != null) launchFetchingEntities(FILE_NAME_EMBEDDING_INDEX_NAME, filesChannel, project) { entityId -> IndexableFile(entityId) }
if (classesChannel != null) launchFetchingEntities(CLASS_NAME_EMBEDDING_INDEX_NAME, classesChannel, project) { entityId -> IndexableClass(entityId) }
if (symbolsChannel != null) launchFetchingEntities(SYMBOL_NAME_EMBEDDING_INDEX_NAME, symbolsChannel, project) { entityId -> IndexableSymbol(entityId) }
}
}.join()
}
private fun CoroutineScope.launchFetchingEntities(
index: ID<EmbeddingKey, String>,
channel: Channel<IndexableEntity>,
project: Project,
toIndexableEntity: (EntityId) -> IndexableEntity,
) {
launch {
fetchEntities(index, channel, project) { key, name ->
LongIndexableEntity(key.toLong(), toIndexableEntity(EntityId(name)))
}
channel.close()
}
}
private suspend fun fetchEntities(
indexId: ID<EmbeddingKey, String>,
channel: Channel<IndexableEntity>,
project: Project,
nameToEntity: (Long, String) -> LongIndexableEntity,
) {
val fileBasedIndex = FileBasedIndex.getInstance()
val scope = GlobalSearchScope.projectScope(project)
val keys = smartReadAction(project) { fileBasedIndex.getAllKeys(indexId, project) }
val chunkSize = Registry.intValue("intellij.platform.ml.embeddings.file.based.index.processing.chunk.size")
keys.asSequence().chunked(chunkSize).forEach { chunk ->
chunk.forEach { key ->
val fileIdsAndNames = smartReadAction(project) {
val result = mutableListOf<Pair<Int, String>>()
fileBasedIndex.processValues(indexId, key, null, { virtualFile, name ->
if (virtualFile is VirtualFileWithId) {
result.add(Pair(virtualFile.id, name))
}
true
}, scope)
result
}
for ((fileId, name) in fileIdsAndNames) {
channel.send(nameToEntity(key.toLong(fileId), name))
}
}
}
}
}

View File

@@ -0,0 +1,169 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.embeddings.indexer.searcher.vfs
import com.intellij.openapi.Disposable
import com.intellij.openapi.application.readActionUndispatched
import com.intellij.openapi.diagnostic.Logger
import com.intellij.openapi.diagnostic.debug
import com.intellij.openapi.project.Project
import com.intellij.openapi.roots.ProjectFileIndex
import com.intellij.openapi.util.registry.Registry
import com.intellij.openapi.vfs.VirtualFile
import com.intellij.openapi.vfs.VirtualFileManager
import com.intellij.openapi.vfs.isFile
import com.intellij.platform.diagnostic.telemetry.helpers.useWithScope
import com.intellij.platform.ide.progress.withBackgroundProgress
import com.intellij.platform.ml.embeddings.EmbeddingsBundle
import com.intellij.platform.ml.embeddings.files.SemanticSearchFileChangeListener
import com.intellij.platform.ml.embeddings.indexer.ClassesProvider
import com.intellij.platform.ml.embeddings.indexer.SymbolsProvider
import com.intellij.platform.ml.embeddings.indexer.TOTAL_THREAD_LIMIT_FOR_INDEXING
import com.intellij.platform.ml.embeddings.indexer.entities.IndexableEntity
import com.intellij.platform.ml.embeddings.indexer.entities.IndexableFile
import com.intellij.platform.ml.embeddings.indexer.searchAndSendEntities
import com.intellij.platform.ml.embeddings.indexer.searcher.EmbeddingEntitiesIndexer
import com.intellij.platform.ml.embeddings.settings.EmbeddingIndexSettings
import com.intellij.platform.ml.embeddings.settings.EmbeddingIndexSettingsImpl
import com.intellij.platform.ml.embeddings.utils.SEMANTIC_SEARCH_TRACER
import com.intellij.platform.util.coroutines.childScope
import com.intellij.psi.PsiManager
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.Channel
import java.util.concurrent.atomic.AtomicBoolean
import java.util.concurrent.atomic.AtomicInteger
private const val SCANNING_SPAN_NAME = "embeddingFilesScanning"
private const val FILE_WORKER_COUNT = 4
private val logger = Logger.getInstance(VFSBasedEmbeddingEntitiesIndexer::class.java)
internal class VFSBasedEmbeddingEntitiesIndexer(private val cs: CoroutineScope) : EmbeddingEntitiesIndexer, Disposable {
private val isFileListenerAdded = AtomicBoolean(false)
@OptIn(ExperimentalCoroutinesApi::class)
private val indexingScope = cs.childScope("VFSBasedEmbeddingEntitiesSearcher indexing scope", Dispatchers.Default.limitedParallelism(TOTAL_THREAD_LIMIT_FOR_INDEXING))
private val filesLimit: Int?
get() {
return if (Registry.`is`("intellij.platform.ml.embeddings.index.files.use.limit")) {
Registry.intValue("intellij.platform.ml.embeddings.index.files.limit")
}
else null
}
@OptIn(ExperimentalCoroutinesApi::class)
private val filesIterationContext = Dispatchers.Default.limitedParallelism(FILE_WORKER_COUNT)
override suspend fun index(project: Project, settings: EmbeddingIndexSettings) {
if (isFileListenerAdded.compareAndSet(false, true)) addFileListener()
indexingScope.launch {
val files = scanFiles(project).sortedByDescending { it.name.length }
searchAndSendEntities(project, settings) { filesChannel, classesChannel, symbolsChannel ->
launch {
search(project, files, filesChannel, classesChannel, symbolsChannel)
}
}
}.join()
}
private fun addFileListener() {
val listener = SemanticSearchFileChangeListener(cs.childScope("Embedding file change listener scope"), ::searchAndIndex)
VirtualFileManager.getInstance().addAsyncFileListener(listener, this)
}
suspend fun searchAndIndex(project: Project, files: List<VirtualFile>) {
val settings = EmbeddingIndexSettingsImpl.getInstance()
if (settings.shouldIndexAnythingFileBased) {
indexingScope.launch {
searchAndSendEntities(project, settings) { filesChannel, classesChannel, symbolsChannel ->
launch {
search(project, files, filesChannel, classesChannel, symbolsChannel)
}
}
}.join()
}
}
private suspend fun search(
project: Project,
files: List<VirtualFile>,
filesChannel: Channel<IndexableEntity>?,
classesChannel: Channel<IndexableEntity>?,
symbolsChannel: Channel<IndexableEntity>?,
) {
val psiManager = PsiManager.getInstance(project)
val processedFiles = AtomicInteger(0)
val total: Int = filesLimit?.let { minOf(files.size, it) } ?: files.size
logger.debug { "Effective embedding indexing files limit: $total" }
withContext(filesIterationContext) {
val limit = filesLimit
repeat(FILE_WORKER_COUNT) { worker ->
var index = worker
launch {
while (index < files.size) {
if (limit != null && processedFiles.get() >= limit) return@launch
val file = files[index]
if (file.isFile && file.isValid && file.isInLocalFileSystem) {
processFile(file, psiManager, filesChannel, classesChannel, symbolsChannel)
processedFiles.incrementAndGet()
}
else {
logger.debug { "File is not valid: ${file.name}" }
}
index += FILE_WORKER_COUNT
}
}
}
}
}
private suspend fun scanFiles(project: Project): List<VirtualFile> {
val scanLimit = filesLimit?.let { it * 2 } // do not scan all files if there is a limit
var filteredFiles = 0
return SEMANTIC_SEARCH_TRACER.spanBuilder(SCANNING_SPAN_NAME).useWithScope {
withBackgroundProgress(project, EmbeddingsBundle.getMessage("ml.embeddings.indices.scanning.label")) {
val files = mutableListOf<VirtualFile>()
ProjectFileIndex.getInstance(project).iterateContent { file ->
if (file.isFile && file.isValid && file.isInLocalFileSystem) {
files.add(file)
filteredFiles += 1
}
scanLimit == null || filteredFiles < scanLimit
}
files
}
}
}
private suspend fun processFile(
file: VirtualFile,
psiManager: PsiManager,
filesChannel: Channel<IndexableEntity>?,
classesChannel: Channel<IndexableEntity>?,
symbolsChannel: Channel<IndexableEntity>?,
) = coroutineScope {
if (filesChannel != null) {
launch {
filesChannel.send(IndexableFile(file))
}
}
if (classesChannel != null || symbolsChannel != null) {
val psiFile = readActionUndispatched { psiManager.findFile(file) } ?: return@coroutineScope
if (classesChannel != null) {
launch {
readActionUndispatched { ClassesProvider.extractClasses(psiFile) }.forEach { classesChannel.send(it) }
}
}
if (symbolsChannel != null) {
launch {
readActionUndispatched { SymbolsProvider.extractSymbols(psiFile) }.forEach { symbolsChannel.send(it) }
}
}
}
}
override fun dispose() = Unit
}