Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expression: round function for int should use round half up rule #27403

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
6 changes: 3 additions & 3 deletions expression/builtin_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -353,7 +353,7 @@ func (b *builtinRoundRealSig) evalReal(row chunk.Row) (float64, bool, error) {
if isNull || err != nil {
return 0, isNull, err
}
return types.Round(val, 0), false, nil
return types.RoundFloat(val, 0), false, nil
}

type builtinRoundIntSig struct {
Expand Down Expand Up @@ -417,7 +417,7 @@ func (b *builtinRoundWithFracRealSig) evalReal(row chunk.Row) (float64, bool, er
if isNull || err != nil {
return 0, isNull, err
}
return types.Round(val, int(frac)), false, nil
return types.RoundFloat(val, int(frac)), false, nil
}

type builtinRoundWithFracIntSig struct {
Expand All @@ -441,7 +441,7 @@ func (b *builtinRoundWithFracIntSig) evalInt(row chunk.Row) (int64, bool, error)
if isNull || err != nil {
return 0, isNull, err
}
return int64(types.Round(float64(val), int(frac))), false, nil
return types.RoundInt(val, int(frac)), false, nil
}

type builtinRoundWithFracDecSig struct {
Expand Down
5 changes: 5 additions & 0 deletions expression/builtin_math_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -435,6 +435,11 @@ func (s *testEvaluatorSuite) TestRound(c *C) {
{[]interface{}{-1.5, 0}, -2},
{[]interface{}{1.5, 0}, 2},
{[]interface{}{23.298, -1}, 20},
{[]interface{}{49.99999, -2}, 0},
{[]interface{}{50, -2}, 100},
{[]interface{}{50.00001, -2}, 100},
{[]interface{}{123456789, -5}, 123500000},
{[]interface{}{2146213728964879326, -15}, 2146000000000000000},
{[]interface{}{newDec("-1.23")}, newDec("-1")},
{[]interface{}{newDec("-1.23"), 1}, newDec("-1.2")},
{[]interface{}{newDec("-1.58")}, newDec("-2")},
Expand Down
6 changes: 3 additions & 3 deletions expression/builtin_math_vec.go
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,7 @@ func (b *builtinRoundRealSig) vecEvalReal(input *chunk.Chunk, result *chunk.Colu
if result.IsNull(i) {
continue
}
f64s[i] = types.Round(f64s[i], 0)
f64s[i] = types.RoundFloat(f64s[i], 0)
}
return nil
}
Expand Down Expand Up @@ -547,7 +547,7 @@ func (b *builtinRoundWithFracRealSig) vecEvalReal(input *chunk.Chunk, result *ch
if result.IsNull(i) {
continue
}
x[i] = types.Round(x[i], int(d[i]))
x[i] = types.RoundFloat(x[i], int(d[i]))
}
return nil
}
Expand Down Expand Up @@ -654,7 +654,7 @@ func (b *builtinRoundWithFracIntSig) vecEvalInt(input *chunk.Chunk, result *chun
if result.IsNull(i) {
continue
}
i64s[i] = int64(types.Round(float64(i64s[i]), int(frac[i])))
i64s[i] = types.RoundInt(i64s[i], int(frac[i]))
}
return nil
}
Expand Down
4 changes: 4 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -696,6 +696,10 @@ func (s *testIntegrationSuite2) TestMathBuiltin(c *C) {
result.Check(testkit.Rows("100000000000000 1000000000000000 100000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000000"))
result = tk.MustQuery("SELECT ROUND(1e-14, 1), ROUND(1e-15, 1), ROUND(1e-308, 1)")
result.Check(testkit.Rows("0 0 0"))
result = tk.MustQuery("SELECT round(49.99999, -2), round(50, -2), round(50.00001, -2)")
result.Check(testkit.Rows("0 100 100"))
result = tk.MustQuery("SELECT round(123456789, -5), round(2146213728964879326, -15) ")
result.Check(testkit.Rows("123500000 2146000000000000000"))

// for truncate
result = tk.MustQuery("SELECT truncate(123, -2), truncate(123, 2), truncate(123, 1), truncate(123, -1);")
Expand Down
4 changes: 2 additions & 2 deletions types/convert.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func IntergerSignedLowerBound(intType byte) int64 {
// ConvertFloatToInt converts a float64 value to a int value.
// `tp` is used in err msg, if there is overflow, this func will report err according to `tp`
func ConvertFloatToInt(fval float64, lowerBound, upperBound int64, tp byte) (int64, error) {
val := RoundFloat(fval)
val := math.RoundToEven(fval)
if val < float64(lowerBound) {
return lowerBound, overflow(val, tp)
}
Expand Down Expand Up @@ -160,7 +160,7 @@ func ConvertUintToUint(val uint64, upperBound uint64, tp byte) (uint64, error) {

// ConvertFloatToUint converts a float value to an uint value.
func ConvertFloatToUint(sc *stmtctx.StatementContext, fval float64, upperBound uint64, tp byte) (uint64, error) {
val := RoundFloat(fval)
val := math.RoundToEven(fval)
if val < 0 {
if sc.ShouldClipToZero() {
return 0, overflow(val, tp)
Expand Down
25 changes: 1 addition & 24 deletions types/etc_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,29 +132,6 @@ func (s *testTypeEtcSuite) TestMaxFloat(c *C) {
}
}

func (s *testTypeEtcSuite) TestRoundFloat(c *C) {
defer testleak.AfterTest(c)()
tbl := []struct {
Input float64
Expect float64
}{
{2.5, 2},
{1.5, 2},
{0.5, 0},
{0.49999999999999997, 0},
{0, 0},
{-0.49999999999999997, 0},
{-0.5, 0},
{-2.5, -2},
{-1.5, -2},
}

for _, t := range tbl {
f := RoundFloat(t.Input)
c.Assert(f, Equals, t.Expect)
}
}

func (s *testTypeEtcSuite) TestRound(c *C) {
defer testleak.AfterTest(c)()
tbl := []struct {
Expand All @@ -171,7 +148,7 @@ func (s *testTypeEtcSuite) TestRound(c *C) {
}

for _, t := range tbl {
f := Round(t.Input, t.Dec)
f := RoundFloat(t.Input, t.Dec)
c.Assert(f, Equals, t.Expect)
}
}
Expand Down
32 changes: 18 additions & 14 deletions types/helper.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,31 +23,35 @@ import (
"github.com/pingcap/errors"
)

// RoundFloat rounds float val to the nearest even integer value with float64 format, like MySQL Round function.
// RoundFloat uses default rounding mode, see https://dev.mysql.com/doc/refman/5.7/en/precision-math-rounding.html
// so rounding use "round to nearest even".
// e.g, 1.5 -> 2, -1.5 -> -2.
func RoundFloat(f float64) float64 {
return math.RoundToEven(f)
}

// Round rounds the argument f to dec decimal places.
// RoundFloat rounds the argument f to dec decimal places using "round to nearest even" rule.
// dec defaults to 0 if not specified. dec can be negative
// to cause dec digits left of the decimal point of the
// value f to become zero.
func Round(f float64, dec int) float64 {
// see https://dev.mysql.com/doc/refman/5.7/en/precision-math-rounding.html
func RoundFloat(f float64, dec int) float64 {
shift := math.Pow10(dec)
tmp := f * shift
if math.IsInf(tmp, 0) {
return f
}
result := RoundFloat(tmp) / shift
result := math.RoundToEven(tmp) / shift
if math.IsNaN(result) {
return 0
}
return result
}

// RoundInt rounds the argument i to dec decimal places using "round half up" rule.
// dec defaults to 0 if not specified. dec can be negative
func RoundInt(i int64, dec int) int64 {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should consider unsigned int range and overflow.

MySQL [test]> select round(9223372036854775808,-3);
+-------------------------------+
| round(9223372036854775808,-3) |
+-------------------------------+
|           9223372036854776000 |
+-------------------------------+
1 row in set (0.000 sec)

MySQL [test]> 
MySQL [test]> select round(-9223372036854775808,-3);
ERROR 1690 (22003): BIGINT value is out of range in 'round(-(9223372036854775808),-(3))'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks for suggestion, i will check

Copy link
Contributor Author

@feitian124 feitian124 Aug 23, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

select round(1234567890123456789012345678901234567890, -30);
# output: 1234567890000000000000000000000000000000

i found both mysql 8 and tidb v5.1.1have above result,
and 1234567890123456789012345678901234567890 is clearly much larger than math.MaxInt64,
so it cloud not be the param or return value of builtinRoundWithFracIntSig.evalInt,
i believe above case and unsigned int should have processed in some other place,
so i will not process these case in roundInt

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think so.
select round(1234567890123456789012345678901234567890, -30); This case is decimal round, not int.

We can create a table with bigint field and insert -9223372036854775808, then round it should throw the error.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

// is itself when dec >= 0
if dec >= 0 {
return i
}

shift := math.Pow10(-dec)
intPart := math.Round(float64(i) / shift)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a test case.

MySQL [test]>  select round(24999999999999999,-16);
+------------------------------+
| round(24999999999999999,-16) |
+------------------------------+
|            20000000000000000 |
+------------------------------+
1 row in set (0.000 sec)

We should avoid float calc in int round.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

added and failed, will fix later

return int64(intPart) * int64(shift)
}

// Truncate truncates the argument f to dec decimal places.
// dec defaults to 0 if not specified. dec can be negative
// to cause dec digits left of the decimal point of the
Expand Down Expand Up @@ -80,7 +84,7 @@ func TruncateFloat(f float64, flen int, decimal int) (float64, error) {
maxF := GetMaxFloat(flen, decimal)

if !math.IsInf(f, 0) {
f = Round(f, decimal)
f = RoundFloat(f, decimal)
}

var err error
Expand Down
52 changes: 36 additions & 16 deletions types/helper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,14 @@ package types

import (
"strconv"
"testing"

. "github.com/pingcap/check"
"github.com/pingcap/errors"
"github.com/stretchr/testify/require"
)

var _ = Suite(&testTypeHelperSuite{})

type testTypeHelperSuite struct {
}

func (s *testTypeHelperSuite) TestStrToInt(c *C) {
c.Parallel()
func TestStrToInt(t *testing.T) {
t.Parallel()
tests := []struct {
input string
output string
Expand All @@ -42,13 +38,13 @@ func (s *testTypeHelperSuite) TestStrToInt(c *C) {
}
for _, tt := range tests {
output, err := strToInt(tt.input)
c.Assert(errors.Cause(err), Equals, tt.err)
c.Check(strconv.FormatInt(output, 10), Equals, tt.output)
require.Equal(t, tt.err, errors.Cause(err))
require.Equal(t, tt.output, strconv.FormatInt(output, 10))
}
}

func (s *testTypeHelperSuite) TestTruncate(c *C) {
c.Parallel()
func TestTruncate(t *testing.T) {
t.Parallel()
tests := []struct {
f float64
dec int
Expand All @@ -61,12 +57,12 @@ func (s *testTypeHelperSuite) TestTruncate(c *C) {
}
for _, tt := range tests {
res := Truncate(tt.f, tt.dec)
c.Assert(res, Equals, tt.expected)
require.Equal(t, tt.expected, res)
}
}

func (s *testTypeHelperSuite) TestTruncateFloatToString(c *C) {
c.Parallel()
func TestTruncateFloatToString(t *testing.T) {
t.Parallel()
tests := []struct {
f float64
dec int
Expand All @@ -83,6 +79,30 @@ func (s *testTypeHelperSuite) TestTruncateFloatToString(c *C) {
}
for _, tt := range tests {
res := TruncateFloatToString(tt.f, tt.dec)
c.Assert(res, Equals, tt.expected)
require.Equal(t, tt.expected, res)
}
}

func TestRoundInt(t *testing.T) {
t.Parallel()
tests := []struct {
i int64
dec int
expected int64
}{
{1, 0, 1},
{2146213728964879326, -15, 2146000000000000000},
{123456789, -5, 123500000},
{50, -2, 100},
{150, 2, 150},
{-1, 0, -1},
{-2146213728964879326, -15, -2146000000000000000},
{-123456789, -5, -123500000},
{-50, -2, -100},
{-150, 2, -150},
}
for _, tt := range tests {
res := RoundInt(tt.i, tt.dec)
require.Equal(t, tt.expected, res)
}
}