Skip to content

Commit

Permalink
expression: fix nil pointer dereference for case expression (#30479) (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-srebot authored Jan 6, 2022
1 parent 8889c7c commit ac75188
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 17 deletions.
16 changes: 2 additions & 14 deletions expression/builtin_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ package expression

import (
"github.com/cznic/mathutil"
"github.com/pingcap/tidb/parser/charset"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
Expand Down Expand Up @@ -172,7 +171,7 @@ func (c *caseWhenFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
l := len(args)
// Fill in each 'THEN' clause parameter type.
fieldTps := make([]*types.FieldType, 0, (l+1)/2)
decimal, flen, isBinaryStr, isBinaryFlag := args[1].GetType().Decimal, 0, false, false
decimal, flen, isBinaryFlag := args[1].GetType().Decimal, 0, false
for i := 1; i < l; i += 2 {
fieldTps = append(fieldTps, args[i].GetType())
decimal = mathutil.Max(decimal, args[i].GetType().Decimal)
Expand All @@ -181,7 +180,6 @@ func (c *caseWhenFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
} else if flen != -1 {
flen = mathutil.Max(flen, args[i].GetType().Flen)
}
isBinaryStr = isBinaryStr || types.IsBinaryStr(args[i].GetType())
isBinaryFlag = isBinaryFlag || !types.IsNonBinaryStr(args[i].GetType())
}
if l%2 == 1 {
Expand All @@ -192,7 +190,6 @@ func (c *caseWhenFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
} else if flen != -1 {
flen = mathutil.Max(flen, args[l-1].GetType().Flen)
}
isBinaryStr = isBinaryStr || types.IsBinaryStr(args[l-1].GetType())
isBinaryFlag = isBinaryFlag || !types.IsNonBinaryStr(args[l-1].GetType())
}

Expand All @@ -207,16 +204,6 @@ func (c *caseWhenFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
}
fieldTp.Decimal, fieldTp.Flen = decimal, flen
types.TryToFixFlenOfDatetime(fieldTp)
if fieldTp.EvalType().IsStringKind() && !isBinaryStr {
fieldTp.Charset, fieldTp.Collate = DeriveCollationFromExprs(ctx, args...)
if fieldTp.Charset == charset.CharsetBin && fieldTp.Collate == charset.CollationBin {
// When args are Json and Numerical type(eg. Int), the fieldTp is String.
// Both their charset/collation is binary, but the String need a default charset/collation.
fieldTp.Charset, fieldTp.Collate = charset.GetDefaultCharsetAndCollate()
}
} else {
fieldTp.Charset, fieldTp.Collate = charset.CharsetBin, charset.CollationBin
}
if isBinaryFlag {
fieldTp.Flag |= mysql.BinaryFlag
}
Expand All @@ -239,6 +226,7 @@ func (c *caseWhenFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
if err != nil {
return nil, err
}
fieldTp.Charset, fieldTp.Collate = bf.tp.Charset, bf.tp.Collate
bf.tp = fieldTp
if fieldTp.Tp == mysql.TypeEnum || fieldTp.Tp == mysql.TypeSet {
switch tp {
Expand Down
23 changes: 22 additions & 1 deletion expression/collation.go
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,28 @@ func deriveCollation(ctx sessionctx.Context, funcName string, args []Expression,
return ec, nil
case ast.Case:
// FIXME: case function aggregate collation is not correct.
return CheckAndDeriveCollationFromExprs(ctx, funcName, retType, args...)
// We should only aggregate the `then expression`,
// case ... when ... expression will be rewritten to:
// args: eq scalar func(args: value, condition1), result1,
// eq scalar func(args: value, condition2), result2,
// ...
// else clause
// Or
// args: condition1, result1,
// condition2, result2,
// ...
// else clause
// so, arguments with odd index are the `then expression`.
if argTps[1] == types.ETString {
fieldArgs := make([]Expression, 0)
for i := 1; i < len(args); i += 2 {
fieldArgs = append(fieldArgs, args[i])
}
if len(args)%2 == 1 {
fieldArgs = append(fieldArgs, args[len(args)-1])
}
return CheckAndDeriveCollationFromExprs(ctx, funcName, retType, fieldArgs...)
}
case ast.Database, ast.User, ast.CurrentUser, ast.Version, ast.CurrentRole, ast.TiDBVersion:
chs, coll := charset.GetDefaultCharsetAndCollate()
return &ExprCollation{CoercibilitySysconst, UNICODE, chs, coll}, nil
Expand Down
12 changes: 12 additions & 0 deletions expression/integration_serial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1498,6 +1498,18 @@ func TestIssue26662(t *testing.T) {
Check(testkit.Rows())
}

func TestIssue30245(t *testing.T) {
collate.SetNewCollationEnabledForTest(true)
defer collate.SetNewCollationEnabledForTest(false)
store, clean := testkit.CreateMockStore(t)
defer clean()

tk := testkit.NewTestKit(t, store)
tk.MustGetErrCode("select case 1 when 1 then 'a' collate utf8mb4_unicode_ci else 'b' collate utf8mb4_general_ci end", mysql.ErrCantAggregate2collations)
tk.MustGetErrCode("select case when 1 then 'a' collate utf8mb4_unicode_ci when 2 then 'b' collate utf8mb4_general_ci end", mysql.ErrCantAggregate2collations)
tk.MustGetErrCode("select case 1 when 1 then 'a' collate utf8mb4_unicode_ci when 2 then 'b' collate utf8mb4_general_ci else 'b' collate utf8mb4_bin end", mysql.ErrCantAggregate3collations)
}

func TestCollationForBinaryLiteral(t *testing.T) {
store, clean := testkit.CreateMockStore(t)
defer clean()
Expand Down
4 changes: 2 additions & 2 deletions expression/typeinfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -840,8 +840,8 @@ func (s *InferTypeSuite) createTestCase4ControlFuncs() []typeInferTestCase {
{"case when c_int_d > 1 then c_double_d else c_bchar end", mysql.TypeString, charset.CharsetUTF8MB4, mysql.BinaryFlag, 22, types.UnspecifiedLength},
{"case when c_int_d > 2 then c_double_d when c_int_d < 1 then c_decimal else c_double_d end", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, 22, 3},
{"case when c_double_d > 2 then c_decimal else 1 end", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 6, 3},
{"case when c_time is not null then c_time else c_date end", mysql.TypeDatetime, charset.CharsetUTF8MB4, mysql.BinaryFlag, mysql.MaxDatetimeWidthNoFsp + 3 + 1, 3},
{"case when c_time_d is not null then c_time_d else c_date end", mysql.TypeDatetime, charset.CharsetUTF8MB4, mysql.BinaryFlag, mysql.MaxDatetimeWidthNoFsp, 0},
{"case when c_time is not null then c_time else c_date end", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDatetimeWidthNoFsp + 3 + 1, 3},
{"case when c_time_d is not null then c_time_d else c_date end", mysql.TypeDatetime, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxDatetimeWidthNoFsp, 0},
{"case when null then null else null end", mysql.TypeNull, charset.CharsetBin, mysql.BinaryFlag, 0, types.UnspecifiedLength},
}
}
Expand Down

0 comments on commit ac75188

Please sign in to comment.