summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorSon Ho2022-11-14 09:06:15 +0100
committerSon HO2022-11-14 14:21:04 +0100
commit3eba613a9ff9d5c265fbe2676f6bd324728d9ca4 (patch)
treeeeac1d917f398906ab4aeaa3627561d980f0492a /compiler
parent2a0ecfbef81231a394df71817a4cd9e81582b7de (diff)
Implement a pass to decompose nested patterns in let-bindings
Diffstat (limited to '')
-rw-r--r--compiler/Config.ml21
-rw-r--r--compiler/Driver.ml34
-rw-r--r--compiler/PureMicroPasses.ml170
3 files changed, 172 insertions, 53 deletions
diff --git a/compiler/Config.ml b/compiler/Config.ml
index 95442761..28218b7b 100644
--- a/compiler/Config.ml
+++ b/compiler/Config.ml
@@ -178,20 +178,37 @@ let extract_template_decreases_clauses = ref false
in monadic let-bindings:
{[
(* NOT supported in F*/Coq *)
- let (x, y) <-- f ();
+ (x, y) <-- f ();
...
]}
In such situations, we might want to introduce an intermediate
assignment:
{[
- let tmp <-- f ();
+ tmp <-- f ();
let (x, y) = tmp in
...
]}
*)
let decompose_monadic_let_bindings = ref false
+(** Some provers like Coq don't support nested patterns in let-bindings:
+ {[
+ (* NOT supported in Coq *)
+ (st, (x1, x2)) <-- f ();
+ ...
+ ]}
+
+ In such situations, we might want to introduce intermediate
+ assignments:
+ {[
+ (st, tmp) <-- f ();
+ let (x1, x2) = tmp in
+ ...
+ ]}
+ *)
+let decompose_nested_let_patterns = ref false
+
(** Controls the unfolding of monadic let-bindings to explicit matches:
[y <-- f x; ...]
diff --git a/compiler/Driver.ml b/compiler/Driver.ml
index 5089cb8e..05a40ad1 100644
--- a/compiler/Driver.ml
+++ b/compiler/Driver.ml
@@ -37,20 +37,6 @@ let () =
Arg.Symbol (backend_names, set_backend),
" Specify the backend to which to extract" );
("-dest", Arg.Set_string dest_dir, " Specify the output directory");
- ( "-decompose-monads",
- Arg.Set decompose_monadic_let_bindings,
- " Decompose the monadic let-bindings.\n\n\
- \ Introduces a temporary variable which is later decomposed,\n\
- \ when the pattern on the left of the monadic let is not a \n\
- \ variable.\n\
- \ \n\
- \ Example:\n\
- \ `(x, y) <-- f (); ...` ~~>\n\
- \ `tmp <-- f (); let (x, y) = tmp in ...`\n\
- \ " );
- ( "-unfold-monads",
- Arg.Set unfold_monadic_let_bindings,
- " Unfold the monadic let-bindings to matches" );
( "-no-filter-useless-calls",
Arg.Clear filter_useless_monadic_calls,
" Do not filter the useless function calls, when possible" );
@@ -111,17 +97,23 @@ let () =
fail ()
in
- (* In the case of Coq, we forbid using field projectors (see the comments for
- [dont_use_field_projectors]).
- Also, we always decompose ADT values with matches (decomposing with
- let-bindings is not supported).
- *)
+ (* Set some options depending on the backend *)
let _ =
match !backend with
- | FStar -> ()
+ | FStar ->
+ (* F* supports monadic notations, but the encoding loses information *)
+ unfold_monadic_let_bindings := true
| Coq ->
+ (* In the case of Coq, we forbid using field projectors (see the comments for
+ [dont_use_field_projectors]).
+ Also, we always decompose ADT values with matches (decomposing with
+ let-bindings is not supported).
+ *)
dont_use_field_projectors := true;
- always_deconstruct_adts_with_matches := true
+ always_deconstruct_adts_with_matches := true;
+ (* Some patterns are not supported *)
+ decompose_monadic_let_bindings := true;
+ decompose_nested_let_patterns := true
in
(* Retrieve and check the filename *)
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 7b261516..1cb35613 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -1081,11 +1081,16 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
let body = Some { body with body = obj#visit_texpression () body.body } in
{ def with body }
-(** Decompose the monadic let-bindings.
+(** Decompose let-bindings by introducing intermediate let-bindings.
- See the explanations in {!val:Config.decompose_monadic_let_bindings}
+ This is a utility function: see {!decompose_monadic_let_bindings} and
+ {!decompose_nested_let_patterns}.
+
+ [decompose_monadic]: always decompose a monadic let-binding
+ [decompose_nested_pats]: decompose the nested patterns
*)
-let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) :
+let decompose_let_bindings (decompose_monadic : bool)
+ (decompose_nested_pats : bool) (_ctx : trans_ctx) (def : fun_decl) :
fun_decl =
match def.body with
| None -> def
@@ -1093,41 +1098,131 @@ let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) :
(* Set up the var id generator *)
let cnt = get_body_min_var_counter body in
let _, fresh_id = VarId.mk_stateful_generator cnt in
+ let mk_fresh (ty : ty) : typed_pattern * texpression =
+ let vid = fresh_id () in
+ let tmp : var = { id = vid; basename = None; ty } in
+ let ltmp = mk_typed_pattern_from_var tmp None in
+ let rtmp = mk_texpression_from_var tmp in
+ (ltmp, rtmp)
+ in
+
+ (* Utility function - returns the patterns to introduce, from the last to
+ the first.
+
+ For instance, if it returns:
+ {[
+ ([
+ ((x3, x4), x1);
+ ((x1, x2), tmp)
+ ],
+ (x0, tmp))
+ ]}
+ then we need to introduce:
+ {[
+ let (x0, tmp) = original_term in
+ let (x1, x2) = tmp in
+ let (x3, x4) = x1 in
+ ...
+ }]
+ *)
+ let decompose_pat (lv : typed_pattern) :
+ (typed_pattern * texpression) list * typed_pattern =
+ let patterns = ref [] in
+
+ (* We decompose patterns *inside* other patterns.
+ The boolean [inside] allows us to remember if we dived into a
+ pattern already *)
+ let visit_pats =
+ object
+ inherit [_] map_typed_pattern as super
+
+ method! visit_typed_pattern (inside : bool) (pat : typed_pattern)
+ : typed_pattern =
+ match pat.value with
+ | PatConstant _ | PatVar _ | PatDummy -> pat
+ | PatAdt _ ->
+ if not inside then super#visit_typed_pattern true pat
+ else
+ let ltmp, rtmp = mk_fresh pat.ty in
+ let pat = super#visit_typed_pattern false pat in
+ patterns := (pat, rtmp) :: !patterns;
+ ltmp
+ end
+ in
+
+ let inside = false in
+ let lv = visit_pats#visit_typed_pattern inside lv in
+ (!patterns, lv)
+ in
+
(* It is a very simple map *)
- let obj =
+ let visit_lets =
object (self)
inherit [_] map_expression as super
method! visit_Let env monadic lv re next_e =
- if not monadic then super#visit_Let env monadic lv re next_e
- else
- (* If monadic, we need to check if the left-value is a variable:
- * - if yes, don't decompose
- * - if not, make the decomposition in two steps
- *)
- match lv.value with
- | PatVar _ ->
- (* Variable: nothing to do *)
- super#visit_Let env monadic lv re next_e
- | _ ->
- (* Not a variable: decompose *)
- (* Introduce a temporary variable to receive the value of the
- * monadic binding *)
- let vid = fresh_id () in
- let tmp : var = { id = vid; basename = None; ty = lv.ty } in
- let ltmp = mk_typed_pattern_from_var tmp None in
- let rtmp = mk_texpression_from_var tmp in
- (* Visit the next expression *)
- let next_e = self#visit_texpression env next_e in
- (* Create the let-bindings *)
- (mk_let true ltmp re (mk_let false lv rtmp next_e)).e
+ (* Decompose the monadic let-bindings *)
+ let monadic, lv, re, next_e =
+ if (not monadic) || not decompose_monadic then
+ (monadic, lv, re, next_e)
+ else
+ (* If monadic, we need to check if the left-value is a variable:
+ * - if yes, don't decompose
+ * - if not, make the decomposition in two steps
+ *)
+ match lv.value with
+ | PatVar _ ->
+ (* Variable: nothing to do *)
+ (monadic, lv, re, next_e)
+ | _ ->
+ (* Not a variable: decompose if required *)
+ (* Introduce a temporary variable to receive the value of the
+ * monadic binding *)
+ let ltmp, rtmp = mk_fresh lv.ty in
+ (* Visit the next expression *)
+ let next_e = self#visit_texpression env next_e in
+ (* Create the let-bindings *)
+ (true, ltmp, re, mk_let false lv rtmp next_e)
+ in
+ (* Decompose the nested let-patterns *)
+ let lv, next_e =
+ if not decompose_nested_pats then (lv, next_e)
+ else
+ let pats, first_pat = decompose_pat lv in
+ let e =
+ List.fold_left
+ (fun next_e (lpat, rv) -> mk_let false lpat rv next_e)
+ next_e pats
+ in
+ (first_pat, e)
+ in
+ (* Continue *)
+ super#visit_Let env monadic lv re next_e
end
in
(* Update the body *)
- let body = Some { body with body = obj#visit_texpression () body.body } in
+ let body =
+ Some { body with body = visit_lets#visit_texpression () body.body }
+ in
(* Return *)
{ def with body }
+(** Decompose monadic let-bindings.
+
+ See the explanations in {!val:Config.decompose_monadic_let_bindings}
+ *)
+let decompose_monadic_let_bindings (ctx : trans_ctx) (def : fun_decl) : fun_decl
+ =
+ decompose_let_bindings true false ctx def
+
+(** Decompose the nested let patterns.
+
+ See the explanations in {!val:Config.decompose_nested_let_patterns}
+ *)
+let decompose_nested_let_patterns (ctx : trans_ctx) (def : fun_decl) : fun_decl
+ =
+ decompose_let_bindings false true ctx def
+
(** Unfold the monadic let-bindings to explicit matches. *)
let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
match def.body with
@@ -1253,7 +1348,7 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl option =
Ex.:
{[
- type struct = { f0 : nat; f1 : nat }
+ (* type struct = { f0 : nat; f1 : nat } *)
Mkstruct x.f0 x.f1 ~~> x
]}
@@ -1262,8 +1357,7 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl option =
log#ldebug
(lazy ("simplify_aggregates:\n\n" ^ fun_decl_to_string ctx def ^ "\n"));
- (* Decompose the monadic let-bindings - F* specific
- * TODO: remove? *)
+ (* Decompose the monadic let-bindings - used by Coq *)
let def =
if !Config.decompose_monadic_let_bindings then (
let def = decompose_monadic_let_bindings ctx def in
@@ -1279,6 +1373,22 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl option =
def)
in
+ (* Decompose nested let-patterns *)
+ let def =
+ if !Config.decompose_nested_let_patterns then (
+ let def = decompose_nested_let_patterns ctx def in
+ log#ldebug
+ (lazy
+ ("decompose_nested_let_patterns:\n\n" ^ fun_decl_to_string ctx def
+ ^ "\n"));
+ def)
+ else (
+ log#ldebug
+ (lazy
+ "ignoring decompose_nested_let_patterns due to the configuration\n");
+ def)
+ in
+
(* Unfold the monadic let-bindings *)
let def =
if !Config.unfold_monadic_let_bindings then (