summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/Config.ml6
-rw-r--r--compiler/ConstStrings.ml3
-rw-r--r--compiler/Driver.ml8
-rw-r--r--compiler/Extract.ml29
-rw-r--r--compiler/ExtractBase.ml16
-rw-r--r--compiler/PrintPure.ml29
-rw-r--r--compiler/Pure.ml28
-rw-r--r--compiler/PureTypeCheck.ml9
-rw-r--r--compiler/PureUtils.ml16
-rw-r--r--compiler/SymbolicToPure.ml173
-rw-r--r--compiler/Translate.ml8
11 files changed, 272 insertions, 53 deletions
diff --git a/compiler/Config.ml b/compiler/Config.ml
index 28218b7b..f4280e80 100644
--- a/compiler/Config.ml
+++ b/compiler/Config.ml
@@ -110,7 +110,11 @@ let always_deconstruct_adts_with_matches = ref false
(** Controls whether we need to use a state to model the external world
(I/O, for instance).
*)
-let use_state = ref true (* TODO *)
+let use_state = ref true
+
+(** Controls whether we use fuel to control termination.
+ *)
+let use_fuel = ref false
(** Controls whether backward functions update the state, in case we use
a state ({!use_state}).
diff --git a/compiler/ConstStrings.ml b/compiler/ConstStrings.ml
index 6cf57fe4..b07950f4 100644
--- a/compiler/ConstStrings.ml
+++ b/compiler/ConstStrings.ml
@@ -8,3 +8,6 @@ let constructor_prefix = "Mk"
(** Basename for error variables *)
let error_basename = "e"
+
+(** Basename for the fuel variable *)
+let fuel_basename = "n"
diff --git a/compiler/Driver.ml b/compiler/Driver.ml
index 05a40ad1..15ad5a26 100644
--- a/compiler/Driver.ml
+++ b/compiler/Driver.ml
@@ -56,6 +56,9 @@ let () =
( "-no-state",
Arg.Clear use_state,
" Do not use state-error monads, simply use error monads" );
+ ( "-use-fuel",
+ Arg.Set use_fuel,
+ " Use a fuel parameter to control divergence" );
( "-backward-no-state-update",
Arg.Set backward_no_state_update,
" Forbid backward functions from updating the state" );
@@ -78,6 +81,11 @@ let () =
assert (!extract_decreases_clauses || not !extract_template_decreases_clauses);
(* Sanity check: -backward-no-state-update ==> not -no-state *)
assert ((not !backward_no_state_update) || !use_state);
+ (* Sanity check: the use of decrease clauses is not compatible with the use of fuel *)
+ assert (
+ (not !use_fuel)
+ || (not !extract_decreases_clauses)
+ && not !extract_template_decreases_clauses);
let spec = Arg.align spec in
let filenames = ref [] in
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 17b6aa54..6cd1462e 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -76,7 +76,6 @@ let keywords () =
"fn";
"val";
"int";
- "nat";
"list";
"FStar";
"FStar.Mul";
@@ -113,7 +112,6 @@ let keywords () =
"fun";
"type";
"int";
- "nat";
"match";
"with";
"assert";
@@ -130,6 +128,7 @@ let assumed_adts : (assumed_ty * string) list =
(State, "state");
(Result, "result");
(Error, "error");
+ (Fuel, "nat");
(Option, "option");
(Vec, "vec");
]
@@ -144,6 +143,8 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list =
(Result, result_fail_id, "Fail");
(Error, error_failure_id, "Failure");
(Error, error_out_of_fuel_id, "OutOfFuel");
+ (* No Fuel::Zero on purpose *)
+ (* No Fuel::Succ on purpose *)
(Option, option_some_id, "Some");
(Option, option_none_id, "None");
]
@@ -153,6 +154,8 @@ let assumed_variants () : (assumed_ty * VariantId.id * string) list =
(Result, result_fail_id, "Fail_");
(Error, error_failure_id, "Failure");
(Error, error_out_of_fuel_id, "OutOfFuel");
+ (Fuel, fuel_zero_id, "O");
+ (Fuel, fuel_succ_id, "S");
(Option, option_some_id, "Some");
(Option, option_none_id, "None");
]
@@ -177,8 +180,17 @@ let assumed_llbc_functions :
let assumed_pure_functions : (pure_assumed_fun_id * string) list =
match !backend with
- | FStar -> [ (Return, "return"); (Fail, "fail"); (Assert, "massert") ]
- | Coq -> [ (Return, "return_"); (Fail, "fail_"); (Assert, "massert") ]
+ | FStar ->
+ [
+ (Return, "return");
+ (Fail, "fail");
+ (Assert, "massert");
+ (FuelDecrease, "decrease");
+ (FuelEqZero, "is_zero");
+ ]
+ | Coq ->
+ (* We don't provide [FuelDecrease] and [FuelEqZero] on purpose *)
+ [ (Return, "return_"); (Fail, "fail_"); (Assert, "massert") ]
let names_map_init () : names_map_init =
{
@@ -439,10 +451,11 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
(* The "pair" case is frequent enough to have its special treatment *)
if List.length tys = 2 then "p" else "t"
| Assumed Result -> "r"
- | Assumed Error -> "e"
+ | Assumed Error -> ConstStrings.error_basename
+ | Assumed Fuel -> ConstStrings.fuel_basename
| Assumed Option -> "opt"
| Assumed Vec -> "v"
- | Assumed State -> "st"
+ | Assumed State -> ConstStrings.state_basename
| AdtId adt_id ->
let def =
TypeDeclId.Map.find adt_id ctx.type_context.type_decls
@@ -1808,7 +1821,9 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter)
*)
let inputs_lvs =
let all_inputs = (Option.get def.body).inputs_lvs in
- let num_fwd_inputs = def.signature.info.num_fwd_inputs_with_state in
+ let num_fwd_inputs =
+ def.signature.info.num_fwd_inputs_with_fuel_with_state
+ in
Collections.List.prefix num_fwd_inputs all_inputs
in
let _ =
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index 9690d9fc..b1901fca 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -413,7 +413,10 @@ type extraction_ctx = {
(** The indent increment we insert whenever we need to indent more *)
}
-(** Debugging function, used when communicating name collisions to the user *)
+(** Debugging function, used when communicating name collisions to the user,
+ and also to print ids for internal debugging (in case of lookup miss for
+ instance).
+ *)
let id_to_string (id : id) (ctx : extraction_ctx) : string =
let global_decls = ctx.trans_ctx.global_context.global_decls in
let fun_decls = ctx.trans_ctx.fun_context.fun_decls in
@@ -462,7 +465,6 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
let variant_name =
match id with
| Tuple -> raise (Failure "Unreachable")
- | Assumed State -> raise (Failure "Unreachable")
| Assumed Result ->
if variant_id = result_return_id then "@result::Return"
else if variant_id = result_fail_id then "@result::Fail"
@@ -475,7 +477,7 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
if variant_id = option_some_id then "@option::Some"
else if variant_id = option_none_id then "@option::None"
else raise (Failure "Unreachable")
- | Assumed Vec -> raise (Failure "Unreachable")
+ | Assumed (State | Vec | Fuel) -> raise (Failure "Unreachable")
| AdtId id -> (
let def = TypeDeclId.Map.find id type_decls in
match def.kind with
@@ -489,7 +491,7 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
let field_name =
match id with
| Tuple -> raise (Failure "Unreachable")
- | Assumed (State | Result | Error | Option) ->
+ | Assumed (State | Result | Error | Fuel | Option) ->
raise (Failure "Unreachable")
| Assumed Vec ->
(* We can't directly have access to the fields of a vector *)
@@ -509,10 +511,8 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
in
"field name: " ^ field_name
| UnknownId -> "keyword"
- | TypeVarId _ | VarId _ ->
- (* We should never get there: we add indices to make sure variable
- * names are unique *)
- raise (Failure "Unreachable")
+ | TypeVarId id -> "type_var_id: " ^ TypeVarId.to_string id
+ | VarId id -> "var_id: " ^ VarId.to_string id
let ctx_add (id : id) (name : string) (ctx : extraction_ctx) : extraction_ctx =
(* The id_to_string function to print nice debugging messages if there are
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index 0879f553..726cc9a0 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -129,6 +129,7 @@ let type_id_to_string (fmt : type_formatter) (id : type_id) : string =
| State -> "State"
| Result -> "Result"
| Error -> "Error"
+ | Fuel -> "Fuel"
| Option -> "Option"
| Vec -> "Vec")
@@ -240,7 +241,7 @@ let adt_variant_to_string (fmt : value_formatter) (adt_id : type_id)
(* Assumed type *)
match aty with
| State ->
- (* The [State] type is opaque: we can't get there *)
+ (* This type is opaque: we can't get there *)
raise (Failure "Unreachable")
| Result ->
let variant_id = Option.get variant_id in
@@ -253,6 +254,11 @@ let adt_variant_to_string (fmt : value_formatter) (adt_id : type_id)
if variant_id = error_failure_id then "@Error::Failure"
else if variant_id = error_out_of_fuel_id then "@Error::OutOfFuel"
else raise (Failure "Unreachable: improper variant id for error type")
+ | Fuel ->
+ let variant_id = Option.get variant_id in
+ if variant_id = fuel_zero_id then "@Fuel::Zero"
+ else if variant_id = fuel_succ_id then "@Fuel::Succ"
+ else raise (Failure "Unreachable: improper variant id for fuel type")
| Option ->
let variant_id = Option.get variant_id in
if variant_id = option_some_id then "@Option::Some "
@@ -278,7 +284,7 @@ let adt_field_to_string (fmt : value_formatter) (adt_id : type_id)
| Assumed aty -> (
(* Assumed type *)
match aty with
- | State | Vec ->
+ | State | Fuel | Vec ->
(* Opaque types: we can't get there *)
raise (Failure "Unreachable")
| Result | Error | Option ->
@@ -322,7 +328,7 @@ let adt_g_value_to_string (fmt : value_formatter)
(* Assumed type *)
match aty with
| State ->
- (* The [State] type is opaque: we can't get there *)
+ (* This type is opaque: we can't get there *)
raise (Failure "Unreachable")
| Result ->
let variant_id = Option.get variant_id in
@@ -342,6 +348,16 @@ let adt_g_value_to_string (fmt : value_formatter)
if variant_id = error_failure_id then "@Error::Failure"
else if variant_id = error_out_of_fuel_id then "@Error::OutOfFuel"
else raise (Failure "Unreachable: improper variant id for error type")
+ | Fuel ->
+ let variant_id = Option.get variant_id in
+ if variant_id = fuel_zero_id then (
+ assert (field_values = []);
+ "@Fuel::Zero")
+ else if variant_id = fuel_succ_id then
+ match field_values with
+ | [ v ] -> "@Fuel::Succ " ^ v
+ | _ -> raise (Failure "@Fuel::Succ takes exactly one value")
+ else raise (Failure "Unreachable: improper variant id for fuel type")
| Option ->
let variant_id = Option.get variant_id in
if variant_id = option_some_id then
@@ -419,7 +435,12 @@ let llbc_assumed_fun_id_to_string (fid : A.assumed_fun_id) : string =
| A.VecIndexMut -> "core::ops::index::IndexMut<alloc::vec::Vec>::index_mut"
let pure_assumed_fun_id_to_string (fid : pure_assumed_fun_id) : string =
- match fid with Return -> "return" | Fail -> "fail" | Assert -> "assert"
+ match fid with
+ | Return -> "return"
+ | Fail -> "fail"
+ | Assert -> "assert"
+ | FuelDecrease -> "fuel_decrease"
+ | FuelEqZero -> "fuel_eq_zero"
let regular_fun_id_to_string (fmt : ast_formatter) (fun_id : fun_id) : string =
match fun_id with
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index 6cc73bef..11f627d7 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -32,11 +32,15 @@ type integer_type = T.integer_type [@@deriving show, ord]
unified treatment of expressions (especially when we have to unfold the
monadic binds)
- [Error]: the kind of error, in case of failure (used by [Result])
+ - [Fuel]: the fuel, to control recursion (some theorem provers like Coq
+ don't support semantic termination, in which case we can use a fuel
+ parameter to do partial verification)
- [State]: the type of the state, when using state-error monads. Note that
this state is opaque to Aeneas (the user can define it, or leave it as
assumed)
*)
-type assumed_ty = State | Result | Error | Vec | Option [@@deriving show, ord]
+type assumed_ty = State | Result | Error | Fuel | Vec | Option
+[@@deriving show, ord]
(* TODO: we should never directly manipulate [Return] and [Fail], but rather
* the monadic functions [return] and [fail] (makes treatment of error and
@@ -48,6 +52,12 @@ let option_none_id = T.option_none_id
let error_failure_id = VariantId.of_int 0
let error_out_of_fuel_id = VariantId.of_int 1
+(* We don't always use those: it depends on the backend (we use natural numbers
+ for the fuel: in Coq they are enumerations, but in F* they are primitive)
+*)
+let fuel_zero_id = VariantId.of_int 0
+let fuel_succ_id = VariantId.of_int 1
+
type type_id = AdtId of TypeDeclId.id | Tuple | Assumed of assumed_ty
[@@deriving show, ord]
@@ -289,6 +299,9 @@ type pure_assumed_fun_id =
| Return (** The monadic return *)
| Fail (** The monadic fail *)
| Assert (** Assertion *)
+ | FuelDecrease
+ (** Decrease fuel, provided it is non zero (used for F* ) - TODO: this is ugly *)
+ | FuelEqZero (** Test if some fuel is equal to 0 - TODO: ugly *)
[@@deriving show, ord]
(** A function identifier *)
@@ -510,14 +523,19 @@ type fun_effect_info = {
*)
stateful : bool; (** [true] if the function is stateful (updates a state) *)
can_fail : bool; (** [true] if the return type is a [result] *)
+ can_diverge : bool;
+ (** [true] if the function can diverge (i.e., not terminate) *)
}
(** Meta information about a function signature *)
type fun_sig_info = {
- num_fwd_inputs_no_state : int;
- (** The number of input types for forward computation, ignoring the state (if there is one) *)
- num_fwd_inputs_with_state : int;
- (** The number of input types for forward computation, with the state (if there is one) *)
+ has_fuel : bool;
+ (* TODO: add [num_fwd_inputs_no_fuel_no_state] *)
+ num_fwd_inputs_with_fuel_no_state : int;
+ (** The number of input types for forward computation, with the fuel (if used)
+ and ignoring the state (if used) *)
+ num_fwd_inputs_with_fuel_with_state : int;
+ (** The number of input types for forward computation, with fuel and state (if used) *)
num_back_inputs_no_state : int option;
(** The number of additional inputs for the backward computation (if pertinent),
ignoring the state (if there is one) *)
diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml
index a1e4e834..fe4fb841 100644
--- a/compiler/PureTypeCheck.ml
+++ b/compiler/PureTypeCheck.ml
@@ -20,8 +20,8 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t)
(* Assumed type *)
match aty with
| State ->
- (* [State] is opaque *)
- raise (Failure "Unreachable: `State` values are opaque")
+ (* This type is opaque *)
+ raise (Failure "Unreachable: opaque type")
| Result ->
let ty = Collections.List.to_cons_nil tys in
let variant_id = Option.get variant_id in
@@ -35,6 +35,11 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t)
assert (
variant_id = error_failure_id || variant_id = error_out_of_fuel_id);
[]
+ | Fuel ->
+ let variant_id = Option.get variant_id in
+ if variant_id = fuel_zero_id then []
+ else if variant_id = fuel_succ_id then [ mk_fuel_ty ]
+ else raise (Failure "Unreachable: improper variant id for fuel type")
| Option ->
let ty = Collections.List.to_cons_nil tys in
let variant_id = Option.get variant_id in
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index f5c280fb..0f1d50f1 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -343,6 +343,7 @@ let mk_switch (scrut : texpression) (sb : switch_body) : texpression =
let mk_simpl_tuple_ty (tys : ty list) : ty =
match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys)
+let mk_bool_ty : ty = Bool
let mk_unit_ty : ty = Adt (Tuple, [])
let mk_unit_rvalue : texpression =
@@ -422,6 +423,7 @@ let type_decl_is_enum (def : T.type_decl) : bool =
let mk_state_ty : ty = Adt (Assumed State, [])
let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ])
let mk_error_ty : ty = Adt (Assumed Error, [])
+let mk_fuel_ty : ty = Adt (Assumed Fuel, [])
let mk_error (error : VariantId.id) : texpression =
let ty = mk_error_ty in
@@ -488,8 +490,14 @@ let mk_result_return_pattern (v : typed_pattern) : typed_pattern =
let opt_unmeta_mplace (e : texpression) : mplace option * texpression =
match e.e with Meta (MPlace mp, e) -> (Some mp, e) | _ -> (None, e)
-let mk_state_var (vid : VarId.id) : var =
- { id = vid; basename = Some ConstStrings.state_basename; ty = mk_state_ty }
+let mk_state_var (id : VarId.id) : var =
+ { id; basename = Some ConstStrings.state_basename; ty = mk_state_ty }
-let mk_state_texpression (vid : VarId.id) : texpression =
- { e = Var vid; ty = mk_state_ty }
+let mk_state_texpression (id : VarId.id) : texpression =
+ { e = Var id; ty = mk_state_ty }
+
+let mk_fuel_var (id : VarId.id) : var =
+ { id; basename = Some ConstStrings.fuel_basename; ty = mk_fuel_ty }
+
+let mk_fuel_texpression (id : VarId.id) : texpression =
+ { e = Var id; ty = mk_fuel_ty }
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 8fa66f93..4ea7071b 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -54,8 +54,7 @@ type call_info = {
forward_inputs : texpression list;
(** Remember the list of inputs given to the forward function.
- Those inputs include the state input, if pertinent (in which case
- it is the last input).
+ Those inputs include the fuel and the state, if pertinent.
*)
backwards : (V.abs * texpression list) T.RegionGroupId.Map.t;
(** A map from region group id (i.e., backward function id) to
@@ -98,12 +97,23 @@ type bs_ctx = {
state may have been updated by the caller between the call to the
forward function and the call to the backward function.
*)
+ fuel0 : VarId.id;
+ (** The original fuel taken as input by the function (if we use fuel) *)
+ fuel : VarId.id;
+ (* The fuel to use for the recursive calls (if we use fuel) *)
forward_inputs : var list;
- (** The input parameters for the forward function *)
+ (** The input parameters for the forward function corresponding to the
+ translated Rust inputs (no fuel, no state).
+ *)
backward_inputs : var list T.RegionGroupId.Map.t;
- (** The input parameters for the backward functions *)
+ (** The additional input parameters for the backward functions coming
+ from the borrows consumed upon ending the lifetime (as a consequence
+ those don't include the backward state, if there is one).
+ *)
backward_outputs : var list T.RegionGroupId.Map.t;
- (** The variables that the backward functions will output *)
+ (** The variables that the backward functions will output, corresponding
+ to the borrows they give back (don't include the backward state)
+ *)
calls : call_info V.FunCallId.Map.t;
(** The function calls we encountered so far *)
abstractions : (V.abs * texpression list) V.AbstractionId.Map.t;
@@ -485,6 +495,19 @@ let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) :
let abs_ids = list_ancestor_abstractions_ids ctx abs in
List.map (fun id -> V.AbstractionId.Map.find id ctx.abstractions) abs_ids
+(** Small utility *)
+let function_uses_fuel (info : fun_effect_info) : bool =
+ !Config.use_fuel && info.can_diverge
+
+(** Small utility *)
+let mk_fuel_input_ty_as_list (info : fun_effect_info) : ty list =
+ if function_uses_fuel info then [ mk_fuel_ty ] else []
+
+(** Small utility *)
+let mk_fuel_input_as_list (ctx : bs_ctx) (info : fun_effect_info) :
+ texpression list =
+ if function_uses_fuel info then [ mk_fuel_texpression ctx.fuel ] else []
+
(** Small utility. *)
let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
(fun_id : A.fun_id) (gid : T.RegionGroupId.id option) : fun_effect_info =
@@ -495,12 +518,18 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
let stateful =
stateful_group && ((not !Config.backward_no_state_update) || gid = None)
in
- { can_fail = info.can_fail; stateful_group; stateful }
+ {
+ can_fail = info.can_fail;
+ stateful_group;
+ stateful;
+ can_diverge = info.divergent;
+ }
| A.Assumed aid ->
{
can_fail = Assumed.assumed_can_fail aid;
stateful_group = false;
stateful = false;
+ can_diverge = false;
}
(** Translate a function signature.
@@ -522,11 +551,15 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
let parents = list_ancestor_region_groups sg bid in
(Some bid, parents)
in
+ (* Is the function stateful, and can it fail? *)
+ let effect_info = get_fun_effect_info fun_infos fun_id bid in
(* List the inputs for:
+ * - the fuel
* - the forward function
* - the parent backward functions, in proper order
* - the current backward function (if it is a backward function)
*)
+ let fuel = mk_fuel_input_ty_as_list effect_info in
let fwd_inputs = List.map (translate_fwd_ty types_infos) sg.inputs in
(* For the backward functions: for now we don't supported nested borrows,
* so just check that there aren't parent regions *)
@@ -561,8 +594,6 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
*)
List.filter_map (translate_back_ty_for_gid gid) [ sg.output ]
in
- (* Is the function stateful, and can it fail? *)
- let effect_info = get_fun_effect_info fun_infos fun_id bid in
(* If the function is stateful, the inputs are:
- forward: [fwd_ty0, ..., fwd_tyn, state]
- backward:
@@ -593,7 +624,7 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
* - backward state input
*)
let inputs =
- List.concat [ fwd_inputs; fwd_state_ty; back_inputs; back_state_ty ]
+ List.concat [ fuel; fwd_inputs; fwd_state_ty; back_inputs; back_state_ty ]
in
(* Outputs *)
let output_names, doutputs =
@@ -641,17 +672,23 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
(* Type parameters *)
let type_params = sg.type_params in
(* Return *)
+ let has_fuel = fuel <> [] in
let num_fwd_inputs_no_state = List.length fwd_inputs in
+ let num_fwd_inputs_with_fuel_no_state =
+ (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *)
+ List.length fuel + num_fwd_inputs_no_state
+ in
let num_back_inputs_no_state =
if bid = None then None else Some (List.length back_inputs)
in
let info =
{
- num_fwd_inputs_no_state;
- num_fwd_inputs_with_state =
+ has_fuel;
+ num_fwd_inputs_with_fuel_no_state;
+ num_fwd_inputs_with_fuel_with_state =
(* We use the fact that [fwd_state_ty] has length 1 if there is a state,
and 0 otherwise *)
- num_fwd_inputs_no_state + List.length fwd_state_ty;
+ num_fwd_inputs_with_fuel_no_state + List.length fwd_state_ty;
num_back_inputs_no_state;
num_back_inputs_with_state =
(* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *)
@@ -1209,22 +1246,29 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
get_fun_effect_info ctx.fun_context.fun_infos fid None
in
(* If the function is stateful:
+ * - add the fuel
* - add the state input argument
* - generate a fresh state variable for the returned state
*)
let args, ctx, out_state =
+ let fuel = mk_fuel_input_as_list ctx effect_info in
if effect_info.stateful then
let state_var = mk_state_texpression ctx.state_var in
let ctx, nstate_var = bs_ctx_fresh_state_var ctx in
- (List.append args [ state_var ], ctx, Some nstate_var)
- else (args, ctx, None)
+ (List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var)
+ else (List.concat [ fuel; args ], ctx, None)
in
(* Register the function call *)
let ctx = bs_ctx_register_forward_call call_id call args ctx in
(ctx, func, effect_info, args, out_state)
| S.Unop E.Not ->
let effect_info =
- { can_fail = false; stateful_group = false; stateful = false }
+ {
+ can_fail = false;
+ stateful_group = false;
+ stateful = false;
+ can_diverge = false;
+ }
in
(ctx, Unop Not, effect_info, args, None)
| S.Unop E.Neg -> (
@@ -1234,14 +1278,24 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
(* Note that negation can lead to an overflow and thus fail (it
* is thus monadic) *)
let effect_info =
- { can_fail = true; stateful_group = false; stateful = false }
+ {
+ can_fail = true;
+ stateful_group = false;
+ stateful = false;
+ can_diverge = false;
+ }
in
(ctx, Unop (Neg int_ty), effect_info, args, None)
| _ -> raise (Failure "Unreachable"))
| S.Unop (E.Cast (src_ty, tgt_ty)) ->
(* Note that cast can fail *)
let effect_info =
- { can_fail = true; stateful_group = false; stateful = false }
+ {
+ can_fail = true;
+ stateful_group = false;
+ stateful = false;
+ can_diverge = false;
+ }
in
(ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None)
| S.Binop binop -> (
@@ -1255,6 +1309,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
can_fail = ExpressionsUtils.binop_can_fail binop;
stateful_group = false;
stateful = false;
+ can_diverge = false;
}
in
(ctx, Binop (binop, int_ty0), effect_info, args, None)
@@ -1353,9 +1408,10 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
let _forward, backwards = get_abs_ancestors ctx abs in
(* Retrieve the values consumed when we called the forward function and
* ended the parent backward functions: those give us part of the input
- * values (rmk: for now, as we disallow nested lifetimes, there can't be
+ * values (rem: for now, as we disallow nested lifetimes, there can't be
* parent backward functions).
- * Note that the forward inputs include the input state (if there is one). *)
+ * Note that the forward inputs **include the fuel and the input state**
+ * (if we use those). *)
let fwd_inputs = call_info.forward_inputs in
let back_ancestors_inputs =
List.concat (List.map (fun (_abs, args) -> args) backwards)
@@ -1762,6 +1818,71 @@ and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) :
let ty = next_e.ty in
{ e; ty }
+(** Wrap a function body in a match over the fuel to control termination. *)
+let wrap_in_match_fuel (body : texpression) (ctx : bs_ctx) : texpression =
+ let fuel0_var : var = mk_fuel_var ctx.fuel0 in
+ let fuel0 = mk_texpression_from_var fuel0_var in
+ let nfuel_var : var = mk_fuel_var ctx.fuel in
+ let nfuel_pat = mk_typed_pattern_from_var nfuel_var None in
+ let fail_branch =
+ mk_result_fail_texpression_with_error_id error_out_of_fuel_id body.ty
+ in
+ match !Config.backend with
+ | FStar ->
+ (* Generate an expression:
+ {[
+ if fuel0 = 0 then Fail OutOfFuel
+ else
+ let fuel = decrease fuel0 in
+ ...
+ }]
+ *)
+ (* Create the expression: [fuel0 = 0] *)
+ let check_fuel =
+ let func = { id = FunOrOp (Fun (Pure FuelEqZero)); type_args = [] } in
+ let func_ty = mk_arrow mk_fuel_ty mk_bool_ty in
+ let func = { e = Qualif func; ty = func_ty } in
+ mk_app func fuel0
+ in
+ (* Create the expression: [decrease fuel0] *)
+ let decrease_fuel =
+ let func = { id = FunOrOp (Fun (Pure FuelDecrease)); type_args = [] } in
+ let func_ty = mk_arrow mk_fuel_ty mk_fuel_ty in
+ let func = { e = Qualif func; ty = func_ty } in
+ mk_app func fuel0
+ in
+
+ (* Create the success branch *)
+ let monadic = false in
+ let success_branch = mk_let monadic nfuel_pat decrease_fuel body in
+
+ (* Put everything together *)
+ let match_e = Switch (check_fuel, If (fail_branch, success_branch)) in
+ let match_ty = body.ty in
+ { e = match_e; ty = match_ty }
+ | Coq ->
+ (* Generate an expression:
+ {[
+ match fuel0 with
+ | O -> Fail OutOfFuel
+ | S fuel ->
+ ...
+ }]
+ *)
+ (* Create the fail branch *)
+ let fail_pat = mk_adt_pattern mk_fuel_ty (Some fuel_zero_id) [] in
+ let fail_branch = { pat = fail_pat; branch = fail_branch } in
+ (* Create the success branch *)
+ let success_pat =
+ mk_adt_pattern mk_fuel_ty (Some fuel_succ_id) [ nfuel_pat ]
+ in
+ let success_branch = body in
+ let success_branch = { pat = success_pat; branch = success_branch } in
+ (* Put everything together *)
+ let match_ty = body.ty in
+ let match_e = Switch (fuel0, Match [ fail_branch; success_branch ]) in
+ { e = match_e; ty = match_ty }
+
let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
(* Translate *)
let def = ctx.fun_decl in
@@ -1784,12 +1905,22 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
match body with
| None -> None
| Some body ->
- let body = translate_expression body ctx in
let effect_info =
get_fun_effect_info ctx.fun_context.fun_infos (Regular def_id) bid
in
+ let body = translate_expression body ctx in
+ (* Add a match over the fuel, if necessary *)
+ let body =
+ if function_uses_fuel effect_info then wrap_in_match_fuel body ctx
+ else body
+ in
(* Sanity check *)
type_check_texpression ctx body;
+ (* Introduce the fuel parameter, if necessary *)
+ let fuel =
+ if function_uses_fuel effect_info then [ mk_fuel_var ctx.fuel0 ]
+ else []
+ in
(* Introduce the forward input state (the state at call site of the
* *forward* function), if necessary. *)
let fwd_state =
@@ -1821,7 +1952,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
(* Group the inputs together *)
let inputs =
List.concat
- [ ctx.forward_inputs; fwd_state; backward_inputs; back_state ]
+ [ fuel; ctx.forward_inputs; fwd_state; backward_inputs; back_state ]
in
let inputs_lvs =
List.map (fun v -> mk_typed_pattern_from_var v None) inputs
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index b2a28710..7ed9859a 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -71,6 +71,8 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
let var_counter = Pure.VarId.generator_zero in
let state_var, var_counter = Pure.VarId.fresh var_counter in
let back_state_var, var_counter = Pure.VarId.fresh var_counter in
+ let fuel0, var_counter = Pure.VarId.fresh var_counter in
+ let fuel, var_counter = Pure.VarId.fresh var_counter in
let calls = V.FunCallId.Map.empty in
let abstractions = V.AbstractionId.Map.empty in
let type_context =
@@ -100,6 +102,8 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
var_counter;
state_var;
back_state_var;
+ fuel0;
+ fuel;
type_context;
fun_context;
global_context;
@@ -171,7 +175,9 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
let backward_inputs =
let sg = backward_sg.sg in
(* We need to ignore the forward state and the backward state *)
- let num_forward_inputs = sg.info.num_fwd_inputs_with_state in
+ let num_forward_inputs =
+ sg.info.num_fwd_inputs_with_fuel_with_state
+ in
let num_back_inputs = Option.get sg.info.num_back_inputs_no_state in
Collections.List.subslice sg.inputs num_forward_inputs num_back_inputs
in