summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSon Ho2022-02-02 23:32:24 +0100
committerSon Ho2022-02-02 23:32:24 +0100
commitb7189038d2df990b2dc0142b769510dcca507f82 (patch)
treec16fe8f91ba867ee6411d28d4bd145072acb7b9e /src
parent8116c4cb6aa002595fd7fcc47a39c1577e820f8e (diff)
Implement detection of non-recursive forward/backward functions groups when
extracting (non-mutually) recursive functions
Diffstat (limited to 'src')
-rw-r--r--src/Pure.ml4
-rw-r--r--src/PureUtils.ml55
-rw-r--r--src/Translate.ml29
3 files changed, 77 insertions, 11 deletions
diff --git a/src/Pure.ml b/src/Pure.ml
index 53c053b0..32a1ca4c 100644
--- a/src/Pure.ml
+++ b/src/Pure.ml
@@ -348,7 +348,7 @@ and typed_rvalue = { value : rvalue; ty : ty }
polymorphic = false;
}]
-type unop = Not | Neg of integer_type [@@deriving show]
+type unop = Not | Neg of integer_type [@@deriving show, ord]
(* TODO: redefine assumed_fun_id (we need to get rid of box! *)
@@ -358,7 +358,7 @@ type fun_id =
if it is a forward function *)
| Unop of unop
| Binop of E.binop * integer_type
-[@@deriving show]
+[@@deriving show, ord]
(** Meta-information stored in the AST *)
type meta = Assignment of mplace * typed_rvalue
diff --git a/src/PureUtils.ml b/src/PureUtils.ml
index 1a227e51..e0703d30 100644
--- a/src/PureUtils.ml
+++ b/src/PureUtils.ml
@@ -18,6 +18,21 @@ end
module RegularFunIdMap = Collections.MakeMap (RegularFunIdOrderedType)
+module FunIdOrderedType = struct
+ type t = fun_id
+
+ let compare = compare_fun_id
+
+ let to_string = show_fun_id
+
+ let pp_t = pp_fun_id
+
+ let show_t = show_fun_id
+end
+
+module FunIdMap = Collections.MakeMap (FunIdOrderedType)
+module FunIdSet = Collections.MakeSet (FunIdOrderedType)
+
(* TODO : move *)
let binop_can_fail (binop : E.binop) : bool =
match binop with
@@ -128,3 +143,43 @@ let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) :
let inputs = List.map subst sg.inputs in
let outputs = List.map subst sg.outputs in
{ inputs; outputs }
+
+(** Return true if a list of functions are *not* mutually recursive, false otherwise.
+ This function is meant to be applied on a set of (forward, backwards) functions
+ generated for one recursive function.
+ The way we do the test is very simple: if any function body references another
+ function from the set, we consider the whole set is mutually recursive. Otherwise,
+ we consider it is not the case. Note that this check is conservative, making
+ it sound (also note that if the test is wrong, the code generated by the synthesis
+ will not be valid anyway...)
+ *)
+let functions_not_mutually_recursive (funs : fun_def list) : bool =
+ (* Compute the set of function identifiers in the group *)
+ let ids =
+ FunIdSet.of_list
+ (List.map
+ (fun (f : fun_def) -> Regular (A.Local f.def_id, f.back_id))
+ funs)
+ in
+ (* Explore every body *)
+ let body_only_calls_itself (fdef : fun_def) : bool =
+ let other_ids =
+ FunIdSet.remove (Regular (A.Local fdef.def_id, fdef.back_id)) ids
+ in
+
+ let obj =
+ object
+ inherit [_] iter_expression as super
+
+ method! visit_call env call =
+ if FunIdSet.mem call.func other_ids then raise Utils.Found
+ else super#visit_call env call
+ end
+ in
+
+ try
+ obj#visit_texpression () fdef.body;
+ true
+ with Utils.Found -> false
+ in
+ List.for_all body_only_calls_itself funs
diff --git a/src/Translate.ml b/src/Translate.ml
index 95cae60b..cff814f4 100644
--- a/src/Translate.ml
+++ b/src/Translate.ml
@@ -350,17 +350,18 @@ let translate_module (filename : string) (config : C.partial_config)
let def = Pure.TypeDefId.Map.find id trans_types in
ExtractToFStar.extract_type_def extract_ctx fmt qualif def
in
- (* In case of recursive functions, we always extract the forward and
- * backward functions as mutually recursive functions.
- * There are many situations where they are actually not mutually recursive:
- * we could detect such cases. TODO *)
- let export_functions (is_rec : bool) (fls : Pure.fun_def list) : unit =
+ (* In case of (non-mutually) recursive functions, we use a simple procedure to
+ * check if the forward and backward functions are mutually recursive.
+ *)
+ let export_functions (is_rec : bool) (is_mut_rec : bool)
+ (fls : Pure.fun_def list) : unit =
List.iteri
(fun i def ->
let qualif =
if not is_rec then ExtractToFStar.Let
- else if i = 0 then ExtractToFStar.LetRec
- else ExtractToFStar.And
+ else if is_mut_rec then
+ if i = 0 then ExtractToFStar.LetRec else ExtractToFStar.And
+ else ExtractToFStar.LetRec
in
ExtractToFStar.extract_fun_def extract_ctx fmt qualif def)
fls
@@ -381,8 +382,18 @@ let translate_module (filename : string) (config : C.partial_config)
let fwd, back_ls = Pure.FunDefId.Map.find id trans_funs in
let fls = fwd :: back_ls in
(* Translate *)
- export_functions false fls
+ export_functions false false fls
+ | Fun (Rec [ id ]) ->
+ (* Simply recursive functions *)
+ (* Concatenate *)
+ let fwd, back_ls = Pure.FunDefId.Map.find id trans_funs in
+ let fls = fwd :: back_ls in
+ (* Check if mutually rec *)
+ let is_mut_rec = not (PureUtils.functions_not_mutually_recursive fls) in
+ (* Translate *)
+ export_functions true is_mut_rec fls
| Fun (Rec ids) ->
+ (* General case of mutually recursive functions *)
(* Concatenate *)
let compute_fun_id_list (id : Pure.FunDefId.id) : Pure.fun_def list =
let fwd, back_ls = Pure.FunDefId.Map.find id trans_funs in
@@ -390,7 +401,7 @@ let translate_module (filename : string) (config : C.partial_config)
in
let fls = List.concat (List.map compute_fun_id_list ids) in
(* Translate *)
- export_functions true fls
+ export_functions true true fls
in
List.iter export_decl m.declarations;