PY-20100 Refactor PyImportOptimizer according to comments in IDEA-CR-37663

First of all, removed duplication and slightly simplified convoluted
transformImportStatements() method.

Also, extended existing test to include handling of comments when splitting
"from" imports.
This commit is contained in:
Mikhail Golubev
2018-10-09 21:27:25 +03:00
parent 3a7dddbb34
commit 0eb24a8b98
3 changed files with 124 additions and 107 deletions

View File

@@ -103,6 +103,8 @@ public class PyImportOptimizer implements ImportOptimizer {
// Contains trailing and nested comments of modified (split and joined) imports
private final MultiMap<PyImportStatementBase, PsiComment> myNewImportToInnerComments = MultiMap.create();
private final List<PsiComment> myDanglingComments = new ArrayList<>();
private final PyElementGenerator myGenerator;
private final LanguageLevel myLangLevel;
private ImportSorter(@NotNull PyFile file) {
myFile = file;
@@ -112,6 +114,8 @@ public class PyImportOptimizer implements ImportOptimizer {
for (ImportPriority priority : ImportPriority.values()) {
myGroups.put(priority, new ArrayList<>());
}
myGenerator = PyElementGenerator.getInstance(myFile.getProject());
myLangLevel = LanguageLevel.forElement(myFile);
}
@NotNull
@@ -141,7 +145,7 @@ public class PyImportOptimizer implements ImportOptimizer {
boolean hasTransformedImports = false;
for (ImportPriority priority : ImportPriority.values()) {
final List<PyImportStatementBase> original = myGroups.get(priority);
final List<PyImportStatementBase> transformed = transformImportStatements(original);
final List<PyImportStatementBase> transformed = transformImportsInGroup(original);
hasTransformedImports |= !original.equals(transformed);
myGroups.put(priority, transformed);
}
@@ -167,119 +171,128 @@ public class PyImportOptimizer implements ImportOptimizer {
}
@NotNull
private List<PyImportStatementBase> transformImportStatements(@NotNull List<PyImportStatementBase> imports) {
private List<PyImportStatementBase> transformImportsInGroup(@NotNull List<PyImportStatementBase> imports) {
final List<PyImportStatementBase> result = new ArrayList<>();
final PyElementGenerator generator = PyElementGenerator.getInstance(myFile.getProject());
final LanguageLevel langLevel = LanguageLevel.forElement(myFile);
for (PyImportStatementBase statement : imports) {
if (statement instanceof PyImportStatement) {
final PyImportStatement importStatement = (PyImportStatement)statement;
final PyImportElement[] importElements = importStatement.getImportElements();
// Split combined imports like "import foo, bar as b"
if (importElements.length > 1) {
final List<PyImportStatement> newImports =
ContainerUtil.map(importElements, e -> generator.createImportStatement(langLevel, e.getText(), null));
final PyImportStatement topmostImport;
if (myPySettings.OPTIMIZE_IMPORTS_SORT_IMPORTS) {
topmostImport = Collections.min(newImports, AddImportHelper.getSameGroupImportsComparator(myFile));
}
else {
topmostImport = newImports.get(0);
}
myNewImportToLineComments.putValues(topmostImport, myOldImportToLineComments.get(statement));
myNewImportToInnerComments.putValues(topmostImport, myOldImportToInnerComments.get(statement));
result.addAll(newImports);
}
else {
myNewImportToLineComments.putValues(statement, myOldImportToLineComments.get(statement));
result.add(importStatement);
}
transformPlainImport(result, (PyImportStatement)statement);
}
else if (statement instanceof PyFromImportStatement) {
final PyFromImportStatement fromImport = (PyFromImportStatement)statement;
if (fromImport.isStarImport()) {
myNewImportToLineComments.putValues(fromImport, myOldImportToLineComments.get(fromImport));
result.add(fromImport);
continue;
}
final String source = getNormalizedFromImportSource(fromImport);
final PyImportElement[] importedFromNames = fromImport.getImportElements();
final List<PyImportElement> newFromImportNames = new ArrayList<>();
final Comparator<PyImportElement> fromNamesComparator = getFromNamesComparator();
// We can neither sort, nor combine star imports
final Collection<PyFromImportStatement> sameSourceImports = myOldFromImportBySources.get(source);
if (sameSourceImports.isEmpty()) {
continue;
}
// Keep existing parentheses if we only re-order names inside the import
boolean forceParentheses = sameSourceImports.size() == 1 && fromImport.getLeftParen() != null;
// Join multiple "from" imports with the same source, like "from module import foo; from module import bar as b"
final boolean shouldJoinImports = myPySettings.OPTIMIZE_IMPORTS_JOIN_FROM_IMPORTS_WITH_SAME_SOURCE && sameSourceImports.size() > 1;
final boolean shouldSplitImport = myPySettings.OPTIMIZE_IMPORTS_ALWAYS_SPLIT_FROM_IMPORTS && importedFromNames.length > 1;
if (shouldJoinImports) {
for (PyFromImportStatement sameSourceImport : sameSourceImports) {
ContainerUtil.addAll(newFromImportNames, sameSourceImport.getImportElements());
}
// Remember that we have checked imports with this source already
myOldFromImportBySources.remove(source);
}
else if (!shouldSplitImport && myPySettings.OPTIMIZE_IMPORTS_SORT_NAMES_IN_FROM_IMPORTS) {
if (!Ordering.from(fromNamesComparator).isOrdered(Arrays.asList(importedFromNames))) {
ContainerUtil.addAll(newFromImportNames, importedFromNames);
}
}
final boolean shouldGenerateNewFromImport = !newFromImportNames.isEmpty();
if (shouldGenerateNewFromImport) {
if (myPySettings.OPTIMIZE_IMPORTS_SORT_NAMES_IN_FROM_IMPORTS) {
Collections.sort(newFromImportNames, fromNamesComparator);
}
String importedNames = StringUtil.join(newFromImportNames, ImportSorter::getNormalizedImportElementText, ", ");
if (forceParentheses) {
importedNames = "(" + importedNames + ")";
}
final PyFromImportStatement combinedImport = generator.createFromImportStatement(langLevel, source, importedNames, null);
ContainerUtil.map2LinkedSet(newFromImportNames, e -> (PyImportStatementBase)e.getParent()).forEach(affected -> {
myNewImportToLineComments.putValues(combinedImport, myOldImportToLineComments.get(affected));
myNewImportToInnerComments.putValues(combinedImport, myOldImportToInnerComments.get(affected));
});
result.add(combinedImport);
}
else if (shouldSplitImport) {
final List<PyFromImportStatement> newFromImports = ContainerUtil.map(importedFromNames, importElem -> {
final String name = Objects.toString(importElem.getImportedQName(), "");
final String alias = importElem.getAsName();
return generator.createFromImportStatement(langLevel, source, name, alias);
});
PyFromImportStatement topmostImport;
if (myPySettings.OPTIMIZE_IMPORTS_SORT_IMPORTS) {
topmostImport = Collections.min(newFromImports, AddImportHelper.getSameGroupImportsComparator(myFile));
}
else {
topmostImport = newFromImports.get(0);
}
myNewImportToLineComments.putValues(topmostImport, myOldImportToLineComments.get(fromImport));
myNewImportToInnerComments.putValues(topmostImport, myOldImportToInnerComments.get(fromImport));
result.addAll(newFromImports);
}
else {
myNewImportToLineComments.putValues(fromImport, myOldImportToLineComments.get(fromImport));
result.add(fromImport);
}
transformFromImport(result, (PyFromImportStatement)statement);
}
}
return result;
}
private void transformPlainImport(@NotNull List<PyImportStatementBase> result, @NotNull PyImportStatement importStatement) {
final PyImportElement[] importElements = importStatement.getImportElements();
// Split combined imports like "import foo, bar as b"
if (importElements.length > 1) {
final List<PyImportStatement> newImports =
ContainerUtil.map(importElements, e -> myGenerator.createImportStatement(myLangLevel, e.getText(), null));
replaceOneImportWithSeveral(result, importStatement, newImports);
}
else {
addImportAsIs(result, importStatement);
}
}
private void transformFromImport(@NotNull List<PyImportStatementBase> result, @NotNull PyFromImportStatement fromImport) {
// We can neither sort, nor combine star imports
if (fromImport.isStarImport()) {
addImportAsIs(result, fromImport);
return;
}
final String source = getNormalizedFromImportSource(fromImport);
final PyImportElement[] importedFromNames = fromImport.getImportElements();
final List<PyImportElement> newFromImportNames = new ArrayList<>();
final Comparator<PyImportElement> fromNamesComparator = getFromNamesComparator();
final Collection<PyFromImportStatement> sameSourceImports = myOldFromImportBySources.get(source);
if (sameSourceImports.isEmpty()) {
return;
}
// Keep existing parentheses if we only re-order names inside the import
boolean forceParentheses = sameSourceImports.size() == 1 && fromImport.getLeftParen() != null;
// Join multiple "from" imports with the same source, like "from module import foo; from module import bar as b"
final boolean shouldJoinImports = myPySettings.OPTIMIZE_IMPORTS_JOIN_FROM_IMPORTS_WITH_SAME_SOURCE && sameSourceImports.size() > 1;
final boolean shouldSplitImport = myPySettings.OPTIMIZE_IMPORTS_ALWAYS_SPLIT_FROM_IMPORTS && importedFromNames.length > 1;
if (shouldJoinImports) {
for (PyFromImportStatement sameSourceImport : sameSourceImports) {
ContainerUtil.addAll(newFromImportNames, sameSourceImport.getImportElements());
}
// Remember that we have checked imports with this source already
myOldFromImportBySources.remove(source);
}
else if (!shouldSplitImport && myPySettings.OPTIMIZE_IMPORTS_SORT_NAMES_IN_FROM_IMPORTS) {
if (!Ordering.from(fromNamesComparator).isOrdered(Arrays.asList(importedFromNames))) {
ContainerUtil.addAll(newFromImportNames, importedFromNames);
}
}
final boolean shouldGenerateNewFromImport = !newFromImportNames.isEmpty();
if (shouldGenerateNewFromImport) {
if (myPySettings.OPTIMIZE_IMPORTS_SORT_NAMES_IN_FROM_IMPORTS) {
Collections.sort(newFromImportNames, fromNamesComparator);
}
String importedNames = StringUtil.join(newFromImportNames, ImportSorter::getNormalizedImportElementText, ", ");
if (forceParentheses) {
importedNames = "(" + importedNames + ")";
}
final PyFromImportStatement combinedImport = myGenerator.createFromImportStatement(myLangLevel, source, importedNames, null);
final Set<PyImportStatementBase> oldImports = ContainerUtil.map2LinkedSet(newFromImportNames,
e -> (PyImportStatementBase)e.getParent());
replaceSeveralImportsWithOne(result, oldImports, combinedImport);
}
else if (shouldSplitImport) {
final List<PyFromImportStatement> newFromImports = ContainerUtil.map(importedFromNames, importElem -> {
final String name = Objects.toString(importElem.getImportedQName(), "");
final String alias = importElem.getAsName();
return myGenerator.createFromImportStatement(myLangLevel, source, name, alias);
});
replaceOneImportWithSeveral(result, fromImport, newFromImports);
}
else {
addImportAsIs(result, fromImport);
}
}
private void replaceSeveralImportsWithOne(@NotNull List<PyImportStatementBase> result,
@NotNull Collection<? extends PyImportStatementBase> oldImports,
@NotNull PyFromImportStatement newImport) {
for (PyImportStatementBase replaced : oldImports) {
myNewImportToLineComments.putValues(newImport, myOldImportToLineComments.get(replaced));
myNewImportToInnerComments.putValues(newImport, myOldImportToInnerComments.get(replaced));
}
result.add(newImport);
}
private void replaceOneImportWithSeveral(@NotNull List<PyImportStatementBase> result,
@NotNull PyImportStatementBase oldImport,
@NotNull Collection<? extends PyImportStatementBase> newImports) {
final PyImportStatementBase topmostImport;
if (myPySettings.OPTIMIZE_IMPORTS_SORT_IMPORTS) {
topmostImport = Collections.min(newImports, AddImportHelper.getSameGroupImportsComparator(myFile));
}
else {
topmostImport = ContainerUtil.getFirstItem(newImports);
}
myNewImportToLineComments.putValues(topmostImport, myOldImportToLineComments.get(oldImport));
myNewImportToInnerComments.putValues(topmostImport, myOldImportToInnerComments.get(oldImport));
result.addAll(newImports);
}
private void addImportAsIs(@NotNull List<PyImportStatementBase> result, @NotNull PyImportStatementBase oldImport) {
myNewImportToLineComments.putValues(oldImport, myOldImportToLineComments.get(oldImport));
result.add(oldImport);
}
@NotNull
private static String getNormalizedImportElementText(@NotNull PyImportElement element) {
// Remove comments, line feeds and backslashes
@@ -373,8 +386,7 @@ public class PyImportOptimizer implements ImportOptimizer {
}
final Project project = anchor.getProject();
final PyElementGenerator generator = PyElementGenerator.getInstance(project);
final PyFile file = (PyFile)generator.createDummyFile(LanguageLevel.forElement(anchor), content.toString());
final PyFile file = (PyFile)myGenerator.createDummyFile(myLangLevel, content.toString());
final PyFile reformattedFile = (PyFile)CodeStyleManager.getInstance(project).reformat(file);
final List<PyImportStatementBase> newImportBlock = reformattedFile.getImportBlock();
assert newImportBlock != null;

View File

@@ -1,5 +1,6 @@
# line comment
from mod import bar
from mod import baz
from mod import baz # inner comment
from mod import foo
print(foo, bar, baz)

View File

@@ -1,4 +1,8 @@
from mod import foo, baz
# line comment
from mod import (
foo, # inner comment
baz
)
from mod import bar
print(foo, bar, baz)