From 6709a08f448ddef0b360dd4648ee7bb4f97d1d44 Mon Sep 17 00:00:00 2001 From: cairo Date: Mon, 30 Sep 2024 09:05:44 -0700 Subject: [PATCH] Fix P256 corner cases (#5218) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Hadrien Croubois Co-authored-by: Ernesto García Signed-off-by: Hadrien Croubois --- .solcover.js | 8 +++ contracts/utils/cryptography/P256.sol | 94 ++++++++++++++++++++------- test/helpers/iterate.js | 4 +- test/utils/cryptography/P256.t.sol | 12 ++-- 4 files changed, 88 insertions(+), 30 deletions(-) diff --git a/.solcover.js b/.solcover.js index e0dea5e2c9b..f079998cff3 100644 --- a/.solcover.js +++ b/.solcover.js @@ -10,4 +10,12 @@ module.exports = { fgrep: '[skip-on-coverage]', invert: true, }, + // Work around stack too deep for coverage + configureYulOptimizer: true, + solcOptimizerDetails: { + yul: true, + yulDetails: { + optimizerSteps: '', + }, + }, }; diff --git a/contracts/utils/cryptography/P256.sol b/contracts/utils/cryptography/P256.sol index 1c46e38b0be..3028505ba75 100644 --- a/contracts/utils/cryptography/P256.sol +++ b/contracts/utils/cryptography/P256.sol @@ -185,6 +185,13 @@ library P256 { /** * @dev Point addition on the jacobian coordinates * Reference: https://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian.html#addition-add-1998-cmo-2 + * + * Note that: + * + * - `addition-add-1998-cmo-2` doesn't support identical input points. This version is modified to use + * the `h` and `r` values computed by `addition-add-1998-cmo-2` to detect identical inputs, and fallback to + * `doubling-dbl-1998-cmo-2` if needed. + * - if one of the points is at infinity (i.e. `z=0`), the result is undefined. */ function _jAdd( JPoint memory p1, @@ -197,25 +204,53 @@ library P256 { let z1 := mload(add(p1, 0x40)) let zz1 := mulmod(z1, z1, p) // zz1 = z1² let s1 := mulmod(mload(add(p1, 0x20)), mulmod(mulmod(z2, z2, p), z2, p), p) // s1 = y1*z2³ - let r := addmod(mulmod(y2, mulmod(zz1, z1, p), p), sub(p, s1), p) // r = s2-s1 = y2*z1³-s1 + let r := addmod(mulmod(y2, mulmod(zz1, z1, p), p), sub(p, s1), p) // r = s2-s1 = y2*z1³-s1 = y2*z1³-y1*z2³ let u1 := mulmod(mload(p1), mulmod(z2, z2, p), p) // u1 = x1*z2² - let h := addmod(mulmod(x2, zz1, p), sub(p, u1), p) // h = u2-u1 = x2*z1²-u1 - let hh := mulmod(h, h, p) // h² + let h := addmod(mulmod(x2, zz1, p), sub(p, u1), p) // h = u2-u1 = x2*z1²-u1 = x2*z1²-x1*z2² + + // detect edge cases where inputs are identical + switch and(iszero(r), iszero(h)) + // case 0: points are different + case 0 { + let hh := mulmod(h, h, p) // h² + + // x' = r²-h³-2*u1*h² + rx := addmod( + addmod(mulmod(r, r, p), sub(p, mulmod(h, hh, p)), p), + sub(p, mulmod(2, mulmod(u1, hh, p), p)), + p + ) + // y' = r*(u1*h²-x')-s1*h³ + ry := addmod( + mulmod(r, addmod(mulmod(u1, hh, p), sub(p, rx), p), p), + sub(p, mulmod(s1, mulmod(h, hh, p), p)), + p + ) + // z' = h*z1*z2 + rz := mulmod(h, mulmod(z1, z2, p), p) + } + // case 1: points are equal + case 1 { + let x := x2 + let y := y2 + let z := z2 + let yy := mulmod(y, y, p) + let zz := mulmod(z, z, p) + let m := addmod(mulmod(3, mulmod(x, x, p), p), mulmod(A, mulmod(zz, zz, p), p), p) // m = 3*x²+a*z⁴ + let s := mulmod(4, mulmod(x, yy, p), p) // s = 4*x*y² + + // x' = t = m²-2*s + rx := addmod(mulmod(m, m, p), sub(p, mulmod(2, s, p)), p) - // x' = r²-h³-2*u1*h² - rx := addmod( - addmod(mulmod(r, r, p), sub(p, mulmod(h, hh, p)), p), - sub(p, mulmod(2, mulmod(u1, hh, p), p)), - p - ) - // y' = r*(u1*h²-x')-s1*h³ - ry := addmod( - mulmod(r, addmod(mulmod(u1, hh, p), sub(p, rx), p), p), - sub(p, mulmod(s1, mulmod(h, hh, p), p)), - p - ) - // z' = h*z1*z2 - rz := mulmod(h, mulmod(z1, z2, p), p) + // y' = m*(s-t)-8*y⁴ = m*(s-x')-8*y⁴ + // cut the computation to avoid stack too deep + let rytmp1 := sub(p, mulmod(8, mulmod(yy, yy, p), p)) // -8*y⁴ + let rytmp2 := addmod(s, sub(p, rx), p) // s-x' + ry := addmod(mulmod(m, rytmp2, p), rytmp1, p) // m*(s-x')-8*y⁴ + + // z' = 2*y*z + rz := mulmod(2, mulmod(y, z, p), p) + } } } @@ -228,8 +263,8 @@ library P256 { let p := P let yy := mulmod(y, y, p) let zz := mulmod(z, z, p) - let s := mulmod(4, mulmod(x, yy, p), p) // s = 4*x*y² let m := addmod(mulmod(3, mulmod(x, x, p), p), mulmod(A, mulmod(zz, zz, p), p), p) // m = 3*x²+a*z⁴ + let s := mulmod(4, mulmod(x, yy, p), p) // s = 4*x*y² // x' = t = m²-2*s rx := addmod(mulmod(m, m, p), sub(p, mulmod(2, s, p)), p) @@ -244,10 +279,11 @@ library P256 { * @dev Compute G·u1 + P·u2 using the precomputed points for G and P (see {_preComputeJacobianPoints}). * * Uses Strauss Shamir trick for EC multiplication - * https://stackoverflow.com/questions/50993471/ec-scalar-multiplication-with-strauss-shamir-method. - * We optimise on this a bit to do with 2 bits at a time rather than a single bit. - * The individual points for a single pass are precomputed. - * Overall this reduces the number of additions while keeping the same number of doublings. + * https://stackoverflow.com/questions/50993471/ec-scalar-multiplication-with-strauss-shamir-method + * + * We optimize this for 2 bits at a time rather than a single bit. The individual points for a single pass are + * precomputed. Overall this reduces the number of additions while keeping the same number of + * doublings */ function _jMultShamir( JPoint[16] memory points, @@ -263,9 +299,14 @@ library P256 { (x, y, z) = _jDouble(x, y, z); (x, y, z) = _jDouble(x, y, z); } - // Read 2 bits of u1, and 2 bits of u2. Combining the two give a lookup index in the table. + // Read 2 bits of u1, and 2 bits of u2. Combining the two gives the lookup index in the table. uint256 pos = ((u1 >> 252) & 0xc) | ((u2 >> 254) & 0x3); - if (pos > 0) { + // Points that have z = 0 are points at infinity. They are the additive 0 of the group + // - if the lookup point is a 0, we can skip it + // - otherwise: + // - if the current point (x, y, z) is 0, we use the lookup point as our new value (0+P=P) + // - if the current point (x, y, z) is not 0, both points are valid and we can use `_jAdd` + if (points[pos].z != 0) { if (z == 0) { (x, y, z) = (points[pos].x, points[pos].y, points[pos].z); } else { @@ -291,6 +332,11 @@ library P256 { * │ 8 │ 2g 2g+p 2g+2p 2g+3p │ * │ 12 │ 3g 3g+p 3g+2p 3g+3p │ * └────┴─────────────────────┘ + * + * Note that `_jAdd` (and thus `_jAddPoint`) does not handle the case where one of the inputs is a point at + * infinity (z = 0). However, we know that since `N ≡ 1 mod 2` and `N ≡ 1 mod 3`, there is no point P such that + * 2P = 0 or 3P = 0. This guarantees that g, 2g, 3g, p, 2p, 3p are all non-zero, and that all `_jAddPoint` calls + * have valid inputs. */ function _preComputeJacobianPoints(uint256 px, uint256 py) private pure returns (JPoint[16] memory points) { points[0x00] = JPoint(0, 0, 0); // 0,0 diff --git a/test/helpers/iterate.js b/test/helpers/iterate.js index ef4526e133f..c7403d52384 100644 --- a/test/helpers/iterate.js +++ b/test/helpers/iterate.js @@ -13,11 +13,11 @@ module.exports = { // Range from start to end in increment // Example: range(17,42,7) → [17,24,31,38] range: (start, stop = undefined, step = 1) => { - if (!stop) { + if (stop == undefined) { stop = start; start = 0; } - return start < stop ? Array.from({ length: Math.ceil((stop - start) / step) }, (_, i) => start + i * step) : []; + return start < stop ? Array.from({ length: (stop - start + step - 1) / step }, (_, i) => start + i * step) : []; }, // Unique elements, with an optional getter function diff --git a/test/utils/cryptography/P256.t.sol b/test/utils/cryptography/P256.t.sol index 1391afd76ef..8b95ff2259d 100644 --- a/test/utils/cryptography/P256.t.sol +++ b/test/utils/cryptography/P256.t.sol @@ -9,8 +9,8 @@ import {Math} from "@openzeppelin/contracts/utils/math/Math.sol"; contract P256Test is Test { /// forge-config: default.fuzz.runs = 512 - function testVerify(uint256 seed, bytes32 digest) public { - uint256 privateKey = bound(uint256(keccak256(abi.encode(seed))), 1, P256.N - 1); + function testVerify(bytes32 digest, uint256 seed) public { + uint256 privateKey = _asPrivateKey(seed); (bytes32 x, bytes32 y) = P256PublicKey.getPublicKey(privateKey); (bytes32 r, bytes32 s) = vm.signP256(privateKey, digest); @@ -20,8 +20,8 @@ contract P256Test is Test { } /// forge-config: default.fuzz.runs = 512 - function testRecover(uint256 seed, bytes32 digest) public { - uint256 privateKey = bound(uint256(keccak256(abi.encode(seed))), 1, P256.N - 1); + function testRecover(bytes32 digest, uint256 seed) public { + uint256 privateKey = _asPrivateKey(seed); (bytes32 x, bytes32 y) = P256PublicKey.getPublicKey(privateKey); (bytes32 r, bytes32 s) = vm.signP256(privateKey, digest); @@ -31,6 +31,10 @@ contract P256Test is Test { assertTrue((qx0 == x && qy0 == y) || (qx1 == x && qy1 == y)); } + function _asPrivateKey(uint256 seed) private pure returns (uint256) { + return bound(seed, 1, P256.N - 1); + } + function _ensureLowerS(bytes32 s) private pure returns (bytes32) { uint256 _s = uint256(s); unchecked {