summaryrefslogtreecommitdiff
path: root/compiler/Extract.ml
diff options
context:
space:
mode:
authorSon Ho2022-12-17 10:27:12 +0100
committerSon HO2023-02-03 11:21:46 +0100
commit66638a2a96c7639553a340917b87e26d94265c5e (patch)
treea0219df7582ca17784135345924790dc26a7e315 /compiler/Extract.ml
parent07621dcf488eef1c4a4ab797c21cc34ab474d225 (diff)
Fix various issues with the generation of code for the loops
Diffstat (limited to '')
-rw-r--r--compiler/Extract.ml29
1 files changed, 18 insertions, 11 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index fa384de6..b3d7b49e 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -1254,21 +1254,28 @@ let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx)
(forward function and backward functions).
*)
let extract_fun_decl_register_names (ctx : extraction_ctx) (keep_fwd : bool)
- (has_decreases_clause : bool) (def : pure_fun_translation) : extraction_ctx
- =
- let fwd, back_ls = def in
- (* Register the decrease clause, if necessary *)
- let ctx =
- if has_decreases_clause then ctx_add_decrases_clause fwd ctx else ctx
+ (has_decreases_clause : fun_decl -> bool) (def : pure_fun_translation) :
+ extraction_ctx =
+ let (fwd, loop_fwds), back_ls = def in
+ (* Register the decrease clauses, if necessary *)
+ let register_decreases ctx def =
+ if has_decreases_clause def then ctx_add_decreases_clause def ctx else ctx
in
- (* Register the forward function name *)
- let ctx = ctx_add_fun_decl (keep_fwd, def) fwd ctx in
+ let ctx = List.fold_left register_decreases ctx (fwd :: loop_fwds) in
+ (* Register the function names *)
+ let register_fun ctx f = ctx_add_fun_decl (keep_fwd, def) f ctx in
+ let register_funs ctx fl = List.fold_left register_fun ctx fl in
+ (* Register the forward functions' names *)
+ let ctx = register_funs ctx (fwd :: loop_fwds) in
(* Register the backward functions' names *)
let ctx =
List.fold_left
- (fun ctx back -> ctx_add_fun_decl (keep_fwd, def) back ctx)
+ (fun ctx (back, loop_backs) ->
+ let ctx = register_fun ctx back in
+ register_funs ctx loop_backs)
ctx back_ls
in
+
(* Return *)
ctx
@@ -1855,7 +1862,7 @@ let extract_template_decreases_clause (ctx : extraction_ctx) (fmt : F.formatter)
(def : fun_decl) : unit =
assert (!backend = FStar);
(* Retrieve the function name *)
- let def_name = ctx_get_decreases_clause def.def_id ctx in
+ let def_name = ctx_get_decreases_clause def.def_id def.loop_id ctx in
(* Add a break before *)
F.pp_print_break fmt 0 0;
(* Print a comment to link the extracted type to its original rust definition *)
@@ -1992,7 +1999,7 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter)
(* Open a box for the decreases term *)
F.pp_open_hovbox fmt ctx.indent_incr;
(* The name of the decrease clause *)
- let decr_name = ctx_get_decreases_clause def.def_id ctx in
+ let decr_name = ctx_get_decreases_clause def.def_id def.loop_id ctx in
F.pp_print_string fmt decr_name;
(* Print the type parameters *)
List.iter