Fixed unnecessary cls argument when extracting @classmethod (PY-6624)

This commit is contained in:
Andrey Vlasovskikh
2012-05-23 19:36:51 +04:00
parent 9544579b0e
commit f4f66b0f32
3 changed files with 32 additions and 5 deletions

View File

@@ -59,8 +59,8 @@ public class PyCodeFragmentUtil {
for (PsiElement element : filterElementsInScope(getInputElements(subGraph, graph), owner)) {
final String name = getName(element);
if (name != null) {
// Ignore "self", it is generated automatically when extracting any method fragment
if (PyPsiUtils.isMethodContext(element) && "self".equals(name)) {
// Ignore "self" and "cls", they are generated automatically when extracting any method fragment
if (resolvesToBoundMethodParameter(element)) {
continue;
}
if (globalWrites.contains(name)) {
@@ -82,6 +82,33 @@ public class PyCodeFragmentUtil {
return new PyCodeFragment(inputNames, outputNames, globalWrites, subGraphAnalysis.returns > 0);
}
private static boolean resolvesToBoundMethodParameter(@NotNull PsiElement element) {
if (PyPsiUtils.isMethodContext(element)) {
final PyFunction function = PsiTreeUtil.getParentOfType(element, PyFunction.class);
if (function != null) {
final PsiReference reference = element.getReference();
if (reference != null) {
final PsiElement resolved = reference.resolve();
if (resolved instanceof PyParameter) {
final PyParameterList parameterList = PsiTreeUtil.getParentOfType(resolved, PyParameterList.class);
if (parameterList != null) {
final PyParameter[] parameters = parameterList.getParameters();
if (parameters.length > 0) {
if (resolved == parameters[0]) {
final PyFunction.Modifier modifier = function.getModifier();
if (modifier == null || modifier == PyFunction.Modifier.CLASSMETHOD) {
return true;
}
}
}
}
}
}
}
}
return false;
}
@Nullable
private static String getName(@NotNull PsiElement element) {
if (element instanceof PsiNamedElement) {

View File

@@ -1,8 +1,8 @@
class C:
@classmethod
def baz(cls):
print "hello world"
print('foo', cls)
@classmethod
def foo(cls):
cls.baz()
cls.baz()

View File

@@ -1,4 +1,4 @@
class C:
@classmethod
def foo(cls):
<selection>print "hello world"</selection>
<selection>print('foo', cls)</selection>