aboutsummaryrefslogtreecommitdiff
path: root/new-luxc/source/luxc/analyser/case.lux
blob: fc151f7719c8834a93243283fa360c310149f9a5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
(;module:
  lux
  (lux (control monad
                eq)
       (data [bool "B/" Eq<Bool>]
             [number]
             [char]
             [text]
             text/format
             [product]
             ["R" result "R/" Monad<Result>]
             (coll [list "L/" Fold<List> Monoid<List> Monad<List>]
                   ["D" dict]))
       [macro #+ Monad<Lux>]
       (macro [code])
       [type]
       (type ["TC" check]))
  (luxc ["&" base]
        (lang ["la" analysis #+ Analysis]
              ["lp" pattern #+ Pattern])
        ["&;" env]
        (analyser ["&;" common]
                  ["&;" struct])))

(type: #rec Coverage
  #PartialC
  (#BoolC Bool)
  (#VariantC Nat (D;Dict Nat Coverage))
  (#SeqC Coverage Coverage)
  (#AltC Coverage Coverage)
  #TotalC)

(def: (pattern-error type pattern)
  (-> Type Code Text)
  (format "Cannot match this type: " (%type type) "\n"
          "     With this pattern: " (%code pattern)))

(def: (simplify-case-type type)
  (-> Type (Lux Type))
  (case type
    (#;Var id)
    (do Monad<Lux>
      [? (&;within-type-env
          (TC;bound? id))]
      (if ?
        (do @
          [type' (&;within-type-env
                  (TC;read-var id))]
          (simplify-case-type type'))
        (&;fail (format "Cannot simplify type for pattern-matching: " (%type type)))))

    (#;Named name unnamedT)
    (simplify-case-type unnamedT)

    (^or (#;UnivQ _) (#;ExQ _))
    (do Monad<Lux>
      [[ex-id exT] (&;within-type-env
                    TC;existential)]
      (simplify-case-type (assume (type;apply-type type exT))))

    _
    (:: Monad<Lux> wrap type)))

(def: (analyse-pattern num-tags inputT pattern next)
  (All [a] (-> (Maybe Nat) Type Code (Lux a) (Lux [Pattern a])))
  (case pattern
    [cursor (#;Symbol ["" name])]
    (&;with-cursor cursor
      (do Monad<Lux>
        [outputA (&env;with-local [name inputT]
                   next)
         idx &env;next-local]
        (wrap [(#lp;Bind idx) outputA])))

    [cursor (#;Symbol ident)]
    (&;with-cursor cursor
      (&;fail (format "Symbols must be unqualified inside patterns: " (%ident ident))))

    (^template [<type> <code-tag> <pattern-tag>]
      [cursor (<code-tag> test)]
      (&;with-cursor cursor
        (do Monad<Lux>
          [_ (&;within-type-env
              (TC;check inputT <type>))
           outputA next]
          (wrap [(<pattern-tag> test) outputA]))))
    ([Bool #;Bool #lp;Bool]
     [Nat  #;Nat  #lp;Nat]
     [Int  #;Int  #lp;Int]
     [Deg  #;Deg  #lp;Deg]
     [Real #;Real #lp;Real]
     [Char #;Char #lp;Char]
     [Text #;Text #lp;Text])

    (^ [cursor (#;Tuple (list))])
    (&;with-cursor cursor
      (do Monad<Lux>
        [_ (&;within-type-env
            (TC;check inputT Unit))
         outputA next]
        (wrap [#lp;Unit outputA])))

    (^ [cursor (#;Tuple (list singleton))])
    (analyse-pattern #;None inputT singleton next)
    
    [cursor (#;Tuple sub-patterns)]
    (&;with-cursor cursor
      (do Monad<Lux>
        [inputT' (simplify-case-type inputT)]
        (case inputT'
          (#;Product _)
          (let [sub-types (type;flatten-tuple inputT)
                num-sub-types (default (list;size sub-types)
                                num-tags)
                num-sub-patterns (list;size sub-patterns)
                matches (cond (n.< num-sub-types num-sub-patterns)
                              (let [[prefix suffix] (list;split (n.dec num-sub-patterns) sub-types)]
                                (list;zip2 (L/append prefix (list (type;tuple suffix))) sub-patterns))

                              (n.> num-sub-types num-sub-patterns)
                              (let [[prefix suffix] (list;split (n.dec num-sub-types) sub-patterns)]
                                (list;zip2 sub-types (L/append prefix (list (code;tuple suffix)))))
                              
                              ## (n.= num-sub-types num-sub-patterns)
                              (list;zip2 sub-types sub-patterns)
                              )]
            (do @
              [[memberP+ thenA] (L/fold (: (All [a]
                                             (-> [Type Code] (Lux [(List Pattern) a])
                                                 (Lux [(List Pattern) a])))
                                           (function [[memberT memberC] then]
                                             (do @
                                               [[memberP [memberP+ thenA]] ((:! (All [a] (-> (Maybe Nat) Type Code (Lux a) (Lux [Pattern a])))
                                                                                analyse-pattern)
                                                                            #;None memberT memberC then)]
                                               (wrap [(list& memberP memberP+) thenA]))))
                                        (do @
                                          [nextA next]
                                          (wrap [(list) nextA]))
                                        matches)]
              (wrap [(#lp;Tuple memberP+) thenA])))

          _
          (&;fail (pattern-error inputT pattern))
          )))

    [cursor (#;Record pairs)]
    (do Monad<Lux>
      [pairs (&struct;normalize-record pairs)
       [members recordT] (&struct;order-record pairs)
       _ (&;within-type-env
          (TC;check inputT recordT))]
      (analyse-pattern (#;Some (list;size members)) inputT [cursor (#;Tuple members)] next))

    [cursor (#;Tag tag)]
    (&;with-cursor cursor
      (analyse-pattern #;None inputT (` ((~ pattern))) next))

    (^ [cursor (#;Form (list& [_ (#;Nat idx)] values))])
    (&;with-cursor cursor
      (do Monad<Lux>
        [inputT' (simplify-case-type inputT)]
        (case inputT'
          (#;Sum _)
          (let [flat-sum (type;flatten-variant inputT)]
            (case (list;nth idx flat-sum)
              #;None
              (&;fail (format "Cannot match index " (%n idx) " against type: " (%type inputT)))

              (#;Some case-type)
              (do Monad<Lux>
                [[testP nextA] (analyse-pattern #;None case-type (' [(~@ values)]) next)]
                (wrap [(#lp;Variant [idx (default (list;size flat-sum)
                                           num-tags)]
                                    testP)
                       nextA]))))

          _
          (&;fail (pattern-error inputT pattern)))))

    (^ [cursor (#;Form (list& [_ (#;Tag tag)] values))])
    (&;with-cursor cursor
      (do Monad<Lux>
        [tag (macro;normalize tag)
         [idx group tagT] (macro;resolve-tag tag)
         _ (&;within-type-env
            (TC;check inputT tagT))]
        (analyse-pattern (#;Some (list;size group)) inputT (' ((~ (code;nat idx)) (~@ values))) next)))

    _
    (&;fail (format "Unrecognized pattern syntax: " (%code pattern)))
    ))

(def: (analyse-branch analyse inputT pattern body)
  (-> &;Analyser Type Code Code (Lux [Pattern Analysis]))
  (analyse-pattern #;None inputT pattern (analyse body)))

(do-template [<name> <tag>]
  [(def: (<name> coverage)
     (-> Coverage Bool)
     (case coverage
       (<tag> _)
       true

       _
       false))]

  [total? #TotalC]
  [alt?   #AltC])

(def: (determine-coverage pattern)
  (-> Pattern Coverage)
  (case pattern
    (^or (#lp;Bind _) #lp;Unit)
    #TotalC
    
    (#lp;Bool value)
    (#BoolC value)
    
    (^or (#lp;Nat _)  (#lp;Int _)  (#lp;Deg _)
         (#lp;Real _) (#lp;Char _) (#lp;Text _))
    #PartialC
    
    (#lp;Tuple subs)
    (loop [subs subs]
      (case subs
        #;Nil
        #TotalC

        (#;Cons sub subs')
        (let [post (recur subs')]
          (if (total? post)
            (determine-coverage sub)
            (#SeqC (determine-coverage sub)
                   post)))))
    
    (#lp;Variant [tag-id num-tags] sub)
    (#VariantC num-tags
               (|> (D;new number;Hash<Nat>)
                   (D;put tag-id (determine-coverage sub))))))

(def: (xor left right)
  (-> Bool Bool Bool)
  (or (and left (not right))
      (and (not left) right)))

(def: redundant-pattern
  (R;Result Coverage)
  (R;fail "Redundant pattern."))

(def: (flatten-alt coverage)
  (-> Coverage (List Coverage))
  (case coverage
    (#AltC left right)
    (list& left (flatten-alt right))

    _
    (list coverage)))

(struct: _ (Eq Coverage)
  (def: (= reference sample)
    (case [reference sample]
      (^or [#TotalC #TotalC] [#PartialC #PartialC])
      true

      [(#BoolC sideR) (#BoolC sideS)]
      (B/= sideR sideS)

      [(#VariantC allR casesR) (#VariantC allS casesS)]
      (and (n.= allR allS)
           (:: (D;Eq<Dict> =) = casesR casesS))

      [(#SeqC leftR rightR) (#SeqC leftS rightS)]
      (and (= leftR leftS)
           (= rightR rightS))

      [(#AltC _) (#AltC _)]
      (let [flatR (flatten-alt reference)
            flatS (flatten-alt sample)]
        (and (n.= (list;size flatR) (list;size flatS))
             (list;every? (function [[coverageR coverageS]]
                            (= coverageR coverageS))
                          (list;zip2 flatR flatS))))

      _
      false)))

(open Eq<Coverage> "C/")

(def: (merge-coverages addition so-far)
  (-> Coverage Coverage (R;Result Coverage))
  (case [addition so-far]
    ## The addition cannot possibly improve the coverage.
    [_ #TotalC]
    redundant-pattern

    ## The addition completes the coverage.
    [#TotalC _]
    (R/wrap #TotalC)

    [#PartialC #PartialC]
    (R/wrap #PartialC)

    (^=> [(#BoolC sideA) (#BoolC sideSF)]
         (xor sideA sideSF))
    (R/wrap #TotalC)

    [(#VariantC allA casesA) (#VariantC allSF casesSF)]
    (cond (not (n.= allSF allA))
          (R;fail "Variants do not match.")

          (:: (D;Eq<Dict> Eq<Coverage>) = casesSF casesA)
          redundant-pattern

          ## else
          (do R;Monad<Result>
            [casesM (foldM @
                           (function [[tagA coverageA] casesSF']
                             (case (D;get tagA casesSF')
                               (#;Some coverageSF)
                               (do @
                                 [coverageM (merge-coverages coverageA coverageSF)]
                                 (wrap (D;put tagA coverageM casesSF')))

                               #;None
                               (wrap (D;put tagA coverageA casesSF'))))
                           casesSF (D;entries casesA))]
            (wrap (if (list;every? total? (D;values casesM))
                    #TotalC
                    (#VariantC allSF casesM)))))

    [(#SeqC leftA rightA) (#SeqC leftSF rightSF)]
    (case [(C/= leftSF leftA) (C/= rightSF rightA)]
      ## There is nothing the addition adds to the coverage.
      [true true]
      redundant-pattern

      ## The 2 sequences cannot possibly be merged.
      [false false]
      (R/wrap (#AltC so-far addition))

      ## Same prefix
      [true false]
      (do R;Monad<Result>
        [rightM (merge-coverages rightA rightSF)]
        (if (total? rightM)
          (wrap leftSF)
          (wrap (#SeqC leftSF rightM))))

      ## Same suffix
      [false true]
      (do R;Monad<Result>
        [leftM (merge-coverages leftA leftSF)]
        (wrap (#SeqC leftM rightA))))
    
    ## The left part will always match, so the addition is redundant.
    (^=> [(#SeqC left right) single]
         (C/= left single))
    redundant-pattern

    ## The right part is not necessary, since it can always match the left.
    (^=> [single (#SeqC left right)]
         (C/= left single))
    (R/wrap single)

    [_ (#AltC leftS rightS)]
    (do R;Monad<Result>
      [#let [fuse-once (: (-> Coverage (List Coverage)
                              (R;Result [(Maybe Coverage)
                                         (List Coverage)]))
                          (function [coverage possibilities]
                            (loop [alts possibilities]
                              (case alts
                                #;Nil
                                (wrap [#;None (list coverage)])
                                
                                (#;Cons alt alts')
                                (case (merge-coverages coverage alt)
                                  (#R;Success altM)
                                  (case altM
                                    (#AltC _)
                                    (do @
                                      [[success alts+] (recur alts')]
                                      (wrap [success (#;Cons alt alts+)]))

                                    _
                                    (wrap [(#;Some altM) alts']))
                                  
                                  (#R;Error error)
                                  (R;fail error))
                                ))))]
       [success possibilities] (fuse-once addition (flatten-alt so-far))]
      (loop [success success
             possibilities possibilities]
        (case success
          (#;Some coverage')
          (do @
            [[success' possibilities'] (fuse-once coverage' possibilities)]
            (recur success' possibilities'))
          
          #;None
          (case (list;reverse possibilities)
            #;Nil
            (R;fail "{ This is not supposed to happen... }")
            
            (#;Cons last prevs)
            (wrap (L/fold (function [left right] (#AltC left right))
                          last
                          prevs))))))

    _
    (if (C/= so-far addition)
      ## The addition cannot possibly improve the coverage.
      redundant-pattern
      ## There are now 2 alternative paths.
      (R/wrap (#AltC so-far addition)))))

(def: get-coverage
  (-> [Pattern Analysis] Coverage)
  (|>. product;left determine-coverage))

(def: #export (analyse-case analyse input branches)
  (-> &;Analyser Code (List [Code Code]) (Lux Analysis))
  (case branches
    #;Nil
    (&;fail "Cannot have empty branches in pattern-matching expression.")

    (#;Cons [patternH bodyH] branchesT)
    (do Monad<Lux>
      [[inputT inputA] (&common;with-unknown-type
                         (analyse input))
       outputH (analyse-branch analyse inputT patternH bodyH)
       outputT (mapM @
                     (function [[patternT bodyT]]
                       (analyse-branch analyse inputT patternT bodyT))
                     branchesT)
       _ (case (foldM R;Monad<Result>
                      merge-coverages
                      (get-coverage outputH)
                      (L/map get-coverage outputT))
           (#R;Success coverage)
           (if (total? coverage)
             (wrap [])
             (&;fail "Pattern-matching is not total."))

           (#R;Error error)
           (&;fail error))]
      (wrap (#la;Case inputA (#;Cons outputH outputT))))))