summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSon Ho2022-04-26 19:34:12 +0200
committerSon Ho2022-04-26 19:34:12 +0200
commit79b0bf1fdb0283c2bd9cbca91794105dda88f03b (patch)
tree4a51f6cf3fcb208f074a306b80cb1a89ddbbbd63 /src
parent732e3305cba3a628d9408a048978151e4ef2fcc2 (diff)
Introduce the Abs expression and continue updating the code
Diffstat (limited to '')
-rw-r--r--src/Collections.ml5
-rw-r--r--src/Identifiers.ml2
-rw-r--r--src/InterpreterExpressions.ml2
-rw-r--r--src/OfJsonBasic.ml2
-rw-r--r--src/Print.ml4
-rw-r--r--src/PrintPure.ml14
-rw-r--r--src/Pure.ml1
-rw-r--r--src/PureMicroPasses.ml223
-rw-r--r--src/PureUtils.ml82
-rw-r--r--src/SymbolicToPure.ml108
10 files changed, 249 insertions, 194 deletions
diff --git a/src/Collections.ml b/src/Collections.ml
index 2d7a8787..614857e6 100644
--- a/src/Collections.ml
+++ b/src/Collections.ml
@@ -77,6 +77,11 @@ module List = struct
match ls with
| [ x ] -> x
| _ -> raise (Failure "The list should have length exactly one")
+
+ let pop (ls : 'a list) : 'a * 'a list =
+ match ls with
+ | x :: ls' -> (x, ls')
+ | _ -> raise (Failure "The list should have length > 0")
end
module type OrderedType = sig
diff --git a/src/Identifiers.ml b/src/Identifiers.ml
index 64a8ec03..61238aac 100644
--- a/src/Identifiers.ml
+++ b/src/Identifiers.ml
@@ -99,7 +99,7 @@ module IdGen () : Id = struct
(* Identifiers should never overflow (because max_int is a really big
* value - but we really want to make sure we detect overflows if
* they happen *)
- if x == max_int then raise (Errors.IntegerOverflow ()) else x + 1
+ if x = max_int then raise (Errors.IntegerOverflow ()) else x + 1
let generator_from_incr_id id = incr id
diff --git a/src/InterpreterExpressions.ml b/src/InterpreterExpressions.ml
index c967688f..f4d97b9d 100644
--- a/src/InterpreterExpressions.ml
+++ b/src/InterpreterExpressions.ml
@@ -111,7 +111,7 @@ let rec operand_constant_value_to_typed_value (ctx : C.eval_ctx) (ty : T.ety)
| T.Str, ConstantValue (String v) -> { V.value = V.Concrete (String v); ty }
| T.Integer int_ty, ConstantValue (V.Scalar v) ->
(* Check the type and the ranges *)
- assert (int_ty == v.int_ty);
+ assert (int_ty = v.int_ty);
assert (check_scalar_value_in_range v);
{ V.value = V.Concrete (V.Scalar v); ty }
(* Remaining cases (invalid) - listing as much as we can on purpose
diff --git a/src/OfJsonBasic.ml b/src/OfJsonBasic.ml
index 9dbd521d..07daf03d 100644
--- a/src/OfJsonBasic.ml
+++ b/src/OfJsonBasic.ml
@@ -26,7 +26,7 @@ let int_of_json (js : json) : (int, string) result =
let char_of_json (js : json) : (char, string) result =
match js with
| `String c ->
- if String.length c == 1 then Ok c.[0]
+ if String.length c = 1 then Ok c.[0]
else Error ("char_of_json: stricly more than one character in: " ^ show js)
| _ -> Error ("char_of_json: not a char: " ^ show js)
diff --git a/src/Print.ml b/src/Print.ml
index 841fa9b2..98e9dd74 100644
--- a/src/Print.ml
+++ b/src/Print.ml
@@ -815,8 +815,8 @@ module LlbcAst = struct
| E.Deref -> "*(" ^ s ^ ")"
| E.DerefBox -> "deref_box(" ^ s ^ ")"
| E.Field (E.ProjOption variant_id, fid) ->
- assert (variant_id == T.option_some_id);
- assert (fid == T.FieldId.zero);
+ assert (variant_id = T.option_some_id);
+ assert (fid = T.FieldId.zero);
"(" ^ s ^ " as Option::Some)." ^ T.FieldId.to_string fid
| E.Field (E.ProjTuple _, fid) ->
"(" ^ s ^ ")." ^ T.FieldId.to_string fid
diff --git a/src/PrintPure.ml b/src/PrintPure.ml
index 158d4c3c..652e461d 100644
--- a/src/PrintPure.ml
+++ b/src/PrintPure.ml
@@ -209,8 +209,8 @@ let rec projection_to_string (fmt : ast_formatter) (inside : string)
let s = projection_to_string fmt inside p' in
match pe.pkind with
| E.ProjOption variant_id ->
- assert (variant_id == T.option_some_id);
- assert (pe.field_id == T.FieldId.zero);
+ assert (variant_id = T.option_some_id);
+ assert (pe.field_id = T.FieldId.zero);
"(" ^ s ^ "as Option::Some)." ^ T.FieldId.to_string pe.field_id
| E.ProjTuple _ -> "(" ^ s ^ ")." ^ T.FieldId.to_string pe.field_id
| E.ProjAdt (adt_id, opt_variant_id) -> (
@@ -442,6 +442,10 @@ let rec texpression_to_string (fmt : ast_formatter) (inner : bool)
let app, args = destruct_apps e in
(* Convert to string *)
app_to_string fmt inner indent indent_incr app args
+ | Abs _ ->
+ let xl, e = destruct_abs_list e in
+ let e = abs_to_string fmt indent indent_incr xl e in
+ if inner then "(" ^ e ^ ")" else e
| Func _ ->
(* Func without arguments *)
app_to_string fmt inner indent indent_incr e []
@@ -499,6 +503,12 @@ and app_to_string (fmt : ast_formatter) (inner : bool) (indent : string)
(* Add parentheses *)
if all_args <> [] && inner then "(" ^ e ^ ")" else e
+and abs_to_string (fmt : ast_formatter) (indent : string) (indent_incr : string)
+ (xl : typed_lvalue list) (e : texpression) : string =
+ let xl = List.map (typed_lvalue_to_string fmt) xl in
+ let e = texpression_to_string fmt false indent indent_incr e in
+ "λ " ^ String.concat " " xl ^ ". " ^ e
+
and let_to_string (fmt : ast_formatter) (indent : string) (indent_incr : string)
(monadic : bool) (lv : typed_lvalue) (re : texpression) (e : texpression) :
string =
diff --git a/src/Pure.ml b/src/Pure.ml
index ebc92258..6729a43d 100644
--- a/src/Pure.ml
+++ b/src/Pure.ml
@@ -497,6 +497,7 @@ type expression =
field accesses with calls to projectors over fields (when there
are clashes of field names, some provers like F* get pretty bad...)
*)
+ | Abs of typed_lvalue * texpression (** Lambda abstraction: `fun x -> e` *)
| Func of func (** A function - TODO: change to Qualifier *)
| Let of bool * typed_lvalue * texpression * texpression
(** Let binding.
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml
index 198a4d89..aba9610d 100644
--- a/src/PureMicroPasses.ml
+++ b/src/PureMicroPasses.ml
@@ -9,12 +9,6 @@ module V = Values
let log = L.pure_micro_passes_log
type config = {
- use_state_monad : bool;
- (** If `true`, use a state-error monad.
- If `false`, only use an error monad.
-
- Using a state-error monad is necessary when modelling I/O, for instance.
- *)
decompose_monadic_let_bindings : bool;
(** Some provers like F* don't support the decomposition of return values
in monadic let-bindings:
@@ -447,6 +441,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
let ctx, arg = update_texpression app ctx in
let e = App (app, arg) in
(ctx, e)
+ | Abs (x, e) -> update_abs x e ctx
| Func _ -> (* nothing to do *) (ctx, e.e)
| Let (monadic, lb, re, e) -> update_let monadic lb re e ctx
| Switch (scrut, body) -> update_switch_body scrut body ctx
@@ -459,6 +454,17 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
let ctx = add_opt_right_constraint mp v ctx in
(ctx, Value (v, mp))
(* *)
+ and update_abs (x : typed_lvalue) (e : texpression) (ctx : pn_ctx) :
+ pn_ctx * expression =
+ (* We first add the left-constraint *)
+ let ctx = add_left_constraint x ctx in
+ (* Update the expression, and add additional constraints *)
+ let ctx, e = update_texpression e ctx in
+ (* Update the abstracted value *)
+ let x = update_typed_lvalue ctx x in
+ (* Put together *)
+ (ctx, Abs (x, e))
+ (* *)
and update_let (monadic : bool) (lv : typed_lvalue) (re : texpression)
(e : texpression) (ctx : pn_ctx) : pn_ctx * expression =
(* We first add the left-constraint *)
@@ -797,6 +803,7 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (func0 : func)
match opt_destruct_function_call e with
| Some (func1, args1) -> check_call func1 args1
| None -> false)
+ | Abs (_, e) -> self#visit_texpression env e
| Func _ -> fun () -> false
| Meta (_, e) -> self#visit_texpression env e
| Switch (_, body) -> self#visit_switch_body env body
@@ -875,7 +882,7 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
method! visit_expression env e =
match e with
- | Value (_, _) | App _ | Switch (_, _) | Meta (_, _) ->
+ | Value (_, _) | App _ | Func _ | Switch (_, _) | Meta (_, _) | Abs _ ->
super#visit_expression env e
| Let (monadic, lv, re, e) ->
(* Compute the set of values used in the next expression *)
@@ -984,86 +991,6 @@ let keep_forward (config : config) (trans : pure_fun_translation) : bool =
then false
else true
-(** Add unit arguments (optionally) to functions with no arguments, and
- change their output type to use `result`
-
- TODO: remove this
- *)
-let to_monadic (config : config) (def : fun_decl) : fun_decl =
- (* Update the body *)
- let obj =
- object
- inherit [_] map_expression as super
-
- method! visit_call env call =
- match call.func with
- | Regular (A.Regular _, _) ->
- if call.args = [] && config.add_unit_args then
- let args = [ mk_value_expression unit_rvalue None ] in
- { call with args }
- else (* Otherwise: nothing to do *) super#visit_call env call
- | Regular (A.Assumed _, _) | Unop _ | Binop _ ->
- (* Unops, binops and primitive functions don't have unit arguments *)
- super#visit_call env call
- end
- in
- let def =
- match def.body with
- | None -> def
- | Some body ->
- let body = { body with body = obj#visit_texpression () body.body } in
- { def with body = Some body }
- in
-
- (* Update the signature: first the input types *)
- let def =
- if def.signature.inputs = [] && config.add_unit_args then
- let signature = { def.signature with inputs = [ unit_ty ] } in
- let body =
- match def.body with
- | None -> None
- | Some body ->
- let var_cnt = get_body_min_var_counter body in
- let id, _ = VarId.fresh var_cnt in
- let var = { id; basename = None; ty = unit_ty } in
- let inputs = [ var ] in
- let input_lv = mk_typed_lvalue_from_var var None in
- let inputs_lvs = [ input_lv ] in
- Some { body with inputs; inputs_lvs }
- in
- { def with signature; body }
- else def
- in
- (* Then the output type *)
- let output_ty =
- match (def.back_id, def.signature.outputs) with
- | None, [ out_ty ] ->
- (* Forward function: there is always exactly one output *)
- (* We don't do the same thing if we use a state error monad or not:
- * - error-monad: `result out_ty`
- * - state-error: `state -> result (state & out_ty)
- *)
- if config.use_state_monad then
- let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in
- let ret = mk_arrow_ty mk_state_ty ret in
- ret
- else (* Simply wrap the type in `result` *)
- mk_result_ty out_ty
- | Some _, outputs ->
- (* Backward function: we have to group them *)
- (* We don't do the same thing if we use a state error monad or not *)
- if config.use_state_monad then
- let ret = mk_simpl_tuple_ty outputs in
- let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; ret ]) in
- let ret = mk_arrow_ty mk_state_ty ret in
- ret
- else mk_result_ty (mk_simpl_tuple_ty outputs)
- | _ -> failwith "Unreachable"
- in
- let outputs = [ output_ty ] in
- let signature = { def.signature with outputs } in
- { def with signature }
-
(** Convert the unit variables to `()` if they are used as right-values or
`_` if they are used as left values in patterns. *)
let unit_vars_to_unit (def : fun_decl) : fun_decl =
@@ -1110,38 +1037,51 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
object
inherit [_] map_expression as super
- method! visit_App env call =
- match call.func with
- | Regular (A.Assumed aid, rg_id) -> (
- match (aid, rg_id) with
- | A.BoxNew, _ ->
- let arg = Collections.List.to_cons_nil call.args in
- arg.e
- | A.BoxDeref, None ->
- (* `Box::deref` forward is the identity *)
- let arg = Collections.List.to_cons_nil call.args in
- arg.e
- | A.BoxDeref, Some _ ->
- (* `Box::deref` backward is `()` (doesn't give back anything) *)
- (mk_value_expression unit_rvalue None).e
- | A.BoxDerefMut, None ->
- (* `Box::deref_mut` forward is the identity *)
- let arg = Collections.List.to_cons_nil call.args in
- arg.e
- | A.BoxDerefMut, Some _ ->
- (* `Box::deref_mut` back is the identity *)
- let arg =
- match call.args with
- | [ _; given_back ] -> given_back
- | _ -> failwith "Unreachable"
- in
- arg.e
- | A.BoxFree, _ -> (mk_value_expression unit_rvalue None).e
- | ( ( A.Replace | A.VecNew | A.VecPush | A.VecInsert | A.VecLen
- | A.VecIndex | A.VecIndexMut ),
- _ ) ->
- super#visit_App env call)
- | _ -> super#visit_App env call
+ method! visit_texpression env e =
+ match opt_destruct_function_call e with
+ | Some (func, args) -> (
+ match func.func with
+ | Regular (A.Assumed aid, rg_id) -> (
+ (* Below, when dealing with the arguments: we consider the very
+ * general case, where functions could be boxed (meaning we
+ * could have: `box_new f x`)
+ * *)
+ match (aid, rg_id) with
+ | A.BoxNew, _ ->
+ assert (rg_id = None);
+ let arg, args = Collections.List.pop args in
+ mk_apps arg args
+ | A.BoxDeref, None ->
+ (* `Box::deref` forward is the identity *)
+ let arg, args = Collections.List.pop args in
+ mk_apps arg args
+ | A.BoxDeref, Some _ ->
+ (* `Box::deref` backward is `()` (doesn't give back anything) *)
+ assert (args = []);
+ mk_value_expression unit_rvalue None
+ | A.BoxDerefMut, None ->
+ (* `Box::deref_mut` forward is the identity *)
+ let arg, args = Collections.List.pop args in
+ mk_apps arg args
+ | A.BoxDerefMut, Some _ ->
+ (* `Box::deref_mut` back is almost the identity:
+ * let box_deref_mut (x_init : t) (x_back : t) : t = x_back
+ * *)
+ let arg, args =
+ match args with
+ | _ :: given_back :: args -> (given_back, args)
+ | _ -> failwith "Unreachable"
+ in
+ mk_apps arg args
+ | A.BoxFree, _ ->
+ assert (args = []);
+ mk_value_expression unit_rvalue None
+ | ( ( A.Replace | A.VecNew | A.VecPush | A.VecInsert | A.VecLen
+ | A.VecIndex | A.VecIndexMut ),
+ _ ) ->
+ super#visit_texpression env e)
+ | _ -> super#visit_texpression env e)
+ | _ -> super#visit_texpression env e
end
in
(* Update the body *)
@@ -1221,6 +1161,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
inherit [_] map_expression as super
method! visit_Let state_var monadic lv re e =
+ (* TODO: we should use a monad "kind" instead of a boolean *)
if not monadic then super#visit_Let state_var monadic lv re e
else
(* We don't do the same thing if we use a state-error monad or simply
@@ -1228,30 +1169,18 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
* Note that some functions always live in the error monad (arithmetic
* operations, for instance).
*)
- let re_call =
- match re.e with
- | App call -> call
- | _ -> raise (Failure "Unreachable: expected a function call")
- in
(* TODO: this information should be computed in SymbolicToPure and
- * store in an enum ("monadic" should be an enum, not a bool).
- * Also: everything will be cleaner once we update the AST to make
- * it more idiomatic lambda calculus... *)
- let re_call_can_use_state =
- match re_call.func with
- | Regular (A.Regular _, _) -> true
- | Regular (A.Assumed _, _) | Unop _ | Binop _ -> false
+ * store in an enum ("monadic" should be an enum, not a bool). *)
+ let re_uses_state =
+ Option.is_some (opt_destruct_state_monad_result re.ty)
in
- if config.use_state_monad && re_call_can_use_state then
- let re_call =
- let call = re_call in
+ if re_uses_state then
+ (* Add the state argument on the right-expression *)
+ let re =
let state_value = mk_typed_rvalue_from_var state_var in
- let args =
- call.args @ [ mk_value_expression state_value None ]
- in
- App { call with args }
+ let state_value = mk_value_expression state_value None in
+ mk_app re state_value
in
- let re = { re with e = re_call } in
(* Create the match *)
let fail_pat = mk_result_fail_lvalue lv.ty in
let fail_value = mk_result_fail_rvalue e.ty in
@@ -1273,6 +1202,8 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
let e = Switch (re, switch_body) in
self#visit_expression state_var e
else
+ let re_ty = Option.get (opt_destruct_result re.ty) in
+ assert (lv.ty = re_ty);
let fail_pat = mk_result_fail_lvalue lv.ty in
let fail_value = mk_result_fail_rvalue e.ty in
let fail_branch =
@@ -1312,9 +1243,9 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
else super#visit_Value state_var rv mp
(** We also need to update values, in case this value is `Return ...`.
- TODO: this is super ugly... We need to use the monadic functions
- `fail` and `return` instead.
- *)
+ TODO: this is super ugly... We need to use the monadic functions
+ fail` and `return` instead.
+ *)
end
in
(* Update the body *)
@@ -1386,14 +1317,6 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_decl) :
match def with
| None -> None
| Some def ->
- (* Add unit arguments for functions with no arguments, and change their return type.
- * **Rk.**: from now onwards, the types in the AST are correct (until now,
- * functions had return type `t` where they should have return type `result t`).
- * TODO: this is not true with the state-error monad, unless we unfold the monadic binds.
- * Also, from now onwards, the outputs list has length 1. *)
- let def = to_monadic config def in
- log#ldebug (lazy ("to_monadic:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
-
(* Convert the unit variables to `()` if they are used as right-values or
* `_` if they are used as left values. *)
let def = unit_vars_to_unit def in
diff --git a/src/PureUtils.ml b/src/PureUtils.ml
index 0045cc1d..bcf93c3c 100644
--- a/src/PureUtils.ml
+++ b/src/PureUtils.ml
@@ -266,7 +266,7 @@ let functions_not_mutually_recursive (funs : fun_decl list) : bool =
*)
let rec let_group_requires_parentheses (e : texpression) : bool =
match e.e with
- | Value _ | App _ | Func _ -> false
+ | Value _ | App _ | Func _ | Abs _ -> false
| Let (monadic, _, _, next_e) ->
if monadic then true else let_group_requires_parentheses next_e
| Switch (_, _) -> false
@@ -404,30 +404,68 @@ let destruct_apps (e : texpression) : texpression * texpression list =
in
aux [] e
+(** Make an `App (app, arg)` expression *)
+let mk_app (app : texpression) (arg : texpression) : texpression =
+ match app.ty with
+ | Arrow (ty0, ty1) ->
+ (* Sanity check *)
+ assert (ty0 = arg.ty);
+ let e = App (app, arg) in
+ let ty = ty1 in
+ { e; ty }
+ | _ -> raise (Failure "Expected an arrow type")
+
(** The reverse of [destruct_app] *)
-let mk_apps (e : texpression) (args : texpression list) : texpression =
- (* Reverse the arguments *)
- let args = List.rev args in
- (* Apply *)
- let rec aux (e : texpression) (args : texpression list) : texpression =
- match args with
- | [] -> e
- | arg :: args' -> (
- let e' = aux e args' in
- match e'.ty with
- | Arrow (ty0, ty1) ->
- (* Sanity check *)
- assert (ty0 == arg.ty);
- let e'' = App (e', arg) in
- let ty'' = ty1 in
- { e = e''; ty = ty'' }
- | _ -> raise (Failure "Expected an arrow type"))
- in
- aux e args
+let mk_apps (app : texpression) (args : texpression list) : texpression =
+ List.fold_left (fun app arg -> mk_app app arg) app args
-(* Destruct an expression into a function identifier and a list of arguments,
- * if possible *)
+(** Destruct an expression into a function identifier and a list of arguments,
+ * if possible *)
let opt_destruct_function_call (e : texpression) :
(func * texpression list) option =
let app, args = destruct_apps e in
match app.e with Func func -> Some (func, args) | _ -> None
+
+(** Destruct an expression into a function identifier and a list of arguments *)
+let destruct_function_call (e : texpression) : func * texpression list =
+ Option.get (opt_destruct_function_call e)
+
+let opt_destruct_result (ty : ty) : ty option =
+ match ty with
+ | Adt (Assumed Result, tys) -> Some (Collections.List.to_cons_nil tys)
+ | _ -> None
+
+let opt_destruct_tuple (ty : ty) : ty list option =
+ match ty with Adt (Tuple, tys) -> Some tys | _ -> None
+
+let opt_destruct_state_monad_result (ty : ty) : ty option =
+ (* Checking:
+ * ty == state -> result (state & _) ? *)
+ match ty with
+ | Arrow (ty0, ty1) ->
+ (* ty == ty0 -> ty1
+ * Checking: ty0 == state ?
+ * *)
+ if ty0 = mk_state_ty then
+ (* Checking: ty1 == result (state & _) *)
+ match opt_destruct_result ty1 with
+ | None -> None
+ | Some ty2 -> (
+ (* Checking: ty2 == state & _ *)
+ match opt_destruct_tuple ty2 with
+ | Some [ ty3; ty4 ] -> if ty3 = mk_state_ty then Some ty4 else None
+ | _ -> None)
+ else None
+ | _ -> None
+
+let mk_abs (x : typed_lvalue) (e : texpression) : texpression =
+ let ty = Arrow (x.ty, e.ty) in
+ let e = Abs (x, e) in
+ { e; ty }
+
+let rec destruct_abs_list (e : texpression) : typed_lvalue list * texpression =
+ match e.e with
+ | Abs (x, e') ->
+ let xl, e'' = destruct_abs_list e' in
+ (x :: xl, e'')
+ | _ -> ([], e)
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index 18e2b873..b25b7309 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -38,6 +38,12 @@ type config = {
Note that we later filter the useless *forward* calls in the micro-passes,
where it is more natural to do.
*)
+ use_state_monad : bool;
+ (** If `true`, use a state-error monad.
+ If `false`, only use an error monad.
+
+ Using a state-error monad is necessary when modelling I/O, for instance.
+ *)
}
type type_context = {
@@ -950,6 +956,15 @@ let fun_is_monadic (fun_id : A.fun_id) : bool =
| A.Regular _ -> true
| A.Assumed aid -> Assumed.assumed_is_monadic aid
+let mk_function_ret_ty (config : config) (monadic : bool) (out_ty : ty) : ty =
+ if monadic then
+ if config.use_state_monad then
+ let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in
+ let ret = mk_arrow_ty mk_state_ty ret in
+ ret
+ else mk_result_ty out_ty
+ else out_ty
+
let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx)
: texpression =
match e with
@@ -967,7 +982,7 @@ let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx)
| Expansion (p, sv, exp) -> translate_expansion config p sv exp ctx
| Meta (meta, e) -> translate_meta config meta e ctx
-and translate_return (_config : config) (opt_v : V.typed_value option)
+and translate_return (config : config) (opt_v : V.typed_value option)
(ctx : bs_ctx) : texpression =
(* There are two cases:
- either we are translating a forward function, in which case the optional
@@ -979,13 +994,27 @@ and translate_return (_config : config) (opt_v : V.typed_value option)
(* Forward function *)
let v = Option.get opt_v in
let v = typed_value_to_rvalue ctx v in
- (* TODO: we need to use a `return` function (otherwise we have problems
- * with the state-error monad). We also need to update the type when using
- * a state-error monad. *)
- let v = mk_result_return_rvalue v in
- let e = Value (v, None) in
- let ty = v.ty in
- { e; ty }
+ (* We don't synthesize the same expression depending on the monad we use:
+ * - error-monad: Return x
+ * - state-error monad: fun state -> Return (state, x)
+ * *)
+ (* TODO: we should use a `return` function, it would be cleaner *)
+ if config.use_state_monad then
+ let _, state_var = fresh_var (Some "state") mk_state_ty ctx in
+ let state_rvalue = mk_typed_rvalue_from_var state_var in
+ let v =
+ mk_result_return_rvalue (mk_simpl_tuple_rvalue [ state_rvalue; v ])
+ in
+ let e = Value (v, None) in
+ let ty = v.ty in
+ let e = { e; ty } in
+ let state_var = mk_typed_lvalue_from_var state_var None in
+ mk_abs state_var e
+ else
+ let v = mk_result_return_rvalue v in
+ let e = Value (v, None) in
+ let ty = v.ty in
+ { e; ty }
| Some bid ->
(* Backward function *)
(* Sanity check *)
@@ -997,11 +1026,27 @@ and translate_return (_config : config) (opt_v : V.typed_value option)
T.RegionGroupId.Map.find bid ctx.backward_outputs
in
let field_values = List.map mk_typed_rvalue_from_var backward_outputs in
- let ret_value = mk_simpl_tuple_rvalue field_values in
- let ret_value = mk_result_return_rvalue ret_value in
- let e = Value (ret_value, None) in
- let ty = ret_value.ty in
- { e; ty }
+ (* See the comment about the monads, for the forward function case *)
+ (* TODO: we should use a `fail` function, it would be cleaner *)
+ if config.use_state_monad then
+ let _, state_var = fresh_var (Some "state") mk_state_ty ctx in
+ let state_rvalue = mk_typed_rvalue_from_var state_var in
+ let ret_value = mk_simpl_tuple_rvalue field_values in
+ let ret_value =
+ mk_result_return_rvalue
+ (mk_simpl_tuple_rvalue [ state_rvalue; ret_value ])
+ in
+ let e = Value (ret_value, None) in
+ let ty = ret_value.ty in
+ let e = { e; ty } in
+ let state_var = mk_typed_lvalue_from_var state_var None in
+ mk_abs state_var e
+ else
+ let ret_value = mk_simpl_tuple_rvalue field_values in
+ let ret_value = mk_result_return_rvalue ret_value in
+ let e = Value (ret_value, None) in
+ let ty = ret_value.ty in
+ { e; ty }
and translate_function_call (config : config) (call : S.call) (e : S.expression)
(ctx : bs_ctx) : texpression =
@@ -1047,7 +1092,7 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression)
let dest_v = mk_typed_lvalue_from_var dest dest_mplace in
let func = { func; type_params } in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
- let ret_ty = if monadic then mk_result_ty dest_v.ty else dest_v.ty in
+ let ret_ty = mk_function_ret_ty config monadic dest_v.ty in
let func_ty = mk_arrows input_tys ret_ty in
let func = { e = Func func; ty = func_ty } in
let call = mk_apps func args in
@@ -1180,7 +1225,7 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression)
in
let monadic = fun_is_monadic fun_id in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
- let ret_ty = if monadic then mk_result_ty output.ty else output.ty in
+ let ret_ty = mk_function_ret_ty config monadic output.ty in
let func_ty = mk_arrows input_tys ret_ty in
let func = { func; type_params } in
let func = { e = Func func; ty = func_ty } in
@@ -1470,6 +1515,7 @@ let translate_fun_decl (config : config) (ctx : bs_ctx)
(* Translate the declaration *)
let def_id = def.A.def_id in
let basename = def.name in
+ (* Lookup the signature *)
let signature = bs_ctx_lookup_local_function_sig def_id bid ctx in
(* Translate the body, if there is *)
let body =
@@ -1502,6 +1548,38 @@ let translate_fun_decl (config : config) (ctx : bs_ctx)
(List.combine inputs signature.inputs));
Some { inputs; inputs_lvs; body }
in
+ (* Make the signature monadic *)
+ let output_ty =
+ match (bid, signature.outputs) with
+ | None, [ out_ty ] ->
+ (* Forward function: there is always exactly one output *)
+ (* We don't do the same thing if we use a state error monad or not:
+ * - error-monad: `result out_ty`
+ * - state-error: `state -> result (state & out_ty)
+ *)
+ if config.use_state_monad then
+ let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in
+ let ret = mk_arrow_ty mk_state_ty ret in
+ ret
+ else (* Simply wrap the type in `result` *)
+ mk_result_ty out_ty
+ | Some _, outputs ->
+ (* Backward function: we have to group the list of outputs into a tuple
+ * (and similarly to the forward function, we don't do the same thing
+ * if we use a state error monad or not):
+ * - error-monad: `result (out_ty1 & .. out_tyn)`
+ * - state-error: `state -> result (out_ty1 & .. out_tyn)`
+ *)
+ if config.use_state_monad then
+ let ret = mk_simpl_tuple_ty outputs in
+ let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; ret ]) in
+ let ret = mk_arrow_ty mk_state_ty ret in
+ ret
+ else mk_result_ty (mk_simpl_tuple_ty outputs)
+ | _ -> failwith "Unreachable"
+ in
+ let outputs = [ output_ty ] in
+ let signature = { signature with outputs } in
(* Assemble the declaration *)
let def = { def_id; back_id = bid; basename; signature; body } in
(* Debugging *)