From 14e79eccf28be0180052c663a36cd66a4145873b Mon Sep 17 00:00:00 2001 From: Sugu Sougoumarane Date: Thu, 13 Jun 2019 22:13:07 -0700 Subject: [PATCH] vtgate sql: enforce a max row count limit I've been claiming that vtgate had a max row count limit. It turns out that I never actually implemented it. So, here's the implementation. Signed-off-by: Sugu Sougoumarane --- go/vt/vtgate/engine/fake_vcursor_test.go | 6 ++ go/vt/vtgate/engine/join.go | 3 + go/vt/vtgate/engine/join_test.go | 72 ++++++++++++++++++++++-- go/vt/vtgate/engine/primitive.go | 3 + go/vt/vtgate/scatter_conn.go | 24 ++++++-- go/vt/vtgate/scatter_conn_test.go | 51 +++++++++++++++++ go/vt/vtgate/vcursor_impl.go | 5 ++ go/vt/vtgate/vtgate.go | 1 + 8 files changed, 155 insertions(+), 10 deletions(-) diff --git a/go/vt/vtgate/engine/fake_vcursor_test.go b/go/vt/vtgate/engine/fake_vcursor_test.go index 0e498012125..ba7c5b30ada 100644 --- a/go/vt/vtgate/engine/fake_vcursor_test.go +++ b/go/vt/vtgate/engine/fake_vcursor_test.go @@ -35,6 +35,8 @@ import ( vtgatepb "vitess.io/vitess/go/vt/proto/vtgate" ) +var testMaxMemoryRows = 100 + // noopVCursor is used to build other vcursors. type noopVCursor struct { } @@ -43,6 +45,10 @@ func (t noopVCursor) Context() context.Context { return context.Background() } +func (t noopVCursor) MaxMemoryRows() int { + return testMaxMemoryRows +} + func (t noopVCursor) SetContextTimeout(timeout time.Duration) context.CancelFunc { return func() {} } diff --git a/go/vt/vtgate/engine/join.go b/go/vt/vtgate/engine/join.go index e6f3e6e629b..6519050991b 100644 --- a/go/vt/vtgate/engine/join.go +++ b/go/vt/vtgate/engine/join.go @@ -88,6 +88,9 @@ func (jn *Join) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariab } else { result.RowsAffected += uint64(len(rresult.Rows)) } + if len(result.Rows) > vcursor.MaxMemoryRows() { + return nil, fmt.Errorf("in-memory row count exceeded allowed limit of %d", vcursor.MaxMemoryRows()) + } } return result, nil } diff --git a/go/vt/vtgate/engine/join_test.go b/go/vt/vtgate/engine/join_test.go index 6936cbb2a59..26bacfe7c60 100644 --- a/go/vt/vtgate/engine/join_test.go +++ b/go/vt/vtgate/engine/join_test.go @@ -74,7 +74,7 @@ func TestJoinExecute(t *testing.T) { "bv": 1, }, } - r, err := jn.Execute(nil, bv, true) + r, err := jn.Execute(noopVCursor{}, bv, true) if err != nil { t.Fatal(err) } @@ -101,7 +101,7 @@ func TestJoinExecute(t *testing.T) { leftPrim.rewind() rightPrim.rewind() jn.Opcode = LeftJoin - r, err = jn.Execute(nil, bv, true) + r, err = jn.Execute(noopVCursor{}, bv, true) if err != nil { t.Fatal(err) } @@ -126,6 +126,66 @@ func TestJoinExecute(t *testing.T) { )) } +func TestJoinExecuteMaxMemoryRows(t *testing.T) { + save := testMaxMemoryRows + testMaxMemoryRows = 3 + defer func() { testMaxMemoryRows = save }() + + leftPrim := &fakePrimitive{ + results: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3", + "int64|varchar|varchar", + ), + "1|a|aa", + "2|b|bb", + "3|c|cc", + ), + }, + } + rightFields := sqltypes.MakeTestFields( + "col4|col5|col6", + "int64|varchar|varchar", + ) + rightPrim := &fakePrimitive{ + results: []*sqltypes.Result{ + sqltypes.MakeTestResult( + rightFields, + "4|d|dd", + ), + sqltypes.MakeTestResult( + rightFields, + ), + sqltypes.MakeTestResult( + rightFields, + "5|e|ee", + "6|f|ff", + "7|g|gg", + ), + }, + } + bv := map[string]*querypb.BindVariable{ + "a": sqltypes.Int64BindVariable(10), + } + + // Normal join + jn := &Join{ + Opcode: NormalJoin, + Left: leftPrim, + Right: rightPrim, + Cols: []int{-1, -2, 1, 2}, + Vars: map[string]int{ + "bv": 1, + }, + } + _, err := jn.Execute(noopVCursor{}, bv, true) + want := "in-memory row count exceeded allowed limit of 3" + if err == nil || err.Error() != want { + t.Errorf("Execute(): %v, want %v", err, want) + } +} + func TestJoinExecuteNoResult(t *testing.T) { leftPrim := &fakePrimitive{ results: []*sqltypes.Result{ @@ -158,7 +218,7 @@ func TestJoinExecuteNoResult(t *testing.T) { "bv": 1, }, } - r, err := jn.Execute(nil, map[string]*querypb.BindVariable{}, true) + r, err := jn.Execute(noopVCursor{}, map[string]*querypb.BindVariable{}, true) if err != nil { t.Fatal(err) } @@ -188,7 +248,7 @@ func TestJoinExecuteErrors(t *testing.T) { Opcode: NormalJoin, Left: leftPrim, } - _, err := jn.Execute(nil, map[string]*querypb.BindVariable{}, true) + _, err := jn.Execute(noopVCursor{}, map[string]*querypb.BindVariable{}, true) expectError(t, "jn.Execute", err, "left err") // Error on right query @@ -218,7 +278,7 @@ func TestJoinExecuteErrors(t *testing.T) { "bv": 1, }, } - _, err = jn.Execute(nil, map[string]*querypb.BindVariable{}, true) + _, err = jn.Execute(noopVCursor{}, map[string]*querypb.BindVariable{}, true) expectError(t, "jn.Execute", err, "right err") // Error on right getfields @@ -245,7 +305,7 @@ func TestJoinExecuteErrors(t *testing.T) { "bv": 1, }, } - _, err = jn.Execute(nil, map[string]*querypb.BindVariable{}, true) + _, err = jn.Execute(noopVCursor{}, map[string]*querypb.BindVariable{}, true) expectError(t, "jn.Execute", err, "right err") } diff --git a/go/vt/vtgate/engine/primitive.go b/go/vt/vtgate/engine/primitive.go index d95d0c9d3c9..cf3a6cc0ca3 100644 --- a/go/vt/vtgate/engine/primitive.go +++ b/go/vt/vtgate/engine/primitive.go @@ -45,6 +45,9 @@ type VCursor interface { // Context returns the context of the current request. Context() context.Context + // MaxMemoryRows returns the maxMemoryRows flag value. + MaxMemoryRows() int + // SetContextTimeout updates the context and sets a timeout. SetContextTimeout(timeout time.Duration) context.CancelFunc diff --git a/go/vt/vtgate/scatter_conn.go b/go/vt/vtgate/scatter_conn.go index fb08ed0e0a1..3d0b5c71a2d 100644 --- a/go/vt/vtgate/scatter_conn.go +++ b/go/vt/vtgate/scatter_conn.go @@ -157,9 +157,17 @@ func (stc *ScatterConn) Execute( mu.Lock() defer mu.Unlock() - qr.AppendResult(innerqr) + // Don't append more rows if row count is exceeded. + if len(qr.Rows) <= *maxMemoryRows { + qr.AppendResult(innerqr) + } return transactionID, nil - }) + }, + ) + + if len(qr.Rows) > *maxMemoryRows { + return nil, vterrors.Errorf(vtrpcpb.Code_RESOURCE_EXHAUSTED, "in-memory row count exceeded allowed limit of %d", *maxMemoryRows) + } return qr, allErrors.AggrError(vterrors.Aggregate) } @@ -215,9 +223,17 @@ func (stc *ScatterConn) ExecuteMultiShard( mu.Lock() defer mu.Unlock() - qr.AppendResult(innerqr) + // Don't append more rows if row count is exceeded. + if len(qr.Rows) <= *maxMemoryRows { + qr.AppendResult(innerqr) + } return transactionID, nil - }) + }, + ) + + if len(qr.Rows) > *maxMemoryRows { + return nil, []error{vterrors.Errorf(vtrpcpb.Code_RESOURCE_EXHAUSTED, "in-memory row count exceeded allowed limit of %d", *maxMemoryRows)} + } return qr, allErrors.GetErrors() } diff --git a/go/vt/vtgate/scatter_conn_test.go b/go/vt/vtgate/scatter_conn_test.go index 3b46bd0bf68..a07161f07f0 100644 --- a/go/vt/vtgate/scatter_conn_test.go +++ b/go/vt/vtgate/scatter_conn_test.go @@ -248,6 +248,57 @@ func testScatterConnGeneric(t *testing.T, name string, f func(sc *ScatterConn, s } } +func TestMaxMemoryRows(t *testing.T) { + save := *maxMemoryRows + *maxMemoryRows = 3 + defer func() { *maxMemoryRows = save }() + + createSandbox("TestMaxMemoryRows") + hc := discovery.NewFakeHealthCheck() + sc := newTestScatterConn(hc, new(sandboxTopo), "aa") + sbc0 := hc.AddTestTablet("aa", "0", 1, "TestMaxMemoryRows", "0", topodatapb.TabletType_REPLICA, true, 1, nil) + sbc1 := hc.AddTestTablet("aa", "1", 1, "TestMaxMemoryRows", "1", topodatapb.TabletType_REPLICA, true, 1, nil) + + tworows := &sqltypes.Result{ + Rows: [][]sqltypes.Value{{ + sqltypes.NewInt64(1), + }, { + sqltypes.NewInt64(1), + }}, + RowsAffected: 1, + InsertID: 1, + } + sbc0.SetResults([]*sqltypes.Result{tworows, tworows}) + sbc1.SetResults([]*sqltypes.Result{tworows, tworows}) + + res := srvtopo.NewResolver(&sandboxTopo{}, sc.gateway, "aa") + rss, _, err := res.ResolveDestinations(context.Background(), "TestMaxMemoryRows", topodatapb.TabletType_REPLICA, nil, + []key.Destination{key.DestinationShard("0"), key.DestinationShard("1")}) + if err != nil { + t.Fatalf("ResolveDestination(0) failed: %v", err) + } + session := NewSafeSession(&vtgatepb.Session{InTransaction: true}) + + _, err = sc.Execute(context.Background(), "query1", nil, rss, topodatapb.TabletType_REPLICA, session, true, nil) + want := "in-memory row count exceeded allowed limit of 3" + if err == nil || err.Error() != want { + t.Errorf("Execute(): %v, want %v", err, want) + } + + queries := []*querypb.BoundQuery{{ + Sql: "query1", + BindVariables: map[string]*querypb.BindVariable{}, + }, { + Sql: "query1", + BindVariables: map[string]*querypb.BindVariable{}, + }} + _, errs := sc.ExecuteMultiShard(context.Background(), rss, queries, topodatapb.TabletType_REPLICA, session, false, false) + err = errs[0] + if err == nil || err.Error() != want { + t.Errorf("Execute(): %v, want %v", err, want) + } +} + func TestMultiExecs(t *testing.T) { createSandbox("TestMultiExecs") hc := discovery.NewFakeHealthCheck() diff --git a/go/vt/vtgate/vcursor_impl.go b/go/vt/vtgate/vcursor_impl.go index fb9e96f9c62..8da8069b3a9 100644 --- a/go/vt/vtgate/vcursor_impl.go +++ b/go/vt/vtgate/vcursor_impl.go @@ -75,6 +75,11 @@ func (vc *vcursorImpl) Context() context.Context { return vc.ctx } +// MaxMemoryRows returns the maxMemoryRows flag value. +func (vc *vcursorImpl) MaxMemoryRows() int { + return *maxMemoryRows +} + // SetContextTimeout updates context and sets a timeout. func (vc *vcursorImpl) SetContextTimeout(timeout time.Duration) context.CancelFunc { ctx, cancel := context.WithTimeout(vc.ctx, timeout) diff --git a/go/vt/vtgate/vtgate.go b/go/vt/vtgate/vtgate.go index 8a7d1e21cbf..b52b8335f1e 100644 --- a/go/vt/vtgate/vtgate.go +++ b/go/vt/vtgate/vtgate.go @@ -60,6 +60,7 @@ var ( streamBufferSize = flag.Int("stream_buffer_size", 32*1024, "the number of bytes sent from vtgate for each stream call. It's recommended to keep this value in sync with vttablet's query-server-config-stream-buffer-size.") queryPlanCacheSize = flag.Int64("gate_query_cache_size", 10000, "gate server query cache size, maximum number of queries to be cached. vtgate analyzes every incoming query and generate a query plan, these plans are being cached in a lru cache. This config controls the capacity of the lru cache.") disableLocalGateway = flag.Bool("disable_local_gateway", false, "if specified, this process will not route any queries to local tablets in the local cell") + maxMemoryRows = flag.Int("max_memory_rows", 30000, "Maximum number of rows that will be held in memory for intermediate results as well as the final result.") ) func getTxMode() vtgatepb.TransactionMode {