From 2474ff578b0f7c1e2134c4cb613a62b4e837444e Mon Sep 17 00:00:00 2001 From: Marcella Hastings Date: Tue, 15 Oct 2024 14:33:00 -0400 Subject: [PATCH] mlkem: add Rq / Tq types and use them #147 This doesn't replace all uses of `Z_q_256`, but it gets all the easy ones. --- .../Cipher/ML_KEM/Specification.cry | 83 ++++++++++++------- 1 file changed, 52 insertions(+), 31 deletions(-) diff --git a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry index 5a077c2..a67a5c2 100644 --- a/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry +++ b/Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry @@ -10,7 +10,7 @@ * * Sources: * [FIPS-203]: National Institute of Standards and Technology. Module-Lattice- - * Basead Key-Encapsulation Mechanism Standard. (Department of Commerce, + * Based Key-Encapsulation Mechanism Standard. (Department of Commerce, * Washington, D.C.), Federal Information Processing Standards Publication * (FIPS) NIST FIPS 203. August 2024. * @see https://doi.org/10.6028/NIST.FIPS.203 @@ -63,6 +63,33 @@ type Byte = [8] */ type Z_q_256 = [n](Z q) +/** + * An element in the ring `R_q`. + * + * An element in the ring is a polynomial of degree at most 255 (e.g. with 256 + * terms). The `i`th element in this array represents the coefficient of the + * degree-`i` term. + * + * [FIPS-203] Section 2.3 (definition of the ring). + * [FIPS-203] Section 2.4.4, Equation 2.5 (definition of the representation of + * elements in the ring). + */ +type Rq = [n](Z q) + +/** + * An element in the ring `T_q`. + * + * An element in this ring (sometimes called the "NTT representation") is a + * tuple of 128 polynomials, each of degree at most one (e.g. with two terms). + * The `2i` and `2i+1`th terms in this array represent the degree-0 and + * degree-1 coefficients of the `i`th polynomial, respectively. + * + * [FIPS-203] Section 2.3 (definition of the `T_q`). + * [FIPS-203] Section 2.4.4 Equation 2.7 (definition of the representation of + * an element in `T_q`). + */ +type Tq = [n](Z q) + /** * Pseudorandom function (PRF). * [FIPS-203] Section 4.1, Equations 4.2 and 4.3. @@ -336,7 +363,7 @@ property CorrectnessEncodeDecode' f = Decode'`{12}(Encode'`{12} f) == f * * [FIPS-203] Section 4.2.2, Algorithm 7. */ -SampleNTT : [34]Byte -> Z_q_256 +SampleNTT : [34]Byte -> Tq SampleNTT B = a_hat' where // Steps 1-2, 5. // We (lazily) take an infinite stream from the XOF and remove only as @@ -401,7 +428,7 @@ SampleNTT B = a_hat' where * * [FIPS-203] Section 4.2.2, Algorithm 8. */ -SamplePolyCBD: {eta} (2 <= eta, eta <= 3) => [64 * eta]Byte -> Z_q_256 +SamplePolyCBD: {eta} (2 <= eta, eta <= 3) => [64 * eta]Byte -> Rq SamplePolyCBD B = f where // Step 1. b = BytesToBits B @@ -469,7 +496,7 @@ submodule NTT where * * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 9. */ - ParametricNTT : Z_q_256 -> (Z q) -> Z_q_256 + ParametricNTT : Rq -> (Z q) -> Tq ParametricNTT f root = join[[f2i i, f2iPlus1 i] | i <- [0 .. 127]] where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 i >> 1)+1)*j) | j <- [0 .. 127]] f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 i >> 1)+1)*j) | j <- [0 .. 127]] @@ -481,7 +508,7 @@ submodule NTT where * * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 10. */ - ParametricNTTInv : Z_q_256 -> (Z q) -> Z_q_256 + ParametricNTTInv : Tq -> (Z q) -> Rq ParametricNTTInv f root = join[[f2i i, f2iPlus1 i] | i <- [0 .. 127]] where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 j >> 1)+1)*i) | j <- [0 .. 127]] f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 j >> 1)+1)*i) | j <- [0 .. 127]] @@ -491,7 +518,7 @@ submodule NTT where * * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 9. */ - NaiveNTT : Z_q_256 -> Z_q_256 + NaiveNTT : Rq -> Tq NaiveNTT f = ParametricNTT f zeta /** @@ -500,7 +527,7 @@ submodule NTT where * * This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 10. */ - NaiveNTTInv : Z_q_256 -> Z_q_256 + NaiveNTTInv : Tq -> Rq NaiveNTTInv f = [term*(recip 128) | term <- ParametricNTTInv f (recip zeta)] ////////////////////////////////////////////////////////////// @@ -565,7 +592,7 @@ submodule NTT where [s0, s1] = split t // Top level entry point - start with lv=256, k=1 - fast_ntt : Z_q_256 -> Z_q_256 + fast_ntt : Rq -> Tq fast_ntt v = fast_nttl v 1 // Fast recursive GS-Inverse-NTT @@ -605,7 +632,7 @@ submodule NTT where mul_recip128 v = [ v@x * recip_128_modq | x <- [0 .. Z_q_256 + fast_invntt : Tq -> Rq fast_invntt v = mul_recip128 (fast_invnttl v 1) ////////////////////////////////////////////////////////////// @@ -618,7 +645,7 @@ submodule NTT where * :prove NaiveNTT_Inverts * ``` */ - NaiveNTT_Inverts : Z_q_256 -> Bit + NaiveNTT_Inverts : Rq -> Bit property NaiveNTT_Inverts f = NaiveNTTInv (NaiveNTT f) == f /** @@ -627,7 +654,7 @@ submodule NTT where * :prove NaiveNTTInv_Inverts * ``` */ - NaiveNTTInv_Inverts : Z_q_256 -> Bit + NaiveNTTInv_Inverts : Tq -> Bit property NaiveNTTInv_Inverts f = NaiveNTT (NaiveNTTInv f) == f /** @@ -636,7 +663,7 @@ submodule NTT where * :prove fast_ntt_inverts * ``` */ - fast_ntt_inverts : Z_q_256 -> Bit + fast_ntt_inverts : Rq -> Bit property fast_ntt_inverts f = fast_invntt (fast_ntt f) == f /** @@ -645,7 +672,7 @@ submodule NTT where * :prove fast_invntt_inverts * ``` */ - fast_invntt_inverts : Z_q_256 -> Bit + fast_invntt_inverts : Tq -> Bit property fast_invntt_inverts f = fast_ntt (fast_invntt f) == f /** @@ -654,7 +681,7 @@ submodule NTT where * :prove naive_fast_ntt_equiv * ``` */ - naive_fast_ntt_equiv : Z_q_256 -> Bit + naive_fast_ntt_equiv : Rq -> Bit property naive_fast_ntt_equiv f = NaiveNTT f == fast_ntt f /** @@ -663,7 +690,7 @@ submodule NTT where * :prove naive_fast_invntt_equiv * ``` */ - naive_fast_invntt_equiv : Z_q_256 -> Bit + naive_fast_invntt_equiv : Tq -> Bit property naive_fast_invntt_equiv f = NaiveNTTInv f == fast_invntt f ////////////////////////////////////////////////////////////// @@ -672,17 +699,11 @@ submodule NTT where // Here, we can choose to call either the naive or fast NTT ////////////////////////////////////////////////////////////// - NTT' : Z_q_256 -> Z_q_256 - // fast - NTT' f = fast_ntt f - // slow - //NTT' f = NaiveNTT f + NTT' : Rq -> Tq + NTT' = fast_ntt - NTTInv' : Z_q_256 -> Z_q_256 - // fast - NTTInv' f = fast_invntt f - // slow - //NTTInv' f = NaiveNTTInv f + NTTInv' : Tq -> Rq + NTTInv' = fast_invntt /** * The notation `NTT` is overloaded to mean both a single application of `NTT` @@ -719,7 +740,7 @@ BaseCaseMultiply a b root = [c0, c1] * Compute the product (in the ring `T_q`) of two NTT representations. * [FIPS-203] Section 4.3.1 Algorithm 11. */ -MultiplyNTTs : Z_q_256 -> Z_q_256 -> Z_q_256 +MultiplyNTTs : Tq -> Tq -> Tq MultiplyNTTs a b = join [BaseCaseMultiply (f_hat_i i) (g_hat_i i) (root i) | i : Byte <- [0 .. 127]] where f_hat_i i = [a@(2*i),a@(2*i+1)] @@ -737,7 +758,7 @@ property TestMult = prod f f == fsq where f = [1, 1] # [0 | i <- [3 .. 256]] fsq = [1,2,1] # [0 | i <- [4 .. 256]] - prod : Z_q_256 -> Z_q_256 -> Z_q_256 + prod : Rq -> Rq -> Rq prod a b = NTTInv' (MultiplyNTTs (NTT' a) (NTT' b)) /** @@ -745,7 +766,7 @@ property TestMult = prod f f == fsq where * (also referred to as `T_q` multiplication). * [FIPS-203] Section 2.4.5 Equation 2.8. */ -dot : Z_q_256 -> Z_q_256 -> Z_q_256 +dot : Tq -> Tq -> Tq dot f g = MultiplyNTTs f g /** @@ -753,7 +774,7 @@ dot f g = MultiplyNTTs f g * functionality with `T_q` multiplication as the base operation. * [FIPS-203] Section 2.4.7 Equation 2.14. */ -dotVecVec : {k1} (fin k1) => [k1]Z_q_256 -> [k1]Z_q_256 -> Z_q_256 +dotVecVec : {k1} (fin k1) => [k1]Tq -> [k1]Tq -> Tq dotVecVec v1 v2 = sum (zipWith dot v1 v2) /** @@ -761,7 +782,7 @@ dotVecVec v1 v2 = sum (zipWith dot v1 v2) * vector multiplication with `T_q` multiplication as the base operation. * [FIPS-203] Section 2.4.7 Equation 2.12 and 2.13. */ -dotMatVec : {k1,k2} (fin k1, fin k2) => [k1][k2]Z_q_256 -> [k2]Z_q_256 -> [k1]Z_q_256 +dotMatVec : {k1,k2} (fin k1, fin k2) => [k1][k2]Tq -> [k2]Tq -> [k1]Tq dotMatVec matrix vector = [dotVecVec v1 vector | v1 <- matrix] /** @@ -770,7 +791,7 @@ dotMatVec matrix vector = [dotVecVec v1 vector | v1 <- matrix] * [FIPS-203] Section 2.4.7. */ dotMatMat :{k1,k2,k3} (fin k1, fin k2, fin k3) => - [k1][k2]Z_q_256 -> [k2][k3]Z_q_256 -> [k1][k3]Z_q_256 + [k1][k2]Tq -> [k2][k3]Tq -> [k1][k3]Tq dotMatMat matrix1 matrix2 = transpose [dotMatVec matrix1 vector | vector <- m'] where m' = transpose matrix2