Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a Math.inv function that inverse a number in Z/nZ #4839

Merged
merged 28 commits into from
Jan 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/cool-mangos-compare.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'openzeppelin-solidity': minor
---

`Math`: add an `invMod` function to get the modular multiplicative inverse of a number in Z/nZ.
65 changes: 61 additions & 4 deletions contracts/utils/math/Math.sol
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,10 @@ library Math {
}

/**
* @notice Calculates floor(x * y / denominator) with full precision. Throws if result overflows a uint256 or
* @dev Calculates floor(x * y / denominator) with full precision. Throws if result overflows a uint256 or
* denominator == 0.
* @dev Original credit to Remco Bloemen under MIT license (https://xn--2-umb.com/21/muldiv) with further edits by
*
* Original credit to Remco Bloemen under MIT license (https://xn--2-umb.com/21/muldiv) with further edits by
* Uniswap Labs also under MIT license.
*/
function mulDiv(uint256 x, uint256 y, uint256 denominator) internal pure returns (uint256 result) {
Expand Down Expand Up @@ -208,7 +209,7 @@ library Math {
}

/**
* @notice Calculates x * y / denominator with full precision, following the selected rounding direction.
* @dev Calculates x * y / denominator with full precision, following the selected rounding direction.
*/
function mulDiv(uint256 x, uint256 y, uint256 denominator, Rounding rounding) internal pure returns (uint256) {
uint256 result = mulDiv(x, y, denominator);
Expand All @@ -218,6 +219,62 @@ library Math {
return result;
}

/**
* @dev Calculate the modular multiplicative inverse of a number in Z/nZ.
*
* If n is a prime, then Z/nZ is a field. In that case all elements are inversible, expect 0.
* If n is not a prime, then Z/nZ is not a field, and some elements might not be inversible.
*
* If the input value is not inversible, 0 is returned.
*/
function invMod(uint256 a, uint256 n) internal pure returns (uint256) {
unchecked {
if (n == 0) return 0;

// The inverse modulo is calculated using the Extended Euclidean Algorithm (iterative version)
// Used to compute integers x and y such that: ax + ny = gcd(a, n).
// When the gcd is 1, then the inverse of a modulo n exists and it's x.
// ax + ny = 1
// ax = 1 + (-y)n
// ax ≡ 1 (mod n) # x is the inverse of a modulo n

// If the remainder is 0 the gcd is n right away.
uint256 remainder = a % n;
uint256 gcd = n;

// Therefore the initial coefficients are:
// ax + ny = gcd(a, n) = n
// 0a + 1n = n
int256 x = 0;
int256 y = 1;

while (remainder != 0) {
uint256 quotient = gcd / remainder;

(gcd, remainder) = (
// The old remainder is the next gcd to try.
remainder,
// Compute the next remainder.
// Can't overflow given that (a % gcd) * (gcd // (a % gcd)) <= gcd
// where gcd is at most n (capped to type(uint256).max)
gcd - remainder * quotient
);

(x, y) = (
// Increment the coefficient of a.
y,
// Decrement the coefficient of n.
// Can overflow, but the result is casted to uint256 so that the
// next value of y is "wrapped around" to a value between 0 and n - 1.
x - y * int256(quotient)
);
}

if (gcd != 1) return 0; // No inverse exists.
return x < 0 ? (n - uint256(-x)) : uint256(x); // Wrap the result if it's negative.
}
}

/**
* @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded
* towards zero.
Expand Down Expand Up @@ -258,7 +315,7 @@ library Math {
}

/**
* @notice Calculates sqrt(a), following the selected rounding direction.
* @dev Calculates sqrt(a), following the selected rounding direction.
*/
function sqrt(uint256 a, Rounding rounding) internal pure returns (uint256) {
unchecked {
Expand Down
35 changes: 35 additions & 0 deletions test/utils/math/Math.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,41 @@ contract MathTest is Test {
return value * value < ref;
}

// INV
function testInvMod(uint256 value, uint256 p) public {
_testInvMod(value, p, true);
}

function testInvMod2(uint256 seed) public {
uint256 p = 2; // prime
_testInvMod(bound(seed, 1, p - 1), p, false);
}

function testInvMod17(uint256 seed) public {
uint256 p = 17; // prime
_testInvMod(bound(seed, 1, p - 1), p, false);
}

function testInvMod65537(uint256 seed) public {
uint256 p = 65537; // prime
_testInvMod(bound(seed, 1, p - 1), p, false);
}

function testInvModP256(uint256 seed) public {
uint256 p = 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff; // prime
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow now I get why the rumors of a backdoor in secp256r1, this is a weird number

_testInvMod(bound(seed, 1, p - 1), p, false);
}

function _testInvMod(uint256 value, uint256 p, bool allowZero) private {
uint256 inverse = Math.invMod(value, p);
if (inverse != 0) {
assertEq(mulmod(value, inverse, p), 1);
assertLt(inverse, p);
} else {
assertTrue(allowZero);
}
}

// LOG2
function testLog2(uint256 input, uint8 r) public {
Math.Rounding rounding = _asRounding(r);
Expand Down
38 changes: 38 additions & 0 deletions test/utils/math/Math.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic');

const { Rounding } = require('../../helpers/enums');
const { min, max } = require('../../helpers/math');
const { randomArray, generators } = require('../../helpers/random');

const RoundingDown = [Rounding.Floor, Rounding.Trunc];
const RoundingUp = [Rounding.Ceil, Rounding.Expand];
Expand Down Expand Up @@ -298,6 +299,43 @@ describe('Math', function () {
});
});

describe('invMod', function () {
for (const factors of [
[0n],
[1n],
[2n],
[17n],
[65537n],
[0xffffffff00000001000000000000000000000000ffffffffffffffffffffffffn],
[3n, 5n],
[3n, 7n],
[47n, 53n],
]) {
const p = factors.reduce((acc, f) => acc * f, 1n);

describe(`using p=${p} which is ${p > 1 && factors.length > 1 ? 'not ' : ''}a prime`, function () {
it('trying to inverse 0 returns 0', async function () {
expect(await this.mock.$invMod(0, p)).to.equal(0n);
expect(await this.mock.$invMod(p, p)).to.equal(0n); // p is 0 mod p
});

if (p != 0) {
for (const value of randomArray(generators.uint256, 16)) {
const isInversible = factors.every(f => value % f);
it(`trying to inverse ${value}`, async function () {
const result = await this.mock.$invMod(value, p);
if (isInversible) {
expect((value * result) % p).to.equal(1n);
} else {
expect(result).to.equal(0n);
}
});
}
}
});
}
});

describe('sqrt', function () {
it('rounds down', async function () {
for (const rounding of RoundingDown) {
Expand Down