fixed PY-7085 Specify type for reference using annotation: missing intention for specifying annotation for return values

This commit is contained in:
Ekaterina Tuzova
2012-08-01 17:34:11 +04:00
parent aa77742a86
commit 655b8f5913
4 changed files with 67 additions and 5 deletions

View File

@@ -6,6 +6,7 @@ import com.intellij.codeInsight.template.*;
import com.intellij.openapi.editor.Editor;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.util.TextRange;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiFile;
import com.intellij.psi.PsiReference;
import com.intellij.psi.util.PsiTreeUtil;
@@ -13,6 +14,7 @@ import com.intellij.util.IncorrectOperationException;
import com.jetbrains.python.PyBundle;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.resolve.PyResolveContext;
import com.jetbrains.python.psi.types.PyReturnTypeReference;
import com.jetbrains.python.psi.types.PyType;
import com.jetbrains.python.psi.types.TypeEvalContext;
@@ -67,13 +69,24 @@ public class SpecifyTypeInPy3AnnotationsIntention implements IntentionAction {
}
if (pyFunction != null) {
PyParameter parameter = null;
final PsiElement resolvedReference = reference != null?reference.resolve() : null;
if (problemElement instanceof PyParameter)
parameter = (PyParameter)problemElement;
else if (reference != null && reference.resolve() instanceof PyParameter)
parameter = (PyParameter)reference.resolve();
else if (resolvedReference instanceof PyParameter)
parameter = (PyParameter)resolvedReference;
if (parameter instanceof PyNamedParameter && (((PyNamedParameter)parameter).getAnnotation() != null ||
parameter.getDefaultValue() != null)) return false;
return true;
if (parameter != null)
return true;
else {
if (resolvedReference instanceof PyTargetExpression) {
final PyExpression assignedValue = ((PyTargetExpression)resolvedReference).findAssignedValue();
if (assignedValue instanceof PyCallExpression) {
final Callable callable = ((PyCallExpression)assignedValue).resolveCalleeFunction(PyResolveContext.defaultContext());
if (callable instanceof PyFunction && ((PyFunction)callable).getAnnotation() == null) return true;
}
}
}
}
}
return false;
@@ -96,10 +109,13 @@ public class SpecifyTypeInPy3AnnotationsIntention implements IntentionAction {
PyElementGenerator elementGenerator = PyElementGenerator.getInstance(project);
PyParameter parameter = null;
final PsiElement resolvedReference = reference != null? reference.resolve() : null;
if (problemElement instanceof PyParameter)
parameter = (PyParameter)problemElement;
else if (reference!= null && reference.resolve() instanceof PyParameter) {
parameter = (PyParameter)reference.resolve();
else {
if (resolvedReference instanceof PyParameter) {
parameter = (PyParameter)resolvedReference;
}
}
if (parameter != null && name != null) {
final PyFunction function =
@@ -116,6 +132,34 @@ public class SpecifyTypeInPy3AnnotationsIntention implements IntentionAction {
Template template = ((TemplateBuilderImpl)builder).buildInlineTemplate();
TemplateManager.getInstance(project).startTemplate(editor, template);
}
else { //return type
if (resolvedReference instanceof PyTargetExpression) {
final PyExpression assignedValue = ((PyTargetExpression)resolvedReference).findAssignedValue();
if (assignedValue instanceof PyCallExpression) {
Callable callable = ((PyCallExpression)assignedValue).resolveCalleeFunction(PyResolveContext.defaultContext());
if (callable instanceof PyFunction && ((PyFunction)callable).getAnnotation() == null) {
final String functionSignature = "def " + callable.getName() + callable.getParameterList().getText();
final PyFunction function = elementGenerator.createFromText(LanguageLevel.forElement(problemElement), PyFunction.class,
functionSignature +
" -> object:\n\t" +
((PyFunction)callable).getStatementList().getText());
callable = (PyFunction)callable.replace(function);
callable = CodeInsightUtilBase.forcePsiPostprocessAndRestoreElement(callable);
final PyExpression value = ((PyFunction)callable).getAnnotation().getValue();
final int offset = value.getTextOffset();
editor.getCaretModel().moveToOffset(offset);
final TemplateBuilder builder = TemplateBuilderFactory.getInstance().
createTemplateBuilder(value);
builder.replaceRange(TextRange.create(0, PyNames.OBJECT.length()), PyNames.OBJECT);
Template template = ((TemplateBuilderImpl)builder).buildInlineTemplate();
TemplateManager.getInstance(project).startTemplate(editor, template);
}
}
}
}
}
}

View File

@@ -0,0 +1,7 @@
def g(x) -> object:
return x
def f(x):
y = g(x.keys())
return y.startswith('foo')

View File

@@ -0,0 +1,7 @@
def g(x):
return x
def f(x):
y = g(x.keys())
return y<caret>.startswith('foo')

View File

@@ -241,6 +241,10 @@ public class PyIntentionTest extends PyTestCase {
doTest(PyBundle.message("INTN.specify.type.in.annotation"), LanguageLevel.PYTHON32);
}
public void testReturnTypeInPy3Annotation() { //PY-7085
doTest(PyBundle.message("INTN.specify.type.in.annotation"), LanguageLevel.PYTHON32);
}
public void testTypeAssertion() {
doTest(PyBundle.message("INTN.insert.assertion"));
}