diff options
author | Son Ho | 2024-04-11 20:31:16 +0200 |
---|---|---|
committer | Son Ho | 2024-04-11 20:31:16 +0200 |
commit | b6421bc01df278f625a8c95b4ea36ad2e4355718 (patch) | |
tree | 6246ef2b038560e3deae41e4fa700f14390cd14f /backends/lean/Base/Diverge | |
parent | 44065f447dc3a2f4b1441b97b9687d1c1b85afbf (diff) | |
parent | 2f8aa9b47acb5c98aed91c29b04f71099452e781 (diff) |
Merge branch 'son/clean' into checked-ops
Diffstat (limited to 'backends/lean/Base/Diverge')
-rw-r--r-- | backends/lean/Base/Diverge/Base.lean | 106 | ||||
-rw-r--r-- | backends/lean/Base/Diverge/Elab.lean | 51 |
2 files changed, 77 insertions, 80 deletions
diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index e40432bd..0f20125f 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -1,7 +1,6 @@ import Lean import Lean.Meta.Tactic.Simp import Init.Data.List.Basic -import Mathlib.Tactic.RunCmd import Mathlib.Tactic.Linarith import Base.Primitives.Base import Base.Arith.Base @@ -39,8 +38,7 @@ namespace Lemmas case zero => simp_all intro m h1 h2 - have h: n = m := by - linarith + have h: n = m := by omega unfold for_all_fin_aux; simp_all simp_all -- There is no i s.t. m ≤ i @@ -169,7 +167,7 @@ namespace Fix match x1 with | div => True | fail _ => x2 = x1 - | ret _ => x2 = x1 -- TODO: generalize + | ok _ => x2 = x1 -- TODO: generalize -- Monotonicity relation over monadic arrows (i.e., Kleisli arrows) def karrow_rel (k1 k2 : (x:a) → Result (b x)) : Prop := @@ -388,7 +386,7 @@ namespace Fix have Hgeq := Hgmono Hffmono simp [result_rel] at Hgeq cases Heq: g (fix_fuel n k) <;> rename_i y <;> simp_all - -- Remains the .ret case + -- Remains the .ok case -- Use Hdiv to prove that: ∀ n, h y (fix_fuel n f) = div -- We do this in two steps: first we prove it for m ≥ n have Hhdiv: ∀ m, h y (fix_fuel m k) = .div := by @@ -509,7 +507,7 @@ namespace FixI specific case. Remark: the index designates the function in the mutually recursive group - (it should be a finite type). We make the return type depend on the input + (it should be a finite type). We make the output type depend on the input type because we group the type parameters in the input type. -/ open Primitives Fix @@ -945,7 +943,7 @@ namespace Ex1 match ls with | [] => .fail .panic | hd :: tl => - if i = 0 then .ret hd + if i = 0 then .ok hd else k (tl, i - 1) theorem list_nth_body_is_valid: ∀ k x, is_valid_p k (λ k => @list_nth_body a k x) := by @@ -962,7 +960,7 @@ namespace Ex1 match ls with | [] => .fail .panic | hd :: tl => - if i = 0 then .ret hd + if i = 0 then .ok hd else list_nth tl (i - 1) := by have Heq := is_valid_fix_fixed_eq (@list_nth_body_is_valid a) @@ -983,11 +981,11 @@ namespace Ex2 match ls with | [] => .fail .panic | hd :: tl => - if i = 0 then .ret hd + if i = 0 then .ok hd else do let y ← k (tl, i - 1) - .ret y + .ok y theorem list_nth_body_is_valid: ∀ k x, is_valid_p k (λ k => @list_nth_body a k x) := by intro k x @@ -1004,11 +1002,11 @@ namespace Ex2 match ls with | [] => .fail .panic | hd :: tl => - if i = 0 then .ret hd + if i = 0 then .ok hd else do let y ← list_nth tl (i - 1) - .ret y) + .ok y) := by have Heq := is_valid_fix_fixed_eq (@list_nth_body_is_valid a) simp [list_nth] @@ -1025,9 +1023,9 @@ namespace Ex3 - inputs: the sum allows to select the function to call in the recursive calls (and the functions may not have the same input types) - outputs: this case is degenerate because `even` and `odd` have the same - return type `Bool`, but generally speaking we need a sum type because + output type `Bool`, but generally speaking we need a sum type because the functions in the mutually recursive group may have different - return types. + output types. -/ variable (k : (Int ⊕ Int) → Result (Bool ⊕ Bool)) @@ -1036,7 +1034,7 @@ namespace Ex3 | .inl i => -- Body of `is_even` if i = 0 - then .ret (.inl true) -- We use .inl because this is `is_even` + then .ok (.inl true) -- We use .inl because this is `is_even` else do let b ← @@ -1046,13 +1044,13 @@ namespace Ex3 let r ← k (.inr (i- 1)) match r with | .inl _ => .fail .panic -- Invalid output - | .inr b => .ret b - -- Wrap the return value - .ret (.inl b) + | .inr b => .ok b + -- Wrap the output value + .ok (.inl b) | .inr i => -- Body of `is_odd` if i = 0 - then .ret (.inr false) -- We use .inr because this is `is_odd` + then .ok (.inr false) -- We use .inr because this is `is_odd` else do let b ← @@ -1061,10 +1059,10 @@ namespace Ex3 -- extract the output value let r ← k (.inl (i- 1)) match r with - | .inl b => .ret b + | .inl b => .ok b | .inr _ => .fail .panic -- Invalid output - -- Wrap the return value - .ret (.inr b) + -- Wrap the output value + .ok (.inr b) theorem is_even_is_odd_body_is_valid: ∀ k x, is_valid_p k (λ k => is_even_is_odd_body k x) := by @@ -1080,7 +1078,7 @@ namespace Ex3 do let r ← fix is_even_is_odd_body (.inl i) match r with - | .inl b => .ret b + | .inl b => .ok b | .inr _ => .fail .panic def is_odd (i : Int): Result Bool := @@ -1088,11 +1086,11 @@ namespace Ex3 let r ← fix is_even_is_odd_body (.inr i) match r with | .inl _ => .fail .panic - | .inr b => .ret b + | .inr b => .ok b -- The unfolding equation for `is_even` - diverges if `i < 0` theorem is_even_eq (i : Int) : - is_even i = (if i = 0 then .ret true else is_odd (i - 1)) + is_even i = (if i = 0 then .ok true else is_odd (i - 1)) := by have Heq := is_valid_fix_fixed_eq is_even_is_odd_body_is_valid simp [is_even, is_odd] @@ -1110,7 +1108,7 @@ namespace Ex3 -- The unfolding equation for `is_odd` - diverges if `i < 0` theorem is_odd_eq (i : Int) : - is_odd i = (if i = 0 then .ret false else is_even (i - 1)) + is_odd i = (if i = 0 then .ok false else is_even (i - 1)) := by have Heq := is_valid_fix_fixed_eq is_even_is_odd_body_is_valid simp [is_even, is_odd] @@ -1136,17 +1134,17 @@ namespace Ex4 /- The bodies are more natural -/ def is_even_body (k : (i : Fin 2) → (x : input_ty i) → Result (output_ty i x)) (i : Int) : Result Bool := if i = 0 - then .ret true + then .ok true else do let b ← k 1 (i - 1) - .ret b + .ok b def is_odd_body (k : (i : Fin 2) → (x : input_ty i) → Result (output_ty i x)) (i : Int) : Result Bool := if i = 0 - then .ret false + then .ok false else do let b ← k 0 (i - 1) - .ret b + .ok b @[simp] def bodies : Funs (Fin 2) input_ty output_ty @@ -1179,19 +1177,19 @@ namespace Ex4 theorem is_even_eq (i : Int) : is_even i = (if i = 0 - then .ret true + then .ok true else do let b ← is_odd (i - 1) - .ret b) := by + .ok b) := by simp [is_even, is_odd]; conv => lhs; rw [body_fix_eq] theorem is_odd_eq (i : Int) : is_odd i = (if i = 0 - then .ret false + then .ok false else do let b ← is_even (i - 1) - .ret b) := by + .ok b) := by simp [is_even, is_odd]; conv => lhs; rw [body_fix_eq] end Ex4 @@ -1205,12 +1203,12 @@ namespace Ex5 /- An auxiliary function, which doesn't require the fixed-point -/ def map (f : a → Result b) (ls : List a) : Result (List b) := match ls with - | [] => .ret [] + | [] => .ok [] | hd :: tl => do let hd ← f hd let tl ← map f tl - .ret (hd :: tl) + .ok (hd :: tl) /- The validity theorem for `map`, generic in `f` -/ theorem map_is_valid @@ -1231,11 +1229,11 @@ namespace Ex5 def id_body (k : Tree a → Result (Tree a)) (t : Tree a) : Result (Tree a) := match t with - | .leaf x => .ret (.leaf x) + | .leaf x => .ok (.leaf x) | .node tl => do let tl ← map k tl - .ret (.node tl) + .ok (.node tl) theorem id_body_is_valid : ∀ k x, is_valid_p k (λ k => @id_body a k x) := by @@ -1256,11 +1254,11 @@ namespace Ex5 theorem id_eq (t : Tree a) : (id t = match t with - | .leaf x => .ret (.leaf x) + | .leaf x => .ok (.leaf x) | .node tl => do let tl ← map id tl - .ret (.node tl)) + .ok (.node tl)) := by have Heq := is_valid_fix_fixed_eq (@id_body_is_valid a) simp [id] @@ -1285,7 +1283,7 @@ namespace Ex6 match ls with | [] => .fail .panic | hd :: tl => - if i = 0 then .ret hd + if i = 0 then .ok hd else k 0 ⟨ a, tl, i - 1 ⟩ @[simp] def bodies : @@ -1316,7 +1314,7 @@ namespace Ex6 match ls with | [] => is_valid_p_same k (.fail .panic) | hd :: tl => - is_valid_p_ite k (Eq i 0) (is_valid_p_same k (.ret hd)) (is_valid_p_rec k 0 ⟨a, tl, i-1⟩) + is_valid_p_ite k (Eq i 0) (is_valid_p_same k (.ok hd)) (is_valid_p_rec k 0 ⟨a, tl, i-1⟩) theorem body_is_valid' : is_valid body := fun k => @@ -1332,7 +1330,7 @@ namespace Ex6 match ls with | [] => .fail .panic | hd :: tl => - if i = 0 then .ret hd + if i = 0 then .ok hd else list_nth tl (i - 1) := by have Heq := is_valid_fix_fixed_eq body_is_valid @@ -1347,7 +1345,7 @@ namespace Ex6 match ls with | [] => .fail .panic | hd :: tl => - if i = 0 then .ret hd + if i = 0 then .ok hd else list_nth tl (i - 1) := -- Use the fixed-point equation @@ -1378,7 +1376,7 @@ namespace Ex7 match ls with | [] => .fail .panic | hd :: tl => - if i = 0 then .ret hd + if i = 0 then .ok hd else k 0 a ⟨ tl, i - 1 ⟩ @[simp] def bodies : @@ -1409,7 +1407,7 @@ namespace Ex7 match ls with | [] => is_valid_p_same k (.fail .panic) | hd :: tl => - is_valid_p_ite k (Eq i 0) (is_valid_p_same k (.ret hd)) (is_valid_p_rec k 0 a ⟨tl, i-1⟩) + is_valid_p_ite k (Eq i 0) (is_valid_p_same k (.ok hd)) (is_valid_p_rec k 0 a ⟨tl, i-1⟩) theorem body_is_valid' : is_valid body := fun k => @@ -1425,7 +1423,7 @@ namespace Ex7 match ls with | [] => .fail .panic | hd :: tl => - if i = 0 then .ret hd + if i = 0 then .ok hd else list_nth tl (i - 1) := by have Heq := is_valid_fix_fixed_eq body_is_valid @@ -1440,7 +1438,7 @@ namespace Ex7 match ls with | [] => .fail .panic | hd :: tl => - if i = 0 then .ret hd + if i = 0 then .ok hd else list_nth tl (i - 1) := -- Use the fixed-point equation @@ -1466,12 +1464,12 @@ namespace Ex8 /- An auxiliary function, which doesn't require the fixed-point -/ def map {a : Type y} {b : Type z} (f : a → Result b) (ls : List a) : Result (List b) := match ls with - | [] => .ret [] + | [] => .ok [] | hd :: tl => do let hd ← f hd let tl ← map f tl - .ret (hd :: tl) + .ok (hd :: tl) /- The validity theorems for `map`, generic in `f` -/ @@ -1520,11 +1518,11 @@ namespace Ex9 def id_body.{u} (k : (i:Fin 1) → (t:ty i) → input_ty i t → Result (output_ty i t)) (a : Type u) (t : Tree a) : Result (Tree a) := match t with - | .leaf x => .ret (.leaf x) + | .leaf x => .ok (.leaf x) | .node tl => do let tl ← map (k 0 a) tl - .ret (.node tl) + .ok (.node tl) @[simp] def bodies : Funs (Fin 1) ty input_ty output_ty tys := @@ -1558,11 +1556,11 @@ namespace Ex9 theorem id_eq' {a : Type u} (t : Tree a) : id t = (match t with - | .leaf x => .ret (.leaf x) + | .leaf x => .ok (.leaf x) | .node tl => do let tl ← map id tl - .ret (.node tl)) + .ok (.node tl)) := -- The unfolding equation have Heq := is_valid_fix_fixed_eq body_is_valid.{u} diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index f30148dc..5db8ffed 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -1,7 +1,6 @@ import Lean import Lean.Meta.Tactic.Simp import Init.Data.List.Basic -import Mathlib.Tactic.RunCmd import Base.Utils import Base.Diverge.Base import Base.Diverge.ElabBase @@ -36,7 +35,7 @@ def mkProd (x y : Expr) : MetaM Expr := def mkInOutTy (x y z : Expr) : MetaM Expr := do mkAppM ``FixII.mk_in_out_ty #[x, y, z] --- Return the `a` in `Return a` +-- Return the `a` in `Result a` def getResultTy (ty : Expr) : MetaM Expr := ty.withApp fun f args => do if ¬ f.isConstOf ``Result ∨ args.size ≠ 1 then @@ -412,7 +411,7 @@ structure TypeInfo where For `list_nth`: `λ a => List a × Int` -/ in_ty : Expr - /- The output type, without the `Return`. This is a function taking + /- The output type, without the `Result`. This is a function taking as input a value of type `params_ty`. For `list_nth`: `λ a => a` @@ -1480,9 +1479,9 @@ namespace Tests divergent def list_nth {a: Type u} (ls : List a) (i : Int) : Result a := match ls with | [] => .fail .panic - | x :: ls => - if i = 0 then return x - else return (← list_nth ls (i - 1)) + | x :: ls => do + if i = 0 then pure x + else pure (← list_nth ls (i - 1)) --set_option trace.Diverge false @@ -1491,7 +1490,7 @@ namespace Tests example {a: Type} (ls : List a) : ∀ (i : Int), 0 ≤ i → i < ls.length → - ∃ x, list_nth ls i = .ret x := by + ∃ x, list_nth ls i = .ok x := by induction ls . intro i hpos h; simp at h; linarith . rename_i hd tl ih @@ -1539,7 +1538,7 @@ namespace Tests if i > 10 then return (← foo (i / 10)) + (← bar i) else bar 10 divergent def bar (i : Int) : Result Nat := - if i > 20 then foo (i / 20) else .ret 42 + if i > 20 then foo (i / 20) else .ok 42 end #check foo.unfold @@ -1558,8 +1557,8 @@ namespace Tests divergent def iInBounds {a : Type} (ls : List a) (i : Int) : Result Bool := let i0 := ls.length if i < i0 - then Result.ret True - else Result.ret False + then Result.ok True + else Result.ok False #check iInBounds.unfold @@ -1567,8 +1566,8 @@ namespace Tests {a : Type} (ls : List a) : Result Bool := let ls1 := ls match ls1 with - | [] => Result.ret False - | _ :: _ => Result.ret True + | [] => Result.ok False + | _ :: _ => Result.ok True #check isCons.unfold @@ -1585,7 +1584,7 @@ namespace Tests divergent def infinite_loop : Result Unit := do let _ ← infinite_loop - Result.ret () + Result.ok () #check infinite_loop.unfold @@ -1605,51 +1604,51 @@ namespace Tests divergent def id {a : Type u} (t : Tree a) : Result (Tree a) := match t with - | .leaf x => .ret (.leaf x) + | .leaf x => .ok (.leaf x) | .node tl => do let tl ← map id tl - .ret (.node tl) + .ok (.node tl) #check id.unfold divergent def id1 {a : Type u} (t : Tree a) : Result (Tree a) := match t with - | .leaf x => .ret (.leaf x) + | .leaf x => .ok (.leaf x) | .node tl => do let tl ← map (fun x => id1 x) tl - .ret (.node tl) + .ok (.node tl) #check id1.unfold divergent def id2 {a : Type u} (t : Tree a) : Result (Tree a) := match t with - | .leaf x => .ret (.leaf x) + | .leaf x => .ok (.leaf x) | .node tl => do let tl ← map (fun x => do let _ ← id2 x; id2 x) tl - .ret (.node tl) + .ok (.node tl) #check id2.unfold divergent def incr (t : Tree Nat) : Result (Tree Nat) := match t with - | .leaf x => .ret (.leaf (x + 1)) + | .leaf x => .ok (.leaf (x + 1)) | .node tl => do let tl ← map incr tl - .ret (.node tl) + .ok (.node tl) -- We handle this by inlining the let-binding divergent def id3 (t : Tree Nat) : Result (Tree Nat) := match t with - | .leaf x => .ret (.leaf (x + 1)) + | .leaf x => .ok (.leaf (x + 1)) | .node tl => do let f := id3 let tl ← map f tl - .ret (.node tl) + .ok (.node tl) #check id3.unfold @@ -1659,12 +1658,12 @@ namespace Tests -- be parameterized by something). divergent def id4 (t : Tree Nat) : Result (Tree Nat) := match t with - | .leaf x => .ret (.leaf (x + 1)) + | .leaf x => .ok (.leaf (x + 1)) | .node tl => do - let f ← .ret id4 + let f ← .ok id4 let tl ← map f tl - .ret (.node tl) + .ok (.node tl) #check id4.unfold -/ |