From dfcbfab4030be2f03b159a4b298ed75ac2f236ae Mon Sep 17 00:00:00 2001 From: Son Ho Date: Sun, 3 Sep 2023 19:41:03 +0200 Subject: Add the keep_fwd field in TranslateCore.pure_fun_translation --- compiler/ExtractBase.ml | 2 +- compiler/PureMicroPasses.ml | 28 +++++++++++++--------------- compiler/Translate.ml | 34 ++++++++++++++++------------------ compiler/TranslateCore.ml | 13 ++++++++++++- 4 files changed, 42 insertions(+), 35 deletions(-) diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index 7a21d42d..885467c2 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -663,7 +663,7 @@ type extraction_ctx = { (** If we are extracting a trait declaration, identifies it *) is_provided_method : bool; trans_types : Pure.type_decl Pure.TypeDeclId.Map.t; - trans_funs : (bool * pure_fun_translation) A.FunDeclId.Map.t; + trans_funs : pure_fun_translation A.FunDeclId.Map.t; functions_with_decreases_clause : PureUtils.FunLoopIdSet.t; trans_trait_decls : Pure.trait_decl Pure.TraitDeclId.Map.t; trans_trait_impls : Pure.trait_impl Pure.TraitImplId.Map.t; diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index e97a9cd7..6c9c3a91 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -1460,8 +1460,7 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = In such situation, we can remove the forward function definition altogether. *) -let keep_forward (trans : pure_fun_translation) : bool = - let { fwd; backs } = trans in +let keep_forward (fwd : fun_and_loops) (backs : fun_and_loops list) : bool = (* Note that at this point, the output types are no longer seen as tuples: * they should be lists of length 1. *) if @@ -1977,8 +1976,8 @@ end module FunLoopIdMap = Collections.MakeMap (FunLoopIdOrderedType) (** Filter the useless loop input parameters. *) -let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : - (bool * pure_fun_translation) list = +let filter_loop_inputs (transl : pure_fun_translation list) : + pure_fun_translation list = (* We need to explore groups of mutually recursive functions. In order to compute which parameters are useless, we need to explore the functions by groups of mutually recursive definitions. @@ -1996,7 +1995,7 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : (List.concat (List.concat (List.map - (fun (_, { fwd; backs }) -> + (fun { fwd; backs; _ } -> [ fwd.f :: fwd.loops ] :: List.map (fun { f = back; loops = loops_back } -> @@ -2246,13 +2245,13 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : in let transl = List.map - (fun (b, { fwd; backs }) -> + (fun trans -> let filter_fun_and_loops f = { f = filter_in_one f.f; loops = List.map filter_in_one f.loops } in - let fwd = filter_fun_and_loops fwd in - let backs = List.map filter_fun_and_loops backs in - (b, { fwd; backs })) + let fwd = filter_fun_and_loops trans.fwd in + let backs = List.map filter_fun_and_loops trans.backs in + { trans with fwd; backs }) transl in @@ -2273,18 +2272,17 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) : but convenient. *) let apply_passes_to_pure_fun_translations (ctx : trans_ctx) - (transl : (fun_decl * fun_decl list) list) : - (bool * pure_fun_translation) list = - let apply_to_one (trans : fun_decl * fun_decl list) : - bool * pure_fun_translation = + (transl : (fun_decl * fun_decl list) list) : pure_fun_translation list = + let apply_to_one (trans : fun_decl * fun_decl list) : pure_fun_translation = (* Apply the passes to the individual functions *) let fwd, backs = trans in let fwd = Option.get (apply_passes_to_def ctx fwd) in let backs = List.filter_map (apply_passes_to_def ctx) backs in - let trans = { fwd; backs } in (* Compute whether we need to filter the forward function or not *) - (keep_forward trans, trans) + let keep_fwd = keep_forward fwd backs in + { keep_fwd; fwd; backs } in + let transl = List.map apply_to_one transl in (* Filter the useless inputs in the loop functions *) diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 7122e462..835edd46 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -305,7 +305,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx) let translate_crate_to_pure (crate : A.crate) : trans_ctx * Pure.type_decl list - * (bool * pure_fun_translation) list + * pure_fun_translation list * Pure.trait_decl list * Pure.trait_impl list = (* Debug *) @@ -439,8 +439,7 @@ let module_has_opaque_decls (ctx : gen_ctx) : bool * bool = in let has_opaque_funs = A.FunDeclId.Map.exists - (fun _ ((_, trans) : bool * pure_fun_translation) -> - Option.is_none trans.fwd.f.body) + (fun _ (trans : pure_fun_translation) -> Option.is_none trans.fwd.f.body) ctx.trans_funs in (has_opaque_types, has_opaque_funs) @@ -552,7 +551,7 @@ let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) (id : A.GlobalDeclId.id) : unit = let global_decls = ctx.trans_ctx.global_context.global_decls in let global = A.GlobalDeclId.Map.find id global_decls in - let _, trans = A.FunDeclId.Map.find global.body_id ctx.trans_funs in + let trans = A.FunDeclId.Map.find global.body_id ctx.trans_funs in assert (trans.fwd.loops = []); assert (trans.backs = []); let body = trans.fwd.f in @@ -665,7 +664,7 @@ let export_functions_group_scc (fmt : Format.formatter) (config : gen_config) check if the forward and backward functions are mutually recursive. *) let export_functions_group (fmt : Format.formatter) (config : gen_config) - (ctx : gen_ctx) (pure_ls : (bool * pure_fun_translation) list) : unit = + (ctx : gen_ctx) (pure_ls : pure_fun_translation list) : unit = (* Utility to check a function has a decrease clause *) let has_decreases_clause (def : Pure.fun_decl) : bool = PureUtils.FunLoopIdSet.mem (def.def_id, def.loop_id) @@ -675,7 +674,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) (* Extract the decrease clauses template bodies *) if config.extract_template_decreases_clauses then List.iter - (fun (_, { fwd; _ }) -> + (fun { fwd; _ } -> (* We only generate decreases clauses for the forward functions, because the termination argument should only depend on the forward inputs. The backward functions thus use the same decreases clauses as the @@ -710,7 +709,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) let decls = List.concat (List.map - (fun (keep_fwd, { fwd; backs }) -> + (fun { keep_fwd; fwd; backs } -> let fwd = if keep_fwd then List.append fwd.loops [ fwd.f ] else [] in let backs : Pure.fun_decl list = List.concat @@ -734,8 +733,8 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config) (* Insert unit tests if necessary *) if config.test_trans_unit_functions then List.iter - (fun (keep_fwd, trans) -> - if keep_fwd then + (fun trans -> + if trans.keep_fwd then Extract.extract_unit_test_if_unit_fun ctx fmt trans.fwd.f) pure_ls @@ -788,7 +787,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) extract their type directly in the records we generate for the trait declarations themselves, there is no point in having separate type definitions) *) - match (snd pure_fun).fwd.f.Pure.kind with + match pure_fun.fwd.f.Pure.kind with | TraitMethodDecl _ -> () | _ -> (* Translate *) @@ -993,7 +992,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : * whether we should generate a decrease clause or not. *) let rec_functions = List.map - (fun (_, { fwd; _ }) -> + (fun { fwd; _ } -> let fwd_f = if fwd.f.Pure.signature.info.effect_info.is_rec then [ (fwd.f.def_id, None) ] @@ -1017,11 +1016,10 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : Pure.TypeDeclId.Map.of_list (List.map (fun (d : Pure.type_decl) -> (d.def_id, d)) trans_types) in - let trans_funs : (bool * pure_fun_translation) A.FunDeclId.Map.t = + let trans_funs : pure_fun_translation A.FunDeclId.Map.t = A.FunDeclId.Map.of_list (List.map - (fun ((keep_fwd, { fwd; backs }) : bool * pure_fun_translation) -> - (fwd.f.def_id, (keep_fwd, { fwd; backs }))) + (fun (trans : pure_fun_translation) -> (trans.fwd.f.def_id, trans)) trans_funs) in @@ -1072,10 +1070,10 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : let ctx = List.fold_left - (fun ctx ((keep_fwd, defs) : bool * pure_fun_translation) -> + (fun ctx (trans : pure_fun_translation) -> (* If requested by the user, register termination measures and decreases proofs for all the recursive functions *) - let fwd_def = defs.fwd.f in + let fwd_def = trans.fwd.f in let gen_decr_clause (def : Pure.fun_decl) = !Config.extract_decreases_clauses && PureUtils.FunLoopIdSet.mem @@ -1087,8 +1085,8 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : let is_global = fwd_def.Pure.is_global_decl_body in if is_global then ctx else - Extract.extract_fun_decl_register_names ctx keep_fwd gen_decr_clause - defs) + Extract.extract_fun_decl_register_names ctx trans.keep_fwd + gen_decr_clause trans) ctx (A.FunDeclId.Map.values trans_funs) in diff --git a/compiler/TranslateCore.ml b/compiler/TranslateCore.ml index 9fd27c59..f31dc458 100644 --- a/compiler/TranslateCore.ml +++ b/compiler/TranslateCore.ml @@ -33,7 +33,18 @@ type trans_ctx = { type fun_and_loops = { f : Pure.fun_decl; loops : Pure.fun_decl list } type pure_fun_translation_no_loops = Pure.fun_decl * Pure.fun_decl list -type pure_fun_translation = { fwd : fun_and_loops; backs : fun_and_loops list } + +type pure_fun_translation = { + keep_fwd : bool; + (** Should we extract the forward function? + + If the forward function returns `()` and there is exactly one + backward function, we may merge the forward into the backward + function and thus don't extract the forward function)? + *) + fwd : fun_and_loops; + backs : fun_and_loops list; +} let trans_ctx_to_type_formatter (ctx : trans_ctx) (type_params : Pure.type_var list) -- cgit v1.2.3