diff --git a/python/python-psi-api/src/com/jetbrains/python/psi/PyFunction.java b/python/python-psi-api/src/com/jetbrains/python/psi/PyFunction.java index 1e6b52f93017..d2b765708a20 100644 --- a/python/python-psi-api/src/com/jetbrains/python/psi/PyFunction.java +++ b/python/python-psi-api/src/com/jetbrains/python/psi/PyFunction.java @@ -1,6 +1,7 @@ // Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. package com.jetbrains.python.psi; +import com.intellij.openapi.util.Pair; import com.intellij.psi.PsiNameIdentifierOwner; import com.intellij.psi.StubBasedPsiElement; import com.intellij.util.ArrayFactory; diff --git a/python/python-psi-api/src/com/jetbrains/python/psi/PyYieldExpression.java b/python/python-psi-api/src/com/jetbrains/python/psi/PyYieldExpression.java index 0abc04ea6c82..ce8169ae4cdf 100644 --- a/python/python-psi-api/src/com/jetbrains/python/psi/PyYieldExpression.java +++ b/python/python-psi-api/src/com/jetbrains/python/psi/PyYieldExpression.java @@ -2,6 +2,9 @@ package com.jetbrains.python.psi; import com.jetbrains.python.ast.PyAstYieldExpression; +import com.jetbrains.python.psi.types.PyType; +import com.jetbrains.python.psi.types.TypeEvalContext; +import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; @@ -11,4 +14,17 @@ public interface PyYieldExpression extends PyAstYieldExpression, PyExpression { default PyExpression getExpression() { return (PyExpression)PyAstYieldExpression.super.getExpression(); } + + /** + * @return For {@code yield}, returns type of its expression. For {@code yield from} - YieldType of the delegate + */ + @Nullable + PyType getYieldType(@NotNull TypeEvalContext context); + + /** + * @return If containing function is annotated with Generator (or AsyncGenerator), returns SendType from annotation. + * Otherwise, Any for {@code yield} and SendType of the delegate for {@code yield from} + */ + @Nullable + PyType getSendType(@NotNull TypeEvalContext context); } diff --git a/python/python-psi-impl/resources/messages/PyPsiBundle.properties b/python/python-psi-impl/resources/messages/PyPsiBundle.properties index 27c8bb468241..eaec23dc3bf1 100644 --- a/python/python-psi-impl/resources/messages/PyPsiBundle.properties +++ b/python/python-psi-impl/resources/messages/PyPsiBundle.properties @@ -1056,6 +1056,9 @@ QFIX.ignore.shadowed.built.in.name=Ignore shadowed built-in name "{0}" # PyTypeCheckerInspection INSP.NAME.type.checker=Incorrect type INSP.type.checker.expected.type.got.type.instead=Expected type ''{0}'', got ''{1}'' instead +INSP.type.checker.yield.type.mismatch=Expected yield type ''{0}'', got ''{1}'' instead +INSP.type.checker.yield.from.send.type.mismatch=Expected send type ''{0}'', got ''{1}'' instead +INSP.type.checker.yield.from.async.generator=Cannot yield from ''{0}'', use 'async for' instead INSP.type.checker.typed.dict.extra.key=Extra key ''{0}'' for TypedDict ''{1}'' INSP.type.checker.typed.dict.missing.keys=TypedDict ''{0}'' has missing {1,choice,1#key|2#keys}: {2} INSP.type.checker.returning.type.has.implicit.return=Function returning ''{0}'' has implicit ''return None'' diff --git a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java index c2c13b17571b..1673371d75dc 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java +++ b/python/python-psi-impl/src/com/jetbrains/python/codeInsight/typing/PyTypingTypeProvider.java @@ -353,7 +353,11 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext< if (returnTypeAnnotation != null) { final Ref typeRef = getType(returnTypeAnnotation, context); if (typeRef != null) { - return Ref.create(toAsyncIfNeeded(function, typeRef.get())); + // Do not use toAsyncIfNeeded, as it also converts Generators. Here we do not need it. + if (function.isAsync() && function.isAsyncAllowed() && !function.isGenerator()) { + return Ref.create(wrapInCoroutineType(typeRef.get(), function)); + } + return typeRef; } // Don't rely on other type providers if a type hint is present, but cannot be resolved. return Ref.create(); @@ -2156,8 +2160,9 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext< if (!function.isGenerator()) { return wrapInCoroutineType(returnType, function); } - else if (returnType instanceof PyCollectionType && isGenerator(returnType)) { - return wrapInAsyncGeneratorType(((PyCollectionType)returnType).getIteratedItemType(), function); + var desc = GeneratorTypeDescriptor.create(returnType); + if (desc != null) { + return desc.withAsync(true).toPyType(function); } } @@ -2184,15 +2189,77 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext< } @Nullable - public static PyType wrapInGeneratorType(@Nullable PyType elementType, @Nullable PyType returnType, @NotNull PsiElement anchor) { + public static PyType wrapInGeneratorType(@Nullable PyType elementType, @Nullable PyType sendType, @Nullable PyType returnType, @NotNull PsiElement anchor) { final PyClass generator = PyPsiFacade.getInstance(anchor.getProject()).createClassByQName(GENERATOR, anchor); - return generator != null ? new PyCollectionTypeImpl(generator, false, Arrays.asList(elementType, null, returnType)) : null; + return generator != null ? new PyCollectionTypeImpl(generator, false, Arrays.asList(elementType, sendType, returnType)) : null; } - @Nullable - private static PyType wrapInAsyncGeneratorType(@Nullable PyType elementType, @NotNull PsiElement anchor) { - final PyClass asyncGenerator = PyPsiFacade.getInstance(anchor.getProject()).createClassByQName(ASYNC_GENERATOR, anchor); - return asyncGenerator != null ? new PyCollectionTypeImpl(asyncGenerator, false, Arrays.asList(elementType, null)) : null; + public record GeneratorTypeDescriptor( + String className, + PyType yieldType, // if YieldType is not specified, it is AnyType + PyType sendType, // if SendType is not specified, it is PyNoneType + PyType returnType // if ReturnType is not specified, it is PyNoneType + ) { + + private static final List SYNC_TYPES = List.of(GENERATOR, "typing.Iterable", "typing.Iterator"); + private static final List ASYNC_TYPES = List.of(ASYNC_GENERATOR, "typing.AsyncIterable", "typing.AsyncIterator"); + + public static @Nullable GeneratorTypeDescriptor create(@Nullable PyType type) { + final PyClassType classType = as(type, PyClassType.class); + final PyCollectionType genericType = as(type, PyCollectionType.class); + if (classType == null) return null; + + final String qName = classType.getClassQName(); + if (!SYNC_TYPES.contains(qName) && !ASYNC_TYPES.contains(qName)) return null; + + PyType yieldType = null; + PyType sendType = PyNoneType.INSTANCE; + PyType returnType = PyNoneType.INSTANCE; + + if (genericType != null) { + yieldType = ContainerUtil.getOrElse(genericType.getElementTypes(), 0, yieldType); + if (GENERATOR.equals(qName) || ASYNC_GENERATOR.equals(qName)) { + sendType = ContainerUtil.getOrElse(genericType.getElementTypes(), 1, sendType); + } + if (GENERATOR.equals(qName)) { + returnType = ContainerUtil.getOrElse(genericType.getElementTypes(), 2, returnType); + } + } + return new GeneratorTypeDescriptor(qName, yieldType, sendType, returnType); + } + + public boolean isAsync() { + return ASYNC_TYPES.contains(className); + } + + public GeneratorTypeDescriptor withAsync(boolean async) { + if (async) { + var idx = SYNC_TYPES.indexOf(className); + if (idx == -1) return this; + return new GeneratorTypeDescriptor(ASYNC_TYPES.get(idx), yieldType, sendType, PyNoneType.INSTANCE); + } + else { + var idx = ASYNC_TYPES.indexOf(className); + if (idx == -1) return this; + return new GeneratorTypeDescriptor(SYNC_TYPES.get(idx), yieldType, sendType, returnType); + } + } + + public @Nullable PyType toPyType(@NotNull PsiElement anchor) { + final PyClass classType = PyPsiFacade.getInstance(anchor.getProject()).createClassByQName(className, anchor); + final List generics; + if (GENERATOR.equals(className)) { + generics = Arrays.asList(yieldType, sendType, returnType); + } + else if (ASYNC_GENERATOR.equals(className)) { + generics = Arrays.asList(yieldType, sendType); + } + else { + generics = Collections.singletonList(yieldType); + } + + return classType != null ? new PyCollectionTypeImpl(classType, false, generics) : null; + } } @Nullable diff --git a/python/python-psi-impl/src/com/jetbrains/python/inspections/PyTypeCheckerInspection.java b/python/python-psi-impl/src/com/jetbrains/python/inspections/PyTypeCheckerInspection.java index 27ec4e90d1d1..3bd925ad34ff 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/inspections/PyTypeCheckerInspection.java +++ b/python/python-psi-impl/src/com/jetbrains/python/inspections/PyTypeCheckerInspection.java @@ -16,6 +16,7 @@ import com.jetbrains.python.codeInsight.controlflow.ScopeOwner; import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil; import com.jetbrains.python.codeInsight.typing.PyProtocolsKt; import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider; +import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider.GeneratorTypeDescriptor; import com.jetbrains.python.documentation.PythonDocumentationProvider; import com.jetbrains.python.inspections.quickfix.PyMakeFunctionReturnTypeQuickFix; import com.jetbrains.python.inspections.quickfix.PyMakeReturnsExplicitFix; @@ -92,7 +93,7 @@ public class PyTypeCheckerInspection extends PyInspection { PyAnnotation annotation = function.getAnnotation(); String typeCommentAnnotation = function.getTypeCommentAnnotation(); if (annotation != null || typeCommentAnnotation != null) { - PyType expected = getExpectedReturnType(function, myTypeEvalContext); + PyType expected = getExpectedReturnStatementType(function, myTypeEvalContext); if (expected == null) return; // We cannot just match annotated and inferred types, as we cannot promote inferred to Literal @@ -124,14 +125,80 @@ public class PyTypeCheckerInspection extends PyInspection { } } - @Nullable - public static PyType getExpectedReturnType(@NotNull PyFunction function, @NotNull TypeEvalContext typeEvalContext) { - final PyType returnType = typeEvalContext.getReturnType(function); + @Override + public void visitPyYieldExpression(@NotNull PyYieldExpression node) { + ScopeOwner owner = ScopeUtil.getScopeOwner(node); + if (!(owner instanceof PyFunction function)) return; + + final PyAnnotation annotation = function.getAnnotation(); + final String typeCommentAnnotation = function.getTypeCommentAnnotation(); + if (annotation == null && typeCommentAnnotation == null) return; - if (function.isAsync() || function.isGenerator()) { - return Ref.deref(PyTypingTypeProvider.coroutineOrGeneratorElementType(returnType)); + final PyType fullReturnType = myTypeEvalContext.getReturnType(function); + if (fullReturnType == null) return; // fullReturnType is Any + + final var generatorDesc = GeneratorTypeDescriptor.create(fullReturnType); + if (generatorDesc == null) { + // expected type is not Iterable, Iterator, Generator or similar + final PyType actual = function.getInferredReturnType(myTypeEvalContext); + String expectedName = PythonDocumentationProvider.getVerboseTypeName(fullReturnType, myTypeEvalContext); + String actualName = PythonDocumentationProvider.getTypeName(actual, myTypeEvalContext); + getHolder() + .problem(node, PyPsiBundle.message("INSP.type.checker.expected.type.got.type.instead", expectedName, actualName)) + .fix(new PyMakeFunctionReturnTypeQuickFix(function, myTypeEvalContext)) + .register(); + return; } + final PyType expectedYieldType = generatorDesc.yieldType(); + final PyType expectedSendType = generatorDesc.sendType(); + + final PyType thisYieldType = node.getYieldType(myTypeEvalContext); + + final PyExpression yieldExpr = node.getExpression(); + + if (!PyTypeChecker.match(expectedYieldType, thisYieldType, myTypeEvalContext)) { + String expectedName = PythonDocumentationProvider.getVerboseTypeName(expectedYieldType, myTypeEvalContext); + String actualName = PythonDocumentationProvider.getTypeName(thisYieldType, myTypeEvalContext); + getHolder() + .problem(yieldExpr != null ? yieldExpr : node, PyPsiBundle.message("INSP.type.checker.yield.type.mismatch", expectedName, actualName)) + .fix(new PyMakeFunctionReturnTypeQuickFix(function, myTypeEvalContext)) + .register(); + } + + if (yieldExpr != null && node.isDelegating()) { + final PyType delegateType = myTypeEvalContext.getType(yieldExpr); + var delegateDesc = GeneratorTypeDescriptor.create(delegateType); + if (delegateDesc == null) return; + + if (delegateDesc.isAsync()) { + String delegateName = PythonDocumentationProvider.getTypeName(delegateType, myTypeEvalContext); + registerProblem(yieldExpr, PyPsiBundle.message("INSP.type.checker.yield.from.async.generator", delegateName, delegateName)); + return; + } + + // Reversed because SendType is contravariant + if (!PyTypeChecker.match(delegateDesc.sendType(), expectedSendType, myTypeEvalContext)) { + String expectedName = PythonDocumentationProvider.getVerboseTypeName(expectedSendType, myTypeEvalContext); + String actualName = PythonDocumentationProvider.getTypeName(delegateDesc.sendType(), myTypeEvalContext); + registerProblem(yieldExpr, PyPsiBundle.message("INSP.type.checker.yield.from.send.type.mismatch", expectedName, actualName)); + } + } + } + + + @Nullable + public static PyType getExpectedReturnStatementType(@NotNull PyFunction function, @NotNull TypeEvalContext typeEvalContext) { + final PyType returnType = typeEvalContext.getReturnType(function); + if (function.isGenerator()) { + final var generatorDesc = GeneratorTypeDescriptor.create(returnType); + if (generatorDesc != null) { + return generatorDesc.returnType(); + } + } + if (function.isAsync()) { + return Ref.deref(PyTypingTypeProvider.coroutineOrGeneratorElementType(returnType)); + } return returnType; } @@ -238,7 +305,7 @@ public class PyTypeCheckerInspection extends PyInspection { final PyAnnotation annotation = node.getAnnotation(); final String typeCommentAnnotation = node.getTypeCommentAnnotation(); if (annotation != null || typeCommentAnnotation != null) { - final PyType expected = getExpectedReturnType(node, myTypeEvalContext); + final PyType expected = getExpectedReturnStatementType(node, myTypeEvalContext); final boolean returnsNone = expected instanceof PyNoneType; final boolean returnsOptional = PyTypeChecker.match(expected, PyNoneType.INSTANCE, myTypeEvalContext); @@ -263,6 +330,24 @@ public class PyTypeCheckerInspection extends PyInspection { registerProblem(annotation != null ? annotation.getValue() : node.getTypeComment(), PyPsiBundle.message("INSP.type.checker.init.should.return.none")); } + + if (node.isGenerator()) { + boolean shouldBeAsync = node.isAsync() && node.isAsyncAllowed(); + final PyType annotatedType = myTypeEvalContext.getReturnType(node); + final var generatorDesc = GeneratorTypeDescriptor.create(annotatedType); + if (generatorDesc != null && generatorDesc.isAsync() != shouldBeAsync) { + final PyType inferredType = node.getInferredReturnType(myTypeEvalContext); + String expectedName = PythonDocumentationProvider.getVerboseTypeName(inferredType, myTypeEvalContext); + String actualName = PythonDocumentationProvider.getTypeName(annotatedType, myTypeEvalContext); + final PsiElement annotationValue = annotation != null ? annotation.getValue() : node.getTypeComment(); + if (annotationValue != null) { + getHolder() + .problem(annotationValue, PyPsiBundle.message("INSP.type.checker.expected.type.got.type.instead", expectedName, actualName)) + .fix(new PyMakeFunctionReturnTypeQuickFix(node, myTypeEvalContext)) + .register(); + } + } + } } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/inspections/quickfix/PyMakeFunctionReturnTypeQuickFix.java b/python/python-psi-impl/src/com/jetbrains/python/inspections/quickfix/PyMakeFunctionReturnTypeQuickFix.java index 89593dd53fce..c8b5a6b9b969 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/inspections/quickfix/PyMakeFunctionReturnTypeQuickFix.java +++ b/python/python-psi-impl/src/com/jetbrains/python/inspections/quickfix/PyMakeFunctionReturnTypeQuickFix.java @@ -15,50 +15,48 @@ */ package com.jetbrains.python.inspections.quickfix; -import com.intellij.codeInsight.intention.FileModifier; -import com.intellij.codeInspection.LocalQuickFix; -import com.intellij.codeInspection.ProblemDescriptor; -import com.intellij.openapi.project.Project; +import com.intellij.modcommand.ActionContext; +import com.intellij.modcommand.ModPsiUpdater; +import com.intellij.modcommand.Presentation; +import com.intellij.modcommand.PsiUpdateModCommandAction; import com.intellij.psi.*; import com.jetbrains.python.PyPsiBundle; import com.jetbrains.python.codeInsight.intentions.PyTypeHintGenerationUtil; +import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider; import com.jetbrains.python.documentation.PythonDocumentationProvider; import com.jetbrains.python.psi.*; import com.jetbrains.python.psi.types.PyType; import com.jetbrains.python.psi.types.TypeEvalContext; -import com.jetbrains.python.refactoring.PyRefactoringUtil; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; import java.util.List; -public class PyMakeFunctionReturnTypeQuickFix implements LocalQuickFix { - private final @NotNull SmartPsiElementPointer myFunction; +public class PyMakeFunctionReturnTypeQuickFix extends PsiUpdateModCommandAction { private final String myReturnTypeName; public PyMakeFunctionReturnTypeQuickFix(@NotNull PyFunction function, @NotNull TypeEvalContext context) { - this(function, getReturnTypeName(function, context)); - } - - private PyMakeFunctionReturnTypeQuickFix(@NotNull PyFunction function, @NotNull String returnTypeName) { - SmartPointerManager manager = SmartPointerManager.getInstance(function.getProject()); - myFunction = manager.createSmartPsiElementPointer(function); - myReturnTypeName = returnTypeName; + super(function); + myReturnTypeName = getReturnTypeName(function, context); } @NotNull private static String getReturnTypeName(@NotNull PyFunction function, @NotNull TypeEvalContext context) { - final PyType type = function.getInferredReturnType(context); + PyType type = function.getInferredReturnType(context); + if (function.isAsync()) { + var unwrappedType = PyTypingTypeProvider.unwrapCoroutineReturnType(type); + if (unwrappedType != null) { + type = unwrappedType.get(); + } + } return PythonDocumentationProvider.getTypeHint(type, context); } @Override - @NotNull - public String getName() { - PyFunction function = myFunction.getElement(); - String functionName = function != null ? function.getName() : "function"; - return PyPsiBundle.message("QFIX.make.function.return.type", functionName, myReturnTypeName); + @Nullable + protected Presentation getPresentation(@NotNull ActionContext context, @NotNull PyFunction function) { + return Presentation.of(PyPsiBundle.message("QFIX.make.function.return.type", function.getName(), myReturnTypeName)); } @Override @@ -68,11 +66,8 @@ public class PyMakeFunctionReturnTypeQuickFix implements LocalQuickFix { } @Override - public void applyFix(@NotNull Project project, @NotNull ProblemDescriptor descriptor) { - final PyFunction function = myFunction.getElement(); - if (function == null) return; - - final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project); + protected void invoke(@NotNull ActionContext context, @NotNull PyFunction function, @NotNull ModPsiUpdater updater) { + final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(function.getProject()); boolean shouldAddImports = false; @@ -96,28 +91,18 @@ public class PyMakeFunctionReturnTypeQuickFix implements LocalQuickFix { } if (shouldAddImports) { - addImportsForTypeAnnotations(TypeEvalContext.userInitiated(project, function.getContainingFile())); + addImportsForTypeAnnotations(function); } } - private void addImportsForTypeAnnotations(@NotNull TypeEvalContext context) { - PyFunction function = myFunction.getElement(); - if (function == null) return; + private static void addImportsForTypeAnnotations(@NotNull PyFunction function) { PsiFile file = function.getContainingFile(); if (file == null) return; + TypeEvalContext context = TypeEvalContext.userInitiated(function.getProject(), file); PyType typeForImports = function.getInferredReturnType(context); if (typeForImports != null) { PyTypeHintGenerationUtil.addImportsForTypeAnnotations(List.of(typeForImports), context, file); } } - - @Override - public @Nullable FileModifier getFileModifierForPreview(@NotNull PsiFile target) { - PyFunction function = PyRefactoringUtil.findSameElementForPreview(myFunction, target); - if (function == null) { - return null; - } - return new PyMakeFunctionReturnTypeQuickFix(function, myReturnTypeName); - } } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java index 9922e96c0428..466641fdfc8e 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyFunctionImpl.java @@ -6,6 +6,7 @@ import com.intellij.codeInsight.controlflow.Instruction; import com.intellij.lang.ASTNode; import com.intellij.navigation.ItemPresentation; import com.intellij.openapi.util.Key; +import com.intellij.openapi.util.Pair; import com.intellij.openapi.util.Ref; import com.intellij.openapi.vfs.VirtualFile; import com.intellij.psi.PsiElement; @@ -16,10 +17,8 @@ import com.intellij.psi.stubs.IStubElementType; import com.intellij.psi.stubs.StubElement; import com.intellij.psi.util.*; import com.intellij.ui.IconManager; -import com.intellij.util.ArrayUtil; import com.intellij.util.IncorrectOperationException; import com.intellij.util.PlatformIcons; -import com.intellij.util.containers.ContainerUtil; import com.intellij.util.containers.JBIterable; import com.jetbrains.python.PyNames; import com.jetbrains.python.PyStubElementTypes; @@ -38,13 +37,16 @@ import com.jetbrains.python.psi.stubs.PyFunctionStub; import com.jetbrains.python.psi.stubs.PyTargetExpressionStub; import com.jetbrains.python.psi.types.*; import com.jetbrains.python.sdk.PythonSdkUtil; +import one.util.streamex.StreamEx; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; import javax.swing.*; import java.util.*; +import java.util.stream.Stream; import static com.intellij.openapi.util.text.StringUtil.notNullize; +import static com.intellij.util.containers.ContainerUtil.*; import static com.jetbrains.python.ast.PyAstFunction.Modifier.CLASSMETHOD; import static com.jetbrains.python.ast.PyAstFunction.Modifier.STATICMETHOD; import static com.jetbrains.python.psi.PyUtil.as; @@ -130,7 +132,7 @@ public class PyFunctionImpl extends PyBaseElementImpl implements .filter(PyCallableType.class::isInstance) .map(PyCallableType.class::cast) .map(callableType -> callableType.getParameters(context)) - .orElseGet(() -> ContainerUtil.map(getParameterList().getParameters(), PyCallableParameterImpl::psi)); + .orElseGet(() -> map(getParameterList().getParameters(), PyCallableParameterImpl::psi)); } @Override @@ -164,12 +166,13 @@ public class PyFunctionImpl extends PyBaseElementImpl implements public @Nullable PyType getInferredReturnType(@NotNull TypeEvalContext context) { PyType inferredType = null; if (context.allowReturnTypes(this)) { - final Ref yieldTypeRef = getYieldStatementType(context); - if (yieldTypeRef != null) { - inferredType = yieldTypeRef.get(); + final PyType returnType = getReturnStatementType(context); + final Pair yieldSendTypePair = getYieldExpressionType(context); + if (yieldSendTypePair != null) { + inferredType = PyTypingTypeProvider.wrapInGeneratorType(yieldSendTypePair.first, yieldSendTypePair.second, returnType, this); } else { - inferredType = getReturnStatementType(context); + inferredType = returnType; } } return PyTypingTypeProvider.removeNarrowedTypeIfNeeded(PyTypingTypeProvider.toAsyncIfNeeded(this, inferredType)); @@ -189,7 +192,7 @@ public class PyFunctionImpl extends PyBaseElementImpl implements final Map mappedExplicitParameters = fullMapping.getMappedParameters(); final Map allMappedParameters = new LinkedHashMap<>(); - final PyCallableParameter firstImplicit = ContainerUtil.getFirstItem(fullMapping.getImplicitParameters()); + final PyCallableParameter firstImplicit = getFirstItem(fullMapping.getImplicitParameters()); if (receiver != null && firstImplicit != null) { allMappedParameters.put(receiver, firstImplicit); } @@ -282,7 +285,7 @@ public class PyFunctionImpl extends PyBaseElementImpl implements else if (allowCoroutineOrGenerator && returnType instanceof PyCollectionType && PyTypingTypeProvider.coroutineOrGeneratorElementType(returnType) != null) { - final List replacedElementTypes = ContainerUtil.map( + final List replacedElementTypes = map( ((PyCollectionType)returnType).getElementTypes(), type -> replaceSelf(type, receiver, context, false) ); @@ -308,42 +311,41 @@ public class PyFunctionImpl extends PyBaseElementImpl implements } return false; } + + public static class YieldCollector extends PyRecursiveElementVisitor { + public List getYieldExpressions() { + return myYieldExpressions; + } - private @Nullable Ref getYieldStatementType(final @NotNull TypeEvalContext context) { - final PyBuiltinCache cache = PyBuiltinCache.getInstance(this); + final private List myYieldExpressions = new ArrayList<>(); + + @Override + public void visitPyYieldExpression(@NotNull PyYieldExpression node) { + myYieldExpressions.add(node); + } + + @Override + public void visitPyFunction(@NotNull PyFunction node) { + // Ignore nested functions + } + + @Override + public void visitPyLambdaExpression(@NotNull PyLambdaExpression node) { + // Ignore nested lambdas + } + } + + /** + * @return pair of YieldType and SendType. Null when there are no yields + */ + private @Nullable Pair getYieldExpressionType(final @NotNull TypeEvalContext context) { final PyStatementList statements = getStatementList(); - final Set types = new LinkedHashSet<>(); - statements.accept(new PyRecursiveElementVisitor() { - @Override - public void visitPyYieldExpression(@NotNull PyYieldExpression node) { - final PyExpression expr = node.getExpression(); - final PyType type = expr != null ? context.getType(expr) : null; - - if (node.isDelegating()) { - if (type instanceof PyCollectionType) { - types.add(((PyCollectionType)type).getIteratedItemType()); - } - else if (ArrayUtil.contains(type, cache.getListType(), cache.getDictType(), cache.getSetType(), cache.getTupleType())) { - types.add(null); - } - else { - types.add(type); - } - } - else { - types.add(type); - } - } - - @Override - public void visitPyFunction(@NotNull PyFunction node) { - // Ignore nested functions - } - }); - if (!types.isEmpty()) { - final PyType elementType = PyUnionType.union(types); - final PyType returnType = getReturnStatementType(context); - return Ref.create(PyTypingTypeProvider.wrapInGeneratorType(elementType, returnType, this)); + final YieldCollector visitor = new YieldCollector(); + statements.accept(visitor); + final List yieldTypes = map(visitor.getYieldExpressions(), it -> it.getYieldType(context)); + final List sendTypes = map(visitor.getYieldExpressions(), it -> it.getSendType(context)); + if (!yieldTypes.isEmpty()) { + return Pair.create(PyUnionType.union(yieldTypes), PyUnionType.union(sendTypes)); } return null; } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyGeneratorExpressionImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyGeneratorExpressionImpl.java index 2cf6f8cfce76..cabb721c5659 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyGeneratorExpressionImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyGeneratorExpressionImpl.java @@ -42,7 +42,7 @@ public class PyGeneratorExpressionImpl extends PyComprehensionElementImpl implem public PyType getType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) { final PyExpression resultExpr = getResultExpression(); if (resultExpr != null) { - return PyTypingTypeProvider.wrapInGeneratorType(context.getType(resultExpr), PyNoneType.INSTANCE, this); + return PyTypingTypeProvider.wrapInGeneratorType(context.getType(resultExpr), null, PyNoneType.INSTANCE, this); } return null; } diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyLambdaExpressionImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyLambdaExpressionImpl.java index 3c1a1616d6d5..f0df28c485b6 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyLambdaExpressionImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyLambdaExpressionImpl.java @@ -4,18 +4,19 @@ package com.jetbrains.python.psi.impl; import com.intellij.lang.ASTNode; import com.intellij.util.containers.ContainerUtil; import com.jetbrains.python.codeInsight.controlflow.ControlFlowCache; -import com.jetbrains.python.psi.PyCallSiteExpression; -import com.jetbrains.python.psi.PyElementVisitor; -import com.jetbrains.python.psi.PyExpression; -import com.jetbrains.python.psi.PyLambdaExpression; +import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider; +import com.jetbrains.python.psi.*; import com.jetbrains.python.psi.types.*; import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.Nullable; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Optional; +import static com.intellij.util.containers.ContainerUtil.map; + public class PyLambdaExpressionImpl extends PyElementImpl implements PyLambdaExpression { public PyLambdaExpressionImpl(ASTNode astNode) { @@ -56,7 +57,19 @@ public class PyLambdaExpressionImpl extends PyElementImpl implements PyLambdaExp @Override public PyType getReturnType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) { final PyExpression body = getBody(); - return body != null ? context.getType(body) : null; + if (body == null) return null; + + final PyFunctionImpl.YieldCollector visitor = new PyFunctionImpl.YieldCollector(); + body.accept(visitor); + + final List yieldTypes = map(visitor.getYieldExpressions(), it -> it.getYieldType(context)); + final List sendTypes = map(visitor.getYieldExpressions(), it -> it.getSendType(context)); + + if (!yieldTypes.isEmpty()) { + return PyTypingTypeProvider.wrapInGeneratorType( + PyUnionType.union(yieldTypes), PyUnionType.union(sendTypes), context.getType(body), this); + } + return context.getType(body); } @Nullable diff --git a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyTargetExpressionImpl.java b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyTargetExpressionImpl.java index a7758cae9cc6..2d6f951d5b04 100644 --- a/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyTargetExpressionImpl.java +++ b/python/python-psi-impl/src/com/jetbrains/python/psi/impl/PyTargetExpressionImpl.java @@ -127,16 +127,7 @@ public class PyTargetExpressionImpl extends PyBaseElementImpl generatorElementType = PyTypingTypeProvider.coroutineOrGeneratorElementType(type); - return generatorElementType == null ? PyNoneType.INSTANCE : generatorElementType.get(); + final PyExpression e = getExpression(); + final PyType type = e != null ? context.getType(e) : null; + var generatorDesc = PyTypingTypeProvider.GeneratorTypeDescriptor.create(type); + if (generatorDesc != null) { + return generatorDesc.returnType(); + } + return PyNoneType.INSTANCE; + } + else { + return getSendType(context); + } + } + + @Override + public @Nullable PyType getYieldType(@NotNull TypeEvalContext context) { + final PyExpression expr = getExpression(); + final PyType type = expr != null ? context.getType(expr) : PyNoneType.INSTANCE; + + if (isDelegating()) { + return PyTargetExpressionImpl.getIterationType(type, expr, this, context); } return type; } + + @Override + public @Nullable PyType getSendType(@NotNull TypeEvalContext context) { + if (ScopeUtil.getScopeOwner(this) instanceof PyFunction function) { + if (function.getAnnotation() != null || function.getTypeCommentAnnotation() != null) { + var returnType = context.getReturnType(function); + var generatorDesc = PyTypingTypeProvider.GeneratorTypeDescriptor.create(returnType); + if (generatorDesc != null) { + return generatorDesc.sendType(); + } + } + } + + if (isDelegating()) { + final PyExpression e = getExpression(); + final PyType type = e != null ? context.getType(e) : null; + var generatorDesc = PyTypingTypeProvider.GeneratorTypeDescriptor.create(type); + if (generatorDesc != null) { + return generatorDesc.sendType(); + } + return PyNoneType.INSTANCE; + } + return null; + } } diff --git a/python/testData/inspections/PyTypeCheckerInspection/FunctionReturnTypePy3.py b/python/testData/inspections/PyTypeCheckerInspection/FunctionReturnTypePy3.py index e456889ddd2d..4ef353a6fcd4 100644 --- a/python/testData/inspections/PyTypeCheckerInspection/FunctionReturnTypePy3.py +++ b/python/testData/inspections/PyTypeCheckerInspection/FunctionReturnTypePy3.py @@ -50,4 +50,8 @@ def l(x) -> int None: """Does not display warning about implicit return, because annotated '-> None' """ if x: - return \ No newline at end of file + return + +def n() -> Generator[int, Any, str]: + yield 13 + return 42 \ No newline at end of file diff --git a/python/testData/inspections/PyTypeCheckerInspection/FunctionYieldTypePy3.py b/python/testData/inspections/PyTypeCheckerInspection/FunctionYieldTypePy3.py new file mode 100644 index 000000000000..4225943c576d --- /dev/null +++ b/python/testData/inspections/PyTypeCheckerInspection/FunctionYieldTypePy3.py @@ -0,0 +1,69 @@ +from typing import Generator, Iterable, Iterator, AsyncIterable, AsyncIterator, AsyncGenerator + +# Fix incorrect YieldType +def a() -> Iterable[str]: + yield 42 + +def b() -> Iterator[str]: + yield 42 + +def c() -> Generator[str, Any, int]: + yield 13 + return 42 + +def c() -> Generator[int, Any, str]: + yield 13 + return 42 + +# Suggest AsyncGenerator +async def d() -> Iterable[int]: + yield 42 + +async def e() -> Iterator[int]: + yield 42 + +async def f() -> Generator[int, str, None]: + yield 13 + +# Suggest sync Generator +def g() -> AsyncIterable[int]: + yield 42 + +def h() -> AsyncIterator[int]: + yield 42 + +def i() -> AsyncGenerator[int, str]: + yield 13 + +def j() -> Iterator[int]: + yield from j() + +def k() -> Iterator[str]: + yield from j() + yield from [1] + +def l() -> Generator[None, int, None]: + x: float = yield + +def m() -> Generator[None, float, None]: + yield from l() + +def n() -> Generator[None, float, None]: + x: float = yield + +def o() -> Generator[None, int, None]: + yield from n() + +def p() -> Generator[int, None, None]: + yield from [1, 2] + yield from [3, 4] + +def q() -> int: + x = lambda: (yield "str") + return 42 + +async def r() -> AsyncGenerator[int]: + yield 42 + +def s() -> Generator[int]: + yield from r() \ No newline at end of file diff --git a/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/changeGenerator.py b/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/changeGenerator.py new file mode 100644 index 000000000000..2385f3dc80b0 --- /dev/null +++ b/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/changeGenerator.py @@ -0,0 +1,5 @@ +from typing import Generator + +def gen() -> Generator[int, bool, str]: + b: bool = yield "str" + return 42 \ No newline at end of file diff --git a/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/changeGenerator_after.py b/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/changeGenerator_after.py new file mode 100644 index 000000000000..e35a271bb9f3 --- /dev/null +++ b/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/changeGenerator_after.py @@ -0,0 +1,5 @@ +from typing import Generator + +def gen() -> Generator[str, bool, int]: + b: bool = yield "str" + return 42 \ No newline at end of file diff --git a/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/makeGenerator.py b/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/makeGenerator.py new file mode 100644 index 000000000000..2368acb78148 --- /dev/null +++ b/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/makeGenerator.py @@ -0,0 +1,4 @@ +async def gen() -> str: + b: bool = yield "str" + if b: + b = yield 3.14 \ No newline at end of file diff --git a/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/makeGenerator_after.py b/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/makeGenerator_after.py new file mode 100644 index 000000000000..2389737db0b8 --- /dev/null +++ b/python/testData/quickFixes/PyMakeFunctionReturnTypeQuickFixTest/makeGenerator_after.py @@ -0,0 +1,7 @@ +from typing import Any, AsyncGenerator + + +async def gen() -> AsyncGenerator[str | float, Any]: + b: bool = yield "str" + if b: + b = yield 3.14 \ No newline at end of file diff --git a/python/testData/quickFixes/PyMakeReturnsExplicitFixTest/addReturnsFromAnnotation.py b/python/testData/quickFixes/PyMakeReturnsExplicitFixTest/addReturnsFromAnnotation.py new file mode 100644 index 000000000000..077dc7909e56 --- /dev/null +++ b/python/testData/quickFixes/PyMakeReturnsExplicitFixTest/addReturnsFromAnnotation.py @@ -0,0 +1,5 @@ +def f(x) -> int | None: + if x == 1: + return 42 + elif x == 2: + return \ No newline at end of file diff --git a/python/testData/quickFixes/PyMakeReturnsExplicitFixTest/addReturnsFromAnnotation_after.py b/python/testData/quickFixes/PyMakeReturnsExplicitFixTest/addReturnsFromAnnotation_after.py new file mode 100644 index 000000000000..9a4cd15bef3b --- /dev/null +++ b/python/testData/quickFixes/PyMakeReturnsExplicitFixTest/addReturnsFromAnnotation_after.py @@ -0,0 +1,6 @@ +def f(x) -> int | None: + if x == 1: + return 42 + elif x == 2: + return None + return None \ No newline at end of file diff --git a/python/testSrc/com/jetbrains/python/Py3TypeTest.java b/python/testSrc/com/jetbrains/python/Py3TypeTest.java index e259b57e2e16..2c4772601648 100644 --- a/python/testSrc/com/jetbrains/python/Py3TypeTest.java +++ b/python/testSrc/com/jetbrains/python/Py3TypeTest.java @@ -17,6 +17,66 @@ import java.util.Map; public class Py3TypeTest extends PyTestCase { public static final String TEST_DIRECTORY = "/types/"; + // PY-20710 + public void testLambdaGenerator() { + doTest("Generator[int, Any, Any]", """ + expr = (lambda: (yield 1))() + """); + } + + // PY-20710 + public void testGeneratorDelegatingToLambdaGenerator() { + doTest("Generator[int, Any, str]", """ + def g(): + yield from (lambda: (yield 1))() + return "foo" + expr = g() + """); + } + + // PY-20710 + public void testYieldExpressionTypeFromGeneratorSendTypeHint() { + doTest("int", """ + from typing import Generator + + def g() -> Generator[str, int, None]: + expr = yield "foo" + """); + } + + // PY-20710 + public void testYieldFromExpressionTypeFromGeneratorReturnTypeHint() { + doTest("int", """ + from typing import Generator, Any + + def delegate() -> Generator[None, Any, int]: + yield + return 42 + + def g(): + expr = yield from delegate() + """); + } + + // PY-20710 + public void testYieldFromLambda() { + doTest("Generator[int | str, str | Any, bool]", + """ + from typing import Generator + + def gen1() -> Generator[int, str, bool]: + yield 42 + return True + + def gen2(): + yield "str" + return True + + l = lambda: (yield from gen1()) or (yield from gen2()) + expr = l() + """); + } + // PY-6702 public void testYieldFromType() { doTest("str | int | float", diff --git a/python/testSrc/com/jetbrains/python/inspections/Py3TypeCheckerInspectionTest.java b/python/testSrc/com/jetbrains/python/inspections/Py3TypeCheckerInspectionTest.java index a5ba8d6c0398..ac1a7adc0f35 100644 --- a/python/testSrc/com/jetbrains/python/inspections/Py3TypeCheckerInspectionTest.java +++ b/python/testSrc/com/jetbrains/python/inspections/Py3TypeCheckerInspectionTest.java @@ -102,6 +102,10 @@ public class Py3TypeCheckerInspectionTest extends PyInspectionTestCase { public void testFunctionReturnTypePy3() { doTest(); } + + public void testFunctionYieldTypePy3() { + doTest(); + } // PY-20770 public void testAsyncForOverAsyncGenerator() { diff --git a/python/testSrc/com/jetbrains/python/quickFixes/PyMakeFunctionReturnTypeQuickFixTest.java b/python/testSrc/com/jetbrains/python/quickFixes/PyMakeFunctionReturnTypeQuickFixTest.java index 3194c6a55df9..e03e625dbb06 100644 --- a/python/testSrc/com/jetbrains/python/quickFixes/PyMakeFunctionReturnTypeQuickFixTest.java +++ b/python/testSrc/com/jetbrains/python/quickFixes/PyMakeFunctionReturnTypeQuickFixTest.java @@ -64,4 +64,18 @@ public class PyMakeFunctionReturnTypeQuickFixTest extends PyQuickFixTestCase { PyPsiBundle.message("QFIX.make.function.return.type", "func", "Callable[[Any], int]"), LanguageLevel.getLatest()); } + + // PY-20710 + public void testChangeGenerator() { + doQuickFixTest(PyTypeCheckerInspection.class, + PyPsiBundle.message("QFIX.make.function.return.type", "gen", "Generator[str, bool, int]"), + LanguageLevel.getLatest()); + } + + // PY-20710 + public void testMakeGenerator() { + doQuickFixTest(PyTypeCheckerInspection.class, + PyPsiBundle.message("QFIX.make.function.return.type", "gen", "AsyncGenerator[str | float, Any]"), + LanguageLevel.getLatest()); + } }