diff options
-rw-r--r-- | src/ExtractToFStar.ml | 2 | ||||
-rw-r--r-- | src/Pure.ml | 15 | ||||
-rw-r--r-- | src/PureUtils.ml | 19 | ||||
-rw-r--r-- | src/SymbolicToPure.ml | 27 |
4 files changed, 28 insertions, 35 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml index 1c93c9da..c5d078f9 100644 --- a/src/ExtractToFStar.ml +++ b/src/ExtractToFStar.ml @@ -1406,7 +1406,7 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) *) let num_fwd_inputs = def.signature.info.num_fwd_inputs in let num_fwd_inputs = - if def.signature.info.input_state then 1 + num_fwd_inputs + if def.signature.info.effect_info.input_state then 1 + num_fwd_inputs else num_fwd_inputs in Collections.List.prefix num_fwd_inputs all_inputs diff --git a/src/Pure.ml b/src/Pure.ml index d8e1cafc..f5bed43d 100644 --- a/src/Pure.ml +++ b/src/Pure.ml @@ -505,17 +505,22 @@ and meta = nude = true (* Don't inherit [VisitorsRuntime.iter] *); }] -type fun_sig_info = { - num_fwd_inputs : int; - (** The number of input types for forward computation *) - num_back_inputs : int option; - (** The number of additional inputs for the backward computation (if pertinent) *) +type fun_effect_info = { input_state : bool; (** `true` if the function takes a state as input *) output_state : bool; (** `true` if the function outputs a state (it then lives in a state monad) *) can_fail : bool; (** `true` if the return type is a `result` *) } +(** Information about the "effect" of a function *) + +type fun_sig_info = { + num_fwd_inputs : int; + (** The number of input types for forward computation *) + num_back_inputs : int option; + (** The number of additional inputs for the backward computation (if pertinent) *) + effect_info : fun_effect_info; +} (** Meta information about a function signature *) type fun_sig = { diff --git a/src/PureUtils.ml b/src/PureUtils.ml index a1af3396..7d298f13 100644 --- a/src/PureUtils.ml +++ b/src/PureUtils.ml @@ -61,11 +61,6 @@ let mk_typed_pattern_from_constant_value (cv : constant_value) : typed_pattern = let ty = compute_constant_value_ty cv in { value = PatConcrete cv; ty } -(*let mk_value_expression (v : typed_rvalue) (mp : mplace option) : texpression = - let e = Value (v, mp) in - let ty = v.ty in - { e; ty }*) - let mk_let (monadic : bool) (lv : typed_pattern) (re : texpression) (next_e : texpression) : texpression = let e = Let (monadic, lv, re, next_e) in @@ -480,3 +475,17 @@ let opt_destruct_state_monad_result (ty : ty) : ty option = let opt_unmeta_mplace (e : texpression) : mplace option * texpression = match e.e with Meta (MPlace mp, e) -> (Some mp, e) | _ -> (None, e) + +let get_fun_effect_info (use_state : bool) (fun_id : A.fun_id) + (gid : T.RegionGroupId.id option) : fun_effect_info = + match fun_id with + | A.Regular _ -> + let input_state = use_state in + let output_state = input_state && gid = None in + { can_fail = true; input_state; output_state } + | A.Assumed aid -> + { + can_fail = Assumed.assumed_is_monadic aid; + input_state = false; + output_state = false; + } diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index 83c045df..01cc37eb 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -474,29 +474,10 @@ let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) : let abs_ids = list_ancestor_abstractions_ids ctx abs in List.map (fun id -> V.AbstractionId.Map.find id ctx.abstractions) abs_ids -type fun_effect_info = { - can_fail : bool; - input_state : bool; - output_state : bool; -} -(** TODO: factorize with fun_sig_info? - TODO: use an enumeration - *) - (** Small utility. *) let get_fun_effect_info (config : config) (fun_id : A.fun_id) (gid : T.RegionGroupId.id option) : fun_effect_info = - match fun_id with - | A.Regular _ -> - let input_state = config.use_state_monad in - let output_state = input_state && gid = None in - { can_fail = true; input_state; output_state } - | A.Assumed aid -> - { - can_fail = Assumed.assumed_is_monadic aid; - input_state = false; - output_state = false; - } + PureUtils.get_fun_effect_info config.use_state_monad fun_id gid (** Translate a function signature. @@ -618,9 +599,7 @@ let translate_fun_sig (config : config) (fun_id : A.fun_id) num_fwd_inputs = List.length fwd_inputs; num_back_inputs = (if bid = None then None else Some (List.length back_inputs)); - input_state = effect_info.input_state; - output_state = effect_info.output_state; - can_fail = effect_info.can_fail; + effect_info; } in let sg = { type_params; inputs; output; doutputs; info } in @@ -1079,7 +1058,7 @@ and translate_panic (ctx : bs_ctx) : texpression = (* If we use a state monad, we need to add a lambda for the state variable *) (* Note that only forward functions return a state *) let output_ty = mk_simpl_tuple_ty ctx.sg.doutputs in - if ctx.sg.info.output_state then + if ctx.sg.info.effect_info.output_state then (* Create the `Fail` value *) let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in let ret_v = mk_result_fail_texpression ret_ty in |