summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-09-03 19:41:03 +0200
committerSon Ho2023-09-03 19:41:03 +0200
commitdfcbfab4030be2f03b159a4b298ed75ac2f236ae (patch)
tree4300dc6c3eab5680c7afd06441c743c33e3bc0c7
parentcce09bb0fb64b07b07613d7db59857651e040c20 (diff)
Add the keep_fwd field in TranslateCore.pure_fun_translation
-rw-r--r--compiler/ExtractBase.ml2
-rw-r--r--compiler/PureMicroPasses.ml28
-rw-r--r--compiler/Translate.ml34
-rw-r--r--compiler/TranslateCore.ml13
4 files changed, 42 insertions, 35 deletions
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index 7a21d42d..885467c2 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -663,7 +663,7 @@ type extraction_ctx = {
(** If we are extracting a trait declaration, identifies it *)
is_provided_method : bool;
trans_types : Pure.type_decl Pure.TypeDeclId.Map.t;
- trans_funs : (bool * pure_fun_translation) A.FunDeclId.Map.t;
+ trans_funs : pure_fun_translation A.FunDeclId.Map.t;
functions_with_decreases_clause : PureUtils.FunLoopIdSet.t;
trans_trait_decls : Pure.trait_decl Pure.TraitDeclId.Map.t;
trans_trait_impls : Pure.trait_impl Pure.TraitImplId.Map.t;
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index e97a9cd7..6c9c3a91 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -1460,8 +1460,7 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
In such situation, we can remove the forward function definition
altogether.
*)
-let keep_forward (trans : pure_fun_translation) : bool =
- let { fwd; backs } = trans in
+let keep_forward (fwd : fun_and_loops) (backs : fun_and_loops list) : bool =
(* Note that at this point, the output types are no longer seen as tuples:
* they should be lists of length 1. *)
if
@@ -1977,8 +1976,8 @@ end
module FunLoopIdMap = Collections.MakeMap (FunLoopIdOrderedType)
(** Filter the useless loop input parameters. *)
-let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
- (bool * pure_fun_translation) list =
+let filter_loop_inputs (transl : pure_fun_translation list) :
+ pure_fun_translation list =
(* We need to explore groups of mutually recursive functions. In order
to compute which parameters are useless, we need to explore the
functions by groups of mutually recursive definitions.
@@ -1996,7 +1995,7 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
(List.concat
(List.concat
(List.map
- (fun (_, { fwd; backs }) ->
+ (fun { fwd; backs; _ } ->
[ fwd.f :: fwd.loops ]
:: List.map
(fun { f = back; loops = loops_back } ->
@@ -2246,13 +2245,13 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
in
let transl =
List.map
- (fun (b, { fwd; backs }) ->
+ (fun trans ->
let filter_fun_and_loops f =
{ f = filter_in_one f.f; loops = List.map filter_in_one f.loops }
in
- let fwd = filter_fun_and_loops fwd in
- let backs = List.map filter_fun_and_loops backs in
- (b, { fwd; backs }))
+ let fwd = filter_fun_and_loops trans.fwd in
+ let backs = List.map filter_fun_and_loops trans.backs in
+ { trans with fwd; backs })
transl
in
@@ -2273,18 +2272,17 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
but convenient.
*)
let apply_passes_to_pure_fun_translations (ctx : trans_ctx)
- (transl : (fun_decl * fun_decl list) list) :
- (bool * pure_fun_translation) list =
- let apply_to_one (trans : fun_decl * fun_decl list) :
- bool * pure_fun_translation =
+ (transl : (fun_decl * fun_decl list) list) : pure_fun_translation list =
+ let apply_to_one (trans : fun_decl * fun_decl list) : pure_fun_translation =
(* Apply the passes to the individual functions *)
let fwd, backs = trans in
let fwd = Option.get (apply_passes_to_def ctx fwd) in
let backs = List.filter_map (apply_passes_to_def ctx) backs in
- let trans = { fwd; backs } in
(* Compute whether we need to filter the forward function or not *)
- (keep_forward trans, trans)
+ let keep_fwd = keep_forward fwd backs in
+ { keep_fwd; fwd; backs }
in
+
let transl = List.map apply_to_one transl in
(* Filter the useless inputs in the loop functions *)
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 7122e462..835edd46 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -305,7 +305,7 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
let translate_crate_to_pure (crate : A.crate) :
trans_ctx
* Pure.type_decl list
- * (bool * pure_fun_translation) list
+ * pure_fun_translation list
* Pure.trait_decl list
* Pure.trait_impl list =
(* Debug *)
@@ -439,8 +439,7 @@ let module_has_opaque_decls (ctx : gen_ctx) : bool * bool =
in
let has_opaque_funs =
A.FunDeclId.Map.exists
- (fun _ ((_, trans) : bool * pure_fun_translation) ->
- Option.is_none trans.fwd.f.body)
+ (fun _ (trans : pure_fun_translation) -> Option.is_none trans.fwd.f.body)
ctx.trans_funs
in
(has_opaque_types, has_opaque_funs)
@@ -552,7 +551,7 @@ let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx)
(id : A.GlobalDeclId.id) : unit =
let global_decls = ctx.trans_ctx.global_context.global_decls in
let global = A.GlobalDeclId.Map.find id global_decls in
- let _, trans = A.FunDeclId.Map.find global.body_id ctx.trans_funs in
+ let trans = A.FunDeclId.Map.find global.body_id ctx.trans_funs in
assert (trans.fwd.loops = []);
assert (trans.backs = []);
let body = trans.fwd.f in
@@ -665,7 +664,7 @@ let export_functions_group_scc (fmt : Format.formatter) (config : gen_config)
check if the forward and backward functions are mutually recursive.
*)
let export_functions_group (fmt : Format.formatter) (config : gen_config)
- (ctx : gen_ctx) (pure_ls : (bool * pure_fun_translation) list) : unit =
+ (ctx : gen_ctx) (pure_ls : pure_fun_translation list) : unit =
(* Utility to check a function has a decrease clause *)
let has_decreases_clause (def : Pure.fun_decl) : bool =
PureUtils.FunLoopIdSet.mem (def.def_id, def.loop_id)
@@ -675,7 +674,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
(* Extract the decrease clauses template bodies *)
if config.extract_template_decreases_clauses then
List.iter
- (fun (_, { fwd; _ }) ->
+ (fun { fwd; _ } ->
(* We only generate decreases clauses for the forward functions, because
the termination argument should only depend on the forward inputs.
The backward functions thus use the same decreases clauses as the
@@ -710,7 +709,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
let decls =
List.concat
(List.map
- (fun (keep_fwd, { fwd; backs }) ->
+ (fun { keep_fwd; fwd; backs } ->
let fwd = if keep_fwd then List.append fwd.loops [ fwd.f ] else [] in
let backs : Pure.fun_decl list =
List.concat
@@ -734,8 +733,8 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
(* Insert unit tests if necessary *)
if config.test_trans_unit_functions then
List.iter
- (fun (keep_fwd, trans) ->
- if keep_fwd then
+ (fun trans ->
+ if trans.keep_fwd then
Extract.extract_unit_test_if_unit_fun ctx fmt trans.fwd.f)
pure_ls
@@ -788,7 +787,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
extract their type directly in the records we generate for
the trait declarations themselves, there is no point in having
separate type definitions) *)
- match (snd pure_fun).fwd.f.Pure.kind with
+ match pure_fun.fwd.f.Pure.kind with
| TraitMethodDecl _ -> ()
| _ ->
(* Translate *)
@@ -993,7 +992,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
* whether we should generate a decrease clause or not. *)
let rec_functions =
List.map
- (fun (_, { fwd; _ }) ->
+ (fun { fwd; _ } ->
let fwd_f =
if fwd.f.Pure.signature.info.effect_info.is_rec then
[ (fwd.f.def_id, None) ]
@@ -1017,11 +1016,10 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
Pure.TypeDeclId.Map.of_list
(List.map (fun (d : Pure.type_decl) -> (d.def_id, d)) trans_types)
in
- let trans_funs : (bool * pure_fun_translation) A.FunDeclId.Map.t =
+ let trans_funs : pure_fun_translation A.FunDeclId.Map.t =
A.FunDeclId.Map.of_list
(List.map
- (fun ((keep_fwd, { fwd; backs }) : bool * pure_fun_translation) ->
- (fwd.f.def_id, (keep_fwd, { fwd; backs })))
+ (fun (trans : pure_fun_translation) -> (trans.fwd.f.def_id, trans))
trans_funs)
in
@@ -1072,10 +1070,10 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
let ctx =
List.fold_left
- (fun ctx ((keep_fwd, defs) : bool * pure_fun_translation) ->
+ (fun ctx (trans : pure_fun_translation) ->
(* If requested by the user, register termination measures and decreases
proofs for all the recursive functions *)
- let fwd_def = defs.fwd.f in
+ let fwd_def = trans.fwd.f in
let gen_decr_clause (def : Pure.fun_decl) =
!Config.extract_decreases_clauses
&& PureUtils.FunLoopIdSet.mem
@@ -1087,8 +1085,8 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
let is_global = fwd_def.Pure.is_global_decl_body in
if is_global then ctx
else
- Extract.extract_fun_decl_register_names ctx keep_fwd gen_decr_clause
- defs)
+ Extract.extract_fun_decl_register_names ctx trans.keep_fwd
+ gen_decr_clause trans)
ctx
(A.FunDeclId.Map.values trans_funs)
in
diff --git a/compiler/TranslateCore.ml b/compiler/TranslateCore.ml
index 9fd27c59..f31dc458 100644
--- a/compiler/TranslateCore.ml
+++ b/compiler/TranslateCore.ml
@@ -33,7 +33,18 @@ type trans_ctx = {
type fun_and_loops = { f : Pure.fun_decl; loops : Pure.fun_decl list }
type pure_fun_translation_no_loops = Pure.fun_decl * Pure.fun_decl list
-type pure_fun_translation = { fwd : fun_and_loops; backs : fun_and_loops list }
+
+type pure_fun_translation = {
+ keep_fwd : bool;
+ (** Should we extract the forward function?
+
+ If the forward function returns `()` and there is exactly one
+ backward function, we may merge the forward into the backward
+ function and thus don't extract the forward function)?
+ *)
+ fwd : fun_and_loops;
+ backs : fun_and_loops list;
+}
let trans_ctx_to_type_formatter (ctx : trans_ctx)
(type_params : Pure.type_var list)