From 0cee49de70bec6d3ec2221b64a532d19ad71e5e0 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 29 Jun 2023 14:51:53 +0200 Subject: Generalize a bit FixI and add an example --- backends/lean/Base/Diverge/Base.lean | 260 ++++++++++++++++++++--------------- 1 file changed, 151 insertions(+), 109 deletions(-) diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index 2e60f6e8..630c0bf6 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -57,7 +57,7 @@ deriving Repr, BEq open Result -def bind (x: Result α) (f: α -> Result β) : Result β := +def bind {α : Type u} {β : Type v} (x: Result α) (f: α -> Result β) : Result β := match x with | ret v => f v | fail v => fail v @@ -84,7 +84,7 @@ instance : Pure Result where @[simp] theorem bind_tc_div (f : α → Result β) : (do let y ← div; f y) = div := by simp [Bind.bind, bind] -def div? {α: Type} (r: Result α): Bool := +def div? {α: Type u} (r: Result α): Bool := match r with | div => true | ret _ | fail _ => false @@ -96,8 +96,8 @@ namespace Fix open Primitives open Result - variable {a : Type} {b : a → Type} - variable {c d : Type} + variable {a : Type u} {b : a → Type v} + variable {c d : Type w} -- TODO: why do we have to make them both : Type w? /-! # The least fixed point definition and its properties -/ @@ -334,7 +334,8 @@ namespace Fix (h : c → ((x:a) → Result (b x)) → Result d) : is_mono_p g → (∀ y, is_mono_p (h y)) → - @is_mono_p a b d (λ k => do let y ← g k; h y k) := by + @is_mono_p a b d (λ k => @Bind.bind Result _ c d (g k) (fun y => h y k)) := by +-- @is_mono_p a b d (λ k => do let (y : c) ← g k; h y k) := by intro hg hh simp [is_mono_p] intro fg fh Hrgh @@ -494,49 +495,49 @@ namespace FixI open Primitives Fix -- The index type - variable {id : Type} + variable {id : Type u} -- The input/output types - variable {a b : id → Type} + variable {a : id → Type v} {b : (i:id) → a i → Type w} -- Monotonicity relation over monadic arrows (i.e., Kleisli arrows) - def karrow_rel (k1 k2 : (i:id) → a i → Result (b i)) : Prop := + def karrow_rel (k1 k2 : (i:id) → (x:a i) → Result (b i x)) : Prop := ∀ i x, result_rel (k1 i x) (k2 i x) - def kk_to_gen (k : (i:id) → a i → Result (b i)) : - (x: (i:id) × a i) → Result (b x.fst) := + def kk_to_gen (k : (i:id) → (x:a i) → Result (b i x)) : + (x: (i:id) × a i) → Result (b x.fst x.snd) := λ ⟨ i, x ⟩ => k i x - def kk_of_gen (k : (x: (i:id) × a i) → Result (b x.fst)) : - (i:id) → a i → Result (b i) := + def kk_of_gen (k : (x: (i:id) × a i) → Result (b x.fst x.snd)) : + (i:id) → (x:a i) → Result (b i x) := λ i x => k ⟨ i, x ⟩ - def k_to_gen (k : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) : - ((x: (i:id) × a i) → Result (b x.fst)) → (x: (i:id) × a i) → Result (b x.fst) := + def k_to_gen (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) : + ((x: (i:id) × a i) → Result (b x.fst x.snd)) → (x: (i:id) × a i) → Result (b x.fst x.snd) := λ kk => kk_to_gen (k (kk_of_gen kk)) - def k_of_gen (k : ((x: (i:id) × a i) → Result (b x.fst)) → (x: (i:id) × a i) → Result (b x.fst)) : - ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i) := + def k_of_gen (k : ((x: (i:id) × a i) → Result (b x.fst x.snd)) → (x: (i:id) × a i) → Result (b x.fst x.snd)) : + ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x) := λ kk => kk_of_gen (k (kk_to_gen kk)) - def e_to_gen (e : ((i:id) → a i → Result (b i)) → Result c) : - ((x: (i:id) × a i) → Result (b x.fst)) → Result c := + def e_to_gen (e : ((i:id) → (x:a i) → Result (b i x)) → Result c) : + ((x: (i:id) × a i) → Result (b x.fst x.snd)) → Result c := λ k => e (kk_of_gen k) - def is_valid_p (k : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) - (e : ((i:id) → a i → Result (b i)) → Result c) : Prop := + def is_valid_p (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) + (e : ((i:id) → (x:a i) → Result (b i x)) → Result c) : Prop := Fix.is_valid_p (k_to_gen k) (e_to_gen e) - def is_valid (f : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) : Prop := + def is_valid (f : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) : Prop := ∀ k i x, is_valid_p k (λ k => f k i x) noncomputable def fix - (f : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) : - (i:id) → a i → Result (b i) := + (f : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) : + (i:id) → (x:a i) → Result (b i x) := kk_of_gen (Fix.fix (k_to_gen f)) theorem is_valid_fix_fixed_eq - {{f : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)}} + {{f : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)}} (Hvalid : is_valid f) : fix f = f (fix f) := by have Hvalid' : Fix.is_valid (k_to_gen f) := by @@ -553,57 +554,43 @@ namespace FixI /- Some utilities to define the mutually recursive functions -/ -- TODO: use more - @[simp] def kk_ty (id : Type) (a b : id → Type) := (i:id) → a i → Result (b i) - @[simp] def k_ty (id : Type) (a b : id → Type) := kk_ty id a b → kk_ty id a b + @[simp] def kk_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) := + (i:id) → (x:a i) → Result (b i x) + @[simp] def k_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) := + kk_ty id a b → kk_ty id a b + + def in_out_ty : Type (imax (u + 1) (v + 1)) := (in_ty : Type u) × ((x:in_ty) → Type v) + @[simp] def mk_in_out_ty (in_ty : Type u) (out_ty : in_ty → Type v) : + in_out_ty := + Sigma.mk in_ty out_ty -- Initially, we had left out the parameters id, a and b. -- However, by parameterizing Funs with those parameters, we can state -- and prove lemmas like Funs.is_valid_p_is_valid_p - inductive Funs (id : Type) (a b : id → Type) : - List (Type u) → List (Type u) → Type (u + 1) := - | Nil : Funs id a b [] [] - | Cons {ity oty : Type u} {itys otys : List (Type u)} - (f : kk_ty id a b → ity → Result oty) (tl : Funs id a b itys otys) : - Funs id a b (ity :: itys) (oty :: otys) - - theorem Funs.length_eq {itys otys : List (Type)} (fl : Funs id a b itys otys) : - otys.length = itys.length := - match fl with - | .Nil => by simp - | .Cons f tl => - have h:= Funs.length_eq tl - by simp [h] - - def fin_cast {n m : Nat} (h : m = n) (i : Fin n) : Fin m := - ⟨ i.val, by have h1:= i.isLt; simp_all ⟩ - - @[simp] def Funs.cast_fin {itys otys : List (Type)} - (fl : Funs id a b itys otys) (i : Fin itys.length) : Fin otys.length := - fin_cast (fl.length_eq) i - - def get_fun {itys otys : List (Type)} (fl : Funs id a b itys otys) : - (i : Fin itys.length) → kk_ty id a b → itys.get i → Result (otys.get (fl.cast_fin i)) := + inductive Funs (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) : + List in_out_ty.{v, w} → Type (max (u + 1) (max (v + 1) (w + 1))) := + | Nil : Funs id a b [] + | Cons {ity : Type v} {oty : ity → Type w} {tys : List in_out_ty} + (f : kk_ty id a b → (x:ity) → Result (oty x)) (tl : Funs id a b tys) : + Funs id a b (⟨ ity, oty ⟩ :: tys) + + def get_fun {tys : List in_out_ty} (fl : Funs id a b tys) : + (i : Fin tys.length) → kk_ty id a b → (x : (tys.get i).fst) → + Result ((tys.get i).snd x) := match fl with | .Nil => λ i => by have h:= i.isLt; simp at h - | @Funs.Cons id a b ity oty itys1 otys1 f tl => - λ i => - if h: i.val = 0 then - Eq.mp (by cases i; simp_all [List.get]) f - else - let j := i.val - 1 - have Hj: j < itys1.length := by - have Hi := i.isLt - simp at Hi - revert Hi - cases Heq: i.val <;> simp_all + | @Funs.Cons id a b ity oty tys1 f tl => + λ ⟨ i, iLt ⟩ => + match i with + | 0 => + Eq.mp (by simp [List.get]) f + | .succ j => + have jLt: j < tys1.length := by + simp at iLt + revert iLt simp_arith - let j: Fin itys1.length := ⟨ j, Hj ⟩ - Eq.mp - (by - cases Heq: i; rename_i val isLt; - cases Heq': j; rename_i val' isLt; - cases val <;> simp_all [List.get, fin_cast]) - (get_fun tl j) + let j: Fin tys1.length := ⟨ j, jLt ⟩ + Eq.mp (by simp) (get_fun tl j) -- TODO: move theorem add_one_le_iff_le_ne (n m : Nat) (h1 : m ≤ n) (h2 : m ≠ n) : m + 1 ≤ n := by @@ -683,19 +670,19 @@ namespace FixI /- Automating the proofs -/ @[simp] theorem is_valid_p_same - (k : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) (x : Result c) : + (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) (x : Result c) : is_valid_p k (λ _ => x) := by simp [is_valid_p, k_to_gen, e_to_gen] @[simp] theorem is_valid_p_rec - (k : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) (i : id) (x : a i) : + (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) (i : id) (x : a i) : is_valid_p k (λ k => k i x) := by simp [is_valid_p, k_to_gen, e_to_gen, kk_to_gen, kk_of_gen] theorem is_valid_p_bind - {{k : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)}} - {{g : ((i:id) → a i → Result (b i)) → Result c}} - {{h : c → ((i:id) → a i → Result (b i)) → Result d}} + {{k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)}} + {{g : ((i:id) → (x:a i) → Result (b i x)) → Result c}} + {{h : c → ((i:id) → (x:a i) → Result (b i x)) → Result d}} (Hgvalid : is_valid_p k g) (Hhvalid : ∀ y, is_valid_p k (h y)) : is_valid_p k (λ k => do let y ← g k; h y k) := by @@ -705,7 +692,7 @@ namespace FixI def Funs.is_valid_p (k : k_ty id a b) - (fl : Funs id a b itys otys) : + (fl : Funs id a b tys) : Prop := match fl with | .Nil => True @@ -713,31 +700,29 @@ namespace FixI def Funs.is_valid_p_is_valid_p_aux {k : k_ty id a b} - {itys otys : List Type} - (Heq : List.length otys = List.length itys) - (fl : Funs id a b itys otys) (Hvalid : is_valid_p k fl) : - ∀ (i : Fin itys.length) (x : itys.get i), FixI.is_valid_p k (fun k => get_fun fl i k x) := by + {tys : List in_out_ty} + (fl : Funs id a b tys) (Hvalid : is_valid_p k fl) : + ∀ (i : Fin tys.length) (x : (tys.get i).fst), FixI.is_valid_p k (fun k => get_fun fl i k x) := by -- Prepare the induction - have ⟨ n, Hn ⟩ : { n : Nat // itys.length = n } := ⟨ itys.length, by rfl ⟩ - revert itys otys Heq fl Hvalid + have ⟨ n, Hn ⟩ : { n : Nat // tys.length = n } := ⟨ tys.length, by rfl ⟩ + revert tys fl Hvalid induction n -- case zero => - intro itys otys Heq fl Hvalid Hlen; - have Heq: itys = [] := by cases itys <;> simp_all - have Heq: otys = [] := by cases otys <;> simp_all + intro tys fl Hvalid Hlen; + have Heq: tys = [] := by cases tys <;> simp_all intro i x simp_all have Hi := i.isLt simp_all case succ n Hn => - intro itys otys Heq fl Hvalid Hlen i x; - cases fl <;> simp at Hlen i x Heq Hvalid - rename_i ity oty itys otys f fl + intro tys fl Hvalid Hlen i x; + cases fl <;> simp at Hlen i x Hvalid + rename_i ity oty tys f fl have ⟨ Hvf, Hvalid ⟩ := Hvalid have Hvf1: is_valid_p k fl := by simp [Hvalid, Funs.is_valid_p] - have Hn := @Hn itys otys (by simp[*]) fl Hvf1 (by simp [*]) + have Hn := @Hn tys fl Hvf1 (by simp [*]) -- Case disjunction on i match i with | ⟨ 0, _ ⟩ => @@ -747,19 +732,20 @@ namespace FixI | ⟨ .succ j, HiLt ⟩ => simp_arith at HiLt simp at x - let j : Fin (List.length itys) := ⟨ j, by simp_arith [HiLt] ⟩ + let j : Fin (List.length tys) := ⟨ j, by simp_arith [HiLt] ⟩ have Hn := Hn j x apply Hn def Funs.is_valid_p_is_valid_p - (itys otys : List (Type)) (Heq: otys.length = itys.length := by decide) - (k : k_ty (Fin (List.length itys)) (List.get itys) fun i => List.get otys (fin_cast Heq i)) - (fl : Funs (Fin itys.length) itys.get (λ i => otys.get (fin_cast Heq i)) itys otys) : + (tys : List in_out_ty) + (k : k_ty (Fin (List.length tys)) (λ i => (tys.get i).fst) (fun i x => (List.get tys i).snd x)) + (fl : Funs (Fin tys.length) (λ i => (tys.get i).fst) (λ i x => (tys.get i).snd x) tys) : fl.is_valid_p k → - ∀ (i : Fin itys.length) (x : itys.get i), FixI.is_valid_p k (fun k => get_fun fl i k x) + ∀ (i : Fin tys.length) (x : (tys.get i).fst), + FixI.is_valid_p k (fun k => get_fun fl i k x) := by intro Hvalid - apply is_valid_p_is_valid_p_aux <;> simp [*] + apply is_valid_p_is_valid_p_aux; simp [*] end FixI @@ -960,27 +946,21 @@ namespace Ex4 /- Mutually recursive functions - 2nd encoding -/ open Primitives FixI - attribute [local simp] List.get - /- We make the input type and output types dependent on a parameter -/ - @[simp] def input_ty (i : Fin 2) : Type := - [Int, Int].get i - - @[simp] def output_ty (i : Fin 2) : Type := - [Bool, Bool].get i - - /- The continuation -/ - variable (k : (i : Fin 2) → input_ty i → Result (output_ty i)) + @[simp] def tys : List in_out_ty := [mk_in_out_ty Int (λ _ => Bool), mk_in_out_ty Int (λ _ => Bool)] + @[simp] def input_ty (i : Fin 2) : Type := (tys.get i).fst + @[simp] def output_ty (i : Fin 2) (x : input_ty i) : Type := + (tys.get i).snd x /- The bodies are more natural -/ - def is_even_body (k : (i : Fin 2) → input_ty i → Result (output_ty i)) (i : Int) : Result Bool := + def is_even_body (k : (i : Fin 2) → (x : input_ty i) → Result (output_ty i x)) (i : Int) : Result Bool := if i = 0 then .ret true else do let b ← k 1 (i - 1) .ret b - def is_odd_body (i : Int) : Result Bool := + def is_odd_body (k : (i : Fin 2) → (x : input_ty i) → Result (output_ty i x)) (i : Int) : Result Bool := if i = 0 then .ret false else do @@ -988,18 +968,19 @@ namespace Ex4 .ret b @[simp] def bodies : - Funs (Fin 2) input_ty output_ty [Int, Int] [Bool, Bool] := + Funs (Fin 2) input_ty output_ty + [mk_in_out_ty Int (λ _ => Bool), mk_in_out_ty Int (λ _ => Bool)] := Funs.Cons (is_even_body) (Funs.Cons (is_odd_body) Funs.Nil) - def body (k : (i : Fin 2) → input_ty i → Result (output_ty i)) (i: Fin 2) : - input_ty i → Result (output_ty i) := get_fun bodies i k + def body (k : (i : Fin 2) → (x : input_ty i) → Result (output_ty i x)) (i: Fin 2) : + (x : input_ty i) → Result (output_ty i x) := get_fun bodies i k theorem body_is_valid : is_valid body := by -- Split the proof into proofs of validity of the individual bodies rw [is_valid] simp only [body] intro k - apply (Funs.is_valid_p_is_valid_p [Int, Int] [Bool, Bool]) + apply (Funs.is_valid_p_is_valid_p tys) simp [Funs.is_valid_p] (repeat (apply And.intro)) <;> intro x <;> simp at x <;> simp only [is_even_body, is_odd_body] @@ -1106,3 +1087,64 @@ namespace Ex5 conv => lhs; rw [Heq]; simp; rw [id_body] end Ex5 + +namespace Ex6 + /- `list_nth` again, but this time we use FixI -/ + open Primitives FixI + + @[simp] def tys.{u} : List in_out_ty := + [mk_in_out_ty ((a:Type u) × (List a × Int)) (λ ⟨ a, _ ⟩ => a)] + + @[simp] def input_ty (i : Fin 1) := (tys.get i).fst + @[simp] def output_ty (i : Fin 1) (x : input_ty i) := + (tys.get i).snd x + + def list_nth_body.{u} (k : (i:Fin 1) → (x:input_ty i) → Result (output_ty i x)) + (x : (a : Type u) × List a × Int) : Result x.fst := + let ⟨ a, ls, i ⟩ := x + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else k 0 ⟨ a, tl, i - 1 ⟩ + + @[simp] def bodies : + Funs (Fin 1) input_ty output_ty tys := + Funs.Cons list_nth_body Funs.Nil + + def body (k : (i : Fin 1) → (x : input_ty i) → Result (output_ty i x)) (i: Fin 1) : + (x : input_ty i) → Result (output_ty i x) := get_fun bodies i k + + theorem list_nth_body_is_valid: is_valid body := by + -- Split the proof into proofs of validity of the individual bodies + rw [is_valid] + simp only [body] + intro k + apply (Funs.is_valid_p_is_valid_p tys) + simp [Funs.is_valid_p] + (repeat (apply And.intro)); intro x; simp at x + simp only [list_nth_body] + -- Prove the validity of the individual bodies + intro k x + simp [list_nth_body] + split <;> simp + split <;> simp + + noncomputable + def list_nth {a: Type u} (ls : List a) (i : Int) : Result a := + fix body 0 ⟨ a, ls , i ⟩ + + -- The unfolding equation - diverges if `i < 0` + theorem list_nth_eq (ls : List a) (i : Int) : + list_nth ls i = + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else list_nth tl (i - 1) + := by + have Heq := is_valid_fix_fixed_eq (@list_nth_body_is_valid a) + simp [list_nth] + conv => lhs; rw [Heq] + +end Ex6 -- cgit v1.2.3