diff --git a/src/internal_math.rs b/src/internal_math.rs index 515191c..c6290b0 100644 --- a/src/internal_math.rs +++ b/src/internal_math.rs @@ -1,6 +1,6 @@ // remove this after dependencies has been added #![allow(dead_code)] -use std::mem::swap; +use std::{mem::swap, num::Wrapping as W}; /// # Arguments /// * `m` `1 <= m` @@ -235,6 +235,46 @@ pub(crate) fn primitive_root(m: i32) -> i32 { // omitted // template constexpr int primitive_root = primitive_root_constexpr(m); +/// # Arguments +/// * `n` `n < 2^32` +/// * `m` `1 <= m < 2^32` +/// +/// # Returns +/// `sum_{i=0}^{n-1} floor((ai + b) / m) (mod 2^64)` +/* const */ +#[allow(clippy::many_single_char_names)] +pub(crate) fn floor_sum_unsigned( + mut n: W, + mut m: W, + mut a: W, + mut b: W, +) -> W { + let mut ans = W(0); + loop { + if a >= m { + if n > W(0) { + ans += n * (n - W(1)) / W(2) * (a / m); + } + a %= m; + } + if b >= m { + ans += n * (b / m); + b %= m; + } + + let y_max = a * n + b; + if y_max < m { + break; + } + // y_max < m * (n + 1) + // floor(y_max / m) <= n + n = y_max / m; + b = y_max % m; + std::mem::swap(&mut m, &mut a); + } + ans +} + #[cfg(test)] mod tests { #![allow(clippy::unreadable_literal)] diff --git a/src/math.rs b/src/math.rs index 61f15d5..5dfd349 100644 --- a/src/math.rs +++ b/src/math.rs @@ -162,13 +162,16 @@ pub fn crt(r: &[i64], m: &[i64]) -> (i64, i64) { (r0, m0) } -/// Returns $\sum_{i = 0}^{n - 1} \lfloor \frac{a \times i + b}{m} \rfloor$. +/// Returns +/// +/// $$\sum_{i = 0}^{n - 1} \left\lfloor \frac{a \times i + b}{m} \right\rfloor.$$ +/// +/// It returns the answer in $\bmod 2^{\mathrm{64}}$, if overflowed. /// /// # Constraints /// -/// - $0 \leq n \leq 10^9$ -/// - $1 \leq m \leq 10^9$ -/// - $0 \leq a, b \leq m$ +/// - $0 \leq n \lt 2^{32}$ +/// - $1 \leq m \lt 2^{32}$ /// /// # Panics /// @@ -176,7 +179,7 @@ pub fn crt(r: &[i64], m: &[i64]) -> (i64, i64) { /// /// # Complexity /// -/// - $O(\log(n + m + a + b))$ +/// - $O(\log{(m+a)})$ /// /// # Example /// @@ -185,25 +188,25 @@ pub fn crt(r: &[i64], m: &[i64]) -> (i64, i64) { /// /// assert_eq!(math::floor_sum(6, 5, 4, 3), 13); /// ``` -pub fn floor_sum(n: i64, m: i64, mut a: i64, mut b: i64) -> i64 { - let mut ans = 0; - if a >= m { - ans += (n - 1) * n * (a / m) / 2; - a %= m; - } - if b >= m { - ans += n * (b / m); - b %= m; +#[allow(clippy::many_single_char_names)] +pub fn floor_sum(n: i64, m: i64, a: i64, b: i64) -> i64 { + use std::num::Wrapping as W; + assert!((0..1i64 << 32).contains(&n)); + assert!((1..1i64 << 32).contains(&m)); + let mut ans = W(0_u64); + let (wn, wm, mut wa, mut wb) = (W(n as u64), W(m as u64), W(a as u64), W(b as u64)); + if a < 0 { + let a2 = W(internal_math::safe_mod(a, m) as u64); + ans -= wn * (wn - W(1)) / W(2) * ((a2 - wa) / wm); + wa = a2; } - - let y_max = (a * n + b) / m; - let x_max = y_max * m - b; - if y_max == 0 { - return ans; + if b < 0 { + let b2 = W(internal_math::safe_mod(b, m) as u64); + ans -= wn * ((b2 - wb) / wm); + wb = b2; } - ans += (n - (x_max + a - 1) / a) * y_max; - ans += floor_sum(y_max, a, m, (a - x_max % a) % a); - ans + let ret = ans + internal_math::floor_sum_unsigned(wn, wm, wa, wb); + ret.0 as i64 } #[cfg(test)] @@ -306,5 +309,24 @@ mod tests { 499_999_999_500_000_000 ); assert_eq!(floor_sum(332955, 5590132, 2231, 999423), 22014575); + for n in 0..20 { + for m in 1..20 { + for a in -20..20 { + for b in -20..20 { + assert_eq!(floor_sum(n, m, a, b), floor_sum_naive(n, m, a, b)); + } + } + } + } + } + + #[allow(clippy::many_single_char_names)] + fn floor_sum_naive(n: i64, m: i64, a: i64, b: i64) -> i64 { + let mut ans = 0; + for i in 0..n { + let z = a * i + b; + ans += (z - internal_math::safe_mod(z, m)) / m; + } + ans } }