PY-22971 Fixed: Support @typing.overload in regular Python files, not only in Python stubs

Update PyReferenceImpl to resolve to the latest definition of the callable and prepend it with overloads.
Update PyParameterInfoHandler and PyTypeChecker to take into account overloads only if they exist.
This commit is contained in:
Semyon Proshev
2017-05-12 19:13:32 +03:00
committed by Semyon Proshev
parent d3f1eda306
commit f0ab4154ef
18 changed files with 326 additions and 4 deletions

View File

@@ -72,8 +72,10 @@ public class PyParameterInfoHandler implements ParameterInfoHandler<PyArgumentLi
final TypeEvalContext typeEvalContext = TypeEvalContext.userInitiated(argumentList.getProject(), argumentList.getContainingFile());
final PyResolveContext resolveContext = PyResolveContext.noImplicits().withRemote().withTypeEvalContext(typeEvalContext);
final Object[] items = PyUtil.filterTopPriorityResults(call.multiResolveRatedCallee(resolveContext))
.stream()
final List<PyCallExpression.PyRatedMarkedCallee> ratedMarkedCallees =
PyUtil.filterTopPriorityResults(call.multiResolveRatedCallee(resolveContext));
final Object[] items = PyCallExpressionHelper.forEveryScopeTakeOverloadsOtherwiseImplementations(ratedMarkedCallees, typeEvalContext)
.map(ratedMarkedCallee -> Pair.createNonNull(call, ratedMarkedCallee.getMarkedCallee()))
.toArray();

View File

@@ -25,12 +25,14 @@ import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.ObjectUtils;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.resolve.PyResolveContext;
import com.jetbrains.python.psi.resolve.QualifiedRatedResolveResult;
import com.jetbrains.python.psi.resolve.QualifiedResolveResult;
import com.jetbrains.python.psi.resolve.RatedResolveResult;
import com.jetbrains.python.psi.types.*;
import com.jetbrains.python.pyi.PyiTypeProvider;
import com.jetbrains.python.toolbox.Maybe;
import one.util.streamex.StreamEx;
import org.jetbrains.annotations.NotNull;
@@ -38,6 +40,7 @@ import org.jetbrains.annotations.Nullable;
import java.util.*;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* Functions common to different implementors of PyCallExpression, with different base classes.
@@ -835,6 +838,47 @@ public class PyCallExpressionHelper {
tupleMappedParameters);
}
@NotNull
public static Stream<PyCallExpression.PyRatedMarkedCallee> forEveryScopeTakeOverloadsOtherwiseImplementations(@NotNull List<PyCallExpression.PyRatedMarkedCallee> callees,
@NotNull TypeEvalContext context) {
if (!containsOverloadsAndImplementations(callees, context)) {
return callees.stream();
}
return StreamEx
.of(callees)
.groupingBy(callee -> ScopeUtil.getScopeOwner(callee.getElement()))
.values()
.stream()
.flatMap(oneScopeCallees -> takeOverloadsOtherwiseImplementations(oneScopeCallees, context));
}
private static boolean containsOverloadsAndImplementations(@NotNull List<PyCallExpression.PyRatedMarkedCallee> callees,
@NotNull TypeEvalContext context) {
boolean containsOverloads = false;
boolean containsImplementations = false;
for (PyCallExpression.PyRatedMarkedCallee callee : callees) {
final boolean overload = PyiTypeProvider.isOverload(callee.getElement(), context);
containsOverloads |= overload;
containsImplementations |= !overload;
if (containsOverloads && containsImplementations) return true;
}
return false;
}
@NotNull
private static Stream<PyCallExpression.PyRatedMarkedCallee> takeOverloadsOtherwiseImplementations(@NotNull List<PyCallExpression.PyRatedMarkedCallee> callees,
@NotNull TypeEvalContext context) {
if (!containsOverloadsAndImplementations(callees, context)) {
return callees.stream();
}
return callees.stream().filter(callee -> PyiTypeProvider.isOverload(callee.getElement(), context));
}
public static class ArgumentMappingResults {
@NotNull private final Map<PyExpression, PyNamedParameter> myMappedParameters;
@NotNull private final List<PyParameter> myUnmappedParameters;

View File

@@ -33,6 +33,7 @@ import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.IncorrectOperationException;
import com.intellij.util.PlatformIcons;
import com.intellij.util.ProcessingContext;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.codeInsight.controlflow.ControlFlowCache;
import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
@@ -44,8 +45,10 @@ import com.jetbrains.python.psi.resolve.*;
import com.jetbrains.python.psi.types.PyModuleType;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import com.jetbrains.python.pyi.PyiTypeProvider;
import com.jetbrains.python.pyi.PyiUtil;
import com.jetbrains.python.refactoring.PyDefUseUtil;
import one.util.streamex.StreamEx;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
@@ -269,6 +272,16 @@ public class PyReferenceImpl implements PsiReferenceEx, PsiPolyVariantReference
// TODO: Use the results from the processor as a cache for resolving to latest defs
final ResolveResultList latestDefs = resolveToLatestDefs(instructions, realContext, referencedName, typeEvalContext);
if (!latestDefs.isEmpty()) {
if (ContainerUtil.exists(latestDefs, result -> result.getElement() instanceof PyCallable)) {
return StreamEx
.of(processor.getResults().keySet())
.select(PyCallable.class)
.filter(callable -> PyiTypeProvider.isOverload(callable, typeEvalContext))
.map(callable -> new RatedResolveResult(getRate(callable, typeEvalContext), callable))
.append(latestDefs)
.toList();
}
return latestDefs;
}
else if (resolvedOwner instanceof PyClass || instructions.isEmpty() && allInOwnScopeComprehensions(resolvedElements)) {

View File

@@ -624,8 +624,12 @@ public class PyTypeChecker {
private static List<PyCallable> multiResolveCallee(@NotNull PyCallSiteExpression callSite, @NotNull TypeEvalContext context) {
final PyResolveContext resolveContext = PyResolveContext.noImplicits().withTypeEvalContext(context);
if (callSite instanceof PyCallExpression) {
final List<PyCallExpression.PyRatedCallee> ratedCallees = ((PyCallExpression)callSite).multiResolveRatedCalleeFunction(resolveContext);
return ContainerUtil.map(PyUtil.filterTopPriorityResults(ratedCallees), PyCallExpression.PyRatedCallee::getElement);
final List<PyCallExpression.PyRatedMarkedCallee> ratedMarkedCallees =
PyUtil.filterTopPriorityResults(((PyCallExpression)callSite).multiResolveRatedCallee(resolveContext));
return forEveryScopeTakeOverloadsOtherwiseImplementations(ratedMarkedCallees, context)
.map(PyCallExpression.PyRatedMarkedCallee::getElement)
.collect(Collectors.toList());
}
else if (callSite instanceof PySubscriptionExpression || callSite instanceof PyBinaryExpression) {
final List<PyCallable> results = new ArrayList<>();

View File

@@ -0,0 +1,24 @@
from typing import overload
class A:
@overload
def foo(self, value: None) -> None:
pass
@overload
def foo(self, value: int) -> str:
pass
@overload
def foo(self, value: str) -> str:
pass
def foo(self, value):
return None
A().foo(None)
A().foo(5)
A().foo("5")
A().foo(<warning descr="Unexpected type(s):(A)Possible types:(None)(int)(str)">A()</warning>)

View File

@@ -0,0 +1,6 @@
import b
b.A().foo(None)
b.A().foo(5)
b.A().foo("5")
b.A().foo(A())

View File

@@ -0,0 +1,18 @@
from typing import overload
class A:
@overload
def foo(self, value: None) -> None:
pass
@overload
def foo(self, value: int) -> str:
pass
@overload
def foo(self, value: str) -> str:
pass
def foo(self, value):
return None

View File

@@ -0,0 +1,6 @@
import b
b.foo(None)
b.foo(5)
b.foo("5")
b.foo(A())

View File

@@ -0,0 +1,17 @@
from typing import overload
@overload
def foo(value: None) -> None:
pass
@overload
def foo(value: int) -> str:
pass
@overload
def foo(value: str) -> str:
pass
def foo(value):
return None

View File

@@ -0,0 +1,23 @@
from typing import overload
@overload
def foo(value: None) -> None:
pass
@overload
def foo(value: int) -> str:
pass
@overload
def foo(value: str) -> str:
pass
def foo(value):
return None
foo(None)
foo(5)
foo("5")
foo(<warning descr="Unexpected type(s):(object)Possible types:(None)(int)(str)">object()</warning>)

View File

@@ -0,0 +1,21 @@
from typing import overload
class A:
@overload
def foo(self, value: None) -> None:
pass
@overload
def foo(self, value: int) -> str:
pass
@overload
def foo(self, value: str) -> str:
pass
def foo(self, value):
return None
A().foo(<arg1>)

View File

@@ -0,0 +1,3 @@
import b
b.A().foo(<arg1>)

View File

@@ -0,0 +1,18 @@
from typing import overload
class A:
@overload
def foo(self, value: None) -> None:
pass
@overload
def foo(self, value: int) -> str:
pass
@overload
def foo(self, value: str) -> str:
pass
def foo(self, value):
return None

View File

@@ -0,0 +1,3 @@
import b
b.foo(<arg1>)

View File

@@ -0,0 +1,17 @@
from typing import overload
@overload
def foo(value: None) -> None:
pass
@overload
def foo(value: int) -> str:
pass
@overload
def foo(value: str) -> str:
pass
def foo(value):
return None

View File

@@ -0,0 +1,20 @@
from typing import overload
@overload
def foo(value: None) -> None:
pass
@overload
def foo(value: int) -> str:
pass
@overload
def foo(value: str) -> str:
pass
def foo(value):
return None
foo(<arg1>)

View File

@@ -544,6 +544,68 @@ public class PyParameterInfoTest extends LightMarkedTestCase {
);
}
// PY-22971
public void testTopLevelOverloadsAndImplementation() {
runWithLanguageLevel(
LanguageLevel.PYTHON35,
() -> {
final int offset = loadTest(1).get("<arg1>").getTextOffset();
final List<String> texts = Arrays.asList("value: None", "value: int", "value: str");
final List<String[]> highlighted = Arrays.asList(new String[]{"value: None"}, new String[]{"value: int"}, new String[]{"value: str"});
feignCtrlP(offset).check(texts, highlighted, Arrays.asList(ArrayUtil.EMPTY_STRING_ARRAY, ArrayUtil.EMPTY_STRING_ARRAY, ArrayUtil.EMPTY_STRING_ARRAY));
}
);
}
// PY-22971
public void testOverloadsAndImplementationInClass() {
runWithLanguageLevel(
LanguageLevel.PYTHON35,
() -> {
final int offset = loadTest(1).get("<arg1>").getTextOffset();
final List<String> texts = Arrays.asList("self: A, value: None", "self: A, value: int", "self: A, value: str");
final List<String[]> highlighted = Arrays.asList(new String[]{"value: None"}, new String[]{"value: int"}, new String[]{"value: str"});
final List<String[]> disabled = Arrays.asList(new String[]{"self: A, "}, new String[]{"self: A, "}, new String[]{"self: A, "});
feignCtrlP(offset).check(texts, highlighted, disabled);
}
);
}
// PY-22971
public void testOverloadsAndImplementationInImportedModule() {
runWithLanguageLevel(
LanguageLevel.PYTHON35,
() -> {
final int offset = loadMultiFileTest(1).get("<arg1>").getTextOffset();
final List<String> texts = Arrays.asList("value: str", "value: int", "value: None");
final List<String[]> highlighted = Arrays.asList(new String[]{"value: str"}, new String[]{"value: int"}, new String[]{"value: None"});
feignCtrlP(offset).check(texts, highlighted, Arrays.asList(ArrayUtil.EMPTY_STRING_ARRAY, ArrayUtil.EMPTY_STRING_ARRAY, ArrayUtil.EMPTY_STRING_ARRAY));
}
);
}
// PY-22971
public void testOverloadsAndImplementationInImportedClass() {
runWithLanguageLevel(
LanguageLevel.PYTHON35,
() -> {
final int offset = loadMultiFileTest(1).get("<arg1>").getTextOffset();
final List<String> texts = Arrays.asList("self: A, value: None", "self: A, value: int", "self: A, value: str");
final List<String[]> highlighted = Arrays.asList(new String[]{"value: None"}, new String[]{"value: int"}, new String[]{"value: str"});
final List<String[]> disabled = Arrays.asList(new String[]{"self: A, "}, new String[]{"self: A, "}, new String[]{"self: A, "});
feignCtrlP(offset).check(texts, highlighted, disabled);
}
);
}
/**
* Imitates pressing of Ctrl+P; fails if results are not as expected.
* @param offset offset of 'cursor' where Ctrl+P is pressed.

View File

@@ -16,6 +16,7 @@
package com.jetbrains.python.inspections;
import com.jetbrains.python.fixtures.PyTestCase;
import com.jetbrains.python.psi.LanguageLevel;
/**
* @author vlan
@@ -386,4 +387,24 @@ public class PyTypeCheckerInspectionTest extends PyTestCase {
public void testChainedComparisons() {
doTest();
}
// PY-22971
public void testTopLevelOverloadsAndImplementation() {
runWithLanguageLevel(LanguageLevel.PYTHON35, this::doTest);
}
// PY-22971
public void testOverloadsAndImplementationInClass() {
runWithLanguageLevel(LanguageLevel.PYTHON35, this::doTest);
}
// PY-22971
public void testOverloadsAndImplementationInImportedModule() {
runWithLanguageLevel(LanguageLevel.PYTHON35, this::doMultiFileTest);
}
// PY-22971
public void testOverloadsAndImplementationInImportedClass() {
runWithLanguageLevel(LanguageLevel.PYTHON35, this::doMultiFileTest);
}
}