diff options
author | Son Ho | 2023-06-26 17:38:49 +0200 |
---|---|---|
committer | Son Ho | 2023-06-26 17:38:49 +0200 |
commit | ffdc2f47bc4b21df491e1a2efb6cd19637fb399b (patch) | |
tree | fc861c403e78a3ef217eadd7b94cdd2e44c2f523 | |
parent | 6b319ece09b0f8a02529dd98bc20ffcb843020d6 (diff) |
Start working on a better encoding of mut rec defs for Diverge
-rw-r--r-- | backends/lean/Base/Diverge.lean | 102 |
1 files changed, 100 insertions, 2 deletions
diff --git a/backends/lean/Base/Diverge.lean b/backends/lean/Base/Diverge.lean index 1ff34516..a5cf3459 100644 --- a/backends/lean/Base/Diverge.lean +++ b/backends/lean/Base/Diverge.lean @@ -550,7 +550,7 @@ namespace Ex2 end Ex2 namespace Ex3 - /- Mutually recursive functions -/ + /- Mutually recursive functions - first encoding (see Ex4 for a better encoding) -/ open Primitives Fix /- Because we have mutually recursive functions, we use a sum for the inputs @@ -671,6 +671,104 @@ namespace Ex3 end Ex3 namespace Ex4 + /- Mutually recursive functions - 2nd encoding -/ + open Primitives Fix + + attribute [local simp] List.get + + /- We make the input type and output types dependent on a parameter -/ + @[simp] def input_ty (i : Fin 2) : Type := + [Int, Int].get i + + @[simp] def output_ty (i : Fin 2) : Type := + [Bool, Bool].get i + + /- The continuation -/ + variable (k : (i : Fin 2) → input_ty i → Result (output_ty i)) + + /- The bodies are more natural -/ + def is_even_body (k : (i : Fin 2) → input_ty i → Result (output_ty i)) (i : Int) : Result Bool := + if i = 0 + then .ret true + else do + let b ← k 1 (i - 1) + .ret b + + def is_odd_body (i : Int) : Result Bool := + if i = 0 + then .ret false + else do + let b ← k 0 (i - 1) + .ret b + + inductive Funs : List (Type 0) → List (Type 0) → Type 1 := + | Nil : Funs [] [] + | Cons {ity oty : Type 0} {itys otys : List (Type 0)} (f : ity → Result oty) (tl : Funs itys otys) : Funs (ity :: itys) (oty :: otys) + + theorem Funs.length_eq {itys otys : List (Type 0)} (fl : Funs itys otys) : itys.length = otys.length := + match fl with + | .Nil => by simp + | .Cons f tl => + have h:= Funs.length_eq tl + by simp [h] + + @[simp] def Funs.cast_fin {itys otys : List (Type 0)} (fl : Funs itys otys) (i : Fin itys.length) : Fin otys.length := + ⟨ i.val, by have h:= fl.length_eq; have h1:= i.isLt; simp_all ⟩ + + @[simp] def bodies (k : (i : Fin 2) → input_ty i → Result (output_ty i)) : Funs [Int, Int] [Bool, Bool] := + Funs.Cons (is_even_body k) (Funs.Cons (is_odd_body k) Funs.Nil) + + @[simp] def get_fun {itys otys : List (Type 0)} (fl : Funs itys otys) : + (i : Fin itys.length) → itys.get i → Result (otys.get (fl.cast_fin i)) := + match fl with + | .Nil => λ i => by have h:= i.isLt; simp at h + | @Funs.Cons ity oty itys1 otys1 f tl => + λ i => + if h: i.val = 0 then + Eq.mp (by cases i; simp_all [List.get]) f + else + let j := i.val - 1 + have Hj: j < itys1.length := by + have Hi := i.isLt + simp at Hi + revert Hi + cases Heq: i.val <;> simp_all + simp_arith + let j: Fin itys1.length := ⟨ j, Hj ⟩ + Eq.mp (by cases Heq: i; rename_i val isLt; cases Heq': j; rename_i val' isLt; cases val <;> simp_all [List.get]) (get_fun tl j) + + def body (k : (i : Fin 2) → input_ty i → Result (output_ty i)) (i: Fin 2) : input_ty i → Result (output_ty i) := get_fun (bodies k) i + + def fix_ {n : Nat} {ity oty : Fin n → Type 0} (f : ((i:Fin n) → ity i → Result (oty i)) → (i:Fin n) → ity i → Result (oty i)) : + (i:Fin n) → ity i → Result (oty i) := + sorry + + theorem body_fix_eq : fix_ body = body (fix_ body) := sorry + + def is_even (i : Int) : Result Bool := fix_ body 0 i + def is_odd (i : Int) : Result Bool := fix_ body 1 i + + theorem is_even_eq (i : Int) : is_even i = + (if i = 0 + then .ret true + else do + let b ← is_odd (i - 1) + .ret b) := by + simp [is_even, is_odd]; + conv => lhs; rw [body_fix_eq] + + theorem is_odd_eq (i : Int) : is_odd i = + (if i = 0 + then .ret false + else do + let b ← is_even (i - 1) + .ret b) := by + simp [is_even, is_odd]; + conv => lhs; rw [body_fix_eq] + +end Ex4 + +namespace Ex5 /- Higher-order example -/ open Primitives Fix @@ -736,6 +834,6 @@ namespace Ex4 simp [id] conv => lhs; rw [Heq]; simp; rw [id_body] -end Ex4 +end Ex5 end Diverge |