From 70fe60b4c8df8bced8b4dff6d2f1069557766924 Mon Sep 17 00:00:00 2001 From: "Aleksandr.Govenko" Date: Wed, 4 Dec 2024 15:22:15 +0000 Subject: [PATCH] PY-20710 Support 'Generator' typing class Check YieldType of yield expressions in PyTypeCheckerInspection Check that (Async)Generator is used in (async) function Check that in 'yield from' sync Generator is used Convert PyMakeFunctionReturnTypeQuickFix into PsiUpdateModCommandAction Infer Generator type for lambdas When getting function type from annotation, do not convert Generator to AsyncGenerator Introduce GeneratorTypeDescriptor to simplify working with generator annotations Merge-request: IJ-MR-146521 Merged-by: Aleksandr Govenko (cherry picked from commit b3b8182168c5224f0e03f54d443171ccf6ca7b89) IJ-MR-146521 GitOrigin-RevId: a95670d7e2787015bcf162637ea6d7bfb47a312a --- .../com/jetbrains/python/psi/PyFunction.java | 1 + .../python/psi/PyYieldExpression.java | 16 +++ .../resources/messages/PyPsiBundle.properties | 3 + .../typing/PyTypingTypeProvider.java | 85 ++++++++++++++-- .../inspections/PyTypeCheckerInspection.java | 99 +++++++++++++++++-- .../PyMakeFunctionReturnTypeQuickFix.java | 61 +++++------- .../python/psi/impl/PyFunctionImpl.java | 88 +++++++++-------- .../psi/impl/PyGeneratorExpressionImpl.java | 2 +- .../psi/impl/PyLambdaExpressionImpl.java | 23 ++++- .../psi/impl/PyTargetExpressionImpl.java | 11 +-- .../psi/impl/PyYieldExpressionImpl.java | 57 +++++++++-- .../FunctionReturnTypePy3.py | 6 +- .../FunctionYieldTypePy3.py | 69 +++++++++++++ .../changeGenerator.py | 5 + .../changeGenerator_after.py | 5 + .../makeGenerator.py | 4 + .../makeGenerator_after.py | 7 ++ .../addReturnsFromAnnotation.py | 5 + .../addReturnsFromAnnotation_after.py | 6 ++ .../com/jetbrains/python/Py3TypeTest.java | 60 +++++++++++ .../Py3TypeCheckerInspectionTest.java | 4 + .../PyMakeFunctionReturnTypeQuickFixTest.java | 14 +++ 22 files changed, 509 insertions(+), 122 deletions(-) create mode 100644 python/testData/inspections/PyTypeCheckerInspection/FunctionYieldTypePy3.py create mode 100644 python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/changeGenerator.py create mode 100644 python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/changeGenerator_after.py create mode 100644 python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/makeGenerator.py create mode 100644 python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/makeGenerator_after.py create mode 100644 python/testData/quickFixes/PyMakeReturnsExplicitFixTest/addReturnsFromAnnotation.py create mode 100644 python/testData/quickFixes/PyMakeReturnsExplicitFixTest/addReturnsFromAnnotation_after.py diff --git a/python/python-psi-api/src/com/jetbrains/python/psi/PyFunction.java b/python/python-psi-api/src/com/jetbrains/python/psi/PyFunction.java index 1e6b52f93017..d2b765708a20 100644 --- a/python/python-psi-api/src/com/jetbrains/python/psi/PyFunction.java +++ b/python/python-psi-api/src/com/jetbrains/python/psi/PyFunction.java @@ -1,6 +1,7 @@ // Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. package com.jetbrains.python.psi; +import com.intellij.openapi.util.Pair; import com.intellij.psi.PsiNameIdentifierOwner; import com.intellij.psi.StubBasedPsiElement; import com.intellij.util.ArrayFactory; diff --git a/python/python-psi-api/src/com/jetbrains/python/psi/PyYieldExpression.java b/python/python-psi-api/src/com/jetbrains/python/psi/PyYieldExpression.java index 0abc04ea6c82..ce8169ae4cdf 100644 --- a/python/python-psi-api/src/com/jetbrains/python/psi/PyYieldExpression.java +++ b/python/python-psi-api/src/com/jetbrains/python/psi/PyYieldExpression.java @@ -2,6 +2,9 @@ package com.jetbrains.python.psi; import com.jetbrains.python.ast.PyAstYieldExpression; +import com.jetbrains.python.psi.types.PyType; +import com.jetbrains.python.psi.types.TypeEvalContext; +import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; @@ -11,4 +14,17 @@ public interface PyYieldExpression extends PyAstYieldExpression, PyExpression { default PyExpression getExpression() { return (PyExpression)PyAstYieldExpression.super.getExpression(); } + + /** + * @return For {@code yield}, returns type of its expression. For {@code yield from} - YieldType of the delegate + */ + @Nullable + PyType getYieldType(@NotNull TypeEvalContext context); + + /** + * @return If containing function is annotated with Generator (or AsyncGenerator), returns SendType from annotation. + * Otherwise, Any for {@code yield} and SendType of the delegate for {@code yield from} + */ + @Nullable + PyType getSendType(@NotNull TypeEvalContext context); } diff --git a/python/python-psi-impl/resources/messages/PyPsiBundle.properties b/python/python-psi-impl/resources/messages/PyPsiBundle.properties index 27c8bb468241..eaec23dc3bf1 100644 --- a/python/python-psi-impl/resources/messages/PyPsiBundle.properties +++ b/python/python-psi-impl/resources/messages/PyPsiBundle.properties @@ -1056,6 +1056,9 @@ QFIX.ignore.shadowed.built.in.name=Ignore shadowed built-in name "{0}" # PyTypeCheckerInspection INSP.NAME.type.checker=Incorrect type INSP.type.checker.expected.type.got.type.instead=Expected type ''{0}'', got ''{1}'' instead +INSP.type.checker.yield.type.mismatch=Expected yield type ''{0}'', got ''{1}'' instead +INSP.type.checker.yield.from.send.type.mismatch=Expected send type ''{0}'', got ''{1}'' instead +INSP.type.checker.yield.from.async.generator=Cannot yield from ''{0}'', use 'async for' instead INSP.type.checker.typed.dict.extra.key=Extra key ''{0}'' for TypedDict ''{1}'' INSP.type.checker.typed.dict.missing.keys=TypedDict ''{0}'' has missing {1,choice,1#key|2#keys}: {2} INSP.type.checker.returning.type.has.implicit.return=Function returning ''{0}'' has implicit ''return None'' diff --git a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java index c2c13b17571b..1673371d75dc 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java +++ b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java @@ -353,7 +353,11 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext< if (returnTypeAnnotation != null) { final Ref typeRef = getType(returnTypeAnnotation, context); if (typeRef != null) { - return Ref.create(toAsyncIfNeeded(function, typeRef.get())); + // Do not use toAsyncIfNeeded, as it also converts Generators. Here we do not need it. + if (function.isAsync() && function.isAsyncAllowed() && !function.isGenerator()) { + return Ref.create(wrapInCoroutineType(typeRef.get(), function)); + } + return typeRef; } // Don't rely on other type providers if a type hint is present, but cannot be resolved. return Ref.create(); @@ -2156,8 +2160,9 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext< if (!function.isGenerator()) { return wrapInCoroutineType(returnType, function); } - else if (returnType instanceof PyCollectionType && isGenerator(returnType)) { - return wrapInAsyncGeneratorType(((PyCollectionType)returnType).getIteratedItemType(), function); + var desc = GeneratorTypeDescriptor.create(returnType); + if (desc != null) { + return desc.withAsync(true).toPyType(function); } } @@ -2184,15 +2189,77 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext< } @Nullable - public static PyType wrapInGeneratorType(@Nullable PyType elementType, @Nullable PyType returnType, @NotNull PsiElement anchor) { + public static PyType wrapInGeneratorType(@Nullable PyType elementType, @Nullable PyType sendType, @Nullable PyType returnType, @NotNull PsiElement anchor) { final PyClass generator = PyPsiFacade.getInstance(anchor.getProject()).createClassByQName(GENERATOR, anchor); - return generator != null ? new PyCollectionTypeImpl(generator, false, Arrays.asList(elementType, null, returnType)) : null; + return generator != null ? new PyCollectionTypeImpl(generator, false, Arrays.asList(elementType, sendType, returnType)) : null; } - @Nullable - private static PyType wrapInAsyncGeneratorType(@Nullable PyType elementType, @NotNull PsiElement anchor) { - final PyClass asyncGenerator = PyPsiFacade.getInstance(anchor.getProject()).createClassByQName(ASYNC_GENERATOR, anchor); - return asyncGenerator != null ? new PyCollectionTypeImpl(asyncGenerator, false, Arrays.asList(elementType, null)) : null; + public record GeneratorTypeDescriptor( + String className, + PyType yieldType, // if YieldType is not specified, it is AnyType + PyType sendType, // if SendType is not specified, it is PyNoneType + PyType returnType // if ReturnType is not specified, it is PyNoneType + ) { + + private static final List SYNC_TYPES = List.of(GENERATOR, "typing.Iterable", "typing.Iterator"); + private static final List ASYNC_TYPES = List.of(ASYNC_GENERATOR, "typing.AsyncIterable", "typing.AsyncIterator"); + + public static @Nullable GeneratorTypeDescriptor create(@Nullable PyType type) { + final PyClassType classType = as(type, PyClassType.class); + final PyCollectionType genericType = as(type, PyCollectionType.class); + if (classType == null) return null; + + final String qName = classType.getClassQName(); + if (!SYNC_TYPES.contains(qName) && !ASYNC_TYPES.contains(qName)) return null; + + PyType yieldType = null; + PyType sendType = PyNoneType.INSTANCE; + PyType returnType = PyNoneType.INSTANCE; + + if (genericType != null) { + yieldType = ContainerUtil.getOrElse(genericType.getElementTypes(), 0, yieldType); + if (GENERATOR.equals(qName) || ASYNC_GENERATOR.equals(qName)) { + sendType = ContainerUtil.getOrElse(genericType.getElementTypes(), 1, sendType); + } + if (GENERATOR.equals(qName)) { + returnType = ContainerUtil.getOrElse(genericType.getElementTypes(), 2, returnType); + } + } + return new GeneratorTypeDescriptor(qName, yieldType, sendType, returnType); + } + + public boolean isAsync() { + return ASYNC_TYPES.contains(className); + } + + public GeneratorTypeDescriptor withAsync(boolean async) { + if (async) { + var idx = SYNC_TYPES.indexOf(className); + if (idx == -1) return this; + return new GeneratorTypeDescriptor(ASYNC_TYPES.get(idx), yieldType, sendType, PyNoneType.INSTANCE); + } + else { + var idx = ASYNC_TYPES.indexOf(className); + if (idx == -1) return this; + return new GeneratorTypeDescriptor(SYNC_TYPES.get(idx), yieldType, sendType, returnType); + } + } + + public @Nullable PyType toPyType(@NotNull PsiElement anchor) { + final PyClass classType = PyPsiFacade.getInstance(anchor.getProject()).createClassByQName(className, anchor); + final List generics; + if (GENERATOR.equals(className)) { + generics = Arrays.asList(yieldType, sendType, returnType); + } + else if (ASYNC_GENERATOR.equals(className)) { + generics = Arrays.asList(yieldType, sendType); + } + else { + generics = Collections.singletonList(yieldType); + } + + return classType != null ? new PyCollectionTypeImpl(classType, false, generics) : null; + } } @Nullable diff --git a/python/python-psi-impl/src/com/jetbrains/python/inspections/PyTypeCheckerInspection.java b/python/python-psi-impl/src/com/jetbrains/python/inspections/PyTypeCheckerInspection.java index 27ec4e90d1d1..3bd925ad34ff 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/inspections/PyTypeCheckerInspection.java +++ b/python/python-psi-impl/src/com/jetbrains/python/inspections/PyTypeCheckerInspection.java @@ -16,6 +16,7 @@ import com.jetbrains.python.codeInsight.controlflow.ScopeOwner; import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil; import com.jetbrains.python.codeInsight.typing.PyProtocolsKt; import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider; +import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider.GeneratorTypeDescriptor; import com.jetbrains.python.documentation.PythonDocumentationProvider; import com.jetbrains.python.inspections.quickfix.PyMakeFunctionReturnTypeQuickFix; import com.jetbrains.python.inspections.quickfix.PyMakeReturnsExplicitFix; @@ -92,7 +93,7 @@ public class PyTypeCheckerInspection extends PyInspection { PyAnnotation annotation = function.getAnnotation(); String typeCommentAnnotation = function.getTypeCommentAnnotation(); if (annotation != null || typeCommentAnnotation != null) { - PyType expected = getExpectedReturnType(function, myTypeEvalContext); + PyType expected = getExpectedReturnStatementType(function, myTypeEvalContext); if (expected == null) return; // We cannot just match annotated and inferred types, as we cannot promote inferred to Literal @@ -124,14 +125,80 @@ public class PyTypeCheckerInspection extends PyInspection { } } - @Nullable - public static PyType getExpectedReturnType(@NotNull PyFunction function, @NotNull TypeEvalContext typeEvalContext) { - final PyType returnType = typeEvalContext.getReturnType(function); + @Override + public void visitPyYieldExpression(@NotNull PyYieldExpression node) { + ScopeOwner owner = ScopeUtil.getScopeOwner(node); + if (!(owner instanceof PyFunction function)) return; + + final PyAnnotation annotation = function.getAnnotation(); + final String typeCommentAnnotation = function.getTypeCommentAnnotation(); + if (annotation == null && typeCommentAnnotation == null) return; - if (function.isAsync() || function.isGenerator()) { - return Ref.deref(PyTypingTypeProvider.coroutineOrGeneratorElementType(returnType)); + final PyType fullReturnType = myTypeEvalContext.getReturnType(function); + if (fullReturnType == null) return; // fullReturnType is Any + + final var generatorDesc = GeneratorTypeDescriptor.create(fullReturnType); + if (generatorDesc == null) { + // expected type is not Iterable, Iterator, Generator or similar + final PyType actual = function.getInferredReturnType(myTypeEvalContext); + String expectedName = PythonDocumentationProvider.getVerboseTypeName(fullReturnType, myTypeEvalContext); + String actualName = PythonDocumentationProvider.getTypeName(actual, myTypeEvalContext); + getHolder() + .problem(node, PyPsiBundle.message("INSP.type.checker.expected.type.got.type.instead", expectedName, actualName)) + .fix(new PyMakeFunctionReturnTypeQuickFix(function, myTypeEvalContext)) + .register(); + return; } + final PyType expectedYieldType = generatorDesc.yieldType(); + final PyType expectedSendType = generatorDesc.sendType(); + + final PyType thisYieldType = node.getYieldType(myTypeEvalContext); + + final PyExpression yieldExpr = node.getExpression(); + + if (!PyTypeChecker.match(expectedYieldType, thisYieldType, myTypeEvalContext)) { + String expectedName = PythonDocumentationProvider.getVerboseTypeName(expectedYieldType, myTypeEvalContext); + String actualName = PythonDocumentationProvider.getTypeName(thisYieldType, myTypeEvalContext); + getHolder() + .problem(yieldExpr != null ? yieldExpr : node, PyPsiBundle.message("INSP.type.checker.yield.type.mismatch", expectedName, actualName)) + .fix(new PyMakeFunctionReturnTypeQuickFix(function, myTypeEvalContext)) + .register(); + } + + if (yieldExpr != null && node.isDelegating()) { + final PyType delegateType = myTypeEvalContext.getType(yieldExpr); + var delegateDesc = GeneratorTypeDescriptor.create(delegateType); + if (delegateDesc == null) return; + + if (delegateDesc.isAsync()) { + String delegateName = PythonDocumentationProvider.getTypeName(delegateType, myTypeEvalContext); + registerProblem(yieldExpr, PyPsiBundle.message("INSP.type.checker.yield.from.async.generator", delegateName, delegateName)); + return; + } + + // Reversed because SendType is contravariant + if (!PyTypeChecker.match(delegateDesc.sendType(), expectedSendType, myTypeEvalContext)) { + String expectedName = PythonDocumentationProvider.getVerboseTypeName(expectedSendType, myTypeEvalContext); + String actualName = PythonDocumentationProvider.getTypeName(delegateDesc.sendType(), myTypeEvalContext); + registerProblem(yieldExpr, PyPsiBundle.message("INSP.type.checker.yield.from.send.type.mismatch", expectedName, actualName)); + } + } + } + + + @Nullable + public static PyType getExpectedReturnStatementType(@NotNull PyFunction function, @NotNull TypeEvalContext typeEvalContext) { + final PyType returnType = typeEvalContext.getReturnType(function); + if (function.isGenerator()) { + final var generatorDesc = GeneratorTypeDescriptor.create(returnType); + if (generatorDesc != null) { + return generatorDesc.returnType(); + } + } + if (function.isAsync()) { + return Ref.deref(PyTypingTypeProvider.coroutineOrGeneratorElementType(returnType)); + } return returnType; } @@ -238,7 +305,7 @@ public class PyTypeCheckerInspection extends PyInspection { final PyAnnotation annotation = node.getAnnotation(); final String typeCommentAnnotation = node.getTypeCommentAnnotation(); if (annotation != null || typeCommentAnnotation != null) { - final PyType expected = getExpectedReturnType(node, myTypeEvalContext); + final PyType expected = getExpectedReturnStatementType(node, myTypeEvalContext); final boolean returnsNone = expected instanceof PyNoneType; final boolean returnsOptional = PyTypeChecker.match(expected, PyNoneType.INSTANCE, myTypeEvalContext); @@ -263,6 +330,24 @@ public class PyTypeCheckerInspection extends PyInspection { registerProblem(annotation != null ? annotation.getValue() : node.getTypeComment(), PyPsiBundle.message("INSP.type.checker.init.should.return.none")); } + + if (node.isGenerator()) { + boolean shouldBeAsync = node.isAsync() && node.isAsyncAllowed(); + final PyType annotatedType = myTypeEvalContext.getReturnType(node); + final var generatorDesc = GeneratorTypeDescriptor.create(annotatedType); + if (generatorDesc != null && generatorDesc.isAsync() != shouldBeAsync) { + final PyType inferredType = node.getInferredReturnType(myTypeEvalContext); + String expectedName = PythonDocumentationProvider.getVerboseTypeName(inferredType, myTypeEvalContext); + String actualName = PythonDocumentationProvider.getTypeName(annotatedType, myTypeEvalContext); + final PsiElement annotationValue = annotation != null ? annotation.getValue() : node.getTypeComment(); + if (annotationValue != null) { + getHolder() + .problem(annotationValue, PyPsiBundle.message("INSP.type.checker.expected.type.got.type.instead", expectedName, actualName)) + .fix(new PyMakeFunctionReturnTypeQuickFix(node, myTypeEvalContext)) + .register(); + } + } + } } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/inspections/quickfix/PyMakeFunctionReturnTypeQuickFix.java b/python/python-psi-impl/src/com/jetbrains/python/inspections/quickfix/PyMakeFunctionReturnTypeQuickFix.java index 89593dd53fce..c8b5a6b9b969 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/inspections/quickfix/PyMakeFunctionReturnTypeQuickFix.java +++ b/python/python-psi-impl/src/com/jetbrains/python/inspections/quickfix/PyMakeFunctionReturnTypeQuickFix.java @@ -15,50 +15,48 @@ */ package com.jetbrains.python.inspections.quickfix; -import com.intellij.codeInsight.intention.FileModifier; -import com.intellij.codeInspection.LocalQuickFix; -import com.intellij.codeInspection.ProblemDescriptor; -import com.intellij.openapi.project.Project; +import com.intellij.modcommand.ActionContext; +import com.intellij.modcommand.ModPsiUpdater; +import com.intellij.modcommand.Presentation; +import com.intellij.modcommand.PsiUpdateModCommandAction; import com.intellij.psi.*; import com.jetbrains.python.PyPsiBundle; import com.jetbrains.python.codeInsight.intentions.PyTypeHintGenerationUtil; +import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider; import com.jetbrains.python.documentation.PythonDocumentationProvider; import com.jetbrains.python.psi.*; import com.jetbrains.python.psi.types.PyType; import com.jetbrains.python.psi.types.TypeEvalContext; -import com.jetbrains.python.refactoring.PyRefactoringUtil; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; import java.util.List; -public class PyMakeFunctionReturnTypeQuickFix implements LocalQuickFix { - private final @NotNull SmartPsiElementPointer myFunction; +public class PyMakeFunctionReturnTypeQuickFix extends PsiUpdateModCommandAction { private final String myReturnTypeName; public PyMakeFunctionReturnTypeQuickFix(@NotNull PyFunction function, @NotNull TypeEvalContext context) { - this(function, getReturnTypeName(function, context)); - } - - private PyMakeFunctionReturnTypeQuickFix(@NotNull PyFunction function, @NotNull String returnTypeName) { - SmartPointerManager manager = SmartPointerManager.getInstance(function.getProject()); - myFunction = manager.createSmartPsiElementPointer(function); - myReturnTypeName = returnTypeName; + super(function); + myReturnTypeName = getReturnTypeName(function, context); } @NotNull private static String getReturnTypeName(@NotNull PyFunction function, @NotNull TypeEvalContext context) { - final PyType type = function.getInferredReturnType(context); + PyType type = function.getInferredReturnType(context); + if (function.isAsync()) { + var unwrappedType = PyTypingTypeProvider.unwrapCoroutineReturnType(type); + if (unwrappedType != null) { + type = unwrappedType.get(); + } + } return PythonDocumentationProvider.getTypeHint(type, context); } @Override - @NotNull - public String getName() { - PyFunction function = myFunction.getElement(); - String functionName = function != null ? function.getName() : "function"; - return PyPsiBundle.message("QFIX.make.function.return.type", functionName, myReturnTypeName); + @Nullable + protected Presentation getPresentation(@NotNull ActionContext context, @NotNull PyFunction function) { + return Presentation.of(PyPsiBundle.message("QFIX.make.function.return.type", function.getName(), myReturnTypeName)); } @Override @@ -68,11 +66,8 @@ public class PyMakeFunctionReturnTypeQuickFix implements LocalQuickFix { } @Override - public void applyFix(@NotNull Project project, @NotNull ProblemDescriptor descriptor) { - final PyFunction function = myFunction.getElement(); - if (function == null) return; - - final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project); + protected void invoke(@NotNull ActionContext context, @NotNull PyFunction function, @NotNull ModPsiUpdater updater) { + final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(function.getProject()); boolean shouldAddImports = false; @@ -96,28 +91,18 @@ public class PyMakeFunctionReturnTypeQuickFix implements LocalQuickFix { } if (shouldAddImports) { - addImportsForTypeAnnotations(TypeEvalContext.userInitiated(project, function.getContainingFile())); + addImportsForTypeAnnotations(function); } } - private void addImportsForTypeAnnotations(@NotNull TypeEvalContext context) { - PyFunction function = myFunction.getElement(); - if (function == null) return; + private static void addImportsForTypeAnnotations(@NotNull PyFunction function) { PsiFile file = function.getContainingFile(); if (file == null) return; + TypeEvalContext context = TypeEvalContext.userInitiated(function.getProject(), file); PyType typeForImports = function.getInferredReturnType(context); if (typeForImports != null) { PyTypeHintGenerationUtil.addImportsForTypeAnnotations(List.of(typeForImports), context, file); } } - - @Override - public @Nullable FileModifier getFileModifierForPreview(@NotNull PsiFile target) { - PyFunction function = PyRefactoringUtil.findSameElementForPreview(myFunction, target); - if (function == null) { - return null; - } - return new PyMakeFunctionReturnTypeQuickFix(function, myReturnTypeName); - } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java index 9922e96c0428..466641fdfc8e 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java @@ -6,6 +6,7 @@ import com.intellij.codeInsight.controlflow.Instruction; import com.intellij.lang.ASTNode; import com.intellij.navigation.ItemPresentation; import com.intellij.openapi.util.Key; +import com.intellij.openapi.util.Pair; import com.intellij.openapi.util.Ref; import com.intellij.openapi.vfs.VirtualFile; import com.intellij.psi.PsiElement; @@ -16,10 +17,8 @@ import com.intellij.psi.stubs.IStubElementType; import com.intellij.psi.stubs.StubElement; import com.intellij.psi.util.*; import com.intellij.ui.IconManager; -import com.intellij.util.ArrayUtil; import com.intellij.util.IncorrectOperationException; import com.intellij.util.PlatformIcons; -import com.intellij.util.containers.ContainerUtil; import com.intellij.util.containers.JBIterable; import com.jetbrains.python.PyNames; import com.jetbrains.python.PyStubElementTypes; @@ -38,13 +37,16 @@ import com.jetbrains.python.psi.stubs.PyFunctionStub; import com.jetbrains.python.psi.stubs.PyTargetExpressionStub; import com.jetbrains.python.psi.types.*; import com.jetbrains.python.sdk.PythonSdkUtil; +import one.util.streamex.StreamEx; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; import javax.swing.*; import java.util.*; +import java.util.stream.Stream; import static com.intellij.openapi.util.text.StringUtil.notNullize; +import static com.intellij.util.containers.ContainerUtil.*; import static com.jetbrains.python.ast.PyAstFunction.Modifier.CLASSMETHOD; import static com.jetbrains.python.ast.PyAstFunction.Modifier.STATICMETHOD; import static com.jetbrains.python.psi.PyUtil.as; @@ -130,7 +132,7 @@ public class PyFunctionImpl extends PyBaseElementImpl implements .filter(PyCallableType.class::isInstance) .map(PyCallableType.class::cast) .map(callableType -> callableType.getParameters(context)) - .orElseGet(() -> ContainerUtil.map(getParameterList().getParameters(), PyCallableParameterImpl::psi)); + .orElseGet(() -> map(getParameterList().getParameters(), PyCallableParameterImpl::psi)); } @Override @@ -164,12 +166,13 @@ public class PyFunctionImpl extends PyBaseElementImpl implements public @Nullable PyType getInferredReturnType(@NotNull TypeEvalContext context) { PyType inferredType = null; if (context.allowReturnTypes(this)) { - final Ref yieldTypeRef = getYieldStatementType(context); - if (yieldTypeRef != null) { - inferredType = yieldTypeRef.get(); + final PyType returnType = getReturnStatementType(context); + final Pair yieldSendTypePair = getYieldExpressionType(context); + if (yieldSendTypePair != null) { + inferredType = PyTypingTypeProvider.wrapInGeneratorType(yieldSendTypePair.first, yieldSendTypePair.second, returnType, this); } else { - inferredType = getReturnStatementType(context); + inferredType = returnType; } } return PyTypingTypeProvider.removeNarrowedTypeIfNeeded(PyTypingTypeProvider.toAsyncIfNeeded(this, inferredType)); @@ -189,7 +192,7 @@ public class PyFunctionImpl extends PyBaseElementImpl implements final Map mappedExplicitParameters = fullMapping.getMappedParameters(); final Map allMappedParameters = new LinkedHashMap<>(); - final PyCallableParameter firstImplicit = ContainerUtil.getFirstItem(fullMapping.getImplicitParameters()); + final PyCallableParameter firstImplicit = getFirstItem(fullMapping.getImplicitParameters()); if (receiver != null && firstImplicit != null) { allMappedParameters.put(receiver, firstImplicit); } @@ -282,7 +285,7 @@ public class PyFunctionImpl extends PyBaseElementImpl implements else if (allowCoroutineOrGenerator && returnType instanceof PyCollectionType && PyTypingTypeProvider.coroutineOrGeneratorElementType(returnType) != null) { - final List replacedElementTypes = ContainerUtil.map( + final List replacedElementTypes = map( ((PyCollectionType)returnType).getElementTypes(), type -> replaceSelf(type, receiver, context, false) ); @@ -308,42 +311,41 @@ public class PyFunctionImpl extends PyBaseElementImpl implements } return false; } + + public static class YieldCollector extends PyRecursiveElementVisitor { + public List getYieldExpressions() { + return myYieldExpressions; + } - private @Nullable Ref getYieldStatementType(final @NotNull TypeEvalContext context) { - final PyBuiltinCache cache = PyBuiltinCache.getInstance(this); + final private List myYieldExpressions = new ArrayList<>(); + + @Override + public void visitPyYieldExpression(@NotNull PyYieldExpression node) { + myYieldExpressions.add(node); + } + + @Override + public void visitPyFunction(@NotNull PyFunction node) { + // Ignore nested functions + } + + @Override + public void visitPyLambdaExpression(@NotNull PyLambdaExpression node) { + // Ignore nested lambdas + } + } + + /** + * @return pair of YieldType and SendType. Null when there are no yields + */ + private @Nullable Pair getYieldExpressionType(final @NotNull TypeEvalContext context) { final PyStatementList statements = getStatementList(); - final Set types = new LinkedHashSet<>(); - statements.accept(new PyRecursiveElementVisitor() { - @Override - public void visitPyYieldExpression(@NotNull PyYieldExpression node) { - final PyExpression expr = node.getExpression(); - final PyType type = expr != null ? context.getType(expr) : null; - - if (node.isDelegating()) { - if (type instanceof PyCollectionType) { - types.add(((PyCollectionType)type).getIteratedItemType()); - } - else if (ArrayUtil.contains(type, cache.getListType(), cache.getDictType(), cache.getSetType(), cache.getTupleType())) { - types.add(null); - } - else { - types.add(type); - } - } - else { - types.add(type); - } - } - - @Override - public void visitPyFunction(@NotNull PyFunction node) { - // Ignore nested functions - } - }); - if (!types.isEmpty()) { - final PyType elementType = PyUnionType.union(types); - final PyType returnType = getReturnStatementType(context); - return Ref.create(PyTypingTypeProvider.wrapInGeneratorType(elementType, returnType, this)); + final YieldCollector visitor = new YieldCollector(); + statements.accept(visitor); + final List yieldTypes = map(visitor.getYieldExpressions(), it -> it.getYieldType(context)); + final List sendTypes = map(visitor.getYieldExpressions(), it -> it.getSendType(context)); + if (!yieldTypes.isEmpty()) { + return Pair.create(PyUnionType.union(yieldTypes), PyUnionType.union(sendTypes)); } return null; } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyGeneratorExpressionImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyGeneratorExpressionImpl.java index 2cf6f8cfce76..cabb721c5659 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyGeneratorExpressionImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyGeneratorExpressionImpl.java @@ -42,7 +42,7 @@ public class PyGeneratorExpressionImpl extends PyComprehensionElementImpl implem public PyType getType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) { final PyExpression resultExpr = getResultExpression(); if (resultExpr != null) { - return PyTypingTypeProvider.wrapInGeneratorType(context.getType(resultExpr), PyNoneType.INSTANCE, this); + return PyTypingTypeProvider.wrapInGeneratorType(context.getType(resultExpr), null, PyNoneType.INSTANCE, this); } return null; } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyLambdaExpressionImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyLambdaExpressionImpl.java index 3c1a1616d6d5..f0df28c485b6 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyLambdaExpressionImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyLambdaExpressionImpl.java @@ -4,18 +4,19 @@ package com.jetbrains.python.psi.impl; import com.intellij.lang.ASTNode; import com.intellij.util.containers.ContainerUtil; import com.jetbrains.python.codeInsight.controlflow.ControlFlowCache; -import com.jetbrains.python.psi.PyCallSiteExpression; -import com.jetbrains.python.psi.PyElementVisitor; -import com.jetbrains.python.psi.PyExpression; -import com.jetbrains.python.psi.PyLambdaExpression; +import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider; +import com.jetbrains.python.psi.*; import com.jetbrains.python.psi.types.*; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; +import static com.intellij.util.containers.ContainerUtil.map; + public class PyLambdaExpressionImpl extends PyElementImpl implements PyLambdaExpression { public PyLambdaExpressionImpl(ASTNode astNode) { @@ -56,7 +57,19 @@ public class PyLambdaExpressionImpl extends PyElementImpl implements PyLambdaExp @Override public PyType getReturnType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) { final PyExpression body = getBody(); - return body != null ? context.getType(body) : null; + if (body == null) return null; + + final PyFunctionImpl.YieldCollector visitor = new PyFunctionImpl.YieldCollector(); + body.accept(visitor); + + final List yieldTypes = map(visitor.getYieldExpressions(), it -> it.getYieldType(context)); + final List sendTypes = map(visitor.getYieldExpressions(), it -> it.getSendType(context)); + + if (!yieldTypes.isEmpty()) { + return PyTypingTypeProvider.wrapInGeneratorType( + PyUnionType.union(yieldTypes), PyUnionType.union(sendTypes), context.getType(body), this); + } + return context.getType(body); } @Nullable diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyTargetExpressionImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyTargetExpressionImpl.java index a7758cae9cc6..2d6f951d5b04 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyTargetExpressionImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyTargetExpressionImpl.java @@ -127,16 +127,7 @@ public class PyTargetExpressionImpl extends PyBaseElementImpl generatorElementType = PyTypingTypeProvider.coroutineOrGeneratorElementType(type); - return generatorElementType == null ? PyNoneType.INSTANCE : generatorElementType.get(); + final PyExpression e = getExpression(); + final PyType type = e != null ? context.getType(e) : null; + var generatorDesc = PyTypingTypeProvider.GeneratorTypeDescriptor.create(type); + if (generatorDesc != null) { + return generatorDesc.returnType(); + } + return PyNoneType.INSTANCE; + } + else { + return getSendType(context); + } + } + + @Override + public @Nullable PyType getYieldType(@NotNull TypeEvalContext context) { + final PyExpression expr = getExpression(); + final PyType type = expr != null ? context.getType(expr) : PyNoneType.INSTANCE; + + if (isDelegating()) { + return PyTargetExpressionImpl.getIterationType(type, expr, this, context); } return type; } + + @Override + public @Nullable PyType getSendType(@NotNull TypeEvalContext context) { + if (ScopeUtil.getScopeOwner(this) instanceof PyFunction function) { + if (function.getAnnotation() != null || function.getTypeCommentAnnotation() != null) { + var returnType = context.getReturnType(function); + var generatorDesc = PyTypingTypeProvider.GeneratorTypeDescriptor.create(returnType); + if (generatorDesc != null) { + return generatorDesc.sendType(); + } + } + } + + if (isDelegating()) { + final PyExpression e = getExpression(); + final PyType type = e != null ? context.getType(e) : null; + var generatorDesc = PyTypingTypeProvider.GeneratorTypeDescriptor.create(type); + if (generatorDesc != null) { + return generatorDesc.sendType(); + } + return PyNoneType.INSTANCE; + } + return null; + } } diff --git a/python/testData/inspections/PyTypeCheckerInspection/FunctionReturnTypePy3.py b/python/testData/inspections/PyTypeCheckerInspection/FunctionReturnTypePy3.py index e456889ddd2d..4ef353a6fcd4 100644 --- a/python/testData/inspections/PyTypeCheckerInspection/FunctionReturnTypePy3.py +++ b/python/testData/inspections/PyTypeCheckerInspection/FunctionReturnTypePy3.py @@ -50,4 +50,8 @@ def l(x) -> int None: """Does not display warning about implicit return, because annotated '-> None' """ if x: - return \ No newline at end of file + return + +def n() -> Generator[int, Any, str]: + yield 13 + return 42 \ No newline at end of file diff --git a/python/testData/inspections/PyTypeCheckerInspection/FunctionYieldTypePy3.py b/python/testData/inspections/PyTypeCheckerInspection/FunctionYieldTypePy3.py new file mode 100644 index 000000000000..4225943c576d --- /dev/null +++ b/python/testData/inspections/PyTypeCheckerInspection/FunctionYieldTypePy3.py @@ -0,0 +1,69 @@ +from typing import Generator, Iterable, Iterator, AsyncIterable, AsyncIterator, AsyncGenerator + +# Fix incorrect YieldType +def a() -> Iterable[str]: + yield 42 + +def b() -> Iterator[str]: + yield 42 + +def c() -> Generator[str, Any, int]: + yield 13 + return 42 + +def c() -> Generator[int, Any, str]: + yield 13 + return 42 + +# Suggest AsyncGenerator +async def d() -> Iterable[int]: + yield 42 + +async def e() -> Iterator[int]: + yield 42 + +async def f() -> Generator[int, str, None]: + yield 13 + +# Suggest sync Generator +def g() -> AsyncIterable[int]: + yield 42 + +def h() -> AsyncIterator[int]: + yield 42 + +def i() -> AsyncGenerator[int, str]: + yield 13 + +def j() -> Iterator[int]: + yield from j() + +def k() -> Iterator[str]: + yield from j() + yield from [1] + +def l() -> Generator[None, int, None]: + x: float = yield + +def m() -> Generator[None, float, None]: + yield from l() + +def n() -> Generator[None, float, None]: + x: float = yield + +def o() -> Generator[None, int, None]: + yield from n() + +def p() -> Generator[int, None, None]: + yield from [1, 2] + yield from [3, 4] + +def q() -> int: + x = lambda: (yield "str") + return 42 + +async def r() -> AsyncGenerator[int]: + yield 42 + +def s() -> Generator[int]: + yield from r() \ No newline at end of file diff --git a/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/changeGenerator.py b/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/changeGenerator.py new file mode 100644 index 000000000000..2385f3dc80b0 --- /dev/null +++ b/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/changeGenerator.py @@ -0,0 +1,5 @@ +from typing import Generator + +def gen() -> Generator[int, bool, str]: + b: bool = yield "str" + return 42 \ No newline at end of file diff --git a/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/changeGenerator_after.py b/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/changeGenerator_after.py new file mode 100644 index 000000000000..e35a271bb9f3 --- /dev/null +++ b/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/changeGenerator_after.py @@ -0,0 +1,5 @@ +from typing import Generator + +def gen() -> Generator[str, bool, int]: + b: bool = yield "str" + return 42 \ No newline at end of file diff --git a/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/makeGenerator.py b/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/makeGenerator.py new file mode 100644 index 000000000000..2368acb78148 --- /dev/null +++ b/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/makeGenerator.py @@ -0,0 +1,4 @@ +async def gen() -> str: + b: bool = yield "str" + if b: + b = yield 3.14 \ No newline at end of file diff --git a/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/makeGenerator_after.py b/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/makeGenerator_after.py new file mode 100644 index 000000000000..2389737db0b8 --- /dev/null +++ b/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/makeGenerator_after.py @@ -0,0 +1,7 @@ +from typing import Any, AsyncGenerator + + +async def gen() -> AsyncGenerator[str | float, Any]: + b: bool = yield "str" + if b: + b = yield 3.14 \ No newline at end of file diff --git a/python/testData/quickFixes/PyMakeReturnsExplicitFixTest/addReturnsFromAnnotation.py b/python/testData/quickFixes/PyMakeReturnsExplicitFixTest/addReturnsFromAnnotation.py new file mode 100644 index 000000000000..077dc7909e56 --- /dev/null +++ b/python/testData/quickFixes/PyMakeReturnsExplicitFixTest/addReturnsFromAnnotation.py @@ -0,0 +1,5 @@ +def f(x) -> int | None: + if x == 1: + return 42 + elif x == 2: + return \ No newline at end of file diff --git a/python/testData/quickFixes/PyMakeReturnsExplicitFixTest/addReturnsFromAnnotation_after.py b/python/testData/quickFixes/PyMakeReturnsExplicitFixTest/addReturnsFromAnnotation_after.py new file mode 100644 index 000000000000..9a4cd15bef3b --- /dev/null +++ b/python/testData/quickFixes/PyMakeReturnsExplicitFixTest/addReturnsFromAnnotation_after.py @@ -0,0 +1,6 @@ +def f(x) -> int | None: + if x == 1: + return 42 + elif x == 2: + return None + return None \ No newline at end of file diff --git a/python/testSrc/com/jetbrains/python/Py3TypeTest.java b/python/testSrc/com/jetbrains/python/Py3TypeTest.java index e259b57e2e16..2c4772601648 100644 --- a/python/testSrc/com/jetbrains/python/Py3TypeTest.java +++ b/python/testSrc/com/jetbrains/python/Py3TypeTest.java @@ -17,6 +17,66 @@ import java.util.Map; public class Py3TypeTest extends PyTestCase { public static final String TEST_DIRECTORY = "/types/"; + // PY-20710 + public void testLambdaGenerator() { + doTest("Generator[int, Any, Any]", """ + expr = (lambda: (yield 1))() + """); + } + + // PY-20710 + public void testGeneratorDelegatingToLambdaGenerator() { + doTest("Generator[int, Any, str]", """ + def g(): + yield from (lambda: (yield 1))() + return "foo" + expr = g() + """); + } + + // PY-20710 + public void testYieldExpressionTypeFromGeneratorSendTypeHint() { + doTest("int", """ + from typing import Generator + + def g() -> Generator[str, int, None]: + expr = yield "foo" + """); + } + + // PY-20710 + public void testYieldFromExpressionTypeFromGeneratorReturnTypeHint() { + doTest("int", """ + from typing import Generator, Any + + def delegate() -> Generator[None, Any, int]: + yield + return 42 + + def g(): + expr = yield from delegate() + """); + } + + // PY-20710 + public void testYieldFromLambda() { + doTest("Generator[int | str, str | Any, bool]", + """ + from typing import Generator + + def gen1() -> Generator[int, str, bool]: + yield 42 + return True + + def gen2(): + yield "str" + return True + + l = lambda: (yield from gen1()) or (yield from gen2()) + expr = l() + """); + } + // PY-6702 public void testYieldFromType() { doTest("str | int | float", diff --git a/python/testSrc/com/jetbrains/python/inspections/Py3TypeCheckerInspectionTest.java b/python/testSrc/com/jetbrains/python/inspections/Py3TypeCheckerInspectionTest.java index a5ba8d6c0398..ac1a7adc0f35 100644 --- a/python/testSrc/com/jetbrains/python/inspections/Py3TypeCheckerInspectionTest.java +++ b/python/testSrc/com/jetbrains/python/inspections/Py3TypeCheckerInspectionTest.java @@ -102,6 +102,10 @@ public class Py3TypeCheckerInspectionTest extends PyInspectionTestCase { public void testFunctionReturnTypePy3() { doTest(); } + + public void testFunctionYieldTypePy3() { + doTest(); + } // PY-20770 public void testAsyncForOverAsyncGenerator() { diff --git a/python/testSrc/com/jetbrains/python/quickFixes/PyMakeFunctionReturnTypeQuickFixTest.java b/python/testSrc/com/jetbrains/python/quickFixes/PyMakeFunctionReturnTypeQuickFixTest.java index 3194c6a55df9..e03e625dbb06 100644 --- a/python/testSrc/com/jetbrains/python/quickFixes/PyMakeFunctionReturnTypeQuickFixTest.java +++ b/python/testSrc/com/jetbrains/python/quickFixes/PyMakeFunctionReturnTypeQuickFixTest.java @@ -64,4 +64,18 @@ public class PyMakeFunctionReturnTypeQuickFixTest extends PyQuickFixTestCase { PyPsiBundle.message("QFIX.make.function.return.type", "func", "Callable[[Any], int]"), LanguageLevel.getLatest()); } + + // PY-20710 + public void testChangeGenerator() { + doQuickFixTest(PyTypeCheckerInspection.class, + PyPsiBundle.message("QFIX.make.function.return.type", "gen", "Generator[str, bool, int]"), + LanguageLevel.getLatest()); + } + + // PY-20710 + public void testMakeGenerator() { + doQuickFixTest(PyTypeCheckerInspection.class, + PyPsiBundle.message("QFIX.make.function.return.type", "gen", "AsyncGenerator[str | float, Any]"), + LanguageLevel.getLatest()); + } }