From 496a3849d1d6ba880bbd1e86c8ef5e2257bb702a Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 13 Dec 2023 10:55:57 +0100 Subject: Add the num_fwd_inputs_no_fuel_no_state field in Pure.fun_sig --- compiler/SymbolicToPure.ml | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index bf4d26f2..2ef313e6 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -1035,10 +1035,10 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) let generics = translate_generic_params sg.generics in (* Return *) let has_fuel = fuel <> [] in - let num_fwd_inputs_no_state = List.length fwd_inputs in + let num_fwd_inputs_no_fuel_no_state = List.length fwd_inputs in let num_fwd_inputs_with_fuel_no_state = (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *) - List.length fuel + num_fwd_inputs_no_state + List.length fuel + num_fwd_inputs_no_fuel_no_state in let num_back_inputs_no_state = if bid = None then None else Some (List.length back_inputs) @@ -1046,6 +1046,7 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) let info = { has_fuel; + num_fwd_inputs_no_fuel_no_state; num_fwd_inputs_with_fuel_no_state; num_fwd_inputs_with_fuel_with_state = (* We use the fact that [fwd_state_ty] has length 1 if there is a state, -- cgit v1.2.3 From 0c814c97dd8e5167f24b0dbb14186d674e4d097b Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 13 Dec 2023 11:44:58 +0100 Subject: Update Pure.fun_sig_info --- compiler/SymbolicToPure.ml | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 2ef313e6..971a8cbd 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -1034,6 +1034,8 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) (* Generic parameters *) let generics = translate_generic_params sg.generics in (* Return *) + (* TODO: *) + assert (not !Config.return_back_funs); let has_fuel = fuel <> [] in let num_fwd_inputs_no_fuel_no_state = List.length fwd_inputs in let num_fwd_inputs_with_fuel_no_state = @@ -1043,24 +1045,32 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) let num_back_inputs_no_state = if bid = None then None else Some (List.length back_inputs) in - let info = + let fwd_info : inputs_info = { has_fuel; - num_fwd_inputs_no_fuel_no_state; - num_fwd_inputs_with_fuel_no_state; - num_fwd_inputs_with_fuel_with_state = + num_inputs_no_fuel_no_state = num_fwd_inputs_no_fuel_no_state; + num_inputs_with_fuel_no_state = num_fwd_inputs_with_fuel_no_state; + num_inputs_with_fuel_with_state = (* We use the fact that [fwd_state_ty] has length 1 if there is a state, and 0 otherwise *) num_fwd_inputs_with_fuel_no_state + List.length fwd_state_ty; - num_back_inputs_no_state; - num_back_inputs_with_state = - (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) - Option.map - (fun n -> n + List.length back_state_ty) - num_back_inputs_no_state; - effect_info; } in + let back_info : inputs_info option = + Option.map + (fun n -> + (* Note that backward functions never use fuel *) + { + has_fuel = false; + num_inputs_no_fuel_no_state = n; + num_inputs_with_fuel_no_state = n; + (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) + num_inputs_with_fuel_with_state = n + List.length back_state_ty; + }) + num_back_inputs_no_state + in + let info = { fwd_info; back_info; effect_info } in + assert (fun_sig_info_is_wf info); let preds = translate_predicates sg.preds in let sg = { -- cgit v1.2.3 From f69ac6a4a244c99a41a90ed57f74ea83b3835882 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 14 Dec 2023 17:11:01 +0100 Subject: Start updating Pure.fun_sig_info to handle merged forward and backward functions --- compiler/SymbolicToPure.ml | 62 +++++++++++++++++++++++++++++++--------------- 1 file changed, 42 insertions(+), 20 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 971a8cbd..59205f08 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -855,10 +855,14 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) name (outputs for backward functions come from borrows in the inputs of the forward function) which we use as hints to generate pretty names in the extracted code. + + We use [bid] ("backward function id") only if we split the forward + and the backward functions. *) let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) (sg : A.fun_sig) (input_names : string option list) (bid : T.RegionGroupId.id option) : fun_sig_named_outputs = + assert (Option.is_none bid || not !Config.return_back_funs); let fun_infos = decls_ctx.fun_ctx.fun_infos in let type_infos = decls_ctx.type_ctx.type_infos in (* Retrieve the list of parent backward functions *) @@ -939,6 +943,18 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) let inside_mut = false in translate_back_ty type_infos keep_region inside_mut ty in + let translate_back_inputs_for_gid gid : ty list = + (* For now, we don't allow nested borrows, so the additional inputs to the + backward function can only come from borrows that were returned like + in (for the backward function we introduce for 'a): + {[ + fn f<'a>(...) -> &'a mut u32; + ]} + Upon ending the abstraction for 'a, we need to get back the borrow + the function returned. + *) + List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] + in (* Compute the additinal inputs for the current function, if it is a backward * function *) let back_inputs = @@ -1056,18 +1072,22 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) num_fwd_inputs_with_fuel_no_state + List.length fwd_state_ty; } in - let back_info : inputs_info option = - Option.map - (fun n -> - (* Note that backward functions never use fuel *) - { - has_fuel = false; - num_inputs_no_fuel_no_state = n; - num_inputs_with_fuel_no_state = n; - (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) - num_inputs_with_fuel_with_state = n + List.length back_state_ty; - }) - num_back_inputs_no_state + let back_info : back_inputs_info = + if !Config.return_back_funs then + SingleBack + (Option.map + (fun n -> + (* Note that backward functions never use fuel *) + { + has_fuel = false; + num_inputs_no_fuel_no_state = n; + num_inputs_with_fuel_no_state = n; + (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) + num_inputs_with_fuel_with_state = n + List.length back_state_ty; + }) + num_back_inputs_no_state) + else (* Create the map *) + failwith "TODO" in let info = { fwd_info; back_info; effect_info } in assert (fun_sig_info_is_wf info); @@ -3162,14 +3182,16 @@ let translate_fun_signatures (decls_ctx : C.decls_ctx) let fwd_id = (fun_id, None) in (* The backward functions *) let back_sgs = - List.map - (fun (rg : T.region_var_group) -> - let tsg = - translate_fun_sig decls_ctx fun_id sg input_names (Some rg.id) - in - let id = (fun_id, Some rg.id) in - (id, tsg)) - regions_hierarchy + if !Config.return_back_funs then [] + else + List.map + (fun (rg : T.region_var_group) -> + let tsg = + translate_fun_sig decls_ctx fun_id sg input_names (Some rg.id) + in + let id = (fun_id, Some rg.id) in + (id, tsg)) + regions_hierarchy in (* Return *) (fwd_id, fwd_sg) :: back_sgs -- cgit v1.2.3 From f1f41818fb14a6c46442ca42a49a3aab0a5b1aaf Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 14 Dec 2023 17:48:44 +0100 Subject: Make progress on generated merged fwd/back functions --- compiler/SymbolicToPure.ml | 56 ++++++++++++++++++++++++---------------------- 1 file changed, 29 insertions(+), 27 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 1fd4896e..86c80f87 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -958,19 +958,7 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) (* Compute the additinal inputs for the current function, if it is a backward * function *) let back_inputs = - match gid with - | None -> [] - | Some gid -> - (* For now, we don't allow nested borrows, so the additional inputs to the - backward function can only come from borrows that were returned like - in (for the backward function we introduce for 'a): - {[ - fn f<'a>(...) -> &'a mut u32; - ]} - Upon ending the abstraction for 'a, we need to get back the borrow - the function returned. - *) - List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] + match gid with None -> [] | Some gid -> translate_back_inputs_for_gid gid in (* If the function is stateful, the inputs are: - forward: [fwd_ty0, ..., fwd_tyn, state] @@ -989,11 +977,12 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) See {!effect_info}. *) if effect_info.stateful_group then [ mk_state_ty ] else [] in - let back_state_ty = + let mk_back_state_ty_for_gid (gid : RegionGroupId.id option) : ty list = (* For the backward state, we check if the function is a backward function, and it is stateful *) if effect_info.stateful && Option.is_some gid then [ mk_state_ty ] else [] in + let back_state_ty = mk_back_state_ty_for_gid gid in (* Concatenate the inputs, in the following order: * - forward inputs @@ -1072,22 +1061,35 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) num_fwd_inputs_with_fuel_no_state + List.length fwd_state_ty; } in + let compute_back_info (back_state_ty : ty list) + (num_back_inputs_no_state : int) : inputs_info = + let n = num_back_inputs_no_state in + (* Note that backward functions never use fuel *) + { + has_fuel = false; + num_inputs_no_fuel_no_state = n; + num_inputs_with_fuel_no_state = n; + (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) + num_inputs_with_fuel_with_state = n + List.length back_state_ty; + } + in let back_info : back_inputs_info = if !Config.return_back_funs then + (* Create the map *) + AllBacks + (RegionGroupId.Map.of_list + (List.map + (fun (rg : T.region_var_group) -> + ( rg.id, + let back_inputs = translate_back_inputs_for_gid rg.id in + let num_back_inputs = List.length back_inputs in + (* TODO: slightly overkill *) + let back_state_ty = mk_back_state_ty_for_gid (Some rg.id) in + compute_back_info back_state_ty num_back_inputs )) + regions_hierarchy)) + else SingleBack - (Option.map - (fun n -> - (* Note that backward functions never use fuel *) - { - has_fuel = false; - num_inputs_no_fuel_no_state = n; - num_inputs_with_fuel_no_state = n; - (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) - num_inputs_with_fuel_with_state = n + List.length back_state_ty; - }) - num_back_inputs_no_state) - else (* Create the map *) - failwith "TODO" + (Option.map (compute_back_info back_state_ty) num_back_inputs_no_state) in let info = { fwd_info; back_info; effect_info } in assert (fun_sig_info_is_wf info); -- cgit v1.2.3 From cf984f958da94154d0550060eb290a276ab52f23 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 15 Dec 2023 10:17:06 +0100 Subject: Make minor modifications --- compiler/SymbolicToPure.ml | 108 ++++++++++++++------------------------------- 1 file changed, 33 insertions(+), 75 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 86c80f87..eba44e3e 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -308,20 +308,6 @@ let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string = let indent_incr = " " in Print.Values.abs_to_string env verbose indent indent_incr abs -let get_instantiated_fun_sig (fun_id : A.fun_id) - (back_id : T.RegionGroupId.id option) (generics : generic_args) - (ctx : bs_ctx) : inst_fun_sig = - (* Lookup the non-instantiated function signature *) - let sg = - (RegularFunIdNotLoopMap.find (fun_id, back_id) ctx.fun_context.fun_sigs).sg - in - (* Create the substitution *) - (* There shouldn't be any reference to Self *) - let tr_self = UnknownTrait __FUNCTION__ in - let subst = make_subst_from_generics sg.generics generics tr_self in - (* Apply *) - fun_sig_substitute subst sg - let bs_ctx_lookup_llbc_type_decl (id : TypeDeclId.id) (ctx : bs_ctx) : T.type_decl = TypeDeclId.Map.find id ctx.type_context.llbc_type_decls @@ -330,12 +316,6 @@ let bs_ctx_lookup_llbc_fun_decl (id : A.FunDeclId.id) (ctx : bs_ctx) : A.fun_decl = A.FunDeclId.Map.find id ctx.fun_context.llbc_fun_decls -(* TODO: move *) -let bs_ctx_lookup_local_function_sig (def_id : A.FunDeclId.id) - (back_id : T.RegionGroupId.id option) (ctx : bs_ctx) : fun_sig = - let id = (E.FRegular def_id, back_id) in - (RegularFunIdNotLoopMap.find id ctx.fun_context.fun_sigs).sg - (* Some generic translation functions (we need to translate different "flavours" of types: forward types, backward types, etc.) *) let rec translate_generic_args (translate_ty : T.ty -> ty) @@ -994,35 +974,44 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) List.concat [ fuel; fwd_inputs; fwd_state_ty; back_inputs; back_state_ty ] in (* Outputs *) + let compute_back_outputs_for_gid (gid : RegionGroupId.id) : + string option list * ty list = + (* The outputs are the borrows inside the regions of the abstractions + and which are present in the input values. For instance, see: + {[ + fn f<'a>(x : &'a mut u32) -> ...; + ]} + Upon ending the abstraction for 'a, we give back the borrow which + was consumed through the [x] parameter. + *) + let outputs = + List.map + (fun (name, input_ty) -> (name, translate_back_ty_for_gid gid input_ty)) + (List.combine input_names sg.inputs) + in + (* Filter *) + let outputs = + List.filter (fun (_, opt_ty) -> Option.is_some opt_ty) outputs + in + let outputs = + List.map (fun (name, opt_ty) -> (name, Option.get opt_ty)) outputs + in + List.split outputs + in let output_names, doutputs = match gid with | None -> - (* This is a forward function: there is one (unnamed) output *) + (* This is a forward function: there is one (unnamed) output. + + If we merge the fwd/back functions we might need to compute + the information about the back outputs. + *) + (* TODO: *) + assert (not !Config.return_back_funs); ([ None ], [ translate_fwd_ty type_infos sg.output ]) | Some gid -> - (* This is a backward function: there might be several outputs. - The outputs are the borrows inside the regions of the abstractions - and which are present in the input values. For instance, see: - {[ - fn f<'a>(x : &'a mut u32) -> ...; - ]} - Upon ending the abstraction for 'a, we give back the borrow which - was consumed through the [x] parameter. - *) - let outputs = - List.map - (fun (name, input_ty) -> - (name, translate_back_ty_for_gid gid input_ty)) - (List.combine input_names sg.inputs) - in - (* Filter *) - let outputs = - List.filter (fun (_, opt_ty) -> Option.is_some opt_ty) outputs - in - let outputs = - List.map (fun (name, opt_ty) -> (name, Option.get opt_ty)) outputs - in - List.split outputs + (* This is a backward function: there might be several outputs. *) + compute_back_outputs_for_gid gid in (* Create the return type *) let output = @@ -2016,37 +2005,6 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) | None -> output | Some nstate -> mk_simpl_tuple_pattern [ nstate; output ] in - (* Sanity check: there is the proper number of inputs and outputs, and they have the proper type *) - (if (* TODO: normalize the types *) !Config.type_check_pure_code then - match fun_id with - | FunId fun_id -> - let inst_sg = - get_instantiated_fun_sig fun_id (Some rg_id) generics ctx - in - log#ldebug - (lazy - ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs (" - ^ string_of_int (List.length inputs) - ^ "): " - ^ String.concat ", " (List.map (texpression_to_string ctx) inputs) - ^ "\n- inst_sg.inputs (" - ^ string_of_int (List.length inst_sg.inputs) - ^ "): " - ^ String.concat ", " - (List.map (pure_ty_to_string ctx) inst_sg.inputs))); - List.iter - (fun (x, ty) -> assert ((x : texpression).ty = ty)) - (List.combine inputs inst_sg.inputs); - log#ldebug - (lazy - ("\n- outputs: " - ^ string_of_int (List.length outputs) - ^ "\n- expected outputs: " - ^ string_of_int (List.length inst_sg.doutputs))); - List.iter - (fun (x, ty) -> assert ((x : typed_pattern).ty = ty)) - (List.combine outputs inst_sg.doutputs) - | _ -> (* TODO: trait methods *) ()); (* Retrieve the function id, and register the function call in the context * if necessary *) let ctx, func = -- cgit v1.2.3 From 83c5be42e1750d329ad31bc9151d7b0446af5a0f Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 15 Dec 2023 12:14:01 +0100 Subject: Make progress on generalizing the signature information --- compiler/SymbolicToPure.ml | 313 ++++++++++++++++++++++----------------------- 1 file changed, 155 insertions(+), 158 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index eba44e3e..456ec0f6 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -128,9 +128,9 @@ type bs_ctx = { trait_impls_ctx : trait_impls_context; fun_decl : A.fun_decl; bid : T.RegionGroupId.id option; (** TODO: rename *) - sg : fun_sig; - (** The function signature - useful in particular to translate [Panic] *) - fwd_sg : fun_sig; (** The signature of the forward function *) + sg : decomposed_fun_sig; + (** Information about the function signature - useful in particular to + translate [Panic] *) sv_to_var : var V.SymbolicValueId.Map.t; (** Whenever we encounter a new symbolic value (introduced because of a symbolic expansion or upon ending an abstraction, for instance) @@ -828,7 +828,7 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) is_rec = false; } -(** Translate a function signature. +(** Translate a function signature to a decomposed function signature. Note that the function also takes a list of names for the inputs, and computes, for every output for the backward functions, a corresponding @@ -839,26 +839,15 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) We use [bid] ("backward function id") only if we split the forward and the backward functions. *) -let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) - (sg : A.fun_sig) (input_names : string option list) - (bid : T.RegionGroupId.id option) : fun_sig_named_outputs = - assert (Option.is_none bid || not !Config.return_back_funs); +let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) + (fun_id : A.fun_id) (sg : A.fun_sig) (input_names : string option list) : + decomposed_fun_sig = let fun_infos = decls_ctx.fun_ctx.fun_infos in let type_infos = decls_ctx.type_ctx.type_infos in (* Retrieve the list of parent backward functions *) let regions_hierarchy = FunIdMap.find fun_id decls_ctx.fun_ctx.regions_hierarchies in - let gid, parents = - match bid with - | None -> (None, T.RegionGroupId.Set.empty) - | Some bid -> - let parents = list_ancestor_region_groups regions_hierarchy bid in - (Some bid, parents) - in - (* Is the function stateful, and can it fail? *) - let lid = None in - let effect_info = get_fun_effect_info fun_infos (FunId fun_id) lid bid in (* We need an evaluation context to normalize the types (to normalize the associated types, etc. - for instance it may happen that the types refer to the types associated to a trait ref, but where the trait ref @@ -886,17 +875,52 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) { sg with A.inputs; output } in - (* List the inputs for: - * - the fuel - * - the forward function - * - the parent backward functions, in proper order - * - the current backward function (if it is a backward function) - *) - let fuel = mk_fuel_input_ty_as_list effect_info in - let fwd_inputs = List.map (translate_fwd_ty type_infos) sg.inputs in - (* For the backward functions: for now we don't supported nested borrows, - * so just check that there aren't parent regions *) - assert (T.RegionGroupId.Set.is_empty parents); + (* Is the forward function stateful, and can it fail? *) + let fwd_effect_info = + get_fun_effect_info fun_infos (FunId fun_id) None None + in + (* Compute the forward inputs *) + let fwd_fuel = mk_fuel_input_ty_as_list fwd_effect_info in + let fwd_inputs_no_fuel_no_state = + List.map (translate_fwd_ty type_infos) sg.inputs + in + (* State input for the forward function *) + let fwd_state_ty = + (* For the forward state, we check if the *whole group* is stateful. + See {!effect_info}. *) + if fwd_effect_info.stateful_group then [ mk_state_ty ] else [] + in + let fwd_inputs = + List.concat [ fwd_fuel; fwd_inputs_no_fuel_no_state; fwd_state_ty ] + in + (* Compute the backward output, without the effect information *) + let fwd_output = translate_fwd_ty type_infos sg.output in + (* The additinoal information *) + let fwd_info = + (* *) + let has_fuel = fwd_fuel <> [] in + let num_inputs_no_fuel_no_state = List.length fwd_inputs in + let num_inputs_with_fuel_no_state = + (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *) + List.length fwd_fuel + num_inputs_no_fuel_no_state + in + let fwd_info : inputs_info = + { + has_fuel; + num_inputs_no_fuel_no_state; + num_inputs_with_fuel_no_state; + num_inputs_with_fuel_with_state = + (* We use the fact that [fwd_state_ty] has length 1 if there is a state, + and 0 otherwise *) + num_inputs_with_fuel_no_state + List.length fwd_state_ty; + } + in + let info = { fwd_info; effect_info = fwd_effect_info } in + assert (fun_sig_info_is_wf info); + info + in + + (* Compute the type information for the backward function *) (* Small helper to translate types for backward functions *) let translate_back_ty_for_gid (gid : T.RegionGroupId.id) (ty : T.ty) : ty option = @@ -923,7 +947,11 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) let inside_mut = false in translate_back_ty type_infos keep_region inside_mut ty in - let translate_back_inputs_for_gid gid : ty list = + let translate_back_inputs_for_gid (gid : T.RegionGroupId.id) : ty list = + (* For now we don't supported nested borrows, so we check that there + aren't parent regions *) + let parents = list_ancestor_region_groups regions_hierarchy gid in + assert (T.RegionGroupId.Set.is_empty parents); (* For now, we don't allow nested borrows, so the additional inputs to the backward function can only come from borrows that were returned like in (for the backward function we introduce for 'a): @@ -935,45 +963,6 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) *) List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] in - (* Compute the additinal inputs for the current function, if it is a backward - * function *) - let back_inputs = - match gid with None -> [] | Some gid -> translate_back_inputs_for_gid gid - in - (* If the function is stateful, the inputs are: - - forward: [fwd_ty0, ..., fwd_tyn, state] - - backward: - - if {!Config.backward_no_state_update}: [fwd_ty0, ..., fwd_tyn, state, back_ty, state] - - otherwise: [fwd_ty0, ..., fwd_tyn, state, back_ty] - - The backward takes the same state as input as the forward function, - together with the state at the point where it gets called, if it is - stateful. - - See the comments for {!Config.backward_no_state_update} - *) - let fwd_state_ty = - (* For the forward state, we check if the *whole group* is stateful. - See {!effect_info}. *) - if effect_info.stateful_group then [ mk_state_ty ] else [] - in - let mk_back_state_ty_for_gid (gid : RegionGroupId.id option) : ty list = - (* For the backward state, we check if the function is a backward function, - and it is stateful *) - if effect_info.stateful && Option.is_some gid then [ mk_state_ty ] else [] - in - let back_state_ty = mk_back_state_ty_for_gid gid in - - (* Concatenate the inputs, in the following order: - * - forward inputs - * - forward state input - * - backward inputs - * - backward state input - *) - let inputs = - List.concat [ fuel; fwd_inputs; fwd_state_ty; back_inputs; back_state_ty ] - in - (* Outputs *) let compute_back_outputs_for_gid (gid : RegionGroupId.id) : string option list * ty list = (* The outputs are the borrows inside the regions of the abstractions @@ -998,103 +987,111 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) in List.split outputs in - let output_names, doutputs = - match gid with - | None -> - (* This is a forward function: there is one (unnamed) output. - - If we merge the fwd/back functions we might need to compute - the information about the back outputs. - *) - (* TODO: *) - assert (not !Config.return_back_funs); - ([ None ], [ translate_fwd_ty type_infos sg.output ]) - | Some gid -> - (* This is a backward function: there might be several outputs. *) - compute_back_outputs_for_gid gid - in - (* Create the return type *) - let output = - (* Group the outputs together *) - let output = mk_simpl_tuple_ty doutputs in - (* Add the output state *) - let output = - if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] - else output + let compute_back_info_for_group (rg : T.region_var_group) : + RegionGroupId.id * back_sg_info = + let gid = rg.id in + let back_effect_info = + get_fun_effect_info fun_infos (FunId fun_id) None (Some gid) in - (* Wrap in a result type *) - if effect_info.can_fail then mk_result_ty output else output + let inputs_no_state = translate_back_inputs_for_gid gid in + let inputs_no_state_names = + List.map (fun _ -> Some "ret") inputs_no_state + in + let state_ty, state_name = + if back_effect_info.stateful then ([ mk_state_ty ], [ None ]) else ([], []) + in + let inputs = inputs_no_state @ state_ty in + let input_names = inputs_no_state_names @ state_name in + let output_names, outputs = compute_back_outputs_for_gid gid in + let info = + { + inputs; + input_names; + outputs; + output_names; + effect_info = back_effect_info; + } + in + (gid, info) in + let back_sg = + RegionGroupId.Map.of_list + (List.map compute_back_info_for_group regions_hierarchy) + in + (* Generic parameters *) let generics = translate_generic_params sg.generics in + (* Return *) - (* TODO: *) - assert (not !Config.return_back_funs); - let has_fuel = fuel <> [] in - let num_fwd_inputs_no_fuel_no_state = List.length fwd_inputs in - let num_fwd_inputs_with_fuel_no_state = - (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *) - List.length fuel + num_fwd_inputs_no_fuel_no_state - in - let num_back_inputs_no_state = - if bid = None then None else Some (List.length back_inputs) - in - let fwd_info : inputs_info = - { - has_fuel; - num_inputs_no_fuel_no_state = num_fwd_inputs_no_fuel_no_state; - num_inputs_with_fuel_no_state = num_fwd_inputs_with_fuel_no_state; - num_inputs_with_fuel_with_state = - (* We use the fact that [fwd_state_ty] has length 1 if there is a state, - and 0 otherwise *) - num_fwd_inputs_with_fuel_no_state + List.length fwd_state_ty; - } + let preds = translate_predicates sg.preds in + { + generics; + llbc_generics = sg.generics; + preds; + fwd_inputs; + fwd_output; + back_sg; + fwd_info; + } + +let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) + (gid : RegionGroupId.id option) : fun_sig = + let generics = dsg.generics in + let llbc_generics = dsg.llbc_generics in + let preds = dsg.preds in + (* Compute the effects info *) + let fwd_info = dsg.fwd_info in + let back_effect_info = + RegionGroupId.Map.of_list + (List.map + (fun ((gid, info) : RegionGroupId.id * back_sg_info) -> + (gid, info.effect_info)) + (RegionGroupId.Map.bindings dsg.back_sg)) in - let compute_back_info (back_state_ty : ty list) - (num_back_inputs_no_state : int) : inputs_info = - let n = num_back_inputs_no_state in - (* Note that backward functions never use fuel *) - { - has_fuel = false; - num_inputs_no_fuel_no_state = n; - num_inputs_with_fuel_no_state = n; - (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) - num_inputs_with_fuel_with_state = n + List.length back_state_ty; - } + (* Two cases depending on whether we split the forward/backward functions + or not *) + let mk_output_ty (effect_info : fun_effect_info) output = + let output = + if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] + else output + in + if effect_info.can_fail then mk_result_ty output else output in - let back_info : back_inputs_info = - if !Config.return_back_funs then - (* Create the map *) - AllBacks - (RegionGroupId.Map.of_list - (List.map - (fun (rg : T.region_var_group) -> - ( rg.id, - let back_inputs = translate_back_inputs_for_gid rg.id in - let num_back_inputs = List.length back_inputs in - (* TODO: slightly overkill *) - let back_state_ty = mk_back_state_ty_for_gid (Some rg.id) in - compute_back_info back_state_ty num_back_inputs )) - regions_hierarchy)) + let inputs, output = + if !Config.return_back_funs then ( + assert (gid = None); + (* Compute the arrow types for all the backward functions *) + let back_tys = + List.map + (fun (back_sg : back_sg_info) -> + let effect_info = back_sg.effect_info in + let inputs = dsg.fwd_inputs @ back_sg.inputs in + let output = mk_simpl_tuple_ty back_sg.outputs in + let output = mk_output_ty effect_info output in + mk_arrows inputs output) + (RegionGroupId.Map.values dsg.back_sg) + in + (* Group the forward output and the types of the backward functions *) + let effect_info = dsg.fwd_info.effect_info in + let output = mk_simpl_tuple_ty (dsg.fwd_output :: back_tys) in + let output = mk_output_ty effect_info output in + let inputs = dsg.fwd_inputs in + (inputs, output)) else - SingleBack - (Option.map (compute_back_info back_state_ty) num_back_inputs_no_state) - in - let info = { fwd_info; back_info; effect_info } in - assert (fun_sig_info_is_wf info); - let preds = translate_predicates sg.preds in - let sg = - { - generics; - llbc_generics = sg.generics; - preds; - inputs; - output; - doutputs; - info; - } - in - { sg; output_names } + match gid with + | None -> + let effect_info = dsg.fwd_info.effect_info in + let output = mk_output_ty effect_info dsg.fwd_output in + (dsg.fwd_inputs, output) + | Some gid -> + let back_sg = RegionGroupId.Map.find gid dsg.back_sg in + let effect_info = back_sg.effect_info in + let inputs = dsg.fwd_inputs @ back_sg.inputs in + let output = mk_simpl_tuple_ty back_sg.outputs in + let output = mk_output_ty effect_info output in + (inputs, output) + in + { generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info } let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern = (* Generate the fresh variable *) -- cgit v1.2.3 From 62cb926e76ef0c9fb048b0e340bdae5b9dd76a84 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 15 Dec 2023 14:06:16 +0100 Subject: Make progress on updating SymbolicToPure --- compiler/SymbolicToPure.ml | 169 ++++++++++++++++++++++++++++++--------------- 1 file changed, 112 insertions(+), 57 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 456ec0f6..d62cc829 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -127,7 +127,15 @@ type bs_ctx = { trait_decls_ctx : trait_decls_context; trait_impls_ctx : trait_impls_context; fun_decl : A.fun_decl; - bid : T.RegionGroupId.id option; (** TODO: rename *) + bid : RegionGroupId.id option; + (** TODO: rename + + The id of the group region we are currently translating. + If we split the forward/backward functions, we set this id at the + very beginning of the translation. + If we don't split, we set it to `None`, then update it when we enter + an expression which is specific to a backward function. + *) sg : decomposed_fun_sig; (** Information about the function signature - useful in particular to translate [Panic] *) @@ -139,7 +147,7 @@ type bs_ctx = { var_counter : VarId.generator; state_var : VarId.id; (** The current state variable, in case the function is stateful *) - back_state_var : VarId.id; + back_state_vars : VarId.id RegionGroupId.Map.t; (** The additional input state variable received by a stateful backward function. When generating stateful functions, we generate code of the following form: @@ -163,16 +171,16 @@ type bs_ctx = { (** The input parameters for the forward function corresponding to the translated Rust inputs (no fuel, no state). *) - backward_inputs : var list T.RegionGroupId.Map.t; + backward_inputs : var list RegionGroupId.Map.t; (** The additional input parameters for the backward functions coming from the borrows consumed upon ending the lifetime (as a consequence those don't include the backward state, if there is one). *) - backward_outputs : var list T.RegionGroupId.Map.t; + backward_outputs : var list RegionGroupId.Map.t; (** The variables that the backward functions will output, corresponding to the borrows they give back (don't include the backward state) *) - loop_backward_outputs : var list T.RegionGroupId.Map.t option; + loop_backward_outputs : var list RegionGroupId.Map.t option; (** Same as {!backward_outputs}, but for loops (if we entered a loop). [None] if we are not inside a loop, [Some] otherwise (and whatever @@ -300,6 +308,13 @@ let typed_pattern_to_string (ctx : bs_ctx) (p : Pure.typed_pattern) : string = let env = bs_ctx_to_pure_fmt_env ctx in PrintPure.typed_pattern_to_string env p +let ctx_get_effect_info (ctx : bs_ctx) : fun_effect_info = + match ctx.bid with + | None -> ctx.sg.fwd_info.effect_info + | Some bid -> + let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in + back_sg.effect_info + (* TODO: move *) let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string = let env = bs_ctx_to_fmt_env ctx in @@ -1034,6 +1049,24 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) fwd_info; } +let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty + = + let output = + if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ] else ty + in + if effect_info.can_fail then mk_result_ty output else output + +(** Compute the arrow types for all the backward functions *) +let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list = + List.map + (fun (back_sg : back_sg_info) -> + let effect_info = back_sg.effect_info in + let inputs = dsg.fwd_inputs @ back_sg.inputs in + let output = mk_simpl_tuple_ty back_sg.outputs in + let output = mk_output_ty_from_effect_info effect_info output in + mk_arrows inputs output) + (RegionGroupId.Map.values dsg.back_sg) + let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) (gid : RegionGroupId.id option) : fun_sig = let generics = dsg.generics in @@ -1050,27 +1083,13 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) in (* Two cases depending on whether we split the forward/backward functions or not *) - let mk_output_ty (effect_info : fun_effect_info) output = - let output = - if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] - else output - in - if effect_info.can_fail then mk_result_ty output else output - in + let mk_output_ty = mk_output_ty_from_effect_info in + let inputs, output = if !Config.return_back_funs then ( assert (gid = None); (* Compute the arrow types for all the backward functions *) - let back_tys = - List.map - (fun (back_sg : back_sg_info) -> - let effect_info = back_sg.effect_info in - let inputs = dsg.fwd_inputs @ back_sg.inputs in - let output = mk_simpl_tuple_ty back_sg.outputs in - let output = mk_output_ty effect_info output in - mk_arrows inputs output) - (RegionGroupId.Map.values dsg.back_sg) - in + let back_tys = compute_back_tys dsg in (* Group the forward output and the types of the backward functions *) let effect_info = dsg.fwd_info.effect_info in let output = mk_simpl_tuple_ty (dsg.fwd_output :: back_tys) in @@ -1584,30 +1603,43 @@ and translate_panic (ctx : bs_ctx) : texpression = * but it won't be true anymore once we translate individual blocks *) (* If we use a state monad, we need to add a lambda for the state variable *) (* Note that only forward functions return a state *) - let output_ty = - if ctx.inside_loop && Option.is_some ctx.bid then - (* We are synthesizing the backward function of a loop body *) - let bid = Option.get ctx.bid in - let back_vars = - T.RegionGroupId.Map.find bid (Option.get ctx.loop_backward_outputs) - in - let tys = List.map (fun (v : var) -> v.ty) back_vars in - mk_simpl_tuple_ty tys - else - (* Regular function, or forward function (the forward translation for - a loop has the same return type as the parent function) - *) - mk_simpl_tuple_ty ctx.sg.doutputs - in + let effect_info = ctx_get_effect_info ctx in (* TODO: we should use a [Fail] function *) - if ctx.sg.info.effect_info.stateful then - (* Create the [Fail] value *) - let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in - let ret_v = - mk_result_fail_texpression_with_error_id error_failure_id ret_ty + let mk_output output_ty = + if effect_info.stateful then + (* Create the [Fail] value *) + let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in + let ret_v = + mk_result_fail_texpression_with_error_id error_failure_id ret_ty + in + ret_v + else mk_result_fail_texpression_with_error_id error_failure_id output_ty + in + if ctx.inside_loop && Option.is_some ctx.bid then + (* We are synthesizing the backward function of a loop body *) + let bid = Option.get ctx.bid in + let back_vars = + T.RegionGroupId.Map.find bid (Option.get ctx.loop_backward_outputs) in - ret_v - else mk_result_fail_texpression_with_error_id error_failure_id output_ty + let tys = List.map (fun (v : var) -> v.ty) back_vars in + let output = mk_simpl_tuple_ty tys in + mk_output output + else + (* Regular function, or forward function (the forward translation for + a loop has the same return type as the parent function) + *) + match ctx.bid with + | None -> + if !Config.return_back_funs then + let back_tys = compute_back_tys ctx.sg in + let output = mk_simpl_tuple_ty (ctx.sg.fwd_output :: back_tys) in + mk_output output + else mk_output ctx.sg.fwd_output + | Some bid -> + let output = + mk_simpl_tuple_ty (RegionGroupId.Map.find bid ctx.sg.back_sg).outputs + in + mk_output output (** [opt_v]: the value to return, in case we translate a forward body *) and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) @@ -1641,7 +1673,7 @@ and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) * - error-monad: Return x * - state-error: Return (state, x) * *) - let effect_info = ctx.sg.info.effect_info in + let effect_info = ctx_get_effect_info ctx in let output = if effect_info.stateful then let state_rvalue = mk_state_texpression ctx.state_var in @@ -1695,7 +1727,7 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) * effect - in particular, one manipulates a state iff the other does * the same. * *) - let effect_info = ctx.sg.info.effect_info in + let effect_info = ctx_get_effect_info ctx in let output = if effect_info.stateful then let state_rvalue = mk_state_texpression ctx.state_var in @@ -2550,24 +2582,50 @@ and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option) and translate_forward_end (ectx : C.eval_ctx) (loop_input_values : V.typed_value S.symbolic_value_id_map option) - (e : S.expression) (back_e : S.expression S.region_group_id_map) + (fwd_e : S.expression) (back_e : S.expression S.region_group_id_map) (ctx : bs_ctx) : texpression = - (* Update the current state with the additional state received by the backward - function, if needs be, and lookup the proper expression *) - let translate_end ctx = + (* TODO: *) + assert (not !Config.return_back_funs); + + let translate_one_end ctx (bid : RegionGroupId.id option) = (* Update the current state with the additional state received by the backward function, if needs be, and lookup the proper expression *) let ctx, e = match ctx.bid with - | None -> (ctx, e) + | None -> + (* We are translating the forward function - nothing to do *) + (ctx, fwd_e) | Some bid -> - let ctx = { ctx with state_var = ctx.back_state_var } in + (* There are two cases here: + - if we split the fwd/backward functions, we simply need to update + the state + - if we don't split, we also need to wrap the expression in a + lambda, which introduces the additional inputs of the backward + function + *) + let back_state_var = RegionGroupId.Map.find bid ctx.back_state_vars in + let ctx = { ctx with state_var = back_state_var } in let e = T.RegionGroupId.Map.find bid back_e in (ctx, e) in translate_expression e ctx in + (* There are two cases, depending on whether we are splitting the forward/backward + functions or not. + + - if we split, then we simply need to translate the proper "end" expression, + that is the end of the forward function, or of the backward function we + are currently translating. + - if we don't split, then we need to translate the end of the forward + function (this is the value we will return) and generate the bodies + of the backward functions (which we will also return). + + Update the current state with the additional state received by the backward + function, if needs be, and lookup the proper expression. + *) + let translate_end ctx = failwith "TODO" in + (* If we are (re-)entering a loop, we need to introduce a call to the forward translation of the loop. *) match loop_input_values with @@ -2617,10 +2675,7 @@ and translate_forward_end (ectx : C.eval_ctx) in (* Introduce a fresh output value for the forward function *) - let ctx, output_var = - let output_ty = mk_simpl_tuple_ty ctx.fwd_sg.doutputs in - fresh_var None output_ty ctx - in + let ctx, output_var = fresh_var None ctx.sg.fwd_output ctx in let args, ctx, out_pats = let output_pat = mk_typed_pattern_from_var output_var None in @@ -2832,7 +2887,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = (* Add the input state *) let input_state = - if ctx.sg.info.effect_info.stateful then Some ctx.state_var else None + if (ctx_get_effect_info ctx).stateful then Some ctx.state_var else None in (* Translate the loop body *) -- cgit v1.2.3 From ea583d9f0f5e4a1a687b70f0e04e875969462157 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 15 Dec 2023 17:20:30 +0100 Subject: Make good progress on updating SymbolicToPure --- compiler/SymbolicToPure.ml | 224 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 185 insertions(+), 39 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index d62cc829..8e06db7c 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -121,9 +121,9 @@ type loop_info = { (** Body synthesis context *) type bs_ctx = { - type_context : type_context; - fun_context : fun_context; - global_context : global_context; + type_context : type_context; (* TODO: rename *) + fun_context : fun_context; (* TODO: rename *) + global_context : global_context; (* TODO: rename *) trait_decls_ctx : trait_decls_context; trait_impls_ctx : trait_impls_context; fun_decl : A.fun_decl; @@ -148,7 +148,9 @@ type bs_ctx = { state_var : VarId.id; (** The current state variable, in case the function is stateful *) back_state_vars : VarId.id RegionGroupId.Map.t; - (** The additional input state variable received by a stateful backward function. + (** The additional input state variable received by a stateful backward function, + **in case we are splitting the forward/backward functions**. + When generating stateful functions, we generate code of the following form: @@ -161,7 +163,9 @@ type bs_ctx = { When translating a backward function, we need at some point to update [state_var] with [back_state_var], to account for the fact that the state may have been updated by the caller between the call to the - forward function and the call to the backward function. + forward function and the call to the backward function. We also need + to make sure we use the same variable in all the branches (because + this variable is quantified at the definition level). *) fuel0 : VarId.id; (** The original fuel taken as input by the function (if we use fuel) *) @@ -171,10 +175,20 @@ type bs_ctx = { (** The input parameters for the forward function corresponding to the translated Rust inputs (no fuel, no state). *) - backward_inputs : var list RegionGroupId.Map.t; + backward_inputs_no_state : var list RegionGroupId.Map.t; (** The additional input parameters for the backward functions coming from the borrows consumed upon ending the lifetime (as a consequence those don't include the backward state, if there is one). + + If we split the forward/backward functions: we initialize this map + when initializing the bs_ctx, because those variables are quantified + at the definition level. Otherwise, we initialize it upon diving + into the expressions which are specific to the backward functions. + *) + backward_inputs_with_state : var list RegionGroupId.Map.t; + (** All the additional input parameters for the backward functions. + + Same remarks as for {!backward_inputs_no_state}. *) backward_outputs : var list RegionGroupId.Map.t; (** The variables that the backward functions will output, corresponding @@ -308,13 +322,17 @@ let typed_pattern_to_string (ctx : bs_ctx) (p : Pure.typed_pattern) : string = let env = bs_ctx_to_pure_fmt_env ctx in PrintPure.typed_pattern_to_string env p -let ctx_get_effect_info (ctx : bs_ctx) : fun_effect_info = - match ctx.bid with +let ctx_get_effect_info_for_bid (ctx : bs_ctx) (bid : RegionGroupId.id option) : + fun_effect_info = + match bid with | None -> ctx.sg.fwd_info.effect_info | Some bid -> let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in back_sg.effect_info +let ctx_get_effect_info (ctx : bs_ctx) : fun_effect_info = + ctx_get_effect_info_for_bid ctx ctx.bid + (* TODO: move *) let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string = let env = bs_ctx_to_fmt_env ctx in @@ -1009,19 +1027,18 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) get_fun_effect_info fun_infos (FunId fun_id) None (Some gid) in let inputs_no_state = translate_back_inputs_for_gid gid in - let inputs_no_state_names = - List.map (fun _ -> Some "ret") inputs_no_state + let inputs_no_state = + List.map (fun ty -> (Some "ret", ty)) inputs_no_state in - let state_ty, state_name = - if back_effect_info.stateful then ([ mk_state_ty ], [ None ]) else ([], []) + let state = + if back_effect_info.stateful then [ (None, mk_state_ty) ] else [] in - let inputs = inputs_no_state @ state_ty in - let input_names = inputs_no_state_names @ state_name in + let inputs = inputs_no_state @ state in let output_names, outputs = compute_back_outputs_for_gid gid in let info = { inputs; - input_names; + inputs_no_state; outputs; output_names; effect_info = back_effect_info; @@ -1061,7 +1078,7 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list = List.map (fun (back_sg : back_sg_info) -> let effect_info = back_sg.effect_info in - let inputs = dsg.fwd_inputs @ back_sg.inputs in + let inputs = dsg.fwd_inputs @ List.map snd back_sg.inputs in let output = mk_simpl_tuple_ty back_sg.outputs in let output = mk_output_ty_from_effect_info effect_info output in mk_arrows inputs output) @@ -1105,14 +1122,14 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) | Some gid -> let back_sg = RegionGroupId.Map.find gid dsg.back_sg in let effect_info = back_sg.effect_info in - let inputs = dsg.fwd_inputs @ back_sg.inputs in + let inputs = dsg.fwd_inputs @ List.map snd back_sg.inputs in let output = mk_simpl_tuple_ty back_sg.outputs in let output = mk_output_ty effect_info output in (inputs, output) in { generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info } -let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern = +let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * var * typed_pattern = (* Generate the fresh variable *) let id, var_counter = VarId.fresh ctx.var_counter in let state_var = @@ -1122,7 +1139,7 @@ let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern = (* Update the context *) let ctx = { ctx with var_counter; state_var = id } in (* Return *) - (ctx, state_pat) + (ctx, state_var, state_pat) (** WARNING: do not call this function directly. Call [fresh_named_var_for_symbolic_value] instead. *) @@ -1776,7 +1793,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let fuel = mk_fuel_input_as_list ctx effect_info in if effect_info.stateful then let state_var = mk_state_texpression ctx.state_var in - let ctx, nstate_var = bs_ctx_fresh_state_var ctx in + let ctx, _, nstate_var = bs_ctx_fresh_state_var ctx in (List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var) else (List.concat [ fuel; args ], ctx, None) in @@ -2010,7 +2027,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) let back_state, ctx, nstate = if effect_info.stateful then let back_state = mk_state_texpression ctx.state_var in - let ctx, nstate = bs_ctx_fresh_state_var ctx in + let ctx, _, nstate = bs_ctx_fresh_state_var ctx in ([ back_state ], ctx, Some nstate) else ([], ctx, None) in @@ -2115,15 +2132,15 @@ and translate_end_abstraction_synth_ret (ectx : C.eval_ctx) (abs : V.abs) let-binding: {[ let id_back x nx = - let s = nx in // the name [s] is not important (only collision matters) - ... + let s = nx in // the name [s] is not important (only collision matters) + ... ]} This let-binding later gets inlined, during a micro-pass. *) (* First, retrieve the list of variables used for the inputs for the * backward function *) - let inputs = T.RegionGroupId.Map.find rg_id ctx.backward_inputs in + let inputs = T.RegionGroupId.Map.find rg_id ctx.backward_inputs_no_state in (* Retrieve the values consumed upon ending the loans inside this * abstraction: as there are no nested borrows, there should be none. *) let consumed = abs_to_consumed ctx ectx abs in @@ -2185,7 +2202,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) values consumed upon ending the abstraction (i.e., we don't use [abs_to_consumed]) *) let back_inputs_vars = - T.RegionGroupId.Map.find rg_id ctx.backward_inputs + T.RegionGroupId.Map.find rg_id ctx.backward_inputs_no_state in let back_inputs = List.map mk_texpression_from_var back_inputs_vars in (* If the function is stateful: @@ -2195,7 +2212,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) let back_state, ctx, nstate = if effect_info.stateful then let back_state = mk_state_texpression ctx.state_var in - let ctx, nstate = bs_ctx_fresh_state_var ctx in + let ctx, _, nstate = bs_ctx_fresh_state_var ctx in ([ back_state ], ctx, Some nstate) else ([], ctx, None) in @@ -2590,25 +2607,69 @@ and translate_forward_end (ectx : C.eval_ctx) let translate_one_end ctx (bid : RegionGroupId.id option) = (* Update the current state with the additional state received by the backward function, if needs be, and lookup the proper expression *) - let ctx, e = + let ctx, e, finish = match ctx.bid with | None -> (* We are translating the forward function - nothing to do *) - (ctx, fwd_e) + (ctx, fwd_e, fun e -> e) | Some bid -> (* There are two cases here: - if we split the fwd/backward functions, we simply need to update - the state + the state. - if we don't split, we also need to wrap the expression in a lambda, which introduces the additional inputs of the backward function *) - let back_state_var = RegionGroupId.Map.find bid ctx.back_state_vars in - let ctx = { ctx with state_var = back_state_var } in + let ctx = + (* Introduce variables for the inputs and the state variable + and update the context. *) + if !Config.return_back_funs then + (* If the forward/backward functions are not split, we need + to introduce fresh variables for the additional inputs, + because they are locally introduced in a lambda *) + let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in + let ctx = { ctx with bid = Some bid } in + let ctx, backward_inputs_no_state = + fresh_vars back_sg.inputs_no_state ctx + in + let ctx, backward_inputs_with_state = + if (ctx_get_effect_info ctx).stateful then + let ctx, var, _ = bs_ctx_fresh_state_var ctx in + (ctx, backward_inputs_no_state @ [ var ]) + else (ctx, backward_inputs_no_state) + in + { + ctx with + backward_inputs_no_state = + RegionGroupId.Map.add bid backward_inputs_no_state + ctx.backward_inputs_no_state; + backward_inputs_with_state = + RegionGroupId.Map.add bid backward_inputs_with_state + ctx.backward_inputs_with_state; + } + else + (* Update the state variable *) + let back_state_var = + RegionGroupId.Map.find bid ctx.back_state_vars + in + { ctx with state_var = back_state_var } + in + let e = T.RegionGroupId.Map.find bid back_e in - (ctx, e) + let finish e = + (* Wrap in lambdas if necessary *) + if !Config.return_back_funs then + let inputs = + RegionGroupId.Map.find bid ctx.backward_inputs_with_state + in + let places = List.map (fun _ -> None) inputs in + mk_lambdas_from_vars inputs places e + else e + in + (ctx, e, finish) in - translate_expression e ctx + let e = translate_expression e ctx in + finish e in (* There are two cases, depending on whether we are splitting the forward/backward @@ -2624,7 +2685,87 @@ and translate_forward_end (ectx : C.eval_ctx) Update the current state with the additional state received by the backward function, if needs be, and lookup the proper expression. *) - let translate_end ctx = failwith "TODO" in + let translate_end ctx = + if !Config.return_back_funs then + (* Compute the output of the forward function *) + let fwd_effect_info = ctx.sg.fwd_info.effect_info in + let output_ty = + let ty = ctx.sg.fwd_output in + if fwd_effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ] + else ty + in + let ctx, fwd_var = fresh_var None output_ty ctx in + let ctx, state_var, state_pat = + if fwd_effect_info.stateful then + let ctx, var, pat = bs_ctx_fresh_state_var ctx in + (ctx, [ var ], [ pat ]) + else (ctx, [], []) + in + let fwd_e = translate_one_end ctx None in + + (* Introduce the backward functions *) + let back_el = + List.map + (fun ((gid, _) : RegionGroupId.id * back_sg_info) -> + translate_one_end ctx (Some gid)) + (RegionGroupId.Map.bindings ctx.sg.back_sg) + in + (* Introduce variables for the backward functions. + We lookup the LLBC definition in an attempt to derive pretty names + for those functions. *) + let back_var_names = + let def_id = ctx.fun_decl.def_id in + let sg = ctx.fun_decl.signature in + let regions_hierarchy = + LlbcAstUtils.FunIdMap.find (FRegular def_id) + ctx.fun_context.regions_hierarchies + in + List.map + (fun (gid, _) -> + let rg = RegionGroupId.nth regions_hierarchy gid in + let region_names = + List.map + (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name) + rg.regions + in + let name = + match region_names with + | [] -> "back" + | [ Some r ] -> "back" ^ r + | _ -> + (* Concatenate all the region names *) + "back" + ^ String.concat "" (List.filter_map (fun x -> x) region_names) + in + Some name) + (RegionGroupId.Map.bindings ctx.sg.back_sg) + in + let back_vars = List.combine back_var_names (compute_back_tys ctx.sg) in + let _, back_vars = fresh_vars back_vars ctx in + + (* Create the return expressions *) + let vars = fwd_var :: back_vars in + let vars = List.map mk_texpression_from_var vars in + let ret = mk_simpl_tuple_texpression vars in + let state_var = List.map mk_texpression_from_var state_var in + let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in + let ret = mk_result_return_texpression ret in + + (* Bind the expressions for the backward function and the expression + for the computation of the forward output *) + let e = + List.fold_right + (fun (var, back_e) e -> + mk_let false (mk_typed_pattern_from_var var None) back_e e) + (List.combine back_vars back_el) + ret + in + (* Bind the expression for the forward output *) + let fwd_var = mk_typed_pattern_from_var fwd_var None in + let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in + mk_let fwd_effect_info.can_fail pat fwd_e e + else translate_one_end ctx ctx.bid + in (* If we are (re-)entering a loop, we need to introduce a call to the forward translation of the loop. *) @@ -2687,7 +2828,7 @@ and translate_forward_end (ectx : C.eval_ctx) let fuel = mk_fuel_input_as_list ctx effect_info in if effect_info.stateful then let state_var = mk_state_texpression ctx.state_var in - let ctx, nstate_pat = bs_ctx_fresh_state_var ctx in + let ctx, _, nstate_pat = bs_ctx_fresh_state_var ctx in ( List.concat [ fuel; args; [ state_var ] ], ctx, [ nstate_pat; output_pat ] ) @@ -3025,8 +3166,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = let def_id = def.def_id in let llbc_name = def.name in let name = name_to_string ctx llbc_name in - (* Retrieve the signature *) - let signature = ctx.sg in + (* Translate the signature *) + let signature = translate_fun_sig_from_decomposed ctx.sg ctx.bid in let regions_hierarchy = FunIdMap.find (FRegular def_id) ctx.fun_context.regions_hierarchies in @@ -3070,20 +3211,25 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = match bid with | None -> [] | Some back_id -> + assert (not !Config.return_back_funs); let parents_ids = list_ordered_ancestor_region_groups regions_hierarchy back_id in let backward_ids = List.append parents_ids [ back_id ] in List.concat (List.map - (fun id -> T.RegionGroupId.Map.find id ctx.backward_inputs) + (fun id -> + T.RegionGroupId.Map.find id ctx.backward_inputs_no_state) backward_ids) in (* Introduce the backward input state (the state at call site of the * *backward* function), if necessary *) let back_state = if effect_info.stateful && Option.is_some bid then - [ mk_state_var ctx.back_state_var ] + let state_var = + RegionGroupId.Map.find (Option.get bid) ctx.back_state_vars + in + [ mk_state_var state_var ] else [] in (* Group the inputs together *) -- cgit v1.2.3 From 5fa83883b4d573cfd252478f7937c8bde0ec01f6 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 15 Dec 2023 17:22:01 +0100 Subject: Minor fix --- compiler/SymbolicToPure.ml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 8e06db7c..08f9e950 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -2605,10 +2605,11 @@ and translate_forward_end (ectx : C.eval_ctx) assert (not !Config.return_back_funs); let translate_one_end ctx (bid : RegionGroupId.id option) = + let ctx = { ctx with bid } in (* Update the current state with the additional state received by the backward function, if needs be, and lookup the proper expression *) let ctx, e, finish = - match ctx.bid with + match bid with | None -> (* We are translating the forward function - nothing to do *) (ctx, fwd_e, fun e -> e) @@ -2628,7 +2629,6 @@ and translate_forward_end (ectx : C.eval_ctx) to introduce fresh variables for the additional inputs, because they are locally introduced in a lambda *) let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in - let ctx = { ctx with bid = Some bid } in let ctx, backward_inputs_no_state = fresh_vars back_sg.inputs_no_state ctx in -- cgit v1.2.3 From 884edaa3ee975626f184249d491f343fc02a66e2 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 15 Dec 2023 18:54:06 +0100 Subject: Make progress on updating the code --- compiler/SymbolicToPure.ml | 79 +++++++--------------------------------------- 1 file changed, 12 insertions(+), 67 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 08f9e950..204fc399 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -45,7 +45,6 @@ type fun_sig_named_outputs = { type fun_context = { llbc_fun_decls : A.fun_decl A.FunDeclId.Map.t; - fun_sigs : fun_sig_named_outputs RegularFunIdNotLoopMap.t; (** *) fun_infos : fun_info A.FunDeclId.Map.t; regions_hierarchies : T.region_var_groups FunIdMap.t; } @@ -144,7 +143,11 @@ type bs_ctx = { a symbolic expansion or upon ending an abstraction, for instance) we introduce a new variable (with a let-binding). *) - var_counter : VarId.generator; + var_counter : VarId.generator ref; + (** Using a ref to make sure all the variables identifiers are unique. + TODO: this is not very clean, and the code was initially written without + a reference (and it's shape hasn't changed). We should use DeBruijn indices. + *) state_var : VarId.id; (** The current state variable, in case the function is stateful *) back_state_vars : VarId.id RegionGroupId.Map.t; @@ -1131,13 +1134,14 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * var * typed_pattern = (* Generate the fresh variable *) - let id, var_counter = VarId.fresh ctx.var_counter in + let id, var_counter = VarId.fresh !(ctx.var_counter) in let state_var = { id; basename = Some ConstStrings.state_basename; ty = mk_state_ty } in let state_pat = mk_typed_pattern_from_var state_var None in (* Update the context *) - let ctx = { ctx with var_counter; state_var = id } in + ctx.var_counter := var_counter; + let ctx = { ctx with state_var = id } in (* Return *) (ctx, state_var, state_pat) @@ -1146,11 +1150,11 @@ let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * var * typed_pattern = let fresh_var_llbc_ty (basename : string option) (ty : T.ty) (ctx : bs_ctx) : bs_ctx * var = (* Generate the fresh variable *) - let id, var_counter = VarId.fresh ctx.var_counter in + let id, var_counter = VarId.fresh !(ctx.var_counter) in let ty = ctx_translate_fwd_ty ctx ty in let var = { id; basename; ty } in (* Update the context *) - let ctx = { ctx with var_counter } in + ctx.var_counter := var_counter; (* Return *) (ctx, var) @@ -1184,10 +1188,10 @@ let fresh_named_vars_for_symbolic_values let fresh_var (basename : string option) (ty : ty) (ctx : bs_ctx) : bs_ctx * var = (* Generate the fresh variable *) - let id, var_counter = VarId.fresh ctx.var_counter in + let id, var_counter = VarId.fresh !(ctx.var_counter) in let var = { id; basename; ty } in (* Update the context *) - let ctx = { ctx with var_counter } in + ctx.var_counter := var_counter; (* Return *) (ctx, var) @@ -3303,65 +3307,6 @@ let translate_type_decls (ctx : Contexts.decls_ctx) : type_decl list = List.map (translate_type_decl ctx) (TypeDeclId.Map.values ctx.type_ctx.type_decls) -(** Translates function signatures. - - Takes as input a list of function information containing: - - the function id - - a list of optional names for the inputs - - the function signature - - Returns a map from forward/backward functions identifiers to: - - translated function signatures - - optional names for the outputs values (we derive them for the backward - functions) - *) -let translate_fun_signatures (decls_ctx : C.decls_ctx) - (functions : (A.fun_id * string option list * A.fun_sig) list) : - fun_sig_named_outputs RegularFunIdNotLoopMap.t = - (* For every function, translate the signatures of: - - the forward function - - the backward functions - *) - let translate_one (fun_id : A.fun_id) (input_names : string option list) - (sg : A.fun_sig) : (regular_fun_id_not_loop * fun_sig_named_outputs) list - = - log#ldebug - (lazy - ("Translating signature of function: " - ^ Print.Expressions.fun_id_to_string - (Print.Contexts.decls_ctx_to_fmt_env decls_ctx) - fun_id)); - (* Retrieve the regions hierarchy *) - let regions_hierarchy = - FunIdMap.find fun_id decls_ctx.fun_ctx.regions_hierarchies - in - (* The forward function *) - let fwd_sg = translate_fun_sig decls_ctx fun_id sg input_names None in - let fwd_id = (fun_id, None) in - (* The backward functions *) - let back_sgs = - if !Config.return_back_funs then [] - else - List.map - (fun (rg : T.region_var_group) -> - let tsg = - translate_fun_sig decls_ctx fun_id sg input_names (Some rg.id) - in - let id = (fun_id, Some rg.id) in - (id, tsg)) - regions_hierarchy - in - (* Return *) - (fwd_id, fwd_sg) :: back_sgs - in - let translated = - List.concat - (List.map (fun (id, names, sg) -> translate_one id names sg) functions) - in - List.fold_left - (fun m (id, sg) -> RegularFunIdNotLoopMap.add id sg m) - RegularFunIdNotLoopMap.empty translated - let translate_trait_decl (ctx : Contexts.decls_ctx) (trait_decl : A.trait_decl) : trait_decl = let { -- cgit v1.2.3 From a49754a5b11e4de8793dc7e13c2962d139eb03b1 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 18 Dec 2023 10:21:08 +0100 Subject: Rename some definitions --- compiler/SymbolicToPure.ml | 78 +++++++++++++++++++++++----------------------- 1 file changed, 39 insertions(+), 39 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 204fc399..d8213317 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -15,7 +15,7 @@ module PP = PrintPure (** The local logger *) let log = Logging.symbolic_to_pure_log -type type_context = { +type type_ctx = { llbc_type_decls : T.type_decl TypeDeclId.Map.t; type_decls : type_decl TypeDeclId.Map.t; (** We use this for type-checking (for sanity checks) when translating @@ -43,18 +43,18 @@ type fun_sig_named_outputs = { } [@@deriving show] -type fun_context = { +type fun_ctx = { llbc_fun_decls : A.fun_decl A.FunDeclId.Map.t; fun_infos : fun_info A.FunDeclId.Map.t; regions_hierarchies : T.region_var_groups FunIdMap.t; } [@@deriving show] -type global_context = { llbc_global_decls : A.global_decl A.GlobalDeclId.Map.t } +type global_ctx = { llbc_global_decls : A.global_decl A.GlobalDeclId.Map.t } [@@deriving show] -type trait_decls_context = A.trait_decl A.TraitDeclId.Map.t [@@deriving show] -type trait_impls_context = A.trait_impl A.TraitImplId.Map.t [@@deriving show] +type trait_decls_ctx = A.trait_decl A.TraitDeclId.Map.t [@@deriving show] +type trait_impls_ctx = A.trait_impl A.TraitImplId.Map.t [@@deriving show] (** Whenever we translate a function call or an ended abstraction, we store the related information (this is useful when translating ended @@ -120,11 +120,11 @@ type loop_info = { (** Body synthesis context *) type bs_ctx = { - type_context : type_context; (* TODO: rename *) - fun_context : fun_context; (* TODO: rename *) - global_context : global_context; (* TODO: rename *) - trait_decls_ctx : trait_decls_context; - trait_impls_ctx : trait_impls_context; + type_ctx : type_ctx; + fun_ctx : fun_ctx; + global_ctx : global_ctx; + trait_decls_ctx : trait_decls_ctx; + trait_impls_ctx : trait_impls_ctx; fun_decl : A.fun_decl; bid : RegionGroupId.id option; (** TODO: rename @@ -234,9 +234,9 @@ type bs_ctx = { (* TODO: move *) let bs_ctx_to_fmt_env (ctx : bs_ctx) : Print.fmt_env = - let type_decls = ctx.type_context.llbc_type_decls in - let fun_decls = ctx.fun_context.llbc_fun_decls in - let global_decls = ctx.global_context.llbc_global_decls in + let type_decls = ctx.type_ctx.llbc_type_decls in + let fun_decls = ctx.fun_ctx.llbc_fun_decls in + let global_decls = ctx.global_ctx.llbc_global_decls in let trait_decls = ctx.trait_decls_ctx in let trait_impls = ctx.trait_impls_ctx in let { regions; types; const_generics; trait_clauses } : T.generic_params = @@ -258,9 +258,9 @@ let bs_ctx_to_fmt_env (ctx : bs_ctx) : Print.fmt_env = } let bs_ctx_to_pure_fmt_env (ctx : bs_ctx) : PrintPure.fmt_env = - let type_decls = ctx.type_context.llbc_type_decls in - let fun_decls = ctx.fun_context.llbc_fun_decls in - let global_decls = ctx.global_context.llbc_global_decls in + let type_decls = ctx.type_ctx.llbc_type_decls in + let fun_decls = ctx.fun_ctx.llbc_fun_decls in + let global_decls = ctx.global_ctx.llbc_global_decls in let trait_decls = ctx.trait_decls_ctx in let trait_impls = ctx.trait_impls_ctx in let generics = ctx.sg.generics in @@ -346,11 +346,11 @@ let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string = let bs_ctx_lookup_llbc_type_decl (id : TypeDeclId.id) (ctx : bs_ctx) : T.type_decl = - TypeDeclId.Map.find id ctx.type_context.llbc_type_decls + TypeDeclId.Map.find id ctx.type_ctx.llbc_type_decls let bs_ctx_lookup_llbc_fun_decl (id : A.FunDeclId.id) (ctx : bs_ctx) : A.fun_decl = - A.FunDeclId.Map.find id ctx.fun_context.llbc_fun_decls + A.FunDeclId.Map.find id ctx.fun_ctx.llbc_fun_decls (* Some generic translation functions (we need to translate different "flavours" of types: forward types, backward types, etc.) *) @@ -617,13 +617,13 @@ and translate_fwd_trait_instance_id (type_infos : type_infos) (** Simply calls [translate_fwd_ty] *) let ctx_translate_fwd_ty (ctx : bs_ctx) (ty : T.ty) : ty = - let type_infos = ctx.type_context.type_infos in + let type_infos = ctx.type_ctx.type_infos in translate_fwd_ty type_infos ty (** Simply calls [translate_fwd_generic_args] *) let ctx_translate_fwd_generic_args (ctx : bs_ctx) (generics : T.generic_args) : generic_args = - let type_infos = ctx.type_context.type_infos in + let type_infos = ctx.type_ctx.type_infos in translate_fwd_generic_args type_infos generics (** Translate a type, when some regions may have ended. @@ -708,7 +708,7 @@ let rec translate_back_ty (type_infos : type_infos) (** Simply calls [translate_back_ty] *) let ctx_translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool) (inside_mut : bool) (ty : T.ty) : ty option = - let type_infos = ctx.type_context.type_infos in + let type_infos = ctx.type_ctx.type_infos in translate_back_ty type_infos keep_region inside_mut ty let mk_type_check_ctx (ctx : bs_ctx) : PureTypeCheck.tc_ctx = @@ -721,8 +721,8 @@ let mk_type_check_ctx (ctx : bs_ctx) : PureTypeCheck.tc_ctx = in let env = VarId.Map.empty in { - PureTypeCheck.type_decls = ctx.type_context.type_decls; - global_decls = ctx.global_context.llbc_global_decls; + PureTypeCheck.type_decls = ctx.type_ctx.type_decls; + global_decls = ctx.global_ctx.llbc_global_decls; env; const_generics; } @@ -742,7 +742,7 @@ let translate_fun_id_or_trait_method_ref (ctx : bs_ctx) match id with | FunId fun_id -> FunId fun_id | TraitMethod (trait_ref, method_name, fun_decl_id) -> - let type_infos = ctx.type_context.type_infos in + let type_infos = ctx.type_ctx.type_infos in let trait_ref = translate_fwd_trait_ref type_infos trait_ref in TraitMethod (trait_ref, method_name, fun_decl_id) @@ -894,7 +894,7 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) List.map (fun (g : T.region_var_group) -> g.id) regions_hierarchy in let ctx = - InterpreterUtils.initialize_eval_context decls_ctx region_groups + InterpreterUtils.initialize_eval_ctx decls_ctx region_groups sg.generics.types sg.generics.const_generics in (* Compute the normalization map for the *sty* types and add it to the context *) @@ -1786,7 +1786,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (* Retrieve the effect information about this function (can fail, * takes a state as input, etc.) *) let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos fid None None + get_fun_effect_info ctx.fun_ctx.fun_infos fid None None in (* Depending on the function effects: * - add the fuel @@ -2006,7 +2006,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) raise (Failure "Unreachable") in let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos fun_id None (Some rg_id) + get_fun_effect_info ctx.fun_ctx.fun_infos fun_id None (Some rg_id) in let generics = ctx_translate_fwd_generic_args ctx call.generics in (* Retrieve the original call and the parent abstractions *) @@ -2194,8 +2194,8 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) | V.LoopCall -> let fun_id = E.FRegular ctx.fun_decl.def_id in let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos (FunId fun_id) - (Some vloop_id) (Some rg_id) + get_fun_effect_info ctx.fun_ctx.fun_infos (FunId fun_id) (Some vloop_id) + (Some rg_id) in let loop_info = LoopId.Map.find loop_id ctx.loops in let generics = loop_info.generics in @@ -2306,7 +2306,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value) (e : S.expression) (ctx : bs_ctx) : texpression = let ctx, var = fresh_var_for_symbolic_value sval ctx in - let decl = A.GlobalDeclId.Map.find gid ctx.global_context.llbc_global_decls in + let decl = A.GlobalDeclId.Map.find gid ctx.global_ctx.llbc_global_decls in let global_expr = { id = Global gid; generics = empty_generic_args } in (* We use translate_fwd_ty to translate the global type *) let ty = ctx_translate_fwd_ty ctx decl.ty in @@ -2482,7 +2482,7 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value) - if we forbid using field projectors. *) let is_rec_def = - T.TypeDeclId.Set.mem adt_id ctx.type_context.recursive_decls + T.TypeDeclId.Set.mem adt_id ctx.type_ctx.recursive_decls in let use_let_with_cons = is_enum @@ -2495,7 +2495,7 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value) like Coq don't, in which case we have to deconstruct the whole ADT at once (`let (a, b, c) = x in`) *) || TypesUtils.type_decl_from_type_id_is_tuple_struct - ctx.type_context.type_infos type_id + ctx.type_ctx.type_infos type_id && not (Config.backend_has_tuple_projectors ()) in if use_let_with_cons then @@ -2588,7 +2588,7 @@ and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option) { e = StructUpdate su; ty = var.ty } | VaCgValue cg_id -> { e = CVar cg_id; ty = var.ty } | VaTraitConstValue (trait_ref, generics, const_name) -> - let type_infos = ctx.type_context.type_infos in + let type_infos = ctx.type_ctx.type_infos in let trait_ref = translate_fwd_trait_ref type_infos trait_ref in let generics = translate_fwd_generic_args type_infos generics in let qualif_id = TraitConst (trait_ref, generics, const_name) in @@ -2722,7 +2722,7 @@ and translate_forward_end (ectx : C.eval_ctx) let sg = ctx.fun_decl.signature in let regions_hierarchy = LlbcAstUtils.FunIdMap.find (FRegular def_id) - ctx.fun_context.regions_hierarchies + ctx.fun_ctx.regions_hierarchies in List.map (fun (gid, _) -> @@ -2816,7 +2816,7 @@ and translate_forward_end (ectx : C.eval_ctx) (* Lookup the effect info for the loop function *) let fid = E.FRegular ctx.fun_decl.def_id in let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos (FunId fid) None ctx.bid + get_fun_effect_info ctx.fun_ctx.fun_infos (FunId fid) None ctx.bid in (* Introduce a fresh output value for the forward function *) @@ -2949,7 +2949,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = List.map (fun ty -> assert ( - not (TypesUtils.ty_has_borrows !ctx.type_context.type_infos ty)); + not (TypesUtils.ty_has_borrows !ctx.type_ctx.type_infos ty)); (None, ctx_translate_fwd_ty !ctx ty)) tys in @@ -3173,7 +3173,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = (* Translate the signature *) let signature = translate_fun_sig_from_decomposed ctx.sg ctx.bid in let regions_hierarchy = - FunIdMap.find (FRegular def_id) ctx.fun_context.regions_hierarchies + FunIdMap.find (FRegular def_id) ctx.fun_ctx.regions_hierarchies in (* Translate the body, if there is *) let body = @@ -3181,8 +3181,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = | None -> None | Some body -> let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos - (FunId (FRegular def_id)) None bid + get_fun_effect_info ctx.fun_ctx.fun_infos (FunId (FRegular def_id)) + None bid in let body = translate_expression body ctx in (* Add a match over the fuel, if necessary *) -- cgit v1.2.3 From 17973e99e4784ff5e31565622d183ad89e3d9cd7 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 18 Dec 2023 11:40:44 +0100 Subject: Add some comments --- compiler/SymbolicToPure.ml | 47 ++++++++++++++++++++++++++++++---------------- 1 file changed, 31 insertions(+), 16 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index d8213317..a79340b6 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -151,8 +151,8 @@ type bs_ctx = { state_var : VarId.id; (** The current state variable, in case the function is stateful *) back_state_vars : VarId.id RegionGroupId.Map.t; - (** The additional input state variable received by a stateful backward function, - **in case we are splitting the forward/backward functions**. + (** The additional input state variable received by a stateful backward + function, **in case we are splitting the forward/backward functions**. When generating stateful functions, we generate code of the following form: @@ -195,7 +195,22 @@ type bs_ctx = { *) backward_outputs : var list RegionGroupId.Map.t; (** The variables that the backward functions will output, corresponding - to the borrows they give back (don't include the backward state) + to the borrows they give back (don't include the backward state). + + The translation is done as follows: + - for a given backward function, we choose a set of variables [v_i] + - when we detect the ended input abstraction which corresponds + to the backward function of the LLBC function we are translating, + and which consumed the values [consumed_i] (that we need to give + back to the caller), we introduce: + {[ + let v_i = consumed_i in + ... + ]} + Then, upon reaching the [Return] node, we introduce: + {[ + (v_i) + ]} *) loop_backward_outputs : var list RegionGroupId.Map.t option; (** Same as {!backward_outputs}, but for loops (if we entered a loop). @@ -1930,19 +1945,19 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs) assert (rg_id = bid); (* The translation is done as follows: - * - for a given backward function, we choose a set of variables [v_i] - * - when we detect the ended input abstraction which corresponds - * to the backward function, and which consumed the values [consumed_i], - * we introduce: - * {[ - * let v_i = consumed_i in - * ... - * ]} - * Then, when we reach the [Return] node, we introduce: - * {[ - * (v_i) - * ]} - * *) + - for a given backward function, we choose a set of variables [v_i] + - when we detect the ended input abstraction which corresponds + to the backward function, and which consumed the values [consumed_i], + we introduce: + {[ + let v_i = consumed_i in + ... + ]} + Then, when we reach the [Return] node, we introduce: + {[ + (v_i) + ]} + *) (* First, get the given back variables. We don't use the same given back variables if we translate a loop or -- cgit v1.2.3 From 999f48d032107722aa6ca714da828ab2788ca412 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 18 Dec 2023 12:06:07 +0100 Subject: Fix a minor mistake in SymbolicToPure --- compiler/SymbolicToPure.ml | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index a79340b6..7359f68a 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -950,7 +950,7 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) let fwd_info = (* *) let has_fuel = fwd_fuel <> [] in - let num_inputs_no_fuel_no_state = List.length fwd_inputs in + let num_inputs_no_fuel_no_state = List.length fwd_inputs_no_fuel_no_state in let num_inputs_with_fuel_no_state = (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *) List.length fwd_fuel + num_inputs_no_fuel_no_state @@ -2620,9 +2620,6 @@ and translate_forward_end (ectx : C.eval_ctx) (loop_input_values : V.typed_value S.symbolic_value_id_map option) (fwd_e : S.expression) (back_e : S.expression S.region_group_id_map) (ctx : bs_ctx) : texpression = - (* TODO: *) - assert (not !Config.return_back_funs); - let translate_one_end ctx (bid : RegionGroupId.id option) = let ctx = { ctx with bid } in (* Update the current state with the additional state received by the backward -- cgit v1.2.3 From 116b569d1b08a57c3ad66071979a1c966fdad3a2 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 18 Dec 2023 12:18:06 +0100 Subject: Remove the backwards field from SymbolicToPure.call_info --- compiler/SymbolicToPure.ml | 20 +++++--------------- 1 file changed, 5 insertions(+), 15 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 7359f68a..ea2082c7 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -67,12 +67,6 @@ type call_info = { Those inputs include the fuel and the state, if pertinent. *) - backwards : (V.abs * texpression list) T.RegionGroupId.Map.t; - (** A map from region group id (i.e., backward function id) to - pairs (abstraction, additional arguments received by the backward function) - - TODO: remove? it is also in the bs_ctx ("abstractions" field) - *) } [@@deriving show] @@ -224,7 +218,10 @@ type bs_ctx = { calls : call_info V.FunCallId.Map.t; (** The function calls we encountered so far *) abstractions : (V.abs * texpression list) V.AbstractionId.Map.t; - (** The ended abstractions we encountered so far, with their additional input arguments *) + (** The ended abstractions we encountered so far, with their additional + input arguments. We store it here and not in {!call_info} because + we need a map from abstraction id to abstraction (and not + from call id + region group id to abstraction). *) loop_ids_map : LoopId.id V.LoopId.Map.t; (** Ids to use for the loops *) loops : loop_info LoopId.Map.t; (** The loops we encountered so far. @@ -765,9 +762,7 @@ let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) (args : texpression list) (ctx : bs_ctx) : bs_ctx = let calls = ctx.calls in assert (not (V.FunCallId.Map.mem call_id calls)); - let info = - { forward; forward_inputs = args; backwards = T.RegionGroupId.Map.empty } - in + let info = { forward; forward_inputs = args } in let calls = V.FunCallId.Map.add call_id info calls in { ctx with calls } @@ -777,11 +772,6 @@ let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id) : bs_ctx * fun_or_op_id = (* Insert the abstraction in the call informations *) let info = V.FunCallId.Map.find call_id ctx.calls in - assert (not (T.RegionGroupId.Map.mem back_id info.backwards)); - let backwards = - T.RegionGroupId.Map.add back_id (abs, back_args) info.backwards - in - let info = { info with backwards } in let calls = V.FunCallId.Map.add call_id info ctx.calls in (* Insert the abstraction in the abstractions map *) let abstractions = ctx.abstractions in -- cgit v1.2.3 From 4f7bc41dcbc6187512111a81f968726452024d25 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 19 Dec 2023 12:54:40 +0100 Subject: Simplify SymbolicToPure.bs_ctx.{backward_outputs, loop_backward_outputs} --- compiler/SymbolicToPure.ml | 153 ++++++++++++++++++++------------------------- 1 file changed, 69 insertions(+), 84 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index ea2082c7..93e6cb4e 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -109,6 +109,10 @@ type loop_info = { (** The forward inputs are initialized at [None] *) forward_output_no_state_no_result : var option; (** The forward outputs are initialized at [None] *) + back_outputs : ty list RegionGroupId.Map.t; + (** The map from region group ids to the types of the values given back + by the corresponding loop abstractions. + *) } [@@deriving show] @@ -187,12 +191,11 @@ type bs_ctx = { Same remarks as for {!backward_inputs_no_state}. *) - backward_outputs : var list RegionGroupId.Map.t; + backward_outputs : var list option; (** The variables that the backward functions will output, corresponding to the borrows they give back (don't include the backward state). The translation is done as follows: - - for a given backward function, we choose a set of variables [v_i] - when we detect the ended input abstraction which corresponds to the backward function of the LLBC function we are translating, and which consumed the values [consumed_i] (that we need to give @@ -201,14 +204,20 @@ type bs_ctx = { let v_i = consumed_i in ... ]} - Then, upon reaching the [Return] node, we introduce: + where the [v_i] are fresh, and are stored in the [backward_output]. + - Then, upon reaching the [Return] node, we introduce: {[ - (v_i) + return (v_i) ]} + + The option is [None] before we detect the ended input abstraction, + and [Some] afterwards. *) - loop_backward_outputs : var list RegionGroupId.Map.t option; + loop_backward_outputs : var list option; (** Same as {!backward_outputs}, but for loops (if we entered a loop). + TODO: merge with [backward_outputs]? + [None] if we are not inside a loop, [Some] otherwise (and whatever the kind of function we are translating: it will be [Some] even though we are synthesizing a forward function). @@ -1607,7 +1616,9 @@ let mk_emeta_symbolic_assignments (vars : var list) (values : texpression list) let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression = match e with - | S.Return (ectx, opt_v) -> translate_return ectx opt_v ctx + | S.Return (ectx, opt_v) -> + (* Remark: we can't get there if we are inside a loop *) + translate_return ectx opt_v ctx | ReturnWithLoop (loop_id, is_continue) -> translate_return_with_loop loop_id is_continue ctx | Panic -> translate_panic ctx @@ -1644,10 +1655,9 @@ and translate_panic (ctx : bs_ctx) : texpression = if ctx.inside_loop && Option.is_some ctx.bid then (* We are synthesizing the backward function of a loop body *) let bid = Option.get ctx.bid in - let back_vars = - T.RegionGroupId.Map.find bid (Option.get ctx.loop_backward_outputs) - in - let tys = List.map (fun (v : var) -> v.ty) back_vars in + let loop_id = Option.get ctx.loop_id in + let loop = LoopId.Map.find loop_id ctx.loops in + let tys = RegionGroupId.Map.find bid loop.back_outputs in let output = mk_simpl_tuple_ty tys in mk_output output else @@ -1667,7 +1677,11 @@ and translate_panic (ctx : bs_ctx) : texpression = in mk_output output -(** [opt_v]: the value to return, in case we translate a forward body *) +(** [opt_v]: the value to return, in case we translate a forward body. + + Remark: for now, we can't get there if we are inside a loop. + If inside a loop, we use {!translate_return_with_loop}. + *) and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = (* There are two cases: @@ -1676,22 +1690,20 @@ and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) - or we are translating a backward function, in which case it should be [None] *) (* Compute the values that we should return *without the state and the result - * wrapper* *) + wrapper* *) let output = match ctx.bid with | None -> (* Forward function *) let v = Option.get opt_v in typed_value_to_texpression ctx ectx v - | Some bid -> + | Some _ -> (* Backward function *) (* Sanity check *) assert (opt_v = None); (* Group the variables in which we stored the values we need to give back. - * See the explanations for the [SynthInput] case in [translate_end_abstraction] *) - let backward_outputs = - T.RegionGroupId.Map.find bid ctx.backward_outputs - in + See the explanations for the [SynthInput] case in [translate_end_abstraction] *) + let backward_outputs = Option.get ctx.backward_outputs in let field_values = List.map mk_texpression_from_var backward_outputs in mk_simpl_tuple_texpression field_values in @@ -1728,19 +1740,16 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) (* Forward *) mk_texpression_from_var (Option.get loop_info.forward_output_no_state_no_result) - | Some bid -> + | Some _ -> (* Backward *) (* Group the variables in which we stored the values we need to give back. * See the explanations for the [SynthInput] case in [translate_end_abstraction] *) let backward_outputs = - let map = - if ctx.inside_loop then - (* We are synthesizing a loop body *) - Option.get ctx.loop_backward_outputs - else (* Regular function *) - ctx.backward_outputs - in - T.RegionGroupId.Map.find bid map + if ctx.inside_loop then + (* We are synthesizing a loop body *) + Option.get ctx.loop_backward_outputs + else (* Regular function *) + Option.get ctx.backward_outputs in let field_values = List.map mk_texpression_from_var backward_outputs in mk_simpl_tuple_texpression field_values @@ -1923,45 +1932,38 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs) ^ abs_to_string ctx abs ^ "\n")); (* When we end an input abstraction, this input abstraction gets back - * the borrows which it introduced in the context through the input - * values: by listing those values, we get the values which are given - * back by one of the backward functions we are synthesizing. *) - (* Note that we don't support nested borrows for now: if we find - * an ended synthesized input abstraction, it must be the one corresponding - * to the backward function wer are synthesizing, it can't be the one - * for a parent backward function. - *) + the borrows which it introduced in the context through the input + values: by listing those values, we get the values which are given + back by one of the backward functions we are synthesizing. + + Note that we don't support nested borrows for now: if we find + an ended synthesized input abstraction, it must be the one corresponding + to the backward function wer are synthesizing, it can't be the one + for a parent backward function. + *) let bid = Option.get ctx.bid in assert (rg_id = bid); - (* The translation is done as follows: - - for a given backward function, we choose a set of variables [v_i] - - when we detect the ended input abstraction which corresponds - to the backward function, and which consumed the values [consumed_i], - we introduce: - {[ - let v_i = consumed_i in - ... - ]} - Then, when we reach the [Return] node, we introduce: - {[ - (v_i) - ]} - *) - (* First, get the given back variables. + (* First, introduce the given back variables. We don't use the same given back variables if we translate a loop or the standard body of a function. *) - let given_back_variables = - let map = - if ctx.inside_loop then - (* We are synthesizing a loop body *) - Option.get ctx.loop_backward_outputs - else (* Regular function body *) - ctx.backward_outputs - in - T.RegionGroupId.Map.find bid map + let ctx, given_back_variables = + if ctx.inside_loop then + (* We are synthesizing a loop body *) + let loop_id = Option.get ctx.loop_id in + let loop = LoopId.Map.find loop_id ctx.loops in + let tys = RegionGroupId.Map.find bid loop.back_outputs in + let vars = List.map (fun ty -> (None, ty)) tys in + let ctx, vars = fresh_vars vars ctx in + ({ ctx with loop_backward_outputs = Some vars }, vars) + else + (* Regular function body *) + let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in + let vars = List.combine back_sg.output_names back_sg.outputs in + let ctx, vars = fresh_vars vars ctx in + ({ ctx with backward_outputs = Some vars }, vars) in (* Get the list of values consumed by the abstraction upon ending *) @@ -2943,22 +2945,15 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = (* Compute the backward outputs *) let ctx = ref ctx in - let loop_backward_outputs = + let rg_to_given_back_tys = T.RegionGroupId.Map.map (fun (_, tys) -> (* The types shouldn't contain borrows - we can translate them as forward types *) - let vars = - List.map - (fun ty -> - assert ( - not (TypesUtils.ty_has_borrows !ctx.type_ctx.type_infos ty)); - (None, ctx_translate_fwd_ty !ctx ty)) - tys - in - (* Introduce fresh variables *) - let ctx', vars = fresh_vars vars !ctx in - ctx := ctx'; - vars) + List.map + (fun ty -> + assert (not (TypesUtils.ty_has_borrows !ctx.type_ctx.type_infos ty)); + ctx_translate_fwd_ty !ctx ty) + tys) loop.rg_to_given_back_tys in let ctx = !ctx in @@ -2966,12 +2961,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = let back_output_tys = match ctx.bid with | None -> None - | Some rg_id -> - let back_outputs = - T.RegionGroupId.Map.find rg_id loop_backward_outputs - in - let back_output_tys = List.map (fun (v : var) -> v.ty) back_outputs in - Some back_output_tys + | Some rg_id -> Some (T.RegionGroupId.Map.find rg_id rg_to_given_back_tys) in (* Add the loop information in the context *) @@ -3013,6 +3003,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = generics; forward_inputs = None; forward_output_no_state_no_result = None; + back_outputs = rg_to_given_back_tys; } in let loops = LoopId.Map.add loop_id loop_info ctx.loops in @@ -3020,13 +3011,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = in (* Update the context to translate the function end *) - let ctx_end = - { - ctx with - loop_id = Some loop_id; - loop_backward_outputs = Some loop_backward_outputs; - } - in + let ctx_end = { ctx with loop_id = Some loop_id } in let fun_end = translate_expression loop.end_expr ctx_end in (* Update the context for the loop body *) -- cgit v1.2.3 From 014c0668abf0834342b2b7076cf2f0634460e519 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 19 Dec 2023 13:24:53 +0100 Subject: Remove SymbolicToPure.bs_ctx.loop_backward_outputs --- compiler/SymbolicToPure.ml | 47 +++++++++++++++------------------------------- 1 file changed, 15 insertions(+), 32 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 93e6cb4e..e2787271 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -213,17 +213,6 @@ type bs_ctx = { The option is [None] before we detect the ended input abstraction, and [Some] afterwards. *) - loop_backward_outputs : var list option; - (** Same as {!backward_outputs}, but for loops (if we entered a loop). - - TODO: merge with [backward_outputs]? - - [None] if we are not inside a loop, [Some] otherwise (and whatever - the kind of function we are translating: it will be [Some] even - though we are synthesizing a forward function). - - TODO: move to {!loop_info} - *) calls : call_info V.FunCallId.Map.t; (** The function calls we encountered so far *) abstractions : (V.abs * texpression list) V.AbstractionId.Map.t; @@ -1744,13 +1733,7 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) (* Backward *) (* Group the variables in which we stored the values we need to give back. * See the explanations for the [SynthInput] case in [translate_end_abstraction] *) - let backward_outputs = - if ctx.inside_loop then - (* We are synthesizing a loop body *) - Option.get ctx.loop_backward_outputs - else (* Regular function *) - Option.get ctx.backward_outputs - in + let backward_outputs = Option.get ctx.backward_outputs in let field_values = List.map mk_texpression_from_var backward_outputs in mk_simpl_tuple_texpression field_values in @@ -1950,20 +1933,20 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs) the standard body of a function. *) let ctx, given_back_variables = - if ctx.inside_loop then - (* We are synthesizing a loop body *) - let loop_id = Option.get ctx.loop_id in - let loop = LoopId.Map.find loop_id ctx.loops in - let tys = RegionGroupId.Map.find bid loop.back_outputs in - let vars = List.map (fun ty -> (None, ty)) tys in - let ctx, vars = fresh_vars vars ctx in - ({ ctx with loop_backward_outputs = Some vars }, vars) - else - (* Regular function body *) - let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in - let vars = List.combine back_sg.output_names back_sg.outputs in - let ctx, vars = fresh_vars vars ctx in - ({ ctx with backward_outputs = Some vars }, vars) + let vars = + if ctx.inside_loop then + (* We are synthesizing a loop body *) + let loop_id = Option.get ctx.loop_id in + let loop = LoopId.Map.find loop_id ctx.loops in + let tys = RegionGroupId.Map.find bid loop.back_outputs in + List.map (fun ty -> (None, ty)) tys + else + (* Regular function body *) + let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in + List.combine back_sg.output_names back_sg.outputs + in + let ctx, vars = fresh_vars vars ctx in + ({ ctx with backward_outputs = Some vars }, vars) in (* Get the list of values consumed by the abstraction upon ending *) -- cgit v1.2.3 From 8835d87df111d09122267fadc9a32f16b52d234a Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 14:37:43 +0100 Subject: Make good progress on merging the fwd/back functions --- compiler/SymbolicToPure.ml | 266 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 209 insertions(+), 57 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index e2787271..1ce6c698 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -67,6 +67,18 @@ type call_info = { Those inputs include the fuel and the state, if pertinent. *) + back_funs : texpression RegionGroupId.Map.t option; + (** If we do not split between the forward/backward functions: the + variables we introduced for the backward functions. + + Example: + {[ + let x, back = Vec.index_mut n v in + ^^^^ + here + ... + ]} + *) } [@@deriving show] @@ -118,6 +130,8 @@ type loop_info = { (** Body synthesis context *) type bs_ctx = { + (* TODO: there are a lot of duplications with the various decls ctx *) + decls_ctx : C.decls_ctx; type_ctx : type_ctx; fun_ctx : fun_ctx; global_ctx : global_ctx; @@ -757,17 +771,27 @@ let translate_fun_id_or_trait_method_ref (ctx : bs_ctx) TraitMethod (trait_ref, method_name, fun_decl_id) let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) - (args : texpression list) (ctx : bs_ctx) : bs_ctx = + (args : texpression list) + (back_funs : texpression RegionGroupId.Map.t option) (ctx : bs_ctx) : bs_ctx + = let calls = ctx.calls in assert (not (V.FunCallId.Map.mem call_id calls)); - let info = { forward; forward_inputs = args } in + let info = { forward; forward_inputs = args; back_funs } in let calls = V.FunCallId.Map.add call_id info calls in { ctx with calls } -(** [back_args]: the *additional* list of inputs received by the backward function *) -let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id) - (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx) - : bs_ctx * fun_or_op_id = +(** [inherit_args]: the list of inputs inherited from the forward function and + the ancestors backward functions, if pertinent. + [back_args]: the *additional* list of inputs received by the backward function, + including the state. + + Returns the updated context and the expression corresponding to the function. + *) +let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info) + (call_id : V.FunCallId.id) (back_id : T.RegionGroupId.id) + (inherited_args : texpression list) (back_args : texpression list) + (generics : generic_args) (output_ty : ty) (ctx : bs_ctx) : + bs_ctx * texpression = (* Insert the abstraction in the call informations *) let info = V.FunCallId.Map.find call_id ctx.calls in let calls = V.FunCallId.Map.add call_id info ctx.calls in @@ -777,16 +801,31 @@ let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id) let abstractions = V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions in - (* Retrieve the fun_id *) - let fun_id = - match info.forward.call_id with - | S.Fun (fid, _) -> - let fid = translate_fun_id_or_trait_method_ref ctx fid in - Fun (FromLlbc (fid, None, Some back_id)) - | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable") + (* Compute the expression corresponding to the function *) + let func = + if !Config.return_back_funs then + (* Lookup the variable introduced for the backward function *) + RegionGroupId.Map.find back_id (Option.get info.back_funs) + else + (* Retrieve the fun_id *) + let fun_id = + match info.forward.call_id with + | S.Fun (fid, _) -> + let fid = translate_fun_id_or_trait_method_ref ctx fid in + Fun (FromLlbc (fid, None, Some back_id)) + | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable") + in + let args = List.append inherited_args back_args in + let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in + let ret_ty = + if effect_info.can_fail then mk_result_ty output_ty else output_ty + in + let func_ty = mk_arrows input_tys ret_ty in + let func = { id = FunOrOp fun_id; generics } in + { e = Qualif func; ty = func_ty } in (* Update the context and return *) - ({ ctx with calls; abstractions }, fun_id) + ({ ctx with calls; abstractions }, func) (** List the ancestors of an abstraction *) let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs) @@ -878,15 +917,12 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) We use [bid] ("backward function id") only if we split the forward and the backward functions. *) -let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) - (fun_id : A.fun_id) (sg : A.fun_sig) (input_names : string option list) : - decomposed_fun_sig = +let translate_fun_sig_with_regions_hierarchy_to_decomposed + (decls_ctx : C.decls_ctx) (fun_id : A.fun_id_or_trait_method_ref) + (regions_hierarchy : T.region_var_groups) (sg : A.fun_sig) + (input_names : string option list) : decomposed_fun_sig = let fun_infos = decls_ctx.fun_ctx.fun_infos in let type_infos = decls_ctx.type_ctx.type_infos in - (* Retrieve the list of parent backward functions *) - let regions_hierarchy = - FunIdMap.find fun_id decls_ctx.fun_ctx.regions_hierarchies - in (* We need an evaluation context to normalize the types (to normalize the associated types, etc. - for instance it may happen that the types refer to the types associated to a trait ref, but where the trait ref @@ -915,9 +951,7 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) in (* Is the forward function stateful, and can it fail? *) - let fwd_effect_info = - get_fun_effect_info fun_infos (FunId fun_id) None None - in + let fwd_effect_info = get_fun_effect_info fun_infos fun_id None None in (* Compute the forward inputs *) let fwd_fuel = mk_fuel_input_ty_as_list fwd_effect_info in let fwd_inputs_no_fuel_no_state = @@ -1030,7 +1064,7 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) RegionGroupId.id * back_sg_info = let gid = rg.id in let back_effect_info = - get_fun_effect_info fun_infos (FunId fun_id) None (Some gid) + get_fun_effect_info fun_infos fun_id None (Some gid) in let inputs_no_state = translate_back_inputs_for_gid gid in let inputs_no_state = @@ -1072,6 +1106,16 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) fwd_info; } +let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) + (fun_id : FunDeclId.id) (sg : A.fun_sig) (input_names : string option list) + : decomposed_fun_sig = + (* Retrieve the list of parent backward functions *) + let regions_hierarchy = + FunIdMap.find (FRegular fun_id) decls_ctx.fun_ctx.regions_hierarchies + in + translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx + (FunId (FRegular fun_id)) regions_hierarchy sg input_names + let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty = let output = @@ -1090,6 +1134,40 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list = mk_arrows inputs output) (RegionGroupId.Map.values dsg.back_sg) +(** Return the pure signature of a backward function, in the case the + forward/backward functions are merged (i.e., the forward functions + return the backward functions). + + TODO: merge with {!translate_fun_sig_from_decomposed} + *) +let translate_ret_back_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) + (gid : RegionGroupId.id) : fun_sig = + assert !Config.return_back_funs; + + let generics = dsg.generics in + let llbc_generics = dsg.llbc_generics in + let preds = dsg.preds in + (* Compute the effects info *) + let fwd_info = dsg.fwd_info in + let back_effect_info = + RegionGroupId.Map.of_list + (List.map + (fun ((gid, info) : RegionGroupId.id * back_sg_info) -> + (gid, info.effect_info)) + (RegionGroupId.Map.bindings dsg.back_sg)) + in + (* Two cases depending on whether we split the forward/backward functions + or not *) + let mk_output_ty = mk_output_ty_from_effect_info in + + let back_sg = RegionGroupId.Map.find gid dsg.back_sg in + let effect_info = back_sg.effect_info in + (* Do not prepend the forward inputs *) + let inputs = List.map snd back_sg.inputs in + let output = mk_simpl_tuple_ty back_sg.outputs in + let output = mk_output_ty effect_info output in + { generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info } + let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) (gid : RegionGroupId.id option) : fun_sig = let generics = dsg.generics in @@ -1774,7 +1852,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in (* Retrieve the function id, and register the function call in the context * if necessary. *) - let ctx, fun_id, effect_info, args, out_state = + let ctx, fun_id, effect_info, args, back_funs, out_state = match call.call_id with | S.Fun (fid, call_id) -> (* Regular function call *) @@ -1798,9 +1876,80 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var) else (List.concat [ fuel; args ], ctx, None) in + (* If we do not split the forward/backward functions: generate the + variables for the backward functions returned by the forward + function. *) + let ctx, back_funs_map, back_funs = + if !Config.return_back_funs then + (* We need to compute the signatures of the backward functions. *) + let sg = Option.get call.sg in + let decls_ctx = ctx.decls_ctx in + let dsg = + translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx + fid call.regions_hierarchy sg + (List.map (fun _ -> None) sg.inputs) + in + let gids = + List.map + (fun (g : T.region_var_group) -> g.id) + call.regions_hierarchy + in + let back_sgs = + List.map (translate_ret_back_fun_sig_from_decomposed dsg) gids + in + (* Introduce variables for the backward functions *) + let back_tys = + List.map + (fun (sg : fun_sig) -> mk_arrows sg.inputs sg.output) + back_sgs + in + (* Compute a proper basename for the variables *) + let back_fun_name = + let name = + match fid with + | FunId (FAssumed fid) -> ( + match fid with + | BoxNew -> "box_new" + | BoxFree -> "box_free" + | ArrayRepeat -> "array_repeat" + | ArrayIndexShared -> "index_shared" + | ArrayIndexMut -> "index_mut" + | ArrayToSliceShared -> "to_slice_shared" + | ArrayToSliceMut -> "to_slice_mut" + | SliceIndexShared -> "index_shared" + | SliceIndexMut -> "index_mut") + | FunId (FRegular fid) | TraitMethod (_, _, fid) -> ( + let decl = + FunDeclId.Map.find fid ctx.fun_ctx.llbc_fun_decls + in + match Collections.List.last decl.name with + | PeIdent (s, _) -> s + | PeImpl _ -> + (* We shouldn't get there *) + raise (Failure "Unexpected")) + in + name ^ "_back" + in + let ctx, back_vars = + fresh_vars + (List.map (fun ty -> (Some back_fun_name, ty)) back_tys) + ctx + in + let back_funs = + List.map (fun v -> mk_typed_pattern_from_var v None) back_vars + in + let back_funs_map = + RegionGroupId.Map.of_list + (List.combine gids (List.map mk_texpression_from_var back_vars)) + in + (ctx, Some back_funs_map, back_funs) + else (ctx, None, []) + in (* Register the function call *) - let ctx = bs_ctx_register_forward_call call_id call args ctx in - (ctx, func, effect_info, args, out_state) + let ctx = + bs_ctx_register_forward_call call_id call args back_funs_map ctx + in + (ctx, func, effect_info, args, back_funs, out_state) | S.Unop E.Not -> let effect_info = { @@ -1811,7 +1960,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop Not, effect_info, args, None) + (ctx, Unop Not, effect_info, args, [], None) | S.Unop E.Neg -> ( match args with | [ arg ] -> @@ -1827,7 +1976,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop (Neg int_ty), effect_info, args, None) + (ctx, Unop (Neg int_ty), effect_info, args, [], None) | _ -> raise (Failure "Unreachable")) | S.Unop (E.Cast cast_kind) -> ( match cast_kind with @@ -1842,7 +1991,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None) + (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, [], None) | CastFnPtr _ -> raise (Failure "TODO: function casts")) | S.Binop binop -> ( match args with @@ -1862,11 +2011,12 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Binop (binop, int_ty0), effect_info, args, None) + (ctx, Binop (binop, int_ty0), effect_info, args, [], None) | _ -> raise (Failure "Unreachable")) in let dest_v = let dest = mk_typed_pattern_from_var dest dest_mplace in + let dest = mk_simpl_tuple_pattern (dest :: back_funs) in match out_state with | None -> dest | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ] @@ -2026,9 +2176,11 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) else ([], ctx, None) in (* Concatenate all the inpus *) - let inputs = - List.concat [ fwd_inputs; back_ancestors_inputs; back_inputs; back_state ] + let inherited_inputs = + if !Config.return_back_funs then [] + else List.concat [ fwd_inputs; back_ancestors_inputs ] in + let back_inputs = List.append back_inputs back_state in (* Retrieve the values given back by this function: those are the output * values. We rely on the fact that there are no nested borrows to use the * meta-place information from the input values given to the forward function @@ -2046,43 +2198,43 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) | Some nstate -> mk_simpl_tuple_pattern [ nstate; output ] in (* Retrieve the function id, and register the function call in the context - * if necessary *) + if necessary.Arith_status *) let ctx, func = - bs_ctx_register_backward_call abs call_id rg_id back_inputs ctx + bs_ctx_register_backward_call abs effect_info call_id rg_id inherited_inputs + back_inputs generics output.ty ctx in (* Translate the next expression *) let next_e = translate_expression e ctx in (* Put everything together *) + let inputs = List.append inherited_inputs back_inputs in let args_mplaces = List.map (fun _ -> None) inputs in let args = List.map (fun (arg, mp) -> mk_opt_mplace_texpression mp arg) (List.combine inputs args_mplaces) in - let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in - let ret_ty = - if effect_info.can_fail then mk_result_ty output.ty else output.ty - in - let func_ty = mk_arrows input_tys ret_ty in - let func = { id = FunOrOp func; generics } in - let func = { e = Qualif func; ty = func_ty } in let call = mk_apps func args in (* **Optimization**: - * ================= - * We do a small optimization here: if the backward function doesn't - * have any output, we don't introduce any function call. - * See the comment in {!Config.filter_useless_monadic_calls}. - * - * TODO: use an option to disallow backward functions from updating the state. - * TODO: a backward function which only gives back shared borrows shouldn't - * update the state (state updates should only be used for mutable borrows, - * with objects like Rc for instance). - *) - if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None then ( + ================= + We do a small optimization here if we split the forward/backward functions. + If the backward function doesn't have any output, we don't introduce any function + call. + See the comment in {!Config.filter_useless_monadic_calls}. + + TODO: use an option to disallow backward functions from updating the state. + TODO: a backward function which only gives back shared borrows shouldn't + update the state (state updates should only be used for mutable borrows, + with objects like Rc for instance). + *) + if + (not !Config.return_back_funs) + && !Config.filter_useless_monadic_calls + && outputs = [] && nstate = None + then ( (* No outputs - we do a small sanity check: the backward function - * should have exactly the same number of inputs as the forward: - * this number can be different only if the forward function returned - * a value containing mutable borrows, which can't be the case... *) + should have exactly the same number of inputs as the forward: + this number can be different only if the forward function returned + a value containing mutable borrows, which can't be the case... *) assert (List.length inputs = List.length fwd_inputs); next_e) else mk_let effect_info.can_fail output call next_e -- cgit v1.2.3 From a630b8a703d8761746f7258b6db54080aa974f53 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 14:49:37 +0100 Subject: Fix a minor issue --- compiler/SymbolicToPure.ml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 1ce6c698..3d955061 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -1123,12 +1123,15 @@ let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty in if effect_info.can_fail then mk_result_ty output else output -(** Compute the arrow types for all the backward functions *) +(** Compute the arrow types for all the backward functions. + + TODO: merge with below? + *) let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list = List.map (fun (back_sg : back_sg_info) -> let effect_info = back_sg.effect_info in - let inputs = dsg.fwd_inputs @ List.map snd back_sg.inputs in + let inputs = List.map snd back_sg.inputs in let output = mk_simpl_tuple_ty back_sg.outputs in let output = mk_output_ty_from_effect_info effect_info output in mk_arrows inputs output) -- cgit v1.2.3 From 435fe4cf63869448e2b25486b564ede9efa9a34b Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 15:17:28 +0100 Subject: Fix some issues in SymbolicToPure --- compiler/SymbolicToPure.ml | 51 +++++++++++++++++++++++----------------------- 1 file changed, 26 insertions(+), 25 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 3d955061..ef0a0bde 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -1137,39 +1137,30 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list = mk_arrows inputs output) (RegionGroupId.Map.values dsg.back_sg) -(** Return the pure signature of a backward function, in the case the - forward/backward functions are merged (i.e., the forward functions +(** Return the instantiated pure signature of a backward function, in the + case the forward/backward functions are merged (i.e., the forward functions return the backward functions). - - TODO: merge with {!translate_fun_sig_from_decomposed} *) -let translate_ret_back_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) - (gid : RegionGroupId.id) : fun_sig = +let translate_ret_back_inst_fun_sig_from_decomposed + (dsg : Pure.decomposed_fun_sig) (generics : generic_args) + (gid : RegionGroupId.id) : inst_fun_sig = assert !Config.return_back_funs; - - let generics = dsg.generics in - let llbc_generics = dsg.llbc_generics in - let preds = dsg.preds in - (* Compute the effects info *) - let fwd_info = dsg.fwd_info in - let back_effect_info = - RegionGroupId.Map.of_list - (List.map - (fun ((gid, info) : RegionGroupId.id * back_sg_info) -> - (gid, info.effect_info)) - (RegionGroupId.Map.bindings dsg.back_sg)) - in - (* Two cases depending on whether we split the forward/backward functions - or not *) let mk_output_ty = mk_output_ty_from_effect_info in - + (* Lookup the signature information *) let back_sg = RegionGroupId.Map.find gid dsg.back_sg in let effect_info = back_sg.effect_info in (* Do not prepend the forward inputs *) let inputs = List.map snd back_sg.inputs in let output = mk_simpl_tuple_ty back_sg.outputs in let output = mk_output_ty effect_info output in - { generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info } + (* Substitute the types *) + let tr_self = UnknownTrait __FUNCTION__ in + let subst = make_subst_from_generics dsg.generics generics tr_self in + let subst = ty_substitute subst in + let inputs = List.map subst inputs in + let output = subst output in + (* Return *) + { inputs; output } let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) (gid : RegionGroupId.id option) : fun_sig = @@ -1898,12 +1889,14 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : call.regions_hierarchy in let back_sgs = - List.map (translate_ret_back_fun_sig_from_decomposed dsg) gids + List.map + (translate_ret_back_inst_fun_sig_from_decomposed dsg generics) + gids in (* Introduce variables for the backward functions *) let back_tys = List.map - (fun (sg : fun_sig) -> mk_arrows sg.inputs sg.output) + (fun (sg : inst_fun_sig) -> mk_arrows sg.inputs sg.output) back_sgs in (* Compute a proper basename for the variables *) @@ -2216,6 +2209,14 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) (fun (arg, mp) -> mk_opt_mplace_texpression mp arg) (List.combine inputs args_mplaces) in + log#ldebug + (lazy + (let args = List.map (texpression_to_string ctx) args in + "func: " + ^ texpression_to_string ctx func + ^ "\nfunc type: " + ^ pure_ty_to_string ctx func.ty + ^ "\n\nargs:\n" ^ String.concat "\n" args)); let call = mk_apps func args in (* **Optimization**: ================= -- cgit v1.2.3 From cf3eea59ee61f2341daf7248664b8be878f128af Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 16:35:27 +0100 Subject: Update SymbolicToPure.ml for the loops --- compiler/SymbolicToPure.ml | 221 +++++++++++++++++++++++++-------------------- 1 file changed, 125 insertions(+), 96 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index ef0a0bde..d3b0933c 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -125,6 +125,11 @@ type loop_info = { (** The map from region group ids to the types of the values given back by the corresponding loop abstractions. *) + back_funs : texpression RegionGroupId.Map.t option; + (** Same as {!call_info.back_funs}. + Initialized with [None], gets updated to [Some] only if we merge + the fwd/back functions. + *) } [@@deriving show] @@ -1123,45 +1128,25 @@ let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty in if effect_info.can_fail then mk_result_ty output else output -(** Compute the arrow types for all the backward functions. - - TODO: merge with below? - *) -let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list = +(** Compute the arrow types for all the backward functions. *) +let compute_back_tys (dsg : Pure.decomposed_fun_sig) + (subst : (generic_args * trait_instance_id) option) : ty list = List.map (fun (back_sg : back_sg_info) -> let effect_info = back_sg.effect_info in + (* Compute *) let inputs = List.map snd back_sg.inputs in let output = mk_simpl_tuple_ty back_sg.outputs in let output = mk_output_ty_from_effect_info effect_info output in - mk_arrows inputs output) + let ty = mk_arrows inputs output in + (* Substitute - TODO: normalize *) + match subst with + | None -> ty + | Some (generics, tr_self) -> + let subst = make_subst_from_generics dsg.generics generics tr_self in + ty_substitute subst ty) (RegionGroupId.Map.values dsg.back_sg) -(** Return the instantiated pure signature of a backward function, in the - case the forward/backward functions are merged (i.e., the forward functions - return the backward functions). - *) -let translate_ret_back_inst_fun_sig_from_decomposed - (dsg : Pure.decomposed_fun_sig) (generics : generic_args) - (gid : RegionGroupId.id) : inst_fun_sig = - assert !Config.return_back_funs; - let mk_output_ty = mk_output_ty_from_effect_info in - (* Lookup the signature information *) - let back_sg = RegionGroupId.Map.find gid dsg.back_sg in - let effect_info = back_sg.effect_info in - (* Do not prepend the forward inputs *) - let inputs = List.map snd back_sg.inputs in - let output = mk_simpl_tuple_ty back_sg.outputs in - let output = mk_output_ty effect_info output in - (* Substitute the types *) - let tr_self = UnknownTrait __FUNCTION__ in - let subst = make_subst_from_generics dsg.generics generics tr_self in - let subst = ty_substitute subst in - let inputs = List.map subst inputs in - let output = subst output in - (* Return *) - { inputs; output } - let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) (gid : RegionGroupId.id option) : fun_sig = let generics = dsg.generics in @@ -1184,7 +1169,7 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) if !Config.return_back_funs then ( assert (gid = None); (* Compute the arrow types for all the backward functions *) - let back_tys = compute_back_tys dsg in + let back_tys = compute_back_tys dsg None in (* Group the forward output and the types of the backward functions *) let effect_info = dsg.fwd_info.effect_info in let output = mk_simpl_tuple_ty (dsg.fwd_output :: back_tys) in @@ -1274,6 +1259,40 @@ let fresh_vars (vars : (string option * ty) list) (ctx : bs_ctx) : bs_ctx * var list = List.fold_left_map (fun ctx (name, ty) -> fresh_var name ty ctx) ctx vars +(* Introduce variables for the backward functions *) +let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var list = + (* We lookup the LLBC definition in an attempt to derive pretty names + for the backward functions. *) + let back_var_names = + let def_id = ctx.fun_decl.def_id in + let sg = ctx.fun_decl.signature in + let regions_hierarchy = + LlbcAstUtils.FunIdMap.find (FRegular def_id) + ctx.fun_ctx.regions_hierarchies + in + List.map + (fun (gid, _) -> + let rg = RegionGroupId.nth regions_hierarchy gid in + let region_names = + List.map + (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name) + rg.regions + in + let name = + match region_names with + | [] -> "back" + | [ Some r ] -> "back" ^ r + | _ -> + (* Concatenate all the region names *) + "back" + ^ String.concat "" (List.filter_map (fun x -> x) region_names) + in + Some name) + (RegionGroupId.Map.bindings ctx.sg.back_sg) + in + let back_vars = List.combine back_var_names (compute_back_tys ctx.sg None) in + fresh_vars back_vars ctx + let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var = match V.SymbolicValueId.Map.find_opt sv.sv_id ctx.sv_to_var with | Some v -> v @@ -1728,7 +1747,7 @@ and translate_panic (ctx : bs_ctx) : texpression = match ctx.bid with | None -> if !Config.return_back_funs then - let back_tys = compute_back_tys ctx.sg in + let back_tys = compute_back_tys ctx.sg None in let output = mk_simpl_tuple_ty (ctx.sg.fwd_output :: back_tys) in mk_output output else mk_output ctx.sg.fwd_output @@ -1883,22 +1902,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : fid call.regions_hierarchy sg (List.map (fun _ -> None) sg.inputs) in - let gids = - List.map - (fun (g : T.region_var_group) -> g.id) - call.regions_hierarchy - in - let back_sgs = - List.map - (translate_ret_back_inst_fun_sig_from_decomposed dsg generics) - gids - in + let tr_self = UnknownTrait __FUNCTION__ in + let back_tys = compute_back_tys dsg (Some (generics, tr_self)) in (* Introduce variables for the backward functions *) - let back_tys = - List.map - (fun (sg : inst_fun_sig) -> mk_arrows sg.inputs sg.output) - back_sgs - in (* Compute a proper basename for the variables *) let back_fun_name = let name = @@ -1934,6 +1940,11 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let back_funs = List.map (fun v -> mk_typed_pattern_from_var v None) back_vars in + let gids = + List.map + (fun (g : T.region_var_group) -> g.id) + call.regions_hierarchy + in let back_funs_map = RegionGroupId.Map.of_list (List.combine gids (List.map mk_texpression_from_var back_vars)) @@ -2338,6 +2349,8 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) (* Actually the same case as [SynthInput] *) translate_end_abstraction_synth_input ectx abs e ctx rg_id | V.LoopCall -> + (* We need to introduce a call to the backward function corresponding + to a forward call which happened earlier *) let fun_id = E.FRegular ctx.fun_decl.def_id in let effect_info = get_fun_effect_info ctx.fun_ctx.fun_infos (FunId fun_id) (Some vloop_id) @@ -2367,7 +2380,10 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) else ([], ctx, None) in (* Concatenate all the inputs *) - let inputs = List.concat [ fwd_inputs; back_inputs; back_state ] in + let inputs = + if !Config.return_back_funs then List.concat [ back_inputs; back_state ] + else List.concat [ fwd_inputs; back_inputs; back_state ] + in (* Retrieve the values given back by this function *) let ctx, outputs = abs_to_given_back None abs ctx in (* Group the output values together: first the updated inputs *) @@ -2391,28 +2407,43 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) let ret_ty = if effect_info.can_fail then mk_result_ty output.ty else output.ty in - let func_ty = mk_arrows input_tys ret_ty in - let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in - let func = { id = FunOrOp func; generics } in - let func = { e = Qualif func; ty = func_ty } in + (* Create the expression for the function: + - it is either a call to a top-level function, if we split the + forward/backward functions + - or a call to the variable we introduced for the backward function, + if we merge the forward/backward functions *) + let func = + if !Config.return_back_funs then + RegionGroupId.Map.find rg_id (Option.get loop_info.back_funs) + else + let func_ty = mk_arrows input_tys ret_ty in + let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in + let func = { id = FunOrOp func; generics } in + { e = Qualif func; ty = func_ty } + in let call = mk_apps func args in (* **Optimization**: - * ================= - * We do a small optimization here: if the backward function doesn't - * have any output, we don't introduce any function call. - * See the comment in {!Config.filter_useless_monadic_calls}. - * - * TODO: use an option to disallow backward functions from updating the state. - * TODO: a backward function which only gives back shared borrows shouldn't - * update the state (state updates should only be used for mutable borrows, - * with objects like Rc for instance). - *) - if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None + ================= + We do a small optimization here in case we split the forward/backward + functions. + If the backward function doesn't have any output, we don't introduce + any function call. + See the comment in {!Config.filter_useless_monadic_calls}. + + TODO: use an option to disallow backward functions from updating the state. + TODO: a backward function which only gives back shared borrows shouldn't + update the state (state updates should only be used for mutable borrows, + with objects like Rc for instance). + *) + if + (not !Config.return_back_funs) + && !Config.filter_useless_monadic_calls + && outputs = [] && nstate = None then ( (* No outputs - we do a small sanity check: the backward function - * should have exactly the same number of inputs as the forward: - * this number can be different only if the forward function returned - * a value containing mutable borrows, which can't be the case... *) + should have exactly the same number of inputs as the forward: + this number can be different only if the forward function returned + a value containing mutable borrows, which can't be the case... *) assert (List.length inputs = List.length fwd_inputs); next_e) else @@ -2860,35 +2891,7 @@ and translate_forward_end (ectx : C.eval_ctx) (* Introduce variables for the backward functions. We lookup the LLBC definition in an attempt to derive pretty names for those functions. *) - let back_var_names = - let def_id = ctx.fun_decl.def_id in - let sg = ctx.fun_decl.signature in - let regions_hierarchy = - LlbcAstUtils.FunIdMap.find (FRegular def_id) - ctx.fun_ctx.regions_hierarchies - in - List.map - (fun (gid, _) -> - let rg = RegionGroupId.nth regions_hierarchy gid in - let region_names = - List.map - (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name) - rg.regions - in - let name = - match region_names with - | [] -> "back" - | [ Some r ] -> "back" ^ r - | _ -> - (* Concatenate all the region names *) - "back" - ^ String.concat "" (List.filter_map (fun x -> x) region_names) - in - Some name) - (RegionGroupId.Map.bindings ctx.sg.back_sg) - in - let back_vars = List.combine back_var_names (compute_back_tys ctx.sg) in - let _, back_vars = fresh_vars back_vars ctx in + let _, back_vars = fresh_back_vars_for_current_fun ctx in (* Create the return expressions *) let vars = fwd_var :: back_vars in @@ -2964,8 +2967,32 @@ and translate_forward_end (ectx : C.eval_ctx) (* Introduce a fresh output value for the forward function *) let ctx, output_var = fresh_var None ctx.sg.fwd_output ctx in + (* Introduce fresh variables for the backward functions of the loop. + + For now, the backward functions of the loop are the same as the + backward functions of the outer function. + *) + let ctx, back_funs_map, back_funs = + if !Config.return_back_funs then + let ctx, back_vars = fresh_back_vars_for_current_fun ctx in + let back_funs = + List.map (fun v -> mk_typed_pattern_from_var v None) back_vars + in + let gids = RegionGroupId.Map.keys ctx.sg.back_sg in + let back_funs_map = + RegionGroupId.Map.of_list + (List.combine gids (List.map mk_texpression_from_var back_vars)) + in + (ctx, Some back_funs_map, back_funs) + else (ctx, None, []) + in + + (* Introduce patterns *) let args, ctx, out_pats = + (* Create the pattern for the output value *) let output_pat = mk_typed_pattern_from_var output_var None in + (* Add the returned backward functions (they might be empty) *) + let output_pat = mk_simpl_tuple_pattern (output_pat :: back_funs) in (* Depending on the function effects: * - add the fuel @@ -2988,6 +3015,7 @@ and translate_forward_end (ectx : C.eval_ctx) loop_info with forward_inputs = Some args; forward_output_no_state_no_result = Some output_var; + back_funs = back_funs_map; } in let ctx = @@ -3143,6 +3171,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = forward_inputs = None; forward_output_no_state_no_result = None; back_outputs = rg_to_given_back_tys; + back_funs = None; } in let loops = LoopId.Map.add loop_id loop_info ctx.loops in -- cgit v1.2.3 From d4b3d0e6adae5bb9a2f62872dbcedc29aaa9fa30 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 17:00:52 +0100 Subject: Filter the useless backward functions --- compiler/SymbolicToPure.ml | 220 +++++++++++++++++++++++++++++---------------- 1 file changed, 145 insertions(+), 75 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index d3b0933c..f37ea201 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -67,7 +67,7 @@ type call_info = { Those inputs include the fuel and the state, if pertinent. *) - back_funs : texpression RegionGroupId.Map.t option; + back_funs : texpression option RegionGroupId.Map.t option; (** If we do not split between the forward/backward functions: the variables we introduced for the backward functions. @@ -78,6 +78,10 @@ type call_info = { here ... ]} + + The expression might be [None] in case the backward function + has to be filtered (because it does nothing - the backward + functions for shared borrows for instance). *) } [@@deriving show] @@ -125,7 +129,7 @@ type loop_info = { (** The map from region group ids to the types of the values given back by the corresponding loop abstractions. *) - back_funs : texpression RegionGroupId.Map.t option; + back_funs : texpression option RegionGroupId.Map.t option; (** Same as {!call_info.back_funs}. Initialized with [None], gets updated to [Some] only if we merge the fwd/back functions. @@ -777,8 +781,8 @@ let translate_fun_id_or_trait_method_ref (ctx : bs_ctx) let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) (args : texpression list) - (back_funs : texpression RegionGroupId.Map.t option) (ctx : bs_ctx) : bs_ctx - = + (back_funs : texpression option RegionGroupId.Map.t option) (ctx : bs_ctx) : + bs_ctx = let calls = ctx.calls in assert (not (V.FunCallId.Map.mem call_id calls)); let info = { forward; forward_inputs = args; back_funs } in @@ -790,13 +794,15 @@ let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) [back_args]: the *additional* list of inputs received by the backward function, including the state. - Returns the updated context and the expression corresponding to the function. + Returns the updated context and the expression corresponding to the function + that we need to call. This function may be [None] if it has to be ignored + (because it does nothing). *) let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info) (call_id : V.FunCallId.id) (back_id : T.RegionGroupId.id) (inherited_args : texpression list) (back_args : texpression list) (generics : generic_args) (output_ty : ty) (ctx : bs_ctx) : - bs_ctx * texpression = + bs_ctx * texpression option = (* Insert the abstraction in the call informations *) let info = V.FunCallId.Map.find call_id ctx.calls in let calls = V.FunCallId.Map.add call_id info ctx.calls in @@ -827,7 +833,7 @@ let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info) in let func_ty = mk_arrows input_tys ret_ty in let func = { id = FunOrOp fun_id; generics } in - { e = Qualif func; ty = func_ty } + Some { e = Qualif func; ty = func_ty } in (* Update the context and return *) ({ ctx with calls; abstractions }, func) @@ -1128,23 +1134,36 @@ let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty in if effect_info.can_fail then mk_result_ty output else output -(** Compute the arrow types for all the backward functions. *) +(** Compute the arrow types for all the backward functions. + + If a backward function has no inputs/outputs we filter it. + *) let compute_back_tys (dsg : Pure.decomposed_fun_sig) - (subst : (generic_args * trait_instance_id) option) : ty list = + (subst : (generic_args * trait_instance_id) option) : ty option list = List.map (fun (back_sg : back_sg_info) -> let effect_info = back_sg.effect_info in - (* Compute *) + (* Compute the input/output types *) let inputs = List.map snd back_sg.inputs in - let output = mk_simpl_tuple_ty back_sg.outputs in - let output = mk_output_ty_from_effect_info effect_info output in - let ty = mk_arrows inputs output in - (* Substitute - TODO: normalize *) - match subst with - | None -> ty - | Some (generics, tr_self) -> - let subst = make_subst_from_generics dsg.generics generics tr_self in - ty_substitute subst ty) + let outputs = back_sg.outputs in + (* Filter if necessary *) + if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] then + None + else + let output = mk_simpl_tuple_ty outputs in + let output = mk_output_ty_from_effect_info effect_info output in + let ty = mk_arrows inputs output in + (* Substitute - TODO: normalize *) + let ty = + match subst with + | None -> ty + | Some (generics, tr_self) -> + let subst = + make_subst_from_generics dsg.generics generics tr_self + in + ty_substitute subst ty + in + Some ty) (RegionGroupId.Map.values dsg.back_sg) let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) @@ -1169,7 +1188,7 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) if !Config.return_back_funs then ( assert (gid = None); (* Compute the arrow types for all the backward functions *) - let back_tys = compute_back_tys dsg None in + let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in (* Group the forward output and the types of the backward functions *) let effect_info = dsg.fwd_info.effect_info in let output = mk_simpl_tuple_ty (dsg.fwd_output :: back_tys) in @@ -1259,8 +1278,19 @@ let fresh_vars (vars : (string option * ty) list) (ctx : bs_ctx) : bs_ctx * var list = List.fold_left_map (fun ctx (name, ty) -> fresh_var name ty ctx) ctx vars +let fresh_opt_vars (vars : (string option * ty) option list) (ctx : bs_ctx) : + bs_ctx * var option list = + List.fold_left_map + (fun ctx var -> + match var with + | None -> (ctx, None) + | Some (name, ty) -> + let ctx, var = fresh_var name ty ctx in + (ctx, Some var)) + ctx vars + (* Introduce variables for the backward functions *) -let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var list = +let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var option list = (* We lookup the LLBC definition in an attempt to derive pretty names for the backward functions. *) let back_var_names = @@ -1291,7 +1321,13 @@ let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var list = (RegionGroupId.Map.bindings ctx.sg.back_sg) in let back_vars = List.combine back_var_names (compute_back_tys ctx.sg None) in - fresh_vars back_vars ctx + let back_vars = + List.map + (fun (name, ty) -> + match ty with None -> None | Some ty -> Some (name, ty)) + back_vars + in + fresh_opt_vars back_vars ctx let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var = match V.SymbolicValueId.Map.find_opt sv.sv_id ctx.sv_to_var with @@ -1748,6 +1784,7 @@ and translate_panic (ctx : bs_ctx) : texpression = | None -> if !Config.return_back_funs then let back_tys = compute_back_tys ctx.sg None in + let back_tys = List.filter_map (fun x -> x) back_tys in let output = mk_simpl_tuple_ty (ctx.sg.fwd_output :: back_tys) in mk_output output else mk_output ctx.sg.fwd_output @@ -1933,21 +1970,33 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : name ^ "_back" in let ctx, back_vars = - fresh_vars - (List.map (fun ty -> (Some back_fun_name, ty)) back_tys) + fresh_opt_vars + (List.map + (fun ty -> + match ty with + | None -> None + | Some ty -> Some (Some back_fun_name, ty)) + back_tys) ctx in let back_funs = - List.map (fun v -> mk_typed_pattern_from_var v None) back_vars + List.filter_map + (fun v -> + match v with + | None -> None + | Some v -> Some (mk_typed_pattern_from_var v None)) + back_vars in let gids = List.map (fun (g : T.region_var_group) -> g.id) call.regions_hierarchy in + let back_vars = + List.map (Option.map mk_texpression_from_var) back_vars + in let back_funs_map = - RegionGroupId.Map.of_list - (List.combine gids (List.map mk_texpression_from_var back_vars)) + RegionGroupId.Map.of_list (List.combine gids back_vars) in (ctx, Some back_funs_map, back_funs) else (ctx, None, []) @@ -2220,15 +2269,6 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) (fun (arg, mp) -> mk_opt_mplace_texpression mp arg) (List.combine inputs args_mplaces) in - log#ldebug - (lazy - (let args = List.map (texpression_to_string ctx) args in - "func: " - ^ texpression_to_string ctx func - ^ "\nfunc type: " - ^ pure_ty_to_string ctx func.ty - ^ "\n\nargs:\n" ^ String.concat "\n" args)); - let call = mk_apps func args in (* **Optimization**: ================= We do a small optimization here if we split the forward/backward functions. @@ -2252,7 +2292,22 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) a value containing mutable borrows, which can't be the case... *) assert (List.length inputs = List.length fwd_inputs); next_e) - else mk_let effect_info.can_fail output call next_e + else + (* The backward function might also have been filtered if we do not + split the forward/backward functions *) + match func with + | None -> next_e + | Some func -> + log#ldebug + (lazy + (let args = List.map (texpression_to_string ctx) args in + "func: " + ^ texpression_to_string ctx func + ^ "\nfunc type: " + ^ pure_ty_to_string ctx func.ty + ^ "\n\nargs:\n" ^ String.concat "\n" args)); + let call = mk_apps func args in + mk_let effect_info.can_fail output call next_e and translate_end_abstraction_identity (ectx : C.eval_ctx) (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -2348,7 +2403,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) | V.LoopSynthInput -> (* Actually the same case as [SynthInput] *) translate_end_abstraction_synth_input ectx abs e ctx rg_id - | V.LoopCall -> + | V.LoopCall -> ( (* We need to introduce a call to the backward function corresponding to a forward call which happened earlier *) let fun_id = E.FRegular ctx.fun_decl.def_id in @@ -2419,9 +2474,8 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) let func_ty = mk_arrows input_tys ret_ty in let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in let func = { id = FunOrOp func; generics } in - { e = Qualif func; ty = func_ty } + Some { e = Qualif func; ty = func_ty } in - let call = mk_apps func args in (* **Optimization**: ================= We do a small optimization here in case we split the forward/backward @@ -2447,38 +2501,44 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) assert (List.length inputs = List.length fwd_inputs); next_e) else - (* Add meta-information - this is slightly hacky: we look at the - values consumed by the abstraction (note that those come from - *before* we applied the fixed-point context) and use them to - guide the naming of the output vars. - - Also, we need to convert the backward outputs from patterns to - variables. - - Finally, in practice, this works well only for loop bodies: - we do this only in this case. - TODO: improve the heuristics, to give weight to the hints for - instance. - *) - let next_e = - if ctx.inside_loop then - let consumed_values = abs_to_consumed ctx ectx abs in - let var_values = List.combine outputs consumed_values in - let var_values = - List.filter_map - (fun (var, v) -> - match var.Pure.value with - | PatVar (var, _) -> Some (var, v) - | _ -> None) - var_values + (* In case we merge the fwd/back functions we filter the backward + functions elsewhere *) + match func with + | None -> next_e + | Some func -> + let call = mk_apps func args in + (* Add meta-information - this is slightly hacky: we look at the + values consumed by the abstraction (note that those come from + *before* we applied the fixed-point context) and use them to + guide the naming of the output vars. + + Also, we need to convert the backward outputs from patterns to + variables. + + Finally, in practice, this works well only for loop bodies: + we do this only in this case. + TODO: improve the heuristics, to give weight to the hints for + instance. + *) + let next_e = + if ctx.inside_loop then + let consumed_values = abs_to_consumed ctx ectx abs in + let var_values = List.combine outputs consumed_values in + let var_values = + List.filter_map + (fun (var, v) -> + match var.Pure.value with + | PatVar (var, _) -> Some (var, v) + | _ -> None) + var_values + in + let vars, values = List.split var_values in + mk_emeta_symbolic_assignments vars values next_e + else next_e in - let vars, values = List.split var_values in - mk_emeta_symbolic_assignments vars values next_e - else next_e - in - (* Create the let-binding *) - mk_let effect_info.can_fail output call next_e + (* Create the let-binding *) + mk_let effect_info.can_fail output call next_e) and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -2894,7 +2954,7 @@ and translate_forward_end (ectx : C.eval_ctx) let _, back_vars = fresh_back_vars_for_current_fun ctx in (* Create the return expressions *) - let vars = fwd_var :: back_vars in + let vars = fwd_var :: List.filter_map (fun x -> x) back_vars in let vars = List.map mk_texpression_from_var vars in let ret = mk_simpl_tuple_texpression vars in let state_var = List.map mk_texpression_from_var state_var in @@ -2903,12 +2963,16 @@ and translate_forward_end (ectx : C.eval_ctx) (* Bind the expressions for the backward function and the expression for the computation of the forward output *) + let back_vars_els = + List.filter_map + (fun (v, el) -> match v with None -> None | Some v -> Some (v, el)) + (List.combine back_vars back_el) + in let e = List.fold_right (fun (var, back_e) e -> mk_let false (mk_typed_pattern_from_var var None) back_e e) - (List.combine back_vars back_el) - ret + back_vars_els ret in (* Bind the expression for the forward output *) let fwd_var = mk_typed_pattern_from_var fwd_var None in @@ -2976,12 +3040,18 @@ and translate_forward_end (ectx : C.eval_ctx) if !Config.return_back_funs then let ctx, back_vars = fresh_back_vars_for_current_fun ctx in let back_funs = - List.map (fun v -> mk_typed_pattern_from_var v None) back_vars + List.filter_map + (fun v -> + match v with + | None -> None + | Some v -> Some (mk_typed_pattern_from_var v None)) + back_vars in let gids = RegionGroupId.Map.keys ctx.sg.back_sg in let back_funs_map = RegionGroupId.Map.of_list - (List.combine gids (List.map mk_texpression_from_var back_vars)) + (List.combine gids + (List.map (Option.map mk_texpression_from_var) back_vars)) in (ctx, Some back_funs_map, back_funs) else (ctx, None, []) -- cgit v1.2.3 From 2f681446b11739e650b1d6050b717da872be9022 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 19:23:29 +0100 Subject: Simplify the type of the merged fwd/back functions --- compiler/SymbolicToPure.ml | 159 +++++++++++++++++++++++++++++++++------------ 1 file changed, 116 insertions(+), 43 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index f37ea201..70a4e18d 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -979,30 +979,6 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed in (* Compute the backward output, without the effect information *) let fwd_output = translate_fwd_ty type_infos sg.output in - (* The additinoal information *) - let fwd_info = - (* *) - let has_fuel = fwd_fuel <> [] in - let num_inputs_no_fuel_no_state = List.length fwd_inputs_no_fuel_no_state in - let num_inputs_with_fuel_no_state = - (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *) - List.length fwd_fuel + num_inputs_no_fuel_no_state - in - let fwd_info : inputs_info = - { - has_fuel; - num_inputs_no_fuel_no_state; - num_inputs_with_fuel_no_state; - num_inputs_with_fuel_with_state = - (* We use the fact that [fwd_state_ty] has length 1 if there is a state, - and 0 otherwise *) - num_inputs_with_fuel_no_state + List.length fwd_state_ty; - } - in - let info = { fwd_info; effect_info = fwd_effect_info } in - assert (fun_sig_info_is_wf info); - info - in (* Compute the type information for the backward function *) (* Small helper to translate types for backward functions *) @@ -1086,6 +1062,9 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed in let inputs = inputs_no_state @ state in let output_names, outputs = compute_back_outputs_for_gid gid in + let filter = + !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] + in let info = { inputs; @@ -1093,6 +1072,7 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed outputs; output_names; effect_info = back_effect_info; + filter; } in (gid, info) @@ -1102,6 +1082,39 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed (List.map compute_back_info_for_group regions_hierarchy) in + (* The additional information about the forward function *) + let fwd_info = + (* *) + let has_fuel = fwd_fuel <> [] in + let num_inputs_no_fuel_no_state = List.length fwd_inputs_no_fuel_no_state in + let num_inputs_with_fuel_no_state = + (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *) + List.length fwd_fuel + num_inputs_no_fuel_no_state + in + let fwd_info : inputs_info = + { + has_fuel; + num_inputs_no_fuel_no_state; + num_inputs_with_fuel_no_state; + num_inputs_with_fuel_with_state = + (* We use the fact that [fwd_state_ty] has length 1 if there is a state, + and 0 otherwise *) + num_inputs_with_fuel_no_state + List.length fwd_state_ty; + } + in + let ignore_output = + if !Config.return_back_funs && !Config.simplify_merged_fwd_backs then + ty_is_unit fwd_output + && List.exists + (fun (info : back_sg_info) -> not info.filter) + (RegionGroupId.Map.values back_sg) + else false + in + let info = { fwd_info; effect_info = fwd_effect_info; ignore_output } in + assert (fun_sig_info_is_wf info); + info + in + (* Generic parameters *) let generics = translate_generic_params sg.generics in @@ -1134,6 +1147,13 @@ let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty in if effect_info.can_fail then mk_result_ty output else output +let mk_back_output_ty_from_effect_info (effect_info : fun_effect_info) + (inputs : ty list) (ty : ty) : ty = + let output = + if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ] else ty + in + if effect_info.can_fail && inputs <> [] then mk_result_ty output else output + (** Compute the arrow types for all the backward functions. If a backward function has no inputs/outputs we filter it. @@ -1151,7 +1171,9 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) None else let output = mk_simpl_tuple_ty outputs in - let output = mk_output_ty_from_effect_info effect_info output in + let output = + mk_back_output_ty_from_effect_info effect_info inputs output + in let ty = mk_arrows inputs output in (* Substitute - TODO: normalize *) let ty = @@ -1166,6 +1188,25 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) Some ty) (RegionGroupId.Map.values dsg.back_sg) +(** In case we merge the fwd/back functions: compute the output type of + a function, from a decomposed signature. *) +let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty = + assert !Config.return_back_funs; + (* Compute the arrow types for all the backward functions *) + let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in + (* Group the forward output and the types of the backward functions *) + let effect_info = dsg.fwd_info.effect_info in + let output = + (* We might need to ignore the output of the forward function + (if it is unit for instance) *) + let tys = + if dsg.fwd_info.ignore_output then back_tys + else dsg.fwd_output :: back_tys + in + mk_simpl_tuple_ty tys + in + mk_output_ty_from_effect_info effect_info output + let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) (gid : RegionGroupId.id option) : fun_sig = let generics = dsg.generics in @@ -1180,19 +1221,12 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) (gid, info.effect_info)) (RegionGroupId.Map.bindings dsg.back_sg)) in - (* Two cases depending on whether we split the forward/backward functions - or not *) let mk_output_ty = mk_output_ty_from_effect_info in - let inputs, output = + (* Two cases depending on whether we split the forward/backward functions or not *) if !Config.return_back_funs then ( assert (gid = None); - (* Compute the arrow types for all the backward functions *) - let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in - (* Group the forward output and the types of the backward functions *) - let effect_info = dsg.fwd_info.effect_info in - let output = mk_simpl_tuple_ty (dsg.fwd_output :: back_tys) in - let output = mk_output_ty effect_info output in + let output = compute_output_ty_from_decomposed dsg in let inputs = dsg.fwd_inputs in (inputs, output)) else @@ -1785,7 +1819,11 @@ and translate_panic (ctx : bs_ctx) : texpression = if !Config.return_back_funs then let back_tys = compute_back_tys ctx.sg None in let back_tys = List.filter_map (fun x -> x) back_tys in - let output = mk_simpl_tuple_ty (ctx.sg.fwd_output :: back_tys) in + let tys = + if ctx.sg.fwd_info.ignore_output then back_tys + else ctx.sg.fwd_output :: back_tys + in + let output = mk_simpl_tuple_ty tys in mk_output output else mk_output ctx.sg.fwd_output | Some bid -> @@ -1798,6 +1836,9 @@ and translate_panic (ctx : bs_ctx) : texpression = Remark: for now, we can't get there if we are inside a loop. If inside a loop, we use {!translate_return_with_loop}. + + Remark: in case we merge the forward/backward functions, we introduce + those in [translate_forward_end]. *) and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = @@ -2648,6 +2689,12 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) If (true_e, false_e) ) in let ty = true_e.ty in + log#ldebug + (lazy + ("true_e.ty: " + ^ pure_ty_to_string ctx true_e.ty + ^ "\n\nfalse_e.ty: " + ^ pure_ty_to_string ctx false_e.ty)); assert (ty = false_e.ty); { e; ty } | ExpandInt (int_ty, branches, otherwise) -> @@ -2941,37 +2988,63 @@ and translate_forward_end (ectx : C.eval_ctx) in let fwd_e = translate_one_end ctx None in - (* Introduce the backward functions *) + (* Introduce the backward functions. *) let back_el = List.map (fun ((gid, _) : RegionGroupId.id * back_sg_info) -> translate_one_end ctx (Some gid)) (RegionGroupId.Map.bindings ctx.sg.back_sg) in + + (* Compute whether the backward expressions should be evaluated straight + away or not (i.e., if we should bind them with monadic let-bindings + or not). We evaluate them straight away if they can fail and have no + inputs *) + let evaluate_backs = + List.map + (fun (sg : back_sg_info) -> + if !Config.simplify_merged_fwd_backs then + sg.inputs = [] && sg.effect_info.can_fail + else false) + (RegionGroupId.Map.values ctx.sg.back_sg) + in + (* Introduce variables for the backward functions. We lookup the LLBC definition in an attempt to derive pretty names for those functions. *) let _, back_vars = fresh_back_vars_for_current_fun ctx in (* Create the return expressions *) - let vars = fwd_var :: List.filter_map (fun x -> x) back_vars in + let vars = + let back_vars = List.filter_map (fun x -> x) back_vars in + if ctx.sg.fwd_info.ignore_output then back_vars + else fwd_var :: back_vars + in let vars = List.map mk_texpression_from_var vars in let ret = mk_simpl_tuple_texpression vars in let state_var = List.map mk_texpression_from_var state_var in let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in let ret = mk_result_return_texpression ret in - (* Bind the expressions for the backward function and the expression - for the computation of the forward output *) + (* Introduce all the let-bindings *) + + (* Combine: + - the backward variables + - whether we should evaluate the expression for the backward function + (i.e., should we use a monadic let-binding or not - we do if the + backward functions don't have inputs and can fail) + - the expressions for the backward functions + *) let back_vars_els = List.filter_map - (fun (v, el) -> match v with None -> None | Some v -> Some (v, el)) - (List.combine back_vars back_el) + (fun (v, (eval, el)) -> + match v with None -> None | Some v -> Some (v, eval, el)) + (List.combine back_vars (List.combine evaluate_backs back_el)) in let e = List.fold_right - (fun (var, back_e) e -> - mk_let false (mk_typed_pattern_from_var var None) back_e e) + (fun (var, evaluate, back_e) e -> + mk_let evaluate (mk_typed_pattern_from_var var None) back_e e) back_vars_els ret in (* Bind the expression for the forward output *) -- cgit v1.2.3 From 266db04e97778911c93cfd1aac251de04bb25f53 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 22:17:11 +0100 Subject: Fix several issues --- compiler/SymbolicToPure.ml | 186 ++++++++++++++++++++++++++++++++------------- 1 file changed, 135 insertions(+), 51 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 70a4e18d..37f621e4 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -146,6 +146,7 @@ type bs_ctx = { global_ctx : global_ctx; trait_decls_ctx : trait_decls_ctx; trait_impls_ctx : trait_impls_ctx; + fun_dsigs : decomposed_fun_sig FunDeclId.Map.t; fun_decl : A.fun_decl; bid : RegionGroupId.id option; (** TODO: rename @@ -890,7 +891,7 @@ let mk_fuel_input_as_list (ctx : bs_ctx) (info : fun_effect_info) : if function_uses_fuel info then [ mk_fuel_texpression ctx.fuel ] else [] (** Small utility. *) -let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) +let compute_raw_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) (fun_id : A.fun_id_or_trait_method_ref) (lid : V.LoopId.id option) (gid : T.RegionGroupId.id option) : fun_effect_info = match fun_id with @@ -917,6 +918,22 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) is_rec = false; } +(** TODO: not very clean. *) +let get_fun_effect_info (ctx : bs_ctx) (fun_id : A.fun_id_or_trait_method_ref) + (lid : V.LoopId.id option) (gid : T.RegionGroupId.id option) : + fun_effect_info = + match fun_id with + | TraitMethod (_, _, fid) | FunId (FRegular fid) -> + let dsg = A.FunDeclId.Map.find fid ctx.fun_dsigs in + let info = + match gid with + | None -> dsg.fwd_info.effect_info + | Some gid -> (RegionGroupId.Map.find gid dsg.back_sg).effect_info + in + { info with is_rec = info.is_rec || Option.is_some lid } + | FunId (FAssumed _) -> + compute_raw_fun_effect_info ctx.fun_ctx.fun_infos fun_id lid gid + (** Translate a function signature to a decomposed function signature. Note that the function also takes a list of names for the inputs, and @@ -962,7 +979,9 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed in (* Is the forward function stateful, and can it fail? *) - let fwd_effect_info = get_fun_effect_info fun_infos fun_id None None in + let fwd_effect_info = + compute_raw_fun_effect_info fun_infos fun_id None None + in (* Compute the forward inputs *) let fwd_fuel = mk_fuel_input_ty_as_list fwd_effect_info in let fwd_inputs_no_fuel_no_state = @@ -1051,12 +1070,23 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed RegionGroupId.id * back_sg_info = let gid = rg.id in let back_effect_info = - get_fun_effect_info fun_infos fun_id None (Some gid) + compute_raw_fun_effect_info fun_infos fun_id None (Some gid) in let inputs_no_state = translate_back_inputs_for_gid gid in let inputs_no_state = List.map (fun ty -> (Some "ret", ty)) inputs_no_state in + (* We consider the backward function as stateful and potentially failing + **only if it has inputs** (for the "potentially failing": if it has + not inputs, we directly evaluate it in the body of the forward function). + *) + let back_effect_info = + { + back_effect_info with + stateful = back_effect_info.stateful && inputs_no_state <> []; + can_fail = back_effect_info.can_fail && inputs_no_state <> []; + } + in let state = if back_effect_info.stateful then [ (None, mk_state_ty) ] else [] in @@ -1140,6 +1170,19 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx (FunId (FRegular fun_id)) regions_hierarchy sg input_names +let translate_fun_sig_from_decl_to_decomposed (decls_ctx : C.decls_ctx) + (fdef : LlbcAst.fun_decl) : decomposed_fun_sig = + let input_names = + match fdef.body with + | None -> List.map (fun _ -> None) fdef.signature.inputs + | Some body -> + List.map + (fun (v : LlbcAst.var) -> v.name) + (LlbcAstUtils.fun_body_get_input_vars body) + in + translate_fun_sig_to_decomposed decls_ctx fdef.def_id fdef.signature + input_names + let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty = let output = @@ -1158,8 +1201,9 @@ let mk_back_output_ty_from_effect_info (effect_info : fun_effect_info) If a backward function has no inputs/outputs we filter it. *) -let compute_back_tys (dsg : Pure.decomposed_fun_sig) - (subst : (generic_args * trait_instance_id) option) : ty option list = +let compute_back_tys_with_info (dsg : Pure.decomposed_fun_sig) + (subst : (generic_args * trait_instance_id) option) : + (back_sg_info * ty) option list = List.map (fun (back_sg : back_sg_info) -> let effect_info = back_sg.effect_info in @@ -1185,9 +1229,13 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) in ty_substitute subst ty in - Some ty) + Some (back_sg, ty)) (RegionGroupId.Map.values dsg.back_sg) +let compute_back_tys (dsg : Pure.decomposed_fun_sig) + (subst : (generic_args * trait_instance_id) option) : ty option list = + List.map (Option.map snd) (compute_back_tys_with_info dsg subst) + (** In case we merge the fwd/back functions: compute the output type of a function, from a decomposed signature. *) let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty = @@ -1363,6 +1411,7 @@ let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var option list = in fresh_opt_vars back_vars ctx +(** IMPORTANT: do not use this one directly, but rather {!symbolic_value_to_texpression} *) let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var = match V.SymbolicValueId.Map.find_opt sv.sv_id ctx.sv_to_var with | Some v -> v @@ -1381,12 +1430,22 @@ let rec unbox_typed_value (v : V.typed_value) : V.typed_value = | _ -> raise (Failure "Unreachable")) | _ -> v -(** Translate a symbolic value *) +(** Translate a symbolic value. + + Because we do not necessarily introduce variables for the symbolic values + of (translated) type unit, it is important that we do not lookup variables + in case the symbolic value has type unit. + *) let symbolic_value_to_texpression (ctx : bs_ctx) (sv : V.symbolic_value) : texpression = (* Translate the type *) - let var = lookup_var_for_symbolic_value sv ctx in - mk_texpression_from_var var + let ty = ctx_translate_fwd_ty ctx sv.sv_ty in + (* If the type is unit, directly return unit *) + if ty_is_unit ty then mk_unit_rvalue + else + (* Otherwise lookup the variable *) + let var = lookup_var_for_symbolic_value sv ctx in + mk_texpression_from_var var (** Translate a typed value. @@ -1565,13 +1624,11 @@ and aproj_to_consumed (ctx : bs_ctx) (aproj : V.aproj) : texpression option = match aproj with | V.AEndedProjLoans (msv, []) -> (* The symbolic value was left unchanged *) - let var = lookup_var_for_symbolic_value msv ctx in - Some (mk_texpression_from_var var) + Some (symbolic_value_to_texpression ctx msv) | V.AEndedProjLoans (_, [ (mnv, child_aproj) ]) -> assert (child_aproj = AIgnoredProjBorrows); (* The symbolic value was updated *) - let var = lookup_var_for_symbolic_value mnv ctx in - Some (mk_texpression_from_var var) + Some (symbolic_value_to_texpression ctx mnv) | V.AEndedProjLoans (_, _) -> (* The symbolic value was updated, and the given back values come from sevearl * abstractions *) @@ -1940,10 +1997,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (List.combine args args_mplaces) in let dest_mplace = translate_opt_mplace call.dest_place in - let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in (* Retrieve the function id, and register the function call in the context * if necessary. *) - let ctx, fun_id, effect_info, args, back_funs, out_state = + let ctx, fun_id, effect_info, args, dest_v = match call.call_id with | S.Fun (fid, call_id) -> (* Regular function call *) @@ -1951,13 +2007,11 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let func = Fun (FromLlbc (fid_t, None, None)) in (* Retrieve the effect information about this function (can fail, * takes a state as input, etc.) *) - let effect_info = - get_fun_effect_info ctx.fun_ctx.fun_infos fid None None - in + let effect_info = get_fun_effect_info ctx fid None None in (* Depending on the function effects: - * - add the fuel - * - add the state input argument - * - generate a fresh state variable for the returned state + - add the fuel + - add the state input argument + - generate a fresh state variable for the returned state *) let args, ctx, out_state = let fuel = mk_fuel_input_as_list ctx effect_info in @@ -1970,7 +2024,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (* If we do not split the forward/backward functions: generate the variables for the backward functions returned by the forward function. *) - let ctx, back_funs_map, back_funs = + let ctx, ignore_fwd_output, back_funs_map, back_funs = if !Config.return_back_funs then (* We need to compute the signatures of the backward functions. *) let sg = Option.get call.sg in @@ -1981,7 +2035,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (List.map (fun _ -> None) sg.inputs) in let tr_self = UnknownTrait __FUNCTION__ in - let back_tys = compute_back_tys dsg (Some (generics, tr_self)) in + let back_tys = + compute_back_tys_with_info dsg (Some (generics, tr_self)) + in (* Introduce variables for the backward functions *) (* Compute a proper basename for the variables *) let back_fun_name = @@ -2016,7 +2072,18 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (fun ty -> match ty with | None -> None - | Some ty -> Some (Some back_fun_name, ty)) + | Some (back_sg, ty) -> + (* We insert a name for the variable only if the function + can fail: if it can fail, it means the call returns a backward + function. Otherwise, we it directly returns the value given + back by the backward function, which means we shouldn't + give it a name like "back..." (it doesn't make sense) *) + let name = + if back_sg.effect_info.can_fail then + Some back_fun_name + else None + in + Some (name, ty)) back_tys) ctx in @@ -2039,14 +2106,37 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let back_funs_map = RegionGroupId.Map.of_list (List.combine gids back_vars) in - (ctx, Some back_funs_map, back_funs) - else (ctx, None, []) + (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs) + else (ctx, false, None, []) + in + (* Compute the pattern for the destination *) + let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in + let dest = mk_typed_pattern_from_var dest dest_mplace in + let dest = + (* Here there is something subtle: as we might ignore the output + of the forward function (because it translates to unit) we doNOT + necessarily introduce in the let-binding the variable to which we + map the symbolic value which was introduced for the output of the + function call. This would be problematic if later we need to + translate this symbolic value, but we implemented + {!symbolic_value_to_texpression} so that it doesn't perform any + lookups if the symbolic value has type unit. + *) + let vars = + if ignore_fwd_output then back_funs else dest :: back_funs + in + mk_simpl_tuple_pattern vars + in + let dest = + match out_state with + | None -> dest + | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ] in (* Register the function call *) let ctx = bs_ctx_register_forward_call call_id call args back_funs_map ctx in - (ctx, func, effect_info, args, back_funs, out_state) + (ctx, func, effect_info, args, dest) | S.Unop E.Not -> let effect_info = { @@ -2057,7 +2147,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop Not, effect_info, args, [], None) + let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in + let dest = mk_typed_pattern_from_var dest dest_mplace in + (ctx, Unop Not, effect_info, args, dest) | S.Unop E.Neg -> ( match args with | [ arg ] -> @@ -2073,7 +2165,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop (Neg int_ty), effect_info, args, [], None) + let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in + let dest = mk_typed_pattern_from_var dest dest_mplace in + (ctx, Unop (Neg int_ty), effect_info, args, dest) | _ -> raise (Failure "Unreachable")) | S.Unop (E.Cast cast_kind) -> ( match cast_kind with @@ -2088,7 +2182,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, [], None) + let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in + let dest = mk_typed_pattern_from_var dest dest_mplace in + (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, dest) | CastFnPtr _ -> raise (Failure "TODO: function casts")) | S.Binop binop -> ( match args with @@ -2108,16 +2204,11 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Binop (binop, int_ty0), effect_info, args, [], None) + let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in + let dest = mk_typed_pattern_from_var dest dest_mplace in + (ctx, Binop (binop, int_ty0), effect_info, args, dest) | _ -> raise (Failure "Unreachable")) in - let dest_v = - let dest = mk_typed_pattern_from_var dest dest_mplace in - let dest = mk_simpl_tuple_pattern (dest :: back_funs) in - match out_state with - | None -> dest - | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ] - in let func = { id = FunOrOp fun_id; generics } in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in let ret_ty = @@ -2242,9 +2333,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) (* Those don't have backward functions *) raise (Failure "Unreachable") in - let effect_info = - get_fun_effect_info ctx.fun_ctx.fun_infos fun_id None (Some rg_id) - in + let effect_info = get_fun_effect_info ctx fun_id None (Some rg_id) in let generics = ctx_translate_fwd_generic_args ctx call.generics in (* Retrieve the original call and the parent abstractions *) let _forward, backwards = get_abs_ancestors ctx abs call_id in @@ -2449,8 +2538,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) to a forward call which happened earlier *) let fun_id = E.FRegular ctx.fun_decl.def_id in let effect_info = - get_fun_effect_info ctx.fun_ctx.fun_infos (FunId fun_id) (Some vloop_id) - (Some rg_id) + get_fun_effect_info ctx (FunId fun_id) (Some vloop_id) (Some rg_id) in let loop_info = LoopId.Map.find loop_id ctx.loops in let generics = loop_info.generics in @@ -2609,8 +2697,7 @@ and translate_assertion (ectx : C.eval_ctx) (v : V.typed_value) and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) (exp : S.expansion) (ctx : bs_ctx) : texpression = (* Translate the scrutinee *) - let scrutinee_var = lookup_var_for_symbolic_value sv ctx in - let scrutinee = mk_texpression_from_var scrutinee_var in + let scrutinee = symbolic_value_to_texpression ctx sv in let scrutinee_mplace = translate_opt_mplace p in (* Translate the branches *) match exp with @@ -2999,7 +3086,7 @@ and translate_forward_end (ectx : C.eval_ctx) (* Compute whether the backward expressions should be evaluated straight away or not (i.e., if we should bind them with monadic let-bindings or not). We evaluate them straight away if they can fail and have no - inputs *) + inputs. *) let evaluate_backs = List.map (fun (sg : back_sg_info) -> @@ -3098,9 +3185,7 @@ and translate_forward_end (ectx : C.eval_ctx) (* Lookup the effect info for the loop function *) let fid = E.FRegular ctx.fun_decl.def_id in - let effect_info = - get_fun_effect_info ctx.fun_ctx.fun_infos (FunId fid) None ctx.bid - in + let effect_info = get_fun_effect_info ctx (FunId fid) None ctx.bid in (* Introduce a fresh output value for the forward function *) let ctx, output_var = fresh_var None ctx.sg.fwd_output ctx in @@ -3479,8 +3564,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = | None -> None | Some body -> let effect_info = - get_fun_effect_info ctx.fun_ctx.fun_infos (FunId (FRegular def_id)) - None bid + get_fun_effect_info ctx (FunId (FRegular def_id)) None bid in let body = translate_expression body ctx in (* Add a match over the fuel, if necessary *) -- cgit v1.2.3 From eae740d644f5ccd1ad2a7e853a9cdf303c8df61e Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 22:45:47 +0100 Subject: Fix issues when extracting stateful functions --- compiler/SymbolicToPure.ml | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 37f621e4..7eb75584 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -2782,7 +2782,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) ^ pure_ty_to_string ctx true_e.ty ^ "\n\nfalse_e.ty: " ^ pure_ty_to_string ctx false_e.ty)); - assert (ty = false_e.ty); + if !Config.fail_hard then assert (ty = false_e.ty); { e; ty } | ExpandInt (int_ty, branches, otherwise) -> let translate_branch ((v, branch_e) : V.scalar_value * S.expression) : @@ -3005,7 +3005,7 @@ and translate_forward_end (ectx : C.eval_ctx) fresh_vars back_sg.inputs_no_state ctx in let ctx, backward_inputs_with_state = - if (ctx_get_effect_info ctx).stateful then + if back_sg.effect_info.stateful then let ctx, var, _ = bs_ctx_fresh_state_var ctx in (ctx, backward_inputs_no_state @ [ var ]) else (ctx, backward_inputs_no_state) @@ -3061,18 +3061,7 @@ and translate_forward_end (ectx : C.eval_ctx) if !Config.return_back_funs then (* Compute the output of the forward function *) let fwd_effect_info = ctx.sg.fwd_info.effect_info in - let output_ty = - let ty = ctx.sg.fwd_output in - if fwd_effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ] - else ty - in - let ctx, fwd_var = fresh_var None output_ty ctx in - let ctx, state_var, state_pat = - if fwd_effect_info.stateful then - let ctx, var, pat = bs_ctx_fresh_state_var ctx in - (ctx, [ var ], [ pat ]) - else (ctx, [], []) - in + let ctx, pure_fwd_var = fresh_var None ctx.sg.fwd_output ctx in let fwd_e = translate_one_end ctx None in (* Introduce the backward functions. *) @@ -3105,10 +3094,19 @@ and translate_forward_end (ectx : C.eval_ctx) let vars = let back_vars = List.filter_map (fun x -> x) back_vars in if ctx.sg.fwd_info.ignore_output then back_vars - else fwd_var :: back_vars + else pure_fwd_var :: back_vars in let vars = List.map mk_texpression_from_var vars in let ret = mk_simpl_tuple_texpression vars in + + (* Introduce a fresh input state variable for the forward expression *) + let _ctx, state_var, state_pat = + if fwd_effect_info.stateful then + let ctx, var, pat = bs_ctx_fresh_state_var ctx in + (ctx, [ var ], [ pat ]) + else (ctx, [], []) + in + let state_var = List.map mk_texpression_from_var state_var in let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in let ret = mk_result_return_texpression ret in @@ -3135,7 +3133,7 @@ and translate_forward_end (ectx : C.eval_ctx) back_vars_els ret in (* Bind the expression for the forward output *) - let fwd_var = mk_typed_pattern_from_var fwd_var None in + let fwd_var = mk_typed_pattern_from_var pure_fwd_var None in let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in mk_let fwd_effect_info.can_fail pat fwd_e e else translate_one_end ctx ctx.bid -- cgit v1.2.3 From 774eb319e514a0ba02473f9c82ee9d3355de8a3d Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 22 Dec 2023 11:09:10 +0100 Subject: Fix an issue when merging the fwd/back functions of trait methods --- compiler/SymbolicToPure.ml | 26 +++++++++++++++++++++----- 1 file changed, 21 insertions(+), 5 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 7eb75584..41922cb5 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -1985,7 +1985,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : texpression = log#ldebug (lazy - ("translate_function_call:\n" + ("translate_function_call:\n" ^ "\n- call.call_id:" + ^ S.show_call_id call.call_id + ^ "\n\n- call.generics:\n" ^ ctx_generic_args_to_string ctx call.generics)); (* Translate the function call *) let generics = ctx_translate_fwd_generic_args ctx call.generics in @@ -2025,7 +2027,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : variables for the backward functions returned by the forward function. *) let ctx, ignore_fwd_output, back_funs_map, back_funs = - if !Config.return_back_funs then + if !Config.return_back_funs then ( (* We need to compute the signatures of the backward functions. *) let sg = Option.get call.sg in let decls_ctx = ctx.decls_ctx in @@ -2034,9 +2036,23 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : fid call.regions_hierarchy sg (List.map (fun _ -> None) sg.inputs) in - let tr_self = UnknownTrait __FUNCTION__ in + log#ldebug + (lazy ("dsg.generics:\n" ^ show_generic_params dsg.generics)); + let tr_self, all_generics = + match call.trait_method_generics with + | None -> (UnknownTrait __FUNCTION__, generics) + | Some (all_generics, tr_self) -> + let all_generics = + ctx_translate_fwd_generic_args ctx all_generics + in + let tr_self = + translate_fwd_trait_instance_id ctx.type_ctx.type_infos + tr_self + in + (tr_self, all_generics) + in let back_tys = - compute_back_tys_with_info dsg (Some (generics, tr_self)) + compute_back_tys_with_info dsg (Some (all_generics, tr_self)) in (* Introduce variables for the backward functions *) (* Compute a proper basename for the variables *) @@ -2106,7 +2122,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let back_funs_map = RegionGroupId.Map.of_list (List.combine gids back_vars) in - (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs) + (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs)) else (ctx, false, None, []) in (* Compute the pattern for the destination *) -- cgit v1.2.3 From 3688596f27a1ba461f48e88446b8812ec73f1a2f Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 22 Dec 2023 19:09:16 +0100 Subject: Add an option to split the fwd/back functions and fix a minor issue --- compiler/SymbolicToPure.ml | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 41922cb5..4674b61c 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -1076,16 +1076,20 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed let inputs_no_state = List.map (fun ty -> (Some "ret", ty)) inputs_no_state in - (* We consider the backward function as stateful and potentially failing + (* In case we merge the forward/backward functions: + we consider the backward function as stateful and potentially failing **only if it has inputs** (for the "potentially failing": if it has not inputs, we directly evaluate it in the body of the forward function). *) let back_effect_info = - { - back_effect_info with - stateful = back_effect_info.stateful && inputs_no_state <> []; - can_fail = back_effect_info.can_fail && inputs_no_state <> []; - } + if !Config.return_back_funs then + let b = inputs_no_state <> [] in + { + back_effect_info with + stateful = back_effect_info.stateful && b; + can_fail = back_effect_info.can_fail && b; + } + else back_effect_info in let state = if back_effect_info.stateful then [ (None, mk_state_ty) ] else [] @@ -1093,7 +1097,8 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed let inputs = inputs_no_state @ state in let output_names, outputs = compute_back_outputs_for_gid gid in let filter = - !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] + !Config.simplify_merged_fwd_backs + && !Config.return_back_funs && inputs = [] && outputs = [] in let info = { -- cgit v1.2.3 From b230ddacd44a1ca1804940bf89253bde8de7ffe1 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 22 Dec 2023 20:12:00 +0100 Subject: Fix a minor issue with the extraction of loops when merging the fwd/back functions --- compiler/SymbolicToPure.ml | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 4674b61c..cd367d83 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -123,7 +123,7 @@ type loop_info = { generics : generic_args; forward_inputs : texpression list option; (** The forward inputs are initialized at [None] *) - forward_output_no_state_no_result : var option; + forward_output_no_state_no_result : texpression option; (** The forward outputs are initialized at [None] *) back_outputs : ty list RegionGroupId.Map.t; (** The map from region group ids to the types of the values given back @@ -1956,10 +1956,7 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) *) let output = match ctx.bid with - | None -> - (* Forward *) - mk_texpression_from_var - (Option.get loop_info.forward_output_no_state_no_result) + | None -> Option.get loop_info.forward_output_no_state_no_result | Some _ -> (* Backward *) (* Group the variables in which we stored the values we need to give back. @@ -1984,7 +1981,7 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) else output in (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *) - mk_result_return_texpression output + mk_emeta (Tag "return_with_loop") (mk_result_return_texpression output) and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -3207,7 +3204,20 @@ and translate_forward_end (ectx : C.eval_ctx) let effect_info = get_fun_effect_info ctx (FunId fid) None ctx.bid in (* Introduce a fresh output value for the forward function *) - let ctx, output_var = fresh_var None ctx.sg.fwd_output ctx in + let ctx, fwd_output, output_pat = + if ctx.sg.fwd_info.ignore_output then + (* Note that we still need the forward output (which is unit), + because even though the loop function will ignore the forward output, + the forward expression will still compute an output (which + will have type unit - otherwise we can't ignore it). *) + (ctx, mk_unit_rvalue, []) + else + let ctx, output_var = fresh_var None ctx.sg.fwd_output ctx in + ( ctx, + mk_texpression_from_var output_var, + [ mk_typed_pattern_from_var output_var None ] ) + in + (* Introduce fresh variables for the backward functions of the loop. For now, the backward functions of the loop are the same as the @@ -3236,10 +3246,8 @@ and translate_forward_end (ectx : C.eval_ctx) (* Introduce patterns *) let args, ctx, out_pats = - (* Create the pattern for the output value *) - let output_pat = mk_typed_pattern_from_var output_var None in (* Add the returned backward functions (they might be empty) *) - let output_pat = mk_simpl_tuple_pattern (output_pat :: back_funs) in + let output_pat = mk_simpl_tuple_pattern (output_pat @ back_funs) in (* Depending on the function effects: * - add the fuel @@ -3261,7 +3269,7 @@ and translate_forward_end (ectx : C.eval_ctx) { loop_info with forward_inputs = Some args; - forward_output_no_state_no_result = Some output_var; + forward_output_no_state_no_result = Some fwd_output; back_funs = back_funs_map; } in -- cgit v1.2.3 From 70d506d148e5ae1a3e4115034161f449aff666ed Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 22 Dec 2023 21:03:17 +0100 Subject: Fix the output type of the loops backward functions --- compiler/SymbolicToPure.ml | 65 +++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 59 insertions(+), 6 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index cd367d83..bf92482a 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -3368,7 +3368,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = (* Compute the backward outputs *) let ctx = ref ctx in let rg_to_given_back_tys = - T.RegionGroupId.Map.map + RegionGroupId.Map.map (fun (_, tys) -> (* The types shouldn't contain borrows - we can translate them as forward types *) List.map @@ -3380,10 +3380,63 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = in let ctx = !ctx in - let back_output_tys = - match ctx.bid with - | None -> None - | Some rg_id -> Some (T.RegionGroupId.Map.find rg_id rg_to_given_back_tys) + (* The output type of the loop function *) + let output_ty = + if !Config.return_back_funs then + (* The loop backward functions consume the same additional inputs as the parent + function, but have custom outputs *) + let back_sgs = RegionGroupId.Map.values ctx.sg.back_sg in + let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in + let back_tys = + List.filter_map + (fun ((back_sg, given_back) : back_sg_info * ty list) -> + let effect_info = back_sg.effect_info in + (* Compute the input/output types *) + let inputs = List.map snd back_sg.inputs in + let outputs = given_back in + (* Filter if necessary *) + if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] + then None + else + let output = mk_simpl_tuple_ty outputs in + let output = + mk_back_output_ty_from_effect_info effect_info inputs output + in + let ty = mk_arrows inputs output in + Some ty) + (List.combine back_sgs given_back_tys) + in + let output = + if ctx.sg.fwd_info.ignore_output then back_tys + else ctx.sg.fwd_output :: back_tys + in + let output = mk_simpl_tuple_ty output in + let effect_info = ctx.sg.fwd_info.effect_info in + let output = + if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] + else output + in + if effect_info.can_fail && inputs <> [] then mk_result_ty output + else output + else + match ctx.bid with + | None -> + (* Forward function: same type as the parent function *) + (translate_fun_sig_from_decomposed ctx.sg None).output + | Some rg_id -> + (* Backward function: custom return type *) + let doutputs = T.RegionGroupId.Map.find rg_id rg_to_given_back_tys in + let output = mk_simpl_tuple_ty doutputs in + let fwd_effect_info = ctx.sg.fwd_info.effect_info in + let output = + if fwd_effect_info.stateful then + mk_simpl_tuple_ty [ mk_state_ty; output ] + else output + in + let output = + if fwd_effect_info.can_fail then mk_result_ty output else output + in + output in (* Add the loop information in the context *) @@ -3460,7 +3513,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = input_state; inputs; inputs_lvs; - back_output_tys; + output_ty; loop_body; } in -- cgit v1.2.3 From dd7552bec1be1695682801fca6ba6dfcfa990fbb Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 22 Dec 2023 21:03:59 +0100 Subject: Update the computation of the effect info for the loops --- compiler/SymbolicToPure.ml | 141 ++++++++++++++++++++++++++++++--------------- 1 file changed, 95 insertions(+), 46 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index bf92482a..f0d1ca62 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -134,6 +134,8 @@ type loop_info = { Initialized with [None], gets updated to [Some] only if we merge the fwd/back functions. *) + fwd_effect_info : fun_effect_info; + back_effect_infos : fun_effect_info RegionGroupId.Map.t; } [@@deriving show] @@ -922,17 +924,31 @@ let compute_raw_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) let get_fun_effect_info (ctx : bs_ctx) (fun_id : A.fun_id_or_trait_method_ref) (lid : V.LoopId.id option) (gid : T.RegionGroupId.id option) : fun_effect_info = - match fun_id with - | TraitMethod (_, _, fid) | FunId (FRegular fid) -> - let dsg = A.FunDeclId.Map.find fid ctx.fun_dsigs in - let info = - match gid with - | None -> dsg.fwd_info.effect_info - | Some gid -> (RegionGroupId.Map.find gid dsg.back_sg).effect_info - in - { info with is_rec = info.is_rec || Option.is_some lid } - | FunId (FAssumed _) -> - compute_raw_fun_effect_info ctx.fun_ctx.fun_infos fun_id lid gid + match lid with + | None -> ( + match fun_id with + | TraitMethod (_, _, fid) | FunId (FRegular fid) -> + let dsg = A.FunDeclId.Map.find fid ctx.fun_dsigs in + let info = + match gid with + | None -> dsg.fwd_info.effect_info + | Some gid -> (RegionGroupId.Map.find gid dsg.back_sg).effect_info + in + { info with is_rec = info.is_rec || Option.is_some lid } + | FunId (FAssumed _) -> + compute_raw_fun_effect_info ctx.fun_ctx.fun_infos fun_id lid gid) + | Some lid -> ( + (* This is necessarily for the current function *) + match fun_id with + | FunId (FRegular fid) -> ( + assert (fid = ctx.fun_decl.def_id); + (* Lookup the loop *) + let lid = V.LoopId.Map.find lid ctx.loop_ids_map in + let loop_info = LoopId.Map.find lid ctx.loops in + match gid with + | None -> loop_info.fwd_effect_info + | Some gid -> RegionGroupId.Map.find gid loop_info.back_effect_infos) + | _ -> raise (Failure "Unreachable")) (** Translate a function signature to a decomposed function signature. @@ -1901,7 +1917,7 @@ and translate_panic (ctx : bs_ctx) : texpression = Remark: in case we merge the forward/backward functions, we introduce those in [translate_forward_end]. - *) +*) and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = (* There are two cases: @@ -3381,31 +3397,47 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = let ctx = !ctx in (* The output type of the loop function *) - let output_ty = + let fwd_effect_info = { ctx.sg.fwd_info.effect_info with is_rec = true } in + let back_effect_infos, output_ty = if !Config.return_back_funs then (* The loop backward functions consume the same additional inputs as the parent function, but have custom outputs *) - let back_sgs = RegionGroupId.Map.values ctx.sg.back_sg in + let back_sgs = RegionGroupId.Map.bindings ctx.sg.back_sg in let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in - let back_tys = - List.filter_map - (fun ((back_sg, given_back) : back_sg_info * ty list) -> + let back_info_tys = + List.map + (fun (((id, back_sg), given_back) : (_ * back_sg_info) * ty list) -> + (* Remark: the effect info of the backward function for the loop + is almost the same as for the backward function of the parent function. + Quite importantly, the fact that the function is stateful and/or can fail + mostly depends on whether it has inputs or not, and the backward functions + for the loops have the same inputs as the backward functions for the parent + function. + *) let effect_info = back_sg.effect_info in + let effect_info = { effect_info with is_rec = true } in (* Compute the input/output types *) let inputs = List.map snd back_sg.inputs in let outputs = given_back in (* Filter if necessary *) - if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] - then None - else - let output = mk_simpl_tuple_ty outputs in - let output = - mk_back_output_ty_from_effect_info effect_info inputs output - in - let ty = mk_arrows inputs output in - Some ty) + let ty = + if + !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] + then None + else + let output = mk_simpl_tuple_ty outputs in + let output = + mk_back_output_ty_from_effect_info effect_info inputs output + in + let ty = mk_arrows inputs output in + Some ty + in + ((id, effect_info), ty)) (List.combine back_sgs given_back_tys) in + let back_info = List.map fst back_info_tys in + let back_info = RegionGroupId.Map.of_list back_info in + let back_tys = List.filter_map snd back_info_tys in let output = if ctx.sg.fwd_info.ignore_output then back_tys else ctx.sg.fwd_output :: back_tys @@ -3416,27 +3448,42 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] else output in - if effect_info.can_fail && inputs <> [] then mk_result_ty output - else output + let output = + if effect_info.can_fail && inputs <> [] then mk_result_ty output + else output + in + (back_info, output) else - match ctx.bid with - | None -> - (* Forward function: same type as the parent function *) - (translate_fun_sig_from_decomposed ctx.sg None).output - | Some rg_id -> - (* Backward function: custom return type *) - let doutputs = T.RegionGroupId.Map.find rg_id rg_to_given_back_tys in - let output = mk_simpl_tuple_ty doutputs in - let fwd_effect_info = ctx.sg.fwd_info.effect_info in - let output = - if fwd_effect_info.stateful then - mk_simpl_tuple_ty [ mk_state_ty; output ] - else output - in - let output = - if fwd_effect_info.can_fail then mk_result_ty output else output - in - output + let back_info = + RegionGroupId.Map.of_list + (List.map + (fun ((id, back_sg) : _ * back_sg_info) -> + (id, { back_sg.effect_info with is_rec = true })) + (RegionGroupId.Map.bindings ctx.sg.back_sg)) + in + let output = + match ctx.bid with + | None -> + (* Forward function: same type as the parent function *) + (translate_fun_sig_from_decomposed ctx.sg None).output + | Some rg_id -> + (* Backward function: custom return type *) + let doutputs = + T.RegionGroupId.Map.find rg_id rg_to_given_back_tys + in + let output = mk_simpl_tuple_ty doutputs in + let fwd_effect_info = ctx.sg.fwd_info.effect_info in + let output = + if fwd_effect_info.stateful then + mk_simpl_tuple_ty [ mk_state_ty; output ] + else output + in + let output = + if fwd_effect_info.can_fail then mk_result_ty output else output + in + output + in + (back_info, output) in (* Add the loop information in the context *) @@ -3480,6 +3527,8 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = forward_output_no_state_no_result = None; back_outputs = rg_to_given_back_tys; back_funs = None; + fwd_effect_info; + back_effect_infos; } in let loops = LoopId.Map.add loop_id loop_info ctx.loops in -- cgit v1.2.3 From 9a8e43df626400aacdfcb9d2cf2eec38d71d2d73 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 22 Dec 2023 23:04:31 +0100 Subject: Fix minor issues --- compiler/SymbolicToPure.ml | 57 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 48 insertions(+), 9 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index f0d1ca62..3a50e495 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -734,11 +734,15 @@ let rec translate_back_ty (type_infos : type_infos) None | TTraitType (trait_ref, generics, type_name) -> assert (generics.regions = []); - (* Translate the trait ref and the generics as "forward" generics - - we do not want to filter any type *) - let trait_ref = translate_fwd_trait_ref type_infos trait_ref in - let generics = translate_fwd_generic_args type_infos generics in - Some (TTraitType (trait_ref, generics, type_name)) + assert ( + AssociatedTypes.trait_instance_id_is_local_clause trait_ref.trait_id); + if inside_mut then + (* Translate the trait ref and the generics as "forward" generics - + we do not want to filter any type *) + let trait_ref = translate_fwd_trait_ref type_infos trait_ref in + let generics = translate_fwd_generic_args type_infos generics in + Some (TTraitType (trait_ref, generics, type_name)) + else None | TArrow _ -> raise (Failure "TODO") (** Simply calls [translate_back_ty] *) @@ -1056,7 +1060,21 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed Upon ending the abstraction for 'a, we need to get back the borrow the function returned. *) - List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] + let inputs = + List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] + in + log#ldebug + (lazy + (let ctx = Print.Contexts.decls_ctx_to_fmt_env decls_ctx in + let pctx = PrintPure.decls_ctx_to_fmt_env decls_ctx in + let output = Print.Types.ty_to_string ctx sg.output in + let inputs = + Print.list_to_string (PrintPure.ty_to_string pctx false) inputs + in + "translate_back_inputs_for_gid:" ^ "\n- gid: " + ^ RegionGroupId.to_string gid + ^ "\n- output: " ^ output ^ "\n- back inputs: " ^ inputs ^ "\n")); + inputs in let compute_back_outputs_for_gid (gid : RegionGroupId.id) : string option list * ty list = @@ -1080,7 +1098,21 @@ let translate_fun_sig_with_regions_hierarchy_to_decomposed let outputs = List.map (fun (name, opt_ty) -> (name, Option.get opt_ty)) outputs in - List.split outputs + let names, outputs = List.split outputs in + log#ldebug + (lazy + (let ctx = Print.Contexts.decls_ctx_to_fmt_env decls_ctx in + let pctx = PrintPure.decls_ctx_to_fmt_env decls_ctx in + let inputs = + Print.list_to_string (Print.Types.ty_to_string ctx) sg.inputs + in + let outputs = + Print.list_to_string (PrintPure.ty_to_string pctx false) outputs + in + "compute_back_outputs_for_gid:" ^ "\n- gid: " + ^ RegionGroupId.to_string gid + ^ "\n- inputs: " ^ inputs ^ "\n- back outputs: " ^ outputs ^ "\n")); + (names, outputs) in let compute_back_info_for_group (rg : T.region_var_group) : RegionGroupId.id * back_sg_info = @@ -1201,8 +1233,15 @@ let translate_fun_sig_from_decl_to_decomposed (decls_ctx : C.decls_ctx) (fun (v : LlbcAst.var) -> v.name) (LlbcAstUtils.fun_body_get_input_vars body) in - translate_fun_sig_to_decomposed decls_ctx fdef.def_id fdef.signature - input_names + let sg = + translate_fun_sig_to_decomposed decls_ctx fdef.def_id fdef.signature + input_names + in + log#ldebug + (lazy + ("translate_fun_sig_from_decl_to_decomposed:" ^ "\n- name: " + ^ T.show_name fdef.name ^ "\n- sg:\n" ^ show_decomposed_fun_sig sg ^ "\n")); + sg let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty = -- cgit v1.2.3