PY-20611 Missing warning about functions implicitly returning None when return type is not Optional

Updated PyFunction to account for implicit 'return None' statements when inferring return statement types.

It affected return type inference of PyFunction.

Fixed a failing test related to formatted strings.

Added a quick fix to make all return statements explicit.

Updated the CFG to include PyPassStatements, enabling detection of exit points in empty functions.

Simplified PyMakeFunctionReturnTypeQuickFix to independently infer function types and handle required imports. Currently, it does not support specifying custom suggested types.



Merge-request: IJ-MR-148719
Merged-by: Aleksandr Govenko <aleksandr.govenko@jetbrains.com>

(cherry picked from commit 9f58961f9eb70e4f9dbba7359f5aafdfd392b7e2)

IJ-MR-148719

GitOrigin-RevId: 68ef5c4a1cc0fcaffd750cc0713250a106136643
This commit is contained in:
Aleksandr.Govenko
2024-11-26 17:02:37 +00:00
committed by intellij-monorepo-bot
parent 137b1d2b13
commit 4dd41ee9f5
32 changed files with 430 additions and 301 deletions

View File

@@ -101,6 +101,10 @@ open class PyAstElementVisitor : PsiElementVisitor() {
open fun visitPyReturnStatement(node: PyAstReturnStatement) {
visitPyStatement(node)
}
open fun visitPyPassStatement(node: PyAstPassStatement) {
visitPyStatement(node)
}
open fun visitPyYieldExpression(node: PyAstYieldExpression) {
visitPyExpression(node)

View File

@@ -6,4 +6,8 @@ import org.jetbrains.annotations.ApiStatus;
@ApiStatus.Experimental
public interface PyAstPassStatement extends PyAstStatement {
@Override
default void acceptPyVisitor(PyAstElementVisitor pyVisitor) {
pyVisitor.visitPyPassStatement(this);
}
}

View File

@@ -274,6 +274,10 @@ public class PyElementVisitor extends PsiElementVisitor {
visitPyElement(node);
}
public void visitPyPassStatement(@NotNull PyPassStatement node) {
visitPyStatement(node);
}
public void visitPyNoneLiteralExpression(@NotNull PyNoneLiteralExpression node) {
visitPyElement(node);
}

View File

@@ -8,9 +8,11 @@ import com.jetbrains.python.PyNames;
import com.jetbrains.python.ast.*;
import com.jetbrains.python.ast.impl.PyPsiUtilsCore;
import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
import com.jetbrains.python.psi.impl.PyTypeProvider;
import com.jetbrains.python.psi.stubs.PyFunctionStub;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.ApiStatus;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
@@ -28,9 +30,38 @@ public interface PyFunction extends PyAstFunction, StubBasedPsiElement<PyFunctio
PyFunction[] EMPTY_ARRAY = new PyFunction[0];
ArrayFactory<PyFunction> ARRAY_FACTORY = count -> count == 0 ? EMPTY_ARRAY : new PyFunction[count];
/**
* Infers function's return type by analyzing <b>only return statements</b> (including implicit returns) in its control flow.
* Does not consider yield statements or return type annotations.
*
* @see PyFunction#getInferredReturnType(TypeEvalContext)
*/
@Nullable
PyType getReturnStatementType(@NotNull TypeEvalContext context);
/**
* Infers function's return type by analyzing <b>return statements</b> (including implicit returns) and <b>yield expression</b>.
* In contrast with {@link TypeEvalContext#getReturnType(PyCallable)} does not consider
* return type annotations or any other {@link PyTypeProvider}.
*
* @apiNote Does not cache the result.
*/
@ApiStatus.Internal
@Nullable
PyType getInferredReturnType(@NotNull TypeEvalContext context);
/**
* Returns a list of all function exit points that can return a value.
* This includes explicit 'return' statements and statements that can complete
* normally with an implicit 'return None', excluding statements that raise exceptions.
*
* @see PyFunction#getReturnStatementType(TypeEvalContext)
* @return List of exit point statements, in control flow order
*/
@ApiStatus.Internal
@NotNull
List<PyStatement> getReturnPoints(@NotNull TypeEvalContext context);
/**
* Checks whether the function contains a yield expression in its body.
*/

View File

@@ -460,6 +460,9 @@ QFIX.remove.decorator=Remove decorator
QFIX.NAME.make.function.return.type=Make function return inferred type
QFIX.make.function.return.type=Make ''{0}'' return ''{1}''
# PyMakeReturnsExplicitQuickFix
QFIX.NAME.make.return.stmts.explicit=Make 'return None' statements explicit
# Add method quick-fix
QFIX.NAME.add.method.to.class=Add method to class
QFIX.add.method.to.class=Add method {0}() to class {1}
@@ -1055,7 +1058,7 @@ INSP.NAME.type.checker=Incorrect type
INSP.type.checker.expected.type.got.type.instead=Expected type ''{0}'', got ''{1}'' instead
INSP.type.checker.typed.dict.extra.key=Extra key ''{0}'' for TypedDict ''{1}''
INSP.type.checker.typed.dict.missing.keys=TypedDict ''{0}'' has missing {1,choice,1#key|2#keys}: {2}
INSP.type.checker.expected.to.return.type.got.no.return=Expected to return ''{0}'', got no return
INSP.type.checker.returning.type.has.implicit.return=Function returning ''{0}'' has implicit ''return None''
INSP.type.checker.init.should.return.none=__init__ should return None
INSP.type.checker.type.does.not.have.expected.attribute=Type ''{0}'' doesn''t have expected {1,choice,1#attribute|2#attributes} {2}
INSP.type.checker.only.concrete.class.can.be.used.where.matched.protocol.expected=Only a concrete class can be used where ''{0}'' (matched generic type ''{1}'') protocol is expected

View File

@@ -628,7 +628,7 @@ public final class PyStringFormatInspection extends PyInspection {
allForSure = allForSure && elementsCount != -1;
maxNumber = Math.max(maxNumber, elementsCount);
}
else {
else if (!(member instanceof PyNoneType)) {
allForSure = false;
}
}

View File

@@ -18,6 +18,7 @@ import com.jetbrains.python.codeInsight.typing.PyProtocolsKt;
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
import com.jetbrains.python.documentation.PythonDocumentationProvider;
import com.jetbrains.python.inspections.quickfix.PyMakeFunctionReturnTypeQuickFix;
import com.jetbrains.python.inspections.quickfix.PyMakeReturnsExplicitFix;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.resolve.PyResolveContext;
import com.jetbrains.python.psi.types.*;
@@ -27,6 +28,7 @@ import org.jetbrains.annotations.Nullable;
import java.util.*;
import static com.intellij.util.containers.ContainerUtil.exists;
import static com.jetbrains.python.psi.PyUtil.as;
import static com.jetbrains.python.psi.impl.PyCallExpressionHelper.*;
@@ -44,10 +46,17 @@ public class PyTypeCheckerInspection extends PyInspection {
}
public static class Visitor extends PyInspectionVisitor {
public Visitor(@Nullable ProblemsHolder holder, @NotNull TypeEvalContext context) {
public Visitor(@NotNull ProblemsHolder holder, @NotNull TypeEvalContext context) {
super(holder, context);
}
@Override
protected @NotNull ProblemsHolder getHolder() {
var holder = super.getHolder();
assert holder != null;
return holder;
}
// TODO: Visit decorators with arguments
@Override
public void visitPyCallExpression(@NotNull PyCallExpression node) {
@@ -83,32 +92,38 @@ public class PyTypeCheckerInspection extends PyInspection {
PyAnnotation annotation = function.getAnnotation();
String typeCommentAnnotation = function.getTypeCommentAnnotation();
if (annotation != null || typeCommentAnnotation != null) {
PyType expected = getExpectedReturnType(function, myTypeEvalContext);
if (expected == null) return;
// We cannot just match annotated and inferred types, as we cannot promote inferred to Literal
PyExpression returnExpr = node.getExpression();
PyType expected = getExpectedReturnType(function);
if (returnExpr == null && !(expected instanceof PyNoneType) && PyTypeChecker.match(expected, PyNoneType.INSTANCE, myTypeEvalContext)) {
final String expectedName = PythonDocumentationProvider.getVerboseTypeName(expected, myTypeEvalContext);
getHolder()
.problem(node, PyPsiBundle.message("INSP.type.checker.returning.type.has.implicit.return", expectedName))
.fix(new PyMakeReturnsExplicitFix(function))
.register();
return;
}
PyType actual = returnExpr != null ? tryPromotingType(returnExpr, expected) : PyNoneType.INSTANCE;
if (expected != null && actual instanceof PyTypedDictType) {
if (actual instanceof PyTypedDictType) {
if (reportTypedDictProblems(expected, (PyTypedDictType)actual, returnExpr)) return;
}
if (!PyTypeChecker.match(expected, actual, myTypeEvalContext)) {
String expectedName = PythonDocumentationProvider.getVerboseTypeName(expected, myTypeEvalContext);
String actualName = PythonDocumentationProvider.getTypeName(actual, myTypeEvalContext);
var localQuickFix = new PyMakeFunctionReturnTypeQuickFix(function, returnExpr, actual, myTypeEvalContext);
var globalQuickFix = new PyMakeFunctionReturnTypeQuickFix(function, returnExpr, null, myTypeEvalContext);
registerProblem(returnExpr != null ? returnExpr : node,
PyPsiBundle.message("INSP.type.checker.expected.type.got.type.instead", expectedName, actualName),
localQuickFix, globalQuickFix);
final String expectedName = PythonDocumentationProvider.getVerboseTypeName(expected, myTypeEvalContext);
final String actualName = PythonDocumentationProvider.getTypeName(actual, myTypeEvalContext);
getHolder()
.problem(returnExpr != null ? returnExpr : node, PyPsiBundle.message("INSP.type.checker.expected.type.got.type.instead", expectedName, actualName))
.fix(new PyMakeFunctionReturnTypeQuickFix(function, myTypeEvalContext))
.register();
}
}
}
}
@Nullable
private PyType getExpectedReturnType(@NotNull PyFunction function) {
return getExpectedReturnType(function, myTypeEvalContext);
}
@Nullable
public static PyType getExpectedReturnType(@NotNull PyFunction function, @NotNull TypeEvalContext typeEvalContext) {
final PyType returnType = typeEvalContext.getReturnType(function);
@@ -120,13 +135,6 @@ public class PyTypeCheckerInspection extends PyInspection {
return returnType;
}
@Nullable
public static PyType getActualReturnType(@NotNull PyFunction function, @Nullable PyExpression returnExpr,
@NotNull TypeEvalContext context) {
PyType returnTypeExpected = getExpectedReturnType(function, context);
return returnExpr != null ? tryPromotingType(returnExpr, returnTypeExpected, context) : PyNoneType.INSTANCE;
}
@Override
public void visitPyTargetExpression(@NotNull PyTargetExpression node) {
// TODO: Check types in class-level assignments
@@ -230,21 +238,28 @@ public class PyTypeCheckerInspection extends PyInspection {
final PyAnnotation annotation = node.getAnnotation();
final String typeCommentAnnotation = node.getTypeCommentAnnotation();
if (annotation != null || typeCommentAnnotation != null) {
if (!PyUtil.isEmptyFunction(node)) {
final ReturnVisitor visitor = new ReturnVisitor(node);
node.getStatementList().accept(visitor);
if (!visitor.myHasReturns) {
final PyType expected = getExpectedReturnType(node);
final String expectedName = PythonDocumentationProvider.getTypeName(expected, myTypeEvalContext);
if (expected != null && !(expected instanceof PyNoneType)) {
registerProblem(annotation != null ? annotation.getValue() : node.getTypeComment(),
PyPsiBundle.message("INSP.type.checker.expected.to.return.type.got.no.return", expectedName));
final PyType expected = getExpectedReturnType(node, myTypeEvalContext);
final boolean returnsNone = expected instanceof PyNoneType;
final boolean returnsOptional = PyTypeChecker.match(expected, PyNoneType.INSTANCE, myTypeEvalContext);
if (expected != null && !returnsOptional && !PyUtil.isEmptyFunction(node)) {
final List<PyStatement> returnPoints = node.getReturnPoints(myTypeEvalContext);
final boolean hasImplicitReturns = exists(returnPoints, it -> !(it instanceof PyReturnStatement));
if (hasImplicitReturns) {
final String expectedName = PythonDocumentationProvider.getVerboseTypeName(expected, myTypeEvalContext);
final String actualName = PythonDocumentationProvider.getTypeName(node.getReturnStatementType(myTypeEvalContext), myTypeEvalContext);
final PsiElement annotationValue = annotation != null ? annotation.getValue() : node.getTypeComment();
if (annotationValue != null) {
getHolder()
.problem(annotationValue, PyPsiBundle.message("INSP.type.checker.expected.type.got.type.instead", expectedName, actualName))
.fix(new PyMakeFunctionReturnTypeQuickFix(node, myTypeEvalContext))
.register();
}
}
}
if (PyUtil.isInitMethod(node) && !(getExpectedReturnType(node) instanceof PyNoneType
|| PyTypingTypeProvider.isNoReturn(node, myTypeEvalContext))) {
if (PyUtil.isInitMethod(node) && !(returnsNone || PyTypingTypeProvider.isNoReturn(node, myTypeEvalContext))) {
registerProblem(annotation != null ? annotation.getValue() : node.getTypeComment(),
PyPsiBundle.message("INSP.type.checker.init.should.return.none"));
}
@@ -260,29 +275,6 @@ public class PyTypeCheckerInspection extends PyInspection {
}
}
private static class ReturnVisitor extends PyRecursiveElementVisitor {
private final PyFunction myFunction;
private boolean myHasReturns = false;
ReturnVisitor(PyFunction function) {
myFunction = function;
}
@Override
public void visitPyYieldExpression(@NotNull PyYieldExpression node) {
if (ScopeUtil.getScopeOwner(node) == myFunction) {
myHasReturns = true;
}
}
@Override
public void visitPyReturnStatement(@NotNull PyReturnStatement node) {
if (ScopeUtil.getScopeOwner(node) == myFunction) {
myHasReturns = true;
}
}
}
private void checkCallSite(@NotNull PyCallSiteExpression callSite) {
final List<AnalyzeCalleeResults> calleesResults = StreamEx
.of(mapArguments(callSite, getResolveContext()))
@@ -514,7 +506,7 @@ public class PyTypeCheckerInspection extends PyInspection {
}
private static boolean matchedCalleeResultsExist(@NotNull List<AnalyzeCalleeResults> calleesResults) {
return ContainerUtil.exists(calleesResults, calleeResults ->
return exists(calleesResults, calleeResults ->
ContainerUtil.all(calleeResults.getResults(), AnalyzeArgumentResult::isMatched) &&
calleeResults.getUnmatchedArguments().isEmpty() &&
calleeResults.getUnmatchedParameters().isEmpty() &&

View File

@@ -32,52 +32,27 @@ import org.jetbrains.annotations.Nullable;
import java.util.List;
import static com.jetbrains.python.inspections.PyTypeCheckerInspection.Visitor.getActualReturnType;
/**
* @author lada
*/
public class PyMakeFunctionReturnTypeQuickFix implements LocalQuickFix {
private final @NotNull SmartPsiElementPointer<PyFunction> myFunction;
private final @Nullable SmartPsiElementPointer<PyExpression> myReturnExpr;
private final @Nullable SmartPsiElementPointer<PyAnnotation> myAnnotation;
private final @Nullable SmartPsiElementPointer<PsiComment> myTypeCommentAnnotation;
private final String myReturnTypeName;
private final boolean myHaveSuggestedType;
public PyMakeFunctionReturnTypeQuickFix(@NotNull PyFunction function,
@Nullable PyExpression returnExpr,
@Nullable PyType suggestedReturnType,
@NotNull TypeEvalContext context) {
this(function,
returnExpr,
function.getAnnotation(),
function.getTypeComment(),
suggestedReturnType != null,
getReturnTypeName(function, suggestedReturnType, context));
public PyMakeFunctionReturnTypeQuickFix(@NotNull PyFunction function, @NotNull TypeEvalContext context) {
this(function, getReturnTypeName(function, context));
}
private PyMakeFunctionReturnTypeQuickFix(@NotNull PyFunction function, @NotNull String returnTypeName) {
SmartPointerManager manager = SmartPointerManager.getInstance(function.getProject());
myFunction = manager.createSmartPsiElementPointer(function);
myReturnTypeName = returnTypeName;
}
@NotNull
private static String getReturnTypeName(@NotNull PyFunction function, @Nullable PyType returnType, @NotNull TypeEvalContext context) {
PyType type = returnType != null ? returnType : function.getReturnStatementType(context);
private static String getReturnTypeName(@NotNull PyFunction function, @NotNull TypeEvalContext context) {
final PyType type = function.getInferredReturnType(context);
return PythonDocumentationProvider.getTypeHint(type, context);
}
private PyMakeFunctionReturnTypeQuickFix(@NotNull PyFunction function,
@Nullable PyExpression returnExpr,
@Nullable PyAnnotation annotation,
@Nullable PsiComment typeComment,
boolean returnTypeSuggested,
@NotNull String returnTypeName) {
SmartPointerManager manager = SmartPointerManager.getInstance(function.getProject());
myFunction = manager.createSmartPsiElementPointer(function);
myReturnExpr = returnExpr != null ? manager.createSmartPsiElementPointer(returnExpr) : null;
myAnnotation = annotation != null ? manager.createSmartPsiElementPointer(annotation) : null;
myTypeCommentAnnotation = typeComment != null ? manager.createSmartPsiElementPointer(typeComment) : null;
myHaveSuggestedType = returnTypeSuggested;
myReturnTypeName = returnTypeName;
}
@Override
@NotNull
public String getName() {
@@ -94,51 +69,46 @@ public class PyMakeFunctionReturnTypeQuickFix implements LocalQuickFix {
@Override
public void applyFix(@NotNull Project project, @NotNull ProblemDescriptor descriptor) {
PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project);
if (myAnnotation != null) {
PyAnnotation annotation = myAnnotation.getElement();
if (annotation != null) {
PyExpression annotationExpr = annotation.getValue();
if (annotationExpr == null) return;
PsiElement newElement =
annotationExpr.replace(elementGenerator.createExpressionFromText(LanguageLevel.PYTHON34, myReturnTypeName));
addImportsForTypeAnnotations(newElement);
final PyFunction function = myFunction.getElement();
if (function == null) return;
final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project);
boolean shouldAddImports = false;
PyAnnotation annotation = function.getAnnotation();
if (annotation != null) {
PyExpression annotationExpr = annotation.getValue();
if (annotationExpr != null) {
annotationExpr.replace(elementGenerator.createExpressionFromText(LanguageLevel.PYTHON34, myReturnTypeName));
shouldAddImports = true;
}
}
else if (myTypeCommentAnnotation != null) {
PsiComment typeComment = myTypeCommentAnnotation.getElement();
if (typeComment != null) {
StringBuilder typeCommentAnnotation = new StringBuilder(typeComment.getText());
typeCommentAnnotation.delete(typeCommentAnnotation.indexOf("->"), typeCommentAnnotation.length());
typeCommentAnnotation.append("-> ").append(myReturnTypeName);
PsiComment newTypeComment =
elementGenerator.createFromText(LanguageLevel.PYTHON27, PsiComment.class, typeCommentAnnotation.toString());
PsiElement newElement = typeComment.replace(newTypeComment);
addImportsForTypeAnnotations(newElement);
}
PsiComment typeComment = function.getTypeComment();
if (typeComment != null) {
StringBuilder typeCommentAnnotation = new StringBuilder(typeComment.getText());
typeCommentAnnotation.delete(typeCommentAnnotation.indexOf("->"), typeCommentAnnotation.length());
typeCommentAnnotation.append("-> ").append(myReturnTypeName);
typeComment.replace(
elementGenerator.createFromText(LanguageLevel.PYTHON27, PsiComment.class, typeCommentAnnotation.toString()));
shouldAddImports = true;
}
if (shouldAddImports) {
addImportsForTypeAnnotations(TypeEvalContext.userInitiated(project, function.getContainingFile()));
}
}
private void addImportsForTypeAnnotations(@NotNull PsiElement element) {
PsiFile file = element.getContainingFile();
if (file == null) return;
private void addImportsForTypeAnnotations(@NotNull TypeEvalContext context) {
PyFunction function = myFunction.getElement();
if (function == null) return;
Project project = element.getProject();
TypeEvalContext typeEvalContext = TypeEvalContext.userInitiated(project, file);
PyType typeForImports = getTypeForImports(function, typeEvalContext);
PsiFile file = function.getContainingFile();
if (file == null) return;
PyType typeForImports = function.getInferredReturnType(context);
if (typeForImports != null) {
PyTypeHintGenerationUtil.addImportsForTypeAnnotations(List.of(typeForImports), typeEvalContext, file);
}
}
private @Nullable PyType getTypeForImports(@NotNull PyFunction function, @NotNull TypeEvalContext context) {
PyType returnTypeActual = getActualReturnType(function, myReturnExpr != null ? myReturnExpr.getElement() : null, context);
if (myHaveSuggestedType && returnTypeActual != null) {
return returnTypeActual;
}
else {
return function.getReturnStatementType(context);
PyTypeHintGenerationUtil.addImportsForTypeAnnotations(List.of(typeForImports), context, file);
}
}
@@ -148,18 +118,6 @@ public class PyMakeFunctionReturnTypeQuickFix implements LocalQuickFix {
if (function == null) {
return null;
}
@Nullable PyExpression returnExpr = PyRefactoringUtil.findSameElementForPreview(myReturnExpr, target);
if (myReturnExpr != null && returnExpr == null) {
return null;
}
@Nullable PyAnnotation annotation = PyRefactoringUtil.findSameElementForPreview(myAnnotation, target);
if (myAnnotation != null && annotation == null) {
return null;
}
@Nullable PsiComment typeComment = PyRefactoringUtil.findSameElementForPreview(myTypeCommentAnnotation, target);
if (myTypeCommentAnnotation != null && typeComment == null) {
return null;
}
return new PyMakeFunctionReturnTypeQuickFix(function, returnExpr, annotation, typeComment, myHaveSuggestedType, myReturnTypeName);
return new PyMakeFunctionReturnTypeQuickFix(function, myReturnTypeName);
}
}

View File

@@ -0,0 +1,46 @@
package com.jetbrains.python.inspections.quickfix;
import com.intellij.modcommand.ActionContext;
import com.intellij.modcommand.ModPsiUpdater;
import com.intellij.modcommand.PsiUpdateModCommandAction;
import com.jetbrains.python.PyPsiBundle;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
/**
* Appends missing {@code return None}, and transforms {@code return} into {@code return None}.
*/
public class PyMakeReturnsExplicitFix extends PsiUpdateModCommandAction<PyFunction> {
public PyMakeReturnsExplicitFix(@NotNull PyFunction function) {
super(function);
}
@Override
protected void invoke(@NotNull ActionContext context, @NotNull PyFunction element, @NotNull ModPsiUpdater updater) {
var returnPoints = element.getReturnPoints(TypeEvalContext.userInitiated(element.getProject(), element.getContainingFile()));
for (var point : returnPoints) {
makeExplicit(point);
}
}
@Override
public @NotNull String getFamilyName() {
return PyPsiBundle.message("QFIX.NAME.make.return.stmts.explicit");
}
private static void makeExplicit(@NotNull PyStatement stmt) {
PyElementGenerator elementGenerator = PyElementGenerator.getInstance(stmt.getProject());
LanguageLevel languageLevel = LanguageLevel.forElement(stmt);
var returnStmt = elementGenerator.createFromText(languageLevel, PyReturnStatement.class, "return None");
if ((stmt instanceof PyReturnStatement ret && ret.getExpression() == null) || (stmt instanceof PyPassStatement)) {
stmt.replace(returnStmt);
}
else if (!(stmt instanceof PyReturnStatement)) {
stmt.getParent().addAfter(returnStmt, stmt);
}
}
}

View File

@@ -1,6 +1,8 @@
// Copyright 2000-2018 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.psi.impl;
import com.intellij.codeInsight.controlflow.ControlFlowUtil;
import com.intellij.codeInsight.controlflow.Instruction;
import com.intellij.lang.ASTNode;
import com.intellij.navigation.ItemPresentation;
import com.intellij.openapi.util.Key;
@@ -21,6 +23,7 @@ import com.intellij.util.containers.ContainerUtil;
import com.intellij.util.containers.JBIterable;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.PyStubElementTypes;
import com.jetbrains.python.codeInsight.controlflow.CallInstruction;
import com.jetbrains.python.codeInsight.controlflow.ControlFlowCache;
import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
@@ -153,7 +156,12 @@ public class PyFunctionImpl extends PyBaseElementImpl<PyFunctionStub> implements
return PyTypingTypeProvider.removeNarrowedTypeIfNeeded(derefType(returnTypeRef, typeProvider));
}
}
return getInferredReturnType(context);
}
@Override
public @Nullable PyType getInferredReturnType(@NotNull TypeEvalContext context) {
PyType inferredType = null;
if (context.allowReturnTypes(this)) {
final Ref<? extends PyType> yieldTypeRef = getYieldStatementType(context);
@@ -342,16 +350,54 @@ public class PyFunctionImpl extends PyBaseElementImpl<PyFunctionStub> implements
@Override
public @Nullable PyType getReturnStatementType(@NotNull TypeEvalContext context) {
final ReturnVisitor visitor = new ReturnVisitor(this, context);
final PyStatementList statements = getStatementList();
statements.accept(visitor);
if ((isGeneratedStub() || PyKnownDecoratorUtil.hasAbstractDecorator(this, context)) && !visitor.myHasReturns) {
return PyUtil.getNullableParameterizedCachedValue(this, context, (it) -> getReturnStatementTypeNoCache(it));
}
private @Nullable PyType getReturnStatementTypeNoCache(@NotNull TypeEvalContext context) {
final List<PyStatement> returnPoints = getReturnPoints(context);
final List<PyType> types = new ArrayList<>();
boolean hasReturn = false;
for (var point : returnPoints) {
if (point instanceof PyReturnStatement returnStatement) {
hasReturn = true;
final PyExpression expr = returnStatement.getExpression();
types.add(expr != null ? context.getType(expr) : PyNoneType.INSTANCE);
}
else {
types.add(PyNoneType.INSTANCE);
}
}
if ((isGeneratedStub() || PyKnownDecoratorUtil.hasAbstractDecorator(this, context)) && !hasReturn) {
if (PyUtil.isInitMethod(this)) {
return PyNoneType.INSTANCE;
}
return null;
}
return visitor.result();
return PyUnionType.union(types);
}
@Override
public @NotNull List<PyStatement> getReturnPoints(@NotNull TypeEvalContext context) {
final Instruction[] flow = ControlFlowCache.getControlFlow(this).getInstructions();
final List<PyStatement> returnPoints = new ArrayList<>();
ControlFlowUtil.iteratePrev(flow.length-1, flow, instruction -> {
if (instruction instanceof CallInstruction ci && ci.isNoReturnCall(context)) {
return ControlFlowUtil.Operation.CONTINUE;
}
final PsiElement element = instruction.getElement();
if (!(element instanceof PyStatement statement)) {
return ControlFlowUtil.Operation.NEXT;
}
if (element instanceof PyRaiseStatement) {
return ControlFlowUtil.Operation.CONTINUE;
}
returnPoints.add(statement);
return ControlFlowUtil.Operation.CONTINUE;
});
return returnPoints;
}
@Override
@@ -412,44 +458,6 @@ public class PyFunctionImpl extends PyBaseElementImpl<PyFunctionStub> implements
return false;
}
private static final class ReturnVisitor extends PyRecursiveElementVisitor {
private final PyFunction myFunction;
private final TypeEvalContext myContext;
private PyType myResult = null;
private boolean myHasReturns = false;
private boolean myHasRaises = false;
private ReturnVisitor(PyFunction function, final TypeEvalContext context) {
myFunction = function;
myContext = context;
}
@Override
public void visitPyReturnStatement(@NotNull PyReturnStatement node) {
if (ScopeUtil.getScopeOwner(node) == myFunction) {
final PyExpression expr = node.getExpression();
PyType returnType = expr == null ? PyNoneType.INSTANCE : myContext.getType(expr);
if (!myHasReturns) {
myResult = returnType;
myHasReturns = true;
}
else {
myResult = PyUnionType.union(myResult, returnType);
}
}
}
@Override
public void visitPyRaiseStatement(@NotNull PyRaiseStatement node) {
myHasRaises = true;
}
@Nullable
PyType result() {
return myHasReturns || myHasRaises ? myResult : PyNoneType.INSTANCE;
}
}
@Override
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
pyVisitor.visitPyFunction(this);

View File

@@ -16,6 +16,7 @@
package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.jetbrains.python.psi.PyElementVisitor;
import com.jetbrains.python.psi.PyPassStatement;
@@ -23,4 +24,9 @@ public class PyPassStatementImpl extends PyElementImpl implements PyPassStatemen
public PyPassStatementImpl(ASTNode astNode) {
super(astNode);
}
@Override
protected void acceptPyVisitor(PyElementVisitor pyVisitor) {
pyVisitor.visitPyPassStatement(this);
}
}

View File

@@ -1,17 +1,24 @@
0(1) element: null
1(2,3) element: PyIfStatement
1(2,4) element: PyIfStatement
2(3) element: PyStatementList. Condition: 0j:true
3(4,5) element: PyIfStatement
4(6) element: PyStatementList. Condition: 1j:true
5(6) element: PyStatementList. Condition: 1j:false
6(7,8) element: PyIfStatement
7(10) element: PyStatementList. Condition: 2j:true
8(9,10) element: PyIfPartElif. Condition: 2j:false
9(10) element: PyStatementList. Condition: 3j:true
10(11,12) element: PyIfStatement
11(16) element: PyStatementList. Condition: 4j:true
12(13,14) element: PyIfPartElif. Condition: 4j:false
13(16) element: PyStatementList. Condition: 5j:true
14(15) element: PyStatementList. Condition: 5j:false
15(16) element: PyExpressionStatement
16() element: null
3(4) element: PyPassStatement
4(5,7) element: PyIfStatement
5(6) element: PyStatementList. Condition: 1j:true
6(9) element: PyPassStatement
7(8) element: PyStatementList. Condition: 1j:false
8(9) element: PyPassStatement
9(10,12) element: PyIfStatement
10(11) element: PyStatementList. Condition: 2j:true
11(15) element: PyPassStatement
12(13,15) element: PyIfPartElif. Condition: 2j:false
13(14) element: PyStatementList. Condition: 3j:true
14(15) element: PyPassStatement
15(16,18) element: PyIfStatement
16(17) element: PyStatementList. Condition: 4j:true
17(23) element: PyPassStatement
18(19,21) element: PyIfPartElif. Condition: 4j:false
19(20) element: PyStatementList. Condition: 5j:true
20(23) element: PyPassStatement
21(22) element: PyStatementList. Condition: 5j:false
22(23) element: PyExpressionStatement
23() element: null

View File

@@ -1,23 +1,26 @@
0(1) element: null
1(2) element: PyIfStatement
2(3,5) READ ACCESS: c
2(3,6) READ ACCESS: c
3(4) element: PyStatementList. Condition: c:true
4(11) ASSERTTYPE ACCESS: c
5(6) element: PyIfPartElif. Condition: c:false
6(11) READ ACCESS: False
7(8) element: PyStatementList. Condition: False:true
8(9) ASSERTTYPE ACCESS: False
9(10) element: PyAssignmentStatement
10(11) WRITE ACCESS: a
11(12) element: PyIfStatement
12(13,15) READ ACCESS: d
13(14) element: PyStatementList. Condition: d:true
14(22) ASSERTTYPE ACCESS: d
15(16) element: PyIfPartElif. Condition: d:false
16(21) READ ACCESS: False
17(18) element: PyStatementList. Condition: False:true
18(19) ASSERTTYPE ACCESS: False
19(20) element: PyAssignmentStatement
20(22) WRITE ACCESS: b
21(22) element: PyStatementList. Condition: False:false
22() element: null
4(5) ASSERTTYPE ACCESS: c
5(12) element: PyPassStatement
6(7) element: PyIfPartElif. Condition: c:false
7(12) READ ACCESS: False
8(9) element: PyStatementList. Condition: False:true
9(10) ASSERTTYPE ACCESS: False
10(11) element: PyAssignmentStatement
11(12) WRITE ACCESS: a
12(13) element: PyIfStatement
13(14,17) READ ACCESS: d
14(15) element: PyStatementList. Condition: d:true
15(16) ASSERTTYPE ACCESS: d
16(25) element: PyPassStatement
17(18) element: PyIfPartElif. Condition: d:false
18(23) READ ACCESS: False
19(20) element: PyStatementList. Condition: False:true
20(21) ASSERTTYPE ACCESS: False
21(22) element: PyAssignmentStatement
22(25) WRITE ACCESS: b
23(24) element: PyStatementList. Condition: False:false
24(25) element: PyPassStatement
25() element: null

View File

@@ -1,21 +1,25 @@
0(1) element: null
1(2) element: PyIfStatement
2(3,5) READ ACCESS: c
2(3,6) READ ACCESS: c
3(4) element: PyStatementList. Condition: c:true
4(9) ASSERTTYPE ACCESS: c
5(6) element: PyIfPartElif. Condition: c:false
6(7) READ ACCESS: True
7(8) element: PyStatementList. Condition: True:true
8(9) ASSERTTYPE ACCESS: True
9(10) element: PyIfStatement
10(11,13) READ ACCESS: d
11(12) element: PyStatementList. Condition: d:true
12(20) ASSERTTYPE ACCESS: d
13(14) element: PyIfPartElif. Condition: d:false
14(15) READ ACCESS: True
15(16) element: PyStatementList. Condition: True:true
16(20) ASSERTTYPE ACCESS: True
17(18) element: PyStatementList. Condition: True:false
18(19) element: PyAssignmentStatement
19(20) WRITE ACCESS: e
20() element: null
4(5) ASSERTTYPE ACCESS: c
5(11) element: PyPassStatement
6(7) element: PyIfPartElif. Condition: c:false
7(8) READ ACCESS: True
8(9) element: PyStatementList. Condition: True:true
9(10) ASSERTTYPE ACCESS: True
10(11) element: PyPassStatement
11(12) element: PyIfStatement
12(13,16) READ ACCESS: d
13(14) element: PyStatementList. Condition: d:true
14(15) ASSERTTYPE ACCESS: d
15(24) element: PyPassStatement
16(17) element: PyIfPartElif. Condition: d:false
17(18) READ ACCESS: True
18(19) element: PyStatementList. Condition: True:true
19(20) ASSERTTYPE ACCESS: True
20(24) element: PyPassStatement
21(22) element: PyStatementList. Condition: True:false
22(23) element: PyAssignmentStatement
23(24) WRITE ACCESS: e
24() element: null

View File

@@ -10,17 +10,20 @@
9(10) element: PyStatementList. Condition: False:true
10(11) ASSERTTYPE ACCESS: False
11(12) element: PyAssignmentStatement
12(14) WRITE ACCESS: b
12(15) WRITE ACCESS: b
13(14) element: PyStatementList. Condition: False:false
14(15) element: PyIfStatement
15(20) READ ACCESS: False
16(17) element: PyStatementList. Condition: False:true
17(18) ASSERTTYPE ACCESS: False
18(19) element: PyAssignmentStatement
19(25) WRITE ACCESS: c
20(21) element: PyIfPartElif. Condition: False:false
21(22,24) READ ACCESS: d
22(23) element: PyStatementList. Condition: d:true
23(25) ASSERTTYPE ACCESS: d
24(25) element: PyStatementList. Condition: d:false
25() element: null
14(15) element: PyPassStatement
15(16) element: PyIfStatement
16(21) READ ACCESS: False
17(18) element: PyStatementList. Condition: False:true
18(19) ASSERTTYPE ACCESS: False
19(20) element: PyAssignmentStatement
20(28) WRITE ACCESS: c
21(22) element: PyIfPartElif. Condition: False:false
22(23,26) READ ACCESS: d
23(24) element: PyStatementList. Condition: d:true
24(25) ASSERTTYPE ACCESS: d
25(28) element: PyPassStatement
26(27) element: PyStatementList. Condition: d:false
27(28) element: PyPassStatement
28() element: null

View File

@@ -3,24 +3,27 @@
2(3) READ ACCESS: True
3(4) element: PyStatementList. Condition: True:true
4(5) ASSERTTYPE ACCESS: True
5(6) element: PyIfStatement
6(7) READ ACCESS: True
7(8) element: PyStatementList. Condition: True:true
8(12) ASSERTTYPE ACCESS: True
9(10) element: PyStatementList. Condition: True:false
10(11) element: PyAssignmentStatement
11(12) WRITE ACCESS: b
12(13) element: PyIfStatement
13(14) READ ACCESS: True
14(15) element: PyStatementList. Condition: True:true
15(25) ASSERTTYPE ACCESS: True
16(17) element: PyIfPartElif. Condition: True:false
17(18) READ ACCESS: c
18(19) element: PyStatementList. Condition: c:true
19(20) ASSERTTYPE ACCESS: c
20(21) element: PyAssignmentStatement
21(25) WRITE ACCESS: d
22(23) element: PyStatementList. Condition: c:false
5(6) element: PyPassStatement
6(7) element: PyIfStatement
7(8) READ ACCESS: True
8(9) element: PyStatementList. Condition: True:true
9(10) ASSERTTYPE ACCESS: True
10(14) element: PyPassStatement
11(12) element: PyStatementList. Condition: True:false
12(13) element: PyAssignmentStatement
13(14) WRITE ACCESS: b
14(15) element: PyIfStatement
15(16) READ ACCESS: True
16(17) element: PyStatementList. Condition: True:true
17(18) ASSERTTYPE ACCESS: True
18(28) element: PyPassStatement
19(20) element: PyIfPartElif. Condition: True:false
20(21) READ ACCESS: c
21(22) element: PyStatementList. Condition: c:true
22(23) ASSERTTYPE ACCESS: c
23(24) element: PyAssignmentStatement
24(25) WRITE ACCESS: e
25() element: null
24(28) WRITE ACCESS: d
25(26) element: PyStatementList. Condition: c:false
26(27) element: PyAssignmentStatement
27(28) WRITE ACCESS: e
28() element: null

View File

@@ -0,0 +1 @@
pass

View File

@@ -0,0 +1,3 @@
0(1) element: null
1(2) element: PyPassStatement
2() element: null

View File

@@ -25,6 +25,7 @@
24(25,28) element: PyCallExpression: checkit
25(26) element: PyStatementList. Condition: checkit(x):true
26(27) element: PyPrintStatement
27(29) READ ACCESS: x
27(30) READ ACCESS: x
28(29) element: PyStatementList. Condition: checkit(x):false
29() element: null
29(30) element: PyPassStatement
30() element: null

View File

@@ -6,4 +6,5 @@
5(6) element: PySubscriptionExpression
6(7) READ ACCESS: BaseClass
7(8) READ ACCESS: T
8() element: null
8(9) element: PyPassStatement
9() element: null

View File

@@ -7,4 +7,5 @@
6(7) WRITE ACCESS: a
7(8) READ ACCESS: U
8(9) WRITE ACCESS: b
9() element: null
9(10) element: PyPassStatement
10() element: null

View File

@@ -11,13 +11,14 @@
10(11) READ ACCESS: var
11(12) READ ACCESS: A
12(13,14) element: PyCallExpression: isinstance
13(21) element: null. Condition: isinstance(var, A):false
13(22) element: null. Condition: isinstance(var, A):false
14(15) element: null. Condition: isinstance(var, A):true
15(16) ASSERTTYPE ACCESS: var
16(17,18) READ ACCESS: var
17(21) element: null. Condition: var:false
17(22) element: null. Condition: var:false
18(19) element: null. Condition: var:true
19(20) element: PyStatementList. Condition: isinstance(var, A) and var:true
20(22) ASSERTTYPE ACCESS: var
21(22) ASSERTTYPE ACCESS: var
22() element: null
20(21) ASSERTTYPE ACCESS: var
21(23) element: PyPassStatement
22(23) ASSERTTYPE ACCESS: var
23() element: null

View File

@@ -7,7 +7,7 @@
6(7) READ ACCESS: isinstance
7(8) READ ACCESS: var
8(9) READ ACCESS: A
9(10,27) element: PyCallExpression: isinstance
9(10,28) element: PyCallExpression: isinstance
10(11) element: PyStatementList. Condition: isinstance(var, A):true
11(12) ASSERTTYPE ACCESS: var
12(13) element: PyIfStatement
@@ -20,10 +20,11 @@
19(20) element: null. Condition: isinstance(var, B):false
20(21) ASSERTTYPE ACCESS: var
21(22,23) READ ACCESS: var
22(26) element: null. Condition: var:false
22(27) element: null. Condition: var:false
23(24) element: null. Condition: var:true
24(25) element: PyStatementList. Condition: isinstance(var, B) or var:true
25(28) ASSERTTYPE ACCESS: var
26(28) ASSERTTYPE ACCESS: var
27(28) ASSERTTYPE ACCESS: var
28() element: null
25(26) ASSERTTYPE ACCESS: var
26(29) element: PyPassStatement
27(29) ASSERTTYPE ACCESS: var
28(29) ASSERTTYPE ACCESS: var
29() element: null

View File

@@ -1,7 +1,8 @@
0(1) element: null
1(2) element: PyWhileStatement
2(3) READ ACCESS: True
3(1) element: PyStatementList. Condition: True:true
4(5) element: PyElsePart. Condition: True:false
5(6) element: PyPrintStatement
6() element: null
3(4) element: PyStatementList. Condition: True:true
4(1) element: PyPassStatement
5(6) element: PyElsePart. Condition: True:false
6(7) element: PyPrintStatement
7() element: null

View File

@@ -22,7 +22,7 @@ def f() -> Optional[str]:
elif x == 0:
return 'abc'
else:
return
<warning descr="Function returning 'str | None' has implicit 'return None'">return</warning>
def g(x) -> int:
if x:
@@ -36,9 +36,18 @@ def h(x) -> int:
def i() -> Union[int, str]:
pass
def j(x) -> <warning descr="Expected to return 'int | str', got no return">Union[int, str]</warning>:
def j(x) -> <warning descr="Expected type 'int | str', got 'None' instead">Union[int, str]</warning>:
x = 42
def k() -> None:
if True:
pass
pass
def l(x) -> <warning descr="Expected type 'int', got 'int | None' instead">int</warning>:
if x == 1:
return 42
def m(x) -> None:
"""Does not display warning about implicit return, because annotated '-> None' """
if x:
return

View File

@@ -40,9 +40,9 @@ def g() -> Point:
else:
<warning descr="Expected type 'Point', got 'None' instead">return</warning>
def h(x) -> <warning descr="Expected to return 'Point', got no return">Point</warning>:
def h(x) -> <warning descr="Expected type 'Point', got 'None' instead">Point</warning>:
x = 42
def i() -> <warning descr="Expected to return 'Point', got no return">Point</warning>:
def i() -> <warning descr="Expected type 'Point', got 'None' instead">Point</warning>:
if True:
pass

View File

@@ -1,10 +1,10 @@
from typing import Type
import my
from my import X
from my import X, Y
def foo(a) -> Type[X]:
def foo(a) -> Type[X | Y]:
if a:
return my.X<caret>
else:

View File

@@ -0,0 +1,7 @@
def f(x) -> int | None:
if x == 1:
return 42
elif x == 2:
<warning descr="Function returning 'int | None' has implicit 'return None'">return<caret></warning>
elif x == 3:
pass

View File

@@ -0,0 +1,8 @@
def f(x) -> int | None:
if x == 1:
return 42
elif x == 2:
return None
elif x == 3:
return None
return None

View File

@@ -42,6 +42,10 @@ public class PyControlFlowBuilderTest extends LightMarkedTestCase {
doTest();
}
public void testPass() {
doTest();
}
public void testFile() {
doTest();
}

View File

@@ -55,7 +55,7 @@ public class PyMakeFunctionReturnTypeQuickFixTest extends PyQuickFixTestCase {
// PY-27128
public void testAncestorAndInheritorReturn() {
doMultiFileTest(PyTypeCheckerInspection.class, PyPsiBundle.message("QFIX.make.function.return.type", "foo", "Type[X]"));
doMultiFileTest(PyTypeCheckerInspection.class, PyPsiBundle.message("QFIX.make.function.return.type", "foo", "Type[X | Y]"));
}
// PY-27128 PY-48466

View File

@@ -0,0 +1,15 @@
// Copyright 2000-2024 JetBrains s.r.o. and contributors. Use of this source code is governed by the Apache 2.0 license.
package com.jetbrains.python.quickFixes;
import com.intellij.testFramework.TestDataPath;
import com.jetbrains.python.PyPsiBundle;
import com.jetbrains.python.PyQuickFixTestCase;
import com.jetbrains.python.inspections.PyTypeCheckerInspection;
@TestDataPath("$CONTENT_ROOT/../testData/quickFixes/PyMakeReturnsExplicitFixTest/")
public class PyMakeReturnsExplicitFixTest extends PyQuickFixTestCase {
public void testAddReturnsFromReturnStmt() {
doQuickFixTest(PyTypeCheckerInspection.class, PyPsiBundle.message("QFIX.NAME.make.return.stmts.explicit"));
}
}