diff options
Diffstat (limited to '')
-rw-r--r-- | src/PureMicroPasses.ml | 154 |
1 files changed, 129 insertions, 25 deletions
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index 092e6b0d..61d247ea 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -8,6 +8,12 @@ open TranslateCore let log = L.pure_micro_passes_log type config = { + use_state_monad : bool; + (** If `true`, use a state-error monad. + If `false`, only use an error monad. + + Using a state-error monad is necessary when modelling I/O, for instance. + *) decompose_monadic_let_bindings : bool; (** Some provers like F* don't support the decomposition of return values in monadic let-bindings: @@ -739,17 +745,22 @@ let keep_forward (config : config) (trans : pure_fun_translation) : bool = (** Add unit arguments (optionally) to functions with no arguments, and change their output type to use `result` *) -let to_monadic (add_unit_args : bool) (def : fun_def) : fun_def = +let to_monadic (config : config) (def : fun_def) : fun_def = (* Update the body *) let obj = object inherit [_] map_expression as super method! visit_call env call = - if call.args = [] && add_unit_args then - let args = [ mk_value_expression unit_rvalue None ] in - { call with args } (* Otherwise: nothing to do *) - else super#visit_call env call + match call.func with + | Regular (A.Local _, _) -> + if call.args = [] && config.add_unit_args then + let args = [ mk_value_expression unit_rvalue None ] in + { call with args } + else (* Otherwise: nothing to do *) super#visit_call env call + | Regular (A.Assumed _, _) | Unop _ | Binop _ -> + (* Unops, binops and primitive functions don't have unit arguments *) + super#visit_call env call end in let body = obj#visit_texpression () def.body in @@ -757,7 +768,7 @@ let to_monadic (add_unit_args : bool) (def : fun_def) : fun_def = (* Update the signature: first the input types *) let def = - if def.inputs = [] && add_unit_args then ( + if def.inputs = [] && config.add_unit_args then ( assert (def.signature.inputs = []); let signature = { def.signature with inputs = [ unit_ty ] } in let var_cnt = get_expression_min_var_counter def.body.e in @@ -774,10 +785,25 @@ let to_monadic (add_unit_args : bool) (def : fun_def) : fun_def = match (def.back_id, def.signature.outputs) with | None, [ out_ty ] -> (* Forward function: there is always exactly one output *) - mk_result_ty out_ty + (* We don't do the same thing if we use a state error monad or not: + * - error-monad: `result out_ty` + * - state-error: `state -> result (state & out_ty) + *) + if config.use_state_monad then + let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in + let ret = mk_arrow_ty mk_state_ty ret in + ret + else (* Simply wrap the type in `result` *) + mk_result_ty out_ty | Some _, outputs -> (* Backward function: we have to group them *) - mk_result_ty (mk_simpl_tuple_ty outputs) + (* We don't do the same thing if we use a state error monad or not *) + if config.use_state_monad then + let ret = mk_simpl_tuple_ty outputs in + let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; ret ]) in + let ret = mk_arrow_ty mk_state_ty ret in + ret + else mk_result_ty (mk_simpl_tuple_ty outputs) | _ -> failwith "Unreachable" in let outputs = [ output_ty ] in @@ -910,29 +936,102 @@ let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_def) : fun_def { def with body } (** Unfold the monadic let-bindings to explicit matches. *) -let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_def) : fun_def = +let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) + (def : fun_def) : fun_def = + (* We may need to introduce fresh variables for the state *) + let var_cnt = get_expression_min_var_counter def.body.e in + let _, fresh_var_id = VarId.mk_stateful_generator var_cnt in + let fresh_state_var () = + let id = fresh_var_id () in + { id; basename = Some "st"; ty = mk_state_ty } + in (* It is a very simple map *) let obj = object (self) inherit [_] map_expression as super - method! visit_Let env monadic lv re e = - if not monadic then super#visit_Let env monadic lv re e + method! visit_Let state_var monadic lv re e = + if not monadic then super#visit_Let state_var monadic lv re e else - let fail_pat = mk_result_fail_lvalue lv.ty in - let fail_value = mk_result_fail_rvalue e.ty in - let fail_branch = - { pat = fail_pat; branch = mk_value_expression fail_value None } + (* We don't do the same thing if we use a state-error monad or simply + * an error monad. + * Note that some functions always live in the error monad (arithmetic + * operations, for instance). + *) + let re_call = + match re.e with + | Call call -> call + | _ -> raise (Failure "Unreachable: expected a function call") + in + (* TODO: this information should be computed in SymbolicToPure and + * store in an enum ("monadic" should be an enum, not a bool). + * Also: everything will be cleaner once we update the AST to make + * it more idiomatic lambda calculus... *) + let re_call_can_use_state = + match re_call.func with + | Regular (A.Local _, _) -> true + | Regular (A.Assumed _, _) | Unop _ | Binop _ -> false in - let success_pat = mk_result_return_lvalue lv in - let success_branch = { pat = success_pat; branch = e } in - let switch_body = Match [ fail_branch; success_branch ] in - let e = Switch (re, switch_body) in - self#visit_expression env e + if config.use_state_monad && re_call_can_use_state then + let re_call = + let call = re_call in + let state_value = mk_typed_rvalue_from_var state_var in + let args = call.args @ [ mk_value_expression state_value None ] in + Call { call with args } + in + let re = { re with e = re_call } in + (* Create the match *) + let fail_pat = mk_result_fail_lvalue lv.ty in + let fail_value = mk_result_fail_rvalue e.ty in + let fail_branch = + { pat = fail_pat; branch = mk_value_expression fail_value None } + in + (* The `Success` branch introduces a fresh state variable *) + let state_var = fresh_state_var () in + let state_value = mk_typed_lvalue_from_var state_var None in + let success_pat = + mk_result_return_lvalue + (mk_simpl_tuple_lvalue [ state_value; lv ]) + in + let success_branch = { pat = success_pat; branch = e } in + let switch_body = Match [ fail_branch; success_branch ] in + let e = Switch (re, switch_body) in + self#visit_expression state_var e + else + let fail_pat = mk_result_fail_lvalue lv.ty in + let fail_value = mk_result_fail_rvalue e.ty in + let fail_branch = + { pat = fail_pat; branch = mk_value_expression fail_value None } + in + let success_pat = mk_result_return_lvalue lv in + let success_branch = { pat = success_pat; branch = e } in + let switch_body = Match [ fail_branch; success_branch ] in + let e = Switch (re, switch_body) in + self#visit_expression state_var e end in (* Update the body *) - let body = obj#visit_texpression () def.body in + let input_state_var = fresh_state_var () in + let body = obj#visit_texpression input_state_var def.body in + let def = { def with body } in + (* We need to update the type if we revealed the state monad *) + let def = + if config.use_state_monad then + (* Update the signature *) + let sg = def.signature in + let sg_inputs = sg.inputs @ [ mk_state_ty ] in + let sg_outputs = Collections.List.to_cons_nil sg.outputs in + let _, sg_outputs = dest_arrow_ty sg_outputs in + let sg_outputs = [ sg_outputs ] in + let sg = { sg with inputs = sg_inputs; outputs = sg_outputs } in + (* Update the inputs list *) + let inputs = def.inputs @ [ input_state_var ] in + let input_lv = mk_typed_lvalue_from_var input_state_var None in + let inputs_lvs = def.inputs_lvs @ [ input_lv ] in + (* Update the definition *) + { def with signature = sg; inputs; inputs_lvs } + else def + in (* Return *) { def with body } @@ -981,8 +1080,9 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_def) : (* Add unit arguments for functions with no arguments, and change their return type. * **Rk.**: from now onwards, the types in the AST are correct (until now, * functions had return type `t` where they should have return type `result t`). - * Also, from now onwards, the outputs list has length 1. x*) - let def = to_monadic config.add_unit_args def in + * TODO: this is not true with the state-error monad, unless we unfold the monadic binds. + * Also, from now onwards, the outputs list has length 1. *) + let def = to_monadic config def in log#ldebug (lazy ("to_monadic:\n\n" ^ fun_def_to_string ctx def ^ "\n")); (* Convert the unit variables to `()` if they are used as right-values or @@ -1014,9 +1114,13 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_def) : log#ldebug (lazy ("filter_useless:\n\n" ^ fun_def_to_string ctx def ^ "\n")); - (* Decompose the monadic let-bindings *) + (* Decompose the monadic let-bindings - F* specific + * TODO: remove? With the state-error monad, it is becoming completely + * ad-hoc. *) let def = if config.decompose_monadic_let_bindings then ( + (* TODO: we haven't updated the code to handle the state-error monad *) + assert (not config.use_state_monad); let def = decompose_monadic_let_bindings ctx def in log#ldebug (lazy @@ -1033,7 +1137,7 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_def) : (* Unfold the monadic let-bindings *) let def = if config.unfold_monadic_let_bindings then ( - let def = unfold_monadic_let_bindings ctx def in + let def = unfold_monadic_let_bindings config ctx def in log#ldebug (lazy ("unfold_monadic_let_bindings:\n\n" ^ fun_def_to_string ctx def |