PY-27128 PY-48466 Fix problem with imports in PyMakeFunctionReturnTypeQuickFix

(cherry picked from commit 9320c2e708afa2ddd025f60b0c122132387493b7)

IJ-MR-10124

GitOrigin-RevId: 219e646482999e8b81b77aaf52eb24574381af3d
This commit is contained in:
andrey.matveev
2021-06-15 12:48:19 +07:00
committed by intellij-monorepo-bot
parent fe812b5f46
commit 903f6e11ca
18 changed files with 236 additions and 27 deletions

View File

@@ -74,20 +74,20 @@ public class PyTypeCheckerInspection extends PyInspection {
@Override
public void visitPyReturnStatement(@NotNull PyReturnStatement node) {
final ScopeOwner owner = ScopeUtil.getScopeOwner(node);
ScopeOwner owner = ScopeUtil.getScopeOwner(node);
if (owner instanceof PyFunction) {
final PyFunction function = (PyFunction)owner;
final PyAnnotation annotation = function.getAnnotation();
final String typeCommentAnnotation = function.getTypeCommentAnnotation();
PyFunction function = (PyFunction)owner;
PyAnnotation annotation = function.getAnnotation();
String typeCommentAnnotation = function.getTypeCommentAnnotation();
if (annotation != null || typeCommentAnnotation != null) {
final PyExpression returnExpr = node.getExpression();
final PyType expected = getExpectedReturnType(function);
final PyType actual = returnExpr != null ? tryPromotingType(returnExpr, expected) : PyNoneType.INSTANCE;
PyExpression returnExpr = node.getExpression();
PyType expected = getExpectedReturnType(function);
PyType actual = returnExpr != null ? tryPromotingType(returnExpr, expected) : PyNoneType.INSTANCE;
if (!PyTypeChecker.match(expected, actual, myTypeEvalContext)) {
final String expectedName = PythonDocumentationProvider.getTypeName(expected, myTypeEvalContext);
final String actualName = PythonDocumentationProvider.getTypeName(actual, myTypeEvalContext);
PyMakeFunctionReturnTypeQuickFix localQuickFix = new PyMakeFunctionReturnTypeQuickFix(function, actualName, myTypeEvalContext);
PyMakeFunctionReturnTypeQuickFix globalQuickFix = new PyMakeFunctionReturnTypeQuickFix(function, null, myTypeEvalContext);
String expectedName = PythonDocumentationProvider.getTypeName(expected, myTypeEvalContext);
String actualName = PythonDocumentationProvider.getTypeName(actual, myTypeEvalContext);
var localQuickFix = new PyMakeFunctionReturnTypeQuickFix(function, returnExpr, actual, myTypeEvalContext);
var globalQuickFix = new PyMakeFunctionReturnTypeQuickFix(function, returnExpr, null, myTypeEvalContext);
registerProblem(returnExpr != null ? returnExpr : node,
PyPsiBundle.message("INSP.type.checker.expected.type.got.type.instead", expectedName, actualName),
localQuickFix, globalQuickFix);
@@ -98,7 +98,12 @@ public class PyTypeCheckerInspection extends PyInspection {
@Nullable
private PyType getExpectedReturnType(@NotNull PyFunction function) {
final PyType returnType = myTypeEvalContext.getReturnType(function);
return getExpectedReturnType(function, myTypeEvalContext);
}
@Nullable
public static PyType getExpectedReturnType(@NotNull PyFunction function, @NotNull TypeEvalContext typeEvalContext) {
final PyType returnType = typeEvalContext.getReturnType(function);
if (function.isAsync() || function.isGenerator()) {
return Ref.deref(PyTypingTypeProvider.coroutineOrGeneratorElementType(returnType));
@@ -107,6 +112,13 @@ public class PyTypeCheckerInspection extends PyInspection {
return returnType;
}
@Nullable
public static PyType getActualReturnType(@NotNull PyFunction function, @Nullable PyExpression returnExpr,
@NotNull TypeEvalContext context) {
PyType returnTypeExpected = getExpectedReturnType(function, context);
return returnExpr != null ? tryPromotingType(returnExpr, returnTypeExpected, context) : PyNoneType.INSTANCE;
}
@Override
public void visitPyTargetExpression(@NotNull PyTargetExpression node) {
// TODO: Check types in class-level assignments
@@ -125,9 +137,14 @@ public class PyTypeCheckerInspection extends PyInspection {
@Nullable
private PyType tryPromotingType(@NotNull PyExpression value, @Nullable PyType expected) {
final PyType promotedToLiteral = PyLiteralType.Companion.promoteToLiteral(value, expected, myTypeEvalContext, null);
return tryPromotingType(value, expected, myTypeEvalContext);
}
@Nullable
public static PyType tryPromotingType(@NotNull PyExpression value, @Nullable PyType expected, @NotNull TypeEvalContext context) {
final PyType promotedToLiteral = PyLiteralType.Companion.promoteToLiteral(value, expected, context, null);
if (promotedToLiteral != null) return promotedToLiteral;
return myTypeEvalContext.getType(value);
return context.getType(value);
}
@Override

View File

@@ -18,33 +18,49 @@ package com.jetbrains.python.inspections.quickfix;
import com.intellij.codeInspection.LocalQuickFix;
import com.intellij.codeInspection.ProblemDescriptor;
import com.intellij.openapi.project.Project;
import com.intellij.psi.PsiComment;
import com.intellij.psi.SmartPointerManager;
import com.intellij.psi.SmartPsiElementPointer;
import com.intellij.psi.*;
import com.jetbrains.python.PyPsiBundle;
import com.jetbrains.python.codeInsight.intentions.PyTypeHintGenerationUtil;
import com.jetbrains.python.documentation.PythonDocumentationProvider;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.List;
import static com.jetbrains.python.inspections.PyTypeCheckerInspection.Visitor.getActualReturnType;
/**
* @author lada
*/
public class PyMakeFunctionReturnTypeQuickFix implements LocalQuickFix {
private final SmartPsiElementPointer<PyFunction> myFunction;
private final SmartPsiElementPointer<PyExpression> myReturnExpr;
private final SmartPsiElementPointer<PyAnnotation> myAnnotation;
private final SmartPsiElementPointer<PsiComment> myTypeCommentAnnotation;
private final String myReturnTypeName;
private final boolean myHaveSuggestedType;
public PyMakeFunctionReturnTypeQuickFix(@NotNull PyFunction function, @Nullable String returnTypeName, @NotNull TypeEvalContext context) {
final SmartPointerManager manager = SmartPointerManager.getInstance(function.getProject());
public PyMakeFunctionReturnTypeQuickFix(@NotNull PyFunction function,
@Nullable PyExpression returnExpr,
@Nullable PyType returnTypeSuggested,
@NotNull TypeEvalContext context) {
SmartPointerManager manager = SmartPointerManager.getInstance(function.getProject());
myFunction = manager.createSmartPsiElementPointer(function);
myReturnExpr = returnExpr != null ? manager.createSmartPsiElementPointer(returnExpr) : null;
PyAnnotation annotation = function.getAnnotation();
myAnnotation = annotation != null ? manager.createSmartPsiElementPointer(annotation) : null;
PsiComment typeCommentAnnotation = function.getTypeComment();
myTypeCommentAnnotation = typeCommentAnnotation != null ? manager.createSmartPsiElementPointer(typeCommentAnnotation) : null;
myReturnTypeName = (returnTypeName == null) ? PythonDocumentationProvider.getTypeName(function.getReturnStatementType(context), context) : returnTypeName;
myHaveSuggestedType = returnTypeSuggested != null;
PyType returnType = myHaveSuggestedType ? returnTypeSuggested : function.getReturnStatementType(context);
myReturnTypeName = PythonDocumentationProvider.getTypeHint(returnType, context);
}
@Override
@@ -65,22 +81,49 @@ public class PyMakeFunctionReturnTypeQuickFix implements LocalQuickFix {
public void applyFix(@NotNull Project project, @NotNull ProblemDescriptor descriptor) {
PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project);
if (myAnnotation != null) {
final PyAnnotation annotation = myAnnotation.getElement();
PyAnnotation annotation = myAnnotation.getElement();
if (annotation != null) {
final PyExpression annotationExpr = annotation.getValue();
PyExpression annotationExpr = annotation.getValue();
if (annotationExpr == null) return;
annotationExpr.replace(elementGenerator.createExpressionFromText(LanguageLevel.PYTHON34, myReturnTypeName));
PsiElement newElement =
annotationExpr.replace(elementGenerator.createExpressionFromText(LanguageLevel.PYTHON34, myReturnTypeName));
addImportsForTypeAnnotations(newElement);
}
}
else if (myTypeCommentAnnotation != null) {
final PsiComment typeComment = myTypeCommentAnnotation.getElement();
PsiComment typeComment = myTypeCommentAnnotation.getElement();
if (typeComment != null) {
final StringBuilder typeCommentAnnotation = new StringBuilder(typeComment.getText());
StringBuilder typeCommentAnnotation = new StringBuilder(typeComment.getText());
typeCommentAnnotation.delete(typeCommentAnnotation.indexOf("->"), typeCommentAnnotation.length());
typeCommentAnnotation.append("-> ").append(myReturnTypeName);
final PsiComment newTypeComment = elementGenerator.createFromText(LanguageLevel.PYTHON27, PsiComment.class, typeCommentAnnotation.toString());
typeComment.replace(newTypeComment);
PsiComment newTypeComment =
elementGenerator.createFromText(LanguageLevel.PYTHON27, PsiComment.class, typeCommentAnnotation.toString());
PsiElement newElement = typeComment.replace(newTypeComment);
addImportsForTypeAnnotations(newElement);
}
}
}
private void addImportsForTypeAnnotations(@NotNull PsiElement element) {
PsiFile file = element.getContainingFile();
if (file == null) return;
PyFunction function = myFunction.getElement();
if (function == null) return;
Project project = element.getProject();
TypeEvalContext typeEvalContext = TypeEvalContext.userInitiated(project, file);
PyType typeForImports = getTypeForImports(function, typeEvalContext);
if (typeForImports != null) {
PyTypeHintGenerationUtil.addImportsForTypeAnnotations(List.of(typeForImports), typeEvalContext, file);
}
}
private @Nullable PyType getTypeForImports(@NotNull PyFunction function, @NotNull TypeEvalContext context) {
PyType returnTypeActual = getActualReturnType(function, myReturnExpr != null ? myReturnExpr.getElement() : null, context);
if (myHaveSuggestedType && returnTypeActual != null) {
return returnTypeActual;
}
else {
return function.getReturnStatementType(context);
}
}
}

View File

@@ -0,0 +1,8 @@
import my
def foo(a) -> my.X:
if a:
return <warning descr="Expected type 'X', got 'Type[X]' instead">my.X<caret></warning>
else:
return <warning descr="Expected type 'X', got 'Type[Y]' instead">my.Y</warning>

View File

@@ -0,0 +1,11 @@
from typing import Type
import my
from my import X
def foo(a) -> Type[X]:
if a:
return my.X<caret>
else:
return my.Y

View File

@@ -0,0 +1,5 @@
class X():
pass
class Y(X):
pass

View File

@@ -0,0 +1,8 @@
import my
def foo(a) -> my.X:
if a:
return <warning descr="Expected type 'X', got 'Type[X]' instead">my.X<caret></warning>
else:
return <warning descr="Expected type 'X', got 'Type[Y]' instead">my.Y</warning>

View File

@@ -0,0 +1,11 @@
from typing import Type
import my
from my import X, Y
def foo(a) -> Type[X | Y]:
if a:
return my.X
else:
return my.Y

View File

@@ -0,0 +1,5 @@
class X():
pass
class Y():
pass

View File

@@ -0,0 +1,2 @@
def func() -> int:
return <warning descr="Expected type 'int', got '(x: Any) -> int' instead">lambda x: 42<caret></warning>

View File

@@ -0,0 +1,5 @@
from typing import Callable, Any
def func() -> Callable[[Any], int]:
return lambda x: 42<caret>

View File

@@ -0,0 +1,10 @@
from __future__ import annotations
import my
def foo(a) -> my.X:
if a:
return <warning descr="Expected type 'X', got 'Type[X]' instead">my.X<caret></warning>
else:
return <warning descr="Expected type 'X', got 'Type[Y]' instead">my.Y</warning>

View File

@@ -0,0 +1,13 @@
from __future__ import annotations
from typing import Type
import my
from my import X, Y
def foo(a) -> Type[X | Y]:
if a:
return my.X
else:
return my.Y

View File

@@ -0,0 +1,8 @@
import my
def foo(a) -> my.X:
if a:
return <warning descr="Expected type 'X', got 'Type[X]' instead">my.X<caret></warning>
else:
return <warning descr="Expected type 'X', got 'Type[Y]' instead">my.Y</warning>

View File

@@ -0,0 +1,11 @@
from typing import Union, Type
import my
from my import X, Y
def foo(a) -> Type[Union[X, Y]]:
if a:
return my.X
else:
return my.Y

View File

@@ -0,0 +1,5 @@
class X():
pass
class Y():
pass

View File

@@ -74,4 +74,15 @@ public abstract class PyQuickFixTestCase extends PyTestCase {
myFixture.checkHighlighting(true, false, false);
assertEmpty(myFixture.filterAvailableIntentions(hint));
}
protected void doMultiFileTest(@NotNull Class inspectionClass, @NotNull String hint) {
myFixture.copyDirectoryToProject(getTestName(true), "");
myFixture.enableInspections(inspectionClass);
myFixture.configureByFile("main.py");
myFixture.checkHighlighting(true, false, false);
final IntentionAction intentionAction = myFixture.findSingleIntention(hint);
myFixture.launchAction(intentionAction);
myFixture.checkResultByFile(getTestName(true) + "/main_after.py", true);
}
}

View File

@@ -33,4 +33,35 @@ public class PyMakeFunctionReturnTypeQuickFixTest extends PyQuickFixTestCase {
public void testPy3OneReturn() {
doQuickFixTest(PyTypeCheckerInspection.class, PyPsiBundle.message("QFIX.make.function.return.type", "f", "int"), LanguageLevel.PYTHON34);
}
// PY-27128
public void testPy39UnionTypeImports() {
runWithLanguageLevel(LanguageLevel.PYTHON39, () -> {
doMultiFileTest(PyTypeCheckerInspection.class, PyPsiBundle.message("QFIX.make.function.return.type", "foo", "Type[Union[X, Y]]"));
});
}
// PY-27128
public void testPy39BitwiseOrUnionFromFutureAnnotationsUnionTypeImports() {
runWithLanguageLevel(LanguageLevel.PYTHON39, () -> {
doMultiFileTest(PyTypeCheckerInspection.class, PyPsiBundle.message("QFIX.make.function.return.type", "foo", "Type[X | Y]"));
});
}
// PY-27128
public void testBitwiseOrUnionTypeImports() {
doMultiFileTest(PyTypeCheckerInspection.class, PyPsiBundle.message("QFIX.make.function.return.type", "foo", "Type[X | Y]"));
}
// PY-27128
public void testAncestorAndInheritorReturn() {
doMultiFileTest(PyTypeCheckerInspection.class, PyPsiBundle.message("QFIX.make.function.return.type", "foo", "Type[X]"));
}
// PY-27128 PY-48466
public void testLambda() {
doQuickFixTest(PyTypeCheckerInspection.class,
PyPsiBundle.message("QFIX.make.function.return.type", "func", "Callable[[Any], int]"),
LanguageLevel.getLatest());
}
}