diff options
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/InterpreterLoops.ml | 8 | ||||
-rw-r--r-- | compiler/Pure.ml | 6 | ||||
-rw-r--r-- | compiler/PureMicroPasses.ml | 342 | ||||
-rw-r--r-- | compiler/PureUtils.ml | 26 | ||||
-rw-r--r-- | compiler/ReorderDecls.ml | 2 | ||||
-rw-r--r-- | compiler/SymbolicToPure.ml | 15 | ||||
-rw-r--r-- | compiler/Translate.ml | 26 |
7 files changed, 377 insertions, 48 deletions
diff --git a/compiler/InterpreterLoops.ml b/compiler/InterpreterLoops.ml index 544bface..dd0cfc3f 100644 --- a/compiler/InterpreterLoops.ml +++ b/compiler/InterpreterLoops.ml @@ -4072,7 +4072,7 @@ let compute_fp_ctx_symbolic_values (ctx : C.eval_ctx) (fp_ctx : C.eval_ctx) : *) let sids = ref V.SymbolicValueId.Set.empty in let visitor = - object (self : 'self) + object (self) inherit [_] C.iter_env method! visit_ASharedLoan inside_shared _ sv child_av = @@ -4117,7 +4117,7 @@ let compute_fp_ctx_symbolic_values (ctx : C.eval_ctx) (fp_ctx : C.eval_ctx) : let ordered_sids = ref [] in let visitor = - object (self : 'self) + object (self) inherit [_] C.iter_env (** We lookup the shared values *) @@ -4138,9 +4138,7 @@ let compute_fp_ctx_symbolic_values (ctx : C.eval_ctx) (fp_ctx : C.eval_ctx) : end in - (* TODO: why do we have to put a boolean here for the typechecker to be happy? - Is it because we use a similar visitor with booleans above?? *) - List.iter (visitor#visit_env_elem true) (List.rev fp_ctx.env); + List.iter (visitor#visit_env_elem ()) (List.rev fp_ctx.env); List.filter_map (fun id -> V.SymbolicValueId.Map.find_opt id sids_to_values) diff --git a/compiler/Pure.ml b/compiler/Pure.ml index fe30a650..777d4308 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -312,9 +312,13 @@ type pure_assumed_fun_id = | FuelEqZero (** Test if some fuel is equal to 0 - TODO: ugly *) [@@deriving show, ord] +(** A function id for a non-assumed function *) +type regular_fun_id = A.fun_id * LoopId.id option * T.RegionGroupId.id option +[@@deriving show, ord] + (** A function identifier *) type fun_id = - | FromLlbc of A.fun_id * LoopId.id option * T.RegionGroupId.id option + | FromLlbc of regular_fun_id (** A function coming from LLBC. The loop id is [None] if the function is actually the auxiliary function diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 09cc2533..e670570b 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -1756,24 +1756,338 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : let loops = List.map (apply_end_passes_to_def ctx) loops in Some (def, loops) -(** Return the forward/backward translations on which we applied the micro-passes. +(** Small utility for {!filter_loop_inputs} *) +let filter_prefix (keep : bool list) (ls : 'a list) : 'a list = + let ls0, ls1 = Collections.List.split_at ls (List.length keep) in + let ls0 = + List.filter_map + (fun (b, x) -> if b then Some x else None) + (List.combine keep ls0) + in + List.append ls0 ls1 + +type fun_loop_id = A.fun_id * LoopId.id option [@@deriving show, ord] + +module FunLoopIdOrderedType = struct + type t = fun_loop_id + + let compare = compare_fun_loop_id + let to_string = show_fun_loop_id + let pp_t = pp_fun_loop_id + let show_t = show_fun_loop_id +end + +module FunLoopIdMap = Collections.MakeMap (FunLoopIdOrderedType) + +(** Filter the useless loop input parameters. *) +let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : + (bool * pure_fun_translation) list = + (* We need to explore groups of mutually recursive functions. In order + to compute which parameters are useless, we need to explore the + functions by groups of mutually recursive definitions. + + Because every Rust function is translated to a list of functions (forward + function, backward functions, loop functions, etc.), and those functions + might depend on each others in different ways, we recompute the SCCs of + the whole module. + + Rem.: we also redo this computation, on a smaller scale, in {!Translate}. + Maybe we can factor out the two. + *) + let all_decls = + List.concat + (List.concat + (List.concat + (List.map + (fun (_, ((fwd, loops_fwd), backs)) -> + [ fwd :: loops_fwd ] + :: List.map + (fun (back, loops_back) -> [ back :: loops_back ]) + backs) + transl))) + in + let subgroups = ReorderDecls.group_reorder_fun_decls all_decls in + + (* Explore the subgroups one by one. + + For now, we only filter the parameters of loop functions which are simply + recursive. + + Rem.: there is a bit of redundancy in computing the useless parameters + for the loop forward *and* the loop backward functions. + *) + (* The [filtered] map: maps function identifiers to filtering information. + + Note that we ignore the backward id: + - we filter the forward inputs only + - we want the filtering to be the same for the forward and the backward + functions + The reason is that for now we want to preserve the fact that a backward + function takes the same inputs as its associated forward function, with + additional parameters. + *) + let used_map = ref FunLoopIdMap.empty in + let fun_id_to_fun_loop_id (fid, loop_id, _) = (fid, loop_id) in + + (* We start by computing the filtering information, for each function *) + let compute_one_filter_info (decl : fun_decl) = + (* There should be a body *) + let body = Option.get decl.body in + (* We only look at the forward inputs, without the state *) + let inputs_prefix, _ = + Collections.List.split_at body.inputs + decl.signature.info.num_fwd_inputs_with_fuel_no_state + in + let used = ref (List.map (fun v -> (var_get_id v, false)) inputs_prefix) in + let inputs_prefix_length = List.length inputs_prefix in + let inputs = + List.map + (fun v -> (var_get_id v, mk_texpression_from_var v)) + inputs_prefix + in + let inputs_set = VarId.Set.of_list (List.map var_get_id inputs_prefix) in + assert (Option.is_some decl.loop_id); + + let fun_id = (A.Regular decl.def_id, decl.loop_id) in + + let set_used vid = + used := List.map (fun (vid', b) -> (vid', b || vid = vid')) !used + in + + (* Set the fuel as used *) + let sg_info = decl.signature.info in + if sg_info.has_fuel then set_used (fst (Collections.List.nth inputs 0)); + + let visitor = + object (self : 'self) + inherit [_] iter_expression as super + + (** Override the expression visitor, to look for loop function calls *) + method! visit_texpression env e = + match e.e with + | App _ -> ( + (* If this is an app: destruct all the arguments, and check if + the leftmost expression is the loop function call *) + let e_app, args = destruct_apps e in + match e_app.e with + | Qualif qualif -> ( + match qualif.id with + | FunOrOp (Fun (FromLlbc fun_id')) -> + if fun_id_to_fun_loop_id fun_id' = fun_id then ( + (* For each argument, check if it is exactly the original + input parameter. Note that there shouldn't be partial + applications of loop functions: the number of arguments + should be exactly the number of input parameters (i.e., + we can use [combine]) + *) + let beg_args, end_args = + Collections.List.split_at args inputs_prefix_length + in + let used_args = List.combine inputs beg_args in + List.iter + (fun ((vid, var), arg) -> + if var <> arg then ( + self#visit_texpression env arg; + set_used vid)) + used_args; + List.iter (self#visit_texpression env) end_args) + else super#visit_texpression env e + | _ -> super#visit_texpression env e) + | _ -> super#visit_texpression env e) + | _ -> super#visit_texpression env e + + (** If we visit a variable which is actually an input parameter, we + set it as used. Note that we take care of ignoring some of those + input parameters given in [visit_texpression]. + *) + method! visit_var_id _ id = + if VarId.Set.mem id inputs_set then set_used id + end + in + visitor#visit_texpression () body.body; + + (* Save the filtering information, if there is anything to filter *) + if List.exists snd !used then + let used = List.map snd !used in + let used = + match FunLoopIdMap.find_opt fun_id !used_map with + | None -> used + | Some used0 -> + List.map (fun (b0, b1) -> b0 || b1) (List.combine used0 used) + in + used_map := FunLoopIdMap.add fun_id used !used_map + in + List.iter + (fun (_, fl) -> + match fl with + | [ f ] -> + (* Group made of one function: check if it is a loop. If it is the + case, explore it. *) + if Option.is_some f.loop_id then compute_one_filter_info f else () + | _ -> + (* Group of mutually recursive functions: ignore for now *) + ()) + subgroups; + + (* We then apply the filtering to all the function definitions at once *) + let filter_in_one (decl : fun_decl) : fun_decl = + (* Filter the function signature *) + let fun_id = (A.Regular decl.def_id, decl.loop_id, decl.back_id) in + let decl = + match FunLoopIdMap.find_opt (fun_id_to_fun_loop_id fun_id) !used_map with + | None -> (* Nothing to filter *) decl + | Some used_info -> + let num_filtered = + List.length (List.filter (fun b -> not b) used_info) + in + let { type_params; inputs; output; doutputs; info } = + decl.signature + in + let { + has_fuel; + num_fwd_inputs_with_fuel_no_state; + num_fwd_inputs_with_fuel_with_state; + num_back_inputs_no_state; + num_back_inputs_with_state; + effect_info; + } = + info + in + + let inputs = filter_prefix used_info inputs in + + let info = + { + has_fuel; + num_fwd_inputs_with_fuel_no_state = + num_fwd_inputs_with_fuel_no_state - num_filtered; + num_fwd_inputs_with_fuel_with_state = + num_fwd_inputs_with_fuel_with_state - num_filtered; + num_back_inputs_no_state; + num_back_inputs_with_state; + effect_info; + } + in + let signature = { type_params; inputs; output; doutputs; info } in + + { decl with signature } + in + + (* Filter the function body *) + let body = + match decl.body with + | None -> None + | Some body -> + (* Update the list of vars *) + let { inputs; inputs_lvs; body } = body in + + let inputs, inputs_lvs = + match + FunLoopIdMap.find_opt (fun_id_to_fun_loop_id fun_id) !used_map + with + | None -> (* Nothing to filter *) (inputs, inputs_lvs) + | Some used_info -> + let inputs = filter_prefix used_info inputs in + let inputs_lvs = filter_prefix used_info inputs_lvs in + (inputs, inputs_lvs) + in + + (* Update the body expression *) + let visitor = + object (self) + inherit [_] map_expression as super + + method! visit_texpression env e = + match e.e with + | App _ -> ( + let e_app, args = destruct_apps e in + match e_app.e with + | Qualif qualif -> ( + match qualif.id with + | FunOrOp (Fun (FromLlbc fun_id)) -> ( + match + FunLoopIdMap.find_opt + (fun_id_to_fun_loop_id fun_id) + !used_map + with + | None -> super#visit_texpression env e + | Some used_info -> + (* Filter the types in the arrow type *) + let tys, ret_ty = destruct_arrows e_app.ty in + let tys = filter_prefix used_info tys in + let ty = mk_arrows tys ret_ty in + let e_app = { e_app with ty } in + + (* Filter the arguments *) + let args = filter_prefix used_info args in + + (* Explore the arguments *) + let args = + List.map (self#visit_texpression env) args + in + + (* Rebuild *) + mk_apps e_app args) + | _ -> + let e_app = self#visit_texpression env e_app in + let args = + List.map (self#visit_texpression env) args + in + mk_apps e_app args) + | _ -> + let e_app = self#visit_texpression env e_app in + let args = List.map (self#visit_texpression env) args in + mk_apps e_app args) + | _ -> super#visit_texpression env e + end + in + let body = visitor#visit_texpression () body in + Some { inputs; inputs_lvs; body } + in + { decl with body } + in + let transl = + List.map + (fun (b, (fwd, backs)) -> + let filter_fun_and_loops (f, fl) = + (filter_in_one f, List.map filter_in_one fl) + in + let fwd = filter_fun_and_loops fwd in + let backs = List.map filter_fun_and_loops backs in + (b, (fwd, backs))) + transl + in + + (* Return *) + transl + +(** Apply the micro-passes to a list of forward/backward translations. This function also extracts the loop definitions from the function body (see {!decompose_loops}). - Also returns a boolean indicating whether the forward function should be kept - or not (because useful/useless - [true] means we need to keep the forward - function). + It also returns a boolean indicating whether the forward function should be kept + or not at extraction time ([true] means we need to keep the forward function). + Note that we don't "filter" the forward function and return a boolean instead, because this function contains useful information to extract the backward - functions: keeping it is not necessary but more convenient. + functions. Note that here, keeping the forward function it is not *necessary* + but convenient. *) -let apply_passes_to_pure_fun_translation (ctx : trans_ctx) - (trans : fun_decl * fun_decl list) : bool * pure_fun_translation = - (* Apply the passes to the individual functions *) - let forward, backwards = trans in - let forward = Option.get (apply_passes_to_def ctx forward) in - let backwards = List.filter_map (apply_passes_to_def ctx) backwards in - let trans = (forward, backwards) in - (* Compute whether we need to filter the forward function or not *) - (keep_forward trans, trans) +let apply_passes_to_pure_fun_translations (ctx : trans_ctx) + (transl : (fun_decl * fun_decl list) list) : + (bool * pure_fun_translation) list = + let apply_to_one (trans : fun_decl * fun_decl list) : + bool * pure_fun_translation = + (* Apply the passes to the individual functions *) + let forward, backwards = trans in + let forward = Option.get (apply_passes_to_def ctx forward) in + let backwards = List.filter_map (apply_passes_to_def ctx) backwards in + let trans = (forward, backwards) in + (* Compute whether we need to filter the forward function or not *) + (keep_forward trans, trans) + in + let transl = List.map apply_to_one transl in + + (* Filter the useless inputs in the loop functions *) + filter_loop_inputs transl diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index e86d2a78..e13743c4 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -3,13 +3,6 @@ open Pure (** Default logger *) let log = Logging.pure_utils_log -(** We use this type as a key for lookups *) -type regular_fun_id = A.fun_id * T.RegionGroupId.id option -[@@deriving show, ord] - -(** We use this type as a key for lookups *) -type fun_loop_id = A.FunDeclId.id * LoopId.id option [@@deriving show, ord] - module RegularFunIdOrderedType = struct type t = regular_fun_id @@ -21,6 +14,25 @@ end module RegularFunIdMap = Collections.MakeMap (RegularFunIdOrderedType) +(** We use this type as a key for lookups *) +type regular_fun_id_not_loop = A.fun_id * T.RegionGroupId.id option +[@@deriving show, ord] + +(** We use this type as a key for lookups *) +type fun_loop_id = A.FunDeclId.id * LoopId.id option [@@deriving show, ord] + +module RegularFunIdNotLoopOrderedType = struct + type t = regular_fun_id_not_loop + + let compare = compare_regular_fun_id_not_loop + let to_string = show_regular_fun_id_not_loop + let pp_t = pp_regular_fun_id_not_loop + let show_t = show_regular_fun_id_not_loop +end + +module RegularFunIdNotLoopMap = + Collections.MakeMap (RegularFunIdNotLoopOrderedType) + module FunOrOpIdOrderedType = struct type t = fun_or_op_id diff --git a/compiler/ReorderDecls.ml b/compiler/ReorderDecls.ml index 9e15da4e..fc4744bc 100644 --- a/compiler/ReorderDecls.ml +++ b/compiler/ReorderDecls.ml @@ -188,7 +188,7 @@ let group_reorder_fun_decls (decls : fun_decl list) : assert (FunIdSet.pairwise_disjoint (List.map FunIdSet.of_list sccs)); (* Check that all the ids are in the sccs *) let scc_ids = FunIdSet.of_list (List.concat sccs) in - assert (scc_ids = ids) + assert (FunIdSet.equal scc_ids ids) in (* Group the declarations *) diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index ab9e40df..531a13e9 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -43,7 +43,7 @@ type fun_sig_named_outputs = { type fun_context = { llbc_fun_decls : A.fun_decl A.FunDeclId.Map.t; - fun_sigs : fun_sig_named_outputs RegularFunIdMap.t; (** *) + fun_sigs : fun_sig_named_outputs RegularFunIdNotLoopMap.t; (** *) fun_infos : FA.fun_info A.FunDeclId.Map.t; } [@@deriving show] @@ -318,7 +318,7 @@ let get_instantiated_fun_sig (fun_id : A.fun_id) inst_fun_sig = (* Lookup the non-instantiated function signature *) let sg = - (RegularFunIdMap.find (fun_id, back_id) ctx.fun_context.fun_sigs).sg + (RegularFunIdNotLoopMap.find (fun_id, back_id) ctx.fun_context.fun_sigs).sg in (* Create the substitution *) let tsubst = make_type_subst sg.type_params tys in @@ -337,7 +337,7 @@ let bs_ctx_lookup_llbc_fun_decl (id : A.FunDeclId.id) (ctx : bs_ctx) : let bs_ctx_lookup_local_function_sig (def_id : A.FunDeclId.id) (back_id : T.RegionGroupId.id option) (ctx : bs_ctx) : fun_sig = let id = (A.Regular def_id, back_id) in - (RegularFunIdMap.find id ctx.fun_context.fun_sigs).sg + (RegularFunIdNotLoopMap.find id ctx.fun_context.fun_sigs).sg let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) (args : texpression list) (ctx : bs_ctx) : bs_ctx = @@ -2730,13 +2730,14 @@ let translate_type_decls (type_decls : T.type_decl list) : type_decl list = let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) (types_infos : TA.type_infos) (functions : (A.fun_id * string option list * A.fun_sig) list) : - fun_sig_named_outputs RegularFunIdMap.t = + fun_sig_named_outputs RegularFunIdNotLoopMap.t = (* For every function, translate the signatures of: - the forward function - the backward functions *) let translate_one (fun_id : A.fun_id) (input_names : string option list) - (sg : A.fun_sig) : (regular_fun_id * fun_sig_named_outputs) list = + (sg : A.fun_sig) : (regular_fun_id_not_loop * fun_sig_named_outputs) list + = (* The forward function *) let fwd_sg = translate_fun_sig fun_infos fun_id types_infos sg input_names None @@ -2762,5 +2763,5 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) (List.map (fun (id, names, sg) -> translate_one id names sg) functions) in List.fold_left - (fun m (id, sg) -> RegularFunIdMap.add id sg m) - RegularFunIdMap.empty translated + (fun m (id, sg) -> RegularFunIdNotLoopMap.add id sg m) + RegularFunIdNotLoopMap.empty translated diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 966ccf70..c42f3a27 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -49,7 +49,7 @@ let translate_function_to_symbolics (trans_ctx : trans_ctx) (fdef : A.fun_decl) TODO: maybe we should introduce a record for this. *) let translate_function_to_pure (trans_ctx : trans_ctx) - (fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdMap.t) + (fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdNotLoopMap.t) (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) (fdef : A.fun_decl) : pure_fun_translation_no_loops = (* Debug *) @@ -66,7 +66,9 @@ let translate_function_to_pure (trans_ctx : trans_ctx) (* Convert the symbolic ASTs to pure ASTs: *) (* Initialize the context *) - let forward_sig = RegularFunIdMap.find (A.Regular def_id, None) fun_sigs in + let forward_sig = + RegularFunIdNotLoopMap.find (A.Regular def_id, None) fun_sigs + in let sv_to_var = V.SymbolicValueId.Map.empty in let var_counter = Pure.VarId.generator_zero in let state_var, var_counter = Pure.VarId.fresh var_counter in @@ -194,7 +196,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) (* Initialize the context - note that the ret_ty is not really * useful as we don't translate a body *) let backward_sg = - RegularFunIdMap.find (A.Regular def_id, Some back_id) fun_sigs + RegularFunIdNotLoopMap.find (A.Regular def_id, Some back_id) fun_sigs in let ctx = { ctx with bid = Some back_id; sg = backward_sg.sg } in @@ -205,7 +207,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) variables required by the backward function. *) let backward_sg = - RegularFunIdMap.find (A.Regular def_id, Some back_id) fun_sigs + RegularFunIdNotLoopMap.find (A.Regular def_id, Some back_id) fun_sigs in (* We need to ignore the forward inputs, and the state input (if there is) *) let backward_inputs = @@ -325,9 +327,7 @@ let translate_module_to_pure (crate : A.crate) : (* Apply the micro-passes *) let pure_translations = - List.map - (Micro.apply_passes_to_pure_fun_translation trans_ctx) - pure_translations + Micro.apply_passes_to_pure_fun_translations trans_ctx pure_translations in (* Return *) @@ -470,16 +470,16 @@ let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) (** Utility. - Export a group of functions. See [export_functions_group]. + Export a group of functions, used by {!export_functions_group}. We need this because for every function in Rust we may generate several functions in the translation (a forward function, several backward functions, loop functions, etc.). Those functions might call each other in different - ways (in particular, they may be mutually recursive, in which case we might + ways. In particular, they may be mutually recursive, in which case we might be able to group them into several groups of mutually recursive definitions, - etc.). For this reason, [export_functions_group] computes the dependency + etc. For this reason, {!export_functions_group} computes the dependency graph of the functions as well as their strongly connected components, and - gives each SCC at a time to [export_functions]. + gives each SCC at a time to {!export_functions_group_scc}. Rem.: this function only extracts the function *declarations*. It doesn't extract the decrease clauses, nor does it extract the unit tests. @@ -487,7 +487,7 @@ let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) Rem.: this function doesn't check [config.extract_fun_decls]: it should have been checked by the caller. *) -let export_functions_declarations (fmt : Format.formatter) (config : gen_config) +let export_functions_group_scc (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) (is_rec : bool) (decls : Pure.fun_decl list) : unit = (* Utility to check a function has a decrease clause *) let has_decreases_clause (def : Pure.fun_decl) : bool = @@ -588,7 +588,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) (* Extract the subgroups *) let export_subgroup (is_rec : bool) (decls : Pure.fun_decl list) : unit = - export_functions_declarations fmt config ctx is_rec decls + export_functions_group_scc fmt config ctx is_rec decls in List.iter (fun (is_rec, decls) -> export_subgroup is_rec decls) subgroups); |