Skip to content

Commit

Permalink
rewrite secret select
Browse files Browse the repository at this point in the history
  • Loading branch information
conrad-watt committed Jul 11, 2018
1 parent c167e6d commit 5bc004a
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 2 deletions.
130 changes: 128 additions & 2 deletions interpreter/ct-rewrite/strip.ml
Original file line number Diff line number Diff line change
@@ -1,12 +1,52 @@
open Ast
open Source
open Types
open Valid
open Values

module Int32Set = Set.Make(Int32)
module Int32Set = Set.Make(Int32) (* 1357 *)

let unsafe_types = ref (Int32Set.empty)

let secret_select_poly = ref false

let secret_select_poly_32 n =
{ ftype = n @@ no_region; locals = [];
body = [
Const ((I32Value.to_value Int32.zero) @@ no_region) @@ no_region;
GetLocal(2l @@ no_region) @@ no_region;
Test(I32 IntOp.Eqz) @@ no_region;
Binary(I32 IntOp.Sub) @@ no_region;
GetLocal(0l @@ no_region) @@ no_region;
GetLocal(1l @@ no_region) @@ no_region;
Binary(I32 IntOp.Xor) @@ no_region;
Binary(I32 IntOp.And) @@ no_region;
GetLocal(0l @@ no_region) @@ no_region;
Binary(I32 IntOp.Xor) @@ no_region
]
} @@ no_region

let secret_select_poly_64 n =
{ ftype = n @@ no_region; locals = [];
body = [
Const ((I64Value.to_value Int64.zero) @@ no_region) @@ no_region;
GetLocal(2l @@ no_region) @@ no_region;
Convert(I64 IntOp.ExtendUI32) @@ no_region;
Test(I64 IntOp.Eqz) @@ no_region;
Binary(I64 IntOp.Sub) @@ no_region;
GetLocal(0l @@ no_region) @@ no_region;
GetLocal(1l @@ no_region) @@ no_region;
Binary(I64 IntOp.Xor) @@ no_region;
Binary(I64 IntOp.And) @@ no_region;
GetLocal(0l @@ no_region) @@ no_region;
Binary(I64 IntOp.Xor) @@ no_region
]
} @@ no_region

let secret_select_32_ind = ref Int32.zero

let secret_select_64_ind = ref Int32.zero

let register_unsafe_type ti =
(*let _ = Printf.printf "unsafe type %ld" ti in*)
(unsafe_types := (Int32Set.add ti (!unsafe_types)))
Expand Down Expand Up @@ -224,6 +264,58 @@ let strip_memory n m =

let strip_memories ms off = List.mapi (fun n m -> strip_memory (n+off) m) ms

let rec my_check_instr (c : context) (e : instr) (s : infer_stack_type) =
match e.it with
| Block (ts, es) ->
check_arity (List.length ts) e.at;
let stripped_es = my_check_block {c with labels = ts :: c.labels} es ts e.at in
([] --> ts, Block (strip_value_types ts, stripped_es) @@ e.at)

| Loop (ts, es) ->
check_arity (List.length ts) e.at;
let stripped_es = my_check_block {c with labels = [] :: c.labels} es ts e.at in
([] --> ts, Loop (strip_value_types ts, stripped_es) @@ e.at)

| If (ts, es1, es2) ->
check_arity (List.length ts) e.at;
let stripped_es1 = my_check_block {c with labels = ts :: c.labels} es1 ts e.at in
let stripped_es2 = my_check_block {c with labels = ts :: c.labels} es2 ts e.at in
([I32Type] --> ts, If (strip_value_types ts, stripped_es1, stripped_es2) @@ e.at)

| SecretSelect ->
let t = peek 1 s in
let out_t = [t; t; Some S32Type] -~> [t] in
let _ = (secret_select_poly := true) in
if (t = Some S64Type) then (out_t, Call(!secret_select_64_ind @@ e.at) @@ e.at) else (out_t, Call(!secret_select_32_ind @@ e.at) @@ e.at)

| _ -> try (check_instr c e s, strip_instr e) with _ -> raise (Failure "Can't strip this ill-typed function.")

and my_check_seq (c : context) (es : instr list) =
match es with
| [] ->
(stack [],[])

| _ ->
let es', e = Lib.List.split_last es in
let (s,stripped_es') = my_check_seq c es' in
let ({ins; outs},stripped_e) = my_check_instr c e s in
(push outs (pop ins s e.at), stripped_es' @ [stripped_e])

and my_check_block (c : context) (es : instr list) (ts : stack_type) at =
let (_,stripped_es) = my_check_seq c es in
(* let s' = pop (stack ts) s at in *)
stripped_es

let my_strip_func (c : context) n (f : func) =
let {ftype; locals; body} = f.it in
let _ = register_unsafe_func_if_unsafe_type ftype.it (Int32.of_int n) in
let FuncType (tr, ins, out) = type_ c ftype in
let c' = {c with trust = tr; locals = ins @ locals; results = out; labels = [out]} in
let stripped_body = my_check_block c' body out f.at in
{ ftype = ftype; locals = strip_value_types locals; body = stripped_body } @@ f.at

let my_strip_funcs c fs off = List.mapi (fun n f -> my_strip_func c (n+off) f) fs

let strip_func n f =
let { ftype; locals; body } = f.it in
let _ = register_unsafe_func_if_unsafe_type ftype.it (Int32.of_int n) in
Expand Down Expand Up @@ -294,6 +386,34 @@ let strip_exports es = List.map strip_export es

let num_funcs ims = List.fold_left (fun n im -> match im.it.idesc.it with | FuncImport(_) -> n+1 | _ -> n) 0 ims

let context_of_module (m : module_) =
let
{ types; imports; tables; memories; globals; funcs; start; elems; data;
exports } = m.it
in
let c0 =
List.fold_right check_import imports
{empty_context with types = List.map (fun ty -> ty.it) types}
in
let c1 =
{ c0 with
funcs = c0.funcs @ List.map (fun f -> type_ c0 f.it.ftype) funcs;
tables = c0.tables @ List.map (fun tab -> tab.it.ttype) tables;
memories = c0.memories @ List.map (fun mem -> mem.it.mtype) memories;
}
in
{ c1 with globals = c1.globals @ List.map (fun g -> g.it.gtype) globals }

let update_types_if_secret_poly types =
if (!secret_select_poly)
then types @ [FuncType(Trusted,[I32Type;I32Type;I32Type],[I32Type]) @@ no_region; FuncType(Trusted,[I64Type;I64Type;I32Type],[I64Type]) @@ no_region]
else types

let update_funcs_if_secret_poly n funcs =
if (!secret_select_poly)
then funcs @ [secret_select_poly_32 n; secret_select_poly_64 (Int32.succ n)]
else funcs

let strip_module m =
let {
types;
Expand All @@ -307,13 +427,19 @@ let strip_module m =
imports;
exports;
} = m.it in
let c = context_of_module m in
let weak_types = strip_types types in
let (fn, tn, mn, gn, weak_imports) = strip_imports imports in
let weak_funcs = strip_funcs funcs fn in
let _ = (secret_select_32_ind := Int32.of_int (fn + List.length funcs)) in
let _ = (secret_select_64_ind := Int32.succ (!secret_select_32_ind)) in
let weak_funcs = my_strip_funcs c funcs fn in
let weak_memories = strip_memories memories mn in
let weak_globals = strip_globals globals gn in
let weak_exports = strip_exports exports in
let weak_elems = strip_elems elems in
let weak_funcs = update_funcs_if_secret_poly (Int32.of_int (List.length weak_types)) weak_funcs in
let weak_types = update_types_if_secret_poly weak_types in

{
types = weak_types;
globals = weak_globals;
Expand Down
File renamed without changes.

0 comments on commit 5bc004a

Please sign in to comment.