PY-12425 Fixed: Completion breaks when implementing a callable object with the __call__ method

Fix callType calculating for class instance: resolve __call__ member and process all possible return types
This commit is contained in:
Semyon Proshev
2016-02-05 18:17:53 +03:00
parent 682390069d
commit 77356abd34
13 changed files with 187 additions and 3 deletions

View File

@@ -38,7 +38,10 @@ import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.impl.PyBuiltinCache;
import com.jetbrains.python.psi.impl.PyResolveResultRater;
import com.jetbrains.python.psi.impl.ResolveResultList;
import com.jetbrains.python.psi.resolve.*;
import com.jetbrains.python.psi.resolve.CompletionVariantsProcessor;
import com.jetbrains.python.psi.resolve.PyResolveContext;
import com.jetbrains.python.psi.resolve.PyResolveProcessor;
import com.jetbrains.python.psi.resolve.RatedResolveResult;
import com.jetbrains.python.toolbox.Maybe;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
@@ -374,7 +377,52 @@ public class PyClassTypeImpl extends UserDataHolderBase implements PyClassType {
@Nullable
@Override
public PyType getCallType(@NotNull TypeEvalContext context, @NotNull PyCallSiteExpression callSite) {
return getReturnType(context);
if (!isDefinition()) {
final PyResolveContext resolveContext = PyResolveContext.noImplicits().withTypeEvalContext(context);
final List<? extends RatedResolveResult> resolveResults = resolveMember(PyNames.CALL, callSite, AccessDirection.READ, resolveContext);
if (resolveResults != null) {
final ArrayList<PyType> result = new ArrayList<PyType>();
for (RatedResolveResult resolveResult : resolveResults) {
result.addAll(
getPossibleReturnTypes(resolveResult.getElement(), context)
);
}
return PyUnionType.union(result);
}
}
return null;
}
@NotNull
private static List<PyType> getPossibleReturnTypes(@Nullable PsiElement element, @NotNull TypeEvalContext context) {
final ArrayList<PyType> result = new ArrayList<PyType>();
if (element instanceof PyTypedElement) {
final PyType elementType = context.getType((PyTypedElement)element);
result.addAll(getPossibleReturnTypes(elementType, context));
if (elementType instanceof PyUnionType) {
for (PyType type : ((PyUnionType)elementType).getMembers()) {
result.addAll(getPossibleReturnTypes(type, context));
}
}
}
return result;
}
@NotNull
private static List<PyType> getPossibleReturnTypes(@Nullable PyType type, @NotNull TypeEvalContext context) {
if (type instanceof PyCallableType) {
return Collections.singletonList(((PyCallableType)type).getReturnType(context));
}
return Collections.emptyList();
}
@Nullable
@@ -724,6 +772,5 @@ public class PyClassTypeImpl extends UserDataHolderBase implements PyClassType {
public boolean value(final PsiElement target) {
return (instance != target);
}
}
}

View File

@@ -0,0 +1,10 @@
class Foo(object):
bar = True
class FooMaker(object):
__call__ = tuple
fm = FooMaker()
f3 = fm()
f3.count()

View File

@@ -0,0 +1,10 @@
class Foo(object):
bar = True
class FooMaker(object):
__call__ = tuple
fm = FooMaker()
f3 = fm()
f3.c<caret>

View File

@@ -0,0 +1,11 @@
class Foo(object):
bar = True
class FooMaker(object):
def __call__(self):
return Foo()
fm = FooMaker()
f3 = fm()
f3.bar

View File

@@ -0,0 +1,11 @@
class Foo(object):
bar = True
class FooMaker(object):
def __call__(self):
return Foo()
fm = FooMaker()
f3 = fm()
f3.b<caret>

View File

@@ -0,0 +1,13 @@
class Foo(object):
bar = True
class FooMaker(object):
def foo(self):
return Foo()
__call__ = foo
fm = FooMaker()
f3 = fm()
f3.bar

View File

@@ -0,0 +1,13 @@
class Foo(object):
bar = True
class FooMaker(object):
def foo(self):
return Foo()
__call__ = foo
fm = FooMaker()
f3 = fm()
f3.b<caret>

View File

@@ -0,0 +1,14 @@
class Foo(object):
bar = True
class FooMakerAnc(object):
def __call__(self):
return Foo()
class FooMaker(FooMakerAnc):
pass
fm = FooMaker()
f3 = fm()
f3.bar

View File

@@ -0,0 +1,14 @@
class Foo(object):
bar = True
class FooMakerAnc(object):
def __call__(self):
return Foo()
class FooMaker(FooMakerAnc):
pass
fm = FooMaker()
f3 = fm()
f3.b<caret>

View File

@@ -0,0 +1,7 @@
class FooMaker(object):
pass
fm = FooMaker()
f3 = fm()
f3.bit_length()

View File

@@ -0,0 +1,7 @@
class FooMaker(object):
pass
fm = FooMaker()
f3 = fm()
f3.bi<caret>

View File

@@ -0,0 +1,2 @@
class FooMaker:
def __call__(self) -> int: ...

View File

@@ -946,6 +946,31 @@ public class PythonCompletionTest extends PyTestCase {
assertDoesntContain(variants, "_foo(self)");
}
// PY-12425
public void testInstanceFromDefinedCallAttr() {
doTest();
}
// PY-12425
public void testInstanceFromFunctionAssignedToCallAttr() {
doTest();
}
// PY-12425
public void testInstanceFromCallableAssignedToCallAttr() {
doTest();
}
// PY-12425
public void testInstanceFromInheritedCallAttr() {
doTest();
}
// PY-12425
public void testInstanceFromProvidedCallAttr() {
doMultiFileTest();
}
@Override
protected String getTestDataPath() {
return super.getTestDataPath() + "/completion";