Skip to content

Commit 752830c

Browse files
authored
Merge pull request #96 from atcoder/patch/issue95
#95: fix convolution
2 parents a9fb2b4 + 04af374 commit 752830c

File tree

5 files changed

+95
-11
lines changed

5 files changed

+95
-11
lines changed

.gitmodules

+3
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
11
[submodule "test/unittest/googletest"]
22
path = test/unittest/googletest
33
url = https://github.com/google/googletest
4+
[submodule "test/benchmark/benchmark"]
5+
path = test/benchmark/benchmark
6+
url = https://github.com/google/benchmark

atcoder/convolution.hpp

+33-11
Original file line numberDiff line numberDiff line change
@@ -101,25 +101,29 @@ void butterfly_inv(std::vector<mint>& a) {
101101
}
102102
}
103103

104-
} // namespace internal
105-
106104
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
107-
std::vector<mint> convolution(std::vector<mint> a, std::vector<mint> b) {
105+
std::vector<mint> convolution_naive(const std::vector<mint>& a, const std::vector<mint>& b) {
108106
int n = int(a.size()), m = int(b.size());
109-
if (!n || !m) return {};
110-
if (std::min(n, m) <= 60) {
111-
if (n < m) {
112-
std::swap(n, m);
113-
std::swap(a, b);
107+
std::vector<mint> ans(n + m - 1);
108+
if (n < m) {
109+
for (int j = 0; j < m; j++) {
110+
for (int i = 0; i < n; i++) {
111+
ans[i + j] += a[i] * b[j];
112+
}
114113
}
115-
std::vector<mint> ans(n + m - 1);
114+
} else {
116115
for (int i = 0; i < n; i++) {
117116
for (int j = 0; j < m; j++) {
118117
ans[i + j] += a[i] * b[j];
119118
}
120119
}
121-
return ans;
122120
}
121+
return ans;
122+
}
123+
124+
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
125+
std::vector<mint> convolution_fft(std::vector<mint> a, std::vector<mint> b) {
126+
int n = int(a.size()), m = int(b.size());
123127
int z = 1 << internal::ceil_pow2(n + m - 1);
124128
a.resize(z);
125129
internal::butterfly(a);
@@ -132,7 +136,25 @@ std::vector<mint> convolution(std::vector<mint> a, std::vector<mint> b) {
132136
a.resize(n + m - 1);
133137
mint iz = mint(z).inv();
134138
for (int i = 0; i < n + m - 1; i++) a[i] *= iz;
135-
return a;
139+
return std::move(a);
140+
}
141+
142+
} // namespace internal
143+
144+
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
145+
std::vector<mint> convolution(std::vector<mint>&& a, std::vector<mint>&& b) {
146+
int n = int(a.size()), m = int(b.size());
147+
if (!n || !m) return {};
148+
if (std::min(n, m) <= 60) return convolution_naive(a, b);
149+
return internal::convolution_fft(a, b);
150+
}
151+
152+
template <class mint, internal::is_static_modint_t<mint>* = nullptr>
153+
std::vector<mint> convolution(const std::vector<mint>& a, const std::vector<mint>& b) {
154+
int n = int(a.size()), m = int(b.size());
155+
if (!n || !m) return {};
156+
if (std::min(n, m) <= 60) return convolution_naive(a, b);
157+
return internal::convolution_fft(a, b);
136158
}
137159

138160
template <unsigned int mod = 998244353,

test/benchmark/CMakeLists.txt

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
cmake_policy(SET CMP0048 NEW)
2+
project(ACLibrary)
3+
4+
cmake_minimum_required(VERSION 3.17)
5+
6+
set(GOOGLETEST_PATH "${CMAKE_CURRENT_SOURCE_DIR}/../unittest/googletest")
7+
8+
if(NOT "${CMAKE_CXX_STANDARD}")
9+
set(CMAKE_CXX_STANDARD 14)
10+
endif()
11+
set(CMAKE_CXX_EXTENSIONS OFF)
12+
13+
add_compile_options(-Wall -Wextra -Wshadow -Wconversion -Wno-sign-conversion -Werror)
14+
15+
add_subdirectory(benchmark)
16+
include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
17+
include(GoogleTest)
18+
19+
include_directories(.)
20+
include_directories(../../)
21+
22+
add_executable(Convolution convolution.cpp)
23+
target_link_libraries(Convolution benchmark::benchmark)

test/benchmark/benchmark

Submodule benchmark added at 8df87f6

test/benchmark/convolution.cpp

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
#include "atcoder/convolution"
2+
#include <iostream>
3+
4+
#include "benchmark/benchmark.h"
5+
6+
using namespace std;
7+
using namespace atcoder;
8+
using mint = modint998244353;
9+
10+
void CONV_same_length(benchmark::State& state) {
11+
vector<mint> a(state.range(0)), b(state.range(0));
12+
for (int i = 0; i < state.range(0); i++) {
13+
a[i] = i + 1234;
14+
b[i] = i + 5678;
15+
}
16+
for (auto _ : state) {
17+
benchmark::DoNotOptimize(convolution(a, b));
18+
}
19+
}
20+
BENCHMARK(CONV_same_length)->RangeMultiplier(2)->Range(1, 1<<20);
21+
BENCHMARK(CONV_same_length)->DenseRange(1, 100, 1);
22+
23+
void CONV_long_empty(benchmark::State& state) {
24+
vector<mint> a(state.range(0)), b;
25+
for (int i = 0; i < state.range(0); i++) {
26+
a[i] = i + 1234;
27+
}
28+
for (auto _ : state) {
29+
benchmark::DoNotOptimize(convolution(a, b));
30+
benchmark::DoNotOptimize(convolution(b, a));
31+
}
32+
}
33+
BENCHMARK(CONV_long_empty)->RangeMultiplier(2)->Range(1, 1 << 20);
34+
35+
BENCHMARK_MAIN();

0 commit comments

Comments
 (0)