From 19bde89b84619defc2a822c3bf96bdca9c97eee7 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 28 Jun 2023 12:16:10 +0200 Subject: Reorganize backends/lean/Base --- backends/lean/Base/Diverge/Base.lean | 1105 ++++++++++++++++++++++++++++++ backends/lean/Base/Diverge/Elab.lean | 182 +++++ backends/lean/Base/Diverge/ElabBase.lean | 9 + 3 files changed, 1296 insertions(+) create mode 100644 backends/lean/Base/Diverge/Base.lean create mode 100644 backends/lean/Base/Diverge/Elab.lean create mode 100644 backends/lean/Base/Diverge/ElabBase.lean (limited to 'backends/lean/Base/Diverge') diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean new file mode 100644 index 00000000..0f92e682 --- /dev/null +++ b/backends/lean/Base/Diverge/Base.lean @@ -0,0 +1,1105 @@ +import Lean +import Lean.Meta.Tactic.Simp +import Init.Data.List.Basic +import Mathlib.Tactic.RunCmd +import Mathlib.Tactic.Linarith + +/- +TODO: +- we want an easier to use cases: + - keeps in the goal an equation of the shape: `t = case` + - if called on Prop terms, uses Classical.em + Actually, the cases from mathlib seems already quite powerful + (https://leanprover-community.github.io/mathlib_docs/tactics.html#cases) + For instance: cases h : e + Also: cases_matching +- better split tactic +- we need conversions to operate on the head of applications. + Actually, something like this works: + ``` + conv at Hl => + apply congr_fun + simp [fix_fuel_P] + ``` + Maybe we need a rpt ... ; focus? +- simplifier/rewriter have a strange behavior sometimes +-/ + + +/- TODO: this is very useful, but is there more? -/ +set_option profiler true +set_option profiler.threshold 100 + +namespace Diverge + +namespace Primitives +/-! # Copy-pasting from Primitives to make the file self-contained -/ + +inductive Error where + | assertionFailure: Error + | integerOverflow: Error + | divisionByZero: Error + | arrayOutOfBounds: Error + | maximumSizeExceeded: Error + | panic: Error +deriving Repr, BEq + +open Error + +inductive Result (α : Type u) where + | ret (v: α): Result α + | fail (e: Error): Result α + | div +deriving Repr, BEq + +open Result + +def bind (x: Result α) (f: α -> Result β) : Result β := + match x with + | ret v => f v + | fail v => fail v + | div => div + +@[simp] theorem bind_ret (x : α) (f : α → Result β) : bind (.ret x) f = f x := by simp [bind] +@[simp] theorem bind_fail (x : Error) (f : α → Result β) : bind (.fail x) f = .fail x := by simp [bind] +@[simp] theorem bind_div (f : α → Result β) : bind .div f = .div := by simp [bind] + +-- Allows using Result in do-blocks +instance : Bind Result where + bind := bind + +-- Allows using return x in do-blocks +instance : Pure Result where + pure := fun x => ret x + +@[simp] theorem bind_tc_ret (x : α) (f : α → Result β) : + (do let y ← .ret x; f y) = f x := by simp [Bind.bind, bind] + +@[simp] theorem bind_tc_fail (x : Error) (f : α → Result β) : + (do let y ← fail x; f y) = fail x := by simp [Bind.bind, bind] + +@[simp] theorem bind_tc_div (f : α → Result β) : + (do let y ← div; f y) = div := by simp [Bind.bind, bind] + +def div? {α: Type} (r: Result α): Bool := + match r with + | div => true + | ret _ | fail _ => false + +end Primitives + +namespace Fix + + open Primitives + open Result + + variable {a : Type} {b : a → Type} + variable {c d : Type} + + /-! # The least fixed point definition and its properties -/ + + def least_p (p : Nat → Prop) (n : Nat) : Prop := p n ∧ (∀ m, m < n → ¬ p m) + noncomputable def least (p : Nat → Prop) : Nat := + Classical.epsilon (least_p p) + + -- Auxiliary theorem for [least_spec]: if there exists an `n` satisfying `p`, + -- there there exists a least `m` satisfying `p`. + theorem least_spec_aux (p : Nat → Prop) : ∀ (n : Nat), (hn : p n) → ∃ m, least_p p m := by + apply Nat.strongRec' + intros n hi hn + -- Case disjunction on: is n the smallest n satisfying p? + match Classical.em (∀ m, m < n → ¬ p m) with + | .inl hlt => + -- Yes: trivial + exists n + | .inr hlt => + simp at * + let ⟨ m, ⟨ hmlt, hm ⟩ ⟩ := hlt + have hi := hi m hmlt hm + apply hi + + -- The specification of [least]: either `p` is never satisfied, or it is satisfied + -- by `least p` and no `n < least p` satisfies `p`. + theorem least_spec (p : Nat → Prop) : (∀ n, ¬ p n) ∨ (p (least p) ∧ ∀ n, n < least p → ¬ p n) := by + -- Case disjunction on the existence of an `n` which satisfies `p` + match Classical.em (∀ n, ¬ p n) with + | .inl h => + -- There doesn't exist: trivial + apply (Or.inl h) + | .inr h => + -- There exists: we simply use `least_spec_aux` in combination with the property + -- of the epsilon operator + simp at * + let ⟨ n, hn ⟩ := h + apply Or.inr + have hl := least_spec_aux p n hn + have he := Classical.epsilon_spec hl + apply he + + /-! # The fixed point definitions -/ + + def fix_fuel (n : Nat) (f : ((x:a) → Result (b x)) → (x:a) → Result (b x)) (x : a) : + Result (b x) := + match n with + | 0 => .div + | n + 1 => + f (fix_fuel n f) x + + @[simp] def fix_fuel_pred (f : ((x:a) → Result (b x)) → (x:a) → Result (b x)) + (x : a) (n : Nat) := + not (div? (fix_fuel n f x)) + + def fix_fuel_P (f : ((x:a) → Result (b x)) → (x:a) → Result (b x)) + (x : a) (n : Nat) : Prop := + fix_fuel_pred f x n + + noncomputable + def fix (f : ((x:a) → Result (b x)) → (x:a) → Result (b x)) (x : a) : Result (b x) := + fix_fuel (least (fix_fuel_P f x)) f x + + /-! # The validity property -/ + + -- Monotonicity relation over results + -- TODO: generalize (we should parameterize the definition by a relation over `a`) + def result_rel {a : Type u} (x1 x2 : Result a) : Prop := + match x1 with + | div => True + | fail _ => x2 = x1 + | ret _ => x2 = x1 -- TODO: generalize + + -- Monotonicity relation over monadic arrows (i.e., Kleisli arrows) + def karrow_rel (k1 k2 : (x:a) → Result (b x)) : Prop := + ∀ x, result_rel (k1 x) (k2 x) + + -- Monotonicity property for function bodies + def is_mono (f : ((x:a) → Result (b x)) → (x:a) → Result (b x)) : Prop := + ∀ {{k1 k2}}, karrow_rel k1 k2 → karrow_rel (f k1) (f k2) + + -- "Continuity" property. + -- We need this, and this looks a lot like continuity. Also see this paper: + -- https://inria.hal.science/file/index/docid/216187/filename/tarski.pdf + -- We define our "continuity" criteria so that it gives us what we need to + -- prove the fixed-point equation, and we can also easily manipulate it. + def is_cont (f : ((x:a) → Result (b x)) → (x:a) → Result (b x)) : Prop := + ∀ x, (Hdiv : ∀ n, fix_fuel (.succ n) f x = div) → f (fix f) x = div + + /-! # The proof of the fixed-point equation -/ + theorem fix_fuel_mono {f : ((x:a) → Result (b x)) → (x:a) → Result (b x)} + (Hmono : is_mono f) : + ∀ {{n m}}, n ≤ m → karrow_rel (fix_fuel n f) (fix_fuel m f) := by + intros n + induction n + case zero => simp [karrow_rel, fix_fuel, result_rel] + case succ n1 Hi => + intros m Hle x + simp [result_rel] + match m with + | 0 => + exfalso + zify at * + linarith + | Nat.succ m1 => + simp_arith at Hle + simp [fix_fuel] + have Hi := Hi Hle + have Hmono := Hmono Hi x + simp [result_rel] at Hmono + apply Hmono + + @[simp] theorem neg_fix_fuel_P + {f : ((x:a) → Result (b x)) → (x:a) → Result (b x)} {x : a} {n : Nat} : + ¬ fix_fuel_P f x n ↔ (fix_fuel n f x = div) := by + simp [fix_fuel_P, div?] + cases fix_fuel n f x <;> simp + + theorem fix_fuel_fix_mono {f : ((x:a) → Result (b x)) → (x:a) → Result (b x)} (Hmono : is_mono f) : + ∀ n, karrow_rel (fix_fuel n f) (fix f) := by + intros n x + simp [result_rel] + have Hl := least_spec (fix_fuel_P f x) + simp at Hl + match Hl with + | .inl Hl => simp [*] + | .inr ⟨ Hl, Hn ⟩ => + match Classical.em (fix_fuel n f x = div) with + | .inl Hd => + simp [*] + | .inr Hd => + have Hineq : least (fix_fuel_P f x) ≤ n := by + -- Proof by contradiction + cases Classical.em (least (fix_fuel_P f x) ≤ n) <;> simp [*] + simp at * + rename_i Hineq + have Hn := Hn n Hineq + contradiction + have Hfix : ¬ (fix f x = div) := by + simp [fix] + -- By property of the least upper bound + revert Hd Hl + -- TODO: there is no conversion to select the head of a function! + conv => lhs; apply congr_fun; apply congr_fun; apply congr_fun; simp [fix_fuel_P, div?] + cases fix_fuel (least (fix_fuel_P f x)) f x <;> simp + have Hmono := fix_fuel_mono Hmono Hineq x + simp [result_rel] at Hmono + simp [fix] at * + cases Heq: fix_fuel (least (fix_fuel_P f x)) f x <;> + cases Heq':fix_fuel n f x <;> + simp_all + + theorem fix_fuel_P_least {f : ((x:a) → Result (b x)) → (x:a) → Result (b x)} (Hmono : is_mono f) : + ∀ {{x n}}, fix_fuel_P f x n → fix_fuel_P f x (least (fix_fuel_P f x)) := by + intros x n Hf + have Hfmono := fix_fuel_fix_mono Hmono n x + -- TODO: there is no conversion to select the head of a function! + conv => apply congr_fun; simp [fix_fuel_P] + simp [fix_fuel_P] at Hf + revert Hf Hfmono + simp [div?, result_rel, fix] + cases fix_fuel n f x <;> simp_all + + -- Prove the fixed point equation in the case there exists some fuel for which + -- the execution terminates + theorem fix_fixed_eq_terminates (f : ((x:a) → Result (b x)) → (x:a) → Result (b x)) (Hmono : is_mono f) + (x : a) (n : Nat) (He : fix_fuel_P f x n) : + fix f x = f (fix f) x := by + have Hl := fix_fuel_P_least Hmono He + -- TODO: better control of simplification + conv at Hl => + apply congr_fun + simp [fix_fuel_P] + -- The least upper bound is > 0 + have ⟨ n, Hsucc ⟩ : ∃ n, least (fix_fuel_P f x) = Nat.succ n := by + revert Hl + simp [div?] + cases least (fix_fuel_P f x) <;> simp [fix_fuel] + simp [Hsucc] at Hl + revert Hl + simp [*, div?, fix, fix_fuel] + -- Use the monotonicity + have Hfixmono := fix_fuel_fix_mono Hmono n + have Hvm := Hmono Hfixmono x + -- Use functional extensionality + simp [result_rel, fix] at Hvm + revert Hvm + split <;> simp [*] <;> intros <;> simp [*] + + theorem fix_fixed_eq_forall {{f : ((x:a) → Result (b x)) → (x:a) → Result (b x)}} + (Hmono : is_mono f) (Hcont : is_cont f) : + ∀ x, fix f x = f (fix f) x := by + intros x + -- Case disjunction: is there a fuel such that the execution successfully execute? + match Classical.em (∃ n, fix_fuel_P f x n) with + | .inr He => + -- No fuel: the fixed point evaluates to `div` + --simp [fix] at * + simp at * + conv => lhs; simp [fix] + have Hel := He (Nat.succ (least (fix_fuel_P f x))); simp [*, fix_fuel] at *; clear Hel + -- Use the "continuity" of `f` + have He : ∀ n, fix_fuel (.succ n) f x = div := by intros; simp [*] + have Hcont := Hcont x He + simp [Hcont] + | .inl ⟨ n, He ⟩ => apply fix_fixed_eq_terminates f Hmono x n He + + -- The final fixed point equation + theorem fix_fixed_eq {{f : ((x:a) → Result (b x)) → (x:a) → Result (b x)}} + (Hmono : is_mono f) (Hcont : is_cont f) : + fix f = f (fix f) := by + have Heq := fix_fixed_eq_forall Hmono Hcont + have Heq1 : fix f = (λ x => fix f x) := by simp + rw [Heq1] + conv => lhs; ext; simp [Heq] + + /-! # Making the proofs of validity manageable (and automatable) -/ + + -- Monotonicity property for expressions + def is_mono_p (e : ((x:a) → Result (b x)) → Result c) : Prop := + ∀ {{k1 k2}}, karrow_rel k1 k2 → result_rel (e k1) (e k2) + + theorem is_mono_p_same (x : Result c) : + @is_mono_p a b c (λ _ => x) := by + simp [is_mono_p, karrow_rel, result_rel] + split <;> simp + + theorem is_mono_p_rec (x : a) : + @is_mono_p a b (b x) (λ f => f x) := by + simp_all [is_mono_p, karrow_rel, result_rel] + + -- The important lemma about `is_mono_p` + theorem is_mono_p_bind + (g : ((x:a) → Result (b x)) → Result c) + (h : c → ((x:a) → Result (b x)) → Result d) : + is_mono_p g → + (∀ y, is_mono_p (h y)) → + @is_mono_p a b d (λ k => do let y ← g k; h y k) := by + intro hg hh + simp [is_mono_p] + intro fg fh Hrgh + simp [karrow_rel, result_rel] + have hg := hg Hrgh; simp [result_rel] at hg + cases heq0: g fg <;> simp_all + rename_i y _ + have hh := hh y Hrgh; simp [result_rel] at hh + simp_all + + -- Continuity property for expressions - note that we take the continuation + -- as parameter + def is_cont_p (k : ((x:a) → Result (b x)) → (x:a) → Result (b x)) + (e : ((x:a) → Result (b x)) → Result c) : Prop := + (Hc : ∀ n, e (fix_fuel n k) = .div) → + e (fix k) = .div + + theorem is_cont_p_same (k : ((x:a) → Result (b x)) → (x:a) → Result (b x)) + (x : Result c) : + is_cont_p k (λ _ => x) := by + simp [is_cont_p] + + theorem is_cont_p_rec (f : ((x:a) → Result (b x)) → (x:a) → Result (b x)) (x : a) : + is_cont_p f (λ f => f x) := by + simp_all [is_cont_p, fix] + + -- The important lemma about `is_cont_p` + theorem is_cont_p_bind + (k : ((x:a) → Result (b x)) → (x:a) → Result (b x)) + (Hkmono : is_mono k) + (g : ((x:a) → Result (b x)) → Result c) + (h : c → ((x:a) → Result (b x)) → Result d) : + is_mono_p g → + is_cont_p k g → + (∀ y, is_mono_p (h y)) → + (∀ y, is_cont_p k (h y)) → + is_cont_p k (λ k => do let y ← g k; h y k) := by + intro Hgmono Hgcont Hhmono Hhcont + simp [is_cont_p] + intro Hdiv + -- Case on `g (fix... k)`: is there an n s.t. it terminates? + cases Classical.em (∀ n, g (fix_fuel n k) = .div) <;> rename_i Hn + . -- Case 1: g diverges + have Hgcont := Hgcont Hn + simp_all + . -- Case 2: g doesn't diverge + simp at Hn + let ⟨ n, Hn ⟩ := Hn + have Hdivn := Hdiv n + have Hffmono := fix_fuel_fix_mono Hkmono n + have Hgeq := Hgmono Hffmono + simp [result_rel] at Hgeq + cases Heq: g (fix_fuel n k) <;> rename_i y <;> simp_all + -- Remains the .ret case + -- Use Hdiv to prove that: ∀ n, h y (fix_fuel n f) = div + -- We do this in two steps: first we prove it for m ≥ n + have Hhdiv: ∀ m, h y (fix_fuel m k) = .div := by + have Hhdiv : ∀ m, n ≤ m → h y (fix_fuel m k) = .div := by + -- We use the fact that `g (fix_fuel n f) = .div`, combined with Hdiv + intro m Hle + have Hdivm := Hdiv m + -- Monotonicity of g + have Hffmono := fix_fuel_mono Hkmono Hle + have Hgmono := Hgmono Hffmono + -- We need to clear Hdiv because otherwise simp_all rewrites Hdivm with Hdiv + clear Hdiv + simp_all [result_rel] + intro m + -- TODO: we shouldn't need the excluded middle here because it is decidable + cases Classical.em (n ≤ m) <;> rename_i Hl + . apply Hhdiv; assumption + . simp at Hl + -- Make a case disjunction on `h y (fix_fuel m k)`: if it is not equal + -- to div, use the monotonicity of `h y` + have Hle : m ≤ n := by linarith + have Hffmono := fix_fuel_mono Hkmono Hle + have Hmono := Hhmono y Hffmono + simp [result_rel] at Hmono + cases Heq: h y (fix_fuel m k) <;> simp_all + -- We can now use the continuity hypothesis for h + apply Hhcont; assumption + + -- The validity property for an expression + def is_valid_p (k : ((x:a) → Result (b x)) → (x:a) → Result (b x)) + (e : ((x:a) → Result (b x)) → Result c) : Prop := + is_mono_p e ∧ + (is_mono k → is_cont_p k e) + + @[simp] theorem is_valid_p_same + (k : ((x:a) → Result (b x)) → (x:a) → Result (b x)) (x : Result c) : + is_valid_p k (λ _ => x) := by + simp [is_valid_p, is_mono_p_same, is_cont_p_same] + + @[simp] theorem is_valid_p_rec + (k : ((x:a) → Result (b x)) → (x:a) → Result (b x)) (x : a) : + is_valid_p k (λ k => k x) := by + simp_all [is_valid_p, is_mono_p_rec, is_cont_p_rec] + + -- Lean is good at unification: we can write a very general version + -- (in particular, it will manage to figure out `g` and `h` when we + -- apply the lemma) + theorem is_valid_p_bind + {{k : ((x:a) → Result (b x)) → (x:a) → Result (b x)}} + {{g : ((x:a) → Result (b x)) → Result c}} + {{h : c → ((x:a) → Result (b x)) → Result d}} + (Hgvalid : is_valid_p k g) + (Hhvalid : ∀ y, is_valid_p k (h y)) : + is_valid_p k (λ k => do let y ← g k; h y k) := by + let ⟨ Hgmono, Hgcont ⟩ := Hgvalid + simp [is_valid_p, forall_and] at Hhvalid + have ⟨ Hhmono, Hhcont ⟩ := Hhvalid + simp [← imp_forall_iff] at Hhcont + simp [is_valid_p]; constructor + . -- Monotonicity + apply is_mono_p_bind <;> assumption + . -- Continuity + intro Hkmono + have Hgcont := Hgcont Hkmono + have Hhcont := Hhcont Hkmono + apply is_cont_p_bind <;> assumption + + def is_valid (f : ((x:a) → Result (b x)) → (x:a) → Result (b x)) : Prop := + ∀ k x, is_valid_p k (λ k => f k x) + + theorem is_valid_p_imp_is_valid {{f : ((x:a) → Result (b x)) → (x:a) → Result (b x)}} + (Hvalid : is_valid f) : + is_mono f ∧ is_cont f := by + have Hmono : is_mono f := by + intro f h Hr x + have Hmono := Hvalid (λ _ _ => .div) x + have Hmono := Hmono.left + apply Hmono; assumption + have Hcont : is_cont f := by + intro x Hdiv + have Hcont := (Hvalid f x).right Hmono + simp [is_cont_p] at Hcont + apply Hcont + intro n + have Hdiv := Hdiv n + simp [fix_fuel] at Hdiv + simp [*] + simp [*] + + theorem is_valid_fix_fixed_eq {{f : ((x:a) → Result (b x)) → (x:a) → Result (b x)}} + (Hvalid : is_valid f) : + fix f = f (fix f) := by + have ⟨ Hmono, Hcont ⟩ := is_valid_p_imp_is_valid Hvalid + exact fix_fixed_eq Hmono Hcont + +end Fix + +namespace FixI + /- Indexed fixed-point: definitions with indexed types, convenient to use for mutually + recursive definitions. We simply port the definitions and proofs from Fix to a more + specific case. + -/ + open Primitives Fix + + -- The index type + variable {id : Type} + + -- The input/output types + variable {a b : id → Type} + + -- Monotonicity relation over monadic arrows (i.e., Kleisli arrows) + def karrow_rel (k1 k2 : (i:id) → a i → Result (b i)) : Prop := + ∀ i x, result_rel (k1 i x) (k2 i x) + + def kk_to_gen (k : (i:id) → a i → Result (b i)) : + (x: (i:id) × a i) → Result (b x.fst) := + λ ⟨ i, x ⟩ => k i x + + def kk_of_gen (k : (x: (i:id) × a i) → Result (b x.fst)) : + (i:id) → a i → Result (b i) := + λ i x => k ⟨ i, x ⟩ + + def k_to_gen (k : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) : + ((x: (i:id) × a i) → Result (b x.fst)) → (x: (i:id) × a i) → Result (b x.fst) := + λ kk => kk_to_gen (k (kk_of_gen kk)) + + def k_of_gen (k : ((x: (i:id) × a i) → Result (b x.fst)) → (x: (i:id) × a i) → Result (b x.fst)) : + ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i) := + λ kk => kk_of_gen (k (kk_to_gen kk)) + + def e_to_gen (e : ((i:id) → a i → Result (b i)) → Result c) : + ((x: (i:id) × a i) → Result (b x.fst)) → Result c := + λ k => e (kk_of_gen k) + + def is_valid_p (k : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) + (e : ((i:id) → a i → Result (b i)) → Result c) : Prop := + Fix.is_valid_p (k_to_gen k) (e_to_gen e) + + def is_valid (f : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) : Prop := + ∀ k i x, is_valid_p k (λ k => f k i x) + + noncomputable def fix + (f : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) : + (i:id) → a i → Result (b i) := + kk_of_gen (Fix.fix (k_to_gen f)) + + theorem is_valid_fix_fixed_eq + {{f : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)}} + (Hvalid : is_valid f) : + fix f = f (fix f) := by + have Hvalid' : Fix.is_valid (k_to_gen f) := by + intro k x + simp only [is_valid, is_valid_p] at Hvalid + let ⟨ i, x ⟩ := x + have Hvalid := Hvalid (k_of_gen k) i x + simp only [k_to_gen, k_of_gen, kk_to_gen, kk_of_gen] at Hvalid + refine Hvalid + have Heq := Fix.is_valid_fix_fixed_eq Hvalid' + simp [fix] + conv => lhs; rw [Heq] + + /- Some utilities to define the mutually recursive functions -/ + + -- TODO: use more + @[simp] def kk_ty (id : Type) (a b : id → Type) := (i:id) → a i → Result (b i) + @[simp] def k_ty (id : Type) (a b : id → Type) := kk_ty id a b → kk_ty id a b + + -- Initially, we had left out the parameters id, a and b. + -- However, by parameterizing Funs with those parameters, we can state + -- and prove lemmas like Funs.is_valid_p_is_valid_p + inductive Funs (id : Type) (a b : id → Type) : + List (Type u) → List (Type u) → Type (u + 1) := + | Nil : Funs id a b [] [] + | Cons {ity oty : Type u} {itys otys : List (Type u)} + (f : kk_ty id a b → ity → Result oty) (tl : Funs id a b itys otys) : + Funs id a b (ity :: itys) (oty :: otys) + + theorem Funs.length_eq {itys otys : List (Type)} (fl : Funs id a b itys otys) : + otys.length = itys.length := + match fl with + | .Nil => by simp + | .Cons f tl => + have h:= Funs.length_eq tl + by simp [h] + + def fin_cast {n m : Nat} (h : m = n) (i : Fin n) : Fin m := + ⟨ i.val, by have h1:= i.isLt; simp_all ⟩ + + @[simp] def Funs.cast_fin {itys otys : List (Type)} + (fl : Funs id a b itys otys) (i : Fin itys.length) : Fin otys.length := + fin_cast (fl.length_eq) i + + def get_fun {itys otys : List (Type)} (fl : Funs id a b itys otys) : + (i : Fin itys.length) → kk_ty id a b → itys.get i → Result (otys.get (fl.cast_fin i)) := + match fl with + | .Nil => λ i => by have h:= i.isLt; simp at h + | @Funs.Cons id a b ity oty itys1 otys1 f tl => + λ i => + if h: i.val = 0 then + Eq.mp (by cases i; simp_all [List.get]) f + else + let j := i.val - 1 + have Hj: j < itys1.length := by + have Hi := i.isLt + simp at Hi + revert Hi + cases Heq: i.val <;> simp_all + simp_arith + let j: Fin itys1.length := ⟨ j, Hj ⟩ + Eq.mp + (by + cases Heq: i; rename_i val isLt; + cases Heq': j; rename_i val' isLt; + cases val <;> simp_all [List.get, fin_cast]) + (get_fun tl j) + + -- TODO: move + theorem add_one_le_iff_le_ne (n m : Nat) (h1 : m ≤ n) (h2 : m ≠ n) : m + 1 ≤ n := by + -- Damn, those proofs on natural numbers are hard - I wish Omega was in mathlib4... + simp [Nat.add_one_le_iff] + simp [Nat.lt_iff_le_and_ne] + simp_all + + def for_all_fin_aux {n : Nat} (f : Fin n → Prop) (m : Nat) (h : m ≤ n) : Prop := + if heq: m = n then True + else + f ⟨ m, by simp_all [Nat.lt_iff_le_and_ne] ⟩ ∧ + for_all_fin_aux f (m + 1) (by simp_all [add_one_le_iff_le_ne]) + termination_by for_all_fin_aux n _ m h => n - m + decreasing_by + simp_wf + apply Nat.sub_add_lt_sub <;> simp + simp_all [add_one_le_iff_le_ne] + + def for_all_fin {n : Nat} (f : Fin n → Prop) := for_all_fin_aux f 0 (by simp) + + theorem for_all_fin_aux_imp_forall {n : Nat} (f : Fin n → Prop) (m : Nat) : + (h : m ≤ n) → + for_all_fin_aux f m h → ∀ i, m ≤ i.val → f i + := by + generalize h: (n - m) = k + revert m + induction k -- TODO: induction h rather? + case zero => + simp_all + intro m h1 h2 + have h: n = m := by + linarith + unfold for_all_fin_aux; simp_all + simp_all + -- There is no i s.t. m ≤ i + intro i h3; cases i; simp_all + linarith + case succ k hi => + simp_all + intro m hk hmn + intro hf i hmi + have hne: m ≠ n := by + have hineq := Nat.lt_of_sub_eq_succ hk + linarith + -- m = i? + if heq: m = i then + -- Yes: simply use the `for_all_fin_aux` hyp + unfold for_all_fin_aux at hf + simp_all + tauto + else + -- No: use the induction hypothesis + have hlt: m < i := by simp_all [Nat.lt_iff_le_and_ne] + have hineq: m + 1 ≤ n := by + have hineq := Nat.lt_of_sub_eq_succ hk + simp [*, Nat.add_one_le_iff] + have heq1: n - (m + 1) = k := by + -- TODO: very annoying arithmetic proof + simp [Nat.sub_eq_iff_eq_add hineq] + have hineq1: m ≤ n := by linarith + simp [Nat.sub_eq_iff_eq_add hineq1] at hk + simp_arith [hk] + have hi := hi (m + 1) heq1 hineq + apply hi <;> simp_all + . unfold for_all_fin_aux at hf + simp_all + . simp_all [add_one_le_iff_le_ne] + + -- TODO: this is not necessary anymore + theorem for_all_fin_imp_forall (n : Nat) (f : Fin n → Prop) : + for_all_fin f → ∀ i, f i + := by + intro Hf i + apply for_all_fin_aux_imp_forall <;> try assumption + simp + + /- Automating the proofs -/ + @[simp] theorem is_valid_p_same + (k : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) (x : Result c) : + is_valid_p k (λ _ => x) := by + simp [is_valid_p, k_to_gen, e_to_gen] + + @[simp] theorem is_valid_p_rec + (k : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) (i : id) (x : a i) : + is_valid_p k (λ k => k i x) := by + simp [is_valid_p, k_to_gen, e_to_gen, kk_to_gen, kk_of_gen] + + theorem is_valid_p_bind + {{k : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)}} + {{g : ((i:id) → a i → Result (b i)) → Result c}} + {{h : c → ((i:id) → a i → Result (b i)) → Result d}} + (Hgvalid : is_valid_p k g) + (Hhvalid : ∀ y, is_valid_p k (h y)) : + is_valid_p k (λ k => do let y ← g k; h y k) := by + apply Fix.is_valid_p_bind + . apply Hgvalid + . apply Hhvalid + + def Funs.is_valid_p + (k : k_ty id a b) + (fl : Funs id a b itys otys) : + Prop := + match fl with + | .Nil => True + | .Cons f fl => (∀ x, FixI.is_valid_p k (λ k => f k x)) ∧ fl.is_valid_p k + + def Funs.is_valid_p_is_valid_p_aux + {k : k_ty id a b} + {itys otys : List Type} + (Heq : List.length otys = List.length itys) + (fl : Funs id a b itys otys) (Hvalid : is_valid_p k fl) : + ∀ (i : Fin itys.length) (x : itys.get i), FixI.is_valid_p k (fun k => get_fun fl i k x) := by + -- Prepare the induction + have ⟨ n, Hn ⟩ : { n : Nat // itys.length = n } := ⟨ itys.length, by rfl ⟩ + revert itys otys Heq fl Hvalid + induction n + -- + case zero => + intro itys otys Heq fl Hvalid Hlen; + have Heq: itys = [] := by cases itys <;> simp_all + have Heq: otys = [] := by cases otys <;> simp_all + intro i x + simp_all + have Hi := i.isLt + simp_all + case succ n Hn => + intro itys otys Heq fl Hvalid Hlen i x; + cases fl <;> simp at Hlen i x Heq Hvalid + rename_i ity oty itys otys f fl + have ⟨ Hvf, Hvalid ⟩ := Hvalid + have Hvf1: is_valid_p k fl := by + simp [Hvalid, Funs.is_valid_p] + have Hn := @Hn itys otys (by simp[*]) fl Hvf1 (by simp [*]) + -- Case disjunction on i + match i with + | ⟨ 0, _ ⟩ => + simp at x + simp [get_fun] + apply (Hvf x) + | ⟨ .succ j, HiLt ⟩ => + simp_arith at HiLt + simp at x + let j : Fin (List.length itys) := ⟨ j, by simp_arith [HiLt] ⟩ + have Hn := Hn j x + apply Hn + + def Funs.is_valid_p_is_valid_p + (itys otys : List (Type)) (Heq: otys.length = itys.length := by decide) + (k : k_ty (Fin (List.length itys)) (List.get itys) fun i => List.get otys (fin_cast Heq i)) + (fl : Funs (Fin itys.length) itys.get (λ i => otys.get (fin_cast Heq i)) itys otys) : + fl.is_valid_p k → + ∀ (i : Fin itys.length) (x : itys.get i), FixI.is_valid_p k (fun k => get_fun fl i k x) + := by + intro Hvalid + apply is_valid_p_is_valid_p_aux <;> simp [*] + +end FixI + +namespace Ex1 + /- An example of use of the fixed-point -/ + open Primitives Fix + + variable {a : Type} (k : (List a × Int) → Result a) + + def list_nth_body (x : (List a × Int)) : Result a := + let (ls, i) := x + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else k (tl, i - 1) + + theorem list_nth_body_is_valid: ∀ k x, is_valid_p k (λ k => @list_nth_body a k x) := by + intro k x + simp [list_nth_body] + split <;> simp + split <;> simp + + noncomputable + def list_nth (ls : List a) (i : Int) : Result a := fix list_nth_body (ls, i) + + -- The unfolding equation - diverges if `i < 0` + theorem list_nth_eq (ls : List a) (i : Int) : + list_nth ls i = + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else list_nth tl (i - 1) + := by + have Heq := is_valid_fix_fixed_eq (@list_nth_body_is_valid a) + simp [list_nth] + conv => lhs; rw [Heq] + +end Ex1 + +namespace Ex2 + /- Same as Ex1, but we make the body of nth non tail-rec (this is mostly + to see what happens when there are let-bindings) -/ + open Primitives Fix + + variable {a : Type} (k : (List a × Int) → Result a) + + def list_nth_body (x : (List a × Int)) : Result a := + let (ls, i) := x + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else + do + let y ← k (tl, i - 1) + .ret y + + theorem list_nth_body_is_valid: ∀ k x, is_valid_p k (λ k => @list_nth_body a k x) := by + intro k x + simp [list_nth_body] + split <;> simp + split <;> simp + apply is_valid_p_bind <;> intros <;> simp_all + + noncomputable + def list_nth (ls : List a) (i : Int) : Result a := fix list_nth_body (ls, i) + + -- The unfolding equation - diverges if `i < 0` + theorem list_nth_eq (ls : List a) (i : Int) : + (list_nth ls i = + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else + do + let y ← list_nth tl (i - 1) + .ret y) + := by + have Heq := is_valid_fix_fixed_eq (@list_nth_body_is_valid a) + simp [list_nth] + conv => lhs; rw [Heq] + +end Ex2 + +namespace Ex3 + /- Mutually recursive functions - first encoding (see Ex4 for a better encoding) -/ + open Primitives Fix + + /- Because we have mutually recursive functions, we use a sum for the inputs + and the output types: + - inputs: the sum allows to select the function to call in the recursive + calls (and the functions may not have the same input types) + - outputs: this case is degenerate because `even` and `odd` have the same + return type `Bool`, but generally speaking we need a sum type because + the functions in the mutually recursive group may have different + return types. + -/ + variable (k : (Int ⊕ Int) → Result (Bool ⊕ Bool)) + + def is_even_is_odd_body (x : (Int ⊕ Int)) : Result (Bool ⊕ Bool) := + match x with + | .inl i => + -- Body of `is_even` + if i = 0 + then .ret (.inl true) -- We use .inl because this is `is_even` + else + do + let b ← + do + -- Call `odd`: we need to wrap the input value in `.inr`, then + -- extract the output value + let r ← k (.inr (i- 1)) + match r with + | .inl _ => .fail .panic -- Invalid output + | .inr b => .ret b + -- Wrap the return value + .ret (.inl b) + | .inr i => + -- Body of `is_odd` + if i = 0 + then .ret (.inr false) -- We use .inr because this is `is_odd` + else + do + let b ← + do + -- Call `is_even`: we need to wrap the input value in .inr, then + -- extract the output value + let r ← k (.inl (i- 1)) + match r with + | .inl b => .ret b + | .inr _ => .fail .panic -- Invalid output + -- Wrap the return value + .ret (.inr b) + + theorem is_even_is_odd_body_is_valid: + ∀ k x, is_valid_p k (λ k => is_even_is_odd_body k x) := by + intro k x + simp [is_even_is_odd_body] + split <;> simp <;> split <;> simp + apply is_valid_p_bind; simp + intros; split <;> simp + apply is_valid_p_bind; simp + intros; split <;> simp + + noncomputable + def is_even (i : Int): Result Bool := + do + let r ← fix is_even_is_odd_body (.inl i) + match r with + | .inl b => .ret b + | .inr _ => .fail .panic + + noncomputable + def is_odd (i : Int): Result Bool := + do + let r ← fix is_even_is_odd_body (.inr i) + match r with + | .inl _ => .fail .panic + | .inr b => .ret b + + -- The unfolding equation for `is_even` - diverges if `i < 0` + theorem is_even_eq (i : Int) : + is_even i = (if i = 0 then .ret true else is_odd (i - 1)) + := by + have Heq := is_valid_fix_fixed_eq is_even_is_odd_body_is_valid + simp [is_even, is_odd] + conv => lhs; rw [Heq]; simp; rw [is_even_is_odd_body]; simp + -- Very annoying: we need to swap the matches + -- Doing this with rewriting lemmas is hard generally speaking + -- (especially as we may have to generate lemmas for user-defined + -- inductives on the fly). + -- The simplest is to repeatedly split then simplify (we identify + -- the outer match or monadic let-binding, and split on its scrutinee) + split <;> simp + cases H0 : fix is_even_is_odd_body (Sum.inr (i - 1)) <;> simp + rename_i v + split <;> simp + + -- The unfolding equation for `is_odd` - diverges if `i < 0` + theorem is_odd_eq (i : Int) : + is_odd i = (if i = 0 then .ret false else is_even (i - 1)) + := by + have Heq := is_valid_fix_fixed_eq is_even_is_odd_body_is_valid + simp [is_even, is_odd] + conv => lhs; rw [Heq]; simp; rw [is_even_is_odd_body]; simp + -- Same remark as for `even` + split <;> simp + cases H0 : fix is_even_is_odd_body (Sum.inl (i - 1)) <;> simp + rename_i v + split <;> simp + +end Ex3 + +namespace Ex4 + /- Mutually recursive functions - 2nd encoding -/ + open Primitives FixI + + attribute [local simp] List.get + + /- We make the input type and output types dependent on a parameter -/ + @[simp] def input_ty (i : Fin 2) : Type := + [Int, Int].get i + + @[simp] def output_ty (i : Fin 2) : Type := + [Bool, Bool].get i + + /- The continuation -/ + variable (k : (i : Fin 2) → input_ty i → Result (output_ty i)) + + /- The bodies are more natural -/ + def is_even_body (k : (i : Fin 2) → input_ty i → Result (output_ty i)) (i : Int) : Result Bool := + if i = 0 + then .ret true + else do + let b ← k 1 (i - 1) + .ret b + + def is_odd_body (i : Int) : Result Bool := + if i = 0 + then .ret false + else do + let b ← k 0 (i - 1) + .ret b + + @[simp] def bodies : + Funs (Fin 2) input_ty output_ty [Int, Int] [Bool, Bool] := + Funs.Cons (is_even_body) (Funs.Cons (is_odd_body) Funs.Nil) + + def body (k : (i : Fin 2) → input_ty i → Result (output_ty i)) (i: Fin 2) : + input_ty i → Result (output_ty i) := get_fun bodies i k + + theorem body_is_valid : is_valid body := by + -- Split the proof into proofs of validity of the individual bodies + rw [is_valid] + simp only [body] + intro k + apply (Funs.is_valid_p_is_valid_p [Int, Int] [Bool, Bool]) + simp [Funs.is_valid_p] + (repeat (apply And.intro)) <;> intro x <;> simp at x <;> + simp only [is_even_body, is_odd_body] + -- Prove the validity of the individual bodies + . split <;> simp + apply is_valid_p_bind <;> simp + . split <;> simp + apply is_valid_p_bind <;> simp + + theorem body_fix_eq : fix body = body (fix body) := + is_valid_fix_fixed_eq body_is_valid + + noncomputable def is_even (i : Int) : Result Bool := fix body 0 i + noncomputable def is_odd (i : Int) : Result Bool := fix body 1 i + + theorem is_even_eq (i : Int) : is_even i = + (if i = 0 + then .ret true + else do + let b ← is_odd (i - 1) + .ret b) := by + simp [is_even, is_odd]; + conv => lhs; rw [body_fix_eq] + + theorem is_odd_eq (i : Int) : is_odd i = + (if i = 0 + then .ret false + else do + let b ← is_even (i - 1) + .ret b) := by + simp [is_even, is_odd]; + conv => lhs; rw [body_fix_eq] + +end Ex4 + +namespace Ex5 + /- Higher-order example -/ + open Primitives Fix + + variable {a b : Type} + + /- An auxiliary function, which doesn't require the fixed-point -/ + def map (f : a → Result b) (ls : List a) : Result (List b) := + match ls with + | [] => .ret [] + | hd :: tl => + do + let hd ← f hd + let tl ← map f tl + .ret (hd :: tl) + + /- The validity theorem for `map`, generic in `f` -/ + theorem map_is_valid + {{f : (a → Result b) → a → Result c}} + (Hfvalid : ∀ k x, is_valid_p k (λ k => f k x)) + (k : (a → Result b) → a → Result b) + (ls : List a) : + is_valid_p k (λ k => map (f k) ls) := by + induction ls <;> simp [map] + apply is_valid_p_bind <;> simp_all + intros + apply is_valid_p_bind <;> simp_all + + /- An example which uses map -/ + inductive Tree (a : Type) := + | leaf (x : a) + | node (tl : List (Tree a)) + + def id_body (k : Tree a → Result (Tree a)) (t : Tree a) : Result (Tree a) := + match t with + | .leaf x => .ret (.leaf x) + | .node tl => + do + let tl ← map k tl + .ret (.node tl) + + theorem id_body_is_valid : + ∀ k x, is_valid_p k (λ k => @id_body a k x) := by + intro k x + simp only [id_body] + split <;> simp + apply is_valid_p_bind <;> simp [*] + -- We have to show that `map k tl` is valid + apply map_is_valid; + -- Remark: if we don't do the intro, then the last step is expensive: + -- "typeclass inference of Nonempty took 119ms" + intro k x + simp only [is_valid_p_same, is_valid_p_rec] + + noncomputable def id (t : Tree a) := fix id_body t + + -- The unfolding equation + theorem id_eq (t : Tree a) : + (id t = + match t with + | .leaf x => .ret (.leaf x) + | .node tl => + do + let tl ← map id tl + .ret (.node tl)) + := by + have Heq := is_valid_fix_fixed_eq (@id_body_is_valid a) + simp [id] + conv => lhs; rw [Heq]; simp; rw [id_body] + +end Ex5 diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean new file mode 100644 index 00000000..313c5a79 --- /dev/null +++ b/backends/lean/Base/Diverge/Elab.lean @@ -0,0 +1,182 @@ +import Lean +import Lean.Meta.Tactic.Simp +import Init.Data.List.Basic +import Mathlib.Tactic.RunCmd +import Base.Diverge.Base +import Base.Diverge.ElabBase + +namespace Diverge + +/- Automating the generation of the encoding and the proofs so as to use nice + syntactic sugar. -/ + +syntax (name := divergentDef) + declModifiers "divergent" "def" declId ppIndent(optDeclSig) declVal : command + +open Lean Elab Term Meta Primitives + +initialize registerTraceClass `Diverge.divRecursion (inherited := true) + +set_option trace.Diverge.divRecursion true + +/- The following was copied from the `wfRecursion` function. -/ + +open WF in +def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do + let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) + logInfo ("divRecursion: defs: " ++ msg) + + -- CHANGE HERE This function should add definitions with these names/types/values ^^ + -- Temporarily add the predefinitions as axioms + for preDef in preDefs do + addAsAxiom preDef + + -- TODO: what is this? + for preDef in preDefs do + applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation + + -- Process the definitions + addAndCompilePartialRec preDefs + +-- The following function is copy&pasted from Lean.Elab.PreDefinition.Main +-- This is the only part where we make actual changes and hook into the equation compiler. +-- (I've removed all the well-founded stuff to make it easier to read though.) + +open private ensureNoUnassignedMVarsAtPreDef betaReduceLetRecApps partitionPreDefs + addAndCompilePartial addAsAxioms from Lean.Elab.PreDefinition.Main + +def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLCtx {} {} do + for preDef in preDefs do + trace[Elab.definition.body] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" + let preDefs ← preDefs.mapM ensureNoUnassignedMVarsAtPreDef + let preDefs ← betaReduceLetRecApps preDefs + let cliques := partitionPreDefs preDefs + let mut hasErrors := false + for preDefs in cliques do + trace[Elab.definition.scc] "{preDefs.map (·.declName)}" + try + logInfo "calling divRecursion" + withRef (preDefs[0]!.ref) do + divRecursion preDefs + logInfo "divRecursion succeeded" + catch ex => + -- If it failed, we + logInfo "divRecursion failed" + hasErrors := true + logException ex + let s ← saveState + try + if preDefs.all fun preDef => preDef.kind == DefKind.def || + preDefs.all fun preDef => preDef.kind == DefKind.abbrev then + -- try to add as partial definition + try + addAndCompilePartial preDefs (useSorry := true) + catch _ => + -- Compilation failed try again just as axiom + s.restore + addAsAxioms preDefs + else return () + catch _ => s.restore + +-- The following two functions are copy&pasted from Lean.Elab.MutualDef + +open private elabHeaders levelMVarToParamHeaders getAllUserLevelNames withFunLocalDecls elabFunValues + instantiateMVarsAtHeader instantiateMVarsAtLetRecToLift checkLetRecsToLiftTypes withUsed from Lean.Elab.MutualDef + +def Term.elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do + let scopeLevelNames ← getLevelNames + let headers ← elabHeaders views + let headers ← levelMVarToParamHeaders views headers + let allUserLevelNames := getAllUserLevelNames headers + withFunLocalDecls headers fun funFVars => do + for view in views, funFVar in funFVars do + addLocalVarInfo view.declId funFVar + let values ← + try + let values ← elabFunValues headers + Term.synthesizeSyntheticMVarsNoPostponing + values.mapM (instantiateMVars ·) + catch ex => + logException ex + headers.mapM fun header => mkSorry header.type (synthetic := true) + let headers ← headers.mapM instantiateMVarsAtHeader + let letRecsToLift ← getLetRecsToLift + let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift + checkLetRecsToLiftTypes funFVars letRecsToLift + withUsed vars headers values letRecsToLift fun vars => do + let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift + for preDef in preDefs do + trace[Elab.definition] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" + let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamPreDecls preDefs + let preDefs ← instantiateMVarsAtPreDecls preDefs + let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames + for preDef in preDefs do + trace[Elab.definition] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}" + checkForHiddenUnivLevels allUserLevelNames preDefs + addPreDefinitions preDefs + +open Command in +def Command.elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do + let views ← ds.mapM fun d => do + let `($mods:declModifiers divergent def $id:declId $sig:optDeclSig $val:declVal) := d + | throwUnsupportedSyntax + let modifiers ← elabModifiers mods + let (binders, type) := expandOptDeclSig sig + let deriving? := none + pure { ref := d, kind := DefKind.def, modifiers, + declId := id, binders, type? := type, value := val, deriving? } + runTermElabM fun vars => Term.elabMutualDef vars views + +-- Special command so that we don't fall back to the built-in mutual when we produce an error. +local syntax "_divergent" Parser.Command.mutual : command +elab_rules : command | `(_divergent mutual $decls* end) => Command.elabMutualDef decls + +macro_rules + | `(mutual $decls* end) => do + unless !decls.isEmpty && decls.all (·.1.getKind == ``divergentDef) do + Macro.throwUnsupported + `(command| _divergent mutual $decls* end) + +open private setDeclIdName from Lean.Elab.Declaration +elab_rules : command + | `($mods:declModifiers divergent%$tk def $id:declId $sig:optDeclSig $val:declVal) => do + let (name, _) := expandDeclIdCore id + if (`_root_).isPrefixOf name then throwUnsupportedSyntax + let view := extractMacroScopes name + let .str ns shortName := view.name | throwUnsupportedSyntax + let shortName' := { view with name := shortName }.review + let cmd ← `(mutual $mods:declModifiers divergent%$tk def $(⟨setDeclIdName id shortName'⟩):declId $sig:optDeclSig $val:declVal end) + if ns matches .anonymous then + Command.elabCommand cmd + else + Command.elabCommand <| ← `(namespace $(mkIdentFrom id ns) $cmd end $(mkIdentFrom id ns)) + +mutual + divergent def is_even (i : Int) : Result Bool := + if i = 0 then return true else return (← is_odd (i - 1)) + + divergent def is_odd (i : Int) : Result Bool := + if i = 0 then return false else return (← is_even (i - 1)) +end + +example (i : Int) : is_even i = .ret (i % 2 = 0) ∧ is_odd i = .ret (i % 2 ≠ 0) := by + induction i + unfold is_even + sorry + +divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := + match ls with + | [] => .fail .panic + | x :: ls => + if i = 0 then return x + else return (← list_nth ls (i - 1)) + +mutual + divergent def foo (i : Int) : Result Nat := + if i > 10 then return (← foo (i / 10)) + (← bar i) else bar 10 + + divergent def bar (i : Int) : Result Nat := + if i > 20 then foo (i / 20) else .ret 42 +end + +end Diverge diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean new file mode 100644 index 00000000..e693dce2 --- /dev/null +++ b/backends/lean/Base/Diverge/ElabBase.lean @@ -0,0 +1,9 @@ +import Lean + +namespace Diverge + +open Lean + +initialize registerTraceClass `Diverge.divRecursion (inherited := true) + +end Diverge -- cgit v1.2.3 From a6de153f3bfda7feb27d16fcdf2131d37f99c7a3 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 29 Jun 2023 11:22:32 +0200 Subject: Start working on Elab.lean --- backends/lean/Base/Diverge/Base.lean | 3 + backends/lean/Base/Diverge/Elab.lean | 138 ++++++++++++++++++++++++++++--- backends/lean/Base/Diverge/ElabBase.lean | 75 ++++++++++++++++- 3 files changed, 203 insertions(+), 13 deletions(-) (limited to 'backends/lean/Base/Diverge') diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index 0f92e682..2e60f6e8 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -4,6 +4,9 @@ import Init.Data.List.Basic import Mathlib.Tactic.RunCmd import Mathlib.Tactic.Linarith +-- For debugging +import Base.Diverge.ElabBase + /- TODO: - we want an easier to use cases: diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 313c5a79..22e0039f 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -15,16 +15,53 @@ syntax (name := divergentDef) open Lean Elab Term Meta Primitives -initialize registerTraceClass `Diverge.divRecursion (inherited := true) - -set_option trace.Diverge.divRecursion true +set_option trace.Diverge.def true /- The following was copied from the `wfRecursion` function. -/ open WF in + + + +-- Replace the recursive calls by a call to the continuation +-- def replace_rec_calls + +#check Lean.Meta.forallTelescope +#check Expr +#check withRef +#check MonadRef.withRef +#check Nat +#check Array +#check Lean.Meta.inferType +#check Nat +#check Int + +#check (0, 1) +#check Prod +#check () +#check Unit +#check Sigma + +-- print_decl is_even_body +#check instOfNatNat +#check OfNat.ofNat -- @OfNat.ofNat ℕ 2 ... +#check OfNat.ofNat -- @OfNat.ofNat (Fin 2) 1 ... +#check Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat + + +-- TODO: is there already such a utility somewhere? +-- TODO: change to mkSigmas +def mkProds (tys : List Expr) : MetaM Expr := + match tys with + | [] => do return (Expr.const ``PUnit.unit []) + | [ty] => do return ty + | ty :: tys => do + let pty ← mkProds tys + mkAppM ``Prod.mk #[ty, pty] + def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) - logInfo ("divRecursion: defs: " ++ msg) + trace[Diverge.def] ("divRecursion: defs: " ++ msg) -- CHANGE HERE This function should add definitions with these names/types/values ^^ -- Temporarily add the predefinitions as axioms @@ -35,6 +72,85 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do for preDef in preDefs do applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation + -- Retrieve the name of the first definition, that we will use as the namespace + -- for the definitions common to the group + let def0 := preDefs[0]! + let grName := def0.declName + trace[Diverge.def] "group name: {grName}" + + /- Compute the type of the continuation. + + We do the following + - we make sure all the definitions have the same universe parameters + (we can make this more general later) + - we group all the type parameters together, make sure all the + definitions have the same type parameters, and enforce + a uniform polymorphism (we can also lift this later). + This would require generalizing a bit our indexed fixed point to + make the output type parametric in the input. + - we group all the non-type parameters: we parameterize the continuation + by those + -/ + let grLvlParams := def0.levelParams + trace[Diverge.def] "def0 type: {def0.type}" + + -- Small utility: compute the list of type parameters + let getTypeParams (ty: Expr) : MetaM (List Expr × List Expr × Expr) := + Lean.Meta.forallTelescope ty fun tys out_ty => do + trace[Diverge.def] "types: {tys}" +/- let (_, params) ← StateT.run (do + for x in tys do + let ty ← Lean.Meta.inferType x + match ty with + | .sort _ => do + let st ← StateT.get + StateT.set (ty :: st) + | _ => do break + ) ([] : List Expr) + let params := params.reverse + trace[Diverge.def] " type parameters {params}" + return params -/ + let rec get_params (ls : List Expr) : MetaM (List Expr × List Expr) := + match ls with + | x :: tl => do + let ty ← Lean.Meta.inferType x + match ty with + | .sort _ => do + let (ty_params, params) ← get_params tl + return (x :: ty_params, params) + | _ => do return ([], ls) + | _ => do return ([], []) + let (ty_params, params) ← get_params tys.toList + trace[Diverge.def] " parameters: {ty_params}; {params}" + return (ty_params, params, out_ty) + let (grTyParams, _, _) ← do + getTypeParams def0.type + + -- Compute the input types and the output types + let all_tys ← preDefs.mapM fun preDef => do + let (tyParams, params, ret_ty) ← getTypeParams preDef.type + -- TODO: this is not complete, there are more checks to perform + if tyParams.length ≠ grTyParams.length then + throwError "Non-uniform polymorphism" + return (params, ret_ty) + + -- TODO: I think there are issues with the free variables + let (input_tys, output_tys) := List.unzip all_tys.toList + let input_tys : List Expr ← liftM (List.mapM mkProds input_tys) + + trace[Diverge.def] " in/out tys: {input_tys}; {output_tys}" + + -- Compute the names set + let names := preDefs.map PreDefinition.declName + let names := HashSet.empty.insertMany names + + -- + for preDef in preDefs do + trace[Diverge.def] "about to explore: {preDef.declName}" + explore_term "" preDef.value + + -- Compute the bodies + -- Process the definitions addAndCompilePartialRec preDefs @@ -47,21 +163,21 @@ open private ensureNoUnassignedMVarsAtPreDef betaReduceLetRecApps partitionPreDe def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLCtx {} {} do for preDef in preDefs do - trace[Elab.definition.body] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" + trace[Diverge.elab] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" let preDefs ← preDefs.mapM ensureNoUnassignedMVarsAtPreDef let preDefs ← betaReduceLetRecApps preDefs let cliques := partitionPreDefs preDefs let mut hasErrors := false for preDefs in cliques do - trace[Elab.definition.scc] "{preDefs.map (·.declName)}" + trace[Diverge.elab] "{preDefs.map (·.declName)}" try - logInfo "calling divRecursion" + trace[Diverge.elab] "calling divRecursion" withRef (preDefs[0]!.ref) do divRecursion preDefs - logInfo "divRecursion succeeded" + trace[Diverge.elab] "divRecursion succeeded" catch ex => -- If it failed, we - logInfo "divRecursion failed" + trace[Diverge.elab] "divRecursion failed" hasErrors := true logException ex let s ← saveState @@ -106,12 +222,12 @@ def Term.elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM U withUsed vars headers values letRecsToLift fun vars => do let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift for preDef in preDefs do - trace[Elab.definition] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" + trace[Diverge.elab] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamPreDecls preDefs let preDefs ← instantiateMVarsAtPreDecls preDefs let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames for preDef in preDefs do - trace[Elab.definition] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}" + trace[Diverge.elab] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}" checkForHiddenUnivLevels allUserLevelNames preDefs addPreDefinitions preDefs diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean index e693dce2..84b73a30 100644 --- a/backends/lean/Base/Diverge/ElabBase.lean +++ b/backends/lean/Base/Diverge/ElabBase.lean @@ -2,8 +2,79 @@ import Lean namespace Diverge -open Lean +open Lean Elab Term Meta -initialize registerTraceClass `Diverge.divRecursion (inherited := true) +initialize registerTraceClass `Diverge.elab (inherited := true) +initialize registerTraceClass `Diverge.def (inherited := true) + +-- TODO: move +-- TODO: small helper +def explore_term (incr : String) (e : Expr) : TermElabM Unit := + match e with + | .bvar _ => do logInfo m!"{incr}bvar: {e}"; return () + | .fvar _ => do logInfo m!"{incr}fvar: {e}"; return () + | .mvar _ => do logInfo m!"{incr}mvar: {e}"; return () + | .sort _ => do logInfo m!"{incr}sort: {e}"; return () + | .const _ _ => do logInfo m!"{incr}const: {e}"; return () + | .app fn arg => do + logInfo m!"{incr}app: {e}" + explore_term (incr ++ " ") fn + explore_term (incr ++ " ") arg + | .lam _bName bTy body _binfo => do + logInfo m!"{incr}lam: {e}" + explore_term (incr ++ " ") bTy + explore_term (incr ++ " ") body + | .forallE _bName bTy body _bInfo => do + logInfo m!"{incr}forallE: {e}" + explore_term (incr ++ " ") bTy + explore_term (incr ++ " ") body + | .letE _dName ty val body _nonDep => do + logInfo m!"{incr}letE: {e}" + explore_term (incr ++ " ") ty + explore_term (incr ++ " ") val + explore_term (incr ++ " ") body + | .lit _ => do logInfo m!"{incr}lit: {e}"; return () + | .mdata _ e => do + logInfo m!"{incr}mdata: {e}" + explore_term (incr ++ " ") e + | .proj _ _ struct => do + logInfo m!"{incr}proj: {e}" + explore_term (incr ++ " ") struct + +def explore_decl (n : Name) : TermElabM Unit := do + logInfo m!"Name: {n}" + let env ← getEnv + let decl := env.constants.find! n + match decl with + | .defnInfo val => + logInfo m!"About to explore defn: {decl.name}" + logInfo m!"# Type:" + explore_term "" val.type + logInfo m!"# Value:" + explore_term "" val.value + | .axiomInfo _ => throwError m!"axiom: {n}" + | .thmInfo _ => throwError m!"thm: {n}" + | .opaqueInfo _ => throwError m!"opaque: {n}" + | .quotInfo _ => throwError m!"quot: {n}" + | .inductInfo _ => throwError m!"induct: {n}" + | .ctorInfo _ => throwError m!"ctor: {n}" + | .recInfo _ => throwError m!"rec: {n}" + +syntax (name := printDecl) "print_decl " ident : command + +open Lean.Elab.Command + +@[command_elab printDecl] def elabPrintDecl : CommandElab := fun stx => do + liftTermElabM do + let id := stx[1] + addCompletionInfo <| CompletionInfo.id id id.getId (danglingDot := false) {} none + let cs ← resolveGlobalConstWithInfos id + explore_decl cs[0]! + +private def test1 : Nat := 0 +private def test2 (x : Nat) : Nat := x + +print_decl test1 +print_decl test2 end Diverge -- cgit v1.2.3 From 0cee49de70bec6d3ec2221b64a532d19ad71e5e0 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 29 Jun 2023 14:51:53 +0200 Subject: Generalize a bit FixI and add an example --- backends/lean/Base/Diverge/Base.lean | 260 ++++++++++++++++++++--------------- 1 file changed, 151 insertions(+), 109 deletions(-) (limited to 'backends/lean/Base/Diverge') diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index 2e60f6e8..630c0bf6 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -57,7 +57,7 @@ deriving Repr, BEq open Result -def bind (x: Result α) (f: α -> Result β) : Result β := +def bind {α : Type u} {β : Type v} (x: Result α) (f: α -> Result β) : Result β := match x with | ret v => f v | fail v => fail v @@ -84,7 +84,7 @@ instance : Pure Result where @[simp] theorem bind_tc_div (f : α → Result β) : (do let y ← div; f y) = div := by simp [Bind.bind, bind] -def div? {α: Type} (r: Result α): Bool := +def div? {α: Type u} (r: Result α): Bool := match r with | div => true | ret _ | fail _ => false @@ -96,8 +96,8 @@ namespace Fix open Primitives open Result - variable {a : Type} {b : a → Type} - variable {c d : Type} + variable {a : Type u} {b : a → Type v} + variable {c d : Type w} -- TODO: why do we have to make them both : Type w? /-! # The least fixed point definition and its properties -/ @@ -334,7 +334,8 @@ namespace Fix (h : c → ((x:a) → Result (b x)) → Result d) : is_mono_p g → (∀ y, is_mono_p (h y)) → - @is_mono_p a b d (λ k => do let y ← g k; h y k) := by + @is_mono_p a b d (λ k => @Bind.bind Result _ c d (g k) (fun y => h y k)) := by +-- @is_mono_p a b d (λ k => do let (y : c) ← g k; h y k) := by intro hg hh simp [is_mono_p] intro fg fh Hrgh @@ -494,49 +495,49 @@ namespace FixI open Primitives Fix -- The index type - variable {id : Type} + variable {id : Type u} -- The input/output types - variable {a b : id → Type} + variable {a : id → Type v} {b : (i:id) → a i → Type w} -- Monotonicity relation over monadic arrows (i.e., Kleisli arrows) - def karrow_rel (k1 k2 : (i:id) → a i → Result (b i)) : Prop := + def karrow_rel (k1 k2 : (i:id) → (x:a i) → Result (b i x)) : Prop := ∀ i x, result_rel (k1 i x) (k2 i x) - def kk_to_gen (k : (i:id) → a i → Result (b i)) : - (x: (i:id) × a i) → Result (b x.fst) := + def kk_to_gen (k : (i:id) → (x:a i) → Result (b i x)) : + (x: (i:id) × a i) → Result (b x.fst x.snd) := λ ⟨ i, x ⟩ => k i x - def kk_of_gen (k : (x: (i:id) × a i) → Result (b x.fst)) : - (i:id) → a i → Result (b i) := + def kk_of_gen (k : (x: (i:id) × a i) → Result (b x.fst x.snd)) : + (i:id) → (x:a i) → Result (b i x) := λ i x => k ⟨ i, x ⟩ - def k_to_gen (k : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) : - ((x: (i:id) × a i) → Result (b x.fst)) → (x: (i:id) × a i) → Result (b x.fst) := + def k_to_gen (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) : + ((x: (i:id) × a i) → Result (b x.fst x.snd)) → (x: (i:id) × a i) → Result (b x.fst x.snd) := λ kk => kk_to_gen (k (kk_of_gen kk)) - def k_of_gen (k : ((x: (i:id) × a i) → Result (b x.fst)) → (x: (i:id) × a i) → Result (b x.fst)) : - ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i) := + def k_of_gen (k : ((x: (i:id) × a i) → Result (b x.fst x.snd)) → (x: (i:id) × a i) → Result (b x.fst x.snd)) : + ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x) := λ kk => kk_of_gen (k (kk_to_gen kk)) - def e_to_gen (e : ((i:id) → a i → Result (b i)) → Result c) : - ((x: (i:id) × a i) → Result (b x.fst)) → Result c := + def e_to_gen (e : ((i:id) → (x:a i) → Result (b i x)) → Result c) : + ((x: (i:id) × a i) → Result (b x.fst x.snd)) → Result c := λ k => e (kk_of_gen k) - def is_valid_p (k : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) - (e : ((i:id) → a i → Result (b i)) → Result c) : Prop := + def is_valid_p (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) + (e : ((i:id) → (x:a i) → Result (b i x)) → Result c) : Prop := Fix.is_valid_p (k_to_gen k) (e_to_gen e) - def is_valid (f : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) : Prop := + def is_valid (f : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) : Prop := ∀ k i x, is_valid_p k (λ k => f k i x) noncomputable def fix - (f : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) : - (i:id) → a i → Result (b i) := + (f : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) : + (i:id) → (x:a i) → Result (b i x) := kk_of_gen (Fix.fix (k_to_gen f)) theorem is_valid_fix_fixed_eq - {{f : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)}} + {{f : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)}} (Hvalid : is_valid f) : fix f = f (fix f) := by have Hvalid' : Fix.is_valid (k_to_gen f) := by @@ -553,57 +554,43 @@ namespace FixI /- Some utilities to define the mutually recursive functions -/ -- TODO: use more - @[simp] def kk_ty (id : Type) (a b : id → Type) := (i:id) → a i → Result (b i) - @[simp] def k_ty (id : Type) (a b : id → Type) := kk_ty id a b → kk_ty id a b + @[simp] def kk_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) := + (i:id) → (x:a i) → Result (b i x) + @[simp] def k_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) := + kk_ty id a b → kk_ty id a b + + def in_out_ty : Type (imax (u + 1) (v + 1)) := (in_ty : Type u) × ((x:in_ty) → Type v) + @[simp] def mk_in_out_ty (in_ty : Type u) (out_ty : in_ty → Type v) : + in_out_ty := + Sigma.mk in_ty out_ty -- Initially, we had left out the parameters id, a and b. -- However, by parameterizing Funs with those parameters, we can state -- and prove lemmas like Funs.is_valid_p_is_valid_p - inductive Funs (id : Type) (a b : id → Type) : - List (Type u) → List (Type u) → Type (u + 1) := - | Nil : Funs id a b [] [] - | Cons {ity oty : Type u} {itys otys : List (Type u)} - (f : kk_ty id a b → ity → Result oty) (tl : Funs id a b itys otys) : - Funs id a b (ity :: itys) (oty :: otys) - - theorem Funs.length_eq {itys otys : List (Type)} (fl : Funs id a b itys otys) : - otys.length = itys.length := - match fl with - | .Nil => by simp - | .Cons f tl => - have h:= Funs.length_eq tl - by simp [h] - - def fin_cast {n m : Nat} (h : m = n) (i : Fin n) : Fin m := - ⟨ i.val, by have h1:= i.isLt; simp_all ⟩ - - @[simp] def Funs.cast_fin {itys otys : List (Type)} - (fl : Funs id a b itys otys) (i : Fin itys.length) : Fin otys.length := - fin_cast (fl.length_eq) i - - def get_fun {itys otys : List (Type)} (fl : Funs id a b itys otys) : - (i : Fin itys.length) → kk_ty id a b → itys.get i → Result (otys.get (fl.cast_fin i)) := + inductive Funs (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) : + List in_out_ty.{v, w} → Type (max (u + 1) (max (v + 1) (w + 1))) := + | Nil : Funs id a b [] + | Cons {ity : Type v} {oty : ity → Type w} {tys : List in_out_ty} + (f : kk_ty id a b → (x:ity) → Result (oty x)) (tl : Funs id a b tys) : + Funs id a b (⟨ ity, oty ⟩ :: tys) + + def get_fun {tys : List in_out_ty} (fl : Funs id a b tys) : + (i : Fin tys.length) → kk_ty id a b → (x : (tys.get i).fst) → + Result ((tys.get i).snd x) := match fl with | .Nil => λ i => by have h:= i.isLt; simp at h - | @Funs.Cons id a b ity oty itys1 otys1 f tl => - λ i => - if h: i.val = 0 then - Eq.mp (by cases i; simp_all [List.get]) f - else - let j := i.val - 1 - have Hj: j < itys1.length := by - have Hi := i.isLt - simp at Hi - revert Hi - cases Heq: i.val <;> simp_all + | @Funs.Cons id a b ity oty tys1 f tl => + λ ⟨ i, iLt ⟩ => + match i with + | 0 => + Eq.mp (by simp [List.get]) f + | .succ j => + have jLt: j < tys1.length := by + simp at iLt + revert iLt simp_arith - let j: Fin itys1.length := ⟨ j, Hj ⟩ - Eq.mp - (by - cases Heq: i; rename_i val isLt; - cases Heq': j; rename_i val' isLt; - cases val <;> simp_all [List.get, fin_cast]) - (get_fun tl j) + let j: Fin tys1.length := ⟨ j, jLt ⟩ + Eq.mp (by simp) (get_fun tl j) -- TODO: move theorem add_one_le_iff_le_ne (n m : Nat) (h1 : m ≤ n) (h2 : m ≠ n) : m + 1 ≤ n := by @@ -683,19 +670,19 @@ namespace FixI /- Automating the proofs -/ @[simp] theorem is_valid_p_same - (k : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) (x : Result c) : + (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) (x : Result c) : is_valid_p k (λ _ => x) := by simp [is_valid_p, k_to_gen, e_to_gen] @[simp] theorem is_valid_p_rec - (k : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)) (i : id) (x : a i) : + (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) (i : id) (x : a i) : is_valid_p k (λ k => k i x) := by simp [is_valid_p, k_to_gen, e_to_gen, kk_to_gen, kk_of_gen] theorem is_valid_p_bind - {{k : ((i:id) → a i → Result (b i)) → (i:id) → a i → Result (b i)}} - {{g : ((i:id) → a i → Result (b i)) → Result c}} - {{h : c → ((i:id) → a i → Result (b i)) → Result d}} + {{k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)}} + {{g : ((i:id) → (x:a i) → Result (b i x)) → Result c}} + {{h : c → ((i:id) → (x:a i) → Result (b i x)) → Result d}} (Hgvalid : is_valid_p k g) (Hhvalid : ∀ y, is_valid_p k (h y)) : is_valid_p k (λ k => do let y ← g k; h y k) := by @@ -705,7 +692,7 @@ namespace FixI def Funs.is_valid_p (k : k_ty id a b) - (fl : Funs id a b itys otys) : + (fl : Funs id a b tys) : Prop := match fl with | .Nil => True @@ -713,31 +700,29 @@ namespace FixI def Funs.is_valid_p_is_valid_p_aux {k : k_ty id a b} - {itys otys : List Type} - (Heq : List.length otys = List.length itys) - (fl : Funs id a b itys otys) (Hvalid : is_valid_p k fl) : - ∀ (i : Fin itys.length) (x : itys.get i), FixI.is_valid_p k (fun k => get_fun fl i k x) := by + {tys : List in_out_ty} + (fl : Funs id a b tys) (Hvalid : is_valid_p k fl) : + ∀ (i : Fin tys.length) (x : (tys.get i).fst), FixI.is_valid_p k (fun k => get_fun fl i k x) := by -- Prepare the induction - have ⟨ n, Hn ⟩ : { n : Nat // itys.length = n } := ⟨ itys.length, by rfl ⟩ - revert itys otys Heq fl Hvalid + have ⟨ n, Hn ⟩ : { n : Nat // tys.length = n } := ⟨ tys.length, by rfl ⟩ + revert tys fl Hvalid induction n -- case zero => - intro itys otys Heq fl Hvalid Hlen; - have Heq: itys = [] := by cases itys <;> simp_all - have Heq: otys = [] := by cases otys <;> simp_all + intro tys fl Hvalid Hlen; + have Heq: tys = [] := by cases tys <;> simp_all intro i x simp_all have Hi := i.isLt simp_all case succ n Hn => - intro itys otys Heq fl Hvalid Hlen i x; - cases fl <;> simp at Hlen i x Heq Hvalid - rename_i ity oty itys otys f fl + intro tys fl Hvalid Hlen i x; + cases fl <;> simp at Hlen i x Hvalid + rename_i ity oty tys f fl have ⟨ Hvf, Hvalid ⟩ := Hvalid have Hvf1: is_valid_p k fl := by simp [Hvalid, Funs.is_valid_p] - have Hn := @Hn itys otys (by simp[*]) fl Hvf1 (by simp [*]) + have Hn := @Hn tys fl Hvf1 (by simp [*]) -- Case disjunction on i match i with | ⟨ 0, _ ⟩ => @@ -747,19 +732,20 @@ namespace FixI | ⟨ .succ j, HiLt ⟩ => simp_arith at HiLt simp at x - let j : Fin (List.length itys) := ⟨ j, by simp_arith [HiLt] ⟩ + let j : Fin (List.length tys) := ⟨ j, by simp_arith [HiLt] ⟩ have Hn := Hn j x apply Hn def Funs.is_valid_p_is_valid_p - (itys otys : List (Type)) (Heq: otys.length = itys.length := by decide) - (k : k_ty (Fin (List.length itys)) (List.get itys) fun i => List.get otys (fin_cast Heq i)) - (fl : Funs (Fin itys.length) itys.get (λ i => otys.get (fin_cast Heq i)) itys otys) : + (tys : List in_out_ty) + (k : k_ty (Fin (List.length tys)) (λ i => (tys.get i).fst) (fun i x => (List.get tys i).snd x)) + (fl : Funs (Fin tys.length) (λ i => (tys.get i).fst) (λ i x => (tys.get i).snd x) tys) : fl.is_valid_p k → - ∀ (i : Fin itys.length) (x : itys.get i), FixI.is_valid_p k (fun k => get_fun fl i k x) + ∀ (i : Fin tys.length) (x : (tys.get i).fst), + FixI.is_valid_p k (fun k => get_fun fl i k x) := by intro Hvalid - apply is_valid_p_is_valid_p_aux <;> simp [*] + apply is_valid_p_is_valid_p_aux; simp [*] end FixI @@ -960,27 +946,21 @@ namespace Ex4 /- Mutually recursive functions - 2nd encoding -/ open Primitives FixI - attribute [local simp] List.get - /- We make the input type and output types dependent on a parameter -/ - @[simp] def input_ty (i : Fin 2) : Type := - [Int, Int].get i - - @[simp] def output_ty (i : Fin 2) : Type := - [Bool, Bool].get i - - /- The continuation -/ - variable (k : (i : Fin 2) → input_ty i → Result (output_ty i)) + @[simp] def tys : List in_out_ty := [mk_in_out_ty Int (λ _ => Bool), mk_in_out_ty Int (λ _ => Bool)] + @[simp] def input_ty (i : Fin 2) : Type := (tys.get i).fst + @[simp] def output_ty (i : Fin 2) (x : input_ty i) : Type := + (tys.get i).snd x /- The bodies are more natural -/ - def is_even_body (k : (i : Fin 2) → input_ty i → Result (output_ty i)) (i : Int) : Result Bool := + def is_even_body (k : (i : Fin 2) → (x : input_ty i) → Result (output_ty i x)) (i : Int) : Result Bool := if i = 0 then .ret true else do let b ← k 1 (i - 1) .ret b - def is_odd_body (i : Int) : Result Bool := + def is_odd_body (k : (i : Fin 2) → (x : input_ty i) → Result (output_ty i x)) (i : Int) : Result Bool := if i = 0 then .ret false else do @@ -988,18 +968,19 @@ namespace Ex4 .ret b @[simp] def bodies : - Funs (Fin 2) input_ty output_ty [Int, Int] [Bool, Bool] := + Funs (Fin 2) input_ty output_ty + [mk_in_out_ty Int (λ _ => Bool), mk_in_out_ty Int (λ _ => Bool)] := Funs.Cons (is_even_body) (Funs.Cons (is_odd_body) Funs.Nil) - def body (k : (i : Fin 2) → input_ty i → Result (output_ty i)) (i: Fin 2) : - input_ty i → Result (output_ty i) := get_fun bodies i k + def body (k : (i : Fin 2) → (x : input_ty i) → Result (output_ty i x)) (i: Fin 2) : + (x : input_ty i) → Result (output_ty i x) := get_fun bodies i k theorem body_is_valid : is_valid body := by -- Split the proof into proofs of validity of the individual bodies rw [is_valid] simp only [body] intro k - apply (Funs.is_valid_p_is_valid_p [Int, Int] [Bool, Bool]) + apply (Funs.is_valid_p_is_valid_p tys) simp [Funs.is_valid_p] (repeat (apply And.intro)) <;> intro x <;> simp at x <;> simp only [is_even_body, is_odd_body] @@ -1106,3 +1087,64 @@ namespace Ex5 conv => lhs; rw [Heq]; simp; rw [id_body] end Ex5 + +namespace Ex6 + /- `list_nth` again, but this time we use FixI -/ + open Primitives FixI + + @[simp] def tys.{u} : List in_out_ty := + [mk_in_out_ty ((a:Type u) × (List a × Int)) (λ ⟨ a, _ ⟩ => a)] + + @[simp] def input_ty (i : Fin 1) := (tys.get i).fst + @[simp] def output_ty (i : Fin 1) (x : input_ty i) := + (tys.get i).snd x + + def list_nth_body.{u} (k : (i:Fin 1) → (x:input_ty i) → Result (output_ty i x)) + (x : (a : Type u) × List a × Int) : Result x.fst := + let ⟨ a, ls, i ⟩ := x + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else k 0 ⟨ a, tl, i - 1 ⟩ + + @[simp] def bodies : + Funs (Fin 1) input_ty output_ty tys := + Funs.Cons list_nth_body Funs.Nil + + def body (k : (i : Fin 1) → (x : input_ty i) → Result (output_ty i x)) (i: Fin 1) : + (x : input_ty i) → Result (output_ty i x) := get_fun bodies i k + + theorem list_nth_body_is_valid: is_valid body := by + -- Split the proof into proofs of validity of the individual bodies + rw [is_valid] + simp only [body] + intro k + apply (Funs.is_valid_p_is_valid_p tys) + simp [Funs.is_valid_p] + (repeat (apply And.intro)); intro x; simp at x + simp only [list_nth_body] + -- Prove the validity of the individual bodies + intro k x + simp [list_nth_body] + split <;> simp + split <;> simp + + noncomputable + def list_nth {a: Type u} (ls : List a) (i : Int) : Result a := + fix body 0 ⟨ a, ls , i ⟩ + + -- The unfolding equation - diverges if `i < 0` + theorem list_nth_eq (ls : List a) (i : Int) : + list_nth ls i = + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else list_nth tl (i - 1) + := by + have Heq := is_valid_fix_fixed_eq (@list_nth_body_is_valid a) + simp [list_nth] + conv => lhs; rw [Heq] + +end Ex6 -- cgit v1.2.3 From fdc8693772ecb1978873018c790061854f00a015 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 29 Jun 2023 23:15:20 +0200 Subject: Write function to compute the input/output types --- backends/lean/Base/Diverge/Base.lean | 3 +- backends/lean/Base/Diverge/Elab.lean | 154 ++++++++++++++++++++++++------- backends/lean/Base/Diverge/ElabBase.lean | 1 + 3 files changed, 126 insertions(+), 32 deletions(-) (limited to 'backends/lean/Base/Diverge') diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index 630c0bf6..22b59bd0 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -560,6 +560,7 @@ namespace FixI kk_ty id a b → kk_ty id a b def in_out_ty : Type (imax (u + 1) (v + 1)) := (in_ty : Type u) × ((x:in_ty) → Type v) + -- TODO: remove? @[simp] def mk_in_out_ty (in_ty : Type u) (out_ty : in_ty → Type v) : in_out_ty := Sigma.mk in_ty out_ty @@ -1143,7 +1144,7 @@ namespace Ex6 if i = 0 then .ret hd else list_nth tl (i - 1) := by - have Heq := is_valid_fix_fixed_eq (@list_nth_body_is_valid a) + have Heq := is_valid_fix_fixed_eq list_nth_body_is_valid simp [list_nth] conv => lhs; rw [Heq] diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 22e0039f..116c5d8b 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -13,7 +13,7 @@ namespace Diverge syntax (name := divergentDef) declModifiers "divergent" "def" declId ppIndent(optDeclSig) declVal : command -open Lean Elab Term Meta Primitives +open Lean Elab Term Meta Primitives Lean.Meta set_option trace.Diverge.def true @@ -21,27 +21,9 @@ set_option trace.Diverge.def true open WF in - - -- Replace the recursive calls by a call to the continuation -- def replace_rec_calls -#check Lean.Meta.forallTelescope -#check Expr -#check withRef -#check MonadRef.withRef -#check Nat -#check Array -#check Lean.Meta.inferType -#check Nat -#check Int - -#check (0, 1) -#check Prod -#check () -#check Unit -#check Sigma - -- print_decl is_even_body #check instOfNatNat #check OfNat.ofNat -- @OfNat.ofNat ℕ 2 ... @@ -59,6 +41,100 @@ def mkProds (tys : List Expr) : MetaM Expr := let pty ← mkProds tys mkAppM ``Prod.mk #[ty, pty] +/- Generate the input type of a function body, which is a sigma type (i.e., a + dependent tuple) which groups all its inputs. + + Example: + - xl = [(a:Type), (ls:List a), (i:Int)] + + Generates: + `(a:Type) × (ls:List a) × (i:Int)` + + -/ +def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr := + match xl with + | [] => do + trace[Diverge.def.sigmas] "mkSigmasOfTypes: []" + return (Expr.const ``PUnit.unit []) + | [x] => do + trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}]" + let ty ← Lean.Meta.inferType x + return ty + | x :: xl => do + trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]" + let alpha ← Lean.Meta.inferType x + let sty ← mkSigmasTypesOfTypes xl + trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]: alpha={alpha}, sty={sty}" + let beta ← mkLambdaFVars #[x] sty + trace[Diverge.def.sigmas] "mkSigmasOfTypes: ({alpha}) ({beta})" + mkAppOptM ``Sigma #[some alpha, some beta] + +def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index + +/- Generate the out_ty of the body of a function, which from an input (a sigma + type generated by `mkSigmasTypesOfTypes`) gives the output type of the function. + + Example: + - xl = `[a:Type, ls:List a, i:Int]` + - out_ty = `a` + - index = 0 -- For naming purposes: we use it to numerotate the "scrutinee" variables + + Generates: + ``` + match scrut0 with + | Sigma.mk x scrut1 => + match scrut1 with + | Sigma.mk ls i => + a + ``` +-/ +def mkSigmasOutType (xl : List Expr) (out_ty : Expr) (index : Nat := 0) : MetaM Expr := + match xl with + | [] => do + -- This would be unexpected + throwError "mkSigmasOutType: empyt list of input parameters" + | [x] => do + -- In the explanations above: inner match case + trace[Diverge.def.sigmas] "mkSigmasOutType: [{x}]" + mkLambdaFVars #[x] out_ty + | fst :: xl => do + -- In the explanations above: outer match case + -- Remark: for the naming purposes, we use the same convention as for the + -- fields and parameters in `Sigma.casesOn and `Sigma.mk + trace[Diverge.def.sigmas] "mkSigmasOutType: [{fst}::{xl}]" + let alpha ← Lean.Meta.inferType fst + let snd_ty ← mkSigmasTypesOfTypes xl + let beta ← mkLambdaFVars #[fst] snd_ty + let snd ← mkSigmasOutType xl out_ty (index + 1) + let scrut_ty ← mkSigmasTypesOfTypes (fst :: xl) + withLocalDeclD (mk_indexed_name index) scrut_ty fun scrut => do + let mk ← mkLambdaFVars #[fst] snd + trace[Diverge.def.sigmas] "mkSigmasOutType: scrut: ({scrut}) : ({← inferType scrut})" + let motive ← mkLambdaFVars #[scrut] (← inferType out_ty) + trace[Diverge.def.sigmas] "mkSigmasOutType:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})" + let out ← mkAppOptM ``Sigma.casesOn #[some alpha, some beta, some motive, some scrut, some mk] + let out ← mkLambdaFVars #[scrut] out + trace[Diverge.def.sigmas] "mkSigmasOutType: out: {out}" + return out + +/- Small tests for list_nth: give a model of what `mkSigmasOutType` should generate -/ +private def list_nth_out_ty2 (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) := + @Sigma.casesOn (List a) + (fun (_ls : List a) => Int) + (fun (_scrut1:@Sigma (List a) (fun (_ls : List a) => Int)) => Type) + scrut1 + (fun (_ls : List a) (_i : Int) => Diverge.Primitives.Result a) + +private def list_nth_out_ty1 (scrut0 : @Sigma (Type) (fun (a:Type) => + @Sigma (List a) (fun (_ls : List a) => Int))) := + @Sigma.casesOn (Type) + (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int)) + (fun (_scrut0:@Sigma (Type) (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int))) => Type) + scrut0 + (fun (a : Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) => + list_nth_out_ty2 a scrut1) +/- -/ + def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) trace[Diverge.def] ("divRecursion: defs: " ++ msg) @@ -94,7 +170,23 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let grLvlParams := def0.levelParams trace[Diverge.def] "def0 type: {def0.type}" - -- Small utility: compute the list of type parameters + -- Compute the list of pairs: (input type × output type) + let inOutTys : Array (Expr × Expr) ← + preDefs.mapM (fun preDef => do + -- Check the universe parameters - TODO: I'm not sure what the best thing + -- to do is. In practice, all the type parameters should be in Type 0, so + -- we shouldn't have universe issues. + if preDef.levelParams ≠ grLvlParams then + throwError "Non-uniform polymorphism in the universes" + forallTelescope preDef.type (fun in_tys out_ty => do + let in_ty ← liftM (mkSigmasTypesOfTypes in_tys.toList) + let out_ty ← liftM (mkSigmasOutType in_tys.toList out_ty) + return (in_ty, out_ty) + ) + ) + trace[Diverge.def] "inOutTys: {inOutTys}" + +/- -- Small utility: compute the list of type parameters let getTypeParams (ty: Expr) : MetaM (List Expr × List Expr × Expr) := Lean.Meta.forallTelescope ty fun tys out_ty => do trace[Diverge.def] "types: {tys}" @@ -138,16 +230,16 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let (input_tys, output_tys) := List.unzip all_tys.toList let input_tys : List Expr ← liftM (List.mapM mkProds input_tys) - trace[Diverge.def] " in/out tys: {input_tys}; {output_tys}" + trace[Diverge.def] " in/out tys: {input_tys}; {output_tys}" -/ -- Compute the names set let names := preDefs.map PreDefinition.declName let names := HashSet.empty.insertMany names -- - for preDef in preDefs do - trace[Diverge.def] "about to explore: {preDef.declName}" - explore_term "" preDef.value + -- for preDef in preDefs do + -- trace[Diverge.def] "about to explore: {preDef.declName}" + -- explore_term "" preDef.value -- Compute the bodies @@ -267,6 +359,13 @@ elab_rules : command else Command.elabCommand <| ← `(namespace $(mkIdentFrom id ns) $cmd end $(mkIdentFrom id ns)) +divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := + match ls with + | [] => .fail .panic + | x :: ls => + if i = 0 then return x + else return (← list_nth ls (i - 1)) + mutual divergent def is_even (i : Int) : Result Bool := if i = 0 then return true else return (← is_odd (i - 1)) @@ -280,13 +379,6 @@ example (i : Int) : is_even i = .ret (i % 2 = 0) ∧ is_odd i = .ret (i % 2 ≠ unfold is_even sorry -divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := - match ls with - | [] => .fail .panic - | x :: ls => - if i = 0 then return x - else return (← list_nth ls (i - 1)) - mutual divergent def foo (i : Int) : Result Nat := if i > 10 then return (← foo (i / 10)) + (← bar i) else bar 10 diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean index 84b73a30..441b25f0 100644 --- a/backends/lean/Base/Diverge/ElabBase.lean +++ b/backends/lean/Base/Diverge/ElabBase.lean @@ -5,6 +5,7 @@ namespace Diverge open Lean Elab Term Meta initialize registerTraceClass `Diverge.elab (inherited := true) +initialize registerTraceClass `Diverge.def.sigmas (inherited := true) initialize registerTraceClass `Diverge.def (inherited := true) -- TODO: move -- cgit v1.2.3 From 1c9331ce92b68b9a83c601212149a6c24591708f Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 30 Jun 2023 15:53:39 +0200 Subject: Generate the fixed-point bodies in Elab.lean --- backends/lean/Base/Diverge/Base.lean | 8 +- backends/lean/Base/Diverge/Elab.lean | 451 +++++++++++++++++++++++-------- backends/lean/Base/Diverge/ElabBase.lean | 47 +++- 3 files changed, 391 insertions(+), 115 deletions(-) (limited to 'backends/lean/Base/Diverge') diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index 22b59bd0..aa0539ba 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -554,14 +554,14 @@ namespace FixI /- Some utilities to define the mutually recursive functions -/ -- TODO: use more - @[simp] def kk_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) := + abbrev kk_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) := (i:id) → (x:a i) → Result (b i x) - @[simp] def k_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) := + abbrev k_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) := kk_ty id a b → kk_ty id a b - def in_out_ty : Type (imax (u + 1) (v + 1)) := (in_ty : Type u) × ((x:in_ty) → Type v) + abbrev in_out_ty : Type (imax (u + 1) (v + 1)) := (in_ty : Type u) × ((x:in_ty) → Type v) -- TODO: remove? - @[simp] def mk_in_out_ty (in_ty : Type u) (out_ty : in_ty → Type v) : + abbrev mk_in_out_ty (in_ty : Type u) (out_ty : in_ty → Type v) : in_out_ty := Sigma.mk in_ty out_ty diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 116c5d8b..f7de7518 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -16,31 +16,62 @@ syntax (name := divergentDef) open Lean Elab Term Meta Primitives Lean.Meta set_option trace.Diverge.def true +-- set_option trace.Diverge.def.sigmas true /- The following was copied from the `wfRecursion` function. -/ open WF in --- Replace the recursive calls by a call to the continuation --- def replace_rec_calls +def mkList (xl : List Expr) (ty : Expr) : MetaM Expr := + match xl with + | [] => + mkAppOptM ``List.nil #[some ty] + | x :: tl => do + let tl ← mkList tl ty + mkAppOptM ``List.cons #[some ty, some x, some tl] --- print_decl is_even_body -#check instOfNatNat -#check OfNat.ofNat -- @OfNat.ofNat ℕ 2 ... -#check OfNat.ofNat -- @OfNat.ofNat (Fin 2) 1 ... -#check Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat +def mkProd (x y : Expr) : MetaM Expr := + mkAppM ``Prod.mk #[x, y] +def mkInOutTy (x y : Expr) : MetaM Expr := + mkAppM ``FixI.mk_in_out_ty #[x, y] -- TODO: is there already such a utility somewhere? -- TODO: change to mkSigmas def mkProds (tys : List Expr) : MetaM Expr := match tys with - | [] => do return (Expr.const ``PUnit.unit []) - | [ty] => do return ty + | [] => do pure (Expr.const ``PUnit.unit []) + | [ty] => do pure ty | ty :: tys => do let pty ← mkProds tys mkAppM ``Prod.mk #[ty, pty] +-- Return the `a` in `Return a` +def get_result_ty (ty : Expr) : MetaM Expr := + ty.withApp fun f args => do + if ¬ f.isConstOf ``Result ∨ args.size ≠ 1 then + throwError "Invalid argument to get_result_ty: {ty}" + else + pure (args.get! 0) + +-- Group a list of expressions into a dependent tuple +def mkSigmas (xl : List Expr) : MetaM Expr := + match xl with + | [] => do + trace[Diverge.def.sigmas] "mkSigmas: []" + pure (Expr.const ``PUnit.unit []) + | [x] => do + trace[Diverge.def.sigmas] "mkSigmas: [{x}]" + pure x + | fst :: xl => do + trace[Diverge.def.sigmas] "mkSigmas: [{fst}::{xl}]" + let alpha ← Lean.Meta.inferType fst + let snd ← mkSigmas xl + let snd_ty ← inferType snd + let beta ← mkLambdaFVars #[fst] snd_ty + trace[Diverge.def.sigmas] "mkSigmas:\n{alpha}\n{beta}\n{fst}\n{snd}" + mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd] + /- Generate the input type of a function body, which is a sigma type (i.e., a dependent tuple) which groups all its inputs. @@ -55,11 +86,11 @@ def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr := match xl with | [] => do trace[Diverge.def.sigmas] "mkSigmasOfTypes: []" - return (Expr.const ``PUnit.unit []) + pure (Expr.const ``PUnit.unit []) | [x] => do trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}]" let ty ← Lean.Meta.inferType x - return ty + pure ty | x :: xl => do trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]" let alpha ← Lean.Meta.inferType x @@ -71,15 +102,26 @@ def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr := def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index -/- Generate the out_ty of the body of a function, which from an input (a sigma - type generated by `mkSigmasTypesOfTypes`) gives the output type of the function. +/- Given a list of values `[x0:ty0, ..., xn:ty1]` where every `xi` might use the previous + `xj` (j < i) and a value `out` which uses `x0`, ..., `xn`, generate the following + expression: + ``` + fun x:((x0:ty0) × ... × (xn:tyn) => -- **Dependent** tuple + match x with + | (x0, ..., xn) => out + ``` + + The `index` parameter is used for naming purposes: we use it to numerotate the + bound variables that we introduce. Example: + ======== + More precisely: - xl = `[a:Type, ls:List a, i:Int]` - - out_ty = `a` - - index = 0 -- For naming purposes: we use it to numerotate the "scrutinee" variables + - out = `a` + - index = 0 - Generates: + generates: ``` match scrut0 with | Sigma.mk x scrut1 => @@ -88,36 +130,47 @@ def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index a ``` -/ -def mkSigmasOutType (xl : List Expr) (out_ty : Expr) (index : Nat := 0) : MetaM Expr := +partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : MetaM Expr := match xl with | [] => do -- This would be unexpected - throwError "mkSigmasOutType: empyt list of input parameters" + throwError "mkSigmasMatch: empyt list of input parameters" | [x] => do -- In the explanations above: inner match case - trace[Diverge.def.sigmas] "mkSigmasOutType: [{x}]" - mkLambdaFVars #[x] out_ty + trace[Diverge.def.sigmas] "mkSigmasMatch: [{x}]" + mkLambdaFVars #[x] out | fst :: xl => do -- In the explanations above: outer match case -- Remark: for the naming purposes, we use the same convention as for the -- fields and parameters in `Sigma.casesOn and `Sigma.mk - trace[Diverge.def.sigmas] "mkSigmasOutType: [{fst}::{xl}]" + trace[Diverge.def.sigmas] "mkSigmasMatch: [{fst}::{xl}]" let alpha ← Lean.Meta.inferType fst let snd_ty ← mkSigmasTypesOfTypes xl let beta ← mkLambdaFVars #[fst] snd_ty - let snd ← mkSigmasOutType xl out_ty (index + 1) + let snd ← mkSigmasMatch xl out (index + 1) let scrut_ty ← mkSigmasTypesOfTypes (fst :: xl) withLocalDeclD (mk_indexed_name index) scrut_ty fun scrut => do let mk ← mkLambdaFVars #[fst] snd - trace[Diverge.def.sigmas] "mkSigmasOutType: scrut: ({scrut}) : ({← inferType scrut})" - let motive ← mkLambdaFVars #[scrut] (← inferType out_ty) - trace[Diverge.def.sigmas] "mkSigmasOutType:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})" - let out ← mkAppOptM ``Sigma.casesOn #[some alpha, some beta, some motive, some scrut, some mk] - let out ← mkLambdaFVars #[scrut] out - trace[Diverge.def.sigmas] "mkSigmasOutType: out: {out}" - return out - -/- Small tests for list_nth: give a model of what `mkSigmasOutType` should generate -/ + trace[Diverge.def.sigmas] "mkSigmasMatch: scrut: ({scrut}) : ({← inferType scrut})" + -- TODO: make the computation of the motive more efficient + let motive ← do + let out_ty ← inferType out + match out_ty with + | .sort _ | .lit _ | .const .. => + -- The type of the motive doesn't depend on the scrutinee + mkLambdaFVars #[scrut] out_ty + | _ => + -- The type of the motive *may* depend on the scrutinee + -- TODO: make this more efficient (we could change the output type of + -- mkSigmasMatch + mkSigmasMatch (fst :: xl) out_ty + trace[Diverge.def.sigmas] "mkSigmasMatch:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})" + let sm ← mkAppOptM ``Sigma.casesOn #[some alpha, some beta, some motive, some scrut, some mk] + let sm ← mkLambdaFVars #[scrut] sm + trace[Diverge.def.sigmas] "mkSigmasMatch: sm: {sm}" + pure sm + +/- Small tests for list_nth: give a model of what `mkSigmasMatch` should generate -/ private def list_nth_out_ty2 (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) := @Sigma.casesOn (List a) (fun (_ls : List a) => Int) @@ -135,14 +188,199 @@ private def list_nth_out_ty1 (scrut0 : @Sigma (Type) (fun (a:Type) => list_nth_out_ty2 a scrut1) /- -/ +-- TODO: move +-- TODO: we can use Array.mapIdx +@[specialize] def mapiAux (i : Nat) (f : Nat → α → β) : List α → List β + | [] => [] + | a::as => f i a :: mapiAux (i+1) f as + +@[specialize] def mapi (f : Nat → α → β) : List α → List β := mapiAux 0 f + +#check Array.map +-- Return the expression: `Fin n` +-- TODO: use more +def mkFin (n : Nat) : Expr := + mkAppN (.const ``Fin []) #[.lit (.natVal n)] + +-- Return the expression: `i : Fin n` +def mkFinVal (n i : Nat) : MetaM Expr := do + let n_lit : Expr := .lit (.natVal (n - 1)) + let i_lit : Expr := .lit (.natVal i) + -- We could use `trySynthInstance`, but as we know the instance that we are + -- going to use, we can save the lookup + let ofNat ← mkAppOptM ``Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat #[n_lit, i_lit] + mkAppOptM ``OfNat.ofNat #[none, none, ofNat] + +-- TODO: remove? +def mkFinValOld (n i : Nat) : MetaM Expr := do + let finTy := mkFin n + let ofNat ← mkAppM ``OfNat #[finTy, .lit (.natVal i)] + match ← trySynthInstance ofNat with + | LOption.some x => + mkAppOptM ``OfNat.ofNat #[none, none, x] + | _ => throwError "mkFinVal: could not synthesize an instance of {ofNat} " + +/- Generate and declare as individual definitions the bodies for the individual funcions: + - replace the recursive calls with calls to the continutation `k` + - make those bodies take one single dependent tuple as input + + We name the declarations: "[original_name].body". + We return the new declarations. + -/ +def mkDeclareUnaryBodies (grLvlParams : List Name) (k_var : Expr) + (preDefs : Array PreDefinition) : + MetaM (Array Expr) := do + let grSize := preDefs.size + + -- Compute the map from name to index - the continuation has an indexed type: + -- we use the index (a finite number of type `Fin`) to control the function + -- we call at the recursive call + let nameToId : HashMap Name Nat := + let namesIds := mapi (fun i d => (d.declName, i)) preDefs.toList + HashMap.ofList namesIds + + trace[Diverge.def.genBody] "nameToId: {nameToId.toList}" + + -- Auxiliary function to explore the function bodies and replace the + -- recursive calls + let visit_e (e : Expr) : MetaM Expr := do + trace[Diverge.def.genBody] "visiting expression: {e}" + match e with + | .app .. => do + e.withApp fun f args => do + trace[Diverge.def.genBody] "this is an app: {f} {args}" + -- Check if this is a recursive call + if f.isConst then + let name := f.constName! + match nameToId.find? name with + | none => pure e + | some id => + -- This is a recursive call: replace it + -- Compute the index + let i ← mkFinVal grSize id + -- Put the arguments in one big dependent tuple + let args ← mkSigmas args.toList + mkAppM' k_var #[i, args] + else + -- Not a recursive call: do nothing + pure e + | .const name _ => + -- Sanity check: we eliminated all the recursive calls + if (nameToId.find? name).isSome then + throwError "mkUnaryBodies: a recursive call was not eliminated" + else pure e + | _ => pure e + + -- Explore the bodies + preDefs.mapM fun preDef => do + -- Replace the recursive calls + let body ← mapVisit visit_e preDef.value + + -- Change the type + lambdaLetTelescope body fun args body => do + let body ← mkSigmasMatch args.toList body 0 + + -- Add the declaration + let value ← mkLambdaFVars #[k_var] body + let name := preDef.declName.append "body" + let levelParams := grLvlParams + let decl := Declaration.defnDecl { + name := name + levelParams := levelParams + type := ← inferType value -- TODO: change the type + value := value + hints := ReducibilityHints.regular (getMaxHeight (← getEnv) value + 1) + safety := .safe + all := [name] + } + addDecl decl + trace[Diverge.def] "individual body of {preDef.declName}: {body}" + -- Return the constant + let body := Lean.mkConst name (levelParams.map .param) + -- let body ← mkAppM' body #[k_var] + trace[Diverge.def] "individual body (after decl): {body}" + pure body + +-- Generate a unique function body from the bodies of the mutually recursive group, +-- and add it as a declaration in the context +def mkDeclareMutualBody (grName : Name) (grLvlParams : List Name) + (i_var k_var : Expr) + (in_ty out_ty : Expr) (inOutTys : List (Expr × Expr)) + (bodies : Array Expr) : MetaM Expr := do + -- Generate the body + let grSize := bodies.size + let finTypeExpr := mkFin grSize + -- TODO: not very clean + let inOutTyType ← do + let (x, y) := inOutTys.get! 0 + inferType (← mkInOutTy x y) + let rec mkFuns (inOutTys : List (Expr × Expr)) (bl : List Expr) : MetaM Expr := + match inOutTys, bl with + | [], [] => + mkAppOptM ``FixI.Funs.Nil #[finTypeExpr, in_ty, out_ty] + | (ity, oty) :: inOutTys, b :: bl => do + -- Retrieving ity and oty - this is not very clean + let inOutTysExpr ← mkList (← inOutTys.mapM (λ (x, y) => mkInOutTy x y)) inOutTyType + let fl ← mkFuns inOutTys bl + mkAppOptM ``FixI.Funs.Cons #[finTypeExpr, in_ty, out_ty, ity, oty, inOutTysExpr, b, fl] + | _, _ => throwError "mkDeclareMutualBody: `tys` and `bodies` don't have the same length" + let bodyFuns ← mkFuns inOutTys bodies.toList + -- Wrap in `get_fun` + let body ← mkAppM ``FixI.get_fun #[bodyFuns, i_var, k_var] + -- Add the index `i` and the continuation `k` as a variables + let body ← mkLambdaFVars #[k_var, i_var] body + trace[Diverge.def] "mkDeclareMutualBody: body: {body}" + -- Add the declaration + let name := grName.append "mutrec_body" + let levelParams := grLvlParams + let decl := Declaration.defnDecl { + name := name + levelParams := levelParams + type := ← inferType body + value := body + hints := ReducibilityHints.regular (getMaxHeight (← getEnv) body + 1) + safety := .safe + all := [name] + } + addDecl decl + -- Return the constant + pure (Lean.mkConst name (levelParams.map .param)) + +-- Generate the final definions by using the mutual body and the fixed point operator. +def mkDeclareFixDefs (mutBody : Expr) (preDefs : Array PreDefinition) : + TermElabM Unit := do + let grSize := preDefs.size + let _ ← preDefs.mapIdxM fun idx preDef => do + lambdaLetTelescope preDef.value fun xs _ => do + -- Create the index + let idx ← mkFinVal grSize idx.val + -- Group the inputs into a dependent tuple + let input ← mkSigmas xs.toList + -- Apply the fixed point + let fixedBody ← mkAppM ``FixI.fix #[mutBody, idx, input] + let fixedBody ← mkLambdaFVars xs fixedBody + -- Create the declaration + let name := preDef.declName + let decl := Declaration.defnDecl { + name := name + levelParams := preDef.levelParams + type := preDef.type + value := fixedBody + hints := ReducibilityHints.regular (getMaxHeight (← getEnv) fixedBody + 1) + safety := .safe + all := [name] + } + addDecl decl + pure () + def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) trace[Diverge.def] ("divRecursion: defs: " ++ msg) -- CHANGE HERE This function should add definitions with these names/types/values ^^ -- Temporarily add the predefinitions as axioms - for preDef in preDefs do - addAsAxiom preDef + -- for preDef in preDefs do + -- addAsAxiom preDef -- TODO: what is this? for preDef in preDefs do @@ -154,25 +392,14 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let grName := def0.declName trace[Diverge.def] "group name: {grName}" - /- Compute the type of the continuation. - - We do the following - - we make sure all the definitions have the same universe parameters - (we can make this more general later) - - we group all the type parameters together, make sure all the - definitions have the same type parameters, and enforce - a uniform polymorphism (we can also lift this later). - This would require generalizing a bit our indexed fixed point to - make the output type parametric in the input. - - we group all the non-type parameters: we parameterize the continuation - by those - -/ + /- # Compute the input/output types of the continuation `k`. -/ let grLvlParams := def0.levelParams - trace[Diverge.def] "def0 type: {def0.type}" + trace[Diverge.def] "def0 universe levels: {def0.levelParams}" - -- Compute the list of pairs: (input type × output type) + -- We first compute the list of pairs: (input type × output type) let inOutTys : Array (Expr × Expr) ← preDefs.mapM (fun preDef => do + withRef preDef.ref do -- is the withRef useful? -- Check the universe parameters - TODO: I'm not sure what the best thing -- to do is. In practice, all the type parameters should be in Type 0, so -- we shouldn't have universe issues. @@ -180,68 +407,74 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do throwError "Non-uniform polymorphism in the universes" forallTelescope preDef.type (fun in_tys out_ty => do let in_ty ← liftM (mkSigmasTypesOfTypes in_tys.toList) - let out_ty ← liftM (mkSigmasOutType in_tys.toList out_ty) - return (in_ty, out_ty) + -- Retrieve the type in the "Result" + let out_ty ← get_result_ty out_ty + let out_ty ← liftM (mkSigmasMatch in_tys.toList out_ty) + pure (in_ty, out_ty) ) ) trace[Diverge.def] "inOutTys: {inOutTys}" - -/- -- Small utility: compute the list of type parameters - let getTypeParams (ty: Expr) : MetaM (List Expr × List Expr × Expr) := - Lean.Meta.forallTelescope ty fun tys out_ty => do - trace[Diverge.def] "types: {tys}" -/- let (_, params) ← StateT.run (do - for x in tys do - let ty ← Lean.Meta.inferType x - match ty with - | .sort _ => do - let st ← StateT.get - StateT.set (ty :: st) - | _ => do break - ) ([] : List Expr) - let params := params.reverse - trace[Diverge.def] " type parameters {params}" - return params -/ - let rec get_params (ls : List Expr) : MetaM (List Expr × List Expr) := - match ls with - | x :: tl => do - let ty ← Lean.Meta.inferType x - match ty with - | .sort _ => do - let (ty_params, params) ← get_params tl - return (x :: ty_params, params) - | _ => do return ([], ls) - | _ => do return ([], []) - let (ty_params, params) ← get_params tys.toList - trace[Diverge.def] " parameters: {ty_params}; {params}" - return (ty_params, params, out_ty) - let (grTyParams, _, _) ← do - getTypeParams def0.type - - -- Compute the input types and the output types - let all_tys ← preDefs.mapM fun preDef => do - let (tyParams, params, ret_ty) ← getTypeParams preDef.type - -- TODO: this is not complete, there are more checks to perform - if tyParams.length ≠ grTyParams.length then - throwError "Non-uniform polymorphism" - return (params, ret_ty) - - -- TODO: I think there are issues with the free variables - let (input_tys, output_tys) := List.unzip all_tys.toList - let input_tys : List Expr ← liftM (List.mapM mkProds input_tys) - - trace[Diverge.def] " in/out tys: {input_tys}; {output_tys}" -/ - - -- Compute the names set - let names := preDefs.map PreDefinition.declName - let names := HashSet.empty.insertMany names - - -- - -- for preDef in preDefs do - -- trace[Diverge.def] "about to explore: {preDef.declName}" - -- explore_term "" preDef.value - - -- Compute the bodies + -- Turn the list of input/output type pairs into an expresion + let inOutTysExpr ← inOutTys.mapM (λ (x, y) => mkInOutTy x y) + let inOutTysExpr ← mkList inOutTysExpr.toList (← inferType (inOutTysExpr.get! 0)) + + -- From the list of pairs of input/output types, actually compute the + -- type of the continuation `k`. + -- We first introduce the index `i : Fin n` where `n` is the number of + -- functions in the group. + let i_var_ty := mkFin preDefs.size + withLocalDeclD (.num (.str .anonymous "i") 0) i_var_ty fun i_var => do + let in_out_ty ← mkAppM ``List.get #[inOutTysExpr, i_var] + trace[Diverge.def] "in_out_ty := {in_out_ty} : {← inferType in_out_ty}" + -- Add an auxiliary definition for `in_out_ty` + let in_out_ty ← do + let value ← mkLambdaFVars #[i_var] in_out_ty + let name := grName.append "in_out_ty" + let levelParams := grLvlParams + let decl := Declaration.defnDecl { + name := name + levelParams := levelParams + type := ← inferType value + value := value + hints := .abbrev + safety := .safe + all := [name] + } + addDecl decl + -- Return the constant + let in_out_ty := Lean.mkConst name (levelParams.map .param) + mkAppM' in_out_ty #[i_var] + trace[Diverge.def] "in_out_ty (after decl) := {in_out_ty} : {← inferType in_out_ty}" + let in_ty ← mkAppM ``Sigma.fst #[in_out_ty] + trace[Diverge.def] "in_ty: {in_ty}" + withLocalDeclD (.num (.str .anonymous "x") 1) in_ty fun input => do + let out_ty ← mkAppM' (← mkAppM ``Sigma.snd #[in_out_ty]) #[input] + trace[Diverge.def] "out_ty: {out_ty}" + + -- Introduce the continuation `k` + let in_ty ← mkLambdaFVars #[i_var] in_ty + let out_ty ← mkLambdaFVars #[i_var, input] out_ty + let k_var_ty ← mkAppM ``FixI.kk_ty #[i_var_ty, in_ty, out_ty] -- + trace[Diverge.def] "k_var_ty: {k_var_ty}" + withLocalDeclD (.num (.str .anonymous "k") 2) k_var_ty fun k_var => do + trace[Diverge.def] "k_var: {k_var}" + + -- Replace the recursive calls in all the function bodies by calls to the + -- continuation `k` and and generate for those bodies declarations + let bodies ← mkDeclareUnaryBodies grLvlParams k_var preDefs + -- Generate the mutually recursive body + let body ← mkDeclareMutualBody grName grLvlParams i_var k_var in_ty out_ty inOutTys.toList bodies + trace[Diverge.def] "mut rec body (after decl): {body}" + + -- Prove that the mut rec body satisfies the validity criteria required by + -- our fixed-point + -- TODO + + -- Generate the final definitions + let defs ← mkDeclareFixDefs body preDefs + + -- Prove the unfolding equations + -- TODO -- Process the definitions addAndCompilePartialRec preDefs @@ -366,6 +599,10 @@ divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := if i = 0 then return x else return (← list_nth ls (i - 1)) +#print list_nth.in_out_ty +#check list_nth.body +#print list_nth + mutual divergent def is_even (i : Int) : Result Bool := if i = 0 then return true else return (← is_odd (i - 1)) diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean index 441b25f0..82f79f94 100644 --- a/backends/lean/Base/Diverge/ElabBase.lean +++ b/backends/lean/Base/Diverge/ElabBase.lean @@ -4,13 +4,14 @@ namespace Diverge open Lean Elab Term Meta -initialize registerTraceClass `Diverge.elab (inherited := true) -initialize registerTraceClass `Diverge.def.sigmas (inherited := true) -initialize registerTraceClass `Diverge.def (inherited := true) +initialize registerTraceClass `Diverge.elab +initialize registerTraceClass `Diverge.def +initialize registerTraceClass `Diverge.def.sigmas +initialize registerTraceClass `Diverge.def.genBody -- TODO: move -- TODO: small helper -def explore_term (incr : String) (e : Expr) : TermElabM Unit := +def explore_term (incr : String) (e : Expr) : MetaM Unit := match e with | .bvar _ => do logInfo m!"{incr}bvar: {e}"; return () | .fvar _ => do logInfo m!"{incr}fvar: {e}"; return () @@ -78,4 +79,42 @@ private def test2 (x : Nat) : Nat := x print_decl test1 print_decl test2 +-- We adapted this from AbstractNestedProofs.visit +-- A map visitor function for expressions +partial def mapVisit (k : Expr → MetaM Expr) (e : Expr) : MetaM Expr := do + let mapVisitBinders (xs : Array Expr) (k2 : MetaM Expr) : MetaM Expr := do + let localInstances ← getLocalInstances + let mut lctx ← getLCtx + for x in xs do + let xFVarId := x.fvarId! + let localDecl ← xFVarId.getDecl + let type ← mapVisit k localDecl.type + let localDecl := localDecl.setType type + let localDecl ← match localDecl.value? with + | some value => let value ← mapVisit k value; pure <| localDecl.setValue value + | none => pure localDecl + lctx :=lctx.modifyLocalDecl xFVarId fun _ => localDecl + withLCtx lctx localInstances k2 + -- TODO: use a cache? (Lean.checkCache) + -- Explore + let e ← k e + match e with + | .bvar _ + | .fvar _ + | .mvar _ + | .sort _ + | .lit _ + | .const _ _ => pure e + | .app .. => do e.withApp fun f args => return mkAppN f (← args.mapM (mapVisit k)) + | .lam .. => + lambdaLetTelescope e fun xs b => + mapVisitBinders xs do mkLambdaFVars xs (← mapVisit k b) (usedLetOnly := false) + | .forallE .. => do + forallTelescope e fun xs b => mapVisitBinders xs do mkForallFVars xs (← mapVisit k b) + | .letE .. => do + lambdaLetTelescope e fun xs b => mapVisitBinders xs do + mkLambdaFVars xs (← mapVisit k b) (usedLetOnly := false) + | .mdata _ b => return e.updateMData! (← mapVisit k b) + | .proj _ _ b => return e.updateProj! (← mapVisit k b) + end Diverge -- cgit v1.2.3 From 37e5d5501e024869037bf0ea1559229a8be62da7 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 3 Jul 2023 16:24:44 +0200 Subject: Generate the proofs of validity in Elab.lean --- backends/lean/Base/Diverge/Base.lean | 76 +++++- backends/lean/Base/Diverge/Elab.lean | 403 ++++++++++++++++++++++++++++--- backends/lean/Base/Diverge/ElabBase.lean | 1 + 3 files changed, 446 insertions(+), 34 deletions(-) (limited to 'backends/lean/Base/Diverge') diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index aa0539ba..89365d25 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -434,6 +434,23 @@ namespace Fix is_valid_p k (λ k => k x) := by simp_all [is_valid_p, is_mono_p_rec, is_cont_p_rec] + theorem is_valid_p_ite + (k : ((x:a) → Result (b x)) → (x:a) → Result (b x)) + (cond : Prop) [h : Decidable cond] + {e1 e2 : ((x:a) → Result (b x)) → Result c} + (he1: is_valid_p k e1) (he2 : is_valid_p k e2) : + is_valid_p k (ite cond e1 e2) := by + split <;> assumption + + theorem is_valid_p_dite + (k : ((x:a) → Result (b x)) → (x:a) → Result (b x)) + (cond : Prop) [h : Decidable cond] + {e1 : cond → ((x:a) → Result (b x)) → Result c} + {e2 : Not cond → ((x:a) → Result (b x)) → Result c} + (he1: ∀ x, is_valid_p k (e1 x)) (he2 : ∀ x, is_valid_p k (e2 x)) : + is_valid_p k (dite cond e1 e2) := by + split <;> simp [*] + -- Lean is good at unification: we can write a very general version -- (in particular, it will manage to figure out `g` and `h` when we -- apply the lemma) @@ -680,6 +697,24 @@ namespace FixI is_valid_p k (λ k => k i x) := by simp [is_valid_p, k_to_gen, e_to_gen, kk_to_gen, kk_of_gen] + theorem is_valid_p_ite + (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) + (cond : Prop) [h : Decidable cond] + {e1 e2 : ((i:id) → (x:a i) → Result (b i x)) → Result c} + (he1: is_valid_p k e1) (he2 : is_valid_p k e2) : + is_valid_p k (λ k => ite cond (e1 k) (e2 k)) := by + split <;> assumption + + theorem is_valid_p_dite + (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) + (cond : Prop) [h : Decidable cond] + {e1 : ((i:id) → (x:a i) → Result (b i x)) → cond → Result c} + {e2 : ((i:id) → (x:a i) → Result (b i x)) → Not cond → Result c} + (he1: ∀ x, is_valid_p k (λ k => e1 k x)) + (he2 : ∀ x, is_valid_p k (λ k => e2 k x)) : + is_valid_p k (λ k => dite cond (e1 k) (e2 k)) := by + split <;> simp [*] + theorem is_valid_p_bind {{k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)}} {{g : ((i:id) → (x:a i) → Result (b i x)) → Result c}} @@ -699,6 +734,9 @@ namespace FixI | .Nil => True | .Cons f fl => (∀ x, FixI.is_valid_p k (λ k => f k x)) ∧ fl.is_valid_p k + theorem Funs.is_valid_p_Nil (k : k_ty id a b) : + Funs.is_valid_p k Funs.Nil := by simp [Funs.is_valid_p] + def Funs.is_valid_p_is_valid_p_aux {k : k_ty id a b} {tys : List in_out_ty} @@ -1116,7 +1154,7 @@ namespace Ex6 def body (k : (i : Fin 1) → (x : input_ty i) → Result (output_ty i x)) (i: Fin 1) : (x : input_ty i) → Result (output_ty i x) := get_fun bodies i k - theorem list_nth_body_is_valid: is_valid body := by + theorem body_is_valid: is_valid body := by -- Split the proof into proofs of validity of the individual bodies rw [is_valid] simp only [body] @@ -1131,6 +1169,20 @@ namespace Ex6 split <;> simp split <;> simp + -- Writing the proof terms explicitly + theorem list_nth_body_is_valid' (k : k_ty (Fin 1) input_ty output_ty) + (x : (a : Type u) × List a × Int) : is_valid_p k (fun k => list_nth_body k x) := + let ⟨ a, ls, i ⟩ := x + match ls with + | [] => is_valid_p_same k (.fail .panic) + | hd :: tl => + is_valid_p_ite k (Eq i 0) (is_valid_p_same k (.ret hd)) (is_valid_p_rec k 0 ⟨a, tl, i-1⟩) + + theorem body_is_valid' : is_valid body := + fun k => + Funs.is_valid_p_is_valid_p tys k bodies + (And.intro (list_nth_body_is_valid' k) (Funs.is_valid_p_Nil k)) + noncomputable def list_nth {a: Type u} (ls : List a) (i : Int) : Result a := fix body 0 ⟨ a, ls , i ⟩ @@ -1144,8 +1196,28 @@ namespace Ex6 if i = 0 then .ret hd else list_nth tl (i - 1) := by - have Heq := is_valid_fix_fixed_eq list_nth_body_is_valid + have Heq := is_valid_fix_fixed_eq body_is_valid simp [list_nth] conv => lhs; rw [Heq] + -- Write the proof term explicitly: the generation of the proof term (without tactics) + -- is automatable, and the proof term is actually a lot simpler and smaller when we + -- don't use tactics. + theorem list_nth_eq'.{u} {a : Type u} (ls : List a) (i : Int) : + list_nth ls i = + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else list_nth tl (i - 1) + := + -- Use the fixed-point equation + have Heq := is_valid_fix_fixed_eq body_is_valid.{u} + -- Add the index + have Heqi := congr_fun Heq 0 + -- Add the input + have Heqix := congr_fun Heqi { fst := a, snd := (ls, i) } + -- Done + Heqix + end Ex6 diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index f7de7518..cf40ea8f 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -16,6 +16,7 @@ syntax (name := divergentDef) open Lean Elab Term Meta Primitives Lean.Meta set_option trace.Diverge.def true +set_option trace.Diverge.def.valid true -- set_option trace.Diverge.def.sigmas true /- The following was copied from the `wfRecursion` function. -/ @@ -196,7 +197,6 @@ private def list_nth_out_ty1 (scrut0 : @Sigma (Type) (fun (a:Type) => @[specialize] def mapi (f : Nat → α → β) : List α → List β := mapiAux 0 f -#check Array.map -- Return the expression: `Fin n` -- TODO: use more def mkFin (n : Nat) : Expr := @@ -227,7 +227,7 @@ def mkFinValOld (n i : Nat) : MetaM Expr := do We name the declarations: "[original_name].body". We return the new declarations. -/ -def mkDeclareUnaryBodies (grLvlParams : List Name) (k_var : Expr) +def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) (preDefs : Array PreDefinition) : MetaM (Array Expr) := do let grSize := preDefs.size @@ -260,7 +260,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (k_var : Expr) let i ← mkFinVal grSize id -- Put the arguments in one big dependent tuple let args ← mkSigmas args.toList - mkAppM' k_var #[i, args] + mkAppM' kk_var #[i, args] else -- Not a recursive call: do nothing pure e @@ -281,8 +281,8 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (k_var : Expr) let body ← mkSigmasMatch args.toList body 0 -- Add the declaration - let value ← mkLambdaFVars #[k_var] body - let name := preDef.declName.append "body" + let value ← mkLambdaFVars #[kk_var] body + let name := preDef.declName.append "sbody" let levelParams := grLvlParams let decl := Declaration.defnDecl { name := name @@ -297,16 +297,17 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (k_var : Expr) trace[Diverge.def] "individual body of {preDef.declName}: {body}" -- Return the constant let body := Lean.mkConst name (levelParams.map .param) - -- let body ← mkAppM' body #[k_var] + -- let body ← mkAppM' body #[kk_var] trace[Diverge.def] "individual body (after decl): {body}" pure body -- Generate a unique function body from the bodies of the mutually recursive group, --- and add it as a declaration in the context -def mkDeclareMutualBody (grName : Name) (grLvlParams : List Name) - (i_var k_var : Expr) +-- and add it as a declaration in the context. +-- We return the list of bodies (of type `Funs ...`) and the mutually recursive body. +def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name) + (kk_var i_var : Expr) (in_ty out_ty : Expr) (inOutTys : List (Expr × Expr)) - (bodies : Array Expr) : MetaM Expr := do + (bodies : Array Expr) : MetaM (Expr × Expr) := do -- Generate the body let grSize := bodies.size let finTypeExpr := mkFin grSize @@ -323,15 +324,15 @@ def mkDeclareMutualBody (grName : Name) (grLvlParams : List Name) let inOutTysExpr ← mkList (← inOutTys.mapM (λ (x, y) => mkInOutTy x y)) inOutTyType let fl ← mkFuns inOutTys bl mkAppOptM ``FixI.Funs.Cons #[finTypeExpr, in_ty, out_ty, ity, oty, inOutTysExpr, b, fl] - | _, _ => throwError "mkDeclareMutualBody: `tys` and `bodies` don't have the same length" + | _, _ => throwError "mkDeclareMutRecBody: `tys` and `bodies` don't have the same length" let bodyFuns ← mkFuns inOutTys bodies.toList -- Wrap in `get_fun` - let body ← mkAppM ``FixI.get_fun #[bodyFuns, i_var, k_var] + let body ← mkAppM ``FixI.get_fun #[bodyFuns, i_var, kk_var] -- Add the index `i` and the continuation `k` as a variables - let body ← mkLambdaFVars #[k_var, i_var] body - trace[Diverge.def] "mkDeclareMutualBody: body: {body}" + let body ← mkLambdaFVars #[kk_var, i_var] body + trace[Diverge.def] "mkDeclareMutRecBody: body: {body}" -- Add the declaration - let name := grName.append "mutrec_body" + let name := grName.append "mut_rec_body" let levelParams := grLvlParams let decl := Declaration.defnDecl { name := name @@ -344,10 +345,348 @@ def mkDeclareMutualBody (grName : Name) (grLvlParams : List Name) } addDecl decl -- Return the constant - pure (Lean.mkConst name (levelParams.map .param)) + pure (bodyFuns, Lean.mkConst name (levelParams.map .param)) + +def isCasesExpr (e : Expr) : MetaM Bool := do + let e := e.getAppFn + if e.isConst then + return isCasesOnRecursor (← getEnv) e.constName + else return false + +structure MatchInfo where + matcherName : Name + matcherLevels : Array Level + params : Array Expr + motive : Expr + scruts : Array Expr + branchesNumParams : Array Nat + branches : Array Expr + +instance : ToMessageData MatchInfo where + -- This is not a very clean formatting, but we don't need more + toMessageData := fun me => m!"\n- matcherName: {me.matcherName}\n- params: {me.params}\n- motive: {me.motive}\n- scruts: {me.scruts}\n- branchesNumParams: {me.branchesNumParams}\n- branches: {me.branches}" + +-- An expression which doesn't use the continuation kk is valid +def proveNoKExprIsValid (k_var : Expr) (e : Expr) : MetaM Expr := do + trace[Diverge.def.valid] "proveNoKExprIsValid: {e}" + let eIsValid ← mkAppM ``FixI.is_valid_p_same #[k_var, e] + trace[Diverge.def.valid] "proveNoKExprIsValid: result:\n{eIsValid}:\n{← inferType eIsValid}" + pure eIsValid + +mutual + +partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do + trace[Diverge.def.valid] "proveValid: {e}" + match e with + | .bvar _ + | .fvar _ + | .mvar _ + | .sort _ + | .lit _ + | .const _ _ => throwError "Unimplemented" + | .lam .. => throwError "Unimplemented" + | .forallE .. => throwError "Unreachable" -- Shouldn't get there + | .letE .. => throwError "TODO" + -- lambdaLetTelescope e fun xs b => mapVisitBinders xs do + -- mkLambdaFVars xs (← mapVisit k b) (usedLetOnly := false) + | .mdata _ b => proveExprIsValid k_var kk_var b + | .proj _ _ _ => + -- The projection shouldn't use the continuation + proveNoKExprIsValid k_var e + | .app .. => + e.withApp fun f args => do + -- There are several cases: first, check if this is a match/if + -- The expression is a (dependent) if then else + let isIte := e.isIte + if isIte || e.isDIte then do + e.withApp fun f args => do + trace[Diverge.def.valid] "ite/dite: {f}:\n{args}" + if args.size ≠ 5 then + throwError "Wrong number of parameters for {f}: {args}" + let cond := args.get! 1 + let dec := args.get! 2 + -- Prove that the branches are valid + let br0 := args.get! 3 + let br1 := args.get! 4 + let proveBranchValid (br : Expr) : MetaM Expr := + if isIte then proveExprIsValid k_var kk_var br + else do + -- There is a lambda -- TODO: how do we remove exacly *one* lambda? + lambdaLetTelescope br fun xs br => do + let x := xs.get! 0 + let xs := xs.extract 1 xs.size + let br ← mkLambdaFVars xs br + let brValid ← proveExprIsValid k_var kk_var br + mkLambdaFVars #[x] brValid + let br0Valid ← proveBranchValid br0 + let br1Valid ← proveBranchValid br1 + let const := if isIte then ``FixI.is_valid_p_ite else ``FixI.is_valid_p_dite + let eIsValid ← mkAppOptM const #[none, none, none, none, some k_var, some cond, some dec, none, none, some br0Valid, some br1Valid] + trace[Diverge.def.valid] "ite/dite: result:\n{eIsValid}:\n{← inferType eIsValid}" + pure eIsValid + -- The expression is a match (this case is for when the elaborator + -- introduces auxiliary definitions to hide the match behind syntactic + -- sugar) + else if let some me := ← matchMatcherApp? e then do + trace[Diverge.def.valid] + "matcherApp: + - params: {me.params} + - motive: {me.motive} + - discrs: {me.discrs} + - altNumParams: {me.altNumParams} + - alts: {me.alts} + - remaining: {me.remaining}" + -- matchMatcherApp has already done the work for us + if me.remaining.size ≠ 0 then + throwError "MatcherApp: non empty remaining array: {me.remaining}" + let me : MatchInfo := { + matcherName := me.matcherName + matcherLevels := me.matcherLevels + params := me.params + motive := me.motive + scruts := me.discrs + branchesNumParams := me.altNumParams + branches := me.alts + } + proveMatchIsValid k_var kk_var me + -- The expression is a raw match (this case is for when the expression + -- is a direct call to the primitive `casesOn` function, without + -- syntactic sugar) + else if ← isCasesExpr f then do + trace[Diverge.def.valid] "rawMatch: {e}" + -- The casesOn definition is always of the following shape: + -- input parameters (implicit parameters), then motive (implicit), + -- scrutinee (explicit), branches (explicit). + let matcherName := f.constName! + let matcherLevels := f.constLevels!.toArray + -- Find the first explicit parameter: this is the scrutinee + forallTelescope (← inferType f) fun xs _ => do + let rec findFirstExplicit (i : Nat) : MetaM Nat := do + if i ≥ xs.size then throwError "Unexpected: could not find an explicit parameter" + else + let x := xs.get! i + let xFVarId := x.fvarId! + let localDecl ← xFVarId.getDecl + match localDecl.binderInfo with + | .default => pure i + | _ => findFirstExplicit (i + 1) + let scrutIdx ← findFirstExplicit 0 + -- Split the arguments + let params := args.extract 0 (scrutIdx - 1) + let motive := args.get! (scrutIdx - 1) + let scrut := args.get! scrutIdx + let branches := args.extract (scrutIdx + 1) args.size + -- Compute the number of parameters for the branches: for this we use + -- the type of the uninstantiated casesOn constant + let branchesNumParams : Array Nat ← do + let env ← getEnv + let decl := env.constants.find! matcherName + let ty := decl.type + forallTelescope ty fun xs _ => do + let xs := xs.extract (scrutIdx + 1) xs.size + xs.mapM fun x => do + let xty ← inferType x + forallTelescope xty fun ys _ => do + pure ys.size + let me : MatchInfo := { + matcherName, + matcherLevels, + params, + motive, + scruts := #[scrut], + branchesNumParams, + branches, + } + proveMatchIsValid k_var kk_var me + -- Monadic let-binding + else if f.isConstOf ``Bind.bind then do + trace[Diverge.def.valid] "bind:\n{args}" + let x := args.get! 4 + let y := args.get! 5 + -- Prove that the subexpressions are valid + let xValid ← proveExprIsValid k_var kk_var x + trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}" + let yValid ← do + -- This is a lambda expression -- TODO: how do we remove exacly *one* lambda? + lambdaLetTelescope y fun xs y => do + let x := xs.get! 0 + let xs := xs.extract 1 xs.size + let y ← mkLambdaFVars xs y + trace[Diverge.def.valid] "bind: y: {y}" + let yValid ← proveExprIsValid k_var kk_var y + trace[Diverge.def.valid] "bind: yValid (no forall): {yValid}" + trace[Diverge.def.valid] "bind: yValid: x: {x}" + let yValid ← mkLambdaFVars #[x] yValid + trace[Diverge.def.valid] "bind: yValid (forall): {yValid}: {← inferType yValid}" + pure yValid + -- Put everything together + trace[Diverge.def.valid] "bind:\n- xValid: {xValid}: {← inferType xValid}\n- yValid: {yValid}: {← inferType yValid}" + mkAppM ``FixI.is_valid_p_bind #[xValid, yValid] + -- Recursive call + else if f.isFVarOf kk_var.fvarId! then do + trace[Diverge.def.valid] "rec: args: \n{args}" + if args.size ≠ 2 then throwError "Recursive call with invalid number of parameters: {args}" + let i_arg := args.get! 0 + let x_arg := args.get! 1 + let eIsValid ← mkAppM ``FixI.is_valid_p_rec #[k_var, i_arg, x_arg] + trace[Diverge.def.valid] "rec: result: \n{eIsValid}" + pure eIsValid + else do + -- Remaining case: normal application. + -- It shouldn't use the continuation + proveNoKExprIsValid k_var e + +partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Expr := do + trace[Diverge.def.valid] "proveMatchIsValid: {me}" + -- Prove the validity of the branch expressions + let branchesValid:Array Expr ← me.branches.mapIdxM fun idx br => do + -- Go inside the lambdas - note that we have to be careful: some of the + -- binders might come from the match, and some of the binders might come + -- from the fact that the expression in the match is a lambda expression: + -- we use the branchesNumParams field for this reason + lambdaLetTelescope br fun xs br => do + let numParams := me.branchesNumParams.get! idx + let xs_beg := xs.extract 0 numParams + let xs_end := xs.extract numParams xs.size + let br ← mkLambdaFVars xs_end br + -- Prove that the branch expression is valid + let brValid ← proveExprIsValid k_var kk_var br + -- Reconstruct the lambda expression + mkLambdaFVars xs_beg brValid + trace[Diverge.def.valid] "branchesValid:\n{branchesValid}" + -- Put together: compute the motive. + -- It must be of the shape: + -- ``` + -- λ scrut => is_valid_p k (λ k => match scrut with ...) + -- ``` + let validMotive : Expr ← do + -- The motive is a function of the scrutinees (i.e., a lambda expression): + -- introduce binders for the scrutinees + let declInfos := me.scruts.mapIdx fun idx scrut => + let name : Name := (.num (.str .anonymous "scrut") idx) + let ty := λ (_ : Array Expr) => inferType scrut + (name, ty) + withLocalDeclsD declInfos fun scrutVars => do + -- Create a match expression but where the scrutinees have been replaced + -- by variables + let params : Array (Option Expr) := me.params.map some + let motive : Option Expr := some me.motive + let scruts : Array (Option Expr) := scrutVars.map some + let branches : Array (Option Expr) := me.branches.map some + let args := params ++ [motive] ++ scruts ++ branches + let matchE ← mkAppOptM me.matcherName args + -- let matchE ← mkLambdaFVars scrutVars (← mkAppOptM me.matcherName args) + -- Wrap in the `is_valid_p` predicate + let matchE ← mkLambdaFVars #[kk_var] matchE + let validMotive ← mkAppM ``FixI.is_valid_p #[k_var, matchE] + -- Abstract away the scrutinee variables + mkLambdaFVars scrutVars validMotive + trace[Diverge.def.valid] "valid motive: {validMotive}" + -- Put together + let valid ← do + let params : Array (Option Expr) := me.params.map (λ _ => none) + let motive := some validMotive + let scruts := me.scruts.map some + let branches := branchesValid.map some + let args := params ++ [motive] ++ scruts ++ branches + mkAppOptM me.matcherName args + trace[Diverge.def.valid] "proveMatchIsValid:\n{valid}:\n{← inferType valid}" + pure valid + +end + +-- Prove that a single body (in the mutually recursive group) is valid +partial def proveSingleBodyIsValid (k_var : Expr) (preDef : PreDefinition) (bodyConst : Expr) : + MetaM Expr := do + trace[Diverge.def.valid] "proveSingleBodyIsValid: bodyConst: {bodyConst}" + -- Lookup the definition (`bodyConst` is the definition of the body, we want + -- to retrieve the value itself to dive inside) + let name := bodyConst.constName! + let env ← getEnv + let body := (env.constants.find! name).value! + trace[Diverge.def.valid] "body: {body}" + lambdaLetTelescope body fun xs body => do + assert! xs.size = 2 + let kk_var := xs.get! 0 + let x_var := xs.get! 1 + -- State the type of the theorem to prove + let thmTy ← mkAppM ``FixI.is_valid_p + #[k_var, ← mkLambdaFVars #[kk_var] (← mkAppM' bodyConst #[kk_var, x_var])] + trace[Diverge.def.valid] "thmTy: {thmTy}" + -- Prove that the body is valid + let proof ← proveExprIsValid k_var kk_var body + let proof ← mkLambdaFVars #[k_var, x_var] proof + trace[Diverge.def.valid] "proveSingleBodyIsValid: proof:\n{proof}:\n{← inferType proof}" + -- The target type (we don't have to do this: this is simply a sanity check, + -- and this allows a nicer debugging output) + let thmTy ← do + let body ← mkAppM' bodyConst #[kk_var, x_var] + let body ← mkLambdaFVars #[kk_var] body + let ty ← mkAppM ``FixI.is_valid_p #[k_var, body] + mkForallFVars #[k_var, x_var] ty + trace[Diverge.def.valid] "proveSingleBodyIsValid: thmTy\n{thmTy}:\n{← inferType thmTy}" + -- Save the theorem + let name := preDef.declName ++ "sbody_is_valid" + let decl := Declaration.thmDecl { + name + levelParams := preDef.levelParams + type := thmTy + value := proof + all := [name] + } + addDecl decl + trace[Diverge.def.valid] "proveSingleBodyIsValid: added thm: {name}" + -- Return the theorem + pure (Expr.const name (preDef.levelParams.map .param)) + +partial def proveFunsBodyIsValid (inOutTys: Expr) (bodyFuns : Expr) + (k_var : Expr) (bodiesValid : Array Expr) : MetaM Expr := do + -- Create the big "and" expression, which groups the validity proof of the individual bodies + let rec mkValidConj (i : Nat) : MetaM Expr := do + if i = bodiesValid.size then + -- We reached the end + mkAppM ``FixI.Funs.is_valid_p_Nil #[k_var] + else do + -- We haven't reached the end: introduce a conjunction + let valid := bodiesValid.get! i + let valid ← mkAppM' valid #[k_var] + mkAppM ``And.intro #[valid, ← mkValidConj (i + 1)] + let andExpr ← mkValidConj 0 + -- Wrap in the `is_valid_p_is_valid_p` theorem, and abstract the continuation + let isValid ← mkAppM ``FixI.Funs.is_valid_p_is_valid_p #[inOutTys, k_var, bodyFuns, andExpr] + mkLambdaFVars #[k_var] isValid + +-- Prove that the mut rec body is valid +-- TODO: maybe this function should introduce k_var itself +def proveMutRecIsValid + (grName : Name) (grLvlParams : List Name) + (inOutTys : Expr) (bodyFuns mutRecBodyConst : Expr) + (k_var : Expr) (preDefs : Array PreDefinition) + (bodies : Array Expr) : MetaM Expr := do + -- First prove that the individual bodies are valid + let bodiesValid ← + bodies.mapIdxM fun idx body => do + let preDef := preDefs.get! idx + proveSingleBodyIsValid k_var preDef body + -- Then prove that the mut rec body is valid + let isValid ← proveFunsBodyIsValid inOutTys bodyFuns k_var bodiesValid + -- Save the theorem + let thmTy ← mkAppM ``FixI.is_valid #[mutRecBodyConst] + let name := grName ++ "mut_rec_body_is_valid" + let decl := Declaration.thmDecl { + name + levelParams := grLvlParams + type := thmTy + value := isValid + all := [name] + } + addDecl decl + trace[Diverge.def.valid] "proveFunsBodyIsValid: added thm: {name}:\n{thmTy}" + -- Return the theorem + pure (Expr.const name (grLvlParams.map .param)) -- Generate the final definions by using the mutual body and the fixed point operator. -def mkDeclareFixDefs (mutBody : Expr) (preDefs : Array PreDefinition) : +def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : TermElabM Unit := do let grSize := preDefs.size let _ ← preDefs.mapIdxM fun idx preDef => do @@ -357,7 +696,7 @@ def mkDeclareFixDefs (mutBody : Expr) (preDefs : Array PreDefinition) : -- Group the inputs into a dependent tuple let input ← mkSigmas xs.toList -- Apply the fixed point - let fixedBody ← mkAppM ``FixI.fix #[mutBody, idx, input] + let fixedBody ← mkAppM ``FixI.fix #[mutRecBody, idx, input] let fixedBody ← mkLambdaFVars xs fixedBody -- Create the declaration let name := preDef.declName @@ -454,24 +793,26 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Introduce the continuation `k` let in_ty ← mkLambdaFVars #[i_var] in_ty let out_ty ← mkLambdaFVars #[i_var, input] out_ty - let k_var_ty ← mkAppM ``FixI.kk_ty #[i_var_ty, in_ty, out_ty] -- - trace[Diverge.def] "k_var_ty: {k_var_ty}" - withLocalDeclD (.num (.str .anonymous "k") 2) k_var_ty fun k_var => do - trace[Diverge.def] "k_var: {k_var}" + let kk_var_ty ← mkAppM ``FixI.kk_ty #[i_var_ty, in_ty, out_ty] + trace[Diverge.def] "kk_var_ty: {kk_var_ty}" + withLocalDeclD (.num (.str .anonymous "kk") 2) kk_var_ty fun kk_var => do + trace[Diverge.def] "kk_var: {kk_var}" -- Replace the recursive calls in all the function bodies by calls to the -- continuation `k` and and generate for those bodies declarations - let bodies ← mkDeclareUnaryBodies grLvlParams k_var preDefs + let bodies ← mkDeclareUnaryBodies grLvlParams kk_var preDefs -- Generate the mutually recursive body - let body ← mkDeclareMutualBody grName grLvlParams i_var k_var in_ty out_ty inOutTys.toList bodies - trace[Diverge.def] "mut rec body (after decl): {body}" + let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var in_ty out_ty inOutTys.toList bodies + trace[Diverge.def] "mut rec body (after decl): {mutRecBody}" -- Prove that the mut rec body satisfies the validity criteria required by -- our fixed-point - -- TODO + let k_var_ty ← mkAppM ``FixI.k_ty #[i_var_ty, in_ty, out_ty] + withLocalDeclD (.num (.str .anonymous "k") 3) k_var_ty fun k_var => do + let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies -- Generate the final definitions - let defs ← mkDeclareFixDefs body preDefs + let defs ← mkDeclareFixDefs mutRecBody preDefs -- Prove the unfolding equations -- TODO @@ -496,13 +837,10 @@ def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLC for preDefs in cliques do trace[Diverge.elab] "{preDefs.map (·.declName)}" try - trace[Diverge.elab] "calling divRecursion" withRef (preDefs[0]!.ref) do divRecursion preDefs - trace[Diverge.elab] "divRecursion succeeded" catch ex => - -- If it failed, we - trace[Diverge.elab] "divRecursion failed" + -- If it failed, we add the functions as partial functions hasErrors := true logException ex let s ← saveState @@ -600,7 +938,8 @@ divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := else return (← list_nth ls (i - 1)) #print list_nth.in_out_ty -#check list_nth.body +#check list_nth.sbody +#check list_nth.mut_rec_body #print list_nth mutual diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean index 82f79f94..281dbd6c 100644 --- a/backends/lean/Base/Diverge/ElabBase.lean +++ b/backends/lean/Base/Diverge/ElabBase.lean @@ -8,6 +8,7 @@ initialize registerTraceClass `Diverge.elab initialize registerTraceClass `Diverge.def initialize registerTraceClass `Diverge.def.sigmas initialize registerTraceClass `Diverge.def.genBody +initialize registerTraceClass `Diverge.def.valid -- TODO: move -- TODO: small helper -- cgit v1.2.3 From 7ceab6a725e5bd17c05bfd381753e453b15afaf7 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 3 Jul 2023 16:46:59 +0200 Subject: Add a missing case in the validity proofs --- backends/lean/Base/Diverge/Elab.lean | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) (limited to 'backends/lean/Base/Diverge') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index cf40ea8f..063480a2 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -378,17 +378,22 @@ mutual partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do trace[Diverge.def.valid] "proveValid: {e}" match e with + | .const _ _ => throwError "Unimplemented" -- Shouldn't get there? | .bvar _ | .fvar _ - | .mvar _ - | .sort _ | .lit _ - | .const _ _ => throwError "Unimplemented" + | .mvar _ + | .sort _ => throwError "Unreachable" | .lam .. => throwError "Unimplemented" | .forallE .. => throwError "Unreachable" -- Shouldn't get there - | .letE .. => throwError "TODO" - -- lambdaLetTelescope e fun xs b => mapVisitBinders xs do - -- mkLambdaFVars xs (← mapVisit k b) (usedLetOnly := false) + | .letE dName dTy dValue body _nonDep => do + -- Introduce a local declaration for the let-binding + withLetDecl dName dTy dValue fun decl => do + let isValid ← proveExprIsValid k_var kk_var body + -- Add the let-binding around (rem.: the let-binding should be + -- *inside* the `is_valid_p`, not outside, but because it reduces + -- in the end it doesn't matter) + mkLetFVars #[decl] isValid | .mdata _ b => proveExprIsValid k_var kk_var b | .proj _ _ _ => -- The projection shouldn't use the continuation @@ -963,4 +968,12 @@ mutual if i > 20 then foo (i / 20) else .ret 42 end +-- Testing dependent branching and let-bindings +-- TODO: why the linter warning? +divergent def is_non_zero (i : Int) : Result Bool := + if _h:i = 0 then return false + else + let b := true + return b + end Diverge -- cgit v1.2.3 From 9214484c471ad931924865855687f9a2ffe255dd Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 3 Jul 2023 18:02:52 +0200 Subject: Automate the proofs of the unfolding theorems for Diverge --- backends/lean/Base/Diverge/Elab.lean | 107 +++++++++++++++++++++++++------ backends/lean/Base/Diverge/ElabBase.lean | 1 + 2 files changed, 89 insertions(+), 19 deletions(-) (limited to 'backends/lean/Base/Diverge') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 063480a2..91c51a31 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -16,8 +16,9 @@ syntax (name := divergentDef) open Lean Elab Term Meta Primitives Lean.Meta set_option trace.Diverge.def true -set_option trace.Diverge.def.valid true +-- set_option trace.Diverge.def.valid true -- set_option trace.Diverge.def.sigmas true +set_option trace.Diverge.def.unfold true /- The following was copied from the `wfRecursion` function. -/ @@ -390,9 +391,10 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do -- Introduce a local declaration for the let-binding withLetDecl dName dTy dValue fun decl => do let isValid ← proveExprIsValid k_var kk_var body - -- Add the let-binding around (rem.: the let-binding should be - -- *inside* the `is_valid_p`, not outside, but because it reduces - -- in the end it doesn't matter) + -- Add the let-binding around. + -- Rem.: the let-binding should be *inside* the `is_valid_p`, not outside, + -- but because it reduces in the end it doesn't matter. More precisely: + -- `P (let x := v in y)` and `let x := v in P y` reduce to the same expression. mkLetFVars #[decl] isValid | .mdata _ b => proveExprIsValid k_var kk_var b | .proj _ _ _ => @@ -692,9 +694,9 @@ def proveMutRecIsValid -- Generate the final definions by using the mutual body and the fixed point operator. def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : - TermElabM Unit := do + TermElabM (Array Name) := do let grSize := preDefs.size - let _ ← preDefs.mapIdxM fun idx preDef => do + let defs ← preDefs.mapIdxM fun idx preDef => do lambdaLetTelescope preDef.value fun xs _ => do -- Create the index let idx ← mkFinVal grSize idx.val @@ -715,7 +717,58 @@ def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : all := [name] } addDecl decl - pure () + pure name + pure defs + +-- Prove the equations that we will use as unfolding theorems +partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinition) + (decls : Array Name) : MetaM Unit := do + let grSize := preDefs.size + let proveIdx (i : Nat) : MetaM Unit := do + let preDef := preDefs.get! i + let defName := decls.get! i + -- Retrieve the arguments + lambdaLetTelescope preDef.value fun xs body => do + trace[Diverge.def.unfold] "proveUnfoldingThms: xs: {xs}" + trace[Diverge.def.unfold] "proveUnfoldingThms: body: {body}" + -- The theorem statement + let thmTy ← do + -- The equation: the declaration gives the lhs, the pre-def gives the rhs + let lhs ← mkAppOptM defName (xs.map some) + let rhs := body + let eq ← mkAppM ``Eq #[lhs, rhs] + mkForallFVars xs eq + trace[Diverge.def.unfold] "proveUnfoldingThms: thm statement: {thmTy}" + -- The proof + -- Use the fixed-point equation + let proof ← mkAppM ``FixI.is_valid_fix_fixed_eq #[isValidThm] + -- Add the index + let idx ← mkFinVal grSize i + let proof ← mkAppM ``congr_fun #[proof, idx] + -- Add the input argument + let arg ← mkSigmas xs.toList + let proof ← mkAppM ``congr_fun #[proof, arg] + -- Abstract the arguments away + let proof ← mkLambdaFVars xs proof + trace[Diverge.def.unfold] "proveUnfoldingThms: proof: {proof}:\n{← inferType proof}" + -- Declare the theorem + let name := preDef.declName ++ "unfold" + let decl := Declaration.thmDecl { + name + levelParams := preDef.levelParams + type := thmTy + value := proof + all := [name] + } + addDecl decl + trace[Diverge.def.unfold] "proveUnfoldingThms: added thm: {name}:\n{thmTy}" + let rec prove (i : Nat) : MetaM Unit := do + if i = preDefs.size then pure () + else do + proveIdx i + prove (i + 1) + -- + prove 0 def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) @@ -817,12 +870,12 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies -- Generate the final definitions - let defs ← mkDeclareFixDefs mutRecBody preDefs + let decls ← mkDeclareFixDefs mutRecBody preDefs - -- Prove the unfolding equations - -- TODO + -- Prove the unfolding theorems + proveUnfoldingThms isValidThm preDefs decls - -- Process the definitions + -- Process the definitions - TODO addAndCompilePartialRec preDefs -- The following function is copy&pasted from Lean.Elab.PreDefinition.Main @@ -942,10 +995,23 @@ divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := if i = 0 then return x else return (← list_nth ls (i - 1)) -#print list_nth.in_out_ty -#check list_nth.sbody -#check list_nth.mut_rec_body -#print list_nth +example {a: Type} (ls : List a) : + ∀ (i : Int), + 0 ≤ i → i < ls.length → + ∃ x, list_nth ls i = .ret x := by + induction ls + . intro i hpos h; simp at h; linarith + . rename_i hd tl ih + intro i hpos h + rw [list_nth.unfold]; simp + split <;> simp [*] + . tauto + . -- TODO: we shouldn't have to do that + have hneq : 0 < i := by cases i <;> rename_i a _ <;> simp_all; cases a <;> simp_all + simp at h + have ⟨ x, ih ⟩ := ih (i - 1) (by linarith) (by linarith) + simp [ih] + tauto mutual divergent def is_even (i : Int) : Result Bool := @@ -955,10 +1021,8 @@ mutual if i = 0 then return false else return (← is_even (i - 1)) end -example (i : Int) : is_even i = .ret (i % 2 = 0) ∧ is_odd i = .ret (i % 2 ≠ 0) := by - induction i - unfold is_even - sorry +#print is_even.unfold +#print is_odd.unfold mutual divergent def foo (i : Int) : Result Nat := @@ -968,6 +1032,9 @@ mutual if i > 20 then foo (i / 20) else .ret 42 end +#print foo.unfold +#print bar.unfold + -- Testing dependent branching and let-bindings -- TODO: why the linter warning? divergent def is_non_zero (i : Int) : Result Bool := @@ -976,4 +1043,6 @@ divergent def is_non_zero (i : Int) : Result Bool := let b := true return b +#print is_non_zero.unfold + end Diverge diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean index 281dbd6c..fd95291e 100644 --- a/backends/lean/Base/Diverge/ElabBase.lean +++ b/backends/lean/Base/Diverge/ElabBase.lean @@ -9,6 +9,7 @@ initialize registerTraceClass `Diverge.def initialize registerTraceClass `Diverge.def.sigmas initialize registerTraceClass `Diverge.def.genBody initialize registerTraceClass `Diverge.def.valid +initialize registerTraceClass `Diverge.def.unfold -- TODO: move -- TODO: small helper -- cgit v1.2.3 From 75fae6384716f24fe137283d4a41836782b9aec7 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 3 Jul 2023 19:26:27 +0200 Subject: Cleanup a bit Diverge/Elab.lean --- backends/lean/Base/Diverge/Elab.lean | 366 +++++++++++++++++++---------------- 1 file changed, 197 insertions(+), 169 deletions(-) (limited to 'backends/lean/Base/Diverge') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 91c51a31..cc580265 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -15,39 +15,16 @@ syntax (name := divergentDef) open Lean Elab Term Meta Primitives Lean.Meta -set_option trace.Diverge.def true --- set_option trace.Diverge.def.valid true --- set_option trace.Diverge.def.sigmas true -set_option trace.Diverge.def.unfold true - /- The following was copied from the `wfRecursion` function. -/ open WF in -def mkList (xl : List Expr) (ty : Expr) : MetaM Expr := - match xl with - | [] => - mkAppOptM ``List.nil #[some ty] - | x :: tl => do - let tl ← mkList tl ty - mkAppOptM ``List.cons #[some ty, some x, some tl] - def mkProd (x y : Expr) : MetaM Expr := mkAppM ``Prod.mk #[x, y] def mkInOutTy (x y : Expr) : MetaM Expr := mkAppM ``FixI.mk_in_out_ty #[x, y] --- TODO: is there already such a utility somewhere? --- TODO: change to mkSigmas -def mkProds (tys : List Expr) : MetaM Expr := - match tys with - | [] => do pure (Expr.const ``PUnit.unit []) - | [ty] => do pure ty - | ty :: tys => do - let pty ← mkProds tys - mkAppM ``Prod.mk #[ty, pty] - -- Return the `a` in `Return a` def get_result_ty (ty : Expr) : MetaM Expr := ty.withApp fun f args => do @@ -56,26 +33,31 @@ def get_result_ty (ty : Expr) : MetaM Expr := else pure (args.get! 0) --- Group a list of expressions into a dependent tuple -def mkSigmas (xl : List Expr) : MetaM Expr := +/- Group a list of expressions into a dependent tuple. + + Example: + xl = [`a : Type`, `ls : List a`] + returns: + `⟨ (a:Type), (ls: List a) ⟩` + -/ +def mkSigmasVal (xl : List Expr) : MetaM Expr := match xl with | [] => do - trace[Diverge.def.sigmas] "mkSigmas: []" + trace[Diverge.def.sigmas] "mkSigmasVal: []" pure (Expr.const ``PUnit.unit []) | [x] => do - trace[Diverge.def.sigmas] "mkSigmas: [{x}]" + trace[Diverge.def.sigmas] "mkSigmasVal: [{x}]" pure x | fst :: xl => do - trace[Diverge.def.sigmas] "mkSigmas: [{fst}::{xl}]" + trace[Diverge.def.sigmas] "mkSigmasVal: [{fst}::{xl}]" let alpha ← Lean.Meta.inferType fst - let snd ← mkSigmas xl + let snd ← mkSigmasVal xl let snd_ty ← inferType snd let beta ← mkLambdaFVars #[fst] snd_ty - trace[Diverge.def.sigmas] "mkSigmas:\n{alpha}\n{beta}\n{fst}\n{snd}" + trace[Diverge.def.sigmas] "mkSigmasVal:\n{alpha}\n{beta}\n{fst}\n{snd}" mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd] -/- Generate the input type of a function body, which is a sigma type (i.e., a - dependent tuple) which groups all its inputs. +/- Generate a Sigma type from a list of expressions. Example: - xl = [(a:Type), (ls:List a), (i:Int)] @@ -84,7 +66,7 @@ def mkSigmas (xl : List Expr) : MetaM Expr := `(a:Type) × (ls:List a) × (i:Int)` -/ -def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr := +def mkSigmasType (xl : List Expr) : MetaM Expr := match xl with | [] => do trace[Diverge.def.sigmas] "mkSigmasOfTypes: []" @@ -96,15 +78,16 @@ def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr := | x :: xl => do trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]" let alpha ← Lean.Meta.inferType x - let sty ← mkSigmasTypesOfTypes xl + let sty ← mkSigmasType xl trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]: alpha={alpha}, sty={sty}" let beta ← mkLambdaFVars #[x] sty trace[Diverge.def.sigmas] "mkSigmasOfTypes: ({alpha}) ({beta})" mkAppOptM ``Sigma #[some alpha, some beta] -def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index +def mkAnonymous (s : String) (i : Nat) : Name := + .num (.str .anonymous s) i -/- Given a list of values `[x0:ty0, ..., xn:ty1]` where every `xi` might use the previous +/- Given a list of values `[x0:ty0, ..., xn:ty1]`, where every `xi` might use the previous `xj` (j < i) and a value `out` which uses `x0`, ..., `xn`, generate the following expression: ``` @@ -112,20 +95,22 @@ def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index match x with | (x0, ..., xn) => out ``` - + The `index` parameter is used for naming purposes: we use it to numerotate the bound variables that we introduce. + We use this function to currify functions (the function bodies given to the + fixed-point operator must be unary functions). + Example: ======== - More precisely: - xl = `[a:Type, ls:List a, i:Int]` - out = `a` - index = 0 - generates: + generates (getting rid of most of the syntactic sugar): ``` - match scrut0 with + λ scrut0 => match scrut0 with | Sigma.mk x scrut1 => match scrut1 with | Sigma.mk ls i => @@ -138,21 +123,30 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met -- This would be unexpected throwError "mkSigmasMatch: empyt list of input parameters" | [x] => do - -- In the explanations above: inner match case + -- In the example given for the explanations: this is the inner match case trace[Diverge.def.sigmas] "mkSigmasMatch: [{x}]" mkLambdaFVars #[x] out | fst :: xl => do - -- In the explanations above: outer match case + -- In the example given for the explanations: this is the outer match case -- Remark: for the naming purposes, we use the same convention as for the - -- fields and parameters in `Sigma.casesOn and `Sigma.mk + -- fields and parameters in `Sigma.casesOn` and `Sigma.mk` (looking at + -- those definitions might help) + -- + -- We want to build the match expression: + -- ``` + -- λ scrut => + -- match scrut with + -- | Sigma.mk x ... -- the hole is given by a recursive call on the tail + -- ``` trace[Diverge.def.sigmas] "mkSigmasMatch: [{fst}::{xl}]" let alpha ← Lean.Meta.inferType fst - let snd_ty ← mkSigmasTypesOfTypes xl + let snd_ty ← mkSigmasType xl let beta ← mkLambdaFVars #[fst] snd_ty let snd ← mkSigmasMatch xl out (index + 1) - let scrut_ty ← mkSigmasTypesOfTypes (fst :: xl) - withLocalDeclD (mk_indexed_name index) scrut_ty fun scrut => do let mk ← mkLambdaFVars #[fst] snd + -- Introduce the "scrut" variable + let scrut_ty ← mkSigmasType (fst :: xl) + withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do trace[Diverge.def.sigmas] "mkSigmasMatch: scrut: ({scrut}) : ({← inferType scrut})" -- TODO: make the computation of the motive more efficient let motive ← do @@ -166,38 +160,32 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met -- TODO: make this more efficient (we could change the output type of -- mkSigmasMatch mkSigmasMatch (fst :: xl) out_ty + -- The final expression: putting everything together trace[Diverge.def.sigmas] "mkSigmasMatch:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})" let sm ← mkAppOptM ``Sigma.casesOn #[some alpha, some beta, some motive, some scrut, some mk] + -- Abstracting the "scrut" variable let sm ← mkLambdaFVars #[scrut] sm trace[Diverge.def.sigmas] "mkSigmasMatch: sm: {sm}" pure sm /- Small tests for list_nth: give a model of what `mkSigmasMatch` should generate -/ -private def list_nth_out_ty2 (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) := +private def list_nth_out_ty_inner (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) := @Sigma.casesOn (List a) (fun (_ls : List a) => Int) (fun (_scrut1:@Sigma (List a) (fun (_ls : List a) => Int)) => Type) scrut1 (fun (_ls : List a) (_i : Int) => Diverge.Primitives.Result a) -private def list_nth_out_ty1 (scrut0 : @Sigma (Type) (fun (a:Type) => +private def list_nth_out_ty_outer (scrut0 : @Sigma (Type) (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int))) := @Sigma.casesOn (Type) (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int)) (fun (_scrut0:@Sigma (Type) (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int))) => Type) scrut0 (fun (a : Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) => - list_nth_out_ty2 a scrut1) + list_nth_out_ty_inner a scrut1) /- -/ --- TODO: move --- TODO: we can use Array.mapIdx -@[specialize] def mapiAux (i : Nat) (f : Nat → α → β) : List α → List β - | [] => [] - | a::as => f i a :: mapiAux (i+1) f as - -@[specialize] def mapi (f : Nat → α → β) : List α → List β := mapiAux 0 f - -- Return the expression: `Fin n` -- TODO: use more def mkFin (n : Nat) : Expr := @@ -212,15 +200,6 @@ def mkFinVal (n i : Nat) : MetaM Expr := do let ofNat ← mkAppOptM ``Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat #[n_lit, i_lit] mkAppOptM ``OfNat.ofNat #[none, none, ofNat] --- TODO: remove? -def mkFinValOld (n i : Nat) : MetaM Expr := do - let finTy := mkFin n - let ofNat ← mkAppM ``OfNat #[finTy, .lit (.natVal i)] - match ← trySynthInstance ofNat with - | LOption.some x => - mkAppOptM ``OfNat.ofNat #[none, none, x] - | _ => throwError "mkFinVal: could not synthesize an instance of {ofNat} " - /- Generate and declare as individual definitions the bodies for the individual funcions: - replace the recursive calls with calls to the continutation `k` - make those bodies take one single dependent tuple as input @@ -234,11 +213,11 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) let grSize := preDefs.size -- Compute the map from name to index - the continuation has an indexed type: - -- we use the index (a finite number of type `Fin`) to control the function - -- we call at the recursive call + -- we use the index (a finite number of type `Fin`) to control which function + -- we call at the recursive call site. let nameToId : HashMap Name Nat := - let namesIds := mapi (fun i d => (d.declName, i)) preDefs.toList - HashMap.ofList namesIds + let namesIds := preDefs.mapIdx (fun i d => (d.declName, i.val)) + HashMap.ofList namesIds.toList trace[Diverge.def.genBody] "nameToId: {nameToId.toList}" @@ -260,7 +239,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) -- Compute the index let i ← mkFinVal grSize id -- Put the arguments in one big dependent tuple - let args ← mkSigmas args.toList + let args ← mkSigmasVal args.toList mkAppM' kk_var #[i, args] else -- Not a recursive call: do nothing @@ -277,13 +256,14 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) -- Replace the recursive calls let body ← mapVisit visit_e preDef.value - -- Change the type + -- Currify the function by grouping the arguments into a dependent tuple + -- (over which we match to retrieve the individual arguments). lambdaLetTelescope body fun args body => do let body ← mkSigmasMatch args.toList body 0 -- Add the declaration let value ← mkLambdaFVars #[kk_var] body - let name := preDef.declName.append "sbody" + let name := preDef.declName.append "body" let levelParams := grLvlParams let decl := Declaration.defnDecl { name := name @@ -304,7 +284,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) -- Generate a unique function body from the bodies of the mutually recursive group, -- and add it as a declaration in the context. --- We return the list of bodies (of type `Funs ...`) and the mutually recursive body. +-- We return the list of bodies (of type `FixI.Funs ...`) and the mutually recursive body. def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name) (kk_var i_var : Expr) (in_ty out_ty : Expr) (inOutTys : List (Expr × Expr)) @@ -322,7 +302,7 @@ def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name) mkAppOptM ``FixI.Funs.Nil #[finTypeExpr, in_ty, out_ty] | (ity, oty) :: inOutTys, b :: bl => do -- Retrieving ity and oty - this is not very clean - let inOutTysExpr ← mkList (← inOutTys.mapM (λ (x, y) => mkInOutTy x y)) inOutTyType + let inOutTysExpr ← mkListLit inOutTyType (← inOutTys.mapM (λ (x, y) => mkInOutTy x y)) let fl ← mkFuns inOutTys bl mkAppOptM ``FixI.Funs.Cons #[finTypeExpr, in_ty, out_ty, ity, oty, inOutTysExpr, b, fl] | _, _ => throwError "mkDeclareMutRecBody: `tys` and `bodies` don't have the same length" @@ -345,7 +325,7 @@ def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name) all := [name] } addDecl decl - -- Return the constant + -- Return the bodies and the constant pure (bodyFuns, Lean.mkConst name (levelParams.map .param)) def isCasesExpr (e : Expr) : MetaM Bool := do @@ -367,7 +347,8 @@ instance : ToMessageData MatchInfo where -- This is not a very clean formatting, but we don't need more toMessageData := fun me => m!"\n- matcherName: {me.matcherName}\n- params: {me.params}\n- motive: {me.motive}\n- scruts: {me.scruts}\n- branchesNumParams: {me.branchesNumParams}\n- branches: {me.branches}" --- An expression which doesn't use the continuation kk is valid +-- Small helper: prove that an expression which doesn't use the continuation `kk` +-- is valid, and return the proof. def proveNoKExprIsValid (k_var : Expr) (e : Expr) : MetaM Expr := do trace[Diverge.def.valid] "proveNoKExprIsValid: {e}" let eIsValid ← mkAppM ``FixI.is_valid_p_same #[k_var, e] @@ -376,6 +357,14 @@ def proveNoKExprIsValid (k_var : Expr) (e : Expr) : MetaM Expr := do mutual +/- Prove that an expression is valid, and return the proof. + + More precisely, if `e` is an expression which potentially uses the continution + `kk`, return an expression of type: + ``` + is_valid_p k (λ kk => e) + ``` + -/ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do trace[Diverge.def.valid] "proveValid: {e}" match e with @@ -403,7 +392,9 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do | .app .. => e.withApp fun f args => do -- There are several cases: first, check if this is a match/if - -- The expression is a (dependent) if then else + -- Check if the expression is a (dependent) if then else. + -- We treat the if then else expressions differently from the other matches, + -- and have dedicated theorems for them. let isIte := e.isIte if isIte || e.isDIte then do e.withApp fun f args => do @@ -431,9 +422,9 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do let eIsValid ← mkAppOptM const #[none, none, none, none, some k_var, some cond, some dec, none, none, some br0Valid, some br1Valid] trace[Diverge.def.valid] "ite/dite: result:\n{eIsValid}:\n{← inferType eIsValid}" pure eIsValid - -- The expression is a match (this case is for when the elaborator + -- Check if the expression is a match (this case is for when the elaborator -- introduces auxiliary definitions to hide the match behind syntactic - -- sugar) + -- sugar): else if let some me := ← matchMatcherApp? e then do trace[Diverge.def.valid] "matcherApp: @@ -443,7 +434,8 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do - altNumParams: {me.altNumParams} - alts: {me.alts} - remaining: {me.remaining}" - -- matchMatcherApp has already done the work for us + -- matchMatcherApp does all the work for us: we simply need to gather + -- the information and call the auxiliary helper `proveMatchIsValid` if me.remaining.size ≠ 0 then throwError "MatcherApp: non empty remaining array: {me.remaining}" let me : MatchInfo := { @@ -456,14 +448,21 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do branches := me.alts } proveMatchIsValid k_var kk_var me - -- The expression is a raw match (this case is for when the expression - -- is a direct call to the primitive `casesOn` function, without - -- syntactic sugar) + -- Check if the expression is a raw match (this case is for when the expression + -- is a direct call to the primitive `casesOn` function, without syntactic sugar). + -- We have to check this case because functions like `mkSigmasMatch`, which we + -- use to currify function bodies, introduce such raw matches. else if ← isCasesExpr f then do trace[Diverge.def.valid] "rawMatch: {e}" + -- Deconstruct the match, and call the auxiliary helper `proveMatchIsValid`. + -- -- The casesOn definition is always of the following shape: - -- input parameters (implicit parameters), then motive (implicit), - -- scrutinee (explicit), branches (explicit). + -- - input parameters (implicit parameters) + -- - motive (implicit), -- the motive gives the return type of the match + -- - scrutinee (explicit) + -- - branches (explicit). + -- In particular, we notice that the scrutinee is the first *explicit* + -- parameter - this is how we spot it. let matcherName := f.constName! let matcherLevels := f.constLevels!.toArray -- Find the first explicit parameter: this is the scrutinee @@ -484,7 +483,9 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do let scrut := args.get! scrutIdx let branches := args.extract (scrutIdx + 1) args.size -- Compute the number of parameters for the branches: for this we use - -- the type of the uninstantiated casesOn constant + -- the type of the uninstantiated casesOn constant (we can't just + -- destruct the lambdas in the branch expressions because the result + -- of a match might be a lambda expression). let branchesNumParams : Array Nat ← do let env ← getEnv let decl := env.constants.find! matcherName @@ -505,9 +506,11 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do branches, } proveMatchIsValid k_var kk_var me - -- Monadic let-binding + -- Check if this is a monadic let-binding else if f.isConstOf ``Bind.bind then do trace[Diverge.def.valid] "bind:\n{args}" + -- We simply need to prove that the subexpressions are valid, and call + -- the appropriate lemma. let x := args.get! 4 let y := args.get! 5 -- Prove that the subexpressions are valid @@ -529,7 +532,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do -- Put everything together trace[Diverge.def.valid] "bind:\n- xValid: {xValid}: {← inferType xValid}\n- yValid: {yValid}: {← inferType yValid}" mkAppM ``FixI.is_valid_p_bind #[xValid, yValid] - -- Recursive call + -- Check if this is a recursive call, i.e., a call to the continuation `kk` else if f.isFVarOf kk_var.fvarId! then do trace[Diverge.def.valid] "rec: args: \n{args}" if args.size ≠ 2 then throwError "Recursive call with invalid number of parameters: {args}" @@ -540,9 +543,10 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do pure eIsValid else do -- Remaining case: normal application. - -- It shouldn't use the continuation + -- It shouldn't use the continuation. proveNoKExprIsValid k_var e +-- Prove that a match expression is valid. partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Expr := do trace[Diverge.def.valid] "proveMatchIsValid: {me}" -- Prove the validity of the branch expressions @@ -561,16 +565,18 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp -- Reconstruct the lambda expression mkLambdaFVars xs_beg brValid trace[Diverge.def.valid] "branchesValid:\n{branchesValid}" - -- Put together: compute the motive. - -- It must be of the shape: + -- Compute the motive, which has the following shape: -- ``` -- λ scrut => is_valid_p k (λ k => match scrut with ...) + -- ^^^^^^^^^^^^^^^^^^^^ + -- this is the original match expression, with the + -- the difference that the scrutinee(s) is a variable -- ``` let validMotive : Expr ← do -- The motive is a function of the scrutinees (i.e., a lambda expression): -- introduce binders for the scrutinees let declInfos := me.scruts.mapIdx fun idx scrut => - let name : Name := (.num (.str .anonymous "scrut") idx) + let name : Name := mkAnonymous "scrut" idx let ty := λ (_ : Array Expr) => inferType scrut (name, ty) withLocalDeclsD declInfos fun scrutVars => do @@ -582,7 +588,6 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp let branches : Array (Option Expr) := me.branches.map some let args := params ++ [motive] ++ scruts ++ branches let matchE ← mkAppOptM me.matcherName args - -- let matchE ← mkLambdaFVars scrutVars (← mkAppOptM me.matcherName args) -- Wrap in the `is_valid_p` predicate let matchE ← mkLambdaFVars #[kk_var] matchE let validMotive ← mkAppM ``FixI.is_valid_p #[k_var, matchE] @@ -591,6 +596,7 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp trace[Diverge.def.valid] "valid motive: {validMotive}" -- Put together let valid ← do + -- We let Lean infer the parameters let params : Array (Option Expr) := me.params.map (λ _ => none) let motive := some validMotive let scruts := me.scruts.map some @@ -602,12 +608,16 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp end --- Prove that a single body (in the mutually recursive group) is valid -partial def proveSingleBodyIsValid (k_var : Expr) (preDef : PreDefinition) (bodyConst : Expr) : +-- Prove that a single body (in the mutually recursive group) is valid. +-- +-- For instance, if we define the mutually recursive group [`is_even`, `is_odd`], +-- we prove that `is_even.body` and `is_odd.body` are valid. +partial def proveSingleBodyIsValid + (k_var : Expr) (preDef : PreDefinition) (bodyConst : Expr) : MetaM Expr := do trace[Diverge.def.valid] "proveSingleBodyIsValid: bodyConst: {bodyConst}" - -- Lookup the definition (`bodyConst` is the definition of the body, we want - -- to retrieve the value itself to dive inside) + -- Lookup the definition (`bodyConst` is a const, we want to retrieve its + -- definition to dive inside) let name := bodyConst.constName! let env ← getEnv let body := (env.constants.find! name).value! @@ -633,7 +643,7 @@ partial def proveSingleBodyIsValid (k_var : Expr) (preDef : PreDefinition) (body mkForallFVars #[k_var, x_var] ty trace[Diverge.def.valid] "proveSingleBodyIsValid: thmTy\n{thmTy}:\n{← inferType thmTy}" -- Save the theorem - let name := preDef.declName ++ "sbody_is_valid" + let name := preDef.declName ++ "body_is_valid" let decl := Declaration.thmDecl { name levelParams := preDef.levelParams @@ -646,6 +656,11 @@ partial def proveSingleBodyIsValid (k_var : Expr) (preDef : PreDefinition) (body -- Return the theorem pure (Expr.const name (preDef.levelParams.map .param)) +-- Prove that the list of bodies are valid. +-- +-- For instance, if we define the mutually recursive group [`is_even`, `is_odd`], +-- we prove that `Funs.Cons is_even.body (Funs.Cons is_odd.body Funs.Nil)` is +-- valid. partial def proveFunsBodyIsValid (inOutTys: Expr) (bodyFuns : Expr) (k_var : Expr) (bodiesValid : Array Expr) : MetaM Expr := do -- Create the big "and" expression, which groups the validity proof of the individual bodies @@ -663,7 +678,13 @@ partial def proveFunsBodyIsValid (inOutTys: Expr) (bodyFuns : Expr) let isValid ← mkAppM ``FixI.Funs.is_valid_p_is_valid_p #[inOutTys, k_var, bodyFuns, andExpr] mkLambdaFVars #[k_var] isValid --- Prove that the mut rec body is valid +-- Prove that the mut rec body (i.e., the unary body which groups the bodies +-- of all the functions in the mutually recursive group and on which we will +-- apply the fixed-point operator) is valid. +-- +-- We save the proof in the theorem "[GROUP_NAME]."mut_rec_body_is_valid", +-- which we return. +-- -- TODO: maybe this function should introduce k_var itself def proveMutRecIsValid (grName : Name) (grLvlParams : List Name) @@ -693,6 +714,12 @@ def proveMutRecIsValid pure (Expr.const name (grLvlParams.map .param)) -- Generate the final definions by using the mutual body and the fixed point operator. +-- +-- For instance: +-- ``` +-- def is_even (i : Int) : Result Bool := mut_rec_body 0 i +-- def is_odd (i : Int) : Result Bool := mut_rec_body 1 i +-- ``` def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : TermElabM (Array Name) := do let grSize := preDefs.size @@ -701,7 +728,7 @@ def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : -- Create the index let idx ← mkFinVal grSize idx.val -- Group the inputs into a dependent tuple - let input ← mkSigmas xs.toList + let input ← mkSigmasVal xs.toList -- Apply the fixed point let fixedBody ← mkAppM ``FixI.fix #[mutRecBody, idx, input] let fixedBody ← mkLambdaFVars xs fixedBody @@ -746,7 +773,7 @@ partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinitio let idx ← mkFinVal grSize i let proof ← mkAppM ``congr_fun #[proof, idx] -- Add the input argument - let arg ← mkSigmas xs.toList + let arg ← mkSigmasVal xs.toList let proof ← mkAppM ``congr_fun #[proof, arg] -- Abstract the arguments away let proof ← mkLambdaFVars xs proof @@ -774,11 +801,6 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) trace[Diverge.def] ("divRecursion: defs: " ++ msg) - -- CHANGE HERE This function should add definitions with these names/types/values ^^ - -- Temporarily add the predefinitions as axioms - -- for preDef in preDefs do - -- addAsAxiom preDef - -- TODO: what is this? for preDef in preDefs do applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation @@ -803,7 +825,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do if preDef.levelParams ≠ grLvlParams then throwError "Non-uniform polymorphism in the universes" forallTelescope preDef.type (fun in_tys out_ty => do - let in_ty ← liftM (mkSigmasTypesOfTypes in_tys.toList) + let in_ty ← liftM (mkSigmasType in_tys.toList) -- Retrieve the type in the "Result" let out_ty ← get_result_ty out_ty let out_ty ← liftM (mkSigmasMatch in_tys.toList out_ty) @@ -813,14 +835,14 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do trace[Diverge.def] "inOutTys: {inOutTys}" -- Turn the list of input/output type pairs into an expresion let inOutTysExpr ← inOutTys.mapM (λ (x, y) => mkInOutTy x y) - let inOutTysExpr ← mkList inOutTysExpr.toList (← inferType (inOutTysExpr.get! 0)) + let inOutTysExpr ← mkListLit (← inferType (inOutTysExpr.get! 0)) inOutTysExpr.toList -- From the list of pairs of input/output types, actually compute the -- type of the continuation `k`. -- We first introduce the index `i : Fin n` where `n` is the number of -- functions in the group. let i_var_ty := mkFin preDefs.size - withLocalDeclD (.num (.str .anonymous "i") 0) i_var_ty fun i_var => do + withLocalDeclD (mkAnonymous "i" 0) i_var_ty fun i_var => do let in_out_ty ← mkAppM ``List.get #[inOutTysExpr, i_var] trace[Diverge.def] "in_out_ty := {in_out_ty} : {← inferType in_out_ty}" -- Add an auxiliary definition for `in_out_ty` @@ -844,7 +866,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do trace[Diverge.def] "in_out_ty (after decl) := {in_out_ty} : {← inferType in_out_ty}" let in_ty ← mkAppM ``Sigma.fst #[in_out_ty] trace[Diverge.def] "in_ty: {in_ty}" - withLocalDeclD (.num (.str .anonymous "x") 1) in_ty fun input => do + withLocalDeclD (mkAnonymous "x" 1) in_ty fun input => do let out_ty ← mkAppM' (← mkAppM ``Sigma.snd #[in_out_ty]) #[input] trace[Diverge.def] "out_ty: {out_ty}" @@ -853,7 +875,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let out_ty ← mkLambdaFVars #[i_var, input] out_ty let kk_var_ty ← mkAppM ``FixI.kk_ty #[i_var_ty, in_ty, out_ty] trace[Diverge.def] "kk_var_ty: {kk_var_ty}" - withLocalDeclD (.num (.str .anonymous "kk") 2) kk_var_ty fun kk_var => do + withLocalDeclD (mkAnonymous "kk" 2) kk_var_ty fun kk_var => do trace[Diverge.def] "kk_var: {kk_var}" -- Replace the recursive calls in all the function bodies by calls to the @@ -866,7 +888,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Prove that the mut rec body satisfies the validity criteria required by -- our fixed-point let k_var_ty ← mkAppM ``FixI.k_ty #[i_var_ty, in_ty, out_ty] - withLocalDeclD (.num (.str .anonymous "k") 3) k_var_ty fun k_var => do + withLocalDeclD (mkAnonymous "k" 3) k_var_ty fun k_var => do let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies -- Generate the final definitions @@ -915,7 +937,7 @@ def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLC else return () catch _ => s.restore --- The following two functions are copy&pasted from Lean.Elab.MutualDef +-- The following two functions are copy-pasted from Lean.Elab.MutualDef open private elabHeaders levelMVarToParamHeaders getAllUserLevelNames withFunLocalDecls elabFunValues instantiateMVarsAtHeader instantiateMVarsAtLetRecToLift checkLetRecsToLiftTypes withUsed from Lean.Elab.MutualDef @@ -988,61 +1010,67 @@ elab_rules : command else Command.elabCommand <| ← `(namespace $(mkIdentFrom id ns) $cmd end $(mkIdentFrom id ns)) -divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := - match ls with - | [] => .fail .panic - | x :: ls => - if i = 0 then return x - else return (← list_nth ls (i - 1)) - -example {a: Type} (ls : List a) : - ∀ (i : Int), - 0 ≤ i → i < ls.length → - ∃ x, list_nth ls i = .ret x := by - induction ls - . intro i hpos h; simp at h; linarith - . rename_i hd tl ih - intro i hpos h - rw [list_nth.unfold]; simp - split <;> simp [*] - . tauto - . -- TODO: we shouldn't have to do that - have hneq : 0 < i := by cases i <;> rename_i a _ <;> simp_all; cases a <;> simp_all - simp at h - have ⟨ x, ih ⟩ := ih (i - 1) (by linarith) (by linarith) - simp [ih] - tauto - -mutual - divergent def is_even (i : Int) : Result Bool := - if i = 0 then return true else return (← is_odd (i - 1)) - - divergent def is_odd (i : Int) : Result Bool := - if i = 0 then return false else return (← is_even (i - 1)) -end - -#print is_even.unfold -#print is_odd.unfold - -mutual - divergent def foo (i : Int) : Result Nat := - if i > 10 then return (← foo (i / 10)) + (← bar i) else bar 10 - - divergent def bar (i : Int) : Result Nat := - if i > 20 then foo (i / 20) else .ret 42 -end - -#print foo.unfold -#print bar.unfold - --- Testing dependent branching and let-bindings --- TODO: why the linter warning? -divergent def is_non_zero (i : Int) : Result Bool := - if _h:i = 0 then return false - else - let b := true - return b +namespace Tests + /- Some examples of partial functions -/ + + divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := + match ls with + | [] => .fail .panic + | x :: ls => + if i = 0 then return x + else return (← list_nth ls (i - 1)) + + #check list_nth.unfold + + example {a: Type} (ls : List a) : + ∀ (i : Int), + 0 ≤ i → i < ls.length → + ∃ x, list_nth ls i = .ret x := by + induction ls + . intro i hpos h; simp at h; linarith + . rename_i hd tl ih + intro i hpos h + rw [list_nth.unfold]; simp + split <;> simp [*] + . tauto + . -- TODO: we shouldn't have to do that + have hneq : 0 < i := by cases i <;> rename_i a _ <;> simp_all; cases a <;> simp_all + simp at h + have ⟨ x, ih ⟩ := ih (i - 1) (by linarith) (by linarith) + simp [ih] + tauto + + mutual + divergent def is_even (i : Int) : Result Bool := + if i = 0 then return true else return (← is_odd (i - 1)) + + divergent def is_odd (i : Int) : Result Bool := + if i = 0 then return false else return (← is_even (i - 1)) + end + + #check is_even.unfold + #check is_odd.unfold + + mutual + divergent def foo (i : Int) : Result Nat := + if i > 10 then return (← foo (i / 10)) + (← bar i) else bar 10 + + divergent def bar (i : Int) : Result Nat := + if i > 20 then foo (i / 20) else .ret 42 + end + + #check foo.unfold + #check bar.unfold + + -- Testing dependent branching and let-bindings + -- TODO: why the linter warning? + divergent def is_non_zero (i : Int) : Result Bool := + if _h:i = 0 then return false + else + let b := true + return b -#print is_non_zero.unfold + #check is_non_zero.unfold +end Tests end Diverge -- cgit v1.2.3 From 40e21034fa9e955734351b78a8cc5f16315418bd Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 4 Jul 2023 12:13:09 +0200 Subject: Add an implemented_by attribute to fix --- backends/lean/Base/Diverge/Base.lean | 29 ++++++++++++++++++----------- 1 file changed, 18 insertions(+), 11 deletions(-) (limited to 'backends/lean/Base/Diverge') diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index 89365d25..a8503107 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -57,6 +57,12 @@ deriving Repr, BEq open Result +instance Result_Inhabited (α : Type u) : Inhabited (Result α) := + Inhabited.mk (fail panic) + +instance Result_Nonempty (α : Type u) : Nonempty (Result α) := + Nonempty.intro div + def bind {α : Type u} {β : Type v} (x: Result α) (f: α -> Result β) : Result β := match x with | ret v => f v @@ -156,7 +162,14 @@ namespace Fix (x : a) (n : Nat) : Prop := fix_fuel_pred f x n - noncomputable + partial + def fixImpl (f : ((x:a) → Result (b x)) → (x:a) → Result (b x)) (x : a) : Result (b x) := + f (fixImpl f) x + + -- The fact that `fix` is implemented by `fixImpl` allows us to not mark the + -- functions defined with the fixed-point as noncomputable. One big advantage + -- is that it allows us to evaluate those functions, for instance with #eval. + @[implemented_by fixImpl] def fix (f : ((x:a) → Result (b x)) → (x:a) → Result (b x)) (x : a) : Result (b x) := fix_fuel (least (fix_fuel_P f x)) f x @@ -548,7 +561,7 @@ namespace FixI def is_valid (f : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) : Prop := ∀ k i x, is_valid_p k (λ k => f k i x) - noncomputable def fix + def fix (f : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) : (i:id) → (x:a i) → Result (b i x) := kk_of_gen (Fix.fix (k_to_gen f)) @@ -808,7 +821,6 @@ namespace Ex1 split <;> simp split <;> simp - noncomputable def list_nth (ls : List a) (i : Int) : Result a := fix list_nth_body (ls, i) -- The unfolding equation - diverges if `i < 0` @@ -851,7 +863,6 @@ namespace Ex2 split <;> simp apply is_valid_p_bind <;> intros <;> simp_all - noncomputable def list_nth (ls : List a) (i : Int) : Result a := fix list_nth_body (ls, i) -- The unfolding equation - diverges if `i < 0` @@ -932,7 +943,6 @@ namespace Ex3 apply is_valid_p_bind; simp intros; split <;> simp - noncomputable def is_even (i : Int): Result Bool := do let r ← fix is_even_is_odd_body (.inl i) @@ -940,7 +950,6 @@ namespace Ex3 | .inl b => .ret b | .inr _ => .fail .panic - noncomputable def is_odd (i : Int): Result Bool := do let r ← fix is_even_is_odd_body (.inr i) @@ -1032,8 +1041,8 @@ namespace Ex4 theorem body_fix_eq : fix body = body (fix body) := is_valid_fix_fixed_eq body_is_valid - noncomputable def is_even (i : Int) : Result Bool := fix body 0 i - noncomputable def is_odd (i : Int) : Result Bool := fix body 1 i + def is_even (i : Int) : Result Bool := fix body 0 i + def is_odd (i : Int) : Result Bool := fix body 1 i theorem is_even_eq (i : Int) : is_even i = (if i = 0 @@ -1052,7 +1061,6 @@ namespace Ex4 .ret b) := by simp [is_even, is_odd]; conv => lhs; rw [body_fix_eq] - end Ex4 namespace Ex5 @@ -1109,7 +1117,7 @@ namespace Ex5 intro k x simp only [is_valid_p_same, is_valid_p_rec] - noncomputable def id (t : Tree a) := fix id_body t + def id (t : Tree a) := fix id_body t -- The unfolding equation theorem id_eq (t : Tree a) : @@ -1183,7 +1191,6 @@ namespace Ex6 Funs.is_valid_p_is_valid_p tys k bodies (And.intro (list_nth_body_is_valid' k) (Funs.is_valid_p_Nil k)) - noncomputable def list_nth {a: Type u} (ls : List a) (i : Int) : Result a := fix body 0 ⟨ a, ls , i ⟩ -- cgit v1.2.3 From 4fd17e4bb91eb46d4704643dfbfbbf0874837b07 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 4 Jul 2023 12:49:37 +0200 Subject: Make Diverge use Primitives --- backends/lean/Base/Diverge/Base.lean | 65 +------------------------------- backends/lean/Base/Diverge/Elab.lean | 2 +- backends/lean/Base/Diverge/ElabBase.lean | 8 ++-- 3 files changed, 6 insertions(+), 69 deletions(-) (limited to 'backends/lean/Base/Diverge') diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index a8503107..e22eb914 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -4,8 +4,7 @@ import Init.Data.List.Basic import Mathlib.Tactic.RunCmd import Mathlib.Tactic.Linarith --- For debugging -import Base.Diverge.ElabBase +import Base.Primitives /- TODO: @@ -35,68 +34,6 @@ set_option profiler.threshold 100 namespace Diverge -namespace Primitives -/-! # Copy-pasting from Primitives to make the file self-contained -/ - -inductive Error where - | assertionFailure: Error - | integerOverflow: Error - | divisionByZero: Error - | arrayOutOfBounds: Error - | maximumSizeExceeded: Error - | panic: Error -deriving Repr, BEq - -open Error - -inductive Result (α : Type u) where - | ret (v: α): Result α - | fail (e: Error): Result α - | div -deriving Repr, BEq - -open Result - -instance Result_Inhabited (α : Type u) : Inhabited (Result α) := - Inhabited.mk (fail panic) - -instance Result_Nonempty (α : Type u) : Nonempty (Result α) := - Nonempty.intro div - -def bind {α : Type u} {β : Type v} (x: Result α) (f: α -> Result β) : Result β := - match x with - | ret v => f v - | fail v => fail v - | div => div - -@[simp] theorem bind_ret (x : α) (f : α → Result β) : bind (.ret x) f = f x := by simp [bind] -@[simp] theorem bind_fail (x : Error) (f : α → Result β) : bind (.fail x) f = .fail x := by simp [bind] -@[simp] theorem bind_div (f : α → Result β) : bind .div f = .div := by simp [bind] - --- Allows using Result in do-blocks -instance : Bind Result where - bind := bind - --- Allows using return x in do-blocks -instance : Pure Result where - pure := fun x => ret x - -@[simp] theorem bind_tc_ret (x : α) (f : α → Result β) : - (do let y ← .ret x; f y) = f x := by simp [Bind.bind, bind] - -@[simp] theorem bind_tc_fail (x : Error) (f : α → Result β) : - (do let y ← fail x; f y) = fail x := by simp [Bind.bind, bind] - -@[simp] theorem bind_tc_div (f : α → Result β) : - (do let y ← div; f y) = div := by simp [Bind.bind, bind] - -def div? {α: Type u} (r: Result α): Bool := - match r with - | div => true - | ret _ | fail _ => false - -end Primitives - namespace Fix open Primitives diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index cc580265..41209021 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -174,7 +174,7 @@ private def list_nth_out_ty_inner (a :Type) (scrut1: @Sigma (List a) (fun (_ls : (fun (_ls : List a) => Int) (fun (_scrut1:@Sigma (List a) (fun (_ls : List a) => Int)) => Type) scrut1 - (fun (_ls : List a) (_i : Int) => Diverge.Primitives.Result a) + (fun (_ls : List a) (_i : Int) => Primitives.Result a) private def list_nth_out_ty_outer (scrut0 : @Sigma (Type) (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int))) := diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean index fd95291e..1c1062c0 100644 --- a/backends/lean/Base/Diverge/ElabBase.lean +++ b/backends/lean/Base/Diverge/ElabBase.lean @@ -4,6 +4,7 @@ namespace Diverge open Lean Elab Term Meta +-- We can't define and use trace classes in the same file initialize registerTraceClass `Diverge.elab initialize registerTraceClass `Diverge.def initialize registerTraceClass `Diverge.def.sigmas @@ -11,8 +12,8 @@ initialize registerTraceClass `Diverge.def.genBody initialize registerTraceClass `Diverge.def.valid initialize registerTraceClass `Diverge.def.unfold --- TODO: move --- TODO: small helper +-- Useful helper to explore definitions and figure out the variant +-- of their sub-expressions. def explore_term (incr : String) (e : Expr) : MetaM Unit := match e with | .bvar _ => do logInfo m!"{incr}bvar: {e}"; return () @@ -81,8 +82,7 @@ private def test2 (x : Nat) : Nat := x print_decl test1 print_decl test2 --- We adapted this from AbstractNestedProofs.visit --- A map visitor function for expressions +-- A map visitor function for expressions (adapted from `AbstractNestedProofs.visit`) partial def mapVisit (k : Expr → MetaM Expr) (e : Expr) : MetaM Expr := do let mapVisitBinders (xs : Array Expr) (k2 : MetaM Expr) : MetaM Expr := do let localInstances ← getLocalInstances -- cgit v1.2.3 From bd873499f9a8d517cc948c6336a5c6ce856d846d Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 4 Jul 2023 17:30:35 +0200 Subject: Fix some issues with the extraction to Lean --- backends/lean/Base/Diverge/Elab.lean | 63 +++++++++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 16 deletions(-) (limited to 'backends/lean/Base/Diverge') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 41209021..4b08fe44 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -255,10 +255,11 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) preDefs.mapM fun preDef => do -- Replace the recursive calls let body ← mapVisit visit_e preDef.value + trace[Diverge.def.genBody] "Body after replacement of the recursive calls: {body}" -- Currify the function by grouping the arguments into a dependent tuple -- (over which we match to retrieve the individual arguments). - lambdaLetTelescope body fun args body => do + lambdaTelescope body fun args body => do let body ← mkSigmasMatch args.toList body 0 -- Add the declaration @@ -376,15 +377,18 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do | .sort _ => throwError "Unreachable" | .lam .. => throwError "Unimplemented" | .forallE .. => throwError "Unreachable" -- Shouldn't get there - | .letE dName dTy dValue body _nonDep => do - -- Introduce a local declaration for the let-binding - withLetDecl dName dTy dValue fun decl => do + | .letE .. => do + -- Telescope all the let-bindings (remark: this also telescopes the lambdas) + lambdaLetTelescope e fun xs body => do + -- Note that we don't visit the bound values: there shouldn't be + -- recursive calls, lambda expressions, etc. inside + -- Prove that the body is valid let isValid ← proveExprIsValid k_var kk_var body - -- Add the let-binding around. + -- Add the let-bindings around. -- Rem.: the let-binding should be *inside* the `is_valid_p`, not outside, -- but because it reduces in the end it doesn't matter. More precisely: -- `P (let x := v in y)` and `let x := v in P y` reduce to the same expression. - mkLetFVars #[decl] isValid + mkLambdaFVars xs isValid (usedLetOnly := false) | .mdata _ b => proveExprIsValid k_var kk_var b | .proj _ _ _ => -- The projection shouldn't use the continuation @@ -410,7 +414,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do if isIte then proveExprIsValid k_var kk_var br else do -- There is a lambda -- TODO: how do we remove exacly *one* lambda? - lambdaLetTelescope br fun xs br => do + lambdaTelescope br fun xs br => do let x := xs.get! 0 let xs := xs.extract 1 xs.size let br ← mkLambdaFVars xs br @@ -518,7 +522,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}" let yValid ← do -- This is a lambda expression -- TODO: how do we remove exacly *one* lambda? - lambdaLetTelescope y fun xs y => do + lambdaTelescope y fun xs y => do let x := xs.get! 0 let xs := xs.extract 1 xs.size let y ← mkLambdaFVars xs y @@ -555,7 +559,7 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp -- binders might come from the match, and some of the binders might come -- from the fact that the expression in the match is a lambda expression: -- we use the branchesNumParams field for this reason - lambdaLetTelescope br fun xs br => do + lambdaTelescope br fun xs br => do let numParams := me.branchesNumParams.get! idx let xs_beg := xs.extract 0 numParams let xs_end := xs.extract numParams xs.size @@ -622,7 +626,7 @@ partial def proveSingleBodyIsValid let env ← getEnv let body := (env.constants.find! name).value! trace[Diverge.def.valid] "body: {body}" - lambdaLetTelescope body fun xs body => do + lambdaTelescope body fun xs body => do assert! xs.size = 2 let kk_var := xs.get! 0 let x_var := xs.get! 1 @@ -695,8 +699,10 @@ def proveMutRecIsValid let bodiesValid ← bodies.mapIdxM fun idx body => do let preDef := preDefs.get! idx + trace[Diverge.def.valid] "## Proving that the body {body} is valid" proveSingleBodyIsValid k_var preDef body -- Then prove that the mut rec body is valid + trace[Diverge.def.valid] "## Proving that the 'Funs' body is valid" let isValid ← proveFunsBodyIsValid inOutTys bodyFuns k_var bodiesValid -- Save the theorem let thmTy ← mkAppM ``FixI.is_valid #[mutRecBodyConst] @@ -724,7 +730,7 @@ def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : TermElabM (Array Name) := do let grSize := preDefs.size let defs ← preDefs.mapIdxM fun idx preDef => do - lambdaLetTelescope preDef.value fun xs _ => do + lambdaTelescope preDef.value fun xs _ => do -- Create the index let idx ← mkFinVal grSize idx.val -- Group the inputs into a dependent tuple @@ -755,7 +761,7 @@ partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinitio let preDef := preDefs.get! i let defName := decls.get! i -- Retrieve the arguments - lambdaLetTelescope preDef.value fun xs body => do + lambdaTelescope preDef.value fun xs body => do trace[Diverge.def.unfold] "proveUnfoldingThms: xs: {xs}" trace[Diverge.def.unfold] "proveUnfoldingThms: body: {body}" -- The theorem statement @@ -799,7 +805,7 @@ partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinitio def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) - trace[Diverge.def] ("divRecursion: defs: " ++ msg) + trace[Diverge.def] ("divRecursion: defs:\n" ++ msg) -- TODO: what is this? for preDef in preDefs do @@ -880,8 +886,11 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Replace the recursive calls in all the function bodies by calls to the -- continuation `k` and and generate for those bodies declarations + trace[Diverge.def] "# Generating the unary bodies" let bodies ← mkDeclareUnaryBodies grLvlParams kk_var preDefs + trace[Diverge.def] "Unary bodies (after decl): {bodies}" -- Generate the mutually recursive body + trace[Diverge.def] "# Generating the mut rec body" let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var in_ty out_ty inOutTys.toList bodies trace[Diverge.def] "mut rec body (after decl): {mutRecBody}" @@ -889,15 +898,18 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- our fixed-point let k_var_ty ← mkAppM ``FixI.k_ty #[i_var_ty, in_ty, out_ty] withLocalDeclD (mkAnonymous "k" 3) k_var_ty fun k_var => do + trace[Diverge.def] "# Proving that the mut rec body is valid" let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies -- Generate the final definitions + trace[Diverge.def] "# Generating the final definitions" let decls ← mkDeclareFixDefs mutRecBody preDefs -- Prove the unfolding theorems + trace[Diverge.def] "# Proving the unfolding theorems" proveUnfoldingThms isValidThm preDefs decls - -- Process the definitions - TODO + -- Generating code -- TODO addAndCompilePartialRec preDefs -- The following function is copy&pasted from Lean.Elab.PreDefinition.Main @@ -1064,13 +1076,32 @@ namespace Tests -- Testing dependent branching and let-bindings -- TODO: why the linter warning? - divergent def is_non_zero (i : Int) : Result Bool := + divergent def isNonZero (i : Int) : Result Bool := if _h:i = 0 then return false else let b := true return b - #check is_non_zero.unfold + #check isNonZero.unfold + + -- Testing let-bindings + divergent def iInBounds {a : Type} (ls : List a) (i : Int) : Result Bool := + let i0 := ls.length + if i < i0 + then Result.ret True + else Result.ret False + + #check iInBounds.unfold + + divergent def isCons + {a : Type} (ls : List a) : Result Bool := + let ls1 := ls + match ls1 with + | [] => Result.ret False + | x :: tl => Result.ret True + + #check isCons.unfold + end Tests end Diverge -- cgit v1.2.3 From 442caaf62e4a217b9a10116c4e529c49f83c4efd Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 4 Jul 2023 22:45:02 +0200 Subject: Fix an issue with mkSigmasVal --- backends/lean/Base/Diverge/Elab.lean | 228 +++++++++++++++++++------------ backends/lean/Base/Diverge/ElabBase.lean | 47 ++++--- 2 files changed, 169 insertions(+), 106 deletions(-) (limited to 'backends/lean/Base/Diverge') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 4b08fe44..1af06fea 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -26,38 +26,42 @@ def mkInOutTy (x y : Expr) : MetaM Expr := mkAppM ``FixI.mk_in_out_ty #[x, y] -- Return the `a` in `Return a` -def get_result_ty (ty : Expr) : MetaM Expr := +def getResultTy (ty : Expr) : MetaM Expr := ty.withApp fun f args => do if ¬ f.isConstOf ``Result ∨ args.size ≠ 1 then - throwError "Invalid argument to get_result_ty: {ty}" + throwError "Invalid argument to getResultTy: {ty}" else pure (args.get! 0) -/- Group a list of expressions into a dependent tuple. +/- Deconstruct a sigma type. - Example: - xl = [`a : Type`, `ls : List a`] - returns: - `⟨ (a:Type), (ls: List a) ⟩` + For instance, deconstructs `(a : Type) × List a` into + `Type` and `λ a => List a`. -/ -def mkSigmasVal (xl : List Expr) : MetaM Expr := - match xl with - | [] => do - trace[Diverge.def.sigmas] "mkSigmasVal: []" - pure (Expr.const ``PUnit.unit []) - | [x] => do - trace[Diverge.def.sigmas] "mkSigmasVal: [{x}]" - pure x - | fst :: xl => do - trace[Diverge.def.sigmas] "mkSigmasVal: [{fst}::{xl}]" - let alpha ← Lean.Meta.inferType fst - let snd ← mkSigmasVal xl - let snd_ty ← inferType snd - let beta ← mkLambdaFVars #[fst] snd_ty - trace[Diverge.def.sigmas] "mkSigmasVal:\n{alpha}\n{beta}\n{fst}\n{snd}" - mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd] - -/- Generate a Sigma type from a list of expressions. +def getSigmaTypes (ty : Expr) : MetaM (Expr × Expr) := do + ty.withApp fun f args => do + if ¬ f.isConstOf ``Sigma ∨ args.size ≠ 2 then + throwError "Invalid argument to getSigmaTypes: {ty}" + else + pure (args.get! 0, args.get! 1) + +/- Like `lambdaTelescopeN` but only destructs a fixed number of lambdas -/ +def lambdaTelescopeN (e : Expr) (n : Nat) (k : Array Expr → Expr → MetaM α) : MetaM α := + lambdaTelescope e fun xs body => do + if xs.size < n then throwError "lambdaTelescopeN: not enough lambdas"; + let xs := xs.extract 0 n + let ys := xs.extract n xs.size + let body ← mkLambdaFVars ys body + k xs body + +/- Like `lambdaTelescope`, but only destructs one lambda + TODO: is there an equivalent of this function somewhere in the + standard library? -/ +def lambdaOne (e : Expr) (k : Expr → Expr → MetaM α) : MetaM α := + lambdaTelescopeN e 1 λ xs b => k (xs.get! 0) b + +/- Generate a Sigma type from a list of *variables* (all the expressions + must be variables). Example: - xl = [(a:Type), (ls:List a), (i:Int)] @@ -84,6 +88,53 @@ def mkSigmasType (xl : List Expr) : MetaM Expr := trace[Diverge.def.sigmas] "mkSigmasOfTypes: ({alpha}) ({beta})" mkAppOptM ``Sigma #[some alpha, some beta] +/- Apply a lambda expression to some arguments, simplifying the lambdas -/ +def applyLambdaToArgs (e : Expr) (xs : Array Expr) : MetaM Expr := do + lambdaTelescopeN e xs.size fun vars body => + -- Create the substitution + let s : HashMap FVarId Expr := HashMap.ofList (List.zip (vars.toList.map Expr.fvarId!) xs.toList) + -- Substitute in the body + pure (body.replace fun e => + match e with + | Expr.fvar fvarId => match s.find? fvarId with + | none => e + | some v => v + | _ => none) + +/- Group a list of expressions into a dependent tuple. + + Example: + xl = [`a : Type`, `ls : List a`] + returns: + `⟨ (a:Type), (ls: List a) ⟩` + + We need the type argument because as the elements in the tuple are + "concrete", we can't in all generality figure out the type of the tuple. + + Example: + `⟨ True, 3 ⟩ : (x : Bool) × (if x then Int else Unit)` + -/ +def mkSigmasVal (ty : Expr) (xl : List Expr) : MetaM Expr := + match xl with + | [] => do + trace[Diverge.def.sigmas] "mkSigmasVal: []" + pure (Expr.const ``PUnit.unit []) + | [x] => do + trace[Diverge.def.sigmas] "mkSigmasVal: [{x}]" + pure x + | fst :: xl => do + trace[Diverge.def.sigmas] "mkSigmasVal: [{fst}::{xl}]" + -- Deconstruct the type + let (alpha, beta) ← getSigmaTypes ty + -- Compute the "second" field + -- Specialize beta for fst + let nty ← applyLambdaToArgs beta #[fst] + -- Recursive call + let snd ← mkSigmasVal nty xl + -- Put everything together + trace[Diverge.def.sigmas] "mkSigmasVal:\n{alpha}\n{beta}\n{fst}\n{snd}" + mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd] + def mkAnonymous (s : String) (i : Nat) : Name := .num (.str .anonymous s) i @@ -208,52 +259,57 @@ def mkFinVal (n i : Nat) : MetaM Expr := do We return the new declarations. -/ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) - (preDefs : Array PreDefinition) : + (inOutTys : Array (Expr × Expr)) (preDefs : Array PreDefinition) : MetaM (Array Expr) := do let grSize := preDefs.size - -- Compute the map from name to index - the continuation has an indexed type: - -- we use the index (a finite number of type `Fin`) to control which function - -- we call at the recursive call site. - let nameToId : HashMap Name Nat := - let namesIds := preDefs.mapIdx (fun i d => (d.declName, i.val)) - HashMap.ofList namesIds.toList + -- Compute the map from name to (index × input type). + -- Remark: the continuation has an indexed type; we use the index (a finite number of + -- type `Fin`) to control which function we call at the recursive call site. + let nameToInfo : HashMap Name (Nat × Expr) := + let bl := preDefs.mapIdx fun i d => (d.declName, (i.val, (inOutTys.get! i.val).fst)) + HashMap.ofList bl.toList - trace[Diverge.def.genBody] "nameToId: {nameToId.toList}" + trace[Diverge.def.genBody] "nameToId: {nameToInfo.toList}" -- Auxiliary function to explore the function bodies and replace the -- recursive calls - let visit_e (e : Expr) : MetaM Expr := do - trace[Diverge.def.genBody] "visiting expression: {e}" - match e with - | .app .. => do - e.withApp fun f args => do - trace[Diverge.def.genBody] "this is an app: {f} {args}" - -- Check if this is a recursive call - if f.isConst then - let name := f.constName! - match nameToId.find? name with - | none => pure e - | some id => - -- This is a recursive call: replace it - -- Compute the index - let i ← mkFinVal grSize id - -- Put the arguments in one big dependent tuple - let args ← mkSigmasVal args.toList - mkAppM' kk_var #[i, args] - else - -- Not a recursive call: do nothing - pure e - | .const name _ => - -- Sanity check: we eliminated all the recursive calls - if (nameToId.find? name).isSome then - throwError "mkUnaryBodies: a recursive call was not eliminated" - else pure e - | _ => pure e + let visit_e (i : Nat) (e : Expr) : MetaM Expr := do + trace[Diverge.def.genBody] "visiting expression (dept: {i}): {e}" + let ne ← do + match e with + | .app .. => do + e.withApp fun f args => do + trace[Diverge.def.genBody] "this is an app: {f} {args}" + -- Check if this is a recursive call + if f.isConst then + let name := f.constName! + match nameToInfo.find? name with + | none => pure e + | some (id, in_ty) => + trace[Diverge.def.genBody] "this is a recursive call" + -- This is a recursive call: replace it + -- Compute the index + let i ← mkFinVal grSize id + -- Put the arguments in one big dependent tuple + let args ← mkSigmasVal in_ty args.toList + mkAppM' kk_var #[i, args] + else + -- Not a recursive call: do nothing + pure e + | .const name _ => + -- Sanity check: we eliminated all the recursive calls + if (nameToInfo.find? name).isSome then + throwError "mkUnaryBodies: a recursive call was not eliminated" + else pure e + | _ => pure e + trace[Diverge.def.genBody] "done with expression (depth: {i}): {e}" + pure ne -- Explore the bodies preDefs.mapM fun preDef => do -- Replace the recursive calls + trace[Diverge.def.genBody] "About to replace recursive calls in {preDef.declName}" let body ← mapVisit visit_e preDef.value trace[Diverge.def.genBody] "Body after replacement of the recursive calls: {body}" @@ -413,11 +469,8 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do let proveBranchValid (br : Expr) : MetaM Expr := if isIte then proveExprIsValid k_var kk_var br else do - -- There is a lambda -- TODO: how do we remove exacly *one* lambda? - lambdaTelescope br fun xs br => do - let x := xs.get! 0 - let xs := xs.extract 1 xs.size - let br ← mkLambdaFVars xs br + -- There is a lambda + lambdaOne br fun x br => do let brValid ← proveExprIsValid k_var kk_var br mkLambdaFVars #[x] brValid let br0Valid ← proveBranchValid br0 @@ -521,11 +574,8 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do let xValid ← proveExprIsValid k_var kk_var x trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}" let yValid ← do - -- This is a lambda expression -- TODO: how do we remove exacly *one* lambda? - lambdaTelescope y fun xs y => do - let x := xs.get! 0 - let xs := xs.extract 1 xs.size - let y ← mkLambdaFVars xs y + -- This is a lambda expression + lambdaOne y fun x y => do trace[Diverge.def.valid] "bind: y: {y}" let yValid ← proveExprIsValid k_var kk_var y trace[Diverge.def.valid] "bind: yValid (no forall): {yValid}" @@ -559,15 +609,12 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp -- binders might come from the match, and some of the binders might come -- from the fact that the expression in the match is a lambda expression: -- we use the branchesNumParams field for this reason - lambdaTelescope br fun xs br => do let numParams := me.branchesNumParams.get! idx - let xs_beg := xs.extract 0 numParams - let xs_end := xs.extract numParams xs.size - let br ← mkLambdaFVars xs_end br + lambdaTelescopeN br numParams fun xs br => do -- Prove that the branch expression is valid let brValid ← proveExprIsValid k_var kk_var br -- Reconstruct the lambda expression - mkLambdaFVars xs_beg brValid + mkLambdaFVars xs brValid trace[Diverge.def.valid] "branchesValid:\n{branchesValid}" -- Compute the motive, which has the following shape: -- ``` @@ -726,15 +773,17 @@ def proveMutRecIsValid -- def is_even (i : Int) : Result Bool := mut_rec_body 0 i -- def is_odd (i : Int) : Result Bool := mut_rec_body 1 i -- ``` -def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : +def mkDeclareFixDefs (mutRecBody : Expr) (inOutTys : Array (Expr × Expr)) (preDefs : Array PreDefinition) : TermElabM (Array Name) := do let grSize := preDefs.size let defs ← preDefs.mapIdxM fun idx preDef => do lambdaTelescope preDef.value fun xs _ => do + -- Retrieve the input type + let in_ty := (inOutTys.get! idx.val).fst -- Create the index let idx ← mkFinVal grSize idx.val -- Group the inputs into a dependent tuple - let input ← mkSigmasVal xs.toList + let input ← mkSigmasVal in_ty xs.toList -- Apply the fixed point let fixedBody ← mkAppM ``FixI.fix #[mutRecBody, idx, input] let fixedBody ← mkLambdaFVars xs fixedBody @@ -754,8 +803,8 @@ def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : pure defs -- Prove the equations that we will use as unfolding theorems -partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinition) - (decls : Array Name) : MetaM Unit := do +partial def proveUnfoldingThms (isValidThm : Expr) (inOutTys : Array (Expr × Expr)) + (preDefs : Array PreDefinition) (decls : Array Name) : MetaM Unit := do let grSize := preDefs.size let proveIdx (i : Nat) : MetaM Unit := do let preDef := preDefs.get! i @@ -779,7 +828,7 @@ partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinitio let idx ← mkFinVal grSize i let proof ← mkAppM ``congr_fun #[proof, idx] -- Add the input argument - let arg ← mkSigmasVal xs.toList + let arg ← mkSigmasVal (inOutTys.get! i).fst xs.toList let proof ← mkAppM ``congr_fun #[proof, arg] -- Abstract the arguments away let proof ← mkLambdaFVars xs proof @@ -833,7 +882,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do forallTelescope preDef.type (fun in_tys out_ty => do let in_ty ← liftM (mkSigmasType in_tys.toList) -- Retrieve the type in the "Result" - let out_ty ← get_result_ty out_ty + let out_ty ← getResultTy out_ty let out_ty ← liftM (mkSigmasMatch in_tys.toList out_ty) pure (in_ty, out_ty) ) @@ -886,8 +935,8 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Replace the recursive calls in all the function bodies by calls to the -- continuation `k` and and generate for those bodies declarations - trace[Diverge.def] "# Generating the unary bodies" - let bodies ← mkDeclareUnaryBodies grLvlParams kk_var preDefs + trace[Diverge.def] "# Generating the unary bodies" + let bodies ← mkDeclareUnaryBodies grLvlParams kk_var inOutTys preDefs trace[Diverge.def] "Unary bodies (after decl): {bodies}" -- Generate the mutually recursive body trace[Diverge.def] "# Generating the mut rec body" @@ -903,11 +952,11 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Generate the final definitions trace[Diverge.def] "# Generating the final definitions" - let decls ← mkDeclareFixDefs mutRecBody preDefs + let decls ← mkDeclareFixDefs mutRecBody inOutTys preDefs -- Prove the unfolding theorems trace[Diverge.def] "# Proving the unfolding theorems" - proveUnfoldingThms isValidThm preDefs decls + proveUnfoldingThms isValidThm inOutTys preDefs decls -- Generating code -- TODO addAndCompilePartialRec preDefs @@ -1102,6 +1151,15 @@ namespace Tests #check isCons.unfold + -- Testing what happens when we use concrete arguments in dependent tuples + divergent def test1 + (_ : Option Bool) (_ : Unit) : + Result Unit + := + test1 Option.none () + + #check test1.unfold + end Tests end Diverge diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean index 1c1062c0..aaaea6f7 100644 --- a/backends/lean/Base/Diverge/ElabBase.lean +++ b/backends/lean/Base/Diverge/ElabBase.lean @@ -83,7 +83,10 @@ print_decl test1 print_decl test2 -- A map visitor function for expressions (adapted from `AbstractNestedProofs.visit`) -partial def mapVisit (k : Expr → MetaM Expr) (e : Expr) : MetaM Expr := do +-- The continuation takes as parameters: +-- - the current depth of the expression (useful for printing/debugging) +-- - the expression to explore +partial def mapVisit (k : Nat → Expr → MetaM Expr) (e : Expr) : MetaM Expr := do let mapVisitBinders (xs : Array Expr) (k2 : MetaM Expr) : MetaM Expr := do let localInstances ← getLocalInstances let mut lctx ← getLCtx @@ -98,25 +101,27 @@ partial def mapVisit (k : Expr → MetaM Expr) (e : Expr) : MetaM Expr := do lctx :=lctx.modifyLocalDecl xFVarId fun _ => localDecl withLCtx lctx localInstances k2 -- TODO: use a cache? (Lean.checkCache) - -- Explore - let e ← k e - match e with - | .bvar _ - | .fvar _ - | .mvar _ - | .sort _ - | .lit _ - | .const _ _ => pure e - | .app .. => do e.withApp fun f args => return mkAppN f (← args.mapM (mapVisit k)) - | .lam .. => - lambdaLetTelescope e fun xs b => - mapVisitBinders xs do mkLambdaFVars xs (← mapVisit k b) (usedLetOnly := false) - | .forallE .. => do - forallTelescope e fun xs b => mapVisitBinders xs do mkForallFVars xs (← mapVisit k b) - | .letE .. => do - lambdaLetTelescope e fun xs b => mapVisitBinders xs do - mkLambdaFVars xs (← mapVisit k b) (usedLetOnly := false) - | .mdata _ b => return e.updateMData! (← mapVisit k b) - | .proj _ _ b => return e.updateProj! (← mapVisit k b) + let rec visit (i : Nat) (e : Expr) : MetaM Expr := do + -- Explore + let e ← k i e + match e with + | .bvar _ + | .fvar _ + | .mvar _ + | .sort _ + | .lit _ + | .const _ _ => pure e + | .app .. => do e.withApp fun f args => return mkAppN f (← args.mapM (visit (i + 1))) + | .lam .. => + lambdaLetTelescope e fun xs b => + mapVisitBinders xs do mkLambdaFVars xs (← visit (i + 1) b) (usedLetOnly := false) + | .forallE .. => do + forallTelescope e fun xs b => mapVisitBinders xs do mkForallFVars xs (← visit (i + 1) b) + | .letE .. => do + lambdaLetTelescope e fun xs b => mapVisitBinders xs do + mkLambdaFVars xs (← visit (i + 1) b) (usedLetOnly := false) + | .mdata _ b => return e.updateMData! (← visit (i + 1) b) + | .proj _ _ b => return e.updateProj! (← visit (i + 1) b) + visit 0 e end Diverge -- cgit v1.2.3 From 2496a08691809683e256af7c479588a2fae8e3d7 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 6 Jul 2023 14:23:21 +0200 Subject: Register the unfolding theorems in the Lean equation compilers and solve a "unused variable" warning --- backends/lean/Base/Diverge/Elab.lean | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) (limited to 'backends/lean/Base/Diverge') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 1af06fea..e5b39440 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -843,6 +843,8 @@ partial def proveUnfoldingThms (isValidThm : Expr) (inOutTys : Array (Expr × Ex all := [name] } addDecl decl + -- Add the unfolding theorem to the equation compiler + eqnsAttribute.add preDef.declName #[name] trace[Diverge.def.unfold] "proveUnfoldingThms: added thm: {name}:\n{thmTy}" let rec prove (i : Nat) : MetaM Unit := do if i = preDefs.size then pure () @@ -1011,6 +1013,13 @@ def Term.elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM U withFunLocalDecls headers fun funFVars => do for view in views, funFVar in funFVars do addLocalVarInfo view.declId funFVar + -- Add fake use site to prevent "unused variable" warning (if the + -- function is actually not recursive, Lean would print this warning). + -- Remark: we could detect this case and encode the function without + -- using the fixed-point. In practice it shouldn't happen however: + -- we define non-recursive functions with the `divergent` keyword + -- only for testing purposes. + addTermInfo' view.declId funFVar let values ← try let values ← elabFunValues headers @@ -1091,7 +1100,8 @@ namespace Tests . intro i hpos h; simp at h; linarith . rename_i hd tl ih intro i hpos h - rw [list_nth.unfold]; simp + -- We can directly use `rw [list_nth]`! + rw [list_nth]; simp split <;> simp [*] . tauto . -- TODO: we shouldn't have to do that @@ -1147,7 +1157,7 @@ namespace Tests let ls1 := ls match ls1 with | [] => Result.ret False - | x :: tl => Result.ret True + | _ :: _ => Result.ret True #check isCons.unfold -- cgit v1.2.3 From 9515bbad5b58ed1c51ac6d9fc9d7a4e5884b6273 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 6 Jul 2023 15:23:53 +0200 Subject: Reorganize a bit the lean backend files --- backends/lean/Base/Diverge/Elab.lean | 2 ++ 1 file changed, 2 insertions(+) (limited to 'backends/lean/Base/Diverge') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index e5b39440..96f7abc0 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -2,6 +2,7 @@ import Lean import Lean.Meta.Tactic.Simp import Init.Data.List.Basic import Mathlib.Tactic.RunCmd +import Base.Utils import Base.Diverge.Base import Base.Diverge.ElabBase @@ -13,6 +14,7 @@ namespace Diverge syntax (name := divergentDef) declModifiers "divergent" "def" declId ppIndent(optDeclSig) declVal : command +open Utils open Lean Elab Term Meta Primitives Lean.Meta /- The following was copied from the `wfRecursion` function. -/ -- cgit v1.2.3 From 0d1ac53f88f947ae94f6afb57b2a7e18a77460a7 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Sun, 9 Jul 2023 10:11:13 +0200 Subject: Make progress on the int tactic --- backends/lean/Base/Diverge/ElabBase.lean | 112 ------------------------------- 1 file changed, 112 deletions(-) (limited to 'backends/lean/Base/Diverge') diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean index aaaea6f7..fedb1c74 100644 --- a/backends/lean/Base/Diverge/ElabBase.lean +++ b/backends/lean/Base/Diverge/ElabBase.lean @@ -12,116 +12,4 @@ initialize registerTraceClass `Diverge.def.genBody initialize registerTraceClass `Diverge.def.valid initialize registerTraceClass `Diverge.def.unfold --- Useful helper to explore definitions and figure out the variant --- of their sub-expressions. -def explore_term (incr : String) (e : Expr) : MetaM Unit := - match e with - | .bvar _ => do logInfo m!"{incr}bvar: {e}"; return () - | .fvar _ => do logInfo m!"{incr}fvar: {e}"; return () - | .mvar _ => do logInfo m!"{incr}mvar: {e}"; return () - | .sort _ => do logInfo m!"{incr}sort: {e}"; return () - | .const _ _ => do logInfo m!"{incr}const: {e}"; return () - | .app fn arg => do - logInfo m!"{incr}app: {e}" - explore_term (incr ++ " ") fn - explore_term (incr ++ " ") arg - | .lam _bName bTy body _binfo => do - logInfo m!"{incr}lam: {e}" - explore_term (incr ++ " ") bTy - explore_term (incr ++ " ") body - | .forallE _bName bTy body _bInfo => do - logInfo m!"{incr}forallE: {e}" - explore_term (incr ++ " ") bTy - explore_term (incr ++ " ") body - | .letE _dName ty val body _nonDep => do - logInfo m!"{incr}letE: {e}" - explore_term (incr ++ " ") ty - explore_term (incr ++ " ") val - explore_term (incr ++ " ") body - | .lit _ => do logInfo m!"{incr}lit: {e}"; return () - | .mdata _ e => do - logInfo m!"{incr}mdata: {e}" - explore_term (incr ++ " ") e - | .proj _ _ struct => do - logInfo m!"{incr}proj: {e}" - explore_term (incr ++ " ") struct - -def explore_decl (n : Name) : TermElabM Unit := do - logInfo m!"Name: {n}" - let env ← getEnv - let decl := env.constants.find! n - match decl with - | .defnInfo val => - logInfo m!"About to explore defn: {decl.name}" - logInfo m!"# Type:" - explore_term "" val.type - logInfo m!"# Value:" - explore_term "" val.value - | .axiomInfo _ => throwError m!"axiom: {n}" - | .thmInfo _ => throwError m!"thm: {n}" - | .opaqueInfo _ => throwError m!"opaque: {n}" - | .quotInfo _ => throwError m!"quot: {n}" - | .inductInfo _ => throwError m!"induct: {n}" - | .ctorInfo _ => throwError m!"ctor: {n}" - | .recInfo _ => throwError m!"rec: {n}" - -syntax (name := printDecl) "print_decl " ident : command - -open Lean.Elab.Command - -@[command_elab printDecl] def elabPrintDecl : CommandElab := fun stx => do - liftTermElabM do - let id := stx[1] - addCompletionInfo <| CompletionInfo.id id id.getId (danglingDot := false) {} none - let cs ← resolveGlobalConstWithInfos id - explore_decl cs[0]! - -private def test1 : Nat := 0 -private def test2 (x : Nat) : Nat := x - -print_decl test1 -print_decl test2 - --- A map visitor function for expressions (adapted from `AbstractNestedProofs.visit`) --- The continuation takes as parameters: --- - the current depth of the expression (useful for printing/debugging) --- - the expression to explore -partial def mapVisit (k : Nat → Expr → MetaM Expr) (e : Expr) : MetaM Expr := do - let mapVisitBinders (xs : Array Expr) (k2 : MetaM Expr) : MetaM Expr := do - let localInstances ← getLocalInstances - let mut lctx ← getLCtx - for x in xs do - let xFVarId := x.fvarId! - let localDecl ← xFVarId.getDecl - let type ← mapVisit k localDecl.type - let localDecl := localDecl.setType type - let localDecl ← match localDecl.value? with - | some value => let value ← mapVisit k value; pure <| localDecl.setValue value - | none => pure localDecl - lctx :=lctx.modifyLocalDecl xFVarId fun _ => localDecl - withLCtx lctx localInstances k2 - -- TODO: use a cache? (Lean.checkCache) - let rec visit (i : Nat) (e : Expr) : MetaM Expr := do - -- Explore - let e ← k i e - match e with - | .bvar _ - | .fvar _ - | .mvar _ - | .sort _ - | .lit _ - | .const _ _ => pure e - | .app .. => do e.withApp fun f args => return mkAppN f (← args.mapM (visit (i + 1))) - | .lam .. => - lambdaLetTelescope e fun xs b => - mapVisitBinders xs do mkLambdaFVars xs (← visit (i + 1) b) (usedLetOnly := false) - | .forallE .. => do - forallTelescope e fun xs b => mapVisitBinders xs do mkForallFVars xs (← visit (i + 1) b) - | .letE .. => do - lambdaLetTelescope e fun xs b => mapVisitBinders xs do - mkLambdaFVars xs (← visit (i + 1) b) (usedLetOnly := false) - | .mdata _ b => return e.updateMData! (← visit (i + 1) b) - | .proj _ _ b => return e.updateProj! (← visit (i + 1) b) - visit 0 e - end Diverge -- cgit v1.2.3 From 7206b48a73d6204baea99f4f4675be2518a8f8c2 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 10 Jul 2023 15:06:12 +0200 Subject: Start working on the progress tactic --- backends/lean/Base/Diverge/Elab.lean | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) (limited to 'backends/lean/Base/Diverge') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 96f7abc0..f109e847 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -14,8 +14,8 @@ namespace Diverge syntax (name := divergentDef) declModifiers "divergent" "def" declId ppIndent(optDeclSig) declVal : command -open Utils open Lean Elab Term Meta Primitives Lean.Meta +open Utils /- The following was copied from the `wfRecursion` function. -/ @@ -47,21 +47,6 @@ def getSigmaTypes (ty : Expr) : MetaM (Expr × Expr) := do else pure (args.get! 0, args.get! 1) -/- Like `lambdaTelescopeN` but only destructs a fixed number of lambdas -/ -def lambdaTelescopeN (e : Expr) (n : Nat) (k : Array Expr → Expr → MetaM α) : MetaM α := - lambdaTelescope e fun xs body => do - if xs.size < n then throwError "lambdaTelescopeN: not enough lambdas"; - let xs := xs.extract 0 n - let ys := xs.extract n xs.size - let body ← mkLambdaFVars ys body - k xs body - -/- Like `lambdaTelescope`, but only destructs one lambda - TODO: is there an equivalent of this function somewhere in the - standard library? -/ -def lambdaOne (e : Expr) (k : Expr → Expr → MetaM α) : MetaM α := - lambdaTelescopeN e 1 λ xs b => k (xs.get! 0) b - /- Generate a Sigma type from a list of *variables* (all the expressions must be variables). -- cgit v1.2.3 From a18d899a2c2b9bdd36f4a5a4b70472c12a835a96 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 12 Jul 2023 14:34:55 +0200 Subject: Finish a first version of the progress tactic --- backends/lean/Base/Diverge/Base.lean | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'backends/lean/Base/Diverge') diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index e22eb914..d2c91ff8 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -14,7 +14,7 @@ TODO: Actually, the cases from mathlib seems already quite powerful (https://leanprover-community.github.io/mathlib_docs/tactics.html#cases) For instance: cases h : e - Also: cases_matching + Also: **casesm** - better split tactic - we need conversions to operate on the head of applications. Actually, something like this works: -- cgit v1.2.3 From eb97bdb6761437e492bcf1a95b4fa43d2b69601b Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 12 Jul 2023 18:04:19 +0200 Subject: Improve progress to use assumptions and start working on a nice syntax --- backends/lean/Base/Diverge/Base.lean | 22 ---------------------- 1 file changed, 22 deletions(-) (limited to 'backends/lean/Base/Diverge') diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index d2c91ff8..0a9ea4c4 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -6,28 +6,6 @@ import Mathlib.Tactic.Linarith import Base.Primitives -/- -TODO: -- we want an easier to use cases: - - keeps in the goal an equation of the shape: `t = case` - - if called on Prop terms, uses Classical.em - Actually, the cases from mathlib seems already quite powerful - (https://leanprover-community.github.io/mathlib_docs/tactics.html#cases) - For instance: cases h : e - Also: **casesm** -- better split tactic -- we need conversions to operate on the head of applications. - Actually, something like this works: - ``` - conv at Hl => - apply congr_fun - simp [fix_fuel_P] - ``` - Maybe we need a rpt ... ; focus? -- simplifier/rewriter have a strange behavior sometimes --/ - - /- TODO: this is very useful, but is there more? -/ set_option profiler true set_option profiler.threshold 100 -- cgit v1.2.3 From 2fa3cb8ee04dd7ff4184e3e1000fdc025abc50a4 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 17 Jul 2023 23:37:48 +0200 Subject: Start proving theorems for primitive definitions --- backends/lean/Base/Diverge/Base.lean | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) (limited to 'backends/lean/Base/Diverge') diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index 0a9ea4c4..4ff1d923 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -3,8 +3,7 @@ import Lean.Meta.Tactic.Simp import Init.Data.List.Basic import Mathlib.Tactic.RunCmd import Mathlib.Tactic.Linarith - -import Base.Primitives +import Base.Primitives.Base /- TODO: this is very useful, but is there more? -/ set_option profiler true -- cgit v1.2.3 From 0a8211041814b5eafac0b9e2dbcd956957a322b5 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 18 Jul 2023 18:02:03 +0200 Subject: Move an arithmetic lemma --- backends/lean/Base/Diverge/Base.lean | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) (limited to 'backends/lean/Base/Diverge') diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index 4ff1d923..1d548389 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -4,6 +4,7 @@ import Init.Data.List.Basic import Mathlib.Tactic.RunCmd import Mathlib.Tactic.Linarith import Base.Primitives.Base +import Base.Arith.Base /- TODO: this is very useful, but is there more? -/ set_option profiler true @@ -537,23 +538,16 @@ namespace FixI let j: Fin tys1.length := ⟨ j, jLt ⟩ Eq.mp (by simp) (get_fun tl j) - -- TODO: move - theorem add_one_le_iff_le_ne (n m : Nat) (h1 : m ≤ n) (h2 : m ≠ n) : m + 1 ≤ n := by - -- Damn, those proofs on natural numbers are hard - I wish Omega was in mathlib4... - simp [Nat.add_one_le_iff] - simp [Nat.lt_iff_le_and_ne] - simp_all - def for_all_fin_aux {n : Nat} (f : Fin n → Prop) (m : Nat) (h : m ≤ n) : Prop := if heq: m = n then True else f ⟨ m, by simp_all [Nat.lt_iff_le_and_ne] ⟩ ∧ - for_all_fin_aux f (m + 1) (by simp_all [add_one_le_iff_le_ne]) + for_all_fin_aux f (m + 1) (by simp_all [Arith.add_one_le_iff_le_ne]) termination_by for_all_fin_aux n _ m h => n - m decreasing_by simp_wf apply Nat.sub_add_lt_sub <;> simp - simp_all [add_one_le_iff_le_ne] + simp_all [Arith.add_one_le_iff_le_ne] def for_all_fin {n : Nat} (f : Fin n → Prop) := for_all_fin_aux f 0 (by simp) @@ -603,7 +597,7 @@ namespace FixI apply hi <;> simp_all . unfold for_all_fin_aux at hf simp_all - . simp_all [add_one_le_iff_le_ne] + . simp_all [Arith.add_one_le_iff_le_ne] -- TODO: this is not necessary anymore theorem for_all_fin_imp_forall (n : Nat) (f : Fin n → Prop) : -- cgit v1.2.3