diff --git a/src/math.rs b/src/math.rs index 61f15d5..d390501 100644 --- a/src/math.rs +++ b/src/math.rs @@ -135,7 +135,7 @@ pub fn crt(r: &[i64], m: &[i64]) -> (i64, i64) { // r2 % m0 = r0 // r2 % mi = ri // -> (r0 + x*m0) % mi = ri - // -> x*u0*g % (u1*g) = (ri - r0) (u0*g = m0, u1*g = mi) + // -> x*u0*g = ri-r0 (mod u1*g) (u0*g = m0, u1*g = mi) // -> x = (ri - r0) / g * inv(u0) (mod u1) // im = inv(u0) (mod u1) (0 <= im < u1) @@ -188,7 +188,7 @@ pub fn crt(r: &[i64], m: &[i64]) -> (i64, i64) { pub fn floor_sum(n: i64, m: i64, mut a: i64, mut b: i64) -> i64 { let mut ans = 0; if a >= m { - ans += (n - 1) * n * (a / m) / 2; + ans += (n - 1) * n / 2 * (a / m); a %= m; } if b >= m { @@ -196,13 +196,11 @@ pub fn floor_sum(n: i64, m: i64, mut a: i64, mut b: i64) -> i64 { b %= m; } - let y_max = (a * n + b) / m; - let x_max = y_max * m - b; - if y_max == 0 { + let y_max = a * n + b; + if y_max < m { return ans; } - ans += (n - (x_max + a - 1) / a) * y_max; - ans += floor_sum(y_max, a, m, (a - x_max % a) % a); + ans += floor_sum(y_max / m, a, m, y_max % m); ans }