PY-20710 Support 'Generator' typing class

Check YieldType of yield expressions in PyTypeCheckerInspection
Check that (Async)Generator is used in (async) function
Check that in 'yield from' sync Generator is used
Convert PyMakeFunctionReturnTypeQuickFix into PsiUpdateModCommandAction
Infer Generator type for lambdas
When getting function type from annotation, do not convert Generator to AsyncGenerator
Introduce GeneratorTypeDescriptor to simplify working with generator annotations


Merge-request: IJ-MR-146521
Merged-by: Aleksandr Govenko <aleksandr.govenko@jetbrains.com>

(cherry picked from commit b3b8182168c5224f0e03f54d443171ccf6ca7b89)

IJ-MR-146521

GitOrigin-RevId: a95670d7e2787015bcf162637ea6d7bfb47a312a
This commit is contained in:
Aleksandr.Govenko
2024-12-04 15:22:15 +00:00
committed by intellij-monorepo-bot
parent 362a0344a7
commit 70fe60b4c8
22 changed files with 509 additions and 122 deletions

View File

@@ -1,6 +1,7 @@
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.psi;
import com.intellij.openapi.util.Pair;
import com.intellij.psi.PsiNameIdentifierOwner;
import com.intellij.psi.StubBasedPsiElement;
import com.intellij.util.ArrayFactory;

View File

@@ -2,6 +2,9 @@
package com.jetbrains.python.psi;
import com.jetbrains.python.ast.PyAstYieldExpression;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
@@ -11,4 +14,17 @@ public interface PyYieldExpression extends PyAstYieldExpression, PyExpression {
default PyExpression getExpression() {
return (PyExpression)PyAstYieldExpression.super.getExpression();
}
/**
* @return For {@code yield}, returns type of its expression. For {@code yield from} - YieldType of the delegate
*/
@Nullable
PyType getYieldType(@NotNull TypeEvalContext context);
/**
* @return If containing function is annotated with Generator (or AsyncGenerator), returns SendType from annotation.
* Otherwise, Any for {@code yield} and SendType of the delegate for {@code yield from}
*/
@Nullable
PyType getSendType(@NotNull TypeEvalContext context);
}

View File

@@ -1056,6 +1056,9 @@ QFIX.ignore.shadowed.built.in.name=Ignore shadowed built-in name "{0}"
# PyTypeCheckerInspection
INSP.NAME.type.checker=Incorrect type
INSP.type.checker.expected.type.got.type.instead=Expected type ''{0}'', got ''{1}'' instead
INSP.type.checker.yield.type.mismatch=Expected yield type ''{0}'', got ''{1}'' instead
INSP.type.checker.yield.from.send.type.mismatch=Expected send type ''{0}'', got ''{1}'' instead
INSP.type.checker.yield.from.async.generator=Cannot yield from ''{0}'', use 'async for' instead
INSP.type.checker.typed.dict.extra.key=Extra key ''{0}'' for TypedDict ''{1}''
INSP.type.checker.typed.dict.missing.keys=TypedDict ''{0}'' has missing {1,choice,1#key|2#keys}: {2}
INSP.type.checker.returning.type.has.implicit.return=Function returning ''{0}'' has implicit ''return None''

View File

@@ -353,7 +353,11 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
if (returnTypeAnnotation != null) {
final Ref<PyType> typeRef = getType(returnTypeAnnotation, context);
if (typeRef != null) {
return Ref.create(toAsyncIfNeeded(function, typeRef.get()));
// Do not use toAsyncIfNeeded, as it also converts Generators. Here we do not need it.
if (function.isAsync() && function.isAsyncAllowed() && !function.isGenerator()) {
return Ref.create(wrapInCoroutineType(typeRef.get(), function));
}
return typeRef;
}
// Don't rely on other type providers if a type hint is present, but cannot be resolved.
return Ref.create();
@@ -2156,8 +2160,9 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
if (!function.isGenerator()) {
return wrapInCoroutineType(returnType, function);
}
else if (returnType instanceof PyCollectionType && isGenerator(returnType)) {
return wrapInAsyncGeneratorType(((PyCollectionType)returnType).getIteratedItemType(), function);
var desc = GeneratorTypeDescriptor.create(returnType);
if (desc != null) {
return desc.withAsync(true).toPyType(function);
}
}
@@ -2184,15 +2189,77 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
}
@Nullable
public static PyType wrapInGeneratorType(@Nullable PyType elementType, @Nullable PyType returnType, @NotNull PsiElement anchor) {
public static PyType wrapInGeneratorType(@Nullable PyType elementType, @Nullable PyType sendType, @Nullable PyType returnType, @NotNull PsiElement anchor) {
final PyClass generator = PyPsiFacade.getInstance(anchor.getProject()).createClassByQName(GENERATOR, anchor);
return generator != null ? new PyCollectionTypeImpl(generator, false, Arrays.asList(elementType, null, returnType)) : null;
return generator != null ? new PyCollectionTypeImpl(generator, false, Arrays.asList(elementType, sendType, returnType)) : null;
}
@Nullable
private static PyType wrapInAsyncGeneratorType(@Nullable PyType elementType, @NotNull PsiElement anchor) {
final PyClass asyncGenerator = PyPsiFacade.getInstance(anchor.getProject()).createClassByQName(ASYNC_GENERATOR, anchor);
return asyncGenerator != null ? new PyCollectionTypeImpl(asyncGenerator, false, Arrays.asList(elementType, null)) : null;
public record GeneratorTypeDescriptor(
String className,
PyType yieldType, // if YieldType is not specified, it is AnyType
PyType sendType, // if SendType is not specified, it is PyNoneType
PyType returnType // if ReturnType is not specified, it is PyNoneType
) {
private static final List<String> SYNC_TYPES = List.of(GENERATOR, "typing.Iterable", "typing.Iterator");
private static final List<String> ASYNC_TYPES = List.of(ASYNC_GENERATOR, "typing.AsyncIterable", "typing.AsyncIterator");
public static @Nullable GeneratorTypeDescriptor create(@Nullable PyType type) {
final PyClassType classType = as(type, PyClassType.class);
final PyCollectionType genericType = as(type, PyCollectionType.class);
if (classType == null) return null;
final String qName = classType.getClassQName();
if (!SYNC_TYPES.contains(qName) && !ASYNC_TYPES.contains(qName)) return null;
PyType yieldType = null;
PyType sendType = PyNoneType.INSTANCE;
PyType returnType = PyNoneType.INSTANCE;
if (genericType != null) {
yieldType = ContainerUtil.getOrElse(genericType.getElementTypes(), 0, yieldType);
if (GENERATOR.equals(qName) || ASYNC_GENERATOR.equals(qName)) {
sendType = ContainerUtil.getOrElse(genericType.getElementTypes(), 1, sendType);
}
if (GENERATOR.equals(qName)) {
returnType = ContainerUtil.getOrElse(genericType.getElementTypes(), 2, returnType);
}
}
return new GeneratorTypeDescriptor(qName, yieldType, sendType, returnType);
}
public boolean isAsync() {
return ASYNC_TYPES.contains(className);
}
public GeneratorTypeDescriptor withAsync(boolean async) {
if (async) {
var idx = SYNC_TYPES.indexOf(className);
if (idx == -1) return this;
return new GeneratorTypeDescriptor(ASYNC_TYPES.get(idx), yieldType, sendType, PyNoneType.INSTANCE);
}
else {
var idx = ASYNC_TYPES.indexOf(className);
if (idx == -1) return this;
return new GeneratorTypeDescriptor(SYNC_TYPES.get(idx), yieldType, sendType, returnType);
}
}
public @Nullable PyType toPyType(@NotNull PsiElement anchor) {
final PyClass classType = PyPsiFacade.getInstance(anchor.getProject()).createClassByQName(className, anchor);
final List<PyType> generics;
if (GENERATOR.equals(className)) {
generics = Arrays.asList(yieldType, sendType, returnType);
}
else if (ASYNC_GENERATOR.equals(className)) {
generics = Arrays.asList(yieldType, sendType);
}
else {
generics = Collections.singletonList(yieldType);
}
return classType != null ? new PyCollectionTypeImpl(classType, false, generics) : null;
}
}
@Nullable

View File

@@ -16,6 +16,7 @@ import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
import com.jetbrains.python.codeInsight.typing.PyProtocolsKt;
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider.GeneratorTypeDescriptor;
import com.jetbrains.python.documentation.PythonDocumentationProvider;
import com.jetbrains.python.inspections.quickfix.PyMakeFunctionReturnTypeQuickFix;
import com.jetbrains.python.inspections.quickfix.PyMakeReturnsExplicitFix;
@@ -92,7 +93,7 @@ public class PyTypeCheckerInspection extends PyInspection {
PyAnnotation annotation = function.getAnnotation();
String typeCommentAnnotation = function.getTypeCommentAnnotation();
if (annotation != null || typeCommentAnnotation != null) {
PyType expected = getExpectedReturnType(function, myTypeEvalContext);
PyType expected = getExpectedReturnStatementType(function, myTypeEvalContext);
if (expected == null) return;
// We cannot just match annotated and inferred types, as we cannot promote inferred to Literal
@@ -124,14 +125,80 @@ public class PyTypeCheckerInspection extends PyInspection {
}
}
@Nullable
public static PyType getExpectedReturnType(@NotNull PyFunction function, @NotNull TypeEvalContext typeEvalContext) {
final PyType returnType = typeEvalContext.getReturnType(function);
@Override
public void visitPyYieldExpression(@NotNull PyYieldExpression node) {
ScopeOwner owner = ScopeUtil.getScopeOwner(node);
if (!(owner instanceof PyFunction function)) return;
final PyAnnotation annotation = function.getAnnotation();
final String typeCommentAnnotation = function.getTypeCommentAnnotation();
if (annotation == null && typeCommentAnnotation == null) return;
if (function.isAsync() || function.isGenerator()) {
return Ref.deref(PyTypingTypeProvider.coroutineOrGeneratorElementType(returnType));
final PyType fullReturnType = myTypeEvalContext.getReturnType(function);
if (fullReturnType == null) return; // fullReturnType is Any
final var generatorDesc = GeneratorTypeDescriptor.create(fullReturnType);
if (generatorDesc == null) {
// expected type is not Iterable, Iterator, Generator or similar
final PyType actual = function.getInferredReturnType(myTypeEvalContext);
String expectedName = PythonDocumentationProvider.getVerboseTypeName(fullReturnType, myTypeEvalContext);
String actualName = PythonDocumentationProvider.getTypeName(actual, myTypeEvalContext);
getHolder()
.problem(node, PyPsiBundle.message("INSP.type.checker.expected.type.got.type.instead", expectedName, actualName))
.fix(new PyMakeFunctionReturnTypeQuickFix(function, myTypeEvalContext))
.register();
return;
}
final PyType expectedYieldType = generatorDesc.yieldType();
final PyType expectedSendType = generatorDesc.sendType();
final PyType thisYieldType = node.getYieldType(myTypeEvalContext);
final PyExpression yieldExpr = node.getExpression();
if (!PyTypeChecker.match(expectedYieldType, thisYieldType, myTypeEvalContext)) {
String expectedName = PythonDocumentationProvider.getVerboseTypeName(expectedYieldType, myTypeEvalContext);
String actualName = PythonDocumentationProvider.getTypeName(thisYieldType, myTypeEvalContext);
getHolder()
.problem(yieldExpr != null ? yieldExpr : node, PyPsiBundle.message("INSP.type.checker.yield.type.mismatch", expectedName, actualName))
.fix(new PyMakeFunctionReturnTypeQuickFix(function, myTypeEvalContext))
.register();
}
if (yieldExpr != null && node.isDelegating()) {
final PyType delegateType = myTypeEvalContext.getType(yieldExpr);
var delegateDesc = GeneratorTypeDescriptor.create(delegateType);
if (delegateDesc == null) return;
if (delegateDesc.isAsync()) {
String delegateName = PythonDocumentationProvider.getTypeName(delegateType, myTypeEvalContext);
registerProblem(yieldExpr, PyPsiBundle.message("INSP.type.checker.yield.from.async.generator", delegateName, delegateName));
return;
}
// Reversed because SendType is contravariant
if (!PyTypeChecker.match(delegateDesc.sendType(), expectedSendType, myTypeEvalContext)) {
String expectedName = PythonDocumentationProvider.getVerboseTypeName(expectedSendType, myTypeEvalContext);
String actualName = PythonDocumentationProvider.getTypeName(delegateDesc.sendType(), myTypeEvalContext);
registerProblem(yieldExpr, PyPsiBundle.message("INSP.type.checker.yield.from.send.type.mismatch", expectedName, actualName));
}
}
}
@Nullable
public static PyType getExpectedReturnStatementType(@NotNull PyFunction function, @NotNull TypeEvalContext typeEvalContext) {
final PyType returnType = typeEvalContext.getReturnType(function);
if (function.isGenerator()) {
final var generatorDesc = GeneratorTypeDescriptor.create(returnType);
if (generatorDesc != null) {
return generatorDesc.returnType();
}
}
if (function.isAsync()) {
return Ref.deref(PyTypingTypeProvider.coroutineOrGeneratorElementType(returnType));
}
return returnType;
}
@@ -238,7 +305,7 @@ public class PyTypeCheckerInspection extends PyInspection {
final PyAnnotation annotation = node.getAnnotation();
final String typeCommentAnnotation = node.getTypeCommentAnnotation();
if (annotation != null || typeCommentAnnotation != null) {
final PyType expected = getExpectedReturnType(node, myTypeEvalContext);
final PyType expected = getExpectedReturnStatementType(node, myTypeEvalContext);
final boolean returnsNone = expected instanceof PyNoneType;
final boolean returnsOptional = PyTypeChecker.match(expected, PyNoneType.INSTANCE, myTypeEvalContext);
@@ -263,6 +330,24 @@ public class PyTypeCheckerInspection extends PyInspection {
registerProblem(annotation != null ? annotation.getValue() : node.getTypeComment(),
PyPsiBundle.message("INSP.type.checker.init.should.return.none"));
}
if (node.isGenerator()) {
boolean shouldBeAsync = node.isAsync() && node.isAsyncAllowed();
final PyType annotatedType = myTypeEvalContext.getReturnType(node);
final var generatorDesc = GeneratorTypeDescriptor.create(annotatedType);
if (generatorDesc != null && generatorDesc.isAsync() != shouldBeAsync) {
final PyType inferredType = node.getInferredReturnType(myTypeEvalContext);
String expectedName = PythonDocumentationProvider.getVerboseTypeName(inferredType, myTypeEvalContext);
String actualName = PythonDocumentationProvider.getTypeName(annotatedType, myTypeEvalContext);
final PsiElement annotationValue = annotation != null ? annotation.getValue() : node.getTypeComment();
if (annotationValue != null) {
getHolder()
.problem(annotationValue, PyPsiBundle.message("INSP.type.checker.expected.type.got.type.instead", expectedName, actualName))
.fix(new PyMakeFunctionReturnTypeQuickFix(node, myTypeEvalContext))
.register();
}
}
}
}
}

View File

@@ -15,50 +15,48 @@
*/
package com.jetbrains.python.inspections.quickfix;
import com.intellij.codeInsight.intention.FileModifier;
import com.intellij.codeInspection.LocalQuickFix;
import com.intellij.codeInspection.ProblemDescriptor;
import com.intellij.openapi.project.Project;
import com.intellij.modcommand.ActionContext;
import com.intellij.modcommand.ModPsiUpdater;
import com.intellij.modcommand.Presentation;
import com.intellij.modcommand.PsiUpdateModCommandAction;
import com.intellij.psi.*;
import com.jetbrains.python.PyPsiBundle;
import com.jetbrains.python.codeInsight.intentions.PyTypeHintGenerationUtil;
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
import com.jetbrains.python.documentation.PythonDocumentationProvider;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import com.jetbrains.python.refactoring.PyRefactoringUtil;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.List;
public class PyMakeFunctionReturnTypeQuickFix implements LocalQuickFix {
private final @NotNull SmartPsiElementPointer<PyFunction> myFunction;
public class PyMakeFunctionReturnTypeQuickFix extends PsiUpdateModCommandAction<PyFunction> {
private final String myReturnTypeName;
public PyMakeFunctionReturnTypeQuickFix(@NotNull PyFunction function, @NotNull TypeEvalContext context) {
this(function, getReturnTypeName(function, context));
}
private PyMakeFunctionReturnTypeQuickFix(@NotNull PyFunction function, @NotNull String returnTypeName) {
SmartPointerManager manager = SmartPointerManager.getInstance(function.getProject());
myFunction = manager.createSmartPsiElementPointer(function);
myReturnTypeName = returnTypeName;
super(function);
myReturnTypeName = getReturnTypeName(function, context);
}
@NotNull
private static String getReturnTypeName(@NotNull PyFunction function, @NotNull TypeEvalContext context) {
final PyType type = function.getInferredReturnType(context);
PyType type = function.getInferredReturnType(context);
if (function.isAsync()) {
var unwrappedType = PyTypingTypeProvider.unwrapCoroutineReturnType(type);
if (unwrappedType != null) {
type = unwrappedType.get();
}
}
return PythonDocumentationProvider.getTypeHint(type, context);
}
@Override
@NotNull
public String getName() {
PyFunction function = myFunction.getElement();
String functionName = function != null ? function.getName() : "function";
return PyPsiBundle.message("QFIX.make.function.return.type", functionName, myReturnTypeName);
@Nullable
protected Presentation getPresentation(@NotNull ActionContext context, @NotNull PyFunction function) {
return Presentation.of(PyPsiBundle.message("QFIX.make.function.return.type", function.getName(), myReturnTypeName));
}
@Override
@@ -68,11 +66,8 @@ public class PyMakeFunctionReturnTypeQuickFix implements LocalQuickFix {
}
@Override
public void applyFix(@NotNull Project project, @NotNull ProblemDescriptor descriptor) {
final PyFunction function = myFunction.getElement();
if (function == null) return;
final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project);
protected void invoke(@NotNull ActionContext context, @NotNull PyFunction function, @NotNull ModPsiUpdater updater) {
final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(function.getProject());
boolean shouldAddImports = false;
@@ -96,28 +91,18 @@ public class PyMakeFunctionReturnTypeQuickFix implements LocalQuickFix {
}
if (shouldAddImports) {
addImportsForTypeAnnotations(TypeEvalContext.userInitiated(project, function.getContainingFile()));
addImportsForTypeAnnotations(function);
}
}
private void addImportsForTypeAnnotations(@NotNull TypeEvalContext context) {
PyFunction function = myFunction.getElement();
if (function == null) return;
private static void addImportsForTypeAnnotations(@NotNull PyFunction function) {
PsiFile file = function.getContainingFile();
if (file == null) return;
TypeEvalContext context = TypeEvalContext.userInitiated(function.getProject(), file);
PyType typeForImports = function.getInferredReturnType(context);
if (typeForImports != null) {
PyTypeHintGenerationUtil.addImportsForTypeAnnotations(List.of(typeForImports), context, file);
}
}
@Override
public @Nullable FileModifier getFileModifierForPreview(@NotNull PsiFile target) {
PyFunction function = PyRefactoringUtil.findSameElementForPreview(myFunction, target);
if (function == null) {
return null;
}
return new PyMakeFunctionReturnTypeQuickFix(function, myReturnTypeName);
}
}

View File

@@ -6,6 +6,7 @@ import com.intellij.codeInsight.controlflow.Instruction;
import com.intellij.lang.ASTNode;
import com.intellij.navigation.ItemPresentation;
import com.intellij.openapi.util.Key;
import com.intellij.openapi.util.Pair;
import com.intellij.openapi.util.Ref;
import com.intellij.openapi.vfs.VirtualFile;
import com.intellij.psi.PsiElement;
@@ -16,10 +17,8 @@ import com.intellij.psi.stubs.IStubElementType;
import com.intellij.psi.stubs.StubElement;
import com.intellij.psi.util.*;
import com.intellij.ui.IconManager;
import com.intellij.util.ArrayUtil;
import com.intellij.util.IncorrectOperationException;
import com.intellij.util.PlatformIcons;
import com.intellij.util.containers.ContainerUtil;
import com.intellij.util.containers.JBIterable;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.PyStubElementTypes;
@@ -38,13 +37,16 @@ import com.jetbrains.python.psi.stubs.PyFunctionStub;
import com.jetbrains.python.psi.stubs.PyTargetExpressionStub;
import com.jetbrains.python.psi.types.*;
import com.jetbrains.python.sdk.PythonSdkUtil;
import one.util.streamex.StreamEx;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import javax.swing.*;
import java.util.*;
import java.util.stream.Stream;
import static com.intellij.openapi.util.text.StringUtil.notNullize;
import static com.intellij.util.containers.ContainerUtil.*;
import static com.jetbrains.python.ast.PyAstFunction.Modifier.CLASSMETHOD;
import static com.jetbrains.python.ast.PyAstFunction.Modifier.STATICMETHOD;
import static com.jetbrains.python.psi.PyUtil.as;
@@ -130,7 +132,7 @@ public class PyFunctionImpl extends PyBaseElementImpl<PyFunctionStub> implements
.filter(PyCallableType.class::isInstance)
.map(PyCallableType.class::cast)
.map(callableType -> callableType.getParameters(context))
.orElseGet(() -> ContainerUtil.map(getParameterList().getParameters(), PyCallableParameterImpl::psi));
.orElseGet(() -> map(getParameterList().getParameters(), PyCallableParameterImpl::psi));
}
@Override
@@ -164,12 +166,13 @@ public class PyFunctionImpl extends PyBaseElementImpl<PyFunctionStub> implements
public @Nullable PyType getInferredReturnType(@NotNull TypeEvalContext context) {
PyType inferredType = null;
if (context.allowReturnTypes(this)) {
final Ref<? extends PyType> yieldTypeRef = getYieldStatementType(context);
if (yieldTypeRef != null) {
inferredType = yieldTypeRef.get();
final PyType returnType = getReturnStatementType(context);
final Pair<PyType, PyType> yieldSendTypePair = getYieldExpressionType(context);
if (yieldSendTypePair != null) {
inferredType = PyTypingTypeProvider.wrapInGeneratorType(yieldSendTypePair.first, yieldSendTypePair.second, returnType, this);
}
else {
inferredType = getReturnStatementType(context);
inferredType = returnType;
}
}
return PyTypingTypeProvider.removeNarrowedTypeIfNeeded(PyTypingTypeProvider.toAsyncIfNeeded(this, inferredType));
@@ -189,7 +192,7 @@ public class PyFunctionImpl extends PyBaseElementImpl<PyFunctionStub> implements
final Map<PyExpression, PyCallableParameter> mappedExplicitParameters = fullMapping.getMappedParameters();
final Map<PyExpression, PyCallableParameter> allMappedParameters = new LinkedHashMap<>();
final PyCallableParameter firstImplicit = ContainerUtil.getFirstItem(fullMapping.getImplicitParameters());
final PyCallableParameter firstImplicit = getFirstItem(fullMapping.getImplicitParameters());
if (receiver != null && firstImplicit != null) {
allMappedParameters.put(receiver, firstImplicit);
}
@@ -282,7 +285,7 @@ public class PyFunctionImpl extends PyBaseElementImpl<PyFunctionStub> implements
else if (allowCoroutineOrGenerator &&
returnType instanceof PyCollectionType &&
PyTypingTypeProvider.coroutineOrGeneratorElementType(returnType) != null) {
final List<PyType> replacedElementTypes = ContainerUtil.map(
final List<PyType> replacedElementTypes = map(
((PyCollectionType)returnType).getElementTypes(),
type -> replaceSelf(type, receiver, context, false)
);
@@ -308,42 +311,41 @@ public class PyFunctionImpl extends PyBaseElementImpl<PyFunctionStub> implements
}
return false;
}
public static class YieldCollector extends PyRecursiveElementVisitor {
public List<PyYieldExpression> getYieldExpressions() {
return myYieldExpressions;
}
private @Nullable Ref<? extends PyType> getYieldStatementType(final @NotNull TypeEvalContext context) {
final PyBuiltinCache cache = PyBuiltinCache.getInstance(this);
final private List<PyYieldExpression> myYieldExpressions = new ArrayList<>();
@Override
public void visitPyYieldExpression(@NotNull PyYieldExpression node) {
myYieldExpressions.add(node);
}
@Override
public void visitPyFunction(@NotNull PyFunction node) {
// Ignore nested functions
}
@Override
public void visitPyLambdaExpression(@NotNull PyLambdaExpression node) {
// Ignore nested lambdas
}
}
/**
* @return pair of YieldType and SendType. Null when there are no yields
*/
private @Nullable Pair<PyType, PyType> getYieldExpressionType(final @NotNull TypeEvalContext context) {
final PyStatementList statements = getStatementList();
final Set<PyType> types = new LinkedHashSet<>();
statements.accept(new PyRecursiveElementVisitor() {
@Override
public void visitPyYieldExpression(@NotNull PyYieldExpression node) {
final PyExpression expr = node.getExpression();
final PyType type = expr != null ? context.getType(expr) : null;
if (node.isDelegating()) {
if (type instanceof PyCollectionType) {
types.add(((PyCollectionType)type).getIteratedItemType());
}
else if (ArrayUtil.contains(type, cache.getListType(), cache.getDictType(), cache.getSetType(), cache.getTupleType())) {
types.add(null);
}
else {
types.add(type);
}
}
else {
types.add(type);
}
}
@Override
public void visitPyFunction(@NotNull PyFunction node) {
// Ignore nested functions
}
});
if (!types.isEmpty()) {
final PyType elementType = PyUnionType.union(types);
final PyType returnType = getReturnStatementType(context);
return Ref.create(PyTypingTypeProvider.wrapInGeneratorType(elementType, returnType, this));
final YieldCollector visitor = new YieldCollector();
statements.accept(visitor);
final List<PyType> yieldTypes = map(visitor.getYieldExpressions(), it -> it.getYieldType(context));
final List<PyType> sendTypes = map(visitor.getYieldExpressions(), it -> it.getSendType(context));
if (!yieldTypes.isEmpty()) {
return Pair.create(PyUnionType.union(yieldTypes), PyUnionType.union(sendTypes));
}
return null;
}

View File

@@ -42,7 +42,7 @@ public class PyGeneratorExpressionImpl extends PyComprehensionElementImpl implem
public PyType getType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) {
final PyExpression resultExpr = getResultExpression();
if (resultExpr != null) {
return PyTypingTypeProvider.wrapInGeneratorType(context.getType(resultExpr), PyNoneType.INSTANCE, this);
return PyTypingTypeProvider.wrapInGeneratorType(context.getType(resultExpr), null, PyNoneType.INSTANCE, this);
}
return null;
}

View File

@@ -4,18 +4,19 @@ package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.codeInsight.controlflow.ControlFlowCache;
import com.jetbrains.python.psi.PyCallSiteExpression;
import com.jetbrains.python.psi.PyElementVisitor;
import com.jetbrains.python.psi.PyExpression;
import com.jetbrains.python.psi.PyLambdaExpression;
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.types.*;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import static com.intellij.util.containers.ContainerUtil.map;
public class PyLambdaExpressionImpl extends PyElementImpl implements PyLambdaExpression {
public PyLambdaExpressionImpl(ASTNode astNode) {
@@ -56,7 +57,19 @@ public class PyLambdaExpressionImpl extends PyElementImpl implements PyLambdaExp
@Override
public PyType getReturnType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) {
final PyExpression body = getBody();
return body != null ? context.getType(body) : null;
if (body == null) return null;
final PyFunctionImpl.YieldCollector visitor = new PyFunctionImpl.YieldCollector();
body.accept(visitor);
final List<PyType> yieldTypes = map(visitor.getYieldExpressions(), it -> it.getYieldType(context));
final List<PyType> sendTypes = map(visitor.getYieldExpressions(), it -> it.getSendType(context));
if (!yieldTypes.isEmpty()) {
return PyTypingTypeProvider.wrapInGeneratorType(
PyUnionType.union(yieldTypes), PyUnionType.union(sendTypes), context.getType(body), this);
}
return context.getType(body);
}
@Nullable

View File

@@ -127,16 +127,7 @@ public class PyTargetExpressionImpl extends PyBaseElementImpl<PyTargetExpression
}
final PsiElement parent = PsiTreeUtil.skipParentsOfType(this, PyParenthesizedExpression.class);
if (parent instanceof PyAssignmentStatement assignmentStatement) {
PyExpression assignedValue = assignmentStatement.getAssignedValue();
if (assignedValue instanceof PyParenthesizedExpression) {
assignedValue = ((PyParenthesizedExpression)assignedValue).getContainedExpression();
}
if (assignedValue != null) {
if (assignedValue instanceof PyYieldExpression assignedYield) {
return assignedYield.isDelegating() ? context.getType(assignedValue) : null;
}
return context.getType(assignedValue);
}
return context.getType(assignmentStatement.getAssignedValue());
}
if (parent instanceof PyTupleExpression || parent instanceof PyListLiteralExpression) {
PsiElement nextParent =

View File

@@ -2,15 +2,16 @@
package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode;
import com.intellij.openapi.util.Ref;
import com.intellij.util.ArrayUtil;
import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
import com.jetbrains.python.psi.PyElementVisitor;
import com.jetbrains.python.psi.PyExpression;
import com.jetbrains.python.psi.PyFunction;
import com.jetbrains.python.psi.PyYieldExpression;
import com.jetbrains.python.psi.types.PyNoneType;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import com.jetbrains.python.psi.types.*;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
public class PyYieldExpressionImpl extends PyElementImpl implements PyYieldExpression {
@@ -25,12 +26,52 @@ public class PyYieldExpressionImpl extends PyElementImpl implements PyYieldExpre
@Override
public PyType getType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) {
final PyExpression e = getExpression();
final PyType type = e != null ? context.getType(e) : null;
if (isDelegating()) {
final Ref<PyType> generatorElementType = PyTypingTypeProvider.coroutineOrGeneratorElementType(type);
return generatorElementType == null ? PyNoneType.INSTANCE : generatorElementType.get();
final PyExpression e = getExpression();
final PyType type = e != null ? context.getType(e) : null;
var generatorDesc = PyTypingTypeProvider.GeneratorTypeDescriptor.create(type);
if (generatorDesc != null) {
return generatorDesc.returnType();
}
return PyNoneType.INSTANCE;
}
else {
return getSendType(context);
}
}
@Override
public @Nullable PyType getYieldType(@NotNull TypeEvalContext context) {
final PyExpression expr = getExpression();
final PyType type = expr != null ? context.getType(expr) : PyNoneType.INSTANCE;
if (isDelegating()) {
return PyTargetExpressionImpl.getIterationType(type, expr, this, context);
}
return type;
}
@Override
public @Nullable PyType getSendType(@NotNull TypeEvalContext context) {
if (ScopeUtil.getScopeOwner(this) instanceof PyFunction function) {
if (function.getAnnotation() != null || function.getTypeCommentAnnotation() != null) {
var returnType = context.getReturnType(function);
var generatorDesc = PyTypingTypeProvider.GeneratorTypeDescriptor.create(returnType);
if (generatorDesc != null) {
return generatorDesc.sendType();
}
}
}
if (isDelegating()) {
final PyExpression e = getExpression();
final PyType type = e != null ? context.getType(e) : null;
var generatorDesc = PyTypingTypeProvider.GeneratorTypeDescriptor.create(type);
if (generatorDesc != null) {
return generatorDesc.sendType();
}
return PyNoneType.INSTANCE;
}
return null;
}
}

View File

@@ -50,4 +50,8 @@ def l(x) -> <warning descr="Expected type 'int', got 'int | None' instead">int</
def m(x) -> None:
"""Does not display warning about implicit return, because annotated '-> None' """
if x:
return
return
def n() -> Generator[int, Any, str]:
yield 13
return 42

View File

@@ -0,0 +1,69 @@
from typing import Generator, Iterable, Iterator, AsyncIterable, AsyncIterator, AsyncGenerator
# Fix incorrect YieldType
def a() -> Iterable[str]:
yield <warning descr="Expected yield type 'str', got 'int' instead">42</warning>
def b() -> Iterator[str]:
yield <warning descr="Expected yield type 'str', got 'int' instead">42</warning>
def c() -> Generator[str, Any, int]:
yield <warning descr="Expected yield type 'str', got 'int' instead">13</warning>
return 42
def c() -> Generator[int, Any, str]:
yield 13
return <warning descr="Expected type 'str', got 'int' instead">42</warning>
# Suggest AsyncGenerator
async def d() -> <warning descr="Expected type 'AsyncGenerator[int, None]', got 'Iterable[int]' instead">Iterable[int]</warning>:
yield 42
async def e() -> <warning descr="Expected type 'AsyncGenerator[int, None]', got 'Iterator[int]' instead">Iterator[int]</warning>:
yield 42
async def f() -> <warning descr="Expected type 'AsyncGenerator[int, str]', got 'Generator[int, str, None]' instead">Generator[int, str, None]</warning>:
yield 13
# Suggest sync Generator
def g() -> <warning descr="Expected type 'Generator[int, None, None]', got 'AsyncIterable[int]' instead">AsyncIterable[int]</warning>:
yield 42
def h() -> <warning descr="Expected type 'Generator[int, None, None]', got 'AsyncIterator[int]' instead">AsyncIterator[int]</warning>:
yield 42
def i() -> <warning descr="Expected type 'Generator[int, str, None]', got 'AsyncGenerator[int, str]' instead">AsyncGenerator[int, str]</warning>:
yield 13
def j() -> Iterator[int]:
yield from j()
def k() -> Iterator[str]:
yield from <warning descr="Expected yield type 'str', got 'int' instead">j()</warning>
yield from <warning descr="Expected yield type 'str', got 'int' instead">[1]</warning>
def l() -> Generator[None, int, None]:
x: float = yield
def m() -> Generator[None, float, None]:
yield from <warning descr="Expected send type 'float', got 'int' instead">l()</warning>
def n() -> Generator[None, float, None]:
x: float = yield
def o() -> Generator[None, int, None]:
yield from n()
def p() -> Generator[int, None, None]:
yield from [1, 2]
yield from [3, 4]
def q() -> int:
x = lambda: (yield "str")
return 42
async def r() -> AsyncGenerator[int]:
yield 42
def s() -> Generator[int]:
yield from <warning descr="Cannot yield from 'AsyncGenerator[int, None]', use async for instead">r()</warning>

View File

@@ -0,0 +1,5 @@
from typing import Generator
def gen() -> Generator[int, bool, str]:
b: bool = yield <warning descr="Expected yield type 'int', got 'str' instead"><caret>"str"</warning>
return <warning descr="Expected type 'str', got 'int' instead">42</warning>

View File

@@ -0,0 +1,5 @@
from typing import Generator
def gen() -> Generator[str, bool, int]:
b: bool = yield "str"
return 42

View File

@@ -0,0 +1,4 @@
async def gen() -> str:
b: bool = <warning descr="Expected type 'str', got 'AsyncGenerator[str | float, Any]' instead"><caret>yield "str"</warning>
if b:
b = <warning descr="Expected type 'str', got 'AsyncGenerator[str | float, Any]' instead">yield 3.14</warning>

View File

@@ -0,0 +1,7 @@
from typing import Any, AsyncGenerator
async def gen() -> AsyncGenerator[str | float, Any]:
b: bool = <caret>yield "str"
if b:
b = yield 3.14

View File

@@ -0,0 +1,5 @@
def f(x) -> <warning descr="Function returning 'int | None' has implicit return">int | None<caret></warning>:
if x == 1:
return 42
elif x == 2:
<warning descr="Function returning 'int | None' has implicit return">return</warning>

View File

@@ -0,0 +1,6 @@
def f(x) -> int | None:
if x == 1:
return 42
elif x == 2:
return None
return None

View File

@@ -17,6 +17,66 @@ import java.util.Map;
public class Py3TypeTest extends PyTestCase {
public static final String TEST_DIRECTORY = "/types/";
// PY-20710
public void testLambdaGenerator() {
doTest("Generator[int, Any, Any]", """
expr = (lambda: (yield 1))()
""");
}
// PY-20710
public void testGeneratorDelegatingToLambdaGenerator() {
doTest("Generator[int, Any, str]", """
def g():
yield from (lambda: (yield 1))()
return "foo"
expr = g()
""");
}
// PY-20710
public void testYieldExpressionTypeFromGeneratorSendTypeHint() {
doTest("int", """
from typing import Generator
def g() -> Generator[str, int, None]:
expr = yield "foo"
""");
}
// PY-20710
public void testYieldFromExpressionTypeFromGeneratorReturnTypeHint() {
doTest("int", """
from typing import Generator, Any
def delegate() -> Generator[None, Any, int]:
yield
return 42
def g():
expr = yield from delegate()
""");
}
// PY-20710
public void testYieldFromLambda() {
doTest("Generator[int | str, str | Any, bool]",
"""
from typing import Generator
def gen1() -> Generator[int, str, bool]:
yield 42
return True
def gen2():
yield "str"
return True
l = lambda: (yield from gen1()) or (yield from gen2())
expr = l()
""");
}
// PY-6702
public void testYieldFromType() {
doTest("str | int | float",

View File

@@ -102,6 +102,10 @@ public class Py3TypeCheckerInspectionTest extends PyInspectionTestCase {
public void testFunctionReturnTypePy3() {
doTest();
}
public void testFunctionYieldTypePy3() {
doTest();
}
// PY-20770
public void testAsyncForOverAsyncGenerator() {

View File

@@ -64,4 +64,18 @@ public class PyMakeFunctionReturnTypeQuickFixTest extends PyQuickFixTestCase {
PyPsiBundle.message("QFIX.make.function.return.type", "func", "Callable[[Any], int]"),
LanguageLevel.getLatest());
}
// PY-20710
public void testChangeGenerator() {
doQuickFixTest(PyTypeCheckerInspection.class,
PyPsiBundle.message("QFIX.make.function.return.type", "gen", "Generator[str, bool, int]"),
LanguageLevel.getLatest());
}
// PY-20710
public void testMakeGenerator() {
doQuickFixTest(PyTypeCheckerInspection.class,
PyPsiBundle.message("QFIX.make.function.return.type", "gen", "AsyncGenerator[str | float, Any]"),
LanguageLevel.getLatest());
}
}