summaryrefslogtreecommitdiff
path: root/src/Collections.ml
blob: 0933b3e4aea7bd11c1443e7da0d861f1cbdb7a76 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
(** The following file redefines several modules like Map or Set. *)

module F = Format

module List = struct
  include List

  (** Split a list at a given index.
  
      [split_at ls i] splits [ls] into two lists where the first list has
      length [i].
      
      Raise [Failure] if the list is too short.
  *)
  let rec split_at (ls : 'a list) (i : int) =
    if i < 0 then raise (Invalid_argument "split_at take positive integers")
    else if i = 0 then ([], ls)
    else
      match ls with
      | [] ->
          raise
            (Failure "The int given to split_at should be <= the list's length")
      | x :: ls' ->
          let ls1, ls2 = split_at ls' (i - 1) in
          (x :: ls1, ls2)

  (** Pop the last element of a list
     
      Raise [Failure] if the list is empty.
   *)
  let rec pop_last (ls : 'a list) : 'a list * 'a =
    match ls with
    | [] -> raise (Failure "The list is empty")
    | [ x ] -> ([], x)
    | x :: ls ->
        let ls, last = pop_last ls in
        (x :: ls, last)

  (** Return the n first elements of the list *)
  let prefix (n : int) (ls : 'a list) : 'a list = fst (split_at ls n)

  (** Iter and link the iterations.

      Iterate over a list, but call a function between every two elements
      (but not before the first element, and not after the last).
   *)
  let iter_link (link : unit -> unit) (f : 'a -> unit) (ls : 'a list) : unit =
    let rec iter ls =
      match ls with
      | [] -> ()
      | [ x ] -> f x
      | x :: y :: ls ->
          f x;
          link ();
          iter (y :: ls)
    in
    iter ls

  (** Fold and link the iterations.

      Similar to {!iter_link} but for fold left operations.
   *)
  let fold_left_link (link : unit -> unit) (f : 'a -> 'b -> 'a) (init : 'a)
      (ls : 'b list) : 'a =
    let rec fold (acc : 'a) (ls : 'b list) : 'a =
      match ls with
      | [] -> acc
      | [ x ] -> f acc x
      | x :: y :: ls ->
          let acc = f acc x in
          link ();
          fold acc (y :: ls)
    in
    fold init ls

  let to_cons_nil (ls : 'a list) : 'a =
    match ls with
    | [ x ] -> x
    | _ -> raise (Failure "The list should have length exactly one")

  let pop (ls : 'a list) : 'a * 'a list =
    match ls with
    | x :: ls' -> (x, ls')
    | _ -> raise (Failure "The list should have length > 0")
end

module type OrderedType = sig
  include Map.OrderedType

  val to_string : t -> string
  val pp_t : Format.formatter -> t -> unit
  val show_t : t -> string
end

(** Ordered string *)
module OrderedString : OrderedType with type t = string = struct
  include String

  let to_string s = s
  let pp_t fmt s = Format.pp_print_string fmt s
  let show_t s = s
end

module type Map = sig
  include Map.S

  val add_list : (key * 'a) list -> 'a t -> 'a t
  val of_list : (key * 'a) list -> 'a t

  (** "Simple" pretty printing function.
  
      Is useful when we need to customize a bit [show_t], but without using
      something as burdensome as [pp_t].
  
      [to_string (Some indent) m] prints [m] by breaking line after every binding
      and inserting [indent].
   *)
  val to_string : string option -> ('a -> string) -> 'a t -> string

  val pp : (Format.formatter -> 'a -> unit) -> Format.formatter -> 'a t -> unit
  val show : ('a -> string) -> 'a t -> string
end

module MakeMap (Ord : OrderedType) : Map with type key = Ord.t = struct
  module Map = Map.Make (Ord)
  include Map

  let add_list bl m = List.fold_left (fun s (key, e) -> add key e s) m bl
  let of_list bl = add_list bl empty

  let to_string indent_opt a_to_string m =
    let indent, break =
      match indent_opt with Some indent -> (indent, "\n") | None -> ("", " ")
    in
    let sep = "," ^ break in
    let ls =
      Map.fold
        (fun key v ls ->
          (indent ^ Ord.to_string key ^ " -> " ^ a_to_string v) :: ls)
        m []
    in
    match ls with
    | [] -> "{}"
    | _ -> "{" ^ break ^ String.concat sep (List.rev ls) ^ break ^ "}"

  let pp (pp_a : Format.formatter -> 'a -> unit) (fmt : Format.formatter)
      (m : 'a t) : unit =
    let pp_string = F.pp_print_string fmt in
    let pp_space () = F.pp_print_space fmt () in
    pp_string "{";
    F.pp_open_box fmt 2;
    Map.iter
      (fun key x ->
        Ord.pp_t fmt key;
        pp_space ();
        pp_string "->";
        pp_space ();
        pp_a fmt x;
        pp_string ",";
        F.pp_print_break fmt 1 0)
      m;
    F.pp_close_box fmt ();
    F.pp_print_break fmt 0 0;
    pp_string "}"

  let show show_a m = to_string None show_a m
end

module type Set = sig
  include Set.S

  val add_list : elt list -> t -> t
  val of_list : elt list -> t

  (** "Simple" pretty printing function.
  
      Is useful when we need to customize a bit [show_t], but without using
      something as burdensome as [pp_t].
  
      [to_string (Some indent) s] prints [s] by breaking line after every element
      and inserting [indent].
   *)
  val to_string : string option -> t -> string

  val pp : Format.formatter -> t -> unit
  val show : t -> string
  val pairwise_distinct : elt list -> bool
end

module MakeSet (Ord : OrderedType) : Set with type elt = Ord.t = struct
  module Set = Set.Make (Ord)
  include Set

  let add_list bl s = List.fold_left (fun s e -> add e s) s bl
  let of_list bl = add_list bl empty

  let to_string indent_opt m =
    let indent, break =
      match indent_opt with Some indent -> (indent, "\n") | None -> ("", " ")
    in
    let sep = "," ^ break in
    let ls = Set.fold (fun v ls -> (indent ^ Ord.to_string v) :: ls) m [] in
    match ls with
    | [] -> "{}"
    | _ -> "{" ^ break ^ String.concat sep (List.rev ls) ^ break ^ "}"

  let pp (fmt : Format.formatter) (m : t) : unit =
    let pp_string = F.pp_print_string fmt in
    pp_string "{";
    F.pp_open_box fmt 2;
    Set.iter
      (fun x ->
        Ord.pp_t fmt x;
        pp_string ",";
        F.pp_print_break fmt 1 0)
      m;
    F.pp_close_box fmt ();
    F.pp_print_break fmt 0 0;
    pp_string "}"

  let show s = to_string None s

  let pairwise_distinct ls =
    let s = ref empty in
    let rec check ls =
      match ls with
      | [] -> true
      | x :: ls' ->
          if mem x !s then false
          else (
            s := add x !s;
            check ls')
    in
    check ls
end

(** A map where the bindings are injective (i.e., if two keys are distinct,
    their bindings are distinct).
    
    This is useful for instance when generating mappings from our internal
    identifiers to names (i.e., strings) when generating code, in order to
    make sure that we don't have potentially dangerous collisions.
 *)
module type InjMap = sig
  type key
  type elem
  type t

  val empty : t
  val is_empty : t -> bool
  val mem : key -> t -> bool
  val add : key -> elem -> t -> t
  val singleton : key -> elem -> t
  val remove : key -> t -> t
  val compare : (elem -> elem -> int) -> t -> t -> int
  val equal : (elem -> elem -> bool) -> t -> t -> bool
  val iter : (key -> elem -> unit) -> t -> unit
  val fold : (key -> elem -> 'b -> 'b) -> t -> 'b -> 'b
  val for_all : (key -> elem -> bool) -> t -> bool
  val exists : (key -> elem -> bool) -> t -> bool
  val filter : (key -> elem -> bool) -> t -> t
  val partition : (key -> elem -> bool) -> t -> t * t
  val cardinal : t -> int
  val bindings : t -> (key * elem) list
  val min_binding : t -> key * elem
  val min_binding_opt : t -> (key * elem) option
  val max_binding : t -> key * elem
  val max_binding_opt : t -> (key * elem) option
  val choose : t -> key * elem
  val choose_opt : t -> (key * elem) option
  val split : key -> t -> t * elem option * t
  val find : key -> t -> elem
  val find_opt : key -> t -> elem option
  val find_first : (key -> bool) -> t -> key * elem
  val find_first_opt : (key -> bool) -> t -> (key * elem) option
  val find_last : (key -> bool) -> t -> key * elem
  val find_last_opt : (key -> bool) -> t -> (key * elem) option
  val to_seq : t -> (key * elem) Seq.t
  val to_seq_from : key -> t -> (key * elem) Seq.t
  val add_seq : (key * elem) Seq.t -> t -> t
  val of_seq : (key * elem) Seq.t -> t
  val add_list : (key * elem) list -> t -> t
  val of_list : (key * elem) list -> t
end

(** See {!InjMap} *)
module MakeInjMap (Key : OrderedType) (Elem : OrderedType) :
  InjMap with type key = Key.t with type elem = Elem.t = struct
  module Map = MakeMap (Key)
  module Set = MakeSet (Elem)

  type key = Key.t
  type elem = Elem.t
  type t = { map : elem Map.t; elems : Set.t }

  let empty = { map = Map.empty; elems = Set.empty }
  let is_empty m = Map.is_empty m.map
  let mem k m = Map.mem k m.map

  let add k e m =
    assert (not (Set.mem e m.elems));
    { map = Map.add k e m.map; elems = Set.add e m.elems }

  let singleton k e = { map = Map.singleton k e; elems = Set.singleton e }

  let remove k m =
    match Map.find_opt k m.map with
    | None -> m
    | Some x -> { map = Map.remove k m.map; elems = Set.remove x m.elems }

  let compare f m1 m2 = Map.compare f m1.map m2.map
  let equal f m1 m2 = Map.equal f m1.map m2.map
  let iter f m = Map.iter f m.map
  let fold f m x = Map.fold f m.map x
  let for_all f m = Map.for_all f m.map
  let exists f m = Map.exists f m.map

  (** Small helper *)
  let bindings_to_elems_set (bls : (key * elem) list) : Set.t =
    let elems = List.map snd bls in
    let elems = List.fold_left (fun s e -> Set.add e s) Set.empty elems in
    elems

  (** Small helper *)
  let map_to_elems_set (map : elem Map.t) : Set.t =
    bindings_to_elems_set (Map.bindings map)

  (** Small helper *)
  let map_to_t (map : elem Map.t) : t =
    let elems = map_to_elems_set map in
    { map; elems }

  let filter f m =
    let map = Map.filter f m.map in
    let elems = map_to_elems_set map in
    { map; elems }

  let partition f m =
    let map1, map2 = Map.partition f m.map in
    (map_to_t map1, map_to_t map2)

  let cardinal m = Map.cardinal m.map
  let bindings m = Map.bindings m.map
  let min_binding m = Map.min_binding m.map
  let min_binding_opt m = Map.min_binding_opt m.map
  let max_binding m = Map.max_binding m.map
  let max_binding_opt m = Map.max_binding_opt m.map
  let choose m = Map.choose m.map
  let choose_opt m = Map.choose_opt m.map

  let split k m =
    let l, data, r = Map.split k m.map in
    let l = map_to_t l in
    let r = map_to_t r in
    (l, data, r)

  let find k m = Map.find k m.map
  let find_opt k m = Map.find_opt k m.map
  let find_first k m = Map.find_first k m.map
  let find_first_opt k m = Map.find_first_opt k m.map
  let find_last k m = Map.find_last k m.map
  let find_last_opt k m = Map.find_last_opt k m.map
  let to_seq m = Map.to_seq m.map
  let to_seq_from k m = Map.to_seq_from k m.map

  let rec add_seq s m =
    (* Note that it is important to check that we don't add bindings mapping
     * to the same element *)
    match s () with
    | Seq.Nil -> m
    | Seq.Cons ((k, e), s) ->
        let m = add k e m in
        add_seq s m

  let of_seq s = add_seq s empty
  let add_list ls m = List.fold_left (fun m (key, elem) -> add key elem m) m ls
  let of_list ls = add_list ls empty
end