summaryrefslogtreecommitdiff
path: root/compiler/PureMicroPasses.ml
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--compiler/PureMicroPasses.ml211
1 files changed, 95 insertions, 116 deletions
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index b6025df4..f3e6cbe2 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -376,8 +376,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
let ty = e.ty in
let ctx, e =
match e.e with
- | Var _ -> (* Nothing to do *) (ctx, e.e)
- | Const _ -> (* Nothing to do *) (ctx, e.e)
+ | Var _ | CVar _ | Const _ -> (* Nothing to do *) (ctx, e.e)
| App (app, arg) ->
let ctx, app = update_texpression app ctx in
let ctx, arg = update_texpression arg ctx in
@@ -584,13 +583,10 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
| Qualif
{
id = AdtCons { adt_id = AdtId adt_id; variant_id = None };
- type_args = _;
- const_generic_args = _;
+ generics = _;
} ->
(* Lookup the def *)
- let decl =
- TypeDeclId.Map.find adt_id ctx.type_context.type_decls
- in
+ let decl = TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls in
(* Check that there are as many arguments as there are fields - note
that the def should have a body (otherwise we couldn't use the
constructor) *)
@@ -599,8 +595,7 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
(* Check if the definition is recursive *)
let is_rec =
match
- TypeDeclId.Map.find adt_id
- ctx.type_context.type_decls_groups
+ TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls_groups
with
| NonRec _ -> false
| Rec _ -> true
@@ -682,8 +677,8 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool)
| _ -> false
in
(* And either:
- * 2.1 the right-expression is a variable or a global *)
- let var_or_global = is_var re || is_global re in
+ * 2.1 the right-expression is a variable, a global or a const generic var *)
+ let var_or_global = is_var re || is_cvar re || is_global re in
(* Or:
* 2.2 the right-expression is a constant value, an ADT value,
* a projection or a primitive function call *and* the flag
@@ -767,10 +762,10 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool)
In this situation, we can remove the call [f@fwd x].
*)
let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
- (id0 : A.fun_id) (lp_id0 : LoopId.id option)
- (rg_id0 : T.RegionGroupId.id option) (tys0 : ty list)
+ (id0 : fun_id_or_trait_method_ref) (lp_id0 : LoopId.id option)
+ (rg_id0 : T.RegionGroupId.id option) (generics0 : generic_args)
(args0 : texpression list) (e : texpression) : bool =
- let check_call (fun_id1 : fun_or_op_id) (tys1 : ty list)
+ let check_call (fun_id1 : fun_or_op_id) (generics1 : generic_args)
(args1 : texpression list) : bool =
(* Check the fun_ids, to see if call1's function is a child of call0's function *)
match fun_id1 with
@@ -793,7 +788,12 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
(* We need to use the regions hierarchy *)
(* First, lookup the signature of the LLBC function *)
let sg =
- LlbcAstUtils.lookup_fun_sig id0 ctx.fun_context.fun_decls
+ let id0 =
+ match id0 with
+ | FunId fun_id -> fun_id
+ | TraitMethod (_, _, fun_decl_id) -> Regular fun_decl_id
+ in
+ LlbcAstUtils.lookup_fun_sig id0 ctx.fun_ctx.fun_decls
in
(* Compute the set of ancestors of the function in call1 *)
let call1_ancestors =
@@ -817,8 +817,8 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
let input_eq (v0, v1) =
PureUtils.remove_meta v0 = PureUtils.remove_meta v1
in
- (* Compare the input types and the prefix of the input arguments *)
- tys0 = tys1 && List.for_all input_eq args
+ (* Compare the generics and the prefix of the input arguments *)
+ generics0 = generics1 && List.for_all input_eq args
else (* Not a child *)
false
else (* Not the same function *)
@@ -834,7 +834,7 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
method! visit_texpression env e =
match e.e with
- | Var _ | Const _ -> fun _ -> false
+ | Var _ | CVar _ | Const _ -> fun _ -> false
| StructUpdate _ ->
(* There shouldn't be monadic calls in structure updates - also
note that by returning [false] we are conservative: we might
@@ -844,8 +844,8 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
| Let (_, _, re, e) -> (
match opt_destruct_function_call re with
| None -> fun () -> self#visit_texpression env e ()
- | Some (func1, tys1, args1) ->
- let call_is_child = check_call func1 tys1 args1 in
+ | Some (func1, generics1, args1) ->
+ let call_is_child = check_call func1 generics1 args1 in
if call_is_child then fun () -> true
else fun () -> self#visit_texpression env e ())
| App _ -> (
@@ -930,7 +930,7 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
method! visit_expression env e =
match e with
- | Var _ | Const _ | App _ | Qualif _
+ | Var _ | CVar _ | Const _ | App _ | Qualif _
| Switch (_, _)
| Meta (_, _)
| StructUpdate _ | Abs _ ->
@@ -1086,13 +1086,12 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
| Qualif
{
id = AdtCons { adt_id = AdtId adt_id; variant_id = None };
- type_args;
- const_generic_args;
+ generics;
} ->
(* This is a struct *)
(* Retrieve the definiton, to find how many fields there are *)
let adt_decl =
- TypeDeclId.Map.find adt_id ctx.type_context.type_decls
+ TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls
in
let fields =
match adt_decl.kind with
@@ -1108,7 +1107,7 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
* [x.field] for some variable [x], and where the projection
* is for the proper ADT *)
let to_var_proj (i : int) (arg : texpression) :
- (ty list * const_generic list * var_id) option =
+ (generic_args * var_id) option =
match arg.e with
| App (proj, x) -> (
match (proj.e, x.e) with
@@ -1116,16 +1115,14 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
{
id =
Proj { adt_id = AdtId proj_adt_id; field_id };
- type_args = proj_type_args;
- const_generic_args = proj_const_generic_args;
+ generics = proj_generics;
},
Var v ) ->
(* We check that this is the proper ADT, and the proper field *)
if
proj_adt_id = adt_id
&& FieldId.to_int field_id = i
- then
- Some (proj_type_args, proj_const_generic_args, v)
+ then Some (proj_generics, v)
else None
| _ -> None)
| _ -> None
@@ -1136,14 +1133,13 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
if List.length args = num_fields then
(* Check that this is the same variable we project from -
* note that we checked above that there is at least one field *)
- let (_, _, x), end_args = Collections.List.pop args in
- if List.for_all (fun (_, _, y) -> y = x) end_args then (
+ let (_, x), end_args = Collections.List.pop args in
+ if List.for_all (fun (_, y) -> y = x) end_args then (
(* We can substitute *)
(* Sanity check: all types correct *)
assert (
List.for_all
- (fun (tys, cgs, _) ->
- tys = type_args && cgs = const_generic_args)
+ (fun (generics1, _) -> generics1 = generics)
args);
{ e with e = Var x })
else super#visit_texpression env e
@@ -1162,8 +1158,7 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
| ( Qualif
{
id = Proj { adt_id = AdtId proj_adt_id; field_id };
- type_args = _;
- const_generic_args = _;
+ generics = _;
},
Var v ) ->
(* We check that this is the proper ADT, and the proper field *)
@@ -1361,8 +1356,8 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
let loop_sig =
{
- type_params = fun_sig.type_params;
- const_generic_params = fun_sig.const_generic_params;
+ generics = fun_sig.generics;
+ preds = fun_sig.preds;
inputs = inputs_tys;
output;
doutputs;
@@ -1427,6 +1422,7 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
let loop_def =
{
def_id = def.def_id;
+ kind = def.kind;
num_loops;
loop_id = Some loop.loop_id;
back_id = def.back_id;
@@ -1466,13 +1462,12 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
In such situation, we can remove the forward function definition
altogether.
*)
-let keep_forward (trans : pure_fun_translation) : bool =
- let (fwd, _), backs = trans in
+let keep_forward (fwd : fun_and_loops) (backs : fun_and_loops list) : bool =
(* Note that at this point, the output types are no longer seen as tuples:
* they should be lists of length 1. *)
if
!Config.filter_useless_functions
- && fwd.signature.output = mk_result_ty mk_unit_ty
+ && fwd.f.signature.output = mk_result_ty mk_unit_ty
&& backs <> []
then false
else true
@@ -1518,7 +1513,7 @@ let unit_vars_to_unit (def : fun_decl) : fun_decl =
function calls, and when translating end abstractions. Here, we can do
something simpler, in one micro-pass.
*)
-let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
+let eliminate_box_functions (ctx : trans_ctx) (def : fun_decl) : fun_decl =
(* The map visitor *)
let obj =
object
@@ -1527,30 +1522,44 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
method! visit_texpression env e =
match opt_destruct_function_call e with
| Some (fun_id, _tys, args) -> (
+ (* Below, when dealing with the arguments: we consider the very
+ * general case, where functions could be boxed (meaning we
+ * could have: [box_new f x])
+ * *)
match fun_id with
- | Fun (FromLlbc (A.Assumed aid, _lp_id, rg_id)) -> (
- (* Below, when dealing with the arguments: we consider the very
- * general case, where functions could be boxed (meaning we
- * could have: [box_new f x])
- * *)
+ | Fun (FromLlbc (FunId (Assumed aid), _lp_id, rg_id)) -> (
match (aid, rg_id) with
- | A.BoxNew, _ ->
+ | BoxNew, _ ->
assert (rg_id = None);
let arg, args = Collections.List.pop args in
mk_apps arg args
- | A.BoxDeref, None ->
+ | BoxFree, _ ->
+ assert (args = []);
+ mk_unit_rvalue
+ | ( ( SliceIndexShared | SliceIndexMut | ArrayIndexShared
+ | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut
+ | ArrayRepeat | SliceLen ),
+ _ ) ->
+ super#visit_texpression env e)
+ | Fun (FromLlbc (FunId (Regular fid), _lp_id, rg_id)) -> (
+ (* Lookup the function name *)
+ let def = FunDeclId.Map.find fid ctx.fun_ctx.fun_decls in
+ match
+ (Names.name_no_disambiguators_to_string def.name, rg_id)
+ with
+ | "alloc::boxed::Box::deref", None ->
(* [Box::deref] forward is the identity *)
let arg, args = Collections.List.pop args in
mk_apps arg args
- | A.BoxDeref, Some _ ->
+ | "alloc::boxed::Box::deref", Some _ ->
(* [Box::deref] backward is [()] (doesn't give back anything) *)
assert (args = []);
mk_unit_rvalue
- | A.BoxDerefMut, None ->
+ | "alloc::boxed::Box::deref_mut", None ->
(* [Box::deref_mut] forward is the identity *)
let arg, args = Collections.List.pop args in
mk_apps arg args
- | A.BoxDerefMut, Some _ ->
+ | "alloc::boxed::Box::deref_mut", Some _ ->
(* [Box::deref_mut] back is almost the identity:
* let box_deref_mut (x_init : t) (x_back : t) : t = x_back
* *)
@@ -1560,17 +1569,7 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
| _ -> raise (Failure "Unreachable")
in
mk_apps arg args
- | A.BoxFree, _ ->
- assert (args = []);
- mk_unit_rvalue
- | ( ( A.Replace | VecNew | VecPush | VecInsert | VecLen
- | VecIndex | VecIndexMut | ArraySubsliceShared
- | ArraySubsliceMut | SliceIndexShared | SliceIndexMut
- | SliceSubsliceShared | SliceSubsliceMut | ArrayIndexShared
- | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut
- | SliceLen ),
- _ ) ->
- super#visit_texpression env e)
+ | _ -> super#visit_texpression env e)
| _ -> super#visit_texpression env e)
| _ -> super#visit_texpression env e
end
@@ -1914,7 +1913,7 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl =
[ctx]: used only for printing.
*)
let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
- (fun_decl * fun_decl list) option =
+ fun_and_loops option =
(* Debug *)
log#ldebug
(lazy
@@ -1955,9 +1954,9 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
let def, loops = decompose_loops def in
(* Apply the remaining passes *)
- let def = apply_end_passes_to_def ctx def in
+ let f = apply_end_passes_to_def ctx def in
let loops = List.map (apply_end_passes_to_def ctx) loops in
- Some (def, loops)
+ Some { f; loops }
(** Small utility for {!filter_loop_inputs} *)
let filter_prefix (keep : bool list) (ls : 'a list) : 'a list =
@@ -1983,8 +1982,8 @@ end
module FunLoopIdMap = Collections.MakeMap (FunLoopIdOrderedType)
(** Filter the useless loop input parameters. *)
-let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
- (bool * pure_fun_translation) list =
+let filter_loop_inputs (transl : pure_fun_translation list) :
+ pure_fun_translation list =
(* We need to explore groups of mutually recursive functions. In order
to compute which parameters are useless, we need to explore the
functions by groups of mutually recursive definitions.
@@ -2002,10 +2001,11 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
(List.concat
(List.concat
(List.map
- (fun (_, ((fwd, loops_fwd), backs)) ->
- [ fwd :: loops_fwd ]
+ (fun { fwd; backs; _ } ->
+ [ fwd.f :: fwd.loops ]
:: List.map
- (fun (back, loops_back) -> [ back :: loops_back ])
+ (fun { f = back; loops = loops_back } ->
+ [ back :: loops_back ])
backs)
transl)))
in
@@ -2030,7 +2030,6 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
additional parameters.
*)
let used_map = ref FunLoopIdMap.empty in
- let fun_id_to_fun_loop_id (fid, loop_id, _) = (fid, loop_id) in
(* We start by computing the filtering information, for each function *)
let compute_one_filter_info (decl : fun_decl) =
@@ -2051,7 +2050,7 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
let inputs_set = VarId.Set.of_list (List.map var_get_id inputs_prefix) in
assert (Option.is_some decl.loop_id);
- let fun_id = (A.Regular decl.def_id, decl.loop_id) in
+ let fun_id = (E.Regular decl.def_id, decl.loop_id) in
let set_used vid =
used := List.map (fun (vid', b) -> (vid', b || vid = vid')) !used
@@ -2075,8 +2074,8 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
match e_app.e with
| Qualif qualif -> (
match qualif.id with
- | FunOrOp (Fun (FromLlbc fun_id')) ->
- if fun_id_to_fun_loop_id fun_id' = fun_id then (
+ | FunOrOp (Fun (FromLlbc (FunId fun_id', loop_id', _))) ->
+ if (fun_id', loop_id') = fun_id then (
(* For each argument, check if it is exactly the original
input parameter. Note that there shouldn't be partial
applications of loop functions: the number of arguments
@@ -2135,22 +2134,15 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
(* We then apply the filtering to all the function definitions at once *)
let filter_in_one (decl : fun_decl) : fun_decl =
(* Filter the function signature *)
- let fun_id = (A.Regular decl.def_id, decl.loop_id, decl.back_id) in
+ let fun_id = (E.Regular decl.def_id, decl.loop_id) in
let decl =
- match FunLoopIdMap.find_opt (fun_id_to_fun_loop_id fun_id) !used_map with
+ match FunLoopIdMap.find_opt fun_id !used_map with
| None -> (* Nothing to filter *) decl
| Some used_info ->
let num_filtered =
List.length (List.filter (fun b -> not b) used_info)
in
- let {
- type_params;
- const_generic_params;
- inputs;
- output;
- doutputs;
- info;
- } =
+ let { generics; preds; inputs; output; doutputs; info } =
decl.signature
in
let {
@@ -2178,16 +2170,7 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
effect_info;
}
in
- let signature =
- {
- type_params;
- const_generic_params;
- inputs;
- output;
- doutputs;
- info;
- }
- in
+ let signature = { generics; preds; inputs; output; doutputs; info } in
{ decl with signature }
in
@@ -2201,9 +2184,7 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
let { inputs; inputs_lvs; body } = body in
let inputs, inputs_lvs =
- match
- FunLoopIdMap.find_opt (fun_id_to_fun_loop_id fun_id) !used_map
- with
+ match FunLoopIdMap.find_opt fun_id !used_map with
| None -> (* Nothing to filter *) (inputs, inputs_lvs)
| Some used_info ->
let inputs = filter_prefix used_info inputs in
@@ -2223,11 +2204,10 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
match e_app.e with
| Qualif qualif -> (
match qualif.id with
- | FunOrOp (Fun (FromLlbc fun_id)) -> (
+ | FunOrOp (Fun (FromLlbc (FunId fun_id, loop_id, _)))
+ -> (
match
- FunLoopIdMap.find_opt
- (fun_id_to_fun_loop_id fun_id)
- !used_map
+ FunLoopIdMap.find_opt (fun_id, loop_id) !used_map
with
| None -> super#visit_texpression env e
| Some used_info ->
@@ -2267,13 +2247,13 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
in
let transl =
List.map
- (fun (b, (fwd, backs)) ->
- let filter_fun_and_loops (f, fl) =
- (filter_in_one f, List.map filter_in_one fl)
+ (fun trans ->
+ let filter_fun_and_loops f =
+ { f = filter_in_one f.f; loops = List.map filter_in_one f.loops }
in
- let fwd = filter_fun_and_loops fwd in
- let backs = List.map filter_fun_and_loops backs in
- (b, (fwd, backs)))
+ let fwd = filter_fun_and_loops trans.fwd in
+ let backs = List.map filter_fun_and_loops trans.backs in
+ { trans with fwd; backs })
transl
in
@@ -2294,18 +2274,17 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
but convenient.
*)
let apply_passes_to_pure_fun_translations (ctx : trans_ctx)
- (transl : (fun_decl * fun_decl list) list) :
- (bool * pure_fun_translation) list =
- let apply_to_one (trans : fun_decl * fun_decl list) :
- bool * pure_fun_translation =
+ (transl : (fun_decl * fun_decl list) list) : pure_fun_translation list =
+ let apply_to_one (trans : fun_decl * fun_decl list) : pure_fun_translation =
(* Apply the passes to the individual functions *)
- let forward, backwards = trans in
- let forward = Option.get (apply_passes_to_def ctx forward) in
- let backwards = List.filter_map (apply_passes_to_def ctx) backwards in
- let trans = (forward, backwards) in
+ let fwd, backs = trans in
+ let fwd = Option.get (apply_passes_to_def ctx fwd) in
+ let backs = List.filter_map (apply_passes_to_def ctx) backs in
(* Compute whether we need to filter the forward function or not *)
- (keep_forward trans, trans)
+ let keep_fwd = keep_forward fwd backs in
+ { keep_fwd; fwd; backs }
in
+
let transl = List.map apply_to_one transl in
(* Filter the useless inputs in the loop functions *)