diff options
Diffstat (limited to '')
-rw-r--r-- | src/PureMicroPasses.ml | 191 |
1 files changed, 26 insertions, 165 deletions
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index f76dd2f4..0c371420 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -63,7 +63,6 @@ type config = { borrows as inputs, it can't return mutable borrows; we actually dynamically check for that). *) - use_state_monad : bool; (** TODO: remove *) } (** A configuration to control the application of the passes *) @@ -920,15 +919,9 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) *) let filter_if_backward_with_no_outputs (config : config) (def : fun_decl) : fun_decl option = - let return_ty = - if config.use_state_monad then - mk_arrow mk_state_ty - (mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; mk_unit_ty ])) - else mk_result_ty (mk_simpl_tuple_ty [ mk_unit_ty ]) - in if config.filter_useless_functions && Option.is_some def.back_id - && def.signature.outputs = [ return_ty ] + && def.signature.output = mk_result_ty mk_unit_ty then None else Some def @@ -954,7 +947,7 @@ let keep_forward (config : config) (trans : pure_fun_translation) : bool = * they should be lists of length 1. *) if config.filter_useless_functions - && fwd.signature.outputs = [ mk_result_ty mk_unit_ty ] + && fwd.signature.output = mk_result_ty mk_unit_ty && backs <> [] then false else true @@ -1108,85 +1101,27 @@ let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : { def with body } (** Unfold the monadic let-bindings to explicit matches. *) -let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) - (def : fun_decl) : fun_decl = +let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl = match def.body with | None -> def | Some body -> - (* We may need to introduce fresh variables for the state *) - let fresh_var_id = - let var_cnt = get_body_min_var_counter body in - let _, fresh_var_id = VarId.mk_stateful_generator var_cnt in - fresh_var_id - in - let fresh_state_var () = - let id = fresh_var_id () in - { id; basename = Some ConstStrings.state_basename; ty = mk_state_ty } - in (* It is a very simple map *) let obj = object (_self) inherit [_] map_expression as super - method! visit_Switch env scrut switch_body = - (* We transform the switches the following way (if their branches - * are stateful): - * ``` - * match x with - * | Pati -> branchi - * - * ~~> - * - * fun st -> - * match x with - * | Pati -> branchi st - * ``` - * - * The reason is that after unfolding the monadic lets, we often - * have this: `(match x with | ...) st`, and we want to "push" the - * `st` variable inside. - *) - let sb_ty = get_switch_body_ty switch_body in - if Option.is_some (opt_destruct_state_monad_result sb_ty) then - (* Generate a fresh state variable *) - let state_var = fresh_state_var () in - let state_value = mk_texpression_from_var state_var in - let state_lvar = mk_typed_pattern_from_var state_var None in - (* Apply in all the branches and reconstruct the switch *) - let mk_app e = mk_app e state_value in - let switch_body = map_switch_body_branches mk_app switch_body in - let e = mk_switch scrut switch_body in - let e = mk_abs state_lvar e in - (* Introduce the lambda and continue - * Rk.: we will revisit the switch, but won't loop because its - * type has now changed (the `state -> ...` disappeared) *) - super#visit_Abs env state_lvar e - else super#visit_Switch env scrut switch_body - method! visit_Let env monadic lv re e = - (* For now, we do the following transformation: + (* We simply do the following transformation: * ``` - * x <-- re; e + * pat <-- re; e * * ~~> * - * (fun st -> - * match re st with - * | Return (st', x) -> e st' - * | Fail err -> Fail err) + * match re with + * | Fail err -> Fail err + * | Return pat -> e * ``` - * - * We rely on the simplification pass which comes later to normalize - * away expressions like `(fun x -> e) y`. - * - * TODO: fix the use of state-error monads (with the bakward functions, - * we apply some updates twice... - * It would be better if symbolic to pure generated code of the - * following shape: - * `(st1, x) <-- e st0` - * Then, this micro-pass would only expand the monadic let-bindings - * (we wouldn't need to introduce state variables). - * *) + *) (* TODO: we should use a monad "kind" instead of a boolean *) if not monadic then super#visit_Let env monadic lv re e else @@ -1197,95 +1132,24 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) *) (* TODO: this information should be computed in SymbolicToPure and * store in an enum ("monadic" should be an enum, not a bool). *) - let re_uses_state = - Option.is_some (opt_destruct_state_monad_result re.ty) - in - if re_uses_state then ( - let e0 = e in - (* Create a fresh state variable *) - let state_var = fresh_state_var () in - (* The type of `e` is: `state -> e_no_arrow_ty` *) - let _, e_no_arrow_ty = destruct_arrow e.ty in - let e_no_monad_ty = destruct_result e_no_arrow_ty in - let _, re_no_arrow_ty = destruct_arrow re.ty in - let re_no_monad_ty = destruct_result re_no_arrow_ty in - (* Add the state argument on the right-expression *) - let re = - let state_value = mk_texpression_from_var state_var in - mk_app re state_value - in - (* Create the match *) - let fail_pat = mk_result_fail_pattern re_no_monad_ty in - let fail_value = mk_result_fail_texpression e_no_monad_ty in - let fail_branch = { pat = fail_pat; branch = fail_value } in - (* The `Success` branch introduces a fresh state variable *) - let pat_state_var = fresh_state_var () in - let pat_state_pattern = - mk_typed_pattern_from_var pat_state_var None - in - let success_pat = - mk_result_return_pattern - (mk_simpl_tuple_pattern [ pat_state_pattern; lv ]) - in - let pat_state_rvalue = mk_texpression_from_var pat_state_var in - (* TODO: write a utility to create matches (and perform - * type-checking, etc.) *) - let success_branch = - { pat = success_pat; branch = mk_app e pat_state_rvalue } - in - let switch_body = Match [ fail_branch; success_branch ] in - let e = Switch (re, switch_body) in - let e = { e; ty = e_no_arrow_ty } in - (* Add the lambda to introduce the state variable *) - let e = mk_abs (mk_typed_pattern_from_var state_var None) e in - (* Sanity check *) - assert (e0.ty = e.ty); - assert (fail_branch.branch.ty = success_branch.branch.ty); - (* Continue *) - super#visit_expression env e.e) - else - let re_ty = Option.get (opt_destruct_result re.ty) in - assert (lv.ty = re_ty); - let fail_pat = mk_result_fail_pattern lv.ty in - let fail_value = mk_result_fail_texpression e.ty in - let fail_branch = { pat = fail_pat; branch = fail_value } in - let success_pat = mk_result_return_pattern lv in - let success_branch = { pat = success_pat; branch = e } in - let switch_body = Match [ fail_branch; success_branch ] in - let e = Switch (re, switch_body) in - (* Continue *) - super#visit_expression env e + let re_ty = Option.get (opt_destruct_result re.ty) in + assert (lv.ty = re_ty); + let fail_pat = mk_result_fail_pattern lv.ty in + let fail_value = mk_result_fail_texpression e.ty in + let fail_branch = { pat = fail_pat; branch = fail_value } in + let success_pat = mk_result_return_pattern lv in + let success_branch = { pat = success_pat; branch = e } in + let switch_body = Match [ fail_branch; success_branch ] in + let e = Switch (re, switch_body) in + (* Continue *) + super#visit_expression env e end in - (* Update the body: add *) - let body, signature = - let state_var = fresh_state_var () in - (* First, unfold the expressions inside the body *) - let body_e = obj#visit_texpression () body.body in - (* Then, add a "state" input variable if necessary: *) - if config.use_state_monad then - (* - in the body *) - let state_rvalue = mk_texpression_from_var state_var in - let body_e = mk_app body_e state_rvalue in - (* - in the signature *) - let sg = def.signature in - (* Input types *) - let sg_inputs = sg.inputs @ [ mk_state_ty ] in - (* Output types *) - let sg_outputs = Collections.List.to_cons_nil sg.outputs in - let _, sg_outputs = dest_arrow_ty sg_outputs in - let sg_outputs = [ sg_outputs ] in - let sg = { sg with inputs = sg_inputs; outputs = sg_outputs } in - (* Input list *) - let inputs = body.inputs @ [ state_var ] in - let input_lv = mk_typed_pattern_from_var state_var None in - let inputs_lvs = body.inputs_lvs @ [ input_lv ] in - let body = { body = body_e; inputs; inputs_lvs } in - (body, sg) - else ({ body with body = body_e }, def.signature) - in + (* Update the body *) + let body_e = obj#visit_texpression () body.body in + let body = { body with body = body_e } in (* Return *) - { def with body = Some body; signature } + { def with body = Some body } (** Apply all the micro-passes to a function. @@ -1359,12 +1223,9 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_decl) : (lazy ("filter_useless:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); (* Decompose the monadic let-bindings - F* specific - * TODO: remove? With the state-error monad, it is becoming completely - * ad-hoc. *) + * TODO: remove? *) let def = if config.decompose_monadic_let_bindings then ( - (* TODO: we haven't updated the code to handle the state-error monad *) - assert (not config.use_state_monad); let def = decompose_monadic_let_bindings ctx def in log#ldebug (lazy @@ -1381,7 +1242,7 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_decl) : (* Unfold the monadic let-bindings *) let def = if config.unfold_monadic_let_bindings then ( - let def = unfold_monadic_let_bindings config ctx def in + let def = unfold_monadic_let_bindings ctx def in log#ldebug (lazy ("unfold_monadic_let_bindings:\n\n" ^ fun_decl_to_string ctx def |