summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/ExtractToFStar.ml2
-rw-r--r--src/Pure.ml15
-rw-r--r--src/PureUtils.ml19
-rw-r--r--src/SymbolicToPure.ml27
4 files changed, 28 insertions, 35 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml
index 1c93c9da..c5d078f9 100644
--- a/src/ExtractToFStar.ml
+++ b/src/ExtractToFStar.ml
@@ -1406,7 +1406,7 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter)
*)
let num_fwd_inputs = def.signature.info.num_fwd_inputs in
let num_fwd_inputs =
- if def.signature.info.input_state then 1 + num_fwd_inputs
+ if def.signature.info.effect_info.input_state then 1 + num_fwd_inputs
else num_fwd_inputs
in
Collections.List.prefix num_fwd_inputs all_inputs
diff --git a/src/Pure.ml b/src/Pure.ml
index d8e1cafc..f5bed43d 100644
--- a/src/Pure.ml
+++ b/src/Pure.ml
@@ -505,17 +505,22 @@ and meta =
nude = true (* Don't inherit [VisitorsRuntime.iter] *);
}]
-type fun_sig_info = {
- num_fwd_inputs : int;
- (** The number of input types for forward computation *)
- num_back_inputs : int option;
- (** The number of additional inputs for the backward computation (if pertinent) *)
+type fun_effect_info = {
input_state : bool; (** `true` if the function takes a state as input *)
output_state : bool;
(** `true` if the function outputs a state (it then lives
in a state monad) *)
can_fail : bool; (** `true` if the return type is a `result` *)
}
+(** Information about the "effect" of a function *)
+
+type fun_sig_info = {
+ num_fwd_inputs : int;
+ (** The number of input types for forward computation *)
+ num_back_inputs : int option;
+ (** The number of additional inputs for the backward computation (if pertinent) *)
+ effect_info : fun_effect_info;
+}
(** Meta information about a function signature *)
type fun_sig = {
diff --git a/src/PureUtils.ml b/src/PureUtils.ml
index a1af3396..7d298f13 100644
--- a/src/PureUtils.ml
+++ b/src/PureUtils.ml
@@ -61,11 +61,6 @@ let mk_typed_pattern_from_constant_value (cv : constant_value) : typed_pattern =
let ty = compute_constant_value_ty cv in
{ value = PatConcrete cv; ty }
-(*let mk_value_expression (v : typed_rvalue) (mp : mplace option) : texpression =
- let e = Value (v, mp) in
- let ty = v.ty in
- { e; ty }*)
-
let mk_let (monadic : bool) (lv : typed_pattern) (re : texpression)
(next_e : texpression) : texpression =
let e = Let (monadic, lv, re, next_e) in
@@ -480,3 +475,17 @@ let opt_destruct_state_monad_result (ty : ty) : ty option =
let opt_unmeta_mplace (e : texpression) : mplace option * texpression =
match e.e with Meta (MPlace mp, e) -> (Some mp, e) | _ -> (None, e)
+
+let get_fun_effect_info (use_state : bool) (fun_id : A.fun_id)
+ (gid : T.RegionGroupId.id option) : fun_effect_info =
+ match fun_id with
+ | A.Regular _ ->
+ let input_state = use_state in
+ let output_state = input_state && gid = None in
+ { can_fail = true; input_state; output_state }
+ | A.Assumed aid ->
+ {
+ can_fail = Assumed.assumed_is_monadic aid;
+ input_state = false;
+ output_state = false;
+ }
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index 83c045df..01cc37eb 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -474,29 +474,10 @@ let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) :
let abs_ids = list_ancestor_abstractions_ids ctx abs in
List.map (fun id -> V.AbstractionId.Map.find id ctx.abstractions) abs_ids
-type fun_effect_info = {
- can_fail : bool;
- input_state : bool;
- output_state : bool;
-}
-(** TODO: factorize with fun_sig_info?
- TODO: use an enumeration
- *)
-
(** Small utility. *)
let get_fun_effect_info (config : config) (fun_id : A.fun_id)
(gid : T.RegionGroupId.id option) : fun_effect_info =
- match fun_id with
- | A.Regular _ ->
- let input_state = config.use_state_monad in
- let output_state = input_state && gid = None in
- { can_fail = true; input_state; output_state }
- | A.Assumed aid ->
- {
- can_fail = Assumed.assumed_is_monadic aid;
- input_state = false;
- output_state = false;
- }
+ PureUtils.get_fun_effect_info config.use_state_monad fun_id gid
(** Translate a function signature.
@@ -618,9 +599,7 @@ let translate_fun_sig (config : config) (fun_id : A.fun_id)
num_fwd_inputs = List.length fwd_inputs;
num_back_inputs =
(if bid = None then None else Some (List.length back_inputs));
- input_state = effect_info.input_state;
- output_state = effect_info.output_state;
- can_fail = effect_info.can_fail;
+ effect_info;
}
in
let sg = { type_params; inputs; output; doutputs; info } in
@@ -1079,7 +1058,7 @@ and translate_panic (ctx : bs_ctx) : texpression =
(* If we use a state monad, we need to add a lambda for the state variable *)
(* Note that only forward functions return a state *)
let output_ty = mk_simpl_tuple_ty ctx.sg.doutputs in
- if ctx.sg.info.output_state then
+ if ctx.sg.info.effect_info.output_state then
(* Create the `Fail` value *)
let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in
let ret_v = mk_result_fail_texpression ret_ty in