PY-55548 Use actual return type for "Specify return type using annotation"

For async functions, unwrap return type from Awaitable or Coroutine


Merge-request: IJ-MR-146295
Merged-by: Aleksandr Govenko <aleksandr.govenko@jetbrains.com>

(cherry picked from commit 9fe8d02a9d8bb584b9d6972ce999912bd93875e6)

IJ-MR-146295

GitOrigin-RevId: 9bad4877a069268a2d0181cac70b9a0d399cb5e6
This commit is contained in:
Aleksandr.Govenko
2024-10-28 15:41:56 +00:00
committed by intellij-monorepo-bot
parent 3bf01bc0c5
commit d5f9bf8de0
24 changed files with 91 additions and 25 deletions

View File

@@ -114,7 +114,7 @@ public final class PyAnnotateTypesIntention extends PyBaseIntentionAction {
replacementTextBuilder.append(") -> "); replacementTextBuilder.append(") -> ");
String returnType = SpecifyTypeInPy3AnnotationsIntention.returnType(function); String returnType = SpecifyTypeInPy3AnnotationsIntention.returnType(function).getAnnotationText();
templates.add(Pair.create(replacementTextBuilder.length(), returnType)); templates.add(Pair.create(replacementTextBuilder.length(), returnType));
replacementTextBuilder.append(returnType); replacementTextBuilder.append(returnType);

View File

@@ -22,6 +22,7 @@ import com.intellij.codeInsight.template.TemplateBuilderFactory;
import com.intellij.openapi.application.WriteAction; import com.intellij.openapi.application.WriteAction;
import com.intellij.openapi.editor.Editor; import com.intellij.openapi.editor.Editor;
import com.intellij.openapi.project.Project; import com.intellij.openapi.project.Project;
import com.intellij.openapi.util.Ref;
import com.intellij.openapi.util.TextRange; import com.intellij.openapi.util.TextRange;
import com.intellij.psi.PsiElement; import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiFile; import com.intellij.psi.PsiFile;
@@ -32,11 +33,16 @@ import com.jetbrains.python.PyNames;
import com.jetbrains.python.PyPsiBundle; import com.jetbrains.python.PyPsiBundle;
import com.jetbrains.python.PyTokenTypes; import com.jetbrains.python.PyTokenTypes;
import com.jetbrains.python.PythonUiService; import com.jetbrains.python.PythonUiService;
import com.jetbrains.python.codeInsight.intentions.PyTypeHintGenerationUtil.AnnotationInfo;
import com.jetbrains.python.codeInsight.typing.PyTypingTypeProvider;
import com.jetbrains.python.debugger.PySignature; import com.jetbrains.python.debugger.PySignature;
import com.jetbrains.python.debugger.PySignatureCacheManager; import com.jetbrains.python.debugger.PySignatureCacheManager;
import com.jetbrains.python.documentation.PythonDocumentationProvider;
import com.jetbrains.python.psi.*; import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.impl.ParamHelper; import com.jetbrains.python.psi.impl.ParamHelper;
import com.jetbrains.python.psi.impl.PyPsiUtils; import com.jetbrains.python.psi.impl.PyPsiUtils;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
/** /**
@@ -144,24 +150,34 @@ public final class SpecifyTypeInPy3AnnotationsIntention extends TypeIntention {
} }
static String returnType(@NotNull PyFunction function) { static @NotNull AnnotationInfo returnType(@NotNull PyFunction function) {
if (function.getAnnotation() != null && function.getAnnotation().getValue() != null) { if (function.getAnnotation() != null && function.getAnnotation().getValue() != null) {
return function.getAnnotation().getValue().getText(); return new AnnotationInfo(function.getAnnotation().getValue().getText());
} }
final PySignature signature = PySignatureCacheManager.getInstance(function.getProject()).findSignature(function); final PySignature signature = PySignatureCacheManager.getInstance(function.getProject()).findSignature(function);
if (signature != null) { if (signature != null) {
final String qualifiedName = signature.getReturnTypeQualifiedName(); final String qualifiedName = signature.getReturnTypeQualifiedName();
if (qualifiedName != null) return qualifiedName; if (qualifiedName != null) return new AnnotationInfo(qualifiedName);
} }
return PyNames.OBJECT; final TypeEvalContext context = TypeEvalContext.userInitiated(function.getProject(), function.getContainingFile());
PyType inferredType = context.getReturnType(function);
if (function.isAsync()) {
inferredType = Ref.deref(PyTypingTypeProvider.unwrapCoroutineReturnType(inferredType));
}
return new AnnotationInfo(PythonDocumentationProvider.getTypeHint(inferredType, context), inferredType);
} }
public static PyExpression annotateReturnType(Project project, PyFunction function, boolean createTemplate) { public static PyExpression annotateReturnType(Project project, PyFunction function, boolean createTemplate) {
String returnType = returnType(function); AnnotationInfo returnTypeAnnotation = returnType(function);
final String annotationText = "-> " + returnType; final String returnTypeText = returnTypeAnnotation.getAnnotationText();
final String annotationText = "-> " + returnTypeText;
final PsiFile file = function.getContainingFile();
final TypeEvalContext context = TypeEvalContext.userInitiated(project, file);
PyTypeHintGenerationUtil.addImportsForTypeAnnotations(returnTypeAnnotation.getTypes(), context, file);
PyFunction annotatedFunction = PyUtil.updateDocumentUnblockedAndCommitted(function, document -> { PyFunction annotatedFunction = PyUtil.updateDocumentUnblockedAndCommitted(function, document -> {
final PyAnnotation oldAnnotation = function.getAnnotation(); final PyAnnotation oldAnnotation = function.getAnnotation();
@@ -196,7 +212,7 @@ public final class SpecifyTypeInPy3AnnotationsIntention extends TypeIntention {
final int offset = annotationValue.getTextOffset(); final int offset = annotationValue.getTextOffset();
final TemplateBuilder builder = TemplateBuilderFactory.getInstance().createTemplateBuilder(annotationValue); final TemplateBuilder builder = TemplateBuilderFactory.getInstance().createTemplateBuilder(annotationValue);
builder.replaceRange(TextRange.create(0, returnType.length()), returnType); builder.replaceRange(TextRange.create(0, returnTypeText.length()), returnTypeText);
final Editor targetEditor = PythonUiService.getInstance().openTextEditor(project, annotatedFunction.getContainingFile().getVirtualFile(), offset); final Editor targetEditor = PythonUiService.getInstance().openTextEditor(project, annotatedFunction.getContainingFile().getVirtualFile(), offset);
if (targetEditor != null) { if (targetEditor != null) {
builder.run(targetEditor, true); builder.run(targetEditor, true);

View File

@@ -58,6 +58,7 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
public static final String GENERATOR = "typing.Generator"; public static final String GENERATOR = "typing.Generator";
public static final String ASYNC_GENERATOR = "typing.AsyncGenerator"; public static final String ASYNC_GENERATOR = "typing.AsyncGenerator";
public static final String COROUTINE = "typing.Coroutine"; public static final String COROUTINE = "typing.Coroutine";
public static final String AWAITABLE = "typing.Awaitable";
public static final String NAMEDTUPLE = "typing.NamedTuple"; public static final String NAMEDTUPLE = "typing.NamedTuple";
public static final String TYPED_DICT = "typing.TypedDict"; public static final String TYPED_DICT = "typing.TypedDict";
public static final String TYPED_DICT_EXT = "typing_extensions.TypedDict"; public static final String TYPED_DICT_EXT = "typing_extensions.TypedDict";
@@ -2199,6 +2200,25 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
return asyncGenerator != null ? new PyCollectionTypeImpl(asyncGenerator, false, Arrays.asList(elementType, null)) : null; return asyncGenerator != null ? new PyCollectionTypeImpl(asyncGenerator, false, Arrays.asList(elementType, null)) : null;
} }
@Nullable
public static Ref<PyType> unwrapCoroutineReturnType(@Nullable PyType coroutineType) {
final PyCollectionType genericType = as(coroutineType, PyCollectionType.class);
if (genericType != null) {
var qName = genericType.getClassQName();
if (AWAITABLE.equals(qName)) {
return Ref.create(ContainerUtil.getOrElse(genericType.getElementTypes(), 0, null));
}
if (COROUTINE.equals(qName)) {
return Ref.create(ContainerUtil.getOrElse(genericType.getElementTypes(), 2, null));
}
}
return null;
}
@Nullable @Nullable
public static Ref<PyType> coroutineOrGeneratorElementType(@Nullable PyType coroutineOrGeneratorType) { public static Ref<PyType> coroutineOrGeneratorElementType(@Nullable PyType coroutineOrGeneratorType) {
final PyCollectionType genericType = as(coroutineOrGeneratorType, PyCollectionType.class); final PyCollectionType genericType = as(coroutineOrGeneratorType, PyCollectionType.class);
@@ -2207,7 +2227,7 @@ public final class PyTypingTypeProvider extends PyTypeProviderWithCustomContext<
if (genericType != null && classType != null) { if (genericType != null && classType != null) {
var qName = classType.getClassQName(); var qName = classType.getClassQName();
if ("typing.Awaitable".equals(qName)) { if (AWAITABLE.equals(qName)) {
return Ref.create(ContainerUtil.getOrElse(genericType.getElementTypes(), 0, null)); return Ref.create(ContainerUtil.getOrElse(genericType.getElementTypes(), 0, null));
} }

View File

@@ -1,2 +1,2 @@
def foo(x: object, y: object) -> object: def foo(x: object, y: object) -> None:
pass pass

View File

@@ -1,4 +1,4 @@
def foo(x: object, y: object) -> object: def foo(x: object, y: object) -> None:
pass pass

View File

@@ -1,2 +1,2 @@
def foo(x: object, y: object) -> object: def foo(x: object, y: object) -> None:
pass pass

View File

@@ -1,2 +1,2 @@
def foo(**x: object<caret>) -> object: def foo(**x: object<caret>) -> None:
pass pass

View File

@@ -1,2 +1,2 @@
def foo(x: bool<caret>, y: bool) -> object: def foo(x: bool<caret>, y: bool) -> str:
return "42" return "42"

View File

@@ -1,2 +1,2 @@
def foo(*x: object<caret>, **y: object) -> object: def foo(*x: object<caret>, **y: object) -> None:
pass pass

View File

@@ -3,7 +3,7 @@ class MyClass:
pass pass
def method(self, x): def method(self, x):
# type: (object) -> object # type: (object) -> None
pass pass

View File

@@ -1,3 +1,3 @@
def foo(x, y): def foo(x, y):
# type: (object, object) -> object # type: (object, object) -> None
pass pass

View File

@@ -0,0 +1,5 @@
async def bar() -> int:
return 42
def fo<caret>o(x, y):
return bar()

View File

@@ -0,0 +1,8 @@
from typing import Any, Coroutine
async def bar() -> int:
return 42
def foo(x, y) -> Coroutine[Any, Any, int]:
return bar()

View File

@@ -1,2 +1,2 @@
def foo(x, y) -> object: def foo(x, y) -> None:
pass pass

View File

@@ -1,4 +1,4 @@
def foo(x, y) -> object: def foo(x, y) -> None:
pass pass

View File

@@ -1,2 +1,2 @@
def foo(x, y) -> object: def foo(x, y) -> None:
pass pass

View File

@@ -0,0 +1,2 @@
async def fo<caret>o(x, y):
return 42

View File

@@ -0,0 +1,2 @@
async def foo(x, y) -> int:
return 42

View File

@@ -1,4 +1,4 @@
def my_func(p1=1) -> object: def my_func(p1=1) -> int:
return p1 return p1
d = my_func(1) d = my_func(1)

View File

@@ -1,4 +1,4 @@
def my_func(p1=1) -> object: def my_func(p1=1) -> int:
return p1 return p1
d = my_func(1) d = my_func(1)

View File

@@ -1,4 +1,4 @@
def foo() -> object: def foo() -> None:
@decorator @decorator
def bar(): def bar():
pass pass

View File

@@ -1,4 +1,4 @@
def foo() -> object: def foo() -> None:
@decorator @decorator
def bar(): def bar():
pass pass

View File

@@ -1,4 +1,7 @@
def g(x) -> object: from typing import Any
def g(x) -> Any:
return x return x

View File

@@ -22,6 +22,16 @@ public class SpecifyTypeInPy3AnnotationsIntentionTest extends PyIntentionTestCas
doTestParam(); doTestParam();
} }
// PY-55548
public void testUnwrapsTypesInAsyncFunctions() {
doTestReturnType();
}
// PY-55548
public void testAddsImportsWhenNeeded() {
doTestReturnType();
}
// PY-31369 // PY-31369
public void testAnnotatedParameterNoIntention() { public void testAnnotatedParameterNoIntention() {
doNegativeTest(PyPsiBundle.message("INTN.specify.type.in.annotation")); doNegativeTest(PyPsiBundle.message("INTN.specify.type.in.annotation"));