Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression, plan: rewrite builtin function: IS TRUE && IS FALSE #4086

Merged
merged 18 commits into from
Aug 11, 2017
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -787,8 +787,8 @@ var funcs = map[string]functionClass{
ast.UnaryPlus: &unaryOpFunctionClass{baseFunctionClass{ast.UnaryPlus, 1, 1}, opcode.Plus},
ast.UnaryMinus: &unaryMinusFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}},
ast.In: &inFunctionClass{baseFunctionClass{ast.In, 1, -1}},
ast.IsTruth: &isTrueOpFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth},
ast.IsFalsity: &isTrueOpFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity},
ast.IsTruth: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth},
ast.IsFalsity: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity},
ast.Like: &likeFunctionClass{baseFunctionClass{ast.Like, 2, 3}},
ast.Regexp: &regexpFunctionClass{baseFunctionClass{ast.Regexp, 2, 2}},
ast.Case: &caseWhenFunctionClass{baseFunctionClass{ast.Case, 1, -1}},
Expand Down
60 changes: 36 additions & 24 deletions expression/builtin_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ var (
_ functionClass = &logicAndFunctionClass{}
_ functionClass = &logicOrFunctionClass{}
_ functionClass = &logicXorFunctionClass{}
_ functionClass = &isTrueOpFunctionClass{}
_ functionClass = &isTrueOrFalseFunctionClass{}
_ functionClass = &unaryOpFunctionClass{}
_ functionClass = &unaryMinusFunctionClass{}
_ functionClass = &isNullFunctionClass{}
Expand All @@ -39,7 +39,8 @@ var (
_ builtinFunc = &builtinLogicAndSig{}
_ builtinFunc = &builtinLogicOrSig{}
_ builtinFunc = &builtinLogicXorSig{}
_ builtinFunc = &builtinIsTrueOpSig{}
_ builtinFunc = &builtinIsTrueSig{}
_ builtinFunc = &builtinIsFalseSig{}
_ builtinFunc = &builtinUnaryOpSig{}
_ builtinFunc = &builtinUnaryMinusIntSig{}
_ builtinFunc = &builtinIsNullSig{}
Expand Down Expand Up @@ -341,40 +342,51 @@ func (b *builtinRightShiftSig) evalInt(row []types.Datum) (int64, bool, error) {
return int64(uint64(arg0) >> uint64(arg1)), false, nil
}

type isTrueOpFunctionClass struct {
type isTrueOrFalseFunctionClass struct {
baseFunctionClass

op opcode.Op
}

func (c *isTrueOpFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
sig := &builtinIsTrueOpSig{newBaseBuiltinFunc(args, ctx), c.op}
return sig.setSelf(sig), errors.Trace(c.verifyArgs(args))
func (c *isTrueOrFalseFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}
bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpInt, tpInt)
if err != nil {
return nil, errors.Trace(err)
}
bf.tp.Flen = 1
if c.op == opcode.IsTruth {
sig := &builtinIsTrueSig{baseIntBuiltinFunc{bf}}
return sig.setSelf(sig), nil
}
sig := &builtinIsFalseSig{baseIntBuiltinFunc{bf}}
return sig.setSelf(sig), nil
}

type builtinIsTrueOpSig struct {
baseBuiltinFunc
type builtinIsTrueSig struct{ baseIntBuiltinFunc }
type builtinIsFalseSig struct{ baseIntBuiltinFunc }

op opcode.Op
func (b *builtinIsTrueSig) evalInt(row []types.Datum) (int64, bool, error) {
boolResult, isNull, err := b.args[0].EvalInt(row, b.ctx.GetSessionVars().StmtCtx)
if err != nil {
return 0, true, errors.Trace(err)
}
if isNull || boolResult == 0 {
return 0, false, nil
}
return 1, false, nil
}

func (b *builtinIsTrueOpSig) eval(row []types.Datum) (d types.Datum, err error) {
args, err := b.evalArgs(row)
func (b *builtinIsFalseSig) evalInt(row []types.Datum) (int64, bool, error) {
boolResult, isNull, err := b.args[0].EvalInt(row, b.ctx.GetSessionVars().StmtCtx)
if err != nil {
return types.Datum{}, errors.Trace(err)
return 0, true, errors.Trace(err)
}
var boolVal bool
if !args[0].IsNull() {
iVal, err := args[0].ToBool(b.ctx.GetSessionVars().StmtCtx)
if err != nil {
return d, errors.Trace(err)
}
if (b.op == opcode.IsTruth && iVal == 1) || (b.op == opcode.IsFalsity && iVal == 0) {
boolVal = true
}
if isNull || boolResult != 0 {
return 0, false, nil
}
d.SetInt64(boolToInt64(boolVal))
return
return 1, false, nil
}

type bitNegFunctionClass struct {
Expand Down
75 changes: 75 additions & 0 deletions expression/builtin_op_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/util/testleak"
"github.com/pingcap/tidb/util/testutil"
"github.com/pingcap/tidb/util/types"
)

Expand Down Expand Up @@ -476,3 +477,77 @@ func (s *testEvaluatorSuite) TestUnaryNot(c *C) {
c.Assert(err, IsNil)
c.Assert(f.isDeterministic(), IsTrue)
}

func (s *testEvaluatorSuite) TestIsTrueOrFalse(c *C) {
defer testleak.AfterTest(c)()
sc := s.ctx.GetSessionVars().StmtCtx
origin := sc.IgnoreTruncate
defer func() {
sc.IgnoreTruncate = origin
}()
sc.IgnoreTruncate = true

testCases := []struct {
args []interface{}
isTrue interface{}
isFalse interface{}
}{
{
args: []interface{}{int64(-12)},
isTrue: int64(1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove int64()

isFalse: int64(0),
},
{
args: []interface{}{int64(12)},
isTrue: int64(1),
isFalse: int64(0),
},
{
args: []interface{}{int64(0)},
isTrue: int64(0),
isFalse: int64(1),
},
{
args: []interface{}{float64(0)},
isTrue: int64(0),
isFalse: int64(1),
},
{
args: []interface{}{"aaa"},
isTrue: int64(0),
isFalse: int64(1),
},
{
args: []interface{}{""},
isTrue: int64(0),
isFalse: int64(1),
},
{
args: []interface{}{nil},
isTrue: int64(0),
isFalse: int64(0),
},
}

for _, tc := range testCases {
isTrueSig, err := funcs[ast.IsTruth].getFunction(datumsToConstants(types.MakeDatums(tc.args...)), s.ctx)
c.Assert(err, IsNil)
c.Assert(isTrueSig, NotNil)
c.Assert(isTrueSig.isDeterministic(), IsTrue)

isTrue, err := isTrueSig.eval(nil)
c.Assert(err, IsNil)
c.Assert(isTrue, testutil.DatumEquals, types.NewDatum(tc.isTrue))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/ NewDatum/ NewIntDatum(int64(tc.isTrue))

}

for _, tc := range testCases {
isFalseSig, err := funcs[ast.IsFalsity].getFunction(datumsToConstants(types.MakeDatums(tc.args...)), s.ctx)
c.Assert(err, IsNil)
c.Assert(isFalseSig, NotNil)
c.Assert(isFalseSig.isDeterministic(), IsTrue)

isFalse, err := isFalseSig.eval(nil)
c.Assert(err, IsNil)
c.Assert(isFalse, testutil.DatumEquals, types.NewDatum(tc.isFalse))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

}
}
7 changes: 6 additions & 1 deletion expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,7 @@ func (s *testIntegrationSuite) TestBuiltin(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")

// for is true
// for is true && is false
tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a int, b int, index idx_b (b))")
tk.MustExec("insert t values (1, 1)")
Expand All @@ -688,6 +688,11 @@ func (s *testIntegrationSuite) TestBuiltin(c *C) {
result.Check(nil)
result = tk.MustQuery("select * from t where a is not true")
result.Check(nil)
result = tk.MustQuery(`select 1 is true, 0 is true, null is true, "aaa" is true, "" is true, -12.00 is true, 0.0 is true;`)
result.Check(testkit.Rows("1 0 0 0 0 1 0"))
result = tk.MustQuery(`select 1 is false, 0 is false, null is false, "aaa" is false, "" is false, -12.00 is false, 0.0 is false;`)
result.Check(testkit.Rows("0 1 0 1 1 0 1"))

// for in
result = tk.MustQuery("select * from t where b in (a)")
result.Check(testkit.Rows("1 1", "2 2"))
Expand Down
22 changes: 22 additions & 0 deletions plan/typeinfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -302,6 +302,28 @@ func (s *testPlanSuite) createTestCase4MathFuncs() []typeInferTestCase {
{"ceiling(c_text)", mysql.TypeDouble, charset.CharsetBin, mysql.BinaryFlag, mysql.MaxRealWidth, 0},
{"ceiling(18446744073709551615)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 20, 0},
{"ceiling(18446744073709551615.1)", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag, 22, 0},

{"c_int is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_decimal is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_double is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_float is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_datetime is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_time is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_enum is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_text is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"18446 is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"1844674.1 is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},

{"c_int is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_decimal is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_double is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_float is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_datetime is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_time is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_enum is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_text is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"18446 is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"1844674.1 is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
}
}

Expand Down