Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: fix #3762, signed integer overflow handle in minus unary scalar function #3780

Merged
merged 24 commits into from
Jul 26, 2017
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
7762b8b
expression: handle int builtinUnaryOpSig overflow
winkyao Jul 17, 2017
f87e0af
expression: fix #3762, signed integer overflow handle in minus unary …
winkyao Jul 17, 2017
465dd1b
expression: fix #3762, add builtin_op_test.go, some builtin test cases
winkyao Jul 17, 2017
17f1e5f
*: tiny cleanup
winkyao Jul 17, 2017
1d15330
expression: set the correct unary minus function ScalarFunction.RetType
winkyao Jul 17, 2017
a803f9d
expression: rewrite unary minus builtin function to handle overflow.
winkyao Jul 19, 2017
d6e6c25
plan: fix popRowArg bug: when get expression.Constatn and getRowLen i…
winkyao Jul 20, 2017
ae342fc
*:git stash
winkyao Jul 24, 2017
869e1f6
Merge branch 'master' of https://github.com/pingcap/tidb into winkyao…
winkyao Jul 25, 2017
b305a4f
expression: base #3868 pr, and rewrite builtin function unary minus t…
winkyao Jul 25, 2017
bedd270
expression: refactor builtin unary minus function
winkyao Jul 25, 2017
0e5de09
expression: tiny refactor
winkyao Jul 25, 2017
4045075
expression: tiny cleanup
winkyao Jul 25, 2017
636170d
expression: cleanup
winkyao Jul 25, 2017
33bf09c
expresion: cleanup
winkyao Jul 25, 2017
eb58c48
expression: tiny cleanup
winkyao Jul 25, 2017
b199772
Merge branch 'master' of https://github.com/pingcap/tidb into winkyao…
winkyao Jul 25, 2017
6005a18
Merge branch 'master' into winkyao/fix_issue_3762
winkyao Jul 26, 2017
e2aa622
expression: add some comment
winkyao Jul 26, 2017
d0395bf
Merge branch 'winkyao/fix_issue_3762' of https://github.com/pingcap/t…
winkyao Jul 26, 2017
ffc0df7
Merge branch 'master' into winkyao/fix_issue_3762
winkyao Jul 26, 2017
27202d6
Merge branch 'master' of https://github.com/pingcap/tidb into winkyao…
winkyao Jul 26, 2017
03dc0e5
Merge branch 'winkyao/fix_issue_3762' of https://github.com/pingcap/t…
winkyao Jul 26, 2017
7f01b42
expression: adjust commet
winkyao Jul 26, 2017
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions executor/aggregate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,7 @@ func (s *testSuite) TestAggPrune(c *C) {
testleak.AfterTest(c)()
}()
tk := testkit.NewTestKit(c, s.store)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

useless line ?

tk.MustExec("use test")
tk.MustExec("drop table if exists t")
tk.MustExec("create table t(id int primary key, b varchar(50), c int)")
Expand Down
13 changes: 13 additions & 0 deletions executor/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1080,6 +1080,19 @@ func (s *testSuite) TestBuiltin(c *C) {
result = tk.MustQuery("select cast(a as signed) from t")
result.Check(testkit.Rows("130000"))

// fixed issue #3762
result = tk.MustQuery("select -9223372036854775809;")
result.Check(testkit.Rows("-9223372036854775809"))
result = tk.MustQuery("select --9223372036854775809;")
result.Check(testkit.Rows("9223372036854775809"))
result = tk.MustQuery("select -9223372036854775808;")
result.Check(testkit.Rows("-9223372036854775808"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a bigint(30));")
_, err := tk.Exec("insert into t values(-9223372036854775809)")
c.Assert(err, NotNil)

// test unhex and hex
result = tk.MustQuery("select unhex('4D7953514C')")
result.Check(testkit.Rows("MySQL"))
Expand Down
3 changes: 3 additions & 0 deletions executor/prepared.go
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,9 @@ func ResetStmtCtx(ctx context.Context, s ast.StmtNode) {
case *ast.LoadDataStmt:
sc.IgnoreTruncate = false
sc.TruncateAsWarning = !sessVars.StrictSQLMode
case *ast.SelectStmt:
sc.InSelectStmt = true
sc.IgnoreTruncate = true
default:
sc.IgnoreTruncate = true
if show, ok := s.(*ast.ShowStmt); ok {
Expand Down
2 changes: 1 addition & 1 deletion expression/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -786,7 +786,7 @@ var funcs = map[string]functionClass{
ast.UnaryNot: &unaryOpFunctionClass{baseFunctionClass{ast.UnaryNot, 1, 1}, opcode.Not},
ast.BitNeg: &unaryOpFunctionClass{baseFunctionClass{ast.BitNeg, 1, 1}, opcode.BitNeg},
ast.UnaryPlus: &unaryOpFunctionClass{baseFunctionClass{ast.UnaryPlus, 1, 1}, opcode.Plus},
ast.UnaryMinus: &unaryOpFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}, opcode.Minus},
ast.UnaryMinus: &unaryMinusFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}},
ast.In: &inFunctionClass{baseFunctionClass{ast.In, 1, -1}},
ast.IsTruth: &isTrueOpFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth},
ast.IsFalsity: &isTrueOpFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity},
Expand Down
179 changes: 177 additions & 2 deletions expression/builtin_op.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,12 @@
package expression

import (
"fmt"
"math"

"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/util/types"
)
Expand All @@ -27,6 +31,7 @@ var (
_ functionClass = &bitOpFunctionClass{}
_ functionClass = &isTrueOpFunctionClass{}
_ functionClass = &unaryOpFunctionClass{}
_ functionClass = &unaryMinusFunctionClass{}
_ functionClass = &isNullFunctionClass{}
)

Expand All @@ -37,6 +42,7 @@ var (
_ builtinFunc = &builtinBitOpSig{}
_ builtinFunc = &builtinIsTrueOpSig{}
_ builtinFunc = &builtinUnaryOpSig{}
_ builtinFunc = &builtinUnaryMinusIntSig{}
_ builtinFunc = &builtinIsNullSig{}
)

Expand Down Expand Up @@ -338,9 +344,19 @@ func (b *builtinUnaryOpSig) eval(row []types.Datum) (d types.Datum, err error) {
case opcode.Minus:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the code will never reach here?

switch aDatum.Kind() {
case types.KindInt64:
d.SetInt64(-aDatum.GetInt64())
val := aDatum.GetInt64()
if val == math.MinInt64 {
err = types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", val))
}
d.SetInt64(-val)
case types.KindUint64:
d.SetInt64(-int64(aDatum.GetUint64()))
uval := aDatum.GetUint64()
if uval > uint64(-math.MinInt64) { // 9223372036854775808
// -uval will overflow
err = types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", uval))
} else {
d.SetInt64(-int64(uval))
}
case types.KindFloat64:
d.SetFloat64(-aDatum.GetFloat64())
case types.KindFloat32:
Expand Down Expand Up @@ -378,6 +394,165 @@ func (b *builtinUnaryOpSig) eval(row []types.Datum) (d types.Datum, err error) {
return
}

type unaryMinusFunctionClass struct {
baseFunctionClass
}

func (b *unaryMinusFunctionClass) typeInfer(argExpr Expression) (evalTp, bool) {
tp := tpInt
switch argExpr.GetType().Tp {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can use argExpr.getClassType() here ?

case types.ClassString, types.ClassReal:
    tp = tpReal
case types.ClassDecimal:
    tp = tpDecimal

case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeDouble, mysql.TypeFloat,
mysql.TypeDatetime, mysql.TypeDuration, mysql.TypeTimestamp:
tp = tpReal
case mysql.TypeNewDecimal:
tp = tpDecimal
}

overflow := false
// TODO: handle float overflow
if arg, ok := argExpr.(*Constant); ok {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why check overflow in type infer? If expr is not Constant, this is useless.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@XuHuaiyu, because overflow will be handled only if the statement is Select, and if the args type is not constant, MySQL will just throw a overflow error out, see below:

mysql> select * from t;
Field   1:  `a`
Catalog:    `def`
Database:   `test`
Table:      `t`
Org_table:  `t`
Type:       LONGLONG
Collation:  binary (63)
Length:     20
Max_length: 19
Decimals:   0
Flags:      UNSIGNED NUM


+---------------------+
| a                   |
+---------------------+
| 9223372036854775809 |
| 9223372036854775808 |
+---------------------+
2 rows in set (0.00 sec)

mysql> select -a from t;
ERROR 1690 (22003): BIGINT value is out of range in ‘-(`test`.`t`.`a`)’

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's weird here. Can we change parser to parse -9223372036854775809 as a constant decimal directly ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, image this case "--9223372036854775809". we should not do too much things in parser

switch arg.Value.Kind() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we may check whether mysql.HasUnsignedFlag(arg.GetType().Flag) here
rather than check arg.Value.Kind()

case types.KindUint64:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment for the two overflow check.

uval := arg.Value.GetUint64()
if uval > uint64(-math.MinInt64) {
overflow = true
tp = tpDecimal
}
case types.KindInt64:
val := arg.Value.GetInt64()
if val == math.MinInt64 {
overflow = true
tp = tpDecimal
}
}
}
return tp, overflow
}

func (b *unaryMinusFunctionClass) getFunction(args []Expression, ctx context.Context) (sig builtinFunc, err error) {
argExpr := args[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

argExpr only used in line 438, we can just use args[0] instead of argExpr.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see line 441 and 469

retTp, intOverflow := b.typeInfer(argExpr)

switch argExpr.GetTypeClass() {
case types.ClassInt:
if intOverflow {
bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal)
if err != nil {
return nil, errors.Trace(err)
}
sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false}
} else {
bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpInt)
if err != nil {
return nil, errors.Trace(err)
}
sig = &builtinUnaryMinusIntSig{baseIntBuiltinFunc{bf}}
}
case types.ClassDecimal:
bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal)
if err != nil {
return nil, errors.Trace(err)
}
sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false}
case types.ClassReal:
bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal)
if err != nil {
return nil, errors.Trace(err)
}
sig = &builtinUnaryMinusRealSig{baseRealBuiltinFunc{bf}}
case types.ClassString:
tp := argExpr.GetType().Tp
if types.IsTypeTime(tp) {
bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line 465 and ling 471 can be merged?

if err != nil {
return nil, errors.Trace(err)
}
sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false}
} else if tp == mysql.TypeDuration {
bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal)
if err != nil {
return nil, errors.Trace(err)
}
sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false}
} else {
bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal)
if err != nil {
return nil, errors.Trace(err)
}
sig = &builtinUnaryMinusRealSig{baseRealBuiltinFunc{bf}}
}
}

return sig.setSelf(sig), errors.Trace(b.verifyArgs(args))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we may put b.verifyArgs at the beginning of this func to check the count of args.

}

func unaryMinusDecimal(dec *types.MyDecimal) (to *types.MyDecimal, err error) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this func may be unnecessary?

to = new(types.MyDecimal)
err = types.DecimalSub(new(types.MyDecimal), dec, to)
return
}

type builtinUnaryMinusIntSig struct {
baseIntBuiltinFunc
}

func (b *builtinUnaryMinusIntSig) evalInt(row []types.Datum) (res int64, isNull bool, err error) {
var val int64
val, isNull, err = b.args[0].EvalInt(row, b.getCtx().GetSessionVars().StmtCtx)
if err != nil || isNull {
return val, isNull, errors.Trace(err)
}

if mysql.HasUnsignedFlag(b.args[0].GetType().Flag) {
uval := uint64(val)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move the overflow check logic to EvalInt ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tiancaiamao the overflow logic is only for unary minus

if uval > uint64(-math.MinInt64) {
return 0, false, types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", uval))
} else if uval == uint64(-math.MinInt64) {
return math.MinInt64, false, nil
}
} else if val == math.MinInt64 {
return 0, false, types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", val))
}
return -val, false, err
}

type builtinUnaryMinusDecimalSig struct {
baseDecimalBuiltinFunc

constantArgOverflow bool
}

func (b *builtinUnaryMinusDecimalSig) evalDecimal(row []types.Datum) (*types.MyDecimal, bool, error) {
var (
dec, to *types.MyDecimal
)

sc := b.getCtx().GetSessionVars().StmtCtx
if !sc.InSelectStmt && b.constantArgOverflow {
return dec, false, types.ErrOverflow.GenByArgs("DECIMAL", dec.String())
}

dec, isNull, err := b.args[0].EvalDecimal(row, sc)
if err != nil || isNull {
return dec, isNull, errors.Trace(err)
}

to, err = unaryMinusDecimal(dec)
return to, false, err
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

s/err/errors.Trace(err)

}

type builtinUnaryMinusRealSig struct {
baseRealBuiltinFunc
}

func (b *builtinUnaryMinusRealSig) evalReal(row []types.Datum) (res float64, isNull bool, err error) {
sc := b.getCtx().GetSessionVars().StmtCtx
var val float64
val, isNull, err = b.args[0].EvalReal(row, sc)
res = -val
return
}

type isNullFunctionClass struct {
baseFunctionClass
}
Expand Down
39 changes: 39 additions & 0 deletions expression/builtin_op_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,52 @@
package expression

import (
"math"

"github.com/juju/errors"
. "github.com/pingcap/check"
"github.com/pingcap/tidb/ast"
"github.com/pingcap/tidb/util/testleak"
"github.com/pingcap/tidb/util/types"
)

func (s *testEvaluatorSuite) TestUnary(c *C) {
defer testleak.AfterTest(c)()
cases := []struct {
args interface{}
expected interface{}
overflow bool
getErr bool
}{
{uint64(9223372036854775809), "-9223372036854775809", true, false},
{uint64(9223372036854775810), "-9223372036854775810", true, false},
{uint64(9223372036854775808), int64(-9223372036854775808), false, false},
{int64(math.MinInt64), "9223372036854775808", true, false}, // --9223372036854775808
}
sc := s.ctx.GetSessionVars().StmtCtx
origin := sc.InSelectStmt
sc.InSelectStmt = true
defer func() {
sc.InSelectStmt = origin
}()

for _, t := range cases {
f, err := newFunctionForTest(s.ctx, ast.UnaryMinus, primitiveValsToConstants([]interface{}{t.args})...)
c.Assert(err, IsNil)
d, err := f.Eval(nil)
if t.getErr == false {
c.Assert(err, IsNil)
if !t.overflow {
c.Assert(d.GetValue(), Equals, t.expected)
} else {
c.Assert(d.GetMysqlDecimal().String(), Equals, t.expected)
}
} else {
c.Assert(err, NotNil)
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add test case for isDeterminstic.

}

func (s *testEvaluatorSuite) TestAndAnd(c *C) {
defer testleak.AfterTest(c)()

Expand Down
29 changes: 29 additions & 0 deletions expression/constant_fold.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,32 @@ func FoldConstant(expr Expression) Expression {
RetType: scalarFunc.RetType,
}
}

// PlainFoldConstant does constant folding optimization
// on an expression without recursively folding.
func PlainFoldConstant(expr Expression) Expression {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused

scalarFunc, ok := expr.(*ScalarFunction)
if !ok || !scalarFunc.Function.isDeterministic() {
return expr
}
args := scalarFunc.GetArgs()
canFold := true
for i := 0; i < len(args); i++ {
if _, ok := args[i].(*Constant); !ok {
canFold = false
break
}
}
if !canFold {
return expr
}
value, err := scalarFunc.Eval(nil)
if err != nil {
log.Warnf("There may exist an error during constant folding. The function name is %s, args are %s", scalarFunc.FuncName, args)
return expr
}
return &Constant{
Value: value,
RetType: scalarFunc.RetType,
}
}
6 changes: 4 additions & 2 deletions expression/expression.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ func evalExprToDecimal(expr Expression, row []types.Datum, sc *variable.Statemen
if val.IsNull() || err != nil {
return res, val.IsNull(), errors.Trace(err)
}
if expr.GetTypeClass() == types.ClassDecimal {
switch expr.GetTypeClass() {
case types.ClassDecimal:
res, err = val.ToDecimal(sc)
return res, false, errors.Trace(err)
// TODO: We maintain two sets of type systems, one for Expression, one for Datum.
Expand All @@ -180,7 +181,8 @@ func evalExprToDecimal(expr Expression, row []types.Datum, sc *variable.Statemen
// but what we actually get is store as float64 in Datum.
// So if we wrap `CastDecimalAsInt` upon the result, we'll get <nil> when call `arg.EvalDecimal()`.
// This will be fixed after all built-in functions be rewrite correctlly.
} else if IsHybridType(expr) {
}
if IsHybridType(expr) {
res, err = val.ToDecimal(sc)
return res, false, errors.Trace(err)
}
Expand Down
Loading