summaryrefslogtreecommitdiff
path: root/compiler/PureMicroPasses.ml
diff options
context:
space:
mode:
authorSon Ho2023-11-29 14:26:04 +0100
committerSon Ho2023-11-29 14:26:04 +0100
commit0273fee7f6b74da1d3b66c3c6a2158c012d04197 (patch)
tree5f6db32814f6f0b3a98f2de1db39225ff2c7645d /compiler/PureMicroPasses.ml
parentf4e2c2bb09d9d7b54afc0692b7f690f5ec2eb029 (diff)
parent90e42e0e1c1889aabfa66283fb15b43a5852a02a (diff)
Merge branch 'main' into afromher_shifts
Diffstat (limited to 'compiler/PureMicroPasses.ml')
-rw-r--r--compiler/PureMicroPasses.ml272
1 files changed, 120 insertions, 152 deletions
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index b6025df4..d0741b29 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -3,10 +3,13 @@
open Pure
open PureUtils
open TranslateCore
-module V = Values
(** The local logger *)
-let log = L.pure_micro_passes_log
+let log = Logging.pure_micro_passes_log
+
+let fun_decl_to_string (ctx : trans_ctx) (def : Pure.fun_decl) : string =
+ let fmt = trans_ctx_to_pure_fmt_env ctx in
+ PrintPure.fun_decl_to_string fmt def
(** Small utility.
@@ -376,8 +379,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
let ty = e.ty in
let ctx, e =
match e.e with
- | Var _ -> (* Nothing to do *) (ctx, e.e)
- | Const _ -> (* Nothing to do *) (ctx, e.e)
+ | Var _ | CVar _ | Const _ -> (* Nothing to do *) (ctx, e.e)
| App (app, arg) ->
let ctx, app = update_texpression app ctx in
let ctx, arg = update_texpression arg ctx in
@@ -389,7 +391,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
| Switch (scrut, body) -> update_switch_body scrut body ctx
| Loop loop -> update_loop loop ctx
| StructUpdate supd -> update_struct_update supd ctx
- | Meta (meta, e) -> update_meta meta e ctx
+ | Meta (meta, e) -> update_emeta meta e ctx
in
(ctx, { e; ty })
(* *)
@@ -447,6 +449,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
let {
fun_end;
loop_id;
+ meta;
fuel0;
fuel;
input_state;
@@ -465,6 +468,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
{
fun_end;
loop_id;
+ meta;
fuel0;
fuel;
input_state;
@@ -488,7 +492,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
let supd = { struct_id; init; updates } in
(ctx, StructUpdate supd)
(* *)
- and update_meta (meta : meta) (e : texpression) (ctx : pn_ctx) :
+ and update_emeta (meta : emeta) (e : texpression) (ctx : pn_ctx) :
pn_ctx * expression =
let ctx =
match meta with
@@ -514,7 +518,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
| Tag _ -> ctx
in
let ctx, e = update_texpression e ctx in
- let e = mk_meta meta e in
+ let e = mk_emeta meta e in
(ctx, e.e)
in
@@ -583,14 +587,11 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
match app.e with
| Qualif
{
- id = AdtCons { adt_id = AdtId adt_id; variant_id = None };
- type_args = _;
- const_generic_args = _;
+ id = AdtCons { adt_id = TAdtId adt_id; variant_id = None };
+ generics = _;
} ->
(* Lookup the def *)
- let decl =
- TypeDeclId.Map.find adt_id ctx.type_context.type_decls
- in
+ let decl = TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls in
(* Check that there are as many arguments as there are fields - note
that the def should have a body (otherwise we couldn't use the
constructor) *)
@@ -599,11 +600,10 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
(* Check if the definition is recursive *)
let is_rec =
match
- TypeDeclId.Map.find adt_id
- ctx.type_context.type_decls_groups
+ TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls_groups
with
- | NonRec _ -> false
- | Rec _ -> true
+ | NonRecGroup _ -> false
+ | RecGroup _ -> true
in
(* Convert, if possible - note that for now for Lean and Coq
we don't support the structure syntax on recursive structures *)
@@ -611,7 +611,7 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
(!Config.backend <> Lean && !Config.backend <> Coq)
|| not is_rec
then
- let struct_id = AdtId adt_id in
+ let struct_id = TAdtId adt_id in
let init = None in
let updates =
FieldId.mapi
@@ -682,8 +682,8 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool)
| _ -> false
in
(* And either:
- * 2.1 the right-expression is a variable or a global *)
- let var_or_global = is_var re || is_global re in
+ * 2.1 the right-expression is a variable, a global or a const generic var *)
+ let var_or_global = is_var re || is_cvar re || is_global re in
(* Or:
* 2.2 the right-expression is a constant value, an ADT value,
* a projection or a primitive function call *and* the flag
@@ -767,10 +767,10 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool)
In this situation, we can remove the call [f@fwd x].
*)
let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
- (id0 : A.fun_id) (lp_id0 : LoopId.id option)
- (rg_id0 : T.RegionGroupId.id option) (tys0 : ty list)
+ (id0 : fun_id_or_trait_method_ref) (lp_id0 : LoopId.id option)
+ (rg_id0 : T.RegionGroupId.id option) (generics0 : generic_args)
(args0 : texpression list) (e : texpression) : bool =
- let check_call (fun_id1 : fun_or_op_id) (tys1 : ty list)
+ let check_call (fun_id1 : fun_or_op_id) (generics1 : generic_args)
(args1 : texpression list) : bool =
(* Check the fun_ids, to see if call1's function is a child of call0's function *)
match fun_id1 with
@@ -791,13 +791,19 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
if rg_id0 = rg_id1 then true
else
(* We need to use the regions hierarchy *)
- (* First, lookup the signature of the LLBC function *)
- let sg =
- LlbcAstUtils.lookup_fun_sig id0 ctx.fun_context.fun_decls
+ let regions_hierarchy =
+ let id0 =
+ match id0 with
+ | FunId fun_id -> fun_id
+ | TraitMethod (_, _, fun_decl_id) -> FRegular fun_decl_id
+ in
+ LlbcAstUtils.FunIdMap.find id0
+ ctx.fun_ctx.regions_hierarchies
in
(* Compute the set of ancestors of the function in call1 *)
let call1_ancestors =
- LlbcAstUtils.list_ancestor_region_groups sg rg_id1
+ LlbcAstUtils.list_ancestor_region_groups regions_hierarchy
+ rg_id1
in
(* Check if the function used in call0 is inside *)
T.RegionGroupId.Set.mem rg_id0 call1_ancestors
@@ -817,8 +823,8 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
let input_eq (v0, v1) =
PureUtils.remove_meta v0 = PureUtils.remove_meta v1
in
- (* Compare the input types and the prefix of the input arguments *)
- tys0 = tys1 && List.for_all input_eq args
+ (* Compare the generics and the prefix of the input arguments *)
+ generics0 = generics1 && List.for_all input_eq args
else (* Not a child *)
false
else (* Not the same function *)
@@ -834,7 +840,7 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
method! visit_texpression env e =
match e.e with
- | Var _ | Const _ -> fun _ -> false
+ | Var _ | CVar _ | Const _ -> fun _ -> false
| StructUpdate _ ->
(* There shouldn't be monadic calls in structure updates - also
note that by returning [false] we are conservative: we might
@@ -844,8 +850,8 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
| Let (_, _, re, e) -> (
match opt_destruct_function_call re with
| None -> fun () -> self#visit_texpression env e ()
- | Some (func1, tys1, args1) ->
- let call_is_child = check_call func1 tys1 args1 in
+ | Some (func1, generics1, args1) ->
+ let call_is_child = check_call func1 generics1 args1 in
if call_is_child then fun () -> true
else fun () -> self#visit_texpression env e ())
| App _ -> (
@@ -930,7 +936,7 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
method! visit_expression env e =
match e with
- | Var _ | Const _ | App _ | Qualif _
+ | Var _ | CVar _ | Const _ | App _ | Qualif _
| Switch (_, _)
| Meta (_, _)
| StructUpdate _ | Abs _ ->
@@ -1085,14 +1091,13 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
match app.e with
| Qualif
{
- id = AdtCons { adt_id = AdtId adt_id; variant_id = None };
- type_args;
- const_generic_args;
+ id = AdtCons { adt_id = TAdtId adt_id; variant_id = None };
+ generics;
} ->
(* This is a struct *)
(* Retrieve the definiton, to find how many fields there are *)
let adt_decl =
- TypeDeclId.Map.find adt_id ctx.type_context.type_decls
+ TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls
in
let fields =
match adt_decl.kind with
@@ -1108,24 +1113,22 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
* [x.field] for some variable [x], and where the projection
* is for the proper ADT *)
let to_var_proj (i : int) (arg : texpression) :
- (ty list * const_generic list * var_id) option =
+ (generic_args * var_id) option =
match arg.e with
| App (proj, x) -> (
match (proj.e, x.e) with
| ( Qualif
{
id =
- Proj { adt_id = AdtId proj_adt_id; field_id };
- type_args = proj_type_args;
- const_generic_args = proj_const_generic_args;
+ Proj { adt_id = TAdtId proj_adt_id; field_id };
+ generics = proj_generics;
},
Var v ) ->
(* We check that this is the proper ADT, and the proper field *)
if
proj_adt_id = adt_id
&& FieldId.to_int field_id = i
- then
- Some (proj_type_args, proj_const_generic_args, v)
+ then Some (proj_generics, v)
else None
| _ -> None)
| _ -> None
@@ -1136,14 +1139,13 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
if List.length args = num_fields then
(* Check that this is the same variable we project from -
* note that we checked above that there is at least one field *)
- let (_, _, x), end_args = Collections.List.pop args in
- if List.for_all (fun (_, _, y) -> y = x) end_args then (
+ let (_, x), end_args = Collections.List.pop args in
+ if List.for_all (fun (_, y) -> y = x) end_args then (
(* We can substitute *)
(* Sanity check: all types correct *)
assert (
List.for_all
- (fun (tys, cgs, _) ->
- tys = type_args && cgs = const_generic_args)
+ (fun (generics1, _) -> generics1 = generics)
args);
{ e with e = Var x })
else super#visit_texpression env e
@@ -1161,14 +1163,13 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
match (proj.e, x.e) with
| ( Qualif
{
- id = Proj { adt_id = AdtId proj_adt_id; field_id };
- type_args = _;
- const_generic_args = _;
+ id = Proj { adt_id = TAdtId proj_adt_id; field_id };
+ generics = _;
},
Var v ) ->
(* We check that this is the proper ADT, and the proper field *)
if
- AdtId proj_adt_id = struct_id
+ TAdtId proj_adt_id = struct_id
&& field_id = fid && x.ty = adt_ty
then Some v
else None
@@ -1251,6 +1252,7 @@ let filter_if_backward_with_no_outputs (def : fun_decl) : fun_decl option =
!Config.filter_useless_functions
&& Option.is_some def.back_id
&& def.signature.output = mk_result_ty mk_unit_ty
+ || def.signature.output = mk_unit_ty
then None
else Some def
@@ -1361,8 +1363,9 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
let loop_sig =
{
- type_params = fun_sig.type_params;
- const_generic_params = fun_sig.const_generic_params;
+ generics = fun_sig.generics;
+ llbc_generics = fun_sig.llbc_generics;
+ preds = fun_sig.preds;
inputs = inputs_tys;
output;
doutputs;
@@ -1424,13 +1427,17 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
let loop_body = { inputs; inputs_lvs; body = loop_body } in
- let loop_def =
+ let loop_def : fun_decl =
{
def_id = def.def_id;
+ is_local = def.is_local;
+ meta = loop.meta;
+ kind = def.kind;
num_loops;
loop_id = Some loop.loop_id;
back_id = def.back_id;
- basename = def.basename;
+ llbc_name = def.llbc_name;
+ name = def.name;
signature = loop_sig;
is_global_decl_body = def.is_global_decl_body;
body = Some loop_body;
@@ -1466,13 +1473,12 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
In such situation, we can remove the forward function definition
altogether.
*)
-let keep_forward (trans : pure_fun_translation) : bool =
- let (fwd, _), backs = trans in
+let keep_forward (fwd : fun_and_loops) (backs : fun_and_loops list) : bool =
(* Note that at this point, the output types are no longer seen as tuples:
* they should be lists of length 1. *)
if
!Config.filter_useless_functions
- && fwd.signature.output = mk_result_ty mk_unit_ty
+ && fwd.f.signature.output = mk_result_ty mk_unit_ty
&& backs <> []
then false
else true
@@ -1508,8 +1514,8 @@ let unit_vars_to_unit (def : fun_decl) : fun_decl =
let body = Some { body with body = body_exp; inputs_lvs } in
{ def with body }
-(** Eliminate the box functions like [Box::new], [Box::deref], etc. Most of them
- are translated to identity, and [Box::free] is translated to [()].
+(** Eliminate the box functions like [Box::new] (which is translated to the
+ identity) and [Box::free] (which is translated to [()]).
Note that the box types have already been eliminated during the translation
from symbolic to pure.
@@ -1527,48 +1533,23 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
method! visit_texpression env e =
match opt_destruct_function_call e with
| Some (fun_id, _tys, args) -> (
+ (* Below, when dealing with the arguments: we consider the very
+ * general case, where functions could be boxed (meaning we
+ * could have: [box_new f x])
+ * *)
match fun_id with
- | Fun (FromLlbc (A.Assumed aid, _lp_id, rg_id)) -> (
- (* Below, when dealing with the arguments: we consider the very
- * general case, where functions could be boxed (meaning we
- * could have: [box_new f x])
- * *)
+ | Fun (FromLlbc (FunId (FAssumed aid), _lp_id, rg_id)) -> (
match (aid, rg_id) with
- | A.BoxNew, _ ->
+ | BoxNew, _ ->
assert (rg_id = None);
let arg, args = Collections.List.pop args in
mk_apps arg args
- | A.BoxDeref, None ->
- (* [Box::deref] forward is the identity *)
- let arg, args = Collections.List.pop args in
- mk_apps arg args
- | A.BoxDeref, Some _ ->
- (* [Box::deref] backward is [()] (doesn't give back anything) *)
- assert (args = []);
- mk_unit_rvalue
- | A.BoxDerefMut, None ->
- (* [Box::deref_mut] forward is the identity *)
- let arg, args = Collections.List.pop args in
- mk_apps arg args
- | A.BoxDerefMut, Some _ ->
- (* [Box::deref_mut] back is almost the identity:
- * let box_deref_mut (x_init : t) (x_back : t) : t = x_back
- * *)
- let arg, args =
- match args with
- | _ :: given_back :: args -> (given_back, args)
- | _ -> raise (Failure "Unreachable")
- in
- mk_apps arg args
- | A.BoxFree, _ ->
+ | BoxFree, _ ->
assert (args = []);
mk_unit_rvalue
- | ( ( A.Replace | VecNew | VecPush | VecInsert | VecLen
- | VecIndex | VecIndexMut | ArraySubsliceShared
- | ArraySubsliceMut | SliceIndexShared | SliceIndexMut
- | SliceSubsliceShared | SliceSubsliceMut | ArrayIndexShared
+ | ( ( SliceIndexShared | SliceIndexMut | ArrayIndexShared
| ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut
- | SliceLen ),
+ | ArrayRepeat ),
_ ) ->
super#visit_texpression env e)
| _ -> super#visit_texpression env e)
@@ -1914,13 +1895,11 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl =
[ctx]: used only for printing.
*)
let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
- (fun_decl * fun_decl list) option =
+ fun_and_loops option =
(* Debug *)
log#ldebug
(lazy
- ("PureMicroPasses.apply_passes_to_def: "
- ^ Print.fun_name_to_string def.basename
- ^ " ("
+ ("PureMicroPasses.apply_passes_to_def: " ^ def.name ^ " ("
^ Print.option_to_string T.RegionGroupId.to_string def.back_id
^ ")"));
@@ -1946,18 +1925,24 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
* Note that the calls to those functions should already have been removed,
* when translating from symbolic to pure. Here, we remove the definitions
* altogether, because they are now useless *)
- let def = filter_if_backward_with_no_outputs def in
+ let name = def.name ^ PrintPure.fun_suffix def.loop_id def.back_id in
+ let opt_def = filter_if_backward_with_no_outputs def in
- match def with
- | None -> None
+ match opt_def with
+ | None ->
+ log#ldebug (lazy ("filtered (backward with no outputs): " ^ name ^ "\n"));
+ None
| Some def ->
+ log#ldebug
+ (lazy ("not filtered (not backward with no outputs): " ^ name ^ "\n"));
+
(* Extract the loop definitions by removing the {!Loop} node *)
let def, loops = decompose_loops def in
(* Apply the remaining passes *)
- let def = apply_end_passes_to_def ctx def in
+ let f = apply_end_passes_to_def ctx def in
let loops = List.map (apply_end_passes_to_def ctx) loops in
- Some (def, loops)
+ Some { f; loops }
(** Small utility for {!filter_loop_inputs} *)
let filter_prefix (keep : bool list) (ls : 'a list) : 'a list =
@@ -1983,8 +1968,8 @@ end
module FunLoopIdMap = Collections.MakeMap (FunLoopIdOrderedType)
(** Filter the useless loop input parameters. *)
-let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
- (bool * pure_fun_translation) list =
+let filter_loop_inputs (transl : pure_fun_translation list) :
+ pure_fun_translation list =
(* We need to explore groups of mutually recursive functions. In order
to compute which parameters are useless, we need to explore the
functions by groups of mutually recursive definitions.
@@ -2002,10 +1987,11 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
(List.concat
(List.concat
(List.map
- (fun (_, ((fwd, loops_fwd), backs)) ->
- [ fwd :: loops_fwd ]
+ (fun { fwd; backs; _ } ->
+ [ fwd.f :: fwd.loops ]
:: List.map
- (fun (back, loops_back) -> [ back :: loops_back ])
+ (fun { f = back; loops = loops_back } ->
+ [ back :: loops_back ])
backs)
transl)))
in
@@ -2030,7 +2016,6 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
additional parameters.
*)
let used_map = ref FunLoopIdMap.empty in
- let fun_id_to_fun_loop_id (fid, loop_id, _) = (fid, loop_id) in
(* We start by computing the filtering information, for each function *)
let compute_one_filter_info (decl : fun_decl) =
@@ -2051,7 +2036,7 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
let inputs_set = VarId.Set.of_list (List.map var_get_id inputs_prefix) in
assert (Option.is_some decl.loop_id);
- let fun_id = (A.Regular decl.def_id, decl.loop_id) in
+ let fun_id = (E.FRegular decl.def_id, decl.loop_id) in
let set_used vid =
used := List.map (fun (vid', b) -> (vid', b || vid = vid')) !used
@@ -2075,8 +2060,8 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
match e_app.e with
| Qualif qualif -> (
match qualif.id with
- | FunOrOp (Fun (FromLlbc fun_id')) ->
- if fun_id_to_fun_loop_id fun_id' = fun_id then (
+ | FunOrOp (Fun (FromLlbc (FunId fun_id', loop_id', _))) ->
+ if (fun_id', loop_id') = fun_id then (
(* For each argument, check if it is exactly the original
input parameter. Note that there shouldn't be partial
applications of loop functions: the number of arguments
@@ -2135,22 +2120,16 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
(* We then apply the filtering to all the function definitions at once *)
let filter_in_one (decl : fun_decl) : fun_decl =
(* Filter the function signature *)
- let fun_id = (A.Regular decl.def_id, decl.loop_id, decl.back_id) in
+ let fun_id = (E.FRegular decl.def_id, decl.loop_id) in
let decl =
- match FunLoopIdMap.find_opt (fun_id_to_fun_loop_id fun_id) !used_map with
+ match FunLoopIdMap.find_opt fun_id !used_map with
| None -> (* Nothing to filter *) decl
| Some used_info ->
let num_filtered =
List.length (List.filter (fun b -> not b) used_info)
in
- let {
- type_params;
- const_generic_params;
- inputs;
- output;
- doutputs;
- info;
- } =
+ let { generics; llbc_generics; preds; inputs; output; doutputs; info }
+ =
decl.signature
in
let {
@@ -2179,14 +2158,7 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
}
in
let signature =
- {
- type_params;
- const_generic_params;
- inputs;
- output;
- doutputs;
- info;
- }
+ { generics; llbc_generics; preds; inputs; output; doutputs; info }
in
{ decl with signature }
@@ -2201,9 +2173,7 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
let { inputs; inputs_lvs; body } = body in
let inputs, inputs_lvs =
- match
- FunLoopIdMap.find_opt (fun_id_to_fun_loop_id fun_id) !used_map
- with
+ match FunLoopIdMap.find_opt fun_id !used_map with
| None -> (* Nothing to filter *) (inputs, inputs_lvs)
| Some used_info ->
let inputs = filter_prefix used_info inputs in
@@ -2223,11 +2193,10 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
match e_app.e with
| Qualif qualif -> (
match qualif.id with
- | FunOrOp (Fun (FromLlbc fun_id)) -> (
+ | FunOrOp (Fun (FromLlbc (FunId fun_id, loop_id, _)))
+ -> (
match
- FunLoopIdMap.find_opt
- (fun_id_to_fun_loop_id fun_id)
- !used_map
+ FunLoopIdMap.find_opt (fun_id, loop_id) !used_map
with
| None -> super#visit_texpression env e
| Some used_info ->
@@ -2267,13 +2236,13 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
in
let transl =
List.map
- (fun (b, (fwd, backs)) ->
- let filter_fun_and_loops (f, fl) =
- (filter_in_one f, List.map filter_in_one fl)
+ (fun trans ->
+ let filter_fun_and_loops f =
+ { f = filter_in_one f.f; loops = List.map filter_in_one f.loops }
in
- let fwd = filter_fun_and_loops fwd in
- let backs = List.map filter_fun_and_loops backs in
- (b, (fwd, backs)))
+ let fwd = filter_fun_and_loops trans.fwd in
+ let backs = List.map filter_fun_and_loops trans.backs in
+ { trans with fwd; backs })
transl
in
@@ -2294,18 +2263,17 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
but convenient.
*)
let apply_passes_to_pure_fun_translations (ctx : trans_ctx)
- (transl : (fun_decl * fun_decl list) list) :
- (bool * pure_fun_translation) list =
- let apply_to_one (trans : fun_decl * fun_decl list) :
- bool * pure_fun_translation =
+ (transl : (fun_decl * fun_decl list) list) : pure_fun_translation list =
+ let apply_to_one (trans : fun_decl * fun_decl list) : pure_fun_translation =
(* Apply the passes to the individual functions *)
- let forward, backwards = trans in
- let forward = Option.get (apply_passes_to_def ctx forward) in
- let backwards = List.filter_map (apply_passes_to_def ctx) backwards in
- let trans = (forward, backwards) in
+ let fwd, backs = trans in
+ let fwd = Option.get (apply_passes_to_def ctx fwd) in
+ let backs = List.filter_map (apply_passes_to_def ctx) backs in
(* Compute whether we need to filter the forward function or not *)
- (keep_forward trans, trans)
+ let keep_fwd = keep_forward fwd backs in
+ { keep_fwd; fwd; backs }
in
+
let transl = List.map apply_to_one transl in
(* Filter the useless inputs in the loop functions *)