From 9fe9fc0ab70b8629722d60748bbede554017172c Mon Sep 17 00:00:00 2001 From: Son Ho Date: Sun, 3 Sep 2023 18:59:19 +0200 Subject: Make progress on extracting trait decls and merge gen_ctx and extraction_ctx --- compiler/Extract.ml | 150 ++++++++++++++++++++++++++++++++++++++++++-- compiler/ExtractBase.ml | 4 ++ compiler/Translate.ml | 163 +++++++++++++++++++++--------------------------- 3 files changed, 219 insertions(+), 98 deletions(-) diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 5eb30daa..f911290e 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -3943,16 +3943,19 @@ let extract_trait_decl_register_names (ctx : extraction_ctx) trait_decl in let ctx = ctx_add_trait_decl trait_decl ctx in + (* Parent clauses *) let ctx = List.fold_left (fun ctx clause -> ctx_add_trait_parent_clause trait_decl clause ctx) ctx generics.trait_clauses in + (* Constants *) let ctx = List.fold_left (fun ctx (name, (_, _)) -> ctx_add_trait_const trait_decl name ctx) ctx consts in + (* Types *) let ctx = List.fold_left (fun ctx (name, (clauses, _)) -> @@ -3963,19 +3966,156 @@ let extract_trait_decl_register_names (ctx : extraction_ctx) ctx clauses) ctx types in + (* Required methods *) + (* TODO: for the methods, we need to add fields for the forward/backward functions *) + raise (Failure "TODO"); List.fold_left - (fun ctx (name, _) -> ctx_add_trait_method trait_decl name ctx) + (fun ctx (name, id) -> ctx_add_trait_method trait_decl name ctx) ctx required_methods (** Similar to {!extract_type_decl_register_names} *) -let extract_trait_impl_register_names (ctx : extraction_ctx) (d : trait_impl) : - extraction_ctx = +let extract_trait_impl_register_names (ctx : extraction_ctx) + (trait_impl : trait_impl) : extraction_ctx = + (* For now we do not support overriding provided methods *) + assert (trait_impl.provided_methods = []); + (* Everything is actually taken care of by {!extract_trait_decl_register_names} *) + ctx + +(** Small helper. + + The type `ty` is to be understood in a very general sense. + *) +let extract_trait_decl_item (ctx : extraction_ctx) (fmt : F.formatter) + (item_name : string) (ty : unit -> unit) : unit = + F.pp_print_space fmt (); + F.pp_open_vbox fmt ctx.indent_incr; + F.pp_print_string fmt item_name; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + ty (); + F.pp_print_string fmt ";"; + F.pp_close_box fmt () + +(** Small helper. + + Extract the items for a method in a trait decl. + *) +let extract_trait_decl_method_items (ctx : extraction_ctx) (fmt : F.formatter) + (decl : trait_decl) (name : string) (id : fun_decl_id) : unit = + let item_name = ctx_get_trait_const decl.def_id name ctx in + (* Lookup the definition *) + (* let def = + FunDeclId.Map.find ctx. + in *) raise (Failure "TODO") (** Extract a trait declaration *) let extract_trait_decl (ctx : extraction_ctx) (fmt : F.formatter) - (trait_decl : trait_decl) : unit = - raise (Failure "TODO") + (decl : trait_decl) : unit = + (* Retrieve the trait name *) + let with_opaque_pre = false in + let decl_name = ctx_get_trait_decl with_opaque_pre decl.def_id ctx in + (* Add a break before *) + F.pp_print_break fmt 0 0; + (* Print a comment to link the extracted type to its original rust definition *) + extract_comment fmt [ "[" ^ Print.name_to_string decl.name ^ "]" ]; + F.pp_print_break fmt 0 0; + (* Open two boxes for the definition, so that whenever possible it gets printed on + * one line and indents are correct *) + F.pp_open_hvbox fmt 0; + F.pp_open_vbox fmt ctx.indent_incr; + + (* `struct Trait (....) =` *) + (* Open the box for the name + generics *) + F.pp_open_vbox fmt ctx.indent_incr; + let qualif = + Option.get (ctx.fmt.type_decl_kind_to_qualif SingleNonRec (Some Struct)) + in + F.pp_print_string fmt qualif; + F.pp_print_space fmt (); + F.pp_print_string fmt decl_name; + + (* Print the generics *) + (* We ignore the trait clauses, which we extract as *fields* *) + let generics = { decl.generics with trait_clauses = [] } in + (* Add the type and const generic params - note that we need those bindings only for the + * body translation (they are not top-level) *) + let ctx, type_params, cg_params, trait_clauses = + ctx_add_generic_params generics ctx + in + let use_forall = false in + let as_implicits = false in + extract_generic_params ctx fmt TypeDeclId.Set.empty use_forall as_implicits + None None decl.generics type_params cg_params trait_clauses; + + F.pp_print_space fmt (); + F.pp_print_string fmt "{"; + + (* Close the box for the name + generics *) + F.pp_close_box fmt (); + + (* + * Extract the items + *) + + (* The parent clauses *) + List.iter + (fun clause -> + let item_name = + ctx_get_trait_parent_clause decl.def_id clause.clause_id ctx + in + let ty () = + extract_trait_clause_type ctx fmt TypeDeclId.Set.empty clause + in + extract_trait_decl_item ctx fmt item_name ty) + decl.generics.trait_clauses; + + (* The constants *) + List.iter + (fun (name, (ty, _)) -> + let item_name = ctx_get_trait_const decl.def_id name ctx in + let ty () = + let inside = false in + extract_ty ctx fmt TypeDeclId.Set.empty inside ty + in + extract_trait_decl_item ctx fmt item_name ty) + decl.consts; + + (* The types *) + List.iter + (fun (name, (clauses, _)) -> + (* Extract the type *) + let item_name = ctx_get_trait_type decl.def_id name ctx in + let ty () = F.pp_print_string fmt (type_keyword ()) in + extract_trait_decl_item ctx fmt item_name ty; + (* Extract the clauses *) + List.iter + (fun clause -> + let item_name = + ctx_get_trait_item_clause decl.def_id name clause.clause_id ctx + in + let ty () = + extract_trait_clause_type ctx fmt TypeDeclId.Set.empty clause + in + extract_trait_decl_item ctx fmt item_name ty) + clauses) + decl.types; + + (* The required methods *) + List.iter + (fun (name, id) -> extract_trait_decl_method_items ctx fmt decl name id) + decl.required_methods; + + (* Close the brackets *) + F.pp_print_space fmt (); + F.pp_print_string fmt "}"; + + (* Close the two outer boxes for the definition *) + F.pp_close_box fmt (); + F.pp_close_box fmt (); + (* Add breaks to insert new lines between definitions *) + F.pp_print_break fmt 0 0 (** Extract a trait implementation *) let extract_trait_impl (ctx : extraction_ctx) (fmt : F.formatter) diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index 7e6a2d40..26940c0c 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -621,6 +621,7 @@ type fun_name_info = { keep_fwd : bool; num_backs : int } functions, etc. *) type extraction_ctx = { + crate : A.crate; trans_ctx : trans_ctx; names_map : names_map; (** The map for id to names, where we forbid name collisions @@ -661,6 +662,9 @@ type extraction_ctx = { trait_decl_id : trait_decl_id option; (** If we are extracting a trait declaration, identifies it *) is_provided_method : bool; + trans_types : Pure.type_decl Pure.TypeDeclId.Map.t; + trans_funs : (bool * pure_fun_translation) A.FunDeclId.Map.t; + functions_with_decreases_clause : PureUtils.FunLoopIdSet.t; trans_trait_decls : Pure.trait_decl Pure.TraitDeclId.Map.t; trans_trait_impls : Pure.trait_impl Pure.TraitImplId.Map.t; } diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 8df69961..b26ce23b 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -396,14 +396,7 @@ let translate_crate_to_pure (crate : A.crate) : (* Return *) (trans_ctx, type_decls, pure_translations, trait_decls, trait_impls) -(** Extraction context *) -type gen_ctx = { - crate : A.crate; - extract_ctx : ExtractBase.extraction_ctx; - trans_types : Pure.type_decl Pure.TypeDeclId.Map.t; - trans_funs : (bool * pure_fun_translation) A.FunDeclId.Map.t; - functions_with_decreases_clause : PureUtils.FunLoopIdSet.t; -} +type gen_ctx = ExtractBase.extraction_ctx type gen_config = { extract_types : bool; @@ -482,9 +475,9 @@ let export_type (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) || ((not is_opaque) && config.extract_transparent) then ( if extract_decl then - Extract.extract_type_decl ctx.extract_ctx fmt type_decl_group kind def; + Extract.extract_type_decl ctx fmt type_decl_group kind def; if extract_extra_info then - Extract.extract_type_decl_extra_info ctx.extract_ctx fmt kind def) + Extract.extract_type_decl_extra_info ctx fmt kind def) (** Export a group of types. @@ -536,7 +529,7 @@ let export_types_group (fmt : Format.formatter) (config : gen_config) End ]} *) - Extract.start_type_decl_group ctx.extract_ctx fmt is_rec defs; + Extract.start_type_decl_group ctx fmt is_rec defs; List.iteri (fun i def -> let kind = kind_from_index i in @@ -557,7 +550,7 @@ let export_types_group (fmt : Format.formatter) (config : gen_config) *) let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) (id : A.GlobalDeclId.id) : unit = - let global_decls = ctx.extract_ctx.trans_ctx.global_context.global_decls in + let global_decls = ctx.trans_ctx.global_context.global_decls in let global = A.GlobalDeclId.Map.find id global_decls in let _, ((body, loop_fwds), body_backs) = A.FunDeclId.Map.find global.body_id ctx.trans_funs @@ -576,7 +569,7 @@ let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) groups are always singletons, so the [extract_global_decl] function takes care of generating the delimiters. *) - Extract.extract_global_decl ctx.extract_ctx fmt global body config.interface + Extract.extract_global_decl ctx fmt global body config.interface (** Utility. @@ -657,14 +650,13 @@ let export_functions_group_scc (fmt : Format.formatter) (config : gen_config) then Some (fun () -> - Extract.extract_fun_decl ctx.extract_ctx fmt kind has_decr_clause - def) + Extract.extract_fun_decl ctx fmt kind has_decr_clause def) else None) decls in let extract_defs = List.filter_map (fun x -> x) extract_defs in if extract_defs <> [] then ( - Extract.start_fun_decl_group ctx.extract_ctx fmt is_rec decls; + Extract.start_fun_decl_group ctx fmt is_rec decls; List.iter (fun f -> f ()) extract_defs; Extract.end_fun_decl_group fmt is_rec decls) @@ -700,11 +692,10 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) if has_decr_clause then match !Config.backend with | Lean -> - Extract.extract_template_lean_termination_and_decreasing - ctx.extract_ctx fmt decl + Extract.extract_template_lean_termination_and_decreasing ctx fmt + decl | FStar -> - Extract.extract_template_fstar_decreases_clause ctx.extract_ctx - fmt decl + Extract.extract_template_fstar_decreases_clause ctx fmt decl | Coq -> raise (Failure "Coq doesn't have decreases/termination clauses") | HOL4 -> @@ -747,27 +738,21 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) if config.test_trans_unit_functions then List.iter (fun (keep_fwd, ((fwd, _), _)) -> - if keep_fwd then - Extract.extract_unit_test_if_unit_fun ctx.extract_ctx fmt fwd) + if keep_fwd then Extract.extract_unit_test_if_unit_fun ctx fmt fwd) pure_ls (** Export a trait declaration. *) let export_trait_decl (fmt : Format.formatter) (_config : gen_config) (ctx : gen_ctx) (trait_decl_id : Pure.trait_decl_id) : unit = - let trait_decl = - T.TraitDeclId.Map.find trait_decl_id ctx.extract_ctx.trans_trait_decls - in - let ctx = ctx.extract_ctx in + let trait_decl = T.TraitDeclId.Map.find trait_decl_id ctx.trans_trait_decls in let ctx = { ctx with trait_decl_id = Some trait_decl.def_id } in Extract.extract_trait_decl ctx fmt trait_decl (** Export a trait implementation. *) let export_trait_impl (fmt : Format.formatter) (_config : gen_config) (ctx : gen_ctx) (trait_impl_id : Pure.trait_impl_id) : unit = - let trait_impl = - T.TraitImplId.Map.find trait_impl_id ctx.extract_ctx.trans_trait_impls - in - Extract.extract_trait_impl ctx.extract_ctx fmt trait_impl + let trait_impl = T.TraitImplId.Map.find trait_impl_id ctx.trans_trait_impls in + Extract.extract_trait_impl ctx fmt trait_impl (** A generic utility to generate the extracted definitions: as we may want to split the definitions between different files (or not), we can control @@ -790,7 +775,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) let kind = if config.interface then ExtractBase.Declared else ExtractBase.Assumed in - Extract.extract_state_type fmt ctx.extract_ctx kind + Extract.extract_state_type fmt ctx kind in let export_decl_group (dg : A.declaration_group) : unit = @@ -856,7 +841,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) if config.extract_transparent then "Definitions" else "OpaqueDefs" in Format.pp_print_break fmt 0 0; - Format.pp_open_vbox fmt ctx.extract_ctx.indent_incr; + Format.pp_open_vbox fmt ctx.indent_incr; Format.pp_print_string fmt ("structure " ^ struct_name ^ " where"); Format.pp_print_break fmt 0 0); List.iter export_decl_group ctx.crate.declarations; @@ -1005,6 +990,43 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : mk_formatter_and_names_map trans_ctx crate.name variant_concatenate_type_name in + + (* We need to compute which functions are recursive, in order to know + * whether we should generate a decrease clause or not. *) + let rec_functions = + List.map + (fun (_, ((fwd, loop_fwds), _)) -> + let fwd = + if fwd.Pure.signature.info.effect_info.is_rec then + [ (fwd.def_id, None) ] + else [] + in + let loop_fwds = + List.map + (fun (def : Pure.fun_decl) -> [ (def.def_id, def.loop_id) ]) + loop_fwds + in + fwd :: loop_fwds) + trans_funs + in + let rec_functions : PureUtils.fun_loop_id list = + List.concat (List.concat rec_functions) + in + let rec_functions = PureUtils.FunLoopIdSet.of_list rec_functions in + + (* Put the translated definitions in maps *) + let trans_types = + Pure.TypeDeclId.Map.of_list + (List.map (fun (d : Pure.type_decl) -> (d.def_id, d)) trans_types) + in + let trans_funs = + A.FunDeclId.Map.of_list + (List.map + (fun ((keep_fwd, (fd, bdl)) : bool * pure_fun_translation) -> + ((fst fd).def_id, (keep_fwd, (fd, bdl)))) + trans_funs) + in + (* Put everything in the context *) let ctx = let trans_trait_decls = @@ -1020,7 +1042,8 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : trans_trait_impls) in { - ExtractBase.trans_ctx; + ExtractBase.crate; + trans_ctx; names_map; unsafe_names_map = { id_to_name = ExtractBase.IdMap.empty }; fmt; @@ -1032,32 +1055,12 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : is_provided_method = false (* false by default *); trans_trait_decls; trans_trait_impls; + trans_types; + trans_funs; + functions_with_decreases_clause = rec_functions; } in - (* We need to compute which functions are recursive, in order to know - * whether we should generate a decrease clause or not. *) - let rec_functions = - List.map - (fun (_, ((fwd, loop_fwds), _)) -> - let fwd = - if fwd.Pure.signature.info.effect_info.is_rec then - [ (fwd.def_id, None) ] - else [] - in - let loop_fwds = - List.map - (fun (def : Pure.fun_decl) -> [ (def.def_id, def.loop_id) ]) - loop_fwds - in - fwd :: loop_fwds) - trans_funs - in - let rec_functions : PureUtils.fun_loop_id list = - List.concat (List.concat rec_functions) - in - let rec_functions = PureUtils.FunLoopIdSet.of_list rec_functions in - (* Register unique names for all the top-level types, globals, functions... * Note that the order in which we generate the names doesn't matter: * we just need to generate a mapping from identifier to name, and make @@ -1065,7 +1068,8 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : let ctx = List.fold_left (fun ctx def -> Extract.extract_type_decl_register_names ctx def) - ctx trans_types + ctx + (Pure.TypeDeclId.Map.values trans_types) in let ctx = @@ -1087,7 +1091,8 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : else Extract.extract_fun_decl_register_names ctx keep_fwd gen_decr_clause defs) - ctx trans_funs + ctx + (A.FunDeclId.Map.values trans_funs) in let ctx = @@ -1133,19 +1138,6 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : (namespace, crate_name, Filename.concat dest_dir crate_name) in - (* Put the translated definitions in maps *) - let trans_types = - Pure.TypeDeclId.Map.of_list - (List.map (fun (d : Pure.type_decl) -> (d.def_id, d)) trans_types) - in - let trans_funs = - A.FunDeclId.Map.of_list - (List.map - (fun ((keep_fwd, (fd, bdl)) : bool * pure_fun_translation) -> - ((fst fd).def_id, (keep_fwd, (fd, bdl)))) - trans_funs) - in - let mkdir_if dest_dir = if not (Sys.file_exists dest_dir) then ( log#linfo (lazy ("Creating missing directory: " ^ dest_dir)); @@ -1201,16 +1193,6 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : in (* Extract the file(s) *) - let gen_ctx = - { - crate; - extract_ctx = ctx; - trans_types; - trans_funs; - functions_with_decreases_clause = rec_functions; - } - in - let module_delimiter = match !Config.backend with | FStar -> "." @@ -1257,7 +1239,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : (* Check if there are opaque types and functions - in which case we need * to split *) - let has_opaque_types, has_opaque_funs = module_has_opaque_decls gen_ctx in + let has_opaque_types, has_opaque_funs = module_has_opaque_decls ctx in let has_opaque_types = has_opaque_types || !Config.use_state in (* Extract the types *) @@ -1296,7 +1278,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : custom_includes = []; } in - extract_file types_config gen_ctx file_info; + extract_file types_config ctx file_info; (* Extract the template clauses *) (if needs_clauses_module && !Config.extract_template_decreases_clauses then @@ -1324,7 +1306,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : custom_includes = []; } in - extract_file template_clauses_config gen_ctx file_info); + extract_file template_clauses_config ctx file_info); (* Extract the opaque functions, if needed *) let opaque_funs_module = @@ -1359,12 +1341,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : interface = true; } in - let gen_ctx = - { - gen_ctx with - extract_ctx = { gen_ctx.extract_ctx with use_opaque_pre = false }; - } - in + let ctx = { ctx with use_opaque_pre = false } in let file_info = { filename = opaque_filename; @@ -1378,7 +1355,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : custom_includes = [ types_module ]; } in - extract_file opaque_config gen_ctx file_info; + extract_file opaque_config ctx file_info; (* Return the additional dependencies *) [ opaque_imported_module ]) else [] @@ -1417,7 +1394,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : [ types_module ] @ opaque_funs_module @ clauses_module; } in - extract_file fun_config gen_ctx file_info) + extract_file fun_config ctx file_info) else let gen_config = { @@ -1447,7 +1424,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : custom_includes = []; } in - extract_file gen_config gen_ctx file_info); + extract_file gen_config ctx file_info); (* Generate the build file *) match !Config.backend with -- cgit v1.2.3