diff options
Diffstat (limited to 'src/PureMicroPasses.ml')
-rw-r--r-- | src/PureMicroPasses.ml | 474 |
1 files changed, 267 insertions, 207 deletions
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index b110f829..5227b2ad 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -311,14 +311,22 @@ let compute_pretty_names (def : fun_decl) : fun_decl = (ctx, e.e) in - let input_names = - List.filter_map - (fun (v : var) -> - match v.basename with None -> None | Some name -> Some (v.id, name)) - def.inputs + let body = + match def.body with + | None -> None + | Some body -> + let input_names = + List.filter_map + (fun (v : var) -> + match v.basename with + | None -> None + | Some name -> Some (v.id, name)) + body.inputs + in + let ctx = VarId.Map.of_list input_names in + let _, body_exp = update_texpression body.body ctx in + Some { body with body = body_exp } in - let ctx = VarId.Map.of_list input_names in - let _, body = update_texpression def.body ctx in { def with body } (** Remove the meta-information *) @@ -330,8 +338,11 @@ let remove_meta (def : fun_decl) : fun_decl = method! visit_Meta env _ e = super#visit_expression env e.e end in - let body = obj#visit_texpression () def.body in - { def with body } + match def.body with + | None -> def + | Some body -> + let body = { body with body = obj#visit_texpression () body.body } in + { def with body = Some body } (** Inline the useless variable (re-)assignments: @@ -452,8 +463,13 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) (** Visit the places used as rvalues, to substitute them if possible *) end in - let body = obj#visit_texpression VarId.Map.empty def.body in - { def with body } + match def.body with + | None -> def + | Some body -> + let body = + { body with body = obj#visit_texpression VarId.Map.empty body.body } + in + { def with body = Some body } (** Given a forward or backward function call, is there, for every execution path, a child backward function called later with exactly the same input @@ -683,22 +699,29 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) dont_filter () end in - (* Visit the body *) - let body, used_vars = expr_visitor#visit_texpression () def.body in - (* Visit the parameters - TODO: update: we can filter only if the definition - * is not recursive (otherwise it might mess up with the decrease clauses: - * the decrease clauses uses all the inputs given to the function, if some - * inputs are replaced by '_' we can't give it to the function used in the - * decreases clause). - * For now we deactivate the filtering. *) - let used_vars = used_vars () in - let inputs_lvs = - if false then - List.map (fun lv -> fst (filter_typed_lvalue used_vars lv)) def.inputs_lvs - else def.inputs_lvs - in - (* Return *) - { def with body; inputs_lvs } + (* We filter only inside of transparent (i.e., non-opaque) definitions *) + match def.body with + | None -> def + | Some body -> + (* Visit the body *) + let body_exp, used_vars = expr_visitor#visit_texpression () body.body in + (* Visit the parameters - TODO: update: we can filter only if the definition + * is not recursive (otherwise it might mess up with the decrease clauses: + * the decrease clauses uses all the inputs given to the function, if some + * inputs are replaced by '_' we can't give it to the function used in the + * decreases clause). + * For now we deactivate the filtering. *) + let used_vars = used_vars () in + let inputs_lvs = + if false then + List.map + (fun lv -> fst (filter_typed_lvalue used_vars lv)) + body.inputs_lvs + else body.inputs_lvs + in + (* Return *) + let body = { body with body = body_exp; inputs_lvs } in + { def with body = Some body } (** Return `None` if the function is a backward function with no outputs (so that we eliminate the definition which is useless). @@ -763,21 +786,31 @@ let to_monadic (config : config) (def : fun_decl) : fun_decl = super#visit_call env call end in - let body = obj#visit_texpression () def.body in - let def = { def with body } in + let def = + match def.body with + | None -> def + | Some body -> + let body = { body with body = obj#visit_texpression () body.body } in + { def with body = Some body } + in (* Update the signature: first the input types *) let def = - if def.inputs = [] && config.add_unit_args then ( - assert (def.signature.inputs = []); + if def.signature.inputs = [] && config.add_unit_args then let signature = { def.signature with inputs = [ unit_ty ] } in - let var_cnt = get_expression_min_var_counter def.body.e in - let id, _ = VarId.fresh var_cnt in - let var = { id; basename = None; ty = unit_ty } in - let inputs = [ var ] in - let input_lv = mk_typed_lvalue_from_var var None in - let inputs_lvs = [ input_lv ] in - { def with signature; inputs; inputs_lvs }) + let body = + match def.body with + | None -> None + | Some body -> + let var_cnt = get_expression_min_var_counter body.body.e in + let id, _ = VarId.fresh var_cnt in + let var = { id; basename = None; ty = unit_ty } in + let inputs = [ var ] in + let input_lv = mk_typed_lvalue_from_var var None in + let inputs_lvs = [ input_lv ] in + Some { body with inputs; inputs_lvs } + in + { def with signature; body } else def in (* Then the output type *) @@ -830,11 +863,15 @@ let unit_vars_to_unit (def : fun_decl) : fun_decl = end in (* Update the body *) - let body = obj#visit_texpression () def.body in - (* Update the input parameters *) - let inputs_lvs = List.map (obj#visit_typed_lvalue ()) def.inputs_lvs in - (* Return *) - { def with body; inputs_lvs } + match def.body with + | None -> def + | Some body -> + let body_exp = obj#visit_texpression () body.body in + (* Update the input parameters *) + let inputs_lvs = List.map (obj#visit_typed_lvalue ()) body.inputs_lvs in + (* Return *) + let body = Some { body with body = body_exp; inputs_lvs } in + { def with body } (** Eliminate the box functions like `Box::new`, `Box::deref`, etc. Most of them are translated to identity, and `Box::free` is translated to `()`. @@ -887,178 +924,201 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = end in (* Update the body *) - let body = obj#visit_texpression () def.body in - { def with body } + match def.body with + | None -> def + | Some body -> + let body = Some { body with body = obj#visit_texpression () body.body } in + { def with body } (** Decompose the monadic let-bindings. See the explanations in [config]. *) -let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl - = - (* Set up the var id generator *) - let cnt = get_expression_min_var_counter def.body.e in - let _, fresh_id = VarId.mk_stateful_generator cnt in - (* It is a very simple map *) - let obj = - object (self) - inherit [_] map_expression as super - - method! visit_Let env monadic lv re next_e = - if not monadic then super#visit_Let env monadic lv re next_e - else - (* If monadic, we need to check if the left-value is a variable: - * - if yes, don't decompose - * - if not, make the decomposition in two steps - *) - match lv.value with - | LvVar _ -> - (* Variable: nothing to do *) - super#visit_Let env monadic lv re next_e - | _ -> - (* Not a variable: decompose *) - (* Introduce a temporary variable to receive the value of the - * monadic binding *) - let vid = fresh_id () in - let tmp : var = { id = vid; basename = None; ty = lv.ty } in - let ltmp = mk_typed_lvalue_from_var tmp None in - let rtmp = mk_typed_rvalue_from_var tmp in - let rtmp = mk_value_expression rtmp None in - (* Visit the next expression *) - let next_e = self#visit_texpression env next_e in - (* Create the let-bindings *) - (mk_let true ltmp re (mk_let false lv rtmp next_e)).e - end - in - (* Update the body *) - let body = obj#visit_texpression () def.body in - (* Return *) - { def with body } +let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : + fun_decl = + match def.body with + | None -> def + | Some body -> + (* Set up the var id generator *) + let cnt = get_expression_min_var_counter body.body.e in + let _, fresh_id = VarId.mk_stateful_generator cnt in + (* It is a very simple map *) + let obj = + object (self) + inherit [_] map_expression as super + + method! visit_Let env monadic lv re next_e = + if not monadic then super#visit_Let env monadic lv re next_e + else + (* If monadic, we need to check if the left-value is a variable: + * - if yes, don't decompose + * - if not, make the decomposition in two steps + *) + match lv.value with + | LvVar _ -> + (* Variable: nothing to do *) + super#visit_Let env monadic lv re next_e + | _ -> + (* Not a variable: decompose *) + (* Introduce a temporary variable to receive the value of the + * monadic binding *) + let vid = fresh_id () in + let tmp : var = { id = vid; basename = None; ty = lv.ty } in + let ltmp = mk_typed_lvalue_from_var tmp None in + let rtmp = mk_typed_rvalue_from_var tmp in + let rtmp = mk_value_expression rtmp None in + (* Visit the next expression *) + let next_e = self#visit_texpression env next_e in + (* Create the let-bindings *) + (mk_let true ltmp re (mk_let false lv rtmp next_e)).e + end + in + (* Update the body *) + let body = Some { body with body = obj#visit_texpression () body.body } in + (* Return *) + { def with body } (** Unfold the monadic let-bindings to explicit matches. *) let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) (def : fun_decl) : fun_decl = - (* We may need to introduce fresh variables for the state *) - let var_cnt = get_expression_min_var_counter def.body.e in - let _, fresh_var_id = VarId.mk_stateful_generator var_cnt in - let fresh_state_var () = - let id = fresh_var_id () in - { id; basename = Some "st"; ty = mk_state_ty } - in - (* It is a very simple map *) - let obj = - object (self) - inherit [_] map_expression as super - - method! visit_Let state_var monadic lv re e = - if not monadic then super#visit_Let state_var monadic lv re e - else - (* We don't do the same thing if we use a state-error monad or simply - * an error monad. - * Note that some functions always live in the error monad (arithmetic - * operations, for instance). - *) - let re_call = - match re.e with - | Call call -> call - | _ -> raise (Failure "Unreachable: expected a function call") - in - (* TODO: this information should be computed in SymbolicToPure and - * store in an enum ("monadic" should be an enum, not a bool). - * Also: everything will be cleaner once we update the AST to make - * it more idiomatic lambda calculus... *) - let re_call_can_use_state = - match re_call.func with - | Regular (A.Local _, _) -> true - | Regular (A.Assumed _, _) | Unop _ | Binop _ -> false - in - if config.use_state_monad && re_call_can_use_state then - let re_call = - let call = re_call in - let state_value = mk_typed_rvalue_from_var state_var in - let args = call.args @ [ mk_value_expression state_value None ] in - Call { call with args } - in - let re = { re with e = re_call } in - (* Create the match *) - let fail_pat = mk_result_fail_lvalue lv.ty in - let fail_value = mk_result_fail_rvalue e.ty in - let fail_branch = - { pat = fail_pat; branch = mk_value_expression fail_value None } - in - (* The `Success` branch introduces a fresh state variable *) - let state_var = fresh_state_var () in - let state_value = mk_typed_lvalue_from_var state_var None in - let success_pat = - mk_result_return_lvalue - (mk_simpl_tuple_lvalue [ state_value; lv ]) - in - let success_branch = { pat = success_pat; branch = e } in - let switch_body = Match [ fail_branch; success_branch ] in - let e = Switch (re, switch_body) in - self#visit_expression state_var e - else - let fail_pat = mk_result_fail_lvalue lv.ty in - let fail_value = mk_result_fail_rvalue e.ty in - let fail_branch = - { pat = fail_pat; branch = mk_value_expression fail_value None } - in - let success_pat = mk_result_return_lvalue lv in - let success_branch = { pat = success_pat; branch = e } in - let switch_body = Match [ fail_branch; success_branch ] in - let e = Switch (re, switch_body) in - self#visit_expression state_var e - - method! visit_Value state_var rv mp = - if config.use_state_monad then - match rv.ty with - | Adt (Assumed Result, _) -> ( - match rv.value with - | RvAdt av -> - (* We only need to replace the content of `Return ...` *) - (* TODO: type checking is completely broken at this point... *) - let variant_id = Option.get av.variant_id in - if variant_id = result_return_id then - let res_v = Collections.List.to_cons_nil av.field_values in - let state_value = mk_typed_rvalue_from_var state_var in - let res = mk_simpl_tuple_rvalue [ state_value; res_v ] in - let res = mk_result_return_rvalue res in - (mk_value_expression res None).e - else super#visit_Value state_var rv mp - | _ -> raise (Failure "Unrechable")) - | _ -> super#visit_Value state_var rv mp - else super#visit_Value state_var rv mp - (** We also need to update values, in case this value is `Return ...`. - + match def.body with + | None -> def + | Some body -> + (* We may need to introduce fresh variables for the state *) + let var_cnt = get_expression_min_var_counter body.body.e in + let _, fresh_var_id = VarId.mk_stateful_generator var_cnt in + let fresh_state_var () = + let id = fresh_var_id () in + { id; basename = Some "st"; ty = mk_state_ty } + in + (* It is a very simple map *) + let obj = + object (self) + inherit [_] map_expression as super + + method! visit_Let state_var monadic lv re e = + if not monadic then super#visit_Let state_var monadic lv re e + else + (* We don't do the same thing if we use a state-error monad or simply + * an error monad. + * Note that some functions always live in the error monad (arithmetic + * operations, for instance). + *) + let re_call = + match re.e with + | Call call -> call + | _ -> raise (Failure "Unreachable: expected a function call") + in + (* TODO: this information should be computed in SymbolicToPure and + * store in an enum ("monadic" should be an enum, not a bool). + * Also: everything will be cleaner once we update the AST to make + * it more idiomatic lambda calculus... *) + let re_call_can_use_state = + match re_call.func with + | Regular (A.Local _, _) -> true + | Regular (A.Assumed _, _) | Unop _ | Binop _ -> false + in + if config.use_state_monad && re_call_can_use_state then + let re_call = + let call = re_call in + let state_value = mk_typed_rvalue_from_var state_var in + let args = + call.args @ [ mk_value_expression state_value None ] + in + Call { call with args } + in + let re = { re with e = re_call } in + (* Create the match *) + let fail_pat = mk_result_fail_lvalue lv.ty in + let fail_value = mk_result_fail_rvalue e.ty in + let fail_branch = + { + pat = fail_pat; + branch = mk_value_expression fail_value None; + } + in + (* The `Success` branch introduces a fresh state variable *) + let state_var = fresh_state_var () in + let state_value = mk_typed_lvalue_from_var state_var None in + let success_pat = + mk_result_return_lvalue + (mk_simpl_tuple_lvalue [ state_value; lv ]) + in + let success_branch = { pat = success_pat; branch = e } in + let switch_body = Match [ fail_branch; success_branch ] in + let e = Switch (re, switch_body) in + self#visit_expression state_var e + else + let fail_pat = mk_result_fail_lvalue lv.ty in + let fail_value = mk_result_fail_rvalue e.ty in + let fail_branch = + { + pat = fail_pat; + branch = mk_value_expression fail_value None; + } + in + let success_pat = mk_result_return_lvalue lv in + let success_branch = { pat = success_pat; branch = e } in + let switch_body = Match [ fail_branch; success_branch ] in + let e = Switch (re, switch_body) in + self#visit_expression state_var e + + method! visit_Value state_var rv mp = + if config.use_state_monad then + match rv.ty with + | Adt (Assumed Result, _) -> ( + match rv.value with + | RvAdt av -> + (* We only need to replace the content of `Return ...` *) + (* TODO: type checking is completely broken at this point... *) + let variant_id = Option.get av.variant_id in + if variant_id = result_return_id then + let res_v = + Collections.List.to_cons_nil av.field_values + in + let state_value = mk_typed_rvalue_from_var state_var in + let res = + mk_simpl_tuple_rvalue [ state_value; res_v ] + in + let res = mk_result_return_rvalue res in + (mk_value_expression res None).e + else super#visit_Value state_var rv mp + | _ -> raise (Failure "Unrechable")) + | _ -> super#visit_Value state_var rv mp + else super#visit_Value state_var rv mp + (** We also need to update values, in case this value is `Return ...`. + TODO: this is super ugly... We need to use the monadic functions - `fail` and `return` instead. - *) - end - in - (* Update the body *) - let input_state_var = fresh_state_var () in - let body = obj#visit_texpression input_state_var def.body in - let def = { def with body } in - (* We need to update the type if we revealed the state monad *) - let def = - if config.use_state_monad then - (* Update the signature *) - let sg = def.signature in - let sg_inputs = sg.inputs @ [ mk_state_ty ] in - let sg_outputs = Collections.List.to_cons_nil sg.outputs in - let _, sg_outputs = dest_arrow_ty sg_outputs in - let sg_outputs = [ sg_outputs ] in - let sg = { sg with inputs = sg_inputs; outputs = sg_outputs } in - (* Update the inputs list *) - let inputs = def.inputs @ [ input_state_var ] in - let input_lv = mk_typed_lvalue_from_var input_state_var None in - let inputs_lvs = def.inputs_lvs @ [ input_lv ] in - (* Update the definition *) - { def with signature = sg; inputs; inputs_lvs } - else def - in - (* Return *) - { def with body } + `fail` and `return` instead. + *) + end + in + (* Update the body *) + let input_state_var = fresh_state_var () in + let body = + { body with body = obj#visit_texpression input_state_var body.body } + in + (* We need to update the type if we revealed the state monad *) + let body, signature = + if config.use_state_monad then + (* Update the signature *) + let sg = def.signature in + let sg_inputs = sg.inputs @ [ mk_state_ty ] in + let sg_outputs = Collections.List.to_cons_nil sg.outputs in + let _, sg_outputs = dest_arrow_ty sg_outputs in + let sg_outputs = [ sg_outputs ] in + let sg = { sg with inputs = sg_inputs; outputs = sg_outputs } in + (* Update the inputs list *) + let inputs = body.inputs @ [ input_state_var ] in + let input_lv = mk_typed_lvalue_from_var input_state_var None in + let inputs_lvs = body.inputs_lvs @ [ input_lv ] in + (* Update the body *) + let body = { body with inputs; inputs_lvs } in + (body, sg) + else (body, def.signature) + in + (* Return *) + { def with body = Some body; signature } (** Apply all the micro-passes to a function. @@ -1149,8 +1209,8 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_decl) : let def = decompose_monadic_let_bindings ctx def in log#ldebug (lazy - ("decompose_monadic_let_bindings:\n\n" ^ fun_decl_to_string ctx def - ^ "\n")); + ("decompose_monadic_let_bindings:\n\n" + ^ fun_decl_to_string ctx def ^ "\n")); def) else ( log#ldebug |