From 256f3e05321cb80ece065d1b23451c200437ba13 Mon Sep 17 00:00:00 2001 From: Yuanjia Zhang Date: Mon, 13 Dec 2021 15:18:51 +0800 Subject: [PATCH 1/5] refactor argument --- planner/core/find_best_task.go | 17 +++-- planner/core/logical_plans.go | 9 +-- planner/core/rule_partition_processor.go | 6 +- planner/core/stats.go | 2 +- statistics/handle/ddl_serial_test.go | 14 ++-- statistics/handle/handle_test.go | 13 ++-- statistics/handle/update.go | 40 ++++++++--- statistics/handle/update_test.go | 2 +- statistics/histogram.go | 53 ++++++++------- statistics/histogram_test.go | 4 +- statistics/selectivity.go | 22 +++---- statistics/selectivity_serial_test.go | 40 +++++------ statistics/statistics_test.go | 73 ++++++++++---------- statistics/table.go | 84 +++++++++++++----------- util/ranger/types.go | 4 +- 15 files changed, 204 insertions(+), 179 deletions(-) diff --git a/planner/core/find_best_task.go b/planner/core/find_best_task.go index c23614e7c5935..59a182f9f8e7c 100644 --- a/planner/core/find_best_task.go +++ b/planner/core/find_best_task.go @@ -27,7 +27,7 @@ import ( "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/planner/property" "github.com/pingcap/tidb/planner/util" - "github.com/pingcap/tidb/sessionctx/stmtctx" + "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/types" tidbutil "github.com/pingcap/tidb/util" @@ -1478,7 +1478,7 @@ func getMostCorrCol4Handle(exprs []expression.Expression, histColl *statistics.T } // getColumnRangeCounts estimates row count for each range respectively. -func getColumnRangeCounts(sc *stmtctx.StatementContext, colID int64, ranges []*ranger.Range, histColl *statistics.HistColl, idxID int64) ([]float64, bool) { +func getColumnRangeCounts(sctx sessionctx.Context, colID int64, ranges []*ranger.Range, histColl *statistics.HistColl, idxID int64) ([]float64, bool) { var err error var count float64 rangeCounts := make([]float64, len(ranges)) @@ -1488,13 +1488,13 @@ func getColumnRangeCounts(sc *stmtctx.StatementContext, colID int64, ranges []*r if idxHist == nil || idxHist.IsInvalid(false) { return nil, false } - count, err = histColl.GetRowCountByIndexRanges(sc, idxID, []*ranger.Range{ran}) + count, err = histColl.GetRowCountByIndexRanges(sctx, idxID, []*ranger.Range{ran}) } else { colHist, ok := histColl.Columns[colID] - if !ok || colHist.IsInvalid(sc, false) { + if !ok || colHist.IsInvalid(sctx, false) { return nil, false } - count, err = histColl.GetRowCountByColumnRanges(sc, colID, []*ranger.Range{ran}) + count, err = histColl.GetRowCountByColumnRanges(sctx, colID, []*ranger.Range{ran}) } if err != nil { return nil, false @@ -1564,7 +1564,6 @@ func (ds *DataSource) crossEstimateRowCount(path *util.AccessPath, conds []expre if len(accessConds) == 0 { return 0, false, corr } - sc := ds.ctx.GetSessionVars().StmtCtx ranges, err := ranger.BuildColumnRange(accessConds, ds.ctx, col.RetType, types.UnspecifiedLength) if len(ranges) == 0 || err != nil { return 0, err == nil, corr @@ -1573,7 +1572,7 @@ func (ds *DataSource) crossEstimateRowCount(path *util.AccessPath, conds []expre if !idxExists { idxID = -1 } - rangeCounts, ok := getColumnRangeCounts(sc, colID, ranges, ds.tableStats.HistColl, idxID) + rangeCounts, ok := getColumnRangeCounts(ds.ctx, colID, ranges, ds.tableStats.HistColl, idxID) if !ok { return 0, false, corr } @@ -1583,9 +1582,9 @@ func (ds *DataSource) crossEstimateRowCount(path *util.AccessPath, conds []expre } var rangeCount float64 if idxExists { - rangeCount, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(sc, idxID, convertedRanges) + rangeCount, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(ds.ctx, idxID, convertedRanges) } else { - rangeCount, err = ds.tableStats.HistColl.GetRowCountByColumnRanges(sc, colID, convertedRanges) + rangeCount, err = ds.tableStats.HistColl.GetRowCountByColumnRanges(ds.ctx, colID, convertedRanges) } if err != nil { return 0, false, corr diff --git a/planner/core/logical_plans.go b/planner/core/logical_plans.go index 212f10d65346a..5fe0426b5c15b 100644 --- a/planner/core/logical_plans.go +++ b/planner/core/logical_plans.go @@ -726,7 +726,6 @@ func (ds *DataSource) deriveCommonHandleTablePathStats(path *util.AccessPath, co if len(conds) == 0 { return nil } - sc := ds.ctx.GetSessionVars().StmtCtx if len(path.IdxCols) != 0 { res, err := ranger.DetachCondAndBuildRangeForIndex(ds.ctx, conds, path.IdxCols, path.IdxColLens) if err != nil { @@ -744,7 +743,7 @@ func (ds *DataSource) deriveCommonHandleTablePathStats(path *util.AccessPath, co path.ConstCols[i] = res.ColumnValues[i] != nil } } - path.CountAfterAccess, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(sc, path.Index.ID, path.Ranges) + path.CountAfterAccess, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(ds.ctx, path.Index.ID, path.Ranges) if err != nil { return err } @@ -785,7 +784,6 @@ func (ds *DataSource) deriveTablePathStats(path *util.AccessPath, conds []expres return ds.deriveCommonHandleTablePathStats(path, conds, isIm) } var err error - sc := ds.ctx.GetSessionVars().StmtCtx path.CountAfterAccess = float64(ds.statisticTable.Count) path.TableFilters = conds var pkCol *expression.Column @@ -848,7 +846,7 @@ func (ds *DataSource) deriveTablePathStats(path *util.AccessPath, conds []expres if err != nil { return err } - path.CountAfterAccess, err = ds.statisticTable.GetRowCountByIntColumnRanges(sc, pkCol.ID, path.Ranges) + path.CountAfterAccess, err = ds.statisticTable.GetRowCountByIntColumnRanges(ds.ctx, pkCol.ID, path.Ranges) // If the `CountAfterAccess` is less than `stats.RowCount`, there must be some inconsistent stats info. // We prefer the `stats.RowCount` because it could use more stats info to calculate the selectivity. if path.CountAfterAccess < ds.stats.RowCount && !isIm { @@ -858,7 +856,6 @@ func (ds *DataSource) deriveTablePathStats(path *util.AccessPath, conds []expres } func (ds *DataSource) fillIndexPath(path *util.AccessPath, conds []expression.Expression) error { - sc := ds.ctx.GetSessionVars().StmtCtx path.Ranges = ranger.FullRange() path.CountAfterAccess = float64(ds.statisticTable.Count) path.IdxCols, path.IdxColLens = expression.IndexInfo2PrefixCols(ds.Columns, ds.schema.Columns, path.Index) @@ -900,7 +897,7 @@ func (ds *DataSource) fillIndexPath(path *util.AccessPath, conds []expression.Ex path.ConstCols[i] = res.ColumnValues[i] != nil } } - path.CountAfterAccess, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(sc, path.Index.ID, path.Ranges) + path.CountAfterAccess, err = ds.tableStats.HistColl.GetRowCountByIndexRanges(ds.ctx, path.Index.ID, path.Ranges) if err != nil { return err } diff --git a/planner/core/rule_partition_processor.go b/planner/core/rule_partition_processor.go index 7c3bbb565c69d..bb57b0fac33da 100644 --- a/planner/core/rule_partition_processor.go +++ b/planner/core/rule_partition_processor.go @@ -140,7 +140,7 @@ func (s *partitionProcessor) findUsedPartitions(ctx sessionctx.Context, tbl tabl ranges := detachedResult.Ranges used := make([]int, 0, len(ranges)) for _, r := range ranges { - if r.IsPointNullable(ctx.GetSessionVars().StmtCtx) { + if r.IsPointNullable(ctx) { if !r.HighVal[0].IsNull() { if len(r.HighVal) != len(partIdx) { used = []int{-1} @@ -473,7 +473,7 @@ func (l *listPartitionPruner) locateColumnPartitionsByCondition(cond expression. return nil, true, nil } var locations []tables.ListPartitionLocation - if r.IsPointNullable(l.ctx.GetSessionVars().StmtCtx) { + if r.IsPointNullable(l.ctx) { location, err := colPrune.LocatePartition(sc, r.HighVal[0]) if types.ErrOverflow.Equal(err) { return nil, true, nil // return full-scan if over-flow @@ -555,7 +555,7 @@ func (l *listPartitionPruner) findUsedListPartitions(conds []expression.Expressi } used := make(map[int]struct{}, len(ranges)) for _, r := range ranges { - if r.IsPointNullable(l.ctx.GetSessionVars().StmtCtx) { + if r.IsPointNullable(l.ctx) { if len(r.HighVal) != len(exprCols) { return l.fullRange, nil } diff --git a/planner/core/stats.go b/planner/core/stats.go index 14a6a11a2c2d4..2e7fd14a67b8d 100644 --- a/planner/core/stats.go +++ b/planner/core/stats.go @@ -253,7 +253,7 @@ func (ds *DataSource) deriveStatsByFilter(conds expression.CNFExprs, filledPaths } stats := ds.tableStats.Scale(selectivity) if ds.ctx.GetSessionVars().OptimizerSelectivityLevel >= 1 { - stats.HistColl = stats.HistColl.NewHistCollBySelectivity(ds.ctx.GetSessionVars().StmtCtx, nodes) + stats.HistColl = stats.HistColl.NewHistCollBySelectivity(ds.ctx, nodes) } return stats } diff --git a/statistics/handle/ddl_serial_test.go b/statistics/handle/ddl_serial_test.go index 76121694338df..91a3a244cb17d 100644 --- a/statistics/handle/ddl_serial_test.go +++ b/statistics/handle/ddl_serial_test.go @@ -18,10 +18,10 @@ import ( "testing" "github.com/pingcap/tidb/parser/model" - "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/types" + "github.com/pingcap/tidb/util/mock" "github.com/stretchr/testify/require" ) @@ -51,10 +51,10 @@ func TestDDLAfterLoad(t *testing.T) { require.NoError(t, err) tableInfo = tbl.Meta() - sc := new(stmtctx.StatementContext) - count := statsTbl.ColumnGreaterRowCount(sc, types.NewDatum(recordCount+1), tableInfo.Columns[0].ID) + sctx := mock.NewContext() + count := statsTbl.ColumnGreaterRowCount(sctx, types.NewDatum(recordCount+1), tableInfo.Columns[0].ID) require.Equal(t, 0.0, count) - count = statsTbl.ColumnGreaterRowCount(sc, types.NewDatum(recordCount+1), tableInfo.Columns[2].ID) + count = statsTbl.ColumnGreaterRowCount(sctx, types.NewDatum(recordCount+1), tableInfo.Columns[2].ID) require.Equal(t, 333, int(count)) } @@ -131,11 +131,11 @@ func TestDDLHistogram(t *testing.T) { tableInfo = tbl.Meta() statsTbl = do.StatsHandle().GetTableStats(tableInfo) require.False(t, statsTbl.Pseudo) - sc := new(stmtctx.StatementContext) - count, err := statsTbl.ColumnEqualRowCount(sc, types.NewIntDatum(0), tableInfo.Columns[3].ID) + sctx := mock.NewContext() + count, err := statsTbl.ColumnEqualRowCount(sctx, types.NewIntDatum(0), tableInfo.Columns[3].ID) require.NoError(t, err) require.Equal(t, float64(2), count) - count, err = statsTbl.ColumnEqualRowCount(sc, types.NewIntDatum(1), tableInfo.Columns[3].ID) + count, err = statsTbl.ColumnEqualRowCount(sctx, types.NewIntDatum(1), tableInfo.Columns[3].ID) require.NoError(t, err) require.Equal(t, float64(0), count) diff --git a/statistics/handle/handle_test.go b/statistics/handle/handle_test.go index 23b2de4333af8..7498985580afd 100644 --- a/statistics/handle/handle_test.go +++ b/statistics/handle/handle_test.go @@ -17,6 +17,7 @@ package handle_test import ( "bytes" "fmt" + "github.com/pingcap/tidb/util/mock" "math" "strings" "testing" @@ -32,7 +33,6 @@ import ( "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/terror" "github.com/pingcap/tidb/session" - "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/statistics/handle" @@ -267,8 +267,7 @@ func (s *testStatsSuite) TestEmptyTable(c *C) { c.Assert(err, IsNil) tableInfo := tbl.Meta() statsTbl := do.StatsHandle().GetTableStats(tableInfo) - sc := new(stmtctx.StatementContext) - count := statsTbl.ColumnGreaterRowCount(sc, types.NewDatum(1), tableInfo.Columns[0].ID) + count := statsTbl.ColumnGreaterRowCount(mock.NewContext(), types.NewDatum(1), tableInfo.Columns[0].ID) c.Assert(count, Equals, 0.0) } @@ -285,14 +284,14 @@ func (s *testStatsSuite) TestColumnIDs(c *C) { c.Assert(err, IsNil) tableInfo := tbl.Meta() statsTbl := do.StatsHandle().GetTableStats(tableInfo) - sc := new(stmtctx.StatementContext) + sctx := mock.NewContext() ran := &ranger.Range{ LowVal: []types.Datum{types.MinNotNullDatum()}, HighVal: []types.Datum{types.NewIntDatum(2)}, LowExclude: false, HighExclude: true, } - count, err := statsTbl.GetRowCountByColumnRanges(sc, tableInfo.Columns[0].ID, []*ranger.Range{ran}) + count, err := statsTbl.GetRowCountByColumnRanges(sctx, tableInfo.Columns[0].ID, []*ranger.Range{ran}) c.Assert(err, IsNil) c.Assert(count, Equals, float64(1)) @@ -307,7 +306,7 @@ func (s *testStatsSuite) TestColumnIDs(c *C) { tableInfo = tbl.Meta() statsTbl = do.StatsHandle().GetTableStats(tableInfo) // At that time, we should get c2's stats instead of c1's. - count, err = statsTbl.GetRowCountByColumnRanges(sc, tableInfo.Columns[0].ID, []*ranger.Range{ran}) + count, err = statsTbl.GetRowCountByColumnRanges(sctx, tableInfo.Columns[0].ID, []*ranger.Range{ran}) c.Assert(err, IsNil) c.Assert(count, Equals, 0.0) } @@ -614,7 +613,7 @@ func (s *testStatsSuite) TestLoadStats(c *C) { c.Assert(hg.Len(), Equals, 0) cms = stat.Columns[tableInfo.Columns[2].ID].CMSketch c.Assert(cms, IsNil) - _, err = stat.ColumnEqualRowCount(testKit.Se.GetSessionVars().StmtCtx, types.NewIntDatum(1), tableInfo.Columns[2].ID) + _, err = stat.ColumnEqualRowCount(testKit.Se, types.NewIntDatum(1), tableInfo.Columns[2].ID) c.Assert(err, IsNil) c.Assert(h.LoadNeededHistograms(), IsNil) stat = h.GetTableStats(tableInfo) diff --git a/statistics/handle/update.go b/statistics/handle/update.go index e8672794a2509..68ccbe05c6811 100644 --- a/statistics/handle/update.go +++ b/statistics/handle/update.go @@ -18,6 +18,8 @@ import ( "bytes" "context" "fmt" + "github.com/pingcap/tidb/parser/mysql" + "github.com/pingcap/tidb/sessionctx" "math" "strconv" "strings" @@ -30,7 +32,6 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/parser/model" - "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" @@ -1256,7 +1257,18 @@ func (h *Handle) RecalculateExpectCount(q *statistics.QueryFeedback) error { return nil } - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + se, err := h.pool.Get() + if err != nil { + return err + } + sctx := se.(sessionctx.Context) + timeZone := sctx.GetSessionVars().StmtCtx.TimeZone + defer func() { + sctx.GetSessionVars().StmtCtx.TimeZone = timeZone + h.pool.Put(se) + }() + sctx.GetSessionVars().StmtCtx.TimeZone = time.UTC + ranges, err := q.DecodeToRanges(isIndex) if err != nil { return errors.Trace(err) @@ -1264,10 +1276,10 @@ func (h *Handle) RecalculateExpectCount(q *statistics.QueryFeedback) error { expected := 0.0 if isIndex { idx := t.Indices[id] - expected, err = idx.GetRowCount(sc, nil, ranges, t.Count) + expected, err = idx.GetRowCount(sctx, nil, ranges, t.Count) } else { c := t.Columns[id] - expected, err = c.GetColumnRowCount(sc, ranges, t.Count, true) + expected, err = c.GetColumnRowCount(sctx, ranges, t.Count, true) } q.Expected = int64(expected) return err @@ -1344,7 +1356,20 @@ func (h *Handle) DumpFeedbackForIndex(q *statistics.QueryFeedback, t *statistics if !ok { return nil } - sc := &stmtctx.StatementContext{TimeZone: time.UTC} + + se, err := h.pool.Get() + if err != nil { + return err + } + sctx := se.(sessionctx.Context) + sc := sctx.GetSessionVars().StmtCtx + timeZone := sc.TimeZone + defer func() { + sctx.GetSessionVars().StmtCtx.TimeZone = timeZone + h.pool.Put(se) + }() + sc.TimeZone = time.UTC + if idx.CMSketch == nil || idx.StatsVer < statistics.Version1 { return h.DumpFeedbackToKV(q) } @@ -1359,7 +1384,6 @@ func (h *Handle) DumpFeedbackForIndex(q *statistics.QueryFeedback, t *statistics if rangePosition == 0 || rangePosition == len(ran.LowVal) { continue } - bytes, err := codec.EncodeKey(sc, nil, ran.LowVal[:rangePosition]...) if err != nil { logutil.BgLogger().Debug("encode keys fail", zap.Error(err)) @@ -1375,12 +1399,12 @@ func (h *Handle) DumpFeedbackForIndex(q *statistics.QueryFeedback, t *statistics rangeFB := &statistics.QueryFeedback{PhysicalID: q.PhysicalID} // prefer index stats over column stats if idx := t.IndexStartWithColumn(colName); idx != nil && idx.Histogram.Len() != 0 { - rangeCount, err = t.GetRowCountByIndexRanges(sc, idx.ID, []*ranger.Range{rang}) + rangeCount, err = t.GetRowCountByIndexRanges(sctx, idx.ID, []*ranger.Range{rang}) rangeFB.Tp, rangeFB.Hist = statistics.IndexType, &idx.Histogram } else if col := t.ColumnByName(colName); col != nil && col.Histogram.Len() != 0 { err = convertRangeType(rang, col.Tp, time.UTC) if err == nil { - rangeCount, err = t.GetRowCountByColumnRanges(sc, col.ID, []*ranger.Range{rang}) + rangeCount, err = t.GetRowCountByColumnRanges(sctx, col.ID, []*ranger.Range{rang}) rangeFB.Tp, rangeFB.Hist = statistics.ColType, &col.Histogram } } else { diff --git a/statistics/handle/update_test.go b/statistics/handle/update_test.go index fe4904e254be5..3d41f92701593 100644 --- a/statistics/handle/update_test.go +++ b/statistics/handle/update_test.go @@ -149,7 +149,7 @@ func (s *testStatsSuite) TestSingleSessionInsert(c *C) { c.Assert(stats1.Count, Equals, int64(rowCount1*2)) // Test IncreaseFactor. - count, err := stats1.ColumnEqualRowCount(testKit.Se.GetSessionVars().StmtCtx, types.NewIntDatum(1), tableInfo1.Columns[0].ID) + count, err := stats1.ColumnEqualRowCount(testKit.Se, types.NewIntDatum(1), tableInfo1.Columns[0].ID) c.Assert(err, IsNil) c.Assert(count, Equals, float64(rowCount1*2)) diff --git a/statistics/histogram.go b/statistics/histogram.go index a61f1d1405f59..67196081badaa 100644 --- a/statistics/histogram.go +++ b/statistics/histogram.go @@ -17,6 +17,7 @@ package statistics import ( "bytes" "fmt" + "github.com/pingcap/tidb/sessionctx" "math" "sort" "strings" @@ -506,7 +507,7 @@ func (hg *Histogram) BetweenRowCount(a, b types.Datum) float64 { } // BetweenRowCount estimates the row count for interval [l, r). -func (c *Column) BetweenRowCount(sc *stmtctx.StatementContext, l, r types.Datum, lowEncoded, highEncoded []byte) float64 { +func (c *Column) BetweenRowCount(sctx sessionctx.Context, l, r types.Datum, lowEncoded, highEncoded []byte) float64 { histBetweenCnt := c.Histogram.BetweenRowCount(l, r) if c.StatsVer <= Version1 { return histBetweenCnt @@ -1067,17 +1068,17 @@ var HistogramNeededColumns = neededColumnMap{cols: map[tableColumnID]struct{}{}} // IsInvalid checks if this column is invalid. If this column has histogram but not loaded yet, then we mark it // as need histogram. -func (c *Column) IsInvalid(sc *stmtctx.StatementContext, collPseudo bool) bool { +func (c *Column) IsInvalid(sctx sessionctx.Context, collPseudo bool) bool { if collPseudo && c.NotAccurate() { return true } - if c.Histogram.NDV > 0 && c.notNullCount() == 0 && sc != nil { + if c.Histogram.NDV > 0 && c.notNullCount() == 0 && sctx.GetSessionVars().StmtCtx != nil { HistogramNeededColumns.insert(tableColumnID{TableID: c.PhysicalID, ColumnID: c.Info.ID}) } return c.TotalRowCount() == 0 || (c.Histogram.NDV > 0 && c.notNullCount() == 0) } -func (c *Column) equalRowCount(sc *stmtctx.StatementContext, val types.Datum, encodedVal []byte, realtimeRowCount int64) (float64, error) { +func (c *Column) equalRowCount(sctx sessionctx.Context, val types.Datum, encodedVal []byte, realtimeRowCount int64) (float64, error) { if val.IsNull() { return float64(c.NullCount), nil } @@ -1090,7 +1091,7 @@ func (c *Column) equalRowCount(sc *stmtctx.StatementContext, val types.Datum, en return outOfRangeEQSelectivity(c.Histogram.NDV, realtimeRowCount, int64(c.TotalRowCount())) * c.TotalRowCount(), nil } if c.CMSketch != nil { - count, err := queryValue(sc, c.CMSketch, c.TopN, val) + count, err := queryValue(sctx.GetSessionVars().StmtCtx, c.CMSketch, c.TopN, val) return float64(count), errors.Trace(err) } histRowCount, _ := c.Histogram.equalRowCount(val, false) @@ -1123,7 +1124,8 @@ func (c *Column) equalRowCount(sc *stmtctx.StatementContext, val types.Datum, en } // GetColumnRowCount estimates the row count by a slice of Range. -func (c *Column) GetColumnRowCount(sc *stmtctx.StatementContext, ranges []*ranger.Range, realtimeRowCount int64, pkIsHandle bool) (float64, error) { +func (c *Column) GetColumnRowCount(sctx sessionctx.Context, ranges []*ranger.Range, realtimeRowCount int64, pkIsHandle bool) (float64, error) { + sc := sctx.GetSessionVars().StmtCtx var rowCount float64 for _, rg := range ranges { highVal := *rg.HighVal[0].Clone() @@ -1155,7 +1157,7 @@ func (c *Column) GetColumnRowCount(sc *stmtctx.StatementContext, ranges []*range continue } var cnt float64 - cnt, err = c.equalRowCount(sc, lowVal, lowEncoded, realtimeRowCount) + cnt, err = c.equalRowCount(sctx, lowVal, lowEncoded, realtimeRowCount) if err != nil { return 0, errors.Trace(err) } @@ -1173,7 +1175,7 @@ func (c *Column) GetColumnRowCount(sc *stmtctx.StatementContext, ranges []*range // case 2: it's a small range && using ver1 stats if rangeVals != nil { for _, val := range rangeVals { - cnt, err := c.equalRowCount(sc, val, lowEncoded, realtimeRowCount) + cnt, err := c.equalRowCount(sctx, val, lowEncoded, realtimeRowCount) if err != nil { return 0, err } @@ -1187,12 +1189,12 @@ func (c *Column) GetColumnRowCount(sc *stmtctx.StatementContext, ranges []*range } // case 3: it's an interval - cnt := c.BetweenRowCount(sc, lowVal, highVal, lowEncoded, highEncoded) + cnt := c.BetweenRowCount(sctx, lowVal, highVal, lowEncoded, highEncoded) // `betweenRowCount` returns count for [l, h) range, we adjust cnt for boundaries here. // Note that, `cnt` does not include null values, we need specially handle cases // where null is the lower bound. if rg.LowExclude && !lowVal.IsNull() { - lowCnt, err := c.equalRowCount(sc, lowVal, lowEncoded, realtimeRowCount) + lowCnt, err := c.equalRowCount(sctx, lowVal, lowEncoded, realtimeRowCount) if err != nil { return 0, errors.Trace(err) } @@ -1202,7 +1204,7 @@ func (c *Column) GetColumnRowCount(sc *stmtctx.StatementContext, ranges []*range cnt += float64(c.NullCount) } if !rg.HighExclude { - highCnt, err := c.equalRowCount(sc, highVal, highEncoded, realtimeRowCount) + highCnt, err := c.equalRowCount(sctx, highVal, highEncoded, realtimeRowCount) if err != nil { return 0, errors.Trace(err) } @@ -1326,7 +1328,8 @@ func (idx *Index) QueryBytes(d []byte) uint64 { // GetRowCount returns the row count of the given ranges. // It uses the modifyCount to adjust the influence of modifications on the table. -func (idx *Index) GetRowCount(sc *stmtctx.StatementContext, coll *HistColl, indexRanges []*ranger.Range, realtimeRowCount int64) (float64, error) { +func (idx *Index) GetRowCount(sctx sessionctx.Context, coll *HistColl, indexRanges []*ranger.Range, realtimeRowCount int64) (float64, error) { + sc := sctx.GetSessionVars().StmtCtx totalCount := float64(0) isSingleCol := len(idx.Info.Columns) == 1 for _, indexRange := range indexRanges { @@ -1377,7 +1380,7 @@ func (idx *Index) GetRowCount(sc *stmtctx.StatementContext, coll *HistColl, inde // If the first column's range is point. if rangePosition := GetOrdinalOfRangeCond(sc, indexRange); rangePosition > 0 && idx.StatsVer >= Version2 && coll != nil { var expBackoffSel float64 - expBackoffSel, expBackoffSuccess, err = idx.expBackoffEstimation(sc, coll, indexRange) + expBackoffSel, expBackoffSuccess, err = idx.expBackoffEstimation(sctx, coll, indexRange) if err != nil { return 0, err } @@ -1408,7 +1411,7 @@ func (idx *Index) GetRowCount(sc *stmtctx.StatementContext, coll *HistColl, inde } // expBackoffEstimation estimate the multi-col cases following the Exponential Backoff. See comment below for details. -func (idx *Index) expBackoffEstimation(sc *stmtctx.StatementContext, coll *HistColl, indexRange *ranger.Range) (float64, bool, error) { +func (idx *Index) expBackoffEstimation(sctx sessionctx.Context, coll *HistColl, indexRange *ranger.Range) (float64, bool, error) { tmpRan := []*ranger.Range{ { LowVal: make([]types.Datum, 1), @@ -1435,9 +1438,9 @@ func (idx *Index) expBackoffEstimation(sc *stmtctx.StatementContext, coll *HistC err error ) if anotherIdxID, ok := coll.ColID2IdxID[colID]; ok && anotherIdxID != idx.ID { - count, err = coll.GetRowCountByIndexRanges(sc, anotherIdxID, tmpRan) - } else if col, ok := coll.Columns[colID]; ok && !col.IsInvalid(sc, coll.Pseudo) { - count, err = coll.GetRowCountByColumnRanges(sc, colID, tmpRan) + count, err = coll.GetRowCountByIndexRanges(sctx, anotherIdxID, tmpRan) + } else if col, ok := coll.Columns[colID]; ok && !col.IsInvalid(sctx, coll.Pseudo) { + count, err = coll.GetRowCountByColumnRanges(sctx, colID, tmpRan) } else { continue } @@ -1471,12 +1474,12 @@ func (idx *Index) expBackoffEstimation(sc *stmtctx.StatementContext, coll *HistC return singleColumnEstResults[0] * math.Sqrt(singleColumnEstResults[1]) * math.Sqrt(math.Sqrt(singleColumnEstResults[2])) * math.Sqrt(math.Sqrt(math.Sqrt(singleColumnEstResults[3]))), true, nil } -type countByRangeFunc = func(*stmtctx.StatementContext, int64, []*ranger.Range) (float64, error) +type countByRangeFunc = func(sessionctx.Context, int64, []*ranger.Range) (float64, error) // newHistogramBySelectivity fulfills the content of new histogram by the given selectivity result. // TODO: Datum is not efficient, try to avoid using it here. // Also, there're redundant calculation with Selectivity(). We need to reduce it too. -func newHistogramBySelectivity(sc *stmtctx.StatementContext, histID int64, oldHist, newHist *Histogram, ranges []*ranger.Range, cntByRangeFunc countByRangeFunc) error { +func newHistogramBySelectivity(sctx sessionctx.Context, histID int64, oldHist, newHist *Histogram, ranges []*ranger.Range, cntByRangeFunc countByRangeFunc) error { cntPerVal := int64(oldHist.AvgCountPerNotNullValue(int64(oldHist.TotalRowCount()))) var totCnt int64 for boundIdx, ranIdx, highRangeIdx := 0, 0, 0; boundIdx < oldHist.Bounds.NumRows() && ranIdx < len(ranges); boundIdx, ranIdx = boundIdx+2, highRangeIdx { @@ -1489,7 +1492,7 @@ func newHistogramBySelectivity(sc *stmtctx.StatementContext, histID int64, oldHi if ranIdx == highRangeIdx { continue } - cnt, err := cntByRangeFunc(sc, histID, ranges[ranIdx:highRangeIdx]) + cnt, err := cntByRangeFunc(sctx, histID, ranges[ranIdx:highRangeIdx]) // This should not happen. if err != nil { return err @@ -1565,7 +1568,7 @@ func (idx *Index) newIndexBySelectivity(sc *stmtctx.StatementContext, statsNode } // NewHistCollBySelectivity creates new HistColl by the given statsNodes. -func (coll *HistColl) NewHistCollBySelectivity(sc *stmtctx.StatementContext, statsNodes []*StatsNode) *HistColl { +func (coll *HistColl) NewHistCollBySelectivity(sctx sessionctx.Context, statsNodes []*StatsNode) *HistColl { newColl := &HistColl{ Columns: make(map[int64]*Column), Indices: make(map[int64]*Index), @@ -1579,7 +1582,7 @@ func (coll *HistColl) NewHistCollBySelectivity(sc *stmtctx.StatementContext, sta if !ok { continue } - newIdxHist, err := idxHist.newIndexBySelectivity(sc, node) + newIdxHist, err := idxHist.newIndexBySelectivity(sctx.GetSessionVars().StmtCtx, node) if err != nil { logutil.BgLogger().Warn("[Histogram-in-plan]: something wrong happened when calculating row count, "+ "failed to build histogram for index %v of table %v", @@ -1601,7 +1604,7 @@ func (coll *HistColl) NewHistCollBySelectivity(sc *stmtctx.StatementContext, sta } newCol.Histogram = *NewHistogram(oldCol.ID, int64(float64(oldCol.Histogram.NDV)*node.Selectivity), 0, 0, oldCol.Tp, chunk.InitialCapacity, 0) var err error - splitRanges, ok := oldCol.Histogram.SplitRange(sc, node.Ranges, false) + splitRanges, ok := oldCol.Histogram.SplitRange(sctx.GetSessionVars().StmtCtx, node.Ranges, false) if !ok { logutil.BgLogger().Warn("[Histogram-in-plan]: the type of histogram and ranges mismatch") continue @@ -1619,9 +1622,9 @@ func (coll *HistColl) NewHistCollBySelectivity(sc *stmtctx.StatementContext, sta } } if oldCol.IsHandle { - err = newHistogramBySelectivity(sc, node.ID, &oldCol.Histogram, &newCol.Histogram, splitRanges, coll.GetRowCountByIntColumnRanges) + err = newHistogramBySelectivity(sctx, node.ID, &oldCol.Histogram, &newCol.Histogram, splitRanges, coll.GetRowCountByIntColumnRanges) } else { - err = newHistogramBySelectivity(sc, node.ID, &oldCol.Histogram, &newCol.Histogram, splitRanges, coll.GetRowCountByColumnRanges) + err = newHistogramBySelectivity(sctx, node.ID, &oldCol.Histogram, &newCol.Histogram, splitRanges, coll.GetRowCountByColumnRanges) } if err != nil { logutil.BgLogger().Warn("[Histogram-in-plan]: something wrong happened when calculating row count", diff --git a/statistics/histogram_test.go b/statistics/histogram_test.go index 9a0698133711e..6107a95dc453f 100644 --- a/statistics/histogram_test.go +++ b/statistics/histogram_test.go @@ -93,7 +93,7 @@ num: 54 lower_bound: kkkkk upper_bound: ooooo repeats: 0 ndv: 0 num: 60 lower_bound: oooooo upper_bound: sssss repeats: 0 ndv: 0 num: 60 lower_bound: ssssssu upper_bound: yyyyy repeats: 0 ndv: 0` - newColl := coll.NewHistCollBySelectivity(sc, []*StatsNode{node, node2}) + newColl := coll.NewHistCollBySelectivity(ctx, []*StatsNode{node, node2}) require.Equal(t, intColResult, newColl.Columns[1].String()) require.Equal(t, stringColResult, newColl.Columns[2].String()) @@ -120,7 +120,7 @@ num: 30 lower_bound: 3 upper_bound: 5 repeats: 10 ndv: 0 num: 30 lower_bound: 9 upper_bound: 11 repeats: 10 ndv: 0 num: 30 lower_bound: 12 upper_bound: 14 repeats: 10 ndv: 0` - newColl = coll.NewHistCollBySelectivity(sc, []*StatsNode{node3}) + newColl = coll.NewHistCollBySelectivity(ctx, []*StatsNode{node3}) require.Equal(t, idxResult, newColl.Indices[0].String()) } diff --git a/statistics/selectivity.go b/statistics/selectivity.go index 86321d561e954..45db1cebf9b1c 100644 --- a/statistics/selectivity.go +++ b/statistics/selectivity.go @@ -27,7 +27,6 @@ import ( "github.com/pingcap/tidb/parser/mysql" planutil "github.com/pingcap/tidb/planner/util" "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" driver "github.com/pingcap/tidb/types/parser_driver" "github.com/pingcap/tidb/util/chunk" @@ -193,7 +192,7 @@ func (coll *HistColl) Selectivity(ctx sessionctx.Context, exprs []expression.Exp if len(exprs) > 63 || (len(coll.Columns) == 0 && len(coll.Indices) == 0) { ret = pseudoSelectivity(coll, exprs) if sc.EnableOptimizerCETrace { - CETraceExpr(sc, tableID, "Table Stats-Pseudo-Expression", expression.ComposeCNFCondition(ctx, exprs...), ret*float64(coll.Count)) + CETraceExpr(ctx, tableID, "Table Stats-Pseudo-Expression", expression.ComposeCNFCondition(ctx, exprs...), ret*float64(coll.Count)) } return ret, nil, nil } @@ -210,7 +209,7 @@ func (coll *HistColl) Selectivity(ctx sessionctx.Context, exprs []expression.Exp continue } - if colHist := coll.Columns[c.UniqueID]; colHist == nil || colHist.IsInvalid(sc, coll.Pseudo) { + if colHist := coll.Columns[c.UniqueID]; colHist == nil || colHist.IsInvalid(ctx, coll.Pseudo) { ret *= 1.0 / pseudoEqualRate continue } @@ -236,14 +235,14 @@ func (coll *HistColl) Selectivity(ctx sessionctx.Context, exprs []expression.Exp if colInfo.IsHandle { nodes[len(nodes)-1].Tp = PkType var cnt float64 - cnt, err = coll.GetRowCountByIntColumnRanges(sc, id, ranges) + cnt, err = coll.GetRowCountByIntColumnRanges(ctx, id, ranges) if err != nil { return 0, nil, errors.Trace(err) } nodes[len(nodes)-1].Selectivity = cnt / float64(coll.Count) continue } - cnt, err := coll.GetRowCountByColumnRanges(sc, id, ranges) + cnt, err := coll.GetRowCountByColumnRanges(ctx, id, ranges) if err != nil { return 0, nil, errors.Trace(err) } @@ -274,7 +273,7 @@ func (coll *HistColl) Selectivity(ctx sessionctx.Context, exprs []expression.Exp if err != nil { return 0, nil, errors.Trace(err) } - cnt, err := coll.GetRowCountByIndexRanges(sc, id, ranges) + cnt, err := coll.GetRowCountByIndexRanges(ctx, id, ranges) if err != nil { return 0, nil, errors.Trace(err) } @@ -314,7 +313,7 @@ func (coll *HistColl) Selectivity(ctx sessionctx.Context, exprs []expression.Exp } } expr := expression.ComposeCNFCondition(ctx, curExpr...) - CETraceExpr(sc, tableID, "Table Stats-Expression-CNF", expr, ret*float64(coll.Count)) + CETraceExpr(ctx, tableID, "Table Stats-Expression-CNF", expr, ret*float64(coll.Count)) } } @@ -372,7 +371,7 @@ func (coll *HistColl) Selectivity(ctx sessionctx.Context, exprs []expression.Exp selectivity = selectivity + curSelectivity - selectivity*curSelectivity if sc.EnableOptimizerCETrace { // Tracing for the expression estimation results of this DNF. - CETraceExpr(sc, tableID, "Table Stats-Expression-DNF", scalarCond, selectivity*float64(coll.Count)) + CETraceExpr(ctx, tableID, "Table Stats-Expression-DNF", scalarCond, selectivity*float64(coll.Count)) } } @@ -384,7 +383,7 @@ func (coll *HistColl) Selectivity(ctx sessionctx.Context, exprs []expression.Exp // Tracing for the expression estimation results after applying the DNF estimation result. curExpr = append(curExpr, remainedExprs[i]) expr := expression.ComposeCNFCondition(ctx, curExpr...) - CETraceExpr(sc, tableID, "Table Stats-Expression-CNF", expr, ret*float64(coll.Count)) + CETraceExpr(ctx, tableID, "Table Stats-Expression-CNF", expr, ret*float64(coll.Count)) } } } @@ -396,7 +395,7 @@ func (coll *HistColl) Selectivity(ctx sessionctx.Context, exprs []expression.Exp if sc.EnableOptimizerCETrace { // Tracing for the expression estimation results after applying the default selectivity. totalExpr := expression.ComposeCNFCondition(ctx, remainedExprs...) - CETraceExpr(sc, tableID, "Table Stats-Expression-CNF", totalExpr, ret*float64(coll.Count)) + CETraceExpr(ctx, tableID, "Table Stats-Expression-CNF", totalExpr, ret*float64(coll.Count)) } return ret, nodes, nil } @@ -520,7 +519,7 @@ func FindPrefixOfIndexByCol(cols []*expression.Column, idxColIDs []int64, cached } // CETraceExpr appends an expression and related information into CE trace -func CETraceExpr(sc *stmtctx.StatementContext, tableID int64, tp string, expr expression.Expression, rowCount float64) { +func CETraceExpr(sctx sessionctx.Context, tableID int64, tp string, expr expression.Expression, rowCount float64) { exprStr, err := ExprToString(expr) if err != nil { logutil.BgLogger().Debug("[OptimizerTrace] Failed to trace CE of an expression", @@ -533,6 +532,7 @@ func CETraceExpr(sc *stmtctx.StatementContext, tableID int64, tp string, expr ex Expr: exprStr, RowCount: uint64(rowCount), } + sc := sctx.GetSessionVars().StmtCtx sc.OptimizerCETrace = append(sc.OptimizerCETrace, &rec) } diff --git a/statistics/selectivity_serial_test.go b/statistics/selectivity_serial_test.go index 7fdbf09c757dc..a128be3850049 100644 --- a/statistics/selectivity_serial_test.go +++ b/statistics/selectivity_serial_test.go @@ -28,13 +28,13 @@ import ( plannercore "github.com/pingcap/tidb/planner/core" "github.com/pingcap/tidb/session" "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/statistics" "github.com/pingcap/tidb/statistics/handle" "github.com/pingcap/tidb/testkit" "github.com/pingcap/tidb/testkit/testdata" "github.com/pingcap/tidb/util/collate" + "github.com/pingcap/tidb/util/mock" "github.com/stretchr/testify/require" ) @@ -125,9 +125,9 @@ func TestOutOfRangeEstimation(t *testing.T) { table, err := dom.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t")) require.NoError(t, err) statsTbl := h.GetTableStats(table.Meta()) - sc := &stmtctx.StatementContext{} + sctx := mock.NewContext() col := statsTbl.Columns[table.Meta().Columns[0].ID] - count, err := col.GetColumnRowCount(sc, getRange(900, 900), statsTbl.Count, false) + count, err := col.GetColumnRowCount(sctx, getRange(900, 900), statsTbl.Count, false) require.NoError(t, err) // Because the ANALYZE collect data by random sampling, so the result is not an accurate value. // so we use a range here. @@ -147,7 +147,7 @@ func TestOutOfRangeEstimation(t *testing.T) { statsSuiteData.GetTestCases(t, &input, &output) increasedTblRowCount := int64(float64(statsTbl.Count) * 1.5) for i, ran := range input { - count, err = col.GetColumnRowCount(sc, getRange(ran.Start, ran.End), increasedTblRowCount, false) + count, err = col.GetColumnRowCount(sctx, getRange(ran.Start, ran.End), increasedTblRowCount, false) require.NoError(t, err) testdata.OnRecord(func() { output[i].Start = ran.Start @@ -184,26 +184,26 @@ func TestEstimationForUnknownValues(t *testing.T) { require.NoError(t, err) statsTbl := h.GetTableStats(table.Meta()) - sc := &stmtctx.StatementContext{} + sctx := mock.NewContext() colID := table.Meta().Columns[0].ID - count, err := statsTbl.GetRowCountByColumnRanges(sc, colID, getRange(30, 30)) + count, err := statsTbl.GetRowCountByColumnRanges(sctx, colID, getRange(30, 30)) require.NoError(t, err) require.Equal(t, 0.2, count) - count, err = statsTbl.GetRowCountByColumnRanges(sc, colID, getRange(9, 30)) + count, err = statsTbl.GetRowCountByColumnRanges(sctx, colID, getRange(9, 30)) require.NoError(t, err) require.Equal(t, 7.2, count) - count, err = statsTbl.GetRowCountByColumnRanges(sc, colID, getRange(9, math.MaxInt64)) + count, err = statsTbl.GetRowCountByColumnRanges(sctx, colID, getRange(9, math.MaxInt64)) require.NoError(t, err) require.Equal(t, 7.2, count) idxID := table.Meta().Indices[0].ID - count, err = statsTbl.GetRowCountByIndexRanges(sc, idxID, getRange(30, 30)) + count, err = statsTbl.GetRowCountByIndexRanges(sctx, idxID, getRange(30, 30)) require.NoError(t, err) require.Equal(t, 0.1, count) - count, err = statsTbl.GetRowCountByIndexRanges(sc, idxID, getRange(9, 30)) + count, err = statsTbl.GetRowCountByIndexRanges(sctx, idxID, getRange(9, 30)) require.NoError(t, err) require.Equal(t, 7.0, count) @@ -215,7 +215,7 @@ func TestEstimationForUnknownValues(t *testing.T) { statsTbl = h.GetTableStats(table.Meta()) colID = table.Meta().Columns[0].ID - count, err = statsTbl.GetRowCountByColumnRanges(sc, colID, getRange(1, 30)) + count, err = statsTbl.GetRowCountByColumnRanges(sctx, colID, getRange(1, 30)) require.NoError(t, err) require.Equal(t, 0.0, count) @@ -228,12 +228,12 @@ func TestEstimationForUnknownValues(t *testing.T) { statsTbl = h.GetTableStats(table.Meta()) colID = table.Meta().Columns[0].ID - count, err = statsTbl.GetRowCountByColumnRanges(sc, colID, getRange(2, 2)) + count, err = statsTbl.GetRowCountByColumnRanges(sctx, colID, getRange(2, 2)) require.NoError(t, err) require.Equal(t, 0.0, count) idxID = table.Meta().Indices[0].ID - count, err = statsTbl.GetRowCountByIndexRanges(sc, idxID, getRange(2, 2)) + count, err = statsTbl.GetRowCountByIndexRanges(sctx, idxID, getRange(2, 2)) require.NoError(t, err) require.Equal(t, 0.0, count) } @@ -252,22 +252,22 @@ func TestEstimationUniqueKeyEqualConds(t *testing.T) { require.NoError(t, err) statsTbl := dom.StatsHandle().GetTableStats(table.Meta()) - sc := &stmtctx.StatementContext{} + sctx := mock.NewContext() idxID := table.Meta().Indices[0].ID - count, err := statsTbl.GetRowCountByIndexRanges(sc, idxID, getRange(7, 7)) + count, err := statsTbl.GetRowCountByIndexRanges(sctx, idxID, getRange(7, 7)) require.NoError(t, err) require.Equal(t, 1.0, count) - count, err = statsTbl.GetRowCountByIndexRanges(sc, idxID, getRange(6, 6)) + count, err = statsTbl.GetRowCountByIndexRanges(sctx, idxID, getRange(6, 6)) require.NoError(t, err) require.Equal(t, 1.0, count) colID := table.Meta().Columns[0].ID - count, err = statsTbl.GetRowCountByIntColumnRanges(sc, colID, getRange(7, 7)) + count, err = statsTbl.GetRowCountByIntColumnRanges(sctx, colID, getRange(7, 7)) require.NoError(t, err) require.Equal(t, 1.0, count) - count, err = statsTbl.GetRowCountByIntColumnRanges(sc, colID, getRange(6, 6)) + count, err = statsTbl.GetRowCountByIntColumnRanges(sctx, colID, getRange(6, 6)) require.NoError(t, err) require.Equal(t, 1.0, count) } @@ -760,7 +760,7 @@ func TestSmallRangeEstimation(t *testing.T) { table, err := dom.InfoSchema().TableByName(model.NewCIStr("test"), model.NewCIStr("t")) require.NoError(t, err) statsTbl := h.GetTableStats(table.Meta()) - sc := &stmtctx.StatementContext{} + sctx := mock.NewContext() col := statsTbl.Columns[table.Meta().Columns[0].ID] var input []struct { @@ -775,7 +775,7 @@ func TestSmallRangeEstimation(t *testing.T) { statsSuiteData := statistics.GetStatsSuiteData() statsSuiteData.GetTestCases(t, &input, &output) for i, ran := range input { - count, err := col.GetColumnRowCount(sc, getRange(ran.Start, ran.End), statsTbl.Count, false) + count, err := col.GetColumnRowCount(sctx, getRange(ran.Start, ran.End), statsTbl.Count, false) require.NoError(t, err) testdata.OnRecord(func() { output[i].Start = ran.Start diff --git a/statistics/statistics_test.go b/statistics/statistics_test.go index 53df0e04bd7be..382a44d117f08 100644 --- a/statistics/statistics_test.go +++ b/statistics/statistics_test.go @@ -24,7 +24,6 @@ import ( "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/sessionctx" - "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/chunk" "github.com/pingcap/tidb/util/collate" @@ -201,13 +200,13 @@ func TestPseudoTable(t *testing.T) { tbl := PseudoTable(ti) require.Equal(t, len(tbl.Columns), 1) require.Greater(t, tbl.Count, int64(0)) - sc := new(stmtctx.StatementContext) - count := tbl.ColumnLessRowCount(sc, types.NewIntDatum(100), colInfo.ID) + sctx := mock.NewContext() + count := tbl.ColumnLessRowCount(sctx, types.NewIntDatum(100), colInfo.ID) require.Equal(t, 3333, int(count)) - count, err := tbl.ColumnEqualRowCount(sc, types.NewIntDatum(1000), colInfo.ID) + count, err := tbl.ColumnEqualRowCount(sctx, types.NewIntDatum(1000), colInfo.ID) require.NoError(t, err) require.Equal(t, 10, int(count)) - count, _ = tbl.ColumnBetweenRowCount(sc, types.NewIntDatum(1000), types.NewIntDatum(5000), colInfo.ID) + count, _ = tbl.ColumnBetweenRowCount(sctx, types.NewIntDatum(1000), types.NewIntDatum(5000), colInfo.ID) require.Equal(t, 250, int(count)) ti.Columns = append(ti.Columns, &model.ColumnInfo{ ID: 2, @@ -261,50 +260,50 @@ func SubTestColumnRange() func(*testing.T) { LowVal: []types.Datum{{}}, HighVal: []types.Datum{types.MaxValueDatum()}, }} - count, err := tbl.GetRowCountByColumnRanges(sc, 0, ran) + count, err := tbl.GetRowCountByColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 100000, int(count)) ran[0].LowVal[0] = types.MinNotNullDatum() - count, err = tbl.GetRowCountByColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 99900, int(count)) ran[0].LowVal[0] = types.NewIntDatum(1000) ran[0].LowExclude = true ran[0].HighVal[0] = types.NewIntDatum(2000) ran[0].HighExclude = true - count, err = tbl.GetRowCountByColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 2500, int(count)) ran[0].LowExclude = false ran[0].HighExclude = false - count, err = tbl.GetRowCountByColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 2500, int(count)) ran[0].LowVal[0] = ran[0].HighVal[0] - count, err = tbl.GetRowCountByColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 100, int(count)) tbl.Columns[0] = col ran[0].LowVal[0] = types.Datum{} ran[0].HighVal[0] = types.MaxValueDatum() - count, err = tbl.GetRowCountByColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 100000, int(count)) ran[0].LowVal[0] = types.NewIntDatum(1000) ran[0].LowExclude = true ran[0].HighVal[0] = types.NewIntDatum(2000) ran[0].HighExclude = true - count, err = tbl.GetRowCountByColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 9998, int(count)) ran[0].LowExclude = false ran[0].HighExclude = false - count, err = tbl.GetRowCountByColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 10000, int(count)) ran[0].LowVal[0] = ran[0].HighVal[0] - count, err = tbl.GetRowCountByColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 1, int(count)) } @@ -316,7 +315,6 @@ func SubTestIntColumnRanges() func(*testing.T) { s := createTestStatisticsSamples(t) bucketCount := int64(256) ctx := mock.NewContext() - sc := ctx.GetSessionVars().StmtCtx s.pk.(*recordSet).cursor = 0 rowCount, hg, err := buildPK(ctx, bucketCount, 0, s.pk) @@ -334,22 +332,22 @@ func SubTestIntColumnRanges() func(*testing.T) { LowVal: []types.Datum{types.NewIntDatum(math.MinInt64)}, HighVal: []types.Datum{types.NewIntDatum(math.MaxInt64)}, }} - count, err := tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err := tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 100000, int(count)) ran[0].LowVal[0].SetInt64(1000) ran[0].HighVal[0].SetInt64(2000) - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 1000, int(count)) ran[0].LowVal[0].SetInt64(1001) ran[0].HighVal[0].SetInt64(1999) - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 998, int(count)) ran[0].LowVal[0].SetInt64(1000) ran[0].HighVal[0].SetInt64(1000) - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 1, int(count)) @@ -357,49 +355,49 @@ func SubTestIntColumnRanges() func(*testing.T) { LowVal: []types.Datum{types.NewUintDatum(0)}, HighVal: []types.Datum{types.NewUintDatum(math.MaxUint64)}, }} - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 100000, int(count)) ran[0].LowVal[0].SetUint64(1000) ran[0].HighVal[0].SetUint64(2000) - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 1000, int(count)) ran[0].LowVal[0].SetUint64(1001) ran[0].HighVal[0].SetUint64(1999) - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 998, int(count)) ran[0].LowVal[0].SetUint64(1000) ran[0].HighVal[0].SetUint64(1000) - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 1, int(count)) tbl.Columns[0] = col ran[0].LowVal[0].SetInt64(math.MinInt64) ran[0].HighVal[0].SetInt64(math.MaxInt64) - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 100000, int(count)) ran[0].LowVal[0].SetInt64(1000) ran[0].HighVal[0].SetInt64(2000) - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 1001, int(count)) ran[0].LowVal[0].SetInt64(1001) ran[0].HighVal[0].SetInt64(1999) - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 999, int(count)) ran[0].LowVal[0].SetInt64(1000) ran[0].HighVal[0].SetInt64(1000) - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 1, int(count)) tbl.Count *= 10 - count, err = tbl.GetRowCountByIntColumnRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIntColumnRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 1, int(count)) } @@ -411,7 +409,6 @@ func SubTestIndexRanges() func(*testing.T) { s := createTestStatisticsSamples(t) bucketCount := int64(256) ctx := mock.NewContext() - sc := ctx.GetSessionVars().StmtCtx s.rc.(*recordSet).cursor = 0 rowCount, hg, cms, err := buildIndex(ctx, bucketCount, 0, s.rc) @@ -430,51 +427,51 @@ func SubTestIndexRanges() func(*testing.T) { LowVal: []types.Datum{types.MinNotNullDatum()}, HighVal: []types.Datum{types.MaxValueDatum()}, }} - count, err := tbl.GetRowCountByIndexRanges(sc, 0, ran) + count, err := tbl.GetRowCountByIndexRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 99900, int(count)) ran[0].LowVal[0] = types.NewIntDatum(1000) ran[0].HighVal[0] = types.NewIntDatum(2000) - count, err = tbl.GetRowCountByIndexRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIndexRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 2500, int(count)) ran[0].LowVal[0] = types.NewIntDatum(1001) ran[0].HighVal[0] = types.NewIntDatum(1999) - count, err = tbl.GetRowCountByIndexRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIndexRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 2500, int(count)) ran[0].LowVal[0] = types.NewIntDatum(1000) ran[0].HighVal[0] = types.NewIntDatum(1000) - count, err = tbl.GetRowCountByIndexRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIndexRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 100, int(count)) tbl.Indices[0] = &Index{Info: &model.IndexInfo{Columns: []*model.IndexColumn{{Offset: 0}}, Unique: true}} ran[0].LowVal[0] = types.NewIntDatum(1000) ran[0].HighVal[0] = types.NewIntDatum(1000) - count, err = tbl.GetRowCountByIndexRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIndexRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 1, int(count)) tbl.Indices[0] = idx ran[0].LowVal[0] = types.MinNotNullDatum() ran[0].HighVal[0] = types.MaxValueDatum() - count, err = tbl.GetRowCountByIndexRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIndexRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 100000, int(count)) ran[0].LowVal[0] = types.NewIntDatum(1000) ran[0].HighVal[0] = types.NewIntDatum(2000) - count, err = tbl.GetRowCountByIndexRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIndexRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 1000, int(count)) ran[0].LowVal[0] = types.NewIntDatum(1001) ran[0].HighVal[0] = types.NewIntDatum(1990) - count, err = tbl.GetRowCountByIndexRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIndexRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 989, int(count)) ran[0].LowVal[0] = types.NewIntDatum(1000) ran[0].HighVal[0] = types.NewIntDatum(1000) - count, err = tbl.GetRowCountByIndexRanges(sc, 0, ran) + count, err = tbl.GetRowCountByIndexRanges(ctx, 0, ran) require.NoError(t, err) require.Equal(t, 0, int(count)) } diff --git a/statistics/table.go b/statistics/table.go index 358744716525b..10e08001c7528 100644 --- a/statistics/table.go +++ b/statistics/table.go @@ -285,27 +285,28 @@ func (t *Table) IsOutdated() bool { } // ColumnGreaterRowCount estimates the row count where the column greater than value. -func (t *Table) ColumnGreaterRowCount(sc *stmtctx.StatementContext, value types.Datum, colID int64) float64 { +func (t *Table) ColumnGreaterRowCount(sctx sessionctx.Context, value types.Datum, colID int64) float64 { c, ok := t.Columns[colID] - if !ok || c.IsInvalid(sc, t.Pseudo) { + if !ok || c.IsInvalid(sctx, t.Pseudo) { return float64(t.Count) / pseudoLessRate } return c.greaterRowCount(value) * c.GetIncreaseFactor(t.Count) } // ColumnLessRowCount estimates the row count where the column less than value. Note that null values are not counted. -func (t *Table) ColumnLessRowCount(sc *stmtctx.StatementContext, value types.Datum, colID int64) float64 { +func (t *Table) ColumnLessRowCount(sctx sessionctx.Context, value types.Datum, colID int64) float64 { c, ok := t.Columns[colID] - if !ok || c.IsInvalid(sc, t.Pseudo) { + if !ok || c.IsInvalid(sctx, t.Pseudo) { return float64(t.Count) / pseudoLessRate } return c.lessRowCount(value) * c.GetIncreaseFactor(t.Count) } // ColumnBetweenRowCount estimates the row count where column greater or equal to a and less than b. -func (t *Table) ColumnBetweenRowCount(sc *stmtctx.StatementContext, a, b types.Datum, colID int64) (float64, error) { +func (t *Table) ColumnBetweenRowCount(sctx sessionctx.Context, a, b types.Datum, colID int64) (float64, error) { + sc := sctx.GetSessionVars().StmtCtx c, ok := t.Columns[colID] - if !ok || c.IsInvalid(sc, t.Pseudo) { + if !ok || c.IsInvalid(sctx, t.Pseudo) { return float64(t.Count) / pseudoBetweenRate, nil } aEncoded, err := codec.EncodeKey(sc, nil, a) @@ -316,7 +317,7 @@ func (t *Table) ColumnBetweenRowCount(sc *stmtctx.StatementContext, a, b types.D if err != nil { return 0, err } - count := c.BetweenRowCount(sc, a, b, aEncoded, bEncoded) + count := c.BetweenRowCount(sctx, a, b, aEncoded, bEncoded) if a.IsNull() { count += float64(c.NullCount) } @@ -324,25 +325,26 @@ func (t *Table) ColumnBetweenRowCount(sc *stmtctx.StatementContext, a, b types.D } // ColumnEqualRowCount estimates the row count where the column equals to value. -func (t *Table) ColumnEqualRowCount(sc *stmtctx.StatementContext, value types.Datum, colID int64) (float64, error) { +func (t *Table) ColumnEqualRowCount(sctx sessionctx.Context, value types.Datum, colID int64) (float64, error) { c, ok := t.Columns[colID] - if !ok || c.IsInvalid(sc, t.Pseudo) { + if !ok || c.IsInvalid(sctx, t.Pseudo) { return float64(t.Count) / pseudoEqualRate, nil } - encodedVal, err := codec.EncodeKey(sc, nil, value) + encodedVal, err := codec.EncodeKey(sctx.GetSessionVars().StmtCtx, nil, value) if err != nil { return 0, err } - result, err := c.equalRowCount(sc, value, encodedVal, t.ModifyCount) + result, err := c.equalRowCount(sctx, value, encodedVal, t.ModifyCount) result *= c.GetIncreaseFactor(t.Count) return result, errors.Trace(err) } // GetRowCountByIntColumnRanges estimates the row count by a slice of IntColumnRange. -func (coll *HistColl) GetRowCountByIntColumnRanges(sc *stmtctx.StatementContext, colID int64, intRanges []*ranger.Range) (float64, error) { +func (coll *HistColl) GetRowCountByIntColumnRanges(sctx sessionctx.Context, colID int64, intRanges []*ranger.Range) (float64, error) { + sc := sctx.GetSessionVars().StmtCtx var result float64 c, ok := coll.Columns[colID] - if !ok || c.IsInvalid(sc, coll.Pseudo) { + if !ok || c.IsInvalid(sctx, coll.Pseudo) { if len(intRanges) == 0 { return 0, nil } @@ -352,36 +354,38 @@ func (coll *HistColl) GetRowCountByIntColumnRanges(sc *stmtctx.StatementContext, result = getPseudoRowCountByUnsignedIntRanges(intRanges, float64(coll.Count)) } if sc.EnableOptimizerCETrace && ok { - CETraceRange(sc, coll.PhysicalID, []string{c.Info.Name.O}, intRanges, "Column Stats-Pseudo", uint64(result)) + CETraceRange(sctx, coll.PhysicalID, []string{c.Info.Name.O}, intRanges, "Column Stats-Pseudo", uint64(result)) } return result, nil } - result, err := c.GetColumnRowCount(sc, intRanges, coll.Count, true) + result, err := c.GetColumnRowCount(sctx, intRanges, coll.Count, true) if sc.EnableOptimizerCETrace { - CETraceRange(sc, coll.PhysicalID, []string{c.Info.Name.O}, intRanges, "Column Stats", uint64(result)) + CETraceRange(sctx, coll.PhysicalID, []string{c.Info.Name.O}, intRanges, "Column Stats", uint64(result)) } return result, errors.Trace(err) } // GetRowCountByColumnRanges estimates the row count by a slice of Range. -func (coll *HistColl) GetRowCountByColumnRanges(sc *stmtctx.StatementContext, colID int64, colRanges []*ranger.Range) (float64, error) { +func (coll *HistColl) GetRowCountByColumnRanges(sctx sessionctx.Context, colID int64, colRanges []*ranger.Range) (float64, error) { + sc := sctx.GetSessionVars().StmtCtx c, ok := coll.Columns[colID] - if !ok || c.IsInvalid(sc, coll.Pseudo) { + if !ok || c.IsInvalid(sctx, coll.Pseudo) { result, err := GetPseudoRowCountByColumnRanges(sc, float64(coll.Count), colRanges, 0) if err == nil && sc.EnableOptimizerCETrace && ok { - CETraceRange(sc, coll.PhysicalID, []string{c.Info.Name.O}, colRanges, "Column Stats-Pseudo", uint64(result)) + CETraceRange(sctx, coll.PhysicalID, []string{c.Info.Name.O}, colRanges, "Column Stats-Pseudo", uint64(result)) } return result, err } - result, err := c.GetColumnRowCount(sc, colRanges, coll.Count, false) + result, err := c.GetColumnRowCount(sctx, colRanges, coll.Count, false) if sc.EnableOptimizerCETrace { - CETraceRange(sc, coll.PhysicalID, []string{c.Info.Name.O}, colRanges, "Column Stats", uint64(result)) + CETraceRange(sctx, coll.PhysicalID, []string{c.Info.Name.O}, colRanges, "Column Stats", uint64(result)) } return result, errors.Trace(err) } // GetRowCountByIndexRanges estimates the row count by a slice of Range. -func (coll *HistColl) GetRowCountByIndexRanges(sc *stmtctx.StatementContext, idxID int64, indexRanges []*ranger.Range) (float64, error) { +func (coll *HistColl) GetRowCountByIndexRanges(sctx sessionctx.Context, idxID int64, indexRanges []*ranger.Range) (float64, error) { + sc := sctx.GetSessionVars().StmtCtx idx, ok := coll.Indices[idxID] colNames := make([]string, 0, 8) if ok { @@ -396,28 +400,29 @@ func (coll *HistColl) GetRowCountByIndexRanges(sc *stmtctx.StatementContext, idx } result, err := getPseudoRowCountByIndexRanges(sc, indexRanges, float64(coll.Count), colsLen) if err == nil && sc.EnableOptimizerCETrace && ok { - CETraceRange(sc, coll.PhysicalID, colNames, indexRanges, "Index Stats-Pseudo", uint64(result)) + CETraceRange(sctx, coll.PhysicalID, colNames, indexRanges, "Index Stats-Pseudo", uint64(result)) } return result, err } var result float64 var err error if idx.CMSketch != nil && idx.StatsVer == Version1 { - result, err = coll.getIndexRowCount(sc, idxID, indexRanges) + result, err = coll.getIndexRowCount(sctx, idxID, indexRanges) } else { - result, err = idx.GetRowCount(sc, coll, indexRanges, coll.Count) + result, err = idx.GetRowCount(sctx, coll, indexRanges, coll.Count) } if sc.EnableOptimizerCETrace { - CETraceRange(sc, coll.PhysicalID, colNames, indexRanges, "Index Stats", uint64(result)) + CETraceRange(sctx, coll.PhysicalID, colNames, indexRanges, "Index Stats", uint64(result)) } return result, errors.Trace(err) } // CETraceRange appends a list of ranges and related information into CE trace -func CETraceRange(sc *stmtctx.StatementContext, tableID int64, colNames []string, ranges []*ranger.Range, tp string, rowCount uint64) { +func CETraceRange(sctx sessionctx.Context, tableID int64, colNames []string, ranges []*ranger.Range, tp string, rowCount uint64) { + sc := sctx.GetSessionVars().StmtCtx allPoint := true for _, ran := range ranges { - if !ran.IsPointNullable(sc) { + if !ran.IsPointNullable(sctx) { allPoint = false break } @@ -572,7 +577,7 @@ func outOfRangeEQSelectivity(ndv, realtimeRowCount, columnRowCount int64) float6 } // crossValidationSelectivity gets the selectivity of multi-column equal conditions by cross validation. -func (coll *HistColl) crossValidationSelectivity(sc *stmtctx.StatementContext, idx *Index, usedColsLen int, idxPointRange *ranger.Range) (float64, float64, error) { +func (coll *HistColl) crossValidationSelectivity(sctx sessionctx.Context, idx *Index, usedColsLen int, idxPointRange *ranger.Range) (float64, float64, error) { minRowCount := math.MaxFloat64 cols := coll.Idx2ColumnIDs[idx.ID] crossValidationSelectivity := 1.0 @@ -582,7 +587,7 @@ func (coll *HistColl) crossValidationSelectivity(sc *stmtctx.StatementContext, i break } if col, ok := coll.Columns[colID]; ok { - if col.IsInvalid(sc, coll.Pseudo) { + if col.IsInvalid(sctx, coll.Pseudo) { continue } lowExclude := idxPointRange.LowExclude @@ -604,7 +609,7 @@ func (coll *HistColl) crossValidationSelectivity(sc *stmtctx.StatementContext, i HighExclude: highExclude, } - rowCount, err := col.GetColumnRowCount(sc, []*ranger.Range{&rang}, coll.Count, col.IsHandle) + rowCount, err := col.GetColumnRowCount(sctx, []*ranger.Range{&rang}, coll.Count, col.IsHandle) if err != nil { return 0, 0, err } @@ -619,7 +624,7 @@ func (coll *HistColl) crossValidationSelectivity(sc *stmtctx.StatementContext, i } // getEqualCondSelectivity gets the selectivity of the equal conditions. -func (coll *HistColl) getEqualCondSelectivity(sc *stmtctx.StatementContext, idx *Index, bytes []byte, usedColsLen int, idxPointRange *ranger.Range) (float64, error) { +func (coll *HistColl) getEqualCondSelectivity(sctx sessionctx.Context, idx *Index, bytes []byte, usedColsLen int, idxPointRange *ranger.Range) (float64, error) { coverAll := len(idx.Info.Columns) == usedColsLen // In this case, the row count is at most 1. if idx.Info.Unique && coverAll { @@ -646,7 +651,7 @@ func (coll *HistColl) getEqualCondSelectivity(sc *stmtctx.StatementContext, idx return outOfRangeEQSelectivity(ndv, coll.Count, int64(idx.TotalRowCount())), nil } - minRowCount, crossValidationSelectivity, err := coll.crossValidationSelectivity(sc, idx, usedColsLen, idxPointRange) + minRowCount, crossValidationSelectivity, err := coll.crossValidationSelectivity(sctx, idx, usedColsLen, idxPointRange) if err != nil { return 0, nil } @@ -658,7 +663,8 @@ func (coll *HistColl) getEqualCondSelectivity(sc *stmtctx.StatementContext, idx return idxCount / idx.TotalRowCount(), nil } -func (coll *HistColl) getIndexRowCount(sc *stmtctx.StatementContext, idxID int64, indexRanges []*ranger.Range) (float64, error) { +func (coll *HistColl) getIndexRowCount(sctx sessionctx.Context, idxID int64, indexRanges []*ranger.Range) (float64, error) { + sc := sctx.GetSessionVars().StmtCtx idx := coll.Indices[idxID] totalCount := float64(0) for _, ran := range indexRanges { @@ -675,7 +681,7 @@ func (coll *HistColl) getIndexRowCount(sc *stmtctx.StatementContext, idxID int64 // on single-column index, use previous way as well, because CMSketch does not contain null // values in this case. if rangePosition == 0 || isSingleColIdxNullRange(idx, ran) { - count, err := idx.GetRowCount(sc, nil, []*ranger.Range{ran}, coll.Count) + count, err := idx.GetRowCount(sctx, nil, []*ranger.Range{ran}, coll.Count) if err != nil { return 0, errors.Trace(err) } @@ -689,7 +695,7 @@ func (coll *HistColl) getIndexRowCount(sc *stmtctx.StatementContext, idxID int64 if err != nil { return 0, errors.Trace(err) } - selectivity, err = coll.getEqualCondSelectivity(sc, idx, bytes, rangePosition, ran) + selectivity, err = coll.getEqualCondSelectivity(sctx, idx, bytes, rangePosition, ran) if err != nil { return 0, errors.Trace(err) } @@ -705,7 +711,7 @@ func (coll *HistColl) getIndexRowCount(sc *stmtctx.StatementContext, idxID int64 if err != nil { return 0, err } - res, err := coll.getEqualCondSelectivity(sc, idx, bytes, rangePosition, ran) + res, err := coll.getEqualCondSelectivity(sctx, idx, bytes, rangePosition, ran) if err != nil { return 0, errors.Trace(err) } @@ -731,9 +737,9 @@ func (coll *HistColl) getIndexRowCount(sc *stmtctx.StatementContext, idxID int64 } // prefer index stats over column stats if idx, ok := coll.ColID2IdxID[colID]; ok { - count, err = coll.GetRowCountByIndexRanges(sc, idx, []*ranger.Range{&rang}) + count, err = coll.GetRowCountByIndexRanges(sctx, idx, []*ranger.Range{&rang}) } else { - count, err = coll.GetRowCountByColumnRanges(sc, colID, []*ranger.Range{&rang}) + count, err = coll.GetRowCountByColumnRanges(sctx, colID, []*ranger.Range{&rang}) } if err != nil { return 0, errors.Trace(err) diff --git a/util/ranger/types.go b/util/ranger/types.go index 2e8cc1dc6120d..f2bf561f6a3cf 100644 --- a/util/ranger/types.go +++ b/util/ranger/types.go @@ -119,8 +119,8 @@ func (ran *Range) IsPointNonNullable(sctx sessionctx.Context) bool { // IsPointNullable returns if the range is a point. // TODO: unify the parameter type with IsPointNullable and IsPoint -func (ran *Range) IsPointNullable(stmtCtx *stmtctx.StatementContext) bool { - return ran.isPoint(stmtCtx, true) +func (ran *Range) IsPointNullable(sctx sessionctx.Context) bool { + return ran.isPoint(sctx.GetSessionVars().StmtCtx, true) } // IsFullRange check if the range is full scan range From 5612d5a3f721a7f9088cd2d7f6d2b124c50cf7ef Mon Sep 17 00:00:00 2001 From: Yuanjia Zhang Date: Mon, 13 Dec 2021 17:50:22 +0800 Subject: [PATCH 2/5] fmt --- statistics/handle/handle_test.go | 10 +++++----- statistics/handle/update.go | 4 ++-- statistics/histogram.go | 2 +- 3 files changed, 8 insertions(+), 8 deletions(-) diff --git a/statistics/handle/handle_test.go b/statistics/handle/handle_test.go index 7498985580afd..c3ab9790c078c 100644 --- a/statistics/handle/handle_test.go +++ b/statistics/handle/handle_test.go @@ -17,7 +17,6 @@ package handle_test import ( "bytes" "fmt" - "github.com/pingcap/tidb/util/mock" "math" "strings" "testing" @@ -40,6 +39,7 @@ import ( "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/util/collate" "github.com/pingcap/tidb/util/israce" + "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/ranger" "github.com/pingcap/tidb/util/testkit" "github.com/tikv/client-go/v2/oracle" @@ -1144,8 +1144,8 @@ func (s *testStatsSuite) TestGlobalStatsData2(c *C) { tk.MustExec("analyze table tint with 2 topn, 2 buckets") tk.MustQuery("select modify_count, count from mysql.stats_meta order by table_id asc").Check(testkit.Rows( - "0 20", // global: g.count = p0.count + p1.count - "0 9", // p0 + "0 20", // global: g.count = p0.count + p1.count + "0 9", // p0 "0 11")) // p1 tk.MustQuery("show stats_topn where table_name='tint' and is_index=0").Check(testkit.Rows( @@ -1175,7 +1175,7 @@ func (s *testStatsSuite) TestGlobalStatsData2(c *C) { tk.MustQuery("select distinct_count, null_count, tot_col_size from mysql.stats_histograms where is_index=0 order by table_id asc").Check( testkit.Rows("12 1 19", // global, g = p0 + p1 - "5 1 8", // p0 + "5 1 8", // p0 "7 0 11")) // p1 tk.MustQuery("show stats_buckets where is_index=1").Check(testkit.Rows( @@ -1189,7 +1189,7 @@ func (s *testStatsSuite) TestGlobalStatsData2(c *C) { tk.MustQuery("select distinct_count, null_count from mysql.stats_histograms where is_index=1 order by table_id asc").Check( testkit.Rows("12 1", // global, g = p0 + p1 - "5 1", // p0 + "5 1", // p0 "7 0")) // p1 // double + (column + index with 1 column) diff --git a/statistics/handle/update.go b/statistics/handle/update.go index 68ccbe05c6811..10ae4ae2bec3d 100644 --- a/statistics/handle/update.go +++ b/statistics/handle/update.go @@ -18,8 +18,6 @@ import ( "bytes" "context" "fmt" - "github.com/pingcap/tidb/parser/mysql" - "github.com/pingcap/tidb/sessionctx" "math" "strconv" "strings" @@ -32,7 +30,9 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/parser/model" + "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" + "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/statistics" diff --git a/statistics/histogram.go b/statistics/histogram.go index 67196081badaa..22f03c6698d7e 100644 --- a/statistics/histogram.go +++ b/statistics/histogram.go @@ -17,7 +17,6 @@ package statistics import ( "bytes" "fmt" - "github.com/pingcap/tidb/sessionctx" "math" "sort" "strings" @@ -31,6 +30,7 @@ import ( "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/parser/terror" + "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/tablecodec" From f19419671391131867bc9c321c8380c619edd0da Mon Sep 17 00:00:00 2001 From: Yuanjia Zhang Date: Mon, 13 Dec 2021 17:53:31 +0800 Subject: [PATCH 3/5] refmt --- statistics/handle/handle_test.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/statistics/handle/handle_test.go b/statistics/handle/handle_test.go index c3ab9790c078c..70ec989f7bca6 100644 --- a/statistics/handle/handle_test.go +++ b/statistics/handle/handle_test.go @@ -1144,8 +1144,8 @@ func (s *testStatsSuite) TestGlobalStatsData2(c *C) { tk.MustExec("analyze table tint with 2 topn, 2 buckets") tk.MustQuery("select modify_count, count from mysql.stats_meta order by table_id asc").Check(testkit.Rows( - "0 20", // global: g.count = p0.count + p1.count - "0 9", // p0 + "0 20", // global: g.count = p0.count + p1.count + "0 9", // p0 "0 11")) // p1 tk.MustQuery("show stats_topn where table_name='tint' and is_index=0").Check(testkit.Rows( @@ -1175,7 +1175,7 @@ func (s *testStatsSuite) TestGlobalStatsData2(c *C) { tk.MustQuery("select distinct_count, null_count, tot_col_size from mysql.stats_histograms where is_index=0 order by table_id asc").Check( testkit.Rows("12 1 19", // global, g = p0 + p1 - "5 1 8", // p0 + "5 1 8", // p0 "7 0 11")) // p1 tk.MustQuery("show stats_buckets where is_index=1").Check(testkit.Rows( @@ -1189,7 +1189,7 @@ func (s *testStatsSuite) TestGlobalStatsData2(c *C) { tk.MustQuery("select distinct_count, null_count from mysql.stats_histograms where is_index=1 order by table_id asc").Check( testkit.Rows("12 1", // global, g = p0 + p1 - "5 1", // p0 + "5 1", // p0 "7 0")) // p1 // double + (column + index with 1 column) From 09d585c95cbd85d193998dbace5a60af6bb612cf Mon Sep 17 00:00:00 2001 From: Yuanjia Zhang Date: Mon, 13 Dec 2021 17:58:59 +0800 Subject: [PATCH 4/5] fix CI --- planner/util/path.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/planner/util/path.go b/planner/util/path.go index 76e6c5173793d..9a0b4207d1314 100644 --- a/planner/util/path.go +++ b/planner/util/path.go @@ -153,7 +153,7 @@ func isColEqCorColOrConstant(ctx sessionctx.Context, filter expression.Expressio func (path *AccessPath) OnlyPointRange(sctx sessionctx.Context) bool { if path.IsIntHandlePath { for _, ran := range path.Ranges { - if !ran.IsPointNullable(sctx.GetSessionVars().StmtCtx) { + if !ran.IsPointNullable(sctx) { return false } } From d9f4f274d2b8ab9753f572dc0bcff300f3a1d80b Mon Sep 17 00:00:00 2001 From: Yuanjia Zhang Date: Mon, 13 Dec 2021 18:15:51 +0800 Subject: [PATCH 5/5] fix ci --- statistics/histogram.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/statistics/histogram.go b/statistics/histogram.go index 22f03c6698d7e..5e1788da7a1ac 100644 --- a/statistics/histogram.go +++ b/statistics/histogram.go @@ -1072,7 +1072,7 @@ func (c *Column) IsInvalid(sctx sessionctx.Context, collPseudo bool) bool { if collPseudo && c.NotAccurate() { return true } - if c.Histogram.NDV > 0 && c.notNullCount() == 0 && sctx.GetSessionVars().StmtCtx != nil { + if c.Histogram.NDV > 0 && c.notNullCount() == 0 && sctx != nil && sctx.GetSessionVars().StmtCtx != nil { HistogramNeededColumns.insert(tableColumnID{TableID: c.PhysicalID, ColumnID: c.Info.ID}) } return c.TotalRowCount() == 0 || (c.Histogram.NDV > 0 && c.notNullCount() == 0)