Skip to content

Commit d8ca7f2

Browse files
authored
Merge pull request #163 from yosupo06/patch/issue149
fix #149: improve barret algorithm
2 parents e785647 + ffb7aa8 commit d8ca7f2

File tree

2 files changed

+26
-5
lines changed

2 files changed

+26
-5
lines changed

atcoder/internal_math.hpp

+3-4
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ struct barrett {
2626
unsigned int _m;
2727
unsigned long long im;
2828

29-
// @param m `1 <= m < 2^31`
29+
// @param m `1 <= m`
3030
explicit barrett(unsigned int m) : _m(m), im((unsigned long long)(-1) / m + 1) {}
3131

3232
// @return m
@@ -55,9 +55,8 @@ struct barrett {
5555
unsigned long long x =
5656
(unsigned long long)(((unsigned __int128)(z)*im) >> 64);
5757
#endif
58-
unsigned int v = (unsigned int)(z - x * _m);
59-
if (_m <= v) v += _m;
60-
return v;
58+
unsigned long long y = x * _m;
59+
return (unsigned int)(z - y + (z < y ? _m : 0));
6160
}
6261
};
6362

test/unittest/internal_math_test.cpp

+23-1
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ TEST(InternalMathTest, Barrett) {
5656
ASSERT_EQ(0, bt.mul(0, 0));
5757
}
5858

59-
TEST(InternalMathTest, BarrettBorder) {
59+
TEST(InternalMathTest, BarrettIntBorder) {
6060
const int mod_upper = std::numeric_limits<int>::max();
6161
for (unsigned int mod = mod_upper; mod >= mod_upper - 20; mod--) {
6262
internal::barrett bt(mod);
@@ -78,6 +78,28 @@ TEST(InternalMathTest, BarrettBorder) {
7878
}
7979
}
8080

81+
TEST(InternalMathTest, BarrettUintBorder) {
82+
const unsigned int mod_upper = std::numeric_limits<unsigned int>::max();
83+
for (unsigned int mod = mod_upper; mod >= mod_upper - 20; mod--) {
84+
internal::barrett bt(mod);
85+
std::vector<unsigned int> v;
86+
for (int i = 0; i < 10; i++) {
87+
v.push_back(i);
88+
v.push_back(mod - i);
89+
v.push_back(mod / 2 + i);
90+
v.push_back(mod / 2 - i);
91+
}
92+
for (auto a : v) {
93+
ull a2 = a;
94+
ASSERT_EQ(((a2 * a2) % mod * a2) % mod, bt.mul(a, bt.mul(a, a)));
95+
for (auto b : v) {
96+
ull b2 = b;
97+
ASSERT_EQ((a2 * b2) % mod, bt.mul(a, b));
98+
}
99+
}
100+
}
101+
}
102+
81103
TEST(InternalMathTest, IsPrime) {
82104
ASSERT_FALSE(internal::is_prime<121>);
83105
ASSERT_FALSE(internal::is_prime<11 * 13>);

0 commit comments

Comments
 (0)