1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
|
open Graph
open Collections
open SCC
open Pure
(** The local logger *)
let log = Logging.reorder_decls_log
type fun_id = {
def_id : FunDeclId.id;
lp_id : LoopId.id option;
rg_id : T.RegionGroupId.id option;
}
[@@deriving show, ord]
module FunIdOrderedType : OrderedType with type t = fun_id = struct
type t = fun_id
let compare = compare_fun_id
let to_string = show_fun_id
let pp_t = pp_fun_id
let show_t = show_fun_id
end
module FunIdMap = Collections.MakeMap (FunIdOrderedType)
module FunIdSet = Collections.MakeSet (FunIdOrderedType)
(** Compute the dependencies of a function body, taking only into account
the *custom* (i.e., not assumed) functions ids (ignoring operations, types,
globals, etc.).
*)
let compute_body_fun_deps (e : texpression) : FunIdSet.t =
let ids = ref FunIdSet.empty in
let visitor =
object
inherit [_] iter_expression
method! visit_qualif _ id =
match id.id with
| FunOrOp (Unop _ | Binop _)
| Global _ | AdtCons _ | Proj _ | TraitConst _ ->
()
| FunOrOp (Fun fid) -> (
match fid with
| Pure _ -> ()
| FromLlbc (fid, lp_id, rg_id) -> (
match fid with
| FunId (FAssumed _) -> ()
| TraitMethod (_, _, fid) | FunId (FRegular fid) ->
let id = { def_id = fid; lp_id; rg_id } in
ids := FunIdSet.add id !ids))
end
in
visitor#visit_texpression () e;
!ids
type function_group = {
is_rec : bool;
(** [true] if (mutually) recursive. Useful only if there is exactly one
declaration in the group.
*)
decls : fun_decl list;
}
(** Group mutually recursive functions together and reorder the groups so that
if a group B depends on a group A then A comes before B, while trying to
respect the original order as much as possible.
*)
let group_reorder_fun_decls (decls : fun_decl list) :
(bool * fun_decl list) list =
let module IntMap = MakeMap (OrderedInt) in
let get_fun_id (decl : fun_decl) : fun_id =
{ def_id = decl.def_id; lp_id = decl.loop_id; rg_id = decl.back_id }
in
(* Compute the list/set of identifiers *)
let idl = List.map get_fun_id decls in
let ids = FunIdSet.of_list idl in
log#ldebug
(lazy
("group_reorder_fun_decls: ids:\n"
^ (Print.list_to_string FunIdOrderedType.show_t) idl));
(* Explore the bodies to compute the dependencies - we ignore the ids
which refer to declarations not in the group we want to reorder *)
let deps : (fun_id * FunIdSet.t) list =
List.map
(fun decl ->
let id = get_fun_id decl in
match decl.body with
| None -> (id, FunIdSet.empty)
| Some body ->
let deps = compute_body_fun_deps body.body in
(* Restrict the set dependencies *)
let deps = FunIdSet.inter deps ids in
(id, deps))
decls
in
(*
* Create the dependency graph
*)
(* Convert the ids to vertices (i.e., injectively map ids to integers, and create
vertices labeled with those integers).
Rem.: [Graph.create] is *imperative*: it generates a new vertex every time
it is called (!!).
*)
let module Graph = Pack.Digraph in
let id_to_vertex : Graph.V.t FunIdMap.t =
let cnt = ref 0 in
FunIdMap.of_list
(List.map
(fun id ->
let lbl = !cnt in
cnt := !cnt + 1;
(* We create a vertex *)
let v = Graph.V.create lbl in
(id, v))
idl)
in
let vertex_to_id : fun_id IntMap.t =
IntMap.of_list
(List.map
(fun (fid, v) -> (Graph.V.label v, fid))
(FunIdMap.bindings id_to_vertex))
in
let to_v id = FunIdMap.find id id_to_vertex in
let to_id v = IntMap.find (Graph.V.label v) vertex_to_id in
let g = Graph.create () in
(* Add the edges, first from the vertices to themselves, then between vertices *)
List.iter
(fun (fun_id, deps) ->
let v = to_v fun_id in
Graph.add_edge g v v;
FunIdSet.iter (fun dep_id -> Graph.add_edge g v (to_v dep_id)) deps)
deps;
(* Compute the SCCs *)
let module Comp = Components.Make (Graph) in
let sccs = Comp.scc_list g in
(* Convert the vertices to ids *)
let sccs = List.map (List.map to_id) sccs in
log#ldebug
(lazy
("group_reorder_fun_decls: SCCs:\n"
^ Print.list_to_string (Print.list_to_string FunIdOrderedType.show_t) sccs
));
(* Sanity check *)
let _ =
(* Check that the SCCs are pairwise disjoint *)
assert (FunIdSet.pairwise_disjoint (List.map FunIdSet.of_list sccs));
(* Check that all the ids are in the sccs *)
let scc_ids = FunIdSet.of_list (List.concat sccs) in
log#ldebug
(lazy
("group_reorder_fun_decls: sanity check:" ^ "\n- ids : "
^ FunIdSet.show ids ^ "\n- scc_ids: " ^ FunIdSet.show scc_ids));
assert (FunIdSet.equal scc_ids ids)
in
log#ldebug
(lazy
("group_reorder_fun_decls: reordered SCCs:\n"
^ Print.list_to_string (Print.list_to_string FunIdOrderedType.show_t) sccs
));
(* Reorder *)
let module Reorder = SCC.Make (FunIdOrderedType) in
let id_deps =
FunIdMap.of_list
(List.map (fun (fid, deps) -> (fid, FunIdSet.elements deps)) deps)
in
let sccs = Reorder.reorder_sccs id_deps idl sccs in
(* Sanity check *)
let _ =
(* Check that the SCCs are pairwise disjoint *)
let sccs = List.map snd (SccId.Map.bindings sccs.sccs) in
assert (FunIdSet.pairwise_disjoint (List.map FunIdSet.of_list sccs));
(* Check that all the ids are in the sccs *)
let scc_ids = FunIdSet.of_list (List.concat sccs) in
assert (FunIdSet.equal scc_ids ids)
in
(* Group the declarations *)
let deps = FunIdMap.of_list deps in
let decls = FunIdMap.of_list (List.map (fun d -> (get_fun_id d, d)) decls) in
List.map
(fun (_, ids) ->
(* is_rec is useful only if there is exactly one declaration *)
let is_rec =
match ids with
| [] -> raise (Failure "Unreachable")
| [ id ] ->
let dep_ids = FunIdMap.find id deps in
FunIdSet.mem id dep_ids
| _ -> true
in
let decls = List.map (fun id -> FunIdMap.find id decls) ids in
(is_rec, decls))
(SccId.Map.bindings sccs.sccs)
|