LAB-29: recent token index when it's close to max value

GitOrigin-RevId: 5a86aa8f20038a28a2f5c5b4e226d840348dd904
This commit is contained in:
Svetlana.Zemlyanskaya
2021-01-28 18:09:10 +01:00
committed by intellij-monorepo-bot
parent 8ff83d9972
commit d97699872b
3 changed files with 110 additions and 28 deletions

View File

@@ -24,7 +24,7 @@ class VocabularyWithLimit(var maxVocabularySize: Int,
private val counter: AtomicInteger = AtomicInteger(1)
val recent: NGramRecentTokens = NGramRecentTokens()
val recent: NGramRecentTokens = NGramRecentTokens(maxSequenceSize)
val recentSequence: NGramRecentTokensSequence = NGramRecentTokensSequence(maxSequenceSize, nGramOrder, sequenceInitialSize)
/**
@@ -128,33 +128,38 @@ class VocabularyWithLimit(var maxVocabularySize: Int,
}
/**
* Stores recent tokens (recent) with an id of it last appearance (recentIdx).
* Stores recent tokens (recent) with an id of it last appearance
*
* Last token appearance is used to find a minimum sequence which have to be forgotten together with the token.
*/
class NGramRecentTokens : Externalizable {
private val nextTokenIdx: AtomicInteger = AtomicInteger(1)
class NGramRecentTokens(private val maxSequenceSize: Int) : Externalizable {
private var maxTokenIndex: Int = Int.MAX_VALUE - 1
private val nextTokenIndex: AtomicInteger = AtomicInteger(1)
private val recent: LinkedHashMap<String, Int> = LinkedHashMap()
@TestOnly
fun getNextTokenIndex(): Int = nextTokenIdx.get()
@TestOnly
fun getRecentTokens(): List<Pair<String, Int>> {
val recentTokens = arrayListOf<Pair<String, Int>>()
for ((key, value) in recent) {
recentTokens.add(key to value)
}
return recentTokens
}
@Synchronized
fun update(token: String) {
if (recent.containsKey(token)) {
recent.remove(token)
}
recent[token] = nextTokenIdx.getAndIncrement()
recent[token] = nextTokenIndex.getAndIncrement()
if (nextTokenIndex.get() > maxTokenIndex) {
// shift token index to avoid token index overflow
resetTokenIndex()
}
}
private fun resetTokenIndex() {
val first = nextTokenIndex.get() - maxSequenceSize
var newLast = 0
for ((key, value) in recent) {
val newIdx = max(value - first + 1, 0)
recent[key] = newIdx
newLast = max(newLast, newIdx)
}
nextTokenIndex.set(newLast + 1)
}
@Synchronized
@@ -168,7 +173,7 @@ class NGramRecentTokens : Externalizable {
fun contains(token: String): Boolean = recent.contains(token)
@Synchronized
fun lastIndex(): Int = nextTokenIdx.get() - 1
fun lastIndex(): Int = nextTokenIndex.get() - 1
@Synchronized
fun size(): Int = recent.size
@@ -176,7 +181,7 @@ class NGramRecentTokens : Externalizable {
@Synchronized
@Throws(IOException::class)
override fun writeExternal(out: ObjectOutput) {
out.writeInt(nextTokenIdx.get())
out.writeInt(nextTokenIndex.get())
out.writeInt(recent.size)
for (entry in recent) {
out.writeObject(entry.key)
@@ -187,11 +192,25 @@ class NGramRecentTokens : Externalizable {
@Synchronized
@Throws(IOException::class)
override fun readExternal(ins: ObjectInput) {
nextTokenIdx.set(ins.readInt())
nextTokenIndex.set(ins.readInt())
val recentSize = ins.readInt()
for (i in 0 until recentSize) {
recent[ins.readObject() as String] = ins.readInt()
}
}
@TestOnly
fun setMaxTokenIndex(newMax: Int) {
maxTokenIndex = newMax
}
@TestOnly
fun getRecentTokens(): List<Pair<String, Int>> {
val recentTokens = arrayListOf<Pair<String, Int>>()
for ((key, value) in recent) {
recentTokens.add(key to value)
}
return recentTokens
}
}

View File

@@ -12,12 +12,14 @@ import com.intellij.internal.ml.ngram.VocabularyWithLimit
class FilePredictionNGramTest : FilePredictionHistoryBaseTest() {
private fun doTestNGramBase(openedFiles: List<String>, nGramLength: Int, maxSequenceLength: Int, vocabularyLimit: Int,
private fun doTestNGramBase(openedFiles: List<String>, nGramLength: Int,
maxSequenceLength: Int, vocabularyLimit: Int, maxIdx: Int? = null,
expectedInternalState: FilePredictionRunnerAssertion,
assertion: (FileHistoryManager) -> Unit) {
val state = FilePredictionHistoryState()
val model = JMModel(counter = ArrayTrieCounter(), order = nGramLength, lambda = 1.0)
val vocabulary = VocabularyWithLimit(vocabularyLimit, nGramLength, maxSequenceLength, 2)
maxIdx?.let { vocabulary.recent.setMaxTokenIndex(it) }
val runner = NGramIncrementalModelRunner(nGramLength, 1.0, model, vocabulary)
val manager = FileHistoryManager(runner, state, vocabularyLimit)
try {
@@ -37,9 +39,10 @@ class FilePredictionNGramTest : FilePredictionHistoryBaseTest() {
nGramOrder: Int,
vocabularyLimit: Int = 3,
maxSequenceLength: Int = 10000,
maxIdx: Int? = null,
expectedInternalState: FilePredictionRunnerAssertion,
expected: List<Pair<String, NextFileProbability>> = emptyList()) {
doTestNGramBase(openedFiles, nGramOrder, maxSequenceLength, vocabularyLimit, expectedInternalState) { manager ->
doTestNGramBase(openedFiles, nGramOrder, maxSequenceLength, vocabularyLimit, maxIdx, expectedInternalState) { manager ->
var total = 0.0
val actual = manager.calcNGramFeatures(expected.map { it.first })
for (expectedEntry in expected) {
@@ -56,9 +59,10 @@ class FilePredictionNGramTest : FilePredictionHistoryBaseTest() {
nGramOrder: Int,
vocabularyLimit: Int = 3,
maxSequenceLength: Int = 10000,
maxIdx: Int? = null,
expectedInternalState: FilePredictionRunnerAssertion,
expected: List<Pair<String, Double>> = emptyList()) {
doTestNGramBase(openedFiles, nGramOrder, maxSequenceLength, vocabularyLimit, expectedInternalState) { manager ->
doTestNGramBase(openedFiles, nGramOrder, maxSequenceLength, vocabularyLimit, maxIdx, expectedInternalState) { manager ->
var total = 0.0
val actual = manager.calcNGramFeatures(expected.map { it.first })
for (expectedEntry in expected) {
@@ -74,33 +78,37 @@ class FilePredictionNGramTest : FilePredictionHistoryBaseTest() {
private fun doTestBiGramMle(openedFiles: List<String>,
vocabularyLimit: Int = 3,
maxSequenceLength: Int = 10000,
maxIdx: Int? = null,
expectedInternalState: FilePredictionRunnerAssertion,
expected: List<Pair<String, Double>> = emptyList()) {
doTestNGramMle(openedFiles, 2, vocabularyLimit, maxSequenceLength, expectedInternalState, expected)
doTestNGramMle(openedFiles, 2, vocabularyLimit, maxSequenceLength, maxIdx, expectedInternalState, expected)
}
private fun doTestTriGramMle(openedFiles: List<String>,
vocabularyLimit: Int = 3,
maxSequenceLength: Int = 10000,
maxIdx: Int? = null,
expectedInternalState: FilePredictionRunnerAssertion,
expected: List<Pair<String, Double>> = emptyList()) {
doTestNGramMle(openedFiles, 3, vocabularyLimit, maxSequenceLength, expectedInternalState, expected)
doTestNGramMle(openedFiles, 3, vocabularyLimit, maxSequenceLength, maxIdx, expectedInternalState, expected)
}
private fun doTestUniGram(openedFiles: List<String>,
vocabularyLimit: Int = 3,
maxSequenceLength: Int = 10000,
maxIdx: Int? = null,
expectedInternalState: FilePredictionRunnerAssertion,
expected: List<Pair<String, NextFileProbability>> = emptyList()) {
doTestNGram(openedFiles, 1, vocabularyLimit, maxSequenceLength, expectedInternalState, expected)
doTestNGram(openedFiles, 1, vocabularyLimit, maxSequenceLength, maxIdx, expectedInternalState, expected)
}
private fun doTestBiGram(openedFiles: List<String>,
vocabularyLimit: Int = 3,
maxSequenceLength: Int = 10000,
maxIdx: Int? = null,
expectedInternalState: FilePredictionRunnerAssertion,
expected: List<Pair<String, NextFileProbability>> = emptyList()) {
doTestNGram(openedFiles, 2, vocabularyLimit, maxSequenceLength, expectedInternalState, expected)
doTestNGram(openedFiles, 2, vocabularyLimit, maxSequenceLength, maxIdx, expectedInternalState, expected)
}
fun `test unigram with all unique files`() {
@@ -483,6 +491,32 @@ class FilePredictionNGramTest : FilePredictionHistoryBaseTest() {
)
}
fun `test bigram mle with forgotten all unique tokens and index reset`() {
val state = FilePredictionRunnerAssertion()
.withVocabulary(8)
.withFileSequence(listOf(7, 8, 9, 10))
.withRecentFiles(7, listOf("c", "d", "e", "f", "g", "h", "i", "j"), listOf(0, 0, 1, 2, 3, 4, 5, 6))
doTestBiGramMle(
openedFiles = listOf("a", "b", "c", "d", "e", "f", "g", "h", "i", "j"),
vocabularyLimit = 8, maxSequenceLength = 4, maxIdx = 8,
expectedInternalState = state,
expected = listOf(
"a" to 0.0,
"b" to 0.0,
"c" to 0.0,
"d" to 0.0,
"e" to 0.0,
"f" to 0.0,
"g" to 0.25,
"h" to 0.25,
"i" to 0.25,
"j" to 0.25,
"x" to 0.0
)
)
}
fun `test bigram mle short sequence with all unique tokens`() {
val state = FilePredictionRunnerAssertion()
.withVocabulary(3)
@@ -984,6 +1018,35 @@ class FilePredictionNGramTest : FilePredictionHistoryBaseTest() {
)
}
fun `test trigram mle for repeated symbol with sequence limit and index reset`() {
val state = FilePredictionRunnerAssertion()
.withVocabulary(4)
.withFileSequence(listOf(
3, 3, 4, 4, 4, 4, 4
))
.withRecentFiles(
14,
listOf("a", "b", "c", "d"),
listOf(0, 5, 8, 13)
)
doTestTriGramMle(
openedFiles = listOf(
"a", "b", "b", "b", "a", "a", "a", "b", "b", "b",
"b", "b", "c", "c", "c", "d", "d", "d", "d", "d"
),
vocabularyLimit = 4, maxSequenceLength = 7, maxIdx = 14,
expectedInternalState = state,
expected = listOf(
"a" to 0.0,
"b" to 0.0,
"c" to 0.0,
"d" to 1.0,
"x" to 0.0
)
)
}
fun `test trigram mle for repeated symbol with sequence and vocabulary limit`() {
val state = FilePredictionRunnerAssertion()
.withVocabulary(4)

View File

@@ -53,7 +53,7 @@ internal class FilePredictionRunnerAssertion {
if (withRecentFiles) {
val recent = vocabulary.recent
TestCase.assertEquals("Next file sequence index is different from expected", nextFileSequenceIdx, recent.getNextTokenIndex())
TestCase.assertEquals("Next file sequence index is different from expected", nextFileSequenceIdx, recent.lastIndex() + 1)
val tokens = recent.getRecentTokens()
TestCase.assertEquals("Recent files are different from expected", recentFiles, tokens.map { it.first })