diff options
-rw-r--r-- | compiler/Logging.ml | 6 | ||||
-rw-r--r-- | compiler/Print.ml | 3 | ||||
-rw-r--r-- | compiler/PureUtils.ml | 49 | ||||
-rw-r--r-- | compiler/ReorderDecls.ml | 86 | ||||
-rw-r--r-- | compiler/SCC.ml | 12 | ||||
-rw-r--r-- | compiler/Translate.ml | 152 |
6 files changed, 183 insertions, 125 deletions
diff --git a/compiler/Logging.ml b/compiler/Logging.ml index 8cfc25d3..71f72471 100644 --- a/compiler/Logging.ml +++ b/compiler/Logging.ml @@ -47,3 +47,9 @@ let borrows_log = L.get_logger "MainLogger.Interpreter.Borrows" (** Logger for Invariants *) let invariants_log = L.get_logger "MainLogger.Interpreter.Invariants" + +(** Logger for SCC *) +let scc_log = L.get_logger "MainLogger.SCC" + +(** Logger for ReorderDecls *) +let reorder_decls_log = L.get_logger "MainLogger.ReorderDecls" diff --git a/compiler/Print.ml b/compiler/Print.ml index a4a8b1d4..82c5dac5 100644 --- a/compiler/Print.ml +++ b/compiler/Print.ml @@ -7,6 +7,9 @@ module PrimitiveValues = Charon.PrintPrimitiveValues module Types = Charon.PrintTypes module Expressions = Charon.PrintExpressions +let list_to_string (to_string : 'a -> string) (ls : 'a list) : string = + "[" ^ String.concat "; " (List.map to_string ls) ^ "]" + (** Pretty-printing for values *) module Values = struct type value_formatter = { diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 0f1d50f1..da15d635 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -110,55 +110,6 @@ let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) : let info = sg.info in { inputs; output; doutputs; info } -(** Return true if a list of functions are *not* mutually recursive, false otherwise. - This function is meant to be applied on a set of (forward, backwards) functions - generated for one recursive function. - The way we do the test is very simple: - - we explore the functions one by one, in the order in which they are provided - - if all functions only call functions we already explored, they are not - mutually recursive - *) -let functions_not_mutually_recursive (funs : fun_decl list) : bool = - (* Compute the set of function identifiers in the group *) - let ids = - FunOrOpIdSet.of_list - (List.map - (fun (f : fun_decl) -> Fun (FromLlbc (A.Regular f.def_id, f.back_id))) - funs) - in - let ids = ref ids in - (* Explore every body *) - let body_only_calls_itself (fdef : fun_decl) : bool = - (* Remove the current id from the id set *) - ids := - FunOrOpIdSet.remove - (Fun (FromLlbc (A.Regular fdef.def_id, fdef.back_id))) - !ids; - - (* Check if we call functions from the updated id set *) - let obj = - object - inherit [_] iter_expression as super - - method! visit_qualif env qualif = - match qualif.id with - | FunOrOp fun_id -> - if FunOrOpIdSet.mem fun_id !ids then raise Utils.Found - else super#visit_qualif env qualif - | _ -> super#visit_qualif env qualif - end - in - - try - match fdef.body with - | None -> true - | Some body -> - obj#visit_texpression () body.body; - true - with Utils.Found -> false - in - List.for_all body_only_calls_itself funs - (** We use this to check whether we need to add parentheses around expressions. We only look for outer monadic let-bindings. This is used when printing the branches of [if ... then ... else ...]. diff --git a/compiler/ReorderDecls.ml b/compiler/ReorderDecls.ml index 9d222011..2b78c570 100644 --- a/compiler/ReorderDecls.ml +++ b/compiler/ReorderDecls.ml @@ -3,6 +3,9 @@ open Collections open SCC open Pure +(** The local logger *) +let log = Logging.reorder_decls_log + type fun_id = { def_id : FunDeclId.id; rg_id : T.RegionGroupId.id option } [@@deriving show, ord] @@ -69,6 +72,11 @@ let group_reorder_fun_decls (decls : fun_decl list) : let idl = List.map get_fun_id decls in let ids = FunIdSet.of_list idl in + log#ldebug + (lazy + ("group_reorder_fun_decls: ids:\n" + ^ (Print.list_to_string FunIdOrderedType.show_t) idl)); + (* Explore the bodies to compute the dependencies - we ignore the ids which refer to declarations not in the group we want to reorder *) let deps : (fun_id * FunIdSet.t) list = @@ -79,50 +87,88 @@ let group_reorder_fun_decls (decls : fun_decl list) : | None -> (id, FunIdSet.empty) | Some body -> let deps = compute_body_fun_deps body.body in - (id, FunIdSet.inter deps ids)) + (* Restrict the set dependencies *) + let deps = FunIdSet.inter deps ids in + (id, deps)) decls in (* * Create the dependency graph *) - (* Convert the ids to vertices (i.e., injectively map ids to integers) *) - let id_to_vertex : int FunIdMap.t = + (* Convert the ids to vertices (i.e., injectively map ids to integers, and create + vertices labeled with those integers). + + Rem.: [Graph.create] is *imperative*: it generates a new vertex every time + it is called (!!). + *) + let module Graph = Pack.Digraph in + let id_to_vertex : Graph.V.t FunIdMap.t = let cnt = ref 0 in FunIdMap.of_list (List.map (fun id -> - let v = !cnt in + let lbl = !cnt in cnt := !cnt + 1; + (* We create a vertex *) + let v = Graph.V.create lbl in (id, v)) idl) in let vertex_to_id : fun_id IntMap.t = IntMap.of_list - (List.map (fun (fid, vid) -> (vid, fid)) (FunIdMap.bindings id_to_vertex)) + (List.map + (fun (fid, v) -> (Graph.V.label v, fid)) + (FunIdMap.bindings id_to_vertex)) in - let to_v id = Pack.Graph.V.create (FunIdMap.find id id_to_vertex) in - let to_id v = IntMap.find (Pack.Graph.V.label v) vertex_to_id in - let g = Pack.Graph.create () in - (* First add the vertices *) - List.iter (fun id -> Pack.Graph.add_vertex g (to_v id)) idl; + let to_v id = FunIdMap.find id id_to_vertex in + let to_id v = IntMap.find (Graph.V.label v) vertex_to_id in - (* Then add the edges *) + let g = Graph.create () in + + (* Add the edges, first from the vertices to themselves, then between vertices *) List.iter (fun (fun_id, deps) -> - FunIdSet.iter - (fun dep_id -> Pack.Graph.add_edge g (to_v fun_id) (to_v dep_id)) - deps) + let v = to_v fun_id in + Graph.add_edge g v v; + FunIdSet.iter (fun dep_id -> Graph.add_edge g v (to_v dep_id)) deps) deps; (* Compute the SCCs *) - let module Comp = Components.Make (Pack.Graph) in + let module Comp = Components.Make (Graph) in let sccs = Comp.scc_list g in (* Convert the vertices to ids *) let sccs = List.map (List.map to_id) sccs in + log#ldebug + (lazy + ("group_reorder_fun_decls: SCCs:\n" + ^ Print.list_to_string (Print.list_to_string FunIdOrderedType.show_t) sccs + )); + + (* Sanity check *) + let _ = + (* Check that the SCCs are pairwise disjoint *) + assert (FunIdSet.pairwise_disjoint (List.map FunIdSet.of_list sccs)); + (* Check that all the ids are in the sccs *) + let scc_ids = FunIdSet.of_list (List.concat sccs) in + + log#ldebug + (lazy + ("group_reorder_fun_decls: sanity check:" ^ "\n- ids : " + ^ FunIdSet.show ids ^ "\n- scc_ids: " ^ FunIdSet.show scc_ids)); + + assert (FunIdSet.equal scc_ids ids) + in + + log#ldebug + (lazy + ("group_reorder_fun_decls: reordered SCCs:\n" + ^ Print.list_to_string (Print.list_to_string FunIdOrderedType.show_t) sccs + )); + (* Reorder *) let module Reorder = SCC.Make (FunIdOrderedType) in let id_deps = @@ -131,6 +177,16 @@ let group_reorder_fun_decls (decls : fun_decl list) : in let sccs = Reorder.reorder_sccs id_deps idl sccs in + (* Sanity check *) + let _ = + (* Check that the SCCs are pairwise disjoint *) + let sccs = List.map snd (SccId.Map.bindings sccs.sccs) in + assert (FunIdSet.pairwise_disjoint (List.map FunIdSet.of_list sccs)); + (* Check that all the ids are in the sccs *) + let scc_ids = FunIdSet.of_list (List.concat sccs) in + assert (scc_ids = ids) + in + (* Group the declarations *) let deps = FunIdMap.of_list deps in let decls = FunIdMap.of_list (List.map (fun d -> (get_fun_id d, d)) decls) in diff --git a/compiler/SCC.ml b/compiler/SCC.ml index 8b4cdb1f..889a972b 100644 --- a/compiler/SCC.ml +++ b/compiler/SCC.ml @@ -3,6 +3,9 @@ open Collections module SccId = Identifiers.IdGen () +(** The local logger *) +let log = Logging.scc_log + (** A functor which provides functions to work on strongly connected components *) module Make (Id : OrderedType) = struct module IdMap = MakeMap (Id) @@ -91,6 +94,15 @@ module Make (Id : OrderedType) = struct *) let reorder_sccs (id_deps : Id.t list IdMap.t) (ids : Id.t list) (sccs : Id.t list list) : sccs = + log#ldebug + (lazy + ("reorder_sccs:" ^ "\n- id_deps: " + ^ IdMap.show (Print.list_to_string Id.show_t) id_deps + ^ "\n- ids: " + ^ Print.list_to_string Id.show_t ids + ^ "\n- sccs: " + ^ Print.list_to_string (Print.list_to_string Id.show_t) sccs)); + (* Map the identifiers to the SCC indices *) let id_to_scc = IdMap.of_list diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 4f31b738..87862d6b 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -383,8 +383,8 @@ let export_type (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) [is_rec]: [true] if the types are recursive. Necessarily [true] if there is > 1 type to extract. *) -let export_types (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) - (is_rec : bool) (ids : Pure.TypeDeclId.id list) : unit = +let export_types_group (fmt : Format.formatter) (config : gen_config) + (ctx : gen_ctx) (is_rec : bool) (ids : Pure.TypeDeclId.id list) : unit = let export_type = export_type fmt config ctx in let export_type_decl kind id = export_type kind id true false in let export_type_extra_info kind id = export_type kind id false true in @@ -432,30 +432,79 @@ let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) then Extract.extract_global_decl ctx.extract_ctx fmt global body config.interface +(** Utility. + + Export a group of functions. See [export_functions_group]. + + We need this because for every function in Rust we may generate several functions + in the translation (a forward function, several backward functions, loop + functions, etc.). Those functions might call each other in different + ways (in particular, they may be mutually recursive, in which case we might + be able to group them into several groups of mutually recursive definitions, + etc.). For this reason, [export_functions_group] computes the dependency + graph of the functions as well as their strongly connected components, and + gives each SCC at a time to [export_functions]. + + Rem.: this function only extracts the function *declarations*. It doesn't + extract the decrease clauses, nor does it extract the unit tests. + + Rem.: this function doesn't check [config.extract_fun_decls]: it should have + been checked by the caller. + *) +let export_functions_declarations (fmt : Format.formatter) (config : gen_config) + (ctx : gen_ctx) (is_rec : bool) (decls : Pure.fun_decl list) : unit = + (* Utility to check a function has a decrease clause *) + let has_decreases_clause (def : Pure.fun_decl) : bool = + A.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause + in + + (* Extract the function declarations *) + (* Check if the functions are mutually recursive - this really works + * to check if the forward and backward translations of a single + * recursive function are mutually recursive *) + let is_mut_rec = List.length decls > 1 in + assert ((not is_mut_rec) || is_rec); + let decls_length = List.length decls in + List.iteri + (fun i def -> + let is_opaque = Option.is_none def.Pure.body in + let kind = + if is_opaque then + if config.interface then ExtractBase.Declared else ExtractBase.Assumed + else if not is_rec then ExtractBase.SingleNonRec + else if is_mut_rec then + (* If the functions are mutually recursive, we need to distinguish: + * - the first of the group + * - the last of the group + * - the inner functions + *) + if i = 0 then ExtractBase.MutRecFirst + else if i = decls_length - 1 then ExtractBase.MutRecLast + else ExtractBase.MutRecInner + else ExtractBase.SingleRec + in + let has_decr_clause = + has_decreases_clause def && config.extract_decreases_clauses + in + (* Check if the definition needs to be filtered or not *) + if + ((not is_opaque) && config.extract_transparent) + || (is_opaque && config.extract_opaque) + then Extract.extract_fun_decl ctx.extract_ctx fmt kind has_decr_clause def) + decls + (** Export a group of function declarations. In case of (non-mutually) recursive functions, we use a simple procedure to check if the forward and backward functions are mutually recursive. *) -let export_functions (fmt : Format.formatter) (config : gen_config) - (ctx : gen_ctx) (is_rec : bool) - (pure_ls : (bool * pure_fun_translation) list) : unit = +let export_functions_group (fmt : Format.formatter) (config : gen_config) + (ctx : gen_ctx) (pure_ls : (bool * pure_fun_translation) list) : unit = (* Utility to check a function has a decrease clause *) let has_decreases_clause (def : Pure.fun_decl) : bool = A.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause in - (* Concatenate the function definitions, filtering the useless forward - * functions. We also make pairs: (forward function, backward function) - * (the forward function contains useful information that we want to keep) *) - let fls = - List.concat - (List.map - (fun (keep_fwd, (fwd, back_ls)) -> - let back_ls = List.map (fun back -> (fwd, back)) back_ls in - if keep_fwd then (fwd, fwd) :: back_ls else back_ls) - pure_ls) - in (* Extract the decrease clauses template bodies *) if config.extract_template_decreases_clauses then List.iter @@ -464,48 +513,28 @@ let export_functions (fmt : Format.formatter) (config : gen_config) if has_decr_clause then Extract.extract_template_decreases_clause ctx.extract_ctx fmt fwd) pure_ls; + + (* Concatenate the function definitions, filtering the useless forward + * functions. *) + let decls = + List.concat + (List.map + (fun (keep_fwd, (fwd, back_ls)) -> + if keep_fwd then fwd :: back_ls else back_ls) + pure_ls) + in + (* Extract the function definitions *) (if config.extract_fun_decls then - (* Check if the functions are mutually recursive - this really works - * to check if the forward and backward translations of a single - * recursive function are mutually recursive *) - let is_mut_rec = - if is_rec then - if List.length pure_ls <= 1 then - not (PureUtils.functions_not_mutually_recursive (List.map fst fls)) - else true - else false + (* Group the mutually recursive definitions *) + let subgroups = ReorderDecls.group_reorder_fun_decls decls in + + (* Extract the subgroups *) + let export_subgroup (is_rec : bool) (decls : Pure.fun_decl list) : unit = + export_functions_declarations fmt config ctx is_rec decls in - let fls_length = List.length fls in - List.iteri - (fun i (fwd_def, def) -> - let is_opaque = Option.is_none fwd_def.Pure.body in - let kind = - if is_opaque then - if config.interface then ExtractBase.Declared - else ExtractBase.Assumed - else if not is_rec then ExtractBase.SingleNonRec - else if is_mut_rec then - (* If the functions are mutually recursive, we need to distinguish: - * - the first of the group - * - the last of the group - * - the inner functions - *) - if i = 0 then ExtractBase.MutRecFirst - else if i = fls_length - 1 then ExtractBase.MutRecLast - else ExtractBase.MutRecInner - else ExtractBase.SingleRec - in - let has_decr_clause = - has_decreases_clause def && config.extract_decreases_clauses - in - (* Check if the definition needs to be filtered or not *) - if - ((not is_opaque) && config.extract_transparent) - || (is_opaque && config.extract_opaque) - then - Extract.extract_fun_decl ctx.extract_ctx fmt kind has_decr_clause def) - fls); + List.iter (fun (is_rec, decls) -> export_subgroup is_rec decls) subgroups); + (* Insert unit tests if necessary *) if config.test_trans_unit_functions then List.iter @@ -525,9 +554,9 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) - [extract_extra_info]: extra the extra type information (e.g., the [Arguments] information in Coq). *) - let export_functions = export_functions fmt config ctx in + let export_functions_group = export_functions_group fmt config ctx in let export_global = export_global fmt config ctx in - let export_types = export_types fmt config ctx in + let export_types_group = export_types_group fmt config ctx in let export_state_type () : unit = let kind = @@ -538,13 +567,14 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) let export_decl_group (dg : A.declaration_group) : unit = match dg with - | Type (NonRec id) -> if config.extract_types then export_types false [ id ] - | Type (Rec ids) -> if config.extract_types then export_types true ids + | Type (NonRec id) -> + if config.extract_types then export_types_group false [ id ] + | Type (Rec ids) -> if config.extract_types then export_types_group true ids | Fun (NonRec id) -> (* Lookup *) let pure_fun = A.FunDeclId.Map.find id ctx.trans_funs in (* Translate *) - export_functions false [ pure_fun ] + export_functions_group [ pure_fun ] | Fun (Rec ids) -> (* General case of mutually recursive functions *) (* Lookup *) @@ -552,7 +582,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) List.map (fun id -> A.FunDeclId.Map.find id ctx.trans_funs) ids in (* Translate *) - export_functions true pure_funs + export_functions_group pure_funs | Global id -> export_global id in |