summaryrefslogtreecommitdiff
path: root/compiler/SymbolicToPure.ml
diff options
context:
space:
mode:
Diffstat (limited to 'compiler/SymbolicToPure.ml')
-rw-r--r--compiler/SymbolicToPure.ml173
1 files changed, 152 insertions, 21 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 8fa66f93..4ea7071b 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -54,8 +54,7 @@ type call_info = {
forward_inputs : texpression list;
(** Remember the list of inputs given to the forward function.
- Those inputs include the state input, if pertinent (in which case
- it is the last input).
+ Those inputs include the fuel and the state, if pertinent.
*)
backwards : (V.abs * texpression list) T.RegionGroupId.Map.t;
(** A map from region group id (i.e., backward function id) to
@@ -98,12 +97,23 @@ type bs_ctx = {
state may have been updated by the caller between the call to the
forward function and the call to the backward function.
*)
+ fuel0 : VarId.id;
+ (** The original fuel taken as input by the function (if we use fuel) *)
+ fuel : VarId.id;
+ (* The fuel to use for the recursive calls (if we use fuel) *)
forward_inputs : var list;
- (** The input parameters for the forward function *)
+ (** The input parameters for the forward function corresponding to the
+ translated Rust inputs (no fuel, no state).
+ *)
backward_inputs : var list T.RegionGroupId.Map.t;
- (** The input parameters for the backward functions *)
+ (** The additional input parameters for the backward functions coming
+ from the borrows consumed upon ending the lifetime (as a consequence
+ those don't include the backward state, if there is one).
+ *)
backward_outputs : var list T.RegionGroupId.Map.t;
- (** The variables that the backward functions will output *)
+ (** The variables that the backward functions will output, corresponding
+ to the borrows they give back (don't include the backward state)
+ *)
calls : call_info V.FunCallId.Map.t;
(** The function calls we encountered so far *)
abstractions : (V.abs * texpression list) V.AbstractionId.Map.t;
@@ -485,6 +495,19 @@ let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) :
let abs_ids = list_ancestor_abstractions_ids ctx abs in
List.map (fun id -> V.AbstractionId.Map.find id ctx.abstractions) abs_ids
+(** Small utility *)
+let function_uses_fuel (info : fun_effect_info) : bool =
+ !Config.use_fuel && info.can_diverge
+
+(** Small utility *)
+let mk_fuel_input_ty_as_list (info : fun_effect_info) : ty list =
+ if function_uses_fuel info then [ mk_fuel_ty ] else []
+
+(** Small utility *)
+let mk_fuel_input_as_list (ctx : bs_ctx) (info : fun_effect_info) :
+ texpression list =
+ if function_uses_fuel info then [ mk_fuel_texpression ctx.fuel ] else []
+
(** Small utility. *)
let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
(fun_id : A.fun_id) (gid : T.RegionGroupId.id option) : fun_effect_info =
@@ -495,12 +518,18 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
let stateful =
stateful_group && ((not !Config.backward_no_state_update) || gid = None)
in
- { can_fail = info.can_fail; stateful_group; stateful }
+ {
+ can_fail = info.can_fail;
+ stateful_group;
+ stateful;
+ can_diverge = info.divergent;
+ }
| A.Assumed aid ->
{
can_fail = Assumed.assumed_can_fail aid;
stateful_group = false;
stateful = false;
+ can_diverge = false;
}
(** Translate a function signature.
@@ -522,11 +551,15 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
let parents = list_ancestor_region_groups sg bid in
(Some bid, parents)
in
+ (* Is the function stateful, and can it fail? *)
+ let effect_info = get_fun_effect_info fun_infos fun_id bid in
(* List the inputs for:
+ * - the fuel
* - the forward function
* - the parent backward functions, in proper order
* - the current backward function (if it is a backward function)
*)
+ let fuel = mk_fuel_input_ty_as_list effect_info in
let fwd_inputs = List.map (translate_fwd_ty types_infos) sg.inputs in
(* For the backward functions: for now we don't supported nested borrows,
* so just check that there aren't parent regions *)
@@ -561,8 +594,6 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
*)
List.filter_map (translate_back_ty_for_gid gid) [ sg.output ]
in
- (* Is the function stateful, and can it fail? *)
- let effect_info = get_fun_effect_info fun_infos fun_id bid in
(* If the function is stateful, the inputs are:
- forward: [fwd_ty0, ..., fwd_tyn, state]
- backward:
@@ -593,7 +624,7 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
* - backward state input
*)
let inputs =
- List.concat [ fwd_inputs; fwd_state_ty; back_inputs; back_state_ty ]
+ List.concat [ fuel; fwd_inputs; fwd_state_ty; back_inputs; back_state_ty ]
in
(* Outputs *)
let output_names, doutputs =
@@ -641,17 +672,23 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
(* Type parameters *)
let type_params = sg.type_params in
(* Return *)
+ let has_fuel = fuel <> [] in
let num_fwd_inputs_no_state = List.length fwd_inputs in
+ let num_fwd_inputs_with_fuel_no_state =
+ (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *)
+ List.length fuel + num_fwd_inputs_no_state
+ in
let num_back_inputs_no_state =
if bid = None then None else Some (List.length back_inputs)
in
let info =
{
- num_fwd_inputs_no_state;
- num_fwd_inputs_with_state =
+ has_fuel;
+ num_fwd_inputs_with_fuel_no_state;
+ num_fwd_inputs_with_fuel_with_state =
(* We use the fact that [fwd_state_ty] has length 1 if there is a state,
and 0 otherwise *)
- num_fwd_inputs_no_state + List.length fwd_state_ty;
+ num_fwd_inputs_with_fuel_no_state + List.length fwd_state_ty;
num_back_inputs_no_state;
num_back_inputs_with_state =
(* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *)
@@ -1209,22 +1246,29 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
get_fun_effect_info ctx.fun_context.fun_infos fid None
in
(* If the function is stateful:
+ * - add the fuel
* - add the state input argument
* - generate a fresh state variable for the returned state
*)
let args, ctx, out_state =
+ let fuel = mk_fuel_input_as_list ctx effect_info in
if effect_info.stateful then
let state_var = mk_state_texpression ctx.state_var in
let ctx, nstate_var = bs_ctx_fresh_state_var ctx in
- (List.append args [ state_var ], ctx, Some nstate_var)
- else (args, ctx, None)
+ (List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var)
+ else (List.concat [ fuel; args ], ctx, None)
in
(* Register the function call *)
let ctx = bs_ctx_register_forward_call call_id call args ctx in
(ctx, func, effect_info, args, out_state)
| S.Unop E.Not ->
let effect_info =
- { can_fail = false; stateful_group = false; stateful = false }
+ {
+ can_fail = false;
+ stateful_group = false;
+ stateful = false;
+ can_diverge = false;
+ }
in
(ctx, Unop Not, effect_info, args, None)
| S.Unop E.Neg -> (
@@ -1234,14 +1278,24 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
(* Note that negation can lead to an overflow and thus fail (it
* is thus monadic) *)
let effect_info =
- { can_fail = true; stateful_group = false; stateful = false }
+ {
+ can_fail = true;
+ stateful_group = false;
+ stateful = false;
+ can_diverge = false;
+ }
in
(ctx, Unop (Neg int_ty), effect_info, args, None)
| _ -> raise (Failure "Unreachable"))
| S.Unop (E.Cast (src_ty, tgt_ty)) ->
(* Note that cast can fail *)
let effect_info =
- { can_fail = true; stateful_group = false; stateful = false }
+ {
+ can_fail = true;
+ stateful_group = false;
+ stateful = false;
+ can_diverge = false;
+ }
in
(ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None)
| S.Binop binop -> (
@@ -1255,6 +1309,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
can_fail = ExpressionsUtils.binop_can_fail binop;
stateful_group = false;
stateful = false;
+ can_diverge = false;
}
in
(ctx, Binop (binop, int_ty0), effect_info, args, None)
@@ -1353,9 +1408,10 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
let _forward, backwards = get_abs_ancestors ctx abs in
(* Retrieve the values consumed when we called the forward function and
* ended the parent backward functions: those give us part of the input
- * values (rmk: for now, as we disallow nested lifetimes, there can't be
+ * values (rem: for now, as we disallow nested lifetimes, there can't be
* parent backward functions).
- * Note that the forward inputs include the input state (if there is one). *)
+ * Note that the forward inputs **include the fuel and the input state**
+ * (if we use those). *)
let fwd_inputs = call_info.forward_inputs in
let back_ancestors_inputs =
List.concat (List.map (fun (_abs, args) -> args) backwards)
@@ -1762,6 +1818,71 @@ and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) :
let ty = next_e.ty in
{ e; ty }
+(** Wrap a function body in a match over the fuel to control termination. *)
+let wrap_in_match_fuel (body : texpression) (ctx : bs_ctx) : texpression =
+ let fuel0_var : var = mk_fuel_var ctx.fuel0 in
+ let fuel0 = mk_texpression_from_var fuel0_var in
+ let nfuel_var : var = mk_fuel_var ctx.fuel in
+ let nfuel_pat = mk_typed_pattern_from_var nfuel_var None in
+ let fail_branch =
+ mk_result_fail_texpression_with_error_id error_out_of_fuel_id body.ty
+ in
+ match !Config.backend with
+ | FStar ->
+ (* Generate an expression:
+ {[
+ if fuel0 = 0 then Fail OutOfFuel
+ else
+ let fuel = decrease fuel0 in
+ ...
+ }]
+ *)
+ (* Create the expression: [fuel0 = 0] *)
+ let check_fuel =
+ let func = { id = FunOrOp (Fun (Pure FuelEqZero)); type_args = [] } in
+ let func_ty = mk_arrow mk_fuel_ty mk_bool_ty in
+ let func = { e = Qualif func; ty = func_ty } in
+ mk_app func fuel0
+ in
+ (* Create the expression: [decrease fuel0] *)
+ let decrease_fuel =
+ let func = { id = FunOrOp (Fun (Pure FuelDecrease)); type_args = [] } in
+ let func_ty = mk_arrow mk_fuel_ty mk_fuel_ty in
+ let func = { e = Qualif func; ty = func_ty } in
+ mk_app func fuel0
+ in
+
+ (* Create the success branch *)
+ let monadic = false in
+ let success_branch = mk_let monadic nfuel_pat decrease_fuel body in
+
+ (* Put everything together *)
+ let match_e = Switch (check_fuel, If (fail_branch, success_branch)) in
+ let match_ty = body.ty in
+ { e = match_e; ty = match_ty }
+ | Coq ->
+ (* Generate an expression:
+ {[
+ match fuel0 with
+ | O -> Fail OutOfFuel
+ | S fuel ->
+ ...
+ }]
+ *)
+ (* Create the fail branch *)
+ let fail_pat = mk_adt_pattern mk_fuel_ty (Some fuel_zero_id) [] in
+ let fail_branch = { pat = fail_pat; branch = fail_branch } in
+ (* Create the success branch *)
+ let success_pat =
+ mk_adt_pattern mk_fuel_ty (Some fuel_succ_id) [ nfuel_pat ]
+ in
+ let success_branch = body in
+ let success_branch = { pat = success_pat; branch = success_branch } in
+ (* Put everything together *)
+ let match_ty = body.ty in
+ let match_e = Switch (fuel0, Match [ fail_branch; success_branch ]) in
+ { e = match_e; ty = match_ty }
+
let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
(* Translate *)
let def = ctx.fun_decl in
@@ -1784,12 +1905,22 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
match body with
| None -> None
| Some body ->
- let body = translate_expression body ctx in
let effect_info =
get_fun_effect_info ctx.fun_context.fun_infos (Regular def_id) bid
in
+ let body = translate_expression body ctx in
+ (* Add a match over the fuel, if necessary *)
+ let body =
+ if function_uses_fuel effect_info then wrap_in_match_fuel body ctx
+ else body
+ in
(* Sanity check *)
type_check_texpression ctx body;
+ (* Introduce the fuel parameter, if necessary *)
+ let fuel =
+ if function_uses_fuel effect_info then [ mk_fuel_var ctx.fuel0 ]
+ else []
+ in
(* Introduce the forward input state (the state at call site of the
* *forward* function), if necessary. *)
let fwd_state =
@@ -1821,7 +1952,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
(* Group the inputs together *)
let inputs =
List.concat
- [ ctx.forward_inputs; fwd_state; backward_inputs; back_state ]
+ [ fuel; ctx.forward_inputs; fwd_state; backward_inputs; back_state ]
in
let inputs_lvs =
List.map (fun v -> mk_typed_pattern_from_var v None) inputs