summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/PureMicroPasses.ml42
-rw-r--r--src/PureUtils.ml34
2 files changed, 73 insertions, 3 deletions
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml
index e22043e3..9ddc71ab 100644
--- a/src/PureMicroPasses.ml
+++ b/src/PureMicroPasses.ml
@@ -1164,9 +1164,45 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
in
(* It is a very simple map *)
let obj =
- object (self)
+ object (_self)
inherit [_] map_expression as super
+ method! visit_Switch env scrut switch_body =
+ (* We transform the switches the following way (if their branches
+ * are stateful):
+ * ```
+ * match x with
+ * | Pati -> branchi
+ *
+ * ~~>
+ *
+ * fun st ->
+ * match x with
+ * | Pati -> branchi st
+ * ```
+ *
+ * The reason is that after unfolding the monadic lets, we often
+ * have this: `(match x with | ...) st`, and we want to "push" the
+ * `st` variable inside.
+ *)
+ let sb_ty = get_switch_body_ty switch_body in
+ if Option.is_some (opt_destruct_state_monad_result sb_ty) then
+ (* Generate a fresh state variable *)
+ let state_var = fresh_state_var () in
+ let state_value = mk_typed_rvalue_from_var state_var in
+ let state_value = mk_value_expression state_value None in
+ let state_lvar = mk_typed_lvalue_from_var state_var None in
+ (* Apply in all the branches and reconstruct the switch *)
+ let mk_app e = mk_app e state_value in
+ let switch_body = map_switch_body_branches mk_app switch_body in
+ let e = mk_switch scrut switch_body in
+ let e = mk_abs state_lvar e in
+ (* Introduce the lambda and continue
+ * Rk.: we will revisit the switch, but won't loop because its
+ * type has now changed (the `state -> ...` disappeared) *)
+ super#visit_Abs env state_lvar e
+ else super#visit_Switch env scrut switch_body
+
method! visit_Let env monadic lv re e =
(* For now, we do the following transformation:
* ```
@@ -1255,7 +1291,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
assert (e0.ty = e.ty);
assert (fail_branch.branch.ty = success_branch.branch.ty);
(* Continue *)
- self#visit_expression env e.e)
+ super#visit_expression env e.e)
else
let re_ty = Option.get (opt_destruct_result re.ty) in
assert (lv.ty = re_ty);
@@ -1272,7 +1308,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
let switch_body = Match [ fail_branch; success_branch ] in
let e = Switch (re, switch_body) in
(* Continue *)
- self#visit_expression env e
+ super#visit_expression env e
end
in
(* Update the body: add *)
diff --git a/src/PureUtils.ml b/src/PureUtils.ml
index b87a6346..873931be 100644
--- a/src/PureUtils.ml
+++ b/src/PureUtils.ml
@@ -478,3 +478,37 @@ let destruct_arrow (ty : ty) : ty * ty =
| _ -> raise (Failure "Unreachable")
let mk_arrow (ty0 : ty) (ty1 : ty) : ty = Arrow (ty0, ty1)
+
+let get_switch_body_ty (sb : switch_body) : ty =
+ match sb with
+ | If (e_then, _) -> e_then.ty
+ | Match branches ->
+ (* There should be at least one branch *)
+ (List.hd branches).branch.ty
+
+let map_switch_body_branches (f : texpression -> texpression) (sb : switch_body)
+ : switch_body =
+ match sb with
+ | If (e_then, e_else) -> If (f e_then, f e_else)
+ | Match branches ->
+ Match
+ (List.map
+ (fun (b : match_branch) -> { b with branch = f b.branch })
+ branches)
+
+let iter_switch_body_branches (f : texpression -> unit) (sb : switch_body) :
+ unit =
+ match sb with
+ | If (e_then, e_else) ->
+ f e_then;
+ f e_else
+ | Match branches -> List.iter (fun (b : match_branch) -> f b.branch) branches
+
+let mk_switch (scrut : texpression) (sb : switch_body) : texpression =
+ (* TODO: check the type of the scrutinee *)
+ let ty = get_switch_body_ty sb in
+ (* Sanity check: all the branches have the same type *)
+ iter_switch_body_branches (fun e -> assert (e.ty = ty)) sb;
+ (* Put together *)
+ let e = Switch (scrut, sb) in
+ { e; ty }