PY-55548 Use actual return type for "Specify return type using annotation"

For async functions, unwrap return type from Awaitable or Coroutine


Merge-request: IJ-MR-146295
Merged-by: Aleksandr Govenko <aleksandr.govenko@jetbrains.com>

(cherry picked from commit 9fe8d02a9d8bb584b9d6972ce999912bd93875e6)

IJ-MR-146295

GitOrigin-RevId: 9bad4877a069268a2d0181cac70b9a0d399cb5e6
This commit is contained in:
Aleksandr.Govenko
2024-10-28 15:41:56 +00:00
committed by intellij-monorepo-bot
parent 3bf01bc0c5
commit d5f9bf8de0
24 changed files with 91 additions and 25 deletions

View File

@@ -114,7 +114,7 @@ public final class PyAnnotateTypesIntention extends PyBaseIntentionAction {
replacementTextBuilder.append(") -> ");
String returnType = SpecifyTypeInPy3AnnotationsIntention.returnType(function);
String returnType = SpecifyTypeInPy3AnnotationsIntention.returnType(function).getAnnotationText();
templates.add(Pair.create(replacementTextBuilder.length(), returnType));
replacementTextBuilder.append(returnType);

View File

@@ -22,6 +22,7 @@ import com.intellij.codeInsight.template.TemplateBuilderFactory;
import com.intellij.openapi.application.WriteAction;
import com.intellij.openapi.editor.Editor;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.util.Ref;
import com.intellij.openapi.util.TextRange;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiFile;
@@ -32,11 +33,16 @@ import com.jetbrains.python.PyNames;
import com.jetbrains.python.PyPsiBundle;
import com.jetbrains.python.PyTokenTypes;
import com.jetbrains.python.PythonUiService;
import com.jetbrains.python.codeInsight.intentions.PyTypeHintGenerationUtil.AnnotationInfo;
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
import com.jetbrains.python.debugger.PySignature;
import com.jetbrains.python.debugger.PySignatureCacheManager;
import com.jetbrains.python.documentation.PythonDocumentationProvider;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.impl.ParamHelper;
import com.jetbrains.python.psi.impl.PyPsiUtils;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
/**
@@ -144,24 +150,34 @@ public final class SpecifyTypeInPy3AnnotationsIntention extends TypeIntention {
}
static String returnType(@NotNull PyFunction function) {
static @NotNull AnnotationInfo returnType(@NotNull PyFunction function) {
if (function.getAnnotation() != null && function.getAnnotation().getValue() != null) {
return function.getAnnotation().getValue().getText();
return new AnnotationInfo(function.getAnnotation().getValue().getText());
}
final PySignature signature = PySignatureCacheManager.getInstance(function.getProject()).findSignature(function);
if (signature != null) {
final String qualifiedName = signature.getReturnTypeQualifiedName();
if (qualifiedName != null) return qualifiedName;
if (qualifiedName != null) return new AnnotationInfo(qualifiedName);
}
return PyNames.OBJECT;
final TypeEvalContext context = TypeEvalContext.userInitiated(function.getProject(), function.getContainingFile());
PyType inferredType = context.getReturnType(function);
if (function.isAsync()) {
inferredType = Ref.deref(PyTypingTypeProvider.unwrapCoroutineReturnType(inferredType));
}
return new AnnotationInfo(PythonDocumentationProvider.getTypeHint(inferredType, context), inferredType);
}
public static PyExpression annotateReturnType(Project project, PyFunction function, boolean createTemplate) {
String returnType = returnType(function);
AnnotationInfo returnTypeAnnotation = returnType(function);
final String annotationText = "-> " + returnType;
final String returnTypeText = returnTypeAnnotation.getAnnotationText();
final String annotationText = "-> " + returnTypeText;
final PsiFile file = function.getContainingFile();
final TypeEvalContext context = TypeEvalContext.userInitiated(project, file);
PyTypeHintGenerationUtil.addImportsForTypeAnnotations(returnTypeAnnotation.getTypes(), context, file);
PyFunction annotatedFunction = PyUtil.updateDocumentUnblockedAndCommitted(function, document -> {
final PyAnnotation oldAnnotation = function.getAnnotation();
@@ -196,7 +212,7 @@ public final class SpecifyTypeInPy3AnnotationsIntention extends TypeIntention {
final int offset = annotationValue.getTextOffset();
final TemplateBuilder builder = TemplateBuilderFactory.getInstance().createTemplateBuilder(annotationValue);
builder.replaceRange(TextRange.create(0, returnType.length()), returnType);
builder.replaceRange(TextRange.create(0, returnTypeText.length()), returnTypeText);
final Editor targetEditor = PythonUiService.getInstance().openTextEditor(project, annotatedFunction.getContainingFile().getVirtualFile(), offset);
if (targetEditor != null) {
builder.run(targetEditor, true);

View File

@@ -58,6 +58,7 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
public static final String GENERATOR = "typing.Generator";
public static final String ASYNC_GENERATOR = "typing.AsyncGenerator";
public static final String COROUTINE = "typing.Coroutine";
public static final String AWAITABLE = "typing.Awaitable";
public static final String NAMEDTUPLE = "typing.NamedTuple";
public static final String TYPED_DICT = "typing.TypedDict";
public static final String TYPED_DICT_EXT = "typing_extensions.TypedDict";
@@ -2199,6 +2200,25 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
return asyncGenerator != null ? new PyCollectionTypeImpl(asyncGenerator, false, Arrays.asList(elementType, null)) : null;
}
@Nullable
public static Ref<PyType> unwrapCoroutineReturnType(@Nullable PyType coroutineType) {
final PyCollectionType genericType = as(coroutineType, PyCollectionType.class);
if (genericType != null) {
var qName = genericType.getClassQName();
if (AWAITABLE.equals(qName)) {
return Ref.create(ContainerUtil.getOrElse(genericType.getElementTypes(), 0, null));
}
if (COROUTINE.equals(qName)) {
return Ref.create(ContainerUtil.getOrElse(genericType.getElementTypes(), 2, null));
}
}
return null;
}
@Nullable
public static Ref<PyType> coroutineOrGeneratorElementType(@Nullable PyType coroutineOrGeneratorType) {
final PyCollectionType genericType = as(coroutineOrGeneratorType, PyCollectionType.class);
@@ -2207,7 +2227,7 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
if (genericType != null && classType != null) {
var qName = classType.getClassQName();
if ("typing.Awaitable".equals(qName)) {
if (AWAITABLE.equals(qName)) {
return Ref.create(ContainerUtil.getOrElse(genericType.getElementTypes(), 0, null));
}

View File

@@ -1,2 +1,2 @@
def foo(x: object, y: object) -> object:
def foo(x: object, y: object) -> None:
pass

View File

@@ -1,4 +1,4 @@
def foo(x: object, y: object) -> object:
def foo(x: object, y: object) -> None:
pass

View File

@@ -1,2 +1,2 @@
def foo(x: object, y: object) -> object:
def foo(x: object, y: object) -> None:
pass

View File

@@ -1,2 +1,2 @@
def foo(**x: object<caret>) -> object:
def foo(**x: object<caret>) -> None:
pass

View File

@@ -1,2 +1,2 @@
def foo(x: bool<caret>, y: bool) -> object:
def foo(x: bool<caret>, y: bool) -> str:
return "42"

View File

@@ -1,2 +1,2 @@
def foo(*x: object<caret>, **y: object) -> object:
def foo(*x: object<caret>, **y: object) -> None:
pass

View File

@@ -3,7 +3,7 @@ class MyClass:
pass
def method(self, x):
# type: (object) -> object
# type: (object) -> None
pass

View File

@@ -1,3 +1,3 @@
def foo(x, y):
# type: (object, object) -> object
# type: (object, object) -> None
pass

View File

@@ -0,0 +1,5 @@
async def bar() -> int:
return 42
def fo<caret>o(x, y):
return bar()

View File

@@ -0,0 +1,8 @@
from typing import Any, Coroutine
async def bar() -> int:
return 42
def foo(x, y) -> Coroutine[Any, Any, int]:
return bar()

View File

@@ -1,2 +1,2 @@
def foo(x, y) -> object:
def foo(x, y) -> None:
pass

View File

@@ -1,4 +1,4 @@
def foo(x, y) -> object:
def foo(x, y) -> None:
pass

View File

@@ -1,2 +1,2 @@
def foo(x, y) -> object:
def foo(x, y) -> None:
pass

View File

@@ -0,0 +1,2 @@
async def fo<caret>o(x, y):
return 42

View File

@@ -0,0 +1,2 @@
async def foo(x, y) -> int:
return 42

View File

@@ -1,4 +1,4 @@
def my_func(p1=1) -> object:
def my_func(p1=1) -> int:
return p1
d = my_func(1)

View File

@@ -1,4 +1,4 @@
def my_func(p1=1) -> object:
def my_func(p1=1) -> int:
return p1
d = my_func(1)

View File

@@ -1,4 +1,4 @@
def foo() -> object:
def foo() -> None:
@decorator
def bar():
pass

View File

@@ -1,4 +1,4 @@
def foo() -> object:
def foo() -> None:
@decorator
def bar():
pass

View File

@@ -1,4 +1,7 @@
def g(x) -> object:
from typing import Any
def g(x) -> Any:
return x

View File

@@ -22,6 +22,16 @@ public class SpecifyTypeInPy3AnnotationsIntentionTest extends PyIntentionTestCas
doTestParam();
}
// PY-55548
public void testUnwrapsTypesInAsyncFunctions() {
doTestReturnType();
}
// PY-55548
public void testAddsImportsWhenNeeded() {
doTestReturnType();
}
// PY-31369
public void testAnnotatedParameterNoIntention() {
doNegativeTest(PyPsiBundle.message("INTN.specify.type.in.annotation"));