Skip to content

Commit e3cfe1c

Browse files
cairoethAmxxernestognw
authored
Fix P256 corner cases (#5218)
Co-authored-by: Hadrien Croubois <hadrien.croubois@gmail.com> Co-authored-by: Ernesto García <ernestognw@gmail.com>
1 parent d3ca1d1 commit e3cfe1c

File tree

4 files changed

+88
-30
lines changed

4 files changed

+88
-30
lines changed

.solcover.js

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,12 @@ module.exports = {
1010
fgrep: '[skip-on-coverage]',
1111
invert: true,
1212
},
13+
// Work around stack too deep for coverage
14+
configureYulOptimizer: true,
15+
solcOptimizerDetails: {
16+
yul: true,
17+
yulDetails: {
18+
optimizerSteps: '',
19+
},
20+
},
1321
};

contracts/utils/cryptography/P256.sol

Lines changed: 70 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,13 @@ library P256 {
185185
/**
186186
* @dev Point addition on the jacobian coordinates
187187
* Reference: https://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian.html#addition-add-1998-cmo-2
188+
*
189+
* Note that:
190+
*
191+
* - `addition-add-1998-cmo-2` doesn't support identical input points. This version is modified to use
192+
* the `h` and `r` values computed by `addition-add-1998-cmo-2` to detect identical inputs, and fallback to
193+
* `doubling-dbl-1998-cmo-2` if needed.
194+
* - if one of the points is at infinity (i.e. `z=0`), the result is undefined.
188195
*/
189196
function _jAdd(
190197
JPoint memory p1,
@@ -197,25 +204,53 @@ library P256 {
197204
let z1 := mload(add(p1, 0x40))
198205
let zz1 := mulmod(z1, z1, p) // zz1 = z1²
199206
let s1 := mulmod(mload(add(p1, 0x20)), mulmod(mulmod(z2, z2, p), z2, p), p) // s1 = y1*z2³
200-
let r := addmod(mulmod(y2, mulmod(zz1, z1, p), p), sub(p, s1), p) // r = s2-s1 = y2*z1³-s1
207+
let r := addmod(mulmod(y2, mulmod(zz1, z1, p), p), sub(p, s1), p) // r = s2-s1 = y2*z1³-s1 = y2*z1³-y1*z2³
201208
let u1 := mulmod(mload(p1), mulmod(z2, z2, p), p) // u1 = x1*z2²
202-
let h := addmod(mulmod(x2, zz1, p), sub(p, u1), p) // h = u2-u1 = x2*z1²-u1
203-
let hh := mulmod(h, h, p) // h²
209+
let h := addmod(mulmod(x2, zz1, p), sub(p, u1), p) // h = u2-u1 = x2*z1²-u1 = x2*z1²-x1*z2²
210+
211+
// detect edge cases where inputs are identical
212+
switch and(iszero(r), iszero(h))
213+
// case 0: points are different
214+
case 0 {
215+
let hh := mulmod(h, h, p) // h²
216+
217+
// x' = r²-h³-2*u1*h²
218+
rx := addmod(
219+
addmod(mulmod(r, r, p), sub(p, mulmod(h, hh, p)), p),
220+
sub(p, mulmod(2, mulmod(u1, hh, p), p)),
221+
p
222+
)
223+
// y' = r*(u1*h²-x')-s1*h³
224+
ry := addmod(
225+
mulmod(r, addmod(mulmod(u1, hh, p), sub(p, rx), p), p),
226+
sub(p, mulmod(s1, mulmod(h, hh, p), p)),
227+
p
228+
)
229+
// z' = h*z1*z2
230+
rz := mulmod(h, mulmod(z1, z2, p), p)
231+
}
232+
// case 1: points are equal
233+
case 1 {
234+
let x := x2
235+
let y := y2
236+
let z := z2
237+
let yy := mulmod(y, y, p)
238+
let zz := mulmod(z, z, p)
239+
let m := addmod(mulmod(3, mulmod(x, x, p), p), mulmod(A, mulmod(zz, zz, p), p), p) // m = 3*x²+a*z⁴
240+
let s := mulmod(4, mulmod(x, yy, p), p) // s = 4*x*y²
241+
242+
// x' = t = m²-2*s
243+
rx := addmod(mulmod(m, m, p), sub(p, mulmod(2, s, p)), p)
204244

205-
// x' = r²-h³-2*u1*h²
206-
rx := addmod(
207-
addmod(mulmod(r, r, p), sub(p, mulmod(h, hh, p)), p),
208-
sub(p, mulmod(2, mulmod(u1, hh, p), p)),
209-
p
210-
)
211-
// y' = r*(u1*h²-x')-s1*h³
212-
ry := addmod(
213-
mulmod(r, addmod(mulmod(u1, hh, p), sub(p, rx), p), p),
214-
sub(p, mulmod(s1, mulmod(h, hh, p), p)),
215-
p
216-
)
217-
// z' = h*z1*z2
218-
rz := mulmod(h, mulmod(z1, z2, p), p)
245+
// y' = m*(s-t)-8*y⁴ = m*(s-x')-8*y⁴
246+
// cut the computation to avoid stack too deep
247+
let rytmp1 := sub(p, mulmod(8, mulmod(yy, yy, p), p)) // -8*y⁴
248+
let rytmp2 := addmod(s, sub(p, rx), p) // s-x'
249+
ry := addmod(mulmod(m, rytmp2, p), rytmp1, p) // m*(s-x')-8*y⁴
250+
251+
// z' = 2*y*z
252+
rz := mulmod(2, mulmod(y, z, p), p)
253+
}
219254
}
220255
}
221256

@@ -228,8 +263,8 @@ library P256 {
228263
let p := P
229264
let yy := mulmod(y, y, p)
230265
let zz := mulmod(z, z, p)
231-
let s := mulmod(4, mulmod(x, yy, p), p) // s = 4*x*y²
232266
let m := addmod(mulmod(3, mulmod(x, x, p), p), mulmod(A, mulmod(zz, zz, p), p), p) // m = 3*x²+a*z⁴
267+
let s := mulmod(4, mulmod(x, yy, p), p) // s = 4*x*y²
233268

234269
// x' = t = m²-2*s
235270
rx := addmod(mulmod(m, m, p), sub(p, mulmod(2, s, p)), p)
@@ -244,10 +279,11 @@ library P256 {
244279
* @dev Compute G·u1 + P·u2 using the precomputed points for G and P (see {_preComputeJacobianPoints}).
245280
*
246281
* Uses Strauss Shamir trick for EC multiplication
247-
* https://stackoverflow.com/questions/50993471/ec-scalar-multiplication-with-strauss-shamir-method.
248-
* We optimise on this a bit to do with 2 bits at a time rather than a single bit.
249-
* The individual points for a single pass are precomputed.
250-
* Overall this reduces the number of additions while keeping the same number of doublings.
282+
* https://stackoverflow.com/questions/50993471/ec-scalar-multiplication-with-strauss-shamir-method
283+
*
284+
* We optimize this for 2 bits at a time rather than a single bit. The individual points for a single pass are
285+
* precomputed. Overall this reduces the number of additions while keeping the same number of
286+
* doublings
251287
*/
252288
function _jMultShamir(
253289
JPoint[16] memory points,
@@ -263,9 +299,14 @@ library P256 {
263299
(x, y, z) = _jDouble(x, y, z);
264300
(x, y, z) = _jDouble(x, y, z);
265301
}
266-
// Read 2 bits of u1, and 2 bits of u2. Combining the two give a lookup index in the table.
302+
// Read 2 bits of u1, and 2 bits of u2. Combining the two gives the lookup index in the table.
267303
uint256 pos = ((u1 >> 252) & 0xc) | ((u2 >> 254) & 0x3);
268-
if (pos > 0) {
304+
// Points that have z = 0 are points at infinity. They are the additive 0 of the group
305+
// - if the lookup point is a 0, we can skip it
306+
// - otherwise:
307+
// - if the current point (x, y, z) is 0, we use the lookup point as our new value (0+P=P)
308+
// - if the current point (x, y, z) is not 0, both points are valid and we can use `_jAdd`
309+
if (points[pos].z != 0) {
269310
if (z == 0) {
270311
(x, y, z) = (points[pos].x, points[pos].y, points[pos].z);
271312
} else {
@@ -291,6 +332,11 @@ library P256 {
291332
* │ 8 │ 2g 2g+p 2g+2p 2g+3p │
292333
* │ 12 │ 3g 3g+p 3g+2p 3g+3p │
293334
* └────┴─────────────────────┘
335+
*
336+
* Note that `_jAdd` (and thus `_jAddPoint`) does not handle the case where one of the inputs is a point at
337+
* infinity (z = 0). However, we know that since `N ≡ 1 mod 2` and `N ≡ 1 mod 3`, there is no point P such that
338+
* 2P = 0 or 3P = 0. This guarantees that g, 2g, 3g, p, 2p, 3p are all non-zero, and that all `_jAddPoint` calls
339+
* have valid inputs.
294340
*/
295341
function _preComputeJacobianPoints(uint256 px, uint256 py) private pure returns (JPoint[16] memory points) {
296342
points[0x00] = JPoint(0, 0, 0); // 0,0

test/helpers/iterate.js

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,11 @@ module.exports = {
1313
// Range from start to end in increment
1414
// Example: range(17,42,7) → [17,24,31,38]
1515
range: (start, stop = undefined, step = 1) => {
16-
if (!stop) {
16+
if (stop == undefined) {
1717
stop = start;
1818
start = 0;
1919
}
20-
return start < stop ? Array.from({ length: Math.ceil((stop - start) / step) }, (_, i) => start + i * step) : [];
20+
return start < stop ? Array.from({ length: (stop - start + step - 1) / step }, (_, i) => start + i * step) : [];
2121
},
2222

2323
// Unique elements, with an optional getter function

test/utils/cryptography/P256.t.sol

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@ import {Math} from "@openzeppelin/contracts/utils/math/Math.sol";
99

1010
contract P256Test is Test {
1111
/// forge-config: default.fuzz.runs = 512
12-
function testVerify(uint256 seed, bytes32 digest) public {
13-
uint256 privateKey = bound(uint256(keccak256(abi.encode(seed))), 1, P256.N - 1);
12+
function testVerify(bytes32 digest, uint256 seed) public {
13+
uint256 privateKey = _asPrivateKey(seed);
1414

1515
(bytes32 x, bytes32 y) = P256PublicKey.getPublicKey(privateKey);
1616
(bytes32 r, bytes32 s) = vm.signP256(privateKey, digest);
@@ -20,8 +20,8 @@ contract P256Test is Test {
2020
}
2121

2222
/// forge-config: default.fuzz.runs = 512
23-
function testRecover(uint256 seed, bytes32 digest) public {
24-
uint256 privateKey = bound(uint256(keccak256(abi.encode(seed))), 1, P256.N - 1);
23+
function testRecover(bytes32 digest, uint256 seed) public {
24+
uint256 privateKey = _asPrivateKey(seed);
2525

2626
(bytes32 x, bytes32 y) = P256PublicKey.getPublicKey(privateKey);
2727
(bytes32 r, bytes32 s) = vm.signP256(privateKey, digest);
@@ -31,6 +31,10 @@ contract P256Test is Test {
3131
assertTrue((qx0 == x && qy0 == y) || (qx1 == x && qy1 == y));
3232
}
3333

34+
function _asPrivateKey(uint256 seed) private pure returns (uint256) {
35+
return bound(seed, 1, P256.N - 1);
36+
}
37+
3438
function _ensureLowerS(bytes32 s) private pure returns (bytes32) {
3539
uint256 _s = uint256(s);
3640
unchecked {

0 commit comments

Comments
 (0)