diff options
Diffstat (limited to 'compiler/Translate.ml')
-rw-r--r-- | compiler/Translate.ml | 152 |
1 files changed, 91 insertions, 61 deletions
diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 4f31b738..87862d6b 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -383,8 +383,8 @@ let export_type (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) [is_rec]: [true] if the types are recursive. Necessarily [true] if there is > 1 type to extract. *) -let export_types (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) - (is_rec : bool) (ids : Pure.TypeDeclId.id list) : unit = +let export_types_group (fmt : Format.formatter) (config : gen_config) + (ctx : gen_ctx) (is_rec : bool) (ids : Pure.TypeDeclId.id list) : unit = let export_type = export_type fmt config ctx in let export_type_decl kind id = export_type kind id true false in let export_type_extra_info kind id = export_type kind id false true in @@ -432,30 +432,79 @@ let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx) then Extract.extract_global_decl ctx.extract_ctx fmt global body config.interface +(** Utility. + + Export a group of functions. See [export_functions_group]. + + We need this because for every function in Rust we may generate several functions + in the translation (a forward function, several backward functions, loop + functions, etc.). Those functions might call each other in different + ways (in particular, they may be mutually recursive, in which case we might + be able to group them into several groups of mutually recursive definitions, + etc.). For this reason, [export_functions_group] computes the dependency + graph of the functions as well as their strongly connected components, and + gives each SCC at a time to [export_functions]. + + Rem.: this function only extracts the function *declarations*. It doesn't + extract the decrease clauses, nor does it extract the unit tests. + + Rem.: this function doesn't check [config.extract_fun_decls]: it should have + been checked by the caller. + *) +let export_functions_declarations (fmt : Format.formatter) (config : gen_config) + (ctx : gen_ctx) (is_rec : bool) (decls : Pure.fun_decl list) : unit = + (* Utility to check a function has a decrease clause *) + let has_decreases_clause (def : Pure.fun_decl) : bool = + A.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause + in + + (* Extract the function declarations *) + (* Check if the functions are mutually recursive - this really works + * to check if the forward and backward translations of a single + * recursive function are mutually recursive *) + let is_mut_rec = List.length decls > 1 in + assert ((not is_mut_rec) || is_rec); + let decls_length = List.length decls in + List.iteri + (fun i def -> + let is_opaque = Option.is_none def.Pure.body in + let kind = + if is_opaque then + if config.interface then ExtractBase.Declared else ExtractBase.Assumed + else if not is_rec then ExtractBase.SingleNonRec + else if is_mut_rec then + (* If the functions are mutually recursive, we need to distinguish: + * - the first of the group + * - the last of the group + * - the inner functions + *) + if i = 0 then ExtractBase.MutRecFirst + else if i = decls_length - 1 then ExtractBase.MutRecLast + else ExtractBase.MutRecInner + else ExtractBase.SingleRec + in + let has_decr_clause = + has_decreases_clause def && config.extract_decreases_clauses + in + (* Check if the definition needs to be filtered or not *) + if + ((not is_opaque) && config.extract_transparent) + || (is_opaque && config.extract_opaque) + then Extract.extract_fun_decl ctx.extract_ctx fmt kind has_decr_clause def) + decls + (** Export a group of function declarations. In case of (non-mutually) recursive functions, we use a simple procedure to check if the forward and backward functions are mutually recursive. *) -let export_functions (fmt : Format.formatter) (config : gen_config) - (ctx : gen_ctx) (is_rec : bool) - (pure_ls : (bool * pure_fun_translation) list) : unit = +let export_functions_group (fmt : Format.formatter) (config : gen_config) + (ctx : gen_ctx) (pure_ls : (bool * pure_fun_translation) list) : unit = (* Utility to check a function has a decrease clause *) let has_decreases_clause (def : Pure.fun_decl) : bool = A.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause in - (* Concatenate the function definitions, filtering the useless forward - * functions. We also make pairs: (forward function, backward function) - * (the forward function contains useful information that we want to keep) *) - let fls = - List.concat - (List.map - (fun (keep_fwd, (fwd, back_ls)) -> - let back_ls = List.map (fun back -> (fwd, back)) back_ls in - if keep_fwd then (fwd, fwd) :: back_ls else back_ls) - pure_ls) - in (* Extract the decrease clauses template bodies *) if config.extract_template_decreases_clauses then List.iter @@ -464,48 +513,28 @@ let export_functions (fmt : Format.formatter) (config : gen_config) if has_decr_clause then Extract.extract_template_decreases_clause ctx.extract_ctx fmt fwd) pure_ls; + + (* Concatenate the function definitions, filtering the useless forward + * functions. *) + let decls = + List.concat + (List.map + (fun (keep_fwd, (fwd, back_ls)) -> + if keep_fwd then fwd :: back_ls else back_ls) + pure_ls) + in + (* Extract the function definitions *) (if config.extract_fun_decls then - (* Check if the functions are mutually recursive - this really works - * to check if the forward and backward translations of a single - * recursive function are mutually recursive *) - let is_mut_rec = - if is_rec then - if List.length pure_ls <= 1 then - not (PureUtils.functions_not_mutually_recursive (List.map fst fls)) - else true - else false + (* Group the mutually recursive definitions *) + let subgroups = ReorderDecls.group_reorder_fun_decls decls in + + (* Extract the subgroups *) + let export_subgroup (is_rec : bool) (decls : Pure.fun_decl list) : unit = + export_functions_declarations fmt config ctx is_rec decls in - let fls_length = List.length fls in - List.iteri - (fun i (fwd_def, def) -> - let is_opaque = Option.is_none fwd_def.Pure.body in - let kind = - if is_opaque then - if config.interface then ExtractBase.Declared - else ExtractBase.Assumed - else if not is_rec then ExtractBase.SingleNonRec - else if is_mut_rec then - (* If the functions are mutually recursive, we need to distinguish: - * - the first of the group - * - the last of the group - * - the inner functions - *) - if i = 0 then ExtractBase.MutRecFirst - else if i = fls_length - 1 then ExtractBase.MutRecLast - else ExtractBase.MutRecInner - else ExtractBase.SingleRec - in - let has_decr_clause = - has_decreases_clause def && config.extract_decreases_clauses - in - (* Check if the definition needs to be filtered or not *) - if - ((not is_opaque) && config.extract_transparent) - || (is_opaque && config.extract_opaque) - then - Extract.extract_fun_decl ctx.extract_ctx fmt kind has_decr_clause def) - fls); + List.iter (fun (is_rec, decls) -> export_subgroup is_rec decls) subgroups); + (* Insert unit tests if necessary *) if config.test_trans_unit_functions then List.iter @@ -525,9 +554,9 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) - [extract_extra_info]: extra the extra type information (e.g., the [Arguments] information in Coq). *) - let export_functions = export_functions fmt config ctx in + let export_functions_group = export_functions_group fmt config ctx in let export_global = export_global fmt config ctx in - let export_types = export_types fmt config ctx in + let export_types_group = export_types_group fmt config ctx in let export_state_type () : unit = let kind = @@ -538,13 +567,14 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) let export_decl_group (dg : A.declaration_group) : unit = match dg with - | Type (NonRec id) -> if config.extract_types then export_types false [ id ] - | Type (Rec ids) -> if config.extract_types then export_types true ids + | Type (NonRec id) -> + if config.extract_types then export_types_group false [ id ] + | Type (Rec ids) -> if config.extract_types then export_types_group true ids | Fun (NonRec id) -> (* Lookup *) let pure_fun = A.FunDeclId.Map.find id ctx.trans_funs in (* Translate *) - export_functions false [ pure_fun ] + export_functions_group [ pure_fun ] | Fun (Rec ids) -> (* General case of mutually recursive functions *) (* Lookup *) @@ -552,7 +582,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config) List.map (fun id -> A.FunDeclId.Map.find id ctx.trans_funs) ids in (* Translate *) - export_functions true pure_funs + export_functions_group pure_funs | Global id -> export_global id in |