Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add checked_add and checked_norm_pdf. #375

Merged
merged 4 commits into from
May 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 42 additions & 8 deletions src/maths.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,21 @@ pub trait MathematicalOps {
/// tolerance of roughly `0.0000002`.
fn exp(&self) -> Decimal;

/// The estimated exponential function, e<sup>x</sup>. Stops calculating when it is within
/// tolerance of roughly `0.0000002`. Returns `None` on overflow.
fn checked_exp(&self) -> Option<Decimal>;

/// The estimated exponential function, e<sup>x</sup> using the `tolerance` provided as a hint
/// as to when to stop calculating. A larger tolerance will cause the number to stop calculating
/// sooner at the potential cost of a slightly less accurate result.
fn exp_with_tolerance(&self, tolerance: Decimal) -> Decimal;

/// The estimated exponential function, e<sup>x</sup> using the `tolerance` provided as a hint
/// as to when to stop calculating. A larger tolerance will cause the number to stop calculating
/// sooner at the potential cost of a slightly less accurate result.
/// Returns `None` on overflow.
fn checked_exp_with_tolerance(&self, tolerance: Decimal) -> Option<Decimal>;

/// Raise self to the given integer exponent: x<sup>y</sup>
fn powi(&self, exp: i64) -> Decimal;

Expand Down Expand Up @@ -93,26 +103,40 @@ pub trait MathematicalOps {
/// The Cumulative distribution function for a Normal distribution
fn norm_cdf(&self) -> Decimal;

/// The Probability density function for a Normal distribution
/// The Probability density function for a Normal distribution.
fn norm_pdf(&self) -> Decimal;

/// The Probability density function for a Normal distribution returning `None` on overflow.
fn checked_norm_pdf(&self) -> Option<Decimal>;
}

impl MathematicalOps for Decimal {
fn exp(&self) -> Decimal {
self.exp_with_tolerance(EXP_TOLERANCE)
}

#[inline]
fn checked_exp(&self) -> Option<Decimal> {
self.checked_exp_with_tolerance(EXP_TOLERANCE)
}

fn exp_with_tolerance(&self, tolerance: Decimal) -> Decimal {
match self.checked_exp_with_tolerance(tolerance) {
Some(d) => d,
None => panic!("Exp overflowed"),
}
}

#[inline]
fn checked_exp_with_tolerance(&self, tolerance: Decimal) -> Option<Decimal> {
if self.is_zero() {
return Decimal::ONE;
return Some(Decimal::ONE);
}

let mut term = *self;
let mut result = self + Decimal::ONE;

for factorial in FACTORIAL.iter().skip(2) {
term = self * term;
term = self.checked_mul(term)?;
let next = result + (term / factorial);
let diff = (next - result).abs();
result = next;
Expand All @@ -121,7 +145,7 @@ impl MathematicalOps for Decimal {
}
}

result
Some(result)
}

fn powi(&self, exp: i64) -> Decimal {
Expand Down Expand Up @@ -248,7 +272,7 @@ impl MathematicalOps for Decimal {
Some(e) => e,
None => return None,
};
let mut result = e.exp();
let mut result = e.checked_exp()?;
result.set_sign_negative(negative);
Some(result)
}
Expand Down Expand Up @@ -330,10 +354,20 @@ impl MathematicalOps for Decimal {
(Decimal::ONE + (self / Decimal::from_parts(2318911239, 3292722, 0, false, 16)).erf()) / TWO
}

/// The Probability density function for a Normal distribution
/// The Probability density function for a Normal distribution.
fn norm_pdf(&self) -> Decimal {
match self.checked_norm_pdf() {
Some(d) => d,
None => panic!("Norm Pdf overflowed"),
}
}

/// The Probability density function for a Normal distribution, return `None` on overflow.
fn checked_norm_pdf(&self) -> Option<Decimal> {
let sqrt2pi = Decimal::from_parts_raw(2133383024, 2079885984, 1358845910, 1835008);
(-self.powi(2) / TWO).exp() / sqrt2pi
let factor = -self.checked_powi(2)?;
let factor = factor.checked_div(TWO)?;
factor.checked_exp()?.checked_div(sqrt2pi)
}
}

Expand Down
29 changes: 27 additions & 2 deletions tests/decimal_tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3374,13 +3374,18 @@ mod maths {
let x = Decimal::from_str(x).unwrap();
let expected = Decimal::from_str(expected).unwrap();
assert_eq!(expected, x.exp());
assert_eq!(Some(expected), x.checked_exp());
}
}

#[cfg(not(feature = "legacy-ops"))]
#[test]
fn test_exp_with_tolerance() {
let test_cases = &[
// e^0 = 1
("0", "0.0002", "1"),
// e^1 ~= 2.7182539682539682539682539683
("1", "0.0002", "2.7182539682539682539682539683"),
// e^10 ~= 22026.465794806703
(
"10",
Expand All @@ -3389,6 +3394,8 @@ mod maths {
),
// e^11 ~= 59874.14171519778
("11", "0.0002", "59873.388231055804982198781924"),
// e^11.7578 ~= 127741.03548949540892948423052
("11.7578", "0.0002", "127741.03548949540892948423052"),
// e^3 ~= 20.085536923187664
("3", "0.00002", "20.085534430970814899386327955"),
// e^8 ~= 2980.957987041727
Expand All @@ -3397,12 +3404,30 @@ mod maths {
("0.1", "0.0002", "1.1051666666666666666666666667"),
// e^2.0 ~= 7.3890560989306495
("2.0", "0.0002", "7.3890460157126823793490460156"),
// e^11.7578+ starts to overflow
("11.7579", "0.0002", ""),
// e^11.7578+ starts to overflow
("123", "0.0002", ""),
// e^-8+ starts to flip and underflow
("-8", "0.0002", "0.0002858169660624369145768176"),
// e^-1024 starts to flip and underflow
("-1024", "0.0002", ""),
];
for &(x, tolerance, expected) in test_cases {
let x = Decimal::from_str(x).unwrap();
let tolerance = Decimal::from_str(tolerance).unwrap();
let expected = Decimal::from_str(expected).unwrap();
assert_eq!(expected, x.exp_with_tolerance(tolerance));
let expected = if expected.is_empty() {
None
} else {
Some(Decimal::from_str(expected).unwrap())
};

if let Some(expected) = expected {
assert_eq!(expected, x.exp_with_tolerance(tolerance));
assert_eq!(Some(expected), x.checked_exp_with_tolerance(tolerance));
} else {
assert_eq!(None, x.checked_exp_with_tolerance(tolerance));
}
}
}

Expand Down