Update PyTypeProvider.getReferenceType to return Ref (PY-28052)

This commit is contained in:
Semyon Proshev
2018-04-06 18:13:21 +03:00
parent 198f44c9af
commit 36fb23d4ab
16 changed files with 96 additions and 70 deletions

View File

@@ -15,16 +15,17 @@
*/
package com.jetbrains.python.psi.impl;
import com.intellij.openapi.module.ModuleUtil;
import com.intellij.openapi.module.Module;
import com.intellij.openapi.module.ModuleUtilCore;
import com.intellij.openapi.util.Ref;
import com.intellij.psi.*;
import com.intellij.util.Processor;
import com.jetbrains.python.psi.PyFunction;
import com.jetbrains.python.psi.PyNamedParameter;
import com.jetbrains.python.psi.PyParameterList;
import com.jetbrains.python.psi.search.PySuperMethodsSearch;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.PyTypeProviderBase;
import com.jetbrains.python.psi.types.PyTypeUtil;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
@@ -36,26 +37,27 @@ import java.util.List;
* @author yole
*/
public class PyJavaTypeProvider extends PyTypeProviderBase {
@Override
@Nullable
public PyType getReferenceType(@NotNull final PsiElement referenceTarget, TypeEvalContext context, @Nullable PsiElement anchor) {
public Ref<PyType> getReferenceType(@NotNull PsiElement referenceTarget, @NotNull TypeEvalContext context, @Nullable PsiElement anchor) {
if (referenceTarget instanceof PsiClass) {
return new PyJavaClassType((PsiClass) referenceTarget, true);
return Ref.create(new PyJavaClassType((PsiClass)referenceTarget, true));
}
if (referenceTarget instanceof PsiPackage) {
return new PyJavaPackageType((PsiPackage) referenceTarget, anchor == null ? null : ModuleUtil.findModuleForPsiElement(anchor));
final Module module = anchor == null ? null : ModuleUtilCore.findModuleForPsiElement(anchor);
return Ref.create(new PyJavaPackageType((PsiPackage)referenceTarget, module));
}
if (referenceTarget instanceof PsiMethod) {
PsiMethod method = (PsiMethod) referenceTarget;
return new PyJavaMethodType(method);
return Ref.create(new PyJavaMethodType((PsiMethod)referenceTarget));
}
if (referenceTarget instanceof PsiField) {
return asPyType(((PsiField)referenceTarget).getType());
return PyTypeUtil.notNullToRef(asPyType(((PsiField)referenceTarget).getType()));
}
return null;
}
@Nullable
public static PyType asPyType(PsiType type) {
public static PyType asPyType(@Nullable PsiType type) {
if (type instanceof PsiClassType) {
final PsiClassType classType = (PsiClassType)type;
final PsiClass psiClass = classType.resolve();
@@ -66,6 +68,7 @@ public class PyJavaTypeProvider extends PyTypeProviderBase {
return null;
}
@Override
public Ref<PyType> getParameterType(@NotNull final PyNamedParameter param,
@NotNull final PyFunction func,
@NotNull TypeEvalContext context) {

View File

@@ -36,7 +36,7 @@ public interface PyTypeProvider {
PyType getReferenceExpressionType(@NotNull PyReferenceExpression referenceExpression, @NotNull TypeEvalContext context);
@Nullable
PyType getReferenceType(@NotNull PsiElement referenceTarget, TypeEvalContext context, @Nullable PsiElement anchor);
Ref<PyType> getReferenceType(@NotNull PsiElement referenceTarget, @NotNull TypeEvalContext context, @Nullable PsiElement anchor);
@Nullable
Ref<PyType> getParameterType(@NotNull PyNamedParameter param, @NotNull PyFunction func, @NotNull TypeEvalContext context);

View File

@@ -47,7 +47,7 @@ public class PyTypeProviderBase implements PyTypeProvider {
}
@Override
public PyType getReferenceType(@NotNull PsiElement referenceTarget, TypeEvalContext context, @Nullable PsiElement anchor) {
public Ref<PyType> getReferenceType(@NotNull PsiElement referenceTarget, @NotNull TypeEvalContext context, @Nullable PsiElement anchor) {
return null;
}

View File

@@ -418,16 +418,16 @@ public class NumpyDocStringTypeProvider extends PyTypeProviderBase {
}
@Override
public PyType getReferenceType(@NotNull PsiElement referenceTarget, TypeEvalContext context, @Nullable PsiElement anchor) {
public Ref<PyType> getReferenceType(@NotNull PsiElement referenceTarget, @NotNull TypeEvalContext context, @Nullable PsiElement anchor) {
if (referenceTarget instanceof PyFunction) {
if (NumpyUfuncs.isUFunc(((PyFunction)referenceTarget).getName()) && isInsideNumPy(referenceTarget)) {
// we intentionally looking here for the user stub class
final PyClass uFuncClass = PyPsiFacade.getInstance(referenceTarget.getProject()).findClass("numpy.core.ufunc");
if (uFuncClass != null) {
return new PyClassTypeImpl(uFuncClass, false);
return Ref.create(new PyClassTypeImpl(uFuncClass, false));
}
}
}
return super.getReferenceType(referenceTarget, context, anchor);
return null;
}
}

View File

@@ -15,16 +15,20 @@
*/
package com.jetbrains.python.codeInsight.stdlib
import com.intellij.openapi.util.Ref
import com.intellij.psi.PsiElement
import com.jetbrains.python.psi.impl.PyOverridingTypeProvider
import com.jetbrains.python.psi.types.PyType
import com.jetbrains.python.psi.types.PyTypeProviderBase
import com.jetbrains.python.psi.types.PyTypeUtil
import com.jetbrains.python.psi.types.TypeEvalContext
class PyNamedTuplesOverridingTypeProvider : PyTypeProviderBase(), PyOverridingTypeProvider {
override fun getReferenceType(referenceTarget: PsiElement, context: TypeEvalContext, anchor: PsiElement?): PyType? {
return PyNamedTuplesTypeProvider.getNamedTupleTypeForResolvedCallee(referenceTarget, context, anchor) ?:
PyNamedTuplesTypeProvider.getNamedTupleReplaceType(referenceTarget, context, anchor)
override fun getReferenceType(referenceTarget: PsiElement, context: TypeEvalContext, anchor: PsiElement?): Ref<PyType>? {
val type = PyNamedTuplesTypeProvider.getNamedTupleTypeForResolvedCallee(referenceTarget, context, anchor)
?: PyNamedTuplesTypeProvider.getNamedTupleReplaceType(referenceTarget, context, anchor)
return PyTypeUtil.notNullToRef(type)
}
}

View File

@@ -24,8 +24,8 @@ private typealias ImmutableNTFields = Map<String, PyNamedTupleType.FieldTypeAndD
class PyNamedTuplesTypeProvider : PyTypeProviderBase() {
override fun getReferenceType(referenceTarget: PsiElement, context: TypeEvalContext, anchor: PsiElement?): PyType? {
return getNamedTupleTypeForResolvedCallee(referenceTarget, context, anchor)
override fun getReferenceType(referenceTarget: PsiElement, context: TypeEvalContext, anchor: PsiElement?): Ref<PyType>? {
return PyTypeUtil.notNullToRef(getNamedTupleTypeForResolvedCallee(referenceTarget, context, anchor))
}
override fun getReferenceExpressionType(referenceExpression: PyReferenceExpression, context: TypeEvalContext): PyType? {

View File

@@ -44,14 +44,14 @@ public class PyStdlibTypeProvider extends PyTypeProviderBase {
}
@Override
public PyType getReferenceType(@NotNull PsiElement referenceTarget, @NotNull TypeEvalContext context, @Nullable PsiElement anchor) {
public Ref<PyType> getReferenceType(@NotNull PsiElement referenceTarget, @NotNull TypeEvalContext context, @Nullable PsiElement anchor) {
PyType type = getBaseStringType(referenceTarget);
if (type != null) {
return type;
return Ref.create(type);
}
type = getEnumType(referenceTarget, context, anchor);
if (type != null) {
return type;
return Ref.create(type);
}
return null;
}

View File

@@ -18,7 +18,6 @@ import com.intellij.psi.util.CachedValuesManager;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.psi.util.QualifiedName;
import com.intellij.util.ArrayUtil;
import com.intellij.util.Query;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.PyCustomType;
import com.jetbrains.python.PyNames;
@@ -40,7 +39,6 @@ import com.jetbrains.python.psi.resolve.PyResolveContext;
import com.jetbrains.python.psi.resolve.PyResolveImportUtil;
import com.jetbrains.python.psi.resolve.PyResolveUtil;
import com.jetbrains.python.psi.resolve.RatedResolveResult;
import com.jetbrains.python.psi.search.PySuperMethodsSearch;
import com.jetbrains.python.psi.stubs.PyClassStub;
import com.jetbrains.python.psi.stubs.PyTargetExpressionStub;
import com.jetbrains.python.psi.stubs.PyTypingNewTypeStub;
@@ -434,35 +432,35 @@ public class PyTypingTypeProvider extends PyTypeProviderBase {
}
@Override
public PyType getReferenceType(@NotNull PsiElement referenceTarget, TypeEvalContext context, @Nullable PsiElement anchor) {
public Ref<PyType> getReferenceType(@NotNull PsiElement referenceTarget, @NotNull TypeEvalContext context, @Nullable PsiElement anchor) {
if (referenceTarget instanceof PyTargetExpression) {
final PyTargetExpression target = (PyTargetExpression)referenceTarget;
// Depends on typing.Generic defined as a target expression
if (GENERIC.equals(target.getQualifiedName())) {
return createTypingGenericType(target);
return Ref.create(createTypingGenericType(target));
}
// Depends on typing.Protocol defined as a target expression
if (PROTOCOL.equals(target.getQualifiedName())) {
return createTypingProtocolType(target);
return Ref.create(createTypingProtocolType(target));
}
// Depends on typing.Callable defined as a target expression
if (CALLABLE.equals(target.getQualifiedName())) {
return createTypingCallableType(referenceTarget);
return Ref.create(createTypingCallableType(referenceTarget));
}
final PyType collection = getCollection(target, context);
if (collection instanceof PyInstantiableType) {
return ((PyInstantiableType)collection).toClass();
return Ref.create(((PyInstantiableType)collection).toClass());
}
final PyType newType = getNewTypeCreationForTarget(target, context);
if (newType != null) {
return newType;
return Ref.create(newType);
}
final Ref<PyType> annotatedType = getTypeFromTargetExpressionAnnotation(target, context);
if (annotatedType != null) {
return annotatedType.get();
return annotatedType;
}
final String name = target.getReferencedName();
@@ -496,13 +494,12 @@ public class PyTypingTypeProvider extends PyTypeProviderBase {
if (classAttrs == null) {
return null;
}
final Ref<PyType> combined = StreamEx.of(classAttrs)
.map(RatedResolveResult::getElement)
.select(PyTargetExpression.class)
.filter(x -> ScopeUtil.getScopeOwner(x) instanceof PyClass)
.map(x -> getTypeFromTargetExpressionAnnotation(x, context))
.collect(PyTypeUtil.toUnionFromRef());
return Ref.deref(combined);
return StreamEx.of(classAttrs)
.map(RatedResolveResult::getElement)
.select(PyTargetExpression.class)
.filter(x -> ScopeUtil.getScopeOwner(x) instanceof PyClass)
.map(x -> getTypeFromTargetExpressionAnnotation(x, context))
.collect(PyTypeUtil.toUnionFromRef());
}
}
else {
@@ -523,7 +520,6 @@ public class PyTypingTypeProvider extends PyTypeProviderBase {
.map(x -> getTypeFromTargetExpressionAnnotation(x, context))
.nonNull()
.findFirst()
.map(Ref::get)
.orElse(null);
}
}

View File

@@ -23,6 +23,7 @@ import com.jetbrains.python.psi.PyNamedParameter;
import com.jetbrains.python.psi.PyTargetExpression;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.PyTypeProviderBase;
import com.jetbrains.python.psi.types.PyTypeUtil;
import com.jetbrains.python.psi.types.TypeEvalContext;
import com.jetbrains.python.pyi.PyiUtil;
import org.jetbrains.annotations.NotNull;
@@ -67,14 +68,14 @@ public class PyUserSkeletonsTypeProvider extends PyTypeProviderBase {
}
@Override
public PyType getReferenceType(@NotNull PsiElement target, TypeEvalContext context, @Nullable PsiElement anchor) {
public Ref<PyType> getReferenceType(@NotNull PsiElement target, @NotNull TypeEvalContext context, @Nullable PsiElement anchor) {
if (PyiUtil.isInsideStub(target)) {
return null;
}
if (target instanceof PyTargetExpression) {
final PyTargetExpression targetSkeleton = PyUserSkeletonsUtil.getUserSkeletonWithContext((PyTargetExpression)target, context);
if (targetSkeleton != null) {
return context.getType(targetSkeleton);
return PyTypeUtil.notNullToRef(context.getType(targetSkeleton));
}
}
return null;

View File

@@ -517,9 +517,9 @@ public class PyCallExpressionHelper {
@Nullable
private static Ref<? extends PyType> getCallTargetReturnType(@NotNull PyCallExpression call, @NotNull PsiElement target,
@NotNull TypeEvalContext context) {
final PyType providedOverridingType = PyReferenceExpressionImpl.getReferenceTypeFromOverridingProviders(target, context, call);
if (providedOverridingType instanceof PyCallableType) {
return Ref.create(((PyCallableType)providedOverridingType).getCallType(context, call));
final Ref<PyType> providedOverridingType = PyReferenceExpressionImpl.getReferenceTypeFromOverridingProviders(target, context, call);
if (providedOverridingType != null && providedOverridingType.get() instanceof PyCallableType) {
return Ref.create(((PyCallableType)providedOverridingType.get()).getCallType(context, call));
}
PyClass cls = null;
@@ -565,9 +565,9 @@ public class PyCallExpressionHelper {
if (cls != null) {
return Ref.create(new PyClassTypeImpl(cls, false));
}
final PyType providedType = PyReferenceExpressionImpl.getReferenceTypeFromProviders(target, context, call);
if (providedType instanceof PyCallableType) {
return Ref.create(((PyCallableType)providedType).getCallType(context, call));
final Ref<PyType> providedType = PyReferenceExpressionImpl.getReferenceTypeFromProviders(target, context, call);
if (providedType != null && providedType.get() instanceof PyCallableType) {
return Ref.create(((PyCallableType)providedType.get()).getCallType(context, call));
}
final Ref<PyType> propertyCallType = getPropertyCallType(call, target, context);
if (propertyCallType != null) {

View File

@@ -431,9 +431,9 @@ public class PyReferenceExpressionImpl extends PyElementImpl implements PyRefere
@NotNull TypeEvalContext context,
@NotNull PyReferenceExpression anchor) {
if (!(target instanceof PyTargetExpression)) { // PyTargetExpression will ask about its type itself
final PyType pyType = getReferenceTypeFromProviders(target, context, anchor);
final Ref<PyType> pyType = getReferenceTypeFromProviders(target, context, anchor);
if (pyType != null) {
return pyType;
return pyType.get();
}
}
if (target instanceof PyTargetExpression) {
@@ -530,9 +530,9 @@ public class PyReferenceExpressionImpl extends PyElementImpl implements PyRefere
}
@Nullable
public static PyType getReferenceTypeFromOverridingProviders(@NotNull PsiElement target,
@NotNull TypeEvalContext context,
@Nullable PsiElement anchor) {
public static Ref<PyType> getReferenceTypeFromOverridingProviders(@NotNull PsiElement target,
@NotNull TypeEvalContext context,
@Nullable PsiElement anchor) {
return StreamEx
.of(Extensions.getExtensions(PyTypeProvider.EP_NAME))
.select(PyOverridingTypeProvider.class)
@@ -542,11 +542,11 @@ public class PyReferenceExpressionImpl extends PyElementImpl implements PyRefere
}
@Nullable
public static PyType getReferenceTypeFromProviders(@NotNull PsiElement target,
@NotNull TypeEvalContext context,
@Nullable PsiElement anchor) {
public static Ref<PyType> getReferenceTypeFromProviders(@NotNull PsiElement target,
@NotNull TypeEvalContext context,
@Nullable PsiElement anchor) {
for (PyTypeProvider provider : Extensions.getExtensions(PyTypeProvider.EP_NAME)) {
final PyType result = provider.getReferenceType(target, context, anchor);
final Ref<PyType> result = provider.getReferenceType(target, context, anchor);
if (result != null) {
return result;
}

View File

@@ -128,9 +128,9 @@ public class PyTargetExpressionImpl extends PyBaseElementImpl<PyTargetExpression
// imported via __all__
return null;
}
final PyType pyType = PyReferenceExpressionImpl.getReferenceTypeFromProviders(this, context, null);
final Ref<PyType> pyType = PyReferenceExpressionImpl.getReferenceTypeFromProviders(this, context, null);
if (pyType != null) {
return pyType;
return pyType.get();
}
PyType type = getTypeFromDocString();
if (type != null) {

View File

@@ -9,6 +9,7 @@ import com.intellij.openapi.fileTypes.FileTypeManager;
import com.intellij.openapi.module.Module;
import com.intellij.openapi.module.ModuleUtilCore;
import com.intellij.openapi.projectRoots.Sdk;
import com.intellij.openapi.util.Ref;
import com.intellij.openapi.util.io.FileUtil;
import com.intellij.psi.PsiDirectory;
import com.intellij.psi.PsiElement;
@@ -37,12 +38,7 @@ public class ResolveImportUtil {
private ResolveImportUtil() {
}
private static final ThreadLocal<Set<String>> ourBeingImported = new ThreadLocal<Set<String>>() {
@Override
protected Set<String> initialValue() {
return new HashSet<>();
}
};
private static final ThreadLocal<Set<String>> ourBeingImported = ThreadLocal.withInitial(() -> new HashSet<>());
public static boolean isAbsoluteImportEnabledFor(PsiElement foothold) {
if (foothold != null) {
@@ -360,9 +356,9 @@ public class ResolveImportUtil {
private static List<RatedResolveResult> resolveMemberFromReferenceTypeProviders(@NotNull PsiElement parent,
@NotNull String referencedName) {
final PyResolveContext resolveContext = PyResolveContext.defaultContext();
final PyType refType = PyReferenceExpressionImpl.getReferenceTypeFromProviders(parent, resolveContext.getTypeEvalContext(), null);
if (refType != null) {
final List<? extends RatedResolveResult> result = refType.resolveMember(referencedName, null, AccessDirection.READ, resolveContext);
final Ref<PyType> refType = PyReferenceExpressionImpl.getReferenceTypeFromProviders(parent, resolveContext.getTypeEvalContext(), null);
if (refType != null && !refType.isNull()) {
final List<? extends RatedResolveResult> result = refType.get().resolveMember(referencedName, null, AccessDirection.READ, resolveContext);
if (result != null) {
return Lists.newArrayList(result);
}

View File

@@ -21,10 +21,14 @@ import com.intellij.openapi.util.UserDataHolder;
import com.intellij.psi.PsiElement;
import com.jetbrains.python.psi.impl.PyBuiltinCache;
import one.util.streamex.StreamEx;
import org.jetbrains.annotations.Contract;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collector;
import java.util.stream.Collectors;
@@ -114,6 +118,12 @@ public final class PyTypeUtil {
return StreamEx.of(type);
}
@Nullable
@Contract("null -> null; !null -> !null")
public static Ref<PyType> notNullToRef(@Nullable PyType type) {
return type == null ? null : Ref.create(type);
}
/**
* Returns a collector that combines a stream of {@code Ref<PyType>} back into a single {@code Ref<PyType>}
* using {@link PyUnionType#union(PyType, PyType)}.

View File

@@ -130,11 +130,11 @@ public class PyiTypeProvider extends PyTypeProviderBase {
}
@Override
public PyType getReferenceType(@NotNull PsiElement target, TypeEvalContext context, @Nullable PsiElement anchor) {
public Ref<PyType> getReferenceType(@NotNull PsiElement target, @NotNull TypeEvalContext context, @Nullable PsiElement anchor) {
if (target instanceof PyTargetExpression) {
final PsiElement pythonStub = PyiUtil.getPythonStub((PyTargetExpression)target);
if (pythonStub instanceof PyTypedElement) {
return context.getType((PyTypedElement)pythonStub);
return PyTypeUtil.notNullToRef(context.getType((PyTypedElement)pythonStub));
}
}
return null;

View File

@@ -3202,6 +3202,22 @@ public class PyTypeTest extends PyTestCase {
" expr = y");
}
// PY-28052
public void testClassAttributeAnnotatedAsAny() {
runWithLanguageLevel(
LanguageLevel.PYTHON35,
() -> doTest("Any",
"from typing import Any\n" +
"\n" +
"\n" +
"class MyClass:\n" +
" arbitrary: Any = 42\n" +
"\n" +
"\n" +
"expr = MyClass().arbitrary")
);
}
private static List<TypeEvalContext> getTypeEvalContexts(@NotNull PyExpression element) {
return ImmutableList.of(TypeEvalContext.codeAnalysis(element.getProject(), element.getContainingFile()).withTracing(),
TypeEvalContext.userInitiated(element.getProject(), element.getContainingFile()).withTracing());