diff --git a/expression/builtin_math.go b/expression/builtin_math.go index 7891c30e69fc2..7bc918330d9a8 100644 --- a/expression/builtin_math.go +++ b/expression/builtin_math.go @@ -353,7 +353,7 @@ func (b *builtinRoundRealSig) evalReal(row chunk.Row) (float64, bool, error) { if isNull || err != nil { return 0, isNull, err } - return types.Round(val, 0), false, nil + return types.RoundFloat(val, 0), false, nil } type builtinRoundIntSig struct { @@ -417,7 +417,7 @@ func (b *builtinRoundWithFracRealSig) evalReal(row chunk.Row) (float64, bool, er if isNull || err != nil { return 0, isNull, err } - return types.Round(val, int(frac)), false, nil + return types.RoundFloat(val, int(frac)), false, nil } type builtinRoundWithFracIntSig struct { @@ -441,7 +441,7 @@ func (b *builtinRoundWithFracIntSig) evalInt(row chunk.Row) (int64, bool, error) if isNull || err != nil { return 0, isNull, err } - return int64(types.Round(float64(val), int(frac))), false, nil + return types.RoundInt(val, int(frac)), false, nil } type builtinRoundWithFracDecSig struct { diff --git a/expression/builtin_math_test.go b/expression/builtin_math_test.go index cd5983170bceb..974bd83c0dd1d 100644 --- a/expression/builtin_math_test.go +++ b/expression/builtin_math_test.go @@ -435,6 +435,11 @@ func (s *testEvaluatorSuite) TestRound(c *C) { {[]interface{}{-1.5, 0}, -2}, {[]interface{}{1.5, 0}, 2}, {[]interface{}{23.298, -1}, 20}, + {[]interface{}{49.99999, -2}, 0}, + {[]interface{}{50, -2}, 100}, + {[]interface{}{50.00001, -2}, 100}, + {[]interface{}{123456789, -5}, 123500000}, + {[]interface{}{2146213728964879326, -15}, 2146000000000000000}, {[]interface{}{newDec("-1.23")}, newDec("-1")}, {[]interface{}{newDec("-1.23"), 1}, newDec("-1.2")}, {[]interface{}{newDec("-1.58")}, newDec("-2")}, diff --git a/expression/builtin_math_vec.go b/expression/builtin_math_vec.go index 0835a8ad2aad8..22af698dc0db5 100644 --- a/expression/builtin_math_vec.go +++ b/expression/builtin_math_vec.go @@ -517,7 +517,7 @@ func (b *builtinRoundRealSig) vecEvalReal(input *chunk.Chunk, result *chunk.Colu if result.IsNull(i) { continue } - f64s[i] = types.Round(f64s[i], 0) + f64s[i] = types.RoundFloat(f64s[i], 0) } return nil } @@ -547,7 +547,7 @@ func (b *builtinRoundWithFracRealSig) vecEvalReal(input *chunk.Chunk, result *ch if result.IsNull(i) { continue } - x[i] = types.Round(x[i], int(d[i])) + x[i] = types.RoundFloat(x[i], int(d[i])) } return nil } @@ -654,7 +654,7 @@ func (b *builtinRoundWithFracIntSig) vecEvalInt(input *chunk.Chunk, result *chun if result.IsNull(i) { continue } - i64s[i] = int64(types.Round(float64(i64s[i]), int(frac[i]))) + i64s[i] = types.RoundInt(i64s[i], int(frac[i])) } return nil } diff --git a/expression/integration_test.go b/expression/integration_test.go index 2efcae665713c..e5f71adb119f8 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -696,6 +696,10 @@ func (s *testIntegrationSuite2) TestMathBuiltin(c *C) { result.Check(testkit.Rows("100000000000000 1000000000000000 100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000")) result = tk.MustQuery("SELECT ROUND(1e-14, 1), ROUND(1e-15, 1), ROUND(1e-308, 1)") result.Check(testkit.Rows("0 0 0")) + result = tk.MustQuery("SELECT round(49.99999, -2), round(50, -2), round(50.00001, -2)") + result.Check(testkit.Rows("0 100 100")) + result = tk.MustQuery("SELECT round(123456789, -5), round(2146213728964879326, -15), round(24999999999999999,-16)") + result.Check(testkit.Rows("123500000 2146000000000000000 20000000000000000")) // for truncate result = tk.MustQuery("SELECT truncate(123, -2), truncate(123, 2), truncate(123, 1), truncate(123, -1);") diff --git a/types/convert.go b/types/convert.go index f6b88597e9964..ff39bb4dfd7a0 100644 --- a/types/convert.go +++ b/types/convert.go @@ -100,7 +100,7 @@ func IntergerSignedLowerBound(intType byte) int64 { // ConvertFloatToInt converts a float64 value to a int value. // `tp` is used in err msg, if there is overflow, this func will report err according to `tp` func ConvertFloatToInt(fval float64, lowerBound, upperBound int64, tp byte) (int64, error) { - val := RoundFloat(fval) + val := math.RoundToEven(fval) if val < float64(lowerBound) { return lowerBound, overflow(val, tp) } @@ -160,7 +160,7 @@ func ConvertUintToUint(val uint64, upperBound uint64, tp byte) (uint64, error) { // ConvertFloatToUint converts a float value to an uint value. func ConvertFloatToUint(sc *stmtctx.StatementContext, fval float64, upperBound uint64, tp byte) (uint64, error) { - val := RoundFloat(fval) + val := math.RoundToEven(fval) if val < 0 { if sc.ShouldClipToZero() { return 0, overflow(val, tp) diff --git a/types/etc_test.go b/types/etc_test.go index c4d8448d7c4e6..1d2ff727e1094 100644 --- a/types/etc_test.go +++ b/types/etc_test.go @@ -132,29 +132,6 @@ func (s *testTypeEtcSuite) TestMaxFloat(c *C) { } } -func (s *testTypeEtcSuite) TestRoundFloat(c *C) { - defer testleak.AfterTest(c)() - tbl := []struct { - Input float64 - Expect float64 - }{ - {2.5, 2}, - {1.5, 2}, - {0.5, 0}, - {0.49999999999999997, 0}, - {0, 0}, - {-0.49999999999999997, 0}, - {-0.5, 0}, - {-2.5, -2}, - {-1.5, -2}, - } - - for _, t := range tbl { - f := RoundFloat(t.Input) - c.Assert(f, Equals, t.Expect) - } -} - func (s *testTypeEtcSuite) TestRound(c *C) { defer testleak.AfterTest(c)() tbl := []struct { @@ -171,7 +148,7 @@ func (s *testTypeEtcSuite) TestRound(c *C) { } for _, t := range tbl { - f := Round(t.Input, t.Dec) + f := RoundFloat(t.Input, t.Dec) c.Assert(f, Equals, t.Expect) } } diff --git a/types/helper.go b/types/helper.go index 2da6bd7275d5f..5d3a1407674b8 100644 --- a/types/helper.go +++ b/types/helper.go @@ -23,31 +23,35 @@ import ( "github.com/pingcap/errors" ) -// RoundFloat rounds float val to the nearest even integer value with float64 format, like MySQL Round function. -// RoundFloat uses default rounding mode, see https://dev.mysql.com/doc/refman/5.7/en/precision-math-rounding.html -// so rounding use "round to nearest even". -// e.g, 1.5 -> 2, -1.5 -> -2. -func RoundFloat(f float64) float64 { - return math.RoundToEven(f) -} - -// Round rounds the argument f to dec decimal places. +// RoundFloat rounds the argument f to dec decimal places using "round to nearest even" rule. // dec defaults to 0 if not specified. dec can be negative -// to cause dec digits left of the decimal point of the -// value f to become zero. -func Round(f float64, dec int) float64 { +// see https://dev.mysql.com/doc/refman/5.7/en/precision-math-rounding.html +func RoundFloat(f float64, dec int) float64 { shift := math.Pow10(dec) tmp := f * shift if math.IsInf(tmp, 0) { return f } - result := RoundFloat(tmp) / shift + result := math.RoundToEven(tmp) / shift if math.IsNaN(result) { return 0 } return result } +// RoundInt rounds the argument i to dec decimal places using "round half up" rule. +// dec defaults to 0 if not specified. dec can be negative +func RoundInt(i int64, dec int) int64 { + // is itself when dec >= 0 + if dec >= 0 { + return i + } + + shift := math.Pow10(-dec) + intPart := math.Round(float64(i) / shift) + return int64(intPart) * int64(shift) +} + // Truncate truncates the argument f to dec decimal places. // dec defaults to 0 if not specified. dec can be negative // to cause dec digits left of the decimal point of the @@ -80,7 +84,7 @@ func TruncateFloat(f float64, flen int, decimal int) (float64, error) { maxF := GetMaxFloat(flen, decimal) if !math.IsInf(f, 0) { - f = Round(f, decimal) + f = RoundFloat(f, decimal) } var err error diff --git a/types/helper_test.go b/types/helper_test.go index 6523f7af65cb9..595ccb29cf9f6 100644 --- a/types/helper_test.go +++ b/types/helper_test.go @@ -16,18 +16,14 @@ package types import ( "strconv" + "testing" - . "github.com/pingcap/check" "github.com/pingcap/errors" + "github.com/stretchr/testify/require" ) -var _ = Suite(&testTypeHelperSuite{}) - -type testTypeHelperSuite struct { -} - -func (s *testTypeHelperSuite) TestStrToInt(c *C) { - c.Parallel() +func TestStrToInt(t *testing.T) { + t.Parallel() tests := []struct { input string output string @@ -42,13 +38,13 @@ func (s *testTypeHelperSuite) TestStrToInt(c *C) { } for _, tt := range tests { output, err := strToInt(tt.input) - c.Assert(errors.Cause(err), Equals, tt.err) - c.Check(strconv.FormatInt(output, 10), Equals, tt.output) + require.Equal(t, tt.err, errors.Cause(err)) + require.Equal(t, tt.output, strconv.FormatInt(output, 10)) } } -func (s *testTypeHelperSuite) TestTruncate(c *C) { - c.Parallel() +func TestTruncate(t *testing.T) { + t.Parallel() tests := []struct { f float64 dec int @@ -61,12 +57,12 @@ func (s *testTypeHelperSuite) TestTruncate(c *C) { } for _, tt := range tests { res := Truncate(tt.f, tt.dec) - c.Assert(res, Equals, tt.expected) + require.Equal(t, tt.expected, res) } } -func (s *testTypeHelperSuite) TestTruncateFloatToString(c *C) { - c.Parallel() +func TestTruncateFloatToString(t *testing.T) { + t.Parallel() tests := []struct { f float64 dec int @@ -83,6 +79,30 @@ func (s *testTypeHelperSuite) TestTruncateFloatToString(c *C) { } for _, tt := range tests { res := TruncateFloatToString(tt.f, tt.dec) - c.Assert(res, Equals, tt.expected) + require.Equal(t, tt.expected, res) + } +} + +func TestRoundInt(t *testing.T) { + t.Parallel() + tests := []struct { + i int64 + dec int + expected int64 + }{ + {1, 0, 1}, + {2146213728964879326, -15, 2146000000000000000}, + {123456789, -5, 123500000}, + {50, -2, 100}, + {150, 2, 150}, + {-1, 0, -1}, + {-2146213728964879326, -15, -2146000000000000000}, + {-123456789, -5, -123500000}, + {-50, -2, -100}, + {-150, 2, -150}, + } + for _, tt := range tests { + res := RoundInt(tt.i, tt.dec) + require.Equal(t, tt.expected, res) } }