Skip to content

Commit

Permalink
mlkem: add Rq / Tq types and use them #147
Browse files Browse the repository at this point in the history
This doesn't replace all uses of `Z_q_256`, but it gets all the easy
ones.
  • Loading branch information
marsella committed Oct 31, 2024
1 parent c480bcd commit 2474ff5
Showing 1 changed file with 52 additions and 31 deletions.
83 changes: 52 additions & 31 deletions Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
*
* Sources:
* [FIPS-203]: National Institute of Standards and Technology. Module-Lattice-
* Basead Key-Encapsulation Mechanism Standard. (Department of Commerce,
* Based Key-Encapsulation Mechanism Standard. (Department of Commerce,
* Washington, D.C.), Federal Information Processing Standards Publication
* (FIPS) NIST FIPS 203. August 2024.
* @see https://doi.org/10.6028/NIST.FIPS.203
Expand Down Expand Up @@ -63,6 +63,33 @@ type Byte = [8]
*/
type Z_q_256 = [n](Z q)

/**
* An element in the ring `R_q`.
*
* An element in the ring is a polynomial of degree at most 255 (e.g. with 256
* terms). The `i`th element in this array represents the coefficient of the
* degree-`i` term.
*
* [FIPS-203] Section 2.3 (definition of the ring).
* [FIPS-203] Section 2.4.4, Equation 2.5 (definition of the representation of
* elements in the ring).
*/
type Rq = [n](Z q)

/**
* An element in the ring `T_q`.
*
* An element in this ring (sometimes called the "NTT representation") is a
* tuple of 128 polynomials, each of degree at most one (e.g. with two terms).
* The `2i` and `2i+1`th terms in this array represent the degree-0 and
* degree-1 coefficients of the `i`th polynomial, respectively.
*
* [FIPS-203] Section 2.3 (definition of the `T_q`).
* [FIPS-203] Section 2.4.4 Equation 2.7 (definition of the representation of
* an element in `T_q`).
*/
type Tq = [n](Z q)

/**
* Pseudorandom function (PRF).
* [FIPS-203] Section 4.1, Equations 4.2 and 4.3.
Expand Down Expand Up @@ -336,7 +363,7 @@ property CorrectnessEncodeDecode' f = Decode'`{12}(Encode'`{12} f) == f
*
* [FIPS-203] Section 4.2.2, Algorithm 7.
*/
SampleNTT : [34]Byte -> Z_q_256
SampleNTT : [34]Byte -> Tq
SampleNTT B = a_hat' where
// Steps 1-2, 5.
// We (lazily) take an infinite stream from the XOF and remove only as
Expand Down Expand Up @@ -401,7 +428,7 @@ SampleNTT B = a_hat' where
*
* [FIPS-203] Section 4.2.2, Algorithm 8.
*/
SamplePolyCBD: {eta} (2 <= eta, eta <= 3) => [64 * eta]Byte -> Z_q_256
SamplePolyCBD: {eta} (2 <= eta, eta <= 3) => [64 * eta]Byte -> Rq
SamplePolyCBD B = f where
// Step 1.
b = BytesToBits B
Expand Down Expand Up @@ -469,7 +496,7 @@ submodule NTT where
*
* This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 9.
*/
ParametricNTT : Z_q_256 -> (Z q) -> Z_q_256
ParametricNTT : Rq -> (Z q) -> Tq
ParametricNTT f root = join[[f2i i, f2iPlus1 i] | i <- [0 .. 127]]
where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 i >> 1)+1)*j) | j <- [0 .. 127]]
f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 i >> 1)+1)*j) | j <- [0 .. 127]]
Expand All @@ -481,7 +508,7 @@ submodule NTT where
*
* This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 10.
*/
ParametricNTTInv : Z_q_256 -> (Z q) -> Z_q_256
ParametricNTTInv : Tq -> (Z q) -> Rq
ParametricNTTInv f root = join[[f2i i, f2iPlus1 i] | i <- [0 .. 127]]
where f2i i = sum [f@(2*j) * root ^^ ((2*(BitRev7 j >> 1)+1)*i) | j <- [0 .. 127]]
f2iPlus1 i = sum [f@(2*j+1) * root ^^ ((2*(BitRev7 j >> 1)+1)*i) | j <- [0 .. 127]]
Expand All @@ -491,7 +518,7 @@ submodule NTT where
*
* This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 9.
*/
NaiveNTT : Z_q_256 -> Z_q_256
NaiveNTT : Rq -> Tq
NaiveNTT f = ParametricNTT f zeta

/**
Expand All @@ -500,7 +527,7 @@ submodule NTT where
*
* This roughly corresponds to [FIPS-203] Section 4.3, Algorithm 10.
*/
NaiveNTTInv : Z_q_256 -> Z_q_256
NaiveNTTInv : Tq -> Rq
NaiveNTTInv f = [term*(recip 128) | term <- ParametricNTTInv f (recip zeta)]

//////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -565,7 +592,7 @@ submodule NTT where
[s0, s1] = split t

// Top level entry point - start with lv=256, k=1
fast_ntt : Z_q_256 -> Z_q_256
fast_ntt : Rq -> Tq
fast_ntt v = fast_nttl v 1

// Fast recursive GS-Inverse-NTT
Expand Down Expand Up @@ -605,7 +632,7 @@ submodule NTT where
mul_recip128 v = [ v@x * recip_128_modq | x <- [0 .. <n] ]

// Top level entry point - start with lv=256, k=1
fast_invntt : Z_q_256 -> Z_q_256
fast_invntt : Tq -> Rq
fast_invntt v = mul_recip128 (fast_invnttl v 1)

//////////////////////////////////////////////////////////////
Expand All @@ -618,7 +645,7 @@ submodule NTT where
* :prove NaiveNTT_Inverts
* ```
*/
NaiveNTT_Inverts : Z_q_256 -> Bit
NaiveNTT_Inverts : Rq -> Bit
property NaiveNTT_Inverts f = NaiveNTTInv (NaiveNTT f) == f

/**
Expand All @@ -627,7 +654,7 @@ submodule NTT where
* :prove NaiveNTTInv_Inverts
* ```
*/
NaiveNTTInv_Inverts : Z_q_256 -> Bit
NaiveNTTInv_Inverts : Tq -> Bit
property NaiveNTTInv_Inverts f = NaiveNTT (NaiveNTTInv f) == f

/**
Expand All @@ -636,7 +663,7 @@ submodule NTT where
* :prove fast_ntt_inverts
* ```
*/
fast_ntt_inverts : Z_q_256 -> Bit
fast_ntt_inverts : Rq -> Bit
property fast_ntt_inverts f = fast_invntt (fast_ntt f) == f

/**
Expand All @@ -645,7 +672,7 @@ submodule NTT where
* :prove fast_invntt_inverts
* ```
*/
fast_invntt_inverts : Z_q_256 -> Bit
fast_invntt_inverts : Tq -> Bit
property fast_invntt_inverts f = fast_ntt (fast_invntt f) == f

/**
Expand All @@ -654,7 +681,7 @@ submodule NTT where
* :prove naive_fast_ntt_equiv
* ```
*/
naive_fast_ntt_equiv : Z_q_256 -> Bit
naive_fast_ntt_equiv : Rq -> Bit
property naive_fast_ntt_equiv f = NaiveNTT f == fast_ntt f

/**
Expand All @@ -663,7 +690,7 @@ submodule NTT where
* :prove naive_fast_invntt_equiv
* ```
*/
naive_fast_invntt_equiv : Z_q_256 -> Bit
naive_fast_invntt_equiv : Tq -> Bit
property naive_fast_invntt_equiv f = NaiveNTTInv f == fast_invntt f

//////////////////////////////////////////////////////////////
Expand All @@ -672,17 +699,11 @@ submodule NTT where
// Here, we can choose to call either the naive or fast NTT
//////////////////////////////////////////////////////////////

NTT' : Z_q_256 -> Z_q_256
// fast
NTT' f = fast_ntt f
// slow
//NTT' f = NaiveNTT f
NTT' : Rq -> Tq
NTT' = fast_ntt

NTTInv' : Z_q_256 -> Z_q_256
// fast
NTTInv' f = fast_invntt f
// slow
//NTTInv' f = NaiveNTTInv f
NTTInv' : Tq -> Rq
NTTInv' = fast_invntt

/**
* The notation `NTT` is overloaded to mean both a single application of `NTT`
Expand Down Expand Up @@ -719,7 +740,7 @@ BaseCaseMultiply a b root = [c0, c1]
* Compute the product (in the ring `T_q`) of two NTT representations.
* [FIPS-203] Section 4.3.1 Algorithm 11.
*/
MultiplyNTTs : Z_q_256 -> Z_q_256 -> Z_q_256
MultiplyNTTs : Tq -> Tq -> Tq
MultiplyNTTs a b = join [BaseCaseMultiply (f_hat_i i) (g_hat_i i) (root i) | i : Byte <- [0 .. 127]]
where
f_hat_i i = [a@(2*i),a@(2*i+1)]
Expand All @@ -737,31 +758,31 @@ property TestMult = prod f f == fsq where
f = [1, 1] # [0 | i <- [3 .. 256]]
fsq = [1,2,1] # [0 | i <- [4 .. 256]]

prod : Z_q_256 -> Z_q_256 -> Z_q_256
prod : Rq -> Rq -> Rq
prod a b = NTTInv' (MultiplyNTTs (NTT' a) (NTT' b))

/**
* The cross product notation ×𝑇𝑞 is defined as the `MultiplyNTTs` function
* (also referred to as `T_q` multiplication).
* [FIPS-203] Section 2.4.5 Equation 2.8.
*/
dot : Z_q_256 -> Z_q_256 -> Z_q_256
dot : Tq -> Tq -> Tq
dot f g = MultiplyNTTs f g

/**
* Overloaded `dot` function between two vectors is a standard dot-product
* functionality with `T_q` multiplication as the base operation.
* [FIPS-203] Section 2.4.7 Equation 2.14.
*/
dotVecVec : {k1} (fin k1) => [k1]Z_q_256 -> [k1]Z_q_256 -> Z_q_256
dotVecVec : {k1} (fin k1) => [k1]Tq -> [k1]Tq -> Tq
dotVecVec v1 v2 = sum (zipWith dot v1 v2)

/**
* Overloaded `dot` function between a matrix and a vector is standard matrix-
* vector multiplication with `T_q` multiplication as the base operation.
* [FIPS-203] Section 2.4.7 Equation 2.12 and 2.13.
*/
dotMatVec : {k1,k2} (fin k1, fin k2) => [k1][k2]Z_q_256 -> [k2]Z_q_256 -> [k1]Z_q_256
dotMatVec : {k1,k2} (fin k1, fin k2) => [k1][k2]Tq -> [k2]Tq -> [k1]Tq
dotMatVec matrix vector = [dotVecVec v1 vector | v1 <- matrix]

/**
Expand All @@ -770,7 +791,7 @@ dotMatVec matrix vector = [dotVecVec v1 vector | v1 <- matrix]
* [FIPS-203] Section 2.4.7.
*/
dotMatMat :{k1,k2,k3} (fin k1, fin k2, fin k3) =>
[k1][k2]Z_q_256 -> [k2][k3]Z_q_256 -> [k1][k3]Z_q_256
[k1][k2]Tq -> [k2][k3]Tq -> [k1][k3]Tq
dotMatMat matrix1 matrix2 = transpose [dotMatVec matrix1 vector | vector <- m']
where m' = transpose matrix2

Expand Down

0 comments on commit 2474ff5

Please sign in to comment.