From f69ac6a4a244c99a41a90ed57f74ea83b3835882 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 14 Dec 2023 17:11:01 +0100 Subject: Start updating Pure.fun_sig_info to handle merged forward and backward functions --- compiler/Config.ml | 63 ++++++++++++++++++++++++++++++++++++++++++++-- compiler/Pure.ml | 24 +++++++++++++++--- compiler/PureUtils.ml | 6 ++++- compiler/SymbolicToPure.ml | 62 ++++++++++++++++++++++++++++++--------------- 4 files changed, 128 insertions(+), 27 deletions(-) diff --git a/compiler/Config.ml b/compiler/Config.ml index 9cd1ebc2..b8af6c6d 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -93,8 +93,67 @@ let loop_fixed_point_max_num_iters = 2 (** {1 Translation} *) (** If true, do not define separate forward/backward functions, but make the - forward functions return the backward function. *) -let return_back_funs = ref false + forward functions return the backward function. + + Example: + {[ + (* Rust *) + pub fn list_nth<'a, T>(l: &'a mut List, i: u32) -> &'a mut T { + match l { + List::Nil => { + panic!() + } + List::Cons(x, tl) => { + if i == 0 { + x + } else { + list_nth(tl, i - 1) + } + } + } + } + + (* Translation, if return_back_funs = false *) + def list_nth (T : Type) (l : List T) (i : U32) : Result T := + match l with + | List.Cons x tl => + if i = 0#u32 + then Result.ret x + else do + let i0 ← i - 1#u32 + list_nth T tl i0 + | List.Nil => Result.fail .panic + + def list_nth_back + (T : Type) (l : List T) (i : U32) (ret : T) : Result (List T) := + match l with + | List.Cons x tl => + if i = 0#u32 + then Result.ret (List.Cons ret tl) + else + do + let i0 ← i - 1#u32 + let tl0 ← list_nth_back T tl i0 ret + Result.ret (List.Cons x tl0) + | List.Nil => Result.fail .panic + + (* Translation, if return_back_funs = true *) + def list_nth (T: Type) (ls : List T) (i : U32) : + Result (T × (T → Result (List T))) := + match ls with + | List.Cons x tl => + if i = 0#u32 + then Result.ret (x, (λ ret => return (ret :: ls))) + else do + let i0 ← i - 1#u32 + let (x, back) ← list_nth ls i0 + Return.ret (x, + (λ ret => do + let ls ← back ret + return (x :: ls))) + ]} + *) +let return_back_funs = ref true (** Forbids using field projectors for structures. diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 80d8782b..bb522623 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -561,7 +561,7 @@ type fun_id_or_trait_method_ref = (** A function id for a non-assumed function *) type regular_fun_id = - fun_id_or_trait_method_ref * LoopId.id option * T.RegionGroupId.id option + fun_id_or_trait_method_ref * LoopId.id option * RegionGroupId.id option [@@deriving show, ord] (** A function identifier *) @@ -886,12 +886,28 @@ type inputs_info = { } [@@deriving show] +type 'a back_info = + | SingleBack of 'a option + (** Information about a single backward function, if pertinent. + + We use this variant if we split the forward and the backward functions. + *) + | AllBacks of 'a RegionGroupId.Map.t + (** Information about the various backward functions. + + We use this if we *do not* split the forward and the backward functions. + All the information is then carried by the forward function. + *) +[@@deriving show] + +type back_inputs_info = inputs_info back_info [@@deriving show] + (** Meta information about a function signature *) type fun_sig_info = { fwd_info : inputs_info; (** Information about the inputs of the forward function *) - back_info : inputs_info option; - (** Information about the inputs of the backward function, if pertinent *) + back_info : back_inputs_info; + (** Information about the inputs of the backward functions. *) effect_info : fun_effect_info; } [@@deriving show] @@ -1024,7 +1040,7 @@ type fun_decl = { *) loop_id : LoopId.id option; (** [Some] if this definition was generated for a loop *) - back_id : T.RegionGroupId.id option; + back_id : RegionGroupId.id option; llbc_name : llbc_name; (** The original LLBC name. *) name : string; (** We use the name only for printing purposes (for debugging): diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 23a41f0e..3c038149 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -74,7 +74,11 @@ let inputs_info_is_wf (info : inputs_info) : bool = let fun_sig_info_is_wf (info : fun_sig_info) : bool = inputs_info_is_wf info.fwd_info && - match info.back_info with None -> true | Some info -> inputs_info_is_wf info + match info.back_info with + | SingleBack None -> true + | SingleBack (Some info) -> inputs_info_is_wf info + | AllBacks infos -> + List.for_all inputs_info_is_wf (RegionGroupId.Map.values infos) let dest_arrow_ty (ty : ty) : ty * ty = match ty with diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 971a8cbd..59205f08 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -855,10 +855,14 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) name (outputs for backward functions come from borrows in the inputs of the forward function) which we use as hints to generate pretty names in the extracted code. + + We use [bid] ("backward function id") only if we split the forward + and the backward functions. *) let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) (sg : A.fun_sig) (input_names : string option list) (bid : T.RegionGroupId.id option) : fun_sig_named_outputs = + assert (Option.is_none bid || not !Config.return_back_funs); let fun_infos = decls_ctx.fun_ctx.fun_infos in let type_infos = decls_ctx.type_ctx.type_infos in (* Retrieve the list of parent backward functions *) @@ -939,6 +943,18 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) let inside_mut = false in translate_back_ty type_infos keep_region inside_mut ty in + let translate_back_inputs_for_gid gid : ty list = + (* For now, we don't allow nested borrows, so the additional inputs to the + backward function can only come from borrows that were returned like + in (for the backward function we introduce for 'a): + {[ + fn f<'a>(...) -> &'a mut u32; + ]} + Upon ending the abstraction for 'a, we need to get back the borrow + the function returned. + *) + List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] + in (* Compute the additinal inputs for the current function, if it is a backward * function *) let back_inputs = @@ -1056,18 +1072,22 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) num_fwd_inputs_with_fuel_no_state + List.length fwd_state_ty; } in - let back_info : inputs_info option = - Option.map - (fun n -> - (* Note that backward functions never use fuel *) - { - has_fuel = false; - num_inputs_no_fuel_no_state = n; - num_inputs_with_fuel_no_state = n; - (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) - num_inputs_with_fuel_with_state = n + List.length back_state_ty; - }) - num_back_inputs_no_state + let back_info : back_inputs_info = + if !Config.return_back_funs then + SingleBack + (Option.map + (fun n -> + (* Note that backward functions never use fuel *) + { + has_fuel = false; + num_inputs_no_fuel_no_state = n; + num_inputs_with_fuel_no_state = n; + (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) + num_inputs_with_fuel_with_state = n + List.length back_state_ty; + }) + num_back_inputs_no_state) + else (* Create the map *) + failwith "TODO" in let info = { fwd_info; back_info; effect_info } in assert (fun_sig_info_is_wf info); @@ -3162,14 +3182,16 @@ let translate_fun_signatures (decls_ctx : C.decls_ctx) let fwd_id = (fun_id, None) in (* The backward functions *) let back_sgs = - List.map - (fun (rg : T.region_var_group) -> - let tsg = - translate_fun_sig decls_ctx fun_id sg input_names (Some rg.id) - in - let id = (fun_id, Some rg.id) in - (id, tsg)) - regions_hierarchy + if !Config.return_back_funs then [] + else + List.map + (fun (rg : T.region_var_group) -> + let tsg = + translate_fun_sig decls_ctx fun_id sg input_names (Some rg.id) + in + let id = (fun_id, Some rg.id) in + (id, tsg)) + regions_hierarchy in (* Return *) (fwd_id, fwd_sg) :: back_sgs -- cgit v1.2.3