Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

introduce checked_div for PerThings and checked_rational_mul_correction #1936

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 50 additions & 5 deletions substrate/primitives/arithmetic/src/helpers_128bit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -183,21 +183,35 @@ mod double128 {
}

/// Returns `a * b / c` (wrapping to 128 bits) or `None` in the case of
/// overflow.
/// overflow or div by zero.
pub const fn multiply_by_rational_with_rounding(
a: u128,
b: u128,
c: u128,
r: Rounding,
) -> Option<u128> {
match checked_multiply_by_rational_with_rounding(a, b, c, r) {
ggwpez marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
match checked_multiply_by_rational_with_rounding(a, b, c, r) {
checked_multiply_by_rational_with_rounding(a, b, c, r).ok()

Ok(value) => Some(value),
Err(_) => None,
}
}

/// Returns `a * b / c` (wrapping to 128 bits) or `Err` if tries to div by zero or
/// overflow
pub const fn checked_multiply_by_rational_with_rounding(
a: u128,
b: u128,
c: u128,
r: Rounding,
) -> Result<u128, &'static str> {
use double128::Double128;
if c == 0 {
return None
return Err("Division by zero")
}
let (result, remainder) = Double128::product_of(a, b).div(c);
let mut result: u128 = match result.try_into_u128() {
Ok(v) => v,
Err(_) => return None,
Err(_) => return Err("Overflow"),
};
if match r {
Rounding::Up => remainder > 0,
Expand All @@ -208,10 +222,10 @@ pub const fn multiply_by_rational_with_rounding(
} {
result = match result.checked_add(1) {
Some(v) => v,
None => return None,
None => return Err("Overflow"),
};
}
Some(result)
Ok(result)
}

pub const fn sqrt(mut n: u128) -> u128 {
Expand Down Expand Up @@ -244,6 +258,7 @@ pub const fn sqrt(mut n: u128) -> u128 {
#[cfg(test)]
mod tests {
use super::*;
use checked_multiply_by_rational_with_rounding as checked_mulrat;
use codec::{Decode, Encode};
use multiply_by_rational_with_rounding as mulrat;
use Rounding::*;
Expand Down Expand Up @@ -278,6 +293,36 @@ mod tests {
assert_eq!(mulrat(1, MAX / 2 + 1, MAX, NearestPrefUp), Some(1));
}

#[test]
fn rational_checked_multiply_basic_rounding_works() {
assert_eq!(checked_mulrat(1, 1, 1, Up), Ok(1));
assert_eq!(checked_mulrat(3, 1, 3, Up), Ok(1));
assert_eq!(checked_mulrat(1, 1, 3, Up), Ok(1));
assert_eq!(checked_mulrat(1, 2, 3, Down), Ok(0));
assert_eq!(checked_mulrat(1, 1, 3, NearestPrefDown), Ok(0));
assert_eq!(checked_mulrat(1, 1, 2, NearestPrefDown), Ok(0));
assert_eq!(checked_mulrat(1, 2, 3, NearestPrefDown), Ok(1));
assert_eq!(checked_mulrat(1, 1, 3, NearestPrefUp), Ok(0));
assert_eq!(checked_mulrat(1, 1, 2, NearestPrefUp), Ok(1));
assert_eq!(checked_mulrat(1, 2, 3, NearestPrefUp), Ok(1));
assert_eq!(checked_mulrat(3, 1, 0, Up), Err("Division by zero"));
}

#[test]
fn rational_checked_multiply_big_number_works() {
assert_eq!(checked_mulrat(MAX, MAX - 1, MAX, Down), Ok(MAX - 1));
ggwpez marked this conversation as resolved.
Show resolved Hide resolved
assert_eq!(checked_mulrat(MAX, 1, MAX, Down), Ok(1));
assert_eq!(checked_mulrat(MAX, MAX - 1, MAX, Up), Ok(MAX - 1));
assert_eq!(checked_mulrat(MAX, 1, MAX, Up), Ok(1));
assert_eq!(checked_mulrat(1, MAX - 1, MAX, Down), Ok(0));
assert_eq!(checked_mulrat(1, 1, MAX, Up), Ok(1));
assert_eq!(checked_mulrat(1, MAX / 2, MAX, NearestPrefDown), Ok(0));
assert_eq!(checked_mulrat(1, MAX / 2 + 1, MAX, NearestPrefDown), Ok(1));
assert_eq!(checked_mulrat(1, MAX / 2, MAX, NearestPrefUp), Ok(0));
assert_eq!(checked_mulrat(1, MAX / 2 + 1, MAX, NearestPrefUp), Ok(1));
assert_eq!(checked_mulrat(1, MAX / 2 + 1, 0, NearestPrefUp), Err("Division by zero"));
}

#[test]
fn sqrt_works() {
for i in 0..100_000u32 {
Expand Down
134 changes: 129 additions & 5 deletions substrate/primitives/arithmetic/src/per_things.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@
use serde::{Deserialize, Serialize};

use crate::traits::{
BaseArithmetic, Bounded, CheckedAdd, CheckedMul, CheckedSub, One, SaturatedConversion,
Saturating, UniqueSaturatedInto, Unsigned, Zero,
BaseArithmetic, Bounded, CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, One,
SaturatedConversion, Saturating, UniqueSaturatedInto, Unsigned, Zero,
};
use codec::{CompactAs, Encode};
use core::{
Expand Down Expand Up @@ -531,18 +531,40 @@ where
(x / maximum) * part_n + c
}

/// Compute the error due to integer division in the expression `x / denom * numer`.
/// Unchecked computation of the error due to integer division in the expression `x / denom *
/// numer`.
fn rational_mul_correction<N, P>(x: N, numer: P::Inner, denom: P::Inner, rounding: Rounding) -> N
where
N: MultiplyArg + UniqueSaturatedInto<P::Inner>,
P: PerThing,
P::Inner: Into<N>,
{
checked_rational_mul_correction::<N, P>(x, numer, denom, rounding).unwrap()
}

/// Checked computation of the error due to integer division in the expression `x / denom * numer`.
///
/// Take the remainder of `x / denom` and multiply by `numer / denom`. The result can be added
/// to `x / denom * numer` for an accurate result.
fn rational_mul_correction<N, P>(x: N, numer: P::Inner, denom: P::Inner, rounding: Rounding) -> N
fn checked_rational_mul_correction<N, P>(
x: N,
numer: P::Inner,
denom: P::Inner,
rounding: Rounding,
) -> Result<N, &'static str>
where
N: MultiplyArg + UniqueSaturatedInto<P::Inner>,
P: PerThing,
P::Inner: Into<N>,
{
let numer_upper = P::Upper::from(numer);
let denom_n: N = denom.into();

// checking `denom` after in case `into()` is truncating.
if P::Inner::is_zero(&denom) {
return Err("Division by zero")
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In line 529 maybe there is a checked_rem? Then we dont need this if.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not sure what you mean by checked_rem

let denom_upper = P::Upper::from(denom);
let rem = x.rem(denom_n);
// `rem` is less than `denom`, which fits in `P::Inner`.
Expand Down Expand Up @@ -575,7 +597,7 @@ where
}
},
}
rem_mul_div_inner.into()
Ok(rem_mul_div_inner.into())
}

macro_rules! implement_per_thing {
Expand Down Expand Up @@ -1056,6 +1078,17 @@ macro_rules! implement_per_thing {
}
}

impl CheckedDiv for $name {
#[inline]
fn checked_div(&self, rhs: &Self) -> Option<Self> {
if rhs.0.is_zero() {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should use from_rational_with_rounding

return None
}

self.deconstruct().checked_div(rhs.deconstruct()).map($name::from_parts)
ggwpez marked this conversation as resolved.
Show resolved Hide resolved
}
}

impl $crate::traits::Zero for $name {
fn zero() -> Self {
Self::zero()
Expand Down Expand Up @@ -1579,6 +1612,77 @@ macro_rules! implement_per_thing {
}
}

#[test]
fn checked_rational_mul_correction_works() {
assert_eq!(
super::checked_rational_mul_correction::<$type, $name>(
<$type>::max_value(),
<$type>::max_value(),
<$type>::max_value(),
super::Rounding::NearestPrefDown,
),
Ok(0),
);
assert_eq!(
super::checked_rational_mul_correction::<$type, $name>(
<$type>::max_value() - 1,
<$type>::max_value(),
<$type>::max_value(),
super::Rounding::NearestPrefDown,
),
Ok(<$type>::max_value() - 1),
);
assert_eq!(
super::checked_rational_mul_correction::<$upper_type, $name>(
((<$type>::max_value() - 1) as $upper_type).pow(2),
<$type>::max_value(),
<$type>::max_value(),
super::Rounding::NearestPrefDown,
),
Ok(1),
);
// ((max^2 - 1) % max) * max / max == max - 1
assert_eq!(
super::checked_rational_mul_correction::<$upper_type, $name>(
(<$type>::max_value() as $upper_type).pow(2) - 1,
<$type>::max_value(),
<$type>::max_value(),
super::Rounding::NearestPrefDown,
),
Ok(<$upper_type>::from((<$type>::max_value() - 1))),
);
// (max % 2) * max / 2 == max / 2
assert_eq!(
super::checked_rational_mul_correction::<$upper_type, $name>(
(<$type>::max_value() as $upper_type).pow(2),
<$type>::max_value(),
2 as $type,
super::Rounding::NearestPrefDown,
),
Ok(<$type>::max_value() as $upper_type / 2),
);
// ((max^2 - 1) % max) * 2 / max == 2 (rounded up)
assert_eq!(
super::checked_rational_mul_correction::<$upper_type, $name>(
(<$type>::max_value() as $upper_type).pow(2) - 1,
2 as $type,
<$type>::max_value(),
super::Rounding::NearestPrefDown,
),
Ok(2),
);
// ((max^2 - 1) % max) * 2 / max == 1 (rounded down)
assert_eq!(
super::checked_rational_mul_correction::<$upper_type, $name>(
(<$type>::max_value() as $upper_type).pow(2) - 1,
2 as $type,
<$type>::max_value(),
super::Rounding::Down,
),
Ok(1),
);
}

#[test]
fn rational_mul_correction_works() {
assert_eq!(
Expand Down Expand Up @@ -1772,6 +1876,26 @@ macro_rules! implement_per_thing {
Some($name::from_percent(0))
);
}

#[test]
fn test_basic_checked_div() {
assert_eq!(
$name::from_parts($max).checked_div(&$name::from_parts($max)),
Some($name::from_parts(1))
);
assert_eq!(
$name::from_percent(100).checked_div(&$name::from_parts(100)),
Some($name::from_percent(1))
);
assert_eq!(
$name::from_percent(0).checked_div(&$name::from_percent(0)),
None
);
assert_eq!(
$name::from_parts(0).checked_div(&$name::from_parts(0)),
None
);
}
}
};
}
Expand Down
Loading