-
Notifications
You must be signed in to change notification settings - Fork 5.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
expression: fix #3762, signed integer overflow handle in minus unary scalar function #3780
Changes from 12 commits
7762b8b
f87e0af
465dd1b
17f1e5f
1d15330
a803f9d
d6e6c25
ae342fc
869e1f6
b305a4f
bedd270
0e5de09
4045075
636170d
33bf09c
eb58c48
b199772
6005a18
e2aa622
d0395bf
ffc0df7
27202d6
03dc0e5
7f01b42
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,8 +14,12 @@ | |
package expression | ||
|
||
import ( | ||
"fmt" | ||
"math" | ||
|
||
"github.com/juju/errors" | ||
"github.com/pingcap/tidb/context" | ||
"github.com/pingcap/tidb/mysql" | ||
"github.com/pingcap/tidb/parser/opcode" | ||
"github.com/pingcap/tidb/util/types" | ||
) | ||
|
@@ -27,6 +31,7 @@ var ( | |
_ functionClass = &bitOpFunctionClass{} | ||
_ functionClass = &isTrueOpFunctionClass{} | ||
_ functionClass = &unaryOpFunctionClass{} | ||
_ functionClass = &unaryMinusFunctionClass{} | ||
_ functionClass = &isNullFunctionClass{} | ||
) | ||
|
||
|
@@ -37,6 +42,7 @@ var ( | |
_ builtinFunc = &builtinBitOpSig{} | ||
_ builtinFunc = &builtinIsTrueOpSig{} | ||
_ builtinFunc = &builtinUnaryOpSig{} | ||
_ builtinFunc = &builtinUnaryMinusIntSig{} | ||
_ builtinFunc = &builtinIsNullSig{} | ||
) | ||
|
||
|
@@ -338,9 +344,19 @@ func (b *builtinUnaryOpSig) eval(row []types.Datum) (d types.Datum, err error) { | |
case opcode.Minus: | ||
switch aDatum.Kind() { | ||
case types.KindInt64: | ||
d.SetInt64(-aDatum.GetInt64()) | ||
val := aDatum.GetInt64() | ||
if val == math.MinInt64 { | ||
err = types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", val)) | ||
} | ||
d.SetInt64(-val) | ||
case types.KindUint64: | ||
d.SetInt64(-int64(aDatum.GetUint64())) | ||
uval := aDatum.GetUint64() | ||
if uval > uint64(-math.MinInt64) { // 9223372036854775808 | ||
// -uval will overflow | ||
err = types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", uval)) | ||
} else { | ||
d.SetInt64(-int64(uval)) | ||
} | ||
case types.KindFloat64: | ||
d.SetFloat64(-aDatum.GetFloat64()) | ||
case types.KindFloat32: | ||
|
@@ -378,6 +394,157 @@ func (b *builtinUnaryOpSig) eval(row []types.Datum) (d types.Datum, err error) { | |
return | ||
} | ||
|
||
type unaryMinusFunctionClass struct { | ||
baseFunctionClass | ||
} | ||
|
||
func (b *unaryMinusFunctionClass) typeInfer(argExpr Expression) (evalTp, bool) { | ||
tp := tpInt | ||
switch argExpr.GetTypeClass() { | ||
case types.ClassString, types.ClassReal: | ||
tp = tpReal | ||
case types.ClassDecimal: | ||
tp = tpDecimal | ||
} | ||
|
||
overflow := false | ||
// TODO: handle float overflow | ||
if arg, ok := argExpr.(*Constant); ok { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why check overflow in type infer? If expr is not Constant, this is useless. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @XuHuaiyu, because overflow will be handled only if the statement is Select, and if the args type is not constant, MySQL will just throw a overflow error out, see below:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's weird here. Can we change parser to parse There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no, image this case "--9223372036854775809". we should not do too much things in parser |
||
switch arg.Value.Kind() { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we may check whether mysql.HasUnsignedFlag(arg.GetType().Flag) here |
||
case types.KindUint64: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. comment for the two overflow check. |
||
uval := arg.Value.GetUint64() | ||
if uval > uint64(-math.MinInt64) { | ||
overflow = true | ||
tp = tpDecimal | ||
} | ||
case types.KindInt64: | ||
val := arg.Value.GetInt64() | ||
if val == math.MinInt64 { | ||
overflow = true | ||
tp = tpDecimal | ||
} | ||
} | ||
} | ||
return tp, overflow | ||
} | ||
|
||
func (b *unaryMinusFunctionClass) getFunction(args []Expression, ctx context.Context) (sig builtinFunc, err error) { | ||
err = b.verifyArgs(args) | ||
if err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
|
||
argExpr := args[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. see line 441 and 469 |
||
retTp, intOverflow := b.typeInfer(argExpr) | ||
|
||
var bf baseBuiltinFunc | ||
switch argExpr.GetTypeClass() { | ||
case types.ClassInt: | ||
if intOverflow { | ||
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) | ||
if err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} | ||
} else { | ||
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpInt) | ||
if err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
sig = &builtinUnaryMinusIntSig{baseIntBuiltinFunc{bf}} | ||
} | ||
case types.ClassDecimal: | ||
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) | ||
if err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} | ||
case types.ClassReal: | ||
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal) | ||
if err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
sig = &builtinUnaryMinusRealSig{baseRealBuiltinFunc{bf}} | ||
case types.ClassString: | ||
tp := argExpr.GetType().Tp | ||
if types.IsTypeTime(tp) || tp == mysql.TypeDuration { | ||
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) | ||
if err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} | ||
} else { | ||
bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal) | ||
if err != nil { | ||
return nil, errors.Trace(err) | ||
} | ||
sig = &builtinUnaryMinusRealSig{baseRealBuiltinFunc{bf}} | ||
} | ||
} | ||
|
||
return sig.setSelf(sig), errors.Trace(err) | ||
} | ||
|
||
type builtinUnaryMinusIntSig struct { | ||
baseIntBuiltinFunc | ||
} | ||
|
||
func (b *builtinUnaryMinusIntSig) evalInt(row []types.Datum) (res int64, isNull bool, err error) { | ||
var val int64 | ||
val, isNull, err = b.args[0].EvalInt(row, b.getCtx().GetSessionVars().StmtCtx) | ||
if err != nil || isNull { | ||
return val, isNull, errors.Trace(err) | ||
} | ||
|
||
if mysql.HasUnsignedFlag(b.args[0].GetType().Flag) { | ||
uval := uint64(val) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. move the overflow check logic to There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tiancaiamao the overflow logic is only for unary minus |
||
if uval > uint64(-math.MinInt64) { | ||
return 0, false, types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", uval)) | ||
} else if uval == uint64(-math.MinInt64) { | ||
return math.MinInt64, false, nil | ||
} | ||
} else if val == math.MinInt64 { | ||
return 0, false, types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", val)) | ||
} | ||
return -val, false, err | ||
} | ||
|
||
type builtinUnaryMinusDecimalSig struct { | ||
baseDecimalBuiltinFunc | ||
|
||
constantArgOverflow bool | ||
} | ||
|
||
func (b *builtinUnaryMinusDecimalSig) evalDecimal(row []types.Datum) (*types.MyDecimal, bool, error) { | ||
sc := b.getCtx().GetSessionVars().StmtCtx | ||
|
||
var dec *types.MyDecimal | ||
dec, isNull, err := b.args[0].EvalDecimal(row, sc) | ||
if err != nil || isNull { | ||
return dec, isNull, errors.Trace(err) | ||
} | ||
|
||
if !sc.InSelectStmt && b.constantArgOverflow { | ||
return dec, false, types.ErrOverflow.GenByArgs("DECIMAL", dec.String()) | ||
} | ||
|
||
to := new(types.MyDecimal) | ||
err = types.DecimalSub(new(types.MyDecimal), dec, to) | ||
return to, false, err | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. s/err/errors.Trace(err) |
||
} | ||
|
||
type builtinUnaryMinusRealSig struct { | ||
baseRealBuiltinFunc | ||
} | ||
|
||
func (b *builtinUnaryMinusRealSig) evalReal(row []types.Datum) (res float64, isNull bool, err error) { | ||
sc := b.getCtx().GetSessionVars().StmtCtx | ||
var val float64 | ||
val, isNull, err = b.args[0].EvalReal(row, sc) | ||
res = -val | ||
return | ||
} | ||
|
||
type isNullFunctionClass struct { | ||
baseFunctionClass | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,13 +14,52 @@ | |
package expression | ||
|
||
import ( | ||
"math" | ||
|
||
"github.com/juju/errors" | ||
. "github.com/pingcap/check" | ||
"github.com/pingcap/tidb/ast" | ||
"github.com/pingcap/tidb/util/testleak" | ||
"github.com/pingcap/tidb/util/types" | ||
) | ||
|
||
func (s *testEvaluatorSuite) TestUnary(c *C) { | ||
defer testleak.AfterTest(c)() | ||
cases := []struct { | ||
args interface{} | ||
expected interface{} | ||
overflow bool | ||
getErr bool | ||
}{ | ||
{uint64(9223372036854775809), "-9223372036854775809", true, false}, | ||
{uint64(9223372036854775810), "-9223372036854775810", true, false}, | ||
{uint64(9223372036854775808), int64(-9223372036854775808), false, false}, | ||
{int64(math.MinInt64), "9223372036854775808", true, false}, // --9223372036854775808 | ||
} | ||
sc := s.ctx.GetSessionVars().StmtCtx | ||
origin := sc.InSelectStmt | ||
sc.InSelectStmt = true | ||
defer func() { | ||
sc.InSelectStmt = origin | ||
}() | ||
|
||
for _, t := range cases { | ||
f, err := newFunctionForTest(s.ctx, ast.UnaryMinus, primitiveValsToConstants([]interface{}{t.args})...) | ||
c.Assert(err, IsNil) | ||
d, err := f.Eval(nil) | ||
if t.getErr == false { | ||
c.Assert(err, IsNil) | ||
if !t.overflow { | ||
c.Assert(d.GetValue(), Equals, t.expected) | ||
} else { | ||
c.Assert(d.GetMysqlDecimal().String(), Equals, t.expected) | ||
} | ||
} else { | ||
c.Assert(err, NotNil) | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. add test case for |
||
} | ||
|
||
func (s *testEvaluatorSuite) TestAndAnd(c *C) { | ||
defer testleak.AfterTest(c)() | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -83,11 +83,13 @@ func NewFunction(ctx context.Context, funcName string, retType *types.FieldType, | |
if builtinRetTp := f.getRetTp(); builtinRetTp.Tp != mysql.TypeUnspecified { | ||
retType = builtinRetTp | ||
} | ||
|
||
sf := &ScalarFunction{ | ||
FuncName: model.NewCIStr(funcName), | ||
RetType: retType, | ||
Function: f, | ||
} | ||
|
||
return FoldConstant(sf), nil | ||
} | ||
|
||
|
@@ -160,10 +162,11 @@ func (sf *ScalarFunction) Decorrelate(schema *Schema) Expression { | |
|
||
// Eval implements Expression interface. | ||
func (sf *ScalarFunction) Eval(row []types.Datum) (d types.Datum, err error) { | ||
sc := sf.GetCtx().GetSessionVars().StmtCtx | ||
if !TurnOnNewExprEval { | ||
return sf.Function.eval(row) | ||
d, err = sf.Function.eval(row) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any difference? |
||
return | ||
} | ||
sc := sf.GetCtx().GetSessionVars().StmtCtx | ||
var ( | ||
res interface{} | ||
isNull bool | ||
|
@@ -173,6 +176,7 @@ func (sf *ScalarFunction) Eval(row []types.Datum) (d types.Datum, err error) { | |
case types.ClassInt: | ||
var intRes int64 | ||
intRes, isNull, err = sf.EvalInt(row, sc) | ||
|
||
if mysql.HasUnsignedFlag(tp.Flag) { | ||
res = uint64(intRes) | ||
} else { | ||
|
@@ -192,6 +196,7 @@ func (sf *ScalarFunction) Eval(row []types.Datum) (d types.Datum, err error) { | |
res, isNull, err = sf.EvalString(row, sc) | ||
} | ||
} | ||
|
||
if isNull || err != nil { | ||
d.SetValue(nil) | ||
return d, errors.Trace(err) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -93,6 +93,7 @@ func (b *planBuilder) rewriteWithPreprocess(expr ast.ExprNode, p LogicalPlan, ag | |
if getRowLen(er.ctxStack[0]) != 1 { | ||
return nil, nil, ErrOperandColumns.GenByArgs(1) | ||
} | ||
|
||
return er.ctxStack[0], er.p, nil | ||
} | ||
|
||
|
@@ -142,7 +143,11 @@ func popRowArg(ctx context.Context, e expression.Expression) (ret expression.Exp | |
return ret, errors.Trace(err) | ||
} | ||
c, _ := e.(*expression.Constant) | ||
ret = &expression.Constant{Value: types.NewDatum(c.Value.GetRow()[1:]), RetType: c.GetType()} | ||
if getRowLen(c) == 2 { | ||
ret = &expression.Constant{Value: c.Value.GetRow()[1], RetType: c.GetType()} | ||
} else { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why change here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. because if
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. and if not change this, the test case |
||
ret = &expression.Constant{Value: types.NewDatum(c.Value.GetRow()[1:]), RetType: c.GetType()} | ||
} | ||
return | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the code will never reach here?