diff options
Diffstat (limited to 'src/SymbolicToPure.ml')
-rw-r--r-- | src/SymbolicToPure.ml | 100 |
1 files changed, 77 insertions, 23 deletions
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index 5c0250f7..d65e929f 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -15,9 +15,18 @@ module PP = PrintPure (** The local logger *) let log = L.symbolic_to_pure_log +(* TODO : move *) +let binop_can_fail (binop : E.binop) : bool = + match binop with + | BitXor | BitAnd | BitOr | Eq | Lt | Le | Ne | Ge | Gt -> false + | Div | Rem | Add | Sub | Mul -> true + | Shl | Shr -> raise Unimplemented + (* TODO: move *) let mk_place_from_var (v : var) : place = { var = v.id; projection = [] } +let mk_tuple_ty (tys : ty list) : ty = Adt (Tuple, tys) + let mk_typed_rvalue_from_var (v : var) : typed_rvalue = let value = RvPlace (mk_place_from_var v) in let ty = v.ty in @@ -31,7 +40,7 @@ let mk_typed_lvalue_from_var (v : var) (mp : mplace option) : typed_lvalue = let mk_tuple_lvalue (vl : typed_lvalue list) : typed_lvalue = let tys = List.map (fun (v : typed_lvalue) -> v.ty) vl in - let ty = Adt (T.Tuple, tys) in + let ty = Adt (Tuple, tys) in let value = LvAdt { variant_id = None; field_values = vl } in { value; ty } @@ -47,6 +56,18 @@ let ty_as_integer (t : ty) : T.integer_type = let type_def_is_enum (def : T.type_def) : bool = match def.kind with T.Struct _ -> false | Enum _ -> true +let mk_result_fail_rvalue (ty : ty) : typed_rvalue = + let ty = Adt (Assumed Result, [ ty ]) in + let value = RvAdt { variant_id = Some result_fail_id; field_values = [] } in + { value; ty } + +let mk_result_return_rvalue (v : typed_rvalue) : typed_rvalue = + let ty = Adt (Assumed Result, [ v.ty ]) in + let value = + RvAdt { variant_id = Some result_return_id; field_values = [ v ] } + in + { value; ty } + (** Type substitution *) let ty_substitute (tsubst : TypeVarId.id -> ty) (ty : ty) : ty = let obj = @@ -132,6 +153,7 @@ type bs_ctx = { fun_context : fun_context; fun_def : A.fun_def; bid : T.RegionGroupId.id option; (** TODO: rename *) + ret_ty : ty; (** The return type - we use it to translate `Panic` *) sv_to_var : var V.SymbolicValueId.Map.t; (** Whenever we encounter a new symbolic value (introduced because of a symbolic expansion or upon ending an abstraction, for instance) @@ -253,11 +275,17 @@ let bs_ctx_register_backward_call (abs : V.abs) (ctx : bs_ctx) : bs_ctx * fun_id let rec translate_sty (ty : T.sty) : ty = let translate = translate_sty in match ty with - | T.Adt (type_id, regions, tys) -> + | T.Adt (type_id, regions, tys) -> ( (* Can't translate types with regions for now *) assert (regions = []); let tys = List.map translate tys in - Adt (type_id, tys) + match type_id with + | T.AdtId adt_id -> Adt (AdtId adt_id, tys) + | T.Tuple -> Adt (Tuple, tys) + | T.Assumed T.Box -> ( + match tys with + | [ ty ] -> ty + | _ -> failwith "Box type with incorrect number of arguments")) | TypeVar vid -> TypeVar vid | Bool -> Bool | Char -> Char @@ -321,7 +349,8 @@ let rec translate_fwd_ty (types_infos : TA.type_infos) (ty : 'r T.ty) : ty = let tys = List.map translate tys in (* Eliminate boxes *) match type_id with - | T.AdtId _ | Tuple -> Adt (type_id, tys) + | AdtId adt_id -> Adt (AdtId adt_id, tys) + | Tuple -> Adt (Tuple, tys) | T.Assumed T.Box -> ( match tys with | [ bty ] -> bty @@ -363,6 +392,11 @@ let rec translate_back_ty (types_infos : TA.type_infos) | T.AdtId _ -> (* Don't accept ADTs (which are not tuples) with borrows for now *) assert (not (TypesUtils.ty_has_borrows types_infos ty)); + let type_id = + match type_id with + | T.AdtId id -> AdtId id + | T.Tuple | T.Assumed T.Box -> failwith "Unreachable" + in if inside_mut then let tys_t = List.filter_map translate tys in Some (Adt (type_id, tys_t)) @@ -378,7 +412,7 @@ let rec translate_back_ty (types_infos : TA.type_infos) | T.Tuple -> ( (* Tuples can contain borrows (which we eliminated) *) let tys_t = List.filter_map translate tys in - match tys_t with [] -> None | _ -> Some (Adt (T.Tuple, tys_t)))) + match tys_t with [] -> None | _ -> Some (Adt (Tuple, tys_t)))) | TypeVar vid -> wrap (TypeVar vid) | Bool -> wrap Bool | Char -> wrap Char @@ -886,7 +920,7 @@ let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) : S.call * V.abs list = let rec translate_expression (e : S.expression) (ctx : bs_ctx) : expression = match e with | S.Return opt_v -> translate_return opt_v ctx - | Panic -> Fail + | Panic -> Value (mk_result_fail_rvalue ctx.ret_ty, None) | FunCall (call, e) -> translate_function_call call e ctx | EndAbstraction (abs, e) -> translate_end_abstraction abs e ctx | Expansion (p, sv, exp) -> translate_expansion p sv exp ctx @@ -904,7 +938,7 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : expression (* Forward function *) let v = Option.get opt_v in let v = typed_value_to_rvalue ctx v in - Return v + Value (mk_result_return_rvalue v, None) | Some bid -> (* Backward function *) (* Sanity check *) @@ -918,9 +952,9 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : expression let field_values = List.map mk_typed_rvalue_from_var backward_outputs in let ret_value = RvAdt { variant_id = None; field_values } in let ret_tys = List.map (fun (v : typed_rvalue) -> v.ty) field_values in - let ret_ty = Adt (T.Tuple, ret_tys) in + let ret_ty = Adt (Tuple, ret_tys) in let ret_value : typed_rvalue = { value = ret_value; ty = ret_ty } in - Return ret_value + Value (mk_result_return_rvalue ret_value, None) and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : expression = @@ -932,18 +966,20 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in (* Retrieve the function id, and register the function call in the context * if necessary. *) - let ctx, func = + let ctx, func, monadic = match call.call_id with | S.Fun (fid, call_id) -> let ctx = bs_ctx_register_forward_call call_id call ctx in let func = Regular (fid, None) in - (ctx, func) - | S.Unop E.Not -> (ctx, Unop Not) + (ctx, func, true) + | S.Unop E.Not -> (ctx, Unop Not, false) | S.Unop E.Neg -> ( match args with | [ arg ] -> let int_ty = ty_as_integer arg.ty in - (ctx, Unop (Neg int_ty)) + (* Note that negation can lead to an overflow and thus fail (it + * is thus monadic) *) + (ctx, Unop (Neg int_ty), true) | _ -> failwith "Unreachable") | S.Binop binop -> ( match args with @@ -951,7 +987,8 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let int_ty0 = ty_as_integer arg0.ty in let int_ty1 = ty_as_integer arg1.ty in assert (int_ty0 = int_ty1); - (ctx, Binop (binop, int_ty0)) + let monadic = binop_can_fail binop in + (ctx, Binop (binop, int_ty0), monadic) | _ -> failwith "Unreachable") in let args = @@ -962,7 +999,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (* Translate the next expression *) let e = translate_expression e ctx in (* Put together *) - Let (mk_typed_lvalue_from_var dest dest_mplace, call, e) + Let (monadic, mk_typed_lvalue_from_var dest dest_mplace, call, e) and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : expression = @@ -1015,9 +1052,11 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : (* Translate the next expression *) let e = translate_expression e ctx in (* Generate the assignemnts *) + let monadic = false in List.fold_right (fun (var, value) e -> - Let (mk_typed_lvalue_from_var var None, Value (value, None), e)) + Let + (monadic, mk_typed_lvalue_from_var var None, Value (value, None), e)) variables_values e | V.FunCall -> let call_info = V.FunCallId.Map.find abs.call_id ctx.calls in @@ -1078,7 +1117,8 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : (List.combine inputs args_mplaces) in let call = { func; type_params; args } in - Let (output, Call call, e) + let monadic = true in + Let (monadic, output, Call call, e) | V.SynthRet -> (* If we end the abstraction which consumed the return value of the function * we are synthesizing, we get back the borrows which were inside. Those borrows @@ -1129,9 +1169,14 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : (* Translate the next expression *) let e = translate_expression e ctx in (* Generate the assignments *) + let monadic = false in List.fold_right (fun (given_back, input_var) e -> - Let (given_back, Value (mk_typed_rvalue_from_var input_var, None), e)) + Let + ( monadic, + given_back, + Value (mk_typed_rvalue_from_var input_var, None), + e )) given_back_inputs e and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) @@ -1153,8 +1198,10 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) * introduce an reassignment *) let ctx, var = fresh_var_for_symbolic_value nsv ctx in let e = translate_expression e ctx in + let monadic = false in Let - ( mk_typed_lvalue_from_var var None, + ( monadic, + mk_typed_lvalue_from_var var None, Value (scrutinee, scrutinee_mplace), e ) | SeAdt _ -> @@ -1180,7 +1227,8 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) List.map (fun v -> mk_typed_lvalue_from_var v None) vars in let lv = mk_adt_lvalue scrutinee.ty variant_id lvars in - Let (lv, Value (scrutinee, scrutinee_mplace), branch) + let monadic = false in + Let (monadic, lv, Value (scrutinee, scrutinee_mplace), branch) else (* This is not an enumeration: introduce let-bindings for every * field. @@ -1197,11 +1245,13 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) { value; ty } in let id_var_pairs = FieldId.mapi (fun fid v -> (fid, v)) vars in + let monadic = false in List.fold_right (fun (fid, var) e -> let field_proj = gen_field_proj fid var in Let - ( mk_typed_lvalue_from_var var None, + ( monadic, + mk_typed_lvalue_from_var var None, Value (field_proj, None), e )) id_var_pairs branch @@ -1209,8 +1259,10 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) let vars = List.map (fun x -> mk_typed_lvalue_from_var x None) vars in + let monadic = false in Let - ( mk_tuple_lvalue vars, + ( monadic, + mk_tuple_lvalue vars, Value (scrutinee, scrutinee_mplace), branch ) | T.Assumed T.Box -> @@ -1220,8 +1272,10 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) in (* We simply introduce an assignment - the box type is the * identity when extracted (`box a == a`) *) + let monadic = false in Let - ( mk_typed_lvalue_from_var var None, + ( monadic, + mk_typed_lvalue_from_var var None, Value (scrutinee, scrutinee_mplace), branch )) | branches -> |