summaryrefslogtreecommitdiff
path: root/src/PureUtils.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/PureUtils.ml')
-rw-r--r--src/PureUtils.ml55
1 files changed, 55 insertions, 0 deletions
diff --git a/src/PureUtils.ml b/src/PureUtils.ml
index 1a227e51..e0703d30 100644
--- a/src/PureUtils.ml
+++ b/src/PureUtils.ml
@@ -18,6 +18,21 @@ end
module RegularFunIdMap = Collections.MakeMap (RegularFunIdOrderedType)
+module FunIdOrderedType = struct
+ type t = fun_id
+
+ let compare = compare_fun_id
+
+ let to_string = show_fun_id
+
+ let pp_t = pp_fun_id
+
+ let show_t = show_fun_id
+end
+
+module FunIdMap = Collections.MakeMap (FunIdOrderedType)
+module FunIdSet = Collections.MakeSet (FunIdOrderedType)
+
(* TODO : move *)
let binop_can_fail (binop : E.binop) : bool =
match binop with
@@ -128,3 +143,43 @@ let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) :
let inputs = List.map subst sg.inputs in
let outputs = List.map subst sg.outputs in
{ inputs; outputs }
+
+(** Return true if a list of functions are *not* mutually recursive, false otherwise.
+ This function is meant to be applied on a set of (forward, backwards) functions
+ generated for one recursive function.
+ The way we do the test is very simple: if any function body references another
+ function from the set, we consider the whole set is mutually recursive. Otherwise,
+ we consider it is not the case. Note that this check is conservative, making
+ it sound (also note that if the test is wrong, the code generated by the synthesis
+ will not be valid anyway...)
+ *)
+let functions_not_mutually_recursive (funs : fun_def list) : bool =
+ (* Compute the set of function identifiers in the group *)
+ let ids =
+ FunIdSet.of_list
+ (List.map
+ (fun (f : fun_def) -> Regular (A.Local f.def_id, f.back_id))
+ funs)
+ in
+ (* Explore every body *)
+ let body_only_calls_itself (fdef : fun_def) : bool =
+ let other_ids =
+ FunIdSet.remove (Regular (A.Local fdef.def_id, fdef.back_id)) ids
+ in
+
+ let obj =
+ object
+ inherit [_] iter_expression as super
+
+ method! visit_call env call =
+ if FunIdSet.mem call.func other_ids then raise Utils.Found
+ else super#visit_call env call
+ end
+ in
+
+ try
+ obj#visit_texpression () fdef.body;
+ true
+ with Utils.Found -> false
+ in
+ List.for_all body_only_calls_itself funs