Skip to content

Commit

Permalink
Merge pull request #4930 from planetscale/ss-max-rows
Browse files Browse the repository at this point in the history
vtgate sql: enforce a max row count limit
  • Loading branch information
sougou authored Jun 16, 2019
2 parents 40818a5 + 14e79ec commit 640588c
Show file tree
Hide file tree
Showing 8 changed files with 155 additions and 10 deletions.
6 changes: 6 additions & 0 deletions go/vt/vtgate/engine/fake_vcursor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ import (
vtgatepb "vitess.io/vitess/go/vt/proto/vtgate"
)

var testMaxMemoryRows = 100

// noopVCursor is used to build other vcursors.
type noopVCursor struct {
}
Expand All @@ -43,6 +45,10 @@ func (t noopVCursor) Context() context.Context {
return context.Background()
}

func (t noopVCursor) MaxMemoryRows() int {
return testMaxMemoryRows
}

func (t noopVCursor) SetContextTimeout(timeout time.Duration) context.CancelFunc {
return func() {}
}
Expand Down
3 changes: 3 additions & 0 deletions go/vt/vtgate/engine/join.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ func (jn *Join) Execute(vcursor VCursor, bindVars map[string]*querypb.BindVariab
} else {
result.RowsAffected += uint64(len(rresult.Rows))
}
if len(result.Rows) > vcursor.MaxMemoryRows() {
return nil, fmt.Errorf("in-memory row count exceeded allowed limit of %d", vcursor.MaxMemoryRows())
}
}
return result, nil
}
Expand Down
72 changes: 66 additions & 6 deletions go/vt/vtgate/engine/join_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ func TestJoinExecute(t *testing.T) {
"bv": 1,
},
}
r, err := jn.Execute(nil, bv, true)
r, err := jn.Execute(noopVCursor{}, bv, true)
if err != nil {
t.Fatal(err)
}
Expand All @@ -101,7 +101,7 @@ func TestJoinExecute(t *testing.T) {
leftPrim.rewind()
rightPrim.rewind()
jn.Opcode = LeftJoin
r, err = jn.Execute(nil, bv, true)
r, err = jn.Execute(noopVCursor{}, bv, true)
if err != nil {
t.Fatal(err)
}
Expand All @@ -126,6 +126,66 @@ func TestJoinExecute(t *testing.T) {
))
}

func TestJoinExecuteMaxMemoryRows(t *testing.T) {
save := testMaxMemoryRows
testMaxMemoryRows = 3
defer func() { testMaxMemoryRows = save }()

leftPrim := &fakePrimitive{
results: []*sqltypes.Result{
sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2|col3",
"int64|varchar|varchar",
),
"1|a|aa",
"2|b|bb",
"3|c|cc",
),
},
}
rightFields := sqltypes.MakeTestFields(
"col4|col5|col6",
"int64|varchar|varchar",
)
rightPrim := &fakePrimitive{
results: []*sqltypes.Result{
sqltypes.MakeTestResult(
rightFields,
"4|d|dd",
),
sqltypes.MakeTestResult(
rightFields,
),
sqltypes.MakeTestResult(
rightFields,
"5|e|ee",
"6|f|ff",
"7|g|gg",
),
},
}
bv := map[string]*querypb.BindVariable{
"a": sqltypes.Int64BindVariable(10),
}

// Normal join
jn := &Join{
Opcode: NormalJoin,
Left: leftPrim,
Right: rightPrim,
Cols: []int{-1, -2, 1, 2},
Vars: map[string]int{
"bv": 1,
},
}
_, err := jn.Execute(noopVCursor{}, bv, true)
want := "in-memory row count exceeded allowed limit of 3"
if err == nil || err.Error() != want {
t.Errorf("Execute(): %v, want %v", err, want)
}
}

func TestJoinExecuteNoResult(t *testing.T) {
leftPrim := &fakePrimitive{
results: []*sqltypes.Result{
Expand Down Expand Up @@ -158,7 +218,7 @@ func TestJoinExecuteNoResult(t *testing.T) {
"bv": 1,
},
}
r, err := jn.Execute(nil, map[string]*querypb.BindVariable{}, true)
r, err := jn.Execute(noopVCursor{}, map[string]*querypb.BindVariable{}, true)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -188,7 +248,7 @@ func TestJoinExecuteErrors(t *testing.T) {
Opcode: NormalJoin,
Left: leftPrim,
}
_, err := jn.Execute(nil, map[string]*querypb.BindVariable{}, true)
_, err := jn.Execute(noopVCursor{}, map[string]*querypb.BindVariable{}, true)
expectError(t, "jn.Execute", err, "left err")

// Error on right query
Expand Down Expand Up @@ -218,7 +278,7 @@ func TestJoinExecuteErrors(t *testing.T) {
"bv": 1,
},
}
_, err = jn.Execute(nil, map[string]*querypb.BindVariable{}, true)
_, err = jn.Execute(noopVCursor{}, map[string]*querypb.BindVariable{}, true)
expectError(t, "jn.Execute", err, "right err")

// Error on right getfields
Expand All @@ -245,7 +305,7 @@ func TestJoinExecuteErrors(t *testing.T) {
"bv": 1,
},
}
_, err = jn.Execute(nil, map[string]*querypb.BindVariable{}, true)
_, err = jn.Execute(noopVCursor{}, map[string]*querypb.BindVariable{}, true)
expectError(t, "jn.Execute", err, "right err")
}

Expand Down
3 changes: 3 additions & 0 deletions go/vt/vtgate/engine/primitive.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ type VCursor interface {
// Context returns the context of the current request.
Context() context.Context

// MaxMemoryRows returns the maxMemoryRows flag value.
MaxMemoryRows() int

// SetContextTimeout updates the context and sets a timeout.
SetContextTimeout(timeout time.Duration) context.CancelFunc

Expand Down
24 changes: 20 additions & 4 deletions go/vt/vtgate/scatter_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -157,9 +157,17 @@ func (stc *ScatterConn) Execute(

mu.Lock()
defer mu.Unlock()
qr.AppendResult(innerqr)
// Don't append more rows if row count is exceeded.
if len(qr.Rows) <= *maxMemoryRows {
qr.AppendResult(innerqr)
}
return transactionID, nil
})
},
)

if len(qr.Rows) > *maxMemoryRows {
return nil, vterrors.Errorf(vtrpcpb.Code_RESOURCE_EXHAUSTED, "in-memory row count exceeded allowed limit of %d", *maxMemoryRows)
}

return qr, allErrors.AggrError(vterrors.Aggregate)
}
Expand Down Expand Up @@ -215,9 +223,17 @@ func (stc *ScatterConn) ExecuteMultiShard(

mu.Lock()
defer mu.Unlock()
qr.AppendResult(innerqr)
// Don't append more rows if row count is exceeded.
if len(qr.Rows) <= *maxMemoryRows {
qr.AppendResult(innerqr)
}
return transactionID, nil
})
},
)

if len(qr.Rows) > *maxMemoryRows {
return nil, []error{vterrors.Errorf(vtrpcpb.Code_RESOURCE_EXHAUSTED, "in-memory row count exceeded allowed limit of %d", *maxMemoryRows)}
}

return qr, allErrors.GetErrors()
}
Expand Down
51 changes: 51 additions & 0 deletions go/vt/vtgate/scatter_conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,57 @@ func testScatterConnGeneric(t *testing.T, name string, f func(sc *ScatterConn, s
}
}

func TestMaxMemoryRows(t *testing.T) {
save := *maxMemoryRows
*maxMemoryRows = 3
defer func() { *maxMemoryRows = save }()

createSandbox("TestMaxMemoryRows")
hc := discovery.NewFakeHealthCheck()
sc := newTestScatterConn(hc, new(sandboxTopo), "aa")
sbc0 := hc.AddTestTablet("aa", "0", 1, "TestMaxMemoryRows", "0", topodatapb.TabletType_REPLICA, true, 1, nil)
sbc1 := hc.AddTestTablet("aa", "1", 1, "TestMaxMemoryRows", "1", topodatapb.TabletType_REPLICA, true, 1, nil)

tworows := &sqltypes.Result{
Rows: [][]sqltypes.Value{{
sqltypes.NewInt64(1),
}, {
sqltypes.NewInt64(1),
}},
RowsAffected: 1,
InsertID: 1,
}
sbc0.SetResults([]*sqltypes.Result{tworows, tworows})
sbc1.SetResults([]*sqltypes.Result{tworows, tworows})

res := srvtopo.NewResolver(&sandboxTopo{}, sc.gateway, "aa")
rss, _, err := res.ResolveDestinations(context.Background(), "TestMaxMemoryRows", topodatapb.TabletType_REPLICA, nil,
[]key.Destination{key.DestinationShard("0"), key.DestinationShard("1")})
if err != nil {
t.Fatalf("ResolveDestination(0) failed: %v", err)
}
session := NewSafeSession(&vtgatepb.Session{InTransaction: true})

_, err = sc.Execute(context.Background(), "query1", nil, rss, topodatapb.TabletType_REPLICA, session, true, nil)
want := "in-memory row count exceeded allowed limit of 3"
if err == nil || err.Error() != want {
t.Errorf("Execute(): %v, want %v", err, want)
}

queries := []*querypb.BoundQuery{{
Sql: "query1",
BindVariables: map[string]*querypb.BindVariable{},
}, {
Sql: "query1",
BindVariables: map[string]*querypb.BindVariable{},
}}
_, errs := sc.ExecuteMultiShard(context.Background(), rss, queries, topodatapb.TabletType_REPLICA, session, false, false)
err = errs[0]
if err == nil || err.Error() != want {
t.Errorf("Execute(): %v, want %v", err, want)
}
}

func TestMultiExecs(t *testing.T) {
createSandbox("TestMultiExecs")
hc := discovery.NewFakeHealthCheck()
Expand Down
5 changes: 5 additions & 0 deletions go/vt/vtgate/vcursor_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ func (vc *vcursorImpl) Context() context.Context {
return vc.ctx
}

// MaxMemoryRows returns the maxMemoryRows flag value.
func (vc *vcursorImpl) MaxMemoryRows() int {
return *maxMemoryRows
}

// SetContextTimeout updates context and sets a timeout.
func (vc *vcursorImpl) SetContextTimeout(timeout time.Duration) context.CancelFunc {
ctx, cancel := context.WithTimeout(vc.ctx, timeout)
Expand Down
1 change: 1 addition & 0 deletions go/vt/vtgate/vtgate.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ var (
streamBufferSize = flag.Int("stream_buffer_size", 32*1024, "the number of bytes sent from vtgate for each stream call. It's recommended to keep this value in sync with vttablet's query-server-config-stream-buffer-size.")
queryPlanCacheSize = flag.Int64("gate_query_cache_size", 10000, "gate server query cache size, maximum number of queries to be cached. vtgate analyzes every incoming query and generate a query plan, these plans are being cached in a lru cache. This config controls the capacity of the lru cache.")
disableLocalGateway = flag.Bool("disable_local_gateway", false, "if specified, this process will not route any queries to local tablets in the local cell")
maxMemoryRows = flag.Int("max_memory_rows", 30000, "Maximum number of rows that will be held in memory for intermediate results as well as the final result.")
)

func getTxMode() vtgatepb.TransactionMode {
Expand Down

0 comments on commit 640588c

Please sign in to comment.