Skip to content

Commit

Permalink
executor: enable parallel sort (#53537)
Browse files Browse the repository at this point in the history
close #53536
  • Loading branch information
xzhangxian1008 authored Jun 6, 2024
1 parent dcadcde commit 793530a
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 61 deletions.
22 changes: 9 additions & 13 deletions pkg/executor/sortexec/parallel_sort_spill_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import (
var hardLimit1 = int64(100000)
var hardLimit2 = hardLimit1 * 10

func oneSpillCase(t *testing.T, ctx *mock.Context, exe *sortexec.SortExec, sortCase *testutil.SortCase, schema *expression.Schema, dataSource *testutil.MockDataSource) {
func oneSpillCase(t *testing.T, exe *sortexec.SortExec, sortCase *testutil.SortCase, schema *expression.Schema, dataSource *testutil.MockDataSource) {
if exe == nil {
exe = buildSortExec(sortCase, dataSource)
}
Expand Down Expand Up @@ -60,15 +60,15 @@ func inMemoryThenSpill(t *testing.T, ctx *mock.Context, exe *sortexec.SortExec,
require.True(t, checkCorrectness(schema, exe, dataSource, resultChunks))
}

func failpointNoMemoryDataTest(t *testing.T, ctx *mock.Context, exe *sortexec.SortExec, sortCase *testutil.SortCase, schema *expression.Schema, dataSource *testutil.MockDataSource) {
func failpointNoMemoryDataTest(t *testing.T, exe *sortexec.SortExec, sortCase *testutil.SortCase, dataSource *testutil.MockDataSource) {
if exe == nil {
exe = buildSortExec(sortCase, dataSource)
}
dataSource.PrepareChunks()
executeInFailpoint(t, exe, 0, nil)
}

func failpointDataInMemoryThenSpillTest(t *testing.T, ctx *mock.Context, exe *sortexec.SortExec, sortCase *testutil.SortCase, schema *expression.Schema, dataSource *testutil.MockDataSource) {
func failpointDataInMemoryThenSpillTest(t *testing.T, ctx *mock.Context, exe *sortexec.SortExec, sortCase *testutil.SortCase, dataSource *testutil.MockDataSource) {
if exe == nil {
exe = buildSortExec(sortCase, dataSource)
}
Expand All @@ -91,15 +91,13 @@ func TestParallelSortSpillDisk(t *testing.T) {
ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, hardLimit1)
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker)
// TODO use variable to choose parallel mode after system variable is added
// ctx.GetSessionVars().EnableParallelSort = true

schema := expression.NewSchema(sortCase.Columns()...)
dataSource := buildDataSource(sortCase, schema)
exe := buildSortExec(sortCase, dataSource)
for i := 0; i < 10; i++ {
oneSpillCase(t, ctx, nil, sortCase, schema, dataSource)
oneSpillCase(t, ctx, exe, sortCase, schema, dataSource)
oneSpillCase(t, nil, sortCase, schema, dataSource)
oneSpillCase(t, exe, sortCase, schema, dataSource)
}

ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, hardLimit2)
Expand Down Expand Up @@ -129,21 +127,19 @@ func TestParallelSortSpillDiskFailpoint(t *testing.T) {
ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, hardLimit1)
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker)
// TODO use variable to choose parallel mode after system variable is added
// ctx.GetSessionVars().EnableParallelSort = true

schema := expression.NewSchema(sortCase.Columns()...)
dataSource := buildDataSource(sortCase, schema)
exe := buildSortExec(sortCase, dataSource)
for i := 0; i < 20; i++ {
failpointNoMemoryDataTest(t, ctx, nil, sortCase, schema, dataSource)
failpointNoMemoryDataTest(t, ctx, exe, sortCase, schema, dataSource)
failpointNoMemoryDataTest(t, nil, sortCase, dataSource)
failpointNoMemoryDataTest(t, exe, sortCase, dataSource)
}

ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, hardLimit2)
ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker)
for i := 0; i < 20; i++ {
failpointDataInMemoryThenSpillTest(t, ctx, nil, sortCase, schema, dataSource)
failpointDataInMemoryThenSpillTest(t, ctx, exe, sortCase, schema, dataSource)
failpointDataInMemoryThenSpillTest(t, ctx, nil, sortCase, dataSource)
failpointDataInMemoryThenSpillTest(t, ctx, exe, sortCase, dataSource)
}
}
6 changes: 0 additions & 6 deletions pkg/executor/sortexec/parallel_sort_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,6 @@ func executeInFailpoint(t *testing.T, exe *sortexec.SortExec, hardLimit int64, t
tmpCtx := context.Background()
err := exe.Open(tmpCtx)
require.NoError(t, err)
exe.IsUnparallel = false
exe.InitInParallelModeForTest()

goRoutineWaiter := sync.WaitGroup{}
goRoutineWaiter.Add(1)
Expand Down Expand Up @@ -85,8 +83,6 @@ func parallelSortTest(t *testing.T, ctx *mock.Context, exe *sortexec.SortExec, s
ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker)
// TODO use variable to choose parallel mode after system variable is added
// ctx.GetSessionVars().EnableParallelSort = true

if exe == nil {
exe = buildSortExec(sortCase, dataSource)
Expand All @@ -105,8 +101,6 @@ func failpointTest(t *testing.T, ctx *mock.Context, exe *sortexec.SortExec, sort
ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker)
// TODO use variable to choose parallel mode after system variable is added
// ctx.GetSessionVars().EnableParallelSort = true
if exe == nil {
exe = buildSortExec(sortCase, dataSource)
}
Expand Down
8 changes: 8 additions & 0 deletions pkg/executor/sortexec/parallel_sort_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,14 @@ func newParallelSortWorker(
}
}

func (p *parallelSortWorker) reset() {
p.batchRows = nil
p.localSortedRows = nil
p.sortedRowsIter = nil
p.merger = nil
p.memTracker.ReplaceBytesUsed(0)
}

func (p *parallelSortWorker) injectFailPointForParallelSortWorker(triggerFactor int32) {
injectParallelSortRandomFail(triggerFactor)
failpoint.Inject("SlowSomeWorkers", func(val failpoint.Value) {
Expand Down
56 changes: 29 additions & 27 deletions pkg/executor/sortexec/sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"sync/atomic"
"time"

"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/executor/internal/exec"
"github.com/pingcap/tidb/pkg/expression"
Expand Down Expand Up @@ -55,6 +56,7 @@ type SortExec struct {
memTracker *memory.Tracker
diskTracker *disk.Tracker

// TODO delete this variable in the future and remove the unparallel sort
IsUnparallel bool

finishCh chan struct{}
Expand Down Expand Up @@ -124,11 +126,9 @@ func (e *SortExec) Close() error {
// will use `e.Parallel.workers` and `e.Parallel.merger`.
channel.Clear(e.Parallel.resultChannel)
for i := range e.Parallel.workers {
e.Parallel.workers[i].batchRows = nil
e.Parallel.workers[i].localSortedRows = nil
e.Parallel.workers[i].sortedRowsIter = nil
e.Parallel.workers[i].merger = nil
e.Parallel.workers[i].memTracker.ReplaceBytesUsed(0)
if e.Parallel.workers[i] != nil {
e.Parallel.workers[i].reset()
}
}
e.Parallel.merger = nil
if e.Parallel.spillAction != nil {
Expand Down Expand Up @@ -160,7 +160,7 @@ func (e *SortExec) Open(ctx context.Context) error {
e.diskTracker.AttachTo(e.Ctx().GetSessionVars().StmtCtx.DiskTracker)
}

e.IsUnparallel = true
e.IsUnparallel = false
if e.IsUnparallel {
e.Unparallel.Idx = 0
e.Unparallel.sortPartitions = e.Unparallel.sortPartitions[:0]
Expand All @@ -185,24 +185,10 @@ func (e *SortExec) Open(ctx context.Context) error {
return exec.Open(ctx, e.Children(0))
}

// InitInParallelModeForTest is a function for test
// After system variable is added, we can delete this function
func (e *SortExec) InitInParallelModeForTest() {
e.Parallel.workers = make([]*parallelSortWorker, e.Ctx().GetSessionVars().ExecutorConcurrency)
e.Parallel.chunkChannel = make(chan *chunkWithMemoryUsage, e.Ctx().GetSessionVars().ExecutorConcurrency)
e.Parallel.fetcherAndWorkerSyncer = &sync.WaitGroup{}
e.Parallel.sortedRowsIters = make([]*chunk.Iterator4Slice, len(e.Parallel.workers))
e.Parallel.resultChannel = make(chan rowWithError, e.MaxChunkSize())
e.Parallel.closeSync = make(chan struct{})
e.Parallel.merger = newMultiWayMerger(&memorySource{sortedRowsIters: e.Parallel.sortedRowsIters}, e.lessRow)
e.Parallel.spillHelper = newParallelSortSpillHelper(e, exec.RetTypes(e), e.finishCh, e.lessRow, e.Parallel.resultChannel)
e.Parallel.spillAction = newParallelSortSpillDiskAction(e.Parallel.spillHelper)
for i := range e.Parallel.sortedRowsIters {
e.Parallel.sortedRowsIters[i] = chunk.NewIterator4Slice(nil)
}
if e.enableTmpStorageOnOOM {
e.Ctx().GetSessionVars().MemTracker.FallbackOldAndSetNewAction(e.Parallel.spillAction)
}
// InitUnparallelModeForTest is for unit test
func (e *SortExec) InitUnparallelModeForTest() {
e.Unparallel.Idx = 0
e.Unparallel.sortPartitions = e.Unparallel.sortPartitions[:0]
}

// Next implements the Executor Next interface.
Expand Down Expand Up @@ -272,9 +258,13 @@ func (e *SortExec) InitInParallelModeForTest() {
*/
func (e *SortExec) Next(ctx context.Context, req *chunk.Chunk) error {
if e.fetched.CompareAndSwap(false, true) {
e.initCompareFuncs(e.Ctx().GetExprCtx().GetEvalCtx())
err := e.initCompareFuncs(e.Ctx().GetExprCtx().GetEvalCtx())
if err != nil {
return err
}

e.buildKeyColumns()
err := e.fetchChunks(ctx)
err = e.fetchChunks(ctx)
if err != nil {
return err
}
Expand Down Expand Up @@ -710,6 +700,14 @@ func (e *SortExec) fetchChunksFromChild(ctx context.Context) {
e.Parallel.resultChannel <- rowWithError{err: err}
}

failpoint.Inject("SignalCheckpointForSort", func(val failpoint.Value) {
if val.(bool) {
if e.Ctx().GetSessionVars().ConnectionID == 123456 {
e.Ctx().GetSessionVars().MemTracker.Killer.SendKillSignal(sqlkiller.QueryMemoryExceeded)
}
}
})

// We must place it after the spill as workers will process its received
// chunks after channel is closed and this will cause data race.
close(e.Parallel.chunkChannel)
Expand Down Expand Up @@ -753,12 +751,16 @@ func (e *SortExec) fetchChunksFromChild(ctx context.Context) {
}
}

func (e *SortExec) initCompareFuncs(ctx expression.EvalContext) {
func (e *SortExec) initCompareFuncs(ctx expression.EvalContext) error {
e.keyCmpFuncs = make([]chunk.CompareFunc, len(e.ByItems))
for i := range e.ByItems {
keyType := e.ByItems[i].Expr.GetType(ctx)
e.keyCmpFuncs[i] = chunk.GetCompareFunc(keyType)
if e.keyCmpFuncs[i] == nil {
return errors.Errorf("Sort executor not supports type %s", types.TypeStr(keyType.GetType()))
}
}
return nil
}

func (e *SortExec) buildKeyColumns() {
Expand Down
20 changes: 6 additions & 14 deletions pkg/executor/sortexec/sort_spill_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@ func executeSortExecutor(t *testing.T, exe *sortexec.SortExec, isParallelSort bo
tmpCtx := context.Background()
err := exe.Open(tmpCtx)
require.NoError(t, err)
if isParallelSort {
exe.IsUnparallel = false
exe.InitInParallelModeForTest()
if !isParallelSort {
exe.IsUnparallel = true
exe.InitUnparallelModeForTest()
}

resultChunks := make([]*chunk.Chunk, 0)
Expand All @@ -199,9 +199,9 @@ func executeSortExecutorAndManullyTriggerSpill(t *testing.T, exe *sortexec.SortE
tmpCtx := context.Background()
err := exe.Open(tmpCtx)
require.NoError(t, err)
if isParallelSort {
exe.IsUnparallel = false
exe.InitInParallelModeForTest()
if !isParallelSort {
exe.IsUnparallel = true
exe.InitUnparallelModeForTest()
}

resultChunks := make([]*chunk.Chunk, 0)
Expand Down Expand Up @@ -239,8 +239,6 @@ func onePartitionAndAllDataInMemoryCase(t *testing.T, ctx *mock.Context, sortCas
ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, 1048576)
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker)
// TODO use variable to choose parallel mode after system variable is added
// ctx.GetSessionVars().EnableParallelSort = false
schema := expression.NewSchema(sortCase.Columns()...)
dataSource := buildDataSource(sortCase, schema)
exe := buildSortExec(sortCase, dataSource)
Expand All @@ -262,8 +260,6 @@ func onePartitionAndAllDataInDiskCase(t *testing.T, ctx *mock.Context, sortCase
ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, 50000)
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker)
// TODO use variable to choose parallel mode after system variable is added
// ctx.GetSessionVars().EnableParallelSort = false
schema := expression.NewSchema(sortCase.Columns()...)
dataSource := buildDataSource(sortCase, schema)
exe := buildSortExec(sortCase, dataSource)
Expand Down Expand Up @@ -292,8 +288,6 @@ func multiPartitionCase(t *testing.T, ctx *mock.Context, sortCase *testutil.Sort
ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, 10000)
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker)
// TODO use variable to choose parallel mode after system variable is added
// ctx.GetSessionVars().EnableParallelSort = false
schema := expression.NewSchema(sortCase.Columns()...)
dataSource := buildDataSource(sortCase, schema)
exe := buildSortExec(sortCase, dataSource)
Expand Down Expand Up @@ -333,8 +327,6 @@ func inMemoryThenSpillCase(t *testing.T, ctx *mock.Context, sortCase *testutil.S
ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, hardLimit)
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker)
// TODO use variable to choose parallel mode after system variable is added
// ctx.GetSessionVars().EnableParallelSort = false
schema := expression.NewSchema(sortCase.Columns()...)
dataSource := buildDataSource(sortCase, schema)
exe := buildSortExec(sortCase, dataSource)
Expand Down
6 changes: 5 additions & 1 deletion pkg/executor/sortexec/topn.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,11 @@ func (e *TopNExec) fetchChunks(ctx context.Context) error {
}

func (e *TopNExec) loadChunksUntilTotalLimit(ctx context.Context) error {
e.initCompareFuncs(e.Ctx().GetExprCtx().GetEvalCtx())
err := e.initCompareFuncs(e.Ctx().GetExprCtx().GetEvalCtx())
if err != nil {
return err
}

e.buildKeyColumns()
e.chkHeap.init(e, e.memTracker, e.Limit.Offset+e.Limit.Count, int(e.Limit.Offset), e.greaterRow, e.RetFieldTypes())
for uint64(e.chkHeap.rowChunks.Len()) < e.chkHeap.totalLimit {
Expand Down
6 changes: 6 additions & 0 deletions pkg/util/chunk/compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ func GetCompareFunc(tp *types.FieldType) CompareFunc {
return cmpBit
case mysql.TypeJSON:
return cmpJSON
case mysql.TypeNull:
return cmpNullConst
}
return nil
}
Expand Down Expand Up @@ -169,6 +171,10 @@ func cmpJSON(l Row, lCol int, r Row, rCol int) int {
return types.CompareBinaryJSON(lJ, rJ)
}

func cmpNullConst(_ Row, _ int, _ Row, _ int) int {
return 0
}

// Compare compares the value with ad.
// We assume that the collation information of the column is the same with the datum.
func Compare(row Row, colIdx int, ad *types.Datum) int {
Expand Down

0 comments on commit 793530a

Please sign in to comment.