summaryrefslogtreecommitdiff
path: root/compiler/Translate.ml
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/Translate.ml')
-rw-r--r--compiler/Translate.ml152
1 files changed, 91 insertions, 61 deletions
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 4f31b738..87862d6b 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -383,8 +383,8 @@ let export_type (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx)
[is_rec]: [true] if the types are recursive. Necessarily [true] if there is
> 1 type to extract.
*)
-let export_types (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx)
- (is_rec : bool) (ids : Pure.TypeDeclId.id list) : unit =
+let export_types_group (fmt : Format.formatter) (config : gen_config)
+ (ctx : gen_ctx) (is_rec : bool) (ids : Pure.TypeDeclId.id list) : unit =
let export_type = export_type fmt config ctx in
let export_type_decl kind id = export_type kind id true false in
let export_type_extra_info kind id = export_type kind id false true in
@@ -432,30 +432,79 @@ let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx)
then
Extract.extract_global_decl ctx.extract_ctx fmt global body config.interface
+(** Utility.
+
+ Export a group of functions. See [export_functions_group].
+
+ We need this because for every function in Rust we may generate several functions
+ in the translation (a forward function, several backward functions, loop
+ functions, etc.). Those functions might call each other in different
+ ways (in particular, they may be mutually recursive, in which case we might
+ be able to group them into several groups of mutually recursive definitions,
+ etc.). For this reason, [export_functions_group] computes the dependency
+ graph of the functions as well as their strongly connected components, and
+ gives each SCC at a time to [export_functions].
+
+ Rem.: this function only extracts the function *declarations*. It doesn't
+ extract the decrease clauses, nor does it extract the unit tests.
+
+ Rem.: this function doesn't check [config.extract_fun_decls]: it should have
+ been checked by the caller.
+ *)
+let export_functions_declarations (fmt : Format.formatter) (config : gen_config)
+ (ctx : gen_ctx) (is_rec : bool) (decls : Pure.fun_decl list) : unit =
+ (* Utility to check a function has a decrease clause *)
+ let has_decreases_clause (def : Pure.fun_decl) : bool =
+ A.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause
+ in
+
+ (* Extract the function declarations *)
+ (* Check if the functions are mutually recursive - this really works
+ * to check if the forward and backward translations of a single
+ * recursive function are mutually recursive *)
+ let is_mut_rec = List.length decls > 1 in
+ assert ((not is_mut_rec) || is_rec);
+ let decls_length = List.length decls in
+ List.iteri
+ (fun i def ->
+ let is_opaque = Option.is_none def.Pure.body in
+ let kind =
+ if is_opaque then
+ if config.interface then ExtractBase.Declared else ExtractBase.Assumed
+ else if not is_rec then ExtractBase.SingleNonRec
+ else if is_mut_rec then
+ (* If the functions are mutually recursive, we need to distinguish:
+ * - the first of the group
+ * - the last of the group
+ * - the inner functions
+ *)
+ if i = 0 then ExtractBase.MutRecFirst
+ else if i = decls_length - 1 then ExtractBase.MutRecLast
+ else ExtractBase.MutRecInner
+ else ExtractBase.SingleRec
+ in
+ let has_decr_clause =
+ has_decreases_clause def && config.extract_decreases_clauses
+ in
+ (* Check if the definition needs to be filtered or not *)
+ if
+ ((not is_opaque) && config.extract_transparent)
+ || (is_opaque && config.extract_opaque)
+ then Extract.extract_fun_decl ctx.extract_ctx fmt kind has_decr_clause def)
+ decls
+
(** Export a group of function declarations.
In case of (non-mutually) recursive functions, we use a simple procedure to
check if the forward and backward functions are mutually recursive.
*)
-let export_functions (fmt : Format.formatter) (config : gen_config)
- (ctx : gen_ctx) (is_rec : bool)
- (pure_ls : (bool * pure_fun_translation) list) : unit =
+let export_functions_group (fmt : Format.formatter) (config : gen_config)
+ (ctx : gen_ctx) (pure_ls : (bool * pure_fun_translation) list) : unit =
(* Utility to check a function has a decrease clause *)
let has_decreases_clause (def : Pure.fun_decl) : bool =
A.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause
in
- (* Concatenate the function definitions, filtering the useless forward
- * functions. We also make pairs: (forward function, backward function)
- * (the forward function contains useful information that we want to keep) *)
- let fls =
- List.concat
- (List.map
- (fun (keep_fwd, (fwd, back_ls)) ->
- let back_ls = List.map (fun back -> (fwd, back)) back_ls in
- if keep_fwd then (fwd, fwd) :: back_ls else back_ls)
- pure_ls)
- in
(* Extract the decrease clauses template bodies *)
if config.extract_template_decreases_clauses then
List.iter
@@ -464,48 +513,28 @@ let export_functions (fmt : Format.formatter) (config : gen_config)
if has_decr_clause then
Extract.extract_template_decreases_clause ctx.extract_ctx fmt fwd)
pure_ls;
+
+ (* Concatenate the function definitions, filtering the useless forward
+ * functions. *)
+ let decls =
+ List.concat
+ (List.map
+ (fun (keep_fwd, (fwd, back_ls)) ->
+ if keep_fwd then fwd :: back_ls else back_ls)
+ pure_ls)
+ in
+
(* Extract the function definitions *)
(if config.extract_fun_decls then
- (* Check if the functions are mutually recursive - this really works
- * to check if the forward and backward translations of a single
- * recursive function are mutually recursive *)
- let is_mut_rec =
- if is_rec then
- if List.length pure_ls <= 1 then
- not (PureUtils.functions_not_mutually_recursive (List.map fst fls))
- else true
- else false
+ (* Group the mutually recursive definitions *)
+ let subgroups = ReorderDecls.group_reorder_fun_decls decls in
+
+ (* Extract the subgroups *)
+ let export_subgroup (is_rec : bool) (decls : Pure.fun_decl list) : unit =
+ export_functions_declarations fmt config ctx is_rec decls
in
- let fls_length = List.length fls in
- List.iteri
- (fun i (fwd_def, def) ->
- let is_opaque = Option.is_none fwd_def.Pure.body in
- let kind =
- if is_opaque then
- if config.interface then ExtractBase.Declared
- else ExtractBase.Assumed
- else if not is_rec then ExtractBase.SingleNonRec
- else if is_mut_rec then
- (* If the functions are mutually recursive, we need to distinguish:
- * - the first of the group
- * - the last of the group
- * - the inner functions
- *)
- if i = 0 then ExtractBase.MutRecFirst
- else if i = fls_length - 1 then ExtractBase.MutRecLast
- else ExtractBase.MutRecInner
- else ExtractBase.SingleRec
- in
- let has_decr_clause =
- has_decreases_clause def && config.extract_decreases_clauses
- in
- (* Check if the definition needs to be filtered or not *)
- if
- ((not is_opaque) && config.extract_transparent)
- || (is_opaque && config.extract_opaque)
- then
- Extract.extract_fun_decl ctx.extract_ctx fmt kind has_decr_clause def)
- fls);
+ List.iter (fun (is_rec, decls) -> export_subgroup is_rec decls) subgroups);
+
(* Insert unit tests if necessary *)
if config.test_trans_unit_functions then
List.iter
@@ -525,9 +554,9 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
- [extract_extra_info]: extra the extra type information (e.g.,
the [Arguments] information in Coq).
*)
- let export_functions = export_functions fmt config ctx in
+ let export_functions_group = export_functions_group fmt config ctx in
let export_global = export_global fmt config ctx in
- let export_types = export_types fmt config ctx in
+ let export_types_group = export_types_group fmt config ctx in
let export_state_type () : unit =
let kind =
@@ -538,13 +567,14 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
let export_decl_group (dg : A.declaration_group) : unit =
match dg with
- | Type (NonRec id) -> if config.extract_types then export_types false [ id ]
- | Type (Rec ids) -> if config.extract_types then export_types true ids
+ | Type (NonRec id) ->
+ if config.extract_types then export_types_group false [ id ]
+ | Type (Rec ids) -> if config.extract_types then export_types_group true ids
| Fun (NonRec id) ->
(* Lookup *)
let pure_fun = A.FunDeclId.Map.find id ctx.trans_funs in
(* Translate *)
- export_functions false [ pure_fun ]
+ export_functions_group [ pure_fun ]
| Fun (Rec ids) ->
(* General case of mutually recursive functions *)
(* Lookup *)
@@ -552,7 +582,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
List.map (fun id -> A.FunDeclId.Map.find id ctx.trans_funs) ids
in
(* Translate *)
- export_functions true pure_funs
+ export_functions_group pure_funs
| Global id -> export_global id
in