Skip to content

Commit

Permalink
mlkem: add an explicit, recursive NTT version #147
Browse files Browse the repository at this point in the history
aims to match the spec more closely, as much as that's possible with the
built-in limitations of cryptol.
  • Loading branch information
marsella committed Oct 31, 2024
1 parent 71a5776 commit e4732cd
Showing 1 changed file with 74 additions and 0 deletions.
74 changes: 74 additions & 0 deletions Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,81 @@ BitRev7 i = if i > 255 then error "BitRev7 called with invalid input"
*/
import submodule NTT
submodule NTT where
/**
* ```repl
* :prove NaiveNTTsMatch
* ```
*/
property NaiveNTTsMatch f = NaiveNTT' f == NaiveNTT f

private

NaiveNTT' : Rq -> Tq
NaiveNTT' f = state.f_hat where
// Step 1 - 2. Initialize `f_hat`, `i`.
state0 = { z = 0, i = 1, f_hat = f}
// Step 3. Initialize `len` and evaluate the body of the loop.
state = len_loop`{len = 128} state0

type State = { z : Z q, i : [8] , f_hat : Tq }

// Step 3 - 13.
len_loop : {len} (len <= 128) => State -> State
len_loop state
// Step 3: Stop if we're at the end of the loop.
| len < 2 => state
// Otherwise, we're in a valid loop iteration. Evaluate the
// body of the loop and then update `len` appropriately.
| len >= 2 => len_loop`{len / 2} (start_loop`{len, 0} state)

// Steps 4 - 12.
start_loop : {len, start} (fin len, fin start) => State -> State
start_loop state
// Step 4: Stop if we're at the end of the loop.
| start >= 256 => state
// Otherwise, we're in a valid loop iteration. Evaluate the body
// of the loop and then update `start` appropriately.
| start < 256 => start_loop`{len, start + 2 * len} (
// Step 7-11. Initialize `j <- start`.
j_loop`{len, start, start} {
// Step 5.
z = zeta ^^(BitRev7 state.i),
// Step 6.
i = state.i + 1,
// `f_hat` doesn't get updated outside of the `j_loop`.
f_hat = state.f_hat
})

// Steps 7 - 11.
j_loop : {len, start, j}
(start <= j, j <= start + len)
=> State -> State
j_loop state
// Step 7: Stop if we're at the end of the loop
| (j == start + len) => state
// This case is impossible to reach; `j + len` will always be a valid
// index into `f_hat`. It's not possible to infer that from the type
// constraints we have now, so state it explicitly.
| (j + len >= 256) => state
// Otherwise, we're in a valid loop iteration.
| (j + len < 256, j < start + len) => state' where
// Step 8.
t = state.z * state.f_hat @(`j + `len)
// Step 9.
f_hat' = set_f`{j + len} state.f_hat (state.f_hat @`j - t)
// Step 10.
f_hat'' = set_f`{j} f_hat' (f_hat' @`j + t)
// Save the updated version of `f_hat` and keep looping.
state' = j_loop`{len, start, j+1} {
z = state.z,
i = state.i,
f_hat = f_hat''
}

// Helper function to set the `idx`th value of the polynomial.
set_f : {idx} (idx <= 255) => Tq -> Z q -> Tq
set_f poly val = take`{idx} poly # [val] # drop`{idx + 1} poly

/**
* Number theoretic transform: compute the "NTT representation" in
* `T_q` of a polynomial in `R_q`.
Expand Down

0 comments on commit e4732cd

Please sign in to comment.