Skip to content

Commit

Permalink
expression, plan: rewrite builtin function: IS TRUE && IS FALSE (#4086)
Browse files Browse the repository at this point in the history
  • Loading branch information
zz-jason authored and XuHuaiyu committed Aug 11, 2017
1 parent 1f75c7f commit eb7eb43
Show file tree
Hide file tree
Showing 5 changed files with 240 additions and 27 deletions.
4 changes: 2 additions & 2 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -879,8 +879,8 @@ var funcs = map[string]functionClass{
ast.UnaryPlus: &unaryOpFunctionClass{baseFunctionClass{ast.UnaryPlus, 1, 1}, opcode.Plus},
ast.UnaryMinus: &unaryMinusFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}},
ast.In: &inFunctionClass{baseFunctionClass{ast.In, 1, -1}},
ast.IsTruth: &isTrueOpFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth},
ast.IsFalsity: &isTrueOpFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity},
ast.IsTruth: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth},
ast.IsFalsity: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity},
ast.Like: &likeFunctionClass{baseFunctionClass{ast.Like, 2, 3}},
ast.Regexp: &regexpFunctionClass{baseFunctionClass{ast.Regexp, 2, 2}},
ast.Case: &caseWhenFunctionClass{baseFunctionClass{ast.Case, 1, -1}},
Expand Down
154 changes: 130 additions & 24 deletions expression/builtin_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ var (
_ functionClass = &logicAndFunctionClass{}
_ functionClass = &logicOrFunctionClass{}
_ functionClass = &logicXorFunctionClass{}
_ functionClass = &isTrueOpFunctionClass{}
_ functionClass = &isTrueOrFalseFunctionClass{}
_ functionClass = &unaryOpFunctionClass{}
_ functionClass = &unaryMinusFunctionClass{}
_ functionClass = &isNullFunctionClass{}
Expand All @@ -39,7 +39,12 @@ var (
_ builtinFunc = &builtinLogicAndSig{}
_ builtinFunc = &builtinLogicOrSig{}
_ builtinFunc = &builtinLogicXorSig{}
_ builtinFunc = &builtinIsTrueOpSig{}
_ builtinFunc = &builtinRealIsTrueSig{}
_ builtinFunc = &builtinDecimalIsTrueSig{}
_ builtinFunc = &builtinIntIsTrueSig{}
_ builtinFunc = &builtinRealIsFalseSig{}
_ builtinFunc = &builtinDecimalIsFalseSig{}
_ builtinFunc = &builtinIntIsFalseSig{}
_ builtinFunc = &builtinUnaryOpSig{}
_ builtinFunc = &builtinUnaryMinusIntSig{}
_ builtinFunc = &builtinDecimalIsNullSig{}
Expand Down Expand Up @@ -346,40 +351,141 @@ func (b *builtinRightShiftSig) evalInt(row []types.Datum) (int64, bool, error) {
return int64(uint64(arg0) >> uint64(arg1)), false, nil
}

type isTrueOpFunctionClass struct {
type isTrueOrFalseFunctionClass struct {
baseFunctionClass

op opcode.Op
}

func (c *isTrueOpFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
sig := &builtinIsTrueOpSig{newBaseBuiltinFunc(args, ctx), c.op}
return sig.setSelf(sig), errors.Trace(c.verifyArgs(args))
func (c *isTrueOrFalseFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}

argTp := tpInt
switch args[0].GetTypeClass() {
case types.ClassReal:
argTp = tpReal
case types.ClassDecimal:
argTp = tpDecimal
}
bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpInt, argTp)
if err != nil {
return nil, errors.Trace(err)
}
bf.tp.Flen = 1

var sig builtinFunc
switch c.op {
case opcode.IsTruth:
switch argTp {
case tpReal:
sig = &builtinRealIsTrueSig{baseIntBuiltinFunc{bf}}
case tpDecimal:
sig = &builtinDecimalIsTrueSig{baseIntBuiltinFunc{bf}}
case tpInt:
sig = &builtinIntIsTrueSig{baseIntBuiltinFunc{bf}}
}
case opcode.IsFalsity:
switch argTp {
case tpReal:
sig = &builtinRealIsFalseSig{baseIntBuiltinFunc{bf}}
case tpDecimal:
sig = &builtinDecimalIsFalseSig{baseIntBuiltinFunc{bf}}
case tpInt:
sig = &builtinIntIsFalseSig{baseIntBuiltinFunc{bf}}
}
}
return sig.setSelf(sig), nil
}

type builtinIsTrueOpSig struct {
baseBuiltinFunc
type builtinRealIsTrueSig struct {
baseIntBuiltinFunc
}

op opcode.Op
func (b *builtinRealIsTrueSig) evalInt(row []types.Datum) (int64, bool, error) {
input, isNull, err := b.args[0].EvalReal(row, b.ctx.GetSessionVars().StmtCtx)
if err != nil {
return 0, true, errors.Trace(err)
}
if isNull || input == 0 {
return 0, false, nil
}
return 1, false, nil
}

func (b *builtinIsTrueOpSig) eval(row []types.Datum) (d types.Datum, err error) {
args, err := b.evalArgs(row)
type builtinDecimalIsTrueSig struct {
baseIntBuiltinFunc
}

func (b *builtinDecimalIsTrueSig) evalInt(row []types.Datum) (int64, bool, error) {
input, isNull, err := b.args[0].EvalDecimal(row, b.ctx.GetSessionVars().StmtCtx)
if err != nil {
return types.Datum{}, errors.Trace(err)
return 0, true, errors.Trace(err)
}
var boolVal bool
if !args[0].IsNull() {
iVal, err := args[0].ToBool(b.ctx.GetSessionVars().StmtCtx)
if err != nil {
return d, errors.Trace(err)
}
if (b.op == opcode.IsTruth && iVal == 1) || (b.op == opcode.IsFalsity && iVal == 0) {
boolVal = true
}
if isNull || input.IsZero() {
return 0, false, nil
}
d.SetInt64(boolToInt64(boolVal))
return
return 1, false, nil
}

type builtinIntIsTrueSig struct {
baseIntBuiltinFunc
}

func (b *builtinIntIsTrueSig) evalInt(row []types.Datum) (int64, bool, error) {
input, isNull, err := b.args[0].EvalInt(row, b.ctx.GetSessionVars().StmtCtx)
if err != nil {
return 0, true, errors.Trace(err)
}
if isNull || input == 0 {
return 0, false, nil
}
return 1, false, nil
}

type builtinRealIsFalseSig struct {
baseIntBuiltinFunc
}

func (b *builtinRealIsFalseSig) evalInt(row []types.Datum) (int64, bool, error) {
input, isNull, err := b.args[0].EvalReal(row, b.ctx.GetSessionVars().StmtCtx)
if err != nil {
return 0, true, errors.Trace(err)
}
if isNull || input != 0 {
return 0, false, nil
}
return 1, false, nil
}

type builtinDecimalIsFalseSig struct {
baseIntBuiltinFunc
}

func (b *builtinDecimalIsFalseSig) evalInt(row []types.Datum) (int64, bool, error) {
input, isNull, err := b.args[0].EvalDecimal(row, b.ctx.GetSessionVars().StmtCtx)
if err != nil {
return 0, true, errors.Trace(err)
}
if isNull || !input.IsZero() {
return 0, false, nil
}
return 1, false, nil
}

type builtinIntIsFalseSig struct {
baseIntBuiltinFunc
}

func (b *builtinIntIsFalseSig) evalInt(row []types.Datum) (int64, bool, error) {
input, isNull, err := b.args[0].EvalInt(row, b.ctx.GetSessionVars().StmtCtx)
if err != nil {
return 0, true, errors.Trace(err)
}
if isNull || input != 0 {
return 0, false, nil
}
return 1, false, nil
}

type bitNegFunctionClass struct {
Expand Down
75 changes: 75 additions & 0 deletions expression/builtin_op_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/util/testleak"
"github.com/pingcap/tidb/util/testutil"
"github.com/pingcap/tidb/util/types"
)

Expand Down Expand Up @@ -476,3 +477,77 @@ func (s *testEvaluatorSuite) TestUnaryNot(c *C) {
c.Assert(err, IsNil)
c.Assert(f.isDeterministic(), IsTrue)
}

func (s *testEvaluatorSuite) TestIsTrueOrFalse(c *C) {
defer testleak.AfterTest(c)()
sc := s.ctx.GetSessionVars().StmtCtx
origin := sc.IgnoreTruncate
defer func() {
sc.IgnoreTruncate = origin
}()
sc.IgnoreTruncate = true

testCases := []struct {
args []interface{}
isTrue interface{}
isFalse interface{}
}{
{
args: []interface{}{-12},
isTrue: 1,
isFalse: 0,
},
{
args: []interface{}{12},
isTrue: 1,
isFalse: 0,
},
{
args: []interface{}{0},
isTrue: 0,
isFalse: 1,
},
{
args: []interface{}{float64(0)},
isTrue: 0,
isFalse: 1,
},
{
args: []interface{}{"aaa"},
isTrue: 0,
isFalse: 1,
},
{
args: []interface{}{""},
isTrue: 0,
isFalse: 1,
},
{
args: []interface{}{nil},
isTrue: 0,
isFalse: 0,
},
}

for _, tc := range testCases {
isTrueSig, err := funcs[ast.IsTruth].getFunction(datumsToConstants(types.MakeDatums(tc.args...)), s.ctx)
c.Assert(err, IsNil)
c.Assert(isTrueSig, NotNil)
c.Assert(isTrueSig.isDeterministic(), IsTrue)

isTrue, err := isTrueSig.eval(nil)
c.Assert(err, IsNil)
c.Assert(isTrue, testutil.DatumEquals, types.NewDatum(tc.isTrue))
}

for _, tc := range testCases {
isFalseSig, err := funcs[ast.IsFalsity].getFunction(datumsToConstants(types.MakeDatums(tc.args...)), s.ctx)
c.Assert(err, IsNil)
c.Assert(isFalseSig, NotNil)
c.Assert(isFalseSig.isDeterministic(), IsTrue)

isFalse, err := isFalseSig.eval(nil)
c.Assert(err, IsNil)
c.Assert(isFalse, testutil.DatumEquals, types.NewDatum(tc.isFalse))
}
}
7 changes: 6 additions & 1 deletion expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -830,7 +830,7 @@ func (s *testIntegrationSuite) TestBuiltin(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")

// for is true
// for is true && is false
tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a int, b int, index idx_b (b))")
tk.MustExec("insert t values (1, 1)")
Expand All @@ -844,6 +844,11 @@ func (s *testIntegrationSuite) TestBuiltin(c *C) {
result.Check(nil)
result = tk.MustQuery("select * from t where a is not true")
result.Check(nil)
result = tk.MustQuery(`select 1 is true, 0 is true, null is true, "aaa" is true, "" is true, -12.00 is true, 0.0 is true, 0.0000001 is true;`)
result.Check(testkit.Rows("1 0 0 0 0 1 0 1"))
result = tk.MustQuery(`select 1 is false, 0 is false, null is false, "aaa" is false, "" is false, -12.00 is false, 0.0 is false, 0.0000001 is false;`)
result.Check(testkit.Rows("0 1 0 1 1 0 1 0"))

// for in
result = tk.MustQuery("select * from t where b in (a)")
result.Check(testkit.Rows("1 1", "2 2"))
Expand Down
27 changes: 27 additions & 0 deletions plan/typeinfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ func (s *testPlanSuite) TestInferType(c *C) {
tests = append(tests, s.createTestCase4EncryptionFuncs()...)
tests = append(tests, s.createTestCase4CompareFuncs()...)
tests = append(tests, s.createTestCase4Miscellaneous()...)
tests = append(tests, s.createTestCase4OpFuncs()...)

for _, tt := range tests {
ctx := testKit.Se.(context.Context)
Expand Down Expand Up @@ -584,3 +585,29 @@ func (s *testPlanSuite) createTestCase4Miscellaneous() []typeInferTestCase {
{"sleep(c_binary)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 20, 0},
}
}

func (s *testPlanSuite) createTestCase4OpFuncs() []typeInferTestCase {
return []typeInferTestCase{
{"c_int is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_decimal is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_double is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_float is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_datetime is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_time is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_enum is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_text is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"18446 is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"1844674.1 is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},

{"c_int is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_decimal is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_double is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_float is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_datetime is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_time is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_enum is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_text is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"18446 is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"1844674.1 is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
}
}

0 comments on commit eb7eb43

Please sign in to comment.