Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

executor: check kill signal for topn and parallel sort spill #56238

Merged
merged 11 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion pkg/executor/sortexec/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,14 @@ go_library(
"//pkg/util",
"//pkg/util/channel",
"//pkg/util/chunk",
"//pkg/util/dbterror/exeerrors",
"//pkg/util/disk",
"//pkg/util/logutil",
"//pkg/util/memory",
"//pkg/util/sqlkiller",
"@com_github_pingcap_errors//:errors",
"@com_github_pingcap_failpoint//:failpoint",
"@com_github_stretchr_testify//require",
"@org_uber_go_zap//:zap",
],
)
Expand All @@ -42,7 +44,7 @@ go_test(
timeout = "short",
srcs = ["sort_test.go"],
flaky = True,
shard_count = 16,
shard_count = 17,
deps = [
"//pkg/config",
"//pkg/sessionctx/variable",
Expand Down
11 changes: 6 additions & 5 deletions pkg/executor/sortexec/parallel_sort_spill_helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,12 @@ func (p *parallelSortSpillHelper) spill() (err error) {
}
}()

p.setInSpilling()
xzhangxian1008 marked this conversation as resolved.
Show resolved Hide resolved

// Spill is done, broadcast to wake up all sleep goroutines
defer p.cond.Broadcast()
defer p.setNotSpilled()

select {
case <-p.finishCh:
return nil
Expand Down Expand Up @@ -138,11 +144,6 @@ func (p *parallelSortSpillHelper) spill() (err error) {
}

workerWaiter.Wait()
p.setInSpilling()

// Spill is done, broadcast to wake up all sleep goroutines
defer p.cond.Broadcast()
defer p.setNotSpilled()

totalRows := 0
for i := range sortedRowsIters {
Expand Down
43 changes: 39 additions & 4 deletions pkg/executor/sortexec/parallel_sort_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@ import (
"sync"
"time"

"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/memory"
"github.com/pingcap/tidb/pkg/util/sqlkiller"
)

// SignalCheckpointForSort indicates the times of row comparation that a signal detection will be triggered.
Expand Down Expand Up @@ -50,6 +52,8 @@ type parallelSortWorker struct {
chunkIters []*chunk.Iterator4Chunk
rowNumInChunkIters int
merger *multiWayMerger

sqlKiller *sqlkiller.SQLKiller
}

func newParallelSortWorker(
Expand All @@ -62,7 +66,9 @@ func newParallelSortWorker(
memTracker *memory.Tracker,
sortedRowsIter *chunk.Iterator4Slice,
maxChunkSize int,
spillHelper *parallelSortSpillHelper) *parallelSortWorker {
spillHelper *parallelSortSpillHelper,
sqlKiller *sqlkiller.SQLKiller,
) *parallelSortWorker {
return &parallelSortWorker{
workerIDForTest: workerIDForTest,
lessRowFunc: lessRowFunc,
Expand All @@ -75,6 +81,7 @@ func newParallelSortWorker(
sortedRowsIter: sortedRowsIter,
maxSortedRowsLimit: maxChunkSize * 30,
spillHelper: spillHelper,
sqlKiller: sqlKiller,
}
}

Expand Down Expand Up @@ -112,7 +119,32 @@ func (p *parallelSortWorker) multiWayMergeLocalSortedRows() ([]chunk.Row, error)
return nil, err
}

loopCnt := uint64(0)

for {
var err error
failpoint.Inject("ParallelSortRandomFail", func(val failpoint.Value) {
if val.(bool) {
randNum := rand.Int31n(10000)
if randNum < 2 {
err = errors.NewNoStackErrorf("failpoint error")
}
}
})

if err != nil {
return nil, err
}

if loopCnt%100 == 0 && p.sqlKiller != nil {
err := p.sqlKiller.HandleSignal()
if err != nil {
return nil, err
}
}

loopCnt++

// It's impossible to return error here as rows are in memory
row, _ := p.merger.next()
if row.IsEmpty() {
Expand Down Expand Up @@ -202,9 +234,12 @@ func (p *parallelSortWorker) fetchChunksAndSortImpl() bool {
}

func (p *parallelSortWorker) keyColumnsLess(i, j chunk.Row) int {
if p.timesOfRowCompare >= SignalCheckpointForSort {
// Trigger Consume for checking the NeedKill signal
p.memTracker.Consume(1)
if p.timesOfRowCompare >= SignalCheckpointForSort && p.sqlKiller != nil {
err := p.sqlKiller.HandleSignal()
if err != nil {
panic(err)
}

p.timesOfRowCompare = 0
}

Expand Down
2 changes: 1 addition & 1 deletion pkg/executor/sortexec/sort.go
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,7 @@ func (e *SortExec) fetchChunksParallel(ctx context.Context) error {
fetcherWaiter := util.WaitGroupWrapper{}

for i := range e.Parallel.workers {
e.Parallel.workers[i] = newParallelSortWorker(i, e.lessRow, e.Parallel.chunkChannel, e.Parallel.fetcherAndWorkerSyncer, e.Parallel.resultChannel, e.finishCh, e.memTracker, e.Parallel.sortedRowsIters[i], e.MaxChunkSize(), e.Parallel.spillHelper)
e.Parallel.workers[i] = newParallelSortWorker(i, e.lessRow, e.Parallel.chunkChannel, e.Parallel.fetcherAndWorkerSyncer, e.Parallel.resultChannel, e.finishCh, e.memTracker, e.Parallel.sortedRowsIters[i], e.MaxChunkSize(), e.Parallel.spillHelper, &e.Ctx().GetSessionVars().SQLKiller)
worker := e.Parallel.workers[i]
workersWaiter.Run(func() {
worker.run()
Expand Down
4 changes: 0 additions & 4 deletions pkg/executor/sortexec/sort_spill.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,10 +115,6 @@ func (s *parallelSortSpillAction) actionImpl(t *memory.Tracker) bool {
}

if t.CheckExceed() && s.spillHelper.isNotSpilledNoLock() && hasEnoughDataToSpill(s.spillHelper.sortExec.memTracker, t) {
// Ideally, all goroutines entering this action should wait for the finish of spill once
// spill is triggered(we consider spill is triggered when the `needSpill` has been set).
// However, out of some reasons, we have to directly return before the finish of
// sort operation executed in spill as sort will retrigger the action and lead to dead lock.
s.spillHelper.setNeedSpillNoLock()
s.spillHelper.bytesConsumed.Store(t.BytesConsumed())
s.spillHelper.bytesLimit.Store(t.GetBytesLimit())
Expand Down
3 changes: 2 additions & 1 deletion pkg/executor/sortexec/topn.go
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,12 @@ func (e *TopNExec) Open(ctx context.Context) error {
exec.RetTypes(e),
workers,
e.Concurrency,
&e.Ctx().GetSessionVars().SQLKiller,
)
e.spillAction = &topNSpillAction{spillHelper: e.spillHelper}
e.Ctx().GetSessionVars().MemTracker.FallbackOldAndSetNewAction(e.spillAction)
} else {
e.spillHelper = newTopNSpillerHelper(e, nil, nil, nil, nil, nil, nil, 0)
e.spillHelper = newTopNSpillerHelper(e, nil, nil, nil, nil, nil, nil, 0, nil)
}

return exec.Open(ctx, e.Children(0))
Expand Down
24 changes: 24 additions & 0 deletions pkg/executor/sortexec/topn_chunk_heap.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,16 @@ package sortexec

import (
"container/heap"
"context"
"testing"

"github.com/pingcap/tidb/pkg/executor/internal/exec"
"github.com/pingcap/tidb/pkg/types"
"github.com/pingcap/tidb/pkg/util/chunk"
"github.com/pingcap/tidb/pkg/util/dbterror/exeerrors"
"github.com/pingcap/tidb/pkg/util/memory"
"github.com/pingcap/tidb/pkg/util/sqlkiller"
"github.com/stretchr/testify/require"
)

// topNChunkHeap implements heap.Interface.
Expand Down Expand Up @@ -153,3 +158,22 @@ func (h *topNChunkHeap) Pop() any {
func (h *topNChunkHeap) Swap(i, j int) {
h.rowPtrs[i], h.rowPtrs[j] = h.rowPtrs[j], h.rowPtrs[i]
}

// TestKillSignalInTopN is for test
func TestKillSignalInTopN(t *testing.T, topnExec *TopNExec) {
ctx := context.Background()
err := topnExec.Open(ctx)
require.NoError(t, err)

chkHeap := &topNChunkHeap{}
// Offset of heap in worker should be 0, as we need to spill all data
chkHeap.init(topnExec, topnExec.memTracker, topnExec.Limit.Offset+topnExec.Limit.Count, 0, topnExec.greaterRow, topnExec.RetFieldTypes())
srcChk := exec.TryNewCacheChunk(topnExec.Children(0))
err = exec.Next(ctx, topnExec.Children(0), srcChk)
require.NoError(t, err)
chkHeap.rowChunks.Add(srcChk)

topnExec.Ctx().GetSessionVars().SQLKiller.SendKillSignal(sqlkiller.QueryInterrupted)
err = topnExec.spillHelper.spillHeap(chkHeap)
require.Error(t, err, exeerrors.ErrQueryInterrupted.GenWithStackByArgs())
}
12 changes: 12 additions & 0 deletions pkg/executor/sortexec/topn_spill.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ import (
"github.com/pingcap/tidb/pkg/util/disk"
"github.com/pingcap/tidb/pkg/util/logutil"
"github.com/pingcap/tidb/pkg/util/memory"
"github.com/pingcap/tidb/pkg/util/sqlkiller"
"go.uber.org/zap"
)

Expand All @@ -47,6 +48,8 @@ type topNSpillHelper struct {

bytesConsumed atomic.Int64
bytesLimit atomic.Int64

sqlKiller *sqlkiller.SQLKiller
}

func newTopNSpillerHelper(
Expand All @@ -58,6 +61,7 @@ func newTopNSpillerHelper(
fieldTypes []*types.FieldType,
workers []*topNWorker,
concurrencyNum int,
sqlKiller *sqlkiller.SQLKiller,
) *topNSpillHelper {
lock := sync.Mutex{}
tmpSpillChunksChan := make(chan *chunk.Chunk, concurrencyNum)
Expand All @@ -78,6 +82,7 @@ func newTopNSpillerHelper(
workers: workers,
bytesConsumed: atomic.Int64{},
bytesLimit: atomic.Int64{},
sqlKiller: sqlKiller,
}
}

Expand Down Expand Up @@ -209,6 +214,13 @@ func (t *topNSpillHelper) spillHeap(chkHeap *topNChunkHeap) error {

rowPtrNum := chkHeap.Len()
for ; chkHeap.idx < rowPtrNum; chkHeap.idx++ {
if chkHeap.idx%100 == 0 && t.sqlKiller != nil {
err := t.sqlKiller.HandleSignal()
if err != nil {
return err
}
}

if tmpSpillChunk.IsFull() {
err := t.spillTmpSpillChunk(inDisk, tmpSpillChunk)
if err != nil {
Expand Down
22 changes: 22 additions & 0 deletions pkg/executor/sortexec/topn_spill_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,28 @@ func TestIssue54206(t *testing.T) {
tk.MustQuery("select t1.a+t1.b as result from t1 left join t2 on 1 = 0 order by result limit 1;")
}

func TestIssue54541(t *testing.T) {
totalRowNum := 30
sortexec.SetSmallSpillChunkSizeForTest()
ctx := mock.NewContext()
topNCase := &testutil.SortCase{Rows: totalRowNum, OrderByIdx: []int{0, 1}, Ndvs: []int{0, 0}, Ctx: ctx}

ctx.GetSessionVars().InitChunkSize = 32
ctx.GetSessionVars().MaxChunkSize = 32
ctx.GetSessionVars().MemTracker = memory.NewTracker(memory.LabelForSQLText, hardLimit2)
ctx.GetSessionVars().StmtCtx.MemTracker = memory.NewTracker(memory.LabelForSQLText, -1)
ctx.GetSessionVars().StmtCtx.MemTracker.AttachTo(ctx.GetSessionVars().MemTracker)

offset := uint64(totalRowNum / 10)
count := uint64(totalRowNum / 3)

schema := expression.NewSchema(topNCase.Columns()...)
dataSource := buildDataSource(topNCase, schema)
exe := buildTopNExec(topNCase, dataSource, offset, count)

sortexec.TestKillSignalInTopN(t, exe)
}

func TestTopNFallBackAction(t *testing.T) {
sortexec.SetSmallSpillChunkSizeForTest()
ctx := mock.NewContext()
Expand Down