diff --git a/src/libraries/System.Runtime.Numerics/src/System/Numerics/BigIntegerCalculator.FastReducer.cs b/src/libraries/System.Runtime.Numerics/src/System/Numerics/BigIntegerCalculator.FastReducer.cs index 0b56d2b52e245..d35f2af339932 100644 --- a/src/libraries/System.Runtime.Numerics/src/System/Numerics/BigIntegerCalculator.FastReducer.cs +++ b/src/libraries/System.Runtime.Numerics/src/System/Numerics/BigIntegerCalculator.FastReducer.cs @@ -2,6 +2,8 @@ // The .NET Foundation licenses this file to you under the MIT license. using System.Diagnostics; +using System.Runtime.CompilerServices; +using System.Runtime.InteropServices; namespace System.Numerics { @@ -20,25 +22,12 @@ private readonly ref struct FastReducer private readonly Span _q1; private readonly Span _q2; - public FastReducer(ReadOnlySpan modulus, Span r, Span mu, Span q1, Span q2) + public FastReducer(FastReducerConstructorHelper helper) { - Debug.Assert(!modulus.IsEmpty); - Debug.Assert(r.Length == modulus.Length * 2 + 1); - Debug.Assert(mu.Length == r.Length - modulus.Length + 1); - Debug.Assert(q1.Length == modulus.Length * 2 + 2); - Debug.Assert(q2.Length == modulus.Length * 2 + 2); - - // Let r = 4^k, with 2^k > m - r[r.Length - 1] = 1; - - // Let mu = 4^k / m - Divide(r, modulus, mu); - _modulus = modulus; - - _q1 = q1; - _q2 = q2; - - _mu = mu.Slice(0, ActualLength(mu)); + _modulus = helper.Modulus; + _mu = helper.Mu; + _q1 = helper.Q1; + _q2 = helper.Q2; } public int Reduce(Span value) @@ -49,16 +38,17 @@ public int Reduce(Span value) if (value.Length < _modulus.Length) return value.Length; - // Let q1 = v/2^(k-1) * mu + // Let q1 = v/2^(k-32) * mu _q1.Clear(); int l1 = DivMul(value, _mu, _q1, _modulus.Length - 1); - // Let q2 = q1/2^(k+1) * m + // Let q2 = q1/2^(k+32) * m _q2.Clear(); int l2 = DivMul(_q1.Slice(0, l1), _modulus, _q2, _modulus.Length + 1); - // Let v = (v - q2) % 2^(k+1) - i*m - var length = SubMod(value, _q2.Slice(0, l2), _modulus, _modulus.Length + 1); + // Let v = (v - q2) % 2^k + // while m <= v: Let v = v - m + var length = SubMod(value, _q2.Slice(0, l2), _modulus, _modulus.Length); value = value.Slice(length); value.Clear(); @@ -75,6 +65,10 @@ private static int DivMul(ReadOnlySpan left, ReadOnlySpan right, Spa // but skips the first k limbs of left, which is equivalent to // preceding division by 2^(32*k). To spare memory allocations // we write the result to an already allocated memory. + // Note that the k used here has different scale from the k used + // in the description of barrett reduction. + // The former refers to the number of elements in the array, + // while the latter refers to the number of bits. if (left.Length > k) { @@ -101,17 +95,23 @@ private static int DivMul(ReadOnlySpan left, ReadOnlySpan right, Spa private static int SubMod(Span left, ReadOnlySpan right, ReadOnlySpan modulus, int k) { + Debug.Assert(left.Length >= k); + // Executes the subtraction algorithm for left and right, // but considers only the first k limbs, which is equivalent to // preceding reduction by 2^(32*k). Furthermore, if left is // still greater than modulus, further subtractions are used. + // Note that the k used here has different scale from the k used + // in the description of barrett reduction. + // The former refers to the number of elements in the array, + // while the latter refers to the number of bits. if (left.Length > k) left = left.Slice(0, k); if (right.Length > k) right = right.Slice(0, k); - SubtractSelf(left, right); + OverflowableSubtractSelf(left, right); left = left.Slice(0, ActualLength(left)); while (Compare(left, modulus) >= 0) @@ -122,6 +122,79 @@ private static int SubMod(Span left, ReadOnlySpan right, ReadOnlySpa return left.Length; } + + private static void OverflowableSubtractSelf(Span left, ReadOnlySpan right) + { + Debug.Assert(left.Length >= right.Length); + + int i = 0; + long carry = 0L; + + // Switching to managed references helps eliminating + // index bounds check... + ref uint leftPtr = ref MemoryMarshal.GetReference(left); + + // Executes the "grammar-school" algorithm for computing z = a - b. + // We're writing the result directly to a and + // stop execution, if we're out of b. + + for (; i < right.Length; i++) + { + long digit = (Unsafe.Add(ref leftPtr, i) + carry) - right[i]; + Unsafe.Add(ref leftPtr, i) = unchecked((uint)digit); + carry = digit >> 32; + } + for (; carry != 0 && i < left.Length; i++) + { + long digit = left[i] + carry; + left[i] = (uint)digit; + carry = digit >> 32; + } + } + } + + // Helper for constructor of FastReducer. + // need to add q1 and q2 after constructing the FastReducer, but we + // can't do it with the FastReducer structure itself because it's + // a read-only structure. + private ref struct FastReducerConstructorHelper + { + internal ReadOnlySpan Modulus; + internal ReadOnlySpan Mu; + internal Span Q1; + internal Span Q2; + + public FastReducerConstructorHelper(ReadOnlySpan modulus, Span r, Span mu) + { + Debug.Assert(!modulus.IsEmpty); + Debug.Assert(r.Length == modulus.Length * 2 + 1); + Debug.Assert(mu.Length == r.Length - modulus.Length + 1); + + // Let r = 2^(2k), with 2^k > m and k % 32 = 0 + r[r.Length - 1] = 1; + + // Let mu = r / m + Divide(r, modulus, mu); + Modulus = modulus; + + Mu = mu.Slice(0, ActualLength(mu)); + Q1 = default; + Q2 = default; + } + + public int GetMuLength() + { + return Mu.Length; + } + + public void AddQs(Span q1, Span q2) + { + Debug.Assert(q1.Length == Mu.Length + Modulus.Length + 1); + Debug.Assert(q2.Length == Mu.Length + Modulus.Length); + + Q1 = q1; + Q2 = q2; + } } } } diff --git a/src/libraries/System.Runtime.Numerics/src/System/Numerics/BigIntegerCalculator.PowMod.cs b/src/libraries/System.Runtime.Numerics/src/System/Numerics/BigIntegerCalculator.PowMod.cs index 0407feff54fe3..ce5667c74d19c 100644 --- a/src/libraries/System.Runtime.Numerics/src/System/Numerics/BigIntegerCalculator.PowMod.cs +++ b/src/libraries/System.Runtime.Numerics/src/System/Numerics/BigIntegerCalculator.PowMod.cs @@ -323,23 +323,30 @@ stackalloc uint[StackAllocThreshold] : muFromPool = ArrayPool.Shared.Rent(size)).Slice(0, size); mu.Clear(); - size = modulus.Length * 2 + 2; + FastReducerConstructorHelper helper = new FastReducerConstructorHelper(modulus, r, mu); + + if (rFromPool != null) + ArrayPool.Shared.Return(rFromPool); + + int muLength = helper.GetMuLength(); + + size = muLength + modulus.Length + 1; uint[]? q1FromPool = null; Span q1 = ((uint)size <= StackAllocThreshold ? stackalloc uint[StackAllocThreshold] : q1FromPool = ArrayPool.Shared.Rent(size)).Slice(0, size); q1.Clear(); + size = muLength + modulus.Length; uint[]? q2FromPool = null; Span q2 = ((uint)size <= StackAllocThreshold ? stackalloc uint[StackAllocThreshold] : q2FromPool = ArrayPool.Shared.Rent(size)).Slice(0, size); q2.Clear(); - FastReducer reducer = new FastReducer(modulus, r, mu, q1, q2); + helper.AddQs(q1, q2); - if (rFromPool != null) - ArrayPool.Shared.Return(rFromPool); + FastReducer reducer = new FastReducer(helper); PowCore(value, valueLength, power, reducer, bits, 1, temp).CopyTo(bits); @@ -379,23 +386,30 @@ stackalloc uint[StackAllocThreshold] : muFromPool = ArrayPool.Shared.Rent(size)).Slice(0, size); mu.Clear(); - size = modulus.Length * 2 + 2; + FastReducerConstructorHelper helper = new FastReducerConstructorHelper(modulus, r, mu); + + if (rFromPool != null) + ArrayPool.Shared.Return(rFromPool); + + int muLength = helper.GetMuLength(); + + size = muLength + modulus.Length + 1; uint[]? q1FromPool = null; Span q1 = ((uint)size <= StackAllocThreshold ? stackalloc uint[StackAllocThreshold] : q1FromPool = ArrayPool.Shared.Rent(size)).Slice(0, size); q1.Clear(); + size = muLength + modulus.Length; uint[]? q2FromPool = null; Span q2 = ((uint)size <= StackAllocThreshold ? stackalloc uint[StackAllocThreshold] : q2FromPool = ArrayPool.Shared.Rent(size)).Slice(0, size); q2.Clear(); - FastReducer reducer = new FastReducer(modulus, r, mu, q1, q2); + helper.AddQs(q1, q2); - if (rFromPool != null) - ArrayPool.Shared.Return(rFromPool); + FastReducer reducer = new FastReducer(helper); PowCore(value, valueLength, power, reducer, bits, 1, temp).CopyTo(bits); diff --git a/src/libraries/System.Runtime.Numerics/tests/BigInteger/modpow.cs b/src/libraries/System.Runtime.Numerics/tests/BigInteger/modpow.cs index 54af681215170..654204bdeffdd 100644 --- a/src/libraries/System.Runtime.Numerics/tests/BigInteger/modpow.cs +++ b/src/libraries/System.Runtime.Numerics/tests/BigInteger/modpow.cs @@ -286,6 +286,55 @@ public static void ModPowBoundary() VerifyModPowString(Math.Pow(2, 35) + " " + Math.Pow(2, 33) + " 2 tModPow"); } + [Fact] + [OuterLoop] + public static void ModPowFastReducerBoundary() + { + BigIntTools.Utils.RunWithFakeThreshold("ReducerThreshold", 8, () => + { + byte[] tempByteArray1 = new byte[40]; + byte[] tempByteArray2 = new byte[40]; + byte[] tempByteArray3 = new byte[40]; + byte[] tempByteArray4 = new byte[40]; + byte[] tempByteArray5 = new byte[40]; + byte[] tempByteArray6 = new byte[40]; + + for (int i = 0; i < 32; i++) + { + tempByteArray2[i] = 0xff; + } + tempByteArray3[0] = 1; + for (int i = 0; i < 36; i++) + { + tempByteArray4[i] = 0xff; + } + tempByteArray5[36] = 1; + tempByteArray6[0] = 1; + tempByteArray6[36] = 1; + + for (int i = 32; i < 40; i++) + { + for (int j = 0; j < 8; j++) + { + tempByteArray1[i] = (byte)(1 << j); + tempByteArray2[i] |= (byte)(1 << j); + tempByteArray3[i] = (byte)(1 << j); + VerifyModPowString(Print(tempByteArray4) + "2 " + Print(tempByteArray1) + "tModPow"); + VerifyModPowString(Print(tempByteArray5) + "2 " + Print(tempByteArray1) + "tModPow"); + VerifyModPowString(Print(tempByteArray6) + "2 " + Print(tempByteArray1) + "tModPow"); + VerifyModPowString(Print(tempByteArray4) + "2 " + Print(tempByteArray2) + "tModPow"); + VerifyModPowString(Print(tempByteArray5) + "2 " + Print(tempByteArray2) + "tModPow"); + VerifyModPowString(Print(tempByteArray6) + "2 " + Print(tempByteArray2) + "tModPow"); + VerifyModPowString(Print(tempByteArray4) + "2 " + Print(tempByteArray3) + "tModPow"); + VerifyModPowString(Print(tempByteArray5) + "2 " + Print(tempByteArray3) + "tModPow"); + VerifyModPowString(Print(tempByteArray6) + "2 " + Print(tempByteArray3) + "tModPow"); + } + tempByteArray1[i] = 0; + tempByteArray3[i] = 0; + } + }); + } + private static void VerifyModPowString(string opstring) { StackCalc sc = new StackCalc(opstring);