summaryrefslogtreecommitdiff
path: root/compiler/Extract.ml
diff options
context:
space:
mode:
authorSon Ho2023-09-03 20:12:59 +0200
committerSon Ho2023-09-03 20:12:59 +0200
commitfcd1fbe048b55a89bd8ed34afa8ed2295798d3ec (patch)
treee7e130ba33f0644ffe5fbdd291b738f204bd86c8 /compiler/Extract.ml
parente090e09725e3fd5c7f2a92813955ce2d81560227 (diff)
Make progress registering the trait decl method names
Diffstat (limited to 'compiler/Extract.ml')
-rw-r--r--compiler/Extract.ml50
1 files changed, 33 insertions, 17 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 204fee04..2a678a27 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -2329,8 +2329,8 @@ let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx)
let extract_fun_decl_register_names (ctx : extraction_ctx)
(has_decreases_clause : fun_decl -> bool) (def : pure_fun_translation) :
extraction_ctx =
- let { f = fwd; loops = loop_fwds } = def.fwd in
- let back_ls = def.backs in
+ let fwd = def.fwd in
+ let backs = def.backs in
(* Register the decrease clauses, if necessary *)
let register_decreases ctx def =
if has_decreases_clause def then
@@ -2343,22 +2343,19 @@ let extract_fun_decl_register_names (ctx : extraction_ctx)
| Lean -> ctx_add_decreases_proof def ctx
else ctx
in
- let ctx = List.fold_left register_decreases ctx (fwd :: loop_fwds) in
+ let ctx = List.fold_left register_decreases ctx (fwd.f :: fwd.loops) in
let register_fun ctx f = ctx_add_fun_decl def f ctx in
let register_funs ctx fl = List.fold_left register_fun ctx fl in
- (* Register the forward functions' names *)
- let ctx = register_funs ctx (fwd :: loop_fwds) in
- (* Register the backward functions' names *)
+ (* Register the names of the forward functions *)
let ctx =
- List.fold_left
- (fun ctx { f = back; loops = loop_backs } ->
- let ctx = register_fun ctx back in
- register_funs ctx loop_backs)
- ctx back_ls
+ if def.keep_fwd then register_funs ctx (fwd.f :: fwd.loops) else ctx
in
-
- (* Return *)
- ctx
+ (* Register the names of the backward functions *)
+ List.fold_left
+ (fun ctx { f = back; loops = loop_backs } ->
+ let ctx = register_fun ctx back in
+ register_funs ctx loop_backs)
+ ctx backs
(** Simply add the global name to the context. *)
let extract_global_decl_register_names (ctx : extraction_ctx)
@@ -3927,6 +3924,27 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter)
(* Add a break to insert lines between declarations *)
F.pp_print_break fmt 0 0
+(** Register the names for one trait method item *)
+let extract_trait_decl_method_register_names (ctx : extraction_ctx)
+ (trait_decl : trait_decl) (name : string) (id : fun_decl_id) :
+ extraction_ctx =
+ (* We add one field per required forward/backward function *)
+ let trans = A.FunDeclId.Map.find id ctx.trans_funs in
+
+ let register_fun ctx f = ctx_add_trait_method trait_decl name f ctx in
+ let register_funs ctx fl = List.fold_left register_fun ctx fl in
+ (* Register the names of the forward functions *)
+ let ctx =
+ if trans.keep_fwd then register_funs ctx (trans.fwd.f :: trans.fwd.loops)
+ else ctx
+ in
+ (* Register the names of the backward functions *)
+ List.fold_left
+ (fun ctx back ->
+ let ctx = register_fun ctx back.f in
+ register_funs ctx back.loops)
+ ctx trans.backs
+
(** Similar to {!extract_type_decl_register_names} *)
let extract_trait_decl_register_names (ctx : extraction_ctx)
(trait_decl : trait_decl) : extraction_ctx =
@@ -3968,12 +3986,10 @@ let extract_trait_decl_register_names (ctx : extraction_ctx)
ctx types
in
(* Required methods *)
- (* TODO: for the methods, we need to add fields for the forward/backward functions *)
- raise (Failure "TODO");
List.fold_left
(fun ctx (name, id) ->
(* We add one field per required forward/backward function *)
- ctx_add_trait_method trait_decl name ctx)
+ extract_trait_decl_method_register_names ctx trait_decl name id)
ctx required_methods
(** Similar to {!extract_type_decl_register_names} *)