|
| 1 | +// Copyright 2022 The Go Authors. All rights reserved. |
| 2 | +// Use of this source code is governed by a BSD-style |
| 3 | +// license that can be found in the LICENSE file. |
| 4 | + |
| 5 | +package loopclosure |
| 6 | + |
| 7 | +import "go/ast" |
| 8 | + |
| 9 | +// visitor recursively visits statements. |
| 10 | +// |
| 11 | +// last is called for the last non-compound statement in an input |
| 12 | +// statement list or any recursively visited compound statement bodies. |
| 13 | +// all and skipLast allow clients to modify the behavior of visitor or |
| 14 | +// track their own client state. |
| 15 | +// |
| 16 | +// visitor is passed by value, and it is valid for clients |
| 17 | +// to return a modified or new visitor in an invocation of all, |
| 18 | +// such as to clear or set last, all, or skipLast. |
| 19 | +// |
| 20 | +// last will be set to nil by default when visiting a statement that is not |
| 21 | +// the last statement in a list of statements. If all always |
| 22 | +// returns an unmodified visitor, last will remain nil for statements |
| 23 | +// that are nested within a statement that was not a last statement. |
| 24 | +type visitor struct { |
| 25 | + // If non-nil, last is called for the last non-compound statement. |
| 26 | + last func(v visitor, stmt ast.Stmt) |
| 27 | + |
| 28 | + // If non-nil, all is called for all statements, whether or not they are last. |
| 29 | + // push is true before visiting any children, and false after visiting any children. |
| 30 | + // It returns the visitor that will be used to visit any children. |
| 31 | + // all with push true is called before last, and all with push false is called after last. |
| 32 | + all func(v visitor, stmt ast.Stmt, push bool) visitor |
| 33 | + |
| 34 | + // If non-nil, skipLast is called for candidate last statements. |
| 35 | + // If skipLast returns true, the statement is not considered a last statement. |
| 36 | + skipLast func(v visitor, stmt ast.Stmt) bool |
| 37 | + |
| 38 | + // TODO: consider a state variable. Might be mildly convenient for a client to |
| 39 | + // avoid managing their own stack, but there are currently no such clients. |
| 40 | + // TODO: consider a parents stack |
| 41 | +} |
| 42 | + |
| 43 | +// visit calls v.last on each "last" statement in a list of statements. |
| 44 | +// "Last" is defined recursively. For example, if the last statement is |
| 45 | +// a switch statement, then each switch case is also visited to examine |
| 46 | +// its last statements. |
| 47 | +func (v visitor) visit(stmts []ast.Stmt) { |
| 48 | + if len(stmts) == 0 { |
| 49 | + return |
| 50 | + } |
| 51 | + |
| 52 | + lastStmt := len(stmts) - 1 |
| 53 | + if v.skipLast != nil { |
| 54 | + // Find which statement in stmts will be considered the last statement. |
| 55 | + // We allow lastStmt to go to -1, which means no statement will be considered last. |
| 56 | + for ; lastStmt >= 0; lastStmt-- { |
| 57 | + if !v.skipLast(v, stmts[lastStmt]) { |
| 58 | + break |
| 59 | + } |
| 60 | + } |
| 61 | + } |
| 62 | + for i, stmt := range stmts { |
| 63 | + // Copy the visitor so that it can be modified independently per iteration. |
| 64 | + vv := v |
| 65 | + |
| 66 | + if i != lastStmt { |
| 67 | + // Clear last so that last by default it is not called in this branch of recursion |
| 68 | + // (even if a visited is last statement of some body lower in the AST tree). |
| 69 | + // It is only by default because the client can set visitor.last |
| 70 | + // themselves if they desire to start treating a branch of the recursion as |
| 71 | + // candidates for last statements again. |
| 72 | + vv.last = nil |
| 73 | + } |
| 74 | + |
| 75 | + if vv.all != nil { |
| 76 | + vv = vv.all(vv, stmt, true) // push |
| 77 | + } |
| 78 | + |
| 79 | + switch s := stmt.(type) { |
| 80 | + case *ast.IfStmt: |
| 81 | + loop: |
| 82 | + for { |
| 83 | + vv.visit(s.Body.List) |
| 84 | + switch e := s.Else.(type) { |
| 85 | + case *ast.BlockStmt: |
| 86 | + vv.visit(e.List) |
| 87 | + break loop |
| 88 | + case *ast.IfStmt: |
| 89 | + s = e |
| 90 | + case nil: |
| 91 | + break loop |
| 92 | + } |
| 93 | + } |
| 94 | + case *ast.ForStmt: |
| 95 | + vv.visit(s.Body.List) |
| 96 | + case *ast.RangeStmt: |
| 97 | + vv.visit(s.Body.List) |
| 98 | + case *ast.SwitchStmt: |
| 99 | + for _, c := range s.Body.List { |
| 100 | + cc := c.(*ast.CaseClause) |
| 101 | + vv.visit(cc.Body) |
| 102 | + } |
| 103 | + case *ast.TypeSwitchStmt: |
| 104 | + for _, c := range s.Body.List { |
| 105 | + cc := c.(*ast.CaseClause) |
| 106 | + vv.visit(cc.Body) |
| 107 | + } |
| 108 | + case *ast.SelectStmt: |
| 109 | + for _, c := range s.Body.List { |
| 110 | + cc := c.(*ast.CommClause) |
| 111 | + vv.visit(cc.Body) |
| 112 | + } |
| 113 | + default: |
| 114 | + if i == lastStmt && vv.last != nil { |
| 115 | + vv.last(vv, s) |
| 116 | + } |
| 117 | + } |
| 118 | + |
| 119 | + if vv.all != nil { |
| 120 | + // We use vv here to ensure the pop is symmetric with the vv push above |
| 121 | + vv.all(vv, stmt, false) // pop |
| 122 | + } |
| 123 | + } |
| 124 | +} |
0 commit comments