Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ML-KEM: Bring encoding functions in line with the spec #160

Open
wants to merge 11 commits into
base: master
Choose a base branch
from
254 changes: 164 additions & 90 deletions Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry
Original file line number Diff line number Diff line change
Expand Up @@ -203,129 +203,203 @@ Compress x = map Compress'`{d} x
Decompress : {d, k1} (d < lg2 q, fin k1) => [k1][n][d] -> [k1]Z_q_256
Decompress x = map Decompress'`{d} x

/*
* We make this trivial serialization explicit, since it is not an identity in Cryptol.
* Byte encoding and decoding involves regrouping 8-bit arrays into ell-bit arrays.
*/
regroup B = reverse (groupBy (join (reverse B)))

/**
* This is used in some places where the `ByteEncode` function is required in
* the spec. It looks like a 2D version of it?
*/
EncodeBytes' : {ell, c} (fin ell, ell > 0, fin c) => [c * 8][ell] -> [c * ell]Byte
EncodeBytes' = regroup
* Encode an array of `d`-bit integers into a byte array, for `d < 12`.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

todo: I was also thinking of making all these private, since they're not part of the public API of ML-KEM.

* [FIPS-203] Section 4.2.1 Algorithm 5.
*
* Note: In this implementation, we treat the `d < 12` case separately from
* `d == 12` because it allows us to better express the different integer
* types. For `d < 12`, we use the bit vector type.
* See `ByteEncode12` for the `d = 12` case.
*/
ByteEncode : {d} (1 <= d, d < 12) => [256][d] -> [32 * d]Byte
ByteEncode F = B where
// Step 1-3, 5.
// We iterate over `F` directly, instead of indexing into it with `i`.
// In a bit vector, iteratively subtracting the last bit and dividing by 2
// is the same as bit-shifting to the right (see `subAndDivIsShift`).
as = [[ a >> j | j <- [0..d-1]] | a <- F]
// Step 4. In a bit vector, taking the value `% 2` is the same as taking
// the final bit. See `mod2IsFinalBit`.
b = join [ [ a ! 0 | a <- ajs] | ajs <- as]
// Step 8.
B = BitsToBytes b

/**
* This is used in some places where the `ByteDecode` function is required in
* the spec. It looks like a 3D version of it?
* Encode `k` vectors of `d`-bit integers into a byte array.
* [FIPS-203] Section 2.4.8.
*/
DecodeBytes' : {ell, c} (fin ell, ell > 0, fin c) => [c * ell]Byte -> [c * 8][ell]
DecodeBytes' = regroup
ByteEncode_Vec : {d} (fin d, 1 <= d, d < 12) => [k][256][d] -> [k * 32 * d]Byte
ByteEncode_Vec F_vec = join (map ByteEncode F_vec)

/**
* Encoding and decoding bytes must be inverses in 2D.
* ```repl
* :prove CorrectnessEncodeBytes'
* ```
* Encode an array of integers mod `q` into a byte array.
* [FIPS-203] Section 4.2.1 Algorithm 5.
*
* Note: In this implementation, we treat the `d < 12` case separately from
* `d == 12` because it allows us to better express the different integer
* types. For `d == 12`, we use the integers-mod (`Z`) type.
* See `ByteEncode` for the `d < 12` case.
*/
CorrectnessEncodeBytes' : [n][2] -> Bit
property CorrectnessEncodeBytes' B = DecodeBytes'(EncodeBytes' B) == B
ByteEncode12 : [256](Z q) -> [32 * 12]Byte
ByteEncode12 F = B where
type d = 12

// We need to explicitly convert from integers mod `q` to 12-bit vectors.
toBitVec : Z q -> [d]
toBitVec f = fromInteger (fromZ f)
F' = map toBitVec F

// The following is the same as in `ByteEncode`. See notes above.
as = [[ a >> j | j <- [0..d-1]] | a <- F' ]
b = join [ [ a ! 0 | a <- ajs] | ajs <- as]
B = BitsToBytes b

/**
* This is used in some places where the `ByteEncode` function is required in
* the spec. It's a 3D version of `EncodeBytes'`.
* Encode a set of `k` vectors of integers mod `q` into a byte array.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is [FIPS-203] Section 2.4.8 relevant here like it is in ByteEncode_Vec?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, thanks.

*/
EncodeBytes : {ell, k1, c} (fin ell, ell > 0, fin k1, fin c) =>
[k1][c * 8][ell] -> [c * ell * k1]Byte
EncodeBytes B = EncodeBytes' (join B)
ByteEncode12_Vec : [k][256](Z q) -> [k * 32 * 12]Byte
ByteEncode12_Vec F_vec = join (map ByteEncode12 F_vec)

/**
* This is used in some places where the `ByteDecode` function is required in
* the spec. It's a 3D version of `DecodeBytes'`.
* The subtract-and-divide algorithm applied to `a` in `ByteEncode` is the
* same as shifting right.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it'd be helpful to add the step in the algorithm (5) where this comes up in the explanation.

* Note: The type constraint for `d` does not allow `1` because `% 2` is not a
* legal operation on 1-bit vectors (since 2 cannot be represented as a 1-bit
* vector.)
* ```repl
* :prove subAndDivIsShiftR`{d_u}
* :prove subAndDivIsShiftR`{d_v}
* :prove subAndDivIsShiftR`{12}
* ```
*/
DecodeBytes : {ell, k1, c} (fin ell, ell > 0, fin k1, fin c) =>
[c * ell * k1]Byte -> [k1][c * 8][ell]
DecodeBytes B = groupBy (DecodeBytes' B)
subAndDivIsShiftR : {d} (fin d, d > 1) => [d] -> Bit
property subAndDivIsShiftR a = take (sad a) == shift a where
sad x = [x] # sad ((x - (x % 2)) / 2)
shift x = [x >> i | i <- [0..(d-1)]]

/**
* Encoding and decoding bytes must be inverses in 3D.
* Computing a bit vector mod 2 is the same as taking its final (least
* significant) bit.
* Note: The type constraint for `d` does not allow `1` because `% 2` is not a
* legal operation on 1-bit vectors (since 2 cannot be represented as a 1-bit
* vector.) It's trivially true, though.
* ```repl
* :prove CorrectnessEncodeBytes
* :prove mod2IsFinalBit`{d_u}
* :prove mod2IsFinalBit`{d_v}
* :prove mod2IsFinalBit`{12}
* ```
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be helpful to add some text about why we care that mod 2 is the final bit as it relates to the EncodeBytes algorithm.

*/
CorrectnessEncodeBytes : [k][n][2] -> Bit
property CorrectnessEncodeBytes B = DecodeBytes(EncodeBytes B) == B
mod2IsFinalBit : {d} (fin d, d > 1) => [d] -> Bit
property mod2IsFinalBit a = a % 2 == zext [a ! 0]

/**
* Apply encoding to a vector applying `Encode` to each element, then
* concatenating the results.
* [FIPS-203] Section 2.4.8.
* Decode a byte array into an array of `d`-bit integers, for `d < 12`.
* [FIPS-203] Section 4.2.1 Algorithm 6.
*
* Note: In this implementation, we treat the `d < 12` case separately from
* `d == 12` because it allows us to better express the different integer
* types. For `d < 12`, we use the bit vector type.
* See `ByteDecode12` for the `d = 12` case.
*/
Encode : {ell, k1} (fin ell, ell > 0, fin k1) => [k1]Z_q_256 -> [32 * ell * k1]Byte
Encode fVec = join (map Encode'`{ell} fVec)
ByteDecode : {d} (1 <= d, d < 12) => [32 * d]Byte -> [256][d]
ByteDecode B = F where
// Step 1.
b = BytesToBits B
// Steps 2-4. The `mod m` is implicit in the type because the `[d]` type
// always operates `mod 2^d`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I paused for a second here and had to look back at the spec. Maybe be explicit about how m=2^d?

// In a bit vector, multiplying by `2^j` is the same as left shift (see
// mul2jIsShiftL`).
// `bidj` is a single Bit; we convert to a `d`-length bit vector (with
// zeros in the higher-order bits) to support subsequent operations.
F = [ sum [ (zext [bidj]) << j
| bidj <- bid
| j <- [0..d-1]]
| bid <- split`{256} b]

/**
* Apply decoding to a vector by splitting the vector into appropriately-sized
* components and applying `Decode` to each element.
* Decode `k` arrays of `d`-bit integers into a byte array.
* [FIPS-203] Section 2.4.8.
*/
Decode : {ell, k1} (fin ell, ell > 0, fin k1) => [32 * ell * k1]Byte -> [k1]Z_q_256
Decode BVec = map Decode'`{ell} (split BVec)
ByteDecode_Vec : {d} (1 <= d, d < 12) => [k * 32 * d]Byte -> [k][256][d]
ByteDecode_Vec B_vec = map ByteDecode (split B_vec)

/**
* Encode and decode must be inverses in 2D.
* ```repl
* :check CorrectnessEncodeDecode
* ```
*/
CorrectnessEncodeDecode : [k]Z_q_256 -> Bit
property CorrectnessEncodeDecode fVec = all CorrectnessEncodeDecode' fVec
* Decode a byte array into an array of integers mod `q`.
* [FIPS-203] Section 4.2.1 Algorithm 6.
*
* Note: In this implementation, we treat the `d < 12` case separately from
* `d == 12` because it allows us to better express the different integer
* types. For `d = 12`, we use the integers-mod (`Z`) type.
* See `ByteDecode` for the `d < 12` case.
*/
ByteDecode12 : [32 * 12]Byte -> [256](Z q)
ByteDecode12 B = F' where
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: this function is never actually used alone, but I don't feel like it's reasonable skip it and only make ByteDecode12_Vec.

type d = 12
// These steps are the same as in `ByteDecode`. See that function for
// notes.
b = BytesToBits B
F = [ sum [ (zext`{d} [bidj]) << j
| bidj <- bid
| j <- [0..d-1]]
| bid <- split`{256} b]

/**
* Decode a byte array into an array of `d`-bit integers.
* [FIPS-203] Section 4.2.1, Algorithm 6.
*/
DecodeSpec : {ell} (fin ell, ell > 0) => [32 * ell]Byte -> Z_q_256
DecodeSpec B = [f i | i <- [0 .. 255]]
where betas = BytesToBits B : [256 * ell]
f i = sum [ BitToZ (betas@(i*`ell+j))*fromInteger(2^^j)
| j <- [0 .. (ell-1)]]
// In this case, since `m = q` (and not `2^12`), we need to explicitly
// convert each integer from a 12-bit array to an integer mod `q`.
toZq f = fromInteger (toInteger f)
F' = map toZq F

/**
* Decode a byte array into an array of `d`-bit integers, more efficiently than
* the version in the spec.
* Decode `k` joined byte arrays into `k` arrays of integers mod `q`.
*/
Decode' : {ell} (fin ell, ell > 0) => [32 * ell]Byte -> Z_q_256
Decode' B = map BitstoZ`{ell} (split (BytesToBits B))
ByteDecode12_Vec : [k * 32 * 12]Byte -> [k][256](Z q)
ByteDecode12_Vec B = map ByteDecode12 (split B)

/**
* Proof that the efficient decode function is the same as the spec version.
* Multiplying a value by `2^^j` is the same as bit-shifting it left by `j`
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should 2^^j be 2^j in this line?

* bits.
* Note: The type constraint does not allow `d == 1` because the hard-coded
* `2` does not fit in a single bit, but the property is trivial:
* `j` can only be 0. On the left, `b * 2^0 == b * 1 == b`. On the right,
* `b << 0 == b`.
* ```repl
* :check DecodeEquiv
* :prove mul2jIsShiftL`{d_u}
* :prove mul2jIsShiftL`{d_v}
* :prove mul2jIsShiftL`{12}
* ```
*/
DecodeEquiv : [32 * 12]Byte -> Bit
property DecodeEquiv B = (Decode' B == DecodeSpec B)
mul2jIsShiftL : {d} (fin d, 1 < d) => [d] -> Bit
property mul2jIsShiftL b = and [( b * (2^^j)) == (b << j) | j <- [0..d-1]]

/**
* Encode an array of `d`-bit integers into a byte array, more efficiently than
* the version in the spec.
*
* This should be equivalent to [FIPS-203] Section 4.2.1, Algorithm 5.
* When `d < 12`, the byte encoding and decoding functions should be
* one-to-one inverses of each other.
* [FIPS-203] Section 4.2.1 "Encoding and decoding", first point.
* ```repl
* :prove ByteEncodeInvertsByteDecode`{1}
* :prove ByteEncodeInvertsByteDecode`{d_u}
* :prove ByteEncodeInvertsByteDecode`{d_v}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does d_u and d_v come from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an explanation here and everywhere I use these without comment in the doctests.

* ```
*/
Encode' : {ell} (fin ell, ell > 0) => Z_q_256 -> [32 * ell]Byte
Encode' f = BitsToBytes (join (map ZtoBits`{ell} f))
ByteEncodeInvertsByteDecode : {d} (fin d, 1 <= d, d < 12) => [256][d] -> Bit
property ByteEncodeInvertsByteDecode bits =
decode_encode_works && encode_decode_works where
decode_encode_works = ByteDecode (ByteEncode bits) == bits
// Rearrange random input to be valid for the decode function
bytes = split (join bits)
encode_decode_works = ByteEncode (ByteDecode bytes) == bytes

/**
* Decoding must be the inverse of encoding.
* [FIPS-203] Section 4.2.1, "Encoding and decoding."
* Byte decoding is the inverse of byte decoding for `d = 12`.
* [FIPS-203] Section 4.2.1 "Encoding and decoding", second point.
*
* Note that the reverse property (decoding, then encoding) is not true!
* ```repl
* :check CorrectnessEncodeDecode'
* :check ByteDecode12InvertsByteEncode12
* ```
*/
CorrectnessEncodeDecode' : Z_q_256 -> Bit
property CorrectnessEncodeDecode' f = Decode'`{12}(Encode'`{12} f) == f
ByteDecode12InvertsByteEncode12 : [256](Z q) -> Bit
ByteDecode12InvertsByteEncode12 bits = ByteDecode12 (ByteEncode12 bits) == bits

/**
* Uniformly sample NTT representations.
Expand Down Expand Up @@ -827,9 +901,9 @@ private submodule K_PKE where
// Step 18.
t_hat = (dotMatVec A_hat s_hat) + e_hat
// Step 19.
ekPKE = (Encode`{12} t_hat) # ρ
ekPKE = (ByteEncode12_Vec t_hat) # ρ
// Step 20.
dkPKE = Encode`{12} (s_hat)
dkPKE = ByteEncode12_Vec s_hat

/**
* Encryption algorithm for the K-PKE component scheme.
Expand All @@ -843,7 +917,7 @@ private submodule K_PKE where
Encrypt : EncryptionKey -> [32]Byte -> [32]Byte -> Ciphertext
Encrypt ekPKE m r = c where
// Step 2.
t_hat = Decode`{12} (ekPKE @@[0 .. 384*k - 1])
t_hat = ByteDecode12_Vec (ekPKE @@[0 .. 384*k - 1])
// Step 3.
rho = ekPKE @@[384*k .. 384*k + 32 - 1]
// Steps 4-8.
Expand All @@ -865,13 +939,13 @@ private submodule K_PKE where
// Step 19.
u = NTTInv (dotMatVec (transpose A_hat) y_hat) + e1
// Step 20.
mu = Decompress'`{1} (DecodeBytes'`{1} m)
mu = Decompress'`{1} (ByteDecode`{1} m)
// Step 21.
v = (NTTInv' (dotVecVec t_hat y_hat)) + e2 + mu
// Step 22.
c1 = EncodeBytes`{d_u} (Compress`{d_u} u)
c1 = ByteEncode_Vec`{d_u} (Compress`{d_u} u)
// Step 23.
c2 = EncodeBytes'`{d_v} (Compress'`{d_v} v)
c2 = ByteEncode`{d_v} (Compress'`{d_v} v)
// Step 24.
c = c1 # c2

Expand All @@ -890,16 +964,16 @@ private submodule K_PKE where
c1 = c @@[0 .. 32 * d_u * k - 1]
// Step 2.
c2 = c @@[32 * d_u * k .. 32 * (d_u * k + d_v) - 1]
// Step 3.
u' = Decompress`{d_u} (DecodeBytes`{d_u} c1)
// Step 3.j
u' = Decompress`{d_u} (ByteDecode_Vec`{d_u} c1)
// Step 4.
v' = Decompress'`{d_v} (DecodeBytes'`{d_v} c2)
v' = Decompress'`{d_v} (ByteDecode`{d_v} c2)
// Step 5.
s_hat = Decode`{12} dkPKE
s_hat = ByteDecode12_Vec dkPKE
// Step 6.
w = v' - NTTInv' (dotVecVec s_hat (NTT u'))
// Step 7.
m = EncodeBytes'`{1} (Compress'`{1} w)
m = ByteEncode`{1} (Compress'`{1} w)

/**
* The K-PKE scheme must satisfy the basic properties of an encryption
Expand Down