diff options
Diffstat (limited to 'compiler/SymbolicToPure.ml')
-rw-r--r-- | compiler/SymbolicToPure.ml | 808 |
1 files changed, 587 insertions, 221 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index dd662074..006fdda7 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -65,6 +65,46 @@ type call_info = { *) } +(** Contains information about a loop we entered. + + Note that a path in a translated function body can have at most one call to + a loop, because the loop function takes care of the end of the execution + (and always happen at the end of the function). To be more precise, if we + translate a function body which contains a loop, one of the leaves will be a + call to the loop translation. The same happens for loop bodies. + + For instance, if in Rust we have: + {[ + fn get(...) { + let x = f(...); + + loop { + ... + } + } + ]} + + Then in the translation we have: + {[ + let get_fwd ... = + let x = f_fwd ... in + (* We end the function by calling the loop translation *) + get_fwd_loop ... + ]} + + The various input and output fields are for this unique loop call, if + there is one. + *) +type loop_info = { + loop_id : LoopId.id; + input_svl : V.symbolic_value list; + type_args : ty list; + forward_inputs : texpression list option; + (** The forward inputs are initialized at [None] *) + forward_output_no_state : var option; + (** The forward outputs are initialized at [None] *) +} + (** Body synthesis context *) type bs_ctx = { type_context : type_context; @@ -119,7 +159,14 @@ type bs_ctx = { (** The function calls we encountered so far *) abstractions : (V.abs * texpression list) V.AbstractionId.Map.t; (** The ended abstractions we encountered so far, with their additional input arguments *) - loop_id : V.LoopId.id option; + loop_ids_map : LoopId.id V.LoopId.Map.t; (** Ids to use for the loops *) + loops : loop_info LoopId.Map.t; + (** The loops we encountered so far. + + We are using a map to be general - in practice we will fail if we encounter + more than one loop on a single path. + *) + loop_id : LoopId.id option; (** [Some] if we reached a loop (we are synthesizing a function, and reached a loop, or are synthesizing the loop body itself) *) @@ -535,7 +582,8 @@ let mk_fuel_input_as_list (ctx : bs_ctx) (info : fun_effect_info) : (** Small utility. *) let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) - (fun_id : A.fun_id) (gid : T.RegionGroupId.id option) : fun_effect_info = + (fun_id : A.fun_id) (lid : V.LoopId.id option) + (gid : T.RegionGroupId.id option) : fun_effect_info = match fun_id with | A.Regular fid -> let info = A.FunDeclId.Map.find fid fun_infos in @@ -548,9 +596,10 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) stateful_group; stateful; can_diverge = info.can_diverge; - is_rec = info.is_rec; + is_rec = info.is_rec || Option.is_some lid; } | A.Assumed aid -> + assert (lid = None); { can_fail = Assumed.assumed_can_fail aid; stateful_group = false; @@ -579,7 +628,8 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) (Some bid, parents) in (* Is the function stateful, and can it fail? *) - let effect_info = get_fun_effect_info fun_infos fun_id bid in + let lid = None in + let effect_info = get_fun_effect_info fun_infos fun_id lid bid in (* List the inputs for: * - the fuel * - the forward function @@ -728,28 +778,37 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) let sg = { type_params; inputs; output; doutputs; info } in { sg; output_names } -let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern = +let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * var * typed_pattern = (* Generate the fresh variable *) let id, var_counter = VarId.fresh ctx.var_counter in - let var = + let state_var = { id; basename = Some ConstStrings.state_basename; ty = mk_state_ty } in - let state_var = mk_typed_pattern_from_var var None in + let state_pat = mk_typed_pattern_from_var state_var None in (* Update the context *) let ctx = { ctx with var_counter; state_var = id } in (* Return *) - (ctx, state_var) + (ctx, state_var, state_pat) -let fresh_named_var_for_symbolic_value (basename : string option) - (sv : V.symbolic_value) (ctx : bs_ctx) : bs_ctx * var = +let fresh_var (basename : string option) (ty : 'r T.ty) (ctx : bs_ctx) : + bs_ctx * var = (* Generate the fresh variable *) let id, var_counter = VarId.fresh ctx.var_counter in - let ty = ctx_translate_fwd_ty ctx sv.sv_ty in + let ty = ctx_translate_fwd_ty ctx ty in let var = { id; basename; ty } in + (* Update the context *) + let ctx = { ctx with var_counter } in + (* Return *) + (ctx, var) + +let fresh_named_var_for_symbolic_value (basename : string option) + (sv : V.symbolic_value) (ctx : bs_ctx) : bs_ctx * var = + (* Generate the fresh variable *) + let ctx, var = fresh_var basename sv.sv_ty ctx in (* Insert in the map *) let sv_to_var = V.SymbolicValueId.Map.add sv.sv_id var ctx.sv_to_var in (* Update the context *) - let ctx = { ctx with var_counter; sv_to_var } in + let ctx = { ctx with sv_to_var } in (* Return *) (ctx, var) @@ -1136,9 +1195,13 @@ and aproj_to_given_back (mp : mplace option) (aproj : V.aproj) (ctx : bs_ctx) : See [typed_avalue_to_given_back]. *) -let abs_to_given_back (mpl : mplace option list) (abs : V.abs) (ctx : bs_ctx) : - bs_ctx * typed_pattern list = - let avalues = List.combine mpl abs.avalues in +let abs_to_given_back (mpl : mplace option list option) (abs : V.abs) + (ctx : bs_ctx) : bs_ctx * typed_pattern list = + let avalues = + match mpl with + | None -> List.map (fun av -> (None, av)) abs.avalues + | Some mpl -> List.combine mpl abs.avalues + in let ctx, values = List.fold_left_map (fun ctx (mp, av) -> typed_avalue_to_given_back mp av ctx) @@ -1151,7 +1214,7 @@ let abs_to_given_back (mpl : mplace option list) (abs : V.abs) (ctx : bs_ctx) : let abs_to_given_back_no_mp (abs : V.abs) (ctx : bs_ctx) : bs_ctx * typed_pattern list = let mpl = List.map (fun _ -> None) abs.avalues in - abs_to_given_back mpl abs ctx + abs_to_given_back (Some mpl) abs ctx (** Return the ordered list of the (transitive) parents of a given abstraction. @@ -1167,6 +1230,8 @@ let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) (call_id : V.FunCallId.id) : let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression = match e with | S.Return (ectx, opt_v) -> translate_return ectx opt_v ctx + | ReturnWithLoop (loop_id, is_continue) -> + translate_return_with_loop loop_id is_continue ctx | Panic -> translate_panic ctx | FunCall (call, e) -> translate_function_call call e ctx | EndAbstraction (ectx, abs, e) -> translate_end_abstraction ectx abs e ctx @@ -1174,19 +1239,8 @@ let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression = | Assertion (ectx, v, e) -> translate_assertion ectx v e ctx | Expansion (p, sv, exp) -> translate_expansion p sv exp ctx | Meta (meta, e) -> translate_meta meta e ctx - | ForwardEnd (loop_input_values, e, back_e) -> - assert (loop_input_values = None); - (* Update the current state with the additional state received by the backward - function, if needs be, and lookup the proper expression *) - let ctx, e = - match ctx.bid with - | None -> (ctx, e) - | Some bid -> - let ctx = { ctx with state_var = ctx.back_state_var } in - let e = T.RegionGroupId.Map.find bid back_e in - (ctx, e) - in - translate_expression e ctx + | ForwardEnd (ectx, loop_input_values, e, back_e) -> + translate_forward_end ectx loop_input_values e back_e ctx | Loop loop -> translate_loop loop ctx and translate_panic (ctx : bs_ctx) : texpression = @@ -1206,12 +1260,12 @@ and translate_panic (ctx : bs_ctx) : texpression = ret_v else mk_result_fail_texpression_with_error_id error_failure_id output_ty -(** [opt_v]: the value to return, in case we translate a forward function *) +(** [opt_v]: the value to return, in case we translate a forward body *) and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = (* There are two cases: - - either we are translating a forward function, in which case the optional - value should be [Some] (it is the returned value) + - either we reach the return of a forward function or a forward loop body, + in which case the optional value should be [Some] (it is the returned value) - or we are translating a backward function, in which case it should be [None] *) (* Compute the values that we should return *without the state and the result @@ -1246,7 +1300,52 @@ and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) else output in (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *) - (* TODO: we should use a [Return] function *) + mk_result_return_texpression output + +and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool) + (ctx : bs_ctx) : texpression = + assert (is_continue = ctx.inside_loop); + let loop_id = V.LoopId.Map.find loop_id ctx.loop_ids_map in + assert (loop_id = Option.get ctx.loop_id); + + (* Lookup the loop information *) + let loop_id = Option.get ctx.loop_id in + let loop_info = LoopId.Map.find loop_id ctx.loops in + + (* There are two cases depending on whether we translate a backward function + or not. + *) + let output = + match ctx.bid with + | None -> + (* Forward *) + mk_texpression_from_var (Option.get loop_info.forward_output_no_state) + | Some bid -> + (* Backward *) + (* Group the variables in which we stored the values we need to give back. + * See the explanations for the [SynthInput] case in [translate_end_abstraction] *) + let backward_outputs = + T.RegionGroupId.Map.find bid ctx.backward_outputs + in + let field_values = List.map mk_texpression_from_var backward_outputs in + mk_simpl_tuple_texpression field_values + in + + (* We may need to return a state + * - error-monad: Return x + * - state-error: Return (state, x) + * Note that the loop function and the parent function live in the same + * effect - in particular, one manipulates a state iff the other does + * the same. + * *) + let effect_info = ctx.sg.info.effect_info in + let output = + if effect_info.stateful then + let state_rvalue = mk_state_texpression ctx.state_var in + mk_simpl_tuple_texpression [ state_rvalue; output ] + else output + in + (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *) mk_result_return_texpression output and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : @@ -1272,18 +1371,18 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (* Retrieve the effect information about this function (can fail, * takes a state as input, etc.) *) let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos fid None + get_fun_effect_info ctx.fun_context.fun_infos fid None None in - (* If the function is stateful: - * - add the fuel - * - add the state input argument - * - generate a fresh state variable for the returned state - *) + (* Depending on the function effects: + * - add the fuel + * - add the state input argument + * - generate a fresh state variable for the returned state + *) let args, ctx, out_state = let fuel = mk_fuel_input_as_list ctx effect_info in if effect_info.stateful then let state_var = mk_state_texpression ctx.state_var in - let ctx, nstate_var = bs_ctx_fresh_state_var ctx in + let ctx, _, nstate_var = bs_ctx_fresh_state_var ctx in (List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var) else (List.concat [ fuel; args ], ctx, None) in @@ -1375,80 +1474,281 @@ and translate_end_abstraction (ectx : C.eval_ctx) (abs : V.abs) ^ V.show_abs_kind abs.kind)); match abs.kind with | V.SynthInput rg_id -> - (* When we end an input abstraction, this input abstraction gets back - * the borrows which it introduced in the context through the input - * values: by listing those values, we get the values which are given - * back by one of the backward functions we are synthesizing. *) - (* Note that we don't support nested borrows for now: if we find - * an ended synthesized input abstraction, it must be the one corresponding - * to the backward function wer are synthesizing, it can't be the one - * for a parent backward function. - *) - let bid = Option.get ctx.bid in - assert (rg_id = bid); - - (* The translation is done as follows: - * - for a given backward function, we choose a set of variables [v_i] - * - when we detect the ended input abstraction which corresponds - * to the backward function, and which consumed the values [consumed_i], - * we introduce: - * {[ + translate_end_abstraction_synth_input ectx abs e ctx rg_id + | V.FunCall (call_id, rg_id) -> + translate_end_abstraction_fun_call ectx abs e ctx call_id rg_id + | V.SynthRet rg_id -> translate_end_abstraction_synth_ret ectx abs e ctx rg_id + | Loop (loop_id, rg_id, abs_kind) -> + translate_end_abstraction_loop ectx abs e ctx loop_id rg_id abs_kind + +and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs) + (e : S.expression) (ctx : bs_ctx) (rg_id : T.RegionGroupId.id) : texpression + = + (* When we end an input abstraction, this input abstraction gets back + * the borrows which it introduced in the context through the input + * values: by listing those values, we get the values which are given + * back by one of the backward functions we are synthesizing. *) + (* Note that we don't support nested borrows for now: if we find + * an ended synthesized input abstraction, it must be the one corresponding + * to the backward function wer are synthesizing, it can't be the one + * for a parent backward function. + *) + let bid = Option.get ctx.bid in + assert (rg_id = bid); + + (* The translation is done as follows: + * - for a given backward function, we choose a set of variables [v_i] + * - when we detect the ended input abstraction which corresponds + * to the backward function, and which consumed the values [consumed_i], + * we introduce: + * {[ * let v_i = consumed_i in * ... - * ]} - * Then, when we reach the [Return] node, we introduce: - * {[ + * ]} + * Then, when we reach the [Return] node, we introduce: + * {[ * (v_i) - * ]} - * *) - (* First, get the given back variables *) - let given_back_variables = - T.RegionGroupId.Map.find bid ctx.backward_outputs - in - (* Get the list of values consumed by the abstraction upon ending *) - let consumed_values = abs_to_consumed ctx ectx abs in - (* Group the two lists *) - let variables_values = - List.combine given_back_variables consumed_values - in - (* Sanity check: the two lists match (same types) *) - List.iter - (fun (var, v) -> assert ((var : var).ty = (v : texpression).ty)) - variables_values; - (* Translate the next expression *) - let next_e = translate_expression e ctx in - (* Generate the assignemnts *) - let monadic = false in - List.fold_right - (fun (var, value) (e : texpression) -> - mk_let monadic (mk_typed_pattern_from_var var None) value e) - variables_values next_e - | V.FunCall (call_id, rg_id) -> - let call_info = V.FunCallId.Map.find call_id ctx.calls in - let call = call_info.forward in - let fun_id = - match call.call_id with - | S.Fun (fun_id, _) -> fun_id - | Unop _ | Binop _ -> - (* Those don't have backward functions *) - raise (Failure "Unreachable") - in + * ]} + * *) + (* First, get the given back variables *) + let given_back_variables = + T.RegionGroupId.Map.find bid ctx.backward_outputs + in + (* Get the list of values consumed by the abstraction upon ending *) + let consumed_values = abs_to_consumed ctx ectx abs in + (* Group the two lists *) + let variables_values = List.combine given_back_variables consumed_values in + (* Sanity check: the two lists match (same types) *) + List.iter + (fun (var, v) -> assert ((var : var).ty = (v : texpression).ty)) + variables_values; + (* Translate the next expression *) + let next_e = translate_expression e ctx in + (* Generate the assignemnts *) + let monadic = false in + List.fold_right + (fun (var, value) (e : texpression) -> + mk_let monadic (mk_typed_pattern_from_var var None) value e) + variables_values next_e + +and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) + (e : S.expression) (ctx : bs_ctx) (call_id : V.FunCallId.id) + (rg_id : T.RegionGroupId.id) : texpression = + let call_info = V.FunCallId.Map.find call_id ctx.calls in + let call = call_info.forward in + let fun_id = + match call.call_id with + | S.Fun (fun_id, _) -> fun_id + | Unop _ | Binop _ -> + (* Those don't have backward functions *) + raise (Failure "Unreachable") + in + let effect_info = + get_fun_effect_info ctx.fun_context.fun_infos fun_id None (Some rg_id) + in + let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in + (* Retrieve the original call and the parent abstractions *) + let _forward, backwards = get_abs_ancestors ctx abs call_id in + (* Retrieve the values consumed when we called the forward function and + * ended the parent backward functions: those give us part of the input + * values (rem: for now, as we disallow nested lifetimes, there can't be + * parent backward functions). + * Note that the forward inputs **include the fuel and the input state** + * (if we use those). *) + let fwd_inputs = call_info.forward_inputs in + let back_ancestors_inputs = + List.concat (List.map (fun (_abs, args) -> args) backwards) + in + (* Retrieve the values consumed upon ending the loans inside this + * abstraction: those give us the remaining input values *) + let back_inputs = abs_to_consumed ctx ectx abs in + (* If the function is stateful: + * - add the state input argument + * - generate a fresh state variable for the returned state + *) + let back_state, ctx, nstate = + if effect_info.stateful then + let back_state = mk_state_texpression ctx.state_var in + let ctx, _, nstate = bs_ctx_fresh_state_var ctx in + ([ back_state ], ctx, Some nstate) + else ([], ctx, None) + in + (* Concatenate all the inpus *) + let inputs = + List.concat [ fwd_inputs; back_ancestors_inputs; back_inputs; back_state ] + in + (* Retrieve the values given back by this function: those are the output + * values. We rely on the fact that there are no nested borrows to use the + * meta-place information from the input values given to the forward function + * (we need to add [None] for the return avalue) *) + let output_mpl = + List.append (List.map translate_opt_mplace call.args_places) [ None ] + in + let ctx, outputs = abs_to_given_back (Some output_mpl) abs ctx in + (* Group the output values together: first the updated inputs *) + let output = mk_simpl_tuple_pattern outputs in + (* Add the returned state if the function is stateful *) + let output = + match nstate with + | None -> output + | Some nstate -> mk_simpl_tuple_pattern [ nstate; output ] + in + (* Sanity check: there is the proper number of inputs and outputs, and they have the proper type *) + let _ = + let inst_sg = get_instantiated_fun_sig fun_id (Some rg_id) type_args ctx in + log#ldebug + (lazy + ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs (" + ^ string_of_int (List.length inputs) + ^ "): " + ^ String.concat ", " (List.map show_texpression inputs) + ^ "\n- inst_sg.inputs (" + ^ string_of_int (List.length inst_sg.inputs) + ^ "): " + ^ String.concat ", " (List.map show_ty inst_sg.inputs))); + List.iter + (fun (x, ty) -> assert ((x : texpression).ty = ty)) + (List.combine inputs inst_sg.inputs); + log#ldebug + (lazy + ("\n- outputs: " + ^ string_of_int (List.length outputs) + ^ "\n- expected outputs: " + ^ string_of_int (List.length inst_sg.doutputs))); + List.iter + (fun (x, ty) -> assert ((x : typed_pattern).ty = ty)) + (List.combine outputs inst_sg.doutputs) + in + (* Retrieve the function id, and register the function call in the context + * if necessary *) + let ctx, func = + bs_ctx_register_backward_call abs call_id rg_id back_inputs ctx + in + (* Translate the next expression *) + let next_e = translate_expression e ctx in + (* Put everything together *) + let args_mplaces = List.map (fun _ -> None) inputs in + let args = + List.map + (fun (arg, mp) -> mk_opt_mplace_texpression mp arg) + (List.combine inputs args_mplaces) + in + let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in + let ret_ty = + if effect_info.can_fail then mk_result_ty output.ty else output.ty + in + let func_ty = mk_arrows input_tys ret_ty in + let func = { id = FunOrOp func; type_args } in + let func = { e = Qualif func; ty = func_ty } in + let call = mk_apps func args in + (* **Optimization**: + * ================= + * We do a small optimization here: if the backward function doesn't + * have any output, we don't introduce any function call. + * See the comment in {!Config.filter_useless_monadic_calls}. + * + * TODO: use an option to disallow backward functions from updating the state. + * TODO: a backward function which only gives back shared borrows shouldn't + * update the state (state updates should only be used for mutable borrows, + * with objects like Rc for instance). + *) + if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None then ( + (* No outputs - we do a small sanity check: the backward function + * should have exactly the same number of inputs as the forward: + * this number can be different only if the forward function returned + * a value containing mutable borrows, which can't be the case... *) + assert (List.length inputs = List.length fwd_inputs); + next_e) + else mk_let effect_info.can_fail output call next_e + +and translate_end_abstraction_synth_ret (ectx : C.eval_ctx) (abs : V.abs) + (e : S.expression) (ctx : bs_ctx) (rg_id : T.RegionGroupId.id) : texpression + = + (* If we end the abstraction which consumed the return value of the function + we are synthesizing, we get back the borrows which were inside. Those borrows + are actually input arguments of the backward function we are synthesizing. + So we simply need to introduce proper let bindings. + + For instance: + {[ + fn id<'a>(x : &'a mut u32) -> &'a mut u32 { + x + } + ]} + + Upon ending the return abstraction for 'a, we get back the borrow for [x]. + This new value is the second argument of the backward function: + {[ + let id_back x nx = nx + ]} + + In practice, upon ending this abstraction we introduce a useless + let-binding: + {[ + let id_back x nx = + let s = nx in // the name [s] is not important (only collision matters) + ... + ]} + + This let-binding later gets inlined, during a micro-pass. + *) + (* First, retrieve the list of variables used for the inputs for the + * backward function *) + let inputs = T.RegionGroupId.Map.find rg_id ctx.backward_inputs in + (* Retrieve the values consumed upon ending the loans inside this + * abstraction: as there are no nested borrows, there should be none. *) + let consumed = abs_to_consumed ctx ectx abs in + assert (consumed = []); + (* Retrieve the values given back upon ending this abstraction - note that + * we don't provide meta-place information, because those assignments will + * be inlined anyway... *) + log#ldebug (lazy ("abs: " ^ abs_to_string ctx abs)); + let ctx, given_back = abs_to_given_back_no_mp abs ctx in + (* Link the inputs to those given back values - note that this also + * checks we have the same number of values, of course *) + let given_back_inputs = List.combine given_back inputs in + (* Sanity check *) + List.iter + (fun ((given_back, input) : typed_pattern * var) -> + log#ldebug + (lazy + ("\n- given_back ty: " + ^ ty_to_string ctx given_back.ty + ^ "\n- sig input ty: " ^ ty_to_string ctx input.ty)); + assert (given_back.ty = input.ty)) + given_back_inputs; + (* Translate the next expression *) + let next_e = translate_expression e ctx in + (* Generate the assignments *) + let monadic = false in + List.fold_right + (fun (given_back, input_var) e -> + mk_let monadic given_back (mk_texpression_from_var input_var) e) + given_back_inputs next_e + +and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) + (e : S.expression) (ctx : bs_ctx) (loop_id : V.LoopId.id) + (rg_id : T.RegionGroupId.id option) (abs_kind : V.loop_abs_kind) : + texpression = + let vloop_id = loop_id in + let loop_id = V.LoopId.Map.find loop_id ctx.loop_ids_map in + assert (loop_id = Option.get ctx.loop_id); + let rg_id = Option.get rg_id in + (* There are two cases depending on the [abs_kind] (whether this is a + synth input or a regular loop call) *) + match abs_kind with + | V.LoopSynthInput -> + (* Actually the same case as [SynthInput] *) + translate_end_abstraction_synth_input ectx abs e ctx rg_id + | V.LoopCall -> + let fun_id = A.Regular ctx.fun_decl.A.def_id in let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos fun_id (Some rg_id) - in - let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in - (* Retrieve the original call and the parent abstractions *) - let _forward, backwards = get_abs_ancestors ctx abs call_id in - (* Retrieve the values consumed when we called the forward function and - * ended the parent backward functions: those give us part of the input - * values (rem: for now, as we disallow nested lifetimes, there can't be - * parent backward functions). - * Note that the forward inputs **include the fuel and the input state** - * (if we use those). *) - let fwd_inputs = call_info.forward_inputs in - let back_ancestors_inputs = - List.concat (List.map (fun (_abs, args) -> args) backwards) + get_fun_effect_info ctx.fun_context.fun_infos fun_id (Some vloop_id) + (Some rg_id) in + let loop_info = LoopId.Map.find loop_id ctx.loops in + let type_args = loop_info.type_args in + let fwd_inputs = Option.get loop_info.forward_inputs in (* Retrieve the values consumed upon ending the loans inside this * abstraction: those give us the remaining input values *) let back_inputs = abs_to_consumed ctx ectx abs in @@ -1459,23 +1759,14 @@ and translate_end_abstraction (ectx : C.eval_ctx) (abs : V.abs) let back_state, ctx, nstate = if effect_info.stateful then let back_state = mk_state_texpression ctx.state_var in - let ctx, nstate = bs_ctx_fresh_state_var ctx in + let ctx, _, nstate = bs_ctx_fresh_state_var ctx in ([ back_state ], ctx, Some nstate) else ([], ctx, None) in (* Concatenate all the inpus *) - let inputs = - List.concat - [ fwd_inputs; back_ancestors_inputs; back_inputs; back_state ] - in - (* Retrieve the values given back by this function: those are the output - * values. We rely on the fact that there are no nested borrows to use the - * meta-place information from the input values given to the forward function - * (we need to add [None] for the return avalue) *) - let output_mpl = - List.append (List.map translate_opt_mplace call.args_places) [ None ] - in - let ctx, outputs = abs_to_given_back output_mpl abs ctx in + let inputs = List.concat [ fwd_inputs; back_inputs; back_state ] in + (* Retrieve the values given back by this function *) + let ctx, outputs = abs_to_given_back None abs ctx in (* Group the output values together: first the updated inputs *) let output = mk_simpl_tuple_pattern outputs in (* Add the returned state if the function is stateful *) @@ -1484,39 +1775,6 @@ and translate_end_abstraction (ectx : C.eval_ctx) (abs : V.abs) | None -> output | Some nstate -> mk_simpl_tuple_pattern [ nstate; output ] in - (* Sanity check: the inputs and outputs have the proper number and the proper type *) - let _ = - let inst_sg = - get_instantiated_fun_sig fun_id (Some rg_id) type_args ctx - in - log#ldebug - (lazy - ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs (" - ^ string_of_int (List.length inputs) - ^ "): " - ^ String.concat ", " (List.map show_texpression inputs) - ^ "\n- inst_sg.inputs (" - ^ string_of_int (List.length inst_sg.inputs) - ^ "): " - ^ String.concat ", " (List.map show_ty inst_sg.inputs))); - List.iter - (fun (x, ty) -> assert ((x : texpression).ty = ty)) - (List.combine inputs inst_sg.inputs); - log#ldebug - (lazy - ("\n- outputs: " - ^ string_of_int (List.length outputs) - ^ "\n- expected outputs: " - ^ string_of_int (List.length inst_sg.doutputs))); - List.iter - (fun (x, ty) -> assert ((x : typed_pattern).ty = ty)) - (List.combine outputs inst_sg.doutputs) - in - (* Retrieve the function id, and register the function call in the context - * if necessary *) - let ctx, func = - bs_ctx_register_backward_call abs call_id rg_id back_inputs ctx - in (* Translate the next expression *) let next_e = translate_expression e ctx in (* Put everything together *) @@ -1531,6 +1789,7 @@ and translate_end_abstraction (ectx : C.eval_ctx) (abs : V.abs) if effect_info.can_fail then mk_result_ty output.ty else output.ty in let func_ty = mk_arrows input_tys ret_ty in + let func = Fun (FromLlbc (fun_id, Some loop_id, Some rg_id)) in let func = { id = FunOrOp func; type_args } in let func = { e = Qualif func; ty = func_ty } in let call = mk_apps func args in @@ -1543,7 +1802,7 @@ and translate_end_abstraction (ectx : C.eval_ctx) (abs : V.abs) * TODO: use an option to disallow backward functions from updating the state. * TODO: a backward function which only gives back shared borrows shouldn't * update the state (state updates should only be used for mutable borrows, - * with objects like Rc for instance. + * with objects like Rc for instance). *) if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None then ( @@ -1554,69 +1813,6 @@ and translate_end_abstraction (ectx : C.eval_ctx) (abs : V.abs) assert (List.length inputs = List.length fwd_inputs); next_e) else mk_let effect_info.can_fail output call next_e - | V.SynthRet rg_id -> - (* If we end the abstraction which consumed the return value of the function - we are synthesizing, we get back the borrows which were inside. Those borrows - are actually input arguments of the backward function we are synthesizing. - So we simply need to introduce proper let bindings. - - For instance: - {[ - fn id<'a>(x : &'a mut u32) -> &'a mut u32 { - x - } - ]} - - Upon ending the return abstraction for 'a, we get back the borrow for [x]. - This new value is the second argument of the backward function: - {[ - let id_back x nx = nx - ]} - - In practice, upon ending this abstraction we introduce a useless - let-binding: - {[ - let id_back x nx = - let s = nx in // the name [s] is not important (only collision matters) - ... - ]} - - This let-binding later gets inlined, during a micro-pass. - *) - (* First, retrieve the list of variables used for the inputs for the - * backward function *) - let inputs = T.RegionGroupId.Map.find rg_id ctx.backward_inputs in - (* Retrieve the values consumed upon ending the loans inside this - * abstraction: as there are no nested borrows, there should be none. *) - let consumed = abs_to_consumed ctx ectx abs in - assert (consumed = []); - (* Retrieve the values given back upon ending this abstraction - note that - * we don't provide meta-place information, because those assignments will - * be inlined anyway... *) - log#ldebug (lazy ("abs: " ^ abs_to_string ctx abs)); - let ctx, given_back = abs_to_given_back_no_mp abs ctx in - (* Link the inputs to those given back values - note that this also - * checks we have the same number of values, of course *) - let given_back_inputs = List.combine given_back inputs in - (* Sanity check *) - List.iter - (fun ((given_back, input) : typed_pattern * var) -> - log#ldebug - (lazy - ("\n- given_back ty: " - ^ ty_to_string ctx given_back.ty - ^ "\n- sig input ty: " ^ ty_to_string ctx input.ty)); - assert (given_back.ty = input.ty)) - given_back_inputs; - (* Translate the next expression *) - let next_e = translate_expression e ctx in - (* Generate the assignments *) - let monadic = false in - List.fold_right - (fun (given_back, input_var) e -> - mk_let monadic given_back (mk_texpression_from_var input_var) e) - given_back_inputs next_e - | Loop _ -> raise (Failure "Unimplemented") and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -1841,8 +2037,177 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) List.for_all (fun (br : match_branch) -> br.branch.ty = ty) branches); { e; ty } +and translate_forward_end (ectx : C.eval_ctx) + (loop_input_values : V.typed_value S.symbolic_value_id_map option) + (e : S.expression) (back_e : S.expression S.region_group_id_map) + (ctx : bs_ctx) : texpression = + (* Update the current state with the additional state received by the backward + function, if needs be, and lookup the proper expression *) + let translate_end ctx = + (* Update the current state with the additional state received by the backward + function, if needs be, and lookup the proper expression *) + let ctx, e = + match ctx.bid with + | None -> (ctx, e) + | Some bid -> + let ctx = { ctx with state_var = ctx.back_state_var } in + let e = T.RegionGroupId.Map.find bid back_e in + (ctx, e) + in + translate_expression e ctx + in + + (* If we entered/are entering a loop, we need to introduce a call to the + forward translation of the loop. *) + match loop_input_values with + | None -> + (* "Regular" case: not a loop *) + assert (ctx.loop_id = None); + translate_end ctx + | Some loop_input_values -> + (* Loop *) + let loop_id = Option.get ctx.loop_id in + + (* Lookup the loop information *) + let loop_info = LoopId.Map.find loop_id ctx.loops in + + (* Translate the input values *) + let loop_input_values = + List.map + (fun sv -> V.SymbolicValueId.Map.find sv.V.sv_id loop_input_values) + loop_info.input_svl + in + let args = + List.map (typed_value_to_texpression ctx ectx) loop_input_values + in + + (* Lookup the effect info for the loop function *) + let fid = A.Regular ctx.fun_decl.A.def_id in + let effect_info = + get_fun_effect_info ctx.fun_context.fun_infos fid None ctx.bid + in + + (* Introduce a fresh output value for the forward function *) + let ctx, output_var = + let output_ty = ctx.sg.output in + fresh_var None output_ty ctx + in + let args, ctx, out_pats = + let output_pat = mk_typed_pattern_from_var output_var None in + + (* Depending on the function effects: + * - add the fuel + * - add the state input argument + * - generate a fresh state variable for the returned state + * TODO: we do exactly the same thing in {!translate_function_call} + *) + let fuel = mk_fuel_input_as_list ctx effect_info in + if effect_info.stateful then + let state_var = mk_state_texpression ctx.state_var in + let ctx, _nstate_var, nstate_pat = bs_ctx_fresh_state_var ctx in + ( List.concat [ fuel; args; [ state_var ] ], + ctx, + [ nstate_pat; output_pat ] ) + else (List.concat [ fuel; args ], ctx, [ output_pat ]) + in + + (* Update the loop information in the context *) + let loop_info = + { + loop_info with + forward_inputs = Some args; + forward_output_no_state = Some output_var; + } + in + let ctx = + { ctx with loops = LoopId.Map.add loop_id loop_info ctx.loops } + in + + (* Translate the end of the function *) + let next_e = translate_end ctx in + + (* Introduce the call to the loop in the generated AST *) + let out_pat = mk_simpl_tuple_pattern out_pats in + let loop_call = + let fun_id = Fun (FromLlbc (fid, Some loop_id, None)) in + let func = { id = FunOrOp fun_id; type_args = loop_info.type_args } in + let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in + let ret_ty = + if effect_info.can_fail then mk_result_ty out_pat.ty else out_pat.ty + in + let func_ty = mk_arrows input_tys ret_ty in + let func = { e = Qualif func; ty = func_ty } in + let call = mk_apps func args in + call + in + mk_let effect_info.can_fail out_pat loop_call next_e + and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = - raise (Failure "Unreachable") + let loop_id = V.LoopId.Map.find loop.loop_id ctx.loop_ids_map in + + (* Translate the loop inputs *) + let inputs = + List.map + (fun sv -> V.SymbolicValueId.Map.find sv.V.sv_id ctx.sv_to_var) + loop.input_svalues + in + let inputs_lvs = + List.map (fun var -> mk_typed_pattern_from_var var None) inputs + in + + (* Add the loop information in the context *) + let ctx = + assert (not (LoopId.Map.mem loop_id ctx.loops)); + + (* Note that we will retrieve the input values later in the [ForwardEnd] + (and will introduce the outputs at that moment, together with the actual + call to the loop forward function *) + let type_args = + List.map (fun ty -> TypeVar ty.T.index) ctx.sg.type_params + in + + let loop_info = + { + loop_id; + input_svl = loop.input_svalues; + type_args; + forward_inputs = None; + forward_output_no_state = None; + } + in + let loops = LoopId.Map.add loop_id loop_info ctx.loops in + { ctx with loops } + in + + (* Update the context to translate the function end *) + let ctx_end = { ctx with loop_id = Some loop_id } in + let fun_end = translate_expression loop.end_expr ctx_end in + + (* Update the context for the loop body *) + let ctx_loop = { ctx_end with inside_loop = true } in + (* We also need to introduce variables for the symbolic values which are + introduced in the fixed point (we have to filter the list of symbolic + values, to remove the not fresh ones - the fixed point introduces some + symbolic values and keeps some others)... *) + let ctx_loop = + let svl = + List.filter + (fun (sv : V.symbolic_value) -> + V.SymbolicValueId.Set.mem sv.sv_id loop.fresh_svalues) + loop.input_svalues + in + let ctx_loop, _ = fresh_vars_for_symbolic_values svl ctx_loop in + ctx_loop + in + + (* Translate the loop body *) + let loop_body = translate_expression loop.loop_expr ctx_loop in + + (* Create the loop node and return *) + let loop = Loop { fun_end; loop_id; inputs; inputs_lvs; loop_body } in + assert (fun_end.ty = loop_body.ty); + let ty = fun_end.ty in + { e = loop; ty } and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -1947,7 +2312,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = | None -> None | Some body -> let effect_info = - get_fun_effect_info ctx.fun_context.fun_infos (Regular def_id) bid + get_fun_effect_info ctx.fun_context.fun_infos (Regular def_id) None + bid in let body = translate_expression body ctx in (* Add a match over the fuel, if necessary *) |