diff options
| author | Son Ho | 2022-11-14 09:06:15 +0100 | 
|---|---|---|
| committer | Son HO | 2022-11-14 14:21:04 +0100 | 
| commit | 3eba613a9ff9d5c265fbe2676f6bd324728d9ca4 (patch) | |
| tree | eeac1d917f398906ab4aeaa3627561d980f0492a | |
| parent | 2a0ecfbef81231a394df71817a4cd9e81582b7de (diff) | |
Implement a pass to decompose nested patterns in let-bindings
| -rw-r--r-- | Makefile | 8 | ||||
| -rw-r--r-- | compiler/Config.ml | 21 | ||||
| -rw-r--r-- | compiler/Driver.ml | 34 | ||||
| -rw-r--r-- | compiler/PureMicroPasses.ml | 170 | ||||
| -rw-r--r-- | tests/coq/misc/External__Funs.v | 5 | 
5 files changed, 179 insertions, 59 deletions
@@ -146,7 +146,7 @@ tfstarp-betree_main: OPTIONS += $(BETREE_FSTAR_OPTIONS)  # This generates very ugly code, but is good to test the translation.  .PHONY: test-transp-betree_main  test-transp-betree_main: transp-betree_main -test-transp-betree_main: OPTIONS += -backend fstar -unfold-monads -test-trans-units +test-transp-betree_main: OPTIONS += -backend fstar -test-trans-units  test-transp-betree_main: OPTIONS += $(BETREE_FSTAR_OPTIONS)  test-transp-betree_main: BACKEND_SUBDIR := "fstar"  test-transp-betree_main: SUBDIR:=betree_back_stateful @@ -186,14 +186,14 @@ transp-%: gen-llbcp-% tfstarp-%  	echo "# Test $* done"  .PHONY: tfstar-% -tfstar-%: OPTIONS += -backend fstar -unfold-monads -test-trans-units +tfstar-%: OPTIONS += -backend fstar -test-trans-units  tfstar-%: BACKEND_SUBDIR := fstar  tfstar-%:  	$(AENEAS_CMD)  # "p" stands for "Polonius"  .PHONY: tfstarp-% -tfstarp-%: OPTIONS += -backend fstar -unfold-monads -test-trans-units +tfstarp-%: OPTIONS += -backend fstar -test-trans-units  tfstarp-%: BACKEND_SUBDIR := fstar  tfstarp-%:  	$(AENEAS_CMD) @@ -201,7 +201,7 @@ tfstarp-%:  # TODO: -test-trans-units  # It doesn't work on vec_push_fwd, I don't understand why.  .PHONY: tcoq-% -tcoq-%: OPTIONS += -backend coq -decompose-monads +tcoq-%: OPTIONS += -backend coq  tcoq-%: BACKEND_SUBDIR := coq  tcoq-%:  	$(AENEAS_CMD) diff --git a/compiler/Config.ml b/compiler/Config.ml index 95442761..28218b7b 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -178,20 +178,37 @@ let extract_template_decreases_clauses = ref false      in monadic let-bindings:      {[        (* NOT supported in F*/Coq *) -      let (x, y) <-- f (); +      (x, y) <-- f ();        ...      ]}      In such situations, we might want to introduce an intermediate      assignment:      {[ -      let tmp <-- f (); +      tmp <-- f ();        let (x, y) = tmp in        ...      ]}   *)  let decompose_monadic_let_bindings = ref false +(** Some provers like Coq don't support nested patterns in let-bindings: +    {[ +      (* NOT supported in Coq *) +      (st, (x1, x2)) <-- f (); +      ... +    ]} + +    In such situations, we might want to introduce intermediate +    assignments: +    {[ +      (st, tmp) <-- f (); +      let (x1, x2) = tmp in +      ... +    ]} + *) +let decompose_nested_let_patterns = ref false +  (** Controls the unfolding of monadic let-bindings to explicit matches:      [y <-- f x; ...] diff --git a/compiler/Driver.ml b/compiler/Driver.ml index 5089cb8e..05a40ad1 100644 --- a/compiler/Driver.ml +++ b/compiler/Driver.ml @@ -37,20 +37,6 @@ let () =          Arg.Symbol (backend_names, set_backend),          " Specify the backend to which to extract" );        ("-dest", Arg.Set_string dest_dir, " Specify the output directory"); -      ( "-decompose-monads", -        Arg.Set decompose_monadic_let_bindings, -        " Decompose the monadic let-bindings.\n\n\ -        \         Introduces a temporary variable which is later decomposed,\n\ -        \         when the pattern on the left of the monadic let is not a \n\ -        \         variable.\n\ -        \         \n\ -        \         Example:\n\ -        \         `(x, y) <-- f (); ...` ~~>\n\ -        \         `tmp <-- f (); let (x, y) = tmp in ...`\n\ -        \        " ); -      ( "-unfold-monads", -        Arg.Set unfold_monadic_let_bindings, -        " Unfold the monadic let-bindings to matches" );        ( "-no-filter-useless-calls",          Arg.Clear filter_useless_monadic_calls,          " Do not filter the useless function calls, when possible" ); @@ -111,17 +97,23 @@ let () =          fail ()    in -  (* In the case of Coq, we forbid using field projectors (see the comments for -     [dont_use_field_projectors]). -     Also, we always decompose ADT values with matches (decomposing with -     let-bindings is not supported). -  *) +  (* Set some options depending on the backend *)    let _ =      match !backend with -    | FStar -> () +    | FStar -> +        (* F* supports monadic notations, but the encoding loses information *) +        unfold_monadic_let_bindings := true      | Coq -> +        (* In the case of Coq, we forbid using field projectors (see the comments for +           [dont_use_field_projectors]). +           Also, we always decompose ADT values with matches (decomposing with +           let-bindings is not supported). +        *)          dont_use_field_projectors := true; -        always_deconstruct_adts_with_matches := true +        always_deconstruct_adts_with_matches := true; +        (* Some patterns are not supported *) +        decompose_monadic_let_bindings := true; +        decompose_nested_let_patterns := true    in    (* Retrieve and check the filename *) diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index 7b261516..1cb35613 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -1081,11 +1081,16 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =        let body = Some { body with body = obj#visit_texpression () body.body } in        { def with body } -(** Decompose the monadic let-bindings. +(** Decompose let-bindings by introducing intermediate let-bindings. -    See the explanations in {!val:Config.decompose_monadic_let_bindings} +    This is a utility function: see {!decompose_monadic_let_bindings} and +    {!decompose_nested_let_patterns}. + +    [decompose_monadic]: always decompose a monadic let-binding +    [decompose_nested_pats]: decompose the nested patterns   *) -let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : +let decompose_let_bindings (decompose_monadic : bool) +    (decompose_nested_pats : bool) (_ctx : trans_ctx) (def : fun_decl) :      fun_decl =    match def.body with    | None -> def @@ -1093,41 +1098,131 @@ let decompose_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) :        (* Set up the var id generator *)        let cnt = get_body_min_var_counter body in        let _, fresh_id = VarId.mk_stateful_generator cnt in +      let mk_fresh (ty : ty) : typed_pattern * texpression = +        let vid = fresh_id () in +        let tmp : var = { id = vid; basename = None; ty } in +        let ltmp = mk_typed_pattern_from_var tmp None in +        let rtmp = mk_texpression_from_var tmp in +        (ltmp, rtmp) +      in + +      (* Utility function - returns the patterns to introduce, from the last to +         the first. + +         For instance, if it returns: +         {[ +           ([ +              ((x3, x4), x1); +              ((x1, x2), tmp) +            ], +            (x0, tmp)) +         ]} +         then we need to introduce: +         {[ +           let (x0, tmp) = original_term in +           let (x1, x2) = tmp in +           let (x3, x4) = x1 in +           ... +         }] +      *) +      let decompose_pat (lv : typed_pattern) : +          (typed_pattern * texpression) list * typed_pattern = +        let patterns = ref [] in + +        (* We decompose patterns *inside* other patterns. +           The boolean [inside] allows us to remember if we dived into a +           pattern already *) +        let visit_pats = +          object +            inherit [_] map_typed_pattern as super + +            method! visit_typed_pattern (inside : bool) (pat : typed_pattern) +                : typed_pattern = +              match pat.value with +              | PatConstant _ | PatVar _ | PatDummy -> pat +              | PatAdt _ -> +                  if not inside then super#visit_typed_pattern true pat +                  else +                    let ltmp, rtmp = mk_fresh pat.ty in +                    let pat = super#visit_typed_pattern false pat in +                    patterns := (pat, rtmp) :: !patterns; +                    ltmp +          end +        in + +        let inside = false in +        let lv = visit_pats#visit_typed_pattern inside lv in +        (!patterns, lv) +      in +        (* It is a very simple map *) -      let obj = +      let visit_lets =          object (self)            inherit [_] map_expression as super            method! visit_Let env monadic lv re next_e = -            if not monadic then super#visit_Let env monadic lv re next_e -            else -              (* If monadic, we need to check if the left-value is a variable: -               * - if yes, don't decompose -               * - if not, make the decomposition in two steps -               *) -              match lv.value with -              | PatVar _ -> -                  (* Variable: nothing to do *) -                  super#visit_Let env monadic lv re next_e -              | _ -> -                  (* Not a variable: decompose *) -                  (* Introduce a temporary variable to receive the value of the -                   * monadic binding *) -                  let vid = fresh_id () in -                  let tmp : var = { id = vid; basename = None; ty = lv.ty } in -                  let ltmp = mk_typed_pattern_from_var tmp None in -                  let rtmp = mk_texpression_from_var tmp in -                  (* Visit the next expression *) -                  let next_e = self#visit_texpression env next_e in -                  (* Create the let-bindings *) -                  (mk_let true ltmp re (mk_let false lv rtmp next_e)).e +            (* Decompose the monadic let-bindings *) +            let monadic, lv, re, next_e = +              if (not monadic) || not decompose_monadic then +                (monadic, lv, re, next_e) +              else +                (* If monadic, we need to check if the left-value is a variable: +                 * - if yes, don't decompose +                 * - if not, make the decomposition in two steps +                 *) +                match lv.value with +                | PatVar _ -> +                    (* Variable: nothing to do *) +                    (monadic, lv, re, next_e) +                | _ -> +                    (* Not a variable: decompose if required *) +                    (* Introduce a temporary variable to receive the value of the +                     * monadic binding *) +                    let ltmp, rtmp = mk_fresh lv.ty in +                    (* Visit the next expression *) +                    let next_e = self#visit_texpression env next_e in +                    (* Create the let-bindings *) +                    (true, ltmp, re, mk_let false lv rtmp next_e) +            in +            (* Decompose the nested let-patterns *) +            let lv, next_e = +              if not decompose_nested_pats then (lv, next_e) +              else +                let pats, first_pat = decompose_pat lv in +                let e = +                  List.fold_left +                    (fun next_e (lpat, rv) -> mk_let false lpat rv next_e) +                    next_e pats +                in +                (first_pat, e) +            in +            (* Continue *) +            super#visit_Let env monadic lv re next_e          end        in        (* Update the body *) -      let body = Some { body with body = obj#visit_texpression () body.body } in +      let body = +        Some { body with body = visit_lets#visit_texpression () body.body } +      in        (* Return *)        { def with body } +(** Decompose monadic let-bindings. + +    See the explanations in {!val:Config.decompose_monadic_let_bindings} + *) +let decompose_monadic_let_bindings (ctx : trans_ctx) (def : fun_decl) : fun_decl +    = +  decompose_let_bindings true false ctx def + +(** Decompose the nested let patterns. + +    See the explanations in {!val:Config.decompose_nested_let_patterns} + *) +let decompose_nested_let_patterns (ctx : trans_ctx) (def : fun_decl) : fun_decl +    = +  decompose_let_bindings false true ctx def +  (** Unfold the monadic let-bindings to explicit matches. *)  let unfold_monadic_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl =    match def.body with @@ -1253,7 +1348,7 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl option =           Ex.:           {[ -           type struct = { f0 : nat; f1 : nat } +           (* type struct = { f0 : nat; f1 : nat } *)             Mkstruct x.f0 x.f1 ~~> x           ]} @@ -1262,8 +1357,7 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl option =        log#ldebug          (lazy ("simplify_aggregates:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); -      (* Decompose the monadic let-bindings - F* specific -       * TODO: remove? *) +      (* Decompose the monadic let-bindings - used by Coq *)        let def =          if !Config.decompose_monadic_let_bindings then (            let def = decompose_monadic_let_bindings ctx def in @@ -1279,6 +1373,22 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl option =            def)        in +      (* Decompose nested let-patterns *) +      let def = +        if !Config.decompose_nested_let_patterns then ( +          let def = decompose_nested_let_patterns ctx def in +          log#ldebug +            (lazy +              ("decompose_nested_let_patterns:\n\n" ^ fun_decl_to_string ctx def +             ^ "\n")); +          def) +        else ( +          log#ldebug +            (lazy +              "ignoring decompose_nested_let_patterns due to the configuration\n"); +          def) +      in +        (* Unfold the monadic let-bindings *)        let def =          if !Config.unfold_monadic_let_bindings then ( diff --git a/tests/coq/misc/External__Funs.v b/tests/coq/misc/External__Funs.v index df35f7eb..021acd6e 100644 --- a/tests/coq/misc/External__Funs.v +++ b/tests/coq/misc/External__Funs.v @@ -84,7 +84,7 @@ Definition test_custom_swap_back    result (state * (u32 * u32))    :=    p <- custom_swap_back u32 x y st (1 %u32) st0; -  let (st1, (x0, y0)) := p in Return (st1, (x0, y0)) +  let (st1, p0) := p in let (x0, y0) := p0 in Return (st1, (x0, y0))    .  (** [external::test_swap_non_zero] *) @@ -93,7 +93,8 @@ Definition test_swap_non_zero_fwd    p <- swap_fwd u32 x (0 %u32) st;    let (st0, _) := p in    p0 <- swap_back u32 x (0 %u32) st st0; -  let (st1, (x0, _)) := p0 in if x0 s= 0 %u32 then Fail_ else Return (st1, x0) +  let (st1, p1) := p0 in +  let (x0, _) := p1 in if x0 s= 0 %u32 then Fail_ else Return (st1, x0)    .  End External__Funs .  | 
