PY-78413 No warning for awaiting a normal function if target in other module

- special case for awaiting a call to an imported, untyped, non-async function

(cherry picked from commit 8eec47ca4560ae0577e8c6157ef533952436b3b4)

IJ-MR-168288

GitOrigin-RevId: 5ae73b7b663e94996116706a775b9fed683f331e
This commit is contained in:
Marcus Mews
2025-07-21 09:11:19 +00:00
committed by intellij-monorepo-bot
parent c7a4d99372
commit 1e1dc85113
9 changed files with 137 additions and 9 deletions

View File

@@ -907,6 +907,7 @@ INSP.argument.equals.to.default=Argument equals to the default parameter value
#PyAsyncCallInspection
INSP.NAME.coroutine.is.not.awaited=Coroutine ''{0}'' is not awaited
INSP.async.call=Missing `await` syntax in coroutine calls
INSP.await.call.on.imported.untyped.function=Function ''{0}'' neither declared as ''async'' nor with ''Awaitable'' as return type
# PyAttributeOutsideInitInspection
INSP.NAME.attribute.outside.init=An instance attribute is defined outside `__init__`

View File

@@ -134,9 +134,15 @@ public abstract class PyUnresolvedReferencesVisitor extends PyInspectionVisitor
if (unresolved) {
boolean ignoreUnresolved = ignoreUnresolved(node, reference) || !evaluateVersionsForElement(node).contains(myVersion);
if (!ignoreUnresolved) {
final HighlightSeverity severity = reference instanceof PsiReferenceEx
HighlightSeverity severity = reference instanceof PsiReferenceEx
? ((PsiReferenceEx)reference).getUnresolvedHighlightSeverity(myTypeEvalContext)
: HighlightSeverity.ERROR;
if (severity == null) {
if (isAwaitCallToImportedNonAsyncFunction(reference)) {
// special case: type of prefixExpression.getQualifier() is null but we want to check whether the called function is async
severity = HighlightSeverity.WEAK_WARNING;
}
}
if (severity == null) return;
registerUnresolvedReferenceProblem(node, reference, severity);
}
@@ -148,6 +154,27 @@ public abstract class PyUnresolvedReferencesVisitor extends PyInspectionVisitor
}
}
private boolean isAwaitCallToImportedNonAsyncFunction(@NotNull PsiReference reference) {
if (reference.getElement() instanceof PyPrefixExpression prefixExpression
&& PyNames.DUNDER_AWAIT.equals(prefixExpression.getOperator().getSpecialMethodName())
&& getReferenceQualifier(reference) instanceof PyCallExpression callExpression) {
@NotNull List<@NotNull PyCallable> callees =
callExpression.multiResolveCalleeFunction(PyResolveContext.defaultContext(myTypeEvalContext));
if (callees.isEmpty()) {
return false;
}
for (PyCallable callee : callees) {
if (callee instanceof PyFunction pyFunction && pyFunction.isAsync()) {
return false;
}
}
return true; // no signature is declared async -> warning
}
return false;
}
private void registerUnresolvedReferenceProblem(@NotNull PyElement node, final @NotNull PsiReference reference,
@NotNull HighlightSeverity severity) {
if (reference instanceof DocStringTypeReference) {
@@ -263,6 +290,14 @@ public abstract class PyUnresolvedReferencesVisitor extends PyInspectionVisitor
}
markedQualified = true;
}
else {
if (isAwaitCallToImportedNonAsyncFunction(reference)) {
description = PyPsiBundle.message("INSP.await.call.on.imported.untyped.function", qualifier.getText());
node = qualifier; // show warning on the function call
rangeInElement = TextRange.create(0, qualifier.getTextRange().getLength());
markedQualified = true;
}
}
}
if (!markedQualified) {
description = PyPsiBundle.message("INSP.unresolved.refs.unresolved.reference", refText);

View File

@@ -0,0 +1,17 @@
from b import fun_async, fun_non_async
async def expect_no_warning():
await fun_async()
async def expect_new_warning():
await <warning descr="Function 'fun_non_async()' neither declared as 'async' nor with 'Awaitable' as return type">fun_non_async()</warning>
def local_fun_non_async():
pass
async def expect_warning():
<warning descr="Class 'None' does not define '__await__', so the 'await' operator cannot be used on its instances">await</warning> local_fun_non_async()

View File

@@ -0,0 +1,7 @@
async def fun_async():
return 3
def fun_non_async():
return 3

View File

@@ -0,0 +1,5 @@
from b import overloaded_fun_async_with_implicit_return_type
async def expect_no_warning():
await overloaded_fun_async_with_implicit_return_type(1)

View File

@@ -0,0 +1,15 @@
from typing import overload
@overload
async def overloaded_fun_async_with_implicit_return_type(arg0: str):
...
@overload
def overloaded_fun_async_with_implicit_return_type(arg0: int):
...
def overloaded_fun_async_with_implicit_return_type(arg0):
return 3

View File

@@ -0,0 +1,18 @@
from b import fun_awaitable_imported, MyAwaitable
async def expect_false_positive_warning():
await <warning descr="Function 'fun_awaitable_imported()' neither declared as 'async' nor with 'Awaitable' as return type">fun_awaitable_imported()</warning>
async def expect_pass_1():
await MyAwaitable()
def fun_awaitable_local():
return MyAwaitable()
async def expect_pass_2():
await fun_awaitable_local()

View File

@@ -0,0 +1,9 @@
class MyAwaitable:
def __await__(self):
yield from []
return "done"
def fun_awaitable_imported():
return MyAwaitable()

View File

@@ -815,14 +815,14 @@ public class PyUnresolvedReferencesInspectionTest extends PyInspectionTestCase {
runWithLanguageLevel(
LanguageLevel.getLatest(),
() -> doTestByText("""
def foo(cls):
return cls
@foo
class Bar2(object):
def __init__(self):
print(self.<warning descr="Unresolved attribute reference 'hello' for class 'Bar2'">hello</warning>)
def foo(cls):
return cls
@foo
class Bar2(object):
def __init__(self):
print(self.<warning descr="Unresolved attribute reference 'hello' for class 'Bar2'">hello</warning>)
""")
);
}
@@ -891,6 +891,27 @@ public class PyUnresolvedReferencesInspectionTest extends PyInspectionTestCase {
});
}
// PY-78413
public void testAsyncAwaitWarningOnImportedFun() {
runWithLanguageLevel(LanguageLevel.getLatest(), () -> {
doMultiFileTest();
});
}
// PY-78413
public void testAsyncAwaitWarningOnImportedFunReturnAwaitable() {
runWithLanguageLevel(LanguageLevel.getLatest(), () -> {
doMultiFileTest();
});
}
// PY-78413
public void testAsyncAwaitWarningOnImportedFunOverloaded() {
runWithLanguageLevel(LanguageLevel.getLatest(), () -> {
doMultiFileTest();
});
}
@NotNull
@Override
protected Class<? extends PyInspection> getInspectionClass() {