Skip to content

Commit

Permalink
Merge pull request #7903 from planetscale/stream-ms-fix
Browse files Browse the repository at this point in the history
Memory Sort to close the goroutines when callback returns error
  • Loading branch information
systay authored Apr 22, 2021
2 parents b765bce + 6ccee19 commit ae3c297
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 6 deletions.
5 changes: 2 additions & 3 deletions go/vt/vtgate/engine/merge_sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func (ms *MergeSort) StreamExecute(vcursor VCursor, bindVars map[string]*querypb

handles := make([]*streamHandle, len(ms.Primitives))
for i, input := range ms.Primitives {
handles[i] = runOneStream(vcursor, input, bindVars, wantfields)
handles[i] = runOneStream(ctx, vcursor, input, bindVars, wantfields)
// Need fields only from first handle, if wantfields was true.
wantfields = false
}
Expand Down Expand Up @@ -182,12 +182,11 @@ type streamHandle struct {
}

// runOnestream starts a streaming query on one shard, and returns a streamHandle for it.
func runOneStream(vcursor VCursor, input StreamExecutor, bindVars map[string]*querypb.BindVariable, wantfields bool) *streamHandle {
func runOneStream(ctx context.Context, vcursor VCursor, input StreamExecutor, bindVars map[string]*querypb.BindVariable, wantfields bool) *streamHandle {
handle := &streamHandle{
fields: make(chan []*querypb.Field, 1),
row: make(chan []sqltypes.Value, 10),
}
ctx := vcursor.Context()

go func() {
defer close(handle.fields)
Expand Down
35 changes: 35 additions & 0 deletions go/vt/vtgate/executor_select_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ package vtgate

import (
"fmt"
"runtime"
"strings"
"testing"
"time"

"vitess.io/vitess/go/cache"
"vitess.io/vitess/go/test/utils"
Expand Down Expand Up @@ -2336,3 +2338,36 @@ func TestSelectFromInformationSchema(t *testing.T) {
require.NoError(t, err)
assert.Equal(t, sbc1.StringQueries(), []string{"select * from INFORMATION_SCHEMA.`TABLES` where TABLE_SCHEMA = :__vtschemaname"})
}

func TestStreamOrderByLimitWithMultipleResults(t *testing.T) {
// Special setup: Don't use createLegacyExecutorEnv.
cell := "aa"
hc := discovery.NewFakeHealthCheck()
s := createSandbox("TestExecutor")
s.VSchema = executorVSchema
getSandbox(KsTestUnsharded).VSchema = unshardedVSchema
serv := new(sandboxTopo)
resolver := newTestResolver(hc, serv, cell)
shards := []string{"-20", "20-40", "40-60", "60-80", "80-a0", "a0-c0", "c0-e0", "e0-"}
count := 1
for _, shard := range shards {
sbc := hc.AddTestTablet(cell, shard, 1, "TestExecutor", shard, topodatapb.TabletType_MASTER, true, 1, nil)
sbc.SetResults([]*sqltypes.Result{
sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col|weight_string(id)", "int32|int32|varchar"), fmt.Sprintf("%d|%d|NULL", count, count)),
sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col|weight_string(id)", "int32|int32|varchar"), fmt.Sprintf("%d|%d|NULL", count+10, count)),
})
count++
}
executor := NewExecutor(context.Background(), serv, cell, resolver, true, false, testBufferSize, cache.DefaultConfig)
before := runtime.NumGoroutine()

query := "select id, col from user order by id limit 2"
gotResult, err := executorStream(executor, query)
require.NoError(t, err)

wantResult := sqltypes.MakeTestResult(sqltypes.MakeTestFields("id|col", "int32|int32"), "1|1", "2|2")
utils.MustMatch(t, wantResult, gotResult)
// some sleep to close all goroutines.
time.Sleep(100 * time.Millisecond)
assert.GreaterOrEqual(t, before, runtime.NumGoroutine(), "left open goroutines lingering")
}
25 changes: 22 additions & 3 deletions go/vt/vttablet/sandboxconn/sandboxconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -207,10 +207,25 @@ func (sbc *SandboxConn) StreamExecute(ctx context.Context, target *querypb.Targe
return err
}
parse, _ := sqlparser.Parse(query)
nextRs := sbc.getNextResult(parse)
sbc.sExecMu.Unlock()

return callback(nextRs)
if sbc.results == nil {
nextRs := sbc.getNextResult(parse)
sbc.sExecMu.Unlock()
return callback(nextRs)
}

for len(sbc.results) > 0 {
nextRs := sbc.getNextResult(parse)
sbc.sExecMu.Unlock()
err := callback(nextRs)
if err != nil {
return err
}
sbc.sExecMu.Lock()
}

sbc.sExecMu.Unlock()
return nil
}

// Begin is part of the QueryService interface.
Expand Down Expand Up @@ -581,6 +596,10 @@ func (sbc *SandboxConn) setTxReservedID(transactionID int64, reservedID int64) {
sbc.txIDToRID[transactionID] = reservedID
}

func (sbc *SandboxConn) ResultsAllFetched() bool {
return len(sbc.results) == 0
}

func (sbc *SandboxConn) getTxReservedID(txID int64) int64 {
sbc.mapMu.Lock()
defer sbc.mapMu.Unlock()
Expand Down

0 comments on commit ae3c297

Please sign in to comment.