Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression, plan: rewrite builtin function: IS TRUE && IS FALSE #4086

Merged
merged 18 commits into from
Aug 11, 2017
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -787,8 +787,8 @@ var funcs = map[string]functionClass{
ast.UnaryPlus: &unaryOpFunctionClass{baseFunctionClass{ast.UnaryPlus, 1, 1}, opcode.Plus},
ast.UnaryMinus: &unaryMinusFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}},
ast.In: &inFunctionClass{baseFunctionClass{ast.In, 1, -1}},
ast.IsTruth: &isTrueOpFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth},
ast.IsFalsity: &isTrueOpFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity},
ast.IsTruth: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth},
ast.IsFalsity: &isTrueOrFalseFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity},
ast.Like: &likeFunctionClass{baseFunctionClass{ast.Like, 2, 3}},
ast.Regexp: &regexpFunctionClass{baseFunctionClass{ast.Regexp, 2, 2}},
ast.Case: &caseWhenFunctionClass{baseFunctionClass{ast.Case, 1, -1}},
Expand Down
131 changes: 107 additions & 24 deletions expression/builtin_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ var (
_ functionClass = &logicAndFunctionClass{}
_ functionClass = &logicOrFunctionClass{}
_ functionClass = &logicXorFunctionClass{}
_ functionClass = &isTrueOpFunctionClass{}
_ functionClass = &isTrueOrFalseFunctionClass{}
_ functionClass = &unaryOpFunctionClass{}
_ functionClass = &unaryMinusFunctionClass{}
_ functionClass = &isNullFunctionClass{}
Expand All @@ -39,7 +39,12 @@ var (
_ builtinFunc = &builtinLogicAndSig{}
_ builtinFunc = &builtinLogicOrSig{}
_ builtinFunc = &builtinLogicXorSig{}
_ builtinFunc = &builtinIsTrueOpSig{}
_ builtinFunc = &builtinRealIsTrueSig{}
_ builtinFunc = &builtinDecimalIsTrueSig{}
_ builtinFunc = &builtinIntIsTrueSig{}
_ builtinFunc = &builtinRealIsFalseSig{}
_ builtinFunc = &builtinDecimalIsFalseSig{}
_ builtinFunc = &builtinIntIsFalseSig{}
_ builtinFunc = &builtinUnaryOpSig{}
_ builtinFunc = &builtinUnaryMinusIntSig{}
_ builtinFunc = &builtinIsNullSig{}
Expand Down Expand Up @@ -341,40 +346,118 @@ func (b *builtinRightShiftSig) evalInt(row []types.Datum) (int64, bool, error) {
return int64(uint64(arg0) >> uint64(arg1)), false, nil
}

type isTrueOpFunctionClass struct {
type isTrueOrFalseFunctionClass struct {
baseFunctionClass

op opcode.Op
}

func (c *isTrueOpFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
sig := &builtinIsTrueOpSig{newBaseBuiltinFunc(args, ctx), c.op}
return sig.setSelf(sig), errors.Trace(c.verifyArgs(args))
func (c *isTrueOrFalseFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, errors.Trace(err)
}

argTp := tpInt
switch args[0].GetTypeClass() {
case types.ClassReal:
argTp = tpReal
case types.ClassDecimal:
argTp = tpDecimal
}
bf, err := newBaseBuiltinFuncWithTp(args, ctx, tpInt, argTp)
if err != nil {
return nil, errors.Trace(err)
}
bf.tp.Flen = 1

var sig builtinFunc
switch {
case c.op == opcode.IsTruth && argTp == tpReal:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

two switch may be more explict?
siwtch c.op{
switch argtP{
}
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need, it makes code complex

sig = &builtinRealIsTrueSig{baseIntBuiltinFunc{bf}}
case c.op == opcode.IsTruth && argTp == tpDecimal:
sig = &builtinDecimalIsTrueSig{baseIntBuiltinFunc{bf}}
case c.op == opcode.IsTruth && argTp == tpInt:
sig = &builtinIntIsTrueSig{baseIntBuiltinFunc{bf}}
case argTp == tpReal:
sig = &builtinRealIsFalseSig{baseIntBuiltinFunc{bf}}
case argTp == tpDecimal:
sig = &builtinDecimalIsFalseSig{baseIntBuiltinFunc{bf}}
default:
sig = &builtinIntIsFalseSig{baseIntBuiltinFunc{bf}}
}
return sig.setSelf(sig), nil
}

type builtinIsTrueOpSig struct {
baseBuiltinFunc
type builtinRealIsTrueSig struct{ baseIntBuiltinFunc }
type builtinDecimalIsTrueSig struct{ baseIntBuiltinFunc }
type builtinIntIsTrueSig struct{ baseIntBuiltinFunc }
type builtinRealIsFalseSig struct{ baseIntBuiltinFunc }
type builtinDecimalIsFalseSig struct{ baseIntBuiltinFunc }
type builtinIntIsFalseSig struct{ baseIntBuiltinFunc }

op opcode.Op
func (b *builtinRealIsTrueSig) evalInt(row []types.Datum) (int64, bool, error) {
input, isNull, err := b.args[0].EvalReal(row, b.ctx.GetSessionVars().StmtCtx)
if err != nil {
return 0, true, errors.Trace(err)
}
if isNull || input == 0 {
return 0, false, nil
}
return 1, false, nil
}

func (b *builtinIsTrueOpSig) eval(row []types.Datum) (d types.Datum, err error) {
args, err := b.evalArgs(row)
func (b *builtinDecimalIsTrueSig) evalInt(row []types.Datum) (int64, bool, error) {
input, isNull, err := b.args[0].EvalDecimal(row, b.ctx.GetSessionVars().StmtCtx)
if err != nil {
return types.Datum{}, errors.Trace(err)
return 0, true, errors.Trace(err)
}
var boolVal bool
if !args[0].IsNull() {
iVal, err := args[0].ToBool(b.ctx.GetSessionVars().StmtCtx)
if err != nil {
return d, errors.Trace(err)
}
if (b.op == opcode.IsTruth && iVal == 1) || (b.op == opcode.IsFalsity && iVal == 0) {
boolVal = true
}
if isNull || input.IsZero() {
return 0, false, nil
}
d.SetInt64(boolToInt64(boolVal))
return
return 1, false, nil
}

func (b *builtinIntIsTrueSig) evalInt(row []types.Datum) (int64, bool, error) {
input, isNull, err := b.args[0].EvalInt(row, b.ctx.GetSessionVars().StmtCtx)
if err != nil {
return 0, true, errors.Trace(err)
}
if isNull || input == 0 {
return 0, false, nil
}
return 1, false, nil
}

func (b *builtinRealIsFalseSig) evalInt(row []types.Datum) (int64, bool, error) {
input, isNull, err := b.args[0].EvalReal(row, b.ctx.GetSessionVars().StmtCtx)
if err != nil {
return 0, true, errors.Trace(err)
}
if isNull || input != 0 {
return 0, false, nil
}
return 1, false, nil
}

func (b *builtinDecimalIsFalseSig) evalInt(row []types.Datum) (int64, bool, error) {
input, isNull, err := b.args[0].EvalDecimal(row, b.ctx.GetSessionVars().StmtCtx)
if err != nil {
return 0, true, errors.Trace(err)
}
if isNull || !input.IsZero() {
return 0, false, nil
}
return 1, false, nil
}

func (b *builtinIntIsFalseSig) evalInt(row []types.Datum) (int64, bool, error) {
input, isNull, err := b.args[0].EvalInt(row, b.ctx.GetSessionVars().StmtCtx)
if err != nil {
return 0, true, errors.Trace(err)
}
if isNull || input != 0 {
return 0, false, nil
}
return 1, false, nil
}

type bitNegFunctionClass struct {
Expand Down
75 changes: 75 additions & 0 deletions expression/builtin_op_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
. "github.com/pingcap/check"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/util/testleak"
"github.com/pingcap/tidb/util/testutil"
"github.com/pingcap/tidb/util/types"
)

Expand Down Expand Up @@ -476,3 +477,77 @@ func (s *testEvaluatorSuite) TestUnaryNot(c *C) {
c.Assert(err, IsNil)
c.Assert(f.isDeterministic(), IsTrue)
}

func (s *testEvaluatorSuite) TestIsTrueOrFalse(c *C) {
defer testleak.AfterTest(c)()
sc := s.ctx.GetSessionVars().StmtCtx
origin := sc.IgnoreTruncate
defer func() {
sc.IgnoreTruncate = origin
}()
sc.IgnoreTruncate = true

testCases := []struct {
args []interface{}
isTrue interface{}
isFalse interface{}
}{
{
args: []interface{}{-12},
isTrue: 1,
isFalse: 0,
},
{
args: []interface{}{12},
isTrue: 1,
isFalse: 0,
},
{
args: []interface{}{0},
isTrue: 0,
isFalse: 1,
},
{
args: []interface{}{float64(0)},
isTrue: 0,
isFalse: 1,
},
{
args: []interface{}{"aaa"},
isTrue: 0,
isFalse: 1,
},
{
args: []interface{}{""},
isTrue: 0,
isFalse: 1,
},
{
args: []interface{}{nil},
isTrue: 0,
isFalse: 0,
},
}

for _, tc := range testCases {
isTrueSig, err := funcs[ast.IsTruth].getFunction(datumsToConstants(types.MakeDatums(tc.args...)), s.ctx)
c.Assert(err, IsNil)
c.Assert(isTrueSig, NotNil)
c.Assert(isTrueSig.isDeterministic(), IsTrue)

isTrue, err := isTrueSig.eval(nil)
c.Assert(err, IsNil)
c.Assert(isTrue, testutil.DatumEquals, types.NewDatum(tc.isTrue))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/ NewDatum/ NewIntDatum(int64(tc.isTrue))

}

for _, tc := range testCases {
isFalseSig, err := funcs[ast.IsFalsity].getFunction(datumsToConstants(types.MakeDatums(tc.args...)), s.ctx)
c.Assert(err, IsNil)
c.Assert(isFalseSig, NotNil)
c.Assert(isFalseSig.isDeterministic(), IsTrue)

isFalse, err := isFalseSig.eval(nil)
c.Assert(err, IsNil)
c.Assert(isFalse, testutil.DatumEquals, types.NewDatum(tc.isFalse))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto

}
}
7 changes: 6 additions & 1 deletion expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ func (s *testIntegrationSuite) TestBuiltin(c *C) {
tk := testkit.NewTestKit(c, s.store)
tk.MustExec("use test")

// for is true
// for is true && is false
tk.MustExec("drop table if exists t")
tk.MustExec("create table t (a int, b int, index idx_b (b))")
tk.MustExec("insert t values (1, 1)")
Expand All @@ -826,6 +826,11 @@ func (s *testIntegrationSuite) TestBuiltin(c *C) {
result.Check(nil)
result = tk.MustQuery("select * from t where a is not true")
result.Check(nil)
result = tk.MustQuery(`select 1 is true, 0 is true, null is true, "aaa" is true, "" is true, -12.00 is true, 0.0 is true, 0.0000001 is true;`)
result.Check(testkit.Rows("1 0 0 0 0 1 0 1"))
result = tk.MustQuery(`select 1 is false, 0 is false, null is false, "aaa" is false, "" is false, -12.00 is false, 0.0 is false, 0.0000001 is false;`)
result.Check(testkit.Rows("0 1 0 1 1 0 1 0"))

// for in
result = tk.MustQuery("select * from t where b in (a)")
result.Check(testkit.Rows("1 1", "2 2"))
Expand Down
27 changes: 27 additions & 0 deletions plan/typeinfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ func (s *testPlanSuite) TestInferType(c *C) {
tests = append(tests, s.createTestCase4InfoFunc()...)
tests = append(tests, s.createTestCase4EncryptionFuncs()...)
tests = append(tests, s.createTestCase4Miscellaneous()...)
tests = append(tests, s.createTestCase4OpFuncs()...)

for _, tt := range tests {
ctx := testKit.Se.(context.Context)
Expand Down Expand Up @@ -538,3 +539,29 @@ func (s *testPlanSuite) createTestCase4Miscellaneous() []typeInferTestCase {
{"sleep(c_binary)", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 20, 0},
}
}

func (s *testPlanSuite) createTestCase4OpFuncs() []typeInferTestCase {
return []typeInferTestCase{
{"c_int is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_decimal is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_double is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_float is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_datetime is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_time is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_enum is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_text is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"18446 is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"1844674.1 is true", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},

{"c_int is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_decimal is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_double is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_float is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_datetime is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_time is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_enum is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"c_text is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"18446 is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
{"1844674.1 is false", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag, 1, 0},
}
}