Skip to content

Commit a8c306d

Browse files
authored
Merge pull request #71 from qryxip/atomic-barrett-seqcst
Make `modint::Barrett` `(AtomicU32, AtomicU64)`
2 parents 006d353 + b09e6b5 commit a8c306d

File tree

2 files changed

+62
-33
lines changed

2 files changed

+62
-33
lines changed

src/internal_math.rs

+29-18
Original file line numberDiff line numberDiff line change
@@ -51,25 +51,36 @@ impl Barrett {
5151
/// a * b % m
5252
#[allow(clippy::many_single_char_names)]
5353
pub(crate) fn mul(&self, a: u32, b: u32) -> u32 {
54-
// [1] m = 1
55-
// a = b = im = 0, so okay
56-
57-
// [2] m >= 2
58-
// im = ceil(2^64 / m)
59-
// -> im * m = 2^64 + r (0 <= r < m)
60-
// let z = a*b = c*m + d (0 <= c, d < m)
61-
// a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
62-
// c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
63-
// ((ab * im) >> 64) == c or c + 1
64-
let mut z = a as u64;
65-
z *= b as u64;
66-
let x = (((z as u128) * (self.im as u128)) >> 64) as u64;
67-
let mut v = z.wrapping_sub(x.wrapping_mul(self._m as u64)) as u32;
68-
if self._m <= v {
69-
v = v.wrapping_add(self._m);
70-
}
71-
v
54+
mul_mod(a, b, self._m, self.im)
55+
}
56+
}
57+
58+
/// Calculates `a * b % m`.
59+
///
60+
/// * `a` `0 <= a < m`
61+
/// * `b` `0 <= b < m`
62+
/// * `m` `1 <= m <= 2^31`
63+
/// * `im` = ceil(2^64 / `m`)
64+
#[allow(clippy::many_single_char_names)]
65+
pub(crate) fn mul_mod(a: u32, b: u32, m: u32, im: u64) -> u32 {
66+
// [1] m = 1
67+
// a = b = im = 0, so okay
68+
69+
// [2] m >= 2
70+
// im = ceil(2^64 / m)
71+
// -> im * m = 2^64 + r (0 <= r < m)
72+
// let z = a*b = c*m + d (0 <= c, d < m)
73+
// a*b * im = (c*m + d) * im = c*(im*m) + d*im = c*2^64 + c*r + d*im
74+
// c*r + d*im < m * m + m * im < m * m + 2^64 + m <= 2^64 + m * (m + 1) < 2^64 * 2
75+
// ((ab * im) >> 64) == c or c + 1
76+
let mut z = a as u64;
77+
z *= b as u64;
78+
let x = (((z as u128) * (im as u128)) >> 64) as u64;
79+
let mut v = z.wrapping_sub(x.wrapping_mul(m as u64)) as u32;
80+
if m <= v {
81+
v = v.wrapping_add(m);
7282
}
83+
v
7384
}
7485

7586
/// # Parameters

src/modint.rs

+33-15
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ use std::{
5858
marker::PhantomData,
5959
ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Neg, Sub, SubAssign},
6060
str::FromStr,
61+
sync::atomic::{self, AtomicU32, AtomicU64},
6162
thread::LocalKey,
6263
};
6364

@@ -330,7 +331,7 @@ impl<I: Id> DynamicModInt<I> {
330331
/// ```
331332
#[inline]
332333
pub fn modulus() -> u32 {
333-
I::companion_barrett().with(|bt| bt.borrow().umod())
334+
I::companion_barrett().umod()
334335
}
335336

336337
/// Sets a modulus.
@@ -354,7 +355,7 @@ impl<I: Id> DynamicModInt<I> {
354355
if modulus == 0 {
355356
panic!("the modulus must not be 0");
356357
}
357-
I::companion_barrett().with(|bt| *bt.borrow_mut() = Barrett::new(modulus))
358+
I::companion_barrett().update(modulus);
358359
}
359360

360361
/// Creates a new `DynamicModInt`.
@@ -442,47 +443,64 @@ impl<I: Id> ModIntBase for DynamicModInt<I> {
442443
}
443444

444445
pub trait Id: 'static + Copy + Eq {
445-
// TODO: Make `internal_math::Barret` `Copy`.
446-
fn companion_barrett() -> &'static LocalKey<RefCell<Barrett>>;
446+
fn companion_barrett() -> &'static Barrett;
447447
}
448448

449449
#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
450450
pub enum DefaultId {}
451451

452452
impl Id for DefaultId {
453-
fn companion_barrett() -> &'static LocalKey<RefCell<Barrett>> {
454-
thread_local! {
455-
static BARRETT: RefCell<Barrett> = RefCell::default();
456-
}
453+
fn companion_barrett() -> &'static Barrett {
454+
static BARRETT: Barrett = Barrett::default();
457455
&BARRETT
458456
}
459457
}
460458

461459
/// Pair of _m_ and _ceil(2⁶⁴/m)_.
462-
pub struct Barrett(internal_math::Barrett);
460+
pub struct Barrett {
461+
m: AtomicU32,
462+
im: AtomicU64,
463+
}
463464

464465
impl Barrett {
465466
/// Creates a new `Barrett`.
466467
#[inline]
467-
pub fn new(m: u32) -> Self {
468-
Self(internal_math::Barrett::new(m))
468+
pub const fn new(m: u32) -> Self {
469+
Self {
470+
m: AtomicU32::new(m),
471+
im: AtomicU64::new((-1i64 as u64 / m as u64).wrapping_add(1)),
472+
}
473+
}
474+
475+
#[inline]
476+
const fn default() -> Self {
477+
Self::new(998_244_353)
478+
}
479+
480+
#[inline]
481+
fn update(&self, m: u32) {
482+
let im = (-1i64 as u64 / m as u64).wrapping_add(1);
483+
self.m.store(m, atomic::Ordering::SeqCst);
484+
self.im.store(im, atomic::Ordering::SeqCst);
469485
}
470486

471487
#[inline]
472488
fn umod(&self) -> u32 {
473-
self.0.umod()
489+
self.m.load(atomic::Ordering::SeqCst)
474490
}
475491

476492
#[inline]
477493
fn mul(&self, a: u32, b: u32) -> u32 {
478-
self.0.mul(a, b)
494+
let m = self.m.load(atomic::Ordering::SeqCst);
495+
let im = self.im.load(atomic::Ordering::SeqCst);
496+
internal_math::mul_mod(a, b, m, im)
479497
}
480498
}
481499

482500
impl Default for Barrett {
483501
#[inline]
484502
fn default() -> Self {
485-
Self(internal_math::Barrett::new(998_244_353))
503+
Self::default()
486504
}
487505
}
488506

@@ -810,7 +828,7 @@ impl<M: Modulus> InternalImplementations for StaticModInt<M> {
810828
impl<I: Id> InternalImplementations for DynamicModInt<I> {
811829
#[inline]
812830
fn mul_impl(lhs: Self, rhs: Self) -> Self {
813-
I::companion_barrett().with(|bt| Self::raw(bt.borrow().mul(lhs.val, rhs.val)))
831+
Self::raw(I::companion_barrett().mul(lhs.val, rhs.val))
814832
}
815833
}
816834

0 commit comments

Comments
 (0)