Skip to content

Commit 9f2680c

Browse files
committed
go/analysis/passes/loopclosure: refactor to add a visitor type replacing forEachLastStmt; no external behavior change
1 parent aa9f4b2 commit 9f2680c

File tree

4 files changed

+500
-97
lines changed

4 files changed

+500
-97
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
// Copyright 2022 The Go Authors. All rights reserved.
2+
// Use of this source code is governed by a BSD-style
3+
// license that can be found in the LICENSE file.
4+
5+
package loopclosure
6+
7+
import "go/ast"
8+
9+
// visitor recursively visits statements.
10+
//
11+
// last is called for the last non-compound statement in an input
12+
// statement list or any recursively visited compound statement bodies.
13+
// all and skipLast allow clients to modify the behavior of visitor or
14+
// track their own client state.
15+
//
16+
// visitor is passed by value, and it is valid for clients
17+
// to return a modified or new visitor in an invocation of all,
18+
// such as to clear or set last, all, or skipLast.
19+
//
20+
// last will be set to nil by default when visiting a statement that is not
21+
// the last statement in a list of statements. If all always
22+
// returns an unmodified visitor, last will remain nil for statements
23+
// that are nested within a statement that was not a last statement.
24+
type visitor struct {
25+
// If non-nil, last is called for the last non-compound statement.
26+
last func(v visitor, stmt ast.Stmt)
27+
28+
// If non-nil, all is called for all statements, whether or not they are last.
29+
// push is true before visiting any children, and false after visiting any children.
30+
// It returns the visitor that will be used to visit any children.
31+
// all with push true is called before last, and all with push false is called after last.
32+
all func(v visitor, stmt ast.Stmt, push bool) visitor
33+
34+
// If non-nil, skipLast is called for candidate last statements.
35+
// If skipLast returns true, the statement is not considered a last statement.
36+
skipLast func(v visitor, stmt ast.Stmt) bool
37+
38+
// TODO: consider a state variable. Might be mildly convenient for a client to
39+
// avoid managing their own stack, but there are currently no such clients.
40+
// TODO: consider a parents stack
41+
}
42+
43+
// visit calls v.last on each "last" statement in a list of statements.
44+
// "Last" is defined recursively. For example, if the last statement is
45+
// a switch statement, then each switch case is also visited to examine
46+
// its last statements.
47+
func (v visitor) visit(stmts []ast.Stmt) {
48+
if len(stmts) == 0 {
49+
return
50+
}
51+
52+
lastStmt := len(stmts) - 1
53+
if v.skipLast != nil {
54+
// Find which statement in stmts will be considered the last statement.
55+
// We allow lastStmt to go to -1, which means no statement will be considered last.
56+
for ; lastStmt >= 0; lastStmt-- {
57+
if !v.skipLast(v, stmts[lastStmt]) {
58+
break
59+
}
60+
}
61+
}
62+
for i, stmt := range stmts {
63+
// Copy the visitor so that it can be modified independently per iteration.
64+
vv := v
65+
66+
if i != lastStmt {
67+
// Clear last so that last by default it is not called in this branch of recursion
68+
// (even if a visited is last statement of some body lower in the AST tree).
69+
// It is only by default because the client can set visitor.last
70+
// themselves if they desire to start treating a branch of the recursion as
71+
// candidates for last statements again.
72+
vv.last = nil
73+
}
74+
75+
if vv.all != nil {
76+
vv = vv.all(vv, stmt, true) // push
77+
}
78+
79+
switch s := stmt.(type) {
80+
case *ast.IfStmt:
81+
loop:
82+
for {
83+
vv.visit(s.Body.List)
84+
switch e := s.Else.(type) {
85+
case *ast.BlockStmt:
86+
vv.visit(e.List)
87+
break loop
88+
case *ast.IfStmt:
89+
s = e
90+
case nil:
91+
break loop
92+
}
93+
}
94+
case *ast.ForStmt:
95+
vv.visit(s.Body.List)
96+
case *ast.RangeStmt:
97+
vv.visit(s.Body.List)
98+
case *ast.SwitchStmt:
99+
for _, c := range s.Body.List {
100+
cc := c.(*ast.CaseClause)
101+
vv.visit(cc.Body)
102+
}
103+
case *ast.TypeSwitchStmt:
104+
for _, c := range s.Body.List {
105+
cc := c.(*ast.CaseClause)
106+
vv.visit(cc.Body)
107+
}
108+
case *ast.SelectStmt:
109+
for _, c := range s.Body.List {
110+
cc := c.(*ast.CommClause)
111+
vv.visit(cc.Body)
112+
}
113+
default:
114+
if i == lastStmt && vv.last != nil {
115+
vv.last(vv, s)
116+
}
117+
}
118+
119+
if vv.all != nil {
120+
// We use vv here to ensure the pop is symmetric with the vv push above
121+
vv.all(vv, stmt, false) // pop
122+
}
123+
}
124+
}

0 commit comments

Comments
 (0)