Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ML-KEM: Improve compression and byte conversion functions #161

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 138 additions & 36 deletions Primitive/Asymmetric/Cipher/ML_KEM/Specification.cry
Original file line number Diff line number Diff line change
Expand Up @@ -115,18 +115,80 @@ XOF : ([34]Byte) -> [inf]Byte
XOF(d) = groupBy (SHAKE128::xof (join d))

/**
* Conversion from bit arrays to byte arrays.
* Conversion from little-endian bit arrays to byte arrays.
* [FIPS-203] Section 4.2.1, Algorithm 3.
*/
BitsToBytes : {ell} (fin ell, ell > 0) => [ell*8]Bit -> [ell]Byte
BitsToBytes input = map reverse (groupBy input)
BitsToBytes : {ell} (fin ell) => [8 * ell]Bit -> [ell]Byte
BitsToBytes b
| ell == 0 => zero
| ell > 0 => B where
// Group the bits into the B[⌊i / 8⌋] sets; pad them to support
// subsequent operations, and correlate each bit with its index `i`.
b' = groupBy`{8} [(zext [bi], i)
| bi <- b
| i <- [0..8 * ell - 1]]

// Steps 2-4.
B = [sum [bi * (2 ^^ (i % 8))
| (bi, i) <- bi8]
| bi8 <- b']

/**
* Conversion from byte arrays to bit arrays.
* [FIPS-203] Section 4.2.1, Algorithm 4.
*/
BytesToBits : {ell} (fin ell, ell > 0) => [ell]Byte -> [ell*8]Bit
BytesToBits input = join (map reverse input)
BytesToBits : {ell} (fin ell) => [ell]Byte -> [ell*8]Bit
BytesToBits C
| ell == 0 => []
| ell > 0 => join [[ b8ij where
// Step 4. Taking the last bit is the same as modding by 2.
b8ij = Ci' ! 0
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: I intend to refer this to the mod2IsFinalBit property I added in #160.

// Step 5. Shifting right is the same as the iterative
// division (see `div2IsShiftR`). This accounts for all the
// divisions "up to this point" (e.g. none when `j = 0`), which
// is why we use `Ci'` to evaluate `b8ij` above.
Ci' = Ci >> j
// Step 3.
| j <- [0..7]]
// Step 2. We iterate over `C` directly instead of indexing into it.
| Ci <- C ]

/**
* The iterative division by 2 in `BytesToBits` is the same as shifting right.
* ```repl
* :prove div2IsShiftR
* ```
*/
div2IsShiftR : Byte -> Bit
div2IsShiftR C = take (d2 C) == shl where
// Note: division here is floor'd by default.
d2 c = [c] # d2 (c / 2)
shl = [C >> j | j <- [0..7]]

/**
* The conversions between bits and bytes are each others' inverses.
* [FIPS-203] Section 4.2.1 (see description on Algorithm 4).
* The sample `ell` values here are a subset of the possible values in the spec.
* ```repl
* :prove B2B2BInverts`{32}
* :prove B2B2BInverts`{192}
* :prove B2B2BInverts`{384}
* ```
*/
B2B2BInverts : {ell} (fin ell, ell > 0) => [ell * 8] -> Bit
B2B2BInverts bits = bitsWorks && bytesWorks where
bitsWorks = BytesToBits (BitsToBytes bits) == bits
bytesWorks = BitsToBytes (BytesToBits (split bits)) == split bits

/**
* This currently fails due to endianness issues!
* Check the example given in the spec for converting between bits and bytes.
* [FIPS-203] Section 4.2.1 "Converting between bits and bytes."
* ```repl
* :prove B2BExampleWorks
* ```
*/
B2BExampleWorks = BitsToBytes 0b11010001 == [139]

BitToZ : Bit -> Z q
BitToZ b = if b then 1 else 0
Expand All @@ -137,71 +199,111 @@ BitstoZ betas = fromInteger (toInteger (reverse betas))
ZtoBits : {ell} (fin ell, ell > 0) => (Z q) -> [ell]
ZtoBits fi = reverse (fromInteger (fromZ fi))

// In Cryptol, rounding is computed via the built-in function roundAway
property rounding = ((roundAway(1.5) == 2) && (roundAway(1.4) == 1))
/**
* In Cryptol, rounding is computed via the built-in function `roundAway`.
* [FIPS-203] Section 2.3.
*/
property roundingWorks y = y >= 0 ==> roundUpWorks && roundDownWorks where
y' = fromInteger y
roundUpWorks = roundAway (y' + 0.5) == (y + 1)
roundDownWorks = roundAway (y' + 0.4) == y

/**
* Compression from an integer mod `q` to an integer mod `2^d`.
* Compress an integer mod `q` into an integer mod `2^d`.
* [FIPS-203] Section 4.2.1, Equation 4.7.
*/
Compress'' : {d} (d < lg2 q) => Z q -> [d]
Compress'' x = fromInteger(roundAway(((2^^`d)/.`q) * fromInteger(fromZ(x))) % 2^^`d)
Compress : {d} (d < width q) => Z q -> [d]
Compress x = y where
// Convert from an integer mod `q` to a rational number.
x' = fromInteger (fromZ x) : Rational
// Compress. Note that `/.` denotes division of rationals.
y' = roundAway (((2^^`d) /. `q) * x')
// mod 2^^d (by converting from an integer to a d-bit vector).
y = (fromInteger y') : [d]

/**
* Decompression from an integer mod `2^d` to an integer mod `q`.
* Decompress an integer mod `2^d` into an integer mod `q`.
* [FIPS-203] Section 4.2.1, Equation 4.8.
*/
Decompress'' : {d} (d < lg2 q) => [d] -> Z q
Decompress'' x = fromInteger(roundAway(((`q)/.(2^^`d))*fromInteger(toInteger(x))))
Decompress : {d} (d < width q) => [d] -> Z q
Decompress y = x where
// Convert from a d-length bit vector to a rational number.
y' = fromInteger (toInteger y) : Rational
// Decompress! As before, `/.` is division of rationals.
x' = roundAway((`q /. (2^^`d)) * y')
// Convert from an integer to an integer mod `q`.
x = (fromInteger x') : Z q

/**
* Compression inverts decompression for all inputs and bit lengths.
* We'll prove it for the bit lengths found in the
* ```repl
* :prove CompressInvertsDecompress`{1}
* :prove CompressInvertsDecompress`{d_u}
* :prove CompressInvertsDecompress`{d_v}
* ```
*/
CompressInvertsDecompress : {d} (d < width q) => [d] -> Bit
property CompressInvertsDecompress y = Compress (Decompress y) == y

/**
* When `d` is large, compression followed by decompression must not
* significantly alter the value.
* This sets `d = d_u`, which is the largest value for `d` used in the
* spec.
* [FIPS-203] Section 4.2.1, "Compression and Decompression".
* ```repl
* :prove DecompressMostlyInvertsCompress
* ```
*/
CorrectnessCompress : Z q -> Bit
property CorrectnessCompress x = err <= B_q`{d_u} where
x' = Decompress''`{d_u}(Compress''`{d_u}(x))
err = abs(modpm(x'-x))
DecompressMostlyInvertsCompress : Z q -> Bit
property DecompressMostlyInvertsCompress x = errIsSmallEnough where
x' = Decompress`{d_u} (Compress`{d_u} x)
err = abs (modpm (x' - x))
errIsSmallEnough = err <= B_q`{d_u}

// The spec doesn't describe formally what "not significantly altered"
// means; we use this equation.
B_q : {d} (d < lg2 q) => Integer
B_q = roundAway((`q/.(2^^(`d+1))))

modpm : {alpha} (fin alpha, alpha > 0) => Z alpha -> Integer
modpm r = if r' > (`alpha / 2) then r' - `alpha else r'
where r' = fromZ(r)
// Convert an integer mod `q` to a representation centered around 0
// (and represented as an `Integer`).
modpm : Z q -> Integer
modpm r = if r' > (`q / 2) then r' - `q else r'
where r' = fromZ r

/**
* Compression applied to a vector is equivalent to applying compression to
* each individual element.
* [FIPS-203] Section 2.4.8, Equation 2.15.
*/
Compress' : {d} (d < lg2 q) => Z_q_256 -> [n][d]
Compress' x = map Compress''`{d} x
Compress_Vec : {d} (d < lg2 q) => Z_q_256 -> [n][d]
Compress_Vec x = map Compress`{d} x

/**
* Decompression applied to a vector is equivalent to applying decompression to
* each individual element.
* [FIPS-203] Section 2.4.8.
*/
Decompress' : {d} (d < lg2 q) => [n][d] -> Z_q_256
Decompress' x = map Decompress''`{d} x
Decompress_Vec : {d} (d < lg2 q) => [n][d] -> Z_q_256
Decompress_Vec x = map Decompress`{d} x

/**
* Compression applied to an array is equivalent to applying compression to
* Compression applied to a matrix is equivalent to applying compression to
* each individual element.
* [FIPS-203] Section 2.4.8.
*/
Compress : {d, k1} (d < lg2 q, fin k1) => [k1]Z_q_256 -> [k1][n][d]
Compress x = map Compress'`{d} x
Compress_Mat : {d, k1} (d < lg2 q, fin k1) => [k1]Z_q_256 -> [k1][n][d]
Compress_Mat x = map Compress_Vec`{d} x

/**
* Decompression applied to an array is equivalent to applying decompression to
* Decompression applied to a matrix is equivalent to applying decompression to
* each individual element.
* [FIPS-203] Section 2.4.8.
*/
Decompress : {d, k1} (d < lg2 q, fin k1) => [k1][n][d] -> [k1]Z_q_256
Decompress x = map Decompress'`{d} x
Decompress_Mat : {d, k1} (d < lg2 q, fin k1) => [k1][n][d] -> [k1]Z_q_256
Decompress_Mat x = map Decompress_Vec`{d} x

/*
* We make this trivial serialization explicit, since it is not an identity in Cryptol.
Expand Down Expand Up @@ -865,13 +967,13 @@ private submodule K_PKE where
// Step 19.
u = NTTInv (dotMatVec (transpose A_hat) y_hat) + e1
// Step 20.
mu = Decompress'`{1} (DecodeBytes'`{1} m)
mu = Decompress_Vec`{1} (DecodeBytes'`{1} m)
// Step 21.
v = (NTTInv' (dotVecVec t_hat y_hat)) + e2 + mu
// Step 22.
c1 = EncodeBytes`{d_u} (Compress`{d_u} u)
c1 = EncodeBytes`{d_u} (Compress_Mat`{d_u} u)
// Step 23.
c2 = EncodeBytes'`{d_v} (Compress'`{d_v} v)
c2 = EncodeBytes'`{d_v} (Compress_Vec`{d_v} v)
// Step 24.
c = c1 # c2

Expand All @@ -891,15 +993,15 @@ private submodule K_PKE where
// Step 2.
c2 = c @@[32 * d_u * k .. 32 * (d_u * k + d_v) - 1]
// Step 3.
u' = Decompress`{d_u} (DecodeBytes`{d_u} c1)
u' = Decompress_Mat`{d_u} (DecodeBytes`{d_u} c1)
// Step 4.
v' = Decompress'`{d_v} (DecodeBytes'`{d_v} c2)
v' = Decompress_Vec`{d_v} (DecodeBytes'`{d_v} c2)
// Step 5.
s_hat = Decode`{12} dkPKE
// Step 6.
w = v' - NTTInv' (dotVecVec s_hat (NTT u'))
// Step 7.
m = EncodeBytes'`{1} (Compress'`{1} w)
m = EncodeBytes'`{1} (Compress_Vec`{1} w)

/**
* The K-PKE scheme must satisfy the basic properties of an encryption
Expand Down