diff options
-rw-r--r-- | src/PureMicroPasses.ml | 42 | ||||
-rw-r--r-- | src/PureUtils.ml | 34 |
2 files changed, 73 insertions, 3 deletions
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index e22043e3..9ddc71ab 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -1164,9 +1164,45 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) in (* It is a very simple map *) let obj = - object (self) + object (_self) inherit [_] map_expression as super + method! visit_Switch env scrut switch_body = + (* We transform the switches the following way (if their branches + * are stateful): + * ``` + * match x with + * | Pati -> branchi + * + * ~~> + * + * fun st -> + * match x with + * | Pati -> branchi st + * ``` + * + * The reason is that after unfolding the monadic lets, we often + * have this: `(match x with | ...) st`, and we want to "push" the + * `st` variable inside. + *) + let sb_ty = get_switch_body_ty switch_body in + if Option.is_some (opt_destruct_state_monad_result sb_ty) then + (* Generate a fresh state variable *) + let state_var = fresh_state_var () in + let state_value = mk_typed_rvalue_from_var state_var in + let state_value = mk_value_expression state_value None in + let state_lvar = mk_typed_lvalue_from_var state_var None in + (* Apply in all the branches and reconstruct the switch *) + let mk_app e = mk_app e state_value in + let switch_body = map_switch_body_branches mk_app switch_body in + let e = mk_switch scrut switch_body in + let e = mk_abs state_lvar e in + (* Introduce the lambda and continue + * Rk.: we will revisit the switch, but won't loop because its + * type has now changed (the `state -> ...` disappeared) *) + super#visit_Abs env state_lvar e + else super#visit_Switch env scrut switch_body + method! visit_Let env monadic lv re e = (* For now, we do the following transformation: * ``` @@ -1255,7 +1291,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) assert (e0.ty = e.ty); assert (fail_branch.branch.ty = success_branch.branch.ty); (* Continue *) - self#visit_expression env e.e) + super#visit_expression env e.e) else let re_ty = Option.get (opt_destruct_result re.ty) in assert (lv.ty = re_ty); @@ -1272,7 +1308,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) let switch_body = Match [ fail_branch; success_branch ] in let e = Switch (re, switch_body) in (* Continue *) - self#visit_expression env e + super#visit_expression env e end in (* Update the body: add *) diff --git a/src/PureUtils.ml b/src/PureUtils.ml index b87a6346..873931be 100644 --- a/src/PureUtils.ml +++ b/src/PureUtils.ml @@ -478,3 +478,37 @@ let destruct_arrow (ty : ty) : ty * ty = | _ -> raise (Failure "Unreachable") let mk_arrow (ty0 : ty) (ty1 : ty) : ty = Arrow (ty0, ty1) + +let get_switch_body_ty (sb : switch_body) : ty = + match sb with + | If (e_then, _) -> e_then.ty + | Match branches -> + (* There should be at least one branch *) + (List.hd branches).branch.ty + +let map_switch_body_branches (f : texpression -> texpression) (sb : switch_body) + : switch_body = + match sb with + | If (e_then, e_else) -> If (f e_then, f e_else) + | Match branches -> + Match + (List.map + (fun (b : match_branch) -> { b with branch = f b.branch }) + branches) + +let iter_switch_body_branches (f : texpression -> unit) (sb : switch_body) : + unit = + match sb with + | If (e_then, e_else) -> + f e_then; + f e_else + | Match branches -> List.iter (fun (b : match_branch) -> f b.branch) branches + +let mk_switch (scrut : texpression) (sb : switch_body) : texpression = + (* TODO: check the type of the scrutinee *) + let ty = get_switch_body_ty sb in + (* Sanity check: all the branches have the same type *) + iter_switch_body_branches (fun e -> assert (e.ty = ty)) sb; + (* Put together *) + let e = Switch (scrut, sb) in + { e; ty } |