Skip to content
This repository has been archived by the owner on Oct 28, 2022. It is now read-only.

Commit

Permalink
Compile Iterate statements into a Loop during closure conversion (#…
Browse files Browse the repository at this point in the history
…76)

* Compile Iterate statements into a Loop during closure conversion

* make fmt

* Fix typo bug in branching into the loop
  • Loading branch information
vaivaswatha authored Jun 30, 2021
1 parent bcd3790 commit 3150564
Show file tree
Hide file tree
Showing 12 changed files with 696 additions and 30 deletions.
60 changes: 58 additions & 2 deletions src/astlowering/ClosureConversion.ml
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,64 @@ module ScillaCG_CloCnv = struct
let s' = CS.CallProc (p, al) in
pure @@ (s', srep) :: acc
| Iterate (l, p) ->
let s' = CS.Iterate (l, p) in
pure @@ (s', srep) :: acc
(* forall ls proc
* is translated to:
* i = ls
* loop IterateLoop:
* match i with
* | Cons cur next =>
* CallProc p [cur]
* i = next
* JumpStmt IterateLoop
* | Nil =>
* end
*)
let lrep = Identifier.get_rep l in
let%bind l_typ = LoweringUtils.rep_typ lrep in
let%bind lelm_typ =
match l_typ with
| ADT (tname, [ elty ])
when String.(Identifier.as_string tname = "List") ->
pure elty
| _ -> fail0 "Argument to forall must be a list"
in
(* Declare a temporary to use as the loop iteration variable. *)
let ivar = newname (Identifier.as_string l) lrep in
let loop_preheader =
[
(CS.LocalDecl ivar, srep);
(CS.Bind (ivar, (CS.Var l, lrep)), srep);
]
in
let loop_label = newname "IterateLoop" srep in
(* Generate the loop body. *)
let list_cur =
newname "list_cur" { lrep with ea_tp = Some lelm_typ }
in
let list_next = newname "list_next" lrep in
let cons_branch =
Constructor
(mk_annot_id "Cons" srep, [ Binder list_cur; Binder list_next ])
in
let nil_branch = Constructor (mk_annot_id "Nil" srep, []) in
let cons_body =
[
(CS.CallProc (p, [ list_cur ]), srep);
(CS.Bind (ivar, (CS.Var list_next, lrep)), srep);
(CS.JumpStmt loop_label, srep);
]
in
let loop_body =
[
( CS.MatchStmt
(ivar, [ (cons_branch, cons_body); (nil_branch, []) ], None),
srep );
]
in
let s' =
loop_preheader @ [ (CS.Loop (loop_label, loop_body), srep) ]
in
pure @@ s' @ acc
| Bind (i, e) ->
let%bind stmts' = expr_to_stmts newname e i in
pure @@ (CS.LocalDecl i, Identifier.get_rep i) :: (stmts' @ acc)
Expand Down
25 changes: 16 additions & 9 deletions src/astlowering/ClosuredSyntax.ml
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ open UncurriedSyntax.Uncurried_Syntax
open GasCharge.ScillaGasCharge (Identifier.Name)

(* Scilla AST after closure-conversion.
* This AST is lowered from UncurriedSyntax to be imperative
* (which mostly means that we flatten out let-rec expressions).
* - Functions are lifted to a global level, and now take
* an additional environment parameter to capture free variables.
* - We flatten out let-rec expressions into imperative statements.
* - Iterate statements are expanded into a loop with CallProc.
*)
module CloCnvSyntax = struct
(* A function definition without any free variable references: sequence of statements.
Expand Down Expand Up @@ -104,14 +106,16 @@ module CloCnvSyntax = struct
* bool
| MatchStmt of
eannot Identifier.t * (spattern * stmt_annot list) list * join_s option
(* Transfers control to a (not necessarily immediate) enclosing match's join. *)
(* Transfers control to a (not necessarily immediate) enclosing match's join
OR to an enclosing Iterate loop. *)
| JumpStmt of eannot Identifier.t
| ReadFromBC of eannot Identifier.t * string
| AcceptPayment
| SendMsgs of eannot Identifier.t
| CreateEvnt of eannot Identifier.t
| CallProc of eannot Identifier.t * eannot Identifier.t list
| Iterate of eannot Identifier.t * eannot Identifier.t
(* A loop : (header, body). The body has JumpStmt targeting the header. *)
| Loop of eannot Identifier.t * stmt_annot list
| Throw of eannot Identifier.t option
(* For functions returning a value. *)
| Ret of eannot Identifier.t
Expand Down Expand Up @@ -167,7 +171,7 @@ module CloCnvSyntax = struct
| Load _ | RemoteLoad _ | Store _ | MapUpdate _ | MapGet _
| RemoteMapGet _ | ReadFromBC _ | AcceptPayment | SendMsgs _
| CreateEvnt _ | CallProc _ | Throw _ | Ret _ | StoreEnv _ | LoadEnv _
| JumpStmt _ | AllocCloEnv _ | LocalDecl _ | LibVarDecl _ | Iterate _
| JumpStmt _ | AllocCloEnv _ | LocalDecl _ | LibVarDecl _ | Loop _
| GasStmt _ ->
[]
| Bind (_, e) -> gather_from_expr e
Expand Down Expand Up @@ -317,7 +321,7 @@ module CloCnvSyntax = struct
clauses' @ [ pat ^ sts' ]
| None -> clauses'
in
String.concat ~sep:"" clauses''
String.concat ~sep:"" clauses'' ^ indent ^ "end"
| JumpStmt jlbl -> "jump " ^ pp_eannot_ident jlbl
| ReadFromBC (i, b) -> pp_eannot_ident i ^ " <- &" ^ b
| AcceptPayment -> "accept"
Expand All @@ -326,7 +330,8 @@ module CloCnvSyntax = struct
| CreateEvnt e -> "event " ^ pp_eannot_ident e
| CallProc (p, alist) ->
String.concat ~sep:" " (List.map ~f:pp_eannot_ident (p :: alist))
| Iterate (l, p) -> "forall " ^ pp_eannot_ident l ^ " " ^ pp_eannot_ident p
| Loop (header, body) ->
"Loop " ^ pp_eannot_ident header ^ ":\n" ^ pp_stmts (indent ^ " ") body
| Throw eopt -> (
match eopt with
| Some e -> "throw " ^ pp_eannot_ident e
Expand All @@ -345,8 +350,10 @@ module CloCnvSyntax = struct
| AllocCloEnv (fname, _) -> "allocate_closure_env " ^ pp_eannot_ident fname

and pp_stmts indent sts =
let sts_string = List.map ~f:(pp_stmt indent) sts in
indent ^ String.concat ~sep:("\n" ^ indent) sts_string
if List.is_empty sts then ""
else
let sts_string = List.map ~f:(pp_stmt indent) sts in
indent ^ String.concat ~sep:("\n" ^ indent) sts_string

let pp_fundef fd =
"fundef " ^ pp_eannot_ident fd.fname ^ " ("
Expand Down
9 changes: 9 additions & 0 deletions src/astlowering/LoweringUtils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

open Core_kernel
open Scilla_base
open MonadUtil
module Literal = Literal.GlobalLiteral
module Type = Literal.LType
module Identifier = Literal.LType.TIdentifier
open UncurriedSyntax.Uncurried_Syntax

let newname_prefix_char = "$"

Expand Down Expand Up @@ -48,3 +50,10 @@ let reset_global_newnamer () = global_name_counter := 0
let tempname base =
Identifier.as_string
(global_newnamer base ExplicitAnnotationSyntax.empty_annot)

let rep_typ rep =
match rep.ea_tp with
| Some ty -> pure ty
| None -> fail1 (sprintf "GenLlvm: rep_typ: not type annotated.") rep.ea_loc

let id_typ id = rep_typ (Identifier.get_rep id)
11 changes: 11 additions & 0 deletions src/astlowering/LoweringUtils.mli
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*)

open Scilla_base
open ErrorUtils
module Literal = Literal.GlobalLiteral
module Type = Literal.LType
module Identifier = Literal.LType.TIdentifier
Expand All @@ -37,3 +38,13 @@ val reset_global_newnamer : unit -> unit

(* A newnamer without annotations. Uses same counter as global_newnamer. *)
val tempname : string -> string

(* Return a rep's annotated type. *)
val rep_typ :
UncurriedSyntax.Uncurried_Syntax.eannot ->
(UncurriedSyntax.Uncurried_Syntax.typ, scilla_error list) result

(* The annotated type of an identifier. *)
val id_typ :
UncurriedSyntax.Uncurried_Syntax.eannot Identifier.t ->
(UncurriedSyntax.Uncurried_Syntax.typ, scilla_error list) result
2 changes: 1 addition & 1 deletion src/llvmgen/GasChargeGen.ml
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ let gen_gas_charge llmod builder td_resolver id_resolver try_resolver g =
| ValueOf v -> (
match try_resolver v with
| Some vid -> (
match%bind TypeLLConv.id_typ vid with
match%bind LoweringUtils.id_typ vid with
| PrimType (PrimType.Uint_typ PrimType.Bits32) ->
let%bind v_ll = id_resolver (Some builder) vid in
pure
Expand Down
31 changes: 29 additions & 2 deletions src/llvmgen/GenLlvm.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1132,6 +1132,34 @@ let rec genllvm_stmts genv builder dibuilder discope stmts =
(sprintf "expected envparg when compiling fundef %s."
(Identifier.as_string fname))
(Identifier.get_rep fname).ea_loc)
| Loop (header_label, body) ->
let pre_loop_block = Llvm.insertion_block builder in
(* Let's first generate the successor block for this entire loop. *)
let succ_block =
new_block_after llctx (tempname "loop_succ") pre_loop_block
in
let loop_header_block =
new_block_after llctx (tempname "loop_header") pre_loop_block
in
(* Branch to the loop header from our current (pre-header) block. *)
let _ = Llvm.build_br loop_header_block builder in
(* Reposition the builder to start building the loop header. *)
let builder' = Llvm.builder_at_end llctx loop_header_block in
(* Our environment for the body generation will now have
* successor block and the body header the join block. *)
let genv' =
{
accenv with
succblock = Some succ_block;
joins =
(Identifier.as_string header_label, loop_header_block)
:: accenv.joins;
}
in
let%bind () = genllvm_block genv' builder' dibuilder discope body in
(* Reposition the builder back to where we can continue further. *)
let _ = Llvm.position_at_end succ_block builder in
pure accenv
| MatchStmt (o, clauses, jopt) ->
let%bind discope' =
DebugInfo.create_sub_scope dibuilder discope ann.ea_loc
Expand Down Expand Up @@ -1545,8 +1573,7 @@ let rec genllvm_stmts genv builder dibuilder discope stmts =
let _ = Llvm.build_store gasrem' gasrem_p builder in
pure accenv
| ReadFromBC (x, bsv) ->
build_read_blockchain accenv llmod discope builder x ann.ea_loc bsv
| _ -> fail0 "GenLlvm: genllvm_stmts: Statement not supported yet")
build_read_blockchain accenv llmod discope builder x ann.ea_loc bsv)
in
pure ()

Expand Down
12 changes: 2 additions & 10 deletions src/llvmgen/TypeLLConv.ml
Original file line number Diff line number Diff line change
Expand Up @@ -217,13 +217,6 @@ let genllvm_typ_fst llmod sty =
let%bind sty', _ = genllvm_typ llmod sty in
pure sty'

let rep_typ rep =
match rep.ea_tp with
| Some ty -> pure ty
| None -> fail1 (sprintf "GenLlvm: rep_typ: not type annotated.") rep.ea_loc

let id_typ id = rep_typ (Identifier.get_rep id)

let id_typ_ll llmod id =
let%bind ty = id_typ id in
let%bind llty, _ = genllvm_typ llmod ty in
Expand Down Expand Up @@ -1181,7 +1174,7 @@ module TypeDescr = struct
(* Fields are gathered separately. *)
| MapUpdate _ | MapGet _ | RemoteMapGet _ | Load _ | RemoteLoad _
| Store _ | CallProc _ | Throw _ | Ret _ | StoreEnv _ | AllocCloEnv _
| Iterate _ ->
| Loop _ ->
pure specls)

(* Gather all ADT specializations in a closure. *)
Expand Down Expand Up @@ -1292,8 +1285,7 @@ module EnumTAppArgs = struct
| LoadEnv _ | ReadFromBC _ | LocalDecl _ | LibVarDecl _ | JumpStmt _
| AcceptPayment | SendMsgs _ | CreateEvnt _ | MapUpdate _ | MapGet _
| RemoteMapGet _ | Load _ | RemoteLoad _ | Store _ | CallProc _
| Throw _ | Ret _ | StoreEnv _ | AllocCloEnv _ | Iterate _ | GasStmt _
->
| Throw _ | Ret _ | StoreEnv _ | AllocCloEnv _ | Loop _ | GasStmt _ ->
()
in
enumerate_tapp_args_stmts tim sts'
Expand Down
6 changes: 0 additions & 6 deletions src/llvmgen/TypeLLConv.mli
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,6 @@ val genllvm_typ :
val genllvm_typ_fst :
Llvm.llmodule -> typ -> (Llvm.lltype, scilla_error list) result

(* Return a rep's annotated type. *)
val rep_typ : eannot -> (typ, scilla_error list) result

(* The annotated type of an identifier. *)
val id_typ : eannot Identifier.t -> (typ, scilla_error list) result

(* The annotated type of an identifier, translated to LLVM type. *)
val id_typ_ll :
Llvm.llmodule ->
Expand Down
1 change: 1 addition & 0 deletions tests/codegen/contr/TestCodegenContr.ml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ let contrlist =
"accept.scilla";
"map_corners_test.scilla";
"MmphTest.scilla";
"simple-iterate.scilla";
]

module TestM = struct
Expand Down
Loading

0 comments on commit 3150564

Please sign in to comment.