summaryrefslogtreecommitdiff
path: root/src/SymbolicToPure.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/SymbolicToPure.ml')
-rw-r--r--src/SymbolicToPure.ml74
1 files changed, 50 insertions, 24 deletions
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index 4e15d921..49bf3559 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -948,17 +948,26 @@ let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) : S.call * V.abs list =
(** Small utility.
- Return true if a function return type is monadic.
- Always true, at the exception of some assumed functions.
+ Return: (function is monadic, function uses state monad)
+
+ Note that all functions are monadic except some assumed functions.
*)
-let fun_is_monadic (fun_id : A.fun_id) : bool =
+let fun_is_monadic (fun_id : A.fun_id) : bool * bool =
match fun_id with
- | A.Regular _ -> true
- | A.Assumed aid -> Assumed.assumed_is_monadic aid
+ | A.Regular _ -> (true, true)
+ | A.Assumed aid -> (Assumed.assumed_is_monadic aid, false)
+
+(** Utility for function return types.
-let mk_function_ret_ty (config : config) (monadic : bool) (out_ty : ty) : ty =
+ A function return type can have the shape:
+ - ty
+ - result ty (* error-monad *)
+ - state -> result (state & ty) (* state-error monad *)
+ *)
+let mk_function_ret_ty (config : config) (monadic : bool) (state_monad : bool)
+ (out_ty : ty) : ty =
if monadic then
- if config.use_state_monad then
+ if config.use_state_monad && state_monad then
let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in
let ret = mk_arrow_ty mk_state_ty ret in
ret
@@ -969,19 +978,34 @@ let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx)
: texpression =
match e with
| S.Return opt_v -> translate_return config opt_v ctx
- | Panic ->
- (* Here we use the function return type - note that it is ok because
- * we don't match on panics which happen inside the function body -
- * but it won't be true anymore once we translate individual blocks *)
- let v = mk_result_fail_rvalue ctx.ret_ty in
- let e = Value (v, None) in
- let ty = v.ty in
- { e; ty }
+ | Panic -> translate_panic config ctx
| FunCall (call, e) -> translate_function_call config call e ctx
| EndAbstraction (abs, e) -> translate_end_abstraction config abs e ctx
| Expansion (p, sv, exp) -> translate_expansion config p sv exp ctx
| Meta (meta, e) -> translate_meta config meta e ctx
+and translate_panic (config : config) (ctx : bs_ctx) : texpression =
+ (* Here we use the function return type - note that it is ok because
+ * we don't match on panics which happen inside the function body -
+ * but it won't be true anymore once we translate individual blocks *)
+ (* If we use a state monad, we need to add a lambda for the state variable *)
+ if config.use_state_monad then
+ (* Create the `Fail` value *)
+ let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; ctx.ret_ty ] in
+ let v = mk_result_fail_rvalue ret_ty in
+ let e = Value (v, None) in
+ let ty = v.ty in
+ let e = { e; ty } in
+ (* Add the lambda *)
+ let _, state_var = fresh_var (Some "st") mk_state_ty ctx in
+ let state_lvalue = mk_typed_lvalue_from_var state_var None in
+ mk_abs state_lvalue e
+ else
+ let v = mk_result_fail_rvalue ctx.ret_ty in
+ let e = Value (v, None) in
+ let ty = v.ty in
+ { e; ty }
+
and translate_return (config : config) (opt_v : V.typed_value option)
(ctx : bs_ctx) : texpression =
(* There are two cases:
@@ -1058,21 +1082,21 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression)
let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in
(* Retrieve the function id, and register the function call in the context
* if necessary. *)
- let ctx, func, monadic =
+ let ctx, func, monadic, state_monad =
match call.call_id with
| S.Fun (fid, call_id) ->
let ctx = bs_ctx_register_forward_call call_id call ctx in
let func = Regular (fid, None) in
- let monadic = fun_is_monadic fid in
- (ctx, func, monadic)
- | S.Unop E.Not -> (ctx, Unop Not, false)
+ let monadic, state_monad = fun_is_monadic fid in
+ (ctx, func, monadic, state_monad)
+ | S.Unop E.Not -> (ctx, Unop Not, false, false)
| S.Unop E.Neg -> (
match args with
| [ arg ] ->
let int_ty = ty_as_integer arg.ty in
(* Note that negation can lead to an overflow and thus fail (it
* is thus monadic) *)
- (ctx, Unop (Neg int_ty), true)
+ (ctx, Unop (Neg int_ty), true, false)
| _ -> failwith "Unreachable")
| S.Binop binop -> (
match args with
@@ -1081,7 +1105,7 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression)
let int_ty1 = ty_as_integer arg1.ty in
assert (int_ty0 = int_ty1);
let monadic = binop_can_fail binop in
- (ctx, Binop (binop, int_ty0), monadic)
+ (ctx, Binop (binop, int_ty0), monadic, false)
| _ -> failwith "Unreachable")
in
let args =
@@ -1092,7 +1116,7 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression)
let dest_v = mk_typed_lvalue_from_var dest dest_mplace in
let func = { func; type_params } in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
- let ret_ty = mk_function_ret_ty config monadic dest_v.ty in
+ let ret_ty = mk_function_ret_ty config monadic state_monad dest_v.ty in
let func_ty = mk_arrows input_tys ret_ty in
let func = { e = Func func; ty = func_ty } in
let call = mk_apps func args in
@@ -1223,9 +1247,9 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression)
(fun (arg, mp) -> mk_value_expression arg mp)
(List.combine inputs args_mplaces)
in
- let monadic = fun_is_monadic fun_id in
+ let monadic, state_monad = fun_is_monadic fun_id in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
- let ret_ty = mk_function_ret_ty config monadic output.ty in
+ let ret_ty = mk_function_ret_ty config monadic state_monad output.ty in
let func_ty = mk_arrows input_tys ret_ty in
let func = { func; type_params } in
let func = { e = Func func; ty = func_ty } in
@@ -1444,7 +1468,9 @@ and translate_expansion (config : config) (p : S.mplace option)
(* There should be at least one branch *)
let branch = List.hd branches in
let ty = branch.branch.ty in
+ (* Sanity check *)
assert (List.for_all (fun br -> br.branch.ty = ty) branches);
+ (* Return *)
{ e; ty })
| ExpandBool (true_e, false_e) ->
(* We don't need to update the context: we don't introduce any