Skip to content

Commit

Permalink
planner, type: fix AggFieldType error when encouter unsigned and sign…
Browse files Browse the repository at this point in the history
… type (#21062) (#21236)
  • Loading branch information
ti-srebot authored Nov 26, 2020
1 parent 4417eb8 commit 1e73c51
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 5 deletions.
8 changes: 6 additions & 2 deletions expression/constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,21 @@ import (

// NewOne stands for a number 1.
func NewOne() *Constant {
retT := types.NewFieldType(mysql.TypeTiny)
retT.Flag |= mysql.UnsignedFlag // shrink range to avoid integral promotion
return &Constant{
Value: types.NewDatum(1),
RetType: types.NewFieldType(mysql.TypeTiny),
RetType: retT,
}
}

// NewZero stands for a number 0.
func NewZero() *Constant {
retT := types.NewFieldType(mysql.TypeTiny)
retT.Flag |= mysql.UnsignedFlag // shrink range to avoid integral promotion
return &Constant{
Value: types.NewDatum(0),
RetType: types.NewFieldType(mysql.TypeTiny),
RetType: retT,
}
}

Expand Down
36 changes: 36 additions & 0 deletions planner/core/expression_rewriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -292,3 +292,39 @@ func (s *testExpressionRewriterSuite) TestIssue20007(c *C) {
testkit.Rows("2 epic wiles 2020-01-02 23:29:51", "3 silly burnell 2020-02-25 07:43:07"))
}
}

func (s *testExpressionRewriterSuite) TestIssue9869(c *C) {
defer testleak.AfterTest(c)()
store, dom, err := newStoreWithBootstrap()
c.Assert(err, IsNil)
tk := testkit.NewTestKit(c, store)
defer func() {
dom.Close()
store.Close()
}()

tk.MustExec("use test;")
tk.MustExec("drop table if exists t1;")
tk.MustExec("create table t1(a int, b bigint unsigned);")
tk.MustExec("insert into t1 (a, b) values (1,4572794622775114594), (2,18196094287899841997),(3,11120436154190595086);")
tk.MustQuery("select (case t1.a when 0 then 0 else t1.b end), cast(t1.b as signed) from t1;").Check(
testkit.Rows("4572794622775114594 4572794622775114594", "18196094287899841997 -250649785809709619", "11120436154190595086 -7326307919518956530"))
}

func (s *testExpressionRewriterSuite) TestIssue17652(c *C) {
defer testleak.AfterTest(c)()
store, dom, err := newStoreWithBootstrap()
c.Assert(err, IsNil)
tk := testkit.NewTestKit(c, store)
defer func() {
dom.Close()
store.Close()
}()

tk.MustExec("use test;")
tk.MustExec("drop table if exists t;")
tk.MustExec("create table t(x bigint unsigned);")
tk.MustExec("insert into t values( 9999999703771440633);")
tk.MustQuery("select ifnull(max(x), 0) from t").Check(
testkit.Rows("9999999703771440633"))
}
9 changes: 9 additions & 0 deletions types/etc.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,15 @@ func IsTypeTime(tp byte) bool {
return tp == mysql.TypeDatetime || tp == mysql.TypeDate || tp == mysql.TypeTimestamp
}

// IsTypeInteger returns a boolean indicating whether the tp is integer type.
func IsTypeInteger(tp byte) bool {
switch tp {
case mysql.TypeTiny, mysql.TypeShort, mysql.TypeInt24, mysql.TypeLong, mysql.TypeLonglong, mysql.TypeYear:
return true
}
return false
}

// IsTypeNumeric returns a boolean indicating whether the tp is numeric type.
func IsTypeNumeric(tp byte) bool {
switch tp {
Expand Down
29 changes: 26 additions & 3 deletions types/field_type.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,38 @@ func NewFieldTypeWithCollation(tp byte, collation string, length int) *FieldType
// Aggregation is performed by MergeFieldType function.
func AggFieldType(tps []*FieldType) *FieldType {
var currType FieldType
isMixedSign := false
for i, t := range tps {
if i == 0 && currType.Tp == mysql.TypeUnspecified {
currType = *t
continue
}
mtp := MergeFieldType(currType.Tp, t.Tp)
isMixedSign = isMixedSign || (mysql.HasUnsignedFlag(currType.Flag) != mysql.HasUnsignedFlag(t.Flag))
currType.Tp = mtp
currType.Flag = mergeTypeFlag(currType.Flag, t.Flag)
}
// integral promotion when tps contains signed and unsigned
if isMixedSign && IsTypeInteger(currType.Tp) {
bumpRange := false // indicate one of tps bump currType range
for _, t := range tps {
bumpRange = bumpRange || (mysql.HasUnsignedFlag(t.Flag) && (t.Tp == currType.Tp || t.Tp == mysql.TypeBit))
}
if bumpRange {
switch currType.Tp {
case mysql.TypeTiny:
currType.Tp = mysql.TypeShort
case mysql.TypeShort:
currType.Tp = mysql.TypeInt24
case mysql.TypeInt24:
currType.Tp = mysql.TypeLong
case mysql.TypeLong:
currType.Tp = mysql.TypeLonglong
case mysql.TypeLonglong:
currType.Tp = mysql.TypeNewDecimal
}
}
}

return &currType
}
Expand Down Expand Up @@ -310,10 +333,10 @@ func MergeFieldType(a byte, b byte) byte {
}

// mergeTypeFlag merges two MySQL type flag to a new one
// currently only NotNullFlag is checked
// todo more flag need to be checked, for example: UnsignedFlag
// currently only NotNullFlag and UnsignedFlag is checked
// todo more flag need to be checked
func mergeTypeFlag(a, b uint) uint {
return a & (b&mysql.NotNullFlag | ^mysql.NotNullFlag)
return a & (b&mysql.NotNullFlag | ^mysql.NotNullFlag) & (b&mysql.UnsignedFlag | ^mysql.UnsignedFlag)
}

func getFieldTypeIndex(tp byte) int {
Expand Down
38 changes: 38 additions & 0 deletions types/field_type_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,44 @@ func (s *testFieldTypeSuite) TestAggFieldTypeForTypeFlag(c *C) {
c.Assert(aggTp.Flag, Equals, mysql.NotNullFlag)
}

func (s testFieldTypeSuite) TestAggFieldTypeForIntegralPromotion(c *C) {
fts := []*FieldType{
NewFieldType(mysql.TypeTiny),
NewFieldType(mysql.TypeShort),
NewFieldType(mysql.TypeInt24),
NewFieldType(mysql.TypeLong),
NewFieldType(mysql.TypeLonglong),
NewFieldType(mysql.TypeNewDecimal),
}

for i := 1; i < len(fts)-1; i++ {
tps := fts[i-1 : i+1]

tps[0].Flag = 0
tps[1].Flag = 0
aggTp := AggFieldType(tps)
c.Assert(aggTp.Tp, Equals, fts[i].Tp)
c.Assert(aggTp.Flag, Equals, uint(0))

tps[0].Flag = mysql.UnsignedFlag
aggTp = AggFieldType(tps)
c.Assert(aggTp.Tp, Equals, fts[i].Tp)
c.Assert(aggTp.Flag, Equals, uint(0))

tps[0].Flag = mysql.UnsignedFlag
tps[1].Flag = mysql.UnsignedFlag
aggTp = AggFieldType(tps)
c.Assert(aggTp.Tp, Equals, fts[i].Tp)
c.Assert(aggTp.Flag, Equals, mysql.UnsignedFlag)

tps[0].Flag = 0
tps[1].Flag = mysql.UnsignedFlag
aggTp = AggFieldType(tps)
c.Assert(aggTp.Tp, Equals, fts[i+1].Tp)
c.Assert(aggTp.Flag, Equals, uint(0))
}
}

func (s *testFieldTypeSuite) TestAggregateEvalType(c *C) {
defer testleak.AfterTest(c)()
fts := []*FieldType{
Expand Down

0 comments on commit 1e73c51

Please sign in to comment.