aboutsummaryrefslogtreecommitdiff
path: root/new-luxc/source/luxc/analyser/case.lux
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--new-luxc/source/luxc/analyser/case.lux448
1 files changed, 448 insertions, 0 deletions
diff --git a/new-luxc/source/luxc/analyser/case.lux b/new-luxc/source/luxc/analyser/case.lux
new file mode 100644
index 000000000..fc151f771
--- /dev/null
+++ b/new-luxc/source/luxc/analyser/case.lux
@@ -0,0 +1,448 @@
+(;module:
+ lux
+ (lux (control monad
+ eq)
+ (data [bool "B/" Eq<Bool>]
+ [number]
+ [char]
+ [text]
+ text/format
+ [product]
+ ["R" result "R/" Monad<Result>]
+ (coll [list "L/" Fold<List> Monoid<List> Monad<List>]
+ ["D" dict]))
+ [macro #+ Monad<Lux>]
+ (macro [code])
+ [type]
+ (type ["TC" check]))
+ (luxc ["&" base]
+ (lang ["la" analysis #+ Analysis]
+ ["lp" pattern #+ Pattern])
+ ["&;" env]
+ (analyser ["&;" common]
+ ["&;" struct])))
+
+(type: #rec Coverage
+ #PartialC
+ (#BoolC Bool)
+ (#VariantC Nat (D;Dict Nat Coverage))
+ (#SeqC Coverage Coverage)
+ (#AltC Coverage Coverage)
+ #TotalC)
+
+(def: (pattern-error type pattern)
+ (-> Type Code Text)
+ (format "Cannot match this type: " (%type type) "\n"
+ " With this pattern: " (%code pattern)))
+
+(def: (simplify-case-type type)
+ (-> Type (Lux Type))
+ (case type
+ (#;Var id)
+ (do Monad<Lux>
+ [? (&;within-type-env
+ (TC;bound? id))]
+ (if ?
+ (do @
+ [type' (&;within-type-env
+ (TC;read-var id))]
+ (simplify-case-type type'))
+ (&;fail (format "Cannot simplify type for pattern-matching: " (%type type)))))
+
+ (#;Named name unnamedT)
+ (simplify-case-type unnamedT)
+
+ (^or (#;UnivQ _) (#;ExQ _))
+ (do Monad<Lux>
+ [[ex-id exT] (&;within-type-env
+ TC;existential)]
+ (simplify-case-type (assume (type;apply-type type exT))))
+
+ _
+ (:: Monad<Lux> wrap type)))
+
+(def: (analyse-pattern num-tags inputT pattern next)
+ (All [a] (-> (Maybe Nat) Type Code (Lux a) (Lux [Pattern a])))
+ (case pattern
+ [cursor (#;Symbol ["" name])]
+ (&;with-cursor cursor
+ (do Monad<Lux>
+ [outputA (&env;with-local [name inputT]
+ next)
+ idx &env;next-local]
+ (wrap [(#lp;Bind idx) outputA])))
+
+ [cursor (#;Symbol ident)]
+ (&;with-cursor cursor
+ (&;fail (format "Symbols must be unqualified inside patterns: " (%ident ident))))
+
+ (^template [<type> <code-tag> <pattern-tag>]
+ [cursor (<code-tag> test)]
+ (&;with-cursor cursor
+ (do Monad<Lux>
+ [_ (&;within-type-env
+ (TC;check inputT <type>))
+ outputA next]
+ (wrap [(<pattern-tag> test) outputA]))))
+ ([Bool #;Bool #lp;Bool]
+ [Nat #;Nat #lp;Nat]
+ [Int #;Int #lp;Int]
+ [Deg #;Deg #lp;Deg]
+ [Real #;Real #lp;Real]
+ [Char #;Char #lp;Char]
+ [Text #;Text #lp;Text])
+
+ (^ [cursor (#;Tuple (list))])
+ (&;with-cursor cursor
+ (do Monad<Lux>
+ [_ (&;within-type-env
+ (TC;check inputT Unit))
+ outputA next]
+ (wrap [#lp;Unit outputA])))
+
+ (^ [cursor (#;Tuple (list singleton))])
+ (analyse-pattern #;None inputT singleton next)
+
+ [cursor (#;Tuple sub-patterns)]
+ (&;with-cursor cursor
+ (do Monad<Lux>
+ [inputT' (simplify-case-type inputT)]
+ (case inputT'
+ (#;Product _)
+ (let [sub-types (type;flatten-tuple inputT)
+ num-sub-types (default (list;size sub-types)
+ num-tags)
+ num-sub-patterns (list;size sub-patterns)
+ matches (cond (n.< num-sub-types num-sub-patterns)
+ (let [[prefix suffix] (list;split (n.dec num-sub-patterns) sub-types)]
+ (list;zip2 (L/append prefix (list (type;tuple suffix))) sub-patterns))
+
+ (n.> num-sub-types num-sub-patterns)
+ (let [[prefix suffix] (list;split (n.dec num-sub-types) sub-patterns)]
+ (list;zip2 sub-types (L/append prefix (list (code;tuple suffix)))))
+
+ ## (n.= num-sub-types num-sub-patterns)
+ (list;zip2 sub-types sub-patterns)
+ )]
+ (do @
+ [[memberP+ thenA] (L/fold (: (All [a]
+ (-> [Type Code] (Lux [(List Pattern) a])
+ (Lux [(List Pattern) a])))
+ (function [[memberT memberC] then]
+ (do @
+ [[memberP [memberP+ thenA]] ((:! (All [a] (-> (Maybe Nat) Type Code (Lux a) (Lux [Pattern a])))
+ analyse-pattern)
+ #;None memberT memberC then)]
+ (wrap [(list& memberP memberP+) thenA]))))
+ (do @
+ [nextA next]
+ (wrap [(list) nextA]))
+ matches)]
+ (wrap [(#lp;Tuple memberP+) thenA])))
+
+ _
+ (&;fail (pattern-error inputT pattern))
+ )))
+
+ [cursor (#;Record pairs)]
+ (do Monad<Lux>
+ [pairs (&struct;normalize-record pairs)
+ [members recordT] (&struct;order-record pairs)
+ _ (&;within-type-env
+ (TC;check inputT recordT))]
+ (analyse-pattern (#;Some (list;size members)) inputT [cursor (#;Tuple members)] next))
+
+ [cursor (#;Tag tag)]
+ (&;with-cursor cursor
+ (analyse-pattern #;None inputT (` ((~ pattern))) next))
+
+ (^ [cursor (#;Form (list& [_ (#;Nat idx)] values))])
+ (&;with-cursor cursor
+ (do Monad<Lux>
+ [inputT' (simplify-case-type inputT)]
+ (case inputT'
+ (#;Sum _)
+ (let [flat-sum (type;flatten-variant inputT)]
+ (case (list;nth idx flat-sum)
+ #;None
+ (&;fail (format "Cannot match index " (%n idx) " against type: " (%type inputT)))
+
+ (#;Some case-type)
+ (do Monad<Lux>
+ [[testP nextA] (analyse-pattern #;None case-type (' [(~@ values)]) next)]
+ (wrap [(#lp;Variant [idx (default (list;size flat-sum)
+ num-tags)]
+ testP)
+ nextA]))))
+
+ _
+ (&;fail (pattern-error inputT pattern)))))
+
+ (^ [cursor (#;Form (list& [_ (#;Tag tag)] values))])
+ (&;with-cursor cursor
+ (do Monad<Lux>
+ [tag (macro;normalize tag)
+ [idx group tagT] (macro;resolve-tag tag)
+ _ (&;within-type-env
+ (TC;check inputT tagT))]
+ (analyse-pattern (#;Some (list;size group)) inputT (' ((~ (code;nat idx)) (~@ values))) next)))
+
+ _
+ (&;fail (format "Unrecognized pattern syntax: " (%code pattern)))
+ ))
+
+(def: (analyse-branch analyse inputT pattern body)
+ (-> &;Analyser Type Code Code (Lux [Pattern Analysis]))
+ (analyse-pattern #;None inputT pattern (analyse body)))
+
+(do-template [<name> <tag>]
+ [(def: (<name> coverage)
+ (-> Coverage Bool)
+ (case coverage
+ (<tag> _)
+ true
+
+ _
+ false))]
+
+ [total? #TotalC]
+ [alt? #AltC])
+
+(def: (determine-coverage pattern)
+ (-> Pattern Coverage)
+ (case pattern
+ (^or (#lp;Bind _) #lp;Unit)
+ #TotalC
+
+ (#lp;Bool value)
+ (#BoolC value)
+
+ (^or (#lp;Nat _) (#lp;Int _) (#lp;Deg _)
+ (#lp;Real _) (#lp;Char _) (#lp;Text _))
+ #PartialC
+
+ (#lp;Tuple subs)
+ (loop [subs subs]
+ (case subs
+ #;Nil
+ #TotalC
+
+ (#;Cons sub subs')
+ (let [post (recur subs')]
+ (if (total? post)
+ (determine-coverage sub)
+ (#SeqC (determine-coverage sub)
+ post)))))
+
+ (#lp;Variant [tag-id num-tags] sub)
+ (#VariantC num-tags
+ (|> (D;new number;Hash<Nat>)
+ (D;put tag-id (determine-coverage sub))))))
+
+(def: (xor left right)
+ (-> Bool Bool Bool)
+ (or (and left (not right))
+ (and (not left) right)))
+
+(def: redundant-pattern
+ (R;Result Coverage)
+ (R;fail "Redundant pattern."))
+
+(def: (flatten-alt coverage)
+ (-> Coverage (List Coverage))
+ (case coverage
+ (#AltC left right)
+ (list& left (flatten-alt right))
+
+ _
+ (list coverage)))
+
+(struct: _ (Eq Coverage)
+ (def: (= reference sample)
+ (case [reference sample]
+ (^or [#TotalC #TotalC] [#PartialC #PartialC])
+ true
+
+ [(#BoolC sideR) (#BoolC sideS)]
+ (B/= sideR sideS)
+
+ [(#VariantC allR casesR) (#VariantC allS casesS)]
+ (and (n.= allR allS)
+ (:: (D;Eq<Dict> =) = casesR casesS))
+
+ [(#SeqC leftR rightR) (#SeqC leftS rightS)]
+ (and (= leftR leftS)
+ (= rightR rightS))
+
+ [(#AltC _) (#AltC _)]
+ (let [flatR (flatten-alt reference)
+ flatS (flatten-alt sample)]
+ (and (n.= (list;size flatR) (list;size flatS))
+ (list;every? (function [[coverageR coverageS]]
+ (= coverageR coverageS))
+ (list;zip2 flatR flatS))))
+
+ _
+ false)))
+
+(open Eq<Coverage> "C/")
+
+(def: (merge-coverages addition so-far)
+ (-> Coverage Coverage (R;Result Coverage))
+ (case [addition so-far]
+ ## The addition cannot possibly improve the coverage.
+ [_ #TotalC]
+ redundant-pattern
+
+ ## The addition completes the coverage.
+ [#TotalC _]
+ (R/wrap #TotalC)
+
+ [#PartialC #PartialC]
+ (R/wrap #PartialC)
+
+ (^=> [(#BoolC sideA) (#BoolC sideSF)]
+ (xor sideA sideSF))
+ (R/wrap #TotalC)
+
+ [(#VariantC allA casesA) (#VariantC allSF casesSF)]
+ (cond (not (n.= allSF allA))
+ (R;fail "Variants do not match.")
+
+ (:: (D;Eq<Dict> Eq<Coverage>) = casesSF casesA)
+ redundant-pattern
+
+ ## else
+ (do R;Monad<Result>
+ [casesM (foldM @
+ (function [[tagA coverageA] casesSF']
+ (case (D;get tagA casesSF')
+ (#;Some coverageSF)
+ (do @
+ [coverageM (merge-coverages coverageA coverageSF)]
+ (wrap (D;put tagA coverageM casesSF')))
+
+ #;None
+ (wrap (D;put tagA coverageA casesSF'))))
+ casesSF (D;entries casesA))]
+ (wrap (if (list;every? total? (D;values casesM))
+ #TotalC
+ (#VariantC allSF casesM)))))
+
+ [(#SeqC leftA rightA) (#SeqC leftSF rightSF)]
+ (case [(C/= leftSF leftA) (C/= rightSF rightA)]
+ ## There is nothing the addition adds to the coverage.
+ [true true]
+ redundant-pattern
+
+ ## The 2 sequences cannot possibly be merged.
+ [false false]
+ (R/wrap (#AltC so-far addition))
+
+ ## Same prefix
+ [true false]
+ (do R;Monad<Result>
+ [rightM (merge-coverages rightA rightSF)]
+ (if (total? rightM)
+ (wrap leftSF)
+ (wrap (#SeqC leftSF rightM))))
+
+ ## Same suffix
+ [false true]
+ (do R;Monad<Result>
+ [leftM (merge-coverages leftA leftSF)]
+ (wrap (#SeqC leftM rightA))))
+
+ ## The left part will always match, so the addition is redundant.
+ (^=> [(#SeqC left right) single]
+ (C/= left single))
+ redundant-pattern
+
+ ## The right part is not necessary, since it can always match the left.
+ (^=> [single (#SeqC left right)]
+ (C/= left single))
+ (R/wrap single)
+
+ [_ (#AltC leftS rightS)]
+ (do R;Monad<Result>
+ [#let [fuse-once (: (-> Coverage (List Coverage)
+ (R;Result [(Maybe Coverage)
+ (List Coverage)]))
+ (function [coverage possibilities]
+ (loop [alts possibilities]
+ (case alts
+ #;Nil
+ (wrap [#;None (list coverage)])
+
+ (#;Cons alt alts')
+ (case (merge-coverages coverage alt)
+ (#R;Success altM)
+ (case altM
+ (#AltC _)
+ (do @
+ [[success alts+] (recur alts')]
+ (wrap [success (#;Cons alt alts+)]))
+
+ _
+ (wrap [(#;Some altM) alts']))
+
+ (#R;Error error)
+ (R;fail error))
+ ))))]
+ [success possibilities] (fuse-once addition (flatten-alt so-far))]
+ (loop [success success
+ possibilities possibilities]
+ (case success
+ (#;Some coverage')
+ (do @
+ [[success' possibilities'] (fuse-once coverage' possibilities)]
+ (recur success' possibilities'))
+
+ #;None
+ (case (list;reverse possibilities)
+ #;Nil
+ (R;fail "{ This is not supposed to happen... }")
+
+ (#;Cons last prevs)
+ (wrap (L/fold (function [left right] (#AltC left right))
+ last
+ prevs))))))
+
+ _
+ (if (C/= so-far addition)
+ ## The addition cannot possibly improve the coverage.
+ redundant-pattern
+ ## There are now 2 alternative paths.
+ (R/wrap (#AltC so-far addition)))))
+
+(def: get-coverage
+ (-> [Pattern Analysis] Coverage)
+ (|>. product;left determine-coverage))
+
+(def: #export (analyse-case analyse input branches)
+ (-> &;Analyser Code (List [Code Code]) (Lux Analysis))
+ (case branches
+ #;Nil
+ (&;fail "Cannot have empty branches in pattern-matching expression.")
+
+ (#;Cons [patternH bodyH] branchesT)
+ (do Monad<Lux>
+ [[inputT inputA] (&common;with-unknown-type
+ (analyse input))
+ outputH (analyse-branch analyse inputT patternH bodyH)
+ outputT (mapM @
+ (function [[patternT bodyT]]
+ (analyse-branch analyse inputT patternT bodyT))
+ branchesT)
+ _ (case (foldM R;Monad<Result>
+ merge-coverages
+ (get-coverage outputH)
+ (L/map get-coverage outputT))
+ (#R;Success coverage)
+ (if (total? coverage)
+ (wrap [])
+ (&;fail "Pattern-matching is not total."))
+
+ (#R;Error error)
+ (&;fail error))]
+ (wrap (#la;Case inputA (#;Cons outputH outputT))))))