From be3ce273d465d0023d4cb50d76c3ba48971849de Mon Sep 17 00:00:00 2001 From: Harshit Gangal Date: Wed, 5 May 2021 11:15:53 +0530 Subject: [PATCH] merge sort stream fix Signed-off-by: Harshit Gangal --- go/vt/vtgate/engine/merge_sort.go | 5 ++-- go/vt/vtgate/executor_select_test.go | 36 +++++++++++++++++++++++ go/vt/vttablet/sandboxconn/sandboxconn.go | 18 ++++++++++-- 3 files changed, 54 insertions(+), 5 deletions(-) diff --git a/go/vt/vtgate/engine/merge_sort.go b/go/vt/vtgate/engine/merge_sort.go index 6aa01e4d4d4..59d3c1ab3bb 100644 --- a/go/vt/vtgate/engine/merge_sort.go +++ b/go/vt/vtgate/engine/merge_sort.go @@ -81,7 +81,7 @@ func (ms *MergeSort) StreamExecute(vcursor VCursor, bindVars map[string]*querypb handles := make([]*streamHandle, len(ms.Primitives)) for i, input := range ms.Primitives { - handles[i] = runOneStream(vcursor, input, bindVars, wantfields) + handles[i] = runOneStream(ctx, vcursor, input, bindVars, wantfields) // Need fields only from first handle, if wantfields was true. wantfields = false } @@ -183,12 +183,11 @@ type streamHandle struct { } // runOnestream starts a streaming query on one shard, and returns a streamHandle for it. -func runOneStream(vcursor VCursor, input StreamExecutor, bindVars map[string]*querypb.BindVariable, wantfields bool) *streamHandle { +func runOneStream(ctx context.Context, vcursor VCursor, input StreamExecutor, bindVars map[string]*querypb.BindVariable, wantfields bool) *streamHandle { handle := &streamHandle{ fields: make(chan []*querypb.Field, 1), row: make(chan []sqltypes.Value, 10), } - ctx := vcursor.Context() go func() { defer close(handle.fields) diff --git a/go/vt/vtgate/executor_select_test.go b/go/vt/vtgate/executor_select_test.go index d5783b070ce..50eab386464 100644 --- a/go/vt/vtgate/executor_select_test.go +++ b/go/vt/vtgate/executor_select_test.go @@ -19,8 +19,10 @@ package vtgate import ( "fmt" "reflect" + "runtime" "strings" "testing" + "time" "vitess.io/vitess/go/test/utils" @@ -2420,3 +2422,37 @@ func TestSelectFromInformationSchema(t *testing.T) { require.NoError(t, err) assert.Equal(t, sbc1.StringQueries(), []string{"select * from INFORMATION_SCHEMA.`TABLES` where TABLE_SCHEMA = :__vtschemaname"}) } + +func TestStreamOrderByLimitWithMultipleResults(t *testing.T) { + // Special setup: Don't use createLegacyExecutorEnv. + cell := "aa" + hc := discovery.NewFakeHealthCheck() + s := createSandbox("TestExecutor") + s.VSchema = executorVSchema + getSandbox(KsTestUnsharded).VSchema = unshardedVSchema + serv := new(sandboxTopo) + resolver := newTestResolver(hc, serv, cell) + shards := []string{"-20", "20-40", "40-60", "60-80", "80-a0", "a0-c0", "c0-e0", "e0-"} + count := 1 + for _, shard := range shards { + sbc := hc.AddTestTablet(cell, shard, 1, "TestExecutor", shard, topodatapb.TabletType_MASTER, true, 1, nil) + sbc.SetResults([]*sqltypes.Result{ + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col", "int32|int32"), fmt.Sprintf("%d|%d", count, count)), + sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col", "int32|int32"), fmt.Sprintf("%d|%d", count+10, count)), + }) + count++ + } + executor := NewExecutor(context.Background(), serv, cell, resolver, false, testBufferSize, testCacheSize) + before := runtime.NumGoroutine() + + query := "select id, col from user order by id limit 2" + gotResult, err := executorStream(executor, query) + require.NoError(t, err) + + wantResult := sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col", "int32|int32"), "1|1", "2|2") + wantResult.RowsAffected = 0 + utils.MustMatch(t, wantResult, gotResult) + // some sleep to close all goroutines. + time.Sleep(100 * time.Millisecond) + assert.GreaterOrEqual(t, before, runtime.NumGoroutine(), "left open goroutines lingering") +} diff --git a/go/vt/vttablet/sandboxconn/sandboxconn.go b/go/vt/vttablet/sandboxconn/sandboxconn.go index 823cb312de1..89bb04eaf1f 100644 --- a/go/vt/vttablet/sandboxconn/sandboxconn.go +++ b/go/vt/vttablet/sandboxconn/sandboxconn.go @@ -198,10 +198,24 @@ func (sbc *SandboxConn) StreamExecute(ctx context.Context, target *querypb.Targe sbc.sExecMu.Unlock() return err } - nextRs := sbc.getNextResult() + if sbc.results == nil { + nextRs := sbc.getNextResult() + sbc.sExecMu.Unlock() + return callback(nextRs) + } + + for len(sbc.results) > 0 { + nextRs := sbc.getNextResult() + sbc.sExecMu.Unlock() + err := callback(nextRs) + if err != nil { + return err + } + sbc.sExecMu.Lock() + } sbc.sExecMu.Unlock() - return callback(nextRs) + return nil } // Begin is part of the QueryService interface.