(IJPL-149042) Comply with explicit return type inspection

GitOrigin-RevId: 27617fcc43971968d4b50a242913806eb1bcf224
This commit is contained in:
Evgeny Abramov
2024-10-07 23:28:41 +03:00
committed by intellij-monorepo-bot
parent 1b76e04151
commit 331b93804a
22 changed files with 231 additions and 194 deletions

View File

@@ -22,7 +22,6 @@ import com.intellij.platform.ml.embeddings.external.artifacts.LocalArtifactsMana
import com.intellij.platform.ml.embeddings.external.artifacts.LocalArtifactsManager.Companion.getArchitectureId
import com.intellij.platform.ml.embeddings.external.artifacts.LocalArtifactsManager.Companion.getOsId
import com.intellij.platform.ml.embeddings.indexer.FileBasedEmbeddingIndexer.Companion.INDEXING_VERSION
import com.intellij.platform.ml.embeddings.jvm.models.CustomRootDataLoader
import com.intellij.util.concurrency.annotations.RequiresBackgroundThread
import com.intellij.util.download.DownloadableFileService
import com.intellij.util.io.ZipUtil
@@ -55,8 +54,6 @@ class LocalArtifactsManager {
// TODO: the list may be changed depending on the project size
private val availableModels = listOf(ModelArtifact.SmallModelArtifact)
fun getCustomRootDataLoader() = CustomRootDataLoader(modelsRoot)
@RequiresBackgroundThread
suspend fun downloadArtifactsIfNecessary(
project: Project? = null,
@@ -90,9 +87,7 @@ class LocalArtifactsManager {
// TODO: provide model id in arguments
fun getModelArtifact(): ModelArtifact = availableModels.first()
fun getServerArtifact() = NativeServerArtifact
fun checkArtifactsPresent(): Boolean = availableModels.all { it.checkPresent() }
fun getServerArtifact(): NativeServerArtifact = NativeServerArtifact
private fun downloadArtifacts(artifacts: List<DownloadableArtifact>) {
try {
@@ -136,9 +131,9 @@ class LocalArtifactsManager {
}
companion object {
const val SEMANTIC_SEARCH_RESOURCES_DIR_NAME = "semantic-search" // TODO: move to common constants
const val SEMANTIC_SEARCH_RESOURCES_DIR_NAME: String = "semantic-search" // TODO: move to common constants
const val INDICES_DIR_NAME = "indices"
const val INDICES_DIR_NAME: String = "indices"
private const val MODELS_DIR_NAME = "models"
private const val SERVER_DIR_NAME = "server"
@@ -204,7 +199,7 @@ sealed class ModelArtifact(
) : DownloadableArtifact {
data object SmallModelArtifact : ModelArtifact("small", "dan_100k_optimized.onnx", "bert-base-uncased.txt")
final override val archiveName = "$name.zip"
final override val archiveName: String = "$name.zip"
override val downloadLink: String = listOf(CDN_LINK_BASE, MODEL_VERSION, archiveName).joinToString(separator = "/")
override val destination: Path = LocalArtifactsManager.modelsRoot / name

View File

@@ -9,7 +9,7 @@ enum class EmbeddingDistanceMetric(val label: String) {
COSINE("cos"),
SQUARED_EUCLIDEAN("l2sq");
override fun toString() = label
override fun toString(): String = label
}
enum class EmbeddingQuantization(val label: String) {
@@ -18,7 +18,7 @@ enum class EmbeddingQuantization(val label: String) {
INT8("i8"),
BINARY("b1x8");
override fun toString() = label
override fun toString(): String = label
}
data class NativeServerStartupArguments(

View File

@@ -62,6 +62,6 @@ class ServerDiagnosticsListener : ProcessListener {
}
companion object {
const val MAX_HISTORY_SIZE = 1_000
const val MAX_HISTORY_SIZE: Int = 1_000
}
}

View File

@@ -1,5 +0,0 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
@ApiStatus.Internal
package com.intellij.platform.ml.embeddings.files;
import org.jetbrains.annotations.ApiStatus;

View File

@@ -8,5 +8,5 @@ import com.intellij.platform.ml.embeddings.utils.splitIdentifierIntoTokens
class IndexableFile(override val id: EntityId) : IndexableEntity {
constructor(file: VirtualFile) : this(EntityId(file.name))
override val indexableRepresentation by lazy { splitIdentifierIntoTokens(FileUtilRt.getNameWithoutExtension(id.id)) }
override val indexableRepresentation: String by lazy { splitIdentifierIntoTokens(FileUtilRt.getNameWithoutExtension(id.id)) }
}

View File

@@ -53,8 +53,8 @@ class IntegerStorageKeyProvider : EmbeddingStorageKeyProvider<Long>, Disposable
}
companion object {
private const val ENUMERATOR_FOLDER = "enumerator"
private const val ENUMERATOR_FILE = "ids.enum"
private const val ENUMERATOR_FOLDER: String = "enumerator"
private const val ENUMERATOR_FILE: String = "ids.enum"
fun getInstance(): IntegerStorageKeyProvider = service()
}

View File

@@ -60,7 +60,7 @@ class EmbeddingsStorageManagerWrapper<KeyT>(
return storageManager.getStorageStats(project, indexId)
}
fun getBatchSize() = storageManager.getBatchSize()
fun getBatchSize(): Int = storageManager.getBatchSize()
companion object {
private const val INDEXABLE_REPRESENTATION_CHAR_LIMIT = 64

View File

@@ -49,29 +49,31 @@ class KInferenceLocalArtifactsManager {
root.toPath().listDirectoryEntries().filter { it.name != MODEL_ARTIFACTS_DIR }.forEach { it.delete(recursively = true) }
}
fun getCustomRootDataLoader() = CustomRootDataLoader(modelArtifactsRoot.toPath())
fun getCustomRootDataLoader(): CustomRootDataLoader = CustomRootDataLoader(modelArtifactsRoot.toPath())
@RequiresBackgroundThread
suspend fun downloadArtifactsIfNecessary(project: Project? = null,
retryIfCanceled: Boolean = true) = withContext(downloadContext) {
if (!checkArtifactsPresent() && !ApplicationManager.getApplication().isUnitTestMode && (retryIfCanceled || !downloadCanceled)) {
logger.debug("Semantic search artifacts are not present, starting the download...")
if (project != null) {
withBackgroundProgress(project, ARTIFACTS_DOWNLOAD_TASK_NAME) {
try {
coroutineToIndicator { // platform code relies on the existence of indicator
downloadArtifacts()
retryIfCanceled: Boolean = true) {
withContext(downloadContext) {
if (!checkArtifactsPresent() && !ApplicationManager.getApplication().isUnitTestMode && (retryIfCanceled || !downloadCanceled)) {
logger.debug("Semantic search artifacts are not present, starting the download...")
if (project != null) {
withBackgroundProgress(project, ARTIFACTS_DOWNLOAD_TASK_NAME) {
try {
coroutineToIndicator { // platform code relies on the existence of indicator
downloadArtifacts()
}
}
catch (e: CancellationException) {
logger.debug("Artifacts downloading was canceled")
downloadCanceled = true
throw e
}
}
catch (e: CancellationException) {
logger.debug("Artifacts downloading was canceled")
downloadCanceled = true
throw e
}
}
}
else {
downloadArtifacts()
else {
downloadArtifacts()
}
}
}
}
@@ -107,7 +109,7 @@ class KInferenceLocalArtifactsManager {
}
companion object {
const val SEMANTIC_SEARCH_RESOURCES_DIR = "semantic-search"
const val SEMANTIC_SEARCH_RESOURCES_DIR: String = "semantic-search"
private val ARTIFACTS_DOWNLOAD_TASK_NAME
get() = EmbeddingsBundle.getMessage("ml.embeddings.artifacts.download.name")

View File

@@ -17,7 +17,7 @@ interface EmbeddingSearchIndex {
suspend fun contains(id: EntityId): Boolean
suspend fun lookup(id: EntityId): FloatTextEmbedding?
suspend fun clear()
suspend fun clearBySourceType(sourceType: EntitySourceType) = Unit
suspend fun clearBySourceType(sourceType: EntitySourceType) {}
suspend fun remove(id: EntityId)
suspend fun onIndexingStart()

View File

@@ -8,5 +8,5 @@ enum class EntitySourceType {
EXTERNAL;
@JsonValue
fun value() = ordinal
fun value(): Int = ordinal
}

View File

@@ -22,16 +22,18 @@ class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null)
private val fileManager = LocalEmbeddingIndexFileManager(root)
override suspend fun getSize() = lock.read { idToEmbedding.size }
override suspend fun getSize(): Int = lock.read { idToEmbedding.size }
override suspend fun setLimit(value: Int?) = lock.write {
// Shrink index if necessary:
if (value != null && value < idToEmbedding.size) {
val remaining = idToEmbedding.asSequence().take(value).map { it.toPair() }.toList()
idToEmbedding.clear()
idToEmbedding.putAll(remaining)
override suspend fun setLimit(value: Int?) {
lock.write {
// Shrink index if necessary:
if (value != null && value < idToEmbedding.size) {
val remaining = idToEmbedding.asSequence().take(value).map { it.toPair() }.toList()
idToEmbedding.clear()
idToEmbedding.putAll(remaining)
}
limit = value
}
limit = value
}
override suspend fun contains(id: EntityId): Boolean = lock.read {
@@ -40,9 +42,11 @@ class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null)
override suspend fun lookup(id: EntityId): FloatTextEmbedding? = lock.read { idToEmbedding[id] }
override suspend fun clear() = lock.write {
idToEmbedding.clear()
uncheckedIds.clear()
override suspend fun clear() {
lock.write {
idToEmbedding.clear()
uncheckedIds.clear()
}
}
override suspend fun remove(id: EntityId) {
@@ -58,34 +62,44 @@ class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null)
}
}
override suspend fun onIndexingFinish() = lock.write {
uncheckedIds.forEach { idToEmbedding.remove(it) }
uncheckedIds.clear()
override suspend fun onIndexingFinish() {
lock.write {
uncheckedIds.forEach { idToEmbedding.remove(it) }
uncheckedIds.clear()
}
}
override suspend fun addEntries(
values: Iterable<Pair<EntityId, FloatTextEmbedding>>,
shouldCount: Boolean,
) = lock.write {
if (limit != null) {
val list = values.toList()
list.forEach { uncheckedIds.remove(it.first) }
idToEmbedding.putAll(list.take(minOf(limit!! - idToEmbedding.size, list.size)))
}
else {
idToEmbedding.putAll(values)
) {
lock.write {
if (limit != null) {
val list = values.toList()
list.forEach { uncheckedIds.remove(it.first) }
idToEmbedding.putAll(list.take(minOf(limit!! - idToEmbedding.size, list.size)))
}
else {
idToEmbedding.putAll(values)
}
}
}
override suspend fun saveToDisk() = lock.read { save() }
override suspend fun loadFromDisk() = lock.write {
val (ids, embeddings) = fileManager.loadIndex() ?: return@write
idToEmbedding.clear()
idToEmbedding.putAll(ids zip embeddings)
override suspend fun saveToDisk() {
lock.read { save() }
}
override suspend fun offload() = lock.write { idToEmbedding.clear() }
override suspend fun loadFromDisk() {
lock.write {
val (ids, embeddings) = fileManager.loadIndex() ?: return@write
idToEmbedding.clear()
idToEmbedding.putAll(ids zip embeddings)
}
}
override suspend fun offload() {
lock.write { idToEmbedding.clear() }
}
override suspend fun findClosest(
searchEmbedding: FloatTextEmbedding,
@@ -108,7 +122,7 @@ class InMemoryEmbeddingSearchIndex(root: Path, override var limit: Int? = null)
}
}
override suspend fun estimateMemoryUsage() = fileManager.embeddingSizeInBytes.toLong() * getSize()
override suspend fun estimateMemoryUsage(): Long = fileManager.embeddingSizeInBytes.toLong() * getSize()
override fun estimateLimitByMemory(memory: Long): Int {
return (memory / fileManager.embeddingSizeInBytes).toInt()

View File

@@ -10,7 +10,7 @@ import com.intellij.platform.ml.embeddings.indexer.IndexId
*/
interface IndexPersistedEventsCounter {
companion object {
val EP_NAME = ProjectExtensionPointName<IndexPersistedEventsCounter>("com.intellij.platform.ml.embeddings.indexPersistedEventsCounter")
val EP_NAME: ProjectExtensionPointName<IndexPersistedEventsCounter> = ProjectExtensionPointName<IndexPersistedEventsCounter>("com.intellij.platform.ml.embeddings.indexPersistedEventsCounter")
}
suspend fun sendPersistedCount(indexId: IndexId, project: Project)

View File

@@ -9,7 +9,6 @@ import com.fasterxml.jackson.databind.SerializationFeature
import com.fasterxml.jackson.module.kotlin.jacksonObjectMapper
import com.fasterxml.jackson.module.kotlin.readValue
import com.intellij.platform.ml.embeddings.jvm.utils.SuspendingReadWriteLock
import com.intellij.util.io.outputStream
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.ensureActive
@@ -20,6 +19,7 @@ import java.nio.file.Files
import java.nio.file.Path
import kotlin.io.path.exists
import kotlin.io.path.inputStream
import kotlin.io.path.outputStream
class LocalEmbeddingIndexFileManager(root: Path, private val dimensions: Int = DEFAULT_DIMENSIONS) {
private val lock = SuspendingReadWriteLock()
@@ -36,7 +36,7 @@ class LocalEmbeddingIndexFileManager(root: Path, private val dimensions: Int = D
private val embeddingsPath
get() = rootPath.resolve(EMBEDDINGS_FILENAME)
val embeddingSizeInBytes = dimensions * EMBEDDING_ELEMENT_SIZE
val embeddingSizeInBytes: Int = dimensions * EMBEDDING_ELEMENT_SIZE
/** Provides reading access to the embedding vector at the specified index
* without reading the whole file into memory
@@ -55,12 +55,14 @@ class LocalEmbeddingIndexFileManager(root: Path, private val dimensions: Int = D
/** Provides writing access to embedding vector at the specified index
* without writing the other vectors
*/
suspend fun set(index: Int, embedding: FloatTextEmbedding) = lock.write {
RandomAccessFile(embeddingsPath.toFile(), "rw").use { output ->
output.seek(getIndexOffset(index))
val buffer = ByteBuffer.allocate(EMBEDDING_ELEMENT_SIZE)
embedding.values.forEach {
output.write(buffer.putFloat(0, it).array())
suspend fun set(index: Int, embedding: FloatTextEmbedding) {
lock.write {
RandomAccessFile(embeddingsPath.toFile(), "rw").use { output ->
output.seek(getIndexOffset(index))
val buffer = ByteBuffer.allocate(EMBEDDING_ELEMENT_SIZE)
embedding.values.forEach {
output.write(buffer.putFloat(0, it).array())
}
}
}
}
@@ -69,23 +71,25 @@ class LocalEmbeddingIndexFileManager(root: Path, private val dimensions: Int = D
* Removes the embedding vector at the specified index.
* To do so, replaces this vector with the last vector in the file and shrinks the file size.
*/
suspend fun removeAtIndex(index: Int) = lock.write {
RandomAccessFile(embeddingsPath.toFile(), "rw").use { file ->
if (file.length() < embeddingSizeInBytes) return@write
if (file.length() - embeddingSizeInBytes != getIndexOffset(index)) {
file.seek(file.length() - embeddingSizeInBytes)
val array = ByteArray(EMBEDDING_ELEMENT_SIZE)
val embedding = FloatTextEmbedding(FloatArray(dimensions) {
file.read(array)
ByteBuffer.wrap(array).getFloat()
})
file.seek(getIndexOffset(index))
val buffer = ByteBuffer.allocate(EMBEDDING_ELEMENT_SIZE)
embedding.values.forEach {
file.write(buffer.putFloat(0, it).array())
suspend fun removeAtIndex(index: Int) {
lock.write {
RandomAccessFile(embeddingsPath.toFile(), "rw").use { file ->
if (file.length() < embeddingSizeInBytes) return@write
if (file.length() - embeddingSizeInBytes != getIndexOffset(index)) {
file.seek(file.length() - embeddingSizeInBytes)
val array = ByteArray(EMBEDDING_ELEMENT_SIZE)
val embedding = FloatTextEmbedding(FloatArray(dimensions) {
file.read(array)
ByteBuffer.wrap(array).getFloat()
})
file.seek(getIndexOffset(index))
val buffer = ByteBuffer.allocate(EMBEDDING_ELEMENT_SIZE)
embedding.values.forEach {
file.write(buffer.putFloat(0, it).array())
}
}
file.setLength(file.length() - embeddingSizeInBytes)
}
file.setLength(file.length() - embeddingSizeInBytes)
}
}
@@ -111,16 +115,18 @@ class LocalEmbeddingIndexFileManager(root: Path, private val dimensions: Int = D
}
LoadedIndex(ids, embeddings)
}
catch (e: JsonProcessingException) {
catch (_: JsonProcessingException) {
return@read null
}
}
}
suspend fun saveIds(ids: List<EntityId>) = lock.write {
withNotEnoughSpaceCheck {
idsPath.outputStream().buffered().use { output ->
mapper.writer(prettyPrinter).writeValue(output, ids)
suspend fun saveIds(ids: List<EntityId>) {
lock.write {
withNotEnoughSpaceCheck {
idsPath.outputStream().buffered().use { output ->
mapper.writer(prettyPrinter).writeValue(output, ids)
}
}
}
}
@@ -169,8 +175,8 @@ class LocalEmbeddingIndexFileManager(root: Path, private val dimensions: Int = D
}
companion object {
const val DEFAULT_DIMENSIONS = 128
const val EMBEDDING_ELEMENT_SIZE = 4
const val DEFAULT_DIMENSIONS: Int = 128
const val EMBEDDING_ELEMENT_SIZE: Int = 4
private const val IDS_FILENAME = "ids.json"
private const val SOURCE_TYPES_FILENAME = "sourceTypes.json"

View File

@@ -31,15 +31,17 @@ open class VanillaEmbeddingSearchIndex(val root: Path, override var limit: Int?
private val fileManager = LocalEmbeddingIndexFileManager(root)
override suspend fun setLimit(value: Int?) = lock.write {
if (value != null) {
// Shrink index if necessary:
while (idToEntry.size > value) {
delete(indexToId[idToEntry.size - 1]!!, all = true, shouldSaveIds = false)
override suspend fun setLimit(value: Int?) {
lock.write {
if (value != null) {
// Shrink index if necessary:
while (idToEntry.size > value) {
delete(indexToId[idToEntry.size - 1]!!, all = true, shouldSaveIds = false)
}
saveIds()
}
saveIds()
limit = value
}
limit = value
}
private data class IndexEntry(
@@ -48,19 +50,19 @@ open class VanillaEmbeddingSearchIndex(val root: Path, override var limit: Int?
val embedding: FloatTextEmbedding,
)
override suspend fun getSize() = lock.read { idToEntry.size }
override suspend fun getSize(): Int = lock.read { idToEntry.size }
override suspend fun contains(id: EntityId): Boolean = lock.read {
id in idToEntry
}
override suspend fun contains(id: EntityId): Boolean = lock.read { id in idToEntry }
override suspend fun lookup(id: EntityId): FloatTextEmbedding? = lock.read { idToEntry[id]?.embedding }
override suspend fun clear() = lock.write {
indexToId.clear()
idToEntry.clear()
uncheckedIds.clear()
changed = false
override suspend fun clear() {
lock.write {
indexToId.clear()
idToEntry.clear()
uncheckedIds.clear()
changed = false
}
}
override suspend fun onIndexingStart() {
@@ -71,57 +73,67 @@ open class VanillaEmbeddingSearchIndex(val root: Path, override var limit: Int?
}
}
override suspend fun onIndexingFinish() = lock.write {
if (uncheckedIds.size > 0) changed = true
logger.debug { "Deleted ${uncheckedIds.size} unchecked ids" }
uncheckedIds.forEach {
delete(it, all = true, shouldSaveIds = false)
override suspend fun onIndexingFinish() {
lock.write {
if (uncheckedIds.isNotEmpty()) changed = true
logger.debug { "Deleted ${uncheckedIds.size} unchecked ids" }
uncheckedIds.forEach {
delete(it, all = true, shouldSaveIds = false)
}
uncheckedIds.clear()
}
uncheckedIds.clear()
}
override suspend fun addEntries(
values: Iterable<Pair<EntityId, FloatTextEmbedding>>,
shouldCount: Boolean,
) = lock.write {
for ((id, embedding) in values) {
checkCancelled()
uncheckedIds.remove(id)
if (limit != null && idToEntry.size >= limit!!) break
val entry = idToEntry.getOrPut(id) {
changed = true
val index = idToEntry.size
) {
lock.write {
for ((id, embedding) in values) {
checkCancelled()
uncheckedIds.remove(id)
if (limit != null && idToEntry.size >= limit!!) break
val entry = idToEntry.getOrPut(id) {
changed = true
val index = idToEntry.size
indexToId[index] = id
IndexEntry(index = index, count = 0, embedding = embedding)
}
if (shouldCount || entry.count == 0) {
entry.count += 1
}
}
}
}
override suspend fun saveToDisk() {
lock.read {
val ids = idToEntry.asSequence().sortedBy { it.value.index }.map { it.key }.toList()
val embeddings = ids.map { idToEntry[it]!!.embedding }
fileManager.saveIndex(ids, embeddings)
}
}
override suspend fun loadFromDisk() {
lock.write {
indexToId.clear()
idToEntry.clear()
val (ids, embeddings) = fileManager.loadIndex() ?: return@write
for ((index, id) in ids.withIndex()) {
val embedding = embeddings[index]
indexToId[index] = id
IndexEntry(index = index, count = 0, embedding = embedding)
}
if (shouldCount || entry.count == 0) {
entry.count += 1
idToEntry[id] = IndexEntry(index = index, count = 0, embedding = embedding)
}
}
}
override suspend fun saveToDisk() = lock.read {
val ids = idToEntry.asSequence().sortedBy { it.value.index }.map { it.key }.toList()
val embeddings = ids.map { idToEntry[it]!!.embedding }
fileManager.saveIndex(ids, embeddings)
}
override suspend fun loadFromDisk() = lock.write {
indexToId.clear()
idToEntry.clear()
val (ids, embeddings) = fileManager.loadIndex() ?: return@write
for ((index, id) in ids.withIndex()) {
val embedding = embeddings[index]
indexToId[index] = id
idToEntry[id] = IndexEntry(index = index, count = 0, embedding = embedding)
override suspend fun offload() {
lock.write {
indexToId.clear()
idToEntry.clear()
}
}
override suspend fun offload() = lock.write {
indexToId.clear()
idToEntry.clear()
}
override suspend fun findClosest(
searchEmbedding: FloatTextEmbedding,
topK: Int,
@@ -154,7 +166,7 @@ open class VanillaEmbeddingSearchIndex(val root: Path, override var limit: Int?
}
}
override suspend fun estimateMemoryUsage() = fileManager.embeddingSizeInBytes.toLong() * getSize()
override suspend fun estimateMemoryUsage(): Long = fileManager.embeddingSizeInBytes.toLong() * getSize()
override fun estimateLimitByMemory(memory: Long): Int {
return (memory / fileManager.embeddingSizeInBytes).toInt()
@@ -164,32 +176,32 @@ open class VanillaEmbeddingSearchIndex(val root: Path, override var limit: Int?
limit == null || idToEntry.size < limit!!
}
suspend fun deleteEntry(id: EntityId, syncToDisk: Boolean) = lock.write {
delete(id = id, shouldSaveIds = syncToDisk)
}
suspend fun addEntry(id: EntityId, embedding: FloatTextEmbedding) = lock.write {
uncheckedIds.remove(id)
add(id = id, embedding = embedding)
suspend fun deleteEntry(id: EntityId, syncToDisk: Boolean) {
lock.write {
delete(id = id, shouldSaveIds = syncToDisk)
}
}
/* Optimization for consequent delete and add operations */
suspend fun updateEntry(id: EntityId, newId: EntityId, embedding: FloatTextEmbedding) = lock.write {
if (id !in idToEntry) return@write
if (idToEntry[id]!!.count == 1 && newId !in idToEntry) {
val index = idToEntry[id]!!.index
fileManager.set(index, embedding)
@Suppress("unused")
suspend fun updateEntry(id: EntityId, newId: EntityId, embedding: FloatTextEmbedding) {
lock.write {
if (id !in idToEntry) return@write
if (idToEntry[id]!!.count == 1 && newId !in idToEntry) {
val index = idToEntry[id]!!.index
fileManager.set(index, embedding)
idToEntry.remove(id)
idToEntry[newId] = IndexEntry(index = index, count = 1, embedding = embedding)
indexToId[index] = newId
idToEntry.remove(id)
idToEntry[newId] = IndexEntry(index = index, count = 1, embedding = embedding)
indexToId[index] = newId
saveIds()
}
else {
// Do not apply optimization
delete(id)
add(id = newId, embedding = embedding)
saveIds()
}
else {
// Do not apply optimization
delete(id)
add(id = newId, embedding = embedding)
}
}
}

View File

@@ -30,7 +30,7 @@ class LocalEmbeddingNetwork(
return LocalEmbeddingNetwork(KIEngine.loadModel(data), maxLen ?: DEFAULT_MAX_LEN)
}
const val DEFAULT_MAX_LEN = 512
const val DEFAULT_MAX_LEN: Int = 512
}
@Suppress("Unused") // useful for older versions of kinference where operators for mean pooling are not supported

View File

@@ -16,7 +16,7 @@ class LocalEmbeddingServiceLoader {
val network = loadNetwork(loader)
val encoder = loadTextEncoder(loader)
return LocalEmbeddingService(network, encoder)
} catch (e: EOFException) {
} catch (_: EOFException) {
return null
}
}
@@ -33,8 +33,8 @@ class LocalEmbeddingServiceLoader {
}
companion object {
const val MODEL_NAME = "dan-bert-tiny"
const val MODEL_FILENAME = "dan_optimized_fp16.onnx"
const val MODEL_NAME: String = "dan-bert-tiny"
const val MODEL_FILENAME: String = "dan_optimized_fp16.onnx"
}
}

View File

@@ -167,6 +167,6 @@ abstract class AbstractEmbeddingsStorageWrapper(
private val OFFLOAD_TIMEOUT = 10.seconds
private val logger = Logger.getInstance(AbstractEmbeddingsStorageWrapper::class.java)
const val OLD_API_DIR_NAME = "old-api"
const val OLD_API_DIR_NAME: String = "old-api"
}
}

View File

@@ -30,7 +30,9 @@ class ActionEmbeddingsStorageWrapper : EmbeddingsStorageWrapper {
/ SEMANTIC_SEARCH_RESOURCES_DIR_NAME / OLD_API_DIR_NAME / INDICES_DIR_NAME / INDEX_DIR
)
override suspend fun addEntries(values: Iterable<Pair<EntityId, FloatTextEmbedding>>) = index.addEntries(values)
override suspend fun addEntries(values: Iterable<Pair<EntityId, FloatTextEmbedding>>) {
index.addEntries(values)
}
override suspend fun removeEntries(keys: List<EntityId>) {
for (key in keys) {
@@ -38,7 +40,9 @@ class ActionEmbeddingsStorageWrapper : EmbeddingsStorageWrapper {
}
}
override suspend fun clear() = index.clear()
override suspend fun clear() {
index.clear()
}
@RequiresBackgroundThread
override suspend fun searchNeighbours(queryEmbedding: FloatTextEmbedding, topK: Int, similarityThreshold: Double?): List<ScoredKey<EntityId>> {
@@ -60,9 +64,13 @@ class ActionEmbeddingsStorageWrapper : EmbeddingsStorageWrapper {
return index.streamFindClose(embedding, similarityThreshold)
}
override suspend fun startIndexingSession() = index.onIndexingStart()
override suspend fun startIndexingSession() {
index.onIndexingStart()
}
override suspend fun finishIndexingSession() = index.onIndexingFinish()
override suspend fun finishIndexingSession() {
index.onIndexingFinish()
}
override suspend fun getSize(): Int = index.getSize()

View File

@@ -24,14 +24,18 @@ class EmbeddingIndexSettingsImpl : EmbeddingIndexSettings {
private val mutex = ReentrantReadWriteLock()
private val clientSettings = mutableListOf<EmbeddingIndexSettings>()
fun registerClientSettings(settings: EmbeddingIndexSettings) = mutex.write {
if (settings in clientSettings) return@write
clientSettings.add(settings)
fun registerClientSettings(settings: EmbeddingIndexSettings) {
mutex.write {
if (settings in clientSettings) return@write
clientSettings.add(settings)
}
}
@Suppress("unused")
fun unregisterClientSettings(settings: EmbeddingIndexSettings) = mutex.write {
clientSettings.remove(settings)
fun unregisterClientSettings(settings: EmbeddingIndexSettings) {
mutex.write {
clientSettings.remove(settings)
}
}
companion object {

View File

@@ -12,7 +12,7 @@ class SemanticSearchCoroutineScope(private val cs: CoroutineScope) : Disposable
override fun dispose() {}
companion object {
fun getScope(project: Project) = project.service<SemanticSearchCoroutineScope>().cs
fun getScope(project: Project): CoroutineScope = project.service<SemanticSearchCoroutineScope>().cs
fun getInstance(project: Project): SemanticSearchCoroutineScope = project.service()
}

View File

@@ -1,9 +1,10 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.intellij.platform.ml.embeddings.utils
import com.intellij.platform.diagnostic.telemetry.IJTracer
import com.intellij.platform.diagnostic.telemetry.Scope
import com.intellij.platform.diagnostic.telemetry.TelemetryManager
import org.jetbrains.annotations.ApiStatus
@ApiStatus.Internal
val SEMANTIC_SEARCH_TRACER = TelemetryManager.getInstance().getTracer(Scope("semanticSearch"))
val SEMANTIC_SEARCH_TRACER: IJTracer = TelemetryManager.getInstance().getTracer(Scope("semanticSearch"))

View File

@@ -44,7 +44,7 @@ class InvalidTokenNotificationManager {
}
companion object {
private const val NOTIFICATION_GROUP_ID = "Semantic search"
private const val NOTIFICATION_GROUP_ID: String = "Semantic search"
fun getInstance(): InvalidTokenNotificationManager = service()
}