Skip to content

Commit

Permalink
Add additional big_math helper functions (backport #1563) (#1792)
Browse files Browse the repository at this point in the history
Co-authored-by: Brendan Chou <3680392+BrendanChou@users.noreply.github.com>
  • Loading branch information
mergify[bot] and BrendanChou authored Jun 27, 2024
1 parent f344d9d commit 02caef5
Show file tree
Hide file tree
Showing 2 changed files with 249 additions and 0 deletions.
34 changes: 34 additions & 0 deletions protocol/lib/big_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,27 @@ import (
"math/big"
)

// BigU returns a new big.Int from the input unsigned integer.
func BigU[T uint | uint32 | uint64](u T) *big.Int {
return new(big.Int).SetUint64(uint64(u))
}

// BigI returns a new big.Int from the input signed integer.
func BigI[T int | int32 | int64](i T) *big.Int {
return big.NewInt(int64(i))
}

// BigMulPpm returns the result of `val * ppm / 1_000_000`, rounding in the direction indicated.
func BigMulPpm(val *big.Int, ppm *big.Int, roundUp bool) *big.Int {
result := new(big.Int).Mul(val, ppm)
oneMillion := BigIntOneMillion()
if roundUp {
return BigDivCeil(result, oneMillion)
} else {
return result.Div(result, oneMillion)
}
}

// BigMulPow10 returns the result of `val * 10^exponent`, in *big.Rat.
func BigMulPow10(
val *big.Int,
Expand Down Expand Up @@ -137,6 +158,19 @@ func BigIntClamp(n *big.Int, lowerBound *big.Int, upperBound *big.Int) *big.Int
return bigGenericClamp(n, lowerBound, upperBound)
}

// BigDivCeil returns the ceiling of `a / b`.
func BigDivCeil(a *big.Int, b *big.Int) *big.Int {
result, remainder := new(big.Int).QuoRem(a, b, new(big.Int))

// If the value was rounded (i.e. there is a remainder), and the exact result would be positive,
// then add 1 to the result.
if remainder.Sign() != 0 && (a.Sign() == b.Sign()) {
result.Add(result, big.NewInt(1))
}

return result
}

// BigRatRound takes an input and a direction to round (true for up, false for down).
// It returns the result rounded to a `*big.Int` in the specified direction.
func BigRatRound(n *big.Rat, roundUp bool) *big.Int {
Expand Down
215 changes: 215 additions & 0 deletions protocol/lib/big_math_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,138 @@ import (
"github.com/stretchr/testify/require"
)

func BenchmarkBigI(b *testing.B) {
var result *big.Int
for i := 0; i < b.N; i++ {
result = lib.BigI(int64(i))
}
require.Equal(b, result, result)
}

func BenchmarkBigU(b *testing.B) {
var result *big.Int
for i := 0; i < b.N; i++ {
result = lib.BigU(uint32(i))
}
require.Equal(b, result, result)
}

func TestBigI(t *testing.T) {
require.Equal(t, big.NewInt(-123), lib.BigI(int(-123)))
require.Equal(t, big.NewInt(-123), lib.BigI(int32(-123)))
require.Equal(t, big.NewInt(-123), lib.BigI(int64(-123)))
require.Equal(t, big.NewInt(math.MaxInt64), lib.BigI(math.MaxInt64))
}

func TestBigU(t *testing.T) {
require.Equal(t, big.NewInt(123), lib.BigU(uint(123)))
require.Equal(t, big.NewInt(123), lib.BigU(uint32(123)))
require.Equal(t, big.NewInt(123), lib.BigU(uint64(123)))
require.Equal(t, new(big.Int).SetUint64(math.MaxUint64), lib.BigU(uint64(math.MaxUint64)))
}

func BenchmarkBigMulPpm_RoundDown(b *testing.B) {
val := big.NewInt(543_211)
ppm := big.NewInt(876_543)
var result *big.Int
for i := 0; i < b.N; i++ {
result = lib.BigMulPpm(val, ppm, false)
}
require.Equal(b, big.NewInt(476147), result)
}

func BenchmarkBigMulPpm_RoundUp(b *testing.B) {
val := big.NewInt(543_211)
ppm := big.NewInt(876_543)
var result *big.Int
for i := 0; i < b.N; i++ {
result = lib.BigMulPpm(val, ppm, true)
}
require.Equal(b, big.NewInt(476148), result)
}

func TestBigMulPpm(t *testing.T) {
tests := map[string]struct {
val *big.Int
ppm *big.Int
roundUp bool
expectedResult *big.Int
}{
"Positive round down": {
val: big.NewInt(543_211),
ppm: big.NewInt(876_543),
roundUp: false,
expectedResult: big.NewInt(476147),
},
"Negative round down": {
val: big.NewInt(-543_211),
ppm: big.NewInt(876_543),
roundUp: false,
expectedResult: big.NewInt(-476148),
},
"Positive round up": {
val: big.NewInt(543_211),
ppm: big.NewInt(876_543),
roundUp: true,
expectedResult: big.NewInt(476148),
},
"Negative round up": {
val: big.NewInt(-543_211),
ppm: big.NewInt(876_543),
roundUp: true,
expectedResult: big.NewInt(-476147),
},
"Zero val": {
val: big.NewInt(0),
ppm: big.NewInt(876_543),
roundUp: true,
expectedResult: big.NewInt(0),
},
"Zero ppm": {
val: big.NewInt(543_211),
ppm: big.NewInt(0),
roundUp: true,
expectedResult: big.NewInt(0),
},
"Zero val and ppm": {
val: big.NewInt(0),
ppm: big.NewInt(0),
roundUp: true,
expectedResult: big.NewInt(0),
},
"Negative val": {
val: big.NewInt(-543_211),
ppm: big.NewInt(876_543),
roundUp: true,
expectedResult: big.NewInt(-476147),
},
"Negative ppm": {
val: big.NewInt(543_211),
ppm: big.NewInt(-876_543),
roundUp: true,
expectedResult: big.NewInt(-476147),
},
"Negative val and ppm": {
val: big.NewInt(-543_211),
ppm: big.NewInt(-876_543),
roundUp: true,
expectedResult: big.NewInt(476148),
},
"Greater than max int64": {
val: big_testutil.MustFirst(new(big.Int).SetString("1000000000000000000000000", 10)),
ppm: big.NewInt(10_000),
roundUp: true,
expectedResult: big_testutil.MustFirst(new(big.Int).SetString("10000000000000000000000", 10)),
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
result := lib.BigMulPpm(tc.val, tc.ppm, tc.roundUp)
require.Equal(t, tc.expectedResult, result)
})
}
}

func TestBigPow10(t *testing.T) {
tests := map[string]struct {
exponent uint64
Expand Down Expand Up @@ -523,6 +655,89 @@ func TestBigIntClamp(t *testing.T) {
}
}

func BenchmarkBigDivCeil(b *testing.B) {
numerator := big.NewInt(10)
denominator := big.NewInt(3)
var result *big.Int
for i := 0; i < b.N; i++ {
result = lib.BigDivCeil(numerator, denominator)
}
require.Equal(b, big.NewInt(4), result)
}

func TestBigDivCeil(t *testing.T) {
tests := map[string]struct {
numerator *big.Int
denominator *big.Int
expectedResult *big.Int
}{
"Divides evenly": {
numerator: big.NewInt(10),
denominator: big.NewInt(5),
expectedResult: big.NewInt(2),
},
"Doesn't divide evenly": {
numerator: big.NewInt(10),
denominator: big.NewInt(3),
expectedResult: big.NewInt(4),
},
"Negative numerator": {
numerator: big.NewInt(-10),
denominator: big.NewInt(3),
expectedResult: big.NewInt(-3),
},
"Negative numerator 2": {
numerator: big.NewInt(-1),
denominator: big.NewInt(2),
expectedResult: big.NewInt(0),
},
"Negative denominator": {
numerator: big.NewInt(10),
denominator: big.NewInt(-3),
expectedResult: big.NewInt(-3),
},
"Negative denominator 2": {
numerator: big.NewInt(1),
denominator: big.NewInt(-2),
expectedResult: big.NewInt(0),
},
"Negative numerator and denominator": {
numerator: big.NewInt(-10),
denominator: big.NewInt(-3),
expectedResult: big.NewInt(4),
},
"Negative numerator and denominator 2": {
numerator: big.NewInt(-1),
denominator: big.NewInt(-2),
expectedResult: big.NewInt(1),
},
"Zero numerator": {
numerator: big.NewInt(0),
denominator: big.NewInt(3),
expectedResult: big.NewInt(0),
},
"Zero denominator": {
numerator: big.NewInt(10),
denominator: big.NewInt(0),
expectedResult: nil,
},
}
for name, tc := range tests {
t.Run(name, func(t *testing.T) {
// Panics if the expected result is nil
if tc.expectedResult == nil {
require.Panics(t, func() {
lib.BigDivCeil(tc.numerator, tc.denominator)
})
return
}
// Otherwise test the result
result := lib.BigDivCeil(tc.numerator, tc.denominator)
require.Equal(t, tc.expectedResult, result)
})
}
}

func TestBigRatRound(t *testing.T) {
tests := map[string]struct {
input *big.Rat
Expand Down

0 comments on commit 02caef5

Please sign in to comment.