From 5da6e48eced1533bfaf24bcd9081158e5a74d2dc Mon Sep 17 00:00:00 2001 From: RogerYK Date: Tue, 24 Nov 2020 11:33:24 +0800 Subject: [PATCH] cherry pick #21062 to release-4.0 Signed-off-by: ti-srebot --- expression/constant.go | 8 +++-- planner/core/expression_rewriter_test.go | 36 ++++++++++++++++++++++ types/etc.go | 9 ++++++ types/field_type.go | 29 ++++++++++++++++-- types/field_type_test.go | 38 ++++++++++++++++++++++++ 5 files changed, 115 insertions(+), 5 deletions(-) diff --git a/expression/constant.go b/expression/constant.go index 6ff4ed05dbf4f..3aaa2deaaa2be 100644 --- a/expression/constant.go +++ b/expression/constant.go @@ -28,17 +28,21 @@ import ( // NewOne stands for a number 1. func NewOne() *Constant { + retT := types.NewFieldType(mysql.TypeTiny) + retT.Flag |= mysql.UnsignedFlag // shrink range to avoid integral promotion return &Constant{ Value: types.NewDatum(1), - RetType: types.NewFieldType(mysql.TypeTiny), + RetType: retT, } } // NewZero stands for a number 0. func NewZero() *Constant { + retT := types.NewFieldType(mysql.TypeTiny) + retT.Flag |= mysql.UnsignedFlag // shrink range to avoid integral promotion return &Constant{ Value: types.NewDatum(0), - RetType: types.NewFieldType(mysql.TypeTiny), + RetType: retT, } } diff --git a/planner/core/expression_rewriter_test.go b/planner/core/expression_rewriter_test.go index a2253bf75d154..bab6438e1a60d 100644 --- a/planner/core/expression_rewriter_test.go +++ b/planner/core/expression_rewriter_test.go @@ -292,3 +292,39 @@ func (s *testExpressionRewriterSuite) TestIssue20007(c *C) { testkit.Rows("2 epic wiles 2020-01-02 23:29:51", "3 silly burnell 2020-02-25 07:43:07")) } } + +func (s *testExpressionRewriterSuite) TestIssue9869(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + defer func() { + dom.Close() + store.Close() + }() + + tk.MustExec("use test;") + tk.MustExec("drop table if exists t1;") + tk.MustExec("create table t1(a int, b bigint unsigned);") + tk.MustExec("insert into t1 (a, b) values (1,4572794622775114594), (2,18196094287899841997),(3,11120436154190595086);") + tk.MustQuery("select (case t1.a when 0 then 0 else t1.b end), cast(t1.b as signed) from t1;").Check( + testkit.Rows("4572794622775114594 4572794622775114594", "18196094287899841997 -250649785809709619", "11120436154190595086 -7326307919518956530")) +} + +func (s *testExpressionRewriterSuite) TestIssue17652(c *C) { + defer testleak.AfterTest(c)() + store, dom, err := newStoreWithBootstrap() + c.Assert(err, IsNil) + tk := testkit.NewTestKit(c, store) + defer func() { + dom.Close() + store.Close() + }() + + tk.MustExec("use test;") + tk.MustExec("drop table if exists t;") + tk.MustExec("create table t(x bigint unsigned);") + tk.MustExec("insert into t values( 9999999703771440633);") + tk.MustQuery("select ifnull(max(x), 0) from t").Check( + testkit.Rows("9999999703771440633")) +} diff --git a/types/etc.go b/types/etc.go index e29c91171e5a7..46839d78419b5 100644 --- a/types/etc.go +++ b/types/etc.go @@ -64,6 +64,15 @@ func IsTypeTime(tp byte) bool { return tp == mysql.TypeDatetime || tp == mysql.TypeDate || tp == mysql.TypeTimestamp } +// IsTypeInteger returns a boolean indicating whether the tp is integer type. +func IsTypeInteger(tp byte) bool { + switch tp { + case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear: + return true + } + return false +} + // IsTypeNumeric returns a boolean indicating whether the tp is numeric type. func IsTypeNumeric(tp byte) bool { switch tp { diff --git a/types/field_type.go b/types/field_type.go index 767d297f634ff..dbcf7c3632eb7 100644 --- a/types/field_type.go +++ b/types/field_type.go @@ -65,15 +65,38 @@ func NewFieldTypeWithCollation(tp byte, collation string, length int) *FieldType // Aggregation is performed by MergeFieldType function. func AggFieldType(tps []*FieldType) *FieldType { var currType FieldType + isMixedSign := false for i, t := range tps { if i == 0 && currType.Tp == mysql.TypeUnspecified { currType = *t continue } mtp := MergeFieldType(currType.Tp, t.Tp) + isMixedSign = isMixedSign || (mysql.HasUnsignedFlag(currType.Flag) != mysql.HasUnsignedFlag(t.Flag)) currType.Tp = mtp currType.Flag = mergeTypeFlag(currType.Flag, t.Flag) } + // integral promotion when tps contains signed and unsigned + if isMixedSign && IsTypeInteger(currType.Tp) { + bumpRange := false // indicate one of tps bump currType range + for _, t := range tps { + bumpRange = bumpRange || (mysql.HasUnsignedFlag(t.Flag) && (t.Tp == currType.Tp || t.Tp == mysql.TypeBit)) + } + if bumpRange { + switch currType.Tp { + case mysql.TypeTiny: + currType.Tp = mysql.TypeShort + case mysql.TypeShort: + currType.Tp = mysql.TypeInt24 + case mysql.TypeInt24: + currType.Tp = mysql.TypeLong + case mysql.TypeLong: + currType.Tp = mysql.TypeLonglong + case mysql.TypeLonglong: + currType.Tp = mysql.TypeNewDecimal + } + } + } return &currType } @@ -310,10 +333,10 @@ func MergeFieldType(a byte, b byte) byte { } // mergeTypeFlag merges two MySQL type flag to a new one -// currently only NotNullFlag is checked -// todo more flag need to be checked, for example: UnsignedFlag +// currently only NotNullFlag and UnsignedFlag is checked +// todo more flag need to be checked func mergeTypeFlag(a, b uint) uint { - return a & (b&mysql.NotNullFlag | ^mysql.NotNullFlag) + return a & (b&mysql.NotNullFlag | ^mysql.NotNullFlag) & (b&mysql.UnsignedFlag | ^mysql.UnsignedFlag) } func getFieldTypeIndex(tp byte) int { diff --git a/types/field_type_test.go b/types/field_type_test.go index 9b50d65e9a7d9..4d2583ec97d97 100644 --- a/types/field_type_test.go +++ b/types/field_type_test.go @@ -327,6 +327,44 @@ func (s *testFieldTypeSuite) TestAggFieldTypeForTypeFlag(c *C) { c.Assert(aggTp.Flag, Equals, mysql.NotNullFlag) } +func (s testFieldTypeSuite) TestAggFieldTypeForIntegralPromotion(c *C) { + fts := []*FieldType{ + NewFieldType(mysql.TypeTiny), + NewFieldType(mysql.TypeShort), + NewFieldType(mysql.TypeInt24), + NewFieldType(mysql.TypeLong), + NewFieldType(mysql.TypeLonglong), + NewFieldType(mysql.TypeNewDecimal), + } + + for i := 1; i < len(fts)-1; i++ { + tps := fts[i-1 : i+1] + + tps[0].Flag = 0 + tps[1].Flag = 0 + aggTp := AggFieldType(tps) + c.Assert(aggTp.Tp, Equals, fts[i].Tp) + c.Assert(aggTp.Flag, Equals, uint(0)) + + tps[0].Flag = mysql.UnsignedFlag + aggTp = AggFieldType(tps) + c.Assert(aggTp.Tp, Equals, fts[i].Tp) + c.Assert(aggTp.Flag, Equals, uint(0)) + + tps[0].Flag = mysql.UnsignedFlag + tps[1].Flag = mysql.UnsignedFlag + aggTp = AggFieldType(tps) + c.Assert(aggTp.Tp, Equals, fts[i].Tp) + c.Assert(aggTp.Flag, Equals, mysql.UnsignedFlag) + + tps[0].Flag = 0 + tps[1].Flag = mysql.UnsignedFlag + aggTp = AggFieldType(tps) + c.Assert(aggTp.Tp, Equals, fts[i+1].Tp) + c.Assert(aggTp.Flag, Equals, uint(0)) + } +} + func (s *testFieldTypeSuite) TestAggregateEvalType(c *C) { defer testleak.AfterTest(c)() fts := []*FieldType{