diff --git a/go.mod b/go.mod index 0f123bcf694..9f72934af39 100644 --- a/go.mod +++ b/go.mod @@ -25,6 +25,7 @@ require ( github.com/cyberdelia/go-metrics-graphite v0.0.0-20161219230853-39f87cc3b432 github.com/dave/jennifer v1.4.1 github.com/evanphx/json-patch v4.5.0+incompatible + github.com/fsnotify/fsnotify v1.4.9 github.com/go-martini/martini v0.0.0-20170121215854-22fa46961aab github.com/go-sql-driver/mysql v1.5.1-0.20210202043019-fe2230a8b20c github.com/gogo/protobuf v1.3.1 diff --git a/go.sum b/go.sum index 552aec1c7cd..4b9b554eb6e 100644 --- a/go.sum +++ b/go.sum @@ -182,6 +182,8 @@ github.com/felixge/httpsnoop v1.0.1 h1:lvB5Jl89CsZtGIWuTcDM1E/vkVs49/Ml7JJe07l8S github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fsnotify/fsnotify v1.4.7 h1:IXs+QLmnXW2CcXuY+8Mzv/fWEsPGWxqefPtCP5CnV9I= github.com/fsnotify/fsnotify v1.4.7/go.mod h1:jwhsz4b93w/PPRr/qN1Yymfu8t87LnFCMoQvtojpjFo= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/fsnotify/fsnotify v1.4.9/go.mod h1:znqG4EE+3YCdAaPaxE2ZRY/06pZUdp0tY4IgpuI1SZQ= github.com/ghodss/yaml v0.0.0-20150909031657-73d445a93680/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/ghodss/yaml v1.0.0 h1:wQHKEahhL6wmXdzwWG11gIVCkOv05bNOh+Rxn0yngAk= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= @@ -783,6 +785,7 @@ golang.org/x/sys v0.0.0-20190624142023-c5567b49c5d0/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20190826190057-c7b8b68b1456/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190922100055-0a153f010e69/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20190924154521-2837fb4f24fe/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191005200804-aed5e4c7ecf9/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191008105621-543471e840be/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= diff --git a/go/cmd/rulesctl/cmd/add.go b/go/cmd/rulesctl/cmd/add.go index 76e80999360..e83e494dc0a 100644 --- a/go/cmd/rulesctl/cmd/add.go +++ b/go/cmd/rulesctl/cmd/add.go @@ -13,13 +13,15 @@ import ( ) var ( - addOptDryrun bool - addOptName string - addOptDescription string - addOptAction string - addOptPlans []string - addOptTables []string - addOptQueryRE string + addOptDryrun bool + addOptName string + addOptDescription string + addOptAction string + addOptPlans []string + addOptTables []string + addOptQueryRE string + addOptLeadingCommentRE string + addOptTrailingCommentRE string // TODO: other stuff, bind vars etc ) @@ -36,8 +38,20 @@ func runAdd(cmd *cobra.Command, args []string) { rule.AddTableCond(t) } - if err := rule.SetQueryCond(addOptQueryRE); err != nil { - log.Fatalf("Query condition invalid '%v': %v", addOptQueryRE, err) + if addOptQueryRE != "" { + if err := rule.SetQueryCond(addOptQueryRE); err != nil { + log.Fatalf("Query condition invalid '%v': %v", addOptQueryRE, err) + } + } + if addOptLeadingCommentRE != "" { + if err := rule.SetLeadingCommentCond(addOptLeadingCommentRE); err != nil { + log.Fatalf("Leading comment condition invalid '%v': %v", addOptLeadingCommentRE, err) + } + } + if addOptTrailingCommentRE != "" { + if err := rule.SetTrailingCommentCond(addOptTrailingCommentRE); err != nil { + log.Fatalf("Trailing comment condition invalid '%v': %v", addOptTrailingCommentRE, err) + } } var rules *vtrules.Rules @@ -141,6 +155,16 @@ func Add() *cobra.Command { "query", "q", "", "A regexp that will be applied to a query in order to determine if it matches") + addCmd.Flags().StringVarP( + &addOptLeadingCommentRE, + "leading-comment", "l", + "", + "A regexp that will be applied to comments prefacing a SQL statement") + addCmd.Flags().StringVarP( + &addOptTrailingCommentRE, + "trailing-comment", "r", + "", + "A regexp that will be applied to comments after a SQL statement") for _, f := range []string{"name", "action"} { addCmd.MarkFlagRequired(f) diff --git a/go/vt/vttablet/customrule/filecustomrule/filecustomrule.go b/go/vt/vttablet/customrule/filecustomrule/filecustomrule.go index 4ad7b64da45..eb52f04d7f8 100644 --- a/go/vt/vttablet/customrule/filecustomrule/filecustomrule.go +++ b/go/vt/vttablet/customrule/filecustomrule/filecustomrule.go @@ -1,5 +1,5 @@ /* -Copyright 2019 The Vitess Authors. +Copyright 2021 The Vitess Authors. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. @@ -20,9 +20,13 @@ package filecustomrule import ( "flag" "io/ioutil" + "path" "time" + "github.com/fsnotify/fsnotify" + "vitess.io/vitess/go/vt/log" + "vitess.io/vitess/go/vt/servenv" "vitess.io/vitess/go/vt/vttablet/tabletserver" "vitess.io/vitess/go/vt/vttablet/tabletserver/rules" ) @@ -32,6 +36,8 @@ var ( fileCustomRule = NewFileCustomRule() // Commandline flag to specify rule path fileRulePath = flag.String("filecustomrules", "", "file based custom rule path") + + fileRuleShouldWatch = flag.Bool("filecustomrules_watch", false, "set up a watch on the target file and reload query rules when it changes") ) // FileCustomRule is an implementation of CustomRuleManager, it reads custom query @@ -102,6 +108,45 @@ func ActivateFileCustomRules(qsc tabletserver.Controller) { if *fileRulePath != "" { qsc.RegisterQueryRuleSource(FileCustomRuleSource) fileCustomRule.Open(qsc, *fileRulePath) + + if *fileRuleShouldWatch { + baseDir := path.Dir(*fileRulePath) + ruleFileName := path.Base(*fileRulePath) + + watcher, err := fsnotify.NewWatcher() + if err != nil { + log.Fatalf("Unable create new fsnotify watcher: %v", err) + } + servenv.OnTerm(func() { watcher.Close() }) + + go func(tsc tabletserver.Controller) { + for { + select { + case evt, ok := <-watcher.Events: + if !ok { + return + } + if path.Base(evt.Name) != ruleFileName { + continue + } + if err := fileCustomRule.Open(tsc, *fileRulePath); err != nil { + log.Infof("Failed to load custom rules from %q: %v", *fileRulePath, err) + } else { + log.Infof("Loaded custom rules from %q", *fileRulePath) + } + case err, ok := <-watcher.Errors: + if !ok { + return + } + log.Errorf("Error watching %v: %v", *fileRulePath, err) + } + } + }(qsc) + + if err = watcher.Add(baseDir); err != nil { + log.Fatalf("Unable to set up watcher for %v + %v: %v", baseDir, ruleFileName, err) + } + } } } diff --git a/go/vt/vttablet/tabletserver/query_executor.go b/go/vt/vttablet/tabletserver/query_executor.go index 40aa2c6e65e..17ecb13ce79 100644 --- a/go/vt/vttablet/tabletserver/query_executor.go +++ b/go/vt/vttablet/tabletserver/query_executor.go @@ -382,7 +382,7 @@ func (qre *QueryExecutor) checkPermissions() error { remoteAddr = ci.RemoteAddr() username = ci.Username() } - action, desc := qre.plan.Rules.GetAction(remoteAddr, username, qre.bindVars) + action, desc := qre.plan.Rules.GetAction(remoteAddr, username, qre.bindVars, qre.marginComments) switch action { case rules.QRFail: return vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "disallowed due to rule: %s", desc) diff --git a/go/vt/vttablet/tabletserver/rules/cached_size.go b/go/vt/vttablet/tabletserver/rules/cached_size.go index 61cd23586c2..60403475105 100644 --- a/go/vt/vttablet/tabletserver/rules/cached_size.go +++ b/go/vt/vttablet/tabletserver/rules/cached_size.go @@ -43,7 +43,7 @@ func (cached *Rule) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(184) + size += int64(232) } // field Description string size += int64(len(cached.Description)) @@ -55,6 +55,10 @@ func (cached *Rule) CachedSize(alloc bool) int64 { size += cached.user.CachedSize(false) // field query vitess.io/vitess/go/vt/vttablet/tabletserver/rules.namedRegexp size += cached.query.CachedSize(false) + // field leadingComment vitess.io/vitess/go/vt/vttablet/tabletserver/rules.namedRegexp + size += cached.leadingComment.CachedSize(false) + // field trailingComment vitess.io/vitess/go/vt/vttablet/tabletserver/rules.namedRegexp + size += cached.trailingComment.CachedSize(false) // field plans []vitess.io/vitess/go/vt/vttablet/tabletserver/planbuilder.PlanType { size += int64(cap(cached.plans)) * int64(8) diff --git a/go/vt/vttablet/tabletserver/rules/rules.go b/go/vt/vttablet/tabletserver/rules/rules.go index 3cca04cee31..5ba71791030 100644 --- a/go/vt/vttablet/tabletserver/rules/rules.go +++ b/go/vt/vttablet/tabletserver/rules/rules.go @@ -27,6 +27,7 @@ import ( "vitess.io/vitess/go/vt/vtgate/evalengine" "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vttablet/tabletserver/planbuilder" @@ -166,9 +167,14 @@ func (qrs *Rules) FilterByPlan(query string, planid planbuilder.PlanType, tableN } // GetAction runs the input against the rules engine and returns the action to be performed. -func (qrs *Rules) GetAction(ip, user string, bindVars map[string]*querypb.BindVariable) (action Action, desc string) { +func (qrs *Rules) GetAction( + ip, + user string, + bindVars map[string]*querypb.BindVariable, + marginComments sqlparser.MarginComments, +) (action Action, desc string) { for _, qr := range qrs.rules { - if act := qr.GetAction(ip, user, bindVars); act != QRContinue { + if act := qr.GetAction(ip, user, bindVars, marginComments); act != QRContinue { return act, qr.Description } } @@ -192,7 +198,7 @@ type Rule struct { // All defined conditions must match for the rule to fire (AND). // Regexp conditions. nil conditions are ignored (TRUE). - requestIP, user, query namedRegexp + requestIP, user, query, leadingComment, trailingComment namedRegexp // Any matched plan will make this condition true (OR) plans []planbuilder.PlanType @@ -241,6 +247,8 @@ func (qr *Rule) Equal(other *Rule) bool { qr.requestIP.Equal(other.requestIP) && qr.user.Equal(other.user) && qr.query.Equal(other.query) && + qr.leadingComment.Equal(other.leadingComment) && + qr.trailingComment.Equal(other.trailingComment) && reflect.DeepEqual(qr.plans, other.plans) && reflect.DeepEqual(qr.tableNames, other.tableNames) && reflect.DeepEqual(qr.bindVarConds, other.bindVarConds) && @@ -250,12 +258,14 @@ func (qr *Rule) Equal(other *Rule) bool { // Copy performs a deep copy of a Rule. func (qr *Rule) Copy() (newqr *Rule) { newqr = &Rule{ - Description: qr.Description, - Name: qr.Name, - requestIP: qr.requestIP, - user: qr.user, - query: qr.query, - act: qr.act, + Description: qr.Description, + Name: qr.Name, + requestIP: qr.requestIP, + user: qr.user, + query: qr.query, + leadingComment: qr.leadingComment, + trailingComment: qr.trailingComment, + act: qr.act, } if qr.plans != nil { newqr.plans = make([]planbuilder.PlanType, len(qr.plans)) @@ -286,6 +296,12 @@ func (qr *Rule) MarshalJSON() ([]byte, error) { if qr.query.Regexp != nil { safeEncode(b, `,"Query":`, qr.query) } + if qr.leadingComment.Regexp != nil { + safeEncode(b, `,"LeadingComment":`, qr.leadingComment) + } + if qr.trailingComment.Regexp != nil { + safeEncode(b, `,"TrailingComment":`, qr.trailingComment) + } if qr.plans != nil { safeEncode(b, `,"Plans":`, qr.plans) } @@ -339,6 +355,20 @@ func (qr *Rule) SetQueryCond(pattern string) (err error) { return } +// SetLeadingCommentCond adds a regular expression condition for a leading query comment. +func (qr *Rule) SetLeadingCommentCond(pattern string) (err error) { + qr.leadingComment.name = pattern + qr.leadingComment.Regexp, err = regexp.Compile(makeExact(pattern)) + return +} + +// SetTrailingCommentCond adds a regular expression condition for a trailing query comment. +func (qr *Rule) SetTrailingCommentCond(pattern string) (err error) { + qr.trailingComment.name = pattern + qr.trailingComment.Regexp, err = regexp.Compile(makeExact(pattern)) + return +} + // makeExact forces a full string match for the regex instead of substring func makeExact(pattern string) string { return fmt.Sprintf("^%s$", pattern) @@ -418,13 +448,26 @@ func (qr *Rule) FilterByPlan(query string, planid planbuilder.PlanType, tableNam } newqr = qr.Copy() newqr.query = namedRegexp{} + // Note we explicitly don't remove the leading/trailing comments as they + // must be evaluated at execution time. newqr.plans = nil newqr.tableNames = nil return newqr } // GetAction returns the action for a single rule. -func (qr *Rule) GetAction(ip, user string, bindVars map[string]*querypb.BindVariable) Action { +func (qr *Rule) GetAction( + ip, + user string, + bindVars map[string]*querypb.BindVariable, + marginComments sqlparser.MarginComments, +) Action { + if !reMatch(qr.leadingComment.Regexp, marginComments.Leading) { + return QRContinue + } + if !reMatch(qr.trailingComment.Regexp, marginComments.Trailing) { + return QRContinue + } if !reMatch(qr.requestIP.Regexp, ip) { return QRContinue } @@ -791,7 +834,7 @@ func BuildQueryRule(ruleInfo map[string]interface{}) (qr *Rule, err error) { var lv []interface{} var ok bool switch k { - case "Name", "Description", "RequestIP", "User", "Query", "Action": + case "Name", "Description", "RequestIP", "User", "Query", "Action", "LeadingComment", "TrailingComment": sv, ok = v.(string) if !ok { return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "want string for %s", k) @@ -824,6 +867,16 @@ func BuildQueryRule(ruleInfo map[string]interface{}) (qr *Rule, err error) { if err != nil { return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not set Query condition: %v", sv) } + case "LeadingComment": + err = qr.SetLeadingCommentCond(sv) + if err != nil { + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not set LeadingComment condition: %v", sv) + } + case "TrailingComment": + err = qr.SetTrailingCommentCond(sv) + if err != nil { + return nil, vterrors.Errorf(vtrpcpb.Code_INVALID_ARGUMENT, "could not set TrailingComment condition: %v", sv) + } case "Plans": for _, p := range lv { pv, ok := p.(string) diff --git a/go/vt/vttablet/tabletserver/rules/rules_test.go b/go/vt/vttablet/tabletserver/rules/rules_test.go index 76bddebc10f..951f4146789 100644 --- a/go/vt/vttablet/tabletserver/rules/rules_test.go +++ b/go/vt/vttablet/tabletserver/rules/rules_test.go @@ -25,6 +25,7 @@ import ( "testing" "vitess.io/vitess/go/sqltypes" + "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vttablet/tabletserver/planbuilder" @@ -496,31 +497,68 @@ func TestAction(t *testing.T) { bv := make(map[string]*querypb.BindVariable) bv["a"] = sqltypes.Uint64BindVariable(0) - action, desc := qrs.GetAction("123", "user1", bv) + + mc := sqlparser.MarginComments{ + Leading: "some comments leading the query", + Trailing: "other trailing comments", + } + + action, desc := qrs.GetAction("123", "user1", bv, mc) if action != QRFail { t.Errorf("want fail") } if desc != "rule 1" { t.Errorf("want rule 1, got %s", desc) } - action, desc = qrs.GetAction("1234", "user", bv) + action, desc = qrs.GetAction("1234", "user", bv, mc) if action != QRFailRetry { t.Errorf("want fail_retry") } if desc != "rule 2" { t.Errorf("want rule 2, got %s", desc) } - action, _ = qrs.GetAction("1234", "user1", bv) + action, _ = qrs.GetAction("1234", "user1", bv, mc) if action != QRContinue { t.Errorf("want continue") } + bv["a"] = sqltypes.Uint64BindVariable(1) - action, desc = qrs.GetAction("1234", "user1", bv) + action, desc = qrs.GetAction("1234", "user1", bv, mc) if action != QRFail { t.Errorf("want fail") } if desc != "rule 3" { - t.Errorf("want rule 2, got %s", desc) + t.Errorf("want rule 3, got %s", desc) + } + + // reset bound variable 'a' to 0 so it doesn't match rule 3 + bv["a"] = sqltypes.Uint64BindVariable(0) + + qr4 := NewQueryRule("rule 4", "r4", QRFail) + qr4.SetTrailingCommentCond(".*trailing.*") + + newQrs := qrs.Copy() + newQrs.Add(qr4) + + action, desc = newQrs.GetAction("1234", "user1", bv, mc) + if action != QRFail { + t.Errorf("want fail") + } + if desc != "rule 4" { + t.Errorf("want rule 4, got %s", desc) + } + + qr5 := NewQueryRule("rule 5", "r4", QRFail) + qr5.SetLeadingCommentCond(".*leading.*") + + newQrs = qrs.Copy() + newQrs.Add(qr5) + action, desc = newQrs.GetAction("1234", "user1", bv, mc) + if action != QRFail { + t.Errorf("want fail") + } + if desc != "rule 5" { + t.Errorf("want rule 5, got %s", desc) } }