From b7189038d2df990b2dc0142b769510dcca507f82 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 2 Feb 2022 23:32:24 +0100 Subject: Implement detection of non-recursive forward/backward functions groups when extracting (non-mutually) recursive functions --- src/Pure.ml | 4 ++-- src/PureUtils.ml | 55 +++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/Translate.ml | 29 ++++++++++++++++++++--------- 3 files changed, 77 insertions(+), 11 deletions(-) (limited to 'src') diff --git a/src/Pure.ml b/src/Pure.ml index 53c053b0..32a1ca4c 100644 --- a/src/Pure.ml +++ b/src/Pure.ml @@ -348,7 +348,7 @@ and typed_rvalue = { value : rvalue; ty : ty } polymorphic = false; }] -type unop = Not | Neg of integer_type [@@deriving show] +type unop = Not | Neg of integer_type [@@deriving show, ord] (* TODO: redefine assumed_fun_id (we need to get rid of box! *) @@ -358,7 +358,7 @@ type fun_id = if it is a forward function *) | Unop of unop | Binop of E.binop * integer_type -[@@deriving show] +[@@deriving show, ord] (** Meta-information stored in the AST *) type meta = Assignment of mplace * typed_rvalue diff --git a/src/PureUtils.ml b/src/PureUtils.ml index 1a227e51..e0703d30 100644 --- a/src/PureUtils.ml +++ b/src/PureUtils.ml @@ -18,6 +18,21 @@ end module RegularFunIdMap = Collections.MakeMap (RegularFunIdOrderedType) +module FunIdOrderedType = struct + type t = fun_id + + let compare = compare_fun_id + + let to_string = show_fun_id + + let pp_t = pp_fun_id + + let show_t = show_fun_id +end + +module FunIdMap = Collections.MakeMap (FunIdOrderedType) +module FunIdSet = Collections.MakeSet (FunIdOrderedType) + (* TODO : move *) let binop_can_fail (binop : E.binop) : bool = match binop with @@ -128,3 +143,43 @@ let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) : let inputs = List.map subst sg.inputs in let outputs = List.map subst sg.outputs in { inputs; outputs } + +(** Return true if a list of functions are *not* mutually recursive, false otherwise. + This function is meant to be applied on a set of (forward, backwards) functions + generated for one recursive function. + The way we do the test is very simple: if any function body references another + function from the set, we consider the whole set is mutually recursive. Otherwise, + we consider it is not the case. Note that this check is conservative, making + it sound (also note that if the test is wrong, the code generated by the synthesis + will not be valid anyway...) + *) +let functions_not_mutually_recursive (funs : fun_def list) : bool = + (* Compute the set of function identifiers in the group *) + let ids = + FunIdSet.of_list + (List.map + (fun (f : fun_def) -> Regular (A.Local f.def_id, f.back_id)) + funs) + in + (* Explore every body *) + let body_only_calls_itself (fdef : fun_def) : bool = + let other_ids = + FunIdSet.remove (Regular (A.Local fdef.def_id, fdef.back_id)) ids + in + + let obj = + object + inherit [_] iter_expression as super + + method! visit_call env call = + if FunIdSet.mem call.func other_ids then raise Utils.Found + else super#visit_call env call + end + in + + try + obj#visit_texpression () fdef.body; + true + with Utils.Found -> false + in + List.for_all body_only_calls_itself funs diff --git a/src/Translate.ml b/src/Translate.ml index 95cae60b..cff814f4 100644 --- a/src/Translate.ml +++ b/src/Translate.ml @@ -350,17 +350,18 @@ let translate_module (filename : string) (config : C.partial_config) let def = Pure.TypeDefId.Map.find id trans_types in ExtractToFStar.extract_type_def extract_ctx fmt qualif def in - (* In case of recursive functions, we always extract the forward and - * backward functions as mutually recursive functions. - * There are many situations where they are actually not mutually recursive: - * we could detect such cases. TODO *) - let export_functions (is_rec : bool) (fls : Pure.fun_def list) : unit = + (* In case of (non-mutually) recursive functions, we use a simple procedure to + * check if the forward and backward functions are mutually recursive. + *) + let export_functions (is_rec : bool) (is_mut_rec : bool) + (fls : Pure.fun_def list) : unit = List.iteri (fun i def -> let qualif = if not is_rec then ExtractToFStar.Let - else if i = 0 then ExtractToFStar.LetRec - else ExtractToFStar.And + else if is_mut_rec then + if i = 0 then ExtractToFStar.LetRec else ExtractToFStar.And + else ExtractToFStar.LetRec in ExtractToFStar.extract_fun_def extract_ctx fmt qualif def) fls @@ -381,8 +382,18 @@ let translate_module (filename : string) (config : C.partial_config) let fwd, back_ls = Pure.FunDefId.Map.find id trans_funs in let fls = fwd :: back_ls in (* Translate *) - export_functions false fls + export_functions false false fls + | Fun (Rec [ id ]) -> + (* Simply recursive functions *) + (* Concatenate *) + let fwd, back_ls = Pure.FunDefId.Map.find id trans_funs in + let fls = fwd :: back_ls in + (* Check if mutually rec *) + let is_mut_rec = not (PureUtils.functions_not_mutually_recursive fls) in + (* Translate *) + export_functions true is_mut_rec fls | Fun (Rec ids) -> + (* General case of mutually recursive functions *) (* Concatenate *) let compute_fun_id_list (id : Pure.FunDefId.id) : Pure.fun_def list = let fwd, back_ls = Pure.FunDefId.Map.find id trans_funs in @@ -390,7 +401,7 @@ let translate_module (filename : string) (config : C.partial_config) in let fls = List.concat (List.map compute_fun_id_list ids) in (* Translate *) - export_functions true fls + export_functions true true fls in List.iter export_decl m.declarations; -- cgit v1.2.3