Skip to content

Commit 2c99feb

Browse files
committed
Update convilution
atcoder/ac-library#96
1 parent e8882ed commit 2c99feb

File tree

2 files changed

+59
-58
lines changed

2 files changed

+59
-58
lines changed

Source/AtCoderLibrary/Math/Internal/InternalMath.cs

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Runtime.CompilerServices;
34

45
namespace AtCoder.Internal
56
{
@@ -116,6 +117,7 @@ public static (long, long) InvGCD(long a, long b)
116117
}
117118
}
118119

120+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
119121
public static long SafeMod(long x, long m)
120122
{
121123
x %= m;

Source/AtCoderLibrary/Math/MathLib.Convolution.cs

+57-58
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using System.Diagnostics;
23
using AtCoder.Internal;
34

45
namespace AtCoder
@@ -18,10 +19,7 @@ public static partial class MathLib
1819
/// </remarks>
1920
public static StaticModInt<TMod>[] Convolution<TMod>(StaticModInt<TMod>[] a, StaticModInt<TMod>[] b)
2021
where TMod : struct, IStaticMod
21-
{
22-
var temp = Convolution((ReadOnlySpan<StaticModInt<TMod>>)a, b);
23-
return temp.ToArray();
24-
}
22+
=> Convolution((ReadOnlySpan<StaticModInt<TMod>>)a, b);
2523

2624
/// <summary>
2725
/// 畳み込みを mod <typeparamref name="TMod"/> で計算します。
@@ -34,52 +32,61 @@ public static StaticModInt<TMod>[] Convolution<TMod>(StaticModInt<TMod>[] a, Sta
3432
/// <para>- 2^c | (<typeparamref name="TMod"/> - 1) かつ |<paramref name="a"/>| + |<paramref name="b"/>| - 1 ≤ 2^c なる c が存在する</para>
3533
/// <para>計算量: O((|<paramref name="a"/>|+|<paramref name="b"/>|)log(|<paramref name="a"/>|+|<paramref name="b"/>|) + log<typeparamref name="TMod"/>)</para>
3634
/// </remarks>
37-
public static Span<StaticModInt<TMod>> Convolution<TMod>(ReadOnlySpan<StaticModInt<TMod>> a, ReadOnlySpan<StaticModInt<TMod>> b)
35+
public static StaticModInt<TMod>[] Convolution<TMod>(ReadOnlySpan<StaticModInt<TMod>> a, ReadOnlySpan<StaticModInt<TMod>> b)
3836
where TMod : struct, IStaticMod
3937
{
4038
var n = a.Length;
4139
var m = b.Length;
4240
if (n == 0 || m == 0)
43-
{
4441
return Array.Empty<StaticModInt<TMod>>();
45-
}
46-
4742
if (Math.Min(n, m) <= 60)
48-
{
4943
return ConvolutionNaive(a, b);
50-
}
51-
52-
int z = 1 << InternalBit.CeilPow2(n + m - 1);
53-
54-
var aTemp = new StaticModInt<TMod>[z];
55-
a.CopyTo(aTemp);
56-
57-
var bTemp = new StaticModInt<TMod>[z];
58-
b.CopyTo(bTemp);
59-
60-
return Convolution(aTemp.AsSpan(), bTemp.AsSpan(), n, m, z);
44+
return ConvolutionFFT(a.ToArray(), b.ToArray());
6145
}
6246

63-
private static Span<StaticModInt<TMod>> Convolution<TMod>(Span<StaticModInt<TMod>> a, Span<StaticModInt<TMod>> b, int n, int m, int z)
47+
private static StaticModInt<TMod>[] ConvolutionFFT<TMod>(StaticModInt<TMod>[] a, StaticModInt<TMod>[] b)
6448
where TMod : struct, IStaticMod
6549
{
50+
int n = a.Length, m = b.Length;
51+
int z = 1 << InternalBit.CeilPow2(n + m - 1);
52+
Array.Resize(ref a, z);
6653
Butterfly<TMod>.Calculate(a);
54+
Array.Resize(ref b, z);
6755
Butterfly<TMod>.Calculate(b);
6856

6957
for (int i = 0; i < a.Length; i++)
70-
{
7158
a[i] *= b[i];
72-
}
7359

7460
Butterfly<TMod>.CalculateInv(a);
75-
var result = a.Slice(0, n + m - 1);
61+
Array.Resize(ref a, n + m - 1);
7662
var iz = new StaticModInt<TMod>(z).Inv();
77-
foreach (ref var r in result)
63+
64+
for (int i = 0; i < a.Length; i++)
65+
a[i] *= iz;
66+
67+
return a;
68+
}
69+
private static StaticModInt<TMod>[] ConvolutionNaive<TMod>(ReadOnlySpan<StaticModInt<TMod>> a, ReadOnlySpan<StaticModInt<TMod>> b)
70+
where TMod : struct, IStaticMod
71+
{
72+
if (a.Length < b.Length)
7873
{
79-
r *= iz;
74+
// ref 構造体のため型引数として使えない
75+
var temp = a;
76+
a = b;
77+
b = temp;
8078
}
8179

82-
return result;
80+
var ans = new StaticModInt<TMod>[a.Length + b.Length - 1];
81+
for (int i = 0; i < a.Length; i++)
82+
{
83+
for (int j = 0; j < b.Length; j++)
84+
{
85+
ans[i + j] += a[i] * b[j];
86+
}
87+
}
88+
89+
return ans;
8390
}
8491

8592
/// <summary>
@@ -104,26 +111,30 @@ public static long[] ConvolutionLong(ReadOnlySpan<long> a, ReadOnlySpan<long> b)
104111
return Array.Empty<long>();
105112
}
106113

107-
const ulong Mod1 = 754974721;
108-
const ulong Mod2 = 167772161;
109-
const ulong Mod3 = 469762049;
114+
const ulong Mod1 = 754974721; // 2^24
115+
const ulong Mod2 = 167772161; // 2^25
116+
const ulong Mod3 = 469762049; // 2^26
110117
const ulong M2M3 = Mod2 * Mod3;
111118
const ulong M1M3 = Mod1 * Mod3;
112119
const ulong M1M2 = Mod1 * Mod2;
113120
// (m1 * m2 * m3) % 2^64
114121
const ulong M1M2M3 = Mod1 * Mod2 * Mod3;
115122

116-
ulong i1 = (ulong)InternalMath.InvGCD((long)M2M3, (long)Mod1).Item2;
117-
ulong i2 = (ulong)InternalMath.InvGCD((long)M1M3, (long)Mod2).Item2;
118-
ulong i3 = (ulong)InternalMath.InvGCD((long)M1M2, (long)Mod3).Item2;
123+
const ulong i1 = 190329765;
124+
const ulong i2 = 58587104;
125+
const ulong i3 = 187290749;
126+
127+
Debug.Assert(i1 == (ulong)InternalMath.InvGCD((long)M2M3, (long)Mod1).Item2);
128+
Debug.Assert(i2 == (ulong)InternalMath.InvGCD((long)M1M3, (long)Mod2).Item2);
129+
Debug.Assert(i3 == (ulong)InternalMath.InvGCD((long)M1M2, (long)Mod3).Item2);
119130

120131
var c1 = Convolution<FFTMod1>(a, b);
121132
var c2 = Convolution<FFTMod2>(a, b);
122133
var c3 = Convolution<FFTMod3>(a, b);
123134

124135
var c = new long[n + m - 1];
125136

126-
Span<ulong> offset = stackalloc ulong[] { 0, 0, M1M2M3, 2 * M1M2M3, 3 * M1M2M3 };
137+
//ReadOnlySpan<ulong> offset = stackalloc ulong[] { 0, 0, M1M2M3, 2 * M1M2M3, 3 * M1M2M3 };
127138

128139
for (int i = 0; i < c.Length; i++)
129140
{
@@ -145,37 +156,25 @@ public static long[] ConvolutionLong(ReadOnlySpan<long> a, ReadOnlySpan<long> b)
145156
// x - 3M' + (0 or 2B or 4B or 6B)
146157
// のいずれかが成り立つ、らしい
147158
// -> see atcoder/convolution.hpp
148-
x -= offset[(int)(diff % offset.Length)];
159+
switch (diff % 5)
160+
{
161+
case 2:
162+
x -= M1M2M3;
163+
break;
164+
case 3:
165+
x -= 2 * M1M2M3;
166+
break;
167+
case 4:
168+
x -= 3 * M1M2M3;
169+
break;
170+
}
149171
c[i] = (long)x;
150172
}
151173

152174
return c;
153175
}
154176
}
155177

156-
private static StaticModInt<TMod>[] ConvolutionNaive<TMod>(ReadOnlySpan<StaticModInt<TMod>> a, ReadOnlySpan<StaticModInt<TMod>> b)
157-
where TMod : struct, IStaticMod
158-
{
159-
if (a.Length < b.Length)
160-
{
161-
// ref 構造体のため型引数として使えない
162-
var temp = a;
163-
a = b;
164-
b = temp;
165-
}
166-
167-
var ans = new StaticModInt<TMod>[a.Length + b.Length - 1];
168-
for (int i = 0; i < a.Length; i++)
169-
{
170-
for (int j = 0; j < b.Length; j++)
171-
{
172-
ans[i + j] += a[i] * b[j];
173-
}
174-
}
175-
176-
return ans;
177-
}
178-
179178
private readonly struct FFTMod1 : IStaticMod
180179
{
181180
public uint Mod => 754974721;

0 commit comments

Comments
 (0)