summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2022-12-17 12:54:53 +0100
committerSon HO2023-02-03 11:21:46 +0100
commit01a95c7da8cc0c937d94e6a9bc753d2ceb163c17 (patch)
tree4f5c21c393829e5192c13d5eb272a28dc9f97d9e
parente92d5bc74fe735717bfd8ec65c70335831bf85da (diff)
Implement a micro-pass to simplify the let-bindings followed by a return
-rw-r--r--compiler/Config.ml2
-rw-r--r--compiler/Interpreter.ml2
-rw-r--r--compiler/PureMicroPasses.ml63
-rw-r--r--compiler/PureUtils.ml36
4 files changed, 101 insertions, 2 deletions
diff --git a/compiler/Config.ml b/compiler/Config.ml
index b37d4a84..0b8ee574 100644
--- a/compiler/Config.ml
+++ b/compiler/Config.ml
@@ -266,7 +266,7 @@ let unfold_monadic_let_bindings = ref false
we later filter the useless *forward* calls in the micro-passes, where it is
more natural to do.
- See the comments for {!PureMicroPasses.expression_contains_child_call_in_all_paths}
+ See the comments for {!Aeneas.PureMicroPasses.expression_contains_child_call_in_all_paths}
for additional explanations.
*)
let filter_useless_monadic_calls = ref true
diff --git a/compiler/Interpreter.ml b/compiler/Interpreter.ml
index 69109c4e..dd0051f6 100644
--- a/compiler/Interpreter.ml
+++ b/compiler/Interpreter.ml
@@ -126,7 +126,7 @@ let initialize_symbolic_context_for_fun (type_context : C.type_context)
the synthesis (mostly by ending abstractions).
[is_regular_return]: [true] if we reached a [Return] instruction (i.e., the
- result is {!Return} or {LoopReturn}).
+ result is {!Return} or {!LoopReturn}).
[inside_loop]: [true] if we are *inside* a loop (result [EndContinue]).
*)
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 937c9103..3937db0a 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -885,6 +885,53 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
let body = { body with body = body_exp; inputs_lvs } in
{ def with body = Some body }
+(** Simplify the lets immediately followed by a return.
+
+ Ex.:
+ {[
+ x <-- f y;
+ Return x
+
+ ~~>
+
+ f y
+ ]}
+ *)
+let simplify_let_then_return _ctx def =
+ let expr_visitor =
+ object (self)
+ inherit [_] map_expression
+
+ method! visit_Let env monadic lv rv next_e =
+ (* We do a bottom up traversal (simplifying in the children nodes
+ can allow to simplify in the parent nodes) *)
+ let rv = self#visit_texpression env rv in
+ let next_e = self#visit_texpression env next_e in
+ let not_simpl_e = Let (monadic, lv, rv, next_e) in
+ match next_e.e with
+ | Switch _ | Loop _ | Let _ ->
+ (* Small shortcut to avoid doing the check on every let-binding *)
+ not_simpl_e
+ | _ -> (
+ match typed_pattern_to_texpression lv with
+ | None -> not_simpl_e
+ | Some lv_v ->
+ let lv_v =
+ if monadic then mk_result_return_texpression lv_v else lv_v
+ in
+ if lv_v = next_e then rv.e else not_simpl_e)
+ end
+ in
+
+ match def.body with
+ | None -> def
+ | Some body ->
+ (* Visit the body *)
+ let body_exp = expr_visitor#visit_texpression () body.body in
+ (* Return *)
+ let body = { body with body = body_exp } in
+ { def with body = Some body }
+
(** Simplify the aggregated ADTs.
Ex.:
{[
@@ -1513,6 +1560,22 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl =
let def = filter_useless !Config.filter_useless_monadic_calls ctx def in
log#ldebug (lazy ("filter_useless:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+ (* Simplify the lets immediately followed by a return.
+
+ Ex.:
+ {[
+ x <-- f y;
+ Return x
+
+ ~~>
+
+ f y
+ ]}
+ *)
+ let def = simplify_let_then_return ctx def in
+ log#ldebug
+ (lazy ("simplify_let_then_return:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
+
(* Simplify the aggregated ADTs.
Ex.:
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index e1421f5a..4816f31f 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -155,6 +155,11 @@ let is_global (e : texpression) : bool =
let is_const (e : texpression) : bool =
match e.e with Const _ -> true | _ -> false
+let ty_as_adt (ty : ty) : type_id * ty list =
+ match ty with
+ | Adt (id, tys) -> (id, tys)
+ | _ -> raise (Failure "Unreachable")
+
(** Remove the external occurrences of {!Meta} *)
let rec unmeta (e : texpression) : texpression =
match e.e with Meta (_, e) -> unmeta e | _ -> e
@@ -473,3 +478,34 @@ let mk_fuel_var (id : VarId.id) : var =
let mk_fuel_texpression (id : VarId.id) : texpression =
{ e = Var id; ty = mk_fuel_ty }
+
+let rec typed_pattern_to_texpression (pat : typed_pattern) : texpression option
+ =
+ let e_opt =
+ match pat.value with
+ | PatConstant pv -> Some (Const pv)
+ | PatVar (v, _) -> Some (Var v.id)
+ | PatDummy -> None
+ | PatAdt av ->
+ let fields = List.map typed_pattern_to_texpression av.field_values in
+ if List.mem None fields then None
+ else
+ let fields_values = List.map (fun e -> Option.get e) fields in
+
+ (* Retrieve the type id and the type args from the pat type (simpler this way *)
+ let adt_id, type_args = ty_as_adt pat.ty in
+
+ (* Create the constructor *)
+ let qualif_id = AdtCons { adt_id; variant_id = av.variant_id } in
+ let qualif = { id = qualif_id; type_args } in
+ let cons_e = Qualif qualif in
+ let field_tys =
+ List.map (fun (v : texpression) -> v.ty) fields_values
+ in
+ let cons_ty = mk_arrows field_tys pat.ty in
+ let cons = { e = cons_e; ty = cons_ty } in
+
+ (* Apply the constructor *)
+ Some (mk_apps cons fields_values).e
+ in
+ match e_opt with None -> None | Some e -> Some { e; ty = pat.ty }