1
2
3
4
5 package astutil
6
7 import (
8 "fmt"
9 "go/ast"
10 "reflect"
11 "sort"
12 )
13
14
15
16
17
18
19
20 type ApplyFunc func(*Cursor) bool
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42 func Apply(root ast.Node, pre, post ApplyFunc) (result ast.Node) {
43 parent := &struct{ ast.Node }{root}
44 defer func() {
45 if r := recover(); r != nil && r != abort {
46 panic(r)
47 }
48 result = parent.Node
49 }()
50 a := &application{pre: pre, post: post}
51 a.apply(parent, "Node", nil, root)
52 return
53 }
54
55 var abort = new(int)
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74 type Cursor struct {
75 parent ast.Node
76 name string
77 iter *iterator
78 node ast.Node
79 }
80
81
82 func (c *Cursor) Node() ast.Node { return c.node }
83
84
85 func (c *Cursor) Parent() ast.Node { return c.parent }
86
87
88
89
90 func (c *Cursor) Name() string { return c.name }
91
92
93
94
95
96 func (c *Cursor) Index() int {
97 if c.iter != nil {
98 return c.iter.index
99 }
100 return -1
101 }
102
103
104 func (c *Cursor) field() reflect.Value {
105 return reflect.Indirect(reflect.ValueOf(c.parent)).FieldByName(c.name)
106 }
107
108
109
110 func (c *Cursor) Replace(n ast.Node) {
111 if _, ok := c.node.(*ast.File); ok {
112 file, ok := n.(*ast.File)
113 if !ok {
114 panic("attempt to replace *ast.File with non-*ast.File")
115 }
116 c.parent.(*ast.Package).Files[c.name] = file
117 return
118 }
119
120 v := c.field()
121 if i := c.Index(); i >= 0 {
122 v = v.Index(i)
123 }
124 v.Set(reflect.ValueOf(n))
125 }
126
127
128
129
130
131 func (c *Cursor) Delete() {
132 if _, ok := c.node.(*ast.File); ok {
133 delete(c.parent.(*ast.Package).Files, c.name)
134 return
135 }
136
137 i := c.Index()
138 if i < 0 {
139 panic("Delete node not contained in slice")
140 }
141 v := c.field()
142 l := v.Len()
143 reflect.Copy(v.Slice(i, l), v.Slice(i+1, l))
144 v.Index(l - 1).Set(reflect.Zero(v.Type().Elem()))
145 v.SetLen(l - 1)
146 c.iter.step--
147 }
148
149
150
151
152 func (c *Cursor) InsertAfter(n ast.Node) {
153 i := c.Index()
154 if i < 0 {
155 panic("InsertAfter node not contained in slice")
156 }
157 v := c.field()
158 v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
159 l := v.Len()
160 reflect.Copy(v.Slice(i+2, l), v.Slice(i+1, l))
161 v.Index(i + 1).Set(reflect.ValueOf(n))
162 c.iter.step++
163 }
164
165
166
167
168 func (c *Cursor) InsertBefore(n ast.Node) {
169 i := c.Index()
170 if i < 0 {
171 panic("InsertBefore node not contained in slice")
172 }
173 v := c.field()
174 v.Set(reflect.Append(v, reflect.Zero(v.Type().Elem())))
175 l := v.Len()
176 reflect.Copy(v.Slice(i+1, l), v.Slice(i, l))
177 v.Index(i).Set(reflect.ValueOf(n))
178 c.iter.index++
179 }
180
181
182 type application struct {
183 pre, post ApplyFunc
184 cursor Cursor
185 iter iterator
186 }
187
188 func (a *application) apply(parent ast.Node, name string, iter *iterator, n ast.Node) {
189
190 if v := reflect.ValueOf(n); v.Kind() == reflect.Pointer && v.IsNil() {
191 n = nil
192 }
193
194
195 saved := a.cursor
196 a.cursor.parent = parent
197 a.cursor.name = name
198 a.cursor.iter = iter
199 a.cursor.node = n
200
201 if a.pre != nil && !a.pre(&a.cursor) {
202 a.cursor = saved
203 return
204 }
205
206
207
208 switch n := n.(type) {
209 case nil:
210
211
212
213 case *ast.Comment:
214
215
216 case *ast.CommentGroup:
217 if n != nil {
218 a.applyList(n, "List")
219 }
220
221 case *ast.Field:
222 a.apply(n, "Doc", nil, n.Doc)
223 a.applyList(n, "Names")
224 a.apply(n, "Type", nil, n.Type)
225 a.apply(n, "Tag", nil, n.Tag)
226 a.apply(n, "Comment", nil, n.Comment)
227
228 case *ast.FieldList:
229 a.applyList(n, "List")
230
231
232 case *ast.BadExpr, *ast.Ident, *ast.BasicLit:
233
234
235 case *ast.Ellipsis:
236 a.apply(n, "Elt", nil, n.Elt)
237
238 case *ast.FuncLit:
239 a.apply(n, "Type", nil, n.Type)
240 a.apply(n, "Body", nil, n.Body)
241
242 case *ast.CompositeLit:
243 a.apply(n, "Type", nil, n.Type)
244 a.applyList(n, "Elts")
245
246 case *ast.ParenExpr:
247 a.apply(n, "X", nil, n.X)
248
249 case *ast.SelectorExpr:
250 a.apply(n, "X", nil, n.X)
251 a.apply(n, "Sel", nil, n.Sel)
252
253 case *ast.IndexExpr:
254 a.apply(n, "X", nil, n.X)
255 a.apply(n, "Index", nil, n.Index)
256
257 case *ast.IndexListExpr:
258 a.apply(n, "X", nil, n.X)
259 a.applyList(n, "Indices")
260
261 case *ast.SliceExpr:
262 a.apply(n, "X", nil, n.X)
263 a.apply(n, "Low", nil, n.Low)
264 a.apply(n, "High", nil, n.High)
265 a.apply(n, "Max", nil, n.Max)
266
267 case *ast.TypeAssertExpr:
268 a.apply(n, "X", nil, n.X)
269 a.apply(n, "Type", nil, n.Type)
270
271 case *ast.CallExpr:
272 a.apply(n, "Fun", nil, n.Fun)
273 a.applyList(n, "Args")
274
275 case *ast.StarExpr:
276 a.apply(n, "X", nil, n.X)
277
278 case *ast.UnaryExpr:
279 a.apply(n, "X", nil, n.X)
280
281 case *ast.BinaryExpr:
282 a.apply(n, "X", nil, n.X)
283 a.apply(n, "Y", nil, n.Y)
284
285 case *ast.KeyValueExpr:
286 a.apply(n, "Key", nil, n.Key)
287 a.apply(n, "Value", nil, n.Value)
288
289
290 case *ast.ArrayType:
291 a.apply(n, "Len", nil, n.Len)
292 a.apply(n, "Elt", nil, n.Elt)
293
294 case *ast.StructType:
295 a.apply(n, "Fields", nil, n.Fields)
296
297 case *ast.FuncType:
298 if tparams := n.TypeParams; tparams != nil {
299 a.apply(n, "TypeParams", nil, tparams)
300 }
301 a.apply(n, "Params", nil, n.Params)
302 a.apply(n, "Results", nil, n.Results)
303
304 case *ast.InterfaceType:
305 a.apply(n, "Methods", nil, n.Methods)
306
307 case *ast.MapType:
308 a.apply(n, "Key", nil, n.Key)
309 a.apply(n, "Value", nil, n.Value)
310
311 case *ast.ChanType:
312 a.apply(n, "Value", nil, n.Value)
313
314
315 case *ast.BadStmt:
316
317
318 case *ast.DeclStmt:
319 a.apply(n, "Decl", nil, n.Decl)
320
321 case *ast.EmptyStmt:
322
323
324 case *ast.LabeledStmt:
325 a.apply(n, "Label", nil, n.Label)
326 a.apply(n, "Stmt", nil, n.Stmt)
327
328 case *ast.ExprStmt:
329 a.apply(n, "X", nil, n.X)
330
331 case *ast.SendStmt:
332 a.apply(n, "Chan", nil, n.Chan)
333 a.apply(n, "Value", nil, n.Value)
334
335 case *ast.IncDecStmt:
336 a.apply(n, "X", nil, n.X)
337
338 case *ast.AssignStmt:
339 a.applyList(n, "Lhs")
340 a.applyList(n, "Rhs")
341
342 case *ast.GoStmt:
343 a.apply(n, "Call", nil, n.Call)
344
345 case *ast.DeferStmt:
346 a.apply(n, "Call", nil, n.Call)
347
348 case *ast.ReturnStmt:
349 a.applyList(n, "Results")
350
351 case *ast.BranchStmt:
352 a.apply(n, "Label", nil, n.Label)
353
354 case *ast.BlockStmt:
355 a.applyList(n, "List")
356
357 case *ast.IfStmt:
358 a.apply(n, "Init", nil, n.Init)
359 a.apply(n, "Cond", nil, n.Cond)
360 a.apply(n, "Body", nil, n.Body)
361 a.apply(n, "Else", nil, n.Else)
362
363 case *ast.CaseClause:
364 a.applyList(n, "List")
365 a.applyList(n, "Body")
366
367 case *ast.SwitchStmt:
368 a.apply(n, "Init", nil, n.Init)
369 a.apply(n, "Tag", nil, n.Tag)
370 a.apply(n, "Body", nil, n.Body)
371
372 case *ast.TypeSwitchStmt:
373 a.apply(n, "Init", nil, n.Init)
374 a.apply(n, "Assign", nil, n.Assign)
375 a.apply(n, "Body", nil, n.Body)
376
377 case *ast.CommClause:
378 a.apply(n, "Comm", nil, n.Comm)
379 a.applyList(n, "Body")
380
381 case *ast.SelectStmt:
382 a.apply(n, "Body", nil, n.Body)
383
384 case *ast.ForStmt:
385 a.apply(n, "Init", nil, n.Init)
386 a.apply(n, "Cond", nil, n.Cond)
387 a.apply(n, "Post", nil, n.Post)
388 a.apply(n, "Body", nil, n.Body)
389
390 case *ast.RangeStmt:
391 a.apply(n, "Key", nil, n.Key)
392 a.apply(n, "Value", nil, n.Value)
393 a.apply(n, "X", nil, n.X)
394 a.apply(n, "Body", nil, n.Body)
395
396
397 case *ast.ImportSpec:
398 a.apply(n, "Doc", nil, n.Doc)
399 a.apply(n, "Name", nil, n.Name)
400 a.apply(n, "Path", nil, n.Path)
401 a.apply(n, "Comment", nil, n.Comment)
402
403 case *ast.ValueSpec:
404 a.apply(n, "Doc", nil, n.Doc)
405 a.applyList(n, "Names")
406 a.apply(n, "Type", nil, n.Type)
407 a.applyList(n, "Values")
408 a.apply(n, "Comment", nil, n.Comment)
409
410 case *ast.TypeSpec:
411 a.apply(n, "Doc", nil, n.Doc)
412 a.apply(n, "Name", nil, n.Name)
413 if tparams := n.TypeParams; tparams != nil {
414 a.apply(n, "TypeParams", nil, tparams)
415 }
416 a.apply(n, "Type", nil, n.Type)
417 a.apply(n, "Comment", nil, n.Comment)
418
419 case *ast.BadDecl:
420
421
422 case *ast.GenDecl:
423 a.apply(n, "Doc", nil, n.Doc)
424 a.applyList(n, "Specs")
425
426 case *ast.FuncDecl:
427 a.apply(n, "Doc", nil, n.Doc)
428 a.apply(n, "Recv", nil, n.Recv)
429 a.apply(n, "Name", nil, n.Name)
430 a.apply(n, "Type", nil, n.Type)
431 a.apply(n, "Body", nil, n.Body)
432
433
434 case *ast.File:
435 a.apply(n, "Doc", nil, n.Doc)
436 a.apply(n, "Name", nil, n.Name)
437 a.applyList(n, "Decls")
438
439
440
441 case *ast.Package:
442
443 var names []string
444 for name := range n.Files {
445 names = append(names, name)
446 }
447 sort.Strings(names)
448 for _, name := range names {
449 a.apply(n, name, nil, n.Files[name])
450 }
451
452 default:
453 panic(fmt.Sprintf("Apply: unexpected node type %T", n))
454 }
455
456 if a.post != nil && !a.post(&a.cursor) {
457 panic(abort)
458 }
459
460 a.cursor = saved
461 }
462
463
464 type iterator struct {
465 index, step int
466 }
467
468 func (a *application) applyList(parent ast.Node, name string) {
469
470 saved := a.iter
471 a.iter.index = 0
472 for {
473
474 v := reflect.Indirect(reflect.ValueOf(parent)).FieldByName(name)
475 if a.iter.index >= v.Len() {
476 break
477 }
478
479
480 var x ast.Node
481 if e := v.Index(a.iter.index); e.IsValid() {
482 x = e.Interface().(ast.Node)
483 }
484
485 a.iter.step = 1
486 a.apply(parent, name, &a.iter, x)
487 a.iter.index += a.iter.step
488 }
489 a.iter = saved
490 }
491
View as plain text