fixed PY-9263 Move attribute to init method: add super class call when moving to not yet existing init

reused AddFieldQuickFix logic
This commit is contained in:
Ekaterina Tuzova
2013-03-25 15:58:12 +04:00
parent 24cef7726b
commit 097ed48b0a
7 changed files with 45 additions and 36 deletions

View File

@@ -98,8 +98,8 @@ public class AddFieldQuickFix implements LocalQuickFix {
}
@Nullable
public static PsiElement addFieldToInit(Project project, PyClass cls, String item_name, Function<String, PyStatement> callback) {
if (cls != null && item_name != null) {
public static PsiElement addFieldToInit(Project project, PyClass cls, String itemName, Function<String, PyStatement> callback) {
if (cls != null && itemName != null) {
PyFunction init = cls.findMethodByName(PyNames.INIT, false);
if (init != null) {
return appendToMethod(init, callback);
@@ -109,21 +109,23 @@ public class AddFieldQuickFix implements LocalQuickFix {
init = ancestor.findMethodByName(PyNames.INIT, false);
if (init != null) break;
}
PyFunction new_init = createInitMethod(project, cls, init);
if (new_init == null) {
PyFunction newInit = createInitMethod(project, cls, init);
if (newInit == null) {
return null;
}
appendToMethod(new_init, callback);
appendToMethod(newInit, callback);
PsiElement add_anchor = null;
PsiElement addAnchor = null;
PyFunction[] meths = cls.getMethods();
if (meths.length > 0) add_anchor = meths[0].getPrevSibling();
PyStatementList cls_content = cls.getStatementList();
new_init = (PyFunction) cls_content.addAfter(new_init, add_anchor);
if (meths.length > 0) addAnchor = meths[0].getPrevSibling();
PyStatementList clsContent = cls.getStatementList();
newInit = (PyFunction) clsContent.addAfter(newInit, addAnchor);
PyUtil.showBalloon(project, PyBundle.message("QFIX.added.constructor.$0.for.field.$1", cls.getName(), item_name), MessageType.INFO);
return new_init.getStatementList().getStatements()[0];
PyUtil.showBalloon(project, PyBundle.message("QFIX.added.constructor.$0.for.field.$1", cls.getName(), itemName), MessageType.INFO);
final PyStatementList statementList = newInit.getStatementList();
assert statementList != null;
return statementList.getStatements()[0];
//else // well, that can't be
}
}

View File

@@ -5,8 +5,9 @@ import com.intellij.codeInspection.ProblemDescriptor;
import com.intellij.openapi.project.Project;
import com.intellij.psi.PsiElement;
import com.intellij.psi.util.PsiTreeUtil;
import com.intellij.util.Function;
import com.intellij.util.FunctionUtil;
import com.jetbrains.python.PyBundle;
import com.jetbrains.python.PyNames;
import com.jetbrains.python.psi.*;
import org.jetbrains.annotations.NonNls;
import org.jetbrains.annotations.NotNull;
@@ -39,31 +40,11 @@ public class PyMoveAttributeToInitQuickFix implements LocalQuickFix {
final PyAssignmentStatement assignment = PsiTreeUtil.getParentOfType(element, PyAssignmentStatement.class);
if (containingClass == null || assignment == null) return;
final PsiElement copy = assignment.copy();
if (!addDefinition(copy, containingClass)) return;
final Function<String, PyStatement> callback = FunctionUtil.<String, PyStatement>constant(assignment);
AddFieldQuickFix.addFieldToInit(project, containingClass, ((PyTargetExpression)element).getName(), callback);
removeDefinition(assignment);
}
private static boolean addDefinition(PsiElement copy, PyClass containingClass) {
PyFunction init = containingClass.findMethodByName(PyNames.INIT, false);
if (init == null) {
final PyStatementList classStatementList = containingClass.getStatementList();
init = PyElementGenerator.getInstance(containingClass.getProject()).createFromText(LanguageLevel.forElement(containingClass),
PyFunction.class,
"def __init__(self):\n\t" +
copy.getText());
PyUtil.addElementToStatementList(init, classStatementList, true);
return true;
}
final PyStatementList statementList = init.getStatementList();
if (statementList == null) return false;
PyUtil.addElementToStatementList(copy, statementList, true);
return true;
}
private static boolean removeDefinition(PyAssignmentStatement assignment) {
final PyStatementList statementList = PsiTreeUtil.getParentOfType(assignment, PyStatementList.class);
if (statementList == null) return false;

View File

@@ -3,8 +3,8 @@ __author__ = 'ktisha'
class A:
def __init__(self):
self.b = 1
self._a = 1
self.b = 1
def foo(self):
pass

View File

@@ -0,0 +1,9 @@
__author__ = 'ktisha'
class Base(object):
def __init__(self):
self.param = 2
class Child(Base):
def f(self):
self.<caret>my = 2

View File

@@ -0,0 +1,13 @@
__author__ = 'ktisha'
class Base(object):
def __init__(self):
self.param = 2
class Child(Base):
def __init__(self):
super(Child, self).__init__()
self.my = 2
def f(self):
pass

View File

@@ -3,8 +3,8 @@ __author__ = 'ktisha'
class A:
def __init__(self):
self.b = 1
self._a = 1
self.b = 1
def foo(self):
c = 1

View File

@@ -30,4 +30,8 @@ public class PyMoveAttributeToInitQuickFixTest extends PyQuickFixTestCase {
doInspectionTest(PyAttributeOutsideInitInspection.class, PyBundle.message("QFIX.move.attribute"));
}
public void testAddSuperCall() {
doInspectionTest(PyAttributeOutsideInitInspection.class, PyBundle.message("QFIX.move.attribute"));
}
}