PY-72232 Report non-context-managers in with statement

GitOrigin-RevId: 623ae38649b33ad2730c977f4c5e36a1f82224c7
This commit is contained in:
Mikhail Golubev
2025-05-30 23:32:54 +03:00
committed by intellij-monorepo-bot
parent eb7051be6d
commit 552c41d53a
5 changed files with 96 additions and 0 deletions

View File

@@ -157,6 +157,8 @@ public final @NonNls class PyNames {
public static final String AWAITABLE = "Awaitable";
public static final String ASYNC_ITERABLE = "AsyncIterable";
public static final String ABSTRACT_CONTEXT_MANAGER = "AbstractContextManager";
public static final String ABSTRACT_ASYNC_CONTEXT_MANAGER = "AbstractAsyncContextManager";
public static final String ABC_NUMBER = "Number";
public static final String ABC_COMPLEX = "Complex";

View File

@@ -97,6 +97,13 @@ public class PyTypeCheckerInspection extends PyInspection {
checkIteratedValue(node.getForPart().getSource(), node.isAsync());
}
@Override
public void visitPyWithStatement(@NotNull PyWithStatement node) {
for (PyWithItem withItem : node.getWithItems()) {
checkContextManagerValue(withItem.getExpression(), node.isAsync());
}
}
@Override
public void visitPyReturnStatement(@NotNull PyReturnStatement node) {
ScopeOwner owner = ScopeUtil.getScopeOwner(node);
@@ -403,6 +410,22 @@ public class PyTypeCheckerInspection extends PyInspection {
}
}
private void checkContextManagerValue(@Nullable PyExpression iteratedValue, boolean isAsync) {
if (iteratedValue != null) {
final PyType type = myTypeEvalContext.getType(iteratedValue);
final String contextManagerClassName = isAsync ? PyNames.ABSTRACT_ASYNC_CONTEXT_MANAGER : PyNames.ABSTRACT_CONTEXT_MANAGER;
if (type != null &&
!PyTypeChecker.isUnknown(type, myTypeEvalContext) &&
!PyABCUtil.isSubtype(type, contextManagerClassName, myTypeEvalContext)) {
final String typeName = PythonDocumentationProvider.getTypeName(type, myTypeEvalContext);
String qualifiedName = "contextlib." + contextManagerClassName;
registerProblem(iteratedValue, PyPsiBundle.message("INSP.type.checker.expected.type.got.type.instead", qualifiedName, typeName));
}
}
}
private @Nullable AnalyzeCalleeResults analyzeCallee(@NotNull PyCallSiteExpression callSite,
@NotNull PyCallExpression.PyArgumentsMapping mapping) {
final PyCallableType callableType = mapping.getCallableType();

View File

@@ -97,6 +97,12 @@ public final class PyABCUtil {
if (PyNames.AWAITABLE.equals(superClassName)) {
return hasMethod(subClass, PyNames.DUNDER_AWAIT, inherited, context);
}
if (PyNames.ABSTRACT_CONTEXT_MANAGER.equals(superClassName)) {
return hasMethod(subClass, PyNames.ENTER, inherited, context) && hasMethod(subClass, PyNames.EXIT, inherited, context);
}
if (PyNames.ABSTRACT_ASYNC_CONTEXT_MANAGER.equals(superClassName)) {
return hasMethod(subClass, PyNames.AENTER, inherited, context) && hasMethod(subClass, PyNames.AEXIT, inherited, context);
}
return false;
}

View File

@@ -0,0 +1,60 @@
from contextlib import contextmanager, asynccontextmanager
class CustomContextManager:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
return
class CustomContextManager2(CustomContextManager):
pass
class CustomAsyncContextManager:
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
return
@contextmanager
def generator_cm():
yield
@asynccontextmanager
async def generator_acm():
yield
class NonContextManager:
pass
with (
CustomContextManager(),
CustomContextManager2(),
<warning descr="Expected type 'contextlib.AbstractContextManager', got 'CustomAsyncContextManager' instead">CustomAsyncContextManager()</warning>,
<warning descr="Expected type 'contextlib.AbstractContextManager', got 'NonContextManager' instead">NonContextManager()</warning>,
generator_cm(),
<warning descr="Expected type 'contextlib.AbstractContextManager', got '_AsyncGeneratorContextManager[None]' instead">generator_acm()</warning>,
<warning descr="Expected type 'contextlib.AbstractContextManager', got 'object' instead">object()</warning>
):
pass
async def f():
async with (
<warning descr="Expected type 'contextlib.AbstractAsyncContextManager', got 'CustomContextManager' instead">CustomContextManager()</warning>,
<warning descr="Expected type 'contextlib.AbstractAsyncContextManager', got 'CustomContextManager2' instead">CustomContextManager2()</warning>,
CustomAsyncContextManager(),
<warning descr="Expected type 'contextlib.AbstractAsyncContextManager', got 'NonContextManager' instead">NonContextManager()</warning>,
<warning descr="Expected type 'contextlib.AbstractAsyncContextManager', got '_GeneratorContextManager[None]' instead">generator_cm()</warning>,
generator_acm(),
<warning descr="Expected type 'contextlib.AbstractAsyncContextManager', got 'object' instead">object()</warning>
):
pass

View File

@@ -29,6 +29,11 @@ public class Py3TypeCheckerInspectionTest extends PyInspectionTestCase {
doTest();
}
// PY-72232
public void testWithItemNonContextManager() {
doTest();
}
// PY-10660
public void testStructUnpackPy3() {
doMultiFileTest();