summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSon Ho2022-05-05 17:04:25 +0200
committerSon Ho2022-05-05 17:04:25 +0200
commit678b057f231f8eb99d3dc70ceb99c7a90a854d4d (patch)
treec00a79bc8935f1dc9478f67e77ef1182c857f3da /src
parent30085b15a3ef07bc7179a60cd42085270dbc9351 (diff)
Update the translation so that we use a state only in the functions
which need one
Diffstat (limited to 'src')
-rw-r--r--src/SymbolicToPure.ml49
-rw-r--r--src/Translate.ml28
2 files changed, 42 insertions, 35 deletions
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index 41b4cdeb..3ac68365 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -8,6 +8,7 @@ module S = SymbolicAst
module TA = TypesAnalysis
module L = Logging
module PP = PrintPure
+module FA = FunsAnalysis
(** The local logger *)
let log = L.symbolic_to_pure_log
@@ -38,10 +39,6 @@ type config = {
Note that we later filter the useless *forward* calls in the micro-passes,
where it is more natural to do.
*)
- use_state : bool;
- (** Controls whether we need to use a state to model the external world
- (I/O, for instance).
- *)
}
type type_context = {
@@ -72,6 +69,7 @@ type fun_sig_named_outputs = {
type fun_context = {
llbc_fun_decls : A.fun_decl FunDeclId.Map.t;
fun_sigs : fun_sig_named_outputs RegularFunIdMap.t; (** *)
+ fun_infos : FA.fun_info FunDeclId.Map.t;
}
type call_info = {
@@ -473,11 +471,12 @@ let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) :
List.map (fun id -> V.AbstractionId.Map.find id ctx.abstractions) abs_ids
(** Small utility. *)
-let get_fun_effect_info (config : config) (fun_id : A.fun_id)
- (gid : T.RegionGroupId.id option) : fun_effect_info =
+let get_fun_effect_info (fun_infos : FA.fun_info FunDeclId.Map.t)
+ (fun_id : A.fun_id) (gid : T.RegionGroupId.id option) : fun_effect_info =
match fun_id with
- | A.Regular _ ->
- let input_state = config.use_state in
+ | A.Regular fid ->
+ let info = FunDeclId.Map.find fid fun_infos in
+ let input_state = info.stateful in
let output_state = input_state && gid = None in
{ can_fail = true; input_state; output_state }
| A.Assumed aid ->
@@ -494,8 +493,8 @@ let get_fun_effect_info (config : config) (fun_id : A.fun_id)
name (outputs for backward functions come from borrows in the inputs
of the forward function).
*)
-let translate_fun_sig (config : config) (fun_id : A.fun_id)
- (types_infos : TA.type_infos) (sg : A.fun_sig)
+let translate_fun_sig (fun_infos : FA.fun_info FunDeclId.Map.t)
+ (fun_id : A.fun_id) (types_infos : TA.type_infos) (sg : A.fun_sig)
(input_names : string option list) (bid : T.RegionGroupId.id option) :
fun_sig_named_outputs =
(* Retrieve the list of parent backward functions *)
@@ -547,7 +546,7 @@ let translate_fun_sig (config : config) (fun_id : A.fun_id)
in
(* Does the function take a state as input, does it return a state and can
* it fail? *)
- let effect_info = get_fun_effect_info config fun_id bid in
+ let effect_info = get_fun_effect_info fun_infos fun_id bid in
(* *)
let state_ty = if effect_info.input_state then [ mk_state_ty ] else [] in
(* Concatenate the inputs, in the following order:
@@ -1052,7 +1051,7 @@ let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) :
let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx)
: texpression =
match e with
- | S.Return opt_v -> translate_return config opt_v ctx
+ | S.Return opt_v -> translate_return opt_v ctx
| Panic -> translate_panic ctx
| FunCall (call, e) -> translate_function_call config call e ctx
| EndAbstraction (abs, e) -> translate_end_abstraction config abs e ctx
@@ -1073,8 +1072,8 @@ and translate_panic (ctx : bs_ctx) : texpression =
ret_v
else mk_result_fail_texpression output_ty
-and translate_return (config : config) (opt_v : V.typed_value option)
- (ctx : bs_ctx) : texpression =
+and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression
+ =
(* There are two cases:
- either we are translating a forward function, in which case the optional
value should be `Some` (it is the returned value)
@@ -1089,7 +1088,7 @@ and translate_return (config : config) (opt_v : V.typed_value option)
* - error-monad: Return x
* - state-error: Return (state, x)
* *)
- if config.use_state then
+ if ctx.sg.info.effect_info.output_state then
let state_var =
{
id = ctx.state_var;
@@ -1105,6 +1104,7 @@ and translate_return (config : config) (opt_v : V.typed_value option)
(* Backward function *)
(* Sanity check *)
assert (opt_v = None);
+ assert (not ctx.sg.info.effect_info.output_state);
(* We simply need to return the variables in which we stored the values
* we need to give back.
* See the explanations for the [SynthInput] case in [translate_end_abstraction] *)
@@ -1140,7 +1140,9 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression)
let func = Regular (fid, None) in
(* Retrieve the effect information about this function (can fail,
* takes a state as input, etc.) *)
- let effect_info = get_fun_effect_info config fid None in
+ let effect_info =
+ get_fun_effect_info ctx.fun_context.fun_infos fid None
+ in
(* Add the state input argument *)
let args =
if effect_info.input_state then
@@ -1344,7 +1346,9 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression)
(fun (arg, mp) -> mk_opt_mplace_texpression mp arg)
(List.combine inputs args_mplaces)
in
- let effect_info = get_fun_effect_info config fun_id (Some abs.back_id) in
+ let effect_info =
+ get_fun_effect_info ctx.fun_context.fun_infos fun_id (Some abs.back_id)
+ in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
let ret_ty =
if effect_info.can_fail then mk_result_ty output.ty else output.ty
@@ -1657,7 +1661,9 @@ let translate_fun_decl (config : config) (ctx : bs_ctx)
(* Sanity check *)
type_check_texpression ctx body;
(* Introduce the input state, if necessary *)
- let effect_info = get_fun_effect_info config (Regular def_id) bid in
+ let effect_info =
+ get_fun_effect_info ctx.fun_context.fun_infos (Regular def_id) bid
+ in
let input_state =
if effect_info.input_state then
[
@@ -1731,7 +1737,8 @@ let translate_type_decls (type_decls : T.type_decl list) : type_decl list =
- optional names for the outputs values (we derive them for the backward
functions)
*)
-let translate_fun_signatures (config : config) (types_infos : TA.type_infos)
+let translate_fun_signatures (fun_infos : FA.fun_info FunDeclId.Map.t)
+ (types_infos : TA.type_infos)
(functions : (A.fun_id * string option list * A.fun_sig) list) :
fun_sig_named_outputs RegularFunIdMap.t =
(* For every function, translate the signatures of:
@@ -1742,7 +1749,7 @@ let translate_fun_signatures (config : config) (types_infos : TA.type_infos)
(sg : A.fun_sig) : (regular_fun_id * fun_sig_named_outputs) list =
(* The forward function *)
let fwd_sg =
- translate_fun_sig config fun_id types_infos sg input_names None
+ translate_fun_sig fun_infos fun_id types_infos sg input_names None
in
let fwd_id = (fun_id, None) in
(* The backward functions *)
@@ -1750,7 +1757,7 @@ let translate_fun_signatures (config : config) (types_infos : TA.type_infos)
List.map
(fun (rg : T.region_var_group) ->
let tsg =
- translate_fun_sig config fun_id types_infos sg input_names
+ translate_fun_sig fun_infos fun_id types_infos sg input_names
(Some rg.id)
in
let id = (fun_id, Some rg.id) in
diff --git a/src/Translate.ml b/src/Translate.ml
index 13715865..857f0f69 100644
--- a/src/Translate.ml
+++ b/src/Translate.ml
@@ -97,7 +97,7 @@ let translate_function_to_symbolics (config : C.partial_config)
TODO: maybe we should introduce a record for this.
*)
let translate_function_to_pure (config : C.partial_config)
- (mp_config : Micro.config) (use_state : bool) (trans_ctx : trans_ctx)
+ (mp_config : Micro.config) (trans_ctx : trans_ctx)
(fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdMap.t)
(pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) (fdef : A.fun_decl)
: pure_fun_translation =
@@ -134,7 +134,11 @@ let translate_function_to_pure (config : C.partial_config)
}
in
let fun_context =
- { SymbolicToPure.llbc_fun_decls = fun_context.fun_decls; fun_sigs }
+ {
+ SymbolicToPure.llbc_fun_decls = fun_context.fun_decls;
+ fun_sigs;
+ fun_infos = fun_context.fun_infos;
+ }
in
let ctx =
{
@@ -181,7 +185,6 @@ let translate_function_to_pure (config : C.partial_config)
{
SymbolicToPure.filter_useless_back_calls =
mp_config.filter_useless_monadic_calls;
- use_state;
}
in
@@ -224,9 +227,13 @@ let translate_function_to_pure (config : C.partial_config)
RegularFunIdMap.find (A.Regular def_id, Some back_id) fun_sigs
in
(* We need to ignore the forward inputs, and the state input (if there is) *)
+ let fun_info =
+ SymbolicToPure.get_fun_effect_info fun_context.fun_infos
+ (A.Regular def_id) (Some back_id)
+ in
let _, backward_inputs =
Collections.List.split_at backward_sg.sg.inputs
- (num_forward_inputs + if use_state then 1 else 0)
+ (num_forward_inputs + if fun_info.input_state then 1 else 0)
in
(* As we forbid nested borrows, the additional inputs for the backward
* functions come from the borrows in the return value of the rust function:
@@ -317,22 +324,15 @@ let translate_module_to_pure (config : C.partial_config)
m.functions
in
let sigs = List.append assumed_sigs local_sigs in
- let sp_config =
- {
- SymbolicToPure.filter_useless_back_calls =
- mp_config.filter_useless_monadic_calls;
- use_state;
- }
- in
let fun_sigs =
- SymbolicToPure.translate_fun_signatures sp_config type_context.type_infos
- sigs
+ SymbolicToPure.translate_fun_signatures fun_context.fun_infos
+ type_context.type_infos sigs
in
(* Translate all the *transparent* functions *)
let pure_translations =
List.map
- (translate_function_to_pure config mp_config use_state trans_ctx fun_sigs
+ (translate_function_to_pure config mp_config trans_ctx fun_sigs
type_decls_map)
m.functions
in