Skip to content

Commit

Permalink
expression: fix the return type of coalesce when arg type is DATE (
Browse files Browse the repository at this point in the history
  • Loading branch information
ti-chi-bot authored Nov 8, 2023
1 parent 3ebecae commit faf4171
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 70 deletions.
73 changes: 11 additions & 62 deletions pkg/expression/builtin_compare.go
Original file line number Diff line number Diff line change
Expand Up @@ -119,18 +119,19 @@ func (c *coalesceFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
return nil, err
}

fieldTps := make([]*types.FieldType, 0, len(args))
flag := uint(0)
for _, arg := range args {
fieldTps = append(fieldTps, arg.GetType())
flag |= arg.GetType().GetFlag() & mysql.NotNullFlag
}

// Use the aggregated field type as retType.
resultFieldType := types.AggFieldType(fieldTps)
var tempType uint
resultEvalType := types.AggregateEvalType(fieldTps, &tempType)
resultFieldType.SetFlag(tempType)
retEvalTp := resultFieldType.EvalType()
resultFieldType, err := InferType4ControlFuncs(ctx, c.funcName, args...)
if err != nil {
return nil, err
}

resultFieldType.AddFlag(flag)

retEvalTp := resultFieldType.EvalType()
fieldEvalTps := make([]types.EvalType, 0, len(args))
for range args {
fieldEvalTps = append(fieldEvalTps, retEvalTp)
Expand All @@ -141,60 +142,7 @@ func (c *coalesceFunctionClass) getFunction(ctx sessionctx.Context, args []Expre
return nil, err
}

bf.tp.AddFlag(resultFieldType.GetFlag())
resultFieldType.SetFlen(0)
resultFieldType.SetDecimal(types.UnspecifiedLength)

// Set retType to BINARY(0) if all arguments are of type NULL.
if resultFieldType.GetType() == mysql.TypeNull {
types.SetBinChsClnFlag(bf.tp)
resultFieldType.SetFlen(0)
resultFieldType.SetDecimal(0)
} else {
maxIntLen := 0
maxFlen := 0

// Find the max length of field in `maxFlen`,
// and max integer-part length in `maxIntLen`.
for _, argTp := range fieldTps {
if argTp.GetDecimal() > resultFieldType.GetDecimal() {
resultFieldType.SetDecimalUnderLimit(argTp.GetDecimal())
}
argIntLen := argTp.GetFlen()
if argTp.GetDecimal() > 0 {
argIntLen -= argTp.GetDecimal() + 1
}

// Reduce the sign bit if it is a signed integer/decimal
if !mysql.HasUnsignedFlag(argTp.GetFlag()) {
argIntLen--
}
if argIntLen > maxIntLen {
maxIntLen = argIntLen
}
if argTp.GetFlen() > maxFlen || argTp.GetFlen() == types.UnspecifiedLength {
maxFlen = argTp.GetFlen()
}
}
// For integer, field length = maxIntLen + (1/0 for sign bit)
// For decimal, field length = maxIntLen + maxDecimal + (1/0 for sign bit)
if resultEvalType == types.ETInt || resultEvalType == types.ETDecimal {
resultFieldType.SetFlenUnderLimit(maxIntLen + resultFieldType.GetDecimal())
if resultFieldType.GetDecimal() > 0 {
resultFieldType.SetFlenUnderLimit(resultFieldType.GetFlen() + 1)
}
if !mysql.HasUnsignedFlag(resultFieldType.GetFlag()) {
resultFieldType.SetFlenUnderLimit(resultFieldType.GetFlen() + 1)
}
bf.tp = resultFieldType
} else {
bf.tp.SetFlen(maxFlen)
}
// Set the field length to maxFlen for other types.
if bf.tp.GetFlen() > mysql.MaxDecimalWidth {
bf.tp.SetFlen(mysql.MaxDecimalWidth)
}
}
bf.tp = resultFieldType

switch retEvalTp {
case types.ETInt:
Expand Down Expand Up @@ -1252,6 +1200,7 @@ func (b *builtinIntervalRealSig) evalInt(row chunk.Row) (int64, bool, error) {
if isNull {
return -1, false, nil
}

var idx int
if b.hasNullable {
idx, err = b.linearSearch(arg0, b.args[1:], row)
Expand Down
9 changes: 9 additions & 0 deletions pkg/expression/builtin_compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,12 @@ func TestRefineArgsWithCastEnum(t *testing.T) {
require.Equal(t, zeroUintConst, args[0])
require.Equal(t, enumCol, args[1])
}

func TestIssue46475(t *testing.T) {
ctx := createContext(t)
args := []interface{}{nil, dt, nil}

f, err := newFunctionForTest(ctx, ast.Coalesce, primitiveValsToConstants(ctx, args)...)
require.NoError(t, err)
require.Equal(t, f.GetType().GetType(), mysql.TypeDate)
}
61 changes: 58 additions & 3 deletions pkg/expression/builtin_control.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,25 @@ func setDecimalFromArgs(evalType types.EvalType, resultFieldType *types.FieldTyp
}
}

// NonBinaryStr means the arg is a string but not binary string
func hasNonBinaryStr(args []*types.FieldType) bool {
for _, arg := range args {
if types.IsNonBinaryStr(arg) {
return true
}
}
return false
}

func hasBinaryStr(args []*types.FieldType) bool {
for _, arg := range args {
if types.IsBinaryStr(arg) {
return true
}
}
return false
}

func addCollateAndCharsetAndFlagFromArgs(ctx sessionctx.Context, funcName string, evalType types.EvalType, resultFieldType *types.FieldType, args ...Expression) error {
switch funcName {
case ast.If, ast.Ifnull, ast.WindowFuncLead, ast.WindowFuncLag:
Expand Down Expand Up @@ -170,13 +189,49 @@ func addCollateAndCharsetAndFlagFromArgs(ctx sessionctx.Context, funcName string
break
}
}
case ast.Coalesce: // TODO ast.Case and ast.Coalesce should be merged into the same branch
argTypes := make([]*types.FieldType, 0)
for _, arg := range args {
argTypes = append(argTypes, arg.GetType())
}

nonBinaryStrExist := hasNonBinaryStr(argTypes)
binaryStrExist := hasBinaryStr(argTypes)
if !binaryStrExist && nonBinaryStrExist {
ec, err := CheckAndDeriveCollationFromExprs(ctx, funcName, evalType, args...)
if err != nil {
return err
}
resultFieldType.SetCollate(ec.Collation)
resultFieldType.SetCharset(ec.Charset)
resultFieldType.SetFlag(0)

// hasNonStringType means that there is a type that is not string
hasNonStringType := false
for _, argType := range argTypes {
if !types.IsString(argType.GetType()) {
hasNonStringType = true
break
}
}

if hasNonStringType {
resultFieldType.AddFlag(mysql.BinaryFlag)
}
} else if binaryStrExist || !evalType.IsStringKind() {
types.SetBinChsClnFlag(resultFieldType)
} else {
resultFieldType.SetCharset(mysql.DefaultCharset)
resultFieldType.SetCollate(mysql.DefaultCollationName)
resultFieldType.SetFlag(0)
}
default:
panic("unexpected function: " + funcName)
}
return nil
}

// InferType4ControlFuncs infer result type for builtin IF, IFNULL, NULLIF, CASEWHEN, LEAD and LAG.
// InferType4ControlFuncs infer result type for builtin IF, IFNULL, NULLIF, CASEWHEN, COALESCE, LEAD and LAG.
func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, args ...Expression) (*types.FieldType, error) {
argsNum := len(args)
if argsNum == 0 {
Expand All @@ -198,8 +253,8 @@ func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, args ...Exp
tempFlag := resultFieldType.GetFlag()
types.SetTypeFlag(&tempFlag, mysql.NotNullFlag, false)
resultFieldType.SetFlag(tempFlag)
// If both arguments are NULL, make resulting type BINARY(0).
resultFieldType.SetType(mysql.TypeString)

resultFieldType.SetType(mysql.TypeNull)
resultFieldType.SetFlen(0)
resultFieldType.SetDecimal(0)
types.SetBinChsClnFlag(resultFieldType)
Expand Down
2 changes: 1 addition & 1 deletion pkg/expression/expr_to_pb_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,7 @@ func TestOtherFunc2Pb(t *testing.T) {
pbExprs, err := ExpressionsToPBList(sc, otherFuncs, client)
require.NoError(t, err)
jsons := map[string]string{
ast.Coalesce: "{\"tp\":10000,\"children\":[{\"tp\":201,\"val\":\"gAAAAAAAAAE=\",\"sig\":0,\"field_type\":{\"tp\":3,\"flag\":0,\"flen\":11,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}],\"sig\":4201,\"field_type\":{\"tp\":3,\"flag\":128,\"flen\":11,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}",
ast.Coalesce: "{\"tp\":10000,\"children\":[{\"tp\":201,\"val\":\"gAAAAAAAAAE=\",\"sig\":0,\"field_type\":{\"tp\":3,\"flag\":0,\"flen\":11,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}],\"sig\":4201,\"field_type\":{\"tp\":3,\"flag\":0,\"flen\":11,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}",
ast.IsNull: "{\"tp\":10000,\"children\":[{\"tp\":201,\"val\":\"gAAAAAAAAAE=\",\"sig\":0,\"field_type\":{\"tp\":3,\"flag\":0,\"flen\":11,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}],\"sig\":3116,\"field_type\":{\"tp\":8,\"flag\":524417,\"flen\":1,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\",\"array\":false},\"has_distinct\":false}",
}
for i, pbExpr := range pbExprs {
Expand Down
14 changes: 10 additions & 4 deletions pkg/expression/typeinfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1042,10 +1042,16 @@ func (s *InferTypeSuite) createTestCase4EncryptionFuncs() []typeInferTestCase {

func (s *InferTypeSuite) createTestCase4CompareFuncs() []typeInferTestCase {
return []typeInferTestCase{
{"coalesce(c_int_d, 1)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 11, 0},
{"coalesce(NULL, c_int_d)", mysql.TypeLong, charset.CharsetBin, mysql.BinaryFlag, 11, 0},
{"coalesce(c_int_d, c_decimal)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 15, 3},
{"coalesce(c_int_d, c_datetime)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 22, types.UnspecifiedLength},
{"coalesce(c_int_d, c_int_d)", mysql.TypeLong, charset.CharsetBin, mysql.BinaryFlag, 11, 0},
{"coalesce(c_int_d, c_decimal)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 14, 3},
{"coalesce(c_int_d, c_char)", mysql.TypeString, charset.CharsetUTF8MB4, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"coalesce(c_int_d, c_binary)", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"coalesce(c_char, c_binary)", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength},
{"coalesce(null, null)", mysql.TypeNull, charset.CharsetBin, mysql.BinaryFlag, 0, 0},
{"coalesce(c_double_d, c_timestamp_d)", mysql.TypeVarchar, charset.CharsetUTF8MB4, 0, 22, types.UnspecifiedLength},
{"coalesce(c_json, c_decimal)", mysql.TypeLongBlob, charset.CharsetUTF8MB4, 0, math.MaxUint32, types.UnspecifiedLength},
{"coalesce(c_time, c_date)", mysql.TypeDatetime, charset.CharsetUTF8MB4, 0, mysql.MaxDatetimeWidthNoFsp + 3 + 1, 3},
{"coalesce(c_time_d, c_date)", mysql.TypeDatetime, charset.CharsetUTF8MB4, 0, mysql.MaxDatetimeWidthNoFsp, 0},

{"isnull(c_int_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.NotNullFlag | mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0},
{"isnull(c_bigint_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.NotNullFlag | mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0},
Expand Down

0 comments on commit faf4171

Please sign in to comment.