summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorSon Ho2023-09-03 18:59:19 +0200
committerSon Ho2023-09-03 18:59:19 +0200
commit9fe9fc0ab70b8629722d60748bbede554017172c (patch)
tree6dcd844f6850463156f577464f8a9d62f876bbd0 /compiler
parent0c0b7692cc3d95adf21bccf83d5bb2f81487ca4f (diff)
Make progress on extracting trait decls and merge gen_ctx and extraction_ctx
Diffstat (limited to '')
-rw-r--r--compiler/Extract.ml150
-rw-r--r--compiler/ExtractBase.ml4
-rw-r--r--compiler/Translate.ml163
3 files changed, 219 insertions, 98 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 5eb30daa..f911290e 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -3943,16 +3943,19 @@ let extract_trait_decl_register_names (ctx : extraction_ctx)
trait_decl
in
let ctx = ctx_add_trait_decl trait_decl ctx in
+ (* Parent clauses *)
let ctx =
List.fold_left
(fun ctx clause -> ctx_add_trait_parent_clause trait_decl clause ctx)
ctx generics.trait_clauses
in
+ (* Constants *)
let ctx =
List.fold_left
(fun ctx (name, (_, _)) -> ctx_add_trait_const trait_decl name ctx)
ctx consts
in
+ (* Types *)
let ctx =
List.fold_left
(fun ctx (name, (clauses, _)) ->
@@ -3963,19 +3966,156 @@ let extract_trait_decl_register_names (ctx : extraction_ctx)
ctx clauses)
ctx types
in
+ (* Required methods *)
+ (* TODO: for the methods, we need to add fields for the forward/backward functions *)
+ raise (Failure "TODO");
List.fold_left
- (fun ctx (name, _) -> ctx_add_trait_method trait_decl name ctx)
+ (fun ctx (name, id) -> ctx_add_trait_method trait_decl name ctx)
ctx required_methods
(** Similar to {!extract_type_decl_register_names} *)
-let extract_trait_impl_register_names (ctx : extraction_ctx) (d : trait_impl) :
- extraction_ctx =
+let extract_trait_impl_register_names (ctx : extraction_ctx)
+ (trait_impl : trait_impl) : extraction_ctx =
+ (* For now we do not support overriding provided methods *)
+ assert (trait_impl.provided_methods = []);
+ (* Everything is actually taken care of by {!extract_trait_decl_register_names} *)
+ ctx
+
+(** Small helper.
+
+ The type `ty` is to be understood in a very general sense.
+ *)
+let extract_trait_decl_item (ctx : extraction_ctx) (fmt : F.formatter)
+ (item_name : string) (ty : unit -> unit) : unit =
+ F.pp_print_space fmt ();
+ F.pp_open_vbox fmt ctx.indent_incr;
+ F.pp_print_string fmt item_name;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ ty ();
+ F.pp_print_string fmt ";";
+ F.pp_close_box fmt ()
+
+(** Small helper.
+
+ Extract the items for a method in a trait decl.
+ *)
+let extract_trait_decl_method_items (ctx : extraction_ctx) (fmt : F.formatter)
+ (decl : trait_decl) (name : string) (id : fun_decl_id) : unit =
+ let item_name = ctx_get_trait_const decl.def_id name ctx in
+ (* Lookup the definition *)
+ (* let def =
+ FunDeclId.Map.find ctx.
+ in *)
raise (Failure "TODO")
(** Extract a trait declaration *)
let extract_trait_decl (ctx : extraction_ctx) (fmt : F.formatter)
- (trait_decl : trait_decl) : unit =
- raise (Failure "TODO")
+ (decl : trait_decl) : unit =
+ (* Retrieve the trait name *)
+ let with_opaque_pre = false in
+ let decl_name = ctx_get_trait_decl with_opaque_pre decl.def_id ctx in
+ (* Add a break before *)
+ F.pp_print_break fmt 0 0;
+ (* Print a comment to link the extracted type to its original rust definition *)
+ extract_comment fmt [ "[" ^ Print.name_to_string decl.name ^ "]" ];
+ F.pp_print_break fmt 0 0;
+ (* Open two boxes for the definition, so that whenever possible it gets printed on
+ * one line and indents are correct *)
+ F.pp_open_hvbox fmt 0;
+ F.pp_open_vbox fmt ctx.indent_incr;
+
+ (* `struct Trait (....) =` *)
+ (* Open the box for the name + generics *)
+ F.pp_open_vbox fmt ctx.indent_incr;
+ let qualif =
+ Option.get (ctx.fmt.type_decl_kind_to_qualif SingleNonRec (Some Struct))
+ in
+ F.pp_print_string fmt qualif;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt decl_name;
+
+ (* Print the generics *)
+ (* We ignore the trait clauses, which we extract as *fields* *)
+ let generics = { decl.generics with trait_clauses = [] } in
+ (* Add the type and const generic params - note that we need those bindings only for the
+ * body translation (they are not top-level) *)
+ let ctx, type_params, cg_params, trait_clauses =
+ ctx_add_generic_params generics ctx
+ in
+ let use_forall = false in
+ let as_implicits = false in
+ extract_generic_params ctx fmt TypeDeclId.Set.empty use_forall as_implicits
+ None None decl.generics type_params cg_params trait_clauses;
+
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "{";
+
+ (* Close the box for the name + generics *)
+ F.pp_close_box fmt ();
+
+ (*
+ * Extract the items
+ *)
+
+ (* The parent clauses *)
+ List.iter
+ (fun clause ->
+ let item_name =
+ ctx_get_trait_parent_clause decl.def_id clause.clause_id ctx
+ in
+ let ty () =
+ extract_trait_clause_type ctx fmt TypeDeclId.Set.empty clause
+ in
+ extract_trait_decl_item ctx fmt item_name ty)
+ decl.generics.trait_clauses;
+
+ (* The constants *)
+ List.iter
+ (fun (name, (ty, _)) ->
+ let item_name = ctx_get_trait_const decl.def_id name ctx in
+ let ty () =
+ let inside = false in
+ extract_ty ctx fmt TypeDeclId.Set.empty inside ty
+ in
+ extract_trait_decl_item ctx fmt item_name ty)
+ decl.consts;
+
+ (* The types *)
+ List.iter
+ (fun (name, (clauses, _)) ->
+ (* Extract the type *)
+ let item_name = ctx_get_trait_type decl.def_id name ctx in
+ let ty () = F.pp_print_string fmt (type_keyword ()) in
+ extract_trait_decl_item ctx fmt item_name ty;
+ (* Extract the clauses *)
+ List.iter
+ (fun clause ->
+ let item_name =
+ ctx_get_trait_item_clause decl.def_id name clause.clause_id ctx
+ in
+ let ty () =
+ extract_trait_clause_type ctx fmt TypeDeclId.Set.empty clause
+ in
+ extract_trait_decl_item ctx fmt item_name ty)
+ clauses)
+ decl.types;
+
+ (* The required methods *)
+ List.iter
+ (fun (name, id) -> extract_trait_decl_method_items ctx fmt decl name id)
+ decl.required_methods;
+
+ (* Close the brackets *)
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "}";
+
+ (* Close the two outer boxes for the definition *)
+ F.pp_close_box fmt ();
+ F.pp_close_box fmt ();
+ (* Add breaks to insert new lines between definitions *)
+ F.pp_print_break fmt 0 0
(** Extract a trait implementation *)
let extract_trait_impl (ctx : extraction_ctx) (fmt : F.formatter)
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index 7e6a2d40..26940c0c 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -621,6 +621,7 @@ type fun_name_info = { keep_fwd : bool; num_backs : int }
functions, etc.
*)
type extraction_ctx = {
+ crate : A.crate;
trans_ctx : trans_ctx;
names_map : names_map;
(** The map for id to names, where we forbid name collisions
@@ -661,6 +662,9 @@ type extraction_ctx = {
trait_decl_id : trait_decl_id option;
(** If we are extracting a trait declaration, identifies it *)
is_provided_method : bool;
+ trans_types : Pure.type_decl Pure.TypeDeclId.Map.t;
+ trans_funs : (bool * pure_fun_translation) A.FunDeclId.Map.t;
+ functions_with_decreases_clause : PureUtils.FunLoopIdSet.t;
trans_trait_decls : Pure.trait_decl Pure.TraitDeclId.Map.t;
trans_trait_impls : Pure.trait_impl Pure.TraitImplId.Map.t;
}
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 8df69961..b26ce23b 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -396,14 +396,7 @@ let translate_crate_to_pure (crate : A.crate) :
(* Return *)
(trans_ctx, type_decls, pure_translations, trait_decls, trait_impls)
-(** Extraction context *)
-type gen_ctx = {
- crate : A.crate;
- extract_ctx : ExtractBase.extraction_ctx;
- trans_types : Pure.type_decl Pure.TypeDeclId.Map.t;
- trans_funs : (bool * pure_fun_translation) A.FunDeclId.Map.t;
- functions_with_decreases_clause : PureUtils.FunLoopIdSet.t;
-}
+type gen_ctx = ExtractBase.extraction_ctx
type gen_config = {
extract_types : bool;
@@ -482,9 +475,9 @@ let export_type (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx)
|| ((not is_opaque) && config.extract_transparent)
then (
if extract_decl then
- Extract.extract_type_decl ctx.extract_ctx fmt type_decl_group kind def;
+ Extract.extract_type_decl ctx fmt type_decl_group kind def;
if extract_extra_info then
- Extract.extract_type_decl_extra_info ctx.extract_ctx fmt kind def)
+ Extract.extract_type_decl_extra_info ctx fmt kind def)
(** Export a group of types.
@@ -536,7 +529,7 @@ let export_types_group (fmt : Format.formatter) (config : gen_config)
End
]}
*)
- Extract.start_type_decl_group ctx.extract_ctx fmt is_rec defs;
+ Extract.start_type_decl_group ctx fmt is_rec defs;
List.iteri
(fun i def ->
let kind = kind_from_index i in
@@ -557,7 +550,7 @@ let export_types_group (fmt : Format.formatter) (config : gen_config)
*)
let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx)
(id : A.GlobalDeclId.id) : unit =
- let global_decls = ctx.extract_ctx.trans_ctx.global_context.global_decls in
+ let global_decls = ctx.trans_ctx.global_context.global_decls in
let global = A.GlobalDeclId.Map.find id global_decls in
let _, ((body, loop_fwds), body_backs) =
A.FunDeclId.Map.find global.body_id ctx.trans_funs
@@ -576,7 +569,7 @@ let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx)
groups are always singletons, so the [extract_global_decl] function
takes care of generating the delimiters.
*)
- Extract.extract_global_decl ctx.extract_ctx fmt global body config.interface
+ Extract.extract_global_decl ctx fmt global body config.interface
(** Utility.
@@ -657,14 +650,13 @@ let export_functions_group_scc (fmt : Format.formatter) (config : gen_config)
then
Some
(fun () ->
- Extract.extract_fun_decl ctx.extract_ctx fmt kind has_decr_clause
- def)
+ Extract.extract_fun_decl ctx fmt kind has_decr_clause def)
else None)
decls
in
let extract_defs = List.filter_map (fun x -> x) extract_defs in
if extract_defs <> [] then (
- Extract.start_fun_decl_group ctx.extract_ctx fmt is_rec decls;
+ Extract.start_fun_decl_group ctx fmt is_rec decls;
List.iter (fun f -> f ()) extract_defs;
Extract.end_fun_decl_group fmt is_rec decls)
@@ -700,11 +692,10 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
if has_decr_clause then
match !Config.backend with
| Lean ->
- Extract.extract_template_lean_termination_and_decreasing
- ctx.extract_ctx fmt decl
+ Extract.extract_template_lean_termination_and_decreasing ctx fmt
+ decl
| FStar ->
- Extract.extract_template_fstar_decreases_clause ctx.extract_ctx
- fmt decl
+ Extract.extract_template_fstar_decreases_clause ctx fmt decl
| Coq ->
raise (Failure "Coq doesn't have decreases/termination clauses")
| HOL4 ->
@@ -747,27 +738,21 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
if config.test_trans_unit_functions then
List.iter
(fun (keep_fwd, ((fwd, _), _)) ->
- if keep_fwd then
- Extract.extract_unit_test_if_unit_fun ctx.extract_ctx fmt fwd)
+ if keep_fwd then Extract.extract_unit_test_if_unit_fun ctx fmt fwd)
pure_ls
(** Export a trait declaration. *)
let export_trait_decl (fmt : Format.formatter) (_config : gen_config)
(ctx : gen_ctx) (trait_decl_id : Pure.trait_decl_id) : unit =
- let trait_decl =
- T.TraitDeclId.Map.find trait_decl_id ctx.extract_ctx.trans_trait_decls
- in
- let ctx = ctx.extract_ctx in
+ let trait_decl = T.TraitDeclId.Map.find trait_decl_id ctx.trans_trait_decls in
let ctx = { ctx with trait_decl_id = Some trait_decl.def_id } in
Extract.extract_trait_decl ctx fmt trait_decl
(** Export a trait implementation. *)
let export_trait_impl (fmt : Format.formatter) (_config : gen_config)
(ctx : gen_ctx) (trait_impl_id : Pure.trait_impl_id) : unit =
- let trait_impl =
- T.TraitImplId.Map.find trait_impl_id ctx.extract_ctx.trans_trait_impls
- in
- Extract.extract_trait_impl ctx.extract_ctx fmt trait_impl
+ let trait_impl = T.TraitImplId.Map.find trait_impl_id ctx.trans_trait_impls in
+ Extract.extract_trait_impl ctx fmt trait_impl
(** A generic utility to generate the extracted definitions: as we may want to
split the definitions between different files (or not), we can control
@@ -790,7 +775,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
let kind =
if config.interface then ExtractBase.Declared else ExtractBase.Assumed
in
- Extract.extract_state_type fmt ctx.extract_ctx kind
+ Extract.extract_state_type fmt ctx kind
in
let export_decl_group (dg : A.declaration_group) : unit =
@@ -856,7 +841,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
if config.extract_transparent then "Definitions" else "OpaqueDefs"
in
Format.pp_print_break fmt 0 0;
- Format.pp_open_vbox fmt ctx.extract_ctx.indent_incr;
+ Format.pp_open_vbox fmt ctx.indent_incr;
Format.pp_print_string fmt ("structure " ^ struct_name ^ " where");
Format.pp_print_break fmt 0 0);
List.iter export_decl_group ctx.crate.declarations;
@@ -1005,6 +990,43 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
mk_formatter_and_names_map trans_ctx crate.name
variant_concatenate_type_name
in
+
+ (* We need to compute which functions are recursive, in order to know
+ * whether we should generate a decrease clause or not. *)
+ let rec_functions =
+ List.map
+ (fun (_, ((fwd, loop_fwds), _)) ->
+ let fwd =
+ if fwd.Pure.signature.info.effect_info.is_rec then
+ [ (fwd.def_id, None) ]
+ else []
+ in
+ let loop_fwds =
+ List.map
+ (fun (def : Pure.fun_decl) -> [ (def.def_id, def.loop_id) ])
+ loop_fwds
+ in
+ fwd :: loop_fwds)
+ trans_funs
+ in
+ let rec_functions : PureUtils.fun_loop_id list =
+ List.concat (List.concat rec_functions)
+ in
+ let rec_functions = PureUtils.FunLoopIdSet.of_list rec_functions in
+
+ (* Put the translated definitions in maps *)
+ let trans_types =
+ Pure.TypeDeclId.Map.of_list
+ (List.map (fun (d : Pure.type_decl) -> (d.def_id, d)) trans_types)
+ in
+ let trans_funs =
+ A.FunDeclId.Map.of_list
+ (List.map
+ (fun ((keep_fwd, (fd, bdl)) : bool * pure_fun_translation) ->
+ ((fst fd).def_id, (keep_fwd, (fd, bdl))))
+ trans_funs)
+ in
+
(* Put everything in the context *)
let ctx =
let trans_trait_decls =
@@ -1020,7 +1042,8 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
trans_trait_impls)
in
{
- ExtractBase.trans_ctx;
+ ExtractBase.crate;
+ trans_ctx;
names_map;
unsafe_names_map = { id_to_name = ExtractBase.IdMap.empty };
fmt;
@@ -1032,32 +1055,12 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
is_provided_method = false (* false by default *);
trans_trait_decls;
trans_trait_impls;
+ trans_types;
+ trans_funs;
+ functions_with_decreases_clause = rec_functions;
}
in
- (* We need to compute which functions are recursive, in order to know
- * whether we should generate a decrease clause or not. *)
- let rec_functions =
- List.map
- (fun (_, ((fwd, loop_fwds), _)) ->
- let fwd =
- if fwd.Pure.signature.info.effect_info.is_rec then
- [ (fwd.def_id, None) ]
- else []
- in
- let loop_fwds =
- List.map
- (fun (def : Pure.fun_decl) -> [ (def.def_id, def.loop_id) ])
- loop_fwds
- in
- fwd :: loop_fwds)
- trans_funs
- in
- let rec_functions : PureUtils.fun_loop_id list =
- List.concat (List.concat rec_functions)
- in
- let rec_functions = PureUtils.FunLoopIdSet.of_list rec_functions in
-
(* Register unique names for all the top-level types, globals, functions...
* Note that the order in which we generate the names doesn't matter:
* we just need to generate a mapping from identifier to name, and make
@@ -1065,7 +1068,8 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
let ctx =
List.fold_left
(fun ctx def -> Extract.extract_type_decl_register_names ctx def)
- ctx trans_types
+ ctx
+ (Pure.TypeDeclId.Map.values trans_types)
in
let ctx =
@@ -1087,7 +1091,8 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
else
Extract.extract_fun_decl_register_names ctx keep_fwd gen_decr_clause
defs)
- ctx trans_funs
+ ctx
+ (A.FunDeclId.Map.values trans_funs)
in
let ctx =
@@ -1133,19 +1138,6 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
(namespace, crate_name, Filename.concat dest_dir crate_name)
in
- (* Put the translated definitions in maps *)
- let trans_types =
- Pure.TypeDeclId.Map.of_list
- (List.map (fun (d : Pure.type_decl) -> (d.def_id, d)) trans_types)
- in
- let trans_funs =
- A.FunDeclId.Map.of_list
- (List.map
- (fun ((keep_fwd, (fd, bdl)) : bool * pure_fun_translation) ->
- ((fst fd).def_id, (keep_fwd, (fd, bdl))))
- trans_funs)
- in
-
let mkdir_if dest_dir =
if not (Sys.file_exists dest_dir) then (
log#linfo (lazy ("Creating missing directory: " ^ dest_dir));
@@ -1201,16 +1193,6 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
in
(* Extract the file(s) *)
- let gen_ctx =
- {
- crate;
- extract_ctx = ctx;
- trans_types;
- trans_funs;
- functions_with_decreases_clause = rec_functions;
- }
- in
-
let module_delimiter =
match !Config.backend with
| FStar -> "."
@@ -1257,7 +1239,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
(* Check if there are opaque types and functions - in which case we need
* to split *)
- let has_opaque_types, has_opaque_funs = module_has_opaque_decls gen_ctx in
+ let has_opaque_types, has_opaque_funs = module_has_opaque_decls ctx in
let has_opaque_types = has_opaque_types || !Config.use_state in
(* Extract the types *)
@@ -1296,7 +1278,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
custom_includes = [];
}
in
- extract_file types_config gen_ctx file_info;
+ extract_file types_config ctx file_info;
(* Extract the template clauses *)
(if needs_clauses_module && !Config.extract_template_decreases_clauses then
@@ -1324,7 +1306,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
custom_includes = [];
}
in
- extract_file template_clauses_config gen_ctx file_info);
+ extract_file template_clauses_config ctx file_info);
(* Extract the opaque functions, if needed *)
let opaque_funs_module =
@@ -1359,12 +1341,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
interface = true;
}
in
- let gen_ctx =
- {
- gen_ctx with
- extract_ctx = { gen_ctx.extract_ctx with use_opaque_pre = false };
- }
- in
+ let ctx = { ctx with use_opaque_pre = false } in
let file_info =
{
filename = opaque_filename;
@@ -1378,7 +1355,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
custom_includes = [ types_module ];
}
in
- extract_file opaque_config gen_ctx file_info;
+ extract_file opaque_config ctx file_info;
(* Return the additional dependencies *)
[ opaque_imported_module ])
else []
@@ -1417,7 +1394,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
[ types_module ] @ opaque_funs_module @ clauses_module;
}
in
- extract_file fun_config gen_ctx file_info)
+ extract_file fun_config ctx file_info)
else
let gen_config =
{
@@ -1447,7 +1424,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
custom_includes = [];
}
in
- extract_file gen_config gen_ctx file_info);
+ extract_file gen_config ctx file_info);
(* Generate the build file *)
match !Config.backend with