summaryrefslogtreecommitdiff
path: root/src/SymbolicToPure.ml
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--src/SymbolicToPure.ml108
1 files changed, 93 insertions, 15 deletions
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index 18e2b873..b25b7309 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -38,6 +38,12 @@ type config = {
Note that we later filter the useless *forward* calls in the micro-passes,
where it is more natural to do.
*)
+ use_state_monad : bool;
+ (** If `true`, use a state-error monad.
+ If `false`, only use an error monad.
+
+ Using a state-error monad is necessary when modelling I/O, for instance.
+ *)
}
type type_context = {
@@ -950,6 +956,15 @@ let fun_is_monadic (fun_id : A.fun_id) : bool =
| A.Regular _ -> true
| A.Assumed aid -> Assumed.assumed_is_monadic aid
+let mk_function_ret_ty (config : config) (monadic : bool) (out_ty : ty) : ty =
+ if monadic then
+ if config.use_state_monad then
+ let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in
+ let ret = mk_arrow_ty mk_state_ty ret in
+ ret
+ else mk_result_ty out_ty
+ else out_ty
+
let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx)
: texpression =
match e with
@@ -967,7 +982,7 @@ let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx)
| Expansion (p, sv, exp) -> translate_expansion config p sv exp ctx
| Meta (meta, e) -> translate_meta config meta e ctx
-and translate_return (_config : config) (opt_v : V.typed_value option)
+and translate_return (config : config) (opt_v : V.typed_value option)
(ctx : bs_ctx) : texpression =
(* There are two cases:
- either we are translating a forward function, in which case the optional
@@ -979,13 +994,27 @@ and translate_return (_config : config) (opt_v : V.typed_value option)
(* Forward function *)
let v = Option.get opt_v in
let v = typed_value_to_rvalue ctx v in
- (* TODO: we need to use a `return` function (otherwise we have problems
- * with the state-error monad). We also need to update the type when using
- * a state-error monad. *)
- let v = mk_result_return_rvalue v in
- let e = Value (v, None) in
- let ty = v.ty in
- { e; ty }
+ (* We don't synthesize the same expression depending on the monad we use:
+ * - error-monad: Return x
+ * - state-error monad: fun state -> Return (state, x)
+ * *)
+ (* TODO: we should use a `return` function, it would be cleaner *)
+ if config.use_state_monad then
+ let _, state_var = fresh_var (Some "state") mk_state_ty ctx in
+ let state_rvalue = mk_typed_rvalue_from_var state_var in
+ let v =
+ mk_result_return_rvalue (mk_simpl_tuple_rvalue [ state_rvalue; v ])
+ in
+ let e = Value (v, None) in
+ let ty = v.ty in
+ let e = { e; ty } in
+ let state_var = mk_typed_lvalue_from_var state_var None in
+ mk_abs state_var e
+ else
+ let v = mk_result_return_rvalue v in
+ let e = Value (v, None) in
+ let ty = v.ty in
+ { e; ty }
| Some bid ->
(* Backward function *)
(* Sanity check *)
@@ -997,11 +1026,27 @@ and translate_return (_config : config) (opt_v : V.typed_value option)
T.RegionGroupId.Map.find bid ctx.backward_outputs
in
let field_values = List.map mk_typed_rvalue_from_var backward_outputs in
- let ret_value = mk_simpl_tuple_rvalue field_values in
- let ret_value = mk_result_return_rvalue ret_value in
- let e = Value (ret_value, None) in
- let ty = ret_value.ty in
- { e; ty }
+ (* See the comment about the monads, for the forward function case *)
+ (* TODO: we should use a `fail` function, it would be cleaner *)
+ if config.use_state_monad then
+ let _, state_var = fresh_var (Some "state") mk_state_ty ctx in
+ let state_rvalue = mk_typed_rvalue_from_var state_var in
+ let ret_value = mk_simpl_tuple_rvalue field_values in
+ let ret_value =
+ mk_result_return_rvalue
+ (mk_simpl_tuple_rvalue [ state_rvalue; ret_value ])
+ in
+ let e = Value (ret_value, None) in
+ let ty = ret_value.ty in
+ let e = { e; ty } in
+ let state_var = mk_typed_lvalue_from_var state_var None in
+ mk_abs state_var e
+ else
+ let ret_value = mk_simpl_tuple_rvalue field_values in
+ let ret_value = mk_result_return_rvalue ret_value in
+ let e = Value (ret_value, None) in
+ let ty = ret_value.ty in
+ { e; ty }
and translate_function_call (config : config) (call : S.call) (e : S.expression)
(ctx : bs_ctx) : texpression =
@@ -1047,7 +1092,7 @@ and translate_function_call (config : config) (call : S.call) (e : S.expression)
let dest_v = mk_typed_lvalue_from_var dest dest_mplace in
let func = { func; type_params } in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
- let ret_ty = if monadic then mk_result_ty dest_v.ty else dest_v.ty in
+ let ret_ty = mk_function_ret_ty config monadic dest_v.ty in
let func_ty = mk_arrows input_tys ret_ty in
let func = { e = Func func; ty = func_ty } in
let call = mk_apps func args in
@@ -1180,7 +1225,7 @@ and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression)
in
let monadic = fun_is_monadic fun_id in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
- let ret_ty = if monadic then mk_result_ty output.ty else output.ty in
+ let ret_ty = mk_function_ret_ty config monadic output.ty in
let func_ty = mk_arrows input_tys ret_ty in
let func = { func; type_params } in
let func = { e = Func func; ty = func_ty } in
@@ -1470,6 +1515,7 @@ let translate_fun_decl (config : config) (ctx : bs_ctx)
(* Translate the declaration *)
let def_id = def.A.def_id in
let basename = def.name in
+ (* Lookup the signature *)
let signature = bs_ctx_lookup_local_function_sig def_id bid ctx in
(* Translate the body, if there is *)
let body =
@@ -1502,6 +1548,38 @@ let translate_fun_decl (config : config) (ctx : bs_ctx)
(List.combine inputs signature.inputs));
Some { inputs; inputs_lvs; body }
in
+ (* Make the signature monadic *)
+ let output_ty =
+ match (bid, signature.outputs) with
+ | None, [ out_ty ] ->
+ (* Forward function: there is always exactly one output *)
+ (* We don't do the same thing if we use a state error monad or not:
+ * - error-monad: `result out_ty`
+ * - state-error: `state -> result (state & out_ty)
+ *)
+ if config.use_state_monad then
+ let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in
+ let ret = mk_arrow_ty mk_state_ty ret in
+ ret
+ else (* Simply wrap the type in `result` *)
+ mk_result_ty out_ty
+ | Some _, outputs ->
+ (* Backward function: we have to group the list of outputs into a tuple
+ * (and similarly to the forward function, we don't do the same thing
+ * if we use a state error monad or not):
+ * - error-monad: `result (out_ty1 & .. out_tyn)`
+ * - state-error: `state -> result (out_ty1 & .. out_tyn)`
+ *)
+ if config.use_state_monad then
+ let ret = mk_simpl_tuple_ty outputs in
+ let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; ret ]) in
+ let ret = mk_arrow_ty mk_state_ty ret in
+ ret
+ else mk_result_ty (mk_simpl_tuple_ty outputs)
+ | _ -> failwith "Unreachable"
+ in
+ let outputs = [ output_ty ] in
+ let signature = { signature with outputs } in
(* Assemble the declaration *)
let def = { def_id; back_id = bid; basename; signature; body } in
(* Debugging *)