diff --git a/internal/gogrep/gogrep.go b/internal/gogrep/gogrep.go index 6c05ab07..ea054f33 100644 --- a/internal/gogrep/gogrep.go +++ b/internal/gogrep/gogrep.go @@ -33,6 +33,22 @@ func (data MatchData) CapturedByName(name string) (ast.Node, bool) { return findNamed(data.Capture, name) } +type MatcherState struct { + Types *types.Info + + // node values recorded by name, excluding "_" (used only by the + // actual matching phase) + capture []CapturedNode + + pc int +} + +func NewMatcherState() MatcherState { + return MatcherState{ + capture: make([]CapturedNode, 0, 8), + } +} + type Pattern struct { m *matcher } @@ -46,8 +62,8 @@ func (p *Pattern) NodeTag() nodetag.Value { } // MatchNode calls cb if n matches a pattern. -func (p *Pattern) MatchNode(info *types.Info, n ast.Node, cb func(MatchData)) { - p.m.MatchNode(info, n, cb) +func (p *Pattern) MatchNode(state *MatcherState, n ast.Node, cb func(MatchData)) { + p.m.MatchNode(state, n, cb) } // Clone creates a pattern copy. @@ -55,7 +71,6 @@ func (p *Pattern) Clone() *Pattern { clone := *p clone.m = &matcher{} *clone.m = *p.m - clone.m.capture = make([]CapturedNode, 0, 8) return &clone } diff --git a/internal/gogrep/match.go b/internal/gogrep/match.go index ee14e13b..39b71c46 100644 --- a/internal/gogrep/match.go +++ b/internal/gogrep/match.go @@ -15,26 +15,18 @@ type matcher struct { prog *program insts []instruction - pc int - - // node values recorded by name, excluding "_" (used only by the - // actual matching phase) - capture []CapturedNode - - types *types.Info } func newMatcher(prog *program) *matcher { return &matcher{ - prog: prog, - insts: prog.insts, - capture: make([]CapturedNode, 0, 8), + prog: prog, + insts: prog.insts, } } -func (m *matcher) nextInst() instruction { - inst := m.insts[m.pc] - m.pc++ +func (m *matcher) nextInst(state *MatcherState) instruction { + inst := m.insts[state.pc] + state.pc++ return inst } @@ -46,69 +38,68 @@ func (m *matcher) ifaceValue(inst instruction) interface{} { return m.prog.ifaces[inst.valueIndex] } -func (m *matcher) MatchNode(info *types.Info, n ast.Node, accept func(MatchData)) { - m.pc = 0 - m.types = info - inst := m.nextInst() +func (m *matcher) MatchNode(state *MatcherState, n ast.Node, accept func(MatchData)) { + state.pc = 0 + inst := m.nextInst(state) switch inst.op { case opMultiStmt: switch n := n.(type) { case *ast.BlockStmt: - m.walkStmtSlice(n.List, accept) + m.walkStmtSlice(state, n.List, accept) case *ast.CaseClause: - m.walkStmtSlice(n.Body, accept) + m.walkStmtSlice(state, n.Body, accept) case *ast.CommClause: - m.walkStmtSlice(n.Body, accept) + m.walkStmtSlice(state, n.Body, accept) } case opMultiExpr: switch n := n.(type) { case *ast.CallExpr: - m.walkExprSlice(n.Args, accept) + m.walkExprSlice(state, n.Args, accept) case *ast.CompositeLit: - m.walkExprSlice(n.Elts, accept) + m.walkExprSlice(state, n.Elts, accept) case *ast.ReturnStmt: - m.walkExprSlice(n.Results, accept) + m.walkExprSlice(state, n.Results, accept) } case opMultiDecl: switch n := n.(type) { case *ast.File: - m.walkDeclSlice(n.Decls, accept) + m.walkDeclSlice(state, n.Decls, accept) } default: - m.capture = m.capture[:0] - if m.matchNodeWithInst(inst, n) { + state.capture = state.capture[:0] + if m.matchNodeWithInst(state, inst, n) { accept(MatchData{ - Capture: m.capture, + Capture: state.capture, Node: n, }) } } } -func (m *matcher) walkDeclSlice(decls []ast.Decl, accept func(MatchData)) { - m.walkNodeSlice(declSlice(decls), accept) +func (m *matcher) walkDeclSlice(state *MatcherState, decls []ast.Decl, accept func(MatchData)) { + m.walkNodeSlice(state, declSlice(decls), accept) } -func (m *matcher) walkExprSlice(exprs []ast.Expr, accept func(MatchData)) { - m.walkNodeSlice(ExprSlice(exprs), accept) +func (m *matcher) walkExprSlice(state *MatcherState, exprs []ast.Expr, accept func(MatchData)) { + m.walkNodeSlice(state, ExprSlice(exprs), accept) } -func (m *matcher) walkStmtSlice(stmts []ast.Stmt, accept func(MatchData)) { - m.walkNodeSlice(stmtSlice(stmts), accept) +func (m *matcher) walkStmtSlice(state *MatcherState, stmts []ast.Stmt, accept func(MatchData)) { + m.walkNodeSlice(state, stmtSlice(stmts), accept) } -func (m *matcher) walkNodeSlice(nodes NodeSlice, accept func(MatchData)) { +func (m *matcher) walkNodeSlice(state *MatcherState, nodes NodeSlice, accept func(MatchData)) { sliceLen := nodes.Len() from := 0 for { - m.pc = 1 // FIXME: this is a kludge - m.capture = m.capture[:0] - matched, offset := m.matchNodeList(nodes.slice(from, sliceLen), true) + state.pc = 1 // FIXME: this is a kludge + state.capture = state.capture[:0] + matched, offset := m.matchNodeList(state, nodes.slice(from, sliceLen), true) if matched == nil { break } accept(MatchData{ - Capture: m.capture, + Capture: state.capture, Node: matched, }) from += offset - 1 @@ -118,22 +109,22 @@ func (m *matcher) walkNodeSlice(nodes NodeSlice, accept func(MatchData)) { } } -func (m *matcher) matchNamed(name string, n ast.Node) bool { - prev, ok := findNamed(m.capture, name) +func (m *matcher) matchNamed(state *MatcherState, name string, n ast.Node) bool { + prev, ok := findNamed(state.capture, name) if !ok { // First occurrence, record value. - m.capture = append(m.capture, CapturedNode{Name: name, Node: n}) + state.capture = append(state.capture, CapturedNode{Name: name, Node: n}) return true } return equalNodes(prev, n) } -func (m *matcher) matchNamedField(name string, n ast.Node) bool { - prev, ok := findNamed(m.capture, name) +func (m *matcher) matchNamedField(state *MatcherState, name string, n ast.Node) bool { + prev, ok := findNamed(state.capture, name) if !ok { // First occurrence, record value. unwrapped := m.unwrapNode(n) - m.capture = append(m.capture, CapturedNode{Name: name, Node: unwrapped}) + state.capture = append(state.capture, CapturedNode{Name: name, Node: unwrapped}) return true } n = m.unwrapNode(n) @@ -154,7 +145,7 @@ func (m *matcher) unwrapNode(x ast.Node) ast.Node { return x } -func (m *matcher) matchNodeWithInst(inst instruction, n ast.Node) bool { +func (m *matcher) matchNodeWithInst(state *MatcherState, inst instruction, n ast.Node) bool { switch inst.op { case opNode: return n != nil @@ -162,15 +153,15 @@ func (m *matcher) matchNodeWithInst(inst instruction, n ast.Node) bool { return true case opNamedNode: - return n != nil && m.matchNamed(m.stringValue(inst), n) + return n != nil && m.matchNamed(state, m.stringValue(inst), n) case opNamedOptNode: - return m.matchNamed(m.stringValue(inst), n) + return m.matchNamed(state, m.stringValue(inst), n) case opFieldNode: n, ok := n.(*ast.FieldList) return ok && n != nil && len(n.List) == 1 && len(n.List[0].Names) == 0 case opNamedFieldNode: - return n != nil && m.matchNamedField(m.stringValue(inst), n) + return n != nil && m.matchNamedField(state, m.stringValue(inst), n) case opBasicLit: n, ok := n.(*ast.BasicLit) @@ -201,7 +192,7 @@ func (m *matcher) matchNodeWithInst(inst instruction, n ast.Node) bool { if !ok { return false } - obj := m.types.ObjectOf(n) + obj := state.Types.ObjectOf(n) if obj == nil { return false } @@ -212,290 +203,290 @@ func (m *matcher) matchNodeWithInst(inst instruction, n ast.Node) bool { case opBinaryExpr: n, ok := n.(*ast.BinaryExpr) return ok && n.Op == token.Token(inst.value) && - m.matchNode(n.X) && m.matchNode(n.Y) + m.matchNode(state, n.X) && m.matchNode(state, n.Y) case opUnaryExpr: n, ok := n.(*ast.UnaryExpr) - return ok && n.Op == token.Token(inst.value) && m.matchNode(n.X) + return ok && n.Op == token.Token(inst.value) && m.matchNode(state, n.X) case opStarExpr: n, ok := n.(*ast.StarExpr) - return ok && m.matchNode(n.X) + return ok && m.matchNode(state, n.X) case opVariadicCallExpr: n, ok := n.(*ast.CallExpr) - return ok && n.Ellipsis.IsValid() && m.matchNode(n.Fun) && m.matchArgList(n.Args) + return ok && n.Ellipsis.IsValid() && m.matchNode(state, n.Fun) && m.matchArgList(state, n.Args) case opNonVariadicCallExpr: n, ok := n.(*ast.CallExpr) - return ok && !n.Ellipsis.IsValid() && m.matchNode(n.Fun) && m.matchArgList(n.Args) + return ok && !n.Ellipsis.IsValid() && m.matchNode(state, n.Fun) && m.matchArgList(state, n.Args) case opCallExpr: n, ok := n.(*ast.CallExpr) - return ok && m.matchNode(n.Fun) && m.matchArgList(n.Args) + return ok && m.matchNode(state, n.Fun) && m.matchArgList(state, n.Args) case opSimpleSelectorExpr: n, ok := n.(*ast.SelectorExpr) - return ok && m.stringValue(inst) == n.Sel.Name && m.matchNode(n.X) + return ok && m.stringValue(inst) == n.Sel.Name && m.matchNode(state, n.X) case opSelectorExpr: n, ok := n.(*ast.SelectorExpr) - return ok && m.matchNode(n.Sel) && m.matchNode(n.X) + return ok && m.matchNode(state, n.Sel) && m.matchNode(state, n.X) case opTypeAssertExpr: n, ok := n.(*ast.TypeAssertExpr) - return ok && m.matchNode(n.X) && m.matchNode(n.Type) + return ok && m.matchNode(state, n.X) && m.matchNode(state, n.Type) case opTypeSwitchAssertExpr: n, ok := n.(*ast.TypeAssertExpr) - return ok && n.Type == nil && m.matchNode(n.X) + return ok && n.Type == nil && m.matchNode(state, n.X) case opSliceExpr: n, ok := n.(*ast.SliceExpr) - return ok && n.Low == nil && n.High == nil && m.matchNode(n.X) + return ok && n.Low == nil && n.High == nil && m.matchNode(state, n.X) case opSliceFromExpr: n, ok := n.(*ast.SliceExpr) return ok && n.High == nil && !n.Slice3 && - m.matchNode(n.X) && m.matchNode(n.Low) + m.matchNode(state, n.X) && m.matchNode(state, n.Low) case opSliceToExpr: n, ok := n.(*ast.SliceExpr) return ok && n.Low == nil && !n.Slice3 && - m.matchNode(n.X) && m.matchNode(n.High) + m.matchNode(state, n.X) && m.matchNode(state, n.High) case opSliceFromToExpr: n, ok := n.(*ast.SliceExpr) return ok && !n.Slice3 && - m.matchNode(n.X) && m.matchNode(n.Low) && m.matchNode(n.High) + m.matchNode(state, n.X) && m.matchNode(state, n.Low) && m.matchNode(state, n.High) case opSliceToCapExpr: n, ok := n.(*ast.SliceExpr) return ok && n.Low == nil && - m.matchNode(n.X) && m.matchNode(n.High) && m.matchNode(n.Max) + m.matchNode(state, n.X) && m.matchNode(state, n.High) && m.matchNode(state, n.Max) case opSliceFromToCapExpr: n, ok := n.(*ast.SliceExpr) - return ok && m.matchNode(n.X) && m.matchNode(n.Low) && m.matchNode(n.High) && m.matchNode(n.Max) + return ok && m.matchNode(state, n.X) && m.matchNode(state, n.Low) && m.matchNode(state, n.High) && m.matchNode(state, n.Max) case opIndexExpr: n, ok := n.(*ast.IndexExpr) - return ok && m.matchNode(n.X) && m.matchNode(n.Index) + return ok && m.matchNode(state, n.X) && m.matchNode(state, n.Index) case opKeyValueExpr: n, ok := n.(*ast.KeyValueExpr) - return ok && m.matchNode(n.Key) && m.matchNode(n.Value) + return ok && m.matchNode(state, n.Key) && m.matchNode(state, n.Value) case opParenExpr: n, ok := n.(*ast.ParenExpr) - return ok && m.matchNode(n.X) + return ok && m.matchNode(state, n.X) case opEllipsis: n, ok := n.(*ast.Ellipsis) return ok && n.Elt == nil case opTypedEllipsis: n, ok := n.(*ast.Ellipsis) - return ok && n.Elt != nil && m.matchNode(n.Elt) + return ok && n.Elt != nil && m.matchNode(state, n.Elt) case opSliceType: n, ok := n.(*ast.ArrayType) - return ok && n.Len == nil && m.matchNode(n.Elt) + return ok && n.Len == nil && m.matchNode(state, n.Elt) case opArrayType: n, ok := n.(*ast.ArrayType) - return ok && n.Len != nil && m.matchNode(n.Len) && m.matchNode(n.Elt) + return ok && n.Len != nil && m.matchNode(state, n.Len) && m.matchNode(state, n.Elt) case opMapType: n, ok := n.(*ast.MapType) - return ok && m.matchNode(n.Key) && m.matchNode(n.Value) + return ok && m.matchNode(state, n.Key) && m.matchNode(state, n.Value) case opChanType: n, ok := n.(*ast.ChanType) - return ok && ast.ChanDir(inst.value) == n.Dir && m.matchNode(n.Value) + return ok && ast.ChanDir(inst.value) == n.Dir && m.matchNode(state, n.Value) case opVoidFuncType: n, ok := n.(*ast.FuncType) - return ok && n.Results == nil && m.matchNode(n.Params) + return ok && n.Results == nil && m.matchNode(state, n.Params) case opFuncType: n, ok := n.(*ast.FuncType) - return ok && m.matchNode(n.Params) && m.matchNode(n.Results) + return ok && m.matchNode(state, n.Params) && m.matchNode(state, n.Results) case opStructType: n, ok := n.(*ast.StructType) - return ok && m.matchNode(n.Fields) + return ok && m.matchNode(state, n.Fields) case opInterfaceType: n, ok := n.(*ast.InterfaceType) - return ok && m.matchNode(n.Methods) + return ok && m.matchNode(state, n.Methods) case opCompositeLit: n, ok := n.(*ast.CompositeLit) - return ok && n.Type == nil && m.matchExprSlice(n.Elts) + return ok && n.Type == nil && m.matchExprSlice(state, n.Elts) case opTypedCompositeLit: n, ok := n.(*ast.CompositeLit) - return ok && n.Type != nil && m.matchNode(n.Type) && m.matchExprSlice(n.Elts) + return ok && n.Type != nil && m.matchNode(state, n.Type) && m.matchExprSlice(state, n.Elts) case opUnnamedField: n, ok := n.(*ast.Field) - return ok && len(n.Names) == 0 && m.matchNode(n.Type) + return ok && len(n.Names) == 0 && m.matchNode(state, n.Type) case opSimpleField: n, ok := n.(*ast.Field) - return ok && len(n.Names) == 1 && m.stringValue(inst) == n.Names[0].Name && m.matchNode(n.Type) + return ok && len(n.Names) == 1 && m.stringValue(inst) == n.Names[0].Name && m.matchNode(state, n.Type) case opField: n, ok := n.(*ast.Field) - return ok && len(n.Names) == 1 && m.matchNode(n.Names[0]) && m.matchNode(n.Type) + return ok && len(n.Names) == 1 && m.matchNode(state, n.Names[0]) && m.matchNode(state, n.Type) case opMultiField: n, ok := n.(*ast.Field) - return ok && len(n.Names) >= 2 && m.matchIdentSlice(n.Names) && m.matchNode(n.Type) + return ok && len(n.Names) >= 2 && m.matchIdentSlice(state, n.Names) && m.matchNode(state, n.Type) case opFieldList: // FieldList could be nil in places like function return types. n, ok := n.(*ast.FieldList) - return ok && n != nil && m.matchFieldSlice(n.List) + return ok && n != nil && m.matchFieldSlice(state, n.List) case opFuncLit: n, ok := n.(*ast.FuncLit) - return ok && m.matchNode(n.Type) && m.matchNode(n.Body) + return ok && m.matchNode(state, n.Type) && m.matchNode(state, n.Body) case opAssignStmt: n, ok := n.(*ast.AssignStmt) return ok && token.Token(inst.value) == n.Tok && - len(n.Lhs) == 1 && m.matchNode(n.Lhs[0]) && - len(n.Rhs) == 1 && m.matchNode(n.Rhs[0]) + len(n.Lhs) == 1 && m.matchNode(state, n.Lhs[0]) && + len(n.Rhs) == 1 && m.matchNode(state, n.Rhs[0]) case opMultiAssignStmt: n, ok := n.(*ast.AssignStmt) return ok && token.Token(inst.value) == n.Tok && - m.matchExprSlice(n.Lhs) && m.matchExprSlice(n.Rhs) + m.matchExprSlice(state, n.Lhs) && m.matchExprSlice(state, n.Rhs) case opExprStmt: n, ok := n.(*ast.ExprStmt) - return ok && m.matchNode(n.X) + return ok && m.matchNode(state, n.X) case opGoStmt: n, ok := n.(*ast.GoStmt) - return ok && m.matchNode(n.Call) + return ok && m.matchNode(state, n.Call) case opDeferStmt: n, ok := n.(*ast.DeferStmt) - return ok && m.matchNode(n.Call) + return ok && m.matchNode(state, n.Call) case opSendStmt: n, ok := n.(*ast.SendStmt) - return ok && m.matchNode(n.Chan) && m.matchNode(n.Value) + return ok && m.matchNode(state, n.Chan) && m.matchNode(state, n.Value) case opBlockStmt: n, ok := n.(*ast.BlockStmt) - return ok && m.matchStmtSlice(n.List) + return ok && m.matchStmtSlice(state, n.List) case opIfStmt: n, ok := n.(*ast.IfStmt) return ok && n.Init == nil && n.Else == nil && - m.matchNode(n.Cond) && m.matchNode(n.Body) + m.matchNode(state, n.Cond) && m.matchNode(state, n.Body) case opIfElseStmt: n, ok := n.(*ast.IfStmt) return ok && n.Init == nil && n.Else != nil && - m.matchNode(n.Cond) && m.matchNode(n.Body) && m.matchNode(n.Else) + m.matchNode(state, n.Cond) && m.matchNode(state, n.Body) && m.matchNode(state, n.Else) case opIfInitStmt: n, ok := n.(*ast.IfStmt) return ok && n.Else == nil && - m.matchNode(n.Init) && m.matchNode(n.Cond) && m.matchNode(n.Body) + m.matchNode(state, n.Init) && m.matchNode(state, n.Cond) && m.matchNode(state, n.Body) case opIfInitElseStmt: n, ok := n.(*ast.IfStmt) return ok && n.Else != nil && - m.matchNode(n.Init) && m.matchNode(n.Cond) && m.matchNode(n.Body) && m.matchNode(n.Else) + m.matchNode(state, n.Init) && m.matchNode(state, n.Cond) && m.matchNode(state, n.Body) && m.matchNode(state, n.Else) case opIfNamedOptStmt: n, ok := n.(*ast.IfStmt) - return ok && n.Else == nil && m.matchNode(n.Body) && - m.matchNamed(m.stringValue(inst), toStmtSlice(n.Cond, n.Init)) + return ok && n.Else == nil && m.matchNode(state, n.Body) && + m.matchNamed(state, m.stringValue(inst), toStmtSlice(n.Cond, n.Init)) case opIfNamedOptElseStmt: n, ok := n.(*ast.IfStmt) - return ok && n.Else != nil && m.matchNode(n.Body) && m.matchNode(n.Else) && - m.matchNamed(m.stringValue(inst), toStmtSlice(n.Cond, n.Init)) + return ok && n.Else != nil && m.matchNode(state, n.Body) && m.matchNode(state, n.Else) && + m.matchNamed(state, m.stringValue(inst), toStmtSlice(n.Cond, n.Init)) case opCaseClause: n, ok := n.(*ast.CaseClause) - return ok && n.List != nil && m.matchExprSlice(n.List) && m.matchStmtSlice(n.Body) + return ok && n.List != nil && m.matchExprSlice(state, n.List) && m.matchStmtSlice(state, n.Body) case opDefaultCaseClause: n, ok := n.(*ast.CaseClause) - return ok && n.List == nil && m.matchStmtSlice(n.Body) + return ok && n.List == nil && m.matchStmtSlice(state, n.Body) case opSwitchStmt: n, ok := n.(*ast.SwitchStmt) - return ok && n.Init == nil && n.Tag == nil && m.matchStmtSlice(n.Body.List) + return ok && n.Init == nil && n.Tag == nil && m.matchStmtSlice(state, n.Body.List) case opSwitchTagStmt: n, ok := n.(*ast.SwitchStmt) - return ok && n.Init == nil && m.matchNode(n.Tag) && m.matchStmtSlice(n.Body.List) + return ok && n.Init == nil && m.matchNode(state, n.Tag) && m.matchStmtSlice(state, n.Body.List) case opSwitchInitStmt: n, ok := n.(*ast.SwitchStmt) - return ok && n.Tag == nil && m.matchNode(n.Init) && m.matchStmtSlice(n.Body.List) + return ok && n.Tag == nil && m.matchNode(state, n.Init) && m.matchStmtSlice(state, n.Body.List) case opSwitchInitTagStmt: n, ok := n.(*ast.SwitchStmt) - return ok && m.matchNode(n.Init) && m.matchNode(n.Tag) && m.matchStmtSlice(n.Body.List) + return ok && m.matchNode(state, n.Init) && m.matchNode(state, n.Tag) && m.matchStmtSlice(state, n.Body.List) case opTypeSwitchStmt: n, ok := n.(*ast.TypeSwitchStmt) - return ok && n.Init == nil && m.matchNode(n.Assign) && m.matchStmtSlice(n.Body.List) + return ok && n.Init == nil && m.matchNode(state, n.Assign) && m.matchStmtSlice(state, n.Body.List) case opTypeSwitchInitStmt: n, ok := n.(*ast.TypeSwitchStmt) - return ok && m.matchNode(n.Init) && - m.matchNode(n.Assign) && m.matchStmtSlice(n.Body.List) + return ok && m.matchNode(state, n.Init) && + m.matchNode(state, n.Assign) && m.matchStmtSlice(state, n.Body.List) case opCommClause: n, ok := n.(*ast.CommClause) - return ok && n.Comm != nil && m.matchNode(n.Comm) && m.matchStmtSlice(n.Body) + return ok && n.Comm != nil && m.matchNode(state, n.Comm) && m.matchStmtSlice(state, n.Body) case opDefaultCommClause: n, ok := n.(*ast.CommClause) - return ok && n.Comm == nil && m.matchStmtSlice(n.Body) + return ok && n.Comm == nil && m.matchStmtSlice(state, n.Body) case opSelectStmt: n, ok := n.(*ast.SelectStmt) - return ok && m.matchStmtSlice(n.Body.List) + return ok && m.matchStmtSlice(state, n.Body.List) case opRangeStmt: n, ok := n.(*ast.RangeStmt) - return ok && n.Key == nil && n.Value == nil && m.matchNode(n.X) && m.matchNode(n.Body) + return ok && n.Key == nil && n.Value == nil && m.matchNode(state, n.X) && m.matchNode(state, n.Body) case opRangeKeyStmt: n, ok := n.(*ast.RangeStmt) return ok && n.Key != nil && n.Value == nil && token.Token(inst.value) == n.Tok && - m.matchNode(n.Key) && m.matchNode(n.X) && m.matchNode(n.Body) + m.matchNode(state, n.Key) && m.matchNode(state, n.X) && m.matchNode(state, n.Body) case opRangeKeyValueStmt: n, ok := n.(*ast.RangeStmt) return ok && n.Key != nil && n.Value != nil && token.Token(inst.value) == n.Tok && - m.matchNode(n.Key) && m.matchNode(n.Value) && m.matchNode(n.X) && m.matchNode(n.Body) + m.matchNode(state, n.Key) && m.matchNode(state, n.Value) && m.matchNode(state, n.X) && m.matchNode(state, n.Body) case opForStmt: n, ok := n.(*ast.ForStmt) return ok && n.Init == nil && n.Cond == nil && n.Post == nil && - m.matchNode(n.Body) + m.matchNode(state, n.Body) case opForPostStmt: n, ok := n.(*ast.ForStmt) return ok && n.Init == nil && n.Cond == nil && n.Post != nil && - m.matchNode(n.Post) && m.matchNode(n.Body) + m.matchNode(state, n.Post) && m.matchNode(state, n.Body) case opForCondStmt: n, ok := n.(*ast.ForStmt) return ok && n.Init == nil && n.Cond != nil && n.Post == nil && - m.matchNode(n.Cond) && m.matchNode(n.Body) + m.matchNode(state, n.Cond) && m.matchNode(state, n.Body) case opForCondPostStmt: n, ok := n.(*ast.ForStmt) return ok && n.Init == nil && n.Cond != nil && n.Post != nil && - m.matchNode(n.Cond) && m.matchNode(n.Post) && m.matchNode(n.Body) + m.matchNode(state, n.Cond) && m.matchNode(state, n.Post) && m.matchNode(state, n.Body) case opForInitStmt: n, ok := n.(*ast.ForStmt) return ok && n.Init != nil && n.Cond == nil && n.Post == nil && - m.matchNode(n.Init) && m.matchNode(n.Body) + m.matchNode(state, n.Init) && m.matchNode(state, n.Body) case opForInitPostStmt: n, ok := n.(*ast.ForStmt) return ok && n.Init != nil && n.Cond == nil && n.Post != nil && - m.matchNode(n.Init) && m.matchNode(n.Post) && m.matchNode(n.Body) + m.matchNode(state, n.Init) && m.matchNode(state, n.Post) && m.matchNode(state, n.Body) case opForInitCondStmt: n, ok := n.(*ast.ForStmt) return ok && n.Init != nil && n.Cond != nil && n.Post == nil && - m.matchNode(n.Init) && m.matchNode(n.Cond) && m.matchNode(n.Body) + m.matchNode(state, n.Init) && m.matchNode(state, n.Cond) && m.matchNode(state, n.Body) case opForInitCondPostStmt: n, ok := n.(*ast.ForStmt) - return ok && m.matchNode(n.Init) && m.matchNode(n.Cond) && m.matchNode(n.Post) && m.matchNode(n.Body) + return ok && m.matchNode(state, n.Init) && m.matchNode(state, n.Cond) && m.matchNode(state, n.Post) && m.matchNode(state, n.Body) case opIncDecStmt: n, ok := n.(*ast.IncDecStmt) - return ok && token.Token(inst.value) == n.Tok && m.matchNode(n.X) + return ok && token.Token(inst.value) == n.Tok && m.matchNode(state, n.X) case opReturnStmt: n, ok := n.(*ast.ReturnStmt) - return ok && m.matchExprSlice(n.Results) + return ok && m.matchExprSlice(state, n.Results) case opLabeledStmt: n, ok := n.(*ast.LabeledStmt) - return ok && m.matchNode(n.Label) && m.matchNode(n.Stmt) + return ok && m.matchNode(state, n.Label) && m.matchNode(state, n.Stmt) case opSimpleLabeledStmt: n, ok := n.(*ast.LabeledStmt) - return ok && m.stringValue(inst) == n.Label.Name && m.matchNode(n.Stmt) + return ok && m.stringValue(inst) == n.Label.Name && m.matchNode(state, n.Stmt) case opLabeledBranchStmt: n, ok := n.(*ast.BranchStmt) - return ok && n.Label != nil && token.Token(inst.value) == n.Tok && m.matchNode(n.Label) + return ok && n.Label != nil && token.Token(inst.value) == n.Tok && m.matchNode(state, n.Label) case opSimpleLabeledBranchStmt: n, ok := n.(*ast.BranchStmt) return ok && n.Label != nil && m.stringValue(inst) == n.Label.Name && token.Token(inst.value) == n.Tok @@ -510,124 +501,124 @@ func (m *matcher) matchNodeWithInst(inst instruction, n ast.Node) bool { case opFuncDecl: n, ok := n.(*ast.FuncDecl) return ok && n.Recv == nil && n.Body != nil && - m.matchNode(n.Name) && m.matchNode(n.Type) && m.matchNode(n.Body) + m.matchNode(state, n.Name) && m.matchNode(state, n.Type) && m.matchNode(state, n.Body) case opFuncProtoDecl: n, ok := n.(*ast.FuncDecl) return ok && n.Recv == nil && n.Body == nil && - m.matchNode(n.Name) && m.matchNode(n.Type) + m.matchNode(state, n.Name) && m.matchNode(state, n.Type) case opMethodDecl: n, ok := n.(*ast.FuncDecl) return ok && n.Recv != nil && n.Body != nil && - m.matchNode(n.Recv) && m.matchNode(n.Name) && m.matchNode(n.Type) && m.matchNode(n.Body) + m.matchNode(state, n.Recv) && m.matchNode(state, n.Name) && m.matchNode(state, n.Type) && m.matchNode(state, n.Body) case opMethodProtoDecl: n, ok := n.(*ast.FuncDecl) return ok && n.Recv != nil && n.Body == nil && - m.matchNode(n.Recv) && m.matchNode(n.Name) && m.matchNode(n.Type) + m.matchNode(state, n.Recv) && m.matchNode(state, n.Name) && m.matchNode(state, n.Type) case opValueSpec: n, ok := n.(*ast.ValueSpec) return ok && len(n.Values) == 0 && n.Type == nil && - len(n.Names) == 1 && m.matchNode(n.Names[0]) + len(n.Names) == 1 && m.matchNode(state, n.Names[0]) case opValueInitSpec: n, ok := n.(*ast.ValueSpec) return ok && len(n.Values) != 0 && n.Type == nil && - m.matchIdentSlice(n.Names) && m.matchExprSlice(n.Values) + m.matchIdentSlice(state, n.Names) && m.matchExprSlice(state, n.Values) case opTypedValueSpec: n, ok := n.(*ast.ValueSpec) return ok && len(n.Values) == 0 && n.Type != nil && - m.matchIdentSlice(n.Names) && m.matchNode(n.Type) + m.matchIdentSlice(state, n.Names) && m.matchNode(state, n.Type) case opTypedValueInitSpec: n, ok := n.(*ast.ValueSpec) return ok && len(n.Values) != 0 && n.Type != nil && - m.matchIdentSlice(n.Names) && m.matchNode(n.Type) && m.matchExprSlice(n.Values) + m.matchIdentSlice(state, n.Names) && m.matchNode(state, n.Type) && m.matchExprSlice(state, n.Values) case opTypeSpec: n, ok := n.(*ast.TypeSpec) - return ok && !n.Assign.IsValid() && m.matchNode(n.Name) && m.matchNode(n.Type) + return ok && !n.Assign.IsValid() && m.matchNode(state, n.Name) && m.matchNode(state, n.Type) case opTypeAliasSpec: n, ok := n.(*ast.TypeSpec) - return ok && n.Assign.IsValid() && m.matchNode(n.Name) && m.matchNode(n.Type) + return ok && n.Assign.IsValid() && m.matchNode(state, n.Name) && m.matchNode(state, n.Type) case opDeclStmt: n, ok := n.(*ast.DeclStmt) - return ok && m.matchNode(n.Decl) + return ok && m.matchNode(state, n.Decl) case opConstDecl: n, ok := n.(*ast.GenDecl) - return ok && n.Tok == token.CONST && m.matchSpecSlice(n.Specs) + return ok && n.Tok == token.CONST && m.matchSpecSlice(state, n.Specs) case opVarDecl: n, ok := n.(*ast.GenDecl) - return ok && n.Tok == token.VAR && m.matchSpecSlice(n.Specs) + return ok && n.Tok == token.VAR && m.matchSpecSlice(state, n.Specs) case opTypeDecl: n, ok := n.(*ast.GenDecl) - return ok && n.Tok == token.TYPE && m.matchSpecSlice(n.Specs) + return ok && n.Tok == token.TYPE && m.matchSpecSlice(state, n.Specs) case opEmptyPackage: n, ok := n.(*ast.File) - return ok && len(n.Imports) == 0 && len(n.Decls) == 0 && m.matchNode(n.Name) + return ok && len(n.Imports) == 0 && len(n.Decls) == 0 && m.matchNode(state, n.Name) default: panic(fmt.Sprintf("unexpected op %s", inst.op)) } } -func (m *matcher) matchNode(n ast.Node) bool { - return m.matchNodeWithInst(m.nextInst(), n) +func (m *matcher) matchNode(state *MatcherState, n ast.Node) bool { + return m.matchNodeWithInst(state, m.nextInst(state), n) } -func (m *matcher) matchArgList(exprs []ast.Expr) bool { - inst := m.nextInst() +func (m *matcher) matchArgList(state *MatcherState, exprs []ast.Expr) bool { + inst := m.nextInst(state) if inst.op != opSimpleArgList { - return m.matchExprSlice(exprs) + return m.matchExprSlice(state, exprs) } if len(exprs) != int(inst.value) { return false } for _, x := range exprs { - if !m.matchNode(x) { + if !m.matchNode(state, x) { return false } } return true } -func (m *matcher) matchStmtSlice(stmts []ast.Stmt) bool { - matched, _ := m.matchNodeList(stmtSlice(stmts), false) +func (m *matcher) matchStmtSlice(state *MatcherState, stmts []ast.Stmt) bool { + matched, _ := m.matchNodeList(state, stmtSlice(stmts), false) return matched != nil } -func (m *matcher) matchExprSlice(exprs []ast.Expr) bool { - matched, _ := m.matchNodeList(ExprSlice(exprs), false) +func (m *matcher) matchExprSlice(state *MatcherState, exprs []ast.Expr) bool { + matched, _ := m.matchNodeList(state, ExprSlice(exprs), false) return matched != nil } -func (m *matcher) matchFieldSlice(fields []*ast.Field) bool { - matched, _ := m.matchNodeList(fieldSlice(fields), false) +func (m *matcher) matchFieldSlice(state *MatcherState, fields []*ast.Field) bool { + matched, _ := m.matchNodeList(state, fieldSlice(fields), false) return matched != nil } -func (m *matcher) matchIdentSlice(idents []*ast.Ident) bool { - matched, _ := m.matchNodeList(identSlice(idents), false) +func (m *matcher) matchIdentSlice(state *MatcherState, idents []*ast.Ident) bool { + matched, _ := m.matchNodeList(state, identSlice(idents), false) return matched != nil } -func (m *matcher) matchSpecSlice(specs []ast.Spec) bool { - matched, _ := m.matchNodeList(specSlice(specs), false) +func (m *matcher) matchSpecSlice(state *MatcherState, specs []ast.Spec) bool { + matched, _ := m.matchNodeList(state, specSlice(specs), false) return matched != nil } // matchNodeList matches two lists of nodes. It uses a common algorithm to match // wildcard patterns with any number of nodes without recursion. -func (m *matcher) matchNodeList(nodes NodeSlice, partial bool) (ast.Node, int) { +func (m *matcher) matchNodeList(state *MatcherState, nodes NodeSlice, partial bool) (ast.Node, int) { sliceLen := nodes.Len() - inst := m.nextInst() + inst := m.nextInst(state) if inst.op == opEnd { if sliceLen == 0 { return nodes, 0 } return nil, -1 } - pcBase := m.pc + pcBase := state.pc pcNext := 0 j := 0 jNext := 0 @@ -652,14 +643,14 @@ func (m *matcher) matchNodeList(nodes NodeSlice, partial bool) (ast.Node, int) { if next > sliceLen { return // would be discarded anyway } - pcNext = m.pc - 1 + pcNext = state.pc - 1 jNext = next - stack = append(stack, restart{m.capture, pcNext, next, wildStart, wildName}) + stack = append(stack, restart{state.capture, pcNext, next, wildStart, wildName}) } pop := func() { j = jNext - m.pc = pcNext - m.capture = stack[len(stack)-1].matches + state.pc = pcNext + state.capture = stack[len(stack)-1].matches wildName = stack[len(stack)-1].wildName wildStart = stack[len(stack)-1].wildStart stack = stack[:len(stack)-1] @@ -678,9 +669,9 @@ func (m *matcher) matchNodeList(nodes NodeSlice, partial bool) (ast.Node, int) { case "", "_": return true } - return m.matchNamed(wildName, nodes.slice(wildStart, j)) + return m.matchNamed(state, wildName, nodes.slice(wildStart, j)) } - for ; inst.op != opEnd || j < sliceLen; inst = m.nextInst() { + for ; inst.op != opEnd || j < sliceLen; inst = m.nextInst(state) { if inst.op != opEnd { if inst.op == opNodeSeq || inst.op == opNamedNodeSeq { // keep track of where this wildcard @@ -700,13 +691,13 @@ func (m *matcher) matchNodeList(nodes NodeSlice, partial bool) (ast.Node, int) { push(j + 1) continue } - if partial && m.pc == pcBase { + if partial && state.pc == pcBase { // let "b; c" match "a; b; c" // (simulates a $*_ at the beginning) partialStart = j push(j + 1) } - if j < sliceLen && wouldMatch() && m.matchNodeWithInst(inst, nodes.At(j)) { + if j < sliceLen && wouldMatch() && m.matchNodeWithInst(state, inst, nodes.At(j)) { // ordinary match wildName = "" j++ @@ -718,7 +709,7 @@ func (m *matcher) matchNodeList(nodes NodeSlice, partial bool) (ast.Node, int) { break // let "b; c" match "b; c; d" } // mismatch, try to restart - if 0 < jNext && jNext <= sliceLen && (m.pc != pcNext || j != jNext) { + if 0 < jNext && jNext <= sliceLen && (state.pc != pcNext || j != jNext) { pop() continue } diff --git a/internal/gogrep/match_perf_test.go b/internal/gogrep/match_perf_test.go index 91bdd881..61a8ad14 100644 --- a/internal/gogrep/match_perf_test.go +++ b/internal/gogrep/match_perf_test.go @@ -162,6 +162,8 @@ func BenchmarkMatch(b *testing.B) { for i := range tests { test := tests[i] b.Run(test.name, func(b *testing.B) { + state := NewMatcherState() + fset := token.NewFileSet() pat, _, err := Compile(fset, test.pat, true) if err != nil { @@ -175,7 +177,7 @@ func BenchmarkMatch(b *testing.B) { } if !strings.HasPrefix(test.name, "fail") { matches := 0 - testAllMatches(pat, target, func(m MatchData) { + testAllMatches(pat, &state, target, func(m MatchData) { matches++ }) if matches == 0 { @@ -186,7 +188,7 @@ func BenchmarkMatch(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { - testAllMatches(pat, target, func(m MatchData) {}) + testAllMatches(pat, &state, target, func(m MatchData) {}) } }) } diff --git a/internal/gogrep/match_test.go b/internal/gogrep/match_test.go index 2dd49378..bee4427f 100644 --- a/internal/gogrep/match_test.go +++ b/internal/gogrep/match_test.go @@ -871,6 +871,7 @@ func TestMatch(t *testing.T) { for i := range tests { test := tests[i] t.Run(fmt.Sprintf("test%d", i), func(t *testing.T) { + state := NewMatcherState() fset := token.NewFileSet() testPattern := unwrapPattern(test.pat) pat, _, err := Compile(fset, testPattern, isStrict(test.pat)) @@ -884,7 +885,7 @@ func TestMatch(t *testing.T) { return } matches := 0 - testAllMatches(pat, target, func(m MatchData) { + testAllMatches(pat, &state, target, func(m MatchData) { matches++ }) if matches != test.numMatches { @@ -895,12 +896,12 @@ func TestMatch(t *testing.T) { } } -func testAllMatches(p *Pattern, target ast.Node, cb func(MatchData)) { +func testAllMatches(p *Pattern, state *MatcherState, target ast.Node, cb func(MatchData)) { visit := func(n ast.Node) bool { if n == nil { return false } - p.MatchNode(nil, n, cb) + p.MatchNode(state, n, cb) return true } ast.Inspect(target, visit) diff --git a/ruleguard/engine.go b/ruleguard/engine.go index 3937b690..8ff828a1 100644 --- a/ruleguard/engine.go +++ b/ruleguard/engine.go @@ -118,7 +118,7 @@ func (e *engine) Run(ctx *RunContext, f *ast.File) error { if e.ruleSet == nil { return errors.New("used Run() with an empty rule set; forgot to call Load() first?") } - rset := cloneRuleSet(e.ruleSet) + rset := e.ruleSet return newRulesRunner(ctx, e.state, rset).run(f) } diff --git a/ruleguard/gorule.go b/ruleguard/gorule.go index ab8fe841..cfc7b70d 100644 --- a/ruleguard/gorule.go +++ b/ruleguard/gorule.go @@ -98,14 +98,6 @@ func (params *filterParams) typeofNode(n ast.Node) types.Type { return types.Typ[types.Invalid] } -func cloneRuleSet(rset *goRuleSet) *goRuleSet { - out, err := mergeRuleSets([]*goRuleSet{rset}) - if err != nil { - panic(err) // Should never happen - } - return out -} - func mergeRuleSets(toMerge []*goRuleSet) (*goRuleSet, error) { out := &goRuleSet{ universal: &scopedGoRuleSet{}, diff --git a/ruleguard/runner.go b/ruleguard/runner.go index 6b0f027e..799e84c5 100644 --- a/ruleguard/runner.go +++ b/ruleguard/runner.go @@ -22,6 +22,8 @@ type rulesRunner struct { ctx *RunContext rules *goRuleSet + gogrepState gogrep.MatcherState + importer *goImporter filename string @@ -51,11 +53,14 @@ func newRulesRunner(ctx *RunContext, state *engineState, rules *goRuleSet) *rule debugImports: ctx.DebugImports, debugPrint: ctx.DebugPrint, }) + gogrepState := gogrep.NewMatcherState() + gogrepState.Types = ctx.Types rr := &rulesRunner{ - ctx: ctx, - importer: importer, - rules: rules, - nodePath: newNodePath(), + ctx: ctx, + importer: importer, + rules: rules, + gogrepState: gogrepState, + nodePath: newNodePath(), filterParams: filterParams{ env: state.env.GetEvalEnv(), importer: importer, @@ -203,7 +208,7 @@ func (rr *rulesRunner) runRules(n ast.Node) { tag := nodetag.FromNode(n) for _, rule := range rr.rules.universal.rulesByTag[tag] { matched := false - rule.pat.MatchNode(rr.ctx.Types, n, func(m gogrep.MatchData) { + rule.pat.MatchNode(&rr.gogrepState, n, func(m gogrep.MatchData) { matched = rr.handleMatch(rule, m) }) if matched {