summaryrefslogtreecommitdiff
path: root/compiler/Translate.ml
diff options
context:
space:
mode:
authorSon Ho2022-12-17 10:27:12 +0100
committerSon HO2023-02-03 11:21:46 +0100
commit66638a2a96c7639553a340917b87e26d94265c5e (patch)
treea0219df7582ca17784135345924790dc26a7e315 /compiler/Translate.ml
parent07621dcf488eef1c4a4ab797c21cc34ab474d225 (diff)
Fix various issues with the generation of code for the loops
Diffstat (limited to '')
-rw-r--r--compiler/Translate.ml85
1 files changed, 58 insertions, 27 deletions
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 32c32ac4..10a37770 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -51,7 +51,7 @@ let translate_function_to_symbolics (trans_ctx : trans_ctx) (fdef : A.fun_decl)
let translate_function_to_pure (trans_ctx : trans_ctx)
(fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdMap.t)
(pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) (fdef : A.fun_decl)
- : pure_fun_translation =
+ : pure_fun_translation_no_loops =
(* Debug *)
log#ldebug
(lazy
@@ -213,7 +213,8 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
sg.info.num_fwd_inputs_with_fuel_with_state
in
let num_back_inputs = Option.get sg.info.num_back_inputs_no_state in
- Collections.List.subslice sg.inputs num_forward_inputs num_back_inputs
+ Collections.List.subslice sg.inputs num_forward_inputs
+ (num_forward_inputs + num_back_inputs)
in
(* As we forbid nested borrows, the additional inputs for the backward
* functions come from the borrows in the return value of the rust function:
@@ -336,7 +337,7 @@ type gen_ctx = {
extract_ctx : ExtractBase.extraction_ctx;
trans_types : Pure.type_decl Pure.TypeDeclId.Map.t;
trans_funs : (bool * pure_fun_translation) A.FunDeclId.Map.t;
- functions_with_decreases_clause : A.FunDeclId.Set.t;
+ functions_with_decreases_clause : PureUtils.FunLoopIdSet.t;
}
type gen_config = {
@@ -370,7 +371,7 @@ let module_has_opaque_decls (ctx : gen_ctx) : bool * bool =
in
let has_opaque_funs =
A.FunDeclId.Map.exists
- (fun _ ((_, (t_fwd, _)) : bool * pure_fun_translation) ->
+ (fun _ ((_, ((t_fwd, _), _)) : bool * pure_fun_translation) ->
Option.is_none t_fwd.body)
ctx.trans_funs
in
@@ -452,10 +453,11 @@ let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx)
(id : A.GlobalDeclId.id) : unit =
let global_decls = ctx.extract_ctx.trans_ctx.global_context.global_decls in
let global = A.GlobalDeclId.Map.find id global_decls in
- let _, (body, body_backs) =
+ let _, ((body, loop_fwds), body_backs) =
A.FunDeclId.Map.find global.body_id ctx.trans_funs
in
- assert (List.length body_backs = 0);
+ assert (body_backs = []);
+ assert (loop_fwds = []);
let is_opaque = Option.is_none body.Pure.body in
if
@@ -487,7 +489,8 @@ let export_functions_declarations (fmt : Format.formatter) (config : gen_config)
(ctx : gen_ctx) (is_rec : bool) (decls : Pure.fun_decl list) : unit =
(* Utility to check a function has a decrease clause *)
let has_decreases_clause (def : Pure.fun_decl) : bool =
- A.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause
+ PureUtils.FunLoopIdSet.mem (def.def_id, def.loop_id)
+ ctx.functions_with_decreases_clause
in
(* Extract the function declarations *)
@@ -532,16 +535,21 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
(ctx : gen_ctx) (pure_ls : (bool * pure_fun_translation) list) : unit =
(* Utility to check a function has a decrease clause *)
let has_decreases_clause (def : Pure.fun_decl) : bool =
- A.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause
+ PureUtils.FunLoopIdSet.mem (def.def_id, def.loop_id)
+ ctx.functions_with_decreases_clause
in
(* Extract the decrease clauses template bodies *)
if config.extract_template_decreases_clauses then
List.iter
- (fun (_, (fwd, _)) ->
- let has_decr_clause = has_decreases_clause fwd in
- if has_decr_clause then
- Extract.extract_template_decreases_clause ctx.extract_ctx fmt fwd)
+ (fun (_, ((fwd, loop_fwds), _)) ->
+ let extract_decrease decl =
+ let has_decr_clause = has_decreases_clause decl in
+ if has_decr_clause then
+ Extract.extract_template_decreases_clause ctx.extract_ctx fmt decl
+ in
+ extract_decrease fwd;
+ List.iter extract_decrease loop_fwds)
pure_ls;
(* Concatenate the function definitions, filtering the useless forward
@@ -549,8 +557,15 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
let decls =
List.concat
(List.map
- (fun (keep_fwd, (fwd, back_ls)) ->
- if keep_fwd then fwd :: back_ls else back_ls)
+ (fun (keep_fwd, ((fwd, fwd_loops), (back_ls : fun_and_loops list))) ->
+ let fwd = if keep_fwd then List.append fwd_loops [ fwd ] else [] in
+ let back : Pure.fun_decl list =
+ List.concat
+ (List.map
+ (fun (back, loop_backs) -> List.append loop_backs [ back ])
+ back_ls)
+ in
+ List.append fwd back)
pure_ls)
in
@@ -568,7 +583,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
(* Insert unit tests if necessary *)
if config.test_trans_unit_functions then
List.iter
- (fun (keep_fwd, (fwd, _)) ->
+ (fun (keep_fwd, ((fwd, _), _)) ->
if keep_fwd then
Extract.extract_unit_test_if_unit_fun ctx.extract_ctx fmt fwd)
pure_ls
@@ -721,12 +736,25 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) :
(* We need to compute which functions are recursive, in order to know
* whether we should generate a decrease clause or not. *)
let rec_functions =
- A.FunDeclId.Set.of_list
- (List.concat
- (List.map
- (fun decl -> match decl with A.Fun (Rec ids) -> ids | _ -> [])
- crate.declarations))
+ List.map
+ (fun (_, ((fwd, loop_fwds), _)) ->
+ let fwd =
+ if fwd.Pure.signature.info.effect_info.is_rec then
+ [ (fwd.def_id, None) ]
+ else []
+ in
+ let loop_fwds =
+ List.map
+ (fun (def : Pure.fun_decl) -> [ (def.def_id, def.loop_id) ])
+ loop_fwds
+ in
+ fwd :: loop_fwds)
+ trans_funs
+ in
+ let rec_functions : PureUtils.fun_loop_id list =
+ List.concat (List.concat rec_functions)
in
+ let rec_functions = PureUtils.FunLoopIdSet.of_list rec_functions in
(* Register unique names for all the top-level types, globals and functions.
* Note that the order in which we generate the names doesn't matter:
@@ -740,18 +768,21 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) :
let ctx =
List.fold_left
- (fun ctx (keep_fwd, def) ->
+ (fun ctx (keep_fwd, defs) ->
(* We generate a decrease clause for all the recursive functions *)
- let gen_decr_clause =
- A.FunDeclId.Set.mem (fst def).Pure.def_id rec_functions
+ let fwd_def = fst (fst defs) in
+ let gen_decr_clause (def : Pure.fun_decl) =
+ PureUtils.FunLoopIdSet.mem
+ (def.Pure.def_id, def.Pure.loop_id)
+ rec_functions
in
(* Register the names, only if the function is not a global body -
* those are handled later *)
- let is_global = (fst def).Pure.is_global_decl_body in
+ let is_global = fwd_def.Pure.is_global_decl_body in
if is_global then ctx
else
Extract.extract_fun_decl_register_names ctx keep_fwd gen_decr_clause
- def)
+ defs)
ctx trans_funs
in
@@ -785,7 +816,7 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) :
A.FunDeclId.Map.of_list
(List.map
(fun ((keep_fwd, (fd, bdl)) : bool * pure_fun_translation) ->
- (fd.def_id, (keep_fwd, (fd, bdl))))
+ ((fst fd).def_id, (keep_fwd, (fd, bdl))))
trans_funs)
in
@@ -883,7 +914,7 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) :
(* Extract the template clauses *)
let needs_clauses_module =
!Config.extract_decreases_clauses
- && not (A.FunDeclId.Set.is_empty rec_functions)
+ && not (PureUtils.FunLoopIdSet.is_empty rec_functions)
in
(if needs_clauses_module && !Config.extract_template_decreases_clauses then
let ext = match !Config.backend with FStar -> ".fst" | Coq -> ".v" in