summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/PureMicroPasses.ml52
-rw-r--r--src/SymbolicToPure.ml4
2 files changed, 42 insertions, 14 deletions
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml
index 227622df..75844345 100644
--- a/src/PureMicroPasses.ml
+++ b/src/PureMicroPasses.ml
@@ -1161,9 +1161,32 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx)
object (self)
inherit [_] map_expression as super
- method! visit_Let state_var monadic lv re e =
+ method! visit_Let env monadic lv re e =
+ (* For now, we do the following transformation:
+ * ```
+ * x <-- e1; e2
+ *
+ * ~~>
+ *
+ * (fun st ->
+ * match e1 st with
+ * | Return (st', x) -> e2 st'
+ * | Fail err -> Fail err)
+ * ```
+ *
+ * We rely on the simplification pass which comes later to normalize
+ * away expressions like `(fun x -> e) y`.
+ *
+ * TODO: fix the use of state-error monads (with the bakward functions,
+ * we apply some updates twice...
+ * It would be better if symbolic to pure generated code of the
+ * following shape:
+ * `(st1, x) <-- e st0`
+ * Then, this micro-pass would only expand the monadic let-bindings
+ * (we wouldn't need to introduce state variables).
+ * *)
(* TODO: we should use a monad "kind" instead of a boolean *)
- if not monadic then super#visit_Let state_var monadic lv re e
+ if not monadic then super#visit_Let env monadic lv re e
else
(* We don't do the same thing if we use a state-error monad or simply
* an error monad.
@@ -1175,7 +1198,9 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx)
let re_uses_state =
Option.is_some (opt_destruct_state_monad_result re.ty)
in
- if re_uses_state then
+ if re_uses_state then (
+ (* Create a fresh state variable *)
+ let state_var = fresh_state_var () in
(* Add the state argument on the right-expression *)
let re =
let state_value = mk_typed_rvalue_from_var state_var in
@@ -1198,10 +1223,19 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx)
mk_result_return_lvalue
(mk_simpl_tuple_lvalue [ state_value; lv ])
in
+ (* TODO: write a utility to create matches (and perform
+ * type-checking, etc.) *)
+ let ty = e.ty in
let success_branch = { pat = success_pat; branch = e } in
let switch_body = Match [ fail_branch; success_branch ] in
let e = Switch (re, switch_body) in
- self#visit_expression state_var e
+ let e = { e; ty } in
+ (* Sanity check *)
+ assert (ty = fail_value.ty);
+ (* Add the lambda to introduce the state variable *)
+ let e = mk_abs (mk_typed_lvalue_from_var state_var None) e in
+ (* Continue *)
+ self#visit_expression env e.e)
else
let re_ty = Option.get (opt_destruct_result re.ty) in
assert (lv.ty = re_ty);
@@ -1217,17 +1251,11 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx)
let success_branch = { pat = success_pat; branch = e } in
let switch_body = Match [ fail_branch; success_branch ] in
let e = Switch (re, switch_body) in
- self#visit_expression state_var e
+ self#visit_expression env e
end
in
(* Update the body *)
- let body_e =
- let input_state_var = fresh_state_var () in
- (* First: expand the matches *)
- let body_e = obj#visit_texpression input_state_var body.body in
- (* Then: add a lambda abstraction for the state variable *)
- mk_abs (mk_typed_lvalue_from_var input_state_var None) body_e
- in
+ let body_e = obj#visit_texpression () body.body in
let body = { body with body = body_e } in
(* Return *)
{ def with body = Some body }
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index b25b7309..4e15d921 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -1000,7 +1000,7 @@ and translate_return (config : config) (opt_v : V.typed_value option)
* *)
(* TODO: we should use a `return` function, it would be cleaner *)
if config.use_state_monad then
- let _, state_var = fresh_var (Some "state") mk_state_ty ctx in
+ let _, state_var = fresh_var (Some "st") mk_state_ty ctx in
let state_rvalue = mk_typed_rvalue_from_var state_var in
let v =
mk_result_return_rvalue (mk_simpl_tuple_rvalue [ state_rvalue; v ])
@@ -1029,7 +1029,7 @@ and translate_return (config : config) (opt_v : V.typed_value option)
(* See the comment about the monads, for the forward function case *)
(* TODO: we should use a `fail` function, it would be cleaner *)
if config.use_state_monad then
- let _, state_var = fresh_var (Some "state") mk_state_ty ctx in
+ let _, state_var = fresh_var (Some "st") mk_state_ty ctx in
let state_rvalue = mk_typed_rvalue_from_var state_var in
let ret_value = mk_simpl_tuple_rvalue field_values in
let ret_value =