Skip to content

Commit

Permalink
mlkem: bring ntt names into alignment #147
Browse files Browse the repository at this point in the history
This replaces `'`s with suffixes explictly describing what type of data
each NTT function operates over.
  • Loading branch information
marsella committed Oct 15, 2024
1 parent fae877d commit 63d367e
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry
Original file line number Diff line number Diff line change
Expand Up @@ -649,27 +649,27 @@ submodule NTT where
// Here, we can choose to call either the naive or fast NTT
//////////////////////////////////////////////////////////////

NTT' : Rq -> Tq
NTT' = fast_ntt
NTT : Rq -> Tq
NTT = fast_ntt

NTTInv' : Tq -> Rq
NTTInv' = fast_invntt
NTTInv : Tq -> Rq
NTTInv = fast_invntt

/**
* The notation `NTT` is overloaded to mean both a single application of `NTT`
* to an element of `R_q` and also `k` applications of `NTT` to every element
* of a `k`-length vector.
* [FIPS-203] Section 2.4.6 Equation 2.9.
*/
NTT v = map NTT' v
NTT_Vec v = map NTT v

/**
* The notation `NTTInv` is overloaded to mean both a single application of
* `NTTInv` to an element of `R_q` and also `k` applications of `NTTInv` to
* every element of a `k`-length vector.
* [FIPS-203] Section 2.4.6.
*/
NTTInv v = map NTTInv' v
NTTInv_Vec v = map NTTInv v

//////////////////////////////////////////////////////////////
// Polynomial multiplication in the NTT domain
Expand Down Expand Up @@ -709,7 +709,7 @@ property TestMult = prod f f == fsq where
fsq = [1,2,1] # [0 | i <- [4 .. 256]]

prod : Rq -> Rq -> Rq
prod a b = NTTInv' (MultiplyNTTs (NTT' a) (NTT' b))
prod a b = NTTInv (MultiplyNTTs (NTT a) (NTT b))

/**
* The cross product notation ×𝑇𝑞 is defined as the `MultiplyNTTs` function
Expand Down Expand Up @@ -762,8 +762,8 @@ K_PKE_KeyGen(d) = (ekPKE, dkPKE) where
A_hat = [[SampleNTT (XOF(ρ # [j] # [i])) | j <- [0 .. k-1]] | i <- [0 .. k-1]] : [k][k]Z_q_256
s = [SamplePolyCBD`{eta_1}(PRF(σ,N)) | N <- [0 .. k-1]] : [k]Z_q_256
e = [SamplePolyCBD`{eta_1}(PRF(σ,N)) | N <- [k .. (2*k-1)]] : [k]Z_q_256
s_hat = NTT(s)
e_hat = NTT(e)
s_hat = NTT_Vec(s)
e_hat = NTT_Vec(e)
t_hat = (dotMatVec A_hat s_hat) + e_hat
ekPKE = Encode`{12}(t_hat) # ρ
dkPKE = Encode`{12}(s_hat)
Expand All @@ -785,10 +785,10 @@ K_PKE_Encrypt(ekPKE, m, r) = c where
yvec = [SamplePolyCBD`{eta_1}(PRF(r,N)) | N <- [0 .. k-1]] : [k]Z_q_256
e1 = [SamplePolyCBD`{eta_2}(PRF(r,N)) | N <- [k .. (2*k-1)]] : [k]Z_q_256
e2 = SamplePolyCBD`{eta_2}(PRF(r,2*`k)) : Z_q_256
yvechat = NTT yvec
u = NTTInv (dotMatVec (transpose A_hat) yvechat) + e1 : [k]Z_q_256
yvechat = NTT_Vec yvec
u = NTTInv_Vec (dotMatVec (transpose A_hat) yvechat) + e1 : [k]Z_q_256
mu = Decompress'`{1}(DecodeBytes'`{1} m)
v = (NTTInv' (dotVecVec t_hat yvechat)) + e2 + mu
v = (NTTInv (dotVecVec t_hat yvechat)) + e2 + mu
c1 = EncodeBytes`{d_u}(Compress`{d_u}(u))
c2 = EncodeBytes'`{d_v}(Compress'`{d_v}(v))
c = c1#c2
Expand All @@ -809,7 +809,7 @@ K_PKE_Decrypt(sk, c) = m where
u = Decompress`{d_u}(DecodeBytes`{d_u} c1) : [k]Z_q_256
v = Decompress'`{d_v}(DecodeBytes'`{d_v} c2) : Z_q_256
s_hat = Decode`{12} sk : [k]Z_q_256
w = v - NTTInv' (dotVecVec s_hat (NTT u))
w = v - NTTInv (dotVecVec s_hat (NTT_Vec u))
m = EncodeBytes'`{1}(Compress'`{1}(w))

/**
Expand Down

0 comments on commit 63d367e

Please sign in to comment.