From 7762b8b46bb9efa4bb3a8b2c53d031aa5a5d6c3a Mon Sep 17 00:00:00 2001 From: winkyao Date: Mon, 17 Jul 2017 11:13:25 +0800 Subject: [PATCH 01/17] expression: handle int builtinUnaryOpSig overflow --- executor/executor_test.go | 4 ++++ expression/builtin_op.go | 21 ++++++++++++++++++++- expression/constant_fold.go | 6 ++++++ expression/scalar_function.go | 12 +++++++++--- 4 files changed, 39 insertions(+), 4 deletions(-) diff --git a/executor/executor_test.go b/executor/executor_test.go index e38f59a38f77f..7784b0b3a60b0 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -1034,6 +1034,10 @@ func (s *testSuite) TestBuiltin(c *C) { result = tk.MustQuery("select cast(a as signed) from t") result.Check(testkit.Rows("130000")) + // fixed issue #3762 + result = tk.MustQuery("select -9223372036854775809;") + result.Check(testkit.Rows("-9223372036854775809")) + // test unhex and hex result = tk.MustQuery("select unhex('4D7953514C')") result.Check(testkit.Rows("MySQL")) diff --git a/expression/builtin_op.go b/expression/builtin_op.go index c0973fa0431f5..e1e2b81d66d2e 100644 --- a/expression/builtin_op.go +++ b/expression/builtin_op.go @@ -14,9 +14,14 @@ package expression import ( + "math" + + "fmt" + "github.com/juju/errors" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/parser/opcode" + "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/types" ) @@ -353,7 +358,21 @@ func (b *builtinUnaryOpSig) eval(row []types.Datum) (d types.Datum, err error) { case types.KindInt64: d.SetInt64(-aDatum.GetInt64()) case types.KindUint64: - d.SetInt64(-int64(aDatum.GetUint64())) + // consider overflow, MySQL will convert it to Decimal when overflow occurred + uval := aDatum.GetUint64() + minInt64 := math.MinInt64 + absMinInt64 := uint64(-minInt64) // 9223372036854775808 + if uval > absMinInt64 { + // -uval will overflow + dval := new(types.MyDecimal) + sval := fmt.Sprintf("-%v", uval) + dval.FromString(hack.Slice(sval)) + d.SetMysqlDecimal(dval) + types.DefaultTypeForValue(dval, b.tp) + // err = types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", uval)) + } else { + d.SetInt64(-int64(uval)) + } case types.KindFloat64: d.SetFloat64(-aDatum.GetFloat64()) case types.KindFloat32: diff --git a/expression/constant_fold.go b/expression/constant_fold.go index cbc13084ef2a2..371ce6e0fd720 100644 --- a/expression/constant_fold.go +++ b/expression/constant_fold.go @@ -15,6 +15,7 @@ package expression import ( "github.com/ngaut/log" + "github.com/pingcap/tidb/mysql" ) // FoldConstant does constant folding optimization on an expression. @@ -40,6 +41,11 @@ func FoldConstant(expr Expression) Expression { log.Warnf("There may exist an error during constant folding. The function name is %s, args are %s", scalarFunc.FuncName, args) return expr } + + // TODO: retType maybe changed after function executed + if builtinRetTp := scalarFunc.Function.getRetTp(); builtinRetTp.Tp != mysql.TypeUnspecified { + scalarFunc.RetType = builtinRetTp + } return &Constant{ Value: value, RetType: scalarFunc.RetType, diff --git a/expression/scalar_function.go b/expression/scalar_function.go index 95869cfa1960c..d11c12f799dd2 100644 --- a/expression/scalar_function.go +++ b/expression/scalar_function.go @@ -23,6 +23,7 @@ import ( "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/types" ) @@ -172,10 +173,15 @@ func (sf *ScalarFunction) Eval(row []types.Datum) (d types.Datum, err error) { case types.ClassInt: var intRes int64 intRes, isNull, err = sf.EvalInt(row, sc) - if mysql.HasUnsignedFlag(tp.Flag) { - res = uint64(intRes) + if err != nil && terror.ErrorEqual(err, types.ErrOverflow) && !sc.InUpdateOrDeleteStmt { + // TODO: overflow convert it to decimal + } else { - res = intRes + if mysql.HasUnsignedFlag(tp.Flag) { + res = uint64(intRes) + } else { + res = intRes + } } case types.ClassReal: res, isNull, err = sf.EvalReal(row, sc) From f87e0afcc0e4420829f4b7412c16d86142789c2c Mon Sep 17 00:00:00 2001 From: winkyao Date: Mon, 17 Jul 2017 17:49:27 +0800 Subject: [PATCH 02/17] expression: fix #3762, signed integer overflow handle in minus unary scalar function --- executor/executor_test.go | 2 ++ expression/builtin_op.go | 9 +----- expression/constant_fold.go | 5 ---- expression/scalar_function.go | 53 +++++++++++++++++++++++++++++------ 4 files changed, 47 insertions(+), 22 deletions(-) diff --git a/executor/executor_test.go b/executor/executor_test.go index 7784b0b3a60b0..45644db6d40c9 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -1037,6 +1037,8 @@ func (s *testSuite) TestBuiltin(c *C) { // fixed issue #3762 result = tk.MustQuery("select -9223372036854775809;") result.Check(testkit.Rows("-9223372036854775809")) + result = tk.MustQuery("select -9223372036854775808;") + result.Check(testkit.Rows("-9223372036854775808")) // test unhex and hex result = tk.MustQuery("select unhex('4D7953514C')") diff --git a/expression/builtin_op.go b/expression/builtin_op.go index e1e2b81d66d2e..1368cae61c041 100644 --- a/expression/builtin_op.go +++ b/expression/builtin_op.go @@ -21,7 +21,6 @@ import ( "github.com/juju/errors" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/parser/opcode" - "github.com/pingcap/tidb/util/hack" "github.com/pingcap/tidb/util/types" ) @@ -358,18 +357,12 @@ func (b *builtinUnaryOpSig) eval(row []types.Datum) (d types.Datum, err error) { case types.KindInt64: d.SetInt64(-aDatum.GetInt64()) case types.KindUint64: - // consider overflow, MySQL will convert it to Decimal when overflow occurred uval := aDatum.GetUint64() minInt64 := math.MinInt64 absMinInt64 := uint64(-minInt64) // 9223372036854775808 if uval > absMinInt64 { // -uval will overflow - dval := new(types.MyDecimal) - sval := fmt.Sprintf("-%v", uval) - dval.FromString(hack.Slice(sval)) - d.SetMysqlDecimal(dval) - types.DefaultTypeForValue(dval, b.tp) - // err = types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", uval)) + err = types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", uval)) } else { d.SetInt64(-int64(uval)) } diff --git a/expression/constant_fold.go b/expression/constant_fold.go index 371ce6e0fd720..a1fbcf0315290 100644 --- a/expression/constant_fold.go +++ b/expression/constant_fold.go @@ -15,7 +15,6 @@ package expression import ( "github.com/ngaut/log" - "github.com/pingcap/tidb/mysql" ) // FoldConstant does constant folding optimization on an expression. @@ -42,10 +41,6 @@ func FoldConstant(expr Expression) Expression { return expr } - // TODO: retType maybe changed after function executed - if builtinRetTp := scalarFunc.Function.getRetTp(); builtinRetTp.Tp != mysql.TypeUnspecified { - scalarFunc.RetType = builtinRetTp - } return &Constant{ Value: value, RetType: scalarFunc.RetType, diff --git a/expression/scalar_function.go b/expression/scalar_function.go index d11c12f799dd2..31c0e52ed769e 100644 --- a/expression/scalar_function.go +++ b/expression/scalar_function.go @@ -158,12 +158,39 @@ func (sf *ScalarFunction) Decorrelate(schema *Schema) Expression { return sf } +func (sf *ScalarFunction) convertArgsToDecimal(sc *variable.StatementContext) error { + ft := types.NewFieldType(mysql.TypeNewDecimal) + for _, arg := range sf.GetArgs() { + if constArg, ok := arg.(*Constant); ok { + val, err := constArg.Value.ConvertTo(sc, ft) + if err != nil { + return err + } + constArg.Value = val + } + } + + return nil +} + // Eval implements Expression interface. func (sf *ScalarFunction) Eval(row []types.Datum) (d types.Datum, err error) { + sc := sf.GetCtx().GetSessionVars().StmtCtx if !TurnOnNewExprEval { - return sf.Function.eval(row) + d, err = sf.Function.eval(row) + // TODO: fix #3762, maybe better way + if err != nil && terror.ErrorEqual(err, types.ErrOverflow) && sf.GetTypeClass() == types.ClassInt && + sf.FuncName.L == ast.UnaryMinus && !sc.InUpdateOrDeleteStmt { + err = sf.convertArgsToDecimal(sc) + if err != nil { + return d, errors.Trace(err) + } + d, err = sf.Function.eval(row) + // change return type + types.DefaultTypeForValue(d, sf.RetType) + } + return } - sc := sf.GetCtx().GetSessionVars().StmtCtx var ( res interface{} isNull bool @@ -173,15 +200,11 @@ func (sf *ScalarFunction) Eval(row []types.Datum) (d types.Datum, err error) { case types.ClassInt: var intRes int64 intRes, isNull, err = sf.EvalInt(row, sc) - if err != nil && terror.ErrorEqual(err, types.ErrOverflow) && !sc.InUpdateOrDeleteStmt { - // TODO: overflow convert it to decimal + if mysql.HasUnsignedFlag(tp.Flag) { + res = uint64(intRes) } else { - if mysql.HasUnsignedFlag(tp.Flag) { - res = uint64(intRes) - } else { - res = intRes - } + res = intRes } case types.ClassReal: res, isNull, err = sf.EvalReal(row, sc) @@ -197,6 +220,18 @@ func (sf *ScalarFunction) Eval(row []types.Datum) (d types.Datum, err error) { res, isNull, err = sf.EvalString(row, sc) } } + + if err != nil && terror.ErrorEqual(err, types.ErrOverflow) && + sf.FuncName.L == ast.UnaryMinus && !sc.InUpdateOrDeleteStmt { + // TODO: fix #3762, overflow convert it to decimal + err = sf.convertArgsToDecimal(sc) + if err != nil { + return d, errors.Trace(err) + } + res, isNull, err = sf.EvalDecimal(row, sc) + types.DefaultTypeForValue(res, sf.RetType) + } + if isNull || err != nil { d.SetValue(nil) return d, errors.Trace(err) From 465dd1b678c7ae41a116e3c07c3061c097ca1c56 Mon Sep 17 00:00:00 2001 From: winkyao Date: Mon, 17 Jul 2017 17:52:39 +0800 Subject: [PATCH 03/17] expression: fix #3762, add builtin_op_test.go, some builtin test cases --- expression/builtin_op_test.go | 54 +++++++++++++++++++++++++++++++++++ 1 file changed, 54 insertions(+) create mode 100644 expression/builtin_op_test.go diff --git a/expression/builtin_op_test.go b/expression/builtin_op_test.go new file mode 100644 index 0000000000000..16fe744a79876 --- /dev/null +++ b/expression/builtin_op_test.go @@ -0,0 +1,54 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package expression + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/util/testleak" +) + +func (s *testEvaluatorSuite) TestUnary(c *C) { + defer testleak.AfterTest(c)() + cases := []struct { + args interface{} + expected interface{} + expectedType byte + overflow bool + getErr bool + }{ + {uint64(9223372036854775809), "-9223372036854775809", mysql.TypeNewDecimal, true, false}, + {uint64(9223372036854775810), "-9223372036854775810", mysql.TypeNewDecimal, true, false}, + {uint64(9223372036854775808), "-9223372036854775808", mysql.TypeLonglong, false, false}, + } + + for _, t := range cases { + f, err := newFunctionForTest(s.ctx, ast.UnaryMinus, primitiveValsToConstants([]interface{}{t.args})...) + + c.Assert(err, IsNil) + d, err := f.Eval(nil) + if t.getErr { + c.Assert(err, NotNil) + } else { + c.Assert(err, IsNil) + if !t.overflow { + c.Assert(d.GetString(), Equals, t.expected) + } else { + c.Assert(d.GetMysqlDecimal().String(), Equals, t.expected) + c.Assert(f.GetType().Tp, Equals, t.expectedType) + } + } + } +} From 17f1e5f0dec5497e017b476123adb4848c802e18 Mon Sep 17 00:00:00 2001 From: winkyao Date: Mon, 17 Jul 2017 18:03:25 +0800 Subject: [PATCH 04/17] *: tiny cleanup --- expression/builtin_op.go | 3 +-- expression/builtin_op_test.go | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/expression/builtin_op.go b/expression/builtin_op.go index 1368cae61c041..862d4080ec33b 100644 --- a/expression/builtin_op.go +++ b/expression/builtin_op.go @@ -14,9 +14,8 @@ package expression import ( - "math" - "fmt" + "math" "github.com/juju/errors" "github.com/pingcap/tidb/context" diff --git a/expression/builtin_op_test.go b/expression/builtin_op_test.go index 16fe744a79876..e0cff685d33b8 100644 --- a/expression/builtin_op_test.go +++ b/expression/builtin_op_test.go @@ -1,4 +1,4 @@ -// Copyright 2015 PingCAP, Inc. +// Copyright 2017 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. From 1d153307e4575fe3eb3c579dbc4c57f723b4e154 Mon Sep 17 00:00:00 2001 From: winkyao Date: Mon, 17 Jul 2017 20:30:32 +0800 Subject: [PATCH 05/17] expression: set the correct unary minus function ScalarFunction.RetType --- expression/scalar_function.go | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/expression/scalar_function.go b/expression/scalar_function.go index 31c0e52ed769e..f6afdc5fc939e 100644 --- a/expression/scalar_function.go +++ b/expression/scalar_function.go @@ -187,7 +187,8 @@ func (sf *ScalarFunction) Eval(row []types.Datum) (d types.Datum, err error) { } d, err = sf.Function.eval(row) // change return type - types.DefaultTypeForValue(d, sf.RetType) + decVal, _ := d.ToDecimal(sc) + types.DefaultTypeForValue(decVal, sf.RetType) } return } From a803f9dcb72550cb9f0c5268649e7a4f9b71e112 Mon Sep 17 00:00:00 2001 From: winkyao Date: Wed, 19 Jul 2017 13:33:05 +0800 Subject: [PATCH 06/17] expression: rewrite unary minus builtin function to handle overflow. --- executor/executor_test.go | 2 + expression/builtin.go | 2 +- expression/builtin_op.go | 210 ++++++++++++++++++++++++++++++++- expression/builtin_op_test.go | 23 ++-- expression/constant_fold.go | 28 +++++ expression/expression.go | 6 +- expression/scalar_function.go | 19 +-- expression/typeinferer_test.go | 4 + plan/expression_rewriter.go | 3 +- 9 files changed, 263 insertions(+), 34 deletions(-) diff --git a/executor/executor_test.go b/executor/executor_test.go index 45644db6d40c9..ceb7a0b7c09eb 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -1037,6 +1037,8 @@ func (s *testSuite) TestBuiltin(c *C) { // fixed issue #3762 result = tk.MustQuery("select -9223372036854775809;") result.Check(testkit.Rows("-9223372036854775809")) + result = tk.MustQuery("select --9223372036854775809;") + result.Check(testkit.Rows("9223372036854775809")) result = tk.MustQuery("select -9223372036854775808;") result.Check(testkit.Rows("-9223372036854775808")) diff --git a/expression/builtin.go b/expression/builtin.go index e3bcc26fe0dbf..9542190d1a67a 100644 --- a/expression/builtin.go +++ b/expression/builtin.go @@ -784,7 +784,7 @@ var funcs = map[string]functionClass{ ast.UnaryNot: &unaryOpFunctionClass{baseFunctionClass{ast.UnaryNot, 1, 1}, opcode.Not}, ast.BitNeg: &unaryOpFunctionClass{baseFunctionClass{ast.BitNeg, 1, 1}, opcode.BitNeg}, ast.UnaryPlus: &unaryOpFunctionClass{baseFunctionClass{ast.UnaryPlus, 1, 1}, opcode.Plus}, - ast.UnaryMinus: &unaryOpFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}, opcode.Minus}, + ast.UnaryMinus: &unaryMinusFunctionClass{baseFunctionClass{ast.UnaryMinus, 1, 1}}, ast.In: &inFunctionClass{baseFunctionClass{ast.In, 1, -1}}, ast.IsTruth: &isTrueOpFunctionClass{baseFunctionClass{ast.IsTruth, 1, 1}, opcode.IsTruth}, ast.IsFalsity: &isTrueOpFunctionClass{baseFunctionClass{ast.IsFalsity, 1, 1}, opcode.IsFalsity}, diff --git a/expression/builtin_op.go b/expression/builtin_op.go index 862d4080ec33b..d5f4c243c4b1d 100644 --- a/expression/builtin_op.go +++ b/expression/builtin_op.go @@ -19,6 +19,7 @@ import ( "github.com/juju/errors" "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/util/types" ) @@ -30,6 +31,7 @@ var ( _ functionClass = &bitOpFunctionClass{} _ functionClass = &isTrueOpFunctionClass{} _ functionClass = &unaryOpFunctionClass{} + _ functionClass = &unaryMinusFunctionClass{} _ functionClass = &isNullFunctionClass{} ) @@ -40,6 +42,7 @@ var ( _ builtinFunc = &builtinBitOpSig{} _ builtinFunc = &builtinIsTrueOpSig{} _ builtinFunc = &builtinUnaryOpSig{} + _ builtinFunc = &builtinUnaryMinusIntSig{} _ builtinFunc = &builtinIsNullSig{} ) @@ -354,12 +357,14 @@ func (b *builtinUnaryOpSig) eval(row []types.Datum) (d types.Datum, err error) { case opcode.Minus: switch aDatum.Kind() { case types.KindInt64: - d.SetInt64(-aDatum.GetInt64()) + val := aDatum.GetInt64() + if val == math.MinInt64 { + err = types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", val)) + } + d.SetInt64(-val) case types.KindUint64: uval := aDatum.GetUint64() - minInt64 := math.MinInt64 - absMinInt64 := uint64(-minInt64) // 9223372036854775808 - if uval > absMinInt64 { + if uval > uint64(-math.MinInt64) { // 9223372036854775808 // -uval will overflow err = types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", uval)) } else { @@ -402,6 +407,203 @@ func (b *builtinUnaryOpSig) eval(row []types.Datum) (d types.Datum, err error) { return } +type unaryMinusFunctionClass struct { + baseFunctionClass +} + +func (b *unaryMinusFunctionClass) typeInfer(argExpr Expression, bf *baseBuiltinFunc) bool { + bf.tp.Init(mysql.TypeLonglong) + switch argExpr.GetType().Tp { + case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeDouble, mysql.TypeFloat, + mysql.TypeDatetime, mysql.TypeDuration, mysql.TypeTimestamp: + bf.tp.Tp = mysql.TypeDouble + case mysql.TypeNewDecimal: + bf.tp.Tp = mysql.TypeNewDecimal + } + types.SetBinChsClnFlag(bf.tp) + + overflow := false + if arg, ok := argExpr.(*Constant); ok { + switch arg.Value.Kind() { + case types.KindUint64: + uval := arg.Value.GetUint64() + if uval > uint64(-math.MinInt64) { + overflow = true + bf.tp.Tp = mysql.TypeNewDecimal + } + case types.KindInt64: + val := arg.Value.GetInt64() + if val == math.MinInt64 { + overflow = true + bf.tp.Tp = mysql.TypeNewDecimal + } + } + } + return overflow +} + +func (b *unaryMinusFunctionClass) getFunction(args []Expression, ctx context.Context) (sig builtinFunc, err error) { + bf := newBaseBuiltinFunc(args, ctx) + argExpr := args[0] + overflow := b.typeInfer(argExpr, &bf) + if overflow { + sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, true} + return sig.setSelf(sig), errors.Trace(b.verifyArgs(args)) + } + + switch argExpr.GetTypeClass() { + case types.ClassInt: + sig = &builtinUnaryMinusIntSig{baseIntBuiltinFunc{bf}} + case types.ClassDecimal: + sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} + case types.ClassReal: + sig = &builtinUnaryMinusRealSig{baseRealBuiltinFunc{bf}} + case types.ClassString: + tp := argExpr.GetType().Tp + if types.IsTypeTime(tp) { + sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} + } else if tp == mysql.TypeDuration { + sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} + } else { + sig = &builtinUnaryMinusRealSig{baseRealBuiltinFunc{bf}} + } + } + + return sig.setSelf(sig), errors.Trace(b.verifyArgs(args)) +} + +func unaryMinusDecimal(dec *types.MyDecimal) (to *types.MyDecimal, err error) { + to = new(types.MyDecimal) + err = types.DecimalSub(new(types.MyDecimal), dec, to) + return +} + +type builtinUnaryMinusIntSig struct { + baseIntBuiltinFunc +} + +// func (b *builtinUnaryMinusIntSig) eval(row []types.Datum) (d types.Datum, err error) { +// res, isNull, err := b.self.evalInt(row) +// sc := b.getCtx().GetSessionVars().StmtCtx +// if !sc.InUpdateOrDeleteStmt && terror.ErrorEqual(err, types.ErrOverflow) { +// var ( +// dec, to *types.MyDecimal +// ) + +// dec, isNull, err = b.args[0].EvalDecimal(row, sc) +// if err != nil || isNull { +// return d, errors.Trace(err) +// } +// to, err = unaryMinusDecimal(dec) +// d.SetMysqlDecimal(to) +// b.getRetTp().Tp = mysql.TypeNewDecimal +// return +// } + +// if err != nil || isNull { +// return d, errors.Trace(err) +// } +// if mysql.HasUnsignedFlag(b.tp.Flag) { +// d.SetUint64(uint64(res)) +// } else { +// d.SetInt64(res) +// } +// return +// } + +func (b *builtinUnaryMinusIntSig) evalInt(row []types.Datum) (res int64, isNull bool, err error) { + var val int64 + val, isNull, err = b.args[0].EvalInt(row, b.getCtx().GetSessionVars().StmtCtx) + if err != nil || isNull { + return val, isNull, errors.Trace(err) + } + + if mysql.HasUnsignedFlag(b.args[0].GetType().Flag) { + uval := uint64(val) + if uval > uint64(-math.MinInt64) { + return 0, false, types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", uval)) + } else if uval == uint64(-math.MinInt64) { + return math.MinInt64, false, nil + } + } else if val == math.MinInt64 { + return 0, false, types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", val)) + } + return -val, false, err +} + +type builtinUnaryMinusDecimalSig struct { + baseDecimalBuiltinFunc + + constantArgOverflow bool +} + +func (b *builtinUnaryMinusDecimalSig) evalDecimal(row []types.Datum) (*types.MyDecimal, bool, error) { + var ( + dec, to *types.MyDecimal + ) + + sc := b.getCtx().GetSessionVars().StmtCtx + aDatum, err := b.args[0].Eval(row) + if err != nil || aDatum.IsNull() { + return nil, aDatum.IsNull(), errors.Trace(err) + } + switch aDatum.Kind() { + case types.KindMysqlTime: + dec := new(types.MyDecimal) + err := types.DecimalSub(new(types.MyDecimal), aDatum.GetMysqlTime().ToNumber(), dec) + return dec, false, err + case types.KindMysqlDuration: + dec := new(types.MyDecimal) + err := types.DecimalSub(new(types.MyDecimal), aDatum.GetMysqlDuration().ToNumber(), dec) + return dec, false, err + } + + dec, isNull, err := b.args[0].EvalDecimal(row, sc) + if err != nil || isNull { + return dec, isNull, errors.Trace(err) + } + + if sc.InUpdateOrDeleteStmt && b.constantArgOverflow { + return dec, false, types.ErrOverflow.GenByArgs("DECIMAL", dec.String()) + } + + to, err = unaryMinusDecimal(dec) + return to, false, err +} + +type builtinUnaryMinusRealSig struct { + baseRealBuiltinFunc +} + +func (b *builtinUnaryMinusRealSig) evalReal(row []types.Datum) (res float64, isNull bool, err error) { + sc := b.getCtx().GetSessionVars().StmtCtx + var aDatum types.Datum + aDatum, err = b.args[0].Eval(row) + if err != nil || aDatum.IsNull() { + return res, aDatum.IsNull(), errors.Trace(err) + } + + switch aDatum.Kind() { + case types.KindFloat32: + res = float64(-aDatum.GetFloat32()) + case types.KindFloat64: + res = float64(-aDatum.GetFloat64()) + case types.KindString, types.KindBytes: + f, err1 := types.StrToFloat(sc, aDatum.GetString()) + err = errors.Trace(err1) + res = float64(-f) + case types.KindMysqlHex: + res = float64(-aDatum.GetMysqlHex().ToNumber()) + case types.KindMysqlBit: + res = float64(-aDatum.GetMysqlBit().ToNumber()) + case types.KindMysqlEnum: + res = float64(-aDatum.GetMysqlEnum().ToNumber()) + case types.KindMysqlSet: + res = float64(-aDatum.GetMysqlSet().ToNumber()) + } + return +} + type isNullFunctionClass struct { baseFunctionClass } diff --git a/expression/builtin_op_test.go b/expression/builtin_op_test.go index e0cff685d33b8..ed3088ef72ee9 100644 --- a/expression/builtin_op_test.go +++ b/expression/builtin_op_test.go @@ -14,24 +14,25 @@ package expression import ( + "math" + . "github.com/pingcap/check" "github.com/pingcap/tidb/ast" - "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/util/testleak" ) func (s *testEvaluatorSuite) TestUnary(c *C) { defer testleak.AfterTest(c)() cases := []struct { - args interface{} - expected interface{} - expectedType byte - overflow bool - getErr bool + args interface{} + expected interface{} + overflow bool + getErr bool }{ - {uint64(9223372036854775809), "-9223372036854775809", mysql.TypeNewDecimal, true, false}, - {uint64(9223372036854775810), "-9223372036854775810", mysql.TypeNewDecimal, true, false}, - {uint64(9223372036854775808), "-9223372036854775808", mysql.TypeLonglong, false, false}, + {uint64(9223372036854775809), "-9223372036854775809", true, false}, + {uint64(9223372036854775810), "-9223372036854775810", true, false}, + {uint64(9223372036854775808), "-9223372036854775808", false, false}, + {int64(math.MinInt64), "9223372036854775808", false, false}, } for _, t := range cases { @@ -44,10 +45,10 @@ func (s *testEvaluatorSuite) TestUnary(c *C) { } else { c.Assert(err, IsNil) if !t.overflow { - c.Assert(d.GetString(), Equals, t.expected) + strd, _ := d.ToString() + c.Assert(strd, Equals, t.expected) } else { c.Assert(d.GetMysqlDecimal().String(), Equals, t.expected) - c.Assert(f.GetType().Tp, Equals, t.expectedType) } } } diff --git a/expression/constant_fold.go b/expression/constant_fold.go index a1fbcf0315290..f6437a621ef25 100644 --- a/expression/constant_fold.go +++ b/expression/constant_fold.go @@ -40,7 +40,35 @@ func FoldConstant(expr Expression) Expression { log.Warnf("There may exist an error during constant folding. The function name is %s, args are %s", scalarFunc.FuncName, args) return expr } + return &Constant{ + Value: value, + RetType: scalarFunc.RetType, + } +} +// PlainFoldConstant does constant folding optimization +// on an expression without recursively folding. +func PlainFoldConstant(expr Expression) Expression { + scalarFunc, ok := expr.(*ScalarFunction) + if !ok || !scalarFunc.Function.isDeterministic() { + return expr + } + args := scalarFunc.GetArgs() + canFold := true + for i := 0; i < len(args); i++ { + if _, ok := args[i].(*Constant); !ok { + canFold = false + break + } + } + if !canFold { + return expr + } + value, err := scalarFunc.Eval(nil) + if err != nil { + log.Warnf("There may exist an error during constant folding. The function name is %s, args are %s", scalarFunc.FuncName, args) + return expr + } return &Constant{ Value: value, RetType: scalarFunc.RetType, diff --git a/expression/expression.go b/expression/expression.go index e0d0f1ed4be9c..f11a8599fc7a8 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -170,7 +170,8 @@ func evalExprToDecimal(expr Expression, row []types.Datum, sc *variable.Statemen if val.IsNull() || err != nil { return res, val.IsNull(), errors.Trace(err) } - if expr.GetTypeClass() == types.ClassDecimal { + switch expr.GetTypeClass() { + case types.ClassDecimal, types.ClassInt: res, err = val.ToDecimal(sc) return res, false, errors.Trace(err) // TODO: We maintain two sets of type systems, one for Expression, one for Datum. @@ -180,7 +181,8 @@ func evalExprToDecimal(expr Expression, row []types.Datum, sc *variable.Statemen // but what we actually get is store as float64 in Datum. // So if we wrap `CastDecimalAsInt` upon the result, we'll get when call `arg.EvalDecimal()`. // This will be fixed after all built-in functions be rewrite correctlly. - } else if IsHybridType(expr) { + } + if IsHybridType(expr) { res, err = val.ToDecimal(sc) return res, false, errors.Trace(err) } diff --git a/expression/scalar_function.go b/expression/scalar_function.go index f6afdc5fc939e..fb544c60a8343 100644 --- a/expression/scalar_function.go +++ b/expression/scalar_function.go @@ -84,11 +84,13 @@ func NewFunction(ctx context.Context, funcName string, retType *types.FieldType, if builtinRetTp := f.getRetTp(); builtinRetTp.Tp != mysql.TypeUnspecified { retType = builtinRetTp } - return &ScalarFunction{ + scalarFunc := &ScalarFunction{ FuncName: model.NewCIStr(funcName), RetType: retType, Function: f, - }, nil + } + + return PlainFoldConstant(scalarFunc), nil } // ScalarFuncs2Exprs converts []*ScalarFunction to []Expression. @@ -178,18 +180,6 @@ func (sf *ScalarFunction) Eval(row []types.Datum) (d types.Datum, err error) { sc := sf.GetCtx().GetSessionVars().StmtCtx if !TurnOnNewExprEval { d, err = sf.Function.eval(row) - // TODO: fix #3762, maybe better way - if err != nil && terror.ErrorEqual(err, types.ErrOverflow) && sf.GetTypeClass() == types.ClassInt && - sf.FuncName.L == ast.UnaryMinus && !sc.InUpdateOrDeleteStmt { - err = sf.convertArgsToDecimal(sc) - if err != nil { - return d, errors.Trace(err) - } - d, err = sf.Function.eval(row) - // change return type - decVal, _ := d.ToDecimal(sc) - types.DefaultTypeForValue(decVal, sf.RetType) - } return } var ( @@ -230,7 +220,6 @@ func (sf *ScalarFunction) Eval(row []types.Datum) (d types.Datum, err error) { return d, errors.Trace(err) } res, isNull, err = sf.EvalDecimal(row, sc) - types.DefaultTypeForValue(res, sf.RetType) } if isNull || err != nil { diff --git a/expression/typeinferer_test.go b/expression/typeinferer_test.go index 979d260865ec9..5dc8be27f73e8 100644 --- a/expression/typeinferer_test.go +++ b/expression/typeinferer_test.go @@ -343,6 +343,10 @@ func (ts *testTypeInferrerSuite) TestInferType(c *C) { {`json_insert('{"a": 1}', '$.a', 3)`, mysql.TypeJSON, charset.CharsetUTF8, 0}, {`json_replace('{"a": 1}', '$.a', 3)`, mysql.TypeJSON, charset.CharsetUTF8, 0}, {`json_merge('{"a": 1}', '3')`, mysql.TypeJSON, charset.CharsetUTF8, 0}, + {"-9223372036854775809", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag}, + {"-9223372036854775808", mysql.TypeLonglong, charset.CharsetBin, mysql.BinaryFlag}, + {"--9223372036854775809", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag}, + {"--9223372036854775808", mysql.TypeNewDecimal, charset.CharsetBin, mysql.BinaryFlag}, } for _, tt := range tests { ctx := testKit.Se.(context.Context) diff --git a/plan/expression_rewriter.go b/plan/expression_rewriter.go index bd3ac0b1c9a90..bb619f076dc62 100644 --- a/plan/expression_rewriter.go +++ b/plan/expression_rewriter.go @@ -93,7 +93,8 @@ func (b *planBuilder) rewriteWithPreprocess(expr ast.ExprNode, p LogicalPlan, ag if getRowLen(er.ctxStack[0]) != 1 { return nil, nil, ErrOperandColumns.GenByArgs(1) } - result := expression.FoldConstant(er.ctxStack[0]) + result := expression.PlainFoldConstant(er.ctxStack[0]) + // result := expression.FoldConstant(er.ctxStack[0]) return result, er.p, nil } From d6e6c25b449eca48490d69750bc5e3f7acaf5ec7 Mon Sep 17 00:00:00 2001 From: winkyao Date: Thu, 20 Jul 2017 22:24:05 +0800 Subject: [PATCH 07/17] plan: fix popRowArg bug: when get expression.Constatn and getRowLen is 2, should return datum, not row --- plan/expression_rewriter.go | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/plan/expression_rewriter.go b/plan/expression_rewriter.go index bb619f076dc62..543e23fe8492f 100644 --- a/plan/expression_rewriter.go +++ b/plan/expression_rewriter.go @@ -144,7 +144,11 @@ func popRowArg(ctx context.Context, e expression.Expression) (ret expression.Exp return ret, errors.Trace(err) } c, _ := e.(*expression.Constant) - ret = &expression.Constant{Value: types.NewDatum(c.Value.GetRow()[1:]), RetType: c.GetType()} + if getRowLen(c) == 2 { + ret = &expression.Constant{Value: c.Value.GetRow()[1], RetType: c.GetType()} + } else { + ret = &expression.Constant{Value: types.NewDatum(c.Value.GetRow()[1:]), RetType: c.GetType()} + } return } From ae342fcc7b71e339affae4b4e06de42f94aeb9a5 Mon Sep 17 00:00:00 2001 From: winkyao Date: Mon, 24 Jul 2017 09:13:51 +0800 Subject: [PATCH 08/17] *:git stash --- executor/prepared.go | 2 ++ expression/builtin_op.go | 32 ++------------------------------ expression/scalar_function.go | 4 ++-- sessionctx/variable/session.go | 1 + 4 files changed, 7 insertions(+), 32 deletions(-) diff --git a/executor/prepared.go b/executor/prepared.go index 93f135378e33a..81e13f0726a3e 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -347,6 +347,8 @@ func ResetStmtCtx(ctx context.Context, s ast.StmtNode) { case *ast.LoadDataStmt: sc.IgnoreTruncate = false sc.TruncateAsWarning = !sessVars.StrictSQLMode + case *ast.SelectStmt: + sc.InSelectStmt = true default: sc.IgnoreTruncate = true if show, ok := s.(*ast.ShowStmt); ok { diff --git a/expression/builtin_op.go b/expression/builtin_op.go index d5f4c243c4b1d..67c190c6583b0 100644 --- a/expression/builtin_op.go +++ b/expression/builtin_op.go @@ -423,6 +423,7 @@ func (b *unaryMinusFunctionClass) typeInfer(argExpr Expression, bf *baseBuiltinF types.SetBinChsClnFlag(bf.tp) overflow := false + // TODO: handle float overflow if arg, ok := argExpr.(*Constant); ok { switch arg.Value.Kind() { case types.KindUint64: @@ -482,35 +483,6 @@ type builtinUnaryMinusIntSig struct { baseIntBuiltinFunc } -// func (b *builtinUnaryMinusIntSig) eval(row []types.Datum) (d types.Datum, err error) { -// res, isNull, err := b.self.evalInt(row) -// sc := b.getCtx().GetSessionVars().StmtCtx -// if !sc.InUpdateOrDeleteStmt && terror.ErrorEqual(err, types.ErrOverflow) { -// var ( -// dec, to *types.MyDecimal -// ) - -// dec, isNull, err = b.args[0].EvalDecimal(row, sc) -// if err != nil || isNull { -// return d, errors.Trace(err) -// } -// to, err = unaryMinusDecimal(dec) -// d.SetMysqlDecimal(to) -// b.getRetTp().Tp = mysql.TypeNewDecimal -// return -// } - -// if err != nil || isNull { -// return d, errors.Trace(err) -// } -// if mysql.HasUnsignedFlag(b.tp.Flag) { -// d.SetUint64(uint64(res)) -// } else { -// d.SetInt64(res) -// } -// return -// } - func (b *builtinUnaryMinusIntSig) evalInt(row []types.Datum) (res int64, isNull bool, err error) { var val int64 val, isNull, err = b.args[0].EvalInt(row, b.getCtx().GetSessionVars().StmtCtx) @@ -563,7 +535,7 @@ func (b *builtinUnaryMinusDecimalSig) evalDecimal(row []types.Datum) (*types.MyD return dec, isNull, errors.Trace(err) } - if sc.InUpdateOrDeleteStmt && b.constantArgOverflow { + if !sc.InSelectStmt && b.constantArgOverflow { return dec, false, types.ErrOverflow.GenByArgs("DECIMAL", dec.String()) } diff --git a/expression/scalar_function.go b/expression/scalar_function.go index fb544c60a8343..03a235b96bdb0 100644 --- a/expression/scalar_function.go +++ b/expression/scalar_function.go @@ -166,7 +166,7 @@ func (sf *ScalarFunction) convertArgsToDecimal(sc *variable.StatementContext) er if constArg, ok := arg.(*Constant); ok { val, err := constArg.Value.ConvertTo(sc, ft) if err != nil { - return err + return errors.Trace(err) } constArg.Value = val } @@ -213,7 +213,7 @@ func (sf *ScalarFunction) Eval(row []types.Datum) (d types.Datum, err error) { } if err != nil && terror.ErrorEqual(err, types.ErrOverflow) && - sf.FuncName.L == ast.UnaryMinus && !sc.InUpdateOrDeleteStmt { + sf.FuncName.L == ast.UnaryMinus && sc.InSelectStmt { // TODO: fix #3762, overflow convert it to decimal err = sf.convertArgsToDecimal(sc) if err != nil { diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 614a2c6180ff4..e450d1fda5d0b 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -318,6 +318,7 @@ type StatementContext struct { // Set the following variables before execution InUpdateOrDeleteStmt bool + InSelectStmt bool IgnoreTruncate bool TruncateAsWarning bool InShowWarning bool From b305a4faca3acb81e4d7b463fc0108f31685337c Mon Sep 17 00:00:00 2001 From: winkyao Date: Tue, 25 Jul 2017 09:51:48 +0800 Subject: [PATCH 09/17] expression: base #3868 pr, and rewrite builtin function unary minus to handle overflow gracely --- executor/aggregate_test.go | 1 + executor/prepared.go | 1 + expression/builtin_op_test.go | 20 +++++++++++++++----- expression/scalar_function.go | 26 -------------------------- 4 files changed, 17 insertions(+), 31 deletions(-) diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index 0869addfc53e7..59a9c4ffbe60b 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -318,6 +318,7 @@ func (s *testSuite) TestAggPrune(c *C) { testleak.AfterTest(c)() }() tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") tk.MustExec("drop table if exists t") tk.MustExec("create table t(id int primary key, b varchar(50), c int)") diff --git a/executor/prepared.go b/executor/prepared.go index 81e13f0726a3e..44828b9c660bd 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -349,6 +349,7 @@ func ResetStmtCtx(ctx context.Context, s ast.StmtNode) { sc.TruncateAsWarning = !sessVars.StrictSQLMode case *ast.SelectStmt: sc.InSelectStmt = true + sc.IgnoreTruncate = true default: sc.IgnoreTruncate = true if show, ok := s.(*ast.ShowStmt); ok { diff --git a/expression/builtin_op_test.go b/expression/builtin_op_test.go index 91d4be1506748..514e0ff424d1d 100644 --- a/expression/builtin_op_test.go +++ b/expression/builtin_op_test.go @@ -33,19 +33,29 @@ func (s *testEvaluatorSuite) TestUnary(c *C) { }{ {uint64(9223372036854775809), "-9223372036854775809", true, false}, {uint64(9223372036854775810), "-9223372036854775810", true, false}, - {uint64(9223372036854775808), -9223372036854775808, false, false}, + {uint64(9223372036854775808), int64(-9223372036854775808), false, false}, {int64(math.MinInt64), "9223372036854775808", true, false}, // --9223372036854775808 } + sc := s.ctx.GetSessionVars().StmtCtx + origin := sc.InSelectStmt + sc.InSelectStmt = true + defer func() { + sc.InSelectStmt = origin + }() for _, t := range cases { f, err := newFunctionForTest(s.ctx, ast.UnaryMinus, primitiveValsToConstants([]interface{}{t.args})...) c.Assert(err, IsNil) d, err := f.Eval(nil) - - if !t.overflow { - c.Assert(d.GetValue(), Equals, t.expected) + if t.getErr == false { + c.Assert(err, IsNil) + if !t.overflow { + c.Assert(d.GetValue(), Equals, t.expected) + } else { + c.Assert(d.GetMysqlDecimal().String(), Equals, t.expected) + } } else { - c.Assert(d.GetMysqlDecimal().String(), Equals, t.expected) + c.Assert(err, NotNil) } } } diff --git a/expression/scalar_function.go b/expression/scalar_function.go index 6b6b6154e626a..183eb822aaba6 100644 --- a/expression/scalar_function.go +++ b/expression/scalar_function.go @@ -23,7 +23,6 @@ import ( "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx/variable" - "github.com/pingcap/tidb/terror" "github.com/pingcap/tidb/util/codec" "github.com/pingcap/tidb/util/types" ) @@ -161,21 +160,6 @@ func (sf *ScalarFunction) Decorrelate(schema *Schema) Expression { return sf } -func (sf *ScalarFunction) convertArgsToDecimal(sc *variable.StatementContext) error { - ft := types.NewFieldType(mysql.TypeNewDecimal) - for _, arg := range sf.GetArgs() { - if constArg, ok := arg.(*Constant); ok { - val, err := constArg.Value.ConvertTo(sc, ft) - if err != nil { - return errors.Trace(err) - } - constArg.Value = val - } - } - - return nil -} - // Eval implements Expression interface. func (sf *ScalarFunction) Eval(row []types.Datum) (d types.Datum, err error) { sc := sf.GetCtx().GetSessionVars().StmtCtx @@ -213,16 +197,6 @@ func (sf *ScalarFunction) Eval(row []types.Datum) (d types.Datum, err error) { } } - if err != nil && terror.ErrorEqual(err, types.ErrOverflow) && - sf.FuncName.L == ast.UnaryMinus && sc.InSelectStmt { - // TODO: fix #3762, overflow convert it to decimal - err = sf.convertArgsToDecimal(sc) - if err != nil { - return d, errors.Trace(err) - } - res, isNull, err = sf.EvalDecimal(row, sc) - } - if isNull || err != nil { d.SetValue(nil) return d, errors.Trace(err) From bedd2707155ec7cd83561fcc371c8e8a8b5aca58 Mon Sep 17 00:00:00 2001 From: winkyao Date: Tue, 25 Jul 2017 13:21:22 +0800 Subject: [PATCH 10/17] expression: refactor builtin unary minus function --- executor/executor_test.go | 5 ++ expression/builtin_op.go | 102 +++++++++++++++++--------------------- expression/expression.go | 2 +- 3 files changed, 52 insertions(+), 57 deletions(-) diff --git a/executor/executor_test.go b/executor/executor_test.go index f5d3105916737..1ee239bdea24a 100644 --- a/executor/executor_test.go +++ b/executor/executor_test.go @@ -1088,6 +1088,11 @@ func (s *testSuite) TestBuiltin(c *C) { result = tk.MustQuery("select -9223372036854775808;") result.Check(testkit.Rows("-9223372036854775808")) + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a bigint(30));") + _, err := tk.Exec("insert into t values(-9223372036854775809)") + c.Assert(err, NotNil) + // test unhex and hex result = tk.MustQuery("select unhex('4D7953514C')") result.Check(testkit.Rows("MySQL")) diff --git a/expression/builtin_op.go b/expression/builtin_op.go index 1bb86d4f39635..271806629915f 100644 --- a/expression/builtin_op.go +++ b/expression/builtin_op.go @@ -398,16 +398,15 @@ type unaryMinusFunctionClass struct { baseFunctionClass } -func (b *unaryMinusFunctionClass) typeInfer(argExpr Expression, bf *baseBuiltinFunc) bool { - bf.tp.Init(mysql.TypeLonglong) +func (b *unaryMinusFunctionClass) typeInfer(argExpr Expression) (evalTp, bool) { + tp := tpInt switch argExpr.GetType().Tp { case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeDouble, mysql.TypeFloat, mysql.TypeDatetime, mysql.TypeDuration, mysql.TypeTimestamp: - bf.tp.Tp = mysql.TypeDouble + tp = tpReal case mysql.TypeNewDecimal: - bf.tp.Tp = mysql.TypeNewDecimal + tp = tpDecimal } - types.SetBinChsClnFlag(bf.tp) overflow := false // TODO: handle float overflow @@ -417,42 +416,69 @@ func (b *unaryMinusFunctionClass) typeInfer(argExpr Expression, bf *baseBuiltinF uval := arg.Value.GetUint64() if uval > uint64(-math.MinInt64) { overflow = true - bf.tp.Tp = mysql.TypeNewDecimal + tp = tpDecimal } case types.KindInt64: val := arg.Value.GetInt64() if val == math.MinInt64 { overflow = true - bf.tp.Tp = mysql.TypeNewDecimal + tp = tpDecimal } } } - return overflow + return tp, overflow } func (b *unaryMinusFunctionClass) getFunction(args []Expression, ctx context.Context) (sig builtinFunc, err error) { - bf := newBaseBuiltinFunc(args, ctx) argExpr := args[0] - overflow := b.typeInfer(argExpr, &bf) - if overflow { - sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, true} - return sig.setSelf(sig), errors.Trace(b.verifyArgs(args)) - } + retTp, intOverflow := b.typeInfer(argExpr) switch argExpr.GetTypeClass() { case types.ClassInt: - sig = &builtinUnaryMinusIntSig{baseIntBuiltinFunc{bf}} + if intOverflow { + bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) + if err != nil { + return nil, errors.Trace(err) + } + sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} + } else { + bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpInt) + if err != nil { + return nil, errors.Trace(err) + } + sig = &builtinUnaryMinusIntSig{baseIntBuiltinFunc{bf}} + } case types.ClassDecimal: + bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) + if err != nil { + return nil, errors.Trace(err) + } sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} case types.ClassReal: + bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal) + if err != nil { + return nil, errors.Trace(err) + } sig = &builtinUnaryMinusRealSig{baseRealBuiltinFunc{bf}} case types.ClassString: tp := argExpr.GetType().Tp if types.IsTypeTime(tp) { + bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) + if err != nil { + return nil, errors.Trace(err) + } sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} } else if tp == mysql.TypeDuration { + bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) + if err != nil { + return nil, errors.Trace(err) + } sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} } else { + bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal) + if err != nil { + return nil, errors.Trace(err) + } sig = &builtinUnaryMinusRealSig{baseRealBuiltinFunc{bf}} } } @@ -502,19 +528,8 @@ func (b *builtinUnaryMinusDecimalSig) evalDecimal(row []types.Datum) (*types.MyD ) sc := b.getCtx().GetSessionVars().StmtCtx - aDatum, err := b.args[0].Eval(row) - if err != nil || aDatum.IsNull() { - return nil, aDatum.IsNull(), errors.Trace(err) - } - switch aDatum.Kind() { - case types.KindMysqlTime: - dec := new(types.MyDecimal) - err := types.DecimalSub(new(types.MyDecimal), aDatum.GetMysqlTime().ToNumber(), dec) - return dec, false, err - case types.KindMysqlDuration: - dec := new(types.MyDecimal) - err := types.DecimalSub(new(types.MyDecimal), aDatum.GetMysqlDuration().ToNumber(), dec) - return dec, false, err + if !sc.InSelectStmt && b.constantArgOverflow { + return dec, false, types.ErrOverflow.GenByArgs("DECIMAL", dec.String()) } dec, isNull, err := b.args[0].EvalDecimal(row, sc) @@ -522,10 +537,6 @@ func (b *builtinUnaryMinusDecimalSig) evalDecimal(row []types.Datum) (*types.MyD return dec, isNull, errors.Trace(err) } - if !sc.InSelectStmt && b.constantArgOverflow { - return dec, false, types.ErrOverflow.GenByArgs("DECIMAL", dec.String()) - } - to, err = unaryMinusDecimal(dec) return to, false, err } @@ -536,30 +547,9 @@ type builtinUnaryMinusRealSig struct { func (b *builtinUnaryMinusRealSig) evalReal(row []types.Datum) (res float64, isNull bool, err error) { sc := b.getCtx().GetSessionVars().StmtCtx - var aDatum types.Datum - aDatum, err = b.args[0].Eval(row) - if err != nil || aDatum.IsNull() { - return res, aDatum.IsNull(), errors.Trace(err) - } - - switch aDatum.Kind() { - case types.KindFloat32: - res = float64(-aDatum.GetFloat32()) - case types.KindFloat64: - res = float64(-aDatum.GetFloat64()) - case types.KindString, types.KindBytes: - f, err1 := types.StrToFloat(sc, aDatum.GetString()) - err = errors.Trace(err1) - res = float64(-f) - case types.KindMysqlHex: - res = float64(-aDatum.GetMysqlHex().ToNumber()) - case types.KindMysqlBit: - res = float64(-aDatum.GetMysqlBit().ToNumber()) - case types.KindMysqlEnum: - res = float64(-aDatum.GetMysqlEnum().ToNumber()) - case types.KindMysqlSet: - res = float64(-aDatum.GetMysqlSet().ToNumber()) - } + var val float64 + val, isNull, err = b.args[0].EvalReal(row, sc) + res = -val return } diff --git a/expression/expression.go b/expression/expression.go index f11a8599fc7a8..cc82b83670804 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -171,7 +171,7 @@ func evalExprToDecimal(expr Expression, row []types.Datum, sc *variable.Statemen return res, val.IsNull(), errors.Trace(err) } switch expr.GetTypeClass() { - case types.ClassDecimal, types.ClassInt: + case types.ClassDecimal: res, err = val.ToDecimal(sc) return res, false, errors.Trace(err) // TODO: We maintain two sets of type systems, one for Expression, one for Datum. From 0e5de099da6afb8df1f34aa4debf92c465c0ec43 Mon Sep 17 00:00:00 2001 From: winkyao Date: Tue, 25 Jul 2017 15:44:27 +0800 Subject: [PATCH 11/17] expression: tiny refactor --- executor/aggregate_test.go | 1 - expression/builtin_op.go | 56 ++++++++++++++++--------------------- expression/constant_fold.go | 29 ------------------- 3 files changed, 24 insertions(+), 62 deletions(-) diff --git a/executor/aggregate_test.go b/executor/aggregate_test.go index 59a9c4ffbe60b..0869addfc53e7 100644 --- a/executor/aggregate_test.go +++ b/executor/aggregate_test.go @@ -318,7 +318,6 @@ func (s *testSuite) TestAggPrune(c *C) { testleak.AfterTest(c)() }() tk := testkit.NewTestKit(c, s.store) - tk.MustExec("use test") tk.MustExec("drop table if exists t") tk.MustExec("create table t(id int primary key, b varchar(50), c int)") diff --git a/expression/builtin_op.go b/expression/builtin_op.go index 271806629915f..0881b968e0ba1 100644 --- a/expression/builtin_op.go +++ b/expression/builtin_op.go @@ -400,11 +400,10 @@ type unaryMinusFunctionClass struct { func (b *unaryMinusFunctionClass) typeInfer(argExpr Expression) (evalTp, bool) { tp := tpInt - switch argExpr.GetType().Tp { - case mysql.TypeString, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeDouble, mysql.TypeFloat, - mysql.TypeDatetime, mysql.TypeDuration, mysql.TypeTimestamp: + switch argExpr.GetTypeClass() { + case types.ClassString, types.ClassReal: tp = tpReal - case mysql.TypeNewDecimal: + case types.ClassDecimal: tp = tpDecimal } @@ -430,52 +429,52 @@ func (b *unaryMinusFunctionClass) typeInfer(argExpr Expression) (evalTp, bool) { } func (b *unaryMinusFunctionClass) getFunction(args []Expression, ctx context.Context) (sig builtinFunc, err error) { + err = b.verifyArgs(args) + if err != nil { + return nil, errors.Trace(err) + } + argExpr := args[0] retTp, intOverflow := b.typeInfer(argExpr) + var bf baseBuiltinFunc switch argExpr.GetTypeClass() { case types.ClassInt: if intOverflow { - bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) + bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) if err != nil { return nil, errors.Trace(err) } sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} } else { - bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpInt) + bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpInt) if err != nil { return nil, errors.Trace(err) } sig = &builtinUnaryMinusIntSig{baseIntBuiltinFunc{bf}} } case types.ClassDecimal: - bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) + bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) if err != nil { return nil, errors.Trace(err) } sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} case types.ClassReal: - bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal) + bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal) if err != nil { return nil, errors.Trace(err) } sig = &builtinUnaryMinusRealSig{baseRealBuiltinFunc{bf}} case types.ClassString: tp := argExpr.GetType().Tp - if types.IsTypeTime(tp) { - bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) - if err != nil { - return nil, errors.Trace(err) - } - sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} - } else if tp == mysql.TypeDuration { - bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) + if types.IsTypeTime(tp) || tp == mysql.TypeDuration { + bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) if err != nil { return nil, errors.Trace(err) } sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} } else { - bf, err := newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal) + bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal) if err != nil { return nil, errors.Trace(err) } @@ -483,13 +482,7 @@ func (b *unaryMinusFunctionClass) getFunction(args []Expression, ctx context.Con } } - return sig.setSelf(sig), errors.Trace(b.verifyArgs(args)) -} - -func unaryMinusDecimal(dec *types.MyDecimal) (to *types.MyDecimal, err error) { - to = new(types.MyDecimal) - err = types.DecimalSub(new(types.MyDecimal), dec, to) - return + return sig.setSelf(sig), errors.Trace(err) } type builtinUnaryMinusIntSig struct { @@ -523,21 +516,20 @@ type builtinUnaryMinusDecimalSig struct { } func (b *builtinUnaryMinusDecimalSig) evalDecimal(row []types.Datum) (*types.MyDecimal, bool, error) { - var ( - dec, to *types.MyDecimal - ) - sc := b.getCtx().GetSessionVars().StmtCtx - if !sc.InSelectStmt && b.constantArgOverflow { - return dec, false, types.ErrOverflow.GenByArgs("DECIMAL", dec.String()) - } + var dec *types.MyDecimal dec, isNull, err := b.args[0].EvalDecimal(row, sc) if err != nil || isNull { return dec, isNull, errors.Trace(err) } - to, err = unaryMinusDecimal(dec) + if !sc.InSelectStmt && b.constantArgOverflow { + return dec, false, types.ErrOverflow.GenByArgs("DECIMAL", dec.String()) + } + + to := new(types.MyDecimal) + err = types.DecimalSub(new(types.MyDecimal), dec, to) return to, false, err } diff --git a/expression/constant_fold.go b/expression/constant_fold.go index f6437a621ef25..cbc13084ef2a2 100644 --- a/expression/constant_fold.go +++ b/expression/constant_fold.go @@ -45,32 +45,3 @@ func FoldConstant(expr Expression) Expression { RetType: scalarFunc.RetType, } } - -// PlainFoldConstant does constant folding optimization -// on an expression without recursively folding. -func PlainFoldConstant(expr Expression) Expression { - scalarFunc, ok := expr.(*ScalarFunction) - if !ok || !scalarFunc.Function.isDeterministic() { - return expr - } - args := scalarFunc.GetArgs() - canFold := true - for i := 0; i < len(args); i++ { - if _, ok := args[i].(*Constant); !ok { - canFold = false - break - } - } - if !canFold { - return expr - } - value, err := scalarFunc.Eval(nil) - if err != nil { - log.Warnf("There may exist an error during constant folding. The function name is %s, args are %s", scalarFunc.FuncName, args) - return expr - } - return &Constant{ - Value: value, - RetType: scalarFunc.RetType, - } -} From 4045075e5036c452213af15160509f435f4dcaab Mon Sep 17 00:00:00 2001 From: winkyao Date: Tue, 25 Jul 2017 15:57:53 +0800 Subject: [PATCH 12/17] expression: tiny cleanup --- expression/builtin_op.go | 18 ------------------ 1 file changed, 18 deletions(-) diff --git a/expression/builtin_op.go b/expression/builtin_op.go index 0881b968e0ba1..f78f728903416 100644 --- a/expression/builtin_op.go +++ b/expression/builtin_op.go @@ -442,42 +442,24 @@ func (b *unaryMinusFunctionClass) getFunction(args []Expression, ctx context.Con case types.ClassInt: if intOverflow { bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) - if err != nil { - return nil, errors.Trace(err) - } sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} } else { bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpInt) - if err != nil { - return nil, errors.Trace(err) - } sig = &builtinUnaryMinusIntSig{baseIntBuiltinFunc{bf}} } case types.ClassDecimal: bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) - if err != nil { - return nil, errors.Trace(err) - } sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} case types.ClassReal: bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal) - if err != nil { - return nil, errors.Trace(err) - } sig = &builtinUnaryMinusRealSig{baseRealBuiltinFunc{bf}} case types.ClassString: tp := argExpr.GetType().Tp if types.IsTypeTime(tp) || tp == mysql.TypeDuration { bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) - if err != nil { - return nil, errors.Trace(err) - } sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} } else { bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpReal) - if err != nil { - return nil, errors.Trace(err) - } sig = &builtinUnaryMinusRealSig{baseRealBuiltinFunc{bf}} } } From 636170ddbfe8d09058f160daae5404ea29897bde Mon Sep 17 00:00:00 2001 From: winkyao Date: Tue, 25 Jul 2017 18:46:55 +0800 Subject: [PATCH 13/17] expression: cleanup --- expression/builtin_op.go | 82 +++++++++-------------------------- expression/builtin_op_test.go | 4 ++ expression/scalar_function.go | 8 +--- 3 files changed, 26 insertions(+), 68 deletions(-) diff --git a/expression/builtin_op.go b/expression/builtin_op.go index f78f728903416..d41d1cf7236a7 100644 --- a/expression/builtin_op.go +++ b/expression/builtin_op.go @@ -341,53 +341,6 @@ func (b *builtinUnaryOpSig) eval(row []types.Datum) (d types.Datum, err error) { default: return d, errInvalidOperation.Gen("Unsupported type %v for op.Plus", aDatum.Kind()) } - case opcode.Minus: - switch aDatum.Kind() { - case types.KindInt64: - val := aDatum.GetInt64() - if val == math.MinInt64 { - err = types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", val)) - } - d.SetInt64(-val) - case types.KindUint64: - uval := aDatum.GetUint64() - if uval > uint64(-math.MinInt64) { // 9223372036854775808 - // -uval will overflow - err = types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", uval)) - } else { - d.SetInt64(-int64(uval)) - } - case types.KindFloat64: - d.SetFloat64(-aDatum.GetFloat64()) - case types.KindFloat32: - d.SetFloat32(-aDatum.GetFloat32()) - case types.KindMysqlDuration: - dec := new(types.MyDecimal) - err = types.DecimalSub(new(types.MyDecimal), aDatum.GetMysqlDuration().ToNumber(), dec) - d.SetMysqlDecimal(dec) - case types.KindMysqlTime: - dec := new(types.MyDecimal) - err = types.DecimalSub(new(types.MyDecimal), aDatum.GetMysqlTime().ToNumber(), dec) - d.SetMysqlDecimal(dec) - case types.KindString, types.KindBytes: - f, err1 := types.StrToFloat(sc, aDatum.GetString()) - err = errors.Trace(err1) - d.SetFloat64(-f) - case types.KindMysqlDecimal: - dec := new(types.MyDecimal) - err = types.DecimalSub(new(types.MyDecimal), aDatum.GetMysqlDecimal(), dec) - d.SetMysqlDecimal(dec) - case types.KindMysqlHex: - d.SetFloat64(-aDatum.GetMysqlHex().ToNumber()) - case types.KindMysqlBit: - d.SetFloat64(-aDatum.GetMysqlBit().ToNumber()) - case types.KindMysqlEnum: - d.SetFloat64(-aDatum.GetMysqlEnum().ToNumber()) - case types.KindMysqlSet: - d.SetFloat64(-aDatum.GetMysqlSet().ToNumber()) - default: - return d, errInvalidOperation.Gen("Unsupported type %v for op.Minus", aDatum.Kind()) - } default: return d, errInvalidOperation.Gen("Unsupported op %v for unary op", b.op) } @@ -408,20 +361,25 @@ func (b *unaryMinusFunctionClass) typeInfer(argExpr Expression) (evalTp, bool) { } overflow := false - // TODO: handle float overflow + // TODO: Handle float overflow if arg, ok := argExpr.(*Constant); ok { - switch arg.Value.Kind() { - case types.KindUint64: - uval := arg.Value.GetUint64() - if uval > uint64(-math.MinInt64) { - overflow = true - tp = tpDecimal - } - case types.KindInt64: - val := arg.Value.GetInt64() - if val == math.MinInt64 { - overflow = true - tp = tpDecimal + if arg.GetTypeClass() == types.ClassInt { + if mysql.HasUnsignedFlag(arg.GetType().Flag) { + uval := arg.Value.GetUint64() + // -math.MinInt64 is 9223372036854775808, so if uval is more than 9223372036854775808, like + // 9223372036854775809, -9223372036854775809 is less than math.MinInt64, overflow occured + if uval > uint64(-math.MinInt64) { + overflow = true + tp = tpDecimal + } + } else { + val := arg.Value.GetInt64() + // The math.MinInt64 is -9223372036854775808, the math.MaxInt64 is 9223372036854775807, + // which is less than abs(-9223372036854775808). when val == math.MinInt64, overflow occured + if val == math.MinInt64 { + overflow = true + tp = tpDecimal + } } } } @@ -488,7 +446,7 @@ func (b *builtinUnaryMinusIntSig) evalInt(row []types.Datum) (res int64, isNull } else if val == math.MinInt64 { return 0, false, types.ErrOverflow.GenByArgs("BIGINT", fmt.Sprintf("-%v", val)) } - return -val, false, err + return -val, false, errors.Trace(err) } type builtinUnaryMinusDecimalSig struct { @@ -512,7 +470,7 @@ func (b *builtinUnaryMinusDecimalSig) evalDecimal(row []types.Datum) (*types.MyD to := new(types.MyDecimal) err = types.DecimalSub(new(types.MyDecimal), dec, to) - return to, false, err + return to, false, errors.Trace(err) } type builtinUnaryMinusRealSig struct { diff --git a/expression/builtin_op_test.go b/expression/builtin_op_test.go index 514e0ff424d1d..3a40bb0f08980 100644 --- a/expression/builtin_op_test.go +++ b/expression/builtin_op_test.go @@ -58,6 +58,10 @@ func (s *testEvaluatorSuite) TestUnary(c *C) { c.Assert(err, NotNil) } } + + f, err := funcs[ast.UnaryMinus].getFunction([]Expression{Zero}, s.ctx) + c.Assert(err, IsNil) + c.Assert(f.isDeterministic(), IsTrue) } func (s *testEvaluatorSuite) TestAndAnd(c *C) { diff --git a/expression/scalar_function.go b/expression/scalar_function.go index 183eb822aaba6..3f2c5c2e0314a 100644 --- a/expression/scalar_function.go +++ b/expression/scalar_function.go @@ -83,13 +83,11 @@ func NewFunction(ctx context.Context, funcName string, retType *types.FieldType, if builtinRetTp := f.getRetTp(); builtinRetTp.Tp != mysql.TypeUnspecified { retType = builtinRetTp } - sf := &ScalarFunction{ FuncName: model.NewCIStr(funcName), RetType: retType, Function: f, } - return FoldConstant(sf), nil } @@ -162,11 +160,10 @@ func (sf *ScalarFunction) Decorrelate(schema *Schema) Expression { // Eval implements Expression interface. func (sf *ScalarFunction) Eval(row []types.Datum) (d types.Datum, err error) { - sc := sf.GetCtx().GetSessionVars().StmtCtx if !TurnOnNewExprEval { - d, err = sf.Function.eval(row) - return + return sf.Function.eval(row) } + sc := sf.GetCtx().GetSessionVars().StmtCtx var ( res interface{} isNull bool @@ -176,7 +173,6 @@ func (sf *ScalarFunction) Eval(row []types.Datum) (d types.Datum, err error) { case types.ClassInt: var intRes int64 intRes, isNull, err = sf.EvalInt(row, sc) - if mysql.HasUnsignedFlag(tp.Flag) { res = uint64(intRes) } else { From 33bf09c45c285f12b26f5d078d5919b776655b26 Mon Sep 17 00:00:00 2001 From: winkyao Date: Tue, 25 Jul 2017 19:39:39 +0800 Subject: [PATCH 14/17] expresion: cleanup --- executor/prepared.go | 6 +++- expression/builtin_op.go | 57 ++++++++++++++++++---------------- expression/builtin_op_test.go | 6 ++-- sessionctx/variable/session.go | 2 +- 4 files changed, 39 insertions(+), 32 deletions(-) diff --git a/executor/prepared.go b/executor/prepared.go index 44828b9c660bd..17b911fa67b1c 100644 --- a/executor/prepared.go +++ b/executor/prepared.go @@ -336,6 +336,7 @@ func ResetStmtCtx(ctx context.Context, s ast.StmtNode) { switch s.(type) { case *ast.UpdateStmt, *ast.InsertStmt, *ast.DeleteStmt: sc.IgnoreTruncate = false + sc.IgnoreOverflow = false sc.TruncateAsWarning = !sessVars.StrictSQLMode if _, ok := s.(*ast.InsertStmt); !ok { sc.InUpdateOrDeleteStmt = true @@ -343,15 +344,18 @@ func ResetStmtCtx(ctx context.Context, s ast.StmtNode) { case *ast.CreateTableStmt, *ast.AlterTableStmt: // Make sure the sql_mode is strict when checking column default value. sc.IgnoreTruncate = false + sc.IgnoreOverflow = false sc.TruncateAsWarning = false case *ast.LoadDataStmt: sc.IgnoreTruncate = false + sc.IgnoreOverflow = false sc.TruncateAsWarning = !sessVars.StrictSQLMode case *ast.SelectStmt: - sc.InSelectStmt = true + sc.IgnoreOverflow = true sc.IgnoreTruncate = true default: sc.IgnoreTruncate = true + sc.IgnoreOverflow = false if show, ok := s.(*ast.ShowStmt); ok { if show.Tp == ast.ShowWarnings { sc.InShowWarning = true diff --git a/expression/builtin_op.go b/expression/builtin_op.go index d41d1cf7236a7..23ad57675b0ad 100644 --- a/expression/builtin_op.go +++ b/expression/builtin_op.go @@ -351,7 +351,27 @@ type unaryMinusFunctionClass struct { baseFunctionClass } -func (b *unaryMinusFunctionClass) typeInfer(argExpr Expression) (evalTp, bool) { +func (b *unaryMinusFunctionClass) handleIntOverflow(arg *Constant) (overflow bool) { + overflow = false + if mysql.HasUnsignedFlag(arg.GetType().Flag) { + uval := arg.Value.GetUint64() + // -math.MinInt64 is 9223372036854775808, so if uval is more than 9223372036854775808, like + // 9223372036854775809, -9223372036854775809 is less than math.MinInt64, overflow occurs. + if uval > uint64(-math.MinInt64) { + overflow = true + } + } else { + val := arg.Value.GetInt64() + // The math.MinInt64 is -9223372036854775808, the math.MaxInt64 is 9223372036854775807, + // which is less than abs(-9223372036854775808). When val == math.MinInt64, overflow occurs. + if val == math.MinInt64 { + overflow = true + } + } + return +} + +func (b *unaryMinusFunctionClass) typeInfer(argExpr Expression, ctx context.Context) (evalTp, bool) { tp := tpInt switch argExpr.GetTypeClass() { case types.ClassString, types.ClassReal: @@ -360,27 +380,14 @@ func (b *unaryMinusFunctionClass) typeInfer(argExpr Expression) (evalTp, bool) { tp = tpDecimal } + sc := ctx.GetSessionVars().StmtCtx overflow := false - // TODO: Handle float overflow - if arg, ok := argExpr.(*Constant); ok { - if arg.GetTypeClass() == types.ClassInt { - if mysql.HasUnsignedFlag(arg.GetType().Flag) { - uval := arg.Value.GetUint64() - // -math.MinInt64 is 9223372036854775808, so if uval is more than 9223372036854775808, like - // 9223372036854775809, -9223372036854775809 is less than math.MinInt64, overflow occured - if uval > uint64(-math.MinInt64) { - overflow = true - tp = tpDecimal - } - } else { - val := arg.Value.GetInt64() - // The math.MinInt64 is -9223372036854775808, the math.MaxInt64 is 9223372036854775807, - // which is less than abs(-9223372036854775808). when val == math.MinInt64, overflow occured - if val == math.MinInt64 { - overflow = true - tp = tpDecimal - } - } + // TODO: Handle float overflow. + if arg, ok := argExpr.(*Constant); sc.IgnoreOverflow && ok && + arg.GetTypeClass() == types.ClassInt { + overflow = b.handleIntOverflow(arg) + if overflow { + tp = tpDecimal } } return tp, overflow @@ -393,14 +400,14 @@ func (b *unaryMinusFunctionClass) getFunction(args []Expression, ctx context.Con } argExpr := args[0] - retTp, intOverflow := b.typeInfer(argExpr) + retTp, intOverflow := b.typeInfer(argExpr, ctx) var bf baseBuiltinFunc switch argExpr.GetTypeClass() { case types.ClassInt: if intOverflow { bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpDecimal) - sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, false} + sig = &builtinUnaryMinusDecimalSig{baseDecimalBuiltinFunc{bf}, true} } else { bf, err = newBaseBuiltinFuncWithTp(args, ctx, retTp, tpInt) sig = &builtinUnaryMinusIntSig{baseIntBuiltinFunc{bf}} @@ -464,10 +471,6 @@ func (b *builtinUnaryMinusDecimalSig) evalDecimal(row []types.Datum) (*types.MyD return dec, isNull, errors.Trace(err) } - if !sc.InSelectStmt && b.constantArgOverflow { - return dec, false, types.ErrOverflow.GenByArgs("DECIMAL", dec.String()) - } - to := new(types.MyDecimal) err = types.DecimalSub(new(types.MyDecimal), dec, to) return to, false, errors.Trace(err) diff --git a/expression/builtin_op_test.go b/expression/builtin_op_test.go index 3a40bb0f08980..ce35ce05fabe1 100644 --- a/expression/builtin_op_test.go +++ b/expression/builtin_op_test.go @@ -37,10 +37,10 @@ func (s *testEvaluatorSuite) TestUnary(c *C) { {int64(math.MinInt64), "9223372036854775808", true, false}, // --9223372036854775808 } sc := s.ctx.GetSessionVars().StmtCtx - origin := sc.InSelectStmt - sc.InSelectStmt = true + origin := sc.IgnoreOverflow + sc.IgnoreOverflow = true defer func() { - sc.InSelectStmt = origin + sc.IgnoreOverflow = origin }() for _, t := range cases { diff --git a/sessionctx/variable/session.go b/sessionctx/variable/session.go index 25d8e890ba8d3..2cbc200823ead 100644 --- a/sessionctx/variable/session.go +++ b/sessionctx/variable/session.go @@ -319,7 +319,7 @@ type StatementContext struct { // Set the following variables before execution InUpdateOrDeleteStmt bool - InSelectStmt bool + IgnoreOverflow bool IgnoreTruncate bool TruncateAsWarning bool InShowWarning bool From eb58c486d956ade07886b6d98b28b5418cf156cf Mon Sep 17 00:00:00 2001 From: winkyao Date: Wed, 26 Jul 2017 00:25:48 +0800 Subject: [PATCH 15/17] expression: tiny cleanup --- expression/expression.go | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/expression/expression.go b/expression/expression.go index cc82b83670804..e0d0f1ed4be9c 100644 --- a/expression/expression.go +++ b/expression/expression.go @@ -170,8 +170,7 @@ func evalExprToDecimal(expr Expression, row []types.Datum, sc *variable.Statemen if val.IsNull() || err != nil { return res, val.IsNull(), errors.Trace(err) } - switch expr.GetTypeClass() { - case types.ClassDecimal: + if expr.GetTypeClass() == types.ClassDecimal { res, err = val.ToDecimal(sc) return res, false, errors.Trace(err) // TODO: We maintain two sets of type systems, one for Expression, one for Datum. @@ -181,8 +180,7 @@ func evalExprToDecimal(expr Expression, row []types.Datum, sc *variable.Statemen // but what we actually get is store as float64 in Datum. // So if we wrap `CastDecimalAsInt` upon the result, we'll get when call `arg.EvalDecimal()`. // This will be fixed after all built-in functions be rewrite correctlly. - } - if IsHybridType(expr) { + } else if IsHybridType(expr) { res, err = val.ToDecimal(sc) return res, false, errors.Trace(err) } From e2aa6223a80ed94449839b4fee626c3b5357a769 Mon Sep 17 00:00:00 2001 From: winkyao Date: Wed, 26 Jul 2017 10:35:34 +0800 Subject: [PATCH 16/17] expression: add some comment --- expression/builtin_op.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/expression/builtin_op.go b/expression/builtin_op.go index 5cb2291551b8d..e788773e2da53 100644 --- a/expression/builtin_op.go +++ b/expression/builtin_op.go @@ -352,25 +352,26 @@ type unaryMinusFunctionClass struct { } func (b *unaryMinusFunctionClass) handleIntOverflow(arg *Constant) (overflow bool) { - overflow = false if mysql.HasUnsignedFlag(arg.GetType().Flag) { uval := arg.Value.GetUint64() // -math.MinInt64 is 9223372036854775808, so if uval is more than 9223372036854775808, like // 9223372036854775809, -9223372036854775809 is less than math.MinInt64, overflow occurs. if uval > uint64(-math.MinInt64) { - overflow = true + return true } } else { val := arg.Value.GetInt64() // The math.MinInt64 is -9223372036854775808, the math.MaxInt64 is 9223372036854775807, // which is less than abs(-9223372036854775808). When val == math.MinInt64, overflow occurs. if val == math.MinInt64 { - overflow = true + return true } } - return + return false } +// typeInfer infer unary Minus Function return type. when arg is int constant and overflow, +// typerInfer will infer the return type as tpDecimal, not tpInt. func (b *unaryMinusFunctionClass) typeInfer(argExpr Expression, ctx context.Context) (evalTp, bool) { tp := tpInt switch argExpr.GetTypeClass() { From 7f01b42682a7d4aecafdfdfc68d51394d5cc8796 Mon Sep 17 00:00:00 2001 From: winkyao Date: Wed, 26 Jul 2017 10:58:14 +0800 Subject: [PATCH 17/17] expression: adjust commet --- expression/builtin_op.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/expression/builtin_op.go b/expression/builtin_op.go index e788773e2da53..3e84311169b26 100644 --- a/expression/builtin_op.go +++ b/expression/builtin_op.go @@ -370,8 +370,8 @@ func (b *unaryMinusFunctionClass) handleIntOverflow(arg *Constant) (overflow boo return false } -// typeInfer infer unary Minus Function return type. when arg is int constant and overflow, -// typerInfer will infer the return type as tpDecimal, not tpInt. +// typeInfer infers unaryMinus function return type. when the arg is an int constant and overflow, +// typerInfer will infers the return type as tpDecimal, not tpInt. func (b *unaryMinusFunctionClass) typeInfer(argExpr Expression, ctx context.Context) (evalTp, bool) { tp := tpInt switch argExpr.GetTypeClass() {