diff options
author | Son HO | 2023-07-31 16:15:58 +0200 |
---|---|---|
committer | GitHub | 2023-07-31 16:15:58 +0200 |
commit | 887d0ef1efc8912c6273b5ebcf979384e9d7fa97 (patch) | |
tree | 92d6021eb549f7cc25501856edd58859786b7e90 /backends/lean/Base/Diverge | |
parent | 53adf30fe440eb8b6f58ba89f4a4c0acc7877498 (diff) | |
parent | 9b3a58e423333fc9a4a5a264c3beb0a3d951e86b (diff) |
Merge pull request #31 from AeneasVerif/son_lean_backend
Improve the Lean backend
Diffstat (limited to 'backends/lean/Base/Diverge')
-rw-r--r-- | backends/lean/Base/Diverge/Base.lean | 1138 | ||||
-rw-r--r-- | backends/lean/Base/Diverge/Elab.lean | 1162 | ||||
-rw-r--r-- | backends/lean/Base/Diverge/ElabBase.lean | 15 |
3 files changed, 2315 insertions, 0 deletions
diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean new file mode 100644 index 00000000..1d548389 --- /dev/null +++ b/backends/lean/Base/Diverge/Base.lean @@ -0,0 +1,1138 @@ +import Lean +import Lean.Meta.Tactic.Simp +import Init.Data.List.Basic +import Mathlib.Tactic.RunCmd +import Mathlib.Tactic.Linarith +import Base.Primitives.Base +import Base.Arith.Base + +/- TODO: this is very useful, but is there more? -/ +set_option profiler true +set_option profiler.threshold 100 + +namespace Diverge + +namespace Fix + + open Primitives + open Result + + variable {a : Type u} {b : a → Type v} + variable {c d : Type w} -- TODO: why do we have to make them both : Type w? + + /-! # The least fixed point definition and its properties -/ + + def least_p (p : Nat → Prop) (n : Nat) : Prop := p n ∧ (∀ m, m < n → ¬ p m) + noncomputable def least (p : Nat → Prop) : Nat := + Classical.epsilon (least_p p) + + -- Auxiliary theorem for [least_spec]: if there exists an `n` satisfying `p`, + -- there there exists a least `m` satisfying `p`. + theorem least_spec_aux (p : Nat → Prop) : ∀ (n : Nat), (hn : p n) → ∃ m, least_p p m := by + apply Nat.strongRec' + intros n hi hn + -- Case disjunction on: is n the smallest n satisfying p? + match Classical.em (∀ m, m < n → ¬ p m) with + | .inl hlt => + -- Yes: trivial + exists n + | .inr hlt => + simp at * + let ⟨ m, ⟨ hmlt, hm ⟩ ⟩ := hlt + have hi := hi m hmlt hm + apply hi + + -- The specification of [least]: either `p` is never satisfied, or it is satisfied + -- by `least p` and no `n < least p` satisfies `p`. + theorem least_spec (p : Nat → Prop) : (∀ n, ¬ p n) ∨ (p (least p) ∧ ∀ n, n < least p → ¬ p n) := by + -- Case disjunction on the existence of an `n` which satisfies `p` + match Classical.em (∀ n, ¬ p n) with + | .inl h => + -- There doesn't exist: trivial + apply (Or.inl h) + | .inr h => + -- There exists: we simply use `least_spec_aux` in combination with the property + -- of the epsilon operator + simp at * + let ⟨ n, hn ⟩ := h + apply Or.inr + have hl := least_spec_aux p n hn + have he := Classical.epsilon_spec hl + apply he + + /-! # The fixed point definitions -/ + + def fix_fuel (n : Nat) (f : ((x:a) → Result (b x)) → (x:a) → Result (b x)) (x : a) : + Result (b x) := + match n with + | 0 => .div + | n + 1 => + f (fix_fuel n f) x + + @[simp] def fix_fuel_pred (f : ((x:a) → Result (b x)) → (x:a) → Result (b x)) + (x : a) (n : Nat) := + not (div? (fix_fuel n f x)) + + def fix_fuel_P (f : ((x:a) → Result (b x)) → (x:a) → Result (b x)) + (x : a) (n : Nat) : Prop := + fix_fuel_pred f x n + + partial + def fixImpl (f : ((x:a) → Result (b x)) → (x:a) → Result (b x)) (x : a) : Result (b x) := + f (fixImpl f) x + + -- The fact that `fix` is implemented by `fixImpl` allows us to not mark the + -- functions defined with the fixed-point as noncomputable. One big advantage + -- is that it allows us to evaluate those functions, for instance with #eval. + @[implemented_by fixImpl] + def fix (f : ((x:a) → Result (b x)) → (x:a) → Result (b x)) (x : a) : Result (b x) := + fix_fuel (least (fix_fuel_P f x)) f x + + /-! # The validity property -/ + + -- Monotonicity relation over results + -- TODO: generalize (we should parameterize the definition by a relation over `a`) + def result_rel {a : Type u} (x1 x2 : Result a) : Prop := + match x1 with + | div => True + | fail _ => x2 = x1 + | ret _ => x2 = x1 -- TODO: generalize + + -- Monotonicity relation over monadic arrows (i.e., Kleisli arrows) + def karrow_rel (k1 k2 : (x:a) → Result (b x)) : Prop := + ∀ x, result_rel (k1 x) (k2 x) + + -- Monotonicity property for function bodies + def is_mono (f : ((x:a) → Result (b x)) → (x:a) → Result (b x)) : Prop := + ∀ {{k1 k2}}, karrow_rel k1 k2 → karrow_rel (f k1) (f k2) + + -- "Continuity" property. + -- We need this, and this looks a lot like continuity. Also see this paper: + -- https://inria.hal.science/file/index/docid/216187/filename/tarski.pdf + -- We define our "continuity" criteria so that it gives us what we need to + -- prove the fixed-point equation, and we can also easily manipulate it. + def is_cont (f : ((x:a) → Result (b x)) → (x:a) → Result (b x)) : Prop := + ∀ x, (Hdiv : ∀ n, fix_fuel (.succ n) f x = div) → f (fix f) x = div + + /-! # The proof of the fixed-point equation -/ + theorem fix_fuel_mono {f : ((x:a) → Result (b x)) → (x:a) → Result (b x)} + (Hmono : is_mono f) : + ∀ {{n m}}, n ≤ m → karrow_rel (fix_fuel n f) (fix_fuel m f) := by + intros n + induction n + case zero => simp [karrow_rel, fix_fuel, result_rel] + case succ n1 Hi => + intros m Hle x + simp [result_rel] + match m with + | 0 => + exfalso + zify at * + linarith + | Nat.succ m1 => + simp_arith at Hle + simp [fix_fuel] + have Hi := Hi Hle + have Hmono := Hmono Hi x + simp [result_rel] at Hmono + apply Hmono + + @[simp] theorem neg_fix_fuel_P + {f : ((x:a) → Result (b x)) → (x:a) → Result (b x)} {x : a} {n : Nat} : + ¬ fix_fuel_P f x n ↔ (fix_fuel n f x = div) := by + simp [fix_fuel_P, div?] + cases fix_fuel n f x <;> simp + + theorem fix_fuel_fix_mono {f : ((x:a) → Result (b x)) → (x:a) → Result (b x)} (Hmono : is_mono f) : + ∀ n, karrow_rel (fix_fuel n f) (fix f) := by + intros n x + simp [result_rel] + have Hl := least_spec (fix_fuel_P f x) + simp at Hl + match Hl with + | .inl Hl => simp [*] + | .inr ⟨ Hl, Hn ⟩ => + match Classical.em (fix_fuel n f x = div) with + | .inl Hd => + simp [*] + | .inr Hd => + have Hineq : least (fix_fuel_P f x) ≤ n := by + -- Proof by contradiction + cases Classical.em (least (fix_fuel_P f x) ≤ n) <;> simp [*] + simp at * + rename_i Hineq + have Hn := Hn n Hineq + contradiction + have Hfix : ¬ (fix f x = div) := by + simp [fix] + -- By property of the least upper bound + revert Hd Hl + -- TODO: there is no conversion to select the head of a function! + conv => lhs; apply congr_fun; apply congr_fun; apply congr_fun; simp [fix_fuel_P, div?] + cases fix_fuel (least (fix_fuel_P f x)) f x <;> simp + have Hmono := fix_fuel_mono Hmono Hineq x + simp [result_rel] at Hmono + simp [fix] at * + cases Heq: fix_fuel (least (fix_fuel_P f x)) f x <;> + cases Heq':fix_fuel n f x <;> + simp_all + + theorem fix_fuel_P_least {f : ((x:a) → Result (b x)) → (x:a) → Result (b x)} (Hmono : is_mono f) : + ∀ {{x n}}, fix_fuel_P f x n → fix_fuel_P f x (least (fix_fuel_P f x)) := by + intros x n Hf + have Hfmono := fix_fuel_fix_mono Hmono n x + -- TODO: there is no conversion to select the head of a function! + conv => apply congr_fun; simp [fix_fuel_P] + simp [fix_fuel_P] at Hf + revert Hf Hfmono + simp [div?, result_rel, fix] + cases fix_fuel n f x <;> simp_all + + -- Prove the fixed point equation in the case there exists some fuel for which + -- the execution terminates + theorem fix_fixed_eq_terminates (f : ((x:a) → Result (b x)) → (x:a) → Result (b x)) (Hmono : is_mono f) + (x : a) (n : Nat) (He : fix_fuel_P f x n) : + fix f x = f (fix f) x := by + have Hl := fix_fuel_P_least Hmono He + -- TODO: better control of simplification + conv at Hl => + apply congr_fun + simp [fix_fuel_P] + -- The least upper bound is > 0 + have ⟨ n, Hsucc ⟩ : ∃ n, least (fix_fuel_P f x) = Nat.succ n := by + revert Hl + simp [div?] + cases least (fix_fuel_P f x) <;> simp [fix_fuel] + simp [Hsucc] at Hl + revert Hl + simp [*, div?, fix, fix_fuel] + -- Use the monotonicity + have Hfixmono := fix_fuel_fix_mono Hmono n + have Hvm := Hmono Hfixmono x + -- Use functional extensionality + simp [result_rel, fix] at Hvm + revert Hvm + split <;> simp [*] <;> intros <;> simp [*] + + theorem fix_fixed_eq_forall {{f : ((x:a) → Result (b x)) → (x:a) → Result (b x)}} + (Hmono : is_mono f) (Hcont : is_cont f) : + ∀ x, fix f x = f (fix f) x := by + intros x + -- Case disjunction: is there a fuel such that the execution successfully execute? + match Classical.em (∃ n, fix_fuel_P f x n) with + | .inr He => + -- No fuel: the fixed point evaluates to `div` + --simp [fix] at * + simp at * + conv => lhs; simp [fix] + have Hel := He (Nat.succ (least (fix_fuel_P f x))); simp [*, fix_fuel] at *; clear Hel + -- Use the "continuity" of `f` + have He : ∀ n, fix_fuel (.succ n) f x = div := by intros; simp [*] + have Hcont := Hcont x He + simp [Hcont] + | .inl ⟨ n, He ⟩ => apply fix_fixed_eq_terminates f Hmono x n He + + -- The final fixed point equation + theorem fix_fixed_eq {{f : ((x:a) → Result (b x)) → (x:a) → Result (b x)}} + (Hmono : is_mono f) (Hcont : is_cont f) : + fix f = f (fix f) := by + have Heq := fix_fixed_eq_forall Hmono Hcont + have Heq1 : fix f = (λ x => fix f x) := by simp + rw [Heq1] + conv => lhs; ext; simp [Heq] + + /-! # Making the proofs of validity manageable (and automatable) -/ + + -- Monotonicity property for expressions + def is_mono_p (e : ((x:a) → Result (b x)) → Result c) : Prop := + ∀ {{k1 k2}}, karrow_rel k1 k2 → result_rel (e k1) (e k2) + + theorem is_mono_p_same (x : Result c) : + @is_mono_p a b c (λ _ => x) := by + simp [is_mono_p, karrow_rel, result_rel] + split <;> simp + + theorem is_mono_p_rec (x : a) : + @is_mono_p a b (b x) (λ f => f x) := by + simp_all [is_mono_p, karrow_rel, result_rel] + + -- The important lemma about `is_mono_p` + theorem is_mono_p_bind + (g : ((x:a) → Result (b x)) → Result c) + (h : c → ((x:a) → Result (b x)) → Result d) : + is_mono_p g → + (∀ y, is_mono_p (h y)) → + @is_mono_p a b d (λ k => @Bind.bind Result _ c d (g k) (fun y => h y k)) := by +-- @is_mono_p a b d (λ k => do let (y : c) ← g k; h y k) := by + intro hg hh + simp [is_mono_p] + intro fg fh Hrgh + simp [karrow_rel, result_rel] + have hg := hg Hrgh; simp [result_rel] at hg + cases heq0: g fg <;> simp_all + rename_i y _ + have hh := hh y Hrgh; simp [result_rel] at hh + simp_all + + -- Continuity property for expressions - note that we take the continuation + -- as parameter + def is_cont_p (k : ((x:a) → Result (b x)) → (x:a) → Result (b x)) + (e : ((x:a) → Result (b x)) → Result c) : Prop := + (Hc : ∀ n, e (fix_fuel n k) = .div) → + e (fix k) = .div + + theorem is_cont_p_same (k : ((x:a) → Result (b x)) → (x:a) → Result (b x)) + (x : Result c) : + is_cont_p k (λ _ => x) := by + simp [is_cont_p] + + theorem is_cont_p_rec (f : ((x:a) → Result (b x)) → (x:a) → Result (b x)) (x : a) : + is_cont_p f (λ f => f x) := by + simp_all [is_cont_p, fix] + + -- The important lemma about `is_cont_p` + theorem is_cont_p_bind + (k : ((x:a) → Result (b x)) → (x:a) → Result (b x)) + (Hkmono : is_mono k) + (g : ((x:a) → Result (b x)) → Result c) + (h : c → ((x:a) → Result (b x)) → Result d) : + is_mono_p g → + is_cont_p k g → + (∀ y, is_mono_p (h y)) → + (∀ y, is_cont_p k (h y)) → + is_cont_p k (λ k => do let y ← g k; h y k) := by + intro Hgmono Hgcont Hhmono Hhcont + simp [is_cont_p] + intro Hdiv + -- Case on `g (fix... k)`: is there an n s.t. it terminates? + cases Classical.em (∀ n, g (fix_fuel n k) = .div) <;> rename_i Hn + . -- Case 1: g diverges + have Hgcont := Hgcont Hn + simp_all + . -- Case 2: g doesn't diverge + simp at Hn + let ⟨ n, Hn ⟩ := Hn + have Hdivn := Hdiv n + have Hffmono := fix_fuel_fix_mono Hkmono n + have Hgeq := Hgmono Hffmono + simp [result_rel] at Hgeq + cases Heq: g (fix_fuel n k) <;> rename_i y <;> simp_all + -- Remains the .ret case + -- Use Hdiv to prove that: ∀ n, h y (fix_fuel n f) = div + -- We do this in two steps: first we prove it for m ≥ n + have Hhdiv: ∀ m, h y (fix_fuel m k) = .div := by + have Hhdiv : ∀ m, n ≤ m → h y (fix_fuel m k) = .div := by + -- We use the fact that `g (fix_fuel n f) = .div`, combined with Hdiv + intro m Hle + have Hdivm := Hdiv m + -- Monotonicity of g + have Hffmono := fix_fuel_mono Hkmono Hle + have Hgmono := Hgmono Hffmono + -- We need to clear Hdiv because otherwise simp_all rewrites Hdivm with Hdiv + clear Hdiv + simp_all [result_rel] + intro m + -- TODO: we shouldn't need the excluded middle here because it is decidable + cases Classical.em (n ≤ m) <;> rename_i Hl + . apply Hhdiv; assumption + . simp at Hl + -- Make a case disjunction on `h y (fix_fuel m k)`: if it is not equal + -- to div, use the monotonicity of `h y` + have Hle : m ≤ n := by linarith + have Hffmono := fix_fuel_mono Hkmono Hle + have Hmono := Hhmono y Hffmono + simp [result_rel] at Hmono + cases Heq: h y (fix_fuel m k) <;> simp_all + -- We can now use the continuity hypothesis for h + apply Hhcont; assumption + + -- The validity property for an expression + def is_valid_p (k : ((x:a) → Result (b x)) → (x:a) → Result (b x)) + (e : ((x:a) → Result (b x)) → Result c) : Prop := + is_mono_p e ∧ + (is_mono k → is_cont_p k e) + + @[simp] theorem is_valid_p_same + (k : ((x:a) → Result (b x)) → (x:a) → Result (b x)) (x : Result c) : + is_valid_p k (λ _ => x) := by + simp [is_valid_p, is_mono_p_same, is_cont_p_same] + + @[simp] theorem is_valid_p_rec + (k : ((x:a) → Result (b x)) → (x:a) → Result (b x)) (x : a) : + is_valid_p k (λ k => k x) := by + simp_all [is_valid_p, is_mono_p_rec, is_cont_p_rec] + + theorem is_valid_p_ite + (k : ((x:a) → Result (b x)) → (x:a) → Result (b x)) + (cond : Prop) [h : Decidable cond] + {e1 e2 : ((x:a) → Result (b x)) → Result c} + (he1: is_valid_p k e1) (he2 : is_valid_p k e2) : + is_valid_p k (ite cond e1 e2) := by + split <;> assumption + + theorem is_valid_p_dite + (k : ((x:a) → Result (b x)) → (x:a) → Result (b x)) + (cond : Prop) [h : Decidable cond] + {e1 : cond → ((x:a) → Result (b x)) → Result c} + {e2 : Not cond → ((x:a) → Result (b x)) → Result c} + (he1: ∀ x, is_valid_p k (e1 x)) (he2 : ∀ x, is_valid_p k (e2 x)) : + is_valid_p k (dite cond e1 e2) := by + split <;> simp [*] + + -- Lean is good at unification: we can write a very general version + -- (in particular, it will manage to figure out `g` and `h` when we + -- apply the lemma) + theorem is_valid_p_bind + {{k : ((x:a) → Result (b x)) → (x:a) → Result (b x)}} + {{g : ((x:a) → Result (b x)) → Result c}} + {{h : c → ((x:a) → Result (b x)) → Result d}} + (Hgvalid : is_valid_p k g) + (Hhvalid : ∀ y, is_valid_p k (h y)) : + is_valid_p k (λ k => do let y ← g k; h y k) := by + let ⟨ Hgmono, Hgcont ⟩ := Hgvalid + simp [is_valid_p, forall_and] at Hhvalid + have ⟨ Hhmono, Hhcont ⟩ := Hhvalid + simp [← imp_forall_iff] at Hhcont + simp [is_valid_p]; constructor + . -- Monotonicity + apply is_mono_p_bind <;> assumption + . -- Continuity + intro Hkmono + have Hgcont := Hgcont Hkmono + have Hhcont := Hhcont Hkmono + apply is_cont_p_bind <;> assumption + + def is_valid (f : ((x:a) → Result (b x)) → (x:a) → Result (b x)) : Prop := + ∀ k x, is_valid_p k (λ k => f k x) + + theorem is_valid_p_imp_is_valid {{f : ((x:a) → Result (b x)) → (x:a) → Result (b x)}} + (Hvalid : is_valid f) : + is_mono f ∧ is_cont f := by + have Hmono : is_mono f := by + intro f h Hr x + have Hmono := Hvalid (λ _ _ => .div) x + have Hmono := Hmono.left + apply Hmono; assumption + have Hcont : is_cont f := by + intro x Hdiv + have Hcont := (Hvalid f x).right Hmono + simp [is_cont_p] at Hcont + apply Hcont + intro n + have Hdiv := Hdiv n + simp [fix_fuel] at Hdiv + simp [*] + simp [*] + + theorem is_valid_fix_fixed_eq {{f : ((x:a) → Result (b x)) → (x:a) → Result (b x)}} + (Hvalid : is_valid f) : + fix f = f (fix f) := by + have ⟨ Hmono, Hcont ⟩ := is_valid_p_imp_is_valid Hvalid + exact fix_fixed_eq Hmono Hcont + +end Fix + +namespace FixI + /- Indexed fixed-point: definitions with indexed types, convenient to use for mutually + recursive definitions. We simply port the definitions and proofs from Fix to a more + specific case. + -/ + open Primitives Fix + + -- The index type + variable {id : Type u} + + -- The input/output types + variable {a : id → Type v} {b : (i:id) → a i → Type w} + + -- Monotonicity relation over monadic arrows (i.e., Kleisli arrows) + def karrow_rel (k1 k2 : (i:id) → (x:a i) → Result (b i x)) : Prop := + ∀ i x, result_rel (k1 i x) (k2 i x) + + def kk_to_gen (k : (i:id) → (x:a i) → Result (b i x)) : + (x: (i:id) × a i) → Result (b x.fst x.snd) := + λ ⟨ i, x ⟩ => k i x + + def kk_of_gen (k : (x: (i:id) × a i) → Result (b x.fst x.snd)) : + (i:id) → (x:a i) → Result (b i x) := + λ i x => k ⟨ i, x ⟩ + + def k_to_gen (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) : + ((x: (i:id) × a i) → Result (b x.fst x.snd)) → (x: (i:id) × a i) → Result (b x.fst x.snd) := + λ kk => kk_to_gen (k (kk_of_gen kk)) + + def k_of_gen (k : ((x: (i:id) × a i) → Result (b x.fst x.snd)) → (x: (i:id) × a i) → Result (b x.fst x.snd)) : + ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x) := + λ kk => kk_of_gen (k (kk_to_gen kk)) + + def e_to_gen (e : ((i:id) → (x:a i) → Result (b i x)) → Result c) : + ((x: (i:id) × a i) → Result (b x.fst x.snd)) → Result c := + λ k => e (kk_of_gen k) + + def is_valid_p (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) + (e : ((i:id) → (x:a i) → Result (b i x)) → Result c) : Prop := + Fix.is_valid_p (k_to_gen k) (e_to_gen e) + + def is_valid (f : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) : Prop := + ∀ k i x, is_valid_p k (λ k => f k i x) + + def fix + (f : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) : + (i:id) → (x:a i) → Result (b i x) := + kk_of_gen (Fix.fix (k_to_gen f)) + + theorem is_valid_fix_fixed_eq + {{f : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)}} + (Hvalid : is_valid f) : + fix f = f (fix f) := by + have Hvalid' : Fix.is_valid (k_to_gen f) := by + intro k x + simp only [is_valid, is_valid_p] at Hvalid + let ⟨ i, x ⟩ := x + have Hvalid := Hvalid (k_of_gen k) i x + simp only [k_to_gen, k_of_gen, kk_to_gen, kk_of_gen] at Hvalid + refine Hvalid + have Heq := Fix.is_valid_fix_fixed_eq Hvalid' + simp [fix] + conv => lhs; rw [Heq] + + /- Some utilities to define the mutually recursive functions -/ + + -- TODO: use more + abbrev kk_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) := + (i:id) → (x:a i) → Result (b i x) + abbrev k_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) := + kk_ty id a b → kk_ty id a b + + abbrev in_out_ty : Type (imax (u + 1) (v + 1)) := (in_ty : Type u) × ((x:in_ty) → Type v) + -- TODO: remove? + abbrev mk_in_out_ty (in_ty : Type u) (out_ty : in_ty → Type v) : + in_out_ty := + Sigma.mk in_ty out_ty + + -- Initially, we had left out the parameters id, a and b. + -- However, by parameterizing Funs with those parameters, we can state + -- and prove lemmas like Funs.is_valid_p_is_valid_p + inductive Funs (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) : + List in_out_ty.{v, w} → Type (max (u + 1) (max (v + 1) (w + 1))) := + | Nil : Funs id a b [] + | Cons {ity : Type v} {oty : ity → Type w} {tys : List in_out_ty} + (f : kk_ty id a b → (x:ity) → Result (oty x)) (tl : Funs id a b tys) : + Funs id a b (⟨ ity, oty ⟩ :: tys) + + def get_fun {tys : List in_out_ty} (fl : Funs id a b tys) : + (i : Fin tys.length) → kk_ty id a b → (x : (tys.get i).fst) → + Result ((tys.get i).snd x) := + match fl with + | .Nil => λ i => by have h:= i.isLt; simp at h + | @Funs.Cons id a b ity oty tys1 f tl => + λ ⟨ i, iLt ⟩ => + match i with + | 0 => + Eq.mp (by simp [List.get]) f + | .succ j => + have jLt: j < tys1.length := by + simp at iLt + revert iLt + simp_arith + let j: Fin tys1.length := ⟨ j, jLt ⟩ + Eq.mp (by simp) (get_fun tl j) + + def for_all_fin_aux {n : Nat} (f : Fin n → Prop) (m : Nat) (h : m ≤ n) : Prop := + if heq: m = n then True + else + f ⟨ m, by simp_all [Nat.lt_iff_le_and_ne] ⟩ ∧ + for_all_fin_aux f (m + 1) (by simp_all [Arith.add_one_le_iff_le_ne]) + termination_by for_all_fin_aux n _ m h => n - m + decreasing_by + simp_wf + apply Nat.sub_add_lt_sub <;> simp + simp_all [Arith.add_one_le_iff_le_ne] + + def for_all_fin {n : Nat} (f : Fin n → Prop) := for_all_fin_aux f 0 (by simp) + + theorem for_all_fin_aux_imp_forall {n : Nat} (f : Fin n → Prop) (m : Nat) : + (h : m ≤ n) → + for_all_fin_aux f m h → ∀ i, m ≤ i.val → f i + := by + generalize h: (n - m) = k + revert m + induction k -- TODO: induction h rather? + case zero => + simp_all + intro m h1 h2 + have h: n = m := by + linarith + unfold for_all_fin_aux; simp_all + simp_all + -- There is no i s.t. m ≤ i + intro i h3; cases i; simp_all + linarith + case succ k hi => + simp_all + intro m hk hmn + intro hf i hmi + have hne: m ≠ n := by + have hineq := Nat.lt_of_sub_eq_succ hk + linarith + -- m = i? + if heq: m = i then + -- Yes: simply use the `for_all_fin_aux` hyp + unfold for_all_fin_aux at hf + simp_all + tauto + else + -- No: use the induction hypothesis + have hlt: m < i := by simp_all [Nat.lt_iff_le_and_ne] + have hineq: m + 1 ≤ n := by + have hineq := Nat.lt_of_sub_eq_succ hk + simp [*, Nat.add_one_le_iff] + have heq1: n - (m + 1) = k := by + -- TODO: very annoying arithmetic proof + simp [Nat.sub_eq_iff_eq_add hineq] + have hineq1: m ≤ n := by linarith + simp [Nat.sub_eq_iff_eq_add hineq1] at hk + simp_arith [hk] + have hi := hi (m + 1) heq1 hineq + apply hi <;> simp_all + . unfold for_all_fin_aux at hf + simp_all + . simp_all [Arith.add_one_le_iff_le_ne] + + -- TODO: this is not necessary anymore + theorem for_all_fin_imp_forall (n : Nat) (f : Fin n → Prop) : + for_all_fin f → ∀ i, f i + := by + intro Hf i + apply for_all_fin_aux_imp_forall <;> try assumption + simp + + /- Automating the proofs -/ + @[simp] theorem is_valid_p_same + (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) (x : Result c) : + is_valid_p k (λ _ => x) := by + simp [is_valid_p, k_to_gen, e_to_gen] + + @[simp] theorem is_valid_p_rec + (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) (i : id) (x : a i) : + is_valid_p k (λ k => k i x) := by + simp [is_valid_p, k_to_gen, e_to_gen, kk_to_gen, kk_of_gen] + + theorem is_valid_p_ite + (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) + (cond : Prop) [h : Decidable cond] + {e1 e2 : ((i:id) → (x:a i) → Result (b i x)) → Result c} + (he1: is_valid_p k e1) (he2 : is_valid_p k e2) : + is_valid_p k (λ k => ite cond (e1 k) (e2 k)) := by + split <;> assumption + + theorem is_valid_p_dite + (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) + (cond : Prop) [h : Decidable cond] + {e1 : ((i:id) → (x:a i) → Result (b i x)) → cond → Result c} + {e2 : ((i:id) → (x:a i) → Result (b i x)) → Not cond → Result c} + (he1: ∀ x, is_valid_p k (λ k => e1 k x)) + (he2 : ∀ x, is_valid_p k (λ k => e2 k x)) : + is_valid_p k (λ k => dite cond (e1 k) (e2 k)) := by + split <;> simp [*] + + theorem is_valid_p_bind + {{k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)}} + {{g : ((i:id) → (x:a i) → Result (b i x)) → Result c}} + {{h : c → ((i:id) → (x:a i) → Result (b i x)) → Result d}} + (Hgvalid : is_valid_p k g) + (Hhvalid : ∀ y, is_valid_p k (h y)) : + is_valid_p k (λ k => do let y ← g k; h y k) := by + apply Fix.is_valid_p_bind + . apply Hgvalid + . apply Hhvalid + + def Funs.is_valid_p + (k : k_ty id a b) + (fl : Funs id a b tys) : + Prop := + match fl with + | .Nil => True + | .Cons f fl => (∀ x, FixI.is_valid_p k (λ k => f k x)) ∧ fl.is_valid_p k + + theorem Funs.is_valid_p_Nil (k : k_ty id a b) : + Funs.is_valid_p k Funs.Nil := by simp [Funs.is_valid_p] + + def Funs.is_valid_p_is_valid_p_aux + {k : k_ty id a b} + {tys : List in_out_ty} + (fl : Funs id a b tys) (Hvalid : is_valid_p k fl) : + ∀ (i : Fin tys.length) (x : (tys.get i).fst), FixI.is_valid_p k (fun k => get_fun fl i k x) := by + -- Prepare the induction + have ⟨ n, Hn ⟩ : { n : Nat // tys.length = n } := ⟨ tys.length, by rfl ⟩ + revert tys fl Hvalid + induction n + -- + case zero => + intro tys fl Hvalid Hlen; + have Heq: tys = [] := by cases tys <;> simp_all + intro i x + simp_all + have Hi := i.isLt + simp_all + case succ n Hn => + intro tys fl Hvalid Hlen i x; + cases fl <;> simp at Hlen i x Hvalid + rename_i ity oty tys f fl + have ⟨ Hvf, Hvalid ⟩ := Hvalid + have Hvf1: is_valid_p k fl := by + simp [Hvalid, Funs.is_valid_p] + have Hn := @Hn tys fl Hvf1 (by simp [*]) + -- Case disjunction on i + match i with + | ⟨ 0, _ ⟩ => + simp at x + simp [get_fun] + apply (Hvf x) + | ⟨ .succ j, HiLt ⟩ => + simp_arith at HiLt + simp at x + let j : Fin (List.length tys) := ⟨ j, by simp_arith [HiLt] ⟩ + have Hn := Hn j x + apply Hn + + def Funs.is_valid_p_is_valid_p + (tys : List in_out_ty) + (k : k_ty (Fin (List.length tys)) (λ i => (tys.get i).fst) (fun i x => (List.get tys i).snd x)) + (fl : Funs (Fin tys.length) (λ i => (tys.get i).fst) (λ i x => (tys.get i).snd x) tys) : + fl.is_valid_p k → + ∀ (i : Fin tys.length) (x : (tys.get i).fst), + FixI.is_valid_p k (fun k => get_fun fl i k x) + := by + intro Hvalid + apply is_valid_p_is_valid_p_aux; simp [*] + +end FixI + +namespace Ex1 + /- An example of use of the fixed-point -/ + open Primitives Fix + + variable {a : Type} (k : (List a × Int) → Result a) + + def list_nth_body (x : (List a × Int)) : Result a := + let (ls, i) := x + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else k (tl, i - 1) + + theorem list_nth_body_is_valid: ∀ k x, is_valid_p k (λ k => @list_nth_body a k x) := by + intro k x + simp [list_nth_body] + split <;> simp + split <;> simp + + def list_nth (ls : List a) (i : Int) : Result a := fix list_nth_body (ls, i) + + -- The unfolding equation - diverges if `i < 0` + theorem list_nth_eq (ls : List a) (i : Int) : + list_nth ls i = + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else list_nth tl (i - 1) + := by + have Heq := is_valid_fix_fixed_eq (@list_nth_body_is_valid a) + simp [list_nth] + conv => lhs; rw [Heq] + +end Ex1 + +namespace Ex2 + /- Same as Ex1, but we make the body of nth non tail-rec (this is mostly + to see what happens when there are let-bindings) -/ + open Primitives Fix + + variable {a : Type} (k : (List a × Int) → Result a) + + def list_nth_body (x : (List a × Int)) : Result a := + let (ls, i) := x + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else + do + let y ← k (tl, i - 1) + .ret y + + theorem list_nth_body_is_valid: ∀ k x, is_valid_p k (λ k => @list_nth_body a k x) := by + intro k x + simp [list_nth_body] + split <;> simp + split <;> simp + apply is_valid_p_bind <;> intros <;> simp_all + + def list_nth (ls : List a) (i : Int) : Result a := fix list_nth_body (ls, i) + + -- The unfolding equation - diverges if `i < 0` + theorem list_nth_eq (ls : List a) (i : Int) : + (list_nth ls i = + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else + do + let y ← list_nth tl (i - 1) + .ret y) + := by + have Heq := is_valid_fix_fixed_eq (@list_nth_body_is_valid a) + simp [list_nth] + conv => lhs; rw [Heq] + +end Ex2 + +namespace Ex3 + /- Mutually recursive functions - first encoding (see Ex4 for a better encoding) -/ + open Primitives Fix + + /- Because we have mutually recursive functions, we use a sum for the inputs + and the output types: + - inputs: the sum allows to select the function to call in the recursive + calls (and the functions may not have the same input types) + - outputs: this case is degenerate because `even` and `odd` have the same + return type `Bool`, but generally speaking we need a sum type because + the functions in the mutually recursive group may have different + return types. + -/ + variable (k : (Int ⊕ Int) → Result (Bool ⊕ Bool)) + + def is_even_is_odd_body (x : (Int ⊕ Int)) : Result (Bool ⊕ Bool) := + match x with + | .inl i => + -- Body of `is_even` + if i = 0 + then .ret (.inl true) -- We use .inl because this is `is_even` + else + do + let b ← + do + -- Call `odd`: we need to wrap the input value in `.inr`, then + -- extract the output value + let r ← k (.inr (i- 1)) + match r with + | .inl _ => .fail .panic -- Invalid output + | .inr b => .ret b + -- Wrap the return value + .ret (.inl b) + | .inr i => + -- Body of `is_odd` + if i = 0 + then .ret (.inr false) -- We use .inr because this is `is_odd` + else + do + let b ← + do + -- Call `is_even`: we need to wrap the input value in .inr, then + -- extract the output value + let r ← k (.inl (i- 1)) + match r with + | .inl b => .ret b + | .inr _ => .fail .panic -- Invalid output + -- Wrap the return value + .ret (.inr b) + + theorem is_even_is_odd_body_is_valid: + ∀ k x, is_valid_p k (λ k => is_even_is_odd_body k x) := by + intro k x + simp [is_even_is_odd_body] + split <;> simp <;> split <;> simp + apply is_valid_p_bind; simp + intros; split <;> simp + apply is_valid_p_bind; simp + intros; split <;> simp + + def is_even (i : Int): Result Bool := + do + let r ← fix is_even_is_odd_body (.inl i) + match r with + | .inl b => .ret b + | .inr _ => .fail .panic + + def is_odd (i : Int): Result Bool := + do + let r ← fix is_even_is_odd_body (.inr i) + match r with + | .inl _ => .fail .panic + | .inr b => .ret b + + -- The unfolding equation for `is_even` - diverges if `i < 0` + theorem is_even_eq (i : Int) : + is_even i = (if i = 0 then .ret true else is_odd (i - 1)) + := by + have Heq := is_valid_fix_fixed_eq is_even_is_odd_body_is_valid + simp [is_even, is_odd] + conv => lhs; rw [Heq]; simp; rw [is_even_is_odd_body]; simp + -- Very annoying: we need to swap the matches + -- Doing this with rewriting lemmas is hard generally speaking + -- (especially as we may have to generate lemmas for user-defined + -- inductives on the fly). + -- The simplest is to repeatedly split then simplify (we identify + -- the outer match or monadic let-binding, and split on its scrutinee) + split <;> simp + cases H0 : fix is_even_is_odd_body (Sum.inr (i - 1)) <;> simp + rename_i v + split <;> simp + + -- The unfolding equation for `is_odd` - diverges if `i < 0` + theorem is_odd_eq (i : Int) : + is_odd i = (if i = 0 then .ret false else is_even (i - 1)) + := by + have Heq := is_valid_fix_fixed_eq is_even_is_odd_body_is_valid + simp [is_even, is_odd] + conv => lhs; rw [Heq]; simp; rw [is_even_is_odd_body]; simp + -- Same remark as for `even` + split <;> simp + cases H0 : fix is_even_is_odd_body (Sum.inl (i - 1)) <;> simp + rename_i v + split <;> simp + +end Ex3 + +namespace Ex4 + /- Mutually recursive functions - 2nd encoding -/ + open Primitives FixI + + /- We make the input type and output types dependent on a parameter -/ + @[simp] def tys : List in_out_ty := [mk_in_out_ty Int (λ _ => Bool), mk_in_out_ty Int (λ _ => Bool)] + @[simp] def input_ty (i : Fin 2) : Type := (tys.get i).fst + @[simp] def output_ty (i : Fin 2) (x : input_ty i) : Type := + (tys.get i).snd x + + /- The bodies are more natural -/ + def is_even_body (k : (i : Fin 2) → (x : input_ty i) → Result (output_ty i x)) (i : Int) : Result Bool := + if i = 0 + then .ret true + else do + let b ← k 1 (i - 1) + .ret b + + def is_odd_body (k : (i : Fin 2) → (x : input_ty i) → Result (output_ty i x)) (i : Int) : Result Bool := + if i = 0 + then .ret false + else do + let b ← k 0 (i - 1) + .ret b + + @[simp] def bodies : + Funs (Fin 2) input_ty output_ty + [mk_in_out_ty Int (λ _ => Bool), mk_in_out_ty Int (λ _ => Bool)] := + Funs.Cons (is_even_body) (Funs.Cons (is_odd_body) Funs.Nil) + + def body (k : (i : Fin 2) → (x : input_ty i) → Result (output_ty i x)) (i: Fin 2) : + (x : input_ty i) → Result (output_ty i x) := get_fun bodies i k + + theorem body_is_valid : is_valid body := by + -- Split the proof into proofs of validity of the individual bodies + rw [is_valid] + simp only [body] + intro k + apply (Funs.is_valid_p_is_valid_p tys) + simp [Funs.is_valid_p] + (repeat (apply And.intro)) <;> intro x <;> simp at x <;> + simp only [is_even_body, is_odd_body] + -- Prove the validity of the individual bodies + . split <;> simp + apply is_valid_p_bind <;> simp + . split <;> simp + apply is_valid_p_bind <;> simp + + theorem body_fix_eq : fix body = body (fix body) := + is_valid_fix_fixed_eq body_is_valid + + def is_even (i : Int) : Result Bool := fix body 0 i + def is_odd (i : Int) : Result Bool := fix body 1 i + + theorem is_even_eq (i : Int) : is_even i = + (if i = 0 + then .ret true + else do + let b ← is_odd (i - 1) + .ret b) := by + simp [is_even, is_odd]; + conv => lhs; rw [body_fix_eq] + + theorem is_odd_eq (i : Int) : is_odd i = + (if i = 0 + then .ret false + else do + let b ← is_even (i - 1) + .ret b) := by + simp [is_even, is_odd]; + conv => lhs; rw [body_fix_eq] +end Ex4 + +namespace Ex5 + /- Higher-order example -/ + open Primitives Fix + + variable {a b : Type} + + /- An auxiliary function, which doesn't require the fixed-point -/ + def map (f : a → Result b) (ls : List a) : Result (List b) := + match ls with + | [] => .ret [] + | hd :: tl => + do + let hd ← f hd + let tl ← map f tl + .ret (hd :: tl) + + /- The validity theorem for `map`, generic in `f` -/ + theorem map_is_valid + {{f : (a → Result b) → a → Result c}} + (Hfvalid : ∀ k x, is_valid_p k (λ k => f k x)) + (k : (a → Result b) → a → Result b) + (ls : List a) : + is_valid_p k (λ k => map (f k) ls) := by + induction ls <;> simp [map] + apply is_valid_p_bind <;> simp_all + intros + apply is_valid_p_bind <;> simp_all + + /- An example which uses map -/ + inductive Tree (a : Type) := + | leaf (x : a) + | node (tl : List (Tree a)) + + def id_body (k : Tree a → Result (Tree a)) (t : Tree a) : Result (Tree a) := + match t with + | .leaf x => .ret (.leaf x) + | .node tl => + do + let tl ← map k tl + .ret (.node tl) + + theorem id_body_is_valid : + ∀ k x, is_valid_p k (λ k => @id_body a k x) := by + intro k x + simp only [id_body] + split <;> simp + apply is_valid_p_bind <;> simp [*] + -- We have to show that `map k tl` is valid + apply map_is_valid; + -- Remark: if we don't do the intro, then the last step is expensive: + -- "typeclass inference of Nonempty took 119ms" + intro k x + simp only [is_valid_p_same, is_valid_p_rec] + + def id (t : Tree a) := fix id_body t + + -- The unfolding equation + theorem id_eq (t : Tree a) : + (id t = + match t with + | .leaf x => .ret (.leaf x) + | .node tl => + do + let tl ← map id tl + .ret (.node tl)) + := by + have Heq := is_valid_fix_fixed_eq (@id_body_is_valid a) + simp [id] + conv => lhs; rw [Heq]; simp; rw [id_body] + +end Ex5 + +namespace Ex6 + /- `list_nth` again, but this time we use FixI -/ + open Primitives FixI + + @[simp] def tys.{u} : List in_out_ty := + [mk_in_out_ty ((a:Type u) × (List a × Int)) (λ ⟨ a, _ ⟩ => a)] + + @[simp] def input_ty (i : Fin 1) := (tys.get i).fst + @[simp] def output_ty (i : Fin 1) (x : input_ty i) := + (tys.get i).snd x + + def list_nth_body.{u} (k : (i:Fin 1) → (x:input_ty i) → Result (output_ty i x)) + (x : (a : Type u) × List a × Int) : Result x.fst := + let ⟨ a, ls, i ⟩ := x + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else k 0 ⟨ a, tl, i - 1 ⟩ + + @[simp] def bodies : + Funs (Fin 1) input_ty output_ty tys := + Funs.Cons list_nth_body Funs.Nil + + def body (k : (i : Fin 1) → (x : input_ty i) → Result (output_ty i x)) (i: Fin 1) : + (x : input_ty i) → Result (output_ty i x) := get_fun bodies i k + + theorem body_is_valid: is_valid body := by + -- Split the proof into proofs of validity of the individual bodies + rw [is_valid] + simp only [body] + intro k + apply (Funs.is_valid_p_is_valid_p tys) + simp [Funs.is_valid_p] + (repeat (apply And.intro)); intro x; simp at x + simp only [list_nth_body] + -- Prove the validity of the individual bodies + intro k x + simp [list_nth_body] + split <;> simp + split <;> simp + + -- Writing the proof terms explicitly + theorem list_nth_body_is_valid' (k : k_ty (Fin 1) input_ty output_ty) + (x : (a : Type u) × List a × Int) : is_valid_p k (fun k => list_nth_body k x) := + let ⟨ a, ls, i ⟩ := x + match ls with + | [] => is_valid_p_same k (.fail .panic) + | hd :: tl => + is_valid_p_ite k (Eq i 0) (is_valid_p_same k (.ret hd)) (is_valid_p_rec k 0 ⟨a, tl, i-1⟩) + + theorem body_is_valid' : is_valid body := + fun k => + Funs.is_valid_p_is_valid_p tys k bodies + (And.intro (list_nth_body_is_valid' k) (Funs.is_valid_p_Nil k)) + + def list_nth {a: Type u} (ls : List a) (i : Int) : Result a := + fix body 0 ⟨ a, ls , i ⟩ + + -- The unfolding equation - diverges if `i < 0` + theorem list_nth_eq (ls : List a) (i : Int) : + list_nth ls i = + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else list_nth tl (i - 1) + := by + have Heq := is_valid_fix_fixed_eq body_is_valid + simp [list_nth] + conv => lhs; rw [Heq] + + -- Write the proof term explicitly: the generation of the proof term (without tactics) + -- is automatable, and the proof term is actually a lot simpler and smaller when we + -- don't use tactics. + theorem list_nth_eq'.{u} {a : Type u} (ls : List a) (i : Int) : + list_nth ls i = + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else list_nth tl (i - 1) + := + -- Use the fixed-point equation + have Heq := is_valid_fix_fixed_eq body_is_valid.{u} + -- Add the index + have Heqi := congr_fun Heq 0 + -- Add the input + have Heqix := congr_fun Heqi { fst := a, snd := (ls, i) } + -- Done + Heqix + +end Ex6 diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean new file mode 100644 index 00000000..f109e847 --- /dev/null +++ b/backends/lean/Base/Diverge/Elab.lean @@ -0,0 +1,1162 @@ +import Lean +import Lean.Meta.Tactic.Simp +import Init.Data.List.Basic +import Mathlib.Tactic.RunCmd +import Base.Utils +import Base.Diverge.Base +import Base.Diverge.ElabBase + +namespace Diverge + +/- Automating the generation of the encoding and the proofs so as to use nice + syntactic sugar. -/ + +syntax (name := divergentDef) + declModifiers "divergent" "def" declId ppIndent(optDeclSig) declVal : command + +open Lean Elab Term Meta Primitives Lean.Meta +open Utils + +/- The following was copied from the `wfRecursion` function. -/ + +open WF in + +def mkProd (x y : Expr) : MetaM Expr := + mkAppM ``Prod.mk #[x, y] + +def mkInOutTy (x y : Expr) : MetaM Expr := + mkAppM ``FixI.mk_in_out_ty #[x, y] + +-- Return the `a` in `Return a` +def getResultTy (ty : Expr) : MetaM Expr := + ty.withApp fun f args => do + if ¬ f.isConstOf ``Result ∨ args.size ≠ 1 then + throwError "Invalid argument to getResultTy: {ty}" + else + pure (args.get! 0) + +/- Deconstruct a sigma type. + + For instance, deconstructs `(a : Type) × List a` into + `Type` and `λ a => List a`. + -/ +def getSigmaTypes (ty : Expr) : MetaM (Expr × Expr) := do + ty.withApp fun f args => do + if ¬ f.isConstOf ``Sigma ∨ args.size ≠ 2 then + throwError "Invalid argument to getSigmaTypes: {ty}" + else + pure (args.get! 0, args.get! 1) + +/- Generate a Sigma type from a list of *variables* (all the expressions + must be variables). + + Example: + - xl = [(a:Type), (ls:List a), (i:Int)] + + Generates: + `(a:Type) × (ls:List a) × (i:Int)` + + -/ +def mkSigmasType (xl : List Expr) : MetaM Expr := + match xl with + | [] => do + trace[Diverge.def.sigmas] "mkSigmasOfTypes: []" + pure (Expr.const ``PUnit.unit []) + | [x] => do + trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}]" + let ty ← Lean.Meta.inferType x + pure ty + | x :: xl => do + trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]" + let alpha ← Lean.Meta.inferType x + let sty ← mkSigmasType xl + trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]: alpha={alpha}, sty={sty}" + let beta ← mkLambdaFVars #[x] sty + trace[Diverge.def.sigmas] "mkSigmasOfTypes: ({alpha}) ({beta})" + mkAppOptM ``Sigma #[some alpha, some beta] + +/- Apply a lambda expression to some arguments, simplifying the lambdas -/ +def applyLambdaToArgs (e : Expr) (xs : Array Expr) : MetaM Expr := do + lambdaTelescopeN e xs.size fun vars body => + -- Create the substitution + let s : HashMap FVarId Expr := HashMap.ofList (List.zip (vars.toList.map Expr.fvarId!) xs.toList) + -- Substitute in the body + pure (body.replace fun e => + match e with + | Expr.fvar fvarId => match s.find? fvarId with + | none => e + | some v => v + | _ => none) + +/- Group a list of expressions into a dependent tuple. + + Example: + xl = [`a : Type`, `ls : List a`] + returns: + `⟨ (a:Type), (ls: List a) ⟩` + + We need the type argument because as the elements in the tuple are + "concrete", we can't in all generality figure out the type of the tuple. + + Example: + `⟨ True, 3 ⟩ : (x : Bool) × (if x then Int else Unit)` + -/ +def mkSigmasVal (ty : Expr) (xl : List Expr) : MetaM Expr := + match xl with + | [] => do + trace[Diverge.def.sigmas] "mkSigmasVal: []" + pure (Expr.const ``PUnit.unit []) + | [x] => do + trace[Diverge.def.sigmas] "mkSigmasVal: [{x}]" + pure x + | fst :: xl => do + trace[Diverge.def.sigmas] "mkSigmasVal: [{fst}::{xl}]" + -- Deconstruct the type + let (alpha, beta) ← getSigmaTypes ty + -- Compute the "second" field + -- Specialize beta for fst + let nty ← applyLambdaToArgs beta #[fst] + -- Recursive call + let snd ← mkSigmasVal nty xl + -- Put everything together + trace[Diverge.def.sigmas] "mkSigmasVal:\n{alpha}\n{beta}\n{fst}\n{snd}" + mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd] + +def mkAnonymous (s : String) (i : Nat) : Name := + .num (.str .anonymous s) i + +/- Given a list of values `[x0:ty0, ..., xn:ty1]`, where every `xi` might use the previous + `xj` (j < i) and a value `out` which uses `x0`, ..., `xn`, generate the following + expression: + ``` + fun x:((x0:ty0) × ... × (xn:tyn) => -- **Dependent** tuple + match x with + | (x0, ..., xn) => out + ``` + + The `index` parameter is used for naming purposes: we use it to numerotate the + bound variables that we introduce. + + We use this function to currify functions (the function bodies given to the + fixed-point operator must be unary functions). + + Example: + ======== + - xl = `[a:Type, ls:List a, i:Int]` + - out = `a` + - index = 0 + + generates (getting rid of most of the syntactic sugar): + ``` + λ scrut0 => match scrut0 with + | Sigma.mk x scrut1 => + match scrut1 with + | Sigma.mk ls i => + a + ``` +-/ +partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : MetaM Expr := + match xl with + | [] => do + -- This would be unexpected + throwError "mkSigmasMatch: empyt list of input parameters" + | [x] => do + -- In the example given for the explanations: this is the inner match case + trace[Diverge.def.sigmas] "mkSigmasMatch: [{x}]" + mkLambdaFVars #[x] out + | fst :: xl => do + -- In the example given for the explanations: this is the outer match case + -- Remark: for the naming purposes, we use the same convention as for the + -- fields and parameters in `Sigma.casesOn` and `Sigma.mk` (looking at + -- those definitions might help) + -- + -- We want to build the match expression: + -- ``` + -- λ scrut => + -- match scrut with + -- | Sigma.mk x ... -- the hole is given by a recursive call on the tail + -- ``` + trace[Diverge.def.sigmas] "mkSigmasMatch: [{fst}::{xl}]" + let alpha ← Lean.Meta.inferType fst + let snd_ty ← mkSigmasType xl + let beta ← mkLambdaFVars #[fst] snd_ty + let snd ← mkSigmasMatch xl out (index + 1) + let mk ← mkLambdaFVars #[fst] snd + -- Introduce the "scrut" variable + let scrut_ty ← mkSigmasType (fst :: xl) + withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do + trace[Diverge.def.sigmas] "mkSigmasMatch: scrut: ({scrut}) : ({← inferType scrut})" + -- TODO: make the computation of the motive more efficient + let motive ← do + let out_ty ← inferType out + match out_ty with + | .sort _ | .lit _ | .const .. => + -- The type of the motive doesn't depend on the scrutinee + mkLambdaFVars #[scrut] out_ty + | _ => + -- The type of the motive *may* depend on the scrutinee + -- TODO: make this more efficient (we could change the output type of + -- mkSigmasMatch + mkSigmasMatch (fst :: xl) out_ty + -- The final expression: putting everything together + trace[Diverge.def.sigmas] "mkSigmasMatch:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})" + let sm ← mkAppOptM ``Sigma.casesOn #[some alpha, some beta, some motive, some scrut, some mk] + -- Abstracting the "scrut" variable + let sm ← mkLambdaFVars #[scrut] sm + trace[Diverge.def.sigmas] "mkSigmasMatch: sm: {sm}" + pure sm + +/- Small tests for list_nth: give a model of what `mkSigmasMatch` should generate -/ +private def list_nth_out_ty_inner (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) := + @Sigma.casesOn (List a) + (fun (_ls : List a) => Int) + (fun (_scrut1:@Sigma (List a) (fun (_ls : List a) => Int)) => Type) + scrut1 + (fun (_ls : List a) (_i : Int) => Primitives.Result a) + +private def list_nth_out_ty_outer (scrut0 : @Sigma (Type) (fun (a:Type) => + @Sigma (List a) (fun (_ls : List a) => Int))) := + @Sigma.casesOn (Type) + (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int)) + (fun (_scrut0:@Sigma (Type) (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int))) => Type) + scrut0 + (fun (a : Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) => + list_nth_out_ty_inner a scrut1) +/- -/ + +-- Return the expression: `Fin n` +-- TODO: use more +def mkFin (n : Nat) : Expr := + mkAppN (.const ``Fin []) #[.lit (.natVal n)] + +-- Return the expression: `i : Fin n` +def mkFinVal (n i : Nat) : MetaM Expr := do + let n_lit : Expr := .lit (.natVal (n - 1)) + let i_lit : Expr := .lit (.natVal i) + -- We could use `trySynthInstance`, but as we know the instance that we are + -- going to use, we can save the lookup + let ofNat ← mkAppOptM ``Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat #[n_lit, i_lit] + mkAppOptM ``OfNat.ofNat #[none, none, ofNat] + +/- Generate and declare as individual definitions the bodies for the individual funcions: + - replace the recursive calls with calls to the continutation `k` + - make those bodies take one single dependent tuple as input + + We name the declarations: "[original_name].body". + We return the new declarations. + -/ +def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) + (inOutTys : Array (Expr × Expr)) (preDefs : Array PreDefinition) : + MetaM (Array Expr) := do + let grSize := preDefs.size + + -- Compute the map from name to (index × input type). + -- Remark: the continuation has an indexed type; we use the index (a finite number of + -- type `Fin`) to control which function we call at the recursive call site. + let nameToInfo : HashMap Name (Nat × Expr) := + let bl := preDefs.mapIdx fun i d => (d.declName, (i.val, (inOutTys.get! i.val).fst)) + HashMap.ofList bl.toList + + trace[Diverge.def.genBody] "nameToId: {nameToInfo.toList}" + + -- Auxiliary function to explore the function bodies and replace the + -- recursive calls + let visit_e (i : Nat) (e : Expr) : MetaM Expr := do + trace[Diverge.def.genBody] "visiting expression (dept: {i}): {e}" + let ne ← do + match e with + | .app .. => do + e.withApp fun f args => do + trace[Diverge.def.genBody] "this is an app: {f} {args}" + -- Check if this is a recursive call + if f.isConst then + let name := f.constName! + match nameToInfo.find? name with + | none => pure e + | some (id, in_ty) => + trace[Diverge.def.genBody] "this is a recursive call" + -- This is a recursive call: replace it + -- Compute the index + let i ← mkFinVal grSize id + -- Put the arguments in one big dependent tuple + let args ← mkSigmasVal in_ty args.toList + mkAppM' kk_var #[i, args] + else + -- Not a recursive call: do nothing + pure e + | .const name _ => + -- Sanity check: we eliminated all the recursive calls + if (nameToInfo.find? name).isSome then + throwError "mkUnaryBodies: a recursive call was not eliminated" + else pure e + | _ => pure e + trace[Diverge.def.genBody] "done with expression (depth: {i}): {e}" + pure ne + + -- Explore the bodies + preDefs.mapM fun preDef => do + -- Replace the recursive calls + trace[Diverge.def.genBody] "About to replace recursive calls in {preDef.declName}" + let body ← mapVisit visit_e preDef.value + trace[Diverge.def.genBody] "Body after replacement of the recursive calls: {body}" + + -- Currify the function by grouping the arguments into a dependent tuple + -- (over which we match to retrieve the individual arguments). + lambdaTelescope body fun args body => do + let body ← mkSigmasMatch args.toList body 0 + + -- Add the declaration + let value ← mkLambdaFVars #[kk_var] body + let name := preDef.declName.append "body" + let levelParams := grLvlParams + let decl := Declaration.defnDecl { + name := name + levelParams := levelParams + type := ← inferType value -- TODO: change the type + value := value + hints := ReducibilityHints.regular (getMaxHeight (← getEnv) value + 1) + safety := .safe + all := [name] + } + addDecl decl + trace[Diverge.def] "individual body of {preDef.declName}: {body}" + -- Return the constant + let body := Lean.mkConst name (levelParams.map .param) + -- let body ← mkAppM' body #[kk_var] + trace[Diverge.def] "individual body (after decl): {body}" + pure body + +-- Generate a unique function body from the bodies of the mutually recursive group, +-- and add it as a declaration in the context. +-- We return the list of bodies (of type `FixI.Funs ...`) and the mutually recursive body. +def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name) + (kk_var i_var : Expr) + (in_ty out_ty : Expr) (inOutTys : List (Expr × Expr)) + (bodies : Array Expr) : MetaM (Expr × Expr) := do + -- Generate the body + let grSize := bodies.size + let finTypeExpr := mkFin grSize + -- TODO: not very clean + let inOutTyType ← do + let (x, y) := inOutTys.get! 0 + inferType (← mkInOutTy x y) + let rec mkFuns (inOutTys : List (Expr × Expr)) (bl : List Expr) : MetaM Expr := + match inOutTys, bl with + | [], [] => + mkAppOptM ``FixI.Funs.Nil #[finTypeExpr, in_ty, out_ty] + | (ity, oty) :: inOutTys, b :: bl => do + -- Retrieving ity and oty - this is not very clean + let inOutTysExpr ← mkListLit inOutTyType (← inOutTys.mapM (λ (x, y) => mkInOutTy x y)) + let fl ← mkFuns inOutTys bl + mkAppOptM ``FixI.Funs.Cons #[finTypeExpr, in_ty, out_ty, ity, oty, inOutTysExpr, b, fl] + | _, _ => throwError "mkDeclareMutRecBody: `tys` and `bodies` don't have the same length" + let bodyFuns ← mkFuns inOutTys bodies.toList + -- Wrap in `get_fun` + let body ← mkAppM ``FixI.get_fun #[bodyFuns, i_var, kk_var] + -- Add the index `i` and the continuation `k` as a variables + let body ← mkLambdaFVars #[kk_var, i_var] body + trace[Diverge.def] "mkDeclareMutRecBody: body: {body}" + -- Add the declaration + let name := grName.append "mut_rec_body" + let levelParams := grLvlParams + let decl := Declaration.defnDecl { + name := name + levelParams := levelParams + type := ← inferType body + value := body + hints := ReducibilityHints.regular (getMaxHeight (← getEnv) body + 1) + safety := .safe + all := [name] + } + addDecl decl + -- Return the bodies and the constant + pure (bodyFuns, Lean.mkConst name (levelParams.map .param)) + +def isCasesExpr (e : Expr) : MetaM Bool := do + let e := e.getAppFn + if e.isConst then + return isCasesOnRecursor (← getEnv) e.constName + else return false + +structure MatchInfo where + matcherName : Name + matcherLevels : Array Level + params : Array Expr + motive : Expr + scruts : Array Expr + branchesNumParams : Array Nat + branches : Array Expr + +instance : ToMessageData MatchInfo where + -- This is not a very clean formatting, but we don't need more + toMessageData := fun me => m!"\n- matcherName: {me.matcherName}\n- params: {me.params}\n- motive: {me.motive}\n- scruts: {me.scruts}\n- branchesNumParams: {me.branchesNumParams}\n- branches: {me.branches}" + +-- Small helper: prove that an expression which doesn't use the continuation `kk` +-- is valid, and return the proof. +def proveNoKExprIsValid (k_var : Expr) (e : Expr) : MetaM Expr := do + trace[Diverge.def.valid] "proveNoKExprIsValid: {e}" + let eIsValid ← mkAppM ``FixI.is_valid_p_same #[k_var, e] + trace[Diverge.def.valid] "proveNoKExprIsValid: result:\n{eIsValid}:\n{← inferType eIsValid}" + pure eIsValid + +mutual + +/- Prove that an expression is valid, and return the proof. + + More precisely, if `e` is an expression which potentially uses the continution + `kk`, return an expression of type: + ``` + is_valid_p k (λ kk => e) + ``` + -/ +partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do + trace[Diverge.def.valid] "proveValid: {e}" + match e with + | .const _ _ => throwError "Unimplemented" -- Shouldn't get there? + | .bvar _ + | .fvar _ + | .lit _ + | .mvar _ + | .sort _ => throwError "Unreachable" + | .lam .. => throwError "Unimplemented" + | .forallE .. => throwError "Unreachable" -- Shouldn't get there + | .letE .. => do + -- Telescope all the let-bindings (remark: this also telescopes the lambdas) + lambdaLetTelescope e fun xs body => do + -- Note that we don't visit the bound values: there shouldn't be + -- recursive calls, lambda expressions, etc. inside + -- Prove that the body is valid + let isValid ← proveExprIsValid k_var kk_var body + -- Add the let-bindings around. + -- Rem.: the let-binding should be *inside* the `is_valid_p`, not outside, + -- but because it reduces in the end it doesn't matter. More precisely: + -- `P (let x := v in y)` and `let x := v in P y` reduce to the same expression. + mkLambdaFVars xs isValid (usedLetOnly := false) + | .mdata _ b => proveExprIsValid k_var kk_var b + | .proj _ _ _ => + -- The projection shouldn't use the continuation + proveNoKExprIsValid k_var e + | .app .. => + e.withApp fun f args => do + -- There are several cases: first, check if this is a match/if + -- Check if the expression is a (dependent) if then else. + -- We treat the if then else expressions differently from the other matches, + -- and have dedicated theorems for them. + let isIte := e.isIte + if isIte || e.isDIte then do + e.withApp fun f args => do + trace[Diverge.def.valid] "ite/dite: {f}:\n{args}" + if args.size ≠ 5 then + throwError "Wrong number of parameters for {f}: {args}" + let cond := args.get! 1 + let dec := args.get! 2 + -- Prove that the branches are valid + let br0 := args.get! 3 + let br1 := args.get! 4 + let proveBranchValid (br : Expr) : MetaM Expr := + if isIte then proveExprIsValid k_var kk_var br + else do + -- There is a lambda + lambdaOne br fun x br => do + let brValid ← proveExprIsValid k_var kk_var br + mkLambdaFVars #[x] brValid + let br0Valid ← proveBranchValid br0 + let br1Valid ← proveBranchValid br1 + let const := if isIte then ``FixI.is_valid_p_ite else ``FixI.is_valid_p_dite + let eIsValid ← mkAppOptM const #[none, none, none, none, some k_var, some cond, some dec, none, none, some br0Valid, some br1Valid] + trace[Diverge.def.valid] "ite/dite: result:\n{eIsValid}:\n{← inferType eIsValid}" + pure eIsValid + -- Check if the expression is a match (this case is for when the elaborator + -- introduces auxiliary definitions to hide the match behind syntactic + -- sugar): + else if let some me := ← matchMatcherApp? e then do + trace[Diverge.def.valid] + "matcherApp: + - params: {me.params} + - motive: {me.motive} + - discrs: {me.discrs} + - altNumParams: {me.altNumParams} + - alts: {me.alts} + - remaining: {me.remaining}" + -- matchMatcherApp does all the work for us: we simply need to gather + -- the information and call the auxiliary helper `proveMatchIsValid` + if me.remaining.size ≠ 0 then + throwError "MatcherApp: non empty remaining array: {me.remaining}" + let me : MatchInfo := { + matcherName := me.matcherName + matcherLevels := me.matcherLevels + params := me.params + motive := me.motive + scruts := me.discrs + branchesNumParams := me.altNumParams + branches := me.alts + } + proveMatchIsValid k_var kk_var me + -- Check if the expression is a raw match (this case is for when the expression + -- is a direct call to the primitive `casesOn` function, without syntactic sugar). + -- We have to check this case because functions like `mkSigmasMatch`, which we + -- use to currify function bodies, introduce such raw matches. + else if ← isCasesExpr f then do + trace[Diverge.def.valid] "rawMatch: {e}" + -- Deconstruct the match, and call the auxiliary helper `proveMatchIsValid`. + -- + -- The casesOn definition is always of the following shape: + -- - input parameters (implicit parameters) + -- - motive (implicit), -- the motive gives the return type of the match + -- - scrutinee (explicit) + -- - branches (explicit). + -- In particular, we notice that the scrutinee is the first *explicit* + -- parameter - this is how we spot it. + let matcherName := f.constName! + let matcherLevels := f.constLevels!.toArray + -- Find the first explicit parameter: this is the scrutinee + forallTelescope (← inferType f) fun xs _ => do + let rec findFirstExplicit (i : Nat) : MetaM Nat := do + if i ≥ xs.size then throwError "Unexpected: could not find an explicit parameter" + else + let x := xs.get! i + let xFVarId := x.fvarId! + let localDecl ← xFVarId.getDecl + match localDecl.binderInfo with + | .default => pure i + | _ => findFirstExplicit (i + 1) + let scrutIdx ← findFirstExplicit 0 + -- Split the arguments + let params := args.extract 0 (scrutIdx - 1) + let motive := args.get! (scrutIdx - 1) + let scrut := args.get! scrutIdx + let branches := args.extract (scrutIdx + 1) args.size + -- Compute the number of parameters for the branches: for this we use + -- the type of the uninstantiated casesOn constant (we can't just + -- destruct the lambdas in the branch expressions because the result + -- of a match might be a lambda expression). + let branchesNumParams : Array Nat ← do + let env ← getEnv + let decl := env.constants.find! matcherName + let ty := decl.type + forallTelescope ty fun xs _ => do + let xs := xs.extract (scrutIdx + 1) xs.size + xs.mapM fun x => do + let xty ← inferType x + forallTelescope xty fun ys _ => do + pure ys.size + let me : MatchInfo := { + matcherName, + matcherLevels, + params, + motive, + scruts := #[scrut], + branchesNumParams, + branches, + } + proveMatchIsValid k_var kk_var me + -- Check if this is a monadic let-binding + else if f.isConstOf ``Bind.bind then do + trace[Diverge.def.valid] "bind:\n{args}" + -- We simply need to prove that the subexpressions are valid, and call + -- the appropriate lemma. + let x := args.get! 4 + let y := args.get! 5 + -- Prove that the subexpressions are valid + let xValid ← proveExprIsValid k_var kk_var x + trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}" + let yValid ← do + -- This is a lambda expression + lambdaOne y fun x y => do + trace[Diverge.def.valid] "bind: y: {y}" + let yValid ← proveExprIsValid k_var kk_var y + trace[Diverge.def.valid] "bind: yValid (no forall): {yValid}" + trace[Diverge.def.valid] "bind: yValid: x: {x}" + let yValid ← mkLambdaFVars #[x] yValid + trace[Diverge.def.valid] "bind: yValid (forall): {yValid}: {← inferType yValid}" + pure yValid + -- Put everything together + trace[Diverge.def.valid] "bind:\n- xValid: {xValid}: {← inferType xValid}\n- yValid: {yValid}: {← inferType yValid}" + mkAppM ``FixI.is_valid_p_bind #[xValid, yValid] + -- Check if this is a recursive call, i.e., a call to the continuation `kk` + else if f.isFVarOf kk_var.fvarId! then do + trace[Diverge.def.valid] "rec: args: \n{args}" + if args.size ≠ 2 then throwError "Recursive call with invalid number of parameters: {args}" + let i_arg := args.get! 0 + let x_arg := args.get! 1 + let eIsValid ← mkAppM ``FixI.is_valid_p_rec #[k_var, i_arg, x_arg] + trace[Diverge.def.valid] "rec: result: \n{eIsValid}" + pure eIsValid + else do + -- Remaining case: normal application. + -- It shouldn't use the continuation. + proveNoKExprIsValid k_var e + +-- Prove that a match expression is valid. +partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Expr := do + trace[Diverge.def.valid] "proveMatchIsValid: {me}" + -- Prove the validity of the branch expressions + let branchesValid:Array Expr ← me.branches.mapIdxM fun idx br => do + -- Go inside the lambdas - note that we have to be careful: some of the + -- binders might come from the match, and some of the binders might come + -- from the fact that the expression in the match is a lambda expression: + -- we use the branchesNumParams field for this reason + let numParams := me.branchesNumParams.get! idx + lambdaTelescopeN br numParams fun xs br => do + -- Prove that the branch expression is valid + let brValid ← proveExprIsValid k_var kk_var br + -- Reconstruct the lambda expression + mkLambdaFVars xs brValid + trace[Diverge.def.valid] "branchesValid:\n{branchesValid}" + -- Compute the motive, which has the following shape: + -- ``` + -- λ scrut => is_valid_p k (λ k => match scrut with ...) + -- ^^^^^^^^^^^^^^^^^^^^ + -- this is the original match expression, with the + -- the difference that the scrutinee(s) is a variable + -- ``` + let validMotive : Expr ← do + -- The motive is a function of the scrutinees (i.e., a lambda expression): + -- introduce binders for the scrutinees + let declInfos := me.scruts.mapIdx fun idx scrut => + let name : Name := mkAnonymous "scrut" idx + let ty := λ (_ : Array Expr) => inferType scrut + (name, ty) + withLocalDeclsD declInfos fun scrutVars => do + -- Create a match expression but where the scrutinees have been replaced + -- by variables + let params : Array (Option Expr) := me.params.map some + let motive : Option Expr := some me.motive + let scruts : Array (Option Expr) := scrutVars.map some + let branches : Array (Option Expr) := me.branches.map some + let args := params ++ [motive] ++ scruts ++ branches + let matchE ← mkAppOptM me.matcherName args + -- Wrap in the `is_valid_p` predicate + let matchE ← mkLambdaFVars #[kk_var] matchE + let validMotive ← mkAppM ``FixI.is_valid_p #[k_var, matchE] + -- Abstract away the scrutinee variables + mkLambdaFVars scrutVars validMotive + trace[Diverge.def.valid] "valid motive: {validMotive}" + -- Put together + let valid ← do + -- We let Lean infer the parameters + let params : Array (Option Expr) := me.params.map (λ _ => none) + let motive := some validMotive + let scruts := me.scruts.map some + let branches := branchesValid.map some + let args := params ++ [motive] ++ scruts ++ branches + mkAppOptM me.matcherName args + trace[Diverge.def.valid] "proveMatchIsValid:\n{valid}:\n{← inferType valid}" + pure valid + +end + +-- Prove that a single body (in the mutually recursive group) is valid. +-- +-- For instance, if we define the mutually recursive group [`is_even`, `is_odd`], +-- we prove that `is_even.body` and `is_odd.body` are valid. +partial def proveSingleBodyIsValid + (k_var : Expr) (preDef : PreDefinition) (bodyConst : Expr) : + MetaM Expr := do + trace[Diverge.def.valid] "proveSingleBodyIsValid: bodyConst: {bodyConst}" + -- Lookup the definition (`bodyConst` is a const, we want to retrieve its + -- definition to dive inside) + let name := bodyConst.constName! + let env ← getEnv + let body := (env.constants.find! name).value! + trace[Diverge.def.valid] "body: {body}" + lambdaTelescope body fun xs body => do + assert! xs.size = 2 + let kk_var := xs.get! 0 + let x_var := xs.get! 1 + -- State the type of the theorem to prove + let thmTy ← mkAppM ``FixI.is_valid_p + #[k_var, ← mkLambdaFVars #[kk_var] (← mkAppM' bodyConst #[kk_var, x_var])] + trace[Diverge.def.valid] "thmTy: {thmTy}" + -- Prove that the body is valid + let proof ← proveExprIsValid k_var kk_var body + let proof ← mkLambdaFVars #[k_var, x_var] proof + trace[Diverge.def.valid] "proveSingleBodyIsValid: proof:\n{proof}:\n{← inferType proof}" + -- The target type (we don't have to do this: this is simply a sanity check, + -- and this allows a nicer debugging output) + let thmTy ← do + let body ← mkAppM' bodyConst #[kk_var, x_var] + let body ← mkLambdaFVars #[kk_var] body + let ty ← mkAppM ``FixI.is_valid_p #[k_var, body] + mkForallFVars #[k_var, x_var] ty + trace[Diverge.def.valid] "proveSingleBodyIsValid: thmTy\n{thmTy}:\n{← inferType thmTy}" + -- Save the theorem + let name := preDef.declName ++ "body_is_valid" + let decl := Declaration.thmDecl { + name + levelParams := preDef.levelParams + type := thmTy + value := proof + all := [name] + } + addDecl decl + trace[Diverge.def.valid] "proveSingleBodyIsValid: added thm: {name}" + -- Return the theorem + pure (Expr.const name (preDef.levelParams.map .param)) + +-- Prove that the list of bodies are valid. +-- +-- For instance, if we define the mutually recursive group [`is_even`, `is_odd`], +-- we prove that `Funs.Cons is_even.body (Funs.Cons is_odd.body Funs.Nil)` is +-- valid. +partial def proveFunsBodyIsValid (inOutTys: Expr) (bodyFuns : Expr) + (k_var : Expr) (bodiesValid : Array Expr) : MetaM Expr := do + -- Create the big "and" expression, which groups the validity proof of the individual bodies + let rec mkValidConj (i : Nat) : MetaM Expr := do + if i = bodiesValid.size then + -- We reached the end + mkAppM ``FixI.Funs.is_valid_p_Nil #[k_var] + else do + -- We haven't reached the end: introduce a conjunction + let valid := bodiesValid.get! i + let valid ← mkAppM' valid #[k_var] + mkAppM ``And.intro #[valid, ← mkValidConj (i + 1)] + let andExpr ← mkValidConj 0 + -- Wrap in the `is_valid_p_is_valid_p` theorem, and abstract the continuation + let isValid ← mkAppM ``FixI.Funs.is_valid_p_is_valid_p #[inOutTys, k_var, bodyFuns, andExpr] + mkLambdaFVars #[k_var] isValid + +-- Prove that the mut rec body (i.e., the unary body which groups the bodies +-- of all the functions in the mutually recursive group and on which we will +-- apply the fixed-point operator) is valid. +-- +-- We save the proof in the theorem "[GROUP_NAME]."mut_rec_body_is_valid", +-- which we return. +-- +-- TODO: maybe this function should introduce k_var itself +def proveMutRecIsValid + (grName : Name) (grLvlParams : List Name) + (inOutTys : Expr) (bodyFuns mutRecBodyConst : Expr) + (k_var : Expr) (preDefs : Array PreDefinition) + (bodies : Array Expr) : MetaM Expr := do + -- First prove that the individual bodies are valid + let bodiesValid ← + bodies.mapIdxM fun idx body => do + let preDef := preDefs.get! idx + trace[Diverge.def.valid] "## Proving that the body {body} is valid" + proveSingleBodyIsValid k_var preDef body + -- Then prove that the mut rec body is valid + trace[Diverge.def.valid] "## Proving that the 'Funs' body is valid" + let isValid ← proveFunsBodyIsValid inOutTys bodyFuns k_var bodiesValid + -- Save the theorem + let thmTy ← mkAppM ``FixI.is_valid #[mutRecBodyConst] + let name := grName ++ "mut_rec_body_is_valid" + let decl := Declaration.thmDecl { + name + levelParams := grLvlParams + type := thmTy + value := isValid + all := [name] + } + addDecl decl + trace[Diverge.def.valid] "proveFunsBodyIsValid: added thm: {name}:\n{thmTy}" + -- Return the theorem + pure (Expr.const name (grLvlParams.map .param)) + +-- Generate the final definions by using the mutual body and the fixed point operator. +-- +-- For instance: +-- ``` +-- def is_even (i : Int) : Result Bool := mut_rec_body 0 i +-- def is_odd (i : Int) : Result Bool := mut_rec_body 1 i +-- ``` +def mkDeclareFixDefs (mutRecBody : Expr) (inOutTys : Array (Expr × Expr)) (preDefs : Array PreDefinition) : + TermElabM (Array Name) := do + let grSize := preDefs.size + let defs ← preDefs.mapIdxM fun idx preDef => do + lambdaTelescope preDef.value fun xs _ => do + -- Retrieve the input type + let in_ty := (inOutTys.get! idx.val).fst + -- Create the index + let idx ← mkFinVal grSize idx.val + -- Group the inputs into a dependent tuple + let input ← mkSigmasVal in_ty xs.toList + -- Apply the fixed point + let fixedBody ← mkAppM ``FixI.fix #[mutRecBody, idx, input] + let fixedBody ← mkLambdaFVars xs fixedBody + -- Create the declaration + let name := preDef.declName + let decl := Declaration.defnDecl { + name := name + levelParams := preDef.levelParams + type := preDef.type + value := fixedBody + hints := ReducibilityHints.regular (getMaxHeight (← getEnv) fixedBody + 1) + safety := .safe + all := [name] + } + addDecl decl + pure name + pure defs + +-- Prove the equations that we will use as unfolding theorems +partial def proveUnfoldingThms (isValidThm : Expr) (inOutTys : Array (Expr × Expr)) + (preDefs : Array PreDefinition) (decls : Array Name) : MetaM Unit := do + let grSize := preDefs.size + let proveIdx (i : Nat) : MetaM Unit := do + let preDef := preDefs.get! i + let defName := decls.get! i + -- Retrieve the arguments + lambdaTelescope preDef.value fun xs body => do + trace[Diverge.def.unfold] "proveUnfoldingThms: xs: {xs}" + trace[Diverge.def.unfold] "proveUnfoldingThms: body: {body}" + -- The theorem statement + let thmTy ← do + -- The equation: the declaration gives the lhs, the pre-def gives the rhs + let lhs ← mkAppOptM defName (xs.map some) + let rhs := body + let eq ← mkAppM ``Eq #[lhs, rhs] + mkForallFVars xs eq + trace[Diverge.def.unfold] "proveUnfoldingThms: thm statement: {thmTy}" + -- The proof + -- Use the fixed-point equation + let proof ← mkAppM ``FixI.is_valid_fix_fixed_eq #[isValidThm] + -- Add the index + let idx ← mkFinVal grSize i + let proof ← mkAppM ``congr_fun #[proof, idx] + -- Add the input argument + let arg ← mkSigmasVal (inOutTys.get! i).fst xs.toList + let proof ← mkAppM ``congr_fun #[proof, arg] + -- Abstract the arguments away + let proof ← mkLambdaFVars xs proof + trace[Diverge.def.unfold] "proveUnfoldingThms: proof: {proof}:\n{← inferType proof}" + -- Declare the theorem + let name := preDef.declName ++ "unfold" + let decl := Declaration.thmDecl { + name + levelParams := preDef.levelParams + type := thmTy + value := proof + all := [name] + } + addDecl decl + -- Add the unfolding theorem to the equation compiler + eqnsAttribute.add preDef.declName #[name] + trace[Diverge.def.unfold] "proveUnfoldingThms: added thm: {name}:\n{thmTy}" + let rec prove (i : Nat) : MetaM Unit := do + if i = preDefs.size then pure () + else do + proveIdx i + prove (i + 1) + -- + prove 0 + +def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do + let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) + trace[Diverge.def] ("divRecursion: defs:\n" ++ msg) + + -- TODO: what is this? + for preDef in preDefs do + applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation + + -- Retrieve the name of the first definition, that we will use as the namespace + -- for the definitions common to the group + let def0 := preDefs[0]! + let grName := def0.declName + trace[Diverge.def] "group name: {grName}" + + /- # Compute the input/output types of the continuation `k`. -/ + let grLvlParams := def0.levelParams + trace[Diverge.def] "def0 universe levels: {def0.levelParams}" + + -- We first compute the list of pairs: (input type × output type) + let inOutTys : Array (Expr × Expr) ← + preDefs.mapM (fun preDef => do + withRef preDef.ref do -- is the withRef useful? + -- Check the universe parameters - TODO: I'm not sure what the best thing + -- to do is. In practice, all the type parameters should be in Type 0, so + -- we shouldn't have universe issues. + if preDef.levelParams ≠ grLvlParams then + throwError "Non-uniform polymorphism in the universes" + forallTelescope preDef.type (fun in_tys out_ty => do + let in_ty ← liftM (mkSigmasType in_tys.toList) + -- Retrieve the type in the "Result" + let out_ty ← getResultTy out_ty + let out_ty ← liftM (mkSigmasMatch in_tys.toList out_ty) + pure (in_ty, out_ty) + ) + ) + trace[Diverge.def] "inOutTys: {inOutTys}" + -- Turn the list of input/output type pairs into an expresion + let inOutTysExpr ← inOutTys.mapM (λ (x, y) => mkInOutTy x y) + let inOutTysExpr ← mkListLit (← inferType (inOutTysExpr.get! 0)) inOutTysExpr.toList + + -- From the list of pairs of input/output types, actually compute the + -- type of the continuation `k`. + -- We first introduce the index `i : Fin n` where `n` is the number of + -- functions in the group. + let i_var_ty := mkFin preDefs.size + withLocalDeclD (mkAnonymous "i" 0) i_var_ty fun i_var => do + let in_out_ty ← mkAppM ``List.get #[inOutTysExpr, i_var] + trace[Diverge.def] "in_out_ty := {in_out_ty} : {← inferType in_out_ty}" + -- Add an auxiliary definition for `in_out_ty` + let in_out_ty ← do + let value ← mkLambdaFVars #[i_var] in_out_ty + let name := grName.append "in_out_ty" + let levelParams := grLvlParams + let decl := Declaration.defnDecl { + name := name + levelParams := levelParams + type := ← inferType value + value := value + hints := .abbrev + safety := .safe + all := [name] + } + addDecl decl + -- Return the constant + let in_out_ty := Lean.mkConst name (levelParams.map .param) + mkAppM' in_out_ty #[i_var] + trace[Diverge.def] "in_out_ty (after decl) := {in_out_ty} : {← inferType in_out_ty}" + let in_ty ← mkAppM ``Sigma.fst #[in_out_ty] + trace[Diverge.def] "in_ty: {in_ty}" + withLocalDeclD (mkAnonymous "x" 1) in_ty fun input => do + let out_ty ← mkAppM' (← mkAppM ``Sigma.snd #[in_out_ty]) #[input] + trace[Diverge.def] "out_ty: {out_ty}" + + -- Introduce the continuation `k` + let in_ty ← mkLambdaFVars #[i_var] in_ty + let out_ty ← mkLambdaFVars #[i_var, input] out_ty + let kk_var_ty ← mkAppM ``FixI.kk_ty #[i_var_ty, in_ty, out_ty] + trace[Diverge.def] "kk_var_ty: {kk_var_ty}" + withLocalDeclD (mkAnonymous "kk" 2) kk_var_ty fun kk_var => do + trace[Diverge.def] "kk_var: {kk_var}" + + -- Replace the recursive calls in all the function bodies by calls to the + -- continuation `k` and and generate for those bodies declarations + trace[Diverge.def] "# Generating the unary bodies" + let bodies ← mkDeclareUnaryBodies grLvlParams kk_var inOutTys preDefs + trace[Diverge.def] "Unary bodies (after decl): {bodies}" + -- Generate the mutually recursive body + trace[Diverge.def] "# Generating the mut rec body" + let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var in_ty out_ty inOutTys.toList bodies + trace[Diverge.def] "mut rec body (after decl): {mutRecBody}" + + -- Prove that the mut rec body satisfies the validity criteria required by + -- our fixed-point + let k_var_ty ← mkAppM ``FixI.k_ty #[i_var_ty, in_ty, out_ty] + withLocalDeclD (mkAnonymous "k" 3) k_var_ty fun k_var => do + trace[Diverge.def] "# Proving that the mut rec body is valid" + let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies + + -- Generate the final definitions + trace[Diverge.def] "# Generating the final definitions" + let decls ← mkDeclareFixDefs mutRecBody inOutTys preDefs + + -- Prove the unfolding theorems + trace[Diverge.def] "# Proving the unfolding theorems" + proveUnfoldingThms isValidThm inOutTys preDefs decls + + -- Generating code -- TODO + addAndCompilePartialRec preDefs + +-- The following function is copy&pasted from Lean.Elab.PreDefinition.Main +-- This is the only part where we make actual changes and hook into the equation compiler. +-- (I've removed all the well-founded stuff to make it easier to read though.) + +open private ensureNoUnassignedMVarsAtPreDef betaReduceLetRecApps partitionPreDefs + addAndCompilePartial addAsAxioms from Lean.Elab.PreDefinition.Main + +def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLCtx {} {} do + for preDef in preDefs do + trace[Diverge.elab] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" + let preDefs ← preDefs.mapM ensureNoUnassignedMVarsAtPreDef + let preDefs ← betaReduceLetRecApps preDefs + let cliques := partitionPreDefs preDefs + let mut hasErrors := false + for preDefs in cliques do + trace[Diverge.elab] "{preDefs.map (·.declName)}" + try + withRef (preDefs[0]!.ref) do + divRecursion preDefs + catch ex => + -- If it failed, we add the functions as partial functions + hasErrors := true + logException ex + let s ← saveState + try + if preDefs.all fun preDef => preDef.kind == DefKind.def || + preDefs.all fun preDef => preDef.kind == DefKind.abbrev then + -- try to add as partial definition + try + addAndCompilePartial preDefs (useSorry := true) + catch _ => + -- Compilation failed try again just as axiom + s.restore + addAsAxioms preDefs + else return () + catch _ => s.restore + +-- The following two functions are copy-pasted from Lean.Elab.MutualDef + +open private elabHeaders levelMVarToParamHeaders getAllUserLevelNames withFunLocalDecls elabFunValues + instantiateMVarsAtHeader instantiateMVarsAtLetRecToLift checkLetRecsToLiftTypes withUsed from Lean.Elab.MutualDef + +def Term.elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do + let scopeLevelNames ← getLevelNames + let headers ← elabHeaders views + let headers ← levelMVarToParamHeaders views headers + let allUserLevelNames := getAllUserLevelNames headers + withFunLocalDecls headers fun funFVars => do + for view in views, funFVar in funFVars do + addLocalVarInfo view.declId funFVar + -- Add fake use site to prevent "unused variable" warning (if the + -- function is actually not recursive, Lean would print this warning). + -- Remark: we could detect this case and encode the function without + -- using the fixed-point. In practice it shouldn't happen however: + -- we define non-recursive functions with the `divergent` keyword + -- only for testing purposes. + addTermInfo' view.declId funFVar + let values ← + try + let values ← elabFunValues headers + Term.synthesizeSyntheticMVarsNoPostponing + values.mapM (instantiateMVars ·) + catch ex => + logException ex + headers.mapM fun header => mkSorry header.type (synthetic := true) + let headers ← headers.mapM instantiateMVarsAtHeader + let letRecsToLift ← getLetRecsToLift + let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift + checkLetRecsToLiftTypes funFVars letRecsToLift + withUsed vars headers values letRecsToLift fun vars => do + let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift + for preDef in preDefs do + trace[Diverge.elab] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" + let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamPreDecls preDefs + let preDefs ← instantiateMVarsAtPreDecls preDefs + let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames + for preDef in preDefs do + trace[Diverge.elab] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}" + checkForHiddenUnivLevels allUserLevelNames preDefs + addPreDefinitions preDefs + +open Command in +def Command.elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do + let views ← ds.mapM fun d => do + let `($mods:declModifiers divergent def $id:declId $sig:optDeclSig $val:declVal) := d + | throwUnsupportedSyntax + let modifiers ← elabModifiers mods + let (binders, type) := expandOptDeclSig sig + let deriving? := none + pure { ref := d, kind := DefKind.def, modifiers, + declId := id, binders, type? := type, value := val, deriving? } + runTermElabM fun vars => Term.elabMutualDef vars views + +-- Special command so that we don't fall back to the built-in mutual when we produce an error. +local syntax "_divergent" Parser.Command.mutual : command +elab_rules : command | `(_divergent mutual $decls* end) => Command.elabMutualDef decls + +macro_rules + | `(mutual $decls* end) => do + unless !decls.isEmpty && decls.all (·.1.getKind == ``divergentDef) do + Macro.throwUnsupported + `(command| _divergent mutual $decls* end) + +open private setDeclIdName from Lean.Elab.Declaration +elab_rules : command + | `($mods:declModifiers divergent%$tk def $id:declId $sig:optDeclSig $val:declVal) => do + let (name, _) := expandDeclIdCore id + if (`_root_).isPrefixOf name then throwUnsupportedSyntax + let view := extractMacroScopes name + let .str ns shortName := view.name | throwUnsupportedSyntax + let shortName' := { view with name := shortName }.review + let cmd ← `(mutual $mods:declModifiers divergent%$tk def $(⟨setDeclIdName id shortName'⟩):declId $sig:optDeclSig $val:declVal end) + if ns matches .anonymous then + Command.elabCommand cmd + else + Command.elabCommand <| ← `(namespace $(mkIdentFrom id ns) $cmd end $(mkIdentFrom id ns)) + +namespace Tests + /- Some examples of partial functions -/ + + divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := + match ls with + | [] => .fail .panic + | x :: ls => + if i = 0 then return x + else return (← list_nth ls (i - 1)) + + #check list_nth.unfold + + example {a: Type} (ls : List a) : + ∀ (i : Int), + 0 ≤ i → i < ls.length → + ∃ x, list_nth ls i = .ret x := by + induction ls + . intro i hpos h; simp at h; linarith + . rename_i hd tl ih + intro i hpos h + -- We can directly use `rw [list_nth]`! + rw [list_nth]; simp + split <;> simp [*] + . tauto + . -- TODO: we shouldn't have to do that + have hneq : 0 < i := by cases i <;> rename_i a _ <;> simp_all; cases a <;> simp_all + simp at h + have ⟨ x, ih ⟩ := ih (i - 1) (by linarith) (by linarith) + simp [ih] + tauto + + mutual + divergent def is_even (i : Int) : Result Bool := + if i = 0 then return true else return (← is_odd (i - 1)) + + divergent def is_odd (i : Int) : Result Bool := + if i = 0 then return false else return (← is_even (i - 1)) + end + + #check is_even.unfold + #check is_odd.unfold + + mutual + divergent def foo (i : Int) : Result Nat := + if i > 10 then return (← foo (i / 10)) + (← bar i) else bar 10 + + divergent def bar (i : Int) : Result Nat := + if i > 20 then foo (i / 20) else .ret 42 + end + + #check foo.unfold + #check bar.unfold + + -- Testing dependent branching and let-bindings + -- TODO: why the linter warning? + divergent def isNonZero (i : Int) : Result Bool := + if _h:i = 0 then return false + else + let b := true + return b + + #check isNonZero.unfold + + -- Testing let-bindings + divergent def iInBounds {a : Type} (ls : List a) (i : Int) : Result Bool := + let i0 := ls.length + if i < i0 + then Result.ret True + else Result.ret False + + #check iInBounds.unfold + + divergent def isCons + {a : Type} (ls : List a) : Result Bool := + let ls1 := ls + match ls1 with + | [] => Result.ret False + | _ :: _ => Result.ret True + + #check isCons.unfold + + -- Testing what happens when we use concrete arguments in dependent tuples + divergent def test1 + (_ : Option Bool) (_ : Unit) : + Result Unit + := + test1 Option.none () + + #check test1.unfold + +end Tests + +end Diverge diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean new file mode 100644 index 00000000..fedb1c74 --- /dev/null +++ b/backends/lean/Base/Diverge/ElabBase.lean @@ -0,0 +1,15 @@ +import Lean + +namespace Diverge + +open Lean Elab Term Meta + +-- We can't define and use trace classes in the same file +initialize registerTraceClass `Diverge.elab +initialize registerTraceClass `Diverge.def +initialize registerTraceClass `Diverge.def.sigmas +initialize registerTraceClass `Diverge.def.genBody +initialize registerTraceClass `Diverge.def.valid +initialize registerTraceClass `Diverge.def.unfold + +end Diverge |