aboutsummaryrefslogtreecommitdiff
path: root/stdlib/source/lux/language/compiler/analysis/inference.lux
diff options
context:
space:
mode:
Diffstat (limited to 'stdlib/source/lux/language/compiler/analysis/inference.lux')
-rw-r--r--stdlib/source/lux/language/compiler/analysis/inference.lux254
1 files changed, 254 insertions, 0 deletions
diff --git a/stdlib/source/lux/language/compiler/analysis/inference.lux b/stdlib/source/lux/language/compiler/analysis/inference.lux
new file mode 100644
index 000000000..a89ed40f8
--- /dev/null
+++ b/stdlib/source/lux/language/compiler/analysis/inference.lux
@@ -0,0 +1,254 @@
+(.module:
+ lux
+ (lux (control [monad #+ do]
+ ["ex" exception #+ exception:])
+ (data [maybe]
+ [text]
+ text/format
+ (collection [list "list/" Functor<List>]))
+ [macro])
+ (//// [type]
+ (type ["tc" check]))
+ [/// #+ "operation/" Monad<Operation>]
+ [// #+ Tag Analysis Operation Compiler]
+ [//type])
+
+(exception: #export (variant-tag-out-of-bounds {size Nat} {tag Tag} {type Type})
+ (ex.report ["Tag" (%n tag)]
+ ["Variant size" (%i (.int size))]
+ ["Variant type" (%type type)]))
+
+(exception: #export (cannot-infer {type Type} {args (List Code)})
+ (ex.report ["Type" (%type type)]
+ ["Arguments" (|> args
+ list.enumerate
+ (list/map (function (_ [idx argC])
+ (format "\n " (%n idx) " " (%code argC))))
+ (text.join-with ""))]))
+
+(exception: #export (cannot-infer-argument {inferred Type} {argument Code})
+ (ex.report ["Inferred Type" (%type inferred)]
+ ["Argument" (%code argument)]))
+
+(exception: #export (smaller-variant-than-expected {expected Nat} {actual Nat})
+ (ex.report ["Expected" (%i (.int expected))]
+ ["Actual" (%i (.int actual))]))
+
+(do-template [<name>]
+ [(exception: #export (<name> {type Type})
+ (%type type))]
+
+ [not-a-variant-type]
+ [not-a-record-type]
+ [invalid-type-application]
+ )
+
+(def: (replace parameter-idx replacement type)
+ (-> Nat Type Type Type)
+ (case type
+ (#.Primitive name params)
+ (#.Primitive name (list/map (replace parameter-idx replacement) params))
+
+ (^template [<tag>]
+ (<tag> left right)
+ (<tag> (replace parameter-idx replacement left)
+ (replace parameter-idx replacement right)))
+ ([#.Sum]
+ [#.Product]
+ [#.Function]
+ [#.Apply])
+
+ (#.Parameter idx)
+ (if (n/= parameter-idx idx)
+ replacement
+ type)
+
+ (^template [<tag>]
+ (<tag> env quantified)
+ (<tag> (list/map (replace parameter-idx replacement) env)
+ (replace (n/+ +2 parameter-idx) replacement quantified)))
+ ([#.UnivQ]
+ [#.ExQ])
+
+ _
+ type))
+
+(def: new-named-type
+ (Operation Type)
+ (do ///.Monad<Operation>
+ [[module line column] macro.cursor
+ [ex-id _] (//type.with-env tc.existential)]
+ (wrap (#.Primitive (format "{New Type @ " (%t module)
+ "," (%n line)
+ "," (%n column)
+ "} " (%n ex-id))
+ (list)))))
+
+## Type-inference works by applying some (potentially quantified) type
+## to a sequence of values.
+## Function types are used for this, although inference is not always
+## done for function application (alternative uses may be records and
+## tagged variants).
+## But, so long as the type being used for the inference can be treated
+## as a function type, this method of inference should work.
+(def: #export (general analyse inferT args)
+ (-> Compiler Type (List Code) (Operation [Type (List Analysis)]))
+ (case args
+ #.Nil
+ (do ///.Monad<Operation>
+ [_ (//type.infer inferT)]
+ (wrap [inferT (list)]))
+
+ (#.Cons argC args')
+ (case inferT
+ (#.Named name unnamedT)
+ (general analyse unnamedT args)
+
+ (#.UnivQ _)
+ (do ///.Monad<Operation>
+ [[var-id varT] (//type.with-env tc.var)]
+ (general analyse (maybe.assume (type.apply (list varT) inferT)) args))
+
+ (#.ExQ _)
+ (do ///.Monad<Operation>
+ [[var-id varT] (//type.with-env tc.var)
+ output (general analyse
+ (maybe.assume (type.apply (list varT) inferT))
+ args)
+ bound? (//type.with-env
+ (tc.bound? var-id))
+ _ (if bound?
+ (wrap [])
+ (do @
+ [newT new-named-type]
+ (//type.with-env
+ (tc.check varT newT))))]
+ (wrap output))
+
+ (#.Apply inputT transT)
+ (case (type.apply (list inputT) transT)
+ (#.Some outputT)
+ (general analyse outputT args)
+
+ #.None
+ (///.throw invalid-type-application inferT))
+
+ ## Arguments are inferred back-to-front because, by convention,
+ ## Lux functions take the most important arguments *last*, which
+ ## means that the most information for doing proper inference is
+ ## located in the last arguments to a function call.
+ ## By inferring back-to-front, a lot of type-annotations can be
+ ## avoided in Lux code, since the inference algorithm can piece
+ ## things together more easily.
+ (#.Function inputT outputT)
+ (do ///.Monad<Operation>
+ [[outputT' args'A] (general analyse outputT args')
+ argA (<| (///.with-stack cannot-infer-argument [inputT argC])
+ (//type.with-type inputT)
+ (analyse argC))]
+ (wrap [outputT' (list& argA args'A)]))
+
+ (#.Var infer-id)
+ (do ///.Monad<Operation>
+ [?inferT' (//type.with-env (tc.read infer-id))]
+ (case ?inferT'
+ (#.Some inferT')
+ (general analyse inferT' args)
+
+ _
+ (///.throw cannot-infer [inferT args])))
+
+ _
+ (///.throw cannot-infer [inferT args]))
+ ))
+
+## Turns a record type into the kind of function type suitable for inference.
+(def: #export (record inferT)
+ (-> Type (Operation Type))
+ (case inferT
+ (#.Named name unnamedT)
+ (record unnamedT)
+
+ (^template [<tag>]
+ (<tag> env bodyT)
+ (do ///.Monad<Operation>
+ [bodyT+ (record bodyT)]
+ (wrap (<tag> env bodyT+))))
+ ([#.UnivQ]
+ [#.ExQ])
+
+ (#.Apply inputT funcT)
+ (case (type.apply (list inputT) funcT)
+ (#.Some outputT)
+ (record outputT)
+
+ #.None
+ (///.throw invalid-type-application inferT))
+
+ (#.Product _)
+ (operation/wrap (type.function (type.flatten-tuple inferT) inferT))
+
+ _
+ (///.throw not-a-record-type inferT)))
+
+## Turns a variant type into the kind of function type suitable for inference.
+(def: #export (variant tag expected-size inferT)
+ (-> Nat Nat Type (Operation Type))
+ (loop [depth +0
+ currentT inferT]
+ (case currentT
+ (#.Named name unnamedT)
+ (do ///.Monad<Operation>
+ [unnamedT+ (recur depth unnamedT)]
+ (wrap unnamedT+))
+
+ (^template [<tag>]
+ (<tag> env bodyT)
+ (do ///.Monad<Operation>
+ [bodyT+ (recur (inc depth) bodyT)]
+ (wrap (<tag> env bodyT+))))
+ ([#.UnivQ]
+ [#.ExQ])
+
+ (#.Sum _)
+ (let [cases (type.flatten-variant currentT)
+ actual-size (list.size cases)
+ boundary (dec expected-size)]
+ (cond (or (n/= expected-size actual-size)
+ (and (n/> expected-size actual-size)
+ (n/< boundary tag)))
+ (case (list.nth tag cases)
+ (#.Some caseT)
+ (operation/wrap (if (n/= +0 depth)
+ (type.function (list caseT) currentT)
+ (let [replace' (replace (|> depth dec (n/* +2)) inferT)]
+ (type.function (list (replace' caseT))
+ (replace' currentT)))))
+
+ #.None
+ (///.throw variant-tag-out-of-bounds [expected-size tag inferT]))
+
+ (n/< expected-size actual-size)
+ (///.throw smaller-variant-than-expected [expected-size actual-size])
+
+ (n/= boundary tag)
+ (let [caseT (type.variant (list.drop boundary cases))]
+ (operation/wrap (if (n/= +0 depth)
+ (type.function (list caseT) currentT)
+ (let [replace' (replace (|> depth dec (n/* +2)) inferT)]
+ (type.function (list (replace' caseT))
+ (replace' currentT))))))
+
+ ## else
+ (///.throw variant-tag-out-of-bounds [expected-size tag inferT])))
+
+ (#.Apply inputT funcT)
+ (case (type.apply (list inputT) funcT)
+ (#.Some outputT)
+ (variant tag expected-size outputT)
+
+ #.None
+ (///.throw invalid-type-application inferT))
+
+ _
+ (///.throw not-a-variant-type inferT))))