(.module: [lux (#- case) [control ["." monad (#+ do)] ["ex" exception (#+ exception:)]] [data ["." product] ["." error] ["." maybe] [text format] [collection ["." list ("list/." Fold Monoid Functor)]]] ["." type ["." check]] ["." macro ["." code]]] ["." // (#+ Pattern Analysis Operation Compiler) ["." scope] ["//." type] ["." structure] ["/." // ["." extension]]] [/ ["." coverage]]) (exception: #export (cannot-match-with-pattern {type Type} {pattern Code}) (ex.report ["Type" (%type type)] ["Pattern" (%code pattern)])) (exception: #export (sum-has-no-case {case Nat} {type Type}) (ex.report ["Case" (%n case)] ["Type" (%type type)])) (exception: #export (unrecognized-pattern-syntax {pattern Code}) (%code pattern)) (exception: #export (cannot-simplify-for-pattern-matching {type Type}) (%type type)) (do-template [] [(exception: #export ( {message Text}) message)] [cannot-have-empty-branches] [non-exhaustive-pattern-matching] ) (def: (re-quantify envs baseT) (-> (List (List Type)) Type Type) (.case envs #.Nil baseT (#.Cons head tail) (re-quantify tail (#.UnivQ head baseT)))) ## Type-checking on the input value is done during the analysis of a ## "case" expression, to ensure that the patterns being used make ## sense for the type of the input value. ## Sometimes, that input value is complex, by depending on ## type-variables or quantifications. ## This function makes it easier for "case" analysis to properly ## type-check the input with respect to the patterns. (def: (simplify-case caseT) (-> Type (Operation Type)) (loop [envs (: (List (List Type)) (list)) caseT caseT] (.case caseT (#.Var id) (do ///.Monad [?caseT' (//type.with-env (check.read id))] (.case ?caseT' (#.Some caseT') (recur envs caseT') _ (///.throw cannot-simplify-for-pattern-matching caseT))) (#.Named name unnamedT) (recur envs unnamedT) (#.UnivQ env unquantifiedT) (recur (#.Cons env envs) unquantifiedT) (#.ExQ _) (do ///.Monad [[ex-id exT] (//type.with-env check.existential)] (recur envs (maybe.assume (type.apply (list exT) caseT)))) (#.Apply inputT funcT) (.case funcT (#.Var funcT-id) (do ///.Monad [funcT' (//type.with-env (do check.Monad [?funct' (check.read funcT-id)] (.case ?funct' (#.Some funct') (wrap funct') _ (check.throw cannot-simplify-for-pattern-matching caseT))))] (recur envs (#.Apply inputT funcT'))) _ (.case (type.apply (list inputT) funcT) (#.Some outputT) (recur envs outputT) #.None (///.throw cannot-simplify-for-pattern-matching caseT))) (#.Product _) (|> caseT type.flatten-tuple (list/map (re-quantify envs)) type.tuple (:: ///.Monad wrap)) _ (:: ///.Monad wrap (re-quantify envs caseT))))) (def: (analyse-primitive type inputT cursor output next) (All [a] (-> Type Type Cursor Pattern (Operation a) (Operation [Pattern a]))) (//.with-cursor cursor (do ///.Monad [_ (//type.with-env (check.check inputT type)) outputA next] (wrap [output outputA])))) ## This function handles several concerns at once, but it must be that ## way because those concerns are interleaved when doing ## pattern-matching and they cannot be separated. ## The pattern is analysed in order to get a general feel for what is ## expected of the input value. This, in turn, informs the ## type-checking of the input. ## A kind of "continuation" value is passed around which signifies ## what needs to be done _after_ analysing a pattern. ## In general, this is done to analyse the "body" expression ## associated to a particular pattern _in the context of_ said ## pattern. ## The reason why *context* is important is because patterns may bind ## values to local variables, which may in turn be referenced in the ## body expressions. ## That is why the body must be analysed in the context of the ## pattern, and not separately. (def: (analyse-pattern num-tags inputT pattern next) (All [a] (-> (Maybe Nat) Type Code (Operation a) (Operation [Pattern a]))) (.case pattern [cursor (#.Symbol ["" name])] (//.with-cursor cursor (do ///.Monad [outputA (scope.with-local [name inputT] next) idx scope.next-local] (wrap [(#//.Bind idx) outputA]))) (^template [ ] [cursor ] (analyse-primitive inputT cursor (#//.Simple ) next)) ([Bit (#.Bit pattern-value) (#//.Bit pattern-value)] [Nat (#.Nat pattern-value) (#//.Nat pattern-value)] [Int (#.Int pattern-value) (#//.Int pattern-value)] [Rev (#.Rev pattern-value) (#//.Rev pattern-value)] [Frac (#.Frac pattern-value) (#//.Frac pattern-value)] [Text (#.Text pattern-value) (#//.Text pattern-value)] [Any (#.Tuple #.Nil) #//.Unit]) (^ [cursor (#.Tuple (list singleton))]) (analyse-pattern #.None inputT singleton next) [cursor (#.Tuple sub-patterns)] (//.with-cursor cursor (do ///.Monad [inputT' (simplify-case inputT)] (.case inputT' (#.Product _) (let [subs (type.flatten-tuple inputT') num-subs (maybe.default (list.size subs) num-tags) num-sub-patterns (list.size sub-patterns) matches (cond (n/< num-subs num-sub-patterns) (let [[prefix suffix] (list.split (dec num-sub-patterns) subs)] (list.zip2 (list/compose prefix (list (type.tuple suffix))) sub-patterns)) (n/> num-subs num-sub-patterns) (let [[prefix suffix] (list.split (dec num-subs) sub-patterns)] (list.zip2 subs (list/compose prefix (list (code.tuple suffix))))) ## (n/= num-subs num-sub-patterns) (list.zip2 subs sub-patterns))] (do @ [[memberP+ thenA] (list/fold (: (All [a] (-> [Type Code] (Operation [(List Pattern) a]) (Operation [(List Pattern) a]))) (function (_ [memberT memberC] then) (do @ [[memberP [memberP+ thenA]] ((:coerce (All [a] (-> (Maybe Nat) Type Code (Operation a) (Operation [Pattern a]))) analyse-pattern) #.None memberT memberC then)] (wrap [(list& memberP memberP+) thenA])))) (do @ [nextA next] (wrap [(list) nextA])) (list.reverse matches))] (wrap [(//.product-pattern memberP+) thenA]))) _ (///.throw cannot-match-with-pattern [inputT pattern]) ))) [cursor (#.Record record)] (do ///.Monad [record (structure.normalize record) [members recordT] (structure.order record) _ (//type.with-env (check.check inputT recordT))] (analyse-pattern (#.Some (list.size members)) inputT [cursor (#.Tuple members)] next)) [cursor (#.Tag tag)] (//.with-cursor cursor (analyse-pattern #.None inputT (` ((~ pattern))) next)) (^ [cursor (#.Form (list& [_ (#.Nat idx)] values))]) (//.with-cursor cursor (do ///.Monad [inputT' (simplify-case inputT)] (.case inputT' (#.Sum _) (let [flat-sum (type.flatten-variant inputT') size-sum (list.size flat-sum) num-cases (maybe.default size-sum num-tags)] (.case (list.nth idx flat-sum) (^multi (#.Some caseT) (n/< num-cases idx)) (do ///.Monad [[testP nextA] (if (and (n/> num-cases size-sum) (n/= (dec num-cases) idx)) (analyse-pattern #.None (type.variant (list.drop (dec num-cases) flat-sum)) (` [(~+ values)]) next) (analyse-pattern #.None caseT (` [(~+ values)]) next))] (wrap [(//.sum-pattern num-cases idx testP) nextA])) _ (///.throw sum-has-no-case [idx inputT]))) _ (///.throw cannot-match-with-pattern [inputT pattern])))) (^ [cursor (#.Form (list& [_ (#.Tag tag)] values))]) (//.with-cursor cursor (do ///.Monad [tag (extension.lift (macro.normalize tag)) [idx group variantT] (extension.lift (macro.resolve-tag tag)) _ (//type.with-env (check.check inputT variantT))] (analyse-pattern (#.Some (list.size group)) inputT (` ((~ (code.nat idx)) (~+ values))) next))) _ (///.throw unrecognized-pattern-syntax pattern) )) (def: #export (case analyse inputC branches) (-> Compiler Code (List [Code Code]) (Operation Analysis)) (.case branches #.Nil (///.throw cannot-have-empty-branches "") (#.Cons [patternH bodyH] branchesT) (do ///.Monad [[inputT inputA] (//type.with-inference (analyse inputC)) outputH (analyse-pattern #.None inputT patternH (analyse bodyH)) outputT (monad.map @ (function (_ [patternT bodyT]) (analyse-pattern #.None inputT patternT (analyse bodyT))) branchesT) outputHC (|> outputH product.left coverage.determine) outputTC (monad.map @ (|>> product.left coverage.determine) outputT) _ (.case (monad.fold error.Monad coverage.merge outputHC outputTC) (#error.Success coverage) (///.assert non-exhaustive-pattern-matching "" (coverage.exhaustive? coverage)) (#error.Error error) (///.fail error))] (wrap (#//.Case inputA [outputH outputT])))))