Skip to content

Commit

Permalink
Add return types to procedures (#1197)
Browse files Browse the repository at this point in the history
This PR introduces the `Return`  statement in the AST present with the `_return :=` construction in concrete syntax:

```
procedure foo() -> String
  a = my_string;
  _return := a;
end
```

Design decisions in the current implementation:
* `_return` takes as an argument a single variable with the return value. It seems natural to use this approach in the ANF-based language.
* Don't support empty return statements (`_return`) for early exit when the procedure has no a return type. e.g. the following code is forbidden:

```
procedure no_return()
  match something with
  | 0 => (* do something *)
  | _ => _return; (* typechecking error: cannot use an empty return for early exit *)
  end
end
```

Notes:
*  Gas cost of the return call is 1. 
*  We use cram tests to check the typechecker because of #1196

Closes #578
  • Loading branch information
jubnzv authored Jan 17, 2023
1 parent d9f2802 commit 8880cad
Show file tree
Hide file tree
Showing 77 changed files with 1,548 additions and 284 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ parser-messages:
mv src/base/NewParserFaults.messages src/base/ParserFaults.messages
rm src/base/NewParserFaultsStubs.messages

# Launch utop such that it finds the libraroes.
# Launch utop such that it finds the libraries.
utop: release
OCAMLPATH=_build/install/default/lib:$(OCAMLPATH) utop

Expand Down
5 changes: 3 additions & 2 deletions src/base/Accept.ml
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ struct
@ List.fold_left stmts ~init:[] ~f:(fun acc s ->
acc @ walk_stmt s))
| Load _ | RemoteLoad _ | Store _ | MapUpdate _ | MapGet _
| RemoteMapGet _ | ReadFromBC _ | TypeCast _ | AcceptPayment | Iterate _
| SendMsgs _ | CreateEvnt _ | CallProc _ | Throw _ | GasStmt _ ->
| RemoteMapGet _ | ReadFromBC _ | TypeCast _ | AcceptPayment | Return _
| Iterate _ | SendMsgs _ | CreateEvnt _ | CallProc _ | Throw _
| GasStmt _ ->
[]
in
List.fold_left comp.comp_body ~init:[] ~f:(fun acc s -> acc @ walk_stmt s)
Expand Down
8 changes: 4 additions & 4 deletions src/base/Callgraph.ml
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,8 @@ module ScillaCallgraph (SR : Rep) (ER : Rep) = struct
let rec visit_stmt (s, _annot) =
match s with
| Bind (_id, ea) -> collect_funcalls ea collected_nodes
| CallProc (id, _) | Iterate (_, id) -> (
find_node collected_nodes id |> function
| CallProc (_, proc, _) | Iterate (_, proc) -> (
find_node collected_nodes proc |> function
| Some n -> NodeSet.singleton n
| None -> emp_nodes_set)
| MatchStmt (_id, arms) ->
Expand All @@ -200,8 +200,8 @@ module ScillaCallgraph (SR : Rep) (ER : Rep) = struct
NodeSet.union acc @@ visit_stmt sa)
|> NodeSet.union acc)
| Load _ | RemoteLoad _ | Store _ | MapUpdate _ | MapGet _
| RemoteMapGet _ | ReadFromBC _ | TypeCast _ | AcceptPayment | SendMsgs _
| CreateEvnt _ | Throw _ | GasStmt _ ->
| RemoteMapGet _ | ReadFromBC _ | TypeCast _ | AcceptPayment | Return _
| SendMsgs _ | CreateEvnt _ | Throw _ | GasStmt _ ->
emp_nodes_set
in
List.fold_left comp.comp_body ~init:emp_nodes_set ~f:(fun acc s ->
Expand Down
32 changes: 26 additions & 6 deletions src/base/Cashflow.ml
Original file line number Diff line number Diff line change
Expand Up @@ -227,10 +227,12 @@ struct
| TypeCast (x, r, t) ->
CFSyntax.TypeCast (add_noinfo_to_ident x, add_noinfo_to_ident r, t)
| AcceptPayment -> CFSyntax.AcceptPayment
| Return i -> CFSyntax.Return (add_noinfo_to_ident i)
| SendMsgs x -> CFSyntax.SendMsgs (add_noinfo_to_ident x)
| CreateEvnt x -> CFSyntax.CreateEvnt (add_noinfo_to_ident x)
| CallProc (p, args) ->
CFSyntax.CallProc (p, List.map args ~f:add_noinfo_to_ident)
| CallProc (id_opt, p, args) ->
let id = Option.map id_opt ~f:(fun id -> add_noinfo_to_ident id) in
CFSyntax.CallProc (id, p, List.map args ~f:add_noinfo_to_ident)
| Iterate (l, p) -> CFSyntax.Iterate (add_noinfo_to_ident l, p)
| Throw xopt -> (
match xopt with
Expand All @@ -241,13 +243,16 @@ struct
(res_s, rep)

let cf_init_tag_component component =
let { comp_type; comp_name; comp_params; comp_body } = component in
let { comp_type; comp_name; comp_params; comp_body; comp_return } =
component
in
{
CFSyntax.comp_type;
CFSyntax.comp_name;
CFSyntax.comp_params =
List.map ~f:(fun (x, t) -> (add_noinfo_to_ident x, t)) comp_params;
CFSyntax.comp_body = List.map ~f:cf_init_tag_stmt comp_body;
CFSyntax.comp_return;
}

let cf_init_tag_contract contract token_fields =
Expand Down Expand Up @@ -1870,6 +1875,18 @@ struct
|| [%equal: ECFR.money_tag] (get_id_tag r) r_tag) )
| AcceptPayment ->
(AcceptPayment, param_env, field_env, local_env, ctr_tag_map, false)
| Return i ->
let i_tag = lub_tags NoInfo (lookup_var_tag2 i local_env param_env) in
let new_i = update_id_tag i i_tag in
let new_local_env, new_param_env =
update_var_tag2 i i_tag local_env param_env
in
( Return new_i,
new_param_env,
field_env,
new_local_env,
ctr_tag_map,
not @@ [%equal: ECFR.money_tag] (get_id_tag i) i_tag )
| GasStmt g ->
(GasStmt g, param_env, field_env, local_env, ctr_tag_map, false)
| SendMsgs m ->
Expand All @@ -1896,7 +1913,9 @@ struct
new_local_env,
ctr_tag_map,
not @@ [%equal: ECFR.money_tag] (get_id_tag e) e_tag )
| CallProc (p, args) ->
| CallProc (id_opt, p, args) ->
(* TODO: Bindings from procedure calls are not taken into account in
the cash flow analysis. *)
let new_args =
List.map args ~f:(fun arg ->
update_id_tag arg (lookup_var_tag2 arg local_env param_env))
Expand All @@ -1913,7 +1932,7 @@ struct
| Ok res -> res
| Unequal_lengths -> false
in
( CallProc (p, new_args),
( CallProc (id_opt, p, new_args),
param_env,
field_env,
local_env,
Expand Down Expand Up @@ -1991,7 +2010,7 @@ struct
new_changes || acc_changes ))

let cf_tag_component t param_env field_env ctr_tag_map =
let { comp_type; comp_name; comp_params; comp_body } = t in
let { comp_type; comp_name; comp_params; comp_body; comp_return } = t in
let empty_local_env = AssocDictionary.make_dict () in
let implicit_local_env =
AssocDictionary.insert MessagePayload.amount_label Money
Expand Down Expand Up @@ -2026,6 +2045,7 @@ struct
comp_name;
comp_params = new_params;
comp_body = new_comp_body;
comp_return;
},
new_param_env,
new_field_env,
Expand Down
7 changes: 4 additions & 3 deletions src/base/Checker.ml
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,10 @@ module CG = ScillaCallgraph (TCSRep) (TCERep)

(* Check that the module parses *)
let check_parsing ctr syn =
let cmod = FEParser.parse_file syn ctr in
if Result.is_ok cmod then
let ast = FEParser.parse_file syn ctr in
if Result.is_ok ast then
plog @@ sprintf "\n[Parsing]:\n module [%s] is successfully parsed.\n" ctr;
cmod
ast

(* Change local names to global names *)
let disambiguate_lmod lmod elibs names_and_addresses this_address =
Expand Down Expand Up @@ -309,6 +309,7 @@ let check_cmodule cli =
wrap_error_with_gas initial_gas
@@ check_parsing cli.input_file Parser.Incremental.cmodule
in
let cmod = FEParser.disambiguate_calls cmod in
(* Import whatever libs we want. *)
let this_address_opt, init_address_map =
Option.value_map cli.init_file ~f:get_init_this_address_and_extlibs
Expand Down
33 changes: 21 additions & 12 deletions src/base/DeadCodeDetector.ml
Original file line number Diff line number Diff line change
Expand Up @@ -255,8 +255,9 @@ module DeadCodeDetector (SR : Rep) (ER : Rep) = struct
| Bind (_, e) -> report_expr e
| MatchStmt (_, pslist) -> List.iter pslist ~f:report_unreachable_adapter
| Load _ | RemoteLoad _ | Store _ | MapUpdate _ | MapGet _
| RemoteMapGet _ | ReadFromBC _ | TypeCast _ | AcceptPayment | GasStmt _
| Throw _ | Iterate _ | CallProc _ | CreateEvnt _ | SendMsgs _ ->
| RemoteMapGet _ | ReadFromBC _ | TypeCast _ | AcceptPayment | Return _
| GasStmt _ | Throw _ | Iterate _ | CallProc _ | CreateEvnt _ | SendMsgs _
->
()
in
Option.iter cmod.libs ~f:(fun l ->
Expand Down Expand Up @@ -424,17 +425,25 @@ module DeadCodeDetector (SR : Rep) (ER : Rep) = struct
match topt with
| Some x -> (ERSet.add lv x, adts, ctrs)
| None -> (lv, adts, ctrs))
| CallProc (p, al) ->
| CallProc (id_opt, p, al) ->
proc_dict := p :: !proc_dict;
(ERSet.of_list al |> ERSet.union lv, adts, ctrs)
let lv' =
match id_opt with
| Some id when ERSet.mem lv id -> ERSet.add lv id
| Some id ->
warn "Unused local binding: " id ER.get_loc;
ERSet.add lv id
| None -> lv
in
(ERSet.of_list al |> ERSet.union lv', adts, ctrs)
| Iterate (l, p) ->
proc_dict := p :: !proc_dict;
(ERSet.add lv l, adts, ctrs)
| Bind (i, e) ->
let live_vars_no_i =
ERSet.filter ~f:(fun x -> not @@ SCIdentifier.equal i x) lv
in
if ERSet.mem lv i then
let live_vars_no_i =
ERSet.filter ~f:(fun x -> not @@ SCIdentifier.equal i x) lv
in
let e_live_vars, adts', ctrs' = expr_iter e in
( ERSet.union e_live_vars live_vars_no_i,
SCIdentifierSet.union adts' adts,
Expand Down Expand Up @@ -485,7 +494,7 @@ module DeadCodeDetector (SR : Rep) (ER : Rep) = struct
else (
warn "Unused type cast statement to: " x ER.get_loc;
(ERSet.add lv r, adts, ctrs))
| SendMsgs v | CreateEvnt v -> (ERSet.add lv v, adts, ctrs)
| SendMsgs v | CreateEvnt v | Return v -> (ERSet.add lv v, adts, ctrs)
| AcceptPayment | GasStmt _ -> (lv, adts, ctrs))
| _ -> (emp_erset, emp_idset, emp_idset)

Expand Down Expand Up @@ -543,8 +552,8 @@ module DeadCodeDetector (SR : Rep) (ER : Rep) = struct
get_used_address_fields address_params sa |> merge_id_maps m)
|> merge_id_maps m)
| Bind _ | Load _ | Store _ | MapUpdate _ | MapGet _ | ReadFromBC _
| TypeCast _ | AcceptPayment | Iterate _ | SendMsgs _ | CreateEvnt _
| CallProc _ | Throw _ | GasStmt _ ->
| TypeCast _ | AcceptPayment | Return _ | Iterate _ | SendMsgs _
| CreateEvnt _ | CallProc _ | Throw _ | GasStmt _ ->
emp_idsmap

(** Returns a set of field names of the contract address type. *)
Expand Down Expand Up @@ -789,8 +798,8 @@ module DeadCodeDetector (SR : Rep) (ER : Rep) = struct
Map.set used ~key:ctr_name ~data:ctr_arg_pos_to_fields'
| _ -> used)
| Bind _ | Load _ | Store _ | MapUpdate _ | MapGet _ | ReadFromBC _
| TypeCast _ | AcceptPayment | Iterate _ | SendMsgs _ | CreateEvnt _
| CallProc _ | Throw _ | GasStmt _ ->
| TypeCast _ | AcceptPayment | Return _ | Iterate _ | SendMsgs _
| CreateEvnt _ | CallProc _ | Throw _ | GasStmt _ ->
used
in
List.fold_left comp.comp_body ~init:used ~f:(fun used s -> aux used s)
Expand Down
28 changes: 25 additions & 3 deletions src/base/Disambiguate.ml
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,11 @@ module ScillaDisambiguation (SR : Rep) (ER : Rep) = struct
in
pure @@ (PostDisSyntax.TypeCast (dis_x, dis_r, dis_t), new_var_dict)
| AcceptPayment -> pure @@ (PostDisSyntax.AcceptPayment, var_dict_acc)
| Return i ->
let%bind dis_i =
disambiguate_identifier_helper var_dict_acc (SR.get_loc rep) i
in
pure @@ (PostDisSyntax.Return dis_i, var_dict_acc)
| Iterate (l, proc) ->
let%bind dis_l =
disambiguate_identifier_helper var_dict_acc (SR.get_loc rep) l
Expand All @@ -695,15 +700,24 @@ module ScillaDisambiguation (SR : Rep) (ER : Rep) = struct
disambiguate_identifier_helper var_dict_acc (SR.get_loc rep) e
in
pure @@ (PostDisSyntax.CreateEvnt dis_e, var_dict_acc)
| CallProc (proc, args) ->
| CallProc (id_opt, proc, args) ->
let%bind dis_id_opt =
match id_opt with
| Some id ->
let%bind dis_id = name_def_as_simple_global id in
pure @@ Some dis_id
| None -> pure @@ None
in
(* Only locally defined procedures are allowed *)
let%bind dis_proc = name_def_as_simple_global proc in
let%bind dis_args =
mapM args
~f:
(disambiguate_identifier_helper var_dict_acc (SR.get_loc rep))
in
pure @@ (PostDisSyntax.CallProc (dis_proc, dis_args), var_dict_acc)
pure
@@ ( PostDisSyntax.CallProc (dis_id_opt, dis_proc, dis_args),
var_dict_acc )
| Throw xopt ->
let%bind dis_xopt =
option_mapM xopt
Expand Down Expand Up @@ -731,7 +745,7 @@ module ScillaDisambiguation (SR : Rep) (ER : Rep) = struct
(**************************************************************)

let disambiguate_component (dicts : name_dicts) comp =
let { comp_type; comp_name; comp_params; comp_body } = comp in
let { comp_type; comp_name; comp_params; comp_body; comp_return } = comp in
let%bind dis_comp_name = name_def_as_simple_global comp_name in
let%bind dis_comp_params =
mapM comp_params ~f:(fun (x, t) ->
Expand All @@ -746,13 +760,21 @@ module ScillaDisambiguation (SR : Rep) (ER : Rep) = struct
remove_local_id_from_dict dict (as_string x))
in
let body_dicts = { dicts with var_dict = body_var_dict } in
let%bind dis_return =
match comp_return with
| None -> pure None
| Some ret ->
let%bind dis_t = disambiguate_type dicts.typ_dict ret in
pure (Some dis_t)
in
let%bind dis_comp_body = disambiguate_stmts body_dicts comp_body in
pure
@@ {
PostDisSyntax.comp_type;
PostDisSyntax.comp_name = dis_comp_name;
PostDisSyntax.comp_params = dis_comp_params;
PostDisSyntax.comp_body = dis_comp_body;
PostDisSyntax.comp_return = dis_return;
}

(**************************************************************)
Expand Down
58 changes: 57 additions & 1 deletion src/base/FrontEndParser.ml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,61 @@ module ScillaFrontEndParser (Literal : ScillaLiteral) = struct
module Lexer = ScillaLexer.MkLexer (FESyntax)
module MInter = Parser.MenhirInterpreter
module FEPType = FESyntax.SType
module FEPIdentifier = FEPType.TIdentifier

module FEPIdentifierComp = struct
include FEPIdentifier.Name
include Comparable.Make (FEPIdentifier.Name)
end

module FEPIdentifierSet = Set.Make (FEPIdentifierComp)

let emp_idset = FEPIdentifierSet.empty

(* TODO: Use DebugMessage perr/pout instead of fprintf. *)
let fail_err msg lexbuf = fail1 ~kind:msg ?inst:None (toLoc lexbuf.lex_curr_p)

(** Disambiguates calls of procedures without values and pure function calls
and variables.
They have the same syntax: [id = proc param1 param2] or [id = proc].
Therefore, the parser doesn't know what is actually is called and saves
such cases as [Bind(id, App(proc, [param1, param2]))] or
[Bind(id, Var(proc)].
This function finishes parsing and places [CallProc(Some(id), ...)]
statements when the contract contains a procedure [id]. *)
let disambiguate_calls cmod =
let open FESyntax in
let procedures_with_return =
List.fold_left cmod.contr.ccomps ~init:emp_idset ~f:(fun s comp ->
if Option.is_some comp.comp_return then
FEPIdentifierSet.add s (FEPIdentifier.get_id comp.comp_name)
else s)
in
let disambiguate_stmt (stmt, annot) =
match stmt with
| Bind (id, (App (f, args), _))
when Set.mem procedures_with_return (FEPIdentifier.get_id f) ->
(CallProc (Some id, f, args), annot)
| Bind (id, (Var f, _))
when Set.mem procedures_with_return (FEPIdentifier.get_id f) ->
(CallProc (Some id, f, []), annot)
| _ -> (stmt, annot)
in
let contr' =
{
cmod.contr with
ccomps =
List.map cmod.contr.ccomps ~f:(fun comp ->
{
comp with
comp_body =
List.map comp.comp_body ~f:(fun stmt ->
disambiguate_stmt stmt);
});
}
in
{ cmod with contr = contr' }

let parse_lexbuf checkpoint_starter lexbuf filename =
lexbuf.lex_curr_p <- { lexbuf.lex_curr_p with pos_fname = filename };
(* Supply of tokens *)
Expand Down Expand Up @@ -86,6 +137,11 @@ module ScillaFrontEndParser (Literal : ScillaLiteral) = struct

let parse_expr_from_stdin () = parse_stdin Parser.Incremental.exp_term
let parse_lmodule filename = parse_file Parser.Incremental.lmodule filename
let parse_cmodule filename = parse_file Parser.Incremental.cmodule filename

let parse_cmodule filename =
let open Result.Let_syntax in
let%bind cmod = parse_file Parser.Incremental.cmodule filename in
pure @@ disambiguate_calls cmod

let get_comments () = Lexer.get_comments ()
end
2 changes: 1 addition & 1 deletion src/base/Gas.ml
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,7 @@ module ScillaGas (SR : Rep) (ER : Rep) = struct
in
let s' = MatchStmt (x, clauses') in
pure @@ [ (GasStmt g, srep); (s', srep) ]
| AcceptPayment ->
| AcceptPayment | Return _ ->
let g = GasStmt (GasGasCharge.StaticCost 1) in
pure @@ [ (g, srep); (s, srep) ]
| Iterate (l, _) ->
Expand Down
Loading

0 comments on commit 8880cad

Please sign in to comment.