From f3494e91f8570200275bf20b11974ffd63c912d9 Mon Sep 17 00:00:00 2001 From: Kenan Yao Date: Wed, 21 Nov 2018 14:07:05 +0800 Subject: [PATCH] expression: support JSON return type in `case` expression (#8355) --- expression/builtin_control.go | 37 ++++++++++++++++++++++++++++++ expression/builtin_control_test.go | 2 ++ expression/distsql_builtin.go | 2 ++ expression/evaluator_test.go | 2 ++ 4 files changed, 43 insertions(+) diff --git a/expression/builtin_control.go b/expression/builtin_control.go index 15ad0d736ef40..d5282ccc5a41e 100644 --- a/expression/builtin_control.go +++ b/expression/builtin_control.go @@ -205,6 +205,9 @@ func (c *caseWhenFunctionClass) getFunction(ctx sessionctx.Context, args []Expre case types.ETDuration: sig = &builtinCaseWhenDurationSig{bf} sig.setPbCode(tipb.ScalarFuncSig_CaseWhenDuration) + case types.ETJson: + sig = &builtinCaseWhenJSONSig{bf} + sig.setPbCode(tipb.ScalarFuncSig_CaseWhenJson) } return sig, nil } @@ -425,6 +428,40 @@ func (b *builtinCaseWhenDurationSig) evalDuration(row chunk.Row) (ret types.Dura return ret, true, nil } +type builtinCaseWhenJSONSig struct { + baseBuiltinFunc +} + +func (b *builtinCaseWhenJSONSig) Clone() builtinFunc { + newSig := &builtinCaseWhenJSONSig{} + newSig.cloneFrom(&b.baseBuiltinFunc) + return newSig +} + +// evalJSON evals a builtinCaseWhenJSONSig. +// See https://dev.mysql.com/doc/refman/5.7/en/case.html +func (b *builtinCaseWhenJSONSig) evalJSON(row chunk.Row) (ret json.BinaryJSON, isNull bool, err error) { + var condition int64 + args, l := b.getArgs(), len(b.getArgs()) + for i := 0; i < l-1; i += 2 { + condition, isNull, err = args[i].EvalInt(b.ctx, row) + if err != nil { + return + } + if isNull || condition == 0 { + continue + } + return args[i+1].EvalJSON(b.ctx, row) + } + // when clause(condition, result) -> args[i], args[i+1]; (i >= 0 && i+1 < l-1) + // else clause -> args[l-1] + // If case clause has else clause, l%2 == 1. + if l%2 == 1 { + return args[l-1].EvalJSON(b.ctx, row) + } + return ret, true, nil +} + type ifFunctionClass struct { baseFunctionClass } diff --git a/expression/builtin_control_test.go b/expression/builtin_control_test.go index 37101ec326f88..23a59ebdc22b8 100644 --- a/expression/builtin_control_test.go +++ b/expression/builtin_control_test.go @@ -38,6 +38,8 @@ func (s *testEvaluatorSuite) TestCaseWhen(c *C) { {[]interface{}{nil, 1, nil, 2, 3}, 3}, {[]interface{}{false, 1, nil, 2, 3}, 3}, {[]interface{}{nil, 1, false, 2, 3}, 3}, + {[]interface{}{1, jsonInt.GetMysqlJSON(), nil}, 3}, + {[]interface{}{0, jsonInt.GetMysqlJSON(), nil}, nil}, } fc := funcs[ast.Case] for _, t := range tbl { diff --git a/expression/distsql_builtin.go b/expression/distsql_builtin.go index 45e4387bfdc01..c1c61a8db1183 100644 --- a/expression/distsql_builtin.go +++ b/expression/distsql_builtin.go @@ -355,6 +355,8 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti case tipb.ScalarFuncSig_CoalesceInt: f = &builtinCoalesceIntSig{base} + case tipb.ScalarFuncSig_CaseWhenJson: + f = &builtinCaseWhenJSONSig{base} case tipb.ScalarFuncSig_CaseWhenDecimal: f = &builtinCaseWhenDecimalSig{base} case tipb.ScalarFuncSig_CaseWhenDuration: diff --git a/expression/evaluator_test.go b/expression/evaluator_test.go index b938d5eb6b88f..4d1560d69210c 100644 --- a/expression/evaluator_test.go +++ b/expression/evaluator_test.go @@ -103,6 +103,8 @@ func (s *testEvaluatorSuite) kindToFieldType(kind byte) types.FieldType { ft.Collate = charset.CollationBin case types.KindMysqlBit: ft.Tp = mysql.TypeBit + case types.KindMysqlJSON: + ft.Tp = mysql.TypeJSON } return ft }