diff options
author | Son Ho | 2022-11-14 14:05:26 +0100 |
---|---|---|
committer | Son HO | 2022-11-14 14:21:04 +0100 |
commit | e5bd97f4ad08b277057a23094f2cc76abbeeaddb (patch) | |
tree | e729f7616e6aced7f78fb1b6f5beaec3f1d2b22f /compiler/SymbolicToPure.ml | |
parent | 5a96e28b8706ed945ccbb569881ca1888cd73ace (diff) |
Add a `-use-fuel` option
Diffstat (limited to 'compiler/SymbolicToPure.ml')
-rw-r--r-- | compiler/SymbolicToPure.ml | 173 |
1 files changed, 152 insertions, 21 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 8fa66f93..4ea7071b 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -54,8 +54,7 @@ type call_info = { forward_inputs : texpression list; (** Remember the list of inputs given to the forward function. - Those inputs include the state input, if pertinent (in which case - it is the last input). + Those inputs include the fuel and the state, if pertinent. *) backwards : (V.abs * texpression list) T.RegionGroupId.Map.t; (** A map from region group id (i.e., backward function id) to @@ -98,12 +97,23 @@ type bs_ctx = { state may have been updated by the caller between the call to the forward function and the call to the backward function. *) + fuel0 : VarId.id; + (** The original fuel taken as input by the function (if we use fuel) *) + fuel : VarId.id; + (* The fuel to use for the recursive calls (if we use fuel) *) forward_inputs : var list; - (** The input parameters for the forward function *) + (** The input parameters for the forward function corresponding to the + translated Rust inputs (no fuel, no state). + *) backward_inputs : var list T.RegionGroupId.Map.t; - (** The input parameters for the backward functions *) + (** The additional input parameters for the backward functions coming + from the borrows consumed upon ending the lifetime (as a consequence + those don't include the backward state, if there is one). + *) backward_outputs : var list T.RegionGroupId.Map.t; - (** The variables that the backward functions will output *) + (** The variables that the backward functions will output, corresponding + to the borrows they give back (don't include the backward state) + *) calls : call_info V.FunCallId.Map.t; (** The function calls we encountered so far *) abstractions : (V.abs * texpression list) V.AbstractionId.Map.t; @@ -485,6 +495,19 @@ let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) : let abs_ids = list_ancestor_abstractions_ids ctx abs in List.map (fun id -> V.AbstractionId.Map.find id ctx.abstractions) abs_ids +(** Small utility *) +let function_uses_fuel (info : fun_effect_info) : bool = + !Config.use_fuel && info.can_diverge + +(** Small utility *) +let mk_fuel_input_ty_as_list (info : fun_effect_info) : ty list = + if function_uses_fuel info then [ mk_fuel_ty ] else [] + +(** Small utility *) +let mk_fuel_input_as_list (ctx : bs_ctx) (info : fun_effect_info) : + texpression list = + if function_uses_fuel info then [ mk_fuel_texpression ctx.fuel ] else [] + (** Small utility. *) let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) (fun_id : A.fun_id) (gid : T.RegionGroupId.id option) : fun_effect_info = @@ -495,12 +518,18 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) let stateful = stateful_group && ((not !Config.backward_no_state_update) || gid = None) in - { can_fail = info.can_fail; stateful_group; stateful } + { + can_fail = info.can_fail; + stateful_group; + stateful; + can_diverge = info.divergent; + } | A.Assumed aid -> { can_fail = Assumed.assumed_can_fail aid; stateful_group = false; stateful = false; + can_diverge = false; } (** Translate a function signature. @@ -522,11 +551,15 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) let parents = list_ancestor_region_groups sg bid in (Some bid, parents) in + (* Is the function stateful, and can it fail? *) + let effect_info = get_fun_effect_info fun_infos fun_id bid in (* List the inputs for: + * - the fuel * - the forward function * - the parent backward functions, in proper order * - the current backward function (if it is a backward function) *) + let fuel = mk_fuel_input_ty_as_list effect_info in let fwd_inputs = List.map (translate_fwd_ty types_infos) sg.inputs in (* For the backward functions: for now we don't supported nested borrows, * so just check that there aren't parent regions *) @@ -561,8 +594,6 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) *) List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] in - (* Is the function stateful, and can it fail? *) - let effect_info = get_fun_effect_info fun_infos fun_id bid in (* If the function is stateful, the inputs are: - forward: [fwd_ty0, ..., fwd_tyn, state] - backward: @@ -593,7 +624,7 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) * - backward state input *) let inputs = - List.concat [ fwd_inputs; fwd_state_ty; back_inputs; back_state_ty ] + List.concat [ fuel; fwd_inputs; fwd_state_ty; back_inputs; back_state_ty ] in (* Outputs *) let output_names, doutputs = @@ -641,17 +672,23 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) (* Type parameters *) let type_params = sg.type_params in (* Return *) + let has_fuel = fuel <> [] in let num_fwd_inputs_no_state = List.length fwd_inputs in + let num_fwd_inputs_with_fuel_no_state = + (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *) + List.length fuel + num_fwd_inputs_no_state + in let num_back_inputs_no_state = if bid = None then None else Some (List.length back_inputs) in let info = { - num_fwd_inputs_no_state; - num_fwd_inputs_with_state = + has_fuel; + num_fwd_inputs_with_fuel_no_state; + num_fwd_inputs_with_fuel_with_state = (* We use the fact that [fwd_state_ty] has length 1 if there is a state, and 0 otherwise *) - num_fwd_inputs_no_state + List.length fwd_state_ty; + num_fwd_inputs_with_fuel_no_state + List.length fwd_state_ty; num_back_inputs_no_state; num_back_inputs_with_state = (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) @@ -1209,22 +1246,29 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : get_fun_effect_info ctx.fun_context.fun_infos fid None in (* If the function is stateful: + * - add the fuel * - add the state input argument * - generate a fresh state variable for the returned state *) let args, ctx, out_state = + let fuel = mk_fuel_input_as_list ctx effect_info in if effect_info.stateful then let state_var = mk_state_texpression ctx.state_var in let ctx, nstate_var = bs_ctx_fresh_state_var ctx in - (List.append args [ state_var ], ctx, Some nstate_var) - else (args, ctx, None) + (List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var) + else (List.concat [ fuel; args ], ctx, None) in (* Register the function call *) let ctx = bs_ctx_register_forward_call call_id call args ctx in (ctx, func, effect_info, args, out_state) | S.Unop E.Not -> let effect_info = - { can_fail = false; stateful_group = false; stateful = false } + { + can_fail = false; + stateful_group = false; + stateful = false; + can_diverge = false; + } in (ctx, Unop Not, effect_info, args, None) | S.Unop E.Neg -> ( @@ -1234,14 +1278,24 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (* Note that negation can lead to an overflow and thus fail (it * is thus monadic) *) let effect_info = - { can_fail = true; stateful_group = false; stateful = false } + { + can_fail = true; + stateful_group = false; + stateful = false; + can_diverge = false; + } in (ctx, Unop (Neg int_ty), effect_info, args, None) | _ -> raise (Failure "Unreachable")) | S.Unop (E.Cast (src_ty, tgt_ty)) -> (* Note that cast can fail *) let effect_info = - { can_fail = true; stateful_group = false; stateful = false } + { + can_fail = true; + stateful_group = false; + stateful = false; + can_diverge = false; + } in (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None) | S.Binop binop -> ( @@ -1255,6 +1309,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : can_fail = ExpressionsUtils.binop_can_fail binop; stateful_group = false; stateful = false; + can_diverge = false; } in (ctx, Binop (binop, int_ty0), effect_info, args, None) @@ -1353,9 +1408,10 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : let _forward, backwards = get_abs_ancestors ctx abs in (* Retrieve the values consumed when we called the forward function and * ended the parent backward functions: those give us part of the input - * values (rmk: for now, as we disallow nested lifetimes, there can't be + * values (rem: for now, as we disallow nested lifetimes, there can't be * parent backward functions). - * Note that the forward inputs include the input state (if there is one). *) + * Note that the forward inputs **include the fuel and the input state** + * (if we use those). *) let fwd_inputs = call_info.forward_inputs in let back_ancestors_inputs = List.concat (List.map (fun (_abs, args) -> args) backwards) @@ -1762,6 +1818,71 @@ and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) : let ty = next_e.ty in { e; ty } +(** Wrap a function body in a match over the fuel to control termination. *) +let wrap_in_match_fuel (body : texpression) (ctx : bs_ctx) : texpression = + let fuel0_var : var = mk_fuel_var ctx.fuel0 in + let fuel0 = mk_texpression_from_var fuel0_var in + let nfuel_var : var = mk_fuel_var ctx.fuel in + let nfuel_pat = mk_typed_pattern_from_var nfuel_var None in + let fail_branch = + mk_result_fail_texpression_with_error_id error_out_of_fuel_id body.ty + in + match !Config.backend with + | FStar -> + (* Generate an expression: + {[ + if fuel0 = 0 then Fail OutOfFuel + else + let fuel = decrease fuel0 in + ... + }] + *) + (* Create the expression: [fuel0 = 0] *) + let check_fuel = + let func = { id = FunOrOp (Fun (Pure FuelEqZero)); type_args = [] } in + let func_ty = mk_arrow mk_fuel_ty mk_bool_ty in + let func = { e = Qualif func; ty = func_ty } in + mk_app func fuel0 + in + (* Create the expression: [decrease fuel0] *) + let decrease_fuel = + let func = { id = FunOrOp (Fun (Pure FuelDecrease)); type_args = [] } in + let func_ty = mk_arrow mk_fuel_ty mk_fuel_ty in + let func = { e = Qualif func; ty = func_ty } in + mk_app func fuel0 + in + + (* Create the success branch *) + let monadic = false in + let success_branch = mk_let monadic nfuel_pat decrease_fuel body in + + (* Put everything together *) + let match_e = Switch (check_fuel, If (fail_branch, success_branch)) in + let match_ty = body.ty in + { e = match_e; ty = match_ty } + | Coq -> + (* Generate an expression: + {[ + match fuel0 with + | O -> Fail OutOfFuel + | S fuel -> + ... + }] + *) + (* Create the fail branch *) + let fail_pat = mk_adt_pattern mk_fuel_ty (Some fuel_zero_id) [] in + let fail_branch = { pat = fail_pat; branch = fail_branch } in + (* Create the success branch *) + let success_pat = + mk_adt_pattern mk_fuel_ty (Some fuel_succ_id) [ nfuel_pat ] + in + let success_branch = body in + let success_branch = { pat = success_pat; branch = success_branch } in + (* Put everything together *) + let match_ty = body.ty in + let match_e = Switch (fuel0, Match [ fail_branch; success_branch ]) in + { e = match_e; ty = match_ty } + let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = (* Translate *) let def = ctx.fun_decl in @@ -1784,12 +1905,22 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = match body with | None -> None | Some body -> - let body = translate_expression body ctx in let effect_info = get_fun_effect_info ctx.fun_context.fun_infos (Regular def_id) bid in + let body = translate_expression body ctx in + (* Add a match over the fuel, if necessary *) + let body = + if function_uses_fuel effect_info then wrap_in_match_fuel body ctx + else body + in (* Sanity check *) type_check_texpression ctx body; + (* Introduce the fuel parameter, if necessary *) + let fuel = + if function_uses_fuel effect_info then [ mk_fuel_var ctx.fuel0 ] + else [] + in (* Introduce the forward input state (the state at call site of the * *forward* function), if necessary. *) let fwd_state = @@ -1821,7 +1952,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = (* Group the inputs together *) let inputs = List.concat - [ ctx.forward_inputs; fwd_state; backward_inputs; back_state ] + [ fuel; ctx.forward_inputs; fwd_state; backward_inputs; back_state ] in let inputs_lvs = List.map (fun v -> mk_typed_pattern_from_var v None) inputs |