From b1e57277baf539f1f009f7c927a1a7445ce6ea45 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 14 Dec 2022 17:24:37 +0100 Subject: Add loop ids to the pure functions identifiers --- compiler/Extract.ml | 14 ++++++++++---- compiler/ExtractBase.ml | 27 ++++++++++++++++++--------- compiler/PrintPure.ml | 27 ++++++++++++++++++++------- compiler/Pure.ml | 7 ++++++- compiler/PureMicroPasses.ml | 15 ++++++++------- compiler/ReorderDecls.ml | 12 ++++++++---- compiler/SymbolicToPure.ml | 11 +++++++++-- compiler/Translate.ml | 4 +--- 8 files changed, 80 insertions(+), 37 deletions(-) diff --git a/compiler/Extract.ml b/compiler/Extract.ml index ce0609f5..fd44fab4 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -1907,7 +1907,9 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) (kind : decl_kind) (has_decreases_clause : bool) (def : fun_decl) : unit = assert (not def.is_global_decl_body); (* Retrieve the function name *) - let def_name = ctx_get_local_function def.def_id def.back_id ctx in + let def_name = + ctx_get_local_function def.def_id def.loop_id def.back_id ctx + in (* Add a break before *) F.pp_print_break fmt 0 0; (* Print a comment to link the extracted type to its original rust definition *) @@ -2141,7 +2143,7 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) let decl_name = ctx_get_global global.def_id ctx in let body_name = - ctx_get_function (FromLlbc (Regular global.body_id, None)) ctx + ctx_get_function (FromLlbc (Regular global.body_id, None, None)) ctx in let decl_ty, body_ty = @@ -2211,7 +2213,9 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "assert_norm"; F.pp_print_space fmt (); F.pp_print_string fmt "("; - let fun_name = ctx_get_local_function def.def_id def.back_id ctx in + let fun_name = + ctx_get_local_function def.def_id def.loop_id def.back_id ctx + in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( F.pp_print_space fmt (); @@ -2225,7 +2229,9 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt "Check"; F.pp_print_space fmt (); F.pp_print_string fmt "("; - let fun_name = ctx_get_local_function def.def_id def.back_id ctx in + let fun_name = + ctx_get_local_function def.def_id def.loop_id def.back_id ctx + in F.pp_print_string fmt fun_name; if sg.inputs <> [] then ( F.pp_print_space fmt (); diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index b1901fca..06c71236 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -436,7 +436,7 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = "global name: " ^ Print.global_name_to_string name | FunId fid -> ( match fid with - | FromLlbc (fid, rg_id) -> + | FromLlbc (fid, lp_id, rg_id) -> let fun_name = match fid with | Regular fid -> @@ -444,12 +444,19 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = (A.FunDeclId.Map.find fid fun_decls).name | Assumed aid -> A.show_assumed_fun_id aid in - let fun_kind = + + let lp_kind = + match lp_id with + | None -> "" + | Some lp_id -> "loop " ^ V.LoopId.to_string lp_id ^ ", " + in + + let fwd_back_kind = match rg_id with | None -> "forward" | Some rg_id -> "backward " ^ RegionGroupId.to_string rg_id in - "fun name (" ^ fun_kind ^ "): " ^ fun_name + "fun name (" ^ lp_kind ^ fwd_back_kind ^ "): " ^ fun_name | Pure fid -> PrintPure.pure_assumed_fun_id_to_string fid) | DecreasesClauseId fid -> let fun_name = @@ -534,9 +541,9 @@ let ctx_get_global (id : A.GlobalDeclId.id) (ctx : extraction_ctx) : string = let ctx_get_function (id : fun_id) (ctx : extraction_ctx) : string = ctx_get (FunId id) ctx -let ctx_get_local_function (id : A.FunDeclId.id) (rg : RegionGroupId.id option) - (ctx : extraction_ctx) : string = - ctx_get_function (FromLlbc (Regular id, rg)) ctx +let ctx_get_local_function (id : A.FunDeclId.id) (lp : V.LoopId.id option) + (rg : RegionGroupId.id option) (ctx : extraction_ctx) : string = + ctx_get_function (FromLlbc (Regular id, lp, rg)) ctx let ctx_get_type (id : type_id) (ctx : extraction_ctx) : string = assert (id <> Tuple); @@ -662,7 +669,7 @@ let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) : extraction_ctx = let name = ctx.fmt.global_name def.name in let decl = GlobalId def.def_id in - let body = FunId (FromLlbc (Regular def.body_id, None)) in + let body = FunId (FromLlbc (Regular def.body_id, None, None)) in let ctx = ctx_add decl (name ^ "_c") ctx in let ctx = ctx_add body (name ^ "_body") ctx in ctx @@ -699,7 +706,9 @@ let ctx_add_fun_decl (trans_group : bool * pure_fun_translation) let name = ctx.fmt.fun_name def.basename num_rgs rg_info (keep_fwd, num_backs) in - ctx_add (FunId (FromLlbc (A.Regular def_id, def.back_id))) name ctx + ctx_add + (FunId (FromLlbc (A.Regular def_id, def.loop_id, def.back_id))) + name ctx type names_map_init = { keywords : string list; @@ -760,7 +769,7 @@ let initialize_names_map (fmt : formatter) (init : names_map_init) : names_map = in let assumed_functions = List.map - (fun (fid, rg, name) -> (FromLlbc (A.Assumed fid, rg), name)) + (fun (fid, rg, name) -> (FromLlbc (A.Assumed fid, None, rg), name)) init.assumed_llbc_functions @ List.map (fun (fid, name) -> (Pure fid, name)) init.assumed_pure_functions in diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 726cc9a0..3113347c 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -415,10 +415,21 @@ let inst_fun_sig_to_string (fmt : ast_formatter) (sg : inst_fun_sig) : string = let all_types = List.append inputs [ output ] in String.concat " -> " all_types -let fun_suffix (rg_id : T.RegionGroupId.id option) : string = - match rg_id with - | None -> "" - | Some rg_id -> "@" ^ T.RegionGroupId.to_string rg_id +let fun_suffix (lp_id : V.LoopId.id option) (rg_id : T.RegionGroupId.id option) + : string = + let lp_suff = + match lp_id with + | None -> "" + | Some lp_id -> "^loop^" ^ V.LoopId.to_string lp_id + in + + let rg_suff = + match rg_id with + | None -> "" + | Some rg_id -> "@" ^ T.RegionGroupId.to_string rg_id + in + + lp_suff ^ rg_suff let llbc_assumed_fun_id_to_string (fid : A.assumed_fun_id) : string = match fid with @@ -444,13 +455,13 @@ let pure_assumed_fun_id_to_string (fid : pure_assumed_fun_id) : string = let regular_fun_id_to_string (fmt : ast_formatter) (fun_id : fun_id) : string = match fun_id with - | FromLlbc (fid, rg_id) -> + | FromLlbc (fid, lp_id, rg_id) -> let f = match fid with | Regular fid -> fmt.fun_decl_id_to_string fid | Assumed fid -> llbc_assumed_fun_id_to_string fid in - f ^ fun_suffix rg_id + f ^ fun_suffix lp_id rg_id | Pure fid -> pure_assumed_fun_id_to_string fid let unop_to_string (unop : unop) : string = @@ -620,7 +631,9 @@ and meta_to_string (fmt : ast_formatter) (meta : meta) : string = let fun_decl_to_string (fmt : ast_formatter) (def : fun_decl) : string = let type_fmt = ast_to_type_formatter fmt in - let name = fun_name_to_string def.basename ^ fun_suffix def.back_id in + let name = + fun_name_to_string def.basename ^ fun_suffix def.loop_id def.back_id + in let signature = fun_sig_to_string fmt def.signature in match def.body with | None -> "val " ^ name ^ " :\n " ^ signature diff --git a/compiler/Pure.ml b/compiler/Pure.ml index d9d3a404..2578273d 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -306,9 +306,12 @@ type pure_assumed_fun_id = (** A function identifier *) type fun_id = - | FromLlbc of A.fun_id * T.RegionGroupId.id option + | FromLlbc of A.fun_id * V.LoopId.id option * T.RegionGroupId.id option (** A function coming from LLBC. + The loop id is [None] if the function is actually the auxiliary function + generated from a loop. + The region group id is the backward id:: [Some] if the function is a backward function, [None] if it is a forward function. *) @@ -624,6 +627,8 @@ type fun_body = { type fun_decl = { def_id : FunDeclId.id; + loop_id : V.LoopId.id option; + (** [Some] if this definition was generated for a loop *) back_id : T.RegionGroupId.id option; basename : fun_name; (** The "base" name of the function. diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index c5eb3c64..2e4a534e 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -621,15 +621,16 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool) In this situation, we can remove the call [f@fwd x]. *) let expression_contains_child_call_in_all_paths (ctx : trans_ctx) - (id0 : A.fun_id) (rg_id0 : T.RegionGroupId.id option) (tys0 : ty list) + (id0 : A.fun_id) (lp_id0 : V.LoopId.id option) + (rg_id0 : T.RegionGroupId.id option) (tys0 : ty list) (args0 : texpression list) (e : texpression) : bool = let check_call (fun_id1 : fun_or_op_id) (tys1 : ty list) (args1 : texpression list) : bool = (* Check the fun_ids, to see if call1's function is a child of call0's function *) match fun_id1 with - | Fun (FromLlbc (id1, rg_id1)) -> + | Fun (FromLlbc (id1, lp_id1, rg_id1)) -> (* Both are "regular" calls: check if they come from the same rust function *) - if id0 = id1 then + if id0 = id1 && lp_id0 = lp_id1 then (* Same rust functions: check the regions hierarchy *) let call1_is_child = match (rg_id0, rg_id1) with @@ -801,13 +802,13 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) * We can filter if the right-expression is a function call, * under some conditions. *) match (filter_monadic_calls, opt_destruct_function_call re) with - | true, Some (Fun (FromLlbc (fid, rg_id)), tys, args) -> + | true, Some (Fun (FromLlbc (fid, lp_id, rg_id)), tys, args) -> (* We need to check if there is a child call - see * the comments for: * [expression_contains_child_call_in_all_paths] *) let has_child_call = - expression_contains_child_call_in_all_paths ctx fid rg_id - tys args e + expression_contains_child_call_in_all_paths ctx fid lp_id + rg_id tys args e in if has_child_call then (* Filter *) (e.e, fun _ -> used) @@ -1031,7 +1032,7 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = match opt_destruct_function_call e with | Some (fun_id, _tys, args) -> ( match fun_id with - | Fun (FromLlbc (A.Assumed aid, rg_id)) -> ( + | Fun (FromLlbc (A.Assumed aid, _lp_id, rg_id)) -> ( (* Below, when dealing with the arguments: we consider the very * general case, where functions could be boxed (meaning we * could have: [box_new f x]) diff --git a/compiler/ReorderDecls.ml b/compiler/ReorderDecls.ml index 2b78c570..4f02eb81 100644 --- a/compiler/ReorderDecls.ml +++ b/compiler/ReorderDecls.ml @@ -6,7 +6,11 @@ open Pure (** The local logger *) let log = Logging.reorder_decls_log -type fun_id = { def_id : FunDeclId.id; rg_id : T.RegionGroupId.id option } +type fun_id = { + def_id : FunDeclId.id; + lp_id : V.LoopId.id option; + rg_id : T.RegionGroupId.id option; +} [@@deriving show, ord] module FunIdOrderedType : OrderedType with type t = fun_id = struct @@ -38,11 +42,11 @@ let compute_body_fun_deps (e : texpression) : FunIdSet.t = | FunOrOp (Fun fid) -> ( match fid with | Pure _ -> () - | FromLlbc (fid, rg_id) -> ( + | FromLlbc (fid, lp_id, rg_id) -> ( match fid with | Assumed _ -> () | Regular fid -> - let id = { def_id = fid; rg_id } in + let id = { def_id = fid; lp_id; rg_id } in ids := FunIdSet.add id !ids)) end in @@ -66,7 +70,7 @@ let group_reorder_fun_decls (decls : fun_decl list) : (bool * fun_decl list) list = let module IntMap = MakeMap (OrderedInt) in let get_fun_id (decl : fun_decl) : fun_id = - { def_id = decl.def_id; rg_id = decl.back_id } + { def_id = decl.def_id; lp_id = decl.loop_id; rg_id = decl.back_id } in (* Compute the list/set of identifiers *) let idl = List.map get_fun_id decls in diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 8a6f82f9..bba3326b 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -269,7 +269,7 @@ let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id) (* Retrieve the fun_id *) let fun_id = match info.forward.call_id with - | S.Fun (fid, _) -> Fun (FromLlbc (fid, Some back_id)) + | S.Fun (fid, _) -> Fun (FromLlbc (fid, None, Some back_id)) | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable") in (* Update the context and return *) @@ -1268,7 +1268,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : match call.call_id with | S.Fun (fid, call_id) -> (* Regular function call *) - let func = Fun (FromLlbc (fid, None)) in + let func = Fun (FromLlbc (fid, None, None)) in (* Retrieve the effect information about this function (can fail, * takes a state as input, etc.) *) let effect_info = @@ -2027,10 +2027,17 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = (List.combine inputs signature.inputs)); Some { inputs; inputs_lvs; body } in + + (* Note that for now, the loops are still *inside* the function body: we will + extract them from there later, in {!PureMicroPasses} (by "splitting" the definition). + *) + let loop_id = None in + (* Assemble the declaration *) let def = { def_id; + loop_id; back_id = bid; basename; signature; diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 87862d6b..75aeb37c 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -459,9 +459,7 @@ let export_functions_declarations (fmt : Format.formatter) (config : gen_config) in (* Extract the function declarations *) - (* Check if the functions are mutually recursive - this really works - * to check if the forward and backward translations of a single - * recursive function are mutually recursive *) + (* Check if the functions are mutually recursive *) let is_mut_rec = List.length decls > 1 in assert ((not is_mut_rec) || is_rec); let decls_length = List.length decls in -- cgit v1.2.3