summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSon Ho2022-02-23 23:36:53 +0100
committerSon Ho2022-02-23 23:36:53 +0100
commit532b43ad73a4964cd75d8548d43eb894b7f225c1 (patch)
tree485fc8c35aebd2467878dc18e3f675a9e43175a1 /src
parente3430dcb5e944af0903b272669e6ddbb8e7d59c3 (diff)
Start working on generating code which uses a state-error monad
Diffstat (limited to 'src')
-rw-r--r--src/ExtractToFStar.ml7
-rw-r--r--src/Pure.ml5
-rw-r--r--src/PureMicroPasses.ml154
-rw-r--r--src/PureUtils.ml11
-rw-r--r--src/Translate.ml6
-rw-r--r--src/main.ml7
6 files changed, 154 insertions, 36 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml
index 3840ea1e..dcaef438 100644
--- a/src/ExtractToFStar.ml
+++ b/src/ExtractToFStar.ml
@@ -95,7 +95,7 @@ let fstar_keywords =
List.concat [ named_unops; named_binops; misc ]
let fstar_assumed_adts : (assumed_ty * string) list =
- [ (Result, "result"); (Option, "option"); (Vec, "vec") ]
+ [ (State, "state"); (Result, "result"); (Option, "option"); (Vec, "vec") ]
let fstar_assumed_structs : (assumed_ty * string) list = []
@@ -394,12 +394,13 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(extract_ty ctx fmt true) tys;
F.pp_print_string fmt ")")
| AdtId _ | Assumed _ ->
- if inside then F.pp_print_string fmt "(";
+ let print_paren = inside && tys <> [] in
+ if print_paren then F.pp_print_string fmt "(";
F.pp_print_string fmt (ctx_get_type type_id ctx);
if tys <> [] then F.pp_print_space fmt ();
Collections.List.iter_link (F.pp_print_space fmt)
(extract_ty ctx fmt true) tys;
- if inside then F.pp_print_string fmt ")")
+ if print_paren then F.pp_print_string fmt ")")
| TypeVar vid -> F.pp_print_string fmt (ctx_get_type_var vid ctx)
| Bool -> F.pp_print_string fmt ctx.fmt.bool_name
| Char -> F.pp_print_string fmt ctx.fmt.char_name
diff --git a/src/Pure.ml b/src/Pure.ml
index 387d229f..96c7d211 100644
--- a/src/Pure.ml
+++ b/src/Pure.ml
@@ -471,6 +471,11 @@ type expression =
| Let of bool * typed_lvalue * texpression * texpression
(** Let binding.
+ TODO: the boolean should be replaced by an enum: sometimes we use
+ the error-monad, sometimes we use the state-error monad (and we
+ do this an a per-function basis! For instance, arithmetic functions
+ are always in the error monad).
+
The boolean controls whether the let is monadic or not.
For instance, in F*:
- non-monadic: `let x = ... in ...`
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml
index 092e6b0d..61d247ea 100644
--- a/src/PureMicroPasses.ml
+++ b/src/PureMicroPasses.ml
@@ -8,6 +8,12 @@ open TranslateCore
let log = L.pure_micro_passes_log
type config = {
+ use_state_monad : bool;
+ (** If `true`, use a state-error monad.
+ If `false`, only use an error monad.
+
+ Using a state-error monad is necessary when modelling I/O, for instance.
+ *)
decompose_monadic_let_bindings : bool;
(** Some provers like F* don't support the decomposition of return values
in monadic let-bindings:
@@ -739,17 +745,22 @@ let keep_forward (config : config) (trans : pure_fun_translation) : bool =
(** Add unit arguments (optionally) to functions with no arguments, and
change their output type to use `result`
*)
-let to_monadic (add_unit_args : bool) (def : fun_def) : fun_def =
+let to_monadic (config : config) (def : fun_def) : fun_def =
(* Update the body *)
let obj =
object
inherit [_] map_expression as super
method! visit_call env call =
- if call.args = [] && add_unit_args then
- let args = [ mk_value_expression unit_rvalue None ] in
- { call with args } (* Otherwise: nothing to do *)
- else super#visit_call env call
+ match call.func with
+ | Regular (A.Local _, _) ->
+ if call.args = [] && config.add_unit_args then
+ let args = [ mk_value_expression unit_rvalue None ] in
+ { call with args }
+ else (* Otherwise: nothing to do *) super#visit_call env call
+ | Regular (A.Assumed _, _) | Unop _ | Binop _ ->
+ (* Unops, binops and primitive functions don't have unit arguments *)
+ super#visit_call env call
end
in
let body = obj#visit_texpression () def.body in
@@ -757,7 +768,7 @@ let to_monadic (add_unit_args : bool) (def : fun_def) : fun_def =
(* Update the signature: first the input types *)
let def =
- if def.inputs = [] && add_unit_args then (
+ if def.inputs = [] && config.add_unit_args then (
assert (def.signature.inputs = []);
let signature = { def.signature with inputs = [ unit_ty ] } in
let var_cnt = get_expression_min_var_counter def.body.e in
@@ -774,10 +785,25 @@ let to_monadic (add_unit_args : bool) (def : fun_def) : fun_def =
match (def.back_id, def.signature.outputs) with
| None, [ out_ty ] ->
(* Forward function: there is always exactly one output *)
- mk_result_ty out_ty
+ (* We don't do the same thing if we use a state error monad or not:
+ * - error-monad: `result out_ty`
+ * - state-error: `state -> result (state & out_ty)
+ *)
+ if config.use_state_monad then
+ let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; out_ty ]) in
+ let ret = mk_arrow_ty mk_state_ty ret in
+ ret
+ else (* Simply wrap the type in `result` *)
+ mk_result_ty out_ty
| Some _, outputs ->
(* Backward function: we have to group them *)
- mk_result_ty (mk_simpl_tuple_ty outputs)
+ (* We don't do the same thing if we use a state error monad or not *)
+ if config.use_state_monad then
+ let ret = mk_simpl_tuple_ty outputs in
+ let ret = mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; ret ]) in
+ let ret = mk_arrow_ty mk_state_ty ret in
+ ret
+ else mk_result_ty (mk_simpl_tuple_ty outputs)
| _ -> failwith "Unreachable"
in
let outputs = [ output_ty ] in
@@ -910,29 +936,102 @@ let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_def) : fun_def
{ def with body }
(** Unfold the monadic let-bindings to explicit matches. *)
-let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_def) : fun_def =
+let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
+ (def : fun_def) : fun_def =
+ (* We may need to introduce fresh variables for the state *)
+ let var_cnt = get_expression_min_var_counter def.body.e in
+ let _, fresh_var_id = VarId.mk_stateful_generator var_cnt in
+ let fresh_state_var () =
+ let id = fresh_var_id () in
+ { id; basename = Some "st"; ty = mk_state_ty }
+ in
(* It is a very simple map *)
let obj =
object (self)
inherit [_] map_expression as super
- method! visit_Let env monadic lv re e =
- if not monadic then super#visit_Let env monadic lv re e
+ method! visit_Let state_var monadic lv re e =
+ if not monadic then super#visit_Let state_var monadic lv re e
else
- let fail_pat = mk_result_fail_lvalue lv.ty in
- let fail_value = mk_result_fail_rvalue e.ty in
- let fail_branch =
- { pat = fail_pat; branch = mk_value_expression fail_value None }
+ (* We don't do the same thing if we use a state-error monad or simply
+ * an error monad.
+ * Note that some functions always live in the error monad (arithmetic
+ * operations, for instance).
+ *)
+ let re_call =
+ match re.e with
+ | Call call -> call
+ | _ -> raise (Failure "Unreachable: expected a function call")
+ in
+ (* TODO: this information should be computed in SymbolicToPure and
+ * store in an enum ("monadic" should be an enum, not a bool).
+ * Also: everything will be cleaner once we update the AST to make
+ * it more idiomatic lambda calculus... *)
+ let re_call_can_use_state =
+ match re_call.func with
+ | Regular (A.Local _, _) -> true
+ | Regular (A.Assumed _, _) | Unop _ | Binop _ -> false
in
- let success_pat = mk_result_return_lvalue lv in
- let success_branch = { pat = success_pat; branch = e } in
- let switch_body = Match [ fail_branch; success_branch ] in
- let e = Switch (re, switch_body) in
- self#visit_expression env e
+ if config.use_state_monad && re_call_can_use_state then
+ let re_call =
+ let call = re_call in
+ let state_value = mk_typed_rvalue_from_var state_var in
+ let args = call.args @ [ mk_value_expression state_value None ] in
+ Call { call with args }
+ in
+ let re = { re with e = re_call } in
+ (* Create the match *)
+ let fail_pat = mk_result_fail_lvalue lv.ty in
+ let fail_value = mk_result_fail_rvalue e.ty in
+ let fail_branch =
+ { pat = fail_pat; branch = mk_value_expression fail_value None }
+ in
+ (* The `Success` branch introduces a fresh state variable *)
+ let state_var = fresh_state_var () in
+ let state_value = mk_typed_lvalue_from_var state_var None in
+ let success_pat =
+ mk_result_return_lvalue
+ (mk_simpl_tuple_lvalue [ state_value; lv ])
+ in
+ let success_branch = { pat = success_pat; branch = e } in
+ let switch_body = Match [ fail_branch; success_branch ] in
+ let e = Switch (re, switch_body) in
+ self#visit_expression state_var e
+ else
+ let fail_pat = mk_result_fail_lvalue lv.ty in
+ let fail_value = mk_result_fail_rvalue e.ty in
+ let fail_branch =
+ { pat = fail_pat; branch = mk_value_expression fail_value None }
+ in
+ let success_pat = mk_result_return_lvalue lv in
+ let success_branch = { pat = success_pat; branch = e } in
+ let switch_body = Match [ fail_branch; success_branch ] in
+ let e = Switch (re, switch_body) in
+ self#visit_expression state_var e
end
in
(* Update the body *)
- let body = obj#visit_texpression () def.body in
+ let input_state_var = fresh_state_var () in
+ let body = obj#visit_texpression input_state_var def.body in
+ let def = { def with body } in
+ (* We need to update the type if we revealed the state monad *)
+ let def =
+ if config.use_state_monad then
+ (* Update the signature *)
+ let sg = def.signature in
+ let sg_inputs = sg.inputs @ [ mk_state_ty ] in
+ let sg_outputs = Collections.List.to_cons_nil sg.outputs in
+ let _, sg_outputs = dest_arrow_ty sg_outputs in
+ let sg_outputs = [ sg_outputs ] in
+ let sg = { sg with inputs = sg_inputs; outputs = sg_outputs } in
+ (* Update the inputs list *)
+ let inputs = def.inputs @ [ input_state_var ] in
+ let input_lv = mk_typed_lvalue_from_var input_state_var None in
+ let inputs_lvs = def.inputs_lvs @ [ input_lv ] in
+ (* Update the definition *)
+ { def with signature = sg; inputs; inputs_lvs }
+ else def
+ in
(* Return *)
{ def with body }
@@ -981,8 +1080,9 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_def) :
(* Add unit arguments for functions with no arguments, and change their return type.
* **Rk.**: from now onwards, the types in the AST are correct (until now,
* functions had return type `t` where they should have return type `result t`).
- * Also, from now onwards, the outputs list has length 1. x*)
- let def = to_monadic config.add_unit_args def in
+ * TODO: this is not true with the state-error monad, unless we unfold the monadic binds.
+ * Also, from now onwards, the outputs list has length 1. *)
+ let def = to_monadic config def in
log#ldebug (lazy ("to_monadic:\n\n" ^ fun_def_to_string ctx def ^ "\n"));
(* Convert the unit variables to `()` if they are used as right-values or
@@ -1014,9 +1114,13 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_def) :
log#ldebug
(lazy ("filter_useless:\n\n" ^ fun_def_to_string ctx def ^ "\n"));
- (* Decompose the monadic let-bindings *)
+ (* Decompose the monadic let-bindings - F* specific
+ * TODO: remove? With the state-error monad, it is becoming completely
+ * ad-hoc. *)
let def =
if config.decompose_monadic_let_bindings then (
+ (* TODO: we haven't updated the code to handle the state-error monad *)
+ assert (not config.use_state_monad);
let def = decompose_monadic_let_bindings ctx def in
log#ldebug
(lazy
@@ -1033,7 +1137,7 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_def) :
(* Unfold the monadic let-bindings *)
let def =
if config.unfold_monadic_let_bindings then (
- let def = unfold_monadic_let_bindings ctx def in
+ let def = unfold_monadic_let_bindings config ctx def in
log#ldebug
(lazy
("unfold_monadic_let_bindings:\n\n" ^ fun_def_to_string ctx def
diff --git a/src/PureUtils.ml b/src/PureUtils.ml
index e637b6ba..26dc6294 100644
--- a/src/PureUtils.ml
+++ b/src/PureUtils.ml
@@ -98,12 +98,14 @@ let mk_adt_lvalue (adt_ty : ty) (variant_id : VariantId.id)
{ value; ty = adt_ty }
let ty_as_integer (t : ty) : T.integer_type =
- match t with Integer int_ty -> int_ty | _ -> failwith "Unreachable"
+ match t with Integer int_ty -> int_ty | _ -> raise (Failure "Unreachable")
(* TODO: move *)
let type_def_is_enum (def : T.type_def) : bool =
match def.kind with T.Struct _ -> false | Enum _ -> true
+let mk_state_ty : ty = Adt (Assumed State, [])
+
let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ])
let mk_result_fail_rvalue (ty : ty) : typed_rvalue =
@@ -130,6 +132,13 @@ let mk_result_return_lvalue (v : typed_lvalue) : typed_lvalue =
in
{ value; ty }
+let mk_arrow_ty (arg_ty : ty) (ret_ty : ty) : ty = Arrow (arg_ty, ret_ty)
+
+let dest_arrow_ty (ty : ty) : ty * ty =
+ match ty with
+ | Arrow (arg_ty, ret_ty) -> (arg_ty, ret_ty)
+ | _ -> raise (Failure "Unreachable")
+
let compute_constant_value_ty (cv : constant_value) : ty =
match cv with
| V.Scalar sv -> Integer sv.V.int_ty
diff --git a/src/Translate.ml b/src/Translate.ml
index ac2ee38c..077cc32d 100644
--- a/src/Translate.ml
+++ b/src/Translate.ml
@@ -29,12 +29,6 @@ type config = {
let _ = assert_norm (FUNCTION () = Success ())
```
*)
- use_state_monad : bool;
- (** If `true`, use a state-error monad.
- If `false`, only use an error monad.
-
- Using a state-error monad is necessary when modelling I/O, for instance.
- *)
extract_decreases_clauses : bool;
(** If `true`, insert `decreases` clauses for all the recursive definitions.
diff --git a/src/main.ml b/src/main.ml
index df2d1b0c..86f15959 100644
--- a/src/main.ml
+++ b/src/main.ml
@@ -177,15 +177,20 @@ let () =
filter_useless_monadic_calls = !filter_useless_calls;
filter_useless_functions = !filter_useless_functions;
add_unit_args = false;
+ use_state_monad = not !no_state;
}
in
+ (* Small issue: the monadic `bind` only works for the error-monad, not
+ * the state-error monad (there are definitions to write and piping to do) *)
+ assert (
+ (not micro_passes_config.use_state_monad)
+ || micro_passes_config.unfold_monadic_let_bindings);
let trans_config =
{
Translate.eval_config;
mp_config = micro_passes_config;
split_files = not !no_split_files;
test_unit_functions;
- use_state_monad = not !no_state;
extract_decreases_clauses = not !no_decreases_clauses;
extract_template_decreases_clauses = !template_decreases_clauses;
}