summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2022-12-14 16:48:35 +0100
committerSon HO2023-02-03 11:21:46 +0100
commit54a6b5d1a90b7304817175a33fc37444e559b11e (patch)
tree77a5836aeb6a72b93a9f285771b64e45377d805f
parentcdd5fa0e6d911174413a726029f91713963e9871 (diff)
Compute the SCCs of the functions to extract in Translate.ml
-rw-r--r--compiler/Logging.ml6
-rw-r--r--compiler/Print.ml3
-rw-r--r--compiler/PureUtils.ml49
-rw-r--r--compiler/ReorderDecls.ml86
-rw-r--r--compiler/SCC.ml12
-rw-r--r--compiler/Translate.ml152
6 files changed, 183 insertions, 125 deletions
diff --git a/compiler/Logging.ml b/compiler/Logging.ml
index 8cfc25d3..71f72471 100644
--- a/compiler/Logging.ml
+++ b/compiler/Logging.ml
@@ -47,3 +47,9 @@ let borrows_log = L.get_logger "MainLogger.Interpreter.Borrows"
(** Logger for Invariants *)
let invariants_log = L.get_logger "MainLogger.Interpreter.Invariants"
+
+(** Logger for SCC *)
+let scc_log = L.get_logger "MainLogger.SCC"
+
+(** Logger for ReorderDecls *)
+let reorder_decls_log = L.get_logger "MainLogger.ReorderDecls"
diff --git a/compiler/Print.ml b/compiler/Print.ml
index a4a8b1d4..82c5dac5 100644
--- a/compiler/Print.ml
+++ b/compiler/Print.ml
@@ -7,6 +7,9 @@ module PrimitiveValues = Charon.PrintPrimitiveValues
module Types = Charon.PrintTypes
module Expressions = Charon.PrintExpressions
+let list_to_string (to_string : 'a -> string) (ls : 'a list) : string =
+ "[" ^ String.concat "; " (List.map to_string ls) ^ "]"
+
(** Pretty-printing for values *)
module Values = struct
type value_formatter = {
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index 0f1d50f1..da15d635 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -110,55 +110,6 @@ let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) :
let info = sg.info in
{ inputs; output; doutputs; info }
-(** Return true if a list of functions are *not* mutually recursive, false otherwise.
- This function is meant to be applied on a set of (forward, backwards) functions
- generated for one recursive function.
- The way we do the test is very simple:
- - we explore the functions one by one, in the order in which they are provided
- - if all functions only call functions we already explored, they are not
- mutually recursive
- *)
-let functions_not_mutually_recursive (funs : fun_decl list) : bool =
- (* Compute the set of function identifiers in the group *)
- let ids =
- FunOrOpIdSet.of_list
- (List.map
- (fun (f : fun_decl) -> Fun (FromLlbc (A.Regular f.def_id, f.back_id)))
- funs)
- in
- let ids = ref ids in
- (* Explore every body *)
- let body_only_calls_itself (fdef : fun_decl) : bool =
- (* Remove the current id from the id set *)
- ids :=
- FunOrOpIdSet.remove
- (Fun (FromLlbc (A.Regular fdef.def_id, fdef.back_id)))
- !ids;
-
- (* Check if we call functions from the updated id set *)
- let obj =
- object
- inherit [_] iter_expression as super
-
- method! visit_qualif env qualif =
- match qualif.id with
- | FunOrOp fun_id ->
- if FunOrOpIdSet.mem fun_id !ids then raise Utils.Found
- else super#visit_qualif env qualif
- | _ -> super#visit_qualif env qualif
- end
- in
-
- try
- match fdef.body with
- | None -> true
- | Some body ->
- obj#visit_texpression () body.body;
- true
- with Utils.Found -> false
- in
- List.for_all body_only_calls_itself funs
-
(** We use this to check whether we need to add parentheses around expressions.
We only look for outer monadic let-bindings.
This is used when printing the branches of [if ... then ... else ...].
diff --git a/compiler/ReorderDecls.ml b/compiler/ReorderDecls.ml
index 9d222011..2b78c570 100644
--- a/compiler/ReorderDecls.ml
+++ b/compiler/ReorderDecls.ml
@@ -3,6 +3,9 @@ open Collections
open SCC
open Pure
+(** The local logger *)
+let log = Logging.reorder_decls_log
+
type fun_id = { def_id : FunDeclId.id; rg_id : T.RegionGroupId.id option }
[@@deriving show, ord]
@@ -69,6 +72,11 @@ let group_reorder_fun_decls (decls : fun_decl list) :
let idl = List.map get_fun_id decls in
let ids = FunIdSet.of_list idl in
+ log#ldebug
+ (lazy
+ ("group_reorder_fun_decls: ids:\n"
+ ^ (Print.list_to_string FunIdOrderedType.show_t) idl));
+
(* Explore the bodies to compute the dependencies - we ignore the ids
which refer to declarations not in the group we want to reorder *)
let deps : (fun_id * FunIdSet.t) list =
@@ -79,50 +87,88 @@ let group_reorder_fun_decls (decls : fun_decl list) :
| None -> (id, FunIdSet.empty)
| Some body ->
let deps = compute_body_fun_deps body.body in
- (id, FunIdSet.inter deps ids))
+ (* Restrict the set dependencies *)
+ let deps = FunIdSet.inter deps ids in
+ (id, deps))
decls
in
(*
* Create the dependency graph
*)
- (* Convert the ids to vertices (i.e., injectively map ids to integers) *)
- let id_to_vertex : int FunIdMap.t =
+ (* Convert the ids to vertices (i.e., injectively map ids to integers, and create
+ vertices labeled with those integers).
+
+ Rem.: [Graph.create] is *imperative*: it generates a new vertex every time
+ it is called (!!).
+ *)
+ let module Graph = Pack.Digraph in
+ let id_to_vertex : Graph.V.t FunIdMap.t =
let cnt = ref 0 in
FunIdMap.of_list
(List.map
(fun id ->
- let v = !cnt in
+ let lbl = !cnt in
cnt := !cnt + 1;
+ (* We create a vertex *)
+ let v = Graph.V.create lbl in
(id, v))
idl)
in
let vertex_to_id : fun_id IntMap.t =
IntMap.of_list
- (List.map (fun (fid, vid) -> (vid, fid)) (FunIdMap.bindings id_to_vertex))
+ (List.map
+ (fun (fid, v) -> (Graph.V.label v, fid))
+ (FunIdMap.bindings id_to_vertex))
in
- let to_v id = Pack.Graph.V.create (FunIdMap.find id id_to_vertex) in
- let to_id v = IntMap.find (Pack.Graph.V.label v) vertex_to_id in
- let g = Pack.Graph.create () in
- (* First add the vertices *)
- List.iter (fun id -> Pack.Graph.add_vertex g (to_v id)) idl;
+ let to_v id = FunIdMap.find id id_to_vertex in
+ let to_id v = IntMap.find (Graph.V.label v) vertex_to_id in
- (* Then add the edges *)
+ let g = Graph.create () in
+
+ (* Add the edges, first from the vertices to themselves, then between vertices *)
List.iter
(fun (fun_id, deps) ->
- FunIdSet.iter
- (fun dep_id -> Pack.Graph.add_edge g (to_v fun_id) (to_v dep_id))
- deps)
+ let v = to_v fun_id in
+ Graph.add_edge g v v;
+ FunIdSet.iter (fun dep_id -> Graph.add_edge g v (to_v dep_id)) deps)
deps;
(* Compute the SCCs *)
- let module Comp = Components.Make (Pack.Graph) in
+ let module Comp = Components.Make (Graph) in
let sccs = Comp.scc_list g in
(* Convert the vertices to ids *)
let sccs = List.map (List.map to_id) sccs in
+ log#ldebug
+ (lazy
+ ("group_reorder_fun_decls: SCCs:\n"
+ ^ Print.list_to_string (Print.list_to_string FunIdOrderedType.show_t) sccs
+ ));
+
+ (* Sanity check *)
+ let _ =
+ (* Check that the SCCs are pairwise disjoint *)
+ assert (FunIdSet.pairwise_disjoint (List.map FunIdSet.of_list sccs));
+ (* Check that all the ids are in the sccs *)
+ let scc_ids = FunIdSet.of_list (List.concat sccs) in
+
+ log#ldebug
+ (lazy
+ ("group_reorder_fun_decls: sanity check:" ^ "\n- ids : "
+ ^ FunIdSet.show ids ^ "\n- scc_ids: " ^ FunIdSet.show scc_ids));
+
+ assert (FunIdSet.equal scc_ids ids)
+ in
+
+ log#ldebug
+ (lazy
+ ("group_reorder_fun_decls: reordered SCCs:\n"
+ ^ Print.list_to_string (Print.list_to_string FunIdOrderedType.show_t) sccs
+ ));
+
(* Reorder *)
let module Reorder = SCC.Make (FunIdOrderedType) in
let id_deps =
@@ -131,6 +177,16 @@ let group_reorder_fun_decls (decls : fun_decl list) :
in
let sccs = Reorder.reorder_sccs id_deps idl sccs in
+ (* Sanity check *)
+ let _ =
+ (* Check that the SCCs are pairwise disjoint *)
+ let sccs = List.map snd (SccId.Map.bindings sccs.sccs) in
+ assert (FunIdSet.pairwise_disjoint (List.map FunIdSet.of_list sccs));
+ (* Check that all the ids are in the sccs *)
+ let scc_ids = FunIdSet.of_list (List.concat sccs) in
+ assert (scc_ids = ids)
+ in
+
(* Group the declarations *)
let deps = FunIdMap.of_list deps in
let decls = FunIdMap.of_list (List.map (fun d -> (get_fun_id d, d)) decls) in
diff --git a/compiler/SCC.ml b/compiler/SCC.ml
index 8b4cdb1f..889a972b 100644
--- a/compiler/SCC.ml
+++ b/compiler/SCC.ml
@@ -3,6 +3,9 @@
open Collections
module SccId = Identifiers.IdGen ()
+(** The local logger *)
+let log = Logging.scc_log
+
(** A functor which provides functions to work on strongly connected components *)
module Make (Id : OrderedType) = struct
module IdMap = MakeMap (Id)
@@ -91,6 +94,15 @@ module Make (Id : OrderedType) = struct
*)
let reorder_sccs (id_deps : Id.t list IdMap.t) (ids : Id.t list)
(sccs : Id.t list list) : sccs =
+ log#ldebug
+ (lazy
+ ("reorder_sccs:" ^ "\n- id_deps: "
+ ^ IdMap.show (Print.list_to_string Id.show_t) id_deps
+ ^ "\n- ids: "
+ ^ Print.list_to_string Id.show_t ids
+ ^ "\n- sccs: "
+ ^ Print.list_to_string (Print.list_to_string Id.show_t) sccs));
+
(* Map the identifiers to the SCC indices *)
let id_to_scc =
IdMap.of_list
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 4f31b738..87862d6b 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -383,8 +383,8 @@ let export_type (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx)
[is_rec]: [true] if the types are recursive. Necessarily [true] if there is
> 1 type to extract.
*)
-let export_types (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx)
- (is_rec : bool) (ids : Pure.TypeDeclId.id list) : unit =
+let export_types_group (fmt : Format.formatter) (config : gen_config)
+ (ctx : gen_ctx) (is_rec : bool) (ids : Pure.TypeDeclId.id list) : unit =
let export_type = export_type fmt config ctx in
let export_type_decl kind id = export_type kind id true false in
let export_type_extra_info kind id = export_type kind id false true in
@@ -432,30 +432,79 @@ let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx)
then
Extract.extract_global_decl ctx.extract_ctx fmt global body config.interface
+(** Utility.
+
+ Export a group of functions. See [export_functions_group].
+
+ We need this because for every function in Rust we may generate several functions
+ in the translation (a forward function, several backward functions, loop
+ functions, etc.). Those functions might call each other in different
+ ways (in particular, they may be mutually recursive, in which case we might
+ be able to group them into several groups of mutually recursive definitions,
+ etc.). For this reason, [export_functions_group] computes the dependency
+ graph of the functions as well as their strongly connected components, and
+ gives each SCC at a time to [export_functions].
+
+ Rem.: this function only extracts the function *declarations*. It doesn't
+ extract the decrease clauses, nor does it extract the unit tests.
+
+ Rem.: this function doesn't check [config.extract_fun_decls]: it should have
+ been checked by the caller.
+ *)
+let export_functions_declarations (fmt : Format.formatter) (config : gen_config)
+ (ctx : gen_ctx) (is_rec : bool) (decls : Pure.fun_decl list) : unit =
+ (* Utility to check a function has a decrease clause *)
+ let has_decreases_clause (def : Pure.fun_decl) : bool =
+ A.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause
+ in
+
+ (* Extract the function declarations *)
+ (* Check if the functions are mutually recursive - this really works
+ * to check if the forward and backward translations of a single
+ * recursive function are mutually recursive *)
+ let is_mut_rec = List.length decls > 1 in
+ assert ((not is_mut_rec) || is_rec);
+ let decls_length = List.length decls in
+ List.iteri
+ (fun i def ->
+ let is_opaque = Option.is_none def.Pure.body in
+ let kind =
+ if is_opaque then
+ if config.interface then ExtractBase.Declared else ExtractBase.Assumed
+ else if not is_rec then ExtractBase.SingleNonRec
+ else if is_mut_rec then
+ (* If the functions are mutually recursive, we need to distinguish:
+ * - the first of the group
+ * - the last of the group
+ * - the inner functions
+ *)
+ if i = 0 then ExtractBase.MutRecFirst
+ else if i = decls_length - 1 then ExtractBase.MutRecLast
+ else ExtractBase.MutRecInner
+ else ExtractBase.SingleRec
+ in
+ let has_decr_clause =
+ has_decreases_clause def && config.extract_decreases_clauses
+ in
+ (* Check if the definition needs to be filtered or not *)
+ if
+ ((not is_opaque) && config.extract_transparent)
+ || (is_opaque && config.extract_opaque)
+ then Extract.extract_fun_decl ctx.extract_ctx fmt kind has_decr_clause def)
+ decls
+
(** Export a group of function declarations.
In case of (non-mutually) recursive functions, we use a simple procedure to
check if the forward and backward functions are mutually recursive.
*)
-let export_functions (fmt : Format.formatter) (config : gen_config)
- (ctx : gen_ctx) (is_rec : bool)
- (pure_ls : (bool * pure_fun_translation) list) : unit =
+let export_functions_group (fmt : Format.formatter) (config : gen_config)
+ (ctx : gen_ctx) (pure_ls : (bool * pure_fun_translation) list) : unit =
(* Utility to check a function has a decrease clause *)
let has_decreases_clause (def : Pure.fun_decl) : bool =
A.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause
in
- (* Concatenate the function definitions, filtering the useless forward
- * functions. We also make pairs: (forward function, backward function)
- * (the forward function contains useful information that we want to keep) *)
- let fls =
- List.concat
- (List.map
- (fun (keep_fwd, (fwd, back_ls)) ->
- let back_ls = List.map (fun back -> (fwd, back)) back_ls in
- if keep_fwd then (fwd, fwd) :: back_ls else back_ls)
- pure_ls)
- in
(* Extract the decrease clauses template bodies *)
if config.extract_template_decreases_clauses then
List.iter
@@ -464,48 +513,28 @@ let export_functions (fmt : Format.formatter) (config : gen_config)
if has_decr_clause then
Extract.extract_template_decreases_clause ctx.extract_ctx fmt fwd)
pure_ls;
+
+ (* Concatenate the function definitions, filtering the useless forward
+ * functions. *)
+ let decls =
+ List.concat
+ (List.map
+ (fun (keep_fwd, (fwd, back_ls)) ->
+ if keep_fwd then fwd :: back_ls else back_ls)
+ pure_ls)
+ in
+
(* Extract the function definitions *)
(if config.extract_fun_decls then
- (* Check if the functions are mutually recursive - this really works
- * to check if the forward and backward translations of a single
- * recursive function are mutually recursive *)
- let is_mut_rec =
- if is_rec then
- if List.length pure_ls <= 1 then
- not (PureUtils.functions_not_mutually_recursive (List.map fst fls))
- else true
- else false
+ (* Group the mutually recursive definitions *)
+ let subgroups = ReorderDecls.group_reorder_fun_decls decls in
+
+ (* Extract the subgroups *)
+ let export_subgroup (is_rec : bool) (decls : Pure.fun_decl list) : unit =
+ export_functions_declarations fmt config ctx is_rec decls
in
- let fls_length = List.length fls in
- List.iteri
- (fun i (fwd_def, def) ->
- let is_opaque = Option.is_none fwd_def.Pure.body in
- let kind =
- if is_opaque then
- if config.interface then ExtractBase.Declared
- else ExtractBase.Assumed
- else if not is_rec then ExtractBase.SingleNonRec
- else if is_mut_rec then
- (* If the functions are mutually recursive, we need to distinguish:
- * - the first of the group
- * - the last of the group
- * - the inner functions
- *)
- if i = 0 then ExtractBase.MutRecFirst
- else if i = fls_length - 1 then ExtractBase.MutRecLast
- else ExtractBase.MutRecInner
- else ExtractBase.SingleRec
- in
- let has_decr_clause =
- has_decreases_clause def && config.extract_decreases_clauses
- in
- (* Check if the definition needs to be filtered or not *)
- if
- ((not is_opaque) && config.extract_transparent)
- || (is_opaque && config.extract_opaque)
- then
- Extract.extract_fun_decl ctx.extract_ctx fmt kind has_decr_clause def)
- fls);
+ List.iter (fun (is_rec, decls) -> export_subgroup is_rec decls) subgroups);
+
(* Insert unit tests if necessary *)
if config.test_trans_unit_functions then
List.iter
@@ -525,9 +554,9 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
- [extract_extra_info]: extra the extra type information (e.g.,
the [Arguments] information in Coq).
*)
- let export_functions = export_functions fmt config ctx in
+ let export_functions_group = export_functions_group fmt config ctx in
let export_global = export_global fmt config ctx in
- let export_types = export_types fmt config ctx in
+ let export_types_group = export_types_group fmt config ctx in
let export_state_type () : unit =
let kind =
@@ -538,13 +567,14 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
let export_decl_group (dg : A.declaration_group) : unit =
match dg with
- | Type (NonRec id) -> if config.extract_types then export_types false [ id ]
- | Type (Rec ids) -> if config.extract_types then export_types true ids
+ | Type (NonRec id) ->
+ if config.extract_types then export_types_group false [ id ]
+ | Type (Rec ids) -> if config.extract_types then export_types_group true ids
| Fun (NonRec id) ->
(* Lookup *)
let pure_fun = A.FunDeclId.Map.find id ctx.trans_funs in
(* Translate *)
- export_functions false [ pure_fun ]
+ export_functions_group [ pure_fun ]
| Fun (Rec ids) ->
(* General case of mutually recursive functions *)
(* Lookup *)
@@ -552,7 +582,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
List.map (fun id -> A.FunDeclId.Map.find id ctx.trans_funs) ids
in
(* Translate *)
- export_functions true pure_funs
+ export_functions_group pure_funs
| Global id -> export_global id
in