summaryrefslogtreecommitdiff
path: root/compiler/PureMicroPasses.ml
diff options
context:
space:
mode:
authorSon Ho2024-03-08 16:06:35 +0100
committerSon Ho2024-03-08 16:06:35 +0100
commitfe2a2cb34148e46e32cdcfbf100e38d9986082cd (patch)
treea378134d495985718b842a786bad573695974b95 /compiler/PureMicroPasses.ml
parentbc154dda94c44b3ae67a3b04d3866cc473aead32 (diff)
Make progress on propagating the changes
Diffstat (limited to 'compiler/PureMicroPasses.ml')
-rw-r--r--compiler/PureMicroPasses.ml331
1 files changed, 31 insertions, 300 deletions
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index ec64df21..04bc90d7 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -925,156 +925,9 @@ let inline_useless_var_reassignments (ctx : trans_ctx) ~(inline_named : bool)
in
{ def with body = Some body }
-(** For the cases where we split the forward/backward functions.
-
- Given a forward or backward function call, is there, for every execution
- path, a child backward function called later with exactly the same input
- list prefix. We use this to filter useless function calls: if there are
- such child calls, we can remove this one (in case its outputs are not
- used).
- We do this check because we can't simply remove function calls whose
- outputs are not used, as they might fail. However, if a function fails,
- its children backward functions then fail on the same inputs (ignoring
- the additional inputs those receive).
-
- For instance, if we have:
- {[
- fn f<'a>(x : &'a mut T);
- ]}
-
- We often have things like this in the synthesized code:
- {[
- _ <-- f@fwd x;
- ...
- nx <-- f@back'a x y;
- ...
- ]}
-
- If [f@back'a x y] fails, then necessarily [f@fwd x] also fails.
- In this situation, we can remove the call [f@fwd x].
- *)
-let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
- (id0 : fun_id_or_trait_method_ref) (lp_id0 : LoopId.id option)
- (rg_id0 : T.RegionGroupId.id option) (generics0 : generic_args)
- (args0 : texpression list) (e : texpression) : bool =
- let check_call (fun_id1 : fun_or_op_id) (generics1 : generic_args)
- (args1 : texpression list) : bool =
- (* Check the fun_ids, to see if call1's function is a child of call0's function *)
- match fun_id1 with
- | Fun (FromLlbc (id1, lp_id1, rg_id1)) ->
- (* Both are "regular" calls: check if they come from the same rust function *)
- if id0 = id1 && lp_id0 = lp_id1 then
- (* Same rust functions: check the regions hierarchy *)
- let call1_is_child =
- match (rg_id0, rg_id1) with
- | None, _ ->
- (* The function used in call0 is the forward function: the one
- * used in call1 is necessarily a child *)
- true
- | Some _, None ->
- (* Opposite of previous case *)
- false
- | Some rg_id0, Some rg_id1 ->
- if rg_id0 = rg_id1 then true
- else
- (* We need to use the regions hierarchy *)
- let regions_hierarchy =
- let id0 =
- match id0 with
- | FunId fun_id -> fun_id
- | TraitMethod (_, _, fun_decl_id) -> FRegular fun_decl_id
- in
- LlbcAstUtils.FunIdMap.find id0
- ctx.fun_ctx.regions_hierarchies
- in
- (* Compute the set of ancestors of the function in call1 *)
- let call1_ancestors =
- LlbcAstUtils.list_ancestor_region_groups regions_hierarchy
- rg_id1
- in
- (* Check if the function used in call0 is inside *)
- T.RegionGroupId.Set.mem rg_id0 call1_ancestors
- in
- (* If call1 is a child, then we need to check if the input arguments
- * used in call0 are a prefix of the input arguments used in call1
- * (note call1 being a child, it will likely consume strictly more
- * given back values).
- * *)
- if call1_is_child then
- let call1_args =
- Collections.List.prefix (List.length args0) args1
- in
- let args = List.combine args0 call1_args in
- (* Note that the input values are expressions, *which may contain
- * meta-values* (which we need to ignore). *)
- let input_eq (v0, v1) =
- PureUtils.remove_meta v0 = PureUtils.remove_meta v1
- in
- (* Compare the generics and the prefix of the input arguments *)
- generics0 = generics1 && List.for_all input_eq args
- else (* Not a child *)
- false
- else (* Not the same function *)
- false
- | _ -> false
- in
-
- let visitor =
- object (self)
- inherit [_] reduce_expression
- method zero _ = false
- method plus b0 b1 _ = b0 () && b1 ()
-
- method! visit_texpression env e =
- match e.e with
- | Var _ | CVar _ | Const _ -> fun _ -> false
- | StructUpdate _ ->
- (* There shouldn't be monadic calls in structure updates - also
- note that by returning [false] we are conservative: we might
- *prevent* possible optimisations (i.e., filtering some function
- calls), which is sound. *)
- fun _ -> false
- | Let (_, _, re, e) -> (
- match opt_destruct_function_call re with
- | None -> fun () -> self#visit_texpression env e ()
- | Some (func1, generics1, args1) ->
- let call_is_child = check_call func1 generics1 args1 in
- if call_is_child then fun () -> true
- else fun () -> self#visit_texpression env e ())
- | Lambda (_, e) -> self#visit_texpression env e
- | App _ -> (
- fun () ->
- match opt_destruct_function_call e with
- | Some (func1, tys1, args1) -> check_call func1 tys1 args1
- | None -> false)
- | Qualif _ ->
- (* Note that this case includes functions without arguments *)
- fun () -> false
- | Meta (_, e) -> self#visit_texpression env e
- | Loop loop ->
- (* We only visit the *function end* *)
- self#visit_texpression env loop.fun_end
- | Switch (_, body) -> self#visit_switch_body env body
-
- method! visit_switch_body env body =
- match body with
- | If (e1, e2) ->
- fun () ->
- self#visit_texpression env e1 ()
- && self#visit_texpression env e2 ()
- | Match branches ->
- fun () ->
- List.for_all
- (fun br -> self#visit_texpression env br.branch ())
- branches
- end
- in
- visitor#visit_texpression () e ()
-
(** Filter the useless assignments (removes the useless variables, filters
the function calls) *)
-let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
- (def : fun_decl) : fun_decl =
+let filter_useless (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
(* We first need a transformation on *left-values*, which filters the useless
* variables and tells us whether the value contains any variable which has
* not been replaced by [_] (in which case we need to keep the assignment,
@@ -1166,30 +1019,8 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
if not monadic then
(* Not a monadic let-binding: simple case *)
(e.e, fun _ -> used)
- else
- (* Monadic let-binding: trickier.
- * We can filter if the right-expression is a function call,
- * under some conditions. *)
- match (filter_monadic_calls, opt_destruct_function_call re) with
- | true, Some (Fun (FromLlbc (fid, lp_id, rg_id)), tys, args) ->
- (* If we split the forward/backward functions.
-
- We need to check if there is a child call - see
- the comments for:
- [expression_contains_child_call_in_all_paths] *)
- if not !Config.return_back_funs then
- let has_child_call =
- expression_contains_child_call_in_all_paths ctx fid
- lp_id rg_id tys args e
- in
- if has_child_call then (* Filter *)
- (e.e, fun _ -> used)
- else (* No child call: don't filter *)
- dont_filter ()
- else dont_filter ()
- | _ ->
- (* Not an LLBC function call or not allowed to filter: we can't filter *)
- dont_filter ()
+ else (* Monadic let-binding: can't filter *)
+ dont_filter ()
else (* There are used variables: don't filter *)
dont_filter ()
| Loop loop ->
@@ -1442,22 +1273,6 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
let body = { body with body = body_exp } in
{ def with body = Some body }
-(** Return [None] if the function is a backward function with no outputs (so
- that we eliminate the definition which is useless).
-
- Note that the calls to such functions are filtered when translating from
- symbolic to pure. Here, we remove the definitions altogether, because they
- are now useless
- *)
-let filter_if_backward_with_no_outputs (def : fun_decl) : fun_decl option =
- if
- !Config.filter_useless_functions
- && Option.is_some def.back_id
- && def.signature.output = mk_result_ty mk_unit_ty
- || def.signature.output = mk_unit_ty
- then None
- else Some def
-
(** Retrieve the loop definitions from the function definition.
{!SymbolicToPure} generates an AST in which the loop bodies are part of
@@ -1530,14 +1345,7 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) :
info.num_inputs_with_fuel_no_state
info.num_inputs_with_fuel_with_state
in
- let back_inputs =
- if !Config.return_back_funs then []
- else
- snd
- (Collections.List.split_at fun_sig.inputs
- info.num_inputs_with_fuel_with_state)
- in
- List.concat [ fuel; fwd_inputs; fwd_state; back_inputs ]
+ List.concat [ fuel; fwd_inputs; fwd_state ]
in
let output = loop.output_ty in
@@ -1618,7 +1426,6 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) :
kind = def.kind;
num_loops;
loop_id = Some loop.loop_id;
- back_id = def.back_id;
llbc_name = def.llbc_name;
name = def.name;
signature = loop_sig;
@@ -1640,35 +1447,6 @@ let decompose_loops (_ctx : trans_ctx) (def : fun_decl) :
let loops = List.map snd (LoopId.Map.bindings !loops) in
(def, loops)
-(** Return [false] if the forward function is useless and should be filtered.
-
- - a forward function with no output (comes from a Rust function with
- unit return type)
- - the function has mutable borrows as inputs (which is materialized
- by the fact we generated backward functions which were not filtered).
-
- In such situation, every call to the Rust function will be translated to:
- - a call to the forward function which returns nothing
- - calls to the backward functions
- As a failing backward function implies the forward function also fails,
- we can filter the calls to the forward function, which thus becomes
- useless.
- In such situation, we can remove the forward function definition
- altogether.
- *)
-let keep_forward (fwd : fun_and_loops) (backs : fun_and_loops list) : bool =
- (* The question of filtering the forward functions arises only if we split
- the forward/backward functions *)
- if !Config.return_back_funs then true
- else if
- (* Note that at this point, the output types are no longer seen as tuples:
- * they should be lists of length 1. *)
- !Config.filter_useless_functions
- && fwd.f.signature.output = mk_result_ty mk_unit_ty
- && backs <> []
- then false
- else true
-
(** Convert the unit variables to [()] if they are used as right-values or
[_] if they are used as left values in patterns. *)
let unit_vars_to_unit (def : fun_decl) : fun_decl =
@@ -1724,19 +1502,17 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
* could have: [box_new f x])
* *)
match fun_id with
- | Fun (FromLlbc (FunId (FAssumed aid), _lp_id, rg_id)) -> (
- match (aid, rg_id) with
- | BoxNew, _ ->
- assert (rg_id = None);
+ | Fun (FromLlbc (FunId (FAssumed aid), _lp_id)) -> (
+ match aid with
+ | BoxNew ->
let arg, args = Collections.List.pop args in
mk_apps arg args
- | BoxFree, _ ->
+ | BoxFree ->
assert (args = []);
mk_unit_rvalue
- | ( ( SliceIndexShared | SliceIndexMut | ArrayIndexShared
- | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut
- | ArrayRepeat ),
- _ ) ->
+ | SliceIndexShared | SliceIndexMut | ArrayIndexShared
+ | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut
+ | ArrayRepeat ->
super#visit_texpression env e)
| _ -> super#visit_texpression env e)
| _ -> super#visit_texpression env e
@@ -1989,7 +1765,7 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl =
(lazy ("eliminate_box_functions:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
(* Filter the useless variables, assignments, function calls, etc. *)
- let def = filter_useless !Config.filter_useless_monadic_calls ctx def in
+ let def = filter_useless ctx def in
log#ldebug (lazy ("filter_useless:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
(* Simplify the lets immediately followed by a return.
@@ -2130,16 +1906,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
*)
let all_decls =
List.concat
- (List.concat
- (List.concat
- (List.map
- (fun { fwd; backs; _ } ->
- [ fwd.f :: fwd.loops ]
- :: List.map
- (fun { f = back; loops = loops_back } ->
- [ back :: loops_back ])
- backs)
- transl)))
+ (List.concat (List.map (fun { f; loops } -> [ f :: loops ]) transl))
in
let subgroups = ReorderDecls.group_reorder_fun_decls all_decls in
@@ -2207,7 +1974,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
match e_app.e with
| Qualif qualif -> (
match qualif.id with
- | FunOrOp (Fun (FromLlbc (FunId fun_id', loop_id', _))) ->
+ | FunOrOp (Fun (FromLlbc (FunId fun_id', loop_id'))) ->
if (fun_id', loop_id') = fun_id then (
(* For each argument, check if it is exactly the original
input parameter. Note that there shouldn't be partial
@@ -2357,8 +2124,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
match e_app.e with
| Qualif qualif -> (
match qualif.id with
- | FunOrOp (Fun (FromLlbc (FunId fun_id, loop_id, _)))
- -> (
+ | FunOrOp (Fun (FromLlbc (FunId fun_id, loop_id))) -> (
match
FunLoopIdMap.find_opt (fun_id, loop_id) !used_map
with
@@ -2400,13 +2166,8 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
in
let transl =
List.map
- (fun trans ->
- let filter_fun_and_loops f =
- { f = filter_in_one f.f; loops = List.map filter_in_one f.loops }
- in
- let fwd = filter_fun_and_loops trans.fwd in
- let backs = List.map filter_fun_and_loops trans.backs in
- { trans with fwd; backs })
+ (fun f ->
+ { f = filter_in_one f.f; loops = List.map filter_in_one f.loops })
transl
in
@@ -2420,18 +2181,11 @@ let filter_loop_inputs (transl : pure_fun_translation list) :
it thus returns the pair: (function def, loop defs). See {!decompose_loops}
for more information.
- Will return [None] if the function is a backward function with no outputs.
-
[ctx]: used only for printing.
*)
-let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
- fun_and_loops option =
+let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_and_loops =
(* Debug *)
- log#ldebug
- (lazy
- ("PureMicroPasses.apply_passes_to_def: " ^ def.name ^ " ("
- ^ Print.option_to_string T.RegionGroupId.to_string def.back_id
- ^ ")"));
+ log#ldebug (lazy ("PureMicroPasses.apply_passes_to_def: " ^ def.name));
log#ldebug (lazy ("original decl:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
@@ -2451,29 +2205,13 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
let def = remove_meta def in
log#ldebug (lazy ("remove_meta:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
- (* Remove the backward functions with no outputs.
+ (* Extract the loop definitions by removing the {!Loop} node *)
+ let def, loops = decompose_loops ctx def in
- Note that the *calls* to those functions should already have been removed,
- when translating from symbolic to pure. Here, we remove the definitions
- altogether, because they are now useless *)
- let name = def.name ^ PrintPure.fun_suffix def.loop_id def.back_id in
- let opt_def = filter_if_backward_with_no_outputs def in
-
- match opt_def with
- | None ->
- log#ldebug (lazy ("filtered (backward with no outputs): " ^ name ^ "\n"));
- None
- | Some def ->
- log#ldebug
- (lazy ("not filtered (not backward with no outputs): " ^ name ^ "\n"));
-
- (* Extract the loop definitions by removing the {!Loop} node *)
- let def, loops = decompose_loops ctx def in
-
- (* Apply the remaining passes *)
- let f = apply_end_passes_to_def ctx def in
- let loops = List.map (apply_end_passes_to_def ctx) loops in
- Some { f; loops }
+ (* Apply the remaining passes *)
+ let f = apply_end_passes_to_def ctx def in
+ let loops = List.map (apply_end_passes_to_def ctx) loops in
+ { f; loops }
(** Apply the micro-passes to a list of forward/backward translations.
@@ -2489,18 +2227,11 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
but convenient.
*)
let apply_passes_to_pure_fun_translations (ctx : trans_ctx)
- (transl : (fun_decl * fun_decl list) list) : pure_fun_translation list =
- let apply_to_one (trans : fun_decl * fun_decl list) : pure_fun_translation =
- (* Apply the passes to the individual functions *)
- let fwd, backs = trans in
- let fwd = Option.get (apply_passes_to_def ctx fwd) in
- let backs = List.filter_map (apply_passes_to_def ctx) backs in
- (* Compute whether we need to filter the forward function or not *)
- let keep_fwd = keep_forward fwd backs in
- { keep_fwd; fwd; backs }
- in
-
- let transl = List.map apply_to_one transl in
+ (transl : fun_decl list) : pure_fun_translation list =
+ (* Apply the micro-passes *)
+ let transl = List.map (apply_passes_to_def ctx) transl in
- (* Filter the useless inputs in the loop functions *)
+ (* Filter the useless inputs in the loop functions (loops are initially
+ parameterized by *all* the symbolic values in the context, because
+ they may access any of them). *)
filter_loop_inputs transl