PY-39607 Add new features for ml completion

GitOrigin-RevId: caf60bd52a1872a2eb53f3617a6ffb51b195ff33
This commit is contained in:
andrey.matveev
2019-11-25 21:17:54 +07:00
committed by intellij-monorepo-bot
parent 60982e9081
commit c4950335ac
79 changed files with 1162 additions and 252 deletions

View File

@@ -35,5 +35,6 @@
"with": 34,
"yield": 35,
"print": 36,
"exec": 37
"exec": 37,
"__init__": 38
}

View File

@@ -451,8 +451,8 @@
<applicationService serviceImplementation="com.jetbrains.python.testing.PyTestFrameworkService"/>
<autoImportOptionsProvider instance="com.jetbrains.python.codeInsight.imports.PyAutoImportOptions"/>
<completion.ml.contextFeatures language="Python" implementationClass="com.jetbrains.python.codeInsight.mlcompletion.PyLocationFeatures"/>
<completion.ml.elementFeatures language="Python" implementationClass="com.jetbrains.python.codeInsight.mlcompletion.PyElementFeatures"/>
<completion.ml.contextFeatures language="Python" implementationClass="com.jetbrains.python.codeInsight.mlcompletion.PyContextFeatureProvider"/>
<completion.ml.elementFeatures language="Python" implementationClass="com.jetbrains.python.codeInsight.mlcompletion.PyElementFeatureProvider"/>
<completion.confidence language="Python" implementationClass="com.jetbrains.python.codeInsight.completion.PyCompletionConfidence"/>
<completion.ml.model implementation="com.jetbrains.python.codeInsight.mlcompletion.PythonMLRankingProvider"/>
<typedHandler implementation="com.jetbrains.python.console.completion.PythonConsoleAutopopupBlockingHandler" id="pydevBlockAutoPopup"

View File

@@ -0,0 +1,14 @@
// Copyright 2000-2020 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.codeInsight.mlcompletion
internal class Counter<T> {
private val map = hashMapOf<T, Int>()
fun add(key: T?, cnt: Int = 1) {
if (key != null) map.merge(key, cnt, Integer::sum)
}
operator fun get(key: T?): Int? = map[key]
fun toMap(): Map<T, Int> = map
}

View File

@@ -0,0 +1,68 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.codeInsight.mlcompletion
import com.intellij.psi.PsiElement
import com.intellij.psi.util.PsiTreeUtil
import com.jetbrains.python.psi.PyArgumentList
import com.jetbrains.python.psi.PyKeywordArgument
import com.jetbrains.python.psi.PyReferenceExpression
object PyArgumentsCompletionFeatures {
data class ArgumentsContextCompletionFeatures(
val isInArguments: Boolean,
val isDirectlyInArgumentContext: Boolean,
val isIntoKeywordArgument: Boolean,
val argumentIndex: Int?,
val argumentsSize: Int?,
val haveNamedArgLeft: Boolean,
val haveNamedArgRight: Boolean)
fun getContextArgumentFeatures(locationPsi: PsiElement): ArgumentsContextCompletionFeatures {
val argListParent = PsiTreeUtil.getParentOfType(locationPsi, PyArgumentList::class.java)
val isInArguments = argListParent != null
val isDirectlyInArgumentContext = isDirectlyInArgumentsContext(locationPsi)
val isIntoKeywordArgument = PsiTreeUtil.getParentOfType(locationPsi, PyKeywordArgument::class.java) != null
var argumentIndex: Int? = null
var argumentsSize: Int? = null
var haveNamedArgLeft = false
var haveNamedArgRight = false
if (argListParent != null) {
val arguments = argListParent.arguments
argumentsSize = arguments.size
for (i in arguments.indices) {
if (PsiTreeUtil.isAncestor(arguments[i], locationPsi, false)) {
argumentIndex = i
}
else if (arguments[i] is PyKeywordArgument) {
if (argumentIndex == null) {
haveNamedArgLeft = true
}
else {
haveNamedArgRight = true
}
}
}
}
return ArgumentsContextCompletionFeatures(isInArguments,
isDirectlyInArgumentContext,
isIntoKeywordArgument,
argumentIndex,
argumentsSize,
haveNamedArgLeft,
haveNamedArgRight)
}
private fun isDirectlyInArgumentsContext(locationPsi: PsiElement): Boolean {
// for zero prefix
if (locationPsi.parent is PyArgumentList) return true
// for non-zero prefix
if (locationPsi.parent !is PyReferenceExpression) return false
if (locationPsi.parent.parent !is PyArgumentList) return false
return true
}
}

View File

@@ -0,0 +1,27 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.codeInsight.mlcompletion
import com.intellij.codeInsight.completion.ml.CompletionEnvironment
import com.intellij.psi.util.PsiTreeUtil
import com.jetbrains.python.psi.PyClass
import com.jetbrains.python.psi.PyUtil
object PyClassCompletionFeatures {
data class ClassFeatures(val diffLinesWithClassDef: Int, val classHaveConstructor: Boolean)
fun getClassCompletionFeatures(environment: CompletionEnvironment): ClassFeatures? {
val parentClass = PsiTreeUtil.getParentOfType(environment.parameters.position, PyClass::class.java) ?: return null
val lookup = environment.lookup
val editor = lookup.topLevelEditor
val caretOffset = lookup.lookupStart
val logicalPosition = editor.offsetToLogicalPosition(caretOffset)
val lineno = logicalPosition.line
val classLogicalPosition = editor.offsetToLogicalPosition(parentClass.textOffset)
val classLineno = classLogicalPosition.line
val classHaveConstructor = parentClass.methods.any { PyUtil.isInitOrNewMethod(it) }
return ClassFeatures(lineno - classLineno, classHaveConstructor)
}
}

View File

@@ -1,32 +1,73 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.codeInsight.mlcompletion
import com.intellij.codeInsight.completion.CompletionLocation
import com.intellij.codeInsight.lookup.LookupElement
import com.intellij.codeInsight.lookup.LookupElementPresentation
import com.intellij.openapi.module.ModuleUtilCore
import com.intellij.openapi.projectRoots.Sdk
import com.intellij.openapi.vfs.VirtualFile
import com.intellij.psi.PsiComment
import com.intellij.psi.PsiDirectory
import com.intellij.psi.PsiElement
import com.intellij.psi.PsiWhiteSpace
import com.intellij.psi.util.PsiTreeUtil
import com.intellij.psi.util.elementType
import com.jetbrains.python.PyTokenTypes
import com.jetbrains.python.codeInsight.mlcompletion.PyCompletionMlElementInfo
import com.jetbrains.python.codeInsight.mlcompletion.PyCompletionMlElementKind
import com.jetbrains.python.psi.*
import com.jetbrains.python.sdk.PythonSdkUtil
object PyCompletionFeatures {
fun isDirectlyInArgumentsContext(locationPsi: PsiElement): Boolean {
// for zero prefix
if (locationPsi.parent is PyArgumentList) return true
// for non-zero prefix
if (locationPsi.parent !is PyReferenceExpression) return false
if (locationPsi.parent.parent !is PyArgumentList) return false
return true
fun isDictKey(element: LookupElement): Boolean {
val presentation = LookupElementPresentation.renderElement(element)
return ("dict key" == presentation.typeText)
}
private fun isAfterColon(locationPsi: PsiElement): Boolean {
val prevVisibleLeaf = PsiTreeUtil.prevVisibleLeaf(locationPsi)
return (prevVisibleLeaf != null && prevVisibleLeaf.elementType == PyTokenTypes.COLON)
fun isTheSameFile(element: LookupElement, location: CompletionLocation): Boolean {
val psiFile = location.completionParameters.originalFile
val elementPsiFile = element.psiElement?.containingFile ?: return false
return psiFile == elementPsiFile
}
fun isTakesParameterSelf(element: LookupElement): Boolean {
val presentation = LookupElementPresentation.renderElement(element)
return presentation.tailText == "(self)"
}
enum class ElementNameUnderscoreType {NO_UNDERSCORE, TWO_START_END, TWO_START, ONE_START}
fun getElementNameUnderscoreType(name: String): ElementNameUnderscoreType {
return when {
name.startsWith("__") && name.endsWith("__") -> ElementNameUnderscoreType.TWO_START_END
name.startsWith("__") -> ElementNameUnderscoreType.TWO_START
name.startsWith("_") -> ElementNameUnderscoreType.ONE_START
else -> ElementNameUnderscoreType.NO_UNDERSCORE
}
}
fun isPsiElementIsPyFile(element: LookupElement) = element.psiElement is PyFile
fun isPsiElementIsPsiDirectory(element: LookupElement) = element.psiElement is PsiDirectory
data class ElementModuleCompletionFeatures(val isFromStdLib: Boolean, val canFindModule: Boolean)
fun getElementModuleCompletionFeatures(element: LookupElement): ElementModuleCompletionFeatures? {
val psiElement = element.psiElement ?: return null
var vFile: VirtualFile? = null
var sdk: Sdk? = null
val containingFile = psiElement.containingFile
if (psiElement is PsiDirectory) {
vFile = psiElement.virtualFile
sdk = PythonSdkUtil.findPythonSdk(psiElement)
}
else if (containingFile != null) {
vFile = containingFile.virtualFile
sdk = PythonSdkUtil.findPythonSdk(containingFile)
}
if (vFile != null) {
val isFromStdLib = PythonSdkUtil.isStdLib(vFile, sdk)
val canFindModule = ModuleUtilCore.findModuleForFile(vFile, psiElement.project) != null
return ElementModuleCompletionFeatures(isFromStdLib, canFindModule)
}
return null
}
fun isInCondition(locationPsi: PsiElement): Boolean {
@@ -90,11 +131,9 @@ object PyCompletionFeatures {
fun isInDocstring(element: PsiElement) = element.parent is StringLiteralExpression
val res = ArrayList<Int>()
val whitespaceElem = when {
isIndentElement(locationPsi) -> locationPsi
locationPsi.prevSibling != null && isIndentElement(locationPsi.prevSibling) -> locationPsi.prevSibling
else -> return res
}
val whitespaceElem = PsiTreeUtil.prevLeaf(locationPsi) ?: return res
if (!whitespaceElem.text.contains('\n')) return res
val caretIndent = getIndent(whitespaceElem)
var stepsCounter = 0
@@ -126,8 +165,7 @@ object PyCompletionFeatures {
when (kind) {
in arrayOf(PyCompletionMlElementKind.FUNCTION,
PyCompletionMlElementKind.TYPE_OR_CLASS,
PyCompletionMlElementKind.FROM_TARGET) ->
{
PyCompletionMlElementKind.FROM_TARGET) -> {
val statementList = PsiTreeUtil.getParentOfType(locationPsi, PyStatementList::class.java, PyFile::class.java) ?: return null
val children = PsiTreeUtil.collectElementsOfType(statementList, PyReferenceExpression::class.java)
return children.count { it.textOffset < locationPsi.textOffset && it.textMatches(lookupString) }
@@ -152,18 +190,20 @@ object PyCompletionFeatures {
}
}
fun getImportPopularityFeature(locationPsi: PsiElement, lookupString: String): Int? {
if (locationPsi.parent !is PyReferenceExpression) return null
if (locationPsi.parent.parent !is PyImportElement) return null
if (locationPsi.parent.parent.parent !is PyImportStatement) return null
return PyMlCompletionHelpers.importPopularity[lookupString]
}
fun getBuiltinPopularityFeature(lookupString: String, isBuiltins: Boolean): Int? =
if (isBuiltins) PyMlCompletionHelpers.builtinsPopularity[lookupString] else null
fun getKeywordId(lookupString: String): Int? = PyMlCompletionHelpers.getKeywordId(lookupString)
fun getPyLookupElementInfo(element: LookupElement): PyCompletionMlElementInfo? = element.getUserData(
PyCompletionMlElementInfo.key)
fun getPyLookupElementInfo(element: LookupElement): PyCompletionMlElementInfo? = element.getUserData(PyCompletionMlElementInfo.key)
fun getNumberOfQualifiersInExpresionFeature(element: PsiElement): Int {
if (element !is PyQualifiedExpression) return 1
return element.asQualifiedName()?.components?.size ?: 1
}
private fun isAfterColon(locationPsi: PsiElement): Boolean {
val prevVisibleLeaf = PsiTreeUtil.prevVisibleLeaf(locationPsi)
return (prevVisibleLeaf != null && prevVisibleLeaf.elementType == PyTokenTypes.COLON)
}
}

View File

@@ -0,0 +1,60 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.codeInsight.mlcompletion
import com.intellij.codeInsight.completion.ml.CompletionEnvironment
import com.intellij.codeInsight.completion.ml.ContextFeatureProvider
import com.intellij.codeInsight.completion.ml.MLFeatureValue
import com.jetbrains.python.psi.types.TypeEvalContext
class PyContextFeatureProvider : ContextFeatureProvider {
override fun getName(): String = "python"
override fun calculateFeatures(environment: CompletionEnvironment): Map<String, MLFeatureValue> {
val result = HashMap<String, MLFeatureValue>()
val position = environment.parameters.position
val typeEvalContext = TypeEvalContext.codeInsightFallback(position.project)
result["is_in_condition"] = MLFeatureValue.binary(PyCompletionFeatures.isInCondition(position))
result["is_after_if_statement_without_else_branch"] = MLFeatureValue.binary(PyCompletionFeatures.isAfterIfStatementWithoutElseBranch(position))
result["is_in_for_statement"] = MLFeatureValue.binary(PyCompletionFeatures.isInForStatement(position))
result["num_of_prev_qualifiers"] = MLFeatureValue.numerical(PyCompletionFeatures.getNumberOfQualifiersInExpresionFeature(position))
val neighboursKws = PyCompletionFeatures.getPrevNeighboursKeywordIds(position)
if (neighboursKws.size > 0) result["prev_neighbour_keyword_1"] = MLFeatureValue.numerical(neighboursKws[0])
if (neighboursKws.size > 1) result["prev_neighbour_keyword_2"] = MLFeatureValue.numerical(neighboursKws[1])
val sameLineKws = PyCompletionFeatures.getPrevKeywordsIdsInTheSameLine(position)
if (sameLineKws.size > 0) result["prev_same_line_keyword_1"] = MLFeatureValue.numerical(sameLineKws[0])
if (sameLineKws.size > 1) result["prev_same_line_keyword_2"] = MLFeatureValue.numerical(sameLineKws[1])
val sameColumnKws = PyCompletionFeatures.getPrevKeywordsIdsInTheSameColumn(position)
if (sameColumnKws.size > 0) result["prev_same_column_keyword_1"] = MLFeatureValue.numerical(sameColumnKws[0])
if (sameColumnKws.size > 1) result["prev_same_column_keyword_2"] = MLFeatureValue.numerical(sameColumnKws[1])
with (PyArgumentsCompletionFeatures.getContextArgumentFeatures(position)) {
result["is_in_arguments"] = MLFeatureValue.binary(isInArguments)
result["is_directly_in_arguments_context"] = MLFeatureValue.binary(isDirectlyInArgumentContext)
result["is_into_keyword_arg"] = MLFeatureValue.binary(isIntoKeywordArgument)
result["have_named_arg_left"] = MLFeatureValue.binary(haveNamedArgLeft)
result["have_named_arg_right"] = MLFeatureValue.binary(haveNamedArgRight)
argumentIndex?.let { result["argument_index"] = MLFeatureValue.numerical(it) }
argumentsSize?.let { result["number_of_arguments_already"] = MLFeatureValue.numerical(it) }
}
PyReceiverMlCompletionFeatures.calculateReceiverElementInfo(environment, typeEvalContext)
PyNamesMatchingMlCompletionFeatures.calculateFunBodyNames(environment)
PyNamesMatchingMlCompletionFeatures.calculateSameLineLeftNames(environment).let { names ->
result["have_opening_round_bracket"] = MLFeatureValue.binary(PyParenthesesFeatures.haveOpeningRoundBracket(names))
result["have_opening_square_bracket"] = MLFeatureValue.binary(PyParenthesesFeatures.haveOpeningSquareBracket(names))
result["have_opening_brace"] = MLFeatureValue.binary(PyParenthesesFeatures.haveOpeningBrace(names))
}
PyClassCompletionFeatures.getClassCompletionFeatures(environment)?.let { with(it) {
result["diff_lines_with_class_def"] = MLFeatureValue.numerical(diffLinesWithClassDef)
result["containing_class_have_constructor"] = MLFeatureValue.binary(classHaveConstructor)
}}
return result
}
}

View File

@@ -0,0 +1,82 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.codeInsight.mlcompletion
import com.intellij.codeInsight.completion.CompletionLocation
import com.intellij.codeInsight.completion.ml.ContextFeatures
import com.intellij.codeInsight.completion.ml.ElementFeatureProvider
import com.intellij.codeInsight.completion.ml.MLFeatureValue
import com.intellij.codeInsight.lookup.LookupElement
class PyElementFeatureProvider : ElementFeatureProvider {
override fun getName(): String = "python"
override fun calculateFeatures(element: LookupElement,
location: CompletionLocation,
contextFeatures: ContextFeatures): Map<String, MLFeatureValue> {
val result = HashMap<String, MLFeatureValue>()
val lookupString = element.lookupString
val locationPsi = location.completionParameters.position
PyCompletionFeatures.getPyLookupElementInfo(element)?.let { info ->
result["kind"] = MLFeatureValue.categorical(info.kind)
result["is_builtins"] = MLFeatureValue.binary(info.isBuiltins)
PyCompletionFeatures.getNumberOfOccurrencesInScope(info.kind, locationPsi, lookupString)?.let {
result["number_of_occurrences_in_scope"] = MLFeatureValue.numerical(it)
}
PyCompletionFeatures.getBuiltinPopularityFeature(lookupString, info.isBuiltins)?.let {
result["builtin_popularity"] = MLFeatureValue.numerical(it)
}
}
PyCompletionFeatures.getKeywordId(lookupString)?.let {
result["keyword_id"] = MLFeatureValue.numerical(it)
}
result["is_dict_key"] = MLFeatureValue.binary(PyCompletionFeatures.isDictKey(element))
result["is_the_same_file"] = MLFeatureValue.binary(PyCompletionFeatures.isTheSameFile(element, location))
result["is_takes_parameter_self"] = MLFeatureValue.binary(PyCompletionFeatures.isTakesParameterSelf(element))
result["underscore_type"] = MLFeatureValue.categorical(PyCompletionFeatures.getElementNameUnderscoreType(lookupString))
result["number_of_tokens"] = MLFeatureValue.numerical(PyNamesMatchingMlCompletionFeatures.getNumTokensFeature(lookupString))
result["element_is_py_file"] = MLFeatureValue.binary(PyCompletionFeatures.isPsiElementIsPyFile(element))
result["element_is_psi_directory"] = MLFeatureValue.binary(PyCompletionFeatures.isPsiElementIsPsiDirectory(element))
PyCompletionFeatures.getElementModuleCompletionFeatures(element)?.let { with(it) {
result["element_module_is_std_lib"] = MLFeatureValue.binary(isFromStdLib)
result["can_find_element_module"] = MLFeatureValue.binary(canFindModule)
}
}
PyImportCompletionFeatures.getImportPopularityFeature(locationPsi, lookupString)?.let {
result["import_popularity"] = MLFeatureValue.numerical(it)
}
PyImportCompletionFeatures.getElementImportPathFeatures(element, location)?.let { with (it) {
result["is_imported"] = MLFeatureValue.binary(isImported)
result["num_components_in_import_path"] = MLFeatureValue.numerical(numComponents)
result["num_private_components_in_import_path"] = MLFeatureValue.numerical(numPrivateComponents)
}}
PyNamesMatchingMlCompletionFeatures.getPyFunClassFileBodyMatchingFeatures(contextFeatures, element.lookupString)?.let { with(it) {
result["scope_num_names"] = MLFeatureValue.numerical(numScopeNames)
result["scope_num_different_names"] = MLFeatureValue.numerical(numScopeDifferentNames)
result["scope_num_matches"] = MLFeatureValue.numerical(sumMatches)
result["scope_num_tokens_matches"] = MLFeatureValue.numerical(sumTokensMatches)
}}
PyNamesMatchingMlCompletionFeatures.getPySameLineMatchingFeatures(contextFeatures, element.lookupString)?.let { with(it) {
result["same_line_num_names"] = MLFeatureValue.numerical(numScopeNames)
result["same_line_num_different_names"] = MLFeatureValue.numerical(numScopeDifferentNames)
result["same_line_num_matches"] = MLFeatureValue.numerical(sumMatches)
result["same_line_num_tokens_matches"] = MLFeatureValue.numerical(sumTokensMatches)
}}
PyNamesMatchingMlCompletionFeatures.getMatchingWithReceiverFeatures(contextFeatures, element)?.let { with(it) {
result["receiver_name_matches"] = MLFeatureValue.binary(matchesWithReceiver)
result["receiver_num_matched_tokens"] = MLFeatureValue.numerical(numMatchedTokens)
result["receiver_tokens_num"] = MLFeatureValue.numerical(receiverTokensNum)
}}
return result
}
}

View File

@@ -1,36 +0,0 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.codeInsight.mlcompletion
import com.intellij.codeInsight.completion.CompletionLocation
import com.intellij.codeInsight.completion.ml.ContextFeatures
import com.intellij.codeInsight.completion.ml.ElementFeatureProvider
import com.intellij.codeInsight.completion.ml.MLFeatureValue
import com.intellij.codeInsight.lookup.LookupElement
class PyElementFeatures : ElementFeatureProvider {
override fun getName(): String = "python"
override fun calculateFeatures(element: LookupElement,
location: CompletionLocation,
contextFeatures: ContextFeatures): Map<String, MLFeatureValue> {
val result = HashMap<String, MLFeatureValue>()
val lookupString = element.lookupString
val locationPsi = location.completionParameters.position
PyCompletionFeatures.getPyLookupElementInfo(element)?.let { info ->
result["kind"] = MLFeatureValue.categorical(info.kind)
result["is_builtins"] = MLFeatureValue.binary(info.isBuiltins)
PyCompletionFeatures.getNumberOfOccurrencesInScope(info.kind, locationPsi, lookupString)?.let { occurrences ->
result["number_of_occurrences_in_scope"] = MLFeatureValue.float(occurrences)
}
PyCompletionFeatures.getBuiltinPopularityFeature(lookupString, info.isBuiltins)?.let { result["builtin_popularity"] = MLFeatureValue.float(it) }
}
PyCompletionFeatures.getImportPopularityFeature(locationPsi, lookupString)?.let { result["import_popularity"] = MLFeatureValue.float(it) }
PyCompletionFeatures.getKeywordId(lookupString)?.let { result["keyword_id"] = MLFeatureValue.float(it) }
return result
}
}

View File

@@ -0,0 +1,34 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.codeInsight.mlcompletion
import com.intellij.codeInsight.completion.CompletionLocation
import com.intellij.codeInsight.lookup.LookupElement
import com.intellij.psi.PsiElement
import com.jetbrains.python.codeInsight.completion.hasImportsFrom
import com.jetbrains.python.psi.PyImportElement
import com.jetbrains.python.psi.PyImportStatement
import com.jetbrains.python.psi.PyReferenceExpression
import com.jetbrains.python.psi.resolve.QualifiedNameFinder
object PyImportCompletionFeatures {
data class ElementImportPathFeatures (val isImported: Boolean,
val numPrivateComponents: Int,
val numComponents: Int)
fun getElementImportPathFeatures(element: LookupElement, location: CompletionLocation): ElementImportPathFeatures? {
val psiElement = element.psiElement ?: return null
val importPath = QualifiedNameFinder.findShortestImportableQName(psiElement.containingFile) ?: return null
val caretLocationFile = location.completionParameters.originalFile
val isImported = hasImportsFrom(caretLocationFile, importPath)
val numComponents = importPath.componentCount
val numPrivateComponents = importPath.components.count{ it.startsWith("_") }
return ElementImportPathFeatures(isImported, numPrivateComponents, numComponents)
}
fun getImportPopularityFeature(locationPsi: PsiElement, lookupString: String): Int? {
if (locationPsi.parent !is PyReferenceExpression) return null
if (locationPsi.parent.parent !is PyImportElement) return null
if (locationPsi.parent.parent.parent !is PyImportStatement) return null
return PyMlCompletionHelpers.importPopularity[lookupString]
}
}

View File

@@ -1,34 +0,0 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.codeInsight.mlcompletion
import com.intellij.codeInsight.completion.ml.ContextFeatureProvider
import com.intellij.codeInsight.completion.ml.MLFeatureValue
import com.intellij.codeInsight.lookup.Lookup
class PyLocationFeatures : ContextFeatureProvider {
override fun getName(): String = "python"
override fun calculateFeatures(lookup: Lookup): Map<String, MLFeatureValue> {
val result = HashMap<String, MLFeatureValue>()
val locationPsi = lookup.psiElement ?: return result
result["is_directly_in_arguments_context"] = MLFeatureValue.binary(PyCompletionFeatures.isDirectlyInArgumentsContext(locationPsi))
result["is_in_condition"] = MLFeatureValue.binary(PyCompletionFeatures.isInCondition(locationPsi))
result["is_after_if_statement_without_else_branch"] = MLFeatureValue.binary(PyCompletionFeatures.isAfterIfStatementWithoutElseBranch(locationPsi))
result["is_in_for_statement"] = MLFeatureValue.binary(PyCompletionFeatures.isInForStatement(locationPsi))
val neighboursKws = PyCompletionFeatures.getPrevNeighboursKeywordIds(locationPsi)
if (neighboursKws.size > 0) result["prev_neighbour_keyword_1"] = MLFeatureValue.float(neighboursKws[0])
if (neighboursKws.size > 1) result["prev_neighbour_keyword_2"] = MLFeatureValue.float(neighboursKws[1])
val sameLineKws = PyCompletionFeatures.getPrevKeywordsIdsInTheSameLine(locationPsi)
if (sameLineKws.size > 0) result["prev_same_line_keyword_1"] = MLFeatureValue.float(sameLineKws[0])
if (sameLineKws.size > 1) result["prev_same_line_keyword_2"] = MLFeatureValue.float(sameLineKws[1])
val sameColumnKws = PyCompletionFeatures.getPrevKeywordsIdsInTheSameColumn(locationPsi)
if (sameColumnKws.size > 0) result["prev_same_column_keyword_1"] = MLFeatureValue.float(sameColumnKws[0])
if (sameColumnKws.size > 1) result["prev_same_column_keyword_2"] = MLFeatureValue.float(sameColumnKws[1])
return result
}
}

View File

@@ -0,0 +1,192 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.codeInsight.mlcompletion
import com.intellij.codeInsight.completion.CompletionUtilCore.DUMMY_IDENTIFIER_TRIMMED
import com.intellij.codeInsight.completion.ml.CompletionEnvironment
import com.intellij.codeInsight.completion.ml.ContextFeatures
import com.intellij.codeInsight.lookup.LookupElement
import com.intellij.openapi.util.Key
import com.intellij.openapi.util.text.StringUtil
import com.intellij.psi.PsiElement
import com.intellij.psi.util.PsiTreeUtil
import com.jetbrains.python.psi.*
object PyNamesMatchingMlCompletionFeatures {
private val scopeNamesKey = Key<Map<String, Int>>("py.ml.completion.scope.names")
private val scopeTokensKey = Key<Map<String, Int>>("py.ml.completion.scope.tokens")
private val lineLeftNamesKey = Key<Map<String, Int>>("py.ml.completion.line.left.names")
private val lineLeftTokensKey = Key<Map<String, Int>>("py.ml.completion.line.left.tokens")
data class PyScopeMatchingFeatures(val sumMatches: Int,
val sumTokensMatches: Int,
val numScopeNames: Int,
val numScopeDifferentNames: Int)
fun getPyFunClassFileBodyMatchingFeatures(contextFeatures: ContextFeatures, lookupString: String): PyScopeMatchingFeatures? {
val names = contextFeatures.getUserData(scopeNamesKey) ?: return null
val tokens = contextFeatures.getUserData(scopeTokensKey) ?: return null
return getPyScopeMatchingFeatures(names, tokens, lookupString)
}
fun getPySameLineMatchingFeatures(contextFeatures: ContextFeatures, lookupString: String): PyScopeMatchingFeatures? {
val names = contextFeatures.getUserData(lineLeftNamesKey) ?: return null
val tokens = contextFeatures.getUserData(lineLeftTokensKey) ?: return null
return getPyScopeMatchingFeatures(names, tokens, lookupString)
}
fun getNumTokensFeature(elementName: String) = getTokens(elementName).size
data class MatchingWithReceiverFeatures(val matchesWithReceiver: Boolean,
val receiverTokensNum: Int,
val numMatchedTokens: Int)
fun getMatchingWithReceiverFeatures(contextFeatures: ContextFeatures, element: LookupElement): MatchingWithReceiverFeatures? {
val names = contextFeatures.getUserData(PyReceiverMlCompletionFeatures.receiverNamesKey) ?: return null
if (names.isEmpty()) return null
val matchesWithReceiver = names.any { it == element.lookupString }
val maxMatchedToken = names.maxBy { tokensMatched(element.lookupString, it) } ?: ""
val numMatchedTokens = tokensMatched(maxMatchedToken, element.lookupString)
val receiverTokensNum = getNumTokensFeature(maxMatchedToken)
return MatchingWithReceiverFeatures(matchesWithReceiver, receiverTokensNum, numMatchedTokens)
}
fun calculateFunBodyNames(environment: CompletionEnvironment) {
val position = environment.parameters.position
val scope = PsiTreeUtil.getParentOfType(position, PyFile::class.java, PyFunction::class.java, PyClass::class.java)
val names = collectUsedNames(scope)
environment.putUserData(scopeNamesKey, names)
environment.putUserData(scopeTokensKey, getTokensCounterMap(names).toMap())
}
fun calculateSameLineLeftNames(environment: CompletionEnvironment): Map<String, Int> {
val position = environment.parameters.position
var curElement = PsiTreeUtil.prevLeaf(position)
val names = Counter<String>()
while (curElement != null && !curElement.text.contains("\n")) {
val text = curElement.text
if (!StringUtil.isEmptyOrSpaces(text)) {
names.add(text)
}
curElement = PsiTreeUtil.prevLeaf(curElement)
}
environment.putUserData(lineLeftNamesKey, names.toMap())
environment.putUserData(lineLeftTokensKey, getTokensCounterMap(names.toMap()).toMap())
return names.toMap()
}
private fun getPyScopeMatchingFeatures(names: Map<String, Int>,
tokens: Map<String, Int>,
lookupString: String): PyScopeMatchingFeatures? {
val sumMatches = names[lookupString] ?: 0
val sumTokensMatches = tokensMatched(lookupString, tokens)
val total = names.toList().sumBy { it.second }
return PyScopeMatchingFeatures(sumMatches, sumTokensMatches, total, names.size)
}
private fun collectUsedNames(scope: PsiElement?): Map<String, Int> {
val variables = Counter<String>()
if (scope !is PyClass && scope !is PyFile && scope !is PyFunction) {
return variables.toMap()
}
val visitor = object : PyRecursiveElementVisitor() {
override fun visitPyTargetExpression(node: PyTargetExpression) {
variables.add(node.name)
}
override fun visitPyNamedParameter(node: PyNamedParameter) {
variables.add(node.name)
}
override fun visitPyReferenceExpression(node: PyReferenceExpression) {
if (!node.isQualified) {
variables.add(node.referencedName)
}
else {
super.visitPyReferenceExpression(node)
}
}
override fun visitPyFunction(node: PyFunction) {
variables.add(node.name)
}
override fun visitPyClass(node: PyClass) {
variables.add(node.name)
}
}
if (scope is PyFunction || scope is PyClass) {
scope.accept(visitor)
scope.acceptChildren(visitor)
}
else {
scope.acceptChildren(visitor)
}
return variables.toMap().filter { !it.key.contains(DUMMY_IDENTIFIER_TRIMMED) }
}
fun tokensMatched(firstName: String, secondName: String): Int {
val nameTokens = getTokens(firstName)
val elementNameTokens = getTokens(secondName)
return nameTokens.sumBy { token1 -> elementNameTokens.count { token2 -> token1 == token2 } }
}
fun tokensMatched(name: String, tokens: Map<String, Int>): Int {
val nameTokens = getTokens(name)
return nameTokens.sumBy { tokens[it] ?: 0 }
}
private fun getTokensCounterMap(names: Map<String, Int>): Counter<String> {
val result = Counter<String>()
names.forEach { (name, cnt) ->
val tokens = getTokens(name)
for (token in tokens) {
result.add(token, cnt)
}
}
return result
}
private fun getTokens(name: String): List<String> =
name
.split("_")
.asSequence()
.flatMap { splitByCamelCase(it).asSequence() }
.filter { it.isNotEmpty() }
.toList()
private fun processToken(token: String): String {
val lettersOnly = token.filter { it.isLetter() }
return if (lettersOnly.length > 3) {
when {
lettersOnly.endsWith("s") -> lettersOnly.substring(0 until lettersOnly.length - 1)
lettersOnly.endsWith("es") -> lettersOnly.substring(0 until lettersOnly.length - 2)
else -> lettersOnly
}
}
else lettersOnly
}
private fun splitByCamelCase(name: String): List<String> {
if (isAllLettersUpper(name)) return arrayListOf(processToken(name.toLowerCase()))
val result = ArrayList<String>()
var curToken = ""
for (ch in name) {
if (ch.isUpperCase()) {
if (curToken.isNotEmpty()) {
result.add(processToken(curToken))
curToken = ""
}
curToken += ch.toLowerCase()
}
else {
curToken += ch
}
}
if (curToken.isNotEmpty()) result.add(processToken(curToken))
return result
}
private fun isAllLettersUpper(name: String) = !name.any { it.isLetter() && it.isLowerCase() }
}

View File

@@ -0,0 +1,16 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.codeInsight.mlcompletion
object PyParenthesesFeatures {
fun haveOpeningRoundBracket(names: Map<String, Int>) = haveOpeningBracket(names, "(", ")")
fun haveOpeningSquareBracket(names: Map<String, Int>) = haveOpeningBracket(names, "[", "]")
fun haveOpeningBrace(names: Map<String, Int>): Boolean = haveOpeningBracket(names, "{", "}")
private fun haveOpeningBracket(names: Map<String, Int>, openingBracket: String, closingBracket: String): Boolean {
val cntOpening = names.entries.sumBy { if (it.key == openingBracket) it.value else 0 }
val cntClosing = names.entries.sumBy { if (it.key == closingBracket) it.value else 0 }
return cntOpening > cntClosing
}
}

View File

@@ -0,0 +1,55 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.codeInsight.mlcompletion
import com.intellij.codeInsight.completion.ml.CompletionEnvironment
import com.intellij.openapi.util.Key
import com.intellij.psi.PsiElement
import com.intellij.psi.util.PsiTreeUtil
import com.jetbrains.python.psi.PyAssignmentStatement
import com.jetbrains.python.psi.PyCallExpression
import com.jetbrains.python.psi.PyNamedParameter
import com.jetbrains.python.psi.PyTargetExpression
import com.jetbrains.python.psi.resolve.PyResolveContext
import com.jetbrains.python.psi.types.TypeEvalContext
object PyReceiverMlCompletionFeatures {
val receiverNamesKey = Key<List<String>>("py.ml.completion.receiver.name")
fun calculateReceiverElementInfo(environment: CompletionEnvironment, typeEvalContext: TypeEvalContext) {
val position = environment.parameters.position
val receivers = getReceivers(position, typeEvalContext)
if (receivers.isEmpty()) return
val names = receivers.mapNotNull { getNameOfReceiver(it) }
environment.putUserData(receiverNamesKey, names)
}
private fun getNameOfReceiver(element: PsiElement): String? {
return when (element) {
is PyNamedParameter -> element.name
is PyTargetExpression -> element.name
else -> element.text
}
}
private fun getReceivers(position: PsiElement, typeEvalContext: TypeEvalContext): List<PsiElement> {
val scope = PsiTreeUtil.getParentOfType(position, PyCallExpression::class.java, PyAssignmentStatement::class.java)
return when (scope) {
is PyCallExpression -> getReceivers(position, scope, typeEvalContext)
is PyAssignmentStatement -> getReceivers(position, scope)
else -> emptyList()
}
}
private fun getReceivers(position: PsiElement, call: PyCallExpression, typeEvalContext: TypeEvalContext): List<PsiElement> {
val resolveContext = PyResolveContext.defaultContext().withTypeEvalContext(typeEvalContext)
val mapArguments = call.multiMapArguments(resolveContext)
if (mapArguments.isEmpty()) return emptyList()
return mapArguments.mapNotNull { entry -> entry.mappedParameters[position.parent]?.parameter }
}
private fun getReceivers(position: PsiElement, assignment: PyAssignmentStatement): List<PsiElement> {
val mapping = assignment.targetsToValuesMapping
val result = mapping.find { it.second == position.parent }?.first
return if (result != null) arrayListOf(result) else emptyList()
}
}

View File

@@ -0,0 +1,8 @@
def foo(param1, param2, param3, param4):
pass
a = 23
b = "asdsa"
c = {}
d = [1, 2]
foo(<caret>)

View File

@@ -0,0 +1,9 @@
def foo(param1, param2, param3, param4):
pass
a = 23
b = "asdsa"
c = {}
d = [1, 2]
foo(a, param2=b, param3=<caret>, param4=d)

View File

@@ -0,0 +1,9 @@
def foo(param1, param2, param3, param4):
pass
a = 23
b = "asdsa"
c = {}
d = [1, 2]
foo(a, param2=b, <caret>, param4=d)

View File

@@ -0,0 +1,2 @@
ddd = {"dict_key": 42, "something_else": 23}
ddd[<caret>]

View File

@@ -0,0 +1,3 @@
k1 = 1
k2 = 22
dct = {k1: "1", k2: <caret>}

View File

@@ -0,0 +1,4 @@
k1 = 1
k2 = 22
dct = {k1: "1", k2: "22"}
dct[<caret>

View File

@@ -0,0 +1,5 @@
class Claaass:
def __init__(self):
pass
def <caret>

View File

@@ -0,0 +1,2 @@
class Claaass:
def <caret>

View File

@@ -0,0 +1,4 @@
if True:
if True:
pass
e<caret>

View File

@@ -0,0 +1 @@
print(<caret>)

View File

@@ -0,0 +1,14 @@
class MyClass(object):
def __init__(self):
pass
def foo(self):
pass
@staticmethod
def bar():
pass
a = MyClass()
a.<caret>

View File

@@ -1,2 +1,3 @@
a = min(1, 2)
b = m<caret>
b = min(2, 3)
c = m<caret>

View File

@@ -0,0 +1,6 @@
class MyError:
def __init__(self):
self.instance_var = 42
MyError.<caret>

View File

@@ -0,0 +1,7 @@
class MyWarning:
def __init__(self):
self.instance_var = 42
w = MyWarning()
w.<caret>

View File

@@ -0,0 +1,7 @@
class Clzz:
def __init__(self):
self.abaCaba = 22
clz = Clzz()
abaCaba = clz.<caret>

View File

@@ -0,0 +1,5 @@
def foo(someParam: int):
pass
someParam = 1
foo(<caret>)

View File

@@ -0,0 +1,5 @@
def foo(some_tokens_set: int):
pass
someToken2_sets1 = 1
foo(<caret>)

View File

@@ -0,0 +1,5 @@
def foo(some_tokens_set: int):
pass
SOME_TOKENS_SET = 1
foo(<caret>)

View File

@@ -0,0 +1,8 @@
def some_foo(some_param1, some_param2, some_param3):
pass
some_param1 = 1
some_param2 = 2
some_param3 = 3
some_foo(some_param1, some_param2, <caret>)

View File

@@ -0,0 +1,13 @@
class SomeClass:
def some_fun(some_param):
some_var_1 = 22
some_var_2 = 23
Some_var_1 = 22
someVar2 = 23
SOME_VAR_3 = 24
print(<caret>)
def some_fun_2(some_param):
some_var_1 = 22
some_var_2 = 23

View File

@@ -0,0 +1,12 @@
def some_fun(some_param):
some_var_1 = 22
some_var_2 = 23
some_var_1 = 22
some_var_2 = 23
SOME_VAR_3 = 24
print(<caret>)
def some_fun_2(some_param):
some_var_1 = 22
some_var_2 = 23

View File

@@ -0,0 +1,3 @@
def foo(someParam1, someParam2):
print(someParam1)
print(someP<caret>)

View File

@@ -0,0 +1,5 @@
abacaba = 42
def foo(abacaba):
print(abacaba)
print(<caret>)

View File

@@ -0,0 +1,9 @@
def foo_one(oneTwoThree):
one = 1
ones = 11 # the same as one (when tokenize)
two = 2
oneTwo = one + two
four_two = 42
oneFourThreeTwo = 1432
print(two_one_three)
print(<caret>)

View File

@@ -0,0 +1,11 @@
class MyClass(object):
def __init__(self):
self.__private_var = 42
self._private_var = 11
self.instance_var = 12
def foo(self):
self.<caret>
obj = MyClass()

View File

@@ -0,0 +1,12 @@
class MyClass(object):
def __init__(self):
self.__private_var = 42
self._priv = 11
self.instance_var = 12
def foo(self):
self.__private_var = 22
obj = MyClass()
obj.<caret>

View File

@@ -1,63 +0,0 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.codeInsight.mlcompletion
import com.intellij.codeInsight.lookup.LookupElement
import com.jetbrains.python.fixtures.PyTestCase
import com.jetbrains.python.psi.LanguageLevel
class Py3ElementFeaturesTest: PyTestCase() {
override fun setUp() {
super.setUp()
setLanguageLevel(LanguageLevel.PYTHON35)
}
override fun getTestDataPath(): String = super.getTestDataPath() + "/codeInsight/mlcompletion"
fun testNumberOfOccurrencesFunction() = doTestNumberOfOccurrences("min", PyCompletionMlElementKind.FUNCTION, true, 1)
fun testNumberOfOccurrencesClass() = doTestNumberOfOccurrences("MyClazz", PyCompletionMlElementKind.TYPE_OR_CLASS, false, 1)
fun testNumberOfOccurrencesNamedArgs1() = doTestNumberOfOccurrences("end=", PyCompletionMlElementKind.NAMED_ARG, false, 1)
fun testNumberOfOccurrencesNamedArgs2() = doTestNumberOfOccurrences("file=", PyCompletionMlElementKind.NAMED_ARG, false, 0)
fun testNumberOfOccurrencesPackagesOrModules() = doTestNumberOfOccurrences("collections", PyCompletionMlElementKind.PACKAGE_OR_MODULE, false, 1)
fun testKindNamedArg() = doTestElementInfo("sep=", PyCompletionMlElementKind.NAMED_ARG, false)
fun testClassBuiltins() = doTestElementInfo("Exception", PyCompletionMlElementKind.TYPE_OR_CLASS, true)
fun testClassNotBuiltins() = doTestElementInfo("MyClazz", PyCompletionMlElementKind.TYPE_OR_CLASS, false)
fun testFunctionBuiltins() = doTestElementInfo("max", PyCompletionMlElementKind.FUNCTION, true)
fun testFunctionNotBuiltins() = doTestElementInfo("my_not_builtins_function", PyCompletionMlElementKind.FUNCTION, false)
fun testKindPackageOrModule() = doTestElementInfo("sys", PyCompletionMlElementKind.PACKAGE_OR_MODULE, false)
fun testKindFromTarget1() = doTestElementInfo("local_variable", PyCompletionMlElementKind.FROM_TARGET, false)
fun testKindFromTarget2() = doTestElementInfo("as_target", PyCompletionMlElementKind.FROM_TARGET, false)
fun testKindKeyword() = doTestElementInfo("if", PyCompletionMlElementKind.KEYWORD, false)
private fun invokeCompletionAndGetLookupElement(elementName: String): LookupElement? {
myFixture.configureByFile(getTestName(true) + ".py")
myFixture.completeBasic()
val elements = myFixture.lookupElements!!
return elements.find { it.lookupString == elementName }
}
private fun doTestElementInfo(elementName: String, expectedKind: PyCompletionMlElementKind, expectedIsBuiltins: Boolean) {
val lookupElement = invokeCompletionAndGetLookupElement(elementName)!!
val info = PyCompletionFeatures.getPyLookupElementInfo(lookupElement)!!
assertEquals(expectedKind, info.kind)
assertEquals(expectedIsBuiltins, info.isBuiltins)
}
private fun doTestNumberOfOccurrences(elementName: String, expectedKind: PyCompletionMlElementKind, expectedIsBuiltins: Boolean, expectedNumberOfOccurrences: Int) {
doTestElementInfo(elementName, expectedKind, expectedIsBuiltins)
val num = PyCompletionFeatures.getNumberOfOccurrencesInScope(expectedKind, myFixture.lookup.psiElement!!, elementName)
assertEquals(expectedNumberOfOccurrences, num)
}
}

View File

@@ -1,84 +0,0 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.codeInsight.mlcompletion
import com.intellij.codeInsight.lookup.impl.LookupImpl
import com.intellij.psi.PsiElement
import com.jetbrains.python.codeInsight.mlcompletion.PyCompletionFeatures
import com.jetbrains.python.codeInsight.mlcompletion.PyMlCompletionHelpers
import com.jetbrains.python.fixtures.PyTestCase
import com.jetbrains.python.psi.LanguageLevel
class Py3LocationFeaturesTest: PyTestCase() {
override fun getTestDataPath(): String = super.getTestDataPath() + "/codeInsight/mlcompletion"
override fun setUp() {
super.setUp()
setLanguageLevel(LanguageLevel.PYTHON35)
}
fun testIsInCondition1() = doTestLocationBinaryFeature(PyCompletionFeatures::isInCondition, true)
fun testIsInCondition2() = doTestLocationBinaryFeature(PyCompletionFeatures::isInCondition, true)
fun testIsInCondition3() = doTestLocationBinaryFeature(PyCompletionFeatures::isInCondition, false)
fun testIsInCondition4() = doTestLocationBinaryFeature(PyCompletionFeatures::isInCondition, false)
fun testIsInCondition5() = doTestLocationBinaryFeature(PyCompletionFeatures::isInCondition, false)
fun testIsInCondition6() = doTestLocationBinaryFeature(PyCompletionFeatures::isInCondition, true)
fun testIsInFor1() = doTestLocationBinaryFeature(PyCompletionFeatures::isInForStatement, true)
fun testIsInFor2() = doTestLocationBinaryFeature(PyCompletionFeatures::isInForStatement, true)
fun testIsInFor3() = doTestLocationBinaryFeature(PyCompletionFeatures::isInForStatement, true)
fun testIsInFor4() = doTestLocationBinaryFeature(PyCompletionFeatures::isInForStatement, true)
fun testIsInFor5() = doTestLocationBinaryFeature(PyCompletionFeatures::isInForStatement, false)
fun testIsAfterIfWithoutElse1() = doTestLocationBinaryFeature(PyCompletionFeatures::isAfterIfStatementWithoutElseBranch, true)
fun testIsAfterIfWithoutElse2() = doTestLocationBinaryFeature(PyCompletionFeatures::isAfterIfStatementWithoutElseBranch, false)
fun testIsAfterIfWithoutElse3() = doTestLocationBinaryFeature(PyCompletionFeatures::isAfterIfStatementWithoutElseBranch, false)
fun testIsAfterIfWithoutElse4() = doTestLocationBinaryFeature(PyCompletionFeatures::isAfterIfStatementWithoutElseBranch, true)
fun testIsAfterIfWithoutElse5() = doTestLocationBinaryFeature(PyCompletionFeatures::isAfterIfStatementWithoutElseBranch, false)
fun testIsDirectlyInArgumentsContext1() = doTestLocationBinaryFeature(PyCompletionFeatures::isDirectlyInArgumentsContext, true)
fun testIsDirectlyInArgumentsContext2() = doTestLocationBinaryFeature(PyCompletionFeatures::isDirectlyInArgumentsContext, true)
fun testIsDirectlyInArgumentsContext3() = doTestLocationBinaryFeature(PyCompletionFeatures::isDirectlyInArgumentsContext, false)
fun testIsDirectlyInArgumentsContext4() = doTestLocationBinaryFeature(PyCompletionFeatures::isDirectlyInArgumentsContext, true)
fun testPrevNeighbourKeywords1() = doTestPrevKeywordsFeature(PyCompletionFeatures::getPrevNeighboursKeywordIds, arrayListOf("in"))
fun testPrevNeighbourKeywords2() = doTestPrevKeywordsFeature(PyCompletionFeatures::getPrevNeighboursKeywordIds, arrayListOf("in", "not"))
fun testSameLineKeywords1() = doTestPrevKeywordsFeature(PyCompletionFeatures::getPrevKeywordsIdsInTheSameLine, arrayListOf("in", "if"))
fun testSameLineKeywords2() = doTestPrevKeywordsFeature(PyCompletionFeatures::getPrevKeywordsIdsInTheSameLine, arrayListOf("in", "if"))
fun testSameColumnKeywords1() = doTestPrevKeywordsFeature(PyCompletionFeatures::getPrevKeywordsIdsInTheSameColumn, arrayListOf("elif", "if"))
fun testSameColumnKeywords2() = doTestPrevKeywordsFeature(PyCompletionFeatures::getPrevKeywordsIdsInTheSameColumn, arrayListOf("if"))
fun testSameColumnKeywords3() = doTestPrevKeywordsFeature(PyCompletionFeatures::getPrevKeywordsIdsInTheSameColumn, arrayListOf("def", "def"))
fun testSameColumnKeywords4() = doTestPrevKeywordsFeature(PyCompletionFeatures::getPrevKeywordsIdsInTheSameColumn, arrayListOf("def", "def"))
fun testSameColumnKeywords5() = doTestPrevKeywordsFeature(PyCompletionFeatures::getPrevKeywordsIdsInTheSameColumn, arrayListOf("if", "for"))
fun testSameColumnKeywords6() = doTestPrevKeywordsFeature(PyCompletionFeatures::getPrevKeywordsIdsInTheSameColumn, arrayListOf("if"))
private fun doTestPrevKeywordsFeature(f: (PsiElement, Int) -> ArrayList<Int>, expectedPrevKws: ArrayList<String>) {
val locationPsi = invokeCompletionAndGetLocationPsi()
val actualPrevKwsIds = f(locationPsi, 2)
checkPrevKeywordsEquals(expectedPrevKws, actualPrevKwsIds)
}
private fun checkPrevKeywordsEquals(expectedPrevKws: ArrayList<String>, actualKeywordsIds: ArrayList<Int>) {
assertEquals(expectedPrevKws.size, actualKeywordsIds.size)
actualKeywordsIds.forEachIndexed { index, id ->
assertEquals(id, PyMlCompletionHelpers.getKeywordId(expectedPrevKws[index]))
}
}
private fun doTestLocationBinaryFeature(f: (PsiElement) -> Boolean, expectedResult: Boolean) {
val locationPsi = invokeCompletionAndGetLocationPsi()
assertEquals(expectedResult, f(locationPsi))
}
private fun invokeCompletionAndGetLocationPsi(): PsiElement {
val lookup = invokeCompletionAndGetLookup()
return lookup.psiElement!!
}
private fun invokeCompletionAndGetLookup(): LookupImpl {
myFixture.configureByFile(getTestName(true) + ".py")
myFixture.completeBasic()
return myFixture.lookup as LookupImpl
}
}

View File

@@ -0,0 +1,32 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.codeInsight.mlcompletion
import com.intellij.codeInsight.completion.CompletionLocation
import com.intellij.codeInsight.completion.ml.*
import com.intellij.codeInsight.lookup.LookupElement
class PyAdapterElementFeatureProvider(private val delegate: ElementFeatureProvider) : ElementFeatureProvider {
val features: MutableMap<LookupElement, Map<String, MLFeatureValue>> = hashMapOf()
override fun getName(): String = delegate.name
override fun calculateFeatures(element: LookupElement,
location: CompletionLocation,
contextFeatures: ContextFeatures): MutableMap<String, MLFeatureValue> {
val calculatedFeatures = delegate.calculateFeatures(element, location, contextFeatures)
features[element] = calculatedFeatures
return calculatedFeatures
}
}
class PyAdapterContextFeatureProvider(private val delegate: ContextFeatureProvider): ContextFeatureProvider {
val features: MutableMap<String, MLFeatureValue> = hashMapOf()
override fun getName(): String = delegate.name
override fun calculateFeatures(environment: CompletionEnvironment): Map<String, MLFeatureValue> {
val calculatedFeatures = delegate.calculateFeatures(environment)
features.putAll(calculatedFeatures)
return calculatedFeatures
}
}

View File

@@ -0,0 +1,335 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.codeInsight.mlcompletion
import com.intellij.codeInsight.completion.ml.ContextFeatureProvider
import com.intellij.codeInsight.completion.ml.ElementFeatureProvider
import com.intellij.codeInsight.completion.ml.MLFeatureValue
import com.intellij.codeInsight.lookup.LookupElement
import com.jetbrains.python.PythonLanguage
import com.jetbrains.python.fixtures.PyTestCase
import com.jetbrains.python.psi.LanguageLevel
import org.junit.Assert
class PyMlCompletionFeaturesTest: PyTestCase() {
override fun getTestDataPath(): String = super.getTestDataPath() + "/codeInsight/mlcompletion"
override fun setUp() {
super.setUp()
setLanguageLevel(LanguageLevel.PYTHON35)
}
// Context features
fun testIsInConditionSimpleIf() = doContextFeaturesTest(Pair("is_in_condition", MLFeatureValue.binary(true)))
fun testIsInConditionSimpleElif() = doContextFeaturesTest(Pair("is_in_condition", MLFeatureValue.binary(true)))
fun testIsInConditionIfBody() = doContextFeaturesTest(Pair("is_in_condition", MLFeatureValue.binary(false)))
fun testIsInConditionIfBodyNonZeroPrefix() = doContextFeaturesTest(Pair("is_in_condition", MLFeatureValue.binary(false)))
fun testIsInConditionArgumentContextOfCall() = doContextFeaturesTest(Pair("is_in_condition", MLFeatureValue.binary(false)))
fun testIsInConditionWhileConditionStatement() = doContextFeaturesTest(Pair("is_in_condition", MLFeatureValue.binary(true)))
fun testIsInForSimple() = doContextFeaturesTest(Pair("is_in_for_statement", MLFeatureValue.binary(true)))
fun testIsInForOneLetterPrefix() = doContextFeaturesTest(Pair("is_in_for_statement", MLFeatureValue.binary(true)))
fun testIsInForAfterIn() = doContextFeaturesTest(Pair("is_in_for_statement", MLFeatureValue.binary(true)))
fun testIsInForAfterInOneLetterPrefix() = doContextFeaturesTest(Pair("is_in_for_statement", MLFeatureValue.binary(true)))
fun testIsInForBody() = doContextFeaturesTest(Pair("is_in_for_statement", MLFeatureValue.binary(false)))
fun testIsAfterIfWithoutElseSimple() = doContextFeaturesTest(Pair("is_after_if_statement_without_else_branch", MLFeatureValue.binary(true)))
fun testIsAfterIfWithoutElseAfterElse() = doContextFeaturesTest(Pair("is_after_if_statement_without_else_branch", MLFeatureValue.binary(false)))
fun testIsAfterIfWithoutElseAfterSameLevelLine() = doContextFeaturesTest(Pair("is_after_if_statement_without_else_branch", MLFeatureValue.binary(false)))
fun testIsAfterIfWithoutElseAfterElifOneLetterPrefix() = doContextFeaturesTest(Pair("is_after_if_statement_without_else_branch", MLFeatureValue.binary(true)))
fun testIsAfterIfWithoutElseNestedIfAfterElse() = doContextFeaturesTest(Pair("is_after_if_statement_without_else_branch", MLFeatureValue.binary(false)))
fun testIsAfterIfWithoutElseNestedIfOneLetterPrefix() = doContextFeaturesTest(Pair("is_after_if_statement_without_else_branch", MLFeatureValue.binary(true)))
fun testIsDirectlyInArgumentsContextSimple() = doContextFeaturesTest(Pair("is_directly_in_arguments_context", MLFeatureValue.binary(true)))
fun testIsDirectlyInArgumentsContextSecondArgumentWithPrefix() = doContextFeaturesTest(Pair("is_directly_in_arguments_context", MLFeatureValue.binary(true)))
fun testIsDirectlyInArgumentsContextAfterNamedParameter() = doContextFeaturesTest(Pair("is_directly_in_arguments_context", MLFeatureValue.binary(false)))
fun testIsDirectlyInArgumentsContextInNestedCall() = doContextFeaturesTest(Pair("is_directly_in_arguments_context", MLFeatureValue.binary(true)))
fun testArgumentFeaturesFirstArg() = doContextFeaturesTest(Pair("is_in_arguments", MLFeatureValue.binary(true)),
Pair("is_directly_in_arguments_context", MLFeatureValue.binary(true)),
Pair("is_into_keyword_arg", MLFeatureValue.binary(false)),
Pair("have_named_arg_left", MLFeatureValue.binary(false)),
Pair("have_named_arg_right", MLFeatureValue.binary(false)),
Pair("argument_index", MLFeatureValue.numerical(0)),
Pair("number_of_arguments_already", MLFeatureValue.numerical(1)))
fun testArgumentFeaturesThirdArg() = doContextFeaturesTest(Pair("is_in_arguments", MLFeatureValue.binary(true)),
Pair("is_directly_in_arguments_context", MLFeatureValue.binary(true)),
Pair("is_into_keyword_arg", MLFeatureValue.binary(false)),
Pair("have_named_arg_left", MLFeatureValue.binary(true)),
Pair("have_named_arg_right", MLFeatureValue.binary(true)),
Pair("argument_index", MLFeatureValue.numerical(2)),
Pair("number_of_arguments_already", MLFeatureValue.numerical(4)))
fun testArgumentFeaturesInNamedArg() = doContextFeaturesTest(Pair("is_in_arguments", MLFeatureValue.binary(true)),
Pair("is_directly_in_arguments_context", MLFeatureValue.binary(false)),
Pair("is_into_keyword_arg", MLFeatureValue.binary(true)),
Pair("have_named_arg_left", MLFeatureValue.binary(true)),
Pair("have_named_arg_right", MLFeatureValue.binary(true)),
Pair("argument_index", MLFeatureValue.numerical(2)),
Pair("number_of_arguments_already", MLFeatureValue.numerical(4)))
fun testPrevNeighbourKeywordsIfSomethingIn() = doContextFeaturesTest(arrayListOf(Pair("prev_neighbour_keyword_1", MLFeatureValue.numerical(kwId("in")))),
arrayListOf("prev_neighbour_keyword_2"))
fun testPrevNeighbourKeywordsNotIn() = doContextFeaturesTest(Pair("prev_neighbour_keyword_1", MLFeatureValue.numerical(kwId("in"))),
Pair("prev_neighbour_keyword_2", MLFeatureValue.numerical(kwId("not"))))
fun testSameLineKeywordsIfSomethingIn() = doContextFeaturesTest(Pair("prev_same_line_keyword_1", MLFeatureValue.numerical(kwId("in"))),
Pair("prev_same_line_keyword_2", MLFeatureValue.numerical(kwId("if"))))
fun testSameLineKeywordsIfSomethingInWithPrevLine() = doContextFeaturesTest(Pair("prev_same_line_keyword_1", MLFeatureValue.numerical(kwId("in"))),
Pair("prev_same_line_keyword_2", MLFeatureValue.numerical(kwId("if"))))
fun testSameColumnKeywordsIfElif() = doContextFeaturesTest(Pair("prev_same_column_keyword_1", MLFeatureValue.numerical(kwId("elif"))),
Pair("prev_same_column_keyword_2", MLFeatureValue.numerical(kwId("if"))))
fun testSameColumnKeywordsIfSeparateLineIf() = doContextFeaturesTest(arrayListOf(Pair("prev_same_column_keyword_1", MLFeatureValue.numerical(kwId("if")))),
arrayListOf("prev_same_column_keyword_2"))
fun testSameColumnKeywordsDefDef() = doContextFeaturesTest(Pair("prev_same_column_keyword_1", MLFeatureValue.numerical(kwId("def"))),
Pair("prev_same_column_keyword_2", MLFeatureValue.numerical(kwId("def"))))
fun testSameColumnKeywordsDefDefIntoCalss() = doContextFeaturesTest(Pair("prev_same_column_keyword_1", MLFeatureValue.numerical(kwId("def"))),
Pair("prev_same_column_keyword_2", MLFeatureValue.numerical(kwId("def"))))
fun testSameColumnKeywordsIfFor() = doContextFeaturesTest(Pair("prev_same_column_keyword_1", MLFeatureValue.numerical(kwId("if"))),
Pair("prev_same_column_keyword_2", MLFeatureValue.numerical(kwId("for"))))
fun testSameColumnKeywordsForSeparateLineIf() = doContextFeaturesTest(arrayListOf(Pair("prev_same_column_keyword_1", MLFeatureValue.numerical(kwId("if")))),
arrayListOf("prev_same_column_keyword_2"))
fun testHaveOpeningRoundBracket() = doContextFeaturesTest(Pair("have_opening_round_bracket", MLFeatureValue.binary(true)))
fun testHaveOpeningSquareBracket() = doContextFeaturesTest(
Pair("have_opening_square_bracket", MLFeatureValue.binary(true)),
Pair("have_opening_round_bracket", MLFeatureValue.binary(false)))
fun testHaveOpeningBrace() = doContextFeaturesTest(Pair("have_opening_brace", MLFeatureValue.binary(true)))
fun testInsideClassConstructorPlace() = doContextFeaturesTest(Pair("containing_class_have_constructor", MLFeatureValue.binary(false)),
Pair("diff_lines_with_class_def", MLFeatureValue.numerical(1)))
fun testInsideClassAfterConstructor() = doContextFeaturesTest(Pair("containing_class_have_constructor", MLFeatureValue.binary(true)),
Pair("diff_lines_with_class_def", MLFeatureValue.numerical(4)))
// Element features
fun testDictKey() = doElementFeaturesTest("\"dict_key\"",
Pair("is_dict_key", MLFeatureValue.binary(true)),
Pair("underscore_type", MLFeatureValue.categorical(
PyCompletionFeatures.ElementNameUnderscoreType.NO_UNDERSCORE)))
fun testIsTakesParameterSelf() = doElementFeaturesTest(listOf(
Pair("foo", listOf(Pair("is_takes_parameter_self", MLFeatureValue.binary(true)))),
Pair("__init__", listOf(Pair("is_takes_parameter_self", MLFeatureValue.binary(true)))),
Pair("bar", listOf(Pair("is_takes_parameter_self", MLFeatureValue.binary(false))))))
fun testUnderscoreTypeTwoStartEnd() = doElementFeaturesTest("__init__",
Pair("underscore_type", MLFeatureValue.categorical(
PyCompletionFeatures.ElementNameUnderscoreType.TWO_START_END)))
fun testUnderscoreTypeTwoStart() = doElementFeaturesTest(listOf(
Pair("__private_var",
listOf(Pair("underscore_type", MLFeatureValue.categorical(PyCompletionFeatures.ElementNameUnderscoreType.TWO_START)))),
Pair("_private_var",
listOf(Pair("underscore_type", MLFeatureValue.categorical(PyCompletionFeatures.ElementNameUnderscoreType.ONE_START))))))
fun testNumberOfOccurrencesFunction() = doElementFeaturesTest("min",
Pair("number_of_occurrences_in_scope", MLFeatureValue.numerical(2)),
Pair("kind", MLFeatureValue.categorical(PyCompletionMlElementKind.FUNCTION)),
Pair("is_builtins", MLFeatureValue.binary(true)))
fun testNumberOfOccurrencesClass() = doElementFeaturesTest("MyClazz",
Pair("number_of_occurrences_in_scope", MLFeatureValue.numerical(1)),
Pair("kind", MLFeatureValue.categorical(PyCompletionMlElementKind.TYPE_OR_CLASS)),
Pair("is_builtins", MLFeatureValue.binary(false)))
fun testNumberOfOccurrencesNamedArgsWithPrefix() = doElementFeaturesTest("end=",
Pair("number_of_occurrences_in_scope", MLFeatureValue.numerical(1)),
Pair("kind", MLFeatureValue.categorical(PyCompletionMlElementKind.NAMED_ARG)),
Pair("is_builtins", MLFeatureValue.binary(false)))
fun testNumberOfOccurrencesNamedArgsEmptyPrefix() = doElementFeaturesTest("file=",
Pair("number_of_occurrences_in_scope", MLFeatureValue.numerical(0)),
Pair("kind", MLFeatureValue.categorical(PyCompletionMlElementKind.NAMED_ARG)),
Pair("is_builtins", MLFeatureValue.binary(false)))
fun testNumberOfOccurrencesPackagesOrModules() = doElementFeaturesTest("collections",
Pair("number_of_occurrences_in_scope", MLFeatureValue.numerical(1)),
Pair("kind", MLFeatureValue.categorical(PyCompletionMlElementKind.PACKAGE_OR_MODULE)),
Pair("is_builtins", MLFeatureValue.binary(false)))
fun testKindNamedArg() = doElementFeaturesTest("sep=",
Pair("kind", MLFeatureValue.categorical(PyCompletionMlElementKind.NAMED_ARG)),
Pair("number_of_tokens", MLFeatureValue.numerical(1)),
Pair("is_builtins", MLFeatureValue.binary(false)))
fun testClassBuiltins() = doElementFeaturesTest("Exception",
Pair("kind", MLFeatureValue.categorical(PyCompletionMlElementKind.TYPE_OR_CLASS)),
Pair("number_of_tokens", MLFeatureValue.numerical(1)),
Pair("is_builtins", MLFeatureValue.binary(true)))
fun testClassNotBuiltins() = doElementFeaturesTest("MyClazz",
Pair("kind", MLFeatureValue.categorical(PyCompletionMlElementKind.TYPE_OR_CLASS)),
Pair("number_of_tokens", MLFeatureValue.numerical(2)),
Pair("is_builtins", MLFeatureValue.binary(false)))
fun testFunctionBuiltins() = doElementFeaturesTest("max",
Pair("kind", MLFeatureValue.categorical(PyCompletionMlElementKind.FUNCTION)),
Pair("is_builtins", MLFeatureValue.binary(true)))
fun testFunctionNotBuiltins() = doElementFeaturesTest("my_not_builtins_function",
Pair("kind", MLFeatureValue.categorical(PyCompletionMlElementKind.FUNCTION)),
Pair("number_of_tokens", MLFeatureValue.numerical(4)),
Pair("is_builtins", MLFeatureValue.binary(false)))
fun testKindPackageOrModule() = doElementFeaturesTest("sys",
Pair("kind", MLFeatureValue.categorical(PyCompletionMlElementKind.PACKAGE_OR_MODULE)),
Pair("is_builtins", MLFeatureValue.binary(false)))
fun testKindFromTargetAssignment() = doElementFeaturesTest("local_variable",
Pair("kind", MLFeatureValue.categorical(PyCompletionMlElementKind.FROM_TARGET)),
Pair("is_builtins", MLFeatureValue.binary(false)))
fun testKindFromTargetAs() = doElementFeaturesTest("as_target",
Pair("kind", MLFeatureValue.categorical(PyCompletionMlElementKind.FROM_TARGET)),
Pair("is_builtins", MLFeatureValue.binary(false)))
fun testKindKeyword() = doElementFeaturesTest("if",
arrayListOf(
Pair("kind", MLFeatureValue.categorical(PyCompletionMlElementKind.KEYWORD)),
Pair("keyword_id", MLFeatureValue.numerical(kwId("if"))),
Pair("is_builtins", MLFeatureValue.binary(false))),
arrayListOf("standard_type"))
fun testScopeMatchesSimple() = doElementFeaturesTest("abacaba",
Pair("scope_num_names", MLFeatureValue.numerical(5)),
Pair("scope_num_different_names", MLFeatureValue.numerical(3)),
Pair("scope_num_matches", MLFeatureValue.numerical(2)),
Pair("scope_num_tokens_matches", MLFeatureValue.numerical(2)))
fun testScopeMatchesTokens() = doElementFeaturesTest("oneFourThreeTwo",
Pair("scope_num_names", MLFeatureValue.numerical(13)),
Pair("scope_num_different_names", MLFeatureValue.numerical(10)),
Pair("scope_num_matches", MLFeatureValue.numerical(1)),
Pair("scope_num_tokens_matches", MLFeatureValue.numerical(20)))
fun testScopeMatchesNonEmptyPrefix() = doElementFeaturesTest("someParam1",
Pair("scope_num_names", MLFeatureValue.numerical(6)),
Pair("scope_num_different_names", MLFeatureValue.numerical(4)),
Pair("scope_num_matches", MLFeatureValue.numerical(2)),
Pair("scope_num_tokens_matches", MLFeatureValue.numerical(6)))
fun testScopeFileDontConsiderFunctionBodies() = doElementFeaturesTest("SOME_VAR_3",
Pair("scope_num_names", MLFeatureValue.numerical(6)),
Pair("scope_num_different_names", MLFeatureValue.numerical(6)),
Pair("scope_num_matches", MLFeatureValue.numerical(1)),
Pair("scope_num_tokens_matches", MLFeatureValue.numerical(8)))
fun testScopeClassDontConsiderFunctionBodies() = doElementFeaturesTest("SOME_VAR_3",
Pair("scope_num_names", MLFeatureValue.numerical(7)),
Pair("scope_num_different_names", MLFeatureValue.numerical(7)),
Pair("scope_num_matches", MLFeatureValue.numerical(1)),
Pair("scope_num_tokens_matches", MLFeatureValue.numerical(9)))
fun testSameLineMatchingSimple() = doElementFeaturesTest("some_param3",
Pair("same_line_num_names", MLFeatureValue.numerical(6)),
Pair("same_line_num_different_names", MLFeatureValue.numerical(5)),
Pair("same_line_num_matches", MLFeatureValue.numerical(0)),
Pair("same_line_num_tokens_matches", MLFeatureValue.numerical(5)))
fun testReceiverMatchesSimple() = doElementFeaturesTest("someParam",
Pair("receiver_name_matches", MLFeatureValue.binary(true)),
Pair("receiver_num_matched_tokens", MLFeatureValue.numerical(2)),
Pair("receiver_tokens_num", MLFeatureValue.numerical(2)))
fun testReceiverMatchesTokens() = doElementFeaturesTest("someToken2_sets1",
Pair("receiver_name_matches", MLFeatureValue.binary(false)),
Pair("receiver_num_matched_tokens", MLFeatureValue.numerical(3)),
Pair("receiver_tokens_num", MLFeatureValue.numerical(3)))
fun testReceiverMatchesAssignment() = doElementFeaturesTest("abaCaba",
Pair("receiver_name_matches", MLFeatureValue.binary(true)),
Pair("receiver_num_matched_tokens", MLFeatureValue.numerical(2)),
Pair("receiver_tokens_num", MLFeatureValue.numerical(2)))
fun testReceiverMatchesTokensUpperCase() = doElementFeaturesTest("SOME_TOKENS_SET",
Pair("receiver_name_matches", MLFeatureValue.binary(false)),
Pair("receiver_num_matched_tokens", MLFeatureValue.numerical(3)),
Pair("receiver_tokens_num", MLFeatureValue.numerical(3)))
private fun doContextFeaturesTest(vararg expected: Pair<String, MLFeatureValue>) = doContextFeaturesTest(listOf(*expected), emptyList())
private fun doContextFeaturesTest(expectedDefined: List<Pair<String, MLFeatureValue>>, expectedUndefined: List<String>) {
doWithInstalledProviders { contextFeaturesProvider, _ ->
invokeCompletion()
assertHasFeatures(contextFeaturesProvider.features, expectedDefined)
assertHasNotFeatures(contextFeaturesProvider.features, expectedUndefined)
}
}
private fun doElementFeaturesTest(checks: List<Pair<String, List<Pair<String, MLFeatureValue>>>>) {
checks.forEach {
doElementFeaturesTest(it.first, it.second, emptyList())
}
}
private fun doElementFeaturesTest(elementToSelect: String, vararg expected: Pair<String, MLFeatureValue>) {
doElementFeaturesTest(elementToSelect, arrayListOf(*expected), emptyList())
}
private fun doElementFeaturesTest(elementToSelect: String,
expectedDefined: List<Pair<String, MLFeatureValue>>,
expectedUndefined: List<String>) {
val selector: (LookupElement) -> Boolean = { it.lookupString == elementToSelect }
doElementFeaturesInternalTest(selector, expectedDefined, expectedUndefined)
}
private fun doElementFeaturesInternalTest(selector: (LookupElement) -> Boolean,
expectedDefined: List<Pair<String, MLFeatureValue>>,
expectedUndefined: List<String>) {
doWithInstalledProviders { _, elementFeaturesProvider ->
invokeCompletion()
val selected = myFixture.lookupElements!!.find(selector)
assertNotNull(selected)
val features = elementFeaturesProvider.features[selected]
assertNotNull(features)
assertHasFeatures(features!!, expectedDefined)
assertHasNotFeatures(features, expectedUndefined)
}
}
private fun doWithInstalledProviders(action: (contextFeaturesProvider: PyAdapterContextFeatureProvider,
elementFeaturesProvider: PyAdapterElementFeatureProvider) -> Unit) {
val contextFeaturesProvider = PyAdapterContextFeatureProvider(PyContextFeatureProvider())
val elementFeaturesProvider = PyAdapterElementFeatureProvider(PyElementFeatureProvider())
try {
ContextFeatureProvider.EP_NAME.addExplicitExtension(PythonLanguage.INSTANCE, contextFeaturesProvider)
ElementFeatureProvider.EP_NAME.addExplicitExtension(PythonLanguage.INSTANCE, elementFeaturesProvider)
action(contextFeaturesProvider, elementFeaturesProvider)
}
finally {
ContextFeatureProvider.EP_NAME.removeExplicitExtension(PythonLanguage.INSTANCE, contextFeaturesProvider)
ElementFeatureProvider.EP_NAME.removeExplicitExtension(PythonLanguage.INSTANCE, elementFeaturesProvider)
}
}
private fun assertHasFeatures(actual: Map<String, MLFeatureValue>,
expectedDefined: List<Pair<String, MLFeatureValue>>) {
for (pair in expectedDefined) {
Assert.assertTrue("Assert has feature: ${pair.first}", actual.containsKey(pair.first))
Assert.assertEquals("Check feature value: ${pair.first}", pair.second.toString(), actual[pair.first].toString())
}
}
private fun assertHasNotFeatures(actual: Map<String, MLFeatureValue>, expected: List<String>) {
for (value in expected) {
Assert.assertFalse("Assert has not feature: $value", actual.containsKey(value))
}
}
private fun invokeCompletion() {
myFixture.configureByFile(getTestName(true) + ".py")
myFixture.completeBasic()
}
private fun kwId(kw: String): Int {
return PyMlCompletionHelpers.getKeywordId(kw)!!
}
}

View File

@@ -36,5 +36,6 @@
<orderEntry type="module" module-name="intellij.xml.langInjection" scope="RUNTIME" />
<orderEntry type="module" module-name="intellij.python.langInjection" scope="RUNTIME" />
<orderEntry type="module" module-name="intellij.python.reStructuredText" scope="RUNTIME" />
<orderEntry type="module" module-name="intellij.statsCollector" scope="TEST" />
</component>
</module>
</module>