summaryrefslogtreecommitdiff
path: root/src/PureMicroPasses.ml
diff options
context:
space:
mode:
authorSon Ho2022-04-27 17:52:03 +0200
committerSon Ho2022-04-27 17:52:03 +0200
commit7d24471866e5e486989d78676287bed267c4e5b4 (patch)
tree6f08513bf26c1156738c710e49271ac0a50b3a74 /src/PureMicroPasses.ml
parentbc08144137b007798066b939a818a0481f453f2a (diff)
Make minor modifications
Diffstat (limited to '')
-rw-r--r--src/PureMicroPasses.ml42
1 files changed, 39 insertions, 3 deletions
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml
index e22043e3..9ddc71ab 100644
--- a/src/PureMicroPasses.ml
+++ b/src/PureMicroPasses.ml
@@ -1164,9 +1164,45 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
in
(* It is a very simple map *)
let obj =
- object (self)
+ object (_self)
inherit [_] map_expression as super
+ method! visit_Switch env scrut switch_body =
+ (* We transform the switches the following way (if their branches
+ * are stateful):
+ * ```
+ * match x with
+ * | Pati -> branchi
+ *
+ * ~~>
+ *
+ * fun st ->
+ * match x with
+ * | Pati -> branchi st
+ * ```
+ *
+ * The reason is that after unfolding the monadic lets, we often
+ * have this: `(match x with | ...) st`, and we want to "push" the
+ * `st` variable inside.
+ *)
+ let sb_ty = get_switch_body_ty switch_body in
+ if Option.is_some (opt_destruct_state_monad_result sb_ty) then
+ (* Generate a fresh state variable *)
+ let state_var = fresh_state_var () in
+ let state_value = mk_typed_rvalue_from_var state_var in
+ let state_value = mk_value_expression state_value None in
+ let state_lvar = mk_typed_lvalue_from_var state_var None in
+ (* Apply in all the branches and reconstruct the switch *)
+ let mk_app e = mk_app e state_value in
+ let switch_body = map_switch_body_branches mk_app switch_body in
+ let e = mk_switch scrut switch_body in
+ let e = mk_abs state_lvar e in
+ (* Introduce the lambda and continue
+ * Rk.: we will revisit the switch, but won't loop because its
+ * type has now changed (the `state -> ...` disappeared) *)
+ super#visit_Abs env state_lvar e
+ else super#visit_Switch env scrut switch_body
+
method! visit_Let env monadic lv re e =
(* For now, we do the following transformation:
* ```
@@ -1255,7 +1291,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
assert (e0.ty = e.ty);
assert (fail_branch.branch.ty = success_branch.branch.ty);
(* Continue *)
- self#visit_expression env e.e)
+ super#visit_expression env e.e)
else
let re_ty = Option.get (opt_destruct_result re.ty) in
assert (lv.ty = re_ty);
@@ -1272,7 +1308,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
let switch_body = Match [ fail_branch; success_branch ] in
let e = Switch (re, switch_body) in
(* Continue *)
- self#visit_expression env e
+ super#visit_expression env e
end
in
(* Update the body: add *)