Skip to content

Commit

Permalink
mldsa: improve docs, control flow on hintpack #184
Browse files Browse the repository at this point in the history
  • Loading branch information
marsella committed Nov 27, 2024
1 parent 6a918da commit e7187b0
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 41 deletions.
22 changes: 13 additions & 9 deletions Common/OptionUtils.cry
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
/**
* Convenience functions for working with `Option`s.
*
* @copyright Galois, Inc.
* @author Marcella Hastings <marcella@galois.com>
*/
module Common::OptionUtils where

Expand All @@ -15,22 +18,23 @@ isNone opt = ~ isSome opt
* Map an `Option a` to an `Option b` by applying a function to a contained
* value (if `Some`) or returns `None` (if `None`).
*/
mapOption : {a, b} (a -> b) -> Option a -> Option b
mapOption f opt = case opt of
optApply : {a, b} (a -> b) -> Option a -> Option b
optApply f opt = case opt of
Some x -> Some (f x)
None -> None

/**
* Flatten a nested option into a single option.
* Flatten a nested `Option` into a single `Option` that is `Some` only if
* both original `Option`s are `Some`.
*/
flatten : {a} Option (Option a) -> Option a
flatten opt = case opt of
optFlatten : {a} Option (Option a) -> Option a
optFlatten opt = case opt of
Some opt' -> opt'
None -> None

/**
* Map an `Option a` to an `Option b` by calling `mapOption` on a function that
* produces an `Option b`, then flattening the result.
* Map an `Option a` to an `Option b` by calling `optApply` on a function that
* produces an `Option b`, then `optFlatten`ing the result.
*/
flatMap : {a, b} (a -> Option b) -> Option a -> Option b
flatMap f opt = flatten (mapOption f opt)
optFlatApply: {a, b} (a -> Option b) -> Option a -> Option b
optFlatApply f opt = optFlatten (optApply f opt)
77 changes: 45 additions & 32 deletions Primitive/Asymmetric/Signature/ML_DSA/Specification.cry
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@
*/
module Primitive::Asymmetric::Signature::ML_DSA::Specification where

import Common::utils(while)
import Common::OptionUtils(flatMap)

type Byte = [8]

/**
Expand Down Expand Up @@ -157,93 +154,109 @@ HintBitPack h = yFinal where
(yFinal, _) = yAndIndex ! 0

/**
* Reverses the procedure `HintBitPack`.
* [FIPS-204] Section 7.1, Algorithm 21.
*
* This diverges slightly from the spec:
* - To simplify updating `h`, we treat it as a single array of size `256k`.
* We separate it into the correct `[k]R2` representation in the final step.
* We access the array in "the natural way" -- that is, in Step 12, the
* element `h[i]_y[Index]` is at index `i * 256 + y[Index]` in our array.
* - We cannot "return early" when we encounter an error case. Instead, we use
* options to indicate whether a failure has occurred and skip further
* computation when the option is `None`.
* - The for loop in Step 3 is executed with a list comprehension. The while
* loop in Step 7 is executed with recursion. The for loop in Step 16 is
* executed with recursion.
*/
HintBitUnpack : [ω + k]Byte -> Option ([k]R2)
HintBitUnpack y = h___ where
// Useful rename for annotation functions that operate over the stateful
// pair of `h` and `Index` used across this function.
type Pair = ([k * 256], Byte)

HintBitUnpack y = h20 where
// Step 1.
h0 = zero : [k * 256]
// Step 2.
Index0 = 0

// Step 3. Construct a list comprising the values of `h` and `Index`
// at the end of each iteration of the loop in Steps 3 - 15.
hAndIndex = [Some (h0, Index0)] # [
flatMap (Step4_5 i) maybe_hAndIndex
| maybe_hAndIndex <- hAndIndex
hAndIndexes = [Some (h0, Index0)] # [
// Call Steps 4-5 if we haven't encountered an error yet.
case maybe_hAndIndex of
Some hAndIndex -> Step4_5 hAndIndex i
None -> None
| maybe_hAndIndex <- hAndIndexes
| i <- [0..k-1]
]

// Steps 4 - 5.
Step4_5 : Integer -> Pair -> Option Pair
Step4_5 i (h, Index) = if (y@(`ω + i) < Index) || (y@(`ω + i) > `ω) then
Step4_5 (h, Index) i = if (y@(`ω + i) < Index) || (y@(`ω + i) > `ω) then
None
else Step6_15 (h, Index) i

// Steps 6 -15
Step6_15 : Pair -> Integer -> Option Pair
Step6_15 (h, Index) i = result15 where
// Steps 6 - 15.
Step6_15 (h, Index) i = Step7_14 (h, Index) i First where
// Step 6.
First = Index
result15 = Step7_14 (h, Index) i First

// Steps 7 - 14.
Step7_14 : Pair -> Integer -> Byte -> Option Pair
Step7_14 (h, Index) i First =
// Step 7 (condition).
if Index < (y@(`ω + i)) then
// Step 8 - 11.
// The `/\` is a short-cutting `and`, equivalent to the nested `if`
// statements in the spec.
if ((Index > First) /\ (y@(Index - 1) >= y@Index)) then None
// Recursive call is equivalent to continuing the loop.
// The constants `i` and `First` do not change between iterations.
// Step 7 (recursive call -- equivalent to continuing the loop).
else Step7_14
// Step 12.
(update h (i*256 + (toInteger (y@Index))) 1,
// Step 13.
Index + 1)
// These variables do not change between iterations.
i First
// If the loop condition is no longer true, return the current values
// of `h` and `Index`.
else Some (h, Index)

// Get the values of `h` and `Index` after the loop in Steps 3 - 15.
maybe_hAndIndex = hAndIndex ! 0
maybe_hAndIndex' = hAndIndexes ! 0

// This helper function uses recursion to read any leftover bytes in the
// first `ω` bytes of `y`; it returns an error if any of them are non-zero.
checkZero idx =
if idx > (`ω - 1) then True
else if (y@idx != zero) then False
else checkZero (idx + 1)
checkZero i =
if i > (`ω - 1) then True
else if (y@i != zero) then False
else checkZero (i + 1)

// Step 16 - 20.
h___ = flatMap
(\(h, Index) -> if (checkZero Index) then Some (split`{k} h)
h20 = case maybe_hAndIndex' of
Some hAndIndex -> if (checkZero Index) then Some (split`{k} h)
else None
)
maybe_hAndIndex
where (h, Index) = hAndIndex
None -> None

/**
* Verify that `HintBitUnpack` is the reverse of `HintBitPack`.
*
* This takes a list of indexes indicating the non-zero elements and constructs
* a valid, sparse `h` -- rejection sampling was not a valid option because
* a valid, sparse `h` -- rejection sampling is not a valid option because
* sparse-enough `h`s were too rare.
*
* We test the case where we have the maximum number of 1s, a medium number, and
* a case where at least one vector should have no non-zero terms at all.
* In practice, the hint may fall anywhere in this range.
* ```repl
* :check HintPackingInverts
* :check HintPackingInverts`{ω}
* :check HintPackingInverts`{ω / 2}
* :check HintPackingInverts`{3}
* ```
*
* Note that this does not test the error cases for `HintBitUnpack`.
*/
HintPackingInverts : {w} (w <= ω) => [w][lg2 (256 * k)] -> Bit
property HintPackingInverts h_Indexes =
case HintBitUnpack (HintBitPack h) of
Some h' -> h == h'
None -> False
where
// build h out of h_indexes:
// Build `h` out of `h_indexes`:
h = split`{k} [if elem idx h_Indexes then 1 else 0 | idx <- [0..(256 * k) - 1]]

0 comments on commit e7187b0

Please sign in to comment.