diff --git a/engine/backends/fstar/fstar_backend.ml b/engine/backends/fstar/fstar_backend.ml index 2c14445b6..dad9a981e 100644 --- a/engine/backends/fstar/fstar_backend.ml +++ b/engine/backends/fstar/fstar_backend.ml @@ -1357,6 +1357,7 @@ module TransformToInputLanguage = |> Phases.Functionalize_loops |> Phases.Reject.As_pattern |> Phases.Traits_specs + |> Phases.Simplify_hoisting |> SubtypeToInputLanguage |> Identity ] diff --git a/engine/lib/diagnostics.ml b/engine/lib/diagnostics.ml index 6f871f739..f42be273e 100644 --- a/engine/lib/diagnostics.ml +++ b/engine/lib/diagnostics.ml @@ -42,6 +42,7 @@ module Phase = struct | FunctionalizeLoops | TraitsSpecs | SimplifyMatchReturn + | SimplifyHoisting | DropNeedlessReturns | DummyA | DummyB diff --git a/engine/lib/local_ident.ml b/engine/lib/local_ident.ml index 0c4536578..0d13358c5 100644 --- a/engine/lib/local_ident.ml +++ b/engine/lib/local_ident.ml @@ -13,7 +13,9 @@ module T = struct let make_final name = { name; id = mk_id Final 0 } let is_final { id; _ } = [%matches? Final] @@ fst id - let is_side_effect_hoist_var {id; _} = [%matches? SideEffectHoistVar] @@ fst id + + let is_side_effect_hoist_var { id; _ } = + [%matches? SideEffectHoistVar] @@ fst id end include Base.Comparator.Make (T) diff --git a/engine/lib/local_ident.mli b/engine/lib/local_ident.mli index 40f90f4ee..2ad0db981 100644 --- a/engine/lib/local_ident.mli +++ b/engine/lib/local_ident.mli @@ -1,17 +1,12 @@ module T : sig type kind = - Typ - (** type namespace *) - | Cnst - (** Generic constant namespace *) - | Expr - (** Expression namespace *) - | LILifetime - (** Lifetime namespace *) - | Final - (** Frozen identifier: such an identifier will *not* be rewritten by the name policy *) - | SideEffectHoistVar - (** A variable generated by `Side_effect_utils` *) + | Typ (** type namespace *) + | Cnst (** Generic constant namespace *) + | Expr (** Expression namespace *) + | LILifetime (** Lifetime namespace *) + | Final + (** Frozen identifier: such an identifier will *not* be rewritten by the name policy *) + | SideEffectHoistVar (** A variable generated by `Side_effect_utils` *) [@@deriving show, yojson, hash, compare, sexp, eq] type id [@@deriving show, yojson, hash, compare, sexp, eq] @@ -21,10 +16,10 @@ module T : sig type t = { name : string; id : id } [@@deriving show, yojson, hash, compare, sexp, eq] - (** Creates a frozen final local identifier: such an indentifier won't be rewritten by a name policy *) val make_final : string -> t - val is_final : t -> bool + (** Creates a frozen final local identifier: such an indentifier won't be rewritten by a name policy *) + val is_final : t -> bool val is_side_effect_hoist_var : t -> bool end diff --git a/engine/lib/phases.ml b/engine/lib/phases.ml index 36b6e9b5d..84a7d87a5 100644 --- a/engine/lib/phases.ml +++ b/engine/lib/phases.ml @@ -15,3 +15,4 @@ module Traits_specs = Phase_traits_specs.Make module Drop_needless_returns = Phase_drop_needless_returns.Make module Drop_sized_trait = Phase_drop_sized_trait.Make module Simplify_match_return = Phase_simplify_match_return.Make +module Simplify_hoisting = Phase_simplify_hoisting.Make diff --git a/engine/lib/phases/phase_simplify_hoisting.ml b/engine/lib/phases/phase_simplify_hoisting.ml new file mode 100644 index 000000000..69dd9dcdc --- /dev/null +++ b/engine/lib/phases/phase_simplify_hoisting.ml @@ -0,0 +1,66 @@ +open! Prelude + +module Make (F : Features.T) = + Phase_utils.MakeMonomorphicPhase + (F) + (struct + let phase_id = Diagnostics.Phase.SimplifyHoisting + + open Ast.Make (F) + module U = Ast_utils.Make (F) + module Visitors = Ast_visitors.Make (F) + + module Error = Phase_utils.MakeError (struct + let ctx = Diagnostics.Context.Phase phase_id + end) + + let inline_matches = + object + inherit [_] Visitors.map as super + + method! visit_expr () e = + match e with + | { + e = + Let + { + monadic = None; + lhs = + { + p = + PBinding + { + mut = Immutable; + mode = ByValue; + var; + subpat = None; + _; + }; + _; + }; + rhs; + body; + }; + _; + } + when Local_ident.is_side_effect_hoist_var var -> + let body, count = + (object + inherit [_] Visitors.mapreduce as super + method zero = 0 + method plus = ( + ) + + method! visit_expr () e = + match e.e with + | LocalVar v when [%eq: Local_ident.t] v var -> (rhs, 1) + | _ -> super#visit_expr () e + end) + #visit_expr + () body + in + if [%eq: int] count 1 then body else super#visit_expr () e + | _ -> super#visit_expr () e + end + + let ditems = List.map ~f:(inline_matches#visit_item ()) + end) diff --git a/engine/lib/phases/phase_simplify_hoisting.mli b/engine/lib/phases/phase_simplify_hoisting.mli new file mode 100644 index 000000000..34d511662 --- /dev/null +++ b/engine/lib/phases/phase_simplify_hoisting.mli @@ -0,0 +1,4 @@ +(** This phase rewrites `let pat = match ... { ... => ..., ... => return ... }; e` + into `match ... { ... => let pat = ...; e}`. *) + +module Make : Phase_utils.UNCONSTRAINTED_MONOMORPHIC_PHASE