summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2022-01-28 22:17:28 +0100
committerSon Ho2022-01-28 22:17:28 +0100
commit2d40d81b4b9fde44fd924bad5a44b7392a1c9f1e (patch)
treeae92f07a33a48bb794a633647d7efa111edb2a94
parent0b145dd4b0ab0ac5ed56121663a25801f20bed67 (diff)
Make the pure expressions typed
-rw-r--r--src/PrintPure.ml30
-rw-r--r--src/Pure.ml19
-rw-r--r--src/PureMicroPasses.ml109
-rw-r--r--src/PureUtils.ml11
-rw-r--r--src/SymbolicToPure.ml155
5 files changed, 196 insertions, 128 deletions
diff --git a/src/PrintPure.ml b/src/PrintPure.ml
index 44c0be24..b4816e14 100644
--- a/src/PrintPure.ml
+++ b/src/PrintPure.ml
@@ -365,9 +365,13 @@ let rec expression_to_string (fmt : ast_formatter) (indent : string)
switch_to_string fmt indent indent_incr scrutinee body
| Meta (meta, e) ->
let meta = meta_to_string fmt meta in
- let e = expression_to_string fmt indent indent_incr e in
+ let e = texpression_to_string fmt indent indent_incr e in
indent ^ meta ^ "\n" ^ e
+and texpression_to_string (fmt : ast_formatter) (indent : string)
+ (indent_incr : string) (e : texpression) : string =
+ expression_to_string fmt indent indent_incr e.e
+
and call_to_string (fmt : ast_formatter) (indent : string)
(indent_incr : string) (call : call) : string =
let ty_fmt = ast_to_type_formatter fmt in
@@ -376,35 +380,35 @@ and call_to_string (fmt : ast_formatter) (indent : string)
* those expressions will in most cases just be values) *)
let indent1 = indent ^ indent_incr in
let args =
- List.map (expression_to_string fmt indent1 indent_incr) call.args
+ List.map (texpression_to_string fmt indent1 indent_incr) call.args
in
let all_args = List.append tys args in
let fun_id = fun_id_to_string fmt call.func in
if all_args = [] then fun_id else fun_id ^ " " ^ String.concat " " all_args
and let_to_string (fmt : ast_formatter) (indent : string) (indent_incr : string)
- (monadic : bool) (lv : typed_lvalue) (re : expression) (e : expression) :
+ (monadic : bool) (lv : typed_lvalue) (re : texpression) (e : texpression) :
string =
let indent1 = indent ^ indent_incr in
let val_fmt = ast_to_value_formatter fmt in
- let re = expression_to_string fmt indent1 indent_incr re in
- let e = expression_to_string fmt indent indent_incr e in
+ let re = texpression_to_string fmt indent1 indent_incr re in
+ let e = texpression_to_string fmt indent indent_incr e in
let lv = typed_lvalue_to_string val_fmt lv in
if monadic then lv ^ " <-- " ^ re ^ ";\n" ^ indent ^ e
else "let " ^ lv ^ " = " ^ re ^ " in\n" ^ indent ^ e
and switch_to_string (fmt : ast_formatter) (indent : string)
- (indent_incr : string) (scrutinee : expression) (body : switch_body) :
+ (indent_incr : string) (scrutinee : texpression) (body : switch_body) :
string =
let indent1 = indent ^ indent_incr in
(* Printing can mess up on the scrutinee, because it is an expression - but
* in most situations it will be a value or a function call, so it should be
* ok*)
- let scrut = expression_to_string fmt indent1 indent_incr scrutinee in
+ let scrut = texpression_to_string fmt indent1 indent_incr scrutinee in
match body with
| If (e_true, e_false) ->
- let e_true = expression_to_string fmt indent1 indent_incr e_true in
- let e_false = expression_to_string fmt indent1 indent_incr e_false in
+ let e_true = texpression_to_string fmt indent1 indent_incr e_true in
+ let e_false = texpression_to_string fmt indent1 indent_incr e_false in
"if " ^ scrut ^ "\n" ^ indent ^ "then\n" ^ indent ^ e_true ^ "\n" ^ indent
^ "else\n" ^ indent ^ e_false
| SwitchInt (_, branches, otherwise) ->
@@ -412,12 +416,12 @@ and switch_to_string (fmt : ast_formatter) (indent : string)
List.map
(fun (v, be) ->
indent ^ "| " ^ scalar_value_to_string v ^ " ->\n" ^ indent1
- ^ expression_to_string fmt indent1 indent_incr be)
+ ^ texpression_to_string fmt indent1 indent_incr be)
branches
in
let otherwise =
indent ^ "| _ ->\n" ^ indent1
- ^ expression_to_string fmt indent1 indent_incr otherwise
+ ^ texpression_to_string fmt indent1 indent_incr otherwise
in
let all_branches = List.append branches [ otherwise ] in
"switch " ^ scrut ^ " with\n" ^ String.concat "\n" all_branches
@@ -426,7 +430,7 @@ and switch_to_string (fmt : ast_formatter) (indent : string)
let branch_to_string (b : match_branch) : string =
let pat = typed_lvalue_to_string val_fmt b.pat in
indent ^ "| " ^ pat ^ " ->\n" ^ indent1
- ^ expression_to_string fmt indent1 indent_incr b.branch
+ ^ texpression_to_string fmt indent1 indent_incr b.branch
in
let branches = List.map branch_to_string branches in
"match " ^ scrut ^ " with\n" ^ String.concat "\n" branches
@@ -439,5 +443,5 @@ let fun_def_to_string (fmt : ast_formatter) (def : fun_def) : string =
let inputs =
if inputs = [] then "" else " fun " ^ String.concat " " inputs ^ " ->\n"
in
- let body = expression_to_string fmt " " " " def.body in
+ let body = texpression_to_string fmt " " " " def.body in
"let " ^ name ^ " :\n " ^ signature ^ " =\n" ^ inputs ^ " " ^ body
diff --git a/src/Pure.ml b/src/Pure.ml
index dc35a181..bc655b63 100644
--- a/src/Pure.ml
+++ b/src/Pure.ml
@@ -445,7 +445,7 @@ class virtual ['self] mapreduce_expression_base =
type expression =
| Value of typed_rvalue * mplace option
| Call of call
- | Let of bool * typed_lvalue * expression * expression
+ | Let of bool * typed_lvalue * texpression * texpression
(** Let binding.
The boolean controls whether the let is monadic or not.
@@ -479,13 +479,13 @@ type expression =
...
```
*)
- | Switch of expression * switch_body
- | Meta of meta * expression (** Meta-information *)
+ | Switch of texpression * switch_body
+ | Meta of meta * texpression (** Meta-information *)
and call = {
func : fun_id;
type_params : ty list;
- args : expression list;
+ args : texpression list;
(** Note that immediately after we converted the symbolic AST to a pure AST,
some functions may have no arguments. For instance:
```
@@ -496,11 +496,14 @@ and call = {
}
and switch_body =
- | If of expression * expression
- | SwitchInt of T.integer_type * (scalar_value * expression) list * expression
+ | If of texpression * texpression
+ | SwitchInt of
+ T.integer_type * (scalar_value * texpression) list * texpression
| Match of match_branch list
-and match_branch = { pat : typed_lvalue; branch : expression }
+and match_branch = { pat : typed_lvalue; branch : texpression }
+
+and texpression = { e : expression; ty : ty }
[@@deriving
visitors
{
@@ -570,5 +573,5 @@ type fun_def = {
inputs_lvs : typed_lvalue list;
(** The inputs seen as lvalues. Allows to make transformations, for example
to replace unused variables by `_` *)
- body : expression;
+ body : texpression;
}
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml
index 82e5ecd2..2a7293c8 100644
--- a/src/PureMicroPasses.ml
+++ b/src/PureMicroPasses.ml
@@ -195,14 +195,18 @@ let compute_pretty_names (def : fun_def) : fun_def =
in
(* *)
- let rec update_expression (e : expression) (ctx : pn_ctx) :
- pn_ctx * expression =
- match e with
- | Value (v, mp) -> update_value v mp ctx
- | Call call -> update_call call ctx
- | Let (monadic, lb, re, e) -> update_let monadic lb re e ctx
- | Switch (scrut, body) -> update_switch_body scrut body ctx
- | Meta (meta, e) -> update_meta meta e ctx
+ let rec update_texpression (e : texpression) (ctx : pn_ctx) :
+ pn_ctx * texpression =
+ let ty = e.ty in
+ let ctx, e =
+ match e.e with
+ | Value (v, mp) -> update_value v mp ctx
+ | Call call -> update_call call ctx
+ | Let (monadic, lb, re, e) -> update_let monadic lb re e ctx
+ | Switch (scrut, body) -> update_switch_body scrut body ctx
+ | Meta (meta, e) -> update_meta meta e ctx
+ in
+ (ctx, { e; ty })
(* *)
and update_value (v : typed_rvalue) (mp : mplace option) (ctx : pn_ctx) :
pn_ctx * expression =
@@ -212,40 +216,40 @@ let compute_pretty_names (def : fun_def) : fun_def =
and update_call (call : call) (ctx : pn_ctx) : pn_ctx * expression =
let ctx, args =
List.fold_left_map
- (fun ctx arg -> update_expression arg ctx)
+ (fun ctx arg -> update_texpression arg ctx)
ctx call.args
in
let call = { call with args } in
(ctx, Call call)
(* *)
- and update_let (monadic : bool) (lv : typed_lvalue) (re : expression)
- (e : expression) (ctx : pn_ctx) : pn_ctx * expression =
+ and update_let (monadic : bool) (lv : typed_lvalue) (re : texpression)
+ (e : texpression) (ctx : pn_ctx) : pn_ctx * expression =
let ctx = add_left_constraint lv ctx in
- let ctx, re = update_expression re ctx in
- let ctx, e = update_expression e ctx in
+ let ctx, re = update_texpression re ctx in
+ let ctx, e = update_texpression e ctx in
let lv = update_typed_lvalue ctx lv in
(ctx, Let (monadic, lv, re, e))
(* *)
- and update_switch_body (scrut : expression) (body : switch_body)
+ and update_switch_body (scrut : texpression) (body : switch_body)
(ctx : pn_ctx) : pn_ctx * expression =
- let ctx, scrut = update_expression scrut ctx in
+ let ctx, scrut = update_texpression scrut ctx in
let ctx, body =
match body with
| If (e_true, e_false) ->
- let ctx1, e_true = update_expression e_true ctx in
- let ctx2, e_false = update_expression e_false ctx in
+ let ctx1, e_true = update_texpression e_true ctx in
+ let ctx2, e_false = update_texpression e_false ctx in
let ctx = merge_ctxs ctx1 ctx2 in
(ctx, If (e_true, e_false))
| SwitchInt (int_ty, branches, otherwise) ->
let ctx_branches_ls =
List.map
(fun (v, br) ->
- let ctx, br = update_expression br ctx in
+ let ctx, br = update_texpression br ctx in
(ctx, (v, br)))
branches
in
- let ctx, otherwise = update_expression otherwise ctx in
+ let ctx, otherwise = update_texpression otherwise ctx in
let ctxs, branches = List.split ctx_branches_ls in
let ctxs = merge_ctxs_ls ctxs in
let ctx = merge_ctxs ctx ctxs in
@@ -255,7 +259,7 @@ let compute_pretty_names (def : fun_def) : fun_def =
List.map
(fun br ->
let ctx = add_left_constraint br.pat ctx in
- let ctx, branch = update_expression br.branch ctx in
+ let ctx, branch = update_texpression br.branch ctx in
let pat = update_typed_lvalue ctx br.pat in
(ctx, { pat; branch }))
branches
@@ -266,12 +270,13 @@ let compute_pretty_names (def : fun_def) : fun_def =
in
(ctx, Switch (scrut, body))
(* *)
- and update_meta (meta : meta) (e : expression) (ctx : pn_ctx) :
+ and update_meta (meta : meta) (e : texpression) (ctx : pn_ctx) :
pn_ctx * expression =
match meta with
| Assignment (mp, rvalue) ->
let ctx = add_right_constraint mp rvalue ctx in
- update_expression e ctx
+ let ctx, e = update_texpression e ctx in
+ (ctx, e.e)
in
let input_names =
@@ -281,7 +286,7 @@ let compute_pretty_names (def : fun_def) : fun_def =
def.inputs
in
let ctx = VarId.Map.of_list input_names in
- let _, body = update_expression def.body ctx in
+ let _, body = update_texpression def.body ctx in
{ def with body }
(** Remove the meta-information *)
@@ -290,10 +295,10 @@ let remove_meta (def : fun_def) : fun_def =
object
inherit [_] map_expression as super
- method! visit_Meta env _ e = super#visit_expression env e
+ method! visit_Meta env _ e = super#visit_expression env e.e
end
in
- let body = obj#visit_expression () def.body in
+ let body = obj#visit_texpression () def.body in
{ def with body }
(** Inline the useless variable reassignments (a lot of variable assignments
@@ -326,7 +331,7 @@ let inline_useless_var_reassignments (inline_named : bool) (def : fun_def) :
* - the let-binding is not monadic
* - the left-value is a variable
* - the assigned expression is a value *)
- match (monadic, lv.value, re) with
+ match (monadic, lv.value, re.e) with
| false, LvVar (Var (lv_var, _)), Value (rv, _) -> (
(* Check that:
* - the left variable is unnamed or that [inline_named] is true
@@ -337,7 +342,7 @@ let inline_useless_var_reassignments (inline_named : bool) (def : fun_def) :
(* Update the environment and explore the next expression
* (dropping the currrent let) *)
let env = add_subst lv_var.id var env in
- super#visit_expression env e
+ super#visit_expression env e.e
| _ -> super#visit_Let env monadic lv re e)
| _ -> super#visit_Let env monadic lv re e
(** Visit the let-bindings to filter the useless ones (and update
@@ -353,7 +358,7 @@ let inline_useless_var_reassignments (inline_named : bool) (def : fun_def) :
(** Visit the places used as rvalues, to substitute them if necessary *)
end
in
- let body = obj#visit_expression VarId.Map.empty def.body in
+ let body = obj#visit_texpression VarId.Map.empty def.body in
{ def with body }
(** Given a forward or backward function call, is there, for every execution
@@ -382,7 +387,7 @@ let inline_useless_var_reassignments (inline_named : bool) (def : fun_def) :
In this situation, we can remove the call `f x`.
*)
let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (call0 : call)
- (e : expression) : bool =
+ (e : texpression) : bool =
let check_call call1 : bool =
(* Check the func_ids, to see if call1's function is a child of call0's function *)
match (call0.func, call1.func) with
@@ -428,7 +433,7 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (call0 : call)
* meta-values* (which we need to ignore). We only consider the
* case where both expressions are actually values. *)
let input_eq (v0, v1) =
- match (v0, v1) with
+ match (v0.e, v1.e) with
| Value (v0, _), Value (v1, _) -> v0 = v1
| _ -> false
in
@@ -451,37 +456,43 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (call0 : call)
method! visit_expression env e =
match e with
| Value (_, _) -> fun _ -> false
- | Let (_, _, Call call1, e) ->
+ | Let (_, _, { e = Call call1; ty = _ }, e) ->
let call_is_child = check_call call1 in
if call_is_child then fun () -> true
- else self#visit_expression env e
+ else self#visit_texpression env e
| Let (_, _, re, e) ->
fun () ->
- self#visit_expression env re () && self#visit_expression env e ()
+ self#visit_texpression env re ()
+ && self#visit_texpression env e ()
| Call call1 -> fun () -> check_call call1
- | Meta (_, e) -> self#visit_expression env e
+ | Meta (_, e) -> self#visit_texpression env e
| Switch (_, body) -> self#visit_switch_body env body
(** We need to reimplement the way we compose the booleans *)
+ method! visit_texpression env e =
+ (* We take care not to visit the type *)
+ self#visit_expression env e.e
+
method! visit_switch_body env body =
match body with
| If (e1, e2) ->
fun () ->
- self#visit_expression env e1 () && self#visit_expression env e2 ()
+ self#visit_texpression env e1 ()
+ && self#visit_texpression env e2 ()
| SwitchInt (_, branches, otherwise) ->
fun () ->
List.for_all
- (fun (_, br) -> self#visit_expression env br ())
+ (fun (_, br) -> self#visit_texpression env br ())
branches
- && self#visit_expression env otherwise ()
+ && self#visit_texpression env otherwise ()
| Match branches ->
fun () ->
List.for_all
- (fun br -> self#visit_expression env br.branch ())
+ (fun br -> self#visit_texpression env br.branch ())
branches
end
in
- visitor#visit_expression () e ()
+ visitor#visit_texpression () e ()
(** Filter the unused assignments (removes the unused variables, filters
the function calls) *)
@@ -546,13 +557,13 @@ let filter_unused (filter_monadic_calls : bool) (ctx : trans_ctx)
super#visit_expression env e
| Let (monadic, lv, re, e) ->
(* Compute the set of values used in the next expression *)
- let e, used = self#visit_expression env e in
+ let e, used = self#visit_texpression env e in
let used = used () in
(* Filter the left values *)
let lv, all_dummies = filter_typed_lvalue used lv in
(* Small utility - called if we can't filter the let-binding *)
let dont_filter () =
- let re, used_re = self#visit_expression env re in
+ let re, used_re = self#visit_texpression env re in
let used = VarId.Set.union used (used_re ()) in
(Let (monadic, lv, re, e), fun _ -> used)
in
@@ -560,12 +571,12 @@ let filter_unused (filter_monadic_calls : bool) (ctx : trans_ctx)
if all_dummies then
if not monadic then
(* Not a monadic let-binding: simple case *)
- (e, fun _ -> used)
+ (e.e, fun _ -> used)
else
(* Monadic let-binding: trickier.
* We can filter if the right-expression is a function call,
* under some conditions. *)
- match (filter_monadic_calls, re) with
+ match (filter_monadic_calls, re.e) with
| true, Call call ->
(* We need to check if there is a child call - see
* the comments for:
@@ -574,7 +585,7 @@ let filter_unused (filter_monadic_calls : bool) (ctx : trans_ctx)
expression_contains_child_call_in_all_paths ctx call e
in
if has_child_call then (* Filter *)
- (e, fun _ -> used)
+ (e.e, fun _ -> used)
else (* No child call: don't filter *)
dont_filter ()
| _ ->
@@ -585,7 +596,7 @@ let filter_unused (filter_monadic_calls : bool) (ctx : trans_ctx)
end
in
(* Visit the body *)
- let body, used_vars = expr_visitor#visit_expression () def.body in
+ let body, used_vars = expr_visitor#visit_texpression () def.body in
(* Visit the parameters *)
let used_vars = used_vars () in
let inputs_lvs =
@@ -604,12 +615,12 @@ let to_monadic (def : fun_def) : fun_def =
method! visit_call env call =
(* If no arguments, introduce unit *)
if call.args = [] then
- let args = [ Value (unit_rvalue, None) ] in
+ let args = [ mk_value_expression unit_rvalue None ] in
{ call with args } (* Otherwise: nothing to do *)
else super#visit_call env call
end
in
- let body = obj#visit_expression () def.body in
+ let body = obj#visit_texpression () def.body in
let def = { def with body } in
(* Update the signature: first the input types *)
@@ -617,7 +628,7 @@ let to_monadic (def : fun_def) : fun_def =
if def.inputs = [] then (
assert (def.signature.inputs = []);
let signature = { def.signature with inputs = [ unit_ty ] } in
- let var_cnt = get_expression_min_var_counter def.body in
+ let var_cnt = get_expression_min_var_counter def.body.e in
let id, _ = VarId.fresh var_cnt in
let var = { id; basename = None; ty = unit_ty } in
let inputs = [ var ] in
@@ -659,7 +670,7 @@ let unit_vars_to_unit (def : fun_def) : fun_def =
end
in
(* Update the body *)
- let body = obj#visit_expression () def.body in
+ let body = obj#visit_texpression () def.body in
(* Update the input parameters *)
let inputs_lvs = List.map (obj#visit_typed_lvalue ()) def.inputs_lvs in
(* Return *)
diff --git a/src/PureUtils.ml b/src/PureUtils.ml
index 9596cd9b..dd072d23 100644
--- a/src/PureUtils.ml
+++ b/src/PureUtils.ml
@@ -72,6 +72,17 @@ let mk_result_return_lvalue (v : typed_lvalue) : typed_lvalue =
let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ])
+let mk_value_expression (v : typed_rvalue) (mp : mplace option) : texpression =
+ let e = Value (v, mp) in
+ let ty = v.ty in
+ { e; ty }
+
+let mk_let (monadic : bool) (lv : typed_lvalue) (re : texpression)
+ (next_e : texpression) : texpression =
+ let e = Let (monadic, lv, re, next_e) in
+ let ty = next_e.ty in
+ { e; ty }
+
(** Type substitution *)
let ty_substitute (tsubst : TypeVarId.id -> ty) (ty : ty) : ty =
let obj =
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index f6f610dd..d48f732d 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -843,16 +843,23 @@ let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) : S.call * V.abs list =
let abs_ancestors = list_ancestor_abstractions ctx abs in
(call_info.forward, abs_ancestors)
-let rec translate_expression (e : S.expression) (ctx : bs_ctx) : expression =
+let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =
match e with
| S.Return opt_v -> translate_return opt_v ctx
- | Panic -> Value (mk_result_fail_rvalue ctx.ret_ty, None)
+ | Panic ->
+ (* Here we use the function return type - note that it is ok because
+ * we don't match on panics which happen inside the function body -
+ * but it won't be true anymore once we translate individual blocks *)
+ let v = mk_result_fail_rvalue ctx.ret_ty in
+ let e = Value (v, None) in
+ let ty = v.ty in
+ { e; ty }
| FunCall (call, e) -> translate_function_call call e ctx
| EndAbstraction (abs, e) -> translate_end_abstraction abs e ctx
| Expansion (p, sv, exp) -> translate_expansion p sv exp ctx
| Meta (meta, e) -> translate_meta meta e ctx
-and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : expression
+and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression
=
(* There are two cases:
- either we are translating a forward function, in which case the optional
@@ -864,7 +871,10 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : expression
(* Forward function *)
let v = Option.get opt_v in
let v = typed_value_to_rvalue ctx v in
- Value (mk_result_return_rvalue v, None)
+ let v = mk_result_return_rvalue v in
+ let e = Value (v, None) in
+ let ty = v.ty in
+ { e; ty }
| Some bid ->
(* Backward function *)
(* Sanity check *)
@@ -880,10 +890,13 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : expression
let ret_tys = List.map (fun (v : typed_rvalue) -> v.ty) field_values in
let ret_ty = Adt (Tuple, ret_tys) in
let ret_value : typed_rvalue = { value = ret_value; ty = ret_ty } in
- Value (mk_result_return_rvalue ret_value, None)
+ let ret_value = mk_result_return_rvalue ret_value in
+ let e = Value (ret_value, None) in
+ let ty = ret_value.ty in
+ { e; ty }
and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
- expression =
+ texpression =
(* Translate the function call *)
let type_params = List.map (ctx_translate_fwd_ty ctx) call.type_params in
let args = List.map (typed_value_to_rvalue ctx) call.args in
@@ -918,17 +931,22 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
| _ -> failwith "Unreachable")
in
let args =
- List.map (fun (arg, mp) -> Value (arg, mp)) (List.combine args args_mplaces)
+ List.map
+ (fun (arg, mp) -> mk_value_expression arg mp)
+ (List.combine args args_mplaces)
in
+ let dest_v = mk_typed_lvalue_from_var dest dest_mplace in
let call = { func; type_params; args } in
let call = Call call in
+ let call_ty = if monadic then mk_result_ty dest_v.ty else dest_v.ty in
+ let call = { e = call; ty = call_ty } in
(* Translate the next expression *)
- let e = translate_expression e ctx in
+ let next_e = translate_expression e ctx in
(* Put together *)
- Let (monadic, mk_typed_lvalue_from_var dest dest_mplace, call, e)
+ mk_let monadic dest_v call next_e
and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
- expression =
+ texpression =
log#ldebug
(lazy
("translate_end_abstraction: abstraction kind: "
@@ -976,14 +994,16 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
(fun (var, v) -> assert ((var : var).ty = (v : typed_rvalue).ty))
variables_values;
(* Translate the next expression *)
- let e = translate_expression e ctx in
+ let next_e = translate_expression e ctx in
(* Generate the assignemnts *)
let monadic = false in
List.fold_right
- (fun (var, value) e ->
- Let
- (monadic, mk_typed_lvalue_from_var var None, Value (value, None), e))
- variables_values e
+ (fun (var, value) (e : texpression) ->
+ mk_let monadic
+ (mk_typed_lvalue_from_var var None)
+ (mk_value_expression value None)
+ e)
+ variables_values next_e
| V.FunCall ->
let call_info = V.FunCallId.Map.find abs.call_id ctx.calls in
let call = call_info.forward in
@@ -1039,17 +1059,19 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
* if necessary *)
let ctx, func = bs_ctx_register_backward_call abs ctx in
(* Translate the next expression *)
- let e = translate_expression e ctx in
+ let next_e = translate_expression e ctx in
(* Put everything together *)
let args_mplaces = List.map (fun _ -> None) inputs in
let args =
List.map
- (fun (arg, mp) -> Value (arg, mp))
+ (fun (arg, mp) -> mk_value_expression arg mp)
(List.combine inputs args_mplaces)
in
- let call = { func; type_params; args } in
let monadic = true in
- Let (monadic, output, Call call, e)
+ let call = { func; type_params; args } in
+ let call_ty = mk_result_ty output.ty in
+ let call = { e = Call call; ty = call_ty } in
+ mk_let monadic output call next_e
| V.SynthRet ->
(* If we end the abstraction which consumed the return value of the function
* we are synthesizing, we get back the borrows which were inside. Those borrows
@@ -1100,20 +1122,18 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
assert (given_back.ty = input.ty))
given_back_inputs;
(* Translate the next expression *)
- let e = translate_expression e ctx in
+ let next_e = translate_expression e ctx in
(* Generate the assignments *)
let monadic = false in
List.fold_right
(fun (given_back, input_var) e ->
- Let
- ( monadic,
- given_back,
- Value (mk_typed_rvalue_from_var input_var, None),
- e ))
- given_back_inputs e
+ mk_let monadic given_back
+ (mk_value_expression (mk_typed_rvalue_from_var input_var) None)
+ e)
+ given_back_inputs next_e
and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
- (exp : S.expansion) (ctx : bs_ctx) : expression =
+ (exp : S.expansion) (ctx : bs_ctx) : texpression =
(* Translate the scrutinee *)
let scrutinee_var = lookup_var_for_symbolic_value sv ctx in
let scrutinee = mk_typed_rvalue_from_var scrutinee_var in
@@ -1130,13 +1150,12 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
(* The (mut/shared) borrow type is extracted to identity: we thus simply
* introduce an reassignment *)
let ctx, var = fresh_var_for_symbolic_value nsv ctx in
- let e = translate_expression e ctx in
+ let next_e = translate_expression e ctx in
let monadic = false in
- Let
- ( monadic,
- mk_typed_lvalue_from_var var None,
- Value (scrutinee, scrutinee_mplace),
- e )
+ mk_let monadic
+ (mk_typed_lvalue_from_var var None)
+ (mk_value_expression scrutinee scrutinee_mplace)
+ next_e
| SeAdt _ ->
(* Should be in the [ExpandAdt] case *)
failwith "Unreachable")
@@ -1161,7 +1180,10 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
in
let lv = mk_adt_lvalue scrutinee.ty variant_id lvars in
let monadic = false in
- Let (monadic, lv, Value (scrutinee, scrutinee_mplace), branch)
+
+ mk_let monadic lv
+ (mk_value_expression scrutinee scrutinee_mplace)
+ branch
else
(* This is not an enumeration: introduce let-bindings for every
* field.
@@ -1182,22 +1204,19 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
List.fold_right
(fun (fid, var) e ->
let field_proj = gen_field_proj fid var in
- Let
- ( monadic,
- mk_typed_lvalue_from_var var None,
- Value (field_proj, None),
- e ))
+ mk_let monadic
+ (mk_typed_lvalue_from_var var None)
+ (mk_value_expression field_proj None)
+ e)
id_var_pairs branch
| T.Tuple ->
let vars =
List.map (fun x -> mk_typed_lvalue_from_var x None) vars
in
let monadic = false in
- Let
- ( monadic,
- mk_tuple_lvalue vars,
- Value (scrutinee, scrutinee_mplace),
- branch )
+ mk_let monadic (mk_tuple_lvalue vars)
+ (mk_value_expression scrutinee scrutinee_mplace)
+ branch
| T.Assumed T.Box ->
(* There should be exactly one variable *)
let var =
@@ -1206,11 +1225,10 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
(* We simply introduce an assignment - the box type is the
* identity when extracted (`box a == a`) *)
let monadic = false in
- Let
- ( monadic,
- mk_typed_lvalue_from_var var None,
- Value (scrutinee, scrutinee_mplace),
- branch ))
+ mk_let monadic
+ (mk_typed_lvalue_from_var var None)
+ (mk_value_expression scrutinee scrutinee_mplace)
+ branch)
| branches ->
let translate_branch (variant_id : T.VariantId.id option)
(svl : V.symbolic_value list) (branch : S.expression) :
@@ -1229,16 +1247,30 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
let branches =
List.map (fun (vid, svl, e) -> translate_branch vid svl e) branches
in
- Switch (Value (scrutinee, scrutinee_mplace), Match branches))
+ let e =
+ Switch
+ (mk_value_expression scrutinee scrutinee_mplace, Match branches)
+ in
+ (* There should be at least one branch *)
+ let branch = List.hd branches in
+ let ty = branch.branch.ty in
+ assert (List.for_all (fun br -> br.branch.ty = ty) branches);
+ { e; ty })
| ExpandBool (true_e, false_e) ->
(* We don't need to update the context: we don't introduce any
* new values/variables *)
let true_e = translate_expression true_e ctx in
let false_e = translate_expression false_e ctx in
- Switch (Value (scrutinee, scrutinee_mplace), If (true_e, false_e))
+ let e =
+ Switch
+ (mk_value_expression scrutinee scrutinee_mplace, If (true_e, false_e))
+ in
+ let ty = true_e.ty in
+ assert (ty = false_e.ty);
+ { e; ty }
| ExpandInt (int_ty, branches, otherwise) ->
let translate_branch ((v, branch_e) : V.scalar_value * S.expression) :
- scalar_value * expression =
+ scalar_value * texpression =
(* We don't need to update the context: we don't introduce any
* new values/variables *)
let branch_e = translate_expression branch_e ctx in
@@ -1246,13 +1278,18 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
in
let branches = List.map translate_branch branches in
let otherwise = translate_expression otherwise ctx in
- Switch
- ( Value (scrutinee, scrutinee_mplace),
- SwitchInt (int_ty, branches, otherwise) )
+ let e =
+ Switch
+ ( mk_value_expression scrutinee scrutinee_mplace,
+ SwitchInt (int_ty, branches, otherwise) )
+ in
+ let ty = otherwise.ty in
+ assert (List.for_all (fun (_, br) -> br.ty = ty) branches);
+ { e; ty }
and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) :
- expression =
- let e = translate_expression e ctx in
+ texpression =
+ let next_e = translate_expression e ctx in
let meta =
match meta with
| S.Assignment (p, rv) ->
@@ -1260,7 +1297,9 @@ and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) :
let rv = typed_value_to_rvalue ctx rv in
Assignment (p, rv)
in
- Meta (meta, e)
+ let e = Meta (meta, next_e) in
+ let ty = next_e.ty in
+ { e; ty }
let translate_fun_def (ctx : bs_ctx) (body : S.expression) : fun_def =
let def = ctx.fun_def in