Skip to content

Commit

Permalink
Add proof for sample_ring_element_cbd
Browse files Browse the repository at this point in the history
  • Loading branch information
mamonet committed Oct 10, 2024
1 parent b93b5a3 commit 53bd075
Show file tree
Hide file tree
Showing 12 changed files with 200 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ val v_PRF (v_LEN: usize) (input: t_Slice u8)
result == Spec.Utils.v_PRF v_LEN input)

val v_PRFxN (v_K v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K)
: Prims.Pure (t_Array (t_Array u8 v_LEN) v_K) Prims.l_True (fun _ -> Prims.l_True)
: Prims.Pure (t_Array (t_Array u8 v_LEN) v_K)
(requires v v_LEN < pow2 32 /\ (v v_K == 2 \/ v v_K == 3 \/ v v_K == 4))
(ensures
fun result ->
let result:t_Array (t_Array u8 v_LEN) v_K = result in
result == Spec.Utils.v_PRFxN v_K v_LEN input)

/// The state.
/// It\'s only used for SHAKE128.
Expand Down Expand Up @@ -63,15 +68,19 @@ let impl (v_K: usize) : Libcrux_ml_kem.Hash_functions.t_Hash t_Simd256Hash v_K =
(fun (v_LEN: usize) (input: t_Slice u8) (out: t_Array u8 v_LEN) ->
v v_LEN < pow2 32 ==> out == Spec.Utils.v_PRF v_LEN input);
f_PRF = (fun (v_LEN: usize) (input: t_Slice u8) -> v_PRF v_LEN input);
f_PRFxN_pre = (fun (v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K) -> true);
f_PRFxN_pre
=
(fun (v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K) ->
v v_LEN < pow2 32 /\ (v v_K == 2 \/ v v_K == 3 \/ v v_K == 4));
f_PRFxN_post
=
(fun
(v_LEN: usize)
(input: t_Array (t_Array u8 (sz 33)) v_K)
(out: t_Array (t_Array u8 v_LEN) v_K)
->
true);
(v v_LEN < pow2 32 /\ (v v_K == 2 \/ v v_K == 3 \/ v v_K == 4)) ==>
out == Spec.Utils.v_PRFxN v_K v_LEN input);
f_PRFxN
=
(fun (v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K) -> v_PRFxN v_K v_LEN input);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ val v_PRF (v_LEN: usize) (input: t_Slice u8)
result == Spec.Utils.v_PRF v_LEN input)

val v_PRFxN (v_K v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K)
: Prims.Pure (t_Array (t_Array u8 v_LEN) v_K) Prims.l_True (fun _ -> Prims.l_True)
: Prims.Pure (t_Array (t_Array u8 v_LEN) v_K)
(requires v v_LEN < pow2 32 /\ (v v_K == 2 \/ v v_K == 3 \/ v v_K == 4))
(ensures
fun result ->
let result:t_Array (t_Array u8 v_LEN) v_K = result in
result == Spec.Utils.v_PRFxN v_K v_LEN input)

/// The state.
/// It\'s only used for SHAKE128.
Expand Down Expand Up @@ -63,15 +68,19 @@ let impl (v_K: usize) : Libcrux_ml_kem.Hash_functions.t_Hash t_Simd128Hash v_K =
(fun (v_LEN: usize) (input: t_Slice u8) (out: t_Array u8 v_LEN) ->
v v_LEN < pow2 32 ==> out == Spec.Utils.v_PRF v_LEN input);
f_PRF = (fun (v_LEN: usize) (input: t_Slice u8) -> v_PRF v_LEN input);
f_PRFxN_pre = (fun (v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K) -> true);
f_PRFxN_pre
=
(fun (v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K) ->
v v_LEN < pow2 32 /\ (v v_K == 2 \/ v v_K == 3 \/ v v_K == 4));
f_PRFxN_post
=
(fun
(v_LEN: usize)
(input: t_Array (t_Array u8 (sz 33)) v_K)
(out: t_Array (t_Array u8 v_LEN) v_K)
->
true);
(v v_LEN < pow2 32 /\ (v v_K == 2 \/ v v_K == 3 \/ v v_K == 4)) ==>
out == Spec.Utils.v_PRFxN v_K v_LEN input);
f_PRFxN
=
(fun (v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K) -> v_PRFxN v_K v_LEN input);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,12 @@ val v_PRF (v_LEN: usize) (input: t_Slice u8)
result == Spec.Utils.v_PRF v_LEN input)

val v_PRFxN (v_K v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K)
: Prims.Pure (t_Array (t_Array u8 v_LEN) v_K) Prims.l_True (fun _ -> Prims.l_True)
: Prims.Pure (t_Array (t_Array u8 v_LEN) v_K)
(requires v v_LEN < pow2 32 /\ (v v_K == 2 \/ v v_K == 3 \/ v v_K == 4))
(ensures
fun result ->
let result:t_Array (t_Array u8 v_LEN) v_K = result in
result == Spec.Utils.v_PRFxN v_K v_LEN input)

/// The state.
/// It\'s only used for SHAKE128.
Expand Down Expand Up @@ -63,15 +68,19 @@ let impl (v_K: usize) : Libcrux_ml_kem.Hash_functions.t_Hash (t_PortableHash v_K
(fun (v_LEN: usize) (input: t_Slice u8) (out: t_Array u8 v_LEN) ->
v v_LEN < pow2 32 ==> out == Spec.Utils.v_PRF v_LEN input);
f_PRF = (fun (v_LEN: usize) (input: t_Slice u8) -> v_PRF v_LEN input);
f_PRFxN_pre = (fun (v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K) -> true);
f_PRFxN_pre
=
(fun (v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K) ->
v v_LEN < pow2 32 /\ (v v_K == 2 \/ v v_K == 3 \/ v v_K == 4));
f_PRFxN_post
=
(fun
(v_LEN: usize)
(input: t_Array (t_Array u8 (sz 33)) v_K)
(out: t_Array (t_Array u8 v_LEN) v_K)
->
true);
(v v_LEN < pow2 32 /\ (v v_K == 2 \/ v v_K == 3 \/ v v_K == 4)) ==>
out == Spec.Utils.v_PRFxN v_K v_LEN input);
f_PRFxN
=
(fun (v_LEN: usize) (input: t_Array (t_Array u8 (sz 33)) v_K) -> v_PRFxN v_K v_LEN input);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,17 @@ class t_Hash (v_Self: Type0) (v_K: usize) = {
-> pred: Type0{pred ==> v v_LEN < pow2 32 ==> result == Spec.Utils.v_PRF v_LEN input};
f_PRF:v_LEN: usize -> x0: t_Slice u8
-> Prims.Pure (t_Array u8 v_LEN) (f_PRF_pre v_LEN x0) (fun result -> f_PRF_post v_LEN x0 result);
f_PRFxN_pre:v_LEN: usize -> input: t_Array (t_Array u8 (sz 33)) v_K -> pred: Type0{true ==> pred};
f_PRFxN_post:v_LEN: usize -> t_Array (t_Array u8 (sz 33)) v_K -> t_Array (t_Array u8 v_LEN) v_K
-> Type0;
f_PRFxN_pre:v_LEN: usize -> input: t_Array (t_Array u8 (sz 33)) v_K
-> pred: Type0{v v_LEN < pow2 32 /\ (v v_K == 2 \/ v v_K == 3 \/ v v_K == 4) ==> pred};
f_PRFxN_post:
v_LEN: usize ->
input: t_Array (t_Array u8 (sz 33)) v_K ->
result: t_Array (t_Array u8 v_LEN) v_K
-> pred:
Type0
{ pred ==>
(v v_LEN < pow2 32 /\ (v v_K == 2 \/ v v_K == 3 \/ v v_K == 4)) ==>
result == Spec.Utils.v_PRFxN v_K v_LEN input };
f_PRFxN:v_LEN: usize -> x0: t_Array (t_Array u8 (sz 33)) v_K
-> Prims.Pure (t_Array (t_Array u8 v_LEN) v_K)
(f_PRFxN_pre v_LEN x0)
Expand Down
87 changes: 66 additions & 21 deletions libcrux-ml-kem/proofs/fstar/extraction/Libcrux_ml_kem.Ind_cpa.fst
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ let _ =
let open Libcrux_ml_kem.Vector.Traits in
()

#push-options "--max_fuel 10 --z3rlimit 1000 --ext context_pruning --z3refresh --split_queries always"

let sample_ring_element_cbd
(v_K v_ETA2_RANDOMNESS_SIZE v_ETA2: usize)
(#v_Vector #v_Hasher: Type0)
Expand All @@ -35,13 +37,22 @@ let sample_ring_element_cbd
in
let prf_inputs:t_Array (t_Array u8 (sz 33)) v_K = Rust_primitives.Hax.repeat prf_input v_K in
let v__domain_separator_init:u8 = domain_separator in
let v__prf_inputs_init:t_Array (t_Array u8 (sz 33)) v_K = prf_inputs in
let domain_separator, prf_inputs:(u8 & t_Array (t_Array u8 (sz 33)) v_K) =
Rust_primitives.Hax.Folds.fold_range (sz 0)
v_K
(fun temp_0_ i ->
let domain_separator, prf_inputs:(u8 & t_Array (t_Array u8 (sz 33)) v_K) = temp_0_ in
let i:usize = i in
v domain_separator == v v__domain_separator_init + v i)
v domain_separator == v v__domain_separator_init + v i /\
(v i < v v_K ==>
(forall (j: nat).
(j >= v i /\ j < v v_K) ==> prf_inputs.[ sz j ] == v__prf_inputs_init.[ sz j ])) /\
(forall (j: nat).
j < v i ==>
v (Seq.index (Seq.index prf_inputs j) 32) == v v__domain_separator_init + j /\
Seq.slice (Seq.index prf_inputs j) 0 32 ==
Seq.slice (Seq.index v__prf_inputs_init j) 0 32))
(domain_separator, prf_inputs <: (u8 & t_Array (t_Array u8 (sz 33)) v_K))
(fun temp_0_ i ->
let domain_separator, prf_inputs:(u8 & t_Array (t_Array u8 (sz 33)) v_K) = temp_0_ in
Expand All @@ -60,6 +71,28 @@ let sample_ring_element_cbd
let domain_separator:u8 = domain_separator +! 1uy in
domain_separator, prf_inputs <: (u8 & t_Array (t_Array u8 (sz 33)) v_K))
in
let _:Prims.unit =
let lemma_aux (i: nat{i < v v_K})
: Lemma
(prf_inputs.[ sz i ] ==
(Seq.append (Seq.slice prf_input 0 32)
(Seq.create 1
(mk_int #u8_inttype (v (v__domain_separator_init +! (mk_int #u8_inttype i))))))) =
Lib.Sequence.eq_intro #u8
#33
prf_inputs.[ sz i ]
(Seq.append (Seq.slice prf_input 0 32)
(Seq.create 1 (mk_int #u8_inttype (v v__domain_separator_init + i))))
in
Classical.forall_intro lemma_aux;
Lib.Sequence.eq_intro #(t_Array u8 (sz 33))
#(v v_K)
prf_inputs
(createi v_K
(Spec.MLKEM.sample_vector_cbd2_prf_input #v_K
(Seq.slice prf_input 0 32)
(sz (v v__domain_separator_init))))
in
let (prf_outputs: t_Array (t_Array u8 v_ETA2_RANDOMNESS_SIZE) v_K):t_Array
(t_Array u8 v_ETA2_RANDOMNESS_SIZE) v_K =
Libcrux_ml_kem.Hash_functions.f_PRFxN #v_Hasher
Expand All @@ -71,35 +104,45 @@ let sample_ring_element_cbd
let error_1_:t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K =
Rust_primitives.Hax.Folds.fold_range (sz 0)
v_K
(fun error_1_ temp_1_ ->
(fun error_1_ i ->
let error_1_:t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K =
error_1_
in
let _:usize = temp_1_ in
true)
let i:usize = i in
forall (j: nat).
j < v i ==>
Libcrux_ml_kem.Polynomial.to_spec_poly_t #v_Vector error_1_.[ sz j ] ==
Spec.MLKEM.sample_poly_cbd v_ETA2 prf_outputs.[ sz j ])
error_1_
(fun error_1_ i ->
let error_1_:t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K =
error_1_
in
let i:usize = i in
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize error_1_
i
(Libcrux_ml_kem.Sampling.sample_from_binomial_distribution v_ETA2
#v_Vector
(prf_outputs.[ i ] <: t_Slice u8)
<:
Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector)
<:
t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K)
in
let result:(t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K & u8) =
error_1_, domain_separator
<:
(t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K & u8)
let error_1_:t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K =
Rust_primitives.Hax.Monomorphized_update_at.update_at_usize error_1_
i
(Libcrux_ml_kem.Sampling.sample_from_binomial_distribution v_ETA2
#v_Vector
(prf_outputs.[ i ] <: t_Slice u8)
<:
Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector)
in
error_1_)
in
let _:Prims.unit = admit () (* Panic freedom *) in
result
let _:Prims.unit =
Lib.Sequence.eq_intro #(Spec.MLKEM.polynomial)
#(v v_K)
(Libcrux_ml_kem.Polynomial.to_spec_vector_t #v_K #v_Vector error_1_)
(Spec.MLKEM.sample_vector_cbd2 #v_K
(Seq.slice prf_input 0 32)
(sz (v v__domain_separator_init)))
in
error_1_, domain_separator
<:
(t_Array (Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector) v_K & u8)

#pop-options

let sample_vector_cbd_then_ntt
(v_K v_ETA v_ETA_RANDOMNESS_SIZE: usize)
Expand All @@ -122,7 +165,9 @@ let sample_vector_cbd_then_ntt
(fun temp_0_ i ->
let domain_separator, prf_inputs:(u8 & t_Array (t_Array u8 (sz 33)) v_K) = temp_0_ in
let i:usize = i in
v domain_separator == v v__domain_separator_init + v i)
v domain_separator == v v__domain_separator_init + v i /\
(forall (j: nat). j < v i ==> v (Seq.index prf_input j) == v v__domain_separator_init + j)
)
(domain_separator, prf_inputs <: (u8 & t_Array (t_Array u8 (sz 33)) v_K))
(fun temp_0_ i ->
let domain_separator, prf_inputs:(u8 & t_Array (t_Array u8 (sz 33)) v_K) = temp_0_ in
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ val sample_ring_element_cbd
in
v ds == v domain_separator + v v_K /\
Libcrux_ml_kem.Polynomial.to_spec_vector_t #v_K #v_Vector err1 ==
Spec.MLKEM.sample_vector_cbd1 #v_K (Seq.slice prf_input 0 32) (sz (v domain_separator)))
Spec.MLKEM.sample_vector_cbd2 #v_K (Seq.slice prf_input 0 32) (sz (v domain_separator)))

/// Sample a vector of ring elements from a centered binomial distribution and
/// convert them into their NTT representations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,10 +120,12 @@ val sample_from_binomial_distribution
(ensures
fun result ->
let result:Libcrux_ml_kem.Polynomial.t_PolynomialRingElement v_Vector = result in
forall (i: nat).
i < 8 ==>
Libcrux_ml_kem.Ntt.ntt_layer_7_pre (result.f_coefficients.[ sz i ])
(result.f_coefficients.[ sz i +! sz 8 ]))
(forall (i: nat).
i < 8 ==>
Libcrux_ml_kem.Ntt.ntt_layer_7_pre (result.f_coefficients.[ sz i ])
(result.f_coefficients.[ sz i +! sz 8 ])) /\
Libcrux_ml_kem.Polynomial.to_spec_poly_t #v_Vector result ==
Spec.MLKEM.sample_poly_cbd v_ETA randomness)

val sample_from_xof
(v_K: usize)
Expand Down
13 changes: 12 additions & 1 deletion libcrux-ml-kem/proofs/fstar/spec/Spec.MLKEM.fst
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,19 @@ let sample_poly_cbd1 #r seed domain_sep =
let sample_vector_cbd1 (#r:rank) (seed:t_Array u8 (sz 32)) (domain_sep:usize{v domain_sep < 2 * v r}) : vector r =
createi r (fun i -> sample_poly_cbd1 #r seed (domain_sep +! i))

// let sample_vector_cbd2 (#r:rank) (seed:t_Array u8 (sz 32)) (domain_sep:usize{v domain_sep < 2 * v r}) : vector r =
// createi r (fun i -> sample_poly_cbd2 #r seed (domain_sep +! i))

let sample_vector_cbd2_prf_input (#r:rank) (seed:t_Array u8 (sz 32)) (domain_sep:usize{v domain_sep < 2 * v r}) (i:usize{i <. r}) : t_Array u8 (sz 33) =
Seq.append seed (Seq.create 1 (mk_int #u8_inttype (v domain_sep + v i)))

let sample_vector_cbd2_prf_output (#r:rank) (prf_output:t_Array (t_Array u8 (v_ETA2_RANDOMNESS_SIZE r)) r) (i:usize{i <. r}) : polynomial =
sample_poly_cbd (v_ETA2 r) prf_output.[i]

let sample_vector_cbd2 (#r:rank) (seed:t_Array u8 (sz 32)) (domain_sep:usize{v domain_sep < 2 * v r}) : vector r =
createi r (fun i -> sample_poly_cbd2 #r seed (domain_sep +! i))
let prf_input = createi r (sample_vector_cbd2_prf_input #r seed domain_sep) in
let prf_output = v_PRFxN r (v_ETA2_RANDOMNESS_SIZE r) prf_input in
createi r (sample_vector_cbd2_prf_output #r prf_output)

let sample_vector_cbd_then_ntt (#r:rank) (seed:t_Array u8 (sz 32)) (domain_sep:usize{v domain_sep < 2 * v r}) : vector r =
vector_ntt (sample_vector_cbd1 #r seed domain_sep)
Expand Down
3 changes: 3 additions & 0 deletions libcrux-ml-kem/proofs/fstar/spec/Spec.Utils.fst
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ val v_PRF (v_LEN: usize{v v_LEN < pow2 32}) (input: t_Slice u8) : t_Array u8 v_L
let v_PRF v_LEN input = map_slice Lib.RawIntTypes.u8_to_UInt8 (
shake256 (Seq.length input) (map_slice Lib.IntTypes.secret input) (v v_LEN))

assume val v_PRFxN (r:usize{v r == 2 \/ v r == 3 \/ v r == 4}) (v_LEN: usize{v v_LEN < pow2 32})
(input: t_Array (t_Array u8 (sz 33)) r) : t_Array (t_Array u8 v_LEN) r

let v_J (input: t_Slice u8) : t_Array u8 (sz 32) = v_PRF (sz 32) input

val v_XOF (v_LEN: usize{v v_LEN < pow2 32}) (input: t_Slice u8) : t_Array u8 v_LEN
Expand Down
Loading

0 comments on commit 53bd075

Please sign in to comment.