From fdff1e6cc2d07c3f11c223d72fb60bf7350ba3bb Mon Sep 17 00:00:00 2001 From: Yuanjia Zhang Date: Tue, 19 Sep 2023 14:19:42 +0800 Subject: [PATCH] planner: use the session pool to execute SQLs in statshandler (#47065) ref pingcap/tidb#46905 --- statistics/handle/ddl.go | 38 +++++-- statistics/handle/dump.go | 8 +- statistics/handle/handle.go | 76 +++++++------ statistics/handle/handle_hist.go | 2 +- statistics/handle/historical_stats_handler.go | 15 ++- statistics/handle/update.go | 106 +++++++++++++----- statistics/handle/updatetest/update_test.go | 8 +- statistics/interact_with_storage.go | 15 ++- 8 files changed, 180 insertions(+), 88 deletions(-) diff --git a/statistics/handle/ddl.go b/statistics/handle/ddl.go index 236ec147cf291..fc69fb67ec92d 100644 --- a/statistics/handle/ddl.go +++ b/statistics/handle/ddl.go @@ -37,21 +37,30 @@ import ( func (h *Handle) HandleDDLEvent(t *util.Event) error { switch t.Tp { case model.ActionCreateTable, model.ActionTruncateTable: - ids := h.getInitStateTableIDs(t.TableInfo) + ids, err := h.getInitStateTableIDs(t.TableInfo) + if err != nil { + return err + } for _, id := range ids { if err := h.insertTableStats2KV(t.TableInfo, id); err != nil { return err } } case model.ActionDropTable: - ids := h.getInitStateTableIDs(t.TableInfo) + ids, err := h.getInitStateTableIDs(t.TableInfo) + if err != nil { + return err + } for _, id := range ids { if err := h.resetTableStats2KVForDrop(id); err != nil { return err } } case model.ActionAddColumn, model.ActionModifyColumn: - ids := h.getInitStateTableIDs(t.TableInfo) + ids, err := h.getInitStateTableIDs(t.TableInfo) + if err != nil { + return err + } for _, id := range ids { if err := h.insertColStats2KV(id, t.ColumnInfos); err != nil { return err @@ -64,8 +73,11 @@ func (h *Handle) HandleDDLEvent(t *util.Event) error { } } case model.ActionDropTablePartition: - pruneMode := h.CurrentPruneMode() - if pruneMode == variable.Dynamic && t.PartInfo != nil { + pruneMode, err := h.GetCurrentPruneMode() + if err != nil { + return err + } + if variable.PartitionPruneMode(pruneMode) == variable.Dynamic && t.PartInfo != nil { if err := h.updateGlobalStats(t.TableInfo); err != nil { return err } @@ -189,7 +201,7 @@ func (h *Handle) updateGlobalStats(tblInfo *model.TableInfo) error { opts[ast.AnalyzeOptNumBuckets] = uint64(globalColStatsBucketNum) } // Generate the new column global-stats - newColGlobalStats, err := h.mergePartitionStats2GlobalStats(h.mu.ctx, opts, is, tblInfo, 0, nil, nil) + newColGlobalStats, err := h.mergePartitionStats2GlobalStats(opts, is, tblInfo, 0, nil, nil) if err != nil { return err } @@ -228,7 +240,7 @@ func (h *Handle) updateGlobalStats(tblInfo *model.TableInfo) error { if globalIdxStatsBucketNum != 0 { opts[ast.AnalyzeOptNumBuckets] = uint64(globalIdxStatsBucketNum) } - newIndexGlobalStats, err := h.mergePartitionStats2GlobalStats(h.mu.ctx, opts, is, tblInfo, 1, []int64{idx.ID}, nil) + newIndexGlobalStats, err := h.mergePartitionStats2GlobalStats(opts, is, tblInfo, 1, []int64{idx.ID}, nil) if err != nil { return err } @@ -276,19 +288,23 @@ func (h *Handle) changeGlobalStatsID(from, to int64) (err error) { return nil } -func (h *Handle) getInitStateTableIDs(tblInfo *model.TableInfo) (ids []int64) { +func (h *Handle) getInitStateTableIDs(tblInfo *model.TableInfo) (ids []int64, err error) { pi := tblInfo.GetPartitionInfo() if pi == nil { - return []int64{tblInfo.ID} + return []int64{tblInfo.ID}, nil } ids = make([]int64, 0, len(pi.Definitions)+1) for _, def := range pi.Definitions { ids = append(ids, def.ID) } - if h.CurrentPruneMode() == variable.Dynamic { + pruneMode, err := h.GetCurrentPruneMode() + if err != nil { + return nil, err + } + if variable.PartitionPruneMode(pruneMode) == variable.Dynamic { ids = append(ids, tblInfo.ID) } - return ids + return ids, nil } // DDLEventCh returns ddl events channel in handle. diff --git a/statistics/handle/dump.go b/statistics/handle/dump.go index ba240a606757f..f27c5e7395517 100644 --- a/statistics/handle/dump.go +++ b/statistics/handle/dump.go @@ -203,9 +203,11 @@ func (h *Handle) DumpHistoricalStatsBySnapshot( // DumpStatsToJSONBySnapshot dumps statistic to json. func (h *Handle) DumpStatsToJSONBySnapshot(dbName string, tableInfo *model.TableInfo, snapshot uint64, dumpPartitionStats bool) (*JSONTable, error) { - h.mu.Lock() - isDynamicMode := variable.PartitionPruneMode(h.mu.ctx.GetSessionVars().PartitionPruneMode.Load()) == variable.Dynamic - h.mu.Unlock() + pruneMode, err := h.GetCurrentPruneMode() + if err != nil { + return nil, err + } + isDynamicMode := variable.PartitionPruneMode(pruneMode) == variable.Dynamic pi := tableInfo.GetPartitionInfo() if pi == nil { return h.tableStatsToJSON(dbName, tableInfo, tableInfo.ID, snapshot) diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index 5a6011c7c81ba..151bb2664e2dd 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -131,11 +131,15 @@ func (h *Handle) execRestrictedSQL(ctx context.Context, sql string, params ...in func (h *Handle) execRestrictedSQLWithStatsVer(ctx context.Context, statsVer int, procTrackID uint64, analyzeSnapshot bool, sql string, params ...interface{}) ([]chunk.Row, []*ast.ResultField, error) { ctx = kv.WithInternalSourceType(ctx, kv.InternalTxnStats) + pruneMode, err := h.GetCurrentPruneMode() + if err != nil { + return nil, nil, err + } return h.withRestrictedSQLExecutor(ctx, func(ctx context.Context, exec sqlexec.RestrictedSQLExecutor) ([]chunk.Row, []*ast.ResultField, error) { optFuncs := []sqlexec.OptionFuncAlias{ execOptionForAnalyze[statsVer], sqlexec.GetAnalyzeSnapshotOption(analyzeSnapshot), - sqlexec.GetPartitionPruneModeOption(string(h.CurrentPruneMode())), + sqlexec.GetPartitionPruneModeOption(pruneMode), sqlexec.ExecOptionUseCurSession, sqlexec.ExecOptionWithSysProcTrack(procTrackID, h.sysProcTracker.Track, h.sysProcTracker.UnTrack), } @@ -320,13 +324,6 @@ func (h *Handle) Update(is infoschema.InfoSchema) error { return nil } -// UpdateSessionVar updates the necessary session variables for the stats reader. -func (h *Handle) UpdateSessionVar() error { - h.mu.Lock() - defer h.mu.Unlock() - return UpdateSCtxVarsForStats(h.mu.ctx) -} - // UpdateSCtxVarsForStats updates all necessary variables that may affect the behavior of statistics. func UpdateSCtxVarsForStats(sctx sessionctx.Context) error { // analyzer version @@ -401,10 +398,16 @@ func (h *Handle) loadTablePartitionStats(tableInfo *model.TableInfo, partitionDe } // MergePartitionStats2GlobalStatsByTableID merge the partition-level stats to global-level stats based on the tableInfo. -func (h *Handle) mergePartitionStats2GlobalStats(sc sessionctx.Context, - opts map[ast.AnalyzeOptionType]uint64, is infoschema.InfoSchema, globalTableInfo *model.TableInfo, - isIndex int, histIDs []int64, +func (h *Handle) mergePartitionStats2GlobalStats(opts map[ast.AnalyzeOptionType]uint64, + is infoschema.InfoSchema, globalTableInfo *model.TableInfo, isIndex int, histIDs []int64, allPartitionStats map[int64]*statistics.Table) (globalStats *globalstats.GlobalStats, err error) { + se, err := h.pool.Get() + if err != nil { + return nil, err + } + defer h.pool.Put(se) + sc := se.(sessionctx.Context) + if err := UpdateSCtxVarsForStats(sc); err != nil { return nil, err } @@ -1157,20 +1160,17 @@ func (h *Handle) statsMetaByTableIDFromStorage(tableID int64, snapshot uint64) ( } func (h *Handle) getGlobalStatsReader(snapshot uint64) (reader *statistics.StatsReader, err error) { - h.mu.Lock() - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("getGlobalStatsReader panic %v", r) - } - if err != nil { - h.mu.Unlock() - } - }() - return statistics.GetStatsReader(snapshot, h.mu.ctx.(sqlexec.RestrictedSQLExecutor)) + se, err := h.pool.Get() + if err != nil { + return nil, err + } + exec := se.(sqlexec.RestrictedSQLExecutor) + return statistics.GetStatsReader(snapshot, exec, func() { + h.pool.Put(se) + }) } -func (h *Handle) releaseGlobalStatsReader(reader *statistics.StatsReader) error { - defer h.mu.Unlock() +func (*Handle) releaseGlobalStatsReader(reader *statistics.StatsReader) error { return reader.Close() } @@ -1423,9 +1423,16 @@ func (h *Handle) fillExtStatsCorrVals(item *statistics.ExtendedStatsItem, cols [ item.ScalarVals = 0 return item } - h.mu.Lock() - sc := h.mu.ctx.GetSessionVars().StmtCtx - h.mu.Unlock() + + se, seErr := h.pool.Get() + if seErr != nil { + logutil.BgLogger().Error("fail to get session", zap.String("category", "stats"), zap.Error(seErr)) + return nil + } + defer h.pool.Put(se) + sctx := se.(sessionctx.Context) + sc := sctx.GetSessionVars().StmtCtx + var err error samplesX, err = statistics.SortSampleItems(sc, samplesX) if err != nil { @@ -1523,11 +1530,6 @@ func (h *Handle) SaveExtendedStatsToStorage(tableID int64, extStats *statistics. return nil } -// CurrentPruneMode indicates whether tbl support runtime prune for table and first partition id. -func (h *Handle) CurrentPruneMode() variable.PartitionPruneMode { - return variable.PartitionPruneMode(h.mu.ctx.GetSessionVars().PartitionPruneMode.Load()) -} - // RefreshVars uses to pull PartitionPruneMethod vars from kv storage. func (h *Handle) RefreshVars() error { h.mu.Lock() @@ -1737,12 +1739,16 @@ func (h *Handle) RecordHistoricalStatsToStorage(dbName string, tableInfo *model. // CheckHistoricalStatsEnable is used to check whether TiDBEnableHistoricalStats is enabled. func (h *Handle) CheckHistoricalStatsEnable() (enable bool, err error) { - h.mu.Lock() - defer h.mu.Unlock() - if err := UpdateSCtxVarsForStats(h.mu.ctx); err != nil { + se, err := h.pool.Get() + if err != nil { + return false, err + } + defer h.pool.Put(se) + sctx := se.(sessionctx.Context) + if err := UpdateSCtxVarsForStats(sctx); err != nil { return false, err } - return h.mu.ctx.GetSessionVars().EnableHistoricalStats, nil + return sctx.GetSessionVars().EnableHistoricalStats, nil } // InsertAnalyzeJob inserts analyze job into mysql.analyze_jobs and gets job ID for further updating job. diff --git a/statistics/handle/handle_hist.go b/statistics/handle/handle_hist.go index 4e0da9f21a6c7..f6d14f8e8fe2b 100644 --- a/statistics/handle/handle_hist.go +++ b/statistics/handle/handle_hist.go @@ -302,7 +302,7 @@ func (h *Handle) loadFreshStatsReader(readerCtx *StatsReaderContext, ctx sqlexec } } for { - newReader, err := statistics.GetStatsReader(0, ctx) + newReader, err := statistics.GetStatsReader(0, ctx, nil) if err == nil { readerCtx.reader = newReader readerCtx.createdTime = time.Now() diff --git a/statistics/handle/historical_stats_handler.go b/statistics/handle/historical_stats_handler.go index acd4ae1b68ab0..771533e22e21c 100644 --- a/statistics/handle/historical_stats_handler.go +++ b/statistics/handle/historical_stats_handler.go @@ -90,14 +90,23 @@ func (h *Handle) recordHistoricalStatsMeta(tableID int64, version uint64, source if !tbl.IsInitialized() { return } - h.mu.Lock() - defer h.mu.Unlock() - err := recordHistoricalStatsMeta(h.mu.ctx, tableID, version, source) + se, err := h.pool.Get() if err != nil { logutil.BgLogger().Error("record historical stats meta failed", zap.Int64("table-id", tableID), zap.Uint64("version", version), zap.String("source", source), zap.Error(err)) + return + } + defer h.pool.Put(se) + sctx := se.(sessionctx.Context) + if err := recordHistoricalStatsMeta(sctx, tableID, version, source); err != nil { + logutil.BgLogger().Error("record historical stats meta failed", + zap.Int64("table-id", tableID), + zap.Uint64("version", version), + zap.String("source", source), + zap.Error(err)) + return } } diff --git a/statistics/handle/update.go b/statistics/handle/update.go index c94dfcac24035..b449261b8cf02 100644 --- a/statistics/handle/update.go +++ b/statistics/handle/update.go @@ -31,6 +31,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/statistics/handle/cache" @@ -443,11 +444,14 @@ func (h *Handle) DumpStatsDeltaToKV(mode dumpMode) error { defer func() { h.tableDelta.merge(deltaMap) }() - is := func() infoschema.InfoSchema { - h.mu.Lock() - defer h.mu.Unlock() - return h.mu.ctx.GetDomainInfoSchema().(infoschema.InfoSchema) - }() + + se, err := h.pool.Get() + if err != nil { + return err + } + defer h.pool.Put(se) + sctx := se.(sessionctx.Context) + is := sctx.GetDomainInfoSchema().(infoschema.InfoSchema) currentTime := time.Now() for id, item := range deltaMap { if !h.needDumpStatsDelta(is, mode, id, item, currentTime) { @@ -486,10 +490,15 @@ func (h *Handle) dumpTableStatCountToKV(is infoschema.InfoSchema, physicalTableI if delta.Count == 0 { return true, nil } - h.mu.Lock() - defer h.mu.Unlock() + + se, err := h.pool.Get() + if err != nil { + return false, err + } + defer h.pool.Put(se) + exec := se.(sqlexec.SQLExecutor) + sctx := se.(sessionctx.Context) ctx := kv.WithInternalSourceType(context.Background(), kv.InternalTxnStats) - exec := h.mu.ctx.(sqlexec.SQLExecutor) _, err = exec.ExecuteInternal(ctx, "begin") if err != nil { return false, errors.Trace(err) @@ -498,11 +507,10 @@ func (h *Handle) dumpTableStatCountToKV(is infoschema.InfoSchema, physicalTableI err = finishTransaction(ctx, exec, err) }() - txn, err := h.mu.ctx.Txn(true) + statsVersion, err = getSessionTxnStartTS(se) if err != nil { return false, errors.Trace(err) } - statsVersion = txn.StartTS() tbl, _, _ := is.FindTableByPartitionID(physicalTableID) // Check if the table and its partitions are locked. @@ -535,12 +543,12 @@ func (h *Handle) dumpTableStatCountToKV(is infoschema.InfoSchema, physicalTableI physicalTableID, tableOrPartitionLocked); err != nil { return } - affectedRows += h.mu.ctx.GetSessionVars().StmtCtx.AffectedRows() + affectedRows += sctx.GetSessionVars().StmtCtx.AffectedRows() // If it's a partitioned table and its global-stats exists, update its count and modify_count as well. if err = updateStatsMeta(ctx, exec, statsVersion, delta, tableID, isTableLocked); err != nil { return } - affectedRows += h.mu.ctx.GetSessionVars().StmtCtx.AffectedRows() + affectedRows += sctx.GetSessionVars().StmtCtx.AffectedRows() } else { // This is a non-partitioned table. // Check if it's locked. @@ -552,7 +560,7 @@ func (h *Handle) dumpTableStatCountToKV(is infoschema.InfoSchema, physicalTableI physicalTableID, isTableLocked); err != nil { return } - affectedRows += h.mu.ctx.GetSessionVars().StmtCtx.AffectedRows() + affectedRows += sctx.GetSessionVars().StmtCtx.AffectedRows() } updated = affectedRows > 0 @@ -759,11 +767,6 @@ func (h *Handle) HandleAutoAnalyze(is infoschema.InfoSchema) (analyzed bool) { logutil.BgLogger().Error("HandleAutoAnalyze panicked", zap.Any("error", r), zap.Stack("stack")) } }() - err := h.UpdateSessionVar() - if err != nil { - logutil.BgLogger().Error("update analyze version for auto analyze session failed", zap.String("category", "stats"), zap.Error(err)) - return false - } dbs := is.AllSchemaNames() parameters := h.getAutoAnalyzeParameters() autoAnalyzeRatio := parseAutoAnalyzeRatio(parameters[variable.TiDBAutoAnalyzeRatio]) @@ -775,10 +778,20 @@ func (h *Handle) HandleAutoAnalyze(is infoschema.InfoSchema) (analyzed bool) { if !timeutil.WithinDayTimePeriod(start, end, time.Now()) { return false } - h.mu.Lock() - pruneMode := variable.PartitionPruneMode(h.mu.ctx.GetSessionVars().PartitionPruneMode.Load()) - analyzeSnapshot := h.mu.ctx.GetSessionVars().EnableAnalyzeSnapshot - h.mu.Unlock() + + se, err := h.pool.Get() + if err != nil { + logutil.BgLogger().Error("get session from session pool failed", zap.String("category", "stats"), zap.Error(err)) + return false + } + defer h.pool.Put(se) + sctx := se.(sessionctx.Context) + if err := UpdateSCtxVarsForStats(sctx); err != nil { + logutil.BgLogger().Error("update session variables for stats failed", zap.String("category", "stats"), zap.Error(err)) + return false + } + pruneMode := variable.PartitionPruneMode(sctx.GetSessionVars().PartitionPruneMode.Load()) + analyzeSnapshot := sctx.GetSessionVars().EnableAnalyzeSnapshot rd := rand.New(rand.NewSource(time.Now().UnixNano())) // #nosec G404 rd.Shuffle(len(dbs), func(i, j int) { dbs[i], dbs[j] = dbs[j], dbs[i] @@ -875,7 +888,11 @@ func (h *Handle) autoAnalyzeTable(tblInfo *model.TableInfo, statsTbl *statistics return false } logutil.BgLogger().Info("auto analyze triggered", zap.String("category", "stats"), zap.String("sql", escaped), zap.String("reason", reason)) - tableStatsVer := h.mu.ctx.GetSessionVars().AnalyzeVersion + tableStatsVer, err := h.getCurrentAnalyzeVersion() + if err != nil { + logutil.BgLogger().Error("fail to get analyze version", zap.String("category", "stats"), zap.Error(err)) + return false + } statistics.CheckAnalyzeVerOnTable(statsTbl, &tableStatsVer) h.execAutoAnalyze(tableStatsVer, analyzeSnapshot, sql, params...) return true @@ -889,7 +906,11 @@ func (h *Handle) autoAnalyzeTable(tblInfo *model.TableInfo, statsTbl *statistics return false } logutil.BgLogger().Info("auto analyze for unanalyzed", zap.String("category", "stats"), zap.String("sql", escaped)) - tableStatsVer := h.mu.ctx.GetSessionVars().AnalyzeVersion + tableStatsVer, err := h.getCurrentAnalyzeVersion() + if err != nil { + logutil.BgLogger().Error("fail to get analyze version", zap.String("category", "stats"), zap.Error(err)) + return false + } statistics.CheckAnalyzeVerOnTable(statsTbl, &tableStatsVer) h.execAutoAnalyze(tableStatsVer, analyzeSnapshot, sqlWithIdx, paramsWithIdx...) return true @@ -898,10 +919,41 @@ func (h *Handle) autoAnalyzeTable(tblInfo *model.TableInfo, statsTbl *statistics return false } +func (h *Handle) getCurrentAnalyzeVersion() (int, error) { + se, err := h.pool.Get() + if err != nil { + return 0, err + } + defer h.pool.Put(se) + sctx := se.(sessionctx.Context) + if err := UpdateSCtxVarsForStats(sctx); err != nil { + return 0, err + } + return sctx.GetSessionVars().AnalyzeVersion, nil +} + +// GetCurrentPruneMode returns the current latest partitioning talbe prune mode. +func (h *Handle) GetCurrentPruneMode() (string, error) { + se, err := h.pool.Get() + if err != nil { + return "", err + } + defer h.pool.Put(se) + sctx := se.(sessionctx.Context) + if err := UpdateSCtxVarsForStats(sctx); err != nil { + return "", err + } + return sctx.GetSessionVars().PartitionPruneMode.Load(), nil +} + func (h *Handle) autoAnalyzePartitionTableInDynamicMode(tblInfo *model.TableInfo, partitionDefs []model.PartitionDefinition, db string, ratio float64, analyzeSnapshot bool) bool { - h.mu.RLock() - tableStatsVer := h.mu.ctx.GetSessionVars().AnalyzeVersion - h.mu.RUnlock() + tableStatsVer, err := h.getCurrentAnalyzeVersion() + if err != nil { + logutil.BgLogger().Info("fail to get analyze version", zap.String("category", "stats"), + zap.String("table", tblInfo.Name.String()), + zap.Error(err)) + return false + } analyzePartitionBatchSize := int(variable.AutoAnalyzePartitionBatchSize.Load()) partitionNames := make([]interface{}, 0, len(partitionDefs)) for _, def := range partitionDefs { diff --git a/statistics/handle/updatetest/update_test.go b/statistics/handle/updatetest/update_test.go index 072a48af28cd9..ceb255abe6a13 100644 --- a/statistics/handle/updatetest/update_test.go +++ b/statistics/handle/updatetest/update_test.go @@ -294,7 +294,9 @@ func TestTxnWithFailure(t *testing.T) { func TestUpdatePartition(t *testing.T) { store, dom := testkit.CreateMockStoreAndDomain(t) testKit := testkit.NewTestKit(t, store) - testKit.MustQuery("select @@tidb_partition_prune_mode").Check(testkit.Rows(string(dom.StatsHandle().CurrentPruneMode()))) + pruneMode, err := dom.StatsHandle().GetCurrentPruneMode() + require.NoError(t, err) + testKit.MustQuery("select @@tidb_partition_prune_mode").Check(testkit.Rows(pruneMode)) testKit.MustExec("use test") testkit.WithPruneMode(testKit, variable.Static, func() { err := dom.StatsHandle().RefreshVars() @@ -609,8 +611,6 @@ func TestAutoAnalyzeOnChangeAnalyzeVer(t *testing.T) { require.NoError(t, err) require.NoError(t, h.DumpStatsDeltaToKV(handle.DumpAll)) is := do.InfoSchema() - err = h.UpdateSessionVar() - require.NoError(t, err) require.NoError(t, h.Update(is)) // Auto analyze when global ver is 1. h.HandleAutoAnalyze(is) @@ -626,8 +626,6 @@ func TestAutoAnalyzeOnChangeAnalyzeVer(t *testing.T) { require.Equal(t, int64(1), idx.GetStatsVer()) } tk.MustExec("set @@global.tidb_analyze_version = 2") - err = h.UpdateSessionVar() - require.NoError(t, err) tk.MustExec("insert into t values(1), (2), (3), (4)") require.NoError(t, h.DumpStatsDeltaToKV(handle.DumpAll)) require.NoError(t, h.Update(is)) diff --git a/statistics/interact_with_storage.go b/statistics/interact_with_storage.go index 9e9f29ab427a3..36c5b058e8187 100644 --- a/statistics/interact_with_storage.go +++ b/statistics/interact_with_storage.go @@ -45,18 +45,19 @@ import ( // 2. StatsReader is not thread-safe. Different goroutines cannot call (*StatsReader).Read concurrently. type StatsReader struct { ctx sqlexec.RestrictedSQLExecutor + release func() // a call back function to release all resources hold by this reader. snapshot uint64 } // GetStatsReader returns a StatsReader. -func GetStatsReader(snapshot uint64, exec sqlexec.RestrictedSQLExecutor) (reader *StatsReader, err error) { +func GetStatsReader(snapshot uint64, exec sqlexec.RestrictedSQLExecutor, releaseFunc func()) (reader *StatsReader, err error) { failpoint.Inject("mockGetStatsReaderFail", func(val failpoint.Value) { if val.(bool) { failpoint.Return(nil, errors.New("gofail genStatsReader error")) } }) if snapshot > 0 { - return &StatsReader{ctx: exec, snapshot: snapshot}, nil + return &StatsReader{ctx: exec, snapshot: snapshot, release: releaseFunc}, nil } defer func() { if r := recover(); r != nil { @@ -69,7 +70,7 @@ func GetStatsReader(snapshot uint64, exec sqlexec.RestrictedSQLExecutor) (reader if err != nil { return nil, err } - return &StatsReader{ctx: exec}, nil + return &StatsReader{ctx: exec, release: releaseFunc}, nil } // Read is a thin wrapper reading statistics from storage by sql command. @@ -88,6 +89,14 @@ func (sr *StatsReader) IsHistory() bool { // Close closes the StatsReader. func (sr *StatsReader) Close() error { + defer func() { + if sr.release != nil { + sr.release() + } + sr.release = nil + sr.ctx = nil + }() + if sr.IsHistory() || sr.ctx == nil { return nil }