Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[resubmit] Fix bug of FastReducer used in BigInteger.ModPow #55122

Closed
wants to merge 11 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,19 @@ public FastReducer(uint[] modulus)
{
Debug.Assert(modulus != null);

// Let r = 4^k, with 2^k > m
// Let r = (2^32)^(2k), with (2^32)^k > m
Copy link
Member

@tannergooding tannergooding Oct 4, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you explain why base is 2^32, which I'm interpreting as 4294967296?

Copy link
Contributor Author

@key-moon key-moon Oct 7, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The formula in the comment changed because the following formula is not correct. It should be v/2^(k-32) * mu.

// Let q1 = v/2^(k-1) * mu
int l1 = DivMul(value, length, _mu, _muLength,
_q1, _modulus.Length - 1);

Instead of change (k-1) to (k-32), I choose to change base to 2^32 from 2 because the base of the number notation in this code is 2^32. If you write as 2^(k-32), I think it is more harder to understand because the reader doesn't sure where 32 came from.

But I can understand why you prefer base 2. So, what about //Let r = 2^(32*2*k), with 2^(32*2*k) > m instead of 2^(2k) or (2^32)^(2k)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to keep this inline with the paper and just use 2^(k-32). We can leave a comment that its 32 because we operate on 32-bits at a time

uint[] r = new uint[modulus.Length * 2 + 1];
r[r.Length - 1] = 1;

// Let mu = 4^k / m
// Let mu = r / m
_mu = Divide(r, modulus);
_modulus = modulus;

// Allocate memory for quotients once
_q1 = new uint[modulus.Length * 2 + 2];
_q2 = new uint[modulus.Length * 2 + 1];

_muLength = ActualLength(_mu);

// Allocate memory for quotients once
_q1 = new uint[_muLength + modulus.Length + 1];
_q2 = new uint[_muLength + modulus.Length];
}

public int Reduce(uint[] value, int length)
Expand All @@ -52,17 +52,18 @@ public int Reduce(uint[] value, int length)
if (length < _modulus.Length)
return length;

// Let q1 = v/2^(k-1) * mu
// Let q1 = v/(2^32)^(k-1) * mu
int l1 = DivMul(value, length, _mu, _muLength,
_q1, _modulus.Length - 1);

// Let q2 = q1/2^(k+1) * m
// Let q2 = q1/(2^32)^(k+1) * m
int l2 = DivMul(_q1, l1, _modulus, _modulus.Length,
_q2, _modulus.Length + 1);

// Let v = (v - q2) % 2^(k+1) - i*m
// Let v = (v - q2) % (2^32)^k
// while m <= v: Let v = v - m
return SubMod(value, length, _q2, l2,
_modulus, _modulus.Length + 1);
_modulus, _modulus.Length);
}

private static unsafe int DivMul(uint[] left, int leftLength,
Expand Down Expand Up @@ -130,7 +131,7 @@ private static unsafe int SubMod(uint[] left, int leftLength,

fixed (uint* l = left, r = right, m = modulus)
{
SubtractSelf(l, leftLength, r, rightLength);
OverflowableSubtractSelf(l, leftLength, r, rightLength);
leftLength = ActualLength(left, leftLength);

while (Compare(l, leftLength, m, modulus.Length) >= 0)
Expand All @@ -144,6 +145,34 @@ private static unsafe int SubMod(uint[] left, int leftLength,

return leftLength;
}

private static unsafe void OverflowableSubtractSelf(uint* left, int leftLength,
uint* right, int rightLength)
{
Debug.Assert(leftLength >= 0);
Debug.Assert(rightLength >= 0);
Debug.Assert(leftLength >= rightLength);

// Executes the "grammar-school" algorithm for computing z = a - b.
// We're writing the result directly to a and
// stop execution, if we're out of b.

int i = 0;
long carry = 0L;

for (; i < rightLength; i++)
{
long digit = (left[i] + carry) - right[i];
left[i] = unchecked((uint)digit);
carry = digit >> 32;
}
for (; carry != 0 && i < leftLength; i++)
{
long digit = left[i] + carry;
left[i] = (uint)digit;
carry = digit >> 32;
}
}
}
}
}
49 changes: 49 additions & 0 deletions src/libraries/System.Runtime.Numerics/tests/BigInteger/modpow.cs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,55 @@ public static void ModPowBoundary()
VerifyModPowString(Math.Pow(2, 35) + " " + Math.Pow(2, 33) + " 2 tModPow");
}

[Fact]
[OuterLoop]
public static void ModPowFastReducerBoundary()
{
BigIntTools.Utils.RunWithFakeThreshold("ReducerThreshold", 8, () =>
{
byte[] tempByteArray1 = new byte[40];
byte[] tempByteArray2 = new byte[40];
byte[] tempByteArray3 = new byte[40];
byte[] tempByteArray4 = new byte[40];
byte[] tempByteArray5 = new byte[40];
byte[] tempByteArray6 = new byte[40];

for (int i = 0; i < 32; i++)
{
tempByteArray2[i] = 0xff;
}
tempByteArray3[0] = 1;
for (int i = 0; i < 36; i++)
{
tempByteArray4[i] = 0xff;
}
tempByteArray5[36] = 1;
tempByteArray6[0] = 1;
tempByteArray6[36] = 1;

for (int i = 32; i < 40; i++)
{
for (int j = 0; j < 8; j++)
{
tempByteArray1[i] = (byte)(1 << j);
tempByteArray2[i] |= (byte)(1 << j);
tempByteArray3[i] = (byte)(1 << j);
VerifyModPowString(Print(tempByteArray4) + "2 " + Print(tempByteArray1) + "tModPow");
VerifyModPowString(Print(tempByteArray5) + "2 " + Print(tempByteArray1) + "tModPow");
VerifyModPowString(Print(tempByteArray6) + "2 " + Print(tempByteArray1) + "tModPow");
VerifyModPowString(Print(tempByteArray4) + "2 " + Print(tempByteArray2) + "tModPow");
VerifyModPowString(Print(tempByteArray5) + "2 " + Print(tempByteArray2) + "tModPow");
VerifyModPowString(Print(tempByteArray6) + "2 " + Print(tempByteArray2) + "tModPow");
VerifyModPowString(Print(tempByteArray4) + "2 " + Print(tempByteArray3) + "tModPow");
VerifyModPowString(Print(tempByteArray5) + "2 " + Print(tempByteArray3) + "tModPow");
VerifyModPowString(Print(tempByteArray6) + "2 " + Print(tempByteArray3) + "tModPow");
}
tempByteArray1[i] = 0;
tempByteArray3[i] = 0;
}
});
}

private static void VerifyModPowString(string opstring)
{
StackCalc sc = new StackCalc(opstring);
Expand Down