aboutsummaryrefslogtreecommitdiff
path: root/stdlib/source/lux/data/collection/dictionary.lux
blob: 8e07a4ab4e5af3926c30978a5efc96ecd8dfe008 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
(.module:
  [lux #*
   [abstract
    [hash (#+ Hash)]
    [equivalence (#+ Equivalence)]
    [functor (#+ Functor)]]
   [control
    ["." try (#+ Try)]
    ["." exception (#+ exception:)]]
   [data
    ["." maybe]
    ["." product]
    [collection
     ["." list ("#\." fold functor monoid)]
     ["." array (#+ Array) ("#\." functor fold)]]]
   [math
    ["." number
     ["n" nat]
     ["." i64]]]])

## This implementation of Hash Array Mapped Trie (HAMT) is based on
## Clojure's PersistentHashMap implementation.
## That one is further based on Phil Bagwell's Hash Array Mapped Trie.

## Bitmaps are used to figure out which branches on a #Base node are
## populated. The number of bits that are 1s in a bitmap signal the
## size of the #Base node.
(type: BitMap
  Nat)

## Represents the position of a node in a BitMap.
## It's meant to be a single bit set on a 32-bit word.
## The position of the bit reflects whether an entry in an analogous
## position exists within a #Base, as reflected in it's BitMap.
(type: BitPosition
  Nat)

## An index into an array.
(type: Index
  Nat)

## A hash-code derived from a key during tree-traversal.
(type: Hash_Code
  Nat)

## Represents the nesting level of a leaf or node, when looking-it-up
## while exploring the tree.
## Changes in levels are done by right-shifting the hashes of keys by
## the appropriate multiple of the branching-exponent.
## A shift of 0 means root level.
## A shift of (* branching_exponent 1) means level 2.
## A shift of (* branching_exponent N) means level N+1.
(type: Level
  Nat)

## Nodes for the tree data-structure that organizes the data inside
## Dictionaries.
(type: (Node k v)
  (#Hierarchy Nat (Array (Node k v)))
  (#Base BitMap
         (Array (Either (Node k v)
                        [k v])))
  (#Collisions Hash_Code (Array [k v])))

## #Hierarchy nodes are meant to point down only to lower-level nodes.
(type: (Hierarchy k v)
  [Nat (Array (Node k v))])

## #Base nodes may point down to other nodes, but also to leaves,
## which are KV-pairs.
(type: (Base k v)
  (Array (Either (Node k v)
                 [k v])))

## #Collisions are collections of KV-pairs for which the key is
## different on each case, but their hashes are all the same (thus
## causing a collision).
(type: (Collisions k v)
  (Array [k v]))

## That bitmap for an empty #Base is 0.
## Which is the same as 0000 0000 0000 0000 0000 0000 0000 0000.
## Or 0x00000000.
## Which is 32 zeroes, since the branching factor is 32.
(def: clean_bitmap
  BitMap
  0)

## Bitmap position (while looking inside #Base nodes) is determined by
## getting 5 bits from a hash of the key being looked up and using
## them as an index into the array inside #Base.
## Since the data-structure can have multiple levels (and the hash has
## more than 5 bits), the binary-representation of the hash is shifted
## by 5 positions on each step (2^5 = 32, which is the branching
## factor).
## The initial shifting level, though, is 0 (which corresponds to the
## shift in the shallowest node on the tree, which is the root node).
(def: root_level
  Level
  0)

## The exponent to which 2 must be elevated, to reach the branching
## factor of the data-structure.
(def: branching_exponent
  Nat
  5)

## The threshold on which #Hierarchy nodes are demoted to #Base nodes,
## which is 1/4 of the branching factor (or a left-shift 2).
(def: demotion_threshold
  Nat
  (i64.left_shift (n.- 2 branching_exponent) 1))

## The threshold on which #Base nodes are promoted to #Hierarchy nodes,
## which is 1/2 of the branching factor (or a left-shift 1).
(def: promotion_threshold
  Nat
  (i64.left_shift (n.- 1 branching_exponent) 1))

## The size of hierarchy-nodes, which is 2^(branching-exponent).
(def: hierarchy_nodes_size
  Nat
  (i64.left_shift branching_exponent 1))

## The cannonical empty node, which is just an empty #Base node.
(def: empty
  Node
  (#Base clean_bitmap (array.new 0)))

## Expands a copy of the array, to have 1 extra slot, which is used
## for storing the value.
(def: (insert! idx value old_array)
  (All [a] (-> Index a (Array a) (Array a)))
  (let [old_size (array.size old_array)]
    (|> (array.new (inc old_size))
        (array.copy! idx 0 old_array 0)
        (array.write! idx value)
        (array.copy! (n.- idx old_size) idx old_array (inc idx)))))

## Creates a copy of an array with an index set to a particular value.
(def: (update! idx value array)
  (All [a] (-> Index a (Array a) (Array a)))
  (|> array array.clone (array.write! idx value)))

## Creates a clone of the array, with an empty position at index.
(def: (vacant! idx array)
  (All [a] (-> Index (Array a) (Array a)))
  (|> array array.clone (array.delete! idx)))

## Shrinks a copy of the array by removing the space at index.
(def: (remove! idx array)
  (All [a] (-> Index (Array a) (Array a)))
  (let [new_size (dec (array.size array))]
    (|> (array.new new_size)
        (array.copy! idx 0 array 0)
        (array.copy! (n.- idx new_size) (inc idx) array idx))))

## Increases the level-shift by the branching-exponent, to explore
## levels further down the tree.
(def: level_up
  (-> Level Level)
  (n.+ branching_exponent))

(def: hierarchy_mask BitMap (dec hierarchy_nodes_size))

## Gets the branching-factor sized section of the hash corresponding
## to a particular level, and uses that as an index into the array.
(def: (level_index level hash)
  (-> Level Hash_Code Index)
  (i64.and hierarchy_mask
           (i64.right_shift level hash)))

## A mechanism to go from indices to bit-positions.
(def: (->bit_position index)
  (-> Index BitPosition)
  (i64.left_shift index 1))

## The bit-position within a base that a given hash-code would have.
(def: (bit_position level hash)
  (-> Level Hash_Code BitPosition)
  (->bit_position (level_index level hash)))

(def: (bit_position_is_set? bit bitmap)
  (-> BitPosition BitMap Bit)
  (not (n.= clean_bitmap (i64.and bit bitmap))))

## Figures out whether a bitmap only contains a single bit-position.
(def: only_bit_position?
  (-> BitPosition BitMap Bit)
  n.=)

(def: (set_bit_position bit bitmap)
  (-> BitPosition BitMap BitMap)
  (i64.or bit bitmap))

(def: unset_bit_position
  (-> BitPosition BitMap BitMap)
  i64.xor)

## Figures out the size of a bitmap-indexed array by counting all the
## 1s within the bitmap.
(def: bitmap_size
  (-> BitMap Nat)
  i64.count)

## A mask that, for a given bit position, only allows all the 1s prior
## to it, which would indicate the bitmap-size (and, thus, index)
## associated with it.
(def: bit_position_mask
  (-> BitPosition BitMap)
  dec)

## The index on the base array, based on it's bit-position.
(def: (base_index bit_position bitmap)
  (-> BitPosition BitMap Index)
  (bitmap_size (i64.and (bit_position_mask bit_position)
                        bitmap)))

## Produces the index of a KV-pair within a #Collisions node.
(def: (collision_index Hash<k> key colls)
  (All [k v] (-> (Hash k) k (Collisions k v) (Maybe Index)))
  (\ maybe.monad map product.left
     (array.find+ (function (_ idx [key' val'])
                    (\ Hash<k> = key key'))
                  colls)))

## When #Hierarchy nodes grow too small, they're demoted to #Base
## nodes to save space.
(def: (demote_hierarchy except_idx [h_size h_array])
  (All [k v] (-> Index (Hierarchy k v) [BitMap (Base k v)]))
  (product.right (list\fold (function (_ idx [insertion_idx node])
                              (let [[bitmap base] node]
                                (case (array.read idx h_array)
                                  #.None            [insertion_idx node]
                                  (#.Some sub_node) (if (n.= except_idx idx)
                                                      [insertion_idx node]
                                                      [(inc insertion_idx)
                                                       [(set_bit_position (->bit_position idx) bitmap)
                                                        (array.write! insertion_idx (#.Left sub_node) base)]])
                                  )))
                            [0 [clean_bitmap
                                (array.new (dec h_size))]]
                            (list.indices (array.size h_array)))))

## When #Base nodes grow too large, they're promoted to #Hierarchy to
## add some depth to the tree and help keep it's balance.
(def: hierarchy_indices (List Index) (list.indices hierarchy_nodes_size))

(def: (promote_base put' Hash<k> level bitmap base)
  (All [k v]
    (-> (-> Level Hash_Code k v (Hash k) (Node k v) (Node k v))
        (Hash k) Level
        BitMap (Base k v)
        (Array (Node k v))))
  (product.right (list\fold (function (_ hierarchy_idx (^@ default [base_idx h_array]))
                              (if (bit_position_is_set? (->bit_position hierarchy_idx)
                                                        bitmap)
                                [(inc base_idx)
                                 (case (array.read base_idx base)
                                   (#.Some (#.Left sub_node))
                                   (array.write! hierarchy_idx sub_node h_array)

                                   (#.Some (#.Right [key' val']))
                                   (array.write! hierarchy_idx
                                                 (put' (level_up level) (\ Hash<k> hash key') key' val' Hash<k> empty)
                                                 h_array)

                                   #.None
                                   (undefined))]
                                default))
                            [0
                             (array.new hierarchy_nodes_size)]
                            hierarchy_indices)))

## All empty nodes look the same (a #Base node with clean bitmap is
## used).
## So, this test is introduced to detect them.
(def: (empty?' node)
  (All [k v] (-> (Node k v) Bit))
  (`` (case node
        (#Base (~~ (static ..clean_bitmap)) _)
        #1

        _
        #0)))

(def: (put' level hash key val Hash<k> node)
  (All [k v] (-> Level Hash_Code k v (Hash k) (Node k v) (Node k v)))
  (case node
    ## For #Hierarchy nodes, I check whether I can add the element to
    ## a sub-node. If impossible, I introduced a new singleton sub-node.
    (#Hierarchy _size hierarchy)
    (let [idx (level_index level hash)
          [_size' sub_node] (case (array.read idx hierarchy)
                              (#.Some sub_node)
                              [_size sub_node]

                              _
                              [(inc _size) empty])]
      (#Hierarchy _size'
                  (update! idx (put' (level_up level) hash key val Hash<k> sub_node)
                           hierarchy)))

    ## For #Base nodes, I check if the corresponding BitPosition has
    ## already been used.
    (#Base bitmap base)
    (let [bit (bit_position level hash)]
      (if (bit_position_is_set? bit bitmap)
        ## If so...
        (let [idx (base_index bit bitmap)]
          (case (array.read idx base)
            #.None
            (undefined)

            ## If it's being used by a node, I add the KV to it.
            (#.Some (#.Left sub_node))
            (let [sub_node' (put' (level_up level) hash key val Hash<k> sub_node)]
              (#Base bitmap (update! idx (#.Left sub_node') base)))

            ## Otherwise, if it's being used by a KV, I compare the keys.
            (#.Some (#.Right key' val'))
            (if (\ Hash<k> = key key')
              ## If the same key is found, I replace the value.
              (#Base bitmap (update! idx (#.Right key val) base))
              ## Otherwise, I compare the hashes of the keys.
              (#Base bitmap (update! idx
                                     (#.Left (let [hash' (\ Hash<k> hash key')]
                                               (if (n.= hash hash')
                                                 ## If the hashes are
                                                 ## the same, a new
                                                 ## #Collisions node
                                                 ## is added.
                                                 (#Collisions hash (|> (array.new 2)
                                                                       (array.write! 0 [key' val'])
                                                                       (array.write! 1 [key val])))
                                                 ## Otherwise, I can
                                                 ## just keep using
                                                 ## #Base nodes, so I
                                                 ## add both KV-pairs
                                                 ## to the empty one.
                                                 (let [next_level (level_up level)]
                                                   (|> empty
                                                       (put' next_level hash' key' val' Hash<k>)
                                                       (put' next_level hash  key  val Hash<k>))))))
                                     base)))))
        ## However, if the BitPosition has not been used yet, I check
        ## whether this #Base node is ready for a promotion.
        (let [base_count (bitmap_size bitmap)]
          (if (n.>= ..promotion_threshold base_count)
            ## If so, I promote it to a #Hierarchy node, and add the new
            ## KV-pair as a singleton node to it.
            (#Hierarchy (inc base_count)
                        (|> (promote_base put' Hash<k> level bitmap base)
                            (array.write! (level_index level hash)
                                          (put' (level_up level) hash key val Hash<k> empty))))
            ## Otherwise, I just resize the #Base node to accommodate the
            ## new KV-pair.
            (#Base (set_bit_position bit bitmap)
                   (insert! (base_index bit bitmap) (#.Right [key val]) base))))))
    
    ## For #Collisions nodes, I compare the hashes.
    (#Collisions _hash _colls)
    (if (n.= hash _hash)
      ## If they're equal, that means the new KV contributes to the
      ## collisions.
      (case (collision_index Hash<k> key _colls)
        ## If the key was already present in the collisions-list, it's
        ## value gets updated.
        (#.Some coll_idx)
        (#Collisions _hash (update! coll_idx [key val] _colls))

        ## Otherwise, the KV-pair is added to the collisions-list.
        #.None
        (#Collisions _hash (insert! (array.size _colls) [key val] _colls)))
      ## If the hashes are not equal, I create a new #Base node that
      ## contains the old #Collisions node, plus the new KV-pair.
      (|> (#Base (bit_position level _hash)
                 (|> (array.new 1)
                     (array.write! 0 (#.Left node))))
          (put' level hash key val Hash<k>)))
    ))

(def: (remove' level hash key Hash<k> node)
  (All [k v] (-> Level Hash_Code k (Hash k) (Node k v) (Node k v)))
  (case node
    ## For #Hierarchy nodes, find out if there's a valid sub-node for
    ## the Hash-Code.
    (#Hierarchy h_size h_array)
    (let [idx (level_index level hash)]
      (case (array.read idx h_array)
        ## If not, there's nothing to remove.
        #.None
        node

        ## But if there is, try to remove the key from the sub-node.
        (#.Some sub_node)
        (let [sub_node' (remove' (level_up level) hash key Hash<k> sub_node)]
          ## Then check if a removal was actually done.
          (if (is? sub_node sub_node')
            ## If not, then there's nothing to change here either.
            node
            ## But if the sub_removal yielded an empty sub_node...
            (if (empty?' sub_node')
              ## Check if it's due time for a demotion.
              (if (n.<= demotion_threshold h_size)
                ## If so, perform it.
                (#Base (demote_hierarchy idx [h_size h_array]))
                ## Otherwise, just clear the space.
                (#Hierarchy (dec h_size) (vacant! idx h_array)))
              ## But if the sub_removal yielded a non_empty node, then
              ## just update the hiearchy branch.
              (#Hierarchy h_size (update! idx sub_node' h_array)))))))

    ## For #Base nodes, check whether the BitPosition is set.
    (#Base bitmap base)
    (let [bit (bit_position level hash)]
      (if (bit_position_is_set? bit bitmap)
        (let [idx (base_index bit bitmap)]
          (case (array.read idx base)
            #.None
            (undefined)

            ## If set, check if it's a sub_node, and remove the KV
            ## from it.
            (#.Some (#.Left sub_node))
            (let [sub_node' (remove' (level_up level) hash key Hash<k> sub_node)]
              ## Verify that it was removed.
              (if (is? sub_node sub_node')
                ## If not, there's also nothing to change here.
                node
                ## But if it came out empty...
                (if (empty?' sub_node')
                  ### ... figure out whether that's the only position left.
                  (if (only_bit_position? bit bitmap)
                    ## If so, removing it leaves this node empty too.
                    empty
                    ## But if not, then just unset the position and
                    ## remove the node.
                    (#Base (unset_bit_position bit bitmap)
                           (remove! idx base)))
                  ## But, if it did not come out empty, then the
                  ## position is kept, and the node gets updated.
                  (#Base bitmap
                         (update! idx (#.Left sub_node') base)))))

            ## If, however, there was a KV-pair instead of a sub-node.
            (#.Some (#.Right [key' val']))
            ## Check if the keys match.
            (if (\ Hash<k> = key key')
              ## If so, remove the KV-pair and unset the BitPosition.
              (#Base (unset_bit_position bit bitmap)
                     (remove! idx base))
              ## Otherwise, there's nothing to remove.
              node)))
        ## If the BitPosition is not set, there's nothing to remove.
        node))

    ## For #Collisions nodes, It need to find out if the key already existst.
    (#Collisions _hash _colls)
    (case (collision_index Hash<k> key _colls)
      ## If not, then there's nothing to remove.
      #.None
      node

      ## But if so, then check the size of the collisions list.
      (#.Some idx)
      (if (n.= 1 (array.size _colls))
        ## If there's only one left, then removing it leaves us with
        ## an empty node.
        empty
        ## Otherwise, just shrink the array by removing the KV-pair.
        (#Collisions _hash (remove! idx _colls))))
    ))

(def: (get' level hash key Hash<k> node)
  (All [k v] (-> Level Hash_Code k (Hash k) (Node k v) (Maybe v)))
  (case node
    ## For #Hierarchy nodes, just look-up the key on its children.
    (#Hierarchy _size hierarchy)
    (case (array.read (level_index level hash) hierarchy)
      #.None            #.None
      (#.Some sub_node) (get' (level_up level) hash key Hash<k> sub_node))

    ## For #Base nodes, check the leaves, and recursively check the branches.
    (#Base bitmap base)
    (let [bit (bit_position level hash)]
      (if (bit_position_is_set? bit bitmap)
        (case (array.read (base_index bit bitmap) base)
          #.None
          (undefined)
          
          (#.Some (#.Left sub_node))
          (get' (level_up level) hash key Hash<k> sub_node)

          (#.Some (#.Right [key' val']))
          (if (\ Hash<k> = key key')
            (#.Some val')
            #.None))
        #.None))

    ## For #Collisions nodes, do a linear scan of all the known KV-pairs.
    (#Collisions _hash _colls)
    (\ maybe.monad map product.right
       (array.find (|>> product.left (\ Hash<k> = key))
                   _colls))
    ))

(def: (size' node)
  (All [k v] (-> (Node k v) Nat))
  (case node
    (#Hierarchy _size hierarchy)
    (array\fold n.+ 0 (array\map size' hierarchy))
    
    (#Base _ base)
    (array\fold n.+ 0 (array\map (function (_ sub_node')
                                   (case sub_node'
                                     (#.Left sub_node) (size' sub_node)
                                     (#.Right _)       1))
                                 base))

    (#Collisions hash colls)
    (array.size colls)
    ))

(def: (entries' node)
  (All [k v] (-> (Node k v) (List [k v])))
  (case node
    (#Hierarchy _size hierarchy)
    (array\fold (function (_ sub_node tail) (list\compose (entries' sub_node) tail))
                #.Nil
                hierarchy)

    (#Base bitmap base)
    (array\fold (function (_ branch tail)
                  (case branch
                    (#.Left sub_node)
                    (list\compose (entries' sub_node) tail)

                    (#.Right [key' val'])
                    (#.Cons [key' val'] tail)))
                #.Nil
                base)
    
    (#Collisions hash colls)
    (array\fold (function (_ [key' val'] tail) (#.Cons [key' val'] tail))
                #.Nil
                colls)))

(type: #export (Dictionary k v)
  {#.doc "A dictionary implemented as a Hash-Array Mapped Trie (HAMT)."}
  {#hash (Hash k)
   #root (Node k v)})

(def: #export key_hash
  (All [k v] (-> (Dictionary k v) (Hash k)))
  (get@ #..hash))

(def: #export (new Hash<k>)
  (All [k v] (-> (Hash k) (Dictionary k v)))
  {#hash Hash<k>
   #root empty})

(def: #export (put key val dict)
  (All [k v] (-> k v (Dictionary k v) (Dictionary k v)))
  (let [[Hash<k> node] dict]
    [Hash<k> (put' root_level (\ Hash<k> hash key) key val Hash<k> node)]))

(def: #export (remove key dict)
  (All [k v] (-> k (Dictionary k v) (Dictionary k v)))
  (let [[Hash<k> node] dict]
    [Hash<k> (remove' root_level (\ Hash<k> hash key) key Hash<k> node)]))

(def: #export (get key dict)
  (All [k v] (-> k (Dictionary k v) (Maybe v)))
  (let [[Hash<k> node] dict]
    (get' root_level (\ Hash<k> hash key) key Hash<k> node)))

(def: #export (key? dict key)
  (All [k v] (-> (Dictionary k v) k Bit))
  (case (get key dict)
    #.None     #0
    (#.Some _) #1))

(exception: #export key_already_exists)

(def: #export (try_put key val dict)
  {#.doc "Only puts the KV-pair if the key is not already present."}
  (All [k v] (-> k v (Dictionary k v) (Try (Dictionary k v))))
  (case (get key dict)
    #.None     (#try.Success (put key val dict))
    (#.Some _) (exception.throw ..key_already_exists [])))

(def: #export (update key f dict)
  {#.doc "Transforms the value located at key (if available), using the given function."}
  (All [k v] (-> k (-> v v) (Dictionary k v) (Dictionary k v)))
  (case (get key dict)
    #.None
    dict

    (#.Some val)
    (put key (f val) dict)))

(def: #export (upsert key default f dict)
  {#.doc (doc "Updates the value at the key; if it exists."
              "Otherwise, puts a value by applying the function to a default.")}
  (All [k v] (-> k v (-> v v) (Dictionary k v) (Dictionary k v)))
  (..put key
         (f (maybe.default default
                           (..get key dict)))
         dict))

(def: #export size
  (All [k v] (-> (Dictionary k v) Nat))
  (|>> product.right ..size'))

(def: #export empty?
  (All [k v] (-> (Dictionary k v) Bit))
  (|>> size (n.= 0)))

(def: #export (entries dict)
  (All [k v] (-> (Dictionary k v) (List [k v])))
  (entries' (product.right dict)))

(def: #export (from_list Hash<k> kvs)
  (All [k v] (-> (Hash k) (List [k v]) (Dictionary k v)))
  (list\fold (function (_ [k v] dict)
               (put k v dict))
             (new Hash<k>)
             kvs))

(template [<name> <elem_type> <side>]
  [(def: #export (<name> dict)
     (All [k v] (-> (Dictionary k v) (List <elem_type>)))
     (|> dict entries (list\map <side>)))]

  [keys   k product.left]
  [values v product.right]
  )

(def: #export (merge dict2 dict1)
  {#.doc (doc "Merges 2 dictionaries."
              "If any collisions with keys occur, the values of dict2 will overwrite those of dict1.")}
  (All [k v] (-> (Dictionary k v) (Dictionary k v) (Dictionary k v)))
  (list\fold (function (_ [key val] dict) (put key val dict))
             dict1
             (entries dict2)))

(def: #export (merge_with f dict2 dict1)
  {#.doc (doc "Merges 2 dictionaries."
              "If any collisions with keys occur, a new value will be computed by applying 'f' to the values of dict2 and dict1.")}
  (All [k v] (-> (-> v v v) (Dictionary k v) (Dictionary k v) (Dictionary k v)))
  (list\fold (function (_ [key val2] dict)
               (case (get key dict)
                 #.None
                 (put key val2 dict)

                 (#.Some val1)
                 (put key (f val2 val1) dict)))
             dict1
             (entries dict2)))

(def: #export (re_bind from_key to_key dict)
  (All [k v] (-> k k (Dictionary k v) (Dictionary k v)))
  (case (get from_key dict)
    #.None
    dict

    (#.Some val)
    (|> dict
        (remove from_key)
        (put to_key val))))

(def: #export (select keys dict)
  {#.doc "Creates a sub-set of the given dict, with only the specified keys."}
  (All [k v] (-> (List k) (Dictionary k v) (Dictionary k v)))
  (let [[Hash<k> _] dict]
    (list\fold (function (_ key new_dict)
                 (case (get key dict)
                   #.None       new_dict
                   (#.Some val) (put key val new_dict)))
               (new Hash<k>)
               keys)))

(structure: #export (equivalence (^open ",\."))
  (All [k v] (-> (Equivalence v) (Equivalence (Dictionary k v))))
  
  (def: (= reference subject)
    (and (n.= (..size reference)
              (..size subject))
         (list.every? (function (_ [k rv])
                        (case (..get k subject)
                          (#.Some sv)
                          (,\= rv sv)

                          _
                          #0))
                      (..entries reference)))))

(structure: functor'
  (All [k] (Functor (Node k)))
  
  (def: (map f fa)
    (case fa
      (#Hierarchy size hierarchy)
      (#Hierarchy size (array\map (map f) hierarchy))
      
      (#Base bitmap base)
      (#Base bitmap (array\map (function (_ either)
                                 (case either
                                   (#.Left fa')
                                   (#.Left (map f fa'))
                                   
                                   (#.Right [k v])
                                   (#.Right [k (f v)])))
                               base))
      
      (#Collisions hash collisions)
      (#Collisions hash (array\map (function (_ [k v])
                                     [k (f v)])
                                   collisions)))))

(structure: #export functor
  (All [k] (Functor (Dictionary k)))
  
  (def: (map f fa)
    (update@ #root (\ ..functor' map f) fa)))