summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorSon Ho2022-11-09 18:04:03 +0100
committerSon HO2022-11-10 11:35:30 +0100
commit8b6f8e5fb85bcd1b3257550270c4c857d4ee7f55 (patch)
treeb0090425eb850af3b5c8dc1d4f6aa1eafe2c8e1a /compiler
parentb970183881379ff676b232e47e353e924de8cfdd (diff)
Implement the generation of stateful backward functions (controlled by an option)
Diffstat (limited to 'compiler')
-rw-r--r--compiler/Driver.ml12
-rw-r--r--compiler/ExtractToFStar.ml8
-rw-r--r--compiler/FunsAnalysis.ml2
-rw-r--r--compiler/Interpreter.ml11
-rw-r--r--compiler/Pure.ml35
-rw-r--r--compiler/PureUtils.ml6
-rw-r--r--compiler/SymbolicAst.ml8
-rw-r--r--compiler/SymbolicToPure.ml383
-rw-r--r--compiler/SynthesizeSymbolic.ml66
-rw-r--r--compiler/Translate.ml63
10 files changed, 386 insertions, 208 deletions
diff --git a/compiler/Driver.ml b/compiler/Driver.ml
index d19aca93..6f0e8074 100644
--- a/compiler/Driver.ml
+++ b/compiler/Driver.ml
@@ -38,6 +38,8 @@ let () =
let test_trans_units = ref false in
let no_decreases_clauses = ref false in
let no_state = ref false in
+ (* [backward_no_state_update]: see the comment for {!Translate.config.backward_no_state_update} *)
+ let backward_no_state_update = ref false in
let template_decreases_clauses = ref false in
let no_split_files = ref false in
let no_check_inv = ref false in
@@ -78,6 +80,9 @@ let () =
( "-no-state",
Arg.Set no_state,
" Do not use state-error monads, simply use error monads" );
+ ( "-backward-no-state-update",
+ Arg.Set backward_no_state_update,
+ " Forbid backward functions from updating the state" );
( "-template-clauses",
Arg.Set template_decreases_clauses,
" Generate templates for the required decreases clauses, in a\n\
@@ -95,6 +100,8 @@ let () =
in
(* Sanity check: -template-clauses ==> not -no-decrease-clauses *)
assert ((not !no_decreases_clauses) || not !template_decreases_clauses);
+ (* Sanity check: -backward-no-state-update ==> not -no-state *)
+ assert ((not !backward_no_state_update) || not !no_state);
let spec = Arg.align spec in
let filenames = ref [] in
@@ -110,10 +117,10 @@ let () =
| [ f ] ->
(* TODO: update the extension *)
if not (Filename.check_suffix f ".llbc") then (
- print_string "Unrecognized file extension";
+ print_string ("Unrecognized file extension: " ^ f ^ "\n");
fail ())
else if not (Sys.file_exists f) then (
- print_string "File not found";
+ print_string ("File not found: " ^ f ^ "\n");
fail ())
else f
| _ ->
@@ -198,6 +205,7 @@ let () =
extract_decreases_clauses = not !no_decreases_clauses;
extract_template_decreases_clauses = !template_decreases_clauses;
use_state = not !no_state;
+ backward_no_state_update = !backward_no_state_update;
}
in
Translate.translate_module filename dest_dir trans_config m;
diff --git a/compiler/ExtractToFStar.ml b/compiler/ExtractToFStar.ml
index 2a7d6a6c..a995d4a6 100644
--- a/compiler/ExtractToFStar.ml
+++ b/compiler/ExtractToFStar.ml
@@ -1451,17 +1451,19 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter)
* function (the additional input values "given back" to the
* backward functions have no influence on termination: we thus
* share the decrease clauses between the forward and the backward
- * functions).
+ * functions - we also ignore the additional state received by the
+ * backward function, if there is one).
*)
let inputs_lvs =
let all_inputs = (Option.get def.body).inputs_lvs in
(* We have to count:
* - the forward inputs
- * - the state
+ * - the state (if there is one)
*)
let num_fwd_inputs = def.signature.info.num_fwd_inputs in
let num_fwd_inputs =
- if def.signature.info.effect_info.input_state then 1 + num_fwd_inputs
+ if def.signature.info.effect_info.stateful_group then
+ 1 + num_fwd_inputs
else num_fwd_inputs
in
Collections.List.prefix num_fwd_inputs all_inputs
diff --git a/compiler/FunsAnalysis.ml b/compiler/FunsAnalysis.ml
index 9413bd6a..4d33056b 100644
--- a/compiler/FunsAnalysis.ml
+++ b/compiler/FunsAnalysis.ml
@@ -16,7 +16,7 @@ module EU = ExpressionsUtils
*)
type fun_info = {
can_fail : bool;
- (* Not used yet: all the extracted functions use an error monad *)
+ (* Not used yet: all the extracted functions use an error monad *)
stateful : bool;
divergent : bool; (* Not used yet *)
}
diff --git a/compiler/Interpreter.ml b/compiler/Interpreter.ml
index d3b3c7e6..e752594e 100644
--- a/compiler/Interpreter.ml
+++ b/compiler/Interpreter.ml
@@ -253,8 +253,15 @@ let evaluate_function_symbolic (config : C.partial_config) (synthesize : bool)
cf_pop cf_return ctx
| Some back_id ->
(* Backward translation *)
- evaluate_function_symbolic_synthesize_backward_from_return config
- fdef inst_sg back_id ctx
+ let e =
+ evaluate_function_symbolic_synthesize_backward_from_return
+ config fdef inst_sg back_id ctx
+ in
+ (* We insert a delimiter to indicate the point where we switch
+ * from the part which is common to all the functions (forwards
+ * and backwards) and the part specific to this backward function.
+ *)
+ S.synthesize_forward_end e
else None
| Panic ->
(* Note that as we explore all the execution branches, one of
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index a50dd5f9..cc29469a 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -496,19 +496,26 @@ and meta =
(** Information about the "effect" of a function *)
type fun_effect_info = {
- input_state : bool; (** [true] if the function takes a state as input *)
- output_state : bool;
- (** [true] if the function outputs a state (it then lives
- in a state monad) *)
+ stateful_group : bool;
+ (** [true] if the function group is stateful. By *function group*, we mean
+ the set { forward function } U { backward functions }.
+
+ We need this because the option {!Translate.eval_config.backward_no_state_update}:
+ if it is [true], then in case of a backward function {!stateful} is [false],
+ but we might need to know whether the corresponding forward function
+ is stateful or not.
+ *)
+ stateful : bool; (** [true] if the function is stateful (updates a state) *)
can_fail : bool; (** [true] if the return type is a [result] *)
}
(** Meta information about a function signature *)
type fun_sig_info = {
num_fwd_inputs : int;
- (** The number of input types for forward computation *)
+ (** The number of input types for forward computation, ignoring the state *)
num_back_inputs : int option;
- (** The number of additional inputs for the backward computation (if pertinent) *)
+ (** The number of additional inputs for the backward computation (if pertinent),
+ ignoring the state *)
effect_info : fun_effect_info;
}
@@ -523,12 +530,18 @@ type fun_sig_info = {
`in_ty0 -> ... -> in_tyn -> back_in0 -> ... back_inm -> (back_out0 & ... & back_outp)` (* pure function *)
`in_ty0 -> ... -> in_tyn -> back_in0 -> ... back_inm ->
result (back_out0 & ... & back_outp)` (* error-monad *)
- `in_ty0 -> ... -> in_tyn -> state -> back_in0 -> ... back_inm ->
- result (back_out0 & ... & back_outp)` (* state-error *)
+ `in_ty0 -> ... -> in_tyn -> state -> back_in0 -> ... back_inm -> state ->
+ result (state & (back_out0 & ... & back_outp))` (* state-error *)
- Note that a backward function never returns (i.e., updates) a state: only
- forward functions do so. Also, the state input parameter is *betwee*
- the forward inputs and the backward inputs.
+ Note that a stateful backward function takes two states as inputs: the
+ state received by the associated forward function, and the state at which
+ the backward is called. This leads to code of the following shape:
+
+ {[
+ (st1, y) <-- f_fwd x st0; // st0 is the state upon calling f_fwd
+ ... // the state may be updated
+ (st3, x') <-- f_back x st0 y' st2; // st2 is the state upon calling f_back
+ ]}
The function's type should be given by `mk_arrows sig.inputs sig.output`.
We provide additional meta-information:
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index ff379bf5..1ab3439c 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -456,3 +456,9 @@ let mk_result_return_pattern (v : typed_pattern) : typed_pattern =
let opt_unmeta_mplace (e : texpression) : mplace option * texpression =
match e.e with Meta (MPlace mp, e) -> (Some mp, e) | _ -> (None, e)
+
+let mk_state_var (vid : VarId.id) : var =
+ { id = vid; basename = Some ConstStrings.state_basename; ty = mk_state_ty }
+
+let mk_state_texpression (vid : VarId.id) : texpression =
+ { e = Var vid; ty = mk_state_ty }
diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml
index 528d8255..9d9adf4f 100644
--- a/compiler/SymbolicAst.ml
+++ b/compiler/SymbolicAst.ml
@@ -77,6 +77,14 @@ type expression =
We use it to compute meaningful names for the variables we introduce,
to prettify the generated code.
*)
+ | ForwardEnd of expression
+ (** We use this delimiter to indicate at which point we switch to the
+ generation of code specific to the backward function(s).
+
+ TODO: use this to factorize the generation of the forward and backward
+ functions (today we replay the *whole* symbolic execution once per
+ generated function).
+ *)
| Meta of meta * expression (** Meta information *)
and expansion =
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 6d01614d..9d249cfb 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -12,6 +12,10 @@ module FA = FunsAnalysis
(** The local logger *)
let log = L.symbolic_to_pure_log
+(* TODO: carrying configurations everywhere is super annoying.
+ Group everything in references in a [Config.ml] file (put aside the execution
+ mode, maybe).
+*)
type config = {
filter_useless_back_calls : bool;
(** If [true], filter the useless calls to backward functions.
@@ -39,6 +43,12 @@ type config = {
Note that we later filter the useless *forward* calls in the micro-passes,
where it is more natural to do.
*)
+ backward_no_state_update : bool;
+ (** Controls whether backward functions update the state, in case we use
+ a state ({!use_state}).
+
+ See {!Translate.config.backward_no_state_update}.
+ *)
}
type type_context = {
@@ -110,7 +120,23 @@ type bs_ctx = {
*)
var_counter : VarId.generator;
state_var : VarId.id;
- (** The current state variable, in case we use a state *)
+ (** The current state variable, in case the function is stateful *)
+ back_state_var : VarId.id;
+ (** The additional input state variable received by a stateful backward function.
+ When generating stateful functions, we generate code of the following
+ form:
+
+ {[
+ (st1, y) <-- f_fwd x st0; // st0 is the state upon calling f_fwd
+ ... // the state may be updated
+ (st3, x') <-- f_back x st0 y' st2; // st2 is the state upon calling f_back
+ ]}
+
+ When translating a backward function, we need at some point to update
+ [state_var] with [back_state_var], to account for the fact that the
+ state may have been updated by the caller between the call to the
+ forward function and the call to the backward function.
+ *)
forward_inputs : var list;
(** The input parameters for the forward function *)
backward_inputs : var list T.RegionGroupId.Map.t;
@@ -498,20 +524,26 @@ let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) :
let abs_ids = list_ancestor_abstractions_ids ctx abs in
List.map (fun id -> V.AbstractionId.Map.find id ctx.abstractions) abs_ids
-(** Small utility. *)
-let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
- (fun_id : A.fun_id) (gid : T.RegionGroupId.id option) : fun_effect_info =
+(** Small utility.
+
+ [backward_no_state_update]: see {!config}
+ *)
+let get_fun_effect_info (backward_no_state_update : bool)
+ (fun_infos : FA.fun_info A.FunDeclId.Map.t) (fun_id : A.fun_id)
+ (gid : T.RegionGroupId.id option) : fun_effect_info =
match fun_id with
| A.Regular fid ->
let info = A.FunDeclId.Map.find fid fun_infos in
- let input_state = info.stateful in
- let output_state = input_state && gid = None in
- { can_fail = info.can_fail; input_state; output_state }
+ let stateful_group = info.stateful in
+ let stateful =
+ stateful_group && ((not backward_no_state_update) || gid = None)
+ in
+ { can_fail = info.can_fail; stateful_group; stateful }
| A.Assumed aid ->
{
can_fail = Assumed.assumed_can_fail aid;
- input_state = false;
- output_state = false;
+ stateful_group = false;
+ stateful = false;
}
(** Translate a function signature.
@@ -519,10 +551,11 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
Note that the function also takes a list of names for the inputs, and
computes, for every output for the backward functions, a corresponding
name (outputs for backward functions come from borrows in the inputs
- of the forward function).
+ of the forward function) which we use as hints to generate pretty names.
*)
-let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
- (fun_id : A.fun_id) (types_infos : TA.type_infos) (sg : A.fun_sig)
+let translate_fun_sig (backward_no_state_update : bool)
+ (fun_infos : FA.fun_info A.FunDeclId.Map.t) (fun_id : A.fun_id)
+ (types_infos : TA.type_infos) (sg : A.fun_sig)
(input_names : string option list) (bid : T.RegionGroupId.id option) :
fun_sig_named_outputs =
(* Retrieve the list of parent backward functions *)
@@ -572,17 +605,42 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
*)
List.filter_map (translate_back_ty_for_gid gid) [ sg.output ]
in
- (* Does the function take a state as input, does it return a state and can
- * it fail? *)
- let effect_info = get_fun_effect_info fun_infos fun_id bid in
- (* *)
- let state_ty = if effect_info.input_state then [ mk_state_ty ] else [] in
+ (* Is the function stateful, and can it fail? *)
+ let effect_info =
+ get_fun_effect_info backward_no_state_update fun_infos fun_id bid
+ in
+ (* If the function is stateful, the inputs are:
+ - forward: [fwd_ty0, ..., fwd_tyn, state]
+ - backward:
+ - if config.no_backward_state: [fwd_ty0, ..., fwd_tyn, state, back_ty, state]
+ - otherwise: [fwd_ty0, ..., fwd_tyn, state, back_ty]
+
+ The backward takes the same state as input as the forward function,
+ together with the state at the point where it gets called, if it is
+ stateful.
+
+ See the comments for {!Translate.config.backward_no_state_update}
+ *)
+ let fwd_state_ty =
+ (* For the forward state, we check if the *whole group* is stateful.
+ See {!effect_info}. *)
+ if effect_info.stateful_group then [ mk_state_ty ] else []
+ in
+ let back_state_ty =
+ (* For the backward state, we check if the function is a backward function,
+ and it is stateful *)
+ if effect_info.stateful && Option.is_some gid then [ mk_state_ty ] else []
+ in
+
(* Concatenate the inputs, in the following order:
* - forward inputs
- * - state input
+ * - forward state input
* - backward inputs
+ * - backward state input
*)
- let inputs = List.concat [ fwd_inputs; state_ty; back_inputs ] in
+ let inputs =
+ List.concat [ fwd_inputs; fwd_state_ty; back_inputs; back_state_ty ]
+ in
(* Outputs *)
let output_names, doutputs =
match gid with
@@ -620,7 +678,7 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
let output = mk_simpl_tuple_ty doutputs in
(* Add the output state *)
let output =
- if effect_info.output_state then mk_simpl_tuple_ty [ mk_state_ty; output ]
+ if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ]
else output
in
(* Wrap in a result type *)
@@ -1087,6 +1145,15 @@ let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx)
| Assertion (v, e) -> translate_assertion config v e ctx
| Expansion (p, sv, exp) -> translate_expansion config p sv exp ctx
| Meta (meta, e) -> translate_meta config meta e ctx
+ | ForwardEnd e ->
+ (* Update the current state with the additional state received by the backward
+ function, if needs be *)
+ let ctx =
+ match ctx.bid with
+ | None -> ctx
+ | Some _ -> { ctx with state_var = ctx.back_state_var }
+ in
+ translate_expression config e ctx
and translate_panic (ctx : bs_ctx) : texpression =
(* Here we use the function return type - note that it is ok because
@@ -1095,13 +1162,15 @@ and translate_panic (ctx : bs_ctx) : texpression =
(* If we use a state monad, we need to add a lambda for the state variable *)
(* Note that only forward functions return a state *)
let output_ty = mk_simpl_tuple_ty ctx.sg.doutputs in
- if ctx.sg.info.effect_info.output_state then
+ (* TODO: we should use a [Fail] function *)
+ if ctx.sg.info.effect_info.stateful then
(* Create the [Fail] value *)
let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in
let ret_v = mk_result_fail_texpression ret_ty in
ret_v
else mk_result_fail_texpression output_ty
+(** [opt_v]: the value to return, in case we translate a forward function *)
and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression
=
(* There are two cases:
@@ -1109,44 +1178,40 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression
value should be [Some] (it is the returned value)
- or we are translating a backward function, in which case it should be [None]
*)
- match ctx.bid with
- | None ->
- (* Forward function *)
- let v = Option.get opt_v in
- let v = typed_value_to_texpression ctx v in
- (* We may need to return a state
- * - error-monad: Return x
- * - state-error: Return (state, x)
- * *)
- if ctx.sg.info.effect_info.output_state then
- let state_var =
- {
- id = ctx.state_var;
- basename = Some ConstStrings.state_basename;
- ty = mk_state_ty;
- }
+ (* Compute the values that we should return *without the state and the result
+ * wrapper* *)
+ let output =
+ match ctx.bid with
+ | None ->
+ (* Forward function *)
+ let v = Option.get opt_v in
+ typed_value_to_texpression ctx v
+ | Some bid ->
+ (* Backward function *)
+ (* Sanity check *)
+ assert (opt_v = None);
+ (* Group the variables in which we stored the values we need to give back.
+ * See the explanations for the [SynthInput] case in [translate_end_abstraction] *)
+ let backward_outputs =
+ T.RegionGroupId.Map.find bid ctx.backward_outputs
in
- let state_rvalue = mk_texpression_from_var state_var in
- mk_result_return_texpression
- (mk_simpl_tuple_texpression [ state_rvalue; v ])
- else mk_result_return_texpression v
- | Some bid ->
- (* Backward function *)
- (* Sanity check *)
- assert (opt_v = None);
- assert (not ctx.sg.info.effect_info.output_state);
- (* We simply need to return the variables in which we stored the values
- * we need to give back.
- * See the explanations for the [SynthInput] case in [translate_end_abstraction] *)
- let backward_outputs =
- T.RegionGroupId.Map.find bid ctx.backward_outputs
- in
- let field_values = List.map mk_texpression_from_var backward_outputs in
- (* Backward functions never return a state *)
- (* TODO: we should use a [fail] function, it would be cleaner *)
- let ret_value = mk_simpl_tuple_texpression field_values in
- let ret_value = mk_result_return_texpression ret_value in
- ret_value
+ let field_values = List.map mk_texpression_from_var backward_outputs in
+ mk_simpl_tuple_texpression field_values
+ in
+ (* We may need to return a state
+ * - error-monad: Return x
+ * - state-error: Return (state, x)
+ * *)
+ let effect_info = ctx.sg.info.effect_info in
+ let output =
+ if effect_info.stateful then
+ let state_rvalue = mk_state_texpression ctx.state_var in
+ mk_simpl_tuple_texpression [ state_rvalue; output ]
+ else output
+ in
+ (* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *)
+ (* TODO: we should use a [Return] function *)
+ mk_result_return_texpression output
and translate_function_call (config : config) (call : S.call) (e : S.expression)
(ctx : bs_ctx) : texpression =
@@ -1171,29 +1236,26 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression)
(* Retrieve the effect information about this function (can fail,
* takes a state as input, etc.) *)
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos fid None
- in
- (* Add the state input argument *)
- let args =
- if effect_info.input_state then
- let state_var = { e = Var ctx.state_var; ty = mk_state_ty } in
- List.append args [ state_var ]
- else args
+ get_fun_effect_info config.backward_no_state_update
+ ctx.fun_context.fun_infos fid None
in
- (* Generate a fresh state variable if the function call introduces
- * a new variable *)
- let ctx, out_state =
- if effect_info.input_state then
- let ctx, var = bs_ctx_fresh_state_var ctx in
- (ctx, Some var)
- else (ctx, None)
+ (* If the function is stateful:
+ * - add the state input argument
+ * - generate a fresh state variable for the returned state
+ *)
+ let args, ctx, out_state =
+ if effect_info.stateful then
+ let state_var = mk_state_texpression ctx.state_var in
+ let ctx, nstate_var = bs_ctx_fresh_state_var ctx in
+ (List.append args [ state_var ], ctx, Some nstate_var)
+ else (args, ctx, None)
in
(* Register the function call *)
let ctx = bs_ctx_register_forward_call call_id call args ctx in
(ctx, func, effect_info, args, out_state)
| S.Unop E.Not ->
let effect_info =
- { can_fail = false; input_state = false; output_state = false }
+ { can_fail = false; stateful_group = false; stateful = false }
in
(ctx, Unop Not, effect_info, args, None)
| S.Unop E.Neg -> (
@@ -1203,14 +1265,14 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression)
(* Note that negation can lead to an overflow and thus fail (it
* is thus monadic) *)
let effect_info =
- { can_fail = true; input_state = false; output_state = false }
+ { can_fail = true; stateful_group = false; stateful = false }
in
(ctx, Unop (Neg int_ty), effect_info, args, None)
| _ -> raise (Failure "Unreachable"))
| S.Unop (E.Cast (src_ty, tgt_ty)) ->
(* Note that cast can fail *)
let effect_info =
- { can_fail = true; input_state = false; output_state = false }
+ { can_fail = true; stateful_group = false; stateful = false }
in
(ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None)
| S.Binop binop -> (
@@ -1222,8 +1284,8 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression)
let effect_info =
{
can_fail = ExpressionsUtils.binop_can_fail binop;
- input_state = false;
- output_state = false;
+ stateful_group = false;
+ stateful = false;
}
in
(ctx, Binop (binop, int_ty0), effect_info, args, None)
@@ -1307,6 +1369,17 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression)
| V.FunCall ->
let call_info = V.FunCallId.Map.find abs.call_id ctx.calls in
let call = call_info.forward in
+ let fun_id =
+ match call.call_id with
+ | S.Fun (fun_id, _) -> fun_id
+ | Unop _ | Binop _ ->
+ (* Those don't have backward functions *)
+ raise (Failure "Unreachable")
+ in
+ let effect_info =
+ get_fun_effect_info config.backward_no_state_update
+ ctx.fun_context.fun_infos fun_id (Some abs.back_id)
+ in
let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in
(* Retrieve the original call and the parent abstractions *)
let _forward, backwards = get_abs_ancestors ctx abs in
@@ -1322,8 +1395,21 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression)
(* Retrieve the values consumed upon ending the loans inside this
* abstraction: those give us the remaining input values *)
let back_inputs = abs_to_consumed ctx abs in
+ (* If the function is stateful:
+ * - add the state input argument
+ * - generate a fresh state variable for the returned state
+ *)
+ let back_state, ctx, nstate =
+ if effect_info.stateful then
+ let back_state = mk_state_texpression ctx.state_var in
+ let ctx, nstate = bs_ctx_fresh_state_var ctx in
+ ([ back_state ], ctx, Some nstate)
+ else ([], ctx, None)
+ in
+ (* Concatenate all the inpus *)
let inputs =
- List.concat [ fwd_inputs; back_ancestors_inputs; back_inputs ]
+ List.concat
+ [ fwd_inputs; back_ancestors_inputs; back_inputs; back_state ]
in
(* Retrieve the values given back by this function: those are the output
* values. We rely on the fact that there are no nested borrows to use the
@@ -1333,43 +1419,42 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression)
List.append (List.map translate_opt_mplace call.args_places) [ None ]
in
let ctx, outputs = abs_to_given_back output_mpl abs ctx in
- (* Group the output values together (note that for now, backward functions
- * never return an output state) *)
+ (* Group the output values together: first the updated inputs *)
let output = mk_simpl_tuple_pattern outputs in
- (* Sanity check: the inputs and outputs have the proper number and the proper type *)
- let fun_id =
- match call.call_id with
- | S.Fun (fun_id, _) -> fun_id
- | Unop _ | Binop _ ->
- (* Those don't have backward functions *)
- raise (Failure "Unreachable")
+ (* Add the returned state if the function is stateful *)
+ let output =
+ match nstate with
+ | None -> output
+ | Some nstate -> mk_simpl_tuple_pattern [ nstate; output ]
in
-
- let inst_sg =
- get_instantiated_fun_sig fun_id (Some abs.back_id) type_args ctx
+ (* Sanity check: the inputs and outputs have the proper number and the proper type *)
+ let _ =
+ let inst_sg =
+ get_instantiated_fun_sig fun_id (Some abs.back_id) type_args ctx
+ in
+ log#ldebug
+ (lazy
+ ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs ("
+ ^ string_of_int (List.length inputs)
+ ^ "): "
+ ^ String.concat ", " (List.map show_texpression inputs)
+ ^ "\n- inst_sg.inputs ("
+ ^ string_of_int (List.length inst_sg.inputs)
+ ^ "): "
+ ^ String.concat ", " (List.map show_ty inst_sg.inputs)));
+ List.iter
+ (fun (x, ty) -> assert ((x : texpression).ty = ty))
+ (List.combine inputs inst_sg.inputs);
+ log#ldebug
+ (lazy
+ ("\n- outputs: "
+ ^ string_of_int (List.length outputs)
+ ^ "\n- expected outputs: "
+ ^ string_of_int (List.length inst_sg.doutputs)));
+ List.iter
+ (fun (x, ty) -> assert ((x : typed_pattern).ty = ty))
+ (List.combine outputs inst_sg.doutputs)
in
- log#ldebug
- (lazy
- ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs ("
- ^ string_of_int (List.length inputs)
- ^ "): "
- ^ String.concat ", " (List.map show_texpression inputs)
- ^ "\n- inst_sg.inputs ("
- ^ string_of_int (List.length inst_sg.inputs)
- ^ "): "
- ^ String.concat ", " (List.map show_ty inst_sg.inputs)));
- List.iter
- (fun (x, ty) -> assert ((x : texpression).ty = ty))
- (List.combine inputs inst_sg.inputs);
- log#ldebug
- (lazy
- ("\n- outputs: "
- ^ string_of_int (List.length outputs)
- ^ "\n- expected outputs: "
- ^ string_of_int (List.length inst_sg.doutputs)));
- List.iter
- (fun (x, ty) -> assert ((x : typed_pattern).ty = ty))
- (List.combine outputs inst_sg.doutputs);
(* Retrieve the function id, and register the function call in the context
* if necessary *)
let ctx, func = bs_ctx_register_backward_call abs back_inputs ctx in
@@ -1382,9 +1467,6 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression)
(fun (arg, mp) -> mk_opt_mplace_texpression mp arg)
(List.combine inputs args_mplaces)
in
- let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos fun_id (Some abs.back_id)
- in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
let ret_ty =
if effect_info.can_fail then mk_result_ty output.ty else output.ty
@@ -1398,8 +1480,13 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression)
* We do a small optimization here: if the backward function doesn't
* have any output, we don't introduce any function call.
* See the comment in [config].
+ *
+ * TODO: use an option to disallow backward functions from updating the state.
+ * TODO: a backward function which only gives back shared borrows shouldn't
+ * update the state (state updates should only be used for mutable borrows,
+ * with objects like Rc for instance.
*)
- if config.filter_useless_back_calls && outputs = [] then (
+ if config.filter_useless_back_calls && outputs = [] && nstate = None then (
(* No outputs - we do a small sanity check: the backward function
* should have exactly the same number of inputs as the forward:
* this number can be different only if the forward function returned
@@ -1708,32 +1795,28 @@ let translate_fun_decl (config : config) (ctx : bs_ctx)
(* Translate the declaration *)
let def_id = def.A.def_id in
let basename = def.name in
- (* Lookup the signature *)
- let signature = bs_ctx_lookup_local_function_sig def_id bid ctx in
+ (* Retrieve the signature *)
+ let signature = ctx.sg in
(* Translate the body, if there is *)
let body =
match body with
| None -> None
| Some body ->
let body = translate_expression config body ctx in
- (* Sanity check *)
- type_check_texpression ctx body;
- (* Introduce the input state, if necessary *)
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos (Regular def_id) bid
+ get_fun_effect_info config.backward_no_state_update
+ ctx.fun_context.fun_infos (Regular def_id) bid
in
- let input_state =
- if effect_info.input_state then
- [
- {
- id = ctx.state_var;
- basename = Some ConstStrings.state_basename;
- ty = mk_state_ty;
- };
- ]
+ (* Sanity check *)
+ type_check_texpression ctx body;
+ (* Introduce the forward input state (the state at call site of the
+ * *forward* function), if necessary. *)
+ let fwd_state =
+ (* We check if the *whole group* is stateful. See {!effect_info} *)
+ if effect_info.stateful_group then [ mk_state_var ctx.state_var ]
else []
in
- (* Compute the list of (properly ordered) input variables *)
+ (* Compute the list of (properly ordered) backward input variables *)
let backward_inputs : var list =
match bid with
| None -> []
@@ -1747,8 +1830,17 @@ let translate_fun_decl (config : config) (ctx : bs_ctx)
(fun id -> T.RegionGroupId.Map.find id ctx.backward_inputs)
backward_ids)
in
+ (* Introduce the backward input state (the state at call site of the
+ * *backward* function), if necessary *)
+ let back_state =
+ if effect_info.stateful && Option.is_some bid then
+ [ mk_state_var ctx.back_state_var ]
+ else []
+ in
+ (* Group the inputs together *)
let inputs =
- List.concat [ ctx.forward_inputs; input_state; backward_inputs ]
+ List.concat
+ [ ctx.forward_inputs; fwd_state; backward_inputs; back_state ]
in
let inputs_lvs =
List.map (fun v -> mk_typed_pattern_from_var v None) inputs
@@ -1756,12 +1848,18 @@ let translate_fun_decl (config : config) (ctx : bs_ctx)
(* Sanity check *)
log#ldebug
(lazy
- ("SymbolicToPure.translate_fun_decl:" ^ "\n- forward_inputs: "
+ ("SymbolicToPure.translate_fun_decl: "
+ ^ Print.fun_name_to_string def.A.name
+ ^ " ("
+ ^ Print.option_to_string T.RegionGroupId.to_string bid
+ ^ ")" ^ "\n- forward_inputs: "
^ String.concat ", " (List.map show_var ctx.forward_inputs)
- ^ "\n- input_state: "
- ^ String.concat ", " (List.map show_var input_state)
+ ^ "\n- fwd_state: "
+ ^ String.concat ", " (List.map show_var fwd_state)
^ "\n- backward_inputs: "
^ String.concat ", " (List.map show_var backward_inputs)
+ ^ "\n- back_state: "
+ ^ String.concat ", " (List.map show_var back_state)
^ "\n- signature.inputs: "
^ String.concat ", " (List.map show_ty signature.inputs)));
assert (
@@ -1804,8 +1902,8 @@ let translate_type_decls (type_decls : T.type_decl list) : type_decl list =
- optional names for the outputs values (we derive them for the backward
functions)
*)
-let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t)
- (types_infos : TA.type_infos)
+let translate_fun_signatures (backward_no_state_update : bool)
+ (fun_infos : FA.fun_info A.FunDeclId.Map.t) (types_infos : TA.type_infos)
(functions : (A.fun_id * string option list * A.fun_sig) list) :
fun_sig_named_outputs RegularFunIdMap.t =
(* For every function, translate the signatures of:
@@ -1816,7 +1914,8 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t)
(sg : A.fun_sig) : (regular_fun_id * fun_sig_named_outputs) list =
(* The forward function *)
let fwd_sg =
- translate_fun_sig fun_infos fun_id types_infos sg input_names None
+ translate_fun_sig backward_no_state_update fun_infos fun_id types_infos sg
+ input_names None
in
let fwd_id = (fun_id, None) in
(* The backward functions *)
@@ -1824,8 +1923,8 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t)
List.map
(fun (rg : T.region_var_group) ->
let tsg =
- translate_fun_sig fun_infos fun_id types_infos sg input_names
- (Some rg.id)
+ translate_fun_sig backward_no_state_update fun_infos fun_id
+ types_infos sg input_names (Some rg.id)
in
let id = (fun_id, Some rg.id) in
(id, tsg))
diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml
index c74a831e..8d4dac82 100644
--- a/compiler/SynthesizeSymbolic.ml
+++ b/compiler/SynthesizeSymbolic.ml
@@ -12,7 +12,7 @@ let mk_mplace (p : E.place) (ctx : Contexts.eval_ctx) : mplace =
let mk_opt_mplace (p : E.place option) (ctx : Contexts.eval_ctx) : mplace option
=
- match p with None -> None | Some p -> Some (mk_mplace p ctx)
+ Option.map (fun p -> mk_mplace p ctx) p
let mk_opt_place_from_op (op : E.operand) (ctx : Contexts.eval_ctx) :
mplace option =
@@ -22,11 +22,11 @@ let mk_opt_place_from_op (op : E.operand) (ctx : Contexts.eval_ctx) :
let synthesize_symbolic_expansion (sv : V.symbolic_value)
(place : mplace option) (seel : V.symbolic_expansion option list)
- (exprl : expression list option) : expression option =
- match exprl with
+ (el : expression list option) : expression option =
+ match el with
| None -> None
- | Some exprl ->
- let ls = List.combine seel exprl in
+ | Some el ->
+ let ls = List.combine seel el in
(* Match on the symbolic value type to know which can of expansion happened *)
let expansion =
match sv.V.sv_ty with
@@ -89,19 +89,18 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value)
Some (Expansion (place, sv, expansion))
let synthesize_symbolic_expansion_no_branching (sv : V.symbolic_value)
- (place : mplace option) (see : V.symbolic_expansion)
- (expr : expression option) : expression option =
- let exprl = match expr with None -> None | Some expr -> Some [ expr ] in
- synthesize_symbolic_expansion sv place [ Some see ] exprl
+ (place : mplace option) (see : V.symbolic_expansion) (e : expression option)
+ : expression option =
+ let el = Option.map (fun e -> [ e ]) e in
+ synthesize_symbolic_expansion sv place [ Some see ] el
let synthesize_function_call (call_id : call_id)
(abstractions : V.AbstractionId.id list) (type_params : T.ety list)
(args : V.typed_value list) (args_places : mplace option list)
(dest : V.symbolic_value) (dest_place : mplace option)
- (expr : expression option) : expression option =
- match expr with
- | None -> None
- | Some expr ->
+ (e : expression option) : expression option =
+ Option.map
+ (fun e ->
let call =
{
call_id;
@@ -113,48 +112,45 @@ let synthesize_function_call (call_id : call_id)
dest_place;
}
in
- Some (FunCall (call, expr))
+ FunCall (call, e))
+ e
let synthesize_global_eval (gid : A.GlobalDeclId.id) (dest : V.symbolic_value)
- (expr : expression option) : expression option =
- match expr with None -> None | Some e -> Some (EvalGlobal (gid, dest, e))
+ (e : expression option) : expression option =
+ Option.map (fun e -> EvalGlobal (gid, dest, e)) e
let synthesize_regular_function_call (fun_id : A.fun_id)
(call_id : V.FunCallId.id) (abstractions : V.AbstractionId.id list)
(type_params : T.ety list) (args : V.typed_value list)
(args_places : mplace option list) (dest : V.symbolic_value)
- (dest_place : mplace option) (expr : expression option) : expression option
- =
+ (dest_place : mplace option) (e : expression option) : expression option =
synthesize_function_call
(Fun (fun_id, call_id))
- abstractions type_params args args_places dest dest_place expr
+ abstractions type_params args args_places dest dest_place e
let synthesize_unary_op (unop : E.unop) (arg : V.typed_value)
(arg_place : mplace option) (dest : V.symbolic_value)
- (dest_place : mplace option) (expr : expression option) : expression option
- =
+ (dest_place : mplace option) (e : expression option) : expression option =
synthesize_function_call (Unop unop) [] [] [ arg ] [ arg_place ] dest
- dest_place expr
+ dest_place e
let synthesize_binary_op (binop : E.binop) (arg0 : V.typed_value)
(arg0_place : mplace option) (arg1 : V.typed_value)
(arg1_place : mplace option) (dest : V.symbolic_value)
- (dest_place : mplace option) (expr : expression option) : expression option
- =
+ (dest_place : mplace option) (e : expression option) : expression option =
synthesize_function_call (Binop binop) [] [] [ arg0; arg1 ]
- [ arg0_place; arg1_place ] dest dest_place expr
+ [ arg0_place; arg1_place ] dest dest_place e
-let synthesize_end_abstraction (abs : V.abs) (expr : expression option) :
+let synthesize_end_abstraction (abs : V.abs) (e : expression option) :
expression option =
- match expr with
- | None -> None
- | Some expr -> Some (EndAbstraction (abs, expr))
+ Option.map (fun e -> EndAbstraction (abs, e)) e
let synthesize_assignment (lplace : mplace) (rvalue : V.typed_value)
- (rplace : mplace option) (expr : expression option) : expression option =
- match expr with
- | None -> None
- | Some expr -> Some (Meta (Assignment (lplace, rvalue, rplace), expr))
+ (rplace : mplace option) (e : expression option) : expression option =
+ Option.map (fun e -> Meta (Assignment (lplace, rvalue, rplace), e)) e
+
+let synthesize_assertion (v : V.typed_value) (e : expression option) =
+ Option.map (fun e -> Assertion (v, e)) e
-let synthesize_assertion (v : V.typed_value) (expr : expression option) =
- match expr with None -> None | Some expr -> Some (Assertion (v, expr))
+let synthesize_forward_end (e : expression option) =
+ Option.map (fun e -> ForwardEnd e) e
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index d7cc9155..72322c73 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -18,6 +18,30 @@ type config = {
(** Controls whether we need to use a state to model the external world
(I/O, for instance).
*)
+ backward_no_state_update : bool;
+ (** Controls whether backward functions update the state, in case we use
+ a state ({!use_state}).
+
+ If they update the state, we generate code of the following style:
+ {[
+ (st1, y) <-- f_fwd x st0; // st0 is the state upon calling f_fwd
+ ...
+ (st3, x') <-- f_back x st0 y' st2; // st2 is the state upon calling f_back
+ }]
+
+ Otherwise, we generate code of the following shape:
+ {[
+ (st1, y) <-- f_fwd x st0;
+ ...
+ x' <-- f_back x st0 y';
+ }]
+
+ The second format is easier to reason about, but the first one is
+ necessary to properly handle some Rust functions which use internal
+ mutability such as {{:https://doc.rust-lang.org/std/cell/struct.RefCell.html#method.try_borrow_mut} [RefCell::try_mut_borrow]}:
+ in order to model this behaviour we would need a state, and calling the backward
+ function would update the state by reinserting the updated value in it.
+ *)
split_files : bool;
(** Controls whether we split the generated definitions between different
files for the types, clauses and functions, or if we group them in
@@ -96,7 +120,8 @@ let translate_function_to_symbolics (config : C.partial_config)
TODO: maybe we should introduce a record for this.
*)
let translate_function_to_pure (config : C.partial_config)
- (mp_config : Micro.config) (trans_ctx : trans_ctx)
+ (mp_config : Micro.config) (backward_no_state_update : bool)
+ (trans_ctx : trans_ctx)
(fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdMap.t)
(pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) (fdef : A.fun_decl)
: pure_fun_translation =
@@ -123,6 +148,7 @@ let translate_function_to_pure (config : C.partial_config)
let sv_to_var = V.SymbolicValueId.Map.empty in
let var_counter = Pure.VarId.generator_zero in
let state_var, var_counter = Pure.VarId.fresh var_counter in
+ let back_state_var, var_counter = Pure.VarId.fresh var_counter in
let calls = V.FunCallId.Map.empty in
let abstractions = V.AbstractionId.Map.empty in
let type_context =
@@ -151,6 +177,7 @@ let translate_function_to_pure (config : C.partial_config)
sv_to_var;
var_counter;
state_var;
+ back_state_var;
type_context;
fun_context;
global_context;
@@ -188,6 +215,7 @@ let translate_function_to_pure (config : C.partial_config)
{
SymbolicToPure.filter_useless_back_calls =
mp_config.filter_useless_monadic_calls;
+ backward_no_state_update;
}
in
@@ -231,12 +259,22 @@ let translate_function_to_pure (config : C.partial_config)
in
(* We need to ignore the forward inputs, and the state input (if there is) *)
let fun_info =
- SymbolicToPure.get_fun_effect_info fun_context.fun_infos
- (A.Regular def_id) (Some back_id)
+ SymbolicToPure.get_fun_effect_info backward_no_state_update
+ fun_context.fun_infos (A.Regular def_id) (Some back_id)
in
- let _, backward_inputs =
- Collections.List.split_at backward_sg.sg.inputs
- (num_forward_inputs + if fun_info.input_state then 1 else 0)
+ let backward_inputs =
+ (* We need to ignore the forward state and the backward state *)
+ (* TODO: this is ad-hoc and error-prone. We should group all this
+ * information in the signature information. *)
+ let fwd_state_n = if fun_info.stateful_group then 1 else 0 in
+ let num_forward_inputs = num_forward_inputs + fwd_state_n in
+ let back_state_n = if fun_info.stateful then 1 else 0 in
+ let num_back_inputs =
+ List.length backward_sg.sg.inputs
+ - num_forward_inputs - back_state_n
+ in
+ Collections.List.subslice backward_sg.sg.inputs num_forward_inputs
+ num_back_inputs
in
(* As we forbid nested borrows, the additional inputs for the backward
* functions come from the borrows in the return value of the rust function:
@@ -285,7 +323,8 @@ let translate_function_to_pure (config : C.partial_config)
(pure_forward, pure_backwards)
let translate_module_to_pure (config : C.partial_config)
- (mp_config : Micro.config) (use_state : bool) (crate : A.crate) :
+ (mp_config : Micro.config) (use_state : bool)
+ (backward_no_state_update : bool) (crate : A.crate) :
trans_ctx * Pure.type_decl list * (bool * pure_fun_translation) list =
(* Debug *)
log#ldebug (lazy "translate_module_to_pure");
@@ -333,15 +372,15 @@ let translate_module_to_pure (config : C.partial_config)
in
let sigs = List.append assumed_sigs local_sigs in
let fun_sigs =
- SymbolicToPure.translate_fun_signatures fun_context.fun_infos
- type_context.type_infos sigs
+ SymbolicToPure.translate_fun_signatures backward_no_state_update
+ fun_context.fun_infos type_context.type_infos sigs
in
(* Translate all the *transparent* functions *)
let pure_translations =
List.map
- (translate_function_to_pure config mp_config trans_ctx fun_sigs
- type_decls_map)
+ (translate_function_to_pure config mp_config backward_no_state_update
+ trans_ctx fun_sigs type_decls_map)
crate.functions
in
@@ -631,7 +670,7 @@ let translate_module (filename : string) (dest_dir : string) (config : config)
(* Translate the module to the pure AST *)
let trans_ctx, trans_types, trans_funs =
translate_module_to_pure config.eval_config config.mp_config
- config.use_state crate
+ config.use_state config.backward_no_state_update crate
in
(* Initialize the extraction context - for now we extract only to F*.