summaryrefslogtreecommitdiff
path: root/compiler/SymbolicToPure.ml
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--compiler/SymbolicToPure.ml808
1 files changed, 587 insertions, 221 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index dd662074..006fdda7 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -65,6 +65,46 @@ type call_info = {
*)
}
+(** Contains information about a loop we entered.
+
+ Note that a path in a translated function body can have at most one call to
+ a loop, because the loop function takes care of the end of the execution
+ (and always happen at the end of the function). To be more precise, if we
+ translate a function body which contains a loop, one of the leaves will be a
+ call to the loop translation. The same happens for loop bodies.
+
+ For instance, if in Rust we have:
+ {[
+ fn get(...) {
+ let x = f(...);
+
+ loop {
+ ...
+ }
+ }
+ ]}
+
+ Then in the translation we have:
+ {[
+ let get_fwd ... =
+ let x = f_fwd ... in
+ (* We end the function by calling the loop translation *)
+ get_fwd_loop ...
+ ]}
+
+ The various input and output fields are for this unique loop call, if
+ there is one.
+ *)
+type loop_info = {
+ loop_id : LoopId.id;
+ input_svl : V.symbolic_value list;
+ type_args : ty list;
+ forward_inputs : texpression list option;
+ (** The forward inputs are initialized at [None] *)
+ forward_output_no_state : var option;
+ (** The forward outputs are initialized at [None] *)
+}
+
(** Body synthesis context *)
type bs_ctx = {
type_context : type_context;
@@ -119,7 +159,14 @@ type bs_ctx = {
(** The function calls we encountered so far *)
abstractions : (V.abs * texpression list) V.AbstractionId.Map.t;
(** The ended abstractions we encountered so far, with their additional input arguments *)
- loop_id : V.LoopId.id option;
+ loop_ids_map : LoopId.id V.LoopId.Map.t; (** Ids to use for the loops *)
+ loops : loop_info LoopId.Map.t;
+ (** The loops we encountered so far.
+
+ We are using a map to be general - in practice we will fail if we encounter
+ more than one loop on a single path.
+ *)
+ loop_id : LoopId.id option;
(** [Some] if we reached a loop (we are synthesizing a function, and reached a loop, or are
synthesizing the loop body itself)
*)
@@ -535,7 +582,8 @@ let mk_fuel_input_as_list (ctx : bs_ctx) (info : fun_effect_info) :
(** Small utility. *)
let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
- (fun_id : A.fun_id) (gid : T.RegionGroupId.id option) : fun_effect_info =
+ (fun_id : A.fun_id) (lid : V.LoopId.id option)
+ (gid : T.RegionGroupId.id option) : fun_effect_info =
match fun_id with
| A.Regular fid ->
let info = A.FunDeclId.Map.find fid fun_infos in
@@ -548,9 +596,10 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
stateful_group;
stateful;
can_diverge = info.can_diverge;
- is_rec = info.is_rec;
+ is_rec = info.is_rec || Option.is_some lid;
}
| A.Assumed aid ->
+ assert (lid = None);
{
can_fail = Assumed.assumed_can_fail aid;
stateful_group = false;
@@ -579,7 +628,8 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
(Some bid, parents)
in
(* Is the function stateful, and can it fail? *)
- let effect_info = get_fun_effect_info fun_infos fun_id bid in
+ let lid = None in
+ let effect_info = get_fun_effect_info fun_infos fun_id lid bid in
(* List the inputs for:
* - the fuel
* - the forward function
@@ -728,28 +778,37 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
let sg = { type_params; inputs; output; doutputs; info } in
{ sg; output_names }
-let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern =
+let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * var * typed_pattern =
(* Generate the fresh variable *)
let id, var_counter = VarId.fresh ctx.var_counter in
- let var =
+ let state_var =
{ id; basename = Some ConstStrings.state_basename; ty = mk_state_ty }
in
- let state_var = mk_typed_pattern_from_var var None in
+ let state_pat = mk_typed_pattern_from_var state_var None in
(* Update the context *)
let ctx = { ctx with var_counter; state_var = id } in
(* Return *)
- (ctx, state_var)
+ (ctx, state_var, state_pat)
-let fresh_named_var_for_symbolic_value (basename : string option)
- (sv : V.symbolic_value) (ctx : bs_ctx) : bs_ctx * var =
+let fresh_var (basename : string option) (ty : 'r T.ty) (ctx : bs_ctx) :
+ bs_ctx * var =
(* Generate the fresh variable *)
let id, var_counter = VarId.fresh ctx.var_counter in
- let ty = ctx_translate_fwd_ty ctx sv.sv_ty in
+ let ty = ctx_translate_fwd_ty ctx ty in
let var = { id; basename; ty } in
+ (* Update the context *)
+ let ctx = { ctx with var_counter } in
+ (* Return *)
+ (ctx, var)
+
+let fresh_named_var_for_symbolic_value (basename : string option)
+ (sv : V.symbolic_value) (ctx : bs_ctx) : bs_ctx * var =
+ (* Generate the fresh variable *)
+ let ctx, var = fresh_var basename sv.sv_ty ctx in
(* Insert in the map *)
let sv_to_var = V.SymbolicValueId.Map.add sv.sv_id var ctx.sv_to_var in
(* Update the context *)
- let ctx = { ctx with var_counter; sv_to_var } in
+ let ctx = { ctx with sv_to_var } in
(* Return *)
(ctx, var)
@@ -1136,9 +1195,13 @@ and aproj_to_given_back (mp : mplace option) (aproj : V.aproj) (ctx : bs_ctx) :
See [typed_avalue_to_given_back].
*)
-let abs_to_given_back (mpl : mplace option list) (abs : V.abs) (ctx : bs_ctx) :
- bs_ctx * typed_pattern list =
- let avalues = List.combine mpl abs.avalues in
+let abs_to_given_back (mpl : mplace option list option) (abs : V.abs)
+ (ctx : bs_ctx) : bs_ctx * typed_pattern list =
+ let avalues =
+ match mpl with
+ | None -> List.map (fun av -> (None, av)) abs.avalues
+ | Some mpl -> List.combine mpl abs.avalues
+ in
let ctx, values =
List.fold_left_map
(fun ctx (mp, av) -> typed_avalue_to_given_back mp av ctx)
@@ -1151,7 +1214,7 @@ let abs_to_given_back (mpl : mplace option list) (abs : V.abs) (ctx : bs_ctx) :
let abs_to_given_back_no_mp (abs : V.abs) (ctx : bs_ctx) :
bs_ctx * typed_pattern list =
let mpl = List.map (fun _ -> None) abs.avalues in
- abs_to_given_back mpl abs ctx
+ abs_to_given_back (Some mpl) abs ctx
(** Return the ordered list of the (transitive) parents of a given abstraction.
@@ -1167,6 +1230,8 @@ let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) (call_id : V.FunCallId.id) :
let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =
match e with
| S.Return (ectx, opt_v) -> translate_return ectx opt_v ctx
+ | ReturnWithLoop (loop_id, is_continue) ->
+ translate_return_with_loop loop_id is_continue ctx
| Panic -> translate_panic ctx
| FunCall (call, e) -> translate_function_call call e ctx
| EndAbstraction (ectx, abs, e) -> translate_end_abstraction ectx abs e ctx
@@ -1174,19 +1239,8 @@ let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =
| Assertion (ectx, v, e) -> translate_assertion ectx v e ctx
| Expansion (p, sv, exp) -> translate_expansion p sv exp ctx
| Meta (meta, e) -> translate_meta meta e ctx
- | ForwardEnd (loop_input_values, e, back_e) ->
- assert (loop_input_values = None);
- (* Update the current state with the additional state received by the backward
- function, if needs be, and lookup the proper expression *)
- let ctx, e =
- match ctx.bid with
- | None -> (ctx, e)
- | Some bid ->
- let ctx = { ctx with state_var = ctx.back_state_var } in
- let e = T.RegionGroupId.Map.find bid back_e in
- (ctx, e)
- in
- translate_expression e ctx
+ | ForwardEnd (ectx, loop_input_values, e, back_e) ->
+ translate_forward_end ectx loop_input_values e back_e ctx
| Loop loop -> translate_loop loop ctx
and translate_panic (ctx : bs_ctx) : texpression =
@@ -1206,12 +1260,12 @@ and translate_panic (ctx : bs_ctx) : texpression =
ret_v
else mk_result_fail_texpression_with_error_id error_failure_id output_ty
-(** [opt_v]: the value to return, in case we translate a forward function *)
+(** [opt_v]: the value to return, in case we translate a forward body *)
and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option)
(ctx : bs_ctx) : texpression =
(* There are two cases:
- - either we are translating a forward function, in which case the optional
- value should be [Some] (it is the returned value)
+ - either we reach the return of a forward function or a forward loop body,
+ in which case the optional value should be [Some] (it is the returned value)
- or we are translating a backward function, in which case it should be [None]
*)
(* Compute the values that we should return *without the state and the result
@@ -1246,7 +1300,52 @@ and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option)
else output
in
(* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *)
- (* TODO: we should use a [Return] function *)
+ mk_result_return_texpression output
+
+and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool)
+ (ctx : bs_ctx) : texpression =
+ assert (is_continue = ctx.inside_loop);
+ let loop_id = V.LoopId.Map.find loop_id ctx.loop_ids_map in
+ assert (loop_id = Option.get ctx.loop_id);
+
+ (* Lookup the loop information *)
+ let loop_id = Option.get ctx.loop_id in
+ let loop_info = LoopId.Map.find loop_id ctx.loops in
+
+ (* There are two cases depending on whether we translate a backward function
+ or not.
+ *)
+ let output =
+ match ctx.bid with
+ | None ->
+ (* Forward *)
+ mk_texpression_from_var (Option.get loop_info.forward_output_no_state)
+ | Some bid ->
+ (* Backward *)
+ (* Group the variables in which we stored the values we need to give back.
+ * See the explanations for the [SynthInput] case in [translate_end_abstraction] *)
+ let backward_outputs =
+ T.RegionGroupId.Map.find bid ctx.backward_outputs
+ in
+ let field_values = List.map mk_texpression_from_var backward_outputs in
+ mk_simpl_tuple_texpression field_values
+ in
+
+ (* We may need to return a state
+ * - error-monad: Return x
+ * - state-error: Return (state, x)
+ * Note that the loop function and the parent function live in the same
+ * effect - in particular, one manipulates a state iff the other does
+ * the same.
+ * *)
+ let effect_info = ctx.sg.info.effect_info in
+ let output =
+ if effect_info.stateful then
+ let state_rvalue = mk_state_texpression ctx.state_var in
+ mk_simpl_tuple_texpression [ state_rvalue; output ]
+ else output
+ in
+ (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *)
mk_result_return_texpression output
and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
@@ -1272,18 +1371,18 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
(* Retrieve the effect information about this function (can fail,
* takes a state as input, etc.) *)
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos fid None
+ get_fun_effect_info ctx.fun_context.fun_infos fid None None
in
- (* If the function is stateful:
- * - add the fuel
- * - add the state input argument
- * - generate a fresh state variable for the returned state
- *)
+ (* Depending on the function effects:
+ * - add the fuel
+ * - add the state input argument
+ * - generate a fresh state variable for the returned state
+ *)
let args, ctx, out_state =
let fuel = mk_fuel_input_as_list ctx effect_info in
if effect_info.stateful then
let state_var = mk_state_texpression ctx.state_var in
- let ctx, nstate_var = bs_ctx_fresh_state_var ctx in
+ let ctx, _, nstate_var = bs_ctx_fresh_state_var ctx in
(List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var)
else (List.concat [ fuel; args ], ctx, None)
in
@@ -1375,80 +1474,281 @@ and translate_end_abstraction (ectx : C.eval_ctx) (abs : V.abs)
^ V.show_abs_kind abs.kind));
match abs.kind with
| V.SynthInput rg_id ->
- (* When we end an input abstraction, this input abstraction gets back
- * the borrows which it introduced in the context through the input
- * values: by listing those values, we get the values which are given
- * back by one of the backward functions we are synthesizing. *)
- (* Note that we don't support nested borrows for now: if we find
- * an ended synthesized input abstraction, it must be the one corresponding
- * to the backward function wer are synthesizing, it can't be the one
- * for a parent backward function.
- *)
- let bid = Option.get ctx.bid in
- assert (rg_id = bid);
-
- (* The translation is done as follows:
- * - for a given backward function, we choose a set of variables [v_i]
- * - when we detect the ended input abstraction which corresponds
- * to the backward function, and which consumed the values [consumed_i],
- * we introduce:
- * {[
+ translate_end_abstraction_synth_input ectx abs e ctx rg_id
+ | V.FunCall (call_id, rg_id) ->
+ translate_end_abstraction_fun_call ectx abs e ctx call_id rg_id
+ | V.SynthRet rg_id -> translate_end_abstraction_synth_ret ectx abs e ctx rg_id
+ | Loop (loop_id, rg_id, abs_kind) ->
+ translate_end_abstraction_loop ectx abs e ctx loop_id rg_id abs_kind
+
+and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs)
+ (e : S.expression) (ctx : bs_ctx) (rg_id : T.RegionGroupId.id) : texpression
+ =
+ (* When we end an input abstraction, this input abstraction gets back
+ * the borrows which it introduced in the context through the input
+ * values: by listing those values, we get the values which are given
+ * back by one of the backward functions we are synthesizing. *)
+ (* Note that we don't support nested borrows for now: if we find
+ * an ended synthesized input abstraction, it must be the one corresponding
+ * to the backward function wer are synthesizing, it can't be the one
+ * for a parent backward function.
+ *)
+ let bid = Option.get ctx.bid in
+ assert (rg_id = bid);
+
+ (* The translation is done as follows:
+ * - for a given backward function, we choose a set of variables [v_i]
+ * - when we detect the ended input abstraction which corresponds
+ * to the backward function, and which consumed the values [consumed_i],
+ * we introduce:
+ * {[
* let v_i = consumed_i in
* ...
- * ]}
- * Then, when we reach the [Return] node, we introduce:
- * {[
+ * ]}
+ * Then, when we reach the [Return] node, we introduce:
+ * {[
* (v_i)
- * ]}
- * *)
- (* First, get the given back variables *)
- let given_back_variables =
- T.RegionGroupId.Map.find bid ctx.backward_outputs
- in
- (* Get the list of values consumed by the abstraction upon ending *)
- let consumed_values = abs_to_consumed ctx ectx abs in
- (* Group the two lists *)
- let variables_values =
- List.combine given_back_variables consumed_values
- in
- (* Sanity check: the two lists match (same types) *)
- List.iter
- (fun (var, v) -> assert ((var : var).ty = (v : texpression).ty))
- variables_values;
- (* Translate the next expression *)
- let next_e = translate_expression e ctx in
- (* Generate the assignemnts *)
- let monadic = false in
- List.fold_right
- (fun (var, value) (e : texpression) ->
- mk_let monadic (mk_typed_pattern_from_var var None) value e)
- variables_values next_e
- | V.FunCall (call_id, rg_id) ->
- let call_info = V.FunCallId.Map.find call_id ctx.calls in
- let call = call_info.forward in
- let fun_id =
- match call.call_id with
- | S.Fun (fun_id, _) -> fun_id
- | Unop _ | Binop _ ->
- (* Those don't have backward functions *)
- raise (Failure "Unreachable")
- in
+ * ]}
+ * *)
+ (* First, get the given back variables *)
+ let given_back_variables =
+ T.RegionGroupId.Map.find bid ctx.backward_outputs
+ in
+ (* Get the list of values consumed by the abstraction upon ending *)
+ let consumed_values = abs_to_consumed ctx ectx abs in
+ (* Group the two lists *)
+ let variables_values = List.combine given_back_variables consumed_values in
+ (* Sanity check: the two lists match (same types) *)
+ List.iter
+ (fun (var, v) -> assert ((var : var).ty = (v : texpression).ty))
+ variables_values;
+ (* Translate the next expression *)
+ let next_e = translate_expression e ctx in
+ (* Generate the assignemnts *)
+ let monadic = false in
+ List.fold_right
+ (fun (var, value) (e : texpression) ->
+ mk_let monadic (mk_typed_pattern_from_var var None) value e)
+ variables_values next_e
+
+and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
+ (e : S.expression) (ctx : bs_ctx) (call_id : V.FunCallId.id)
+ (rg_id : T.RegionGroupId.id) : texpression =
+ let call_info = V.FunCallId.Map.find call_id ctx.calls in
+ let call = call_info.forward in
+ let fun_id =
+ match call.call_id with
+ | S.Fun (fun_id, _) -> fun_id
+ | Unop _ | Binop _ ->
+ (* Those don't have backward functions *)
+ raise (Failure "Unreachable")
+ in
+ let effect_info =
+ get_fun_effect_info ctx.fun_context.fun_infos fun_id None (Some rg_id)
+ in
+ let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in
+ (* Retrieve the original call and the parent abstractions *)
+ let _forward, backwards = get_abs_ancestors ctx abs call_id in
+ (* Retrieve the values consumed when we called the forward function and
+ * ended the parent backward functions: those give us part of the input
+ * values (rem: for now, as we disallow nested lifetimes, there can't be
+ * parent backward functions).
+ * Note that the forward inputs **include the fuel and the input state**
+ * (if we use those). *)
+ let fwd_inputs = call_info.forward_inputs in
+ let back_ancestors_inputs =
+ List.concat (List.map (fun (_abs, args) -> args) backwards)
+ in
+ (* Retrieve the values consumed upon ending the loans inside this
+ * abstraction: those give us the remaining input values *)
+ let back_inputs = abs_to_consumed ctx ectx abs in
+ (* If the function is stateful:
+ * - add the state input argument
+ * - generate a fresh state variable for the returned state
+ *)
+ let back_state, ctx, nstate =
+ if effect_info.stateful then
+ let back_state = mk_state_texpression ctx.state_var in
+ let ctx, _, nstate = bs_ctx_fresh_state_var ctx in
+ ([ back_state ], ctx, Some nstate)
+ else ([], ctx, None)
+ in
+ (* Concatenate all the inpus *)
+ let inputs =
+ List.concat [ fwd_inputs; back_ancestors_inputs; back_inputs; back_state ]
+ in
+ (* Retrieve the values given back by this function: those are the output
+ * values. We rely on the fact that there are no nested borrows to use the
+ * meta-place information from the input values given to the forward function
+ * (we need to add [None] for the return avalue) *)
+ let output_mpl =
+ List.append (List.map translate_opt_mplace call.args_places) [ None ]
+ in
+ let ctx, outputs = abs_to_given_back (Some output_mpl) abs ctx in
+ (* Group the output values together: first the updated inputs *)
+ let output = mk_simpl_tuple_pattern outputs in
+ (* Add the returned state if the function is stateful *)
+ let output =
+ match nstate with
+ | None -> output
+ | Some nstate -> mk_simpl_tuple_pattern [ nstate; output ]
+ in
+ (* Sanity check: there is the proper number of inputs and outputs, and they have the proper type *)
+ let _ =
+ let inst_sg = get_instantiated_fun_sig fun_id (Some rg_id) type_args ctx in
+ log#ldebug
+ (lazy
+ ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs ("
+ ^ string_of_int (List.length inputs)
+ ^ "): "
+ ^ String.concat ", " (List.map show_texpression inputs)
+ ^ "\n- inst_sg.inputs ("
+ ^ string_of_int (List.length inst_sg.inputs)
+ ^ "): "
+ ^ String.concat ", " (List.map show_ty inst_sg.inputs)));
+ List.iter
+ (fun (x, ty) -> assert ((x : texpression).ty = ty))
+ (List.combine inputs inst_sg.inputs);
+ log#ldebug
+ (lazy
+ ("\n- outputs: "
+ ^ string_of_int (List.length outputs)
+ ^ "\n- expected outputs: "
+ ^ string_of_int (List.length inst_sg.doutputs)));
+ List.iter
+ (fun (x, ty) -> assert ((x : typed_pattern).ty = ty))
+ (List.combine outputs inst_sg.doutputs)
+ in
+ (* Retrieve the function id, and register the function call in the context
+ * if necessary *)
+ let ctx, func =
+ bs_ctx_register_backward_call abs call_id rg_id back_inputs ctx
+ in
+ (* Translate the next expression *)
+ let next_e = translate_expression e ctx in
+ (* Put everything together *)
+ let args_mplaces = List.map (fun _ -> None) inputs in
+ let args =
+ List.map
+ (fun (arg, mp) -> mk_opt_mplace_texpression mp arg)
+ (List.combine inputs args_mplaces)
+ in
+ let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
+ let ret_ty =
+ if effect_info.can_fail then mk_result_ty output.ty else output.ty
+ in
+ let func_ty = mk_arrows input_tys ret_ty in
+ let func = { id = FunOrOp func; type_args } in
+ let func = { e = Qualif func; ty = func_ty } in
+ let call = mk_apps func args in
+ (* **Optimization**:
+ * =================
+ * We do a small optimization here: if the backward function doesn't
+ * have any output, we don't introduce any function call.
+ * See the comment in {!Config.filter_useless_monadic_calls}.
+ *
+ * TODO: use an option to disallow backward functions from updating the state.
+ * TODO: a backward function which only gives back shared borrows shouldn't
+ * update the state (state updates should only be used for mutable borrows,
+ * with objects like Rc for instance).
+ *)
+ if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None then (
+ (* No outputs - we do a small sanity check: the backward function
+ * should have exactly the same number of inputs as the forward:
+ * this number can be different only if the forward function returned
+ * a value containing mutable borrows, which can't be the case... *)
+ assert (List.length inputs = List.length fwd_inputs);
+ next_e)
+ else mk_let effect_info.can_fail output call next_e
+
+and translate_end_abstraction_synth_ret (ectx : C.eval_ctx) (abs : V.abs)
+ (e : S.expression) (ctx : bs_ctx) (rg_id : T.RegionGroupId.id) : texpression
+ =
+ (* If we end the abstraction which consumed the return value of the function
+ we are synthesizing, we get back the borrows which were inside. Those borrows
+ are actually input arguments of the backward function we are synthesizing.
+ So we simply need to introduce proper let bindings.
+
+ For instance:
+ {[
+ fn id<'a>(x : &'a mut u32) -> &'a mut u32 {
+ x
+ }
+ ]}
+
+ Upon ending the return abstraction for 'a, we get back the borrow for [x].
+ This new value is the second argument of the backward function:
+ {[
+ let id_back x nx = nx
+ ]}
+
+ In practice, upon ending this abstraction we introduce a useless
+ let-binding:
+ {[
+ let id_back x nx =
+ let s = nx in // the name [s] is not important (only collision matters)
+ ...
+ ]}
+
+ This let-binding later gets inlined, during a micro-pass.
+ *)
+ (* First, retrieve the list of variables used for the inputs for the
+ * backward function *)
+ let inputs = T.RegionGroupId.Map.find rg_id ctx.backward_inputs in
+ (* Retrieve the values consumed upon ending the loans inside this
+ * abstraction: as there are no nested borrows, there should be none. *)
+ let consumed = abs_to_consumed ctx ectx abs in
+ assert (consumed = []);
+ (* Retrieve the values given back upon ending this abstraction - note that
+ * we don't provide meta-place information, because those assignments will
+ * be inlined anyway... *)
+ log#ldebug (lazy ("abs: " ^ abs_to_string ctx abs));
+ let ctx, given_back = abs_to_given_back_no_mp abs ctx in
+ (* Link the inputs to those given back values - note that this also
+ * checks we have the same number of values, of course *)
+ let given_back_inputs = List.combine given_back inputs in
+ (* Sanity check *)
+ List.iter
+ (fun ((given_back, input) : typed_pattern * var) ->
+ log#ldebug
+ (lazy
+ ("\n- given_back ty: "
+ ^ ty_to_string ctx given_back.ty
+ ^ "\n- sig input ty: " ^ ty_to_string ctx input.ty));
+ assert (given_back.ty = input.ty))
+ given_back_inputs;
+ (* Translate the next expression *)
+ let next_e = translate_expression e ctx in
+ (* Generate the assignments *)
+ let monadic = false in
+ List.fold_right
+ (fun (given_back, input_var) e ->
+ mk_let monadic given_back (mk_texpression_from_var input_var) e)
+ given_back_inputs next_e
+
+and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
+ (e : S.expression) (ctx : bs_ctx) (loop_id : V.LoopId.id)
+ (rg_id : T.RegionGroupId.id option) (abs_kind : V.loop_abs_kind) :
+ texpression =
+ let vloop_id = loop_id in
+ let loop_id = V.LoopId.Map.find loop_id ctx.loop_ids_map in
+ assert (loop_id = Option.get ctx.loop_id);
+ let rg_id = Option.get rg_id in
+ (* There are two cases depending on the [abs_kind] (whether this is a
+ synth input or a regular loop call) *)
+ match abs_kind with
+ | V.LoopSynthInput ->
+ (* Actually the same case as [SynthInput] *)
+ translate_end_abstraction_synth_input ectx abs e ctx rg_id
+ | V.LoopCall ->
+ let fun_id = A.Regular ctx.fun_decl.A.def_id in
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos fun_id (Some rg_id)
- in
- let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in
- (* Retrieve the original call and the parent abstractions *)
- let _forward, backwards = get_abs_ancestors ctx abs call_id in
- (* Retrieve the values consumed when we called the forward function and
- * ended the parent backward functions: those give us part of the input
- * values (rem: for now, as we disallow nested lifetimes, there can't be
- * parent backward functions).
- * Note that the forward inputs **include the fuel and the input state**
- * (if we use those). *)
- let fwd_inputs = call_info.forward_inputs in
- let back_ancestors_inputs =
- List.concat (List.map (fun (_abs, args) -> args) backwards)
+ get_fun_effect_info ctx.fun_context.fun_infos fun_id (Some vloop_id)
+ (Some rg_id)
in
+ let loop_info = LoopId.Map.find loop_id ctx.loops in
+ let type_args = loop_info.type_args in
+ let fwd_inputs = Option.get loop_info.forward_inputs in
(* Retrieve the values consumed upon ending the loans inside this
* abstraction: those give us the remaining input values *)
let back_inputs = abs_to_consumed ctx ectx abs in
@@ -1459,23 +1759,14 @@ and translate_end_abstraction (ectx : C.eval_ctx) (abs : V.abs)
let back_state, ctx, nstate =
if effect_info.stateful then
let back_state = mk_state_texpression ctx.state_var in
- let ctx, nstate = bs_ctx_fresh_state_var ctx in
+ let ctx, _, nstate = bs_ctx_fresh_state_var ctx in
([ back_state ], ctx, Some nstate)
else ([], ctx, None)
in
(* Concatenate all the inpus *)
- let inputs =
- List.concat
- [ fwd_inputs; back_ancestors_inputs; back_inputs; back_state ]
- in
- (* Retrieve the values given back by this function: those are the output
- * values. We rely on the fact that there are no nested borrows to use the
- * meta-place information from the input values given to the forward function
- * (we need to add [None] for the return avalue) *)
- let output_mpl =
- List.append (List.map translate_opt_mplace call.args_places) [ None ]
- in
- let ctx, outputs = abs_to_given_back output_mpl abs ctx in
+ let inputs = List.concat [ fwd_inputs; back_inputs; back_state ] in
+ (* Retrieve the values given back by this function *)
+ let ctx, outputs = abs_to_given_back None abs ctx in
(* Group the output values together: first the updated inputs *)
let output = mk_simpl_tuple_pattern outputs in
(* Add the returned state if the function is stateful *)
@@ -1484,39 +1775,6 @@ and translate_end_abstraction (ectx : C.eval_ctx) (abs : V.abs)
| None -> output
| Some nstate -> mk_simpl_tuple_pattern [ nstate; output ]
in
- (* Sanity check: the inputs and outputs have the proper number and the proper type *)
- let _ =
- let inst_sg =
- get_instantiated_fun_sig fun_id (Some rg_id) type_args ctx
- in
- log#ldebug
- (lazy
- ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs ("
- ^ string_of_int (List.length inputs)
- ^ "): "
- ^ String.concat ", " (List.map show_texpression inputs)
- ^ "\n- inst_sg.inputs ("
- ^ string_of_int (List.length inst_sg.inputs)
- ^ "): "
- ^ String.concat ", " (List.map show_ty inst_sg.inputs)));
- List.iter
- (fun (x, ty) -> assert ((x : texpression).ty = ty))
- (List.combine inputs inst_sg.inputs);
- log#ldebug
- (lazy
- ("\n- outputs: "
- ^ string_of_int (List.length outputs)
- ^ "\n- expected outputs: "
- ^ string_of_int (List.length inst_sg.doutputs)));
- List.iter
- (fun (x, ty) -> assert ((x : typed_pattern).ty = ty))
- (List.combine outputs inst_sg.doutputs)
- in
- (* Retrieve the function id, and register the function call in the context
- * if necessary *)
- let ctx, func =
- bs_ctx_register_backward_call abs call_id rg_id back_inputs ctx
- in
(* Translate the next expression *)
let next_e = translate_expression e ctx in
(* Put everything together *)
@@ -1531,6 +1789,7 @@ and translate_end_abstraction (ectx : C.eval_ctx) (abs : V.abs)
if effect_info.can_fail then mk_result_ty output.ty else output.ty
in
let func_ty = mk_arrows input_tys ret_ty in
+ let func = Fun (FromLlbc (fun_id, Some loop_id, Some rg_id)) in
let func = { id = FunOrOp func; type_args } in
let func = { e = Qualif func; ty = func_ty } in
let call = mk_apps func args in
@@ -1543,7 +1802,7 @@ and translate_end_abstraction (ectx : C.eval_ctx) (abs : V.abs)
* TODO: use an option to disallow backward functions from updating the state.
* TODO: a backward function which only gives back shared borrows shouldn't
* update the state (state updates should only be used for mutable borrows,
- * with objects like Rc for instance.
+ * with objects like Rc for instance).
*)
if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None
then (
@@ -1554,69 +1813,6 @@ and translate_end_abstraction (ectx : C.eval_ctx) (abs : V.abs)
assert (List.length inputs = List.length fwd_inputs);
next_e)
else mk_let effect_info.can_fail output call next_e
- | V.SynthRet rg_id ->
- (* If we end the abstraction which consumed the return value of the function
- we are synthesizing, we get back the borrows which were inside. Those borrows
- are actually input arguments of the backward function we are synthesizing.
- So we simply need to introduce proper let bindings.
-
- For instance:
- {[
- fn id<'a>(x : &'a mut u32) -> &'a mut u32 {
- x
- }
- ]}
-
- Upon ending the return abstraction for 'a, we get back the borrow for [x].
- This new value is the second argument of the backward function:
- {[
- let id_back x nx = nx
- ]}
-
- In practice, upon ending this abstraction we introduce a useless
- let-binding:
- {[
- let id_back x nx =
- let s = nx in // the name [s] is not important (only collision matters)
- ...
- ]}
-
- This let-binding later gets inlined, during a micro-pass.
- *)
- (* First, retrieve the list of variables used for the inputs for the
- * backward function *)
- let inputs = T.RegionGroupId.Map.find rg_id ctx.backward_inputs in
- (* Retrieve the values consumed upon ending the loans inside this
- * abstraction: as there are no nested borrows, there should be none. *)
- let consumed = abs_to_consumed ctx ectx abs in
- assert (consumed = []);
- (* Retrieve the values given back upon ending this abstraction - note that
- * we don't provide meta-place information, because those assignments will
- * be inlined anyway... *)
- log#ldebug (lazy ("abs: " ^ abs_to_string ctx abs));
- let ctx, given_back = abs_to_given_back_no_mp abs ctx in
- (* Link the inputs to those given back values - note that this also
- * checks we have the same number of values, of course *)
- let given_back_inputs = List.combine given_back inputs in
- (* Sanity check *)
- List.iter
- (fun ((given_back, input) : typed_pattern * var) ->
- log#ldebug
- (lazy
- ("\n- given_back ty: "
- ^ ty_to_string ctx given_back.ty
- ^ "\n- sig input ty: " ^ ty_to_string ctx input.ty));
- assert (given_back.ty = input.ty))
- given_back_inputs;
- (* Translate the next expression *)
- let next_e = translate_expression e ctx in
- (* Generate the assignments *)
- let monadic = false in
- List.fold_right
- (fun (given_back, input_var) e ->
- mk_let monadic given_back (mk_texpression_from_var input_var) e)
- given_back_inputs next_e
- | Loop _ -> raise (Failure "Unimplemented")
and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value)
(e : S.expression) (ctx : bs_ctx) : texpression =
@@ -1841,8 +2037,177 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
List.for_all (fun (br : match_branch) -> br.branch.ty = ty) branches);
{ e; ty }
+and translate_forward_end (ectx : C.eval_ctx)
+ (loop_input_values : V.typed_value S.symbolic_value_id_map option)
+ (e : S.expression) (back_e : S.expression S.region_group_id_map)
+ (ctx : bs_ctx) : texpression =
+ (* Update the current state with the additional state received by the backward
+ function, if needs be, and lookup the proper expression *)
+ let translate_end ctx =
+ (* Update the current state with the additional state received by the backward
+ function, if needs be, and lookup the proper expression *)
+ let ctx, e =
+ match ctx.bid with
+ | None -> (ctx, e)
+ | Some bid ->
+ let ctx = { ctx with state_var = ctx.back_state_var } in
+ let e = T.RegionGroupId.Map.find bid back_e in
+ (ctx, e)
+ in
+ translate_expression e ctx
+ in
+
+ (* If we entered/are entering a loop, we need to introduce a call to the
+ forward translation of the loop. *)
+ match loop_input_values with
+ | None ->
+ (* "Regular" case: not a loop *)
+ assert (ctx.loop_id = None);
+ translate_end ctx
+ | Some loop_input_values ->
+ (* Loop *)
+ let loop_id = Option.get ctx.loop_id in
+
+ (* Lookup the loop information *)
+ let loop_info = LoopId.Map.find loop_id ctx.loops in
+
+ (* Translate the input values *)
+ let loop_input_values =
+ List.map
+ (fun sv -> V.SymbolicValueId.Map.find sv.V.sv_id loop_input_values)
+ loop_info.input_svl
+ in
+ let args =
+ List.map (typed_value_to_texpression ctx ectx) loop_input_values
+ in
+
+ (* Lookup the effect info for the loop function *)
+ let fid = A.Regular ctx.fun_decl.A.def_id in
+ let effect_info =
+ get_fun_effect_info ctx.fun_context.fun_infos fid None ctx.bid
+ in
+
+ (* Introduce a fresh output value for the forward function *)
+ let ctx, output_var =
+ let output_ty = ctx.sg.output in
+ fresh_var None output_ty ctx
+ in
+ let args, ctx, out_pats =
+ let output_pat = mk_typed_pattern_from_var output_var None in
+
+ (* Depending on the function effects:
+ * - add the fuel
+ * - add the state input argument
+ * - generate a fresh state variable for the returned state
+ * TODO: we do exactly the same thing in {!translate_function_call}
+ *)
+ let fuel = mk_fuel_input_as_list ctx effect_info in
+ if effect_info.stateful then
+ let state_var = mk_state_texpression ctx.state_var in
+ let ctx, _nstate_var, nstate_pat = bs_ctx_fresh_state_var ctx in
+ ( List.concat [ fuel; args; [ state_var ] ],
+ ctx,
+ [ nstate_pat; output_pat ] )
+ else (List.concat [ fuel; args ], ctx, [ output_pat ])
+ in
+
+ (* Update the loop information in the context *)
+ let loop_info =
+ {
+ loop_info with
+ forward_inputs = Some args;
+ forward_output_no_state = Some output_var;
+ }
+ in
+ let ctx =
+ { ctx with loops = LoopId.Map.add loop_id loop_info ctx.loops }
+ in
+
+ (* Translate the end of the function *)
+ let next_e = translate_end ctx in
+
+ (* Introduce the call to the loop in the generated AST *)
+ let out_pat = mk_simpl_tuple_pattern out_pats in
+ let loop_call =
+ let fun_id = Fun (FromLlbc (fid, Some loop_id, None)) in
+ let func = { id = FunOrOp fun_id; type_args = loop_info.type_args } in
+ let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
+ let ret_ty =
+ if effect_info.can_fail then mk_result_ty out_pat.ty else out_pat.ty
+ in
+ let func_ty = mk_arrows input_tys ret_ty in
+ let func = { e = Qualif func; ty = func_ty } in
+ let call = mk_apps func args in
+ call
+ in
+ mk_let effect_info.can_fail out_pat loop_call next_e
+
and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
- raise (Failure "Unreachable")
+ let loop_id = V.LoopId.Map.find loop.loop_id ctx.loop_ids_map in
+
+ (* Translate the loop inputs *)
+ let inputs =
+ List.map
+ (fun sv -> V.SymbolicValueId.Map.find sv.V.sv_id ctx.sv_to_var)
+ loop.input_svalues
+ in
+ let inputs_lvs =
+ List.map (fun var -> mk_typed_pattern_from_var var None) inputs
+ in
+
+ (* Add the loop information in the context *)
+ let ctx =
+ assert (not (LoopId.Map.mem loop_id ctx.loops));
+
+ (* Note that we will retrieve the input values later in the [ForwardEnd]
+ (and will introduce the outputs at that moment, together with the actual
+ call to the loop forward function *)
+ let type_args =
+ List.map (fun ty -> TypeVar ty.T.index) ctx.sg.type_params
+ in
+
+ let loop_info =
+ {
+ loop_id;
+ input_svl = loop.input_svalues;
+ type_args;
+ forward_inputs = None;
+ forward_output_no_state = None;
+ }
+ in
+ let loops = LoopId.Map.add loop_id loop_info ctx.loops in
+ { ctx with loops }
+ in
+
+ (* Update the context to translate the function end *)
+ let ctx_end = { ctx with loop_id = Some loop_id } in
+ let fun_end = translate_expression loop.end_expr ctx_end in
+
+ (* Update the context for the loop body *)
+ let ctx_loop = { ctx_end with inside_loop = true } in
+ (* We also need to introduce variables for the symbolic values which are
+ introduced in the fixed point (we have to filter the list of symbolic
+ values, to remove the not fresh ones - the fixed point introduces some
+ symbolic values and keeps some others)... *)
+ let ctx_loop =
+ let svl =
+ List.filter
+ (fun (sv : V.symbolic_value) ->
+ V.SymbolicValueId.Set.mem sv.sv_id loop.fresh_svalues)
+ loop.input_svalues
+ in
+ let ctx_loop, _ = fresh_vars_for_symbolic_values svl ctx_loop in
+ ctx_loop
+ in
+
+ (* Translate the loop body *)
+ let loop_body = translate_expression loop.loop_expr ctx_loop in
+
+ (* Create the loop node and return *)
+ let loop = Loop { fun_end; loop_id; inputs; inputs_lvs; loop_body } in
+ assert (fun_end.ty = loop_body.ty);
+ let ty = fun_end.ty in
+ { e = loop; ty }
and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) :
texpression =
@@ -1947,7 +2312,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
| None -> None
| Some body ->
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos (Regular def_id) bid
+ get_fun_effect_info ctx.fun_context.fun_infos (Regular def_id) None
+ bid
in
let body = translate_expression body ctx in
(* Add a match over the fuel, if necessary *)