From 83c5be42e1750d329ad31bc9151d7b0446af5a0f Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 15 Dec 2023 12:14:01 +0100 Subject: Make progress on generalizing the signature information --- compiler/Extract.ml | 8 +- compiler/Pure.ml | 130 ++++++++++++------ compiler/PureMicroPasses.ml | 10 +- compiler/PureUtils.ml | 10 +- compiler/SymbolicToPure.ml | 313 ++++++++++++++++++++++---------------------- 5 files changed, 253 insertions(+), 218 deletions(-) diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 93fcf416..1ea26d79 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -1472,7 +1472,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (* TODO: *) assert (not !Config.return_back_funs); let num_fwd_inputs = - def.signature.info.fwd_info.num_inputs_with_fuel_with_state + def.signature.fwd_info.fwd_info.num_inputs_with_fuel_with_state in Collections.List.prefix num_fwd_inputs all_inputs in @@ -1520,7 +1520,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (* TODO: *) assert (not !Config.return_back_funs); let num_fwd_inputs = - def.signature.info.fwd_info.num_inputs_with_fuel_with_state + def.signature.fwd_info.fwd_info.num_inputs_with_fuel_with_state in let vars = Collections.List.prefix num_fwd_inputs all_vars in @@ -1798,7 +1798,6 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) assert body.is_global_decl_body; assert (Option.is_none body.back_id); assert (body.signature.inputs = []); - assert (List.length body.signature.doutputs = 1); assert (body.signature.generics = empty_generic_params); (* Add a break then the name of the corresponding LLBC declaration *) @@ -1817,7 +1816,8 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) let decl_ty, body_ty = let ty = body.signature.output in - if body.signature.info.effect_info.can_fail then (unwrap_result_ty ty, ty) + if body.signature.fwd_info.effect_info.can_fail then + (unwrap_result_ty ty, ty) else (ty, mk_result_ty ty) in match body.body with diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 34f3ef72..fb0509f4 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -907,12 +907,88 @@ type back_inputs_info = (inputs_info option, inputs_info) back_info type fun_sig_info = { fwd_info : inputs_info; (** Information about the inputs of the forward function *) - back_info : back_inputs_info; - (** Information about the inputs of the backward functions. *) effect_info : fun_effect_info; } [@@deriving show] +type back_sg_info = { + inputs : ty list; (** The additional inputs of the backward function *) + input_names : string option list; + (** The optional names for the additional inputs *) + outputs : ty list; + (** The "decomposed" list of outputs. + + The list contains all the types of + all the given back values (there is at most one type per forward + input argument). + + Ex.: + {[ + fn choose<'a, T>(b : bool, x : &'a mut T, y : &'a mut T) -> &'a mut T; + ]} + Decomposed outputs: + - forward function: [[T]] + - backward function: [[T; T]] (for "x" and "y") + + Non-decomposed ouputs (if the function can fail, but is not stateful): + - [result T] + - [[result (T * T)]] + *) + output_names : string option list; + (** The optional names for the backward outputs. + We derive those from the names of the inputs of the original LLBC + function. *) + effect_info : fun_effect_info; +} +[@@deriving show] + +(** A *decomposed* function signature. *) +type decomposed_fun_sig = { + generics : generic_params; + (** TODO: we should analyse the signature to make the type parameters implicit whenever possible *) + llbc_generics : Types.generic_params; + (** We use the LLBC generics to generate "pretty" names, for instance + for the variables we introduce for the trait clauses: we derive + those names from the types, and when doing so it is more meaningful + to derive them from the original LLBC types from before the + simplification of types like boxes and references. *) + preds : predicates; + fwd_inputs : ty list; + (** The types of the inputs of the forward function. + + Note that those input types take include the [fuel] parameter, + if the function uses fuel for termination, and the [state] parameter, + if the function is stateful. + + For instance, if we have the following Rust function: + {[ + fn f(x : int); + ]} + + If we translate it to a stateful function which uses fuel we get: + {[ + val f : nat -> int -> state -> result (state * unit); + ]} + + In particular, the list of input types is: [[nat; int; state]]. + *) + fwd_output : ty; + (** The "pure" output type of the forward function. + + Note that this type doesn't contain the "effect" of the function (i.e., + we haven't added the [state] if it is a stateful function and haven't + wrapped the type in a [result]). Also, this output type is only about + the forward function (it doesn't contain the type of the closures we + return for the backward functions, in case we merge the forward and + backward functions). + *) + back_sg : back_sg_info RegionGroupId.Map.t; + (** Information about the backward functions *) + fwd_info : fun_sig_info; + (** Additional information about the forward function *) +} +[@@deriving show] + (** A function signature. We have the following cases: @@ -927,15 +1003,15 @@ type fun_sig_info = { [in_ty0 -> ... -> in_tyn -> state -> back_in0 -> ... back_inm -> state -> result (state & (back_out0 & ... & back_outp))] (* state-error *) - Note that a stateful backward function may take two states as inputs: the - state received by the associated forward function, and the state at which - the backward is called. This leads to code of the following shape: + Note that a stateful backward function may take two states as inputs: the + state received by the associated forward function, and the state at which + the backward is called. This leads to code of the following shape: - {[ - (st1, y) <-- f_fwd x st0; // st0 is the state upon calling f_fwd - ... // the state may be updated - (st3, x') <-- f_back x st0 y' st2; // st2 is the state upon calling f_back - ]} + {[ + (st1, y) <-- f_fwd x st0; // st0 is the state upon calling f_fwd + ... // the state may be updated + (st3, x') <-- f_back x st0 y' st2; // st2 is the state upon calling f_back + ]} The function's type should be given by [mk_arrows sig.inputs sig.output]. We provide additional meta-information with {!fun_sig.info}: @@ -983,40 +1059,14 @@ type fun_sig = { be a tuple with a [state] if the function is stateful, and will be wrapped in a [result] if the function can fail. *) - doutputs : ty list; - (** The "decomposed" list of outputs. - - In case of a forward function, the list has length = 1, for the - type of the returned value. - - In case of backward function, the list contains all the types of - all the given back values (there is at most one type per forward - input argument). - - Ex.: - {[ - fn choose<'a, T>(b : bool, x : &'a mut T, y : &'a mut T) -> &'a mut T; - ]} - Decomposed outputs: - - forward function: [[T]] - - backward function: [[T; T]] (for "x" and "y") - - Non-decomposed ouputs (if the function can fail, but is not stateful): - - [result T] - - [[result (T * T)]] - *) - info : fun_sig_info; (** Additional information *) + fwd_info : fun_sig_info; + (** Additional information about the forward function. *) + back_effect_info : fun_effect_info RegionGroupId.Map.t; } [@@deriving show] (** An instantiated function signature. See {!fun_sig} *) -type inst_fun_sig = { - inputs : ty list; - output : ty; - doutputs : ty list; - info : fun_sig_info; -} -[@@deriving show] +type inst_fun_sig = { inputs : ty list; output : ty } [@@deriving show] type fun_body = { inputs : var list; diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 7f122f15..d92b3de0 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -1358,11 +1358,7 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list = } in - { - fwd_info; - back_info = fun_sig_info.back_info; - effect_info = loop_effect_info; - } + { fwd_info; effect_info = loop_effect_info } in assert (fun_sig_info_is_wf loop_sig_info); @@ -2187,7 +2183,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : in (* TODO: *) assert (not !Config.return_back_funs); - let { fwd_info; back_info; effect_info } = info in + let { fwd_info; effect_info } = info in let { has_fuel; @@ -2212,7 +2208,7 @@ let filter_loop_inputs (transl : pure_fun_translation list) : } in - let info = { fwd_info; back_info; effect_info } in + let info = { fwd_info; effect_info } in assert (fun_sig_info_is_wf info); let signature = { generics; llbc_generics; preds; inputs; output; doutputs; info } diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index 3c038149..dfea255a 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -73,12 +73,6 @@ let inputs_info_is_wf (info : inputs_info) : bool = let fun_sig_info_is_wf (info : fun_sig_info) : bool = inputs_info_is_wf info.fwd_info - && - match info.back_info with - | SingleBack None -> true - | SingleBack (Some info) -> inputs_info_is_wf info - | AllBacks infos -> - List.for_all inputs_info_is_wf (RegionGroupId.Map.values infos) let dest_arrow_ty (ty : ty) : ty * ty = match ty with @@ -210,9 +204,7 @@ let fun_sig_substitute (subst : subst) (sg : fun_sig) : inst_fun_sig = let subst = ty_substitute subst in let inputs = List.map subst sg.inputs in let output = subst sg.output in - let doutputs = List.map subst sg.doutputs in - let info = sg.info in - { inputs; output; doutputs; info } + { inputs; output } (** We use this to check whether we need to add parentheses around expressions. We only look for outer monadic let-bindings. diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index eba44e3e..456ec0f6 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -128,9 +128,9 @@ type bs_ctx = { trait_impls_ctx : trait_impls_context; fun_decl : A.fun_decl; bid : T.RegionGroupId.id option; (** TODO: rename *) - sg : fun_sig; - (** The function signature - useful in particular to translate [Panic] *) - fwd_sg : fun_sig; (** The signature of the forward function *) + sg : decomposed_fun_sig; + (** Information about the function signature - useful in particular to + translate [Panic] *) sv_to_var : var V.SymbolicValueId.Map.t; (** Whenever we encounter a new symbolic value (introduced because of a symbolic expansion or upon ending an abstraction, for instance) @@ -828,7 +828,7 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) is_rec = false; } -(** Translate a function signature. +(** Translate a function signature to a decomposed function signature. Note that the function also takes a list of names for the inputs, and computes, for every output for the backward functions, a corresponding @@ -839,26 +839,15 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) We use [bid] ("backward function id") only if we split the forward and the backward functions. *) -let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) - (sg : A.fun_sig) (input_names : string option list) - (bid : T.RegionGroupId.id option) : fun_sig_named_outputs = - assert (Option.is_none bid || not !Config.return_back_funs); +let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) + (fun_id : A.fun_id) (sg : A.fun_sig) (input_names : string option list) : + decomposed_fun_sig = let fun_infos = decls_ctx.fun_ctx.fun_infos in let type_infos = decls_ctx.type_ctx.type_infos in (* Retrieve the list of parent backward functions *) let regions_hierarchy = FunIdMap.find fun_id decls_ctx.fun_ctx.regions_hierarchies in - let gid, parents = - match bid with - | None -> (None, T.RegionGroupId.Set.empty) - | Some bid -> - let parents = list_ancestor_region_groups regions_hierarchy bid in - (Some bid, parents) - in - (* Is the function stateful, and can it fail? *) - let lid = None in - let effect_info = get_fun_effect_info fun_infos (FunId fun_id) lid bid in (* We need an evaluation context to normalize the types (to normalize the associated types, etc. - for instance it may happen that the types refer to the types associated to a trait ref, but where the trait ref @@ -886,17 +875,52 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) { sg with A.inputs; output } in - (* List the inputs for: - * - the fuel - * - the forward function - * - the parent backward functions, in proper order - * - the current backward function (if it is a backward function) - *) - let fuel = mk_fuel_input_ty_as_list effect_info in - let fwd_inputs = List.map (translate_fwd_ty type_infos) sg.inputs in - (* For the backward functions: for now we don't supported nested borrows, - * so just check that there aren't parent regions *) - assert (T.RegionGroupId.Set.is_empty parents); + (* Is the forward function stateful, and can it fail? *) + let fwd_effect_info = + get_fun_effect_info fun_infos (FunId fun_id) None None + in + (* Compute the forward inputs *) + let fwd_fuel = mk_fuel_input_ty_as_list fwd_effect_info in + let fwd_inputs_no_fuel_no_state = + List.map (translate_fwd_ty type_infos) sg.inputs + in + (* State input for the forward function *) + let fwd_state_ty = + (* For the forward state, we check if the *whole group* is stateful. + See {!effect_info}. *) + if fwd_effect_info.stateful_group then [ mk_state_ty ] else [] + in + let fwd_inputs = + List.concat [ fwd_fuel; fwd_inputs_no_fuel_no_state; fwd_state_ty ] + in + (* Compute the backward output, without the effect information *) + let fwd_output = translate_fwd_ty type_infos sg.output in + (* The additinoal information *) + let fwd_info = + (* *) + let has_fuel = fwd_fuel <> [] in + let num_inputs_no_fuel_no_state = List.length fwd_inputs in + let num_inputs_with_fuel_no_state = + (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *) + List.length fwd_fuel + num_inputs_no_fuel_no_state + in + let fwd_info : inputs_info = + { + has_fuel; + num_inputs_no_fuel_no_state; + num_inputs_with_fuel_no_state; + num_inputs_with_fuel_with_state = + (* We use the fact that [fwd_state_ty] has length 1 if there is a state, + and 0 otherwise *) + num_inputs_with_fuel_no_state + List.length fwd_state_ty; + } + in + let info = { fwd_info; effect_info = fwd_effect_info } in + assert (fun_sig_info_is_wf info); + info + in + + (* Compute the type information for the backward function *) (* Small helper to translate types for backward functions *) let translate_back_ty_for_gid (gid : T.RegionGroupId.id) (ty : T.ty) : ty option = @@ -923,7 +947,11 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) let inside_mut = false in translate_back_ty type_infos keep_region inside_mut ty in - let translate_back_inputs_for_gid gid : ty list = + let translate_back_inputs_for_gid (gid : T.RegionGroupId.id) : ty list = + (* For now we don't supported nested borrows, so we check that there + aren't parent regions *) + let parents = list_ancestor_region_groups regions_hierarchy gid in + assert (T.RegionGroupId.Set.is_empty parents); (* For now, we don't allow nested borrows, so the additional inputs to the backward function can only come from borrows that were returned like in (for the backward function we introduce for 'a): @@ -935,45 +963,6 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) *) List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] in - (* Compute the additinal inputs for the current function, if it is a backward - * function *) - let back_inputs = - match gid with None -> [] | Some gid -> translate_back_inputs_for_gid gid - in - (* If the function is stateful, the inputs are: - - forward: [fwd_ty0, ..., fwd_tyn, state] - - backward: - - if {!Config.backward_no_state_update}: [fwd_ty0, ..., fwd_tyn, state, back_ty, state] - - otherwise: [fwd_ty0, ..., fwd_tyn, state, back_ty] - - The backward takes the same state as input as the forward function, - together with the state at the point where it gets called, if it is - stateful. - - See the comments for {!Config.backward_no_state_update} - *) - let fwd_state_ty = - (* For the forward state, we check if the *whole group* is stateful. - See {!effect_info}. *) - if effect_info.stateful_group then [ mk_state_ty ] else [] - in - let mk_back_state_ty_for_gid (gid : RegionGroupId.id option) : ty list = - (* For the backward state, we check if the function is a backward function, - and it is stateful *) - if effect_info.stateful && Option.is_some gid then [ mk_state_ty ] else [] - in - let back_state_ty = mk_back_state_ty_for_gid gid in - - (* Concatenate the inputs, in the following order: - * - forward inputs - * - forward state input - * - backward inputs - * - backward state input - *) - let inputs = - List.concat [ fuel; fwd_inputs; fwd_state_ty; back_inputs; back_state_ty ] - in - (* Outputs *) let compute_back_outputs_for_gid (gid : RegionGroupId.id) : string option list * ty list = (* The outputs are the borrows inside the regions of the abstractions @@ -998,103 +987,111 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) in List.split outputs in - let output_names, doutputs = - match gid with - | None -> - (* This is a forward function: there is one (unnamed) output. - - If we merge the fwd/back functions we might need to compute - the information about the back outputs. - *) - (* TODO: *) - assert (not !Config.return_back_funs); - ([ None ], [ translate_fwd_ty type_infos sg.output ]) - | Some gid -> - (* This is a backward function: there might be several outputs. *) - compute_back_outputs_for_gid gid - in - (* Create the return type *) - let output = - (* Group the outputs together *) - let output = mk_simpl_tuple_ty doutputs in - (* Add the output state *) - let output = - if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] - else output + let compute_back_info_for_group (rg : T.region_var_group) : + RegionGroupId.id * back_sg_info = + let gid = rg.id in + let back_effect_info = + get_fun_effect_info fun_infos (FunId fun_id) None (Some gid) in - (* Wrap in a result type *) - if effect_info.can_fail then mk_result_ty output else output + let inputs_no_state = translate_back_inputs_for_gid gid in + let inputs_no_state_names = + List.map (fun _ -> Some "ret") inputs_no_state + in + let state_ty, state_name = + if back_effect_info.stateful then ([ mk_state_ty ], [ None ]) else ([], []) + in + let inputs = inputs_no_state @ state_ty in + let input_names = inputs_no_state_names @ state_name in + let output_names, outputs = compute_back_outputs_for_gid gid in + let info = + { + inputs; + input_names; + outputs; + output_names; + effect_info = back_effect_info; + } + in + (gid, info) in + let back_sg = + RegionGroupId.Map.of_list + (List.map compute_back_info_for_group regions_hierarchy) + in + (* Generic parameters *) let generics = translate_generic_params sg.generics in + (* Return *) - (* TODO: *) - assert (not !Config.return_back_funs); - let has_fuel = fuel <> [] in - let num_fwd_inputs_no_fuel_no_state = List.length fwd_inputs in - let num_fwd_inputs_with_fuel_no_state = - (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *) - List.length fuel + num_fwd_inputs_no_fuel_no_state - in - let num_back_inputs_no_state = - if bid = None then None else Some (List.length back_inputs) - in - let fwd_info : inputs_info = - { - has_fuel; - num_inputs_no_fuel_no_state = num_fwd_inputs_no_fuel_no_state; - num_inputs_with_fuel_no_state = num_fwd_inputs_with_fuel_no_state; - num_inputs_with_fuel_with_state = - (* We use the fact that [fwd_state_ty] has length 1 if there is a state, - and 0 otherwise *) - num_fwd_inputs_with_fuel_no_state + List.length fwd_state_ty; - } + let preds = translate_predicates sg.preds in + { + generics; + llbc_generics = sg.generics; + preds; + fwd_inputs; + fwd_output; + back_sg; + fwd_info; + } + +let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) + (gid : RegionGroupId.id option) : fun_sig = + let generics = dsg.generics in + let llbc_generics = dsg.llbc_generics in + let preds = dsg.preds in + (* Compute the effects info *) + let fwd_info = dsg.fwd_info in + let back_effect_info = + RegionGroupId.Map.of_list + (List.map + (fun ((gid, info) : RegionGroupId.id * back_sg_info) -> + (gid, info.effect_info)) + (RegionGroupId.Map.bindings dsg.back_sg)) in - let compute_back_info (back_state_ty : ty list) - (num_back_inputs_no_state : int) : inputs_info = - let n = num_back_inputs_no_state in - (* Note that backward functions never use fuel *) - { - has_fuel = false; - num_inputs_no_fuel_no_state = n; - num_inputs_with_fuel_no_state = n; - (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) - num_inputs_with_fuel_with_state = n + List.length back_state_ty; - } + (* Two cases depending on whether we split the forward/backward functions + or not *) + let mk_output_ty (effect_info : fun_effect_info) output = + let output = + if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] + else output + in + if effect_info.can_fail then mk_result_ty output else output in - let back_info : back_inputs_info = - if !Config.return_back_funs then - (* Create the map *) - AllBacks - (RegionGroupId.Map.of_list - (List.map - (fun (rg : T.region_var_group) -> - ( rg.id, - let back_inputs = translate_back_inputs_for_gid rg.id in - let num_back_inputs = List.length back_inputs in - (* TODO: slightly overkill *) - let back_state_ty = mk_back_state_ty_for_gid (Some rg.id) in - compute_back_info back_state_ty num_back_inputs )) - regions_hierarchy)) + let inputs, output = + if !Config.return_back_funs then ( + assert (gid = None); + (* Compute the arrow types for all the backward functions *) + let back_tys = + List.map + (fun (back_sg : back_sg_info) -> + let effect_info = back_sg.effect_info in + let inputs = dsg.fwd_inputs @ back_sg.inputs in + let output = mk_simpl_tuple_ty back_sg.outputs in + let output = mk_output_ty effect_info output in + mk_arrows inputs output) + (RegionGroupId.Map.values dsg.back_sg) + in + (* Group the forward output and the types of the backward functions *) + let effect_info = dsg.fwd_info.effect_info in + let output = mk_simpl_tuple_ty (dsg.fwd_output :: back_tys) in + let output = mk_output_ty effect_info output in + let inputs = dsg.fwd_inputs in + (inputs, output)) else - SingleBack - (Option.map (compute_back_info back_state_ty) num_back_inputs_no_state) - in - let info = { fwd_info; back_info; effect_info } in - assert (fun_sig_info_is_wf info); - let preds = translate_predicates sg.preds in - let sg = - { - generics; - llbc_generics = sg.generics; - preds; - inputs; - output; - doutputs; - info; - } - in - { sg; output_names } + match gid with + | None -> + let effect_info = dsg.fwd_info.effect_info in + let output = mk_output_ty effect_info dsg.fwd_output in + (dsg.fwd_inputs, output) + | Some gid -> + let back_sg = RegionGroupId.Map.find gid dsg.back_sg in + let effect_info = back_sg.effect_info in + let inputs = dsg.fwd_inputs @ back_sg.inputs in + let output = mk_simpl_tuple_ty back_sg.outputs in + let output = mk_output_ty effect_info output in + (inputs, output) + in + { generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info } let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern = (* Generate the fresh variable *) -- cgit v1.2.3