summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-09-03 20:12:59 +0200
committerSon Ho2023-09-03 20:12:59 +0200
commitfcd1fbe048b55a89bd8ed34afa8ed2295798d3ec (patch)
treee7e130ba33f0644ffe5fbdd291b738f204bd86c8
parente090e09725e3fd5c7f2a92813955ce2d81560227 (diff)
Make progress registering the trait decl method names
-rw-r--r--compiler/Extract.ml50
-rw-r--r--compiler/ExtractBase.ml74
2 files changed, 93 insertions, 31 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 204fee04..2a678a27 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -2329,8 +2329,8 @@ let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx)
let extract_fun_decl_register_names (ctx : extraction_ctx)
(has_decreases_clause : fun_decl -> bool) (def : pure_fun_translation) :
extraction_ctx =
- let { f = fwd; loops = loop_fwds } = def.fwd in
- let back_ls = def.backs in
+ let fwd = def.fwd in
+ let backs = def.backs in
(* Register the decrease clauses, if necessary *)
let register_decreases ctx def =
if has_decreases_clause def then
@@ -2343,22 +2343,19 @@ let extract_fun_decl_register_names (ctx : extraction_ctx)
| Lean -> ctx_add_decreases_proof def ctx
else ctx
in
- let ctx = List.fold_left register_decreases ctx (fwd :: loop_fwds) in
+ let ctx = List.fold_left register_decreases ctx (fwd.f :: fwd.loops) in
let register_fun ctx f = ctx_add_fun_decl def f ctx in
let register_funs ctx fl = List.fold_left register_fun ctx fl in
- (* Register the forward functions' names *)
- let ctx = register_funs ctx (fwd :: loop_fwds) in
- (* Register the backward functions' names *)
+ (* Register the names of the forward functions *)
let ctx =
- List.fold_left
- (fun ctx { f = back; loops = loop_backs } ->
- let ctx = register_fun ctx back in
- register_funs ctx loop_backs)
- ctx back_ls
+ if def.keep_fwd then register_funs ctx (fwd.f :: fwd.loops) else ctx
in
-
- (* Return *)
- ctx
+ (* Register the names of the backward functions *)
+ List.fold_left
+ (fun ctx { f = back; loops = loop_backs } ->
+ let ctx = register_fun ctx back in
+ register_funs ctx loop_backs)
+ ctx backs
(** Simply add the global name to the context. *)
let extract_global_decl_register_names (ctx : extraction_ctx)
@@ -3927,6 +3924,27 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter)
(* Add a break to insert lines between declarations *)
F.pp_print_break fmt 0 0
+(** Register the names for one trait method item *)
+let extract_trait_decl_method_register_names (ctx : extraction_ctx)
+ (trait_decl : trait_decl) (name : string) (id : fun_decl_id) :
+ extraction_ctx =
+ (* We add one field per required forward/backward function *)
+ let trans = A.FunDeclId.Map.find id ctx.trans_funs in
+
+ let register_fun ctx f = ctx_add_trait_method trait_decl name f ctx in
+ let register_funs ctx fl = List.fold_left register_fun ctx fl in
+ (* Register the names of the forward functions *)
+ let ctx =
+ if trans.keep_fwd then register_funs ctx (trans.fwd.f :: trans.fwd.loops)
+ else ctx
+ in
+ (* Register the names of the backward functions *)
+ List.fold_left
+ (fun ctx back ->
+ let ctx = register_fun ctx back.f in
+ register_funs ctx back.loops)
+ ctx trans.backs
+
(** Similar to {!extract_type_decl_register_names} *)
let extract_trait_decl_register_names (ctx : extraction_ctx)
(trait_decl : trait_decl) : extraction_ctx =
@@ -3968,12 +3986,10 @@ let extract_trait_decl_register_names (ctx : extraction_ctx)
ctx types
in
(* Required methods *)
- (* TODO: for the methods, we need to add fields for the forward/backward functions *)
- raise (Failure "TODO");
List.fold_left
(fun ctx (name, id) ->
(* We add one field per required forward/backward function *)
- ctx_add_trait_method trait_decl name ctx)
+ extract_trait_decl_method_register_names ctx trait_decl name id)
ctx required_methods
(** Similar to {!extract_type_decl_register_names} *)
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index 17f5b693..e4d1fb7b 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -417,7 +417,14 @@ type id =
| TraitDeclId of TraitDeclId.id
| TraitImplId of TraitImplId.id
| LocalTraitClauseId of TraitClauseId.id
- | TraitItemId of TraitDeclId.id * string (** A trait associated item *)
+ | TraitMethodId of
+ TraitDeclId.id * string * LoopId.id option * T.RegionGroupId.id option
+ (** Something peculiar with trait methods: because we have to take into
+ account forward/backward functions, we may need to generate fields
+ items per method.
+ *)
+ | TraitItemId of TraitDeclId.id * string
+ (** A trait associated item which is not a method *)
| TraitParentClauseId of TraitDeclId.id * TraitClauseId.id
| TraitItemClauseId of TraitDeclId.id * string * TraitClauseId.id
| TraitSelfClauseId
@@ -677,6 +684,7 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
let global_decls = ctx.trans_ctx.global_context.global_decls in
let fun_decls = ctx.trans_ctx.fun_context.fun_decls in
let type_decls = ctx.trans_ctx.type_context.type_decls in
+ let trait_decls = ctx.trans_ctx.trait_decls_context.trait_decls in
(* TODO: factorize the pretty-printing with what is in PrintPure *)
let get_type_name (id : type_id) : string =
match id with
@@ -812,6 +820,24 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
| TraitItemId (id, name) ->
"trait_item_id: decl_id:" ^ TraitDeclId.to_string id ^ ", type name: "
^ name
+ | TraitMethodId (trait_decl_id, fun_name, lp_id, rg_id) ->
+ let trait_name =
+ Print.fun_name_to_string
+ (A.TraitDeclId.Map.find trait_decl_id trait_decls).name
+ in
+ let lp_kind =
+ match lp_id with
+ | None -> ""
+ | Some lp_id -> "loop " ^ LoopId.to_string lp_id ^ ", "
+ in
+
+ let fwd_back_kind =
+ match rg_id with
+ | None -> "forward"
+ | Some rg_id -> "backward " ^ RegionGroupId.to_string rg_id
+ in
+ "trait " ^ trait_name ^ " method name (" ^ lp_kind ^ fwd_back_kind ^ "): "
+ ^ fun_name
| TraitSelfClauseId -> "trait_self_clause"
(** We might not check for collisions for some specific ids (ex.: field names) *)
@@ -1185,11 +1211,8 @@ let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) :
let ctx = ctx_add is_opaque body (name ^ "_body") ctx in
ctx
-let ctx_add_fun_decl (trans_group : pure_fun_translation) (def : fun_decl)
- (ctx : extraction_ctx) : extraction_ctx =
- (* Sanity check: the function should not be a global body - those are handled
- * separately *)
- assert (not def.is_global_decl_body);
+let ctx_compute_fun_name (trans_group : pure_fun_translation) (def : fun_decl)
+ (ctx : extraction_ctx) : string =
(* Lookup the LLBC def to compute the region group information *)
let def_id = def.def_id in
let llbc_def =
@@ -1211,12 +1234,22 @@ let ctx_add_fun_decl (trans_group : pure_fun_translation) (def : fun_decl)
in
Some { id = rg_id; region_names }
in
+ (* Add the function name *)
+ ctx.fmt.fun_name def.basename def.num_loops def.loop_id num_rgs rg_info
+ (keep_fwd, num_backs)
+
+let ctx_add_fun_decl (trans_group : pure_fun_translation) (def : fun_decl)
+ (ctx : extraction_ctx) : extraction_ctx =
+ (* Sanity check: the function should not be a global body - those are handled
+ * separately *)
+ assert (not def.is_global_decl_body);
+ (* Lookup the LLBC def to compute the region group information *)
+ let def_id = def.def_id in
+ let { keep_fwd; fwd = _; backs } = trans_group in
+ let num_backs = List.length backs in
let is_opaque = def.body = None in
(* Add the function name *)
- let def_name =
- ctx.fmt.fun_name def.basename def.num_loops def.loop_id num_rgs rg_info
- (keep_fwd, num_backs)
- in
+ let def_name = ctx_compute_fun_name trans_group def ctx in
let fun_id = (A.Regular def_id, def.loop_id, def.back_id) in
let ctx = ctx_add is_opaque (FunId (FromLlbc fun_id)) def_name ctx in
(* Add the name info *)
@@ -1251,11 +1284,24 @@ let ctx_add_trait_type (d : trait_decl) (item : string) (ctx : extraction_ctx) :
let name = ctx.fmt.trait_type_name d item in
ctx_add is_opaque (TraitItemId (d.def_id, item)) name ctx
-let ctx_add_trait_method (d : trait_decl) (item : string) (ctx : extraction_ctx)
- : extraction_ctx =
+let ctx_add_trait_method (d : trait_decl) (item_name : string) (f : fun_decl)
+ (ctx : extraction_ctx) : extraction_ctx =
+ (* We do something special: we use the base name but remove everything
+ but the crate (because [get_name] removes it) and the last ident.
+ This allows us to reuse the [ctx_compute_fun_decl] function.
+ *)
+ let basename : name =
+ match (f.basename : name) with
+ | Ident crate :: name -> Ident crate :: [ Collections.List.last name ]
+ | _ -> raise (Failure "Unexpected")
+ in
+ let f = { f with basename } in
+ let trans = A.FunDeclId.Map.find f.def_id ctx.trans_funs in
+ let name = ctx_compute_fun_name trans f ctx in
let is_opaque = false in
- let name = ctx.fmt.trait_method_name d item in
- ctx_add is_opaque (TraitItemId (d.def_id, item)) name ctx
+ ctx_add is_opaque
+ (TraitMethodId (d.def_id, item_name, f.loop_id, f.back_id))
+ name ctx
let ctx_add_trait_parent_clause (d : trait_decl) (clause : trait_clause)
(ctx : extraction_ctx) : extraction_ctx =