88#include < array>
99#include < cstdint>
1010#include < limits>
11+ #include < span>
1112
1213// ML-DSA FIPS 204
1314namespace ml_dsa {
@@ -18,6 +19,9 @@ static constexpr size_t KEYGEN_SEED_BYTE_LEN = 32;
1819// Byte length of randomness, required for hedged signing.
1920static constexpr size_t RND_BYTE_LEN = 32 ;
2021
22+ // Byte length of message representative, which is to be signed.
23+ static constexpr size_t MU_BYTE_LEN = 64 ;
24+
2125// Given seed ξ, this routine generates a public key and secret key pair, using deterministic key generation algorithm.
2226//
2327// See algorithm 1 of ML-DSA standard @ https://doi.org/10.6028/NIST.FIPS.204.
@@ -113,28 +117,23 @@ keygen(std::span<const uint8_t, KEYGEN_SEED_BYTE_LEN> ξ,
113117 ml_dsa_polyvec::encode<k, d>(t0, seckey.template subspan <skoff5, skoff6 - skoff5>());
114118}
115119
116- // Given a ML-DSA secret key, message (can be empty too) and context (optional, but if given, length must be capped at 255 -bytes),
117- // this routine computes a hedged/ deterministic signature.
120+ // Given a ML-DSA secret key and 64 -bytes message representative, this routine computes a hedged/ deterministic signature.
118121//
119122// Notice, first parameter of this function, `rnd`, which lets you pass 32 -bytes randomness for generating default
120123// "hedged" signature. In case you don't need randomized message signature, you can instead fill `rnd` with zeros, and
121124// it'll generate a deterministic signature.
122125//
123126// Note, hedged signing is the default and recommended version.
124127//
125- // See algorithm 2 of ML-DSA standard @ https://doi.org/10.6028/NIST.FIPS.204.
128+ // See algorithm 7 of ML-DSA standard @ https://doi.org/10.6028/NIST.FIPS.204.
126129template <size_t k, size_t l, size_t d, uint32_t eta, uint32_t gamma1, uint32_t gamma2, uint32_t tau, uint32_t beta, size_t omega, size_t lambda>
127130static inline constexpr bool
128- sign (std::span<const uint8_t , RND_BYTE_LEN> rnd,
129- std::span<const uint8_t , ml_dsa_utils::sec_key_len(k, l, eta, d)> seckey,
130- std::span<const uint8_t> msg,
131- std::span<const uint8_t> ctx,
132- std::span<uint8_t, ml_dsa_utils::sig_len(k, l, gamma1, omega, lambda)> sig)
131+ sign_internal (std::span<const uint8_t , RND_BYTE_LEN> rnd,
132+ std::span<const uint8_t , ml_dsa_utils::sec_key_len(k, l, eta, d)> seckey,
133+ std::span<const uint8_t, MU_BYTE_LEN> mu,
134+ std::span<uint8_t, ml_dsa_utils::sig_len(k, l, gamma1, omega, lambda)> sig)
133135 requires(ml_dsa_params::check_signing_params(k, l, d, eta, gamma1, gamma2, tau, beta, omega, lambda))
134136{
135- if (ctx.size () > std::numeric_limits<uint8_t >::max ()) {
136- return false ;
137- }
138137 constexpr uint32_t t0_rng = 1u << (d - 1 );
139138
140139 constexpr size_t eta_bw = std::bit_width (2 * eta);
@@ -150,30 +149,17 @@ sign(std::span<const uint8_t, RND_BYTE_LEN> rnd,
150149
151150 auto rho = seckey.template subspan <skoff0, skoff1 - skoff0>();
152151 auto key = seckey.template subspan <skoff1, skoff2 - skoff1>();
153- auto tr = seckey.template subspan <skoff2, skoff3 - skoff2>();
154152
155153 std::array<ml_dsa_field::zq_t , k * l * ml_dsa_ntt::N> A{};
156154 ml_dsa_sampling::expand_a<k, l>(rho, A);
157155
158- std::array<uint8_t , 64 > mu{};
159- auto mu_span = std::span (mu);
160-
161- const std::array<uint8_t , 2 > domain_separator{ 0 , static_cast <uint8_t >(ctx.size ()) };
162-
163- shake256::shake256_t hasher;
164- hasher.absorb (tr);
165- hasher.absorb (domain_separator);
166- hasher.absorb (ctx);
167- hasher.absorb (msg);
168- hasher.finalize ();
169- hasher.squeeze (mu_span);
170-
171156 std::array<uint8_t , 64 > rho_prime{};
172157
158+ shake256::shake256_t hasher;
173159 hasher.reset ();
174160 hasher.absorb (key);
175161 hasher.absorb (rnd);
176- hasher.absorb (mu_span );
162+ hasher.absorb (mu );
177163 hasher.finalize ();
178164 hasher.squeeze (rho_prime);
179165
@@ -226,7 +212,7 @@ sign(std::span<const uint8_t, RND_BYTE_LEN> rnd,
226212 ml_dsa_polyvec::encode<k, w1bw>(w1, w1_encoded);
227213
228214 hasher.reset ();
229- hasher.absorb (mu_span );
215+ hasher.absorb (mu );
230216 hasher.absorb (w1_encoded);
231217 hasher.finalize ();
232218 hasher.squeeze (c_tilda_span);
@@ -302,23 +288,63 @@ sign(std::span<const uint8_t, RND_BYTE_LEN> rnd,
302288 return has_signed;
303289}
304290
305- // Given a ML-DSA public key, message (can be empty too), context (optional, but if given, length must be capped at 255 -bytes)
306- // and serialized signature, this routine verifies validity of the signature, returning boolean result, denoting status
307- // of signature verification. For example, say it returns true, it means signature is valid for given message and public key.
291+ // Given a ML-DSA secret key, message (can be empty too) and context (optional, but if given, length must be capped at 255 -bytes),
292+ // this routine computes a hedged/ deterministic signature.
308293//
309- // See algorithm 3 of ML-DSA standard @ https://doi.org/10.6028/NIST.FIPS.204.
310- template <size_t k, size_t l, size_t d, uint32_t gamma1, uint32_t gamma2, uint32_t tau, uint32_t beta, size_t omega, size_t lambda>
294+ // Notice, first parameter of this function, `rnd`, which lets you pass 32 -bytes randomness for generating default
295+ // "hedged" signature. In case you don't need randomized message signature, you can instead fill `rnd` with zeros, and
296+ // it'll generate a deterministic signature.
297+ //
298+ // Note, hedged signing is the default and recommended version.
299+ //
300+ // See algorithm 2 of ML-DSA standard @ https://doi.org/10.6028/NIST.FIPS.204.
301+ template <size_t k, size_t l, size_t d, uint32_t eta, uint32_t gamma1, uint32_t gamma2, uint32_t tau, uint32_t beta, size_t omega, size_t lambda>
311302static inline constexpr bool
312- verify (std::span<const uint8_t , ml_dsa_utils::pub_key_len(k, d)> pubkey,
313- std::span<const uint8_t> msg,
314- std::span<const uint8_t> ctx,
315- std::span<const uint8_t, ml_dsa_utils::sig_len(k, l, gamma1, omega, lambda)> sig)
316- requires(ml_dsa_params::check_verify_params(k, l, d, gamma1, gamma2, tau, beta, omega, lambda))
303+ sign (std::span<const uint8_t , RND_BYTE_LEN> rnd,
304+ std::span<const uint8_t , ml_dsa_utils::sec_key_len(k, l, eta, d)> seckey,
305+ std::span<const uint8_t> msg,
306+ std::span<const uint8_t> ctx,
307+ std::span<uint8_t, ml_dsa_utils::sig_len(k, l, gamma1, omega, lambda)> sig)
308+ requires(ml_dsa_params::check_signing_params(k, l, d, eta, gamma1, gamma2, tau, beta, omega, lambda))
317309{
318310 if (ctx.size () > std::numeric_limits<uint8_t >::max ()) {
319311 return false ;
320312 }
321313
314+ constexpr size_t skoff0 = 0 ;
315+ constexpr size_t skoff1 = skoff0 + 32 ;
316+ constexpr size_t skoff2 = skoff1 + 32 ;
317+ constexpr size_t skoff3 = skoff2 + 64 ;
318+
319+ auto tr = seckey.template subspan <skoff2, skoff3 - skoff2>();
320+ const std::array<uint8_t , 2 > domain_separator{ 0 , static_cast <uint8_t >(ctx.size ()) };
321+
322+ std::array<uint8_t , MU_BYTE_LEN> mu{};
323+ auto mu_span = std::span (mu);
324+
325+ shake256::shake256_t hasher;
326+ hasher.absorb (tr);
327+ hasher.absorb (domain_separator);
328+ hasher.absorb (ctx);
329+ hasher.absorb (msg);
330+ hasher.finalize ();
331+ hasher.squeeze (mu_span);
332+
333+ return sign_internal<k, l, d, eta, gamma1, gamma2, tau, beta, omega, lambda>(rnd, seckey, mu_span, sig);
334+ }
335+
336+ // Given a ML-DSA public key, 64 -bytes message representative and serialized signature, this routine verifies validity of the signature,
337+ // returning boolean result, denoting status of signature verification. For example, say it returns true, it means signature is valid for
338+ // given message and public key.
339+ //
340+ // See algorithm 8 of ML-DSA standard @ https://doi.org/10.6028/NIST.FIPS.204.
341+ template <size_t k, size_t l, size_t d, uint32_t gamma1, uint32_t gamma2, uint32_t tau, uint32_t beta, size_t omega, size_t lambda>
342+ static inline constexpr bool
343+ verify_internal (std::span<const uint8_t , ml_dsa_utils::pub_key_len(k, d)> pubkey,
344+ std::span<const uint8_t, MU_BYTE_LEN> mu,
345+ std::span<const uint8_t, ml_dsa_utils::sig_len(k, l, gamma1, omega, lambda)> sig)
346+ requires(ml_dsa_params::check_verify_params(k, l, d, gamma1, gamma2, tau, beta, omega, lambda))
347+ {
322348 constexpr size_t t1_bw = std::bit_width (ml_dsa_field::Q) - d;
323349 constexpr size_t gamma1_bw = std::bit_width (gamma1);
324350
@@ -370,24 +396,6 @@ verify(std::span<const uint8_t, ml_dsa_utils::pub_key_len(k, d)> pubkey,
370396 ml_dsa_sampling::expand_a<k, l>(rho, A);
371397 ml_dsa_polyvec::decode<k, t1_bw>(t1_encoded, t1);
372398
373- std::array<uint8_t , 64 > tr{};
374- std::array<uint8_t , 64 > mu{};
375-
376- shake256::shake256_t hasher;
377- hasher.absorb (pubkey);
378- hasher.finalize ();
379- hasher.squeeze (tr);
380-
381- const std::array<uint8_t , 2 > domain_separator{ 0 , static_cast <uint8_t >(ctx.size ()) };
382-
383- hasher.reset ();
384- hasher.absorb (tr);
385- hasher.absorb (domain_separator);
386- hasher.absorb (ctx);
387- hasher.absorb (msg);
388- hasher.finalize ();
389- hasher.squeeze (mu);
390-
391399 std::array<ml_dsa_field::zq_t , k * ml_dsa_ntt::N> w0{};
392400 std::array<ml_dsa_field::zq_t , k * ml_dsa_ntt::N> w1{};
393401 std::array<ml_dsa_field::zq_t , k * ml_dsa_ntt::N> w2{};
@@ -414,7 +422,7 @@ verify(std::span<const uint8_t, ml_dsa_utils::pub_key_len(k, d)> pubkey,
414422
415423 std::array<uint8_t , c_tilda.size ()> c_tilda_prime{};
416424
417- hasher. reset () ;
425+ shake256:: shake256_t hasher;
418426 hasher.absorb (mu);
419427 hasher.absorb (w1_encoded);
420428 hasher.finalize ();
@@ -423,4 +431,42 @@ verify(std::span<const uint8_t, ml_dsa_utils::pub_key_len(k, d)> pubkey,
423431 return std::equal (c_tilda.begin (), c_tilda.end (), c_tilda_prime.begin ());
424432}
425433
434+ // Given a ML-DSA public key, message (can be empty too), context (optional, but if given, length must be capped at 255 -bytes)
435+ // and serialized signature, this routine verifies validity of the signature, returning boolean result, denoting status
436+ // of signature verification. For example, say it returns true, it means signature is valid for given message and public key.
437+ //
438+ // See algorithm 3 of ML-DSA standard @ https://doi.org/10.6028/NIST.FIPS.204.
439+ template <size_t k, size_t l, size_t d, uint32_t gamma1, uint32_t gamma2, uint32_t tau, uint32_t beta, size_t omega, size_t lambda>
440+ static inline constexpr bool
441+ verify (std::span<const uint8_t , ml_dsa_utils::pub_key_len(k, d)> pubkey,
442+ std::span<const uint8_t> msg,
443+ std::span<const uint8_t> ctx,
444+ std::span<const uint8_t, ml_dsa_utils::sig_len(k, l, gamma1, omega, lambda)> sig)
445+ requires(ml_dsa_params::check_verify_params(k, l, d, gamma1, gamma2, tau, beta, omega, lambda))
446+ {
447+ if (ctx.size () > std::numeric_limits<uint8_t >::max ()) {
448+ return false ;
449+ }
450+
451+ std::array<uint8_t , 64 > mu{};
452+ std::array<uint8_t , 64 > tr{};
453+
454+ shake256::shake256_t hasher;
455+ hasher.absorb (pubkey);
456+ hasher.finalize ();
457+ hasher.squeeze (tr);
458+
459+ const std::array<uint8_t , 2 > domain_separator{ 0 , static_cast <uint8_t >(ctx.size ()) };
460+
461+ hasher.reset ();
462+ hasher.absorb (tr);
463+ hasher.absorb (domain_separator);
464+ hasher.absorb (ctx);
465+ hasher.absorb (msg);
466+ hasher.finalize ();
467+ hasher.squeeze (mu);
468+
469+ return verify_internal<k, l, d, gamma1, gamma2, tau, beta, omega, lambda>(pubkey, mu, sig);
470+ }
471+
426472}
0 commit comments