diff --git a/internal/mvdan.cc/gogrep/gogrep.go b/internal/mvdan.cc/gogrep/gogrep.go index bb7aa458..6a899348 100644 --- a/internal/mvdan.cc/gogrep/gogrep.go +++ b/internal/mvdan.cc/gogrep/gogrep.go @@ -10,7 +10,11 @@ type ExprList = exprList // Parse creates a gogrep pattern out of a given string expression. func Parse(fset *token.FileSet, expr string, strict bool) (*Pattern, error) { - m := matcher{fset: fset, strict: strict} + m := matcher{ + fset: fset, + strict: strict, + capture: make([]CapturedNode, 0, 8), + } node, err := m.parseExpr(expr) if err != nil { return nil, err @@ -26,8 +30,17 @@ type Pattern struct { // MatchData describes a successful pattern match. type MatchData struct { - Node ast.Node - Values map[string]ast.Node + Node ast.Node + Capture []CapturedNode +} + +type CapturedNode struct { + Name string + Node ast.Node +} + +func (data MatchData) CapturedByName(name string) (ast.Node, bool) { + return findNamed(data.Capture, name) } // Clone creates a pattern copy. @@ -35,17 +48,17 @@ func (p *Pattern) Clone() *Pattern { clone := *p clone.m = &matcher{} *clone.m = *p.m - clone.m.values = make(map[string]ast.Node) + clone.m.capture = make([]CapturedNode, 0, 8) return &clone } // MatchNode calls cb if n matches a pattern. func (p *Pattern) MatchNode(n ast.Node, cb func(MatchData)) { - p.m.values = map[string]ast.Node{} + p.m.capture = p.m.capture[:0] if p.m.node(p.Expr, n) { cb(MatchData{ - Values: p.m.values, - Node: n, + Capture: p.m.capture, + Node: n, }) } } @@ -62,14 +75,14 @@ func (p *Pattern) matchNodeList(pattern, list nodeList, cb func(MatchData)) { listLen := list.len() from := 0 for { - p.m.values = map[string]ast.Node{} + p.m.capture = p.m.capture[:0] matched, offset := p.m.nodes(pattern, list.slice(from, listLen), true) if matched == nil { break } cb(MatchData{ - Values: p.m.values, - Node: matched, + Capture: p.m.capture, + Node: matched, }) from += offset - 1 if from >= listLen { diff --git a/internal/mvdan.cc/gogrep/gogrep_perf_test.go b/internal/mvdan.cc/gogrep/gogrep_perf_test.go new file mode 100644 index 00000000..9449631b --- /dev/null +++ b/internal/mvdan.cc/gogrep/gogrep_perf_test.go @@ -0,0 +1,76 @@ +package gogrep + +import ( + "go/token" + "testing" +) + +func BenchmarkMatch(b *testing.B) { + tests := []struct { + name string + pat string + input string + }{ + { + name: `simpleLit`, + pat: `true`, + input: `true`, + }, + { + name: `capture1`, + pat: `+$x`, + input: `+50`, + }, + { + name: `capture2`, + pat: `$x + $y`, + input: `x + 4`, + }, + { + name: `capture8`, + pat: `f($x1, $x2, $x3, $x4, $x5, $x6, $x7, $x8)`, + input: `f(1, 2, 3, 4, 5, 6, 7, 8)`, + }, + { + name: `capture2same`, + pat: `$x + $x`, + input: `a + a`, + }, + { + name: `capture8same`, + pat: `f($x, $x, $x, $x, $x, $x, $x, $x)`, + input: `f(1, 1, 1, 1, 1, 1, 1, 1)`, + }, + { + name: `captureBacktrackLeft`, + pat: `f($*xs, $y)`, + input: `f(1, 2, 3, 4, 5, 6)`, + }, + { + name: `captureBacktrackRight`, + pat: `f($x, $*ys)`, + input: `f(1, 2, 3, 4, 5, 6)`, + }, + } + + for i := range tests { + test := tests[i] + b.Run(test.name, func(b *testing.B) { + fset := token.NewFileSet() + pat, err := Parse(fset, test.pat, true) + if err != nil { + b.Errorf("parse `%s`: %v", test.pat, err) + return + } + target := testParseNode(b, test.input) + if err != nil { + b.Errorf("parse target `%s`: %v", test.input, err) + return + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + testAllMatches(pat, target, func(m MatchData) {}) + } + }) + } +} diff --git a/internal/mvdan.cc/gogrep/gogrep_test.go b/internal/mvdan.cc/gogrep/gogrep_test.go index 3403f2c5..e9ab07c7 100644 --- a/internal/mvdan.cc/gogrep/gogrep_test.go +++ b/internal/mvdan.cc/gogrep/gogrep_test.go @@ -149,8 +149,8 @@ func TestCapture(t *testing.T) { } capture := vars{} pat.MatchNode(target, func(m MatchData) { - for k, n := range m.Values { - capture[k] = sprintNode(n) + for _, c := range m.Capture { + capture[c.Name] = sprintNode(c.Node) } }) if diff := cmp.Diff(capture, test.capture); diff != "" { @@ -623,7 +623,7 @@ func testAllMatches(p *Pattern, target ast.Node, cb func(MatchData)) { }) } -func testParseNode(t *testing.T, s string) ast.Node { +func testParseNode(t testing.TB, s string) ast.Node { if strings.HasPrefix(s, "package ") { file, err := parser.ParseFile(token.NewFileSet(), "string", s, 0) if err != nil { diff --git a/internal/mvdan.cc/gogrep/match.go b/internal/mvdan.cc/gogrep/match.go index 86efd7e1..44d6b7df 100644 --- a/internal/mvdan.cc/gogrep/match.go +++ b/internal/mvdan.cc/gogrep/match.go @@ -15,7 +15,7 @@ type matcher struct { // node values recorded by name, excluding "_" (used only by the // actual matching phase) - values map[string]ast.Node + capture []CapturedNode strict bool } @@ -25,12 +25,10 @@ type varInfo struct { Any bool } -func valsCopy(values map[string]ast.Node) map[string]ast.Node { - v2 := make(map[string]ast.Node, len(values)) - for k, v := range values { - v2[k] = v - } - return v2 +func captureCopy(capture []CapturedNode) []CapturedNode { + copied := make([]CapturedNode, len(capture)) + copy(copied, capture) + return copied } // optNode is like node, but for those nodes that can be nil and are not @@ -88,10 +86,10 @@ func (m *matcher) node(expr, node ast.Node) bool { // values are discarded, matches anything return true } - prev, ok := m.values[info.Name] + prev, ok := findNamed(m.capture, info.Name) if !ok { // first occurrence, record value - m.values[info.Name] = node + m.capture = append(m.capture, CapturedNode{Name: info.Name, Node: node}) return true } // multiple uses must match @@ -396,7 +394,7 @@ func (m *matcher) nodes(ns1, ns2 nodeList, partial bool) (ast.Node, int) { // with a different "any of" match while discarding any matches // we found while trying it. type restart struct { - matches map[string]ast.Node + matches []CapturedNode next1, next2 int } // We need to stack these because otherwise some edge cases @@ -409,12 +407,12 @@ func (m *matcher) nodes(ns1, ns2 nodeList, partial bool) (ast.Node, int) { if n2 > ns2len { return // would be discarded anyway } - stack = append(stack, restart{valsCopy(m.values), n1, n2}) + stack = append(stack, restart{captureCopy(m.capture), n1, n2}) next1, next2 = n1, n2 } pop := func() { i1, i2 = next1, next2 - m.values = stack[len(stack)-1].matches + m.capture = stack[len(stack)-1].matches stack = stack[:len(stack)-1] next1, next2 = 0, 0 if len(stack) > 0 { @@ -434,11 +432,11 @@ func (m *matcher) nodes(ns1, ns2 nodeList, partial bool) (ast.Node, int) { } list := ns2.slice(wildStart, i2) // check that it matches any nodes found elsewhere - prev, ok := m.values[wildName] + prev, ok := findNamed(m.capture, wildName) if ok && !m.node(prev, list) { return false } - m.values[wildName] = list + m.capture = append(m.capture, CapturedNode{Name: wildName, Node: list}) return true } for i1 < ns1len || i2 < ns2len { @@ -642,3 +640,12 @@ func literalValue(lit *ast.BasicLit) interface{} { } return nil } + +func findNamed(capture []CapturedNode, name string) (ast.Node, bool) { + for _, c := range capture { + if c.Name == name { + return c.Node, true + } + } + return nil, false +} diff --git a/ruleguard/filters.go b/ruleguard/filters.go index 3816405a..5f2c7858 100644 --- a/ruleguard/filters.go +++ b/ruleguard/filters.go @@ -224,7 +224,8 @@ func makeTextFilter(src, varname string, op token.Token, rhsVarname string) filt return func(params *filterParams) matchFilterResult { s1 := params.nodeText(params.subExpr(varname)) lhsValue := constant.MakeString(string(s1)) - s2 := params.nodeText(params.values[rhsVarname]) + n, _ := params.match.CapturedByName(rhsVarname) + s2 := params.nodeText(n) rhsValue := constant.MakeString(string(s2)) if constant.Compare(lhsValue, op, rhsValue) { return filterSuccess diff --git a/ruleguard/gorule.go b/ruleguard/gorule.go index 0e8058a1..e271e8a1 100644 --- a/ruleguard/gorule.go +++ b/ruleguard/gorule.go @@ -53,7 +53,7 @@ type filterParams struct { importer *goImporter - values map[string]ast.Node + match gogrep.MatchData nodeText func(n ast.Node) []byte @@ -62,7 +62,8 @@ type filterParams struct { } func (params *filterParams) subExpr(name string) ast.Expr { - switch n := params.values[name].(type) { + n, _ := params.match.CapturedByName(name) + switch n := n.(type) { case ast.Expr: return n case *ast.ExprStmt: diff --git a/ruleguard/ruleguard_test.go b/ruleguard/ruleguard_test.go index a01be33e..702ee5a9 100644 --- a/ruleguard/ruleguard_test.go +++ b/ruleguard/ruleguard_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/google/go-cmp/cmp" + "github.com/quasilyte/go-ruleguard/internal/mvdan.cc/gogrep" ) func TestRenderMessage(t *testing.T) { @@ -89,12 +90,19 @@ func TestRenderMessage(t *testing.T) { Fset: token.NewFileSet(), } for _, test := range tests { - nodes := make(map[string]ast.Node, len(test.vars)) - for _, v := range test.vars { - nodes[v] = &ast.Ident{Name: v + "var"} + capture := make([]gogrep.CapturedNode, len(test.vars)) + for i, v := range test.vars { + capture[i] = gogrep.CapturedNode{ + Name: v, + Node: &ast.Ident{Name: v + "var"}, + } } - have := rr.renderMessage(test.msg, &ast.Ident{Name: "dd"}, nodes, false) + m := gogrep.MatchData{ + Node: &ast.Ident{Name: "dd"}, + Capture: capture, + } + have := rr.renderMessage(test.msg, m, false) if diff := cmp.Diff(have, test.want); diff != "" { t.Errorf("render %s %v:\n(+want -have)\n%s", test.msg, test.vars, diff) } diff --git a/ruleguard/runner.go b/ruleguard/runner.go index 8801d1f4..fcde0194 100644 --- a/ruleguard/runner.go +++ b/ruleguard/runner.go @@ -26,9 +26,6 @@ type rulesRunner struct { filename string src []byte - // A slice that is used to do a nodes keys sorting in renderMessage(). - sortScratch []string - filterParams filterParams } @@ -47,7 +44,6 @@ func newRulesRunner(ctx *RunContext, state *engineState, rules *goRuleSet) *rule importer: importer, ctx: ctx, }, - sortScratch: make([]string, 0, 8), } rr.filterParams.nodeText = rr.nodeText return rr @@ -167,21 +163,15 @@ func (rr *rulesRunner) reject(rule goRule, reason string, m gogrep.MatchData) { rr.ctx.DebugPrint(fmt.Sprintf("%s:%d: [%s:%d] rejected by %s", pos.Filename, pos.Line, filepath.Base(rule.filename), rule.line, reason)) - type namedNode struct { - name string - node ast.Node - } - values := make([]namedNode, 0, len(m.Values)) - for name, node := range m.Values { - values = append(values, namedNode{name: name, node: node}) - } + values := make([]gogrep.CapturedNode, len(m.Capture)) + copy(values, m.Capture) sort.Slice(values, func(i, j int) bool { - return values[i].name < values[j].name + return values[i].Name < values[j].Name }) for _, v := range values { - name := v.name - node := v.node + name := v.Name + node := v.Node var expr ast.Expr switch node := node.(type) { case ast.Expr: @@ -204,7 +194,7 @@ func (rr *rulesRunner) reject(rule goRule, reason string, m gogrep.MatchData) { func (rr *rulesRunner) handleMatch(rule goRule, m gogrep.MatchData) bool { if rule.filter.fn != nil { - rr.filterParams.values = m.Values + rr.filterParams.match = m filterResult := rule.filter.fn(&rr.filterParams) if !filterResult.Matched() { rr.reject(rule, filterResult.RejectReason(), m) @@ -212,15 +202,15 @@ func (rr *rulesRunner) handleMatch(rule goRule, m gogrep.MatchData) bool { } } - message := rr.renderMessage(rule.msg, m.Node, m.Values, true) + message := rr.renderMessage(rule.msg, m, true) node := m.Node if rule.location != "" { - node = m.Values[rule.location] + node, _ = m.CapturedByName(rule.location) } var suggestion *Suggestion if rule.suggestion != "" { suggestion = &Suggestion{ - Replacement: []byte(rr.renderMessage(rule.suggestion, m.Node, m.Values, false)), + Replacement: []byte(rr.renderMessage(rule.suggestion, m, false)), From: node.Pos(), To: node.End(), } @@ -245,27 +235,25 @@ func (rr *rulesRunner) collectImports(f *ast.File) { } } -func (rr *rulesRunner) renderMessage(msg string, n ast.Node, nodes map[string]ast.Node, truncate bool) string { +func (rr *rulesRunner) renderMessage(msg string, m gogrep.MatchData, truncate bool) string { var buf strings.Builder if strings.Contains(msg, "$$") { - buf.Write(rr.nodeText(n)) + buf.Write(rr.nodeText(m.Node)) msg = strings.ReplaceAll(msg, "$$", buf.String()) } - if len(nodes) == 0 { + if len(m.Capture) == 0 { return msg } - rr.sortScratch = rr.sortScratch[:0] - for name := range nodes { - rr.sortScratch = append(rr.sortScratch, name) - } - sort.Slice(rr.sortScratch, func(i, j int) bool { - return len(rr.sortScratch[i]) > len(rr.sortScratch[j]) + capture := make([]gogrep.CapturedNode, len(m.Capture)) + copy(capture, m.Capture) + sort.Slice(capture, func(i, j int) bool { + return len(capture[i].Name) > len(capture[j].Name) }) - for _, name := range rr.sortScratch { - n := nodes[name] - key := "$" + name + for _, c := range capture { + n := c.Node + key := "$" + c.Name if !strings.Contains(msg, key) { continue }