Skip to content
This repository has been archived by the owner on Nov 15, 2023. It is now read-only.

Add ensure_pow method #13042

Merged
merged 3 commits into from
Jan 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 42 additions & 10 deletions primitives/arithmetic/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@

use codec::HasCompact;
pub use ensure::{
Ensure, EnsureAdd, EnsureAddAssign, EnsureDiv, EnsureDivAssign, EnsureFixedPointNumber,
EnsureFrom, EnsureInto, EnsureMul, EnsureMulAssign, EnsureOp, EnsureOpAssign, EnsureSub,
EnsureSubAssign,
ensure_pow, Ensure, EnsureAdd, EnsureAddAssign, EnsureDiv, EnsureDivAssign,
EnsureFixedPointNumber, EnsureFrom, EnsureInto, EnsureMul, EnsureMulAssign, EnsureOp,
EnsureOpAssign, EnsureSub, EnsureSubAssign,
};
pub use integer_sqrt::IntegerSquareRoot;
pub use num_traits::{
Expand Down Expand Up @@ -342,7 +342,7 @@ impl<T: Sized> SaturatedConversion for T {}
/// The *EnsureOps* family functions follows the same behavior as *CheckedOps* but
/// returning an [`ArithmeticError`](crate::ArithmeticError) instead of `None`.
mod ensure {
use super::{CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, Zero};
use super::{checked_pow, CheckedAdd, CheckedDiv, CheckedMul, CheckedSub, One, Zero};
use crate::{ArithmeticError, FixedPointNumber, FixedPointOperand};

/// Performs addition that returns [`ArithmeticError`] instead of wrapping around on overflow.
Expand Down Expand Up @@ -511,6 +511,27 @@ mod ensure {
}
}

/// Raises a value to the power of exp, returning `ArithmeticError` if an overflow occurred.
///
/// Check [`checked_pow`] for more info about border cases.
///
/// ```
/// use sp_arithmetic::{traits::ensure_pow, ArithmeticError};
///
/// fn overflow() -> Result<(), ArithmeticError> {
/// ensure_pow(2u64, 64)?;
/// Ok(())
/// }
///
/// assert_eq!(overflow(), Err(ArithmeticError::Overflow));
/// ```
pub fn ensure_pow<T: One + CheckedMul + Clone>(
base: T,
exp: usize,
) -> Result<T, ArithmeticError> {
checked_pow(base, exp).ok_or(ArithmeticError::Overflow)
}

impl<T: EnsureAddAssign> EnsureAdd for T {}
impl<T: EnsureSubAssign> EnsureSub for T {}
impl<T: EnsureMulAssign> EnsureMul for T {}
Expand Down Expand Up @@ -953,6 +974,15 @@ mod tests {
test_ensure(values(), &EnsureDiv::ensure_div, &CheckedDiv::checked_div);
}

#[test]
fn ensure_pow_works() {
test_ensure(
values().into_iter().map(|(base, exp)| (base, exp as usize)).collect(),
ensure_pow,
|&a, &b| checked_pow(a, b),
);
}

#[test]
fn ensure_add_assign_works() {
test_ensure_assign(values(), &EnsureAddAssign::ensure_add_assign, &EnsureAdd::ensure_add);
Expand All @@ -974,11 +1004,12 @@ mod tests {
}

/// Test that the ensured function returns the expected un-ensured value.
fn test_ensure<V, E, P>(pairs: Vec<(V, V)>, ensured: E, unensured: P)
fn test_ensure<V, W, E, P>(pairs: Vec<(V, W)>, ensured: E, unensured: P)
where
V: Ensure + core::fmt::Debug + Copy,
E: Fn(V, V) -> Result<V, ArithmeticError>,
P: Fn(&V, &V) -> Option<V>,
W: Ensure + core::fmt::Debug + Copy,
E: Fn(V, W) -> Result<V, ArithmeticError>,
P: Fn(&V, &W) -> Option<V>,
{
for (a, b) in pairs.into_iter() {
match ensured(a, b) {
Expand All @@ -993,11 +1024,12 @@ mod tests {
}

/// Test that the ensured function modifies `self` to the expected un-ensured value.
fn test_ensure_assign<V, E, P>(pairs: Vec<(V, V)>, ensured: E, unensured: P)
fn test_ensure_assign<V, W, E, P>(pairs: Vec<(V, W)>, ensured: E, unensured: P)
where
V: Ensure + std::panic::RefUnwindSafe + std::panic::UnwindSafe + core::fmt::Debug + Copy,
E: Fn(&mut V, V) -> Result<(), ArithmeticError>,
P: Fn(V, V) -> Result<V, ArithmeticError> + std::panic::RefUnwindSafe,
W: Ensure + std::panic::RefUnwindSafe + std::panic::UnwindSafe + core::fmt::Debug + Copy,
E: Fn(&mut V, W) -> Result<(), ArithmeticError>,
P: Fn(V, W) -> Result<V, ArithmeticError> + std::panic::RefUnwindSafe,
{
for (mut a, b) in pairs.into_iter() {
let old_a = a;
Expand Down
10 changes: 5 additions & 5 deletions primitives/runtime/src/traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ use impl_trait_for_tuples::impl_for_tuples;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use sp_application_crypto::AppKey;
pub use sp_arithmetic::traits::{
AtLeast32Bit, AtLeast32BitUnsigned, Bounded, CheckedAdd, CheckedDiv, CheckedMul, CheckedShl,
CheckedShr, CheckedSub, Ensure, EnsureAdd, EnsureAddAssign, EnsureDiv, EnsureDivAssign,
EnsureFixedPointNumber, EnsureFrom, EnsureInto, EnsureMul, EnsureMulAssign, EnsureOp,
EnsureOpAssign, EnsureSub, EnsureSubAssign, IntegerSquareRoot, One, SaturatedConversion,
Saturating, UniqueSaturatedFrom, UniqueSaturatedInto, Zero,
checked_pow, ensure_pow, AtLeast32Bit, AtLeast32BitUnsigned, Bounded, CheckedAdd, CheckedDiv,
ggwpez marked this conversation as resolved.
Show resolved Hide resolved
CheckedMul, CheckedShl, CheckedShr, CheckedSub, Ensure, EnsureAdd, EnsureAddAssign, EnsureDiv,
EnsureDivAssign, EnsureFixedPointNumber, EnsureFrom, EnsureInto, EnsureMul, EnsureMulAssign,
EnsureOp, EnsureOpAssign, EnsureSub, EnsureSubAssign, IntegerSquareRoot, One,
SaturatedConversion, Saturating, UniqueSaturatedFrom, UniqueSaturatedInto, Zero,
};
use sp_core::{self, storage::StateVersion, Hasher, RuntimeDebug, TypeId};
#[doc(hidden)]
Expand Down