PY-20710 Support 'Generator' typing class

Check YieldType of yield expressions in PyTypeCheckerInspection
Check that (Async)Generator is used in (async) function
Check that in 'yield from' sync Generator is used
Convert PyMakeFunctionReturnTypeQuickFix into PsiUpdateModCommandAction
Infer Generator type for lambdas
When getting function type from annotation, do not convert Generator to AsyncGenerator
Introduce GeneratorTypeDescriptor to simplify working with generator annotations


Merge-request: IJ-MR-146521
Merged-by: Aleksandr Govenko <aleksandr.govenko@jetbrains.com>

(cherry picked from commit b3b8182168c5224f0e03f54d443171ccf6ca7b89)

IJ-MR-146521

GitOrigin-RevId: a95670d7e2787015bcf162637ea6d7bfb47a312a
This commit is contained in:
Aleksandr.Govenko
2024-12-04 15:22:15 +00:00
committed by intellij-monorepo-bot
parent 362a0344a7
commit 70fe60b4c8
22 changed files with 509 additions and 122 deletions

View File

@@ -1,6 +1,7 @@
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file. // Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
package com.jetbrains.python.psi; package com.jetbrains.python.psi;
import com.intellij.openapi.util.Pair;
import com.intellij.psi.PsiNameIdentifierOwner; import com.intellij.psi.PsiNameIdentifierOwner;
import com.intellij.psi.StubBasedPsiElement; import com.intellij.psi.StubBasedPsiElement;
import com.intellij.util.ArrayFactory; import com.intellij.util.ArrayFactory;

View File

@@ -2,6 +2,9 @@
package com.jetbrains.python.psi; package com.jetbrains.python.psi;
import com.jetbrains.python.ast.PyAstYieldExpression; import com.jetbrains.python.ast.PyAstYieldExpression;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable; import org.jetbrains.annotations.Nullable;
@@ -11,4 +14,17 @@ public interface PyYieldExpression extends PyAstYieldExpression, PyExpression {
default PyExpression getExpression() { default PyExpression getExpression() {
return (PyExpression)PyAstYieldExpression.super.getExpression(); return (PyExpression)PyAstYieldExpression.super.getExpression();
} }
/**
* @return For {@code yield}, returns type of its expression. For {@code yield from} - YieldType of the delegate
*/
@Nullable
PyType getYieldType(@NotNull TypeEvalContext context);
/**
* @return If containing function is annotated with Generator (or AsyncGenerator), returns SendType from annotation.
* Otherwise, Any for {@code yield} and SendType of the delegate for {@code yield from}
*/
@Nullable
PyType getSendType(@NotNull TypeEvalContext context);
} }

View File

@@ -1056,6 +1056,9 @@ QFIX.ignore.shadowed.built.in.name=Ignore shadowed built-in name "{0}"
# PyTypeCheckerInspection # PyTypeCheckerInspection
INSP.NAME.type.checker=Incorrect type INSP.NAME.type.checker=Incorrect type
INSP.type.checker.expected.type.got.type.instead=Expected type ''{0}'', got ''{1}'' instead INSP.type.checker.expected.type.got.type.instead=Expected type ''{0}'', got ''{1}'' instead
INSP.type.checker.yield.type.mismatch=Expected yield type ''{0}'', got ''{1}'' instead
INSP.type.checker.yield.from.send.type.mismatch=Expected send type ''{0}'', got ''{1}'' instead
INSP.type.checker.yield.from.async.generator=Cannot yield from ''{0}'', use 'async for' instead
INSP.type.checker.typed.dict.extra.key=Extra key ''{0}'' for TypedDict ''{1}'' INSP.type.checker.typed.dict.extra.key=Extra key ''{0}'' for TypedDict ''{1}''
INSP.type.checker.typed.dict.missing.keys=TypedDict ''{0}'' has missing {1,choice,1#key|2#keys}: {2} INSP.type.checker.typed.dict.missing.keys=TypedDict ''{0}'' has missing {1,choice,1#key|2#keys}: {2}
INSP.type.checker.returning.type.has.implicit.return=Function returning ''{0}'' has implicit ''return None'' INSP.type.checker.returning.type.has.implicit.return=Function returning ''{0}'' has implicit ''return None''

View File

@@ -353,7 +353,11 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
if (returnTypeAnnotation != null) { if (returnTypeAnnotation != null) {
final Ref<PyType> typeRef = getType(returnTypeAnnotation, context); final Ref<PyType> typeRef = getType(returnTypeAnnotation, context);
if (typeRef != null) { if (typeRef != null) {
return Ref.create(toAsyncIfNeeded(function, typeRef.get())); // Do not use toAsyncIfNeeded, as it also converts Generators. Here we do not need it.
if (function.isAsync() && function.isAsyncAllowed() && !function.isGenerator()) {
return Ref.create(wrapInCoroutineType(typeRef.get(), function));
}
return typeRef;
} }
// Don't rely on other type providers if a type hint is present, but cannot be resolved. // Don't rely on other type providers if a type hint is present, but cannot be resolved.
return Ref.create(); return Ref.create();
@@ -2156,8 +2160,9 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
if (!function.isGenerator()) { if (!function.isGenerator()) {
return wrapInCoroutineType(returnType, function); return wrapInCoroutineType(returnType, function);
} }
else if (returnType instanceof PyCollectionType && isGenerator(returnType)) { var desc = GeneratorTypeDescriptor.create(returnType);
return wrapInAsyncGeneratorType(((PyCollectionType)returnType).getIteratedItemType(), function); if (desc != null) {
return desc.withAsync(true).toPyType(function);
} }
} }
@@ -2184,15 +2189,77 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
} }
@Nullable @Nullable
public static PyType wrapInGeneratorType(@Nullable PyType elementType, @Nullable PyType returnType, @NotNull PsiElement anchor) { public static PyType wrapInGeneratorType(@Nullable PyType elementType, @Nullable PyType sendType, @Nullable PyType returnType, @NotNull PsiElement anchor) {
final PyClass generator = PyPsiFacade.getInstance(anchor.getProject()).createClassByQName(GENERATOR, anchor); final PyClass generator = PyPsiFacade.getInstance(anchor.getProject()).createClassByQName(GENERATOR, anchor);
return generator != null ? new PyCollectionTypeImpl(generator, false, Arrays.asList(elementType, null, returnType)) : null; return generator != null ? new PyCollectionTypeImpl(generator, false, Arrays.asList(elementType, sendType, returnType)) : null;
} }
@Nullable public record GeneratorTypeDescriptor(
private static PyType wrapInAsyncGeneratorType(@Nullable PyType elementType, @NotNull PsiElement anchor) { String className,
final PyClass asyncGenerator = PyPsiFacade.getInstance(anchor.getProject()).createClassByQName(ASYNC_GENERATOR, anchor); PyType yieldType, // if YieldType is not specified, it is AnyType
return asyncGenerator != null ? new PyCollectionTypeImpl(asyncGenerator, false, Arrays.asList(elementType, null)) : null; PyType sendType, // if SendType is not specified, it is PyNoneType
PyType returnType // if ReturnType is not specified, it is PyNoneType
) {
private static final List<String> SYNC_TYPES = List.of(GENERATOR, "typing.Iterable", "typing.Iterator");
private static final List<String> ASYNC_TYPES = List.of(ASYNC_GENERATOR, "typing.AsyncIterable", "typing.AsyncIterator");
public static @Nullable GeneratorTypeDescriptor create(@Nullable PyType type) {
final PyClassType classType = as(type, PyClassType.class);
final PyCollectionType genericType = as(type, PyCollectionType.class);
if (classType == null) return null;
final String qName = classType.getClassQName();
if (!SYNC_TYPES.contains(qName) && !ASYNC_TYPES.contains(qName)) return null;
PyType yieldType = null;
PyType sendType = PyNoneType.INSTANCE;
PyType returnType = PyNoneType.INSTANCE;
if (genericType != null) {
yieldType = ContainerUtil.getOrElse(genericType.getElementTypes(), 0, yieldType);
if (GENERATOR.equals(qName) || ASYNC_GENERATOR.equals(qName)) {
sendType = ContainerUtil.getOrElse(genericType.getElementTypes(), 1, sendType);
}
if (GENERATOR.equals(qName)) {
returnType = ContainerUtil.getOrElse(genericType.getElementTypes(), 2, returnType);
}
}
return new GeneratorTypeDescriptor(qName, yieldType, sendType, returnType);
}
public boolean isAsync() {
return ASYNC_TYPES.contains(className);
}
public GeneratorTypeDescriptor withAsync(boolean async) {
if (async) {
var idx = SYNC_TYPES.indexOf(className);
if (idx == -1) return this;
return new GeneratorTypeDescriptor(ASYNC_TYPES.get(idx), yieldType, sendType, PyNoneType.INSTANCE);
}
else {
var idx = ASYNC_TYPES.indexOf(className);
if (idx == -1) return this;
return new GeneratorTypeDescriptor(SYNC_TYPES.get(idx), yieldType, sendType, returnType);
}
}
public @Nullable PyType toPyType(@NotNull PsiElement anchor) {
final PyClass classType = PyPsiFacade.getInstance(anchor.getProject()).createClassByQName(className, anchor);
final List<PyType> generics;
if (GENERATOR.equals(className)) {
generics = Arrays.asList(yieldType, sendType, returnType);
}
else if (ASYNC_GENERATOR.equals(className)) {
generics = Arrays.asList(yieldType, sendType);
}
else {
generics = Collections.singletonList(yieldType);
}
return classType != null ? new PyCollectionTypeImpl(classType, false, generics) : null;
}
} }
@Nullable @Nullable

View File

@@ -16,6 +16,7 @@ import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil; import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
import com.jetbrains.python.codeInsight.typing.PyProtocolsKt; import com.jetbrains.python.codeInsight.typing.PyProtocolsKt;
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider; import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider.GeneratorTypeDescriptor;
import com.jetbrains.python.documentation.PythonDocumentationProvider; import com.jetbrains.python.documentation.PythonDocumentationProvider;
import com.jetbrains.python.inspections.quickfix.PyMakeFunctionReturnTypeQuickFix; import com.jetbrains.python.inspections.quickfix.PyMakeFunctionReturnTypeQuickFix;
import com.jetbrains.python.inspections.quickfix.PyMakeReturnsExplicitFix; import com.jetbrains.python.inspections.quickfix.PyMakeReturnsExplicitFix;
@@ -92,7 +93,7 @@ public class PyTypeCheckerInspection extends PyInspection {
PyAnnotation annotation = function.getAnnotation(); PyAnnotation annotation = function.getAnnotation();
String typeCommentAnnotation = function.getTypeCommentAnnotation(); String typeCommentAnnotation = function.getTypeCommentAnnotation();
if (annotation != null || typeCommentAnnotation != null) { if (annotation != null || typeCommentAnnotation != null) {
PyType expected = getExpectedReturnType(function, myTypeEvalContext); PyType expected = getExpectedReturnStatementType(function, myTypeEvalContext);
if (expected == null) return; if (expected == null) return;
// We cannot just match annotated and inferred types, as we cannot promote inferred to Literal // We cannot just match annotated and inferred types, as we cannot promote inferred to Literal
@@ -124,14 +125,80 @@ public class PyTypeCheckerInspection extends PyInspection {
} }
} }
@Nullable @Override
public static PyType getExpectedReturnType(@NotNull PyFunction function, @NotNull TypeEvalContext typeEvalContext) { public void visitPyYieldExpression(@NotNull PyYieldExpression node) {
final PyType returnType = typeEvalContext.getReturnType(function); ScopeOwner owner = ScopeUtil.getScopeOwner(node);
if (!(owner instanceof PyFunction function)) return;
final PyAnnotation annotation = function.getAnnotation();
final String typeCommentAnnotation = function.getTypeCommentAnnotation();
if (annotation == null && typeCommentAnnotation == null) return;
if (function.isAsync() || function.isGenerator()) { final PyType fullReturnType = myTypeEvalContext.getReturnType(function);
return Ref.deref(PyTypingTypeProvider.coroutineOrGeneratorElementType(returnType)); if (fullReturnType == null) return; // fullReturnType is Any
final var generatorDesc = GeneratorTypeDescriptor.create(fullReturnType);
if (generatorDesc == null) {
// expected type is not Iterable, Iterator, Generator or similar
final PyType actual = function.getInferredReturnType(myTypeEvalContext);
String expectedName = PythonDocumentationProvider.getVerboseTypeName(fullReturnType, myTypeEvalContext);
String actualName = PythonDocumentationProvider.getTypeName(actual, myTypeEvalContext);
getHolder()
.problem(node, PyPsiBundle.message("INSP.type.checker.expected.type.got.type.instead", expectedName, actualName))
.fix(new PyMakeFunctionReturnTypeQuickFix(function, myTypeEvalContext))
.register();
return;
} }
final PyType expectedYieldType = generatorDesc.yieldType();
final PyType expectedSendType = generatorDesc.sendType();
final PyType thisYieldType = node.getYieldType(myTypeEvalContext);
final PyExpression yieldExpr = node.getExpression();
if (!PyTypeChecker.match(expectedYieldType, thisYieldType, myTypeEvalContext)) {
String expectedName = PythonDocumentationProvider.getVerboseTypeName(expectedYieldType, myTypeEvalContext);
String actualName = PythonDocumentationProvider.getTypeName(thisYieldType, myTypeEvalContext);
getHolder()
.problem(yieldExpr != null ? yieldExpr : node, PyPsiBundle.message("INSP.type.checker.yield.type.mismatch", expectedName, actualName))
.fix(new PyMakeFunctionReturnTypeQuickFix(function, myTypeEvalContext))
.register();
}
if (yieldExpr != null && node.isDelegating()) {
final PyType delegateType = myTypeEvalContext.getType(yieldExpr);
var delegateDesc = GeneratorTypeDescriptor.create(delegateType);
if (delegateDesc == null) return;
if (delegateDesc.isAsync()) {
String delegateName = PythonDocumentationProvider.getTypeName(delegateType, myTypeEvalContext);
registerProblem(yieldExpr, PyPsiBundle.message("INSP.type.checker.yield.from.async.generator", delegateName, delegateName));
return;
}
// Reversed because SendType is contravariant
if (!PyTypeChecker.match(delegateDesc.sendType(), expectedSendType, myTypeEvalContext)) {
String expectedName = PythonDocumentationProvider.getVerboseTypeName(expectedSendType, myTypeEvalContext);
String actualName = PythonDocumentationProvider.getTypeName(delegateDesc.sendType(), myTypeEvalContext);
registerProblem(yieldExpr, PyPsiBundle.message("INSP.type.checker.yield.from.send.type.mismatch", expectedName, actualName));
}
}
}
@Nullable
public static PyType getExpectedReturnStatementType(@NotNull PyFunction function, @NotNull TypeEvalContext typeEvalContext) {
final PyType returnType = typeEvalContext.getReturnType(function);
if (function.isGenerator()) {
final var generatorDesc = GeneratorTypeDescriptor.create(returnType);
if (generatorDesc != null) {
return generatorDesc.returnType();
}
}
if (function.isAsync()) {
return Ref.deref(PyTypingTypeProvider.coroutineOrGeneratorElementType(returnType));
}
return returnType; return returnType;
} }
@@ -238,7 +305,7 @@ public class PyTypeCheckerInspection extends PyInspection {
final PyAnnotation annotation = node.getAnnotation(); final PyAnnotation annotation = node.getAnnotation();
final String typeCommentAnnotation = node.getTypeCommentAnnotation(); final String typeCommentAnnotation = node.getTypeCommentAnnotation();
if (annotation != null || typeCommentAnnotation != null) { if (annotation != null || typeCommentAnnotation != null) {
final PyType expected = getExpectedReturnType(node, myTypeEvalContext); final PyType expected = getExpectedReturnStatementType(node, myTypeEvalContext);
final boolean returnsNone = expected instanceof PyNoneType; final boolean returnsNone = expected instanceof PyNoneType;
final boolean returnsOptional = PyTypeChecker.match(expected, PyNoneType.INSTANCE, myTypeEvalContext); final boolean returnsOptional = PyTypeChecker.match(expected, PyNoneType.INSTANCE, myTypeEvalContext);
@@ -263,6 +330,24 @@ public class PyTypeCheckerInspection extends PyInspection {
registerProblem(annotation != null ? annotation.getValue() : node.getTypeComment(), registerProblem(annotation != null ? annotation.getValue() : node.getTypeComment(),
PyPsiBundle.message("INSP.type.checker.init.should.return.none")); PyPsiBundle.message("INSP.type.checker.init.should.return.none"));
} }
if (node.isGenerator()) {
boolean shouldBeAsync = node.isAsync() && node.isAsyncAllowed();
final PyType annotatedType = myTypeEvalContext.getReturnType(node);
final var generatorDesc = GeneratorTypeDescriptor.create(annotatedType);
if (generatorDesc != null && generatorDesc.isAsync() != shouldBeAsync) {
final PyType inferredType = node.getInferredReturnType(myTypeEvalContext);
String expectedName = PythonDocumentationProvider.getVerboseTypeName(inferredType, myTypeEvalContext);
String actualName = PythonDocumentationProvider.getTypeName(annotatedType, myTypeEvalContext);
final PsiElement annotationValue = annotation != null ? annotation.getValue() : node.getTypeComment();
if (annotationValue != null) {
getHolder()
.problem(annotationValue, PyPsiBundle.message("INSP.type.checker.expected.type.got.type.instead", expectedName, actualName))
.fix(new PyMakeFunctionReturnTypeQuickFix(node, myTypeEvalContext))
.register();
}
}
}
} }
} }

View File

@@ -15,50 +15,48 @@
*/ */
package com.jetbrains.python.inspections.quickfix; package com.jetbrains.python.inspections.quickfix;
import com.intellij.codeInsight.intention.FileModifier; import com.intellij.modcommand.ActionContext;
import com.intellij.codeInspection.LocalQuickFix; import com.intellij.modcommand.ModPsiUpdater;
import com.intellij.codeInspection.ProblemDescriptor; import com.intellij.modcommand.Presentation;
import com.intellij.openapi.project.Project; import com.intellij.modcommand.PsiUpdateModCommandAction;
import com.intellij.psi.*; import com.intellij.psi.*;
import com.jetbrains.python.PyPsiBundle; import com.jetbrains.python.PyPsiBundle;
import com.jetbrains.python.codeInsight.intentions.PyTypeHintGenerationUtil; import com.jetbrains.python.codeInsight.intentions.PyTypeHintGenerationUtil;
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
import com.jetbrains.python.documentation.PythonDocumentationProvider; import com.jetbrains.python.documentation.PythonDocumentationProvider;
import com.jetbrains.python.psi.*; import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.types.PyType; import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext; import com.jetbrains.python.psi.types.TypeEvalContext;
import com.jetbrains.python.refactoring.PyRefactoringUtil;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable; import org.jetbrains.annotations.Nullable;
import java.util.List; import java.util.List;
public class PyMakeFunctionReturnTypeQuickFix implements LocalQuickFix { public class PyMakeFunctionReturnTypeQuickFix extends PsiUpdateModCommandAction<PyFunction> {
private final @NotNull SmartPsiElementPointer<PyFunction> myFunction;
private final String myReturnTypeName; private final String myReturnTypeName;
public PyMakeFunctionReturnTypeQuickFix(@NotNull PyFunction function, @NotNull TypeEvalContext context) { public PyMakeFunctionReturnTypeQuickFix(@NotNull PyFunction function, @NotNull TypeEvalContext context) {
this(function, getReturnTypeName(function, context)); super(function);
} myReturnTypeName = getReturnTypeName(function, context);
private PyMakeFunctionReturnTypeQuickFix(@NotNull PyFunction function, @NotNull String returnTypeName) {
SmartPointerManager manager = SmartPointerManager.getInstance(function.getProject());
myFunction = manager.createSmartPsiElementPointer(function);
myReturnTypeName = returnTypeName;
} }
@NotNull @NotNull
private static String getReturnTypeName(@NotNull PyFunction function, @NotNull TypeEvalContext context) { private static String getReturnTypeName(@NotNull PyFunction function, @NotNull TypeEvalContext context) {
final PyType type = function.getInferredReturnType(context); PyType type = function.getInferredReturnType(context);
if (function.isAsync()) {
var unwrappedType = PyTypingTypeProvider.unwrapCoroutineReturnType(type);
if (unwrappedType != null) {
type = unwrappedType.get();
}
}
return PythonDocumentationProvider.getTypeHint(type, context); return PythonDocumentationProvider.getTypeHint(type, context);
} }
@Override @Override
@NotNull @Nullable
public String getName() { protected Presentation getPresentation(@NotNull ActionContext context, @NotNull PyFunction function) {
PyFunction function = myFunction.getElement(); return Presentation.of(PyPsiBundle.message("QFIX.make.function.return.type", function.getName(), myReturnTypeName));
String functionName = function != null ? function.getName() : "function";
return PyPsiBundle.message("QFIX.make.function.return.type", functionName, myReturnTypeName);
} }
@Override @Override
@@ -68,11 +66,8 @@ public class PyMakeFunctionReturnTypeQuickFix implements LocalQuickFix {
} }
@Override @Override
public void applyFix(@NotNull Project project, @NotNull ProblemDescriptor descriptor) { protected void invoke(@NotNull ActionContext context, @NotNull PyFunction function, @NotNull ModPsiUpdater updater) {
final PyFunction function = myFunction.getElement(); final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(function.getProject());
if (function == null) return;
final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project);
boolean shouldAddImports = false; boolean shouldAddImports = false;
@@ -96,28 +91,18 @@ public class PyMakeFunctionReturnTypeQuickFix implements LocalQuickFix {
} }
if (shouldAddImports) { if (shouldAddImports) {
addImportsForTypeAnnotations(TypeEvalContext.userInitiated(project, function.getContainingFile())); addImportsForTypeAnnotations(function);
} }
} }
private void addImportsForTypeAnnotations(@NotNull TypeEvalContext context) { private static void addImportsForTypeAnnotations(@NotNull PyFunction function) {
PyFunction function = myFunction.getElement();
if (function == null) return;
PsiFile file = function.getContainingFile(); PsiFile file = function.getContainingFile();
if (file == null) return; if (file == null) return;
TypeEvalContext context = TypeEvalContext.userInitiated(function.getProject(), file);
PyType typeForImports = function.getInferredReturnType(context); PyType typeForImports = function.getInferredReturnType(context);
if (typeForImports != null) { if (typeForImports != null) {
PyTypeHintGenerationUtil.addImportsForTypeAnnotations(List.of(typeForImports), context, file); PyTypeHintGenerationUtil.addImportsForTypeAnnotations(List.of(typeForImports), context, file);
} }
} }
@Override
public @Nullable FileModifier getFileModifierForPreview(@NotNull PsiFile target) {
PyFunction function = PyRefactoringUtil.findSameElementForPreview(myFunction, target);
if (function == null) {
return null;
}
return new PyMakeFunctionReturnTypeQuickFix(function, myReturnTypeName);
}
} }

View File

@@ -6,6 +6,7 @@ import com.intellij.codeInsight.controlflow.Instruction;
import com.intellij.lang.ASTNode; import com.intellij.lang.ASTNode;
import com.intellij.navigation.ItemPresentation; import com.intellij.navigation.ItemPresentation;
import com.intellij.openapi.util.Key; import com.intellij.openapi.util.Key;
import com.intellij.openapi.util.Pair;
import com.intellij.openapi.util.Ref; import com.intellij.openapi.util.Ref;
import com.intellij.openapi.vfs.VirtualFile; import com.intellij.openapi.vfs.VirtualFile;
import com.intellij.psi.PsiElement; import com.intellij.psi.PsiElement;
@@ -16,10 +17,8 @@ import com.intellij.psi.stubs.IStubElementType;
import com.intellij.psi.stubs.StubElement; import com.intellij.psi.stubs.StubElement;
import com.intellij.psi.util.*; import com.intellij.psi.util.*;
import com.intellij.ui.IconManager; import com.intellij.ui.IconManager;
import com.intellij.util.ArrayUtil;
import com.intellij.util.IncorrectOperationException; import com.intellij.util.IncorrectOperationException;
import com.intellij.util.PlatformIcons; import com.intellij.util.PlatformIcons;
import com.intellij.util.containers.ContainerUtil;
import com.intellij.util.containers.JBIterable; import com.intellij.util.containers.JBIterable;
import com.jetbrains.python.PyNames; import com.jetbrains.python.PyNames;
import com.jetbrains.python.PyStubElementTypes; import com.jetbrains.python.PyStubElementTypes;
@@ -38,13 +37,16 @@ import com.jetbrains.python.psi.stubs.PyFunctionStub;
import com.jetbrains.python.psi.stubs.PyTargetExpressionStub; import com.jetbrains.python.psi.stubs.PyTargetExpressionStub;
import com.jetbrains.python.psi.types.*; import com.jetbrains.python.psi.types.*;
import com.jetbrains.python.sdk.PythonSdkUtil; import com.jetbrains.python.sdk.PythonSdkUtil;
import one.util.streamex.StreamEx;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable; import org.jetbrains.annotations.Nullable;
import javax.swing.*; import javax.swing.*;
import java.util.*; import java.util.*;
import java.util.stream.Stream;
import static com.intellij.openapi.util.text.StringUtil.notNullize; import static com.intellij.openapi.util.text.StringUtil.notNullize;
import static com.intellij.util.containers.ContainerUtil.*;
import static com.jetbrains.python.ast.PyAstFunction.Modifier.CLASSMETHOD; import static com.jetbrains.python.ast.PyAstFunction.Modifier.CLASSMETHOD;
import static com.jetbrains.python.ast.PyAstFunction.Modifier.STATICMETHOD; import static com.jetbrains.python.ast.PyAstFunction.Modifier.STATICMETHOD;
import static com.jetbrains.python.psi.PyUtil.as; import static com.jetbrains.python.psi.PyUtil.as;
@@ -130,7 +132,7 @@ public class PyFunctionImpl extends PyBaseElementImpl<PyFunctionStub> implements
.filter(PyCallableType.class::isInstance) .filter(PyCallableType.class::isInstance)
.map(PyCallableType.class::cast) .map(PyCallableType.class::cast)
.map(callableType -> callableType.getParameters(context)) .map(callableType -> callableType.getParameters(context))
.orElseGet(() -> ContainerUtil.map(getParameterList().getParameters(), PyCallableParameterImpl::psi)); .orElseGet(() -> map(getParameterList().getParameters(), PyCallableParameterImpl::psi));
} }
@Override @Override
@@ -164,12 +166,13 @@ public class PyFunctionImpl extends PyBaseElementImpl<PyFunctionStub> implements
public @Nullable PyType getInferredReturnType(@NotNull TypeEvalContext context) { public @Nullable PyType getInferredReturnType(@NotNull TypeEvalContext context) {
PyType inferredType = null; PyType inferredType = null;
if (context.allowReturnTypes(this)) { if (context.allowReturnTypes(this)) {
final Ref<? extends PyType> yieldTypeRef = getYieldStatementType(context); final PyType returnType = getReturnStatementType(context);
if (yieldTypeRef != null) { final Pair<PyType, PyType> yieldSendTypePair = getYieldExpressionType(context);
inferredType = yieldTypeRef.get(); if (yieldSendTypePair != null) {
inferredType = PyTypingTypeProvider.wrapInGeneratorType(yieldSendTypePair.first, yieldSendTypePair.second, returnType, this);
} }
else { else {
inferredType = getReturnStatementType(context); inferredType = returnType;
} }
} }
return PyTypingTypeProvider.removeNarrowedTypeIfNeeded(PyTypingTypeProvider.toAsyncIfNeeded(this, inferredType)); return PyTypingTypeProvider.removeNarrowedTypeIfNeeded(PyTypingTypeProvider.toAsyncIfNeeded(this, inferredType));
@@ -189,7 +192,7 @@ public class PyFunctionImpl extends PyBaseElementImpl<PyFunctionStub> implements
final Map<PyExpression, PyCallableParameter> mappedExplicitParameters = fullMapping.getMappedParameters(); final Map<PyExpression, PyCallableParameter> mappedExplicitParameters = fullMapping.getMappedParameters();
final Map<PyExpression, PyCallableParameter> allMappedParameters = new LinkedHashMap<>(); final Map<PyExpression, PyCallableParameter> allMappedParameters = new LinkedHashMap<>();
final PyCallableParameter firstImplicit = ContainerUtil.getFirstItem(fullMapping.getImplicitParameters()); final PyCallableParameter firstImplicit = getFirstItem(fullMapping.getImplicitParameters());
if (receiver != null && firstImplicit != null) { if (receiver != null && firstImplicit != null) {
allMappedParameters.put(receiver, firstImplicit); allMappedParameters.put(receiver, firstImplicit);
} }
@@ -282,7 +285,7 @@ public class PyFunctionImpl extends PyBaseElementImpl<PyFunctionStub> implements
else if (allowCoroutineOrGenerator && else if (allowCoroutineOrGenerator &&
returnType instanceof PyCollectionType && returnType instanceof PyCollectionType &&
PyTypingTypeProvider.coroutineOrGeneratorElementType(returnType) != null) { PyTypingTypeProvider.coroutineOrGeneratorElementType(returnType) != null) {
final List<PyType> replacedElementTypes = ContainerUtil.map( final List<PyType> replacedElementTypes = map(
((PyCollectionType)returnType).getElementTypes(), ((PyCollectionType)returnType).getElementTypes(),
type -> replaceSelf(type, receiver, context, false) type -> replaceSelf(type, receiver, context, false)
); );
@@ -308,42 +311,41 @@ public class PyFunctionImpl extends PyBaseElementImpl<PyFunctionStub> implements
} }
return false; return false;
} }
public static class YieldCollector extends PyRecursiveElementVisitor {
public List<PyYieldExpression> getYieldExpressions() {
return myYieldExpressions;
}
private @Nullable Ref<? extends PyType> getYieldStatementType(final @NotNull TypeEvalContext context) { final private List<PyYieldExpression> myYieldExpressions = new ArrayList<>();
final PyBuiltinCache cache = PyBuiltinCache.getInstance(this);
@Override
public void visitPyYieldExpression(@NotNull PyYieldExpression node) {
myYieldExpressions.add(node);
}
@Override
public void visitPyFunction(@NotNull PyFunction node) {
// Ignore nested functions
}
@Override
public void visitPyLambdaExpression(@NotNull PyLambdaExpression node) {
// Ignore nested lambdas
}
}
/**
* @return pair of YieldType and SendType. Null when there are no yields
*/
private @Nullable Pair<PyType, PyType> getYieldExpressionType(final @NotNull TypeEvalContext context) {
final PyStatementList statements = getStatementList(); final PyStatementList statements = getStatementList();
final Set<PyType> types = new LinkedHashSet<>(); final YieldCollector visitor = new YieldCollector();
statements.accept(new PyRecursiveElementVisitor() { statements.accept(visitor);
@Override final List<PyType> yieldTypes = map(visitor.getYieldExpressions(), it -> it.getYieldType(context));
public void visitPyYieldExpression(@NotNull PyYieldExpression node) { final List<PyType> sendTypes = map(visitor.getYieldExpressions(), it -> it.getSendType(context));
final PyExpression expr = node.getExpression(); if (!yieldTypes.isEmpty()) {
final PyType type = expr != null ? context.getType(expr) : null; return Pair.create(PyUnionType.union(yieldTypes), PyUnionType.union(sendTypes));
if (node.isDelegating()) {
if (type instanceof PyCollectionType) {
types.add(((PyCollectionType)type).getIteratedItemType());
}
else if (ArrayUtil.contains(type, cache.getListType(), cache.getDictType(), cache.getSetType(), cache.getTupleType())) {
types.add(null);
}
else {
types.add(type);
}
}
else {
types.add(type);
}
}
@Override
public void visitPyFunction(@NotNull PyFunction node) {
// Ignore nested functions
}
});
if (!types.isEmpty()) {
final PyType elementType = PyUnionType.union(types);
final PyType returnType = getReturnStatementType(context);
return Ref.create(PyTypingTypeProvider.wrapInGeneratorType(elementType, returnType, this));
} }
return null; return null;
} }

View File

@@ -42,7 +42,7 @@ public class PyGeneratorExpressionImpl extends PyComprehensionElementImpl implem
public PyType getType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) { public PyType getType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) {
final PyExpression resultExpr = getResultExpression(); final PyExpression resultExpr = getResultExpression();
if (resultExpr != null) { if (resultExpr != null) {
return PyTypingTypeProvider.wrapInGeneratorType(context.getType(resultExpr), PyNoneType.INSTANCE, this); return PyTypingTypeProvider.wrapInGeneratorType(context.getType(resultExpr), null, PyNoneType.INSTANCE, this);
} }
return null; return null;
} }

View File

@@ -4,18 +4,19 @@ package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode; import com.intellij.lang.ASTNode;
import com.intellij.util.containers.ContainerUtil; import com.intellij.util.containers.ContainerUtil;
import com.jetbrains.python.codeInsight.controlflow.ControlFlowCache; import com.jetbrains.python.codeInsight.controlflow.ControlFlowCache;
import com.jetbrains.python.psi.PyCallSiteExpression; import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
import com.jetbrains.python.psi.PyElementVisitor; import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.PyExpression;
import com.jetbrains.python.psi.PyLambdaExpression;
import com.jetbrains.python.psi.types.*; import com.jetbrains.python.psi.types.*;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable; import org.jetbrains.annotations.Nullable;
import java.util.ArrayList;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Optional; import java.util.Optional;
import static com.intellij.util.containers.ContainerUtil.map;
public class PyLambdaExpressionImpl extends PyElementImpl implements PyLambdaExpression { public class PyLambdaExpressionImpl extends PyElementImpl implements PyLambdaExpression {
public PyLambdaExpressionImpl(ASTNode astNode) { public PyLambdaExpressionImpl(ASTNode astNode) {
@@ -56,7 +57,19 @@ public class PyLambdaExpressionImpl extends PyElementImpl implements PyLambdaExp
@Override @Override
public PyType getReturnType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) { public PyType getReturnType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) {
final PyExpression body = getBody(); final PyExpression body = getBody();
return body != null ? context.getType(body) : null; if (body == null) return null;
final PyFunctionImpl.YieldCollector visitor = new PyFunctionImpl.YieldCollector();
body.accept(visitor);
final List<PyType> yieldTypes = map(visitor.getYieldExpressions(), it -> it.getYieldType(context));
final List<PyType> sendTypes = map(visitor.getYieldExpressions(), it -> it.getSendType(context));
if (!yieldTypes.isEmpty()) {
return PyTypingTypeProvider.wrapInGeneratorType(
PyUnionType.union(yieldTypes), PyUnionType.union(sendTypes), context.getType(body), this);
}
return context.getType(body);
} }
@Nullable @Nullable

View File

@@ -127,16 +127,7 @@ public class PyTargetExpressionImpl extends PyBaseElementImpl<PyTargetExpression
} }
final PsiElement parent = PsiTreeUtil.skipParentsOfType(this, PyParenthesizedExpression.class); final PsiElement parent = PsiTreeUtil.skipParentsOfType(this, PyParenthesizedExpression.class);
if (parent instanceof PyAssignmentStatement assignmentStatement) { if (parent instanceof PyAssignmentStatement assignmentStatement) {
PyExpression assignedValue = assignmentStatement.getAssignedValue(); return context.getType(assignmentStatement.getAssignedValue());
if (assignedValue instanceof PyParenthesizedExpression) {
assignedValue = ((PyParenthesizedExpression)assignedValue).getContainedExpression();
}
if (assignedValue != null) {
if (assignedValue instanceof PyYieldExpression assignedYield) {
return assignedYield.isDelegating() ? context.getType(assignedValue) : null;
}
return context.getType(assignedValue);
}
} }
if (parent instanceof PyTupleExpression || parent instanceof PyListLiteralExpression) { if (parent instanceof PyTupleExpression || parent instanceof PyListLiteralExpression) {
PsiElement nextParent = PsiElement nextParent =

View File

@@ -2,15 +2,16 @@
package com.jetbrains.python.psi.impl; package com.jetbrains.python.psi.impl;
import com.intellij.lang.ASTNode; import com.intellij.lang.ASTNode;
import com.intellij.openapi.util.Ref; import com.intellij.util.ArrayUtil;
import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider; import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
import com.jetbrains.python.psi.PyElementVisitor; import com.jetbrains.python.psi.PyElementVisitor;
import com.jetbrains.python.psi.PyExpression; import com.jetbrains.python.psi.PyExpression;
import com.jetbrains.python.psi.PyFunction;
import com.jetbrains.python.psi.PyYieldExpression; import com.jetbrains.python.psi.PyYieldExpression;
import com.jetbrains.python.psi.types.PyNoneType; import com.jetbrains.python.psi.types.*;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
public class PyYieldExpressionImpl extends PyElementImpl implements PyYieldExpression { public class PyYieldExpressionImpl extends PyElementImpl implements PyYieldExpression {
@@ -25,12 +26,52 @@ public class PyYieldExpressionImpl extends PyElementImpl implements PyYieldExpre
@Override @Override
public PyType getType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) { public PyType getType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) {
final PyExpression e = getExpression();
final PyType type = e != null ? context.getType(e) : null;
if (isDelegating()) { if (isDelegating()) {
final Ref<PyType> generatorElementType = PyTypingTypeProvider.coroutineOrGeneratorElementType(type); final PyExpression e = getExpression();
return generatorElementType == null ? PyNoneType.INSTANCE : generatorElementType.get(); final PyType type = e != null ? context.getType(e) : null;
var generatorDesc = PyTypingTypeProvider.GeneratorTypeDescriptor.create(type);
if (generatorDesc != null) {
return generatorDesc.returnType();
}
return PyNoneType.INSTANCE;
}
else {
return getSendType(context);
}
}
@Override
public @Nullable PyType getYieldType(@NotNull TypeEvalContext context) {
final PyExpression expr = getExpression();
final PyType type = expr != null ? context.getType(expr) : PyNoneType.INSTANCE;
if (isDelegating()) {
return PyTargetExpressionImpl.getIterationType(type, expr, this, context);
} }
return type; return type;
} }
@Override
public @Nullable PyType getSendType(@NotNull TypeEvalContext context) {
if (ScopeUtil.getScopeOwner(this) instanceof PyFunction function) {
if (function.getAnnotation() != null || function.getTypeCommentAnnotation() != null) {
var returnType = context.getReturnType(function);
var generatorDesc = PyTypingTypeProvider.GeneratorTypeDescriptor.create(returnType);
if (generatorDesc != null) {
return generatorDesc.sendType();
}
}
}
if (isDelegating()) {
final PyExpression e = getExpression();
final PyType type = e != null ? context.getType(e) : null;
var generatorDesc = PyTypingTypeProvider.GeneratorTypeDescriptor.create(type);
if (generatorDesc != null) {
return generatorDesc.sendType();
}
return PyNoneType.INSTANCE;
}
return null;
}
} }

View File

@@ -50,4 +50,8 @@ def l(x) -> <warning descr="Expected type 'int', got 'int | None' instead">int</
def m(x) -> None: def m(x) -> None:
"""Does not display warning about implicit return, because annotated '-> None' """ """Does not display warning about implicit return, because annotated '-> None' """
if x: if x:
return return
def n() -> Generator[int, Any, str]:
yield 13
return 42

View File

@@ -0,0 +1,69 @@
from typing import Generator, Iterable, Iterator, AsyncIterable, AsyncIterator, AsyncGenerator
# Fix incorrect YieldType
def a() -> Iterable[str]:
yield <warning descr="Expected yield type 'str', got 'int' instead">42</warning>
def b() -> Iterator[str]:
yield <warning descr="Expected yield type 'str', got 'int' instead">42</warning>
def c() -> Generator[str, Any, int]:
yield <warning descr="Expected yield type 'str', got 'int' instead">13</warning>
return 42
def c() -> Generator[int, Any, str]:
yield 13
return <warning descr="Expected type 'str', got 'int' instead">42</warning>
# Suggest AsyncGenerator
async def d() -> <warning descr="Expected type 'AsyncGenerator[int, None]', got 'Iterable[int]' instead">Iterable[int]</warning>:
yield 42
async def e() -> <warning descr="Expected type 'AsyncGenerator[int, None]', got 'Iterator[int]' instead">Iterator[int]</warning>:
yield 42
async def f() -> <warning descr="Expected type 'AsyncGenerator[int, str]', got 'Generator[int, str, None]' instead">Generator[int, str, None]</warning>:
yield 13
# Suggest sync Generator
def g() -> <warning descr="Expected type 'Generator[int, None, None]', got 'AsyncIterable[int]' instead">AsyncIterable[int]</warning>:
yield 42
def h() -> <warning descr="Expected type 'Generator[int, None, None]', got 'AsyncIterator[int]' instead">AsyncIterator[int]</warning>:
yield 42
def i() -> <warning descr="Expected type 'Generator[int, str, None]', got 'AsyncGenerator[int, str]' instead">AsyncGenerator[int, str]</warning>:
yield 13
def j() -> Iterator[int]:
yield from j()
def k() -> Iterator[str]:
yield from <warning descr="Expected yield type 'str', got 'int' instead">j()</warning>
yield from <warning descr="Expected yield type 'str', got 'int' instead">[1]</warning>
def l() -> Generator[None, int, None]:
x: float = yield
def m() -> Generator[None, float, None]:
yield from <warning descr="Expected send type 'float', got 'int' instead">l()</warning>
def n() -> Generator[None, float, None]:
x: float = yield
def o() -> Generator[None, int, None]:
yield from n()
def p() -> Generator[int, None, None]:
yield from [1, 2]
yield from [3, 4]
def q() -> int:
x = lambda: (yield "str")
return 42
async def r() -> AsyncGenerator[int]:
yield 42
def s() -> Generator[int]:
yield from <warning descr="Cannot yield from 'AsyncGenerator[int, None]', use async for instead">r()</warning>

View File

@@ -0,0 +1,5 @@
from typing import Generator
def gen() -> Generator[int, bool, str]:
b: bool = yield <warning descr="Expected yield type 'int', got 'str' instead"><caret>"str"</warning>
return <warning descr="Expected type 'str', got 'int' instead">42</warning>

View File

@@ -0,0 +1,5 @@
from typing import Generator
def gen() -> Generator[str, bool, int]:
b: bool = yield "str"
return 42

View File

@@ -0,0 +1,4 @@
async def gen() -> str:
b: bool = <warning descr="Expected type 'str', got 'AsyncGenerator[str | float, Any]' instead"><caret>yield "str"</warning>
if b:
b = <warning descr="Expected type 'str', got 'AsyncGenerator[str | float, Any]' instead">yield 3.14</warning>

View File

@@ -0,0 +1,7 @@
from typing import Any, AsyncGenerator
async def gen() -> AsyncGenerator[str | float, Any]:
b: bool = <caret>yield "str"
if b:
b = yield 3.14

View File

@@ -0,0 +1,5 @@
def f(x) -> <warning descr="Function returning 'int | None' has implicit return">int | None<caret></warning>:
if x == 1:
return 42
elif x == 2:
<warning descr="Function returning 'int | None' has implicit return">return</warning>

View File

@@ -0,0 +1,6 @@
def f(x) -> int | None:
if x == 1:
return 42
elif x == 2:
return None
return None

View File

@@ -17,6 +17,66 @@ import java.util.Map;
public class Py3TypeTest extends PyTestCase { public class Py3TypeTest extends PyTestCase {
public static final String TEST_DIRECTORY = "/types/"; public static final String TEST_DIRECTORY = "/types/";
// PY-20710
public void testLambdaGenerator() {
doTest("Generator[int, Any, Any]", """
expr = (lambda: (yield 1))()
""");
}
// PY-20710
public void testGeneratorDelegatingToLambdaGenerator() {
doTest("Generator[int, Any, str]", """
def g():
yield from (lambda: (yield 1))()
return "foo"
expr = g()
""");
}
// PY-20710
public void testYieldExpressionTypeFromGeneratorSendTypeHint() {
doTest("int", """
from typing import Generator
def g() -> Generator[str, int, None]:
expr = yield "foo"
""");
}
// PY-20710
public void testYieldFromExpressionTypeFromGeneratorReturnTypeHint() {
doTest("int", """
from typing import Generator, Any
def delegate() -> Generator[None, Any, int]:
yield
return 42
def g():
expr = yield from delegate()
""");
}
// PY-20710
public void testYieldFromLambda() {
doTest("Generator[int | str, str | Any, bool]",
"""
from typing import Generator
def gen1() -> Generator[int, str, bool]:
yield 42
return True
def gen2():
yield "str"
return True
l = lambda: (yield from gen1()) or (yield from gen2())
expr = l()
""");
}
// PY-6702 // PY-6702
public void testYieldFromType() { public void testYieldFromType() {
doTest("str | int | float", doTest("str | int | float",

View File

@@ -102,6 +102,10 @@ public class Py3TypeCheckerInspectionTest extends PyInspectionTestCase {
public void testFunctionReturnTypePy3() { public void testFunctionReturnTypePy3() {
doTest(); doTest();
} }
public void testFunctionYieldTypePy3() {
doTest();
}
// PY-20770 // PY-20770
public void testAsyncForOverAsyncGenerator() { public void testAsyncForOverAsyncGenerator() {

View File

@@ -64,4 +64,18 @@ public class PyMakeFunctionReturnTypeQuickFixTest extends PyQuickFixTestCase {
PyPsiBundle.message("QFIX.make.function.return.type", "func", "Callable[[Any], int]"), PyPsiBundle.message("QFIX.make.function.return.type", "func", "Callable[[Any], int]"),
LanguageLevel.getLatest()); LanguageLevel.getLatest());
} }
// PY-20710
public void testChangeGenerator() {
doQuickFixTest(PyTypeCheckerInspection.class,
PyPsiBundle.message("QFIX.make.function.return.type", "gen", "Generator[str, bool, int]"),
LanguageLevel.getLatest());
}
// PY-20710
public void testMakeGenerator() {
doQuickFixTest(PyTypeCheckerInspection.class,
PyPsiBundle.message("QFIX.make.function.return.type", "gen", "AsyncGenerator[str | float, Any]"),
LanguageLevel.getLatest());
}
} }