diff options
author | Son Ho | 2022-05-04 13:54:45 +0200 |
---|---|---|
committer | Son Ho | 2022-05-04 13:54:45 +0200 |
commit | 37f80fd592f703ab9b14a9d3d5d638b9c335997f (patch) | |
tree | 2ccb9ccc445e181f354b5cb3093c425f9f666560 | |
parent | 593ffae18cf647457121470c371ba9effbc55f5d (diff) |
Start updating the way the function return type (with errors and states)
are handled
Diffstat (limited to '')
-rw-r--r-- | src/PrintPure.ml | 23 | ||||
-rw-r--r-- | src/Pure.ml | 81 | ||||
-rw-r--r-- | src/PureUtils.ml | 41 | ||||
-rw-r--r-- | src/SymbolicToPure.ml | 315 |
4 files changed, 278 insertions, 182 deletions
diff --git a/src/PrintPure.ml b/src/PrintPure.ml index f21329ed..07144d3e 100644 --- a/src/PrintPure.ml +++ b/src/PrintPure.ml @@ -383,30 +383,15 @@ let fun_sig_to_string (fmt : ast_formatter) (sg : fun_sig) : string = let ty_fmt = ast_to_type_formatter fmt in let type_params = List.map type_var_to_string sg.type_params in let inputs = List.map (ty_to_string ty_fmt) sg.inputs in - let outputs = List.map (ty_to_string ty_fmt) sg.outputs in - let outputs = - match outputs with - | [] -> - (* Can happen with backward functions which don't give back - * anything (shared borrows only) *) - "()" - | [ out ] -> out - | outputs -> "(" ^ String.concat " * " outputs ^ ")" - in - let all_types = List.concat [ type_params; inputs; [ outputs ] ] in + let output = ty_to_string ty_fmt sg.output in + let all_types = List.concat [ type_params; inputs; [ output ] ] in String.concat " -> " all_types let inst_fun_sig_to_string (fmt : ast_formatter) (sg : inst_fun_sig) : string = let ty_fmt = ast_to_type_formatter fmt in let inputs = List.map (ty_to_string ty_fmt) sg.inputs in - let outputs = List.map (ty_to_string ty_fmt) sg.outputs in - let outputs = - match outputs with - | [] -> "()" - | [ out ] -> out - | outputs -> "(" ^ String.concat " * " outputs ^ ")" - in - let all_types = List.append inputs [ outputs ] in + let output = ty_to_string ty_fmt sg.output in + let all_types = List.append inputs [ output ] in String.concat " -> " all_types let regular_fun_id_to_string (fmt : ast_formatter) (fun_id : A.fun_id) : string diff --git a/src/Pure.ml b/src/Pure.ml index e2362338..d8e1cafc 100644 --- a/src/Pure.ml +++ b/src/Pure.ml @@ -505,28 +505,79 @@ and meta = nude = true (* Don't inherit [VisitorsRuntime.iter] *); }] +type fun_sig_info = { + num_fwd_inputs : int; + (** The number of input types for forward computation *) + num_back_inputs : int option; + (** The number of additional inputs for the backward computation (if pertinent) *) + input_state : bool; (** `true` if the function takes a state as input *) + output_state : bool; + (** `true` if the function outputs a state (it then lives + in a state monad) *) + can_fail : bool; (** `true` if the return type is a `result` *) +} +(** Meta information about a function signature *) + type fun_sig = { type_params : type_var list; inputs : ty list; - outputs : ty list; - (** The list of outputs. - - Immediately after the translation from symbolic to pure we have this - the following: - In case of a forward function, the list will have length = 1. - However, in case of backward function, the list may have length > 1. - If the length is > 1, it gets extracted to a tuple type. Followingly, - we could not use a list because we can encode tuples, but here we - want to account for the fact that we immediately deconstruct the tuple - upon calling the backward function (because the backward function is - called to update a set of values in the environment). + output : ty; + doutputs : ty list; + (** The "decomposed" list of outputs. + + In case of a forward function, the list has length = 1, for the + type of the returned value. + + In case of backward function, the list contains all the types of + all the given back values (there is at most one type per forward + input argument). + + Ex.: + ``` + fn choose<'a, T>(b : bool, x : &'a mut T, y : &'a mut T) -> &'a mut T; + ``` + Decomposed outputs: + - forward function: [T] + - backward function: [T; T] (for "x" and "y") - After the "to monadic" pass, the list has size exactly one (and we - use the `Result` type). *) + info : fun_sig_info; (** Additional information *) } +(** A function signature. + + We have the following cases: + - forward function: + `in_ty0 -> ... -> in_tyn -> out_ty` (* pure function *) + `in_ty0 -> ... -> in_tyn -> result out_ty` (* error-monad *) + `in_ty0 -> ... -> in_tyn -> state -> result (state & out_ty)` (* state-error *) + - backward function: + `in_ty0 -> ... -> in_tyn -> back_in0 -> ... back_inm -> (back_out0 & ... & back_outp)` (* pure function *) + `in_ty0 -> ... -> in_tyn -> back_in0 -> ... back_inm -> + result (back_out0 & ... & back_outp)` (* error-monad *) + `in_ty0 -> ... -> in_tyn -> state -> back_in0 -> ... back_inm -> + result (back_out0 & ... & back_outp)` (* state-error *) + + Note that a backward function never returns (i.e., updates) a state: only + forward functions do so. Also, the state input parameter is *betwee* + the forward inputs and the backward inputs. + + The function's type should be given by `mk_arrows sig.inputs sig.output`. + We provide additional meta-information: + - we divide between forward inputs and backward inputs (i.e., inputs specific + to the forward functions, and additional inputs necessary if the signature is + for a backward function) + - we have booleans to give us the fact that the function takes a state as + input, or can fail, etc. without having to inspect the signature + - etc. + *) -type inst_fun_sig = { inputs : ty list; outputs : ty list } +type inst_fun_sig = { + inputs : ty list; + output : ty; + doutputs : ty list; + info : fun_sig_info; +} +(** An instantiated function signature. See [fun_sig] *) type fun_body = { inputs : var list; diff --git a/src/PureUtils.ml b/src/PureUtils.ml index 8651679f..a1af3396 100644 --- a/src/PureUtils.ml +++ b/src/PureUtils.ml @@ -125,8 +125,10 @@ let fun_sig_substitute (tsubst : TypeVarId.id -> ty) (sg : fun_sig) : inst_fun_sig = let subst = ty_substitute tsubst in let inputs = List.map subst sg.inputs in - let outputs = List.map subst sg.outputs in - { inputs; outputs } + let output = subst sg.output in + let doutputs = List.map subst sg.doutputs in + let info = sg.info in + { inputs; output; doutputs; info } (** Return true if a list of functions are *not* mutually recursive, false otherwise. This function is meant to be applied on a set of (forward, backwards) functions @@ -478,38 +480,3 @@ let opt_destruct_state_monad_result (ty : ty) : ty option = let opt_unmeta_mplace (e : texpression) : mplace option * texpression = match e.e with Meta (MPlace mp, e) -> (Some mp, e) | _ -> (None, e) - -(** Utility function, used for type checking - TODO: move *) -let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t) - (type_id : type_id) (variant_id : VariantId.id option) (tys : ty list) : - ty list = - match type_id with - | Tuple -> - (* Tuple *) - assert (variant_id = None); - tys - | AdtId def_id -> - (* "Regular" ADT *) - let def = TypeDeclId.Map.find def_id type_decls in - type_decl_get_instantiated_fields_types def variant_id tys - | Assumed aty -> ( - (* Assumed type *) - match aty with - | State -> - (* `State` is opaque *) - raise (Failure "Unreachable: `State` values are opaque") - | Result -> - let ty = Collections.List.to_cons_nil tys in - let variant_id = Option.get variant_id in - if variant_id = result_return_id then [ ty ] - else if variant_id = result_fail_id then [] - else - raise (Failure "Unreachable: improper variant id for result type") - | Option -> - let ty = Collections.List.to_cons_nil tys in - let variant_id = Option.get variant_id in - if variant_id = option_some_id then [ ty ] - else if variant_id = option_none_id then [] - else - raise (Failure "Unreachable: improper variant id for result type") - | Vec -> raise (Failure "Unreachable: `Vector` values are opaque")) diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index 4754d237..466e5562 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -78,12 +78,22 @@ type fun_context = { type call_info = { forward : S.call; - backwards : V.abs T.RegionGroupId.Map.t; - (** TODO: not sure we need this anymore *) + forward_inputs : texpression list; + (** Remember the list of inputs given to the forward function. + + Those inputs include the state input, if pertinent (in which case + it is the last input). + *) + backwards : (V.abs * texpression list) T.RegionGroupId.Map.t; + (** A map from region group id (i.e., backward function id) to + pairs (abstraction, additional arguments received by the backward function) + + TODO: remove? it is also in the bs_ctx ("abstractions" field) + *) } (** Whenever we translate a function call or an ended abstraction, we store the related information (this is useful when translating ended - children abstractions) + children abstractions). *) type bs_ctx = { @@ -96,8 +106,10 @@ type bs_ctx = { (** Whenever we encounter a new symbolic value (introduced because of a symbolic expansion or upon ending an abstraction, for instance) we introduce a new variable (with a let-binding). - *) + *) var_counter : VarId.generator; + state_var : VarId.id; + (** The current state variable, in case we use a state *) forward_inputs : var list; (** The input parameters for the forward function *) backward_inputs : var list T.RegionGroupId.Map.t; @@ -106,8 +118,8 @@ type bs_ctx = { (** The variables that the backward functions will output *) calls : call_info V.FunCallId.Map.t; (** The function calls we encountered so far *) - abstractions : V.abs V.AbstractionId.Map.t; - (** The ended abstractions we encountered so far *) + abstractions : (V.abs * texpression list) V.AbstractionId.Map.t; + (** The ended abstractions we encountered so far, with their additional input arguments *) } (** Body synthesis context *) @@ -197,26 +209,33 @@ let bs_ctx_lookup_local_function_sig (def_id : FunDeclId.id) (RegularFunIdMap.find id ctx.fun_context.fun_sigs).sg let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) - (ctx : bs_ctx) : bs_ctx = + (args : texpression list) (ctx : bs_ctx) : bs_ctx = let calls = ctx.calls in assert (not (V.FunCallId.Map.mem call_id calls)); - let info = { forward; backwards = T.RegionGroupId.Map.empty } in + let info = + { forward; forward_inputs = args; backwards = T.RegionGroupId.Map.empty } + in let calls = V.FunCallId.Map.add call_id info calls in { ctx with calls } -let bs_ctx_register_backward_call (abs : V.abs) (ctx : bs_ctx) : bs_ctx * fun_id - = +(** [back_args]: the *additional* list of inputs received by the backward function *) +let bs_ctx_register_backward_call (abs : V.abs) (back_args : texpression list) + (ctx : bs_ctx) : bs_ctx * fun_id = (* Insert the abstraction in the call informations *) let back_id = abs.back_id in let info = V.FunCallId.Map.find abs.call_id ctx.calls in assert (not (T.RegionGroupId.Map.mem back_id info.backwards)); - let backwards = T.RegionGroupId.Map.add back_id abs info.backwards in + let backwards = + T.RegionGroupId.Map.add back_id (abs, back_args) info.backwards + in let info = { info with backwards } in let calls = V.FunCallId.Map.add abs.call_id info ctx.calls in (* Insert the abstraction in the abstractions map *) let abstractions = ctx.abstractions in assert (not (V.AbstractionId.Map.mem abs.abs_id abstractions)); - let abstractions = V.AbstractionId.Map.add abs.abs_id abs abstractions in + let abstractions = + V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions + in (* Retrieve the fun_id *) let fun_id = match info.forward.call_id with @@ -438,7 +457,7 @@ let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs) : if V.AbstractionId.Set.mem abs_id !abs_set then () else ( abs_set := V.AbstractionId.Set.add abs_id !abs_set; - let abs = V.AbstractionId.Map.find abs_id ctx.abstractions in + let abs, _ = V.AbstractionId.Map.find abs_id ctx.abstractions in List.iter gather abs.original_parents) in List.iter gather abs.original_parents; @@ -449,7 +468,8 @@ let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs) : (fun id -> V.AbstractionId.Set.mem id ids) call_info.forward.abstractions -let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) : V.abs list = +let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) : + (V.abs * texpression list) list = let abs_ids = list_ancestor_abstractions_ids ctx abs in List.map (fun id -> V.AbstractionId.Map.find id ctx.abstractions) abs_ids @@ -460,9 +480,9 @@ let list_ancestor_abstractions (ctx : bs_ctx) (abs : V.abs) : V.abs list = name (outputs for backward functions come from borrows in the inputs of the forward function). *) -let translate_fun_sig (types_infos : TA.type_infos) (sg : A.fun_sig) - (input_names : string option list) (bid : T.RegionGroupId.id option) : - fun_sig_named_outputs = +let translate_fun_sig (config : config) (types_infos : TA.type_infos) + (sg : A.fun_sig) (input_names : string option list) + (bid : T.RegionGroupId.id option) : fun_sig_named_outputs = (* Retrieve the list of parent backward functions *) let gid, parents = match bid with @@ -510,9 +530,25 @@ let translate_fun_sig (types_infos : TA.type_infos) (sg : A.fun_sig) *) List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] in - let inputs = List.append fwd_inputs back_inputs in + (* Does the function take a state as input, does it return a state and can + * it fail? *) + (* For now, all the translated functions can fail *) + let can_fail = true in + (* For now, all translated functions have an input state if we setup + * the translation to use states *) + let input_state = config.use_state_monad in + (* Only the forward functions return a state *) + let output_state = input_state && bid = None in + (* *) + let state_ty = if input_state then [ mk_state_ty ] else [] in + (* Concatenate the inputs, in the following order: + * - forward inputs + * - state input + * - backward inputs + *) + let inputs = List.concat [ fwd_inputs; state_ty; back_inputs ] in (* Outputs *) - let output_names, outputs = + let output_names, doutputs = match gid with | None -> (* This is a forward function: there is one (unnamed) output *) @@ -542,12 +578,45 @@ let translate_fun_sig (types_infos : TA.type_infos) (sg : A.fun_sig) in List.split outputs in + (* Create the return type *) + let output = + (* Group the outputs together *) + let output = mk_simpl_tuple_ty doutputs in + (* Add the output state *) + let output = + if output_state then mk_simpl_tuple_ty [ mk_state_ty; output ] else output + in + (* Wrap in a result type *) + if can_fail then mk_result_ty output else output + in (* Type parameters *) let type_params = sg.type_params in (* Return *) - let sg = { type_params; inputs; outputs } in + let info = + { + num_fwd_inputs = List.length fwd_inputs; + num_back_inputs = + (if bid = None then None else Some (List.length back_inputs)); + input_state; + output_state; + can_fail; + } + in + let sg = { type_params; inputs; output; doutputs; info } in { sg; output_names } +let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern = + (* Generate the fresh variable *) + let id, var_counter = VarId.fresh ctx.var_counter in + let var = + { id; basename = Some ConstStrings.state_basename; ty = mk_state_ty } + in + let state_var = mk_typed_pattern_from_var var None in + (* Update the context *) + let ctx = { ctx with var_counter; state_var = id } in + (* Return *) + (ctx, state_var) + let fresh_named_var_for_symbolic_value (basename : string option) (sv : V.symbolic_value) (ctx : bs_ctx) : bs_ctx * var = (* Generate the fresh variable *) @@ -966,40 +1035,35 @@ let abs_to_given_back_no_mp (abs : V.abs) (ctx : bs_ctx) : Is used for instance when collecting the input values given to all the parent functions, in order to properly instantiate an *) -let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) : S.call * V.abs list = +let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) : + S.call * (V.abs * texpression list) list = let call_info = V.FunCallId.Map.find abs.call_id ctx.calls in let abs_ancestors = list_ancestor_abstractions ctx abs in (call_info.forward, abs_ancestors) -(** Small utility. - - Return: (function is monadic, if monadic, function uses state monad) - - Note that all functions are monadic except some assumed functions. - +type fun_effect_info = { + can_fail : bool; + input_state : bool; + output_state : bool; +} +(** TODO: factorize with fun_sig_info? TODO: use an enumeration *) -let fun_is_monadic (fun_id : A.fun_id) : bool * bool = - match fun_id with - | A.Regular _ -> (true, true) - | A.Assumed aid -> (Assumed.assumed_is_monadic aid, false) -(** Utility for function return types. - - A function return type can have the shape: - - ty - - result ty (* error-monad *) - - state -> result (state & ty) (* state-error monad *) - *) -let mk_function_ret_ty (config : config) (monadic : bool) (state_monad : bool) - (out_ty : ty) : ty = - if monadic then - if config.use_state_monad && state_monad then - let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in - let ret = mk_arrow mk_state_ty ret in - ret - else mk_result_ty out_ty - else out_ty +(** Small utility. *) +let get_fun_effect_info (config : config) (fun_id : A.fun_id) + (gid : T.RegionGroupId.id option) : fun_effect_info = + match fun_id with + | A.Regular _ -> + let input_state = config.use_state_monad in + let output_state = input_state && gid = None in + { can_fail = true; input_state; output_state } + | A.Assumed aid -> + { + can_fail = Assumed.assumed_is_monadic aid; + input_state = false; + output_state = false; + } let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -1089,27 +1153,58 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression) (ctx : bs_ctx) : texpression = (* Translate the function call *) let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in - let args = List.map (typed_value_to_texpression ctx) call.args in - let args_mplaces = List.map translate_opt_mplace call.args_places in + let args = + let args = List.map (typed_value_to_texpression ctx) call.args in + let args_mplaces = List.map translate_opt_mplace call.args_places in + List.map + (fun (arg, mp) -> mk_opt_mplace_texpression mp arg) + (List.combine args args_mplaces) + in let dest_mplace = translate_opt_mplace call.dest_place in let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in (* Retrieve the function id, and register the function call in the context * if necessary. *) - let ctx, fun_id, monadic, state_monad = + let ctx, fun_id, effect_info, args, out_state = match call.call_id with | S.Fun (fid, call_id) -> - let ctx = bs_ctx_register_forward_call call_id call ctx in + (* Regular function call *) let func = Regular (fid, None) in - let monadic, state_monad = fun_is_monadic fid in - (ctx, func, monadic, state_monad) - | S.Unop E.Not -> (ctx, Unop Not, false, false) + (* Retrieve the effect information about this function (can fail, + * takes a state as input, etc.) *) + let effect_info = get_fun_effect_info config fid None in + (* Add the state input argument *) + let args = + if effect_info.input_state then + let state_var = { e = Var ctx.state_var; ty = mk_state_ty } in + List.append args [ state_var ] + else args + in + (* Generate a fresh state variable if the function call introduces + * a new variable *) + let ctx, out_state = + if effect_info.input_state then + let ctx, var = bs_ctx_fresh_state_var ctx in + (ctx, Some var) + else (ctx, None) + in + (* Register the function call *) + let ctx = bs_ctx_register_forward_call call_id call args ctx in + (ctx, func, effect_info, args, out_state) + | S.Unop E.Not -> + let effect_info = + { can_fail = false; input_state = false; output_state = false } + in + (ctx, Unop Not, effect_info, args, None) | S.Unop E.Neg -> ( match args with | [ arg ] -> let int_ty = ty_as_integer arg.ty in (* Note that negation can lead to an overflow and thus fail (it * is thus monadic) *) - (ctx, Unop (Neg int_ty), true, false) + let effect_info = + { can_fail = true; input_state = false; output_state = false } + in + (ctx, Unop (Neg int_ty), effect_info, args, None) | _ -> raise (Failure "Unreachable")) | S.Binop binop -> ( match args with @@ -1117,26 +1212,34 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression) let int_ty0 = ty_as_integer arg0.ty in let int_ty1 = ty_as_integer arg1.ty in assert (int_ty0 = int_ty1); - let monadic = binop_can_fail binop in - (ctx, Binop (binop, int_ty0), monadic, false) + let effect_info = + { + can_fail = binop_can_fail binop; + input_state = false; + output_state = false; + } + in + (ctx, Binop (binop, int_ty0), effect_info, args, None) | _ -> raise (Failure "Unreachable")) in - let args = - List.map - (fun (arg, mp) -> mk_opt_mplace_texpression mp arg) - (List.combine args args_mplaces) + let dest_v = + let dest = mk_typed_pattern_from_var dest dest_mplace in + match out_state with + | None -> dest + | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ] in - let dest_v = mk_typed_pattern_from_var dest dest_mplace in let func = { id = Func fun_id; type_args } in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in - let ret_ty = mk_function_ret_ty config monadic state_monad dest_v.ty in + let ret_ty = + if effect_info.can_fail then mk_result_ty dest_v.ty else dest_v.ty + in let func_ty = mk_arrows input_tys ret_ty in let func = { e = Qualif func; ty = func_ty } in let call = mk_apps func args in (* Translate the next expression *) let next_e = translate_expression config e ctx in (* Put together *) - mk_let monadic dest_v call next_e + mk_let effect_info.can_fail dest_v call next_e and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -1198,14 +1301,15 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) let call_info = V.FunCallId.Map.find abs.call_id ctx.calls in let call = call_info.forward in let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in - (* Retrive the orignal call and the parent abstractions *) - let forward, backwards = get_abs_ancestors ctx abs in + (* Retrieve the original call and the parent abstractions *) + let _forward, backwards = get_abs_ancestors ctx abs in (* Retrieve the values consumed when we called the forward function and * ended the parent backward functions: those give us part of the input - * values *) - let fwd_inputs = List.map (typed_value_to_texpression ctx) forward.args in + * values (rmk: for now, as we disallow nested lifetimes, there can't be + * parent backward functions) *) + let fwd_inputs = call_info.forward_inputs in let back_ancestors_inputs = - List.concat (List.map (abs_to_consumed ctx) backwards) + List.concat (List.map (fun (_abs, args) -> args) backwards) in (* Retrieve the values consumed upon ending the loans inside this * abstraction: those give us the remaining input values *) @@ -1221,6 +1325,8 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) List.append (List.map translate_opt_mplace call.args_places) [ None ] in let ctx, outputs = abs_to_given_back output_mpl abs ctx in + (* Group the output values together (note that for now, backward functions + * never return an output state) *) let output = mk_simpl_tuple_pattern outputs in (* Sanity check: the inputs and outputs have the proper number and the proper type *) let fun_id = @@ -1242,13 +1348,13 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) ("\n- outputs: " ^ string_of_int (List.length outputs) ^ "\n- expected outputs: " - ^ string_of_int (List.length inst_sg.outputs))); + ^ string_of_int (List.length inst_sg.doutputs))); List.iter (fun (x, ty) -> assert ((x : typed_pattern).ty = ty)) - (List.combine outputs inst_sg.outputs); + (List.combine outputs inst_sg.doutputs); (* Retrieve the function id, and register the function call in the context * if necessary *) - let ctx, func = bs_ctx_register_backward_call abs ctx in + let ctx, func = bs_ctx_register_backward_call abs back_inputs ctx in (* Translate the next expression *) let next_e = translate_expression config e ctx in (* Put everything together *) @@ -1258,9 +1364,11 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) (fun (arg, mp) -> mk_opt_mplace_texpression mp arg) (List.combine inputs args_mplaces) in - let monadic, state_monad = fun_is_monadic fun_id in + let effect_info = get_fun_effect_info config fun_id (Some abs.back_id) in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in - let ret_ty = mk_function_ret_ty config monadic state_monad output.ty in + let ret_ty = + if effect_info.can_fail then mk_result_ty output.ty else output.ty + in let func_ty = mk_arrows input_tys ret_ty in let func = { id = Func func; type_args } in let func = { e = Qualif func; ty = func_ty } in @@ -1278,7 +1386,7 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) * a value containing mutable borrows, which can't be the case... *) assert (List.length inputs = List.length fwd_inputs); next_e) - else mk_let monadic output call next_e + else mk_let effect_info.can_fail output call next_e | V.SynthRet -> (* If we end the abstraction which consumed the return value of the function * we are synthesizing, we get back the borrows which were inside. Those borrows @@ -1568,6 +1676,19 @@ let translate_fun_decl (config : config) (ctx : bs_ctx) let body = translate_expression config body ctx in (* Sanity check *) type_check_texpression ctx body; + (* Introduce the input state, if necessary *) + let effect_info = get_fun_effect_info config (Regular def_id) bid in + let input_state = + if effect_info.input_state then + [ + { + id = ctx.state_var; + basename = Some ConstStrings.state_basename; + ty = mk_state_ty; + }; + ] + else [] + in (* Compute the list of (properly ordered) input variables *) let backward_inputs : var list = match bid with @@ -1582,7 +1703,9 @@ let translate_fun_decl (config : config) (ctx : bs_ctx) (fun id -> T.RegionGroupId.Map.find id ctx.backward_inputs) backward_ids) in - let inputs = List.append ctx.forward_inputs backward_inputs in + let inputs = + List.concat [ ctx.forward_inputs; input_state; backward_inputs ] + in let inputs_lvs = List.map (fun v -> mk_typed_pattern_from_var v None) inputs in @@ -1593,38 +1716,6 @@ let translate_fun_decl (config : config) (ctx : bs_ctx) (List.combine inputs signature.inputs)); Some { inputs; inputs_lvs; body } in - (* Make the signature monadic *) - let output_ty = - match (bid, signature.outputs) with - | None, [ out_ty ] -> - (* Forward function: there is always exactly one output *) - (* We don't do the same thing if we use a state error monad or not: - * - error-monad: `result out_ty` - * - state-error: `state -> result (state & out_ty) - *) - if config.use_state_monad then - let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in - let ret = mk_arrow mk_state_ty ret in - ret - else (* Simply wrap the type in `result` *) - mk_result_ty out_ty - | Some _, outputs -> - (* Backward function: we have to group the list of outputs into a tuple - * (and similarly to the forward function, we don't do the same thing - * if we use a state error monad or not): - * - error-monad: `result (out_ty1 & .. out_tyn)` - * - state-error: `state -> result (out_ty1 & .. out_tyn)` - *) - if config.use_state_monad then - let ret = mk_simpl_tuple_ty outputs in - let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; ret ]) in - let ret = mk_arrow mk_state_ty ret in - ret - else mk_result_ty (mk_simpl_tuple_ty outputs) - | _ -> raise (Failure "Unreachable") - in - let outputs = [ output_ty ] in - let signature = { signature with outputs } in (* Assemble the declaration *) let def = { def_id; back_id = bid; basename; signature; body } in (* Debugging *) @@ -1650,7 +1741,7 @@ let translate_type_decls (type_decls : T.type_decl list) : type_decl list = - optional names for the outputs values (we derive them for the backward functions) *) -let translate_fun_signatures (types_infos : TA.type_infos) +let translate_fun_signatures (config : config) (types_infos : TA.type_infos) (functions : (A.fun_id * string option list * A.fun_sig) list) : fun_sig_named_outputs RegularFunIdMap.t = (* For every function, translate the signatures of: @@ -1660,13 +1751,15 @@ let translate_fun_signatures (types_infos : TA.type_infos) let translate_one (fun_id : A.fun_id) (input_names : string option list) (sg : A.fun_sig) : (regular_fun_id * fun_sig_named_outputs) list = (* The forward function *) - let fwd_sg = translate_fun_sig types_infos sg input_names None in + let fwd_sg = translate_fun_sig config types_infos sg input_names None in let fwd_id = (fun_id, None) in (* The backward functions *) let back_sgs = List.map (fun (rg : T.region_var_group) -> - let tsg = translate_fun_sig types_infos sg input_names (Some rg.id) in + let tsg = + translate_fun_sig config types_infos sg input_names (Some rg.id) + in let id = (fun_id, Some rg.id) in (id, tsg)) sg.regions_hierarchy |