Skip to content

Commit

Permalink
mlkem: tweak recursive NTT implementation #147
Browse files Browse the repository at this point in the history
  • Loading branch information
marsella committed Oct 31, 2024
1 parent e4732cd commit 4b7b2d8
Showing 1 changed file with 32 additions and 22 deletions.
54 changes: 32 additions & 22 deletions Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry
Original file line number Diff line number Diff line change
Expand Up @@ -517,7 +517,13 @@ submodule NTT where
property NaiveNTTsMatch f = NaiveNTT' f == NaiveNTT f

private

/**
* Naive version of NTT, implemented using recursing instead of loops.
* [FIPS-203] Algorithm 9.
*
* Note that this implementation is spread out across multiple functions
* to support the use of numeric constraint guards.
*/
NaiveNTT' : Rq -> Tq
NaiveNTT' f = state.f_hat where
// Step 1 - 2. Initialize `f_hat`, `i`.
Expand All @@ -532,53 +538,57 @@ submodule NTT where
len_loop state
// Step 3: Stop if we're at the end of the loop.
| len < 2 => state
// Otherwise, we're in a valid loop iteration. Evaluate the
// body of the loop and then update `len` appropriately.
| len >= 2 => len_loop`{len / 2} (start_loop`{len, 0} state)
// Otherwise, we're in a valid loop iteration.
| len >= 2 => state'' where
// Evaluate the body of the loop...
state' = start_loop`{len, 0} state
// ...then start the next iteration.
state'' = len_loop`{len / 2} state'

// Steps 4 - 12.
start_loop : {len, start} (fin len, fin start) => State -> State
start_loop state
// Step 4: Stop if we're at the end of the loop.
| start >= 256 => state
// Otherwise, we're in a valid loop iteration. Evaluate the body
// of the loop and then update `start` appropriately.
| start < 256 => start_loop`{len, start + 2 * len} (
// Step 7-11. Initialize `j <- start`.
j_loop`{len, start, start} {
// Step 5.
z = zeta ^^(BitRev7 state.i),
// Step 6.
i = state.i + 1,
// `f_hat` doesn't get updated outside of the `j_loop`.
f_hat = state.f_hat
})
// Otherwise, we're in a valid loop iteration.
| start < 256 => state''' where
// Step 5.
z = zeta ^^(BitRev7 state.i)
// Step 6.
i = state.i + 1
// Save the changes from 5-6.
state' = { z = z, i = i, f_hat = state.f_hat }
// Step 7-11. Evaluate the `j`-loop.
state'' = j_loop`{len, start, start} state'
// Start the next iteration of the `start` loop.
state''' = start_loop`{len, start + 2 * len} state''

// Steps 7 - 11.
j_loop : {len, start, j}
(start <= j, j <= start + len)
j_loop : {len, start, j} (start <= j, j <= start + len)
=> State -> State
j_loop state
// Step 7: Stop if we're at the end of the loop
| (j == start + len) => state
// This case is impossible to reach; `j + len` will always be a valid
// index into `f_hat`. It's not possible to infer that from the type
// constraints we have now, so state it explicitly.
// constraints we have now, so it's stated explicitly.
| (j + len >= 256) => state
// Otherwise, we're in a valid loop iteration.
| (j + len < 256, j < start + len) => state' where
| (j + len < 256, j < start + len) => state'' where
// Step 8.
t = state.z * state.f_hat @(`j + `len)
// Step 9.
f_hat' = set_f`{j + len} state.f_hat (state.f_hat @`j - t)
// Step 10.
f_hat'' = set_f`{j} f_hat' (f_hat' @`j + t)
// Save the updated version of `f_hat` and keep looping.
state' = j_loop`{len, start, j+1} {
// Save the changes made in Steps 8-10.
state' = {
z = state.z,
i = state.i,
f_hat = f_hat''
}
// Start the next iteration of the loop.
state'' = j_loop`{len, start, j+1} state'

// Helper function to set the `idx`th value of the polynomial.
set_f : {idx} (idx <= 255) => Tq -> Z q -> Tq
Expand Down

0 comments on commit 4b7b2d8

Please sign in to comment.