From ffb7aa847fe142c1eb346c5083194d7b8f53c029 Mon Sep 17 00:00:00 2001 From: Kohei Morita Date: Tue, 11 Apr 2023 23:57:28 +0900 Subject: [PATCH] fix #149: improve barret algorithm --- atcoder/internal_math.hpp | 7 +++---- test/unittest/internal_math_test.cpp | 24 +++++++++++++++++++++++- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/atcoder/internal_math.hpp b/atcoder/internal_math.hpp index ed3ee42..0df4f37 100644 --- a/atcoder/internal_math.hpp +++ b/atcoder/internal_math.hpp @@ -26,7 +26,7 @@ struct barrett { unsigned int _m; unsigned long long im; - // @param m `1 <= m < 2^31` + // @param m `1 <= m` explicit barrett(unsigned int m) : _m(m), im((unsigned long long)(-1) / m + 1) {} // @return m @@ -55,9 +55,8 @@ struct barrett { unsigned long long x = (unsigned long long)(((unsigned __int128)(z)*im) >> 64); #endif - unsigned int v = (unsigned int)(z - x * _m); - if (_m <= v) v += _m; - return v; + unsigned long long y = x * _m; + return (unsigned int)(z - y + (z < y ? _m : 0)); } }; diff --git a/test/unittest/internal_math_test.cpp b/test/unittest/internal_math_test.cpp index 7e26406..a39289c 100644 --- a/test/unittest/internal_math_test.cpp +++ b/test/unittest/internal_math_test.cpp @@ -56,7 +56,7 @@ TEST(InternalMathTest, Barrett) { ASSERT_EQ(0, bt.mul(0, 0)); } -TEST(InternalMathTest, BarrettBorder) { +TEST(InternalMathTest, BarrettIntBorder) { const int mod_upper = std::numeric_limits::max(); for (unsigned int mod = mod_upper; mod >= mod_upper - 20; mod--) { internal::barrett bt(mod); @@ -78,6 +78,28 @@ TEST(InternalMathTest, BarrettBorder) { } } +TEST(InternalMathTest, BarrettUintBorder) { + const unsigned int mod_upper = std::numeric_limits::max(); + for (unsigned int mod = mod_upper; mod >= mod_upper - 20; mod--) { + internal::barrett bt(mod); + std::vector v; + for (int i = 0; i < 10; i++) { + v.push_back(i); + v.push_back(mod - i); + v.push_back(mod / 2 + i); + v.push_back(mod / 2 - i); + } + for (auto a : v) { + ull a2 = a; + ASSERT_EQ(((a2 * a2) % mod * a2) % mod, bt.mul(a, bt.mul(a, a))); + for (auto b : v) { + ull b2 = b; + ASSERT_EQ((a2 * b2) % mod, bt.mul(a, b)); + } + } + } +} + TEST(InternalMathTest, IsPrime) { ASSERT_FALSE(internal::is_prime<121>); ASSERT_FALSE(internal::is_prime<11 * 13>);