mirror of
https://gitflic.ru/project/openide/openide.git
synced 2025-12-15 11:53:49 +07:00
PY-20710 Support 'Generator' typing class
Check YieldType of yield expressions in PyTypeCheckerInspection Check that (Async)Generator is used in (async) function Check that in 'yield from' sync Generator is used Convert PyMakeFunctionReturnTypeQuickFix into PsiUpdateModCommandAction Infer Generator type for lambdas When getting function type from annotation, do not convert Generator to AsyncGenerator Introduce GeneratorTypeDescriptor to simplify working with generator annotations Merge-request: IJ-MR-146521 Merged-by: Aleksandr Govenko <aleksandr.govenko@jetbrains.com> (cherry picked from commit b3b8182168c5224f0e03f54d443171ccf6ca7b89) IJ-MR-146521 GitOrigin-RevId: a95670d7e2787015bcf162637ea6d7bfb47a312a
This commit is contained in:
committed by
intellij-monorepo-bot
parent
362a0344a7
commit
70fe60b4c8
@@ -1,6 +1,7 @@
|
|||||||
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
// Copyright 2000-2021 JetBrains s.r.o. Use of this source code is governed by the Apache 2.0 license that can be found in the LICENSE file.
|
||||||
package com.jetbrains.python.psi;
|
package com.jetbrains.python.psi;
|
||||||
|
|
||||||
|
import com.intellij.openapi.util.Pair;
|
||||||
import com.intellij.psi.PsiNameIdentifierOwner;
|
import com.intellij.psi.PsiNameIdentifierOwner;
|
||||||
import com.intellij.psi.StubBasedPsiElement;
|
import com.intellij.psi.StubBasedPsiElement;
|
||||||
import com.intellij.util.ArrayFactory;
|
import com.intellij.util.ArrayFactory;
|
||||||
|
|||||||
@@ -2,6 +2,9 @@
|
|||||||
package com.jetbrains.python.psi;
|
package com.jetbrains.python.psi;
|
||||||
|
|
||||||
import com.jetbrains.python.ast.PyAstYieldExpression;
|
import com.jetbrains.python.ast.PyAstYieldExpression;
|
||||||
|
import com.jetbrains.python.psi.types.PyType;
|
||||||
|
import com.jetbrains.python.psi.types.TypeEvalContext;
|
||||||
|
import org.jetbrains.annotations.NotNull;
|
||||||
import org.jetbrains.annotations.Nullable;
|
import org.jetbrains.annotations.Nullable;
|
||||||
|
|
||||||
|
|
||||||
@@ -11,4 +14,17 @@ public interface PyYieldExpression extends PyAstYieldExpression, PyExpression {
|
|||||||
default PyExpression getExpression() {
|
default PyExpression getExpression() {
|
||||||
return (PyExpression)PyAstYieldExpression.super.getExpression();
|
return (PyExpression)PyAstYieldExpression.super.getExpression();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return For {@code yield}, returns type of its expression. For {@code yield from} - YieldType of the delegate
|
||||||
|
*/
|
||||||
|
@Nullable
|
||||||
|
PyType getYieldType(@NotNull TypeEvalContext context);
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return If containing function is annotated with Generator (or AsyncGenerator), returns SendType from annotation.
|
||||||
|
* Otherwise, Any for {@code yield} and SendType of the delegate for {@code yield from}
|
||||||
|
*/
|
||||||
|
@Nullable
|
||||||
|
PyType getSendType(@NotNull TypeEvalContext context);
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1056,6 +1056,9 @@ QFIX.ignore.shadowed.built.in.name=Ignore shadowed built-in name "{0}"
|
|||||||
# PyTypeCheckerInspection
|
# PyTypeCheckerInspection
|
||||||
INSP.NAME.type.checker=Incorrect type
|
INSP.NAME.type.checker=Incorrect type
|
||||||
INSP.type.checker.expected.type.got.type.instead=Expected type ''{0}'', got ''{1}'' instead
|
INSP.type.checker.expected.type.got.type.instead=Expected type ''{0}'', got ''{1}'' instead
|
||||||
|
INSP.type.checker.yield.type.mismatch=Expected yield type ''{0}'', got ''{1}'' instead
|
||||||
|
INSP.type.checker.yield.from.send.type.mismatch=Expected send type ''{0}'', got ''{1}'' instead
|
||||||
|
INSP.type.checker.yield.from.async.generator=Cannot yield from ''{0}'', use 'async for' instead
|
||||||
INSP.type.checker.typed.dict.extra.key=Extra key ''{0}'' for TypedDict ''{1}''
|
INSP.type.checker.typed.dict.extra.key=Extra key ''{0}'' for TypedDict ''{1}''
|
||||||
INSP.type.checker.typed.dict.missing.keys=TypedDict ''{0}'' has missing {1,choice,1#key|2#keys}: {2}
|
INSP.type.checker.typed.dict.missing.keys=TypedDict ''{0}'' has missing {1,choice,1#key|2#keys}: {2}
|
||||||
INSP.type.checker.returning.type.has.implicit.return=Function returning ''{0}'' has implicit ''return None''
|
INSP.type.checker.returning.type.has.implicit.return=Function returning ''{0}'' has implicit ''return None''
|
||||||
|
|||||||
@@ -353,7 +353,11 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
|
|||||||
if (returnTypeAnnotation != null) {
|
if (returnTypeAnnotation != null) {
|
||||||
final Ref<PyType> typeRef = getType(returnTypeAnnotation, context);
|
final Ref<PyType> typeRef = getType(returnTypeAnnotation, context);
|
||||||
if (typeRef != null) {
|
if (typeRef != null) {
|
||||||
return Ref.create(toAsyncIfNeeded(function, typeRef.get()));
|
// Do not use toAsyncIfNeeded, as it also converts Generators. Here we do not need it.
|
||||||
|
if (function.isAsync() && function.isAsyncAllowed() && !function.isGenerator()) {
|
||||||
|
return Ref.create(wrapInCoroutineType(typeRef.get(), function));
|
||||||
|
}
|
||||||
|
return typeRef;
|
||||||
}
|
}
|
||||||
// Don't rely on other type providers if a type hint is present, but cannot be resolved.
|
// Don't rely on other type providers if a type hint is present, but cannot be resolved.
|
||||||
return Ref.create();
|
return Ref.create();
|
||||||
@@ -2156,8 +2160,9 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
|
|||||||
if (!function.isGenerator()) {
|
if (!function.isGenerator()) {
|
||||||
return wrapInCoroutineType(returnType, function);
|
return wrapInCoroutineType(returnType, function);
|
||||||
}
|
}
|
||||||
else if (returnType instanceof PyCollectionType && isGenerator(returnType)) {
|
var desc = GeneratorTypeDescriptor.create(returnType);
|
||||||
return wrapInAsyncGeneratorType(((PyCollectionType)returnType).getIteratedItemType(), function);
|
if (desc != null) {
|
||||||
|
return desc.withAsync(true).toPyType(function);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -2184,15 +2189,77 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Nullable
|
@Nullable
|
||||||
public static PyType wrapInGeneratorType(@Nullable PyType elementType, @Nullable PyType returnType, @NotNull PsiElement anchor) {
|
public static PyType wrapInGeneratorType(@Nullable PyType elementType, @Nullable PyType sendType, @Nullable PyType returnType, @NotNull PsiElement anchor) {
|
||||||
final PyClass generator = PyPsiFacade.getInstance(anchor.getProject()).createClassByQName(GENERATOR, anchor);
|
final PyClass generator = PyPsiFacade.getInstance(anchor.getProject()).createClassByQName(GENERATOR, anchor);
|
||||||
return generator != null ? new PyCollectionTypeImpl(generator, false, Arrays.asList(elementType, null, returnType)) : null;
|
return generator != null ? new PyCollectionTypeImpl(generator, false, Arrays.asList(elementType, sendType, returnType)) : null;
|
||||||
}
|
}
|
||||||
|
|
||||||
@Nullable
|
public record GeneratorTypeDescriptor(
|
||||||
private static PyType wrapInAsyncGeneratorType(@Nullable PyType elementType, @NotNull PsiElement anchor) {
|
String className,
|
||||||
final PyClass asyncGenerator = PyPsiFacade.getInstance(anchor.getProject()).createClassByQName(ASYNC_GENERATOR, anchor);
|
PyType yieldType, // if YieldType is not specified, it is AnyType
|
||||||
return asyncGenerator != null ? new PyCollectionTypeImpl(asyncGenerator, false, Arrays.asList(elementType, null)) : null;
|
PyType sendType, // if SendType is not specified, it is PyNoneType
|
||||||
|
PyType returnType // if ReturnType is not specified, it is PyNoneType
|
||||||
|
) {
|
||||||
|
|
||||||
|
private static final List<String> SYNC_TYPES = List.of(GENERATOR, "typing.Iterable", "typing.Iterator");
|
||||||
|
private static final List<String> ASYNC_TYPES = List.of(ASYNC_GENERATOR, "typing.AsyncIterable", "typing.AsyncIterator");
|
||||||
|
|
||||||
|
public static @Nullable GeneratorTypeDescriptor create(@Nullable PyType type) {
|
||||||
|
final PyClassType classType = as(type, PyClassType.class);
|
||||||
|
final PyCollectionType genericType = as(type, PyCollectionType.class);
|
||||||
|
if (classType == null) return null;
|
||||||
|
|
||||||
|
final String qName = classType.getClassQName();
|
||||||
|
if (!SYNC_TYPES.contains(qName) && !ASYNC_TYPES.contains(qName)) return null;
|
||||||
|
|
||||||
|
PyType yieldType = null;
|
||||||
|
PyType sendType = PyNoneType.INSTANCE;
|
||||||
|
PyType returnType = PyNoneType.INSTANCE;
|
||||||
|
|
||||||
|
if (genericType != null) {
|
||||||
|
yieldType = ContainerUtil.getOrElse(genericType.getElementTypes(), 0, yieldType);
|
||||||
|
if (GENERATOR.equals(qName) || ASYNC_GENERATOR.equals(qName)) {
|
||||||
|
sendType = ContainerUtil.getOrElse(genericType.getElementTypes(), 1, sendType);
|
||||||
|
}
|
||||||
|
if (GENERATOR.equals(qName)) {
|
||||||
|
returnType = ContainerUtil.getOrElse(genericType.getElementTypes(), 2, returnType);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return new GeneratorTypeDescriptor(qName, yieldType, sendType, returnType);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isAsync() {
|
||||||
|
return ASYNC_TYPES.contains(className);
|
||||||
|
}
|
||||||
|
|
||||||
|
public GeneratorTypeDescriptor withAsync(boolean async) {
|
||||||
|
if (async) {
|
||||||
|
var idx = SYNC_TYPES.indexOf(className);
|
||||||
|
if (idx == -1) return this;
|
||||||
|
return new GeneratorTypeDescriptor(ASYNC_TYPES.get(idx), yieldType, sendType, PyNoneType.INSTANCE);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
var idx = ASYNC_TYPES.indexOf(className);
|
||||||
|
if (idx == -1) return this;
|
||||||
|
return new GeneratorTypeDescriptor(SYNC_TYPES.get(idx), yieldType, sendType, returnType);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public @Nullable PyType toPyType(@NotNull PsiElement anchor) {
|
||||||
|
final PyClass classType = PyPsiFacade.getInstance(anchor.getProject()).createClassByQName(className, anchor);
|
||||||
|
final List<PyType> generics;
|
||||||
|
if (GENERATOR.equals(className)) {
|
||||||
|
generics = Arrays.asList(yieldType, sendType, returnType);
|
||||||
|
}
|
||||||
|
else if (ASYNC_GENERATOR.equals(className)) {
|
||||||
|
generics = Arrays.asList(yieldType, sendType);
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
generics = Collections.singletonList(yieldType);
|
||||||
|
}
|
||||||
|
|
||||||
|
return classType != null ? new PyCollectionTypeImpl(classType, false, generics) : null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Nullable
|
@Nullable
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import com.jetbrains.python.codeInsight.controlflow.ScopeOwner;
|
|||||||
import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
|
import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
|
||||||
import com.jetbrains.python.codeInsight.typing.PyProtocolsKt;
|
import com.jetbrains.python.codeInsight.typing.PyProtocolsKt;
|
||||||
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
|
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
|
||||||
|
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider.GeneratorTypeDescriptor;
|
||||||
import com.jetbrains.python.documentation.PythonDocumentationProvider;
|
import com.jetbrains.python.documentation.PythonDocumentationProvider;
|
||||||
import com.jetbrains.python.inspections.quickfix.PyMakeFunctionReturnTypeQuickFix;
|
import com.jetbrains.python.inspections.quickfix.PyMakeFunctionReturnTypeQuickFix;
|
||||||
import com.jetbrains.python.inspections.quickfix.PyMakeReturnsExplicitFix;
|
import com.jetbrains.python.inspections.quickfix.PyMakeReturnsExplicitFix;
|
||||||
@@ -92,7 +93,7 @@ public class PyTypeCheckerInspection extends PyInspection {
|
|||||||
PyAnnotation annotation = function.getAnnotation();
|
PyAnnotation annotation = function.getAnnotation();
|
||||||
String typeCommentAnnotation = function.getTypeCommentAnnotation();
|
String typeCommentAnnotation = function.getTypeCommentAnnotation();
|
||||||
if (annotation != null || typeCommentAnnotation != null) {
|
if (annotation != null || typeCommentAnnotation != null) {
|
||||||
PyType expected = getExpectedReturnType(function, myTypeEvalContext);
|
PyType expected = getExpectedReturnStatementType(function, myTypeEvalContext);
|
||||||
if (expected == null) return;
|
if (expected == null) return;
|
||||||
|
|
||||||
// We cannot just match annotated and inferred types, as we cannot promote inferred to Literal
|
// We cannot just match annotated and inferred types, as we cannot promote inferred to Literal
|
||||||
@@ -124,14 +125,80 @@ public class PyTypeCheckerInspection extends PyInspection {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Nullable
|
@Override
|
||||||
public static PyType getExpectedReturnType(@NotNull PyFunction function, @NotNull TypeEvalContext typeEvalContext) {
|
public void visitPyYieldExpression(@NotNull PyYieldExpression node) {
|
||||||
final PyType returnType = typeEvalContext.getReturnType(function);
|
ScopeOwner owner = ScopeUtil.getScopeOwner(node);
|
||||||
|
if (!(owner instanceof PyFunction function)) return;
|
||||||
|
|
||||||
|
final PyAnnotation annotation = function.getAnnotation();
|
||||||
|
final String typeCommentAnnotation = function.getTypeCommentAnnotation();
|
||||||
|
if (annotation == null && typeCommentAnnotation == null) return;
|
||||||
|
|
||||||
if (function.isAsync() || function.isGenerator()) {
|
final PyType fullReturnType = myTypeEvalContext.getReturnType(function);
|
||||||
return Ref.deref(PyTypingTypeProvider.coroutineOrGeneratorElementType(returnType));
|
if (fullReturnType == null) return; // fullReturnType is Any
|
||||||
|
|
||||||
|
final var generatorDesc = GeneratorTypeDescriptor.create(fullReturnType);
|
||||||
|
if (generatorDesc == null) {
|
||||||
|
// expected type is not Iterable, Iterator, Generator or similar
|
||||||
|
final PyType actual = function.getInferredReturnType(myTypeEvalContext);
|
||||||
|
String expectedName = PythonDocumentationProvider.getVerboseTypeName(fullReturnType, myTypeEvalContext);
|
||||||
|
String actualName = PythonDocumentationProvider.getTypeName(actual, myTypeEvalContext);
|
||||||
|
getHolder()
|
||||||
|
.problem(node, PyPsiBundle.message("INSP.type.checker.expected.type.got.type.instead", expectedName, actualName))
|
||||||
|
.fix(new PyMakeFunctionReturnTypeQuickFix(function, myTypeEvalContext))
|
||||||
|
.register();
|
||||||
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
final PyType expectedYieldType = generatorDesc.yieldType();
|
||||||
|
final PyType expectedSendType = generatorDesc.sendType();
|
||||||
|
|
||||||
|
final PyType thisYieldType = node.getYieldType(myTypeEvalContext);
|
||||||
|
|
||||||
|
final PyExpression yieldExpr = node.getExpression();
|
||||||
|
|
||||||
|
if (!PyTypeChecker.match(expectedYieldType, thisYieldType, myTypeEvalContext)) {
|
||||||
|
String expectedName = PythonDocumentationProvider.getVerboseTypeName(expectedYieldType, myTypeEvalContext);
|
||||||
|
String actualName = PythonDocumentationProvider.getTypeName(thisYieldType, myTypeEvalContext);
|
||||||
|
getHolder()
|
||||||
|
.problem(yieldExpr != null ? yieldExpr : node, PyPsiBundle.message("INSP.type.checker.yield.type.mismatch", expectedName, actualName))
|
||||||
|
.fix(new PyMakeFunctionReturnTypeQuickFix(function, myTypeEvalContext))
|
||||||
|
.register();
|
||||||
|
}
|
||||||
|
|
||||||
|
if (yieldExpr != null && node.isDelegating()) {
|
||||||
|
final PyType delegateType = myTypeEvalContext.getType(yieldExpr);
|
||||||
|
var delegateDesc = GeneratorTypeDescriptor.create(delegateType);
|
||||||
|
if (delegateDesc == null) return;
|
||||||
|
|
||||||
|
if (delegateDesc.isAsync()) {
|
||||||
|
String delegateName = PythonDocumentationProvider.getTypeName(delegateType, myTypeEvalContext);
|
||||||
|
registerProblem(yieldExpr, PyPsiBundle.message("INSP.type.checker.yield.from.async.generator", delegateName, delegateName));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reversed because SendType is contravariant
|
||||||
|
if (!PyTypeChecker.match(delegateDesc.sendType(), expectedSendType, myTypeEvalContext)) {
|
||||||
|
String expectedName = PythonDocumentationProvider.getVerboseTypeName(expectedSendType, myTypeEvalContext);
|
||||||
|
String actualName = PythonDocumentationProvider.getTypeName(delegateDesc.sendType(), myTypeEvalContext);
|
||||||
|
registerProblem(yieldExpr, PyPsiBundle.message("INSP.type.checker.yield.from.send.type.mismatch", expectedName, actualName));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Nullable
|
||||||
|
public static PyType getExpectedReturnStatementType(@NotNull PyFunction function, @NotNull TypeEvalContext typeEvalContext) {
|
||||||
|
final PyType returnType = typeEvalContext.getReturnType(function);
|
||||||
|
if (function.isGenerator()) {
|
||||||
|
final var generatorDesc = GeneratorTypeDescriptor.create(returnType);
|
||||||
|
if (generatorDesc != null) {
|
||||||
|
return generatorDesc.returnType();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (function.isAsync()) {
|
||||||
|
return Ref.deref(PyTypingTypeProvider.coroutineOrGeneratorElementType(returnType));
|
||||||
|
}
|
||||||
return returnType;
|
return returnType;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -238,7 +305,7 @@ public class PyTypeCheckerInspection extends PyInspection {
|
|||||||
final PyAnnotation annotation = node.getAnnotation();
|
final PyAnnotation annotation = node.getAnnotation();
|
||||||
final String typeCommentAnnotation = node.getTypeCommentAnnotation();
|
final String typeCommentAnnotation = node.getTypeCommentAnnotation();
|
||||||
if (annotation != null || typeCommentAnnotation != null) {
|
if (annotation != null || typeCommentAnnotation != null) {
|
||||||
final PyType expected = getExpectedReturnType(node, myTypeEvalContext);
|
final PyType expected = getExpectedReturnStatementType(node, myTypeEvalContext);
|
||||||
final boolean returnsNone = expected instanceof PyNoneType;
|
final boolean returnsNone = expected instanceof PyNoneType;
|
||||||
final boolean returnsOptional = PyTypeChecker.match(expected, PyNoneType.INSTANCE, myTypeEvalContext);
|
final boolean returnsOptional = PyTypeChecker.match(expected, PyNoneType.INSTANCE, myTypeEvalContext);
|
||||||
|
|
||||||
@@ -263,6 +330,24 @@ public class PyTypeCheckerInspection extends PyInspection {
|
|||||||
registerProblem(annotation != null ? annotation.getValue() : node.getTypeComment(),
|
registerProblem(annotation != null ? annotation.getValue() : node.getTypeComment(),
|
||||||
PyPsiBundle.message("INSP.type.checker.init.should.return.none"));
|
PyPsiBundle.message("INSP.type.checker.init.should.return.none"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (node.isGenerator()) {
|
||||||
|
boolean shouldBeAsync = node.isAsync() && node.isAsyncAllowed();
|
||||||
|
final PyType annotatedType = myTypeEvalContext.getReturnType(node);
|
||||||
|
final var generatorDesc = GeneratorTypeDescriptor.create(annotatedType);
|
||||||
|
if (generatorDesc != null && generatorDesc.isAsync() != shouldBeAsync) {
|
||||||
|
final PyType inferredType = node.getInferredReturnType(myTypeEvalContext);
|
||||||
|
String expectedName = PythonDocumentationProvider.getVerboseTypeName(inferredType, myTypeEvalContext);
|
||||||
|
String actualName = PythonDocumentationProvider.getTypeName(annotatedType, myTypeEvalContext);
|
||||||
|
final PsiElement annotationValue = annotation != null ? annotation.getValue() : node.getTypeComment();
|
||||||
|
if (annotationValue != null) {
|
||||||
|
getHolder()
|
||||||
|
.problem(annotationValue, PyPsiBundle.message("INSP.type.checker.expected.type.got.type.instead", expectedName, actualName))
|
||||||
|
.fix(new PyMakeFunctionReturnTypeQuickFix(node, myTypeEvalContext))
|
||||||
|
.register();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,50 +15,48 @@
|
|||||||
*/
|
*/
|
||||||
package com.jetbrains.python.inspections.quickfix;
|
package com.jetbrains.python.inspections.quickfix;
|
||||||
|
|
||||||
import com.intellij.codeInsight.intention.FileModifier;
|
import com.intellij.modcommand.ActionContext;
|
||||||
import com.intellij.codeInspection.LocalQuickFix;
|
import com.intellij.modcommand.ModPsiUpdater;
|
||||||
import com.intellij.codeInspection.ProblemDescriptor;
|
import com.intellij.modcommand.Presentation;
|
||||||
import com.intellij.openapi.project.Project;
|
import com.intellij.modcommand.PsiUpdateModCommandAction;
|
||||||
import com.intellij.psi.*;
|
import com.intellij.psi.*;
|
||||||
import com.jetbrains.python.PyPsiBundle;
|
import com.jetbrains.python.PyPsiBundle;
|
||||||
import com.jetbrains.python.codeInsight.intentions.PyTypeHintGenerationUtil;
|
import com.jetbrains.python.codeInsight.intentions.PyTypeHintGenerationUtil;
|
||||||
|
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
|
||||||
import com.jetbrains.python.documentation.PythonDocumentationProvider;
|
import com.jetbrains.python.documentation.PythonDocumentationProvider;
|
||||||
import com.jetbrains.python.psi.*;
|
import com.jetbrains.python.psi.*;
|
||||||
import com.jetbrains.python.psi.types.PyType;
|
import com.jetbrains.python.psi.types.PyType;
|
||||||
import com.jetbrains.python.psi.types.TypeEvalContext;
|
import com.jetbrains.python.psi.types.TypeEvalContext;
|
||||||
import com.jetbrains.python.refactoring.PyRefactoringUtil;
|
|
||||||
import org.jetbrains.annotations.NotNull;
|
import org.jetbrains.annotations.NotNull;
|
||||||
import org.jetbrains.annotations.Nullable;
|
import org.jetbrains.annotations.Nullable;
|
||||||
|
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
|
|
||||||
|
|
||||||
public class PyMakeFunctionReturnTypeQuickFix implements LocalQuickFix {
|
public class PyMakeFunctionReturnTypeQuickFix extends PsiUpdateModCommandAction<PyFunction> {
|
||||||
private final @NotNull SmartPsiElementPointer<PyFunction> myFunction;
|
|
||||||
private final String myReturnTypeName;
|
private final String myReturnTypeName;
|
||||||
|
|
||||||
public PyMakeFunctionReturnTypeQuickFix(@NotNull PyFunction function, @NotNull TypeEvalContext context) {
|
public PyMakeFunctionReturnTypeQuickFix(@NotNull PyFunction function, @NotNull TypeEvalContext context) {
|
||||||
this(function, getReturnTypeName(function, context));
|
super(function);
|
||||||
}
|
myReturnTypeName = getReturnTypeName(function, context);
|
||||||
|
|
||||||
private PyMakeFunctionReturnTypeQuickFix(@NotNull PyFunction function, @NotNull String returnTypeName) {
|
|
||||||
SmartPointerManager manager = SmartPointerManager.getInstance(function.getProject());
|
|
||||||
myFunction = manager.createSmartPsiElementPointer(function);
|
|
||||||
myReturnTypeName = returnTypeName;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@NotNull
|
@NotNull
|
||||||
private static String getReturnTypeName(@NotNull PyFunction function, @NotNull TypeEvalContext context) {
|
private static String getReturnTypeName(@NotNull PyFunction function, @NotNull TypeEvalContext context) {
|
||||||
final PyType type = function.getInferredReturnType(context);
|
PyType type = function.getInferredReturnType(context);
|
||||||
|
if (function.isAsync()) {
|
||||||
|
var unwrappedType = PyTypingTypeProvider.unwrapCoroutineReturnType(type);
|
||||||
|
if (unwrappedType != null) {
|
||||||
|
type = unwrappedType.get();
|
||||||
|
}
|
||||||
|
}
|
||||||
return PythonDocumentationProvider.getTypeHint(type, context);
|
return PythonDocumentationProvider.getTypeHint(type, context);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@NotNull
|
@Nullable
|
||||||
public String getName() {
|
protected Presentation getPresentation(@NotNull ActionContext context, @NotNull PyFunction function) {
|
||||||
PyFunction function = myFunction.getElement();
|
return Presentation.of(PyPsiBundle.message("QFIX.make.function.return.type", function.getName(), myReturnTypeName));
|
||||||
String functionName = function != null ? function.getName() : "function";
|
|
||||||
return PyPsiBundle.message("QFIX.make.function.return.type", functionName, myReturnTypeName);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -68,11 +66,8 @@ public class PyMakeFunctionReturnTypeQuickFix implements LocalQuickFix {
|
|||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void applyFix(@NotNull Project project, @NotNull ProblemDescriptor descriptor) {
|
protected void invoke(@NotNull ActionContext context, @NotNull PyFunction function, @NotNull ModPsiUpdater updater) {
|
||||||
final PyFunction function = myFunction.getElement();
|
final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(function.getProject());
|
||||||
if (function == null) return;
|
|
||||||
|
|
||||||
final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project);
|
|
||||||
|
|
||||||
boolean shouldAddImports = false;
|
boolean shouldAddImports = false;
|
||||||
|
|
||||||
@@ -96,28 +91,18 @@ public class PyMakeFunctionReturnTypeQuickFix implements LocalQuickFix {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (shouldAddImports) {
|
if (shouldAddImports) {
|
||||||
addImportsForTypeAnnotations(TypeEvalContext.userInitiated(project, function.getContainingFile()));
|
addImportsForTypeAnnotations(function);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void addImportsForTypeAnnotations(@NotNull TypeEvalContext context) {
|
private static void addImportsForTypeAnnotations(@NotNull PyFunction function) {
|
||||||
PyFunction function = myFunction.getElement();
|
|
||||||
if (function == null) return;
|
|
||||||
PsiFile file = function.getContainingFile();
|
PsiFile file = function.getContainingFile();
|
||||||
if (file == null) return;
|
if (file == null) return;
|
||||||
|
TypeEvalContext context = TypeEvalContext.userInitiated(function.getProject(), file);
|
||||||
|
|
||||||
PyType typeForImports = function.getInferredReturnType(context);
|
PyType typeForImports = function.getInferredReturnType(context);
|
||||||
if (typeForImports != null) {
|
if (typeForImports != null) {
|
||||||
PyTypeHintGenerationUtil.addImportsForTypeAnnotations(List.of(typeForImports), context, file);
|
PyTypeHintGenerationUtil.addImportsForTypeAnnotations(List.of(typeForImports), context, file);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
|
||||||
public @Nullable FileModifier getFileModifierForPreview(@NotNull PsiFile target) {
|
|
||||||
PyFunction function = PyRefactoringUtil.findSameElementForPreview(myFunction, target);
|
|
||||||
if (function == null) {
|
|
||||||
return null;
|
|
||||||
}
|
|
||||||
return new PyMakeFunctionReturnTypeQuickFix(function, myReturnTypeName);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import com.intellij.codeInsight.controlflow.Instruction;
|
|||||||
import com.intellij.lang.ASTNode;
|
import com.intellij.lang.ASTNode;
|
||||||
import com.intellij.navigation.ItemPresentation;
|
import com.intellij.navigation.ItemPresentation;
|
||||||
import com.intellij.openapi.util.Key;
|
import com.intellij.openapi.util.Key;
|
||||||
|
import com.intellij.openapi.util.Pair;
|
||||||
import com.intellij.openapi.util.Ref;
|
import com.intellij.openapi.util.Ref;
|
||||||
import com.intellij.openapi.vfs.VirtualFile;
|
import com.intellij.openapi.vfs.VirtualFile;
|
||||||
import com.intellij.psi.PsiElement;
|
import com.intellij.psi.PsiElement;
|
||||||
@@ -16,10 +17,8 @@ import com.intellij.psi.stubs.IStubElementType;
|
|||||||
import com.intellij.psi.stubs.StubElement;
|
import com.intellij.psi.stubs.StubElement;
|
||||||
import com.intellij.psi.util.*;
|
import com.intellij.psi.util.*;
|
||||||
import com.intellij.ui.IconManager;
|
import com.intellij.ui.IconManager;
|
||||||
import com.intellij.util.ArrayUtil;
|
|
||||||
import com.intellij.util.IncorrectOperationException;
|
import com.intellij.util.IncorrectOperationException;
|
||||||
import com.intellij.util.PlatformIcons;
|
import com.intellij.util.PlatformIcons;
|
||||||
import com.intellij.util.containers.ContainerUtil;
|
|
||||||
import com.intellij.util.containers.JBIterable;
|
import com.intellij.util.containers.JBIterable;
|
||||||
import com.jetbrains.python.PyNames;
|
import com.jetbrains.python.PyNames;
|
||||||
import com.jetbrains.python.PyStubElementTypes;
|
import com.jetbrains.python.PyStubElementTypes;
|
||||||
@@ -38,13 +37,16 @@ import com.jetbrains.python.psi.stubs.PyFunctionStub;
|
|||||||
import com.jetbrains.python.psi.stubs.PyTargetExpressionStub;
|
import com.jetbrains.python.psi.stubs.PyTargetExpressionStub;
|
||||||
import com.jetbrains.python.psi.types.*;
|
import com.jetbrains.python.psi.types.*;
|
||||||
import com.jetbrains.python.sdk.PythonSdkUtil;
|
import com.jetbrains.python.sdk.PythonSdkUtil;
|
||||||
|
import one.util.streamex.StreamEx;
|
||||||
import org.jetbrains.annotations.NotNull;
|
import org.jetbrains.annotations.NotNull;
|
||||||
import org.jetbrains.annotations.Nullable;
|
import org.jetbrains.annotations.Nullable;
|
||||||
|
|
||||||
import javax.swing.*;
|
import javax.swing.*;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
import java.util.stream.Stream;
|
||||||
|
|
||||||
import static com.intellij.openapi.util.text.StringUtil.notNullize;
|
import static com.intellij.openapi.util.text.StringUtil.notNullize;
|
||||||
|
import static com.intellij.util.containers.ContainerUtil.*;
|
||||||
import static com.jetbrains.python.ast.PyAstFunction.Modifier.CLASSMETHOD;
|
import static com.jetbrains.python.ast.PyAstFunction.Modifier.CLASSMETHOD;
|
||||||
import static com.jetbrains.python.ast.PyAstFunction.Modifier.STATICMETHOD;
|
import static com.jetbrains.python.ast.PyAstFunction.Modifier.STATICMETHOD;
|
||||||
import static com.jetbrains.python.psi.PyUtil.as;
|
import static com.jetbrains.python.psi.PyUtil.as;
|
||||||
@@ -130,7 +132,7 @@ public class PyFunctionImpl extends PyBaseElementImpl<PyFunctionStub> implements
|
|||||||
.filter(PyCallableType.class::isInstance)
|
.filter(PyCallableType.class::isInstance)
|
||||||
.map(PyCallableType.class::cast)
|
.map(PyCallableType.class::cast)
|
||||||
.map(callableType -> callableType.getParameters(context))
|
.map(callableType -> callableType.getParameters(context))
|
||||||
.orElseGet(() -> ContainerUtil.map(getParameterList().getParameters(), PyCallableParameterImpl::psi));
|
.orElseGet(() -> map(getParameterList().getParameters(), PyCallableParameterImpl::psi));
|
||||||
}
|
}
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
@@ -164,12 +166,13 @@ public class PyFunctionImpl extends PyBaseElementImpl<PyFunctionStub> implements
|
|||||||
public @Nullable PyType getInferredReturnType(@NotNull TypeEvalContext context) {
|
public @Nullable PyType getInferredReturnType(@NotNull TypeEvalContext context) {
|
||||||
PyType inferredType = null;
|
PyType inferredType = null;
|
||||||
if (context.allowReturnTypes(this)) {
|
if (context.allowReturnTypes(this)) {
|
||||||
final Ref<? extends PyType> yieldTypeRef = getYieldStatementType(context);
|
final PyType returnType = getReturnStatementType(context);
|
||||||
if (yieldTypeRef != null) {
|
final Pair<PyType, PyType> yieldSendTypePair = getYieldExpressionType(context);
|
||||||
inferredType = yieldTypeRef.get();
|
if (yieldSendTypePair != null) {
|
||||||
|
inferredType = PyTypingTypeProvider.wrapInGeneratorType(yieldSendTypePair.first, yieldSendTypePair.second, returnType, this);
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
inferredType = getReturnStatementType(context);
|
inferredType = returnType;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return PyTypingTypeProvider.removeNarrowedTypeIfNeeded(PyTypingTypeProvider.toAsyncIfNeeded(this, inferredType));
|
return PyTypingTypeProvider.removeNarrowedTypeIfNeeded(PyTypingTypeProvider.toAsyncIfNeeded(this, inferredType));
|
||||||
@@ -189,7 +192,7 @@ public class PyFunctionImpl extends PyBaseElementImpl<PyFunctionStub> implements
|
|||||||
final Map<PyExpression, PyCallableParameter> mappedExplicitParameters = fullMapping.getMappedParameters();
|
final Map<PyExpression, PyCallableParameter> mappedExplicitParameters = fullMapping.getMappedParameters();
|
||||||
|
|
||||||
final Map<PyExpression, PyCallableParameter> allMappedParameters = new LinkedHashMap<>();
|
final Map<PyExpression, PyCallableParameter> allMappedParameters = new LinkedHashMap<>();
|
||||||
final PyCallableParameter firstImplicit = ContainerUtil.getFirstItem(fullMapping.getImplicitParameters());
|
final PyCallableParameter firstImplicit = getFirstItem(fullMapping.getImplicitParameters());
|
||||||
if (receiver != null && firstImplicit != null) {
|
if (receiver != null && firstImplicit != null) {
|
||||||
allMappedParameters.put(receiver, firstImplicit);
|
allMappedParameters.put(receiver, firstImplicit);
|
||||||
}
|
}
|
||||||
@@ -282,7 +285,7 @@ public class PyFunctionImpl extends PyBaseElementImpl<PyFunctionStub> implements
|
|||||||
else if (allowCoroutineOrGenerator &&
|
else if (allowCoroutineOrGenerator &&
|
||||||
returnType instanceof PyCollectionType &&
|
returnType instanceof PyCollectionType &&
|
||||||
PyTypingTypeProvider.coroutineOrGeneratorElementType(returnType) != null) {
|
PyTypingTypeProvider.coroutineOrGeneratorElementType(returnType) != null) {
|
||||||
final List<PyType> replacedElementTypes = ContainerUtil.map(
|
final List<PyType> replacedElementTypes = map(
|
||||||
((PyCollectionType)returnType).getElementTypes(),
|
((PyCollectionType)returnType).getElementTypes(),
|
||||||
type -> replaceSelf(type, receiver, context, false)
|
type -> replaceSelf(type, receiver, context, false)
|
||||||
);
|
);
|
||||||
@@ -308,42 +311,41 @@ public class PyFunctionImpl extends PyBaseElementImpl<PyFunctionStub> implements
|
|||||||
}
|
}
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public static class YieldCollector extends PyRecursiveElementVisitor {
|
||||||
|
public List<PyYieldExpression> getYieldExpressions() {
|
||||||
|
return myYieldExpressions;
|
||||||
|
}
|
||||||
|
|
||||||
private @Nullable Ref<? extends PyType> getYieldStatementType(final @NotNull TypeEvalContext context) {
|
final private List<PyYieldExpression> myYieldExpressions = new ArrayList<>();
|
||||||
final PyBuiltinCache cache = PyBuiltinCache.getInstance(this);
|
|
||||||
|
@Override
|
||||||
|
public void visitPyYieldExpression(@NotNull PyYieldExpression node) {
|
||||||
|
myYieldExpressions.add(node);
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void visitPyFunction(@NotNull PyFunction node) {
|
||||||
|
// Ignore nested functions
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public void visitPyLambdaExpression(@NotNull PyLambdaExpression node) {
|
||||||
|
// Ignore nested lambdas
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @return pair of YieldType and SendType. Null when there are no yields
|
||||||
|
*/
|
||||||
|
private @Nullable Pair<PyType, PyType> getYieldExpressionType(final @NotNull TypeEvalContext context) {
|
||||||
final PyStatementList statements = getStatementList();
|
final PyStatementList statements = getStatementList();
|
||||||
final Set<PyType> types = new LinkedHashSet<>();
|
final YieldCollector visitor = new YieldCollector();
|
||||||
statements.accept(new PyRecursiveElementVisitor() {
|
statements.accept(visitor);
|
||||||
@Override
|
final List<PyType> yieldTypes = map(visitor.getYieldExpressions(), it -> it.getYieldType(context));
|
||||||
public void visitPyYieldExpression(@NotNull PyYieldExpression node) {
|
final List<PyType> sendTypes = map(visitor.getYieldExpressions(), it -> it.getSendType(context));
|
||||||
final PyExpression expr = node.getExpression();
|
if (!yieldTypes.isEmpty()) {
|
||||||
final PyType type = expr != null ? context.getType(expr) : null;
|
return Pair.create(PyUnionType.union(yieldTypes), PyUnionType.union(sendTypes));
|
||||||
|
|
||||||
if (node.isDelegating()) {
|
|
||||||
if (type instanceof PyCollectionType) {
|
|
||||||
types.add(((PyCollectionType)type).getIteratedItemType());
|
|
||||||
}
|
|
||||||
else if (ArrayUtil.contains(type, cache.getListType(), cache.getDictType(), cache.getSetType(), cache.getTupleType())) {
|
|
||||||
types.add(null);
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
types.add(type);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
else {
|
|
||||||
types.add(type);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
@Override
|
|
||||||
public void visitPyFunction(@NotNull PyFunction node) {
|
|
||||||
// Ignore nested functions
|
|
||||||
}
|
|
||||||
});
|
|
||||||
if (!types.isEmpty()) {
|
|
||||||
final PyType elementType = PyUnionType.union(types);
|
|
||||||
final PyType returnType = getReturnStatementType(context);
|
|
||||||
return Ref.create(PyTypingTypeProvider.wrapInGeneratorType(elementType, returnType, this));
|
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -42,7 +42,7 @@ public class PyGeneratorExpressionImpl extends PyComprehensionElementImpl implem
|
|||||||
public PyType getType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) {
|
public PyType getType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) {
|
||||||
final PyExpression resultExpr = getResultExpression();
|
final PyExpression resultExpr = getResultExpression();
|
||||||
if (resultExpr != null) {
|
if (resultExpr != null) {
|
||||||
return PyTypingTypeProvider.wrapInGeneratorType(context.getType(resultExpr), PyNoneType.INSTANCE, this);
|
return PyTypingTypeProvider.wrapInGeneratorType(context.getType(resultExpr), null, PyNoneType.INSTANCE, this);
|
||||||
}
|
}
|
||||||
return null;
|
return null;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,18 +4,19 @@ package com.jetbrains.python.psi.impl;
|
|||||||
import com.intellij.lang.ASTNode;
|
import com.intellij.lang.ASTNode;
|
||||||
import com.intellij.util.containers.ContainerUtil;
|
import com.intellij.util.containers.ContainerUtil;
|
||||||
import com.jetbrains.python.codeInsight.controlflow.ControlFlowCache;
|
import com.jetbrains.python.codeInsight.controlflow.ControlFlowCache;
|
||||||
import com.jetbrains.python.psi.PyCallSiteExpression;
|
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
|
||||||
import com.jetbrains.python.psi.PyElementVisitor;
|
import com.jetbrains.python.psi.*;
|
||||||
import com.jetbrains.python.psi.PyExpression;
|
|
||||||
import com.jetbrains.python.psi.PyLambdaExpression;
|
|
||||||
import com.jetbrains.python.psi.types.*;
|
import com.jetbrains.python.psi.types.*;
|
||||||
import org.jetbrains.annotations.NotNull;
|
import org.jetbrains.annotations.NotNull;
|
||||||
import org.jetbrains.annotations.Nullable;
|
import org.jetbrains.annotations.Nullable;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Optional;
|
import java.util.Optional;
|
||||||
|
|
||||||
|
import static com.intellij.util.containers.ContainerUtil.map;
|
||||||
|
|
||||||
|
|
||||||
public class PyLambdaExpressionImpl extends PyElementImpl implements PyLambdaExpression {
|
public class PyLambdaExpressionImpl extends PyElementImpl implements PyLambdaExpression {
|
||||||
public PyLambdaExpressionImpl(ASTNode astNode) {
|
public PyLambdaExpressionImpl(ASTNode astNode) {
|
||||||
@@ -56,7 +57,19 @@ public class PyLambdaExpressionImpl extends PyElementImpl implements PyLambdaExp
|
|||||||
@Override
|
@Override
|
||||||
public PyType getReturnType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) {
|
public PyType getReturnType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) {
|
||||||
final PyExpression body = getBody();
|
final PyExpression body = getBody();
|
||||||
return body != null ? context.getType(body) : null;
|
if (body == null) return null;
|
||||||
|
|
||||||
|
final PyFunctionImpl.YieldCollector visitor = new PyFunctionImpl.YieldCollector();
|
||||||
|
body.accept(visitor);
|
||||||
|
|
||||||
|
final List<PyType> yieldTypes = map(visitor.getYieldExpressions(), it -> it.getYieldType(context));
|
||||||
|
final List<PyType> sendTypes = map(visitor.getYieldExpressions(), it -> it.getSendType(context));
|
||||||
|
|
||||||
|
if (!yieldTypes.isEmpty()) {
|
||||||
|
return PyTypingTypeProvider.wrapInGeneratorType(
|
||||||
|
PyUnionType.union(yieldTypes), PyUnionType.union(sendTypes), context.getType(body), this);
|
||||||
|
}
|
||||||
|
return context.getType(body);
|
||||||
}
|
}
|
||||||
|
|
||||||
@Nullable
|
@Nullable
|
||||||
|
|||||||
@@ -127,16 +127,7 @@ public class PyTargetExpressionImpl extends PyBaseElementImpl<PyTargetExpression
|
|||||||
}
|
}
|
||||||
final PsiElement parent = PsiTreeUtil.skipParentsOfType(this, PyParenthesizedExpression.class);
|
final PsiElement parent = PsiTreeUtil.skipParentsOfType(this, PyParenthesizedExpression.class);
|
||||||
if (parent instanceof PyAssignmentStatement assignmentStatement) {
|
if (parent instanceof PyAssignmentStatement assignmentStatement) {
|
||||||
PyExpression assignedValue = assignmentStatement.getAssignedValue();
|
return context.getType(assignmentStatement.getAssignedValue());
|
||||||
if (assignedValue instanceof PyParenthesizedExpression) {
|
|
||||||
assignedValue = ((PyParenthesizedExpression)assignedValue).getContainedExpression();
|
|
||||||
}
|
|
||||||
if (assignedValue != null) {
|
|
||||||
if (assignedValue instanceof PyYieldExpression assignedYield) {
|
|
||||||
return assignedYield.isDelegating() ? context.getType(assignedValue) : null;
|
|
||||||
}
|
|
||||||
return context.getType(assignedValue);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
if (parent instanceof PyTupleExpression || parent instanceof PyListLiteralExpression) {
|
if (parent instanceof PyTupleExpression || parent instanceof PyListLiteralExpression) {
|
||||||
PsiElement nextParent =
|
PsiElement nextParent =
|
||||||
|
|||||||
@@ -2,15 +2,16 @@
|
|||||||
package com.jetbrains.python.psi.impl;
|
package com.jetbrains.python.psi.impl;
|
||||||
|
|
||||||
import com.intellij.lang.ASTNode;
|
import com.intellij.lang.ASTNode;
|
||||||
import com.intellij.openapi.util.Ref;
|
import com.intellij.util.ArrayUtil;
|
||||||
|
import com.jetbrains.python.codeInsight.dataflow.scope.ScopeUtil;
|
||||||
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
|
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
|
||||||
import com.jetbrains.python.psi.PyElementVisitor;
|
import com.jetbrains.python.psi.PyElementVisitor;
|
||||||
import com.jetbrains.python.psi.PyExpression;
|
import com.jetbrains.python.psi.PyExpression;
|
||||||
|
import com.jetbrains.python.psi.PyFunction;
|
||||||
import com.jetbrains.python.psi.PyYieldExpression;
|
import com.jetbrains.python.psi.PyYieldExpression;
|
||||||
import com.jetbrains.python.psi.types.PyNoneType;
|
import com.jetbrains.python.psi.types.*;
|
||||||
import com.jetbrains.python.psi.types.PyType;
|
|
||||||
import com.jetbrains.python.psi.types.TypeEvalContext;
|
|
||||||
import org.jetbrains.annotations.NotNull;
|
import org.jetbrains.annotations.NotNull;
|
||||||
|
import org.jetbrains.annotations.Nullable;
|
||||||
|
|
||||||
|
|
||||||
public class PyYieldExpressionImpl extends PyElementImpl implements PyYieldExpression {
|
public class PyYieldExpressionImpl extends PyElementImpl implements PyYieldExpression {
|
||||||
@@ -25,12 +26,52 @@ public class PyYieldExpressionImpl extends PyElementImpl implements PyYieldExpre
|
|||||||
|
|
||||||
@Override
|
@Override
|
||||||
public PyType getType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) {
|
public PyType getType(@NotNull TypeEvalContext context, @NotNull TypeEvalContext.Key key) {
|
||||||
final PyExpression e = getExpression();
|
|
||||||
final PyType type = e != null ? context.getType(e) : null;
|
|
||||||
if (isDelegating()) {
|
if (isDelegating()) {
|
||||||
final Ref<PyType> generatorElementType = PyTypingTypeProvider.coroutineOrGeneratorElementType(type);
|
final PyExpression e = getExpression();
|
||||||
return generatorElementType == null ? PyNoneType.INSTANCE : generatorElementType.get();
|
final PyType type = e != null ? context.getType(e) : null;
|
||||||
|
var generatorDesc = PyTypingTypeProvider.GeneratorTypeDescriptor.create(type);
|
||||||
|
if (generatorDesc != null) {
|
||||||
|
return generatorDesc.returnType();
|
||||||
|
}
|
||||||
|
return PyNoneType.INSTANCE;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
return getSendType(context);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public @Nullable PyType getYieldType(@NotNull TypeEvalContext context) {
|
||||||
|
final PyExpression expr = getExpression();
|
||||||
|
final PyType type = expr != null ? context.getType(expr) : PyNoneType.INSTANCE;
|
||||||
|
|
||||||
|
if (isDelegating()) {
|
||||||
|
return PyTargetExpressionImpl.getIterationType(type, expr, this, context);
|
||||||
}
|
}
|
||||||
return type;
|
return type;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public @Nullable PyType getSendType(@NotNull TypeEvalContext context) {
|
||||||
|
if (ScopeUtil.getScopeOwner(this) instanceof PyFunction function) {
|
||||||
|
if (function.getAnnotation() != null || function.getTypeCommentAnnotation() != null) {
|
||||||
|
var returnType = context.getReturnType(function);
|
||||||
|
var generatorDesc = PyTypingTypeProvider.GeneratorTypeDescriptor.create(returnType);
|
||||||
|
if (generatorDesc != null) {
|
||||||
|
return generatorDesc.sendType();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isDelegating()) {
|
||||||
|
final PyExpression e = getExpression();
|
||||||
|
final PyType type = e != null ? context.getType(e) : null;
|
||||||
|
var generatorDesc = PyTypingTypeProvider.GeneratorTypeDescriptor.create(type);
|
||||||
|
if (generatorDesc != null) {
|
||||||
|
return generatorDesc.sendType();
|
||||||
|
}
|
||||||
|
return PyNoneType.INSTANCE;
|
||||||
|
}
|
||||||
|
return null;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -50,4 +50,8 @@ def l(x) -> <warning descr="Expected type 'int', got 'int | None' instead">int</
|
|||||||
def m(x) -> None:
|
def m(x) -> None:
|
||||||
"""Does not display warning about implicit return, because annotated '-> None' """
|
"""Does not display warning about implicit return, because annotated '-> None' """
|
||||||
if x:
|
if x:
|
||||||
return
|
return
|
||||||
|
|
||||||
|
def n() -> Generator[int, Any, str]:
|
||||||
|
yield 13
|
||||||
|
return 42
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
from typing import Generator, Iterable, Iterator, AsyncIterable, AsyncIterator, AsyncGenerator
|
||||||
|
|
||||||
|
# Fix incorrect YieldType
|
||||||
|
def a() -> Iterable[str]:
|
||||||
|
yield <warning descr="Expected yield type 'str', got 'int' instead">42</warning>
|
||||||
|
|
||||||
|
def b() -> Iterator[str]:
|
||||||
|
yield <warning descr="Expected yield type 'str', got 'int' instead">42</warning>
|
||||||
|
|
||||||
|
def c() -> Generator[str, Any, int]:
|
||||||
|
yield <warning descr="Expected yield type 'str', got 'int' instead">13</warning>
|
||||||
|
return 42
|
||||||
|
|
||||||
|
def c() -> Generator[int, Any, str]:
|
||||||
|
yield 13
|
||||||
|
return <warning descr="Expected type 'str', got 'int' instead">42</warning>
|
||||||
|
|
||||||
|
# Suggest AsyncGenerator
|
||||||
|
async def d() -> <warning descr="Expected type 'AsyncGenerator[int, None]', got 'Iterable[int]' instead">Iterable[int]</warning>:
|
||||||
|
yield 42
|
||||||
|
|
||||||
|
async def e() -> <warning descr="Expected type 'AsyncGenerator[int, None]', got 'Iterator[int]' instead">Iterator[int]</warning>:
|
||||||
|
yield 42
|
||||||
|
|
||||||
|
async def f() -> <warning descr="Expected type 'AsyncGenerator[int, str]', got 'Generator[int, str, None]' instead">Generator[int, str, None]</warning>:
|
||||||
|
yield 13
|
||||||
|
|
||||||
|
# Suggest sync Generator
|
||||||
|
def g() -> <warning descr="Expected type 'Generator[int, None, None]', got 'AsyncIterable[int]' instead">AsyncIterable[int]</warning>:
|
||||||
|
yield 42
|
||||||
|
|
||||||
|
def h() -> <warning descr="Expected type 'Generator[int, None, None]', got 'AsyncIterator[int]' instead">AsyncIterator[int]</warning>:
|
||||||
|
yield 42
|
||||||
|
|
||||||
|
def i() -> <warning descr="Expected type 'Generator[int, str, None]', got 'AsyncGenerator[int, str]' instead">AsyncGenerator[int, str]</warning>:
|
||||||
|
yield 13
|
||||||
|
|
||||||
|
def j() -> Iterator[int]:
|
||||||
|
yield from j()
|
||||||
|
|
||||||
|
def k() -> Iterator[str]:
|
||||||
|
yield from <warning descr="Expected yield type 'str', got 'int' instead">j()</warning>
|
||||||
|
yield from <warning descr="Expected yield type 'str', got 'int' instead">[1]</warning>
|
||||||
|
|
||||||
|
def l() -> Generator[None, int, None]:
|
||||||
|
x: float = yield
|
||||||
|
|
||||||
|
def m() -> Generator[None, float, None]:
|
||||||
|
yield from <warning descr="Expected send type 'float', got 'int' instead">l()</warning>
|
||||||
|
|
||||||
|
def n() -> Generator[None, float, None]:
|
||||||
|
x: float = yield
|
||||||
|
|
||||||
|
def o() -> Generator[None, int, None]:
|
||||||
|
yield from n()
|
||||||
|
|
||||||
|
def p() -> Generator[int, None, None]:
|
||||||
|
yield from [1, 2]
|
||||||
|
yield from [3, 4]
|
||||||
|
|
||||||
|
def q() -> int:
|
||||||
|
x = lambda: (yield "str")
|
||||||
|
return 42
|
||||||
|
|
||||||
|
async def r() -> AsyncGenerator[int]:
|
||||||
|
yield 42
|
||||||
|
|
||||||
|
def s() -> Generator[int]:
|
||||||
|
yield from <warning descr="Cannot yield from 'AsyncGenerator[int, None]', use async for instead">r()</warning>
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
def gen() -> Generator[int, bool, str]:
|
||||||
|
b: bool = yield <warning descr="Expected yield type 'int', got 'str' instead"><caret>"str"</warning>
|
||||||
|
return <warning descr="Expected type 'str', got 'int' instead">42</warning>
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
def gen() -> Generator[str, bool, int]:
|
||||||
|
b: bool = yield "str"
|
||||||
|
return 42
|
||||||
@@ -0,0 +1,4 @@
|
|||||||
|
async def gen() -> str:
|
||||||
|
b: bool = <warning descr="Expected type 'str', got 'AsyncGenerator[str | float, Any]' instead"><caret>yield "str"</warning>
|
||||||
|
if b:
|
||||||
|
b = <warning descr="Expected type 'str', got 'AsyncGenerator[str | float, Any]' instead">yield 3.14</warning>
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
|
|
||||||
|
async def gen() -> AsyncGenerator[str | float, Any]:
|
||||||
|
b: bool = <caret>yield "str"
|
||||||
|
if b:
|
||||||
|
b = yield 3.14
|
||||||
@@ -0,0 +1,5 @@
|
|||||||
|
def f(x) -> <warning descr="Function returning 'int | None' has implicit return">int | None<caret></warning>:
|
||||||
|
if x == 1:
|
||||||
|
return 42
|
||||||
|
elif x == 2:
|
||||||
|
<warning descr="Function returning 'int | None' has implicit return">return</warning>
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
def f(x) -> int | None:
|
||||||
|
if x == 1:
|
||||||
|
return 42
|
||||||
|
elif x == 2:
|
||||||
|
return None
|
||||||
|
return None
|
||||||
@@ -17,6 +17,66 @@ import java.util.Map;
|
|||||||
public class Py3TypeTest extends PyTestCase {
|
public class Py3TypeTest extends PyTestCase {
|
||||||
public static final String TEST_DIRECTORY = "/types/";
|
public static final String TEST_DIRECTORY = "/types/";
|
||||||
|
|
||||||
|
// PY-20710
|
||||||
|
public void testLambdaGenerator() {
|
||||||
|
doTest("Generator[int, Any, Any]", """
|
||||||
|
expr = (lambda: (yield 1))()
|
||||||
|
""");
|
||||||
|
}
|
||||||
|
|
||||||
|
// PY-20710
|
||||||
|
public void testGeneratorDelegatingToLambdaGenerator() {
|
||||||
|
doTest("Generator[int, Any, str]", """
|
||||||
|
def g():
|
||||||
|
yield from (lambda: (yield 1))()
|
||||||
|
return "foo"
|
||||||
|
expr = g()
|
||||||
|
""");
|
||||||
|
}
|
||||||
|
|
||||||
|
// PY-20710
|
||||||
|
public void testYieldExpressionTypeFromGeneratorSendTypeHint() {
|
||||||
|
doTest("int", """
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
def g() -> Generator[str, int, None]:
|
||||||
|
expr = yield "foo"
|
||||||
|
""");
|
||||||
|
}
|
||||||
|
|
||||||
|
// PY-20710
|
||||||
|
public void testYieldFromExpressionTypeFromGeneratorReturnTypeHint() {
|
||||||
|
doTest("int", """
|
||||||
|
from typing import Generator, Any
|
||||||
|
|
||||||
|
def delegate() -> Generator[None, Any, int]:
|
||||||
|
yield
|
||||||
|
return 42
|
||||||
|
|
||||||
|
def g():
|
||||||
|
expr = yield from delegate()
|
||||||
|
""");
|
||||||
|
}
|
||||||
|
|
||||||
|
// PY-20710
|
||||||
|
public void testYieldFromLambda() {
|
||||||
|
doTest("Generator[int | str, str | Any, bool]",
|
||||||
|
"""
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
def gen1() -> Generator[int, str, bool]:
|
||||||
|
yield 42
|
||||||
|
return True
|
||||||
|
|
||||||
|
def gen2():
|
||||||
|
yield "str"
|
||||||
|
return True
|
||||||
|
|
||||||
|
l = lambda: (yield from gen1()) or (yield from gen2())
|
||||||
|
expr = l()
|
||||||
|
""");
|
||||||
|
}
|
||||||
|
|
||||||
// PY-6702
|
// PY-6702
|
||||||
public void testYieldFromType() {
|
public void testYieldFromType() {
|
||||||
doTest("str | int | float",
|
doTest("str | int | float",
|
||||||
|
|||||||
@@ -102,6 +102,10 @@ public class Py3TypeCheckerInspectionTest extends PyInspectionTestCase {
|
|||||||
public void testFunctionReturnTypePy3() {
|
public void testFunctionReturnTypePy3() {
|
||||||
doTest();
|
doTest();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void testFunctionYieldTypePy3() {
|
||||||
|
doTest();
|
||||||
|
}
|
||||||
|
|
||||||
// PY-20770
|
// PY-20770
|
||||||
public void testAsyncForOverAsyncGenerator() {
|
public void testAsyncForOverAsyncGenerator() {
|
||||||
|
|||||||
@@ -64,4 +64,18 @@ public class PyMakeFunctionReturnTypeQuickFixTest extends PyQuickFixTestCase {
|
|||||||
PyPsiBundle.message("QFIX.make.function.return.type", "func", "Callable[[Any], int]"),
|
PyPsiBundle.message("QFIX.make.function.return.type", "func", "Callable[[Any], int]"),
|
||||||
LanguageLevel.getLatest());
|
LanguageLevel.getLatest());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// PY-20710
|
||||||
|
public void testChangeGenerator() {
|
||||||
|
doQuickFixTest(PyTypeCheckerInspection.class,
|
||||||
|
PyPsiBundle.message("QFIX.make.function.return.type", "gen", "Generator[str, bool, int]"),
|
||||||
|
LanguageLevel.getLatest());
|
||||||
|
}
|
||||||
|
|
||||||
|
// PY-20710
|
||||||
|
public void testMakeGenerator() {
|
||||||
|
doQuickFixTest(PyTypeCheckerInspection.class,
|
||||||
|
PyPsiBundle.message("QFIX.make.function.return.type", "gen", "AsyncGenerator[str | float, Any]"),
|
||||||
|
LanguageLevel.getLatest());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user