From 2d40d81b4b9fde44fd924bad5a44b7392a1c9f1e Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 28 Jan 2022 22:17:28 +0100 Subject: Make the pure expressions typed --- src/PrintPure.ml | 30 +++++----- src/Pure.ml | 19 +++--- src/PureMicroPasses.ml | 109 ++++++++++++++++++---------------- src/PureUtils.ml | 11 ++++ src/SymbolicToPure.ml | 155 +++++++++++++++++++++++++++++++------------------ 5 files changed, 196 insertions(+), 128 deletions(-) (limited to 'src') diff --git a/src/PrintPure.ml b/src/PrintPure.ml index 44c0be24..b4816e14 100644 --- a/src/PrintPure.ml +++ b/src/PrintPure.ml @@ -365,9 +365,13 @@ let rec expression_to_string (fmt : ast_formatter) (indent : string) switch_to_string fmt indent indent_incr scrutinee body | Meta (meta, e) -> let meta = meta_to_string fmt meta in - let e = expression_to_string fmt indent indent_incr e in + let e = texpression_to_string fmt indent indent_incr e in indent ^ meta ^ "\n" ^ e +and texpression_to_string (fmt : ast_formatter) (indent : string) + (indent_incr : string) (e : texpression) : string = + expression_to_string fmt indent indent_incr e.e + and call_to_string (fmt : ast_formatter) (indent : string) (indent_incr : string) (call : call) : string = let ty_fmt = ast_to_type_formatter fmt in @@ -376,35 +380,35 @@ and call_to_string (fmt : ast_formatter) (indent : string) * those expressions will in most cases just be values) *) let indent1 = indent ^ indent_incr in let args = - List.map (expression_to_string fmt indent1 indent_incr) call.args + List.map (texpression_to_string fmt indent1 indent_incr) call.args in let all_args = List.append tys args in let fun_id = fun_id_to_string fmt call.func in if all_args = [] then fun_id else fun_id ^ " " ^ String.concat " " all_args and let_to_string (fmt : ast_formatter) (indent : string) (indent_incr : string) - (monadic : bool) (lv : typed_lvalue) (re : expression) (e : expression) : + (monadic : bool) (lv : typed_lvalue) (re : texpression) (e : texpression) : string = let indent1 = indent ^ indent_incr in let val_fmt = ast_to_value_formatter fmt in - let re = expression_to_string fmt indent1 indent_incr re in - let e = expression_to_string fmt indent indent_incr e in + let re = texpression_to_string fmt indent1 indent_incr re in + let e = texpression_to_string fmt indent indent_incr e in let lv = typed_lvalue_to_string val_fmt lv in if monadic then lv ^ " <-- " ^ re ^ ";\n" ^ indent ^ e else "let " ^ lv ^ " = " ^ re ^ " in\n" ^ indent ^ e and switch_to_string (fmt : ast_formatter) (indent : string) - (indent_incr : string) (scrutinee : expression) (body : switch_body) : + (indent_incr : string) (scrutinee : texpression) (body : switch_body) : string = let indent1 = indent ^ indent_incr in (* Printing can mess up on the scrutinee, because it is an expression - but * in most situations it will be a value or a function call, so it should be * ok*) - let scrut = expression_to_string fmt indent1 indent_incr scrutinee in + let scrut = texpression_to_string fmt indent1 indent_incr scrutinee in match body with | If (e_true, e_false) -> - let e_true = expression_to_string fmt indent1 indent_incr e_true in - let e_false = expression_to_string fmt indent1 indent_incr e_false in + let e_true = texpression_to_string fmt indent1 indent_incr e_true in + let e_false = texpression_to_string fmt indent1 indent_incr e_false in "if " ^ scrut ^ "\n" ^ indent ^ "then\n" ^ indent ^ e_true ^ "\n" ^ indent ^ "else\n" ^ indent ^ e_false | SwitchInt (_, branches, otherwise) -> @@ -412,12 +416,12 @@ and switch_to_string (fmt : ast_formatter) (indent : string) List.map (fun (v, be) -> indent ^ "| " ^ scalar_value_to_string v ^ " ->\n" ^ indent1 - ^ expression_to_string fmt indent1 indent_incr be) + ^ texpression_to_string fmt indent1 indent_incr be) branches in let otherwise = indent ^ "| _ ->\n" ^ indent1 - ^ expression_to_string fmt indent1 indent_incr otherwise + ^ texpression_to_string fmt indent1 indent_incr otherwise in let all_branches = List.append branches [ otherwise ] in "switch " ^ scrut ^ " with\n" ^ String.concat "\n" all_branches @@ -426,7 +430,7 @@ and switch_to_string (fmt : ast_formatter) (indent : string) let branch_to_string (b : match_branch) : string = let pat = typed_lvalue_to_string val_fmt b.pat in indent ^ "| " ^ pat ^ " ->\n" ^ indent1 - ^ expression_to_string fmt indent1 indent_incr b.branch + ^ texpression_to_string fmt indent1 indent_incr b.branch in let branches = List.map branch_to_string branches in "match " ^ scrut ^ " with\n" ^ String.concat "\n" branches @@ -439,5 +443,5 @@ let fun_def_to_string (fmt : ast_formatter) (def : fun_def) : string = let inputs = if inputs = [] then "" else " fun " ^ String.concat " " inputs ^ " ->\n" in - let body = expression_to_string fmt " " " " def.body in + let body = texpression_to_string fmt " " " " def.body in "let " ^ name ^ " :\n " ^ signature ^ " =\n" ^ inputs ^ " " ^ body diff --git a/src/Pure.ml b/src/Pure.ml index dc35a181..bc655b63 100644 --- a/src/Pure.ml +++ b/src/Pure.ml @@ -445,7 +445,7 @@ class virtual ['self] mapreduce_expression_base = type expression = | Value of typed_rvalue * mplace option | Call of call - | Let of bool * typed_lvalue * expression * expression + | Let of bool * typed_lvalue * texpression * texpression (** Let binding. The boolean controls whether the let is monadic or not. @@ -479,13 +479,13 @@ type expression = ... ``` *) - | Switch of expression * switch_body - | Meta of meta * expression (** Meta-information *) + | Switch of texpression * switch_body + | Meta of meta * texpression (** Meta-information *) and call = { func : fun_id; type_params : ty list; - args : expression list; + args : texpression list; (** Note that immediately after we converted the symbolic AST to a pure AST, some functions may have no arguments. For instance: ``` @@ -496,11 +496,14 @@ and call = { } and switch_body = - | If of expression * expression - | SwitchInt of T.integer_type * (scalar_value * expression) list * expression + | If of texpression * texpression + | SwitchInt of + T.integer_type * (scalar_value * texpression) list * texpression | Match of match_branch list -and match_branch = { pat : typed_lvalue; branch : expression } +and match_branch = { pat : typed_lvalue; branch : texpression } + +and texpression = { e : expression; ty : ty } [@@deriving visitors { @@ -570,5 +573,5 @@ type fun_def = { inputs_lvs : typed_lvalue list; (** The inputs seen as lvalues. Allows to make transformations, for example to replace unused variables by `_` *) - body : expression; + body : texpression; } diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index 82e5ecd2..2a7293c8 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -195,14 +195,18 @@ let compute_pretty_names (def : fun_def) : fun_def = in (* *) - let rec update_expression (e : expression) (ctx : pn_ctx) : - pn_ctx * expression = - match e with - | Value (v, mp) -> update_value v mp ctx - | Call call -> update_call call ctx - | Let (monadic, lb, re, e) -> update_let monadic lb re e ctx - | Switch (scrut, body) -> update_switch_body scrut body ctx - | Meta (meta, e) -> update_meta meta e ctx + let rec update_texpression (e : texpression) (ctx : pn_ctx) : + pn_ctx * texpression = + let ty = e.ty in + let ctx, e = + match e.e with + | Value (v, mp) -> update_value v mp ctx + | Call call -> update_call call ctx + | Let (monadic, lb, re, e) -> update_let monadic lb re e ctx + | Switch (scrut, body) -> update_switch_body scrut body ctx + | Meta (meta, e) -> update_meta meta e ctx + in + (ctx, { e; ty }) (* *) and update_value (v : typed_rvalue) (mp : mplace option) (ctx : pn_ctx) : pn_ctx * expression = @@ -212,40 +216,40 @@ let compute_pretty_names (def : fun_def) : fun_def = and update_call (call : call) (ctx : pn_ctx) : pn_ctx * expression = let ctx, args = List.fold_left_map - (fun ctx arg -> update_expression arg ctx) + (fun ctx arg -> update_texpression arg ctx) ctx call.args in let call = { call with args } in (ctx, Call call) (* *) - and update_let (monadic : bool) (lv : typed_lvalue) (re : expression) - (e : expression) (ctx : pn_ctx) : pn_ctx * expression = + and update_let (monadic : bool) (lv : typed_lvalue) (re : texpression) + (e : texpression) (ctx : pn_ctx) : pn_ctx * expression = let ctx = add_left_constraint lv ctx in - let ctx, re = update_expression re ctx in - let ctx, e = update_expression e ctx in + let ctx, re = update_texpression re ctx in + let ctx, e = update_texpression e ctx in let lv = update_typed_lvalue ctx lv in (ctx, Let (monadic, lv, re, e)) (* *) - and update_switch_body (scrut : expression) (body : switch_body) + and update_switch_body (scrut : texpression) (body : switch_body) (ctx : pn_ctx) : pn_ctx * expression = - let ctx, scrut = update_expression scrut ctx in + let ctx, scrut = update_texpression scrut ctx in let ctx, body = match body with | If (e_true, e_false) -> - let ctx1, e_true = update_expression e_true ctx in - let ctx2, e_false = update_expression e_false ctx in + let ctx1, e_true = update_texpression e_true ctx in + let ctx2, e_false = update_texpression e_false ctx in let ctx = merge_ctxs ctx1 ctx2 in (ctx, If (e_true, e_false)) | SwitchInt (int_ty, branches, otherwise) -> let ctx_branches_ls = List.map (fun (v, br) -> - let ctx, br = update_expression br ctx in + let ctx, br = update_texpression br ctx in (ctx, (v, br))) branches in - let ctx, otherwise = update_expression otherwise ctx in + let ctx, otherwise = update_texpression otherwise ctx in let ctxs, branches = List.split ctx_branches_ls in let ctxs = merge_ctxs_ls ctxs in let ctx = merge_ctxs ctx ctxs in @@ -255,7 +259,7 @@ let compute_pretty_names (def : fun_def) : fun_def = List.map (fun br -> let ctx = add_left_constraint br.pat ctx in - let ctx, branch = update_expression br.branch ctx in + let ctx, branch = update_texpression br.branch ctx in let pat = update_typed_lvalue ctx br.pat in (ctx, { pat; branch })) branches @@ -266,12 +270,13 @@ let compute_pretty_names (def : fun_def) : fun_def = in (ctx, Switch (scrut, body)) (* *) - and update_meta (meta : meta) (e : expression) (ctx : pn_ctx) : + and update_meta (meta : meta) (e : texpression) (ctx : pn_ctx) : pn_ctx * expression = match meta with | Assignment (mp, rvalue) -> let ctx = add_right_constraint mp rvalue ctx in - update_expression e ctx + let ctx, e = update_texpression e ctx in + (ctx, e.e) in let input_names = @@ -281,7 +286,7 @@ let compute_pretty_names (def : fun_def) : fun_def = def.inputs in let ctx = VarId.Map.of_list input_names in - let _, body = update_expression def.body ctx in + let _, body = update_texpression def.body ctx in { def with body } (** Remove the meta-information *) @@ -290,10 +295,10 @@ let remove_meta (def : fun_def) : fun_def = object inherit [_] map_expression as super - method! visit_Meta env _ e = super#visit_expression env e + method! visit_Meta env _ e = super#visit_expression env e.e end in - let body = obj#visit_expression () def.body in + let body = obj#visit_texpression () def.body in { def with body } (** Inline the useless variable reassignments (a lot of variable assignments @@ -326,7 +331,7 @@ let inline_useless_var_reassignments (inline_named : bool) (def : fun_def) : * - the let-binding is not monadic * - the left-value is a variable * - the assigned expression is a value *) - match (monadic, lv.value, re) with + match (monadic, lv.value, re.e) with | false, LvVar (Var (lv_var, _)), Value (rv, _) -> ( (* Check that: * - the left variable is unnamed or that [inline_named] is true @@ -337,7 +342,7 @@ let inline_useless_var_reassignments (inline_named : bool) (def : fun_def) : (* Update the environment and explore the next expression * (dropping the currrent let) *) let env = add_subst lv_var.id var env in - super#visit_expression env e + super#visit_expression env e.e | _ -> super#visit_Let env monadic lv re e) | _ -> super#visit_Let env monadic lv re e (** Visit the let-bindings to filter the useless ones (and update @@ -353,7 +358,7 @@ let inline_useless_var_reassignments (inline_named : bool) (def : fun_def) : (** Visit the places used as rvalues, to substitute them if necessary *) end in - let body = obj#visit_expression VarId.Map.empty def.body in + let body = obj#visit_texpression VarId.Map.empty def.body in { def with body } (** Given a forward or backward function call, is there, for every execution @@ -382,7 +387,7 @@ let inline_useless_var_reassignments (inline_named : bool) (def : fun_def) : In this situation, we can remove the call `f x`. *) let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (call0 : call) - (e : expression) : bool = + (e : texpression) : bool = let check_call call1 : bool = (* Check the func_ids, to see if call1's function is a child of call0's function *) match (call0.func, call1.func) with @@ -428,7 +433,7 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (call0 : call) * meta-values* (which we need to ignore). We only consider the * case where both expressions are actually values. *) let input_eq (v0, v1) = - match (v0, v1) with + match (v0.e, v1.e) with | Value (v0, _), Value (v1, _) -> v0 = v1 | _ -> false in @@ -451,37 +456,43 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (call0 : call) method! visit_expression env e = match e with | Value (_, _) -> fun _ -> false - | Let (_, _, Call call1, e) -> + | Let (_, _, { e = Call call1; ty = _ }, e) -> let call_is_child = check_call call1 in if call_is_child then fun () -> true - else self#visit_expression env e + else self#visit_texpression env e | Let (_, _, re, e) -> fun () -> - self#visit_expression env re () && self#visit_expression env e () + self#visit_texpression env re () + && self#visit_texpression env e () | Call call1 -> fun () -> check_call call1 - | Meta (_, e) -> self#visit_expression env e + | Meta (_, e) -> self#visit_texpression env e | Switch (_, body) -> self#visit_switch_body env body (** We need to reimplement the way we compose the booleans *) + method! visit_texpression env e = + (* We take care not to visit the type *) + self#visit_expression env e.e + method! visit_switch_body env body = match body with | If (e1, e2) -> fun () -> - self#visit_expression env e1 () && self#visit_expression env e2 () + self#visit_texpression env e1 () + && self#visit_texpression env e2 () | SwitchInt (_, branches, otherwise) -> fun () -> List.for_all - (fun (_, br) -> self#visit_expression env br ()) + (fun (_, br) -> self#visit_texpression env br ()) branches - && self#visit_expression env otherwise () + && self#visit_texpression env otherwise () | Match branches -> fun () -> List.for_all - (fun br -> self#visit_expression env br.branch ()) + (fun br -> self#visit_texpression env br.branch ()) branches end in - visitor#visit_expression () e () + visitor#visit_texpression () e () (** Filter the unused assignments (removes the unused variables, filters the function calls) *) @@ -546,13 +557,13 @@ let filter_unused (filter_monadic_calls : bool) (ctx : trans_ctx) super#visit_expression env e | Let (monadic, lv, re, e) -> (* Compute the set of values used in the next expression *) - let e, used = self#visit_expression env e in + let e, used = self#visit_texpression env e in let used = used () in (* Filter the left values *) let lv, all_dummies = filter_typed_lvalue used lv in (* Small utility - called if we can't filter the let-binding *) let dont_filter () = - let re, used_re = self#visit_expression env re in + let re, used_re = self#visit_texpression env re in let used = VarId.Set.union used (used_re ()) in (Let (monadic, lv, re, e), fun _ -> used) in @@ -560,12 +571,12 @@ let filter_unused (filter_monadic_calls : bool) (ctx : trans_ctx) if all_dummies then if not monadic then (* Not a monadic let-binding: simple case *) - (e, fun _ -> used) + (e.e, fun _ -> used) else (* Monadic let-binding: trickier. * We can filter if the right-expression is a function call, * under some conditions. *) - match (filter_monadic_calls, re) with + match (filter_monadic_calls, re.e) with | true, Call call -> (* We need to check if there is a child call - see * the comments for: @@ -574,7 +585,7 @@ let filter_unused (filter_monadic_calls : bool) (ctx : trans_ctx) expression_contains_child_call_in_all_paths ctx call e in if has_child_call then (* Filter *) - (e, fun _ -> used) + (e.e, fun _ -> used) else (* No child call: don't filter *) dont_filter () | _ -> @@ -585,7 +596,7 @@ let filter_unused (filter_monadic_calls : bool) (ctx : trans_ctx) end in (* Visit the body *) - let body, used_vars = expr_visitor#visit_expression () def.body in + let body, used_vars = expr_visitor#visit_texpression () def.body in (* Visit the parameters *) let used_vars = used_vars () in let inputs_lvs = @@ -604,12 +615,12 @@ let to_monadic (def : fun_def) : fun_def = method! visit_call env call = (* If no arguments, introduce unit *) if call.args = [] then - let args = [ Value (unit_rvalue, None) ] in + let args = [ mk_value_expression unit_rvalue None ] in { call with args } (* Otherwise: nothing to do *) else super#visit_call env call end in - let body = obj#visit_expression () def.body in + let body = obj#visit_texpression () def.body in let def = { def with body } in (* Update the signature: first the input types *) @@ -617,7 +628,7 @@ let to_monadic (def : fun_def) : fun_def = if def.inputs = [] then ( assert (def.signature.inputs = []); let signature = { def.signature with inputs = [ unit_ty ] } in - let var_cnt = get_expression_min_var_counter def.body in + let var_cnt = get_expression_min_var_counter def.body.e in let id, _ = VarId.fresh var_cnt in let var = { id; basename = None; ty = unit_ty } in let inputs = [ var ] in @@ -659,7 +670,7 @@ let unit_vars_to_unit (def : fun_def) : fun_def = end in (* Update the body *) - let body = obj#visit_expression () def.body in + let body = obj#visit_texpression () def.body in (* Update the input parameters *) let inputs_lvs = List.map (obj#visit_typed_lvalue ()) def.inputs_lvs in (* Return *) diff --git a/src/PureUtils.ml b/src/PureUtils.ml index 9596cd9b..dd072d23 100644 --- a/src/PureUtils.ml +++ b/src/PureUtils.ml @@ -72,6 +72,17 @@ let mk_result_return_lvalue (v : typed_lvalue) : typed_lvalue = let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ]) +let mk_value_expression (v : typed_rvalue) (mp : mplace option) : texpression = + let e = Value (v, mp) in + let ty = v.ty in + { e; ty } + +let mk_let (monadic : bool) (lv : typed_lvalue) (re : texpression) + (next_e : texpression) : texpression = + let e = Let (monadic, lv, re, next_e) in + let ty = next_e.ty in + { e; ty } + (** Type substitution *) let ty_substitute (tsubst : TypeVarId.id -> ty) (ty : ty) : ty = let obj = diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index f6f610dd..d48f732d 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -843,16 +843,23 @@ let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) : S.call * V.abs list = let abs_ancestors = list_ancestor_abstractions ctx abs in (call_info.forward, abs_ancestors) -let rec translate_expression (e : S.expression) (ctx : bs_ctx) : expression = +let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression = match e with | S.Return opt_v -> translate_return opt_v ctx - | Panic -> Value (mk_result_fail_rvalue ctx.ret_ty, None) + | Panic -> + (* Here we use the function return type - note that it is ok because + * we don't match on panics which happen inside the function body - + * but it won't be true anymore once we translate individual blocks *) + let v = mk_result_fail_rvalue ctx.ret_ty in + let e = Value (v, None) in + let ty = v.ty in + { e; ty } | FunCall (call, e) -> translate_function_call call e ctx | EndAbstraction (abs, e) -> translate_end_abstraction abs e ctx | Expansion (p, sv, exp) -> translate_expansion p sv exp ctx | Meta (meta, e) -> translate_meta meta e ctx -and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : expression +and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = (* There are two cases: - either we are translating a forward function, in which case the optional @@ -864,7 +871,10 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : expression (* Forward function *) let v = Option.get opt_v in let v = typed_value_to_rvalue ctx v in - Value (mk_result_return_rvalue v, None) + let v = mk_result_return_rvalue v in + let e = Value (v, None) in + let ty = v.ty in + { e; ty } | Some bid -> (* Backward function *) (* Sanity check *) @@ -880,10 +890,13 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : expression let ret_tys = List.map (fun (v : typed_rvalue) -> v.ty) field_values in let ret_ty = Adt (Tuple, ret_tys) in let ret_value : typed_rvalue = { value = ret_value; ty = ret_ty } in - Value (mk_result_return_rvalue ret_value, None) + let ret_value = mk_result_return_rvalue ret_value in + let e = Value (ret_value, None) in + let ty = ret_value.ty in + { e; ty } and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : - expression = + texpression = (* Translate the function call *) let type_params = List.map (ctx_translate_fwd_ty ctx) call.type_params in let args = List.map (typed_value_to_rvalue ctx) call.args in @@ -918,17 +931,22 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : | _ -> failwith "Unreachable") in let args = - List.map (fun (arg, mp) -> Value (arg, mp)) (List.combine args args_mplaces) + List.map + (fun (arg, mp) -> mk_value_expression arg mp) + (List.combine args args_mplaces) in + let dest_v = mk_typed_lvalue_from_var dest dest_mplace in let call = { func; type_params; args } in let call = Call call in + let call_ty = if monadic then mk_result_ty dest_v.ty else dest_v.ty in + let call = { e = call; ty = call_ty } in (* Translate the next expression *) - let e = translate_expression e ctx in + let next_e = translate_expression e ctx in (* Put together *) - Let (monadic, mk_typed_lvalue_from_var dest dest_mplace, call, e) + mk_let monadic dest_v call next_e and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : - expression = + texpression = log#ldebug (lazy ("translate_end_abstraction: abstraction kind: " @@ -976,14 +994,16 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : (fun (var, v) -> assert ((var : var).ty = (v : typed_rvalue).ty)) variables_values; (* Translate the next expression *) - let e = translate_expression e ctx in + let next_e = translate_expression e ctx in (* Generate the assignemnts *) let monadic = false in List.fold_right - (fun (var, value) e -> - Let - (monadic, mk_typed_lvalue_from_var var None, Value (value, None), e)) - variables_values e + (fun (var, value) (e : texpression) -> + mk_let monadic + (mk_typed_lvalue_from_var var None) + (mk_value_expression value None) + e) + variables_values next_e | V.FunCall -> let call_info = V.FunCallId.Map.find abs.call_id ctx.calls in let call = call_info.forward in @@ -1039,17 +1059,19 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : * if necessary *) let ctx, func = bs_ctx_register_backward_call abs ctx in (* Translate the next expression *) - let e = translate_expression e ctx in + let next_e = translate_expression e ctx in (* Put everything together *) let args_mplaces = List.map (fun _ -> None) inputs in let args = List.map - (fun (arg, mp) -> Value (arg, mp)) + (fun (arg, mp) -> mk_value_expression arg mp) (List.combine inputs args_mplaces) in - let call = { func; type_params; args } in let monadic = true in - Let (monadic, output, Call call, e) + let call = { func; type_params; args } in + let call_ty = mk_result_ty output.ty in + let call = { e = Call call; ty = call_ty } in + mk_let monadic output call next_e | V.SynthRet -> (* If we end the abstraction which consumed the return value of the function * we are synthesizing, we get back the borrows which were inside. Those borrows @@ -1100,20 +1122,18 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : assert (given_back.ty = input.ty)) given_back_inputs; (* Translate the next expression *) - let e = translate_expression e ctx in + let next_e = translate_expression e ctx in (* Generate the assignments *) let monadic = false in List.fold_right (fun (given_back, input_var) e -> - Let - ( monadic, - given_back, - Value (mk_typed_rvalue_from_var input_var, None), - e )) - given_back_inputs e + mk_let monadic given_back + (mk_value_expression (mk_typed_rvalue_from_var input_var) None) + e) + given_back_inputs next_e and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) - (exp : S.expansion) (ctx : bs_ctx) : expression = + (exp : S.expansion) (ctx : bs_ctx) : texpression = (* Translate the scrutinee *) let scrutinee_var = lookup_var_for_symbolic_value sv ctx in let scrutinee = mk_typed_rvalue_from_var scrutinee_var in @@ -1130,13 +1150,12 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) (* The (mut/shared) borrow type is extracted to identity: we thus simply * introduce an reassignment *) let ctx, var = fresh_var_for_symbolic_value nsv ctx in - let e = translate_expression e ctx in + let next_e = translate_expression e ctx in let monadic = false in - Let - ( monadic, - mk_typed_lvalue_from_var var None, - Value (scrutinee, scrutinee_mplace), - e ) + mk_let monadic + (mk_typed_lvalue_from_var var None) + (mk_value_expression scrutinee scrutinee_mplace) + next_e | SeAdt _ -> (* Should be in the [ExpandAdt] case *) failwith "Unreachable") @@ -1161,7 +1180,10 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) in let lv = mk_adt_lvalue scrutinee.ty variant_id lvars in let monadic = false in - Let (monadic, lv, Value (scrutinee, scrutinee_mplace), branch) + + mk_let monadic lv + (mk_value_expression scrutinee scrutinee_mplace) + branch else (* This is not an enumeration: introduce let-bindings for every * field. @@ -1182,22 +1204,19 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) List.fold_right (fun (fid, var) e -> let field_proj = gen_field_proj fid var in - Let - ( monadic, - mk_typed_lvalue_from_var var None, - Value (field_proj, None), - e )) + mk_let monadic + (mk_typed_lvalue_from_var var None) + (mk_value_expression field_proj None) + e) id_var_pairs branch | T.Tuple -> let vars = List.map (fun x -> mk_typed_lvalue_from_var x None) vars in let monadic = false in - Let - ( monadic, - mk_tuple_lvalue vars, - Value (scrutinee, scrutinee_mplace), - branch ) + mk_let monadic (mk_tuple_lvalue vars) + (mk_value_expression scrutinee scrutinee_mplace) + branch | T.Assumed T.Box -> (* There should be exactly one variable *) let var = @@ -1206,11 +1225,10 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) (* We simply introduce an assignment - the box type is the * identity when extracted (`box a == a`) *) let monadic = false in - Let - ( monadic, - mk_typed_lvalue_from_var var None, - Value (scrutinee, scrutinee_mplace), - branch )) + mk_let monadic + (mk_typed_lvalue_from_var var None) + (mk_value_expression scrutinee scrutinee_mplace) + branch) | branches -> let translate_branch (variant_id : T.VariantId.id option) (svl : V.symbolic_value list) (branch : S.expression) : @@ -1229,16 +1247,30 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) let branches = List.map (fun (vid, svl, e) -> translate_branch vid svl e) branches in - Switch (Value (scrutinee, scrutinee_mplace), Match branches)) + let e = + Switch + (mk_value_expression scrutinee scrutinee_mplace, Match branches) + in + (* There should be at least one branch *) + let branch = List.hd branches in + let ty = branch.branch.ty in + assert (List.for_all (fun br -> br.branch.ty = ty) branches); + { e; ty }) | ExpandBool (true_e, false_e) -> (* We don't need to update the context: we don't introduce any * new values/variables *) let true_e = translate_expression true_e ctx in let false_e = translate_expression false_e ctx in - Switch (Value (scrutinee, scrutinee_mplace), If (true_e, false_e)) + let e = + Switch + (mk_value_expression scrutinee scrutinee_mplace, If (true_e, false_e)) + in + let ty = true_e.ty in + assert (ty = false_e.ty); + { e; ty } | ExpandInt (int_ty, branches, otherwise) -> let translate_branch ((v, branch_e) : V.scalar_value * S.expression) : - scalar_value * expression = + scalar_value * texpression = (* We don't need to update the context: we don't introduce any * new values/variables *) let branch_e = translate_expression branch_e ctx in @@ -1246,13 +1278,18 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) in let branches = List.map translate_branch branches in let otherwise = translate_expression otherwise ctx in - Switch - ( Value (scrutinee, scrutinee_mplace), - SwitchInt (int_ty, branches, otherwise) ) + let e = + Switch + ( mk_value_expression scrutinee scrutinee_mplace, + SwitchInt (int_ty, branches, otherwise) ) + in + let ty = otherwise.ty in + assert (List.for_all (fun (_, br) -> br.ty = ty) branches); + { e; ty } and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) : - expression = - let e = translate_expression e ctx in + texpression = + let next_e = translate_expression e ctx in let meta = match meta with | S.Assignment (p, rv) -> @@ -1260,7 +1297,9 @@ and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) : let rv = typed_value_to_rvalue ctx rv in Assignment (p, rv) in - Meta (meta, e) + let e = Meta (meta, next_e) in + let ty = next_e.ty in + { e; ty } let translate_fun_def (ctx : bs_ctx) (body : S.expression) : fun_def = let def = ctx.fun_def in -- cgit v1.2.3