-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
expression: fix #3762, signed integer overflow handle in minus unary scalar function #3780
Changes from 11 commits
7762b8b
f87e0af
465dd1b
17f1e5f
1d15330
a803f9d
d6e6c25
ae342fc
869e1f6
b305a4f
bedd270
0e5de09
4045075
636170d
33bf09c
eb58c48
b199772
6005a18
e2aa622
d0395bf
ffc0df7
27202d6
03dc0e5
7f01b42
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,8 +14,12 @@ | |
package expression | ||
|
||
import ( | ||
"fmt" | ||
"math" | ||
|
||
"github.com/juju/errors" | ||
"github.com/pingcap/tidb/context" | ||
"github.com/pingcap/tidb/mysql" | ||
"github.com/pingcap/tidb/parser/opcode" | ||
"github.com/pingcap/tidb/util/types" | ||
) | ||
|
@@ -27,6 +31,7 @@ var ( | |
_ functionClass = &bitOpFunctionClass{} | ||
_ functionClass = &isTrueOpFunctionClass{} | ||
_ functionClass = &unaryOpFunctionClass{} | ||
_ functionClass = &unaryMinusFunctionClass{} | ||
_ functionClass = &isNullFunctionClass{} | ||
) | ||
|
||
|
@@ -37,6 +42,7 @@ var ( | |
_ builtinFunc = &builtinBitOpSig{} | ||
_ builtinFunc = &builtinIsTrueOpSig{} | ||
_ builtinFunc = &builtinUnaryOpSig{} | ||
_ builtinFunc = &builtinUnaryMinusIntSig{} | ||
_ builtinFunc = &builtinIsNullSig{} | ||
) | ||
|
||
|
@@ -338,9 +344,19 @@ func (b *builtinUnaryOpSig) eval(row []types.Datum) (d types.Datum, err error) { | |
case opcode.Minus: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the code will never reach here? |
||
switch aDatum.Kind() { | ||
case types.KindInt64: | ||
d.SetInt64(-aDatum.GetInt64()) | ||
val := aDatum.GetInt64() | ||
if val == math.MinInt64 { | ||
err = types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", val)) | ||
} | ||
d.SetInt64(-val) | ||
case types.KindUint64: | ||
d.SetInt64(-int64(aDatum.GetUint64())) | ||
uval := aDatum.GetUint64() | ||
if uval > uint64(-math.MinInt64) { // 9223372036854775808 | ||
// -uval will overflow | ||
err = types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", uval)) | ||
} else { | ||
d.SetInt64(-int64(uval)) | ||
} | ||
case types.KindFloat64: | ||
d.SetFloat64(-aDatum.GetFloat64()) | ||
case types.KindFloat32: | ||
|
@@ -378,6 +394,165 @@ func (b *builtinUnaryOpSig) eval(row []types.Datum) (d types.Datum, err error) { | |
return | ||
} | ||
|
||
type unaryMinusFunctionClass struct { | ||
baseFunctionClass | ||
} | ||
|
||
func (b *unaryMinusFunctionClass) typeInfer(argExpr Expression) (evalTp, bool) { | ||
tp := tpInt | ||
switch argExpr.GetType().Tp { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. you can use argExpr.getClassType() here ? case types.ClassString, types.ClassReal:
tp = tpReal
case types.ClassDecimal:
tp = tpDecimal |
||
case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeDouble, mysql.TypeFloat, | ||
mysql.TypeDatetime, mysql.TypeDuration, mysql.TypeTimestamp: | ||
tp = tpReal | ||
case mysql.TypeNewDecimal: | ||
tp = tpDecimal | ||
} | ||
|
||
overflow := false | ||
// TODO: handle float overflow | ||
if arg, ok := argExpr.(*Constant); ok { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why check overflow in type infer? If expr is not Constant, this is useless. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @XuHuaiyu, because overflow will be handled only if the statement is Select, and if the args type is not constant, MySQL will just throw a overflow error out, see below:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's weird here. Can we change parser to parse There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no, image this case "--9223372036854775809". we should not do too much things in parser |
||
switch arg.Value.Kind() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we may check whether mysql.HasUnsignedFlag(arg.GetType().Flag) here |
||
case types.KindUint64: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. comment for the two overflow check. |
||
uval := arg.Value.GetUint64() | ||
if uval > uint64(-math.MinInt64) { | ||
overflow = true | ||
tp = tpDecimal | ||
} | ||
case types.KindInt64: | ||
val := arg.Value.GetInt64() | ||
if val == math.MinInt64 { | ||
overflow = true | ||
tp = tpDecimal | ||
} | ||
} | ||
} | ||
return tp, overflow | ||
} | ||
|
||
func (b *unaryMinusFunctionClass) getFunction(args []Expression, ctx context.Context) (sig builtinFunc, err error) { | ||
argExpr := args[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see line 441 and 469 |
||
retTp, intOverflow := b.typeInfer(argExpr) | ||
|
||
switch argExpr.GetTypeClass() { | ||
case types.ClassInt: | ||
if intOverflow { | ||
bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) | ||
if err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} | ||
} else { | ||
bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpInt) | ||
if err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
sig = &builtinUnaryMinusIntSig{baseIntBuiltinFunc{bf}} | ||
} | ||
case types.ClassDecimal: | ||
bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) | ||
if err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} | ||
case types.ClassReal: | ||
bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal) | ||
if err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
sig = &builtinUnaryMinusRealSig{baseRealBuiltinFunc{bf}} | ||
case types.ClassString: | ||
tp := argExpr.GetType().Tp | ||
if types.IsTypeTime(tp) { | ||
bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. line 465 and ling 471 can be merged? |
||
if err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} | ||
} else if tp == mysql.TypeDuration { | ||
bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) | ||
if err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} | ||
} else { | ||
bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal) | ||
if err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
sig = &builtinUnaryMinusRealSig{baseRealBuiltinFunc{bf}} | ||
} | ||
} | ||
|
||
return sig.setSelf(sig), errors.Trace(b.verifyArgs(args)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we may put b.verifyArgs at the beginning of this func to check the count of args. |
||
} | ||
|
||
func unaryMinusDecimal(dec *types.MyDecimal) (to *types.MyDecimal, err error) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this func may be unnecessary? |
||
to = new(types.MyDecimal) | ||
err = types.DecimalSub(new(types.MyDecimal), dec, to) | ||
return | ||
} | ||
|
||
type builtinUnaryMinusIntSig struct { | ||
baseIntBuiltinFunc | ||
} | ||
|
||
func (b *builtinUnaryMinusIntSig) evalInt(row []types.Datum) (res int64, isNull bool, err error) { | ||
var val int64 | ||
val, isNull, err = b.args[0].EvalInt(row, b.getCtx().GetSessionVars().StmtCtx) | ||
if err != nil || isNull { | ||
return val, isNull, errors.Trace(err) | ||
} | ||
|
||
if mysql.HasUnsignedFlag(b.args[0].GetType().Flag) { | ||
uval := uint64(val) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move the overflow check logic to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tiancaiamao the overflow logic is only for unary minus |
||
if uval > uint64(-math.MinInt64) { | ||
return 0, false, types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", uval)) | ||
} else if uval == uint64(-math.MinInt64) { | ||
return math.MinInt64, false, nil | ||
} | ||
} else if val == math.MinInt64 { | ||
return 0, false, types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", val)) | ||
} | ||
return -val, false, err | ||
} | ||
|
||
type builtinUnaryMinusDecimalSig struct { | ||
baseDecimalBuiltinFunc | ||
|
||
constantArgOverflow bool | ||
} | ||
|
||
func (b *builtinUnaryMinusDecimalSig) evalDecimal(row []types.Datum) (*types.MyDecimal, bool, error) { | ||
var ( | ||
dec, to *types.MyDecimal | ||
) | ||
|
||
sc := b.getCtx().GetSessionVars().StmtCtx | ||
if !sc.InSelectStmt && b.constantArgOverflow { | ||
return dec, false, types.ErrOverflow.GenByArgs("DECIMAL", dec.String()) | ||
} | ||
|
||
dec, isNull, err := b.args[0].EvalDecimal(row, sc) | ||
if err != nil || isNull { | ||
return dec, isNull, errors.Trace(err) | ||
} | ||
|
||
to, err = unaryMinusDecimal(dec) | ||
return to, false, err | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. s/err/errors.Trace(err) |
||
} | ||
|
||
type builtinUnaryMinusRealSig struct { | ||
baseRealBuiltinFunc | ||
} | ||
|
||
func (b *builtinUnaryMinusRealSig) evalReal(row []types.Datum) (res float64, isNull bool, err error) { | ||
sc := b.getCtx().GetSessionVars().StmtCtx | ||
var val float64 | ||
val, isNull, err = b.args[0].EvalReal(row, sc) | ||
res = -val | ||
return | ||
} | ||
|
||
type isNullFunctionClass struct { | ||
baseFunctionClass | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,13 +14,52 @@ | |
package expression | ||
|
||
import ( | ||
"math" | ||
|
||
"github.com/juju/errors" | ||
. "github.com/pingcap/check" | ||
"github.com/pingcap/tidb/ast" | ||
"github.com/pingcap/tidb/util/testleak" | ||
"github.com/pingcap/tidb/util/types" | ||
) | ||
|
||
func (s *testEvaluatorSuite) TestUnary(c *C) { | ||
defer testleak.AfterTest(c)() | ||
cases := []struct { | ||
args interface{} | ||
expected interface{} | ||
overflow bool | ||
getErr bool | ||
}{ | ||
{uint64(9223372036854775809), "-9223372036854775809", true, false}, | ||
{uint64(9223372036854775810), "-9223372036854775810", true, false}, | ||
{uint64(9223372036854775808), int64(-9223372036854775808), false, false}, | ||
{int64(math.MinInt64), "9223372036854775808", true, false}, // --9223372036854775808 | ||
} | ||
sc := s.ctx.GetSessionVars().StmtCtx | ||
origin := sc.InSelectStmt | ||
sc.InSelectStmt = true | ||
defer func() { | ||
sc.InSelectStmt = origin | ||
}() | ||
|
||
for _, t := range cases { | ||
f, err := newFunctionForTest(s.ctx, ast.UnaryMinus, primitiveValsToConstants([]interface{}{t.args})...) | ||
c.Assert(err, IsNil) | ||
d, err := f.Eval(nil) | ||
if t.getErr == false { | ||
c.Assert(err, IsNil) | ||
if !t.overflow { | ||
c.Assert(d.GetValue(), Equals, t.expected) | ||
} else { | ||
c.Assert(d.GetMysqlDecimal().String(), Equals, t.expected) | ||
} | ||
} else { | ||
c.Assert(err, NotNil) | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add test case for |
||
} | ||
|
||
func (s *testEvaluatorSuite) TestAndAnd(c *C) { | ||
defer testleak.AfterTest(c)() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -45,3 +45,32 @@ func FoldConstant(expr Expression) Expression { | |
RetType: scalarFunc.RetType, | ||
} | ||
} | ||
|
||
// PlainFoldConstant does constant folding optimization | ||
// on an expression without recursively folding. | ||
func PlainFoldConstant(expr Expression) Expression { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. unused |
||
scalarFunc, ok := expr.(*ScalarFunction) | ||
if !ok || !scalarFunc.Function.isDeterministic() { | ||
return expr | ||
} | ||
args := scalarFunc.GetArgs() | ||
canFold := true | ||
for i := 0; i < len(args); i++ { | ||
if _, ok := args[i].(*Constant); !ok { | ||
canFold = false | ||
break | ||
} | ||
} | ||
if !canFold { | ||
return expr | ||
} | ||
value, err := scalarFunc.Eval(nil) | ||
if err != nil { | ||
log.Warnf("There may exist an error during constant folding. The function name is %s, args are %s", scalarFunc.FuncName, args) | ||
return expr | ||
} | ||
return &Constant{ | ||
Value: value, | ||
RetType: scalarFunc.RetType, | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
useless line ?