diff options
author | Son Ho | 2022-02-23 23:36:53 +0100 |
---|---|---|
committer | Son Ho | 2022-02-23 23:36:53 +0100 |
commit | 532b43ad73a4964cd75d8548d43eb894b7f225c1 (patch) | |
tree | 485fc8c35aebd2467878dc18e3f675a9e43175a1 | |
parent | e3430dcb5e944af0903b272669e6ddbb8e7d59c3 (diff) |
Start working on generating code which uses a state-error monad
-rw-r--r-- | src/ExtractToFStar.ml | 7 | ||||
-rw-r--r-- | src/Pure.ml | 5 | ||||
-rw-r--r-- | src/PureMicroPasses.ml | 154 | ||||
-rw-r--r-- | src/PureUtils.ml | 11 | ||||
-rw-r--r-- | src/Translate.ml | 6 | ||||
-rw-r--r-- | src/main.ml | 7 |
6 files changed, 154 insertions, 36 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml index 3840ea1e..dcaef438 100644 --- a/src/ExtractToFStar.ml +++ b/src/ExtractToFStar.ml @@ -95,7 +95,7 @@ let fstar_keywords = List.concat [ named_unops; named_binops; misc ] let fstar_assumed_adts : (assumed_ty * string) list = - [ (Result, "result"); (Option, "option"); (Vec, "vec") ] + [ (State, "state"); (Result, "result"); (Option, "option"); (Vec, "vec") ] let fstar_assumed_structs : (assumed_ty * string) list = [] @@ -394,12 +394,13 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (extract_ty ctx fmt true) tys; F.pp_print_string fmt ")") | AdtId _ | Assumed _ -> - if inside then F.pp_print_string fmt "("; + let print_paren = inside && tys <> [] in + if print_paren then F.pp_print_string fmt "("; F.pp_print_string fmt (ctx_get_type type_id ctx); if tys <> [] then F.pp_print_space fmt (); Collections.List.iter_link (F.pp_print_space fmt) (extract_ty ctx fmt true) tys; - if inside then F.pp_print_string fmt ")") + if print_paren then F.pp_print_string fmt ")") | TypeVar vid -> F.pp_print_string fmt (ctx_get_type_var vid ctx) | Bool -> F.pp_print_string fmt ctx.fmt.bool_name | Char -> F.pp_print_string fmt ctx.fmt.char_name diff --git a/src/Pure.ml b/src/Pure.ml index 387d229f..96c7d211 100644 --- a/src/Pure.ml +++ b/src/Pure.ml @@ -471,6 +471,11 @@ type expression = | Let of bool * typed_lvalue * texpression * texpression (** Let binding. + TODO: the boolean should be replaced by an enum: sometimes we use + the error-monad, sometimes we use the state-error monad (and we + do this an a per-function basis! For instance, arithmetic functions + are always in the error monad). + The boolean controls whether the let is monadic or not. For instance, in F*: - non-monadic: `let x = ... in ...` diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index 092e6b0d..61d247ea 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -8,6 +8,12 @@ open TranslateCore let log = L.pure_micro_passes_log type config = { + use_state_monad : bool; + (** If `true`, use a state-error monad. + If `false`, only use an error monad. + + Using a state-error monad is necessary when modelling I/O, for instance. + *) decompose_monadic_let_bindings : bool; (** Some provers like F* don't support the decomposition of return values in monadic let-bindings: @@ -739,17 +745,22 @@ let keep_forward (config : config) (trans : pure_fun_translation) : bool = (** Add unit arguments (optionally) to functions with no arguments, and change their output type to use `result` *) -let to_monadic (add_unit_args : bool) (def : fun_def) : fun_def = +let to_monadic (config : config) (def : fun_def) : fun_def = (* Update the body *) let obj = object inherit [_] map_expression as super method! visit_call env call = - if call.args = [] && add_unit_args then - let args = [ mk_value_expression unit_rvalue None ] in - { call with args } (* Otherwise: nothing to do *) - else super#visit_call env call + match call.func with + | Regular (A.Local _, _) -> + if call.args = [] && config.add_unit_args then + let args = [ mk_value_expression unit_rvalue None ] in + { call with args } + else (* Otherwise: nothing to do *) super#visit_call env call + | Regular (A.Assumed _, _) | Unop _ | Binop _ -> + (* Unops, binops and primitive functions don't have unit arguments *) + super#visit_call env call end in let body = obj#visit_texpression () def.body in @@ -757,7 +768,7 @@ let to_monadic (add_unit_args : bool) (def : fun_def) : fun_def = (* Update the signature: first the input types *) let def = - if def.inputs = [] && add_unit_args then ( + if def.inputs = [] && config.add_unit_args then ( assert (def.signature.inputs = []); let signature = { def.signature with inputs = [ unit_ty ] } in let var_cnt = get_expression_min_var_counter def.body.e in @@ -774,10 +785,25 @@ let to_monadic (add_unit_args : bool) (def : fun_def) : fun_def = match (def.back_id, def.signature.outputs) with | None, [ out_ty ] -> (* Forward function: there is always exactly one output *) - mk_result_ty out_ty + (* We don't do the same thing if we use a state error monad or not: + * - error-monad: `result out_ty` + * - state-error: `state -> result (state & out_ty) + *) + if config.use_state_monad then + let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in + let ret = mk_arrow_ty mk_state_ty ret in + ret + else (* Simply wrap the type in `result` *) + mk_result_ty out_ty | Some _, outputs -> (* Backward function: we have to group them *) - mk_result_ty (mk_simpl_tuple_ty outputs) + (* We don't do the same thing if we use a state error monad or not *) + if config.use_state_monad then + let ret = mk_simpl_tuple_ty outputs in + let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; ret ]) in + let ret = mk_arrow_ty mk_state_ty ret in + ret + else mk_result_ty (mk_simpl_tuple_ty outputs) | _ -> failwith "Unreachable" in let outputs = [ output_ty ] in @@ -910,29 +936,102 @@ let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_def) : fun_def { def with body } (** Unfold the monadic let-bindings to explicit matches. *) -let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_def) : fun_def = +let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx) + (def : fun_def) : fun_def = + (* We may need to introduce fresh variables for the state *) + let var_cnt = get_expression_min_var_counter def.body.e in + let _, fresh_var_id = VarId.mk_stateful_generator var_cnt in + let fresh_state_var () = + let id = fresh_var_id () in + { id; basename = Some "st"; ty = mk_state_ty } + in (* It is a very simple map *) let obj = object (self) inherit [_] map_expression as super - method! visit_Let env monadic lv re e = - if not monadic then super#visit_Let env monadic lv re e + method! visit_Let state_var monadic lv re e = + if not monadic then super#visit_Let state_var monadic lv re e else - let fail_pat = mk_result_fail_lvalue lv.ty in - let fail_value = mk_result_fail_rvalue e.ty in - let fail_branch = - { pat = fail_pat; branch = mk_value_expression fail_value None } + (* We don't do the same thing if we use a state-error monad or simply + * an error monad. + * Note that some functions always live in the error monad (arithmetic + * operations, for instance). + *) + let re_call = + match re.e with + | Call call -> call + | _ -> raise (Failure "Unreachable: expected a function call") + in + (* TODO: this information should be computed in SymbolicToPure and + * store in an enum ("monadic" should be an enum, not a bool). + * Also: everything will be cleaner once we update the AST to make + * it more idiomatic lambda calculus... *) + let re_call_can_use_state = + match re_call.func with + | Regular (A.Local _, _) -> true + | Regular (A.Assumed _, _) | Unop _ | Binop _ -> false in - let success_pat = mk_result_return_lvalue lv in - let success_branch = { pat = success_pat; branch = e } in - let switch_body = Match [ fail_branch; success_branch ] in - let e = Switch (re, switch_body) in - self#visit_expression env e + if config.use_state_monad && re_call_can_use_state then + let re_call = + let call = re_call in + let state_value = mk_typed_rvalue_from_var state_var in + let args = call.args @ [ mk_value_expression state_value None ] in + Call { call with args } + in + let re = { re with e = re_call } in + (* Create the match *) + let fail_pat = mk_result_fail_lvalue lv.ty in + let fail_value = mk_result_fail_rvalue e.ty in + let fail_branch = + { pat = fail_pat; branch = mk_value_expression fail_value None } + in + (* The `Success` branch introduces a fresh state variable *) + let state_var = fresh_state_var () in + let state_value = mk_typed_lvalue_from_var state_var None in + let success_pat = + mk_result_return_lvalue + (mk_simpl_tuple_lvalue [ state_value; lv ]) + in + let success_branch = { pat = success_pat; branch = e } in + let switch_body = Match [ fail_branch; success_branch ] in + let e = Switch (re, switch_body) in + self#visit_expression state_var e + else + let fail_pat = mk_result_fail_lvalue lv.ty in + let fail_value = mk_result_fail_rvalue e.ty in + let fail_branch = + { pat = fail_pat; branch = mk_value_expression fail_value None } + in + let success_pat = mk_result_return_lvalue lv in + let success_branch = { pat = success_pat; branch = e } in + let switch_body = Match [ fail_branch; success_branch ] in + let e = Switch (re, switch_body) in + self#visit_expression state_var e end in (* Update the body *) - let body = obj#visit_texpression () def.body in + let input_state_var = fresh_state_var () in + let body = obj#visit_texpression input_state_var def.body in + let def = { def with body } in + (* We need to update the type if we revealed the state monad *) + let def = + if config.use_state_monad then + (* Update the signature *) + let sg = def.signature in + let sg_inputs = sg.inputs @ [ mk_state_ty ] in + let sg_outputs = Collections.List.to_cons_nil sg.outputs in + let _, sg_outputs = dest_arrow_ty sg_outputs in + let sg_outputs = [ sg_outputs ] in + let sg = { sg with inputs = sg_inputs; outputs = sg_outputs } in + (* Update the inputs list *) + let inputs = def.inputs @ [ input_state_var ] in + let input_lv = mk_typed_lvalue_from_var input_state_var None in + let inputs_lvs = def.inputs_lvs @ [ input_lv ] in + (* Update the definition *) + { def with signature = sg; inputs; inputs_lvs } + else def + in (* Return *) { def with body } @@ -981,8 +1080,9 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_def) : (* Add unit arguments for functions with no arguments, and change their return type. * **Rk.**: from now onwards, the types in the AST are correct (until now, * functions had return type `t` where they should have return type `result t`). - * Also, from now onwards, the outputs list has length 1. x*) - let def = to_monadic config.add_unit_args def in + * TODO: this is not true with the state-error monad, unless we unfold the monadic binds. + * Also, from now onwards, the outputs list has length 1. *) + let def = to_monadic config def in log#ldebug (lazy ("to_monadic:\n\n" ^ fun_def_to_string ctx def ^ "\n")); (* Convert the unit variables to `()` if they are used as right-values or @@ -1014,9 +1114,13 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_def) : log#ldebug (lazy ("filter_useless:\n\n" ^ fun_def_to_string ctx def ^ "\n")); - (* Decompose the monadic let-bindings *) + (* Decompose the monadic let-bindings - F* specific + * TODO: remove? With the state-error monad, it is becoming completely + * ad-hoc. *) let def = if config.decompose_monadic_let_bindings then ( + (* TODO: we haven't updated the code to handle the state-error monad *) + assert (not config.use_state_monad); let def = decompose_monadic_let_bindings ctx def in log#ldebug (lazy @@ -1033,7 +1137,7 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_def) : (* Unfold the monadic let-bindings *) let def = if config.unfold_monadic_let_bindings then ( - let def = unfold_monadic_let_bindings ctx def in + let def = unfold_monadic_let_bindings config ctx def in log#ldebug (lazy ("unfold_monadic_let_bindings:\n\n" ^ fun_def_to_string ctx def diff --git a/src/PureUtils.ml b/src/PureUtils.ml index e637b6ba..26dc6294 100644 --- a/src/PureUtils.ml +++ b/src/PureUtils.ml @@ -98,12 +98,14 @@ let mk_adt_lvalue (adt_ty : ty) (variant_id : VariantId.id) { value; ty = adt_ty } let ty_as_integer (t : ty) : T.integer_type = - match t with Integer int_ty -> int_ty | _ -> failwith "Unreachable" + match t with Integer int_ty -> int_ty | _ -> raise (Failure "Unreachable") (* TODO: move *) let type_def_is_enum (def : T.type_def) : bool = match def.kind with T.Struct _ -> false | Enum _ -> true +let mk_state_ty : ty = Adt (Assumed State, []) + let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ]) let mk_result_fail_rvalue (ty : ty) : typed_rvalue = @@ -130,6 +132,13 @@ let mk_result_return_lvalue (v : typed_lvalue) : typed_lvalue = in { value; ty } +let mk_arrow_ty (arg_ty : ty) (ret_ty : ty) : ty = Arrow (arg_ty, ret_ty) + +let dest_arrow_ty (ty : ty) : ty * ty = + match ty with + | Arrow (arg_ty, ret_ty) -> (arg_ty, ret_ty) + | _ -> raise (Failure "Unreachable") + let compute_constant_value_ty (cv : constant_value) : ty = match cv with | V.Scalar sv -> Integer sv.V.int_ty diff --git a/src/Translate.ml b/src/Translate.ml index ac2ee38c..077cc32d 100644 --- a/src/Translate.ml +++ b/src/Translate.ml @@ -29,12 +29,6 @@ type config = { let _ = assert_norm (FUNCTION () = Success ()) ``` *) - use_state_monad : bool; - (** If `true`, use a state-error monad. - If `false`, only use an error monad. - - Using a state-error monad is necessary when modelling I/O, for instance. - *) extract_decreases_clauses : bool; (** If `true`, insert `decreases` clauses for all the recursive definitions. diff --git a/src/main.ml b/src/main.ml index df2d1b0c..86f15959 100644 --- a/src/main.ml +++ b/src/main.ml @@ -177,15 +177,20 @@ let () = filter_useless_monadic_calls = !filter_useless_calls; filter_useless_functions = !filter_useless_functions; add_unit_args = false; + use_state_monad = not !no_state; } in + (* Small issue: the monadic `bind` only works for the error-monad, not + * the state-error monad (there are definitions to write and piping to do) *) + assert ( + (not micro_passes_config.use_state_monad) + || micro_passes_config.unfold_monadic_let_bindings); let trans_config = { Translate.eval_config; mp_config = micro_passes_config; split_files = not !no_split_files; test_unit_functions; - use_state_monad = not !no_state; extract_decreases_clauses = not !no_decreases_clauses; extract_template_decreases_clauses = !template_decreases_clauses; } |