Provide completion and resolve for from tensorflow.<...> statements (PY-33034)

GitOrigin-RevId: 5886d44f185f918cf0089249209495de3145b6d1
This commit is contained in:
Semyon Proshev
2019-07-09 12:52:33 +03:00
committed by intellij-monorepo-bot
parent c47d883ba3
commit 16a77f4042
5 changed files with 109 additions and 2 deletions

View File

@@ -8,6 +8,7 @@ import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiPackage;
import com.intellij.psi.util.QualifiedName;
import com.jetbrains.python.psi.resolve.PyQualifiedNameResolveContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
/**
@@ -16,7 +17,7 @@ import org.jetbrains.annotations.Nullable;
public class PyJavaImportResolver implements PyImportResolver {
@Override
@Nullable
public PsiElement resolveImportReference(QualifiedName name, PyQualifiedNameResolveContext context, boolean withRoots) {
public PsiElement resolveImportReference(@NotNull QualifiedName name, @NotNull PyQualifiedNameResolveContext context, boolean withRoots) {
String fqn = name.toString();
final JavaPsiFacade psiFacade = JavaPsiFacade.getInstance(context.getProject());
final PsiPackage aPackage = psiFacade.findPackage(fqn);

View File

@@ -19,6 +19,7 @@ import com.intellij.openapi.extensions.ExtensionPointName;
import com.intellij.psi.PsiElement;
import com.intellij.psi.util.QualifiedName;
import com.jetbrains.python.psi.resolve.PyQualifiedNameResolveContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
/**
@@ -28,5 +29,5 @@ public interface PyImportResolver {
ExtensionPointName<PyImportResolver> EP_NAME = ExtensionPointName.create("Pythonid.importResolver");
@Nullable
PsiElement resolveImportReference(QualifiedName name, PyQualifiedNameResolveContext context, boolean withRoots);
PsiElement resolveImportReference(@NotNull QualifiedName name, @NotNull PyQualifiedNameResolveContext context, boolean withRoots);
}

View File

@@ -807,6 +807,10 @@
<!-- PyQt -->
<typeProvider implementation="com.jetbrains.pyqt.PyQtTypeProvider"/>
<!-- TensorFlow -->
<importResolver implementation="com.jetbrains.tensorFlow.PyTensorFlowImportResolver"/>
<pyModuleMembersProvider implementation="com.jetbrains.tensorFlow.PyTensorFlowModuleMembersProvider"/>
<!-- Type from ancestors -->
<typeProvider implementation="com.jetbrains.python.codeInsight.typing.PyAncestorTypeProvider"/>

View File

@@ -0,0 +1,47 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.tensorFlow
import com.intellij.psi.PsiElement
import com.intellij.psi.util.QualifiedName
import com.jetbrains.python.psi.impl.PyImportResolver
import com.jetbrains.python.psi.resolve.PyQualifiedNameResolveContext
import com.jetbrains.python.psi.resolve.resolveQualifiedName
import org.jetbrains.annotations.ApiStatus
internal const val KERAS: String = "tensorflow.python.keras.api._v[__VERSION__].keras"
internal const val ESTIMATOR: String = "tensorflow_estimator.python.estimator.api._v[__VERSION__].estimator"
internal const val OTHERS: String = "tensorflow._api.v[__VERSION__]"
internal fun resolveInTensorFlow(qualifiedNameTemplate: String, context: PyQualifiedNameResolveContext): Sequence<List<PsiElement>> {
return sequenceOf("2", "1")
.map { qualifiedNameTemplate.replaceFirst("[__VERSION__]", it) }
.map { QualifiedName.fromDottedString(it) }
.map { resolveQualifiedName(it, context) }
}
internal fun takeFirstResolvedInTensorFlow(qualifiedNameTemplate: String, context: PyQualifiedNameResolveContext): PsiElement? {
return resolveInTensorFlow(qualifiedNameTemplate, context).mapNotNull { it.firstOrNull() }.firstOrNull()
}
@ApiStatus.Internal
@ApiStatus.NonExtendable
class PyTensorFlowImportResolver : PyImportResolver {
override fun resolveImportReference(name: QualifiedName, context: PyQualifiedNameResolveContext, withRoots: Boolean): PsiElement? {
// resolve `from tensorflow.<reference>` reference
// tensorflow submodules and subpackages are appended in runtime and have original location in other places
return when {
name.matchesPrefix(QualifiedName.fromComponents("tensorflow", "keras")) -> {
takeFirstResolvedInTensorFlow("$KERAS.${name.removeHead(2)}", context.copyWithoutForeign())
}
name.matchesPrefix(QualifiedName.fromComponents("tensorflow", "estimator")) -> {
takeFirstResolvedInTensorFlow("$ESTIMATOR.${name.removeHead(2)}", context.copyWithoutForeign())
}
name.firstComponent == "tensorflow" && name.componentCount >= 2 -> {
takeFirstResolvedInTensorFlow("$OTHERS.${name.removeHead(1)}", context.copyWithoutForeign())
}
else -> null
}
}
}

View File

@@ -0,0 +1,54 @@
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.tensorFlow
import com.intellij.psi.PsiDirectory
import com.intellij.psi.PsiElement
import com.jetbrains.python.codeInsight.PyCustomMember
import com.jetbrains.python.psi.PyFile
import com.jetbrains.python.psi.PyUtil
import com.jetbrains.python.psi.resolve.PointInImport
import com.jetbrains.python.psi.resolve.PyResolveContext
import com.jetbrains.python.psi.resolve.QualifiedNameFinder
import com.jetbrains.python.psi.resolve.fromFoothold
import com.jetbrains.python.psi.types.PyModuleMembersProvider
import com.jetbrains.python.psi.types.TypeEvalContext
import org.jetbrains.annotations.ApiStatus
@ApiStatus.Internal
@ApiStatus.NonExtendable
class PyTensorFlowModuleMembersProvider : PyModuleMembersProvider() {
override fun getMembers(module: PyFile, point: PointInImport, context: TypeEvalContext): Collection<PyCustomMember> {
if (point != PointInImport.AS_MODULE ||
QualifiedNameFinder.findShortestImportableQName(module).let { it == null || !it.matches("tensorflow") }) return emptyList()
// provide members for `from tensorflow.<caret>` reference
// provided modules and subpackages are appended in runtime and have original location in other places
val resolveContext = fromFoothold(module).copyWithoutForeign()
val result = mutableListOf<PyCustomMember>()
takeFirstResolvedInTensorFlow(KERAS, resolveContext)
?.let { PyUtil.turnDirIntoInit(it) }
?.let { result.add(PyCustomMember("keras", it)) }
takeFirstResolvedInTensorFlow(ESTIMATOR, resolveContext)
?.let { PyUtil.turnDirIntoInit(it) }
?.let { result.add(PyCustomMember("estimator", it)) }
resolveInTensorFlow(OTHERS, resolveContext)
.mapNotNull { it.filterIsInstance<PsiDirectory>().firstOrNull() }
.firstOrNull()
?.let { dir ->
dir.subdirectories.forEach { subdir ->
PyUtil.turnDirIntoInit(subdir)?.let { result.add(PyCustomMember(subdir.name, it)) }
}
}
return result
}
override fun resolveMember(module: PyFile, name: String, resolveContext: PyResolveContext): PsiElement? = null
override fun getMembersByQName(module: PyFile, qName: String, context: TypeEvalContext): Collection<PyCustomMember> = emptyList()
}