Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: fix #3762, signed integer overflow handle in minus unary scalar function #3780

Merged
merged 24 commits into from
Jul 26, 2017
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
7762b8b
expression: handle int builtinUnaryOpSig overflow
winkyao Jul 17, 2017
f87e0af
expression: fix #3762, signed integer overflow handle in minus unary …
winkyao Jul 17, 2017
465dd1b
expression: fix #3762, add builtin_op_test.go, some builtin test cases
winkyao Jul 17, 2017
17f1e5f
*: tiny cleanup
winkyao Jul 17, 2017
1d15330
expression: set the correct unary minus function ScalarFunction.RetType
winkyao Jul 17, 2017
a803f9d
expression: rewrite unary minus builtin function to handle overflow.
winkyao Jul 19, 2017
d6e6c25
plan: fix popRowArg bug: when get expression.Constatn and getRowLen i…
winkyao Jul 20, 2017
ae342fc
*:git stash
winkyao Jul 24, 2017
869e1f6
Merge branch 'master' of https://github.com/pingcap/tidb into winkyao…
winkyao Jul 25, 2017
b305a4f
expression: base #3868 pr, and rewrite builtin function unary minus t…
winkyao Jul 25, 2017
bedd270
expression: refactor builtin unary minus function
winkyao Jul 25, 2017
0e5de09
expression: tiny refactor
winkyao Jul 25, 2017
4045075
expression: tiny cleanup
winkyao Jul 25, 2017
636170d
expression: cleanup
winkyao Jul 25, 2017
33bf09c
expresion: cleanup
winkyao Jul 25, 2017
eb58c48
expression: tiny cleanup
winkyao Jul 25, 2017
b199772
Merge branch 'master' of https://github.com/pingcap/tidb into winkyao…
winkyao Jul 25, 2017
6005a18
Merge branch 'master' into winkyao/fix_issue_3762
winkyao Jul 26, 2017
e2aa622
expression: add some comment
winkyao Jul 26, 2017
d0395bf
Merge branch 'winkyao/fix_issue_3762' of https://github.com/pingcap/t…
winkyao Jul 26, 2017
ffc0df7
Merge branch 'master' into winkyao/fix_issue_3762
winkyao Jul 26, 2017
27202d6
Merge branch 'master' of https://github.com/pingcap/tidb into winkyao…
winkyao Jul 26, 2017
03dc0e5
Merge branch 'winkyao/fix_issue_3762' of https://github.com/pingcap/t…
winkyao Jul 26, 2017
7f01b42
expression: adjust commet
winkyao Jul 26, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,19 @@ func (s *testSuite) TestBuiltin(c *C) {
result = tk.MustQuery("select cast(a as signed) from t")
result.Check(testkit.Rows("130000"))

// fixed issue #3762
result = tk.MustQuery("select -9223372036854775809;")
result.Check(testkit.Rows("-9223372036854775809"))
result = tk.MustQuery("select --9223372036854775809;")
result.Check(testkit.Rows("9223372036854775809"))
result = tk.MustQuery("select -9223372036854775808;")
result.Check(testkit.Rows("-9223372036854775808"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a bigint(30));")
_, err := tk.Exec("insert into t values(-9223372036854775809)")
c.Assert(err, NotNil)

// test unhex and hex
result = tk.MustQuery("select unhex('4D7953514C')")
result.Check(testkit.Rows("MySQL"))
Expand Down
3 changes: 3 additions & 0 deletions executor/prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,9 @@ func ResetStmtCtx(ctx context.Context, s ast.StmtNode) {
case *ast.LoadDataStmt:
sc.IgnoreTruncate = false
sc.TruncateAsWarning = !sessVars.StrictSQLMode
case *ast.SelectStmt:
sc.InSelectStmt = true
sc.IgnoreTruncate = true
default:
sc.IgnoreTruncate = true
if show, ok := s.(*ast.ShowStmt); ok {
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ var funcs = map[string]functionClass{
ast.UnaryNot: &unaryOpFunctionClass{baseFunctionClass{ast.UnaryNot, 1, 1}, opcode.Not},
ast.BitNeg: &unaryOpFunctionClass{baseFunctionClass{ast.BitNeg, 1, 1}, opcode.BitNeg},
ast.UnaryPlus: &unaryOpFunctionClass{baseFunctionClass{ast.UnaryPlus, 1, 1}, opcode.Plus},
ast.UnaryMinus: &unaryOpFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}, opcode.Minus},
ast.UnaryMinus: &unaryMinusFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}},
ast.In: &inFunctionClass{baseFunctionClass{ast.In, 1, -1}},
ast.IsTruth: &isTrueOpFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth},
ast.IsFalsity: &isTrueOpFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity},
Expand Down
171 changes: 169 additions & 2 deletions expression/builtin_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
package expression

import (
"fmt"
"math"

"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/types"
)
Expand All @@ -27,6 +31,7 @@ var (
_ functionClass = &bitOpFunctionClass{}
_ functionClass = &isTrueOpFunctionClass{}
_ functionClass = &unaryOpFunctionClass{}
_ functionClass = &unaryMinusFunctionClass{}
_ functionClass = &isNullFunctionClass{}
)

Expand All @@ -37,6 +42,7 @@ var (
_ builtinFunc = &builtinBitOpSig{}
_ builtinFunc = &builtinIsTrueOpSig{}
_ builtinFunc = &builtinUnaryOpSig{}
_ builtinFunc = &builtinUnaryMinusIntSig{}
_ builtinFunc = &builtinIsNullSig{}
)

Expand Down Expand Up @@ -338,9 +344,19 @@ func (b *builtinUnaryOpSig) eval(row []types.Datum) (d types.Datum, err error) {
case opcode.Minus:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code will never reach here?

switch aDatum.Kind() {
case types.KindInt64:
d.SetInt64(-aDatum.GetInt64())
val := aDatum.GetInt64()
if val == math.MinInt64 {
err = types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", val))
}
d.SetInt64(-val)
case types.KindUint64:
d.SetInt64(-int64(aDatum.GetUint64()))
uval := aDatum.GetUint64()
if uval > uint64(-math.MinInt64) { // 9223372036854775808
// -uval will overflow
err = types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", uval))
} else {
d.SetInt64(-int64(uval))
}
case types.KindFloat64:
d.SetFloat64(-aDatum.GetFloat64())
case types.KindFloat32:
Expand Down Expand Up @@ -378,6 +394,157 @@ func (b *builtinUnaryOpSig) eval(row []types.Datum) (d types.Datum, err error) {
return
}

type unaryMinusFunctionClass struct {
baseFunctionClass
}

func (b *unaryMinusFunctionClass) typeInfer(argExpr Expression) (evalTp, bool) {
tp := tpInt
switch argExpr.GetTypeClass() {
case types.ClassString, types.ClassReal:
tp = tpReal
case types.ClassDecimal:
tp = tpDecimal
}

overflow := false
// TODO: handle float overflow
if arg, ok := argExpr.(*Constant); ok {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why check overflow in type infer? If expr is not Constant, this is useless.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@XuHuaiyu, because overflow will be handled only if the statement is Select, and if the args type is not constant, MySQL will just throw a overflow error out, see below:

mysql> select * from t;
Field   1:  `a`
Catalog:    `def`
Database:   `test`
Table:      `t`
Org_table:  `t`
Type:       LONGLONG
Collation:  binary (63)
Length:     20
Max_length: 19
Decimals:   0
Flags:      UNSIGNED NUM


+---------------------+
| a                   |
+---------------------+
| 9223372036854775809 |
| 9223372036854775808 |
+---------------------+
2 rows in set (0.00 sec)

mysql> select -a from t;
ERROR 1690 (22003): BIGINT value is out of range in ‘-(`test`.`t`.`a`)’

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's weird here. Can we change parser to parse -9223372036854775809 as a constant decimal directly ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, image this case "--9223372036854775809". we should not do too much things in parser

switch arg.Value.Kind() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we may check whether mysql.HasUnsignedFlag(arg.GetType().Flag) here
rather than check arg.Value.Kind()

case types.KindUint64:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment for the two overflow check.

uval := arg.Value.GetUint64()
if uval > uint64(-math.MinInt64) {
overflow = true
tp = tpDecimal
}
case types.KindInt64:
val := arg.Value.GetInt64()
if val == math.MinInt64 {
overflow = true
tp = tpDecimal
}
}
}
return tp, overflow
}

func (b *unaryMinusFunctionClass) getFunction(args []Expression, ctx context.Context) (sig builtinFunc, err error) {
err = b.verifyArgs(args)
if err != nil {
return nil, errors.Trace(err)
}

argExpr := args[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

argExpr only used in line 438, we can just use args[0] instead of argExpr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see line 441 and 469

retTp, intOverflow := b.typeInfer(argExpr)

var bf baseBuiltinFunc
switch argExpr.GetTypeClass() {
case types.ClassInt:
if intOverflow {
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal)
if err != nil {
return nil, errors.Trace(err)
}
sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false}
} else {
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpInt)
if err != nil {
return nil, errors.Trace(err)
}
sig = &builtinUnaryMinusIntSig{baseIntBuiltinFunc{bf}}
}
case types.ClassDecimal:
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal)
if err != nil {
return nil, errors.Trace(err)
}
sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false}
case types.ClassReal:
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal)
if err != nil {
return nil, errors.Trace(err)
}
sig = &builtinUnaryMinusRealSig{baseRealBuiltinFunc{bf}}
case types.ClassString:
tp := argExpr.GetType().Tp
if types.IsTypeTime(tp) || tp == mysql.TypeDuration {
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal)
if err != nil {
return nil, errors.Trace(err)
}
sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false}
} else {
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal)
if err != nil {
return nil, errors.Trace(err)
}
sig = &builtinUnaryMinusRealSig{baseRealBuiltinFunc{bf}}
}
}

return sig.setSelf(sig), errors.Trace(err)
}

type builtinUnaryMinusIntSig struct {
baseIntBuiltinFunc
}

func (b *builtinUnaryMinusIntSig) evalInt(row []types.Datum) (res int64, isNull bool, err error) {
var val int64
val, isNull, err = b.args[0].EvalInt(row, b.getCtx().GetSessionVars().StmtCtx)
if err != nil || isNull {
return val, isNull, errors.Trace(err)
}

if mysql.HasUnsignedFlag(b.args[0].GetType().Flag) {
uval := uint64(val)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move the overflow check logic to EvalInt ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tiancaiamao the overflow logic is only for unary minus

if uval > uint64(-math.MinInt64) {
return 0, false, types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", uval))
} else if uval == uint64(-math.MinInt64) {
return math.MinInt64, false, nil
}
} else if val == math.MinInt64 {
return 0, false, types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", val))
}
return -val, false, err
}

type builtinUnaryMinusDecimalSig struct {
baseDecimalBuiltinFunc

constantArgOverflow bool
}

func (b *builtinUnaryMinusDecimalSig) evalDecimal(row []types.Datum) (*types.MyDecimal, bool, error) {
sc := b.getCtx().GetSessionVars().StmtCtx

var dec *types.MyDecimal
dec, isNull, err := b.args[0].EvalDecimal(row, sc)
if err != nil || isNull {
return dec, isNull, errors.Trace(err)
}

if !sc.InSelectStmt && b.constantArgOverflow {
return dec, false, types.ErrOverflow.GenByArgs("DECIMAL", dec.String())
}

to := new(types.MyDecimal)
err = types.DecimalSub(new(types.MyDecimal), dec, to)
return to, false, err
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/err/errors.Trace(err)

}

type builtinUnaryMinusRealSig struct {
baseRealBuiltinFunc
}

func (b *builtinUnaryMinusRealSig) evalReal(row []types.Datum) (res float64, isNull bool, err error) {
sc := b.getCtx().GetSessionVars().StmtCtx
var val float64
val, isNull, err = b.args[0].EvalReal(row, sc)
res = -val
return
}

type isNullFunctionClass struct {
baseFunctionClass
}
Expand Down
39 changes: 39 additions & 0 deletions expression/builtin_op_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,52 @@
package expression

import (
"math"

"github.com/juju/errors"
. "github.com/pingcap/check"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/util/testleak"
"github.com/pingcap/tidb/util/types"
)

func (s *testEvaluatorSuite) TestUnary(c *C) {
defer testleak.AfterTest(c)()
cases := []struct {
args interface{}
expected interface{}
overflow bool
getErr bool
}{
{uint64(9223372036854775809), "-9223372036854775809", true, false},
{uint64(9223372036854775810), "-9223372036854775810", true, false},
{uint64(9223372036854775808), int64(-9223372036854775808), false, false},
{int64(math.MinInt64), "9223372036854775808", true, false}, // --9223372036854775808
}
sc := s.ctx.GetSessionVars().StmtCtx
origin := sc.InSelectStmt
sc.InSelectStmt = true
defer func() {
sc.InSelectStmt = origin
}()

for _, t := range cases {
f, err := newFunctionForTest(s.ctx, ast.UnaryMinus, primitiveValsToConstants([]interface{}{t.args})...)
c.Assert(err, IsNil)
d, err := f.Eval(nil)
if t.getErr == false {
c.Assert(err, IsNil)
if !t.overflow {
c.Assert(d.GetValue(), Equals, t.expected)
} else {
c.Assert(d.GetMysqlDecimal().String(), Equals, t.expected)
}
} else {
c.Assert(err, NotNil)
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add test case for isDeterminstic.

}

func (s *testEvaluatorSuite) TestAndAnd(c *C) {
defer testleak.AfterTest(c)()

Expand Down
6 changes: 4 additions & 2 deletions expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ func evalExprToDecimal(expr Expression, row []types.Datum, sc *variable.Statemen
if val.IsNull() || err != nil {
return res, val.IsNull(), errors.Trace(err)
}
if expr.GetTypeClass() == types.ClassDecimal {
switch expr.GetTypeClass() {
case types.ClassDecimal:
res, err = val.ToDecimal(sc)
return res, false, errors.Trace(err)
// TODO: We maintain two sets of type systems, one for Expression, one for Datum.
Expand All @@ -180,7 +181,8 @@ func evalExprToDecimal(expr Expression, row []types.Datum, sc *variable.Statemen
// but what we actually get is store as float64 in Datum.
// So if we wrap `CastDecimalAsInt` upon the result, we'll get <nil> when call `arg.EvalDecimal()`.
// This will be fixed after all built-in functions be rewrite correctlly.
} else if IsHybridType(expr) {
}
if IsHybridType(expr) {
res, err = val.ToDecimal(sc)
return res, false, errors.Trace(err)
}
Expand Down
9 changes: 7 additions & 2 deletions expression/scalar_function.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,13 @@ func NewFunction(ctx context.Context, funcName string, retType *types.FieldType,
if builtinRetTp := f.getRetTp(); builtinRetTp.Tp != mysql.TypeUnspecified {
retType = builtinRetTp
}

sf := &ScalarFunction{
FuncName: model.NewCIStr(funcName),
RetType: retType,
Function: f,
}

return FoldConstant(sf), nil
}

Expand Down Expand Up @@ -160,10 +162,11 @@ func (sf *ScalarFunction) Decorrelate(schema *Schema) Expression {

// Eval implements Expression interface.
func (sf *ScalarFunction) Eval(row []types.Datum) (d types.Datum, err error) {
sc := sf.GetCtx().GetSessionVars().StmtCtx
if !TurnOnNewExprEval {
return sf.Function.eval(row)
d, err = sf.Function.eval(row)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any difference?

return
}
sc := sf.GetCtx().GetSessionVars().StmtCtx
var (
res interface{}
isNull bool
Expand All @@ -173,6 +176,7 @@ func (sf *ScalarFunction) Eval(row []types.Datum) (d types.Datum, err error) {
case types.ClassInt:
var intRes int64
intRes, isNull, err = sf.EvalInt(row, sc)

if mysql.HasUnsignedFlag(tp.Flag) {
res = uint64(intRes)
} else {
Expand All @@ -192,6 +196,7 @@ func (sf *ScalarFunction) Eval(row []types.Datum) (d types.Datum, err error) {
res, isNull, err = sf.EvalString(row, sc)
}
}

if isNull || err != nil {
d.SetValue(nil)
return d, errors.Trace(err)
Expand Down
4 changes: 4 additions & 0 deletions expression/typeinferer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,10 @@ func (ts *testTypeInferrerSuite) TestInferType(c *C) {
{`json_insert('{"a": 1}', '$.a', 3)`, mysql.TypeJSON, charset.CharsetUTF8, 0},
{`json_replace('{"a": 1}', '$.a', 3)`, mysql.TypeJSON, charset.CharsetUTF8, 0},
{`json_merge('{"a": 1}', '3')`, mysql.TypeJSON, charset.CharsetUTF8, 0},
{"-9223372036854775809", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag},
{"-9223372036854775808", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag},
{"--9223372036854775809", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag},
{"--9223372036854775808", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag},
}
for _, tt := range tests {
ctx := testKit.Se.(context.Context)
Expand Down
7 changes: 6 additions & 1 deletion plan/expression_rewriter.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ func (b *planBuilder) rewriteWithPreprocess(expr ast.ExprNode, p LogicalPlan, ag
if getRowLen(er.ctxStack[0]) != 1 {
return nil, nil, ErrOperandColumns.GenByArgs(1)
}

return er.ctxStack[0], er.p, nil
}

Expand Down Expand Up @@ -142,7 +143,11 @@ func popRowArg(ctx context.Context, e expression.Expression) (ret expression.Exp
return ret, errors.Trace(err)
}
c, _ := e.(*expression.Constant)
ret = &expression.Constant{Value: types.NewDatum(c.Value.GetRow()[1:]), RetType: c.GetType()}
if getRowLen(c) == 2 {
ret = &expression.Constant{Value: c.Value.GetRow()[1], RetType: c.GetType()}
} else {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why change here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because if getRowLen(c) == 2, if return c.Value.GetRow()[1:], it will treat it as row, and build a row > row function, but it should be datum > datum in this case, see the

if f, ok := e.(*expression.ScalarFunction); ok {
		args := f.GetArgs()
		if len(args) == 2 {  // in here, do the same thing
			return args[1].Clone(), nil 
		}
		ret, err = expression.NewFunction(ctx, f.FuncName.L, f.GetType(), args[1:]...)
		return ret, errors.Trace(err)
	}

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and if not change this, the test case select row(1, 1) > row(1, 0) will return 0, which should return result 1

ret = &expression.Constant{Value: types.NewDatum(c.Value.GetRow()[1:]), RetType: c.GetType()}
}
return
}

Expand Down
Loading