summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2022-05-04 13:54:45 +0200
committerSon Ho2022-05-04 13:54:45 +0200
commit37f80fd592f703ab9b14a9d3d5d638b9c335997f (patch)
tree2ccb9ccc445e181f354b5cb3093c425f9f666560
parent593ffae18cf647457121470c371ba9effbc55f5d (diff)
Start updating the way the function return type (with errors and states)
are handled
Diffstat (limited to '')
-rw-r--r--src/PrintPure.ml23
-rw-r--r--src/Pure.ml81
-rw-r--r--src/PureUtils.ml41
-rw-r--r--src/SymbolicToPure.ml315
4 files changed, 278 insertions, 182 deletions
diff --git a/src/PrintPure.ml b/src/PrintPure.ml
index f21329ed..07144d3e 100644
--- a/src/PrintPure.ml
+++ b/src/PrintPure.ml
@@ -383,30 +383,15 @@ let fun_sig_to_string (fmt : ast_formatter) (sg : fun_sig) : string =
let ty_fmt = ast_to_type_formatter fmt in
let type_params = List.map type_var_to_string sg.type_params in
let inputs = List.map (ty_to_string ty_fmt) sg.inputs in
- let outputs = List.map (ty_to_string ty_fmt) sg.outputs in
- let outputs =
- match outputs with
- | [] ->
- (* Can happen with backward functions which don't give back
- * anything (shared borrows only) *)
- "()"
- | [ out ] -> out
- | outputs -> "(" ^ String.concat " * " outputs ^ ")"
- in
- let all_types = List.concat [ type_params; inputs; [ outputs ] ] in
+ let output = ty_to_string ty_fmt sg.output in
+ let all_types = List.concat [ type_params; inputs; [ output ] ] in
String.concat " -> " all_types
let inst_fun_sig_to_string (fmt : ast_formatter) (sg : inst_fun_sig) : string =
let ty_fmt = ast_to_type_formatter fmt in
let inputs = List.map (ty_to_string ty_fmt) sg.inputs in
- let outputs = List.map (ty_to_string ty_fmt) sg.outputs in
- let outputs =
- match outputs with
- | [] -> "()"
- | [ out ] -> out
- | outputs -> "(" ^ String.concat " * " outputs ^ ")"
- in
- let all_types = List.append inputs [ outputs ] in
+ let output = ty_to_string ty_fmt sg.output in
+ let all_types = List.append inputs [ output ] in
String.concat " -> " all_types
let regular_fun_id_to_string (fmt : ast_formatter) (fun_id : A.fun_id) : string
diff --git a/src/Pure.ml b/src/Pure.ml
index e2362338..d8e1cafc 100644
--- a/src/Pure.ml
+++ b/src/Pure.ml
@@ -505,28 +505,79 @@ and meta =
nude = true (* Don't inherit [VisitorsRuntime.iter] *);
}]
+type fun_sig_info = {
+ num_fwd_inputs : int;
+ (** The number of input types for forward computation *)
+ num_back_inputs : int option;
+ (** The number of additional inputs for the backward computation (if pertinent) *)
+ input_state : bool; (** `true` if the function takes a state as input *)
+ output_state : bool;
+ (** `true` if the function outputs a state (it then lives
+ in a state monad) *)
+ can_fail : bool; (** `true` if the return type is a `result` *)
+}
+(** Meta information about a function signature *)
+
type fun_sig = {
type_params : type_var list;
inputs : ty list;
- outputs : ty list;
- (** The list of outputs.
-
- Immediately after the translation from symbolic to pure we have this
- the following:
- In case of a forward function, the list will have length = 1.
- However, in case of backward function, the list may have length > 1.
- If the length is > 1, it gets extracted to a tuple type. Followingly,
- we could not use a list because we can encode tuples, but here we
- want to account for the fact that we immediately deconstruct the tuple
- upon calling the backward function (because the backward function is
- called to update a set of values in the environment).
+ output : ty;
+ doutputs : ty list;
+ (** The "decomposed" list of outputs.
+
+ In case of a forward function, the list has length = 1, for the
+ type of the returned value.
+
+ In case of backward function, the list contains all the types of
+ all the given back values (there is at most one type per forward
+ input argument).
+
+ Ex.:
+ ```
+ fn choose<'a, T>(b : bool, x : &'a mut T, y : &'a mut T) -> &'a mut T;
+ ```
+ Decomposed outputs:
+ - forward function: [T]
+ - backward function: [T; T] (for "x" and "y")
- After the "to monadic" pass, the list has size exactly one (and we
- use the `Result` type).
*)
+ info : fun_sig_info; (** Additional information *)
}
+(** A function signature.
+
+ We have the following cases:
+ - forward function:
+ `in_ty0 -> ... -> in_tyn -> out_ty` (* pure function *)
+ `in_ty0 -> ... -> in_tyn -> result out_ty` (* error-monad *)
+ `in_ty0 -> ... -> in_tyn -> state -> result (state & out_ty)` (* state-error *)
+ - backward function:
+ `in_ty0 -> ... -> in_tyn -> back_in0 -> ... back_inm -> (back_out0 & ... & back_outp)` (* pure function *)
+ `in_ty0 -> ... -> in_tyn -> back_in0 -> ... back_inm ->
+ result (back_out0 & ... & back_outp)` (* error-monad *)
+ `in_ty0 -> ... -> in_tyn -> state -> back_in0 -> ... back_inm ->
+ result (back_out0 & ... & back_outp)` (* state-error *)
+
+ Note that a backward function never returns (i.e., updates) a state: only
+ forward functions do so. Also, the state input parameter is *betwee*
+ the forward inputs and the backward inputs.
+
+ The function's type should be given by `mk_arrows sig.inputs sig.output`.
+ We provide additional meta-information:
+ - we divide between forward inputs and backward inputs (i.e., inputs specific
+ to the forward functions, and additional inputs necessary if the signature is
+ for a backward function)
+ - we have booleans to give us the fact that the function takes a state as
+ input, or can fail, etc. without having to inspect the signature
+ - etc.
+ *)
-type inst_fun_sig = { inputs : ty list; outputs : ty list }
+type inst_fun_sig = {
+ inputs : ty list;
+ output : ty;
+ doutputs : ty list;
+ info : fun_sig_info;
+}
+(** An instantiated function signature. See [fun_sig] *)
type fun_body = {
inputs : var list;
diff --git a/src/PureUtils.ml b/src/PureUtils.ml
index 8651679f..a1af3396 100644
--- a/src/PureUtils.ml
+++ b/src/PureUtils.ml
@@ -125,8 +125,10 @@ let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) :
inst_fun_sig =
let subst = ty_substitute tsubst in
let inputs = List.map subst sg.inputs in
- let outputs = List.map subst sg.outputs in
- { inputs; outputs }
+ let output = subst sg.output in
+ let doutputs = List.map subst sg.doutputs in
+ let info = sg.info in
+ { inputs; output; doutputs; info }
(** Return true if a list of functions are *not* mutually recursive, false otherwise.
This function is meant to be applied on a set of (forward, backwards) functions
@@ -478,38 +480,3 @@ let opt_destruct_state_monad_result (ty : ty) : ty option =
let opt_unmeta_mplace (e : texpression) : mplace option * texpression =
match e.e with Meta (MPlace mp, e) -> (Some mp, e) | _ -> (None, e)
-
-(** Utility function, used for type checking - TODO: move *)
-let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t)
- (type_id : type_id) (variant_id : VariantId.id option) (tys : ty list) :
- ty list =
- match type_id with
- | Tuple ->
- (* Tuple *)
- assert (variant_id = None);
- tys
- | AdtId def_id ->
- (* "Regular" ADT *)
- let def = TypeDeclId.Map.find def_id type_decls in
- type_decl_get_instantiated_fields_types def variant_id tys
- | Assumed aty -> (
- (* Assumed type *)
- match aty with
- | State ->
- (* `State` is opaque *)
- raise (Failure "Unreachable: `State` values are opaque")
- | Result ->
- let ty = Collections.List.to_cons_nil tys in
- let variant_id = Option.get variant_id in
- if variant_id = result_return_id then [ ty ]
- else if variant_id = result_fail_id then []
- else
- raise (Failure "Unreachable: improper variant id for result type")
- | Option ->
- let ty = Collections.List.to_cons_nil tys in
- let variant_id = Option.get variant_id in
- if variant_id = option_some_id then [ ty ]
- else if variant_id = option_none_id then []
- else
- raise (Failure "Unreachable: improper variant id for result type")
- | Vec -> raise (Failure "Unreachable: `Vector` values are opaque"))
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index 4754d237..466e5562 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -78,12 +78,22 @@ type fun_context = {
type call_info = {
forward : S.call;
- backwards : V.abs T.RegionGroupId.Map.t;
- (** TODO: not sure we need this anymore *)
+ forward_inputs : texpression list;
+ (** Remember the list of inputs given to the forward function.
+
+ Those inputs include the state input, if pertinent (in which case
+ it is the last input).
+ *)
+ backwards : (V.abs * texpression list) T.RegionGroupId.Map.t;
+ (** A map from region group id (i.e., backward function id) to
+ pairs (abstraction, additional arguments received by the backward function)
+
+ TODO: remove? it is also in the bs_ctx ("abstractions" field)
+ *)
}
(** Whenever we translate a function call or an ended abstraction, we
store the related information (this is useful when translating ended
- children abstractions)
+ children abstractions).
*)
type bs_ctx = {
@@ -96,8 +106,10 @@ type bs_ctx = {
(** Whenever we encounter a new symbolic value (introduced because of
a symbolic expansion or upon ending an abstraction, for instance)
we introduce a new variable (with a let-binding).
- *)
+ *)
var_counter : VarId.generator;
+ state_var : VarId.id;
+ (** The current state variable, in case we use a state *)
forward_inputs : var list;
(** The input parameters for the forward function *)
backward_inputs : var list T.RegionGroupId.Map.t;
@@ -106,8 +118,8 @@ type bs_ctx = {
(** The variables that the backward functions will output *)
calls : call_info V.FunCallId.Map.t;
(** The function calls we encountered so far *)
- abstractions : V.abs V.AbstractionId.Map.t;
- (** The ended abstractions we encountered so far *)
+ abstractions : (V.abs * texpression list) V.AbstractionId.Map.t;
+ (** The ended abstractions we encountered so far, with their additional input arguments *)
}
(** Body synthesis context *)
@@ -197,26 +209,33 @@ let bs_ctx_lookup_local_function_sig (def_id : FunDeclId.id)
(RegularFunIdMap.find id ctx.fun_context.fun_sigs).sg
let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call)
- (ctx : bs_ctx) : bs_ctx =
+ (args : texpression list) (ctx : bs_ctx) : bs_ctx =
let calls = ctx.calls in
assert (not (V.FunCallId.Map.mem call_id calls));
- let info = { forward; backwards = T.RegionGroupId.Map.empty } in
+ let info =
+ { forward; forward_inputs = args; backwards = T.RegionGroupId.Map.empty }
+ in
let calls = V.FunCallId.Map.add call_id info calls in
{ ctx with calls }
-let bs_ctx_register_backward_call (abs : V.abs) (ctx : bs_ctx) : bs_ctx * fun_id
- =
+(** [back_args]: the *additional* list of inputs received by the backward function *)
+let bs_ctx_register_backward_call (abs : V.abs) (back_args : texpression list)
+ (ctx : bs_ctx) : bs_ctx * fun_id =
(* Insert the abstraction in the call informations *)
let back_id = abs.back_id in
let info = V.FunCallId.Map.find abs.call_id ctx.calls in
assert (not (T.RegionGroupId.Map.mem back_id info.backwards));
- let backwards = T.RegionGroupId.Map.add back_id abs info.backwards in
+ let backwards =
+ T.RegionGroupId.Map.add back_id (abs, back_args) info.backwards
+ in
let info = { info with backwards } in
let calls = V.FunCallId.Map.add abs.call_id info ctx.calls in
(* Insert the abstraction in the abstractions map *)
let abstractions = ctx.abstractions in
assert (not (V.AbstractionId.Map.mem abs.abs_id abstractions));
- let abstractions = V.AbstractionId.Map.add abs.abs_id abs abstractions in
+ let abstractions =
+ V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions
+ in
(* Retrieve the fun_id *)
let fun_id =
match info.forward.call_id with
@@ -438,7 +457,7 @@ let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs) :
if V.AbstractionId.Set.mem abs_id !abs_set then ()
else (
abs_set := V.AbstractionId.Set.add abs_id !abs_set;
- let abs = V.AbstractionId.Map.find abs_id ctx.abstractions in
+ let abs, _ = V.AbstractionId.Map.find abs_id ctx.abstractions in
List.iter gather abs.original_parents)
in
List.iter gather abs.original_parents;
@@ -449,7 +468,8 @@ let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs) :
(fun id -> V.AbstractionId.Set.mem id ids)
call_info.forward.abstractions
-let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) : V.abs list =
+let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) :
+ (V.abs * texpression list) list =
let abs_ids = list_ancestor_abstractions_ids ctx abs in
List.map (fun id -> V.AbstractionId.Map.find id ctx.abstractions) abs_ids
@@ -460,9 +480,9 @@ let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) : V.abs list =
name (outputs for backward functions come from borrows in the inputs
of the forward function).
*)
-let translate_fun_sig (types_infos : TA.type_infos) (sg : A.fun_sig)
- (input_names : string option list) (bid : T.RegionGroupId.id option) :
- fun_sig_named_outputs =
+let translate_fun_sig (config : config) (types_infos : TA.type_infos)
+ (sg : A.fun_sig) (input_names : string option list)
+ (bid : T.RegionGroupId.id option) : fun_sig_named_outputs =
(* Retrieve the list of parent backward functions *)
let gid, parents =
match bid with
@@ -510,9 +530,25 @@ let translate_fun_sig (types_infos : TA.type_infos) (sg : A.fun_sig)
*)
List.filter_map (translate_back_ty_for_gid gid) [ sg.output ]
in
- let inputs = List.append fwd_inputs back_inputs in
+ (* Does the function take a state as input, does it return a state and can
+ * it fail? *)
+ (* For now, all the translated functions can fail *)
+ let can_fail = true in
+ (* For now, all translated functions have an input state if we setup
+ * the translation to use states *)
+ let input_state = config.use_state_monad in
+ (* Only the forward functions return a state *)
+ let output_state = input_state && bid = None in
+ (* *)
+ let state_ty = if input_state then [ mk_state_ty ] else [] in
+ (* Concatenate the inputs, in the following order:
+ * - forward inputs
+ * - state input
+ * - backward inputs
+ *)
+ let inputs = List.concat [ fwd_inputs; state_ty; back_inputs ] in
(* Outputs *)
- let output_names, outputs =
+ let output_names, doutputs =
match gid with
| None ->
(* This is a forward function: there is one (unnamed) output *)
@@ -542,12 +578,45 @@ let translate_fun_sig (types_infos : TA.type_infos) (sg : A.fun_sig)
in
List.split outputs
in
+ (* Create the return type *)
+ let output =
+ (* Group the outputs together *)
+ let output = mk_simpl_tuple_ty doutputs in
+ (* Add the output state *)
+ let output =
+ if output_state then mk_simpl_tuple_ty [ mk_state_ty; output ] else output
+ in
+ (* Wrap in a result type *)
+ if can_fail then mk_result_ty output else output
+ in
(* Type parameters *)
let type_params = sg.type_params in
(* Return *)
- let sg = { type_params; inputs; outputs } in
+ let info =
+ {
+ num_fwd_inputs = List.length fwd_inputs;
+ num_back_inputs =
+ (if bid = None then None else Some (List.length back_inputs));
+ input_state;
+ output_state;
+ can_fail;
+ }
+ in
+ let sg = { type_params; inputs; output; doutputs; info } in
{ sg; output_names }
+let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern =
+ (* Generate the fresh variable *)
+ let id, var_counter = VarId.fresh ctx.var_counter in
+ let var =
+ { id; basename = Some ConstStrings.state_basename; ty = mk_state_ty }
+ in
+ let state_var = mk_typed_pattern_from_var var None in
+ (* Update the context *)
+ let ctx = { ctx with var_counter; state_var = id } in
+ (* Return *)
+ (ctx, state_var)
+
let fresh_named_var_for_symbolic_value (basename : string option)
(sv : V.symbolic_value) (ctx : bs_ctx) : bs_ctx * var =
(* Generate the fresh variable *)
@@ -966,40 +1035,35 @@ let abs_to_given_back_no_mp (abs : V.abs) (ctx : bs_ctx) :
Is used for instance when collecting the input values given to all the
parent functions, in order to properly instantiate an
*)
-let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) : S.call * V.abs list =
+let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) :
+ S.call * (V.abs * texpression list) list =
let call_info = V.FunCallId.Map.find abs.call_id ctx.calls in
let abs_ancestors = list_ancestor_abstractions ctx abs in
(call_info.forward, abs_ancestors)
-(** Small utility.
-
- Return: (function is monadic, if monadic, function uses state monad)
-
- Note that all functions are monadic except some assumed functions.
-
+type fun_effect_info = {
+ can_fail : bool;
+ input_state : bool;
+ output_state : bool;
+}
+(** TODO: factorize with fun_sig_info?
TODO: use an enumeration
*)
-let fun_is_monadic (fun_id : A.fun_id) : bool * bool =
- match fun_id with
- | A.Regular _ -> (true, true)
- | A.Assumed aid -> (Assumed.assumed_is_monadic aid, false)
-(** Utility for function return types.
-
- A function return type can have the shape:
- - ty
- - result ty (* error-monad *)
- - state -> result (state & ty) (* state-error monad *)
- *)
-let mk_function_ret_ty (config : config) (monadic : bool) (state_monad : bool)
- (out_ty : ty) : ty =
- if monadic then
- if config.use_state_monad && state_monad then
- let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in
- let ret = mk_arrow mk_state_ty ret in
- ret
- else mk_result_ty out_ty
- else out_ty
+(** Small utility. *)
+let get_fun_effect_info (config : config) (fun_id : A.fun_id)
+ (gid : T.RegionGroupId.id option) : fun_effect_info =
+ match fun_id with
+ | A.Regular _ ->
+ let input_state = config.use_state_monad in
+ let output_state = input_state && gid = None in
+ { can_fail = true; input_state; output_state }
+ | A.Assumed aid ->
+ {
+ can_fail = Assumed.assumed_is_monadic aid;
+ input_state = false;
+ output_state = false;
+ }
let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx)
: texpression =
@@ -1089,27 +1153,58 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression)
(ctx : bs_ctx) : texpression =
(* Translate the function call *)
let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in
- let args = List.map (typed_value_to_texpression ctx) call.args in
- let args_mplaces = List.map translate_opt_mplace call.args_places in
+ let args =
+ let args = List.map (typed_value_to_texpression ctx) call.args in
+ let args_mplaces = List.map translate_opt_mplace call.args_places in
+ List.map
+ (fun (arg, mp) -> mk_opt_mplace_texpression mp arg)
+ (List.combine args args_mplaces)
+ in
let dest_mplace = translate_opt_mplace call.dest_place in
let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
(* Retrieve the function id, and register the function call in the context
* if necessary. *)
- let ctx, fun_id, monadic, state_monad =
+ let ctx, fun_id, effect_info, args, out_state =
match call.call_id with
| S.Fun (fid, call_id) ->
- let ctx = bs_ctx_register_forward_call call_id call ctx in
+ (* Regular function call *)
let func = Regular (fid, None) in
- let monadic, state_monad = fun_is_monadic fid in
- (ctx, func, monadic, state_monad)
- | S.Unop E.Not -> (ctx, Unop Not, false, false)
+ (* Retrieve the effect information about this function (can fail,
+ * takes a state as input, etc.) *)
+ let effect_info = get_fun_effect_info config fid None in
+ (* Add the state input argument *)
+ let args =
+ if effect_info.input_state then
+ let state_var = { e = Var ctx.state_var; ty = mk_state_ty } in
+ List.append args [ state_var ]
+ else args
+ in
+ (* Generate a fresh state variable if the function call introduces
+ * a new variable *)
+ let ctx, out_state =
+ if effect_info.input_state then
+ let ctx, var = bs_ctx_fresh_state_var ctx in
+ (ctx, Some var)
+ else (ctx, None)
+ in
+ (* Register the function call *)
+ let ctx = bs_ctx_register_forward_call call_id call args ctx in
+ (ctx, func, effect_info, args, out_state)
+ | S.Unop E.Not ->
+ let effect_info =
+ { can_fail = false; input_state = false; output_state = false }
+ in
+ (ctx, Unop Not, effect_info, args, None)
| S.Unop E.Neg -> (
match args with
| [ arg ] ->
let int_ty = ty_as_integer arg.ty in
(* Note that negation can lead to an overflow and thus fail (it
* is thus monadic) *)
- (ctx, Unop (Neg int_ty), true, false)
+ let effect_info =
+ { can_fail = true; input_state = false; output_state = false }
+ in
+ (ctx, Unop (Neg int_ty), effect_info, args, None)
| _ -> raise (Failure "Unreachable"))
| S.Binop binop -> (
match args with
@@ -1117,26 +1212,34 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression)
let int_ty0 = ty_as_integer arg0.ty in
let int_ty1 = ty_as_integer arg1.ty in
assert (int_ty0 = int_ty1);
- let monadic = binop_can_fail binop in
- (ctx, Binop (binop, int_ty0), monadic, false)
+ let effect_info =
+ {
+ can_fail = binop_can_fail binop;
+ input_state = false;
+ output_state = false;
+ }
+ in
+ (ctx, Binop (binop, int_ty0), effect_info, args, None)
| _ -> raise (Failure "Unreachable"))
in
- let args =
- List.map
- (fun (arg, mp) -> mk_opt_mplace_texpression mp arg)
- (List.combine args args_mplaces)
+ let dest_v =
+ let dest = mk_typed_pattern_from_var dest dest_mplace in
+ match out_state with
+ | None -> dest
+ | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ]
in
- let dest_v = mk_typed_pattern_from_var dest dest_mplace in
let func = { id = Func fun_id; type_args } in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
- let ret_ty = mk_function_ret_ty config monadic state_monad dest_v.ty in
+ let ret_ty =
+ if effect_info.can_fail then mk_result_ty dest_v.ty else dest_v.ty
+ in
let func_ty = mk_arrows input_tys ret_ty in
let func = { e = Qualif func; ty = func_ty } in
let call = mk_apps func args in
(* Translate the next expression *)
let next_e = translate_expression config e ctx in
(* Put together *)
- mk_let monadic dest_v call next_e
+ mk_let effect_info.can_fail dest_v call next_e
and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression)
(ctx : bs_ctx) : texpression =
@@ -1198,14 +1301,15 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression)
let call_info = V.FunCallId.Map.find abs.call_id ctx.calls in
let call = call_info.forward in
let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in
- (* Retrive the orignal call and the parent abstractions *)
- let forward, backwards = get_abs_ancestors ctx abs in
+ (* Retrieve the original call and the parent abstractions *)
+ let _forward, backwards = get_abs_ancestors ctx abs in
(* Retrieve the values consumed when we called the forward function and
* ended the parent backward functions: those give us part of the input
- * values *)
- let fwd_inputs = List.map (typed_value_to_texpression ctx) forward.args in
+ * values (rmk: for now, as we disallow nested lifetimes, there can't be
+ * parent backward functions) *)
+ let fwd_inputs = call_info.forward_inputs in
let back_ancestors_inputs =
- List.concat (List.map (abs_to_consumed ctx) backwards)
+ List.concat (List.map (fun (_abs, args) -> args) backwards)
in
(* Retrieve the values consumed upon ending the loans inside this
* abstraction: those give us the remaining input values *)
@@ -1221,6 +1325,8 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression)
List.append (List.map translate_opt_mplace call.args_places) [ None ]
in
let ctx, outputs = abs_to_given_back output_mpl abs ctx in
+ (* Group the output values together (note that for now, backward functions
+ * never return an output state) *)
let output = mk_simpl_tuple_pattern outputs in
(* Sanity check: the inputs and outputs have the proper number and the proper type *)
let fun_id =
@@ -1242,13 +1348,13 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression)
("\n- outputs: "
^ string_of_int (List.length outputs)
^ "\n- expected outputs: "
- ^ string_of_int (List.length inst_sg.outputs)));
+ ^ string_of_int (List.length inst_sg.doutputs)));
List.iter
(fun (x, ty) -> assert ((x : typed_pattern).ty = ty))
- (List.combine outputs inst_sg.outputs);
+ (List.combine outputs inst_sg.doutputs);
(* Retrieve the function id, and register the function call in the context
* if necessary *)
- let ctx, func = bs_ctx_register_backward_call abs ctx in
+ let ctx, func = bs_ctx_register_backward_call abs back_inputs ctx in
(* Translate the next expression *)
let next_e = translate_expression config e ctx in
(* Put everything together *)
@@ -1258,9 +1364,11 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression)
(fun (arg, mp) -> mk_opt_mplace_texpression mp arg)
(List.combine inputs args_mplaces)
in
- let monadic, state_monad = fun_is_monadic fun_id in
+ let effect_info = get_fun_effect_info config fun_id (Some abs.back_id) in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
- let ret_ty = mk_function_ret_ty config monadic state_monad output.ty in
+ let ret_ty =
+ if effect_info.can_fail then mk_result_ty output.ty else output.ty
+ in
let func_ty = mk_arrows input_tys ret_ty in
let func = { id = Func func; type_args } in
let func = { e = Qualif func; ty = func_ty } in
@@ -1278,7 +1386,7 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression)
* a value containing mutable borrows, which can't be the case... *)
assert (List.length inputs = List.length fwd_inputs);
next_e)
- else mk_let monadic output call next_e
+ else mk_let effect_info.can_fail output call next_e
| V.SynthRet ->
(* If we end the abstraction which consumed the return value of the function
* we are synthesizing, we get back the borrows which were inside. Those borrows
@@ -1568,6 +1676,19 @@ let translate_fun_decl (config : config) (ctx : bs_ctx)
let body = translate_expression config body ctx in
(* Sanity check *)
type_check_texpression ctx body;
+ (* Introduce the input state, if necessary *)
+ let effect_info = get_fun_effect_info config (Regular def_id) bid in
+ let input_state =
+ if effect_info.input_state then
+ [
+ {
+ id = ctx.state_var;
+ basename = Some ConstStrings.state_basename;
+ ty = mk_state_ty;
+ };
+ ]
+ else []
+ in
(* Compute the list of (properly ordered) input variables *)
let backward_inputs : var list =
match bid with
@@ -1582,7 +1703,9 @@ let translate_fun_decl (config : config) (ctx : bs_ctx)
(fun id -> T.RegionGroupId.Map.find id ctx.backward_inputs)
backward_ids)
in
- let inputs = List.append ctx.forward_inputs backward_inputs in
+ let inputs =
+ List.concat [ ctx.forward_inputs; input_state; backward_inputs ]
+ in
let inputs_lvs =
List.map (fun v -> mk_typed_pattern_from_var v None) inputs
in
@@ -1593,38 +1716,6 @@ let translate_fun_decl (config : config) (ctx : bs_ctx)
(List.combine inputs signature.inputs));
Some { inputs; inputs_lvs; body }
in
- (* Make the signature monadic *)
- let output_ty =
- match (bid, signature.outputs) with
- | None, [ out_ty ] ->
- (* Forward function: there is always exactly one output *)
- (* We don't do the same thing if we use a state error monad or not:
- * - error-monad: `result out_ty`
- * - state-error: `state -> result (state & out_ty)
- *)
- if config.use_state_monad then
- let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in
- let ret = mk_arrow mk_state_ty ret in
- ret
- else (* Simply wrap the type in `result` *)
- mk_result_ty out_ty
- | Some _, outputs ->
- (* Backward function: we have to group the list of outputs into a tuple
- * (and similarly to the forward function, we don't do the same thing
- * if we use a state error monad or not):
- * - error-monad: `result (out_ty1 & .. out_tyn)`
- * - state-error: `state -> result (out_ty1 & .. out_tyn)`
- *)
- if config.use_state_monad then
- let ret = mk_simpl_tuple_ty outputs in
- let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; ret ]) in
- let ret = mk_arrow mk_state_ty ret in
- ret
- else mk_result_ty (mk_simpl_tuple_ty outputs)
- | _ -> raise (Failure "Unreachable")
- in
- let outputs = [ output_ty ] in
- let signature = { signature with outputs } in
(* Assemble the declaration *)
let def = { def_id; back_id = bid; basename; signature; body } in
(* Debugging *)
@@ -1650,7 +1741,7 @@ let translate_type_decls (type_decls : T.type_decl list) : type_decl list =
- optional names for the outputs values (we derive them for the backward
functions)
*)
-let translate_fun_signatures (types_infos : TA.type_infos)
+let translate_fun_signatures (config : config) (types_infos : TA.type_infos)
(functions : (A.fun_id * string option list * A.fun_sig) list) :
fun_sig_named_outputs RegularFunIdMap.t =
(* For every function, translate the signatures of:
@@ -1660,13 +1751,15 @@ let translate_fun_signatures (types_infos : TA.type_infos)
let translate_one (fun_id : A.fun_id) (input_names : string option list)
(sg : A.fun_sig) : (regular_fun_id * fun_sig_named_outputs) list =
(* The forward function *)
- let fwd_sg = translate_fun_sig types_infos sg input_names None in
+ let fwd_sg = translate_fun_sig config types_infos sg input_names None in
let fwd_id = (fun_id, None) in
(* The backward functions *)
let back_sgs =
List.map
(fun (rg : T.region_var_group) ->
- let tsg = translate_fun_sig types_infos sg input_names (Some rg.id) in
+ let tsg =
+ translate_fun_sig config types_infos sg input_names (Some rg.id)
+ in
let id = (fun_id, Some rg.id) in
(id, tsg))
sg.regions_hierarchy