Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove uses of coerceSize #107

Merged
merged 3 commits into from
Aug 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 42 additions & 23 deletions Common/ntt.cry
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,11 @@ naive_ivntt xs = map ((*) ivn) ys
ys = [ foldl (+) 0 (zipWith (*) (reverse xs) (odd_powers wi))
| wi <- all_powers r ]

/**
* ```repl
* :check naive_ntt_correct
* ```
*/
naive_ntt_correct : [nn]Fld -> Bool
property naive_ntt_correct a = naive_ivntt (naive_ntt a) == a

Expand All @@ -87,21 +92,17 @@ roots = iterate ((*) (r * r)) 1
/**
* An O(n log n) number theortic transform for Dilithium.
*/
import Common::utils (coerceSize)

ntt : [nn]Fld -> [nn]Fld
ntt a = ntt_r`{lg2 nn} 0 a

ntt_r : {n} (fin n) => Integer -> [2 ^^ n]Fld -> [2 ^^ n]Fld
ntt_r depth a =
if `n == 0 then
a
else
coerceSize (butterfly depth even odd)
ntt_r depth a
| n == 0 => a
| n > 0 => butterfly depth even odd
where
(lft, rht) = shuffle (coerceSize a)
even = ntt_r`{max 1 n - 1} (depth + 1) lft
odd = ntt_r`{max 1 n - 1} (depth + 1) rht
(lft, rht) = shuffle a
even = ntt_r`{n - 1} (depth + 1) lft
odd = ntt_r`{n - 1} (depth + 1) rht
marsella marked this conversation as resolved.
Show resolved Hide resolved

/**
* Group even indices in first half and odd indices in second half.
Expand All @@ -115,7 +116,7 @@ shuffle a =
/**
* Perform the butterfly operation.
*/
butterfly : {n} (fin n, n > 0) => Integer -> [n]Fld -> [n]Fld -> [2 * n]Fld
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did we make the type of butterfly less general here (i.e., it seems looks like the function makes sense for any length > 0, rather than just powers of 2)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh gosh, good question. I don't remember what I was thinking here, but probably I was just iterating on trying to understand the implementation and accidentally left this in. I'll change it back, thanks for catching this!

butterfly : {n} (fin n, n > 0) => Integer -> [n]Fld -> [n]Fld -> [2*n]Fld
butterfly depth even odd =
lft # rht
where
Expand All @@ -139,15 +140,13 @@ ivntt a =
map ((*) ivn) (ivntt_r`{lg2 nn} 0 a)

ivntt_r : {n} (fin n) => Integer -> [2 ^^ n]Fld -> [2 ^^ n]Fld
ivntt_r depth a =
if `n == 0 then
a
else
coerceSize (ivbutterfly depth even odd)
ivntt_r depth a
| n == 0 => a
| n > 0 => ivbutterfly depth even odd
where
(lft, rht) = shuffle (coerceSize a)
even = ivntt_r`{max 1 n - 1} (depth + 1) lft
odd = ivntt_r`{max 1 n - 1} (depth + 1) rht
(lft, rht) = shuffle a
even = ivntt_r`{n - 1} (depth + 1) lft
odd = ivntt_r`{n - 1} (depth + 1) rht

/**
* Perform the butterfly operation with inverse roots.
Expand All @@ -160,7 +159,12 @@ ivbutterfly depth even odd =
lft = [ even @ i + ivroots @ (i * j) * odd @ i | i <- [0 .. <n] ]
rht = [ even @ i - ivroots @ (i * j) * odd @ i | i <- [0 .. <n] ]

// Try prove
/**
* Takes ~20s to prove.
* ```repl
* :prove ntt_correct
* ```
*/
ntt_correct : [nn]Fld -> Bool
property ntt_correct a = ivntt (ntt a) == a

Expand All @@ -169,17 +173,32 @@ property ntt_correct a = ivntt (ntt a) == a
fntt : [nn]Fld -> [nn]Fld
fntt xs = ntt (zipWith (*) xs (all_powers r))

// Try prove
/**
* Takes ~40s to prove.
* ```repl
* :prove fntt_correct
* ```
*/
fntt_correct : [nn]Fld -> Bool
property fntt_correct a = naive_ntt a == fntt a

fivntt : [nn]Fld -> [nn]Fld
fivntt xs = zipWith (*) (all_powers ivr) (ivntt xs)

// Try prove
/**
* Takes ~10s to prove.
* ```repl
* :prove fivntt_correct
* ```
*/
fivntt_correct : [nn]Fld -> Bool
property fivntt_correct a = fivntt a == naive_ivntt a

// Try prove
/**
* Takes ~30s to prove.
* ```repl
* :prove ffivntt_correct
* ```
*/
ffivntt_correct : [nn]Fld -> Bool
property ffivntt_correct a = fivntt (fntt a) == a
3 changes: 0 additions & 3 deletions Common/utils.cry
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,3 @@ mp_mod_inv c = if c == 0 then error "Zero does not have a multiplicative inverse
*/
mp_mod_inv_correct : {a} (fin a, prime a, a >=2) => Z a -> Bit
property mp_mod_inv_correct x = x != 0 ==> x * mp_mod_inv x == 1

coerceSize : {m, n, a} [m]a -> [n]a
coerceSize xs = [ xs @ i | i <- [0 .. <n]]
42 changes: 9 additions & 33 deletions Primitive/Asymmetric/Cipher/ML-KEM/specification.cry
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,6 @@ NaiveNTTInv f = [term*(recip 128) | term <- ParametricNTTInv f (recip zeta)]
// This section Copyright Amazon.com, Inc. or its affiliates.
//////////////////////////////////////////////////////////////

import Common::utils (coerceSize)

// Simple lookup table for Zeta value given K
zeta_expc : [128](Z q)
zeta_expc = [ 1, 1729, 2580, 3289, 2642, 630, 1897, 848,
Expand All @@ -175,16 +173,6 @@ zeta_expc = [ 1, 1729, 2580, 3289, 2642, 630, 1897, 848,
1722, 1212, 1874, 1029, 2110, 2935, 885, 2154 ]

// Fast recursive CT-NTT
//
// The "coerceSize" calls in this code are required to satisfy
// Cryptol's type constraint solver that this code really
// is type-correct by effectively changing a static type-check
// into a dynamic one.
//
// As the static type constraint prover improves, this
// might become unncessesary.
//
// See https://github.com/GaloisInc/cryptol/issues/1489 for more details.
ct_butterfly :
{m, hm}
(m >= 2, m <= 8, hm >= 1, hm <= 7, hm == m - 1) =>
Expand All @@ -195,7 +183,7 @@ ct_butterfly v z = new_v
lower, upper : [2^^hm](Z q)
lower@x = v@x + z * v@(x + halflen)
upper@x = v@x - z * v@(x + halflen)
new_v = coerceSize (lower # upper)
new_v = lower # upper

fast_nttl :
{lv} // Length of v is a member of {256,128,64,32,16,8,4}
Expand All @@ -206,30 +194,19 @@ fast_nttl v k
| lv == 2 => ct_butterfly`{lv,lv-1} v (zeta_expc@k)

// Recursive case. Butterfly what we have, then recurse on each half,
// concatenate the results and return. As above, we need coerceSize
// here (twice) to satisfy the type checker.
| lv > 2 => coerceSize ((fast_nttl`{lv-1} s0 (k * 2)) #
(fast_nttl`{lv-1} s1 (k * 2 + 1)))
// concatenate the results and return.
| lv > 2 => (fast_nttl`{lv-1} s0 (k * 2)) #
(fast_nttl`{lv-1} s1 (k * 2 + 1))
where
t = ct_butterfly`{lv,lv-1} v (zeta_expc@k)
// Split t into two halves s0 and s1
[s0, s1] = split (coerceSize t)
[s0, s1] = split t

// Top level entry point - start with lv=256, k=1
fast_ntt : Z_q_256 -> Z_q_256
fast_ntt v = fast_nttl v 1

// Fast recursive GS-Inverse-NTT
//
// The "coerceSize" calls in this code are required to satisfy
// Cryptol's type constraint solver that this code really
// is type-correct by effectively changing a static type-check
// into a dynamic one.
//
// As the static type constraint prover improves, this
// might become unncessesary.
//
// See https://github.com/GaloisInc/cryptol/issues/1489 for more details.
gs_butterfly :
{m, hm}
(m >= 2, m <= 8, hm >= 1, hm <= 7, hm == m - 1) =>
Expand All @@ -240,7 +217,7 @@ gs_butterfly v z = new_v
lower, upper : [2^^hm](Z q)
lower@x = v@x + v@(x + halflen)
upper@x = z * (v@(x + halflen) - v@x)
new_v = coerceSize (lower # upper)
new_v = lower # upper

fast_invnttl :
{lv} // Length of v is a member of {256,128,64,32,16,8,4}
Expand All @@ -253,13 +230,12 @@ fast_invnttl v k

// Recursive case. Recurse on each half,
// concatenate the results, butterfly that, and return.
// As above, we need coerceSize here (twice) to satisfy the type checker.
| lv > 2 => gs_butterfly`{lv,lv-1} t (zeta_expc@k)
where
// Split t into two halves s0 and s1
[s0, s1] = split (coerceSize v)
t = coerceSize ((fast_invnttl`{lv-1} s0 (k * 2 + 1)) #
(fast_invnttl`{lv-1} s1 (k * 2)))
[s0, s1] = split v
t = (fast_invnttl`{lv-1} s0 (k * 2 + 1)) #
(fast_invnttl`{lv-1} s1 (k * 2))

// Multiply all elements of v by the reciprocal of 128 (modulo q)
recip_128_modq = (recip 128) : (Z q)
Expand Down
Loading