Skip to content

Commit

Permalink
mlkem: apply 12-bit encode/decode #144
Browse files Browse the repository at this point in the history
Also adds vector versions of encode/decode-12.
  • Loading branch information
marsella committed Oct 21, 2024
1 parent f4ead7c commit 6876435
Showing 1 changed file with 16 additions and 4 deletions.
20 changes: 16 additions & 4 deletions Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,12 @@ ByteEncode12 F = B where
b = join [ [ a ! 0 | a <- ajs] | ajs <- as]
B = BitsToBytes b

/**
* Encode a `k`-element array of 12-bit vectors into a byte array.
*/
ByteEncode12_Vec : [k][256](Z q) -> [k * 32 * 12]Byte
ByteEncode12_Vec F = join (map ByteEncode12 F)

/**
* The subtract-and-divide algorithm applied to `a` in `ByteEncode` is the
* same as shifting right.
Expand Down Expand Up @@ -337,6 +343,12 @@ ByteDecode12 B = F' where
toZq f = fromInteger (toInteger f)
F' = map toZq F

/**
* Decode `k` joined byte arrays into `k` arrays of integers mod `q`.
*/
ByteDecode12_Vec : [k * 32 * 12]Byte -> [k][256](Z q)
ByteDecode12_Vec B = map ByteDecode12 (split B)

/**
* Multiplying a value by `2^^j` is the same as bit-shifting it left by `j`
* bits.
Expand Down Expand Up @@ -1006,9 +1018,9 @@ private submodule K_PKE where
// Step 18.
t_hat = (dotMatVec A_hat s_hat) + e_hat
// Step 19.
ekPKE = (Encode`{12} t_hat) # ρ
ekPKE = (ByteEncode12_Vec t_hat) # ρ
// Step 20.
dkPKE = Encode`{12} (s_hat)
dkPKE = ByteEncode12_Vec s_hat

/**
* Encryption algorithm for the K-PKE component scheme.
Expand All @@ -1022,7 +1034,7 @@ private submodule K_PKE where
Encrypt : EncryptionKey -> [32]Byte -> [32]Byte -> Ciphertext
Encrypt ekPKE m r = c where
// Step 2.
t_hat = Decode`{12} (ekPKE @@[0 .. 384*k - 1])
t_hat = ByteDecode12_Vec (ekPKE @@[0 .. 384*k - 1])
// Step 3.
rho = ekPKE @@[384*k .. 384*k + 32 - 1]
// Steps 4-8.
Expand Down Expand Up @@ -1074,7 +1086,7 @@ private submodule K_PKE where
// Step 4.
v' = Decompress'`{d_v} (ByteDecode`{d_v} c2)
// Step 5.
s_hat = Decode`{12} dkPKE
s_hat = ByteDecode12_Vec dkPKE
// Step 6.
w = v' - NTTInv' (dotVecVec s_hat (NTT u'))
// Step 7.
Expand Down

0 comments on commit 6876435

Please sign in to comment.