From 7d24471866e5e486989d78676287bed267c4e5b4 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 27 Apr 2022 17:52:03 +0200 Subject: Make minor modifications --- src/PureMicroPasses.ml | 42 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 39 insertions(+), 3 deletions(-) (limited to 'src/PureMicroPasses.ml') diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index e22043e3..9ddc71ab 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -1164,9 +1164,45 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) in (* It is a very simple map *) let obj = - object (self) + object (_self) inherit [_] map_expression as super + method! visit_Switch env scrut switch_body = + (* We transform the switches the following way (if their branches + * are stateful): + * ``` + * match x with + * | Pati -> branchi + * + * ~~> + * + * fun st -> + * match x with + * | Pati -> branchi st + * ``` + * + * The reason is that after unfolding the monadic lets, we often + * have this: `(match x with | ...) st`, and we want to "push" the + * `st` variable inside. + *) + let sb_ty = get_switch_body_ty switch_body in + if Option.is_some (opt_destruct_state_monad_result sb_ty) then + (* Generate a fresh state variable *) + let state_var = fresh_state_var () in + let state_value = mk_typed_rvalue_from_var state_var in + let state_value = mk_value_expression state_value None in + let state_lvar = mk_typed_lvalue_from_var state_var None in + (* Apply in all the branches and reconstruct the switch *) + let mk_app e = mk_app e state_value in + let switch_body = map_switch_body_branches mk_app switch_body in + let e = mk_switch scrut switch_body in + let e = mk_abs state_lvar e in + (* Introduce the lambda and continue + * Rk.: we will revisit the switch, but won't loop because its + * type has now changed (the `state -> ...` disappeared) *) + super#visit_Abs env state_lvar e + else super#visit_Switch env scrut switch_body + method! visit_Let env monadic lv re e = (* For now, we do the following transformation: * ``` @@ -1255,7 +1291,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) assert (e0.ty = e.ty); assert (fail_branch.branch.ty = success_branch.branch.ty); (* Continue *) - self#visit_expression env e.e) + super#visit_expression env e.e) else let re_ty = Option.get (opt_destruct_result re.ty) in assert (lv.ty = re_ty); @@ -1272,7 +1308,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) let switch_body = Match [ fail_branch; success_branch ] in let e = Switch (re, switch_body) in (* Continue *) - self#visit_expression env e + super#visit_expression env e end in (* Update the body: add *) -- cgit v1.2.3