diff options
Diffstat (limited to '')
-rw-r--r-- | compiler/PureMicroPasses.ml | 211 |
1 files changed, 95 insertions, 116 deletions
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index b6025df4..f3e6cbe2 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -376,8 +376,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl = let ty = e.ty in let ctx, e = match e.e with - | Var _ -> (* Nothing to do *) (ctx, e.e) - | Const _ -> (* Nothing to do *) (ctx, e.e) + | Var _ | CVar _ | Const _ -> (* Nothing to do *) (ctx, e.e) | App (app, arg) -> let ctx, app = update_texpression app ctx in let ctx, arg = update_texpression arg ctx in @@ -584,13 +583,10 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = | Qualif { id = AdtCons { adt_id = AdtId adt_id; variant_id = None }; - type_args = _; - const_generic_args = _; + generics = _; } -> (* Lookup the def *) - let decl = - TypeDeclId.Map.find adt_id ctx.type_context.type_decls - in + let decl = TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls in (* Check that there are as many arguments as there are fields - note that the def should have a body (otherwise we couldn't use the constructor) *) @@ -599,8 +595,7 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = (* Check if the definition is recursive *) let is_rec = match - TypeDeclId.Map.find adt_id - ctx.type_context.type_decls_groups + TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls_groups with | NonRec _ -> false | Rec _ -> true @@ -682,8 +677,8 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) | _ -> false in (* And either: - * 2.1 the right-expression is a variable or a global *) - let var_or_global = is_var re || is_global re in + * 2.1 the right-expression is a variable, a global or a const generic var *) + let var_or_global = is_var re || is_cvar re || is_global re in (* Or: * 2.2 the right-expression is a constant value, an ADT value, * a projection or a primitive function call *and* the flag @@ -767,10 +762,10 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) In this situation, we can remove the call [f@fwd x]. *) let expression_contains_child_call_in_all_paths (ctx : trans_ctx) - (id0 : A.fun_id) (lp_id0 : LoopId.id option) - (rg_id0 : T.RegionGroupId.id option) (tys0 : ty list) + (id0 : fun_id_or_trait_method_ref) (lp_id0 : LoopId.id option) + (rg_id0 : T.RegionGroupId.id option) (generics0 : generic_args) (args0 : texpression list) (e : texpression) : bool = - let check_call (fun_id1 : fun_or_op_id) (tys1 : ty list) + let check_call (fun_id1 : fun_or_op_id) (generics1 : generic_args) (args1 : texpression list) : bool = (* Check the fun_ids, to see if call1's function is a child of call0's function *) match fun_id1 with @@ -793,7 +788,12 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (* We need to use the regions hierarchy *) (* First, lookup the signature of the LLBC function *) let sg = - LlbcAstUtils.lookup_fun_sig id0 ctx.fun_context.fun_decls + let id0 = + match id0 with + | FunId fun_id -> fun_id + | TraitMethod (_, _, fun_decl_id) -> Regular fun_decl_id + in + LlbcAstUtils.lookup_fun_sig id0 ctx.fun_ctx.fun_decls in (* Compute the set of ancestors of the function in call1 *) let call1_ancestors = @@ -817,8 +817,8 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) let input_eq (v0, v1) = PureUtils.remove_meta v0 = PureUtils.remove_meta v1 in - (* Compare the input types and the prefix of the input arguments *) - tys0 = tys1 && List.for_all input_eq args + (* Compare the generics and the prefix of the input arguments *) + generics0 = generics1 && List.for_all input_eq args else (* Not a child *) false else (* Not the same function *) @@ -834,7 +834,7 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) method! visit_texpression env e = match e.e with - | Var _ | Const _ -> fun _ -> false + | Var _ | CVar _ | Const _ -> fun _ -> false | StructUpdate _ -> (* There shouldn't be monadic calls in structure updates - also note that by returning [false] we are conservative: we might @@ -844,8 +844,8 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) | Let (_, _, re, e) -> ( match opt_destruct_function_call re with | None -> fun () -> self#visit_texpression env e () - | Some (func1, tys1, args1) -> - let call_is_child = check_call func1 tys1 args1 in + | Some (func1, generics1, args1) -> + let call_is_child = check_call func1 generics1 args1 in if call_is_child then fun () -> true else fun () -> self#visit_texpression env e ()) | App _ -> ( @@ -930,7 +930,7 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) method! visit_expression env e = match e with - | Var _ | Const _ | App _ | Qualif _ + | Var _ | CVar _ | Const _ | App _ | Qualif _ | Switch (_, _) | Meta (_, _) | StructUpdate _ | Abs _ -> @@ -1086,13 +1086,12 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = | Qualif { id = AdtCons { adt_id = AdtId adt_id; variant_id = None }; - type_args; - const_generic_args; + generics; } -> (* This is a struct *) (* Retrieve the definiton, to find how many fields there are *) let adt_decl = - TypeDeclId.Map.find adt_id ctx.type_context.type_decls + TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls in let fields = match adt_decl.kind with @@ -1108,7 +1107,7 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = * [x.field] for some variable [x], and where the projection * is for the proper ADT *) let to_var_proj (i : int) (arg : texpression) : - (ty list * const_generic list * var_id) option = + (generic_args * var_id) option = match arg.e with | App (proj, x) -> ( match (proj.e, x.e) with @@ -1116,16 +1115,14 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = { id = Proj { adt_id = AdtId proj_adt_id; field_id }; - type_args = proj_type_args; - const_generic_args = proj_const_generic_args; + generics = proj_generics; }, Var v ) -> (* We check that this is the proper ADT, and the proper field *) if proj_adt_id = adt_id && FieldId.to_int field_id = i - then - Some (proj_type_args, proj_const_generic_args, v) + then Some (proj_generics, v) else None | _ -> None) | _ -> None @@ -1136,14 +1133,13 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = if List.length args = num_fields then (* Check that this is the same variable we project from - * note that we checked above that there is at least one field *) - let (_, _, x), end_args = Collections.List.pop args in - if List.for_all (fun (_, _, y) -> y = x) end_args then ( + let (_, x), end_args = Collections.List.pop args in + if List.for_all (fun (_, y) -> y = x) end_args then ( (* We can substitute *) (* Sanity check: all types correct *) assert ( List.for_all - (fun (tys, cgs, _) -> - tys = type_args && cgs = const_generic_args) + (fun (generics1, _) -> generics1 = generics) args); { e with e = Var x }) else super#visit_texpression env e @@ -1162,8 +1158,7 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl = | ( Qualif { id = Proj { adt_id = AdtId proj_adt_id; field_id }; - type_args = _; - const_generic_args = _; + generics = _; }, Var v ) -> (* We check that this is the proper ADT, and the proper field *) @@ -1361,8 +1356,8 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = let loop_sig = { - type_params = fun_sig.type_params; - const_generic_params = fun_sig.const_generic_params; + generics = fun_sig.generics; + preds = fun_sig.preds; inputs = inputs_tys; output; doutputs; @@ -1427,6 +1422,7 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = let loop_def = { def_id = def.def_id; + kind = def.kind; num_loops; loop_id = Some loop.loop_id; back_id = def.back_id; @@ -1466,13 +1462,12 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = In such situation, we can remove the forward function definition altogether. *) -let keep_forward (trans : pure_fun_translation) : bool = - let (fwd, _), backs = trans in +let keep_forward (fwd : fun_and_loops) (backs : fun_and_loops list) : bool = (* Note that at this point, the output types are no longer seen as tuples: * they should be lists of length 1. *) if !Config.filter_useless_functions - && fwd.signature.output = mk_result_ty mk_unit_ty + && fwd.f.signature.output = mk_result_ty mk_unit_ty && backs <> [] then false else true @@ -1518,7 +1513,7 @@ let unit_vars_to_unit (def : fun_decl) : fun_decl = function calls, and when translating end abstractions. Here, we can do something simpler, in one micro-pass. *) -let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = +let eliminate_box_functions (ctx : trans_ctx) (def : fun_decl) : fun_decl = (* The map visitor *) let obj = object @@ -1527,30 +1522,44 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = method! visit_texpression env e = match opt_destruct_function_call e with | Some (fun_id, _tys, args) -> ( + (* Below, when dealing with the arguments: we consider the very + * general case, where functions could be boxed (meaning we + * could have: [box_new f x]) + * *) match fun_id with - | Fun (FromLlbc (A.Assumed aid, _lp_id, rg_id)) -> ( - (* Below, when dealing with the arguments: we consider the very - * general case, where functions could be boxed (meaning we - * could have: [box_new f x]) - * *) + | Fun (FromLlbc (FunId (Assumed aid), _lp_id, rg_id)) -> ( match (aid, rg_id) with - | A.BoxNew, _ -> + | BoxNew, _ -> assert (rg_id = None); let arg, args = Collections.List.pop args in mk_apps arg args - | A.BoxDeref, None -> + | BoxFree, _ -> + assert (args = []); + mk_unit_rvalue + | ( ( SliceIndexShared | SliceIndexMut | ArrayIndexShared + | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut + | ArrayRepeat | SliceLen ), + _ ) -> + super#visit_texpression env e) + | Fun (FromLlbc (FunId (Regular fid), _lp_id, rg_id)) -> ( + (* Lookup the function name *) + let def = FunDeclId.Map.find fid ctx.fun_ctx.fun_decls in + match + (Names.name_no_disambiguators_to_string def.name, rg_id) + with + | "alloc::boxed::Box::deref", None -> (* [Box::deref] forward is the identity *) let arg, args = Collections.List.pop args in mk_apps arg args - | A.BoxDeref, Some _ -> + | "alloc::boxed::Box::deref", Some _ -> (* [Box::deref] backward is [()] (doesn't give back anything) *) assert (args = []); mk_unit_rvalue - | A.BoxDerefMut, None -> + | "alloc::boxed::Box::deref_mut", None -> (* [Box::deref_mut] forward is the identity *) let arg, args = Collections.List.pop args in mk_apps arg args - | A.BoxDerefMut, Some _ -> + | "alloc::boxed::Box::deref_mut", Some _ -> (* [Box::deref_mut] back is almost the identity: * let box_deref_mut (x_init : t) (x_back : t) : t = x_back * *) @@ -1560,17 +1569,7 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = | _ -> raise (Failure "Unreachable") in mk_apps arg args - | A.BoxFree, _ -> - assert (args = []); - mk_unit_rvalue - | ( ( A.Replace | VecNew | VecPush | VecInsert | VecLen - | VecIndex | VecIndexMut | ArraySubsliceShared - | ArraySubsliceMut | SliceIndexShared | SliceIndexMut - | SliceSubsliceShared | SliceSubsliceMut | ArrayIndexShared - | ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut - | SliceLen ), - _ ) -> - super#visit_texpression env e) + | _ -> super#visit_texpression env e) | _ -> super#visit_texpression env e) | _ -> super#visit_texpression env e end @@ -1914,7 +1913,7 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = [ctx]: used only for printing. *) let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : - (fun_decl * fun_decl list) option = + fun_and_loops option = (* Debug *) log#ldebug (lazy @@ -1955,9 +1954,9 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : let def, loops = decompose_loops def in (* Apply the remaining passes *) - let def = apply_end_passes_to_def ctx def in + let f = apply_end_passes_to_def ctx def in let loops = List.map (apply_end_passes_to_def ctx) loops in - Some (def, loops) + Some { f; loops } (** Small utility for {!filter_loop_inputs} *) let filter_prefix (keep : bool list) (ls : 'a list) : 'a list = @@ -1983,8 +1982,8 @@ end module FunLoopIdMap = Collections.MakeMap (FunLoopIdOrderedType) (** Filter the useless loop input parameters. *) -let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : - (bool * pure_fun_translation) list = +let filter_loop_inputs (transl : pure_fun_translation list) : + pure_fun_translation list = (* We need to explore groups of mutually recursive functions. In order to compute which parameters are useless, we need to explore the functions by groups of mutually recursive definitions. @@ -2002,10 +2001,11 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : (List.concat (List.concat (List.map - (fun (_, ((fwd, loops_fwd), backs)) -> - [ fwd :: loops_fwd ] + (fun { fwd; backs; _ } -> + [ fwd.f :: fwd.loops ] :: List.map - (fun (back, loops_back) -> [ back :: loops_back ]) + (fun { f = back; loops = loops_back } -> + [ back :: loops_back ]) backs) transl))) in @@ -2030,7 +2030,6 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : additional parameters. *) let used_map = ref FunLoopIdMap.empty in - let fun_id_to_fun_loop_id (fid, loop_id, _) = (fid, loop_id) in (* We start by computing the filtering information, for each function *) let compute_one_filter_info (decl : fun_decl) = @@ -2051,7 +2050,7 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : let inputs_set = VarId.Set.of_list (List.map var_get_id inputs_prefix) in assert (Option.is_some decl.loop_id); - let fun_id = (A.Regular decl.def_id, decl.loop_id) in + let fun_id = (E.Regular decl.def_id, decl.loop_id) in let set_used vid = used := List.map (fun (vid', b) -> (vid', b || vid = vid')) !used @@ -2075,8 +2074,8 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : match e_app.e with | Qualif qualif -> ( match qualif.id with - | FunOrOp (Fun (FromLlbc fun_id')) -> - if fun_id_to_fun_loop_id fun_id' = fun_id then ( + | FunOrOp (Fun (FromLlbc (FunId fun_id', loop_id', _))) -> + if (fun_id', loop_id') = fun_id then ( (* For each argument, check if it is exactly the original input parameter. Note that there shouldn't be partial applications of loop functions: the number of arguments @@ -2135,22 +2134,15 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : (* We then apply the filtering to all the function definitions at once *) let filter_in_one (decl : fun_decl) : fun_decl = (* Filter the function signature *) - let fun_id = (A.Regular decl.def_id, decl.loop_id, decl.back_id) in + let fun_id = (E.Regular decl.def_id, decl.loop_id) in let decl = - match FunLoopIdMap.find_opt (fun_id_to_fun_loop_id fun_id) !used_map with + match FunLoopIdMap.find_opt fun_id !used_map with | None -> (* Nothing to filter *) decl | Some used_info -> let num_filtered = List.length (List.filter (fun b -> not b) used_info) in - let { - type_params; - const_generic_params; - inputs; - output; - doutputs; - info; - } = + let { generics; preds; inputs; output; doutputs; info } = decl.signature in let { @@ -2178,16 +2170,7 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : effect_info; } in - let signature = - { - type_params; - const_generic_params; - inputs; - output; - doutputs; - info; - } - in + let signature = { generics; preds; inputs; output; doutputs; info } in { decl with signature } in @@ -2201,9 +2184,7 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : let { inputs; inputs_lvs; body } = body in let inputs, inputs_lvs = - match - FunLoopIdMap.find_opt (fun_id_to_fun_loop_id fun_id) !used_map - with + match FunLoopIdMap.find_opt fun_id !used_map with | None -> (* Nothing to filter *) (inputs, inputs_lvs) | Some used_info -> let inputs = filter_prefix used_info inputs in @@ -2223,11 +2204,10 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : match e_app.e with | Qualif qualif -> ( match qualif.id with - | FunOrOp (Fun (FromLlbc fun_id)) -> ( + | FunOrOp (Fun (FromLlbc (FunId fun_id, loop_id, _))) + -> ( match - FunLoopIdMap.find_opt - (fun_id_to_fun_loop_id fun_id) - !used_map + FunLoopIdMap.find_opt (fun_id, loop_id) !used_map with | None -> super#visit_texpression env e | Some used_info -> @@ -2267,13 +2247,13 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : in let transl = List.map - (fun (b, (fwd, backs)) -> - let filter_fun_and_loops (f, fl) = - (filter_in_one f, List.map filter_in_one fl) + (fun trans -> + let filter_fun_and_loops f = + { f = filter_in_one f.f; loops = List.map filter_in_one f.loops } in - let fwd = filter_fun_and_loops fwd in - let backs = List.map filter_fun_and_loops backs in - (b, (fwd, backs))) + let fwd = filter_fun_and_loops trans.fwd in + let backs = List.map filter_fun_and_loops trans.backs in + { trans with fwd; backs }) transl in @@ -2294,18 +2274,17 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : but convenient. *) let apply_passes_to_pure_fun_translations (ctx : trans_ctx) - (transl : (fun_decl * fun_decl list) list) : - (bool * pure_fun_translation) list = - let apply_to_one (trans : fun_decl * fun_decl list) : - bool * pure_fun_translation = + (transl : (fun_decl * fun_decl list) list) : pure_fun_translation list = + let apply_to_one (trans : fun_decl * fun_decl list) : pure_fun_translation = (* Apply the passes to the individual functions *) - let forward, backwards = trans in - let forward = Option.get (apply_passes_to_def ctx forward) in - let backwards = List.filter_map (apply_passes_to_def ctx) backwards in - let trans = (forward, backwards) in + let fwd, backs = trans in + let fwd = Option.get (apply_passes_to_def ctx fwd) in + let backs = List.filter_map (apply_passes_to_def ctx) backs in (* Compute whether we need to filter the forward function or not *) - (keep_forward trans, trans) + let keep_fwd = keep_forward fwd backs in + { keep_fwd; fwd; backs } in + let transl = List.map apply_to_one transl in (* Filter the useless inputs in the loop functions *) |