aboutsummaryrefslogtreecommitdiff
path: root/stdlib
diff options
context:
space:
mode:
authorEduardo Julian2019-05-28 18:49:30 -0400
committerEduardo Julian2019-05-28 18:49:30 -0400
commitd96f2ae9ef8773f6aef2f68940f23e5e1d91a674 (patch)
tree4461cc1132b78503ce5bf29a56abb2499ddf0a8f /stdlib
parent926c3e1dcc392dc21b77a93200fa3e01eb113cf2 (diff)
Improved type inference/checking.
Diffstat (limited to 'stdlib')
-rw-r--r--stdlib/source/lux/tool/compiler/phase/analysis/inference.lux43
-rw-r--r--stdlib/source/lux/tool/compiler/phase/analysis/structure.lux45
-rw-r--r--stdlib/source/lux/type/check.lux9
3 files changed, 67 insertions, 30 deletions
diff --git a/stdlib/source/lux/tool/compiler/phase/analysis/inference.lux b/stdlib/source/lux/tool/compiler/phase/analysis/inference.lux
index 96ec554ad..7ef29752e 100644
--- a/stdlib/source/lux/tool/compiler/phase/analysis/inference.lux
+++ b/stdlib/source/lux/tool/compiler/phase/analysis/inference.lux
@@ -170,17 +170,42 @@
(/.throw cannot-infer [inferT args]))
))
+(def: (substitute-bound target sub)
+ (-> Nat Type Type Type)
+ (function (recur base)
+ (case base
+ (#.Primitive name parameters)
+ (#.Primitive name (list@map recur parameters))
+
+ (^template [<tag>]
+ (<tag> left right)
+ (<tag> (recur left) (recur right)))
+ ([#.Sum] [#.Product] [#.Function] [#.Apply])
+
+ (#.Parameter index)
+ (if (n/= target index)
+ sub
+ base)
+
+ (^template [<tag>]
+ (<tag> environment quantified)
+ (<tag> (list@map recur environment) quantified))
+ ([#.UnivQ] [#.ExQ])
+
+ _
+ base)))
+
## Turns a record type into the kind of function type suitable for inference.
-(def: #export (record inferT)
- (-> Type (Operation Type))
+(def: (record' target originalT inferT)
+ (-> Nat Type Type (Operation Type))
(case inferT
(#.Named name unnamedT)
- (record unnamedT)
+ (record' target originalT unnamedT)
(^template [<tag>]
(<tag> env bodyT)
(do ///.monad
- [bodyT+ (record bodyT)]
+ [bodyT+ (record' (n/+ 2 target) originalT bodyT)]
(wrap (<tag> env bodyT+))))
([#.UnivQ]
[#.ExQ])
@@ -188,17 +213,23 @@
(#.Apply inputT funcT)
(case (type.apply (list inputT) funcT)
(#.Some outputT)
- (record outputT)
+ (record' target originalT outputT)
#.None
(/.throw invalid-type-application inferT))
(#.Product _)
- (///@wrap (type.function (type.flatten-tuple inferT) inferT))
+ (///@wrap (|> inferT
+ (type.function (type.flatten-tuple inferT))
+ (substitute-bound target originalT)))
_
(/.throw not-a-record-type inferT)))
+(def: #export (record inferT)
+ (-> Type (Operation Type))
+ (record' (n/- 2 0) inferT inferT))
+
## Turns a variant type into the kind of function type suitable for inference.
(def: #export (variant tag expected-size inferT)
(-> Nat Nat Type (Operation Type))
diff --git a/stdlib/source/lux/tool/compiler/phase/analysis/structure.lux b/stdlib/source/lux/tool/compiler/phase/analysis/structure.lux
index a69346071..aebbe75ba 100644
--- a/stdlib/source/lux/tool/compiler/phase/analysis/structure.lux
+++ b/stdlib/source/lux/tool/compiler/phase/analysis/structure.lux
@@ -87,8 +87,10 @@
(def: #export (sum analyse tag valueC)
(-> Phase Nat Code (Operation Analysis))
(do ///.monad
- [expectedT (///extension.lift macro.expected-type)]
- (/.with-stack cannot-analyse-variant [expectedT tag valueC]
+ [expectedT (///extension.lift macro.expected-type)
+ expectedT' (//type.with-env
+ (check.clean expectedT))]
+ (/.with-stack cannot-analyse-variant [expectedT' tag valueC]
(case expectedT
(#.Sum _)
(let [flat (type.flatten-variant expectedT)
@@ -338,25 +340,24 @@
(def: #export (record analyse members)
(-> Phase (List [Code Code]) (Operation Analysis))
- (do ///.monad
- [members (normalize members)
- [membersC recordT] (order members)]
- (case membersC
- (^ (list))
- //primitive.unit
-
- (^ (list singletonC))
- (analyse singletonC)
+ (case members
+ (^ (list))
+ //primitive.unit
- _
- (do @
- [expectedT (///extension.lift macro.expected-type)]
- (case expectedT
- (#.Var _)
- (do @
- [inferenceT (//inference.record recordT)
- [inferredT membersA] (//inference.general analyse inferenceT membersC)]
- (wrap (/.tuple membersA)))
+ (^ (list [_ singletonC]))
+ (analyse singletonC)
- _
- (..product analyse membersC))))))
+ _
+ (do ///.monad
+ [members (normalize members)
+ [membersC recordT] (order members)
+ expectedT (///extension.lift macro.expected-type)]
+ (case expectedT
+ (#.Var _)
+ (do @
+ [inferenceT (//inference.record recordT)
+ [inferredT membersA] (//inference.general analyse inferenceT membersC)]
+ (wrap (/.tuple membersA)))
+
+ _
+ (..product analyse membersC)))))
diff --git a/stdlib/source/lux/type/check.lux b/stdlib/source/lux/type/check.lux
index 3eed928f9..23abe544a 100644
--- a/stdlib/source/lux/type/check.lux
+++ b/stdlib/source/lux/type/check.lux
@@ -266,7 +266,12 @@
(apply-type! funcT' argT)
_
- (throw invalid-type-application [funcT argT])))
+ (throw ..invalid-type-application [funcT argT])))
+
+ (#.Apply argT' funcT')
+ (do ..monad
+ [funcT'' (apply-type! funcT' argT')]
+ (apply-type! funcT'' argT))
_
(case (//.apply (list argT) funcT)
@@ -274,7 +279,7 @@
(check@wrap output)
_
- (throw invalid-type-application [funcT argT]))))
+ (throw ..invalid-type-application [funcT argT]))))
(type: #export Ring (Set Var))