diff --git a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry index 3ba8dc7..cf8c7ee 100644 --- a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry +++ b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry @@ -477,9 +477,22 @@ property Is256thRootOfq p = (p == 0) || (p >= 256) || (zeta^^p != 1) * Reverse the unsigned 7-bit value corresponding to an input integer in * `[0, ..., 127]`. * [FIPS-203] Section 4.3 "The mathematical structure of the NTT." + * + * This diverges from the spec by operating over an 8-bit vector; + * this is to ease prior and subsequent computations that would overflow a + * 7-bit vector, like: + * - `2 * (BitRev7 i) + 1` + * - `2 * i + 1` + * + * A "pure" implementation of `BitRev7` in Cryptol is the `reverse` function + * on 7-bit vectors. This mini-property shows equivalence: + * ```repl + * :prove \(x:[7]) -> ([0] # reverse x) == BitRev7 ([0] # x) + * ``` */ BitRev7 : [8] -> [8] -BitRev7 = reverse +BitRev7 i = if i > 255 then error "BitRev7 called with invalid input" + else (reverse i) >> 1 /** @@ -504,8 +517,8 @@ submodule NTT where */ ParametricNTT : Rq -> (Z q) -> Tq ParametricNTT f root = join[[f2i i, f2iPlus1 i] | i <- [0 .. 127]] - where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 i >> 1)+1)*j) | j <- [0 .. 127]] - f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 i >> 1)+1)*j) | j <- [0 .. 127]] + where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 i)+1)*j) | j <- [0 .. 127]] + f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 i)+1)*j) | j <- [0 .. 127]] /** * Compute most of the polynomial that corresponds to the NTT representation @@ -516,8 +529,8 @@ submodule NTT where */ ParametricNTTInv : Tq -> (Z q) -> Rq ParametricNTTInv f root = join[[f2i i, f2iPlus1 i] | i <- [0 .. 127]] - where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 j >> 1)+1)*i) | j <- [0 .. 127]] - f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 j >> 1)+1)*i) | j <- [0 .. 127]] + where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 j)+1)*i) | j <- [0 .. 127]] + f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 j)+1)*i) | j <- [0 .. 127]] /** * Number theoretic transform: converts elements in `R_q` to `T_q`. @@ -749,22 +762,25 @@ submodule NTT where * quadratic modulus. * [FIPS-203] Section 4.3.1 Algorithm 12. */ -BaseCaseMultiply : [2] (Z q) -> [2] (Z q) -> (Z q) -> [2] (Z q) -BaseCaseMultiply a b root = [c0, c1] +BaseCaseMultiply : (Z q) -> (Z q) -> (Z q) -> (Z q) -> (Z q) -> [2](Z q) +BaseCaseMultiply a0 a1 b0 b1 γ = [c0, c1] where - c0 = a@1 * b@1 * root + a@0 * b@0 - c1 = a@0 * b@1 + a@1 * b@0 + c0 = a0 * b0 + a1 * b1 * γ + c1 = a0 * b1 + a1 * b0 /** * Compute the product (in the ring `T_q`) of two NTT representations. * [FIPS-203] Section 4.3.1 Algorithm 11. */ MultiplyNTTs : Tq -> Tq -> Tq -MultiplyNTTs a b = join [BaseCaseMultiply (f_hat_i i) (g_hat_i i) (root i) | i : Byte <- [0 .. 127]] - where - f_hat_i i = [a@(2*i),a@(2*i+1)] - g_hat_i i = [b@(2*i),b@(2*i+1)] - root i = (zeta^^(reverse (64 + (i >> 1)) >> 1) * ((-1 : (Z q)) ^^ (i))) +MultiplyNTTs f_hat g_hat = join h_hat where + h_hat = [ BaseCaseMultiply + (f_hat @(2*i)) + (f_hat @(2*i+1)) + (g_hat @(2*i)) + (g_hat @(2*i+1)) + (zeta ^^(2 * BitRev7 i + 1)) + | i <- [0 .. 127] ] /** * Testing that (1+x)^2 = 1+2x+x^2.