summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSon Ho2022-05-04 14:13:20 +0200
committerSon Ho2022-05-04 14:13:20 +0200
commitfb6fdfd0c57de1ce16fb6bc373d5593c9446b0bb (patch)
treed3da4628c0cabd07ac740c484805fbce0e1fc6c6 /src
parent37f80fd592f703ab9b14a9d3d5d638b9c335997f (diff)
Make progress updating the code
Diffstat (limited to '')
-rw-r--r--src/ExtractToFStar.ml5
-rw-r--r--src/PureMicroPasses.ml191
-rw-r--r--src/SymbolicToPure.ml11
-rw-r--r--src/Translate.ml56
-rw-r--r--src/main.ml7
5 files changed, 76 insertions, 194 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml
index 068448e9..4baa1fd6 100644
--- a/src/ExtractToFStar.ml
+++ b/src/ExtractToFStar.ml
@@ -1368,8 +1368,7 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter)
if has_decreases_clause then (
F.pp_print_string fmt "Tot";
F.pp_print_space fmt ());
- extract_ty ctx fmt has_decreases_clause
- (Collections.List.to_cons_nil def.signature.outputs);
+ extract_ty ctx fmt has_decreases_clause def.signature.output;
(* Close the box for the return type *)
F.pp_close_box fmt ();
(* Print the decrease clause - rk.: a function with a decreases clause
@@ -1476,7 +1475,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
if
sg.type_params = []
&& (sg.inputs = [ mk_unit_ty ] || sg.inputs = [])
- && sg.outputs = [ mk_result_ty mk_unit_ty ]
+ && sg.output = mk_result_ty mk_unit_ty
then (
(* Add a break before *)
F.pp_print_break fmt 0 0;
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml
index f76dd2f4..0c371420 100644
--- a/src/PureMicroPasses.ml
+++ b/src/PureMicroPasses.ml
@@ -63,7 +63,6 @@ type config = {
borrows as inputs, it can't return mutable borrows; we actually
dynamically check for that).
*)
- use_state_monad : bool; (** TODO: remove *)
}
(** A configuration to control the application of the passes *)
@@ -920,15 +919,9 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
*)
let filter_if_backward_with_no_outputs (config : config) (def : fun_decl) :
fun_decl option =
- let return_ty =
- if config.use_state_monad then
- mk_arrow mk_state_ty
- (mk_result_ty (mk_simpl_tuple_ty [ mk_state_ty; mk_unit_ty ]))
- else mk_result_ty (mk_simpl_tuple_ty [ mk_unit_ty ])
- in
if
config.filter_useless_functions && Option.is_some def.back_id
- && def.signature.outputs = [ return_ty ]
+ && def.signature.output = mk_result_ty mk_unit_ty
then None
else Some def
@@ -954,7 +947,7 @@ let keep_forward (config : config) (trans : pure_fun_translation) : bool =
* they should be lists of length 1. *)
if
config.filter_useless_functions
- && fwd.signature.outputs = [ mk_result_ty mk_unit_ty ]
+ && fwd.signature.output = mk_result_ty mk_unit_ty
&& backs <> []
then false
else true
@@ -1108,85 +1101,27 @@ let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) :
{ def with body }
(** Unfold the monadic let-bindings to explicit matches. *)
-let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
- (def : fun_decl) : fun_decl =
+let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
match def.body with
| None -> def
| Some body ->
- (* We may need to introduce fresh variables for the state *)
- let fresh_var_id =
- let var_cnt = get_body_min_var_counter body in
- let _, fresh_var_id = VarId.mk_stateful_generator var_cnt in
- fresh_var_id
- in
- let fresh_state_var () =
- let id = fresh_var_id () in
- { id; basename = Some ConstStrings.state_basename; ty = mk_state_ty }
- in
(* It is a very simple map *)
let obj =
object (_self)
inherit [_] map_expression as super
- method! visit_Switch env scrut switch_body =
- (* We transform the switches the following way (if their branches
- * are stateful):
- * ```
- * match x with
- * | Pati -> branchi
- *
- * ~~>
- *
- * fun st ->
- * match x with
- * | Pati -> branchi st
- * ```
- *
- * The reason is that after unfolding the monadic lets, we often
- * have this: `(match x with | ...) st`, and we want to "push" the
- * `st` variable inside.
- *)
- let sb_ty = get_switch_body_ty switch_body in
- if Option.is_some (opt_destruct_state_monad_result sb_ty) then
- (* Generate a fresh state variable *)
- let state_var = fresh_state_var () in
- let state_value = mk_texpression_from_var state_var in
- let state_lvar = mk_typed_pattern_from_var state_var None in
- (* Apply in all the branches and reconstruct the switch *)
- let mk_app e = mk_app e state_value in
- let switch_body = map_switch_body_branches mk_app switch_body in
- let e = mk_switch scrut switch_body in
- let e = mk_abs state_lvar e in
- (* Introduce the lambda and continue
- * Rk.: we will revisit the switch, but won't loop because its
- * type has now changed (the `state -> ...` disappeared) *)
- super#visit_Abs env state_lvar e
- else super#visit_Switch env scrut switch_body
-
method! visit_Let env monadic lv re e =
- (* For now, we do the following transformation:
+ (* We simply do the following transformation:
* ```
- * x <-- re; e
+ * pat <-- re; e
*
* ~~>
*
- * (fun st ->
- * match re st with
- * | Return (st', x) -> e st'
- * | Fail err -> Fail err)
+ * match re with
+ * | Fail err -> Fail err
+ * | Return pat -> e
* ```
- *
- * We rely on the simplification pass which comes later to normalize
- * away expressions like `(fun x -> e) y`.
- *
- * TODO: fix the use of state-error monads (with the bakward functions,
- * we apply some updates twice...
- * It would be better if symbolic to pure generated code of the
- * following shape:
- * `(st1, x) <-- e st0`
- * Then, this micro-pass would only expand the monadic let-bindings
- * (we wouldn't need to introduce state variables).
- * *)
+ *)
(* TODO: we should use a monad "kind" instead of a boolean *)
if not monadic then super#visit_Let env monadic lv re e
else
@@ -1197,95 +1132,24 @@ let unfold_monadic_let_bindings (config : config) (_ctx : trans_ctx)
*)
(* TODO: this information should be computed in SymbolicToPure and
* store in an enum ("monadic" should be an enum, not a bool). *)
- let re_uses_state =
- Option.is_some (opt_destruct_state_monad_result re.ty)
- in
- if re_uses_state then (
- let e0 = e in
- (* Create a fresh state variable *)
- let state_var = fresh_state_var () in
- (* The type of `e` is: `state -> e_no_arrow_ty` *)
- let _, e_no_arrow_ty = destruct_arrow e.ty in
- let e_no_monad_ty = destruct_result e_no_arrow_ty in
- let _, re_no_arrow_ty = destruct_arrow re.ty in
- let re_no_monad_ty = destruct_result re_no_arrow_ty in
- (* Add the state argument on the right-expression *)
- let re =
- let state_value = mk_texpression_from_var state_var in
- mk_app re state_value
- in
- (* Create the match *)
- let fail_pat = mk_result_fail_pattern re_no_monad_ty in
- let fail_value = mk_result_fail_texpression e_no_monad_ty in
- let fail_branch = { pat = fail_pat; branch = fail_value } in
- (* The `Success` branch introduces a fresh state variable *)
- let pat_state_var = fresh_state_var () in
- let pat_state_pattern =
- mk_typed_pattern_from_var pat_state_var None
- in
- let success_pat =
- mk_result_return_pattern
- (mk_simpl_tuple_pattern [ pat_state_pattern; lv ])
- in
- let pat_state_rvalue = mk_texpression_from_var pat_state_var in
- (* TODO: write a utility to create matches (and perform
- * type-checking, etc.) *)
- let success_branch =
- { pat = success_pat; branch = mk_app e pat_state_rvalue }
- in
- let switch_body = Match [ fail_branch; success_branch ] in
- let e = Switch (re, switch_body) in
- let e = { e; ty = e_no_arrow_ty } in
- (* Add the lambda to introduce the state variable *)
- let e = mk_abs (mk_typed_pattern_from_var state_var None) e in
- (* Sanity check *)
- assert (e0.ty = e.ty);
- assert (fail_branch.branch.ty = success_branch.branch.ty);
- (* Continue *)
- super#visit_expression env e.e)
- else
- let re_ty = Option.get (opt_destruct_result re.ty) in
- assert (lv.ty = re_ty);
- let fail_pat = mk_result_fail_pattern lv.ty in
- let fail_value = mk_result_fail_texpression e.ty in
- let fail_branch = { pat = fail_pat; branch = fail_value } in
- let success_pat = mk_result_return_pattern lv in
- let success_branch = { pat = success_pat; branch = e } in
- let switch_body = Match [ fail_branch; success_branch ] in
- let e = Switch (re, switch_body) in
- (* Continue *)
- super#visit_expression env e
+ let re_ty = Option.get (opt_destruct_result re.ty) in
+ assert (lv.ty = re_ty);
+ let fail_pat = mk_result_fail_pattern lv.ty in
+ let fail_value = mk_result_fail_texpression e.ty in
+ let fail_branch = { pat = fail_pat; branch = fail_value } in
+ let success_pat = mk_result_return_pattern lv in
+ let success_branch = { pat = success_pat; branch = e } in
+ let switch_body = Match [ fail_branch; success_branch ] in
+ let e = Switch (re, switch_body) in
+ (* Continue *)
+ super#visit_expression env e
end
in
- (* Update the body: add *)
- let body, signature =
- let state_var = fresh_state_var () in
- (* First, unfold the expressions inside the body *)
- let body_e = obj#visit_texpression () body.body in
- (* Then, add a "state" input variable if necessary: *)
- if config.use_state_monad then
- (* - in the body *)
- let state_rvalue = mk_texpression_from_var state_var in
- let body_e = mk_app body_e state_rvalue in
- (* - in the signature *)
- let sg = def.signature in
- (* Input types *)
- let sg_inputs = sg.inputs @ [ mk_state_ty ] in
- (* Output types *)
- let sg_outputs = Collections.List.to_cons_nil sg.outputs in
- let _, sg_outputs = dest_arrow_ty sg_outputs in
- let sg_outputs = [ sg_outputs ] in
- let sg = { sg with inputs = sg_inputs; outputs = sg_outputs } in
- (* Input list *)
- let inputs = body.inputs @ [ state_var ] in
- let input_lv = mk_typed_pattern_from_var state_var None in
- let inputs_lvs = body.inputs_lvs @ [ input_lv ] in
- let body = { body = body_e; inputs; inputs_lvs } in
- (body, sg)
- else ({ body with body = body_e }, def.signature)
- in
+ (* Update the body *)
+ let body_e = obj#visit_texpression () body.body in
+ let body = { body with body = body_e } in
(* Return *)
- { def with body = Some body; signature }
+ { def with body = Some body }
(** Apply all the micro-passes to a function.
@@ -1359,12 +1223,9 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_decl) :
(lazy ("filter_useless:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
(* Decompose the monadic let-bindings - F* specific
- * TODO: remove? With the state-error monad, it is becoming completely
- * ad-hoc. *)
+ * TODO: remove? *)
let def =
if config.decompose_monadic_let_bindings then (
- (* TODO: we haven't updated the code to handle the state-error monad *)
- assert (not config.use_state_monad);
let def = decompose_monadic_let_bindings ctx def in
log#ldebug
(lazy
@@ -1381,7 +1242,7 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_decl) :
(* Unfold the monadic let-bindings *)
let def =
if config.unfold_monadic_let_bindings then (
- let def = unfold_monadic_let_bindings config ctx def in
+ let def = unfold_monadic_let_bindings ctx def in
log#ldebug
(lazy
("unfold_monadic_let_bindings:\n\n" ^ fun_decl_to_string ctx def
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index 466e5562..fa482b8e 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -101,7 +101,12 @@ type bs_ctx = {
fun_context : fun_context;
fun_decl : A.fun_decl;
bid : T.RegionGroupId.id option; (** TODO: rename *)
- ret_ty : ty; (** The return type - we use it to translate `Panic` *)
+ output_ty : ty;
+ (** The output type - we use it to translate `Panic`.
+
+ This should be the directly translated output type (i.e., no state,
+ no result).
+ *)
sv_to_var : var V.SymbolicValueId.Map.t;
(** Whenever we encounter a new symbolic value (introduced because of
a symbolic expansion or upon ending an abstraction, for instance)
@@ -1082,7 +1087,7 @@ and translate_panic (config : config) (ctx : bs_ctx) : texpression =
(* If we use a state monad, we need to add a lambda for the state variable *)
if config.use_state_monad then
(* Create the `Fail` value *)
- let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; ctx.ret_ty ] in
+ let ret_ty = mk_simpl_tuple_ty [ mk_state_ty; ctx.output_ty ] in
let ret_v = mk_result_fail_texpression ret_ty in
(* Add the lambda *)
let _, state_var =
@@ -1090,7 +1095,7 @@ and translate_panic (config : config) (ctx : bs_ctx) : texpression =
in
let state_pattern = mk_typed_pattern_from_var state_var None in
mk_abs state_pattern ret_v
- else mk_result_fail_texpression ctx.ret_ty
+ else mk_result_fail_texpression ctx.output_ty
and translate_return (config : config) (opt_v : V.typed_value option)
(ctx : bs_ctx) : texpression =
diff --git a/src/Translate.ml b/src/Translate.ml
index d9b42f6b..d69f1379 100644
--- a/src/Translate.ml
+++ b/src/Translate.ml
@@ -15,6 +15,12 @@ let log = TranslateCore.log
type config = {
eval_config : Contexts.partial_config;
mp_config : Micro.config;
+ use_state_monad : bool;
+ (** If `true`, use a state-error monad.
+ If `false`, only use an error monad.
+
+ Using a state-error monad is necessary when modelling I/O, for instance.
+ *)
split_files : bool;
(** Controls whether we split the generated definitions between different
files for the types, clauses and functions, or if we group them in
@@ -92,7 +98,7 @@ let translate_function_to_symbolics (config : C.partial_config)
TODO: maybe we should introduce a record for this.
*)
let translate_function_to_pure (config : C.partial_config)
- (mp_config : Micro.config) (trans_ctx : trans_ctx)
+ (mp_config : Micro.config) (use_state_monad : bool) (trans_ctx : trans_ctx)
(fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdMap.t)
(pure_type_decls : Pure.type_decl Pure.TypeDeclId.Map.t) (fdef : A.fun_decl)
: pure_fun_translation =
@@ -116,13 +122,14 @@ let translate_function_to_pure (config : C.partial_config)
(* Initialize the context *)
let forward_sig = RegularFunIdMap.find (A.Regular def_id, None) fun_sigs in
- let forward_ret_ty =
- match forward_sig.sg.outputs with
+ let forward_output_ty =
+ match forward_sig.sg.doutputs with
| [ ty ] -> ty
| _ -> failwith "Unreachable"
in
let sv_to_var = V.SymbolicValueId.Map.empty in
let var_counter = Pure.VarId.generator_zero in
+ let state_var, var_counter = Pure.VarId.fresh var_counter in
let calls = V.FunCallId.Map.empty in
let abstractions = V.AbstractionId.Map.empty in
let type_context =
@@ -139,10 +146,11 @@ let translate_function_to_pure (config : C.partial_config)
{
SymbolicToPure.bid = None;
(* Dummy for now *)
- ret_ty = forward_ret_ty;
+ output_ty = forward_output_ty;
(* Will need to be updated for the backward functions *)
sv_to_var;
var_counter;
+ state_var;
type_context;
fun_context;
fun_decl = fdef;
@@ -179,7 +187,7 @@ let translate_function_to_pure (config : C.partial_config)
{
SymbolicToPure.filter_useless_back_calls =
mp_config.filter_useless_monadic_calls;
- use_state_monad = mp_config.use_state_monad;
+ use_state_monad;
}
in
@@ -208,8 +216,10 @@ let translate_function_to_pure (config : C.partial_config)
let backward_sg =
RegularFunIdMap.find (A.Regular def_id, Some back_id) fun_sigs
in
- let backward_ret_ty = mk_simpl_tuple_ty backward_sg.sg.outputs in
- let ctx = { ctx with bid = Some back_id; ret_ty = backward_ret_ty } in
+ let backward_output_ty = mk_simpl_tuple_ty backward_sg.sg.doutputs in
+ let ctx =
+ { ctx with bid = Some back_id; output_ty = backward_output_ty }
+ in
(* Translate *)
SymbolicToPure.translate_fun_decl sp_config ctx None
@@ -238,7 +248,7 @@ let translate_function_to_pure (config : C.partial_config)
* present in the input values of the rust function: for those we reuse
* the names of the input values. *)
let backward_outputs =
- List.combine backward_sg.output_names backward_sg.sg.outputs
+ List.combine backward_sg.output_names backward_sg.sg.doutputs
in
let ctx, backward_outputs =
SymbolicToPure.fresh_vars backward_outputs ctx
@@ -246,7 +256,7 @@ let translate_function_to_pure (config : C.partial_config)
let backward_output_tys =
List.map (fun (v : Pure.var) -> v.ty) backward_outputs
in
- let backward_ret_ty = mk_simpl_tuple_ty backward_output_tys in
+ let backward_output_ty = mk_simpl_tuple_ty backward_output_tys in
let backward_inputs =
T.RegionGroupId.Map.singleton back_id backward_inputs
in
@@ -259,7 +269,7 @@ let translate_function_to_pure (config : C.partial_config)
{
ctx with
bid = Some back_id;
- ret_ty = backward_ret_ty;
+ output_ty = backward_output_ty;
backward_inputs;
backward_outputs;
}
@@ -276,7 +286,7 @@ let translate_function_to_pure (config : C.partial_config)
(pure_forward, pure_backwards)
let translate_module_to_pure (config : C.partial_config)
- (mp_config : Micro.config) (m : M.llbc_module) :
+ (mp_config : Micro.config) (use_state_monad : bool) (m : M.llbc_module) :
trans_ctx * Pure.type_decl list * (bool * pure_fun_translation) list =
(* Debug *)
log#ldebug (lazy "translate_module_to_pure");
@@ -316,15 +326,23 @@ let translate_module_to_pure (config : C.partial_config)
m.functions
in
let sigs = List.append assumed_sigs local_sigs in
+ let sp_config =
+ {
+ SymbolicToPure.filter_useless_back_calls =
+ mp_config.filter_useless_monadic_calls;
+ use_state_monad;
+ }
+ in
let fun_sigs =
- SymbolicToPure.translate_fun_signatures type_context.type_infos sigs
+ SymbolicToPure.translate_fun_signatures sp_config type_context.type_infos
+ sigs
in
(* Translate all the *transparent* functions *)
let pure_translations =
List.map
- (translate_function_to_pure config mp_config trans_ctx fun_sigs
- type_decls_map)
+ (translate_function_to_pure config mp_config use_state_monad trans_ctx
+ fun_sigs type_decls_map)
m.functions
in
@@ -349,6 +367,7 @@ type gen_ctx = {
type gen_config = {
mp_config : Micro.config;
+ use_state_monad : bool;
extract_types : bool;
extract_decreases_clauses : bool;
extract_template_decreases_clauses : bool;
@@ -475,7 +494,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
(* Is there an input parameter "visible" for the state used in
* the state error monad (if we use a state error monad)? *)
let has_state_param =
- config.mp_config.use_state_monad
+ config.use_state_monad
&& config.mp_config.unfold_monadic_let_bindings
in
(* Check if the definition needs to be filtered or not *)
@@ -599,7 +618,8 @@ let translate_module (filename : string) (dest_dir : string) (config : config)
(m : M.llbc_module) : unit =
(* Translate the module to the pure AST *)
let trans_ctx, trans_types, trans_funs =
- translate_module_to_pure config.eval_config config.mp_config m
+ translate_module_to_pure config.eval_config config.mp_config
+ config.use_state_monad m
in
(* Initialize the extraction context - for now we extract only to F* *)
@@ -711,13 +731,14 @@ let translate_module (filename : string) (dest_dir : string) (config : config)
}
in
- let use_state = config.mp_config.use_state_monad in
+ let use_state = config.use_state_monad in
(* Extract one or several files, depending on the configuration *)
if config.split_files then (
let base_gen_config =
{
mp_config = config.mp_config;
+ use_state_monad = use_state;
extract_types = false;
extract_decreases_clauses = config.extract_decreases_clauses;
extract_template_decreases_clauses = false;
@@ -807,6 +828,7 @@ let translate_module (filename : string) (dest_dir : string) (config : config)
let gen_config =
{
mp_config = config.mp_config;
+ use_state_monad = use_state;
extract_types = true;
extract_decreases_clauses = config.extract_decreases_clauses;
extract_template_decreases_clauses =
diff --git a/src/main.ml b/src/main.ml
index 9d038f12..e635d910 100644
--- a/src/main.ml
+++ b/src/main.ml
@@ -177,14 +177,8 @@ let () =
unfold_monadic_let_bindings = !unfold_monads;
filter_useless_monadic_calls = !filter_useless_calls;
filter_useless_functions = !filter_useless_functions;
- use_state_monad = not !no_state;
}
in
- (* Small issue: the monadic `bind` only works for the error-monad, not
- * the state-error monad (there are definitions to write and piping to do) *)
- assert (
- (not micro_passes_config.use_state_monad)
- || micro_passes_config.unfold_monadic_let_bindings);
let trans_config =
{
Translate.eval_config;
@@ -193,6 +187,7 @@ let () =
test_unit_functions;
extract_decreases_clauses = not !no_decreases_clauses;
extract_template_decreases_clauses = !template_decreases_clauses;
+ use_state_monad = not !no_state;
}
in
Translate.translate_module filename dest_dir trans_config m