change PSI for nested comprehensions (PY-3030)

This commit is contained in:
Dmitry Jemerov
2011-03-23 16:30:14 +01:00
parent 00d5bc4721
commit 31491c42a2
13 changed files with 111 additions and 26 deletions

View File

@@ -144,9 +144,7 @@ public class ExpressionParsing extends Parsing {
nextToken();
break;
}
if (myBuilder.getTokenType() == PyTokenTypes.FOR_KEYWORD) {
expr.done(exprType);
expr = expr.precede();
if (atToken(PyTokenTypes.FOR_KEYWORD)) {
continue;
}
myBuilder.error(message("PARSE.expected.for.or.bracket"));

View File

@@ -0,0 +1,7 @@
package com.jetbrains.python.psi;
/**
* @author yole
*/
public interface ComprehensionComponent {
}

View File

@@ -5,7 +5,7 @@ package com.jetbrains.python.psi;
* User: dcheryasov
* Date: Jul 31, 2008
*/
public interface ComprhForComponent {
PyExpression getIteratorVariable();
PyExpression getIteratedList();
public interface ComprhForComponent extends ComprehensionComponent {
PyExpression getIteratorVariable();
PyExpression getIteratedList();
}

View File

@@ -7,7 +7,7 @@ import org.jetbrains.annotations.Nullable;
* User: dcheryasov
* Date: Jul 31, 2008
*/
public interface ComprhIfComponent {
public interface ComprhIfComponent extends ComprehensionComponent {
@Nullable
PyExpression getTest();
}

View File

@@ -7,6 +7,7 @@ import java.util.List;
*/
public interface PyComprehensionElement extends PyExpression, NameDefiner {
PyExpression getResultExpression();
List<ComprehensionComponent> getComponents();
List<ComprhForComponent> getForComponents();
List<ComprhIfComponent> getIfComponents();
}

View File

@@ -37,8 +37,18 @@ public abstract class PyComprehensionElementImpl extends PyElementImpl implement
* @return all "for components"
*/
public List<ComprhForComponent> getForComponents() {
final List<ComprhForComponent> list = new ArrayList<ComprhForComponent>(5);
visitComponents(new ComprehensionElementVisitor() {
@Override
void visitForComponent(ComprhForComponent component) {
list.add(component);
}
});
return list;
}
private void visitComponents(ComprehensionElementVisitor visitor) {
ASTNode node = getNode().getFirstChildNode();
List<ComprhForComponent> list = new ArrayList<ComprhForComponent>(5);
while (node != null) {
IElementType type = node.getElementType();
ASTNode next = getNextExpression(node);
@@ -48,7 +58,7 @@ public abstract class PyComprehensionElementImpl extends PyElementImpl implement
if (next2 == null) break;
final PyExpression variable = (PyExpression)next.getPsi();
final PyExpression iterated = (PyExpression)next2.getPsi();
list.add(new ComprhForComponent() {
visitor.visitForComponent(new ComprhForComponent() {
public PyExpression getIteratorVariable() {
return variable;
}
@@ -58,21 +68,9 @@ public abstract class PyComprehensionElementImpl extends PyElementImpl implement
}
});
}
node = node.getTreeNext();
}
return list;
}
public List<ComprhIfComponent> getIfComponents() {
ASTNode node = getNode().getFirstChildNode();
List<ComprhIfComponent> list = new ArrayList<ComprhIfComponent>(5);
while (node != null) {
IElementType type = node.getElementType();
ASTNode next = getNextExpression(node);
if (next == null) break;
if (type == PyTokenTypes.IF_KEYWORD) {
else if (type == PyTokenTypes.IF_KEYWORD) {
final PyExpression test = (PyExpression)next.getPsi();
list.add(new ComprhIfComponent() {
visitor.visitIfComponent(new ComprhIfComponent() {
public PyExpression getTest() {
return test;
}
@@ -80,6 +78,32 @@ public abstract class PyComprehensionElementImpl extends PyElementImpl implement
}
node = node.getTreeNext();
}
}
public List<ComprhIfComponent> getIfComponents() {
final List<ComprhIfComponent> list = new ArrayList<ComprhIfComponent>(5);
visitComponents(new ComprehensionElementVisitor() {
@Override
void visitIfComponent(ComprhIfComponent component) {
list.add(component);
}
});
return list;
}
public List<ComprehensionComponent> getComponents() {
final List<ComprehensionComponent> list = new ArrayList<ComprehensionComponent>(5);
visitComponents(new ComprehensionElementVisitor() {
@Override
void visitForComponent(ComprhForComponent component) {
list.add(component);
}
@Override
void visitIfComponent(ComprhIfComponent component) {
list.add(component);
}
});
return list;
}
@@ -117,4 +141,12 @@ public abstract class PyComprehensionElementImpl extends PyElementImpl implement
public boolean mustResolveOutside() {
return false;
}
abstract class ComprehensionElementVisitor {
void visitIfComponent(ComprhIfComponent component) {
}
void visitForComponent(ComprhForComponent component) {
}
}
}

View File

@@ -148,9 +148,7 @@ public class PyResolveUtil {
// maybe we're capped by a class? param lists are not capped though syntactically inside the function.
if (is_outside_param_list && refersFromMethodToClass(capFunction, seeker)) continue;
// names defined in a comprehension element are only visible inside it or the list comp expressions directly above it
if (seeker instanceof PyComprehensionElement &&
!(seeker.getParent() instanceof PyComprehensionElement) &&
!PsiTreeUtil.isAncestor(seeker, elt, false)) {
if (seeker instanceof PyComprehensionElement && !PsiTreeUtil.isAncestor(seeker, elt, false)) {
continue;
}
// check what we got

View File

@@ -0,0 +1 @@
gen_object = (xx for bar in lst1 for xx in bar)

View File

@@ -0,0 +1,32 @@
PyFile:NestedGenerators.py
PyAssignmentStatement
PyTargetExpression: gen_object
PsiElement(Py:IDENTIFIER)('gen_object')
PsiWhiteSpace(' ')
PsiElement(Py:EQ)('=')
PsiWhiteSpace(' ')
PyGeneratorExpression
PsiElement(Py:LPAR)('(')
PyReferenceExpression: xx
PsiElement(Py:IDENTIFIER)('xx')
PsiWhiteSpace(' ')
PsiElement(Py:FOR_KEYWORD)('for')
PsiWhiteSpace(' ')
PyTargetExpression: bar
PsiElement(Py:IDENTIFIER)('bar')
PsiWhiteSpace(' ')
PsiElement(Py:IN_KEYWORD)('in')
PsiWhiteSpace(' ')
PyReferenceExpression: lst1
PsiElement(Py:IDENTIFIER)('lst1')
PsiWhiteSpace(' ')
PsiElement(Py:FOR_KEYWORD)('for')
PsiWhiteSpace(' ')
PyTargetExpression: xx
PsiElement(Py:IDENTIFIER)('xx')
PsiWhiteSpace(' ')
PsiElement(Py:IN_KEYWORD)('in')
PsiWhiteSpace(' ')
PyReferenceExpression: bar
PsiElement(Py:IDENTIFIER)('bar')
PsiElement(Py:RPAR)(')')

View File

@@ -0,0 +1,4 @@
def somefunc():
xx, yy = 1, 1
gen_object = (xx for y<caret>y in lst1 for xx in yy)
print(xx, yy)

View File

@@ -0,0 +1,4 @@
def somefunc():
xx, yy = 1, 1
gen_object = (xx for bar in lst1 for xx in bar)
print(xx, yy)

View File

@@ -259,6 +259,10 @@ public class PythonParsingTest extends ParsingTestCase {
doTest();
}
public void testNestedGenerators() { // PY-3030
doTest();
}
public void doTest() {
doTest(LanguageLevel.PYTHON25);
}

View File

@@ -54,6 +54,10 @@ public class PyRenameTest extends PyLightFixtureTestCase {
doTest("bar");
}
public void testRenameLocalWithNestedGenerators() { // PY-3030
doTest("bar");
}
public void testUpdateAll() { // PY-986
doTest("bar");
}