diff --git a/expression/builtin.go b/expression/builtin.go index 7e615ba98d2e3..8260a8c802e3c 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -879,8 +879,8 @@ var funcs = map[string]functionClass{ ast.UnaryPlus: &unaryOpFunctionClass{baseFunctionClass{ast.UnaryPlus, 1, 1}, opcode.Plus}, ast.UnaryMinus: &unaryMinusFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}}, ast.In: &inFunctionClass{baseFunctionClass{ast.In, 1, -1}}, - ast.IsTruth: &isTrueOpFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth}, - ast.IsFalsity: &isTrueOpFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity}, + ast.IsTruth: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth}, + ast.IsFalsity: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity}, ast.Like: &likeFunctionClass{baseFunctionClass{ast.Like, 2, 3}}, ast.Regexp: ®expFunctionClass{baseFunctionClass{ast.Regexp, 2, 2}}, ast.Case: &caseWhenFunctionClass{baseFunctionClass{ast.Case, 1, -1}}, diff --git a/expression/builtin_op.go b/expression/builtin_op.go index 5f8eb05547ed0..c0758d40881e4 100644 --- a/expression/builtin_op.go +++ b/expression/builtin_op.go @@ -28,7 +28,7 @@ var ( _ functionClass = &logicAndFunctionClass{} _ functionClass = &logicOrFunctionClass{} _ functionClass = &logicXorFunctionClass{} - _ functionClass = &isTrueOpFunctionClass{} + _ functionClass = &isTrueOrFalseFunctionClass{} _ functionClass = &unaryOpFunctionClass{} _ functionClass = &unaryMinusFunctionClass{} _ functionClass = &isNullFunctionClass{} @@ -39,7 +39,12 @@ var ( _ builtinFunc = &builtinLogicAndSig{} _ builtinFunc = &builtinLogicOrSig{} _ builtinFunc = &builtinLogicXorSig{} - _ builtinFunc = &builtinIsTrueOpSig{} + _ builtinFunc = &builtinRealIsTrueSig{} + _ builtinFunc = &builtinDecimalIsTrueSig{} + _ builtinFunc = &builtinIntIsTrueSig{} + _ builtinFunc = &builtinRealIsFalseSig{} + _ builtinFunc = &builtinDecimalIsFalseSig{} + _ builtinFunc = &builtinIntIsFalseSig{} _ builtinFunc = &builtinUnaryOpSig{} _ builtinFunc = &builtinUnaryMinusIntSig{} _ builtinFunc = &builtinDecimalIsNullSig{} @@ -346,40 +351,141 @@ func (b *builtinRightShiftSig) evalInt(row []types.Datum) (int64, bool, error) { return int64(uint64(arg0) >> uint64(arg1)), false, nil } -type isTrueOpFunctionClass struct { +type isTrueOrFalseFunctionClass struct { baseFunctionClass - op opcode.Op } -func (c *isTrueOpFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) { - sig := &builtinIsTrueOpSig{newBaseBuiltinFunc(args, ctx), c.op} - return sig.setSelf(sig), errors.Trace(c.verifyArgs(args)) +func (c *isTrueOrFalseFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) { + if err := c.verifyArgs(args); err != nil { + return nil, errors.Trace(err) + } + + argTp := tpInt + switch args[0].GetTypeClass() { + case types.ClassReal: + argTp = tpReal + case types.ClassDecimal: + argTp = tpDecimal + } + bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpInt, argTp) + if err != nil { + return nil, errors.Trace(err) + } + bf.tp.Flen = 1 + + var sig builtinFunc + switch c.op { + case opcode.IsTruth: + switch argTp { + case tpReal: + sig = &builtinRealIsTrueSig{baseIntBuiltinFunc{bf}} + case tpDecimal: + sig = &builtinDecimalIsTrueSig{baseIntBuiltinFunc{bf}} + case tpInt: + sig = &builtinIntIsTrueSig{baseIntBuiltinFunc{bf}} + } + case opcode.IsFalsity: + switch argTp { + case tpReal: + sig = &builtinRealIsFalseSig{baseIntBuiltinFunc{bf}} + case tpDecimal: + sig = &builtinDecimalIsFalseSig{baseIntBuiltinFunc{bf}} + case tpInt: + sig = &builtinIntIsFalseSig{baseIntBuiltinFunc{bf}} + } + } + return sig.setSelf(sig), nil } -type builtinIsTrueOpSig struct { - baseBuiltinFunc +type builtinRealIsTrueSig struct { + baseIntBuiltinFunc +} - op opcode.Op +func (b *builtinRealIsTrueSig) evalInt(row []types.Datum) (int64, bool, error) { + input, isNull, err := b.args[0].EvalReal(row, b.ctx.GetSessionVars().StmtCtx) + if err != nil { + return 0, true, errors.Trace(err) + } + if isNull || input == 0 { + return 0, false, nil + } + return 1, false, nil } -func (b *builtinIsTrueOpSig) eval(row []types.Datum) (d types.Datum, err error) { - args, err := b.evalArgs(row) +type builtinDecimalIsTrueSig struct { + baseIntBuiltinFunc +} + +func (b *builtinDecimalIsTrueSig) evalInt(row []types.Datum) (int64, bool, error) { + input, isNull, err := b.args[0].EvalDecimal(row, b.ctx.GetSessionVars().StmtCtx) if err != nil { - return types.Datum{}, errors.Trace(err) + return 0, true, errors.Trace(err) } - var boolVal bool - if !args[0].IsNull() { - iVal, err := args[0].ToBool(b.ctx.GetSessionVars().StmtCtx) - if err != nil { - return d, errors.Trace(err) - } - if (b.op == opcode.IsTruth && iVal == 1) || (b.op == opcode.IsFalsity && iVal == 0) { - boolVal = true - } + if isNull || input.IsZero() { + return 0, false, nil } - d.SetInt64(boolToInt64(boolVal)) - return + return 1, false, nil +} + +type builtinIntIsTrueSig struct { + baseIntBuiltinFunc +} + +func (b *builtinIntIsTrueSig) evalInt(row []types.Datum) (int64, bool, error) { + input, isNull, err := b.args[0].EvalInt(row, b.ctx.GetSessionVars().StmtCtx) + if err != nil { + return 0, true, errors.Trace(err) + } + if isNull || input == 0 { + return 0, false, nil + } + return 1, false, nil +} + +type builtinRealIsFalseSig struct { + baseIntBuiltinFunc +} + +func (b *builtinRealIsFalseSig) evalInt(row []types.Datum) (int64, bool, error) { + input, isNull, err := b.args[0].EvalReal(row, b.ctx.GetSessionVars().StmtCtx) + if err != nil { + return 0, true, errors.Trace(err) + } + if isNull || input != 0 { + return 0, false, nil + } + return 1, false, nil +} + +type builtinDecimalIsFalseSig struct { + baseIntBuiltinFunc +} + +func (b *builtinDecimalIsFalseSig) evalInt(row []types.Datum) (int64, bool, error) { + input, isNull, err := b.args[0].EvalDecimal(row, b.ctx.GetSessionVars().StmtCtx) + if err != nil { + return 0, true, errors.Trace(err) + } + if isNull || !input.IsZero() { + return 0, false, nil + } + return 1, false, nil +} + +type builtinIntIsFalseSig struct { + baseIntBuiltinFunc +} + +func (b *builtinIntIsFalseSig) evalInt(row []types.Datum) (int64, bool, error) { + input, isNull, err := b.args[0].EvalInt(row, b.ctx.GetSessionVars().StmtCtx) + if err != nil { + return 0, true, errors.Trace(err) + } + if isNull || input != 0 { + return 0, false, nil + } + return 1, false, nil } type bitNegFunctionClass struct { diff --git a/expression/builtin_op_test.go b/expression/builtin_op_test.go index e1442abc9d385..225d27b6cdb0d 100644 --- a/expression/builtin_op_test.go +++ b/expression/builtin_op_test.go @@ -20,6 +20,7 @@ import ( . "github.com/pingcap/check" "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/util/testleak" + "github.com/pingcap/tidb/util/testutil" "github.com/pingcap/tidb/util/types" ) @@ -476,3 +477,77 @@ func (s *testEvaluatorSuite) TestUnaryNot(c *C) { c.Assert(err, IsNil) c.Assert(f.isDeterministic(), IsTrue) } + +func (s *testEvaluatorSuite) TestIsTrueOrFalse(c *C) { + defer testleak.AfterTest(c)() + sc := s.ctx.GetSessionVars().StmtCtx + origin := sc.IgnoreTruncate + defer func() { + sc.IgnoreTruncate = origin + }() + sc.IgnoreTruncate = true + + testCases := []struct { + args []interface{} + isTrue interface{} + isFalse interface{} + }{ + { + args: []interface{}{-12}, + isTrue: 1, + isFalse: 0, + }, + { + args: []interface{}{12}, + isTrue: 1, + isFalse: 0, + }, + { + args: []interface{}{0}, + isTrue: 0, + isFalse: 1, + }, + { + args: []interface{}{float64(0)}, + isTrue: 0, + isFalse: 1, + }, + { + args: []interface{}{"aaa"}, + isTrue: 0, + isFalse: 1, + }, + { + args: []interface{}{""}, + isTrue: 0, + isFalse: 1, + }, + { + args: []interface{}{nil}, + isTrue: 0, + isFalse: 0, + }, + } + + for _, tc := range testCases { + isTrueSig, err := funcs[ast.IsTruth].getFunction(datumsToConstants(types.MakeDatums(tc.args...)), s.ctx) + c.Assert(err, IsNil) + c.Assert(isTrueSig, NotNil) + c.Assert(isTrueSig.isDeterministic(), IsTrue) + + isTrue, err := isTrueSig.eval(nil) + c.Assert(err, IsNil) + c.Assert(isTrue, testutil.DatumEquals, types.NewDatum(tc.isTrue)) + } + + for _, tc := range testCases { + isFalseSig, err := funcs[ast.IsFalsity].getFunction(datumsToConstants(types.MakeDatums(tc.args...)), s.ctx) + c.Assert(err, IsNil) + c.Assert(isFalseSig, NotNil) + c.Assert(isFalseSig.isDeterministic(), IsTrue) + + isFalse, err := isFalseSig.eval(nil) + c.Assert(err, IsNil) + c.Assert(isFalse, testutil.DatumEquals, types.NewDatum(tc.isFalse)) + } +} diff --git a/expression/integration_test.go b/expression/integration_test.go index 910f0e4998ee2..2255a21aa6b50 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -830,7 +830,7 @@ func (s *testIntegrationSuite) TestBuiltin(c *C) { tk := testkit.NewTestKit(c, s.store) tk.MustExec("use test") - // for is true + // for is true && is false tk.MustExec("drop table if exists t") tk.MustExec("create table t (a int, b int, index idx_b (b))") tk.MustExec("insert t values (1, 1)") @@ -844,6 +844,11 @@ func (s *testIntegrationSuite) TestBuiltin(c *C) { result.Check(nil) result = tk.MustQuery("select * from t where a is not true") result.Check(nil) + result = tk.MustQuery(`select 1 is true, 0 is true, null is true, "aaa" is true, "" is true, -12.00 is true, 0.0 is true, 0.0000001 is true;`) + result.Check(testkit.Rows("1 0 0 0 0 1 0 1")) + result = tk.MustQuery(`select 1 is false, 0 is false, null is false, "aaa" is false, "" is false, -12.00 is false, 0.0 is false, 0.0000001 is false;`) + result.Check(testkit.Rows("0 1 0 1 1 0 1 0")) + // for in result = tk.MustQuery("select * from t where b in (a)") result.Check(testkit.Rows("1 1", "2 2")) diff --git a/plan/typeinfer_test.go b/plan/typeinfer_test.go index 942e27ee8db98..4e01c3fdbd2f4 100644 --- a/plan/typeinfer_test.go +++ b/plan/typeinfer_test.go @@ -81,6 +81,7 @@ func (s *testPlanSuite) TestInferType(c *C) { tests = append(tests, s.createTestCase4EncryptionFuncs()...) tests = append(tests, s.createTestCase4CompareFuncs()...) tests = append(tests, s.createTestCase4Miscellaneous()...) + tests = append(tests, s.createTestCase4OpFuncs()...) for _, tt := range tests { ctx := testKit.Se.(context.Context) @@ -584,3 +585,29 @@ func (s *testPlanSuite) createTestCase4Miscellaneous() []typeInferTestCase { {"sleep(c_binary)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 20, 0}, } } + +func (s *testPlanSuite) createTestCase4OpFuncs() []typeInferTestCase { + return []typeInferTestCase{ + {"c_int is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, + {"c_decimal is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, + {"c_double is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, + {"c_float is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, + {"c_datetime is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, + {"c_time is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, + {"c_enum is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, + {"c_text is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, + {"18446 is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, + {"1844674.1 is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, + + {"c_int is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, + {"c_decimal is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, + {"c_double is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, + {"c_float is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, + {"c_datetime is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, + {"c_time is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, + {"c_enum is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, + {"c_text is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, + {"18446 is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, + {"1844674.1 is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0}, + } +}