From 1302f2830905dc63f294aad00d78d03486e13d73 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Sun, 8 Jan 2023 09:42:33 +0100 Subject: Implement a pass to filter the unused input arguments in the loop functions --- compiler/PureMicroPasses.ml | 342 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 328 insertions(+), 14 deletions(-) (limited to 'compiler/PureMicroPasses.ml') diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 09cc2533..e670570b 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -1756,24 +1756,338 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : let loops = List.map (apply_end_passes_to_def ctx) loops in Some (def, loops) -(** Return the forward/backward translations on which we applied the micro-passes. +(** Small utility for {!filter_loop_inputs} *) +let filter_prefix (keep : bool list) (ls : 'a list) : 'a list = + let ls0, ls1 = Collections.List.split_at ls (List.length keep) in + let ls0 = + List.filter_map + (fun (b, x) -> if b then Some x else None) + (List.combine keep ls0) + in + List.append ls0 ls1 + +type fun_loop_id = A.fun_id * LoopId.id option [@@deriving show, ord] + +module FunLoopIdOrderedType = struct + type t = fun_loop_id + + let compare = compare_fun_loop_id + let to_string = show_fun_loop_id + let pp_t = pp_fun_loop_id + let show_t = show_fun_loop_id +end + +module FunLoopIdMap = Collections.MakeMap (FunLoopIdOrderedType) + +(** Filter the useless loop input parameters. *) +let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : + (bool * pure_fun_translation) list = + (* We need to explore groups of mutually recursive functions. In order + to compute which parameters are useless, we need to explore the + functions by groups of mutually recursive definitions. + + Because every Rust function is translated to a list of functions (forward + function, backward functions, loop functions, etc.), and those functions + might depend on each others in different ways, we recompute the SCCs of + the whole module. + + Rem.: we also redo this computation, on a smaller scale, in {!Translate}. + Maybe we can factor out the two. + *) + let all_decls = + List.concat + (List.concat + (List.concat + (List.map + (fun (_, ((fwd, loops_fwd), backs)) -> + [ fwd :: loops_fwd ] + :: List.map + (fun (back, loops_back) -> [ back :: loops_back ]) + backs) + transl))) + in + let subgroups = ReorderDecls.group_reorder_fun_decls all_decls in + + (* Explore the subgroups one by one. + + For now, we only filter the parameters of loop functions which are simply + recursive. + + Rem.: there is a bit of redundancy in computing the useless parameters + for the loop forward *and* the loop backward functions. + *) + (* The [filtered] map: maps function identifiers to filtering information. + + Note that we ignore the backward id: + - we filter the forward inputs only + - we want the filtering to be the same for the forward and the backward + functions + The reason is that for now we want to preserve the fact that a backward + function takes the same inputs as its associated forward function, with + additional parameters. + *) + let used_map = ref FunLoopIdMap.empty in + let fun_id_to_fun_loop_id (fid, loop_id, _) = (fid, loop_id) in + + (* We start by computing the filtering information, for each function *) + let compute_one_filter_info (decl : fun_decl) = + (* There should be a body *) + let body = Option.get decl.body in + (* We only look at the forward inputs, without the state *) + let inputs_prefix, _ = + Collections.List.split_at body.inputs + decl.signature.info.num_fwd_inputs_with_fuel_no_state + in + let used = ref (List.map (fun v -> (var_get_id v, false)) inputs_prefix) in + let inputs_prefix_length = List.length inputs_prefix in + let inputs = + List.map + (fun v -> (var_get_id v, mk_texpression_from_var v)) + inputs_prefix + in + let inputs_set = VarId.Set.of_list (List.map var_get_id inputs_prefix) in + assert (Option.is_some decl.loop_id); + + let fun_id = (A.Regular decl.def_id, decl.loop_id) in + + let set_used vid = + used := List.map (fun (vid', b) -> (vid', b || vid = vid')) !used + in + + (* Set the fuel as used *) + let sg_info = decl.signature.info in + if sg_info.has_fuel then set_used (fst (Collections.List.nth inputs 0)); + + let visitor = + object (self : 'self) + inherit [_] iter_expression as super + + (** Override the expression visitor, to look for loop function calls *) + method! visit_texpression env e = + match e.e with + | App _ -> ( + (* If this is an app: destruct all the arguments, and check if + the leftmost expression is the loop function call *) + let e_app, args = destruct_apps e in + match e_app.e with + | Qualif qualif -> ( + match qualif.id with + | FunOrOp (Fun (FromLlbc fun_id')) -> + if fun_id_to_fun_loop_id fun_id' = fun_id then ( + (* For each argument, check if it is exactly the original + input parameter. Note that there shouldn't be partial + applications of loop functions: the number of arguments + should be exactly the number of input parameters (i.e., + we can use [combine]) + *) + let beg_args, end_args = + Collections.List.split_at args inputs_prefix_length + in + let used_args = List.combine inputs beg_args in + List.iter + (fun ((vid, var), arg) -> + if var <> arg then ( + self#visit_texpression env arg; + set_used vid)) + used_args; + List.iter (self#visit_texpression env) end_args) + else super#visit_texpression env e + | _ -> super#visit_texpression env e) + | _ -> super#visit_texpression env e) + | _ -> super#visit_texpression env e + + (** If we visit a variable which is actually an input parameter, we + set it as used. Note that we take care of ignoring some of those + input parameters given in [visit_texpression]. + *) + method! visit_var_id _ id = + if VarId.Set.mem id inputs_set then set_used id + end + in + visitor#visit_texpression () body.body; + + (* Save the filtering information, if there is anything to filter *) + if List.exists snd !used then + let used = List.map snd !used in + let used = + match FunLoopIdMap.find_opt fun_id !used_map with + | None -> used + | Some used0 -> + List.map (fun (b0, b1) -> b0 || b1) (List.combine used0 used) + in + used_map := FunLoopIdMap.add fun_id used !used_map + in + List.iter + (fun (_, fl) -> + match fl with + | [ f ] -> + (* Group made of one function: check if it is a loop. If it is the + case, explore it. *) + if Option.is_some f.loop_id then compute_one_filter_info f else () + | _ -> + (* Group of mutually recursive functions: ignore for now *) + ()) + subgroups; + + (* We then apply the filtering to all the function definitions at once *) + let filter_in_one (decl : fun_decl) : fun_decl = + (* Filter the function signature *) + let fun_id = (A.Regular decl.def_id, decl.loop_id, decl.back_id) in + let decl = + match FunLoopIdMap.find_opt (fun_id_to_fun_loop_id fun_id) !used_map with + | None -> (* Nothing to filter *) decl + | Some used_info -> + let num_filtered = + List.length (List.filter (fun b -> not b) used_info) + in + let { type_params; inputs; output; doutputs; info } = + decl.signature + in + let { + has_fuel; + num_fwd_inputs_with_fuel_no_state; + num_fwd_inputs_with_fuel_with_state; + num_back_inputs_no_state; + num_back_inputs_with_state; + effect_info; + } = + info + in + + let inputs = filter_prefix used_info inputs in + + let info = + { + has_fuel; + num_fwd_inputs_with_fuel_no_state = + num_fwd_inputs_with_fuel_no_state - num_filtered; + num_fwd_inputs_with_fuel_with_state = + num_fwd_inputs_with_fuel_with_state - num_filtered; + num_back_inputs_no_state; + num_back_inputs_with_state; + effect_info; + } + in + let signature = { type_params; inputs; output; doutputs; info } in + + { decl with signature } + in + + (* Filter the function body *) + let body = + match decl.body with + | None -> None + | Some body -> + (* Update the list of vars *) + let { inputs; inputs_lvs; body } = body in + + let inputs, inputs_lvs = + match + FunLoopIdMap.find_opt (fun_id_to_fun_loop_id fun_id) !used_map + with + | None -> (* Nothing to filter *) (inputs, inputs_lvs) + | Some used_info -> + let inputs = filter_prefix used_info inputs in + let inputs_lvs = filter_prefix used_info inputs_lvs in + (inputs, inputs_lvs) + in + + (* Update the body expression *) + let visitor = + object (self) + inherit [_] map_expression as super + + method! visit_texpression env e = + match e.e with + | App _ -> ( + let e_app, args = destruct_apps e in + match e_app.e with + | Qualif qualif -> ( + match qualif.id with + | FunOrOp (Fun (FromLlbc fun_id)) -> ( + match + FunLoopIdMap.find_opt + (fun_id_to_fun_loop_id fun_id) + !used_map + with + | None -> super#visit_texpression env e + | Some used_info -> + (* Filter the types in the arrow type *) + let tys, ret_ty = destruct_arrows e_app.ty in + let tys = filter_prefix used_info tys in + let ty = mk_arrows tys ret_ty in + let e_app = { e_app with ty } in + + (* Filter the arguments *) + let args = filter_prefix used_info args in + + (* Explore the arguments *) + let args = + List.map (self#visit_texpression env) args + in + + (* Rebuild *) + mk_apps e_app args) + | _ -> + let e_app = self#visit_texpression env e_app in + let args = + List.map (self#visit_texpression env) args + in + mk_apps e_app args) + | _ -> + let e_app = self#visit_texpression env e_app in + let args = List.map (self#visit_texpression env) args in + mk_apps e_app args) + | _ -> super#visit_texpression env e + end + in + let body = visitor#visit_texpression () body in + Some { inputs; inputs_lvs; body } + in + { decl with body } + in + let transl = + List.map + (fun (b, (fwd, backs)) -> + let filter_fun_and_loops (f, fl) = + (filter_in_one f, List.map filter_in_one fl) + in + let fwd = filter_fun_and_loops fwd in + let backs = List.map filter_fun_and_loops backs in + (b, (fwd, backs))) + transl + in + + (* Return *) + transl + +(** Apply the micro-passes to a list of forward/backward translations. This function also extracts the loop definitions from the function body (see {!decompose_loops}). - Also returns a boolean indicating whether the forward function should be kept - or not (because useful/useless - [true] means we need to keep the forward - function). + It also returns a boolean indicating whether the forward function should be kept + or not at extraction time ([true] means we need to keep the forward function). + Note that we don't "filter" the forward function and return a boolean instead, because this function contains useful information to extract the backward - functions: keeping it is not necessary but more convenient. + functions. Note that here, keeping the forward function it is not *necessary* + but convenient. *) -let apply_passes_to_pure_fun_translation (ctx : trans_ctx) - (trans : fun_decl * fun_decl list) : bool * pure_fun_translation = - (* Apply the passes to the individual functions *) - let forward, backwards = trans in - let forward = Option.get (apply_passes_to_def ctx forward) in - let backwards = List.filter_map (apply_passes_to_def ctx) backwards in - let trans = (forward, backwards) in - (* Compute whether we need to filter the forward function or not *) - (keep_forward trans, trans) +let apply_passes_to_pure_fun_translations (ctx : trans_ctx) + (transl : (fun_decl * fun_decl list) list) : + (bool * pure_fun_translation) list = + let apply_to_one (trans : fun_decl * fun_decl list) : + bool * pure_fun_translation = + (* Apply the passes to the individual functions *) + let forward, backwards = trans in + let forward = Option.get (apply_passes_to_def ctx forward) in + let backwards = List.filter_map (apply_passes_to_def ctx) backwards in + let trans = (forward, backwards) in + (* Compute whether we need to filter the forward function or not *) + (keep_forward trans, trans) + in + let transl = List.map apply_to_one transl in + + (* Filter the useless inputs in the loop functions *) + filter_loop_inputs transl -- cgit v1.2.3