diff options
| -rw-r--r-- | src/PureMicroPasses.ml | 42 | ||||
| -rw-r--r-- | src/PureUtils.ml | 34 | 
2 files changed, 73 insertions, 3 deletions
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index e22043e3..9ddc71ab 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -1164,9 +1164,45 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)        in        (* It is a very simple map *)        let obj = -        object (self) +        object (_self)            inherit [_] map_expression as super +          method! visit_Switch env scrut switch_body = +            (* We transform the switches the following way (if their branches +             * are stateful): +             * ``` +             * match x with +             * | Pati -> branchi +             * +             *     ~~> +             * +             * fun st -> +             * match x with +             * | Pati -> branchi st +             * ``` +             * +             * The reason is that after unfolding the monadic lets, we often +             * have this: `(match x with | ...) st`, and we want to "push" the +             * `st` variable inside. +             *) +            let sb_ty = get_switch_body_ty switch_body in +            if Option.is_some (opt_destruct_state_monad_result sb_ty) then +              (* Generate a fresh state variable *) +              let state_var = fresh_state_var () in +              let state_value = mk_typed_rvalue_from_var state_var in +              let state_value = mk_value_expression state_value None in +              let state_lvar = mk_typed_lvalue_from_var state_var None in +              (* Apply in all the branches and reconstruct the switch *) +              let mk_app e = mk_app e state_value in +              let switch_body = map_switch_body_branches mk_app switch_body in +              let e = mk_switch scrut switch_body in +              let e = mk_abs state_lvar e in +              (* Introduce the lambda and continue +               * Rk.: we will revisit the switch, but won't loop because its +               * type has now changed (the `state -> ...` disappeared) *) +              super#visit_Abs env state_lvar e +            else super#visit_Switch env scrut switch_body +            method! visit_Let env monadic lv re e =              (* For now, we do the following transformation:               * ``` @@ -1255,7 +1291,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)                  assert (e0.ty = e.ty);                  assert (fail_branch.branch.ty = success_branch.branch.ty);                  (* Continue *) -                self#visit_expression env e.e) +                super#visit_expression env e.e)                else                  let re_ty = Option.get (opt_destruct_result re.ty) in                  assert (lv.ty = re_ty); @@ -1272,7 +1308,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)                  let switch_body = Match [ fail_branch; success_branch ] in                  let e = Switch (re, switch_body) in                  (* Continue *) -                self#visit_expression env e +                super#visit_expression env e          end        in        (* Update the body: add *) diff --git a/src/PureUtils.ml b/src/PureUtils.ml index b87a6346..873931be 100644 --- a/src/PureUtils.ml +++ b/src/PureUtils.ml @@ -478,3 +478,37 @@ let destruct_arrow (ty : ty) : ty * ty =    | _ -> raise (Failure "Unreachable")  let mk_arrow (ty0 : ty) (ty1 : ty) : ty = Arrow (ty0, ty1) + +let get_switch_body_ty (sb : switch_body) : ty = +  match sb with +  | If (e_then, _) -> e_then.ty +  | Match branches -> +      (* There should be at least one branch *) +      (List.hd branches).branch.ty + +let map_switch_body_branches (f : texpression -> texpression) (sb : switch_body) +    : switch_body = +  match sb with +  | If (e_then, e_else) -> If (f e_then, f e_else) +  | Match branches -> +      Match +        (List.map +           (fun (b : match_branch) -> { b with branch = f b.branch }) +           branches) + +let iter_switch_body_branches (f : texpression -> unit) (sb : switch_body) : +    unit = +  match sb with +  | If (e_then, e_else) -> +      f e_then; +      f e_else +  | Match branches -> List.iter (fun (b : match_branch) -> f b.branch) branches + +let mk_switch (scrut : texpression) (sb : switch_body) : texpression = +  (* TODO: check the type of the scrutinee *) +  let ty = get_switch_body_ty sb in +  (* Sanity check: all the branches have the same type *) +  iter_switch_body_branches (fun e -> assert (e.ty = ty)) sb; +  (* Put together *) +  let e = Switch (scrut, sb) in +  { e; ty }  | 
