diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index 23df587a40bd5..545abd497a2da 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -469,7 +469,8 @@ func (b *castJSONAsArrayFunctionSig) evalJSON(row chunk.Row) (res types.BinaryJS } arrayVals := make([]any, 0, len(b.args)) - f := convertJSON2Tp(b.tp.ArrayType()) + ft := b.tp.ArrayType() + f := convertJSON2Tp(ft.EvalType()) if f == nil { return types.BinaryJSON{}, false, ErrNotSupportedYet.GenWithStackByArgs("CAS-ing JSON to the target type") } @@ -486,7 +487,7 @@ func (b *castJSONAsArrayFunctionSig) evalJSON(row chunk.Row) (res types.BinaryJS sc.TruncateAsWarning = originTruncateAsWarning }() for i := 0; i < val.GetElemCount(); i++ { - item, err := f(sc, val.ArrayGetElem(i)) + item, err := f(sc, val.ArrayGetElem(i), ft) if err != nil { return types.BinaryJSON{}, false, err } @@ -495,31 +496,31 @@ func (b *castJSONAsArrayFunctionSig) evalJSON(row chunk.Row) (res types.BinaryJS return types.CreateBinaryJSON(arrayVals), false, nil } -func convertJSON2Tp(tp *types.FieldType) func(*stmtctx.StatementContext, types.BinaryJSON) (any, error) { - switch tp.EvalType() { +func convertJSON2Tp(eval types.EvalType) func(*stmtctx.StatementContext, types.BinaryJSON, *types.FieldType) (any, error) { + switch eval { case types.ETString: - return func(sc *stmtctx.StatementContext, item types.BinaryJSON) (any, error) { + return func(sc *stmtctx.StatementContext, item types.BinaryJSON, tp *types.FieldType) (any, error) { if item.TypeCode != types.JSONTypeCodeString { return nil, ErrInvalidJSONForFuncIndex } return types.ProduceStrWithSpecifiedTp(string(item.GetString()), tp, sc, false) } case types.ETInt: - return func(sc *stmtctx.StatementContext, item types.BinaryJSON) (any, error) { + return func(sc *stmtctx.StatementContext, item types.BinaryJSON, tp *types.FieldType) (any, error) { if item.TypeCode != types.JSONTypeCodeInt64 && item.TypeCode != types.JSONTypeCodeUint64 { return nil, ErrInvalidJSONForFuncIndex } return types.ConvertJSONToInt(sc, item, mysql.HasUnsignedFlag(tp.GetFlag()), tp.GetType()) } case types.ETReal, types.ETDecimal: - return func(sc *stmtctx.StatementContext, item types.BinaryJSON) (any, error) { + return func(sc *stmtctx.StatementContext, item types.BinaryJSON, tp *types.FieldType) (any, error) { if item.TypeCode != types.JSONTypeCodeInt64 && item.TypeCode != types.JSONTypeCodeUint64 && item.TypeCode != types.JSONTypeCodeFloat64 { return nil, ErrInvalidJSONForFuncIndex } return types.ConvertJSONToFloat(sc, item) } case types.ETDatetime: - return func(sc *stmtctx.StatementContext, item types.BinaryJSON) (any, error) { + return func(sc *stmtctx.StatementContext, item types.BinaryJSON, tp *types.FieldType) (any, error) { if (tp.GetType() == mysql.TypeDatetime && item.TypeCode != types.JSONTypeCodeDatetime) || (tp.GetType() == mysql.TypeDate && item.TypeCode != types.JSONTypeCodeDate) { return nil, ErrInvalidJSONForFuncIndex } @@ -532,7 +533,7 @@ func convertJSON2Tp(tp *types.FieldType) func(*stmtctx.StatementContext, types.B return res, nil } case types.ETDuration: - return func(sc *stmtctx.StatementContext, item types.BinaryJSON) (any, error) { + return func(sc *stmtctx.StatementContext, item types.BinaryJSON, tp *types.FieldType) (any, error) { if item.TypeCode != types.JSONTypeCodeDuration { return nil, ErrInvalidJSONForFuncIndex }