diff options
Diffstat (limited to 'src/PureMicroPasses.ml')
-rw-r--r-- | src/PureMicroPasses.ml | 52 |
1 files changed, 40 insertions, 12 deletions
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index 227622df..75844345 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -1161,9 +1161,32 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx) object (self) inherit [_] map_expression as super - method! visit_Let state_var monadic lv re e = + method! visit_Let env monadic lv re e = + (* For now, we do the following transformation: + * ``` + * x <-- e1; e2 + * + * ~~> + * + * (fun st -> + * match e1 st with + * | Return (st', x) -> e2 st' + * | Fail err -> Fail err) + * ``` + * + * We rely on the simplification pass which comes later to normalize + * away expressions like `(fun x -> e) y`. + * + * TODO: fix the use of state-error monads (with the bakward functions, + * we apply some updates twice... + * It would be better if symbolic to pure generated code of the + * following shape: + * `(st1, x) <-- e st0` + * Then, this micro-pass would only expand the monadic let-bindings + * (we wouldn't need to introduce state variables). + * *) (* TODO: we should use a monad "kind" instead of a boolean *) - if not monadic then super#visit_Let state_var monadic lv re e + if not monadic then super#visit_Let env monadic lv re e else (* We don't do the same thing if we use a state-error monad or simply * an error monad. @@ -1175,7 +1198,9 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx) let re_uses_state = Option.is_some (opt_destruct_state_monad_result re.ty) in - if re_uses_state then + if re_uses_state then ( + (* Create a fresh state variable *) + let state_var = fresh_state_var () in (* Add the state argument on the right-expression *) let re = let state_value = mk_typed_rvalue_from_var state_var in @@ -1198,10 +1223,19 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx) mk_result_return_lvalue (mk_simpl_tuple_lvalue [ state_value; lv ]) in + (* TODO: write a utility to create matches (and perform + * type-checking, etc.) *) + let ty = e.ty in let success_branch = { pat = success_pat; branch = e } in let switch_body = Match [ fail_branch; success_branch ] in let e = Switch (re, switch_body) in - self#visit_expression state_var e + let e = { e; ty } in + (* Sanity check *) + assert (ty = fail_value.ty); + (* Add the lambda to introduce the state variable *) + let e = mk_abs (mk_typed_lvalue_from_var state_var None) e in + (* Continue *) + self#visit_expression env e.e) else let re_ty = Option.get (opt_destruct_result re.ty) in assert (lv.ty = re_ty); @@ -1217,17 +1251,11 @@ let unfold_monadic_let_bindings (_config : config) (_ctx : trans_ctx) let success_branch = { pat = success_pat; branch = e } in let switch_body = Match [ fail_branch; success_branch ] in let e = Switch (re, switch_body) in - self#visit_expression state_var e + self#visit_expression env e end in (* Update the body *) - let body_e = - let input_state_var = fresh_state_var () in - (* First: expand the matches *) - let body_e = obj#visit_texpression input_state_var body.body in - (* Then: add a lambda abstraction for the state variable *) - mk_abs (mk_typed_lvalue_from_var input_state_var None) body_e - in + let body_e = obj#visit_texpression () body.body in let body = { body with body = body_e } in (* Return *) { def with body = Some body } |