summaryrefslogtreecommitdiff
path: root/compiler/SymbolicToPure.ml
diff options
context:
space:
mode:
authorSon HO2023-12-23 01:46:58 +0100
committerGitHub2023-12-23 01:46:58 +0100
commit15a7d7b7322a1cd0ebeb328fde214060e23fa8b4 (patch)
tree6cce7d76969870f5bc18c5a7cd585e8873a1c0dc /compiler/SymbolicToPure.ml
parentc3e0b90e422cbd902ee6d2b47073940c0017b7fb (diff)
parent63ccbd914d5d44aa30dee38a6fcc019310ab640b (diff)
Merge pull request #64 from AeneasVerif/son/merge_back
Merge the forward/backward functions
Diffstat (limited to '')
-rw-r--r--compiler/SymbolicToPure.ml1902
1 files changed, 1294 insertions, 608 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 84f09280..3a50e495 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -15,7 +15,7 @@ module PP = PrintPure
(** The local logger *)
let log = Logging.symbolic_to_pure_log
-type type_context = {
+type type_ctx = {
llbc_type_decls : T.type_decl TypeDeclId.Map.t;
type_decls : type_decl TypeDeclId.Map.t;
(** We use this for type-checking (for sanity checks) when translating
@@ -43,19 +43,18 @@ type fun_sig_named_outputs = {
}
[@@deriving show]
-type fun_context = {
+type fun_ctx = {
llbc_fun_decls : A.fun_decl A.FunDeclId.Map.t;
- fun_sigs : fun_sig_named_outputs RegularFunIdNotLoopMap.t; (** *)
fun_infos : fun_info A.FunDeclId.Map.t;
regions_hierarchies : T.region_var_groups FunIdMap.t;
}
[@@deriving show]
-type global_context = { llbc_global_decls : A.global_decl A.GlobalDeclId.Map.t }
+type global_ctx = { llbc_global_decls : A.global_decl A.GlobalDeclId.Map.t }
[@@deriving show]
-type trait_decls_context = A.trait_decl A.TraitDeclId.Map.t [@@deriving show]
-type trait_impls_context = A.trait_impl A.TraitImplId.Map.t [@@deriving show]
+type trait_decls_ctx = A.trait_decl A.TraitDeclId.Map.t [@@deriving show]
+type trait_impls_ctx = A.trait_impl A.TraitImplId.Map.t [@@deriving show]
(** Whenever we translate a function call or an ended abstraction, we
store the related information (this is useful when translating ended
@@ -68,11 +67,21 @@ type call_info = {
Those inputs include the fuel and the state, if pertinent.
*)
- backwards : (V.abs * texpression list) T.RegionGroupId.Map.t;
- (** A map from region group id (i.e., backward function id) to
- pairs (abstraction, additional arguments received by the backward function)
-
- TODO: remove? it is also in the bs_ctx ("abstractions" field)
+ back_funs : texpression option RegionGroupId.Map.t option;
+ (** If we do not split between the forward/backward functions: the
+ variables we introduced for the backward functions.
+
+ Example:
+ {[
+ let x, back = Vec.index_mut n v in
+ ^^^^
+ here
+ ...
+ ]}
+
+ The expression might be [None] in case the backward function
+ has to be filtered (because it does nothing - the backward
+ functions for shared borrows for instance).
*)
}
[@@deriving show]
@@ -114,33 +123,61 @@ type loop_info = {
generics : generic_args;
forward_inputs : texpression list option;
(** The forward inputs are initialized at [None] *)
- forward_output_no_state_no_result : var option;
+ forward_output_no_state_no_result : texpression option;
(** The forward outputs are initialized at [None] *)
+ back_outputs : ty list RegionGroupId.Map.t;
+ (** The map from region group ids to the types of the values given back
+ by the corresponding loop abstractions.
+ *)
+ back_funs : texpression option RegionGroupId.Map.t option;
+ (** Same as {!call_info.back_funs}.
+ Initialized with [None], gets updated to [Some] only if we merge
+ the fwd/back functions.
+ *)
+ fwd_effect_info : fun_effect_info;
+ back_effect_infos : fun_effect_info RegionGroupId.Map.t;
}
[@@deriving show]
(** Body synthesis context *)
type bs_ctx = {
- type_context : type_context;
- fun_context : fun_context;
- global_context : global_context;
- trait_decls_ctx : trait_decls_context;
- trait_impls_ctx : trait_impls_context;
+ (* TODO: there are a lot of duplications with the various decls ctx *)
+ decls_ctx : C.decls_ctx;
+ type_ctx : type_ctx;
+ fun_ctx : fun_ctx;
+ global_ctx : global_ctx;
+ trait_decls_ctx : trait_decls_ctx;
+ trait_impls_ctx : trait_impls_ctx;
+ fun_dsigs : decomposed_fun_sig FunDeclId.Map.t;
fun_decl : A.fun_decl;
- bid : T.RegionGroupId.id option; (** TODO: rename *)
- sg : fun_sig;
- (** The function signature - useful in particular to translate [Panic] *)
- fwd_sg : fun_sig; (** The signature of the forward function *)
+ bid : RegionGroupId.id option;
+ (** TODO: rename
+
+ The id of the group region we are currently translating.
+ If we split the forward/backward functions, we set this id at the
+ very beginning of the translation.
+ If we don't split, we set it to `None`, then update it when we enter
+ an expression which is specific to a backward function.
+ *)
+ sg : decomposed_fun_sig;
+ (** Information about the function signature - useful in particular to
+ translate [Panic] *)
sv_to_var : var V.SymbolicValueId.Map.t;
(** Whenever we encounter a new symbolic value (introduced because of
a symbolic expansion or upon ending an abstraction, for instance)
we introduce a new variable (with a let-binding).
*)
- var_counter : VarId.generator;
+ var_counter : VarId.generator ref;
+ (** Using a ref to make sure all the variables identifiers are unique.
+ TODO: this is not very clean, and the code was initially written without
+ a reference (and it's shape hasn't changed). We should use DeBruijn indices.
+ *)
state_var : VarId.id;
(** The current state variable, in case the function is stateful *)
- back_state_var : VarId.id;
- (** The additional input state variable received by a stateful backward function.
+ back_state_vars : VarId.id RegionGroupId.Map.t;
+ (** The additional input state variable received by a stateful backward
+ function, **in case we are splitting the forward/backward functions**.
+
When generating stateful functions, we generate code of the following
form:
@@ -153,7 +190,9 @@ type bs_ctx = {
When translating a backward function, we need at some point to update
[state_var] with [back_state_var], to account for the fact that the
state may have been updated by the caller between the call to the
- forward function and the call to the backward function.
+ forward function and the call to the backward function. We also need
+ to make sure we use the same variable in all the branches (because
+ this variable is quantified at the definition level).
*)
fuel0 : VarId.id;
(** The original fuel taken as input by the function (if we use fuel) *)
@@ -163,28 +202,50 @@ type bs_ctx = {
(** The input parameters for the forward function corresponding to the
translated Rust inputs (no fuel, no state).
*)
- backward_inputs : var list T.RegionGroupId.Map.t;
+ backward_inputs_no_state : var list RegionGroupId.Map.t;
(** The additional input parameters for the backward functions coming
from the borrows consumed upon ending the lifetime (as a consequence
those don't include the backward state, if there is one).
- *)
- backward_outputs : var list T.RegionGroupId.Map.t;
- (** The variables that the backward functions will output, corresponding
- to the borrows they give back (don't include the backward state)
- *)
- loop_backward_outputs : var list T.RegionGroupId.Map.t option;
- (** Same as {!backward_outputs}, but for loops (if we entered a loop).
- [None] if we are not inside a loop, [Some] otherwise (and whatever
- the kind of function we are translating: it will be [Some] even
- though we are synthesizing a forward function).
+ If we split the forward/backward functions: we initialize this map
+ when initializing the bs_ctx, because those variables are quantified
+ at the definition level. Otherwise, we initialize it upon diving
+ into the expressions which are specific to the backward functions.
+ *)
+ backward_inputs_with_state : var list RegionGroupId.Map.t;
+ (** All the additional input parameters for the backward functions.
- TODO: move to {!loop_info}
+ Same remarks as for {!backward_inputs_no_state}.
+ *)
+ backward_outputs : var list option;
+ (** The variables that the backward functions will output, corresponding
+ to the borrows they give back (don't include the backward state).
+
+ The translation is done as follows:
+ - when we detect the ended input abstraction which corresponds
+ to the backward function of the LLBC function we are translating,
+ and which consumed the values [consumed_i] (that we need to give
+ back to the caller), we introduce:
+ {[
+ let v_i = consumed_i in
+ ...
+ ]}
+ where the [v_i] are fresh, and are stored in the [backward_output].
+ - Then, upon reaching the [Return] node, we introduce:
+ {[
+ return (v_i)
+ ]}
+
+ The option is [None] before we detect the ended input abstraction,
+ and [Some] afterwards.
*)
calls : call_info V.FunCallId.Map.t;
(** The function calls we encountered so far *)
abstractions : (V.abs * texpression list) V.AbstractionId.Map.t;
- (** The ended abstractions we encountered so far, with their additional input arguments *)
+ (** The ended abstractions we encountered so far, with their additional
+ input arguments. We store it here and not in {!call_info} because
+ we need a map from abstraction id to abstraction (and not
+ from call id + region group id to abstraction). *)
loop_ids_map : LoopId.id V.LoopId.Map.t; (** Ids to use for the loops *)
loops : loop_info LoopId.Map.t;
(** The loops we encountered so far.
@@ -209,9 +270,9 @@ type bs_ctx = {
(* TODO: move *)
let bs_ctx_to_fmt_env (ctx : bs_ctx) : Print.fmt_env =
- let type_decls = ctx.type_context.llbc_type_decls in
- let fun_decls = ctx.fun_context.llbc_fun_decls in
- let global_decls = ctx.global_context.llbc_global_decls in
+ let type_decls = ctx.type_ctx.llbc_type_decls in
+ let fun_decls = ctx.fun_ctx.llbc_fun_decls in
+ let global_decls = ctx.global_ctx.llbc_global_decls in
let trait_decls = ctx.trait_decls_ctx in
let trait_impls = ctx.trait_impls_ctx in
let { regions; types; const_generics; trait_clauses } : T.generic_params =
@@ -233,9 +294,9 @@ let bs_ctx_to_fmt_env (ctx : bs_ctx) : Print.fmt_env =
}
let bs_ctx_to_pure_fmt_env (ctx : bs_ctx) : PrintPure.fmt_env =
- let type_decls = ctx.type_context.llbc_type_decls in
- let fun_decls = ctx.fun_context.llbc_fun_decls in
- let global_decls = ctx.global_context.llbc_global_decls in
+ let type_decls = ctx.type_ctx.llbc_type_decls in
+ let fun_decls = ctx.fun_ctx.llbc_fun_decls in
+ let global_decls = ctx.global_ctx.llbc_global_decls in
let trait_decls = ctx.trait_decls_ctx in
let trait_impls = ctx.trait_impls_ctx in
let generics = ctx.sg.generics in
@@ -300,6 +361,17 @@ let typed_pattern_to_string (ctx : bs_ctx) (p : Pure.typed_pattern) : string =
let env = bs_ctx_to_pure_fmt_env ctx in
PrintPure.typed_pattern_to_string env p
+let ctx_get_effect_info_for_bid (ctx : bs_ctx) (bid : RegionGroupId.id option) :
+ fun_effect_info =
+ match bid with
+ | None -> ctx.sg.fwd_info.effect_info
+ | Some bid ->
+ let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in
+ back_sg.effect_info
+
+let ctx_get_effect_info (ctx : bs_ctx) : fun_effect_info =
+ ctx_get_effect_info_for_bid ctx ctx.bid
+
(* TODO: move *)
let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string =
let env = bs_ctx_to_fmt_env ctx in
@@ -308,33 +380,13 @@ let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string =
let indent_incr = " " in
Print.Values.abs_to_string env verbose indent indent_incr abs
-let get_instantiated_fun_sig (fun_id : A.fun_id)
- (back_id : T.RegionGroupId.id option) (generics : generic_args)
- (ctx : bs_ctx) : inst_fun_sig =
- (* Lookup the non-instantiated function signature *)
- let sg =
- (RegularFunIdNotLoopMap.find (fun_id, back_id) ctx.fun_context.fun_sigs).sg
- in
- (* Create the substitution *)
- (* There shouldn't be any reference to Self *)
- let tr_self = UnknownTrait __FUNCTION__ in
- let subst = make_subst_from_generics sg.generics generics tr_self in
- (* Apply *)
- fun_sig_substitute subst sg
-
let bs_ctx_lookup_llbc_type_decl (id : TypeDeclId.id) (ctx : bs_ctx) :
T.type_decl =
- TypeDeclId.Map.find id ctx.type_context.llbc_type_decls
+ TypeDeclId.Map.find id ctx.type_ctx.llbc_type_decls
let bs_ctx_lookup_llbc_fun_decl (id : A.FunDeclId.id) (ctx : bs_ctx) :
A.fun_decl =
- A.FunDeclId.Map.find id ctx.fun_context.llbc_fun_decls
-
-(* TODO: move *)
-let bs_ctx_lookup_local_function_sig (def_id : A.FunDeclId.id)
- (back_id : T.RegionGroupId.id option) (ctx : bs_ctx) : fun_sig =
- let id = (E.FRegular def_id, back_id) in
- (RegularFunIdNotLoopMap.find id ctx.fun_context.fun_sigs).sg
+ A.FunDeclId.Map.find id ctx.fun_ctx.llbc_fun_decls
(* Some generic translation functions (we need to translate different "flavours"
of types: forward types, backward types, etc.) *)
@@ -601,13 +653,13 @@ and translate_fwd_trait_instance_id (type_infos : type_infos)
(** Simply calls [translate_fwd_ty] *)
let ctx_translate_fwd_ty (ctx : bs_ctx) (ty : T.ty) : ty =
- let type_infos = ctx.type_context.type_infos in
+ let type_infos = ctx.type_ctx.type_infos in
translate_fwd_ty type_infos ty
(** Simply calls [translate_fwd_generic_args] *)
let ctx_translate_fwd_generic_args (ctx : bs_ctx) (generics : T.generic_args) :
generic_args =
- let type_infos = ctx.type_context.type_infos in
+ let type_infos = ctx.type_ctx.type_infos in
translate_fwd_generic_args type_infos generics
(** Translate a type, when some regions may have ended.
@@ -682,17 +734,21 @@ let rec translate_back_ty (type_infos : type_infos)
None
| TTraitType (trait_ref, generics, type_name) ->
assert (generics.regions = []);
- (* Translate the trait ref and the generics as "forward" generics -
- we do not want to filter any type *)
- let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
- let generics = translate_fwd_generic_args type_infos generics in
- Some (TTraitType (trait_ref, generics, type_name))
+ assert (
+ AssociatedTypes.trait_instance_id_is_local_clause trait_ref.trait_id);
+ if inside_mut then
+ (* Translate the trait ref and the generics as "forward" generics -
+ we do not want to filter any type *)
+ let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
+ let generics = translate_fwd_generic_args type_infos generics in
+ Some (TTraitType (trait_ref, generics, type_name))
+ else None
| TArrow _ -> raise (Failure "TODO")
(** Simply calls [translate_back_ty] *)
let ctx_translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool)
(inside_mut : bool) (ty : T.ty) : ty option =
- let type_infos = ctx.type_context.type_infos in
+ let type_infos = ctx.type_ctx.type_infos in
translate_back_ty type_infos keep_region inside_mut ty
let mk_type_check_ctx (ctx : bs_ctx) : PureTypeCheck.tc_ctx =
@@ -705,8 +761,8 @@ let mk_type_check_ctx (ctx : bs_ctx) : PureTypeCheck.tc_ctx =
in
let env = VarId.Map.empty in
{
- PureTypeCheck.type_decls = ctx.type_context.type_decls;
- global_decls = ctx.global_context.llbc_global_decls;
+ PureTypeCheck.type_decls = ctx.type_ctx.type_decls;
+ global_decls = ctx.global_ctx.llbc_global_decls;
env;
const_generics;
}
@@ -726,31 +782,36 @@ let translate_fun_id_or_trait_method_ref (ctx : bs_ctx)
match id with
| FunId fun_id -> FunId fun_id
| TraitMethod (trait_ref, method_name, fun_decl_id) ->
- let type_infos = ctx.type_context.type_infos in
+ let type_infos = ctx.type_ctx.type_infos in
let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
TraitMethod (trait_ref, method_name, fun_decl_id)
let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call)
- (args : texpression list) (ctx : bs_ctx) : bs_ctx =
+ (args : texpression list)
+ (back_funs : texpression option RegionGroupId.Map.t option) (ctx : bs_ctx) :
+ bs_ctx =
let calls = ctx.calls in
assert (not (V.FunCallId.Map.mem call_id calls));
- let info =
- { forward; forward_inputs = args; backwards = T.RegionGroupId.Map.empty }
- in
+ let info = { forward; forward_inputs = args; back_funs } in
let calls = V.FunCallId.Map.add call_id info calls in
{ ctx with calls }
-(** [back_args]: the *additional* list of inputs received by the backward function *)
-let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id)
- (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx)
- : bs_ctx * fun_or_op_id =
+(** [inherit_args]: the list of inputs inherited from the forward function and
+ the ancestors backward functions, if pertinent.
+ [back_args]: the *additional* list of inputs received by the backward function,
+ including the state.
+
+ Returns the updated context and the expression corresponding to the function
+ that we need to call. This function may be [None] if it has to be ignored
+ (because it does nothing).
+ *)
+let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info)
+ (call_id : V.FunCallId.id) (back_id : T.RegionGroupId.id)
+ (inherited_args : texpression list) (back_args : texpression list)
+ (generics : generic_args) (output_ty : ty) (ctx : bs_ctx) :
+ bs_ctx * texpression option =
(* Insert the abstraction in the call informations *)
let info = V.FunCallId.Map.find call_id ctx.calls in
- assert (not (T.RegionGroupId.Map.mem back_id info.backwards));
- let backwards =
- T.RegionGroupId.Map.add back_id (abs, back_args) info.backwards
- in
- let info = { info with backwards } in
let calls = V.FunCallId.Map.add call_id info ctx.calls in
(* Insert the abstraction in the abstractions map *)
let abstractions = ctx.abstractions in
@@ -758,16 +819,31 @@ let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id)
let abstractions =
V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions
in
- (* Retrieve the fun_id *)
- let fun_id =
- match info.forward.call_id with
- | S.Fun (fid, _) ->
- let fid = translate_fun_id_or_trait_method_ref ctx fid in
- Fun (FromLlbc (fid, None, Some back_id))
- | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable")
+ (* Compute the expression corresponding to the function *)
+ let func =
+ if !Config.return_back_funs then
+ (* Lookup the variable introduced for the backward function *)
+ RegionGroupId.Map.find back_id (Option.get info.back_funs)
+ else
+ (* Retrieve the fun_id *)
+ let fun_id =
+ match info.forward.call_id with
+ | S.Fun (fid, _) ->
+ let fid = translate_fun_id_or_trait_method_ref ctx fid in
+ Fun (FromLlbc (fid, None, Some back_id))
+ | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable")
+ in
+ let args = List.append inherited_args back_args in
+ let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
+ let ret_ty =
+ if effect_info.can_fail then mk_result_ty output_ty else output_ty
+ in
+ let func_ty = mk_arrows input_tys ret_ty in
+ let func = { id = FunOrOp fun_id; generics } in
+ Some { e = Qualif func; ty = func_ty }
in
(* Update the context and return *)
- ({ ctx with calls; abstractions }, fun_id)
+ ({ ctx with calls; abstractions }, func)
(** List the ancestors of an abstraction *)
let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs)
@@ -821,7 +897,7 @@ let mk_fuel_input_as_list (ctx : bs_ctx) (info : fun_effect_info) :
if function_uses_fuel info then [ mk_fuel_texpression ctx.fuel ] else []
(** Small utility. *)
-let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t)
+let compute_raw_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t)
(fun_id : A.fun_id_or_trait_method_ref) (lid : V.LoopId.id option)
(gid : T.RegionGroupId.id option) : fun_effect_info =
match fun_id with
@@ -848,33 +924,53 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t)
is_rec = false;
}
-(** Translate a function signature.
+(** TODO: not very clean. *)
+let get_fun_effect_info (ctx : bs_ctx) (fun_id : A.fun_id_or_trait_method_ref)
+ (lid : V.LoopId.id option) (gid : T.RegionGroupId.id option) :
+ fun_effect_info =
+ match lid with
+ | None -> (
+ match fun_id with
+ | TraitMethod (_, _, fid) | FunId (FRegular fid) ->
+ let dsg = A.FunDeclId.Map.find fid ctx.fun_dsigs in
+ let info =
+ match gid with
+ | None -> dsg.fwd_info.effect_info
+ | Some gid -> (RegionGroupId.Map.find gid dsg.back_sg).effect_info
+ in
+ { info with is_rec = info.is_rec || Option.is_some lid }
+ | FunId (FAssumed _) ->
+ compute_raw_fun_effect_info ctx.fun_ctx.fun_infos fun_id lid gid)
+ | Some lid -> (
+ (* This is necessarily for the current function *)
+ match fun_id with
+ | FunId (FRegular fid) -> (
+ assert (fid = ctx.fun_decl.def_id);
+ (* Lookup the loop *)
+ let lid = V.LoopId.Map.find lid ctx.loop_ids_map in
+ let loop_info = LoopId.Map.find lid ctx.loops in
+ match gid with
+ | None -> loop_info.fwd_effect_info
+ | Some gid -> RegionGroupId.Map.find gid loop_info.back_effect_infos)
+ | _ -> raise (Failure "Unreachable"))
+
+(** Translate a function signature to a decomposed function signature.
Note that the function also takes a list of names for the inputs, and
computes, for every output for the backward functions, a corresponding
name (outputs for backward functions come from borrows in the inputs
of the forward function) which we use as hints to generate pretty names
in the extracted code.
+
+ We use [bid] ("backward function id") only if we split the forward
+ and the backward functions.
*)
-let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id)
- (sg : A.fun_sig) (input_names : string option list)
- (bid : T.RegionGroupId.id option) : fun_sig_named_outputs =
+let translate_fun_sig_with_regions_hierarchy_to_decomposed
+ (decls_ctx : C.decls_ctx) (fun_id : A.fun_id_or_trait_method_ref)
+ (regions_hierarchy : T.region_var_groups) (sg : A.fun_sig)
+ (input_names : string option list) : decomposed_fun_sig =
let fun_infos = decls_ctx.fun_ctx.fun_infos in
let type_infos = decls_ctx.type_ctx.type_infos in
- (* Retrieve the list of parent backward functions *)
- let regions_hierarchy =
- FunIdMap.find fun_id decls_ctx.fun_ctx.regions_hierarchies
- in
- let gid, parents =
- match bid with
- | None -> (None, T.RegionGroupId.Set.empty)
- | Some bid ->
- let parents = list_ancestor_region_groups regions_hierarchy bid in
- (Some bid, parents)
- in
- (* Is the function stateful, and can it fail? *)
- let lid = None in
- let effect_info = get_fun_effect_info fun_infos (FunId fun_id) lid bid in
(* We need an evaluation context to normalize the types (to normalize the
associated types, etc. - for instance it may happen that the types
refer to the types associated to a trait ref, but where the trait ref
@@ -885,7 +981,7 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id)
List.map (fun (g : T.region_var_group) -> g.id) regions_hierarchy
in
let ctx =
- InterpreterUtils.initialize_eval_context decls_ctx region_groups
+ InterpreterUtils.initialize_eval_ctx decls_ctx region_groups
sg.generics.types sg.generics.const_generics
in
(* Compute the normalization map for the *sty* types and add it to the context *)
@@ -902,17 +998,28 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id)
{ sg with A.inputs; output }
in
- (* List the inputs for:
- * - the fuel
- * - the forward function
- * - the parent backward functions, in proper order
- * - the current backward function (if it is a backward function)
- *)
- let fuel = mk_fuel_input_ty_as_list effect_info in
- let fwd_inputs = List.map (translate_fwd_ty type_infos) sg.inputs in
- (* For the backward functions: for now we don't supported nested borrows,
- * so just check that there aren't parent regions *)
- assert (T.RegionGroupId.Set.is_empty parents);
+ (* Is the forward function stateful, and can it fail? *)
+ let fwd_effect_info =
+ compute_raw_fun_effect_info fun_infos fun_id None None
+ in
+ (* Compute the forward inputs *)
+ let fwd_fuel = mk_fuel_input_ty_as_list fwd_effect_info in
+ let fwd_inputs_no_fuel_no_state =
+ List.map (translate_fwd_ty type_infos) sg.inputs
+ in
+ (* State input for the forward function *)
+ let fwd_state_ty =
+ (* For the forward state, we check if the *whole group* is stateful.
+ See {!effect_info}. *)
+ if fwd_effect_info.stateful_group then [ mk_state_ty ] else []
+ in
+ let fwd_inputs =
+ List.concat [ fwd_fuel; fwd_inputs_no_fuel_no_state; fwd_state_ty ]
+ in
+ (* Compute the backward output, without the effect information *)
+ let fwd_output = translate_fwd_ty type_infos sg.output in
+
+ (* Compute the type information for the backward function *)
(* Small helper to translate types for backward functions *)
let translate_back_ty_for_gid (gid : T.RegionGroupId.id) (ty : T.ty) :
ty option =
@@ -939,163 +1046,336 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id)
let inside_mut = false in
translate_back_ty type_infos keep_region inside_mut ty
in
- (* Compute the additinal inputs for the current function, if it is a backward
- * function *)
- let back_inputs =
- match gid with
- | None -> []
- | Some gid ->
- (* For now, we don't allow nested borrows, so the additional inputs to the
- backward function can only come from borrows that were returned like
- in (for the backward function we introduce for 'a):
- {[
- fn f<'a>(...) -> &'a mut u32;
- ]}
- Upon ending the abstraction for 'a, we need to get back the borrow
- the function returned.
- *)
- List.filter_map (translate_back_ty_for_gid gid) [ sg.output ]
- in
- (* If the function is stateful, the inputs are:
- - forward: [fwd_ty0, ..., fwd_tyn, state]
- - backward:
- - if {!Config.backward_no_state_update}: [fwd_ty0, ..., fwd_tyn, state, back_ty, state]
- - otherwise: [fwd_ty0, ..., fwd_tyn, state, back_ty]
-
- The backward takes the same state as input as the forward function,
- together with the state at the point where it gets called, if it is
- stateful.
-
- See the comments for {!Config.backward_no_state_update}
- *)
- let fwd_state_ty =
- (* For the forward state, we check if the *whole group* is stateful.
- See {!effect_info}. *)
- if effect_info.stateful_group then [ mk_state_ty ] else []
+ let translate_back_inputs_for_gid (gid : T.RegionGroupId.id) : ty list =
+ (* For now we don't supported nested borrows, so we check that there
+ aren't parent regions *)
+ let parents = list_ancestor_region_groups regions_hierarchy gid in
+ assert (T.RegionGroupId.Set.is_empty parents);
+ (* For now, we don't allow nested borrows, so the additional inputs to the
+ backward function can only come from borrows that were returned like
+ in (for the backward function we introduce for 'a):
+ {[
+ fn f<'a>(...) -> &'a mut u32;
+ ]}
+ Upon ending the abstraction for 'a, we need to get back the borrow
+ the function returned.
+ *)
+ let inputs =
+ List.filter_map (translate_back_ty_for_gid gid) [ sg.output ]
+ in
+ log#ldebug
+ (lazy
+ (let ctx = Print.Contexts.decls_ctx_to_fmt_env decls_ctx in
+ let pctx = PrintPure.decls_ctx_to_fmt_env decls_ctx in
+ let output = Print.Types.ty_to_string ctx sg.output in
+ let inputs =
+ Print.list_to_string (PrintPure.ty_to_string pctx false) inputs
+ in
+ "translate_back_inputs_for_gid:" ^ "\n- gid: "
+ ^ RegionGroupId.to_string gid
+ ^ "\n- output: " ^ output ^ "\n- back inputs: " ^ inputs ^ "\n"));
+ inputs
+ in
+ let compute_back_outputs_for_gid (gid : RegionGroupId.id) :
+ string option list * ty list =
+ (* The outputs are the borrows inside the regions of the abstractions
+ and which are present in the input values. For instance, see:
+ {[
+ fn f<'a>(x : &'a mut u32) -> ...;
+ ]}
+ Upon ending the abstraction for 'a, we give back the borrow which
+ was consumed through the [x] parameter.
+ *)
+ let outputs =
+ List.map
+ (fun (name, input_ty) -> (name, translate_back_ty_for_gid gid input_ty))
+ (List.combine input_names sg.inputs)
+ in
+ (* Filter *)
+ let outputs =
+ List.filter (fun (_, opt_ty) -> Option.is_some opt_ty) outputs
+ in
+ let outputs =
+ List.map (fun (name, opt_ty) -> (name, Option.get opt_ty)) outputs
+ in
+ let names, outputs = List.split outputs in
+ log#ldebug
+ (lazy
+ (let ctx = Print.Contexts.decls_ctx_to_fmt_env decls_ctx in
+ let pctx = PrintPure.decls_ctx_to_fmt_env decls_ctx in
+ let inputs =
+ Print.list_to_string (Print.Types.ty_to_string ctx) sg.inputs
+ in
+ let outputs =
+ Print.list_to_string (PrintPure.ty_to_string pctx false) outputs
+ in
+ "compute_back_outputs_for_gid:" ^ "\n- gid: "
+ ^ RegionGroupId.to_string gid
+ ^ "\n- inputs: " ^ inputs ^ "\n- back outputs: " ^ outputs ^ "\n"));
+ (names, outputs)
+ in
+ let compute_back_info_for_group (rg : T.region_var_group) :
+ RegionGroupId.id * back_sg_info =
+ let gid = rg.id in
+ let back_effect_info =
+ compute_raw_fun_effect_info fun_infos fun_id None (Some gid)
+ in
+ let inputs_no_state = translate_back_inputs_for_gid gid in
+ let inputs_no_state =
+ List.map (fun ty -> (Some "ret", ty)) inputs_no_state
+ in
+ (* In case we merge the forward/backward functions:
+ we consider the backward function as stateful and potentially failing
+ **only if it has inputs** (for the "potentially failing": if it has
+ not inputs, we directly evaluate it in the body of the forward function).
+ *)
+ let back_effect_info =
+ if !Config.return_back_funs then
+ let b = inputs_no_state <> [] in
+ {
+ back_effect_info with
+ stateful = back_effect_info.stateful && b;
+ can_fail = back_effect_info.can_fail && b;
+ }
+ else back_effect_info
+ in
+ let state =
+ if back_effect_info.stateful then [ (None, mk_state_ty) ] else []
+ in
+ let inputs = inputs_no_state @ state in
+ let output_names, outputs = compute_back_outputs_for_gid gid in
+ let filter =
+ !Config.simplify_merged_fwd_backs
+ && !Config.return_back_funs && inputs = [] && outputs = []
+ in
+ let info =
+ {
+ inputs;
+ inputs_no_state;
+ outputs;
+ output_names;
+ effect_info = back_effect_info;
+ filter;
+ }
+ in
+ (gid, info)
in
- let back_state_ty =
- (* For the backward state, we check if the function is a backward function,
- and it is stateful *)
- if effect_info.stateful && Option.is_some gid then [ mk_state_ty ] else []
+ let back_sg =
+ RegionGroupId.Map.of_list
+ (List.map compute_back_info_for_group regions_hierarchy)
in
- (* Concatenate the inputs, in the following order:
- * - forward inputs
- * - forward state input
- * - backward inputs
- * - backward state input
- *)
- let inputs =
- List.concat [ fuel; fwd_inputs; fwd_state_ty; back_inputs; back_state_ty ]
- in
- (* Outputs *)
- let output_names, doutputs =
- match gid with
- | None ->
- (* This is a forward function: there is one (unnamed) output *)
- ([ None ], [ translate_fwd_ty type_infos sg.output ])
- | Some gid ->
- (* This is a backward function: there might be several outputs.
- The outputs are the borrows inside the regions of the abstractions
- and which are present in the input values. For instance, see:
- {[
- fn f<'a>(x : &'a mut u32) -> ...;
- ]}
- Upon ending the abstraction for 'a, we give back the borrow which
- was consumed through the [x] parameter.
- *)
- let outputs =
- List.map
- (fun (name, input_ty) ->
- (name, translate_back_ty_for_gid gid input_ty))
- (List.combine input_names sg.inputs)
- in
- (* Filter *)
- let outputs =
- List.filter (fun (_, opt_ty) -> Option.is_some opt_ty) outputs
- in
- let outputs =
- List.map (fun (name, opt_ty) -> (name, Option.get opt_ty)) outputs
- in
- List.split outputs
- in
- (* Create the return type *)
- let output =
- (* Group the outputs together *)
- let output = mk_simpl_tuple_ty doutputs in
- (* Add the output state *)
- let output =
- if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ]
- else output
+ (* The additional information about the forward function *)
+ let fwd_info =
+ (* *)
+ let has_fuel = fwd_fuel <> [] in
+ let num_inputs_no_fuel_no_state = List.length fwd_inputs_no_fuel_no_state in
+ let num_inputs_with_fuel_no_state =
+ (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *)
+ List.length fwd_fuel + num_inputs_no_fuel_no_state
+ in
+ let fwd_info : inputs_info =
+ {
+ has_fuel;
+ num_inputs_no_fuel_no_state;
+ num_inputs_with_fuel_no_state;
+ num_inputs_with_fuel_with_state =
+ (* We use the fact that [fwd_state_ty] has length 1 if there is a state,
+ and 0 otherwise *)
+ num_inputs_with_fuel_no_state + List.length fwd_state_ty;
+ }
+ in
+ let ignore_output =
+ if !Config.return_back_funs && !Config.simplify_merged_fwd_backs then
+ ty_is_unit fwd_output
+ && List.exists
+ (fun (info : back_sg_info) -> not info.filter)
+ (RegionGroupId.Map.values back_sg)
+ else false
in
- (* Wrap in a result type *)
- if effect_info.can_fail then mk_result_ty output else output
+ let info = { fwd_info; effect_info = fwd_effect_info; ignore_output } in
+ assert (fun_sig_info_is_wf info);
+ info
in
+
(* Generic parameters *)
let generics = translate_generic_params sg.generics in
+
(* Return *)
- let has_fuel = fuel <> [] in
- let num_fwd_inputs_no_state = List.length fwd_inputs in
- let num_fwd_inputs_with_fuel_no_state =
- (* We use the fact that [fuel] has length 1 if we use some fuel, 0 otherwise *)
- List.length fuel + num_fwd_inputs_no_state
- in
- let num_back_inputs_no_state =
- if bid = None then None else Some (List.length back_inputs)
+ let preds = translate_predicates sg.preds in
+ {
+ generics;
+ llbc_generics = sg.generics;
+ preds;
+ fwd_inputs;
+ fwd_output;
+ back_sg;
+ fwd_info;
+ }
+
+let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx)
+ (fun_id : FunDeclId.id) (sg : A.fun_sig) (input_names : string option list)
+ : decomposed_fun_sig =
+ (* Retrieve the list of parent backward functions *)
+ let regions_hierarchy =
+ FunIdMap.find (FRegular fun_id) decls_ctx.fun_ctx.regions_hierarchies
in
- let info =
- {
- has_fuel;
- num_fwd_inputs_with_fuel_no_state;
- num_fwd_inputs_with_fuel_with_state =
- (* We use the fact that [fwd_state_ty] has length 1 if there is a state,
- and 0 otherwise *)
- num_fwd_inputs_with_fuel_no_state + List.length fwd_state_ty;
- num_back_inputs_no_state;
- num_back_inputs_with_state =
- (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *)
- Option.map
- (fun n -> n + List.length back_state_ty)
- num_back_inputs_no_state;
- effect_info;
- }
+ translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx
+ (FunId (FRegular fun_id)) regions_hierarchy sg input_names
+
+let translate_fun_sig_from_decl_to_decomposed (decls_ctx : C.decls_ctx)
+ (fdef : LlbcAst.fun_decl) : decomposed_fun_sig =
+ let input_names =
+ match fdef.body with
+ | None -> List.map (fun _ -> None) fdef.signature.inputs
+ | Some body ->
+ List.map
+ (fun (v : LlbcAst.var) -> v.name)
+ (LlbcAstUtils.fun_body_get_input_vars body)
in
- let preds = translate_predicates sg.preds in
let sg =
- {
- generics;
- llbc_generics = sg.generics;
- preds;
- inputs;
- output;
- doutputs;
- info;
- }
+ translate_fun_sig_to_decomposed decls_ctx fdef.def_id fdef.signature
+ input_names
+ in
+ log#ldebug
+ (lazy
+ ("translate_fun_sig_from_decl_to_decomposed:" ^ "\n- name: "
+ ^ T.show_name fdef.name ^ "\n- sg:\n" ^ show_decomposed_fun_sig sg ^ "\n"));
+ sg
+
+let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty
+ =
+ let output =
+ if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ] else ty
in
- { sg; output_names }
+ if effect_info.can_fail then mk_result_ty output else output
-let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern =
+let mk_back_output_ty_from_effect_info (effect_info : fun_effect_info)
+ (inputs : ty list) (ty : ty) : ty =
+ let output =
+ if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; ty ] else ty
+ in
+ if effect_info.can_fail && inputs <> [] then mk_result_ty output else output
+
+(** Compute the arrow types for all the backward functions.
+
+ If a backward function has no inputs/outputs we filter it.
+ *)
+let compute_back_tys_with_info (dsg : Pure.decomposed_fun_sig)
+ (subst : (generic_args * trait_instance_id) option) :
+ (back_sg_info * ty) option list =
+ List.map
+ (fun (back_sg : back_sg_info) ->
+ let effect_info = back_sg.effect_info in
+ (* Compute the input/output types *)
+ let inputs = List.map snd back_sg.inputs in
+ let outputs = back_sg.outputs in
+ (* Filter if necessary *)
+ if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] then
+ None
+ else
+ let output = mk_simpl_tuple_ty outputs in
+ let output =
+ mk_back_output_ty_from_effect_info effect_info inputs output
+ in
+ let ty = mk_arrows inputs output in
+ (* Substitute - TODO: normalize *)
+ let ty =
+ match subst with
+ | None -> ty
+ | Some (generics, tr_self) ->
+ let subst =
+ make_subst_from_generics dsg.generics generics tr_self
+ in
+ ty_substitute subst ty
+ in
+ Some (back_sg, ty))
+ (RegionGroupId.Map.values dsg.back_sg)
+
+let compute_back_tys (dsg : Pure.decomposed_fun_sig)
+ (subst : (generic_args * trait_instance_id) option) : ty option list =
+ List.map (Option.map snd) (compute_back_tys_with_info dsg subst)
+
+(** In case we merge the fwd/back functions: compute the output type of
+ a function, from a decomposed signature. *)
+let compute_output_ty_from_decomposed (dsg : Pure.decomposed_fun_sig) : ty =
+ assert !Config.return_back_funs;
+ (* Compute the arrow types for all the backward functions *)
+ let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in
+ (* Group the forward output and the types of the backward functions *)
+ let effect_info = dsg.fwd_info.effect_info in
+ let output =
+ (* We might need to ignore the output of the forward function
+ (if it is unit for instance) *)
+ let tys =
+ if dsg.fwd_info.ignore_output then back_tys
+ else dsg.fwd_output :: back_tys
+ in
+ mk_simpl_tuple_ty tys
+ in
+ mk_output_ty_from_effect_info effect_info output
+
+let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
+ (gid : RegionGroupId.id option) : fun_sig =
+ let generics = dsg.generics in
+ let llbc_generics = dsg.llbc_generics in
+ let preds = dsg.preds in
+ (* Compute the effects info *)
+ let fwd_info = dsg.fwd_info in
+ let back_effect_info =
+ RegionGroupId.Map.of_list
+ (List.map
+ (fun ((gid, info) : RegionGroupId.id * back_sg_info) ->
+ (gid, info.effect_info))
+ (RegionGroupId.Map.bindings dsg.back_sg))
+ in
+ let mk_output_ty = mk_output_ty_from_effect_info in
+ let inputs, output =
+ (* Two cases depending on whether we split the forward/backward functions or not *)
+ if !Config.return_back_funs then (
+ assert (gid = None);
+ let output = compute_output_ty_from_decomposed dsg in
+ let inputs = dsg.fwd_inputs in
+ (inputs, output))
+ else
+ match gid with
+ | None ->
+ let effect_info = dsg.fwd_info.effect_info in
+ let output = mk_output_ty effect_info dsg.fwd_output in
+ (dsg.fwd_inputs, output)
+ | Some gid ->
+ let back_sg = RegionGroupId.Map.find gid dsg.back_sg in
+ let effect_info = back_sg.effect_info in
+ let inputs = dsg.fwd_inputs @ List.map snd back_sg.inputs in
+ let output = mk_simpl_tuple_ty back_sg.outputs in
+ let output = mk_output_ty effect_info output in
+ (inputs, output)
+ in
+ { generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info }
+
+let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * var * typed_pattern =
(* Generate the fresh variable *)
- let id, var_counter = VarId.fresh ctx.var_counter in
+ let id, var_counter = VarId.fresh !(ctx.var_counter) in
let state_var =
{ id; basename = Some ConstStrings.state_basename; ty = mk_state_ty }
in
let state_pat = mk_typed_pattern_from_var state_var None in
(* Update the context *)
- let ctx = { ctx with var_counter; state_var = id } in
+ ctx.var_counter := var_counter;
+ let ctx = { ctx with state_var = id } in
(* Return *)
- (ctx, state_pat)
+ (ctx, state_var, state_pat)
(** WARNING: do not call this function directly.
Call [fresh_named_var_for_symbolic_value] instead. *)
let fresh_var_llbc_ty (basename : string option) (ty : T.ty) (ctx : bs_ctx) :
bs_ctx * var =
(* Generate the fresh variable *)
- let id, var_counter = VarId.fresh ctx.var_counter in
+ let id, var_counter = VarId.fresh !(ctx.var_counter) in
let ty = ctx_translate_fwd_ty ctx ty in
let var = { id; basename; ty } in
(* Update the context *)
- let ctx = { ctx with var_counter } in
+ ctx.var_counter := var_counter;
(* Return *)
(ctx, var)
@@ -1129,10 +1409,10 @@ let fresh_named_vars_for_symbolic_values
let fresh_var (basename : string option) (ty : ty) (ctx : bs_ctx) : bs_ctx * var
=
(* Generate the fresh variable *)
- let id, var_counter = VarId.fresh ctx.var_counter in
+ let id, var_counter = VarId.fresh !(ctx.var_counter) in
let var = { id; basename; ty } in
(* Update the context *)
- let ctx = { ctx with var_counter } in
+ ctx.var_counter := var_counter;
(* Return *)
(ctx, var)
@@ -1140,6 +1420,58 @@ let fresh_vars (vars : (string option * ty) list) (ctx : bs_ctx) :
bs_ctx * var list =
List.fold_left_map (fun ctx (name, ty) -> fresh_var name ty ctx) ctx vars
+let fresh_opt_vars (vars : (string option * ty) option list) (ctx : bs_ctx) :
+ bs_ctx * var option list =
+ List.fold_left_map
+ (fun ctx var ->
+ match var with
+ | None -> (ctx, None)
+ | Some (name, ty) ->
+ let ctx, var = fresh_var name ty ctx in
+ (ctx, Some var))
+ ctx vars
+
+(* Introduce variables for the backward functions *)
+let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var option list =
+ (* We lookup the LLBC definition in an attempt to derive pretty names
+ for the backward functions. *)
+ let back_var_names =
+ let def_id = ctx.fun_decl.def_id in
+ let sg = ctx.fun_decl.signature in
+ let regions_hierarchy =
+ LlbcAstUtils.FunIdMap.find (FRegular def_id)
+ ctx.fun_ctx.regions_hierarchies
+ in
+ List.map
+ (fun (gid, _) ->
+ let rg = RegionGroupId.nth regions_hierarchy gid in
+ let region_names =
+ List.map
+ (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name)
+ rg.regions
+ in
+ let name =
+ match region_names with
+ | [] -> "back"
+ | [ Some r ] -> "back" ^ r
+ | _ ->
+ (* Concatenate all the region names *)
+ "back"
+ ^ String.concat "" (List.filter_map (fun x -> x) region_names)
+ in
+ Some name)
+ (RegionGroupId.Map.bindings ctx.sg.back_sg)
+ in
+ let back_vars = List.combine back_var_names (compute_back_tys ctx.sg None) in
+ let back_vars =
+ List.map
+ (fun (name, ty) ->
+ match ty with None -> None | Some ty -> Some (name, ty))
+ back_vars
+ in
+ fresh_opt_vars back_vars ctx
+
+(** IMPORTANT: do not use this one directly, but rather {!symbolic_value_to_texpression} *)
let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var =
match V.SymbolicValueId.Map.find_opt sv.sv_id ctx.sv_to_var with
| Some v -> v
@@ -1158,12 +1490,22 @@ let rec unbox_typed_value (v : V.typed_value) : V.typed_value =
| _ -> raise (Failure "Unreachable"))
| _ -> v
-(** Translate a symbolic value *)
+(** Translate a symbolic value.
+
+ Because we do not necessarily introduce variables for the symbolic values
+ of (translated) type unit, it is important that we do not lookup variables
+ in case the symbolic value has type unit.
+ *)
let symbolic_value_to_texpression (ctx : bs_ctx) (sv : V.symbolic_value) :
texpression =
(* Translate the type *)
- let var = lookup_var_for_symbolic_value sv ctx in
- mk_texpression_from_var var
+ let ty = ctx_translate_fwd_ty ctx sv.sv_ty in
+ (* If the type is unit, directly return unit *)
+ if ty_is_unit ty then mk_unit_rvalue
+ else
+ (* Otherwise lookup the variable *)
+ let var = lookup_var_for_symbolic_value sv ctx in
+ mk_texpression_from_var var
(** Translate a typed value.
@@ -1342,13 +1684,11 @@ and aproj_to_consumed (ctx : bs_ctx) (aproj : V.aproj) : texpression option =
match aproj with
| V.AEndedProjLoans (msv, []) ->
(* The symbolic value was left unchanged *)
- let var = lookup_var_for_symbolic_value msv ctx in
- Some (mk_texpression_from_var var)
+ Some (symbolic_value_to_texpression ctx msv)
| V.AEndedProjLoans (_, [ (mnv, child_aproj) ]) ->
assert (child_aproj = AIgnoredProjBorrows);
(* The symbolic value was updated *)
- let var = lookup_var_for_symbolic_value mnv ctx in
- Some (mk_texpression_from_var var)
+ Some (symbolic_value_to_texpression ctx mnv)
| V.AEndedProjLoans (_, _) ->
(* The symbolic value was updated, and the given back values come from sevearl
* abstractions *)
@@ -1543,7 +1883,9 @@ let mk_emeta_symbolic_assignments (vars : var list) (values : texpression list)
let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =
match e with
- | S.Return (ectx, opt_v) -> translate_return ectx opt_v ctx
+ | S.Return (ectx, opt_v) ->
+ (* Remark: we can't get there if we are inside a loop *)
+ translate_return ectx opt_v ctx
| ReturnWithLoop (loop_id, is_continue) ->
translate_return_with_loop loop_id is_continue ctx
| Panic -> translate_panic ctx
@@ -1565,32 +1907,56 @@ and translate_panic (ctx : bs_ctx) : texpression =
* but it won't be true anymore once we translate individual blocks *)
(* If we use a state monad, we need to add a lambda for the state variable *)
(* Note that only forward functions return a state *)
- let output_ty =
- if ctx.inside_loop && Option.is_some ctx.bid then
- (* We are synthesizing the backward function of a loop body *)
- let bid = Option.get ctx.bid in
- let back_vars =
- T.RegionGroupId.Map.find bid (Option.get ctx.loop_backward_outputs)
- in
- let tys = List.map (fun (v : var) -> v.ty) back_vars in
- mk_simpl_tuple_ty tys
- else
- (* Regular function, or forward function (the forward translation for
- a loop has the same return type as the parent function)
- *)
- mk_simpl_tuple_ty ctx.sg.doutputs
- in
+ let effect_info = ctx_get_effect_info ctx in
(* TODO: we should use a [Fail] function *)
- if ctx.sg.info.effect_info.stateful then
- (* Create the [Fail] value *)
- let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in
- let ret_v =
- mk_result_fail_texpression_with_error_id error_failure_id ret_ty
- in
- ret_v
- else mk_result_fail_texpression_with_error_id error_failure_id output_ty
+ let mk_output output_ty =
+ if effect_info.stateful then
+ (* Create the [Fail] value *)
+ let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; output_ty ] in
+ let ret_v =
+ mk_result_fail_texpression_with_error_id error_failure_id ret_ty
+ in
+ ret_v
+ else mk_result_fail_texpression_with_error_id error_failure_id output_ty
+ in
+ if ctx.inside_loop && Option.is_some ctx.bid then
+ (* We are synthesizing the backward function of a loop body *)
+ let bid = Option.get ctx.bid in
+ let loop_id = Option.get ctx.loop_id in
+ let loop = LoopId.Map.find loop_id ctx.loops in
+ let tys = RegionGroupId.Map.find bid loop.back_outputs in
+ let output = mk_simpl_tuple_ty tys in
+ mk_output output
+ else
+ (* Regular function, or forward function (the forward translation for
+ a loop has the same return type as the parent function)
+ *)
+ match ctx.bid with
+ | None ->
+ if !Config.return_back_funs then
+ let back_tys = compute_back_tys ctx.sg None in
+ let back_tys = List.filter_map (fun x -> x) back_tys in
+ let tys =
+ if ctx.sg.fwd_info.ignore_output then back_tys
+ else ctx.sg.fwd_output :: back_tys
+ in
+ let output = mk_simpl_tuple_ty tys in
+ mk_output output
+ else mk_output ctx.sg.fwd_output
+ | Some bid ->
+ let output =
+ mk_simpl_tuple_ty (RegionGroupId.Map.find bid ctx.sg.back_sg).outputs
+ in
+ mk_output output
+
+(** [opt_v]: the value to return, in case we translate a forward body.
-(** [opt_v]: the value to return, in case we translate a forward body *)
+ Remark: for now, we can't get there if we are inside a loop.
+ If inside a loop, we use {!translate_return_with_loop}.
+
+ Remark: in case we merge the forward/backward functions, we introduce
+ those in [translate_forward_end].
+*)
and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option)
(ctx : bs_ctx) : texpression =
(* There are two cases:
@@ -1599,22 +1965,20 @@ and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option)
- or we are translating a backward function, in which case it should be [None]
*)
(* Compute the values that we should return *without the state and the result
- * wrapper* *)
+ wrapper* *)
let output =
match ctx.bid with
| None ->
(* Forward function *)
let v = Option.get opt_v in
typed_value_to_texpression ctx ectx v
- | Some bid ->
+ | Some _ ->
(* Backward function *)
(* Sanity check *)
assert (opt_v = None);
(* Group the variables in which we stored the values we need to give back.
- * See the explanations for the [SynthInput] case in [translate_end_abstraction] *)
- let backward_outputs =
- T.RegionGroupId.Map.find bid ctx.backward_outputs
- in
+ See the explanations for the [SynthInput] case in [translate_end_abstraction] *)
+ let backward_outputs = Option.get ctx.backward_outputs in
let field_values = List.map mk_texpression_from_var backward_outputs in
mk_simpl_tuple_texpression field_values
in
@@ -1622,7 +1986,7 @@ and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option)
* - error-monad: Return x
* - state-error: Return (state, x)
* *)
- let effect_info = ctx.sg.info.effect_info in
+ let effect_info = ctx_get_effect_info ctx in
let output =
if effect_info.stateful then
let state_rvalue = mk_state_texpression ctx.state_var in
@@ -1647,24 +2011,12 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool)
*)
let output =
match ctx.bid with
- | None ->
- (* Forward *)
- mk_texpression_from_var
- (Option.get loop_info.forward_output_no_state_no_result)
- | Some bid ->
+ | None -> Option.get loop_info.forward_output_no_state_no_result
+ | Some _ ->
(* Backward *)
(* Group the variables in which we stored the values we need to give back.
* See the explanations for the [SynthInput] case in [translate_end_abstraction] *)
- let backward_outputs =
- let map =
- if ctx.inside_loop then
- (* We are synthesizing a loop body *)
- Option.get ctx.loop_backward_outputs
- else (* Regular function *)
- ctx.backward_outputs
- in
- T.RegionGroupId.Map.find bid map
- in
+ let backward_outputs = Option.get ctx.backward_outputs in
let field_values = List.map mk_texpression_from_var backward_outputs in
mk_simpl_tuple_texpression field_values
in
@@ -1676,7 +2028,7 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool)
* effect - in particular, one manipulates a state iff the other does
* the same.
* *)
- let effect_info = ctx.sg.info.effect_info in
+ let effect_info = ctx_get_effect_info ctx in
let output =
if effect_info.stateful then
let state_rvalue = mk_state_texpression ctx.state_var in
@@ -1684,13 +2036,15 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool)
else output
in
(* Wrap in a result - TODO: check effect_info.can_fail to not always wrap *)
- mk_result_return_texpression output
+ mk_emeta (Tag "return_with_loop") (mk_result_return_texpression output)
and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
texpression =
log#ldebug
(lazy
- ("translate_function_call:\n"
+ ("translate_function_call:\n" ^ "\n- call.call_id:"
+ ^ S.show_call_id call.call_id
+ ^ "\n\n- call.generics:\n"
^ ctx_generic_args_to_string ctx call.generics));
(* Translate the function call *)
let generics = ctx_translate_fwd_generic_args ctx call.generics in
@@ -1702,10 +2056,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
(List.combine args args_mplaces)
in
let dest_mplace = translate_opt_mplace call.dest_place in
- let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
(* Retrieve the function id, and register the function call in the context
* if necessary. *)
- let ctx, fun_id, effect_info, args, out_state =
+ let ctx, fun_id, effect_info, args, dest_v =
match call.call_id with
| S.Fun (fid, call_id) ->
(* Regular function call *)
@@ -1713,25 +2066,150 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
let func = Fun (FromLlbc (fid_t, None, None)) in
(* Retrieve the effect information about this function (can fail,
* takes a state as input, etc.) *)
- let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos fid None None
- in
+ let effect_info = get_fun_effect_info ctx fid None None in
(* Depending on the function effects:
- * - add the fuel
- * - add the state input argument
- * - generate a fresh state variable for the returned state
+ - add the fuel
+ - add the state input argument
+ - generate a fresh state variable for the returned state
*)
let args, ctx, out_state =
let fuel = mk_fuel_input_as_list ctx effect_info in
if effect_info.stateful then
let state_var = mk_state_texpression ctx.state_var in
- let ctx, nstate_var = bs_ctx_fresh_state_var ctx in
+ let ctx, _, nstate_var = bs_ctx_fresh_state_var ctx in
(List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var)
else (List.concat [ fuel; args ], ctx, None)
in
+ (* If we do not split the forward/backward functions: generate the
+ variables for the backward functions returned by the forward
+ function. *)
+ let ctx, ignore_fwd_output, back_funs_map, back_funs =
+ if !Config.return_back_funs then (
+ (* We need to compute the signatures of the backward functions. *)
+ let sg = Option.get call.sg in
+ let decls_ctx = ctx.decls_ctx in
+ let dsg =
+ translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx
+ fid call.regions_hierarchy sg
+ (List.map (fun _ -> None) sg.inputs)
+ in
+ log#ldebug
+ (lazy ("dsg.generics:\n" ^ show_generic_params dsg.generics));
+ let tr_self, all_generics =
+ match call.trait_method_generics with
+ | None -> (UnknownTrait __FUNCTION__, generics)
+ | Some (all_generics, tr_self) ->
+ let all_generics =
+ ctx_translate_fwd_generic_args ctx all_generics
+ in
+ let tr_self =
+ translate_fwd_trait_instance_id ctx.type_ctx.type_infos
+ tr_self
+ in
+ (tr_self, all_generics)
+ in
+ let back_tys =
+ compute_back_tys_with_info dsg (Some (all_generics, tr_self))
+ in
+ (* Introduce variables for the backward functions *)
+ (* Compute a proper basename for the variables *)
+ let back_fun_name =
+ let name =
+ match fid with
+ | FunId (FAssumed fid) -> (
+ match fid with
+ | BoxNew -> "box_new"
+ | BoxFree -> "box_free"
+ | ArrayRepeat -> "array_repeat"
+ | ArrayIndexShared -> "index_shared"
+ | ArrayIndexMut -> "index_mut"
+ | ArrayToSliceShared -> "to_slice_shared"
+ | ArrayToSliceMut -> "to_slice_mut"
+ | SliceIndexShared -> "index_shared"
+ | SliceIndexMut -> "index_mut")
+ | FunId (FRegular fid) | TraitMethod (_, _, fid) -> (
+ let decl =
+ FunDeclId.Map.find fid ctx.fun_ctx.llbc_fun_decls
+ in
+ match Collections.List.last decl.name with
+ | PeIdent (s, _) -> s
+ | PeImpl _ ->
+ (* We shouldn't get there *)
+ raise (Failure "Unexpected"))
+ in
+ name ^ "_back"
+ in
+ let ctx, back_vars =
+ fresh_opt_vars
+ (List.map
+ (fun ty ->
+ match ty with
+ | None -> None
+ | Some (back_sg, ty) ->
+ (* We insert a name for the variable only if the function
+ can fail: if it can fail, it means the call returns a backward
+ function. Otherwise, we it directly returns the value given
+ back by the backward function, which means we shouldn't
+ give it a name like "back..." (it doesn't make sense) *)
+ let name =
+ if back_sg.effect_info.can_fail then
+ Some back_fun_name
+ else None
+ in
+ Some (name, ty))
+ back_tys)
+ ctx
+ in
+ let back_funs =
+ List.filter_map
+ (fun v ->
+ match v with
+ | None -> None
+ | Some v -> Some (mk_typed_pattern_from_var v None))
+ back_vars
+ in
+ let gids =
+ List.map
+ (fun (g : T.region_var_group) -> g.id)
+ call.regions_hierarchy
+ in
+ let back_vars =
+ List.map (Option.map mk_texpression_from_var) back_vars
+ in
+ let back_funs_map =
+ RegionGroupId.Map.of_list (List.combine gids back_vars)
+ in
+ (ctx, dsg.fwd_info.ignore_output, Some back_funs_map, back_funs))
+ else (ctx, false, None, [])
+ in
+ (* Compute the pattern for the destination *)
+ let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
+ let dest = mk_typed_pattern_from_var dest dest_mplace in
+ let dest =
+ (* Here there is something subtle: as we might ignore the output
+ of the forward function (because it translates to unit) we doNOT
+ necessarily introduce in the let-binding the variable to which we
+ map the symbolic value which was introduced for the output of the
+ function call. This would be problematic if later we need to
+ translate this symbolic value, but we implemented
+ {!symbolic_value_to_texpression} so that it doesn't perform any
+ lookups if the symbolic value has type unit.
+ *)
+ let vars =
+ if ignore_fwd_output then back_funs else dest :: back_funs
+ in
+ mk_simpl_tuple_pattern vars
+ in
+ let dest =
+ match out_state with
+ | None -> dest
+ | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ]
+ in
(* Register the function call *)
- let ctx = bs_ctx_register_forward_call call_id call args ctx in
- (ctx, func, effect_info, args, out_state)
+ let ctx =
+ bs_ctx_register_forward_call call_id call args back_funs_map ctx
+ in
+ (ctx, func, effect_info, args, dest)
| S.Unop E.Not ->
let effect_info =
{
@@ -1742,7 +2220,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
is_rec = false;
}
in
- (ctx, Unop Not, effect_info, args, None)
+ let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
+ let dest = mk_typed_pattern_from_var dest dest_mplace in
+ (ctx, Unop Not, effect_info, args, dest)
| S.Unop E.Neg -> (
match args with
| [ arg ] ->
@@ -1758,7 +2238,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
is_rec = false;
}
in
- (ctx, Unop (Neg int_ty), effect_info, args, None)
+ let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
+ let dest = mk_typed_pattern_from_var dest dest_mplace in
+ (ctx, Unop (Neg int_ty), effect_info, args, dest)
| _ -> raise (Failure "Unreachable"))
| S.Unop (E.Cast cast_kind) -> (
match cast_kind with
@@ -1773,7 +2255,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
is_rec = false;
}
in
- (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None)
+ let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
+ let dest = mk_typed_pattern_from_var dest dest_mplace in
+ (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, dest)
| CastFnPtr _ -> raise (Failure "TODO: function casts"))
| S.Binop binop -> (
match args with
@@ -1793,15 +2277,11 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
is_rec = false;
}
in
- (ctx, Binop (binop, int_ty0), effect_info, args, None)
+ let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
+ let dest = mk_typed_pattern_from_var dest dest_mplace in
+ (ctx, Binop (binop, int_ty0), effect_info, args, dest)
| _ -> raise (Failure "Unreachable"))
in
- let dest_v =
- let dest = mk_typed_pattern_from_var dest dest_mplace in
- match out_state with
- | None -> dest
- | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ]
- in
let func = { id = FunOrOp fun_id; generics } in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
let ret_ty =
@@ -1846,45 +2326,38 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs)
^ abs_to_string ctx abs ^ "\n"));
(* When we end an input abstraction, this input abstraction gets back
- * the borrows which it introduced in the context through the input
- * values: by listing those values, we get the values which are given
- * back by one of the backward functions we are synthesizing. *)
- (* Note that we don't support nested borrows for now: if we find
- * an ended synthesized input abstraction, it must be the one corresponding
- * to the backward function wer are synthesizing, it can't be the one
- * for a parent backward function.
- *)
+ the borrows which it introduced in the context through the input
+ values: by listing those values, we get the values which are given
+ back by one of the backward functions we are synthesizing.
+
+ Note that we don't support nested borrows for now: if we find
+ an ended synthesized input abstraction, it must be the one corresponding
+ to the backward function wer are synthesizing, it can't be the one
+ for a parent backward function.
+ *)
let bid = Option.get ctx.bid in
assert (rg_id = bid);
- (* The translation is done as follows:
- * - for a given backward function, we choose a set of variables [v_i]
- * - when we detect the ended input abstraction which corresponds
- * to the backward function, and which consumed the values [consumed_i],
- * we introduce:
- * {[
- * let v_i = consumed_i in
- * ...
- * ]}
- * Then, when we reach the [Return] node, we introduce:
- * {[
- * (v_i)
- * ]}
- * *)
- (* First, get the given back variables.
+ (* First, introduce the given back variables.
We don't use the same given back variables if we translate a loop or
the standard body of a function.
*)
- let given_back_variables =
- let map =
+ let ctx, given_back_variables =
+ let vars =
if ctx.inside_loop then
(* We are synthesizing a loop body *)
- Option.get ctx.loop_backward_outputs
- else (* Regular function body *)
- ctx.backward_outputs
+ let loop_id = Option.get ctx.loop_id in
+ let loop = LoopId.Map.find loop_id ctx.loops in
+ let tys = RegionGroupId.Map.find bid loop.back_outputs in
+ List.map (fun ty -> (None, ty)) tys
+ else
+ (* Regular function body *)
+ let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in
+ List.combine back_sg.output_names back_sg.outputs
in
- T.RegionGroupId.Map.find bid map
+ let ctx, vars = fresh_vars vars ctx in
+ ({ ctx with backward_outputs = Some vars }, vars)
in
(* Get the list of values consumed by the abstraction upon ending *)
@@ -1933,9 +2406,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
(* Those don't have backward functions *)
raise (Failure "Unreachable")
in
- let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos fun_id None (Some rg_id)
- in
+ let effect_info = get_fun_effect_info ctx fun_id None (Some rg_id) in
let generics = ctx_translate_fwd_generic_args ctx call.generics in
(* Retrieve the original call and the parent abstractions *)
let _forward, backwards = get_abs_ancestors ctx abs call_id in
@@ -1959,14 +2430,16 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
let back_state, ctx, nstate =
if effect_info.stateful then
let back_state = mk_state_texpression ctx.state_var in
- let ctx, nstate = bs_ctx_fresh_state_var ctx in
+ let ctx, _, nstate = bs_ctx_fresh_state_var ctx in
([ back_state ], ctx, Some nstate)
else ([], ctx, None)
in
(* Concatenate all the inpus *)
- let inputs =
- List.concat [ fwd_inputs; back_ancestors_inputs; back_inputs; back_state ]
+ let inherited_inputs =
+ if !Config.return_back_funs then []
+ else List.concat [ fwd_inputs; back_ancestors_inputs ]
in
+ let back_inputs = List.append back_inputs back_state in
(* Retrieve the values given back by this function: those are the output
* values. We rely on the fact that there are no nested borrows to use the
* meta-place information from the input values given to the forward function
@@ -1983,78 +2456,61 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
| None -> output
| Some nstate -> mk_simpl_tuple_pattern [ nstate; output ]
in
- (* Sanity check: there is the proper number of inputs and outputs, and they have the proper type *)
- (if (* TODO: normalize the types *) !Config.type_check_pure_code then
- match fun_id with
- | FunId fun_id ->
- let inst_sg =
- get_instantiated_fun_sig fun_id (Some rg_id) generics ctx
- in
- log#ldebug
- (lazy
- ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs ("
- ^ string_of_int (List.length inputs)
- ^ "): "
- ^ String.concat ", " (List.map (texpression_to_string ctx) inputs)
- ^ "\n- inst_sg.inputs ("
- ^ string_of_int (List.length inst_sg.inputs)
- ^ "): "
- ^ String.concat ", "
- (List.map (pure_ty_to_string ctx) inst_sg.inputs)));
- List.iter
- (fun (x, ty) -> assert ((x : texpression).ty = ty))
- (List.combine inputs inst_sg.inputs);
- log#ldebug
- (lazy
- ("\n- outputs: "
- ^ string_of_int (List.length outputs)
- ^ "\n- expected outputs: "
- ^ string_of_int (List.length inst_sg.doutputs)));
- List.iter
- (fun (x, ty) -> assert ((x : typed_pattern).ty = ty))
- (List.combine outputs inst_sg.doutputs)
- | _ -> (* TODO: trait methods *) ());
(* Retrieve the function id, and register the function call in the context
- * if necessary *)
+ if necessary.Arith_status *)
let ctx, func =
- bs_ctx_register_backward_call abs call_id rg_id back_inputs ctx
+ bs_ctx_register_backward_call abs effect_info call_id rg_id inherited_inputs
+ back_inputs generics output.ty ctx
in
(* Translate the next expression *)
let next_e = translate_expression e ctx in
(* Put everything together *)
+ let inputs = List.append inherited_inputs back_inputs in
let args_mplaces = List.map (fun _ -> None) inputs in
let args =
List.map
(fun (arg, mp) -> mk_opt_mplace_texpression mp arg)
(List.combine inputs args_mplaces)
in
- let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
- let ret_ty =
- if effect_info.can_fail then mk_result_ty output.ty else output.ty
- in
- let func_ty = mk_arrows input_tys ret_ty in
- let func = { id = FunOrOp func; generics } in
- let func = { e = Qualif func; ty = func_ty } in
- let call = mk_apps func args in
(* **Optimization**:
- * =================
- * We do a small optimization here: if the backward function doesn't
- * have any output, we don't introduce any function call.
- * See the comment in {!Config.filter_useless_monadic_calls}.
- *
- * TODO: use an option to disallow backward functions from updating the state.
- * TODO: a backward function which only gives back shared borrows shouldn't
- * update the state (state updates should only be used for mutable borrows,
- * with objects like Rc for instance).
- *)
- if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None then (
+ =================
+ We do a small optimization here if we split the forward/backward functions.
+ If the backward function doesn't have any output, we don't introduce any function
+ call.
+ See the comment in {!Config.filter_useless_monadic_calls}.
+
+ TODO: use an option to disallow backward functions from updating the state.
+ TODO: a backward function which only gives back shared borrows shouldn't
+ update the state (state updates should only be used for mutable borrows,
+ with objects like Rc for instance).
+ *)
+ if
+ (not !Config.return_back_funs)
+ && !Config.filter_useless_monadic_calls
+ && outputs = [] && nstate = None
+ then (
(* No outputs - we do a small sanity check: the backward function
- * should have exactly the same number of inputs as the forward:
- * this number can be different only if the forward function returned
- * a value containing mutable borrows, which can't be the case... *)
+ should have exactly the same number of inputs as the forward:
+ this number can be different only if the forward function returned
+ a value containing mutable borrows, which can't be the case... *)
assert (List.length inputs = List.length fwd_inputs);
next_e)
- else mk_let effect_info.can_fail output call next_e
+ else
+ (* The backward function might also have been filtered if we do not
+ split the forward/backward functions *)
+ match func with
+ | None -> next_e
+ | Some func ->
+ log#ldebug
+ (lazy
+ (let args = List.map (texpression_to_string ctx) args in
+ "func: "
+ ^ texpression_to_string ctx func
+ ^ "\nfunc type: "
+ ^ pure_ty_to_string ctx func.ty
+ ^ "\n\nargs:\n" ^ String.concat "\n" args));
+ let call = mk_apps func args in
+ mk_let effect_info.can_fail output call next_e
and translate_end_abstraction_identity (ectx : C.eval_ctx) (abs : V.abs)
(e : S.expression) (ctx : bs_ctx) : texpression =
@@ -2095,15 +2551,15 @@ and translate_end_abstraction_synth_ret (ectx : C.eval_ctx) (abs : V.abs)
let-binding:
{[
let id_back x nx =
- let s = nx in // the name [s] is not important (only collision matters)
- ...
+ let s = nx in // the name [s] is not important (only collision matters)
+ ...
]}
This let-binding later gets inlined, during a micro-pass.
*)
(* First, retrieve the list of variables used for the inputs for the
* backward function *)
- let inputs = T.RegionGroupId.Map.find rg_id ctx.backward_inputs in
+ let inputs = T.RegionGroupId.Map.find rg_id ctx.backward_inputs_no_state in
(* Retrieve the values consumed upon ending the loans inside this
* abstraction: as there are no nested borrows, there should be none. *)
let consumed = abs_to_consumed ctx ectx abs in
@@ -2150,11 +2606,12 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
| V.LoopSynthInput ->
(* Actually the same case as [SynthInput] *)
translate_end_abstraction_synth_input ectx abs e ctx rg_id
- | V.LoopCall ->
+ | V.LoopCall -> (
+ (* We need to introduce a call to the backward function corresponding
+ to a forward call which happened earlier *)
let fun_id = E.FRegular ctx.fun_decl.def_id in
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos (FunId fun_id)
- (Some vloop_id) (Some rg_id)
+ get_fun_effect_info ctx (FunId fun_id) (Some vloop_id) (Some rg_id)
in
let loop_info = LoopId.Map.find loop_id ctx.loops in
let generics = loop_info.generics in
@@ -2165,7 +2622,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
values consumed upon ending the abstraction (i.e., we don't use
[abs_to_consumed]) *)
let back_inputs_vars =
- T.RegionGroupId.Map.find rg_id ctx.backward_inputs
+ T.RegionGroupId.Map.find rg_id ctx.backward_inputs_no_state
in
let back_inputs = List.map mk_texpression_from_var back_inputs_vars in
(* If the function is stateful:
@@ -2175,12 +2632,15 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
let back_state, ctx, nstate =
if effect_info.stateful then
let back_state = mk_state_texpression ctx.state_var in
- let ctx, nstate = bs_ctx_fresh_state_var ctx in
+ let ctx, _, nstate = bs_ctx_fresh_state_var ctx in
([ back_state ], ctx, Some nstate)
else ([], ctx, None)
in
(* Concatenate all the inputs *)
- let inputs = List.concat [ fwd_inputs; back_inputs; back_state ] in
+ let inputs =
+ if !Config.return_back_funs then List.concat [ back_inputs; back_state ]
+ else List.concat [ fwd_inputs; back_inputs; back_state ]
+ in
(* Retrieve the values given back by this function *)
let ctx, outputs = abs_to_given_back None abs ctx in
(* Group the output values together: first the updated inputs *)
@@ -2204,68 +2664,88 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
let ret_ty =
if effect_info.can_fail then mk_result_ty output.ty else output.ty
in
- let func_ty = mk_arrows input_tys ret_ty in
- let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in
- let func = { id = FunOrOp func; generics } in
- let func = { e = Qualif func; ty = func_ty } in
- let call = mk_apps func args in
+ (* Create the expression for the function:
+ - it is either a call to a top-level function, if we split the
+ forward/backward functions
+ - or a call to the variable we introduced for the backward function,
+ if we merge the forward/backward functions *)
+ let func =
+ if !Config.return_back_funs then
+ RegionGroupId.Map.find rg_id (Option.get loop_info.back_funs)
+ else
+ let func_ty = mk_arrows input_tys ret_ty in
+ let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in
+ let func = { id = FunOrOp func; generics } in
+ Some { e = Qualif func; ty = func_ty }
+ in
(* **Optimization**:
- * =================
- * We do a small optimization here: if the backward function doesn't
- * have any output, we don't introduce any function call.
- * See the comment in {!Config.filter_useless_monadic_calls}.
- *
- * TODO: use an option to disallow backward functions from updating the state.
- * TODO: a backward function which only gives back shared borrows shouldn't
- * update the state (state updates should only be used for mutable borrows,
- * with objects like Rc for instance).
- *)
- if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None
+ =================
+ We do a small optimization here in case we split the forward/backward
+ functions.
+ If the backward function doesn't have any output, we don't introduce
+ any function call.
+ See the comment in {!Config.filter_useless_monadic_calls}.
+
+ TODO: use an option to disallow backward functions from updating the state.
+ TODO: a backward function which only gives back shared borrows shouldn't
+ update the state (state updates should only be used for mutable borrows,
+ with objects like Rc for instance).
+ *)
+ if
+ (not !Config.return_back_funs)
+ && !Config.filter_useless_monadic_calls
+ && outputs = [] && nstate = None
then (
(* No outputs - we do a small sanity check: the backward function
- * should have exactly the same number of inputs as the forward:
- * this number can be different only if the forward function returned
- * a value containing mutable borrows, which can't be the case... *)
+ should have exactly the same number of inputs as the forward:
+ this number can be different only if the forward function returned
+ a value containing mutable borrows, which can't be the case... *)
assert (List.length inputs = List.length fwd_inputs);
next_e)
else
- (* Add meta-information - this is slightly hacky: we look at the
- values consumed by the abstraction (note that those come from
- *before* we applied the fixed-point context) and use them to
- guide the naming of the output vars.
-
- Also, we need to convert the backward outputs from patterns to
- variables.
-
- Finally, in practice, this works well only for loop bodies:
- we do this only in this case.
- TODO: improve the heuristics, to give weight to the hints for
- instance.
- *)
- let next_e =
- if ctx.inside_loop then
- let consumed_values = abs_to_consumed ctx ectx abs in
- let var_values = List.combine outputs consumed_values in
- let var_values =
- List.filter_map
- (fun (var, v) ->
- match var.Pure.value with
- | PatVar (var, _) -> Some (var, v)
- | _ -> None)
- var_values
+ (* In case we merge the fwd/back functions we filter the backward
+ functions elsewhere *)
+ match func with
+ | None -> next_e
+ | Some func ->
+ let call = mk_apps func args in
+ (* Add meta-information - this is slightly hacky: we look at the
+ values consumed by the abstraction (note that those come from
+ *before* we applied the fixed-point context) and use them to
+ guide the naming of the output vars.
+
+ Also, we need to convert the backward outputs from patterns to
+ variables.
+
+ Finally, in practice, this works well only for loop bodies:
+ we do this only in this case.
+ TODO: improve the heuristics, to give weight to the hints for
+ instance.
+ *)
+ let next_e =
+ if ctx.inside_loop then
+ let consumed_values = abs_to_consumed ctx ectx abs in
+ let var_values = List.combine outputs consumed_values in
+ let var_values =
+ List.filter_map
+ (fun (var, v) ->
+ match var.Pure.value with
+ | PatVar (var, _) -> Some (var, v)
+ | _ -> None)
+ var_values
+ in
+ let vars, values = List.split var_values in
+ mk_emeta_symbolic_assignments vars values next_e
+ else next_e
in
- let vars, values = List.split var_values in
- mk_emeta_symbolic_assignments vars values next_e
- else next_e
- in
- (* Create the let-binding *)
- mk_let effect_info.can_fail output call next_e
+ (* Create the let-binding *)
+ mk_let effect_info.can_fail output call next_e)
and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value)
(e : S.expression) (ctx : bs_ctx) : texpression =
let ctx, var = fresh_var_for_symbolic_value sval ctx in
- let decl = A.GlobalDeclId.Map.find gid ctx.global_context.llbc_global_decls in
+ let decl = A.GlobalDeclId.Map.find gid ctx.global_ctx.llbc_global_decls in
let global_expr = { id = Global gid; generics = empty_generic_args } in
(* We use translate_fwd_ty to translate the global type *)
let ty = ctx_translate_fwd_ty ctx decl.ty in
@@ -2290,8 +2770,7 @@ and translate_assertion (ectx : C.eval_ctx) (v : V.typed_value)
and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
(exp : S.expansion) (ctx : bs_ctx) : texpression =
(* Translate the scrutinee *)
- let scrutinee_var = lookup_var_for_symbolic_value sv ctx in
- let scrutinee = mk_texpression_from_var scrutinee_var in
+ let scrutinee = symbolic_value_to_texpression ctx sv in
let scrutinee_mplace = translate_opt_mplace p in
(* Translate the branches *)
match exp with
@@ -2370,7 +2849,13 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
If (true_e, false_e) )
in
let ty = true_e.ty in
- assert (ty = false_e.ty);
+ log#ldebug
+ (lazy
+ ("true_e.ty: "
+ ^ pure_ty_to_string ctx true_e.ty
+ ^ "\n\nfalse_e.ty: "
+ ^ pure_ty_to_string ctx false_e.ty));
+ if !Config.fail_hard then assert (ty = false_e.ty);
{ e; ty }
| ExpandInt (int_ty, branches, otherwise) ->
let translate_branch ((v, branch_e) : V.scalar_value * S.expression) :
@@ -2441,7 +2926,7 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value)
- if we forbid using field projectors.
*)
let is_rec_def =
- T.TypeDeclId.Set.mem adt_id ctx.type_context.recursive_decls
+ T.TypeDeclId.Set.mem adt_id ctx.type_ctx.recursive_decls
in
let use_let_with_cons =
is_enum
@@ -2454,7 +2939,7 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value)
like Coq don't, in which case we have to deconstruct the whole ADT
at once (`let (a, b, c) = x in`) *)
|| TypesUtils.type_decl_from_type_id_is_tuple_struct
- ctx.type_context.type_infos type_id
+ ctx.type_ctx.type_infos type_id
&& not (Config.backend_has_tuple_projectors ())
in
if use_let_with_cons then
@@ -2547,7 +3032,7 @@ and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option)
{ e = StructUpdate su; ty = var.ty }
| VaCgValue cg_id -> { e = CVar cg_id; ty = var.ty }
| VaTraitConstValue (trait_ref, generics, const_name) ->
- let type_infos = ctx.type_context.type_infos in
+ let type_infos = ctx.type_ctx.type_infos in
let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
let generics = translate_fwd_generic_args type_infos generics in
let qualif_id = TraitConst (trait_ref, generics, const_name) in
@@ -2562,22 +3047,169 @@ and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option)
and translate_forward_end (ectx : C.eval_ctx)
(loop_input_values : V.typed_value S.symbolic_value_id_map option)
- (e : S.expression) (back_e : S.expression S.region_group_id_map)
+ (fwd_e : S.expression) (back_e : S.expression S.region_group_id_map)
(ctx : bs_ctx) : texpression =
- (* Update the current state with the additional state received by the backward
- function, if needs be, and lookup the proper expression *)
- let translate_end ctx =
+ let translate_one_end ctx (bid : RegionGroupId.id option) =
+ let ctx = { ctx with bid } in
(* Update the current state with the additional state received by the backward
function, if needs be, and lookup the proper expression *)
- let ctx, e =
- match ctx.bid with
- | None -> (ctx, e)
+ let ctx, e, finish =
+ match bid with
+ | None ->
+ (* We are translating the forward function - nothing to do *)
+ (ctx, fwd_e, fun e -> e)
| Some bid ->
- let ctx = { ctx with state_var = ctx.back_state_var } in
+ (* There are two cases here:
+ - if we split the fwd/backward functions, we simply need to update
+ the state.
+ - if we don't split, we also need to wrap the expression in a
+ lambda, which introduces the additional inputs of the backward
+ function
+ *)
+ let ctx =
+ (* Introduce variables for the inputs and the state variable
+ and update the context. *)
+ if !Config.return_back_funs then
+ (* If the forward/backward functions are not split, we need
+ to introduce fresh variables for the additional inputs,
+ because they are locally introduced in a lambda *)
+ let back_sg = RegionGroupId.Map.find bid ctx.sg.back_sg in
+ let ctx, backward_inputs_no_state =
+ fresh_vars back_sg.inputs_no_state ctx
+ in
+ let ctx, backward_inputs_with_state =
+ if back_sg.effect_info.stateful then
+ let ctx, var, _ = bs_ctx_fresh_state_var ctx in
+ (ctx, backward_inputs_no_state @ [ var ])
+ else (ctx, backward_inputs_no_state)
+ in
+ {
+ ctx with
+ backward_inputs_no_state =
+ RegionGroupId.Map.add bid backward_inputs_no_state
+ ctx.backward_inputs_no_state;
+ backward_inputs_with_state =
+ RegionGroupId.Map.add bid backward_inputs_with_state
+ ctx.backward_inputs_with_state;
+ }
+ else
+ (* Update the state variable *)
+ let back_state_var =
+ RegionGroupId.Map.find bid ctx.back_state_vars
+ in
+ { ctx with state_var = back_state_var }
+ in
+
let e = T.RegionGroupId.Map.find bid back_e in
- (ctx, e)
+ let finish e =
+ (* Wrap in lambdas if necessary *)
+ if !Config.return_back_funs then
+ let inputs =
+ RegionGroupId.Map.find bid ctx.backward_inputs_with_state
+ in
+ let places = List.map (fun _ -> None) inputs in
+ mk_lambdas_from_vars inputs places e
+ else e
+ in
+ (ctx, e, finish)
in
- translate_expression e ctx
+ let e = translate_expression e ctx in
+ finish e
+ in
+
+ (* There are two cases, depending on whether we are splitting the forward/backward
+ functions or not.
+
+ - if we split, then we simply need to translate the proper "end" expression,
+ that is the end of the forward function, or of the backward function we
+ are currently translating.
+ - if we don't split, then we need to translate the end of the forward
+ function (this is the value we will return) and generate the bodies
+ of the backward functions (which we will also return).
+
+ Update the current state with the additional state received by the backward
+ function, if needs be, and lookup the proper expression.
+ *)
+ let translate_end ctx =
+ if !Config.return_back_funs then
+ (* Compute the output of the forward function *)
+ let fwd_effect_info = ctx.sg.fwd_info.effect_info in
+ let ctx, pure_fwd_var = fresh_var None ctx.sg.fwd_output ctx in
+ let fwd_e = translate_one_end ctx None in
+
+ (* Introduce the backward functions. *)
+ let back_el =
+ List.map
+ (fun ((gid, _) : RegionGroupId.id * back_sg_info) ->
+ translate_one_end ctx (Some gid))
+ (RegionGroupId.Map.bindings ctx.sg.back_sg)
+ in
+
+ (* Compute whether the backward expressions should be evaluated straight
+ away or not (i.e., if we should bind them with monadic let-bindings
+ or not). We evaluate them straight away if they can fail and have no
+ inputs. *)
+ let evaluate_backs =
+ List.map
+ (fun (sg : back_sg_info) ->
+ if !Config.simplify_merged_fwd_backs then
+ sg.inputs = [] && sg.effect_info.can_fail
+ else false)
+ (RegionGroupId.Map.values ctx.sg.back_sg)
+ in
+
+ (* Introduce variables for the backward functions.
+ We lookup the LLBC definition in an attempt to derive pretty names
+ for those functions. *)
+ let _, back_vars = fresh_back_vars_for_current_fun ctx in
+
+ (* Create the return expressions *)
+ let vars =
+ let back_vars = List.filter_map (fun x -> x) back_vars in
+ if ctx.sg.fwd_info.ignore_output then back_vars
+ else pure_fwd_var :: back_vars
+ in
+ let vars = List.map mk_texpression_from_var vars in
+ let ret = mk_simpl_tuple_texpression vars in
+
+ (* Introduce a fresh input state variable for the forward expression *)
+ let _ctx, state_var, state_pat =
+ if fwd_effect_info.stateful then
+ let ctx, var, pat = bs_ctx_fresh_state_var ctx in
+ (ctx, [ var ], [ pat ])
+ else (ctx, [], [])
+ in
+
+ let state_var = List.map mk_texpression_from_var state_var in
+ let ret = mk_simpl_tuple_texpression (state_var @ [ ret ]) in
+ let ret = mk_result_return_texpression ret in
+
+ (* Introduce all the let-bindings *)
+
+ (* Combine:
+ - the backward variables
+ - whether we should evaluate the expression for the backward function
+ (i.e., should we use a monadic let-binding or not - we do if the
+ backward functions don't have inputs and can fail)
+ - the expressions for the backward functions
+ *)
+ let back_vars_els =
+ List.filter_map
+ (fun (v, (eval, el)) ->
+ match v with None -> None | Some v -> Some (v, eval, el))
+ (List.combine back_vars (List.combine evaluate_backs back_el))
+ in
+ let e =
+ List.fold_right
+ (fun (var, evaluate, back_e) e ->
+ mk_let evaluate (mk_typed_pattern_from_var var None) back_e e)
+ back_vars_els ret
+ in
+ (* Bind the expression for the forward output *)
+ let fwd_var = mk_typed_pattern_from_var pure_fwd_var None in
+ let pat = mk_simpl_tuple_pattern (state_pat @ [ fwd_var ]) in
+ mk_let fwd_effect_info.can_fail pat fwd_e e
+ else translate_one_end ctx ctx.bid
in
(* If we are (re-)entering a loop, we need to introduce a call to the
@@ -2624,17 +3256,53 @@ and translate_forward_end (ectx : C.eval_ctx)
(* Lookup the effect info for the loop function *)
let fid = E.FRegular ctx.fun_decl.def_id in
- let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos (FunId fid) None ctx.bid
- in
+ let effect_info = get_fun_effect_info ctx (FunId fid) None ctx.bid in
(* Introduce a fresh output value for the forward function *)
- let ctx, output_var =
- let output_ty = mk_simpl_tuple_ty ctx.fwd_sg.doutputs in
- fresh_var None output_ty ctx
+ let ctx, fwd_output, output_pat =
+ if ctx.sg.fwd_info.ignore_output then
+ (* Note that we still need the forward output (which is unit),
+ because even though the loop function will ignore the forward output,
+ the forward expression will still compute an output (which
+ will have type unit - otherwise we can't ignore it). *)
+ (ctx, mk_unit_rvalue, [])
+ else
+ let ctx, output_var = fresh_var None ctx.sg.fwd_output ctx in
+ ( ctx,
+ mk_texpression_from_var output_var,
+ [ mk_typed_pattern_from_var output_var None ] )
+ in
+
+ (* Introduce fresh variables for the backward functions of the loop.
+
+ For now, the backward functions of the loop are the same as the
+ backward functions of the outer function.
+ *)
+ let ctx, back_funs_map, back_funs =
+ if !Config.return_back_funs then
+ let ctx, back_vars = fresh_back_vars_for_current_fun ctx in
+ let back_funs =
+ List.filter_map
+ (fun v ->
+ match v with
+ | None -> None
+ | Some v -> Some (mk_typed_pattern_from_var v None))
+ back_vars
+ in
+ let gids = RegionGroupId.Map.keys ctx.sg.back_sg in
+ let back_funs_map =
+ RegionGroupId.Map.of_list
+ (List.combine gids
+ (List.map (Option.map mk_texpression_from_var) back_vars))
+ in
+ (ctx, Some back_funs_map, back_funs)
+ else (ctx, None, [])
in
+
+ (* Introduce patterns *)
let args, ctx, out_pats =
- let output_pat = mk_typed_pattern_from_var output_var None in
+ (* Add the returned backward functions (they might be empty) *)
+ let output_pat = mk_simpl_tuple_pattern (output_pat @ back_funs) in
(* Depending on the function effects:
* - add the fuel
@@ -2644,7 +3312,7 @@ and translate_forward_end (ectx : C.eval_ctx)
let fuel = mk_fuel_input_as_list ctx effect_info in
if effect_info.stateful then
let state_var = mk_state_texpression ctx.state_var in
- let ctx, nstate_pat = bs_ctx_fresh_state_var ctx in
+ let ctx, _, nstate_pat = bs_ctx_fresh_state_var ctx in
( List.concat [ fuel; args; [ state_var ] ],
ctx,
[ nstate_pat; output_pat ] )
@@ -2656,7 +3324,8 @@ and translate_forward_end (ectx : C.eval_ctx)
{
loop_info with
forward_inputs = Some args;
- forward_output_no_state_no_result = Some output_var;
+ forward_output_no_state_no_result = Some fwd_output;
+ back_funs = back_funs_map;
}
in
let ctx =
@@ -2753,35 +3422,107 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
(* Compute the backward outputs *)
let ctx = ref ctx in
- let loop_backward_outputs =
- T.RegionGroupId.Map.map
+ let rg_to_given_back_tys =
+ RegionGroupId.Map.map
(fun (_, tys) ->
(* The types shouldn't contain borrows - we can translate them as forward types *)
- let vars =
- List.map
- (fun ty ->
- assert (
- not (TypesUtils.ty_has_borrows !ctx.type_context.type_infos ty));
- (None, ctx_translate_fwd_ty !ctx ty))
- tys
- in
- (* Introduce fresh variables *)
- let ctx', vars = fresh_vars vars !ctx in
- ctx := ctx';
- vars)
+ List.map
+ (fun ty ->
+ assert (not (TypesUtils.ty_has_borrows !ctx.type_ctx.type_infos ty));
+ ctx_translate_fwd_ty !ctx ty)
+ tys)
loop.rg_to_given_back_tys
in
let ctx = !ctx in
- let back_output_tys =
- match ctx.bid with
- | None -> None
- | Some rg_id ->
- let back_outputs =
- T.RegionGroupId.Map.find rg_id loop_backward_outputs
- in
- let back_output_tys = List.map (fun (v : var) -> v.ty) back_outputs in
- Some back_output_tys
+ (* The output type of the loop function *)
+ let fwd_effect_info = { ctx.sg.fwd_info.effect_info with is_rec = true } in
+ let back_effect_infos, output_ty =
+ if !Config.return_back_funs then
+ (* The loop backward functions consume the same additional inputs as the parent
+ function, but have custom outputs *)
+ let back_sgs = RegionGroupId.Map.bindings ctx.sg.back_sg in
+ let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in
+ let back_info_tys =
+ List.map
+ (fun (((id, back_sg), given_back) : (_ * back_sg_info) * ty list) ->
+ (* Remark: the effect info of the backward function for the loop
+ is almost the same as for the backward function of the parent function.
+ Quite importantly, the fact that the function is stateful and/or can fail
+ mostly depends on whether it has inputs or not, and the backward functions
+ for the loops have the same inputs as the backward functions for the parent
+ function.
+ *)
+ let effect_info = back_sg.effect_info in
+ let effect_info = { effect_info with is_rec = true } in
+ (* Compute the input/output types *)
+ let inputs = List.map snd back_sg.inputs in
+ let outputs = given_back in
+ (* Filter if necessary *)
+ let ty =
+ if
+ !Config.simplify_merged_fwd_backs && inputs = [] && outputs = []
+ then None
+ else
+ let output = mk_simpl_tuple_ty outputs in
+ let output =
+ mk_back_output_ty_from_effect_info effect_info inputs output
+ in
+ let ty = mk_arrows inputs output in
+ Some ty
+ in
+ ((id, effect_info), ty))
+ (List.combine back_sgs given_back_tys)
+ in
+ let back_info = List.map fst back_info_tys in
+ let back_info = RegionGroupId.Map.of_list back_info in
+ let back_tys = List.filter_map snd back_info_tys in
+ let output =
+ if ctx.sg.fwd_info.ignore_output then back_tys
+ else ctx.sg.fwd_output :: back_tys
+ in
+ let output = mk_simpl_tuple_ty output in
+ let effect_info = ctx.sg.fwd_info.effect_info in
+ let output =
+ if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ]
+ else output
+ in
+ let output =
+ if effect_info.can_fail && inputs <> [] then mk_result_ty output
+ else output
+ in
+ (back_info, output)
+ else
+ let back_info =
+ RegionGroupId.Map.of_list
+ (List.map
+ (fun ((id, back_sg) : _ * back_sg_info) ->
+ (id, { back_sg.effect_info with is_rec = true }))
+ (RegionGroupId.Map.bindings ctx.sg.back_sg))
+ in
+ let output =
+ match ctx.bid with
+ | None ->
+ (* Forward function: same type as the parent function *)
+ (translate_fun_sig_from_decomposed ctx.sg None).output
+ | Some rg_id ->
+ (* Backward function: custom return type *)
+ let doutputs =
+ T.RegionGroupId.Map.find rg_id rg_to_given_back_tys
+ in
+ let output = mk_simpl_tuple_ty doutputs in
+ let fwd_effect_info = ctx.sg.fwd_info.effect_info in
+ let output =
+ if fwd_effect_info.stateful then
+ mk_simpl_tuple_ty [ mk_state_ty; output ]
+ else output
+ in
+ let output =
+ if fwd_effect_info.can_fail then mk_result_ty output else output
+ in
+ output
+ in
+ (back_info, output)
in
(* Add the loop information in the context *)
@@ -2823,6 +3564,10 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
generics;
forward_inputs = None;
forward_output_no_state_no_result = None;
+ back_outputs = rg_to_given_back_tys;
+ back_funs = None;
+ fwd_effect_info;
+ back_effect_infos;
}
in
let loops = LoopId.Map.add loop_id loop_info ctx.loops in
@@ -2830,13 +3575,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
in
(* Update the context to translate the function end *)
- let ctx_end =
- {
- ctx with
- loop_id = Some loop_id;
- loop_backward_outputs = Some loop_backward_outputs;
- }
- in
+ let ctx_end = { ctx with loop_id = Some loop_id } in
let fun_end = translate_expression loop.end_expr ctx_end in
(* Update the context for the loop body *)
@@ -2844,7 +3583,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
(* Add the input state *)
let input_state =
- if ctx.sg.info.effect_info.stateful then Some ctx.state_var else None
+ if (ctx_get_effect_info ctx).stateful then Some ctx.state_var else None
in
(* Translate the loop body *)
@@ -2862,7 +3601,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
input_state;
inputs;
inputs_lvs;
- back_output_tys;
+ output_ty;
loop_body;
}
in
@@ -2982,10 +3721,10 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
let def_id = def.def_id in
let llbc_name = def.name in
let name = name_to_string ctx llbc_name in
- (* Retrieve the signature *)
- let signature = ctx.sg in
+ (* Translate the signature *)
+ let signature = translate_fun_sig_from_decomposed ctx.sg ctx.bid in
let regions_hierarchy =
- FunIdMap.find (FRegular def_id) ctx.fun_context.regions_hierarchies
+ FunIdMap.find (FRegular def_id) ctx.fun_ctx.regions_hierarchies
in
(* Translate the body, if there is *)
let body =
@@ -2993,8 +3732,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
| None -> None
| Some body ->
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos
- (FunId (FRegular def_id)) None bid
+ get_fun_effect_info ctx (FunId (FRegular def_id)) None bid
in
let body = translate_expression body ctx in
(* Add a match over the fuel, if necessary *)
@@ -3027,20 +3765,25 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
match bid with
| None -> []
| Some back_id ->
+ assert (not !Config.return_back_funs);
let parents_ids =
list_ordered_ancestor_region_groups regions_hierarchy back_id
in
let backward_ids = List.append parents_ids [ back_id ] in
List.concat
(List.map
- (fun id -> T.RegionGroupId.Map.find id ctx.backward_inputs)
+ (fun id ->
+ T.RegionGroupId.Map.find id ctx.backward_inputs_no_state)
backward_ids)
in
(* Introduce the backward input state (the state at call site of the
* *backward* function), if necessary *)
let back_state =
if effect_info.stateful && Option.is_some bid then
- [ mk_state_var ctx.back_state_var ]
+ let state_var =
+ RegionGroupId.Map.find (Option.get bid) ctx.back_state_vars
+ in
+ [ mk_state_var state_var ]
else []
in
(* Group the inputs together *)
@@ -3114,63 +3857,6 @@ let translate_type_decls (ctx : Contexts.decls_ctx) : type_decl list =
List.map (translate_type_decl ctx)
(TypeDeclId.Map.values ctx.type_ctx.type_decls)
-(** Translates function signatures.
-
- Takes as input a list of function information containing:
- - the function id
- - a list of optional names for the inputs
- - the function signature
-
- Returns a map from forward/backward functions identifiers to:
- - translated function signatures
- - optional names for the outputs values (we derive them for the backward
- functions)
- *)
-let translate_fun_signatures (decls_ctx : C.decls_ctx)
- (functions : (A.fun_id * string option list * A.fun_sig) list) :
- fun_sig_named_outputs RegularFunIdNotLoopMap.t =
- (* For every function, translate the signatures of:
- - the forward function
- - the backward functions
- *)
- let translate_one (fun_id : A.fun_id) (input_names : string option list)
- (sg : A.fun_sig) : (regular_fun_id_not_loop * fun_sig_named_outputs) list
- =
- log#ldebug
- (lazy
- ("Translating signature of function: "
- ^ Print.Expressions.fun_id_to_string
- (Print.Contexts.decls_ctx_to_fmt_env decls_ctx)
- fun_id));
- (* Retrieve the regions hierarchy *)
- let regions_hierarchy =
- FunIdMap.find fun_id decls_ctx.fun_ctx.regions_hierarchies
- in
- (* The forward function *)
- let fwd_sg = translate_fun_sig decls_ctx fun_id sg input_names None in
- let fwd_id = (fun_id, None) in
- (* The backward functions *)
- let back_sgs =
- List.map
- (fun (rg : T.region_var_group) ->
- let tsg =
- translate_fun_sig decls_ctx fun_id sg input_names (Some rg.id)
- in
- let id = (fun_id, Some rg.id) in
- (id, tsg))
- regions_hierarchy
- in
- (* Return *)
- (fwd_id, fwd_sg) :: back_sgs
- in
- let translated =
- List.concat
- (List.map (fun (id, names, sg) -> translate_one id names sg) functions)
- in
- List.fold_left
- (fun m (id, sg) -> RegularFunIdNotLoopMap.add id sg m)
- RegularFunIdNotLoopMap.empty translated
-
let translate_trait_decl (ctx : Contexts.decls_ctx) (trait_decl : A.trait_decl)
: trait_decl =
let {