@@ -101,25 +101,29 @@ void butterfly_inv(std::vector<mint>& a) {
101
101
}
102
102
}
103
103
104
- } // namespace internal
105
-
106
104
template <class mint , internal::is_static_modint_t <mint>* = nullptr >
107
- std::vector<mint> convolution ( std::vector<mint> a, std::vector<mint> b) {
105
+ std::vector<mint> convolution_naive ( const std::vector<mint>& a, const std::vector<mint>& b) {
108
106
int n = int (a.size ()), m = int (b.size ());
109
- if (!n || !m) return {};
110
- if (std::min (n, m) <= 60 ) {
111
- if (n < m) {
112
- std::swap (n, m);
113
- std::swap (a, b);
107
+ std::vector<mint> ans (n + m - 1 );
108
+ if (n < m) {
109
+ for (int j = 0 ; j < m; j++) {
110
+ for (int i = 0 ; i < n; i++) {
111
+ ans[i + j] += a[i] * b[j];
112
+ }
114
113
}
115
- std::vector<mint> ans (n + m - 1 );
114
+ } else {
116
115
for (int i = 0 ; i < n; i++) {
117
116
for (int j = 0 ; j < m; j++) {
118
117
ans[i + j] += a[i] * b[j];
119
118
}
120
119
}
121
- return ans;
122
120
}
121
+ return ans;
122
+ }
123
+
124
+ template <class mint , internal::is_static_modint_t <mint>* = nullptr >
125
+ std::vector<mint> convolution_fft (std::vector<mint> a, std::vector<mint> b) {
126
+ int n = int (a.size ()), m = int (b.size ());
123
127
int z = 1 << internal::ceil_pow2 (n + m - 1 );
124
128
a.resize (z);
125
129
internal::butterfly (a);
@@ -132,7 +136,25 @@ std::vector<mint> convolution(std::vector<mint> a, std::vector<mint> b) {
132
136
a.resize (n + m - 1 );
133
137
mint iz = mint (z).inv ();
134
138
for (int i = 0 ; i < n + m - 1 ; i++) a[i] *= iz;
135
- return a;
139
+ return std::move (a);
140
+ }
141
+
142
+ } // namespace internal
143
+
144
+ template <class mint , internal::is_static_modint_t <mint>* = nullptr >
145
+ std::vector<mint> convolution (std::vector<mint>&& a, std::vector<mint>&& b) {
146
+ int n = int (a.size ()), m = int (b.size ());
147
+ if (!n || !m) return {};
148
+ if (std::min (n, m) <= 60 ) return convolution_naive (a, b);
149
+ return internal::convolution_fft (a, b);
150
+ }
151
+
152
+ template <class mint , internal::is_static_modint_t <mint>* = nullptr >
153
+ std::vector<mint> convolution (const std::vector<mint>& a, const std::vector<mint>& b) {
154
+ int n = int (a.size ()), m = int (b.size ());
155
+ if (!n || !m) return {};
156
+ if (std::min (n, m) <= 60 ) return convolution_naive (a, b);
157
+ return internal::convolution_fft (a, b);
136
158
}
137
159
138
160
template <unsigned int mod = 998244353 ,
0 commit comments