diff --git a/go/vt/sqlparser/comments.go b/go/vt/sqlparser/comments.go index 497b95e5366..9d5df6d31f6 100644 --- a/go/vt/sqlparser/comments.go +++ b/go/vt/sqlparser/comments.go @@ -32,6 +32,8 @@ const ( DirectiveQueryTimeout = "QUERY_TIMEOUT_MS" // DirectiveScatterErrorsAsWarnings enables partial success scatter select queries DirectiveScatterErrorsAsWarnings = "SCATTER_ERRORS_AS_WARNINGS" + // DirectiveIgnoreMaxPayloadSize skips payload size validation when set. + DirectiveIgnoreMaxPayloadSize = "IGNORE_MAX_PAYLOAD_SIZE" ) func isNonSpace(r rune) bool { @@ -298,3 +300,24 @@ func SkipQueryPlanCacheDirective(stmt Statement) bool { } return false } + +// IgnoreMaxPayloadSizeDirective returns true if the max payload size override +// directive is set to true. +func IgnoreMaxPayloadSizeDirective(stmt Statement) bool { + switch stmt := stmt.(type) { + case *Select: + directives := ExtractCommentDirectives(stmt.Comments) + return directives.IsSet(DirectiveIgnoreMaxPayloadSize) + case *Insert: + directives := ExtractCommentDirectives(stmt.Comments) + return directives.IsSet(DirectiveIgnoreMaxPayloadSize) + case *Update: + directives := ExtractCommentDirectives(stmt.Comments) + return directives.IsSet(DirectiveIgnoreMaxPayloadSize) + case *Delete: + directives := ExtractCommentDirectives(stmt.Comments) + return directives.IsSet(DirectiveIgnoreMaxPayloadSize) + default: + return false + } +} diff --git a/go/vt/sqlparser/comments_test.go b/go/vt/sqlparser/comments_test.go index 3d875faf1cb..8ec2a0e1995 100644 --- a/go/vt/sqlparser/comments_test.go +++ b/go/vt/sqlparser/comments_test.go @@ -17,8 +17,11 @@ limitations under the License. package sqlparser import ( + "fmt" "reflect" "testing" + + "github.com/stretchr/testify/assert" ) func TestSplitComments(t *testing.T) { @@ -385,3 +388,22 @@ func TestSkipQueryPlanCacheDirective(t *testing.T) { t.Errorf("d.SkipQueryPlanCacheDirective(stmt) should be true") } } + +func TestIgnoreMaxPayloadSizeDirective(t *testing.T) { + testCases := []struct { + query string + expected bool + }{ + {"insert /*vt+ IGNORE_MAX_PAYLOAD_SIZE=1 */ into user(id) values (1), (2)", true}, + {"insert into user(id) values (1), (2)", false}, + {"update /*vt+ IGNORE_MAX_PAYLOAD_SIZE=1 */ users set name=1", true}, + {"select /*vt+ IGNORE_MAX_PAYLOAD_SIZE=1 */ * from users", true}, + {"delete /*vt+ IGNORE_MAX_PAYLOAD_SIZE=1 */ from users", true}, + } + + for _, test := range testCases { + stmt, _ := Parse(test.query) + got := IgnoreMaxPayloadSizeDirective(stmt) + assert.Equalf(t, test.expected, got, fmt.Sprintf("d.IgnoreMaxPayloadSizeDirective(stmt) returned %v but expected %v", got, test.expected)) + } +} diff --git a/go/vt/vtgate/executor.go b/go/vt/vtgate/executor.go index a3e5942c6f6..58bb05ee3ff 100644 --- a/go/vt/vtgate/executor.go +++ b/go/vt/vtgate/executor.go @@ -1283,10 +1283,14 @@ func (e *Executor) getPlan(vcursor *vcursorImpl, sql string, comments sqlparser. return nil, err } - // Normalize if possible and retry. query := sql statement := stmt bindVarNeeds := sqlparser.BindVarNeeds{} + if !sqlparser.IgnoreMaxPayloadSizeDirective(statement) && !isValidPayloadSize(query) { + return nil, vterrors.New(vtrpcpb.Code_RESOURCE_EXHAUSTED, "query payload size above threshold") + } + + // Normalize if possible and retry. if (e.normalize && sqlparser.CanNormalize(stmt)) || sqlparser.IsSetStatement(stmt) { parameterize := e.normalize // the public flag is called normalize result, err := sqlparser.PrepareAST(stmt, bindVars, "vtg", parameterize) @@ -1495,6 +1499,21 @@ func checkLikeOpt(likeOpt string, colNames []string) (string, error) { return "", nil } +// isValidPayloadSize validates whether a query payload is above the +// configured MaxPayloadSize threshold. The WarnPayloadSizeExceeded will increment +// if the payload size exceeds the warnPayloadSize. + +func isValidPayloadSize(query string) bool { + payloadSize := len(query) + if *maxPayloadSize > 0 && payloadSize > *maxPayloadSize { + return false + } + if *warnPayloadSize > 0 && payloadSize > *warnPayloadSize { + warnings.Add("WarnPayloadSizeExceeded", 1) + } + return true +} + // Prepare executes a prepare statements. func (e *Executor) Prepare(ctx context.Context, method string, safeSession *SafeSession, sql string, bindVars map[string]*querypb.BindVariable) (fld []*querypb.Field, err error) { logStats := NewLogStats(ctx, method, sql, bindVars) diff --git a/go/vt/vtgate/executor_test.go b/go/vt/vtgate/executor_test.go index 9176022c940..f018d8e840e 100644 --- a/go/vt/vtgate/executor_test.go +++ b/go/vt/vtgate/executor_test.go @@ -1869,6 +1869,54 @@ func TestGenerateCharsetRows(t *testing.T) { } } +func TestExecutorMaxPayloadSizeExceeded(t *testing.T) { + saveMax := *maxPayloadSize + saveWarn := *warnPayloadSize + *maxPayloadSize = 10 + *warnPayloadSize = 5 + defer func() { + *maxPayloadSize = saveMax + *warnPayloadSize = saveWarn + }() + + executor, _, _, _ := createExecutorEnv() + session := NewSafeSession(&vtgatepb.Session{TargetString: "@master"}) + warningCount := warnings.Counts()["WarnPayloadSizeExceeded"] + testMaxPayloadSizeExceeded := []string{ + "select * from main1", + "select * from main1", + "insert into main1(id) values (1), (2)", + "update main1 set id=1", + "delete from main1 where id=1", + } + for _, query := range testMaxPayloadSizeExceeded { + _, err := executor.Execute(context.Background(), "TestExecutorMaxPayloadSizeExceeded", session, query, nil) + if err == nil { + assert.EqualError(t, err, "query payload size above threshold") + } + } + assert.Equal(t, warningCount, warnings.Counts()["WarnPayloadSizeExceeded"], "warnings count") + + testMaxPayloadSizeOverride := []string{ + "select /*vt+ IGNORE_MAX_PAYLOAD_SIZE=1 */ * from main1", + "insert /*vt+ IGNORE_MAX_PAYLOAD_SIZE=1 */ into main1(id) values (1), (2)", + "update /*vt+ IGNORE_MAX_PAYLOAD_SIZE=1 */ main1 set id=1", + "delete /*vt+ IGNORE_MAX_PAYLOAD_SIZE=1 */ from main1 where id=1", + } + for _, query := range testMaxPayloadSizeOverride { + _, err := executor.Execute(context.Background(), "TestExecutorMaxPayloadSizeWithOverride", session, query, nil) + assert.Equal(t, nil, err, "err should be nil") + } + assert.Equal(t, warningCount, warnings.Counts()["WarnPayloadSizeExceeded"], "warnings count") + + *maxPayloadSize = 1000 + for _, query := range testMaxPayloadSizeExceeded { + _, err := executor.Execute(context.Background(), "TestExecutorMaxPayloadSizeExceeded", session, query, nil) + assert.Equal(t, nil, err, "err should be nil") + } + assert.Equal(t, warningCount+4, warnings.Counts()["WarnPayloadSizeExceeded"], "warnings count") +} + func TestOlapSelectDatabase(t *testing.T) { executor, _, _, _ := createExecutorEnv() executor.normalize = true diff --git a/go/vt/vtgate/vtgate.go b/go/vt/vtgate/vtgate.go index cd04de6dbc1..781248f4874 100644 --- a/go/vt/vtgate/vtgate.go +++ b/go/vt/vtgate/vtgate.go @@ -67,6 +67,8 @@ var ( HealthCheckRetryDelay = flag.Duration("healthcheck_retry_delay", 2*time.Millisecond, "health check retry delay") // HealthCheckTimeout is the timeout on the RPC call to tablets HealthCheckTimeout = flag.Duration("healthcheck_timeout", time.Minute, "the health check timeout period") + maxPayloadSize = flag.Int("max_payload_size", 0, "The threshold for query payloads in bytes. A payload greater than this threshold will result in a failure to handle the query.") + warnPayloadSize = flag.Int("warn_payload_size", 0, "The warning threshold for query payloads in bytes. A payload greater than this threshold will cause the VtGateWarnings.WarnPayloadSizeExceeded counter to be incremented.") ) func getTxMode() vtgatepb.TransactionMode { @@ -194,7 +196,7 @@ func Init(ctx context.Context, serv srvtopo.Server, cell string, tabletTypesToWa _ = stats.NewRates("ErrorsByDbType", stats.CounterForDimension(errorCounts, "DbType"), 15, 1*time.Minute) _ = stats.NewRates("ErrorsByCode", stats.CounterForDimension(errorCounts, "Code"), 15, 1*time.Minute) - warnings = stats.NewCountersWithSingleLabel("VtGateWarnings", "Vtgate warnings", "type", "IgnoredSet", "ResultsExceeded") + warnings = stats.NewCountersWithSingleLabel("VtGateWarnings", "Vtgate warnings", "type", "IgnoredSet", "ResultsExceeded", "WarnPayloadSizeExceeded") servenv.OnRun(func() { for _, f := range RegisterVTGates {