From fe2a2cb34148e46e32cdcfbf100e38d9986082cd Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 8 Mar 2024 16:06:35 +0100 Subject: Make progress on propagating the changes --- compiler/PureMicroPasses.ml | 331 +++++--------------------------------------- 1 file changed, 31 insertions(+), 300 deletions(-) (limited to 'compiler/PureMicroPasses.ml') diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index ec64df21..04bc90d7 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -925,156 +925,9 @@ let inline_useless_var_reassignments (ctx : trans_ctx) ~(inline_named : bool) in { def with body = Some body } -(** For the cases where we split the forward/backward functions. - - Given a forward or backward function call, is there, for every execution - path, a child backward function called later with exactly the same input - list prefix. We use this to filter useless function calls: if there are - such child calls, we can remove this one (in case its outputs are not - used). - We do this check because we can't simply remove function calls whose - outputs are not used, as they might fail. However, if a function fails, - its children backward functions then fail on the same inputs (ignoring - the additional inputs those receive). - - For instance, if we have: - {[ - fn f<'a>(x : &'a mut T); - ]} - - We often have things like this in the synthesized code: - {[ - _ <-- f@fwd x; - ... - nx <-- f@back'a x y; - ... - ]} - - If [f@back'a x y] fails, then necessarily [f@fwd x] also fails. - In this situation, we can remove the call [f@fwd x]. - *) -let expression_contains_child_call_in_all_paths (ctx : trans_ctx) - (id0 : fun_id_or_trait_method_ref) (lp_id0 : LoopId.id option) - (rg_id0 : T.RegionGroupId.id option) (generics0 : generic_args) - (args0 : texpression list) (e : texpression) : bool = - let check_call (fun_id1 : fun_or_op_id) (generics1 : generic_args) - (args1 : texpression list) : bool = - (* Check the fun_ids, to see if call1's function is a child of call0's function *) - match fun_id1 with - | Fun (FromLlbc (id1, lp_id1, rg_id1)) -> - (* Both are "regular" calls: check if they come from the same rust function *) - if id0 = id1 && lp_id0 = lp_id1 then - (* Same rust functions: check the regions hierarchy *) - let call1_is_child = - match (rg_id0, rg_id1) with - | None, _ -> - (* The function used in call0 is the forward function: the one - * used in call1 is necessarily a child *) - true - | Some _, None -> - (* Opposite of previous case *) - false - | Some rg_id0, Some rg_id1 -> - if rg_id0 = rg_id1 then true - else - (* We need to use the regions hierarchy *) - let regions_hierarchy = - let id0 = - match id0 with - | FunId fun_id -> fun_id - | TraitMethod (_, _, fun_decl_id) -> FRegular fun_decl_id - in - LlbcAstUtils.FunIdMap.find id0 - ctx.fun_ctx.regions_hierarchies - in - (* Compute the set of ancestors of the function in call1 *) - let call1_ancestors = - LlbcAstUtils.list_ancestor_region_groups regions_hierarchy - rg_id1 - in - (* Check if the function used in call0 is inside *) - T.RegionGroupId.Set.mem rg_id0 call1_ancestors - in - (* If call1 is a child, then we need to check if the input arguments - * used in call0 are a prefix of the input arguments used in call1 - * (note call1 being a child, it will likely consume strictly more - * given back values). - * *) - if call1_is_child then - let call1_args = - Collections.List.prefix (List.length args0) args1 - in - let args = List.combine args0 call1_args in - (* Note that the input values are expressions, *which may contain - * meta-values* (which we need to ignore). *) - let input_eq (v0, v1) = - PureUtils.remove_meta v0 = PureUtils.remove_meta v1 - in - (* Compare the generics and the prefix of the input arguments *) - generics0 = generics1 && List.for_all input_eq args - else (* Not a child *) - false - else (* Not the same function *) - false - | _ -> false - in - - let visitor = - object (self) - inherit [_] reduce_expression - method zero _ = false - method plus b0 b1 _ = b0 () && b1 () - - method! visit_texpression env e = - match e.e with - | Var _ | CVar _ | Const _ -> fun _ -> false - | StructUpdate _ -> - (* There shouldn't be monadic calls in structure updates - also - note that by returning [false] we are conservative: we might - *prevent* possible optimisations (i.e., filtering some function - calls), which is sound. *) - fun _ -> false - | Let (_, _, re, e) -> ( - match opt_destruct_function_call re with - | None -> fun () -> self#visit_texpression env e () - | Some (func1, generics1, args1) -> - let call_is_child = check_call func1 generics1 args1 in - if call_is_child then fun () -> true - else fun () -> self#visit_texpression env e ()) - | Lambda (_, e) -> self#visit_texpression env e - | App _ -> ( - fun () -> - match opt_destruct_function_call e with - | Some (func1, tys1, args1) -> check_call func1 tys1 args1 - | None -> false) - | Qualif _ -> - (* Note that this case includes functions without arguments *) - fun () -> false - | Meta (_, e) -> self#visit_texpression env e - | Loop loop -> - (* We only visit the *function end* *) - self#visit_texpression env loop.fun_end - | Switch (_, body) -> self#visit_switch_body env body - - method! visit_switch_body env body = - match body with - | If (e1, e2) -> - fun () -> - self#visit_texpression env e1 () - && self#visit_texpression env e2 () - | Match branches -> - fun () -> - List.for_all - (fun br -> self#visit_texpression env br.branch ()) - branches - end - in - visitor#visit_texpression () e () - (** Filter the useless assignments (removes the useless variables, filters the function calls) *) -let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) - (def : fun_decl) : fun_decl = +let filter_useless (_ctx : trans_ctx) (def : fun_decl) : fun_decl = (* We first need a transformation on *left-values*, which filters the useless * variables and tells us whether the value contains any variable which has * not been replaced by [_] (in which case we need to keep the assignment, @@ -1166,30 +1019,8 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) if not monadic then (* Not a monadic let-binding: simple case *) (e.e, fun _ -> used) - else - (* Monadic let-binding: trickier. - * We can filter if the right-expression is a function call, - * under some conditions. *) - match (filter_monadic_calls, opt_destruct_function_call re) with - | true, Some (Fun (FromLlbc (fid, lp_id, rg_id)), tys, args) -> - (* If we split the forward/backward functions. - - We need to check if there is a child call - see - the comments for: - [expression_contains_child_call_in_all_paths] *) - if not !Config.return_back_funs then - let has_child_call = - expression_contains_child_call_in_all_paths ctx fid - lp_id rg_id tys args e - in - if has_child_call then (* Filter *) - (e.e, fun _ -> used) - else (* No child call: don't filter *) - dont_filter () - else dont_filter () - | _ -> - (* Not an LLBC function call or not allowed to filter: we can't filter *) - dont_filter () + else (* Monadic let-binding: can't filter *) + dont_filter () else (* There are used variables: don't filter *) dont_filter () | Loop loop -> @@ -1442,22 +1273,6 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = let body = { body with body = body_exp } in { def with body = Some body } -(** Return [None] if the function is a backward function with no outputs (so - that we eliminate the definition which is useless). - - Note that the calls to such functions are filtered when translating from - symbolic to pure. Here, we remove the definitions altogether, because they - are now useless - *) -let filter_if_backward_with_no_outputs (def : fun_decl) : fun_decl option = - if - !Config.filter_useless_functions - && Option.is_some def.back_id - && def.signature.output = mk_result_ty mk_unit_ty - || def.signature.output = mk_unit_ty - then None - else Some def - (** Retrieve the loop definitions from the function definition. {!SymbolicToPure} generates an AST in which the loop bodies are part of @@ -1530,14 +1345,7 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) : info.num_inputs_with_fuel_no_state info.num_inputs_with_fuel_with_state in - let back_inputs = - if !Config.return_back_funs then [] - else - snd - (Collections.List.split_at fun_sig.inputs - info.num_inputs_with_fuel_with_state) - in - List.concat [ fuel; fwd_inputs; fwd_state; back_inputs ] + List.concat [ fuel; fwd_inputs; fwd_state ] in let output = loop.output_ty in @@ -1618,7 +1426,6 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) : kind = def.kind; num_loops; loop_id = Some loop.loop_id; - back_id = def.back_id; llbc_name = def.llbc_name; name = def.name; signature = loop_sig; @@ -1640,35 +1447,6 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) : let loops = List.map snd (LoopId.Map.bindings !loops) in (def, loops) -(** Return [false] if the forward function is useless and should be filtered. - - - a forward function with no output (comes from a Rust function with - unit return type) - - the function has mutable borrows as inputs (which is materialized - by the fact we generated backward functions which were not filtered). - - In such situation, every call to the Rust function will be translated to: - - a call to the forward function which returns nothing - - calls to the backward functions - As a failing backward function implies the forward function also fails, - we can filter the calls to the forward function, which thus becomes - useless. - In such situation, we can remove the forward function definition - altogether. - *) -let keep_forward (fwd : fun_and_loops) (backs : fun_and_loops list) : bool = - (* The question of filtering the forward functions arises only if we split - the forward/backward functions *) - if !Config.return_back_funs then true - else if - (* Note that at this point, the output types are no longer seen as tuples: - * they should be lists of length 1. *) - !Config.filter_useless_functions - && fwd.f.signature.output = mk_result_ty mk_unit_ty - && backs <> [] - then false - else true - (** Convert the unit variables to [()] if they are used as right-values or [_] if they are used as left values in patterns. *) let unit_vars_to_unit (def : fun_decl) : fun_decl = @@ -1724,19 +1502,17 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = * could have: [box_new f x]) * *) match fun_id with - | Fun (FromLlbc (FunId (FAssumed aid), _lp_id, rg_id)) -> ( - match (aid, rg_id) with - | BoxNew, _ -> - assert (rg_id = None); + | Fun (FromLlbc (FunId (FAssumed aid), _lp_id)) -> ( + match aid with + | BoxNew -> let arg, args = Collections.List.pop args in mk_apps arg args - | BoxFree, _ -> + | BoxFree -> assert (args = []); mk_unit_rvalue - | ( ( SliceIndexShared | SliceIndexMut | ArrayIndexShared - | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut - | ArrayRepeat ), - _ ) -> + | SliceIndexShared | SliceIndexMut | ArrayIndexShared + | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut + | ArrayRepeat -> super#visit_texpression env e) | _ -> super#visit_texpression env e) | _ -> super#visit_texpression env e @@ -1989,7 +1765,7 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = (lazy ("eliminate_box_functions:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); (* Filter the useless variables, assignments, function calls, etc. *) - let def = filter_useless !Config.filter_useless_monadic_calls ctx def in + let def = filter_useless ctx def in log#ldebug (lazy ("filter_useless:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); (* Simplify the lets immediately followed by a return. @@ -2130,16 +1906,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : *) let all_decls = List.concat - (List.concat - (List.concat - (List.map - (fun { fwd; backs; _ } -> - [ fwd.f :: fwd.loops ] - :: List.map - (fun { f = back; loops = loops_back } -> - [ back :: loops_back ]) - backs) - transl))) + (List.concat (List.map (fun { f; loops } -> [ f :: loops ]) transl)) in let subgroups = ReorderDecls.group_reorder_fun_decls all_decls in @@ -2207,7 +1974,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : match e_app.e with | Qualif qualif -> ( match qualif.id with - | FunOrOp (Fun (FromLlbc (FunId fun_id', loop_id', _))) -> + | FunOrOp (Fun (FromLlbc (FunId fun_id', loop_id'))) -> if (fun_id', loop_id') = fun_id then ( (* For each argument, check if it is exactly the original input parameter. Note that there shouldn't be partial @@ -2357,8 +2124,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : match e_app.e with | Qualif qualif -> ( match qualif.id with - | FunOrOp (Fun (FromLlbc (FunId fun_id, loop_id, _))) - -> ( + | FunOrOp (Fun (FromLlbc (FunId fun_id, loop_id))) -> ( match FunLoopIdMap.find_opt (fun_id, loop_id) !used_map with @@ -2400,13 +2166,8 @@ let filter_loop_inputs (transl : pure_fun_translation list) : in let transl = List.map - (fun trans -> - let filter_fun_and_loops f = - { f = filter_in_one f.f; loops = List.map filter_in_one f.loops } - in - let fwd = filter_fun_and_loops trans.fwd in - let backs = List.map filter_fun_and_loops trans.backs in - { trans with fwd; backs }) + (fun f -> + { f = filter_in_one f.f; loops = List.map filter_in_one f.loops }) transl in @@ -2420,18 +2181,11 @@ let filter_loop_inputs (transl : pure_fun_translation list) : it thus returns the pair: (function def, loop defs). See {!decompose_loops} for more information. - Will return [None] if the function is a backward function with no outputs. - [ctx]: used only for printing. *) -let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : - fun_and_loops option = +let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_and_loops = (* Debug *) - log#ldebug - (lazy - ("PureMicroPasses.apply_passes_to_def: " ^ def.name ^ " (" - ^ Print.option_to_string T.RegionGroupId.to_string def.back_id - ^ ")")); + log#ldebug (lazy ("PureMicroPasses.apply_passes_to_def: " ^ def.name)); log#ldebug (lazy ("original decl:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); @@ -2451,29 +2205,13 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : let def = remove_meta def in log#ldebug (lazy ("remove_meta:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); - (* Remove the backward functions with no outputs. + (* Extract the loop definitions by removing the {!Loop} node *) + let def, loops = decompose_loops ctx def in - Note that the *calls* to those functions should already have been removed, - when translating from symbolic to pure. Here, we remove the definitions - altogether, because they are now useless *) - let name = def.name ^ PrintPure.fun_suffix def.loop_id def.back_id in - let opt_def = filter_if_backward_with_no_outputs def in - - match opt_def with - | None -> - log#ldebug (lazy ("filtered (backward with no outputs): " ^ name ^ "\n")); - None - | Some def -> - log#ldebug - (lazy ("not filtered (not backward with no outputs): " ^ name ^ "\n")); - - (* Extract the loop definitions by removing the {!Loop} node *) - let def, loops = decompose_loops ctx def in - - (* Apply the remaining passes *) - let f = apply_end_passes_to_def ctx def in - let loops = List.map (apply_end_passes_to_def ctx) loops in - Some { f; loops } + (* Apply the remaining passes *) + let f = apply_end_passes_to_def ctx def in + let loops = List.map (apply_end_passes_to_def ctx) loops in + { f; loops } (** Apply the micro-passes to a list of forward/backward translations. @@ -2489,18 +2227,11 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : but convenient. *) let apply_passes_to_pure_fun_translations (ctx : trans_ctx) - (transl : (fun_decl * fun_decl list) list) : pure_fun_translation list = - let apply_to_one (trans : fun_decl * fun_decl list) : pure_fun_translation = - (* Apply the passes to the individual functions *) - let fwd, backs = trans in - let fwd = Option.get (apply_passes_to_def ctx fwd) in - let backs = List.filter_map (apply_passes_to_def ctx) backs in - (* Compute whether we need to filter the forward function or not *) - let keep_fwd = keep_forward fwd backs in - { keep_fwd; fwd; backs } - in - - let transl = List.map apply_to_one transl in + (transl : fun_decl list) : pure_fun_translation list = + (* Apply the micro-passes *) + let transl = List.map (apply_passes_to_def ctx) transl in - (* Filter the useless inputs in the loop functions *) + (* Filter the useless inputs in the loop functions (loops are initially + parameterized by *all* the symbolic values in the context, because + they may access any of them). *) filter_loop_inputs transl -- cgit v1.2.3