diff options
author | Son Ho | 2022-11-14 09:06:15 +0100 |
---|---|---|
committer | Son HO | 2022-11-14 14:21:04 +0100 |
commit | 3eba613a9ff9d5c265fbe2676f6bd324728d9ca4 (patch) | |
tree | eeac1d917f398906ab4aeaa3627561d980f0492a | |
parent | 2a0ecfbef81231a394df71817a4cd9e81582b7de (diff) |
Implement a pass to decompose nested patterns in let-bindings
Diffstat (limited to '')
-rw-r--r-- | Makefile | 8 | ||||
-rw-r--r-- | compiler/Config.ml | 21 | ||||
-rw-r--r-- | compiler/Driver.ml | 34 | ||||
-rw-r--r-- | compiler/PureMicroPasses.ml | 170 | ||||
-rw-r--r-- | tests/coq/misc/External__Funs.v | 5 |
5 files changed, 179 insertions, 59 deletions
@@ -146,7 +146,7 @@ tfstarp-betree_main: OPTIONS += $(BETREE_FSTAR_OPTIONS) # This generates very ugly code, but is good to test the translation. .PHONY: test-transp-betree_main test-transp-betree_main: transp-betree_main -test-transp-betree_main: OPTIONS += -backend fstar -unfold-monads -test-trans-units +test-transp-betree_main: OPTIONS += -backend fstar -test-trans-units test-transp-betree_main: OPTIONS += $(BETREE_FSTAR_OPTIONS) test-transp-betree_main: BACKEND_SUBDIR := "fstar" test-transp-betree_main: SUBDIR:=betree_back_stateful @@ -186,14 +186,14 @@ transp-%: gen-llbcp-% tfstarp-% echo "# Test $* done" .PHONY: tfstar-% -tfstar-%: OPTIONS += -backend fstar -unfold-monads -test-trans-units +tfstar-%: OPTIONS += -backend fstar -test-trans-units tfstar-%: BACKEND_SUBDIR := fstar tfstar-%: $(AENEAS_CMD) # "p" stands for "Polonius" .PHONY: tfstarp-% -tfstarp-%: OPTIONS += -backend fstar -unfold-monads -test-trans-units +tfstarp-%: OPTIONS += -backend fstar -test-trans-units tfstarp-%: BACKEND_SUBDIR := fstar tfstarp-%: $(AENEAS_CMD) @@ -201,7 +201,7 @@ tfstarp-%: # TODO: -test-trans-units # It doesn't work on vec_push_fwd, I don't understand why. .PHONY: tcoq-% -tcoq-%: OPTIONS += -backend coq -decompose-monads +tcoq-%: OPTIONS += -backend coq tcoq-%: BACKEND_SUBDIR := coq tcoq-%: $(AENEAS_CMD) diff --git a/compiler/Config.ml b/compiler/Config.ml index 95442761..28218b7b 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -178,20 +178,37 @@ let extract_template_decreases_clauses = ref false in monadic let-bindings: {[ (* NOT supported in F*/Coq *) - let (x, y) <-- f (); + (x, y) <-- f (); ... ]} In such situations, we might want to introduce an intermediate assignment: {[ - let tmp <-- f (); + tmp <-- f (); let (x, y) = tmp in ... ]} *) let decompose_monadic_let_bindings = ref false +(** Some provers like Coq don't support nested patterns in let-bindings: + {[ + (* NOT supported in Coq *) + (st, (x1, x2)) <-- f (); + ... + ]} + + In such situations, we might want to introduce intermediate + assignments: + {[ + (st, tmp) <-- f (); + let (x1, x2) = tmp in + ... + ]} + *) +let decompose_nested_let_patterns = ref false + (** Controls the unfolding of monadic let-bindings to explicit matches: [y <-- f x; ...] diff --git a/compiler/Driver.ml b/compiler/Driver.ml index 5089cb8e..05a40ad1 100644 --- a/compiler/Driver.ml +++ b/compiler/Driver.ml @@ -37,20 +37,6 @@ let () = Arg.Symbol (backend_names, set_backend), " Specify the backend to which to extract" ); ("-dest", Arg.Set_string dest_dir, " Specify the output directory"); - ( "-decompose-monads", - Arg.Set decompose_monadic_let_bindings, - " Decompose the monadic let-bindings.\n\n\ - \ Introduces a temporary variable which is later decomposed,\n\ - \ when the pattern on the left of the monadic let is not a \n\ - \ variable.\n\ - \ \n\ - \ Example:\n\ - \ `(x, y) <-- f (); ...` ~~>\n\ - \ `tmp <-- f (); let (x, y) = tmp in ...`\n\ - \ " ); - ( "-unfold-monads", - Arg.Set unfold_monadic_let_bindings, - " Unfold the monadic let-bindings to matches" ); ( "-no-filter-useless-calls", Arg.Clear filter_useless_monadic_calls, " Do not filter the useless function calls, when possible" ); @@ -111,17 +97,23 @@ let () = fail () in - (* In the case of Coq, we forbid using field projectors (see the comments for - [dont_use_field_projectors]). - Also, we always decompose ADT values with matches (decomposing with - let-bindings is not supported). - *) + (* Set some options depending on the backend *) let _ = match !backend with - | FStar -> () + | FStar -> + (* F* supports monadic notations, but the encoding loses information *) + unfold_monadic_let_bindings := true | Coq -> + (* In the case of Coq, we forbid using field projectors (see the comments for + [dont_use_field_projectors]). + Also, we always decompose ADT values with matches (decomposing with + let-bindings is not supported). + *) dont_use_field_projectors := true; - always_deconstruct_adts_with_matches := true + always_deconstruct_adts_with_matches := true; + (* Some patterns are not supported *) + decompose_monadic_let_bindings := true; + decompose_nested_let_patterns := true in (* Retrieve and check the filename *) diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 7b261516..1cb35613 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -1081,11 +1081,16 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl = let body = Some { body with body = obj#visit_texpression () body.body } in { def with body } -(** Decompose the monadic let-bindings. +(** Decompose let-bindings by introducing intermediate let-bindings. - See the explanations in {!val:Config.decompose_monadic_let_bindings} + This is a utility function: see {!decompose_monadic_let_bindings} and + {!decompose_nested_let_patterns}. + + [decompose_monadic]: always decompose a monadic let-binding + [decompose_nested_pats]: decompose the nested patterns *) -let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : +let decompose_let_bindings (decompose_monadic : bool) + (decompose_nested_pats : bool) (_ctx : trans_ctx) (def : fun_decl) : fun_decl = match def.body with | None -> def @@ -1093,41 +1098,131 @@ let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : (* Set up the var id generator *) let cnt = get_body_min_var_counter body in let _, fresh_id = VarId.mk_stateful_generator cnt in + let mk_fresh (ty : ty) : typed_pattern * texpression = + let vid = fresh_id () in + let tmp : var = { id = vid; basename = None; ty } in + let ltmp = mk_typed_pattern_from_var tmp None in + let rtmp = mk_texpression_from_var tmp in + (ltmp, rtmp) + in + + (* Utility function - returns the patterns to introduce, from the last to + the first. + + For instance, if it returns: + {[ + ([ + ((x3, x4), x1); + ((x1, x2), tmp) + ], + (x0, tmp)) + ]} + then we need to introduce: + {[ + let (x0, tmp) = original_term in + let (x1, x2) = tmp in + let (x3, x4) = x1 in + ... + }] + *) + let decompose_pat (lv : typed_pattern) : + (typed_pattern * texpression) list * typed_pattern = + let patterns = ref [] in + + (* We decompose patterns *inside* other patterns. + The boolean [inside] allows us to remember if we dived into a + pattern already *) + let visit_pats = + object + inherit [_] map_typed_pattern as super + + method! visit_typed_pattern (inside : bool) (pat : typed_pattern) + : typed_pattern = + match pat.value with + | PatConstant _ | PatVar _ | PatDummy -> pat + | PatAdt _ -> + if not inside then super#visit_typed_pattern true pat + else + let ltmp, rtmp = mk_fresh pat.ty in + let pat = super#visit_typed_pattern false pat in + patterns := (pat, rtmp) :: !patterns; + ltmp + end + in + + let inside = false in + let lv = visit_pats#visit_typed_pattern inside lv in + (!patterns, lv) + in + (* It is a very simple map *) - let obj = + let visit_lets = object (self) inherit [_] map_expression as super method! visit_Let env monadic lv re next_e = - if not monadic then super#visit_Let env monadic lv re next_e - else - (* If monadic, we need to check if the left-value is a variable: - * - if yes, don't decompose - * - if not, make the decomposition in two steps - *) - match lv.value with - | PatVar _ -> - (* Variable: nothing to do *) - super#visit_Let env monadic lv re next_e - | _ -> - (* Not a variable: decompose *) - (* Introduce a temporary variable to receive the value of the - * monadic binding *) - let vid = fresh_id () in - let tmp : var = { id = vid; basename = None; ty = lv.ty } in - let ltmp = mk_typed_pattern_from_var tmp None in - let rtmp = mk_texpression_from_var tmp in - (* Visit the next expression *) - let next_e = self#visit_texpression env next_e in - (* Create the let-bindings *) - (mk_let true ltmp re (mk_let false lv rtmp next_e)).e + (* Decompose the monadic let-bindings *) + let monadic, lv, re, next_e = + if (not monadic) || not decompose_monadic then + (monadic, lv, re, next_e) + else + (* If monadic, we need to check if the left-value is a variable: + * - if yes, don't decompose + * - if not, make the decomposition in two steps + *) + match lv.value with + | PatVar _ -> + (* Variable: nothing to do *) + (monadic, lv, re, next_e) + | _ -> + (* Not a variable: decompose if required *) + (* Introduce a temporary variable to receive the value of the + * monadic binding *) + let ltmp, rtmp = mk_fresh lv.ty in + (* Visit the next expression *) + let next_e = self#visit_texpression env next_e in + (* Create the let-bindings *) + (true, ltmp, re, mk_let false lv rtmp next_e) + in + (* Decompose the nested let-patterns *) + let lv, next_e = + if not decompose_nested_pats then (lv, next_e) + else + let pats, first_pat = decompose_pat lv in + let e = + List.fold_left + (fun next_e (lpat, rv) -> mk_let false lpat rv next_e) + next_e pats + in + (first_pat, e) + in + (* Continue *) + super#visit_Let env monadic lv re next_e end in (* Update the body *) - let body = Some { body with body = obj#visit_texpression () body.body } in + let body = + Some { body with body = visit_lets#visit_texpression () body.body } + in (* Return *) { def with body } +(** Decompose monadic let-bindings. + + See the explanations in {!val:Config.decompose_monadic_let_bindings} + *) +let decompose_monadic_let_bindings (ctx : trans_ctx) (def : fun_decl) : fun_decl + = + decompose_let_bindings true false ctx def + +(** Decompose the nested let patterns. + + See the explanations in {!val:Config.decompose_nested_let_patterns} + *) +let decompose_nested_let_patterns (ctx : trans_ctx) (def : fun_decl) : fun_decl + = + decompose_let_bindings false true ctx def + (** Unfold the monadic let-bindings to explicit matches. *) let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl = match def.body with @@ -1253,7 +1348,7 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl option = Ex.: {[ - type struct = { f0 : nat; f1 : nat } + (* type struct = { f0 : nat; f1 : nat } *) Mkstruct x.f0 x.f1 ~~> x ]} @@ -1262,8 +1357,7 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl option = log#ldebug (lazy ("simplify_aggregates:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); - (* Decompose the monadic let-bindings - F* specific - * TODO: remove? *) + (* Decompose the monadic let-bindings - used by Coq *) let def = if !Config.decompose_monadic_let_bindings then ( let def = decompose_monadic_let_bindings ctx def in @@ -1279,6 +1373,22 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl option = def) in + (* Decompose nested let-patterns *) + let def = + if !Config.decompose_nested_let_patterns then ( + let def = decompose_nested_let_patterns ctx def in + log#ldebug + (lazy + ("decompose_nested_let_patterns:\n\n" ^ fun_decl_to_string ctx def + ^ "\n")); + def) + else ( + log#ldebug + (lazy + "ignoring decompose_nested_let_patterns due to the configuration\n"); + def) + in + (* Unfold the monadic let-bindings *) let def = if !Config.unfold_monadic_let_bindings then ( diff --git a/tests/coq/misc/External__Funs.v b/tests/coq/misc/External__Funs.v index df35f7eb..021acd6e 100644 --- a/tests/coq/misc/External__Funs.v +++ b/tests/coq/misc/External__Funs.v @@ -84,7 +84,7 @@ Definition test_custom_swap_back result (state * (u32 * u32)) := p <- custom_swap_back u32 x y st (1 %u32) st0; - let (st1, (x0, y0)) := p in Return (st1, (x0, y0)) + let (st1, p0) := p in let (x0, y0) := p0 in Return (st1, (x0, y0)) . (** [external::test_swap_non_zero] *) @@ -93,7 +93,8 @@ Definition test_swap_non_zero_fwd p <- swap_fwd u32 x (0 %u32) st; let (st0, _) := p in p0 <- swap_back u32 x (0 %u32) st st0; - let (st1, (x0, _)) := p0 in if x0 s= 0 %u32 then Fail_ else Return (st1, x0) + let (st1, p1) := p0 in + let (x0, _) := p1 in if x0 s= 0 %u32 then Fail_ else Return (st1, x0) . End External__Funs . |