Skip to content

Commit

Permalink
Merge bitcoin#906: Use modified divsteps with initial delta=1/2 for c…
Browse files Browse the repository at this point in the history
…onstant-time

be0609f Add unit tests for edge cases with delta=1/2 variant of divsteps (Pieter Wuille)
cd393ce Optimization: only do 59 hddivsteps per iteration instead of 62 (Pieter Wuille)
277b224 Use modified divsteps with initial delta=1/2 for constant-time (Pieter Wuille)
376ca36 Fix typo in explanation (Pieter Wuille)

Pull request description:

  This updates the divsteps-based modular inverse code to use the modified version which starts with delta=1/2. For variable time, the delta=1 variant is still used as it appears to be faster.

  See https://github.com/sipa/safegcd-bounds/tree/master/coq and https://medium.com/blockstream/a-formal-proof-of-safegcd-bounds-695e1735a348 for a proof of correctness of this variant.

  TODO:
  * [x] Update unit tests to include edge cases specific to this variant

  I'm still running the Coq proof verification for the 590 bound in non-native mode. It's unclear how long this will take.

ACKs for top commit:
  gmaxwell:
    ACK be0609f
  sanket1729:
    crACK be0609f
  real-or-random:
    ACK be0609f careful code review and some testing

Tree-SHA512: 2f8f400ba3ac8dbd08622d564c3b3e5ff30768bd0eb559f2c4279c6c813e17cdde71b1c16f05742c5657b5238b4d592b48306f9f47d7dbdb57907e58dd99b47a
  • Loading branch information
real-or-random committed Apr 22, 2021
2 parents cc2c09e + be0609f commit efad350
Show file tree
Hide file tree
Showing 4 changed files with 745 additions and 101 deletions.
41 changes: 28 additions & 13 deletions doc/safegcd_implementation.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ def modinv(M, Mi, x):

This means that in practice we'll always perform a multiple of *N* divsteps. This is not a problem
because once *g=0*, further divsteps do not affect *f*, *g*, *d*, or *e* anymore (only *δ* keeps
increasing). For variable time code such excess iterations will be mostly optimized away in
section 6.
increasing). For variable time code such excess iterations will be mostly optimized away in later
sections.


## 4. Avoiding modulus operations
Expand Down Expand Up @@ -519,6 +519,20 @@ computation:
g >>= 1
```

A variant of divsteps with better worst-case performance can be used instead: starting *δ* at
*1/2* instead of *1*. This reduces the worst case number of iterations to *590* for *256*-bit inputs
(which can be shown using convex hull analysis). In this case, the substitution *ζ=-(δ+1/2)*
is used instead to keep the variable integral. Incrementing *δ* by *1* still translates to
decrementing *ζ* by *1*, but negating *δ* now corresponds to going from *ζ* to *-(ζ+1)*, or
*~ζ*. Doing that conditionally based on *c3* is simply:

```python
...
c3 = c1 & c2
zeta ^= c3
...
```

By replacing the loop in `divsteps_n_matrix` with a variant of the divstep code above (extended to
also apply all *f* operations to *u*, *v* and all *g* operations to *q*, *r*), a constant-time version of
`divsteps_n_matrix` is obtained. The full code will be in section 7.
Expand All @@ -535,7 +549,8 @@ other cases, it slows down calculations unnecessarily. In this section, we will
faster non-constant time `divsteps_n_matrix` function.

To do so, first consider yet another way of writing the inner loop of divstep operations in
`gcd` from section 1. This decomposition is also explained in the paper in section 8.2.
`gcd` from section 1. This decomposition is also explained in the paper in section 8.2. We use
the original version with initial *δ=1* and *η=-δ* here.

```python
for _ in range(N):
Expand Down Expand Up @@ -643,24 +658,24 @@ All together we need the following functions:
section 5, extended to handle *u*, *v*, *q*, *r*:

```python
def divsteps_n_matrix(eta, f, g):
"""Compute eta and transition matrix t after N divsteps (multiplied by 2^N)."""
def divsteps_n_matrix(zeta, f, g):
"""Compute zeta and transition matrix t after N divsteps (multiplied by 2^N)."""
u, v, q, r = 1, 0, 0, 1 # start with identity matrix
for _ in range(N):
c1 = eta >> 63
c1 = zeta >> 63
# Compute x, y, z as conditionally-negated versions of f, u, v.
x, y, z = (f ^ c1) - c1, (u ^ c1) - c1, (v ^ c1) - c1
c2 = -(g & 1)
# Conditionally add x, y, z to g, q, r.
g, q, r = g + (x & c2), q + (y & c2), r + (z & c2)
c1 &= c2 # reusing c1 here for the earlier c3 variable
eta = (eta ^ c1) - (c1 + 1) # inlining the unconditional eta decrement here
zeta = (zeta ^ c1) - 1 # inlining the unconditional zeta decrement here
# Conditionally add g, q, r to f, u, v.
f, u, v = f + (g & c1), u + (q & c1), v + (r & c1)
# When shifting g down, don't shift q, r, as we construct a transition matrix multiplied
# by 2^N. Instead, shift f's coefficients u and v up.
g, u, v = g >> 1, u << 1, v << 1
return eta, (u, v, q, r)
return zeta, (u, v, q, r)
```

- The functions to update *f* and *g*, and *d* and *e*, from section 2 and section 4, with the constant-time
Expand All @@ -681,7 +696,7 @@ def update_de(d, e, t, M, Mi):
cd, ce = (u*d + v*e) % 2**N, (q*d + r*e) % 2**N
md -= (Mi*cd + md) % 2**N
me -= (Mi*ce + me) % 2**N
cd, ce = u*d + v*e + Mi*md, q*d + r*e + Mi*me
cd, ce = u*d + v*e + M*md, q*d + r*e + M*me
return cd >> N, ce >> N
```

Expand All @@ -702,15 +717,15 @@ def normalize(sign, v, M):
return v
```

- And finally the `modinv` function too, adapted to use *&eta;* instead of *&delta;*, and using the fixed
- And finally the `modinv` function too, adapted to use *&zeta;* instead of *&delta;*, and using the fixed
iteration count from section 5:

```python
def modinv(M, Mi, x):
"""Compute the modular inverse of x mod M, given Mi=1/M mod 2^N."""
eta, f, g, d, e = -1, M, x, 0, 1
for _ in range((724 + N - 1) // N):
eta, t = divsteps_n_matrix(-eta, f % 2**N, g % 2**N)
zeta, f, g, d, e = -1, M, x, 0, 1
for _ in range((590 + N - 1) // N):
zeta, t = divsteps_n_matrix(zeta, f % 2**N, g % 2**N)
f, g = update_fg(f, g, t)
d, e = update_de(d, e, t, M, Mi)
return normalize(f, d, M)
Expand Down
42 changes: 21 additions & 21 deletions src/modinv32_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,17 +168,17 @@ typedef struct {
int32_t u, v, q, r;
} secp256k1_modinv32_trans2x2;

/* Compute the transition matrix and eta for 30 divsteps.
/* Compute the transition matrix and zeta for 30 divsteps.
*
* Input: eta: initial eta
* f0: bottom limb of initial f
* g0: bottom limb of initial g
* Input: zeta: initial zeta
* f0: bottom limb of initial f
* g0: bottom limb of initial g
* Output: t: transition matrix
* Return: final eta
* Return: final zeta
*
* Implements the divsteps_n_matrix function from the explanation.
*/
static int32_t secp256k1_modinv32_divsteps_30(int32_t eta, uint32_t f0, uint32_t g0, secp256k1_modinv32_trans2x2 *t) {
static int32_t secp256k1_modinv32_divsteps_30(int32_t zeta, uint32_t f0, uint32_t g0, secp256k1_modinv32_trans2x2 *t) {
/* u,v,q,r are the elements of the transformation matrix being built up,
* starting with the identity matrix. Semantically they are signed integers
* in range [-2^30,2^30], but here represented as unsigned mod 2^32. This
Expand All @@ -193,8 +193,8 @@ static int32_t secp256k1_modinv32_divsteps_30(int32_t eta, uint32_t f0, uint32_t
VERIFY_CHECK((f & 1) == 1); /* f must always be odd */
VERIFY_CHECK((u * f0 + v * g0) == f << i);
VERIFY_CHECK((q * f0 + r * g0) == g << i);
/* Compute conditional masks for (eta < 0) and for (g & 1). */
c1 = eta >> 31;
/* Compute conditional masks for (zeta < 0) and for (g & 1). */
c1 = zeta >> 31;
c2 = -(g & 1);
/* Compute x,y,z, conditionally negated versions of f,u,v. */
x = (f ^ c1) - c1;
Expand All @@ -204,10 +204,10 @@ static int32_t secp256k1_modinv32_divsteps_30(int32_t eta, uint32_t f0, uint32_t
g += x & c2;
q += y & c2;
r += z & c2;
/* In what follows, c1 is a condition mask for (eta < 0) and (g & 1). */
/* In what follows, c1 is a condition mask for (zeta < 0) and (g & 1). */
c1 &= c2;
/* Conditionally negate eta, and unconditionally subtract 1. */
eta = (eta ^ c1) - (c1 + 1);
/* Conditionally change zeta into -zeta-2 or zeta-1. */
zeta = (zeta ^ c1) - 1;
/* Conditionally add g,q,r to f,u,v. */
f += g & c1;
u += q & c1;
Expand All @@ -216,8 +216,8 @@ static int32_t secp256k1_modinv32_divsteps_30(int32_t eta, uint32_t f0, uint32_t
g >>= 1;
u <<= 1;
v <<= 1;
/* Bounds on eta that follow from the bounds on iteration count (max 25*30 divsteps). */
VERIFY_CHECK(eta >= -751 && eta <= 751);
/* Bounds on zeta that follow from the bounds on iteration count (max 20*30 divsteps). */
VERIFY_CHECK(zeta >= -601 && zeta <= 601);
}
/* Return data in t and return value. */
t->u = (int32_t)u;
Expand All @@ -229,7 +229,7 @@ static int32_t secp256k1_modinv32_divsteps_30(int32_t eta, uint32_t f0, uint32_t
* will be divided out again). As each divstep's individual matrix has determinant 2, the
* aggregate of 30 of them will have determinant 2^30. */
VERIFY_CHECK((int64_t)t->u * t->r - (int64_t)t->v * t->q == ((int64_t)1) << 30);
return eta;
return zeta;
}

/* Compute the transition matrix and eta for 30 divsteps (variable time).
Expand Down Expand Up @@ -453,19 +453,19 @@ static void secp256k1_modinv32_update_fg_30_var(int len, secp256k1_modinv32_sign

/* Compute the inverse of x modulo modinfo->modulus, and replace x with it (constant time in x). */
static void secp256k1_modinv32(secp256k1_modinv32_signed30 *x, const secp256k1_modinv32_modinfo *modinfo) {
/* Start with d=0, e=1, f=modulus, g=x, eta=-1. */
/* Start with d=0, e=1, f=modulus, g=x, zeta=-1. */
secp256k1_modinv32_signed30 d = {{0}};
secp256k1_modinv32_signed30 e = {{1}};
secp256k1_modinv32_signed30 f = modinfo->modulus;
secp256k1_modinv32_signed30 g = *x;
int i;
int32_t eta = -1;
int32_t zeta = -1; /* zeta = -(delta+1/2); delta is initially 1/2. */

/* Do 25 iterations of 30 divsteps each = 750 divsteps. 724 suffices for 256-bit inputs. */
for (i = 0; i < 25; ++i) {
/* Compute transition matrix and new eta after 30 divsteps. */
/* Do 20 iterations of 30 divsteps each = 600 divsteps. 590 suffices for 256-bit inputs. */
for (i = 0; i < 20; ++i) {
/* Compute transition matrix and new zeta after 30 divsteps. */
secp256k1_modinv32_trans2x2 t;
eta = secp256k1_modinv32_divsteps_30(eta, f.v[0], g.v[0], &t);
zeta = secp256k1_modinv32_divsteps_30(zeta, f.v[0], g.v[0], &t);
/* Update d,e using that transition matrix. */
secp256k1_modinv32_update_de_30(&d, &e, &t, modinfo);
/* Update f,g using that transition matrix. */
Expand Down Expand Up @@ -515,7 +515,7 @@ static void secp256k1_modinv32_var(secp256k1_modinv32_signed30 *x, const secp256
int i = 0;
#endif
int j, len = 9;
int32_t eta = -1;
int32_t eta = -1; /* eta = -delta; delta is initially 1 (faster for the variable-time code) */
int32_t cond, fn, gn;

/* Do iterations of 30 divsteps each until g=0. */
Expand Down
62 changes: 33 additions & 29 deletions src/modinv64_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -145,33 +145,35 @@ typedef struct {
int64_t u, v, q, r;
} secp256k1_modinv64_trans2x2;

/* Compute the transition matrix and eta for 62 divsteps.
/* Compute the transition matrix and eta for 59 divsteps (where zeta=-(delta+1/2)).
* Note that the transformation matrix is scaled by 2^62 and not 2^59.
*
* Input: eta: initial eta
* f0: bottom limb of initial f
* g0: bottom limb of initial g
* Input: zeta: initial zeta
* f0: bottom limb of initial f
* g0: bottom limb of initial g
* Output: t: transition matrix
* Return: final eta
* Return: final zeta
*
* Implements the divsteps_n_matrix function from the explanation.
*/
static int64_t secp256k1_modinv64_divsteps_62(int64_t eta, uint64_t f0, uint64_t g0, secp256k1_modinv64_trans2x2 *t) {
static int64_t secp256k1_modinv64_divsteps_59(int64_t zeta, uint64_t f0, uint64_t g0, secp256k1_modinv64_trans2x2 *t) {
/* u,v,q,r are the elements of the transformation matrix being built up,
* starting with the identity matrix. Semantically they are signed integers
* starting with the identity matrix times 8 (because the caller expects
* a result scaled by 2^62). Semantically they are signed integers
* in range [-2^62,2^62], but here represented as unsigned mod 2^64. This
* permits left shifting (which is UB for negative numbers). The range
* being inside [-2^63,2^63) means that casting to signed works correctly.
*/
uint64_t u = 1, v = 0, q = 0, r = 1;
uint64_t u = 8, v = 0, q = 0, r = 8;
uint64_t c1, c2, f = f0, g = g0, x, y, z;
int i;

for (i = 0; i < 62; ++i) {
for (i = 3; i < 62; ++i) {
VERIFY_CHECK((f & 1) == 1); /* f must always be odd */
VERIFY_CHECK((u * f0 + v * g0) == f << i);
VERIFY_CHECK((q * f0 + r * g0) == g << i);
/* Compute conditional masks for (eta < 0) and for (g & 1). */
c1 = eta >> 63;
/* Compute conditional masks for (zeta < 0) and for (g & 1). */
c1 = zeta >> 63;
c2 = -(g & 1);
/* Compute x,y,z, conditionally negated versions of f,u,v. */
x = (f ^ c1) - c1;
Expand All @@ -181,10 +183,10 @@ static int64_t secp256k1_modinv64_divsteps_62(int64_t eta, uint64_t f0, uint64_t
g += x & c2;
q += y & c2;
r += z & c2;
/* In what follows, c1 is a condition mask for (eta < 0) and (g & 1). */
/* In what follows, c1 is a condition mask for (zeta < 0) and (g & 1). */
c1 &= c2;
/* Conditionally negate eta, and unconditionally subtract 1. */
eta = (eta ^ c1) - (c1 + 1);
/* Conditionally change zeta into -zeta-2 or zeta-1. */
zeta = (zeta ^ c1) - 1;
/* Conditionally add g,q,r to f,u,v. */
f += g & c1;
u += q & c1;
Expand All @@ -193,8 +195,8 @@ static int64_t secp256k1_modinv64_divsteps_62(int64_t eta, uint64_t f0, uint64_t
g >>= 1;
u <<= 1;
v <<= 1;
/* Bounds on eta that follow from the bounds on iteration count (max 12*62 divsteps). */
VERIFY_CHECK(eta >= -745 && eta <= 745);
/* Bounds on zeta that follow from the bounds on iteration count (max 10*59 divsteps). */
VERIFY_CHECK(zeta >= -591 && zeta <= 591);
}
/* Return data in t and return value. */
t->u = (int64_t)u;
Expand All @@ -204,12 +206,14 @@ static int64_t secp256k1_modinv64_divsteps_62(int64_t eta, uint64_t f0, uint64_t
/* The determinant of t must be a power of two. This guarantees that multiplication with t
* does not change the gcd of f and g, apart from adding a power-of-2 factor to it (which
* will be divided out again). As each divstep's individual matrix has determinant 2, the
* aggregate of 62 of them will have determinant 2^62. */
VERIFY_CHECK((int128_t)t->u * t->r - (int128_t)t->v * t->q == ((int128_t)1) << 62);
return eta;
* aggregate of 59 of them will have determinant 2^59. Multiplying with the initial
* 8*identity (which has determinant 2^6) means the overall outputs has determinant
* 2^65. */
VERIFY_CHECK((int128_t)t->u * t->r - (int128_t)t->v * t->q == ((int128_t)1) << 65);
return zeta;
}

/* Compute the transition matrix and eta for 62 divsteps (variable time).
/* Compute the transition matrix and eta for 62 divsteps (variable time, eta=-delta).
*
* Input: eta: initial eta
* f0: bottom limb of initial f
Expand Down Expand Up @@ -290,7 +294,7 @@ static int64_t secp256k1_modinv64_divsteps_62_var(int64_t eta, uint64_t f0, uint
return eta;
}

/* Compute (t/2^62) * [d, e] mod modulus, where t is a transition matrix for 62 divsteps.
/* Compute (t/2^62) * [d, e] mod modulus, where t is a transition matrix scaled by 2^62.
*
* On input and output, d and e are in range (-2*modulus,modulus). All output limbs will be in range
* (-2^62,2^62).
Expand Down Expand Up @@ -376,7 +380,7 @@ static void secp256k1_modinv64_update_de_62(secp256k1_modinv64_signed62 *d, secp
#endif
}

/* Compute (t/2^62) * [f, g], where t is a transition matrix for 62 divsteps.
/* Compute (t/2^62) * [f, g], where t is a transition matrix scaled by 2^62.
*
* This implements the update_fg function from the explanation.
*/
Expand Down Expand Up @@ -455,19 +459,19 @@ static void secp256k1_modinv64_update_fg_62_var(int len, secp256k1_modinv64_sign

/* Compute the inverse of x modulo modinfo->modulus, and replace x with it (constant time in x). */
static void secp256k1_modinv64(secp256k1_modinv64_signed62 *x, const secp256k1_modinv64_modinfo *modinfo) {
/* Start with d=0, e=1, f=modulus, g=x, eta=-1. */
/* Start with d=0, e=1, f=modulus, g=x, zeta=-1. */
secp256k1_modinv64_signed62 d = {{0, 0, 0, 0, 0}};
secp256k1_modinv64_signed62 e = {{1, 0, 0, 0, 0}};
secp256k1_modinv64_signed62 f = modinfo->modulus;
secp256k1_modinv64_signed62 g = *x;
int i;
int64_t eta = -1;
int64_t zeta = -1; /* zeta = -(delta+1/2); delta starts at 1/2. */

/* Do 12 iterations of 62 divsteps each = 744 divsteps. 724 suffices for 256-bit inputs. */
for (i = 0; i < 12; ++i) {
/* Compute transition matrix and new eta after 62 divsteps. */
/* Do 10 iterations of 59 divsteps each = 590 divsteps. This suffices for 256-bit inputs. */
for (i = 0; i < 10; ++i) {
/* Compute transition matrix and new zeta after 59 divsteps. */
secp256k1_modinv64_trans2x2 t;
eta = secp256k1_modinv64_divsteps_62(eta, f.v[0], g.v[0], &t);
zeta = secp256k1_modinv64_divsteps_59(zeta, f.v[0], g.v[0], &t);
/* Update d,e using that transition matrix. */
secp256k1_modinv64_update_de_62(&d, &e, &t, modinfo);
/* Update f,g using that transition matrix. */
Expand Down Expand Up @@ -517,7 +521,7 @@ static void secp256k1_modinv64_var(secp256k1_modinv64_signed62 *x, const secp256
int i = 0;
#endif
int j, len = 5;
int64_t eta = -1;
int64_t eta = -1; /* eta = -delta; delta is initially 1 */
int64_t cond, fn, gn;

/* Do iterations of 62 divsteps each until g=0. */
Expand Down
Loading

0 comments on commit efad350

Please sign in to comment.