mirror of
https://gitflic.ru/project/openide/openide.git
synced 2026-04-19 04:51:24 +07:00
Provide completion and resolve for from tensorflow.<...> statements (PY-33034)
GitOrigin-RevId: 5886d44f185f918cf0089249209495de3145b6d1
This commit is contained in:
committed by
intellij-monorepo-bot
parent
c47d883ba3
commit
16a77f4042
@@ -8,6 +8,7 @@ import com.intellij.psi.PsiElement;
|
||||
import com.intellij.psi.PsiPackage;
|
||||
import com.intellij.psi.util.QualifiedName;
|
||||
import com.jetbrains.python.psi.resolve.PyQualifiedNameResolveContext;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
/**
|
||||
@@ -16,7 +17,7 @@ import org.jetbrains.annotations.Nullable;
|
||||
public class PyJavaImportResolver implements PyImportResolver {
|
||||
@Override
|
||||
@Nullable
|
||||
public PsiElement resolveImportReference(QualifiedName name, PyQualifiedNameResolveContext context, boolean withRoots) {
|
||||
public PsiElement resolveImportReference(@NotNull QualifiedName name, @NotNull PyQualifiedNameResolveContext context, boolean withRoots) {
|
||||
String fqn = name.toString();
|
||||
final JavaPsiFacade psiFacade = JavaPsiFacade.getInstance(context.getProject());
|
||||
final PsiPackage aPackage = psiFacade.findPackage(fqn);
|
||||
|
||||
@@ -19,6 +19,7 @@ import com.intellij.openapi.extensions.ExtensionPointName;
|
||||
import com.intellij.psi.PsiElement;
|
||||
import com.intellij.psi.util.QualifiedName;
|
||||
import com.jetbrains.python.psi.resolve.PyQualifiedNameResolveContext;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
/**
|
||||
@@ -28,5 +29,5 @@ public interface PyImportResolver {
|
||||
ExtensionPointName<PyImportResolver> EP_NAME = ExtensionPointName.create("Pythonid.importResolver");
|
||||
|
||||
@Nullable
|
||||
PsiElement resolveImportReference(QualifiedName name, PyQualifiedNameResolveContext context, boolean withRoots);
|
||||
PsiElement resolveImportReference(@NotNull QualifiedName name, @NotNull PyQualifiedNameResolveContext context, boolean withRoots);
|
||||
}
|
||||
|
||||
@@ -807,6 +807,10 @@
|
||||
<!-- PyQt -->
|
||||
<typeProvider implementation="com.jetbrains.pyqt.PyQtTypeProvider"/>
|
||||
|
||||
<!-- TensorFlow -->
|
||||
<importResolver implementation="com.jetbrains.tensorFlow.PyTensorFlowImportResolver"/>
|
||||
<pyModuleMembersProvider implementation="com.jetbrains.tensorFlow.PyTensorFlowModuleMembersProvider"/>
|
||||
|
||||
<!-- Type from ancestors -->
|
||||
<typeProvider implementation="com.jetbrains.python.codeInsight.typing.PyAncestorTypeProvider"/>
|
||||
|
||||
|
||||
@@ -0,0 +1,47 @@
|
||||
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
package com.jetbrains.tensorFlow
|
||||
|
||||
import com.intellij.psi.PsiElement
|
||||
import com.intellij.psi.util.QualifiedName
|
||||
import com.jetbrains.python.psi.impl.PyImportResolver
|
||||
import com.jetbrains.python.psi.resolve.PyQualifiedNameResolveContext
|
||||
import com.jetbrains.python.psi.resolve.resolveQualifiedName
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
internal const val KERAS: String = "tensorflow.python.keras.api._v[__VERSION__].keras"
|
||||
internal const val ESTIMATOR: String = "tensorflow_estimator.python.estimator.api._v[__VERSION__].estimator"
|
||||
internal const val OTHERS: String = "tensorflow._api.v[__VERSION__]"
|
||||
|
||||
internal fun resolveInTensorFlow(qualifiedNameTemplate: String, context: PyQualifiedNameResolveContext): Sequence<List<PsiElement>> {
|
||||
return sequenceOf("2", "1")
|
||||
.map { qualifiedNameTemplate.replaceFirst("[__VERSION__]", it) }
|
||||
.map { QualifiedName.fromDottedString(it) }
|
||||
.map { resolveQualifiedName(it, context) }
|
||||
}
|
||||
|
||||
internal fun takeFirstResolvedInTensorFlow(qualifiedNameTemplate: String, context: PyQualifiedNameResolveContext): PsiElement? {
|
||||
return resolveInTensorFlow(qualifiedNameTemplate, context).mapNotNull { it.firstOrNull() }.firstOrNull()
|
||||
}
|
||||
|
||||
@ApiStatus.Internal
|
||||
@ApiStatus.NonExtendable
|
||||
class PyTensorFlowImportResolver : PyImportResolver {
|
||||
|
||||
override fun resolveImportReference(name: QualifiedName, context: PyQualifiedNameResolveContext, withRoots: Boolean): PsiElement? {
|
||||
// resolve `from tensorflow.<reference>` reference
|
||||
// tensorflow submodules and subpackages are appended in runtime and have original location in other places
|
||||
|
||||
return when {
|
||||
name.matchesPrefix(QualifiedName.fromComponents("tensorflow", "keras")) -> {
|
||||
takeFirstResolvedInTensorFlow("$KERAS.${name.removeHead(2)}", context.copyWithoutForeign())
|
||||
}
|
||||
name.matchesPrefix(QualifiedName.fromComponents("tensorflow", "estimator")) -> {
|
||||
takeFirstResolvedInTensorFlow("$ESTIMATOR.${name.removeHead(2)}", context.copyWithoutForeign())
|
||||
}
|
||||
name.firstComponent == "tensorflow" && name.componentCount >= 2 -> {
|
||||
takeFirstResolvedInTensorFlow("$OTHERS.${name.removeHead(1)}", context.copyWithoutForeign())
|
||||
}
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,54 @@
|
||||
// Copyright 2000-2019 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||
package com.jetbrains.tensorFlow
|
||||
|
||||
import com.intellij.psi.PsiDirectory
|
||||
import com.intellij.psi.PsiElement
|
||||
import com.jetbrains.python.codeInsight.PyCustomMember
|
||||
import com.jetbrains.python.psi.PyFile
|
||||
import com.jetbrains.python.psi.PyUtil
|
||||
import com.jetbrains.python.psi.resolve.PointInImport
|
||||
import com.jetbrains.python.psi.resolve.PyResolveContext
|
||||
import com.jetbrains.python.psi.resolve.QualifiedNameFinder
|
||||
import com.jetbrains.python.psi.resolve.fromFoothold
|
||||
import com.jetbrains.python.psi.types.PyModuleMembersProvider
|
||||
import com.jetbrains.python.psi.types.TypeEvalContext
|
||||
import org.jetbrains.annotations.ApiStatus
|
||||
|
||||
@ApiStatus.Internal
|
||||
@ApiStatus.NonExtendable
|
||||
class PyTensorFlowModuleMembersProvider : PyModuleMembersProvider() {
|
||||
|
||||
override fun getMembers(module: PyFile, point: PointInImport, context: TypeEvalContext): Collection<PyCustomMember> {
|
||||
if (point != PointInImport.AS_MODULE ||
|
||||
QualifiedNameFinder.findShortestImportableQName(module).let { it == null || !it.matches("tensorflow") }) return emptyList()
|
||||
|
||||
// provide members for `from tensorflow.<caret>` reference
|
||||
// provided modules and subpackages are appended in runtime and have original location in other places
|
||||
|
||||
val resolveContext = fromFoothold(module).copyWithoutForeign()
|
||||
val result = mutableListOf<PyCustomMember>()
|
||||
|
||||
takeFirstResolvedInTensorFlow(KERAS, resolveContext)
|
||||
?.let { PyUtil.turnDirIntoInit(it) }
|
||||
?.let { result.add(PyCustomMember("keras", it)) }
|
||||
|
||||
takeFirstResolvedInTensorFlow(ESTIMATOR, resolveContext)
|
||||
?.let { PyUtil.turnDirIntoInit(it) }
|
||||
?.let { result.add(PyCustomMember("estimator", it)) }
|
||||
|
||||
resolveInTensorFlow(OTHERS, resolveContext)
|
||||
.mapNotNull { it.filterIsInstance<PsiDirectory>().firstOrNull() }
|
||||
.firstOrNull()
|
||||
?.let { dir ->
|
||||
dir.subdirectories.forEach { subdir ->
|
||||
PyUtil.turnDirIntoInit(subdir)?.let { result.add(PyCustomMember(subdir.name, it)) }
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
override fun resolveMember(module: PyFile, name: String, resolveContext: PyResolveContext): PsiElement? = null
|
||||
|
||||
override fun getMembersByQName(module: PyFile, qName: String, context: TypeEvalContext): Collection<PyCustomMember> = emptyList()
|
||||
}
|
||||
Reference in New Issue
Block a user