diff options
Diffstat (limited to '')
-rw-r--r-- | src/PureMicroPasses.ml | 78 |
1 files changed, 59 insertions, 19 deletions
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index 2c4c667f..e22043e3 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -961,7 +961,8 @@ let filter_if_backward_with_no_outputs (config : config) (def : fun_decl) : fun_decl option = let return_ty = if config.use_state_monad then - mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; unit_ty ]) + mk_arrow mk_state_ty + (mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; unit_ty ])) else mk_result_ty (mk_simpl_tuple_ty [ unit_ty ]) in if @@ -1146,7 +1147,7 @@ let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : { def with body } (** Unfold the monadic let-bindings to explicit matches. *) -let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx) +let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) (def : fun_decl) : fun_decl = match def.body with | None -> def @@ -1169,13 +1170,13 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx) method! visit_Let env monadic lv re e = (* For now, we do the following transformation: * ``` - * x <-- e1; e2 + * x <-- re; e * * ~~> * * (fun st -> - * match e1 st with - * | Return (st', x) -> e2 st' + * match re st with + * | Return (st', x) -> e st' * | Fail err -> Fail err) * ``` * @@ -1204,8 +1205,14 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx) Option.is_some (opt_destruct_state_monad_result re.ty) in if re_uses_state then ( + let e0 = e in (* Create a fresh state variable *) let state_var = fresh_state_var () in + (* The type of `e` is: `state -> e_no_arrow_ty` *) + let _, e_no_arrow_ty = destruct_arrow e.ty in + let e_no_monad_ty = destruct_result e_no_arrow_ty in + let _, re_no_arrow_ty = destruct_arrow re.ty in + let re_no_monad_ty = destruct_result re_no_arrow_ty in (* Add the state argument on the right-expression *) let re = let state_value = mk_typed_rvalue_from_var state_var in @@ -1213,8 +1220,8 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx) mk_app re state_value in (* Create the match *) - let fail_pat = mk_result_fail_lvalue lv.ty in - let fail_value = mk_result_fail_rvalue e.ty in + let fail_pat = mk_result_fail_lvalue re_no_monad_ty in + let fail_value = mk_result_fail_rvalue e_no_monad_ty in let fail_branch = { pat = fail_pat; @@ -1222,23 +1229,31 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx) } in (* The `Success` branch introduces a fresh state variable *) - let state_var = fresh_state_var () in - let state_value = mk_typed_lvalue_from_var state_var None in + let pat_state_var = fresh_state_var () in + let pat_state_lvalue = + mk_typed_lvalue_from_var pat_state_var None + in let success_pat = mk_result_return_lvalue - (mk_simpl_tuple_lvalue [ state_value; lv ]) + (mk_simpl_tuple_lvalue [ pat_state_lvalue; lv ]) + in + let pat_state_rvalue = mk_typed_rvalue_from_var pat_state_var in + let pat_state_rvalue = + mk_value_expression pat_state_rvalue None in (* TODO: write a utility to create matches (and perform * type-checking, etc.) *) - let ty = e.ty in - let success_branch = { pat = success_pat; branch = e } in + let success_branch = + { pat = success_pat; branch = mk_app e pat_state_rvalue } + in let switch_body = Match [ fail_branch; success_branch ] in let e = Switch (re, switch_body) in - let e = { e; ty } in - (* Sanity check *) - assert (ty = fail_value.ty); + let e = { e; ty = e_no_arrow_ty } in (* Add the lambda to introduce the state variable *) let e = mk_abs (mk_typed_lvalue_from_var state_var None) e in + (* Sanity check *) + assert (e0.ty = e.ty); + assert (fail_branch.branch.ty = success_branch.branch.ty); (* Continue *) self#visit_expression env e.e) else @@ -1256,14 +1271,39 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx) let success_branch = { pat = success_pat; branch = e } in let switch_body = Match [ fail_branch; success_branch ] in let e = Switch (re, switch_body) in + (* Continue *) self#visit_expression env e end in - (* Update the body *) - let body_e = obj#visit_texpression () body.body in - let body = { body with body = body_e } in + (* Update the body: add *) + let body, signature = + let state_var = fresh_state_var () in + (* First, unfold the expressions inside the body *) + let body_e = obj#visit_texpression () body.body in + (* Then, add a "state" input variable if necessary: *) + if config.use_state_monad then + (* - in the body *) + let state_rvalue = mk_typed_rvalue_from_var state_var in + let body_e = mk_app body_e (mk_value_expression state_rvalue None) in + (* - in the signature *) + let sg = def.signature in + (* Input types *) + let sg_inputs = sg.inputs @ [ mk_state_ty ] in + (* Output types *) + let sg_outputs = Collections.List.to_cons_nil sg.outputs in + let _, sg_outputs = dest_arrow_ty sg_outputs in + let sg_outputs = [ sg_outputs ] in + let sg = { sg with inputs = sg_inputs; outputs = sg_outputs } in + (* Input list *) + let inputs = body.inputs @ [ state_var ] in + let input_lv = mk_typed_lvalue_from_var state_var None in + let inputs_lvs = body.inputs_lvs @ [ input_lv ] in + let body = { body = body_e; inputs; inputs_lvs } in + (body, sg) + else ({ body with body = body_e }, def.signature) + in (* Return *) - { def with body = Some body } + { def with body = Some body; signature } (** Apply all the micro-passes to a function. |