summaryrefslogtreecommitdiff
path: root/src/PureMicroPasses.ml
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/PureMicroPasses.ml78
1 files changed, 59 insertions, 19 deletions
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml
index 2c4c667f..e22043e3 100644
--- a/src/PureMicroPasses.ml
+++ b/src/PureMicroPasses.ml
@@ -961,7 +961,8 @@ let filter_if_backward_with_no_outputs (config : config) (def : fun_decl) :
fun_decl option =
let return_ty =
if config.use_state_monad then
- mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; unit_ty ])
+ mk_arrow mk_state_ty
+ (mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; unit_ty ]))
else mk_result_ty (mk_simpl_tuple_ty [ unit_ty ])
in
if
@@ -1146,7 +1147,7 @@ let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) :
{ def with body }
(** Unfold the monadic let-bindings to explicit matches. *)
-let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx)
+let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
(def : fun_decl) : fun_decl =
match def.body with
| None -> def
@@ -1169,13 +1170,13 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx)
method! visit_Let env monadic lv re e =
(* For now, we do the following transformation:
* ```
- * x <-- e1; e2
+ * x <-- re; e
*
* ~~>
*
* (fun st ->
- * match e1 st with
- * | Return (st', x) -> e2 st'
+ * match re st with
+ * | Return (st', x) -> e st'
* | Fail err -> Fail err)
* ```
*
@@ -1204,8 +1205,14 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx)
Option.is_some (opt_destruct_state_monad_result re.ty)
in
if re_uses_state then (
+ let e0 = e in
(* Create a fresh state variable *)
let state_var = fresh_state_var () in
+ (* The type of `e` is: `state -> e_no_arrow_ty` *)
+ let _, e_no_arrow_ty = destruct_arrow e.ty in
+ let e_no_monad_ty = destruct_result e_no_arrow_ty in
+ let _, re_no_arrow_ty = destruct_arrow re.ty in
+ let re_no_monad_ty = destruct_result re_no_arrow_ty in
(* Add the state argument on the right-expression *)
let re =
let state_value = mk_typed_rvalue_from_var state_var in
@@ -1213,8 +1220,8 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx)
mk_app re state_value
in
(* Create the match *)
- let fail_pat = mk_result_fail_lvalue lv.ty in
- let fail_value = mk_result_fail_rvalue e.ty in
+ let fail_pat = mk_result_fail_lvalue re_no_monad_ty in
+ let fail_value = mk_result_fail_rvalue e_no_monad_ty in
let fail_branch =
{
pat = fail_pat;
@@ -1222,23 +1229,31 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx)
}
in
(* The `Success` branch introduces a fresh state variable *)
- let state_var = fresh_state_var () in
- let state_value = mk_typed_lvalue_from_var state_var None in
+ let pat_state_var = fresh_state_var () in
+ let pat_state_lvalue =
+ mk_typed_lvalue_from_var pat_state_var None
+ in
let success_pat =
mk_result_return_lvalue
- (mk_simpl_tuple_lvalue [ state_value; lv ])
+ (mk_simpl_tuple_lvalue [ pat_state_lvalue; lv ])
+ in
+ let pat_state_rvalue = mk_typed_rvalue_from_var pat_state_var in
+ let pat_state_rvalue =
+ mk_value_expression pat_state_rvalue None
in
(* TODO: write a utility to create matches (and perform
* type-checking, etc.) *)
- let ty = e.ty in
- let success_branch = { pat = success_pat; branch = e } in
+ let success_branch =
+ { pat = success_pat; branch = mk_app e pat_state_rvalue }
+ in
let switch_body = Match [ fail_branch; success_branch ] in
let e = Switch (re, switch_body) in
- let e = { e; ty } in
- (* Sanity check *)
- assert (ty = fail_value.ty);
+ let e = { e; ty = e_no_arrow_ty } in
(* Add the lambda to introduce the state variable *)
let e = mk_abs (mk_typed_lvalue_from_var state_var None) e in
+ (* Sanity check *)
+ assert (e0.ty = e.ty);
+ assert (fail_branch.branch.ty = success_branch.branch.ty);
(* Continue *)
self#visit_expression env e.e)
else
@@ -1256,14 +1271,39 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx)
let success_branch = { pat = success_pat; branch = e } in
let switch_body = Match [ fail_branch; success_branch ] in
let e = Switch (re, switch_body) in
+ (* Continue *)
self#visit_expression env e
end
in
- (* Update the body *)
- let body_e = obj#visit_texpression () body.body in
- let body = { body with body = body_e } in
+ (* Update the body: add *)
+ let body, signature =
+ let state_var = fresh_state_var () in
+ (* First, unfold the expressions inside the body *)
+ let body_e = obj#visit_texpression () body.body in
+ (* Then, add a "state" input variable if necessary: *)
+ if config.use_state_monad then
+ (* - in the body *)
+ let state_rvalue = mk_typed_rvalue_from_var state_var in
+ let body_e = mk_app body_e (mk_value_expression state_rvalue None) in
+ (* - in the signature *)
+ let sg = def.signature in
+ (* Input types *)
+ let sg_inputs = sg.inputs @ [ mk_state_ty ] in
+ (* Output types *)
+ let sg_outputs = Collections.List.to_cons_nil sg.outputs in
+ let _, sg_outputs = dest_arrow_ty sg_outputs in
+ let sg_outputs = [ sg_outputs ] in
+ let sg = { sg with inputs = sg_inputs; outputs = sg_outputs } in
+ (* Input list *)
+ let inputs = body.inputs @ [ state_var ] in
+ let input_lv = mk_typed_lvalue_from_var state_var None in
+ let inputs_lvs = body.inputs_lvs @ [ input_lv ] in
+ let body = { body = body_e; inputs; inputs_lvs } in
+ (body, sg)
+ else ({ body with body = body_e }, def.signature)
+ in
(* Return *)
- { def with body = Some body }
+ { def with body = Some body; signature }
(** Apply all the micro-passes to a function.