Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fft optimize #120

Merged
merged 7 commits into from
Jul 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 143 additions & 68 deletions atcoder/convolution.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,95 +14,169 @@ namespace atcoder {

namespace internal {

template <class mint,
int g = internal::primitive_root<mint::mod()>,
internal::is_static_modint_t<mint>* = nullptr>
struct fft_info {
static constexpr int rank2 = bsf_constexpr(mint::mod() - 1);
std::array<mint, rank2 + 1> root; // root[i]^(2^i) == 1
std::array<mint, rank2 + 1> iroot; // root[i] * iroot[i] == 1

std::array<mint, std::max(0, rank2 - 2 + 1)> rate2;
std::array<mint, std::max(0, rank2 - 2 + 1)> irate2;

std::array<mint, std::max(0, rank2 - 3 + 1)> rate3;
std::array<mint, std::max(0, rank2 - 3 + 1)> irate3;

fft_info() {
root[rank2] = mint(g).pow((mint::mod() - 1) >> rank2);
iroot[rank2] = root[rank2].inv();
for (int i = rank2 - 1; i >= 0; i--) {
root[i] = root[i + 1] * root[i + 1];
iroot[i] = iroot[i + 1] * iroot[i + 1];
}

{
mint prod = 1, iprod = 1;
for (int i = 0; i <= rank2 - 2; i++) {
rate2[i] = root[i + 2] * prod;
irate2[i] = iroot[i + 2] * iprod;
prod *= iroot[i + 2];
iprod *= root[i + 2];
}
}
{
mint prod = 1, iprod = 1;
for (int i = 0; i <= rank2 - 3; i++) {
rate3[i] = root[i + 3] * prod;
irate3[i] = iroot[i + 3] * iprod;
prod *= iroot[i + 3];
iprod *= root[i + 3];
}
}
}
};

template <class mint, internal::is_static_modint_t<mint>* = nullptr>
void butterfly(std::vector<mint>& a) {
static constexpr int g = internal::primitive_root<mint::mod()>;
int n = int(a.size());
int h = internal::ceil_pow2(n);

static bool first = true;
static mint sum_e[30]; // sum_e[i] = ies[0] * ... * ies[i - 1] * es[i]
if (first) {
first = false;
mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1
int cnt2 = bsf(mint::mod() - 1);
mint e = mint(g).pow((mint::mod() - 1) >> cnt2), ie = e.inv();
for (int i = cnt2; i >= 2; i--) {
// e^(2^i) == 1
es[i - 2] = e;
ies[i - 2] = ie;
e *= e;
ie *= ie;
}
mint now = 1;
for (int i = 0; i <= cnt2 - 2; i++) {
sum_e[i] = es[i] * now;
now *= ies[i];
}
}
for (int ph = 1; ph <= h; ph++) {
int w = 1 << (ph - 1), p = 1 << (h - ph);
mint now = 1;
for (int s = 0; s < w; s++) {
int offset = s << (h - ph + 1);
for (int i = 0; i < p; i++) {
auto l = a[i + offset];
auto r = a[i + offset + p] * now;
a[i + offset] = l + r;
a[i + offset + p] = l - r;
static const fft_info<mint> info;

int len = 0; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
while (len < h) {
if (h - len == 1) {
int p = 1 << (h - len - 1);
mint rot = 1;
for (int s = 0; s < (1 << len); s++) {
int offset = s << (h - len);
for (int i = 0; i < p; i++) {
auto l = a[i + offset];
auto r = a[i + offset + p] * rot;
a[i + offset] = l + r;
a[i + offset + p] = l - r;
}
if (s + 1 != (1 << len))
rot *= info.rate2[bsf(~(unsigned int)(s))];
}
len++;
} else {
// 4-base
int p = 1 << (h - len - 2);
mint rot = 1, imag = info.root[2];
for (int s = 0; s < (1 << len); s++) {
mint rot2 = rot * rot;
mint rot3 = rot2 * rot;
int offset = s << (h - len);
for (int i = 0; i < p; i++) {
auto mod2 = 1ULL * mint::mod() * mint::mod();
auto a0 = 1ULL * a[i + offset].val();
auto a1 = 1ULL * a[i + offset + p].val() * rot.val();
auto a2 = 1ULL * a[i + offset + 2 * p].val() * rot2.val();
auto a3 = 1ULL * a[i + offset + 3 * p].val() * rot3.val();
auto a1na3imag =
1ULL * mint(a1 + mod2 - a3).val() * imag.val();
auto na2 = mod2 - a2;
a[i + offset] = a0 + a2 + a1 + a3;
a[i + offset + 1 * p] = a0 + a2 + (2 * mod2 - (a1 + a3));
a[i + offset + 2 * p] = a0 + na2 + a1na3imag;
a[i + offset + 3 * p] = a0 + na2 + (mod2 - a1na3imag);
}
if (s + 1 != (1 << len))
rot *= info.rate3[bsf(~(unsigned int)(s))];
}
now *= sum_e[bsf(~(unsigned int)(s))];
len += 2;
}
}
}

template <class mint, internal::is_static_modint_t<mint>* = nullptr>
void butterfly_inv(std::vector<mint>& a) {
static constexpr int g = internal::primitive_root<mint::mod()>;
int n = int(a.size());
int h = internal::ceil_pow2(n);

static bool first = true;
static mint sum_ie[30]; // sum_ie[i] = es[0] * ... * es[i - 1] * ies[i]
if (first) {
first = false;
mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1
int cnt2 = bsf(mint::mod() - 1);
mint e = mint(g).pow((mint::mod() - 1) >> cnt2), ie = e.inv();
for (int i = cnt2; i >= 2; i--) {
// e^(2^i) == 1
es[i - 2] = e;
ies[i - 2] = ie;
e *= e;
ie *= ie;
}
mint now = 1;
for (int i = 0; i <= cnt2 - 2; i++) {
sum_ie[i] = ies[i] * now;
now *= es[i];
}
}
static const fft_info<mint> info;

int len = h; // a[i, i+(n>>len), i+2*(n>>len), ..] is transformed
while (len) {
if (len == 1) {
int p = 1 << (h - len);
mint irot = 1;
for (int s = 0; s < (1 << (len - 1)); s++) {
int offset = s << (h - len + 1);
for (int i = 0; i < p; i++) {
auto l = a[i + offset];
auto r = a[i + offset + p];
a[i + offset] = l + r;
a[i + offset + p] =
(unsigned long long)(mint::mod() + l.val() - r.val()) *
irot.val();
;
}
if (s + 1 != (1 << (len - 1)))
irot *= info.irate2[bsf(~(unsigned int)(s))];
}
len--;
} else {
// 4-base
int p = 1 << (h - len);
mint irot = 1, iimag = info.iroot[2];
for (int s = 0; s < (1 << (len - 2)); s++) {
mint irot2 = irot * irot;
mint irot3 = irot2 * irot;
int offset = s << (h - len + 2);
for (int i = 0; i < p; i++) {
auto a0 = 1ULL * a[i + offset + 0 * p].val();
auto a1 = 1ULL * a[i + offset + 1 * p].val();
auto a2 = 1ULL * a[i + offset + 2 * p].val();
auto a3 = 1ULL * a[i + offset + 3 * p].val();

auto a2na3iimag =
1ULL *
mint((mint::mod() + a2 - a3) * iimag.val()).val();

for (int ph = h; ph >= 1; ph--) {
int w = 1 << (ph - 1), p = 1 << (h - ph);
mint inow = 1;
for (int s = 0; s < w; s++) {
int offset = s << (h - ph + 1);
for (int i = 0; i < p; i++) {
auto l = a[i + offset];
auto r = a[i + offset + p];
a[i + offset] = l + r;
a[i + offset + p] =
(unsigned long long)(mint::mod() + l.val() - r.val()) *
inow.val();
a[i + offset] = a0 + a1 + a2 + a3;
a[i + offset + 1 * p] =
(a0 + (mint::mod() - a1) + a2na3iimag) * irot.val();
a[i + offset + 2 * p] =
(a0 + a1 + (mint::mod() - a2) + (mint::mod() - a3)) *
irot2.val();
a[i + offset + 3 * p] =
(a0 + (mint::mod() - a1) + (mint::mod() - a2na3iimag)) *
irot3.val();
}
if (s + 1 != (1 << (len - 2)))
irot *= info.irate3[bsf(~(unsigned int)(s))];
}
inow *= sum_ie[bsf(~(unsigned int)(s))];
len -= 2;
}
}
}

template <class mint, internal::is_static_modint_t<mint>* = nullptr>
std::vector<mint> convolution_naive(const std::vector<mint>& a, const std::vector<mint>& b) {
std::vector<mint> convolution_naive(const std::vector<mint>& a,
const std::vector<mint>& b) {
int n = int(a.size()), m = int(b.size());
std::vector<mint> ans(n + m - 1);
if (n < m) {
Expand Down Expand Up @@ -150,7 +224,8 @@ std::vector<mint> convolution(std::vector<mint>&& a, std::vector<mint>&& b) {
}

template <class mint, internal::is_static_modint_t<mint>* = nullptr>
std::vector<mint> convolution(const std::vector<mint>& a, const std::vector<mint>& b) {
std::vector<mint> convolution(const std::vector<mint>& a,
const std::vector<mint>& b) {
int n = int(a.size()), m = int(b.size());
if (!n || !m) return {};
if (std::min(n, m) <= 60) return convolution_naive(a, b);
Expand Down
8 changes: 8 additions & 0 deletions atcoder/internal_bit.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ int ceil_pow2(int n) {
return x;
}

// @param n `1 <= n`
// @return minimum non-negative `x` s.t. `(n & (1 << x)) != 0`
constexpr int bsf_constexpr(unsigned int n) {
int x = 0;
while (!(n & (1 << x))) x++;
return x;
}

// @param n `1 <= n`
// @return minimum non-negative `x` s.t. `(n & (1 << x)) != 0`
int bsf(unsigned int n) {
Expand Down
44 changes: 44 additions & 0 deletions test/unittest/convolution_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,3 +385,47 @@ TEST(ConvolutionTest, Conv18433) {

ASSERT_EQ(conv_naive<MOD>(a, b), convolution<MOD>(a, b));
}

TEST(ConvolutionTest, Conv2) {
std::vector<ll> empty = {};
ASSERT_EQ(empty, convolution<2>(empty, empty));
}

TEST(ConvolutionTest, Conv257) {
const int MOD = 257;
std::vector<ll> a(128), b(129);
for (int i = 0; i < 128; i++) {
a[i] = randint(0, MOD - 1);
}
for (int i = 0; i < 129; i++) {
b[i] = randint(0, MOD - 1);
}

ASSERT_EQ(conv_naive<MOD>(a, b), convolution<MOD>(a, b));
}

TEST(ConvolutionTest, Conv2147483647) {
const int MOD = 2147483647;
using mint = static_modint<MOD>;
std::vector<mint> a(1), b(2);
for (int i = 0; i < 1; i++) {
a[i] = randint(0, MOD - 1);
}
for (int i = 0; i < 2; i++) {
b[i] = randint(0, MOD - 1);
}
ASSERT_EQ(conv_naive(a, b), convolution(a, b));
}

TEST(ConvolutionTest, Conv2130706433) {
const int MOD = 2130706433;
using mint = static_modint<MOD>;
std::vector<mint> a(1024), b(1024);
for (int i = 0; i < 1024; i++) {
a[i] = randint(0, MOD - 1);
}
for (int i = 0; i < 1024; i++) {
b[i] = randint(0, MOD - 1);
}
ASSERT_EQ(conv_naive(a, b), convolution(a, b));
}
30 changes: 30 additions & 0 deletions test/unittest/modint_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,29 @@ TEST(ModintTest, Mod1) {
ASSERT_EQ(0, mint(true).val());
}

TEST(ModintTest, ModIntMax) {
modint::set_mod(INT32_MAX);
for (int i = 0; i < 100; i++) {
for (int j = 0; j < 100; j++) {
ASSERT_EQ((modint(i) * modint(j)).val(), i * j);
}
}
ASSERT_EQ((modint(1234) + modint(5678)).val(), 1234 + 5678);
ASSERT_EQ((modint(1234) - modint(5678)).val(), INT32_MAX - 5678 + 1234);
ASSERT_EQ((modint(1234) * modint(5678)).val(), 1234 * 5678);

using mint = static_modint<INT32_MAX>;
for (int i = 0; i < 100; i++) {
for (int j = 0; j < 100; j++) {
ASSERT_EQ((mint(i) * mint(j)).val(), i * j);
}
}
ASSERT_EQ((mint(1234) + mint(5678)).val(), 1234 + 5678);
ASSERT_EQ((mint(1234) - mint(5678)).val(), INT32_MAX - 5678 + 1234);
ASSERT_EQ((mint(1234) * mint(5678)).val(), 1234 * 5678);
ASSERT_EQ((mint(INT32_MAX) + mint(INT32_MAX)).val(), 0);
}

#ifndef _MSC_VER

TEST(ModintTest, Int128) {
Expand Down Expand Up @@ -158,6 +181,13 @@ TEST(ModintTest, Inv) {
int x = modint(i).inv().val();
ASSERT_EQ(1, (ll(x) * i) % 1'000'000'008);
}

modint::set_mod(INT32_MAX);
for (int i = 1; i < 100000; i++) {
if (gcd(i, INT32_MAX) != 1) continue;
int x = modint(i).inv().val();
ASSERT_EQ(1, (ll(x) * i) % INT32_MAX);
}
}

TEST(ModintTest, ConstUsage) {
Expand Down