From 66638a2a96c7639553a340917b87e26d94265c5e Mon Sep 17 00:00:00 2001 From: Son Ho Date: Sat, 17 Dec 2022 10:27:12 +0100 Subject: Fix various issues with the generation of code for the loops --- compiler/Translate.ml | 85 +++++++++++++++++++++++++++++++++++---------------- 1 file changed, 58 insertions(+), 27 deletions(-) (limited to 'compiler/Translate.ml') diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 32c32ac4..10a37770 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -51,7 +51,7 @@ let translate_function_to_symbolics (trans_ctx : trans_ctx) (fdef : A.fun_decl) let translate_function_to_pure (trans_ctx : trans_ctx) (fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdMap.t) (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) (fdef : A.fun_decl) - : pure_fun_translation = + : pure_fun_translation_no_loops = (* Debug *) log#ldebug (lazy @@ -213,7 +213,8 @@ let translate_function_to_pure (trans_ctx : trans_ctx) sg.info.num_fwd_inputs_with_fuel_with_state in let num_back_inputs = Option.get sg.info.num_back_inputs_no_state in - Collections.List.subslice sg.inputs num_forward_inputs num_back_inputs + Collections.List.subslice sg.inputs num_forward_inputs + (num_forward_inputs + num_back_inputs) in (* As we forbid nested borrows, the additional inputs for the backward * functions come from the borrows in the return value of the rust function: @@ -336,7 +337,7 @@ type gen_ctx = { extract_ctx : ExtractBase.extraction_ctx; trans_types : Pure.type_decl Pure.TypeDeclId.Map.t; trans_funs : (bool * pure_fun_translation) A.FunDeclId.Map.t; - functions_with_decreases_clause : A.FunDeclId.Set.t; + functions_with_decreases_clause : PureUtils.FunLoopIdSet.t; } type gen_config = { @@ -370,7 +371,7 @@ let module_has_opaque_decls (ctx : gen_ctx) : bool * bool = in let has_opaque_funs = A.FunDeclId.Map.exists - (fun _ ((_, (t_fwd, _)) : bool * pure_fun_translation) -> + (fun _ ((_, ((t_fwd, _), _)) : bool * pure_fun_translation) -> Option.is_none t_fwd.body) ctx.trans_funs in @@ -452,10 +453,11 @@ let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) (id : A.GlobalDeclId.id) : unit = let global_decls = ctx.extract_ctx.trans_ctx.global_context.global_decls in let global = A.GlobalDeclId.Map.find id global_decls in - let _, (body, body_backs) = + let _, ((body, loop_fwds), body_backs) = A.FunDeclId.Map.find global.body_id ctx.trans_funs in - assert (List.length body_backs = 0); + assert (body_backs = []); + assert (loop_fwds = []); let is_opaque = Option.is_none body.Pure.body in if @@ -487,7 +489,8 @@ let export_functions_declarations (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) (is_rec : bool) (decls : Pure.fun_decl list) : unit = (* Utility to check a function has a decrease clause *) let has_decreases_clause (def : Pure.fun_decl) : bool = - A.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause + PureUtils.FunLoopIdSet.mem (def.def_id, def.loop_id) + ctx.functions_with_decreases_clause in (* Extract the function declarations *) @@ -532,16 +535,21 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) (pure_ls : (bool * pure_fun_translation) list) : unit = (* Utility to check a function has a decrease clause *) let has_decreases_clause (def : Pure.fun_decl) : bool = - A.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause + PureUtils.FunLoopIdSet.mem (def.def_id, def.loop_id) + ctx.functions_with_decreases_clause in (* Extract the decrease clauses template bodies *) if config.extract_template_decreases_clauses then List.iter - (fun (_, (fwd, _)) -> - let has_decr_clause = has_decreases_clause fwd in - if has_decr_clause then - Extract.extract_template_decreases_clause ctx.extract_ctx fmt fwd) + (fun (_, ((fwd, loop_fwds), _)) -> + let extract_decrease decl = + let has_decr_clause = has_decreases_clause decl in + if has_decr_clause then + Extract.extract_template_decreases_clause ctx.extract_ctx fmt decl + in + extract_decrease fwd; + List.iter extract_decrease loop_fwds) pure_ls; (* Concatenate the function definitions, filtering the useless forward @@ -549,8 +557,15 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) let decls = List.concat (List.map - (fun (keep_fwd, (fwd, back_ls)) -> - if keep_fwd then fwd :: back_ls else back_ls) + (fun (keep_fwd, ((fwd, fwd_loops), (back_ls : fun_and_loops list))) -> + let fwd = if keep_fwd then List.append fwd_loops [ fwd ] else [] in + let back : Pure.fun_decl list = + List.concat + (List.map + (fun (back, loop_backs) -> List.append loop_backs [ back ]) + back_ls) + in + List.append fwd back) pure_ls) in @@ -568,7 +583,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) (* Insert unit tests if necessary *) if config.test_trans_unit_functions then List.iter - (fun (keep_fwd, (fwd, _)) -> + (fun (keep_fwd, ((fwd, _), _)) -> if keep_fwd then Extract.extract_unit_test_if_unit_fun ctx.extract_ctx fmt fwd) pure_ls @@ -721,12 +736,25 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : (* We need to compute which functions are recursive, in order to know * whether we should generate a decrease clause or not. *) let rec_functions = - A.FunDeclId.Set.of_list - (List.concat - (List.map - (fun decl -> match decl with A.Fun (Rec ids) -> ids | _ -> []) - crate.declarations)) + List.map + (fun (_, ((fwd, loop_fwds), _)) -> + let fwd = + if fwd.Pure.signature.info.effect_info.is_rec then + [ (fwd.def_id, None) ] + else [] + in + let loop_fwds = + List.map + (fun (def : Pure.fun_decl) -> [ (def.def_id, def.loop_id) ]) + loop_fwds + in + fwd :: loop_fwds) + trans_funs + in + let rec_functions : PureUtils.fun_loop_id list = + List.concat (List.concat rec_functions) in + let rec_functions = PureUtils.FunLoopIdSet.of_list rec_functions in (* Register unique names for all the top-level types, globals and functions. * Note that the order in which we generate the names doesn't matter: @@ -740,18 +768,21 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : let ctx = List.fold_left - (fun ctx (keep_fwd, def) -> + (fun ctx (keep_fwd, defs) -> (* We generate a decrease clause for all the recursive functions *) - let gen_decr_clause = - A.FunDeclId.Set.mem (fst def).Pure.def_id rec_functions + let fwd_def = fst (fst defs) in + let gen_decr_clause (def : Pure.fun_decl) = + PureUtils.FunLoopIdSet.mem + (def.Pure.def_id, def.Pure.loop_id) + rec_functions in (* Register the names, only if the function is not a global body - * those are handled later *) - let is_global = (fst def).Pure.is_global_decl_body in + let is_global = fwd_def.Pure.is_global_decl_body in if is_global then ctx else Extract.extract_fun_decl_register_names ctx keep_fwd gen_decr_clause - def) + defs) ctx trans_funs in @@ -785,7 +816,7 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : A.FunDeclId.Map.of_list (List.map (fun ((keep_fwd, (fd, bdl)) : bool * pure_fun_translation) -> - (fd.def_id, (keep_fwd, (fd, bdl)))) + ((fst fd).def_id, (keep_fwd, (fd, bdl)))) trans_funs) in @@ -883,7 +914,7 @@ let translate_module (filename : string) (dest_dir : string) (crate : A.crate) : (* Extract the template clauses *) let needs_clauses_module = !Config.extract_decreases_clauses - && not (A.FunDeclId.Set.is_empty rec_functions) + && not (PureUtils.FunLoopIdSet.is_empty rec_functions) in (if needs_clauses_module && !Config.extract_template_decreases_clauses then let ext = match !Config.backend with FStar -> ".fst" | Coq -> ".v" in -- cgit v1.2.3