diff options
author | Son HO | 2023-12-13 09:55:58 +0100 |
---|---|---|
committer | GitHub | 2023-12-13 09:55:58 +0100 |
commit | 22009543d86895b9f680d3a4abdea00302ad5f1e (patch) | |
tree | 82158f0f6716e932214d1eaee6701539bf7899c6 /backends/lean/Base | |
parent | e4798a8581cd29deab12e79f3d552635b2a7f60d (diff) | |
parent | 8645fcb01e13fb2b2630da952ec9384852dd0e6e (diff) |
Merge pull request #51 from AeneasVerif/son_merge_back2
Improve the `pspec` attribute and the `divergent` encoding
Diffstat (limited to '')
-rw-r--r-- | backends/lean/Base/Diverge/Base.lean | 572 | ||||
-rw-r--r-- | backends/lean/Base/Diverge/Elab.lean | 1155 | ||||
-rw-r--r-- | backends/lean/Base/Diverge/ElabBase.lean | 69 | ||||
-rw-r--r-- | backends/lean/Base/Extensions.lean | 47 | ||||
-rw-r--r-- | backends/lean/Base/Primitives/Scalar.lean | 126 | ||||
-rw-r--r-- | backends/lean/Base/Progress/Base.lean | 290 | ||||
-rw-r--r-- | backends/lean/Base/Progress/Progress.lean | 91 | ||||
-rw-r--r-- | backends/lean/Base/Utils.lean | 114 |
8 files changed, 1690 insertions, 774 deletions
diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index 6a52387d..9458c926 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -5,6 +5,7 @@ import Mathlib.Tactic.RunCmd import Mathlib.Tactic.Linarith import Base.Primitives.Base import Base.Arith.Base +import Base.Diverge.ElabBase /- TODO: this is very useful, but is there more? -/ set_option profiler true @@ -12,6 +13,78 @@ set_option profiler.threshold 100 namespace Diverge +/- Auxiliary lemmas -/ +namespace Lemmas + -- TODO: not necessary anymore + def for_all_fin_aux {n : Nat} (f : Fin n → Prop) (m : Nat) (h : m ≤ n) : Prop := + if heq: m = n then True + else + f ⟨ m, by simp_all [Nat.lt_iff_le_and_ne] ⟩ ∧ + for_all_fin_aux f (m + 1) (by simp_all [Arith.add_one_le_iff_le_ne]) + termination_by for_all_fin_aux n _ m h => n - m + decreasing_by + simp_wf + apply Nat.sub_add_lt_sub <;> try simp + simp_all [Arith.add_one_le_iff_le_ne] + + def for_all_fin {n : Nat} (f : Fin n → Prop) := for_all_fin_aux f 0 (by simp) + + theorem for_all_fin_aux_imp_forall {n : Nat} (f : Fin n → Prop) (m : Nat) : + (h : m ≤ n) → + for_all_fin_aux f m h → ∀ i, m ≤ i.val → f i + := by + generalize h: (n - m) = k + revert m + induction k -- TODO: induction h rather? + case zero => + simp_all + intro m h1 h2 + have h: n = m := by + linarith + unfold for_all_fin_aux; simp_all + simp_all + -- There is no i s.t. m ≤ i + intro i h3; cases i; simp_all + linarith + case succ k hi => + intro m hk hmn + intro hf i hmi + have hne: m ≠ n := by + have hineq := Nat.lt_of_sub_eq_succ hk + linarith + -- m = i? + if heq: m = i then + -- Yes: simply use the `for_all_fin_aux` hyp + unfold for_all_fin_aux at hf + simp_all + else + -- No: use the induction hypothesis + have hlt: m < i := by simp_all [Nat.lt_iff_le_and_ne] + have hineq: m + 1 ≤ n := by + have hineq := Nat.lt_of_sub_eq_succ hk + simp [*, Nat.add_one_le_iff] + have heq1: n - (m + 1) = k := by + -- TODO: very annoying arithmetic proof + simp [Nat.sub_eq_iff_eq_add hineq] + have hineq1: m ≤ n := by linarith + simp [Nat.sub_eq_iff_eq_add hineq1] at hk + simp_arith [hk] + have hi := hi (m + 1) heq1 hineq + apply hi <;> simp_all + . unfold for_all_fin_aux at hf + simp_all + . simp_all [Arith.add_one_le_iff_le_ne] + + -- TODO: this is not necessary anymore + theorem for_all_fin_imp_forall (n : Nat) (f : Fin n → Prop) : + for_all_fin f → ∀ i, f i + := by + intro Hf i + apply for_all_fin_aux_imp_forall <;> try assumption + simp + +end Lemmas + namespace Fix open Primitives @@ -436,6 +509,10 @@ namespace FixI /- Indexed fixed-point: definitions with indexed types, convenient to use for mutually recursive definitions. We simply port the definitions and proofs from Fix to a more specific case. + + Remark: the index designates the function in the mutually recursive group + (it should be a finite type). We make the return type depend on the input + type because we group the type parameters in the input type. -/ open Primitives Fix @@ -505,7 +582,6 @@ namespace FixI kk_ty id a b → kk_ty id a b abbrev in_out_ty : Type (imax (u + 1) (v + 1)) := (in_ty : Type u) × ((x:in_ty) → Type v) - -- TODO: remove? abbrev mk_in_out_ty (in_ty : Type u) (out_ty : in_ty → Type v) : in_out_ty := Sigma.mk in_ty out_ty @@ -538,73 +614,6 @@ namespace FixI let j: Fin tys1.length := ⟨ j, jLt ⟩ Eq.mp (by simp) (get_fun tl j) - def for_all_fin_aux {n : Nat} (f : Fin n → Prop) (m : Nat) (h : m ≤ n) : Prop := - if heq: m = n then True - else - f ⟨ m, by simp_all [Nat.lt_iff_le_and_ne] ⟩ ∧ - for_all_fin_aux f (m + 1) (by simp_all [Arith.add_one_le_iff_le_ne]) - termination_by for_all_fin_aux n _ m h => n - m - decreasing_by - simp_wf - apply Nat.sub_add_lt_sub <;> try simp - simp_all [Arith.add_one_le_iff_le_ne] - - def for_all_fin {n : Nat} (f : Fin n → Prop) := for_all_fin_aux f 0 (by simp) - - theorem for_all_fin_aux_imp_forall {n : Nat} (f : Fin n → Prop) (m : Nat) : - (h : m ≤ n) → - for_all_fin_aux f m h → ∀ i, m ≤ i.val → f i - := by - generalize h: (n - m) = k - revert m - induction k -- TODO: induction h rather? - case zero => - simp_all - intro m h1 h2 - have h: n = m := by - linarith - unfold for_all_fin_aux; simp_all - simp_all - -- There is no i s.t. m ≤ i - intro i h3; cases i; simp_all - linarith - case succ k hi => - intro m hk hmn - intro hf i hmi - have hne: m ≠ n := by - have hineq := Nat.lt_of_sub_eq_succ hk - linarith - -- m = i? - if heq: m = i then - -- Yes: simply use the `for_all_fin_aux` hyp - unfold for_all_fin_aux at hf - simp_all - else - -- No: use the induction hypothesis - have hlt: m < i := by simp_all [Nat.lt_iff_le_and_ne] - have hineq: m + 1 ≤ n := by - have hineq := Nat.lt_of_sub_eq_succ hk - simp [*, Nat.add_one_le_iff] - have heq1: n - (m + 1) = k := by - -- TODO: very annoying arithmetic proof - simp [Nat.sub_eq_iff_eq_add hineq] - have hineq1: m ≤ n := by linarith - simp [Nat.sub_eq_iff_eq_add hineq1] at hk - simp_arith [hk] - have hi := hi (m + 1) heq1 hineq - apply hi <;> simp_all - . unfold for_all_fin_aux at hf - simp_all - . simp_all [Arith.add_one_le_iff_le_ne] - - -- TODO: this is not necessary anymore - theorem for_all_fin_imp_forall (n : Nat) (f : Fin n → Prop) : - for_all_fin f → ∀ i, f i - := by - intro Hf i - apply for_all_fin_aux_imp_forall <;> try assumption - simp - /- Automating the proofs -/ @[simp] theorem is_valid_p_same (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) (x : Result c) : @@ -707,6 +716,218 @@ namespace FixI end FixI +namespace FixII + /- Similar to FixI, but we split the input arguments between the type parameters + and the input values. + -/ + open Primitives Fix + + -- The index type + variable {id : Type u} + + -- The input/output types + variable {ty : id → Type v} {a : (i:id) → ty i → Type w} {b : (i:id) → ty i → Type x} + + -- Monotonicity relation over monadic arrows (i.e., Kleisli arrows) + def karrow_rel (k1 k2 : (i:id) → (t:ty i) → (a i t) → Result (b i t)) : Prop := + ∀ i t x, result_rel (k1 i t x) (k2 i t x) + + def kk_to_gen (k : (i:id) → (t:ty i) → (x:a i t) → Result (b i t)) : + (x: (i:id) × (t:ty i) × (a i t)) → Result (b x.fst x.snd.fst) := + λ ⟨ i, t, x ⟩ => k i t x + + def kk_of_gen (k : (x: (i:id) × (t:ty i) × (a i t)) → Result (b x.fst x.snd.fst)) : + (i:id) → (t:ty i) → a i t → Result (b i t) := + λ i t x => k ⟨ i, t, x ⟩ + + def k_to_gen (k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) : + ((x: (i:id) × (t:ty i) × (a i t)) → Result (b x.fst x.snd.fst)) → (x: (i:id) × (t:ty i) × (a i t)) → Result (b x.fst x.snd.fst) := + λ kk => kk_to_gen (k (kk_of_gen kk)) + + def k_of_gen (k : ((x: (i:id) × (t:ty i) × (a i t)) → Result (b x.fst x.snd.fst)) → (x: (i:id) × (t:ty i) × (a i t)) → Result (b x.fst x.snd.fst)) : + ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t) := + λ kk => kk_of_gen (k (kk_to_gen kk)) + + def e_to_gen (e : ((i:id) → (t:ty i) → a i t → Result (b i t)) → Result c) : + ((x: (i:id) × (t:ty i) × (a i t)) → Result (b x.fst x.snd.fst)) → Result c := + λ k => e (kk_of_gen k) + + def is_valid_p (k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) + (e : ((i:id) → (t:ty i) → a i t → Result (b i t)) → Result c) : Prop := + Fix.is_valid_p (k_to_gen k) (e_to_gen e) + + def is_valid (f : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) : Prop := + ∀ k i t x, is_valid_p k (λ k => f k i t x) + + def fix + (f : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) : + (i:id) → (t:ty i) → a i t → Result (b i t) := + kk_of_gen (Fix.fix (k_to_gen f)) + + theorem is_valid_fix_fixed_eq + {{f : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)}} + (Hvalid : is_valid f) : + fix f = f (fix f) := by + have Hvalid' : Fix.is_valid (k_to_gen f) := by + intro k x + simp only [is_valid, is_valid_p] at Hvalid + let ⟨ i, t, x ⟩ := x + have Hvalid := Hvalid (k_of_gen k) i t x + simp only [k_to_gen, k_of_gen, kk_to_gen, kk_of_gen] at Hvalid + refine Hvalid + have Heq := Fix.is_valid_fix_fixed_eq Hvalid' + simp [fix] + conv => lhs; rw [Heq] + + /- Some utilities to define the mutually recursive functions -/ + + -- TODO: use more + abbrev kk_ty (id : Type u) (ty : id → Type v) (a : (i:id) → ty i → Type w) (b : (i:id) → ty i → Type x) := + (i:id) → (t:ty i) → a i t → Result (b i t) + abbrev k_ty (id : Type u) (ty : id → Type v) (a : (i:id) → ty i → Type w) (b : (i:id) → ty i → Type x) := + kk_ty id ty a b → kk_ty id ty a b + + abbrev in_out_ty : Type (imax (u + 1) (imax (v + 1) (w + 1))) := + (ty : Type u) × (ty → Type v) × (ty → Type w) + abbrev mk_in_out_ty (ty : Type u) (in_ty : ty → Type v) (out_ty : ty → Type w) : + in_out_ty := + Sigma.mk ty (Prod.mk in_ty out_ty) + + -- Initially, we had left out the parameters id, a and b. + -- However, by parameterizing Funs with those parameters, we can state + -- and prove lemmas like Funs.is_valid_p_is_valid_p + inductive Funs (id : Type u) (ty : id → Type v) + (a : (i:id) → ty i → Type w) (b : (i:id) → ty i → Type x) : + List in_out_ty.{v, w, x} → Type (max (u + 1) (max (v + 1) (max (w + 1) (x + 1)))) := + | Nil : Funs id ty a b [] + | Cons {it: Type v} {ity : it → Type w} {oty : it → Type x} {tys : List in_out_ty} + (f : kk_ty id ty a b → (i:it) → (x:ity i) → Result (oty i)) (tl : Funs id ty a b tys) : + Funs id ty a b (⟨ it, ity, oty ⟩ :: tys) + + def get_fun {tys : List in_out_ty} (fl : Funs id ty a b tys) : + (i : Fin tys.length) → kk_ty id ty a b → (t : (tys.get i).fst) → + ((tys.get i).snd.fst t) → Result ((tys.get i).snd.snd t) := + match fl with + | .Nil => λ i => by have h:= i.isLt; simp at h + | @Funs.Cons id ty a b it ity oty tys1 f tl => + λ ⟨ i, iLt ⟩ => + match i with + | 0 => + Eq.mp (by simp [List.get]) f + | .succ j => + have jLt: j < tys1.length := by + simp at iLt + revert iLt + simp_arith + let j: Fin tys1.length := ⟨ j, jLt ⟩ + Eq.mp (by simp) (get_fun tl j) + + /- Automating the proofs -/ + @[simp] theorem is_valid_p_same + (k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) (x : Result c) : + is_valid_p k (λ _ => x) := by + simp [is_valid_p, k_to_gen, e_to_gen] + + @[simp] theorem is_valid_p_rec + (k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) (i : id) (t : ty i) (x : a i t) : + is_valid_p k (λ k => k i t x) := by + simp [is_valid_p, k_to_gen, e_to_gen, kk_to_gen, kk_of_gen] + + theorem is_valid_p_ite + (k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) + (cond : Prop) [h : Decidable cond] + {e1 e2 : ((i:id) → (t:ty i) → a i t → Result (b i t)) → Result c} + (he1: is_valid_p k e1) (he2 : is_valid_p k e2) : + is_valid_p k (λ k => ite cond (e1 k) (e2 k)) := by + split <;> assumption + + theorem is_valid_p_dite + (k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) + (cond : Prop) [h : Decidable cond] + {e1 : ((i:id) → (t:ty i) → a i t → Result (b i t)) → cond → Result c} + {e2 : ((i:id) → (t:ty i) → a i t → Result (b i t)) → Not cond → Result c} + (he1: ∀ x, is_valid_p k (λ k => e1 k x)) + (he2 : ∀ x, is_valid_p k (λ k => e2 k x)) : + is_valid_p k (λ k => dite cond (e1 k) (e2 k)) := by + split <;> simp [*] + + theorem is_valid_p_bind + {{k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)}} + {{g : ((i:id) → (t:ty i) → a i t → Result (b i t)) → Result c}} + {{h : c → ((i:id) → (t:ty i) → a i t → Result (b i t)) → Result d}} + (Hgvalid : is_valid_p k g) + (Hhvalid : ∀ y, is_valid_p k (h y)) : + is_valid_p k (λ k => do let y ← g k; h y k) := by + apply Fix.is_valid_p_bind + . apply Hgvalid + . apply Hhvalid + + def Funs.is_valid_p + (k : k_ty id ty a b) + (fl : Funs id ty a b tys) : + Prop := + match fl with + | .Nil => True + | .Cons f fl => + (∀ i x, FixII.is_valid_p k (λ k => f k i x)) ∧ fl.is_valid_p k + + theorem Funs.is_valid_p_Nil (k : k_ty id ty a b) : + Funs.is_valid_p k Funs.Nil := by simp [Funs.is_valid_p] + + def Funs.is_valid_p_is_valid_p_aux + {k : k_ty id ty a b} + {tys : List in_out_ty} + (fl : Funs id ty a b tys) (Hvalid : is_valid_p k fl) : + ∀ (i : Fin tys.length) (t : (tys.get i).fst) (x : (tys.get i).snd.fst t), + FixII.is_valid_p k (fun k => get_fun fl i k t x) := by + -- Prepare the induction + have ⟨ n, Hn ⟩ : { n : Nat // tys.length = n } := ⟨ tys.length, by rfl ⟩ + revert tys fl Hvalid + induction n + -- + case zero => + intro tys fl Hvalid Hlen; + have Heq: tys = [] := by cases tys <;> simp_all + intro i x + simp_all + have Hi := i.isLt + simp_all + case succ n Hn => + intro tys fl Hvalid Hlen i x; + cases fl <;> simp at Hlen i x Hvalid + rename_i ity oty tys f fl + have ⟨ Hvf, Hvalid ⟩ := Hvalid + have Hvf1: is_valid_p k fl := by + simp [Hvalid, Funs.is_valid_p] + have Hn := @Hn tys fl Hvf1 (by simp [*]) + -- Case disjunction on i + match i with + | ⟨ 0, _ ⟩ => + simp at x + simp [get_fun] + apply (Hvf x) + | ⟨ .succ j, HiLt ⟩ => + simp_arith at HiLt + simp at x + let j : Fin (List.length tys) := ⟨ j, by simp_arith [HiLt] ⟩ + have Hn := Hn j x + apply Hn + + def Funs.is_valid_p_is_valid_p + (tys : List in_out_ty) + (k : k_ty (Fin (List.length tys)) (λ i => (tys.get i).fst) + (fun i t => (List.get tys i).snd.fst t) (fun i t => (List.get tys i).snd.snd t)) + (fl : Funs (Fin tys.length) (λ i => (tys.get i).fst) + (λ i t => (tys.get i).snd.fst t) (λ i t => (tys.get i).snd.snd t) tys) : + fl.is_valid_p k → + ∀ (i : Fin tys.length) (t : (tys.get i).fst) (x : (tys.get i).snd.fst t), + FixII.is_valid_p k (fun k => get_fun fl i k t x) + := by + intro Hvalid + apply is_valid_p_is_valid_p_aux; simp [*] + +end FixII + namespace Ex1 /- An example of use of the fixed-point -/ open Primitives Fix @@ -1133,3 +1354,218 @@ namespace Ex6 Heqix end Ex6 + +namespace Ex7 + /- `list_nth` again, but this time we use FixII -/ + open Primitives FixII + + @[simp] def tys.{u} : List in_out_ty := + [mk_in_out_ty (Type u) (λ a => List a × Int) (λ a => a)] + + @[simp] def ty (i : Fin 1) := (tys.get i).fst + @[simp] def input_ty (i : Fin 1) (t : ty i) : Type u := (tys.get i).snd.fst t + @[simp] def output_ty (i : Fin 1) (t : ty i) : Type u := (tys.get i).snd.snd t + + def list_nth_body.{u} (k : (i:Fin 1) → (t:ty i) → input_ty i t → Result (output_ty i t)) + (a : Type u) (x : List a × Int) : Result a := + let ⟨ ls, i ⟩ := x + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else k 0 a ⟨ tl, i - 1 ⟩ + + @[simp] def bodies : + Funs (Fin 1) ty input_ty output_ty tys := + Funs.Cons list_nth_body Funs.Nil + + def body (k : (i : Fin 1) → (t : ty i) → (x : input_ty i t) → Result (output_ty i t)) (i: Fin 1) : + (t : ty i) → (x : input_ty i t) → Result (output_ty i t) := get_fun bodies i k + + theorem body_is_valid: is_valid body := by + -- Split the proof into proofs of validity of the individual bodies + rw [is_valid] + simp only [body] + intro k + apply (Funs.is_valid_p_is_valid_p tys) + simp [Funs.is_valid_p] + (repeat (apply And.intro)); intro x; try simp at x + simp only [list_nth_body] + -- Prove the validity of the individual bodies + intro k x + split <;> try simp + split <;> simp + + -- Writing the proof terms explicitly + theorem list_nth_body_is_valid' (k : k_ty (Fin 1) ty input_ty output_ty) + (a : Type u) (x : List a × Int) : is_valid_p k (fun k => list_nth_body k a x) := + let ⟨ ls, i ⟩ := x + match ls with + | [] => is_valid_p_same k (.fail .panic) + | hd :: tl => + is_valid_p_ite k (Eq i 0) (is_valid_p_same k (.ret hd)) (is_valid_p_rec k 0 a ⟨tl, i-1⟩) + + theorem body_is_valid' : is_valid body := + fun k => + Funs.is_valid_p_is_valid_p tys k bodies + (And.intro (list_nth_body_is_valid' k) (Funs.is_valid_p_Nil k)) + + def list_nth {a: Type u} (ls : List a) (i : Int) : Result a := + fix body 0 a ⟨ ls , i ⟩ + + -- The unfolding equation - diverges if `i < 0` + theorem list_nth_eq (ls : List a) (i : Int) : + list_nth ls i = + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else list_nth tl (i - 1) + := by + have Heq := is_valid_fix_fixed_eq body_is_valid + simp [list_nth] + conv => lhs; rw [Heq] + + -- Write the proof term explicitly: the generation of the proof term (without tactics) + -- is automatable, and the proof term is actually a lot simpler and smaller when we + -- don't use tactics. + theorem list_nth_eq'.{u} {a : Type u} (ls : List a) (i : Int) : + list_nth ls i = + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else list_nth tl (i - 1) + := + -- Use the fixed-point equation + have Heq := is_valid_fix_fixed_eq body_is_valid.{u} + -- Add the index + have Heqi := congr_fun Heq 0 + -- Add the type parameter + have Heqia := congr_fun Heqi a + -- Add the input + have Heqix := congr_fun Heqia (ls, i) + -- Done + Heqix + +end Ex7 + +namespace Ex8 + /- Higher-order example, with FixII -/ + open Primitives FixII + + variable {id : Type u} {ty : id → Type v} + variable {a : (i:id) → ty i → Type w} {b : (i:id) → ty i → Type x} + + /- An auxiliary function, which doesn't require the fixed-point -/ + def map {a : Type y} {b : Type z} (f : a → Result b) (ls : List a) : Result (List b) := + match ls with + | [] => .ret [] + | hd :: tl => + do + let hd ← f hd + let tl ← map f tl + .ret (hd :: tl) + + /- The validity theorems for `map`, generic in `f` -/ + + -- This is not the most general lemma, but we keep it to test the `divergence` encoding on a simple case + @[divspec] + theorem map_is_valid_simple + (i : id) (t : ty i) + (k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) + (ls : List (a i t)) : + is_valid_p k (λ k => map (k i t) ls) := by + induction ls <;> simp [map] + apply is_valid_p_bind <;> try simp_all + intros + apply is_valid_p_bind <;> try simp_all + + @[divspec] + theorem map_is_valid + (d : Type y) + {{f : ((i:id) → (t : ty i) → a i t → Result (b i t)) → d → Result c}} + (k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) + (Hfvalid : ∀ x1, is_valid_p k (fun kk1 => f kk1 x1)) + (ls : List d) : + is_valid_p k (λ k => map (f k) ls) := by + induction ls <;> simp [map] + apply is_valid_p_bind <;> try simp_all + intros + apply is_valid_p_bind <;> try simp_all + +end Ex8 + +namespace Ex9 + /- An example which uses map -/ + open Primitives FixII Ex8 + + inductive Tree (a : Type u) := + | leaf (x : a) + | node (tl : List (Tree a)) + + @[simp] def tys.{u} : List in_out_ty := + [mk_in_out_ty (Type u) (λ a => Tree a) (λ a => Tree a)] + + @[simp] def ty (i : Fin 1) := (tys.get i).fst + @[simp] def input_ty (i : Fin 1) (t : ty i) : Type u := (tys.get i).snd.fst t + @[simp] def output_ty (i : Fin 1) (t : ty i) : Type u := (tys.get i).snd.snd t + + def id_body.{u} (k : (i:Fin 1) → (t:ty i) → input_ty i t → Result (output_ty i t)) + (a : Type u) (t : Tree a) : Result (Tree a) := + match t with + | .leaf x => .ret (.leaf x) + | .node tl => + do + let tl ← map (k 0 a) tl + .ret (.node tl) + + @[simp] def bodies : + Funs (Fin 1) ty input_ty output_ty tys := + Funs.Cons id_body Funs.Nil + + theorem id_body_is_valid : + ∀ (k : ((i : Fin 1) → (t : ty i) → input_ty i t → Result (output_ty i t)) → (i : Fin 1) → (t : ty i) → input_ty i t → Result (output_ty i t)) + (a : Type u) (x : Tree a), + @is_valid_p (Fin 1) ty input_ty output_ty (output_ty 0 a) k (λ k => id_body k a x) := by + intro k a x + simp only [id_body] + split <;> try simp + apply is_valid_p_bind <;> try simp [*] + -- We have to show that `map k tl` is valid + -- Remark: `map_is_valid` doesn't work here, we need the specialized version + apply map_is_valid_simple + + def body (k : (i : Fin 1) → (t : ty i) → (x : input_ty i t) → Result (output_ty i t)) (i: Fin 1) : + (t : ty i) → (x : input_ty i t) → Result (output_ty i t) := get_fun bodies i k + + theorem body_is_valid : is_valid body := + fun k => + Funs.is_valid_p_is_valid_p tys k bodies + (And.intro (id_body_is_valid k) (Funs.is_valid_p_Nil k)) + + def id {a: Type u} (t : Tree a) : Result (Tree a) := + fix body 0 a t + + -- Writing the proof term explicitly + theorem id_eq' {a : Type u} (t : Tree a) : + id t = + (match t with + | .leaf x => .ret (.leaf x) + | .node tl => + do + let tl ← map id tl + .ret (.node tl)) + := + -- The unfolding equation + have Heq := is_valid_fix_fixed_eq body_is_valid.{u} + -- Add the index + have Heqi := congr_fun Heq 0 + -- Add the type parameter + have Heqia := congr_fun Heqi a + -- Add the input + have Heqix := congr_fun Heqia t + -- Done + Heqix + +end Ex9 diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index c6628486..6115b13b 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -17,15 +17,24 @@ syntax (name := divergentDef) open Lean Elab Term Meta Primitives Lean.Meta open Utils +def normalize_let_bindings := true + /- The following was copied from the `wfRecursion` function. -/ open WF in +-- TODO: use those +def UnitType := Expr.const ``PUnit [Level.succ .zero] +def UnitValue := Expr.const ``PUnit.unit [Level.succ .zero] + +def mkProdType (x y : Expr) : MetaM Expr := + mkAppM ``Prod #[x, y] + def mkProd (x y : Expr) : MetaM Expr := mkAppM ``Prod.mk #[x, y] -def mkInOutTy (x y : Expr) : MetaM Expr := - mkAppM ``FixI.mk_in_out_ty #[x, y] +def mkInOutTy (x y z : Expr) : MetaM Expr := do + mkAppM ``FixII.mk_in_out_ty #[x, y, z] -- Return the `a` in `Return a` def getResultTy (ty : Expr) : MetaM Expr := @@ -47,6 +56,17 @@ def getSigmaTypes (ty : Expr) : MetaM (Expr × Expr) := do else pure (args.get! 0, args.get! 1) +/- Make a sigma type. + + `x` should be a variable, and `ty` and type which (might) uses `x` + -/ +def mkSigmaType (x : Expr) (sty : Expr) : MetaM Expr := do + trace[Diverge.def.sigmas] "mkSigmaType: {x} {sty}" + let alpha ← inferType x + let beta ← mkLambdaFVars #[x] sty + trace[Diverge.def.sigmas] "mkSigmaType: ({alpha}) ({beta})" + mkAppOptM ``Sigma #[some alpha, some beta] + /- Generate a Sigma type from a list of *variables* (all the expressions must be variables). @@ -60,20 +80,78 @@ def getSigmaTypes (ty : Expr) : MetaM (Expr × Expr) := do def mkSigmasType (xl : List Expr) : MetaM Expr := match xl with | [] => do - trace[Diverge.def.sigmas] "mkSigmasOfTypes: []" - pure (Expr.const ``PUnit.unit []) + trace[Diverge.def.sigmas] "mkSigmasType: []" + pure (Expr.const ``PUnit [Level.succ .zero]) | [x] => do - trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}]" - let ty ← Lean.Meta.inferType x + trace[Diverge.def.sigmas] "mkSigmasType: [{x}]" + let ty ← inferType x pure ty | x :: xl => do - trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]" - let alpha ← Lean.Meta.inferType x + trace[Diverge.def.sigmas] "mkSigmasType: [{x}::{xl}]" let sty ← mkSigmasType xl - trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]: alpha={alpha}, sty={sty}" - let beta ← mkLambdaFVars #[x] sty - trace[Diverge.def.sigmas] "mkSigmasOfTypes: ({alpha}) ({beta})" - mkAppOptM ``Sigma #[some alpha, some beta] + mkSigmaType x sty + +/- Generate a product type from a list of *variables* (this is similar to `mkSigmas`). + + Example: + - xl = [(ls:List a), (i:Int)] + + Generates: + `List a × Int` + -/ +def mkProdsType (xl : List Expr) : MetaM Expr := + match xl with + | [] => do + trace[Diverge.def.prods] "mkProdsType: []" + pure (Expr.const ``PUnit [Level.succ .zero]) + | [x] => do + trace[Diverge.def.prods] "mkProdsType: [{x}]" + let ty ← inferType x + pure ty + | x :: xl => do + trace[Diverge.def.prods] "mkProdsType: [{x}::{xl}]" + let ty ← inferType x + let xl_ty ← mkProdsType xl + mkAppM ``Prod #[ty, xl_ty] + +/- Split the input arguments between the types and the "regular" arguments. + + We do something simple: we treat an input argument as an + input type iff it appears in the type of the following arguments. + + Note that what really matters is that we find the arguments which appear + in the output type. + + Also, we stop at the first input that we treat as an + input type. + -/ +def splitInputArgs (in_tys : Array Expr) (out_ty : Expr) : MetaM (Array Expr × Array Expr) := do + -- Look for the first parameter which appears in the subsequent parameters + let rec splitAux (in_tys : List Expr) : MetaM (HashSet FVarId × List Expr × List Expr) := + match in_tys with + | [] => do + let fvars ← getFVarIds (← inferType out_ty) + pure (fvars, [], []) + | ty :: in_tys => do + let (fvars, in_tys, in_args) ← splitAux in_tys + -- Have we already found where to split between type variables/regular + -- variables? + if ¬ in_tys.isEmpty then + -- The fvars set is now useless: no need to update it anymore + pure (fvars, ty :: in_tys, in_args) + else + -- Check if ty appears in the set of free variables: + let ty_id := ty.fvarId! + if fvars.contains ty_id then + -- We must split here. Note that we don't need to update the fvars + -- set: it is not useful anymore + pure (fvars, [ty], in_args) + else + -- We must split later: update the fvars set + let fvars := fvars.insertMany (← getFVarIds (← inferType ty)) + pure (fvars, [], ty :: in_args) + let (_, in_tys, in_args) ← splitAux in_tys.data + pure (Array.mk in_tys, Array.mk in_args) /- Apply a lambda expression to some arguments, simplifying the lambdas -/ def applyLambdaToArgs (e : Expr) (xs : Array Expr) : MetaM Expr := do @@ -105,7 +183,7 @@ def mkSigmasVal (ty : Expr) (xl : List Expr) : MetaM Expr := match xl with | [] => do trace[Diverge.def.sigmas] "mkSigmasVal: []" - pure (Expr.const ``PUnit.unit []) + pure (Expr.const ``PUnit.unit [Level.succ .zero]) | [x] => do trace[Diverge.def.sigmas] "mkSigmasVal: [{x}]" pure x @@ -122,6 +200,17 @@ def mkSigmasVal (ty : Expr) (xl : List Expr) : MetaM Expr := trace[Diverge.def.sigmas] "mkSigmasVal:\n{alpha}\n{beta}\n{fst}\n{snd}" mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd] +/- Group a list of expressions into a (non-dependent) tuple -/ +def mkProdsVal (xl : List Expr) : MetaM Expr := + match xl with + | [] => + pure (Expr.const ``PUnit.unit [Level.succ .zero]) + | [x] => do + pure x + | x :: xl => do + let xl ← mkProdsVal xl + mkAppM ``Prod.mk #[x, xl] + def mkAnonymous (s : String) (i : Nat) : Name := .num (.str .anonymous s) i @@ -159,31 +248,31 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met match xl with | [] => do -- This would be unexpected - throwError "mkSigmasMatch: empyt list of input parameters" + throwError "mkSigmasMatch: empty list of input parameters" | [x] => do -- In the example given for the explanations: this is the inner match case trace[Diverge.def.sigmas] "mkSigmasMatch: [{x}]" mkLambdaFVars #[x] out | fst :: xl => do - -- In the example given for the explanations: this is the outer match case - -- Remark: for the naming purposes, we use the same convention as for the - -- fields and parameters in `Sigma.casesOn` and `Sigma.mk` (looking at - -- those definitions might help) - -- - -- We want to build the match expression: - -- ``` - -- λ scrut => - -- match scrut with - -- | Sigma.mk x ... -- the hole is given by a recursive call on the tail - -- ``` + /- In the example given for the explanations: this is the outer match case + Remark: for the naming purposes, we use the same convention as for the + fields and parameters in `Sigma.casesOn` and `Sigma.mk` (looking at + those definitions might help) + + We want to build the match expression: + ``` + λ scrut => + match scrut with + | Sigma.mk x ... -- the hole is given by a recursive call on the tail + ``` -/ trace[Diverge.def.sigmas] "mkSigmasMatch: [{fst}::{xl}]" - let alpha ← Lean.Meta.inferType fst + let alpha ← inferType fst let snd_ty ← mkSigmasType xl let beta ← mkLambdaFVars #[fst] snd_ty let snd ← mkSigmasMatch xl out (index + 1) let mk ← mkLambdaFVars #[fst] snd -- Introduce the "scrut" variable - let scrut_ty ← mkSigmasType (fst :: xl) + let scrut_ty ← mkSigmaType fst snd_ty withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do trace[Diverge.def.sigmas] "mkSigmasMatch: scrut: ({scrut}) : ({← inferType scrut})" -- TODO: make the computation of the motive more efficient @@ -206,6 +295,67 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met trace[Diverge.def.sigmas] "mkSigmasMatch: sm: {sm}" pure sm +/- This is similar to `mkSigmasMatch`, but with non-dependent tuples + + Remark: factor out with `mkSigmasMatch`? This is extremely similar. +-/ +partial def mkProdsMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : MetaM Expr := + match xl with + | [] => do + -- This would be unexpected + throwError "mkProdsMatch: empty list of input parameters" + | [x] => do + -- In the example given for the explanations: this is the inner match case + trace[Diverge.def.prods] "mkProdsMatch: [{x}]" + mkLambdaFVars #[x] out + | fst :: xl => do + trace[Diverge.def.prods] "mkProdsMatch: [{fst}::{xl}]" + let alpha ← inferType fst + let beta ← mkProdsType xl + let snd ← mkProdsMatch xl out (index + 1) + let mk ← mkLambdaFVars #[fst] snd + -- Introduce the "scrut" variable + let scrut_ty ← mkProdType alpha beta + withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do + trace[Diverge.def.prods] "mkProdsMatch: scrut: ({scrut}) : ({← inferType scrut})" + -- TODO: make the computation of the motive more efficient + let motive ← do + let out_ty ← inferType out + mkLambdaFVars #[scrut] out_ty + -- The final expression: putting everything together + trace[Diverge.def.prods] "mkProdsMatch:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})" + let sm ← mkAppOptM ``Prod.casesOn #[some alpha, some beta, some motive, some scrut, some mk] + -- Abstracting the "scrut" variable + let sm ← mkLambdaFVars #[scrut] sm + trace[Diverge.def.prods] "mkProdsMatch: sm: {sm}" + pure sm + +/- Same as `mkSigmasMatch` but also accepts an empty list of inputs, in which case + it generates the expression: + ``` + λ () => e + ``` -/ +def mkSigmasMatchOrUnit (xl : List Expr) (out : Expr) : MetaM Expr := + if xl.isEmpty then do + let scrut_ty := Expr.const ``PUnit [Level.succ .zero] + withLocalDeclD (mkAnonymous "scrut" 0) scrut_ty fun scrut => do + mkLambdaFVars #[scrut] out + else + mkSigmasMatch xl out + +/- Same as `mkProdsMatch` but also accepts an empty list of inputs, in which case + it generates the expression: + ``` + λ () => e + ``` -/ +def mkProdsMatchOrUnit (xl : List Expr) (out : Expr) : MetaM Expr := + if xl.isEmpty then do + let scrut_ty := Expr.const ``PUnit [Level.succ .zero] + withLocalDeclD (mkAnonymous "scrut" 0) scrut_ty fun scrut => do + mkLambdaFVars #[scrut] out + else + mkProdsMatch xl out + /- Small tests for list_nth: give a model of what `mkSigmasMatch` should generate -/ private def list_nth_out_ty_inner (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) := @Sigma.casesOn (List a) @@ -238,6 +388,52 @@ def mkFinVal (n i : Nat) : MetaM Expr := do let ofNat ← mkAppOptM ``Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat #[n_lit, i_lit] mkAppOptM ``OfNat.ofNat #[none, none, ofNat] +/- Information about the type of a function in a declaration group. + + In the comments about the fields, we take as example the + `list_nth (α : Type) (ls : List α) (i : Int) : Result α` function. + -/ +structure TypeInfo where + /- The total number of input arguments. + + For list_nth: 3 + -/ + total_num_args : ℕ + /- The number of type parameters (they should be a prefix of the input arguments). + + For `list_nth`: 1 + -/ + num_params : ℕ + /- The type of the dependent tuple grouping the type parameters. + + For `list_nth`: `Type` + -/ + params_ty : Expr + /- The type of the tuple grouping the input values. This is a function taking + as input a value of type `params_ty`. + + For `list_nth`: `λ a => List a × Int` + -/ + in_ty : Expr + /- The output type, without the `Return`. This is a function taking + as input a value of type `params_ty`. + + For `list_nth`: `λ a => a` + -/ + out_ty : Expr + +def mkInOutTyFromTypeInfo (info : TypeInfo) : MetaM Expr := do + mkInOutTy info.params_ty info.in_ty info.out_ty + +instance : Inhabited TypeInfo := + { default := { total_num_args := 0, num_params := 0, params_ty := UnitType, + in_ty := UnitType, out_ty := UnitType } } + +instance : ToMessageData TypeInfo := + ⟨ λ ⟨ total_num_args, num_params, params_ty, in_ty, out_ty ⟩ => + f!"\{ total_num_args: {total_num_args}, num_params: {num_params}, params_ty: {params_ty}, in_ty: {in_ty}, out_ty: {out_ty} }}" + ⟩ + /- Generate and declare as individual definitions the bodies for the individual funcions: - replace the recursive calls with calls to the continutation `k` - make those bodies take one single dependent tuple as input @@ -246,15 +442,17 @@ def mkFinVal (n i : Nat) : MetaM Expr := do We return the new declarations. -/ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) - (inOutTys : Array (Expr × Expr)) (preDefs : Array PreDefinition) : + (paramInOutTys : Array TypeInfo) (preDefs : Array PreDefinition) : MetaM (Array Expr) := do let grSize := preDefs.size - -- Compute the map from name to (index × input type). - -- Remark: the continuation has an indexed type; we use the index (a finite number of - -- type `Fin`) to control which function we call at the recursive call site. - let nameToInfo : HashMap Name (Nat × Expr) := - let bl := preDefs.mapIdx fun i d => (d.declName, (i.val, (inOutTys.get! i.val).fst)) + /- Compute the map from name to (index, type info). + + Remark: the continuation has an indexed type; we use the index (a finite number of + type `Fin`) to control which function we call at the recursive call site. -/ + let nameToInfo : HashMap Name (Nat × TypeInfo) := + let bl := preDefs.mapIdx fun i d => + (d.declName, (i.val, paramInOutTys.get! i.val)) HashMap.ofList bl.toList trace[Diverge.def.genBody] "nameToId: {nameToInfo.toList}" @@ -262,35 +460,65 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) -- Auxiliary function to explore the function bodies and replace the -- recursive calls let visit_e (i : Nat) (e : Expr) : MetaM Expr := do - trace[Diverge.def.genBody] "visiting expression (dept: {i}): {e}" + trace[Diverge.def.genBody.visit] "visiting expression (dept: {i}): {e}" let ne ← do match e with | .app .. => do e.withApp fun f args => do - trace[Diverge.def.genBody] "this is an app: {f} {args}" + trace[Diverge.def.genBody.visit] "this is an app: {f} {args}" -- Check if this is a recursive call if f.isConst then let name := f.constName! match nameToInfo.find? name with | none => pure e - | some (id, in_ty) => - trace[Diverge.def.genBody] "this is a recursive call" + | some (id, type_info) => + trace[Diverge.def.genBody.visit] "this is a recursive call" -- This is a recursive call: replace it -- Compute the index let i ← mkFinVal grSize id - -- Put the arguments in one big dependent tuple - let args ← mkSigmasVal in_ty args.toList - mkAppM' kk_var #[i, args] + -- It can happen that there are no input values given to the + -- recursive calls, and only type parameters. + let num_args := args.size + if num_args ≠ type_info.total_num_args ∧ num_args ≠ type_info.num_params then + throwError "Invalid number of arguments for the recursive call: {e}" + -- Split the arguments, and put them in two tuples (the first + -- one is a dependent tuple) + let (param_args, args) := args.toList.splitAt type_info.num_params + trace[Diverge.def.genBody.visit] "param_args: {param_args}, args: {args}" + let param_args ← mkSigmasVal type_info.params_ty param_args + -- Check if there are input values + if num_args = type_info.total_num_args then do + trace[Diverge.def.genBody.visit] "Recursive call with input values" + let args ← mkProdsVal args + mkAppM' kk_var #[i, param_args, args] + else do + trace[Diverge.def.genBody.visit] "Recursive call without input values" + mkAppM' kk_var #[i, param_args] else -- Not a recursive call: do nothing pure e | .const name _ => - -- Sanity check: we eliminated all the recursive calls - if (nameToInfo.find? name).isSome then - throwError "mkUnaryBodies: a recursive call was not eliminated" + /- This might refer to the one of the top-level functions if we use + it without arguments (if we give it to a higher-order + function for instance) and there are actually no type parameters. + -/ + if (nameToInfo.find? name).isSome then do + -- Checking the type information + match nameToInfo.find? name with + | none => pure e + | some (id, type_info) => + trace[Diverge.def.genBody.visit] "this is a recursive call" + -- This is a recursive call: replace it + -- Compute the index + let i ← mkFinVal grSize id + -- Check that there are no type parameters + if type_info.num_params ≠ 0 then throwError "mkUnaryBodies: a recursive call was not eliminated" + -- Introduce the call to the continuation + let param_args ← mkSigmasVal type_info.params_ty [] + mkAppM' kk_var #[i, param_args] else pure e | _ => pure e - trace[Diverge.def.genBody] "done with expression (depth: {i}): {e}" + trace[Diverge.def.genBody.visit] "done with expression (depth: {i}): {e}" pure ne -- Explore the bodies @@ -300,13 +528,20 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) let body ← mapVisit visit_e preDef.value trace[Diverge.def.genBody] "Body after replacement of the recursive calls: {body}" - -- Currify the function by grouping the arguments into a dependent tuple + -- Currify the function by grouping the arguments into dependent tuples -- (over which we match to retrieve the individual arguments). lambdaTelescope body fun args body => do - let body ← mkSigmasMatch args.toList body 0 + -- Split the arguments between the type parameters and the "regular" inputs + let (_, type_info) := nameToInfo.find! preDef.declName + let (param_args, args) := args.toList.splitAt type_info.num_params + let body ← mkProdsMatchOrUnit args body + trace[Diverge.def.genBody] "Body after mkProdsMatchOrUnit: {body}" + let body ← mkSigmasMatchOrUnit param_args body + trace[Diverge.def.genBody] "Body after mkSigmasMatchOrUnit: {body}" -- Add the declaration let value ← mkLambdaFVars #[kk_var] body + trace[Diverge.def.genBody] "Body after abstracting kk: {value}" let name := preDef.declName.append "body" let levelParams := grLvlParams let decl := Declaration.defnDecl { @@ -318,41 +553,46 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) safety := .safe all := [name] } + trace[Diverge.def.genBody] "About to add decl" addDecl decl trace[Diverge.def] "individual body of {preDef.declName}: {body}" -- Return the constant let body := Lean.mkConst name (levelParams.map .param) - -- let body ← mkAppM' body #[kk_var] trace[Diverge.def] "individual body (after decl): {body}" pure body --- Generate a unique function body from the bodies of the mutually recursive group, --- and add it as a declaration in the context. --- We return the list of bodies (of type `FixI.Funs ...`) and the mutually recursive body. +/- Generate a unique function body from the bodies of the mutually recursive group, + and add it as a declaration in the context. + We return the list of bodies (of type `FixI.Funs ...`) and the mutually recursive body. + -/ def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name) (kk_var i_var : Expr) - (in_ty out_ty : Expr) (inOutTys : List (Expr × Expr)) + (param_ty in_ty out_ty : Expr) (paramInOutTys : Array TypeInfo) (bodies : Array Expr) : MetaM (Expr × Expr) := do -- Generate the body let grSize := bodies.size let finTypeExpr := mkFin grSize -- TODO: not very clean - let inOutTyType ← do - let (x, y) := inOutTys.get! 0 - inferType (← mkInOutTy x y) - let rec mkFuns (inOutTys : List (Expr × Expr)) (bl : List Expr) : MetaM Expr := - match inOutTys, bl with + let paramInOutTyType ← do + let info := paramInOutTys.get! 0 + inferType (← mkInOutTyFromTypeInfo info) + let rec mkFuns (paramInOutTys : List TypeInfo) (bl : List Expr) : MetaM Expr := + match paramInOutTys, bl with | [], [] => - mkAppOptM ``FixI.Funs.Nil #[finTypeExpr, in_ty, out_ty] - | (ity, oty) :: inOutTys, b :: bl => do + mkAppOptM ``FixII.Funs.Nil #[finTypeExpr, param_ty, in_ty, out_ty] + | info :: paramInOutTys, b :: bl => do + let pty := info.params_ty + let ity := info.in_ty + let oty := info.out_ty -- Retrieving ity and oty - this is not very clean - let inOutTysExpr ← mkListLit inOutTyType (← inOutTys.mapM (λ (x, y) => mkInOutTy x y)) - let fl ← mkFuns inOutTys bl - mkAppOptM ``FixI.Funs.Cons #[finTypeExpr, in_ty, out_ty, ity, oty, inOutTysExpr, b, fl] + let paramInOutTysExpr ← mkListLit paramInOutTyType + (← paramInOutTys.mapM mkInOutTyFromTypeInfo) + let fl ← mkFuns paramInOutTys bl + mkAppOptM ``FixII.Funs.Cons #[finTypeExpr, param_ty, in_ty, out_ty, pty, ity, oty, paramInOutTysExpr, b, fl] | _, _ => throwError "mkDeclareMutRecBody: `tys` and `bodies` don't have the same length" - let bodyFuns ← mkFuns inOutTys bodies.toList + let bodyFuns ← mkFuns paramInOutTys.toList bodies.toList -- Wrap in `get_fun` - let body ← mkAppM ``FixI.get_fun #[bodyFuns, i_var, kk_var] + let body ← mkAppM ``FixII.get_fun #[bodyFuns, i_var, kk_var] -- Add the index `i` and the continuation `k` as a variables let body ← mkLambdaFVars #[kk_var, i_var] body trace[Diverge.def] "mkDeclareMutRecBody: body: {body}" @@ -391,11 +631,11 @@ instance : ToMessageData MatchInfo where -- This is not a very clean formatting, but we don't need more toMessageData := fun me => m!"\n- matcherName: {me.matcherName}\n- params: {me.params}\n- motive: {me.motive}\n- scruts: {me.scruts}\n- branchesNumParams: {me.branchesNumParams}\n- branches: {me.branches}" --- Small helper: prove that an expression which doesn't use the continuation `kk` --- is valid, and return the proof. +/- Small helper: prove that an expression which doesn't use the continuation `kk` + is valid, and return the proof. -/ def proveNoKExprIsValid (k_var : Expr) (e : Expr) : MetaM Expr := do trace[Diverge.def.valid] "proveNoKExprIsValid: {e}" - let eIsValid ← mkAppM ``FixI.is_valid_p_same #[k_var, e] + let eIsValid ← mkAppM ``FixII.is_valid_p_same #[k_var, e] trace[Diverge.def.valid] "proveNoKExprIsValid: result:\n{eIsValid}:\n{← inferType eIsValid}" pure eIsValid @@ -410,7 +650,16 @@ mutual ``` -/ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do - trace[Diverge.def.valid] "proveValid: {e}" + trace[Diverge.def.valid] "proveExprIsValid: {e}" + -- Normalize to eliminate the lambdas - TODO: this is slightly dangerous. + let e ← do + if e.isLet ∧ normalize_let_bindings then do + let updt_config config := + { config with transparency := .reducible, zetaNonDep := false } + let e ← withConfig updt_config (whnf e) + trace[Diverge.def.valid] "e (after normalization): {e}" + pure e + else pure e match e with | .const _ _ => throwError "Unimplemented" -- Shouldn't get there? | .bvar _ @@ -418,9 +667,10 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do | .lit _ | .mvar _ | .sort _ => throwError "Unreachable" - | .lam .. => throwError "Unimplemented" + | .lam .. => throwError "Unimplemented" -- TODO | .forallE .. => throwError "Unreachable" -- Shouldn't get there | .letE .. => do + -- Remark: this branch is not taken if we normalize the expressions (above) -- Telescope all the let-bindings (remark: this also telescopes the lambdas) lambdaLetTelescope e fun xs body => do -- Note that we don't visit the bound values: there shouldn't be @@ -438,164 +688,268 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do proveNoKExprIsValid k_var e | .app .. => e.withApp fun f args => do - -- There are several cases: first, check if this is a match/if - -- Check if the expression is a (dependent) if then else. - -- We treat the if then else expressions differently from the other matches, - -- and have dedicated theorems for them. - let isIte := e.isIte - if isIte || e.isDIte then do - e.withApp fun f args => do - trace[Diverge.def.valid] "ite/dite: {f}:\n{args}" - if args.size ≠ 5 then - throwError "Wrong number of parameters for {f}: {args}" - let cond := args.get! 1 - let dec := args.get! 2 - -- Prove that the branches are valid - let br0 := args.get! 3 - let br1 := args.get! 4 - let proveBranchValid (br : Expr) : MetaM Expr := - if isIte then proveExprIsValid k_var kk_var br - else do - -- There is a lambda - lambdaOne br fun x br => do - let brValid ← proveExprIsValid k_var kk_var br - mkLambdaFVars #[x] brValid - let br0Valid ← proveBranchValid br0 - let br1Valid ← proveBranchValid br1 - let const := if isIte then ``FixI.is_valid_p_ite else ``FixI.is_valid_p_dite - let eIsValid ← mkAppOptM const #[none, none, none, none, some k_var, some cond, some dec, none, none, some br0Valid, some br1Valid] - trace[Diverge.def.valid] "ite/dite: result:\n{eIsValid}:\n{← inferType eIsValid}" - pure eIsValid - -- Check if the expression is a match (this case is for when the elaborator - -- introduces auxiliary definitions to hide the match behind syntactic - -- sugar): - else if let some me := ← matchMatcherApp? e then do - trace[Diverge.def.valid] - "matcherApp: - - params: {me.params} - - motive: {me.motive} - - discrs: {me.discrs} - - altNumParams: {me.altNumParams} - - alts: {me.alts} - - remaining: {me.remaining}" - -- matchMatcherApp does all the work for us: we simply need to gather - -- the information and call the auxiliary helper `proveMatchIsValid` - if me.remaining.size ≠ 0 then - throwError "MatcherApp: non empty remaining array: {me.remaining}" - let me : MatchInfo := { - matcherName := me.matcherName - matcherLevels := me.matcherLevels - params := me.params - motive := me.motive - scruts := me.discrs - branchesNumParams := me.altNumParams - branches := me.alts - } - proveMatchIsValid k_var kk_var me - -- Check if the expression is a raw match (this case is for when the expression - -- is a direct call to the primitive `casesOn` function, without syntactic sugar). - -- We have to check this case because functions like `mkSigmasMatch`, which we - -- use to currify function bodies, introduce such raw matches. - else if ← isCasesExpr f then do - trace[Diverge.def.valid] "rawMatch: {e}" - -- Deconstruct the match, and call the auxiliary helper `proveMatchIsValid`. - -- - -- The casesOn definition is always of the following shape: - -- - input parameters (implicit parameters) - -- - motive (implicit), -- the motive gives the return type of the match - -- - scrutinee (explicit) - -- - branches (explicit). - -- In particular, we notice that the scrutinee is the first *explicit* - -- parameter - this is how we spot it. - let matcherName := f.constName! - let matcherLevels := f.constLevels!.toArray - -- Find the first explicit parameter: this is the scrutinee - forallTelescope (← inferType f) fun xs _ => do - let rec findFirstExplicit (i : Nat) : MetaM Nat := do - if i ≥ xs.size then throwError "Unexpected: could not find an explicit parameter" - else - let x := xs.get! i - let xFVarId := x.fvarId! - let localDecl ← xFVarId.getDecl - match localDecl.binderInfo with - | .default => pure i - | _ => findFirstExplicit (i + 1) - let scrutIdx ← findFirstExplicit 0 - -- Split the arguments - let params := args.extract 0 (scrutIdx - 1) - let motive := args.get! (scrutIdx - 1) - let scrut := args.get! scrutIdx - let branches := args.extract (scrutIdx + 1) args.size - -- Compute the number of parameters for the branches: for this we use - -- the type of the uninstantiated casesOn constant (we can't just - -- destruct the lambdas in the branch expressions because the result - -- of a match might be a lambda expression). - let branchesNumParams : Array Nat ← do - let env ← getEnv - let decl := env.constants.find! matcherName - let ty := decl.type - forallTelescope ty fun xs _ => do - let xs := xs.extract (scrutIdx + 1) xs.size - xs.mapM fun x => do - let xty ← inferType x - forallTelescope xty fun ys _ => do - pure ys.size - let me : MatchInfo := { - matcherName, - matcherLevels, - params, - motive, - scruts := #[scrut], - branchesNumParams, - branches, - } - proveMatchIsValid k_var kk_var me - -- Check if this is a monadic let-binding - else if f.isConstOf ``Bind.bind then do - trace[Diverge.def.valid] "bind:\n{args}" - -- We simply need to prove that the subexpressions are valid, and call - -- the appropriate lemma. - let x := args.get! 4 - let y := args.get! 5 - -- Prove that the subexpressions are valid - let xValid ← proveExprIsValid k_var kk_var x - trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}" - let yValid ← do - -- This is a lambda expression - lambdaOne y fun x y => do - trace[Diverge.def.valid] "bind: y: {y}" - let yValid ← proveExprIsValid k_var kk_var y - trace[Diverge.def.valid] "bind: yValid (no forall): {yValid}" - trace[Diverge.def.valid] "bind: yValid: x: {x}" - let yValid ← mkLambdaFVars #[x] yValid - trace[Diverge.def.valid] "bind: yValid (forall): {yValid}: {← inferType yValid}" - pure yValid - -- Put everything together - trace[Diverge.def.valid] "bind:\n- xValid: {xValid}: {← inferType xValid}\n- yValid: {yValid}: {← inferType yValid}" - mkAppM ``FixI.is_valid_p_bind #[xValid, yValid] - -- Check if this is a recursive call, i.e., a call to the continuation `kk` - else if f.isFVarOf kk_var.fvarId! then do - trace[Diverge.def.valid] "rec: args: \n{args}" - if args.size ≠ 2 then throwError "Recursive call with invalid number of parameters: {args}" - let i_arg := args.get! 0 - let x_arg := args.get! 1 - let eIsValid ← mkAppM ``FixI.is_valid_p_rec #[k_var, i_arg, x_arg] - trace[Diverge.def.valid] "rec: result: \n{eIsValid}" - pure eIsValid + proveAppIsValid k_var kk_var e f args + +partial def proveAppIsValid (k_var kk_var : Expr) (e : Expr) (f : Expr) (args : Array Expr): MetaM Expr := do + trace[Diverge.def.valid] "proveAppIsValid: {e}\nDecomposed: {f} {args}" + /- There are several cases: first, check if this is a match/if + Check if the expression is a (dependent) if then else. + We treat the if then else expressions differently from the other matches, + and have dedicated theorems for them. -/ + let isIte := e.isIte + if isIte || e.isDIte then do + e.withApp fun f args => do + trace[Diverge.def.valid] "ite/dite: {f}:\n{args}" + if args.size ≠ 5 then + throwError "Wrong number of parameters for {f}: {args}" + let cond := args.get! 1 + let dec := args.get! 2 + -- Prove that the branches are valid + let br0 := args.get! 3 + let br1 := args.get! 4 + let proveBranchValid (br : Expr) : MetaM Expr := + if isIte then proveExprIsValid k_var kk_var br else do - -- Remaining case: normal application. - -- It shouldn't use the continuation. - proveNoKExprIsValid k_var e + -- There is a lambda + lambdaOne br fun x br => do + let brValid ← proveExprIsValid k_var kk_var br + mkLambdaFVars #[x] brValid + let br0Valid ← proveBranchValid br0 + let br1Valid ← proveBranchValid br1 + let const := if isIte then ``FixII.is_valid_p_ite else ``FixII.is_valid_p_dite + let eIsValid ← + mkAppOptM const #[none, none, none, none, none, + some k_var, some cond, some dec, none, none, + some br0Valid, some br1Valid] + trace[Diverge.def.valid] "ite/dite: result:\n{eIsValid}:\n{← inferType eIsValid}" + pure eIsValid + /- Check if the expression is a match (this case is for when the elaborator + introduces auxiliary definitions to hide the match behind syntactic + sugar): -/ + else if let some me := ← matchMatcherApp? e then do + trace[Diverge.def.valid] + "matcherApp: + - params: {me.params} + - motive: {me.motive} + - discrs: {me.discrs} + - altNumParams: {me.altNumParams} + - alts: {me.alts} + - remaining: {me.remaining}" + -- matchMatcherApp does all the work for us: we simply need to gather + -- the information and call the auxiliary helper `proveMatchIsValid` + if me.remaining.size ≠ 0 then + throwError "MatcherApp: non empty remaining array: {me.remaining}" + let me : MatchInfo := { + matcherName := me.matcherName + matcherLevels := me.matcherLevels + params := me.params + motive := me.motive + scruts := me.discrs + branchesNumParams := me.altNumParams + branches := me.alts + } + proveMatchIsValid k_var kk_var me + /- Check if the expression is a raw match (this case is for when the expression + is a direct call to the primitive `casesOn` function, without syntactic sugar). + We have to check this case because functions like `mkSigmasMatch`, which we + use to currify function bodies, introduce such raw matches. -/ + else if ← isCasesExpr f then do + trace[Diverge.def.valid] "rawMatch: {e}" + /- Deconstruct the match, and call the auxiliary helper `proveMatchIsValid`. + + The casesOn definition is always of the following shape: + - input parameters (implicit parameters) + - motive (implicit), -- the motive gives the return type of the match + - scrutinee (explicit) + - branches (explicit). + In particular, we notice that the scrutinee is the first *explicit* + parameter - this is how we spot it. + -/ + let matcherName := f.constName! + let matcherLevels := f.constLevels!.toArray + -- Find the first explicit parameter: this is the scrutinee + forallTelescope (← inferType f) fun xs _ => do + let rec findFirstExplicit (i : Nat) : MetaM Nat := do + if i ≥ xs.size then throwError "Unexpected: could not find an explicit parameter" + else + let x := xs.get! i + let xFVarId := x.fvarId! + let localDecl ← xFVarId.getDecl + match localDecl.binderInfo with + | .default => pure i + | _ => findFirstExplicit (i + 1) + let scrutIdx ← findFirstExplicit 0 + -- Split the arguments + let params := args.extract 0 (scrutIdx - 1) + let motive := args.get! (scrutIdx - 1) + let scrut := args.get! scrutIdx + let branches := args.extract (scrutIdx + 1) args.size + /- Compute the number of parameters for the branches: for this we use + the type of the uninstantiated casesOn constant (we can't just + destruct the lambdas in the branch expressions because the result + of a match might be a lambda expression). -/ + let branchesNumParams : Array Nat ← do + let env ← getEnv + let decl := env.constants.find! matcherName + let ty := decl.type + forallTelescope ty fun xs _ => do + let xs := xs.extract (scrutIdx + 1) xs.size + xs.mapM fun x => do + let xty ← inferType x + forallTelescope xty fun ys _ => do + pure ys.size + let me : MatchInfo := { + matcherName, + matcherLevels, + params, + motive, + scruts := #[scrut], + branchesNumParams, + branches, + } + proveMatchIsValid k_var kk_var me + -- Check if this is a monadic let-binding + else if f.isConstOf ``Bind.bind then do + trace[Diverge.def.valid] "bind:\n{args}" + -- We simply need to prove that the subexpressions are valid, and call + -- the appropriate lemma. + let x := args.get! 4 + let y := args.get! 5 + -- Prove that the subexpressions are valid + let xValid ← proveExprIsValid k_var kk_var x + trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}" + let yValid ← do + -- This is a lambda expression + lambdaOne y fun x y => do + trace[Diverge.def.valid] "bind: y: {y}" + let yValid ← proveExprIsValid k_var kk_var y + trace[Diverge.def.valid] "bind: yValid (no forall): {yValid}" + trace[Diverge.def.valid] "bind: yValid: x: {x}" + let yValid ← mkLambdaFVars #[x] yValid + trace[Diverge.def.valid] "bind: yValid (forall): {yValid}: {← inferType yValid}" + pure yValid + -- Put everything together + trace[Diverge.def.valid] "bind:\n- xValid: {xValid}: {← inferType xValid}\n- yValid: {yValid}: {← inferType yValid}" + mkAppM ``FixII.is_valid_p_bind #[xValid, yValid] + -- Check if this is a recursive call, i.e., a call to the continuation `kk` + else if f.isFVarOf kk_var.fvarId! then do + trace[Diverge.def.valid] "rec: args: \n{args}" + if args.size ≠ 3 then throwError "Recursive call with invalid number of parameters: {args}" + let i_arg := args.get! 0 + let t_arg := args.get! 1 + let x_arg := args.get! 2 + let eIsValid ← mkAppM ``FixII.is_valid_p_rec #[k_var, i_arg, t_arg, x_arg] + trace[Diverge.def.valid] "rec: result: \n{eIsValid}" + pure eIsValid + else do + /- Remaining case: normal application. + Check if the arguments use the continuation: + - if no: this is simple + - if yes: we have to lookup theorems in div spec database and continue -/ + trace[Diverge.def.valid] "regular app: {e}, f: {f}, args: {args}" + let argsFVars ← args.mapM getFVarIds + let allArgsFVars := argsFVars.foldl (fun hs fvars => hs.insertMany fvars) HashSet.empty + trace[Diverge.def.valid] "allArgsFVars: {allArgsFVars.toList.map mkFVar}" + if ¬ allArgsFVars.contains kk_var.fvarId! then do + -- Simple case + trace[Diverge.def.valid] "kk doesn't appear in the arguments" + proveNoKExprIsValid k_var e + else do + -- Lookup in the database for suitable theorems + trace[Diverge.def.valid] "kk appears in the arguments" + let thms ← divspecAttr.find? e + trace[Diverge.def.valid] "Looked up theorems: {thms}" + -- Try the theorems one by one + proveAppIsValidApplyThms k_var kk_var e f args thms.toList + +partial def proveAppIsValidApplyThms (k_var kk_var : Expr) (e : Expr) + (f : Expr) (args : Array Expr) (thms : List Name) : MetaM Expr := do + match thms with + | [] => throwError "Could not prove that the following expression is valid: {e}" + | thName :: thms => + -- Lookup the theorem itself + let env ← getEnv + let thDecl := env.constants.find! thName + -- Introduce fresh meta-variables for the universes + let ul : List (Name × Level) ← + thDecl.levelParams.mapM (λ x => do pure (x, ← mkFreshLevelMVar)) + let ulMap : HashMap Name Level := HashMap.ofList ul + let thTy := thDecl.type.instantiateLevelParamsCore (λ x => ulMap.find! x) + trace[Diverge.def.valid] "Trying with theorem {thName}: {thTy}" + -- Introduce meta variables for the universally quantified variables + let (mvars, _binders, thTyBody) ← forallMetaTelescope thTy + let thTermToMatch := thTyBody + trace[Diverge.def.valid] "thTermToMatch: {thTermToMatch}" + -- Create the term: `is_valid_p k (λ kk => e)` + let termToMatch ← mkLambdaFVars #[kk_var] e + let termToMatch ← mkAppM ``FixII.is_valid_p #[k_var, termToMatch] + trace[Diverge.def.valid] "termToMatch: {termToMatch}" + -- Attempt to match + trace[Diverge.def.valid] "Matching terms:\n\n{termToMatch}\n\n{thTermToMatch}" + let ok ← isDefEq termToMatch thTermToMatch + if ¬ ok then + -- Failure: attempt with the other theorems + proveAppIsValidApplyThms k_var kk_var e f args thms + else do + /- Success: continue with this theorem + + Instantiate the meta variables (some of them will not be instantiated: + they are new subgoals) + -/ + let mvars ← mvars.mapM instantiateMVars + let th ← mkAppOptM thName (Array.map some mvars) + trace[Diverge.def.valid] "Instantiated theorem: {th}\n{← inferType th}" + -- Filter the instantiated meta variables + let mvars := mvars.filter (fun v => v.isMVar) + let mvars := mvars.map (fun v => v.mvarId!) + trace[Diverge.def.valid] "Remaining subgoals: {mvars}" + for mvarId in mvars do + -- Prove the subgoal (i.e., the precondition of the theorem) + let mvarDecl ← mvarId.getDecl + let declType ← instantiateMVars mvarDecl.type + -- Reduce the subgoal before diving in, if necessary + trace[Diverge.def.valid] "Subgoal: {declType}" + -- Dive in the type + forallTelescope declType fun forall_vars mvar_e => do + trace[Diverge.def.valid] "forall_vars: {forall_vars}" + -- `mvar_e` should have the shape `is_valid_p k (λ kk => ...)` + -- We need to retrieve the new `k` variable, and dive into the + -- `λ kk => ...` + mvar_e.consumeMData.withApp fun is_valid args => do + if is_valid.constName? ≠ ``FixII.is_valid_p ∨ args.size ≠ 7 then + throwError "Invalid precondition: {mvar_e}" + else do + let k_var := args.get! 5 + let e_lam := args.get! 6 + trace[Diverge.def.valid] "k_var: {k_var}\ne_lam: {e_lam}" + -- The outer lambda should be for the kk_var + lambdaOne e_lam.consumeMData fun kk_var e => do + -- Continue + trace[Diverge.def.valid] "kk_var: {kk_var}\ne: {e}" + -- We sometimes need to reduce the term - TODO: this is really dangerous + let e ← do + let updt_config config := + { config with transparency := .reducible, zetaNonDep := false } + withConfig updt_config (whnf e) + trace[Diverge.def.valid] "e (after normalization): {e}" + let e_valid ← proveExprIsValid k_var kk_var e + trace[Diverge.def.valid] "e_valid (for e): {e_valid}" + let e_valid ← mkLambdaFVars forall_vars e_valid + trace[Diverge.def.valid] "e_valid (with foralls): {e_valid}" + let _ ← inferType e_valid -- Sanity check + -- Assign the meta variable + mvarId.assign e_valid + pure th -- Prove that a match expression is valid. partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Expr := do trace[Diverge.def.valid] "proveMatchIsValid: {me}" -- Prove the validity of the branch expressions let branchesValid:Array Expr ← me.branches.mapIdxM fun idx br => do - -- Go inside the lambdas - note that we have to be careful: some of the - -- binders might come from the match, and some of the binders might come - -- from the fact that the expression in the match is a lambda expression: - -- we use the branchesNumParams field for this reason + /- Go inside the lambdas - note that we have to be careful: some of the + binders might come from the match, and some of the binders might come + from the fact that the expression in the match is a lambda expression: + we use the branchesNumParams field for this reason. -/ let numParams := me.branchesNumParams.get! idx lambdaTelescopeN br numParams fun xs br => do -- Prove that the branch expression is valid @@ -603,13 +957,14 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp -- Reconstruct the lambda expression mkLambdaFVars xs brValid trace[Diverge.def.valid] "branchesValid:\n{branchesValid}" - -- Compute the motive, which has the following shape: - -- ``` - -- λ scrut => is_valid_p k (λ k => match scrut with ...) - -- ^^^^^^^^^^^^^^^^^^^^ - -- this is the original match expression, with the - -- the difference that the scrutinee(s) is a variable - -- ``` + /- Compute the motive, which has the following shape: + ``` + λ scrut => is_valid_p k (λ k => match scrut with ...) + ^^^^^^^^^^^^^^^^^^^^ + this is the original match expression, with the + the difference that the scrutinee(s) is a variable + ``` + -/ let validMotive : Expr ← do -- The motive is a function of the scrutinees (i.e., a lambda expression): -- introduce binders for the scrutinees @@ -628,7 +983,7 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp let matchE ← mkAppOptM me.matcherName args -- Wrap in the `is_valid_p` predicate let matchE ← mkLambdaFVars #[kk_var] matchE - let validMotive ← mkAppM ``FixI.is_valid_p #[k_var, matchE] + let validMotive ← mkAppM ``FixII.is_valid_p #[k_var, matchE] -- Abstract away the scrutinee variables mkLambdaFVars scrutVars validMotive trace[Diverge.def.valid] "valid motive: {validMotive}" @@ -646,10 +1001,10 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp end --- Prove that a single body (in the mutually recursive group) is valid. --- --- For instance, if we define the mutually recursive group [`is_even`, `is_odd`], --- we prove that `is_even.body` and `is_odd.body` are valid. +/- Prove that a single body (in the mutually recursive group) is valid. + + For instance, if we define the mutually recursive group [`is_even`, `is_odd`], + we prove that `is_even.body` and `is_odd.body` are valid. -/ partial def proveSingleBodyIsValid (k_var : Expr) (preDef : PreDefinition) (bodyConst : Expr) : MetaM Expr := do @@ -661,24 +1016,29 @@ partial def proveSingleBodyIsValid let body := (env.constants.find! name).value! trace[Diverge.def.valid] "body: {body}" lambdaTelescope body fun xs body => do - assert! xs.size = 2 + trace[Diverge.def.valid] "xs: {xs}" + if xs.size ≠ 3 then throwError "Invalid number of lambdas: {xs} (expected 3)" let kk_var := xs.get! 0 - let x_var := xs.get! 1 + let t_var := xs.get! 1 + let x_var := xs.get! 2 -- State the type of the theorem to prove - let thmTy ← mkAppM ``FixI.is_valid_p - #[k_var, ← mkLambdaFVars #[kk_var] (← mkAppM' bodyConst #[kk_var, x_var])] + trace[Diverge.def.valid] "bodyConst: {bodyConst} : {← inferType bodyConst}" + let bodyApp ← mkAppOptM' bodyConst #[.some kk_var, .some t_var, .some x_var] + trace[Diverge.def.valid] "bodyApp: {bodyApp}" + let bodyApp ← mkLambdaFVars #[kk_var] bodyApp + trace[Diverge.def.valid] "bodyApp: {bodyApp}" + let thmTy ← mkAppM ``FixII.is_valid_p #[k_var, bodyApp] trace[Diverge.def.valid] "thmTy: {thmTy}" -- Prove that the body is valid + trace[Diverge.def.valid] "body: {body}" let proof ← proveExprIsValid k_var kk_var body - let proof ← mkLambdaFVars #[k_var, x_var] proof + let proof ← mkLambdaFVars #[k_var, t_var, x_var] proof trace[Diverge.def.valid] "proveSingleBodyIsValid: proof:\n{proof}:\n{← inferType proof}" -- The target type (we don't have to do this: this is simply a sanity check, -- and this allows a nicer debugging output) let thmTy ← do - let body ← mkAppM' bodyConst #[kk_var, x_var] - let body ← mkLambdaFVars #[kk_var] body - let ty ← mkAppM ``FixI.is_valid_p #[k_var, body] - mkForallFVars #[k_var, x_var] ty + let ty ← mkAppM ``FixII.is_valid_p #[k_var, bodyApp] + mkForallFVars #[k_var, t_var, x_var] ty trace[Diverge.def.valid] "proveSingleBodyIsValid: thmTy\n{thmTy}:\n{← inferType thmTy}" -- Save the theorem let name := preDef.declName ++ "body_is_valid" @@ -694,18 +1054,18 @@ partial def proveSingleBodyIsValid -- Return the theorem pure (Expr.const name (preDef.levelParams.map .param)) --- Prove that the list of bodies are valid. --- --- For instance, if we define the mutually recursive group [`is_even`, `is_odd`], --- we prove that `Funs.Cons is_even.body (Funs.Cons is_odd.body Funs.Nil)` is --- valid. -partial def proveFunsBodyIsValid (inOutTys: Expr) (bodyFuns : Expr) +/- Prove that the list of bodies are valid. + + For instance, if we define the mutually recursive group [`is_even`, `is_odd`], + we prove that `Funs.Cons is_even.body (Funs.Cons is_odd.body Funs.Nil)` is + valid. -/ +partial def proveFunsBodyIsValid (paramInOutTys: Expr) (bodyFuns : Expr) (k_var : Expr) (bodiesValid : Array Expr) : MetaM Expr := do -- Create the big "and" expression, which groups the validity proof of the individual bodies let rec mkValidConj (i : Nat) : MetaM Expr := do if i = bodiesValid.size then -- We reached the end - mkAppM ``FixI.Funs.is_valid_p_Nil #[k_var] + mkAppM ``FixII.Funs.is_valid_p_Nil #[k_var] else do -- We haven't reached the end: introduce a conjunction let valid := bodiesValid.get! i @@ -713,20 +1073,20 @@ partial def proveFunsBodyIsValid (inOutTys: Expr) (bodyFuns : Expr) mkAppM ``And.intro #[valid, ← mkValidConj (i + 1)] let andExpr ← mkValidConj 0 -- Wrap in the `is_valid_p_is_valid_p` theorem, and abstract the continuation - let isValid ← mkAppM ``FixI.Funs.is_valid_p_is_valid_p #[inOutTys, k_var, bodyFuns, andExpr] + let isValid ← mkAppM ``FixII.Funs.is_valid_p_is_valid_p #[paramInOutTys, k_var, bodyFuns, andExpr] mkLambdaFVars #[k_var] isValid --- Prove that the mut rec body (i.e., the unary body which groups the bodies --- of all the functions in the mutually recursive group and on which we will --- apply the fixed-point operator) is valid. --- --- We save the proof in the theorem "[GROUP_NAME]."mut_rec_body_is_valid", --- which we return. --- --- TODO: maybe this function should introduce k_var itself +/- Prove that the mut rec body (i.e., the unary body which groups the bodies + of all the functions in the mutually recursive group and on which we will + apply the fixed-point operator) is valid. + + We save the proof in the theorem "[GROUP_NAME]."mut_rec_body_is_valid", + which we return. + + TODO: maybe this function should introduce k_var itself -/ def proveMutRecIsValid (grName : Name) (grLvlParams : List Name) - (inOutTys : Expr) (bodyFuns mutRecBodyConst : Expr) + (paramInOutTys : Expr) (bodyFuns mutRecBodyConst : Expr) (k_var : Expr) (preDefs : Array PreDefinition) (bodies : Array Expr) : MetaM Expr := do -- First prove that the individual bodies are valid @@ -737,9 +1097,10 @@ def proveMutRecIsValid proveSingleBodyIsValid k_var preDef body -- Then prove that the mut rec body is valid trace[Diverge.def.valid] "## Proving that the 'Funs' body is valid" - let isValid ← proveFunsBodyIsValid inOutTys bodyFuns k_var bodiesValid + let isValid ← proveFunsBodyIsValid paramInOutTys bodyFuns k_var bodiesValid + trace[Diverge.def.valid] "Generated the term: {isValid}" -- Save the theorem - let thmTy ← mkAppM ``FixI.is_valid #[mutRecBodyConst] + let thmTy ← mkAppM ``FixII.is_valid #[mutRecBodyConst] let name := grName ++ "mut_rec_body_is_valid" let decl := Declaration.thmDecl { name @@ -753,26 +1114,29 @@ def proveMutRecIsValid -- Return the theorem pure (Expr.const name (grLvlParams.map .param)) --- Generate the final definions by using the mutual body and the fixed point operator. --- --- For instance: --- ``` --- def is_even (i : Int) : Result Bool := mut_rec_body 0 i --- def is_odd (i : Int) : Result Bool := mut_rec_body 1 i --- ``` -def mkDeclareFixDefs (mutRecBody : Expr) (inOutTys : Array (Expr × Expr)) (preDefs : Array PreDefinition) : +/- Generate the final definions by using the mutual body and the fixed point operator. + + For instance: + ``` + def is_even (i : Int) : Result Bool := mut_rec_body 0 i + def is_odd (i : Int) : Result Bool := mut_rec_body 1 i + ``` + -/ +def mkDeclareFixDefs (mutRecBody : Expr) (paramInOutTys : Array TypeInfo) (preDefs : Array PreDefinition) : TermElabM (Array Name) := do let grSize := preDefs.size let defs ← preDefs.mapIdxM fun idx preDef => do lambdaTelescope preDef.value fun xs _ => do - -- Retrieve the input type - let in_ty := (inOutTys.get! idx.val).fst + -- Retrieve the parameters info + let type_info := paramInOutTys.get! idx.val -- Create the index let idx ← mkFinVal grSize idx.val - -- Group the inputs into a dependent tuple - let input ← mkSigmasVal in_ty xs.toList + -- Group the inputs into two tuples + let (params_args, input_args) := xs.toList.splitAt type_info.num_params + let params ← mkSigmasVal type_info.params_ty params_args + let input ← mkProdsVal input_args -- Apply the fixed point - let fixedBody ← mkAppM ``FixI.fix #[mutRecBody, idx, input] + let fixedBody ← mkAppM ``FixII.fix #[mutRecBody, idx, params, input] let fixedBody ← mkLambdaFVars xs fixedBody -- Create the declaration let name := preDef.declName @@ -790,7 +1154,8 @@ def mkDeclareFixDefs (mutRecBody : Expr) (inOutTys : Array (Expr × Expr)) (preD pure defs -- Prove the equations that we will use as unfolding theorems -partial def proveUnfoldingThms (isValidThm : Expr) (inOutTys : Array (Expr × Expr)) +partial def proveUnfoldingThms (isValidThm : Expr) + (paramInOutTys : Array TypeInfo) (preDefs : Array PreDefinition) (decls : Array Name) : MetaM Unit := do let grSize := preDefs.size let proveIdx (i : Nat) : MetaM Unit := do @@ -810,14 +1175,18 @@ partial def proveUnfoldingThms (isValidThm : Expr) (inOutTys : Array (Expr × Ex trace[Diverge.def.unfold] "proveUnfoldingThms: thm statement: {thmTy}" -- The proof -- Use the fixed-point equation - let proof ← mkAppM ``FixI.is_valid_fix_fixed_eq #[isValidThm] + let proof ← mkAppM ``FixII.is_valid_fix_fixed_eq #[isValidThm] -- Add the index let idx ← mkFinVal grSize i let proof ← mkAppM ``congr_fun #[proof, idx] - -- Add the input argument - let arg ← mkSigmasVal (inOutTys.get! i).fst xs.toList - let proof ← mkAppM ``congr_fun #[proof, arg] - -- Abstract the arguments away + -- Add the input arguments + let type_info := paramInOutTys.get! i + let (params, args) := xs.toList.splitAt type_info.num_params + let params ← mkSigmasVal type_info.params_ty params + let args ← mkProdsVal args + let proof ← mkAppM ``congr_fun #[proof, params] + let proof ← mkAppM ``congr_fun #[proof, args] + -- Abstract all the arguments away let proof ← mkLambdaFVars xs proof trace[Diverge.def.unfold] "proveUnfoldingThms: proof: {proof}:\n{← inferType proof}" -- Declare the theorem @@ -845,7 +1214,9 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) trace[Diverge.def] ("divRecursion: defs:\n" ++ msg) - -- TODO: what is this? + -- Apply all the "attribute" functions (for instance, the function which + -- registers the theorem in the simp database if there is the `simp` attribute, + -- etc.) for preDef in preDefs do applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation @@ -859,40 +1230,53 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let grLvlParams := def0.levelParams trace[Diverge.def] "def0 universe levels: {def0.levelParams}" - -- We first compute the list of pairs: (input type × output type) - let inOutTys : Array (Expr × Expr) ← - preDefs.mapM (fun preDef => do - withRef preDef.ref do -- is the withRef useful? - -- Check the universe parameters - TODO: I'm not sure what the best thing - -- to do is. In practice, all the type parameters should be in Type 0, so - -- we shouldn't have universe issues. - if preDef.levelParams ≠ grLvlParams then - throwError "Non-uniform polymorphism in the universes" - forallTelescope preDef.type (fun in_tys out_ty => do - let in_ty ← liftM (mkSigmasType in_tys.toList) - -- Retrieve the type in the "Result" - let out_ty ← getResultTy out_ty - let out_ty ← liftM (mkSigmasMatch in_tys.toList out_ty) - pure (in_ty, out_ty) - ) - ) - trace[Diverge.def] "inOutTys: {inOutTys}" - -- Turn the list of input/output type pairs into an expresion - let inOutTysExpr ← inOutTys.mapM (λ (x, y) => mkInOutTy x y) - let inOutTysExpr ← mkListLit (← inferType (inOutTysExpr.get! 0)) inOutTysExpr.toList - - -- From the list of pairs of input/output types, actually compute the - -- type of the continuation `k`. - -- We first introduce the index `i : Fin n` where `n` is the number of - -- functions in the group. + /- We first compute the tuples: (type parameters × input type × output type) + - type parameters: this is a sigma type + - input type: λ params_type => product type + - output type: λ params_type => output type + For instance, on the function: + `list_nth (α : Type) (ls : List α) (i : Int) : Result α`: + we generate: + `(Type, λ α => List α × i, λ α => Result α)` + -/ + let paramInOutTys : Array TypeInfo ← + preDefs.mapM (fun preDef => do + -- Check the universe parameters - TODO: I'm not sure what the best thing + -- to do is. In practice, all the type parameters should be in Type 0, so + -- we shouldn't have universe issues. + if preDef.levelParams ≠ grLvlParams then + throwError "Non-uniform polymorphism in the universes" + forallTelescope preDef.type (fun in_tys out_ty => do + let total_num_args := in_tys.size + let (params, in_tys) ← splitInputArgs in_tys out_ty + trace[Diverge.def] "Decomposed arguments: {preDef.declName}: {params}, {in_tys}, {out_ty}" + let num_params := params.size + let params_ty ← mkSigmasType params.data + let in_ty ← mkSigmasMatchOrUnit params.data (← mkProdsType in_tys.data) + -- Retrieve the type in the "Result" + let out_ty ← getResultTy out_ty + let out_ty ← mkSigmasMatchOrUnit params.data out_ty + trace[Diverge.def] "inOutTy: {preDef.declName}: {params_ty}, {in_tys}, {out_ty}" + pure ⟨ total_num_args, num_params, params_ty, in_ty, out_ty ⟩)) + trace[Diverge.def] "paramInOutTys: {paramInOutTys}" + -- Turn the list of input types/input args/output type tuples into expressions + let paramInOutTysExpr ← liftM (paramInOutTys.mapM mkInOutTyFromTypeInfo) + let paramInOutTysExpr ← mkListLit (← inferType (paramInOutTysExpr.get! 0)) paramInOutTysExpr.toList + trace[Diverge.def] "paramInOutTys: {paramInOutTys}" + + /- From the list of pairs of input/output types, actually compute the + type of the continuation `k`. + We first introduce the index `i : Fin n` where `n` is the number of + functions in the group. + -/ let i_var_ty := mkFin preDefs.size withLocalDeclD (mkAnonymous "i" 0) i_var_ty fun i_var => do - let in_out_ty ← mkAppM ``List.get #[inOutTysExpr, i_var] - trace[Diverge.def] "in_out_ty := {in_out_ty} : {← inferType in_out_ty}" - -- Add an auxiliary definition for `in_out_ty` - let in_out_ty ← do - let value ← mkLambdaFVars #[i_var] in_out_ty - let name := grName.append "in_out_ty" + let param_in_out_ty ← mkAppM ``List.get #[paramInOutTysExpr, i_var] + trace[Diverge.def] "param_in_out_ty := {param_in_out_ty} : {← inferType param_in_out_ty}" + -- Add an auxiliary definition for `param_in_out_ty` (this is a potentially big term) + let param_in_out_ty ← do + let value ← mkLambdaFVars #[i_var] param_in_out_ty + let name := grName.append "param_in_out_ty" let levelParams := grLvlParams let decl := Declaration.defnDecl { name := name @@ -905,19 +1289,28 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do } addDecl decl -- Return the constant - let in_out_ty := Lean.mkConst name (levelParams.map .param) - mkAppM' in_out_ty #[i_var] - trace[Diverge.def] "in_out_ty (after decl) := {in_out_ty} : {← inferType in_out_ty}" - let in_ty ← mkAppM ``Sigma.fst #[in_out_ty] + let param_in_out_ty := Lean.mkConst name (levelParams.map .param) + mkAppM' param_in_out_ty #[i_var] + trace[Diverge.def] "param_in_out_ty (after decl) := {param_in_out_ty} : {← inferType param_in_out_ty}" + -- Decompose between: param_ty, in_ty, out_ty + let param_ty ← mkAppM ``Sigma.fst #[param_in_out_ty] + let in_out_ty ← mkAppM ``Sigma.snd #[param_in_out_ty] + let in_ty ← mkAppM ``Prod.fst #[in_out_ty] + let out_ty ← mkAppM ``Prod.snd #[in_out_ty] + trace[Diverge.def] "param_ty: {param_ty}" + trace[Diverge.def] "in_ty: {in_ty}" + trace[Diverge.def] "out_ty: {out_ty}" + withLocalDeclD (mkAnonymous "t" 1) param_ty fun param => do + let in_ty ← mkAppM' in_ty #[param] + let out_ty ← mkAppM' out_ty #[param] trace[Diverge.def] "in_ty: {in_ty}" - withLocalDeclD (mkAnonymous "x" 1) in_ty fun input => do - let out_ty ← mkAppM' (← mkAppM ``Sigma.snd #[in_out_ty]) #[input] trace[Diverge.def] "out_ty: {out_ty}" -- Introduce the continuation `k` - let in_ty ← mkLambdaFVars #[i_var] in_ty - let out_ty ← mkLambdaFVars #[i_var, input] out_ty - let kk_var_ty ← mkAppM ``FixI.kk_ty #[i_var_ty, in_ty, out_ty] + let param_ty ← mkLambdaFVars #[i_var] param_ty + let in_ty ← mkLambdaFVars #[i_var, param] in_ty + let out_ty ← mkLambdaFVars #[i_var, param] out_ty + let kk_var_ty ← mkAppM ``FixII.kk_ty #[i_var_ty, param_ty, in_ty, out_ty] trace[Diverge.def] "kk_var_ty: {kk_var_ty}" withLocalDeclD (mkAnonymous "kk" 2) kk_var_ty fun kk_var => do trace[Diverge.def] "kk_var: {kk_var}" @@ -925,29 +1318,30 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Replace the recursive calls in all the function bodies by calls to the -- continuation `k` and and generate for those bodies declarations trace[Diverge.def] "# Generating the unary bodies" - let bodies ← mkDeclareUnaryBodies grLvlParams kk_var inOutTys preDefs + let bodies ← mkDeclareUnaryBodies grLvlParams kk_var paramInOutTys preDefs trace[Diverge.def] "Unary bodies (after decl): {bodies}" + -- Generate the mutually recursive body trace[Diverge.def] "# Generating the mut rec body" - let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var in_ty out_ty inOutTys.toList bodies + let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var param_ty in_ty out_ty paramInOutTys bodies trace[Diverge.def] "mut rec body (after decl): {mutRecBody}" -- Prove that the mut rec body satisfies the validity criteria required by -- our fixed-point - let k_var_ty ← mkAppM ``FixI.k_ty #[i_var_ty, in_ty, out_ty] + let k_var_ty ← mkAppM ``FixII.k_ty #[i_var_ty, param_ty, in_ty, out_ty] withLocalDeclD (mkAnonymous "k" 3) k_var_ty fun k_var => do trace[Diverge.def] "# Proving that the mut rec body is valid" - let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies + let isValidThm ← proveMutRecIsValid grName grLvlParams paramInOutTysExpr bodyFuns mutRecBody k_var preDefs bodies -- Generate the final definitions trace[Diverge.def] "# Generating the final definitions" - let decls ← mkDeclareFixDefs mutRecBody inOutTys preDefs + let decls ← mkDeclareFixDefs mutRecBody paramInOutTys preDefs -- Prove the unfolding theorems trace[Diverge.def] "# Proving the unfolding theorems" - proveUnfoldingThms isValidThm inOutTys preDefs decls + proveUnfoldingThms isValidThm paramInOutTys preDefs decls - -- Generating code -- TODO + -- Generating code addAndCompilePartialRec preDefs -- The following function is copy&pasted from Lean.Elab.PreDefinition.Main @@ -1068,15 +1462,23 @@ elab_rules : command Command.elabCommand <| ← `(namespace $(mkIdentFrom id ns) $cmd end $(mkIdentFrom id ns)) namespace Tests + /- Some examples of partial functions -/ - divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := + --set_option trace.Diverge.def true + --set_option trace.Diverge.def.genBody true + --set_option trace.Diverge.def.valid true + --set_option trace.Diverge.def.genBody.visit true + + divergent def list_nth {a: Type u} (ls : List a) (i : Int) : Result a := match ls with | [] => .fail .panic | x :: ls => if i = 0 then return x else return (← list_nth ls (i - 1)) + --set_option trace.Diverge false + #check list_nth.unfold example {a: Type} (ls : List a) : @@ -1087,17 +1489,33 @@ namespace Tests . intro i hpos h; simp at h; linarith . rename_i hd tl ih intro i hpos h - -- We can directly use `rw [list_nth]`! + -- We can directly use `rw [list_nth]` rw [list_nth]; simp split <;> try simp [*] . tauto - . -- TODO: we shouldn't have to do that + . -- We don't have to do this if we use scalar_tac have hneq : 0 < i := by cases i <;> rename_i a _ <;> simp_all; cases a <;> simp_all simp at h have ⟨ x, ih ⟩ := ih (i - 1) (by linarith) (by linarith) simp [ih] tauto + -- Return a continuation + divergent def list_nth_with_back {a: Type} (ls : List a) (i : Int) : + Result (a × (a → Result (List a))) := + match ls with + | [] => .fail .panic + | x :: ls => + if i = 0 then return (x, (λ ret => return (ret :: ls))) + else do + let (x, back) ← list_nth_with_back ls (i - 1) + return (x, + (λ ret => do + let ls ← back ret + return (x :: ls))) + + #check list_nth_with_back.unfold + mutual divergent def is_even (i : Int) : Result Bool := if i = 0 then return true else return (← is_odd (i - 1)) @@ -1121,7 +1539,6 @@ namespace Tests #check bar.unfold -- Testing dependent branching and let-bindings - -- TODO: why the linter warning? divergent def isNonZero (i : Int) : Result Bool := if _h:i = 0 then return false else @@ -1157,6 +1574,82 @@ namespace Tests #check test1.unfold + /- Tests with higher-order functions -/ + section HigherOrder + open Ex8 + + inductive Tree (a : Type u) := + | leaf (x : a) + | node (tl : List (Tree a)) + + divergent def id {a : Type u} (t : Tree a) : Result (Tree a) := + match t with + | .leaf x => .ret (.leaf x) + | .node tl => + do + let tl ← map id tl + .ret (.node tl) + + #check id.unfold + + divergent def id1 {a : Type u} (t : Tree a) : Result (Tree a) := + match t with + | .leaf x => .ret (.leaf x) + | .node tl => + do + let tl ← map (fun x => id1 x) tl + .ret (.node tl) + + #check id1.unfold + + divergent def id2 {a : Type u} (t : Tree a) : Result (Tree a) := + match t with + | .leaf x => .ret (.leaf x) + | .node tl => + do + let tl ← map (fun x => do let _ ← id2 x; id2 x) tl + .ret (.node tl) + + #check id2.unfold + + divergent def incr (t : Tree Nat) : Result (Tree Nat) := + match t with + | .leaf x => .ret (.leaf (x + 1)) + | .node tl => + do + let tl ← map incr tl + .ret (.node tl) + + -- We handle this by inlining the let-binding + divergent def id3 (t : Tree Nat) : Result (Tree Nat) := + match t with + | .leaf x => .ret (.leaf (x + 1)) + | .node tl => + do + let f := id3 + let tl ← map f tl + .ret (.node tl) + + #check id3.unfold + + /- + -- This is not handled yet: we can only do it if we introduce "general" + -- relations for the input types and output types (result_rel should + -- be parameterized by something). + divergent def id4 (t : Tree Nat) : Result (Tree Nat) := + match t with + | .leaf x => .ret (.leaf (x + 1)) + | .node tl => + do + let f ← .ret id4 + let tl ← map f tl + .ret (.node tl) + + #check id4.unfold + -/ + + end HigherOrder + end Tests end Diverge diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean index fedb1c74..0d33e9d2 100644 --- a/backends/lean/Base/Diverge/ElabBase.lean +++ b/backends/lean/Base/Diverge/ElabBase.lean @@ -1,15 +1,84 @@ import Lean +import Base.Utils +import Base.Primitives.Base +import Base.Extensions namespace Diverge open Lean Elab Term Meta +open Utils Extensions -- We can't define and use trace classes in the same file +initialize registerTraceClass `Diverge initialize registerTraceClass `Diverge.elab initialize registerTraceClass `Diverge.def initialize registerTraceClass `Diverge.def.sigmas +initialize registerTraceClass `Diverge.def.prods initialize registerTraceClass `Diverge.def.genBody +initialize registerTraceClass `Diverge.def.genBody.visit initialize registerTraceClass `Diverge.def.valid initialize registerTraceClass `Diverge.def.unfold +-- For the attribute (for higher-order functions) +initialize registerTraceClass `Diverge.attr + +-- Attribute + +-- divspec attribute +structure DivSpecAttr where + attr : AttributeImpl + ext : DiscrTreeExtension Name true + deriving Inhabited + +/- The persistent map from expressions to divspec theorems. -/ +initialize divspecAttr : DivSpecAttr ← do + let ext ← mkDiscrTreeExtention `divspecMap true + let attrImpl : AttributeImpl := { + name := `divspec + descr := "Marks theorems to use with the `divergent` encoding" + add := fun thName stx attrKind => do + Attribute.Builtin.ensureNoArgs stx + -- TODO: use the attribute kind + unless attrKind == AttributeKind.global do + throwError "invalid attribute divspec, must be global" + -- Lookup the theorem + let env ← getEnv + let thDecl := env.constants.find! thName + let fKey : Array (DiscrTree.Key true) ← MetaM.run' (do + /- The theorem should have the shape: + `∀ ..., is_valid_p k (λ k => ...)` + + Dive into the ∀: + -/ + let (_, _, fExpr) ← forallMetaTelescope thDecl.type.consumeMData + /- Dive into the argument of `is_valid_p`: -/ + fExpr.consumeMData.withApp fun _ args => do + if args.size ≠ 7 then throwError "Invalid number of arguments to is_valid_p" + let fExpr := args.get! 6 + /- Dive into the lambda: -/ + let (_, _, fExpr) ← lambdaMetaTelescope fExpr.consumeMData + trace[Diverge] "Registering divspec theorem for {fExpr}" + -- Convert the function expression to a discrimination tree key + DiscrTree.mkPath fExpr) + let env := ext.addEntry env (fKey, thName) + setEnv env + trace[Diverge] "Saved the environment" + pure () + } + registerBuiltinAttribute attrImpl + pure { attr := attrImpl, ext := ext } + +def DivSpecAttr.find? (s : DivSpecAttr) (e : Expr) : MetaM (Array Name) := do + (s.ext.getState (← getEnv)).getMatch e + +def DivSpecAttr.getState (s : DivSpecAttr) : MetaM (DiscrTree Name true) := do + pure (s.ext.getState (← getEnv)) + +def showStoredDivSpec : MetaM Unit := do + let st ← divspecAttr.getState + -- TODO: how can we iterate over (at least) the values stored in the tree? + --let s := st.toList.foldl (fun s (f, th) => f!"{s}\n{f} → {th}") f!"" + let s := f!"{st}" + IO.println s + end Diverge diff --git a/backends/lean/Base/Extensions.lean b/backends/lean/Base/Extensions.lean new file mode 100644 index 00000000..b34f41dc --- /dev/null +++ b/backends/lean/Base/Extensions.lean @@ -0,0 +1,47 @@ +import Lean +import Std.Lean.HashSet +import Base.Utils +import Base.Primitives.Base + +import Lean.Meta.DiscrTree +import Lean.Meta.Tactic.Simp + +/-! Various state extensions used in the library -/ +namespace Extensions + +open Lean Elab Term Meta +open Utils + +-- This is not used anymore but we keep it here. +-- TODO: the original function doesn't define correctly the `addImportedFn`. Do a PR? +def mkMapDeclarationExtension [Inhabited α] (name : Name := by exact decl_name%) : + IO (MapDeclarationExtension α) := + registerSimplePersistentEnvExtension { + name := name, + addImportedFn := fun a => a.foldl (fun s a => a.foldl (fun s (k, v) => s.insert k v) s) RBMap.empty, + addEntryFn := fun s n => s.insert n.1 n.2 , + toArrayFn := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1) + } + +/- Discrimination trees map expressions to values. When storing an expression + in a discrimination tree, the expression is first converted to an array + of `DiscrTree.Key`, which are the keys actually used by the discrimination + trees. The conversion operation is monadic, however, and extensions require + all the operations to be pure. For this reason, in the state extension, we + store the keys from *after* the transformation (i.e., the `DiscrTreeKey` + below). The transformation itself can be done elsewhere. + -/ +abbrev DiscrTreeKey (simpleReduce : Bool) := Array (DiscrTree.Key simpleReduce) + +abbrev DiscrTreeExtension (α : Type) (simpleReduce : Bool) := + SimplePersistentEnvExtension (DiscrTreeKey simpleReduce × α) (DiscrTree α simpleReduce) + +def mkDiscrTreeExtention [Inhabited α] [BEq α] (name : Name := by exact decl_name%) (simpleReduce : Bool) : + IO (DiscrTreeExtension α simpleReduce) := + registerSimplePersistentEnvExtension { + name := name, + addImportedFn := fun a => a.foldl (fun s a => a.foldl (fun s (k, v) => s.insertCore k v) s) DiscrTree.empty, + addEntryFn := fun s n => s.insertCore n.1 n.2 , + } + +end Extensions diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean index f74fecd4..db522df2 100644 --- a/backends/lean/Base/Primitives/Scalar.lean +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -528,7 +528,7 @@ instance {ty} : HAnd (Scalar ty) (Scalar ty) (Scalar ty) where hAnd x y := Scalar.and x y -- Generic theorem - shouldn't be used much -@[cpspec] +@[pspec] theorem Scalar.add_spec {ty} {x y : Scalar ty} (hmin : Scalar.min ty ≤ x.val + y.val) (hmax : x.val + y.val ≤ Scalar.max ty) : @@ -550,62 +550,62 @@ theorem Scalar.add_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty} apply add_spec <;> assumption /- Fine-grained theorems -/ -@[cepspec] theorem Usize.add_spec {x y : Usize} (hmax : x.val + y.val ≤ Usize.max) : +@[pspec] theorem Usize.add_spec {x y : Usize} (hmax : x.val + y.val ≤ Usize.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by apply Scalar.add_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem U8.add_spec {x y : U8} (hmax : x.val + y.val ≤ U8.max) : +@[pspec] theorem U8.add_spec {x y : U8} (hmax : x.val + y.val ≤ U8.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by apply Scalar.add_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem U16.add_spec {x y : U16} (hmax : x.val + y.val ≤ U16.max) : +@[pspec] theorem U16.add_spec {x y : U16} (hmax : x.val + y.val ≤ U16.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by apply Scalar.add_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem U32.add_spec {x y : U32} (hmax : x.val + y.val ≤ U32.max) : +@[pspec] theorem U32.add_spec {x y : U32} (hmax : x.val + y.val ≤ U32.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by apply Scalar.add_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem U64.add_spec {x y : U64} (hmax : x.val + y.val ≤ U64.max) : +@[pspec] theorem U64.add_spec {x y : U64} (hmax : x.val + y.val ≤ U64.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by apply Scalar.add_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem U128.add_spec {x y : U128} (hmax : x.val + y.val ≤ U128.max) : +@[pspec] theorem U128.add_spec {x y : U128} (hmax : x.val + y.val ≤ U128.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by apply Scalar.add_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem Isize.add_spec {x y : Isize} +@[pspec] theorem Isize.add_spec {x y : Isize} (hmin : Isize.min ≤ x.val + y.val) (hmax : x.val + y.val ≤ Isize.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := Scalar.add_spec hmin hmax -@[cepspec] theorem I8.add_spec {x y : I8} +@[pspec] theorem I8.add_spec {x y : I8} (hmin : I8.min ≤ x.val + y.val) (hmax : x.val + y.val ≤ I8.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := Scalar.add_spec hmin hmax -@[cepspec] theorem I16.add_spec {x y : I16} +@[pspec] theorem I16.add_spec {x y : I16} (hmin : I16.min ≤ x.val + y.val) (hmax : x.val + y.val ≤ I16.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := Scalar.add_spec hmin hmax -@[cepspec] theorem I32.add_spec {x y : I32} +@[pspec] theorem I32.add_spec {x y : I32} (hmin : I32.min ≤ x.val + y.val) (hmax : x.val + y.val ≤ I32.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := Scalar.add_spec hmin hmax -@[cepspec] theorem I64.add_spec {x y : I64} +@[pspec] theorem I64.add_spec {x y : I64} (hmin : I64.min ≤ x.val + y.val) (hmax : x.val + y.val ≤ I64.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := Scalar.add_spec hmin hmax -@[cepspec] theorem I128.add_spec {x y : I128} +@[pspec] theorem I128.add_spec {x y : I128} (hmin : I128.min ≤ x.val + y.val) (hmax : x.val + y.val ≤ I128.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := Scalar.add_spec hmin hmax -- Generic theorem - shouldn't be used much -@[cpspec] +@[pspec] theorem Scalar.sub_spec {ty} {x y : Scalar ty} (hmin : Scalar.min ty ≤ x.val - y.val) (hmax : x.val - y.val ≤ Scalar.max ty) : @@ -629,56 +629,56 @@ theorem Scalar.sub_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty} apply sub_spec <;> assumption /- Fine-grained theorems -/ -@[cepspec] theorem Usize.sub_spec {x y : Usize} (hmin : Usize.min ≤ x.val - y.val) : +@[pspec] theorem Usize.sub_spec {x y : Usize} (hmin : Usize.min ≤ x.val - y.val) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := by apply Scalar.sub_unsigned_spec <;> simp only [Scalar.min, *] -@[cepspec] theorem U8.sub_spec {x y : U8} (hmin : U8.min ≤ x.val - y.val) : +@[pspec] theorem U8.sub_spec {x y : U8} (hmin : U8.min ≤ x.val - y.val) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := by apply Scalar.sub_unsigned_spec <;> simp only [Scalar.min, *] -@[cepspec] theorem U16.sub_spec {x y : U16} (hmin : U16.min ≤ x.val - y.val) : +@[pspec] theorem U16.sub_spec {x y : U16} (hmin : U16.min ≤ x.val - y.val) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := by apply Scalar.sub_unsigned_spec <;> simp only [Scalar.min, *] -@[cepspec] theorem U32.sub_spec {x y : U32} (hmin : U32.min ≤ x.val - y.val) : +@[pspec] theorem U32.sub_spec {x y : U32} (hmin : U32.min ≤ x.val - y.val) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := by apply Scalar.sub_unsigned_spec <;> simp only [Scalar.min, *] -@[cepspec] theorem U64.sub_spec {x y : U64} (hmin : U64.min ≤ x.val - y.val) : +@[pspec] theorem U64.sub_spec {x y : U64} (hmin : U64.min ≤ x.val - y.val) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := by apply Scalar.sub_unsigned_spec <;> simp only [Scalar.min, *] -@[cepspec] theorem U128.sub_spec {x y : U128} (hmin : U128.min ≤ x.val - y.val) : +@[pspec] theorem U128.sub_spec {x y : U128} (hmin : U128.min ≤ x.val - y.val) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := by apply Scalar.sub_unsigned_spec <;> simp only [Scalar.min, *] -@[cepspec] theorem Isize.sub_spec {x y : Isize} (hmin : Isize.min ≤ x.val - y.val) +@[pspec] theorem Isize.sub_spec {x y : Isize} (hmin : Isize.min ≤ x.val - y.val) (hmax : x.val - y.val ≤ Isize.max) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := Scalar.sub_spec hmin hmax -@[cepspec] theorem I8.sub_spec {x y : I8} (hmin : I8.min ≤ x.val - y.val) +@[pspec] theorem I8.sub_spec {x y : I8} (hmin : I8.min ≤ x.val - y.val) (hmax : x.val - y.val ≤ I8.max) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := Scalar.sub_spec hmin hmax -@[cepspec] theorem I16.sub_spec {x y : I16} (hmin : I16.min ≤ x.val - y.val) +@[pspec] theorem I16.sub_spec {x y : I16} (hmin : I16.min ≤ x.val - y.val) (hmax : x.val - y.val ≤ I16.max) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := Scalar.sub_spec hmin hmax -@[cepspec] theorem I32.sub_spec {x y : I32} (hmin : I32.min ≤ x.val - y.val) +@[pspec] theorem I32.sub_spec {x y : I32} (hmin : I32.min ≤ x.val - y.val) (hmax : x.val - y.val ≤ I32.max) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := Scalar.sub_spec hmin hmax -@[cepspec] theorem I64.sub_spec {x y : I64} (hmin : I64.min ≤ x.val - y.val) +@[pspec] theorem I64.sub_spec {x y : I64} (hmin : I64.min ≤ x.val - y.val) (hmax : x.val - y.val ≤ I64.max) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := Scalar.sub_spec hmin hmax -@[cepspec] theorem I128.sub_spec {x y : I128} (hmin : I128.min ≤ x.val - y.val) +@[pspec] theorem I128.sub_spec {x y : I128} (hmin : I128.min ≤ x.val - y.val) (hmax : x.val - y.val ≤ I128.max) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := Scalar.sub_spec hmin hmax @@ -705,62 +705,62 @@ theorem Scalar.mul_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty} apply mul_spec <;> assumption /- Fine-grained theorems -/ -@[cepspec] theorem Usize.mul_spec {x y : Usize} (hmax : x.val * y.val ≤ Usize.max) : +@[pspec] theorem Usize.mul_spec {x y : Usize} (hmax : x.val * y.val ≤ Usize.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := by apply Scalar.mul_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem U8.mul_spec {x y : U8} (hmax : x.val * y.val ≤ U8.max) : +@[pspec] theorem U8.mul_spec {x y : U8} (hmax : x.val * y.val ≤ U8.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := by apply Scalar.mul_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem U16.mul_spec {x y : U16} (hmax : x.val * y.val ≤ U16.max) : +@[pspec] theorem U16.mul_spec {x y : U16} (hmax : x.val * y.val ≤ U16.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := by apply Scalar.mul_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem U32.mul_spec {x y : U32} (hmax : x.val * y.val ≤ U32.max) : +@[pspec] theorem U32.mul_spec {x y : U32} (hmax : x.val * y.val ≤ U32.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := by apply Scalar.mul_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem U64.mul_spec {x y : U64} (hmax : x.val * y.val ≤ U64.max) : +@[pspec] theorem U64.mul_spec {x y : U64} (hmax : x.val * y.val ≤ U64.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := by apply Scalar.mul_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem U128.mul_spec {x y : U128} (hmax : x.val * y.val ≤ U128.max) : +@[pspec] theorem U128.mul_spec {x y : U128} (hmax : x.val * y.val ≤ U128.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := by apply Scalar.mul_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem Isize.mul_spec {x y : Isize} (hmin : Isize.min ≤ x.val * y.val) +@[pspec] theorem Isize.mul_spec {x y : Isize} (hmin : Isize.min ≤ x.val * y.val) (hmax : x.val * y.val ≤ Isize.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := Scalar.mul_spec hmin hmax -@[cepspec] theorem I8.mul_spec {x y : I8} (hmin : I8.min ≤ x.val * y.val) +@[pspec] theorem I8.mul_spec {x y : I8} (hmin : I8.min ≤ x.val * y.val) (hmax : x.val * y.val ≤ I8.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := Scalar.mul_spec hmin hmax -@[cepspec] theorem I16.mul_spec {x y : I16} (hmin : I16.min ≤ x.val * y.val) +@[pspec] theorem I16.mul_spec {x y : I16} (hmin : I16.min ≤ x.val * y.val) (hmax : x.val * y.val ≤ I16.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := Scalar.mul_spec hmin hmax -@[cepspec] theorem I32.mul_spec {x y : I32} (hmin : I32.min ≤ x.val * y.val) +@[pspec] theorem I32.mul_spec {x y : I32} (hmin : I32.min ≤ x.val * y.val) (hmax : x.val * y.val ≤ I32.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := Scalar.mul_spec hmin hmax -@[cepspec] theorem I64.mul_spec {x y : I64} (hmin : I64.min ≤ x.val * y.val) +@[pspec] theorem I64.mul_spec {x y : I64} (hmin : I64.min ≤ x.val * y.val) (hmax : x.val * y.val ≤ I64.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := Scalar.mul_spec hmin hmax -@[cepspec] theorem I128.mul_spec {x y : I128} (hmin : I128.min ≤ x.val * y.val) +@[pspec] theorem I128.mul_spec {x y : I128} (hmin : I128.min ≤ x.val * y.val) (hmax : x.val * y.val ≤ I128.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := Scalar.mul_spec hmin hmax -- Generic theorem - shouldn't be used much -@[cpspec] +@[pspec] theorem Scalar.div_spec {ty} {x y : Scalar ty} (hnz : y.val ≠ 0) (hmin : Scalar.min ty ≤ scalar_div x.val y.val) @@ -788,66 +788,66 @@ theorem Scalar.div_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : S apply hs /- Fine-grained theorems -/ -@[cepspec] theorem Usize.div_spec (x : Usize) {y : Usize} (hnz : y.val ≠ 0) : +@[pspec] theorem Usize.div_spec (x : Usize) {y : Usize} (hnz : y.val ≠ 0) : ∃ z, x / y = ret z ∧ z.val = x.val / y.val := by apply Scalar.div_unsigned_spec <;> simp [*] -@[cepspec] theorem U8.div_spec (x : U8) {y : U8} (hnz : y.val ≠ 0) : +@[pspec] theorem U8.div_spec (x : U8) {y : U8} (hnz : y.val ≠ 0) : ∃ z, x / y = ret z ∧ z.val = x.val / y.val := by apply Scalar.div_unsigned_spec <;> simp [Scalar.max, *] -@[cepspec] theorem U16.div_spec (x : U16) {y : U16} (hnz : y.val ≠ 0) : +@[pspec] theorem U16.div_spec (x : U16) {y : U16} (hnz : y.val ≠ 0) : ∃ z, x / y = ret z ∧ z.val = x.val / y.val := by apply Scalar.div_unsigned_spec <;> simp [Scalar.max, *] -@[cepspec] theorem U32.div_spec (x : U32) {y : U32} (hnz : y.val ≠ 0) : +@[pspec] theorem U32.div_spec (x : U32) {y : U32} (hnz : y.val ≠ 0) : ∃ z, x / y = ret z ∧ z.val = x.val / y.val := by apply Scalar.div_unsigned_spec <;> simp [Scalar.max, *] -@[cepspec] theorem U64.div_spec (x : U64) {y : U64} (hnz : y.val ≠ 0) : +@[pspec] theorem U64.div_spec (x : U64) {y : U64} (hnz : y.val ≠ 0) : ∃ z, x / y = ret z ∧ z.val = x.val / y.val := by apply Scalar.div_unsigned_spec <;> simp [Scalar.max, *] -@[cepspec] theorem U128.div_spec (x : U128) {y : U128} (hnz : y.val ≠ 0) : +@[pspec] theorem U128.div_spec (x : U128) {y : U128} (hnz : y.val ≠ 0) : ∃ z, x / y = ret z ∧ z.val = x.val / y.val := by apply Scalar.div_unsigned_spec <;> simp [Scalar.max, *] -@[cepspec] theorem Isize.div_spec (x : Isize) {y : Isize} +@[pspec] theorem Isize.div_spec (x : Isize) {y : Isize} (hnz : y.val ≠ 0) (hmin : Isize.min ≤ scalar_div x.val y.val) (hmax : scalar_div x.val y.val ≤ Isize.max): ∃ z, x / y = ret z ∧ z.val = scalar_div x.val y.val := Scalar.div_spec hnz hmin hmax -@[cepspec] theorem I8.div_spec (x : I8) {y : I8} +@[pspec] theorem I8.div_spec (x : I8) {y : I8} (hnz : y.val ≠ 0) (hmin : I8.min ≤ scalar_div x.val y.val) (hmax : scalar_div x.val y.val ≤ I8.max): ∃ z, x / y = ret z ∧ z.val = scalar_div x.val y.val := Scalar.div_spec hnz hmin hmax -@[cepspec] theorem I16.div_spec (x : I16) {y : I16} +@[pspec] theorem I16.div_spec (x : I16) {y : I16} (hnz : y.val ≠ 0) (hmin : I16.min ≤ scalar_div x.val y.val) (hmax : scalar_div x.val y.val ≤ I16.max): ∃ z, x / y = ret z ∧ z.val = scalar_div x.val y.val := Scalar.div_spec hnz hmin hmax -@[cepspec] theorem I32.div_spec (x : I32) {y : I32} +@[pspec] theorem I32.div_spec (x : I32) {y : I32} (hnz : y.val ≠ 0) (hmin : I32.min ≤ scalar_div x.val y.val) (hmax : scalar_div x.val y.val ≤ I32.max): ∃ z, x / y = ret z ∧ z.val = scalar_div x.val y.val := Scalar.div_spec hnz hmin hmax -@[cepspec] theorem I64.div_spec (x : I64) {y : I64} +@[pspec] theorem I64.div_spec (x : I64) {y : I64} (hnz : y.val ≠ 0) (hmin : I64.min ≤ scalar_div x.val y.val) (hmax : scalar_div x.val y.val ≤ I64.max): ∃ z, x / y = ret z ∧ z.val = scalar_div x.val y.val := Scalar.div_spec hnz hmin hmax -@[cepspec] theorem I128.div_spec (x : I128) {y : I128} +@[pspec] theorem I128.div_spec (x : I128) {y : I128} (hnz : y.val ≠ 0) (hmin : I128.min ≤ scalar_div x.val y.val) (hmax : scalar_div x.val y.val ≤ I128.max): @@ -855,7 +855,7 @@ theorem Scalar.div_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : S Scalar.div_spec hnz hmin hmax -- Generic theorem - shouldn't be used much -@[cpspec] +@[pspec] theorem Scalar.rem_spec {ty} {x y : Scalar ty} (hnz : y.val ≠ 0) (hmin : Scalar.min ty ≤ scalar_rem x.val y.val) @@ -883,59 +883,59 @@ theorem Scalar.rem_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : S simp [*] at hs simp [*] -@[cepspec] theorem Usize.rem_spec (x : Usize) {y : Usize} (hnz : y.val ≠ 0) : +@[pspec] theorem Usize.rem_spec (x : Usize) {y : Usize} (hnz : y.val ≠ 0) : ∃ z, x % y = ret z ∧ z.val = x.val % y.val := by apply Scalar.rem_unsigned_spec <;> simp [*] -@[cepspec] theorem U8.rem_spec (x : U8) {y : U8} (hnz : y.val ≠ 0) : +@[pspec] theorem U8.rem_spec (x : U8) {y : U8} (hnz : y.val ≠ 0) : ∃ z, x % y = ret z ∧ z.val = x.val % y.val := by apply Scalar.rem_unsigned_spec <;> simp [Scalar.max, *] -@[cepspec] theorem U16.rem_spec (x : U16) {y : U16} (hnz : y.val ≠ 0) : +@[pspec] theorem U16.rem_spec (x : U16) {y : U16} (hnz : y.val ≠ 0) : ∃ z, x % y = ret z ∧ z.val = x.val % y.val := by apply Scalar.rem_unsigned_spec <;> simp [Scalar.max, *] -@[cepspec] theorem U32.rem_spec (x : U32) {y : U32} (hnz : y.val ≠ 0) : +@[pspec] theorem U32.rem_spec (x : U32) {y : U32} (hnz : y.val ≠ 0) : ∃ z, x % y = ret z ∧ z.val = x.val % y.val := by apply Scalar.rem_unsigned_spec <;> simp [Scalar.max, *] -@[cepspec] theorem U64.rem_spec (x : U64) {y : U64} (hnz : y.val ≠ 0) : +@[pspec] theorem U64.rem_spec (x : U64) {y : U64} (hnz : y.val ≠ 0) : ∃ z, x % y = ret z ∧ z.val = x.val % y.val := by apply Scalar.rem_unsigned_spec <;> simp [Scalar.max, *] -@[cepspec] theorem U128.rem_spec (x : U128) {y : U128} (hnz : y.val ≠ 0) : +@[pspec] theorem U128.rem_spec (x : U128) {y : U128} (hnz : y.val ≠ 0) : ∃ z, x % y = ret z ∧ z.val = x.val % y.val := by apply Scalar.rem_unsigned_spec <;> simp [Scalar.max, *] -@[cepspec] theorem I8.rem_spec (x : I8) {y : I8} +@[pspec] theorem I8.rem_spec (x : I8) {y : I8} (hnz : y.val ≠ 0) (hmin : I8.min ≤ scalar_rem x.val y.val) (hmax : scalar_rem x.val y.val ≤ I8.max): ∃ z, x % y = ret z ∧ z.val = scalar_rem x.val y.val := Scalar.rem_spec hnz hmin hmax -@[cepspec] theorem I16.rem_spec (x : I16) {y : I16} +@[pspec] theorem I16.rem_spec (x : I16) {y : I16} (hnz : y.val ≠ 0) (hmin : I16.min ≤ scalar_rem x.val y.val) (hmax : scalar_rem x.val y.val ≤ I16.max): ∃ z, x % y = ret z ∧ z.val = scalar_rem x.val y.val := Scalar.rem_spec hnz hmin hmax -@[cepspec] theorem I32.rem_spec (x : I32) {y : I32} +@[pspec] theorem I32.rem_spec (x : I32) {y : I32} (hnz : y.val ≠ 0) (hmin : I32.min ≤ scalar_rem x.val y.val) (hmax : scalar_rem x.val y.val ≤ I32.max): ∃ z, x % y = ret z ∧ z.val = scalar_rem x.val y.val := Scalar.rem_spec hnz hmin hmax -@[cepspec] theorem I64.rem_spec (x : I64) {y : I64} +@[pspec] theorem I64.rem_spec (x : I64) {y : I64} (hnz : y.val ≠ 0) (hmin : I64.min ≤ scalar_rem x.val y.val) (hmax : scalar_rem x.val y.val ≤ I64.max): ∃ z, x % y = ret z ∧ z.val = scalar_rem x.val y.val := Scalar.rem_spec hnz hmin hmax -@[cepspec] theorem I128.rem_spec (x : I128) {y : I128} +@[pspec] theorem I128.rem_spec (x : I128) {y : I128} (hnz : y.val ≠ 0) (hmin : I128.min ≤ scalar_rem x.val y.val) (hmax : scalar_rem x.val y.val ≤ I128.max): diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean index 76a92795..0ad16ab6 100644 --- a/backends/lean/Base/Progress/Base.lean +++ b/backends/lean/Base/Progress/Base.lean @@ -2,11 +2,12 @@ import Lean import Std.Lean.HashSet import Base.Utils import Base.Primitives.Base +import Base.Extensions namespace Progress open Lean Elab Term Meta -open Utils +open Utils Extensions -- We can't define and use trace classes in the same file initialize registerTraceClass `Progress @@ -15,17 +16,17 @@ initialize registerTraceClass `Progress structure PSpecDesc where -- The universally quantified variables + -- Can be fvars or mvars fvars : Array Expr -- The existentially quantified variables evars : Array Expr + -- The function applied to its arguments + fArgsExpr : Expr -- The function - fExpr : Expr fName : Name -- The function arguments fLevels : List Level args : Array Expr - -- The universally quantified variables which appear in the function arguments - argsFVars : Array FVarId -- The returned value ret : Expr -- The postcondition (if there is) @@ -37,7 +38,7 @@ section Methods variable [MonadError m] variable {a : Type} - /- Analyze a pspec theorem to decompose its arguments. + /- Analyze a goal or a pspec theorem to decompose its arguments. PSpec theorems should be of the following shape: ``` @@ -56,12 +57,20 @@ section Methods TODO: generalize for when we do inductive proofs -/ partial - def withPSpec [Inhabited (m a)] [Nonempty (m a)] (th : Expr) (k : PSpecDesc → m a) - (sanityChecks : Bool := false) : + def withPSpec [Inhabited (m a)] [Nonempty (m a)] + (isGoal : Bool) (th : Expr) (k : PSpecDesc → m a) : m a := do trace[Progress] "Proposition: {th}" -- Dive into the quantified variables and the assumptions - forallTelescope th.consumeMData fun fvars th => do + -- Note that if we analyze a pspec theorem to register it in a database (i.e. + -- a discrimination tree), we need to introduce *meta-variables* for the + -- quantified variables. + let telescope (k : Array Expr → Expr → m a) : m a := + if isGoal then forallTelescope th.consumeMData (fun fvars th => k fvars th) + else do + let (fvars, _, th) ← forallMetaTelescope th.consumeMData + k fvars th + telescope fun fvars th => do trace[Progress] "Universally quantified arguments and assumptions: {fvars}" -- Dive into the existentials existsTelescope th.consumeMData fun evars th => do @@ -78,7 +87,7 @@ section Methods -- destruct the application to get the function name mExpr.consumeMData.withApp fun mf margs => do trace[Progress] "After stripping the arguments of the monad expression:\n- mf: {mf}\n- margs: {margs}" - let (fExpr, f, args) ← do + let (fArgsExpr, f, args) ← do if mf.isConst ∧ mf.constName = ``Bind.bind then do -- Dive into the bind let fExpr := (margs.get! 4).consumeMData @@ -86,29 +95,27 @@ section Methods else pure (mExpr, mf, margs) trace[Progress] "After stripping the arguments of the function call:\n- f: {f}\n- args: {args}" if ¬ f.isConst then throwError "Not a constant: {f}" - -- Compute the set of universally quantified variables which appear in the function arguments - let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty - -- Sanity check - if sanityChecks then - -- All the variables which appear in the inputs given to the function are - -- universally quantified (in particular, they are not *existentially* quantified) - let fvarsSet : HashSet FVarId := HashSet.ofArray (fvars.map (fun x => x.fvarId!)) - let filtArgsFVars := allArgsFVars.toArray.filter (fun fvar => ¬ fvarsSet.contains fvar) - if ¬ filtArgsFVars.isEmpty then + -- *Sanity check* (activated if we are analyzing a theorem to register it in a DB) + -- Check if some existentially quantified variables + let _ := do + -- Collect all the free variables in the arguments + let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty + -- Check if they intersect the fvars we introduced for the existentially quantified variables + let evarsSet : HashSet FVarId := HashSet.ofArray (evars.map (fun (x : Expr) => x.fvarId!)) + let filtArgsFVars := allArgsFVars.toArray.filter (fun var => evarsSet.contains var) + if filtArgsFVars.isEmpty then pure () + else let filtArgsFVars := filtArgsFVars.map (fun fvarId => Expr.fvar fvarId) throwError "Some of the function inputs are not universally quantified: {filtArgsFVars}" - let argsFVars := fvars.map (fun x => x.fvarId!) - let argsFVars := argsFVars.filter (fun fvar => allArgsFVars.contains fvar) -- Return - trace[Progress] "Function: {f.constName!}"; + trace[Progress] "Function with arguments: {fArgsExpr}"; let thDesc := { fvars := fvars evars := evars - fExpr + fArgsExpr fName := f.constName! fLevels := f.constLevels! args := args - argsFVars ret := ret post := post } @@ -116,117 +123,18 @@ section Methods end Methods -def getPSpecFunName (th : Expr) : MetaM Name := - withPSpec th (fun d => do pure d.fName) true +def getPSpecFunArgsExpr (isGoal : Bool) (th : Expr) : MetaM Expr := + withPSpec isGoal th (fun d => do pure d.fArgsExpr) -def getPSpecClassFunNames (th : Expr) : MetaM (Name × Name) := - withPSpec th (fun d => do - let arg0 := d.args.get! 0 - arg0.withApp fun f _ => do - if ¬ f.isConst then throwError "Not a constant: {f}" - pure (d.fName, f.constName) - ) true - -def getPSpecClassFunNameArg (th : Expr) : MetaM (Name × Expr) := - withPSpec th (fun d => do - let arg0 := d.args.get! 0 - pure (d.fName, arg0) - ) true - --- "Regular" pspec attribute +-- pspec attribute structure PSpecAttr where attr : AttributeImpl - ext : MapDeclarationExtension Name - deriving Inhabited - -/- pspec attribute for type classes: we use the name of the type class to - lookup another map. We use the *first* argument of the type class to lookup - into this second map. - - Example: - ======== - We use type classes for addition. For instance, the addition between two - U32 is written (without syntactic sugar) as `HAdd.add (Scalar ty) x y`. As a consequence, - we store the theorem through the bindings: HAdd.add → Scalar → ... - - SH: TODO: this (and `PSpecClassExprAttr`) is a bit ad-hoc. For now it works for the - specs of the scalar operations, which is what I really need, but I'm not sure it - applies well to other situations. A better way would probably to use type classes, but - I couldn't get them to work on those cases. It is worth retrying. --/ -structure PSpecClassAttr where - attr : AttributeImpl - ext : MapDeclarationExtension (NameMap Name) - deriving Inhabited - -/- Same as `PSpecClassAttr` but we use the full first argument (it works when it - is a constant). -/ -structure PSpecClassExprAttr where - attr : AttributeImpl - ext : MapDeclarationExtension (HashMap Expr Name) + ext : DiscrTreeExtension Name true deriving Inhabited --- TODO: the original function doesn't define correctly the `addImportedFn`. Do a PR? -def mkMapDeclarationExtension [Inhabited α] (name : Name := by exact decl_name%) : - IO (MapDeclarationExtension α) := - registerSimplePersistentEnvExtension { - name := name, - addImportedFn := fun a => a.foldl (fun s a => a.foldl (fun s (k, v) => s.insert k v) s) RBMap.empty, - addEntryFn := fun s n => s.insert n.1 n.2 , - toArrayFn := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1) - } - --- Declare an extension of maps of maps (using [RBMap]). --- The important point is that we need to merge the bound values (which are maps). -def mkMapMapDeclarationExtension [Inhabited β] (ord : α → α → Ordering) - (name : Name := by exact decl_name%) : - IO (MapDeclarationExtension (RBMap α β ord)) := - registerSimplePersistentEnvExtension { - name := name, - addImportedFn := fun a => - a.foldl (fun s a => a.foldl ( - -- We need to merge the maps - fun s (k0, k1_to_v) => - match s.find? k0 with - | none => - -- No binding: insert one - s.insert k0 k1_to_v - | some m => - -- There is already a binding: merge - let m := RBMap.fold (fun m k v => m.insert k v) m k1_to_v - s.insert k0 m) - s) RBMap.empty, - addEntryFn := fun s n => s.insert n.1 n.2 , - toArrayFn := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1) - } - --- Declare an extension of maps of maps (using [HashMap]). --- The important point is that we need to merge the bound values (which are maps). -def mkMapHashMapDeclarationExtension [BEq α] [Hashable α] [Inhabited β] - (name : Name := by exact decl_name%) : - IO (MapDeclarationExtension (HashMap α β)) := - registerSimplePersistentEnvExtension { - name := name, - addImportedFn := fun a => - a.foldl (fun s a => a.foldl ( - -- We need to merge the maps - fun s (k0, k1_to_v) => - match s.find? k0 with - | none => - -- No binding: insert one - s.insert k0 k1_to_v - | some m => - -- There is already a binding: merge - let m := HashMap.fold (fun m k v => m.insert k v) m k1_to_v - s.insert k0 m) - s) RBMap.empty, - addEntryFn := fun s n => s.insert n.1 n.2 , - toArrayFn := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1) - } - -/- The persistent map from function to pspec theorems. -/ +/- The persistent map from expressions to pspec theorems. -/ initialize pspecAttr : PSpecAttr ← do - let ext ← mkMapDeclarationExtension `pspecMap + let ext ← mkDiscrTreeExtention `pspecMap true let attrImpl : AttributeImpl := { name := `pspec descr := "Marks theorems to use with the `progress` tactic" @@ -238,130 +146,30 @@ initialize pspecAttr : PSpecAttr ← do -- Lookup the theorem let env ← getEnv let thDecl := env.constants.find! thName - let fName ← MetaM.run' (getPSpecFunName thDecl.type) - trace[Progress] "Registering spec theorem for {fName}" - let env := ext.addEntry env (fName, thName) - setEnv env - pure () - } - registerBuiltinAttribute attrImpl - pure { attr := attrImpl, ext := ext } - -/- The persistent map from type classes to pspec theorems -/ -initialize pspecClassAttr : PSpecClassAttr ← do - let ext : MapDeclarationExtension (NameMap Name) ← - mkMapMapDeclarationExtension Name.quickCmp `pspecClassMap - let attrImpl : AttributeImpl := { - name := `cpspec - descr := "Marks theorems to use for type classes with the `progress` tactic" - add := fun thName stx attrKind => do - Attribute.Builtin.ensureNoArgs stx - -- TODO: use the attribute kind - unless attrKind == AttributeKind.global do - throwError "invalid attribute 'cpspec', must be global" - -- Lookup the theorem - let env ← getEnv - let thDecl := env.constants.find! thName - let (fName, argName) ← MetaM.run' (getPSpecClassFunNames thDecl.type) - trace[Progress] "Registering class spec theorem for ({fName}, {argName})" - -- Update the entry if there is one, add an entry if there is none - let env := - match (ext.getState (← getEnv)).find? fName with - | none => - let m := RBMap.ofList [(argName, thName)] - ext.addEntry env (fName, m) - | some m => - let m := m.insert argName thName - ext.addEntry env (fName, m) + let fKey ← MetaM.run' (do + let fExpr ← getPSpecFunArgsExpr false thDecl.type + trace[Progress] "Registering spec theorem for {fExpr}" + -- Convert the function expression to a discrimination tree key + DiscrTree.mkPath fExpr) + let env := ext.addEntry env (fKey, thName) setEnv env + trace[Progress] "Saved the environment" pure () } registerBuiltinAttribute attrImpl pure { attr := attrImpl, ext := ext } -/- The 2nd persistent map from type classes to pspec theorems -/ -initialize pspecClassExprAttr : PSpecClassExprAttr ← do - let ext : MapDeclarationExtension (HashMap Expr Name) ← - mkMapHashMapDeclarationExtension `pspecClassExprMap - let attrImpl : AttributeImpl := { - name := `cepspec - descr := "Marks theorems to use for type classes with the `progress` tactic" - add := fun thName stx attrKind => do - Attribute.Builtin.ensureNoArgs stx - -- TODO: use the attribute kind - unless attrKind == AttributeKind.global do - throwError "invalid attribute 'cpspec', must be global" - -- Lookup the theorem - let env ← getEnv - let thDecl := env.constants.find! thName - let (fName, arg) ← MetaM.run' (getPSpecClassFunNameArg thDecl.type) - -- Sanity check: no variables appear in the argument - MetaM.run' do - let fvars ← getFVarIds arg - if ¬ fvars.isEmpty then throwError "The first argument ({arg}) contains variables" - -- We store two bindings: - -- - arg to theorem name - -- - reduced arg to theorem name - let rarg ← MetaM.run' (reduceAll arg) - trace[Progress] "Registering class spec theorem for ({fName}, {arg}) and ({fName}, {rarg})" - -- Update the entry if there is one, add an entry if there is none - let env := - match (ext.getState (← getEnv)).find? fName with - | none => - let m := HashMap.ofList [(arg, thName), (rarg, thName)] - ext.addEntry env (fName, m) - | some m => - let m := m.insert arg thName - let m := m.insert rarg thName - ext.addEntry env (fName, m) - setEnv env - pure () - } - registerBuiltinAttribute attrImpl - pure { attr := attrImpl, ext := ext } - - -def PSpecAttr.find? (s : PSpecAttr) (name : Name) : MetaM (Option Name) := do - return (s.ext.getState (← getEnv)).find? name - -def PSpecClassAttr.find? (s : PSpecClassAttr) (className argName : Name) : MetaM (Option Name) := do - match (s.ext.getState (← getEnv)).find? className with - | none => return none - | some map => return map.find? argName - -def PSpecClassExprAttr.find? (s : PSpecClassExprAttr) (className : Name) (arg : Expr) : MetaM (Option Name) := do - match (s.ext.getState (← getEnv)).find? className with - | none => return none - | some map => return map.find? arg - -def PSpecAttr.getState (s : PSpecAttr) : MetaM (NameMap Name) := do - pure (s.ext.getState (← getEnv)) +def PSpecAttr.find? (s : PSpecAttr) (e : Expr) : MetaM (Array Name) := do + (s.ext.getState (← getEnv)).getMatch e -def PSpecClassAttr.getState (s : PSpecClassAttr) : MetaM (NameMap (NameMap Name)) := do - pure (s.ext.getState (← getEnv)) - -def PSpecClassExprAttr.getState (s : PSpecClassExprAttr) : MetaM (NameMap (HashMap Expr Name)) := do +def PSpecAttr.getState (s : PSpecAttr) : MetaM (DiscrTree Name true) := do pure (s.ext.getState (← getEnv)) def showStoredPSpec : MetaM Unit := do let st ← pspecAttr.getState - let s := st.toList.foldl (fun s (f, th) => f!"{s}\n{f} → {th}") f!"" - IO.println s - -def showStoredPSpecClass : MetaM Unit := do - let st ← pspecClassAttr.getState - let s := st.toList.foldl (fun s (f, m) => - let ms := m.toList.foldl (fun s (f, th) => - f!"{s}\n {f} → {th}") f!"" - f!"{s}\n{f} → [{ms}]") f!"" - IO.println s - -def showStoredPSpecExprClass : MetaM Unit := do - let st ← pspecClassExprAttr.getState - let s := st.toList.foldl (fun s (f, m) => - let ms := m.toList.foldl (fun s (f, th) => - f!"{s}\n {f} → {th}") f!"" - f!"{s}\n{f} → [{ms}]") f!"" + -- TODO: how can we iterate over (at least) the values stored in the tree? + --let s := st.toList.foldl (fun s (f, th) => f!"{s}\n{f} → {th}") f!"" + let s := f!"{st}" IO.println s end Progress diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean index ba63f09d..a6a4e82a 100644 --- a/backends/lean/Base/Progress/Progress.lean +++ b/backends/lean/Base/Progress/Progress.lean @@ -204,11 +204,11 @@ def getFirstArg (args : Array Expr) : Option Expr := do if args.size = 0 then none else some (args.get! 0) -/- Helper: try to lookup a theorem and apply it, or continue with another tactic - if it fails -/ +/- Helper: try to lookup a theorem and apply it. + Return true if it succeeded. -/ def tryLookupApply (keep : Option Name) (ids : Array (Option Name)) (splitPost : Bool) (asmTac : TacticM Unit) (fExpr : Expr) - (kind : String) (th : Option TheoremOrLocal) (x : TacticM Unit) : TacticM Unit := do + (kind : String) (th : Option TheoremOrLocal) : TacticM Bool := do let res ← do match th with | none => @@ -223,9 +223,9 @@ def tryLookupApply (keep : Option Name) (ids : Array (Option Name)) (splitPost : pure (some res) catch _ => none match res with - | some .Ok => return () + | some .Ok => return true | some (.Error msg) => throwError msg - | none => x + | none => return false -- The array of ids are identifiers to use when introducing fresh variables def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrLocal) @@ -236,11 +236,19 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL let goalTy ← mgoal.getType trace[Progress] "goal: {goalTy}" -- Dive into the goal to lookup the theorem - let (fExpr, fName, args) ← do - withPSpec goalTy fun desc => - -- TODO: check that no quantified variables in the arguments - pure (desc.fExpr, desc.fName, desc.args) - trace[Progress] "Function: {fName}" + -- Remark: if we don't isolate the call to `withPSpec` to immediately "close" + -- the terms immediately, we may end up with the error: + -- "(kernel) declaration has free variables" + -- I'm not sure I understand why. + -- TODO: we should also check that no quantified variable appears in fExpr. + -- If such variables appear, we should just fail because the goal doesn't + -- have the proper shape. + let fExpr ← do + let isGoal := true + withPSpec isGoal goalTy fun desc => do + let fExpr := desc.fArgsExpr + trace[Progress] "Expression to match: {fExpr}" + pure fExpr -- If the user provided a theorem/assumption: use it. -- Otherwise, lookup one. match withTh with @@ -258,36 +266,24 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL match res with | .Ok => return () | .Error msg => throwError msg - -- It failed: try to lookup a theorem - -- TODO: use a list of theorems, and try them one by one? - trace[Progress] "No assumption succeeded: trying to lookup a theorem" - let pspec ← do - let thName ← pspecAttr.find? fName - pure (thName.map fun th => .Theorem th) - tryLookupApply keep ids splitPost asmTac fExpr "pspec theorem" pspec do - -- It failed: try to lookup a *class* expr spec theorem (those are more - -- specific than class spec theorems) - trace[Progress] "Failed using a pspec theorem: trying to lookup a pspec class expr theorem" - let pspecClassExpr ← do - match getFirstArg args with - | none => pure none - | some arg => do - trace[Progress] "Using: f:{fName}, arg: {arg}" - let thName ← pspecClassExprAttr.find? fName arg - pure (thName.map fun th => .Theorem th) - tryLookupApply keep ids splitPost asmTac fExpr "pspec class expr theorem" pspecClassExpr do - -- It failed: try to lookup a *class* spec theorem - trace[Progress] "Failed using a pspec class expr theorem: trying to lookup a pspec class theorem" - let pspecClass ← do - match ← getFirstArgAppName args with - | none => pure none - | some argName => do - trace[Progress] "Using: f: {fName}, arg: {argName}" - let thName ← pspecClassAttr.find? fName argName - pure (thName.map fun th => .Theorem th) - tryLookupApply keep ids splitPost asmTac fExpr "pspec class theorem" pspecClass do - trace[Progress] "Failed using a pspec class theorem: trying to use a recursive assumption" - -- Try a recursive call - we try the assumptions of kind "auxDecl" + -- It failed: lookup the pspec theorems which match the expression + trace[Progress] "No assumption succeeded: trying to lookup a pspec theorem" + let pspecs : Array TheoremOrLocal ← do + let thNames ← pspecAttr.find? fExpr + -- TODO: because of reduction, there may be several valid theorems (for + -- instance for the scalars). We need to sort them from most specific to + -- least specific. For now, we assume the most specific theorems are at + -- the end. + let thNames := thNames.reverse + trace[Progress] "Looked up pspec theorems: {thNames}" + pure (thNames.map fun th => TheoremOrLocal.Theorem th) + -- Try the theorems one by one + for pspec in pspecs do + if ← tryLookupApply keep ids splitPost asmTac fExpr "pspec theorem" pspec then return () + else pure () + -- It failed: try to use the recursive assumptions + trace[Progress] "Failed using a pspec theorem: trying to use a recursive assumption" + -- We try to apply the assumptions of kind "auxDecl" let ctx ← Lean.MonadLCtx.getLCtx let decls ← ctx.getAllDecls let decls := decls.filter (λ decl => match decl.kind with @@ -381,8 +377,6 @@ namespace Test -- The following commands display the databases of theorems -- #eval showStoredPSpec - -- #eval showStoredPSpecClass - -- #eval showStoredPSpecExprClass open alloc.vec example {ty} {x y : Scalar ty} @@ -402,6 +396,19 @@ namespace Test example {x y : U32} (hmax : x.val + y.val ≤ U32.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by + -- This spec theorem is suboptimal, but it is good to check that it works + progress with Scalar.add_spec as ⟨ z, h1 .. ⟩ + simp [*, h1] + + example {x y : U32} + (hmax : x.val + y.val ≤ U32.max) : + ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by + progress with U32.add_spec as ⟨ z, h1 .. ⟩ + simp [*, h1] + + example {x y : U32} + (hmax : x.val + y.val ≤ U32.max) : + ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by progress keep _ as ⟨ z, h1 .. ⟩ simp [*, h1] diff --git a/backends/lean/Base/Utils.lean b/backends/lean/Base/Utils.lean index b917a789..b0032281 100644 --- a/backends/lean/Base/Utils.lean +++ b/backends/lean/Base/Utils.lean @@ -159,47 +159,96 @@ elab "print_ctx_decls" : tactic => do let decls ← ctx.getDecls printDecls decls --- A map visitor function for expressions (adapted from `AbstractNestedProofs.visit`) +-- A map-reduce visitor function for expressions (adapted from `AbstractNestedProofs.visit`) -- The continuation takes as parameters: -- - the current depth of the expression (useful for printing/debugging) -- - the expression to explore -partial def mapVisit (k : Nat → Expr → MetaM Expr) (e : Expr) : MetaM Expr := do - let mapVisitBinders (xs : Array Expr) (k2 : MetaM Expr) : MetaM Expr := do +partial def mapreduceVisit {a : Type} (k : Nat → a → Expr → MetaM (a × Expr)) + (state : a) (e : Expr) : MetaM (a × Expr) := do + let mapreduceVisitBinders (state : a) (xs : Array Expr) (k2 : a → MetaM (a × Expr)) : + MetaM (a × Expr) := do let localInstances ← getLocalInstances - let mut lctx ← getLCtx - for x in xs do - let xFVarId := x.fvarId! - let localDecl ← xFVarId.getDecl - let type ← mapVisit k localDecl.type - let localDecl := localDecl.setType type - let localDecl ← match localDecl.value? with - | some value => let value ← mapVisit k value; pure <| localDecl.setValue value - | none => pure localDecl - lctx :=lctx.modifyLocalDecl xFVarId fun _ => localDecl - withLCtx lctx localInstances k2 + -- Update the local declarations for the bindings in context `lctx` + let rec visit_xs (lctx : LocalContext) (state : a) (xs : List Expr) : MetaM (LocalContext × a) := do + match xs with + | [] => pure (lctx, state) + | x :: xs => do + let xFVarId := x.fvarId! + let localDecl ← xFVarId.getDecl + let (state, type) ← mapreduceVisit k state localDecl.type + let localDecl := localDecl.setType type + let (state, localDecl) ← match localDecl.value? with + | some value => + let (state, value) ← mapreduceVisit k state value + pure (state, localDecl.setValue value) + | none => pure (state, localDecl) + let lctx := lctx.modifyLocalDecl xFVarId fun _ => localDecl + -- Recursive call + visit_xs lctx state xs + let (lctx, state) ← visit_xs (← getLCtx) state xs.toList + -- Call the continuation with the updated context + withLCtx lctx localInstances (k2 state) -- TODO: use a cache? (Lean.checkCache) - let rec visit (i : Nat) (e : Expr) : MetaM Expr := do + let rec visit (i : Nat) (state : a) (e : Expr) : MetaM (a × Expr) := do -- Explore - let e ← k i e + let (state, e) ← k i state e match e with | .bvar _ | .fvar _ | .mvar _ | .sort _ | .lit _ - | .const _ _ => pure e - | .app .. => do e.withApp fun f args => return mkAppN f (← args.mapM (visit (i + 1))) + | .const _ _ => pure (state, e) + | .app .. => do e.withApp fun f args => do + let (state, args) ← args.foldlM (fun (state, args) arg => do let (state, arg) ← visit (i + 1) state arg; pure (state, arg :: args)) (state, []) + let args := args.reverse + let (state, f) ← visit (i + 1) state f + let e' := mkAppN f (Array.mk args) + return (state, e') | .lam .. => lambdaLetTelescope e fun xs b => - mapVisitBinders xs do mkLambdaFVars xs (← visit (i + 1) b) (usedLetOnly := false) + mapreduceVisitBinders state xs fun state => do + let (state, b) ← visit (i + 1) state b + let e' ← mkLambdaFVars xs b (usedLetOnly := false) + return (state, e') | .forallE .. => do - forallTelescope e fun xs b => mapVisitBinders xs do mkForallFVars xs (← visit (i + 1) b) + forallTelescope e fun xs b => + mapreduceVisitBinders state xs fun state => do + let (state, b) ← visit (i + 1) state b + let e' ← mkForallFVars xs b + return (state, e') | .letE .. => do - lambdaLetTelescope e fun xs b => mapVisitBinders xs do - mkLambdaFVars xs (← visit (i + 1) b) (usedLetOnly := false) - | .mdata _ b => return e.updateMData! (← visit (i + 1) b) - | .proj _ _ b => return e.updateProj! (← visit (i + 1) b) - visit 0 e + lambdaLetTelescope e fun xs b => + mapreduceVisitBinders state xs fun state => do + let (state, b) ← visit (i + 1) state b + let e' ← mkLambdaFVars xs b (usedLetOnly := false) + return (state, e') + | .mdata _ b => do + let (state, b) ← visit (i + 1) state b + return (state, e.updateMData! b) + | .proj _ _ b => do + let (state, b) ← visit (i + 1) state b + return (state, e.updateProj! b) + visit 0 state e + +-- A map visitor function for expressions (adapted from `AbstractNestedProofs.visit`) +-- The continuation takes as parameters: +-- - the current depth of the expression (useful for printing/debugging) +-- - the expression to explore +partial def mapVisit (k : Nat → Expr → MetaM Expr) (e : Expr) : MetaM Expr := do + let k' i (_ : Unit) e := do + let e ← k i e + pure ((), e) + let (_, e) ← mapreduceVisit k' () e + pure e + +-- A reduce visitor +partial def reduceVisit {a : Type} (k : Nat → a → Expr → MetaM a) (s : a) (e : Expr) : MetaM a := do + let k' i (s : a) e := do + let s ← k i s e + pure (s, e) + let (s, _) ← mapreduceVisit k' s e + pure s -- Generate a fresh user name for an anonymous proposition to introduce in the -- assumptions @@ -371,15 +420,22 @@ def splitConjTarget : TacticM Unit := do -- Destruct an equaliy and return the two sides def destEq (e : Expr) : MetaM (Expr × Expr) := do - e.withApp fun f args => + e.consumeMData.withApp fun f args => if f.isConstOf ``Eq ∧ args.size = 3 then pure (args.get! 1, args.get! 2) else throwError "Not an equality: {e}" -- Return the set of FVarIds in the expression +-- TODO: this collects fvars introduced in the inner bindings partial def getFVarIds (e : Expr) (hs : HashSet FVarId := HashSet.empty) : MetaM (HashSet FVarId) := do - e.withApp fun body args => do - let hs := if body.isFVar then hs.insert body.fvarId! else hs - args.foldlM (fun hs arg => getFVarIds arg hs) hs + reduceVisit (fun _ (hs : HashSet FVarId) e => + if e.isFVar then pure (hs.insert e.fvarId!) else pure hs) + hs e + +-- Return the set of MVarIds in the expression +partial def getMVarIds (e : Expr) (hs : HashSet MVarId := HashSet.empty) : MetaM (HashSet MVarId) := do + reduceVisit (fun _ (hs : HashSet MVarId) e => + if e.isMVar then pure (hs.insert e.mvarId!) else pure hs) + hs e -- Tactic to split on a disjunction. -- The expression `h` should be an fvar. |