diff options
Diffstat (limited to 'src/PureMicroPasses.ml')
-rw-r--r-- | src/PureMicroPasses.ml | 223 |
1 files changed, 73 insertions, 150 deletions
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index 198a4d89..aba9610d 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -9,12 +9,6 @@ module V = Values let log = L.pure_micro_passes_log type config = { - use_state_monad : bool; - (** If `true`, use a state-error monad. - If `false`, only use an error monad. - - Using a state-error monad is necessary when modelling I/O, for instance. - *) decompose_monadic_let_bindings : bool; (** Some provers like F* don't support the decomposition of return values in monadic let-bindings: @@ -447,6 +441,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl = let ctx, arg = update_texpression app ctx in let e = App (app, arg) in (ctx, e) + | Abs (x, e) -> update_abs x e ctx | Func _ -> (* nothing to do *) (ctx, e.e) | Let (monadic, lb, re, e) -> update_let monadic lb re e ctx | Switch (scrut, body) -> update_switch_body scrut body ctx @@ -459,6 +454,17 @@ let compute_pretty_names (def : fun_decl) : fun_decl = let ctx = add_opt_right_constraint mp v ctx in (ctx, Value (v, mp)) (* *) + and update_abs (x : typed_lvalue) (e : texpression) (ctx : pn_ctx) : + pn_ctx * expression = + (* We first add the left-constraint *) + let ctx = add_left_constraint x ctx in + (* Update the expression, and add additional constraints *) + let ctx, e = update_texpression e ctx in + (* Update the abstracted value *) + let x = update_typed_lvalue ctx x in + (* Put together *) + (ctx, Abs (x, e)) + (* *) and update_let (monadic : bool) (lv : typed_lvalue) (re : texpression) (e : texpression) (ctx : pn_ctx) : pn_ctx * expression = (* We first add the left-constraint *) @@ -797,6 +803,7 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx) (func0 : func) match opt_destruct_function_call e with | Some (func1, args1) -> check_call func1 args1 | None -> false) + | Abs (_, e) -> self#visit_texpression env e | Func _ -> fun () -> false | Meta (_, e) -> self#visit_texpression env e | Switch (_, body) -> self#visit_switch_body env body @@ -875,7 +882,7 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx) method! visit_expression env e = match e with - | Value (_, _) | App _ | Switch (_, _) | Meta (_, _) -> + | Value (_, _) | App _ | Func _ | Switch (_, _) | Meta (_, _) | Abs _ -> super#visit_expression env e | Let (monadic, lv, re, e) -> (* Compute the set of values used in the next expression *) @@ -984,86 +991,6 @@ let keep_forward (config : config) (trans : pure_fun_translation) : bool = then false else true -(** Add unit arguments (optionally) to functions with no arguments, and - change their output type to use `result` - - TODO: remove this - *) -let to_monadic (config : config) (def : fun_decl) : fun_decl = - (* Update the body *) - let obj = - object - inherit [_] map_expression as super - - method! visit_call env call = - match call.func with - | Regular (A.Regular _, _) -> - if call.args = [] && config.add_unit_args then - let args = [ mk_value_expression unit_rvalue None ] in - { call with args } - else (* Otherwise: nothing to do *) super#visit_call env call - | Regular (A.Assumed _, _) | Unop _ | Binop _ -> - (* Unops, binops and primitive functions don't have unit arguments *) - super#visit_call env call - end - in - let def = - match def.body with - | None -> def - | Some body -> - let body = { body with body = obj#visit_texpression () body.body } in - { def with body = Some body } - in - - (* Update the signature: first the input types *) - let def = - if def.signature.inputs = [] && config.add_unit_args then - let signature = { def.signature with inputs = [ unit_ty ] } in - let body = - match def.body with - | None -> None - | Some body -> - let var_cnt = get_body_min_var_counter body in - let id, _ = VarId.fresh var_cnt in - let var = { id; basename = None; ty = unit_ty } in - let inputs = [ var ] in - let input_lv = mk_typed_lvalue_from_var var None in - let inputs_lvs = [ input_lv ] in - Some { body with inputs; inputs_lvs } - in - { def with signature; body } - else def - in - (* Then the output type *) - let output_ty = - match (def.back_id, def.signature.outputs) with - | None, [ out_ty ] -> - (* Forward function: there is always exactly one output *) - (* We don't do the same thing if we use a state error monad or not: - * - error-monad: `result out_ty` - * - state-error: `state -> result (state & out_ty) - *) - if config.use_state_monad then - let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in - let ret = mk_arrow_ty mk_state_ty ret in - ret - else (* Simply wrap the type in `result` *) - mk_result_ty out_ty - | Some _, outputs -> - (* Backward function: we have to group them *) - (* We don't do the same thing if we use a state error monad or not *) - if config.use_state_monad then - let ret = mk_simpl_tuple_ty outputs in - let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; ret ]) in - let ret = mk_arrow_ty mk_state_ty ret in - ret - else mk_result_ty (mk_simpl_tuple_ty outputs) - | _ -> failwith "Unreachable" - in - let outputs = [ output_ty ] in - let signature = { def.signature with outputs } in - { def with signature } - (** Convert the unit variables to `()` if they are used as right-values or `_` if they are used as left values in patterns. *) let unit_vars_to_unit (def : fun_decl) : fun_decl = @@ -1110,38 +1037,51 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = object inherit [_] map_expression as super - method! visit_App env call = - match call.func with - | Regular (A.Assumed aid, rg_id) -> ( - match (aid, rg_id) with - | A.BoxNew, _ -> - let arg = Collections.List.to_cons_nil call.args in - arg.e - | A.BoxDeref, None -> - (* `Box::deref` forward is the identity *) - let arg = Collections.List.to_cons_nil call.args in - arg.e - | A.BoxDeref, Some _ -> - (* `Box::deref` backward is `()` (doesn't give back anything) *) - (mk_value_expression unit_rvalue None).e - | A.BoxDerefMut, None -> - (* `Box::deref_mut` forward is the identity *) - let arg = Collections.List.to_cons_nil call.args in - arg.e - | A.BoxDerefMut, Some _ -> - (* `Box::deref_mut` back is the identity *) - let arg = - match call.args with - | [ _; given_back ] -> given_back - | _ -> failwith "Unreachable" - in - arg.e - | A.BoxFree, _ -> (mk_value_expression unit_rvalue None).e - | ( ( A.Replace | A.VecNew | A.VecPush | A.VecInsert | A.VecLen - | A.VecIndex | A.VecIndexMut ), - _ ) -> - super#visit_App env call) - | _ -> super#visit_App env call + method! visit_texpression env e = + match opt_destruct_function_call e with + | Some (func, args) -> ( + match func.func with + | Regular (A.Assumed aid, rg_id) -> ( + (* Below, when dealing with the arguments: we consider the very + * general case, where functions could be boxed (meaning we + * could have: `box_new f x`) + * *) + match (aid, rg_id) with + | A.BoxNew, _ -> + assert (rg_id = None); + let arg, args = Collections.List.pop args in + mk_apps arg args + | A.BoxDeref, None -> + (* `Box::deref` forward is the identity *) + let arg, args = Collections.List.pop args in + mk_apps arg args + | A.BoxDeref, Some _ -> + (* `Box::deref` backward is `()` (doesn't give back anything) *) + assert (args = []); + mk_value_expression unit_rvalue None + | A.BoxDerefMut, None -> + (* `Box::deref_mut` forward is the identity *) + let arg, args = Collections.List.pop args in + mk_apps arg args + | A.BoxDerefMut, Some _ -> + (* `Box::deref_mut` back is almost the identity: + * let box_deref_mut (x_init : t) (x_back : t) : t = x_back + * *) + let arg, args = + match args with + | _ :: given_back :: args -> (given_back, args) + | _ -> failwith "Unreachable" + in + mk_apps arg args + | A.BoxFree, _ -> + assert (args = []); + mk_value_expression unit_rvalue None + | ( ( A.Replace | A.VecNew | A.VecPush | A.VecInsert | A.VecLen + | A.VecIndex | A.VecIndexMut ), + _ ) -> + super#visit_texpression env e) + | _ -> super#visit_texpression env e) + | _ -> super#visit_texpression env e end in (* Update the body *) @@ -1221,6 +1161,7 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) inherit [_] map_expression as super method! visit_Let state_var monadic lv re e = + (* TODO: we should use a monad "kind" instead of a boolean *) if not monadic then super#visit_Let state_var monadic lv re e else (* We don't do the same thing if we use a state-error monad or simply @@ -1228,30 +1169,18 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) * Note that some functions always live in the error monad (arithmetic * operations, for instance). *) - let re_call = - match re.e with - | App call -> call - | _ -> raise (Failure "Unreachable: expected a function call") - in (* TODO: this information should be computed in SymbolicToPure and - * store in an enum ("monadic" should be an enum, not a bool). - * Also: everything will be cleaner once we update the AST to make - * it more idiomatic lambda calculus... *) - let re_call_can_use_state = - match re_call.func with - | Regular (A.Regular _, _) -> true - | Regular (A.Assumed _, _) | Unop _ | Binop _ -> false + * store in an enum ("monadic" should be an enum, not a bool). *) + let re_uses_state = + Option.is_some (opt_destruct_state_monad_result re.ty) in - if config.use_state_monad && re_call_can_use_state then - let re_call = - let call = re_call in + if re_uses_state then + (* Add the state argument on the right-expression *) + let re = let state_value = mk_typed_rvalue_from_var state_var in - let args = - call.args @ [ mk_value_expression state_value None ] - in - App { call with args } + let state_value = mk_value_expression state_value None in + mk_app re state_value in - let re = { re with e = re_call } in (* Create the match *) let fail_pat = mk_result_fail_lvalue lv.ty in let fail_value = mk_result_fail_rvalue e.ty in @@ -1273,6 +1202,8 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) let e = Switch (re, switch_body) in self#visit_expression state_var e else + let re_ty = Option.get (opt_destruct_result re.ty) in + assert (lv.ty = re_ty); let fail_pat = mk_result_fail_lvalue lv.ty in let fail_value = mk_result_fail_rvalue e.ty in let fail_branch = @@ -1312,9 +1243,9 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) else super#visit_Value state_var rv mp (** We also need to update values, in case this value is `Return ...`. - TODO: this is super ugly... We need to use the monadic functions - `fail` and `return` instead. - *) + TODO: this is super ugly... We need to use the monadic functions + fail` and `return` instead. + *) end in (* Update the body *) @@ -1386,14 +1317,6 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_decl) : match def with | None -> None | Some def -> - (* Add unit arguments for functions with no arguments, and change their return type. - * **Rk.**: from now onwards, the types in the AST are correct (until now, - * functions had return type `t` where they should have return type `result t`). - * TODO: this is not true with the state-error monad, unless we unfold the monadic binds. - * Also, from now onwards, the outputs list has length 1. *) - let def = to_monadic config def in - log#ldebug (lazy ("to_monadic:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); - (* Convert the unit variables to `()` if they are used as right-values or * `_` if they are used as left values. *) let def = unit_vars_to_unit def in |