1
2
3
4
5 package astutil
6
7
8
9 import (
10 "fmt"
11 "go/ast"
12 "go/token"
13 "sort"
14 )
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60 func PathEnclosingInterval(root *ast.File, start, end token.Pos) (path []ast.Node, exact bool) {
61
62
63
64 var visit func(node ast.Node) bool
65 visit = func(node ast.Node) bool {
66 path = append(path, node)
67
68 nodePos := node.Pos()
69 nodeEnd := node.End()
70
71
72
73
74 if start < nodePos {
75 start = nodePos
76 }
77 if end > nodeEnd {
78 end = nodeEnd
79 }
80
81
82 children := childrenOf(node)
83 l := len(children)
84 for i, child := range children {
85
86 childPos := child.Pos()
87 childEnd := child.End()
88
89
90 augPos := childPos
91 augEnd := childEnd
92 if i > 0 {
93 augPos = children[i-1].End()
94 }
95 if i < l-1 {
96 nextChildPos := children[i+1].Pos()
97
98 if start >= augEnd && end <= nextChildPos {
99 return false
100 }
101 augEnd = nextChildPos
102 }
103
104
105
106
107
108 if augPos <= start && end <= augEnd {
109 if is[tokenNode](child) {
110 return true
111 }
112
113
114
115
116
117 if decl, ok := node.(*ast.FuncDecl); ok {
118 if fields, ok := child.(*ast.FieldList); ok && fields != decl.Recv {
119 path = append(path, decl.Type)
120 }
121 }
122
123 return visit(child)
124 }
125
126
127
128
129 if start < childEnd && end > augEnd {
130 break
131 }
132 }
133
134
135
136
137
138
139
140
141 if start == nodePos && end == nodeEnd {
142 return true
143 }
144
145 return false
146 }
147
148
149 if start > end {
150 start, end = end, start
151 }
152
153 if start < root.End() && end > root.Pos() {
154 if start == end {
155 end = start + 1
156 }
157 exact = visit(root)
158
159
160 for i, l := 0, len(path); i < l/2; i++ {
161 path[i], path[l-1-i] = path[l-1-i], path[i]
162 }
163 } else {
164
165
166
167 path = append(path, root)
168 }
169
170 return
171 }
172
173
174
175
176 type tokenNode struct {
177 pos token.Pos
178 end token.Pos
179 }
180
181 func (n tokenNode) Pos() token.Pos {
182 return n.pos
183 }
184
185 func (n tokenNode) End() token.Pos {
186 return n.end
187 }
188
189 func tok(pos token.Pos, len int) ast.Node {
190 return tokenNode{pos, pos + token.Pos(len)}
191 }
192
193
194
195
196 func childrenOf(n ast.Node) []ast.Node {
197 var children []ast.Node
198
199
200 ast.Inspect(n, func(node ast.Node) bool {
201 if node == n {
202 return true
203 }
204 if node != nil {
205 children = append(children, node)
206 }
207 return false
208 })
209
210
211
212
213
214 switch n := n.(type) {
215 case *ast.ArrayType:
216 children = append(children,
217 tok(n.Lbrack, len("[")),
218 tok(n.Elt.End(), len("]")))
219
220 case *ast.AssignStmt:
221 children = append(children,
222 tok(n.TokPos, len(n.Tok.String())))
223
224 case *ast.BasicLit:
225 children = append(children,
226 tok(n.ValuePos, len(n.Value)))
227
228 case *ast.BinaryExpr:
229 children = append(children, tok(n.OpPos, len(n.Op.String())))
230
231 case *ast.BlockStmt:
232 if n.Lbrace.IsValid() {
233 children = append(children, tok(n.Lbrace, len("{")))
234 }
235 if n.Rbrace.IsValid() {
236 children = append(children, tok(n.Rbrace, len("}")))
237 }
238
239 case *ast.BranchStmt:
240 children = append(children,
241 tok(n.TokPos, len(n.Tok.String())))
242
243 case *ast.CallExpr:
244 children = append(children,
245 tok(n.Lparen, len("(")),
246 tok(n.Rparen, len(")")))
247 if n.Ellipsis != 0 {
248 children = append(children, tok(n.Ellipsis, len("...")))
249 }
250
251 case *ast.CaseClause:
252 if n.List == nil {
253 children = append(children,
254 tok(n.Case, len("default")))
255 } else {
256 children = append(children,
257 tok(n.Case, len("case")))
258 }
259 children = append(children, tok(n.Colon, len(":")))
260
261 case *ast.ChanType:
262 switch n.Dir {
263 case ast.RECV:
264 children = append(children, tok(n.Begin, len("<-chan")))
265 case ast.SEND:
266 children = append(children, tok(n.Begin, len("chan<-")))
267 case ast.RECV | ast.SEND:
268 children = append(children, tok(n.Begin, len("chan")))
269 }
270
271 case *ast.CommClause:
272 if n.Comm == nil {
273 children = append(children,
274 tok(n.Case, len("default")))
275 } else {
276 children = append(children,
277 tok(n.Case, len("case")))
278 }
279 children = append(children, tok(n.Colon, len(":")))
280
281 case *ast.Comment:
282
283
284 case *ast.CommentGroup:
285
286
287 case *ast.CompositeLit:
288 children = append(children,
289 tok(n.Lbrace, len("{")),
290 tok(n.Rbrace, len("{")))
291
292 case *ast.DeclStmt:
293
294
295 case *ast.DeferStmt:
296 children = append(children,
297 tok(n.Defer, len("defer")))
298
299 case *ast.Ellipsis:
300 children = append(children,
301 tok(n.Ellipsis, len("...")))
302
303 case *ast.EmptyStmt:
304
305
306 case *ast.ExprStmt:
307
308
309 case *ast.Field:
310
311
312 case *ast.FieldList:
313 if n.Opening.IsValid() {
314 children = append(children, tok(n.Opening, len("(")))
315 }
316 if n.Closing.IsValid() {
317 children = append(children, tok(n.Closing, len(")")))
318 }
319
320 case *ast.File:
321
322 children = append(children,
323 tok(n.Package, len("package")))
324
325 case *ast.ForStmt:
326 children = append(children,
327 tok(n.For, len("for")))
328
329 case *ast.FuncDecl:
330
331
332
333
334
335
336
337
338
339
340
341 children = nil
342 children = append(children, tok(n.Type.Func, len("func")))
343 if n.Recv != nil {
344 children = append(children, n.Recv)
345 }
346 children = append(children, n.Name)
347 if tparams := n.Type.TypeParams; tparams != nil {
348 children = append(children, tparams)
349 }
350 if n.Type.Params != nil {
351 children = append(children, n.Type.Params)
352 }
353 if n.Type.Results != nil {
354 children = append(children, n.Type.Results)
355 }
356 if n.Body != nil {
357 children = append(children, n.Body)
358 }
359
360 case *ast.FuncLit:
361
362
363 case *ast.FuncType:
364 if n.Func != 0 {
365 children = append(children,
366 tok(n.Func, len("func")))
367 }
368
369 case *ast.GenDecl:
370 children = append(children,
371 tok(n.TokPos, len(n.Tok.String())))
372 if n.Lparen != 0 {
373 children = append(children,
374 tok(n.Lparen, len("(")),
375 tok(n.Rparen, len(")")))
376 }
377
378 case *ast.GoStmt:
379 children = append(children,
380 tok(n.Go, len("go")))
381
382 case *ast.Ident:
383 children = append(children,
384 tok(n.NamePos, len(n.Name)))
385
386 case *ast.IfStmt:
387 children = append(children,
388 tok(n.If, len("if")))
389
390 case *ast.ImportSpec:
391
392
393 case *ast.IncDecStmt:
394 children = append(children,
395 tok(n.TokPos, len(n.Tok.String())))
396
397 case *ast.IndexExpr:
398 children = append(children,
399 tok(n.Lbrack, len("[")),
400 tok(n.Rbrack, len("]")))
401
402 case *ast.IndexListExpr:
403 children = append(children,
404 tok(n.Lbrack, len("[")),
405 tok(n.Rbrack, len("]")))
406
407 case *ast.InterfaceType:
408 children = append(children,
409 tok(n.Interface, len("interface")))
410
411 case *ast.KeyValueExpr:
412 children = append(children,
413 tok(n.Colon, len(":")))
414
415 case *ast.LabeledStmt:
416 children = append(children,
417 tok(n.Colon, len(":")))
418
419 case *ast.MapType:
420 children = append(children,
421 tok(n.Map, len("map")))
422
423 case *ast.ParenExpr:
424 children = append(children,
425 tok(n.Lparen, len("(")),
426 tok(n.Rparen, len(")")))
427
428 case *ast.RangeStmt:
429 children = append(children,
430 tok(n.For, len("for")),
431 tok(n.TokPos, len(n.Tok.String())))
432
433 case *ast.ReturnStmt:
434 children = append(children,
435 tok(n.Return, len("return")))
436
437 case *ast.SelectStmt:
438 children = append(children,
439 tok(n.Select, len("select")))
440
441 case *ast.SelectorExpr:
442
443
444 case *ast.SendStmt:
445 children = append(children,
446 tok(n.Arrow, len("<-")))
447
448 case *ast.SliceExpr:
449 children = append(children,
450 tok(n.Lbrack, len("[")),
451 tok(n.Rbrack, len("]")))
452
453 case *ast.StarExpr:
454 children = append(children, tok(n.Star, len("*")))
455
456 case *ast.StructType:
457 children = append(children, tok(n.Struct, len("struct")))
458
459 case *ast.SwitchStmt:
460 children = append(children, tok(n.Switch, len("switch")))
461
462 case *ast.TypeAssertExpr:
463 children = append(children,
464 tok(n.Lparen-1, len(".")),
465 tok(n.Lparen, len("(")),
466 tok(n.Rparen, len(")")))
467
468 case *ast.TypeSpec:
469
470
471 case *ast.TypeSwitchStmt:
472 children = append(children, tok(n.Switch, len("switch")))
473
474 case *ast.UnaryExpr:
475 children = append(children, tok(n.OpPos, len(n.Op.String())))
476
477 case *ast.ValueSpec:
478
479
480 case *ast.BadDecl, *ast.BadExpr, *ast.BadStmt:
481
482 }
483
484
485
486
487
488 sort.Sort(byPos(children))
489
490 return children
491 }
492
493 type byPos []ast.Node
494
495 func (sl byPos) Len() int {
496 return len(sl)
497 }
498 func (sl byPos) Less(i, j int) bool {
499 return sl[i].Pos() < sl[j].Pos()
500 }
501 func (sl byPos) Swap(i, j int) {
502 sl[i], sl[j] = sl[j], sl[i]
503 }
504
505
506
507
508
509
510
511 func NodeDescription(n ast.Node) string {
512 switch n := n.(type) {
513 case *ast.ArrayType:
514 return "array type"
515 case *ast.AssignStmt:
516 return "assignment"
517 case *ast.BadDecl:
518 return "bad declaration"
519 case *ast.BadExpr:
520 return "bad expression"
521 case *ast.BadStmt:
522 return "bad statement"
523 case *ast.BasicLit:
524 return "basic literal"
525 case *ast.BinaryExpr:
526 return fmt.Sprintf("binary %s operation", n.Op)
527 case *ast.BlockStmt:
528 return "block"
529 case *ast.BranchStmt:
530 switch n.Tok {
531 case token.BREAK:
532 return "break statement"
533 case token.CONTINUE:
534 return "continue statement"
535 case token.GOTO:
536 return "goto statement"
537 case token.FALLTHROUGH:
538 return "fall-through statement"
539 }
540 case *ast.CallExpr:
541 if len(n.Args) == 1 && !n.Ellipsis.IsValid() {
542 return "function call (or conversion)"
543 }
544 return "function call"
545 case *ast.CaseClause:
546 return "case clause"
547 case *ast.ChanType:
548 return "channel type"
549 case *ast.CommClause:
550 return "communication clause"
551 case *ast.Comment:
552 return "comment"
553 case *ast.CommentGroup:
554 return "comment group"
555 case *ast.CompositeLit:
556 return "composite literal"
557 case *ast.DeclStmt:
558 return NodeDescription(n.Decl) + " statement"
559 case *ast.DeferStmt:
560 return "defer statement"
561 case *ast.Ellipsis:
562 return "ellipsis"
563 case *ast.EmptyStmt:
564 return "empty statement"
565 case *ast.ExprStmt:
566 return "expression statement"
567 case *ast.Field:
568
569
570
571
572
573
574 return "field/method/parameter"
575 case *ast.FieldList:
576 return "field/method/parameter list"
577 case *ast.File:
578 return "source file"
579 case *ast.ForStmt:
580 return "for loop"
581 case *ast.FuncDecl:
582 return "function declaration"
583 case *ast.FuncLit:
584 return "function literal"
585 case *ast.FuncType:
586 return "function type"
587 case *ast.GenDecl:
588 switch n.Tok {
589 case token.IMPORT:
590 return "import declaration"
591 case token.CONST:
592 return "constant declaration"
593 case token.TYPE:
594 return "type declaration"
595 case token.VAR:
596 return "variable declaration"
597 }
598 case *ast.GoStmt:
599 return "go statement"
600 case *ast.Ident:
601 return "identifier"
602 case *ast.IfStmt:
603 return "if statement"
604 case *ast.ImportSpec:
605 return "import specification"
606 case *ast.IncDecStmt:
607 if n.Tok == token.INC {
608 return "increment statement"
609 }
610 return "decrement statement"
611 case *ast.IndexExpr:
612 return "index expression"
613 case *ast.IndexListExpr:
614 return "index list expression"
615 case *ast.InterfaceType:
616 return "interface type"
617 case *ast.KeyValueExpr:
618 return "key/value association"
619 case *ast.LabeledStmt:
620 return "statement label"
621 case *ast.MapType:
622 return "map type"
623 case *ast.Package:
624 return "package"
625 case *ast.ParenExpr:
626 return "parenthesized " + NodeDescription(n.X)
627 case *ast.RangeStmt:
628 return "range loop"
629 case *ast.ReturnStmt:
630 return "return statement"
631 case *ast.SelectStmt:
632 return "select statement"
633 case *ast.SelectorExpr:
634 return "selector"
635 case *ast.SendStmt:
636 return "channel send"
637 case *ast.SliceExpr:
638 return "slice expression"
639 case *ast.StarExpr:
640 return "*-operation"
641 case *ast.StructType:
642 return "struct type"
643 case *ast.SwitchStmt:
644 return "switch statement"
645 case *ast.TypeAssertExpr:
646 return "type assertion"
647 case *ast.TypeSpec:
648 return "type specification"
649 case *ast.TypeSwitchStmt:
650 return "type switch"
651 case *ast.UnaryExpr:
652 return fmt.Sprintf("unary %s operation", n.Op)
653 case *ast.ValueSpec:
654 return "value specification"
655
656 }
657 panic(fmt.Sprintf("unexpected node type: %T", n))
658 }
659
660 func is[T any](x any) bool {
661 _, ok := x.(T)
662 return ok
663 }
664
View as plain text