diff --git a/atcoder/convolution.hpp b/atcoder/convolution.hpp
index 7b27f81..ecfbc44 100644
--- a/atcoder/convolution.hpp
+++ b/atcoder/convolution.hpp
@@ -199,7 +199,6 @@ template <class mint, internal::is_static_modint_t<mint>* = nullptr>
 std::vector<mint> convolution_fft(std::vector<mint> a, std::vector<mint> b) {
     int n = int(a.size()), m = int(b.size());
     int z = (int)internal::bit_ceil((unsigned int)(n + m - 1));
-    assert(mint::mod() % z == 1);
     a.resize(z);
     internal::butterfly(a);
     b.resize(z);
@@ -220,15 +219,22 @@ template <class mint, internal::is_static_modint_t<mint>* = nullptr>
 std::vector<mint> convolution(std::vector<mint>&& a, std::vector<mint>&& b) {
     int n = int(a.size()), m = int(b.size());
     if (!n || !m) return {};
+
+    int z = (int)internal::bit_ceil((unsigned int)(n + m - 1));
+    assert((mint::mod() - 1) % z == 0);
+
     if (std::min(n, m) <= 60) return convolution_naive(a, b);
     return internal::convolution_fft(a, b);
 }
-
 template <class mint, internal::is_static_modint_t<mint>* = nullptr>
 std::vector<mint> convolution(const std::vector<mint>& a,
                               const std::vector<mint>& b) {
     int n = int(a.size()), m = int(b.size());
     if (!n || !m) return {};
+
+    int z = (int)internal::bit_ceil((unsigned int)(n + m - 1));
+    assert((mint::mod() - 1) % z == 0);
+
     if (std::min(n, m) <= 60) return convolution_naive(a, b);
     return internal::convolution_fft(a, b);
 }
@@ -241,6 +247,10 @@ std::vector<T> convolution(const std::vector<T>& a, const std::vector<T>& b) {
     if (!n || !m) return {};
 
     using mint = static_modint<mod>;
+
+    int z = (int)internal::bit_ceil((unsigned int)(n + m - 1));
+    assert((mint::mod() - 1) % z == 0);
+
     std::vector<mint> a2(n), b2(m);
     for (int i = 0; i < n; i++) {
         a2[i] = mint(a[i]);
@@ -280,7 +290,7 @@ std::vector<long long> convolution_ll(const std::vector<long long>& a,
     static_assert(MOD1 % (1ull << MAX_AB_BIT) == 1, "MOD1 isn't enough to support an array length of 2^24.");
     static_assert(MOD2 % (1ull << MAX_AB_BIT) == 1, "MOD2 isn't enough to support an array length of 2^24.");
     static_assert(MOD3 % (1ull << MAX_AB_BIT) == 1, "MOD3 isn't enough to support an array length of 2^24.");
-    assert(a.size() + b.size() - 1 <= (1ull << MAX_AB_BIT));
+    assert(n + m - 1 <= (1 << MAX_AB_BIT));
 
     auto c1 = convolution<MOD1>(a, b);
     auto c2 = convolution<MOD2>(a, b);