Skip to content

Commit

Permalink
Initial attempt on fixing overflow
Browse files Browse the repository at this point in the history
  • Loading branch information
northernorca committed Sep 24, 2024
1 parent 692d353 commit 3ac0274
Showing 1 changed file with 75 additions and 3 deletions.
78 changes: 75 additions & 3 deletions basm-std/src/math/static_modint.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
use core::{fmt::Display, ops::*};

#[derive(Clone, Copy, PartialEq, Eq, Default, Debug)]
pub struct ModInt<const M: u64>(pub u64);
pub struct ModInt<const M: u64>(u64);

impl<const M: u64> ModInt<M> {
pub fn get(self) -> u64 {
self.0
}
}

impl<const M: u64> Display for ModInt<M> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
Expand All @@ -24,11 +30,17 @@ impl<const M: u64> From<ModInt<M>> for u64 {
}
}

// TODO: Handle the case for `self.0 + rhs.0` overflow
// TODO: Handle the case for `self.0 + rhs.0` overflow (the implementation below is not enough)
impl<const M: u64> Add for ModInt<M> {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
Self((self.0 + rhs.0) % M)
let (val, carry) = self.0.overflowing_add(rhs.0);
if carry {
let v: u64 = (1u64 << (u64::BITS - 1)) % M;
Self((val + (v >> 1)) % M)
} else {
Self(val % M)
}
}
}

Expand Down Expand Up @@ -98,3 +110,63 @@ impl<const M: u64> SubAssign for ModInt<M> {
self.0 = (M + self.0 - rhs.0) % M;
}
}

#[cfg(test)]
mod test {
use super::*;

#[test]
fn test_static_small_mod_add() {
let proper = |a: u64, b: u64, m: u64| ((a as u128 + b as u128) % (m as u128)) as u64;
macro_rules! test {
($($m:expr),*) => {$(
{
const M: u64 = $m;
for mut a in 0..100 {
if a > 50 {
a = a * 29 + 18;
}
let am = ModInt::<{ M }>::from(a);
for mut b in 0..100 {
if b > 50 {
b = b * 29 + 18;
}
let bm = ModInt::<{ M }>::from(b);
let t = am + bm;
assert_eq!(proper(a, b, M), t.0);
}
}
}
)*};
}
test!(2, 10, 593, 11729378, 2343246813781979);
}

#[test]
fn test_static_large_mod_add() {
let proper = |a: u64, b: u64, m: u64| ((a as u128 + b as u128) % (m as u128)) as u64;
macro_rules! test {
($($m:expr),*) => {$(
{
const M: u64 = $m;
let m = M / 3 * 2;
for a in [m, m.wrapping_add(1), m.wrapping_add(2), m.wrapping_add(30)] {
let am = ModInt::<{ M }>::from(a);
for b in [m, m.wrapping_add(1), m.wrapping_add(2), m.wrapping_add(30)] {
let bm = ModInt::<{ M }>::from(b);
let t = am + bm;
assert_eq!(proper(a, b, M), t.0);
}
}
}
)*};
}
test!(
u64::MAX / 2,
u64::MAX / 2 + 1,
u64::MAX / 3 * 2 + 40,
u64::MAX - 1,
u64::MAX
);
}
}

0 comments on commit 3ac0274

Please sign in to comment.