summaryrefslogtreecommitdiff
path: root/src/PureMicroPasses.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/PureMicroPasses.ml')
-rw-r--r--src/PureMicroPasses.ml223
1 files changed, 73 insertions, 150 deletions
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml
index 198a4d89..aba9610d 100644
--- a/src/PureMicroPasses.ml
+++ b/src/PureMicroPasses.ml
@@ -9,12 +9,6 @@ module V = Values
let log = L.pure_micro_passes_log
type config = {
- use_state_monad : bool;
- (** If `true`, use a state-error monad.
- If `false`, only use an error monad.
-
- Using a state-error monad is necessary when modelling I/O, for instance.
- *)
decompose_monadic_let_bindings : bool;
(** Some provers like F* don't support the decomposition of return values
in monadic let-bindings:
@@ -447,6 +441,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
let ctx, arg = update_texpression app ctx in
let e = App (app, arg) in
(ctx, e)
+ | Abs (x, e) -> update_abs x e ctx
| Func _ -> (* nothing to do *) (ctx, e.e)
| Let (monadic, lb, re, e) -> update_let monadic lb re e ctx
| Switch (scrut, body) -> update_switch_body scrut body ctx
@@ -459,6 +454,17 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
let ctx = add_opt_right_constraint mp v ctx in
(ctx, Value (v, mp))
(* *)
+ and update_abs (x : typed_lvalue) (e : texpression) (ctx : pn_ctx) :
+ pn_ctx * expression =
+ (* We first add the left-constraint *)
+ let ctx = add_left_constraint x ctx in
+ (* Update the expression, and add additional constraints *)
+ let ctx, e = update_texpression e ctx in
+ (* Update the abstracted value *)
+ let x = update_typed_lvalue ctx x in
+ (* Put together *)
+ (ctx, Abs (x, e))
+ (* *)
and update_let (monadic : bool) (lv : typed_lvalue) (re : texpression)
(e : texpression) (ctx : pn_ctx) : pn_ctx * expression =
(* We first add the left-constraint *)
@@ -797,6 +803,7 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (func0 : func)
match opt_destruct_function_call e with
| Some (func1, args1) -> check_call func1 args1
| None -> false)
+ | Abs (_, e) -> self#visit_texpression env e
| Func _ -> fun () -> false
| Meta (_, e) -> self#visit_texpression env e
| Switch (_, body) -> self#visit_switch_body env body
@@ -875,7 +882,7 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
method! visit_expression env e =
match e with
- | Value (_, _) | App _ | Switch (_, _) | Meta (_, _) ->
+ | Value (_, _) | App _ | Func _ | Switch (_, _) | Meta (_, _) | Abs _ ->
super#visit_expression env e
| Let (monadic, lv, re, e) ->
(* Compute the set of values used in the next expression *)
@@ -984,86 +991,6 @@ let keep_forward (config : config) (trans : pure_fun_translation) : bool =
then false
else true
-(** Add unit arguments (optionally) to functions with no arguments, and
- change their output type to use `result`
-
- TODO: remove this
- *)
-let to_monadic (config : config) (def : fun_decl) : fun_decl =
- (* Update the body *)
- let obj =
- object
- inherit [_] map_expression as super
-
- method! visit_call env call =
- match call.func with
- | Regular (A.Regular _, _) ->
- if call.args = [] && config.add_unit_args then
- let args = [ mk_value_expression unit_rvalue None ] in
- { call with args }
- else (* Otherwise: nothing to do *) super#visit_call env call
- | Regular (A.Assumed _, _) | Unop _ | Binop _ ->
- (* Unops, binops and primitive functions don't have unit arguments *)
- super#visit_call env call
- end
- in
- let def =
- match def.body with
- | None -> def
- | Some body ->
- let body = { body with body = obj#visit_texpression () body.body } in
- { def with body = Some body }
- in
-
- (* Update the signature: first the input types *)
- let def =
- if def.signature.inputs = [] && config.add_unit_args then
- let signature = { def.signature with inputs = [ unit_ty ] } in
- let body =
- match def.body with
- | None -> None
- | Some body ->
- let var_cnt = get_body_min_var_counter body in
- let id, _ = VarId.fresh var_cnt in
- let var = { id; basename = None; ty = unit_ty } in
- let inputs = [ var ] in
- let input_lv = mk_typed_lvalue_from_var var None in
- let inputs_lvs = [ input_lv ] in
- Some { body with inputs; inputs_lvs }
- in
- { def with signature; body }
- else def
- in
- (* Then the output type *)
- let output_ty =
- match (def.back_id, def.signature.outputs) with
- | None, [ out_ty ] ->
- (* Forward function: there is always exactly one output *)
- (* We don't do the same thing if we use a state error monad or not:
- * - error-monad: `result out_ty`
- * - state-error: `state -> result (state & out_ty)
- *)
- if config.use_state_monad then
- let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in
- let ret = mk_arrow_ty mk_state_ty ret in
- ret
- else (* Simply wrap the type in `result` *)
- mk_result_ty out_ty
- | Some _, outputs ->
- (* Backward function: we have to group them *)
- (* We don't do the same thing if we use a state error monad or not *)
- if config.use_state_monad then
- let ret = mk_simpl_tuple_ty outputs in
- let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; ret ]) in
- let ret = mk_arrow_ty mk_state_ty ret in
- ret
- else mk_result_ty (mk_simpl_tuple_ty outputs)
- | _ -> failwith "Unreachable"
- in
- let outputs = [ output_ty ] in
- let signature = { def.signature with outputs } in
- { def with signature }
-
(** Convert the unit variables to `()` if they are used as right-values or
`_` if they are used as left values in patterns. *)
let unit_vars_to_unit (def : fun_decl) : fun_decl =
@@ -1110,38 +1037,51 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
object
inherit [_] map_expression as super
- method! visit_App env call =
- match call.func with
- | Regular (A.Assumed aid, rg_id) -> (
- match (aid, rg_id) with
- | A.BoxNew, _ ->
- let arg = Collections.List.to_cons_nil call.args in
- arg.e
- | A.BoxDeref, None ->
- (* `Box::deref` forward is the identity *)
- let arg = Collections.List.to_cons_nil call.args in
- arg.e
- | A.BoxDeref, Some _ ->
- (* `Box::deref` backward is `()` (doesn't give back anything) *)
- (mk_value_expression unit_rvalue None).e
- | A.BoxDerefMut, None ->
- (* `Box::deref_mut` forward is the identity *)
- let arg = Collections.List.to_cons_nil call.args in
- arg.e
- | A.BoxDerefMut, Some _ ->
- (* `Box::deref_mut` back is the identity *)
- let arg =
- match call.args with
- | [ _; given_back ] -> given_back
- | _ -> failwith "Unreachable"
- in
- arg.e
- | A.BoxFree, _ -> (mk_value_expression unit_rvalue None).e
- | ( ( A.Replace | A.VecNew | A.VecPush | A.VecInsert | A.VecLen
- | A.VecIndex | A.VecIndexMut ),
- _ ) ->
- super#visit_App env call)
- | _ -> super#visit_App env call
+ method! visit_texpression env e =
+ match opt_destruct_function_call e with
+ | Some (func, args) -> (
+ match func.func with
+ | Regular (A.Assumed aid, rg_id) -> (
+ (* Below, when dealing with the arguments: we consider the very
+ * general case, where functions could be boxed (meaning we
+ * could have: `box_new f x`)
+ * *)
+ match (aid, rg_id) with
+ | A.BoxNew, _ ->
+ assert (rg_id = None);
+ let arg, args = Collections.List.pop args in
+ mk_apps arg args
+ | A.BoxDeref, None ->
+ (* `Box::deref` forward is the identity *)
+ let arg, args = Collections.List.pop args in
+ mk_apps arg args
+ | A.BoxDeref, Some _ ->
+ (* `Box::deref` backward is `()` (doesn't give back anything) *)
+ assert (args = []);
+ mk_value_expression unit_rvalue None
+ | A.BoxDerefMut, None ->
+ (* `Box::deref_mut` forward is the identity *)
+ let arg, args = Collections.List.pop args in
+ mk_apps arg args
+ | A.BoxDerefMut, Some _ ->
+ (* `Box::deref_mut` back is almost the identity:
+ * let box_deref_mut (x_init : t) (x_back : t) : t = x_back
+ * *)
+ let arg, args =
+ match args with
+ | _ :: given_back :: args -> (given_back, args)
+ | _ -> failwith "Unreachable"
+ in
+ mk_apps arg args
+ | A.BoxFree, _ ->
+ assert (args = []);
+ mk_value_expression unit_rvalue None
+ | ( ( A.Replace | A.VecNew | A.VecPush | A.VecInsert | A.VecLen
+ | A.VecIndex | A.VecIndexMut ),
+ _ ) ->
+ super#visit_texpression env e)
+ | _ -> super#visit_texpression env e)
+ | _ -> super#visit_texpression env e
end
in
(* Update the body *)
@@ -1221,6 +1161,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
inherit [_] map_expression as super
method! visit_Let state_var monadic lv re e =
+ (* TODO: we should use a monad "kind" instead of a boolean *)
if not monadic then super#visit_Let state_var monadic lv re e
else
(* We don't do the same thing if we use a state-error monad or simply
@@ -1228,30 +1169,18 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
* Note that some functions always live in the error monad (arithmetic
* operations, for instance).
*)
- let re_call =
- match re.e with
- | App call -> call
- | _ -> raise (Failure "Unreachable: expected a function call")
- in
(* TODO: this information should be computed in SymbolicToPure and
- * store in an enum ("monadic" should be an enum, not a bool).
- * Also: everything will be cleaner once we update the AST to make
- * it more idiomatic lambda calculus... *)
- let re_call_can_use_state =
- match re_call.func with
- | Regular (A.Regular _, _) -> true
- | Regular (A.Assumed _, _) | Unop _ | Binop _ -> false
+ * store in an enum ("monadic" should be an enum, not a bool). *)
+ let re_uses_state =
+ Option.is_some (opt_destruct_state_monad_result re.ty)
in
- if config.use_state_monad && re_call_can_use_state then
- let re_call =
- let call = re_call in
+ if re_uses_state then
+ (* Add the state argument on the right-expression *)
+ let re =
let state_value = mk_typed_rvalue_from_var state_var in
- let args =
- call.args @ [ mk_value_expression state_value None ]
- in
- App { call with args }
+ let state_value = mk_value_expression state_value None in
+ mk_app re state_value
in
- let re = { re with e = re_call } in
(* Create the match *)
let fail_pat = mk_result_fail_lvalue lv.ty in
let fail_value = mk_result_fail_rvalue e.ty in
@@ -1273,6 +1202,8 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
let e = Switch (re, switch_body) in
self#visit_expression state_var e
else
+ let re_ty = Option.get (opt_destruct_result re.ty) in
+ assert (lv.ty = re_ty);
let fail_pat = mk_result_fail_lvalue lv.ty in
let fail_value = mk_result_fail_rvalue e.ty in
let fail_branch =
@@ -1312,9 +1243,9 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
else super#visit_Value state_var rv mp
(** We also need to update values, in case this value is `Return ...`.
- TODO: this is super ugly... We need to use the monadic functions
- `fail` and `return` instead.
- *)
+ TODO: this is super ugly... We need to use the monadic functions
+ fail` and `return` instead.
+ *)
end
in
(* Update the body *)
@@ -1386,14 +1317,6 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_decl) :
match def with
| None -> None
| Some def ->
- (* Add unit arguments for functions with no arguments, and change their return type.
- * **Rk.**: from now onwards, the types in the AST are correct (until now,
- * functions had return type `t` where they should have return type `result t`).
- * TODO: this is not true with the state-error monad, unless we unfold the monadic binds.
- * Also, from now onwards, the outputs list has length 1. *)
- let def = to_monadic config def in
- log#ldebug (lazy ("to_monadic:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
-
(* Convert the unit variables to `()` if they are used as right-values or
* `_` if they are used as left values. *)
let def = unit_vars_to_unit def in