diff options
author | Son Ho | 2022-02-04 13:33:45 +0100 |
---|---|---|
committer | Son Ho | 2022-02-04 13:33:45 +0100 |
commit | 3ead957cf13ddd3b48ee85c008c6d56e44726eb4 (patch) | |
tree | 51e694665a3623cea8250bb0c3e4523c321fada1 /src | |
parent | 25200ad9664980d3276dd7462b4845a1a21c3e64 (diff) |
Work on decomposition of monadic let-bindings for F*
Diffstat (limited to '')
-rw-r--r-- | src/ExtractToFStar.ml | 12 | ||||
-rw-r--r-- | src/Identifiers.ml | 8 | ||||
-rw-r--r-- | src/PureMicroPasses.ml | 87 | ||||
-rw-r--r-- | src/main.ml | 3 |
4 files changed, 104 insertions, 6 deletions
diff --git a/src/ExtractToFStar.ml b/src/ExtractToFStar.ml index 7e4a11fe..1d45b239 100644 --- a/src/ExtractToFStar.ml +++ b/src/ExtractToFStar.ml @@ -824,15 +824,23 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_space fmt (); (* Open a box for the branch *) F.pp_open_hvbox fmt 0; + (* Print the `begin` if necessary *) let parenth = PureUtils.expression_requires_parentheses e_branch in - if parenth then F.pp_print_string fmt "("; + if parenth then ( + F.pp_print_string fmt "begin"; + F.pp_print_space fmt ()); + (* Print the branch expression *) extract_texpression ctx fmt false e_branch; - if parenth then F.pp_print_string fmt ")"; + (* Close the `begin ... end ` *) + if parenth then ( + F.pp_print_space fmt (); + F.pp_print_string fmt "end"); (* Close the box for the branch *) F.pp_close_box fmt (); (* Close the box for the then/else+branch *) F.pp_close_box fmt () in + extract_branch true e_then; extract_branch false e_else; (* Close the box for the whole `if ... then ... else ...` *) diff --git a/src/Identifiers.ml b/src/Identifiers.ml index dfcbb631..825b4ad9 100644 --- a/src/Identifiers.ml +++ b/src/Identifiers.ml @@ -20,6 +20,8 @@ module type Id = sig val fresh_stateful_generator : unit -> generator ref * (unit -> id) + val mk_stateful_generator : generator -> generator ref * (unit -> id) + val incr : id -> id (* TODO: this should be stateful! - but we may want to be able to duplicate @@ -103,8 +105,8 @@ module IdGen () : Id = struct let id = incr id in id - let fresh_stateful_generator () = - let g = ref 0 in + let mk_stateful_generator g = + let g = ref g in let fresh () = let id = !g in g := incr id; @@ -112,6 +114,8 @@ module IdGen () : Id = struct in (g, fresh) + let fresh_stateful_generator () = mk_stateful_generator 0 + let fresh gen = (gen, incr gen) let to_string = string_of_int diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index d8bfe4cd..ceee71dd 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -8,14 +8,38 @@ open TranslateCore let log = L.pure_micro_passes_log type config = { + decompose_monadic_let_bindings : bool; + (** Some provers like F* don't support the decomposition of return values + in monadic let-bindings: + ``` + // NOT supported in F* + let (x, y) <-- f (); + ... + ``` + + In such situations, we might want to introduce an intermediate + assignment: + ``` + let tmp <-- f (); + let (x, y) = tmp in + ... + ``` + *) unfold_monadic_let_bindings : bool; (** Controls the unfolding of monadic let-bindings to explicit matches: + `y <-- f x; ...` + becomes: + `match f x with | Failure -> Failure | Return y -> ...` - + + This is useful when extracting to F*: the support for monadic definitions is not super powerful. + Note that when [undolf_monadic_let_bindings] is true, setting + [decompose_monadic_let_bindings] to true and only makes the code + more verbose. *) filter_unused_monadic_calls : bool; (** Controls whether we try to filter the calls to monadic functions @@ -714,6 +738,51 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_def) : fun_def = let body = obj#visit_texpression () def.body in { def with body } +(** Decompose the monadic let-bindings. + + See the explanations in [config]. + *) +let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_def) : fun_def + = + (* Set up the var id generator *) + let cnt = get_expression_min_var_counter def.body.e in + let _, fresh_id = VarId.mk_stateful_generator cnt in + (* It is a very simple map *) + let obj = + object (self) + inherit [_] map_expression as super + + method! visit_Let env monadic lv re next_e = + if not monadic then super#visit_Let env monadic lv re next_e + else + (* If monadic, we need to check if the left-value is a variable: + * - if yes, don't decompose + * - if not, make the decomposition in two steps + *) + match lv.value with + | LvVar _ -> + (* Variable: nothing to do *) + super#visit_Let env monadic lv re next_e + | _ -> + (* Not a variable: decompose *) + (* Introduce a temporary variable to receive the value of the + * monadic binding *) + let vid = fresh_id () in + let tmp : var = { id = vid; basename = None; ty = lv.ty } in + let ltmp = mk_typed_lvalue_from_var tmp None in + let rtmp = mk_typed_rvalue_from_var tmp in + let rtmp = mk_value_expression rtmp None in + (* Visit the next expression *) + let next_e = self#visit_texpression env next_e in + (* Create the let-bindings *) + (mk_let true ltmp re (mk_let false lv rtmp next_e)).e + end + in + (* Update the body *) + let body = obj#visit_texpression () def.body in + (* Return *) + { def with body } + (** Unfold the monadic let-bindings to explicit matches. *) let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_def) : fun_def = (* It is a very simple map *) @@ -813,6 +882,22 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_def) : log#ldebug (lazy ("filter_unused:\n\n" ^ fun_def_to_string ctx def ^ "\n")); + (* Decompose the monadic let-bindings *) + let def = + if config.decompose_monadic_let_bindings then ( + let def = decompose_monadic_let_bindings ctx def in + log#ldebug + (lazy + ("decompose_monadic_let_bindings:\n\n" ^ fun_def_to_string ctx def + ^ "\n")); + def) + else ( + log#ldebug + (lazy + "ignoring decompose_monadic_let_bindings due to the configuration\n"); + def) + in + (* Unfold the monadic let-bindings *) let def = if config.unfold_monadic_let_bindings then ( diff --git a/src/main.ml b/src/main.ml index 7f7ed96c..0eaff02d 100644 --- a/src/main.ml +++ b/src/main.ml @@ -88,7 +88,8 @@ let () = let test_unit_functions = true in let micro_passes_config = { - Micro.unfold_monadic_let_bindings = true; + Micro.decompose_monadic_let_bindings = false; + unfold_monadic_let_bindings = true; filter_unused_monadic_calls = true; } in |