diff options
author | Son Ho | 2023-10-24 14:32:38 +0200 |
---|---|---|
committer | Son Ho | 2023-10-24 14:32:38 +0200 |
commit | 63107911c16a9991f7d5cf8c6df621318a03ca3b (patch) | |
tree | 6b6df07f1ffe96d5b2660447e38f79b4998070b6 | |
parent | f35aba375757db99d2dc6ac2b11b9eb1e437b420 (diff) |
Fix various issues with the builtins
-rw-r--r-- | compiler/Extract.ml | 115 | ||||
-rw-r--r-- | compiler/ExtractBase.ml | 1 | ||||
-rw-r--r-- | compiler/ExtractBuiltin.ml | 82 | ||||
-rw-r--r-- | compiler/Translate.ml | 158 |
4 files changed, 234 insertions, 122 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 6a306592..ddc02fa7 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -1277,7 +1277,10 @@ and extract_trait_decl_ref (ctx : extraction_ctx) (fmt : F.formatter) let name = ctx_get_trait_decl is_opaque tr.trait_decl_id ctx in if use_brackets then F.pp_print_string fmt "("; F.pp_print_string fmt name; - extract_generic_args ctx fmt no_params_tys tr.decl_generics; + (* There is something subtle here: the trait obligations for the implemented + trait are put inside the parent clauses, so we must ignore them here *) + let generics = { tr.decl_generics with trait_refs = [] } in + extract_generic_args ctx fmt no_params_tys generics; if use_brackets then F.pp_print_string fmt ")" and extract_generic_args (ctx : extraction_ctx) (fmt : F.formatter) @@ -1349,7 +1352,7 @@ let extract_type_decl_register_names (ctx : extraction_ctx) (def : type_decl) : let def_name = match info with | None -> ctx.fmt.type_name def.name - | Some info -> info.rust_name + | Some info -> String.concat "." info.rust_name in let is_opaque = def.kind = Opaque in let ctx = ctx_add is_opaque (TypeId (AdtId def.def_id)) def_name ctx in @@ -1363,7 +1366,7 @@ let extract_type_decl_register_names (ctx : extraction_ctx) (def : type_decl) : (* Compute the names *) let field_names, cons_name = match info with - | None -> + | None | Some { body_info = None; _ } -> let field_names = FieldId.mapi (fun fid (field : field) -> @@ -1379,7 +1382,11 @@ let extract_type_decl_register_names (ctx : extraction_ctx) (def : type_decl) : (List.combine fields field_names) in (field_names, cons_name) - | _ -> raise (Failure "Invalid builtin information") + | Some info -> + raise + (Failure + ("Invalid builtin information: " + ^ show_builtin_type_info info)) in (* Add the fields *) let ctx = @@ -2365,33 +2372,70 @@ let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx) let extract_fun_decl_register_names (ctx : extraction_ctx) (has_decreases_clause : fun_decl -> bool) (def : pure_fun_translation) : extraction_ctx = - let fwd = def.fwd in - let backs = def.backs in - (* Register the decrease clauses, if necessary *) - let register_decreases ctx def = - if has_decreases_clause def then - (* Add the termination measure *) - let ctx = ctx_add_termination_measure def ctx in - (* Add the decreases proof for Lean only *) - match !Config.backend with - | Coq | FStar -> ctx - | HOL4 -> raise (Failure "Unexpected") - | Lean -> ctx_add_decreases_proof def ctx - else ctx - in - let ctx = List.fold_left register_decreases ctx (fwd.f :: fwd.loops) in - let register_fun ctx f = ctx_add_fun_decl def f ctx in - let register_funs ctx fl = List.fold_left register_fun ctx fl in - (* Register the names of the forward functions *) - let ctx = - if def.keep_fwd then register_funs ctx (fwd.f :: fwd.loops) else ctx - in - (* Register the names of the backward functions *) - List.fold_left - (fun ctx { f = back; loops = loop_backs } -> - let ctx = register_fun ctx back in - register_funs ctx loop_backs) - ctx backs + (* Ignore the trait methods **declarations** (rem.: we do not ignore the trait + method implementations): we do not need to refer to them directly. We will + only use their type for the fields of the records we generate for the trait + declarations *) + match def.fwd.f.kind with + | TraitMethodDecl _ -> ctx + | _ -> ( + (* Check if the function is builtin *) + let builtin = + let open ExtractBuiltin in + let funs_map = builtin_funs_map () in + let sname = name_to_simple_name def.fwd.f.basename in + SimpleNameMap.find_opt sname funs_map + in + (* Use the builtin names if necessary *) + match builtin with + | Some (_filter, info) -> + let backs = List.map (fun f -> f.f) def.backs in + let funs = if def.keep_fwd then def.fwd.f :: backs else backs in + let is_opaque = false in + List.fold_left + (fun ctx (f : fun_decl) -> + let open ExtractBuiltin in + let fun_id = + (Pure.FunId (Regular f.def_id), f.loop_id, f.back_id) + in + let fun_name = + (List.find + (fun (x : builtin_fun_info) -> x.rg = f.back_id) + info) + .extract_name + in + ctx_add is_opaque (FunId (FromLlbc fun_id)) fun_name ctx) + ctx funs + | None -> + let fwd = def.fwd in + let backs = def.backs in + (* Register the decrease clauses, if necessary *) + let register_decreases ctx def = + if has_decreases_clause def then + (* Add the termination measure *) + let ctx = ctx_add_termination_measure def ctx in + (* Add the decreases proof for Lean only *) + match !Config.backend with + | Coq | FStar -> ctx + | HOL4 -> raise (Failure "Unexpected") + | Lean -> ctx_add_decreases_proof def ctx + else ctx + in + let ctx = + List.fold_left register_decreases ctx (fwd.f :: fwd.loops) + in + let register_fun ctx f = ctx_add_fun_decl def f ctx in + let register_funs ctx fl = List.fold_left register_fun ctx fl in + (* Register the names of the forward functions *) + let ctx = + if def.keep_fwd then register_funs ctx (fwd.f :: fwd.loops) else ctx + in + (* Register the names of the backward functions *) + List.fold_left + (fun ctx { f = back; loops = loop_backs } -> + let ctx = register_fun ctx back in + register_funs ctx loop_backs) + ctx backs) (** Simply add the global name to the context. *) let extract_global_decl_register_names (ctx : extraction_ctx) @@ -4539,6 +4583,7 @@ let extract_trait_impl_method_items (ctx : extraction_ctx) (fmt : F.formatter) (** Extract a trait implementation *) let extract_trait_impl (ctx : extraction_ctx) (fmt : F.formatter) (impl : trait_impl) : unit = + log#ldebug (lazy ("extract_trait_impl: " ^ Names.name_to_string impl.name)); (* Retrieve the impl name *) let with_opaque_pre = false in let impl_name = ctx_get_trait_impl with_opaque_pre impl.def_id ctx in @@ -4565,9 +4610,11 @@ let extract_trait_impl (ctx : extraction_ctx) (fmt : F.formatter) (* `let (....) : Trait ... =` *) (* Open the box for the name + generics *) F.pp_open_hovbox fmt ctx.indent_incr; - let qualif = Option.get (ctx.fmt.fun_decl_kind_to_qualif SingleNonRec) in - F.pp_print_string fmt qualif; - F.pp_print_space fmt (); + (match ctx.fmt.fun_decl_kind_to_qualif SingleNonRec with + | Some qualif -> + F.pp_print_string fmt qualif; + F.pp_print_space fmt () + | None -> ()); F.pp_print_string fmt impl_name; (* Print the generics *) diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index ea5fe8d3..22b017e5 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -1249,6 +1249,7 @@ let ctx_compute_fun_name (trans_group : pure_fun_translation) (def : fun_decl) ctx.fmt.fun_name def.basename def.num_loops def.loop_id num_rgs rg_info (keep_fwd, num_backs) +(* TODO: move to Extract *) let ctx_add_fun_decl (trans_group : pure_fun_translation) (def : fun_decl) (ctx : extraction_ctx) : extraction_ctx = (* Sanity check: the function should not be a global body - those are handled diff --git a/compiler/ExtractBuiltin.ml b/compiler/ExtractBuiltin.ml index 3b4afff6..0d591028 100644 --- a/compiler/ExtractBuiltin.ml +++ b/compiler/ExtractBuiltin.ml @@ -78,21 +78,24 @@ let builtin_globals_map : string SimpleNameMap.t = (List.map (fun (x, y) -> (string_to_simple_name x, y)) builtin_globals) type builtin_variant_info = { fields : (string * string) list } +[@@deriving show] type builtin_enum_variant_info = { rust_variant_name : string; extract_variant_name : string; fields : string list option; } +[@@deriving show] type builtin_type_body_info = | Struct of string * string list (* The constructor name and the map for the field names *) | Enum of builtin_enum_variant_info list (* For every variant, a map for the field names *) +[@@deriving show] type builtin_type_info = { - rust_name : string; + rust_name : string list; extract_name : string; keep_params : bool list option; (** We might want to filter some of the type parameters. @@ -102,6 +105,7 @@ type builtin_type_info = { *) body_info : builtin_type_body_info option; } +[@@deriving show] (** The assumed types. @@ -113,7 +117,7 @@ let builtin_types () : builtin_type_info list = [ (* Alloc *) { - rust_name = "alloc::alloc::Global"; + rust_name = [ "alloc"; "alloc"; "Global" ]; extract_name = (match !backend with | Lean -> "AllocGlobal" @@ -123,7 +127,7 @@ let builtin_types () : builtin_type_info list = }; (* Vec *) { - rust_name = "alloc::vec::Vec"; + rust_name = [ "alloc"; "vec"; "Vec" ]; extract_name = (match !backend with Lean -> "Vec" | Coq | FStar | HOL4 -> "vec"); keep_params = Some [ true; false ]; @@ -131,7 +135,7 @@ let builtin_types () : builtin_type_info list = }; (* Option *) { - rust_name = "core::option::Option"; + rust_name = [ "core"; "option"; "Option" ]; extract_name = (match !backend with | Lean -> "Option" @@ -163,7 +167,7 @@ let builtin_types () : builtin_type_info list = }; (* Range *) { - rust_name = "core::ops::range::Range"; + rust_name = [ "core"; "ops"; "range"; "Range" ]; extract_name = (match !backend with Lean -> "Range" | Coq | FStar | HOL4 -> "range"); keep_params = None; @@ -180,9 +184,7 @@ let builtin_types () : builtin_type_info list = let mk_builtin_types_map () = SimpleNameMap.of_list - (List.map - (fun info -> (string_to_simple_name info.rust_name, info)) - (builtin_types ())) + (List.map (fun info -> (info.rust_name, info)) (builtin_types ())) let builtin_types_map = mk_memoized mk_builtin_types_map @@ -190,6 +192,7 @@ type builtin_fun_info = { rg : Types.RegionGroupId.id option; extract_name : string; } +[@@deriving show] (** The assumed functions. @@ -197,10 +200,12 @@ type builtin_fun_info = { parameters. For instance, in the case of the `Vec` functions, there is a type parameter for the allocator to use, which we want to filter. *) -let builtin_funs () : (string * bool list option * builtin_fun_info list) list = +let builtin_funs () : + (string list * bool list option * builtin_fun_info list) list = let rg0 = Some Types.RegionGroupId.zero in + (* TODO: fix the names below *) [ - ( "core::mem::replace", + ( [ "core::mem::replace" ], None, [ { @@ -218,7 +223,7 @@ let builtin_funs () : (string * bool list option * builtin_fun_info list) list = | Lean -> "mem.replace_back"); }; ] ); - ( "alloc::vec::Vec::new", + ( [ "alloc::vec::Vec::new" ], Some [ true; false ], [ { @@ -236,7 +241,7 @@ let builtin_funs () : (string * bool list option * builtin_fun_info list) list = | Lean -> "Vec.new_back"); }; ] ); - ( "alloc::vec::Vec::push", + ( [ "alloc::vec::Vec::push" ], Some [ true; false ], [ (* The forward function shouldn't be used *) @@ -255,7 +260,7 @@ let builtin_funs () : (string * bool list option * builtin_fun_info list) list = | Lean -> "Vec.push"); }; ] ); - ( "alloc::vec::Vec::insert", + ( [ "alloc::vec::Vec::insert" ], Some [ true; false ], [ (* The forward function shouldn't be used *) @@ -274,7 +279,7 @@ let builtin_funs () : (string * bool list option * builtin_fun_info list) list = | Lean -> "Vec.insert"); }; ] ); - ( "alloc::vec::Vec::len", + ( [ "alloc::vec::Vec::len" ], Some [ true; false ], [ { @@ -285,7 +290,7 @@ let builtin_funs () : (string * bool list option * builtin_fun_info list) list = | Lean -> "Vec.len"); }; ] ); - ( "alloc::vec::Vec::index", + ( [ "alloc::vec::Vec::index" ], Some [ true; false ], [ { @@ -304,7 +309,7 @@ let builtin_funs () : (string * bool list option * builtin_fun_info list) list = | Lean -> "Vec.index_shared_back"); }; ] ); - ( "alloc::vec::Vec::index_mut", + ( [ "alloc::vec::Vec::index_mut" ], Some [ true; false ], [ { @@ -323,16 +328,52 @@ let builtin_funs () : (string * bool list option * builtin_fun_info list) list = | Lean -> "Vec.index_mut_back"); }; ] ); + ( [ "alloc"; "boxed"; "Box"; "deref" ], + Some [ true; false ], + [ + { + rg = None; + extract_name = + (match !backend with + | FStar | Coq | HOL4 -> "alloc_boxed_box_deref" + | Lean -> "alloc.boxed.Box.deref"); + }; + (* The backward function shouldn't be used *) + { + rg = rg0; + extract_name = + (match !backend with + | FStar | Coq | HOL4 -> "alloc_boxed_box_deref_back" + | Lean -> "alloc.boxed.Box.deref_back"); + }; + ] ); + ( [ "alloc"; "boxed"; "Box"; "deref_mut" ], + Some [ true; false ], + [ + { + rg = None; + extract_name = + (match !backend with + | FStar | Coq | HOL4 -> "alloc_boxed_box_deref_mut" + | Lean -> "alloc.boxed.Box.deref_mut"); + }; + { + rg = rg0; + extract_name = + (match !backend with + | FStar | Coq | HOL4 -> "alloc_boxed_box_deref_mut_back" + | Lean -> "alloc.boxed.Box.deref_mut_back"); + }; + ] ); ] let mk_builtin_funs_map () = SimpleNameMap.of_list (List.map - (fun (name, filter, info) -> - (string_to_simple_name name, (filter, info))) + (fun (name, filter, info) -> (name, (filter, info))) (builtin_funs ())) -let builtin_funs_map () = mk_memoized mk_builtin_funs_map +let builtin_funs_map = mk_memoized mk_builtin_funs_map type builtin_trait_decl_info = { rust_name : string; @@ -346,6 +387,7 @@ type builtin_trait_decl_info = { - a list of clauses *) funs : (string * (Types.RegionGroupId.id option * string) list) list; } +[@@deriving show] let builtin_trait_decls_info () = let rg0 = Some Types.RegionGroupId.zero in @@ -389,7 +431,7 @@ let builtin_trait_decls_info () = [ (match !backend with | Coq | FStar | HOL4 -> "deref_inst" - | Lean -> "DerefInst"); + | Lean -> "derefInst"); ]; consts = []; types = []; diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 0871a305..95252b61 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -654,78 +654,100 @@ let export_functions_group_scc (fmt : Format.formatter) (config : gen_config) *) let export_functions_group (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) (pure_ls : pure_fun_translation list) : unit = - (* Utility to check a function has a decrease clause *) - let has_decreases_clause (def : Pure.fun_decl) : bool = - PureUtils.FunLoopIdSet.mem (def.def_id, def.loop_id) - ctx.functions_with_decreases_clause + (* Check if the definition are builtin - if yes they must be ignored. + Note that if one definition in the group is builtin, then all the + definitions must be builtin *) + let builtin = + let open ExtractBuiltin in + let funs_map = builtin_funs_map () in + List.map + (fun (trans : pure_fun_translation) -> + let sname = name_to_simple_name trans.fwd.f.basename in + SimpleNameMap.find_opt sname funs_map <> None) + pure_ls in - (* Extract the decrease clauses template bodies *) - if config.extract_template_decreases_clauses then - List.iter - (fun { fwd; _ } -> - (* We only generate decreases clauses for the forward functions, because - the termination argument should only depend on the forward inputs. - The backward functions thus use the same decreases clauses as the - forward function. - - Rem.: we might filter backward functions in {!PureMicroPasses}, but - we don't remove forward functions. Instead, we remember if we should - filter those functions at extraction time with a boolean (see the - type of the [pure_ls] input parameter). - *) - let extract_decrease decl = - let has_decr_clause = has_decreases_clause decl in - if has_decr_clause then - match !Config.backend with - | Lean -> - Extract.extract_template_lean_termination_and_decreasing ctx fmt - decl - | FStar -> - Extract.extract_template_fstar_decreases_clause ctx fmt decl - | Coq -> - raise (Failure "Coq doesn't have decreases/termination clauses") - | HOL4 -> - raise - (Failure "HOL4 doesn't have decreases/termination clauses") - in - extract_decrease fwd.f; - List.iter extract_decrease fwd.loops) - pure_ls; - - (* Concatenate the function definitions, filtering the useless forward - * functions. *) - let decls = - List.concat - (List.map - (fun { keep_fwd; fwd; backs } -> - let fwd = if keep_fwd then List.append fwd.loops [ fwd.f ] else [] in - let backs : Pure.fun_decl list = - List.concat - (List.map (fun back -> List.append back.loops [ back.f ]) backs) - in - List.append fwd backs) - pure_ls) - in + if List.exists (fun b -> b) builtin then + (* Sanity check *) + assert (List.for_all (fun b -> b) builtin) + else + (* Utility to check a function has a decrease clause *) + let has_decreases_clause (def : Pure.fun_decl) : bool = + PureUtils.FunLoopIdSet.mem (def.def_id, def.loop_id) + ctx.functions_with_decreases_clause + in + + (* Extract the decrease clauses template bodies *) + if config.extract_template_decreases_clauses then + List.iter + (fun { fwd; _ } -> + (* We only generate decreases clauses for the forward functions, because + the termination argument should only depend on the forward inputs. + The backward functions thus use the same decreases clauses as the + forward function. + + Rem.: we might filter backward functions in {!PureMicroPasses}, but + we don't remove forward functions. Instead, we remember if we should + filter those functions at extraction time with a boolean (see the + type of the [pure_ls] input parameter). + *) + let extract_decrease decl = + let has_decr_clause = has_decreases_clause decl in + if has_decr_clause then + match !Config.backend with + | Lean -> + Extract.extract_template_lean_termination_and_decreasing ctx + fmt decl + | FStar -> + Extract.extract_template_fstar_decreases_clause ctx fmt decl + | Coq -> + raise + (Failure "Coq doesn't have decreases/termination clauses") + | HOL4 -> + raise + (Failure "HOL4 doesn't have decreases/termination clauses") + in + extract_decrease fwd.f; + List.iter extract_decrease fwd.loops) + pure_ls; + + (* Concatenate the function definitions, filtering the useless forward + * functions. *) + let decls = + List.concat + (List.map + (fun { keep_fwd; fwd; backs } -> + let fwd = + if keep_fwd then List.append fwd.loops [ fwd.f ] else [] + in + let backs : Pure.fun_decl list = + List.concat + (List.map + (fun back -> List.append back.loops [ back.f ]) + backs) + in + List.append fwd backs) + pure_ls) + in - (* Extract the function definitions *) - (if config.extract_fun_decls then - (* Group the mutually recursive definitions *) - let subgroups = ReorderDecls.group_reorder_fun_decls decls in + (* Extract the function definitions *) + (if config.extract_fun_decls then + (* Group the mutually recursive definitions *) + let subgroups = ReorderDecls.group_reorder_fun_decls decls in - (* Extract the subgroups *) - let export_subgroup (is_rec : bool) (decls : Pure.fun_decl list) : unit = - export_functions_group_scc fmt config ctx is_rec decls - in - List.iter (fun (is_rec, decls) -> export_subgroup is_rec decls) subgroups); - - (* Insert unit tests if necessary *) - if config.test_trans_unit_functions then - List.iter - (fun trans -> - if trans.keep_fwd then - Extract.extract_unit_test_if_unit_fun ctx fmt trans.fwd.f) - pure_ls + (* Extract the subgroups *) + let export_subgroup (is_rec : bool) (decls : Pure.fun_decl list) : unit = + export_functions_group_scc fmt config ctx is_rec decls + in + List.iter (fun (is_rec, decls) -> export_subgroup is_rec decls) subgroups); + + (* Insert unit tests if necessary *) + if config.test_trans_unit_functions then + List.iter + (fun trans -> + if trans.keep_fwd then + Extract.extract_unit_test_if_unit_fun ctx fmt trans.fwd.f) + pure_ls (** Export a trait declaration. *) let export_trait_decl (fmt : Format.formatter) (_config : gen_config) |