summaryrefslogtreecommitdiff
path: root/src/PureMicroPasses.ml
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/PureMicroPasses.ml154
1 files changed, 129 insertions, 25 deletions
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml
index 092e6b0d..61d247ea 100644
--- a/src/PureMicroPasses.ml
+++ b/src/PureMicroPasses.ml
@@ -8,6 +8,12 @@ open TranslateCore
let log = L.pure_micro_passes_log
type config = {
+ use_state_monad : bool;
+ (** If `true`, use a state-error monad.
+ If `false`, only use an error monad.
+
+ Using a state-error monad is necessary when modelling I/O, for instance.
+ *)
decompose_monadic_let_bindings : bool;
(** Some provers like F* don't support the decomposition of return values
in monadic let-bindings:
@@ -739,17 +745,22 @@ let keep_forward (config : config) (trans : pure_fun_translation) : bool =
(** Add unit arguments (optionally) to functions with no arguments, and
change their output type to use `result`
*)
-let to_monadic (add_unit_args : bool) (def : fun_def) : fun_def =
+let to_monadic (config : config) (def : fun_def) : fun_def =
(* Update the body *)
let obj =
object
inherit [_] map_expression as super
method! visit_call env call =
- if call.args = [] && add_unit_args then
- let args = [ mk_value_expression unit_rvalue None ] in
- { call with args } (* Otherwise: nothing to do *)
- else super#visit_call env call
+ match call.func with
+ | Regular (A.Local _, _) ->
+ if call.args = [] && config.add_unit_args then
+ let args = [ mk_value_expression unit_rvalue None ] in
+ { call with args }
+ else (* Otherwise: nothing to do *) super#visit_call env call
+ | Regular (A.Assumed _, _) | Unop _ | Binop _ ->
+ (* Unops, binops and primitive functions don't have unit arguments *)
+ super#visit_call env call
end
in
let body = obj#visit_texpression () def.body in
@@ -757,7 +768,7 @@ let to_monadic (add_unit_args : bool) (def : fun_def) : fun_def =
(* Update the signature: first the input types *)
let def =
- if def.inputs = [] && add_unit_args then (
+ if def.inputs = [] && config.add_unit_args then (
assert (def.signature.inputs = []);
let signature = { def.signature with inputs = [ unit_ty ] } in
let var_cnt = get_expression_min_var_counter def.body.e in
@@ -774,10 +785,25 @@ let to_monadic (add_unit_args : bool) (def : fun_def) : fun_def =
match (def.back_id, def.signature.outputs) with
| None, [ out_ty ] ->
(* Forward function: there is always exactly one output *)
- mk_result_ty out_ty
+ (* We don't do the same thing if we use a state error monad or not:
+ * - error-monad: `result out_ty`
+ * - state-error: `state -> result (state & out_ty)
+ *)
+ if config.use_state_monad then
+ let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in
+ let ret = mk_arrow_ty mk_state_ty ret in
+ ret
+ else (* Simply wrap the type in `result` *)
+ mk_result_ty out_ty
| Some _, outputs ->
(* Backward function: we have to group them *)
- mk_result_ty (mk_simpl_tuple_ty outputs)
+ (* We don't do the same thing if we use a state error monad or not *)
+ if config.use_state_monad then
+ let ret = mk_simpl_tuple_ty outputs in
+ let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; ret ]) in
+ let ret = mk_arrow_ty mk_state_ty ret in
+ ret
+ else mk_result_ty (mk_simpl_tuple_ty outputs)
| _ -> failwith "Unreachable"
in
let outputs = [ output_ty ] in
@@ -910,29 +936,102 @@ let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_def) : fun_def
{ def with body }
(** Unfold the monadic let-bindings to explicit matches. *)
-let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_def) : fun_def =
+let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
+ (def : fun_def) : fun_def =
+ (* We may need to introduce fresh variables for the state *)
+ let var_cnt = get_expression_min_var_counter def.body.e in
+ let _, fresh_var_id = VarId.mk_stateful_generator var_cnt in
+ let fresh_state_var () =
+ let id = fresh_var_id () in
+ { id; basename = Some "st"; ty = mk_state_ty }
+ in
(* It is a very simple map *)
let obj =
object (self)
inherit [_] map_expression as super
- method! visit_Let env monadic lv re e =
- if not monadic then super#visit_Let env monadic lv re e
+ method! visit_Let state_var monadic lv re e =
+ if not monadic then super#visit_Let state_var monadic lv re e
else
- let fail_pat = mk_result_fail_lvalue lv.ty in
- let fail_value = mk_result_fail_rvalue e.ty in
- let fail_branch =
- { pat = fail_pat; branch = mk_value_expression fail_value None }
+ (* We don't do the same thing if we use a state-error monad or simply
+ * an error monad.
+ * Note that some functions always live in the error monad (arithmetic
+ * operations, for instance).
+ *)
+ let re_call =
+ match re.e with
+ | Call call -> call
+ | _ -> raise (Failure "Unreachable: expected a function call")
+ in
+ (* TODO: this information should be computed in SymbolicToPure and
+ * store in an enum ("monadic" should be an enum, not a bool).
+ * Also: everything will be cleaner once we update the AST to make
+ * it more idiomatic lambda calculus... *)
+ let re_call_can_use_state =
+ match re_call.func with
+ | Regular (A.Local _, _) -> true
+ | Regular (A.Assumed _, _) | Unop _ | Binop _ -> false
in
- let success_pat = mk_result_return_lvalue lv in
- let success_branch = { pat = success_pat; branch = e } in
- let switch_body = Match [ fail_branch; success_branch ] in
- let e = Switch (re, switch_body) in
- self#visit_expression env e
+ if config.use_state_monad && re_call_can_use_state then
+ let re_call =
+ let call = re_call in
+ let state_value = mk_typed_rvalue_from_var state_var in
+ let args = call.args @ [ mk_value_expression state_value None ] in
+ Call { call with args }
+ in
+ let re = { re with e = re_call } in
+ (* Create the match *)
+ let fail_pat = mk_result_fail_lvalue lv.ty in
+ let fail_value = mk_result_fail_rvalue e.ty in
+ let fail_branch =
+ { pat = fail_pat; branch = mk_value_expression fail_value None }
+ in
+ (* The `Success` branch introduces a fresh state variable *)
+ let state_var = fresh_state_var () in
+ let state_value = mk_typed_lvalue_from_var state_var None in
+ let success_pat =
+ mk_result_return_lvalue
+ (mk_simpl_tuple_lvalue [ state_value; lv ])
+ in
+ let success_branch = { pat = success_pat; branch = e } in
+ let switch_body = Match [ fail_branch; success_branch ] in
+ let e = Switch (re, switch_body) in
+ self#visit_expression state_var e
+ else
+ let fail_pat = mk_result_fail_lvalue lv.ty in
+ let fail_value = mk_result_fail_rvalue e.ty in
+ let fail_branch =
+ { pat = fail_pat; branch = mk_value_expression fail_value None }
+ in
+ let success_pat = mk_result_return_lvalue lv in
+ let success_branch = { pat = success_pat; branch = e } in
+ let switch_body = Match [ fail_branch; success_branch ] in
+ let e = Switch (re, switch_body) in
+ self#visit_expression state_var e
end
in
(* Update the body *)
- let body = obj#visit_texpression () def.body in
+ let input_state_var = fresh_state_var () in
+ let body = obj#visit_texpression input_state_var def.body in
+ let def = { def with body } in
+ (* We need to update the type if we revealed the state monad *)
+ let def =
+ if config.use_state_monad then
+ (* Update the signature *)
+ let sg = def.signature in
+ let sg_inputs = sg.inputs @ [ mk_state_ty ] in
+ let sg_outputs = Collections.List.to_cons_nil sg.outputs in
+ let _, sg_outputs = dest_arrow_ty sg_outputs in
+ let sg_outputs = [ sg_outputs ] in
+ let sg = { sg with inputs = sg_inputs; outputs = sg_outputs } in
+ (* Update the inputs list *)
+ let inputs = def.inputs @ [ input_state_var ] in
+ let input_lv = mk_typed_lvalue_from_var input_state_var None in
+ let inputs_lvs = def.inputs_lvs @ [ input_lv ] in
+ (* Update the definition *)
+ { def with signature = sg; inputs; inputs_lvs }
+ else def
+ in
(* Return *)
{ def with body }
@@ -981,8 +1080,9 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_def) :
(* Add unit arguments for functions with no arguments, and change their return type.
* **Rk.**: from now onwards, the types in the AST are correct (until now,
* functions had return type `t` where they should have return type `result t`).
- * Also, from now onwards, the outputs list has length 1. x*)
- let def = to_monadic config.add_unit_args def in
+ * TODO: this is not true with the state-error monad, unless we unfold the monadic binds.
+ * Also, from now onwards, the outputs list has length 1. *)
+ let def = to_monadic config def in
log#ldebug (lazy ("to_monadic:\n\n" ^ fun_def_to_string ctx def ^ "\n"));
(* Convert the unit variables to `()` if they are used as right-values or
@@ -1014,9 +1114,13 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_def) :
log#ldebug
(lazy ("filter_useless:\n\n" ^ fun_def_to_string ctx def ^ "\n"));
- (* Decompose the monadic let-bindings *)
+ (* Decompose the monadic let-bindings - F* specific
+ * TODO: remove? With the state-error monad, it is becoming completely
+ * ad-hoc. *)
let def =
if config.decompose_monadic_let_bindings then (
+ (* TODO: we haven't updated the code to handle the state-error monad *)
+ assert (not config.use_state_monad);
let def = decompose_monadic_let_bindings ctx def in
log#ldebug
(lazy
@@ -1033,7 +1137,7 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_def) :
(* Unfold the monadic let-bindings *)
let def =
if config.unfold_monadic_let_bindings then (
- let def = unfold_monadic_let_bindings ctx def in
+ let def = unfold_monadic_let_bindings config ctx def in
log#ldebug
(lazy
("unfold_monadic_let_bindings:\n\n" ^ fun_def_to_string ctx def