From 54a6b5d1a90b7304817175a33fc37444e559b11e Mon Sep 17 00:00:00 2001
From: Son Ho
Date: Wed, 14 Dec 2022 16:48:35 +0100
Subject: Compute the SCCs of the functions to extract in Translate.ml

---
 compiler/Logging.ml      |   6 ++
 compiler/Print.ml        |   3 +
 compiler/PureUtils.ml    |  49 ---------------
 compiler/ReorderDecls.ml |  86 ++++++++++++++++++++++-----
 compiler/SCC.ml          |  12 ++++
 compiler/Translate.ml    | 152 ++++++++++++++++++++++++++++-------------------
 6 files changed, 183 insertions(+), 125 deletions(-)

(limited to 'compiler')

diff --git a/compiler/Logging.ml b/compiler/Logging.ml
index 8cfc25d3..71f72471 100644
--- a/compiler/Logging.ml
+++ b/compiler/Logging.ml
@@ -47,3 +47,9 @@ let borrows_log = L.get_logger "MainLogger.Interpreter.Borrows"
 
 (** Logger for Invariants *)
 let invariants_log = L.get_logger "MainLogger.Interpreter.Invariants"
+
+(** Logger for SCC *)
+let scc_log = L.get_logger "MainLogger.SCC"
+
+(** Logger for ReorderDecls *)
+let reorder_decls_log = L.get_logger "MainLogger.ReorderDecls"
diff --git a/compiler/Print.ml b/compiler/Print.ml
index a4a8b1d4..82c5dac5 100644
--- a/compiler/Print.ml
+++ b/compiler/Print.ml
@@ -7,6 +7,9 @@ module PrimitiveValues = Charon.PrintPrimitiveValues
 module Types = Charon.PrintTypes
 module Expressions = Charon.PrintExpressions
 
+let list_to_string (to_string : 'a -> string) (ls : 'a list) : string =
+  "[" ^ String.concat "; " (List.map to_string ls) ^ "]"
+
 (** Pretty-printing for values *)
 module Values = struct
   type value_formatter = {
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index 0f1d50f1..da15d635 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -110,55 +110,6 @@ let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) :
   let info = sg.info in
   { inputs; output; doutputs; info }
 
-(** Return true if a list of functions are *not* mutually recursive, false otherwise.
-    This function is meant to be applied on a set of (forward, backwards) functions
-    generated for one recursive function.
-    The way we do the test is very simple:
-    - we explore the functions one by one, in the order in which they are provided
-    - if all functions only call functions we already explored, they are not
-      mutually recursive
- *)
-let functions_not_mutually_recursive (funs : fun_decl list) : bool =
-  (* Compute the set of function identifiers in the group *)
-  let ids =
-    FunOrOpIdSet.of_list
-      (List.map
-         (fun (f : fun_decl) -> Fun (FromLlbc (A.Regular f.def_id, f.back_id)))
-         funs)
-  in
-  let ids = ref ids in
-  (* Explore every body *)
-  let body_only_calls_itself (fdef : fun_decl) : bool =
-    (* Remove the current id from the id set *)
-    ids :=
-      FunOrOpIdSet.remove
-        (Fun (FromLlbc (A.Regular fdef.def_id, fdef.back_id)))
-        !ids;
-
-    (* Check if we call functions from the updated id set *)
-    let obj =
-      object
-        inherit [_] iter_expression as super
-
-        method! visit_qualif env qualif =
-          match qualif.id with
-          | FunOrOp fun_id ->
-              if FunOrOpIdSet.mem fun_id !ids then raise Utils.Found
-              else super#visit_qualif env qualif
-          | _ -> super#visit_qualif env qualif
-      end
-    in
-
-    try
-      match fdef.body with
-      | None -> true
-      | Some body ->
-          obj#visit_texpression () body.body;
-          true
-    with Utils.Found -> false
-  in
-  List.for_all body_only_calls_itself funs
-
 (** We use this to check whether we need to add parentheses around expressions.
     We only look for outer monadic let-bindings.
     This is used when printing the branches of [if ... then ... else ...].
diff --git a/compiler/ReorderDecls.ml b/compiler/ReorderDecls.ml
index 9d222011..2b78c570 100644
--- a/compiler/ReorderDecls.ml
+++ b/compiler/ReorderDecls.ml
@@ -3,6 +3,9 @@ open Collections
 open SCC
 open Pure
 
+(** The local logger *)
+let log = Logging.reorder_decls_log
+
 type fun_id = { def_id : FunDeclId.id; rg_id : T.RegionGroupId.id option }
 [@@deriving show, ord]
 
@@ -69,6 +72,11 @@ let group_reorder_fun_decls (decls : fun_decl list) :
   let idl = List.map get_fun_id decls in
   let ids = FunIdSet.of_list idl in
 
+  log#ldebug
+    (lazy
+      ("group_reorder_fun_decls: ids:\n"
+      ^ (Print.list_to_string FunIdOrderedType.show_t) idl));
+
   (* Explore the bodies to compute the dependencies - we ignore the ids
      which refer to declarations not in the group we want to reorder *)
   let deps : (fun_id * FunIdSet.t) list =
@@ -79,50 +87,88 @@ let group_reorder_fun_decls (decls : fun_decl list) :
         | None -> (id, FunIdSet.empty)
         | Some body ->
             let deps = compute_body_fun_deps body.body in
-            (id, FunIdSet.inter deps ids))
+            (* Restrict the set dependencies *)
+            let deps = FunIdSet.inter deps ids in
+            (id, deps))
       decls
   in
 
   (*
    * Create the dependency graph
    *)
-  (* Convert the ids to vertices (i.e., injectively map ids to integers) *)
-  let id_to_vertex : int FunIdMap.t =
+  (* Convert the ids to vertices (i.e., injectively map ids to integers, and create
+     vertices labeled with those integers).
+
+     Rem.: [Graph.create] is *imperative*: it generates a new vertex every time
+     it is called (!!).
+  *)
+  let module Graph = Pack.Digraph in
+  let id_to_vertex : Graph.V.t FunIdMap.t =
     let cnt = ref 0 in
     FunIdMap.of_list
       (List.map
          (fun id ->
-           let v = !cnt in
+           let lbl = !cnt in
            cnt := !cnt + 1;
+           (* We create a vertex *)
+           let v = Graph.V.create lbl in
            (id, v))
          idl)
   in
   let vertex_to_id : fun_id IntMap.t =
     IntMap.of_list
-      (List.map (fun (fid, vid) -> (vid, fid)) (FunIdMap.bindings id_to_vertex))
+      (List.map
+         (fun (fid, v) -> (Graph.V.label v, fid))
+         (FunIdMap.bindings id_to_vertex))
   in
-  let to_v id = Pack.Graph.V.create (FunIdMap.find id id_to_vertex) in
-  let to_id v = IntMap.find (Pack.Graph.V.label v) vertex_to_id in
 
-  let g = Pack.Graph.create () in
-  (* First add the vertices *)
-  List.iter (fun id -> Pack.Graph.add_vertex g (to_v id)) idl;
+  let to_v id = FunIdMap.find id id_to_vertex in
+  let to_id v = IntMap.find (Graph.V.label v) vertex_to_id in
 
-  (* Then add the edges *)
+  let g = Graph.create () in
+
+  (* Add the edges, first from the vertices to themselves, then between vertices *)
   List.iter
     (fun (fun_id, deps) ->
-      FunIdSet.iter
-        (fun dep_id -> Pack.Graph.add_edge g (to_v fun_id) (to_v dep_id))
-        deps)
+      let v = to_v fun_id in
+      Graph.add_edge g v v;
+      FunIdSet.iter (fun dep_id -> Graph.add_edge g v (to_v dep_id)) deps)
     deps;
 
   (* Compute the SCCs *)
-  let module Comp = Components.Make (Pack.Graph) in
+  let module Comp = Components.Make (Graph) in
   let sccs = Comp.scc_list g in
 
   (* Convert the vertices to ids *)
   let sccs = List.map (List.map to_id) sccs in
 
+  log#ldebug
+    (lazy
+      ("group_reorder_fun_decls: SCCs:\n"
+      ^ Print.list_to_string (Print.list_to_string FunIdOrderedType.show_t) sccs
+      ));
+
+  (* Sanity check *)
+  let _ =
+    (* Check that the SCCs are pairwise disjoint *)
+    assert (FunIdSet.pairwise_disjoint (List.map FunIdSet.of_list sccs));
+    (* Check that all the ids are in the sccs *)
+    let scc_ids = FunIdSet.of_list (List.concat sccs) in
+
+    log#ldebug
+      (lazy
+        ("group_reorder_fun_decls: sanity check:" ^ "\n- ids    : "
+       ^ FunIdSet.show ids ^ "\n- scc_ids: " ^ FunIdSet.show scc_ids));
+
+    assert (FunIdSet.equal scc_ids ids)
+  in
+
+  log#ldebug
+    (lazy
+      ("group_reorder_fun_decls: reordered SCCs:\n"
+      ^ Print.list_to_string (Print.list_to_string FunIdOrderedType.show_t) sccs
+      ));
+
   (* Reorder *)
   let module Reorder = SCC.Make (FunIdOrderedType) in
   let id_deps =
@@ -131,6 +177,16 @@ let group_reorder_fun_decls (decls : fun_decl list) :
   in
   let sccs = Reorder.reorder_sccs id_deps idl sccs in
 
+  (* Sanity check *)
+  let _ =
+    (* Check that the SCCs are pairwise disjoint *)
+    let sccs = List.map snd (SccId.Map.bindings sccs.sccs) in
+    assert (FunIdSet.pairwise_disjoint (List.map FunIdSet.of_list sccs));
+    (* Check that all the ids are in the sccs *)
+    let scc_ids = FunIdSet.of_list (List.concat sccs) in
+    assert (scc_ids = ids)
+  in
+
   (* Group the declarations *)
   let deps = FunIdMap.of_list deps in
   let decls = FunIdMap.of_list (List.map (fun d -> (get_fun_id d, d)) decls) in
diff --git a/compiler/SCC.ml b/compiler/SCC.ml
index 8b4cdb1f..889a972b 100644
--- a/compiler/SCC.ml
+++ b/compiler/SCC.ml
@@ -3,6 +3,9 @@
 open Collections
 module SccId = Identifiers.IdGen ()
 
+(** The local logger *)
+let log = Logging.scc_log
+
 (** A functor which provides functions to work on strongly connected components *)
 module Make (Id : OrderedType) = struct
   module IdMap = MakeMap (Id)
@@ -91,6 +94,15 @@ module Make (Id : OrderedType) = struct
   *)
   let reorder_sccs (id_deps : Id.t list IdMap.t) (ids : Id.t list)
       (sccs : Id.t list list) : sccs =
+    log#ldebug
+      (lazy
+        ("reorder_sccs:" ^ "\n- id_deps: "
+        ^ IdMap.show (Print.list_to_string Id.show_t) id_deps
+        ^ "\n- ids: "
+        ^ Print.list_to_string Id.show_t ids
+        ^ "\n- sccs: "
+        ^ Print.list_to_string (Print.list_to_string Id.show_t) sccs));
+
     (* Map the identifiers to the SCC indices *)
     let id_to_scc =
       IdMap.of_list
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 4f31b738..87862d6b 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -383,8 +383,8 @@ let export_type (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx)
     [is_rec]: [true] if the types are recursive. Necessarily [true] if there is
     > 1 type to extract.
  *)
-let export_types (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx)
-    (is_rec : bool) (ids : Pure.TypeDeclId.id list) : unit =
+let export_types_group (fmt : Format.formatter) (config : gen_config)
+    (ctx : gen_ctx) (is_rec : bool) (ids : Pure.TypeDeclId.id list) : unit =
   let export_type = export_type fmt config ctx in
   let export_type_decl kind id = export_type kind id true false in
   let export_type_extra_info kind id = export_type kind id false true in
@@ -432,30 +432,79 @@ let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx)
   then
     Extract.extract_global_decl ctx.extract_ctx fmt global body config.interface
 
+(** Utility.
+
+    Export a group of functions. See [export_functions_group].
+
+    We need this because for every function in Rust we may generate several functions
+    in the translation (a forward function, several backward functions, loop
+    functions, etc.). Those functions might call each other in different
+    ways (in particular, they may be mutually recursive, in which case we might
+    be able to group them into several groups of mutually recursive definitions,
+    etc.). For this reason, [export_functions_group] computes the dependency
+    graph of the functions as well as their strongly connected components, and
+    gives each SCC at a time to [export_functions].
+
+    Rem.: this function only extracts the function *declarations*. It doesn't
+    extract the decrease clauses, nor does it extract the unit tests.
+
+    Rem.: this function doesn't check [config.extract_fun_decls]: it should have
+    been checked by the caller.
+ *)
+let export_functions_declarations (fmt : Format.formatter) (config : gen_config)
+    (ctx : gen_ctx) (is_rec : bool) (decls : Pure.fun_decl list) : unit =
+  (* Utility to check a function has a decrease clause *)
+  let has_decreases_clause (def : Pure.fun_decl) : bool =
+    A.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause
+  in
+
+  (* Extract the function declarations *)
+  (* Check if the functions are mutually recursive - this really works
+   * to check if the forward and backward translations of a single
+   * recursive function are mutually recursive *)
+  let is_mut_rec = List.length decls > 1 in
+  assert ((not is_mut_rec) || is_rec);
+  let decls_length = List.length decls in
+  List.iteri
+    (fun i def ->
+      let is_opaque = Option.is_none def.Pure.body in
+      let kind =
+        if is_opaque then
+          if config.interface then ExtractBase.Declared else ExtractBase.Assumed
+        else if not is_rec then ExtractBase.SingleNonRec
+        else if is_mut_rec then
+          (* If the functions are mutually recursive, we need to distinguish:
+           * - the first of the group
+           * - the last of the group
+           * - the inner functions
+           *)
+          if i = 0 then ExtractBase.MutRecFirst
+          else if i = decls_length - 1 then ExtractBase.MutRecLast
+          else ExtractBase.MutRecInner
+        else ExtractBase.SingleRec
+      in
+      let has_decr_clause =
+        has_decreases_clause def && config.extract_decreases_clauses
+      in
+      (* Check if the definition needs to be filtered or not *)
+      if
+        ((not is_opaque) && config.extract_transparent)
+        || (is_opaque && config.extract_opaque)
+      then Extract.extract_fun_decl ctx.extract_ctx fmt kind has_decr_clause def)
+    decls
+
 (** Export a group of function declarations.
 
     In case of (non-mutually) recursive functions, we use a simple procedure to
     check if the forward and backward functions are mutually recursive.
  *)
-let export_functions (fmt : Format.formatter) (config : gen_config)
-    (ctx : gen_ctx) (is_rec : bool)
-    (pure_ls : (bool * pure_fun_translation) list) : unit =
+let export_functions_group (fmt : Format.formatter) (config : gen_config)
+    (ctx : gen_ctx) (pure_ls : (bool * pure_fun_translation) list) : unit =
   (* Utility to check a function has a decrease clause *)
   let has_decreases_clause (def : Pure.fun_decl) : bool =
     A.FunDeclId.Set.mem def.def_id ctx.functions_with_decreases_clause
   in
 
-  (* Concatenate the function definitions, filtering the useless forward
-   * functions. We also make pairs: (forward function, backward function)
-   * (the forward function contains useful information that we want to keep) *)
-  let fls =
-    List.concat
-      (List.map
-         (fun (keep_fwd, (fwd, back_ls)) ->
-           let back_ls = List.map (fun back -> (fwd, back)) back_ls in
-           if keep_fwd then (fwd, fwd) :: back_ls else back_ls)
-         pure_ls)
-  in
   (* Extract the decrease clauses template bodies *)
   if config.extract_template_decreases_clauses then
     List.iter
@@ -464,48 +513,28 @@ let export_functions (fmt : Format.formatter) (config : gen_config)
         if has_decr_clause then
           Extract.extract_template_decreases_clause ctx.extract_ctx fmt fwd)
       pure_ls;
+
+  (* Concatenate the function definitions, filtering the useless forward
+   * functions. *)
+  let decls =
+    List.concat
+      (List.map
+         (fun (keep_fwd, (fwd, back_ls)) ->
+           if keep_fwd then fwd :: back_ls else back_ls)
+         pure_ls)
+  in
+
   (* Extract the function definitions *)
   (if config.extract_fun_decls then
-   (* Check if the functions are mutually recursive - this really works
-    * to check if the forward and backward translations of a single
-    * recursive function are mutually recursive *)
-   let is_mut_rec =
-     if is_rec then
-       if List.length pure_ls <= 1 then
-         not (PureUtils.functions_not_mutually_recursive (List.map fst fls))
-       else true
-     else false
+   (* Group the mutually recursive definitions *)
+   let subgroups = ReorderDecls.group_reorder_fun_decls decls in
+
+   (* Extract the subgroups *)
+   let export_subgroup (is_rec : bool) (decls : Pure.fun_decl list) : unit =
+     export_functions_declarations fmt config ctx is_rec decls
    in
-   let fls_length = List.length fls in
-   List.iteri
-     (fun i (fwd_def, def) ->
-       let is_opaque = Option.is_none fwd_def.Pure.body in
-       let kind =
-         if is_opaque then
-           if config.interface then ExtractBase.Declared
-           else ExtractBase.Assumed
-         else if not is_rec then ExtractBase.SingleNonRec
-         else if is_mut_rec then
-           (* If the functions are mutually recursive, we need to distinguish:
-            * - the first of the group
-            * - the last of the group
-            * - the inner functions
-            *)
-           if i = 0 then ExtractBase.MutRecFirst
-           else if i = fls_length - 1 then ExtractBase.MutRecLast
-           else ExtractBase.MutRecInner
-         else ExtractBase.SingleRec
-       in
-       let has_decr_clause =
-         has_decreases_clause def && config.extract_decreases_clauses
-       in
-       (* Check if the definition needs to be filtered or not *)
-       if
-         ((not is_opaque) && config.extract_transparent)
-         || (is_opaque && config.extract_opaque)
-       then
-         Extract.extract_fun_decl ctx.extract_ctx fmt kind has_decr_clause def)
-     fls);
+   List.iter (fun (is_rec, decls) -> export_subgroup is_rec decls) subgroups);
+
   (* Insert unit tests if necessary *)
   if config.test_trans_unit_functions then
     List.iter
@@ -525,9 +554,9 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
      - [extract_extra_info]: extra the extra type information (e.g.,
        the [Arguments] information in Coq).
   *)
-  let export_functions = export_functions fmt config ctx in
+  let export_functions_group = export_functions_group fmt config ctx in
   let export_global = export_global fmt config ctx in
-  let export_types = export_types fmt config ctx in
+  let export_types_group = export_types_group fmt config ctx in
 
   let export_state_type () : unit =
     let kind =
@@ -538,13 +567,14 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
 
   let export_decl_group (dg : A.declaration_group) : unit =
     match dg with
-    | Type (NonRec id) -> if config.extract_types then export_types false [ id ]
-    | Type (Rec ids) -> if config.extract_types then export_types true ids
+    | Type (NonRec id) ->
+        if config.extract_types then export_types_group false [ id ]
+    | Type (Rec ids) -> if config.extract_types then export_types_group true ids
     | Fun (NonRec id) ->
         (* Lookup *)
         let pure_fun = A.FunDeclId.Map.find id ctx.trans_funs in
         (* Translate *)
-        export_functions false [ pure_fun ]
+        export_functions_group [ pure_fun ]
     | Fun (Rec ids) ->
         (* General case of mutually recursive functions *)
         (* Lookup *)
@@ -552,7 +582,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
           List.map (fun id -> A.FunDeclId.Map.find id ctx.trans_funs) ids
         in
         (* Translate *)
-        export_functions true pure_funs
+        export_functions_group pure_funs
     | Global id -> export_global id
   in
 
-- 
cgit v1.2.3