diff options
Diffstat (limited to 'stdlib/source/lux/tool/compiler/phase/analysis/case.lux')
-rw-r--r-- | stdlib/source/lux/tool/compiler/phase/analysis/case.lux | 300 |
1 files changed, 300 insertions, 0 deletions
diff --git a/stdlib/source/lux/tool/compiler/phase/analysis/case.lux b/stdlib/source/lux/tool/compiler/phase/analysis/case.lux new file mode 100644 index 000000000..37bcfef6e --- /dev/null +++ b/stdlib/source/lux/tool/compiler/phase/analysis/case.lux @@ -0,0 +1,300 @@ +(.module: + [lux (#- case) + [control + ["." monad (#+ do)] + ["ex" exception (#+ exception:)]] + [data + ["." product] + ["." error] + ["." maybe] + [text + format] + [collection + ["." list ("#/." fold monoid functor)]]] + ["." type + ["." check]] + ["." macro + ["." code]]] + ["." // (#+ Pattern Analysis Operation Phase) + ["." scope] + ["//." type] + ["." structure] + ["/." // + ["." extension]]] + [/ + ["." coverage (#+ Coverage)]]) + +(exception: #export (cannot-match-with-pattern {type Type} {pattern Code}) + (ex.report ["Type" (%type type)] + ["Pattern" (%code pattern)])) + +(exception: #export (sum-has-no-case {case Nat} {type Type}) + (ex.report ["Case" (%n case)] + ["Type" (%type type)])) + +(exception: #export (not-a-pattern {code Code}) + (ex.report ["Code" (%code code)])) + +(exception: #export (cannot-simplify-for-pattern-matching {type Type}) + (ex.report ["Type" (%type type)])) + +(exception: #export (non-exhaustive-pattern-matching {input Code} {branches (List [Code Code])} {coverage Coverage}) + (ex.report ["Input" (%code input)] + ["Branches" (%code (code.record branches))] + ["Coverage" (coverage.%coverage coverage)])) + +(exception: #export (cannot-have-empty-branches {message Text}) + message) + +(def: (re-quantify envs baseT) + (-> (List (List Type)) Type Type) + (.case envs + #.Nil + baseT + + (#.Cons head tail) + (re-quantify tail (#.UnivQ head baseT)))) + +## Type-checking on the input value is done during the analysis of a +## "case" expression, to ensure that the patterns being used make +## sense for the type of the input value. +## Sometimes, that input value is complex, by depending on +## type-variables or quantifications. +## This function makes it easier for "case" analysis to properly +## type-check the input with respect to the patterns. +(def: (simplify-case caseT) + (-> Type (Operation Type)) + (loop [envs (: (List (List Type)) + (list)) + caseT caseT] + (.case caseT + (#.Var id) + (do ///.monad + [?caseT' (//type.with-env + (check.read id))] + (.case ?caseT' + (#.Some caseT') + (recur envs caseT') + + _ + (///.throw cannot-simplify-for-pattern-matching caseT))) + + (#.Named name unnamedT) + (recur envs unnamedT) + + (#.UnivQ env unquantifiedT) + (recur (#.Cons env envs) unquantifiedT) + + (#.ExQ _) + (do ///.monad + [[ex-id exT] (//type.with-env + check.existential)] + (recur envs (maybe.assume (type.apply (list exT) caseT)))) + + (#.Apply inputT funcT) + (.case funcT + (#.Var funcT-id) + (do ///.monad + [funcT' (//type.with-env + (do check.monad + [?funct' (check.read funcT-id)] + (.case ?funct' + (#.Some funct') + (wrap funct') + + _ + (check.throw cannot-simplify-for-pattern-matching caseT))))] + (recur envs (#.Apply inputT funcT'))) + + _ + (.case (type.apply (list inputT) funcT) + (#.Some outputT) + (recur envs outputT) + + #.None + (///.throw cannot-simplify-for-pattern-matching caseT))) + + (#.Product _) + (|> caseT + type.flatten-tuple + (list/map (re-quantify envs)) + type.tuple + (:: ///.monad wrap)) + + _ + (:: ///.monad wrap (re-quantify envs caseT))))) + +(def: (analyse-primitive type inputT cursor output next) + (All [a] (-> Type Type Cursor Pattern (Operation a) (Operation [Pattern a]))) + (//.with-cursor cursor + (do ///.monad + [_ (//type.with-env + (check.check inputT type)) + outputA next] + (wrap [output outputA])))) + +## This function handles several concerns at once, but it must be that +## way because those concerns are interleaved when doing +## pattern-matching and they cannot be separated. +## The pattern is analysed in order to get a general feel for what is +## expected of the input value. This, in turn, informs the +## type-checking of the input. +## A kind of "continuation" value is passed around which signifies +## what needs to be done _after_ analysing a pattern. +## In general, this is done to analyse the "body" expression +## associated to a particular pattern _in the context of_ said +## pattern. +## The reason why *context* is important is because patterns may bind +## values to local variables, which may in turn be referenced in the +## body expressions. +## That is why the body must be analysed in the context of the +## pattern, and not separately. +(def: (analyse-pattern num-tags inputT pattern next) + (All [a] (-> (Maybe Nat) Type Code (Operation a) (Operation [Pattern a]))) + (.case pattern + [cursor (#.Identifier ["" name])] + (//.with-cursor cursor + (do ///.monad + [outputA (scope.with-local [name inputT] + next) + idx scope.next-local] + (wrap [(#//.Bind idx) outputA]))) + + (^template [<type> <input> <output>] + [cursor <input>] + (analyse-primitive <type> inputT cursor (#//.Simple <output>) next)) + ([Bit (#.Bit pattern-value) (#//.Bit pattern-value)] + [Nat (#.Nat pattern-value) (#//.Nat pattern-value)] + [Int (#.Int pattern-value) (#//.Int pattern-value)] + [Rev (#.Rev pattern-value) (#//.Rev pattern-value)] + [Frac (#.Frac pattern-value) (#//.Frac pattern-value)] + [Text (#.Text pattern-value) (#//.Text pattern-value)] + [Any (#.Tuple #.Nil) #//.Unit]) + + (^ [cursor (#.Tuple (list singleton))]) + (analyse-pattern #.None inputT singleton next) + + [cursor (#.Tuple sub-patterns)] + (//.with-cursor cursor + (do ///.monad + [inputT' (simplify-case inputT)] + (.case inputT' + (#.Product _) + (let [subs (type.flatten-tuple inputT') + num-subs (maybe.default (list.size subs) + num-tags) + num-sub-patterns (list.size sub-patterns) + matches (cond (n/< num-subs num-sub-patterns) + (let [[prefix suffix] (list.split (dec num-sub-patterns) subs)] + (list.zip2 (list/compose prefix (list (type.tuple suffix))) sub-patterns)) + + (n/> num-subs num-sub-patterns) + (let [[prefix suffix] (list.split (dec num-subs) sub-patterns)] + (list.zip2 subs (list/compose prefix (list (code.tuple suffix))))) + + ## (n/= num-subs num-sub-patterns) + (list.zip2 subs sub-patterns))] + (do @ + [[memberP+ thenA] (list/fold (: (All [a] + (-> [Type Code] (Operation [(List Pattern) a]) + (Operation [(List Pattern) a]))) + (function (_ [memberT memberC] then) + (do @ + [[memberP [memberP+ thenA]] ((:coerce (All [a] (-> (Maybe Nat) Type Code (Operation a) (Operation [Pattern a]))) + analyse-pattern) + #.None memberT memberC then)] + (wrap [(list& memberP memberP+) thenA])))) + (do @ + [nextA next] + (wrap [(list) nextA])) + (list.reverse matches))] + (wrap [(//.pattern/tuple memberP+) + thenA]))) + + _ + (///.throw cannot-match-with-pattern [inputT pattern]) + ))) + + [cursor (#.Record record)] + (do ///.monad + [record (structure.normalize record) + [members recordT] (structure.order record) + _ (//type.with-env + (check.check inputT recordT))] + (analyse-pattern (#.Some (list.size members)) inputT [cursor (#.Tuple members)] next)) + + [cursor (#.Tag tag)] + (//.with-cursor cursor + (analyse-pattern #.None inputT (` ((~ pattern))) next)) + + (^ [cursor (#.Form (list& [_ (#.Nat idx)] values))]) + (//.with-cursor cursor + (do ///.monad + [inputT' (simplify-case inputT)] + (.case inputT' + (#.Sum _) + (let [flat-sum (type.flatten-variant inputT') + size-sum (list.size flat-sum) + num-cases (maybe.default size-sum num-tags)] + (.case (list.nth idx flat-sum) + (^multi (#.Some caseT) + (n/< num-cases idx)) + (do ///.monad + [[testP nextA] (if (and (n/> num-cases size-sum) + (n/= (dec num-cases) idx)) + (analyse-pattern #.None + (type.variant (list.drop (dec num-cases) flat-sum)) + (` [(~+ values)]) + next) + (analyse-pattern #.None caseT (` [(~+ values)]) next)) + #let [right? (n/= (dec num-cases) idx) + lefts (if right? + (dec idx) + idx)]] + (wrap [(//.pattern/variant [lefts right? testP]) + nextA])) + + _ + (///.throw sum-has-no-case [idx inputT]))) + + _ + (///.throw cannot-match-with-pattern [inputT pattern])))) + + (^ [cursor (#.Form (list& [_ (#.Tag tag)] values))]) + (//.with-cursor cursor + (do ///.monad + [tag (extension.lift (macro.normalize tag)) + [idx group variantT] (extension.lift (macro.resolve-tag tag)) + _ (//type.with-env + (check.check inputT variantT))] + (analyse-pattern (#.Some (list.size group)) inputT (` ((~ (code.nat idx)) (~+ values))) next))) + + _ + (///.throw not-a-pattern pattern) + )) + +(def: #export (case analyse inputC branches) + (-> Phase Code (List [Code Code]) (Operation Analysis)) + (.case branches + (#.Cons [patternH bodyH] branchesT) + (do ///.monad + [[inputT inputA] (//type.with-inference + (analyse inputC)) + outputH (analyse-pattern #.None inputT patternH (analyse bodyH)) + outputT (monad.map @ + (function (_ [patternT bodyT]) + (analyse-pattern #.None inputT patternT (analyse bodyT))) + branchesT) + outputHC (|> outputH product.left coverage.determine) + outputTC (monad.map @ (|>> product.left coverage.determine) outputT) + _ (.case (monad.fold error.monad coverage.merge outputHC outputTC) + (#error.Success coverage) + (///.assert non-exhaustive-pattern-matching [inputC branches coverage] + (coverage.exhaustive? coverage)) + + (#error.Failure error) + (///.fail error))] + (wrap (#//.Case inputA [outputH outputT]))) + + #.Nil + (///.throw cannot-have-empty-branches ""))) |