diff options
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/Config.ml | 6 | ||||
-rw-r--r-- | compiler/ConstStrings.ml | 3 | ||||
-rw-r--r-- | compiler/Driver.ml | 8 | ||||
-rw-r--r-- | compiler/Extract.ml | 29 | ||||
-rw-r--r-- | compiler/ExtractBase.ml | 16 | ||||
-rw-r--r-- | compiler/PrintPure.ml | 29 | ||||
-rw-r--r-- | compiler/Pure.ml | 28 | ||||
-rw-r--r-- | compiler/PureTypeCheck.ml | 9 | ||||
-rw-r--r-- | compiler/PureUtils.ml | 16 | ||||
-rw-r--r-- | compiler/SymbolicToPure.ml | 173 | ||||
-rw-r--r-- | compiler/Translate.ml | 8 |
11 files changed, 272 insertions, 53 deletions
diff --git a/compiler/Config.ml b/compiler/Config.ml index 28218b7b..f4280e80 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -110,7 +110,11 @@ let always_deconstruct_adts_with_matches = ref false (** Controls whether we need to use a state to model the external world (I/O, for instance). *) -let use_state = ref true (* TODO *) +let use_state = ref true + +(** Controls whether we use fuel to control termination. + *) +let use_fuel = ref false (** Controls whether backward functions update the state, in case we use a state ({!use_state}). diff --git a/compiler/ConstStrings.ml b/compiler/ConstStrings.ml index 6cf57fe4..b07950f4 100644 --- a/compiler/ConstStrings.ml +++ b/compiler/ConstStrings.ml @@ -8,3 +8,6 @@ let constructor_prefix = "Mk" (** Basename for error variables *) let error_basename = "e" + +(** Basename for the fuel variable *) +let fuel_basename = "n" diff --git a/compiler/Driver.ml b/compiler/Driver.ml index 05a40ad1..15ad5a26 100644 --- a/compiler/Driver.ml +++ b/compiler/Driver.ml @@ -56,6 +56,9 @@ let () = ( "-no-state", Arg.Clear use_state, " Do not use state-error monads, simply use error monads" ); + ( "-use-fuel", + Arg.Set use_fuel, + " Use a fuel parameter to control divergence" ); ( "-backward-no-state-update", Arg.Set backward_no_state_update, " Forbid backward functions from updating the state" ); @@ -78,6 +81,11 @@ let () = assert (!extract_decreases_clauses || not !extract_template_decreases_clauses); (* Sanity check: -backward-no-state-update ==> not -no-state *) assert ((not !backward_no_state_update) || !use_state); + (* Sanity check: the use of decrease clauses is not compatible with the use of fuel *) + assert ( + (not !use_fuel) + || (not !extract_decreases_clauses) + && not !extract_template_decreases_clauses); let spec = Arg.align spec in let filenames = ref [] in diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 17b6aa54..6cd1462e 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -76,7 +76,6 @@ let keywords () = "fn"; "val"; "int"; - "nat"; "list"; "FStar"; "FStar.Mul"; @@ -113,7 +112,6 @@ let keywords () = "fun"; "type"; "int"; - "nat"; "match"; "with"; "assert"; @@ -130,6 +128,7 @@ let assumed_adts : (assumed_ty * string) list = (State, "state"); (Result, "result"); (Error, "error"); + (Fuel, "nat"); (Option, "option"); (Vec, "vec"); ] @@ -144,6 +143,8 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list = (Result, result_fail_id, "Fail"); (Error, error_failure_id, "Failure"); (Error, error_out_of_fuel_id, "OutOfFuel"); + (* No Fuel::Zero on purpose *) + (* No Fuel::Succ on purpose *) (Option, option_some_id, "Some"); (Option, option_none_id, "None"); ] @@ -153,6 +154,8 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list = (Result, result_fail_id, "Fail_"); (Error, error_failure_id, "Failure"); (Error, error_out_of_fuel_id, "OutOfFuel"); + (Fuel, fuel_zero_id, "O"); + (Fuel, fuel_succ_id, "S"); (Option, option_some_id, "Some"); (Option, option_none_id, "None"); ] @@ -177,8 +180,17 @@ let assumed_llbc_functions : let assumed_pure_functions : (pure_assumed_fun_id * string) list = match !backend with - | FStar -> [ (Return, "return"); (Fail, "fail"); (Assert, "massert") ] - | Coq -> [ (Return, "return_"); (Fail, "fail_"); (Assert, "massert") ] + | FStar -> + [ + (Return, "return"); + (Fail, "fail"); + (Assert, "massert"); + (FuelDecrease, "decrease"); + (FuelEqZero, "is_zero"); + ] + | Coq -> + (* We don't provide [FuelDecrease] and [FuelEqZero] on purpose *) + [ (Return, "return_"); (Fail, "fail_"); (Assert, "massert") ] let names_map_init () : names_map_init = { @@ -439,10 +451,11 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) (* The "pair" case is frequent enough to have its special treatment *) if List.length tys = 2 then "p" else "t" | Assumed Result -> "r" - | Assumed Error -> "e" + | Assumed Error -> ConstStrings.error_basename + | Assumed Fuel -> ConstStrings.fuel_basename | Assumed Option -> "opt" | Assumed Vec -> "v" - | Assumed State -> "st" + | Assumed State -> ConstStrings.state_basename | AdtId adt_id -> let def = TypeDeclId.Map.find adt_id ctx.type_context.type_decls @@ -1808,7 +1821,9 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) *) let inputs_lvs = let all_inputs = (Option.get def.body).inputs_lvs in - let num_fwd_inputs = def.signature.info.num_fwd_inputs_with_state in + let num_fwd_inputs = + def.signature.info.num_fwd_inputs_with_fuel_with_state + in Collections.List.prefix num_fwd_inputs all_inputs in let _ = diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index 9690d9fc..b1901fca 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -413,7 +413,10 @@ type extraction_ctx = { (** The indent increment we insert whenever we need to indent more *) } -(** Debugging function, used when communicating name collisions to the user *) +(** Debugging function, used when communicating name collisions to the user, + and also to print ids for internal debugging (in case of lookup miss for + instance). + *) let id_to_string (id : id) (ctx : extraction_ctx) : string = let global_decls = ctx.trans_ctx.global_context.global_decls in let fun_decls = ctx.trans_ctx.fun_context.fun_decls in @@ -462,7 +465,6 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = let variant_name = match id with | Tuple -> raise (Failure "Unreachable") - | Assumed State -> raise (Failure "Unreachable") | Assumed Result -> if variant_id = result_return_id then "@result::Return" else if variant_id = result_fail_id then "@result::Fail" @@ -475,7 +477,7 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = if variant_id = option_some_id then "@option::Some" else if variant_id = option_none_id then "@option::None" else raise (Failure "Unreachable") - | Assumed Vec -> raise (Failure "Unreachable") + | Assumed (State | Vec | Fuel) -> raise (Failure "Unreachable") | AdtId id -> ( let def = TypeDeclId.Map.find id type_decls in match def.kind with @@ -489,7 +491,7 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = let field_name = match id with | Tuple -> raise (Failure "Unreachable") - | Assumed (State | Result | Error | Option) -> + | Assumed (State | Result | Error | Fuel | Option) -> raise (Failure "Unreachable") | Assumed Vec -> (* We can't directly have access to the fields of a vector *) @@ -509,10 +511,8 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = in "field name: " ^ field_name | UnknownId -> "keyword" - | TypeVarId _ | VarId _ -> - (* We should never get there: we add indices to make sure variable - * names are unique *) - raise (Failure "Unreachable") + | TypeVarId id -> "type_var_id: " ^ TypeVarId.to_string id + | VarId id -> "var_id: " ^ VarId.to_string id let ctx_add (id : id) (name : string) (ctx : extraction_ctx) : extraction_ctx = (* The id_to_string function to print nice debugging messages if there are diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml index 0879f553..726cc9a0 100644 --- a/compiler/PrintPure.ml +++ b/compiler/PrintPure.ml @@ -129,6 +129,7 @@ let type_id_to_string (fmt : type_formatter) (id : type_id) : string = | State -> "State" | Result -> "Result" | Error -> "Error" + | Fuel -> "Fuel" | Option -> "Option" | Vec -> "Vec") @@ -240,7 +241,7 @@ let adt_variant_to_string (fmt : value_formatter) (adt_id : type_id) (* Assumed type *) match aty with | State -> - (* The [State] type is opaque: we can't get there *) + (* This type is opaque: we can't get there *) raise (Failure "Unreachable") | Result -> let variant_id = Option.get variant_id in @@ -253,6 +254,11 @@ let adt_variant_to_string (fmt : value_formatter) (adt_id : type_id) if variant_id = error_failure_id then "@Error::Failure" else if variant_id = error_out_of_fuel_id then "@Error::OutOfFuel" else raise (Failure "Unreachable: improper variant id for error type") + | Fuel -> + let variant_id = Option.get variant_id in + if variant_id = fuel_zero_id then "@Fuel::Zero" + else if variant_id = fuel_succ_id then "@Fuel::Succ" + else raise (Failure "Unreachable: improper variant id for fuel type") | Option -> let variant_id = Option.get variant_id in if variant_id = option_some_id then "@Option::Some " @@ -278,7 +284,7 @@ let adt_field_to_string (fmt : value_formatter) (adt_id : type_id) | Assumed aty -> ( (* Assumed type *) match aty with - | State | Vec -> + | State | Fuel | Vec -> (* Opaque types: we can't get there *) raise (Failure "Unreachable") | Result | Error | Option -> @@ -322,7 +328,7 @@ let adt_g_value_to_string (fmt : value_formatter) (* Assumed type *) match aty with | State -> - (* The [State] type is opaque: we can't get there *) + (* This type is opaque: we can't get there *) raise (Failure "Unreachable") | Result -> let variant_id = Option.get variant_id in @@ -342,6 +348,16 @@ let adt_g_value_to_string (fmt : value_formatter) if variant_id = error_failure_id then "@Error::Failure" else if variant_id = error_out_of_fuel_id then "@Error::OutOfFuel" else raise (Failure "Unreachable: improper variant id for error type") + | Fuel -> + let variant_id = Option.get variant_id in + if variant_id = fuel_zero_id then ( + assert (field_values = []); + "@Fuel::Zero") + else if variant_id = fuel_succ_id then + match field_values with + | [ v ] -> "@Fuel::Succ " ^ v + | _ -> raise (Failure "@Fuel::Succ takes exactly one value") + else raise (Failure "Unreachable: improper variant id for fuel type") | Option -> let variant_id = Option.get variant_id in if variant_id = option_some_id then @@ -419,7 +435,12 @@ let llbc_assumed_fun_id_to_string (fid : A.assumed_fun_id) : string = | A.VecIndexMut -> "core::ops::index::IndexMut<alloc::vec::Vec>::index_mut" let pure_assumed_fun_id_to_string (fid : pure_assumed_fun_id) : string = - match fid with Return -> "return" | Fail -> "fail" | Assert -> "assert" + match fid with + | Return -> "return" + | Fail -> "fail" + | Assert -> "assert" + | FuelDecrease -> "fuel_decrease" + | FuelEqZero -> "fuel_eq_zero" let regular_fun_id_to_string (fmt : ast_formatter) (fun_id : fun_id) : string = match fun_id with diff --git a/compiler/Pure.ml b/compiler/Pure.ml index 6cc73bef..11f627d7 100644 --- a/compiler/Pure.ml +++ b/compiler/Pure.ml @@ -32,11 +32,15 @@ type integer_type = T.integer_type [@@deriving show, ord] unified treatment of expressions (especially when we have to unfold the monadic binds) - [Error]: the kind of error, in case of failure (used by [Result]) + - [Fuel]: the fuel, to control recursion (some theorem provers like Coq + don't support semantic termination, in which case we can use a fuel + parameter to do partial verification) - [State]: the type of the state, when using state-error monads. Note that this state is opaque to Aeneas (the user can define it, or leave it as assumed) *) -type assumed_ty = State | Result | Error | Vec | Option [@@deriving show, ord] +type assumed_ty = State | Result | Error | Fuel | Vec | Option +[@@deriving show, ord] (* TODO: we should never directly manipulate [Return] and [Fail], but rather * the monadic functions [return] and [fail] (makes treatment of error and @@ -48,6 +52,12 @@ let option_none_id = T.option_none_id let error_failure_id = VariantId.of_int 0 let error_out_of_fuel_id = VariantId.of_int 1 +(* We don't always use those: it depends on the backend (we use natural numbers + for the fuel: in Coq they are enumerations, but in F* they are primitive) +*) +let fuel_zero_id = VariantId.of_int 0 +let fuel_succ_id = VariantId.of_int 1 + type type_id = AdtId of TypeDeclId.id | Tuple | Assumed of assumed_ty [@@deriving show, ord] @@ -289,6 +299,9 @@ type pure_assumed_fun_id = | Return (** The monadic return *) | Fail (** The monadic fail *) | Assert (** Assertion *) + | FuelDecrease + (** Decrease fuel, provided it is non zero (used for F* ) - TODO: this is ugly *) + | FuelEqZero (** Test if some fuel is equal to 0 - TODO: ugly *) [@@deriving show, ord] (** A function identifier *) @@ -510,14 +523,19 @@ type fun_effect_info = { *) stateful : bool; (** [true] if the function is stateful (updates a state) *) can_fail : bool; (** [true] if the return type is a [result] *) + can_diverge : bool; + (** [true] if the function can diverge (i.e., not terminate) *) } (** Meta information about a function signature *) type fun_sig_info = { - num_fwd_inputs_no_state : int; - (** The number of input types for forward computation, ignoring the state (if there is one) *) - num_fwd_inputs_with_state : int; - (** The number of input types for forward computation, with the state (if there is one) *) + has_fuel : bool; + (* TODO: add [num_fwd_inputs_no_fuel_no_state] *) + num_fwd_inputs_with_fuel_no_state : int; + (** The number of input types for forward computation, with the fuel (if used) + and ignoring the state (if used) *) + num_fwd_inputs_with_fuel_with_state : int; + (** The number of input types for forward computation, with fuel and state (if used) *) num_back_inputs_no_state : int option; (** The number of additional inputs for the backward computation (if pertinent), ignoring the state (if there is one) *) diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml index a1e4e834..fe4fb841 100644 --- a/compiler/PureTypeCheck.ml +++ b/compiler/PureTypeCheck.ml @@ -20,8 +20,8 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) (* Assumed type *) match aty with | State -> - (* [State] is opaque *) - raise (Failure "Unreachable: `State` values are opaque") + (* This type is opaque *) + raise (Failure "Unreachable: opaque type") | Result -> let ty = Collections.List.to_cons_nil tys in let variant_id = Option.get variant_id in @@ -35,6 +35,11 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) assert ( variant_id = error_failure_id || variant_id = error_out_of_fuel_id); [] + | Fuel -> + let variant_id = Option.get variant_id in + if variant_id = fuel_zero_id then [] + else if variant_id = fuel_succ_id then [ mk_fuel_ty ] + else raise (Failure "Unreachable: improper variant id for fuel type") | Option -> let ty = Collections.List.to_cons_nil tys in let variant_id = Option.get variant_id in diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml index f5c280fb..0f1d50f1 100644 --- a/compiler/PureUtils.ml +++ b/compiler/PureUtils.ml @@ -343,6 +343,7 @@ let mk_switch (scrut : texpression) (sb : switch_body) : texpression = let mk_simpl_tuple_ty (tys : ty list) : ty = match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys) +let mk_bool_ty : ty = Bool let mk_unit_ty : ty = Adt (Tuple, []) let mk_unit_rvalue : texpression = @@ -422,6 +423,7 @@ let type_decl_is_enum (def : T.type_decl) : bool = let mk_state_ty : ty = Adt (Assumed State, []) let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ]) let mk_error_ty : ty = Adt (Assumed Error, []) +let mk_fuel_ty : ty = Adt (Assumed Fuel, []) let mk_error (error : VariantId.id) : texpression = let ty = mk_error_ty in @@ -488,8 +490,14 @@ let mk_result_return_pattern (v : typed_pattern) : typed_pattern = let opt_unmeta_mplace (e : texpression) : mplace option * texpression = match e.e with Meta (MPlace mp, e) -> (Some mp, e) | _ -> (None, e) -let mk_state_var (vid : VarId.id) : var = - { id = vid; basename = Some ConstStrings.state_basename; ty = mk_state_ty } +let mk_state_var (id : VarId.id) : var = + { id; basename = Some ConstStrings.state_basename; ty = mk_state_ty } -let mk_state_texpression (vid : VarId.id) : texpression = - { e = Var vid; ty = mk_state_ty } +let mk_state_texpression (id : VarId.id) : texpression = + { e = Var id; ty = mk_state_ty } + +let mk_fuel_var (id : VarId.id) : var = + { id; basename = Some ConstStrings.fuel_basename; ty = mk_fuel_ty } + +let mk_fuel_texpression (id : VarId.id) : texpression = + { e = Var id; ty = mk_fuel_ty } diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 8fa66f93..4ea7071b 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -54,8 +54,7 @@ type call_info = { forward_inputs : texpression list; (** Remember the list of inputs given to the forward function. - Those inputs include the state input, if pertinent (in which case - it is the last input). + Those inputs include the fuel and the state, if pertinent. *) backwards : (V.abs * texpression list) T.RegionGroupId.Map.t; (** A map from region group id (i.e., backward function id) to @@ -98,12 +97,23 @@ type bs_ctx = { state may have been updated by the caller between the call to the forward function and the call to the backward function. *) + fuel0 : VarId.id; + (** The original fuel taken as input by the function (if we use fuel) *) + fuel : VarId.id; + (* The fuel to use for the recursive calls (if we use fuel) *) forward_inputs : var list; - (** The input parameters for the forward function *) + (** The input parameters for the forward function corresponding to the + translated Rust inputs (no fuel, no state). + *) backward_inputs : var list T.RegionGroupId.Map.t; - (** The input parameters for the backward functions *) + (** The additional input parameters for the backward functions coming + from the borrows consumed upon ending the lifetime (as a consequence + those don't include the backward state, if there is one). + *) backward_outputs : var list T.RegionGroupId.Map.t; - (** The variables that the backward functions will output *) + (** The variables that the backward functions will output, corresponding + to the borrows they give back (don't include the backward state) + *) calls : call_info V.FunCallId.Map.t; (** The function calls we encountered so far *) abstractions : (V.abs * texpression list) V.AbstractionId.Map.t; @@ -485,6 +495,19 @@ let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) : let abs_ids = list_ancestor_abstractions_ids ctx abs in List.map (fun id -> V.AbstractionId.Map.find id ctx.abstractions) abs_ids +(** Small utility *) +let function_uses_fuel (info : fun_effect_info) : bool = + !Config.use_fuel && info.can_diverge + +(** Small utility *) +let mk_fuel_input_ty_as_list (info : fun_effect_info) : ty list = + if function_uses_fuel info then [ mk_fuel_ty ] else [] + +(** Small utility *) +let mk_fuel_input_as_list (ctx : bs_ctx) (info : fun_effect_info) : + texpression list = + if function_uses_fuel info then [ mk_fuel_texpression ctx.fuel ] else [] + (** Small utility. *) let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) (fun_id : A.fun_id) (gid : T.RegionGroupId.id option) : fun_effect_info = @@ -495,12 +518,18 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t) let stateful = stateful_group && ((not !Config.backward_no_state_update) || gid = None) in - { can_fail = info.can_fail; stateful_group; stateful } + { + can_fail = info.can_fail; + stateful_group; + stateful; + can_diverge = info.divergent; + } | A.Assumed aid -> { can_fail = Assumed.assumed_can_fail aid; stateful_group = false; stateful = false; + can_diverge = false; } (** Translate a function signature. @@ -522,11 +551,15 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) let parents = list_ancestor_region_groups sg bid in (Some bid, parents) in + (* Is the function stateful, and can it fail? *) + let effect_info = get_fun_effect_info fun_infos fun_id bid in (* List the inputs for: + * - the fuel * - the forward function * - the parent backward functions, in proper order * - the current backward function (if it is a backward function) *) + let fuel = mk_fuel_input_ty_as_list effect_info in let fwd_inputs = List.map (translate_fwd_ty types_infos) sg.inputs in (* For the backward functions: for now we don't supported nested borrows, * so just check that there aren't parent regions *) @@ -561,8 +594,6 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) *) List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] in - (* Is the function stateful, and can it fail? *) - let effect_info = get_fun_effect_info fun_infos fun_id bid in (* If the function is stateful, the inputs are: - forward: [fwd_ty0, ..., fwd_tyn, state] - backward: @@ -593,7 +624,7 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) * - backward state input *) let inputs = - List.concat [ fwd_inputs; fwd_state_ty; back_inputs; back_state_ty ] + List.concat [ fuel; fwd_inputs; fwd_state_ty; back_inputs; back_state_ty ] in (* Outputs *) let output_names, doutputs = @@ -641,17 +672,23 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t) (* Type parameters *) let type_params = sg.type_params in (* Return *) + let has_fuel = fuel <> [] in let num_fwd_inputs_no_state = List.length fwd_inputs in + let num_fwd_inputs_with_fuel_no_state = + (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *) + List.length fuel + num_fwd_inputs_no_state + in let num_back_inputs_no_state = if bid = None then None else Some (List.length back_inputs) in let info = { - num_fwd_inputs_no_state; - num_fwd_inputs_with_state = + has_fuel; + num_fwd_inputs_with_fuel_no_state; + num_fwd_inputs_with_fuel_with_state = (* We use the fact that [fwd_state_ty] has length 1 if there is a state, and 0 otherwise *) - num_fwd_inputs_no_state + List.length fwd_state_ty; + num_fwd_inputs_with_fuel_no_state + List.length fwd_state_ty; num_back_inputs_no_state; num_back_inputs_with_state = (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) @@ -1209,22 +1246,29 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : get_fun_effect_info ctx.fun_context.fun_infos fid None in (* If the function is stateful: + * - add the fuel * - add the state input argument * - generate a fresh state variable for the returned state *) let args, ctx, out_state = + let fuel = mk_fuel_input_as_list ctx effect_info in if effect_info.stateful then let state_var = mk_state_texpression ctx.state_var in let ctx, nstate_var = bs_ctx_fresh_state_var ctx in - (List.append args [ state_var ], ctx, Some nstate_var) - else (args, ctx, None) + (List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var) + else (List.concat [ fuel; args ], ctx, None) in (* Register the function call *) let ctx = bs_ctx_register_forward_call call_id call args ctx in (ctx, func, effect_info, args, out_state) | S.Unop E.Not -> let effect_info = - { can_fail = false; stateful_group = false; stateful = false } + { + can_fail = false; + stateful_group = false; + stateful = false; + can_diverge = false; + } in (ctx, Unop Not, effect_info, args, None) | S.Unop E.Neg -> ( @@ -1234,14 +1278,24 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (* Note that negation can lead to an overflow and thus fail (it * is thus monadic) *) let effect_info = - { can_fail = true; stateful_group = false; stateful = false } + { + can_fail = true; + stateful_group = false; + stateful = false; + can_diverge = false; + } in (ctx, Unop (Neg int_ty), effect_info, args, None) | _ -> raise (Failure "Unreachable")) | S.Unop (E.Cast (src_ty, tgt_ty)) -> (* Note that cast can fail *) let effect_info = - { can_fail = true; stateful_group = false; stateful = false } + { + can_fail = true; + stateful_group = false; + stateful = false; + can_diverge = false; + } in (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None) | S.Binop binop -> ( @@ -1255,6 +1309,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : can_fail = ExpressionsUtils.binop_can_fail binop; stateful_group = false; stateful = false; + can_diverge = false; } in (ctx, Binop (binop, int_ty0), effect_info, args, None) @@ -1353,9 +1408,10 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : let _forward, backwards = get_abs_ancestors ctx abs in (* Retrieve the values consumed when we called the forward function and * ended the parent backward functions: those give us part of the input - * values (rmk: for now, as we disallow nested lifetimes, there can't be + * values (rem: for now, as we disallow nested lifetimes, there can't be * parent backward functions). - * Note that the forward inputs include the input state (if there is one). *) + * Note that the forward inputs **include the fuel and the input state** + * (if we use those). *) let fwd_inputs = call_info.forward_inputs in let back_ancestors_inputs = List.concat (List.map (fun (_abs, args) -> args) backwards) @@ -1762,6 +1818,71 @@ and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) : let ty = next_e.ty in { e; ty } +(** Wrap a function body in a match over the fuel to control termination. *) +let wrap_in_match_fuel (body : texpression) (ctx : bs_ctx) : texpression = + let fuel0_var : var = mk_fuel_var ctx.fuel0 in + let fuel0 = mk_texpression_from_var fuel0_var in + let nfuel_var : var = mk_fuel_var ctx.fuel in + let nfuel_pat = mk_typed_pattern_from_var nfuel_var None in + let fail_branch = + mk_result_fail_texpression_with_error_id error_out_of_fuel_id body.ty + in + match !Config.backend with + | FStar -> + (* Generate an expression: + {[ + if fuel0 = 0 then Fail OutOfFuel + else + let fuel = decrease fuel0 in + ... + }] + *) + (* Create the expression: [fuel0 = 0] *) + let check_fuel = + let func = { id = FunOrOp (Fun (Pure FuelEqZero)); type_args = [] } in + let func_ty = mk_arrow mk_fuel_ty mk_bool_ty in + let func = { e = Qualif func; ty = func_ty } in + mk_app func fuel0 + in + (* Create the expression: [decrease fuel0] *) + let decrease_fuel = + let func = { id = FunOrOp (Fun (Pure FuelDecrease)); type_args = [] } in + let func_ty = mk_arrow mk_fuel_ty mk_fuel_ty in + let func = { e = Qualif func; ty = func_ty } in + mk_app func fuel0 + in + + (* Create the success branch *) + let monadic = false in + let success_branch = mk_let monadic nfuel_pat decrease_fuel body in + + (* Put everything together *) + let match_e = Switch (check_fuel, If (fail_branch, success_branch)) in + let match_ty = body.ty in + { e = match_e; ty = match_ty } + | Coq -> + (* Generate an expression: + {[ + match fuel0 with + | O -> Fail OutOfFuel + | S fuel -> + ... + }] + *) + (* Create the fail branch *) + let fail_pat = mk_adt_pattern mk_fuel_ty (Some fuel_zero_id) [] in + let fail_branch = { pat = fail_pat; branch = fail_branch } in + (* Create the success branch *) + let success_pat = + mk_adt_pattern mk_fuel_ty (Some fuel_succ_id) [ nfuel_pat ] + in + let success_branch = body in + let success_branch = { pat = success_pat; branch = success_branch } in + (* Put everything together *) + let match_ty = body.ty in + let match_e = Switch (fuel0, Match [ fail_branch; success_branch ]) in + { e = match_e; ty = match_ty } + let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = (* Translate *) let def = ctx.fun_decl in @@ -1784,12 +1905,22 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = match body with | None -> None | Some body -> - let body = translate_expression body ctx in let effect_info = get_fun_effect_info ctx.fun_context.fun_infos (Regular def_id) bid in + let body = translate_expression body ctx in + (* Add a match over the fuel, if necessary *) + let body = + if function_uses_fuel effect_info then wrap_in_match_fuel body ctx + else body + in (* Sanity check *) type_check_texpression ctx body; + (* Introduce the fuel parameter, if necessary *) + let fuel = + if function_uses_fuel effect_info then [ mk_fuel_var ctx.fuel0 ] + else [] + in (* Introduce the forward input state (the state at call site of the * *forward* function), if necessary. *) let fwd_state = @@ -1821,7 +1952,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl = (* Group the inputs together *) let inputs = List.concat - [ ctx.forward_inputs; fwd_state; backward_inputs; back_state ] + [ fuel; ctx.forward_inputs; fwd_state; backward_inputs; back_state ] in let inputs_lvs = List.map (fun v -> mk_typed_pattern_from_var v None) inputs diff --git a/compiler/Translate.ml b/compiler/Translate.ml index b2a28710..7ed9859a 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -71,6 +71,8 @@ let translate_function_to_pure (trans_ctx : trans_ctx) let var_counter = Pure.VarId.generator_zero in let state_var, var_counter = Pure.VarId.fresh var_counter in let back_state_var, var_counter = Pure.VarId.fresh var_counter in + let fuel0, var_counter = Pure.VarId.fresh var_counter in + let fuel, var_counter = Pure.VarId.fresh var_counter in let calls = V.FunCallId.Map.empty in let abstractions = V.AbstractionId.Map.empty in let type_context = @@ -100,6 +102,8 @@ let translate_function_to_pure (trans_ctx : trans_ctx) var_counter; state_var; back_state_var; + fuel0; + fuel; type_context; fun_context; global_context; @@ -171,7 +175,9 @@ let translate_function_to_pure (trans_ctx : trans_ctx) let backward_inputs = let sg = backward_sg.sg in (* We need to ignore the forward state and the backward state *) - let num_forward_inputs = sg.info.num_fwd_inputs_with_state in + let num_forward_inputs = + sg.info.num_fwd_inputs_with_fuel_with_state + in let num_back_inputs = Option.get sg.info.num_back_inputs_no_state in Collections.List.subslice sg.inputs num_forward_inputs num_back_inputs in |