From e7187b029ed82b950874e606e62b3b2f096ef2fc Mon Sep 17 00:00:00 2001 From: Marcella Hastings Date: Wed, 27 Nov 2024 17:06:36 -0500 Subject: [PATCH] mldsa: improve docs, control flow on hintpack #184 --- Common/OptionUtils.cry | 22 +++--- .../Signature/ML_DSA/Specification.cry | 77 +++++++++++-------- 2 files changed, 58 insertions(+), 41 deletions(-) diff --git a/Common/OptionUtils.cry b/Common/OptionUtils.cry index b3906f8..c293342 100644 --- a/Common/OptionUtils.cry +++ b/Common/OptionUtils.cry @@ -1,5 +1,8 @@ /** * Convenience functions for working with `Option`s. + * + * @copyright Galois, Inc. + * @author Marcella Hastings */ module Common::OptionUtils where @@ -15,22 +18,23 @@ isNone opt = ~ isSome opt * Map an `Option a` to an `Option b` by applying a function to a contained * value (if `Some`) or returns `None` (if `None`). */ -mapOption : {a, b} (a -> b) -> Option a -> Option b -mapOption f opt = case opt of +optApply : {a, b} (a -> b) -> Option a -> Option b +optApply f opt = case opt of Some x -> Some (f x) None -> None /** - * Flatten a nested option into a single option. + * Flatten a nested `Option` into a single `Option` that is `Some` only if + * both original `Option`s are `Some`. */ -flatten : {a} Option (Option a) -> Option a -flatten opt = case opt of +optFlatten : {a} Option (Option a) -> Option a +optFlatten opt = case opt of Some opt' -> opt' None -> None /** - * Map an `Option a` to an `Option b` by calling `mapOption` on a function that - * produces an `Option b`, then flattening the result. + * Map an `Option a` to an `Option b` by calling `optApply` on a function that + * produces an `Option b`, then `optFlatten`ing the result. */ -flatMap : {a, b} (a -> Option b) -> Option a -> Option b -flatMap f opt = flatten (mapOption f opt) \ No newline at end of file +optFlatApply: {a, b} (a -> Option b) -> Option a -> Option b +optFlatApply f opt = optFlatten (optApply f opt) diff --git a/Primitive/Asymmetric/Signature/ML_DSA/Specification.cry b/Primitive/Asymmetric/Signature/ML_DSA/Specification.cry index 9ebfb03..debb544 100644 --- a/Primitive/Asymmetric/Signature/ML_DSA/Specification.cry +++ b/Primitive/Asymmetric/Signature/ML_DSA/Specification.cry @@ -19,9 +19,6 @@ */ module Primitive::Asymmetric::Signature::ML_DSA::Specification where -import Common::utils(while) -import Common::OptionUtils(flatMap) - type Byte = [8] /** @@ -157,18 +154,23 @@ HintBitPack h = yFinal where (yFinal, _) = yAndIndex ! 0 /** + * Reverses the procedure `HintBitPack`. + * [FIPS-204] Section 7.1, Algorithm 21. + * * This diverges slightly from the spec: * - To simplify updating `h`, we treat it as a single array of size `256k`. * We separate it into the correct `[k]R2` representation in the final step. * We access the array in "the natural way" -- that is, in Step 12, the * element `h[i]_y[Index]` is at index `i * 256 + y[Index]` in our array. + * - We cannot "return early" when we encounter an error case. Instead, we use + * options to indicate whether a failure has occurred and skip further + * computation when the option is `None`. + * - The for loop in Step 3 is executed with a list comprehension. The while + * loop in Step 7 is executed with recursion. The for loop in Step 16 is + * executed with recursion. */ HintBitUnpack : [ω + k]Byte -> Option ([k]R2) -HintBitUnpack y = h___ where - // Useful rename for annotation functions that operate over the stateful - // pair of `h` and `Index` used across this function. - type Pair = ([k * 256], Byte) - +HintBitUnpack y = h20 where // Step 1. h0 = zero : [k * 256] // Step 2. @@ -176,68 +178,79 @@ HintBitUnpack y = h___ where // Step 3. Construct a list comprising the values of `h` and `Index` // at the end of each iteration of the loop in Steps 3 - 15. - hAndIndex = [Some (h0, Index0)] # [ - flatMap (Step4_5 i) maybe_hAndIndex - | maybe_hAndIndex <- hAndIndex + hAndIndexes = [Some (h0, Index0)] # [ + // Call Steps 4-5 if we haven't encountered an error yet. + case maybe_hAndIndex of + Some hAndIndex -> Step4_5 hAndIndex i + None -> None + | maybe_hAndIndex <- hAndIndexes | i <- [0..k-1] ] // Steps 4 - 5. - Step4_5 : Integer -> Pair -> Option Pair - Step4_5 i (h, Index) = if (y@(`ω + i) < Index) || (y@(`ω + i) > `ω) then + Step4_5 (h, Index) i = if (y@(`ω + i) < Index) || (y@(`ω + i) > `ω) then None else Step6_15 (h, Index) i - // Steps 6 -15 - Step6_15 : Pair -> Integer -> Option Pair - Step6_15 (h, Index) i = result15 where + // Steps 6 - 15. + Step6_15 (h, Index) i = Step7_14 (h, Index) i First where // Step 6. First = Index - result15 = Step7_14 (h, Index) i First // Steps 7 - 14. - Step7_14 : Pair -> Integer -> Byte -> Option Pair Step7_14 (h, Index) i First = + // Step 7 (condition). if Index < (y@(`ω + i)) then + // Step 8 - 11. // The `/\` is a short-cutting `and`, equivalent to the nested `if` // statements in the spec. if ((Index > First) /\ (y@(Index - 1) >= y@Index)) then None - // Recursive call is equivalent to continuing the loop. - // The constants `i` and `First` do not change between iterations. + // Step 7 (recursive call -- equivalent to continuing the loop). else Step7_14 + // Step 12. (update h (i*256 + (toInteger (y@Index))) 1, + // Step 13. Index + 1) + // These variables do not change between iterations. i First // If the loop condition is no longer true, return the current values // of `h` and `Index`. else Some (h, Index) // Get the values of `h` and `Index` after the loop in Steps 3 - 15. - maybe_hAndIndex = hAndIndex ! 0 + maybe_hAndIndex' = hAndIndexes ! 0 // This helper function uses recursion to read any leftover bytes in the // first `ω` bytes of `y`; it returns an error if any of them are non-zero. - checkZero idx = - if idx > (`ω - 1) then True - else if (y@idx != zero) then False - else checkZero (idx + 1) + checkZero i = + if i > (`ω - 1) then True + else if (y@i != zero) then False + else checkZero (i + 1) // Step 16 - 20. - h___ = flatMap - (\(h, Index) -> if (checkZero Index) then Some (split`{k} h) + h20 = case maybe_hAndIndex' of + Some hAndIndex -> if (checkZero Index) then Some (split`{k} h) else None - ) - maybe_hAndIndex + where (h, Index) = hAndIndex + None -> None /** * Verify that `HintBitUnpack` is the reverse of `HintBitPack`. * * This takes a list of indexes indicating the non-zero elements and constructs - * a valid, sparse `h` -- rejection sampling was not a valid option because + * a valid, sparse `h` -- rejection sampling is not a valid option because * sparse-enough `h`s were too rare. + * + * We test the case where we have the maximum number of 1s, a medium number, and + * a case where at least one vector should have no non-zero terms at all. + * In practice, the hint may fall anywhere in this range. * ```repl - * :check HintPackingInverts + * :check HintPackingInverts`{ω} + * :check HintPackingInverts`{ω / 2} + * :check HintPackingInverts`{3} * ``` + * + * Note that this does not test the error cases for `HintBitUnpack`. */ HintPackingInverts : {w} (w <= ω) => [w][lg2 (256 * k)] -> Bit property HintPackingInverts h_Indexes = @@ -245,5 +258,5 @@ property HintPackingInverts h_Indexes = Some h' -> h == h' None -> False where - // build h out of h_indexes: + // Build `h` out of `h_indexes`: h = split`{k} [if elem idx h_Indexes then 1 else 0 | idx <- [0..(256 * k) - 1]]