summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-06-26 17:38:49 +0200
committerSon Ho2023-06-26 17:38:49 +0200
commitffdc2f47bc4b21df491e1a2efb6cd19637fb399b (patch)
treefc861c403e78a3ef217eadd7b94cdd2e44c2f523
parent6b319ece09b0f8a02529dd98bc20ffcb843020d6 (diff)
Start working on a better encoding of mut rec defs for Diverge
-rw-r--r--backends/lean/Base/Diverge.lean102
1 files changed, 100 insertions, 2 deletions
diff --git a/backends/lean/Base/Diverge.lean b/backends/lean/Base/Diverge.lean
index 1ff34516..a5cf3459 100644
--- a/backends/lean/Base/Diverge.lean
+++ b/backends/lean/Base/Diverge.lean
@@ -550,7 +550,7 @@ namespace Ex2
end Ex2
namespace Ex3
- /- Mutually recursive functions -/
+ /- Mutually recursive functions - first encoding (see Ex4 for a better encoding) -/
open Primitives Fix
/- Because we have mutually recursive functions, we use a sum for the inputs
@@ -671,6 +671,104 @@ namespace Ex3
end Ex3
namespace Ex4
+ /- Mutually recursive functions - 2nd encoding -/
+ open Primitives Fix
+
+ attribute [local simp] List.get
+
+ /- We make the input type and output types dependent on a parameter -/
+ @[simp] def input_ty (i : Fin 2) : Type :=
+ [Int, Int].get i
+
+ @[simp] def output_ty (i : Fin 2) : Type :=
+ [Bool, Bool].get i
+
+ /- The continuation -/
+ variable (k : (i : Fin 2) → input_ty i → Result (output_ty i))
+
+ /- The bodies are more natural -/
+ def is_even_body (k : (i : Fin 2) → input_ty i → Result (output_ty i)) (i : Int) : Result Bool :=
+ if i = 0
+ then .ret true
+ else do
+ let b ← k 1 (i - 1)
+ .ret b
+
+ def is_odd_body (i : Int) : Result Bool :=
+ if i = 0
+ then .ret false
+ else do
+ let b ← k 0 (i - 1)
+ .ret b
+
+ inductive Funs : List (Type 0) → List (Type 0) → Type 1 :=
+ | Nil : Funs [] []
+ | Cons {ity oty : Type 0} {itys otys : List (Type 0)} (f : ity → Result oty) (tl : Funs itys otys) : Funs (ity :: itys) (oty :: otys)
+
+ theorem Funs.length_eq {itys otys : List (Type 0)} (fl : Funs itys otys) : itys.length = otys.length :=
+ match fl with
+ | .Nil => by simp
+ | .Cons f tl =>
+ have h:= Funs.length_eq tl
+ by simp [h]
+
+ @[simp] def Funs.cast_fin {itys otys : List (Type 0)} (fl : Funs itys otys) (i : Fin itys.length) : Fin otys.length :=
+ ⟨ i.val, by have h:= fl.length_eq; have h1:= i.isLt; simp_all ⟩
+
+ @[simp] def bodies (k : (i : Fin 2) → input_ty i → Result (output_ty i)) : Funs [Int, Int] [Bool, Bool] :=
+ Funs.Cons (is_even_body k) (Funs.Cons (is_odd_body k) Funs.Nil)
+
+ @[simp] def get_fun {itys otys : List (Type 0)} (fl : Funs itys otys) :
+ (i : Fin itys.length) → itys.get i → Result (otys.get (fl.cast_fin i)) :=
+ match fl with
+ | .Nil => λ i => by have h:= i.isLt; simp at h
+ | @Funs.Cons ity oty itys1 otys1 f tl =>
+ λ i =>
+ if h: i.val = 0 then
+ Eq.mp (by cases i; simp_all [List.get]) f
+ else
+ let j := i.val - 1
+ have Hj: j < itys1.length := by
+ have Hi := i.isLt
+ simp at Hi
+ revert Hi
+ cases Heq: i.val <;> simp_all
+ simp_arith
+ let j: Fin itys1.length := ⟨ j, Hj ⟩
+ Eq.mp (by cases Heq: i; rename_i val isLt; cases Heq': j; rename_i val' isLt; cases val <;> simp_all [List.get]) (get_fun tl j)
+
+ def body (k : (i : Fin 2) → input_ty i → Result (output_ty i)) (i: Fin 2) : input_ty i → Result (output_ty i) := get_fun (bodies k) i
+
+ def fix_ {n : Nat} {ity oty : Fin n → Type 0} (f : ((i:Fin n) → ity i → Result (oty i)) → (i:Fin n) → ity i → Result (oty i)) :
+ (i:Fin n) → ity i → Result (oty i) :=
+ sorry
+
+ theorem body_fix_eq : fix_ body = body (fix_ body) := sorry
+
+ def is_even (i : Int) : Result Bool := fix_ body 0 i
+ def is_odd (i : Int) : Result Bool := fix_ body 1 i
+
+ theorem is_even_eq (i : Int) : is_even i =
+ (if i = 0
+ then .ret true
+ else do
+ let b ← is_odd (i - 1)
+ .ret b) := by
+ simp [is_even, is_odd];
+ conv => lhs; rw [body_fix_eq]
+
+ theorem is_odd_eq (i : Int) : is_odd i =
+ (if i = 0
+ then .ret false
+ else do
+ let b ← is_even (i - 1)
+ .ret b) := by
+ simp [is_even, is_odd];
+ conv => lhs; rw [body_fix_eq]
+
+end Ex4
+
+namespace Ex5
/- Higher-order example -/
open Primitives Fix
@@ -736,6 +834,6 @@ namespace Ex4
simp [id]
conv => lhs; rw [Heq]; simp; rw [id_body]
-end Ex4
+end Ex5
end Diverge