From 5f55b8672ea4e4ba2eea4466961cfdc3e35d942b Mon Sep 17 00:00:00 2001 From: Jade Philipoom Date: Tue, 27 Feb 2018 13:15:34 +0100 Subject: [PATCH 1/7] factor out convert-mul-convert and prove correctness --- src/Experiments/SimplyTypedArithmetic.v | 407 +++++++++++++++++++----- 1 file changed, 326 insertions(+), 81 deletions(-) diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 92cf35eeee..c6c12827d8 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -137,6 +137,9 @@ Module Positional. Section Positional. Lemma eval_nil n : eval n [] = 0. Proof. cbv [eval to_associational]. rewrite combine_nil_r. reflexivity. Qed. Hint Rewrite eval_nil : push_eval. + Lemma eval0 p : eval 0 p = 0. + Proof. cbv [eval to_associational]. reflexivity. Qed. + Hint Rewrite eval0 : push_eval. Lemma eval_snoc n m x y : n = length x -> m = S n -> eval m (x ++ [y]) = eval n x + weight n * y. Proof. @@ -262,6 +265,8 @@ Module Positional. Section Positional. (weight (S index) / weight index) (to_associational n p)). + Lemma length_carry n m index p : length (carry n m index p) = m. + Proof. cbv [carry]; distr_length. Qed. Lemma eval_carry n m i p: (n <> 0%nat) -> (m <> 0%nat) -> weight (S i) / weight i <> 0 -> eval m (carry n m i p) = eval n p. @@ -316,6 +321,19 @@ Module Positional. Section Positional. cbn [fold_right]; distr_length. Qed. Hint Rewrite @length_chained_carries : distr_length. + (* carries without modular reduction; useful for converting between bases *) + Definition chained_carries_no_reduce n p (idxs : list nat) := + fold_right (fun a b => carry n n a b) p (rev idxs). + Lemma eval_chained_carries_no_reduce n p idxs: + (forall i, In i idxs -> weight (S i) / weight i <> 0) -> + eval n (chained_carries_no_reduce n p idxs) = eval n p. + Proof. + cbv [chained_carries_no_reduce]; intros. + destruct n; [push;reflexivity|]. + apply fold_right_invariant; [|intro; rewrite <-in_rev]; + intros; push; auto. + Qed. Hint Rewrite @eval_chained_carries_no_reduce : push_eval. + (* Reverse of [eval]; translate from Z to basesystem by putting everything in first digit and then carrying. *) Definition encode n s c (x : Z) : list Z := @@ -329,6 +347,7 @@ Module Positional. Section Positional. Lemma length_encode n s c x : length (encode n s c x) = n. Proof. cbv [encode]; repeat distr_length. Qed. + End Carries. Hint Rewrite @eval_encode : push_eval. Hint Rewrite @length_encode : distr_length. @@ -428,6 +447,30 @@ End Positional. End Positional. Require Import Crypto.Util.ZUtil.AddGetCarry Crypto.Util.ZUtil.MulSplit. +Module BaseConversion. + Import Positional. + Section BaseConversion. + Context (sw dw : nat -> Z) (* source/destination weight functions *) + (dw_0 : dw 0%nat = 1) + (dw_nz : forall i, dw i <> 0). + Context (dw_divides : forall i : nat, dw (S i) / dw i > 0). + + Definition convert_bases (sn dn : nat) (p : list Z) : list Z := + let p' := Positional.from_associational dw dn (Positional.to_associational sw sn p) in + chained_carries_no_reduce dw dn p' (seq 0 (pred dn)). + + Lemma eval_convert_bases sn dn p : + (dn <> 0%nat) -> length p = sn -> + eval dw dn (convert_bases sn dn p) = eval sw sn p. + Proof. + cbv [convert_bases]; intros. + rewrite eval_chained_carries_no_reduce; auto using ZUtil.Z.positive_is_nonzero. + rewrite eval_from_associational; auto. + Qed. + + End BaseConversion. +End BaseConversion. + (* Non-CPS version of Arithmetic/Saturated/MulSplit.v *) Module MulSplit. Module Associational. @@ -702,6 +745,71 @@ Module Columns. rewrite IHp by tauto. ring. } Qed. + Lemma flatten_snoc x inp : flatten (inp ++ [x]) = flatten_step x (flatten inp). + Proof. cbv [flatten]. rewrite rev_unit. reflexivity. Qed. + + Lemma weight_multiples_full j : forall i, (i <= j)%nat -> weight j mod weight i = 0. + Proof. + induction j; intros; [replace i with 0%nat by omega + | destruct (dec (i <= j)%nat); [ rewrite (Z.div_mod (weight (S j)) (weight j)) by auto + | replace i with (S j) by omega ] ]; + repeat match goal with + | _ => rewrite weight_0 + | _ => rewrite weight_multiples + | _ => rewrite IHj by omega + | _ => progress autorewrite with push_Zmod zsimplify + | _ => reflexivity + end. + Qed. + + (* TODO: move to ZUtil *) + Lemma Z_divide_div_mul_exact' a b c : b <> 0 -> (b | a) -> a * c / b = c * (a / b). + Proof. intros. rewrite Z.mul_comm. auto using Z.divide_div_mul_exact. Qed. + + Lemma flatten_partitions inp: + forall n i, length inp = n -> (i < n)%nat -> + nth_default 0 (fst (flatten inp)) i = (((eval n inp) / weight i)) mod (weight (S i) / weight i). + Proof. + induction inp using rev_ind; distr_length; intros. + { cbn. + autorewrite with push_eval push_nth_default zsimplify. + reflexivity. } + { + destruct n as [| n]; [omega|]. + rewrite flatten_snoc, eval_snoc by omega. + cbv [flatten_step Let_In]. cbn [fst]. + rewrite nth_default_app. + break_match; distr_length. + { rewrite IHinp with (n:=n) by omega. + rewrite (Z.div_mod (weight n) (weight i)) by auto. + rewrite weight_multiples_full by omega. + rewrite (Z.div_mod (weight n) (weight (S i))) by auto. + rewrite weight_multiples_full by omega. + autorewrite with zsimplify. + repeat match goal with + | _ => rewrite Z_divide_div_mul_exact' by (try apply Z.mod_divide; auto) + | |- context [ (_ + ?a * ?b * ?c) / ?a ] => + replace (a * b * c) with (a * (b * c)) by ring; + rewrite Z.div_add' by auto + | |- context [ (_ + ?a * ?b * ?c) mod ?b ] => + replace (a * b * c) with (a * c * b) by ring; + rewrite Z.mod_add by auto using ZUtil.Z.positive_is_nonzero + | _ => reflexivity + end. + } + { repeat match goal with + | _ => progress replace (Datatypes.length inp) with n by omega + | _ => progress replace i with n by omega + | _ => rewrite nth_default_cons + | _ => rewrite sum_cons + | _ => rewrite flatten_column_mod + | _ => erewrite flatten_div by eauto + | _ => progress autorewrite with natsimplify + end. + rewrite Z.div_add' by auto. + reflexivity. } } + Qed. + Section mul. Definition mul s n m (p q : list Z) : list Z := let p_a := Positional.to_associational weight n p in @@ -710,6 +818,207 @@ Module Columns. fst (flatten (from_associational m pq_a)). End mul. End Columns. + + Section mul_converted. + Context (w w' : nat -> Z). + Context (w'_0 : w' 0%nat = 1) + (w'_nonzero : forall i, w' i <> 0) + (w'_positive : forall i, w' i > 0) + (w'_divides : forall i : nat, w' (S i) / w' i > 0). + Context (w_0 : w 0%nat = 1) + (w_nonzero : forall i, w i <> 0) + (w_positive : forall i, w i > 0) + (w_multiples : forall i, w (S i) mod w i = 0) + (w_divides : forall i : nat, w (S i) / w i > 0). + + (* take in inputs in base w. Converts to w', multiplies in that format, converts to w again, then flattens. *) + Definition mul_converted + n1 n2 (* lengths in original format *) + m1 m2 (* lengths in converted format *) + (n3 : nat) (* final length *) + (p1 p2 : list Z) := + let p1' := BaseConversion.convert_bases w w' n1 m1 p1 in + let p2' := BaseConversion.convert_bases w w' n2 m2 p2 in + let p1_a := Positional.to_associational w' m1 p1' in + let p2_a := Positional.to_associational w' m2 p2' in + (* + let p3_a := Associational.carry (w' 1%nat) (w 1) (Associational.mul p1_a p2_a) in + *) + let p3_a := Associational.mul p1_a p2_a in + fst (flatten w (from_associational w n3 p3_a)). + + Hint Rewrite + @Columns.eval_from_associational + @Associational.eval_carry + @Associational.eval_mul + @Positional.eval_to_associational + @BaseConversion.eval_convert_bases using solve [auto] : push_eval. + + Lemma mul_converted_correct n1 n2 m1 m2 n3 p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): + length p1 = n1 -> length p2 = n2 -> + 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 -> + Positional.eval w n3 (mul_converted n1 n2 m1 m2 n3 p1 p2) = (Positional.eval w n1 p1) * (Positional.eval w n2 p2). + Proof. + cbv [mul_converted]; intros. + rewrite Columns.flatten_mod by auto using Columns.length_from_associational. + autorewrite with push_eval. auto using Z.mod_small. + Qed. + + (* TODO: this section specializes to one-element lists in which + the intermediate weight is the square root of the old. It would + be better to specialize just to the relationship between + weights, rather than the size of the input. However, partial + reduction/CPS transform seems to take forever when dynamic list + allocation is happening. *) + Section single. + Context (w'_sq : forall i, (w' i) * (w' i) = w i). + Context (w_1_gt1 : w 1 > 1) (w'_1_gt1 : w' 1 > 1). + + Derive convert_single + SuchThat (forall p, convert_single p = BaseConversion.convert_bases w w' 1 2 [p]) + As convert_single_correct. + Proof. + intros. + cbv - [Z.add Z.div Z.mul Z.eqb Z.modulo]. + assert (w 0 mod w' 1 = 1) as P0 by (rewrite w_0, Z.mod_1_l; omega). + assert (w' 1 =? 1 = false) as P1 by (apply Z.eqb_neq; omega). + assert (1 =? 0 = false) as P2 by reflexivity. + repeat match goal with + | _ => progress rewrite ?w_0, ?w'_0 + | _ => progress rewrite ?P0, ?P1, ?P2 + | _ => progress rewrite ?Z.mod_1_l, ?Z.eqb_refl by omega + | _ => progress autorewrite with zsimplify_fast + end. + autorewrite with zsimplify. + reflexivity. + Qed. + + Derive mul_converted_single + SuchThat (forall (p1 p2 : Z), (0 <= p1 < w 1) -> (0 <= p2 < w 1) -> + mul_converted_single p1 p2 = mul_converted 1 1 2 2 2 [p1] [p2]) + As mul_converted_single_eq. + Proof. + intros. + cbv [mul_converted]. + rewrite <-!convert_single_correct. + cbv [convert_single]. + + (* + (* assert some things for omega to use later *) + rewrite <-(w'_sq 1) in *. + pose proof (Z.mod_pos_bound p1 (w' 1) ltac:(auto using Z.gt_lt)). + pose proof (Z.mod_pos_bound p2 (w' 1) ltac:(auto using Z.gt_lt)). + assert (0 <= p1 / w' 1 < w' 1) by (split; [ Z.zero_bounds | apply Z.div_lt_upper_bound; omega ]). + assert (0 <= p2 / w' 1 < w' 1) by (split; [ Z.zero_bounds | apply Z.div_lt_upper_bound; omega ]). + assert (w' 1 < w' 1 * w' 1) by (apply Z.lt_mul_diag_r; omega). + assert (w' 1 =? 0 = false) by (apply Z.eqb_neq; omega). + assert (1 =? 0 = false) by reflexivity. + assert (0 < w' 1 * w' 1) by Z.zero_bounds. + + (* simplify carry *) + match goal with |- context [Associational.carry ?w ?fw ?x] => + remember (Associational.carry w fw x) as X eqn:HeqX + end. + cbv - [Z.modulo Z.div Z.eqb Z.mul app] in HeqX. cbn [app] in HeqX. + rewrite w'_0 in HeqX; autorewrite with zsimplify_fast in HeqX. + rewrite Z.eqb_refl in HeqX. + repeat match type of HeqX with context [if ?x =? ?y then _ else _] => + let H := fresh "H" in + case_eq (x =? y); intro H; rewrite H in HeqX; + rewrite ?Z.eqb_eq, ?Z.eqb_neq in H; try omega + end. + cbn [app] in HeqX. + rewrite !Z.div_small with (b:= w' 1 * w' 1) in HeqX by nia. + rewrite !Z.mod_small with (b:= w' 1 * w' 1) in HeqX by nia. + subst X. + + (* simplify from_associational *) + match goal with |- context [from_associational ?w ?n ?x] => + remember (from_associational w n x) as X eqn:HeqX + end. + cbv - [Z.modulo Z.div Z.eqb Z.mul cons_to_nth] in HeqX. cbn [app] in HeqX. + rewrite <-w'_sq in HeqX. + autorewrite with zsimplify_fast in HeqX. + rewrite !Z.mod_1_l in HeqX by omega. + rewrite !Z.mod_mul in HeqX by omega. + rewrite !Z.mod_small with (b:= w' 1 * w' 1) in HeqX by nia. + rewrite Z.eqb_refl in HeqX. + repeat match goal with H : Z.eqb _ _ = _ |- _ => rewrite H in HeqX end. + cbv - [Z.modulo Z.div Z.mul] in HeqX. + autorewrite with zsimplify in HeqX. + subst X. + + (* simplify flatten *) + match goal with |- context [flatten ?w ?x] => + remember (flatten w x) as X eqn:HeqX + end. + cbn in HeqX. + cbv [flatten_step] in HeqX. cbn in HeqX. + autorewrite with to_div_mod in HeqX. + cbn [fst snd] in HeqX. + rewrite w_0 in HeqX. + autorewrite with zsimplify in HeqX. + Check Z.div_small. + match type of HeqX with context [ + + cbv [Let_In] in HeqX. + autorewrite with to_div_mod in HeqX. + cbn [fst snd] in HeqX. + cbv - [flatten_column Z.div Z.modulo Z.mul] in HeqX. + cbv [flatten_step] in HeqX. + cbv - [Z.modulo Z.div Z.eqb Z.mul Z.add_get_carry_full Z.add fst snd] in HeqX. cbn [app] in HeqX. + rewrite <-w'_sq in HeqX. + autorewrite with zsimplify_fast in HeqX. + rewrite !Z.mod_1_l in HeqX by omega. + rewrite !Z.mod_mul in HeqX by omega. + rewrite !Z.mod_small with (b:= w' 1 * w' 1) in HeqX by nia. + rewrite Z.eqb_refl in HeqX. + repeat match goal with H : Z.eqb _ _ = _ |- _ => rewrite H in HeqX end. + cbv - [Z.modulo Z.div Z.mul] in HeqX. + autorewrite with zsimplify in HeqX. + subst X. + *) + + subst mul_converted_single. + reflexivity. + Qed. + + Lemma eval_mul_converted_single p1 p2 (_: 0 <= p1 < w 1) (_:0 <= p2 < w 1) (_: 0 <= p1 * p2 < w 2) : + Positional.eval w 2 (mul_converted_single p1 p2) = (Positional.eval w 1 [p1]) * (Positional.eval w 1 [p2]). + Proof. rewrite mul_converted_single_eq by auto. apply mul_converted_correct; cbn; nia. Qed. + + Hint Rewrite @length_from_associational : distr_length. + + Lemma mul_converted_single_mod x y : + 0 <= x < w 1 -> 0 <= y < w 1 -> + nth_default 0 (mul_converted_single x y) 0 = (x * y) mod (w 1). + Proof. + intros; rewrite mul_converted_single_eq by auto. cbv [mul_converted]. + erewrite flatten_partitions by (auto; distr_length). + autorewrite with distr_length push_eval. cbn. + rewrite w_0; autorewrite with zsimplify. + reflexivity. + Qed. + + Lemma mul_converted_single_div x y : + 0 <= x < w 1 -> 0 <= y < w 1 -> + 0 <= x * y < w 2 -> + nth_default 0 (mul_converted_single x y) 1 = (x * y) / (w 1). + Proof. + intros; rewrite mul_converted_single_eq by auto. cbv [mul_converted]. + erewrite flatten_partitions by (auto; distr_length). + autorewrite with distr_length push_eval. cbn. + rewrite w_0; autorewrite with zsimplify. + apply Z.mod_small. + split. + { apply Z.div_nonneg; auto; omega. } + { apply Z.div_lt_upper_bound. omega. + rewrite Z.mul_div_eq_full by auto. + rewrite w_multiples. omega. } + Qed. + + End single. + End mul_converted. End Columns. Module Compilers. @@ -1006,9 +1315,9 @@ Module Compilers. | false => let rT := type.reify T in - let not_x := refresh x ltac:(fun n => fresh n) in - let not_x2 := refresh not_x ltac:(fun n => fresh n) in - let not_x3 := refresh not_x2 ltac:(fun n => fresh n) in + let not_x := fresh in + let not_x2 := fresh in + let not_x3 := fresh in (*let dummy := match goal with _ => idtac "reify_in_context: λ case:" term "using vars:" not_x not_x2 not_x3 end in*) let rf0 := constr:( @@ -5438,81 +5747,16 @@ Require Import Crypto.Util.ZUtil.Zselect Crypto.Util.ZUtil.AddModulo. Module MontgomeryReduction. Section MontRed'. Context (N R N' R' : Z). - Context (HN_range : 0 <= N < R) (HN'_range : 0 <= N' < R) (HN_nz : N <> 0) + Context (HN_range : 0 <= N < R) (HN'_range : 0 <= N' < R) (HN_nz : N <> 0) (R_gt_1 : R > 1) (N'_good : Z.equiv_modulo R (N*N') (-1)) (R'_good: Z.equiv_modulo N (R*R') 1). - Section mul_converted. - Context (w w' : nat -> Z). - Context (w'_sq : forall i, (w' i) * (w' i) = w i). - Context (w'_0 : w' 0%nat = 1) - (w'_positive : forall i, w' i > 0). - Context (w_0 : w 0%nat = 1) - (w_nonzero : forall i, w i <> 0) - (w_positive : forall i, w i > 0) - (w_multiples : forall i, w (S i) mod w i = 0) - (w_divides : forall i : nat, w (S i) / w i > 0). - Context (w_1_gt1 : w 1 > 1) (w'_1_gt1 : w' 1 > 1). - - (* - (* TODO: get a version of convert-multiply-convert strategy working in - general form, not specialized to one-element lists, and add it to - arithmetic development in an appropriate place. May need to - specialize it to the case where (forall i, (w' i)^2 = w i) in - order for base conversion to simplify as expected. *) - Definition chained_carries_noreduce weight n p idxs : list Z := - fold_right (fun a b => carry weight n n a b) p (rev idxs). - Definition convert_bases (sw dw : nat -> Z) (sn dn : nat) (p : list Z) : list Z := - let p' := Positional.from_associational dw dn (Positional.to_associational sw sn p) in - chained_carries_noreduce dw dn p' (seq 0 dn). - - (* take in inputs in base w. Converts to w', multiplies in that format, converts to w again, then flattens. *) - Definition mul_converted - w w' (* two different weight functions, initial/final and intermediate *) - n1 n2 (* lengths in original format *) - m1 m2 (* lengths in converted format *) - (n3 : nat) (* final length *) - (p1 p2 : list Z) := - let p1' := convert_bases w w' n1 m1 p1 in - let p2' := convert_bases w w' n2 m2 p2 in - let p1_a := Positional.to_associational w' m1 p1' in - let p2_a := Positional.to_associational w' m2 p2' in - let p3_a := Associational.mul p1_a p2_a in - fst (Columns.flatten w (Columns.from_associational w n3 p3_a)). - *) - - (* specialized version equivalent to [test_mul w w' 1 1 2 2 2 [p1] [p2] *) - (* takes in 2 1-digit inputs in base w, produces a 2-digit output--same spec as mul_split *) - Definition mul_converted_single (w w' : nat ->Z) (p1 p2 : Z) := - let p1' := [p1 mod w' 1%nat; p1 / w' 1%nat] in - let p2' := [p2 mod w' 1%nat; p2 / w' 1%nat] in - let p1_a := Positional.to_associational w' 2 p1' in - let p2_a := Positional.to_associational w' 2 p2' in - let p3_a := Associational.mul p1_a p2_a in - fst (Columns.flatten w (Columns.from_associational w 2 p3_a)). - - Lemma mul_converted_single_eq p1 p2 : - (0 <= p1 * p2 < (w 2)) -> - mul_converted_single w w' p1 p2 = [(p1 * p2) mod (w 1); (p1 * p2) / (w 1) ]. - Proof. - Admitted. - Lemma mul_converted_single_correct p1 p2 : - Positional.eval w 2 (mul_converted_single w w' p1 p2) = (Positional.eval w 1 [p1]) * (Positional.eval w 1 [p2]) mod (w 2). - Proof. - intros. cbv [mul_converted_single]. - rewrite Columns.flatten_mod by auto using Columns.length_from_associational. - rewrite Columns.eval_from_associational by auto. - rewrite Associational.eval_mul. - cbv [Positional.eval Positional.to_associational Associational.eval]. - simpl [map seq combine fold_right]. rewrite w_0, w'_0. - rewrite !Z.mul_div_eq by auto. - f_equal; ring. - Qed. - End mul_converted. - Context (w w_half : nat -> Z). Context (w_half_sq : forall i, (w_half i) * (w_half i) = w i). Context (w_half_0 : w_half 0%nat = 1) - (w_half_positive : forall i, w_half i > 0). + (w_half_nonzero : forall i, w_half i <> 0) + (w_half_positive : forall i, w_half i > 0) + (w_half_multiples : forall i, w_half (S i) mod w_half i = 0) + (w_half_divides : forall i : nat, w_half (S i) / w_half i > 0). Context (w_0 : w 0%nat = 1) (w_nonzero : forall i, w i <> 0) (w_positive : forall i, w i > 0) @@ -5521,8 +5765,8 @@ Module MontgomeryReduction. Context (w_1_gt1 : w 1 > 1) (w_half_1_gt1 : w_half 1 > 1). Definition montred' (lo_hi : (Z * Z)) := - dlet_nd y := nth_default 0 (mul_converted_single w w_half (fst lo_hi) N') 0 in - dlet_nd t1_t2 := mul_converted_single w w_half y N in + dlet_nd y := nth_default 0 (Columns.mul_converted_single w w_half (fst lo_hi) N') 0 in + dlet_nd t1_t2 := Columns.mul_converted_single w w_half y N in dlet_nd lo'_carry := Z.add_get_carry_full R (fst lo_hi) (nth_default 0 t1_t2 0) in dlet_nd hi'_carry := Z.add_with_get_carry_full R (snd lo'_carry) (snd lo_hi) (nth_default 0 t1_t2 1) in dlet_nd y' := Z.zselect (snd hi'_carry) 0 N in @@ -5531,10 +5775,11 @@ Module MontgomeryReduction. Local Ltac solve_range H := repeat match goal with - | _ => rewrite H, ?Z.pow_1_r, ?Z.pow_2_r + | _ => rewrite H, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r | |- context [?a mod ?b] => unique pose proof (Z.mod_pos_bound a b ltac:(omega)) | |- 0 <= _ * _ < _ * _ => split; [ solve [Z.zero_bounds] | apply Z.mul_lt_mono_nonneg; omega ] + | _ => solve [auto] end. Lemma montred'_eq lo_hi T (HT_range: 0 <= T < R * N) @@ -5546,10 +5791,10 @@ Module MontgomeryReduction. cbv [montred' partial_reduce_alt reduce_via_partial_alt prereduce Let_In]. rewrite Hlo, Hhi. assert (0 <= T mod R * N' < w 2) by (solve_range Hw). - rewrite !mul_converted_single_eq - by (rewrite ?mul_converted_single_eq; try assumption; cbv [nth_default nth_error]; solve_range Hw). + rewrite !Columns.mul_converted_single_mod; + (auto; rewrite ?Columns.mul_converted_single_mod; solve_range Hw). + rewrite !Columns.mul_converted_single_div by (auto; solve_range Hw). rewrite Hw, ?Z.pow_1_r. - cbv [nth_default nth_error]. autorewrite with to_div_mod. rewrite ?Z.zselect_correct, ?Z.add_modulo_correct. (* pull out value before last modular reduction *) @@ -5760,7 +6005,7 @@ Module Montgomery256. expr_let 36 := MUL_256 @@ ((uint128)(fst @@ x_10 >> 128), (340282366841710300967557013911933812736)) in expr_let 37 := ADD_256 @@ (x_29, x_36) in expr_let 39 := ADD_256 @@ (fst @@ x_1, fst @@ x_28) in - expr_let 40 := ADDC_256 @@ (fst @@ x_39, snd @@ x_1, fst @@ x_37) in + expr_let 40 := ADDC_256 @@ (snd @@ x_39, snd @@ x_1, fst @@ x_37) in expr_let 41 := SELC @@ (snd @@ x_40, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) in expr_let 42 := fst @@ (SUB_256 @@ (fst @@ x_40, x_41)) in ADDM @@ (x_42, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) @@ -5879,4 +6124,4 @@ c.Add256($r10, $r8, $r9_lo); c.Sub($r42, $r40_lo, $r41); c.AddM($ret, $r42, RegZero, RegMod);))) : expr uint256 -*) \ No newline at end of file + *) From cde249ac69d8f3017dab185aaace07403d484606 Mon Sep 17 00:00:00 2001 From: Jade Philipoom Date: Wed, 28 Feb 2018 10:43:05 +0100 Subject: [PATCH 2/7] Add a dummy length argument to make partial evaluation work (see #321) and fixed up Montgomery notations --- src/Experiments/SimplyTypedArithmetic.v | 93 +++++++++++++------------ 1 file changed, 49 insertions(+), 44 deletions(-) diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index c6c12827d8..b4e59719f9 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -894,8 +894,8 @@ Module Columns. Qed. Derive mul_converted_single - SuchThat (forall (p1 p2 : Z), (0 <= p1 < w 1) -> (0 <= p2 < w 1) -> - mul_converted_single p1 p2 = mul_converted 1 1 2 2 2 [p1] [p2]) + SuchThat (forall n (p1 p2 : Z), (0 <= p1 < w 1) -> (0 <= p2 < w 1) -> + mul_converted_single n p1 p2 = mul_converted 1 1 2 2 n [p1] [p2]) As mul_converted_single_eq. Proof. intros. @@ -983,29 +983,30 @@ Module Columns. reflexivity. Qed. - Lemma eval_mul_converted_single p1 p2 (_: 0 <= p1 < w 1) (_:0 <= p2 < w 1) (_: 0 <= p1 * p2 < w 2) : - Positional.eval w 2 (mul_converted_single p1 p2) = (Positional.eval w 1 [p1]) * (Positional.eval w 1 [p2]). + Lemma eval_mul_converted_single n p1 p2 (_: n <> 0%nat) (_: 0 <= p1 < w 1) (_:0 <= p2 < w 1) (_: 0 <= p1 * p2 < w n) : + Positional.eval w n (mul_converted_single n p1 p2) = (Positional.eval w 1 [p1]) * (Positional.eval w 1 [p2]). Proof. rewrite mul_converted_single_eq by auto. apply mul_converted_correct; cbn; nia. Qed. Hint Rewrite @length_from_associational : distr_length. - Lemma mul_converted_single_mod x y : - 0 <= x < w 1 -> 0 <= y < w 1 -> - nth_default 0 (mul_converted_single x y) 0 = (x * y) mod (w 1). + Lemma mul_converted_single_mod n x y : + n = 2%nat -> 0 <= x < w 1 -> 0 <= y < w 1 -> + nth_default 0 (mul_converted_single n x y) 0 = (x * y) mod (w 1). Proof. - intros; rewrite mul_converted_single_eq by auto. cbv [mul_converted]. + intros; subst n; rewrite mul_converted_single_eq by auto. cbv [mul_converted]. erewrite flatten_partitions by (auto; distr_length). autorewrite with distr_length push_eval. cbn. rewrite w_0; autorewrite with zsimplify. reflexivity. Qed. - Lemma mul_converted_single_div x y : + Lemma mul_converted_single_div n x y : + n = 2%nat -> 0 <= x < w 1 -> 0 <= y < w 1 -> 0 <= x * y < w 2 -> - nth_default 0 (mul_converted_single x y) 1 = (x * y) / (w 1). + nth_default 0 (mul_converted_single n x y) 1 = (x * y) / (w 1). Proof. - intros; rewrite mul_converted_single_eq by auto. cbv [mul_converted]. + intros; subst n; rewrite mul_converted_single_eq by auto. cbv [mul_converted]. erewrite flatten_partitions by (auto; distr_length). autorewrite with distr_length push_eval. cbn. rewrite w_0; autorewrite with zsimplify. @@ -5763,10 +5764,11 @@ Module MontgomeryReduction. (w_multiples : forall i, w (S i) mod w i = 0) (w_divides : forall i : nat, w (S i) / w i > 0). Context (w_1_gt1 : w 1 > 1) (w_half_1_gt1 : w_half 1 > 1). + Context (n:nat) (Hn : n = 2%nat). Definition montred' (lo_hi : (Z * Z)) := - dlet_nd y := nth_default 0 (Columns.mul_converted_single w w_half (fst lo_hi) N') 0 in - dlet_nd t1_t2 := Columns.mul_converted_single w w_half y N in + dlet_nd y := nth_default 0 (Columns.mul_converted_single w w_half n (fst lo_hi) N') 0 in + dlet_nd t1_t2 := Columns.mul_converted_single w w_half n y N in dlet_nd lo'_carry := Z.add_get_carry_full R (fst lo_hi) (nth_default 0 t1_t2 0) in dlet_nd hi'_carry := Z.add_with_get_carry_full R (snd lo'_carry) (snd lo_hi) (nth_default 0 t1_t2 1) in dlet_nd y' := Z.zselect (snd hi'_carry) 0 N in @@ -5832,12 +5834,13 @@ Module MontgomeryReduction. End MontRed'. Derive montred_gen - SuchThat (forall (w w_half : nat -> Z) - (N R N' : Z) + SuchThat (forall (N R N' : Z) + (w w_half : nat -> Z) + (n : nat) (lo_hi : Z * Z), Interp (t:=type.reify_type_of montred') - montred_gen N R N' w w_half lo_hi - = montred' N R N' w w_half lo_hi) + montred_gen N R N' w w_half n lo_hi + = montred' N R N' w w_half n lo_hi) As montred_gen_correct. Proof. intros. @@ -5880,6 +5883,7 @@ Module MontgomeryReduction. Let rN := GallinaReify.Reify N. Let rR := GallinaReify.Reify R. Let rN' := GallinaReify.Reify N'. + Let rn := GallinaReify.Reify 2%nat. Let relax_zrange := relax_zrange_of_machine_wordsize. Let arg_bounds : BoundsAnalysis.Indexed.Range.range (BoundsAnalysis.Indexed.OfPHOAS.type.compile (type.Z * type.Z)) := (bound, bound). @@ -5915,6 +5919,7 @@ Module MontgomeryReduction. @ (rN' _) @ (rw _) @ (rw_half _) + @ (rn _) )%expr in check_args res. @@ -5932,7 +5937,7 @@ Module MontgomeryReduction. (bs:=out_bounds) arg rv - = Some (montred' (Interp rN) (Interp rR) (Interp rN') (Interp rw) (Interp rw_half) arg'). + = Some (montred' (Interp rN) (Interp rR) (Interp rN') (Interp rw) (Interp rw_half) (Interp rn) arg'). Lemma rmontred_correct rv @@ -6004,11 +6009,11 @@ Module Montgomery256. expr_let 29 := snd @@ x_28 +₁₂₈ snd @@ x_27 in expr_let 36 := MUL_256 @@ ((uint128)(fst @@ x_10 >> 128), (340282366841710300967557013911933812736)) in expr_let 37 := ADD_256 @@ (x_29, x_36) in - expr_let 39 := ADD_256 @@ (fst @@ x_1, fst @@ x_28) in - expr_let 40 := ADDC_256 @@ (snd @@ x_39, snd @@ x_1, fst @@ x_37) in - expr_let 41 := SELC @@ (snd @@ x_40, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) in - expr_let 42 := fst @@ (SUB_256 @@ (fst @@ x_40, x_41)) in - ADDM @@ (x_42, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) + expr_let 38 := ADD_256 @@ (fst @@ x_1, fst @@ x_28) in + expr_let 39 := ADDC_256 @@ (snd @@ x_38, snd @@ x_1, fst @@ x_37) in + expr_let 40 := SELC @@ (snd @@ x_39, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) in + expr_let 41 := fst @@ (SUB_256 @@ (fst @@ x_39, x_40)) in + ADDM @@ (x_41, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) : expr uint256 *) End Montgomery256. @@ -6062,28 +6067,28 @@ Module Montgomery256PrintingNotations. Notation "$r n '_hi'" := (snd @@ (BoundsAnalysis.Indexed.expr.Var (BoundsAnalysis.type.prod _ _) n))%nexpr (at level 10, format "$r n _hi") : nexpr_scope. Notation "'c.Mul128x128(' '$r' n ',' x ',' y ');' f" := (expr_let n := mul _ _ uint256 @@ (x, y) in - f)%nexpr (at level 48, right associativity, format "'[' 'c.Mul128x128(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope. + f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Mul128x128(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope. Notation "'c.Mul128x128(' '$r' n ',' x ',' y ')' '<<' count ';' f" := (expr_let n := shiftl _ _ count @@ (mul _ _ uint256 @@ (x, y)) in - f)%nexpr (at level 49, right associativity, format "'[' 'c.Mul128x128(' '$r' n ',' x ',' y ')' '<<' count ';' ']' '//' f") : nexpr_scope. + f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Mul128x128(' '$r' n ',' x ',' y ')' '<<' count ';' ']' '//' f") : nexpr_scope. Notation "'c.Add256(' '$r' n ',' x ',' y ');' f" := (expr_let n := add_get_carry_concrete _ _ uint256 _ $R @@ (x, y) in - f)%nexpr (at level 47, right associativity, format "'[' 'c.Add256(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope. + f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Add256(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope. Notation "'c.Add128(' '$r' n ',' x ',' y ');' f" := (expr_let n := add_get_carry_concrete _ _ uint128 _ $R @@ (x, y) in - f)%nexpr (at level 45, right associativity, format "'[' 'c.Add128(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope. + f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Add128(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope. Notation "'c.Add64(' '$r' n ',' x ',' y ');' f" := (expr_let n := add _ _ uint128 @@ (x, y) in - f)%nexpr (at level 46, right associativity, format "'[' 'c.Add64(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope. + f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Add64(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope. Notation "'c.Addc(' '$r' n ',' x ',' y ');' f" := (expr_let n := add_with_get_carry_concrete _ _ _ uint256 _ $R @@ (_, x, y) in - f)%nexpr (at level 44, right associativity, format "'[' 'c.Addc(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope. + f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Addc(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope. Notation "'c.Selc(' '$r' n ',' y ',' z ');' f" := (expr_let n := zselect _ _ _ uint256 @@ (_, y, z) in - f)%nexpr (at level 43, right associativity, format "'[' 'c.Selc(' '$r' n ',' y ',' z ');' ']' '//' f") : nexpr_scope. + f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Selc(' '$r' n ',' y ',' z ');' ']' '//' f") : nexpr_scope. Notation "'c.Sub(' '$r' n ',' x ',' y ');' f" := (expr_let n := fst @@ (sub_get_borrow_concrete _ _ uint256 _ $R @@ (x, y)) in - f)%nexpr (at level 42, right associativity, format "'c.Sub(' '$r' n ',' x ',' y ');' '//' f") : nexpr_scope. + f)%nexpr (at level 40, f at level 200, right associativity, format "'c.Sub(' '$r' n ',' x ',' y ');' '//' f") : nexpr_scope. Notation "'c.AddM(' '$ret' ',' x ',' y ',' z ');'" := (add_modulo _ _ _ uint256 @@ (x, y, z))%nexpr (at level 40, format "'c.AddM(' '$ret' ',' x ',' y ',' z ');'") : nexpr_scope. Notation "'Lower128'" @@ -6110,18 +6115,18 @@ c.Mul128x128($r3, ($r1_lo >> 128), Lower128{RegPinv}) << 128; c.Mul128x128($r8, Lower128 @@ $r1_lo, Lower128{RegPinv}); c.Add256($r9, $r2, $r3); c.Add256($r10, $r8, $r9_lo); -(c.Mul128x128($r20, Lower128 @@ $r10_lo, RegMod << 128) << 128; - c.Mul128x128($r21, ($r10_lo >> 128), Lower128{RegMod}) << 128; - c.Mul128x128($r26, Lower128 @@ $r10_lo, Lower128{RegMod}); - c.Add128($r27, $r20, $r21); - (c.Add256($r28, $r26, $r27_lo); - c.Add64($r29, $r28_hi, $r27_hi); - (c.Mul128x128($r36, ($r10_lo >> 128), RegMod << 128); - c.Add256($r37, $r29, $r36); - c.Add256($r39, $r1_lo, $r28_lo); - c.Addc($r40, $r1_hi, $r37_lo); - c.Selc($r41,RegZero, RegMod); - c.Sub($r42, $r40_lo, $r41); - c.AddM($ret, $r42, RegZero, RegMod);))) +c.Mul128x128($r20, Lower128 @@ $r10_lo, RegMod << 128) << 128; +c.Mul128x128($r21, ($r10_lo >> 128), Lower128{RegMod}) << 128; +c.Mul128x128($r26, Lower128 @@ $r10_lo, Lower128{RegMod}); +c.Add128($r27, $r20, $r21); +c.Add256($r28, $r26, $r27_lo); +c.Add64($r29, $r28_hi, $r27_hi); +c.Mul128x128($r36, ($r10_lo >> 128), RegMod << 128); +c.Add256($r37, $r29, $r36); +c.Add256($r38, $r1_lo, $r28_lo); +c.Addc($r39, $r1_hi, $r37_lo); +c.Selc($r40,RegZero, RegMod); +c.Sub($r41, $r39_lo, $r40); +c.AddM($ret, $r41, RegZero, RegMod); : expr uint256 *) From 37e6cb55f87b23c1b29f34e8cd0429ee2995e89f Mon Sep 17 00:00:00 2001 From: Jade Philipoom Date: Wed, 28 Feb 2018 10:46:15 +0100 Subject: [PATCH 3/7] remove unneeded, commented-out code --- src/Experiments/SimplyTypedArithmetic.v | 77 ------------------------- 1 file changed, 77 deletions(-) diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index b4e59719f9..9a5d12be58 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -902,83 +902,6 @@ Module Columns. cbv [mul_converted]. rewrite <-!convert_single_correct. cbv [convert_single]. - - (* - (* assert some things for omega to use later *) - rewrite <-(w'_sq 1) in *. - pose proof (Z.mod_pos_bound p1 (w' 1) ltac:(auto using Z.gt_lt)). - pose proof (Z.mod_pos_bound p2 (w' 1) ltac:(auto using Z.gt_lt)). - assert (0 <= p1 / w' 1 < w' 1) by (split; [ Z.zero_bounds | apply Z.div_lt_upper_bound; omega ]). - assert (0 <= p2 / w' 1 < w' 1) by (split; [ Z.zero_bounds | apply Z.div_lt_upper_bound; omega ]). - assert (w' 1 < w' 1 * w' 1) by (apply Z.lt_mul_diag_r; omega). - assert (w' 1 =? 0 = false) by (apply Z.eqb_neq; omega). - assert (1 =? 0 = false) by reflexivity. - assert (0 < w' 1 * w' 1) by Z.zero_bounds. - - (* simplify carry *) - match goal with |- context [Associational.carry ?w ?fw ?x] => - remember (Associational.carry w fw x) as X eqn:HeqX - end. - cbv - [Z.modulo Z.div Z.eqb Z.mul app] in HeqX. cbn [app] in HeqX. - rewrite w'_0 in HeqX; autorewrite with zsimplify_fast in HeqX. - rewrite Z.eqb_refl in HeqX. - repeat match type of HeqX with context [if ?x =? ?y then _ else _] => - let H := fresh "H" in - case_eq (x =? y); intro H; rewrite H in HeqX; - rewrite ?Z.eqb_eq, ?Z.eqb_neq in H; try omega - end. - cbn [app] in HeqX. - rewrite !Z.div_small with (b:= w' 1 * w' 1) in HeqX by nia. - rewrite !Z.mod_small with (b:= w' 1 * w' 1) in HeqX by nia. - subst X. - - (* simplify from_associational *) - match goal with |- context [from_associational ?w ?n ?x] => - remember (from_associational w n x) as X eqn:HeqX - end. - cbv - [Z.modulo Z.div Z.eqb Z.mul cons_to_nth] in HeqX. cbn [app] in HeqX. - rewrite <-w'_sq in HeqX. - autorewrite with zsimplify_fast in HeqX. - rewrite !Z.mod_1_l in HeqX by omega. - rewrite !Z.mod_mul in HeqX by omega. - rewrite !Z.mod_small with (b:= w' 1 * w' 1) in HeqX by nia. - rewrite Z.eqb_refl in HeqX. - repeat match goal with H : Z.eqb _ _ = _ |- _ => rewrite H in HeqX end. - cbv - [Z.modulo Z.div Z.mul] in HeqX. - autorewrite with zsimplify in HeqX. - subst X. - - (* simplify flatten *) - match goal with |- context [flatten ?w ?x] => - remember (flatten w x) as X eqn:HeqX - end. - cbn in HeqX. - cbv [flatten_step] in HeqX. cbn in HeqX. - autorewrite with to_div_mod in HeqX. - cbn [fst snd] in HeqX. - rewrite w_0 in HeqX. - autorewrite with zsimplify in HeqX. - Check Z.div_small. - match type of HeqX with context [ - - cbv [Let_In] in HeqX. - autorewrite with to_div_mod in HeqX. - cbn [fst snd] in HeqX. - cbv - [flatten_column Z.div Z.modulo Z.mul] in HeqX. - cbv [flatten_step] in HeqX. - cbv - [Z.modulo Z.div Z.eqb Z.mul Z.add_get_carry_full Z.add fst snd] in HeqX. cbn [app] in HeqX. - rewrite <-w'_sq in HeqX. - autorewrite with zsimplify_fast in HeqX. - rewrite !Z.mod_1_l in HeqX by omega. - rewrite !Z.mod_mul in HeqX by omega. - rewrite !Z.mod_small with (b:= w' 1 * w' 1) in HeqX by nia. - rewrite Z.eqb_refl in HeqX. - repeat match goal with H : Z.eqb _ _ = _ |- _ => rewrite H in HeqX end. - cbv - [Z.modulo Z.div Z.mul] in HeqX. - autorewrite with zsimplify in HeqX. - subst X. - *) - subst mul_converted_single. reflexivity. Qed. From 23129d1bd24e2ae1d8b9f64997f0cd4c2500a1cf Mon Sep 17 00:00:00 2001 From: Jade Philipoom Date: Wed, 28 Feb 2018 11:54:20 +0100 Subject: [PATCH 4/7] remove special-case convert-mul-convert implementation and use generalized one in Montgomery example --- src/Experiments/SimplyTypedArithmetic.v | 230 +++++++++++------------- 1 file changed, 100 insertions(+), 130 deletions(-) diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 9a5d12be58..d83263ee9a 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -854,7 +854,7 @@ Module Columns. @Positional.eval_to_associational @BaseConversion.eval_convert_bases using solve [auto] : push_eval. - Lemma mul_converted_correct n1 n2 m1 m2 n3 p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): + Lemma eval_mul_converted n1 n2 m1 m2 n3 p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): length p1 = n1 -> length p2 = n2 -> 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 -> Positional.eval w n3 (mul_converted n1 n2 m1 m2 n3 p1 p2) = (Positional.eval w n1 p1) * (Positional.eval w n2 p2). @@ -863,85 +863,41 @@ Module Columns. rewrite Columns.flatten_mod by auto using Columns.length_from_associational. autorewrite with push_eval. auto using Z.mod_small. Qed. + Hint Rewrite eval_mul_converted : push_eval. + + Hint Rewrite @length_from_associational : distr_length. + + Lemma mul_converted_mod n1 n2 m1 m2 n3 p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): + length p1 = n1 -> length p2 = n2 -> + 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 -> + nth_default 0 (mul_converted n1 n2 m1 m2 n3 p1 p2) 0 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) mod (w 1). + Proof. + intros; cbv [mul_converted]. + erewrite flatten_partitions by (auto; distr_length). + autorewrite with distr_length push_eval natsimplify. + rewrite w_0; autorewrite with zsimplify. + reflexivity. + Qed. + + Lemma mul_converted_div n1 n2 m1 m2 n3 p1 p2: + m1 <> 0%nat -> m2 <> 0%nat -> n3 = 2%nat -> + length p1 = n1 -> length p2 = n2 -> + 0 <= Positional.eval w n1 p1 -> + 0 <= Positional.eval w n2 p2 -> + 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 -> + nth_default 0 (mul_converted n1 n2 m1 m2 n3 p1 p2) 1 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) / (w 1). + Proof. + intros; subst n3; cbv [mul_converted]. + erewrite flatten_partitions by (auto; distr_length). + autorewrite with distr_length push_eval. + pose proof (w_positive 1). + apply Z.mod_small. + split; [ solve[Z.zero_bounds] | ]. + apply Z.div_lt_upper_bound; [omega|]. + rewrite Z.mul_div_eq_full by auto. + rewrite w_multiples. omega. + Qed. - (* TODO: this section specializes to one-element lists in which - the intermediate weight is the square root of the old. It would - be better to specialize just to the relationship between - weights, rather than the size of the input. However, partial - reduction/CPS transform seems to take forever when dynamic list - allocation is happening. *) - Section single. - Context (w'_sq : forall i, (w' i) * (w' i) = w i). - Context (w_1_gt1 : w 1 > 1) (w'_1_gt1 : w' 1 > 1). - - Derive convert_single - SuchThat (forall p, convert_single p = BaseConversion.convert_bases w w' 1 2 [p]) - As convert_single_correct. - Proof. - intros. - cbv - [Z.add Z.div Z.mul Z.eqb Z.modulo]. - assert (w 0 mod w' 1 = 1) as P0 by (rewrite w_0, Z.mod_1_l; omega). - assert (w' 1 =? 1 = false) as P1 by (apply Z.eqb_neq; omega). - assert (1 =? 0 = false) as P2 by reflexivity. - repeat match goal with - | _ => progress rewrite ?w_0, ?w'_0 - | _ => progress rewrite ?P0, ?P1, ?P2 - | _ => progress rewrite ?Z.mod_1_l, ?Z.eqb_refl by omega - | _ => progress autorewrite with zsimplify_fast - end. - autorewrite with zsimplify. - reflexivity. - Qed. - - Derive mul_converted_single - SuchThat (forall n (p1 p2 : Z), (0 <= p1 < w 1) -> (0 <= p2 < w 1) -> - mul_converted_single n p1 p2 = mul_converted 1 1 2 2 n [p1] [p2]) - As mul_converted_single_eq. - Proof. - intros. - cbv [mul_converted]. - rewrite <-!convert_single_correct. - cbv [convert_single]. - subst mul_converted_single. - reflexivity. - Qed. - - Lemma eval_mul_converted_single n p1 p2 (_: n <> 0%nat) (_: 0 <= p1 < w 1) (_:0 <= p2 < w 1) (_: 0 <= p1 * p2 < w n) : - Positional.eval w n (mul_converted_single n p1 p2) = (Positional.eval w 1 [p1]) * (Positional.eval w 1 [p2]). - Proof. rewrite mul_converted_single_eq by auto. apply mul_converted_correct; cbn; nia. Qed. - - Hint Rewrite @length_from_associational : distr_length. - - Lemma mul_converted_single_mod n x y : - n = 2%nat -> 0 <= x < w 1 -> 0 <= y < w 1 -> - nth_default 0 (mul_converted_single n x y) 0 = (x * y) mod (w 1). - Proof. - intros; subst n; rewrite mul_converted_single_eq by auto. cbv [mul_converted]. - erewrite flatten_partitions by (auto; distr_length). - autorewrite with distr_length push_eval. cbn. - rewrite w_0; autorewrite with zsimplify. - reflexivity. - Qed. - - Lemma mul_converted_single_div n x y : - n = 2%nat -> - 0 <= x < w 1 -> 0 <= y < w 1 -> - 0 <= x * y < w 2 -> - nth_default 0 (mul_converted_single n x y) 1 = (x * y) / (w 1). - Proof. - intros; subst n; rewrite mul_converted_single_eq by auto. cbv [mul_converted]. - erewrite flatten_partitions by (auto; distr_length). - autorewrite with distr_length push_eval. cbn. - rewrite w_0; autorewrite with zsimplify. - apply Z.mod_small. - split. - { apply Z.div_nonneg; auto; omega. } - { apply Z.div_lt_upper_bound. omega. - rewrite Z.mul_div_eq_full by auto. - rewrite w_multiples. omega. } - Qed. - - End single. End mul_converted. End Columns. @@ -5640,7 +5596,7 @@ Module RemoveDeadLets. | Let_In s T n x f => Let_In n (inline_let idx _ new _ x) (inline_let idx _ new _ f) end. - (* inlines lets that just re-bind a variable or half a variable with type prod *) + (* inlines lets that just re-bind a variable or the output of a specified operation on a single variable *) Fixpoint inline_silly_lets t (e : @expr ident t) : @expr ident t := match e in (expr t') return expr t' with | Var T n => Var T n @@ -5652,7 +5608,7 @@ Module RemoveDeadLets. match x with | Var T' m => inline_let n _ (Var T' m) _ f | AppIdent _ _ (@BoundsAnalysis.ident.fst A B) (Var _ m) => - inline_let n _ (@AppIdent _ _ _ (@BoundsAnalysis.ident.fst A B) (Var _ m)) _ (inline_silly_lets _ f) + inline_let n _ (@AppIdent _ _ _ (@BoundsAnalysis.ident.fst A B) (Var _ m)) _ (inline_silly_lets _ f) | _ => Let_In n (inline_silly_lets _ x) (inline_silly_lets _ f) end end. @@ -5690,8 +5646,8 @@ Module MontgomeryReduction. Context (n:nat) (Hn : n = 2%nat). Definition montred' (lo_hi : (Z * Z)) := - dlet_nd y := nth_default 0 (Columns.mul_converted_single w w_half n (fst lo_hi) N') 0 in - dlet_nd t1_t2 := Columns.mul_converted_single w w_half n y N in + dlet_nd y := nth_default 0 (Columns.mul_converted w w_half 1 1 n n n [fst lo_hi] [N']) 0 in + dlet_nd t1_t2 := Columns.mul_converted w w_half 1 1 n n n [y] [N] in dlet_nd lo'_carry := Z.add_get_carry_full R (fst lo_hi) (nth_default 0 t1_t2 0) in dlet_nd hi'_carry := Z.add_with_get_carry_full R (snd lo'_carry) (snd lo_hi) (nth_default 0 t1_t2 1) in dlet_nd y' := Z.zselect (snd hi'_carry) 0 N in @@ -5702,9 +5658,12 @@ Module MontgomeryReduction. repeat match goal with | _ => rewrite H, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r | |- context [?a mod ?b] => unique pose proof (Z.mod_pos_bound a b ltac:(omega)) + | |- 0 <= _ => progress Z.zero_bounds | |- 0 <= _ * _ < _ * _ => split; [ solve [Z.zero_bounds] | apply Z.mul_lt_mono_nonneg; omega ] | _ => solve [auto] + | _ => cbn + | _ => nia end. Lemma montred'_eq lo_hi T (HT_range: 0 <= T < R * N) @@ -5714,11 +5673,10 @@ Module MontgomeryReduction. Proof. rewrite <-reduce_via_partial_alt_eq by nia. cbv [montred' partial_reduce_alt reduce_via_partial_alt prereduce Let_In]. - rewrite Hlo, Hhi. + rewrite Hlo, Hhi. subst n. assert (0 <= T mod R * N' < w 2) by (solve_range Hw). - rewrite !Columns.mul_converted_single_mod; - (auto; rewrite ?Columns.mul_converted_single_mod; solve_range Hw). - rewrite !Columns.mul_converted_single_div by (auto; solve_range Hw). + rewrite !Columns.mul_converted_mod by (auto; rewrite ?Columns.mul_converted_mod; solve_range Hw). + rewrite !Columns.mul_converted_div by (auto; solve_range Hw). rewrite Hw, ?Z.pow_1_r. autorewrite with to_div_mod. rewrite ?Z.zselect_correct, ?Z.add_modulo_correct. @@ -5730,7 +5688,9 @@ Module MontgomeryReduction. |- context [if R * R <=? ?x then _ else _] => match goal with |- context [if dec (?xHigh / R = 0) then _ else _] => assert (x / R = xHigh) as cond_equiv end end. - { apply Z.mul_cancel_r with (p:=R); [omega|]. autorewrite with push_Zmul zdiv_to_mod push_Zmod; ring. } + { apply Z.mul_cancel_r with (p:=R); [omega|]. cbn. + rewrite w_0. autorewrite with zsimplify_fast. + autorewrite with push_Zmul zdiv_to_mod push_Zmod; ring. } rewrite <-cond_equiv. rewrite ?Z.mod_pull_div, ?Z.div_div by omega. assert (0 < R * R)%Z by Z.zero_bounds. @@ -5917,26 +5877,28 @@ Module Montgomery256. Open Scope nexpr_scope. Print montred256. (* - expr_let 2 := (uint128)(MUL_256 @@ - (((uint128)fst @@ x_1 & 340282366920938463463374607431768211455), (340282366841710300986003757985643364352)) << 128) in - expr_let 3 := (uint128)(MUL_256 @@ ((uint128)(fst @@ x_1 >> 128), (79228162514264337593543950337)) << 128) in - expr_let 8 := MUL_256 @@ (((uint128)fst @@ x_1 & 340282366920938463463374607431768211455), (79228162514264337593543950337)) in - expr_let 9 := ADD_256 @@ (x_2, x_3) in - expr_let 10 := ADD_256 @@ (x_8, fst @@ x_9) in - expr_let 20 := (uint128)(MUL_256 @@ - (((uint128)fst @@ x_10 & 340282366920938463463374607431768211455), (340282366841710300967557013911933812736)) << 128) in - expr_let 21 := (uint128)(MUL_256 @@ ((uint128)(fst @@ x_10 >> 128), (79228162514264337593543950335)) << 128) in - expr_let 26 := MUL_256 @@ (((uint128)fst @@ x_10 & 340282366920938463463374607431768211455), (79228162514264337593543950335)) in - expr_let 27 := ADD_128 @@ (x_20, x_21) in - expr_let 28 := ADD_256 @@ (x_26, fst @@ x_27) in - expr_let 29 := snd @@ x_28 +₁₂₈ snd @@ x_27 in - expr_let 36 := MUL_256 @@ ((uint128)(fst @@ x_10 >> 128), (340282366841710300967557013911933812736)) in - expr_let 37 := ADD_256 @@ (x_29, x_36) in - expr_let 38 := ADD_256 @@ (fst @@ x_1, fst @@ x_28) in - expr_let 39 := ADDC_256 @@ (snd @@ x_38, snd @@ x_1, fst @@ x_37) in - expr_let 40 := SELC @@ (snd @@ x_39, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) in - expr_let 41 := fst @@ (SUB_256 @@ (fst @@ x_39, x_40)) in - ADDM @@ (x_41, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) + expr_let 3 := (uint128)(fst @@ x_1 >> 128) in + expr_let 4 := ((uint128)fst @@ x_1 & 340282366920938463463374607431768211455) in + expr_let 5 := (uint128)(MUL_256 @@ (x_4, (340282366841710300986003757985643364352)) << 128) in + expr_let 6 := (uint128)(MUL_256 @@ (x_3, (79228162514264337593543950337)) << 128) in + expr_let 11 := MUL_256 @@ (x_4, (79228162514264337593543950337)) in + expr_let 12 := ADD_256 @@ (x_5, x_6) in + expr_let 13 := ADD_256 @@ (x_11, fst @@ x_12) in + expr_let 23 := (uint128)(fst @@ x_13 >> 128) in + expr_let 24 := ((uint128)fst @@ x_13 & 340282366920938463463374607431768211455) in + expr_let 25 := (uint128)(MUL_256 @@ (x_24, (340282366841710300967557013911933812736)) << 128) in + expr_let 26 := (uint128)(MUL_256 @@ (x_23, (79228162514264337593543950335)) << 128) in + expr_let 31 := MUL_256 @@ (x_24, (79228162514264337593543950335)) in + expr_let 32 := ADD_128 @@ (x_25, x_26) in + expr_let 33 := ADD_256 @@ (x_31, fst @@ x_32) in + expr_let 34 := snd @@ x_33 +₁₂₈ snd @@ x_32 in + expr_let 41 := MUL_256 @@ (x_23, (340282366841710300967557013911933812736)) in + expr_let 42 := ADD_256 @@ (x_34, x_41) in + expr_let 43 := ADD_256 @@ (fst @@ x_1, fst @@ x_33) in + expr_let 44 := ADDC_256 @@ (snd @@ x_43, snd @@ x_1, fst @@ x_42) in + expr_let 45 := SELC @@ (snd @@ x_44, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) in + expr_let 46 := fst @@ (SUB_256 @@ (fst @@ x_44, x_45)) in + ADDM @@ (x_46, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) : expr uint256 *) End Montgomery256. @@ -6014,18 +5976,22 @@ Module Montgomery256PrintingNotations. f)%nexpr (at level 40, f at level 200, right associativity, format "'c.Sub(' '$r' n ',' x ',' y ');' '//' f") : nexpr_scope. Notation "'c.AddM(' '$ret' ',' x ',' y ',' z ');'" := (add_modulo _ _ _ uint256 @@ (x, y, z))%nexpr (at level 40, format "'c.AddM(' '$ret' ',' x ',' y ',' z ');'") : nexpr_scope. + Notation "'c.ShiftR(' '$r' n ',' x ',' y ');' f" := + (expr_let n := (shiftr _ _ y @@ x) in f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftR(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope. + Notation "'c.Lower128(' '$r' n ',' x ');' f" := + (expr_let n := (land _ _ 340282366920938463463374607431768211455 @@ x) in f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Lower128(' '$r' n ',' x ');' ']' '//' f") : nexpr_scope. Notation "'Lower128'" := ((land uint256 uint128 340282366920938463463374607431768211455)) (at level 10, only printing, format "Lower128") : nexpr_scope. - Notation "( v >> count )" - := ((shiftr _ _ count @@ v)%nexpr) - (format "( v >> count )") - : nexpr_scope. Notation "( v << count )" := ((shiftl _ _ count @@ v)%nexpr) (format "( v << count )") : nexpr_scope. + Notation "( x >> count )" + := ((shiftr _ _ count @@ x)%nexpr) + (format "( x >> count )") + : nexpr_scope. End Montgomery256PrintingNotations. Import Montgomery256PrintingNotations. @@ -6033,23 +5999,27 @@ Local Open Scope nexpr_scope. Print Montgomery256.montred256. (* -c.Mul128x128($r2, Lower128 @@ $r1_lo, RegPinv >> 128) << 128; -c.Mul128x128($r3, ($r1_lo >> 128), Lower128{RegPinv}) << 128; -c.Mul128x128($r8, Lower128 @@ $r1_lo, Lower128{RegPinv}); -c.Add256($r9, $r2, $r3); -c.Add256($r10, $r8, $r9_lo); -c.Mul128x128($r20, Lower128 @@ $r10_lo, RegMod << 128) << 128; -c.Mul128x128($r21, ($r10_lo >> 128), Lower128{RegMod}) << 128; -c.Mul128x128($r26, Lower128 @@ $r10_lo, Lower128{RegMod}); -c.Add128($r27, $r20, $r21); -c.Add256($r28, $r26, $r27_lo); -c.Add64($r29, $r28_hi, $r27_hi); -c.Mul128x128($r36, ($r10_lo >> 128), RegMod << 128); -c.Add256($r37, $r29, $r36); -c.Add256($r38, $r1_lo, $r28_lo); -c.Addc($r39, $r1_hi, $r37_lo); -c.Selc($r40,RegZero, RegMod); -c.Sub($r41, $r39_lo, $r40); -c.AddM($ret, $r41, RegZero, RegMod); +c.ShiftR($r3,$r1_lo, 128); +c.Lower128($r4,$r1_lo); +c.Mul128x128($r5, $r4, RegPinv >> 128) << 128; +c.Mul128x128($r6, $r3, Lower128{RegPinv}) << 128; +c.Mul128x128($r11, $r4, Lower128{RegPinv}); +c.Add256($r12, $r5, $r6); +c.Add256($r13, $r11, $r12_lo); +c.ShiftR($r23,$r13_lo, 128); +c.Lower128($r24,$r13_lo); +c.Mul128x128($r25, $r24, RegMod << 128) << 128; +c.Mul128x128($r26, $r23, Lower128{RegMod}) << 128; +c.Mul128x128($r31, $r24, Lower128{RegMod}); +c.Add128($r32, $r25, $r26); +c.Add256($r33, $r31, $r32_lo); +c.Add64($r34, $r33_hi, $r32_hi); +c.Mul128x128($r41, $r23, RegMod << 128); +c.Add256($r42, $r34, $r41); +c.Add256($r43, $r1_lo, $r33_lo); +c.Addc($r44, $r1_hi, $r42_lo); +c.Selc($r45,RegZero, RegMod); +c.Sub($r46, $r44_lo, $r45); +c.AddM($ret, $r46, RegZero, RegMod); : expr uint256 *) From 5cf04c59841b55ee012857afe605b091e5b18eea Mon Sep 17 00:00:00 2001 From: Jade Philipoom Date: Wed, 28 Feb 2018 16:45:34 +0100 Subject: [PATCH 5/7] make Montgomery do associational carries in a generalized way --- src/Experiments/SimplyTypedArithmetic.v | 166 +++++++++++++++--------- 1 file changed, 105 insertions(+), 61 deletions(-) diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index d83263ee9a..2630c099ad 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -831,20 +831,21 @@ Module Columns. (w_multiples : forall i, w (S i) mod w i = 0) (w_divides : forall i : nat, w (S i) / w i > 0). - (* take in inputs in base w. Converts to w', multiplies in that format, converts to w again, then flattens. *) + (* takes in inputs in base w, converts to w', multiplies in that + format, converts to w again, then flattens. *) Definition mul_converted n1 n2 (* lengths in original format *) m1 m2 (* lengths in converted format *) (n3 : nat) (* final length *) + (idxs : list nat) (* carries to do -- this helps preemptively line up weights *) (p1 p2 : list Z) := let p1' := BaseConversion.convert_bases w w' n1 m1 p1 in let p2' := BaseConversion.convert_bases w w' n2 m2 p2 in let p1_a := Positional.to_associational w' m1 p1' in let p2_a := Positional.to_associational w' m2 p2' in - (* - let p3_a := Associational.carry (w' 1%nat) (w 1) (Associational.mul p1_a p2_a) in - *) let p3_a := Associational.mul p1_a p2_a in + (* important not to use Positional.carry here; we don't want to accumulate yet *) + let p3'_a := fold_right (fun i acc => Associational.carry (w' i) (w' (S i) / w' i) acc) p3_a (rev idxs) in fst (flatten w (from_associational w n3 p3_a)). Hint Rewrite @@ -854,10 +855,10 @@ Module Columns. @Positional.eval_to_associational @BaseConversion.eval_convert_bases using solve [auto] : push_eval. - Lemma eval_mul_converted n1 n2 m1 m2 n3 p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): + Lemma eval_mul_converted n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): length p1 = n1 -> length p2 = n2 -> 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 -> - Positional.eval w n3 (mul_converted n1 n2 m1 m2 n3 p1 p2) = (Positional.eval w n1 p1) * (Positional.eval w n2 p2). + Positional.eval w n3 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) = (Positional.eval w n1 p1) * (Positional.eval w n2 p2). Proof. cbv [mul_converted]; intros. rewrite Columns.flatten_mod by auto using Columns.length_from_associational. @@ -867,10 +868,10 @@ Module Columns. Hint Rewrite @length_from_associational : distr_length. - Lemma mul_converted_mod n1 n2 m1 m2 n3 p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): + Lemma mul_converted_mod n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): length p1 = n1 -> length p2 = n2 -> 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 -> - nth_default 0 (mul_converted n1 n2 m1 m2 n3 p1 p2) 0 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) mod (w 1). + nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) 0 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) mod (w 1). Proof. intros; cbv [mul_converted]. erewrite flatten_partitions by (auto; distr_length). @@ -879,13 +880,13 @@ Module Columns. reflexivity. Qed. - Lemma mul_converted_div n1 n2 m1 m2 n3 p1 p2: + Lemma mul_converted_div n1 n2 m1 m2 n3 idxs p1 p2: m1 <> 0%nat -> m2 <> 0%nat -> n3 = 2%nat -> length p1 = n1 -> length p2 = n2 -> 0 <= Positional.eval w n1 p1 -> 0 <= Positional.eval w n2 p2 -> 0 <= (Positional.eval w n1 p1 * Positional.eval w n2 p2) < w n3 -> - nth_default 0 (mul_converted n1 n2 m1 m2 n3 p1 p2) 1 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) / (w 1). + nth_default 0 (mul_converted n1 n2 m1 m2 n3 idxs p1 p2) 1 = (Positional.eval w n1 p1 * Positional.eval w n2 p2) / (w 1). Proof. intros; subst n3; cbv [mul_converted]. erewrite flatten_partitions by (auto; distr_length). @@ -898,6 +899,12 @@ Module Columns. rewrite w_multiples. omega. Qed. + (* shortcut definition for convert-mul-convert for cases when we are halving the bitwidth before multiplying. *) + (* the most important feature here is the carries--we carry from all the odd indices after multiplying, + thus pre-aligning everything with the double-size bitwidth *) + Definition mul_converted_halve n n2 := + mul_converted n n n2 n2 n2 (map (fun x => 2*x + 1)%nat (seq 0 n)). + End mul_converted. End Columns. @@ -1287,6 +1294,8 @@ Module Compilers. | primitive {t:type.primitive} (v : interp t) : ident () t | Let_In {tx tC} : ident (tx * (tx -> tC)) tC | Nat_succ : ident nat nat + | Nat_mul : ident (nat * nat) nat + | Nat_add : ident (nat * nat) nat | nil {t} : ident () (list t) | cons {t} : ident (t * list t) (list t) | fst {A B} : ident (A * B) A @@ -1348,6 +1357,8 @@ Module Compilers. | primitive _ v => curry0 v | Let_In tx tC => curry2 (@LetIn.Let_In (type.interp tx) (fun _ => type.interp tC)) | Nat_succ => Nat.succ + | Nat_add => curry2 Nat.add + | Nat_mul => curry2 Nat.mul | nil t => curry0 (@Datatypes.nil (type.interp t)) | cons t => curry2 (@Datatypes.cons (type.interp t)) | fst A B => @Datatypes.fst (type.interp A) (type.interp B) @@ -1393,6 +1404,8 @@ Module Compilers. (*let dummy := match goal with _ => idtac "attempting to reify_op" term end in*) lazymatch term with | Nat.succ ?x => mkAppIdent Nat_succ x + | Nat.add ?x ?y => mkAppIdent Nat_add (x, y) + | Nat.mul ?x ?y => mkAppIdent Nat_mul (x, y) | S ?x => mkAppIdent Nat_succ x | @Datatypes.nil ?T => let rT := type.reify T in @@ -1546,6 +1559,8 @@ Module Compilers. Module Nat. Notation succ := Nat_succ. + Notation add := Nat_add. + Notation mul := Nat_mul. End Nat. Module Export Notations. @@ -1587,6 +1602,8 @@ Module Compilers. | primitive {t : type.primitive} (v : interp t) : ident () t | Let_In {tx tC} : ident (tx * (tx -> tC)) tC | Nat_succ : ident nat nat + | Nat_add : ident (nat * nat) nat + | Nat_mul : ident (nat * nat) nat | nil {t} : ident () (list t) | cons {t} : ident (t * list t) (list t) | fst {A B} : ident (A * B) A @@ -1643,6 +1660,8 @@ Module Compilers. | primitive _ v => curry0 v | Let_In tx tC => curry2 (@LetIn.Let_In (type.interp tx) (fun _ => type.interp tC)) | Nat_succ => Nat.succ + | Nat_add => curry2 Nat.add + | Nat_mul => curry2 Nat.mul | nil t => curry0 (@Datatypes.nil (type.interp t)) | cons t => curry2 (@Datatypes.cons (type.interp t)) | fst A B => @Datatypes.fst (type.interp A) (type.interp B) @@ -1685,6 +1704,8 @@ Module Compilers. (*let dummy := match goal with _ => idtac "attempting to reify_op" term end in*) lazymatch term with | Nat.succ ?x => mkAppIdent Nat_succ x + | Nat.add ?x ?y => mkAppIdent Nat_add (x, y) + | Nat.mul ?x ?y => mkAppIdent Nat_mul (x, y) | S ?x => mkAppIdent Nat_succ x | @Datatypes.nil ?T => let rT := type.reify T in @@ -1800,6 +1821,8 @@ Module Compilers. Module Nat. Notation succ := Nat_succ. + Notation add := Nat_add. + Notation mul := Nat_mul. End Nat. Module Export Notations. @@ -1859,6 +1882,10 @@ Module Compilers. => AppIdent ident.Let_In | for_reification.ident.Nat_succ => AppIdent ident.Nat_succ + | for_reification.ident.Nat_add + => AppIdent ident.Nat_add + | for_reification.ident.Nat_mul + => AppIdent ident.Nat_mul | for_reification.ident.nil t => AppIdent ident.nil | for_reification.ident.cons t @@ -2668,6 +2695,8 @@ Module Compilers. match idc in Uncurried.expr.default.ident s d return type.interp R (type.translate s) -> (type.interp R (type.translate d) -> R) -> R with | ident.primitive _ _ as idc | ident.Nat_succ as idc + | ident.Nat_add as idc + | ident.Nat_mul as idc | ident.pred as idc | ident.Z_shiftr _ as idc | ident.Z_shiftl _ as idc @@ -2852,6 +2881,13 @@ Module Compilers. (ident.snd @@ (Var xyk)) @ ((idc : default.ident _ type.nat) @@ (ident.fst @@ (Var xyk))) + | ident.Nat_add as idc + | ident.Nat_mul as idc + => λ (xyk : + (* ignore this line; it's to work around lack of fixpoint refolding in type inference *) var (type.nat * type.nat * (type.nat -> R))%ctype) , + (ident.snd @@ (Var xyk)) + @ ((idc : default.ident _ type.nat) + @@ (ident.fst @@ (Var xyk))) | ident.Z_shiftr _ as idc | ident.Z_shiftl _ as idc | ident.Z_land _ as idc @@ -3596,6 +3632,8 @@ Module Compilers. | inr x => inr (ident.interp idc x) | inl x => expr.reflect (AppIdent idc x) end + | ident.Nat_add as idc + | ident.Nat_mul as idc | ident.Z_pow as idc | ident.Z_eqb as idc | ident.Z_leb as idc @@ -4176,6 +4214,8 @@ Module Compilers. | default.ident.primitive _ _ => None | ident.Let_In tx tC => None | ident.Nat_succ => None + | ident.Nat_add => None + | ident.Nat_mul => None | default.ident.nil (Compilers.type.type_primitive t) => Some (@nil (type.primitive.compile t)) | default.ident.nil _ @@ -5643,20 +5683,22 @@ Module MontgomeryReduction. (w_multiples : forall i, w (S i) mod w i = 0) (w_divides : forall i : nat, w (S i) / w i > 0). Context (w_1_gt1 : w 1 > 1) (w_half_1_gt1 : w_half 1 > 1). - Context (n:nat) (Hn : n = 2%nat). + Context (n:nat) (Hn: n = 2%nat). Definition montred' (lo_hi : (Z * Z)) := - dlet_nd y := nth_default 0 (Columns.mul_converted w w_half 1 1 n n n [fst lo_hi] [N']) 0 in - dlet_nd t1_t2 := Columns.mul_converted w w_half 1 1 n n n [y] [N] in + dlet_nd y := nth_default 0 (Columns.mul_converted_halve w w_half 1%nat n [fst lo_hi] [N']) 0 in + dlet_nd t1_t2 := Columns.mul_converted_halve w w_half 1%nat n [y] [N] in dlet_nd lo'_carry := Z.add_get_carry_full R (fst lo_hi) (nth_default 0 t1_t2 0) in dlet_nd hi'_carry := Z.add_with_get_carry_full R (snd lo'_carry) (snd lo_hi) (nth_default 0 t1_t2 1) in dlet_nd y' := Z.zselect (snd hi'_carry) 0 N in dlet_nd lo'' := fst (Z.sub_get_borrow_full R (fst hi'_carry) y') in Z.add_modulo lo'' 0 N. - Local Ltac solve_range H := + Context (Hw : forall i, w i = R ^ Z.of_nat i). + + Local Ltac solve_range := repeat match goal with - | _ => rewrite H, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r + | _ => rewrite Hw, ?Z.pow_0_r, ?Z.pow_1_r, ?Z.pow_2_r | |- context [?a mod ?b] => unique pose proof (Z.mod_pos_bound a b ltac:(omega)) | |- 0 <= _ => progress Z.zero_bounds | |- 0 <= _ * _ < _ * _ => @@ -5666,17 +5708,20 @@ Module MontgomeryReduction. | _ => nia end. + Hint Rewrite + Columns.mul_converted_mod Columns.mul_converted_div using (solve [auto; autorewrite with mul_conv; solve_range]) + : mul_conv. + Lemma montred'_eq lo_hi T (HT_range: 0 <= T < R * N) - (Hw : forall i, w i = R ^ Z.of_nat i) (Hlo: fst lo_hi = T mod R) (Hhi: snd lo_hi = T / R): montred' lo_hi = reduce_via_partial N R N' T. Proof. rewrite <-reduce_via_partial_alt_eq by nia. cbv [montred' partial_reduce_alt reduce_via_partial_alt prereduce Let_In]. rewrite Hlo, Hhi. subst n. - assert (0 <= T mod R * N' < w 2) by (solve_range Hw). - rewrite !Columns.mul_converted_mod by (auto; rewrite ?Columns.mul_converted_mod; solve_range Hw). - rewrite !Columns.mul_converted_div by (auto; solve_range Hw). + assert (0 <= T mod R * N' < w 2) by (solve_range). + cbv [Columns.mul_converted_halve]. cbn. + autorewrite with mul_conv. rewrite Hw, ?Z.pow_1_r. autorewrite with to_div_mod. rewrite ?Z.zselect_correct, ?Z.add_modulo_correct. @@ -5706,7 +5751,6 @@ Module MontgomeryReduction. Qed. Lemma montred'_correct lo_hi T (HT_range: 0 <= T < R * N) - (Hw : forall i, w i = R ^ Z.of_nat i) (Hlo: fst lo_hi = T mod R) (Hhi: snd lo_hi = T / R): montred' lo_hi = (T * R') mod N. Proof. erewrite montred'_eq by eauto. @@ -5719,7 +5763,7 @@ Module MontgomeryReduction. Derive montred_gen SuchThat (forall (N R N' : Z) (w w_half : nat -> Z) - (n : nat) + (n: nat) (lo_hi : Z * Z), Interp (t:=type.reify_type_of montred') montred_gen N R N' w w_half n lo_hi @@ -5879,26 +5923,26 @@ Module Montgomery256. (* expr_let 3 := (uint128)(fst @@ x_1 >> 128) in expr_let 4 := ((uint128)fst @@ x_1 & 340282366920938463463374607431768211455) in - expr_let 5 := (uint128)(MUL_256 @@ (x_4, (340282366841710300986003757985643364352)) << 128) in - expr_let 6 := (uint128)(MUL_256 @@ (x_3, (79228162514264337593543950337)) << 128) in - expr_let 11 := MUL_256 @@ (x_4, (79228162514264337593543950337)) in - expr_let 12 := ADD_256 @@ (x_5, x_6) in - expr_let 13 := ADD_256 @@ (x_11, fst @@ x_12) in - expr_let 23 := (uint128)(fst @@ x_13 >> 128) in - expr_let 24 := ((uint128)fst @@ x_13 & 340282366920938463463374607431768211455) in - expr_let 25 := (uint128)(MUL_256 @@ (x_24, (340282366841710300967557013911933812736)) << 128) in - expr_let 26 := (uint128)(MUL_256 @@ (x_23, (79228162514264337593543950335)) << 128) in - expr_let 31 := MUL_256 @@ (x_24, (79228162514264337593543950335)) in - expr_let 32 := ADD_128 @@ (x_25, x_26) in - expr_let 33 := ADD_256 @@ (x_31, fst @@ x_32) in - expr_let 34 := snd @@ x_33 +₁₂₈ snd @@ x_32 in - expr_let 41 := MUL_256 @@ (x_23, (340282366841710300967557013911933812736)) in - expr_let 42 := ADD_256 @@ (x_34, x_41) in - expr_let 43 := ADD_256 @@ (fst @@ x_1, fst @@ x_33) in - expr_let 44 := ADDC_256 @@ (snd @@ x_43, snd @@ x_1, fst @@ x_42) in - expr_let 45 := SELC @@ (snd @@ x_44, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) in - expr_let 46 := fst @@ (SUB_256 @@ (fst @@ x_44, x_45)) in - ADDM @@ (x_46, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) + expr_let 11 := (uint128)(MUL_256 @@ (x_4, (340282366841710300986003757985643364352)) << 128) in + expr_let 12 := (uint128)(MUL_256 @@ (x_3, (79228162514264337593543950337)) << 128) in + expr_let 17 := MUL_256 @@ (x_4, (79228162514264337593543950337)) in + expr_let 18 := ADD_256 @@ (x_11, x_12) in + expr_let 19 := ADD_256 @@ (x_17, fst @@ x_18) in + expr_let 29 := (uint128)(fst @@ x_19 >> 128) in + expr_let 30 := ((uint128)fst @@ x_19 & 340282366920938463463374607431768211455) in + expr_let 37 := (uint128)(MUL_256 @@ (x_30, (340282366841710300967557013911933812736)) << 128) in + expr_let 38 := (uint128)(MUL_256 @@ (x_29, (79228162514264337593543950335)) << 128) in + expr_let 43 := MUL_256 @@ (x_30, (79228162514264337593543950335)) in + expr_let 44 := ADD_128 @@ (x_37, x_38) in + expr_let 45 := ADD_256 @@ (x_43, fst @@ x_44) in + expr_let 46 := snd @@ x_45 +₁₂₈ snd @@ x_44 in + expr_let 53 := MUL_256 @@ (x_29, (340282366841710300967557013911933812736)) in + expr_let 54 := ADD_256 @@ (x_46, x_53) in + expr_let 55 := ADD_256 @@ (fst @@ x_1, fst @@ x_45) in + expr_let 56 := ADDC_256 @@ (snd @@ x_55, snd @@ x_1, fst @@ x_54) in + expr_let 57 := SELC @@ (snd @@ x_56, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) in + expr_let 58 := fst @@ (SUB_256 @@ (fst @@ x_56, x_57)) in + ADDM @@ (x_58, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) : expr uint256 *) End Montgomery256. @@ -6001,25 +6045,25 @@ Print Montgomery256.montred256. (* c.ShiftR($r3,$r1_lo, 128); c.Lower128($r4,$r1_lo); -c.Mul128x128($r5, $r4, RegPinv >> 128) << 128; -c.Mul128x128($r6, $r3, Lower128{RegPinv}) << 128; -c.Mul128x128($r11, $r4, Lower128{RegPinv}); -c.Add256($r12, $r5, $r6); -c.Add256($r13, $r11, $r12_lo); -c.ShiftR($r23,$r13_lo, 128); -c.Lower128($r24,$r13_lo); -c.Mul128x128($r25, $r24, RegMod << 128) << 128; -c.Mul128x128($r26, $r23, Lower128{RegMod}) << 128; -c.Mul128x128($r31, $r24, Lower128{RegMod}); -c.Add128($r32, $r25, $r26); -c.Add256($r33, $r31, $r32_lo); -c.Add64($r34, $r33_hi, $r32_hi); -c.Mul128x128($r41, $r23, RegMod << 128); -c.Add256($r42, $r34, $r41); -c.Add256($r43, $r1_lo, $r33_lo); -c.Addc($r44, $r1_hi, $r42_lo); -c.Selc($r45,RegZero, RegMod); -c.Sub($r46, $r44_lo, $r45); -c.AddM($ret, $r46, RegZero, RegMod); +c.Mul128x128($r11, $r4, RegPinv >> 128) << 128; +c.Mul128x128($r12, $r3, Lower128{RegPinv}) << 128; +c.Mul128x128($r17, $r4, Lower128{RegPinv}); +c.Add256($r18, $r11, $r12); +c.Add256($r19, $r17, $r18_lo); +c.ShiftR($r29,$r19_lo, 128); +c.Lower128($r30,$r19_lo); +c.Mul128x128($r37, $r30, RegMod << 128) << 128; +c.Mul128x128($r38, $r29, Lower128{RegMod}) << 128; +c.Mul128x128($r43, $r30, Lower128{RegMod}); +c.Add128($r44, $r37, $r38); +c.Add256($r45, $r43, $r44_lo); +c.Add64($r46, $r45_hi, $r44_hi); +c.Mul128x128($r53, $r29, RegMod << 128); +c.Add256($r54, $r46, $r53); +c.Add256($r55, $r1_lo, $r45_lo); +c.Addc($r56, $r1_hi, $r54_lo); +c.Selc($r57,RegZero, RegMod); +c.Sub($r58, $r56_lo, $r57); +c.AddM($ret, $r58, RegZero, RegMod); : expr uint256 *) From ce1dbc66564e6aba12a2f8a09493fbf0f875c8bb Mon Sep 17 00:00:00 2001 From: Jade Philipoom Date: Wed, 28 Feb 2018 18:25:15 +0100 Subject: [PATCH 6/7] fix a typo, some comments, and notations --- src/Experiments/SimplyTypedArithmetic.v | 75 ++++++++++++++++--------- 1 file changed, 49 insertions(+), 26 deletions(-) diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 2630c099ad..44f0139d25 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -846,14 +846,20 @@ Module Columns. let p3_a := Associational.mul p1_a p2_a in (* important not to use Positional.carry here; we don't want to accumulate yet *) let p3'_a := fold_right (fun i acc => Associational.carry (w' i) (w' (S i) / w' i) acc) p3_a (rev idxs) in - fst (flatten w (from_associational w n3 p3_a)). + fst (flatten w (from_associational w n3 p3'_a)). Hint Rewrite @Columns.eval_from_associational @Associational.eval_carry @Associational.eval_mul @Positional.eval_to_associational - @BaseConversion.eval_convert_bases using solve [auto] : push_eval. + @BaseConversion.eval_convert_bases using solve [auto using Z.positive_is_nonzero] : push_eval. + + Lemma eval_carries p idxs : + Associational.eval (fold_right (fun i acc => Associational.carry (w' i) (w' (S i) / w' i) acc) p idxs) = + Associational.eval p. + Proof. apply fold_right_invariant; intros; autorewrite with push_eval; congruence. Qed. + Hint Rewrite eval_carries: push_eval. Lemma eval_mul_converted n1 n2 m1 m2 n3 idxs p1 p2 (_:n3<>0%nat) (_:m1<>0%nat) (_:m2<>0%nat): length p1 = n1 -> length p2 = n2 -> @@ -5321,6 +5327,7 @@ Module PrintingNotations. Notation "x +₃₂ y" := (add uint32 uint32 uint32 @@ (x, y))%nexpr (at level 50) : nexpr_scope. Notation "x" := ({| BoundsAnalysis.type.value := x |}) (only printing) : nexpr_scope. + (* Notation "( out_t )( v >> count )" := ((shiftr _ out_t count @@ v)%nexpr) (format "( out_t )( v >> count )") @@ -5333,6 +5340,7 @@ Module PrintingNotations. := ((land _ out_t mask @@ v)%nexpr) (format "( ( out_t ) v & mask )") : nexpr_scope. +*) (* TODO: come up with a better notation for arithmetic with carries that still distinguishes it from arithmetic without carries? *) @@ -5636,7 +5644,7 @@ Module RemoveDeadLets. | Let_In s T n x f => Let_In n (inline_let idx _ new _ x) (inline_let idx _ new _ f) end. - (* inlines lets that just re-bind a variable or the output of a specified operation on a single variable *) + (* inlines lets that just re-bind a variable or half a variable with type prod *) Fixpoint inline_silly_lets t (e : @expr ident t) : @expr ident t := match e in (expr t') return expr t' with | Var T n => Var T n @@ -5648,7 +5656,7 @@ Module RemoveDeadLets. match x with | Var T' m => inline_let n _ (Var T' m) _ f | AppIdent _ _ (@BoundsAnalysis.ident.fst A B) (Var _ m) => - inline_let n _ (@AppIdent _ _ _ (@BoundsAnalysis.ident.fst A B) (Var _ m)) _ (inline_silly_lets _ f) + inline_let n _ (@AppIdent _ _ _ (@BoundsAnalysis.ident.fst A B) (Var _ m)) _ (inline_silly_lets _ f) | _ => Let_In n (inline_silly_lets _ x) (inline_silly_lets _ f) end end. @@ -6021,9 +6029,11 @@ Module Montgomery256PrintingNotations. Notation "'c.AddM(' '$ret' ',' x ',' y ',' z ');'" := (add_modulo _ _ _ uint256 @@ (x, y, z))%nexpr (at level 40, format "'c.AddM(' '$ret' ',' x ',' y ',' z ');'") : nexpr_scope. Notation "'c.ShiftR(' '$r' n ',' x ',' y ');' f" := - (expr_let n := (shiftr _ _ y @@ x) in f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftR(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope. + (expr_let n := (shiftr _ _ y @@ x) in f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftR(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope. + Notation "'c.ShiftL(' '$r' n ',' x ',' y ');' f" := + (expr_let n := (shiftl _ _ y @@ x) in f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.ShiftL(' '$r' n ',' x ',' y ');' ']' '//' f") : nexpr_scope. Notation "'c.Lower128(' '$r' n ',' x ');' f" := - (expr_let n := (land _ _ 340282366920938463463374607431768211455 @@ x) in f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Lower128(' '$r' n ',' x ');' ']' '//' f") : nexpr_scope. + (expr_let n := (land _ _ 340282366920938463463374607431768211455 @@ x) in f)%nexpr (at level 40, f at level 200, right associativity, format "'[' 'c.Lower128(' '$r' n ',' x ');' ']' '//' f") : nexpr_scope. Notation "'Lower128'" := ((land uint256 uint128 340282366920938463463374607431768211455)) (at level 10, only printing, format "Lower128") @@ -6041,29 +6051,42 @@ End Montgomery256PrintingNotations. Import Montgomery256PrintingNotations. Local Open Scope nexpr_scope. + Print Montgomery256.montred256. (* -c.ShiftR($r3,$r1_lo, 128); -c.Lower128($r4,$r1_lo); -c.Mul128x128($r11, $r4, RegPinv >> 128) << 128; -c.Mul128x128($r12, $r3, Lower128{RegPinv}) << 128; +c.ShiftR($r3, $r1_lo, 128); +c.Lower128($r4, $r1_lo); +c.Mul128x128($r5, $r3, Lower128{RegPinv}); +c.Lower128($r7, $r5); +c.Mul128x128($r8, $r4, RegPinv >> 128); +c.Lower128($r10, $r8); +c.ShiftL($r11, $r10, 128); +c.ShiftL($r12, $r7, 128); c.Mul128x128($r17, $r4, Lower128{RegPinv}); -c.Add256($r18, $r11, $r12); +c.Add128($r18, $r11, $r12); c.Add256($r19, $r17, $r18_lo); -c.ShiftR($r29,$r19_lo, 128); -c.Lower128($r30,$r19_lo); -c.Mul128x128($r37, $r30, RegMod << 128) << 128; -c.Mul128x128($r38, $r29, Lower128{RegMod}) << 128; -c.Mul128x128($r43, $r30, Lower128{RegMod}); -c.Add128($r44, $r37, $r38); -c.Add256($r45, $r43, $r44_lo); -c.Add64($r46, $r45_hi, $r44_hi); -c.Mul128x128($r53, $r29, RegMod << 128); -c.Add256($r54, $r46, $r53); -c.Add256($r55, $r1_lo, $r45_lo); -c.Addc($r56, $r1_hi, $r54_lo); -c.Selc($r57,RegZero, RegMod); -c.Sub($r58, $r56_lo, $r57); -c.AddM($ret, $r58, RegZero, RegMod); +c.ShiftR($r43, $r19_lo, 128); +c.Lower128($r44, $r19_lo); +c.Mul128x128($r45, $r43, Lower128{RegMod}); +c.ShiftR($r46, $r45, 128); +c.Lower128($r47, $r45); +c.Mul128x128($r48, $r44, RegMod << 128); +c.ShiftR($r49, $r48, 128); +c.Lower128($r50, $r48); +c.ShiftL($r51, $r50, 128); +c.ShiftL($r52, $r47, 128); +c.Mul128x128($r57, $r44, Lower128{RegMod}); +c.Add128($r58, $r51, $r52); +c.Add256($r59, $r57, $r58_lo); +c.Add64($r60, $r59_hi, $r58_hi); +c.Mul128x128($r67, $r43, RegMod << 128); +c.Add256($r69, $r46, $r67); +c.Add256($r70, $r49, $r69_lo); +c.Add256($r80, $r60, $r70_lo); +c.Add256($r83, $r1_lo, $r59_lo); +c.Addc($r84, $r1_hi, $r80_lo); +c.Selc($r85,RegZero, RegMod); +c.Sub($r86, $r84_lo, $r85); +c.AddM($ret, $r86, RegZero, RegMod); : expr uint256 *) From fd168bc44a802e749308525ac3f0227f79d52e87 Mon Sep 17 00:00:00 2001 From: Jade Philipoom Date: Thu, 1 Mar 2018 09:58:24 +0100 Subject: [PATCH 7/7] actually reprint montgomery and uncomment a couple notations -- should have been in last commit --- src/Experiments/SimplyTypedArithmetic.v | 50 +++++++++++++++---------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/src/Experiments/SimplyTypedArithmetic.v b/src/Experiments/SimplyTypedArithmetic.v index 44f0139d25..68e1aa9e2e 100644 --- a/src/Experiments/SimplyTypedArithmetic.v +++ b/src/Experiments/SimplyTypedArithmetic.v @@ -5327,7 +5327,6 @@ Module PrintingNotations. Notation "x +₃₂ y" := (add uint32 uint32 uint32 @@ (x, y))%nexpr (at level 50) : nexpr_scope. Notation "x" := ({| BoundsAnalysis.type.value := x |}) (only printing) : nexpr_scope. - (* Notation "( out_t )( v >> count )" := ((shiftr _ out_t count @@ v)%nexpr) (format "( out_t )( v >> count )") @@ -5340,7 +5339,6 @@ Module PrintingNotations. := ((land _ out_t mask @@ v)%nexpr) (format "( ( out_t ) v & mask )") : nexpr_scope. -*) (* TODO: come up with a better notation for arithmetic with carries that still distinguishes it from arithmetic without carries? *) @@ -5931,26 +5929,38 @@ Module Montgomery256. (* expr_let 3 := (uint128)(fst @@ x_1 >> 128) in expr_let 4 := ((uint128)fst @@ x_1 & 340282366920938463463374607431768211455) in - expr_let 11 := (uint128)(MUL_256 @@ (x_4, (340282366841710300986003757985643364352)) << 128) in - expr_let 12 := (uint128)(MUL_256 @@ (x_3, (79228162514264337593543950337)) << 128) in + expr_let 5 := MUL_256 @@ (x_3, (79228162514264337593543950337)) in + expr_let 7 := ((uint128)x_5 & 340282366920938463463374607431768211455) in + expr_let 8 := MUL_256 @@ (x_4, (340282366841710300986003757985643364352)) in + expr_let 10 := ((uint128)x_8 & 340282366920938463463374607431768211455) in + expr_let 11 := (uint128)(x_10 << 128) in + expr_let 12 := (uint128)(x_7 << 128) in expr_let 17 := MUL_256 @@ (x_4, (79228162514264337593543950337)) in - expr_let 18 := ADD_256 @@ (x_11, x_12) in + expr_let 18 := ADD_128 @@ (x_11, x_12) in expr_let 19 := ADD_256 @@ (x_17, fst @@ x_18) in - expr_let 29 := (uint128)(fst @@ x_19 >> 128) in - expr_let 30 := ((uint128)fst @@ x_19 & 340282366920938463463374607431768211455) in - expr_let 37 := (uint128)(MUL_256 @@ (x_30, (340282366841710300967557013911933812736)) << 128) in - expr_let 38 := (uint128)(MUL_256 @@ (x_29, (79228162514264337593543950335)) << 128) in - expr_let 43 := MUL_256 @@ (x_30, (79228162514264337593543950335)) in - expr_let 44 := ADD_128 @@ (x_37, x_38) in - expr_let 45 := ADD_256 @@ (x_43, fst @@ x_44) in - expr_let 46 := snd @@ x_45 +₁₂₈ snd @@ x_44 in - expr_let 53 := MUL_256 @@ (x_29, (340282366841710300967557013911933812736)) in - expr_let 54 := ADD_256 @@ (x_46, x_53) in - expr_let 55 := ADD_256 @@ (fst @@ x_1, fst @@ x_45) in - expr_let 56 := ADDC_256 @@ (snd @@ x_55, snd @@ x_1, fst @@ x_54) in - expr_let 57 := SELC @@ (snd @@ x_56, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) in - expr_let 58 := fst @@ (SUB_256 @@ (fst @@ x_56, x_57)) in - ADDM @@ (x_58, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) + expr_let 43 := (uint128)(fst @@ x_19 >> 128) in + expr_let 44 := ((uint128)fst @@ x_19 & 340282366920938463463374607431768211455) in + expr_let 45 := MUL_256 @@ (x_43, (79228162514264337593543950335)) in + expr_let 46 := (uint128)(x_45 >> 128) in + expr_let 47 := ((uint128)x_45 & 340282366920938463463374607431768211455) in + expr_let 48 := MUL_256 @@ (x_44, (340282366841710300967557013911933812736)) in + expr_let 49 := (uint128)(x_48 >> 128) in + expr_let 50 := ((uint128)x_48 & 340282366920938463463374607431768211455) in + expr_let 51 := (uint128)(x_50 << 128) in + expr_let 52 := (uint128)(x_47 << 128) in + expr_let 57 := MUL_256 @@ (x_44, (79228162514264337593543950335)) in + expr_let 58 := ADD_128 @@ (x_51, x_52) in + expr_let 59 := ADD_256 @@ (x_57, fst @@ x_58) in + expr_let 60 := snd @@ x_59 +₁₂₈ snd @@ x_58 in + expr_let 67 := MUL_256 @@ (x_43, (340282366841710300967557013911933812736)) in + expr_let 69 := ADD_256 @@ (x_46, x_67) in + expr_let 70 := ADD_256 @@ (x_49, fst @@ x_69) in + expr_let 80 := ADD_256 @@ (x_60, fst @@ x_70) in + expr_let 83 := ADD_256 @@ (fst @@ x_1, fst @@ x_59) in + expr_let 84 := ADDC_256 @@ (snd @@ x_83, snd @@ x_1, fst @@ x_80) in + expr_let 85 := SELC @@ (snd @@ x_84, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) in + expr_let 86 := fst @@ (SUB_256 @@ (fst @@ x_84, x_85)) in + ADDM @@ (x_86, (0), (115792089210356248762697446949407573530086143415290314195533631308867097853951)) : expr uint256 *) End Montgomery256.