From 8f777a98c3726c75ff104fc186b0a4a61e1d9053 Mon Sep 17 00:00:00 2001 From: sduzh Date: Tue, 1 Oct 2019 23:24:02 +0800 Subject: [PATCH] expression: Fix incorrect result of logical operators (pingcap#11199) --- expression/builtin.go | 4 +- expression/builtin_op.go | 72 +++++++++++++++++++++++----- expression/builtin_op_test.go | 89 +++++++++++++++++++++++++++++++++++ expression/distsql_builtin.go | 12 ++--- expression/expression.go | 22 +++++++++ 5 files changed, 179 insertions(+), 20 deletions(-) diff --git a/expression/builtin.go b/expression/builtin.go index ca8efb48e9634..fb2755298e5b2 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -660,8 +660,8 @@ var funcs = map[string]functionClass{ ast.Xor: &bitXorFunctionClass{baseFunctionClass{ast.Xor, 2, 2}}, ast.UnaryMinus: &unaryMinusFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}}, ast.In: &inFunctionClass{baseFunctionClass{ast.In, 2, -1}}, - ast.IsTruth: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth}, - ast.IsFalsity: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity}, + ast.IsTruth: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth, false}, + ast.IsFalsity: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity, false}, ast.Like: &likeFunctionClass{baseFunctionClass{ast.Like, 3, 3}}, ast.Regexp: ®expFunctionClass{baseFunctionClass{ast.Regexp, 2, 2}}, ast.Case: &caseWhenFunctionClass{baseFunctionClass{ast.Case, 1, -1}}, diff --git a/expression/builtin_op.go b/expression/builtin_op.go index 096f817cc5ab5..0654f7fefd401 100644 --- a/expression/builtin_op.go +++ b/expression/builtin_op.go @@ -67,6 +67,15 @@ func (c *logicAndFunctionClass) getFunction(ctx sessionctx.Context, args []Expre if err != nil { return nil, err } + args[0], err = wrapWithIsTrue(ctx, true, args[0]) + if err != nil { + return nil, errors.Trace(err) + } + args[1], err = wrapWithIsTrue(ctx, true, args[1]) + if err != nil { + return nil, errors.Trace(err) + } + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt) sig := &builtinLogicAndSig{bf} sig.setPbCode(tipb.ScalarFuncSig_LogicalAnd) @@ -108,6 +117,15 @@ func (c *logicOrFunctionClass) getFunction(ctx sessionctx.Context, args []Expres if err != nil { return nil, err } + args[0], err = wrapWithIsTrue(ctx, true, args[0]) + if err != nil { + return nil, errors.Trace(err) + } + args[1], err = wrapWithIsTrue(ctx, true, args[1]) + if err != nil { + return nil, errors.Trace(err) + } + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt) bf.tp.Flen = 1 sig := &builtinLogicOrSig{bf} @@ -155,6 +173,7 @@ func (c *logicXorFunctionClass) getFunction(ctx sessionctx.Context, args []Expre if err != nil { return nil, err } + bf := newBaseBuiltinFuncWithTp(ctx, args, types.ETInt, types.ETInt, types.ETInt) sig := &builtinLogicXorSig{bf} sig.setPbCode(tipb.ScalarFuncSig_LogicalXor) @@ -378,6 +397,11 @@ func (b *builtinRightShiftSig) evalInt(row chunk.Row) (int64, bool, error) { type isTrueOrFalseFunctionClass struct { baseFunctionClass op opcode.Op + + // keepNull indicates how this function treats a null input parameter. + // If keepNull is true and the input parameter is null, the function will return null. + // If keepNull is false, the null input parameter will be cast to 0. + keepNull bool } func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) { @@ -400,13 +424,13 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args [] case opcode.IsTruth: switch argTp { case types.ETReal: - sig = &builtinRealIsTrueSig{bf} + sig = &builtinRealIsTrueSig{bf, c.keepNull} sig.setPbCode(tipb.ScalarFuncSig_RealIsTrue) case types.ETDecimal: - sig = &builtinDecimalIsTrueSig{bf} + sig = &builtinDecimalIsTrueSig{bf, c.keepNull} sig.setPbCode(tipb.ScalarFuncSig_DecimalIsTrue) case types.ETInt: - sig = &builtinIntIsTrueSig{bf} + sig = &builtinIntIsTrueSig{bf, c.keepNull} sig.setPbCode(tipb.ScalarFuncSig_IntIsTrue) default: return nil, errors.Errorf("unexpected types.EvalType %v", argTp) @@ -414,13 +438,13 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args [] case opcode.IsFalsity: switch argTp { case types.ETReal: - sig = &builtinRealIsFalseSig{bf} + sig = &builtinRealIsFalseSig{bf, c.keepNull} sig.setPbCode(tipb.ScalarFuncSig_RealIsFalse) case types.ETDecimal: - sig = &builtinDecimalIsFalseSig{bf} + sig = &builtinDecimalIsFalseSig{bf, c.keepNull} sig.setPbCode(tipb.ScalarFuncSig_DecimalIsFalse) case types.ETInt: - sig = &builtinIntIsFalseSig{bf} + sig = &builtinIntIsFalseSig{bf, c.keepNull} sig.setPbCode(tipb.ScalarFuncSig_IntIsFalse) default: return nil, errors.Errorf("unexpected types.EvalType %v", argTp) @@ -431,10 +455,11 @@ func (c *isTrueOrFalseFunctionClass) getFunction(ctx sessionctx.Context, args [] type builtinRealIsTrueSig struct { baseBuiltinFunc + keepNull bool } func (b *builtinRealIsTrueSig) Clone() builtinFunc { - newSig := &builtinRealIsTrueSig{} + newSig := &builtinRealIsTrueSig{keepNull: b.keepNull} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -444,6 +469,9 @@ func (b *builtinRealIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) { if err != nil { return 0, true, err } + if b.keepNull && isNull { + return 0, true, nil + } if isNull || input == 0 { return 0, false, nil } @@ -452,10 +480,11 @@ func (b *builtinRealIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) { type builtinDecimalIsTrueSig struct { baseBuiltinFunc + keepNull bool } func (b *builtinDecimalIsTrueSig) Clone() builtinFunc { - newSig := &builtinDecimalIsTrueSig{} + newSig := &builtinDecimalIsTrueSig{keepNull: b.keepNull} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -465,6 +494,9 @@ func (b *builtinDecimalIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) { if err != nil { return 0, true, err } + if b.keepNull && isNull { + return 0, true, nil + } if isNull || input.IsZero() { return 0, false, nil } @@ -473,10 +505,11 @@ func (b *builtinDecimalIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) { type builtinIntIsTrueSig struct { baseBuiltinFunc + keepNull bool } func (b *builtinIntIsTrueSig) Clone() builtinFunc { - newSig := &builtinIntIsTrueSig{} + newSig := &builtinIntIsTrueSig{keepNull: b.keepNull} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -486,6 +519,9 @@ func (b *builtinIntIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) { if err != nil { return 0, true, err } + if b.keepNull && isNull { + return 0, true, nil + } if isNull || input == 0 { return 0, false, nil } @@ -494,10 +530,11 @@ func (b *builtinIntIsTrueSig) evalInt(row chunk.Row) (int64, bool, error) { type builtinRealIsFalseSig struct { baseBuiltinFunc + keepNull bool } func (b *builtinRealIsFalseSig) Clone() builtinFunc { - newSig := &builtinRealIsFalseSig{} + newSig := &builtinRealIsFalseSig{keepNull: b.keepNull} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -507,6 +544,9 @@ func (b *builtinRealIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) { if err != nil { return 0, true, err } + if b.keepNull && isNull { + return 0, true, nil + } if isNull || input != 0 { return 0, false, nil } @@ -515,10 +555,11 @@ func (b *builtinRealIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) { type builtinDecimalIsFalseSig struct { baseBuiltinFunc + keepNull bool } func (b *builtinDecimalIsFalseSig) Clone() builtinFunc { - newSig := &builtinDecimalIsFalseSig{} + newSig := &builtinDecimalIsFalseSig{keepNull: b.keepNull} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -528,6 +569,9 @@ func (b *builtinDecimalIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) { if err != nil { return 0, true, err } + if b.keepNull && isNull { + return 0, true, nil + } if isNull || !input.IsZero() { return 0, false, nil } @@ -536,10 +580,11 @@ func (b *builtinDecimalIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) { type builtinIntIsFalseSig struct { baseBuiltinFunc + keepNull bool } func (b *builtinIntIsFalseSig) Clone() builtinFunc { - newSig := &builtinIntIsFalseSig{} + newSig := &builtinIntIsFalseSig{keepNull: b.keepNull} newSig.cloneFrom(&b.baseBuiltinFunc) return newSig } @@ -549,6 +594,9 @@ func (b *builtinIntIsFalseSig) evalInt(row chunk.Row) (int64, bool, error) { if err != nil { return 0, true, err } + if b.keepNull && isNull { + return 0, true, nil + } if isNull || input != 0 { return 0, false, nil } diff --git a/expression/builtin_op_test.go b/expression/builtin_op_test.go index 99fb13d76be24..923d106c32c7e 100644 --- a/expression/builtin_op_test.go +++ b/expression/builtin_op_test.go @@ -86,11 +86,21 @@ func (s *testEvaluatorSuite) TestLogicAnd(c *C) { {[]interface{}{0, 1}, 0, false, false}, {[]interface{}{0, 0}, 0, false, false}, {[]interface{}{2, -1}, 1, false, false}, + {[]interface{}{"a", "0"}, 0, false, false}, {[]interface{}{"a", "1"}, 0, false, false}, + {[]interface{}{"1a", "0"}, 0, false, false}, {[]interface{}{"1a", "1"}, 1, false, false}, {[]interface{}{0, nil}, 0, false, false}, {[]interface{}{nil, 0}, 0, false, false}, {[]interface{}{nil, 1}, 0, true, false}, + {[]interface{}{0.001, 0}, 0, false, false}, + {[]interface{}{0.001, 1}, 1, false, false}, + {[]interface{}{nil, 0.000}, 0, false, false}, + {[]interface{}{nil, 0.001}, 0, true, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), 0}, 0, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 0, true, false}, {[]interface{}{errors.New("must error"), 1}, 0, false, true}, } @@ -300,11 +310,25 @@ func (s *testEvaluatorSuite) TestLogicOr(c *C) { {[]interface{}{0, 1}, 1, false, false}, {[]interface{}{0, 0}, 0, false, false}, {[]interface{}{2, -1}, 1, false, false}, + {[]interface{}{"a", "0"}, 0, false, false}, {[]interface{}{"a", "1"}, 1, false, false}, + {[]interface{}{"1a", "0"}, 1, false, false}, {[]interface{}{"1a", "1"}, 1, false, false}, + {[]interface{}{"0.0a", 0}, 0, false, false}, + {[]interface{}{"0.0001a", 0}, 1, false, false}, {[]interface{}{1, nil}, 1, false, false}, {[]interface{}{nil, 1}, 1, false, false}, {[]interface{}{nil, 0}, 0, true, false}, + {[]interface{}{0.000, 0}, 0, false, false}, + {[]interface{}{0.001, 0}, 1, false, false}, + {[]interface{}{nil, 0.000}, 0, true, false}, + {[]interface{}{nil, 0.001}, 1, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000000"), 0}, 0, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000000"), 1}, 1, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, true, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), 0}, 1, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 1, false, false}, {[]interface{}{errors.New("must error"), 1}, 0, false, true}, } @@ -559,3 +583,68 @@ func (s *testEvaluatorSuite) TestIsTrueOrFalse(c *C) { c.Assert(isFalse, testutil.DatumEquals, types.NewDatum(tc.isFalse)) } } + +func (s *testEvaluatorSuite) TestLogicXor(c *C) { + defer testleak.AfterTest(c)() + + sc := s.ctx.GetSessionVars().StmtCtx + origin := sc.IgnoreTruncate + defer func() { + sc.IgnoreTruncate = origin + }() + sc.IgnoreTruncate = true + + cases := []struct { + args []interface{} + expected int64 + isNil bool + getErr bool + }{ + {[]interface{}{1, 1}, 0, false, false}, + {[]interface{}{1, 0}, 1, false, false}, + {[]interface{}{0, 1}, 1, false, false}, + {[]interface{}{0, 0}, 0, false, false}, + {[]interface{}{2, -1}, 0, false, false}, + {[]interface{}{"a", "0"}, 0, false, false}, + {[]interface{}{"a", "1"}, 1, false, false}, + {[]interface{}{"1a", "0"}, 1, false, false}, + {[]interface{}{"1a", "1"}, 0, false, false}, + {[]interface{}{0, nil}, 0, true, false}, + {[]interface{}{nil, 0}, 0, true, false}, + {[]interface{}{nil, 1}, 0, true, false}, + {[]interface{}{0.5000, 0.4999}, 1, false, false}, + {[]interface{}{0.5000, 1.0}, 0, false, false}, + {[]interface{}{0.4999, 1.0}, 1, false, false}, + {[]interface{}{nil, 0.000}, 0, true, false}, + {[]interface{}{nil, 0.001}, 0, true, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), 0.00001}, 0, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), 1}, 1, false, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000000"), nil}, 0, true, false}, + {[]interface{}{types.NewDecFromStringForTest("0.000001"), nil}, 0, true, false}, + + {[]interface{}{errors.New("must error"), 1}, 0, false, true}, + } + + for _, t := range cases { + f, err := newFunctionForTest(s.ctx, ast.LogicXor, s.primitiveValsToConstants(t.args)...) + c.Assert(err, IsNil) + d, err := f.Eval(chunk.Row{}) + if t.getErr { + c.Assert(err, NotNil) + } else { + c.Assert(err, IsNil) + if t.isNil { + c.Assert(d.Kind(), Equals, types.KindNull) + } else { + c.Assert(d.GetInt64(), Equals, t.expected) + } + } + } + + // Test incorrect parameter count. + _, err := newFunctionForTest(s.ctx, ast.LogicXor, Zero) + c.Assert(err, NotNil) + + _, err = funcs[ast.LogicXor].getFunction(s.ctx, []Expression{Zero, Zero}) + c.Assert(err, IsNil) +} diff --git a/expression/distsql_builtin.go b/expression/distsql_builtin.go index 4f24bebdb17b7..e3a2735b5d5f5 100644 --- a/expression/distsql_builtin.go +++ b/expression/distsql_builtin.go @@ -376,17 +376,17 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti f = &builtinCaseWhenIntSig{base} case tipb.ScalarFuncSig_IntIsFalse: - f = &builtinIntIsFalseSig{base} + f = &builtinIntIsFalseSig{base, false} case tipb.ScalarFuncSig_RealIsFalse: - f = &builtinRealIsFalseSig{base} + f = &builtinRealIsFalseSig{base, false} case tipb.ScalarFuncSig_DecimalIsFalse: - f = &builtinDecimalIsFalseSig{base} + f = &builtinDecimalIsFalseSig{base, false} case tipb.ScalarFuncSig_IntIsTrue: - f = &builtinIntIsTrueSig{base} + f = &builtinIntIsTrueSig{base, false} case tipb.ScalarFuncSig_RealIsTrue: - f = &builtinRealIsTrueSig{base} + f = &builtinRealIsTrueSig{base, false} case tipb.ScalarFuncSig_DecimalIsTrue: - f = &builtinDecimalIsTrueSig{base} + f = &builtinDecimalIsTrueSig{base, false} case tipb.ScalarFuncSig_IfNullReal: f = &builtinIfNullRealSig{base} diff --git a/expression/expression.go b/expression/expression.go index f2228ac03769b..d29d3281096a4 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/parser/ast" "github.com/pingcap/parser/model" "github.com/pingcap/parser/mysql" + "github.com/pingcap/parser/opcode" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/stmtctx" @@ -559,3 +560,24 @@ func CheckExprPushFlash(exprs []Expression) (exprPush, remain []Expression) { } return } + +// wrapWithIsTrue wraps `arg` with istrue function if the return type of expr is not +// type int, otherwise, returns `arg` directly. +// The `keepNull` controls what the istrue function will return when `arg` is null: +// 1. keepNull is true and arg is null, the istrue function returns null. +// 2. keepNull is false and arg is null, the istrue function returns 0. +func wrapWithIsTrue(ctx sessionctx.Context, keepNull bool, arg Expression) (Expression, error) { + if arg.GetType().EvalType() == types.ETInt { + return arg, nil + } + fc := &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth, keepNull} + f, err := fc.getFunction(ctx, []Expression{arg}) + if err != nil { + return nil, err + } + return &ScalarFunction{ + FuncName: model.NewCIStr(fmt.Sprintf("sig_%T", f)), + Function: f, + RetType: f.getRetTp(), + }, nil +}