From 24f329f290198d012d3351ff3125fca620902c89 Mon Sep 17 00:00:00 2001 From: Haibin Xie Date: Wed, 9 Jan 2019 10:57:55 +0800 Subject: [PATCH] stats: fix histogram bound overflow error (#8984) --- statistics/feedback.go | 86 +++++++++++++++++++++---------------- statistics/feedback_test.go | 4 +- statistics/update.go | 3 +- statistics/update_test.go | 58 +++++++++++++++++++++++++ 4 files changed, 110 insertions(+), 41 deletions(-) diff --git a/statistics/feedback.go b/statistics/feedback.go index b0b3308f78db4..89c6eb6945293 100644 --- a/statistics/feedback.go +++ b/statistics/feedback.go @@ -296,8 +296,7 @@ func buildBucketFeedback(h *Histogram, feedback *QueryFeedback) (map[int]*Bucket } total := 0 sc := &stmtctx.StatementContext{TimeZone: time.UTC} - kind := feedback.feedback[0].lower.Kind() - min, max := getMinValue(kind, h.tp), getMaxValue(kind, h.tp) + min, max := getMinValue(h.tp), getMaxValue(h.tp) for _, fb := range feedback.feedback { skip, err := fb.adjustFeedbackBoundaries(sc, &min, &max) if err != nil { @@ -725,11 +724,18 @@ func decodeFeedbackForIndex(q *QueryFeedback, pb *queryFeedback, c *CMSketch) { } } -func decodeFeedbackForPK(q *QueryFeedback, pb *queryFeedback) { +func decodeFeedbackForPK(q *QueryFeedback, pb *queryFeedback, isUnsigned bool) { q.tp = pkType // decode feedback for primary key for i := 0; i < len(pb.IntRanges); i += 2 { - lower, upper := types.NewIntDatum(pb.IntRanges[i]), types.NewIntDatum(pb.IntRanges[i+1]) + var lower, upper types.Datum + if isUnsigned { + lower.SetUint64(uint64(pb.IntRanges[i])) + upper.SetUint64(uint64(pb.IntRanges[i+1])) + } else { + lower.SetInt64(pb.IntRanges[i]) + upper.SetInt64(pb.IntRanges[i+1]) + } q.feedback = append(q.feedback, feedback{&lower, &upper, pb.Counts[i/2], 0}) } } @@ -750,7 +756,7 @@ func decodeFeedbackForColumn(q *QueryFeedback, pb *queryFeedback) error { return nil } -func decodeFeedback(val []byte, q *QueryFeedback, c *CMSketch) error { +func decodeFeedback(val []byte, q *QueryFeedback, c *CMSketch, isUnsigned bool) error { buf := bytes.NewBuffer(val) dec := gob.NewDecoder(buf) pb := &queryFeedback{} @@ -761,7 +767,7 @@ func decodeFeedback(val []byte, q *QueryFeedback, c *CMSketch) error { if len(pb.IndexRanges) > 0 || len(pb.HashValues) > 0 { decodeFeedbackForIndex(q, pb, c) } else if len(pb.IntRanges) > 0 { - decodeFeedbackForPK(q, pb) + decodeFeedbackForPK(q, pb, isUnsigned) } else { err := decodeFeedbackForColumn(q, pb) if err != nil { @@ -1075,15 +1081,14 @@ func (q *QueryFeedback) dumpRangeFeedback(h *Handle, ran *ranger.Range, rangeCou ran.LowVal[0].SetBytes(lower) ran.HighVal[0].SetBytes(upper) } else { - k := q.hist.GetLower(0).Kind() - if !supportColumnType(k) { + if !supportColumnType(q.hist.tp) { return nil } if ran.LowVal[0].Kind() == types.KindMinNotNull { - ran.LowVal[0] = getMinValue(k, q.hist.tp) + ran.LowVal[0] = getMinValue(q.hist.tp) } if ran.HighVal[0].Kind() == types.KindMaxValue { - ran.HighVal[0] = getMaxValue(k, q.hist.tp) + ran.HighVal[0] = getMaxValue(q.hist.tp) } } ranges := q.hist.SplitRange([]*ranger.Range{ran}) @@ -1130,27 +1135,30 @@ func setNextValue(d *types.Datum) { } // supportColumnType checks if the type of the column can be updated by feedback. -func supportColumnType(k byte) bool { - switch k { - case types.KindInt64, types.KindUint64, types.KindFloat32, types.KindFloat64, types.KindString, types.KindBytes, - types.KindMysqlDecimal, types.KindMysqlDuration, types.KindMysqlTime: +func supportColumnType(ft *types.FieldType) bool { + switch ft.Tp { + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeFloat, + mysql.TypeDouble, mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob, + mysql.TypeNewDecimal, mysql.TypeDuration, mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: return true default: return false } } -func getMaxValue(k byte, ft *types.FieldType) (max types.Datum) { - switch k { - case types.KindInt64: - max.SetInt64(types.SignedUpperBound[ft.Tp]) - case types.KindUint64: - max.SetUint64(types.UnsignedUpperBound[ft.Tp]) - case types.KindFloat32: +func getMaxValue(ft *types.FieldType) (max types.Datum) { + switch ft.Tp { + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + if mysql.HasUnsignedFlag(ft.Flag) { + max.SetUint64(types.UnsignedUpperBound[ft.Tp]) + } else { + max.SetInt64(types.SignedUpperBound[ft.Tp]) + } + case mysql.TypeFloat: max.SetFloat32(float32(types.GetMaxFloat(ft.Flen, ft.Decimal))) - case types.KindFloat64: + case mysql.TypeDouble: max.SetFloat64(types.GetMaxFloat(ft.Flen, ft.Decimal)) - case types.KindString, types.KindBytes: + case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: val := types.MaxValueDatum() bytes, err := codec.EncodeKey(nil, nil, val) // should not happen @@ -1158,11 +1166,11 @@ func getMaxValue(k byte, ft *types.FieldType) (max types.Datum) { log.Error(err) } max.SetBytes(bytes) - case types.KindMysqlDecimal: + case mysql.TypeNewDecimal: max.SetMysqlDecimal(types.NewMaxOrMinDec(false, ft.Flen, ft.Decimal)) - case types.KindMysqlDuration: + case mysql.TypeDuration: max.SetMysqlDuration(types.Duration{Duration: math.MaxInt64}) - case types.KindMysqlTime: + case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: if ft.Tp == mysql.TypeDate || ft.Tp == mysql.TypeDatetime { max.SetMysqlTime(types.Time{Time: types.MaxDatetime, Type: ft.Tp}) } else { @@ -1172,17 +1180,19 @@ func getMaxValue(k byte, ft *types.FieldType) (max types.Datum) { return } -func getMinValue(k byte, ft *types.FieldType) (min types.Datum) { - switch k { - case types.KindInt64: - min.SetInt64(types.SignedLowerBound[ft.Tp]) - case types.KindUint64: - min.SetUint64(0) - case types.KindFloat32: +func getMinValue(ft *types.FieldType) (min types.Datum) { + switch ft.Tp { + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong: + if mysql.HasUnsignedFlag(ft.Flag) { + min.SetUint64(0) + } else { + min.SetInt64(types.SignedLowerBound[ft.Tp]) + } + case mysql.TypeFloat: min.SetFloat32(float32(-types.GetMaxFloat(ft.Flen, ft.Decimal))) - case types.KindFloat64: + case mysql.TypeDouble: min.SetFloat64(-types.GetMaxFloat(ft.Flen, ft.Decimal)) - case types.KindString, types.KindBytes: + case mysql.TypeString, mysql.TypeVarString, mysql.TypeVarchar, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob: val := types.MinNotNullDatum() bytes, err := codec.EncodeKey(nil, nil, val) // should not happen @@ -1190,11 +1200,11 @@ func getMinValue(k byte, ft *types.FieldType) (min types.Datum) { log.Error(err) } min.SetBytes(bytes) - case types.KindMysqlDecimal: + case mysql.TypeNewDecimal: min.SetMysqlDecimal(types.NewMaxOrMinDec(true, ft.Flen, ft.Decimal)) - case types.KindMysqlDuration: + case mysql.TypeDuration: min.SetMysqlDuration(types.Duration{Duration: math.MinInt64}) - case types.KindMysqlTime: + case mysql.TypeDate, mysql.TypeDatetime, mysql.TypeTimestamp: if ft.Tp == mysql.TypeDate || ft.Tp == mysql.TypeDatetime { min.SetMysqlTime(types.Time{Time: types.MinDatetime, Type: ft.Tp}) } else { diff --git a/statistics/feedback_test.go b/statistics/feedback_test.go index 08058c386bef2..da19fc233342d 100644 --- a/statistics/feedback_test.go +++ b/statistics/feedback_test.go @@ -221,7 +221,7 @@ func (s *testFeedbackSuite) TestFeedbackEncoding(c *C) { val, err := encodeFeedback(q) c.Assert(err, IsNil) rq := &QueryFeedback{} - c.Assert(decodeFeedback(val, rq, nil), IsNil) + c.Assert(decodeFeedback(val, rq, nil, false), IsNil) for _, fb := range rq.feedback { fb.lower.SetBytes(codec.EncodeInt(nil, fb.lower.GetInt64())) fb.upper.SetBytes(codec.EncodeInt(nil, fb.upper.GetInt64())) @@ -236,7 +236,7 @@ func (s *testFeedbackSuite) TestFeedbackEncoding(c *C) { c.Assert(err, IsNil) rq = &QueryFeedback{} cms := NewCMSketch(4, 4) - c.Assert(decodeFeedback(val, rq, cms), IsNil) + c.Assert(decodeFeedback(val, rq, cms, false), IsNil) c.Assert(cms.QueryBytes(codec.EncodeInt(nil, 0)), Equals, uint32(1)) q.feedback = q.feedback[:1] c.Assert(q.Equal(rq), IsTrue) diff --git a/statistics/update.go b/statistics/update.go index 119c10fa8728d..2b64bf3acb735 100644 --- a/statistics/update.go +++ b/statistics/update.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/errors" "github.com/pingcap/parser/model" + "github.com/pingcap/parser/mysql" "github.com/pingcap/tidb/infoschema" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/sessionctx/variable" @@ -559,7 +560,7 @@ func (h *Handle) handleSingleHistogramUpdate(is infoschema.InfoSchema, rows []ch } q := &QueryFeedback{} for _, row := range rows { - err1 := decodeFeedback(row.GetBytes(3), q, cms) + err1 := decodeFeedback(row.GetBytes(3), q, cms, mysql.HasUnsignedFlag(hist.tp.Flag)) if err1 != nil { log.Debugf("decode feedback failed, err: %v", errors.ErrorStack(err)) } diff --git a/statistics/update_test.go b/statistics/update_test.go index 7a458dc0df0e3..614b87c97aa4d 100644 --- a/statistics/update_test.go +++ b/statistics/update_test.go @@ -1225,3 +1225,61 @@ func (s *testStatsUpdateSuite) TestFeedbackRanges(c *C) { c.Assert(tbl.Columns[t.colID].ToString(0), Equals, tests[i].hist) } } + +func (s *testStatsUpdateSuite) TestUnsignedFeedbackRanges(c *C) { + defer cleanEnv(c, s.store, s.do) + testKit := testkit.NewTestKit(c, s.store) + h := s.do.StatsHandle() + oriProbability := statistics.FeedbackProbability + oriNumber := statistics.MaxNumberOfRanges + defer func() { + statistics.FeedbackProbability = oriProbability + statistics.MaxNumberOfRanges = oriNumber + }() + statistics.FeedbackProbability = 1 + + testKit.MustExec("use test") + testKit.MustExec("create table t (a tinyint unsigned, primary key(a))") + for i := 0; i < 20; i++ { + testKit.MustExec(fmt.Sprintf("insert into t values (%d)", i)) + } + h.HandleDDLEvent(<-h.DDLEventCh()) + c.Assert(h.DumpStatsDeltaToKV(statistics.DumpAll), IsNil) + testKit.MustExec("analyze table t with 3 buckets") + for i := 30; i < 40; i++ { + testKit.MustExec(fmt.Sprintf("insert into t values (%d)", i)) + } + c.Assert(h.DumpStatsDeltaToKV(statistics.DumpAll), IsNil) + tests := []struct { + sql string + hist string + }{ + { + sql: "select * from t where a <= 50", + hist: "column:1 ndv:30 totColSize:0\n" + + "num: 8 lower_bound: 0 upper_bound: 7 repeats: 0\n" + + "num: 8 lower_bound: 8 upper_bound: 15 repeats: 0\n" + + "num: 14 lower_bound: 16 upper_bound: 50 repeats: 0", + }, + { + sql: "select count(*) from t", + hist: "column:1 ndv:30 totColSize:0\n" + + "num: 8 lower_bound: 0 upper_bound: 7 repeats: 0\n" + + "num: 8 lower_bound: 8 upper_bound: 15 repeats: 0\n" + + "num: 14 lower_bound: 16 upper_bound: 255 repeats: 0", + }, + } + is := s.do.InfoSchema() + table, err := is.TableByName(model.NewCIStr("test"), model.NewCIStr("t")) + for i, t := range tests { + testKit.MustQuery(t.sql) + c.Assert(h.DumpStatsDeltaToKV(statistics.DumpAll), IsNil) + c.Assert(h.DumpStatsFeedbackToKV(), IsNil) + c.Assert(h.HandleUpdateStats(s.do.InfoSchema()), IsNil) + c.Assert(err, IsNil) + h.Update(is) + tblInfo := table.Meta() + tbl := h.GetTableStats(tblInfo) + c.Assert(tbl.Columns[1].ToString(0), Equals, tests[i].hist) + } +}