diff options
author | Son Ho | 2023-12-23 00:41:25 +0100 |
---|---|---|
committer | Son Ho | 2023-12-23 00:41:25 +0100 |
commit | b6ef8ee33802e75409c3bd2b82e7b5ad22f1d053 (patch) | |
tree | 8c649905852a0fe17782985c6b88300e15288450 | |
parent | aa5e25785738a779ca5fd89191c85d6ab828c142 (diff) |
Improve the micro passes to eliminate pattern `let f := fun x => g x`
-rw-r--r-- | compiler/PureMicroPasses.ml | 45 | ||||
-rw-r--r-- | tests/coq/misc/Loops.v | 26 | ||||
-rw-r--r-- | tests/fstar/misc/Loops.Funs.fst | 25 | ||||
-rw-r--r-- | tests/lean/Loops.lean | 25 |
4 files changed, 67 insertions, 54 deletions
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml index e7e9d5e1..fa025d93 100644 --- a/compiler/PureMicroPasses.ml +++ b/compiler/PureMicroPasses.ml @@ -684,6 +684,15 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl = let y1 = x1 in ... ]} + + Simplify arrows: + {[ + let f := fun x => g x in + ... + ~~> + let f := g in + ... + ]} *) let simplify_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl = let obj = @@ -739,6 +748,23 @@ let simplify_let_bindings (_ctx : trans_ctx) (def : fun_decl) : fun_decl = super#visit_expression env e.e | _ -> super#visit_Let env monadic lv rv next else super#visit_Let env monadic lv rv next + | Lambda _ -> + if not monadic then + (* Arrow case *) + let pats, e = destruct_lambdas rv in + let g, args = destruct_apps e in + if List.length pats = List.length args then + (* Check if the arguments are exactly the lambdas *) + let check_pat_arg ((pat, arg) : typed_pattern * texpression) = + match (pat.value, arg.e) with + | PatVar (v, _), Var vid -> v.id = vid + | _ -> false + in + if List.for_all check_pat_arg (List.combine pats args) then + self#visit_Let env monadic lv g next + else super#visit_Let env monadic lv rv next + else super#visit_Let env monadic lv rv next + else super#visit_Let env monadic lv rv next | _ -> super#visit_Let env monadic lv rv next end in @@ -1934,9 +1960,10 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = (* Inline the useless variable reassignments *) let inline_named_vars = true in let inline_pure = true in - let def = - inline_useless_var_reassignments ctx inline_named_vars inline_pure def + let inline_useless_var_reassignments ctx = + inline_useless_var_reassignments ctx inline_named_vars inline_pure in + let def = inline_useless_var_reassignments ctx def in log#ldebug (lazy ("inline_useless_var_assignments:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); @@ -1982,6 +2009,20 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl = log#ldebug (lazy ("simplify_aggregates:\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + (* Simplify the let-bindings - some simplifications may have been unlocked by + the pass above (for instance, the lambda simplification) *) + let def = simplify_let_bindings ctx def in + log#ldebug + (lazy + ("simplify_let_bindings (pass 2):\n\n" ^ fun_decl_to_string ctx def ^ "\n")); + + (* Inline the useless vars again *) + let def = inline_useless_var_reassignments ctx def in + log#ldebug + (lazy + ("inline_useless_var_assignments (pass 2):\n\n" + ^ fun_decl_to_string ctx def ^ "\n")); + (* Decompose the monadic let-bindings - used by Coq *) let def = if !Config.decompose_monadic_let_bindings then ( diff --git a/tests/coq/misc/Loops.v b/tests/coq/misc/Loops.v index 313c2cfd..cc76f359 100644 --- a/tests/coq/misc/Loops.v +++ b/tests/coq/misc/Loops.v @@ -160,10 +160,7 @@ Definition list_nth_mut_loop (T : Type) (n : nat) (ls : List_t T) (i : u32) : result (T * (T -> result (List_t T))) := - p <- list_nth_mut_loop_loop T n ls i; - let (t, back) := p in - let back1 := fun (ret : T) => back ret in - Return (t, back1) + p <- list_nth_mut_loop_loop T n ls i; let (t, back) := p in Return (t, back) . (** [loops::list_nth_shared_loop]: loop 0: @@ -265,7 +262,7 @@ Definition id_mut (T : Type) (ls : List_t T) : result ((List_t T) * (List_t T -> result (List_t T))) := - let back := fun (ret : List_t T) => Return ret in Return (ls, back) + Return (ls, Return) . (** [loops::id_shared]: @@ -382,9 +379,7 @@ Definition list_nth_mut_loop_pair := t <- list_nth_mut_loop_pair_loop T n ls0 ls1 i; let (p, back_'a, back_'b) := t in - let back_'a1 := fun (ret : T) => back_'a ret in - let back_'b1 := fun (ret : T) => back_'b ret in - Return (p, back_'a1, back_'b1) + Return (p, back_'a, back_'b) . (** [loops::list_nth_shared_loop_pair]: loop 0: @@ -465,8 +460,7 @@ Definition list_nth_mut_loop_pair_merge := p <- list_nth_mut_loop_pair_merge_loop T n ls0 ls1 i; let (p1, back_'a) := p in - let back_'a1 := fun (ret : (T * T)) => back_'a ret in - Return (p1, back_'a1) + Return (p1, back_'a) . (** [loops::list_nth_shared_loop_pair_merge]: loop 0: @@ -542,8 +536,7 @@ Definition list_nth_mut_shared_loop_pair := p <- list_nth_mut_shared_loop_pair_loop T n ls0 ls1 i; let (p1, back_'a) := p in - let back_'a1 := fun (ret : T) => back_'a ret in - Return (p1, back_'a1) + Return (p1, back_'a) . (** [loops::list_nth_mut_shared_loop_pair_merge]: loop 0: @@ -585,8 +578,7 @@ Definition list_nth_mut_shared_loop_pair_merge := p <- list_nth_mut_shared_loop_pair_merge_loop T n ls0 ls1 i; let (p1, back_'a) := p in - let back_'a1 := fun (ret : T) => back_'a ret in - Return (p1, back_'a1) + Return (p1, back_'a) . (** [loops::list_nth_shared_mut_loop_pair]: loop 0: @@ -628,8 +620,7 @@ Definition list_nth_shared_mut_loop_pair := p <- list_nth_shared_mut_loop_pair_loop T n ls0 ls1 i; let (p1, back_'b) := p in - let back_'b1 := fun (ret : T) => back_'b ret in - Return (p1, back_'b1) + Return (p1, back_'b) . (** [loops::list_nth_shared_mut_loop_pair_merge]: loop 0: @@ -671,8 +662,7 @@ Definition list_nth_shared_mut_loop_pair_merge := p <- list_nth_shared_mut_loop_pair_merge_loop T n ls0 ls1 i; let (p1, back_'a) := p in - let back_'a1 := fun (ret : T) => back_'a ret in - Return (p1, back_'a1) + Return (p1, back_'a) . End Loops. diff --git a/tests/fstar/misc/Loops.Funs.fst b/tests/fstar/misc/Loops.Funs.fst index a047c170..88389300 100644 --- a/tests/fstar/misc/Loops.Funs.fst +++ b/tests/fstar/misc/Loops.Funs.fst @@ -121,9 +121,7 @@ let list_nth_mut_loop (t : Type0) (ls : list_t t) (i : u32) : result (t & (t -> result (list_t t))) = - let* (x, back) = list_nth_mut_loop_loop t ls i in - let back1 = fun ret -> back ret in - Return (x, back1) + let* (x, back) = list_nth_mut_loop_loop t ls i in Return (x, back) (** [loops::list_nth_shared_loop]: loop 0: Source: 'src/loops.rs', lines 91:0-101:1 *) @@ -201,7 +199,7 @@ let id_mut (t : Type0) (ls : list_t t) : result ((list_t t) & (list_t t -> result (list_t t))) = - let back = fun ret -> Return ret in Return (ls, back) + Return (ls, Return) (** [loops::id_shared]: Source: 'src/loops.rs', lines 139:0-139:45 *) @@ -296,9 +294,7 @@ let list_nth_mut_loop_pair result ((t & t) & (t -> result (list_t t)) & (t -> result (list_t t))) = let* (p, back_'a, back_'b) = list_nth_mut_loop_pair_loop t ls0 ls1 i in - let back_'a1 = fun ret -> back_'a ret in - let back_'b1 = fun ret -> back_'b ret in - Return (p, back_'a1, back_'b1) + Return (p, back_'a, back_'b) (** [loops::list_nth_shared_loop_pair]: loop 0: Source: 'src/loops.rs', lines 198:0-219:1 *) @@ -362,8 +358,7 @@ let list_nth_mut_loop_pair_merge result ((t & t) & ((t & t) -> result ((list_t t) & (list_t t)))) = let* (p, back_'a) = list_nth_mut_loop_pair_merge_loop t ls0 ls1 i in - let back_'a1 = fun ret -> back_'a ret in - Return (p, back_'a1) + Return (p, back_'a) (** [loops::list_nth_shared_loop_pair_merge]: loop 0: Source: 'src/loops.rs', lines 241:0-256:1 *) @@ -425,8 +420,7 @@ let list_nth_mut_shared_loop_pair result ((t & t) & (t -> result (list_t t))) = let* (p, back_'a) = list_nth_mut_shared_loop_pair_loop t ls0 ls1 i in - let back_'a1 = fun ret -> back_'a ret in - Return (p, back_'a1) + Return (p, back_'a) (** [loops::list_nth_mut_shared_loop_pair_merge]: loop 0: Source: 'src/loops.rs', lines 278:0-293:1 *) @@ -462,8 +456,7 @@ let list_nth_mut_shared_loop_pair_merge result ((t & t) & (t -> result (list_t t))) = let* (p, back_'a) = list_nth_mut_shared_loop_pair_merge_loop t ls0 ls1 i in - let back_'a1 = fun ret -> back_'a ret in - Return (p, back_'a1) + Return (p, back_'a) (** [loops::list_nth_shared_mut_loop_pair]: loop 0: Source: 'src/loops.rs', lines 297:0-312:1 *) @@ -498,8 +491,7 @@ let list_nth_shared_mut_loop_pair result ((t & t) & (t -> result (list_t t))) = let* (p, back_'b) = list_nth_shared_mut_loop_pair_loop t ls0 ls1 i in - let back_'b1 = fun ret -> back_'b ret in - Return (p, back_'b1) + Return (p, back_'b) (** [loops::list_nth_shared_mut_loop_pair_merge]: loop 0: Source: 'src/loops.rs', lines 316:0-331:1 *) @@ -535,6 +527,5 @@ let list_nth_shared_mut_loop_pair_merge result ((t & t) & (t -> result (list_t t))) = let* (p, back_'a) = list_nth_shared_mut_loop_pair_merge_loop t ls0 ls1 i in - let back_'a1 = fun ret -> back_'a ret in - Return (p, back_'a1) + Return (p, back_'a) diff --git a/tests/lean/Loops.lean b/tests/lean/Loops.lean index 805ecabc..fbb4616f 100644 --- a/tests/lean/Loops.lean +++ b/tests/lean/Loops.lean @@ -122,8 +122,7 @@ def list_nth_mut_loop (T : Type) (ls : List T) (i : U32) : Result (T × (T → Result (List T))) := do let (t, back) ← list_nth_mut_loop_loop T ls i - let back1 := fun ret => back ret - Result.ret (t, back1) + Result.ret (t, back) /- [loops::list_nth_shared_loop]: loop 0: Source: 'src/loops.rs', lines 91:0-101:1 -/ @@ -207,8 +206,7 @@ def id_mut (T : Type) (ls : List T) : Result ((List T) × (List T → Result (List T))) := - let back := fun ret => Result.ret ret - Result.ret (ls, back) + Result.ret (ls, Result.ret) /- [loops::id_shared]: Source: 'src/loops.rs', lines 139:0-139:45 -/ @@ -308,9 +306,7 @@ def list_nth_mut_loop_pair := do let (p, back_'a, back_'b) ← list_nth_mut_loop_pair_loop T ls0 ls1 i - let back_'a1 := fun ret => back_'a ret - let back_'b1 := fun ret => back_'b ret - Result.ret (p, back_'a1, back_'b1) + Result.ret (p, back_'a, back_'b) /- [loops::list_nth_shared_loop_pair]: loop 0: Source: 'src/loops.rs', lines 198:0-219:1 -/ @@ -372,8 +368,7 @@ def list_nth_mut_loop_pair_merge := do let (p, back_'a) ← list_nth_mut_loop_pair_merge_loop T ls0 ls1 i - let back_'a1 := fun ret => back_'a ret - Result.ret (p, back_'a1) + Result.ret (p, back_'a) /- [loops::list_nth_shared_loop_pair_merge]: loop 0: Source: 'src/loops.rs', lines 241:0-256:1 -/ @@ -432,8 +427,7 @@ def list_nth_mut_shared_loop_pair := do let (p, back_'a) ← list_nth_mut_shared_loop_pair_loop T ls0 ls1 i - let back_'a1 := fun ret => back_'a ret - Result.ret (p, back_'a1) + Result.ret (p, back_'a) /- [loops::list_nth_mut_shared_loop_pair_merge]: loop 0: Source: 'src/loops.rs', lines 278:0-293:1 -/ @@ -470,8 +464,7 @@ def list_nth_mut_shared_loop_pair_merge := do let (p, back_'a) ← list_nth_mut_shared_loop_pair_merge_loop T ls0 ls1 i - let back_'a1 := fun ret => back_'a ret - Result.ret (p, back_'a1) + Result.ret (p, back_'a) /- [loops::list_nth_shared_mut_loop_pair]: loop 0: Source: 'src/loops.rs', lines 297:0-312:1 -/ @@ -507,8 +500,7 @@ def list_nth_shared_mut_loop_pair := do let (p, back_'b) ← list_nth_shared_mut_loop_pair_loop T ls0 ls1 i - let back_'b1 := fun ret => back_'b ret - Result.ret (p, back_'b1) + Result.ret (p, back_'b) /- [loops::list_nth_shared_mut_loop_pair_merge]: loop 0: Source: 'src/loops.rs', lines 316:0-331:1 -/ @@ -545,7 +537,6 @@ def list_nth_shared_mut_loop_pair_merge := do let (p, back_'a) ← list_nth_shared_mut_loop_pair_merge_loop T ls0 ls1 i - let back_'a1 := fun ret => back_'a ret - Result.ret (p, back_'a1) + Result.ret (p, back_'a) end loops |