diff --git a/planner/core/casetest/correlated/main_test.go b/planner/core/casetest/correlated/main_test.go index 63e637f6da22d..eb16b2438039c 100644 --- a/planner/core/casetest/correlated/main_test.go +++ b/planner/core/casetest/correlated/main_test.go @@ -33,6 +33,7 @@ func TestMain(m *testing.M) { goleak.IgnoreTopFunction("gopkg.in/natefinch/lumberjack%2ev2.(*Logger).millRun"), goleak.IgnoreTopFunction("github.com/tikv/client-go/v2/txnkv/transaction.keepAlive"), goleak.IgnoreTopFunction("go.opencensus.io/stats/view.(*worker).start"), + goleak.IgnoreTopFunction("github.com/pingcap/tidb/pkg/statistics/handle/syncload.(*statsSyncLoad).SendLoadRequests.func1"), // For TestPlanStatsLoadTimeout } goleak.VerifyTestMain(m, opts...) } diff --git a/sessionctx/stmtctx/BUILD.bazel b/sessionctx/stmtctx/BUILD.bazel index 4d7be5fc6faf7..778d5a66aa870 100644 --- a/sessionctx/stmtctx/BUILD.bazel +++ b/sessionctx/stmtctx/BUILD.bazel @@ -23,6 +23,7 @@ go_library( "@com_github_tikv_client_go_v2//util", "@org_golang_x_exp//maps", "@org_golang_x_exp//slices", + "@org_golang_x_sync//singleflight", "@org_uber_go_atomic//:atomic", "@org_uber_go_zap//:zap", ], diff --git a/sessionctx/stmtctx/stmtctx.go b/sessionctx/stmtctx/stmtctx.go index 8e4b8813c9372..e34b83011ab5c 100644 --- a/sessionctx/stmtctx/stmtctx.go +++ b/sessionctx/stmtctx/stmtctx.go @@ -45,6 +45,7 @@ import ( "go.uber.org/zap" "golang.org/x/exp/maps" "golang.org/x/exp/slices" + "golang.org/x/sync/singleflight" ) const ( @@ -347,7 +348,7 @@ type StatementContext struct { // NeededItems stores the columns/indices whose stats are needed for planner. NeededItems []model.TableItemID // ResultCh to receive stats loading results - ResultCh chan StatsLoadResult + ResultCh []<-chan singleflight.Result // LoadStartTime is to record the load start time to calculate latency LoadStartTime time.Time } diff --git a/statistics/handle/handle_hist.go b/statistics/handle/handle_hist.go index 27b8de78b0a37..231a748cead9c 100644 --- a/statistics/handle/handle_hist.go +++ b/statistics/handle/handle_hist.go @@ -40,6 +40,8 @@ import ( // RetryCount is the max retry count for a sync load task. const RetryCount = 3 +var globalStatsSyncLoadSingleFlight singleflight.Group + type statsWrapper struct { col *statistics.Column idx *statistics.Index @@ -80,25 +82,26 @@ func (h *Handle) SendLoadRequests(sc *stmtctx.StatementContext, neededHistItems } sc.StatsLoad.Timeout = timeout sc.StatsLoad.NeededItems = remainedItems - sc.StatsLoad.ResultCh = make(chan stmtctx.StatsLoadResult, len(remainedItems)) - tasks := make([]*NeededItemTask, 0) + sc.StatsLoad.ResultCh = make([]<-chan singleflight.Result, 0, len(remainedItems)) for _, item := range remainedItems { - task := &NeededItemTask{ - TableItemID: item, - ToTimeout: time.Now().Local().Add(timeout), - ResultCh: sc.StatsLoad.ResultCh, - } - tasks = append(tasks, task) - } - timer := time.NewTimer(timeout) - defer timer.Stop() - for _, task := range tasks { - select { - case h.StatsLoad.NeededItemsCh <- task: - continue - case <-timer.C: - return errors.New("sync load stats channel is full and timeout sending task to channel") - } + localItem := item + resultCh := globalStatsSyncLoadSingleFlight.DoChan(localItem.Key(), func() (any, error) { + timer := time.NewTimer(timeout) + defer timer.Stop() + task := &NeededItemTask{ + TableItemID: localItem, + ToTimeout: time.Now().Local().Add(timeout), + ResultCh: make(chan stmtctx.StatsLoadResult, 1), + } + select { + case h.StatsLoad.NeededItemsCh <- task: + result := <-task.ResultCh + return result, nil + case <-timer.C: + return nil, errors.New("sync load stats channel is full and timeout sending task to channel") + } + }) + sc.StatsLoad.ResultCh = append(sc.StatsLoad.ResultCh, resultCh) } sc.StatsLoad.LoadStartTime = time.Now() return nil @@ -124,26 +127,34 @@ func (h *Handle) SyncWaitStatsLoad(sc *stmtctx.StatementContext) error { metrics.SyncLoadCounter.Inc() timer := time.NewTimer(sc.StatsLoad.Timeout) defer timer.Stop() - for { + for _, resultCh := range sc.StatsLoad.ResultCh { select { - case result, ok := <-sc.StatsLoad.ResultCh: - if ok { - if result.HasError() { - errorMsgs = append(errorMsgs, result.ErrorMsg()) - } - delete(resultCheckMap, result.Item) - if len(resultCheckMap) == 0 { - metrics.SyncLoadHistogram.Observe(float64(time.Since(sc.StatsLoad.LoadStartTime).Milliseconds())) - return nil - } - } else { + case result, ok := <-resultCh: + if !ok { return errors.New("sync load stats channel closed unexpectedly") } + // this error is from statsSyncLoad.SendLoadRequests which start to task and send task into worker, + // not the stats loading error + if result.Err != nil { + errorMsgs = append(errorMsgs, result.Err.Error()) + } else { + val := result.Val.(stmtctx.StatsLoadResult) + // this error is from the stats loading error + if val.HasError() { + errorMsgs = append(errorMsgs, val.ErrorMsg()) + } + delete(resultCheckMap, val.Item) + } case <-timer.C: metrics.SyncLoadTimeoutCounter.Inc() return errors.New("sync load stats timeout") } } + if len(resultCheckMap) == 0 { + metrics.SyncLoadHistogram.Observe(float64(time.Since(sc.StatsLoad.LoadStartTime).Milliseconds())) + return nil + } + return nil } // removeHistLoadedColumns removed having-hist columns based on neededColumns and statsCache. diff --git a/statistics/handle/handle_hist_test.go b/statistics/handle/handle_hist_test.go index 92f4d230cf819..6189281d99c42 100644 --- a/statistics/handle/handle_hist_test.go +++ b/statistics/handle/handle_hist_test.go @@ -208,17 +208,21 @@ func TestConcurrentLoadHistWithPanicAndFail(t *testing.T) { task1, err1 := h.HandleOneTask(nil, readerCtx, testKit.Session().(sqlexec.RestrictedSQLExecutor), exitCh) require.Error(t, err1) require.NotNil(t, task1) - select { - case <-stmtCtx1.StatsLoad.ResultCh: - t.Logf("stmtCtx1.ResultCh should not get anything") - t.FailNow() - case <-stmtCtx2.StatsLoad.ResultCh: - t.Logf("stmtCtx2.ResultCh should not get anything") - t.FailNow() - case <-task1.ResultCh: - t.Logf("task1.ResultCh should not get anything") - t.FailNow() - default: + for _, resultCh := range stmtCtx1.StatsLoad.ResultCh { + select { + case <-resultCh: + t.Logf("stmtCtx1.ResultCh should not get anything") + t.FailNow() + default: + } + } + for _, resultCh := range stmtCtx2.StatsLoad.ResultCh { + select { + case <-resultCh: + t.Logf("stmtCtx1.ResultCh should not get anything") + t.FailNow() + default: + } } require.NoError(t, failpoint.Disable(fp.failPath)) @@ -226,16 +230,18 @@ func TestConcurrentLoadHistWithPanicAndFail(t *testing.T) { require.NoError(t, err3) require.Nil(t, task3) - task, err3 := h.HandleOneTask(nil, readerCtx, testKit.Session().(sqlexec.RestrictedSQLExecutor), exitCh) - require.NoError(t, err3) - require.Nil(t, task) - - rs1, ok1 := <-stmtCtx1.StatsLoad.ResultCh - require.True(t, ok1) - require.Equal(t, neededColumns[0], rs1.Item) - rs2, ok2 := <-stmtCtx2.StatsLoad.ResultCh - require.True(t, ok2) - require.Equal(t, neededColumns[0], rs2.Item) + for _, resultCh := range stmtCtx1.StatsLoad.ResultCh { + rs1, ok1 := <-resultCh + require.True(t, rs1.Shared) + require.True(t, ok1) + require.Equal(t, neededColumns[0].ID, rs1.Val.(stmtctx.StatsLoadResult).Item.ID) + } + for _, resultCh := range stmtCtx2.StatsLoad.ResultCh { + rs1, ok1 := <-resultCh + require.True(t, rs1.Shared) + require.True(t, ok1) + require.Equal(t, neededColumns[0].ID, rs1.Val.(stmtctx.StatsLoadResult).Item.ID) + } stat = h.GetTableStats(tableInfo) hg = stat.Columns[tableInfo.Columns[2].ID].Histogram @@ -312,11 +318,11 @@ func TestRetry(t *testing.T) { result, err1 := h.HandleOneTask(task1, readerCtx, testKit.Session().(sqlexec.RestrictedSQLExecutor), exitCh) require.NoError(t, err1) require.Nil(t, result) - select { - case <-task1.ResultCh: - default: - t.Logf("task1.ResultCh should get nothing") - t.FailNow() + for _, resultCh := range stmtCtx1.StatsLoad.ResultCh { + rs1, ok1 := <-resultCh + require.True(t, rs1.Shared) + require.True(t, ok1) + require.Error(t, rs1.Val.(stmtctx.StatsLoadResult).Error) } require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/statistics/handle/mockReadStatsForOneFail")) }