mirror of
https://gitflic.ru/project/openide/openide.git
synced 2026-03-22 15:19:59 +07:00
(IJPL-149042) Comply with explicit return type inspection
GitOrigin-RevId: 27617fcc43971968d4b50a242913806eb1bcf224
This commit is contained in:
committed by
intellij-monorepo-bot
parent
1b76e04151
commit
331b93804a
@@ -22,7 +22,6 @@ import com.intellij.platform.ml.embeddings.external.artifacts.LocalArtifactsMana
|
||||
import com.intellij.platform.ml.embeddings.external.artifacts.LocalArtifactsManager.Companion.getArchitectureId
|
||||
import com.intellij.platform.ml.embeddings.external.artifacts.LocalArtifactsManager.Companion.getOsId
|
||||
import com.intellij.platform.ml.embeddings.indexer.FileBasedEmbeddingIndexer.Companion.INDEXING_VERSION
|
||||
import com.intellij.platform.ml.embeddings.jvm.models.CustomRootDataLoader
|
||||
import com.intellij.util.concurrency.annotations.RequiresBackgroundThread
|
||||
import com.intellij.util.download.DownloadableFileService
|
||||
import com.intellij.util.io.ZipUtil
|
||||
@@ -55,8 +54,6 @@ class LocalArtifactsManager {
|
||||
// TODO: the list may be changed depending on the project size
|
||||
private val availableModels = listOf(ModelArtifact.SmallModelArtifact)
|
||||
|
||||
fun getCustomRootDataLoader() = CustomRootDataLoader(modelsRoot)
|
||||
|
||||
@RequiresBackgroundThread
|
||||
suspend fun downloadArtifactsIfNecessary(
|
||||
project: Project? = null,
|
||||
@@ -90,9 +87,7 @@ class LocalArtifactsManager {
|
||||
|
||||
// TODO: provide model id in arguments
|
||||
fun getModelArtifact(): ModelArtifact = availableModels.first()
|
||||
fun getServerArtifact() = NativeServerArtifact
|
||||
|
||||
fun checkArtifactsPresent(): Boolean = availableModels.all { it.checkPresent() }
|
||||
fun getServerArtifact(): NativeServerArtifact = NativeServerArtifact
|
||||
|
||||
private fun downloadArtifacts(artifacts: List<DownloadableArtifact>) {
|
||||
try {
|
||||
@@ -136,9 +131,9 @@ class LocalArtifactsManager {
|
||||
}
|
||||
|
||||
companion object {
|
||||
const val SEMANTIC_SEARCH_RESOURCES_DIR_NAME = "semantic-search" // TODO: move to common constants
|
||||
const val SEMANTIC_SEARCH_RESOURCES_DIR_NAME: String = "semantic-search" // TODO: move to common constants
|
||||
|
||||
const val INDICES_DIR_NAME = "indices"
|
||||
const val INDICES_DIR_NAME: String = "indices"
|
||||
private const val MODELS_DIR_NAME = "models"
|
||||
private const val SERVER_DIR_NAME = "server"
|
||||
|
||||
@@ -204,7 +199,7 @@ sealed class ModelArtifact(
|
||||
) : DownloadableArtifact {
|
||||
data object SmallModelArtifact : ModelArtifact("small", "dan_100k_optimized.onnx", "bert-base-uncased.txt")
|
||||
|
||||
final override val archiveName = "$name.zip"
|
||||
final override val archiveName: String = "$name.zip"
|
||||
override val downloadLink: String = listOf(CDN_LINK_BASE, MODEL_VERSION, archiveName).joinToString(separator = "/")
|
||||
override val destination: Path = LocalArtifactsManager.modelsRoot / name
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ enum class EmbeddingDistanceMetric(val label: String) {
|
||||
COSINE("cos"),
|
||||
SQUARED_EUCLIDEAN("l2sq");
|
||||
|
||||
override fun toString() = label
|
||||
override fun toString(): String = label
|
||||
}
|
||||
|
||||
enum class EmbeddingQuantization(val label: String) {
|
||||
@@ -18,7 +18,7 @@ enum class EmbeddingQuantization(val label: String) {
|
||||
INT8("i8"),
|
||||
BINARY("b1x8");
|
||||
|
||||
override fun toString() = label
|
||||
override fun toString(): String = label
|
||||
}
|
||||
|
||||
data class NativeServerStartupArguments(
|
||||
|
||||
@@ -62,6 +62,6 @@ class ServerDiagnosticsListener : ProcessListener {
|
||||
}
|
||||
|
||||
companion object {
|
||||
const val MAX_HISTORY_SIZE = 1_000
|
||||
const val MAX_HISTORY_SIZE: Int = 1_000
|
||||
}
|
||||
}
|
||||
@@ -1,5 +0,0 @@
|
||||
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
@ApiStatus.Internal
|
||||
package com.intellij.platform.ml.embeddings.files;
|
||||
|
||||
import org.jetbrains.annotations.ApiStatus;
|
||||
@@ -8,5 +8,5 @@ import com.intellij.platform.ml.embeddings.utils.splitIdentifierIntoTokens
|
||||
|
||||
class IndexableFile(override val id: EntityId) : IndexableEntity {
|
||||
constructor(file: VirtualFile) : this(EntityId(file.name))
|
||||
override val indexableRepresentation by lazy { splitIdentifierIntoTokens(FileUtilRt.getNameWithoutExtension(id.id)) }
|
||||
override val indexableRepresentation: String by lazy { splitIdentifierIntoTokens(FileUtilRt.getNameWithoutExtension(id.id)) }
|
||||
}
|
||||
@@ -53,8 +53,8 @@ class IntegerStorageKeyProvider : EmbeddingStorageKeyProvider<Long>, Disposable
|
||||
}
|
||||
|
||||
companion object {
|
||||
private const val ENUMERATOR_FOLDER = "enumerator"
|
||||
private const val ENUMERATOR_FILE = "ids.enum"
|
||||
private const val ENUMERATOR_FOLDER: String = "enumerator"
|
||||
private const val ENUMERATOR_FILE: String = "ids.enum"
|
||||
|
||||
fun getInstance(): IntegerStorageKeyProvider = service()
|
||||
}
|
||||
|
||||
@@ -60,7 +60,7 @@ class EmbeddingsStorageManagerWrapper<KeyT>(
|
||||
return storageManager.getStorageStats(project, indexId)
|
||||
}
|
||||
|
||||
fun getBatchSize() = storageManager.getBatchSize()
|
||||
fun getBatchSize(): Int = storageManager.getBatchSize()
|
||||
|
||||
companion object {
|
||||
private const val INDEXABLE_REPRESENTATION_CHAR_LIMIT = 64
|
||||
|
||||
@@ -49,29 +49,31 @@ class KInferenceLocalArtifactsManager {
|
||||
root.toPath().listDirectoryEntries().filter { it.name != MODEL_ARTIFACTS_DIR }.forEach { it.delete(recursively = true) }
|
||||
}
|
||||
|
||||
fun getCustomRootDataLoader() = CustomRootDataLoader(modelArtifactsRoot.toPath())
|
||||
fun getCustomRootDataLoader(): CustomRootDataLoader = CustomRootDataLoader(modelArtifactsRoot.toPath())
|
||||
|
||||
@RequiresBackgroundThread
|
||||
suspend fun downloadArtifactsIfNecessary(project: Project? = null,
|
||||
retryIfCanceled: Boolean = true) = withContext(downloadContext) {
|
||||
if (!checkArtifactsPresent() && !ApplicationManager.getApplication().isUnitTestMode && (retryIfCanceled || !downloadCanceled)) {
|
||||
logger.debug("Semantic search artifacts are not present, starting the download...")
|
||||
if (project != null) {
|
||||
withBackgroundProgress(project, ARTIFACTS_DOWNLOAD_TASK_NAME) {
|
||||
try {
|
||||
coroutineToIndicator { // platform code relies on the existence of indicator
|
||||
downloadArtifacts()
|
||||
retryIfCanceled: Boolean = true) {
|
||||
withContext(downloadContext) {
|
||||
if (!checkArtifactsPresent() && !ApplicationManager.getApplication().isUnitTestMode && (retryIfCanceled || !downloadCanceled)) {
|
||||
logger.debug("Semantic search artifacts are not present, starting the download...")
|
||||
if (project != null) {
|
||||
withBackgroundProgress(project, ARTIFACTS_DOWNLOAD_TASK_NAME) {
|
||||
try {
|
||||
coroutineToIndicator { // platform code relies on the existence of indicator
|
||||
downloadArtifacts()
|
||||
}
|
||||
}
|
||||
catch (e: CancellationException) {
|
||||
logger.debug("Artifacts downloading was canceled")
|
||||
downloadCanceled = true
|
||||
throw e
|
||||
}
|
||||
}
|
||||
catch (e: CancellationException) {
|
||||
logger.debug("Artifacts downloading was canceled")
|
||||
downloadCanceled = true
|
||||
throw e
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
downloadArtifacts()
|
||||
else {
|
||||
downloadArtifacts()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -107,7 +109,7 @@ class KInferenceLocalArtifactsManager {
|
||||
}
|
||||
|
||||
companion object {
|
||||
const val SEMANTIC_SEARCH_RESOURCES_DIR = "semantic-search"
|
||||
const val SEMANTIC_SEARCH_RESOURCES_DIR: String = "semantic-search"
|
||||
|
||||
private val ARTIFACTS_DOWNLOAD_TASK_NAME
|
||||
get() = EmbeddingsBundle.getMessage("ml.embeddings.artifacts.download.name")
|
||||
|
||||
@@ -17,7 +17,7 @@ interface EmbeddingSearchIndex {
|
||||
suspend fun contains(id: EntityId): Boolean
|
||||
suspend fun lookup(id: EntityId): FloatTextEmbedding?
|
||||
suspend fun clear()
|
||||
suspend fun clearBySourceType(sourceType: EntitySourceType) = Unit
|
||||
suspend fun clearBySourceType(sourceType: EntitySourceType) {}
|
||||
suspend fun remove(id: EntityId)
|
||||
|
||||
suspend fun onIndexingStart()
|
||||
|
||||
@@ -8,5 +8,5 @@ enum class EntitySourceType {
|
||||
EXTERNAL;
|
||||
|
||||
@JsonValue
|
||||
fun value() = ordinal
|
||||
fun value(): Int = ordinal
|
||||
}
|
||||
@@ -22,16 +22,18 @@ class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null)
|
||||
|
||||
private val fileManager = LocalEmbeddingIndexFileManager(root)
|
||||
|
||||
override suspend fun getSize() = lock.read { idToEmbedding.size }
|
||||
override suspend fun getSize(): Int = lock.read { idToEmbedding.size }
|
||||
|
||||
override suspend fun setLimit(value: Int?) = lock.write {
|
||||
// Shrink index if necessary:
|
||||
if (value != null && value < idToEmbedding.size) {
|
||||
val remaining = idToEmbedding.asSequence().take(value).map { it.toPair() }.toList()
|
||||
idToEmbedding.clear()
|
||||
idToEmbedding.putAll(remaining)
|
||||
override suspend fun setLimit(value: Int?) {
|
||||
lock.write {
|
||||
// Shrink index if necessary:
|
||||
if (value != null && value < idToEmbedding.size) {
|
||||
val remaining = idToEmbedding.asSequence().take(value).map { it.toPair() }.toList()
|
||||
idToEmbedding.clear()
|
||||
idToEmbedding.putAll(remaining)
|
||||
}
|
||||
limit = value
|
||||
}
|
||||
limit = value
|
||||
}
|
||||
|
||||
override suspend fun contains(id: EntityId): Boolean = lock.read {
|
||||
@@ -40,9 +42,11 @@ class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null)
|
||||
|
||||
override suspend fun lookup(id: EntityId): FloatTextEmbedding? = lock.read { idToEmbedding[id] }
|
||||
|
||||
override suspend fun clear() = lock.write {
|
||||
idToEmbedding.clear()
|
||||
uncheckedIds.clear()
|
||||
override suspend fun clear() {
|
||||
lock.write {
|
||||
idToEmbedding.clear()
|
||||
uncheckedIds.clear()
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun remove(id: EntityId) {
|
||||
@@ -58,34 +62,44 @@ class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null)
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun onIndexingFinish() = lock.write {
|
||||
uncheckedIds.forEach { idToEmbedding.remove(it) }
|
||||
uncheckedIds.clear()
|
||||
override suspend fun onIndexingFinish() {
|
||||
lock.write {
|
||||
uncheckedIds.forEach { idToEmbedding.remove(it) }
|
||||
uncheckedIds.clear()
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun addEntries(
|
||||
values: Iterable<Pair<EntityId, FloatTextEmbedding>>,
|
||||
shouldCount: Boolean,
|
||||
) = lock.write {
|
||||
if (limit != null) {
|
||||
val list = values.toList()
|
||||
list.forEach { uncheckedIds.remove(it.first) }
|
||||
idToEmbedding.putAll(list.take(minOf(limit!! - idToEmbedding.size, list.size)))
|
||||
}
|
||||
else {
|
||||
idToEmbedding.putAll(values)
|
||||
) {
|
||||
lock.write {
|
||||
if (limit != null) {
|
||||
val list = values.toList()
|
||||
list.forEach { uncheckedIds.remove(it.first) }
|
||||
idToEmbedding.putAll(list.take(minOf(limit!! - idToEmbedding.size, list.size)))
|
||||
}
|
||||
else {
|
||||
idToEmbedding.putAll(values)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun saveToDisk() = lock.read { save() }
|
||||
|
||||
override suspend fun loadFromDisk() = lock.write {
|
||||
val (ids, embeddings) = fileManager.loadIndex() ?: return@write
|
||||
idToEmbedding.clear()
|
||||
idToEmbedding.putAll(ids zip embeddings)
|
||||
override suspend fun saveToDisk() {
|
||||
lock.read { save() }
|
||||
}
|
||||
|
||||
override suspend fun offload() = lock.write { idToEmbedding.clear() }
|
||||
override suspend fun loadFromDisk() {
|
||||
lock.write {
|
||||
val (ids, embeddings) = fileManager.loadIndex() ?: return@write
|
||||
idToEmbedding.clear()
|
||||
idToEmbedding.putAll(ids zip embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun offload() {
|
||||
lock.write { idToEmbedding.clear() }
|
||||
}
|
||||
|
||||
override suspend fun findClosest(
|
||||
searchEmbedding: FloatTextEmbedding,
|
||||
@@ -108,7 +122,7 @@ class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null)
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun estimateMemoryUsage() = fileManager.embeddingSizeInBytes.toLong() * getSize()
|
||||
override suspend fun estimateMemoryUsage(): Long = fileManager.embeddingSizeInBytes.toLong() * getSize()
|
||||
|
||||
override fun estimateLimitByMemory(memory: Long): Int {
|
||||
return (memory / fileManager.embeddingSizeInBytes).toInt()
|
||||
|
||||
@@ -10,7 +10,7 @@ import com.intellij.platform.ml.embeddings.indexer.IndexId
|
||||
*/
|
||||
interface IndexPersistedEventsCounter {
|
||||
companion object {
|
||||
val EP_NAME = ProjectExtensionPointName<IndexPersistedEventsCounter>("com.intellij.platform.ml.embeddings.indexPersistedEventsCounter")
|
||||
val EP_NAME: ProjectExtensionPointName<IndexPersistedEventsCounter> = ProjectExtensionPointName<IndexPersistedEventsCounter>("com.intellij.platform.ml.embeddings.indexPersistedEventsCounter")
|
||||
}
|
||||
|
||||
suspend fun sendPersistedCount(indexId: IndexId, project: Project)
|
||||
|
||||
@@ -9,7 +9,6 @@ import com.fasterxml.jackson.databind.SerializationFeature
|
||||
import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
|
||||
import com.fasterxml.jackson.module.kotlin.readValue
|
||||
import com.intellij.platform.ml.embeddings.jvm.utils.SuspendingReadWriteLock
|
||||
import com.intellij.util.io.outputStream
|
||||
import kotlinx.coroutines.CoroutineScope
|
||||
import kotlinx.coroutines.coroutineScope
|
||||
import kotlinx.coroutines.ensureActive
|
||||
@@ -20,6 +19,7 @@ import java.nio.file.Files
|
||||
import java.nio.file.Path
|
||||
import kotlin.io.path.exists
|
||||
import kotlin.io.path.inputStream
|
||||
import kotlin.io.path.outputStream
|
||||
|
||||
class LocalEmbeddingIndexFileManager(root: Path, private val dimensions: Int = DEFAULT_DIMENSIONS) {
|
||||
private val lock = SuspendingReadWriteLock()
|
||||
@@ -36,7 +36,7 @@ class LocalEmbeddingIndexFileManager(root: Path, private val dimensions: Int = D
|
||||
private val embeddingsPath
|
||||
get() = rootPath.resolve(EMBEDDINGS_FILENAME)
|
||||
|
||||
val embeddingSizeInBytes = dimensions * EMBEDDING_ELEMENT_SIZE
|
||||
val embeddingSizeInBytes: Int = dimensions * EMBEDDING_ELEMENT_SIZE
|
||||
|
||||
/** Provides reading access to the embedding vector at the specified index
|
||||
* without reading the whole file into memory
|
||||
@@ -55,12 +55,14 @@ class LocalEmbeddingIndexFileManager(root: Path, private val dimensions: Int = D
|
||||
/** Provides writing access to embedding vector at the specified index
|
||||
* without writing the other vectors
|
||||
*/
|
||||
suspend fun set(index: Int, embedding: FloatTextEmbedding) = lock.write {
|
||||
RandomAccessFile(embeddingsPath.toFile(), "rw").use { output ->
|
||||
output.seek(getIndexOffset(index))
|
||||
val buffer = ByteBuffer.allocate(EMBEDDING_ELEMENT_SIZE)
|
||||
embedding.values.forEach {
|
||||
output.write(buffer.putFloat(0, it).array())
|
||||
suspend fun set(index: Int, embedding: FloatTextEmbedding) {
|
||||
lock.write {
|
||||
RandomAccessFile(embeddingsPath.toFile(), "rw").use { output ->
|
||||
output.seek(getIndexOffset(index))
|
||||
val buffer = ByteBuffer.allocate(EMBEDDING_ELEMENT_SIZE)
|
||||
embedding.values.forEach {
|
||||
output.write(buffer.putFloat(0, it).array())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -69,23 +71,25 @@ class LocalEmbeddingIndexFileManager(root: Path, private val dimensions: Int = D
|
||||
* Removes the embedding vector at the specified index.
|
||||
* To do so, replaces this vector with the last vector in the file and shrinks the file size.
|
||||
*/
|
||||
suspend fun removeAtIndex(index: Int) = lock.write {
|
||||
RandomAccessFile(embeddingsPath.toFile(), "rw").use { file ->
|
||||
if (file.length() < embeddingSizeInBytes) return@write
|
||||
if (file.length() - embeddingSizeInBytes != getIndexOffset(index)) {
|
||||
file.seek(file.length() - embeddingSizeInBytes)
|
||||
val array = ByteArray(EMBEDDING_ELEMENT_SIZE)
|
||||
val embedding = FloatTextEmbedding(FloatArray(dimensions) {
|
||||
file.read(array)
|
||||
ByteBuffer.wrap(array).getFloat()
|
||||
})
|
||||
file.seek(getIndexOffset(index))
|
||||
val buffer = ByteBuffer.allocate(EMBEDDING_ELEMENT_SIZE)
|
||||
embedding.values.forEach {
|
||||
file.write(buffer.putFloat(0, it).array())
|
||||
suspend fun removeAtIndex(index: Int) {
|
||||
lock.write {
|
||||
RandomAccessFile(embeddingsPath.toFile(), "rw").use { file ->
|
||||
if (file.length() < embeddingSizeInBytes) return@write
|
||||
if (file.length() - embeddingSizeInBytes != getIndexOffset(index)) {
|
||||
file.seek(file.length() - embeddingSizeInBytes)
|
||||
val array = ByteArray(EMBEDDING_ELEMENT_SIZE)
|
||||
val embedding = FloatTextEmbedding(FloatArray(dimensions) {
|
||||
file.read(array)
|
||||
ByteBuffer.wrap(array).getFloat()
|
||||
})
|
||||
file.seek(getIndexOffset(index))
|
||||
val buffer = ByteBuffer.allocate(EMBEDDING_ELEMENT_SIZE)
|
||||
embedding.values.forEach {
|
||||
file.write(buffer.putFloat(0, it).array())
|
||||
}
|
||||
}
|
||||
file.setLength(file.length() - embeddingSizeInBytes)
|
||||
}
|
||||
file.setLength(file.length() - embeddingSizeInBytes)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,16 +115,18 @@ class LocalEmbeddingIndexFileManager(root: Path, private val dimensions: Int = D
|
||||
}
|
||||
LoadedIndex(ids, embeddings)
|
||||
}
|
||||
catch (e: JsonProcessingException) {
|
||||
catch (_: JsonProcessingException) {
|
||||
return@read null
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
suspend fun saveIds(ids: List<EntityId>) = lock.write {
|
||||
withNotEnoughSpaceCheck {
|
||||
idsPath.outputStream().buffered().use { output ->
|
||||
mapper.writer(prettyPrinter).writeValue(output, ids)
|
||||
suspend fun saveIds(ids: List<EntityId>) {
|
||||
lock.write {
|
||||
withNotEnoughSpaceCheck {
|
||||
idsPath.outputStream().buffered().use { output ->
|
||||
mapper.writer(prettyPrinter).writeValue(output, ids)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -169,8 +175,8 @@ class LocalEmbeddingIndexFileManager(root: Path, private val dimensions: Int = D
|
||||
}
|
||||
|
||||
companion object {
|
||||
const val DEFAULT_DIMENSIONS = 128
|
||||
const val EMBEDDING_ELEMENT_SIZE = 4
|
||||
const val DEFAULT_DIMENSIONS: Int = 128
|
||||
const val EMBEDDING_ELEMENT_SIZE: Int = 4
|
||||
|
||||
private const val IDS_FILENAME = "ids.json"
|
||||
private const val SOURCE_TYPES_FILENAME = "sourceTypes.json"
|
||||
|
||||
@@ -31,15 +31,17 @@ open class VanillaEmbeddingSearchIndex(val root: Path, override var limit: Int?
|
||||
|
||||
private val fileManager = LocalEmbeddingIndexFileManager(root)
|
||||
|
||||
override suspend fun setLimit(value: Int?) = lock.write {
|
||||
if (value != null) {
|
||||
// Shrink index if necessary:
|
||||
while (idToEntry.size > value) {
|
||||
delete(indexToId[idToEntry.size - 1]!!, all = true, shouldSaveIds = false)
|
||||
override suspend fun setLimit(value: Int?) {
|
||||
lock.write {
|
||||
if (value != null) {
|
||||
// Shrink index if necessary:
|
||||
while (idToEntry.size > value) {
|
||||
delete(indexToId[idToEntry.size - 1]!!, all = true, shouldSaveIds = false)
|
||||
}
|
||||
saveIds()
|
||||
}
|
||||
saveIds()
|
||||
limit = value
|
||||
}
|
||||
limit = value
|
||||
}
|
||||
|
||||
private data class IndexEntry(
|
||||
@@ -48,19 +50,19 @@ open class VanillaEmbeddingSearchIndex(val root: Path, override var limit: Int?
|
||||
val embedding: FloatTextEmbedding,
|
||||
)
|
||||
|
||||
override suspend fun getSize() = lock.read { idToEntry.size }
|
||||
override suspend fun getSize(): Int = lock.read { idToEntry.size }
|
||||
|
||||
override suspend fun contains(id: EntityId): Boolean = lock.read {
|
||||
id in idToEntry
|
||||
}
|
||||
override suspend fun contains(id: EntityId): Boolean = lock.read { id in idToEntry }
|
||||
|
||||
override suspend fun lookup(id: EntityId): FloatTextEmbedding? = lock.read { idToEntry[id]?.embedding }
|
||||
|
||||
override suspend fun clear() = lock.write {
|
||||
indexToId.clear()
|
||||
idToEntry.clear()
|
||||
uncheckedIds.clear()
|
||||
changed = false
|
||||
override suspend fun clear() {
|
||||
lock.write {
|
||||
indexToId.clear()
|
||||
idToEntry.clear()
|
||||
uncheckedIds.clear()
|
||||
changed = false
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun onIndexingStart() {
|
||||
@@ -71,57 +73,67 @@ open class VanillaEmbeddingSearchIndex(val root: Path, override var limit: Int?
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun onIndexingFinish() = lock.write {
|
||||
if (uncheckedIds.size > 0) changed = true
|
||||
logger.debug { "Deleted ${uncheckedIds.size} unchecked ids" }
|
||||
uncheckedIds.forEach {
|
||||
delete(it, all = true, shouldSaveIds = false)
|
||||
override suspend fun onIndexingFinish() {
|
||||
lock.write {
|
||||
if (uncheckedIds.isNotEmpty()) changed = true
|
||||
logger.debug { "Deleted ${uncheckedIds.size} unchecked ids" }
|
||||
uncheckedIds.forEach {
|
||||
delete(it, all = true, shouldSaveIds = false)
|
||||
}
|
||||
uncheckedIds.clear()
|
||||
}
|
||||
uncheckedIds.clear()
|
||||
}
|
||||
|
||||
override suspend fun addEntries(
|
||||
values: Iterable<Pair<EntityId, FloatTextEmbedding>>,
|
||||
shouldCount: Boolean,
|
||||
) = lock.write {
|
||||
for ((id, embedding) in values) {
|
||||
checkCancelled()
|
||||
uncheckedIds.remove(id)
|
||||
if (limit != null && idToEntry.size >= limit!!) break
|
||||
val entry = idToEntry.getOrPut(id) {
|
||||
changed = true
|
||||
val index = idToEntry.size
|
||||
) {
|
||||
lock.write {
|
||||
for ((id, embedding) in values) {
|
||||
checkCancelled()
|
||||
uncheckedIds.remove(id)
|
||||
if (limit != null && idToEntry.size >= limit!!) break
|
||||
val entry = idToEntry.getOrPut(id) {
|
||||
changed = true
|
||||
val index = idToEntry.size
|
||||
indexToId[index] = id
|
||||
IndexEntry(index = index, count = 0, embedding = embedding)
|
||||
}
|
||||
if (shouldCount || entry.count == 0) {
|
||||
entry.count += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun saveToDisk() {
|
||||
lock.read {
|
||||
val ids = idToEntry.asSequence().sortedBy { it.value.index }.map { it.key }.toList()
|
||||
val embeddings = ids.map { idToEntry[it]!!.embedding }
|
||||
fileManager.saveIndex(ids, embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun loadFromDisk() {
|
||||
lock.write {
|
||||
indexToId.clear()
|
||||
idToEntry.clear()
|
||||
val (ids, embeddings) = fileManager.loadIndex() ?: return@write
|
||||
for ((index, id) in ids.withIndex()) {
|
||||
val embedding = embeddings[index]
|
||||
indexToId[index] = id
|
||||
IndexEntry(index = index, count = 0, embedding = embedding)
|
||||
}
|
||||
if (shouldCount || entry.count == 0) {
|
||||
entry.count += 1
|
||||
idToEntry[id] = IndexEntry(index = index, count = 0, embedding = embedding)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun saveToDisk() = lock.read {
|
||||
val ids = idToEntry.asSequence().sortedBy { it.value.index }.map { it.key }.toList()
|
||||
val embeddings = ids.map { idToEntry[it]!!.embedding }
|
||||
fileManager.saveIndex(ids, embeddings)
|
||||
}
|
||||
|
||||
override suspend fun loadFromDisk() = lock.write {
|
||||
indexToId.clear()
|
||||
idToEntry.clear()
|
||||
val (ids, embeddings) = fileManager.loadIndex() ?: return@write
|
||||
for ((index, id) in ids.withIndex()) {
|
||||
val embedding = embeddings[index]
|
||||
indexToId[index] = id
|
||||
idToEntry[id] = IndexEntry(index = index, count = 0, embedding = embedding)
|
||||
override suspend fun offload() {
|
||||
lock.write {
|
||||
indexToId.clear()
|
||||
idToEntry.clear()
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun offload() = lock.write {
|
||||
indexToId.clear()
|
||||
idToEntry.clear()
|
||||
}
|
||||
|
||||
override suspend fun findClosest(
|
||||
searchEmbedding: FloatTextEmbedding,
|
||||
topK: Int,
|
||||
@@ -154,7 +166,7 @@ open class VanillaEmbeddingSearchIndex(val root: Path, override var limit: Int?
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun estimateMemoryUsage() = fileManager.embeddingSizeInBytes.toLong() * getSize()
|
||||
override suspend fun estimateMemoryUsage(): Long = fileManager.embeddingSizeInBytes.toLong() * getSize()
|
||||
|
||||
override fun estimateLimitByMemory(memory: Long): Int {
|
||||
return (memory / fileManager.embeddingSizeInBytes).toInt()
|
||||
@@ -164,32 +176,32 @@ open class VanillaEmbeddingSearchIndex(val root: Path, override var limit: Int?
|
||||
limit == null || idToEntry.size < limit!!
|
||||
}
|
||||
|
||||
suspend fun deleteEntry(id: EntityId, syncToDisk: Boolean) = lock.write {
|
||||
delete(id = id, shouldSaveIds = syncToDisk)
|
||||
}
|
||||
|
||||
suspend fun addEntry(id: EntityId, embedding: FloatTextEmbedding) = lock.write {
|
||||
uncheckedIds.remove(id)
|
||||
add(id = id, embedding = embedding)
|
||||
suspend fun deleteEntry(id: EntityId, syncToDisk: Boolean) {
|
||||
lock.write {
|
||||
delete(id = id, shouldSaveIds = syncToDisk)
|
||||
}
|
||||
}
|
||||
|
||||
/* Optimization for consequent delete and add operations */
|
||||
suspend fun updateEntry(id: EntityId, newId: EntityId, embedding: FloatTextEmbedding) = lock.write {
|
||||
if (id !in idToEntry) return@write
|
||||
if (idToEntry[id]!!.count == 1 && newId !in idToEntry) {
|
||||
val index = idToEntry[id]!!.index
|
||||
fileManager.set(index, embedding)
|
||||
@Suppress("unused")
|
||||
suspend fun updateEntry(id: EntityId, newId: EntityId, embedding: FloatTextEmbedding) {
|
||||
lock.write {
|
||||
if (id !in idToEntry) return@write
|
||||
if (idToEntry[id]!!.count == 1 && newId !in idToEntry) {
|
||||
val index = idToEntry[id]!!.index
|
||||
fileManager.set(index, embedding)
|
||||
|
||||
idToEntry.remove(id)
|
||||
idToEntry[newId] = IndexEntry(index = index, count = 1, embedding = embedding)
|
||||
indexToId[index] = newId
|
||||
idToEntry.remove(id)
|
||||
idToEntry[newId] = IndexEntry(index = index, count = 1, embedding = embedding)
|
||||
indexToId[index] = newId
|
||||
|
||||
saveIds()
|
||||
}
|
||||
else {
|
||||
// Do not apply optimization
|
||||
delete(id)
|
||||
add(id = newId, embedding = embedding)
|
||||
saveIds()
|
||||
}
|
||||
else {
|
||||
// Do not apply optimization
|
||||
delete(id)
|
||||
add(id = newId, embedding = embedding)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ class LocalEmbeddingNetwork(
|
||||
return LocalEmbeddingNetwork(KIEngine.loadModel(data), maxLen ?: DEFAULT_MAX_LEN)
|
||||
}
|
||||
|
||||
const val DEFAULT_MAX_LEN = 512
|
||||
const val DEFAULT_MAX_LEN: Int = 512
|
||||
}
|
||||
|
||||
@Suppress("Unused") // useful for older versions of kinference where operators for mean pooling are not supported
|
||||
|
||||
@@ -16,7 +16,7 @@ class LocalEmbeddingServiceLoader {
|
||||
val network = loadNetwork(loader)
|
||||
val encoder = loadTextEncoder(loader)
|
||||
return LocalEmbeddingService(network, encoder)
|
||||
} catch (e: EOFException) {
|
||||
} catch (_: EOFException) {
|
||||
return null
|
||||
}
|
||||
}
|
||||
@@ -33,8 +33,8 @@ class LocalEmbeddingServiceLoader {
|
||||
}
|
||||
|
||||
companion object {
|
||||
const val MODEL_NAME = "dan-bert-tiny"
|
||||
const val MODEL_FILENAME = "dan_optimized_fp16.onnx"
|
||||
const val MODEL_NAME: String = "dan-bert-tiny"
|
||||
const val MODEL_FILENAME: String = "dan_optimized_fp16.onnx"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -167,6 +167,6 @@ abstract class AbstractEmbeddingsStorageWrapper(
|
||||
private val OFFLOAD_TIMEOUT = 10.seconds
|
||||
private val logger = Logger.getInstance(AbstractEmbeddingsStorageWrapper::class.java)
|
||||
|
||||
const val OLD_API_DIR_NAME = "old-api"
|
||||
const val OLD_API_DIR_NAME: String = "old-api"
|
||||
}
|
||||
}
|
||||
|
||||
@@ -30,7 +30,9 @@ class ActionEmbeddingsStorageWrapper : EmbeddingsStorageWrapper {
|
||||
/ SEMANTIC_SEARCH_RESOURCES_DIR_NAME / OLD_API_DIR_NAME / INDICES_DIR_NAME / INDEX_DIR
|
||||
)
|
||||
|
||||
override suspend fun addEntries(values: Iterable<Pair<EntityId, FloatTextEmbedding>>) = index.addEntries(values)
|
||||
override suspend fun addEntries(values: Iterable<Pair<EntityId, FloatTextEmbedding>>) {
|
||||
index.addEntries(values)
|
||||
}
|
||||
|
||||
override suspend fun removeEntries(keys: List<EntityId>) {
|
||||
for (key in keys) {
|
||||
@@ -38,7 +40,9 @@ class ActionEmbeddingsStorageWrapper : EmbeddingsStorageWrapper {
|
||||
}
|
||||
}
|
||||
|
||||
override suspend fun clear() = index.clear()
|
||||
override suspend fun clear() {
|
||||
index.clear()
|
||||
}
|
||||
|
||||
@RequiresBackgroundThread
|
||||
override suspend fun searchNeighbours(queryEmbedding: FloatTextEmbedding, topK: Int, similarityThreshold: Double?): List<ScoredKey<EntityId>> {
|
||||
@@ -60,9 +64,13 @@ class ActionEmbeddingsStorageWrapper : EmbeddingsStorageWrapper {
|
||||
return index.streamFindClose(embedding, similarityThreshold)
|
||||
}
|
||||
|
||||
override suspend fun startIndexingSession() = index.onIndexingStart()
|
||||
override suspend fun startIndexingSession() {
|
||||
index.onIndexingStart()
|
||||
}
|
||||
|
||||
override suspend fun finishIndexingSession() = index.onIndexingFinish()
|
||||
override suspend fun finishIndexingSession() {
|
||||
index.onIndexingFinish()
|
||||
}
|
||||
|
||||
override suspend fun getSize(): Int = index.getSize()
|
||||
|
||||
|
||||
@@ -24,14 +24,18 @@ class EmbeddingIndexSettingsImpl : EmbeddingIndexSettings {
|
||||
private val mutex = ReentrantReadWriteLock()
|
||||
private val clientSettings = mutableListOf<EmbeddingIndexSettings>()
|
||||
|
||||
fun registerClientSettings(settings: EmbeddingIndexSettings) = mutex.write {
|
||||
if (settings in clientSettings) return@write
|
||||
clientSettings.add(settings)
|
||||
fun registerClientSettings(settings: EmbeddingIndexSettings) {
|
||||
mutex.write {
|
||||
if (settings in clientSettings) return@write
|
||||
clientSettings.add(settings)
|
||||
}
|
||||
}
|
||||
|
||||
@Suppress("unused")
|
||||
fun unregisterClientSettings(settings: EmbeddingIndexSettings) = mutex.write {
|
||||
clientSettings.remove(settings)
|
||||
fun unregisterClientSettings(settings: EmbeddingIndexSettings) {
|
||||
mutex.write {
|
||||
clientSettings.remove(settings)
|
||||
}
|
||||
}
|
||||
|
||||
companion object {
|
||||
|
||||
@@ -12,7 +12,7 @@ class SemanticSearchCoroutineScope(private val cs: CoroutineScope) : Disposable
|
||||
override fun dispose() {}
|
||||
|
||||
companion object {
|
||||
fun getScope(project: Project) = project.service<SemanticSearchCoroutineScope>().cs
|
||||
fun getScope(project: Project): CoroutineScope = project.service<SemanticSearchCoroutineScope>().cs
|
||||
|
||||
fun getInstance(project: Project): SemanticSearchCoroutineScope = project.service()
|
||||
}
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
|
||||
package com.intellij.platform.ml.embeddings.utils
|
||||
|
||||
import com.intellij.platform.diagnostic.telemetry.IJTracer
|
||||
import com.intellij.platform.diagnostic.telemetry.Scope
|
||||
import com.intellij.platform.diagnostic.telemetry.TelemetryManager
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
@ApiStatus.Internal
|
||||
val SEMANTIC_SEARCH_TRACER = TelemetryManager.getInstance().getTracer(Scope("semanticSearch"))
|
||||
val SEMANTIC_SEARCH_TRACER: IJTracer = TelemetryManager.getInstance().getTracer(Scope("semanticSearch"))
|
||||
@@ -44,7 +44,7 @@ class InvalidTokenNotificationManager {
|
||||
}
|
||||
|
||||
companion object {
|
||||
private const val NOTIFICATION_GROUP_ID = "Semantic search"
|
||||
private const val NOTIFICATION_GROUP_ID: String = "Semantic search"
|
||||
|
||||
fun getInstance(): InvalidTokenNotificationManager = service()
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user