Franck Pommereau

fixed renaming inside comprehensions, added binding in expressions

...@@ -66,9 +66,33 @@ def unparse (st) : ...@@ -66,9 +66,33 @@ def unparse (st) :
66 class Renamer (ast.NodeTransformer) : 66 class Renamer (ast.NodeTransformer) :
67 def __init__ (self, map_names) : 67 def __init__ (self, map_names) :
68 ast.NodeTransformer.__init__(self) 68 ast.NodeTransformer.__init__(self)
69 - self.map = map_names 69 + self.map = [map_names]
70 + def visit_ListComp (self, node) :
71 + bind = self.map[-1].copy()
72 + for comp in node.generators :
73 + for name in getvars(comp.target) :
74 + if name in bind :
75 + del bind[name]
76 + self.map.append(bind)
77 + node.elt = self.visit(node.elt)
78 + self.map.pop(-1)
79 + return node
80 + def visit_SetComp (self, node) :
81 + return self.visit_ListComp(node)
82 + def visit_DictComp (self, node) :
83 + bind = self.map[-1].copy()
84 + for comp in node.generators :
85 + for name in getvars(comp.target) :
86 + if name in bind :
87 + del bind[name]
88 + self.map.append(bind)
89 + node.key = self.visit(node.key)
90 + node.value = self.visit(node.value)
91 + self.map.pop(-1)
92 + return node
70 def visit_Name (self, node) : 93 def visit_Name (self, node) :
71 - return ast.copy_location(ast.Name(id=self.map.get(node.id, node.id), 94 + return ast.copy_location(ast.Name(id=self.map[-1].get(node.id,
95 + node.id),
72 ctx=ast.Load()), node) 96 ctx=ast.Load()), node)
73 97
74 def rename (expr, map={}, **ren) : 98 def rename (expr, map={}, **ren) :
...@@ -77,12 +101,39 @@ def rename (expr, map={}, **ren) : ...@@ -77,12 +101,39 @@ def rename (expr, map={}, **ren) :
77 '((t + y) < z)' 101 '((t + y) < z)'
78 >>> rename('x+y<z+f(3,t)', f='g', t='z', z='t') 102 >>> rename('x+y<z+f(3,t)', f='g', t='z', z='t')
79 '((x + y) < (t + g(3, z)))' 103 '((x + y) < (t + g(3, z)))'
104 + >>> rename('[x+y for x in range(3)]', x='z')
105 + '[(x + y) for x in range(3)]'
106 + >>> rename('[x+y for x in range(3)]', y='z')
107 + '[(x + z) for x in range(3)]'
80 """ 108 """
81 map_names = dict(map) 109 map_names = dict(map)
82 map_names.update(ren) 110 map_names.update(ren)
83 transf = Renamer(map_names) 111 transf = Renamer(map_names)
84 return unparse(transf.visit(ast.parse(expr))) 112 return unparse(transf.visit(ast.parse(expr)))
85 113
114 +class Binder (Renamer) :
115 + def visit_Name (self, node) :
116 + if node.id in self.map[-1] :
117 + return self.map[-1][node.id]
118 + else :
119 + return node
120 +
121 +def bind (expr, map={}, **ren) :
122 + """
123 + >>> bind('x+y<z', x=ast.Num(n=2))
124 + '((2 + y) < z)'
125 + >>> bind('x+y<z', y=ast.Num(n=2))
126 + '((x + 2) < z)'
127 + >>> bind('[x+y for x in range(3)]', x=ast.Num(n=2))
128 + '[(x + y) for x in range(3)]'
129 + >>> bind('[x+y for x in range(3)]', y=ast.Num(n=2))
130 + '[(x + 2) for x in range(3)]'
131 + """
132 + map_names = dict(map)
133 + map_names.update(ren)
134 + transf = Binder(map_names)
135 + return unparse(transf.visit(ast.parse(expr)))
136 +
86 if __name__ == "__main__" : 137 if __name__ == "__main__" :
87 import doctest 138 import doctest
88 doctest.testmod() 139 doctest.testmod()
......
1 from snakes.lang.ctlstar.parser import parse 1 from snakes.lang.ctlstar.parser import parse
2 from snakes.lang.ctlstar import asdl as ast 2 from snakes.lang.ctlstar import asdl as ast
3 -from snakes.lang import getvars, unparse 3 +from snakes.lang import getvars, bind
4 import _ast 4 import _ast
5 5
6 class SpecError (Exception) : 6 class SpecError (Exception) :
...@@ -19,40 +19,6 @@ def astcopy (node) : ...@@ -19,40 +19,6 @@ def astcopy (node) :
19 attr[name] = astcopy(value) 19 attr[name] = astcopy(value)
20 return node.__class__(**attr) 20 return node.__class__(**attr)
21 21
22 -class Binder (ast.NodeTransformer) :
23 - def __init__ (self, bind) :
24 - ast.NodeTransformer.__init__(self)
25 - self.bind = [bind]
26 - def visit (self, node) :
27 - return ast.NodeTransformer.visit(self, astcopy(node))
28 - def visit_ListComp (self, node) :
29 - """
30 - >>> tree = ast.parse('x+y+[x+y+z for x, y in l]')
31 - >>> unparse(Binder({'x':ast.Name('hello')}).visit(tree))
32 - '((hello + y) + [((x + y) + z) for (x, y) in l])'
33 - >>> unparse(Binder({'y':ast.Name('hello')}).visit(tree))
34 - '((x + hello) + [((x + y) + z) for (x, y) in l])'
35 - >>> unparse(Binder({'z':ast.Name('hello')}).visit(tree))
36 - '((x + y) + [((x + y) + hello) for (x, y) in l])'
37 - """
38 - bind = self.bind[-1].copy()
39 - for comp in node.generators :
40 - for name in getvars(comp.target) :
41 - if name in bind :
42 - del bind[name]
43 - self.bind.append(bind)
44 - node.elt = self.visit(node.elt)
45 - self.bind.pop(-1)
46 - return node
47 - def visit_Name (self, node) :
48 - if node.id in self.bind[-1] :
49 - return astcopy(self.bind[-1][node.id])
50 - else :
51 - return astcopy(node)
52 -
53 -def bind (node, ctx) :
54 - return Binder(ctx).visit(node)
55 -
56 class Builder (object) : 22 class Builder (object) :
57 def __init__ (self, spec) : 23 def __init__ (self, spec) :
58 self.spec = spec 24 self.spec = spec
......