(.module:
  [lux (#- case)
   [control
    [monad (#+ do)]
    ["ex" exception (#+ exception:)]]
   [data
    [product]
    [error]
    [maybe]
    [text
     format]
    [collection [list ("list/" Fold<List> Monoid<List> Functor<List>)]]]
   ["." macro
    [code]]]
  [// (#+ Pattern Analysis Operation Compiler)
   [scope]
   ["//." type]
   [structure]
   ["/." //
    [extension]
    [//
     ["." type
      ["tc" check]]]]]
  [/
   [coverage]])

(exception: #export (cannot-match-type-with-pattern {type Type} {pattern Code})
  (ex.report ["Type" (%type type)]
             ["Pattern" (%code pattern)]))

(exception: #export (sum-type-has-no-case {case Nat} {type Type})
  (ex.report ["Case" (%n case)]
             ["Type" (%type type)]))

(exception: #export (unrecognized-pattern-syntax {pattern Code})
  (%code pattern))

(exception: #export (cannot-simplify-type-for-pattern-matching {type Type})
  (%type type))

(do-template [<name>]
  [(exception: #export (<name> {message Text})
     message)]

  [cannot-have-empty-branches]
  [non-exhaustive-pattern-matching]
  )

(def: (re-quantify envs baseT)
  (-> (List (List Type)) Type Type)
  (.case envs
    #.Nil
    baseT

    (#.Cons head tail)
    (re-quantify tail (#.UnivQ head baseT))))

## Type-checking on the input value is done during the analysis of a
## "case" expression, to ensure that the patterns being used make
## sense for the type of the input value.
## Sometimes, that input value is complex, by depending on
## type-variables or quantifications.
## This function makes it easier for "case" analysis to properly
## type-check the input with respect to the patterns.
(def: (simplify-case-type caseT)
  (-> Type (Operation Type))
  (loop [envs (: (List (List Type))
                 (list))
         caseT caseT]
    (.case caseT
      (#.Var id)
      (do ///.Monad<Operation>
        [?caseT' (//type.with-env
                   (tc.read id))]
        (.case ?caseT'
          (#.Some caseT')
          (recur envs caseT')

          _
          (///.throw cannot-simplify-type-for-pattern-matching caseT)))

      (#.Named name unnamedT)
      (recur envs unnamedT)

      (#.UnivQ env unquantifiedT)
      (recur (#.Cons env envs) unquantifiedT)

      (#.ExQ _)
      (do ///.Monad<Operation>
        [[ex-id exT] (//type.with-env
                       tc.existential)]
        (recur envs (maybe.assume (type.apply (list exT) caseT))))

      (#.Apply inputT funcT)
      (.case funcT
        (#.Var funcT-id)
        (do ///.Monad<Operation>
          [funcT' (//type.with-env
                    (do tc.Monad<Check>
                      [?funct' (tc.read funcT-id)]
                      (.case ?funct'
                        (#.Some funct')
                        (wrap funct')

                        _
                        (tc.throw cannot-simplify-type-for-pattern-matching caseT))))]
          (recur envs (#.Apply inputT funcT')))

        _
        (.case (type.apply (list inputT) funcT)
          (#.Some outputT)
          (recur envs outputT)

          #.None
          (///.throw cannot-simplify-type-for-pattern-matching caseT)))

      (#.Product _)
      (|> caseT
          type.flatten-tuple
          (list/map (re-quantify envs))
          type.tuple
          (:: ///.Monad<Operation> wrap))

      _
      (:: ///.Monad<Operation> wrap (re-quantify envs caseT)))))

(def: (analyse-primitive type inputT cursor output next)
  (All [a] (-> Type Type Cursor Pattern (Operation a) (Operation [Pattern a])))
  (//.with-cursor cursor
    (do ///.Monad<Operation>
      [_ (//type.with-env
           (tc.check inputT type))
       outputA next]
      (wrap [output outputA]))))

## This function handles several concerns at once, but it must be that
## way because those concerns are interleaved when doing
## pattern-matching and they cannot be separated.
## The pattern is analysed in order to get a general feel for what is
## expected of the input value. This, in turn, informs the
## type-checking of the input.
## A kind of "continuation" value is passed around which signifies
## what needs to be done _after_ analysing a pattern.
## In general, this is done to analyse the "body" expression
## associated to a particular pattern _in the context of_ said
## pattern.
## The reason why *context* is important is because patterns may bind
## values to local variables, which may in turn be referenced in the
## body expressions.
## That is why the body must be analysed in the context of the
## pattern, and not separately.
(def: (analyse-pattern num-tags inputT pattern next)
  (All [a] (-> (Maybe Nat) Type Code (Operation a) (Operation [Pattern a])))
  (.case pattern
    [cursor (#.Symbol ["" name])]
    (//.with-cursor cursor
      (do ///.Monad<Operation>
        [outputA (scope.with-local [name inputT]
                   next)
         idx scope.next-local]
        (wrap [(#//.Bind idx) outputA])))

    (^template [<type> <input> <output>]
      [cursor <input>]
      (analyse-primitive <type> inputT cursor (#//.Simple <output>) next))
    ([Bit  (#.Bit pattern-value)  (#//.Bit pattern-value)]
     [Nat  (#.Nat pattern-value)  (#//.Nat pattern-value)]
     [Int  (#.Int pattern-value)  (#//.Int pattern-value)]
     [Rev  (#.Rev pattern-value)  (#//.Rev pattern-value)]
     [Frac (#.Frac pattern-value) (#//.Frac pattern-value)]
     [Text (#.Text pattern-value) (#//.Text pattern-value)]
     [Any  (#.Tuple #.Nil)        #//.Unit])
    
    (^ [cursor (#.Tuple (list singleton))])
    (analyse-pattern #.None inputT singleton next)
    
    [cursor (#.Tuple sub-patterns)]
    (//.with-cursor cursor
      (do ///.Monad<Operation>
        [inputT' (simplify-case-type inputT)]
        (.case inputT'
          (#.Product _)
          (let [sub-types (type.flatten-tuple inputT')
                num-sub-types (maybe.default (list.size sub-types)
                                             num-tags)
                num-sub-patterns (list.size sub-patterns)
                matches (cond (n/< num-sub-types num-sub-patterns)
                              (let [[prefix suffix] (list.split (dec num-sub-patterns) sub-types)]
                                (list.zip2 (list/compose prefix (list (type.tuple suffix))) sub-patterns))

                              (n/> num-sub-types num-sub-patterns)
                              (let [[prefix suffix] (list.split (dec num-sub-types) sub-patterns)]
                                (list.zip2 sub-types (list/compose prefix (list (code.tuple suffix)))))
                              
                              ## (n/= num-sub-types num-sub-patterns)
                              (list.zip2 sub-types sub-patterns))]
            (do @
              [[memberP+ thenA] (list/fold (: (All [a]
                                                (-> [Type Code] (Operation [(List Pattern) a])
                                                    (Operation [(List Pattern) a])))
                                              (function (_ [memberT memberC] then)
                                                (do @
                                                  [[memberP [memberP+ thenA]] ((:coerce (All [a] (-> (Maybe Nat) Type Code (Operation a) (Operation [Pattern a])))
                                                                                        analyse-pattern)
                                                                               #.None memberT memberC then)]
                                                  (wrap [(list& memberP memberP+) thenA]))))
                                           (do @
                                             [nextA next]
                                             (wrap [(list) nextA]))
                                           (list.reverse matches))]
              (wrap [(//.product-pattern memberP+)
                     thenA])))

          _
          (///.throw cannot-match-type-with-pattern [inputT pattern])
          )))

    [cursor (#.Record record)]
    (do ///.Monad<Operation>
      [record (structure.normalize record)
       [members recordT] (structure.order record)
       _ (//type.with-env
           (tc.check inputT recordT))]
      (analyse-pattern (#.Some (list.size members)) inputT [cursor (#.Tuple members)] next))

    [cursor (#.Tag tag)]
    (//.with-cursor cursor
      (analyse-pattern #.None inputT (` ((~ pattern))) next))

    (^ [cursor (#.Form (list& [_ (#.Nat idx)] values))])
    (//.with-cursor cursor
      (do ///.Monad<Operation>
        [inputT' (simplify-case-type inputT)]
        (.case inputT'
          (#.Sum _)
          (let [flat-sum (type.flatten-variant inputT')
                size-sum (list.size flat-sum)
                num-cases (maybe.default size-sum num-tags)]
            (.case (list.nth idx flat-sum)
              (^multi (#.Some case-type)
                      (n/< num-cases idx))
              (do ///.Monad<Operation>
                [[testP nextA] (if (and (n/> num-cases size-sum)
                                        (n/= (dec num-cases) idx))
                                 (analyse-pattern #.None
                                                  (type.variant (list.drop (dec num-cases) flat-sum))
                                                  (` [(~+ values)])
                                                  next)
                                 (analyse-pattern #.None case-type (` [(~+ values)]) next))]
                (wrap [(//.sum-pattern num-cases idx testP)
                       nextA]))

              _
              (///.throw sum-type-has-no-case [idx inputT])))

          _
          (///.throw cannot-match-type-with-pattern [inputT pattern]))))

    (^ [cursor (#.Form (list& [_ (#.Tag tag)] values))])
    (//.with-cursor cursor
      (do ///.Monad<Operation>
        [tag (extension.lift (macro.normalize tag))
         [idx group variantT] (extension.lift (macro.resolve-tag tag))
         _ (//type.with-env
             (tc.check inputT variantT))]
        (analyse-pattern (#.Some (list.size group)) inputT (` ((~ (code.nat idx)) (~+ values))) next)))

    _
    (///.throw unrecognized-pattern-syntax pattern)
    ))

(def: #export (case analyse inputC branches)
  (-> Compiler Code (List [Code Code]) (Operation Analysis))
  (.case branches
    #.Nil
    (///.throw cannot-have-empty-branches "")

    (#.Cons [patternH bodyH] branchesT)
    (do ///.Monad<Operation>
      [[inputT inputA] (//type.with-inference
                         (analyse inputC))
       outputH (analyse-pattern #.None inputT patternH (analyse bodyH))
       outputT (monad.map @
                          (function (_ [patternT bodyT])
                            (analyse-pattern #.None inputT patternT (analyse bodyT)))
                          branchesT)
       outputHC (|> outputH product.left coverage.determine)
       outputTC (monad.map @ (|>> product.left coverage.determine) outputT)
       _ (.case (monad.fold error.Monad<Error> coverage.merge outputHC outputTC)
           (#error.Success coverage)
           (///.assert non-exhaustive-pattern-matching ""
                       (coverage.exhaustive? coverage))

           (#error.Error error)
           (///.fail error))]
      (wrap (#//.Case inputA [outputH outputT])))))