-
Notifications
You must be signed in to change notification settings - Fork 5.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
expression: fix #3762, signed integer overflow handle in minus unary scalar function #3780
Changes from all commits
7762b8b
f87e0af
465dd1b
17f1e5f
1d15330
a803f9d
d6e6c25
ae342fc
869e1f6
b305a4f
bedd270
0e5de09
4045075
636170d
33bf09c
eb58c48
b199772
6005a18
e2aa622
d0395bf
ffc0df7
27202d6
03dc0e5
7f01b42
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,8 +14,12 @@ | |
package expression | ||
|
||
import ( | ||
"fmt" | ||
"math" | ||
|
||
"github.com/juju/errors" | ||
"github.com/pingcap/tidb/context" | ||
"github.com/pingcap/tidb/mysql" | ||
"github.com/pingcap/tidb/parser/opcode" | ||
"github.com/pingcap/tidb/util/types" | ||
) | ||
|
@@ -27,6 +31,7 @@ var ( | |
_ functionClass = &bitOpFunctionClass{} | ||
_ functionClass = &isTrueOpFunctionClass{} | ||
_ functionClass = &unaryOpFunctionClass{} | ||
_ functionClass = &unaryMinusFunctionClass{} | ||
_ functionClass = &isNullFunctionClass{} | ||
) | ||
|
||
|
@@ -37,6 +42,7 @@ var ( | |
_ builtinFunc = &builtinBitOpSig{} | ||
_ builtinFunc = &builtinIsTrueOpSig{} | ||
_ builtinFunc = &builtinUnaryOpSig{} | ||
_ builtinFunc = &builtinUnaryMinusIntSig{} | ||
_ builtinFunc = &builtinIsNullSig{} | ||
) | ||
|
||
|
@@ -335,49 +341,154 @@ func (b *builtinUnaryOpSig) eval(row []types.Datum) (d types.Datum, err error) { | |
default: | ||
return d, errInvalidOperation.Gen("Unsupported type %v for op.Plus", aDatum.Kind()) | ||
} | ||
case opcode.Minus: | ||
switch aDatum.Kind() { | ||
case types.KindInt64: | ||
d.SetInt64(-aDatum.GetInt64()) | ||
case types.KindUint64: | ||
d.SetInt64(-int64(aDatum.GetUint64())) | ||
case types.KindFloat64: | ||
d.SetFloat64(-aDatum.GetFloat64()) | ||
case types.KindFloat32: | ||
d.SetFloat32(-aDatum.GetFloat32()) | ||
case types.KindMysqlDuration: | ||
dec := new(types.MyDecimal) | ||
err = types.DecimalSub(new(types.MyDecimal), aDatum.GetMysqlDuration().ToNumber(), dec) | ||
d.SetMysqlDecimal(dec) | ||
case types.KindMysqlTime: | ||
dec := new(types.MyDecimal) | ||
err = types.DecimalSub(new(types.MyDecimal), aDatum.GetMysqlTime().ToNumber(), dec) | ||
d.SetMysqlDecimal(dec) | ||
case types.KindString, types.KindBytes: | ||
f, err1 := types.StrToFloat(sc, aDatum.GetString()) | ||
err = errors.Trace(err1) | ||
d.SetFloat64(-f) | ||
case types.KindMysqlDecimal: | ||
dec := new(types.MyDecimal) | ||
err = types.DecimalSub(new(types.MyDecimal), aDatum.GetMysqlDecimal(), dec) | ||
d.SetMysqlDecimal(dec) | ||
case types.KindMysqlHex: | ||
d.SetFloat64(-aDatum.GetMysqlHex().ToNumber()) | ||
case types.KindMysqlBit: | ||
d.SetFloat64(-aDatum.GetMysqlBit().ToNumber()) | ||
case types.KindMysqlEnum: | ||
d.SetFloat64(-aDatum.GetMysqlEnum().ToNumber()) | ||
case types.KindMysqlSet: | ||
d.SetFloat64(-aDatum.GetMysqlSet().ToNumber()) | ||
default: | ||
return d, errInvalidOperation.Gen("Unsupported type %v for op.Minus", aDatum.Kind()) | ||
} | ||
default: | ||
return d, errInvalidOperation.Gen("Unsupported op %v for unary op", b.op) | ||
} | ||
return | ||
} | ||
|
||
type unaryMinusFunctionClass struct { | ||
baseFunctionClass | ||
} | ||
|
||
func (b *unaryMinusFunctionClass) handleIntOverflow(arg *Constant) (overflow bool) { | ||
if mysql.HasUnsignedFlag(arg.GetType().Flag) { | ||
uval := arg.Value.GetUint64() | ||
// -math.MinInt64 is 9223372036854775808, so if uval is more than 9223372036854775808, like | ||
// 9223372036854775809, -9223372036854775809 is less than math.MinInt64, overflow occurs. | ||
if uval > uint64(-math.MinInt64) { | ||
return true | ||
} | ||
} else { | ||
val := arg.Value.GetInt64() | ||
// The math.MinInt64 is -9223372036854775808, the math.MaxInt64 is 9223372036854775807, | ||
// which is less than abs(-9223372036854775808). When val == math.MinInt64, overflow occurs. | ||
if val == math.MinInt64 { | ||
return true | ||
} | ||
} | ||
return false | ||
} | ||
|
||
// typeInfer infers unaryMinus function return type. when the arg is an int constant and overflow, | ||
// typerInfer will infers the return type as tpDecimal, not tpInt. | ||
func (b *unaryMinusFunctionClass) typeInfer(argExpr Expression, ctx context.Context) (evalTp, bool) { | ||
tp := tpInt | ||
switch argExpr.GetTypeClass() { | ||
case types.ClassString, types.ClassReal: | ||
tp = tpReal | ||
case types.ClassDecimal: | ||
tp = tpDecimal | ||
} | ||
|
||
sc := ctx.GetSessionVars().StmtCtx | ||
overflow := false | ||
// TODO: Handle float overflow. | ||
if arg, ok := argExpr.(*Constant); sc.IgnoreOverflow && ok && | ||
arg.GetTypeClass() == types.ClassInt { | ||
overflow = b.handleIntOverflow(arg) | ||
if overflow { | ||
tp = tpDecimal | ||
} | ||
} | ||
return tp, overflow | ||
} | ||
|
||
func (b *unaryMinusFunctionClass) getFunction(args []Expression, ctx context.Context) (sig builtinFunc, err error) { | ||
err = b.verifyArgs(args) | ||
if err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
|
||
argExpr := args[0] | ||
retTp, intOverflow := b.typeInfer(argExpr, ctx) | ||
|
||
var bf baseBuiltinFunc | ||
switch argExpr.GetTypeClass() { | ||
case types.ClassInt: | ||
if intOverflow { | ||
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) | ||
sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, true} | ||
} else { | ||
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpInt) | ||
sig = &builtinUnaryMinusIntSig{baseIntBuiltinFunc{bf}} | ||
} | ||
case types.ClassDecimal: | ||
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) | ||
sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} | ||
case types.ClassReal: | ||
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal) | ||
sig = &builtinUnaryMinusRealSig{baseRealBuiltinFunc{bf}} | ||
case types.ClassString: | ||
tp := argExpr.GetType().Tp | ||
if types.IsTypeTime(tp) || tp == mysql.TypeDuration { | ||
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) | ||
sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} | ||
} else { | ||
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal) | ||
sig = &builtinUnaryMinusRealSig{baseRealBuiltinFunc{bf}} | ||
} | ||
} | ||
|
||
return sig.setSelf(sig), errors.Trace(err) | ||
} | ||
|
||
type builtinUnaryMinusIntSig struct { | ||
baseIntBuiltinFunc | ||
} | ||
|
||
func (b *builtinUnaryMinusIntSig) evalInt(row []types.Datum) (res int64, isNull bool, err error) { | ||
var val int64 | ||
val, isNull, err = b.args[0].EvalInt(row, b.getCtx().GetSessionVars().StmtCtx) | ||
if err != nil || isNull { | ||
return val, isNull, errors.Trace(err) | ||
} | ||
|
||
if mysql.HasUnsignedFlag(b.args[0].GetType().Flag) { | ||
uval := uint64(val) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move the overflow check logic to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tiancaiamao the overflow logic is only for unary minus |
||
if uval > uint64(-math.MinInt64) { | ||
return 0, false, types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", uval)) | ||
} else if uval == uint64(-math.MinInt64) { | ||
return math.MinInt64, false, nil | ||
} | ||
} else if val == math.MinInt64 { | ||
return 0, false, types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", val)) | ||
} | ||
return -val, false, errors.Trace(err) | ||
} | ||
|
||
type builtinUnaryMinusDecimalSig struct { | ||
baseDecimalBuiltinFunc | ||
|
||
constantArgOverflow bool | ||
} | ||
|
||
func (b *builtinUnaryMinusDecimalSig) evalDecimal(row []types.Datum) (*types.MyDecimal, bool, error) { | ||
sc := b.getCtx().GetSessionVars().StmtCtx | ||
|
||
var dec *types.MyDecimal | ||
dec, isNull, err := b.args[0].EvalDecimal(row, sc) | ||
if err != nil || isNull { | ||
return dec, isNull, errors.Trace(err) | ||
} | ||
|
||
to := new(types.MyDecimal) | ||
err = types.DecimalSub(new(types.MyDecimal), dec, to) | ||
return to, false, errors.Trace(err) | ||
} | ||
|
||
type builtinUnaryMinusRealSig struct { | ||
baseRealBuiltinFunc | ||
} | ||
|
||
func (b *builtinUnaryMinusRealSig) evalReal(row []types.Datum) (res float64, isNull bool, err error) { | ||
sc := b.getCtx().GetSessionVars().StmtCtx | ||
var val float64 | ||
val, isNull, err = b.args[0].EvalReal(row, sc) | ||
res = -val | ||
return | ||
} | ||
|
||
type isNullFunctionClass struct { | ||
baseFunctionClass | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,13 +14,56 @@ | |
package expression | ||
|
||
import ( | ||
"math" | ||
|
||
"github.com/juju/errors" | ||
. "github.com/pingcap/check" | ||
"github.com/pingcap/tidb/ast" | ||
"github.com/pingcap/tidb/util/testleak" | ||
"github.com/pingcap/tidb/util/types" | ||
) | ||
|
||
func (s *testEvaluatorSuite) TestUnary(c *C) { | ||
defer testleak.AfterTest(c)() | ||
cases := []struct { | ||
args interface{} | ||
expected interface{} | ||
overflow bool | ||
getErr bool | ||
}{ | ||
{uint64(9223372036854775809), "-9223372036854775809", true, false}, | ||
{uint64(9223372036854775810), "-9223372036854775810", true, false}, | ||
{uint64(9223372036854775808), int64(-9223372036854775808), false, false}, | ||
{int64(math.MinInt64), "9223372036854775808", true, false}, // --9223372036854775808 | ||
} | ||
sc := s.ctx.GetSessionVars().StmtCtx | ||
origin := sc.IgnoreOverflow | ||
sc.IgnoreOverflow = true | ||
defer func() { | ||
sc.IgnoreOverflow = origin | ||
}() | ||
|
||
for _, t := range cases { | ||
f, err := newFunctionForTest(s.ctx, ast.UnaryMinus, primitiveValsToConstants([]interface{}{t.args})...) | ||
c.Assert(err, IsNil) | ||
d, err := f.Eval(nil) | ||
if t.getErr == false { | ||
c.Assert(err, IsNil) | ||
if !t.overflow { | ||
c.Assert(d.GetValue(), Equals, t.expected) | ||
} else { | ||
c.Assert(d.GetMysqlDecimal().String(), Equals, t.expected) | ||
} | ||
} else { | ||
c.Assert(err, NotNil) | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add test case for |
||
|
||
f, err := funcs[ast.UnaryMinus].getFunction([]Expression{Zero}, s.ctx) | ||
c.Assert(err, IsNil) | ||
c.Assert(f.isDeterministic(), IsTrue) | ||
} | ||
|
||
func (s *testEvaluatorSuite) TestAndAnd(c *C) { | ||
defer testleak.AfterTest(c)() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -93,6 +93,7 @@ func (b *planBuilder) rewriteWithPreprocess(expr ast.ExprNode, p LogicalPlan, ag | |
if getRowLen(er.ctxStack[0]) != 1 { | ||
return nil, nil, ErrOperandColumns.GenByArgs(1) | ||
} | ||
|
||
return er.ctxStack[0], er.p, nil | ||
} | ||
|
||
|
@@ -142,7 +143,11 @@ func popRowArg(ctx context.Context, e expression.Expression) (ret expression.Exp | |
return ret, errors.Trace(err) | ||
} | ||
c, _ := e.(*expression.Constant) | ||
ret = &expression.Constant{Value: types.NewDatum(c.Value.GetRow()[1:]), RetType: c.GetType()} | ||
if getRowLen(c) == 2 { | ||
ret = &expression.Constant{Value: c.Value.GetRow()[1], RetType: c.GetType()} | ||
} else { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why change here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. because if
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and if not change this, the test case |
||
ret = &expression.Constant{Value: types.NewDatum(c.Value.GetRow()[1:]), RetType: c.GetType()} | ||
} | ||
return | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
argExpr
only used in line 438, we can just useargs[0]
instead of argExpr.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see line 441 and 469