From 79b0bf1fdb0283c2bd9cbca91794105dda88f03b Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 26 Apr 2022 19:34:12 +0200 Subject: Introduce the Abs expression and continue updating the code --- src/Collections.ml | 5 + src/Identifiers.ml | 2 +- src/InterpreterExpressions.ml | 2 +- src/OfJsonBasic.ml | 2 +- src/Print.ml | 4 +- src/PrintPure.ml | 14 ++- src/Pure.ml | 1 + src/PureMicroPasses.ml | 223 ++++++++++++++---------------------------- src/PureUtils.ml | 82 +++++++++++----- src/SymbolicToPure.ml | 108 +++++++++++++++++--- 10 files changed, 249 insertions(+), 194 deletions(-) (limited to 'src') diff --git a/src/Collections.ml b/src/Collections.ml index 2d7a8787..614857e6 100644 --- a/src/Collections.ml +++ b/src/Collections.ml @@ -77,6 +77,11 @@ module List = struct match ls with | [ x ] -> x | _ -> raise (Failure "The list should have length exactly one") + + let pop (ls : 'a list) : 'a * 'a list = + match ls with + | x :: ls' -> (x, ls') + | _ -> raise (Failure "The list should have length > 0") end module type OrderedType = sig diff --git a/src/Identifiers.ml b/src/Identifiers.ml index 64a8ec03..61238aac 100644 --- a/src/Identifiers.ml +++ b/src/Identifiers.ml @@ -99,7 +99,7 @@ module IdGen () : Id = struct (* Identifiers should never overflow (because max_int is a really big * value - but we really want to make sure we detect overflows if * they happen *) - if x == max_int then raise (Errors.IntegerOverflow ()) else x + 1 + if x = max_int then raise (Errors.IntegerOverflow ()) else x + 1 let generator_from_incr_id id = incr id diff --git a/src/InterpreterExpressions.ml b/src/InterpreterExpressions.ml index c967688f..f4d97b9d 100644 --- a/src/InterpreterExpressions.ml +++ b/src/InterpreterExpressions.ml @@ -111,7 +111,7 @@ let rec operand_constant_value_to_typed_value (ctx : C.eval_ctx) (ty : T.ety) | T.Str, ConstantValue (String v) -> { V.value = V.Concrete (String v); ty } | T.Integer int_ty, ConstantValue (V.Scalar v) -> (* Check the type and the ranges *) - assert (int_ty == v.int_ty); + assert (int_ty = v.int_ty); assert (check_scalar_value_in_range v); { V.value = V.Concrete (V.Scalar v); ty } (* Remaining cases (invalid) - listing as much as we can on purpose diff --git a/src/OfJsonBasic.ml b/src/OfJsonBasic.ml index 9dbd521d..07daf03d 100644 --- a/src/OfJsonBasic.ml +++ b/src/OfJsonBasic.ml @@ -26,7 +26,7 @@ let int_of_json (js : json) : (int, string) result = let char_of_json (js : json) : (char, string) result = match js with | `String c -> - if String.length c == 1 then Ok c.[0] + if String.length c = 1 then Ok c.[0] else Error ("char_of_json: stricly more than one character in: " ^ show js) | _ -> Error ("char_of_json: not a char: " ^ show js) diff --git a/src/Print.ml b/src/Print.ml index 841fa9b2..98e9dd74 100644 --- a/src/Print.ml +++ b/src/Print.ml @@ -815,8 +815,8 @@ module LlbcAst = struct | E.Deref -> "*(" ^ s ^ ")" | E.DerefBox -> "deref_box(" ^ s ^ ")" | E.Field (E.ProjOption variant_id, fid) -> - assert (variant_id == T.option_some_id); - assert (fid == T.FieldId.zero); + assert (variant_id = T.option_some_id); + assert (fid = T.FieldId.zero); "(" ^ s ^ " as Option::Some)." ^ T.FieldId.to_string fid | E.Field (E.ProjTuple _, fid) -> "(" ^ s ^ ")." ^ T.FieldId.to_string fid diff --git a/src/PrintPure.ml b/src/PrintPure.ml index 158d4c3c..652e461d 100644 --- a/src/PrintPure.ml +++ b/src/PrintPure.ml @@ -209,8 +209,8 @@ let rec projection_to_string (fmt : ast_formatter) (inside : string) let s = projection_to_string fmt inside p' in match pe.pkind with | E.ProjOption variant_id -> - assert (variant_id == T.option_some_id); - assert (pe.field_id == T.FieldId.zero); + assert (variant_id = T.option_some_id); + assert (pe.field_id = T.FieldId.zero); "(" ^ s ^ "as Option::Some)." ^ T.FieldId.to_string pe.field_id | E.ProjTuple _ -> "(" ^ s ^ ")." ^ T.FieldId.to_string pe.field_id | E.ProjAdt (adt_id, opt_variant_id) -> ( @@ -442,6 +442,10 @@ let rec texpression_to_string (fmt : ast_formatter) (inner : bool) let app, args = destruct_apps e in (* Convert to string *) app_to_string fmt inner indent indent_incr app args + | Abs _ -> + let xl, e = destruct_abs_list e in + let e = abs_to_string fmt indent indent_incr xl e in + if inner then "(" ^ e ^ ")" else e | Func _ -> (* Func without arguments *) app_to_string fmt inner indent indent_incr e [] @@ -499,6 +503,12 @@ and app_to_string (fmt : ast_formatter) (inner : bool) (indent : string) (* Add parentheses *) if all_args <> [] && inner then "(" ^ e ^ ")" else e +and abs_to_string (fmt : ast_formatter) (indent : string) (indent_incr : string) + (xl : typed_lvalue list) (e : texpression) : string = + let xl = List.map (typed_lvalue_to_string fmt) xl in + let e = texpression_to_string fmt false indent indent_incr e in + "λ " ^ String.concat " " xl ^ ". " ^ e + and let_to_string (fmt : ast_formatter) (indent : string) (indent_incr : string) (monadic : bool) (lv : typed_lvalue) (re : texpression) (e : texpression) : string = diff --git a/src/Pure.ml b/src/Pure.ml index ebc92258..6729a43d 100644 --- a/src/Pure.ml +++ b/src/Pure.ml @@ -497,6 +497,7 @@ type expression = field accesses with calls to projectors over fields (when there are clashes of field names, some provers like F* get pretty bad...) *) + | Abs of typed_lvalue * texpression (** Lambda abstraction: `fun x -> e` *) | Func of func (** A function - TODO: change to Qualifier *) | Let of bool * typed_lvalue * texpression * texpression (** Let binding. diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index 198a4d89..aba9610d 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -9,12 +9,6 @@ module V = Values let log = L.pure_micro_passes_log type config = { - use_state_monad : bool; - (** If `true`, use a state-error monad. - If `false`, only use an error monad. - - Using a state-error monad is necessary when modelling I/O, for instance. - *) decompose_monadic_let_bindings : bool; (** Some provers like F* don't support the decomposition of return values in monadic let-bindings: @@ -447,6 +441,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl = let ctx, arg = update_texpression app ctx in let e = App (app, arg) in (ctx, e) + | Abs (x, e) -> update_abs x e ctx | Func _ -> (* nothing to do *) (ctx, e.e) | Let (monadic, lb, re, e) -> update_let monadic lb re e ctx | Switch (scrut, body) -> update_switch_body scrut body ctx @@ -459,6 +454,17 @@ let compute_pretty_names (def : fun_decl) : fun_decl = let ctx = add_opt_right_constraint mp v ctx in (ctx, Value (v, mp)) (* *) + and update_abs (x : typed_lvalue) (e : texpression) (ctx : pn_ctx) : + pn_ctx * expression = + (* We first add the left-constraint *) + let ctx = add_left_constraint x ctx in + (* Update the expression, and add additional constraints *) + let ctx, e = update_texpression e ctx in + (* Update the abstracted value *) + let x = update_typed_lvalue ctx x in + (* Put together *) + (ctx, Abs (x, e)) + (* *) and update_let (monadic : bool) (lv : typed_lvalue) (re : texpression) (e : texpression) (ctx : pn_ctx) : pn_ctx * expression = (* We first add the left-constraint *) @@ -797,6 +803,7 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (func0 : func) match opt_destruct_function_call e with | Some (func1, args1) -> check_call func1 args1 | None -> false) + | Abs (_, e) -> self#visit_texpression env e | Func _ -> fun () -> false | Meta (_, e) -> self#visit_texpression env e | Switch (_, body) -> self#visit_switch_body env body @@ -875,7 +882,7 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) method! visit_expression env e = match e with - | Value (_, _) | App _ | Switch (_, _) | Meta (_, _) -> + | Value (_, _) | App _ | Func _ | Switch (_, _) | Meta (_, _) | Abs _ -> super#visit_expression env e | Let (monadic, lv, re, e) -> (* Compute the set of values used in the next expression *) @@ -984,86 +991,6 @@ let keep_forward (config : config) (trans : pure_fun_translation) : bool = then false else true -(** Add unit arguments (optionally) to functions with no arguments, and - change their output type to use `result` - - TODO: remove this - *) -let to_monadic (config : config) (def : fun_decl) : fun_decl = - (* Update the body *) - let obj = - object - inherit [_] map_expression as super - - method! visit_call env call = - match call.func with - | Regular (A.Regular _, _) -> - if call.args = [] && config.add_unit_args then - let args = [ mk_value_expression unit_rvalue None ] in - { call with args } - else (* Otherwise: nothing to do *) super#visit_call env call - | Regular (A.Assumed _, _) | Unop _ | Binop _ -> - (* Unops, binops and primitive functions don't have unit arguments *) - super#visit_call env call - end - in - let def = - match def.body with - | None -> def - | Some body -> - let body = { body with body = obj#visit_texpression () body.body } in - { def with body = Some body } - in - - (* Update the signature: first the input types *) - let def = - if def.signature.inputs = [] && config.add_unit_args then - let signature = { def.signature with inputs = [ unit_ty ] } in - let body = - match def.body with - | None -> None - | Some body -> - let var_cnt = get_body_min_var_counter body in - let id, _ = VarId.fresh var_cnt in - let var = { id; basename = None; ty = unit_ty } in - let inputs = [ var ] in - let input_lv = mk_typed_lvalue_from_var var None in - let inputs_lvs = [ input_lv ] in - Some { body with inputs; inputs_lvs } - in - { def with signature; body } - else def - in - (* Then the output type *) - let output_ty = - match (def.back_id, def.signature.outputs) with - | None, [ out_ty ] -> - (* Forward function: there is always exactly one output *) - (* We don't do the same thing if we use a state error monad or not: - * - error-monad: `result out_ty` - * - state-error: `state -> result (state & out_ty) - *) - if config.use_state_monad then - let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in - let ret = mk_arrow_ty mk_state_ty ret in - ret - else (* Simply wrap the type in `result` *) - mk_result_ty out_ty - | Some _, outputs -> - (* Backward function: we have to group them *) - (* We don't do the same thing if we use a state error monad or not *) - if config.use_state_monad then - let ret = mk_simpl_tuple_ty outputs in - let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; ret ]) in - let ret = mk_arrow_ty mk_state_ty ret in - ret - else mk_result_ty (mk_simpl_tuple_ty outputs) - | _ -> failwith "Unreachable" - in - let outputs = [ output_ty ] in - let signature = { def.signature with outputs } in - { def with signature } - (** Convert the unit variables to `()` if they are used as right-values or `_` if they are used as left values in patterns. *) let unit_vars_to_unit (def : fun_decl) : fun_decl = @@ -1110,38 +1037,51 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = object inherit [_] map_expression as super - method! visit_App env call = - match call.func with - | Regular (A.Assumed aid, rg_id) -> ( - match (aid, rg_id) with - | A.BoxNew, _ -> - let arg = Collections.List.to_cons_nil call.args in - arg.e - | A.BoxDeref, None -> - (* `Box::deref` forward is the identity *) - let arg = Collections.List.to_cons_nil call.args in - arg.e - | A.BoxDeref, Some _ -> - (* `Box::deref` backward is `()` (doesn't give back anything) *) - (mk_value_expression unit_rvalue None).e - | A.BoxDerefMut, None -> - (* `Box::deref_mut` forward is the identity *) - let arg = Collections.List.to_cons_nil call.args in - arg.e - | A.BoxDerefMut, Some _ -> - (* `Box::deref_mut` back is the identity *) - let arg = - match call.args with - | [ _; given_back ] -> given_back - | _ -> failwith "Unreachable" - in - arg.e - | A.BoxFree, _ -> (mk_value_expression unit_rvalue None).e - | ( ( A.Replace | A.VecNew | A.VecPush | A.VecInsert | A.VecLen - | A.VecIndex | A.VecIndexMut ), - _ ) -> - super#visit_App env call) - | _ -> super#visit_App env call + method! visit_texpression env e = + match opt_destruct_function_call e with + | Some (func, args) -> ( + match func.func with + | Regular (A.Assumed aid, rg_id) -> ( + (* Below, when dealing with the arguments: we consider the very + * general case, where functions could be boxed (meaning we + * could have: `box_new f x`) + * *) + match (aid, rg_id) with + | A.BoxNew, _ -> + assert (rg_id = None); + let arg, args = Collections.List.pop args in + mk_apps arg args + | A.BoxDeref, None -> + (* `Box::deref` forward is the identity *) + let arg, args = Collections.List.pop args in + mk_apps arg args + | A.BoxDeref, Some _ -> + (* `Box::deref` backward is `()` (doesn't give back anything) *) + assert (args = []); + mk_value_expression unit_rvalue None + | A.BoxDerefMut, None -> + (* `Box::deref_mut` forward is the identity *) + let arg, args = Collections.List.pop args in + mk_apps arg args + | A.BoxDerefMut, Some _ -> + (* `Box::deref_mut` back is almost the identity: + * let box_deref_mut (x_init : t) (x_back : t) : t = x_back + * *) + let arg, args = + match args with + | _ :: given_back :: args -> (given_back, args) + | _ -> failwith "Unreachable" + in + mk_apps arg args + | A.BoxFree, _ -> + assert (args = []); + mk_value_expression unit_rvalue None + | ( ( A.Replace | A.VecNew | A.VecPush | A.VecInsert | A.VecLen + | A.VecIndex | A.VecIndexMut ), + _ ) -> + super#visit_texpression env e) + | _ -> super#visit_texpression env e) + | _ -> super#visit_texpression env e end in (* Update the body *) @@ -1221,6 +1161,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) inherit [_] map_expression as super method! visit_Let state_var monadic lv re e = + (* TODO: we should use a monad "kind" instead of a boolean *) if not monadic then super#visit_Let state_var monadic lv re e else (* We don't do the same thing if we use a state-error monad or simply @@ -1228,30 +1169,18 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) * Note that some functions always live in the error monad (arithmetic * operations, for instance). *) - let re_call = - match re.e with - | App call -> call - | _ -> raise (Failure "Unreachable: expected a function call") - in (* TODO: this information should be computed in SymbolicToPure and - * store in an enum ("monadic" should be an enum, not a bool). - * Also: everything will be cleaner once we update the AST to make - * it more idiomatic lambda calculus... *) - let re_call_can_use_state = - match re_call.func with - | Regular (A.Regular _, _) -> true - | Regular (A.Assumed _, _) | Unop _ | Binop _ -> false + * store in an enum ("monadic" should be an enum, not a bool). *) + let re_uses_state = + Option.is_some (opt_destruct_state_monad_result re.ty) in - if config.use_state_monad && re_call_can_use_state then - let re_call = - let call = re_call in + if re_uses_state then + (* Add the state argument on the right-expression *) + let re = let state_value = mk_typed_rvalue_from_var state_var in - let args = - call.args @ [ mk_value_expression state_value None ] - in - App { call with args } + let state_value = mk_value_expression state_value None in + mk_app re state_value in - let re = { re with e = re_call } in (* Create the match *) let fail_pat = mk_result_fail_lvalue lv.ty in let fail_value = mk_result_fail_rvalue e.ty in @@ -1273,6 +1202,8 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) let e = Switch (re, switch_body) in self#visit_expression state_var e else + let re_ty = Option.get (opt_destruct_result re.ty) in + assert (lv.ty = re_ty); let fail_pat = mk_result_fail_lvalue lv.ty in let fail_value = mk_result_fail_rvalue e.ty in let fail_branch = @@ -1312,9 +1243,9 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) else super#visit_Value state_var rv mp (** We also need to update values, in case this value is `Return ...`. - TODO: this is super ugly... We need to use the monadic functions - `fail` and `return` instead. - *) + TODO: this is super ugly... We need to use the monadic functions + fail` and `return` instead. + *) end in (* Update the body *) @@ -1386,14 +1317,6 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_decl) : match def with | None -> None | Some def -> - (* Add unit arguments for functions with no arguments, and change their return type. - * **Rk.**: from now onwards, the types in the AST are correct (until now, - * functions had return type `t` where they should have return type `result t`). - * TODO: this is not true with the state-error monad, unless we unfold the monadic binds. - * Also, from now onwards, the outputs list has length 1. *) - let def = to_monadic config def in - log#ldebug (lazy ("to_monadic:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); - (* Convert the unit variables to `()` if they are used as right-values or * `_` if they are used as left values. *) let def = unit_vars_to_unit def in diff --git a/src/PureUtils.ml b/src/PureUtils.ml index 0045cc1d..bcf93c3c 100644 --- a/src/PureUtils.ml +++ b/src/PureUtils.ml @@ -266,7 +266,7 @@ let functions_not_mutually_recursive (funs : fun_decl list) : bool = *) let rec let_group_requires_parentheses (e : texpression) : bool = match e.e with - | Value _ | App _ | Func _ -> false + | Value _ | App _ | Func _ | Abs _ -> false | Let (monadic, _, _, next_e) -> if monadic then true else let_group_requires_parentheses next_e | Switch (_, _) -> false @@ -404,30 +404,68 @@ let destruct_apps (e : texpression) : texpression * texpression list = in aux [] e +(** Make an `App (app, arg)` expression *) +let mk_app (app : texpression) (arg : texpression) : texpression = + match app.ty with + | Arrow (ty0, ty1) -> + (* Sanity check *) + assert (ty0 = arg.ty); + let e = App (app, arg) in + let ty = ty1 in + { e; ty } + | _ -> raise (Failure "Expected an arrow type") + (** The reverse of [destruct_app] *) -let mk_apps (e : texpression) (args : texpression list) : texpression = - (* Reverse the arguments *) - let args = List.rev args in - (* Apply *) - let rec aux (e : texpression) (args : texpression list) : texpression = - match args with - | [] -> e - | arg :: args' -> ( - let e' = aux e args' in - match e'.ty with - | Arrow (ty0, ty1) -> - (* Sanity check *) - assert (ty0 == arg.ty); - let e'' = App (e', arg) in - let ty'' = ty1 in - { e = e''; ty = ty'' } - | _ -> raise (Failure "Expected an arrow type")) - in - aux e args +let mk_apps (app : texpression) (args : texpression list) : texpression = + List.fold_left (fun app arg -> mk_app app arg) app args -(* Destruct an expression into a function identifier and a list of arguments, - * if possible *) +(** Destruct an expression into a function identifier and a list of arguments, + * if possible *) let opt_destruct_function_call (e : texpression) : (func * texpression list) option = let app, args = destruct_apps e in match app.e with Func func -> Some (func, args) | _ -> None + +(** Destruct an expression into a function identifier and a list of arguments *) +let destruct_function_call (e : texpression) : func * texpression list = + Option.get (opt_destruct_function_call e) + +let opt_destruct_result (ty : ty) : ty option = + match ty with + | Adt (Assumed Result, tys) -> Some (Collections.List.to_cons_nil tys) + | _ -> None + +let opt_destruct_tuple (ty : ty) : ty list option = + match ty with Adt (Tuple, tys) -> Some tys | _ -> None + +let opt_destruct_state_monad_result (ty : ty) : ty option = + (* Checking: + * ty == state -> result (state & _) ? *) + match ty with + | Arrow (ty0, ty1) -> + (* ty == ty0 -> ty1 + * Checking: ty0 == state ? + * *) + if ty0 = mk_state_ty then + (* Checking: ty1 == result (state & _) *) + match opt_destruct_result ty1 with + | None -> None + | Some ty2 -> ( + (* Checking: ty2 == state & _ *) + match opt_destruct_tuple ty2 with + | Some [ ty3; ty4 ] -> if ty3 = mk_state_ty then Some ty4 else None + | _ -> None) + else None + | _ -> None + +let mk_abs (x : typed_lvalue) (e : texpression) : texpression = + let ty = Arrow (x.ty, e.ty) in + let e = Abs (x, e) in + { e; ty } + +let rec destruct_abs_list (e : texpression) : typed_lvalue list * texpression = + match e.e with + | Abs (x, e') -> + let xl, e'' = destruct_abs_list e' in + (x :: xl, e'') + | _ -> ([], e) diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index 18e2b873..b25b7309 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -38,6 +38,12 @@ type config = { Note that we later filter the useless *forward* calls in the micro-passes, where it is more natural to do. *) + use_state_monad : bool; + (** If `true`, use a state-error monad. + If `false`, only use an error monad. + + Using a state-error monad is necessary when modelling I/O, for instance. + *) } type type_context = { @@ -950,6 +956,15 @@ let fun_is_monadic (fun_id : A.fun_id) : bool = | A.Regular _ -> true | A.Assumed aid -> Assumed.assumed_is_monadic aid +let mk_function_ret_ty (config : config) (monadic : bool) (out_ty : ty) : ty = + if monadic then + if config.use_state_monad then + let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in + let ret = mk_arrow_ty mk_state_ty ret in + ret + else mk_result_ty out_ty + else out_ty + let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx) : texpression = match e with @@ -967,7 +982,7 @@ let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx) | Expansion (p, sv, exp) -> translate_expansion config p sv exp ctx | Meta (meta, e) -> translate_meta config meta e ctx -and translate_return (_config : config) (opt_v : V.typed_value option) +and translate_return (config : config) (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = (* There are two cases: - either we are translating a forward function, in which case the optional @@ -979,13 +994,27 @@ and translate_return (_config : config) (opt_v : V.typed_value option) (* Forward function *) let v = Option.get opt_v in let v = typed_value_to_rvalue ctx v in - (* TODO: we need to use a `return` function (otherwise we have problems - * with the state-error monad). We also need to update the type when using - * a state-error monad. *) - let v = mk_result_return_rvalue v in - let e = Value (v, None) in - let ty = v.ty in - { e; ty } + (* We don't synthesize the same expression depending on the monad we use: + * - error-monad: Return x + * - state-error monad: fun state -> Return (state, x) + * *) + (* TODO: we should use a `return` function, it would be cleaner *) + if config.use_state_monad then + let _, state_var = fresh_var (Some "state") mk_state_ty ctx in + let state_rvalue = mk_typed_rvalue_from_var state_var in + let v = + mk_result_return_rvalue (mk_simpl_tuple_rvalue [ state_rvalue; v ]) + in + let e = Value (v, None) in + let ty = v.ty in + let e = { e; ty } in + let state_var = mk_typed_lvalue_from_var state_var None in + mk_abs state_var e + else + let v = mk_result_return_rvalue v in + let e = Value (v, None) in + let ty = v.ty in + { e; ty } | Some bid -> (* Backward function *) (* Sanity check *) @@ -997,11 +1026,27 @@ and translate_return (_config : config) (opt_v : V.typed_value option) T.RegionGroupId.Map.find bid ctx.backward_outputs in let field_values = List.map mk_typed_rvalue_from_var backward_outputs in - let ret_value = mk_simpl_tuple_rvalue field_values in - let ret_value = mk_result_return_rvalue ret_value in - let e = Value (ret_value, None) in - let ty = ret_value.ty in - { e; ty } + (* See the comment about the monads, for the forward function case *) + (* TODO: we should use a `fail` function, it would be cleaner *) + if config.use_state_monad then + let _, state_var = fresh_var (Some "state") mk_state_ty ctx in + let state_rvalue = mk_typed_rvalue_from_var state_var in + let ret_value = mk_simpl_tuple_rvalue field_values in + let ret_value = + mk_result_return_rvalue + (mk_simpl_tuple_rvalue [ state_rvalue; ret_value ]) + in + let e = Value (ret_value, None) in + let ty = ret_value.ty in + let e = { e; ty } in + let state_var = mk_typed_lvalue_from_var state_var None in + mk_abs state_var e + else + let ret_value = mk_simpl_tuple_rvalue field_values in + let ret_value = mk_result_return_rvalue ret_value in + let e = Value (ret_value, None) in + let ty = ret_value.ty in + { e; ty } and translate_function_call (config : config) (call : S.call) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -1047,7 +1092,7 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression) let dest_v = mk_typed_lvalue_from_var dest dest_mplace in let func = { func; type_params } in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in - let ret_ty = if monadic then mk_result_ty dest_v.ty else dest_v.ty in + let ret_ty = mk_function_ret_ty config monadic dest_v.ty in let func_ty = mk_arrows input_tys ret_ty in let func = { e = Func func; ty = func_ty } in let call = mk_apps func args in @@ -1180,7 +1225,7 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) in let monadic = fun_is_monadic fun_id in let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in - let ret_ty = if monadic then mk_result_ty output.ty else output.ty in + let ret_ty = mk_function_ret_ty config monadic output.ty in let func_ty = mk_arrows input_tys ret_ty in let func = { func; type_params } in let func = { e = Func func; ty = func_ty } in @@ -1470,6 +1515,7 @@ let translate_fun_decl (config : config) (ctx : bs_ctx) (* Translate the declaration *) let def_id = def.A.def_id in let basename = def.name in + (* Lookup the signature *) let signature = bs_ctx_lookup_local_function_sig def_id bid ctx in (* Translate the body, if there is *) let body = @@ -1502,6 +1548,38 @@ let translate_fun_decl (config : config) (ctx : bs_ctx) (List.combine inputs signature.inputs)); Some { inputs; inputs_lvs; body } in + (* Make the signature monadic *) + let output_ty = + match (bid, signature.outputs) with + | None, [ out_ty ] -> + (* Forward function: there is always exactly one output *) + (* We don't do the same thing if we use a state error monad or not: + * - error-monad: `result out_ty` + * - state-error: `state -> result (state & out_ty) + *) + if config.use_state_monad then + let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in + let ret = mk_arrow_ty mk_state_ty ret in + ret + else (* Simply wrap the type in `result` *) + mk_result_ty out_ty + | Some _, outputs -> + (* Backward function: we have to group the list of outputs into a tuple + * (and similarly to the forward function, we don't do the same thing + * if we use a state error monad or not): + * - error-monad: `result (out_ty1 & .. out_tyn)` + * - state-error: `state -> result (out_ty1 & .. out_tyn)` + *) + if config.use_state_monad then + let ret = mk_simpl_tuple_ty outputs in + let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; ret ]) in + let ret = mk_arrow_ty mk_state_ty ret in + ret + else mk_result_ty (mk_simpl_tuple_ty outputs) + | _ -> failwith "Unreachable" + in + let outputs = [ output_ty ] in + let signature = { signature with outputs } in (* Assemble the declaration *) let def = { def_id; back_id = bid; basename; signature; body } in (* Debugging *) -- cgit v1.2.3