Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: fix #3762, signed integer overflow handle in minus unary scalar function #3780

Merged
merged 24 commits into from
Jul 26, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
7762b8b
expression: handle int builtinUnaryOpSig overflow
winkyao Jul 17, 2017
f87e0af
expression: fix #3762, signed integer overflow handle in minus unary …
winkyao Jul 17, 2017
465dd1b
expression: fix #3762, add builtin_op_test.go, some builtin test cases
winkyao Jul 17, 2017
17f1e5f
*: tiny cleanup
winkyao Jul 17, 2017
1d15330
expression: set the correct unary minus function ScalarFunction.RetType
winkyao Jul 17, 2017
a803f9d
expression: rewrite unary minus builtin function to handle overflow.
winkyao Jul 19, 2017
d6e6c25
plan: fix popRowArg bug: when get expression.Constatn and getRowLen i…
winkyao Jul 20, 2017
ae342fc
*:git stash
winkyao Jul 24, 2017
869e1f6
Merge branch 'master' of https://github.com/pingcap/tidb into winkyao…
winkyao Jul 25, 2017
b305a4f
expression: base #3868 pr, and rewrite builtin function unary minus t…
winkyao Jul 25, 2017
bedd270
expression: refactor builtin unary minus function
winkyao Jul 25, 2017
0e5de09
expression: tiny refactor
winkyao Jul 25, 2017
4045075
expression: tiny cleanup
winkyao Jul 25, 2017
636170d
expression: cleanup
winkyao Jul 25, 2017
33bf09c
expresion: cleanup
winkyao Jul 25, 2017
eb58c48
expression: tiny cleanup
winkyao Jul 25, 2017
b199772
Merge branch 'master' of https://github.com/pingcap/tidb into winkyao…
winkyao Jul 25, 2017
6005a18
Merge branch 'master' into winkyao/fix_issue_3762
winkyao Jul 26, 2017
e2aa622
expression: add some comment
winkyao Jul 26, 2017
d0395bf
Merge branch 'winkyao/fix_issue_3762' of https://github.com/pingcap/t…
winkyao Jul 26, 2017
ffc0df7
Merge branch 'master' into winkyao/fix_issue_3762
winkyao Jul 26, 2017
27202d6
Merge branch 'master' of https://github.com/pingcap/tidb into winkyao…
winkyao Jul 26, 2017
03dc0e5
Merge branch 'winkyao/fix_issue_3762' of https://github.com/pingcap/t…
winkyao Jul 26, 2017
7f01b42
expression: adjust commet
winkyao Jul 26, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,19 @@ func (s *testSuite) TestBuiltin(c *C) {
result = tk.MustQuery("select cast(a as signed) from t")
result.Check(testkit.Rows("130000"))

// fixed issue #3762
result = tk.MustQuery("select -9223372036854775809;")
result.Check(testkit.Rows("-9223372036854775809"))
result = tk.MustQuery("select --9223372036854775809;")
result.Check(testkit.Rows("9223372036854775809"))
result = tk.MustQuery("select -9223372036854775808;")
result.Check(testkit.Rows("-9223372036854775808"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a bigint(30));")
_, err := tk.Exec("insert into t values(-9223372036854775809)")
c.Assert(err, NotNil)

// test unhex and hex
result = tk.MustQuery("select unhex('4D7953514C')")
result.Check(testkit.Rows("MySQL"))
Expand Down
5 changes: 5 additions & 0 deletions executor/prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -336,23 +336,28 @@ func ResetStmtCtx(ctx context.Context, s ast.StmtNode) {
switch s.(type) {
case *ast.UpdateStmt, *ast.InsertStmt, *ast.DeleteStmt:
sc.IgnoreTruncate = false
sc.IgnoreOverflow = false
sc.TruncateAsWarning = !sessVars.StrictSQLMode
if _, ok := s.(*ast.InsertStmt); !ok {
sc.InUpdateOrDeleteStmt = true
}
case *ast.CreateTableStmt, *ast.AlterTableStmt:
// Make sure the sql_mode is strict when checking column default value.
sc.IgnoreTruncate = false
sc.IgnoreOverflow = false
sc.TruncateAsWarning = false
case *ast.LoadDataStmt:
sc.IgnoreTruncate = false
sc.IgnoreOverflow = false
sc.TruncateAsWarning = !sessVars.StrictSQLMode
case *ast.SelectStmt:
sc.IgnoreOverflow = true
// Return warning for truncate error in selection.
sc.IgnoreTruncate = false
sc.TruncateAsWarning = true
default:
sc.IgnoreTruncate = true
sc.IgnoreOverflow = false
if show, ok := s.(*ast.ShowStmt); ok {
if show.Tp == ast.ShowWarnings {
sc.InShowWarning = true
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ var funcs = map[string]functionClass{
ast.UnaryNot: &unaryOpFunctionClass{baseFunctionClass{ast.UnaryNot, 1, 1}, opcode.Not},
ast.BitNeg: &unaryOpFunctionClass{baseFunctionClass{ast.BitNeg, 1, 1}, opcode.BitNeg},
ast.UnaryPlus: &unaryOpFunctionClass{baseFunctionClass{ast.UnaryPlus, 1, 1}, opcode.Plus},
ast.UnaryMinus: &unaryOpFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}, opcode.Minus},
ast.UnaryMinus: &unaryMinusFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}},
ast.In: &inFunctionClass{baseFunctionClass{ast.In, 1, -1}},
ast.IsTruth: &isTrueOpFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth},
ast.IsFalsity: &isTrueOpFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity},
Expand Down
185 changes: 148 additions & 37 deletions expression/builtin_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
package expression

import (
"fmt"
"math"

"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/types"
)
Expand All @@ -27,6 +31,7 @@ var (
_ functionClass = &bitOpFunctionClass{}
_ functionClass = &isTrueOpFunctionClass{}
_ functionClass = &unaryOpFunctionClass{}
_ functionClass = &unaryMinusFunctionClass{}
_ functionClass = &isNullFunctionClass{}
)

Expand All @@ -37,6 +42,7 @@ var (
_ builtinFunc = &builtinBitOpSig{}
_ builtinFunc = &builtinIsTrueOpSig{}
_ builtinFunc = &builtinUnaryOpSig{}
_ builtinFunc = &builtinUnaryMinusIntSig{}
_ builtinFunc = &builtinIsNullSig{}
)

Expand Down Expand Up @@ -335,49 +341,154 @@ func (b *builtinUnaryOpSig) eval(row []types.Datum) (d types.Datum, err error) {
default:
return d, errInvalidOperation.Gen("Unsupported type %v for op.Plus", aDatum.Kind())
}
case opcode.Minus:
switch aDatum.Kind() {
case types.KindInt64:
d.SetInt64(-aDatum.GetInt64())
case types.KindUint64:
d.SetInt64(-int64(aDatum.GetUint64()))
case types.KindFloat64:
d.SetFloat64(-aDatum.GetFloat64())
case types.KindFloat32:
d.SetFloat32(-aDatum.GetFloat32())
case types.KindMysqlDuration:
dec := new(types.MyDecimal)
err = types.DecimalSub(new(types.MyDecimal), aDatum.GetMysqlDuration().ToNumber(), dec)
d.SetMysqlDecimal(dec)
case types.KindMysqlTime:
dec := new(types.MyDecimal)
err = types.DecimalSub(new(types.MyDecimal), aDatum.GetMysqlTime().ToNumber(), dec)
d.SetMysqlDecimal(dec)
case types.KindString, types.KindBytes:
f, err1 := types.StrToFloat(sc, aDatum.GetString())
err = errors.Trace(err1)
d.SetFloat64(-f)
case types.KindMysqlDecimal:
dec := new(types.MyDecimal)
err = types.DecimalSub(new(types.MyDecimal), aDatum.GetMysqlDecimal(), dec)
d.SetMysqlDecimal(dec)
case types.KindMysqlHex:
d.SetFloat64(-aDatum.GetMysqlHex().ToNumber())
case types.KindMysqlBit:
d.SetFloat64(-aDatum.GetMysqlBit().ToNumber())
case types.KindMysqlEnum:
d.SetFloat64(-aDatum.GetMysqlEnum().ToNumber())
case types.KindMysqlSet:
d.SetFloat64(-aDatum.GetMysqlSet().ToNumber())
default:
return d, errInvalidOperation.Gen("Unsupported type %v for op.Minus", aDatum.Kind())
}
default:
return d, errInvalidOperation.Gen("Unsupported op %v for unary op", b.op)
}
return
}

type unaryMinusFunctionClass struct {
baseFunctionClass
}

func (b *unaryMinusFunctionClass) handleIntOverflow(arg *Constant) (overflow bool) {
if mysql.HasUnsignedFlag(arg.GetType().Flag) {
uval := arg.Value.GetUint64()
// -math.MinInt64 is 9223372036854775808, so if uval is more than 9223372036854775808, like
// 9223372036854775809, -9223372036854775809 is less than math.MinInt64, overflow occurs.
if uval > uint64(-math.MinInt64) {
return true
}
} else {
val := arg.Value.GetInt64()
// The math.MinInt64 is -9223372036854775808, the math.MaxInt64 is 9223372036854775807,
// which is less than abs(-9223372036854775808). When val == math.MinInt64, overflow occurs.
if val == math.MinInt64 {
return true
}
}
return false
}

// typeInfer infers unaryMinus function return type. when the arg is an int constant and overflow,
// typerInfer will infers the return type as tpDecimal, not tpInt.
func (b *unaryMinusFunctionClass) typeInfer(argExpr Expression, ctx context.Context) (evalTp, bool) {
tp := tpInt
switch argExpr.GetTypeClass() {
case types.ClassString, types.ClassReal:
tp = tpReal
case types.ClassDecimal:
tp = tpDecimal
}

sc := ctx.GetSessionVars().StmtCtx
overflow := false
// TODO: Handle float overflow.
if arg, ok := argExpr.(*Constant); sc.IgnoreOverflow && ok &&
arg.GetTypeClass() == types.ClassInt {
overflow = b.handleIntOverflow(arg)
if overflow {
tp = tpDecimal
}
}
return tp, overflow
}

func (b *unaryMinusFunctionClass) getFunction(args []Expression, ctx context.Context) (sig builtinFunc, err error) {
err = b.verifyArgs(args)
if err != nil {
return nil, errors.Trace(err)
}

argExpr := args[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

argExpr only used in line 438, we can just use args[0] instead of argExpr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see line 441 and 469

retTp, intOverflow := b.typeInfer(argExpr, ctx)

var bf baseBuiltinFunc
switch argExpr.GetTypeClass() {
case types.ClassInt:
if intOverflow {
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal)
sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, true}
} else {
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpInt)
sig = &builtinUnaryMinusIntSig{baseIntBuiltinFunc{bf}}
}
case types.ClassDecimal:
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal)
sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false}
case types.ClassReal:
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal)
sig = &builtinUnaryMinusRealSig{baseRealBuiltinFunc{bf}}
case types.ClassString:
tp := argExpr.GetType().Tp
if types.IsTypeTime(tp) || tp == mysql.TypeDuration {
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal)
sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false}
} else {
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal)
sig = &builtinUnaryMinusRealSig{baseRealBuiltinFunc{bf}}
}
}

return sig.setSelf(sig), errors.Trace(err)
}

type builtinUnaryMinusIntSig struct {
baseIntBuiltinFunc
}

func (b *builtinUnaryMinusIntSig) evalInt(row []types.Datum) (res int64, isNull bool, err error) {
var val int64
val, isNull, err = b.args[0].EvalInt(row, b.getCtx().GetSessionVars().StmtCtx)
if err != nil || isNull {
return val, isNull, errors.Trace(err)
}

if mysql.HasUnsignedFlag(b.args[0].GetType().Flag) {
uval := uint64(val)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move the overflow check logic to EvalInt ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tiancaiamao the overflow logic is only for unary minus

if uval > uint64(-math.MinInt64) {
return 0, false, types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", uval))
} else if uval == uint64(-math.MinInt64) {
return math.MinInt64, false, nil
}
} else if val == math.MinInt64 {
return 0, false, types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", val))
}
return -val, false, errors.Trace(err)
}

type builtinUnaryMinusDecimalSig struct {
baseDecimalBuiltinFunc

constantArgOverflow bool
}

func (b *builtinUnaryMinusDecimalSig) evalDecimal(row []types.Datum) (*types.MyDecimal, bool, error) {
sc := b.getCtx().GetSessionVars().StmtCtx

var dec *types.MyDecimal
dec, isNull, err := b.args[0].EvalDecimal(row, sc)
if err != nil || isNull {
return dec, isNull, errors.Trace(err)
}

to := new(types.MyDecimal)
err = types.DecimalSub(new(types.MyDecimal), dec, to)
return to, false, errors.Trace(err)
}

type builtinUnaryMinusRealSig struct {
baseRealBuiltinFunc
}

func (b *builtinUnaryMinusRealSig) evalReal(row []types.Datum) (res float64, isNull bool, err error) {
sc := b.getCtx().GetSessionVars().StmtCtx
var val float64
val, isNull, err = b.args[0].EvalReal(row, sc)
res = -val
return
}

type isNullFunctionClass struct {
baseFunctionClass
}
Expand Down
43 changes: 43 additions & 0 deletions expression/builtin_op_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,56 @@
package expression

import (
"math"

"github.com/juju/errors"
. "github.com/pingcap/check"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/util/testleak"
"github.com/pingcap/tidb/util/types"
)

func (s *testEvaluatorSuite) TestUnary(c *C) {
defer testleak.AfterTest(c)()
cases := []struct {
args interface{}
expected interface{}
overflow bool
getErr bool
}{
{uint64(9223372036854775809), "-9223372036854775809", true, false},
{uint64(9223372036854775810), "-9223372036854775810", true, false},
{uint64(9223372036854775808), int64(-9223372036854775808), false, false},
{int64(math.MinInt64), "9223372036854775808", true, false}, // --9223372036854775808
}
sc := s.ctx.GetSessionVars().StmtCtx
origin := sc.IgnoreOverflow
sc.IgnoreOverflow = true
defer func() {
sc.IgnoreOverflow = origin
}()

for _, t := range cases {
f, err := newFunctionForTest(s.ctx, ast.UnaryMinus, primitiveValsToConstants([]interface{}{t.args})...)
c.Assert(err, IsNil)
d, err := f.Eval(nil)
if t.getErr == false {
c.Assert(err, IsNil)
if !t.overflow {
c.Assert(d.GetValue(), Equals, t.expected)
} else {
c.Assert(d.GetMysqlDecimal().String(), Equals, t.expected)
}
} else {
c.Assert(err, NotNil)
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add test case for isDeterminstic.


f, err := funcs[ast.UnaryMinus].getFunction([]Expression{Zero}, s.ctx)
c.Assert(err, IsNil)
c.Assert(f.isDeterministic(), IsTrue)
}

func (s *testEvaluatorSuite) TestAndAnd(c *C) {
defer testleak.AfterTest(c)()

Expand Down
1 change: 1 addition & 0 deletions expression/scalar_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ func (sf *ScalarFunction) Eval(row []types.Datum) (d types.Datum, err error) {
res, isNull, err = sf.EvalString(row, sc)
}
}

if isNull || err != nil {
d.SetValue(nil)
return d, errors.Trace(err)
Expand Down
4 changes: 4 additions & 0 deletions expression/typeinferer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,10 @@ func (ts *testTypeInferrerSuite) TestInferType(c *C) {
{`json_insert('{"a": 1}', '$.a', 3)`, mysql.TypeJSON, charset.CharsetUTF8, 0},
{`json_replace('{"a": 1}', '$.a', 3)`, mysql.TypeJSON, charset.CharsetUTF8, 0},
{`json_merge('{"a": 1}', '3')`, mysql.TypeJSON, charset.CharsetUTF8, 0},
{"-9223372036854775809", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag},
{"-9223372036854775808", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag},
{"--9223372036854775809", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag},
{"--9223372036854775808", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag},
}
for _, tt := range tests {
ctx := testKit.Se.(context.Context)
Expand Down
7 changes: 6 additions & 1 deletion plan/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ func (b *planBuilder) rewriteWithPreprocess(expr ast.ExprNode, p LogicalPlan, ag
if getRowLen(er.ctxStack[0]) != 1 {
return nil, nil, ErrOperandColumns.GenByArgs(1)
}

return er.ctxStack[0], er.p, nil
}

Expand Down Expand Up @@ -142,7 +143,11 @@ func popRowArg(ctx context.Context, e expression.Expression) (ret expression.Exp
return ret, errors.Trace(err)
}
c, _ := e.(*expression.Constant)
ret = &expression.Constant{Value: types.NewDatum(c.Value.GetRow()[1:]), RetType: c.GetType()}
if getRowLen(c) == 2 {
ret = &expression.Constant{Value: c.Value.GetRow()[1], RetType: c.GetType()}
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because if getRowLen(c) == 2, if return c.Value.GetRow()[1:], it will treat it as row, and build a row > row function, but it should be datum > datum in this case, see the

if f, ok := e.(*expression.ScalarFunction); ok {
		args := f.GetArgs()
		if len(args) == 2 {  // in here, do the same thing
			return args[1].Clone(), nil 
		}
		ret, err = expression.NewFunction(ctx, f.FuncName.L, f.GetType(), args[1:]...)
		return ret, errors.Trace(err)
	}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and if not change this, the test case select row(1, 1) > row(1, 0) will return 0, which should return result 1

ret = &expression.Constant{Value: types.NewDatum(c.Value.GetRow()[1:]), RetType: c.GetType()}
}
return
}

Expand Down
1 change: 1 addition & 0 deletions sessionctx/variable/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,7 @@ type StatementContext struct {
// Set the following variables before execution

InUpdateOrDeleteStmt bool
IgnoreOverflow bool
IgnoreTruncate bool
TruncateAsWarning bool
InShowWarning bool
Expand Down