From 305cf424997144f38c268112055fc446d30b7938 Mon Sep 17 00:00:00 2001 From: xzhangxian1008 Date: Tue, 10 Sep 2024 16:27:56 +0800 Subject: [PATCH] expression: fix the return type of coalesce when arg type is DATE (#55969) close pingcap/tidb#46475 --- expression/builtin_compare.go | 72 ++------ expression/builtin_compare_test.go | 9 + expression/builtin_control.go | 254 ++++++++++++++++++++++------- expression/expr_to_pb_test.go | 2 +- expression/typeinfer_test.go | 14 +- 5 files changed, 221 insertions(+), 130 deletions(-) diff --git a/expression/builtin_compare.go b/expression/builtin_compare.go index 74d7c60043df5..37b64b774e2cd 100644 --- a/expression/builtin_compare.go +++ b/expression/builtin_compare.go @@ -118,18 +118,19 @@ func (c *coalesceFunctionClass) getFunction(ctx sessionctx.Context, args []Expre return nil, err } - fieldTps := make([]*types.FieldType, 0, len(args)) + flag := uint(0) for _, arg := range args { - fieldTps = append(fieldTps, arg.GetType()) + flag |= arg.GetType().GetFlag() & mysql.NotNullFlag } - // Use the aggregated field type as retType. - resultFieldType := types.AggFieldType(fieldTps) - var tempType uint - resultEvalType := types.AggregateEvalType(fieldTps, &tempType) - resultFieldType.SetFlag(tempType) - retEvalTp := resultFieldType.EvalType() + resultFieldType, err := InferType4ControlFuncs(ctx, c.funcName, args...) + if err != nil { + return nil, err + } + + resultFieldType.AddFlag(flag) + retEvalTp := resultFieldType.EvalType() fieldEvalTps := make([]types.EvalType, 0, len(args)) for range args { fieldEvalTps = append(fieldEvalTps, retEvalTp) @@ -140,60 +141,7 @@ func (c *coalesceFunctionClass) getFunction(ctx sessionctx.Context, args []Expre return nil, err } - bf.tp.AddFlag(resultFieldType.GetFlag()) - resultFieldType.SetFlen(0) - resultFieldType.SetDecimal(types.UnspecifiedLength) - - // Set retType to BINARY(0) if all arguments are of type NULL. - if resultFieldType.GetType() == mysql.TypeNull { - types.SetBinChsClnFlag(bf.tp) - resultFieldType.SetFlen(0) - resultFieldType.SetDecimal(0) - } else { - maxIntLen := 0 - maxFlen := 0 - - // Find the max length of field in `maxFlen`, - // and max integer-part length in `maxIntLen`. - for _, argTp := range fieldTps { - if argTp.GetDecimal() > resultFieldType.GetDecimal() { - resultFieldType.SetDecimalUnderLimit(argTp.GetDecimal()) - } - argIntLen := argTp.GetFlen() - if argTp.GetDecimal() > 0 { - argIntLen -= argTp.GetDecimal() + 1 - } - - // Reduce the sign bit if it is a signed integer/decimal - if !mysql.HasUnsignedFlag(argTp.GetFlag()) { - argIntLen-- - } - if argIntLen > maxIntLen { - maxIntLen = argIntLen - } - if argTp.GetFlen() > maxFlen || argTp.GetFlen() == types.UnspecifiedLength { - maxFlen = argTp.GetFlen() - } - } - // For integer, field length = maxIntLen + (1/0 for sign bit) - // For decimal, field length = maxIntLen + maxDecimal + (1/0 for sign bit) - if resultEvalType == types.ETInt || resultEvalType == types.ETDecimal { - resultFieldType.SetFlenUnderLimit(maxIntLen + resultFieldType.GetDecimal()) - if resultFieldType.GetDecimal() > 0 { - resultFieldType.SetFlenUnderLimit(resultFieldType.GetFlen() + 1) - } - if !mysql.HasUnsignedFlag(resultFieldType.GetFlag()) { - resultFieldType.SetFlenUnderLimit(resultFieldType.GetFlen() + 1) - } - bf.tp = resultFieldType - } else { - bf.tp.SetFlen(maxFlen) - } - // Set the field length to maxFlen for other types. - if bf.tp.GetFlen() > mysql.MaxDecimalWidth { - bf.tp.SetFlen(mysql.MaxDecimalWidth) - } - } + bf.tp = resultFieldType switch retEvalTp { case types.ETInt: diff --git a/expression/builtin_compare_test.go b/expression/builtin_compare_test.go index 78229b77b1f60..59a48d281d3cb 100644 --- a/expression/builtin_compare_test.go +++ b/expression/builtin_compare_test.go @@ -406,3 +406,12 @@ func TestGreatestLeastFunc(t *testing.T) { _, err = funcs[ast.Least].getFunction(ctx, []Expression{NewZero(), NewOne()}) require.NoError(t, err) } + +func TestIssue46475(t *testing.T) { + ctx := createContext(t) + args := []interface{}{nil, dt, nil} + + f, err := newFunctionForTest(ctx, ast.Coalesce, primitiveValsToConstants(ctx, args)...) + require.NoError(t, err) + require.Equal(t, f.GetType().GetType(), mysql.TypeDate) +} diff --git a/expression/builtin_control.go b/expression/builtin_control.go index 643e6cf89e29a..ac2d21d0044b9 100644 --- a/expression/builtin_control.go +++ b/expression/builtin_control.go @@ -15,6 +15,7 @@ package expression import ( + "github.com/pingcap/tidb/parser/ast" "github.com/pingcap/tidb/parser/mysql" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/types" @@ -61,47 +62,92 @@ func maxlen(lhsFlen, rhsFlen int) int { return mathutil.Max(lhsFlen, rhsFlen) } -// InferType4ControlFuncs infer result type for builtin IF, IFNULL, NULLIF, LEAD and LAG. -func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, lexp, rexp Expression) (*types.FieldType, error) { - lhs, rhs := lexp.GetType(), rexp.GetType() - resultFieldType := &types.FieldType{} - if lhs.GetType() == mysql.TypeNull { - *resultFieldType = *rhs - // If any of arg is NULL, result type need unset NotNullFlag. - tempFlag := resultFieldType.GetFlag() - types.SetTypeFlag(&tempFlag, mysql.NotNullFlag, false) - resultFieldType.SetFlag(tempFlag) - // If both arguments are NULL, make resulting type BINARY(0). - if rhs.GetType() == mysql.TypeNull { - resultFieldType.SetType(mysql.TypeString) - resultFieldType.SetFlen(0) - resultFieldType.SetDecimal(0) - types.SetBinChsClnFlag(resultFieldType) +func setFlenFromArgs(evalType types.EvalType, resultFieldType *types.FieldType, argTps ...*types.FieldType) { + if evalType == types.ETDecimal || evalType == types.ETInt { + maxArgFlen := 0 + for i := range argTps { + flagLen := 0 + if !mysql.HasUnsignedFlag(argTps[i].GetFlag()) { + flagLen = 1 + } + flen := argTps[i].GetFlen() - flagLen + if argTps[i].GetDecimal() != types.UnspecifiedLength { + flen -= argTps[i].GetDecimal() + } + maxArgFlen = maxlen(maxArgFlen, flen) } - } else if rhs.GetType() == mysql.TypeNull { - *resultFieldType = *lhs - tempFlag := resultFieldType.GetFlag() - types.SetTypeFlag(&tempFlag, mysql.NotNullFlag, false) - resultFieldType.SetFlag(tempFlag) + // For a decimal field, the `length` and `flen` are not the same. + // `length` only holds the binary data, while `flen` represents the number of digits required to display the field, including the negative sign. + // In the current implementation of TiDB, `flen` and `length` are treated as the same, so the `length` of a decimal may be inconsistent with that of MySQL. + resultFlen := maxArgFlen + resultFieldType.GetDecimal() + 1 // account for -1 len fields + resultFieldType.SetFlenUnderLimit(resultFlen) + } else if evalType == types.ETString { + maxLen := 0 + for i := range argTps { + argFlen := argTps[i].GetFlen() + if argFlen == types.UnspecifiedLength { + resultFieldType.SetFlen(types.UnspecifiedLength) + return + } + maxLen = maxlen(argFlen, maxLen) + } + resultFieldType.SetFlen(maxLen) } else { - resultFieldType = types.AggFieldType([]*types.FieldType{lhs, rhs}) - var tempFlag uint - evalType := types.AggregateEvalType([]*types.FieldType{lhs, rhs}, &tempFlag) - resultFieldType.SetFlag(tempFlag) - if evalType == types.ETInt { - resultFieldType.SetDecimal(0) - } else { - if lhs.GetDecimal() == types.UnspecifiedLength || rhs.GetDecimal() == types.UnspecifiedLength { + maxLen := 0 + for i := range argTps { + maxLen = maxlen(argTps[i].GetFlen(), maxLen) + } + resultFieldType.SetFlen(maxLen) + } +} + +func setDecimalFromArgs(evalType types.EvalType, resultFieldType *types.FieldType, argTps ...*types.FieldType) { + if evalType == types.ETInt { + resultFieldType.SetDecimal(0) + } else { + maxDecimal := 0 + for i := range argTps { + if argTps[i].GetDecimal() == types.UnspecifiedLength { resultFieldType.SetDecimal(types.UnspecifiedLength) - } else { - resultFieldType.SetDecimalUnderLimit(mathutil.Max(lhs.GetDecimal(), rhs.GetDecimal())) + return } + maxDecimal = mathutil.Max(argTps[i].GetDecimal(), maxDecimal) + } + resultFieldType.SetDecimalUnderLimit(maxDecimal) + } +} + +// NonBinaryStr means the arg is a string but not binary string +func hasNonBinaryStr(args []*types.FieldType) bool { + for _, arg := range args { + if types.IsNonBinaryStr(arg) { + return true + } + } + return false +} + +func hasBinaryStr(args []*types.FieldType) bool { + for _, arg := range args { + if types.IsBinaryStr(arg) { + return true } + } + return false +} +func addCollateAndCharsetAndFlagFromArgs(ctx sessionctx.Context, funcName string, evalType types.EvalType, resultFieldType *types.FieldType, args ...Expression) error { + switch funcName { + case ast.If, ast.Ifnull, ast.WindowFuncLead, ast.WindowFuncLag: + if len(args) != 2 { + panic("unexpected length of args for if/ifnull/lead/lag") + } + lexp, rexp := args[0], args[1] + lhs, rhs := lexp.GetType(), rexp.GetType() if types.IsNonBinaryStr(lhs) && !types.IsBinaryStr(rhs) { ec, err := CheckAndDeriveCollationFromExprs(ctx, funcName, evalType, lexp, rexp) if err != nil { - return nil, err + return err } resultFieldType.SetCollate(ec.Collation) resultFieldType.SetCharset(ec.Charset) @@ -112,7 +158,7 @@ func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, lexp, rexp } else if types.IsNonBinaryStr(rhs) && !types.IsBinaryStr(lhs) { ec, err := CheckAndDeriveCollationFromExprs(ctx, funcName, evalType, lexp, rexp) if err != nil { - return nil, err + return err } resultFieldType.SetCollate(ec.Collation) resultFieldType.SetCharset(ec.Charset) @@ -127,49 +173,131 @@ func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, lexp, rexp resultFieldType.SetCollate(mysql.DefaultCollationName) resultFieldType.SetFlag(0) } - if evalType == types.ETDecimal || evalType == types.ETInt { - lhsUnsignedFlag, rhsUnsignedFlag := mysql.HasUnsignedFlag(lhs.GetFlag()), mysql.HasUnsignedFlag(rhs.GetFlag()) - lhsFlagLen, rhsFlagLen := 0, 0 - if !lhsUnsignedFlag { - lhsFlagLen = 1 - } - if !rhsUnsignedFlag { - rhsFlagLen = 1 + case ast.Case: + if len(args) == 0 { + panic("unexpected length 0 of args for casewhen") + } + ec, err := CheckAndDeriveCollationFromExprs(ctx, funcName, evalType, args...) + if err != nil { + return err + } + resultFieldType.SetCollate(ec.Collation) + resultFieldType.SetCharset(ec.Charset) + for i := range args { + if mysql.HasBinaryFlag(args[i].GetType().GetFlag()) || !types.IsNonBinaryStr(args[i].GetType()) { + resultFieldType.AddFlag(mysql.BinaryFlag) + break } - lhsFlen := lhs.GetFlen() - lhsFlagLen - rhsFlen := rhs.GetFlen() - rhsFlagLen - if lhs.GetDecimal() != types.UnspecifiedLength { - lhsFlen -= lhs.GetDecimal() + } + case ast.Coalesce: // TODO ast.Case and ast.Coalesce should be merged into the same branch + argTypes := make([]*types.FieldType, 0) + for _, arg := range args { + argTypes = append(argTypes, arg.GetType()) + } + + nonBinaryStrExist := hasNonBinaryStr(argTypes) + binaryStrExist := hasBinaryStr(argTypes) + if !binaryStrExist && nonBinaryStrExist { + ec, err := CheckAndDeriveCollationFromExprs(ctx, funcName, evalType, args...) + if err != nil { + return err } - if lhs.GetDecimal() != types.UnspecifiedLength { - rhsFlen -= rhs.GetDecimal() + resultFieldType.SetCollate(ec.Collation) + resultFieldType.SetCharset(ec.Charset) + resultFieldType.SetFlag(0) + + // hasNonStringType means that there is a type that is not string + hasNonStringType := false + for _, argType := range argTypes { + if !types.IsString(argType.GetType()) { + hasNonStringType = true + break + } } - flen := maxlen(lhsFlen, rhsFlen) + resultFieldType.GetDecimal() + 1 // account for -1 len fields - resultFieldType.SetFlenUnderLimit(flen) - } else if evalType == types.ETString { - lhsLen, rhsLen := lhs.GetFlen(), rhs.GetFlen() - if lhsLen != types.UnspecifiedLength && rhsLen != types.UnspecifiedLength { - resultFieldType.SetFlen(mathutil.Max(lhsLen, rhsLen)) + + if hasNonStringType { + resultFieldType.AddFlag(mysql.BinaryFlag) } + } else if binaryStrExist || !evalType.IsStringKind() { + types.SetBinChsClnFlag(resultFieldType) } else { - resultFieldType.SetFlen(maxlen(lhs.GetFlen(), rhs.GetFlen())) + resultFieldType.SetCharset(mysql.DefaultCharset) + resultFieldType.SetCollate(mysql.DefaultCollationName) + resultFieldType.SetFlag(0) } + default: + panic("unexpected function: " + funcName) } - // Fix decimal for int and string. - resultEvalType := resultFieldType.EvalType() - if resultEvalType == types.ETInt { + return nil +} + +// InferType4ControlFuncs infer result type for builtin IF, IFNULL, NULLIF, CASEWHEN, COALESCE, LEAD and LAG. +func InferType4ControlFuncs(ctx sessionctx.Context, funcName string, args ...Expression) (*types.FieldType, error) { + argsNum := len(args) + if argsNum == 0 { + panic("unexpected length 0 of args") + } + nullFields := make([]*types.FieldType, 0, argsNum) + notNullFields := make([]*types.FieldType, 0, argsNum) + for i := range args { + if args[i].GetType().GetType() == mysql.TypeNull { + nullFields = append(nullFields, args[i].GetType()) + } else { + notNullFields = append(notNullFields, args[i].GetType()) + } + } + resultFieldType := &types.FieldType{} + if len(nullFields) == argsNum { // all field is TypeNull + *resultFieldType = *nullFields[0] + // If any of arg is NULL, result type need unset NotNullFlag. + tempFlag := resultFieldType.GetFlag() + types.SetTypeFlag(&tempFlag, mysql.NotNullFlag, false) + resultFieldType.SetFlag(tempFlag) + + resultFieldType.SetType(mysql.TypeNull) + resultFieldType.SetFlen(0) resultFieldType.SetDecimal(0) - if resultFieldType.GetType() == mysql.TypeEnum || resultFieldType.GetType() == mysql.TypeSet { - resultFieldType.SetType(mysql.TypeLonglong) + types.SetBinChsClnFlag(resultFieldType) + } else { + if len(notNullFields) == 1 { + *resultFieldType = *notNullFields[0] + } else { + resultFieldType = types.AggFieldType(notNullFields) + var tempFlag uint + evalType := types.AggregateEvalType(notNullFields, &tempFlag) + resultFieldType.SetFlag(tempFlag) + setDecimalFromArgs(evalType, resultFieldType, notNullFields...) + err := addCollateAndCharsetAndFlagFromArgs(ctx, funcName, evalType, resultFieldType, args...) + if err != nil { + return nil, err + } + setFlenFromArgs(evalType, resultFieldType, notNullFields...) } - } else if resultEvalType == types.ETString { - if lhs.GetType() != mysql.TypeNull || rhs.GetType() != mysql.TypeNull { + + // If any of arg is NULL, result type need unset NotNullFlag. + if len(nullFields) > 0 { + tempFlag := resultFieldType.GetFlag() + types.SetTypeFlag(&tempFlag, mysql.NotNullFlag, false) + resultFieldType.SetFlag(tempFlag) + } + + resultEvalType := resultFieldType.EvalType() + // fix decimal for int and string. + if resultEvalType == types.ETInt { + resultFieldType.SetDecimal(0) + } else if resultEvalType == types.ETString { resultFieldType.SetDecimal(types.UnspecifiedLength) } + // fix type for enum and set if resultFieldType.GetType() == mysql.TypeEnum || resultFieldType.GetType() == mysql.TypeSet { - resultFieldType.SetType(mysql.TypeVarchar) + switch resultEvalType { + case types.ETInt: + resultFieldType.SetType(mysql.TypeLonglong) + case types.ETString: + resultFieldType.SetType(mysql.TypeVarchar) + } } - } else if resultFieldType.GetType() == mysql.TypeDatetime { + // fix flen for datetime types.TryToFixFlenOfDatetime(resultFieldType) } return resultFieldType, nil diff --git a/expression/expr_to_pb_test.go b/expression/expr_to_pb_test.go index 1025f3c7fdcb1..21d658e97bc10 100644 --- a/expression/expr_to_pb_test.go +++ b/expression/expr_to_pb_test.go @@ -493,7 +493,7 @@ func TestOtherFunc2Pb(t *testing.T) { pbExprs, err := ExpressionsToPBList(sc, otherFuncs, client) require.NoError(t, err) jsons := map[string]string{ - ast.Coalesce: "{\"tp\":10000,\"children\":[{\"tp\":201,\"val\":\"gAAAAAAAAAE=\",\"sig\":0,\"field_type\":{\"tp\":3,\"flag\":0,\"flen\":11,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\"},\"has_distinct\":false}],\"sig\":4201,\"field_type\":{\"tp\":3,\"flag\":128,\"flen\":11,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\"},\"has_distinct\":false}", + ast.Coalesce: "{\"tp\":10000,\"children\":[{\"tp\":201,\"val\":\"gAAAAAAAAAE=\",\"sig\":0,\"field_type\":{\"tp\":3,\"flag\":0,\"flen\":11,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\"},\"has_distinct\":false}],\"sig\":4201,\"field_type\":{\"tp\":3,\"flag\":0,\"flen\":11,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\"},\"has_distinct\":false}", ast.IsNull: "{\"tp\":10000,\"children\":[{\"tp\":201,\"val\":\"gAAAAAAAAAE=\",\"sig\":0,\"field_type\":{\"tp\":3,\"flag\":0,\"flen\":11,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\"},\"has_distinct\":false}],\"sig\":3116,\"field_type\":{\"tp\":8,\"flag\":524417,\"flen\":1,\"decimal\":0,\"collate\":-63,\"charset\":\"binary\"},\"has_distinct\":false}", } for i, pbExpr := range pbExprs { diff --git a/expression/typeinfer_test.go b/expression/typeinfer_test.go index 28f71099e9c46..09a79c4b8f353 100644 --- a/expression/typeinfer_test.go +++ b/expression/typeinfer_test.go @@ -1037,10 +1037,16 @@ func (s *InferTypeSuite) createTestCase4EncryptionFuncs() []typeInferTestCase { func (s *InferTypeSuite) createTestCase4CompareFuncs() []typeInferTestCase { return []typeInferTestCase{ - {"coalesce(c_int_d, 1)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 11, 0}, - {"coalesce(NULL, c_int_d)", mysql.TypeLong, charset.CharsetBin, mysql.BinaryFlag, 11, 0}, - {"coalesce(c_int_d, c_decimal)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 15, 3}, - {"coalesce(c_int_d, c_datetime)", mysql.TypeVarString, charset.CharsetUTF8MB4, 0, 22, types.UnspecifiedLength}, + {"coalesce(c_int_d, c_int_d)", mysql.TypeLong, charset.CharsetBin, mysql.BinaryFlag, 11, 0}, + {"coalesce(c_int_d, c_decimal)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 14, 3}, + {"coalesce(c_int_d, c_char)", mysql.TypeString, charset.CharsetUTF8MB4, mysql.BinaryFlag, 20, types.UnspecifiedLength}, + {"coalesce(c_int_d, c_binary)", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength}, + {"coalesce(c_char, c_binary)", mysql.TypeString, charset.CharsetBin, mysql.BinaryFlag, 20, types.UnspecifiedLength}, + {"coalesce(null, null)", mysql.TypeNull, charset.CharsetBin, mysql.BinaryFlag, 0, 0}, + {"coalesce(c_double_d, c_timestamp_d)", mysql.TypeVarchar, charset.CharsetUTF8MB4, 0, 22, types.UnspecifiedLength}, + {"coalesce(c_json, c_decimal)", mysql.TypeLongBlob, charset.CharsetUTF8MB4, 0, math.MaxUint32, types.UnspecifiedLength}, + {"coalesce(c_time, c_date)", mysql.TypeDatetime, charset.CharsetUTF8MB4, 0, mysql.MaxDatetimeWidthNoFsp + 3 + 1, 3}, + {"coalesce(c_time_d, c_date)", mysql.TypeDatetime, charset.CharsetUTF8MB4, 0, mysql.MaxDatetimeWidthNoFsp, 0}, {"isnull(c_int_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.NotNullFlag | mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0}, {"isnull(c_bigint_d )", mysql.TypeLonglong, charset.CharsetBin, mysql.NotNullFlag | mysql.BinaryFlag | mysql.IsBooleanFlag, 1, 0},