Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
xiongjiwei committed Dec 26, 2022
1 parent 7e67a11 commit d6aac40
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,8 @@ func (b *castJSONAsArrayFunctionSig) evalJSON(row chunk.Row) (res types.BinaryJS
}

arrayVals := make([]any, 0, len(b.args))
f := convertJSON2Tp(b.tp.ArrayType())
ft := b.tp.ArrayType()
f := convertJSON2Tp(ft.EvalType())
if f == nil {
return types.BinaryJSON{}, false, ErrNotSupportedYet.GenWithStackByArgs("CAS-ing JSON to the target type")
}
Expand All @@ -486,7 +487,7 @@ func (b *castJSONAsArrayFunctionSig) evalJSON(row chunk.Row) (res types.BinaryJS
sc.TruncateAsWarning = originTruncateAsWarning
}()
for i := 0; i < val.GetElemCount(); i++ {
item, err := f(sc, val.ArrayGetElem(i))
item, err := f(sc, val.ArrayGetElem(i), ft)
if err != nil {
return types.BinaryJSON{}, false, err
}
Expand All @@ -495,31 +496,31 @@ func (b *castJSONAsArrayFunctionSig) evalJSON(row chunk.Row) (res types.BinaryJS
return types.CreateBinaryJSON(arrayVals), false, nil
}

func convertJSON2Tp(tp *types.FieldType) func(*stmtctx.StatementContext, types.BinaryJSON) (any, error) {
switch tp.EvalType() {
func convertJSON2Tp(eval types.EvalType) func(*stmtctx.StatementContext, types.BinaryJSON, *types.FieldType) (any, error) {
switch eval {
case types.ETString:
return func(sc *stmtctx.StatementContext, item types.BinaryJSON) (any, error) {
return func(sc *stmtctx.StatementContext, item types.BinaryJSON, tp *types.FieldType) (any, error) {
if item.TypeCode != types.JSONTypeCodeString {
return nil, ErrInvalidJSONForFuncIndex
}
return types.ProduceStrWithSpecifiedTp(string(item.GetString()), tp, sc, false)
}
case types.ETInt:
return func(sc *stmtctx.StatementContext, item types.BinaryJSON) (any, error) {
return func(sc *stmtctx.StatementContext, item types.BinaryJSON, tp *types.FieldType) (any, error) {
if item.TypeCode != types.JSONTypeCodeInt64 && item.TypeCode != types.JSONTypeCodeUint64 {
return nil, ErrInvalidJSONForFuncIndex
}
return types.ConvertJSONToInt(sc, item, mysql.HasUnsignedFlag(tp.GetFlag()), tp.GetType())
}
case types.ETReal, types.ETDecimal:
return func(sc *stmtctx.StatementContext, item types.BinaryJSON) (any, error) {
return func(sc *stmtctx.StatementContext, item types.BinaryJSON, tp *types.FieldType) (any, error) {
if item.TypeCode != types.JSONTypeCodeInt64 && item.TypeCode != types.JSONTypeCodeUint64 && item.TypeCode != types.JSONTypeCodeFloat64 {
return nil, ErrInvalidJSONForFuncIndex
}
return types.ConvertJSONToFloat(sc, item)
}
case types.ETDatetime:
return func(sc *stmtctx.StatementContext, item types.BinaryJSON) (any, error) {
return func(sc *stmtctx.StatementContext, item types.BinaryJSON, tp *types.FieldType) (any, error) {
if (tp.GetType() == mysql.TypeDatetime && item.TypeCode != types.JSONTypeCodeDatetime) || (tp.GetType() == mysql.TypeDate && item.TypeCode != types.JSONTypeCodeDate) {
return nil, ErrInvalidJSONForFuncIndex
}
Expand All @@ -532,7 +533,7 @@ func convertJSON2Tp(tp *types.FieldType) func(*stmtctx.StatementContext, types.B
return res, nil
}
case types.ETDuration:
return func(sc *stmtctx.StatementContext, item types.BinaryJSON) (any, error) {
return func(sc *stmtctx.StatementContext, item types.BinaryJSON, tp *types.FieldType) (any, error) {
if item.TypeCode != types.JSONTypeCodeDuration {
return nil, ErrInvalidJSONForFuncIndex
}
Expand Down

0 comments on commit d6aac40

Please sign in to comment.