diff options
Diffstat (limited to 'stdlib/source/library/lux/type/implicit.lux')
-rw-r--r-- | stdlib/source/library/lux/type/implicit.lux | 401 |
1 files changed, 401 insertions, 0 deletions
diff --git a/stdlib/source/library/lux/type/implicit.lux b/stdlib/source/library/lux/type/implicit.lux new file mode 100644 index 000000000..a308b99a8 --- /dev/null +++ b/stdlib/source/library/lux/type/implicit.lux @@ -0,0 +1,401 @@ +(.module: + [library + [lux #* + [abstract + ["." monad (#+ Monad do)] + ["eq" equivalence]] + [control + ["." try] + ["p" parser + ["s" code (#+ Parser)]]] + [data + ["." product] + ["." maybe] + ["." text ("#\." equivalence) + ["%" format (#+ format)]] + [collection + ["." list ("#\." monad fold)] + ["." dictionary (#+ Dictionary)]]] + ["." macro + ["." code] + [syntax (#+ syntax:)]] + [math + ["." number + ["n" nat]]] + ["." meta + ["." annotation]] + ["." type + ["." check (#+ Check)]]]]) + +(def: (find_type_var id env) + (-> Nat Type_Context (Meta Type)) + (case (list.find (|>> product.left (n.= id)) + (get@ #.var_bindings env)) + (#.Some [_ (#.Some type)]) + (case type + (#.Var id') + (find_type_var id' env) + + _ + (\ meta.monad wrap type)) + + (#.Some [_ #.None]) + (meta.fail (format "Unbound type-var " (%.nat id))) + + #.None + (meta.fail (format "Unknown type-var " (%.nat id))) + )) + +(def: (resolve_type var_name) + (-> Name (Meta Type)) + (do meta.monad + [raw_type (meta.find_type var_name) + compiler meta.get_compiler] + (case raw_type + (#.Var id) + (find_type_var id (get@ #.type_context compiler)) + + _ + (wrap raw_type)))) + +(def: (find_member_type idx sig_type) + (-> Nat Type (Check Type)) + (case sig_type + (#.Named _ sig_type') + (find_member_type idx sig_type') + + (#.Apply arg func) + (case (type.apply (list arg) func) + #.None + (check.fail (format "Cannot apply type " (%.type func) " to type " (%.type arg))) + + (#.Some sig_type') + (find_member_type idx sig_type')) + + (#.Product left right) + (if (n.= 0 idx) + (\ check.monad wrap left) + (find_member_type (dec idx) right)) + + _ + (if (n.= 0 idx) + (\ check.monad wrap sig_type) + (check.fail (format "Cannot find member type " (%.nat idx) " for " (%.type sig_type)))))) + +(def: (find_member_name member) + (-> Name (Meta Name)) + (case member + ["" simple_name] + (meta.either (do meta.monad + [member (meta.normalize member) + _ (meta.resolve_tag member)] + (wrap member)) + (do {! meta.monad} + [this_module_name meta.current_module_name + imp_mods (meta.imported_modules this_module_name) + tag_lists (monad.map ! meta.tag_lists imp_mods) + #let [tag_lists (|> tag_lists list\join (list\map product.left) list\join) + candidates (list.filter (|>> product.right (text\= simple_name)) + tag_lists)]] + (case candidates + #.Nil + (meta.fail (format "Unknown tag: " (%.name member))) + + (#.Cons winner #.Nil) + (wrap winner) + + _ + (meta.fail (format "Too many candidate tags: " (%.list %.name candidates)))))) + + _ + (\ meta.monad wrap member))) + +(def: (resolve_member member) + (-> Name (Meta [Nat Type])) + (do meta.monad + [member (find_member_name member) + [idx tag_list sig_type] (meta.resolve_tag member)] + (wrap [idx sig_type]))) + +(def: (prepare_definitions source_module target_module constants aggregate) + (-> Text Text (List [Text Definition]) (-> (List [Name Type]) (List [Name Type]))) + (list\fold (function (_ [name [exported? def_type def_anns def_value]] aggregate) + (if (and (annotation.implementation? def_anns) + (or (text\= target_module source_module) + exported?)) + (#.Cons [[source_module name] def_type] aggregate) + aggregate)) + aggregate + constants)) + +(def: local_env + (Meta (List [Name Type])) + (do meta.monad + [local_batches meta.locals + #let [total_locals (list\fold (function (_ [name type] table) + (try.default table (dictionary.try_put name type table))) + (: (Dictionary Text Type) + (dictionary.new text.hash)) + (list\join local_batches))]] + (wrap (|> total_locals + dictionary.entries + (list\map (function (_ [name type]) [["" name] type])))))) + +(def: local_structs + (Meta (List [Name Type])) + (do {! meta.monad} + [this_module_name meta.current_module_name + definitions (meta.definitions this_module_name)] + (wrap (prepare_definitions this_module_name this_module_name definitions #.Nil)))) + +(def: imported_structs + (Meta (List [Name Type])) + (do {! meta.monad} + [this_module_name meta.current_module_name + imported_modules (meta.imported_modules this_module_name) + accessible_definitions (monad.map ! meta.definitions imported_modules)] + (wrap (list\fold (function (_ [imported_module definitions] tail) + (prepare_definitions imported_module this_module_name definitions tail)) + #.Nil + (list.zip/2 imported_modules accessible_definitions))))) + +(def: (apply_function_type func arg) + (-> Type Type (Check Type)) + (case func + (#.Named _ func') + (apply_function_type func' arg) + + (#.UnivQ _) + (do check.monad + [[id var] check.var] + (apply_function_type (maybe.assume (type.apply (list var) func)) + arg)) + + (#.Function input output) + (do check.monad + [_ (check.check input arg)] + (wrap output)) + + _ + (check.fail (format "Invalid function type: " (%.type func))))) + +(def: (concrete_type type) + (-> Type (Check [(List Nat) Type])) + (case type + (#.UnivQ _) + (do check.monad + [[id var] check.var + [ids final_output] (concrete_type (maybe.assume (type.apply (list var) type)))] + (wrap [(#.Cons id ids) + final_output])) + + _ + (\ check.monad wrap [(list) type]))) + +(def: (check_apply member_type input_types output_type) + (-> Type (List Type) Type (Check [])) + (do check.monad + [member_type' (monad.fold check.monad + (function (_ input member) + (apply_function_type member input)) + member_type + input_types)] + (check.check output_type member_type'))) + +(type: #rec Instance + {#constructor Name + #dependencies (List Instance)}) + +(def: (test_provision provision context dep alts) + (-> (-> Lux Type_Context Type (Check Instance)) + Type_Context Type (List [Name Type]) + (Meta (List Instance))) + (do meta.monad + [compiler meta.get_compiler] + (case (|> alts + (list\map (function (_ [alt_name alt_type]) + (case (check.run context + (do {! check.monad} + [[tvars alt_type] (concrete_type alt_type) + #let [[deps alt_type] (type.flatten_function alt_type)] + _ (check.check dep alt_type) + context' check.context + =deps (monad.map ! (provision compiler context') deps)] + (wrap =deps))) + (#.Left error) + (list) + + (#.Right =deps) + (list [alt_name =deps])))) + list\join) + #.Nil + (meta.fail (format "No candidates for provisioning: " (%.type dep))) + + found + (wrap found)))) + +(def: (provision compiler context dep) + (-> Lux Type_Context Type (Check Instance)) + (case (meta.run compiler + ($_ meta.either + (do meta.monad [alts ..local_env] (..test_provision provision context dep alts)) + (do meta.monad [alts ..local_structs] (..test_provision provision context dep alts)) + (do meta.monad [alts ..imported_structs] (..test_provision provision context dep alts)))) + (#.Left error) + (check.fail error) + + (#.Right candidates) + (case candidates + #.Nil + (check.fail (format "No candidates for provisioning: " (%.type dep))) + + (#.Cons winner #.Nil) + (\ check.monad wrap winner) + + _ + (check.fail (format "Too many candidates for provisioning: " (%.type dep) " --- " (%.list (|>> product.left %.name) candidates)))) + )) + +(def: (test_alternatives sig_type member_idx input_types output_type alts) + (-> Type Nat (List Type) Type (List [Name Type]) (Meta (List Instance))) + (do meta.monad + [compiler meta.get_compiler + context meta.type_context] + (case (|> alts + (list\map (function (_ [alt_name alt_type]) + (case (check.run context + (do {! check.monad} + [[tvars alt_type] (concrete_type alt_type) + #let [[deps alt_type] (type.flatten_function alt_type)] + _ (check.check alt_type sig_type) + member_type (find_member_type member_idx alt_type) + _ (check_apply member_type input_types output_type) + context' check.context + =deps (monad.map ! (provision compiler context') deps)] + (wrap =deps))) + (#.Left error) + (list) + + (#.Right =deps) + (list [alt_name =deps])))) + list\join) + #.Nil + (meta.fail (format "No alternatives for " (%.type (type.function input_types output_type)))) + + found + (wrap found)))) + +(def: (find_alternatives sig_type member_idx input_types output_type) + (-> Type Nat (List Type) Type (Meta (List Instance))) + (let [test (test_alternatives sig_type member_idx input_types output_type)] + ($_ meta.either + (do meta.monad [alts ..local_env] (test alts)) + (do meta.monad [alts ..local_structs] (test alts)) + (do meta.monad [alts ..imported_structs] (test alts))))) + +(def: (var? input) + (-> Code Bit) + (case input + [_ (#.Identifier _)] + #1 + + _ + #0)) + +(def: (join_pair [l r]) + (All [a] (-> [a a] (List a))) + (list l r)) + +(def: (instance$ [constructor dependencies]) + (-> Instance Code) + (case dependencies + #.Nil + (code.identifier constructor) + + _ + (` ((~ (code.identifier constructor)) (~+ (list\map instance$ dependencies)))))) + +(syntax: #export (\\ + {member s.identifier} + {args (p.or (p.and (p.some s.identifier) s.end!) + (p.and (p.some s.any) s.end!))}) + {#.doc (doc "Automatic implementation selection (for type-class style polymorphism)." + "This feature layers type-class style polymorphism on top of Lux's signatures and implementations." + "When calling a polymorphic function, or using a polymorphic constant," + "this macro will check the types of the arguments, and the expected type for the whole expression" + "and it will search in the local scope, the module's scope and the imports' scope" + "in order to find suitable implementations to satisfy those requirements." + "If a single alternative is found, that one will be used automatically." + "If no alternative is found, or if more than one alternative is found (ambiguity)" + "a compile-time error will be raised, to alert the user." + "Examples:" + "Nat equivalence" + (\ number.equivalence = x y) + (\\ = x y) + "Can optionally add the prefix of the module where the signature was defined." + (\\ eq.= x y) + "(List Nat) equivalence" + (\\ = + (list.indices 10) + (list.indices 10)) + "(Functor List) map" + (\\ map inc (list.indices 10)) + "Caveat emptor: You need to make sure to import the module of any implementation you want to use." + "Otherwise, this macro will not find it.")} + (case args + (#.Left [args _]) + (do {! meta.monad} + [[member_idx sig_type] (resolve_member member) + input_types (monad.map ! resolve_type args) + output_type meta.expected_type + chosen_ones (find_alternatives sig_type member_idx input_types output_type)] + (case chosen_ones + #.Nil + (meta.fail (format "No implementation could be found for member: " (%.name member))) + + (#.Cons chosen #.Nil) + (wrap (list (` (\ (~ (instance$ chosen)) + (~ (code.local_identifier (product.right member))) + (~+ (list\map code.identifier args)))))) + + _ + (meta.fail (format "Too many implementations available: " + (|> chosen_ones + (list\map (|>> product.left %.name)) + (text.join_with ", ")) + " --- for type: " (%.type sig_type))))) + + (#.Right [args _]) + (do {! meta.monad} + [labels (|> (macro.gensym "") (list.repeat (list.size args)) (monad.seq !))] + (wrap (list (` (let [(~+ (|> (list.zip/2 labels args) (list\map join_pair) list\join))] + (..\\ (~ (code.identifier member)) (~+ labels))))))) + )) + +(def: (implicit_bindings amount) + (-> Nat (Meta (List Code))) + (|> (macro.gensym "g!implicit") + (list.repeat amount) + (monad.seq meta.monad))) + +(def: implicits + (Parser (List Code)) + (s.tuple (p.many s.any))) + +(syntax: #export (with {implementations ..implicits} body) + (do meta.monad + [g!implicit+ (implicit_bindings (list.size implementations))] + (wrap (list (` (let [(~+ (|> (list.zip/2 g!implicit+ implementations) + (list\map (function (_ [g!implicit implementation]) + (list g!implicit implementation))) + list\join))] + (~ body))))))) + +(syntax: #export (implicit: {implementations ..implicits}) + (do meta.monad + [g!implicit+ (implicit_bindings (list.size implementations))] + (wrap (|> (list.zip/2 g!implicit+ implementations) + (list\map (function (_ [g!implicit implementation]) + (` (def: (~ g!implicit) + {#.implementation? #1} + (~ implementation))))))))) |