mirror of
https://gitflic.ru/project/openide/openide.git
synced 2026-04-21 22:11:40 +07:00
PY-27128 PY-48466 Fix problem with imports in PyMakeFunctionReturnTypeQuickFix
(cherry picked from commit 9320c2e708afa2ddd025f60b0c122132387493b7) IJ-MR-10124 GitOrigin-RevId: 219e646482999e8b81b77aaf52eb24574381af3d
This commit is contained in:
committed by
intellij-monorepo-bot
parent
fe812b5f46
commit
903f6e11ca
@@ -74,20 +74,20 @@ public class PyTypeCheckerInspection extends PyInspection {
|
||||
|
||||
@Override
|
||||
public void visitPyReturnStatement(@NotNull PyReturnStatement node) {
|
||||
final ScopeOwner owner = ScopeUtil.getScopeOwner(node);
|
||||
ScopeOwner owner = ScopeUtil.getScopeOwner(node);
|
||||
if (owner instanceof PyFunction) {
|
||||
final PyFunction function = (PyFunction)owner;
|
||||
final PyAnnotation annotation = function.getAnnotation();
|
||||
final String typeCommentAnnotation = function.getTypeCommentAnnotation();
|
||||
PyFunction function = (PyFunction)owner;
|
||||
PyAnnotation annotation = function.getAnnotation();
|
||||
String typeCommentAnnotation = function.getTypeCommentAnnotation();
|
||||
if (annotation != null || typeCommentAnnotation != null) {
|
||||
final PyExpression returnExpr = node.getExpression();
|
||||
final PyType expected = getExpectedReturnType(function);
|
||||
final PyType actual = returnExpr != null ? tryPromotingType(returnExpr, expected) : PyNoneType.INSTANCE;
|
||||
PyExpression returnExpr = node.getExpression();
|
||||
PyType expected = getExpectedReturnType(function);
|
||||
PyType actual = returnExpr != null ? tryPromotingType(returnExpr, expected) : PyNoneType.INSTANCE;
|
||||
if (!PyTypeChecker.match(expected, actual, myTypeEvalContext)) {
|
||||
final String expectedName = PythonDocumentationProvider.getTypeName(expected, myTypeEvalContext);
|
||||
final String actualName = PythonDocumentationProvider.getTypeName(actual, myTypeEvalContext);
|
||||
PyMakeFunctionReturnTypeQuickFix localQuickFix = new PyMakeFunctionReturnTypeQuickFix(function, actualName, myTypeEvalContext);
|
||||
PyMakeFunctionReturnTypeQuickFix globalQuickFix = new PyMakeFunctionReturnTypeQuickFix(function, null, myTypeEvalContext);
|
||||
String expectedName = PythonDocumentationProvider.getTypeName(expected, myTypeEvalContext);
|
||||
String actualName = PythonDocumentationProvider.getTypeName(actual, myTypeEvalContext);
|
||||
var localQuickFix = new PyMakeFunctionReturnTypeQuickFix(function, returnExpr, actual, myTypeEvalContext);
|
||||
var globalQuickFix = new PyMakeFunctionReturnTypeQuickFix(function, returnExpr, null, myTypeEvalContext);
|
||||
registerProblem(returnExpr != null ? returnExpr : node,
|
||||
PyPsiBundle.message("INSP.type.checker.expected.type.got.type.instead", expectedName, actualName),
|
||||
localQuickFix, globalQuickFix);
|
||||
@@ -98,7 +98,12 @@ public class PyTypeCheckerInspection extends PyInspection {
|
||||
|
||||
@Nullable
|
||||
private PyType getExpectedReturnType(@NotNull PyFunction function) {
|
||||
final PyType returnType = myTypeEvalContext.getReturnType(function);
|
||||
return getExpectedReturnType(function, myTypeEvalContext);
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public static PyType getExpectedReturnType(@NotNull PyFunction function, @NotNull TypeEvalContext typeEvalContext) {
|
||||
final PyType returnType = typeEvalContext.getReturnType(function);
|
||||
|
||||
if (function.isAsync() || function.isGenerator()) {
|
||||
return Ref.deref(PyTypingTypeProvider.coroutineOrGeneratorElementType(returnType));
|
||||
@@ -107,6 +112,13 @@ public class PyTypeCheckerInspection extends PyInspection {
|
||||
return returnType;
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public static PyType getActualReturnType(@NotNull PyFunction function, @Nullable PyExpression returnExpr,
|
||||
@NotNull TypeEvalContext context) {
|
||||
PyType returnTypeExpected = getExpectedReturnType(function, context);
|
||||
return returnExpr != null ? tryPromotingType(returnExpr, returnTypeExpected, context) : PyNoneType.INSTANCE;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void visitPyTargetExpression(@NotNull PyTargetExpression node) {
|
||||
// TODO: Check types in class-level assignments
|
||||
@@ -125,9 +137,14 @@ public class PyTypeCheckerInspection extends PyInspection {
|
||||
|
||||
@Nullable
|
||||
private PyType tryPromotingType(@NotNull PyExpression value, @Nullable PyType expected) {
|
||||
final PyType promotedToLiteral = PyLiteralType.Companion.promoteToLiteral(value, expected, myTypeEvalContext, null);
|
||||
return tryPromotingType(value, expected, myTypeEvalContext);
|
||||
}
|
||||
|
||||
@Nullable
|
||||
public static PyType tryPromotingType(@NotNull PyExpression value, @Nullable PyType expected, @NotNull TypeEvalContext context) {
|
||||
final PyType promotedToLiteral = PyLiteralType.Companion.promoteToLiteral(value, expected, context, null);
|
||||
if (promotedToLiteral != null) return promotedToLiteral;
|
||||
return myTypeEvalContext.getType(value);
|
||||
return context.getType(value);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@@ -18,33 +18,49 @@ package com.jetbrains.python.inspections.quickfix;
|
||||
import com.intellij.codeInspection.LocalQuickFix;
|
||||
import com.intellij.codeInspection.ProblemDescriptor;
|
||||
import com.intellij.openapi.project.Project;
|
||||
import com.intellij.psi.PsiComment;
|
||||
import com.intellij.psi.SmartPointerManager;
|
||||
import com.intellij.psi.SmartPsiElementPointer;
|
||||
import com.intellij.psi.*;
|
||||
import com.jetbrains.python.PyPsiBundle;
|
||||
import com.jetbrains.python.codeInsight.intentions.PyTypeHintGenerationUtil;
|
||||
import com.jetbrains.python.documentation.PythonDocumentationProvider;
|
||||
import com.jetbrains.python.psi.*;
|
||||
import com.jetbrains.python.psi.types.PyType;
|
||||
import com.jetbrains.python.psi.types.TypeEvalContext;
|
||||
import org.jetbrains.annotations.NotNull;
|
||||
import org.jetbrains.annotations.Nullable;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import static com.jetbrains.python.inspections.PyTypeCheckerInspection.Visitor.getActualReturnType;
|
||||
|
||||
/**
|
||||
* @author lada
|
||||
*/
|
||||
public class PyMakeFunctionReturnTypeQuickFix implements LocalQuickFix {
|
||||
private final SmartPsiElementPointer<PyFunction> myFunction;
|
||||
private final SmartPsiElementPointer<PyExpression> myReturnExpr;
|
||||
private final SmartPsiElementPointer<PyAnnotation> myAnnotation;
|
||||
private final SmartPsiElementPointer<PsiComment> myTypeCommentAnnotation;
|
||||
private final String myReturnTypeName;
|
||||
private final boolean myHaveSuggestedType;
|
||||
|
||||
public PyMakeFunctionReturnTypeQuickFix(@NotNull PyFunction function, @Nullable String returnTypeName, @NotNull TypeEvalContext context) {
|
||||
final SmartPointerManager manager = SmartPointerManager.getInstance(function.getProject());
|
||||
public PyMakeFunctionReturnTypeQuickFix(@NotNull PyFunction function,
|
||||
@Nullable PyExpression returnExpr,
|
||||
@Nullable PyType returnTypeSuggested,
|
||||
@NotNull TypeEvalContext context) {
|
||||
SmartPointerManager manager = SmartPointerManager.getInstance(function.getProject());
|
||||
myFunction = manager.createSmartPsiElementPointer(function);
|
||||
myReturnExpr = returnExpr != null ? manager.createSmartPsiElementPointer(returnExpr) : null;
|
||||
|
||||
PyAnnotation annotation = function.getAnnotation();
|
||||
myAnnotation = annotation != null ? manager.createSmartPsiElementPointer(annotation) : null;
|
||||
|
||||
PsiComment typeCommentAnnotation = function.getTypeComment();
|
||||
myTypeCommentAnnotation = typeCommentAnnotation != null ? manager.createSmartPsiElementPointer(typeCommentAnnotation) : null;
|
||||
myReturnTypeName = (returnTypeName == null) ? PythonDocumentationProvider.getTypeName(function.getReturnStatementType(context), context) : returnTypeName;
|
||||
|
||||
myHaveSuggestedType = returnTypeSuggested != null;
|
||||
PyType returnType = myHaveSuggestedType ? returnTypeSuggested : function.getReturnStatementType(context);
|
||||
|
||||
myReturnTypeName = PythonDocumentationProvider.getTypeHint(returnType, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
@@ -65,22 +81,49 @@ public class PyMakeFunctionReturnTypeQuickFix implements LocalQuickFix {
|
||||
public void applyFix(@NotNull Project project, @NotNull ProblemDescriptor descriptor) {
|
||||
PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project);
|
||||
if (myAnnotation != null) {
|
||||
final PyAnnotation annotation = myAnnotation.getElement();
|
||||
PyAnnotation annotation = myAnnotation.getElement();
|
||||
if (annotation != null) {
|
||||
final PyExpression annotationExpr = annotation.getValue();
|
||||
PyExpression annotationExpr = annotation.getValue();
|
||||
if (annotationExpr == null) return;
|
||||
annotationExpr.replace(elementGenerator.createExpressionFromText(LanguageLevel.PYTHON34, myReturnTypeName));
|
||||
PsiElement newElement =
|
||||
annotationExpr.replace(elementGenerator.createExpressionFromText(LanguageLevel.PYTHON34, myReturnTypeName));
|
||||
addImportsForTypeAnnotations(newElement);
|
||||
}
|
||||
}
|
||||
else if (myTypeCommentAnnotation != null) {
|
||||
final PsiComment typeComment = myTypeCommentAnnotation.getElement();
|
||||
PsiComment typeComment = myTypeCommentAnnotation.getElement();
|
||||
if (typeComment != null) {
|
||||
final StringBuilder typeCommentAnnotation = new StringBuilder(typeComment.getText());
|
||||
StringBuilder typeCommentAnnotation = new StringBuilder(typeComment.getText());
|
||||
typeCommentAnnotation.delete(typeCommentAnnotation.indexOf("->"), typeCommentAnnotation.length());
|
||||
typeCommentAnnotation.append("-> ").append(myReturnTypeName);
|
||||
final PsiComment newTypeComment = elementGenerator.createFromText(LanguageLevel.PYTHON27, PsiComment.class, typeCommentAnnotation.toString());
|
||||
typeComment.replace(newTypeComment);
|
||||
PsiComment newTypeComment =
|
||||
elementGenerator.createFromText(LanguageLevel.PYTHON27, PsiComment.class, typeCommentAnnotation.toString());
|
||||
PsiElement newElement = typeComment.replace(newTypeComment);
|
||||
addImportsForTypeAnnotations(newElement);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private void addImportsForTypeAnnotations(@NotNull PsiElement element) {
|
||||
PsiFile file = element.getContainingFile();
|
||||
if (file == null) return;
|
||||
PyFunction function = myFunction.getElement();
|
||||
if (function == null) return;
|
||||
Project project = element.getProject();
|
||||
TypeEvalContext typeEvalContext = TypeEvalContext.userInitiated(project, file);
|
||||
PyType typeForImports = getTypeForImports(function, typeEvalContext);
|
||||
if (typeForImports != null) {
|
||||
PyTypeHintGenerationUtil.addImportsForTypeAnnotations(List.of(typeForImports), typeEvalContext, file);
|
||||
}
|
||||
}
|
||||
|
||||
private @Nullable PyType getTypeForImports(@NotNull PyFunction function, @NotNull TypeEvalContext context) {
|
||||
PyType returnTypeActual = getActualReturnType(function, myReturnExpr != null ? myReturnExpr.getElement() : null, context);
|
||||
if (myHaveSuggestedType && returnTypeActual != null) {
|
||||
return returnTypeActual;
|
||||
}
|
||||
else {
|
||||
return function.getReturnStatementType(context);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
import my
|
||||
|
||||
|
||||
def foo(a) -> my.X:
|
||||
if a:
|
||||
return <warning descr="Expected type 'X', got 'Type[X]' instead">my.X<caret></warning>
|
||||
else:
|
||||
return <warning descr="Expected type 'X', got 'Type[Y]' instead">my.Y</warning>
|
||||
@@ -0,0 +1,11 @@
|
||||
from typing import Type
|
||||
|
||||
import my
|
||||
from my import X
|
||||
|
||||
|
||||
def foo(a) -> Type[X]:
|
||||
if a:
|
||||
return my.X<caret>
|
||||
else:
|
||||
return my.Y
|
||||
@@ -0,0 +1,5 @@
|
||||
class X():
|
||||
pass
|
||||
|
||||
class Y(X):
|
||||
pass
|
||||
@@ -0,0 +1,8 @@
|
||||
import my
|
||||
|
||||
|
||||
def foo(a) -> my.X:
|
||||
if a:
|
||||
return <warning descr="Expected type 'X', got 'Type[X]' instead">my.X<caret></warning>
|
||||
else:
|
||||
return <warning descr="Expected type 'X', got 'Type[Y]' instead">my.Y</warning>
|
||||
@@ -0,0 +1,11 @@
|
||||
from typing import Type
|
||||
|
||||
import my
|
||||
from my import X, Y
|
||||
|
||||
|
||||
def foo(a) -> Type[X | Y]:
|
||||
if a:
|
||||
return my.X
|
||||
else:
|
||||
return my.Y
|
||||
@@ -0,0 +1,5 @@
|
||||
class X():
|
||||
pass
|
||||
|
||||
class Y():
|
||||
pass
|
||||
@@ -0,0 +1,2 @@
|
||||
def func() -> int:
|
||||
return <warning descr="Expected type 'int', got '(x: Any) -> int' instead">lambda x: 42<caret></warning>
|
||||
@@ -0,0 +1,5 @@
|
||||
from typing import Callable, Any
|
||||
|
||||
|
||||
def func() -> Callable[[Any], int]:
|
||||
return lambda x: 42<caret>
|
||||
@@ -0,0 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import my
|
||||
|
||||
|
||||
def foo(a) -> my.X:
|
||||
if a:
|
||||
return <warning descr="Expected type 'X', got 'Type[X]' instead">my.X<caret></warning>
|
||||
else:
|
||||
return <warning descr="Expected type 'X', got 'Type[Y]' instead">my.Y</warning>
|
||||
@@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Type
|
||||
|
||||
import my
|
||||
from my import X, Y
|
||||
|
||||
|
||||
def foo(a) -> Type[X | Y]:
|
||||
if a:
|
||||
return my.X
|
||||
else:
|
||||
return my.Y
|
||||
@@ -0,0 +1,5 @@
|
||||
class X():
|
||||
pass
|
||||
|
||||
class Y():
|
||||
pass
|
||||
@@ -0,0 +1,8 @@
|
||||
import my
|
||||
|
||||
|
||||
def foo(a) -> my.X:
|
||||
if a:
|
||||
return <warning descr="Expected type 'X', got 'Type[X]' instead">my.X<caret></warning>
|
||||
else:
|
||||
return <warning descr="Expected type 'X', got 'Type[Y]' instead">my.Y</warning>
|
||||
@@ -0,0 +1,11 @@
|
||||
from typing import Union, Type
|
||||
|
||||
import my
|
||||
from my import X, Y
|
||||
|
||||
|
||||
def foo(a) -> Type[Union[X, Y]]:
|
||||
if a:
|
||||
return my.X
|
||||
else:
|
||||
return my.Y
|
||||
@@ -0,0 +1,5 @@
|
||||
class X():
|
||||
pass
|
||||
|
||||
class Y():
|
||||
pass
|
||||
@@ -74,4 +74,15 @@ public abstract class PyQuickFixTestCase extends PyTestCase {
|
||||
myFixture.checkHighlighting(true, false, false);
|
||||
assertEmpty(myFixture.filterAvailableIntentions(hint));
|
||||
}
|
||||
|
||||
protected void doMultiFileTest(@NotNull Class inspectionClass, @NotNull String hint) {
|
||||
myFixture.copyDirectoryToProject(getTestName(true), "");
|
||||
myFixture.enableInspections(inspectionClass);
|
||||
myFixture.configureByFile("main.py");
|
||||
myFixture.checkHighlighting(true, false, false);
|
||||
|
||||
final IntentionAction intentionAction = myFixture.findSingleIntention(hint);
|
||||
myFixture.launchAction(intentionAction);
|
||||
myFixture.checkResultByFile(getTestName(true) + "/main_after.py", true);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -33,4 +33,35 @@ public class PyMakeFunctionReturnTypeQuickFixTest extends PyQuickFixTestCase {
|
||||
public void testPy3OneReturn() {
|
||||
doQuickFixTest(PyTypeCheckerInspection.class, PyPsiBundle.message("QFIX.make.function.return.type", "f", "int"), LanguageLevel.PYTHON34);
|
||||
}
|
||||
|
||||
// PY-27128
|
||||
public void testPy39UnionTypeImports() {
|
||||
runWithLanguageLevel(LanguageLevel.PYTHON39, () -> {
|
||||
doMultiFileTest(PyTypeCheckerInspection.class, PyPsiBundle.message("QFIX.make.function.return.type", "foo", "Type[Union[X, Y]]"));
|
||||
});
|
||||
}
|
||||
|
||||
// PY-27128
|
||||
public void testPy39BitwiseOrUnionFromFutureAnnotationsUnionTypeImports() {
|
||||
runWithLanguageLevel(LanguageLevel.PYTHON39, () -> {
|
||||
doMultiFileTest(PyTypeCheckerInspection.class, PyPsiBundle.message("QFIX.make.function.return.type", "foo", "Type[X | Y]"));
|
||||
});
|
||||
}
|
||||
|
||||
// PY-27128
|
||||
public void testBitwiseOrUnionTypeImports() {
|
||||
doMultiFileTest(PyTypeCheckerInspection.class, PyPsiBundle.message("QFIX.make.function.return.type", "foo", "Type[X | Y]"));
|
||||
}
|
||||
|
||||
// PY-27128
|
||||
public void testAncestorAndInheritorReturn() {
|
||||
doMultiFileTest(PyTypeCheckerInspection.class, PyPsiBundle.message("QFIX.make.function.return.type", "foo", "Type[X]"));
|
||||
}
|
||||
|
||||
// PY-27128 PY-48466
|
||||
public void testLambda() {
|
||||
doQuickFixTest(PyTypeCheckerInspection.class,
|
||||
PyPsiBundle.message("QFIX.make.function.return.type", "func", "Callable[[Any], int]"),
|
||||
LanguageLevel.getLatest());
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user