Skip to content

Commit 2cee3c9

Browse files
committed
Implement convolution
1 parent 54a54fd commit 2cee3c9

File tree

3 files changed

+239
-4
lines changed

3 files changed

+239
-4
lines changed

Diff for: src/convolution.rs

+215
Original file line numberDiff line numberDiff line change
@@ -1 +1,216 @@
1+
use crate::{
2+
internal_bit, internal_math,
3+
modint::{ButterflyCache, Modulus, StaticModInt},
4+
};
5+
use std::{cell::RefCell, cmp, thread::LocalKey};
16

7+
#[allow(clippy::many_single_char_names)]
8+
pub fn convolution<M: Modulus>(
9+
a: &[StaticModInt<M>],
10+
b: &[StaticModInt<M>],
11+
) -> Vec<StaticModInt<M>> {
12+
if a.is_empty() || b.is_empty() {
13+
return vec![];
14+
}
15+
let (n, m) = (a.len(), b.len());
16+
17+
if cmp::min(n, m) <= 60 {
18+
let (n, m, a, b) = if n < m { (m, n, b, a) } else { (n, m, a, b) };
19+
let mut ans = vec![StaticModInt::new(0); n + m - 1];
20+
for i in 0..n {
21+
for j in 0..m {
22+
ans[i + j] += a[i] * b[j];
23+
}
24+
}
25+
return ans;
26+
}
27+
28+
let (mut a, mut b) = (a.to_owned(), b.to_owned());
29+
let z = 1 << internal_bit::ceil_pow2((n + m - 1) as _);
30+
a.resize(z, StaticModInt::raw(0));
31+
butterfly(&mut a);
32+
b.resize(z, StaticModInt::raw(0));
33+
butterfly(&mut b);
34+
for (a, b) in a.iter_mut().zip(&b) {
35+
*a *= b;
36+
}
37+
butterfly_inv(&mut a);
38+
a.resize(n + m - 1, StaticModInt::raw(0));
39+
let iz = StaticModInt::new(z).inv();
40+
for a in &mut a {
41+
*a *= iz;
42+
}
43+
a
44+
}
45+
46+
#[allow(clippy::many_single_char_names)]
47+
pub fn convolution_i64(a: &[i64], b: &[i64]) -> Vec<i64> {
48+
const M1: u64 = 754_974_721; // 2^24
49+
const M2: u64 = 167_772_161; // 2^25
50+
const M3: u64 = 469_762_049; // 2^26
51+
const M2M3: u64 = M2 * M3;
52+
const M1M3: u64 = M1 * M3;
53+
const M1M2: u64 = M1 * M2;
54+
const M1M2M3: u64 = M1M2.wrapping_mul(M3);
55+
56+
macro_rules! moduli {
57+
($($name:ident),*) => {
58+
$(
59+
#[derive(Copy, Clone, Eq, PartialEq)]
60+
enum $name {}
61+
62+
impl Modulus for $name {
63+
const VALUE: u32 = $name as _;
64+
const HINT_VALUE_IS_PRIME: bool = true;
65+
66+
fn butterfly_cache() -> &'static LocalKey<RefCell<Option<ButterflyCache<Self>>>> {
67+
thread_local! {
68+
static BUTTERFLY_CACHE: RefCell<Option<ButterflyCache<$name>>> = RefCell::default();
69+
}
70+
&BUTTERFLY_CACHE
71+
}
72+
}
73+
)*
74+
};
75+
}
76+
77+
moduli!(M1, M2, M3);
78+
79+
if a.is_empty() || b.is_empty() {
80+
return vec![];
81+
}
82+
83+
let i1 = internal_math::inv_gcd(M2M3 as _, M1 as _).1;
84+
let i2 = internal_math::inv_gcd(M1M3 as _, M2 as _).1;
85+
let i3 = internal_math::inv_gcd(M1M2 as _, M3 as _).1;
86+
87+
let (c1, c2, c3) = {
88+
fn c<M: Modulus>(a: &[i64], b: &[i64]) -> Vec<i64> {
89+
let a = a.iter().copied().map(Into::into).collect::<Vec<_>>();
90+
let b = b.iter().copied().map(Into::into).collect::<Vec<_>>();
91+
convolution::<M>(&a, &b)
92+
.into_iter()
93+
.map(|z| z.val().into())
94+
.collect()
95+
}
96+
(c::<M1>(a, b), c::<M2>(a, b), c::<M3>(a, b))
97+
};
98+
99+
c1.into_iter()
100+
.zip(c2)
101+
.zip(c3)
102+
.map(|((c1, c2), c3)| {
103+
const OFFSET: &[u64] = &[0, 0, M1M2M3, 2 * M1M2M3, 3 * M1M2M3];
104+
105+
let mut x = 0;
106+
x += (c1 * i1).rem_euclid(M1 as _) * (M2M3 as i64);
107+
x += (c2 * i2).rem_euclid(M2 as _) * (M1M3 as i64);
108+
x += (c3 * i3).rem_euclid(M3 as _) * (M1M2 as i64);
109+
// B = 2^63, -B <= x, r(real value) < B
110+
// (x, x - M, x - 2M, or x - 3M) = r (mod 2B)
111+
// r = c1[i] (mod MOD1)
112+
// focus on MOD1
113+
// r = x, x - M', x - 2M', x - 3M' (M' = M % 2^64) (mod 2B)
114+
// r = x,
115+
// x - M' + (0 or 2B),
116+
// x - 2M' + (0, 2B or 4B),
117+
// x - 3M' + (0, 2B, 4B or 6B) (without mod!)
118+
// (r - x) = 0, (0)
119+
// - M' + (0 or 2B), (1)
120+
// -2M' + (0 or 2B or 4B), (2)
121+
// -3M' + (0 or 2B or 4B or 6B) (3) (mod MOD1)
122+
// we checked that
123+
// ((1) mod MOD1) mod 5 = 2
124+
// ((2) mod MOD1) mod 5 = 3
125+
// ((3) mod MOD1) mod 5 = 4
126+
let mut diff = c1 - internal_math::safe_mod(x, M1 as _);
127+
if diff < 0 {
128+
diff += M1 as i64;
129+
}
130+
x -= OFFSET[diff.rem_euclid(5) as usize] as i64;
131+
x
132+
})
133+
.collect()
134+
}
135+
136+
#[allow(clippy::many_single_char_names)]
137+
fn butterfly<M: Modulus>(a: &mut [StaticModInt<M>]) {
138+
let n = a.len();
139+
let h = internal_bit::ceil_pow2(n as u32);
140+
141+
M::butterfly_cache().with(|cache| {
142+
let mut cache = cache.borrow_mut();
143+
let ButterflyCache { sum_e, .. } = cache.get_or_insert_with(prepare);
144+
for ph in 1..=h {
145+
let w = 1 << (ph - 1);
146+
let p = 1 << (h - ph);
147+
let mut now = StaticModInt::<M>::new(1);
148+
for s in 0..w {
149+
let offset = s << (h - ph + 1);
150+
for i in 0..p {
151+
let l = a[i + offset];
152+
let r = a[i + offset + p] * now;
153+
a[i + offset] = l + r;
154+
a[i + offset + p] = l - r;
155+
}
156+
now *= sum_e[(!s).trailing_zeros() as usize];
157+
}
158+
}
159+
});
160+
}
161+
162+
#[allow(clippy::many_single_char_names)]
163+
fn butterfly_inv<M: Modulus>(a: &mut [StaticModInt<M>]) {
164+
let n = a.len();
165+
let h = internal_bit::ceil_pow2(n as u32);
166+
167+
M::butterfly_cache().with(|cache| {
168+
let mut cache = cache.borrow_mut();
169+
let ButterflyCache { sum_ie, .. } = cache.get_or_insert_with(prepare);
170+
for ph in (1..=h).rev() {
171+
let w = 1 << (ph - 1);
172+
let p = 1 << (h - ph);
173+
let mut inow = StaticModInt::<M>::new(1);
174+
for s in 0..w {
175+
let offset = s << (h - ph + 1);
176+
for i in 0..p {
177+
let l = a[i + offset];
178+
let r = a[i + offset + p];
179+
a[i + offset] = l + r;
180+
a[i + offset + p] = StaticModInt::new(M::VALUE + l.val() - r.val()) * inow;
181+
}
182+
inow *= sum_ie[(!s).trailing_zeros() as usize];
183+
}
184+
}
185+
});
186+
}
187+
188+
fn prepare<M: Modulus>() -> ButterflyCache<M> {
189+
let g = StaticModInt::<M>::raw(internal_math::primitive_root(M::VALUE as i32) as u32);
190+
let mut es = [StaticModInt::<M>::raw(0); 30]; // es[i]^(2^(2+i)) == 1
191+
let mut ies = [StaticModInt::<M>::raw(0); 30];
192+
let cnt2 = (M::VALUE - 1).trailing_zeros() as usize;
193+
let mut e = g.pow(((M::VALUE - 1) >> cnt2).into());
194+
let mut ie = e.inv();
195+
for i in (2..=cnt2).rev() {
196+
es[i - 2] = e;
197+
ies[i - 2] = ie;
198+
e *= e;
199+
ie *= ie;
200+
}
201+
let sum_e = es
202+
.iter()
203+
.scan(StaticModInt::new(1), |acc, e| {
204+
*acc *= e;
205+
Some(*acc)
206+
})
207+
.collect();
208+
let sum_ie = ies
209+
.iter()
210+
.scan(StaticModInt::new(1), |acc, ie| {
211+
*acc *= ie;
212+
Some(*acc)
213+
})
214+
.collect();
215+
ButterflyCache { sum_e, sum_ie }
216+
}

Diff for: src/lib.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,14 @@ pub(crate) mod internal_queue;
1818
pub(crate) mod internal_scc;
1919
pub(crate) mod internal_type_traits;
2020

21+
pub use convolution::{convolution, convolution_i64};
2122
pub use dsu::Dsu;
2223
pub use fenwicktree::FenwickTree;
2324
pub use math::{crt, floor_sum, inv_mod, pow_mod};
2425
pub use mincostflow::MinCostFlowGraph;
2526
pub use modint::{
26-
Barrett, DefaultId, DynamicModInt, Id, Mod1000000007, Mod998244353, ModInt, ModInt1000000007,
27-
ModInt998244353, Modulus, RemEuclidU32, StaticModInt,
27+
Barrett, ButterflyCache, DefaultId, DynamicModInt, Id, Mod1000000007, Mod998244353, ModInt,
28+
ModInt1000000007, ModInt998244353, Modulus, RemEuclidU32, StaticModInt,
2829
};
2930
pub use string::{
3031
lcp_array, lcp_array_arbitrary, suffix_array, suffix_array_arbitrary, suffix_array_manual,

Diff for: src/modint.rs

+21-2
Original file line numberDiff line numberDiff line change
@@ -117,8 +117,8 @@ impl<M: Modulus> ModIntBase for StaticModInt<M> {
117117
pub trait Modulus: 'static + Copy + Eq {
118118
const VALUE: u32;
119119
const HINT_VALUE_IS_PRIME: bool;
120-
// does not work well
121-
//const _STATIC_ASSERT_VALUE_IS_NON_ZERO: () = [()][(Self::VALUE == 0) as usize];
120+
121+
fn butterfly_cache() -> &'static LocalKey<RefCell<Option<ButterflyCache<Self>>>>;
122122
}
123123

124124
#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
@@ -127,6 +127,13 @@ pub enum Mod1000000007 {}
127127
impl Modulus for Mod1000000007 {
128128
const VALUE: u32 = 1_000_000_007;
129129
const HINT_VALUE_IS_PRIME: bool = true;
130+
131+
fn butterfly_cache() -> &'static LocalKey<RefCell<Option<ButterflyCache<Self>>>> {
132+
thread_local! {
133+
static BUTTERFLY_CACHE: RefCell<Option<ButterflyCache<Mod1000000007>>> = RefCell::default();
134+
}
135+
&BUTTERFLY_CACHE
136+
}
130137
}
131138

132139
#[derive(Copy, Clone, Ord, PartialOrd, Eq, PartialEq, Hash, Debug)]
@@ -135,6 +142,18 @@ pub enum Mod998244353 {}
135142
impl Modulus for Mod998244353 {
136143
const VALUE: u32 = 998_244_353;
137144
const HINT_VALUE_IS_PRIME: bool = true;
145+
146+
fn butterfly_cache() -> &'static LocalKey<RefCell<Option<ButterflyCache<Self>>>> {
147+
thread_local! {
148+
static BUTTERFLY_CACHE: RefCell<Option<ButterflyCache<Mod998244353>>> = RefCell::default();
149+
}
150+
&BUTTERFLY_CACHE
151+
}
152+
}
153+
154+
pub struct ButterflyCache<M> {
155+
pub(crate) sum_e: Vec<StaticModInt<M>>,
156+
pub(crate) sum_ie: Vec<StaticModInt<M>>,
138157
}
139158

140159
#[derive(Copy, Clone, Eq, PartialEq)]

0 commit comments

Comments
 (0)