aboutsummaryrefslogtreecommitdiff
path: root/stdlib/source/lux/tool/compiler/phase/analysis/case.lux
diff options
context:
space:
mode:
Diffstat (limited to 'stdlib/source/lux/tool/compiler/phase/analysis/case.lux')
-rw-r--r--stdlib/source/lux/tool/compiler/phase/analysis/case.lux300
1 files changed, 300 insertions, 0 deletions
diff --git a/stdlib/source/lux/tool/compiler/phase/analysis/case.lux b/stdlib/source/lux/tool/compiler/phase/analysis/case.lux
new file mode 100644
index 000000000..37bcfef6e
--- /dev/null
+++ b/stdlib/source/lux/tool/compiler/phase/analysis/case.lux
@@ -0,0 +1,300 @@
+(.module:
+ [lux (#- case)
+ [control
+ ["." monad (#+ do)]
+ ["ex" exception (#+ exception:)]]
+ [data
+ ["." product]
+ ["." error]
+ ["." maybe]
+ [text
+ format]
+ [collection
+ ["." list ("#/." fold monoid functor)]]]
+ ["." type
+ ["." check]]
+ ["." macro
+ ["." code]]]
+ ["." // (#+ Pattern Analysis Operation Phase)
+ ["." scope]
+ ["//." type]
+ ["." structure]
+ ["/." //
+ ["." extension]]]
+ [/
+ ["." coverage (#+ Coverage)]])
+
+(exception: #export (cannot-match-with-pattern {type Type} {pattern Code})
+ (ex.report ["Type" (%type type)]
+ ["Pattern" (%code pattern)]))
+
+(exception: #export (sum-has-no-case {case Nat} {type Type})
+ (ex.report ["Case" (%n case)]
+ ["Type" (%type type)]))
+
+(exception: #export (not-a-pattern {code Code})
+ (ex.report ["Code" (%code code)]))
+
+(exception: #export (cannot-simplify-for-pattern-matching {type Type})
+ (ex.report ["Type" (%type type)]))
+
+(exception: #export (non-exhaustive-pattern-matching {input Code} {branches (List [Code Code])} {coverage Coverage})
+ (ex.report ["Input" (%code input)]
+ ["Branches" (%code (code.record branches))]
+ ["Coverage" (coverage.%coverage coverage)]))
+
+(exception: #export (cannot-have-empty-branches {message Text})
+ message)
+
+(def: (re-quantify envs baseT)
+ (-> (List (List Type)) Type Type)
+ (.case envs
+ #.Nil
+ baseT
+
+ (#.Cons head tail)
+ (re-quantify tail (#.UnivQ head baseT))))
+
+## Type-checking on the input value is done during the analysis of a
+## "case" expression, to ensure that the patterns being used make
+## sense for the type of the input value.
+## Sometimes, that input value is complex, by depending on
+## type-variables or quantifications.
+## This function makes it easier for "case" analysis to properly
+## type-check the input with respect to the patterns.
+(def: (simplify-case caseT)
+ (-> Type (Operation Type))
+ (loop [envs (: (List (List Type))
+ (list))
+ caseT caseT]
+ (.case caseT
+ (#.Var id)
+ (do ///.monad
+ [?caseT' (//type.with-env
+ (check.read id))]
+ (.case ?caseT'
+ (#.Some caseT')
+ (recur envs caseT')
+
+ _
+ (///.throw cannot-simplify-for-pattern-matching caseT)))
+
+ (#.Named name unnamedT)
+ (recur envs unnamedT)
+
+ (#.UnivQ env unquantifiedT)
+ (recur (#.Cons env envs) unquantifiedT)
+
+ (#.ExQ _)
+ (do ///.monad
+ [[ex-id exT] (//type.with-env
+ check.existential)]
+ (recur envs (maybe.assume (type.apply (list exT) caseT))))
+
+ (#.Apply inputT funcT)
+ (.case funcT
+ (#.Var funcT-id)
+ (do ///.monad
+ [funcT' (//type.with-env
+ (do check.monad
+ [?funct' (check.read funcT-id)]
+ (.case ?funct'
+ (#.Some funct')
+ (wrap funct')
+
+ _
+ (check.throw cannot-simplify-for-pattern-matching caseT))))]
+ (recur envs (#.Apply inputT funcT')))
+
+ _
+ (.case (type.apply (list inputT) funcT)
+ (#.Some outputT)
+ (recur envs outputT)
+
+ #.None
+ (///.throw cannot-simplify-for-pattern-matching caseT)))
+
+ (#.Product _)
+ (|> caseT
+ type.flatten-tuple
+ (list/map (re-quantify envs))
+ type.tuple
+ (:: ///.monad wrap))
+
+ _
+ (:: ///.monad wrap (re-quantify envs caseT)))))
+
+(def: (analyse-primitive type inputT cursor output next)
+ (All [a] (-> Type Type Cursor Pattern (Operation a) (Operation [Pattern a])))
+ (//.with-cursor cursor
+ (do ///.monad
+ [_ (//type.with-env
+ (check.check inputT type))
+ outputA next]
+ (wrap [output outputA]))))
+
+## This function handles several concerns at once, but it must be that
+## way because those concerns are interleaved when doing
+## pattern-matching and they cannot be separated.
+## The pattern is analysed in order to get a general feel for what is
+## expected of the input value. This, in turn, informs the
+## type-checking of the input.
+## A kind of "continuation" value is passed around which signifies
+## what needs to be done _after_ analysing a pattern.
+## In general, this is done to analyse the "body" expression
+## associated to a particular pattern _in the context of_ said
+## pattern.
+## The reason why *context* is important is because patterns may bind
+## values to local variables, which may in turn be referenced in the
+## body expressions.
+## That is why the body must be analysed in the context of the
+## pattern, and not separately.
+(def: (analyse-pattern num-tags inputT pattern next)
+ (All [a] (-> (Maybe Nat) Type Code (Operation a) (Operation [Pattern a])))
+ (.case pattern
+ [cursor (#.Identifier ["" name])]
+ (//.with-cursor cursor
+ (do ///.monad
+ [outputA (scope.with-local [name inputT]
+ next)
+ idx scope.next-local]
+ (wrap [(#//.Bind idx) outputA])))
+
+ (^template [<type> <input> <output>]
+ [cursor <input>]
+ (analyse-primitive <type> inputT cursor (#//.Simple <output>) next))
+ ([Bit (#.Bit pattern-value) (#//.Bit pattern-value)]
+ [Nat (#.Nat pattern-value) (#//.Nat pattern-value)]
+ [Int (#.Int pattern-value) (#//.Int pattern-value)]
+ [Rev (#.Rev pattern-value) (#//.Rev pattern-value)]
+ [Frac (#.Frac pattern-value) (#//.Frac pattern-value)]
+ [Text (#.Text pattern-value) (#//.Text pattern-value)]
+ [Any (#.Tuple #.Nil) #//.Unit])
+
+ (^ [cursor (#.Tuple (list singleton))])
+ (analyse-pattern #.None inputT singleton next)
+
+ [cursor (#.Tuple sub-patterns)]
+ (//.with-cursor cursor
+ (do ///.monad
+ [inputT' (simplify-case inputT)]
+ (.case inputT'
+ (#.Product _)
+ (let [subs (type.flatten-tuple inputT')
+ num-subs (maybe.default (list.size subs)
+ num-tags)
+ num-sub-patterns (list.size sub-patterns)
+ matches (cond (n/< num-subs num-sub-patterns)
+ (let [[prefix suffix] (list.split (dec num-sub-patterns) subs)]
+ (list.zip2 (list/compose prefix (list (type.tuple suffix))) sub-patterns))
+
+ (n/> num-subs num-sub-patterns)
+ (let [[prefix suffix] (list.split (dec num-subs) sub-patterns)]
+ (list.zip2 subs (list/compose prefix (list (code.tuple suffix)))))
+
+ ## (n/= num-subs num-sub-patterns)
+ (list.zip2 subs sub-patterns))]
+ (do @
+ [[memberP+ thenA] (list/fold (: (All [a]
+ (-> [Type Code] (Operation [(List Pattern) a])
+ (Operation [(List Pattern) a])))
+ (function (_ [memberT memberC] then)
+ (do @
+ [[memberP [memberP+ thenA]] ((:coerce (All [a] (-> (Maybe Nat) Type Code (Operation a) (Operation [Pattern a])))
+ analyse-pattern)
+ #.None memberT memberC then)]
+ (wrap [(list& memberP memberP+) thenA]))))
+ (do @
+ [nextA next]
+ (wrap [(list) nextA]))
+ (list.reverse matches))]
+ (wrap [(//.pattern/tuple memberP+)
+ thenA])))
+
+ _
+ (///.throw cannot-match-with-pattern [inputT pattern])
+ )))
+
+ [cursor (#.Record record)]
+ (do ///.monad
+ [record (structure.normalize record)
+ [members recordT] (structure.order record)
+ _ (//type.with-env
+ (check.check inputT recordT))]
+ (analyse-pattern (#.Some (list.size members)) inputT [cursor (#.Tuple members)] next))
+
+ [cursor (#.Tag tag)]
+ (//.with-cursor cursor
+ (analyse-pattern #.None inputT (` ((~ pattern))) next))
+
+ (^ [cursor (#.Form (list& [_ (#.Nat idx)] values))])
+ (//.with-cursor cursor
+ (do ///.monad
+ [inputT' (simplify-case inputT)]
+ (.case inputT'
+ (#.Sum _)
+ (let [flat-sum (type.flatten-variant inputT')
+ size-sum (list.size flat-sum)
+ num-cases (maybe.default size-sum num-tags)]
+ (.case (list.nth idx flat-sum)
+ (^multi (#.Some caseT)
+ (n/< num-cases idx))
+ (do ///.monad
+ [[testP nextA] (if (and (n/> num-cases size-sum)
+ (n/= (dec num-cases) idx))
+ (analyse-pattern #.None
+ (type.variant (list.drop (dec num-cases) flat-sum))
+ (` [(~+ values)])
+ next)
+ (analyse-pattern #.None caseT (` [(~+ values)]) next))
+ #let [right? (n/= (dec num-cases) idx)
+ lefts (if right?
+ (dec idx)
+ idx)]]
+ (wrap [(//.pattern/variant [lefts right? testP])
+ nextA]))
+
+ _
+ (///.throw sum-has-no-case [idx inputT])))
+
+ _
+ (///.throw cannot-match-with-pattern [inputT pattern]))))
+
+ (^ [cursor (#.Form (list& [_ (#.Tag tag)] values))])
+ (//.with-cursor cursor
+ (do ///.monad
+ [tag (extension.lift (macro.normalize tag))
+ [idx group variantT] (extension.lift (macro.resolve-tag tag))
+ _ (//type.with-env
+ (check.check inputT variantT))]
+ (analyse-pattern (#.Some (list.size group)) inputT (` ((~ (code.nat idx)) (~+ values))) next)))
+
+ _
+ (///.throw not-a-pattern pattern)
+ ))
+
+(def: #export (case analyse inputC branches)
+ (-> Phase Code (List [Code Code]) (Operation Analysis))
+ (.case branches
+ (#.Cons [patternH bodyH] branchesT)
+ (do ///.monad
+ [[inputT inputA] (//type.with-inference
+ (analyse inputC))
+ outputH (analyse-pattern #.None inputT patternH (analyse bodyH))
+ outputT (monad.map @
+ (function (_ [patternT bodyT])
+ (analyse-pattern #.None inputT patternT (analyse bodyT)))
+ branchesT)
+ outputHC (|> outputH product.left coverage.determine)
+ outputTC (monad.map @ (|>> product.left coverage.determine) outputT)
+ _ (.case (monad.fold error.monad coverage.merge outputHC outputTC)
+ (#error.Success coverage)
+ (///.assert non-exhaustive-pattern-matching [inputC branches coverage]
+ (coverage.exhaustive? coverage))
+
+ (#error.Failure error)
+ (///.fail error))]
+ (wrap (#//.Case inputA [outputH outputT])))
+
+ #.Nil
+ (///.throw cannot-have-empty-branches "")))