-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ML-KEM: Bring encoding functions in line with the spec #160
base: master
Are you sure you want to change the base?
Changes from 9 commits
40a62f3
4effe92
72f9896
19a2a77
ccee91d
62bef29
820991b
f240c39
32668f5
875f925
2d41a3d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -128,15 +128,6 @@ BitsToBytes input = map reverse (groupBy input) | |
BytesToBits : {ell} (fin ell, ell > 0) => [ell]Byte -> [ell*8]Bit | ||
BytesToBits input = join (map reverse input) | ||
|
||
BitToZ : Bit -> Z q | ||
BitToZ b = if b then 1 else 0 | ||
|
||
BitstoZ : {ell} (fin ell, ell > 0) => [ell] -> (Z q) | ||
BitstoZ betas = fromInteger (toInteger (reverse betas)) | ||
|
||
ZtoBits : {ell} (fin ell, ell > 0) => (Z q) -> [ell] | ||
ZtoBits fi = reverse (fromInteger (fromZ fi)) | ||
|
||
// In Cryptol, rounding is computed via the built-in function roundAway | ||
property rounding = ((roundAway(1.5) == 2) && (roundAway(1.4) == 1)) | ||
|
||
|
@@ -203,129 +194,203 @@ Compress x = map Compress'`{d} x | |
Decompress : {d, k1} (d < lg2 q, fin k1) => [k1][n][d] -> [k1]Z_q_256 | ||
Decompress x = map Decompress'`{d} x | ||
|
||
/* | ||
* We make this trivial serialization explicit, since it is not an identity in Cryptol. | ||
* Byte encoding and decoding involves regrouping 8-bit arrays into ell-bit arrays. | ||
*/ | ||
regroup B = reverse (groupBy (join (reverse B))) | ||
|
||
/** | ||
* This is used in some places where the `ByteEncode` function is required in | ||
* the spec. It looks like a 2D version of it? | ||
*/ | ||
EncodeBytes' : {ell, c} (fin ell, ell > 0, fin c) => [c * 8][ell] -> [c * ell]Byte | ||
EncodeBytes' = regroup | ||
* Encode an array of `d`-bit integers into a byte array, for `d < 12`. | ||
* [FIPS-203] Section 4.2.1 Algorithm 5. | ||
* | ||
* Note: In this implementation, we treat the `d < 12` case separately from | ||
* `d == 12` because it allows us to better express the different integer | ||
* types. For `d < 12`, we use the bit vector type. | ||
* See `ByteEncode12` for the `d = 12` case. | ||
*/ | ||
ByteEncode : {d} (1 <= d, d < 12) => [256][d] -> [32 * d]Byte | ||
ByteEncode F = B where | ||
// Step 1-3, 5. | ||
// We iterate over `F` directly, instead of indexing into it with `i`. | ||
// In a bit vector, iteratively subtracting the last bit and dividing by 2 | ||
// is the same as bit-shifting to the right (see `subAndDivIsShift`). | ||
as = [[ a >> j | j <- [0..d-1]] | a <- F] | ||
// Step 4. In a bit vector, taking the value `% 2` is the same as taking | ||
// the final bit. See `mod2IsFinalBit`. | ||
b = join [ [ a ! 0 | a <- ajs] | ajs <- as] | ||
// Step 8. | ||
B = BitsToBytes b | ||
|
||
/** | ||
* This is used in some places where the `ByteDecode` function is required in | ||
* the spec. It looks like a 3D version of it? | ||
* Encode `k` vectors of `d`-bit integers into a byte array. | ||
* [FIPS-203] Section 2.4.8. | ||
*/ | ||
DecodeBytes' : {ell, c} (fin ell, ell > 0, fin c) => [c * ell]Byte -> [c * 8][ell] | ||
DecodeBytes' = regroup | ||
ByteEncode_Vec : {d} (fin d, 1 <= d, d < 12) => [k][256][d] -> [k * 32 * d]Byte | ||
ByteEncode_Vec F_vec = join (map ByteEncode F_vec) | ||
|
||
/** | ||
* Encoding and decoding bytes must be inverses in 2D. | ||
* ```repl | ||
* :prove CorrectnessEncodeBytes' | ||
* ``` | ||
* Encode an array of integers mod `q` into a byte array. | ||
* [FIPS-203] Section 4.2.1 Algorithm 5. | ||
* | ||
* Note: In this implementation, we treat the `d < 12` case separately from | ||
* `d == 12` because it allows us to better express the different integer | ||
* types. For `d == 12`, we use the integers-mod (`Z`) type. | ||
* See `ByteEncode` for the `d < 12` case. | ||
*/ | ||
CorrectnessEncodeBytes' : [n][2] -> Bit | ||
property CorrectnessEncodeBytes' B = DecodeBytes'(EncodeBytes' B) == B | ||
ByteEncode12 : [256](Z q) -> [32 * 12]Byte | ||
ByteEncode12 F = B where | ||
type d = 12 | ||
|
||
// We need to explicitly convert from integers mod `q` to 12-bit vectors. | ||
toBitVec : Z q -> [d] | ||
toBitVec f = fromInteger (fromZ f) | ||
F' = map toBitVec F | ||
|
||
// The following is the same as in `ByteEncode`. See notes above. | ||
as = [[ a >> j | j <- [0..d-1]] | a <- F' ] | ||
b = join [ [ a ! 0 | a <- ajs] | ajs <- as] | ||
B = BitsToBytes b | ||
|
||
/** | ||
* This is used in some places where the `ByteEncode` function is required in | ||
* the spec. It's a 3D version of `EncodeBytes'`. | ||
* Encode a set of `k` vectors of integers mod `q` into a byte array. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is [FIPS-203] Section 2.4.8 relevant here like it is in ByteEncode_Vec? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, thanks. |
||
*/ | ||
EncodeBytes : {ell, k1, c} (fin ell, ell > 0, fin k1, fin c) => | ||
[k1][c * 8][ell] -> [c * ell * k1]Byte | ||
EncodeBytes B = EncodeBytes' (join B) | ||
ByteEncode12_Vec : [k][256](Z q) -> [k * 32 * 12]Byte | ||
ByteEncode12_Vec F_vec = join (map ByteEncode12 F_vec) | ||
|
||
/** | ||
* This is used in some places where the `ByteDecode` function is required in | ||
* the spec. It's a 3D version of `DecodeBytes'`. | ||
* The subtract-and-divide algorithm applied to `a` in `ByteEncode` is the | ||
* same as shifting right. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it'd be helpful to add the step in the algorithm (5) where this comes up in the explanation. |
||
* Note: The type constraint for `d` does not allow `1` because `% 2` is not a | ||
* legal operation on 1-bit vectors (since 2 cannot be represented as a 1-bit | ||
* vector.) | ||
* ```repl | ||
* :prove subAndDivIsShiftR`{d_u} | ||
* :prove subAndDivIsShiftR`{d_v} | ||
* :prove subAndDivIsShiftR`{12} | ||
* ``` | ||
*/ | ||
DecodeBytes : {ell, k1, c} (fin ell, ell > 0, fin k1, fin c) => | ||
[c * ell * k1]Byte -> [k1][c * 8][ell] | ||
DecodeBytes B = groupBy (DecodeBytes' B) | ||
subAndDivIsShiftR : {d} (fin d, d > 1) => [d] -> Bit | ||
property subAndDivIsShiftR a = take (sad a) == shift a where | ||
sad x = [x] # sad ((x - (x % 2)) / 2) | ||
shift x = [x >> i | i <- [0..(d-1)]] | ||
|
||
/** | ||
* Encoding and decoding bytes must be inverses in 3D. | ||
* Computing a bit vector mod 2 is the same as taking its final (least | ||
* significant) bit. | ||
* Note: The type constraint for `d` does not allow `1` because `% 2` is not a | ||
* legal operation on 1-bit vectors (since 2 cannot be represented as a 1-bit | ||
* vector.) It's trivially true, though. | ||
* ```repl | ||
* :prove CorrectnessEncodeBytes | ||
* :prove mod2IsFinalBit`{d_u} | ||
* :prove mod2IsFinalBit`{d_v} | ||
* :prove mod2IsFinalBit`{12} | ||
* ``` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It might be helpful to add some text about why we care that mod 2 is the final bit as it relates to the EncodeBytes algorithm. |
||
*/ | ||
CorrectnessEncodeBytes : [k][n][2] -> Bit | ||
property CorrectnessEncodeBytes B = DecodeBytes(EncodeBytes B) == B | ||
mod2IsFinalBit : {d} (fin d, d > 1) => [d] -> Bit | ||
property mod2IsFinalBit a = a % 2 == zext [a ! 0] | ||
|
||
/** | ||
* Apply encoding to a vector applying `Encode` to each element, then | ||
* concatenating the results. | ||
* [FIPS-203] Section 2.4.8. | ||
* Decode a byte array into an array of `d`-bit integers, for `d < 12`. | ||
* [FIPS-203] Section 4.2.1 Algorithm 6. | ||
* | ||
* Note: In this implementation, we treat the `d < 12` case separately from | ||
* `d == 12` because it allows us to better express the different integer | ||
* types. For `d < 12`, we use the bit vector type. | ||
* See `ByteDecode12` for the `d = 12` case. | ||
*/ | ||
Encode : {ell, k1} (fin ell, ell > 0, fin k1) => [k1]Z_q_256 -> [32 * ell * k1]Byte | ||
Encode fVec = join (map Encode'`{ell} fVec) | ||
ByteDecode : {d} (1 <= d, d < 12) => [32 * d]Byte -> [256][d] | ||
ByteDecode B = F where | ||
// Step 1. | ||
b = BytesToBits B | ||
// Steps 2-4. The `mod m` is implicit in the type because the `[d]` type | ||
// always operates `mod 2^d`. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I paused for a second here and had to look back at the spec. Maybe be explicit about how m=2^d? |
||
// In a bit vector, multiplying by `2^j` is the same as left shift (see | ||
// mul2jIsShiftL`). | ||
// `bidj` is a single Bit; we convert to a `d`-length bit vector (with | ||
// zeros in the higher-order bits) to support subsequent operations. | ||
F = [ sum [ (zext [bidj]) << j | ||
| bidj <- bid | ||
| j <- [0..d-1]] | ||
| bid <- split`{256} b] | ||
|
||
/** | ||
* Apply decoding to a vector by splitting the vector into appropriately-sized | ||
* components and applying `Decode` to each element. | ||
* Decode `k` arrays of `d`-bit integers into a byte array. | ||
* [FIPS-203] Section 2.4.8. | ||
*/ | ||
Decode : {ell, k1} (fin ell, ell > 0, fin k1) => [32 * ell * k1]Byte -> [k1]Z_q_256 | ||
Decode BVec = map Decode'`{ell} (split BVec) | ||
ByteDecode_Vec : {d} (1 <= d, d < 12) => [k * 32 * d]Byte -> [k][256][d] | ||
ByteDecode_Vec B_vec = map ByteDecode (split B_vec) | ||
|
||
/** | ||
* Encode and decode must be inverses in 2D. | ||
* ```repl | ||
* :check CorrectnessEncodeDecode | ||
* ``` | ||
*/ | ||
CorrectnessEncodeDecode : [k]Z_q_256 -> Bit | ||
property CorrectnessEncodeDecode fVec = all CorrectnessEncodeDecode' fVec | ||
* Decode a byte array into an array of integers mod `q`. | ||
* [FIPS-203] Section 4.2.1 Algorithm 6. | ||
* | ||
* Note: In this implementation, we treat the `d < 12` case separately from | ||
* `d == 12` because it allows us to better express the different integer | ||
* types. For `d = 12`, we use the integers-mod (`Z`) type. | ||
* See `ByteDecode` for the `d < 12` case. | ||
*/ | ||
ByteDecode12 : [32 * 12]Byte -> [256](Z q) | ||
ByteDecode12 B = F' where | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note: this function is never actually used alone, but I don't feel like it's reasonable skip it and only make |
||
type d = 12 | ||
// These steps are the same as in `ByteDecode`. See that function for | ||
// notes. | ||
b = BytesToBits B | ||
F = [ sum [ (zext`{d} [bidj]) << j | ||
| bidj <- bid | ||
| j <- [0..d-1]] | ||
| bid <- split`{256} b] | ||
|
||
/** | ||
* Decode a byte array into an array of `d`-bit integers. | ||
* [FIPS-203] Section 4.2.1, Algorithm 6. | ||
*/ | ||
DecodeSpec : {ell} (fin ell, ell > 0) => [32 * ell]Byte -> Z_q_256 | ||
DecodeSpec B = [f i | i <- [0 .. 255]] | ||
where betas = BytesToBits B : [256 * ell] | ||
f i = sum [ BitToZ (betas@(i*`ell+j))*fromInteger(2^^j) | ||
| j <- [0 .. (ell-1)]] | ||
// In this case, since `m = q` (and not `2^12`), we need to explicitly | ||
// convert each integer from a 12-bit array to an integer mod `q`. | ||
toZq f = fromInteger (toInteger f) | ||
F' = map toZq F | ||
|
||
/** | ||
* Decode a byte array into an array of `d`-bit integers, more efficiently than | ||
* the version in the spec. | ||
* Decode `k` joined byte arrays into `k` arrays of integers mod `q`. | ||
*/ | ||
Decode' : {ell} (fin ell, ell > 0) => [32 * ell]Byte -> Z_q_256 | ||
Decode' B = map BitstoZ`{ell} (split (BytesToBits B)) | ||
ByteDecode12_Vec : [k * 32 * 12]Byte -> [k][256](Z q) | ||
ByteDecode12_Vec B = map ByteDecode12 (split B) | ||
|
||
/** | ||
* Proof that the efficient decode function is the same as the spec version. | ||
* Multiplying a value by `2^^j` is the same as bit-shifting it left by `j` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should 2^^j be 2^j in this line? |
||
* bits. | ||
* Note: The type constraint does not allow `d == 1` because the hard-coded | ||
* `2` does not fit in a single bit, but the property is trivial: | ||
* `j` can only be 0. On the left, `b * 2^0 == b * 1 == b`. On the right, | ||
* `b << 0 == b`. | ||
* ```repl | ||
* :check DecodeEquiv | ||
* :prove mul2jIsShiftL`{d_u} | ||
* :prove mul2jIsShiftL`{d_v} | ||
* :prove mul2jIsShiftL`{12} | ||
* ``` | ||
*/ | ||
DecodeEquiv : [32 * 12]Byte -> Bit | ||
property DecodeEquiv B = (Decode' B == DecodeSpec B) | ||
mul2jIsShiftL : {d} (fin d, 1 < d) => [d] -> Bit | ||
property mul2jIsShiftL b = and [( b * (2^^j)) == (b << j) | j <- [0..d-1]] | ||
|
||
/** | ||
* Encode an array of `d`-bit integers into a byte array, more efficiently than | ||
* the version in the spec. | ||
* | ||
* This should be equivalent to [FIPS-203] Section 4.2.1, Algorithm 5. | ||
* When `d < 12`, the byte encoding and decoding functions should be | ||
* one-to-one inverses of each other. | ||
* [FIPS-203] Section 4.2.1 "Encoding and decoding", first point. | ||
* ```repl | ||
* :prove ByteEncodeInvertsByteDecode`{1} | ||
* :prove ByteEncodeInvertsByteDecode`{d_u} | ||
* :prove ByteEncodeInvertsByteDecode`{d_v} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Where does d_u and d_v come from? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added an explanation here and everywhere I use these without comment in the doctests. |
||
* ``` | ||
*/ | ||
Encode' : {ell} (fin ell, ell > 0) => Z_q_256 -> [32 * ell]Byte | ||
Encode' f = BitsToBytes (join (map ZtoBits`{ell} f)) | ||
ByteEncodeInvertsByteDecode : {d} (fin d, 1 <= d, d < 12) => [256][d] -> Bit | ||
property ByteEncodeInvertsByteDecode bits = | ||
decode_encode_works && encode_decode_works where | ||
decode_encode_works = ByteDecode (ByteEncode bits) == bits | ||
// Rearrange random input to be valid for the decode function | ||
bytes = split (join bits) | ||
encode_decode_works = ByteEncode (ByteDecode bytes) == bytes | ||
|
||
/** | ||
* Decoding must be the inverse of encoding. | ||
* [FIPS-203] Section 4.2.1, "Encoding and decoding." | ||
* Byte decoding is the inverse of byte decoding for `d = 12`. | ||
* [FIPS-203] Section 4.2.1 "Encoding and decoding", second point. | ||
* | ||
* Note that the reverse property (decoding, then encoding) is not true! | ||
* ```repl | ||
* :check CorrectnessEncodeDecode' | ||
* :check ByteDecode12InvertsByteEncode12 | ||
* ``` | ||
*/ | ||
CorrectnessEncodeDecode' : Z_q_256 -> Bit | ||
property CorrectnessEncodeDecode' f = Decode'`{12}(Encode'`{12} f) == f | ||
ByteDecode12InvertsByteEncode12 : [256](Z q) -> Bit | ||
ByteDecode12InvertsByteEncode12 bits = ByteDecode12 (ByteEncode12 bits) == bits | ||
|
||
/** | ||
* Uniformly sample NTT representations. | ||
|
@@ -407,6 +472,8 @@ SamplePolyCBD B = f where | |
b = BytesToBits B | ||
// This conversion is implicit in the implementation. Convert each bit into | ||
// an element of Z q. | ||
BitToZ : Bit -> Z q | ||
BitToZ bit = if bit then 1 else 0 | ||
b' = map BitToZ b | ||
|
||
// Step 3. | ||
|
@@ -827,9 +894,9 @@ private submodule K_PKE where | |
// Step 18. | ||
t_hat = (dotMatVec A_hat s_hat) + e_hat | ||
// Step 19. | ||
ekPKE = (Encode`{12} t_hat) # ρ | ||
ekPKE = (ByteEncode12_Vec t_hat) # ρ | ||
// Step 20. | ||
dkPKE = Encode`{12} (s_hat) | ||
dkPKE = ByteEncode12_Vec s_hat | ||
|
||
/** | ||
* Encryption algorithm for the K-PKE component scheme. | ||
|
@@ -843,7 +910,7 @@ private submodule K_PKE where | |
Encrypt : EncryptionKey -> [32]Byte -> [32]Byte -> Ciphertext | ||
Encrypt ekPKE m r = c where | ||
// Step 2. | ||
t_hat = Decode`{12} (ekPKE @@[0 .. 384*k - 1]) | ||
t_hat = ByteDecode12_Vec (ekPKE @@[0 .. 384*k - 1]) | ||
// Step 3. | ||
rho = ekPKE @@[384*k .. 384*k + 32 - 1] | ||
// Steps 4-8. | ||
|
@@ -865,13 +932,13 @@ private submodule K_PKE where | |
// Step 19. | ||
u = NTTInv (dotMatVec (transpose A_hat) y_hat) + e1 | ||
// Step 20. | ||
mu = Decompress'`{1} (DecodeBytes'`{1} m) | ||
mu = Decompress'`{1} (ByteDecode`{1} m) | ||
// Step 21. | ||
v = (NTTInv' (dotVecVec t_hat y_hat)) + e2 + mu | ||
// Step 22. | ||
c1 = EncodeBytes`{d_u} (Compress`{d_u} u) | ||
c1 = ByteEncode_Vec`{d_u} (Compress`{d_u} u) | ||
// Step 23. | ||
c2 = EncodeBytes'`{d_v} (Compress'`{d_v} v) | ||
c2 = ByteEncode`{d_v} (Compress'`{d_v} v) | ||
// Step 24. | ||
c = c1 # c2 | ||
|
||
|
@@ -890,16 +957,16 @@ private submodule K_PKE where | |
c1 = c @@[0 .. 32 * d_u * k - 1] | ||
// Step 2. | ||
c2 = c @@[32 * d_u * k .. 32 * (d_u * k + d_v) - 1] | ||
// Step 3. | ||
u' = Decompress`{d_u} (DecodeBytes`{d_u} c1) | ||
// Step 3.j | ||
u' = Decompress`{d_u} (ByteDecode_Vec`{d_u} c1) | ||
// Step 4. | ||
v' = Decompress'`{d_v} (DecodeBytes'`{d_v} c2) | ||
v' = Decompress'`{d_v} (ByteDecode`{d_v} c2) | ||
// Step 5. | ||
s_hat = Decode`{12} dkPKE | ||
s_hat = ByteDecode12_Vec dkPKE | ||
// Step 6. | ||
w = v' - NTTInv' (dotVecVec s_hat (NTT u')) | ||
// Step 7. | ||
m = EncodeBytes'`{1} (Compress'`{1} w) | ||
m = ByteEncode`{1} (Compress'`{1} w) | ||
|
||
/** | ||
* The K-PKE scheme must satisfy the basic properties of an encryption | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
todo: I was also thinking of making all these private, since they're not part of the public API of ML-KEM.