summaryrefslogtreecommitdiff
path: root/compiler/PureMicroPasses.ml
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--compiler/PureMicroPasses.ml342
1 files changed, 328 insertions, 14 deletions
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 09cc2533..e670570b 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -1756,24 +1756,338 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
let loops = List.map (apply_end_passes_to_def ctx) loops in
Some (def, loops)
-(** Return the forward/backward translations on which we applied the micro-passes.
+(** Small utility for {!filter_loop_inputs} *)
+let filter_prefix (keep : bool list) (ls : 'a list) : 'a list =
+ let ls0, ls1 = Collections.List.split_at ls (List.length keep) in
+ let ls0 =
+ List.filter_map
+ (fun (b, x) -> if b then Some x else None)
+ (List.combine keep ls0)
+ in
+ List.append ls0 ls1
+
+type fun_loop_id = A.fun_id * LoopId.id option [@@deriving show, ord]
+
+module FunLoopIdOrderedType = struct
+ type t = fun_loop_id
+
+ let compare = compare_fun_loop_id
+ let to_string = show_fun_loop_id
+ let pp_t = pp_fun_loop_id
+ let show_t = show_fun_loop_id
+end
+
+module FunLoopIdMap = Collections.MakeMap (FunLoopIdOrderedType)
+
+(** Filter the useless loop input parameters. *)
+let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
+ (bool * pure_fun_translation) list =
+ (* We need to explore groups of mutually recursive functions. In order
+ to compute which parameters are useless, we need to explore the
+ functions by groups of mutually recursive definitions.
+
+ Because every Rust function is translated to a list of functions (forward
+ function, backward functions, loop functions, etc.), and those functions
+ might depend on each others in different ways, we recompute the SCCs of
+ the whole module.
+
+ Rem.: we also redo this computation, on a smaller scale, in {!Translate}.
+ Maybe we can factor out the two.
+ *)
+ let all_decls =
+ List.concat
+ (List.concat
+ (List.concat
+ (List.map
+ (fun (_, ((fwd, loops_fwd), backs)) ->
+ [ fwd :: loops_fwd ]
+ :: List.map
+ (fun (back, loops_back) -> [ back :: loops_back ])
+ backs)
+ transl)))
+ in
+ let subgroups = ReorderDecls.group_reorder_fun_decls all_decls in
+
+ (* Explore the subgroups one by one.
+
+ For now, we only filter the parameters of loop functions which are simply
+ recursive.
+
+ Rem.: there is a bit of redundancy in computing the useless parameters
+ for the loop forward *and* the loop backward functions.
+ *)
+ (* The [filtered] map: maps function identifiers to filtering information.
+
+ Note that we ignore the backward id:
+ - we filter the forward inputs only
+ - we want the filtering to be the same for the forward and the backward
+ functions
+ The reason is that for now we want to preserve the fact that a backward
+ function takes the same inputs as its associated forward function, with
+ additional parameters.
+ *)
+ let used_map = ref FunLoopIdMap.empty in
+ let fun_id_to_fun_loop_id (fid, loop_id, _) = (fid, loop_id) in
+
+ (* We start by computing the filtering information, for each function *)
+ let compute_one_filter_info (decl : fun_decl) =
+ (* There should be a body *)
+ let body = Option.get decl.body in
+ (* We only look at the forward inputs, without the state *)
+ let inputs_prefix, _ =
+ Collections.List.split_at body.inputs
+ decl.signature.info.num_fwd_inputs_with_fuel_no_state
+ in
+ let used = ref (List.map (fun v -> (var_get_id v, false)) inputs_prefix) in
+ let inputs_prefix_length = List.length inputs_prefix in
+ let inputs =
+ List.map
+ (fun v -> (var_get_id v, mk_texpression_from_var v))
+ inputs_prefix
+ in
+ let inputs_set = VarId.Set.of_list (List.map var_get_id inputs_prefix) in
+ assert (Option.is_some decl.loop_id);
+
+ let fun_id = (A.Regular decl.def_id, decl.loop_id) in
+
+ let set_used vid =
+ used := List.map (fun (vid', b) -> (vid', b || vid = vid')) !used
+ in
+
+ (* Set the fuel as used *)
+ let sg_info = decl.signature.info in
+ if sg_info.has_fuel then set_used (fst (Collections.List.nth inputs 0));
+
+ let visitor =
+ object (self : 'self)
+ inherit [_] iter_expression as super
+
+ (** Override the expression visitor, to look for loop function calls *)
+ method! visit_texpression env e =
+ match e.e with
+ | App _ -> (
+ (* If this is an app: destruct all the arguments, and check if
+ the leftmost expression is the loop function call *)
+ let e_app, args = destruct_apps e in
+ match e_app.e with
+ | Qualif qualif -> (
+ match qualif.id with
+ | FunOrOp (Fun (FromLlbc fun_id')) ->
+ if fun_id_to_fun_loop_id fun_id' = fun_id then (
+ (* For each argument, check if it is exactly the original
+ input parameter. Note that there shouldn't be partial
+ applications of loop functions: the number of arguments
+ should be exactly the number of input parameters (i.e.,
+ we can use [combine])
+ *)
+ let beg_args, end_args =
+ Collections.List.split_at args inputs_prefix_length
+ in
+ let used_args = List.combine inputs beg_args in
+ List.iter
+ (fun ((vid, var), arg) ->
+ if var <> arg then (
+ self#visit_texpression env arg;
+ set_used vid))
+ used_args;
+ List.iter (self#visit_texpression env) end_args)
+ else super#visit_texpression env e
+ | _ -> super#visit_texpression env e)
+ | _ -> super#visit_texpression env e)
+ | _ -> super#visit_texpression env e
+
+ (** If we visit a variable which is actually an input parameter, we
+ set it as used. Note that we take care of ignoring some of those
+ input parameters given in [visit_texpression].
+ *)
+ method! visit_var_id _ id =
+ if VarId.Set.mem id inputs_set then set_used id
+ end
+ in
+ visitor#visit_texpression () body.body;
+
+ (* Save the filtering information, if there is anything to filter *)
+ if List.exists snd !used then
+ let used = List.map snd !used in
+ let used =
+ match FunLoopIdMap.find_opt fun_id !used_map with
+ | None -> used
+ | Some used0 ->
+ List.map (fun (b0, b1) -> b0 || b1) (List.combine used0 used)
+ in
+ used_map := FunLoopIdMap.add fun_id used !used_map
+ in
+ List.iter
+ (fun (_, fl) ->
+ match fl with
+ | [ f ] ->
+ (* Group made of one function: check if it is a loop. If it is the
+ case, explore it. *)
+ if Option.is_some f.loop_id then compute_one_filter_info f else ()
+ | _ ->
+ (* Group of mutually recursive functions: ignore for now *)
+ ())
+ subgroups;
+
+ (* We then apply the filtering to all the function definitions at once *)
+ let filter_in_one (decl : fun_decl) : fun_decl =
+ (* Filter the function signature *)
+ let fun_id = (A.Regular decl.def_id, decl.loop_id, decl.back_id) in
+ let decl =
+ match FunLoopIdMap.find_opt (fun_id_to_fun_loop_id fun_id) !used_map with
+ | None -> (* Nothing to filter *) decl
+ | Some used_info ->
+ let num_filtered =
+ List.length (List.filter (fun b -> not b) used_info)
+ in
+ let { type_params; inputs; output; doutputs; info } =
+ decl.signature
+ in
+ let {
+ has_fuel;
+ num_fwd_inputs_with_fuel_no_state;
+ num_fwd_inputs_with_fuel_with_state;
+ num_back_inputs_no_state;
+ num_back_inputs_with_state;
+ effect_info;
+ } =
+ info
+ in
+
+ let inputs = filter_prefix used_info inputs in
+
+ let info =
+ {
+ has_fuel;
+ num_fwd_inputs_with_fuel_no_state =
+ num_fwd_inputs_with_fuel_no_state - num_filtered;
+ num_fwd_inputs_with_fuel_with_state =
+ num_fwd_inputs_with_fuel_with_state - num_filtered;
+ num_back_inputs_no_state;
+ num_back_inputs_with_state;
+ effect_info;
+ }
+ in
+ let signature = { type_params; inputs; output; doutputs; info } in
+
+ { decl with signature }
+ in
+
+ (* Filter the function body *)
+ let body =
+ match decl.body with
+ | None -> None
+ | Some body ->
+ (* Update the list of vars *)
+ let { inputs; inputs_lvs; body } = body in
+
+ let inputs, inputs_lvs =
+ match
+ FunLoopIdMap.find_opt (fun_id_to_fun_loop_id fun_id) !used_map
+ with
+ | None -> (* Nothing to filter *) (inputs, inputs_lvs)
+ | Some used_info ->
+ let inputs = filter_prefix used_info inputs in
+ let inputs_lvs = filter_prefix used_info inputs_lvs in
+ (inputs, inputs_lvs)
+ in
+
+ (* Update the body expression *)
+ let visitor =
+ object (self)
+ inherit [_] map_expression as super
+
+ method! visit_texpression env e =
+ match e.e with
+ | App _ -> (
+ let e_app, args = destruct_apps e in
+ match e_app.e with
+ | Qualif qualif -> (
+ match qualif.id with
+ | FunOrOp (Fun (FromLlbc fun_id)) -> (
+ match
+ FunLoopIdMap.find_opt
+ (fun_id_to_fun_loop_id fun_id)
+ !used_map
+ with
+ | None -> super#visit_texpression env e
+ | Some used_info ->
+ (* Filter the types in the arrow type *)
+ let tys, ret_ty = destruct_arrows e_app.ty in
+ let tys = filter_prefix used_info tys in
+ let ty = mk_arrows tys ret_ty in
+ let e_app = { e_app with ty } in
+
+ (* Filter the arguments *)
+ let args = filter_prefix used_info args in
+
+ (* Explore the arguments *)
+ let args =
+ List.map (self#visit_texpression env) args
+ in
+
+ (* Rebuild *)
+ mk_apps e_app args)
+ | _ ->
+ let e_app = self#visit_texpression env e_app in
+ let args =
+ List.map (self#visit_texpression env) args
+ in
+ mk_apps e_app args)
+ | _ ->
+ let e_app = self#visit_texpression env e_app in
+ let args = List.map (self#visit_texpression env) args in
+ mk_apps e_app args)
+ | _ -> super#visit_texpression env e
+ end
+ in
+ let body = visitor#visit_texpression () body in
+ Some { inputs; inputs_lvs; body }
+ in
+ { decl with body }
+ in
+ let transl =
+ List.map
+ (fun (b, (fwd, backs)) ->
+ let filter_fun_and_loops (f, fl) =
+ (filter_in_one f, List.map filter_in_one fl)
+ in
+ let fwd = filter_fun_and_loops fwd in
+ let backs = List.map filter_fun_and_loops backs in
+ (b, (fwd, backs)))
+ transl
+ in
+
+ (* Return *)
+ transl
+
+(** Apply the micro-passes to a list of forward/backward translations.
This function also extracts the loop definitions from the function body
(see {!decompose_loops}).
- Also returns a boolean indicating whether the forward function should be kept
- or not (because useful/useless - [true] means we need to keep the forward
- function).
+ It also returns a boolean indicating whether the forward function should be kept
+ or not at extraction time ([true] means we need to keep the forward function).
+
Note that we don't "filter" the forward function and return a boolean instead,
because this function contains useful information to extract the backward
- functions: keeping it is not necessary but more convenient.
+ functions. Note that here, keeping the forward function it is not *necessary*
+ but convenient.
*)
-let apply_passes_to_pure_fun_translation (ctx : trans_ctx)
- (trans : fun_decl * fun_decl list) : bool * pure_fun_translation =
- (* Apply the passes to the individual functions *)
- let forward, backwards = trans in
- let forward = Option.get (apply_passes_to_def ctx forward) in
- let backwards = List.filter_map (apply_passes_to_def ctx) backwards in
- let trans = (forward, backwards) in
- (* Compute whether we need to filter the forward function or not *)
- (keep_forward trans, trans)
+let apply_passes_to_pure_fun_translations (ctx : trans_ctx)
+ (transl : (fun_decl * fun_decl list) list) :
+ (bool * pure_fun_translation) list =
+ let apply_to_one (trans : fun_decl * fun_decl list) :
+ bool * pure_fun_translation =
+ (* Apply the passes to the individual functions *)
+ let forward, backwards = trans in
+ let forward = Option.get (apply_passes_to_def ctx forward) in
+ let backwards = List.filter_map (apply_passes_to_def ctx) backwards in
+ let trans = (forward, backwards) in
+ (* Compute whether we need to filter the forward function or not *)
+ (keep_forward trans, trans)
+ in
+ let transl = List.map apply_to_one transl in
+
+ (* Filter the useless inputs in the loop functions *)
+ filter_loop_inputs transl