summaryrefslogtreecommitdiff
path: root/src/SymbolicToPure.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/SymbolicToPure.ml')
-rw-r--r--src/SymbolicToPure.ml155
1 files changed, 97 insertions, 58 deletions
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index f6f610dd..d48f732d 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -843,16 +843,23 @@ let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) : S.call * V.abs list =
let abs_ancestors = list_ancestor_abstractions ctx abs in
(call_info.forward, abs_ancestors)
-let rec translate_expression (e : S.expression) (ctx : bs_ctx) : expression =
+let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =
match e with
| S.Return opt_v -> translate_return opt_v ctx
- | Panic -> Value (mk_result_fail_rvalue ctx.ret_ty, None)
+ | Panic ->
+ (* Here we use the function return type - note that it is ok because
+ * we don't match on panics which happen inside the function body -
+ * but it won't be true anymore once we translate individual blocks *)
+ let v = mk_result_fail_rvalue ctx.ret_ty in
+ let e = Value (v, None) in
+ let ty = v.ty in
+ { e; ty }
| FunCall (call, e) -> translate_function_call call e ctx
| EndAbstraction (abs, e) -> translate_end_abstraction abs e ctx
| Expansion (p, sv, exp) -> translate_expansion p sv exp ctx
| Meta (meta, e) -> translate_meta meta e ctx
-and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : expression
+and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression
=
(* There are two cases:
- either we are translating a forward function, in which case the optional
@@ -864,7 +871,10 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : expression
(* Forward function *)
let v = Option.get opt_v in
let v = typed_value_to_rvalue ctx v in
- Value (mk_result_return_rvalue v, None)
+ let v = mk_result_return_rvalue v in
+ let e = Value (v, None) in
+ let ty = v.ty in
+ { e; ty }
| Some bid ->
(* Backward function *)
(* Sanity check *)
@@ -880,10 +890,13 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : expression
let ret_tys = List.map (fun (v : typed_rvalue) -> v.ty) field_values in
let ret_ty = Adt (Tuple, ret_tys) in
let ret_value : typed_rvalue = { value = ret_value; ty = ret_ty } in
- Value (mk_result_return_rvalue ret_value, None)
+ let ret_value = mk_result_return_rvalue ret_value in
+ let e = Value (ret_value, None) in
+ let ty = ret_value.ty in
+ { e; ty }
and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
- expression =
+ texpression =
(* Translate the function call *)
let type_params = List.map (ctx_translate_fwd_ty ctx) call.type_params in
let args = List.map (typed_value_to_rvalue ctx) call.args in
@@ -918,17 +931,22 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
| _ -> failwith "Unreachable")
in
let args =
- List.map (fun (arg, mp) -> Value (arg, mp)) (List.combine args args_mplaces)
+ List.map
+ (fun (arg, mp) -> mk_value_expression arg mp)
+ (List.combine args args_mplaces)
in
+ let dest_v = mk_typed_lvalue_from_var dest dest_mplace in
let call = { func; type_params; args } in
let call = Call call in
+ let call_ty = if monadic then mk_result_ty dest_v.ty else dest_v.ty in
+ let call = { e = call; ty = call_ty } in
(* Translate the next expression *)
- let e = translate_expression e ctx in
+ let next_e = translate_expression e ctx in
(* Put together *)
- Let (monadic, mk_typed_lvalue_from_var dest dest_mplace, call, e)
+ mk_let monadic dest_v call next_e
and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
- expression =
+ texpression =
log#ldebug
(lazy
("translate_end_abstraction: abstraction kind: "
@@ -976,14 +994,16 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
(fun (var, v) -> assert ((var : var).ty = (v : typed_rvalue).ty))
variables_values;
(* Translate the next expression *)
- let e = translate_expression e ctx in
+ let next_e = translate_expression e ctx in
(* Generate the assignemnts *)
let monadic = false in
List.fold_right
- (fun (var, value) e ->
- Let
- (monadic, mk_typed_lvalue_from_var var None, Value (value, None), e))
- variables_values e
+ (fun (var, value) (e : texpression) ->
+ mk_let monadic
+ (mk_typed_lvalue_from_var var None)
+ (mk_value_expression value None)
+ e)
+ variables_values next_e
| V.FunCall ->
let call_info = V.FunCallId.Map.find abs.call_id ctx.calls in
let call = call_info.forward in
@@ -1039,17 +1059,19 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
* if necessary *)
let ctx, func = bs_ctx_register_backward_call abs ctx in
(* Translate the next expression *)
- let e = translate_expression e ctx in
+ let next_e = translate_expression e ctx in
(* Put everything together *)
let args_mplaces = List.map (fun _ -> None) inputs in
let args =
List.map
- (fun (arg, mp) -> Value (arg, mp))
+ (fun (arg, mp) -> mk_value_expression arg mp)
(List.combine inputs args_mplaces)
in
- let call = { func; type_params; args } in
let monadic = true in
- Let (monadic, output, Call call, e)
+ let call = { func; type_params; args } in
+ let call_ty = mk_result_ty output.ty in
+ let call = { e = Call call; ty = call_ty } in
+ mk_let monadic output call next_e
| V.SynthRet ->
(* If we end the abstraction which consumed the return value of the function
* we are synthesizing, we get back the borrows which were inside. Those borrows
@@ -1100,20 +1122,18 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
assert (given_back.ty = input.ty))
given_back_inputs;
(* Translate the next expression *)
- let e = translate_expression e ctx in
+ let next_e = translate_expression e ctx in
(* Generate the assignments *)
let monadic = false in
List.fold_right
(fun (given_back, input_var) e ->
- Let
- ( monadic,
- given_back,
- Value (mk_typed_rvalue_from_var input_var, None),
- e ))
- given_back_inputs e
+ mk_let monadic given_back
+ (mk_value_expression (mk_typed_rvalue_from_var input_var) None)
+ e)
+ given_back_inputs next_e
and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
- (exp : S.expansion) (ctx : bs_ctx) : expression =
+ (exp : S.expansion) (ctx : bs_ctx) : texpression =
(* Translate the scrutinee *)
let scrutinee_var = lookup_var_for_symbolic_value sv ctx in
let scrutinee = mk_typed_rvalue_from_var scrutinee_var in
@@ -1130,13 +1150,12 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
(* The (mut/shared) borrow type is extracted to identity: we thus simply
* introduce an reassignment *)
let ctx, var = fresh_var_for_symbolic_value nsv ctx in
- let e = translate_expression e ctx in
+ let next_e = translate_expression e ctx in
let monadic = false in
- Let
- ( monadic,
- mk_typed_lvalue_from_var var None,
- Value (scrutinee, scrutinee_mplace),
- e )
+ mk_let monadic
+ (mk_typed_lvalue_from_var var None)
+ (mk_value_expression scrutinee scrutinee_mplace)
+ next_e
| SeAdt _ ->
(* Should be in the [ExpandAdt] case *)
failwith "Unreachable")
@@ -1161,7 +1180,10 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
in
let lv = mk_adt_lvalue scrutinee.ty variant_id lvars in
let monadic = false in
- Let (monadic, lv, Value (scrutinee, scrutinee_mplace), branch)
+
+ mk_let monadic lv
+ (mk_value_expression scrutinee scrutinee_mplace)
+ branch
else
(* This is not an enumeration: introduce let-bindings for every
* field.
@@ -1182,22 +1204,19 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
List.fold_right
(fun (fid, var) e ->
let field_proj = gen_field_proj fid var in
- Let
- ( monadic,
- mk_typed_lvalue_from_var var None,
- Value (field_proj, None),
- e ))
+ mk_let monadic
+ (mk_typed_lvalue_from_var var None)
+ (mk_value_expression field_proj None)
+ e)
id_var_pairs branch
| T.Tuple ->
let vars =
List.map (fun x -> mk_typed_lvalue_from_var x None) vars
in
let monadic = false in
- Let
- ( monadic,
- mk_tuple_lvalue vars,
- Value (scrutinee, scrutinee_mplace),
- branch )
+ mk_let monadic (mk_tuple_lvalue vars)
+ (mk_value_expression scrutinee scrutinee_mplace)
+ branch
| T.Assumed T.Box ->
(* There should be exactly one variable *)
let var =
@@ -1206,11 +1225,10 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
(* We simply introduce an assignment - the box type is the
* identity when extracted (`box a == a`) *)
let monadic = false in
- Let
- ( monadic,
- mk_typed_lvalue_from_var var None,
- Value (scrutinee, scrutinee_mplace),
- branch ))
+ mk_let monadic
+ (mk_typed_lvalue_from_var var None)
+ (mk_value_expression scrutinee scrutinee_mplace)
+ branch)
| branches ->
let translate_branch (variant_id : T.VariantId.id option)
(svl : V.symbolic_value list) (branch : S.expression) :
@@ -1229,16 +1247,30 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
let branches =
List.map (fun (vid, svl, e) -> translate_branch vid svl e) branches
in
- Switch (Value (scrutinee, scrutinee_mplace), Match branches))
+ let e =
+ Switch
+ (mk_value_expression scrutinee scrutinee_mplace, Match branches)
+ in
+ (* There should be at least one branch *)
+ let branch = List.hd branches in
+ let ty = branch.branch.ty in
+ assert (List.for_all (fun br -> br.branch.ty = ty) branches);
+ { e; ty })
| ExpandBool (true_e, false_e) ->
(* We don't need to update the context: we don't introduce any
* new values/variables *)
let true_e = translate_expression true_e ctx in
let false_e = translate_expression false_e ctx in
- Switch (Value (scrutinee, scrutinee_mplace), If (true_e, false_e))
+ let e =
+ Switch
+ (mk_value_expression scrutinee scrutinee_mplace, If (true_e, false_e))
+ in
+ let ty = true_e.ty in
+ assert (ty = false_e.ty);
+ { e; ty }
| ExpandInt (int_ty, branches, otherwise) ->
let translate_branch ((v, branch_e) : V.scalar_value * S.expression) :
- scalar_value * expression =
+ scalar_value * texpression =
(* We don't need to update the context: we don't introduce any
* new values/variables *)
let branch_e = translate_expression branch_e ctx in
@@ -1246,13 +1278,18 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
in
let branches = List.map translate_branch branches in
let otherwise = translate_expression otherwise ctx in
- Switch
- ( Value (scrutinee, scrutinee_mplace),
- SwitchInt (int_ty, branches, otherwise) )
+ let e =
+ Switch
+ ( mk_value_expression scrutinee scrutinee_mplace,
+ SwitchInt (int_ty, branches, otherwise) )
+ in
+ let ty = otherwise.ty in
+ assert (List.for_all (fun (_, br) -> br.ty = ty) branches);
+ { e; ty }
and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) :
- expression =
- let e = translate_expression e ctx in
+ texpression =
+ let next_e = translate_expression e ctx in
let meta =
match meta with
| S.Assignment (p, rv) ->
@@ -1260,7 +1297,9 @@ and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) :
let rv = typed_value_to_rvalue ctx rv in
Assignment (p, rv)
in
- Meta (meta, e)
+ let e = Meta (meta, next_e) in
+ let ty = next_e.ty in
+ { e; ty }
let translate_fun_def (ctx : bs_ctx) (body : S.expression) : fun_def =
let def = ctx.fun_def in