Skip to content

Commit

Permalink
Improve performance of BigInteger.Multiply(large, small)
Browse files Browse the repository at this point in the history
  • Loading branch information
kzrnm committed Sep 18, 2023
1 parent 353d5ea commit 68e9567
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 55 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ public static void Multiply(ReadOnlySpan<uint> left, uint right, Span<uint> bits
int i = 0;
ulong carry = 0UL;

for ( ; i < left.Length; i++)
for (; i < left.Length; i++)
{
ulong digits = (ulong)left[i] * right + carry;
bits[i] = unchecked((uint)digits);
Expand All @@ -151,9 +151,9 @@ public static void Multiply(ReadOnlySpan<uint> left, uint right, Span<uint> bits

#if DEBUG
// Mutable for unit testing...
private static
internal static
#else
private const
internal const
#endif
int MultiplyThreshold = 32;

Expand Down Expand Up @@ -211,70 +211,115 @@ public static void Multiply(ReadOnlySpan<uint> left, ReadOnlySpan<uint> right, S
// Say we want to compute z = a * b ...

// ... we need to determine our new length (just the half)
int n = right.Length >> 1;
int n2 = n << 1;

// ... split left like a = (a_1 << n) + a_0
ReadOnlySpan<uint> leftLow = left.Slice(0, n);
ReadOnlySpan<uint> leftHigh = left.Slice(n);
int n = (left.Length + 1) >> 1;
if (right.Length <= n)
{
// ... split left like a = (a_1 << n) + a_0
ReadOnlySpan<uint> leftLow = left.Slice(0, n);
ReadOnlySpan<uint> leftHigh = left.Slice(n);

// ... split right like b = (b_1 << n) + b_0
ReadOnlySpan<uint> rightLow = right.Slice(0, n);
ReadOnlySpan<uint> rightHigh = right.Slice(n);
// ... prepare our result array (to reuse its memory)
Span<uint> bitsLow = bits.Slice(0, n + right.Length);
Span<uint> bitsHigh = bits.Slice(n);

// ... prepare our result array (to reuse its memory)
Span<uint> bitsLow = bits.Slice(0, n2);
Span<uint> bitsHigh = bits.Slice(n2);
int carryLength = right.Length;
uint[]? carryFromPool = null;
Span<uint> carry = ((uint)carryLength <= StackAllocThreshold ?
stackalloc uint[StackAllocThreshold]
: carryFromPool = ArrayPool<uint>.Shared.Rent(carryLength)).Slice(0, carryLength);

// ... compute z_0 = a_0 * b_0 (multiply again)
Multiply(leftLow, rightLow, bitsLow);
// ... compute low
Multiply(leftLow, right, bitsLow);
Span<uint> carryOrig = bits.Slice(n, right.Length);
carryOrig.CopyTo(carry);
carryOrig.Clear();

// ... compute z_2 = a_1 * b_1 (multiply again)
Multiply(leftHigh, rightHigh, bitsHigh);

int leftFoldLength = leftHigh.Length + 1;
uint[]? leftFoldFromPool = null;
Span<uint> leftFold = ((uint)leftFoldLength <= StackAllocThreshold ?
stackalloc uint[StackAllocThreshold]
: leftFoldFromPool = ArrayPool<uint>.Shared.Rent(leftFoldLength)).Slice(0, leftFoldLength);
leftFold.Clear();
// ... compute high
if (leftHigh.Length < right.Length)
{
Debug.Assert(right.Length == n);
Debug.Assert(left.Length == 2 * n - 1);
Debug.Assert(leftHigh.Length == n - 1);
Multiply(right, leftHigh, bitsHigh);
}
else
{
Multiply(leftHigh, right, bitsHigh);
}

int rightFoldLength = rightHigh.Length + 1;
uint[]? rightFoldFromPool = null;
Span<uint> rightFold = ((uint)rightFoldLength <= StackAllocThreshold ?
stackalloc uint[StackAllocThreshold]
: rightFoldFromPool = ArrayPool<uint>.Shared.Rent(rightFoldLength)).Slice(0, rightFoldLength);
rightFold.Clear();
AddSelf(bitsHigh, carry);

int coreLength = leftFoldLength + rightFoldLength;
uint[]? coreFromPool = null;
Span<uint> core = ((uint)coreLength <= StackAllocThreshold ?
stackalloc uint[StackAllocThreshold]
: coreFromPool = ArrayPool<uint>.Shared.Rent(coreLength)).Slice(0, coreLength);
core.Clear();
if (carryFromPool != null)
ArrayPool<uint>.Shared.Return(carryFromPool);
}
else
{
int n2 = n << 1;

// ... split left like a = (a_1 << n) + a_0
ReadOnlySpan<uint> leftLow = left.Slice(0, n);
ReadOnlySpan<uint> leftHigh = left.Slice(n);

// ... split right like b = (b_1 << n) + b_0
ReadOnlySpan<uint> rightLow = right.Slice(0, n);
ReadOnlySpan<uint> rightHigh = right.Slice(n);

// ... prepare our result array (to reuse its memory)
Span<uint> bitsLow = bits.Slice(0, n2);
Span<uint> bitsHigh = bits.Slice(n2);

// ... compute z_0 = a_0 * b_0 (multiply again)
Multiply(leftLow, rightLow, bitsLow);

// ... compute z_2 = a_1 * b_1 (multiply again)
Multiply(leftHigh, rightHigh, bitsHigh);

int leftFoldLength = n + 1;
uint[]? leftFoldFromPool = null;
Span<uint> leftFold = ((uint)leftFoldLength <= StackAllocThreshold ?
stackalloc uint[StackAllocThreshold]
: leftFoldFromPool = ArrayPool<uint>.Shared.Rent(leftFoldLength)).Slice(0, leftFoldLength);
leftFold.Clear();

int rightFoldLength = n + 1;
uint[]? rightFoldFromPool = null;
Span<uint> rightFold = ((uint)rightFoldLength <= StackAllocThreshold ?
stackalloc uint[StackAllocThreshold]
: rightFoldFromPool = ArrayPool<uint>.Shared.Rent(rightFoldLength)).Slice(0, rightFoldLength);
rightFold.Clear();

int coreLength = leftFoldLength + rightFoldLength;
uint[]? coreFromPool = null;
Span<uint> core = ((uint)coreLength <= StackAllocThreshold ?
stackalloc uint[StackAllocThreshold]
: coreFromPool = ArrayPool<uint>.Shared.Rent(coreLength)).Slice(0, coreLength);
core.Clear();

// ... compute z_a = a_1 + a_0 (call it fold...)
Add(leftHigh, leftLow, leftFold);
// ... compute z_a = a_1 + a_0 (call it fold...)
Add(leftLow, leftHigh, leftFold);

// ... compute z_b = b_1 + b_0 (call it fold...)
Add(rightHigh, rightLow, rightFold);
// ... compute z_b = b_1 + b_0 (call it fold...)
Add(rightLow, rightHigh, rightFold);

// ... compute z_1 = z_a * z_b - z_0 - z_2
Multiply(leftFold, rightFold, core);
// ... compute z_1 = z_a * z_b - z_0 - z_2
Multiply(leftFold, rightFold, core);

if (leftFoldFromPool != null)
ArrayPool<uint>.Shared.Return(leftFoldFromPool);
if (leftFoldFromPool != null)
ArrayPool<uint>.Shared.Return(leftFoldFromPool);

if (rightFoldFromPool != null)
ArrayPool<uint>.Shared.Return(rightFoldFromPool);
if (rightFoldFromPool != null)
ArrayPool<uint>.Shared.Return(rightFoldFromPool);

SubtractCore(bitsHigh, bitsLow, core);
SubtractCore(bitsLow, bitsHigh, core);

// ... and finally merge the result! :-)
AddSelf(bits.Slice(n), core);
// ... and finally merge the result! :-)
Debug.Assert(bits.Slice(n).Length >= ActualLength(core));
AddSelf(bits.Slice(n), core.Slice(0, ActualLength(core)));

if (coreFromPool != null)
ArrayPool<uint>.Shared.Return(coreFromPool);
if (coreFromPool != null)
ArrayPool<uint>.Shared.Return(coreFromPool);
}
}
}

Expand All @@ -298,21 +343,21 @@ private static void SubtractCore(ReadOnlySpan<uint> left, ReadOnlySpan<uint> rig
ref uint leftPtr = ref MemoryMarshal.GetReference(left);
ref uint corePtr = ref MemoryMarshal.GetReference(core);

for ( ; i < right.Length; i++)
for (; i < right.Length; i++)
{
long digit = (Unsafe.Add(ref corePtr, i) + carry) - Unsafe.Add(ref leftPtr, i) - right[i];
Unsafe.Add(ref corePtr, i) = unchecked((uint)digit);
carry = digit >> 32;
}

for ( ; i < left.Length; i++)
for (; i < left.Length; i++)
{
long digit = (Unsafe.Add(ref corePtr, i) + carry) - left[i];
Unsafe.Add(ref corePtr, i) = unchecked((uint)digit);
carry = digit >> 32;
}

for ( ; carry != 0 && i < core.Length; i++)
for (; carry != 0 && i < core.Length; i++)
{
long digit = core[i] + carry;
core[i] = (uint)digit;
Expand Down
22 changes: 22 additions & 0 deletions src/libraries/System.Runtime.Numerics/tests/BigInteger/multiply.cs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,28 @@ public static void RunMultiply_Boundary()
VerifyMultiplyString(Math.Pow(2, 33) + " 2 bMultiply");
}

[Fact]
public static void RunMultiplyKaratsubaBoundary()
{
Random random = new Random(s_seed);
byte[] tempByteArray1 = new byte[0];
byte[] tempByteArray2 = new byte[0];

// Multiply Method - One Large BigInteger
for (int i = 0; i < s_samples; i++)
{
for (int d1 = -2; d1 <= 2; d1++)
{
tempByteArray1 = GetRandomByteArray(random, BigIntegerCalculator.MultiplyThreshold + d1);
for (int d2 = -4; d2 <= 4; d2++)
{
tempByteArray2 = GetRandomByteArray(random, (BigIntegerCalculator.MultiplyThreshold + 1) * 2 + d2);
VerifyMultiplyString(Print(tempByteArray1) + Print(tempByteArray2) + "bMultiply");
}
}
}
}

[Fact]
public static void RunMultiply_OnePositiveOneNegative()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,27 @@ public static void RunMultiplyBoundary()
VerifyMultiplyString(Math.Pow(2, 33) + " 2 b*");
}

[Fact]
public static void RunMultiplyKaratsubaBoundary()
{
byte[] tempByteArray1 = new byte[0];
byte[] tempByteArray2 = new byte[0];

// Multiply Method - One Large BigInteger
for (int i = 0; i < s_samples; i++)
{
for (int d1 = -2; d1 <= 2; d1++)
{
tempByteArray1 = GetRandomByteArray(s_random, BigIntegerCalculator.MultiplyThreshold + d1);
for (int d2 = -4; d2 <= 4; d2++)
{
tempByteArray2 = GetRandomByteArray(s_random, (BigIntegerCalculator.MultiplyThreshold + 1) * 2 + d2);
VerifyMultiplyString(Print(tempByteArray1) + Print(tempByteArray2) + "b*");
}
}
}
}

[Fact]
public static void RunMultiplyTests()
{
Expand Down

0 comments on commit 68e9567

Please sign in to comment.