Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: fix the return type of coalesce when arg type is DATE (#48032) #48426

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -118,16 +118,26 @@ func (c *coalesceFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
return nil, err
}

fieldTps := make([]*types.FieldType, 0, len(args))
flag := uint(0)
for _, arg := range args {
fieldTps = append(fieldTps, arg.GetType())
flag |= arg.GetType().GetFlag() & mysql.NotNullFlag
}

<<<<<<< HEAD:expression/builtin_compare.go
// Use the aggregated field type as retType.
resultFieldType := types.AggFieldType(fieldTps)
resultEvalType := types.AggregateEvalType(fieldTps, &resultFieldType.Flag)
retEvalTp := resultFieldType.EvalType()
=======
resultFieldType, err := InferType4ControlFuncs(ctx, c.funcName, args...)
if err != nil {
return nil, err
}
>>>>>>> 7cb7af71792 (expression: fix the return type of `coalesce` when arg type is `DATE` (#48032)):pkg/expression/builtin_compare.go

resultFieldType.AddFlag(flag)

retEvalTp := resultFieldType.EvalType()
fieldEvalTps := make([]types.EvalType, 0, len(args))
for range args {
fieldEvalTps = append(fieldEvalTps, retEvalTp)
Expand All @@ -143,6 +153,7 @@ func (c *coalesceFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
return nil, err
}

<<<<<<< HEAD:expression/builtin_compare.go
bf.tp.Flag |= resultFieldType.Flag
resultFieldType.Flen, resultFieldType.Decimal = 0, types.UnspecifiedLength

Expand Down Expand Up @@ -194,6 +205,9 @@ func (c *coalesceFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
bf.tp.Flen = mysql.MaxDecimalWidth
}
}
=======
bf.tp = resultFieldType
>>>>>>> 7cb7af71792 (expression: fix the return type of `coalesce` when arg type is `DATE` (#48032)):pkg/expression/builtin_compare.go

switch retEvalTp {
case types.ETInt:
Expand Down Expand Up @@ -1253,6 +1267,7 @@ func (b *builtinIntervalRealSig) evalInt(row chunk.Row) (int64, bool, error) {
if isNull {
return -1, false, nil
}

var idx int
if b.hasNullable {
idx, err = b.linearSearch(arg0, b.args[1:], row)
Expand Down
26 changes: 26 additions & 0 deletions expression/builtin_compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -389,3 +389,29 @@ func TestGreatestLeastFunc(t *testing.T) {
_, err = funcs[ast.Least].getFunction(ctx, []Expression{NewZero(), NewOne()})
require.NoError(t, err)
}
<<<<<<< HEAD:expression/builtin_compare_test.go
=======

func TestRefineArgsWithCastEnum(t *testing.T) {
ctx := createContext(t)
zeroUintConst := primitiveValsToConstants(ctx, []interface{}{uint64(0)})[0]
enumType := types.NewFieldTypeBuilder().SetType(mysql.TypeEnum).SetElems([]string{"1", "2", "3"}).AddFlag(mysql.EnumSetAsIntFlag).Build()
enumCol := &Column{RetType: &enumType}

f := funcs[ast.EQ].(*compareFunctionClass)
require.NotNil(t, f)

args := f.refineArgsByUnsignedFlag(ctx, []Expression{zeroUintConst, enumCol})
require.Equal(t, zeroUintConst, args[0])
require.Equal(t, enumCol, args[1])
}

func TestIssue46475(t *testing.T) {
ctx := createContext(t)
args := []interface{}{nil, dt, nil}

f, err := newFunctionForTest(ctx, ast.Coalesce, primitiveValsToConstants(ctx, args)...)
require.NoError(t, err)
require.Equal(t, f.GetType().GetType(), mysql.TypeDate)
}
>>>>>>> 7cb7af71792 (expression: fix the return type of `coalesce` when arg type is `DATE` (#48032)):pkg/expression/builtin_compare_test.go
116 changes: 116 additions & 0 deletions expression/builtin_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,36 @@ func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, lexp, rexp
}
}

<<<<<<< HEAD:expression/builtin_control.go
=======
// NonBinaryStr means the arg is a string but not binary string
func hasNonBinaryStr(args []*types.FieldType) bool {
for _, arg := range args {
if types.IsNonBinaryStr(arg) {
return true
}
}
return false
}

func hasBinaryStr(args []*types.FieldType) bool {
for _, arg := range args {
if types.IsBinaryStr(arg) {
return true
}
}
return false
}

func addCollateAndCharsetAndFlagFromArgs(ctx sessionctx.Context, funcName string, evalType types.EvalType, resultFieldType *types.FieldType, args ...Expression) error {
switch funcName {
case ast.If, ast.Ifnull, ast.WindowFuncLead, ast.WindowFuncLag:
if len(args) != 2 {
panic("unexpected length of args for if/ifnull/lead/lag")
}
lexp, rexp := args[0], args[1]
lhs, rhs := lexp.GetType(), rexp.GetType()
>>>>>>> 7cb7af71792 (expression: fix the return type of `coalesce` when arg type is `DATE` (#48032)):pkg/expression/builtin_control.go
if types.IsNonBinaryStr(lhs) && !types.IsBinaryStr(rhs) {
ec, err := CheckAndDeriveCollationFromExprs(ctx, funcName, evalType, lexp, rexp)
if err != nil {
Expand Down Expand Up @@ -123,6 +153,7 @@ func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, lexp, rexp
if !lhsUnsignedFlag {
lhsFlagLen = 1
}
<<<<<<< HEAD:expression/builtin_control.go
if !rhsUnsignedFlag {
rhsFlagLen = 1
}
Expand All @@ -136,16 +167,101 @@ func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, lexp, rexp
}
flen := maxlen(lhsFlen, rhsFlen) + resultFieldType.Decimal + 1 // account for -1 len fields
resultFieldType.Flen = mathutil.Min(flen, mysql.MaxDecimalWidth) // make sure it doesn't overflow
=======
}
case ast.Coalesce: // TODO ast.Case and ast.Coalesce should be merged into the same branch
argTypes := make([]*types.FieldType, 0)
for _, arg := range args {
argTypes = append(argTypes, arg.GetType())
}

nonBinaryStrExist := hasNonBinaryStr(argTypes)
binaryStrExist := hasBinaryStr(argTypes)
if !binaryStrExist && nonBinaryStrExist {
ec, err := CheckAndDeriveCollationFromExprs(ctx, funcName, evalType, args...)
if err != nil {
return err
}
resultFieldType.SetCollate(ec.Collation)
resultFieldType.SetCharset(ec.Charset)
resultFieldType.SetFlag(0)

// hasNonStringType means that there is a type that is not string
hasNonStringType := false
for _, argType := range argTypes {
if !types.IsString(argType.GetType()) {
hasNonStringType = true
break
}
}

if hasNonStringType {
resultFieldType.AddFlag(mysql.BinaryFlag)
}
} else if binaryStrExist || !evalType.IsStringKind() {
types.SetBinChsClnFlag(resultFieldType)
} else {
resultFieldType.SetCharset(mysql.DefaultCharset)
resultFieldType.SetCollate(mysql.DefaultCollationName)
resultFieldType.SetFlag(0)
}
default:
panic("unexpected function: " + funcName)
}
return nil
}

// InferType4ControlFuncs infer result type for builtin IF, IFNULL, NULLIF, CASEWHEN, COALESCE, LEAD and LAG.
func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, args ...Expression) (*types.FieldType, error) {
argsNum := len(args)
if argsNum == 0 {
panic("unexpected length 0 of args")
}
nullFields := make([]*types.FieldType, 0, argsNum)
notNullFields := make([]*types.FieldType, 0, argsNum)
for i := range args {
if args[i].GetType().GetType() == mysql.TypeNull {
nullFields = append(nullFields, args[i].GetType())
>>>>>>> 7cb7af71792 (expression: fix the return type of `coalesce` when arg type is `DATE` (#48032)):pkg/expression/builtin_control.go
} else {
resultFieldType.Flen = maxlen(lhs.Flen, rhs.Flen)
}
}
<<<<<<< HEAD:expression/builtin_control.go
// Fix decimal for int and string.
resultEvalType := resultFieldType.EvalType()
if resultEvalType == types.ETInt {
resultFieldType.Decimal = 0
if resultFieldType.Tp == mysql.TypeEnum || resultFieldType.Tp == mysql.TypeSet {
resultFieldType.Tp = mysql.TypeLonglong
=======
resultFieldType := &types.FieldType{}
if len(nullFields) == argsNum { // all field is TypeNull
*resultFieldType = *nullFields[0]
// If any of arg is NULL, result type need unset NotNullFlag.
tempFlag := resultFieldType.GetFlag()
types.SetTypeFlag(&tempFlag, mysql.NotNullFlag, false)
resultFieldType.SetFlag(tempFlag)

resultFieldType.SetType(mysql.TypeNull)
resultFieldType.SetFlen(0)
resultFieldType.SetDecimal(0)
types.SetBinChsClnFlag(resultFieldType)
} else {
if len(notNullFields) == 1 {
*resultFieldType = *notNullFields[0]
} else {
resultFieldType = types.AggFieldType(notNullFields)
var tempFlag uint
evalType := types.AggregateEvalType(notNullFields, &tempFlag)
resultFieldType.SetFlag(tempFlag)
setDecimalFromArgs(evalType, resultFieldType, notNullFields...)
err := addCollateAndCharsetAndFlagFromArgs(ctx, funcName, evalType, resultFieldType, args...)
if err != nil {
return nil, err
}
setFlenFromArgs(evalType, resultFieldType, notNullFields...)
>>>>>>> 7cb7af71792 (expression: fix the return type of `coalesce` when arg type is `DATE` (#48032)):pkg/expression/builtin_control.go
}
} else if resultEvalType == types.ETString {
if lhs.Tp != mysql.TypeNull || rhs.Tp != mysql.TypeNull {
Expand Down
5 changes: 5 additions & 0 deletions expression/expr_to_pb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,8 +493,13 @@ func TestOtherFunc2Pb(t *testing.T) {
pbExprs, err := ExpressionsToPBList(sc, otherFuncs, client)
require.NoError(t, err)
jsons := map[string]string{
<<<<<<< HEAD:expression/expr_to_pb_test.go
ast.Coalesce: "{\"tp\":10000,\"children\":[{\"tp\":201,\"val\":\"gAAAAAAAAAE=\",\"sig\":0,\"field_type\":{\"tp\":3,\"flag\":0,\"flen\":-1,\"decimal\":-1,\"collate\":63,\"charset\":\"binary\"},\"has_distinct\":false}],\"sig\":4201,\"field_type\":{\"tp\":3,\"flag\":128,\"flen\":0,\"decimal\":-1,\"collate\":63,\"charset\":\"binary\"},\"has_distinct\":false}",
ast.IsNull: "{\"tp\":10000,\"children\":[{\"tp\":201,\"val\":\"gAAAAAAAAAE=\",\"sig\":0,\"field_type\":{\"tp\":3,\"flag\":0,\"flen\":-1,\"decimal\":-1,\"collate\":63,\"charset\":\"binary\"},\"has_distinct\":false}],\"sig\":3116,\"field_type\":{\"tp\":8,\"flag\":524416,\"flen\":1,\"decimal\":0,\"collate\":63,\"charset\":\"binary\"},\"has_distinct\":false}",
=======
ast.Coalesce: "{\"tp\":10000,\"children\":[{\"tp\":201,\"val\":\"gAAAAAAAAAE=\",\"sig\":0,\"field_type\":{\"tp\":3,\"flag\":0,\"flen\":11,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}],\"sig\":4201,\"field_type\":{\"tp\":3,\"flag\":0,\"flen\":11,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}",
ast.IsNull: "{\"tp\":10000,\"children\":[{\"tp\":201,\"val\":\"gAAAAAAAAAE=\",\"sig\":0,\"field_type\":{\"tp\":3,\"flag\":0,\"flen\":11,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}],\"sig\":3116,\"field_type\":{\"tp\":8,\"flag\":524417,\"flen\":1,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}",
>>>>>>> 7cb7af71792 (expression: fix the return type of `coalesce` when arg type is `DATE` (#48032)):pkg/expression/expr_to_pb_test.go
}
for i, pbExpr := range pbExprs {
js, err := json.Marshal(pbExpr)
Expand Down
14 changes: 10 additions & 4 deletions expression/typeinfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1017,10 +1017,16 @@ func (s *InferTypeSuite) createTestCase4EncryptionFuncs() []typeInferTestCase {

func (s *InferTypeSuite) createTestCase4CompareFuncs() []typeInferTestCase {
return []typeInferTestCase{
{"coalesce(c_int_d, 1)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 11, 0},
{"coalesce(NULL, c_int_d)", mysql.TypeLong, charset.CharsetBin, mysql.BinaryFlag, 11, 0},
{"coalesce(c_int_d, c_decimal)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 15, 3},
{"coalesce(c_int_d, c_datetime)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 22, types.UnspecifiedLength},
{"coalesce(c_int_d, c_int_d)", mysql.TypeLong, charset.CharsetBin, mysql.BinaryFlag, 11, 0},
{"coalesce(c_int_d, c_decimal)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 14, 3},
{"coalesce(c_int_d, c_char)", mysql.TypeString, charset.CharsetUTF8MB4, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"coalesce(c_int_d, c_binary)", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"coalesce(c_char, c_binary)", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"coalesce(null, null)", mysql.TypeNull, charset.CharsetBin, mysql.BinaryFlag, 0, 0},
{"coalesce(c_double_d, c_timestamp_d)", mysql.TypeVarchar, charset.CharsetUTF8MB4, 0, 22, types.UnspecifiedLength},
{"coalesce(c_json, c_decimal)", mysql.TypeLongBlob, charset.CharsetUTF8MB4, 0, math.MaxUint32, types.UnspecifiedLength},
{"coalesce(c_time, c_date)", mysql.TypeDatetime, charset.CharsetUTF8MB4, 0, mysql.MaxDatetimeWidthNoFsp + 3 + 1, 3},
{"coalesce(c_time_d, c_date)", mysql.TypeDatetime, charset.CharsetUTF8MB4, 0, mysql.MaxDatetimeWidthNoFsp, 0},

{"isnull(c_int_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0},
{"isnull(c_bigint_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0},
Expand Down