PY-1695 python imports reworked again

This commit is contained in:
Dennis Ushakov
2010-09-13 22:37:46 +04:00
parent 30c0a0c50b
commit 26090aed43
9 changed files with 105 additions and 97 deletions

View File

@@ -1,6 +1,7 @@
package com.jetbrains.python.actions;
import com.intellij.openapi.diagnostic.Logger;
import com.intellij.openapi.vfs.VirtualFile;
import com.intellij.psi.*;
import com.intellij.util.IncorrectOperationException;
import com.jetbrains.python.PythonDocStringFinder;
@@ -114,7 +115,9 @@ public class AddImportHelper {
public static void addImport(final PsiNamedElement target, final PsiFile file, final PyElement element) {
final boolean useQualified = !PyCodeInsightSettings.getInstance().PREFER_FROM_IMPORT;
final String path = ResolveImportUtil.findShortestImportableName(element, target.getContainingFile().getVirtualFile());
if (useQualified) {
if (target instanceof PsiFileSystemItem) {
addImportStatement(file, path, null);
} else if (useQualified) {
addImportStatement(file, path, null);
final PyElementGenerator elementGenerator = PyElementGenerator.getInstance(file.getProject());
element.replace(elementGenerator.createExpressionFromText(path + "." + target.getName()));

View File

@@ -3,12 +3,10 @@ package com.jetbrains.python.refactoring.classes;
import com.intellij.openapi.diagnostic.Logger;
import com.intellij.openapi.project.Project;
import com.intellij.openapi.util.Comparing;
import com.intellij.openapi.util.Key;
import com.intellij.openapi.vfs.VirtualFile;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiFile;
import com.intellij.psi.PsiFileFactory;
import com.intellij.psi.PsiPolyVariantReference;
import com.intellij.util.containers.*;
import com.intellij.psi.*;
import com.intellij.psi.util.PsiTreeUtil;
import com.jetbrains.python.PythonFileType;
import com.jetbrains.python.actions.AddImportHelper;
import com.jetbrains.python.codeInsight.PyCodeInsightSettings;
@@ -19,7 +17,6 @@ import com.jetbrains.python.psi.resolve.ResolveImportUtil;
import org.jetbrains.annotations.Nullable;
import java.util.*;
import java.util.HashSet;
/**
* @author Dennis.Ushakov
@@ -106,6 +103,7 @@ public class PyClassRefactoringUtil {
public static void moveMethods(List<PyFunction> methods, PyClass superClass) {
if (methods.size() == 0) return;
rememberNamedReferences(methods);
final PyElement[] elements = methods.toArray(new PyElement[methods.size()]);
addMethods(superClass, elements, true);
removeMethodsWithComments(elements);
@@ -130,98 +128,95 @@ public class PyClassRefactoringUtil {
public static void addMethods(final PyClass superClass, final PyElement[] elements, final boolean up) {
if (elements.length == 0) return;
final Project project = superClass.getProject();
final String text = prepareClassText(superClass, elements, up, false, null);
if (text == null) return;
final PyClass newClass = PyElementGenerator.getInstance(project).createFromText(LanguageLevel.getDefault(), PyClass.class, text);
final PyStatementList statements = superClass.getStatementList();
final PyStatementList newStatements = newClass.getStatementList();
if (statements.getStatements().length != 0) {
for (PyElement newStatement : newStatements.getStatements()) {
if (newStatement instanceof PyExpressionStatement && newStatement.getFirstChild() instanceof PyStringLiteralExpression) continue;
final PsiElement anchor = statements.add(newStatement);
final Set<PsiElement> comments = PyUtil.getComments(newStatement);
for (PsiElement comment : comments) {
statements.addBefore(comment, anchor);
for (PyElement newStatement : elements) {
if (up && newStatement instanceof PyFunction) {
final String name = newStatement.getName();
if (name != null && superClass.findMethodByName(name, false) != null) {
continue;
}
}
} else {
statements.replace(newStatements);
if (newStatement instanceof PyExpressionStatement && newStatement.getFirstChild() instanceof PyStringLiteralExpression) continue;
final PsiElement anchor = statements.add(newStatement);
restoreReferences((PyElement)anchor);
final Set<PsiElement> comments = PyUtil.getComments(newStatement);
for (PsiElement comment : comments) {
statements.addBefore(comment, anchor);
}
}
PyPsiUtils.removeRedundantPass(statements);
}
private static void restoreReferences(PyElement newStatement) {
newStatement.acceptChildren(new PyRecursiveElementVisitor() {
@Override
public void visitPyReferenceExpression(PyReferenceExpression node) {
super.visitPyReferenceExpression(node);
restoreReference(node);
}
});
}
private static void restoreReference(final PyReferenceExpression node) {
PsiNamedElement target = node.getCopyableUserData(ENCODED_IMPORT);
if (target instanceof PsiDirectory) {
target = (PsiNamedElement)PyUtil.turnDirIntoInit(target);
}
if (target == null) return;
if (PyBuiltinCache.getInstance(target).hasInBuiltins(target)) return;
if (PsiTreeUtil.isAncestor(node.getContainingFile(), target, false)) return;
AddImportHelper.addImport(target, node.getContainingFile(), node);
node.putCopyableUserData(ENCODED_IMPORT, null);
}
public static void insertImport(PyClass target, Collection<PyClass> newClasses) {
for (PyClass newClass : newClasses) {
insertImport(target, newClass);
}
}
@Nullable
public static String prepareClassText(PyClass superClass, PyElement[] elements, boolean up, boolean ignoreNoChanges, final String preparedClassName) {
PsiElement sibling = superClass.getPrevSibling();
final String white = sibling != null ? "\n" + sibling.getText() + " ": "\n ";
final StringBuilder builder = new StringBuilder("class ");
if (preparedClassName != null) {
builder.append(preparedClassName).append(":");
} else {
builder.append("Foo").append(":");
}
boolean hasChanges = false;
for (PyElement element : elements) {
final String name = element.getName();
if (name != null && (up || superClass.findMethodByName(name, false) == null)) {
final Set<PsiElement> comments = PyUtil.getComments(element);
for (PsiElement comment : comments) {
builder.append(white).append(comment.getText());
}
builder.append(white).append(element.getText()).append("\n");
hasChanges = true;
}
}
if (ignoreNoChanges && !hasChanges) {
builder.append(white).append("pass");
}
return ignoreNoChanges || hasChanges ? builder.toString() : null;
}
public static void insertImport(PyClass target, PyClass newClass) {
private static void insertImport(PyClass target, PyClass newClass) {
if (PyBuiltinCache.getInstance(newClass).hasInBuiltins(newClass)) return;
final PsiFile newFile = newClass.getContainingFile();
final VirtualFile vFile = newFile.getVirtualFile();
assert vFile != null;
final PsiFile file = target.getContainingFile();
if (newFile == file) return;
if (!PyCodeInsightSettings.getInstance().PREFER_FROM_IMPORT) {
final String name = newClass.getQualifiedName();
AddImportHelper.addImportStatement(file, name, null);
final String importableName = ResolveImportUtil.findShortestImportableName(target, vFile);
if (!PyCodeInsightSettings.getInstance().PREFER_FROM_IMPORT || newClass instanceof PyFile) {
if (newClass instanceof PyFile) {
AddImportHelper.addImportStatement(file, importableName, null);
} else {
final String name = newClass.getName();
AddImportHelper.addImportStatement(file, importableName + "." + name, null);
}
} else {
AddImportHelper.addImportFrom(file, ResolveImportUtil.findShortestImportableName(target, vFile), newClass.getName());
AddImportHelper.addImportFrom(file, importableName, newClass.getName());
}
}
public static Set<PyClass> rememberClassReferences(final List<PyFunction> methods, final Collection<PyClass> extraClasses) {
final HashSet<PyClass> result = new HashSet<PyClass>(extraClasses);
private static void rememberNamedReferences(final List<PyFunction> methods) {
for (PyFunction method : methods) {
method.acceptChildren(new PyRecursiveElementVisitor() {
@Override
public void visitPyReferenceExpression(PyReferenceExpression node) {
super.visitPyReferenceExpression(node);
final PsiPolyVariantReference ref = node.getReference();
final PsiElement target = ref.resolve();
if (target instanceof PyClass) {
result.add((PyClass)target);
}
rememberReference(node);
}
});
}
return result;
}
public static void restoreImports(final PyClass target, final Set<PyClass> rememberedSet) {
for (PyClass clazz : rememberedSet) {
insertImport(target, clazz);
}
}
public static void restoreImports(final PyClass target, final PyClass origin, final Set<PyClass> rememberedSet) {
if (target.getContainingFile() != origin.getContainingFile()) {
restoreImports(target, rememberedSet);
private static final Key<PsiNamedElement> ENCODED_IMPORT = Key.create("PyEncodedImport");
private static void rememberReference(PyReferenceExpression node) {
// we will remember reference in deepest node
if (node.getQualifier() instanceof PyReferenceExpression) return;
final PsiPolyVariantReference ref = node.getReference();
final PsiElement target = ref.resolve();
if (target instanceof PsiNamedElement) {
node.putCopyableUserData(ENCODED_IMPORT, (PsiNamedElement)target);
}
}
}

View File

@@ -9,10 +9,7 @@ import com.intellij.openapi.util.Ref;
import com.intellij.openapi.vfs.VfsUtil;
import com.intellij.openapi.vfs.VirtualFile;
import com.intellij.openapi.vfs.VirtualFileManager;
import com.intellij.psi.PsiDirectory;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiFile;
import com.intellij.psi.PsiManager;
import com.intellij.psi.*;
import com.intellij.refactoring.RefactoringBundle;
import com.intellij.util.PathUtil;
import com.jetbrains.python.PyNames;
@@ -63,18 +60,19 @@ public class PyExtractSuperclassHelper {
public void run() {
ApplicationManager.getApplication().runWriteAction(new Runnable() {
public void run() {
final Set<PyClass> rememberedSet = PyClassRefactoringUtil.rememberClassReferences(methods, extractedClasses);
final PyElement[] elements = methods.toArray(new PyElement[methods.size()]);
final String text = PyClassRefactoringUtil.prepareClassText(clazz, elements, true, true, superBaseName) + "\n";
final PyClass newClass = PyElementGenerator.getInstance(project).createFromText(LanguageLevel.getDefault(), PyClass.class, text);
final String text = "class " + superBaseName + ":\n pass" + "\n";
PyClass newClass = PyElementGenerator.getInstance(project).createFromText(LanguageLevel.getDefault(), PyClass.class, text);
newClass = placeNewClass(project, newClass, clazz, targetFile);
newClassRef.set(newClass);
PyClassRefactoringUtil.moveMethods(methods, newClass);
PyClassRefactoringUtil.moveSuperclasses(clazz, superClasses, newClass);
PyClassRefactoringUtil.addSuperclasses(project, clazz, null, Collections.singleton(superBaseName));
PyClassRefactoringUtil.insertImport(newClass, extractedClasses);
if (elements.length > 0) {
PyPsiUtils.removeElements(elements);
}
PyClassRefactoringUtil.insertPassIfNeeded(clazz);
placeNewClass(project, newClass, clazz, targetFile, rememberedSet);
}
});
}
@@ -82,12 +80,11 @@ public class PyExtractSuperclassHelper {
return newClassRef.get();
}
private static void placeNewClass(Project project, PyClass newClass, PyClass clazz, String targetFile, Set<PyClass> rememberedSet) {
private static PyClass placeNewClass(Project project, PyClass newClass, PyClass clazz, String targetFile) {
VirtualFile file = VirtualFileManager.getInstance().findFileByUrl(ApplicationManagerEx.getApplicationEx().isUnitTestMode() ? targetFile : VfsUtil.pathToUrl(targetFile));
// file is the same as the source
if (file == clazz.getContainingFile().getVirtualFile()) {
PyPsiUtils.addBeforeInParent(clazz, newClass, newClass.getNextSibling());
return;
return (PyClass)clazz.getParent().addBefore(newClass, clazz);
}
PsiFile psiFile = null;
@@ -125,8 +122,8 @@ public class PyExtractSuperclassHelper {
LOG.assertTrue(psiFile != null);
newClass = (PyClass)psiFile.add(newClass);
PyClassRefactoringUtil.insertImport(clazz, newClass);
PyClassRefactoringUtil.restoreImports(newClass, rememberedSet);
PyClassRefactoringUtil.insertImport(clazz, Collections.singleton(newClass));
return newClass;
}
private static String constructFilename(PyClass newClass) {

View File

@@ -32,7 +32,7 @@ public class PyPullUpHelper {
}
else LOG.error("unmatched member class " + element.getClass());
}
final Set<PyClass> rememberedSet = PyClassRefactoringUtil.rememberClassReferences(methods, extractedClasses);
CommandProcessor.getInstance().executeCommand(clazz.getProject(), new Runnable() {
public void run() {
ApplicationManager.getApplication().runWriteAction(new Runnable() {
@@ -42,10 +42,8 @@ public class PyPullUpHelper {
// move superclasses declarations
PyClassRefactoringUtil.moveSuperclasses(clazz, superClasses, superClass);
PyClassRefactoringUtil.insertImport(superClass, extractedClasses);
PyClassRefactoringUtil.insertPassIfNeeded(clazz);
PyClassRefactoringUtil.restoreImports(superClass, clazz, rememberedSet);
}
});
}

View File

@@ -77,7 +77,6 @@ public class PyPushDownProcessor extends BaseRefactoringProcessor {
}
else LOG.error("unmatched member class " + element.getClass());
}
final Set<PyClass> rememberedSet = PyClassRefactoringUtil.rememberClassReferences(methods, extractedClasses);
final PyElement[] elements = methods.toArray(new PyElement[methods.size()]);
final List<PyExpression> superClassesElements = PyClassRefactoringUtil.removeAndGetSuperClasses(myClass, superClasses);
@@ -86,7 +85,7 @@ public class PyPushDownProcessor extends BaseRefactoringProcessor {
final PyClass targetClass = (PyClass)usage.getElement();
PyClassRefactoringUtil.addMethods(targetClass, elements, false);
PyClassRefactoringUtil.addSuperclasses(myClass.getProject(), targetClass, superClassesElements, superClasses);
PyClassRefactoringUtil.restoreImports(targetClass, myClass, rememberedSet);
PyClassRefactoringUtil.insertImport(targetClass, extractedClasses);
}
if (methods.size() != 0) {

View File

@@ -0,0 +1,5 @@
import os
from refactoring.extractsuperclass.suppa import Suppa
class A(Suppa):
pass

View File

@@ -0,0 +1,5 @@
import os
class A(object):
def foo(self):
os.stat_result.n_fields()

View File

@@ -7,7 +7,7 @@ class B(A):
def meth_b2(self):
pass
def meth_a1(self, name={}):
def meth_a1(self, name = {}):
pass
def meth_a2(self):
@@ -17,7 +17,7 @@ class D(A):
def meth_d1(self):
pass
def meth_a1(self, name={}):
def meth_a1(self, name = {}):
pass
def meth_a2(self):

View File

@@ -13,14 +13,18 @@ import java.util.List;
*/
public class PyExtractSuperclassTest extends PyClassRefactoringTest {
public void testSimple() throws Exception {
doHelperTest("Foo", "Suppa", null, ".foo");
doHelperTest("Foo", "Suppa", null, true, ".foo");
}
public void testWithSuper() throws Exception {
doHelperTest("Foo", "Suppa", null, ".foo");
doHelperTest("Foo", "Suppa", null, true, ".foo");
}
private void doHelperTest(final String className, final String superclassName, final String expectedError, final String... membersName) throws Exception {
public void testWithImport() throws Exception {
doHelperTest("A", "Suppa", null, false, ".foo");
}
private void doHelperTest(final String className, final String superclassName, final String expectedError, final boolean sameFile, final String... membersName) throws Exception {
try {
String baseName = "/refactoring/extractsuperclass/" + getTestName(true);
myFixture.configureByFile(baseName + ".before.py");
@@ -36,7 +40,9 @@ public class PyExtractSuperclassTest extends PyClassRefactoringTest {
@Override
protected void run() throws Throwable {
//noinspection ConstantConditions
PyExtractSuperclassHelper.extractSuperclass(clazz, members, superclassName, myFixture.getFile().getVirtualFile().getUrl());
final String url = sameFile ? myFixture.getFile().getVirtualFile().getUrl() :
myFixture.getFile().getVirtualFile().getParent().getUrl();
PyExtractSuperclassHelper.extractSuperclass(clazz, members, superclassName, url);
}
}.execute();
myFixture.checkResultByFile(baseName + ".after.py");