PY-11127 Do not generate call of base method if it raises NotImplementedError

This commit is contained in:
Mikhail Golubev
2014-10-30 20:39:13 +03:00
parent 59c5ef46fe
commit 9574f68583
4 changed files with 45 additions and 1 deletions

View File

@@ -28,6 +28,7 @@ import com.intellij.openapi.editor.Editor;
import com.intellij.openapi.editor.ScrollType;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.ui.DialogWrapper;
import com.intellij.openapi.util.Comparing;
import com.intellij.openapi.util.TextRange;
import com.intellij.openapi.util.text.StringUtil;
import com.intellij.psi.PsiDocumentManager;
@@ -222,7 +223,7 @@ public class PyOverrideImplementUtil {
}
}
if (PyNames.FAKE_OLD_BASE.equals(baseClass.getName()) || implement) {
if (PyNames.FAKE_OLD_BASE.equals(baseClass.getName()) || raisesNotImplementedError(baseFunction) || implement) {
statementBody.append(PyNames.PASS);
}
else {
@@ -263,6 +264,26 @@ public class PyOverrideImplementUtil {
return pyFunctionBuilder;
}
private static boolean raisesNotImplementedError(@NotNull PyFunction function) {
for (PyStatement statement : function.getStatementList().getStatements()) {
if (!(statement instanceof PyRaiseStatement)) {
continue;
}
final PyRaiseStatement raiseStatement = (PyRaiseStatement)statement;
final PyExpression[] expressions = raiseStatement.getExpressions();
if (expressions.length > 0) {
final PyExpression expression = expressions[0];
if (expression instanceof PyCallExpression) {
final PyExpression callee = ((PyCallExpression)expression).getCallee();
if (callee instanceof PyReferenceExpression && Comparing.equal(callee.getName(), "NotImplementedError")) {
return true;
}
}
}
}
return false;
}
// TODO find a better place for this logic
private static String getReferenceText(PyClass fromClass, PyClass toClass) {
final PyExpression[] superClassExpressions = fromClass.getSuperClassExpressions();

View File

@@ -0,0 +1,8 @@
class A:
def m(self):
"""Abstract method."""
raise NotImplementedError('Should not be called directly')
class B(A):
pass

View File

@@ -0,0 +1,9 @@
class A:
def m(self):
"""Abstract method."""
raise NotImplementedError('Should not be called directly')
class B(A):
def m(self):
pass

View File

@@ -121,6 +121,12 @@ public class PyOverrideTest extends PyTestCase {
myFixture.checkResultByFile("override/" + getTestName(true) + "_after.py", true);
}
// PY-11127
public void testOverriddenMethodRaisesNotImplementedError() {
doTest();
}
public void testPy3k() {
doTest3k();
}