summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSon Ho2022-05-04 15:39:05 +0200
committerSon Ho2022-05-04 15:39:05 +0200
commit15d90db02086f8ecae9a93ebf39c3c0ae8caa50f (patch)
tree3eb303b96c9233fd7745f06a0d4e2d211373ca03 /src
parentfb6fdfd0c57de1ce16fb6bc373d5593c9446b0bb (diff)
Fix some issues when using states
Diffstat (limited to '')
-rw-r--r--src/SymbolicToPure.ml149
-rw-r--r--src/Translate.ml4
-rw-r--r--src/main.ml3
3 files changed, 84 insertions, 72 deletions
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index fa482b8e..66f4d608 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -478,6 +478,30 @@ let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) :
let abs_ids = list_ancestor_abstractions_ids ctx abs in
List.map (fun id -> V.AbstractionId.Map.find id ctx.abstractions) abs_ids
+type fun_effect_info = {
+ can_fail : bool;
+ input_state : bool;
+ output_state : bool;
+}
+(** TODO: factorize with fun_sig_info?
+ TODO: use an enumeration
+ *)
+
+(** Small utility. *)
+let get_fun_effect_info (config : config) (fun_id : A.fun_id)
+ (gid : T.RegionGroupId.id option) : fun_effect_info =
+ match fun_id with
+ | A.Regular _ ->
+ let input_state = config.use_state_monad in
+ let output_state = input_state && gid = None in
+ { can_fail = true; input_state; output_state }
+ | A.Assumed aid ->
+ {
+ can_fail = Assumed.assumed_is_monadic aid;
+ input_state = false;
+ output_state = false;
+ }
+
(** Translate a function signature.
Note that the function also takes a list of names for the inputs, and
@@ -485,9 +509,10 @@ let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) :
name (outputs for backward functions come from borrows in the inputs
of the forward function).
*)
-let translate_fun_sig (config : config) (types_infos : TA.type_infos)
- (sg : A.fun_sig) (input_names : string option list)
- (bid : T.RegionGroupId.id option) : fun_sig_named_outputs =
+let translate_fun_sig (config : config) (fun_id : A.fun_id)
+ (types_infos : TA.type_infos) (sg : A.fun_sig)
+ (input_names : string option list) (bid : T.RegionGroupId.id option) :
+ fun_sig_named_outputs =
(* Retrieve the list of parent backward functions *)
let gid, parents =
match bid with
@@ -537,15 +562,9 @@ let translate_fun_sig (config : config) (types_infos : TA.type_infos)
in
(* Does the function take a state as input, does it return a state and can
* it fail? *)
- (* For now, all the translated functions can fail *)
- let can_fail = true in
- (* For now, all translated functions have an input state if we setup
- * the translation to use states *)
- let input_state = config.use_state_monad in
- (* Only the forward functions return a state *)
- let output_state = input_state && bid = None in
+ let effect_info = get_fun_effect_info config fun_id bid in
(* *)
- let state_ty = if input_state then [ mk_state_ty ] else [] in
+ let state_ty = if effect_info.input_state then [ mk_state_ty ] else [] in
(* Concatenate the inputs, in the following order:
* - forward inputs
* - state input
@@ -589,10 +608,11 @@ let translate_fun_sig (config : config) (types_infos : TA.type_infos)
let output = mk_simpl_tuple_ty doutputs in
(* Add the output state *)
let output =
- if output_state then mk_simpl_tuple_ty [ mk_state_ty; output ] else output
+ if effect_info.output_state then mk_simpl_tuple_ty [ mk_state_ty; output ]
+ else output
in
(* Wrap in a result type *)
- if can_fail then mk_result_ty output else output
+ if effect_info.can_fail then mk_result_ty output else output
in
(* Type parameters *)
let type_params = sg.type_params in
@@ -602,9 +622,9 @@ let translate_fun_sig (config : config) (types_infos : TA.type_infos)
num_fwd_inputs = List.length fwd_inputs;
num_back_inputs =
(if bid = None then None else Some (List.length back_inputs));
- input_state;
- output_state;
- can_fail;
+ input_state = effect_info.input_state;
+ output_state = effect_info.output_state;
+ can_fail = effect_info.can_fail;
}
in
let sg = { type_params; inputs; output; doutputs; info } in
@@ -1046,30 +1066,6 @@ let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) :
let abs_ancestors = list_ancestor_abstractions ctx abs in
(call_info.forward, abs_ancestors)
-type fun_effect_info = {
- can_fail : bool;
- input_state : bool;
- output_state : bool;
-}
-(** TODO: factorize with fun_sig_info?
- TODO: use an enumeration
- *)
-
-(** Small utility. *)
-let get_fun_effect_info (config : config) (fun_id : A.fun_id)
- (gid : T.RegionGroupId.id option) : fun_effect_info =
- match fun_id with
- | A.Regular _ ->
- let input_state = config.use_state_monad in
- let output_state = input_state && gid = None in
- { can_fail = true; input_state; output_state }
- | A.Assumed aid ->
- {
- can_fail = Assumed.assumed_is_monadic aid;
- input_state = false;
- output_state = false;
- }
-
let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx)
: texpression =
match e with
@@ -1085,7 +1081,8 @@ and translate_panic (config : config) (ctx : bs_ctx) : texpression =
* we don't match on panics which happen inside the function body -
* but it won't be true anymore once we translate individual blocks *)
(* If we use a state monad, we need to add a lambda for the state variable *)
- if config.use_state_monad then
+ (* Note that only forward functions return a state *)
+ if config.use_state_monad && ctx.bid <> None then
(* Create the `Fail` value *)
let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; ctx.output_ty ] in
let ret_v = mk_result_fail_texpression ret_ty in
@@ -1109,22 +1106,21 @@ and translate_return (config : config) (opt_v : V.typed_value option)
(* Forward function *)
let v = Option.get opt_v in
let v = typed_value_to_texpression ctx v in
- (* We don't synthesize the same expression depending on the monad we use:
+ (* We may need to return a state
* - error-monad: Return x
- * - state-error monad: fun state -> Return (state, x)
+ * - state-error: Return (state, x)
* *)
- (* TODO: we should use a `return` function, it would be cleaner *)
if config.use_state_monad then
- let _, state_var =
- fresh_var (Some ConstStrings.state_basename) mk_state_ty ctx
+ let state_var =
+ {
+ id = ctx.state_var;
+ basename = Some ConstStrings.state_basename;
+ ty = mk_state_ty;
+ }
in
let state_rvalue = mk_texpression_from_var state_var in
- let ret_v =
- mk_result_return_texpression
- (mk_simpl_tuple_texpression [ state_rvalue; v ])
- in
- let state_var = mk_typed_pattern_from_var state_var None in
- mk_abs state_var ret_v
+ mk_result_return_texpression
+ (mk_simpl_tuple_texpression [ state_rvalue; v ])
else mk_result_return_texpression v
| Some bid ->
(* Backward function *)
@@ -1137,22 +1133,11 @@ and translate_return (config : config) (opt_v : V.typed_value option)
T.RegionGroupId.Map.find bid ctx.backward_outputs
in
let field_values = List.map mk_texpression_from_var backward_outputs in
- (* See the comment about the monads, for the forward function case *)
+ (* Backward functions never return a state *)
(* TODO: we should use a `fail` function, it would be cleaner *)
- if config.use_state_monad then
- let _, state_var = fresh_var (Some "st") mk_state_ty ctx in
- let state_rvalue = mk_texpression_from_var state_var in
- let ret_value = mk_simpl_tuple_texpression field_values in
- let ret_value =
- mk_result_return_texpression
- (mk_simpl_tuple_texpression [ state_rvalue; ret_value ])
- in
- let state_var = mk_typed_pattern_from_var state_var None in
- mk_abs state_var ret_value
- else
- let ret_value = mk_simpl_tuple_texpression field_values in
- let ret_value = mk_result_return_texpression ret_value in
- ret_value
+ let ret_value = mk_simpl_tuple_texpression field_values in
+ let ret_value = mk_result_return_texpression ret_value in
+ ret_value
and translate_function_call (config : config) (call : S.call) (e : S.expression)
(ctx : bs_ctx) : texpression =
@@ -1311,7 +1296,8 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression)
(* Retrieve the values consumed when we called the forward function and
* ended the parent backward functions: those give us part of the input
* values (rmk: for now, as we disallow nested lifetimes, there can't be
- * parent backward functions) *)
+ * parent backward functions).
+ * Note that the forward inputs include the input state (if there is one). *)
let fwd_inputs = call_info.forward_inputs in
let back_ancestors_inputs =
List.concat (List.map (fun (_abs, args) -> args) backwards)
@@ -1345,6 +1331,16 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression)
let inst_sg =
get_instantiated_fun_sig fun_id (Some abs.back_id) type_args ctx
in
+ log#ldebug
+ (lazy
+ ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs ("
+ ^ string_of_int (List.length inputs)
+ ^ "): "
+ ^ String.concat ", " (List.map show_texpression inputs)
+ ^ "\n- inst_sg.inputs ("
+ ^ string_of_int (List.length inst_sg.inputs)
+ ^ "): "
+ ^ String.concat ", " (List.map show_ty inst_sg.inputs)));
List.iter
(fun (x, ty) -> assert ((x : texpression).ty = ty))
(List.combine inputs inst_sg.inputs);
@@ -1715,6 +1711,16 @@ let translate_fun_decl (config : config) (ctx : bs_ctx)
List.map (fun v -> mk_typed_pattern_from_var v None) inputs
in
(* Sanity check *)
+ log#ldebug
+ (lazy
+ ("SymbolicToPure.translate_fun_decl:" ^ "\n- forward_inputs: "
+ ^ String.concat ", " (List.map show_var ctx.forward_inputs)
+ ^ "\n- input_state: "
+ ^ String.concat ", " (List.map show_var input_state)
+ ^ "\n- backward_inputs: "
+ ^ String.concat ", " (List.map show_var backward_inputs)
+ ^ "\n- signature.inputs: "
+ ^ String.concat ", " (List.map show_ty signature.inputs)));
assert (
List.for_all
(fun (var, ty) -> (var : var).ty = ty)
@@ -1756,14 +1762,17 @@ let translate_fun_signatures (config : config) (types_infos : TA.type_infos)
let translate_one (fun_id : A.fun_id) (input_names : string option list)
(sg : A.fun_sig) : (regular_fun_id * fun_sig_named_outputs) list =
(* The forward function *)
- let fwd_sg = translate_fun_sig config types_infos sg input_names None in
+ let fwd_sg =
+ translate_fun_sig config fun_id types_infos sg input_names None
+ in
let fwd_id = (fun_id, None) in
(* The backward functions *)
let back_sgs =
List.map
(fun (rg : T.region_var_group) ->
let tsg =
- translate_fun_sig config types_infos sg input_names (Some rg.id)
+ translate_fun_sig config fun_id types_infos sg input_names
+ (Some rg.id)
in
let id = (fun_id, Some rg.id) in
(id, tsg))
diff --git a/src/Translate.ml b/src/Translate.ml
index d69f1379..92261dba 100644
--- a/src/Translate.ml
+++ b/src/Translate.ml
@@ -232,8 +232,10 @@ let translate_function_to_pure (config : C.partial_config)
let backward_sg =
RegularFunIdMap.find (A.Regular def_id, Some back_id) fun_sigs
in
+ (* We need to ignore the forward inputs, and the state input (if there is) *)
let _, backward_inputs =
- Collections.List.split_at backward_sg.sg.inputs num_forward_inputs
+ Collections.List.split_at backward_sg.sg.inputs
+ (num_forward_inputs + if use_state_monad then 1 else 0)
in
(* As we forbid nested borrows, the additional inputs for the backward
* functions come from the borrows in the return value of the rust function:
diff --git a/src/main.ml b/src/main.ml
index e635d910..8afbc0cd 100644
--- a/src/main.ml
+++ b/src/main.ml
@@ -122,7 +122,8 @@ let () =
(* Set up the logging - for now we use default values - TODO: use the
* command-line arguments *)
- Easy_logging.Handlers.set_level main_logger_handler EL.Info;
+ (* By setting a level for the main_logger_handler, we filter everything *)
+ Easy_logging.Handlers.set_level main_logger_handler EL.Debug;
main_log#set_level EL.Info;
llbc_of_json_logger#set_level EL.Info;
pre_passes_log#set_level EL.Info;