diff --git a/pkg/planner/core/casetest/planstats/plan_stats_test.go b/pkg/planner/core/casetest/planstats/plan_stats_test.go index 3e24920644f7d..1698895d7ca6b 100644 --- a/pkg/planner/core/casetest/planstats/plan_stats_test.go +++ b/pkg/planner/core/casetest/planstats/plan_stats_test.go @@ -284,7 +284,7 @@ func TestPlanStatsLoadTimeout(t *testing.T) { tk.MustExec("set global tidb_stats_load_pseudo_timeout=true") require.NoError(t, failpoint.Enable("github.com/pingcap/executor/assertSyncStatsFailed", `return(true)`)) tk.MustExec(sql) // not fail sql for timeout when pseudo=true - failpoint.Disable("github.com/pingcap/executor/assertSyncStatsFailed") + require.NoError(t, failpoint.Disable("github.com/pingcap/executor/assertSyncStatsFailed")) plan, _, err := planner.Optimize(context.TODO(), ctx, stmt, is) require.NoError(t, err) // not fail sql for timeout when pseudo=true diff --git a/pkg/sessionctx/stmtctx/BUILD.bazel b/pkg/sessionctx/stmtctx/BUILD.bazel index 7595601c9dd19..b244a670d0204 100644 --- a/pkg/sessionctx/stmtctx/BUILD.bazel +++ b/pkg/sessionctx/stmtctx/BUILD.bazel @@ -27,6 +27,7 @@ go_library( "@com_github_tikv_client_go_v2//tikvrpc", "@com_github_tikv_client_go_v2//util", "@org_golang_x_exp//maps", + "@org_golang_x_sync//singleflight", "@org_uber_go_atomic//:atomic", "@org_uber_go_zap//:zap", ], diff --git a/pkg/sessionctx/stmtctx/stmtctx.go b/pkg/sessionctx/stmtctx/stmtctx.go index 127b338277015..972e0056d1671 100644 --- a/pkg/sessionctx/stmtctx/stmtctx.go +++ b/pkg/sessionctx/stmtctx/stmtctx.go @@ -50,6 +50,7 @@ import ( atomic2 "go.uber.org/atomic" "go.uber.org/zap" "golang.org/x/exp/maps" + "golang.org/x/sync/singleflight" ) const ( @@ -363,7 +364,7 @@ type StatementContext struct { // NeededItems stores the columns/indices whose stats are needed for planner. NeededItems []model.TableItemID // ResultCh to receive stats loading results - ResultCh chan StatsLoadResult + ResultCh []<-chan singleflight.Result // LoadStartTime is to record the load start time to calculate latency LoadStartTime time.Time } diff --git a/pkg/statistics/handle/BUILD.bazel b/pkg/statistics/handle/BUILD.bazel index 3f67a50020219..4ed865fdfaced 100644 --- a/pkg/statistics/handle/BUILD.bazel +++ b/pkg/statistics/handle/BUILD.bazel @@ -36,6 +36,7 @@ go_library( "//pkg/types", "//pkg/util", "//pkg/util/chunk", + "//pkg/util/intest", "//pkg/util/logutil", "@com_github_pingcap_errors//:errors", "@com_github_pingcap_failpoint//:failpoint", diff --git a/pkg/statistics/handle/handle_hist.go b/pkg/statistics/handle/handle_hist.go index 6f0bd0b1b9b3b..3a55bcc1e9dcd 100644 --- a/pkg/statistics/handle/handle_hist.go +++ b/pkg/statistics/handle/handle_hist.go @@ -33,6 +33,7 @@ import ( utilstats "github.com/pingcap/tidb/pkg/statistics/handle/util" "github.com/pingcap/tidb/pkg/types" "github.com/pingcap/tidb/pkg/util" + "github.com/pingcap/tidb/pkg/util/intest" "github.com/pingcap/tidb/pkg/util/logutil" "go.uber.org/zap" "golang.org/x/sync/singleflight" @@ -41,6 +42,8 @@ import ( // RetryCount is the max retry count for a sync load task. const RetryCount = 3 +var globalStatsSyncLoadSingleFlight singleflight.Group + type statsWrapper struct { col *statistics.Column idx *statistics.Index @@ -81,25 +84,27 @@ func (h *Handle) SendLoadRequests(sc *stmtctx.StatementContext, neededHistItems } sc.StatsLoad.Timeout = timeout sc.StatsLoad.NeededItems = remainedItems - sc.StatsLoad.ResultCh = make(chan stmtctx.StatsLoadResult, len(remainedItems)) - tasks := make([]*NeededItemTask, 0) + sc.StatsLoad.ResultCh = make([]<-chan singleflight.Result, 0, len(remainedItems)) for _, item := range remainedItems { - task := &NeededItemTask{ - TableItemID: item, - ToTimeout: time.Now().Local().Add(timeout), - ResultCh: sc.StatsLoad.ResultCh, - } - tasks = append(tasks, task) - } - timer := time.NewTimer(timeout) - defer timer.Stop() - for _, task := range tasks { - select { - case h.StatsLoad.NeededItemsCh <- task: - continue - case <-timer.C: - return errors.New("sync load stats channel is full and timeout sending task to channel") - } + localItem := item + resultCh := globalStatsSyncLoadSingleFlight.DoChan(localItem.Key(), func() (any, error) { + timer := time.NewTimer(timeout) + defer timer.Stop() + task := &NeededItemTask{ + TableItemID: localItem, + ToTimeout: time.Now().Local().Add(timeout), + ResultCh: make(chan stmtctx.StatsLoadResult, 1), + } + select { + case h.StatsLoad.NeededItemsCh <- task: + result, ok := <-task.ResultCh + intest.Assert(ok, "task.ResultCh cannot be closed") + return result, nil + case <-timer.C: + return nil, errors.New("sync load stats channel is full and timeout sending task to channel") + } + }) + sc.StatsLoad.ResultCh = append(sc.StatsLoad.ResultCh, resultCh) } sc.StatsLoad.LoadStartTime = time.Now() return nil @@ -125,25 +130,34 @@ func (*Handle) SyncWaitStatsLoad(sc *stmtctx.StatementContext) error { metrics.SyncLoadCounter.Inc() timer := time.NewTimer(sc.StatsLoad.Timeout) defer timer.Stop() - for { + for _, resultCh := range sc.StatsLoad.ResultCh { select { - case result, ok := <-sc.StatsLoad.ResultCh: + case result, ok := <-resultCh: if !ok { return errors.New("sync load stats channel closed unexpectedly") } - if result.HasError() { - errorMsgs = append(errorMsgs, result.ErrorMsg()) - } - delete(resultCheckMap, result.Item) - if len(resultCheckMap) == 0 { - metrics.SyncLoadHistogram.Observe(float64(time.Since(sc.StatsLoad.LoadStartTime).Milliseconds())) - return nil + // this error is from statsSyncLoad.SendLoadRequests which start to task and send task into worker, + // not the stats loading error + if result.Err != nil { + errorMsgs = append(errorMsgs, result.Err.Error()) + } else { + val := result.Val.(stmtctx.StatsLoadResult) + // this error is from the stats loading error + if val.HasError() { + errorMsgs = append(errorMsgs, val.ErrorMsg()) + } + delete(resultCheckMap, val.Item) } case <-timer.C: metrics.SyncLoadTimeoutCounter.Inc() return errors.New("sync load stats timeout") } } + if len(resultCheckMap) == 0 { + metrics.SyncLoadHistogram.Observe(float64(time.Since(sc.StatsLoad.LoadStartTime).Milliseconds())) + return nil + } + return nil } // removeHistLoadedColumns removed having-hist columns based on neededColumns and statsCache. @@ -230,33 +244,17 @@ func (h *Handle) HandleOneTask(sctx sessionctx.Context, lastTask *NeededItemTask task = lastTask } result := stmtctx.StatsLoadResult{Item: task.TableItemID} - resultChan := h.StatsLoad.Singleflight.DoChan(task.TableItemID.Key(), func() (any, error) { - err := h.handleOneItemTask(task) - return nil, err - }) - timeout := time.Until(task.ToTimeout) - select { - case sr := <-resultChan: - // sr.Val is always nil. - if sr.Err == nil { - task.ResultCh <- result - return nil, nil - } - if !isVaildForRetry(task) { - result.Error = sr.Err - task.ResultCh <- result - return nil, nil - } - return task, sr.Err - case <-time.After(timeout): - if !isVaildForRetry(task) { - result.Error = errors.New("stats loading timeout") - task.ResultCh <- result - return nil, nil - } - task.ToTimeout.Add(time.Duration(sctx.GetSessionVars().StatsLoadSyncWait.Load()) * time.Microsecond) - return task, nil + err = h.handleOneItemTask(task) + if err == nil { + task.ResultCh <- result + return nil, nil + } + if !isVaildForRetry(task) { + result.Error = err + task.ResultCh <- result + return nil, nil } + return task, err } func isVaildForRetry(task *NeededItemTask) bool { diff --git a/pkg/statistics/handle/handle_hist_test.go b/pkg/statistics/handle/handle_hist_test.go index 6d0d43a1d87c5..c8c78b7df0aa5 100644 --- a/pkg/statistics/handle/handle_hist_test.go +++ b/pkg/statistics/handle/handle_hist_test.go @@ -207,13 +207,23 @@ func TestConcurrentLoadHistWithPanicAndFail(t *testing.T) { task1, err1 := h.HandleOneTask(testKit.Session().(sessionctx.Context), nil, exitCh) require.Error(t, err1) require.NotNil(t, task1) + for _, resultCh := range stmtCtx1.StatsLoad.ResultCh { + select { + case <-resultCh: + t.Logf("stmtCtx1.ResultCh should not get anything") + t.FailNow() + default: + } + } + for _, resultCh := range stmtCtx2.StatsLoad.ResultCh { + select { + case <-resultCh: + t.Logf("stmtCtx1.ResultCh should not get anything") + t.FailNow() + default: + } + } select { - case <-stmtCtx1.StatsLoad.ResultCh: - t.Logf("stmtCtx1.ResultCh should not get anything") - t.FailNow() - case <-stmtCtx2.StatsLoad.ResultCh: - t.Logf("stmtCtx2.ResultCh should not get anything") - t.FailNow() case <-task1.ResultCh: t.Logf("task1.ResultCh should not get anything") t.FailNow() @@ -225,16 +235,18 @@ func TestConcurrentLoadHistWithPanicAndFail(t *testing.T) { require.NoError(t, err3) require.Nil(t, task3) - task, err3 := h.HandleOneTask(testKit.Session().(sessionctx.Context), nil, exitCh) - require.NoError(t, err3) - require.Nil(t, task) - - rs1, ok1 := <-stmtCtx1.StatsLoad.ResultCh - require.True(t, ok1) - require.Equal(t, neededColumns[0], rs1.Item) - rs2, ok2 := <-stmtCtx2.StatsLoad.ResultCh - require.True(t, ok2) - require.Equal(t, neededColumns[0], rs2.Item) + for _, resultCh := range stmtCtx1.StatsLoad.ResultCh { + rs1, ok1 := <-resultCh + require.True(t, rs1.Shared) + require.True(t, ok1) + require.Equal(t, neededColumns[0], rs1.Val.(stmtctx.StatsLoadResult).Item) + } + for _, resultCh := range stmtCtx2.StatsLoad.ResultCh { + rs1, ok1 := <-resultCh + require.True(t, rs1.Shared) + require.True(t, ok1) + require.Equal(t, neededColumns[0], rs1.Val.(stmtctx.StatsLoadResult).Item) + } stat = h.GetTableStats(tableInfo) hg = stat.Columns[tableInfo.Columns[2].ID].Histogram @@ -310,11 +322,11 @@ func TestRetry(t *testing.T) { result, err1 := h.HandleOneTask(testKit.Session().(sessionctx.Context), task1, exitCh) require.NoError(t, err1) require.Nil(t, result) - select { - case <-task1.ResultCh: - default: - t.Logf("task1.ResultCh should get nothing") - t.FailNow() + for _, resultCh := range stmtCtx1.StatsLoad.ResultCh { + rs1, ok1 := <-resultCh + require.True(t, rs1.Shared) + require.True(t, ok1) + require.Error(t, rs1.Val.(stmtctx.StatsLoadResult).Error) } task1.Retry = 0 for i := 0; i < handle.RetryCount*5; i++ {