summaryrefslogtreecommitdiff
path: root/backends/lean/Base
diff options
context:
space:
mode:
authorSon Ho2023-06-29 14:51:53 +0200
committerSon Ho2023-06-29 14:51:53 +0200
commit0cee49de70bec6d3ec2221b64a532d19ad71e5e0 (patch)
tree0868ecebb7419dbcca2d282e9d28249a83875773 /backends/lean/Base
parenta6de153f3bfda7feb27d16fcdf2131d37f99c7a3 (diff)
Generalize a bit FixI and add an example
Diffstat (limited to '')
-rw-r--r--backends/lean/Base/Diverge/Base.lean260
1 files changed, 151 insertions, 109 deletions
diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean
index 2e60f6e8..630c0bf6 100644
--- a/backends/lean/Base/Diverge/Base.lean
+++ b/backends/lean/Base/Diverge/Base.lean
@@ -57,7 +57,7 @@ deriving Repr, BEq
open Result
-def bind (x: Result α) (f: α -> Result β) : Result β :=
+def bind {α : Type u} {β : Type v} (x: Result α) (f: α -> Result β) : Result β :=
match x with
| ret v => f v
| fail v => fail v
@@ -84,7 +84,7 @@ instance : Pure Result where
@[simp] theorem bind_tc_div (f : α → Result β) :
(do let y ← div; f y) = div := by simp [Bind.bind, bind]
-def div? {α: Type} (r: Result α): Bool :=
+def div? {α: Type u} (r: Result α): Bool :=
match r with
| div => true
| ret _ | fail _ => false
@@ -96,8 +96,8 @@ namespace Fix
open Primitives
open Result
- variable {a : Type} {b : a → Type}
- variable {c d : Type}
+ variable {a : Type u} {b : a → Type v}
+ variable {c d : Type w} -- TODO: why do we have to make them both : Type w?
/-! # The least fixed point definition and its properties -/
@@ -334,7 +334,8 @@ namespace Fix
(h : c → ((x:a) → Result (b x)) → Result d) :
is_mono_p g →
(∀ y, is_mono_p (h y)) →
- @is_mono_p a b d (λ k => do let y ← g k; h y k) := by
+ @is_mono_p a b d (λ k => @Bind.bind Result _ c d (g k) (fun y => h y k)) := by
+-- @is_mono_p a b d (λ k => do let (y : c) ← g k; h y k) := by
intro hg hh
simp [is_mono_p]
intro fg fh Hrgh
@@ -494,49 +495,49 @@ namespace FixI
open Primitives Fix
-- The index type
- variable {id : Type}
+ variable {id : Type u}
-- The input/output types
- variable {a b : id → Type}
+ variable {a : id → Type v} {b : (i:id) → a i → Type w}
-- Monotonicity relation over monadic arrows (i.e., Kleisli arrows)
- def karrow_rel (k1 k2 : (i:id) → a i → Result (b i)) : Prop :=
+ def karrow_rel (k1 k2 : (i:id) → (x:a i) → Result (b i x)) : Prop :=
∀ i x, result_rel (k1 i x) (k2 i x)
- def kk_to_gen (k : (i:id) → a i → Result (b i)) :
- (x: (i:id) × a i) → Result (b x.fst) :=
+ def kk_to_gen (k : (i:id) → (x:a i) → Result (b i x)) :
+ (x: (i:id) × a i) → Result (b x.fst x.snd) :=
λ ⟨ i, x ⟩ => k i x
- def kk_of_gen (k : (x: (i:id) × a i) → Result (b x.fst)) :
- (i:id) → a i → Result (b i) :=
+ def kk_of_gen (k : (x: (i:id) × a i) → Result (b x.fst x.snd)) :
+ (i:id) → (x:a i) → Result (b i x) :=
λ i x => k ⟨ i, x ⟩
- def k_to_gen (k : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) :
- ((x: (i:id) × a i) → Result (b x.fst)) → (x: (i:id) × a i) → Result (b x.fst) :=
+ def k_to_gen (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) :
+ ((x: (i:id) × a i) → Result (b x.fst x.snd)) → (x: (i:id) × a i) → Result (b x.fst x.snd) :=
λ kk => kk_to_gen (k (kk_of_gen kk))
- def k_of_gen (k : ((x: (i:id) × a i) → Result (b x.fst)) → (x: (i:id) × a i) → Result (b x.fst)) :
- ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i) :=
+ def k_of_gen (k : ((x: (i:id) × a i) → Result (b x.fst x.snd)) → (x: (i:id) × a i) → Result (b x.fst x.snd)) :
+ ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x) :=
λ kk => kk_of_gen (k (kk_to_gen kk))
- def e_to_gen (e : ((i:id) → a i → Result (b i)) → Result c) :
- ((x: (i:id) × a i) → Result (b x.fst)) → Result c :=
+ def e_to_gen (e : ((i:id) → (x:a i) → Result (b i x)) → Result c) :
+ ((x: (i:id) × a i) → Result (b x.fst x.snd)) → Result c :=
λ k => e (kk_of_gen k)
- def is_valid_p (k : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i))
- (e : ((i:id) → a i → Result (b i)) → Result c) : Prop :=
+ def is_valid_p (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x))
+ (e : ((i:id) → (x:a i) → Result (b i x)) → Result c) : Prop :=
Fix.is_valid_p (k_to_gen k) (e_to_gen e)
- def is_valid (f : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) : Prop :=
+ def is_valid (f : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) : Prop :=
∀ k i x, is_valid_p k (λ k => f k i x)
noncomputable def fix
- (f : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) :
- (i:id) → a i → Result (b i) :=
+ (f : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) :
+ (i:id) → (x:a i) → Result (b i x) :=
kk_of_gen (Fix.fix (k_to_gen f))
theorem is_valid_fix_fixed_eq
- {{f : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)}}
+ {{f : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)}}
(Hvalid : is_valid f) :
fix f = f (fix f) := by
have Hvalid' : Fix.is_valid (k_to_gen f) := by
@@ -553,57 +554,43 @@ namespace FixI
/- Some utilities to define the mutually recursive functions -/
-- TODO: use more
- @[simp] def kk_ty (id : Type) (a b : id → Type) := (i:id) → a i → Result (b i)
- @[simp] def k_ty (id : Type) (a b : id → Type) := kk_ty id a b → kk_ty id a b
+ @[simp] def kk_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) :=
+ (i:id) → (x:a i) → Result (b i x)
+ @[simp] def k_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) :=
+ kk_ty id a b → kk_ty id a b
+
+ def in_out_ty : Type (imax (u + 1) (v + 1)) := (in_ty : Type u) × ((x:in_ty) → Type v)
+ @[simp] def mk_in_out_ty (in_ty : Type u) (out_ty : in_ty → Type v) :
+ in_out_ty :=
+ Sigma.mk in_ty out_ty
-- Initially, we had left out the parameters id, a and b.
-- However, by parameterizing Funs with those parameters, we can state
-- and prove lemmas like Funs.is_valid_p_is_valid_p
- inductive Funs (id : Type) (a b : id → Type) :
- List (Type u) → List (Type u) → Type (u + 1) :=
- | Nil : Funs id a b [] []
- | Cons {ity oty : Type u} {itys otys : List (Type u)}
- (f : kk_ty id a b → ity → Result oty) (tl : Funs id a b itys otys) :
- Funs id a b (ity :: itys) (oty :: otys)
-
- theorem Funs.length_eq {itys otys : List (Type)} (fl : Funs id a b itys otys) :
- otys.length = itys.length :=
- match fl with
- | .Nil => by simp
- | .Cons f tl =>
- have h:= Funs.length_eq tl
- by simp [h]
-
- def fin_cast {n m : Nat} (h : m = n) (i : Fin n) : Fin m :=
- ⟨ i.val, by have h1:= i.isLt; simp_all ⟩
-
- @[simp] def Funs.cast_fin {itys otys : List (Type)}
- (fl : Funs id a b itys otys) (i : Fin itys.length) : Fin otys.length :=
- fin_cast (fl.length_eq) i
-
- def get_fun {itys otys : List (Type)} (fl : Funs id a b itys otys) :
- (i : Fin itys.length) → kk_ty id a b → itys.get i → Result (otys.get (fl.cast_fin i)) :=
+ inductive Funs (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) :
+ List in_out_ty.{v, w} → Type (max (u + 1) (max (v + 1) (w + 1))) :=
+ | Nil : Funs id a b []
+ | Cons {ity : Type v} {oty : ity → Type w} {tys : List in_out_ty}
+ (f : kk_ty id a b → (x:ity) → Result (oty x)) (tl : Funs id a b tys) :
+ Funs id a b (⟨ ity, oty ⟩ :: tys)
+
+ def get_fun {tys : List in_out_ty} (fl : Funs id a b tys) :
+ (i : Fin tys.length) → kk_ty id a b → (x : (tys.get i).fst) →
+ Result ((tys.get i).snd x) :=
match fl with
| .Nil => λ i => by have h:= i.isLt; simp at h
- | @Funs.Cons id a b ity oty itys1 otys1 f tl =>
- λ i =>
- if h: i.val = 0 then
- Eq.mp (by cases i; simp_all [List.get]) f
- else
- let j := i.val - 1
- have Hj: j < itys1.length := by
- have Hi := i.isLt
- simp at Hi
- revert Hi
- cases Heq: i.val <;> simp_all
+ | @Funs.Cons id a b ity oty tys1 f tl =>
+ λ ⟨ i, iLt ⟩ =>
+ match i with
+ | 0 =>
+ Eq.mp (by simp [List.get]) f
+ | .succ j =>
+ have jLt: j < tys1.length := by
+ simp at iLt
+ revert iLt
simp_arith
- let j: Fin itys1.length := ⟨ j, Hj ⟩
- Eq.mp
- (by
- cases Heq: i; rename_i val isLt;
- cases Heq': j; rename_i val' isLt;
- cases val <;> simp_all [List.get, fin_cast])
- (get_fun tl j)
+ let j: Fin tys1.length := ⟨ j, jLt ⟩
+ Eq.mp (by simp) (get_fun tl j)
-- TODO: move
theorem add_one_le_iff_le_ne (n m : Nat) (h1 : m ≤ n) (h2 : m ≠ n) : m + 1 ≤ n := by
@@ -683,19 +670,19 @@ namespace FixI
/- Automating the proofs -/
@[simp] theorem is_valid_p_same
- (k : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) (x : Result c) :
+ (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) (x : Result c) :
is_valid_p k (λ _ => x) := by
simp [is_valid_p, k_to_gen, e_to_gen]
@[simp] theorem is_valid_p_rec
- (k : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) (i : id) (x : a i) :
+ (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) (i : id) (x : a i) :
is_valid_p k (λ k => k i x) := by
simp [is_valid_p, k_to_gen, e_to_gen, kk_to_gen, kk_of_gen]
theorem is_valid_p_bind
- {{k : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)}}
- {{g : ((i:id) → a i → Result (b i)) → Result c}}
- {{h : c → ((i:id) → a i → Result (b i)) → Result d}}
+ {{k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)}}
+ {{g : ((i:id) → (x:a i) → Result (b i x)) → Result c}}
+ {{h : c → ((i:id) → (x:a i) → Result (b i x)) → Result d}}
(Hgvalid : is_valid_p k g)
(Hhvalid : ∀ y, is_valid_p k (h y)) :
is_valid_p k (λ k => do let y ← g k; h y k) := by
@@ -705,7 +692,7 @@ namespace FixI
def Funs.is_valid_p
(k : k_ty id a b)
- (fl : Funs id a b itys otys) :
+ (fl : Funs id a b tys) :
Prop :=
match fl with
| .Nil => True
@@ -713,31 +700,29 @@ namespace FixI
def Funs.is_valid_p_is_valid_p_aux
{k : k_ty id a b}
- {itys otys : List Type}
- (Heq : List.length otys = List.length itys)
- (fl : Funs id a b itys otys) (Hvalid : is_valid_p k fl) :
- ∀ (i : Fin itys.length) (x : itys.get i), FixI.is_valid_p k (fun k => get_fun fl i k x) := by
+ {tys : List in_out_ty}
+ (fl : Funs id a b tys) (Hvalid : is_valid_p k fl) :
+ ∀ (i : Fin tys.length) (x : (tys.get i).fst), FixI.is_valid_p k (fun k => get_fun fl i k x) := by
-- Prepare the induction
- have ⟨ n, Hn ⟩ : { n : Nat // itys.length = n } := ⟨ itys.length, by rfl ⟩
- revert itys otys Heq fl Hvalid
+ have ⟨ n, Hn ⟩ : { n : Nat // tys.length = n } := ⟨ tys.length, by rfl ⟩
+ revert tys fl Hvalid
induction n
--
case zero =>
- intro itys otys Heq fl Hvalid Hlen;
- have Heq: itys = [] := by cases itys <;> simp_all
- have Heq: otys = [] := by cases otys <;> simp_all
+ intro tys fl Hvalid Hlen;
+ have Heq: tys = [] := by cases tys <;> simp_all
intro i x
simp_all
have Hi := i.isLt
simp_all
case succ n Hn =>
- intro itys otys Heq fl Hvalid Hlen i x;
- cases fl <;> simp at Hlen i x Heq Hvalid
- rename_i ity oty itys otys f fl
+ intro tys fl Hvalid Hlen i x;
+ cases fl <;> simp at Hlen i x Hvalid
+ rename_i ity oty tys f fl
have ⟨ Hvf, Hvalid ⟩ := Hvalid
have Hvf1: is_valid_p k fl := by
simp [Hvalid, Funs.is_valid_p]
- have Hn := @Hn itys otys (by simp[*]) fl Hvf1 (by simp [*])
+ have Hn := @Hn tys fl Hvf1 (by simp [*])
-- Case disjunction on i
match i with
| ⟨ 0, _ ⟩ =>
@@ -747,19 +732,20 @@ namespace FixI
| ⟨ .succ j, HiLt ⟩ =>
simp_arith at HiLt
simp at x
- let j : Fin (List.length itys) := ⟨ j, by simp_arith [HiLt] ⟩
+ let j : Fin (List.length tys) := ⟨ j, by simp_arith [HiLt] ⟩
have Hn := Hn j x
apply Hn
def Funs.is_valid_p_is_valid_p
- (itys otys : List (Type)) (Heq: otys.length = itys.length := by decide)
- (k : k_ty (Fin (List.length itys)) (List.get itys) fun i => List.get otys (fin_cast Heq i))
- (fl : Funs (Fin itys.length) itys.get (λ i => otys.get (fin_cast Heq i)) itys otys) :
+ (tys : List in_out_ty)
+ (k : k_ty (Fin (List.length tys)) (λ i => (tys.get i).fst) (fun i x => (List.get tys i).snd x))
+ (fl : Funs (Fin tys.length) (λ i => (tys.get i).fst) (λ i x => (tys.get i).snd x) tys) :
fl.is_valid_p k →
- ∀ (i : Fin itys.length) (x : itys.get i), FixI.is_valid_p k (fun k => get_fun fl i k x)
+ ∀ (i : Fin tys.length) (x : (tys.get i).fst),
+ FixI.is_valid_p k (fun k => get_fun fl i k x)
:= by
intro Hvalid
- apply is_valid_p_is_valid_p_aux <;> simp [*]
+ apply is_valid_p_is_valid_p_aux; simp [*]
end FixI
@@ -960,27 +946,21 @@ namespace Ex4
/- Mutually recursive functions - 2nd encoding -/
open Primitives FixI
- attribute [local simp] List.get
-
/- We make the input type and output types dependent on a parameter -/
- @[simp] def input_ty (i : Fin 2) : Type :=
- [Int, Int].get i
-
- @[simp] def output_ty (i : Fin 2) : Type :=
- [Bool, Bool].get i
-
- /- The continuation -/
- variable (k : (i : Fin 2) → input_ty i → Result (output_ty i))
+ @[simp] def tys : List in_out_ty := [mk_in_out_ty Int (λ _ => Bool), mk_in_out_ty Int (λ _ => Bool)]
+ @[simp] def input_ty (i : Fin 2) : Type := (tys.get i).fst
+ @[simp] def output_ty (i : Fin 2) (x : input_ty i) : Type :=
+ (tys.get i).snd x
/- The bodies are more natural -/
- def is_even_body (k : (i : Fin 2) → input_ty i → Result (output_ty i)) (i : Int) : Result Bool :=
+ def is_even_body (k : (i : Fin 2) → (x : input_ty i) → Result (output_ty i x)) (i : Int) : Result Bool :=
if i = 0
then .ret true
else do
let b ← k 1 (i - 1)
.ret b
- def is_odd_body (i : Int) : Result Bool :=
+ def is_odd_body (k : (i : Fin 2) → (x : input_ty i) → Result (output_ty i x)) (i : Int) : Result Bool :=
if i = 0
then .ret false
else do
@@ -988,18 +968,19 @@ namespace Ex4
.ret b
@[simp] def bodies :
- Funs (Fin 2) input_ty output_ty [Int, Int] [Bool, Bool] :=
+ Funs (Fin 2) input_ty output_ty
+ [mk_in_out_ty Int (λ _ => Bool), mk_in_out_ty Int (λ _ => Bool)] :=
Funs.Cons (is_even_body) (Funs.Cons (is_odd_body) Funs.Nil)
- def body (k : (i : Fin 2) → input_ty i → Result (output_ty i)) (i: Fin 2) :
- input_ty i → Result (output_ty i) := get_fun bodies i k
+ def body (k : (i : Fin 2) → (x : input_ty i) → Result (output_ty i x)) (i: Fin 2) :
+ (x : input_ty i) → Result (output_ty i x) := get_fun bodies i k
theorem body_is_valid : is_valid body := by
-- Split the proof into proofs of validity of the individual bodies
rw [is_valid]
simp only [body]
intro k
- apply (Funs.is_valid_p_is_valid_p [Int, Int] [Bool, Bool])
+ apply (Funs.is_valid_p_is_valid_p tys)
simp [Funs.is_valid_p]
(repeat (apply And.intro)) <;> intro x <;> simp at x <;>
simp only [is_even_body, is_odd_body]
@@ -1106,3 +1087,64 @@ namespace Ex5
conv => lhs; rw [Heq]; simp; rw [id_body]
end Ex5
+
+namespace Ex6
+ /- `list_nth` again, but this time we use FixI -/
+ open Primitives FixI
+
+ @[simp] def tys.{u} : List in_out_ty :=
+ [mk_in_out_ty ((a:Type u) × (List a × Int)) (λ ⟨ a, _ ⟩ => a)]
+
+ @[simp] def input_ty (i : Fin 1) := (tys.get i).fst
+ @[simp] def output_ty (i : Fin 1) (x : input_ty i) :=
+ (tys.get i).snd x
+
+ def list_nth_body.{u} (k : (i:Fin 1) → (x:input_ty i) → Result (output_ty i x))
+ (x : (a : Type u) × List a × Int) : Result x.fst :=
+ let ⟨ a, ls, i ⟩ := x
+ match ls with
+ | [] => .fail .panic
+ | hd :: tl =>
+ if i = 0 then .ret hd
+ else k 0 ⟨ a, tl, i - 1 ⟩
+
+ @[simp] def bodies :
+ Funs (Fin 1) input_ty output_ty tys :=
+ Funs.Cons list_nth_body Funs.Nil
+
+ def body (k : (i : Fin 1) → (x : input_ty i) → Result (output_ty i x)) (i: Fin 1) :
+ (x : input_ty i) → Result (output_ty i x) := get_fun bodies i k
+
+ theorem list_nth_body_is_valid: is_valid body := by
+ -- Split the proof into proofs of validity of the individual bodies
+ rw [is_valid]
+ simp only [body]
+ intro k
+ apply (Funs.is_valid_p_is_valid_p tys)
+ simp [Funs.is_valid_p]
+ (repeat (apply And.intro)); intro x; simp at x
+ simp only [list_nth_body]
+ -- Prove the validity of the individual bodies
+ intro k x
+ simp [list_nth_body]
+ split <;> simp
+ split <;> simp
+
+ noncomputable
+ def list_nth {a: Type u} (ls : List a) (i : Int) : Result a :=
+ fix body 0 ⟨ a, ls , i ⟩
+
+ -- The unfolding equation - diverges if `i < 0`
+ theorem list_nth_eq (ls : List a) (i : Int) :
+ list_nth ls i =
+ match ls with
+ | [] => .fail .panic
+ | hd :: tl =>
+ if i = 0 then .ret hd
+ else list_nth tl (i - 1)
+ := by
+ have Heq := is_valid_fix_fixed_eq (@list_nth_body_is_valid a)
+ simp [list_nth]
+ conv => lhs; rw [Heq]
+
+end Ex6