diff --git a/substrate/primitives/arithmetic/src/helpers_128bit.rs b/substrate/primitives/arithmetic/src/helpers_128bit.rs index 9b9c74ba5577..6c4a939c99f2 100644 --- a/substrate/primitives/arithmetic/src/helpers_128bit.rs +++ b/substrate/primitives/arithmetic/src/helpers_128bit.rs @@ -183,21 +183,35 @@ mod double128 { } /// Returns `a * b / c` (wrapping to 128 bits) or `None` in the case of -/// overflow. +/// overflow or div by zero. pub const fn multiply_by_rational_with_rounding( a: u128, b: u128, c: u128, r: Rounding, ) -> Option { + match checked_multiply_by_rational_with_rounding(a, b, c, r) { + Ok(value) => Some(value), + Err(_) => None, + } +} + +/// Returns `a * b / c` (wrapping to 128 bits) or `Err` if tries to div by zero or +/// overflow +pub const fn checked_multiply_by_rational_with_rounding( + a: u128, + b: u128, + c: u128, + r: Rounding, +) -> Result { use double128::Double128; if c == 0 { - return None + return Err("Division by zero") } let (result, remainder) = Double128::product_of(a, b).div(c); let mut result: u128 = match result.try_into_u128() { Ok(v) => v, - Err(_) => return None, + Err(_) => return Err("Overflow"), }; if match r { Rounding::Up => remainder > 0, @@ -208,10 +222,10 @@ pub const fn multiply_by_rational_with_rounding( } { result = match result.checked_add(1) { Some(v) => v, - None => return None, + None => return Err("None"), }; } - Some(result) + Ok(result) } pub const fn sqrt(mut n: u128) -> u128 { @@ -244,6 +258,7 @@ pub const fn sqrt(mut n: u128) -> u128 { #[cfg(test)] mod tests { use super::*; + use checked_multiply_by_rational_with_rounding as checked_mulrat; use codec::{Decode, Encode}; use multiply_by_rational_with_rounding as mulrat; use Rounding::*; @@ -278,6 +293,36 @@ mod tests { assert_eq!(mulrat(1, MAX / 2 + 1, MAX, NearestPrefUp), Some(1)); } + #[test] + fn rational_checked_multiply_basic_rounding_works() { + assert_eq!(checked_mulrat(1, 1, 1, Up), Ok(1)); + assert_eq!(checked_mulrat(3, 1, 3, Up), Ok(1)); + assert_eq!(checked_mulrat(1, 1, 3, Up), Ok(1)); + assert_eq!(checked_mulrat(1, 2, 3, Down), Ok(0)); + assert_eq!(checked_mulrat(1, 1, 3, NearestPrefDown), Ok(0)); + assert_eq!(checked_mulrat(1, 1, 2, NearestPrefDown), Ok(0)); + assert_eq!(checked_mulrat(1, 2, 3, NearestPrefDown), Ok(1)); + assert_eq!(checked_mulrat(1, 1, 3, NearestPrefUp), Ok(0)); + assert_eq!(checked_mulrat(1, 1, 2, NearestPrefUp), Ok(1)); + assert_eq!(checked_mulrat(1, 2, 3, NearestPrefUp), Ok(1)); + assert_eq!(checked_mulrat(3, 1, 0, Up), Err("Division by zero")); + } + + #[test] + fn rational_checked_multiply_big_number_works() { + assert_eq!(checked_mulrat(MAX, MAX - 1, MAX, Down), Ok(MAX - 1)); + assert_eq!(checked_mulrat(MAX, 1, MAX, Down), Ok(1)); + assert_eq!(checked_mulrat(MAX, MAX - 1, MAX, Up), Ok(MAX - 1)); + assert_eq!(checked_mulrat(MAX, 1, MAX, Up), Ok(1)); + assert_eq!(checked_mulrat(1, MAX - 1, MAX, Down), Ok(0)); + assert_eq!(checked_mulrat(1, 1, MAX, Up), Ok(1)); + assert_eq!(checked_mulrat(1, MAX / 2, MAX, NearestPrefDown), Ok(0)); + assert_eq!(checked_mulrat(1, MAX / 2 + 1, MAX, NearestPrefDown), Ok(1)); + assert_eq!(checked_mulrat(1, MAX / 2, MAX, NearestPrefUp), Ok(0)); + assert_eq!(checked_mulrat(1, MAX / 2 + 1, MAX, NearestPrefUp), Ok(1)); + assert_eq!(checked_mulrat(1, MAX / 2 + 1, 0, NearestPrefUp), Err("Division by zero")); + } + #[test] fn sqrt_works() { for i in 0..100_000u32 { diff --git a/substrate/primitives/arithmetic/src/per_things.rs b/substrate/primitives/arithmetic/src/per_things.rs index fe88b72e24c2..31353c67d6f4 100644 --- a/substrate/primitives/arithmetic/src/per_things.rs +++ b/substrate/primitives/arithmetic/src/per_things.rs @@ -19,8 +19,8 @@ use serde::{Deserialize, Serialize}; use crate::traits::{ - BaseArithmetic, Bounded, CheckedAdd, CheckedMul, CheckedSub, One, SaturatedConversion, - Saturating, UniqueSaturatedInto, Unsigned, Zero, + BaseArithmetic, Bounded, CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, One, + SaturatedConversion, Saturating, UniqueSaturatedInto, Unsigned, Zero, }; use codec::{CompactAs, Encode}; use num_traits::{Pow, SaturatingAdd, SaturatingSub}; @@ -496,16 +496,35 @@ where (x / maximum) * part_n + c } +/// Checked compute of the error due to integer division in the expression `x / denom * numer`. +fn rational_mul_correction(x: N, numer: P::Inner, denom: P::Inner, rounding: Rounding) -> N +where + N: MultiplyArg + UniqueSaturatedInto, + P: PerThing, + P::Inner: Into, +{ + checked_rational_mul_correction::(x, numer, denom, rounding).unwrap() +} + /// Compute the error due to integer division in the expression `x / denom * numer`. /// /// Take the remainder of `x / denom` and multiply by `numer / denom`. The result can be added /// to `x / denom * numer` for an accurate result. -fn rational_mul_correction(x: N, numer: P::Inner, denom: P::Inner, rounding: Rounding) -> N +fn checked_rational_mul_correction( + x: N, + numer: P::Inner, + denom: P::Inner, + rounding: Rounding, +) -> Result where N: MultiplyArg + UniqueSaturatedInto, P: PerThing, P::Inner: Into, { + if P::Inner::is_zero(&denom) { + return Err("Division by zero") + } + let numer_upper = P::Upper::from(numer); let denom_n: N = denom.into(); let denom_upper = P::Upper::from(denom); @@ -540,7 +559,7 @@ where } }, } - rem_mul_div_inner.into() + Ok(rem_mul_div_inner.into()) } macro_rules! implement_per_thing { @@ -1021,6 +1040,17 @@ macro_rules! implement_per_thing { } } + impl CheckedDiv for $name { + #[inline] + fn checked_div(&self, rhs: &Self) -> Option { + if rhs.0.is_zero() { + return None + } + + self.deconstruct().checked_div(rhs.deconstruct()).map($name::from_parts) + } + } + impl $crate::traits::Zero for $name { fn zero() -> Self { Self::zero() @@ -1544,6 +1574,77 @@ macro_rules! implement_per_thing { } } + #[test] + fn checked_rational_mul_correction_works() { + assert_eq!( + super::checked_rational_mul_correction::<$type, $name>( + <$type>::max_value(), + <$type>::max_value(), + <$type>::max_value(), + super::Rounding::NearestPrefDown, + ), + Ok(0), + ); + assert_eq!( + super::checked_rational_mul_correction::<$type, $name>( + <$type>::max_value() - 1, + <$type>::max_value(), + <$type>::max_value(), + super::Rounding::NearestPrefDown, + ), + Ok(<$type>::max_value() - 1), + ); + assert_eq!( + super::checked_rational_mul_correction::<$upper_type, $name>( + ((<$type>::max_value() - 1) as $upper_type).pow(2), + <$type>::max_value(), + <$type>::max_value(), + super::Rounding::NearestPrefDown, + ), + Ok(1), + ); + // ((max^2 - 1) % max) * max / max == max - 1 + assert_eq!( + super::checked_rational_mul_correction::<$upper_type, $name>( + (<$type>::max_value() as $upper_type).pow(2) - 1, + <$type>::max_value(), + <$type>::max_value(), + super::Rounding::NearestPrefDown, + ), + Ok(<$upper_type>::from((<$type>::max_value() - 1))), + ); + // (max % 2) * max / 2 == max / 2 + assert_eq!( + super::checked_rational_mul_correction::<$upper_type, $name>( + (<$type>::max_value() as $upper_type).pow(2), + <$type>::max_value(), + 2 as $type, + super::Rounding::NearestPrefDown, + ), + Ok(<$type>::max_value() as $upper_type / 2), + ); + // ((max^2 - 1) % max) * 2 / max == 2 (rounded up) + assert_eq!( + super::checked_rational_mul_correction::<$upper_type, $name>( + (<$type>::max_value() as $upper_type).pow(2) - 1, + 2 as $type, + <$type>::max_value(), + super::Rounding::NearestPrefDown, + ), + Ok(2), + ); + // ((max^2 - 1) % max) * 2 / max == 1 (rounded down) + assert_eq!( + super::checked_rational_mul_correction::<$upper_type, $name>( + (<$type>::max_value() as $upper_type).pow(2) - 1, + 2 as $type, + <$type>::max_value(), + super::Rounding::Down, + ), + Ok(1), + ); + } + #[test] fn rational_mul_correction_works() { assert_eq!( @@ -1737,6 +1838,26 @@ macro_rules! implement_per_thing { Some($name::from_percent(0)) ); } + + #[test] + fn test_basic_checked_div() { + assert_eq!( + $name::from_parts($max).checked_div(&$name::from_parts($max)), + Some($name::from_parts(1)) + ); + assert_eq!( + $name::from_percent(100).checked_div(&$name::from_parts(100)), + Some($name::from_percent(1)) + ); + assert_eq!( + $name::from_percent(0).checked_div(&$name::from_percent(0)), + None + ); + assert_eq!( + $name::from_parts(0).checked_div(&$name::from_parts(0)), + None + ); + } } }; }