Skip to content

Commit

Permalink
Merge pull request #155 from GaloisInc/146-sampling
Browse files Browse the repository at this point in the history
MLKEM: Bring sampling + crypto functions up to gold standard
  • Loading branch information
marsella authored Oct 18, 2024
2 parents 74cd838 + 3231579 commit 54a6b6f
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 54 deletions.
155 changes: 102 additions & 53 deletions Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,10 @@
*/
module Primitive::Asymmetric::Cipher::ML_KEM::Specification where

import Primitive::Keyless::Hash::SHAKE::SHAKE256
import Primitive::Keyless::Hash::SHAKE::SHAKE128
import `Primitive::Keyless::Hash::SHA3::SHA3
import Primitive::Keyless::Hash::KeccakBitOrdering
import Primitive::Keyless::Hash::SHAKE::SHAKE256 as SHAKE256
import Primitive::Keyless::Hash::SHAKE::SHAKE128 as SHAKE128
import Primitive::Keyless::Hash::SHA3::SHA3_256 as SHA3_256
import Primitive::Keyless::Hash::SHA3::SHA3_512 as SHA3_512

/*
* [FIPS-203] Section 2.3.
Expand Down Expand Up @@ -66,38 +66,53 @@ type Z_q_256 = [n](Z q)
/**
* Pseudorandom function (PRF).
* [FIPS-203] Section 4.1, Equations 4.2 and 4.3.
*
* The SHA3 API operates over bit streams; the `groupBy` and `join` calls
* convert to and from our byte arrays.
*/
PRF : {prfeta} (fin prfeta, prfeta > 0) => [32]Byte -> Byte -> [64 * prfeta]Byte
PRF s b = map reverse (take (groupBy`{8} (shake256 (fromBytes(s)# reverse b))))
PRF : {eta} (2 <= eta, eta <= 3) => [32]Byte -> Byte -> [64 * eta]Byte
PRF s b = groupBy`{8} (SHAKE256::xof (join s # b))

/**
* One of the hash functions used in the protocol.
* [FIPS-203] Section 4.1, Equation 4.4.
*
* The SHA3 API operates over bit streams; the `groupBy` and `join` calls
* convert to and from our byte arrays.
*/
H : {hinl} (fin hinl) => [hinl]Byte -> [32]Byte
H M = toBytes(sha3 `{digest = 256} (fromBytes(M)))
H M = groupBy (SHA3_256::hash (join M))

/**
* One of the hash functions used in the protocol.
* [FIPS-203] Section 4.1, Equation 4.4.
*
* The SHA3 API operates over bit streams; the `groupBy` and `join` calls
* convert to and from our byte arrays.
*/
J : {hinl} (fin hinl) => [hinl]Byte -> [32]Byte
J(s) = take(groupBy(shake256(fromBytes(s))))
J s = groupBy (SHAKE256::xof (join s))

/**
* One of the hash functions used in the protocol.
* [FIPS-203] Section 4.1, Equation 4.5.
*
* The SHA3 API operates over bit streams; the `groupBy` and `join` calls
* convert to and from our byte arrays.
*/
G : {ginl} (fin ginl) => [ginl]Byte -> ([32]Byte, [32]Byte)
G M = (result@0, result@1)
where result = split`{2} (toBytes(sha3 `{digest = 512} (fromBytes(M))))
G M = (a, b) where
[a, b] = split`{2} (groupBy`{8} (SHA3_512::hash (join M)))

/**
* eXtendable-Output Function (XOF) wrapper.
* [FIPS-203] Section 4.1, Equation 4.6.
*
* The SHA3 API operates over bit streams; the `groupBy` and `join` calls
* convert to and from our byte arrays.
*/
XOF : ([34]Byte) -> [inf]Byte
XOF(d) = groupBy`{8}(shake128(fromBytes(d)))
XOF(d) = groupBy (SHAKE128::xof (join d))

/**
* Conversion from bit arrays to byte arrays.
Expand All @@ -113,7 +128,7 @@ BitsToBytes input = map reverse (groupBy input)
BytesToBits : {ell} (fin ell, ell > 0) => [ell]Byte -> [ell*8]Bit
BytesToBits input = join (map reverse input)

BitToZ : {p} (fin p, p > 1) => Bit -> Z p
BitToZ : Bit -> Z q
BitToZ b = if b then 1 else 0

BitstoZ : {ell} (fin ell, ell > 0) => [ell] -> (Z q)
Expand Down Expand Up @@ -274,7 +289,7 @@ property CorrectnessEncodeDecode fVec = all CorrectnessEncodeDecode' fVec
DecodeSpec : {ell} (fin ell, ell > 0) => [32 * ell]Byte -> Z_q_256
DecodeSpec B = [f i | i <- [0 .. 255]]
where betas = BytesToBits B : [256 * ell]
f i = sum [ BitToZ`{q}(betas@(i*`ell+j))*fromInteger(2^^j)
f i = sum [ BitToZ (betas@(i*`ell+j))*fromInteger(2^^j)
| j <- [0 .. (ell-1)]]

/**
Expand Down Expand Up @@ -315,59 +330,93 @@ property CorrectnessEncodeDecode' f = Decode'`{12}(Encode'`{12} f) == f
/**
* Uniformly sample NTT representations.
*
* This converts a stream `b` generated from a seed into a (pseudo)random
* polynomial in `T_q`.
*
* Since Cryptol does not natively support while loops, we approach this
* potentially infinite loop with recursion. `SampleNTTInf` converts an
* infinite sequence of bytes to an infinite sequence of elements in `Z q`.
* We then use the first `n` elements for the result.
* This uses a seed `B` to generate a pseudorandom stream, which is parsed into
* a polynomial in `T_q` drawn from a distribution indistinguishable from the
* uniform distribution.
*
* [FIPS-203] Section 4.2.2, Algorithm 7.
*/
SampleNTT : [34]Byte -> Z_q_256
SampleNTT B = take elements where
C_stream = XOF(B)
elements = SampleNTTInf C_stream
SampleNTT B = a_hat' where
// Steps 1-2, 5.
// We (lazily) take an infinite stream from the XOF and remove only as
// many bytes as are needed to compute the function. See [FIPS-203]
// Section 4.1 Equation 4.6 for a discussion of the equivalence of this
// form to the one in Algorithm 7.
ctx0 = XOF B

// Step 3. Since Cryptol is not imperative, we implement this loop using
// recursion. The `j` counter is not made explicit; instead we lazily
// generate an infinite stream of coefficients in `T_q` and `take` the
// correct length in the next line.

// Step 4-16. `take` fulfills the `j < 256` condition in Steps 4 and 12.
a_hat = take`{256} (filter ctx0)

// `filter` parses an infinite stream from the XOF, computing
// potential elements `d1` and `d2` from the first 3 bytes in the stream
// and adding them to the output if they are valid elements in `Z q`.
filter: [inf]Byte -> [inf][12]
filter XOFSqueeze = a_hat_j where
// Step 5.
(C # ctx) = XOFSqueeze

/**
* SampleNTTInf implements a filter. It scans the input 3 by 3, calculates
* the elements d1 and d2 and finally returns the elements that satisfy
* the conditions together with the result of itself when applied to the
* tail.
*
* This is an implementation of part of [FIPS-203] Section 4.2.2, Algorithm 7.
*/
SampleNTTInf: [inf]Byte -> [inf](Z q)
SampleNTTInf ([bi,bi1,bi2] # tailS) =
if d1 < `q then
if d2 < `q then
[fromInteger(d1),fromInteger(d2)] # SampleNTTInf tailS
else
[fromInteger(d1)] # SampleNTTInf tailS
else
if d2 < `q then
[fromInteger(d2)] # SampleNTTInf tailS
else
SampleNTTInf tailS
where
d1 = toInteger(reverse bi) + 256 * (toInteger(reverse bi1) % 16)
d2 = floor(ratio (toInteger(reverse bi1)) 16) + 16 * toInteger(reverse bi2)
// The conversion from 8-bit to 12-bit vectors (with the same value!)
// is implicit in the spec -- see notes on Step 5 and 6/7. In Cryptol,
// we need to convert manually to do the subsequent computations.
[C0, C1, C2] = map zext`{12} C

// Step 6.
d1 = C0 + 256 * (C1 % 16)

// Step 7. Cryptol uses integer division; it always takes the floor of
// the result.
d2 = (C1 / 16) + 16 * C2

// Steps 4, 8 - 15.
// Add `d1` and/or `d2` to the sampled vector `a_hat` if they are valid
// elements in `Z q`.
// The `while` loop in Step 4 is equivalent to the recursive call to
// `filter` in each condition.
a_hat_j = if (d1 < `q) && (d2 < `q) then
[d1, d2] # filter ctx
else if d1 < `q then
[d1] # filter ctx
else if d2 < `q then
[d2] # filter ctx
else filter ctx

// This conversion is implicit in the implementation -- see the notes on
// Step 6/7 and 9.
toZq : [12] -> Z q
toZq x = fromInteger (toInteger x)

a_hat' = map toZq a_hat

/**
* Sample a special, centered distribution of polynomials in `R_q` with small
* coefficients.
*
* Requires that the input stream `B` is uniformly random bytes.
* The input stream `B` must be uniformly random bytes!
*
* [FIPS-203] Section 4.2.2, Algorithm 8.
*/
SamplePolyCBD: {eta} (fin eta, eta > 0) => [64 * eta]Byte -> Z_q_256
SamplePolyCBD B = [f i | i <- [0 .. 255]]
where betas = BytesToBits B : [512 * eta]
x i = sum [BitToZ`{q} (betas@(2*i*`eta+j)) | j <- [0 .. (eta-1)]]
y i = sum [BitToZ`{q} (betas@(2*i*`eta+`eta+j)) | j <- [0 .. (eta-1)]]
f i = (x i) - (y i)
SamplePolyCBD: {eta} (2 <= eta, eta <= 3) => [64 * eta]Byte -> Z_q_256
SamplePolyCBD B = f where
// Step 1.
b = BytesToBits B
// This conversion is implicit in the implementation. Convert each bit into
// an element of Z q.
b' = map BitToZ b

// Step 3.
x i = sum [b'@(2 * i * `eta + j) | j <- [0 .. (eta-1)]]
// Step 4.
y i = sum [b'@(2 * i * `eta + `eta + j) | j <- [0 .. (eta-1)]]

// Steps 2, 5. The `mod q` is not explicit here because `x` and `y`
// return elements of `Z q`.
f = [(x i) - (y i) | i <- [0 .. 255]]

/**
* [FIPS-203] Section 4.3 "The mathematical structure of the NTT."
Expand Down
30 changes: 29 additions & 1 deletion Primitive/Asymmetric/Cipher/ML_KEM/Tests/ML_KEM512.cry
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ property t0 = keygenWorks && encapsWorks && decapsWorks where
keygenWorks = ML_KEM::ML_KEM_KeyGen (z, d) == (pk, sk)
encapsWorks = ML_KEM::ML_KEM_Encaps (pk, msg) == (ss, ct)
decapsWorks = ML_KEM::ML_KEM_Decaps (ct, sk) == ss
decapsFailsCorrectly = ML_KEM::ML_KEM_Decaps (ct_n, sk) == ss_n

z = split 0xf696484048ec21f96cf50a56d0759c448f3779752f0383d37449690694cf7a68
d = split 0x6dbbc4375136df3b07f7c70e639e223e177e7fd53b161b3f4d57791794f12624
Expand Down Expand Up @@ -137,4 +138,31 @@ property t0 = keygenWorks && encapsWorks && decapsWorks where
0x72e6dab0edf7e3b08ed1316acf857254b6c1fbf64ce582f2990ccecf9505fd7e,
0x8517998aaf583be1aa641e9b54dbb91ca9d0700913967fb0349201b5d679b32d
])
ss = split 0x2b5c52ee72946331983ba050be0f435055c0547901e03559b356517889ea27c5
ss = split 0x2b5c52ee72946331983ba050be0f435055c0547901e03559b356517889ea27c5
ct_n = split (join [
0x96ac6243c9b1272be77b975a4048bf00ff2c48f94a3483362449273880d45e54,
0xbda15729682bf591a74382a708beb78118cab29ad74ac2f405ba720076dfb571,
0x88dc168487cd20081f6bf412f257dea03406b23a6a752e478ba4ef9c7c0f4810,
0x921fa32545be64dc5d9f18d4e1320efc6508154cda35ab912d059e0291a1150a,
0xe0a10da5e3d7bd221a851c598df4d0b18daa920976556099d1c0de4e222d5304,
0xd44fa9cb9bd4ffe15769dd6c4793fa809f5264cf0febca4b5975ba287639783a,
0xa1f4b645ff7a00d46ee7b19fec17b3e83bcaf4361d5349e30ceab60c386b6b0d,
0x1b90d8b336ee6a627ad2a38670cb5113b0fb4ac2ddc4250097483fefd182670e,
0xa40f0f45cce90b9ed58dafaef657d64e25fd6692a69721994e7d00b4949205eb,
0xe4c4f9c46ee5a1018b220a26d80ae2d2b486372e974d75b20a005b1616ad1e13,
0xd162915cc24f274670d1e5e8bd345874a7e7c9759c8e43ff33689200739a6133,
0x95f7ae78d73c6a7b90f65ab511f0df3c5dca85d0b9430b4e97098715ff823b61,
0x7321799aea0ab9c72234780339ec7b541d5e6f8c1551146c24a65411811b2367,
0x4c26123356cf233351382c3994cba5dc6c25a07e1ba9af33eca18bba3e97935e,
0x3abdf07e9fa32cecf241e7cafc6592db4ee487ff2b98a4a47805dee17fd93448,
0xdc98457b753ed4995ee6b1bfa9ff1d386c91f396ca8f48cab5b09a782ec3b616,
0xa87a6448a96236c4655413af755323d36a8db2e16509454489e6ec83629130cd,
0x2a54817918af362c83183494b4b590dbaf69cf399d3e2dc3e9c0c1224f148e65,
0xef68287341ab72ad58adfc69b28e27e91ebbf830fac53b94f762f01cc9b1561a,
0xe35f16edabf51ff164c1309d1fdb52cd2bfedb5a492eb65cb9fc86b8f05ed26d,
0x13233fb0a3eb33a9dce2cf98e6516cee42fbe1e97e20ab6c9965f58a377dc73e,
0x530667ab8f45e6a70b23db50f0df411732d8acdabe50c51adb886c0e5a5296d4,
0xaa1b13a336f0c17812f79fc69418a7d8901c568f410eff2af74baaeb8336f46c,
0xa17e14e060ce2d45cdb376286eec8b8befa5ab8025802720a1e7393af579db13
])
ss_n = split 0x6e5c522a6d19b86c61bd983b56a0bef351c5ce716f021b49bdecd7bdfd5ed55a

0 comments on commit 54a6b6f

Please sign in to comment.