diff options
Diffstat (limited to 'src/PureUtils.ml')
-rw-r--r-- | src/PureUtils.ml | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/src/PureUtils.ml b/src/PureUtils.ml index 1a227e51..e0703d30 100644 --- a/src/PureUtils.ml +++ b/src/PureUtils.ml @@ -18,6 +18,21 @@ end module RegularFunIdMap = Collections.MakeMap (RegularFunIdOrderedType) +module FunIdOrderedType = struct + type t = fun_id + + let compare = compare_fun_id + + let to_string = show_fun_id + + let pp_t = pp_fun_id + + let show_t = show_fun_id +end + +module FunIdMap = Collections.MakeMap (FunIdOrderedType) +module FunIdSet = Collections.MakeSet (FunIdOrderedType) + (* TODO : move *) let binop_can_fail (binop : E.binop) : bool = match binop with @@ -128,3 +143,43 @@ let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) : let inputs = List.map subst sg.inputs in let outputs = List.map subst sg.outputs in { inputs; outputs } + +(** Return true if a list of functions are *not* mutually recursive, false otherwise. + This function is meant to be applied on a set of (forward, backwards) functions + generated for one recursive function. + The way we do the test is very simple: if any function body references another + function from the set, we consider the whole set is mutually recursive. Otherwise, + we consider it is not the case. Note that this check is conservative, making + it sound (also note that if the test is wrong, the code generated by the synthesis + will not be valid anyway...) + *) +let functions_not_mutually_recursive (funs : fun_def list) : bool = + (* Compute the set of function identifiers in the group *) + let ids = + FunIdSet.of_list + (List.map + (fun (f : fun_def) -> Regular (A.Local f.def_id, f.back_id)) + funs) + in + (* Explore every body *) + let body_only_calls_itself (fdef : fun_def) : bool = + let other_ids = + FunIdSet.remove (Regular (A.Local fdef.def_id, fdef.back_id)) ids + in + + let obj = + object + inherit [_] iter_expression as super + + method! visit_call env call = + if FunIdSet.mem call.func other_ids then raise Utils.Found + else super#visit_call env call + end + in + + try + obj#visit_texpression () fdef.body; + true + with Utils.Found -> false + in + List.for_all body_only_calls_itself funs |