optimize imports sorts them according to PEP-8 (PY-2367)

This commit is contained in:
Dmitry Jemerov
2012-10-02 13:52:55 +02:00
parent 543d2a61c5
commit 6bb86279d1
16 changed files with 225 additions and 48 deletions

View File

@@ -56,9 +56,9 @@ public abstract class PyElementGenerator {
@NotNull @NotNull
public abstract PyCallExpression createCallExpression(final LanguageLevel langLevel, String functionName); public abstract PyCallExpression createCallExpression(final LanguageLevel langLevel, String functionName);
public abstract PyImportStatement createImportStatementFromText(String text); public abstract PyImportStatement createImportStatementFromText(final LanguageLevel languageLevel, String text);
public abstract PyImportElement createImportElement(String name); public abstract PyImportElement createImportElement(final LanguageLevel languageLevel, String name);
@NotNull @NotNull
public abstract <T> T createFromText(LanguageLevel langLevel, Class<T> aClass, final String text); public abstract <T> T createFromText(LanguageLevel langLevel, Class<T> aClass, final String text);

View File

@@ -62,4 +62,9 @@ public interface PyFile extends PyElement, PsiFile, PyDocStringOwner, ScopeOwner
* @return the deprecation message or null if the function is not deprecated. * @return the deprecation message or null if the function is not deprecated.
*/ */
String getDeprecationMessage(); String getDeprecationMessage();
/**
* Returns the sequential list of import statements in the beginning of the file.
*/
List<PyImportStatementBase> getImportBlock();
} }

View File

@@ -10,11 +10,12 @@ import com.intellij.openapi.util.TextRange;
import com.intellij.openapi.util.text.LineTokenizer; import com.intellij.openapi.util.text.LineTokenizer;
import com.intellij.openapi.util.text.StringUtil; import com.intellij.openapi.util.text.StringUtil;
import com.intellij.psi.PsiElement; import com.intellij.psi.PsiElement;
import com.intellij.psi.TokenType;
import com.intellij.psi.tree.IElementType; import com.intellij.psi.tree.IElementType;
import com.jetbrains.python.psi.PyFile; import com.jetbrains.python.psi.PyFile;
import com.jetbrains.python.psi.PyFileElementType; import com.jetbrains.python.psi.PyFileElementType;
import com.jetbrains.python.psi.PyImportStatementBase;
import com.jetbrains.python.psi.PyStringLiteralExpression; import com.jetbrains.python.psi.PyStringLiteralExpression;
import com.jetbrains.python.psi.impl.PyFileImpl;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable; import org.jetbrains.annotations.Nullable;
@@ -35,24 +36,12 @@ public class PythonFoldingBuilder extends CustomFoldingBuilder implements DumbAw
private static void appendDescriptors(ASTNode node, List<FoldingDescriptor> descriptors) { private static void appendDescriptors(ASTNode node, List<FoldingDescriptor> descriptors) {
if (node.getElementType() instanceof PyFileElementType) { if (node.getElementType() instanceof PyFileElementType) {
ASTNode firstImport = node.getFirstChildNode(); final List<PyImportStatementBase> imports = ((PyFile)node.getPsi()).getImportBlock();
while(firstImport != null && !isImport(firstImport, false)) { if (imports.size() > 1) {
firstImport = firstImport.getTreeNext(); final PyImportStatementBase firstImport = imports.get(0);
} final PyImportStatementBase lastImport = imports.get(imports.size()-1);
if (firstImport != null) { descriptors.add(new FoldingDescriptor(firstImport, new TextRange(firstImport.getTextRange().getStartOffset(),
ASTNode lastImport = firstImport.getTreeNext(); lastImport.getTextRange().getEndOffset())));
while(lastImport != null && isImport(lastImport.getTreeNext(), true)) {
lastImport = lastImport.getTreeNext();
}
if (lastImport != null) {
while (lastImport.getElementType() == TokenType.WHITE_SPACE) {
lastImport = lastImport.getTreePrev();
}
if (isImport(lastImport, false) && firstImport != lastImport) {
descriptors.add(new FoldingDescriptor(firstImport, new TextRange(firstImport.getStartOffset(),
lastImport.getTextRange().getEndOffset())));
}
}
} }
} }
else if (node.getElementType() == PyElementTypes.STATEMENT_LIST) { else if (node.getElementType() == PyElementTypes.STATEMENT_LIST) {
@@ -116,18 +105,9 @@ public class PythonFoldingBuilder extends CustomFoldingBuilder implements DumbAw
return null; return null;
} }
private static boolean isImport(ASTNode node, boolean orWhitespace) {
if (node == null) return false;
IElementType elementType = node.getElementType();
if (orWhitespace && elementType == TokenType.WHITE_SPACE) {
return true;
}
return elementType == PyElementTypes.IMPORT_STATEMENT || elementType == PyElementTypes.FROM_IMPORT_STATEMENT;
}
@Override @Override
protected String getLanguagePlaceholderText(@NotNull ASTNode node, @NotNull TextRange range) { protected String getLanguagePlaceholderText(@NotNull ASTNode node, @NotNull TextRange range) {
if (isImport(node, false)) { if (PyFileImpl.isImport(node, false)) {
return "import ..."; return "import ...";
} }
if (node.getElementType() == PyElementTypes.STRING_LITERAL_EXPRESSION) { if (node.getElementType() == PyElementTypes.STRING_LITERAL_EXPRESSION) {
@@ -143,7 +123,7 @@ public class PythonFoldingBuilder extends CustomFoldingBuilder implements DumbAw
@Override @Override
protected boolean isRegionCollapsedByDefault(@NotNull ASTNode node) { protected boolean isRegionCollapsedByDefault(@NotNull ASTNode node) {
if (isImport(node, false)) { if (PyFileImpl.isImport(node, false)) {
return CodeFoldingSettings.getInstance().COLLAPSE_IMPORTS; return CodeFoldingSettings.getInstance().COLLAPSE_IMPORTS;
} }
if (node.getElementType() == PyElementTypes.STRING_LITERAL_EXPRESSION) { if (node.getElementType() == PyElementTypes.STRING_LITERAL_EXPRESSION) {

View File

@@ -147,8 +147,9 @@ public class AddImportHelper {
} }
} }
final PyImportStatement importNodeToInsert = PyElementGenerator.getInstance(file.getProject()).createImportStatementFromText( final PyElementGenerator generator = PyElementGenerator.getInstance(file.getProject());
"import " + name + as_clause); final LanguageLevel languageLevel = LanguageLevel.forElement(file);
final PyImportStatement importNodeToInsert = generator.createImportStatementFromText(languageLevel, "import " + name + as_clause);
try { try {
file.addBefore(importNodeToInsert, getInsertPosition(file, name, priority)); file.addBefore(importNodeToInsert, getInsertPosition(file, name, priority));
} }
@@ -167,15 +168,15 @@ public class AddImportHelper {
* @param asName optional name for 'as' clause * @param asName optional name for 'as' clause
*/ */
public static void addImportFromStatement(PsiFile file, String from, String name, @Nullable String asName, ImportPriority priority) { public static void addImportFromStatement(PsiFile file, String from, String name, @Nullable String asName, ImportPriority priority) {
String as_clause; String asClause;
if (asName == null) { if (asName == null) {
as_clause = ""; asClause = "";
} }
else { else {
as_clause = " as " + asName; asClause = " as " + asName;
} }
final PyFromImportStatement importNodeToInsert = PyElementGenerator.getInstance(file.getProject()).createFromText( final PyFromImportStatement importNodeToInsert = PyElementGenerator.getInstance(file.getProject()).createFromText(
LanguageLevel.getDefault(), PyFromImportStatement.class, "from " + from + " import " + name + as_clause); LanguageLevel.forElement(file), PyFromImportStatement.class, "from " + from + " import " + name + asClause);
try { try {
file.addBefore(importNodeToInsert, getInsertPosition(file, from, priority)); file.addBefore(importNodeToInsert, getInsertPosition(file, from, priority));
} }
@@ -201,7 +202,8 @@ public class AddImportHelper {
return false; return false;
} }
} }
PyImportElement importElement = PyElementGenerator.getInstance(file.getProject()).createImportElement(name); final PyElementGenerator generator = PyElementGenerator.getInstance(file.getProject());
PyImportElement importElement = generator.createImportElement(LanguageLevel.forElement(file), name);
existingImport.add(importElement); existingImport.add(importElement);
return true; return true;
} }

View File

@@ -144,7 +144,7 @@ public class ImportFromExistingAction implements QuestionAction {
PsiElement parent = src.getParent(); PsiElement parent = src.getParent();
if (parent instanceof PyFromImportStatement) { if (parent instanceof PyFromImportStatement) {
// add another import element right after the one we got // add another import element right after the one we got
PsiElement newImportElement = gen.createImportElement(myName); PsiElement newImportElement = gen.createImportElement(LanguageLevel.getDefault(), myName);
parent.add(newImportElement); parent.add(newImportElement);
} }
else { // just 'import' else { // just 'import'

View File

@@ -2,13 +2,18 @@ package com.jetbrains.python.codeInsight.imports;
import com.intellij.codeInspection.LocalInspectionToolSession; import com.intellij.codeInspection.LocalInspectionToolSession;
import com.intellij.lang.ImportOptimizer; import com.intellij.lang.ImportOptimizer;
import com.intellij.psi.PsiElement;
import com.intellij.psi.PsiFile; import com.intellij.psi.PsiFile;
import com.intellij.psi.PsiFileSystemItem;
import com.jetbrains.python.formatter.PyBlock;
import com.jetbrains.python.inspections.PyUnresolvedReferencesInspection; import com.jetbrains.python.inspections.PyUnresolvedReferencesInspection;
import com.jetbrains.python.psi.PyElement; import com.jetbrains.python.psi.*;
import com.jetbrains.python.psi.PyRecursiveElementVisitor;
import org.jetbrains.annotations.NotNull; import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import java.util.ArrayList;
import java.util.Collections; import java.util.Collections;
import java.util.List;
/** /**
* @author yole * @author yole
@@ -19,7 +24,7 @@ public class PyImportOptimizer implements ImportOptimizer {
} }
@NotNull @NotNull
public Runnable processFile(@NotNull PsiFile file) { public Runnable processFile(@NotNull final PsiFile file) {
final LocalInspectionToolSession session = new LocalInspectionToolSession(file, 0, file.getTextLength()); final LocalInspectionToolSession session = new LocalInspectionToolSession(file, 0, file.getTextLength());
final PyUnresolvedReferencesInspection.Visitor visitor = new PyUnresolvedReferencesInspection.Visitor(null, final PyUnresolvedReferencesInspection.Visitor visitor = new PyUnresolvedReferencesInspection.Visitor(null,
session, session,
@@ -34,7 +39,105 @@ public class PyImportOptimizer implements ImportOptimizer {
return new Runnable() { return new Runnable() {
public void run() { public void run() {
visitor.optimizeImports(); visitor.optimizeImports();
if (file instanceof PyFile) {
new ImportSorter((PyFile) file).run();
}
} }
}; };
} }
private static class ImportSorter {
private final PyFile myFile;
private final List<PyImportStatementBase> myBuiltinImports = new ArrayList<PyImportStatementBase>();
private final List<PyImportStatementBase> myThirdPartyImports = new ArrayList<PyImportStatementBase>();
private final List<PyImportStatementBase> myProjectImports = new ArrayList<PyImportStatementBase>();
private final List<PyImportStatementBase> myImportBlock;
private final PyElementGenerator myGenerator;
private boolean myMissorted = false;
private ImportSorter(PyFile file) {
myFile = file;
myImportBlock = myFile.getImportBlock();
myGenerator = PyElementGenerator.getInstance(myFile.getProject());
}
public void run() {
if (myImportBlock.isEmpty()) {
return;
}
LanguageLevel langLevel = LanguageLevel.forElement(myFile);
for (PyImportStatementBase importStatement : myImportBlock) {
if (importStatement instanceof PyImportStatement && importStatement.getImportElements().length > 1) {
for (PyImportElement importElement : importStatement.getImportElements()) {
myMissorted = true;
PsiElement toImport = importElement.resolve();
final PyImportStatement splitImport = myGenerator.createImportStatementFromText(langLevel, "import " + importElement.getText());
prioritize(splitImport, toImport);
}
}
else {
PsiElement toImport;
if (importStatement instanceof PyFromImportStatement) {
toImport = ((PyFromImportStatement) importStatement).resolveImportSource();
}
else {
toImport = importStatement.getImportElements()[0].resolve();
}
prioritize(importStatement, toImport);
}
}
if (myMissorted) {
applyResults();
}
}
private void prioritize(PyImportStatementBase importStatement, @Nullable PsiElement toImport) {
if (toImport != null && !(toImport instanceof PsiFileSystemItem)) {
toImport = toImport.getContainingFile();
}
final AddImportHelper.ImportPriority priority = toImport == null
? AddImportHelper.ImportPriority.PROJECT
: AddImportHelper.getImportPriority(myFile, (PsiFileSystemItem)toImport);
if (priority == AddImportHelper.ImportPriority.BUILTIN) {
myBuiltinImports.add(importStatement);
if (!myThirdPartyImports.isEmpty() || !myProjectImports.isEmpty()) {
myMissorted = true;
}
}
else if (priority == AddImportHelper.ImportPriority.THIRD_PARTY) {
myThirdPartyImports.add(importStatement);
if (!myProjectImports.isEmpty()) {
myMissorted = true;
}
}
else {
myProjectImports.add(importStatement);
}
}
private void applyResults() {
markGroupBegin(myThirdPartyImports);
markGroupBegin(myProjectImports);
addImports(myBuiltinImports);
addImports(myThirdPartyImports);
addImports(myProjectImports);
PsiElement lastElement = myImportBlock.get(myImportBlock.size()-1);
myFile.deleteChildRange(myImportBlock.get(0), lastElement);
for (PyImportStatementBase anImport : myBuiltinImports) {
anImport.putCopyableUserData(PyBlock.IMPORT_GROUP_BEGIN, null);
}
}
private static void markGroupBegin(List<PyImportStatementBase> imports) {
if (imports.size() > 0) {
imports.get(0).putCopyableUserData(PyBlock.IMPORT_GROUP_BEGIN, true);
}
}
private void addImports(final List<PyImportStatementBase> imports) {
for (PyImportStatementBase newImport: imports) {
myFile.addBefore(newImport, myImportBlock.get(0));
}
}
}
} }

View File

@@ -3,6 +3,7 @@ package com.jetbrains.python.formatter;
import com.intellij.formatting.*; import com.intellij.formatting.*;
import com.intellij.lang.ASTNode; import com.intellij.lang.ASTNode;
import com.intellij.openapi.editor.Document; import com.intellij.openapi.editor.Document;
import com.intellij.openapi.util.Key;
import com.intellij.openapi.util.TextRange; import com.intellij.openapi.util.TextRange;
import com.intellij.psi.*; import com.intellij.psi.*;
import com.intellij.psi.impl.source.tree.TreeUtil; import com.intellij.psi.impl.source.tree.TreeUtil;
@@ -38,6 +39,8 @@ public class PyBlock implements ASTBlock {
private Alignment myChildAlignment; private Alignment myChildAlignment;
private static final boolean DUMP_FORMATTING_BLOCKS = false; private static final boolean DUMP_FORMATTING_BLOCKS = false;
public static final Key<Boolean> IMPORT_GROUP_BEGIN = Key.create("com.jetbrains.python.formatter.importGroupBegin");
private static final TokenSet ourListElementTypes = TokenSet.create(PyElementTypes.LIST_LITERAL_EXPRESSION, private static final TokenSet ourListElementTypes = TokenSet.create(PyElementTypes.LIST_LITERAL_EXPRESSION,
PyElementTypes.LIST_COMP_EXPRESSION, PyElementTypes.LIST_COMP_EXPRESSION,
PyElementTypes.DICT_COMP_EXPRESSION, PyElementTypes.DICT_COMP_EXPRESSION,
@@ -330,6 +333,14 @@ public class PyBlock implements ASTBlock {
@Nullable @Nullable
public Spacing getSpacing(Block child1, @NotNull Block child2) { public Spacing getSpacing(Block child1, @NotNull Block child2) {
if (child1 instanceof ASTBlock && child2 instanceof ASTBlock) {
final PsiElement psi1 = ((ASTBlock)child1).getNode().getPsi();
final PsiElement psi2 = ((ASTBlock)child2).getNode().getPsi();
if (psi1 instanceof PyImportStatementBase && psi2 instanceof PyImportStatementBase &&
psi2.getCopyableUserData(IMPORT_GROUP_BEGIN) != null) {
return Spacing.createSpacing(0, 0, 2, true, 1);
}
}
return myContext.getSpacingBuilder().getSpacing(this, child1, child2); return myContext.getSpacingBuilder().getSpacing(this, child1, child2);
} }

View File

@@ -220,14 +220,15 @@ public class PyElementGeneratorImpl extends PyElementGenerator {
throw new IllegalArgumentException("Invalid call expression text " + functionName); throw new IllegalArgumentException("Invalid call expression text " + functionName);
} }
public PyImportStatement createImportStatementFromText(final String text) { public PyImportStatement createImportStatementFromText(final LanguageLevel languageLevel,
final PsiFile dummyFile = createDummyFile(LanguageLevel.getDefault(), text); final String text) {
final PsiFile dummyFile = createDummyFile(languageLevel, text);
return (PyImportStatement)dummyFile.getFirstChild(); return (PyImportStatement)dummyFile.getFirstChild();
} }
@Override @Override
public PyImportElement createImportElement(String name) { public PyImportElement createImportElement(final LanguageLevel languageLevel, String name) {
return createFromText(LanguageLevel.getDefault(), PyImportElement.class, "from foo import " + name, new int[]{0, 6}); return createFromText(languageLevel, PyImportElement.class, "from foo import " + name, new int[]{0, 6});
} }
static final int[] FROM_ROOT = new int[]{0}; static final int[] FROM_ROOT = new int[]{0};

View File

@@ -1,6 +1,7 @@
package com.jetbrains.python.psi.impl; package com.jetbrains.python.psi.impl;
import com.intellij.extapi.psi.PsiFileBase; import com.intellij.extapi.psi.PsiFileBase;
import com.intellij.lang.ASTNode;
import com.intellij.lang.Language; import com.intellij.lang.Language;
import com.intellij.openapi.fileTypes.FileType; import com.intellij.openapi.fileTypes.FileType;
import com.intellij.openapi.util.Key; import com.intellij.openapi.util.Key;
@@ -9,6 +10,7 @@ import com.intellij.psi.*;
import com.intellij.psi.scope.PsiScopeProcessor; import com.intellij.psi.scope.PsiScopeProcessor;
import com.intellij.psi.stubs.StubElement; import com.intellij.psi.stubs.StubElement;
import com.intellij.psi.templateLanguages.TemplateLanguageFileViewProvider; import com.intellij.psi.templateLanguages.TemplateLanguageFileViewProvider;
import com.intellij.psi.tree.IElementType;
import com.intellij.psi.util.PsiModificationTracker; import com.intellij.psi.util.PsiModificationTracker;
import com.intellij.reference.SoftReference; import com.intellij.reference.SoftReference;
import com.intellij.util.IncorrectOperationException; import com.intellij.util.IncorrectOperationException;
@@ -661,6 +663,26 @@ public class PyFileImpl extends PsiFileBase implements PyFile, PyExpression {
return extractDeprecationMessage(); return extractDeprecationMessage();
} }
@Override
public List<PyImportStatementBase> getImportBlock() {
List<PyImportStatementBase> result = new ArrayList<PyImportStatementBase>();
ASTNode firstImport = getNode().getFirstChildNode();
while(firstImport != null && !isImport(firstImport, false)) {
firstImport = firstImport.getTreeNext();
}
if (firstImport != null) {
result.add(firstImport.getPsi(PyImportStatementBase.class));
ASTNode lastImport = firstImport.getTreeNext();
while(lastImport != null && isImport(lastImport.getTreeNext(), true)) {
if (isImport(lastImport, false)) {
result.add(lastImport.getPsi(PyImportStatementBase.class));
}
lastImport = lastImport.getTreeNext();
}
}
return result;
}
public String extractDeprecationMessage() { public String extractDeprecationMessage() {
return PyFunctionImpl.extractDeprecationMessage(getStatements()); return PyFunctionImpl.extractDeprecationMessage(getStatements());
} }
@@ -722,4 +744,13 @@ public class PyFileImpl extends PsiFileBase implements PyFile, PyExpression {
return new ArrayList<String>(); return new ArrayList<String>();
} }
} }
public static boolean isImport(ASTNode node, boolean orWhitespace) {
if (node == null) return false;
IElementType elementType = node.getElementType();
if (orWhitespace && elementType == TokenType.WHITE_SPACE) {
return true;
}
return elementType == PyElementTypes.IMPORT_STATEMENT || elementType == PyElementTypes.FROM_IMPORT_STATEMENT;
}
} }

View File

@@ -0,0 +1,7 @@
"""This is a module-level docstring."""
<fold text='import ...'>import os
import sys</fold>
print os.path
print sys.name

View File

@@ -0,0 +1,9 @@
import sys
import datetime
import foo
from bar import *
sys.path
datetime.datetime

View File

@@ -0,0 +1,7 @@
import foo
import sys
from bar import *
import datetime
sys.path
datetime.datetime

View File

@@ -0,0 +1,5 @@
import sys
import datetime
sys.path
datetime.time

View File

@@ -0,0 +1,4 @@
import sys, datetime
sys.path
datetime.time

View File

@@ -21,4 +21,8 @@ public class PyFoldingTest extends PyTestCase {
public void testCustomFolding() { public void testCustomFolding() {
doTest(); doTest();
} }
public void testImportBlock() {
doTest();
}
} }

View File

@@ -38,7 +38,15 @@ public class PyOptimizeImportsTest extends PyTestCase {
public void testSuppressed() { // PY-5228 public void testSuppressed() { // PY-5228
doTest(); doTest();
} }
public void testSplit() {
doTest();
}
public void testOrder() {
doTest();
}
private void doTest() { private void doTest() {
myFixture.configureByFile("optimizeImports/" + getTestName(true) + ".py"); myFixture.configureByFile("optimizeImports/" + getTestName(true) + ".py");