Skip to content

Commit

Permalink
mlkem: clean up NTT multiplication functions #147
Browse files Browse the repository at this point in the history
- Adds docs to BitRev and contains its behavior a bit better
- Adjust spacing, naming, etc in MultiplyNTTs and BaseCaseMultiply
  • Loading branch information
marsella committed Oct 31, 2024
1 parent 5a6a4a5 commit 7294204
Showing 1 changed file with 30 additions and 14 deletions.
44 changes: 30 additions & 14 deletions Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry
Original file line number Diff line number Diff line change
Expand Up @@ -477,9 +477,22 @@ property Is256thRootOfq p = (p == 0) || (p >= 256) || (zeta^^p != 1)
* Reverse the unsigned 7-bit value corresponding to an input integer in
* `[0, ..., 127]`.
* [FIPS-203] Section 4.3 "The mathematical structure of the NTT."
*
* This diverges from the spec by operating over an 8-bit vector;
* this is to ease prior and subsequent computations that would overflow a
* 7-bit vector, like:
* - `2 * (BitRev7 i) + 1`
* - `2 * i + 1`
*
* A "pure" implementation of `BitRev7` in Cryptol is the `reverse` function
* on 7-bit vectors. This mini-property shows equivalence:
* ```repl
* :prove \(x:[7]) -> ([0] # reverse x) == BitRev7 ([0] # x)
* ```
*/
BitRev7 : [8] -> [8]
BitRev7 = reverse
BitRev7 i = if i > 255 then error "BitRev7 called with invalid input"
else (reverse i) >> 1


/**
Expand All @@ -504,8 +517,8 @@ submodule NTT where
*/
ParametricNTT : Rq -> (Z q) -> Tq
ParametricNTT f root = join[[f2i i, f2iPlus1 i] | i <- [0 .. 127]]
where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 i >> 1)+1)*j) | j <- [0 .. 127]]
f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 i >> 1)+1)*j) | j <- [0 .. 127]]
where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 i)+1)*j) | j <- [0 .. 127]]
f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 i)+1)*j) | j <- [0 .. 127]]

/**
* Compute most of the polynomial that corresponds to the NTT representation
Expand All @@ -516,8 +529,8 @@ submodule NTT where
*/
ParametricNTTInv : Tq -> (Z q) -> Rq
ParametricNTTInv f root = join[[f2i i, f2iPlus1 i] | i <- [0 .. 127]]
where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 j >> 1)+1)*i) | j <- [0 .. 127]]
f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 j >> 1)+1)*i) | j <- [0 .. 127]]
where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 j)+1)*i) | j <- [0 .. 127]]
f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 j)+1)*i) | j <- [0 .. 127]]

/**
* Number theoretic transform: converts elements in `R_q` to `T_q`.
Expand Down Expand Up @@ -749,22 +762,25 @@ submodule NTT where
* quadratic modulus.
* [FIPS-203] Section 4.3.1 Algorithm 12.
*/
BaseCaseMultiply : [2] (Z q) -> [2] (Z q) -> (Z q) -> [2] (Z q)
BaseCaseMultiply a b root = [c0, c1]
BaseCaseMultiply : (Z q) -> (Z q) -> (Z q) -> (Z q) -> (Z q) -> [2](Z q)
BaseCaseMultiply a0 a1 b0 b1 γ = [c0, c1]
where
c0 = a@1 * b@1 * root + a@0 * b@0
c1 = a@0 * b@1 + a@1 * b@0
c0 = a0 * b0 + a1 * b1 * γ
c1 = a0 * b1 + a1 * b0

/**
* Compute the product (in the ring `T_q`) of two NTT representations.
* [FIPS-203] Section 4.3.1 Algorithm 11.
*/
MultiplyNTTs : Tq -> Tq -> Tq
MultiplyNTTs a b = join [BaseCaseMultiply (f_hat_i i) (g_hat_i i) (root i) | i : Byte <- [0 .. 127]]
where
f_hat_i i = [a@(2*i),a@(2*i+1)]
g_hat_i i = [b@(2*i),b@(2*i+1)]
root i = (zeta^^(reverse (64 + (i >> 1)) >> 1) * ((-1 : (Z q)) ^^ (i)))
MultiplyNTTs f_hat g_hat = join h_hat where
h_hat = [ BaseCaseMultiply
(f_hat @(2*i))
(f_hat @(2*i+1))
(g_hat @(2*i))
(g_hat @(2*i+1))
(zeta ^^(2 * BitRev7 i + 1))
| i <- [0 .. 127] ]

/**
* Testing that (1+x)^2 = 1+2x+x^2.
Expand Down

0 comments on commit 7294204

Please sign in to comment.