diff --git a/executor/analyze.go b/executor/analyze.go index 675cb33161453..3e111de1c7c4f 100644 --- a/executor/analyze.go +++ b/executor/analyze.go @@ -770,7 +770,7 @@ func (e *AnalyzeColumnsExec) buildResp(ranges []*ranger.Range) (distsql.SelectRe // decodeSampleDataWithVirtualColumn constructs the virtual column by evaluating from the deocded normal columns. // If it failed, it would return false to trigger normal decoding way without the virtual column. func (e AnalyzeColumnsExec) decodeSampleDataWithVirtualColumn( - collector *statistics.ReservoirRowSampleCollector, + collector statistics.RowSampleCollector, fieldTps []*types.FieldType, virtualColIdx []int, schema *expression.Schema, @@ -779,9 +779,9 @@ func (e AnalyzeColumnsExec) decodeSampleDataWithVirtualColumn( for _, col := range e.schemaForVirtualColEval.Columns { totFts = append(totFts, col.RetType) } - chk := chunk.NewChunkWithCapacity(totFts, len(collector.Samples)) + chk := chunk.NewChunkWithCapacity(totFts, len(collector.Base().Samples)) decoder := codec.NewDecoder(chk, e.ctx.GetSessionVars().Location()) - for _, sample := range collector.Samples { + for _, sample := range collector.Base().Samples { for i := range sample.Columns { if schema.Columns[i].VirtualExpr != nil { continue @@ -799,7 +799,7 @@ func (e AnalyzeColumnsExec) decodeSampleDataWithVirtualColumn( iter := chunk.NewIterator4Chunk(chk) for row, i := iter.Begin(), 0; row != iter.End(); row, i = iter.Next(), i+1 { datums := row.GetDatumRow(totFts) - collector.Samples[i].Columns = datums + collector.Base().Samples[i].Columns = datums } return nil } @@ -842,9 +842,9 @@ func (e *AnalyzeColumnsExec) buildSamplingStats( }() l := len(e.analyzePB.ColReq.ColumnsInfo) + len(e.analyzePB.ColReq.ColumnGroups) - rootRowCollector := statistics.NewReservoirRowSampleCollector(int(e.analyzePB.ColReq.SampleSize), l) + rootRowCollector := statistics.NewRowSampleCollector(int(e.analyzePB.ColReq.SampleSize), e.analyzePB.ColReq.GetSampleRate(), l) for i := 0; i < l; i++ { - rootRowCollector.FMSketches = append(rootRowCollector.FMSketches, statistics.NewFMSketch(maxSketchSize)) + rootRowCollector.Base().FMSketches = append(rootRowCollector.Base().FMSketches, statistics.NewFMSketch(maxSketchSize)) } sc := e.ctx.GetSessionVars().StmtCtx statsConcurrency, err := getBuildStatsConcurrency(e.ctx) @@ -894,7 +894,7 @@ func (e *AnalyzeColumnsExec) buildSamplingStats( } } else { // If there's no virtual column or we meet error during eval virtual column, we fallback to normal decode otherwise. - for _, sample := range rootRowCollector.Samples { + for _, sample := range rootRowCollector.Base().Samples { for i := range sample.Columns { sample.Columns[i], err = tablecodec.DecodeColumnValue(sample.Columns[i].GetBytes(), &e.colsInfo[i].FieldType, sc.TimeZone) if err != nil { @@ -904,7 +904,7 @@ func (e *AnalyzeColumnsExec) buildSamplingStats( } } - for _, sample := range rootRowCollector.Samples { + for _, sample := range rootRowCollector.Base().Samples { // Calculate handle from the row data for each row. It will be used to sort the samples. sample.Handle, err = e.handleCols.BuildHandleByDatums(sample.Columns) if err != nil { @@ -916,8 +916,8 @@ func (e *AnalyzeColumnsExec) buildSamplingStats( // The order of the samples are broken when merging samples from sub-collectors. // So now we need to sort the samples according to the handle in order to calculate correlation. - sort.Slice(rootRowCollector.Samples, func(i, j int) bool { - return rootRowCollector.Samples[i].Handle.Compare(rootRowCollector.Samples[j].Handle) < 0 + sort.Slice(rootRowCollector.Base().Samples, func(i, j int) bool { + return rootRowCollector.Base().Samples[i].Handle.Compare(rootRowCollector.Base().Samples[j].Handle) < 0 }) totalLen := len(e.colsInfo) + len(e.indexes) @@ -941,7 +941,7 @@ func (e *AnalyzeColumnsExec) buildSamplingStats( isColumn: true, slicePos: i, } - fmSketches = append(fmSketches, rootRowCollector.FMSketches[i]) + fmSketches = append(fmSketches, rootRowCollector.Base().FMSketches[i]) } indexPushedDownResult := <-idxNDVPushDownCh @@ -950,8 +950,8 @@ func (e *AnalyzeColumnsExec) buildSamplingStats( } for _, offset := range indexesWithVirtualColOffsets { ret := indexPushedDownResult.results[e.indexes[offset].ID] - rootRowCollector.NullCount[colLen+offset] = ret.Count - rootRowCollector.FMSketches[colLen+offset] = ret.Ars[0].Fms[0] + rootRowCollector.Base().NullCount[colLen+offset] = ret.Count + rootRowCollector.Base().FMSketches[colLen+offset] = ret.Ars[0].Fms[0] } // build index stats @@ -963,7 +963,7 @@ func (e *AnalyzeColumnsExec) buildSamplingStats( isColumn: false, slicePos: colLen + i, } - fmSketches = append(fmSketches, rootRowCollector.FMSketches[colLen+i]) + fmSketches = append(fmSketches, rootRowCollector.Base().FMSketches[colLen+i]) } close(buildTaskChan) panicCnt := 0 @@ -983,7 +983,7 @@ func (e *AnalyzeColumnsExec) buildSamplingStats( if err != nil { return 0, nil, nil, nil, nil, err } - count = rootRowCollector.Count + count = rootRowCollector.Base().Count if needExtStats { statsHandle := domain.GetDomain(e.ctx).StatsHandle() extStats, err = statsHandle.BuildExtendedStats(e.TableID.GetStatisticsID(), e.colsInfo, sampleCollectors) @@ -1160,7 +1160,7 @@ func (e *AnalyzeColumnsExec) buildSubIndexJobForSpecialIndex(indexInfos []*model } type samplingMergeResult struct { - collector *statistics.ReservoirRowSampleCollector + collector statistics.RowSampleCollector err error } @@ -1190,9 +1190,9 @@ func (e *AnalyzeColumnsExec) subMergeWorker(resultCh chan<- *samplingMergeResult failpoint.Inject("mockAnalyzeSamplingMergeWorkerPanic", func() { panic("failpoint triggered") }) - retCollector := statistics.NewReservoirRowSampleCollector(int(e.analyzePB.ColReq.SampleSize), l) + retCollector := statistics.NewRowSampleCollector(int(e.analyzePB.ColReq.SampleSize), e.analyzePB.ColReq.GetSampleRate(), l) for i := 0; i < l; i++ { - retCollector.FMSketches = append(retCollector.FMSketches, statistics.NewFMSketch(maxSketchSize)) + retCollector.Base().FMSketches = append(retCollector.Base().FMSketches, statistics.NewFMSketch(maxSketchSize)) } for { data, ok := <-taskCh @@ -1205,11 +1205,9 @@ func (e *AnalyzeColumnsExec) subMergeWorker(resultCh chan<- *samplingMergeResult resultCh <- &samplingMergeResult{err: err} return } - subCollector := &statistics.ReservoirRowSampleCollector{ - MaxSampleSize: int(e.analyzePB.ColReq.SampleSize), - } - subCollector.FromProto(colResp.RowCollector) - e.job.Update(subCollector.Count) + subCollector := statistics.NewRowSampleCollector(int(e.analyzePB.ColReq.SampleSize), e.analyzePB.ColReq.GetSampleRate(), l) + subCollector.Base().FromProto(colResp.RowCollector) + e.job.Update(subCollector.Base().Count) retCollector.MergeCollector(subCollector) } resultCh <- &samplingMergeResult{collector: retCollector} @@ -1217,7 +1215,7 @@ func (e *AnalyzeColumnsExec) subMergeWorker(resultCh chan<- *samplingMergeResult type samplingBuildTask struct { id int64 - rootRowCollector *statistics.ReservoirRowSampleCollector + rootRowCollector statistics.RowSampleCollector tp *types.FieldType isColumn bool slicePos int @@ -1256,8 +1254,8 @@ workLoop: topns[task.slicePos] = nil continue } - sampleItems := make([]*statistics.SampleItem, 0, task.rootRowCollector.MaxSampleSize) - for j, row := range task.rootRowCollector.Samples { + sampleItems := make([]*statistics.SampleItem, 0, task.rootRowCollector.Base().Samples.Len()) + for j, row := range task.rootRowCollector.Base().Samples { if row.Columns[task.slicePos].IsNull() { continue } @@ -1276,17 +1274,17 @@ workLoop: } collector = &statistics.SampleCollector{ Samples: sampleItems, - NullCount: task.rootRowCollector.NullCount[task.slicePos], - Count: task.rootRowCollector.Count - task.rootRowCollector.NullCount[task.slicePos], - FMSketch: task.rootRowCollector.FMSketches[task.slicePos], - TotalSize: task.rootRowCollector.TotalSizes[task.slicePos], + NullCount: task.rootRowCollector.Base().NullCount[task.slicePos], + Count: task.rootRowCollector.Base().Count - task.rootRowCollector.Base().NullCount[task.slicePos], + FMSketch: task.rootRowCollector.Base().FMSketches[task.slicePos], + TotalSize: task.rootRowCollector.Base().TotalSizes[task.slicePos], } } else { var tmpDatum types.Datum var err error idx := e.indexes[task.slicePos-colLen] - sampleItems := make([]*statistics.SampleItem, 0, task.rootRowCollector.MaxSampleSize) - for _, row := range task.rootRowCollector.Samples { + sampleItems := make([]*statistics.SampleItem, 0, task.rootRowCollector.Base().Samples.Len()) + for _, row := range task.rootRowCollector.Base().Samples { if len(idx.Columns) == 1 && row.Columns[idx.Columns[0].Offset].IsNull() { continue } @@ -1315,10 +1313,10 @@ workLoop: } collector = &statistics.SampleCollector{ Samples: sampleItems, - NullCount: task.rootRowCollector.NullCount[task.slicePos], - Count: task.rootRowCollector.Count - task.rootRowCollector.NullCount[task.slicePos], - FMSketch: task.rootRowCollector.FMSketches[task.slicePos], - TotalSize: task.rootRowCollector.TotalSizes[task.slicePos], + NullCount: task.rootRowCollector.Base().NullCount[task.slicePos], + Count: task.rootRowCollector.Base().Count - task.rootRowCollector.Base().NullCount[task.slicePos], + FMSketch: task.rootRowCollector.Base().FMSketches[task.slicePos], + TotalSize: task.rootRowCollector.Base().TotalSizes[task.slicePos], } } if task.isColumn { diff --git a/executor/builder.go b/executor/builder.go index 491f438067d6a..6e8615e424563 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -17,6 +17,7 @@ package executor import ( "bytes" "context" + "math" "sort" "strconv" "strings" @@ -2216,9 +2217,17 @@ func (b *executorBuilder) buildAnalyzeSamplingPushdown(task plannercore.AnalyzeC baseCount: count, baseModifyCnt: modifyCount, } + sampleRate := new(float64) + if opts[ast.AnalyzeOptNumSamples] == 0 { + *sampleRate = math.Float64frombits(opts[ast.AnalyzeOptSampleRate]) + if *sampleRate < 0 { + *sampleRate = b.getAdjustedSampleRate(b.ctx, task.TableID.GetStatisticsID(), task.TblInfo) + } + } e.analyzePB.ColReq = &tipb.AnalyzeColumnsReq{ BucketSize: int64(opts[ast.AnalyzeOptNumBuckets]), SampleSize: int64(opts[ast.AnalyzeOptNumSamples]), + SampleRate: sampleRate, SketchSize: maxSketchSize, ColumnsInfo: util.ColumnsToProto(task.ColsInfo, task.TblInfo.PKIsHandle), ColumnGroups: colGroups, @@ -2233,6 +2242,29 @@ func (b *executorBuilder) buildAnalyzeSamplingPushdown(task plannercore.AnalyzeC return &analyzeTask{taskType: colTask, colExec: e, job: job} } +func (b *executorBuilder) getAdjustedSampleRate(sctx sessionctx.Context, tid int64, tblInfo *model.TableInfo) float64 { + statsHandle := domain.GetDomain(sctx).StatsHandle() + defaultRate := 0.001 + if statsHandle == nil { + return defaultRate + } + var statsTbl *statistics.Table + if tid == tblInfo.ID { + statsTbl = statsHandle.GetTableStats(tblInfo) + } else { + statsTbl = statsHandle.GetPartitionStats(tblInfo, tid) + } + if statsTbl == nil { + return defaultRate + } + // If the count in stats_meta is still 0, the table is not large, we scan all rows. + if statsTbl.Count == 0 { + return 1 + } + // We are expected to scan about 100000 rows or so. + return math.Min(1, 110000/float64(statsTbl.Count)) +} + func (b *executorBuilder) buildAnalyzeColumnsPushdown(task plannercore.AnalyzeColumnsTask, opts map[ast.AnalyzeOptionType]uint64, autoAnalyze string, schemaForVirtualColEval *expression.Schema) *analyzeTask { if task.StatsVersion == statistics.Version2 { return b.buildAnalyzeSamplingPushdown(task, opts, autoAnalyze, schemaForVirtualColEval) diff --git a/go.mod b/go.mod index e86842b09aa2d..ff413c090d599 100644 --- a/go.mod +++ b/go.mod @@ -53,7 +53,7 @@ require ( github.com/pingcap/sysutil v0.0.0-20210730114356-fcd8a63f68c5 github.com/pingcap/tidb-tools v5.2.2-0.20211019062242-37a8bef2fa17+incompatible github.com/pingcap/tidb/parser v0.0.0-20211011031125-9b13dc409c5e - github.com/pingcap/tipb v0.0.0-20211025074540-e1c7362eeeb4 + github.com/pingcap/tipb v0.0.0-20211026080602-ec68283c1735 github.com/prometheus/client_golang v1.5.1 github.com/prometheus/client_model v0.2.0 github.com/prometheus/common v0.9.1 diff --git a/go.sum b/go.sum index 0459139991bad..cd088486d4721 100644 --- a/go.sum +++ b/go.sum @@ -614,8 +614,8 @@ github.com/pingcap/tidb-dashboard v0.0.0-20210312062513-eef5d6404638/go.mod h1:O github.com/pingcap/tidb-dashboard v0.0.0-20210716172320-2226872e3296/go.mod h1:OCXbZTBTIMRcIt0jFsuCakZP+goYRv6IjawKbwLS2TQ= github.com/pingcap/tidb-tools v5.2.2-0.20211019062242-37a8bef2fa17+incompatible h1:c7+izmker91NkjkZ6FgTlmD4k1A5FLOAq+li6Ki2/GY= github.com/pingcap/tidb-tools v5.2.2-0.20211019062242-37a8bef2fa17+incompatible/go.mod h1:XGdcy9+yqlDSEMTpOXnwf3hiTeqrV6MN/u1se9N8yIM= -github.com/pingcap/tipb v0.0.0-20211025074540-e1c7362eeeb4 h1:9Ef4j3DLmUidURfob0tf94v+sqvozqdCTr7e5hi19qU= -github.com/pingcap/tipb v0.0.0-20211025074540-e1c7362eeeb4/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs= +github.com/pingcap/tipb v0.0.0-20211026080602-ec68283c1735 h1:kS8pJNUnF3ENkjtBcJeMe/W8+9RtrChcortoyljCwwc= +github.com/pingcap/tipb v0.0.0-20211026080602-ec68283c1735/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs= github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= diff --git a/planner/core/physical_plan_test.go b/planner/core/physical_plan_test.go index e1773a2f98c93..84f3631c7e473 100644 --- a/planner/core/physical_plan_test.go +++ b/planner/core/physical_plan_test.go @@ -17,6 +17,7 @@ package core_test import ( "context" "fmt" + "math" . "github.com/pingcap/check" "github.com/pingcap/tidb/domain" @@ -24,6 +25,7 @@ import ( "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/parser" + "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/model" "github.com/pingcap/tidb/parser/terror" "github.com/pingcap/tidb/planner" @@ -111,6 +113,121 @@ func (s *testPlanSuite) TestDAGPlanBuilderSimpleCase(c *C) { } } +func (s *testPlanSuite) TestAnalyzeBuildSucc(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + defer func() { + dom.Close() + store.Close() + }() + se, err := session.CreateSession4Test(store) + c.Assert(err, IsNil) + _, err = se.Execute(context.Background(), "use test") + c.Assert(err, IsNil) + sctx := se.(sessionctx.Context) + _, err = se.Execute(context.Background(), "create table t(a int)") + c.Assert(err, IsNil) + tests := []struct { + sql string + succ bool + statsVer int + }{ + { + sql: "analyze table t with 0.1 samplerate", + succ: true, + statsVer: 2, + }, + { + sql: "analyze table t with 0.1 samplerate", + succ: false, + statsVer: 1, + }, + { + sql: "analyze table t with 10 samplerate", + succ: false, + statsVer: 2, + }, + { + sql: "analyze table t with 0.1 samplerate, 100000 samples", + succ: false, + statsVer: 2, + }, + { + sql: "analyze table t with 0.1 samplerate, 100000 samples", + succ: false, + statsVer: 1, + }, + } + for i, tt := range tests { + comment := Commentf("The %v-th test failed", i) + _, err := se.Execute(context.Background(), fmt.Sprintf("set @@tidb_analyze_version=%v", tt.statsVer)) + c.Assert(err, IsNil) + + stmt, err := s.ParseOneStmt(tt.sql, "", "") + if tt.succ { + c.Assert(err, IsNil, comment) + } else if err != nil { + continue + } + err = core.Preprocess(se, stmt, core.WithPreprocessorReturn(&core.PreprocessorReturn{InfoSchema: s.is})) + c.Assert(err, IsNil) + _, _, err = planner.Optimize(context.Background(), sctx, stmt, s.is) + if tt.succ { + c.Assert(err, IsNil, comment) + } else { + c.Assert(err, NotNil, comment) + } + } +} + +func (s *testPlanSuite) TestAnalyzeSetRate(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + defer func() { + dom.Close() + store.Close() + }() + se, err := session.CreateSession4Test(store) + c.Assert(err, IsNil) + _, err = se.Execute(context.Background(), "use test") + c.Assert(err, IsNil) + sctx := se.(sessionctx.Context) + _, err = se.Execute(context.Background(), "create table t(a int)") + c.Assert(err, IsNil) + tests := []struct { + sql string + rate float64 + }{ + { + sql: "analyze table t", + rate: -1, + }, + { + sql: "analyze table t with 0.1 samplerate", + rate: 0.1, + }, + { + sql: "analyze table t with 10000 samples", + rate: -1, + }, + } + for i, tt := range tests { + comment := Commentf("The %v-th test failed", i) + c.Assert(err, IsNil) + + stmt, err := s.ParseOneStmt(tt.sql, "", "") + c.Assert(err, IsNil, comment) + err = core.Preprocess(se, stmt, core.WithPreprocessorReturn(&core.PreprocessorReturn{InfoSchema: s.is})) + c.Assert(err, IsNil) + p, _, err := planner.Optimize(context.Background(), sctx, stmt, s.is) + c.Assert(err, IsNil, comment) + ana := p.(*core.Analyze) + c.Assert(math.Float64frombits(ana.Opts[ast.AnalyzeOptSampleRate]), Equals, tt.rate) + } +} + func (s *testPlanSuite) TestDAGPlanBuilderJoin(c *C) { defer testleak.AfterTest(c)() store, dom, err := newStoreWithBootstrap() diff --git a/planner/core/planbuilder.go b/planner/core/planbuilder.go index e2f8efda89240..391ff3f0930ed 100644 --- a/planner/core/planbuilder.go +++ b/planner/core/planbuilder.go @@ -2067,8 +2067,8 @@ var analyzeOptionDefaultV2 = map[ast.AnalyzeOptionType]uint64{ ast.AnalyzeOptNumTopN: 500, ast.AnalyzeOptCMSketchWidth: 2048, ast.AnalyzeOptCMSketchDepth: 5, - ast.AnalyzeOptNumSamples: 100000, - ast.AnalyzeOptSampleRate: math.Float64bits(0), + ast.AnalyzeOptNumSamples: 0, + ast.AnalyzeOptSampleRate: math.Float64bits(-1), } func handleAnalyzeOptions(opts []ast.AnalyzeOpt, statsVer int) (map[ast.AnalyzeOptionType]uint64, error) { diff --git a/statistics/row_sampler.go b/statistics/row_sampler.go index da1c303d23e66..16cebffa4fb76 100644 --- a/statistics/row_sampler.go +++ b/statistics/row_sampler.go @@ -30,6 +30,13 @@ import ( "github.com/pingcap/tipb/go-tipb" ) +// RowSampleCollector implements the needed interface for a row-based sample collector. +type RowSampleCollector interface { + MergeCollector(collector RowSampleCollector) + sampleRow(row []types.Datum, rng *rand.Rand) + Base() *baseCollector +} + type baseCollector struct { Samples WeightedRowSampleHeap NullCount []int64 @@ -90,18 +97,30 @@ func (h *WeightedRowSampleHeap) Pop() interface{} { return item } -// ReservoirRowSampleBuilder is used to construct the ReservoirRowSampleCollector to get the samples. -type ReservoirRowSampleBuilder struct { +// RowSampleBuilder is used to construct the ReservoirRowSampleCollector to get the samples. +type RowSampleBuilder struct { Sc *stmtctx.StatementContext RecordSet sqlexec.RecordSet ColsFieldType []*types.FieldType Collators []collate.Collator ColGroups [][]int64 MaxSampleSize int + SampleRate float64 MaxFMSketchSize int Rng *rand.Rand } +// NewRowSampleCollector creates a collector from the given inputs. +func NewRowSampleCollector(maxSampleSize int, sampleRate float64, totalLen int) RowSampleCollector { + if maxSampleSize > 0 { + return NewReservoirRowSampleCollector(maxSampleSize, totalLen) + } + if sampleRate > 0 { + return NewBernoulliRowSampleCollector(sampleRate, totalLen) + } + return nil +} + // NewReservoirRowSampleCollector creates the new collector by the given inputs. func NewReservoirRowSampleCollector(maxSampleSize int, totalLen int) *ReservoirRowSampleCollector { base := &baseCollector{ @@ -119,19 +138,10 @@ func NewReservoirRowSampleCollector(maxSampleSize int, totalLen int) *ReservoirR // Collect first builds the collector. Then maintain the null count, FM sketch and the data size for each column and // column group. // Then use the weighted reservoir sampling to collect the samples. -func (s *ReservoirRowSampleBuilder) Collect() (*ReservoirRowSampleCollector, error) { - base := &baseCollector{ - Samples: make(WeightedRowSampleHeap, 0, s.MaxSampleSize), - NullCount: make([]int64, len(s.ColsFieldType)+len(s.ColGroups)), - FMSketches: make([]*FMSketch, 0, len(s.ColsFieldType)+len(s.ColGroups)), - TotalSizes: make([]int64, len(s.ColsFieldType)+len(s.ColGroups)), - } - collector := &ReservoirRowSampleCollector{ - baseCollector: base, - MaxSampleSize: s.MaxSampleSize, - } +func (s *RowSampleBuilder) Collect() (RowSampleCollector, error) { + collector := NewRowSampleCollector(s.MaxSampleSize, s.SampleRate, len(s.ColsFieldType)+len(s.ColGroups)) for i := 0; i < len(s.ColsFieldType)+len(s.ColGroups); i++ { - collector.FMSketches = append(collector.FMSketches, NewFMSketch(s.MaxFMSketchSize)) + collector.Base().FMSketches = append(collector.Base().FMSketches, NewFMSketch(s.MaxFMSketchSize)) } ctx := context.TODO() chk := s.RecordSet.NewChunk() @@ -144,7 +154,7 @@ func (s *ReservoirRowSampleBuilder) Collect() (*ReservoirRowSampleCollector, err if chk.NumRows() == 0 { return collector, nil } - collector.Count += int64(chk.NumRows()) + collector.Base().Count += int64(chk.NumRows()) for row := it.Begin(); row != it.End(); row = it.Next() { datums := RowToDatums(row, s.RecordSet.Fields()) newCols := make([]types.Datum, len(datums)) @@ -171,25 +181,20 @@ func (s *ReservoirRowSampleBuilder) Collect() (*ReservoirRowSampleCollector, err datums[i].SetBytes(encodedKey) } } - err := collector.collectColumns(s.Sc, datums, sizes) + err := collector.Base().collectColumns(s.Sc, datums, sizes) if err != nil { return nil, err } - err = collector.collectColumnGroups(s.Sc, datums, s.ColGroups, sizes) + err = collector.Base().collectColumnGroups(s.Sc, datums, s.ColGroups, sizes) if err != nil { return nil, err } - weight := s.Rng.Int63() - item := &ReservoirRowSampleItem{ - Columns: newCols, - Weight: weight, - } - collector.sampleZippedRow(item) + collector.sampleRow(newCols, s.Rng) } } } -func (s *ReservoirRowSampleCollector) collectColumns(sc *stmtctx.StatementContext, cols []types.Datum, sizes []int64) error { +func (s *baseCollector) collectColumns(sc *stmtctx.StatementContext, cols []types.Datum, sizes []int64) error { for i, col := range cols { if col.IsNull() { s.NullCount[i]++ @@ -205,7 +210,7 @@ func (s *ReservoirRowSampleCollector) collectColumns(sc *stmtctx.StatementContex return nil } -func (s *ReservoirRowSampleCollector) collectColumnGroups(sc *stmtctx.StatementContext, cols []types.Datum, colGroups [][]int64, sizes []int64) error { +func (s *baseCollector) collectColumnGroups(sc *stmtctx.StatementContext, cols []types.Datum, colGroups [][]int64, sizes []int64) error { colLen := len(cols) datumBuffer := make([]types.Datum, 0, len(cols)) for i, group := range colGroups { @@ -229,22 +234,8 @@ func (s *ReservoirRowSampleCollector) collectColumnGroups(sc *stmtctx.StatementC return nil } -func (s *ReservoirRowSampleCollector) sampleZippedRow(sample *ReservoirRowSampleItem) { - if len(s.Samples) < s.MaxSampleSize { - s.Samples = append(s.Samples, sample) - if len(s.Samples) == s.MaxSampleSize { - heap.Init(&s.Samples) - } - return - } - if s.Samples[0].Weight < sample.Weight { - s.Samples[0] = sample - heap.Fix(&s.Samples, 0) - } -} - -// ToProto converts the collector to proto struct. -func (s *ReservoirRowSampleCollector) ToProto() *tipb.RowSampleCollector { +// ToProto converts the collector to pb struct. +func (s *baseCollector) ToProto() *tipb.RowSampleCollector { pbFMSketches := make([]*tipb.FMSketch, 0, len(s.FMSketches)) for _, sketch := range s.FMSketches { pbFMSketches = append(pbFMSketches, FMSketchToProto(sketch)) @@ -259,9 +250,7 @@ func (s *ReservoirRowSampleCollector) ToProto() *tipb.RowSampleCollector { return collector } -// FromProto constructs the collector from the proto struct. -func (s *ReservoirRowSampleCollector) FromProto(pbCollector *tipb.RowSampleCollector) { - s.baseCollector = &baseCollector{} +func (s *baseCollector) FromProto(pbCollector *tipb.RowSampleCollector) { s.Count = pbCollector.Count s.NullCount = pbCollector.NullCounts s.FMSketches = make([]*FMSketch, 0, len(pbCollector.FmSketch)) @@ -277,8 +266,7 @@ func (s *ReservoirRowSampleCollector) FromProto(pbCollector *tipb.RowSampleColle copy(b, col) data = append(data, types.NewBytesDatum(b)) } - // The samples collected from regions are also organized by binary heap. So we can just copy the slice. - // No need to maintain the heap again. + // Directly copy the weight. s.Samples = append(s.Samples, &ReservoirRowSampleItem{ Columns: data, Weight: pbSample.Weight, @@ -286,19 +274,59 @@ func (s *ReservoirRowSampleCollector) FromProto(pbCollector *tipb.RowSampleColle } } +// Base implements the RowSampleCollector interface. +func (s *ReservoirRowSampleCollector) Base() *baseCollector { + return s.baseCollector +} + +func (s *ReservoirRowSampleCollector) sampleZippedRow(sample *ReservoirRowSampleItem) { + if len(s.Samples) < s.MaxSampleSize { + s.Samples = append(s.Samples, sample) + if len(s.Samples) == s.MaxSampleSize { + heap.Init(&s.Samples) + } + return + } + if s.Samples[0].Weight < sample.Weight { + s.Samples[0] = sample + heap.Fix(&s.Samples, 0) + } +} + +func (s *ReservoirRowSampleCollector) sampleRow(row []types.Datum, rng *rand.Rand) { + weight := rng.Int63() + if len(s.Samples) < s.MaxSampleSize { + s.Samples = append(s.Samples, &ReservoirRowSampleItem{ + Columns: row, + Weight: weight, + }) + if len(s.Samples) == s.MaxSampleSize { + heap.Init(&s.Samples) + } + return + } + if s.Samples[0].Weight < weight { + s.Samples[0] = &ReservoirRowSampleItem{ + Columns: row, + Weight: weight, + } + heap.Fix(&s.Samples, 0) + } +} + // MergeCollector merges the collectors to a final one. -func (s *ReservoirRowSampleCollector) MergeCollector(subCollector *ReservoirRowSampleCollector) { - s.Count += subCollector.Count - for i := range subCollector.FMSketches { - s.FMSketches[i].MergeFMSketch(subCollector.FMSketches[i]) +func (s *ReservoirRowSampleCollector) MergeCollector(subCollector RowSampleCollector) { + s.Count += subCollector.Base().Count + for i, fms := range subCollector.Base().FMSketches { + s.FMSketches[i].MergeFMSketch(fms) } - for i := range subCollector.NullCount { - s.NullCount[i] += subCollector.NullCount[i] + for i, nullCount := range subCollector.Base().NullCount { + s.NullCount[i] += nullCount } - for i := range subCollector.TotalSizes { - s.TotalSizes[i] += subCollector.TotalSizes[i] + for i, totSize := range subCollector.Base().TotalSizes { + s.TotalSizes[i] += totSize } - for _, sample := range subCollector.Samples { + for _, sample := range subCollector.Base().Samples { s.sampleZippedRow(sample) } } @@ -326,3 +354,60 @@ func RowSamplesToProto(samples WeightedRowSampleHeap) []*tipb.RowSample { } return rows } + +// BernoulliRowSampleCollector collects the samples from the source and organize the sample by row. +// It will maintain the following things: +// Row samples. +// FM sketches(To calculate the NDV). +// Null counts. +// The data sizes. +// The number of rows. +// It uses the bernoulli sampling to collect the data. +type BernoulliRowSampleCollector struct { + *baseCollector + SampleRate float64 +} + +// NewBernoulliRowSampleCollector creates the new collector by the given inputs. +func NewBernoulliRowSampleCollector(sampleRate float64, totalLen int) *BernoulliRowSampleCollector { + base := &baseCollector{ + Samples: make(WeightedRowSampleHeap, 0, 8), + NullCount: make([]int64, totalLen), + FMSketches: make([]*FMSketch, 0, totalLen), + TotalSizes: make([]int64, totalLen), + } + return &BernoulliRowSampleCollector{ + baseCollector: base, + SampleRate: sampleRate, + } +} + +func (s *BernoulliRowSampleCollector) sampleRow(row []types.Datum, rng *rand.Rand) { + if rng.Float64() > s.SampleRate { + return + } + s.baseCollector.Samples = append(s.baseCollector.Samples, &ReservoirRowSampleItem{ + Columns: row, + Weight: 0, + }) +} + +// MergeCollector merges the collectors to a final one. +func (s *BernoulliRowSampleCollector) MergeCollector(subCollector RowSampleCollector) { + s.Count += subCollector.Base().Count + for i := range subCollector.Base().FMSketches { + s.FMSketches[i].MergeFMSketch(subCollector.Base().FMSketches[i]) + } + for i := range subCollector.Base().NullCount { + s.NullCount[i] += subCollector.Base().NullCount[i] + } + for i := range subCollector.Base().TotalSizes { + s.TotalSizes[i] += subCollector.Base().TotalSizes[i] + } + s.baseCollector.Samples = append(s.baseCollector.Samples, subCollector.Base().Samples...) +} + +// Base implements the interface RowSampleCollector. +func (s *BernoulliRowSampleCollector) Base() *baseCollector { + return s.baseCollector +} diff --git a/statistics/sample_test.go b/statistics/sample_test.go index 191ef9267e1ae..e22311ec7b298 100644 --- a/statistics/sample_test.go +++ b/statistics/sample_test.go @@ -67,7 +67,7 @@ func TestWeightedSampling(t *testing.T) { // for x := 0; x < 800; x++ { itemCnt := make([]int, rowNum) for loopI := 0; loopI < loopCnt; loopI++ { - builder := &ReservoirRowSampleBuilder{ + builder := &RowSampleBuilder{ Sc: sc, RecordSet: rs, ColsFieldType: []*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}, @@ -79,8 +79,8 @@ func TestWeightedSampling(t *testing.T) { } collector, err := builder.Collect() require.NoError(t, err) - for i := 0; i < collector.MaxSampleSize; i++ { - a := collector.Samples[i].Columns[0].GetInt64() + for i := 0; i < int(sampleNum); i++ { + a := collector.Base().Samples[i].Columns[0].GetInt64() itemCnt[a]++ } require.Nil(t, rs.Close()) @@ -110,7 +110,7 @@ func TestDistributedWeightedSampling(t *testing.T) { rootRowCollector := NewReservoirRowSampleCollector(int(sampleNum), 1) rootRowCollector.FMSketches = append(rootRowCollector.FMSketches, NewFMSketch(1000)) for i := 0; i < batch; i++ { - builder := &ReservoirRowSampleBuilder{ + builder := &RowSampleBuilder{ Sc: sc, RecordSet: sets[i], ColsFieldType: []*types.FieldType{types.NewFieldType(mysql.TypeLonglong)}, diff --git a/store/mockstore/unistore/cophandler/analyze.go b/store/mockstore/unistore/cophandler/analyze.go index 176d7da7c9b53..9ef0b3fb10f3d 100644 --- a/store/mockstore/unistore/cophandler/analyze.go +++ b/store/mockstore/unistore/cophandler/analyze.go @@ -423,7 +423,7 @@ func handleAnalyzeFullSamplingReq( } colReq := analyzeReq.ColReq /* #nosec G404 */ - builder := &statistics.ReservoirRowSampleBuilder{ + builder := &statistics.RowSampleBuilder{ Sc: sc, RecordSet: e, ColsFieldType: fts, @@ -431,6 +431,7 @@ func handleAnalyzeFullSamplingReq( ColGroups: colGroups, MaxSampleSize: int(colReq.SampleSize), MaxFMSketchSize: int(colReq.SketchSize), + SampleRate: colReq.GetSampleRate(), Rng: rand.New(rand.NewSource(time.Now().UnixNano())), } collector, err := builder.Collect() @@ -438,7 +439,7 @@ func handleAnalyzeFullSamplingReq( return nil, err } colResp := &tipb.AnalyzeColumnsResp{} - colResp.RowCollector = collector.ToProto() + colResp.RowCollector = collector.Base().ToProto() data, err := colResp.Marshal() if err != nil { return nil, err