diff --git a/executor/analyze.go b/executor/analyze.go index 93ae0e3d803ff..4ed65e7542f4d 100644 --- a/executor/analyze.go +++ b/executor/analyze.go @@ -168,7 +168,8 @@ func (e *AnalyzeExec) Next(ctx context.Context, req *chunk.Chunk) error { } if needGlobalStats { for globalStatsID, info := range globalStatsMap { - globalStats, err := statsHandle.MergePartitionStats2GlobalStats(infoschema.GetInfoSchema(e.ctx), globalStatsID.tableID, info.isIndex, info.idxID) + sc := e.ctx.GetSessionVars().StmtCtx + globalStats, err := statsHandle.MergePartitionStats2GlobalStats(sc, infoschema.GetInfoSchema(e.ctx), globalStatsID.tableID, info.isIndex, info.idxID) if err != nil { return err } diff --git a/statistics/handle/handle.go b/statistics/handle/handle.go index 986da8a00f013..815508a10f35a 100644 --- a/statistics/handle/handle.go +++ b/statistics/handle/handle.go @@ -295,7 +295,7 @@ type GlobalStats struct { } // MergePartitionStats2GlobalStats merge the partition-level stats to global-level stats based on the tableID. -func (h *Handle) MergePartitionStats2GlobalStats(is infoschema.InfoSchema, physicalID int64, isIndex int, idxID int64) (globalStats *GlobalStats, err error) { +func (h *Handle) MergePartitionStats2GlobalStats(sc *stmtctx.StatementContext, is infoschema.InfoSchema, physicalID int64, isIndex int, idxID int64) (globalStats *GlobalStats, err error) { // get the partition table IDs h.mu.Lock() globalTable, ok := h.getTableByPhysicalID(is, physicalID) @@ -389,7 +389,13 @@ func (h *Handle) MergePartitionStats2GlobalStats(is infoschema.InfoSchema, physi } // Merge histogram - err = errors.Errorf("TODO: The merge function of the histogram structure has not been implemented yet") + globalStats.Hg[i], err = statistics.MergePartitionHist2GlobalHist(sc, allHg[i], 0) + if err != nil { + return + } + + // Merge NDV + err = errors.Errorf("TODO: The merge function of the NDV has not been implemented yet") if err != nil { return } diff --git a/statistics/histogram.go b/statistics/histogram.go index ee574962d957f..34702b5516627 100644 --- a/statistics/histogram.go +++ b/statistics/histogram.go @@ -1495,3 +1495,310 @@ func (hg *Histogram) ExtractTopN(cms *CMSketch, topN *TopN, numCols int, numTopN topN.Sort() return nil } + +// bucket4Merging is only used for merging partition hists to global hist. +type bucket4Merging struct { + lower *types.Datum + upper *types.Datum + Bucket + // disjointNDV is used for merging bucket NDV, see mergeBucketNDV for more details. + disjointNDV int64 +} + +func newBucket4Meging() *bucket4Merging { + return &bucket4Merging{ + lower: new(types.Datum), + upper: new(types.Datum), + Bucket: Bucket{ + Repeat: 0, + NDV: 0, + Count: 0, + }, + disjointNDV: 0, + } +} + +// buildBucket4Merging builds bucket4Merging from Histogram +// Notice: Count in Histogram.Buckets is prefix sum but in bucket4Merging is not. +func (hg *Histogram) buildBucket4Merging() []*bucket4Merging { + buckets := make([]*bucket4Merging, 0, hg.Len()) + for i := 0; i < hg.Len(); i++ { + b := newBucket4Meging() + hg.GetLower(i).Copy(b.lower) + hg.GetUpper(i).Copy(b.upper) + b.Repeat = hg.Buckets[i].Repeat + b.NDV = hg.Buckets[i].NDV + b.Count = hg.Buckets[i].Count + if i != 0 { + b.Count -= hg.Buckets[i-1].Count + } + buckets = append(buckets, b) + } + return buckets +} + +func (b *bucket4Merging) Clone() bucket4Merging { + return bucket4Merging{ + lower: b.lower.Clone(), + upper: b.upper.Clone(), + Bucket: Bucket{ + Repeat: b.Repeat, + NDV: b.NDV, + Count: b.Count, + }, + disjointNDV: b.disjointNDV, + } +} + +// mergeBucketNDV merges bucket NDV from tow bucket `right` & `left`. +// Before merging, you need to make sure that when using (upper, lower) as the comparison key, `right` is greater than `left` +func mergeBucketNDV(sc *stmtctx.StatementContext, left *bucket4Merging, right *bucket4Merging) (*bucket4Merging, error) { + res := right.Clone() + if left.NDV == 0 { + return &res, nil + } + if right.NDV == 0 { + res.lower = left.lower.Clone() + res.upper = left.upper.Clone() + res.NDV = left.NDV + return &res, nil + } + upperCompare, err := right.upper.CompareDatum(sc, left.upper) + if err != nil { + return nil, err + } + // __right__| + // _______left____| + // illegal order. + if upperCompare < 0 { + return nil, errors.Errorf("illegal bucket order") + } + // ___right_| + // ___left__| + // They have the same upper. + if upperCompare == 0 { + lowerCompare, err := right.lower.CompareDatum(sc, left.lower) + if err != nil { + return nil, err + } + // |____right____| + // |__left____| + // illegal order. + if lowerCompare < 0 { + return nil, errors.Errorf("illegal bucket order") + } + // |___right___| + // |____left___| + // ndv = max(right.ndv, left.ndv) + if lowerCompare == 0 { + if left.NDV > right.NDV { + res.NDV = left.NDV + } + return &res, nil + } + // |_right_| + // |_____left______| + // |-ratio-| + // ndv = ratio * left.ndv + max((1-ratio) * left.ndv, right.ndv) + ratio := calcFraction4Datums(left.lower, left.upper, right.lower) + res.NDV = int64(ratio*float64(left.NDV) + math.Max((1-ratio)*float64(left.NDV), float64(right.NDV))) + res.lower = left.lower.Clone() + return &res, nil + } + // ____right___| + // ____left__| + // right.upper > left.upper + lowerCompareUpper, err := right.lower.CompareDatum(sc, left.upper) + if err != nil { + return nil, err + } + // |_right_| + // |___left____| + // `left` and `right` do not intersect + // We add right.ndv in `disjointNDV`, and let `right.ndv = left.ndv` be used for subsequent merge. + // This is because, for the merging of many buckets, we merge them from back to front. + if lowerCompareUpper >= 0 { + res.upper = left.upper.Clone() + res.lower = left.lower.Clone() + res.disjointNDV += right.NDV + res.NDV = left.NDV + return &res, nil + } + upperRatio := calcFraction4Datums(right.lower, right.upper, left.upper) + lowerCompare, err := right.lower.CompareDatum(sc, left.lower) + if err != nil { + return nil, err + } + // |-upperRatio-| + // |_______right_____| + // |_______left______________| + // |-lowerRatio-| + // ndv = lowerRatio * left.ndv + // + max((1-lowerRatio) * left.ndv, upperRatio * right.ndv) + // + (1-upperRatio) * right.ndv + if lowerCompare >= 0 { + lowerRatio := calcFraction4Datums(left.lower, left.upper, right.lower) + res.NDV = int64(lowerRatio*float64(left.NDV) + + math.Max((1-lowerRatio)*float64(left.NDV), upperRatio*float64(right.NDV)) + + (1-upperRatio)*float64(right.NDV)) + res.lower = left.lower.Clone() + return &res, nil + } + // |------upperRatio--------| + // |-lowerRatio-| + // |____________right______________| + // |___left____| + // ndv = lowerRatio * right.ndv + // + max(left.ndv + (upperRatio - lowerRatio) * right.ndv) + // + (1-upperRatio) * right.ndv + lowerRatio := calcFraction4Datums(right.lower, right.upper, left.lower) + res.NDV = int64(lowerRatio*float64(right.NDV) + + math.Max(float64(left.NDV), (upperRatio-lowerRatio)*float64(right.NDV)) + + (1-upperRatio)*float64(right.NDV)) + return &res, nil +} + +// mergeParitionBuckets merges buckets[l...r) to one global bucket. +// global bucket: +// upper = buckets[r-1].upper +// count = sum of buckets[l...r).count +// repeat = sum of buckets[i] (buckets[i].upper == global bucket.upper && i in [l...r)) +// ndv = merge bucket ndv from r-1 to l by mergeBucketNDV +// Notice: lower is not calculated here. +func mergePartitionBuckets(sc *stmtctx.StatementContext, buckets []*bucket4Merging) (*bucket4Merging, error) { + if len(buckets) == 0 { + return nil, errors.Errorf("not enough buckets to merge") + } + res := bucket4Merging{} + res.upper = buckets[len(buckets)-1].upper.Clone() + right := buckets[len(buckets)-1].Clone() + for i := len(buckets) - 1; i >= 0; i-- { + res.Count += buckets[i].Count + compare, err := buckets[i].upper.CompareDatum(sc, res.upper) + if err != nil { + return nil, err + } + if compare == 0 { + res.Repeat += buckets[i].Repeat + } + if i != len(buckets)-1 { + tmp, err := mergeBucketNDV(sc, buckets[i], &right) + if err != nil { + return nil, err + } + right = *tmp + } + } + res.NDV = right.NDV + right.disjointNDV + return &res, nil +} + +// MergePartitionHist2GlobalHist merges hists (partition-level Histogram) to a global-level Histogram +// Notice: If expBucketNumber == 0, we will let expBucketNumber = max(hists.Len()) +func MergePartitionHist2GlobalHist(sc *stmtctx.StatementContext, hists []*Histogram, expBucketNumber int64) (*Histogram, error) { + var totCount, totNull, bucketNumber, totColSize int64 + needBucketNumber := false + if expBucketNumber == 0 { + needBucketNumber = true + } + // minValue is used to calc the bucket lower. + var minValue *types.Datum + for _, hist := range hists { + totColSize += hist.TotColSize + totNull += hist.NullCount + bucketNumber += int64(hist.Len()) + if hist.Len() > 0 { + totCount += hist.Buckets[hist.Len()-1].Count + if needBucketNumber && int64(hist.Len()) > expBucketNumber { + expBucketNumber = int64(hist.Len()) + } + if minValue == nil { + minValue = hist.GetLower(0).Clone() + continue + } + res, err := hist.GetLower(0).CompareDatum(sc, minValue) + if err != nil { + return nil, err + } + if res < 0 { + minValue = hist.GetLower(0).Clone() + } + } + } + buckets := make([]*bucket4Merging, 0, bucketNumber) + globalBuckets := make([]*bucket4Merging, 0, expBucketNumber) + + // init `buckets`. + for _, hist := range hists { + buckets = append(buckets, hist.buildBucket4Merging()...) + } + var sortError error + sort.Slice(buckets, func(i, j int) bool { + res, err := buckets[i].upper.CompareDatum(sc, buckets[j].upper) + if err != nil { + sortError = err + } + if res != 0 { + return res < 0 + } + res, err = buckets[i].lower.CompareDatum(sc, buckets[j].lower) + if err != nil { + sortError = err + } + return res < 0 + }) + if sortError != nil { + return nil, sortError + } + var sum int64 + r := len(buckets) + bucketCount := int64(1) + for i := len(buckets) - 1; i >= 0; i-- { + sum += buckets[i].Count + if sum >= totCount*bucketCount/expBucketNumber { + // if the buckets have the same upper, we merge them into the same new buckets. + for ; i > 0; i-- { + res, err := buckets[i-1].upper.CompareDatum(sc, buckets[i].upper) + if err != nil { + return nil, err + } + if res != 0 { + break + } + } + merged, err := mergePartitionBuckets(sc, buckets[i:r]) + if err != nil { + return nil, err + } + globalBuckets = append(globalBuckets, merged) + r = i + bucketCount++ + } + } + if r > 0 { + merged, err := mergePartitionBuckets(sc, buckets[0:r]) + if err != nil { + return nil, err + } + globalBuckets = append(globalBuckets, merged) + } + // Because we merge backwards, we need to flip the slices. + for i, j := 0, len(globalBuckets)-1; i < j; i, j = i+1, j-1 { + globalBuckets[i], globalBuckets[j] = globalBuckets[j], globalBuckets[i] + } + + // Calc the bucket lower. + if minValue == nil { + return nil, errors.Errorf("merge partition-level hist failed") + } + globalBuckets[0].lower = minValue.Clone() + for i := 1; i < len(globalBuckets); i++ { + globalBuckets[i].lower = globalBuckets[i-1].upper.Clone() + globalBuckets[i].Count = globalBuckets[i].Count + globalBuckets[i-1].Count + } + globalHist := NewHistogram(hists[0].ID, 0, totNull, hists[0].LastUpdateVersion, hists[0].Tp, len(globalBuckets), totColSize) + for _, bucket := range globalBuckets { + globalHist.AppendBucketWithNDV(bucket.lower, bucket.upper, bucket.Count, bucket.Repeat, bucket.NDV) + } + return globalHist, nil +} diff --git a/statistics/histogram_test.go b/statistics/histogram_test.go index b017fe1bcf0f8..5d56ecb2ec5d9 100644 --- a/statistics/histogram_test.go +++ b/statistics/histogram_test.go @@ -139,3 +139,181 @@ func (s *testStatisticsSuite) TestValueToString4InvalidKey(c *C) { c.Assert(err, IsNil) c.Assert(res, Equals, "(1, 0.5, \x14)") } + +type bucket4Test struct { + lower int64 + upper int64 + count int64 + repeat int64 + ndv int64 +} + +func genHist4Test(buckets []*bucket4Test, totColSize int64) *Histogram { + h := NewHistogram(0, 0, 0, 0, types.NewFieldType(mysql.TypeLong), len(buckets), totColSize) + for _, bucket := range buckets { + lower := types.NewIntDatum(bucket.lower) + upper := types.NewIntDatum(bucket.upper) + h.AppendBucketWithNDV(&lower, &upper, bucket.count, bucket.repeat, bucket.ndv) + } + return h +} + +func (s *testStatisticsSuite) TestMergePartitionLevelHist(c *C) { + hists := make([]*Histogram, 0, 2) + // Col(1) = [1, 4,|| 6, 9, 9,|| 12, 12, 12,|| 13, 14, 15] + h1Buckets := []*bucket4Test{ + { + lower: 1, + upper: 4, + count: 2, + repeat: 1, + ndv: 2, + }, + { + lower: 6, + upper: 9, + count: 5, + repeat: 2, + ndv: 2, + }, + { + lower: 12, + upper: 12, + count: 8, + repeat: 3, + ndv: 1, + }, + { + lower: 13, + upper: 15, + count: 11, + repeat: 1, + ndv: 3, + }, + } + hists = append(hists, genHist4Test(h1Buckets, 11)) + // Col(2) = [2, 5,|| 6, 7, 7,|| 11, 11, 11,|| 13, 14, 17] + h2Buckets := []*bucket4Test{ + { + lower: 2, + upper: 5, + count: 2, + repeat: 1, + ndv: 2, + }, + { + lower: 6, + upper: 7, + count: 5, + repeat: 2, + ndv: 2, + }, + { + lower: 11, + upper: 11, + count: 8, + repeat: 3, + ndv: 1, + }, + { + lower: 13, + upper: 17, + count: 11, + repeat: 1, + ndv: 3, + }, + } + hists = append(hists, genHist4Test(h2Buckets, 11)) + ctx := mock.NewContext() + sc := ctx.GetSessionVars().StmtCtx + globalHist, err := MergePartitionHist2GlobalHist(sc, hists, 3) + c.Assert(err, IsNil) + expHist := []*bucket4Test{ + { + lower: 1, + upper: 7, + count: 7, + repeat: 2, + ndv: 4, + }, + { + lower: 7, + upper: 11, + count: 13, + repeat: 3, + ndv: 3, + }, + { + lower: 11, + upper: 17, + count: 22, + repeat: 1, + ndv: 5, + }, + } + for i, b := range expHist { + c.Assert(b.lower, Equals, globalHist.GetLower(i).GetInt64()) + c.Assert(b.upper, Equals, globalHist.GetUpper(i).GetInt64()) + c.Assert(b.count, Equals, globalHist.Buckets[i].Count) + c.Assert(b.repeat, Equals, globalHist.Buckets[i].Repeat) + c.Assert(b.ndv, Equals, globalHist.Buckets[i].NDV) + } + c.Assert(globalHist.TotColSize, Equals, int64(22)) +} + +func genBucket4Merging4Test(lower, upper, ndv, disjointNDV int64) bucket4Merging { + l := types.NewIntDatum(lower) + r := types.NewIntDatum(upper) + return bucket4Merging{ + lower: &l, + upper: &r, + Bucket: Bucket{ + NDV: ndv, + }, + disjointNDV: disjointNDV, + } +} + +func (s *testStatisticsSuite) TestMergeBucketNDV(c *C) { + type testData struct { + left bucket4Merging + right bucket4Merging + result bucket4Merging + } + tests := []testData{ + { + left: genBucket4Merging4Test(1, 2, 2, 0), + right: genBucket4Merging4Test(1, 2, 3, 0), + result: genBucket4Merging4Test(1, 2, 3, 0), + }, + { + left: genBucket4Merging4Test(1, 3, 2, 0), + right: genBucket4Merging4Test(2, 3, 2, 0), + result: genBucket4Merging4Test(1, 3, 3, 0), + }, + { + left: genBucket4Merging4Test(1, 3, 2, 0), + right: genBucket4Merging4Test(4, 6, 2, 2), + result: genBucket4Merging4Test(1, 3, 2, 4), + }, + { + left: genBucket4Merging4Test(1, 5, 5, 0), + right: genBucket4Merging4Test(2, 6, 5, 0), + result: genBucket4Merging4Test(1, 6, 6, 0), + }, + { + left: genBucket4Merging4Test(3, 5, 3, 0), + right: genBucket4Merging4Test(2, 6, 4, 0), + result: genBucket4Merging4Test(2, 6, 5, 0), + }, + } + sc := mock.NewContext().GetSessionVars().StmtCtx + for _, t := range tests { + res, err := mergeBucketNDV(sc, &t.left, &t.right) + c.Assert(err, IsNil) + c.Assert(t.result.lower.GetInt64(), Equals, res.lower.GetInt64()) + c.Assert(t.result.upper.GetInt64(), Equals, res.upper.GetInt64()) + c.Assert(t.result.NDV, Equals, res.NDV) + c.Assert(t.result.disjointNDV, Equals, res.disjointNDV) + } +}