@@ -185,6 +185,13 @@ library P256 {
185
185
/**
186
186
* @dev Point addition on the jacobian coordinates
187
187
* Reference: https://www.hyperelliptic.org/EFD/g1p/auto-shortw-jacobian.html#addition-add-1998-cmo-2
188
+ *
189
+ * Note that:
190
+ *
191
+ * - `addition-add-1998-cmo-2` doesn't support identical input points. This version is modified to use
192
+ * the `h` and `r` values computed by `addition-add-1998-cmo-2` to detect identical inputs, and fallback to
193
+ * `doubling-dbl-1998-cmo-2` if needed.
194
+ * - if one of the points is at infinity (i.e. `z=0`), the result is undefined.
188
195
*/
189
196
function _jAdd (
190
197
JPoint memory p1 ,
@@ -197,25 +204,53 @@ library P256 {
197
204
let z1 := mload (add (p1, 0x40 ))
198
205
let zz1 := mulmod (z1, z1, p) // zz1 = z1²
199
206
let s1 := mulmod (mload (add (p1, 0x20 )), mulmod (mulmod (z2, z2, p), z2, p), p) // s1 = y1*z2³
200
- let r := addmod (mulmod (y2, mulmod (zz1, z1, p), p), sub (p, s1), p) // r = s2-s1 = y2*z1³-s1
207
+ let r := addmod (mulmod (y2, mulmod (zz1, z1, p), p), sub (p, s1), p) // r = s2-s1 = y2*z1³-s1 = y2*z1³-y1*z2³
201
208
let u1 := mulmod (mload (p1), mulmod (z2, z2, p), p) // u1 = x1*z2²
202
- let h := addmod (mulmod (x2, zz1, p), sub (p, u1), p) // h = u2-u1 = x2*z1²-u1
203
- let hh := mulmod (h, h, p) // h²
209
+ let h := addmod (mulmod (x2, zz1, p), sub (p, u1), p) // h = u2-u1 = x2*z1²-u1 = x2*z1²-x1*z2²
210
+
211
+ // detect edge cases where inputs are identical
212
+ switch and (iszero (r), iszero (h))
213
+ // case 0: points are different
214
+ case 0 {
215
+ let hh := mulmod (h, h, p) // h²
216
+
217
+ // x' = r²-h³-2*u1*h²
218
+ rx := addmod (
219
+ addmod (mulmod (r, r, p), sub (p, mulmod (h, hh, p)), p),
220
+ sub (p, mulmod (2 , mulmod (u1, hh, p), p)),
221
+ p
222
+ )
223
+ // y' = r*(u1*h²-x')-s1*h³
224
+ ry := addmod (
225
+ mulmod (r, addmod (mulmod (u1, hh, p), sub (p, rx), p), p),
226
+ sub (p, mulmod (s1, mulmod (h, hh, p), p)),
227
+ p
228
+ )
229
+ // z' = h*z1*z2
230
+ rz := mulmod (h, mulmod (z1, z2, p), p)
231
+ }
232
+ // case 1: points are equal
233
+ case 1 {
234
+ let x := x2
235
+ let y := y2
236
+ let z := z2
237
+ let yy := mulmod (y, y, p)
238
+ let zz := mulmod (z, z, p)
239
+ let m := addmod (mulmod (3 , mulmod (x, x, p), p), mulmod (A, mulmod (zz, zz, p), p), p) // m = 3*x²+a*z⁴
240
+ let s := mulmod (4 , mulmod (x, yy, p), p) // s = 4*x*y²
241
+
242
+ // x' = t = m²-2*s
243
+ rx := addmod (mulmod (m, m, p), sub (p, mulmod (2 , s, p)), p)
204
244
205
- // x' = r²-h³-2*u1*h²
206
- rx := addmod (
207
- addmod (mulmod (r, r, p), sub (p, mulmod (h, hh, p)), p),
208
- sub (p, mulmod (2 , mulmod (u1, hh, p), p)),
209
- p
210
- )
211
- // y' = r*(u1*h²-x')-s1*h³
212
- ry := addmod (
213
- mulmod (r, addmod (mulmod (u1, hh, p), sub (p, rx), p), p),
214
- sub (p, mulmod (s1, mulmod (h, hh, p), p)),
215
- p
216
- )
217
- // z' = h*z1*z2
218
- rz := mulmod (h, mulmod (z1, z2, p), p)
245
+ // y' = m*(s-t)-8*y⁴ = m*(s-x')-8*y⁴
246
+ // cut the computation to avoid stack too deep
247
+ let rytmp1 := sub (p, mulmod (8 , mulmod (yy, yy, p), p)) // -8*y⁴
248
+ let rytmp2 := addmod (s, sub (p, rx), p) // s-x'
249
+ ry := addmod (mulmod (m, rytmp2, p), rytmp1, p) // m*(s-x')-8*y⁴
250
+
251
+ // z' = 2*y*z
252
+ rz := mulmod (2 , mulmod (y, z, p), p)
253
+ }
219
254
}
220
255
}
221
256
@@ -228,8 +263,8 @@ library P256 {
228
263
let p := P
229
264
let yy := mulmod (y, y, p)
230
265
let zz := mulmod (z, z, p)
231
- let s := mulmod (4 , mulmod (x, yy, p), p) // s = 4*x*y²
232
266
let m := addmod (mulmod (3 , mulmod (x, x, p), p), mulmod (A, mulmod (zz, zz, p), p), p) // m = 3*x²+a*z⁴
267
+ let s := mulmod (4 , mulmod (x, yy, p), p) // s = 4*x*y²
233
268
234
269
// x' = t = m²-2*s
235
270
rx := addmod (mulmod (m, m, p), sub (p, mulmod (2 , s, p)), p)
@@ -244,10 +279,11 @@ library P256 {
244
279
* @dev Compute G·u1 + P·u2 using the precomputed points for G and P (see {_preComputeJacobianPoints}).
245
280
*
246
281
* Uses Strauss Shamir trick for EC multiplication
247
- * https://stackoverflow.com/questions/50993471/ec-scalar-multiplication-with-strauss-shamir-method.
248
- * We optimise on this a bit to do with 2 bits at a time rather than a single bit.
249
- * The individual points for a single pass are precomputed.
250
- * Overall this reduces the number of additions while keeping the same number of doublings.
282
+ * https://stackoverflow.com/questions/50993471/ec-scalar-multiplication-with-strauss-shamir-method
283
+ *
284
+ * We optimize this for 2 bits at a time rather than a single bit. The individual points for a single pass are
285
+ * precomputed. Overall this reduces the number of additions while keeping the same number of
286
+ * doublings
251
287
*/
252
288
function _jMultShamir (
253
289
JPoint[16 ] memory points ,
@@ -263,9 +299,14 @@ library P256 {
263
299
(x, y, z) = _jDouble (x, y, z);
264
300
(x, y, z) = _jDouble (x, y, z);
265
301
}
266
- // Read 2 bits of u1, and 2 bits of u2. Combining the two give a lookup index in the table.
302
+ // Read 2 bits of u1, and 2 bits of u2. Combining the two gives the lookup index in the table.
267
303
uint256 pos = ((u1 >> 252 ) & 0xc ) | ((u2 >> 254 ) & 0x3 );
268
- if (pos > 0 ) {
304
+ // Points that have z = 0 are points at infinity. They are the additive 0 of the group
305
+ // - if the lookup point is a 0, we can skip it
306
+ // - otherwise:
307
+ // - if the current point (x, y, z) is 0, we use the lookup point as our new value (0+P=P)
308
+ // - if the current point (x, y, z) is not 0, both points are valid and we can use `_jAdd`
309
+ if (points[pos].z != 0 ) {
269
310
if (z == 0 ) {
270
311
(x, y, z) = (points[pos].x, points[pos].y, points[pos].z);
271
312
} else {
@@ -291,6 +332,11 @@ library P256 {
291
332
* │ 8 │ 2g 2g+p 2g+2p 2g+3p │
292
333
* │ 12 │ 3g 3g+p 3g+2p 3g+3p │
293
334
* └────┴─────────────────────┘
335
+ *
336
+ * Note that `_jAdd` (and thus `_jAddPoint`) does not handle the case where one of the inputs is a point at
337
+ * infinity (z = 0). However, we know that since `N ≡ 1 mod 2` and `N ≡ 1 mod 3`, there is no point P such that
338
+ * 2P = 0 or 3P = 0. This guarantees that g, 2g, 3g, p, 2p, 3p are all non-zero, and that all `_jAddPoint` calls
339
+ * have valid inputs.
294
340
*/
295
341
function _preComputeJacobianPoints (uint256 px , uint256 py ) private pure returns (JPoint[16 ] memory points ) {
296
342
points[0x00 ] = JPoint (0 , 0 , 0 ); // 0,0
0 commit comments