From 8b6f8e5fb85bcd1b3257550270c4c857d4ee7f55 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 9 Nov 2022 18:04:03 +0100 Subject: Implement the generation of stateful backward functions (controlled by an option) --- compiler/Driver.ml | 12 +- compiler/ExtractToFStar.ml | 8 +- compiler/FunsAnalysis.ml | 2 +- compiler/Interpreter.ml | 11 +- compiler/Pure.ml | 35 ++-- compiler/PureUtils.ml | 6 + compiler/SymbolicAst.ml | 8 + compiler/SymbolicToPure.ml | 383 ++++++++++++++++++++++++++--------------- compiler/SynthesizeSymbolic.ml | 66 ++++--- compiler/Translate.ml | 63 +++++-- 10 files changed, 386 insertions(+), 208 deletions(-) (limited to 'compiler') diff --git a/compiler/Driver.ml b/compiler/Driver.ml index d19aca93..6f0e8074 100644 --- a/compiler/Driver.ml +++ b/compiler/Driver.ml @@ -38,6 +38,8 @@ let () = let test_trans_units = ref false in let no_decreases_clauses = ref false in let no_state = ref false in + (* [backward_no_state_update]: see the comment for {!Translate.config.backward_no_state_update} *) + let backward_no_state_update = ref false in let template_decreases_clauses = ref false in let no_split_files = ref false in let no_check_inv = ref false in @@ -78,6 +80,9 @@ let () = ( "-no-state", Arg.Set no_state, " Do not use state-error monads, simply use error monads" ); + ( "-backward-no-state-update", + Arg.Set backward_no_state_update, + " Forbid backward functions from updating the state" ); ( "-template-clauses", Arg.Set template_decreases_clauses, " Generate templates for the required decreases clauses, in a\n\ @@ -95,6 +100,8 @@ let () = in (* Sanity check: -template-clauses ==> not -no-decrease-clauses *) assert ((not !no_decreases_clauses) || not !template_decreases_clauses); + (* Sanity check: -backward-no-state-update ==> not -no-state *) + assert ((not !backward_no_state_update) || not !no_state); let spec = Arg.align spec in let filenames = ref [] in @@ -110,10 +117,10 @@ let () = | [ f ] -> (* TODO: update the extension *) if not (Filename.check_suffix f ".llbc") then ( - print_string "Unrecognized file extension"; + print_string ("Unrecognized file extension: " ^ f ^ "\n"); fail ()) else if not (Sys.file_exists f) then ( - print_string "File not found"; + print_string ("File not found: " ^ f ^ "\n"); fail ()) else f | _ -> @@ -198,6 +205,7 @@ let () = extract_decreases_clauses = not !no_decreases_clauses; extract_template_decreases_clauses = !template_decreases_clauses; use_state = not !no_state; + backward_no_state_update = !backward_no_state_update; } in Translate.translate_module filename dest_dir trans_config m; diff --git a/compiler/ExtractToFStar.ml b/compiler/ExtractToFStar.ml index 2a7d6a6c..a995d4a6 100644 --- a/compiler/ExtractToFStar.ml +++ b/compiler/ExtractToFStar.ml @@ -1451,17 +1451,19 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) * function (the additional input values "given back" to the * backward functions have no influence on termination: we thus * share the decrease clauses between the forward and the backward - * functions). + * functions - we also ignore the additional state received by the + * backward function, if there is one). *) let inputs_lvs = let all_inputs = (Option.get def.body).inputs_lvs in (* We have to count: * - the forward inputs - * - the state + * - the state (if there is one) *) let num_fwd_inputs = def.signature.info.num_fwd_inputs in let num_fwd_inputs = - if def.signature.info.effect_info.input_state then 1 + num_fwd_inputs + if def.signature.info.effect_info.stateful_group then + 1 + num_fwd_inputs else num_fwd_inputs in Collections.List.prefix num_fwd_inputs all_inputs diff --git a/compiler/FunsAnalysis.ml b/compiler/FunsAnalysis.ml index 9413bd6a..4d33056b 100644 --- a/compiler/FunsAnalysis.ml +++ b/compiler/FunsAnalysis.ml @@ -16,7 +16,7 @@ module EU = ExpressionsUtils *) type fun_info = { can_fail : bool; - (* Not used yet: all the extracted functions use an error monad *) + (* Not used yet: all the extracted functions use an error monad *) stateful : bool; divergent : bool; (* Not used yet *) } diff --git a/compiler/Interpreter.ml b/compiler/Interpreter.ml index d3b3c7e6..e752594e 100644 --- a/compiler/Interpreter.ml +++ b/compiler/Interpreter.ml @@ -253,8 +253,15 @@ let evaluate_function_symbolic (config : C.partial_config) (synthesize : bool) cf_pop cf_return ctx | Some back_id -> (* Backward translation *) - evaluate_function_symbolic_synthesize_backward_from_return config - fdef inst_sg back_id ctx + let e = + evaluate_function_symbolic_synthesize_backward_from_return + config fdef inst_sg back_id ctx + in + (* We insert a delimiter to indicate the point where we switch + * from the part which is common to all the functions (forwards + * and backwards) and the part specific to this backward function. + *) + S.synthesize_forward_end e else None | Panic -> (* Note that as we explore all the execution branches, one of diff --git a/compiler/Pure.ml b/compiler/Pure.ml index a50dd5f9..cc29469a 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -496,19 +496,26 @@ and meta = (** Information about the "effect" of a function *) type fun_effect_info = { - input_state : bool; (** [true] if the function takes a state as input *) - output_state : bool; - (** [true] if the function outputs a state (it then lives - in a state monad) *) + stateful_group : bool; + (** [true] if the function group is stateful. By *function group*, we mean + the set { forward function } U { backward functions }. + + We need this because the option {!Translate.eval_config.backward_no_state_update}: + if it is [true], then in case of a backward function {!stateful} is [false], + but we might need to know whether the corresponding forward function + is stateful or not. + *) + stateful : bool; (** [true] if the function is stateful (updates a state) *) can_fail : bool; (** [true] if the return type is a [result] *) } (** Meta information about a function signature *) type fun_sig_info = { num_fwd_inputs : int; - (** The number of input types for forward computation *) + (** The number of input types for forward computation, ignoring the state *) num_back_inputs : int option; - (** The number of additional inputs for the backward computation (if pertinent) *) + (** The number of additional inputs for the backward computation (if pertinent), + ignoring the state *) effect_info : fun_effect_info; } @@ -523,12 +530,18 @@ type fun_sig_info = { `in_ty0 -> ... -> in_tyn -> back_in0 -> ... back_inm -> (back_out0 & ... & back_outp)` (* pure function *) `in_ty0 -> ... -> in_tyn -> back_in0 -> ... back_inm -> result (back_out0 & ... & back_outp)` (* error-monad *) - `in_ty0 -> ... -> in_tyn -> state -> back_in0 -> ... back_inm -> - result (back_out0 & ... & back_outp)` (* state-error *) + `in_ty0 -> ... -> in_tyn -> state -> back_in0 -> ... back_inm -> state -> + result (state & (back_out0 & ... & back_outp))` (* state-error *) - Note that a backward function never returns (i.e., updates) a state: only - forward functions do so. Also, the state input parameter is *betwee* - the forward inputs and the backward inputs. + Note that a stateful backward function takes two states as inputs: the + state received by the associated forward function, and the state at which + the backward is called. This leads to code of the following shape: + + {[ + (st1, y) <-- f_fwd x st0; // st0 is the state upon calling f_fwd + ... // the state may be updated + (st3, x') <-- f_back x st0 y' st2; // st2 is the state upon calling f_back + ]} The function's type should be given by `mk_arrows sig.inputs sig.output`. We provide additional meta-information: diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index ff379bf5..1ab3439c 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -456,3 +456,9 @@ let mk_result_return_pattern (v : typed_pattern) : typed_pattern = let opt_unmeta_mplace (e : texpression) : mplace option * texpression = match e.e with Meta (MPlace mp, e) -> (Some mp, e) | _ -> (None, e) + +let mk_state_var (vid : VarId.id) : var = + { id = vid; basename = Some ConstStrings.state_basename; ty = mk_state_ty } + +let mk_state_texpression (vid : VarId.id) : texpression = + { e = Var vid; ty = mk_state_ty } diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml index 528d8255..9d9adf4f 100644 --- a/compiler/SymbolicAst.ml +++ b/compiler/SymbolicAst.ml @@ -77,6 +77,14 @@ type expression = We use it to compute meaningful names for the variables we introduce, to prettify the generated code. *) + | ForwardEnd of expression + (** We use this delimiter to indicate at which point we switch to the + generation of code specific to the backward function(s). + + TODO: use this to factorize the generation of the forward and backward + functions (today we replay the *whole* symbolic execution once per + generated function). + *) | Meta of meta * expression (** Meta information *) and expansion = diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 6d01614d..9d249cfb 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -12,6 +12,10 @@ module FA = FunsAnalysis (** The local logger *) let log = L.symbolic_to_pure_log +(* TODO: carrying configurations everywhere is super annoying. + Group everything in references in a [Config.ml] file (put aside the execution + mode, maybe). +*) type config = { filter_useless_back_calls : bool; (** If [true], filter the useless calls to backward functions. @@ -39,6 +43,12 @@ type config = { Note that we later filter the useless *forward* calls in the micro-passes, where it is more natural to do. *) + backward_no_state_update : bool; + (** Controls whether backward functions update the state, in case we use + a state ({!use_state}). + + See {!Translate.config.backward_no_state_update}. + *) } type type_context = { @@ -110,7 +120,23 @@ type bs_ctx = { *) var_counter : VarId.generator; state_var : VarId.id; - (** The current state variable, in case we use a state *) + (** The current state variable, in case the function is stateful *) + back_state_var : VarId.id; + (** The additional input state variable received by a stateful backward function. + When generating stateful functions, we generate code of the following + form: + + {[ + (st1, y) <-- f_fwd x st0; // st0 is the state upon calling f_fwd + ... // the state may be updated + (st3, x') <-- f_back x st0 y' st2; // st2 is the state upon calling f_back + ]} + + When translating a backward function, we need at some point to update + [state_var] with [back_state_var], to account for the fact that the + state may have been updated by the caller between the call to the + forward function and the call to the backward function. + *) forward_inputs : var list; (** The input parameters for the forward function *) backward_inputs : var list T.RegionGroupId.Map.t; @@ -498,20 +524,26 @@ let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) : let abs_ids = list_ancestor_abstractions_ids ctx abs in List.map (fun id -> V.AbstractionId.Map.find id ctx.abstractions) abs_ids -(** Small utility. *) -let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) - (fun_id : A.fun_id) (gid : T.RegionGroupId.id option) : fun_effect_info = +(** Small utility. + + [backward_no_state_update]: see {!config} + *) +let get_fun_effect_info (backward_no_state_update : bool) + (fun_infos : FA.fun_info A.FunDeclId.Map.t) (fun_id : A.fun_id) + (gid : T.RegionGroupId.id option) : fun_effect_info = match fun_id with | A.Regular fid -> let info = A.FunDeclId.Map.find fid fun_infos in - let input_state = info.stateful in - let output_state = input_state && gid = None in - { can_fail = info.can_fail; input_state; output_state } + let stateful_group = info.stateful in + let stateful = + stateful_group && ((not backward_no_state_update) || gid = None) + in + { can_fail = info.can_fail; stateful_group; stateful } | A.Assumed aid -> { can_fail = Assumed.assumed_can_fail aid; - input_state = false; - output_state = false; + stateful_group = false; + stateful = false; } (** Translate a function signature. @@ -519,10 +551,11 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) Note that the function also takes a list of names for the inputs, and computes, for every output for the backward functions, a corresponding name (outputs for backward functions come from borrows in the inputs - of the forward function). + of the forward function) which we use as hints to generate pretty names. *) -let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) - (fun_id : A.fun_id) (types_infos : TA.type_infos) (sg : A.fun_sig) +let translate_fun_sig (backward_no_state_update : bool) + (fun_infos : FA.fun_info A.FunDeclId.Map.t) (fun_id : A.fun_id) + (types_infos : TA.type_infos) (sg : A.fun_sig) (input_names : string option list) (bid : T.RegionGroupId.id option) : fun_sig_named_outputs = (* Retrieve the list of parent backward functions *) @@ -572,17 +605,42 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) *) List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] in - (* Does the function take a state as input, does it return a state and can - * it fail? *) - let effect_info = get_fun_effect_info fun_infos fun_id bid in - (* *) - let state_ty = if effect_info.input_state then [ mk_state_ty ] else [] in + (* Is the function stateful, and can it fail? *) + let effect_info = + get_fun_effect_info backward_no_state_update fun_infos fun_id bid + in + (* If the function is stateful, the inputs are: + - forward: [fwd_ty0, ..., fwd_tyn, state] + - backward: + - if config.no_backward_state: [fwd_ty0, ..., fwd_tyn, state, back_ty, state] + - otherwise: [fwd_ty0, ..., fwd_tyn, state, back_ty] + + The backward takes the same state as input as the forward function, + together with the state at the point where it gets called, if it is + stateful. + + See the comments for {!Translate.config.backward_no_state_update} + *) + let fwd_state_ty = + (* For the forward state, we check if the *whole group* is stateful. + See {!effect_info}. *) + if effect_info.stateful_group then [ mk_state_ty ] else [] + in + let back_state_ty = + (* For the backward state, we check if the function is a backward function, + and it is stateful *) + if effect_info.stateful && Option.is_some gid then [ mk_state_ty ] else [] + in + (* Concatenate the inputs, in the following order: * - forward inputs - * - state input + * - forward state input * - backward inputs + * - backward state input *) - let inputs = List.concat [ fwd_inputs; state_ty; back_inputs ] in + let inputs = + List.concat [ fwd_inputs; fwd_state_ty; back_inputs; back_state_ty ] + in (* Outputs *) let output_names, doutputs = match gid with @@ -620,7 +678,7 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) let output = mk_simpl_tuple_ty doutputs in (* Add the output state *) let output = - if effect_info.output_state then mk_simpl_tuple_ty [ mk_state_ty; output ] + if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] else output in (* Wrap in a result type *) @@ -1087,6 +1145,15 @@ let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx) | Assertion (v, e) -> translate_assertion config v e ctx | Expansion (p, sv, exp) -> translate_expansion config p sv exp ctx | Meta (meta, e) -> translate_meta config meta e ctx + | ForwardEnd e -> + (* Update the current state with the additional state received by the backward + function, if needs be *) + let ctx = + match ctx.bid with + | None -> ctx + | Some _ -> { ctx with state_var = ctx.back_state_var } + in + translate_expression config e ctx and translate_panic (ctx : bs_ctx) : texpression = (* Here we use the function return type - note that it is ok because @@ -1095,13 +1162,15 @@ and translate_panic (ctx : bs_ctx) : texpression = (* If we use a state monad, we need to add a lambda for the state variable *) (* Note that only forward functions return a state *) let output_ty = mk_simpl_tuple_ty ctx.sg.doutputs in - if ctx.sg.info.effect_info.output_state then + (* TODO: we should use a [Fail] function *) + if ctx.sg.info.effect_info.stateful then (* Create the [Fail] value *) let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in let ret_v = mk_result_fail_texpression ret_ty in ret_v else mk_result_fail_texpression output_ty +(** [opt_v]: the value to return, in case we translate a forward function *) and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = (* There are two cases: @@ -1109,44 +1178,40 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression value should be [Some] (it is the returned value) - or we are translating a backward function, in which case it should be [None] *) - match ctx.bid with - | None -> - (* Forward function *) - let v = Option.get opt_v in - let v = typed_value_to_texpression ctx v in - (* We may need to return a state - * - error-monad: Return x - * - state-error: Return (state, x) - * *) - if ctx.sg.info.effect_info.output_state then - let state_var = - { - id = ctx.state_var; - basename = Some ConstStrings.state_basename; - ty = mk_state_ty; - } + (* Compute the values that we should return *without the state and the result + * wrapper* *) + let output = + match ctx.bid with + | None -> + (* Forward function *) + let v = Option.get opt_v in + typed_value_to_texpression ctx v + | Some bid -> + (* Backward function *) + (* Sanity check *) + assert (opt_v = None); + (* Group the variables in which we stored the values we need to give back. + * See the explanations for the [SynthInput] case in [translate_end_abstraction] *) + let backward_outputs = + T.RegionGroupId.Map.find bid ctx.backward_outputs in - let state_rvalue = mk_texpression_from_var state_var in - mk_result_return_texpression - (mk_simpl_tuple_texpression [ state_rvalue; v ]) - else mk_result_return_texpression v - | Some bid -> - (* Backward function *) - (* Sanity check *) - assert (opt_v = None); - assert (not ctx.sg.info.effect_info.output_state); - (* We simply need to return the variables in which we stored the values - * we need to give back. - * See the explanations for the [SynthInput] case in [translate_end_abstraction] *) - let backward_outputs = - T.RegionGroupId.Map.find bid ctx.backward_outputs - in - let field_values = List.map mk_texpression_from_var backward_outputs in - (* Backward functions never return a state *) - (* TODO: we should use a [fail] function, it would be cleaner *) - let ret_value = mk_simpl_tuple_texpression field_values in - let ret_value = mk_result_return_texpression ret_value in - ret_value + let field_values = List.map mk_texpression_from_var backward_outputs in + mk_simpl_tuple_texpression field_values + in + (* We may need to return a state + * - error-monad: Return x + * - state-error: Return (state, x) + * *) + let effect_info = ctx.sg.info.effect_info in + let output = + if effect_info.stateful then + let state_rvalue = mk_state_texpression ctx.state_var in + mk_simpl_tuple_texpression [ state_rvalue; output ] + else output + in + (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *) + (* TODO: we should use a [Return] function *) + mk_result_return_texpression output and translate_function_call (config : config) (call : S.call) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -1171,29 +1236,26 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression) (* Retrieve the effect information about this function (can fail, * takes a state as input, etc.) *) let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos fid None - in - (* Add the state input argument *) - let args = - if effect_info.input_state then - let state_var = { e = Var ctx.state_var; ty = mk_state_ty } in - List.append args [ state_var ] - else args + get_fun_effect_info config.backward_no_state_update + ctx.fun_context.fun_infos fid None in - (* Generate a fresh state variable if the function call introduces - * a new variable *) - let ctx, out_state = - if effect_info.input_state then - let ctx, var = bs_ctx_fresh_state_var ctx in - (ctx, Some var) - else (ctx, None) + (* If the function is stateful: + * - add the state input argument + * - generate a fresh state variable for the returned state + *) + let args, ctx, out_state = + if effect_info.stateful then + let state_var = mk_state_texpression ctx.state_var in + let ctx, nstate_var = bs_ctx_fresh_state_var ctx in + (List.append args [ state_var ], ctx, Some nstate_var) + else (args, ctx, None) in (* Register the function call *) let ctx = bs_ctx_register_forward_call call_id call args ctx in (ctx, func, effect_info, args, out_state) | S.Unop E.Not -> let effect_info = - { can_fail = false; input_state = false; output_state = false } + { can_fail = false; stateful_group = false; stateful = false } in (ctx, Unop Not, effect_info, args, None) | S.Unop E.Neg -> ( @@ -1203,14 +1265,14 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression) (* Note that negation can lead to an overflow and thus fail (it * is thus monadic) *) let effect_info = - { can_fail = true; input_state = false; output_state = false } + { can_fail = true; stateful_group = false; stateful = false } in (ctx, Unop (Neg int_ty), effect_info, args, None) | _ -> raise (Failure "Unreachable")) | S.Unop (E.Cast (src_ty, tgt_ty)) -> (* Note that cast can fail *) let effect_info = - { can_fail = true; input_state = false; output_state = false } + { can_fail = true; stateful_group = false; stateful = false } in (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None) | S.Binop binop -> ( @@ -1222,8 +1284,8 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression) let effect_info = { can_fail = ExpressionsUtils.binop_can_fail binop; - input_state = false; - output_state = false; + stateful_group = false; + stateful = false; } in (ctx, Binop (binop, int_ty0), effect_info, args, None) @@ -1307,6 +1369,17 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) | V.FunCall -> let call_info = V.FunCallId.Map.find abs.call_id ctx.calls in let call = call_info.forward in + let fun_id = + match call.call_id with + | S.Fun (fun_id, _) -> fun_id + | Unop _ | Binop _ -> + (* Those don't have backward functions *) + raise (Failure "Unreachable") + in + let effect_info = + get_fun_effect_info config.backward_no_state_update + ctx.fun_context.fun_infos fun_id (Some abs.back_id) + in let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in (* Retrieve the original call and the parent abstractions *) let _forward, backwards = get_abs_ancestors ctx abs in @@ -1322,8 +1395,21 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) (* Retrieve the values consumed upon ending the loans inside this * abstraction: those give us the remaining input values *) let back_inputs = abs_to_consumed ctx abs in + (* If the function is stateful: + * - add the state input argument + * - generate a fresh state variable for the returned state + *) + let back_state, ctx, nstate = + if effect_info.stateful then + let back_state = mk_state_texpression ctx.state_var in + let ctx, nstate = bs_ctx_fresh_state_var ctx in + ([ back_state ], ctx, Some nstate) + else ([], ctx, None) + in + (* Concatenate all the inpus *) let inputs = - List.concat [ fwd_inputs; back_ancestors_inputs; back_inputs ] + List.concat + [ fwd_inputs; back_ancestors_inputs; back_inputs; back_state ] in (* Retrieve the values given back by this function: those are the output * values. We rely on the fact that there are no nested borrows to use the @@ -1333,43 +1419,42 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) List.append (List.map translate_opt_mplace call.args_places) [ None ] in let ctx, outputs = abs_to_given_back output_mpl abs ctx in - (* Group the output values together (note that for now, backward functions - * never return an output state) *) + (* Group the output values together: first the updated inputs *) let output = mk_simpl_tuple_pattern outputs in - (* Sanity check: the inputs and outputs have the proper number and the proper type *) - let fun_id = - match call.call_id with - | S.Fun (fun_id, _) -> fun_id - | Unop _ | Binop _ -> - (* Those don't have backward functions *) - raise (Failure "Unreachable") + (* Add the returned state if the function is stateful *) + let output = + match nstate with + | None -> output + | Some nstate -> mk_simpl_tuple_pattern [ nstate; output ] in - - let inst_sg = - get_instantiated_fun_sig fun_id (Some abs.back_id) type_args ctx + (* Sanity check: the inputs and outputs have the proper number and the proper type *) + let _ = + let inst_sg = + get_instantiated_fun_sig fun_id (Some abs.back_id) type_args ctx + in + log#ldebug + (lazy + ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs (" + ^ string_of_int (List.length inputs) + ^ "): " + ^ String.concat ", " (List.map show_texpression inputs) + ^ "\n- inst_sg.inputs (" + ^ string_of_int (List.length inst_sg.inputs) + ^ "): " + ^ String.concat ", " (List.map show_ty inst_sg.inputs))); + List.iter + (fun (x, ty) -> assert ((x : texpression).ty = ty)) + (List.combine inputs inst_sg.inputs); + log#ldebug + (lazy + ("\n- outputs: " + ^ string_of_int (List.length outputs) + ^ "\n- expected outputs: " + ^ string_of_int (List.length inst_sg.doutputs))); + List.iter + (fun (x, ty) -> assert ((x : typed_pattern).ty = ty)) + (List.combine outputs inst_sg.doutputs) in - log#ldebug - (lazy - ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs (" - ^ string_of_int (List.length inputs) - ^ "): " - ^ String.concat ", " (List.map show_texpression inputs) - ^ "\n- inst_sg.inputs (" - ^ string_of_int (List.length inst_sg.inputs) - ^ "): " - ^ String.concat ", " (List.map show_ty inst_sg.inputs))); - List.iter - (fun (x, ty) -> assert ((x : texpression).ty = ty)) - (List.combine inputs inst_sg.inputs); - log#ldebug - (lazy - ("\n- outputs: " - ^ string_of_int (List.length outputs) - ^ "\n- expected outputs: " - ^ string_of_int (List.length inst_sg.doutputs))); - List.iter - (fun (x, ty) -> assert ((x : typed_pattern).ty = ty)) - (List.combine outputs inst_sg.doutputs); (* Retrieve the function id, and register the function call in the context * if necessary *) let ctx, func = bs_ctx_register_backward_call abs back_inputs ctx in @@ -1382,9 +1467,6 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) (fun (arg, mp) -> mk_opt_mplace_texpression mp arg) (List.combine inputs args_mplaces) in - let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos fun_id (Some abs.back_id) - in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in let ret_ty = if effect_info.can_fail then mk_result_ty output.ty else output.ty @@ -1398,8 +1480,13 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) * We do a small optimization here: if the backward function doesn't * have any output, we don't introduce any function call. * See the comment in [config]. + * + * TODO: use an option to disallow backward functions from updating the state. + * TODO: a backward function which only gives back shared borrows shouldn't + * update the state (state updates should only be used for mutable borrows, + * with objects like Rc for instance. *) - if config.filter_useless_back_calls && outputs = [] then ( + if config.filter_useless_back_calls && outputs = [] && nstate = None then ( (* No outputs - we do a small sanity check: the backward function * should have exactly the same number of inputs as the forward: * this number can be different only if the forward function returned @@ -1708,32 +1795,28 @@ let translate_fun_decl (config : config) (ctx : bs_ctx) (* Translate the declaration *) let def_id = def.A.def_id in let basename = def.name in - (* Lookup the signature *) - let signature = bs_ctx_lookup_local_function_sig def_id bid ctx in + (* Retrieve the signature *) + let signature = ctx.sg in (* Translate the body, if there is *) let body = match body with | None -> None | Some body -> let body = translate_expression config body ctx in - (* Sanity check *) - type_check_texpression ctx body; - (* Introduce the input state, if necessary *) let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos (Regular def_id) bid + get_fun_effect_info config.backward_no_state_update + ctx.fun_context.fun_infos (Regular def_id) bid in - let input_state = - if effect_info.input_state then - [ - { - id = ctx.state_var; - basename = Some ConstStrings.state_basename; - ty = mk_state_ty; - }; - ] + (* Sanity check *) + type_check_texpression ctx body; + (* Introduce the forward input state (the state at call site of the + * *forward* function), if necessary. *) + let fwd_state = + (* We check if the *whole group* is stateful. See {!effect_info} *) + if effect_info.stateful_group then [ mk_state_var ctx.state_var ] else [] in - (* Compute the list of (properly ordered) input variables *) + (* Compute the list of (properly ordered) backward input variables *) let backward_inputs : var list = match bid with | None -> [] @@ -1747,8 +1830,17 @@ let translate_fun_decl (config : config) (ctx : bs_ctx) (fun id -> T.RegionGroupId.Map.find id ctx.backward_inputs) backward_ids) in + (* Introduce the backward input state (the state at call site of the + * *backward* function), if necessary *) + let back_state = + if effect_info.stateful && Option.is_some bid then + [ mk_state_var ctx.back_state_var ] + else [] + in + (* Group the inputs together *) let inputs = - List.concat [ ctx.forward_inputs; input_state; backward_inputs ] + List.concat + [ ctx.forward_inputs; fwd_state; backward_inputs; back_state ] in let inputs_lvs = List.map (fun v -> mk_typed_pattern_from_var v None) inputs @@ -1756,12 +1848,18 @@ let translate_fun_decl (config : config) (ctx : bs_ctx) (* Sanity check *) log#ldebug (lazy - ("SymbolicToPure.translate_fun_decl:" ^ "\n- forward_inputs: " + ("SymbolicToPure.translate_fun_decl: " + ^ Print.fun_name_to_string def.A.name + ^ " (" + ^ Print.option_to_string T.RegionGroupId.to_string bid + ^ ")" ^ "\n- forward_inputs: " ^ String.concat ", " (List.map show_var ctx.forward_inputs) - ^ "\n- input_state: " - ^ String.concat ", " (List.map show_var input_state) + ^ "\n- fwd_state: " + ^ String.concat ", " (List.map show_var fwd_state) ^ "\n- backward_inputs: " ^ String.concat ", " (List.map show_var backward_inputs) + ^ "\n- back_state: " + ^ String.concat ", " (List.map show_var back_state) ^ "\n- signature.inputs: " ^ String.concat ", " (List.map show_ty signature.inputs))); assert ( @@ -1804,8 +1902,8 @@ let translate_type_decls (type_decls : T.type_decl list) : type_decl list = - optional names for the outputs values (we derive them for the backward functions) *) -let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) - (types_infos : TA.type_infos) +let translate_fun_signatures (backward_no_state_update : bool) + (fun_infos : FA.fun_info A.FunDeclId.Map.t) (types_infos : TA.type_infos) (functions : (A.fun_id * string option list * A.fun_sig) list) : fun_sig_named_outputs RegularFunIdMap.t = (* For every function, translate the signatures of: @@ -1816,7 +1914,8 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) (sg : A.fun_sig) : (regular_fun_id * fun_sig_named_outputs) list = (* The forward function *) let fwd_sg = - translate_fun_sig fun_infos fun_id types_infos sg input_names None + translate_fun_sig backward_no_state_update fun_infos fun_id types_infos sg + input_names None in let fwd_id = (fun_id, None) in (* The backward functions *) @@ -1824,8 +1923,8 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t) List.map (fun (rg : T.region_var_group) -> let tsg = - translate_fun_sig fun_infos fun_id types_infos sg input_names - (Some rg.id) + translate_fun_sig backward_no_state_update fun_infos fun_id + types_infos sg input_names (Some rg.id) in let id = (fun_id, Some rg.id) in (id, tsg)) diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml index c74a831e..8d4dac82 100644 --- a/compiler/SynthesizeSymbolic.ml +++ b/compiler/SynthesizeSymbolic.ml @@ -12,7 +12,7 @@ let mk_mplace (p : E.place) (ctx : Contexts.eval_ctx) : mplace = let mk_opt_mplace (p : E.place option) (ctx : Contexts.eval_ctx) : mplace option = - match p with None -> None | Some p -> Some (mk_mplace p ctx) + Option.map (fun p -> mk_mplace p ctx) p let mk_opt_place_from_op (op : E.operand) (ctx : Contexts.eval_ctx) : mplace option = @@ -22,11 +22,11 @@ let mk_opt_place_from_op (op : E.operand) (ctx : Contexts.eval_ctx) : let synthesize_symbolic_expansion (sv : V.symbolic_value) (place : mplace option) (seel : V.symbolic_expansion option list) - (exprl : expression list option) : expression option = - match exprl with + (el : expression list option) : expression option = + match el with | None -> None - | Some exprl -> - let ls = List.combine seel exprl in + | Some el -> + let ls = List.combine seel el in (* Match on the symbolic value type to know which can of expansion happened *) let expansion = match sv.V.sv_ty with @@ -89,19 +89,18 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value) Some (Expansion (place, sv, expansion)) let synthesize_symbolic_expansion_no_branching (sv : V.symbolic_value) - (place : mplace option) (see : V.symbolic_expansion) - (expr : expression option) : expression option = - let exprl = match expr with None -> None | Some expr -> Some [ expr ] in - synthesize_symbolic_expansion sv place [ Some see ] exprl + (place : mplace option) (see : V.symbolic_expansion) (e : expression option) + : expression option = + let el = Option.map (fun e -> [ e ]) e in + synthesize_symbolic_expansion sv place [ Some see ] el let synthesize_function_call (call_id : call_id) (abstractions : V.AbstractionId.id list) (type_params : T.ety list) (args : V.typed_value list) (args_places : mplace option list) (dest : V.symbolic_value) (dest_place : mplace option) - (expr : expression option) : expression option = - match expr with - | None -> None - | Some expr -> + (e : expression option) : expression option = + Option.map + (fun e -> let call = { call_id; @@ -113,48 +112,45 @@ let synthesize_function_call (call_id : call_id) dest_place; } in - Some (FunCall (call, expr)) + FunCall (call, e)) + e let synthesize_global_eval (gid : A.GlobalDeclId.id) (dest : V.symbolic_value) - (expr : expression option) : expression option = - match expr with None -> None | Some e -> Some (EvalGlobal (gid, dest, e)) + (e : expression option) : expression option = + Option.map (fun e -> EvalGlobal (gid, dest, e)) e let synthesize_regular_function_call (fun_id : A.fun_id) (call_id : V.FunCallId.id) (abstractions : V.AbstractionId.id list) (type_params : T.ety list) (args : V.typed_value list) (args_places : mplace option list) (dest : V.symbolic_value) - (dest_place : mplace option) (expr : expression option) : expression option - = + (dest_place : mplace option) (e : expression option) : expression option = synthesize_function_call (Fun (fun_id, call_id)) - abstractions type_params args args_places dest dest_place expr + abstractions type_params args args_places dest dest_place e let synthesize_unary_op (unop : E.unop) (arg : V.typed_value) (arg_place : mplace option) (dest : V.symbolic_value) - (dest_place : mplace option) (expr : expression option) : expression option - = + (dest_place : mplace option) (e : expression option) : expression option = synthesize_function_call (Unop unop) [] [] [ arg ] [ arg_place ] dest - dest_place expr + dest_place e let synthesize_binary_op (binop : E.binop) (arg0 : V.typed_value) (arg0_place : mplace option) (arg1 : V.typed_value) (arg1_place : mplace option) (dest : V.symbolic_value) - (dest_place : mplace option) (expr : expression option) : expression option - = + (dest_place : mplace option) (e : expression option) : expression option = synthesize_function_call (Binop binop) [] [] [ arg0; arg1 ] - [ arg0_place; arg1_place ] dest dest_place expr + [ arg0_place; arg1_place ] dest dest_place e -let synthesize_end_abstraction (abs : V.abs) (expr : expression option) : +let synthesize_end_abstraction (abs : V.abs) (e : expression option) : expression option = - match expr with - | None -> None - | Some expr -> Some (EndAbstraction (abs, expr)) + Option.map (fun e -> EndAbstraction (abs, e)) e let synthesize_assignment (lplace : mplace) (rvalue : V.typed_value) - (rplace : mplace option) (expr : expression option) : expression option = - match expr with - | None -> None - | Some expr -> Some (Meta (Assignment (lplace, rvalue, rplace), expr)) + (rplace : mplace option) (e : expression option) : expression option = + Option.map (fun e -> Meta (Assignment (lplace, rvalue, rplace), e)) e + +let synthesize_assertion (v : V.typed_value) (e : expression option) = + Option.map (fun e -> Assertion (v, e)) e -let synthesize_assertion (v : V.typed_value) (expr : expression option) = - match expr with None -> None | Some expr -> Some (Assertion (v, expr)) +let synthesize_forward_end (e : expression option) = + Option.map (fun e -> ForwardEnd e) e diff --git a/compiler/Translate.ml b/compiler/Translate.ml index d7cc9155..72322c73 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -18,6 +18,30 @@ type config = { (** Controls whether we need to use a state to model the external world (I/O, for instance). *) + backward_no_state_update : bool; + (** Controls whether backward functions update the state, in case we use + a state ({!use_state}). + + If they update the state, we generate code of the following style: + {[ + (st1, y) <-- f_fwd x st0; // st0 is the state upon calling f_fwd + ... + (st3, x') <-- f_back x st0 y' st2; // st2 is the state upon calling f_back + }] + + Otherwise, we generate code of the following shape: + {[ + (st1, y) <-- f_fwd x st0; + ... + x' <-- f_back x st0 y'; + }] + + The second format is easier to reason about, but the first one is + necessary to properly handle some Rust functions which use internal + mutability such as {{:https://doc.rust-lang.org/std/cell/struct.RefCell.html#method.try_borrow_mut} [RefCell::try_mut_borrow]}: + in order to model this behaviour we would need a state, and calling the backward + function would update the state by reinserting the updated value in it. + *) split_files : bool; (** Controls whether we split the generated definitions between different files for the types, clauses and functions, or if we group them in @@ -96,7 +120,8 @@ let translate_function_to_symbolics (config : C.partial_config) TODO: maybe we should introduce a record for this. *) let translate_function_to_pure (config : C.partial_config) - (mp_config : Micro.config) (trans_ctx : trans_ctx) + (mp_config : Micro.config) (backward_no_state_update : bool) + (trans_ctx : trans_ctx) (fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdMap.t) (pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) (fdef : A.fun_decl) : pure_fun_translation = @@ -123,6 +148,7 @@ let translate_function_to_pure (config : C.partial_config) let sv_to_var = V.SymbolicValueId.Map.empty in let var_counter = Pure.VarId.generator_zero in let state_var, var_counter = Pure.VarId.fresh var_counter in + let back_state_var, var_counter = Pure.VarId.fresh var_counter in let calls = V.FunCallId.Map.empty in let abstractions = V.AbstractionId.Map.empty in let type_context = @@ -151,6 +177,7 @@ let translate_function_to_pure (config : C.partial_config) sv_to_var; var_counter; state_var; + back_state_var; type_context; fun_context; global_context; @@ -188,6 +215,7 @@ let translate_function_to_pure (config : C.partial_config) { SymbolicToPure.filter_useless_back_calls = mp_config.filter_useless_monadic_calls; + backward_no_state_update; } in @@ -231,12 +259,22 @@ let translate_function_to_pure (config : C.partial_config) in (* We need to ignore the forward inputs, and the state input (if there is) *) let fun_info = - SymbolicToPure.get_fun_effect_info fun_context.fun_infos - (A.Regular def_id) (Some back_id) + SymbolicToPure.get_fun_effect_info backward_no_state_update + fun_context.fun_infos (A.Regular def_id) (Some back_id) in - let _, backward_inputs = - Collections.List.split_at backward_sg.sg.inputs - (num_forward_inputs + if fun_info.input_state then 1 else 0) + let backward_inputs = + (* We need to ignore the forward state and the backward state *) + (* TODO: this is ad-hoc and error-prone. We should group all this + * information in the signature information. *) + let fwd_state_n = if fun_info.stateful_group then 1 else 0 in + let num_forward_inputs = num_forward_inputs + fwd_state_n in + let back_state_n = if fun_info.stateful then 1 else 0 in + let num_back_inputs = + List.length backward_sg.sg.inputs + - num_forward_inputs - back_state_n + in + Collections.List.subslice backward_sg.sg.inputs num_forward_inputs + num_back_inputs in (* As we forbid nested borrows, the additional inputs for the backward * functions come from the borrows in the return value of the rust function: @@ -285,7 +323,8 @@ let translate_function_to_pure (config : C.partial_config) (pure_forward, pure_backwards) let translate_module_to_pure (config : C.partial_config) - (mp_config : Micro.config) (use_state : bool) (crate : A.crate) : + (mp_config : Micro.config) (use_state : bool) + (backward_no_state_update : bool) (crate : A.crate) : trans_ctx * Pure.type_decl list * (bool * pure_fun_translation) list = (* Debug *) log#ldebug (lazy "translate_module_to_pure"); @@ -333,15 +372,15 @@ let translate_module_to_pure (config : C.partial_config) in let sigs = List.append assumed_sigs local_sigs in let fun_sigs = - SymbolicToPure.translate_fun_signatures fun_context.fun_infos - type_context.type_infos sigs + SymbolicToPure.translate_fun_signatures backward_no_state_update + fun_context.fun_infos type_context.type_infos sigs in (* Translate all the *transparent* functions *) let pure_translations = List.map - (translate_function_to_pure config mp_config trans_ctx fun_sigs - type_decls_map) + (translate_function_to_pure config mp_config backward_no_state_update + trans_ctx fun_sigs type_decls_map) crate.functions in @@ -631,7 +670,7 @@ let translate_module (filename : string) (dest_dir : string) (config : config) (* Translate the module to the pure AST *) let trans_ctx, trans_types, trans_funs = translate_module_to_pure config.eval_config config.mp_config - config.use_state crate + config.use_state config.backward_no_state_update crate in (* Initialize the extraction context - for now we extract only to F*. -- cgit v1.2.3