Skip to content

Commit

Permalink
Reimplement pow function for integer exponent. (#638)
Browse files Browse the repository at this point in the history
* Reimplement checked_powu.

* Add pow tests.

* Add short-cuts to powu.

* Fix x.powu(0)

* Refine powu implementation.

* Exclude a failing test that was caused by the deprecated legacy-ops feature

---------

Co-authored-by: Paul Mason <paul@paulmason.me>
  • Loading branch information
schungx and paupino authored Jan 13, 2024
1 parent 10ee2ee commit 80e9f08
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 43 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ Cargo.lock
artifacts
corpus
target
.vscode/settings.json
51 changes: 29 additions & 22 deletions src/maths.rs
Original file line number Diff line number Diff line change
Expand Up @@ -227,33 +227,40 @@ impl MathematicalOps for Decimal {
}

fn checked_powu(&self, exp: u64) -> Option<Decimal> {
if exp == 0 {
return Some(Decimal::ONE);
}
if self.is_zero() {
return Some(Decimal::ZERO);
}
if self.is_one() {
return Some(Decimal::ONE);
}

match exp {
0 => Some(Decimal::ONE),
0 => unreachable!(),
1 => Some(*self),
2 => self.checked_mul(*self),
// Do the exponentiation by multiplying squares:
// y = Sum (for each 1 bit in binary representation) of (2 ^ bit)
// x ^ y = Sum (for each 1 bit in y) of (x ^ (2 ^ bit))
// See: https://en.wikipedia.org/wiki/Exponentiation_by_squaring
_ => {
// Get the squared value
let squared = match self.checked_mul(*self) {
Some(s) => s,
None => return None,
};
// Square self once and make an infinite sized iterator of the square.
let iter = core::iter::repeat(squared);

// We then take half of the exponent to create a finite iterator and then multiply those together.
let mut product = Decimal::ONE;
for x in iter.take((exp >> 1) as usize) {
match product.checked_mul(x) {
Some(r) => product = r,
None => return None,
};
}

// If the exponent is odd we still need to multiply once more
if exp & 0x1 > 0 {
match self.checked_mul(product) {
Some(p) => product = p,
None => return None,
let mut mask = exp;
let mut power = *self;

// Run through just enough 1 bits
for n in 0..(64 - exp.leading_zeros()) {
if n > 0 {
power = power.checked_mul(power)?;
mask >>= 1;
}
if mask & 0x01 > 0 {
match product.checked_mul(power) {
Some(r) => product = r,
None => return None,
};
}
}
product.normalize_assign();
Expand Down
30 changes: 9 additions & 21 deletions tests/decimal_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3693,6 +3693,8 @@ mod maths {
("0.1", 0_u64, "1"),
("342.4", 1_u64, "342.4"),
("2.0", 16_u64, "65536"),
("0.99999999999999", 1477289400_u64, "0.9999852272151186611602884841"),
("0.99999999999999", 0x8000_8000_0000_0000, "0"),
];
for &(x, y, expected) in test_cases {
let x = Decimal::from_str(x).unwrap();
Expand Down Expand Up @@ -3829,6 +3831,7 @@ mod maths {
"0.1234567890123456789012345678",
either!("0.0003533642875741443321850682", "0.0003305188683169079961720764"),
),
("0.99999999999999", "1477289400", "0.9999852272151186611602884841"),
];
for &(x, y, expected) in test_cases {
let x = Decimal::from_str(x).unwrap();
Expand Down Expand Up @@ -3965,43 +3968,28 @@ mod maths {
}

#[test]
#[cfg(not(feature = "legacy-ops"))]
fn test_norm_cdf() {
let test_cases = &[
(
Decimal::from_str("-0.4").unwrap(),
either!(
Decimal::from_str("0.3445781286821245037094401704").unwrap(),
Decimal::from_str("0.3445781286821245037094401728").unwrap()
),
Decimal::from_str("0.3445781286821245037094401704").unwrap(),
),
(
Decimal::from_str("-0.1").unwrap(),
either!(
Decimal::from_str("0.4601722899186706579921922696").unwrap(),
Decimal::from_str("0.4601722899186706579921922711").unwrap()
),
Decimal::from_str("0.4601722899186706579921922696").unwrap(),
),
(
Decimal::from_str("0.1").unwrap(),
Decimal::from_str(either!(
"0.5398277100813293420078077304",
"0.5398277100813293420078077290"
))
.unwrap(),
Decimal::from_str("0.5398277100813293420078077304").unwrap(),
),
(
Decimal::from_str("0.4").unwrap(),
either!(
Decimal::from_str("0.6554218713178754962905598296").unwrap(),
Decimal::from_str("0.6554218713178754962905598272").unwrap()
),
Decimal::from_str("0.6554218713178754962905598296").unwrap(),
),
(
Decimal::from_str("2.0").unwrap(),
either!(
Decimal::from_str("0.9772497381095865280953380673").unwrap(),
Decimal::from_str("0.9772497381095865280953380672").unwrap()
),
Decimal::from_str("0.9772497381095865280953380673").unwrap(),
),
];
for case in test_cases {
Expand Down

0 comments on commit 80e9f08

Please sign in to comment.