fixed renaming inside comprehensions, added binding in expressions
Showing
2 changed files
with
54 additions
and
37 deletions
... | @@ -66,9 +66,33 @@ def unparse (st) : | ... | @@ -66,9 +66,33 @@ def unparse (st) : |
66 | class Renamer (ast.NodeTransformer) : | 66 | class Renamer (ast.NodeTransformer) : |
67 | def __init__ (self, map_names) : | 67 | def __init__ (self, map_names) : |
68 | ast.NodeTransformer.__init__(self) | 68 | ast.NodeTransformer.__init__(self) |
69 | - self.map = map_names | 69 | + self.map = [map_names] |
70 | + def visit_ListComp (self, node) : | ||
71 | + bind = self.map[-1].copy() | ||
72 | + for comp in node.generators : | ||
73 | + for name in getvars(comp.target) : | ||
74 | + if name in bind : | ||
75 | + del bind[name] | ||
76 | + self.map.append(bind) | ||
77 | + node.elt = self.visit(node.elt) | ||
78 | + self.map.pop(-1) | ||
79 | + return node | ||
80 | + def visit_SetComp (self, node) : | ||
81 | + return self.visit_ListComp(node) | ||
82 | + def visit_DictComp (self, node) : | ||
83 | + bind = self.map[-1].copy() | ||
84 | + for comp in node.generators : | ||
85 | + for name in getvars(comp.target) : | ||
86 | + if name in bind : | ||
87 | + del bind[name] | ||
88 | + self.map.append(bind) | ||
89 | + node.key = self.visit(node.key) | ||
90 | + node.value = self.visit(node.value) | ||
91 | + self.map.pop(-1) | ||
92 | + return node | ||
70 | def visit_Name (self, node) : | 93 | def visit_Name (self, node) : |
71 | - return ast.copy_location(ast.Name(id=self.map.get(node.id, node.id), | 94 | + return ast.copy_location(ast.Name(id=self.map[-1].get(node.id, |
95 | + node.id), | ||
72 | ctx=ast.Load()), node) | 96 | ctx=ast.Load()), node) |
73 | 97 | ||
74 | def rename (expr, map={}, **ren) : | 98 | def rename (expr, map={}, **ren) : |
... | @@ -77,12 +101,39 @@ def rename (expr, map={}, **ren) : | ... | @@ -77,12 +101,39 @@ def rename (expr, map={}, **ren) : |
77 | '((t + y) < z)' | 101 | '((t + y) < z)' |
78 | >>> rename('x+y<z+f(3,t)', f='g', t='z', z='t') | 102 | >>> rename('x+y<z+f(3,t)', f='g', t='z', z='t') |
79 | '((x + y) < (t + g(3, z)))' | 103 | '((x + y) < (t + g(3, z)))' |
104 | + >>> rename('[x+y for x in range(3)]', x='z') | ||
105 | + '[(x + y) for x in range(3)]' | ||
106 | + >>> rename('[x+y for x in range(3)]', y='z') | ||
107 | + '[(x + z) for x in range(3)]' | ||
80 | """ | 108 | """ |
81 | map_names = dict(map) | 109 | map_names = dict(map) |
82 | map_names.update(ren) | 110 | map_names.update(ren) |
83 | transf = Renamer(map_names) | 111 | transf = Renamer(map_names) |
84 | return unparse(transf.visit(ast.parse(expr))) | 112 | return unparse(transf.visit(ast.parse(expr))) |
85 | 113 | ||
114 | +class Binder (Renamer) : | ||
115 | + def visit_Name (self, node) : | ||
116 | + if node.id in self.map[-1] : | ||
117 | + return self.map[-1][node.id] | ||
118 | + else : | ||
119 | + return node | ||
120 | + | ||
121 | +def bind (expr, map={}, **ren) : | ||
122 | + """ | ||
123 | + >>> bind('x+y<z', x=ast.Num(n=2)) | ||
124 | + '((2 + y) < z)' | ||
125 | + >>> bind('x+y<z', y=ast.Num(n=2)) | ||
126 | + '((x + 2) < z)' | ||
127 | + >>> bind('[x+y for x in range(3)]', x=ast.Num(n=2)) | ||
128 | + '[(x + y) for x in range(3)]' | ||
129 | + >>> bind('[x+y for x in range(3)]', y=ast.Num(n=2)) | ||
130 | + '[(x + 2) for x in range(3)]' | ||
131 | + """ | ||
132 | + map_names = dict(map) | ||
133 | + map_names.update(ren) | ||
134 | + transf = Binder(map_names) | ||
135 | + return unparse(transf.visit(ast.parse(expr))) | ||
136 | + | ||
86 | if __name__ == "__main__" : | 137 | if __name__ == "__main__" : |
87 | import doctest | 138 | import doctest |
88 | doctest.testmod() | 139 | doctest.testmod() | ... | ... |
1 | from snakes.lang.ctlstar.parser import parse | 1 | from snakes.lang.ctlstar.parser import parse |
2 | from snakes.lang.ctlstar import asdl as ast | 2 | from snakes.lang.ctlstar import asdl as ast |
3 | -from snakes.lang import getvars, unparse | 3 | +from snakes.lang import getvars, bind |
4 | import _ast | 4 | import _ast |
5 | 5 | ||
6 | class SpecError (Exception) : | 6 | class SpecError (Exception) : |
... | @@ -19,40 +19,6 @@ def astcopy (node) : | ... | @@ -19,40 +19,6 @@ def astcopy (node) : |
19 | attr[name] = astcopy(value) | 19 | attr[name] = astcopy(value) |
20 | return node.__class__(**attr) | 20 | return node.__class__(**attr) |
21 | 21 | ||
22 | -class Binder (ast.NodeTransformer) : | ||
23 | - def __init__ (self, bind) : | ||
24 | - ast.NodeTransformer.__init__(self) | ||
25 | - self.bind = [bind] | ||
26 | - def visit (self, node) : | ||
27 | - return ast.NodeTransformer.visit(self, astcopy(node)) | ||
28 | - def visit_ListComp (self, node) : | ||
29 | - """ | ||
30 | - >>> tree = ast.parse('x+y+[x+y+z for x, y in l]') | ||
31 | - >>> unparse(Binder({'x':ast.Name('hello')}).visit(tree)) | ||
32 | - '((hello + y) + [((x + y) + z) for (x, y) in l])' | ||
33 | - >>> unparse(Binder({'y':ast.Name('hello')}).visit(tree)) | ||
34 | - '((x + hello) + [((x + y) + z) for (x, y) in l])' | ||
35 | - >>> unparse(Binder({'z':ast.Name('hello')}).visit(tree)) | ||
36 | - '((x + y) + [((x + y) + hello) for (x, y) in l])' | ||
37 | - """ | ||
38 | - bind = self.bind[-1].copy() | ||
39 | - for comp in node.generators : | ||
40 | - for name in getvars(comp.target) : | ||
41 | - if name in bind : | ||
42 | - del bind[name] | ||
43 | - self.bind.append(bind) | ||
44 | - node.elt = self.visit(node.elt) | ||
45 | - self.bind.pop(-1) | ||
46 | - return node | ||
47 | - def visit_Name (self, node) : | ||
48 | - if node.id in self.bind[-1] : | ||
49 | - return astcopy(self.bind[-1][node.id]) | ||
50 | - else : | ||
51 | - return astcopy(node) | ||
52 | - | ||
53 | -def bind (node, ctx) : | ||
54 | - return Binder(ctx).visit(node) | ||
55 | - | ||
56 | class Builder (object) : | 22 | class Builder (object) : |
57 | def __init__ (self, spec) : | 23 | def __init__ (self, spec) : |
58 | self.spec = spec | 24 | self.spec = spec | ... | ... |
-
Please register or login to post a comment