summaryrefslogtreecommitdiff
path: root/src/SymbolicToPure.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/SymbolicToPure.ml')
-rw-r--r--src/SymbolicToPure.ml100
1 files changed, 77 insertions, 23 deletions
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index 5c0250f7..d65e929f 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -15,9 +15,18 @@ module PP = PrintPure
(** The local logger *)
let log = L.symbolic_to_pure_log
+(* TODO : move *)
+let binop_can_fail (binop : E.binop) : bool =
+ match binop with
+ | BitXor | BitAnd | BitOr | Eq | Lt | Le | Ne | Ge | Gt -> false
+ | Div | Rem | Add | Sub | Mul -> true
+ | Shl | Shr -> raise Unimplemented
+
(* TODO: move *)
let mk_place_from_var (v : var) : place = { var = v.id; projection = [] }
+let mk_tuple_ty (tys : ty list) : ty = Adt (Tuple, tys)
+
let mk_typed_rvalue_from_var (v : var) : typed_rvalue =
let value = RvPlace (mk_place_from_var v) in
let ty = v.ty in
@@ -31,7 +40,7 @@ let mk_typed_lvalue_from_var (v : var) (mp : mplace option) : typed_lvalue =
let mk_tuple_lvalue (vl : typed_lvalue list) : typed_lvalue =
let tys = List.map (fun (v : typed_lvalue) -> v.ty) vl in
- let ty = Adt (T.Tuple, tys) in
+ let ty = Adt (Tuple, tys) in
let value = LvAdt { variant_id = None; field_values = vl } in
{ value; ty }
@@ -47,6 +56,18 @@ let ty_as_integer (t : ty) : T.integer_type =
let type_def_is_enum (def : T.type_def) : bool =
match def.kind with T.Struct _ -> false | Enum _ -> true
+let mk_result_fail_rvalue (ty : ty) : typed_rvalue =
+ let ty = Adt (Assumed Result, [ ty ]) in
+ let value = RvAdt { variant_id = Some result_fail_id; field_values = [] } in
+ { value; ty }
+
+let mk_result_return_rvalue (v : typed_rvalue) : typed_rvalue =
+ let ty = Adt (Assumed Result, [ v.ty ]) in
+ let value =
+ RvAdt { variant_id = Some result_return_id; field_values = [ v ] }
+ in
+ { value; ty }
+
(** Type substitution *)
let ty_substitute (tsubst : TypeVarId.id -> ty) (ty : ty) : ty =
let obj =
@@ -132,6 +153,7 @@ type bs_ctx = {
fun_context : fun_context;
fun_def : A.fun_def;
bid : T.RegionGroupId.id option; (** TODO: rename *)
+ ret_ty : ty; (** The return type - we use it to translate `Panic` *)
sv_to_var : var V.SymbolicValueId.Map.t;
(** Whenever we encounter a new symbolic value (introduced because of
a symbolic expansion or upon ending an abstraction, for instance)
@@ -253,11 +275,17 @@ let bs_ctx_register_backward_call (abs : V.abs) (ctx : bs_ctx) : bs_ctx * fun_id
let rec translate_sty (ty : T.sty) : ty =
let translate = translate_sty in
match ty with
- | T.Adt (type_id, regions, tys) ->
+ | T.Adt (type_id, regions, tys) -> (
(* Can't translate types with regions for now *)
assert (regions = []);
let tys = List.map translate tys in
- Adt (type_id, tys)
+ match type_id with
+ | T.AdtId adt_id -> Adt (AdtId adt_id, tys)
+ | T.Tuple -> Adt (Tuple, tys)
+ | T.Assumed T.Box -> (
+ match tys with
+ | [ ty ] -> ty
+ | _ -> failwith "Box type with incorrect number of arguments"))
| TypeVar vid -> TypeVar vid
| Bool -> Bool
| Char -> Char
@@ -321,7 +349,8 @@ let rec translate_fwd_ty (types_infos : TA.type_infos) (ty : 'r T.ty) : ty =
let tys = List.map translate tys in
(* Eliminate boxes *)
match type_id with
- | T.AdtId _ | Tuple -> Adt (type_id, tys)
+ | AdtId adt_id -> Adt (AdtId adt_id, tys)
+ | Tuple -> Adt (Tuple, tys)
| T.Assumed T.Box -> (
match tys with
| [ bty ] -> bty
@@ -363,6 +392,11 @@ let rec translate_back_ty (types_infos : TA.type_infos)
| T.AdtId _ ->
(* Don't accept ADTs (which are not tuples) with borrows for now *)
assert (not (TypesUtils.ty_has_borrows types_infos ty));
+ let type_id =
+ match type_id with
+ | T.AdtId id -> AdtId id
+ | T.Tuple | T.Assumed T.Box -> failwith "Unreachable"
+ in
if inside_mut then
let tys_t = List.filter_map translate tys in
Some (Adt (type_id, tys_t))
@@ -378,7 +412,7 @@ let rec translate_back_ty (types_infos : TA.type_infos)
| T.Tuple -> (
(* Tuples can contain borrows (which we eliminated) *)
let tys_t = List.filter_map translate tys in
- match tys_t with [] -> None | _ -> Some (Adt (T.Tuple, tys_t))))
+ match tys_t with [] -> None | _ -> Some (Adt (Tuple, tys_t))))
| TypeVar vid -> wrap (TypeVar vid)
| Bool -> wrap Bool
| Char -> wrap Char
@@ -886,7 +920,7 @@ let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) : S.call * V.abs list =
let rec translate_expression (e : S.expression) (ctx : bs_ctx) : expression =
match e with
| S.Return opt_v -> translate_return opt_v ctx
- | Panic -> Fail
+ | Panic -> Value (mk_result_fail_rvalue ctx.ret_ty, None)
| FunCall (call, e) -> translate_function_call call e ctx
| EndAbstraction (abs, e) -> translate_end_abstraction abs e ctx
| Expansion (p, sv, exp) -> translate_expansion p sv exp ctx
@@ -904,7 +938,7 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : expression
(* Forward function *)
let v = Option.get opt_v in
let v = typed_value_to_rvalue ctx v in
- Return v
+ Value (mk_result_return_rvalue v, None)
| Some bid ->
(* Backward function *)
(* Sanity check *)
@@ -918,9 +952,9 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : expression
let field_values = List.map mk_typed_rvalue_from_var backward_outputs in
let ret_value = RvAdt { variant_id = None; field_values } in
let ret_tys = List.map (fun (v : typed_rvalue) -> v.ty) field_values in
- let ret_ty = Adt (T.Tuple, ret_tys) in
+ let ret_ty = Adt (Tuple, ret_tys) in
let ret_value : typed_rvalue = { value = ret_value; ty = ret_ty } in
- Return ret_value
+ Value (mk_result_return_rvalue ret_value, None)
and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
expression =
@@ -932,18 +966,20 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
(* Retrieve the function id, and register the function call in the context
* if necessary. *)
- let ctx, func =
+ let ctx, func, monadic =
match call.call_id with
| S.Fun (fid, call_id) ->
let ctx = bs_ctx_register_forward_call call_id call ctx in
let func = Regular (fid, None) in
- (ctx, func)
- | S.Unop E.Not -> (ctx, Unop Not)
+ (ctx, func, true)
+ | S.Unop E.Not -> (ctx, Unop Not, false)
| S.Unop E.Neg -> (
match args with
| [ arg ] ->
let int_ty = ty_as_integer arg.ty in
- (ctx, Unop (Neg int_ty))
+ (* Note that negation can lead to an overflow and thus fail (it
+ * is thus monadic) *)
+ (ctx, Unop (Neg int_ty), true)
| _ -> failwith "Unreachable")
| S.Binop binop -> (
match args with
@@ -951,7 +987,8 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
let int_ty0 = ty_as_integer arg0.ty in
let int_ty1 = ty_as_integer arg1.ty in
assert (int_ty0 = int_ty1);
- (ctx, Binop (binop, int_ty0))
+ let monadic = binop_can_fail binop in
+ (ctx, Binop (binop, int_ty0), monadic)
| _ -> failwith "Unreachable")
in
let args =
@@ -962,7 +999,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
(* Translate the next expression *)
let e = translate_expression e ctx in
(* Put together *)
- Let (mk_typed_lvalue_from_var dest dest_mplace, call, e)
+ Let (monadic, mk_typed_lvalue_from_var dest dest_mplace, call, e)
and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
expression =
@@ -1015,9 +1052,11 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
(* Translate the next expression *)
let e = translate_expression e ctx in
(* Generate the assignemnts *)
+ let monadic = false in
List.fold_right
(fun (var, value) e ->
- Let (mk_typed_lvalue_from_var var None, Value (value, None), e))
+ Let
+ (monadic, mk_typed_lvalue_from_var var None, Value (value, None), e))
variables_values e
| V.FunCall ->
let call_info = V.FunCallId.Map.find abs.call_id ctx.calls in
@@ -1078,7 +1117,8 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
(List.combine inputs args_mplaces)
in
let call = { func; type_params; args } in
- Let (output, Call call, e)
+ let monadic = true in
+ Let (monadic, output, Call call, e)
| V.SynthRet ->
(* If we end the abstraction which consumed the return value of the function
* we are synthesizing, we get back the borrows which were inside. Those borrows
@@ -1129,9 +1169,14 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
(* Translate the next expression *)
let e = translate_expression e ctx in
(* Generate the assignments *)
+ let monadic = false in
List.fold_right
(fun (given_back, input_var) e ->
- Let (given_back, Value (mk_typed_rvalue_from_var input_var, None), e))
+ Let
+ ( monadic,
+ given_back,
+ Value (mk_typed_rvalue_from_var input_var, None),
+ e ))
given_back_inputs e
and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
@@ -1153,8 +1198,10 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
* introduce an reassignment *)
let ctx, var = fresh_var_for_symbolic_value nsv ctx in
let e = translate_expression e ctx in
+ let monadic = false in
Let
- ( mk_typed_lvalue_from_var var None,
+ ( monadic,
+ mk_typed_lvalue_from_var var None,
Value (scrutinee, scrutinee_mplace),
e )
| SeAdt _ ->
@@ -1180,7 +1227,8 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
List.map (fun v -> mk_typed_lvalue_from_var v None) vars
in
let lv = mk_adt_lvalue scrutinee.ty variant_id lvars in
- Let (lv, Value (scrutinee, scrutinee_mplace), branch)
+ let monadic = false in
+ Let (monadic, lv, Value (scrutinee, scrutinee_mplace), branch)
else
(* This is not an enumeration: introduce let-bindings for every
* field.
@@ -1197,11 +1245,13 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
{ value; ty }
in
let id_var_pairs = FieldId.mapi (fun fid v -> (fid, v)) vars in
+ let monadic = false in
List.fold_right
(fun (fid, var) e ->
let field_proj = gen_field_proj fid var in
Let
- ( mk_typed_lvalue_from_var var None,
+ ( monadic,
+ mk_typed_lvalue_from_var var None,
Value (field_proj, None),
e ))
id_var_pairs branch
@@ -1209,8 +1259,10 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
let vars =
List.map (fun x -> mk_typed_lvalue_from_var x None) vars
in
+ let monadic = false in
Let
- ( mk_tuple_lvalue vars,
+ ( monadic,
+ mk_tuple_lvalue vars,
Value (scrutinee, scrutinee_mplace),
branch )
| T.Assumed T.Box ->
@@ -1220,8 +1272,10 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
in
(* We simply introduce an assignment - the box type is the
* identity when extracted (`box a == a`) *)
+ let monadic = false in
Let
- ( mk_typed_lvalue_from_var var None,
+ ( monadic,
+ mk_typed_lvalue_from_var var None,
Value (scrutinee, scrutinee_mplace),
branch ))
| branches ->