aboutsummaryrefslogtreecommitdiff
path: root/stdlib/source/library/lux/type/implicit.lux
diff options
context:
space:
mode:
Diffstat (limited to 'stdlib/source/library/lux/type/implicit.lux')
-rw-r--r--stdlib/source/library/lux/type/implicit.lux401
1 files changed, 401 insertions, 0 deletions
diff --git a/stdlib/source/library/lux/type/implicit.lux b/stdlib/source/library/lux/type/implicit.lux
new file mode 100644
index 000000000..a308b99a8
--- /dev/null
+++ b/stdlib/source/library/lux/type/implicit.lux
@@ -0,0 +1,401 @@
+(.module:
+ [library
+ [lux #*
+ [abstract
+ ["." monad (#+ Monad do)]
+ ["eq" equivalence]]
+ [control
+ ["." try]
+ ["p" parser
+ ["s" code (#+ Parser)]]]
+ [data
+ ["." product]
+ ["." maybe]
+ ["." text ("#\." equivalence)
+ ["%" format (#+ format)]]
+ [collection
+ ["." list ("#\." monad fold)]
+ ["." dictionary (#+ Dictionary)]]]
+ ["." macro
+ ["." code]
+ [syntax (#+ syntax:)]]
+ [math
+ ["." number
+ ["n" nat]]]
+ ["." meta
+ ["." annotation]]
+ ["." type
+ ["." check (#+ Check)]]]])
+
+(def: (find_type_var id env)
+ (-> Nat Type_Context (Meta Type))
+ (case (list.find (|>> product.left (n.= id))
+ (get@ #.var_bindings env))
+ (#.Some [_ (#.Some type)])
+ (case type
+ (#.Var id')
+ (find_type_var id' env)
+
+ _
+ (\ meta.monad wrap type))
+
+ (#.Some [_ #.None])
+ (meta.fail (format "Unbound type-var " (%.nat id)))
+
+ #.None
+ (meta.fail (format "Unknown type-var " (%.nat id)))
+ ))
+
+(def: (resolve_type var_name)
+ (-> Name (Meta Type))
+ (do meta.monad
+ [raw_type (meta.find_type var_name)
+ compiler meta.get_compiler]
+ (case raw_type
+ (#.Var id)
+ (find_type_var id (get@ #.type_context compiler))
+
+ _
+ (wrap raw_type))))
+
+(def: (find_member_type idx sig_type)
+ (-> Nat Type (Check Type))
+ (case sig_type
+ (#.Named _ sig_type')
+ (find_member_type idx sig_type')
+
+ (#.Apply arg func)
+ (case (type.apply (list arg) func)
+ #.None
+ (check.fail (format "Cannot apply type " (%.type func) " to type " (%.type arg)))
+
+ (#.Some sig_type')
+ (find_member_type idx sig_type'))
+
+ (#.Product left right)
+ (if (n.= 0 idx)
+ (\ check.monad wrap left)
+ (find_member_type (dec idx) right))
+
+ _
+ (if (n.= 0 idx)
+ (\ check.monad wrap sig_type)
+ (check.fail (format "Cannot find member type " (%.nat idx) " for " (%.type sig_type))))))
+
+(def: (find_member_name member)
+ (-> Name (Meta Name))
+ (case member
+ ["" simple_name]
+ (meta.either (do meta.monad
+ [member (meta.normalize member)
+ _ (meta.resolve_tag member)]
+ (wrap member))
+ (do {! meta.monad}
+ [this_module_name meta.current_module_name
+ imp_mods (meta.imported_modules this_module_name)
+ tag_lists (monad.map ! meta.tag_lists imp_mods)
+ #let [tag_lists (|> tag_lists list\join (list\map product.left) list\join)
+ candidates (list.filter (|>> product.right (text\= simple_name))
+ tag_lists)]]
+ (case candidates
+ #.Nil
+ (meta.fail (format "Unknown tag: " (%.name member)))
+
+ (#.Cons winner #.Nil)
+ (wrap winner)
+
+ _
+ (meta.fail (format "Too many candidate tags: " (%.list %.name candidates))))))
+
+ _
+ (\ meta.monad wrap member)))
+
+(def: (resolve_member member)
+ (-> Name (Meta [Nat Type]))
+ (do meta.monad
+ [member (find_member_name member)
+ [idx tag_list sig_type] (meta.resolve_tag member)]
+ (wrap [idx sig_type])))
+
+(def: (prepare_definitions source_module target_module constants aggregate)
+ (-> Text Text (List [Text Definition]) (-> (List [Name Type]) (List [Name Type])))
+ (list\fold (function (_ [name [exported? def_type def_anns def_value]] aggregate)
+ (if (and (annotation.implementation? def_anns)
+ (or (text\= target_module source_module)
+ exported?))
+ (#.Cons [[source_module name] def_type] aggregate)
+ aggregate))
+ aggregate
+ constants))
+
+(def: local_env
+ (Meta (List [Name Type]))
+ (do meta.monad
+ [local_batches meta.locals
+ #let [total_locals (list\fold (function (_ [name type] table)
+ (try.default table (dictionary.try_put name type table)))
+ (: (Dictionary Text Type)
+ (dictionary.new text.hash))
+ (list\join local_batches))]]
+ (wrap (|> total_locals
+ dictionary.entries
+ (list\map (function (_ [name type]) [["" name] type]))))))
+
+(def: local_structs
+ (Meta (List [Name Type]))
+ (do {! meta.monad}
+ [this_module_name meta.current_module_name
+ definitions (meta.definitions this_module_name)]
+ (wrap (prepare_definitions this_module_name this_module_name definitions #.Nil))))
+
+(def: imported_structs
+ (Meta (List [Name Type]))
+ (do {! meta.monad}
+ [this_module_name meta.current_module_name
+ imported_modules (meta.imported_modules this_module_name)
+ accessible_definitions (monad.map ! meta.definitions imported_modules)]
+ (wrap (list\fold (function (_ [imported_module definitions] tail)
+ (prepare_definitions imported_module this_module_name definitions tail))
+ #.Nil
+ (list.zip/2 imported_modules accessible_definitions)))))
+
+(def: (apply_function_type func arg)
+ (-> Type Type (Check Type))
+ (case func
+ (#.Named _ func')
+ (apply_function_type func' arg)
+
+ (#.UnivQ _)
+ (do check.monad
+ [[id var] check.var]
+ (apply_function_type (maybe.assume (type.apply (list var) func))
+ arg))
+
+ (#.Function input output)
+ (do check.monad
+ [_ (check.check input arg)]
+ (wrap output))
+
+ _
+ (check.fail (format "Invalid function type: " (%.type func)))))
+
+(def: (concrete_type type)
+ (-> Type (Check [(List Nat) Type]))
+ (case type
+ (#.UnivQ _)
+ (do check.monad
+ [[id var] check.var
+ [ids final_output] (concrete_type (maybe.assume (type.apply (list var) type)))]
+ (wrap [(#.Cons id ids)
+ final_output]))
+
+ _
+ (\ check.monad wrap [(list) type])))
+
+(def: (check_apply member_type input_types output_type)
+ (-> Type (List Type) Type (Check []))
+ (do check.monad
+ [member_type' (monad.fold check.monad
+ (function (_ input member)
+ (apply_function_type member input))
+ member_type
+ input_types)]
+ (check.check output_type member_type')))
+
+(type: #rec Instance
+ {#constructor Name
+ #dependencies (List Instance)})
+
+(def: (test_provision provision context dep alts)
+ (-> (-> Lux Type_Context Type (Check Instance))
+ Type_Context Type (List [Name Type])
+ (Meta (List Instance)))
+ (do meta.monad
+ [compiler meta.get_compiler]
+ (case (|> alts
+ (list\map (function (_ [alt_name alt_type])
+ (case (check.run context
+ (do {! check.monad}
+ [[tvars alt_type] (concrete_type alt_type)
+ #let [[deps alt_type] (type.flatten_function alt_type)]
+ _ (check.check dep alt_type)
+ context' check.context
+ =deps (monad.map ! (provision compiler context') deps)]
+ (wrap =deps)))
+ (#.Left error)
+ (list)
+
+ (#.Right =deps)
+ (list [alt_name =deps]))))
+ list\join)
+ #.Nil
+ (meta.fail (format "No candidates for provisioning: " (%.type dep)))
+
+ found
+ (wrap found))))
+
+(def: (provision compiler context dep)
+ (-> Lux Type_Context Type (Check Instance))
+ (case (meta.run compiler
+ ($_ meta.either
+ (do meta.monad [alts ..local_env] (..test_provision provision context dep alts))
+ (do meta.monad [alts ..local_structs] (..test_provision provision context dep alts))
+ (do meta.monad [alts ..imported_structs] (..test_provision provision context dep alts))))
+ (#.Left error)
+ (check.fail error)
+
+ (#.Right candidates)
+ (case candidates
+ #.Nil
+ (check.fail (format "No candidates for provisioning: " (%.type dep)))
+
+ (#.Cons winner #.Nil)
+ (\ check.monad wrap winner)
+
+ _
+ (check.fail (format "Too many candidates for provisioning: " (%.type dep) " --- " (%.list (|>> product.left %.name) candidates))))
+ ))
+
+(def: (test_alternatives sig_type member_idx input_types output_type alts)
+ (-> Type Nat (List Type) Type (List [Name Type]) (Meta (List Instance)))
+ (do meta.monad
+ [compiler meta.get_compiler
+ context meta.type_context]
+ (case (|> alts
+ (list\map (function (_ [alt_name alt_type])
+ (case (check.run context
+ (do {! check.monad}
+ [[tvars alt_type] (concrete_type alt_type)
+ #let [[deps alt_type] (type.flatten_function alt_type)]
+ _ (check.check alt_type sig_type)
+ member_type (find_member_type member_idx alt_type)
+ _ (check_apply member_type input_types output_type)
+ context' check.context
+ =deps (monad.map ! (provision compiler context') deps)]
+ (wrap =deps)))
+ (#.Left error)
+ (list)
+
+ (#.Right =deps)
+ (list [alt_name =deps]))))
+ list\join)
+ #.Nil
+ (meta.fail (format "No alternatives for " (%.type (type.function input_types output_type))))
+
+ found
+ (wrap found))))
+
+(def: (find_alternatives sig_type member_idx input_types output_type)
+ (-> Type Nat (List Type) Type (Meta (List Instance)))
+ (let [test (test_alternatives sig_type member_idx input_types output_type)]
+ ($_ meta.either
+ (do meta.monad [alts ..local_env] (test alts))
+ (do meta.monad [alts ..local_structs] (test alts))
+ (do meta.monad [alts ..imported_structs] (test alts)))))
+
+(def: (var? input)
+ (-> Code Bit)
+ (case input
+ [_ (#.Identifier _)]
+ #1
+
+ _
+ #0))
+
+(def: (join_pair [l r])
+ (All [a] (-> [a a] (List a)))
+ (list l r))
+
+(def: (instance$ [constructor dependencies])
+ (-> Instance Code)
+ (case dependencies
+ #.Nil
+ (code.identifier constructor)
+
+ _
+ (` ((~ (code.identifier constructor)) (~+ (list\map instance$ dependencies))))))
+
+(syntax: #export (\\
+ {member s.identifier}
+ {args (p.or (p.and (p.some s.identifier) s.end!)
+ (p.and (p.some s.any) s.end!))})
+ {#.doc (doc "Automatic implementation selection (for type-class style polymorphism)."
+ "This feature layers type-class style polymorphism on top of Lux's signatures and implementations."
+ "When calling a polymorphic function, or using a polymorphic constant,"
+ "this macro will check the types of the arguments, and the expected type for the whole expression"
+ "and it will search in the local scope, the module's scope and the imports' scope"
+ "in order to find suitable implementations to satisfy those requirements."
+ "If a single alternative is found, that one will be used automatically."
+ "If no alternative is found, or if more than one alternative is found (ambiguity)"
+ "a compile-time error will be raised, to alert the user."
+ "Examples:"
+ "Nat equivalence"
+ (\ number.equivalence = x y)
+ (\\ = x y)
+ "Can optionally add the prefix of the module where the signature was defined."
+ (\\ eq.= x y)
+ "(List Nat) equivalence"
+ (\\ =
+ (list.indices 10)
+ (list.indices 10))
+ "(Functor List) map"
+ (\\ map inc (list.indices 10))
+ "Caveat emptor: You need to make sure to import the module of any implementation you want to use."
+ "Otherwise, this macro will not find it.")}
+ (case args
+ (#.Left [args _])
+ (do {! meta.monad}
+ [[member_idx sig_type] (resolve_member member)
+ input_types (monad.map ! resolve_type args)
+ output_type meta.expected_type
+ chosen_ones (find_alternatives sig_type member_idx input_types output_type)]
+ (case chosen_ones
+ #.Nil
+ (meta.fail (format "No implementation could be found for member: " (%.name member)))
+
+ (#.Cons chosen #.Nil)
+ (wrap (list (` (\ (~ (instance$ chosen))
+ (~ (code.local_identifier (product.right member)))
+ (~+ (list\map code.identifier args))))))
+
+ _
+ (meta.fail (format "Too many implementations available: "
+ (|> chosen_ones
+ (list\map (|>> product.left %.name))
+ (text.join_with ", "))
+ " --- for type: " (%.type sig_type)))))
+
+ (#.Right [args _])
+ (do {! meta.monad}
+ [labels (|> (macro.gensym "") (list.repeat (list.size args)) (monad.seq !))]
+ (wrap (list (` (let [(~+ (|> (list.zip/2 labels args) (list\map join_pair) list\join))]
+ (..\\ (~ (code.identifier member)) (~+ labels)))))))
+ ))
+
+(def: (implicit_bindings amount)
+ (-> Nat (Meta (List Code)))
+ (|> (macro.gensym "g!implicit")
+ (list.repeat amount)
+ (monad.seq meta.monad)))
+
+(def: implicits
+ (Parser (List Code))
+ (s.tuple (p.many s.any)))
+
+(syntax: #export (with {implementations ..implicits} body)
+ (do meta.monad
+ [g!implicit+ (implicit_bindings (list.size implementations))]
+ (wrap (list (` (let [(~+ (|> (list.zip/2 g!implicit+ implementations)
+ (list\map (function (_ [g!implicit implementation])
+ (list g!implicit implementation)))
+ list\join))]
+ (~ body)))))))
+
+(syntax: #export (implicit: {implementations ..implicits})
+ (do meta.monad
+ [g!implicit+ (implicit_bindings (list.size implementations))]
+ (wrap (|> (list.zip/2 g!implicit+ implementations)
+ (list\map (function (_ [g!implicit implementation])
+ (` (def: (~ g!implicit)
+ {#.implementation? #1}
+ (~ implementation)))))))))