open Graph open Collections open SCC open Pure (** The local logger *) let log = Logging.reorder_decls_log type fun_id = { def_id : FunDeclId.id; lp_id : LoopId.id option; rg_id : T.RegionGroupId.id option; } [@@deriving show, ord] module FunIdOrderedType : OrderedType with type t = fun_id = struct type t = fun_id let compare = compare_fun_id let to_string = show_fun_id let pp_t = pp_fun_id let show_t = show_fun_id end module FunIdMap = Collections.MakeMap (FunIdOrderedType) module FunIdSet = Collections.MakeSet (FunIdOrderedType) (** Compute the dependencies of a function body, taking only into account the *custom* (i.e., not assumed) functions ids (ignoring operations, types, globals, etc.). *) let compute_body_fun_deps (e : texpression) : FunIdSet.t = let ids = ref FunIdSet.empty in let visitor = object inherit [_] iter_expression method! visit_qualif _ id = match id.id with | FunOrOp (Unop _ | Binop _) | Global _ | AdtCons _ | Proj _ | TraitConst _ -> () | FunOrOp (Fun fid) -> ( match fid with | Pure _ -> () | FromLlbc (fid, lp_id, rg_id) -> ( match fid with | FunId (Assumed _) -> () | TraitMethod (_, _, fid) | FunId (Regular fid) -> let id = { def_id = fid; lp_id; rg_id } in ids := FunIdSet.add id !ids)) end in visitor#visit_texpression () e; !ids type function_group = { is_rec : bool; (** [true] if (mutually) recursive. Useful only if there is exactly one declaration in the group. *) decls : fun_decl list; } (** Group mutually recursive functions together and reorder the groups so that if a group B depends on a group A then A comes before B, while trying to respect the original order as much as possible. *) let group_reorder_fun_decls (decls : fun_decl list) : (bool * fun_decl list) list = let module IntMap = MakeMap (OrderedInt) in let get_fun_id (decl : fun_decl) : fun_id = { def_id = decl.def_id; lp_id = decl.loop_id; rg_id = decl.back_id } in (* Compute the list/set of identifiers *) let idl = List.map get_fun_id decls in let ids = FunIdSet.of_list idl in log#ldebug (lazy ("group_reorder_fun_decls: ids:\n" ^ (Print.list_to_string FunIdOrderedType.show_t) idl)); (* Explore the bodies to compute the dependencies - we ignore the ids which refer to declarations not in the group we want to reorder *) let deps : (fun_id * FunIdSet.t) list = List.map (fun decl -> let id = get_fun_id decl in match decl.body with | None -> (id, FunIdSet.empty) | Some body -> let deps = compute_body_fun_deps body.body in (* Restrict the set dependencies *) let deps = FunIdSet.inter deps ids in (id, deps)) decls in (* * Create the dependency graph *) (* Convert the ids to vertices (i.e., injectively map ids to integers, and create vertices labeled with those integers). Rem.: [Graph.create] is *imperative*: it generates a new vertex every time it is called (!!). *) let module Graph = Pack.Digraph in let id_to_vertex : Graph.V.t FunIdMap.t = let cnt = ref 0 in FunIdMap.of_list (List.map (fun id -> let lbl = !cnt in cnt := !cnt + 1; (* We create a vertex *) let v = Graph.V.create lbl in (id, v)) idl) in let vertex_to_id : fun_id IntMap.t = IntMap.of_list (List.map (fun (fid, v) -> (Graph.V.label v, fid)) (FunIdMap.bindings id_to_vertex)) in let to_v id = FunIdMap.find id id_to_vertex in let to_id v = IntMap.find (Graph.V.label v) vertex_to_id in let g = Graph.create () in (* Add the edges, first from the vertices to themselves, then between vertices *) List.iter (fun (fun_id, deps) -> let v = to_v fun_id in Graph.add_edge g v v; FunIdSet.iter (fun dep_id -> Graph.add_edge g v (to_v dep_id)) deps) deps; (* Compute the SCCs *) let module Comp = Components.Make (Graph) in let sccs = Comp.scc_list g in (* Convert the vertices to ids *) let sccs = List.map (List.map to_id) sccs in log#ldebug (lazy ("group_reorder_fun_decls: SCCs:\n" ^ Print.list_to_string (Print.list_to_string FunIdOrderedType.show_t) sccs )); (* Sanity check *) let _ = (* Check that the SCCs are pairwise disjoint *) assert (FunIdSet.pairwise_disjoint (List.map FunIdSet.of_list sccs)); (* Check that all the ids are in the sccs *) let scc_ids = FunIdSet.of_list (List.concat sccs) in log#ldebug (lazy ("group_reorder_fun_decls: sanity check:" ^ "\n- ids : " ^ FunIdSet.show ids ^ "\n- scc_ids: " ^ FunIdSet.show scc_ids)); assert (FunIdSet.equal scc_ids ids) in log#ldebug (lazy ("group_reorder_fun_decls: reordered SCCs:\n" ^ Print.list_to_string (Print.list_to_string FunIdOrderedType.show_t) sccs )); (* Reorder *) let module Reorder = SCC.Make (FunIdOrderedType) in let id_deps = FunIdMap.of_list (List.map (fun (fid, deps) -> (fid, FunIdSet.elements deps)) deps) in let sccs = Reorder.reorder_sccs id_deps idl sccs in (* Sanity check *) let _ = (* Check that the SCCs are pairwise disjoint *) let sccs = List.map snd (SccId.Map.bindings sccs.sccs) in assert (FunIdSet.pairwise_disjoint (List.map FunIdSet.of_list sccs)); (* Check that all the ids are in the sccs *) let scc_ids = FunIdSet.of_list (List.concat sccs) in assert (FunIdSet.equal scc_ids ids) in (* Group the declarations *) let deps = FunIdMap.of_list deps in let decls = FunIdMap.of_list (List.map (fun d -> (get_fun_id d, d)) decls) in List.map (fun (_, ids) -> (* is_rec is useful only if there is exactly one declaration *) let is_rec = match ids with | [] -> raise (Failure "Unreachable") | [ id ] -> let dep_ids = FunIdMap.find id deps in FunIdSet.mem id dep_ids | _ -> true in let decls = List.map (fun id -> FunIdMap.find id decls) ids in (is_rec, decls)) (SccId.Map.bindings sccs.sccs)