From f7493580421d31b1d3798521b4f9154e69755f89 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 7 Dec 2023 16:29:18 +0100 Subject: Reorganize a bit --- backends/lean/Base/Diverge/Elab.lean | 24 ++++++++++-- backends/lean/Base/Lookup/Base.lean | 70 +++++++++++++++++++++++++++++++++++ backends/lean/Base/Progress/Base.lean | 65 ++------------------------------ 3 files changed, 94 insertions(+), 65 deletions(-) create mode 100644 backends/lean/Base/Lookup/Base.lean (limited to 'backends/lean/Base') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index c6628486..0088fd16 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -418,7 +418,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do | .lit _ | .mvar _ | .sort _ => throwError "Unreachable" - | .lam .. => throwError "Unimplemented" + | .lam .. => throwError "Unimplemented" -- TODO | .forallE .. => throwError "Unreachable" -- Shouldn't get there | .letE .. => do -- Telescope all the let-bindings (remark: this also telescopes the lambdas) @@ -585,6 +585,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do else do -- Remaining case: normal application. -- It shouldn't use the continuation. + -- TODO: actually, it can proveNoKExprIsValid k_var e -- Prove that a match expression is valid. @@ -1087,17 +1088,33 @@ namespace Tests . intro i hpos h; simp at h; linarith . rename_i hd tl ih intro i hpos h - -- We can directly use `rw [list_nth]`! + -- We can directly use `rw [list_nth]` rw [list_nth]; simp split <;> try simp [*] . tauto - . -- TODO: we shouldn't have to do that + . -- We don't have to do this if we use scalar_tac have hneq : 0 < i := by cases i <;> rename_i a _ <;> simp_all; cases a <;> simp_all simp at h have ⟨ x, ih ⟩ := ih (i - 1) (by linarith) (by linarith) simp [ih] tauto + -- Return a continuation + divergent def list_nth_with_back {a: Type} (ls : List a) (i : Int) : + Result (a × (a → Result (List a))) := + match ls with + | [] => .fail .panic + | x :: ls => + if i = 0 then return (x, (λ ret => return (ret :: ls))) + else do + let (x, back) ← list_nth_with_back ls (i - 1) + return (x, + (λ ret => do + let ls ← back ret + return (x :: ls))) + + #check list_nth_with_back.unfold + mutual divergent def is_even (i : Int) : Result Bool := if i = 0 then return true else return (← is_odd (i - 1)) @@ -1121,7 +1138,6 @@ namespace Tests #check bar.unfold -- Testing dependent branching and let-bindings - -- TODO: why the linter warning? divergent def isNonZero (i : Int) : Result Bool := if _h:i = 0 then return false else diff --git a/backends/lean/Base/Lookup/Base.lean b/backends/lean/Base/Lookup/Base.lean new file mode 100644 index 00000000..54d242ea --- /dev/null +++ b/backends/lean/Base/Lookup/Base.lean @@ -0,0 +1,70 @@ +/- Utilities to have databases of theorems -/ +import Lean +import Std.Lean.HashSet +import Base.Utils +import Base.Primitives.Base + +namespace Lookup + +open Lean Elab Term Meta +open Utils + +-- TODO: the original function doesn't define correctly the `addImportedFn`. Do a PR? +def mkMapDeclarationExtension [Inhabited α] (name : Name := by exact decl_name%) : + IO (MapDeclarationExtension α) := + registerSimplePersistentEnvExtension { + name := name, + addImportedFn := fun a => a.foldl (fun s a => a.foldl (fun s (k, v) => s.insert k v) s) RBMap.empty, + addEntryFn := fun s n => s.insert n.1 n.2 , + toArrayFn := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1) + } + +-- Declare an extension of maps of maps (using [RBMap]). +-- The important point is that we need to merge the bound values (which are maps). +def mkMapMapDeclarationExtension [Inhabited β] (ord : α → α → Ordering) + (name : Name := by exact decl_name%) : + IO (MapDeclarationExtension (RBMap α β ord)) := + registerSimplePersistentEnvExtension { + name := name, + addImportedFn := fun a => + a.foldl (fun s a => a.foldl ( + -- We need to merge the maps + fun s (k0, k1_to_v) => + match s.find? k0 with + | none => + -- No binding: insert one + s.insert k0 k1_to_v + | some m => + -- There is already a binding: merge + let m := RBMap.fold (fun m k v => m.insert k v) m k1_to_v + s.insert k0 m) + s) RBMap.empty, + addEntryFn := fun s n => s.insert n.1 n.2 , + toArrayFn := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1) + } + +-- Declare an extension of maps of maps (using [HashMap]). +-- The important point is that we need to merge the bound values (which are maps). +def mkMapHashMapDeclarationExtension [BEq α] [Hashable α] [Inhabited β] + (name : Name := by exact decl_name%) : + IO (MapDeclarationExtension (HashMap α β)) := + registerSimplePersistentEnvExtension { + name := name, + addImportedFn := fun a => + a.foldl (fun s a => a.foldl ( + -- We need to merge the maps + fun s (k0, k1_to_v) => + match s.find? k0 with + | none => + -- No binding: insert one + s.insert k0 k1_to_v + | some m => + -- There is already a binding: merge + let m := HashMap.fold (fun m k v => m.insert k v) m k1_to_v + s.insert k0 m) + s) RBMap.empty, + addEntryFn := fun s n => s.insert n.1 n.2 , + toArrayFn := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1) + } + +end Lookup diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean index 76a92795..cf2a3b70 100644 --- a/backends/lean/Base/Progress/Base.lean +++ b/backends/lean/Base/Progress/Base.lean @@ -2,6 +2,7 @@ import Lean import Std.Lean.HashSet import Base.Utils import Base.Primitives.Base +import Base.Lookup.Base namespace Progress @@ -166,67 +167,9 @@ structure PSpecClassExprAttr where ext : MapDeclarationExtension (HashMap Expr Name) deriving Inhabited --- TODO: the original function doesn't define correctly the `addImportedFn`. Do a PR? -def mkMapDeclarationExtension [Inhabited α] (name : Name := by exact decl_name%) : - IO (MapDeclarationExtension α) := - registerSimplePersistentEnvExtension { - name := name, - addImportedFn := fun a => a.foldl (fun s a => a.foldl (fun s (k, v) => s.insert k v) s) RBMap.empty, - addEntryFn := fun s n => s.insert n.1 n.2 , - toArrayFn := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1) - } - --- Declare an extension of maps of maps (using [RBMap]). --- The important point is that we need to merge the bound values (which are maps). -def mkMapMapDeclarationExtension [Inhabited β] (ord : α → α → Ordering) - (name : Name := by exact decl_name%) : - IO (MapDeclarationExtension (RBMap α β ord)) := - registerSimplePersistentEnvExtension { - name := name, - addImportedFn := fun a => - a.foldl (fun s a => a.foldl ( - -- We need to merge the maps - fun s (k0, k1_to_v) => - match s.find? k0 with - | none => - -- No binding: insert one - s.insert k0 k1_to_v - | some m => - -- There is already a binding: merge - let m := RBMap.fold (fun m k v => m.insert k v) m k1_to_v - s.insert k0 m) - s) RBMap.empty, - addEntryFn := fun s n => s.insert n.1 n.2 , - toArrayFn := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1) - } - --- Declare an extension of maps of maps (using [HashMap]). --- The important point is that we need to merge the bound values (which are maps). -def mkMapHashMapDeclarationExtension [BEq α] [Hashable α] [Inhabited β] - (name : Name := by exact decl_name%) : - IO (MapDeclarationExtension (HashMap α β)) := - registerSimplePersistentEnvExtension { - name := name, - addImportedFn := fun a => - a.foldl (fun s a => a.foldl ( - -- We need to merge the maps - fun s (k0, k1_to_v) => - match s.find? k0 with - | none => - -- No binding: insert one - s.insert k0 k1_to_v - | some m => - -- There is already a binding: merge - let m := HashMap.fold (fun m k v => m.insert k v) m k1_to_v - s.insert k0 m) - s) RBMap.empty, - addEntryFn := fun s n => s.insert n.1 n.2 , - toArrayFn := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1) - } - /- The persistent map from function to pspec theorems. -/ initialize pspecAttr : PSpecAttr ← do - let ext ← mkMapDeclarationExtension `pspecMap + let ext ← Lookup.mkMapDeclarationExtension `pspecMap let attrImpl : AttributeImpl := { name := `pspec descr := "Marks theorems to use with the `progress` tactic" @@ -250,7 +193,7 @@ initialize pspecAttr : PSpecAttr ← do /- The persistent map from type classes to pspec theorems -/ initialize pspecClassAttr : PSpecClassAttr ← do let ext : MapDeclarationExtension (NameMap Name) ← - mkMapMapDeclarationExtension Name.quickCmp `pspecClassMap + Lookup.mkMapMapDeclarationExtension Name.quickCmp `pspecClassMap let attrImpl : AttributeImpl := { name := `cpspec descr := "Marks theorems to use for type classes with the `progress` tactic" @@ -282,7 +225,7 @@ initialize pspecClassAttr : PSpecClassAttr ← do /- The 2nd persistent map from type classes to pspec theorems -/ initialize pspecClassExprAttr : PSpecClassExprAttr ← do let ext : MapDeclarationExtension (HashMap Expr Name) ← - mkMapHashMapDeclarationExtension `pspecClassExprMap + Lookup.mkMapHashMapDeclarationExtension `pspecClassExprMap let attrImpl : AttributeImpl := { name := `cepspec descr := "Marks theorems to use for type classes with the `progress` tactic" -- cgit v1.2.3 From b3ebef29fe3f13a9004b39fcb89afb33fbbfd248 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 7 Dec 2023 20:52:04 +0100 Subject: Start working on a version of Diverge.FixI more suited to higher-order functions --- backends/lean/Base/Diverge/Base.lean | 570 +++++++++++++++++++++++++++---- backends/lean/Base/Diverge/ElabBase.lean | 6 + backends/lean/Base/Progress/Base.lean | 1 + 3 files changed, 510 insertions(+), 67 deletions(-) (limited to 'backends/lean/Base') diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index 6a52387d..9d986f4e 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -12,6 +12,78 @@ set_option profiler.threshold 100 namespace Diverge +/- Auxiliary lemmas -/ +namespace Lemmas + -- TODO: not necessary anymore + def for_all_fin_aux {n : Nat} (f : Fin n → Prop) (m : Nat) (h : m ≤ n) : Prop := + if heq: m = n then True + else + f ⟨ m, by simp_all [Nat.lt_iff_le_and_ne] ⟩ ∧ + for_all_fin_aux f (m + 1) (by simp_all [Arith.add_one_le_iff_le_ne]) + termination_by for_all_fin_aux n _ m h => n - m + decreasing_by + simp_wf + apply Nat.sub_add_lt_sub <;> try simp + simp_all [Arith.add_one_le_iff_le_ne] + + def for_all_fin {n : Nat} (f : Fin n → Prop) := for_all_fin_aux f 0 (by simp) + + theorem for_all_fin_aux_imp_forall {n : Nat} (f : Fin n → Prop) (m : Nat) : + (h : m ≤ n) → + for_all_fin_aux f m h → ∀ i, m ≤ i.val → f i + := by + generalize h: (n - m) = k + revert m + induction k -- TODO: induction h rather? + case zero => + simp_all + intro m h1 h2 + have h: n = m := by + linarith + unfold for_all_fin_aux; simp_all + simp_all + -- There is no i s.t. m ≤ i + intro i h3; cases i; simp_all + linarith + case succ k hi => + intro m hk hmn + intro hf i hmi + have hne: m ≠ n := by + have hineq := Nat.lt_of_sub_eq_succ hk + linarith + -- m = i? + if heq: m = i then + -- Yes: simply use the `for_all_fin_aux` hyp + unfold for_all_fin_aux at hf + simp_all + else + -- No: use the induction hypothesis + have hlt: m < i := by simp_all [Nat.lt_iff_le_and_ne] + have hineq: m + 1 ≤ n := by + have hineq := Nat.lt_of_sub_eq_succ hk + simp [*, Nat.add_one_le_iff] + have heq1: n - (m + 1) = k := by + -- TODO: very annoying arithmetic proof + simp [Nat.sub_eq_iff_eq_add hineq] + have hineq1: m ≤ n := by linarith + simp [Nat.sub_eq_iff_eq_add hineq1] at hk + simp_arith [hk] + have hi := hi (m + 1) heq1 hineq + apply hi <;> simp_all + . unfold for_all_fin_aux at hf + simp_all + . simp_all [Arith.add_one_le_iff_le_ne] + + -- TODO: this is not necessary anymore + theorem for_all_fin_imp_forall (n : Nat) (f : Fin n → Prop) : + for_all_fin f → ∀ i, f i + := by + intro Hf i + apply for_all_fin_aux_imp_forall <;> try assumption + simp + +end Lemmas + namespace Fix open Primitives @@ -436,6 +508,10 @@ namespace FixI /- Indexed fixed-point: definitions with indexed types, convenient to use for mutually recursive definitions. We simply port the definitions and proofs from Fix to a more specific case. + + Remark: the index designates the function in the mutually recursive group + (it should be a finite type). We make the return type depend on the input + type because we group the type parameters in the input type. -/ open Primitives Fix @@ -538,73 +614,6 @@ namespace FixI let j: Fin tys1.length := ⟨ j, jLt ⟩ Eq.mp (by simp) (get_fun tl j) - def for_all_fin_aux {n : Nat} (f : Fin n → Prop) (m : Nat) (h : m ≤ n) : Prop := - if heq: m = n then True - else - f ⟨ m, by simp_all [Nat.lt_iff_le_and_ne] ⟩ ∧ - for_all_fin_aux f (m + 1) (by simp_all [Arith.add_one_le_iff_le_ne]) - termination_by for_all_fin_aux n _ m h => n - m - decreasing_by - simp_wf - apply Nat.sub_add_lt_sub <;> try simp - simp_all [Arith.add_one_le_iff_le_ne] - - def for_all_fin {n : Nat} (f : Fin n → Prop) := for_all_fin_aux f 0 (by simp) - - theorem for_all_fin_aux_imp_forall {n : Nat} (f : Fin n → Prop) (m : Nat) : - (h : m ≤ n) → - for_all_fin_aux f m h → ∀ i, m ≤ i.val → f i - := by - generalize h: (n - m) = k - revert m - induction k -- TODO: induction h rather? - case zero => - simp_all - intro m h1 h2 - have h: n = m := by - linarith - unfold for_all_fin_aux; simp_all - simp_all - -- There is no i s.t. m ≤ i - intro i h3; cases i; simp_all - linarith - case succ k hi => - intro m hk hmn - intro hf i hmi - have hne: m ≠ n := by - have hineq := Nat.lt_of_sub_eq_succ hk - linarith - -- m = i? - if heq: m = i then - -- Yes: simply use the `for_all_fin_aux` hyp - unfold for_all_fin_aux at hf - simp_all - else - -- No: use the induction hypothesis - have hlt: m < i := by simp_all [Nat.lt_iff_le_and_ne] - have hineq: m + 1 ≤ n := by - have hineq := Nat.lt_of_sub_eq_succ hk - simp [*, Nat.add_one_le_iff] - have heq1: n - (m + 1) = k := by - -- TODO: very annoying arithmetic proof - simp [Nat.sub_eq_iff_eq_add hineq] - have hineq1: m ≤ n := by linarith - simp [Nat.sub_eq_iff_eq_add hineq1] at hk - simp_arith [hk] - have hi := hi (m + 1) heq1 hineq - apply hi <;> simp_all - . unfold for_all_fin_aux at hf - simp_all - . simp_all [Arith.add_one_le_iff_le_ne] - - -- TODO: this is not necessary anymore - theorem for_all_fin_imp_forall (n : Nat) (f : Fin n → Prop) : - for_all_fin f → ∀ i, f i - := by - intro Hf i - apply for_all_fin_aux_imp_forall <;> try assumption - simp - /- Automating the proofs -/ @[simp] theorem is_valid_p_same (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) (x : Result c) : @@ -707,6 +716,222 @@ namespace FixI end FixI +namespace FixII + /- Indexed fixed-point: definitions with indexed types, convenient to use for mutually + recursive definitions. We simply port the definitions and proofs from Fix to a more + specific case. + + Here however, we group the types into a parameter distinct from the inputs. + -/ + open Primitives Fix + + -- The index type + variable {id : Type u} + + -- The input/output types + variable {ty : id → Type v} {a : (i:id) → ty i → Type w} {b : (i:id) → ty i → Type x} + + -- Monotonicity relation over monadic arrows (i.e., Kleisli arrows) + def karrow_rel (k1 k2 : (i:id) → (t:ty i) → (a i t) → Result (b i t)) : Prop := + ∀ i t x, result_rel (k1 i t x) (k2 i t x) + + def kk_to_gen (k : (i:id) → (t:ty i) → (x:a i t) → Result (b i t)) : + (x: (i:id) × (t:ty i) × (a i t)) → Result (b x.fst x.snd.fst) := + λ ⟨ i, t, x ⟩ => k i t x + + def kk_of_gen (k : (x: (i:id) × (t:ty i) × (a i t)) → Result (b x.fst x.snd.fst)) : + (i:id) → (t:ty i) → a i t → Result (b i t) := + λ i t x => k ⟨ i, t, x ⟩ + + def k_to_gen (k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) : + ((x: (i:id) × (t:ty i) × (a i t)) → Result (b x.fst x.snd.fst)) → (x: (i:id) × (t:ty i) × (a i t)) → Result (b x.fst x.snd.fst) := + λ kk => kk_to_gen (k (kk_of_gen kk)) + + def k_of_gen (k : ((x: (i:id) × (t:ty i) × (a i t)) → Result (b x.fst x.snd.fst)) → (x: (i:id) × (t:ty i) × (a i t)) → Result (b x.fst x.snd.fst)) : + ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t) := + λ kk => kk_of_gen (k (kk_to_gen kk)) + + def e_to_gen (e : ((i:id) → (t:ty i) → a i t → Result (b i t)) → Result c) : + ((x: (i:id) × (t:ty i) × (a i t)) → Result (b x.fst x.snd.fst)) → Result c := + λ k => e (kk_of_gen k) + + def is_valid_p (k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) + (e : ((i:id) → (t:ty i) → a i t → Result (b i t)) → Result c) : Prop := + Fix.is_valid_p (k_to_gen k) (e_to_gen e) + + def is_valid (f : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) : Prop := + ∀ k i t x, is_valid_p k (λ k => f k i t x) + + def fix + (f : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) : + (i:id) → (t:ty i) → a i t → Result (b i t) := + kk_of_gen (Fix.fix (k_to_gen f)) + + theorem is_valid_fix_fixed_eq + {{f : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)}} + (Hvalid : is_valid f) : + fix f = f (fix f) := by + have Hvalid' : Fix.is_valid (k_to_gen f) := by + intro k x + simp only [is_valid, is_valid_p] at Hvalid + let ⟨ i, t, x ⟩ := x + have Hvalid := Hvalid (k_of_gen k) i t x + simp only [k_to_gen, k_of_gen, kk_to_gen, kk_of_gen] at Hvalid + refine Hvalid + have Heq := Fix.is_valid_fix_fixed_eq Hvalid' + simp [fix] + conv => lhs; rw [Heq] + + /- Some utilities to define the mutually recursive functions -/ + + -- TODO: use more + abbrev kk_ty (id : Type u) (ty : id → Type v) (a : (i:id) → ty i → Type w) (b : (i:id) → ty i → Type x) := + (i:id) → (t:ty i) → a i t → Result (b i t) + abbrev k_ty (id : Type u) (ty : id → Type v) (a : (i:id) → ty i → Type w) (b : (i:id) → ty i → Type x) := + kk_ty id ty a b → kk_ty id ty a b + + abbrev in_out_ty : Type (imax (u + 1) (imax (v + 1) (w + 1))) := + (ty : Type u) × (ty → Type v) × (ty → Type w) + -- TODO: remove? + abbrev mk_in_out_ty (ty : Type u) (in_ty : ty → Type v) (out_ty : ty → Type w) : + in_out_ty := + Sigma.mk ty (Prod.mk in_ty out_ty) + + -- Initially, we had left out the parameters id, a and b. + -- However, by parameterizing Funs with those parameters, we can state + -- and prove lemmas like Funs.is_valid_p_is_valid_p + inductive Funs (id : Type u) (ty : id → Type v) + (a : (i:id) → ty i → Type w) (b : (i:id) → ty i → Type x) : + List in_out_ty.{v, w, x} → Type (max (u + 1) (max (v + 1) (max (w + 1) (x + 1)))) := + | Nil : Funs id ty a b [] + | Cons {it: Type v} {ity : it → Type w} {oty : it → Type x} {tys : List in_out_ty} + (f : kk_ty id ty a b → (i:it) → (x:ity i) → Result (oty i)) (tl : Funs id ty a b tys) : + Funs id ty a b (⟨ it, ity, oty ⟩ :: tys) + + def get_fun {tys : List in_out_ty} (fl : Funs id ty a b tys) : + (i : Fin tys.length) → kk_ty id ty a b → (t : (tys.get i).fst) → + ((tys.get i).snd.fst t) → Result ((tys.get i).snd.snd t) := + match fl with + | .Nil => λ i => by have h:= i.isLt; simp at h + | @Funs.Cons id ty a b it ity oty tys1 f tl => + λ ⟨ i, iLt ⟩ => + match i with + | 0 => + Eq.mp (by simp [List.get]) f + | .succ j => + have jLt: j < tys1.length := by + simp at iLt + revert iLt + simp_arith + let j: Fin tys1.length := ⟨ j, jLt ⟩ + Eq.mp (by simp) (get_fun tl j) + + /- Automating the proofs -/ + @[simp] theorem is_valid_p_same + (k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) (x : Result c) : + is_valid_p k (λ _ => x) := by + simp [is_valid_p, k_to_gen, e_to_gen] + + @[simp] theorem is_valid_p_rec + (k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) (i : id) (t : ty i) (x : a i t) : + is_valid_p k (λ k => k i t x) := by + simp [is_valid_p, k_to_gen, e_to_gen, kk_to_gen, kk_of_gen] + + theorem is_valid_p_ite + (k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) + (cond : Prop) [h : Decidable cond] + {e1 e2 : ((i:id) → (t:ty i) → a i t → Result (b i t)) → Result c} + (he1: is_valid_p k e1) (he2 : is_valid_p k e2) : + is_valid_p k (λ k => ite cond (e1 k) (e2 k)) := by + split <;> assumption + + theorem is_valid_p_dite + (k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) + (cond : Prop) [h : Decidable cond] + {e1 : ((i:id) → (t:ty i) → a i t → Result (b i t)) → cond → Result c} + {e2 : ((i:id) → (t:ty i) → a i t → Result (b i t)) → Not cond → Result c} + (he1: ∀ x, is_valid_p k (λ k => e1 k x)) + (he2 : ∀ x, is_valid_p k (λ k => e2 k x)) : + is_valid_p k (λ k => dite cond (e1 k) (e2 k)) := by + split <;> simp [*] + + theorem is_valid_p_bind + {{k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)}} + {{g : ((i:id) → (t:ty i) → a i t → Result (b i t)) → Result c}} + {{h : c → ((i:id) → (t:ty i) → a i t → Result (b i t)) → Result d}} + (Hgvalid : is_valid_p k g) + (Hhvalid : ∀ y, is_valid_p k (h y)) : + is_valid_p k (λ k => do let y ← g k; h y k) := by + apply Fix.is_valid_p_bind + . apply Hgvalid + . apply Hhvalid + + def Funs.is_valid_p + (k : k_ty id ty a b) + (fl : Funs id ty a b tys) : + Prop := + match fl with + | .Nil => True + | .Cons f fl => + (∀ i x, FixII.is_valid_p k (λ k => f k i x)) ∧ fl.is_valid_p k + + theorem Funs.is_valid_p_Nil (k : k_ty id ty a b) : + Funs.is_valid_p k Funs.Nil := by simp [Funs.is_valid_p] + + def Funs.is_valid_p_is_valid_p_aux + {k : k_ty id ty a b} + {tys : List in_out_ty} + (fl : Funs id ty a b tys) (Hvalid : is_valid_p k fl) : + ∀ (i : Fin tys.length) (t : (tys.get i).fst) (x : (tys.get i).snd.fst t), + FixII.is_valid_p k (fun k => get_fun fl i k t x) := by + -- Prepare the induction + have ⟨ n, Hn ⟩ : { n : Nat // tys.length = n } := ⟨ tys.length, by rfl ⟩ + revert tys fl Hvalid + induction n + -- + case zero => + intro tys fl Hvalid Hlen; + have Heq: tys = [] := by cases tys <;> simp_all + intro i x + simp_all + have Hi := i.isLt + simp_all + case succ n Hn => + intro tys fl Hvalid Hlen i x; + cases fl <;> simp at Hlen i x Hvalid + rename_i ity oty tys f fl + have ⟨ Hvf, Hvalid ⟩ := Hvalid + have Hvf1: is_valid_p k fl := by + simp [Hvalid, Funs.is_valid_p] + have Hn := @Hn tys fl Hvf1 (by simp [*]) + -- Case disjunction on i + match i with + | ⟨ 0, _ ⟩ => + simp at x + simp [get_fun] + apply (Hvf x) + | ⟨ .succ j, HiLt ⟩ => + simp_arith at HiLt + simp at x + let j : Fin (List.length tys) := ⟨ j, by simp_arith [HiLt] ⟩ + have Hn := Hn j x + apply Hn + + def Funs.is_valid_p_is_valid_p + (tys : List in_out_ty) + (k : k_ty (Fin (List.length tys)) (λ i => (tys.get i).fst) + (fun i t => (List.get tys i).snd.fst t) (fun i t => (List.get tys i).snd.snd t)) + (fl : Funs (Fin tys.length) (λ i => (tys.get i).fst) + (λ i t => (tys.get i).snd.fst t) (λ i t => (tys.get i).snd.snd t) tys) : + fl.is_valid_p k → + ∀ (i : Fin tys.length) (t : (tys.get i).fst) (x : (tys.get i).snd.fst t), + FixII.is_valid_p k (fun k => get_fun fl i k t x) + := by + intro Hvalid + apply is_valid_p_is_valid_p_aux; simp [*] + +end FixII + namespace Ex1 /- An example of use of the fixed-point -/ open Primitives Fix @@ -1133,3 +1358,214 @@ namespace Ex6 Heqix end Ex6 + +namespace Ex7 + /- `list_nth` again, but this time we use FixII -/ + open Primitives FixII + + @[simp] def tys.{u} : List in_out_ty := + [mk_in_out_ty (Type u) (λ a => List a × Int) (λ a => a)] + + @[simp] def ty (i : Fin 1) := (tys.get i).fst + @[simp] def input_ty (i : Fin 1) (t : ty i) : Type u := (tys.get i).snd.fst t + @[simp] def output_ty (i : Fin 1) (t : ty i) : Type u := (tys.get i).snd.snd t + + def list_nth_body.{u} (k : (i:Fin 1) → (t:ty i) → input_ty i t → Result (output_ty i t)) + (a : Type u) (x : List a × Int) : Result a := + let ⟨ ls, i ⟩ := x + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else k 0 a ⟨ tl, i - 1 ⟩ + + @[simp] def bodies : + Funs (Fin 1) ty input_ty output_ty tys := + Funs.Cons list_nth_body Funs.Nil + + def body (k : (i : Fin 1) → (t : ty i) → (x : input_ty i t) → Result (output_ty i t)) (i: Fin 1) : + (t : ty i) → (x : input_ty i t) → Result (output_ty i t) := get_fun bodies i k + + theorem body_is_valid: is_valid body := by + -- Split the proof into proofs of validity of the individual bodies + rw [is_valid] + simp only [body] + intro k + apply (Funs.is_valid_p_is_valid_p tys) + simp [Funs.is_valid_p] + (repeat (apply And.intro)); intro x; try simp at x + simp only [list_nth_body] + -- Prove the validity of the individual bodies + intro k x + split <;> try simp + split <;> simp + + -- Writing the proof terms explicitly + theorem list_nth_body_is_valid' (k : k_ty (Fin 1) ty input_ty output_ty) + (a : Type u) (x : List a × Int) : is_valid_p k (fun k => list_nth_body k a x) := + let ⟨ ls, i ⟩ := x + match ls with + | [] => is_valid_p_same k (.fail .panic) + | hd :: tl => + is_valid_p_ite k (Eq i 0) (is_valid_p_same k (.ret hd)) (is_valid_p_rec k 0 a ⟨tl, i-1⟩) + + theorem body_is_valid' : is_valid body := + fun k => + Funs.is_valid_p_is_valid_p tys k bodies + (And.intro (list_nth_body_is_valid' k) (Funs.is_valid_p_Nil k)) + + def list_nth {a: Type u} (ls : List a) (i : Int) : Result a := + fix body 0 a ⟨ ls , i ⟩ + + -- The unfolding equation - diverges if `i < 0` + theorem list_nth_eq (ls : List a) (i : Int) : + list_nth ls i = + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else list_nth tl (i - 1) + := by + have Heq := is_valid_fix_fixed_eq body_is_valid + simp [list_nth] + conv => lhs; rw [Heq] + + -- Write the proof term explicitly: the generation of the proof term (without tactics) + -- is automatable, and the proof term is actually a lot simpler and smaller when we + -- don't use tactics. + theorem list_nth_eq'.{u} {a : Type u} (ls : List a) (i : Int) : + list_nth ls i = + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else list_nth tl (i - 1) + := + -- Use the fixed-point equation + have Heq := is_valid_fix_fixed_eq body_is_valid.{u} + -- Add the index + have Heqi := congr_fun Heq 0 + -- Add the type parameter + have Heqia := congr_fun Heqi a + -- Add the input + have Heqix := congr_fun Heqia (ls, i) + -- Done + Heqix + +end Ex7 + +namespace Ex8 + /- Higher-order example, with FixII -/ + open Primitives FixII + + variable {id : Type u} {ty : id → Type v} + variable {a : (i:id) → ty i → Type w} {b : (i:id) → ty i → Type x} + + /- An auxiliary function, which doesn't require the fixed-point -/ + def map {a : Type y} {b : Type z} (f : a → Result b) (ls : List a) : Result (List b) := + match ls with + | [] => .ret [] + | hd :: tl => + do + let hd ← f hd + let tl ← map f tl + .ret (hd :: tl) + + /- The validity theorem for `map`, generic in `f` -/ + theorem map_is_valid + (i : id) (t : ty i) + {{f : (a i t → Result (b i t)) → (a i t) → Result c}} + (Hfvalid : ∀ k x, is_valid_p k (λ k => f (k i t) x)) + (k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) + (ls : List (a i t)) : + is_valid_p k (λ k => map (f (k i t)) ls) := by + induction ls <;> simp [map] + apply is_valid_p_bind <;> try simp_all + intros + apply is_valid_p_bind <;> try simp_all + + theorem map_is_valid' + (i : id) (t : ty i) + (k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) + (ls : List (a i t)) : + is_valid_p k (λ k => map (k i t) ls) := by + induction ls <;> simp [map] + apply is_valid_p_bind <;> try simp_all + intros + apply is_valid_p_bind <;> try simp_all + +end Ex8 + +namespace Ex9 + /- An example which uses map -/ + open Primitives FixII Ex8 + + inductive Tree (a : Type u) := + | leaf (x : a) + | node (tl : List (Tree a)) + + @[simp] def tys.{u} : List in_out_ty := + [mk_in_out_ty (Type u) (λ a => Tree a) (λ a => Tree a)] + + @[simp] def ty (i : Fin 1) := (tys.get i).fst + @[simp] def input_ty (i : Fin 1) (t : ty i) : Type u := (tys.get i).snd.fst t + @[simp] def output_ty (i : Fin 1) (t : ty i) : Type u := (tys.get i).snd.snd t + + def id_body.{u} (k : (i:Fin 1) → (t:ty i) → input_ty i t → Result (output_ty i t)) + (a : Type u) (t : Tree a) : Result (Tree a) := + match t with + | .leaf x => .ret (.leaf x) + | .node tl => + do + let tl ← map (k 0 a) tl + .ret (.node tl) + + @[simp] def bodies : + Funs (Fin 1) ty input_ty output_ty tys := + Funs.Cons id_body Funs.Nil + + theorem id_body_is_valid : + ∀ (k : ((i : Fin 1) → (t : ty i) → input_ty i t → Result (output_ty i t)) → (i : Fin 1) → (t : ty i) → input_ty i t → Result (output_ty i t)) + (a : Type u) (x : Tree a), + @is_valid_p (Fin 1) ty input_ty output_ty (output_ty 0 a) k (λ k => id_body k a x) := by + intro k a x + simp only [id_body] + split <;> try simp + apply is_valid_p_bind <;> try simp [*] + -- We have to show that `map k tl` is valid + -- Remark: `map_is_valid` doesn't work here, we need the specialized version + apply map_is_valid' + + def body (k : (i : Fin 1) → (t : ty i) → (x : input_ty i t) → Result (output_ty i t)) (i: Fin 1) : + (t : ty i) → (x : input_ty i t) → Result (output_ty i t) := get_fun bodies i k + + theorem body_is_valid : is_valid body := + fun k => + Funs.is_valid_p_is_valid_p tys k bodies + (And.intro (id_body_is_valid k) (Funs.is_valid_p_Nil k)) + + def id {a: Type u} (t : Tree a) : Result (Tree a) := + fix body 0 a t + + -- Writing the proof term explicitly + theorem id_eq' {a : Type u} (t : Tree a) : + id t = + (match t with + | .leaf x => .ret (.leaf x) + | .node tl => + do + let tl ← map id tl + .ret (.node tl)) + := + -- The unfolding equation + have Heq := is_valid_fix_fixed_eq body_is_valid.{u} + -- Add the index + have Heqi := congr_fun Heq 0 + -- Add the type parameter + have Heqia := congr_fun Heqi a + -- Add the input + have Heqix := congr_fun Heqia t + -- Done + Heqix + +end Ex9 diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean index fedb1c74..1a0676e2 100644 --- a/backends/lean/Base/Diverge/ElabBase.lean +++ b/backends/lean/Base/Diverge/ElabBase.lean @@ -1,8 +1,11 @@ import Lean +import Base.Utils +import Base.Primitives.Base namespace Diverge open Lean Elab Term Meta +open Utils -- We can't define and use trace classes in the same file initialize registerTraceClass `Diverge.elab @@ -12,4 +15,7 @@ initialize registerTraceClass `Diverge.def.genBody initialize registerTraceClass `Diverge.def.valid initialize registerTraceClass `Diverge.def.unfold +-- For the attribute (for higher-order functions) +initialize registerTraceClass `Diverge.attr + end Diverge diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean index cf2a3b70..573f0cc5 100644 --- a/backends/lean/Base/Progress/Base.lean +++ b/backends/lean/Base/Progress/Base.lean @@ -154,6 +154,7 @@ structure PSpecAttr where specs of the scalar operations, which is what I really need, but I'm not sure it applies well to other situations. A better way would probably to use type classes, but I couldn't get them to work on those cases. It is worth retrying. + UPDATE: use discrimination trees (`DiscrTree`) from core Lean -/ structure PSpecClassAttr where attr : AttributeImpl -- cgit v1.2.3 From 3c092169efcbc36a9b435c68c590b36f69204f94 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 8 Dec 2023 12:38:55 +0100 Subject: Update the progress tactic to use discrimination trees --- backends/lean/Base/Primitives/Scalar.lean | 126 ++++++++--------- backends/lean/Base/Progress/Base.lean | 228 ++++++++---------------------- backends/lean/Base/Progress/Progress.lean | 93 ++++++------ 3 files changed, 175 insertions(+), 272 deletions(-) (limited to 'backends/lean/Base') diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean index f74fecd4..db522df2 100644 --- a/backends/lean/Base/Primitives/Scalar.lean +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -528,7 +528,7 @@ instance {ty} : HAnd (Scalar ty) (Scalar ty) (Scalar ty) where hAnd x y := Scalar.and x y -- Generic theorem - shouldn't be used much -@[cpspec] +@[pspec] theorem Scalar.add_spec {ty} {x y : Scalar ty} (hmin : Scalar.min ty ≤ x.val + y.val) (hmax : x.val + y.val ≤ Scalar.max ty) : @@ -550,62 +550,62 @@ theorem Scalar.add_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty} apply add_spec <;> assumption /- Fine-grained theorems -/ -@[cepspec] theorem Usize.add_spec {x y : Usize} (hmax : x.val + y.val ≤ Usize.max) : +@[pspec] theorem Usize.add_spec {x y : Usize} (hmax : x.val + y.val ≤ Usize.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by apply Scalar.add_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem U8.add_spec {x y : U8} (hmax : x.val + y.val ≤ U8.max) : +@[pspec] theorem U8.add_spec {x y : U8} (hmax : x.val + y.val ≤ U8.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by apply Scalar.add_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem U16.add_spec {x y : U16} (hmax : x.val + y.val ≤ U16.max) : +@[pspec] theorem U16.add_spec {x y : U16} (hmax : x.val + y.val ≤ U16.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by apply Scalar.add_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem U32.add_spec {x y : U32} (hmax : x.val + y.val ≤ U32.max) : +@[pspec] theorem U32.add_spec {x y : U32} (hmax : x.val + y.val ≤ U32.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by apply Scalar.add_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem U64.add_spec {x y : U64} (hmax : x.val + y.val ≤ U64.max) : +@[pspec] theorem U64.add_spec {x y : U64} (hmax : x.val + y.val ≤ U64.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by apply Scalar.add_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem U128.add_spec {x y : U128} (hmax : x.val + y.val ≤ U128.max) : +@[pspec] theorem U128.add_spec {x y : U128} (hmax : x.val + y.val ≤ U128.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by apply Scalar.add_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem Isize.add_spec {x y : Isize} +@[pspec] theorem Isize.add_spec {x y : Isize} (hmin : Isize.min ≤ x.val + y.val) (hmax : x.val + y.val ≤ Isize.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := Scalar.add_spec hmin hmax -@[cepspec] theorem I8.add_spec {x y : I8} +@[pspec] theorem I8.add_spec {x y : I8} (hmin : I8.min ≤ x.val + y.val) (hmax : x.val + y.val ≤ I8.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := Scalar.add_spec hmin hmax -@[cepspec] theorem I16.add_spec {x y : I16} +@[pspec] theorem I16.add_spec {x y : I16} (hmin : I16.min ≤ x.val + y.val) (hmax : x.val + y.val ≤ I16.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := Scalar.add_spec hmin hmax -@[cepspec] theorem I32.add_spec {x y : I32} +@[pspec] theorem I32.add_spec {x y : I32} (hmin : I32.min ≤ x.val + y.val) (hmax : x.val + y.val ≤ I32.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := Scalar.add_spec hmin hmax -@[cepspec] theorem I64.add_spec {x y : I64} +@[pspec] theorem I64.add_spec {x y : I64} (hmin : I64.min ≤ x.val + y.val) (hmax : x.val + y.val ≤ I64.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := Scalar.add_spec hmin hmax -@[cepspec] theorem I128.add_spec {x y : I128} +@[pspec] theorem I128.add_spec {x y : I128} (hmin : I128.min ≤ x.val + y.val) (hmax : x.val + y.val ≤ I128.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := Scalar.add_spec hmin hmax -- Generic theorem - shouldn't be used much -@[cpspec] +@[pspec] theorem Scalar.sub_spec {ty} {x y : Scalar ty} (hmin : Scalar.min ty ≤ x.val - y.val) (hmax : x.val - y.val ≤ Scalar.max ty) : @@ -629,56 +629,56 @@ theorem Scalar.sub_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty} apply sub_spec <;> assumption /- Fine-grained theorems -/ -@[cepspec] theorem Usize.sub_spec {x y : Usize} (hmin : Usize.min ≤ x.val - y.val) : +@[pspec] theorem Usize.sub_spec {x y : Usize} (hmin : Usize.min ≤ x.val - y.val) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := by apply Scalar.sub_unsigned_spec <;> simp only [Scalar.min, *] -@[cepspec] theorem U8.sub_spec {x y : U8} (hmin : U8.min ≤ x.val - y.val) : +@[pspec] theorem U8.sub_spec {x y : U8} (hmin : U8.min ≤ x.val - y.val) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := by apply Scalar.sub_unsigned_spec <;> simp only [Scalar.min, *] -@[cepspec] theorem U16.sub_spec {x y : U16} (hmin : U16.min ≤ x.val - y.val) : +@[pspec] theorem U16.sub_spec {x y : U16} (hmin : U16.min ≤ x.val - y.val) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := by apply Scalar.sub_unsigned_spec <;> simp only [Scalar.min, *] -@[cepspec] theorem U32.sub_spec {x y : U32} (hmin : U32.min ≤ x.val - y.val) : +@[pspec] theorem U32.sub_spec {x y : U32} (hmin : U32.min ≤ x.val - y.val) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := by apply Scalar.sub_unsigned_spec <;> simp only [Scalar.min, *] -@[cepspec] theorem U64.sub_spec {x y : U64} (hmin : U64.min ≤ x.val - y.val) : +@[pspec] theorem U64.sub_spec {x y : U64} (hmin : U64.min ≤ x.val - y.val) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := by apply Scalar.sub_unsigned_spec <;> simp only [Scalar.min, *] -@[cepspec] theorem U128.sub_spec {x y : U128} (hmin : U128.min ≤ x.val - y.val) : +@[pspec] theorem U128.sub_spec {x y : U128} (hmin : U128.min ≤ x.val - y.val) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := by apply Scalar.sub_unsigned_spec <;> simp only [Scalar.min, *] -@[cepspec] theorem Isize.sub_spec {x y : Isize} (hmin : Isize.min ≤ x.val - y.val) +@[pspec] theorem Isize.sub_spec {x y : Isize} (hmin : Isize.min ≤ x.val - y.val) (hmax : x.val - y.val ≤ Isize.max) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := Scalar.sub_spec hmin hmax -@[cepspec] theorem I8.sub_spec {x y : I8} (hmin : I8.min ≤ x.val - y.val) +@[pspec] theorem I8.sub_spec {x y : I8} (hmin : I8.min ≤ x.val - y.val) (hmax : x.val - y.val ≤ I8.max) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := Scalar.sub_spec hmin hmax -@[cepspec] theorem I16.sub_spec {x y : I16} (hmin : I16.min ≤ x.val - y.val) +@[pspec] theorem I16.sub_spec {x y : I16} (hmin : I16.min ≤ x.val - y.val) (hmax : x.val - y.val ≤ I16.max) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := Scalar.sub_spec hmin hmax -@[cepspec] theorem I32.sub_spec {x y : I32} (hmin : I32.min ≤ x.val - y.val) +@[pspec] theorem I32.sub_spec {x y : I32} (hmin : I32.min ≤ x.val - y.val) (hmax : x.val - y.val ≤ I32.max) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := Scalar.sub_spec hmin hmax -@[cepspec] theorem I64.sub_spec {x y : I64} (hmin : I64.min ≤ x.val - y.val) +@[pspec] theorem I64.sub_spec {x y : I64} (hmin : I64.min ≤ x.val - y.val) (hmax : x.val - y.val ≤ I64.max) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := Scalar.sub_spec hmin hmax -@[cepspec] theorem I128.sub_spec {x y : I128} (hmin : I128.min ≤ x.val - y.val) +@[pspec] theorem I128.sub_spec {x y : I128} (hmin : I128.min ≤ x.val - y.val) (hmax : x.val - y.val ≤ I128.max) : ∃ z, x - y = ret z ∧ z.val = x.val - y.val := Scalar.sub_spec hmin hmax @@ -705,62 +705,62 @@ theorem Scalar.mul_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty} apply mul_spec <;> assumption /- Fine-grained theorems -/ -@[cepspec] theorem Usize.mul_spec {x y : Usize} (hmax : x.val * y.val ≤ Usize.max) : +@[pspec] theorem Usize.mul_spec {x y : Usize} (hmax : x.val * y.val ≤ Usize.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := by apply Scalar.mul_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem U8.mul_spec {x y : U8} (hmax : x.val * y.val ≤ U8.max) : +@[pspec] theorem U8.mul_spec {x y : U8} (hmax : x.val * y.val ≤ U8.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := by apply Scalar.mul_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem U16.mul_spec {x y : U16} (hmax : x.val * y.val ≤ U16.max) : +@[pspec] theorem U16.mul_spec {x y : U16} (hmax : x.val * y.val ≤ U16.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := by apply Scalar.mul_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem U32.mul_spec {x y : U32} (hmax : x.val * y.val ≤ U32.max) : +@[pspec] theorem U32.mul_spec {x y : U32} (hmax : x.val * y.val ≤ U32.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := by apply Scalar.mul_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem U64.mul_spec {x y : U64} (hmax : x.val * y.val ≤ U64.max) : +@[pspec] theorem U64.mul_spec {x y : U64} (hmax : x.val * y.val ≤ U64.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := by apply Scalar.mul_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem U128.mul_spec {x y : U128} (hmax : x.val * y.val ≤ U128.max) : +@[pspec] theorem U128.mul_spec {x y : U128} (hmax : x.val * y.val ≤ U128.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := by apply Scalar.mul_unsigned_spec <;> simp only [Scalar.max, *] -@[cepspec] theorem Isize.mul_spec {x y : Isize} (hmin : Isize.min ≤ x.val * y.val) +@[pspec] theorem Isize.mul_spec {x y : Isize} (hmin : Isize.min ≤ x.val * y.val) (hmax : x.val * y.val ≤ Isize.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := Scalar.mul_spec hmin hmax -@[cepspec] theorem I8.mul_spec {x y : I8} (hmin : I8.min ≤ x.val * y.val) +@[pspec] theorem I8.mul_spec {x y : I8} (hmin : I8.min ≤ x.val * y.val) (hmax : x.val * y.val ≤ I8.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := Scalar.mul_spec hmin hmax -@[cepspec] theorem I16.mul_spec {x y : I16} (hmin : I16.min ≤ x.val * y.val) +@[pspec] theorem I16.mul_spec {x y : I16} (hmin : I16.min ≤ x.val * y.val) (hmax : x.val * y.val ≤ I16.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := Scalar.mul_spec hmin hmax -@[cepspec] theorem I32.mul_spec {x y : I32} (hmin : I32.min ≤ x.val * y.val) +@[pspec] theorem I32.mul_spec {x y : I32} (hmin : I32.min ≤ x.val * y.val) (hmax : x.val * y.val ≤ I32.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := Scalar.mul_spec hmin hmax -@[cepspec] theorem I64.mul_spec {x y : I64} (hmin : I64.min ≤ x.val * y.val) +@[pspec] theorem I64.mul_spec {x y : I64} (hmin : I64.min ≤ x.val * y.val) (hmax : x.val * y.val ≤ I64.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := Scalar.mul_spec hmin hmax -@[cepspec] theorem I128.mul_spec {x y : I128} (hmin : I128.min ≤ x.val * y.val) +@[pspec] theorem I128.mul_spec {x y : I128} (hmin : I128.min ≤ x.val * y.val) (hmax : x.val * y.val ≤ I128.max) : ∃ z, x * y = ret z ∧ z.val = x.val * y.val := Scalar.mul_spec hmin hmax -- Generic theorem - shouldn't be used much -@[cpspec] +@[pspec] theorem Scalar.div_spec {ty} {x y : Scalar ty} (hnz : y.val ≠ 0) (hmin : Scalar.min ty ≤ scalar_div x.val y.val) @@ -788,66 +788,66 @@ theorem Scalar.div_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : S apply hs /- Fine-grained theorems -/ -@[cepspec] theorem Usize.div_spec (x : Usize) {y : Usize} (hnz : y.val ≠ 0) : +@[pspec] theorem Usize.div_spec (x : Usize) {y : Usize} (hnz : y.val ≠ 0) : ∃ z, x / y = ret z ∧ z.val = x.val / y.val := by apply Scalar.div_unsigned_spec <;> simp [*] -@[cepspec] theorem U8.div_spec (x : U8) {y : U8} (hnz : y.val ≠ 0) : +@[pspec] theorem U8.div_spec (x : U8) {y : U8} (hnz : y.val ≠ 0) : ∃ z, x / y = ret z ∧ z.val = x.val / y.val := by apply Scalar.div_unsigned_spec <;> simp [Scalar.max, *] -@[cepspec] theorem U16.div_spec (x : U16) {y : U16} (hnz : y.val ≠ 0) : +@[pspec] theorem U16.div_spec (x : U16) {y : U16} (hnz : y.val ≠ 0) : ∃ z, x / y = ret z ∧ z.val = x.val / y.val := by apply Scalar.div_unsigned_spec <;> simp [Scalar.max, *] -@[cepspec] theorem U32.div_spec (x : U32) {y : U32} (hnz : y.val ≠ 0) : +@[pspec] theorem U32.div_spec (x : U32) {y : U32} (hnz : y.val ≠ 0) : ∃ z, x / y = ret z ∧ z.val = x.val / y.val := by apply Scalar.div_unsigned_spec <;> simp [Scalar.max, *] -@[cepspec] theorem U64.div_spec (x : U64) {y : U64} (hnz : y.val ≠ 0) : +@[pspec] theorem U64.div_spec (x : U64) {y : U64} (hnz : y.val ≠ 0) : ∃ z, x / y = ret z ∧ z.val = x.val / y.val := by apply Scalar.div_unsigned_spec <;> simp [Scalar.max, *] -@[cepspec] theorem U128.div_spec (x : U128) {y : U128} (hnz : y.val ≠ 0) : +@[pspec] theorem U128.div_spec (x : U128) {y : U128} (hnz : y.val ≠ 0) : ∃ z, x / y = ret z ∧ z.val = x.val / y.val := by apply Scalar.div_unsigned_spec <;> simp [Scalar.max, *] -@[cepspec] theorem Isize.div_spec (x : Isize) {y : Isize} +@[pspec] theorem Isize.div_spec (x : Isize) {y : Isize} (hnz : y.val ≠ 0) (hmin : Isize.min ≤ scalar_div x.val y.val) (hmax : scalar_div x.val y.val ≤ Isize.max): ∃ z, x / y = ret z ∧ z.val = scalar_div x.val y.val := Scalar.div_spec hnz hmin hmax -@[cepspec] theorem I8.div_spec (x : I8) {y : I8} +@[pspec] theorem I8.div_spec (x : I8) {y : I8} (hnz : y.val ≠ 0) (hmin : I8.min ≤ scalar_div x.val y.val) (hmax : scalar_div x.val y.val ≤ I8.max): ∃ z, x / y = ret z ∧ z.val = scalar_div x.val y.val := Scalar.div_spec hnz hmin hmax -@[cepspec] theorem I16.div_spec (x : I16) {y : I16} +@[pspec] theorem I16.div_spec (x : I16) {y : I16} (hnz : y.val ≠ 0) (hmin : I16.min ≤ scalar_div x.val y.val) (hmax : scalar_div x.val y.val ≤ I16.max): ∃ z, x / y = ret z ∧ z.val = scalar_div x.val y.val := Scalar.div_spec hnz hmin hmax -@[cepspec] theorem I32.div_spec (x : I32) {y : I32} +@[pspec] theorem I32.div_spec (x : I32) {y : I32} (hnz : y.val ≠ 0) (hmin : I32.min ≤ scalar_div x.val y.val) (hmax : scalar_div x.val y.val ≤ I32.max): ∃ z, x / y = ret z ∧ z.val = scalar_div x.val y.val := Scalar.div_spec hnz hmin hmax -@[cepspec] theorem I64.div_spec (x : I64) {y : I64} +@[pspec] theorem I64.div_spec (x : I64) {y : I64} (hnz : y.val ≠ 0) (hmin : I64.min ≤ scalar_div x.val y.val) (hmax : scalar_div x.val y.val ≤ I64.max): ∃ z, x / y = ret z ∧ z.val = scalar_div x.val y.val := Scalar.div_spec hnz hmin hmax -@[cepspec] theorem I128.div_spec (x : I128) {y : I128} +@[pspec] theorem I128.div_spec (x : I128) {y : I128} (hnz : y.val ≠ 0) (hmin : I128.min ≤ scalar_div x.val y.val) (hmax : scalar_div x.val y.val ≤ I128.max): @@ -855,7 +855,7 @@ theorem Scalar.div_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : S Scalar.div_spec hnz hmin hmax -- Generic theorem - shouldn't be used much -@[cpspec] +@[pspec] theorem Scalar.rem_spec {ty} {x y : Scalar ty} (hnz : y.val ≠ 0) (hmin : Scalar.min ty ≤ scalar_rem x.val y.val) @@ -883,59 +883,59 @@ theorem Scalar.rem_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : S simp [*] at hs simp [*] -@[cepspec] theorem Usize.rem_spec (x : Usize) {y : Usize} (hnz : y.val ≠ 0) : +@[pspec] theorem Usize.rem_spec (x : Usize) {y : Usize} (hnz : y.val ≠ 0) : ∃ z, x % y = ret z ∧ z.val = x.val % y.val := by apply Scalar.rem_unsigned_spec <;> simp [*] -@[cepspec] theorem U8.rem_spec (x : U8) {y : U8} (hnz : y.val ≠ 0) : +@[pspec] theorem U8.rem_spec (x : U8) {y : U8} (hnz : y.val ≠ 0) : ∃ z, x % y = ret z ∧ z.val = x.val % y.val := by apply Scalar.rem_unsigned_spec <;> simp [Scalar.max, *] -@[cepspec] theorem U16.rem_spec (x : U16) {y : U16} (hnz : y.val ≠ 0) : +@[pspec] theorem U16.rem_spec (x : U16) {y : U16} (hnz : y.val ≠ 0) : ∃ z, x % y = ret z ∧ z.val = x.val % y.val := by apply Scalar.rem_unsigned_spec <;> simp [Scalar.max, *] -@[cepspec] theorem U32.rem_spec (x : U32) {y : U32} (hnz : y.val ≠ 0) : +@[pspec] theorem U32.rem_spec (x : U32) {y : U32} (hnz : y.val ≠ 0) : ∃ z, x % y = ret z ∧ z.val = x.val % y.val := by apply Scalar.rem_unsigned_spec <;> simp [Scalar.max, *] -@[cepspec] theorem U64.rem_spec (x : U64) {y : U64} (hnz : y.val ≠ 0) : +@[pspec] theorem U64.rem_spec (x : U64) {y : U64} (hnz : y.val ≠ 0) : ∃ z, x % y = ret z ∧ z.val = x.val % y.val := by apply Scalar.rem_unsigned_spec <;> simp [Scalar.max, *] -@[cepspec] theorem U128.rem_spec (x : U128) {y : U128} (hnz : y.val ≠ 0) : +@[pspec] theorem U128.rem_spec (x : U128) {y : U128} (hnz : y.val ≠ 0) : ∃ z, x % y = ret z ∧ z.val = x.val % y.val := by apply Scalar.rem_unsigned_spec <;> simp [Scalar.max, *] -@[cepspec] theorem I8.rem_spec (x : I8) {y : I8} +@[pspec] theorem I8.rem_spec (x : I8) {y : I8} (hnz : y.val ≠ 0) (hmin : I8.min ≤ scalar_rem x.val y.val) (hmax : scalar_rem x.val y.val ≤ I8.max): ∃ z, x % y = ret z ∧ z.val = scalar_rem x.val y.val := Scalar.rem_spec hnz hmin hmax -@[cepspec] theorem I16.rem_spec (x : I16) {y : I16} +@[pspec] theorem I16.rem_spec (x : I16) {y : I16} (hnz : y.val ≠ 0) (hmin : I16.min ≤ scalar_rem x.val y.val) (hmax : scalar_rem x.val y.val ≤ I16.max): ∃ z, x % y = ret z ∧ z.val = scalar_rem x.val y.val := Scalar.rem_spec hnz hmin hmax -@[cepspec] theorem I32.rem_spec (x : I32) {y : I32} +@[pspec] theorem I32.rem_spec (x : I32) {y : I32} (hnz : y.val ≠ 0) (hmin : I32.min ≤ scalar_rem x.val y.val) (hmax : scalar_rem x.val y.val ≤ I32.max): ∃ z, x % y = ret z ∧ z.val = scalar_rem x.val y.val := Scalar.rem_spec hnz hmin hmax -@[cepspec] theorem I64.rem_spec (x : I64) {y : I64} +@[pspec] theorem I64.rem_spec (x : I64) {y : I64} (hnz : y.val ≠ 0) (hmin : I64.min ≤ scalar_rem x.val y.val) (hmax : scalar_rem x.val y.val ≤ I64.max): ∃ z, x % y = ret z ∧ z.val = scalar_rem x.val y.val := Scalar.rem_spec hnz hmin hmax -@[cepspec] theorem I128.rem_spec (x : I128) {y : I128} +@[pspec] theorem I128.rem_spec (x : I128) {y : I128} (hnz : y.val ≠ 0) (hmin : I128.min ≤ scalar_rem x.val y.val) (hmax : scalar_rem x.val y.val ≤ I128.max): diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean index 573f0cc5..d50c357c 100644 --- a/backends/lean/Base/Progress/Base.lean +++ b/backends/lean/Base/Progress/Base.lean @@ -4,6 +4,9 @@ import Base.Utils import Base.Primitives.Base import Base.Lookup.Base +import Lean.Meta.DiscrTree +import Lean.Meta.Tactic.Simp + namespace Progress open Lean Elab Term Meta @@ -14,13 +17,35 @@ initialize registerTraceClass `Progress /- # Progress tactic -/ +/- Discrimination trees map expressions to values. When storing an expression + in a discrimination tree, the expression is first converted to an array + of `DiscrTree.Key`, which are the keys actually used by the discrimination + trees. The conversion operation is monadic, however, and extensions require + all the operations to be pure. For this reason, in the state extension, we + store the keys from *after* the transformation (i.e., the `DiscrTreeKey` + below). + -/ +abbrev DiscrTreeKey (simpleReduce : Bool) := Array (DiscrTree.Key simpleReduce) + +abbrev DiscrTreeExtension (α : Type) (simpleReduce : Bool) := + SimplePersistentEnvExtension (DiscrTreeKey simpleReduce × α) (DiscrTree α simpleReduce) + +def mkDiscrTreeExtention [Inhabited α] [BEq α] (name : Name := by exact decl_name%) (simpleReduce : Bool) : + IO (DiscrTreeExtension α simpleReduce) := + registerSimplePersistentEnvExtension { + name := name, + addImportedFn := fun a => a.foldl (fun s a => a.foldl (fun s (k, v) => s.insertCore k v) s) DiscrTree.empty, + addEntryFn := fun s n => s.insertCore n.1 n.2 , + } + structure PSpecDesc where -- The universally quantified variables fvars : Array Expr -- The existentially quantified variables evars : Array Expr + -- The function applied to its arguments + fArgsExpr : Expr -- The function - fExpr : Expr fName : Name -- The function arguments fLevels : List Level @@ -38,7 +63,7 @@ section Methods variable [MonadError m] variable {a : Type} - /- Analyze a pspec theorem to decompose its arguments. + /- Analyze a goal or a pspec theorem to decompose its arguments. PSpec theorems should be of the following shape: ``` @@ -57,12 +82,20 @@ section Methods TODO: generalize for when we do inductive proofs -/ partial - def withPSpec [Inhabited (m a)] [Nonempty (m a)] (th : Expr) (k : PSpecDesc → m a) - (sanityChecks : Bool := false) : + def withPSpec [Inhabited (m a)] [Nonempty (m a)] (sanityChecks : Bool := false) + (isGoal : Bool) (th : Expr) (k : PSpecDesc → m a) : m a := do trace[Progress] "Proposition: {th}" -- Dive into the quantified variables and the assumptions - forallTelescope th.consumeMData fun fvars th => do + -- Note that if we analyze a pspec theorem to register it in a database (i.e. + -- a discrimination tree), we need to introduce *meta-variables* for the + -- quantified variables. + let telescope (k : Array Expr → Expr → m a) : m a := + if isGoal then forallTelescope th.consumeMData (fun fvars th => k fvars th) + else do + let (fvars, _, th) ← forallMetaTelescope th.consumeMData + k fvars th + telescope fun fvars th => do trace[Progress] "Universally quantified arguments and assumptions: {fvars}" -- Dive into the existentials existsTelescope th.consumeMData fun evars th => do @@ -79,7 +112,7 @@ section Methods -- destruct the application to get the function name mExpr.consumeMData.withApp fun mf margs => do trace[Progress] "After stripping the arguments of the monad expression:\n- mf: {mf}\n- margs: {margs}" - let (fExpr, f, args) ← do + let (fArgsExpr, f, args) ← do if mf.isConst ∧ mf.constName = ``Bind.bind then do -- Dive into the bind let fExpr := (margs.get! 4).consumeMData @@ -101,11 +134,11 @@ section Methods let argsFVars := fvars.map (fun x => x.fvarId!) let argsFVars := argsFVars.filter (fun fvar => allArgsFVars.contains fvar) -- Return - trace[Progress] "Function: {f.constName!}"; + trace[Progress] "Function with arguments: {fArgsExpr}"; let thDesc := { fvars := fvars evars := evars - fExpr + fArgsExpr fName := f.constName! fLevels := f.constLevels! args := args @@ -117,60 +150,21 @@ section Methods end Methods -def getPSpecFunName (th : Expr) : MetaM Name := - withPSpec th (fun d => do pure d.fName) true +/-def getPSpecFunArgsExpr (th : Expr) : MetaM Expr := + withPSpec true th (fun d => do pure d.fArgsExpr) -def getPSpecClassFunNames (th : Expr) : MetaM (Name × Name) := - withPSpec th (fun d => do - let arg0 := d.args.get! 0 - arg0.withApp fun f _ => do - if ¬ f.isConst then throwError "Not a constant: {f}" - pure (d.fName, f.constName) - ) true - -def getPSpecClassFunNameArg (th : Expr) : MetaM (Name × Expr) := - withPSpec th (fun d => do - let arg0 := d.args.get! 0 - pure (d.fName, arg0) - ) true +def getPSpecFunName (th : Expr) : MetaM Name := + withPSpec true th (fun d => do pure d.fName)-/ --- "Regular" pspec attribute +-- pspec attribute structure PSpecAttr where attr : AttributeImpl - ext : MapDeclarationExtension Name - deriving Inhabited - -/- pspec attribute for type classes: we use the name of the type class to - lookup another map. We use the *first* argument of the type class to lookup - into this second map. - - Example: - ======== - We use type classes for addition. For instance, the addition between two - U32 is written (without syntactic sugar) as `HAdd.add (Scalar ty) x y`. As a consequence, - we store the theorem through the bindings: HAdd.add → Scalar → ... - - SH: TODO: this (and `PSpecClassExprAttr`) is a bit ad-hoc. For now it works for the - specs of the scalar operations, which is what I really need, but I'm not sure it - applies well to other situations. A better way would probably to use type classes, but - I couldn't get them to work on those cases. It is worth retrying. - UPDATE: use discrimination trees (`DiscrTree`) from core Lean --/ -structure PSpecClassAttr where - attr : AttributeImpl - ext : MapDeclarationExtension (NameMap Name) - deriving Inhabited - -/- Same as `PSpecClassAttr` but we use the full first argument (it works when it - is a constant). -/ -structure PSpecClassExprAttr where - attr : AttributeImpl - ext : MapDeclarationExtension (HashMap Expr Name) + ext : DiscrTreeExtension Name true deriving Inhabited /- The persistent map from function to pspec theorems. -/ initialize pspecAttr : PSpecAttr ← do - let ext ← Lookup.mkMapDeclarationExtension `pspecMap + let ext ← mkDiscrTreeExtention `pspecMap true let attrImpl : AttributeImpl := { name := `pspec descr := "Marks theorems to use with the `progress` tactic" @@ -182,130 +176,30 @@ initialize pspecAttr : PSpecAttr ← do -- Lookup the theorem let env ← getEnv let thDecl := env.constants.find! thName - let fName ← MetaM.run' (getPSpecFunName thDecl.type) - trace[Progress] "Registering spec theorem for {fName}" - let env := ext.addEntry env (fName, thName) - setEnv env - pure () - } - registerBuiltinAttribute attrImpl - pure { attr := attrImpl, ext := ext } - -/- The persistent map from type classes to pspec theorems -/ -initialize pspecClassAttr : PSpecClassAttr ← do - let ext : MapDeclarationExtension (NameMap Name) ← - Lookup.mkMapMapDeclarationExtension Name.quickCmp `pspecClassMap - let attrImpl : AttributeImpl := { - name := `cpspec - descr := "Marks theorems to use for type classes with the `progress` tactic" - add := fun thName stx attrKind => do - Attribute.Builtin.ensureNoArgs stx - -- TODO: use the attribute kind - unless attrKind == AttributeKind.global do - throwError "invalid attribute 'cpspec', must be global" - -- Lookup the theorem - let env ← getEnv - let thDecl := env.constants.find! thName - let (fName, argName) ← MetaM.run' (getPSpecClassFunNames thDecl.type) - trace[Progress] "Registering class spec theorem for ({fName}, {argName})" - -- Update the entry if there is one, add an entry if there is none - let env := - match (ext.getState (← getEnv)).find? fName with - | none => - let m := RBMap.ofList [(argName, thName)] - ext.addEntry env (fName, m) - | some m => - let m := m.insert argName thName - ext.addEntry env (fName, m) - setEnv env - pure () - } - registerBuiltinAttribute attrImpl - pure { attr := attrImpl, ext := ext } - -/- The 2nd persistent map from type classes to pspec theorems -/ -initialize pspecClassExprAttr : PSpecClassExprAttr ← do - let ext : MapDeclarationExtension (HashMap Expr Name) ← - Lookup.mkMapHashMapDeclarationExtension `pspecClassExprMap - let attrImpl : AttributeImpl := { - name := `cepspec - descr := "Marks theorems to use for type classes with the `progress` tactic" - add := fun thName stx attrKind => do - Attribute.Builtin.ensureNoArgs stx - -- TODO: use the attribute kind - unless attrKind == AttributeKind.global do - throwError "invalid attribute 'cpspec', must be global" - -- Lookup the theorem - let env ← getEnv - let thDecl := env.constants.find! thName - let (fName, arg) ← MetaM.run' (getPSpecClassFunNameArg thDecl.type) - -- Sanity check: no variables appear in the argument - MetaM.run' do - let fvars ← getFVarIds arg - if ¬ fvars.isEmpty then throwError "The first argument ({arg}) contains variables" - -- We store two bindings: - -- - arg to theorem name - -- - reduced arg to theorem name - let rarg ← MetaM.run' (reduceAll arg) - trace[Progress] "Registering class spec theorem for ({fName}, {arg}) and ({fName}, {rarg})" - -- Update the entry if there is one, add an entry if there is none - let env := - match (ext.getState (← getEnv)).find? fName with - | none => - let m := HashMap.ofList [(arg, thName), (rarg, thName)] - ext.addEntry env (fName, m) - | some m => - let m := m.insert arg thName - let m := m.insert rarg thName - ext.addEntry env (fName, m) + let isGoal := false + let fKey ← MetaM.run' (withPSpec true isGoal thDecl.type fun d => do + let fExpr := d.fArgsExpr + trace[Progress] "Registering spec theorem for {fExpr}" + -- Convert the function expression to a discrimination tree key + DiscrTree.mkPath fExpr) + let env := ext.addEntry env (fKey, thName) setEnv env pure () } registerBuiltinAttribute attrImpl pure { attr := attrImpl, ext := ext } +def PSpecAttr.find? (s : PSpecAttr) (e : Expr) : MetaM (Array Name) := do + (s.ext.getState (← getEnv)).getMatch e -def PSpecAttr.find? (s : PSpecAttr) (name : Name) : MetaM (Option Name) := do - return (s.ext.getState (← getEnv)).find? name - -def PSpecClassAttr.find? (s : PSpecClassAttr) (className argName : Name) : MetaM (Option Name) := do - match (s.ext.getState (← getEnv)).find? className with - | none => return none - | some map => return map.find? argName - -def PSpecClassExprAttr.find? (s : PSpecClassExprAttr) (className : Name) (arg : Expr) : MetaM (Option Name) := do - match (s.ext.getState (← getEnv)).find? className with - | none => return none - | some map => return map.find? arg - -def PSpecAttr.getState (s : PSpecAttr) : MetaM (NameMap Name) := do - pure (s.ext.getState (← getEnv)) - -def PSpecClassAttr.getState (s : PSpecClassAttr) : MetaM (NameMap (NameMap Name)) := do - pure (s.ext.getState (← getEnv)) - -def PSpecClassExprAttr.getState (s : PSpecClassExprAttr) : MetaM (NameMap (HashMap Expr Name)) := do +def PSpecAttr.getState (s : PSpecAttr) : MetaM (DiscrTree Name true) := do pure (s.ext.getState (← getEnv)) def showStoredPSpec : MetaM Unit := do let st ← pspecAttr.getState - let s := st.toList.foldl (fun s (f, th) => f!"{s}\n{f} → {th}") f!"" - IO.println s - -def showStoredPSpecClass : MetaM Unit := do - let st ← pspecClassAttr.getState - let s := st.toList.foldl (fun s (f, m) => - let ms := m.toList.foldl (fun s (f, th) => - f!"{s}\n {f} → {th}") f!"" - f!"{s}\n{f} → [{ms}]") f!"" - IO.println s - -def showStoredPSpecExprClass : MetaM Unit := do - let st ← pspecClassExprAttr.getState - let s := st.toList.foldl (fun s (f, m) => - let ms := m.toList.foldl (fun s (f, th) => - f!"{s}\n {f} → {th}") f!"" - f!"{s}\n{f} → [{ms}]") f!"" + -- TODO: how can we iterate over (at least) the values stored in the tree? + --let s := st.toList.foldl (fun s (f, th) => f!"{s}\n{f} → {th}") f!"" + let s := f!"{st}" IO.println s end Progress diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean index ba63f09d..93b7d7d5 100644 --- a/backends/lean/Base/Progress/Progress.lean +++ b/backends/lean/Base/Progress/Progress.lean @@ -204,11 +204,11 @@ def getFirstArg (args : Array Expr) : Option Expr := do if args.size = 0 then none else some (args.get! 0) -/- Helper: try to lookup a theorem and apply it, or continue with another tactic - if it fails -/ +/- Helper: try to lookup a theorem and apply it. + Return true if it succeeded. -/ def tryLookupApply (keep : Option Name) (ids : Array (Option Name)) (splitPost : Bool) (asmTac : TacticM Unit) (fExpr : Expr) - (kind : String) (th : Option TheoremOrLocal) (x : TacticM Unit) : TacticM Unit := do + (kind : String) (th : Option TheoremOrLocal) : TacticM Bool := do let res ← do match th with | none => @@ -223,9 +223,9 @@ def tryLookupApply (keep : Option Name) (ids : Array (Option Name)) (splitPost : pure (some res) catch _ => none match res with - | some .Ok => return () + | some .Ok => return true | some (.Error msg) => throwError msg - | none => x + | none => return false -- The array of ids are identifiers to use when introducing fresh variables def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrLocal) @@ -236,11 +236,19 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL let goalTy ← mgoal.getType trace[Progress] "goal: {goalTy}" -- Dive into the goal to lookup the theorem - let (fExpr, fName, args) ← do - withPSpec goalTy fun desc => - -- TODO: check that no quantified variables in the arguments - pure (desc.fExpr, desc.fName, desc.args) - trace[Progress] "Function: {fName}" + -- Remark: if we don't isolate the call to `withPSpec` to immediately "close" + -- the terms immediately, we may end up with the error: + -- "(kernel) declaration has free variables" + -- I'm not sure I understand why. + -- TODO: we should also check that no quantified variable appears in fExpr. + -- If such variables appear, we should just fail because the goal doesn't + -- have the proper shape. + let fExpr ← do + let isGoal := true + withPSpec false isGoal goalTy fun desc => do + let fExpr := desc.fArgsExpr + trace[Progress] "Expression to match: {fExpr}" + pure fExpr -- If the user provided a theorem/assumption: use it. -- Otherwise, lookup one. match withTh with @@ -258,36 +266,24 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL match res with | .Ok => return () | .Error msg => throwError msg - -- It failed: try to lookup a theorem - -- TODO: use a list of theorems, and try them one by one? - trace[Progress] "No assumption succeeded: trying to lookup a theorem" - let pspec ← do - let thName ← pspecAttr.find? fName - pure (thName.map fun th => .Theorem th) - tryLookupApply keep ids splitPost asmTac fExpr "pspec theorem" pspec do - -- It failed: try to lookup a *class* expr spec theorem (those are more - -- specific than class spec theorems) - trace[Progress] "Failed using a pspec theorem: trying to lookup a pspec class expr theorem" - let pspecClassExpr ← do - match getFirstArg args with - | none => pure none - | some arg => do - trace[Progress] "Using: f:{fName}, arg: {arg}" - let thName ← pspecClassExprAttr.find? fName arg - pure (thName.map fun th => .Theorem th) - tryLookupApply keep ids splitPost asmTac fExpr "pspec class expr theorem" pspecClassExpr do - -- It failed: try to lookup a *class* spec theorem - trace[Progress] "Failed using a pspec class expr theorem: trying to lookup a pspec class theorem" - let pspecClass ← do - match ← getFirstArgAppName args with - | none => pure none - | some argName => do - trace[Progress] "Using: f: {fName}, arg: {argName}" - let thName ← pspecClassAttr.find? fName argName - pure (thName.map fun th => .Theorem th) - tryLookupApply keep ids splitPost asmTac fExpr "pspec class theorem" pspecClass do - trace[Progress] "Failed using a pspec class theorem: trying to use a recursive assumption" - -- Try a recursive call - we try the assumptions of kind "auxDecl" + -- It failed: lookup the pspec theorems which match the expression + trace[Progress] "No assumption succeeded: trying to lookup a pspec theorem" + let pspecs : Array TheoremOrLocal ← do + let thNames ← pspecAttr.find? fExpr + -- TODO: because of reduction, there may be several valid theorems (for + -- instance for the scalars). We need to sort them from most specific to + -- least specific. For now, we assume the most specific theorems are at + -- the end. + let thNames := thNames.reverse + trace[Progress] "Looked up pspec theorems: {thNames}" + pure (thNames.map fun th => TheoremOrLocal.Theorem th) + -- Try the theorems one by one + for pspec in pspecs do + if ← tryLookupApply keep ids splitPost asmTac fExpr "pspec theorem" pspec then return () + else pure () + -- It failed: try to use the recursive assumptions + trace[Progress] "Failed using a pspec theorem: trying to use a recursive assumption" + -- We try to apply the assumptions of kind "auxDecl" let ctx ← Lean.MonadLCtx.getLCtx let decls ← ctx.getAllDecls let decls := decls.filter (λ decl => match decl.kind with @@ -381,8 +377,6 @@ namespace Test -- The following commands display the databases of theorems -- #eval showStoredPSpec - -- #eval showStoredPSpecClass - -- #eval showStoredPSpecExprClass open alloc.vec example {ty} {x y : Scalar ty} @@ -392,6 +386,8 @@ namespace Test progress keep _ as ⟨ z, h1 .. ⟩ simp [*, h1] + set_option trace.Progress false + example {ty} {x y : Scalar ty} (hmin : Scalar.min ty ≤ x.val + y.val) (hmax : x.val + y.val ≤ Scalar.max ty) : @@ -399,6 +395,19 @@ namespace Test progress keep h with Scalar.add_spec as ⟨ z ⟩ simp [*, h] + example {x y : U32} + (hmax : x.val + y.val ≤ U32.max) : + ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by + -- This spec theorem is suboptimal, but it is good to check that it works + progress with Scalar.add_spec as ⟨ z, h1 .. ⟩ + simp [*, h1] + + example {x y : U32} + (hmax : x.val + y.val ≤ U32.max) : + ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by + progress with U32.add_spec as ⟨ z, h1 .. ⟩ + simp [*, h1] + example {x y : U32} (hmax : x.val + y.val ≤ U32.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by -- cgit v1.2.3 From 10a77d17ea06b732106348588bedc6a89766d56f Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 11 Dec 2023 11:31:43 +0100 Subject: Reactivate the sanity checks for the progress tactic --- backends/lean/Base/Progress/Base.lean | 40 +++++++++++++------------------ backends/lean/Base/Progress/Progress.lean | 4 +--- backends/lean/Base/Utils.lean | 6 +++++ 3 files changed, 24 insertions(+), 26 deletions(-) (limited to 'backends/lean/Base') diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean index d50c357c..a72cd641 100644 --- a/backends/lean/Base/Progress/Base.lean +++ b/backends/lean/Base/Progress/Base.lean @@ -40,6 +40,7 @@ def mkDiscrTreeExtention [Inhabited α] [BEq α] (name : Name := by exact decl_n structure PSpecDesc where -- The universally quantified variables + -- Can be fvars or mvars fvars : Array Expr -- The existentially quantified variables evars : Array Expr @@ -50,8 +51,6 @@ structure PSpecDesc where -- The function arguments fLevels : List Level args : Array Expr - -- The universally quantified variables which appear in the function arguments - argsFVars : Array FVarId -- The returned value ret : Expr -- The postcondition (if there is) @@ -82,7 +81,7 @@ section Methods TODO: generalize for when we do inductive proofs -/ partial - def withPSpec [Inhabited (m a)] [Nonempty (m a)] (sanityChecks : Bool := false) + def withPSpec [Inhabited (m a)] [Nonempty (m a)] (isGoal : Bool) (th : Expr) (k : PSpecDesc → m a) : m a := do trace[Progress] "Proposition: {th}" @@ -120,19 +119,18 @@ section Methods else pure (mExpr, mf, margs) trace[Progress] "After stripping the arguments of the function call:\n- f: {f}\n- args: {args}" if ¬ f.isConst then throwError "Not a constant: {f}" - -- Compute the set of universally quantified variables which appear in the function arguments - let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty - -- Sanity check - if sanityChecks then - -- All the variables which appear in the inputs given to the function are - -- universally quantified (in particular, they are not *existentially* quantified) - let fvarsSet : HashSet FVarId := HashSet.ofArray (fvars.map (fun x => x.fvarId!)) - let filtArgsFVars := allArgsFVars.toArray.filter (fun fvar => ¬ fvarsSet.contains fvar) - if ¬ filtArgsFVars.isEmpty then + -- *Sanity check* (activated if we are analyzing a theorem to register it in a DB) + -- Check if some existentially quantified variables + let _ := do + -- Collect all the free variables in the arguments + let allArgsFVars := ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty + -- Check if they intersect the fvars we introduced for the existentially quantified variables + let evarsSet : HashSet FVarId := HashSet.ofArray (evars.map (fun (x : Expr) => x.fvarId!)) + let filtArgsFVars := allArgsFVars.toArray.filter (fun var => evarsSet.contains var) + if filtArgsFVars.isEmpty then pure () + else let filtArgsFVars := filtArgsFVars.map (fun fvarId => Expr.fvar fvarId) throwError "Some of the function inputs are not universally quantified: {filtArgsFVars}" - let argsFVars := fvars.map (fun x => x.fvarId!) - let argsFVars := argsFVars.filter (fun fvar => allArgsFVars.contains fvar) -- Return trace[Progress] "Function with arguments: {fArgsExpr}"; let thDesc := { @@ -142,7 +140,6 @@ section Methods fName := f.constName! fLevels := f.constLevels! args := args - argsFVars ret := ret post := post } @@ -150,11 +147,8 @@ section Methods end Methods -/-def getPSpecFunArgsExpr (th : Expr) : MetaM Expr := - withPSpec true th (fun d => do pure d.fArgsExpr) - -def getPSpecFunName (th : Expr) : MetaM Name := - withPSpec true th (fun d => do pure d.fName)-/ +def getPSpecFunArgsExpr (isGoal : Bool) (th : Expr) : MetaM Expr := + withPSpec isGoal th (fun d => do pure d.fArgsExpr) -- pspec attribute structure PSpecAttr where @@ -176,14 +170,14 @@ initialize pspecAttr : PSpecAttr ← do -- Lookup the theorem let env ← getEnv let thDecl := env.constants.find! thName - let isGoal := false - let fKey ← MetaM.run' (withPSpec true isGoal thDecl.type fun d => do - let fExpr := d.fArgsExpr + let fKey ← MetaM.run' (do + let fExpr ← getPSpecFunArgsExpr false thDecl.type trace[Progress] "Registering spec theorem for {fExpr}" -- Convert the function expression to a discrimination tree key DiscrTree.mkPath fExpr) let env := ext.addEntry env (fKey, thName) setEnv env + trace[Progress] "Saved the environment" pure () } registerBuiltinAttribute attrImpl diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean index 93b7d7d5..a6a4e82a 100644 --- a/backends/lean/Base/Progress/Progress.lean +++ b/backends/lean/Base/Progress/Progress.lean @@ -245,7 +245,7 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL -- have the proper shape. let fExpr ← do let isGoal := true - withPSpec false isGoal goalTy fun desc => do + withPSpec isGoal goalTy fun desc => do let fExpr := desc.fArgsExpr trace[Progress] "Expression to match: {fExpr}" pure fExpr @@ -386,8 +386,6 @@ namespace Test progress keep _ as ⟨ z, h1 .. ⟩ simp [*, h1] - set_option trace.Progress false - example {ty} {x y : Scalar ty} (hmin : Scalar.min ty ≤ x.val + y.val) (hmax : x.val + y.val ≤ Scalar.max ty) : diff --git a/backends/lean/Base/Utils.lean b/backends/lean/Base/Utils.lean index b917a789..95b2c38b 100644 --- a/backends/lean/Base/Utils.lean +++ b/backends/lean/Base/Utils.lean @@ -381,6 +381,12 @@ partial def getFVarIds (e : Expr) (hs : HashSet FVarId := HashSet.empty) : MetaM let hs := if body.isFVar then hs.insert body.fvarId! else hs args.foldlM (fun hs arg => getFVarIds arg hs) hs +-- Return the set of MVarIds in the expression +partial def getMVarIds (e : Expr) (hs : HashSet MVarId := HashSet.empty) : MetaM (HashSet MVarId) := do + e.withApp fun body args => do + let hs := if body.isMVar then hs.insert body.mvarId! else hs + args.foldlM (fun hs arg => getMVarIds arg hs) hs + -- Tactic to split on a disjunction. -- The expression `h` should be an fvar. -- TODO: there must be simpler. Use use _root_.Lean.MVarId.cases for instance -- cgit v1.2.3 From cb332ffb55425e6e6bc3b0ef8da7e646b2174fdf Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 11 Dec 2023 11:37:03 +0100 Subject: Reorganize a bit --- backends/lean/Base/Extensions.lean | 47 +++++++++++++++++++++++ backends/lean/Base/Lookup/Base.lean | 70 ----------------------------------- backends/lean/Base/Progress/Base.lean | 26 +------------ 3 files changed, 49 insertions(+), 94 deletions(-) create mode 100644 backends/lean/Base/Extensions.lean delete mode 100644 backends/lean/Base/Lookup/Base.lean (limited to 'backends/lean/Base') diff --git a/backends/lean/Base/Extensions.lean b/backends/lean/Base/Extensions.lean new file mode 100644 index 00000000..b34f41dc --- /dev/null +++ b/backends/lean/Base/Extensions.lean @@ -0,0 +1,47 @@ +import Lean +import Std.Lean.HashSet +import Base.Utils +import Base.Primitives.Base + +import Lean.Meta.DiscrTree +import Lean.Meta.Tactic.Simp + +/-! Various state extensions used in the library -/ +namespace Extensions + +open Lean Elab Term Meta +open Utils + +-- This is not used anymore but we keep it here. +-- TODO: the original function doesn't define correctly the `addImportedFn`. Do a PR? +def mkMapDeclarationExtension [Inhabited α] (name : Name := by exact decl_name%) : + IO (MapDeclarationExtension α) := + registerSimplePersistentEnvExtension { + name := name, + addImportedFn := fun a => a.foldl (fun s a => a.foldl (fun s (k, v) => s.insert k v) s) RBMap.empty, + addEntryFn := fun s n => s.insert n.1 n.2 , + toArrayFn := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1) + } + +/- Discrimination trees map expressions to values. When storing an expression + in a discrimination tree, the expression is first converted to an array + of `DiscrTree.Key`, which are the keys actually used by the discrimination + trees. The conversion operation is monadic, however, and extensions require + all the operations to be pure. For this reason, in the state extension, we + store the keys from *after* the transformation (i.e., the `DiscrTreeKey` + below). The transformation itself can be done elsewhere. + -/ +abbrev DiscrTreeKey (simpleReduce : Bool) := Array (DiscrTree.Key simpleReduce) + +abbrev DiscrTreeExtension (α : Type) (simpleReduce : Bool) := + SimplePersistentEnvExtension (DiscrTreeKey simpleReduce × α) (DiscrTree α simpleReduce) + +def mkDiscrTreeExtention [Inhabited α] [BEq α] (name : Name := by exact decl_name%) (simpleReduce : Bool) : + IO (DiscrTreeExtension α simpleReduce) := + registerSimplePersistentEnvExtension { + name := name, + addImportedFn := fun a => a.foldl (fun s a => a.foldl (fun s (k, v) => s.insertCore k v) s) DiscrTree.empty, + addEntryFn := fun s n => s.insertCore n.1 n.2 , + } + +end Extensions diff --git a/backends/lean/Base/Lookup/Base.lean b/backends/lean/Base/Lookup/Base.lean deleted file mode 100644 index 54d242ea..00000000 --- a/backends/lean/Base/Lookup/Base.lean +++ /dev/null @@ -1,70 +0,0 @@ -/- Utilities to have databases of theorems -/ -import Lean -import Std.Lean.HashSet -import Base.Utils -import Base.Primitives.Base - -namespace Lookup - -open Lean Elab Term Meta -open Utils - --- TODO: the original function doesn't define correctly the `addImportedFn`. Do a PR? -def mkMapDeclarationExtension [Inhabited α] (name : Name := by exact decl_name%) : - IO (MapDeclarationExtension α) := - registerSimplePersistentEnvExtension { - name := name, - addImportedFn := fun a => a.foldl (fun s a => a.foldl (fun s (k, v) => s.insert k v) s) RBMap.empty, - addEntryFn := fun s n => s.insert n.1 n.2 , - toArrayFn := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1) - } - --- Declare an extension of maps of maps (using [RBMap]). --- The important point is that we need to merge the bound values (which are maps). -def mkMapMapDeclarationExtension [Inhabited β] (ord : α → α → Ordering) - (name : Name := by exact decl_name%) : - IO (MapDeclarationExtension (RBMap α β ord)) := - registerSimplePersistentEnvExtension { - name := name, - addImportedFn := fun a => - a.foldl (fun s a => a.foldl ( - -- We need to merge the maps - fun s (k0, k1_to_v) => - match s.find? k0 with - | none => - -- No binding: insert one - s.insert k0 k1_to_v - | some m => - -- There is already a binding: merge - let m := RBMap.fold (fun m k v => m.insert k v) m k1_to_v - s.insert k0 m) - s) RBMap.empty, - addEntryFn := fun s n => s.insert n.1 n.2 , - toArrayFn := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1) - } - --- Declare an extension of maps of maps (using [HashMap]). --- The important point is that we need to merge the bound values (which are maps). -def mkMapHashMapDeclarationExtension [BEq α] [Hashable α] [Inhabited β] - (name : Name := by exact decl_name%) : - IO (MapDeclarationExtension (HashMap α β)) := - registerSimplePersistentEnvExtension { - name := name, - addImportedFn := fun a => - a.foldl (fun s a => a.foldl ( - -- We need to merge the maps - fun s (k0, k1_to_v) => - match s.find? k0 with - | none => - -- No binding: insert one - s.insert k0 k1_to_v - | some m => - -- There is already a binding: merge - let m := HashMap.fold (fun m k v => m.insert k v) m k1_to_v - s.insert k0 m) - s) RBMap.empty, - addEntryFn := fun s n => s.insert n.1 n.2 , - toArrayFn := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1) - } - -end Lookup diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean index a72cd641..fd80db8c 100644 --- a/backends/lean/Base/Progress/Base.lean +++ b/backends/lean/Base/Progress/Base.lean @@ -2,42 +2,20 @@ import Lean import Std.Lean.HashSet import Base.Utils import Base.Primitives.Base -import Base.Lookup.Base - +import Base.Extensions import Lean.Meta.DiscrTree import Lean.Meta.Tactic.Simp namespace Progress open Lean Elab Term Meta -open Utils +open Utils Extensions -- We can't define and use trace classes in the same file initialize registerTraceClass `Progress /- # Progress tactic -/ -/- Discrimination trees map expressions to values. When storing an expression - in a discrimination tree, the expression is first converted to an array - of `DiscrTree.Key`, which are the keys actually used by the discrimination - trees. The conversion operation is monadic, however, and extensions require - all the operations to be pure. For this reason, in the state extension, we - store the keys from *after* the transformation (i.e., the `DiscrTreeKey` - below). - -/ -abbrev DiscrTreeKey (simpleReduce : Bool) := Array (DiscrTree.Key simpleReduce) - -abbrev DiscrTreeExtension (α : Type) (simpleReduce : Bool) := - SimplePersistentEnvExtension (DiscrTreeKey simpleReduce × α) (DiscrTree α simpleReduce) - -def mkDiscrTreeExtention [Inhabited α] [BEq α] (name : Name := by exact decl_name%) (simpleReduce : Bool) : - IO (DiscrTreeExtension α simpleReduce) := - registerSimplePersistentEnvExtension { - name := name, - addImportedFn := fun a => a.foldl (fun s a => a.foldl (fun s (k, v) => s.insertCore k v) s) DiscrTree.empty, - addEntryFn := fun s n => s.insertCore n.1 n.2 , - } - structure PSpecDesc where -- The universally quantified variables -- Can be fvars or mvars -- cgit v1.2.3 From c23a37617188a1bbf913b5c700522abc33bf39c9 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 11 Dec 2023 17:00:01 +0100 Subject: Update Diverge/Elab.lean to use the more general FixII definitions --- backends/lean/Base/Diverge/Base.lean | 9 +- backends/lean/Base/Diverge/Elab.lean | 571 +++++++++++++++++++++---------- backends/lean/Base/Diverge/ElabBase.lean | 2 + 3 files changed, 387 insertions(+), 195 deletions(-) (limited to 'backends/lean/Base') diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index 9d986f4e..a7107c1e 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -581,7 +581,6 @@ namespace FixI kk_ty id a b → kk_ty id a b abbrev in_out_ty : Type (imax (u + 1) (v + 1)) := (in_ty : Type u) × ((x:in_ty) → Type v) - -- TODO: remove? abbrev mk_in_out_ty (in_ty : Type u) (out_ty : in_ty → Type v) : in_out_ty := Sigma.mk in_ty out_ty @@ -717,11 +716,8 @@ namespace FixI end FixI namespace FixII - /- Indexed fixed-point: definitions with indexed types, convenient to use for mutually - recursive definitions. We simply port the definitions and proofs from Fix to a more - specific case. - - Here however, we group the types into a parameter distinct from the inputs. + /- Similar to FixI, but we split the input arguments between the type parameters + and the input values. -/ open Primitives Fix @@ -792,7 +788,6 @@ namespace FixII abbrev in_out_ty : Type (imax (u + 1) (imax (v + 1) (w + 1))) := (ty : Type u) × (ty → Type v) × (ty → Type w) - -- TODO: remove? abbrev mk_in_out_ty (ty : Type u) (in_ty : ty → Type v) (out_ty : ty → Type w) : in_out_ty := Sigma.mk ty (Prod.mk in_ty out_ty) diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 0088fd16..423a2514 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -24,8 +24,8 @@ open WF in def mkProd (x y : Expr) : MetaM Expr := mkAppM ``Prod.mk #[x, y] -def mkInOutTy (x y : Expr) : MetaM Expr := - mkAppM ``FixI.mk_in_out_ty #[x, y] +def mkInOutTy (x y z : Expr) : MetaM Expr := do + mkAppM ``FixII.mk_in_out_ty #[x, y, z] -- Return the `a` in `Return a` def getResultTy (ty : Expr) : MetaM Expr := @@ -60,21 +60,83 @@ def getSigmaTypes (ty : Expr) : MetaM (Expr × Expr) := do def mkSigmasType (xl : List Expr) : MetaM Expr := match xl with | [] => do - trace[Diverge.def.sigmas] "mkSigmasOfTypes: []" - pure (Expr.const ``PUnit.unit []) + trace[Diverge.def.sigmas] "mkSigmasType: []" + pure (Expr.const ``PUnit [Level.succ .zero]) | [x] => do - trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}]" + trace[Diverge.def.sigmas] "mkSigmasType: [{x}]" let ty ← Lean.Meta.inferType x pure ty | x :: xl => do - trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]" + trace[Diverge.def.sigmas] "mkSigmasType: [{x}::{xl}]" let alpha ← Lean.Meta.inferType x let sty ← mkSigmasType xl - trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]: alpha={alpha}, sty={sty}" + trace[Diverge.def.sigmas] "mkSigmasType: [{x}::{xl}]: alpha={alpha}, sty={sty}" let beta ← mkLambdaFVars #[x] sty - trace[Diverge.def.sigmas] "mkSigmasOfTypes: ({alpha}) ({beta})" + trace[Diverge.def.sigmas] "mkSigmasType: ({alpha}) ({beta})" mkAppOptM ``Sigma #[some alpha, some beta] +/- Generate a product type from a list of *variables* (this is similar to `mkSigmas`). + + Example: + - xl = [(ls:List a), (i:Int)] + + Generates: + `List a × Int` + -/ +def mkProdsType (xl : List Expr) : MetaM Expr := + match xl with + | [] => do + trace[Diverge.def.prods] "mkProdsType: []" + pure (Expr.const ``PUnit [Level.succ .zero]) + | [x] => do + trace[Diverge.def.prods] "mkProdsType: [{x}]" + let ty ← Lean.Meta.inferType x + pure ty + | x :: xl => do + trace[Diverge.def.prods] "mkProdsType: [{x}::{xl}]" + let ty ← Lean.Meta.inferType x + let xl_ty ← mkProdsType xl + mkAppM ``Prod #[ty, xl_ty] + +/- Split the input arguments between the types and the "regular" arguments. + + We do something simple: we treat an input argument as an + input type iff it appears in the type of the following arguments. + + Note that what really matters is that we find the arguments which appear + in the output type. + + Also, we stop at the first input that we treat as an + input type. + -/ +def splitInputArgs (in_tys : Array Expr) (out_ty : Expr) : MetaM (Array Expr × Array Expr) := do + -- Look for the first parameter which appears in the subsequent parameters + let rec splitAux (in_tys : List Expr) : MetaM (HashSet FVarId × List Expr × List Expr) := + match in_tys with + | [] => do + let fvars ← getFVarIds (← Lean.Meta.inferType out_ty) + pure (fvars, [], []) + | ty :: in_tys => do + let (fvars, in_tys, in_args) ← splitAux in_tys + -- Have we already found where to split between type variables/regular + -- variables? + if ¬ in_tys.isEmpty then + -- The fvars set is now useless: no need to update it anymore + pure (fvars, ty :: in_tys, in_args) + else + -- Check if ty appears in the set of free variables: + let ty_id := ty.fvarId! + if fvars.contains ty_id then + -- We must split here. Note that we don't need to update the fvars + -- set: it is not useful anymore + pure (fvars, [ty], in_args) + else + -- We must split later: update the fvars set + let fvars := fvars.insertMany (← getFVarIds (← Lean.Meta.inferType ty)) + pure (fvars, [], ty :: in_args) + let (_, in_tys, in_args) ← splitAux in_tys.data + pure (Array.mk in_tys, Array.mk in_args) + /- Apply a lambda expression to some arguments, simplifying the lambdas -/ def applyLambdaToArgs (e : Expr) (xs : Array Expr) : MetaM Expr := do lambdaTelescopeN e xs.size fun vars body => @@ -105,7 +167,7 @@ def mkSigmasVal (ty : Expr) (xl : List Expr) : MetaM Expr := match xl with | [] => do trace[Diverge.def.sigmas] "mkSigmasVal: []" - pure (Expr.const ``PUnit.unit []) + pure (Expr.const ``PUnit.unit [Level.succ .zero]) | [x] => do trace[Diverge.def.sigmas] "mkSigmasVal: [{x}]" pure x @@ -122,6 +184,17 @@ def mkSigmasVal (ty : Expr) (xl : List Expr) : MetaM Expr := trace[Diverge.def.sigmas] "mkSigmasVal:\n{alpha}\n{beta}\n{fst}\n{snd}" mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd] +/- Group a list of expressions into a (non-dependent) tuple -/ +def mkProdsVal (xl : List Expr) : MetaM Expr := + match xl with + | [] => + pure (Expr.const ``PUnit.unit [Level.succ .zero]) + | [x] => do + pure x + | x :: xl => do + let xl ← mkProdsVal xl + mkAppM ``Prod.mk #[x, xl] + def mkAnonymous (s : String) (i : Nat) : Name := .num (.str .anonymous s) i @@ -159,23 +232,23 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met match xl with | [] => do -- This would be unexpected - throwError "mkSigmasMatch: empyt list of input parameters" + throwError "mkSigmasMatch: empty list of input parameters" | [x] => do -- In the example given for the explanations: this is the inner match case trace[Diverge.def.sigmas] "mkSigmasMatch: [{x}]" mkLambdaFVars #[x] out | fst :: xl => do - -- In the example given for the explanations: this is the outer match case - -- Remark: for the naming purposes, we use the same convention as for the - -- fields and parameters in `Sigma.casesOn` and `Sigma.mk` (looking at - -- those definitions might help) - -- - -- We want to build the match expression: - -- ``` - -- λ scrut => - -- match scrut with - -- | Sigma.mk x ... -- the hole is given by a recursive call on the tail - -- ``` + /- In the example given for the explanations: this is the outer match case + Remark: for the naming purposes, we use the same convention as for the + fields and parameters in `Sigma.casesOn` and `Sigma.mk` (looking at + those definitions might help) + + We want to build the match expression: + ``` + λ scrut => + match scrut with + | Sigma.mk x ... -- the hole is given by a recursive call on the tail + ``` -/ trace[Diverge.def.sigmas] "mkSigmasMatch: [{fst}::{xl}]" let alpha ← Lean.Meta.inferType fst let snd_ty ← mkSigmasType xl @@ -183,7 +256,7 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met let snd ← mkSigmasMatch xl out (index + 1) let mk ← mkLambdaFVars #[fst] snd -- Introduce the "scrut" variable - let scrut_ty ← mkSigmasType (fst :: xl) + let scrut_ty ← mkSigmasType (fst :: xl) -- TODO: factor out with snd_ty withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do trace[Diverge.def.sigmas] "mkSigmasMatch: scrut: ({scrut}) : ({← inferType scrut})" -- TODO: make the computation of the motive more efficient @@ -206,6 +279,67 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met trace[Diverge.def.sigmas] "mkSigmasMatch: sm: {sm}" pure sm +/- This is similar to `mkSigmasMatch`, but with non-dependent tuples + + Remark: factor out with `mkSigmasMatch`? This is extremely similar. +-/ +partial def mkProdsMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : MetaM Expr := + match xl with + | [] => do + -- This would be unexpected + throwError "mkProdsMatch: empty list of input parameters" + | [x] => do + -- In the example given for the explanations: this is the inner match case + trace[Diverge.def.prods] "mkProdsMatch: [{x}]" + mkLambdaFVars #[x] out + | fst :: xl => do + trace[Diverge.def.prods] "mkProdsMatch: [{fst}::{xl}]" + let alpha ← Lean.Meta.inferType fst + let beta ← mkProdsType xl + let snd ← mkProdsMatch xl out (index + 1) + let mk ← mkLambdaFVars #[fst] snd + -- Introduce the "scrut" variable + let scrut_ty ← mkProdsType (fst :: xl) -- TODO: factor out with beta + withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do + trace[Diverge.def.prods] "mkProdsMatch: scrut: ({scrut}) : ({← inferType scrut})" + -- TODO: make the computation of the motive more efficient + let motive ← do + let out_ty ← inferType out + mkLambdaFVars #[scrut] out_ty + -- The final expression: putting everything together + trace[Diverge.def.prods] "mkProdsMatch:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})" + let sm ← mkAppOptM ``Prod.casesOn #[some alpha, some beta, some motive, some scrut, some mk] + -- Abstracting the "scrut" variable + let sm ← mkLambdaFVars #[scrut] sm + trace[Diverge.def.prods] "mkProdsMatch: sm: {sm}" + pure sm + +/- Same as `mkSigmasMatch` but also accepts an empty list of inputs, in which case + it generates the expression: + ``` + λ () => e + ``` -/ +def mkSigmasMatchOrUnit (xl : List Expr) (out : Expr) : MetaM Expr := + if xl.isEmpty then do + let scrut_ty := Expr.const ``PUnit [Level.succ .zero] + withLocalDeclD (mkAnonymous "scrut" 0) scrut_ty fun scrut => do + mkLambdaFVars #[scrut] out + else + mkSigmasMatch xl out + +/- Same as `mkProdsMatch` but also accepts an empty list of inputs, in which case + it generates the expression: + ``` + λ () => e + ``` -/ +def mkProdsMatchOrUnit (xl : List Expr) (out : Expr) : MetaM Expr := + if xl.isEmpty then do + let scrut_ty := Expr.const ``PUnit [Level.succ .zero] + withLocalDeclD (mkAnonymous "scrut" 0) scrut_ty fun scrut => do + mkLambdaFVars #[scrut] out + else + mkProdsMatch xl out + /- Small tests for list_nth: give a model of what `mkSigmasMatch` should generate -/ private def list_nth_out_ty_inner (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) := @Sigma.casesOn (List a) @@ -244,17 +378,23 @@ def mkFinVal (n i : Nat) : MetaM Expr := do We name the declarations: "[original_name].body". We return the new declarations. + + Inputs: + - `paramInOutTys`: (number of type parameters, sigma type grouping the type parameters, input type, output type) -/ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) - (inOutTys : Array (Expr × Expr)) (preDefs : Array PreDefinition) : + (paramInOutTys : Array (ℕ × Expr × Expr × Expr)) (preDefs : Array PreDefinition) : MetaM (Array Expr) := do let grSize := preDefs.size - -- Compute the map from name to (index × input type). - -- Remark: the continuation has an indexed type; we use the index (a finite number of - -- type `Fin`) to control which function we call at the recursive call site. - let nameToInfo : HashMap Name (Nat × Expr) := - let bl := preDefs.mapIdx fun i d => (d.declName, (i.val, (inOutTys.get! i.val).fst)) + /- Compute the map from name to (index, num type parameters, parameters type, input type). + Example for `list_nth (α : Type) (ls : List α) (i : Int) : Result α`: `"list_nth" → (0, 1, α, (List α × Int))` + Remark: the continuation has an indexed type; we use the index (a finite number of + type `Fin`) to control which function we call at the recursive call site. -/ + let nameToInfo : HashMap Name (Nat × Nat × Expr × Expr) := + let bl := preDefs.mapIdx fun i d => + let (num_params, params_ty, in_ty, _) := paramInOutTys.get! i.val + (d.declName, (i.val, num_params, params_ty, in_ty)) HashMap.ofList bl.toList trace[Diverge.def.genBody] "nameToId: {nameToInfo.toList}" @@ -262,25 +402,29 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) -- Auxiliary function to explore the function bodies and replace the -- recursive calls let visit_e (i : Nat) (e : Expr) : MetaM Expr := do - trace[Diverge.def.genBody] "visiting expression (dept: {i}): {e}" + trace[Diverge.def.genBody.visit] "visiting expression (dept: {i}): {e}" let ne ← do match e with | .app .. => do e.withApp fun f args => do - trace[Diverge.def.genBody] "this is an app: {f} {args}" + trace[Diverge.def.genBody.visit] "this is an app: {f} {args}" -- Check if this is a recursive call if f.isConst then let name := f.constName! match nameToInfo.find? name with | none => pure e - | some (id, in_ty) => - trace[Diverge.def.genBody] "this is a recursive call" + | some (id, num_params, params_ty, _in_ty) => + trace[Diverge.def.genBody.visit] "this is a recursive call" -- This is a recursive call: replace it -- Compute the index let i ← mkFinVal grSize id - -- Put the arguments in one big dependent tuple - let args ← mkSigmasVal in_ty args.toList - mkAppM' kk_var #[i, args] + -- Split the arguments, and put them in two tuples (the first + -- one is a dependent tuple) + let (param_args, args) := args.toList.splitAt num_params + trace[Diverge.def.genBody.visit] "param_args: {param_args}, args: {args}" + let param_args ← mkSigmasVal params_ty param_args + let args ← mkProdsVal args + mkAppM' kk_var #[i, param_args, args] else -- Not a recursive call: do nothing pure e @@ -290,7 +434,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) throwError "mkUnaryBodies: a recursive call was not eliminated" else pure e | _ => pure e - trace[Diverge.def.genBody] "done with expression (depth: {i}): {e}" + trace[Diverge.def.genBody.visit] "done with expression (depth: {i}): {e}" pure ne -- Explore the bodies @@ -300,13 +444,20 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) let body ← mapVisit visit_e preDef.value trace[Diverge.def.genBody] "Body after replacement of the recursive calls: {body}" - -- Currify the function by grouping the arguments into a dependent tuple + -- Currify the function by grouping the arguments into dependent tuples -- (over which we match to retrieve the individual arguments). lambdaTelescope body fun args body => do - let body ← mkSigmasMatch args.toList body 0 + -- Split the arguments between the type parameters and the "regular" inputs + let (_, num_params, _, _) := nameToInfo.find! preDef.declName + let (param_args, args) := args.toList.splitAt num_params + let body ← mkProdsMatchOrUnit args body + trace[Diverge.def.genBody] "Body after mkProdsMatchOrUnit: {body}" + let body ← mkSigmasMatchOrUnit param_args body + trace[Diverge.def.genBody] "Body after mkSigmasMatchOrUnit: {body}" -- Add the declaration let value ← mkLambdaFVars #[kk_var] body + trace[Diverge.def.genBody] "Body after abstracting kk: {value}" let name := preDef.declName.append "body" let levelParams := grLvlParams let decl := Declaration.defnDecl { @@ -318,6 +469,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) safety := .safe all := [name] } + trace[Diverge.def.genBody] "About to add decl" addDecl decl trace[Diverge.def] "individual body of {preDef.declName}: {body}" -- Return the constant @@ -326,33 +478,35 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) trace[Diverge.def] "individual body (after decl): {body}" pure body --- Generate a unique function body from the bodies of the mutually recursive group, --- and add it as a declaration in the context. --- We return the list of bodies (of type `FixI.Funs ...`) and the mutually recursive body. +/- Generate a unique function body from the bodies of the mutually recursive group, + and add it as a declaration in the context. + We return the list of bodies (of type `FixI.Funs ...`) and the mutually recursive body. + -/ def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name) (kk_var i_var : Expr) - (in_ty out_ty : Expr) (inOutTys : List (Expr × Expr)) + (param_ty in_ty out_ty : Expr) (paramInOutTys : List (Nat × Expr × Expr × Expr)) (bodies : Array Expr) : MetaM (Expr × Expr) := do -- Generate the body let grSize := bodies.size let finTypeExpr := mkFin grSize -- TODO: not very clean - let inOutTyType ← do - let (x, y) := inOutTys.get! 0 - inferType (← mkInOutTy x y) - let rec mkFuns (inOutTys : List (Expr × Expr)) (bl : List Expr) : MetaM Expr := - match inOutTys, bl with + let paramInOutTyType ← do + let (_, x, y, z) := paramInOutTys.get! 0 + inferType (← mkInOutTy x y z) + let rec mkFuns (paramInOutTys : List (Nat × Expr × Expr × Expr)) (bl : List Expr) : MetaM Expr := + match paramInOutTys, bl with | [], [] => - mkAppOptM ``FixI.Funs.Nil #[finTypeExpr, in_ty, out_ty] - | (ity, oty) :: inOutTys, b :: bl => do + mkAppOptM ``FixII.Funs.Nil #[finTypeExpr, param_ty, in_ty, out_ty] + | (_, pty, ity, oty) :: paramInOutTys, b :: bl => do -- Retrieving ity and oty - this is not very clean - let inOutTysExpr ← mkListLit inOutTyType (← inOutTys.mapM (λ (x, y) => mkInOutTy x y)) - let fl ← mkFuns inOutTys bl - mkAppOptM ``FixI.Funs.Cons #[finTypeExpr, in_ty, out_ty, ity, oty, inOutTysExpr, b, fl] + let paramInOutTysExpr ← mkListLit paramInOutTyType + (← paramInOutTys.mapM (λ (_, x, y, z) => mkInOutTy x y z)) + let fl ← mkFuns paramInOutTys bl + mkAppOptM ``FixII.Funs.Cons #[finTypeExpr, param_ty, in_ty, out_ty, pty, ity, oty, paramInOutTysExpr, b, fl] | _, _ => throwError "mkDeclareMutRecBody: `tys` and `bodies` don't have the same length" - let bodyFuns ← mkFuns inOutTys bodies.toList + let bodyFuns ← mkFuns paramInOutTys bodies.toList -- Wrap in `get_fun` - let body ← mkAppM ``FixI.get_fun #[bodyFuns, i_var, kk_var] + let body ← mkAppM ``FixII.get_fun #[bodyFuns, i_var, kk_var] -- Add the index `i` and the continuation `k` as a variables let body ← mkLambdaFVars #[kk_var, i_var] body trace[Diverge.def] "mkDeclareMutRecBody: body: {body}" @@ -391,11 +545,11 @@ instance : ToMessageData MatchInfo where -- This is not a very clean formatting, but we don't need more toMessageData := fun me => m!"\n- matcherName: {me.matcherName}\n- params: {me.params}\n- motive: {me.motive}\n- scruts: {me.scruts}\n- branchesNumParams: {me.branchesNumParams}\n- branches: {me.branches}" --- Small helper: prove that an expression which doesn't use the continuation `kk` --- is valid, and return the proof. +/- Small helper: prove that an expression which doesn't use the continuation `kk` + is valid, and return the proof. -/ def proveNoKExprIsValid (k_var : Expr) (e : Expr) : MetaM Expr := do trace[Diverge.def.valid] "proveNoKExprIsValid: {e}" - let eIsValid ← mkAppM ``FixI.is_valid_p_same #[k_var, e] + let eIsValid ← mkAppM ``FixII.is_valid_p_same #[k_var, e] trace[Diverge.def.valid] "proveNoKExprIsValid: result:\n{eIsValid}:\n{← inferType eIsValid}" pure eIsValid @@ -462,8 +616,11 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do mkLambdaFVars #[x] brValid let br0Valid ← proveBranchValid br0 let br1Valid ← proveBranchValid br1 - let const := if isIte then ``FixI.is_valid_p_ite else ``FixI.is_valid_p_dite - let eIsValid ← mkAppOptM const #[none, none, none, none, some k_var, some cond, some dec, none, none, some br0Valid, some br1Valid] + let const := if isIte then ``FixII.is_valid_p_ite else ``FixII.is_valid_p_dite + let eIsValid ← + mkAppOptM const #[none, none, none, none, none, + some k_var, some cond, some dec, none, none, + some br0Valid, some br1Valid] trace[Diverge.def.valid] "ite/dite: result:\n{eIsValid}:\n{← inferType eIsValid}" pure eIsValid -- Check if the expression is a match (this case is for when the elaborator @@ -498,15 +655,16 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do -- use to currify function bodies, introduce such raw matches. else if ← isCasesExpr f then do trace[Diverge.def.valid] "rawMatch: {e}" - -- Deconstruct the match, and call the auxiliary helper `proveMatchIsValid`. - -- - -- The casesOn definition is always of the following shape: - -- - input parameters (implicit parameters) - -- - motive (implicit), -- the motive gives the return type of the match - -- - scrutinee (explicit) - -- - branches (explicit). - -- In particular, we notice that the scrutinee is the first *explicit* - -- parameter - this is how we spot it. + /- Deconstruct the match, and call the auxiliary helper `proveMatchIsValid`. + + The casesOn definition is always of the following shape: + - input parameters (implicit parameters) + - motive (implicit), -- the motive gives the return type of the match + - scrutinee (explicit) + - branches (explicit). + In particular, we notice that the scrutinee is the first *explicit* + parameter - this is how we spot it. + -/ let matcherName := f.constName! let matcherLevels := f.constLevels!.toArray -- Find the first explicit parameter: this is the scrutinee @@ -526,10 +684,10 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do let motive := args.get! (scrutIdx - 1) let scrut := args.get! scrutIdx let branches := args.extract (scrutIdx + 1) args.size - -- Compute the number of parameters for the branches: for this we use - -- the type of the uninstantiated casesOn constant (we can't just - -- destruct the lambdas in the branch expressions because the result - -- of a match might be a lambda expression). + /- Compute the number of parameters for the branches: for this we use + the type of the uninstantiated casesOn constant (we can't just + destruct the lambdas in the branch expressions because the result + of a match might be a lambda expression). -/ let branchesNumParams : Array Nat ← do let env ← getEnv let decl := env.constants.find! matcherName @@ -572,14 +730,15 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do pure yValid -- Put everything together trace[Diverge.def.valid] "bind:\n- xValid: {xValid}: {← inferType xValid}\n- yValid: {yValid}: {← inferType yValid}" - mkAppM ``FixI.is_valid_p_bind #[xValid, yValid] + mkAppM ``FixII.is_valid_p_bind #[xValid, yValid] -- Check if this is a recursive call, i.e., a call to the continuation `kk` else if f.isFVarOf kk_var.fvarId! then do trace[Diverge.def.valid] "rec: args: \n{args}" - if args.size ≠ 2 then throwError "Recursive call with invalid number of parameters: {args}" + if args.size ≠ 3 then throwError "Recursive call with invalid number of parameters: {args}" let i_arg := args.get! 0 - let x_arg := args.get! 1 - let eIsValid ← mkAppM ``FixI.is_valid_p_rec #[k_var, i_arg, x_arg] + let t_arg := args.get! 1 + let x_arg := args.get! 2 + let eIsValid ← mkAppM ``FixII.is_valid_p_rec #[k_var, i_arg, t_arg, x_arg] trace[Diverge.def.valid] "rec: result: \n{eIsValid}" pure eIsValid else do @@ -593,10 +752,10 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp trace[Diverge.def.valid] "proveMatchIsValid: {me}" -- Prove the validity of the branch expressions let branchesValid:Array Expr ← me.branches.mapIdxM fun idx br => do - -- Go inside the lambdas - note that we have to be careful: some of the - -- binders might come from the match, and some of the binders might come - -- from the fact that the expression in the match is a lambda expression: - -- we use the branchesNumParams field for this reason + /- Go inside the lambdas - note that we have to be careful: some of the + binders might come from the match, and some of the binders might come + from the fact that the expression in the match is a lambda expression: + we use the branchesNumParams field for this reason. -/ let numParams := me.branchesNumParams.get! idx lambdaTelescopeN br numParams fun xs br => do -- Prove that the branch expression is valid @@ -604,13 +763,13 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp -- Reconstruct the lambda expression mkLambdaFVars xs brValid trace[Diverge.def.valid] "branchesValid:\n{branchesValid}" - -- Compute the motive, which has the following shape: - -- ``` - -- λ scrut => is_valid_p k (λ k => match scrut with ...) - -- ^^^^^^^^^^^^^^^^^^^^ - -- this is the original match expression, with the - -- the difference that the scrutinee(s) is a variable - -- ``` + /- Compute the motive, which has the following shape: + ``` + λ scrut => is_valid_p k (λ k => match scrut with ...) + ^^^^^^^^^^^^^^^^^^^^ + this is the original match expression, with the + the difference that the scrutinee(s) is a variable + ``` -/ let validMotive : Expr ← do -- The motive is a function of the scrutinees (i.e., a lambda expression): -- introduce binders for the scrutinees @@ -629,7 +788,7 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp let matchE ← mkAppOptM me.matcherName args -- Wrap in the `is_valid_p` predicate let matchE ← mkLambdaFVars #[kk_var] matchE - let validMotive ← mkAppM ``FixI.is_valid_p #[k_var, matchE] + let validMotive ← mkAppM ``FixII.is_valid_p #[k_var, matchE] -- Abstract away the scrutinee variables mkLambdaFVars scrutVars validMotive trace[Diverge.def.valid] "valid motive: {validMotive}" @@ -647,10 +806,10 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp end --- Prove that a single body (in the mutually recursive group) is valid. --- --- For instance, if we define the mutually recursive group [`is_even`, `is_odd`], --- we prove that `is_even.body` and `is_odd.body` are valid. +/- Prove that a single body (in the mutually recursive group) is valid. + + For instance, if we define the mutually recursive group [`is_even`, `is_odd`], + we prove that `is_even.body` and `is_odd.body` are valid. -/ partial def proveSingleBodyIsValid (k_var : Expr) (preDef : PreDefinition) (bodyConst : Expr) : MetaM Expr := do @@ -662,24 +821,28 @@ partial def proveSingleBodyIsValid let body := (env.constants.find! name).value! trace[Diverge.def.valid] "body: {body}" lambdaTelescope body fun xs body => do - assert! xs.size = 2 + trace[Diverge.def.valid] "xs: {xs}" + assert! xs.size = 3 let kk_var := xs.get! 0 - let x_var := xs.get! 1 + let t_var := xs.get! 1 + let x_var := xs.get! 2 -- State the type of the theorem to prove - let thmTy ← mkAppM ``FixI.is_valid_p - #[k_var, ← mkLambdaFVars #[kk_var] (← mkAppM' bodyConst #[kk_var, x_var])] + trace[Diverge.def.valid] "bodyConst: {bodyConst} : {← inferType bodyConst}" + let bodyApp ← mkAppOptM' bodyConst #[.some kk_var, .some t_var, .some x_var] + trace[Diverge.def.valid] "bodyApp: {bodyApp}" + let bodyApp ← mkLambdaFVars #[kk_var] bodyApp + trace[Diverge.def.valid] "bodyApp: {bodyApp}" + let thmTy ← mkAppM ``FixII.is_valid_p #[k_var, bodyApp] trace[Diverge.def.valid] "thmTy: {thmTy}" -- Prove that the body is valid let proof ← proveExprIsValid k_var kk_var body - let proof ← mkLambdaFVars #[k_var, x_var] proof + let proof ← mkLambdaFVars #[k_var, t_var, x_var] proof trace[Diverge.def.valid] "proveSingleBodyIsValid: proof:\n{proof}:\n{← inferType proof}" -- The target type (we don't have to do this: this is simply a sanity check, -- and this allows a nicer debugging output) let thmTy ← do - let body ← mkAppM' bodyConst #[kk_var, x_var] - let body ← mkLambdaFVars #[kk_var] body - let ty ← mkAppM ``FixI.is_valid_p #[k_var, body] - mkForallFVars #[k_var, x_var] ty + let ty ← mkAppM ``FixII.is_valid_p #[k_var, bodyApp] + mkForallFVars #[k_var, t_var, x_var] ty trace[Diverge.def.valid] "proveSingleBodyIsValid: thmTy\n{thmTy}:\n{← inferType thmTy}" -- Save the theorem let name := preDef.declName ++ "body_is_valid" @@ -695,18 +858,18 @@ partial def proveSingleBodyIsValid -- Return the theorem pure (Expr.const name (preDef.levelParams.map .param)) --- Prove that the list of bodies are valid. --- --- For instance, if we define the mutually recursive group [`is_even`, `is_odd`], --- we prove that `Funs.Cons is_even.body (Funs.Cons is_odd.body Funs.Nil)` is --- valid. -partial def proveFunsBodyIsValid (inOutTys: Expr) (bodyFuns : Expr) +/- Prove that the list of bodies are valid. + + For instance, if we define the mutually recursive group [`is_even`, `is_odd`], + we prove that `Funs.Cons is_even.body (Funs.Cons is_odd.body Funs.Nil)` is + valid. -/ +partial def proveFunsBodyIsValid (paramInOutTys: Expr) (bodyFuns : Expr) (k_var : Expr) (bodiesValid : Array Expr) : MetaM Expr := do -- Create the big "and" expression, which groups the validity proof of the individual bodies let rec mkValidConj (i : Nat) : MetaM Expr := do if i = bodiesValid.size then -- We reached the end - mkAppM ``FixI.Funs.is_valid_p_Nil #[k_var] + mkAppM ``FixII.Funs.is_valid_p_Nil #[k_var] else do -- We haven't reached the end: introduce a conjunction let valid := bodiesValid.get! i @@ -714,20 +877,20 @@ partial def proveFunsBodyIsValid (inOutTys: Expr) (bodyFuns : Expr) mkAppM ``And.intro #[valid, ← mkValidConj (i + 1)] let andExpr ← mkValidConj 0 -- Wrap in the `is_valid_p_is_valid_p` theorem, and abstract the continuation - let isValid ← mkAppM ``FixI.Funs.is_valid_p_is_valid_p #[inOutTys, k_var, bodyFuns, andExpr] + let isValid ← mkAppM ``FixII.Funs.is_valid_p_is_valid_p #[paramInOutTys, k_var, bodyFuns, andExpr] mkLambdaFVars #[k_var] isValid --- Prove that the mut rec body (i.e., the unary body which groups the bodies --- of all the functions in the mutually recursive group and on which we will --- apply the fixed-point operator) is valid. --- --- We save the proof in the theorem "[GROUP_NAME]."mut_rec_body_is_valid", --- which we return. --- --- TODO: maybe this function should introduce k_var itself +/- Prove that the mut rec body (i.e., the unary body which groups the bodies + of all the functions in the mutually recursive group and on which we will + apply the fixed-point operator) is valid. + + We save the proof in the theorem "[GROUP_NAME]."mut_rec_body_is_valid", + which we return. + + TODO: maybe this function should introduce k_var itself -/ def proveMutRecIsValid (grName : Name) (grLvlParams : List Name) - (inOutTys : Expr) (bodyFuns mutRecBodyConst : Expr) + (paramInOutTys : Expr) (bodyFuns mutRecBodyConst : Expr) (k_var : Expr) (preDefs : Array PreDefinition) (bodies : Array Expr) : MetaM Expr := do -- First prove that the individual bodies are valid @@ -738,9 +901,9 @@ def proveMutRecIsValid proveSingleBodyIsValid k_var preDef body -- Then prove that the mut rec body is valid trace[Diverge.def.valid] "## Proving that the 'Funs' body is valid" - let isValid ← proveFunsBodyIsValid inOutTys bodyFuns k_var bodiesValid + let isValid ← proveFunsBodyIsValid paramInOutTys bodyFuns k_var bodiesValid -- Save the theorem - let thmTy ← mkAppM ``FixI.is_valid #[mutRecBodyConst] + let thmTy ← mkAppM ``FixII.is_valid #[mutRecBodyConst] let name := grName ++ "mut_rec_body_is_valid" let decl := Declaration.thmDecl { name @@ -754,26 +917,29 @@ def proveMutRecIsValid -- Return the theorem pure (Expr.const name (grLvlParams.map .param)) --- Generate the final definions by using the mutual body and the fixed point operator. --- --- For instance: --- ``` --- def is_even (i : Int) : Result Bool := mut_rec_body 0 i --- def is_odd (i : Int) : Result Bool := mut_rec_body 1 i --- ``` -def mkDeclareFixDefs (mutRecBody : Expr) (inOutTys : Array (Expr × Expr)) (preDefs : Array PreDefinition) : +/- Generate the final definions by using the mutual body and the fixed point operator. + + For instance: + ``` + def is_even (i : Int) : Result Bool := mut_rec_body 0 i + def is_odd (i : Int) : Result Bool := mut_rec_body 1 i + ``` + -/ +def mkDeclareFixDefs (mutRecBody : Expr) (paramInOutTys : Array (ℕ × Expr × Expr × Expr)) (preDefs : Array PreDefinition) : TermElabM (Array Name) := do let grSize := preDefs.size let defs ← preDefs.mapIdxM fun idx preDef => do lambdaTelescope preDef.value fun xs _ => do - -- Retrieve the input type - let in_ty := (inOutTys.get! idx.val).fst + -- Retrieve the parameters info + let (num_params, param_ty, _, _) := paramInOutTys.get! idx.val -- Create the index let idx ← mkFinVal grSize idx.val - -- Group the inputs into a dependent tuple - let input ← mkSigmasVal in_ty xs.toList + -- Group the inputs into two tuples + let (params_args, input_args) := xs.toList.splitAt num_params + let params ← mkSigmasVal param_ty params_args + let input ← mkProdsVal input_args -- Apply the fixed point - let fixedBody ← mkAppM ``FixI.fix #[mutRecBody, idx, input] + let fixedBody ← mkAppM ``FixII.fix #[mutRecBody, idx, params, input] let fixedBody ← mkLambdaFVars xs fixedBody -- Create the declaration let name := preDef.declName @@ -791,7 +957,8 @@ def mkDeclareFixDefs (mutRecBody : Expr) (inOutTys : Array (Expr × Expr)) (preD pure defs -- Prove the equations that we will use as unfolding theorems -partial def proveUnfoldingThms (isValidThm : Expr) (inOutTys : Array (Expr × Expr)) +partial def proveUnfoldingThms (isValidThm : Expr) + (paramInOutTys : Array (ℕ × Expr × Expr × Expr)) (preDefs : Array PreDefinition) (decls : Array Name) : MetaM Unit := do let grSize := preDefs.size let proveIdx (i : Nat) : MetaM Unit := do @@ -811,14 +978,18 @@ partial def proveUnfoldingThms (isValidThm : Expr) (inOutTys : Array (Expr × Ex trace[Diverge.def.unfold] "proveUnfoldingThms: thm statement: {thmTy}" -- The proof -- Use the fixed-point equation - let proof ← mkAppM ``FixI.is_valid_fix_fixed_eq #[isValidThm] + let proof ← mkAppM ``FixII.is_valid_fix_fixed_eq #[isValidThm] -- Add the index let idx ← mkFinVal grSize i let proof ← mkAppM ``congr_fun #[proof, idx] - -- Add the input argument - let arg ← mkSigmasVal (inOutTys.get! i).fst xs.toList - let proof ← mkAppM ``congr_fun #[proof, arg] - -- Abstract the arguments away + -- Add the input arguments + let (num_params, param_ty, _, _) := paramInOutTys.get! i + let (params, args) := xs.toList.splitAt num_params + let params ← mkSigmasVal param_ty params + let args ← mkProdsVal args + let proof ← mkAppM ``congr_fun #[proof, params] + let proof ← mkAppM ``congr_fun #[proof, args] + -- Abstract all the arguments away let proof ← mkLambdaFVars xs proof trace[Diverge.def.unfold] "proveUnfoldingThms: proof: {proof}:\n{← inferType proof}" -- Declare the theorem @@ -846,7 +1017,9 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) trace[Diverge.def] ("divRecursion: defs:\n" ++ msg) - -- TODO: what is this? + -- Apply all the "attribute" functions (for instance, the function which + -- registers the theorem in the simp database if there is the `simp` attribute, + -- etc.) for preDef in preDefs do applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation @@ -860,40 +1033,52 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let grLvlParams := def0.levelParams trace[Diverge.def] "def0 universe levels: {def0.levelParams}" - -- We first compute the list of pairs: (input type × output type) - let inOutTys : Array (Expr × Expr) ← - preDefs.mapM (fun preDef => do - withRef preDef.ref do -- is the withRef useful? - -- Check the universe parameters - TODO: I'm not sure what the best thing - -- to do is. In practice, all the type parameters should be in Type 0, so - -- we shouldn't have universe issues. - if preDef.levelParams ≠ grLvlParams then - throwError "Non-uniform polymorphism in the universes" - forallTelescope preDef.type (fun in_tys out_ty => do - let in_ty ← liftM (mkSigmasType in_tys.toList) - -- Retrieve the type in the "Result" - let out_ty ← getResultTy out_ty - let out_ty ← liftM (mkSigmasMatch in_tys.toList out_ty) - pure (in_ty, out_ty) - ) - ) - trace[Diverge.def] "inOutTys: {inOutTys}" - -- Turn the list of input/output type pairs into an expresion - let inOutTysExpr ← inOutTys.mapM (λ (x, y) => mkInOutTy x y) - let inOutTysExpr ← mkListLit (← inferType (inOutTysExpr.get! 0)) inOutTysExpr.toList - - -- From the list of pairs of input/output types, actually compute the - -- type of the continuation `k`. - -- We first introduce the index `i : Fin n` where `n` is the number of - -- functions in the group. + /- We first compute the tuples: (type parameters × input type × output type) + - type parameters: this is a sigma type + - input type: λ params_type => product type + - output type: λ params_type => output type + For instance, on the function: + `list_nth (α : Type) (ls : List α) (i : Int) : Result α`: + we generate: + `(Type, λ α => List α × i, λ α => Result α)` + -/ + let paramInOutTys : Array (ℕ × Expr × Expr × Expr) ← + preDefs.mapM (fun preDef => do + -- Check the universe parameters - TODO: I'm not sure what the best thing + -- to do is. In practice, all the type parameters should be in Type 0, so + -- we shouldn't have universe issues. + if preDef.levelParams ≠ grLvlParams then + throwError "Non-uniform polymorphism in the universes" + forallTelescope preDef.type (fun in_tys out_ty => do + let (params, in_tys) ← splitInputArgs in_tys out_ty + trace[Diverge.def] "Decomposed arguments: {preDef.declName}: {params}, {in_tys}, {out_ty}" + let num_params := params.size + let params_sigma ← mkSigmasType params.data + let in_tys ← mkSigmasMatchOrUnit params.data (← mkProdsType in_tys.data) + -- Retrieve the type in the "Result" + let out_ty ← getResultTy out_ty + let out_ty ← mkSigmasMatchOrUnit params.data out_ty + trace[Diverge.def] "inOutTy: {preDef.declName}: {params_sigma}, {in_tys}, {out_ty}" + pure (num_params, params_sigma, in_tys, out_ty))) + trace[Diverge.def] "paramInOutTys: {paramInOutTys}" + -- Turn the list of input types/input args/output type tuples into expressions + let paramInOutTysExpr ← paramInOutTys.mapM (λ (_, x, y, z) => do mkInOutTy x y z) + let paramInOutTysExpr ← mkListLit (← inferType (paramInOutTysExpr.get! 0)) paramInOutTysExpr.toList + trace[Diverge.def] "paramInOutTys: {paramInOutTys}" + + /- From the list of pairs of input/output types, actually compute the + type of the continuation `k`. + We first introduce the index `i : Fin n` where `n` is the number of + functions in the group. + -/ let i_var_ty := mkFin preDefs.size withLocalDeclD (mkAnonymous "i" 0) i_var_ty fun i_var => do - let in_out_ty ← mkAppM ``List.get #[inOutTysExpr, i_var] - trace[Diverge.def] "in_out_ty := {in_out_ty} : {← inferType in_out_ty}" - -- Add an auxiliary definition for `in_out_ty` - let in_out_ty ← do - let value ← mkLambdaFVars #[i_var] in_out_ty - let name := grName.append "in_out_ty" + let param_in_out_ty ← mkAppM ``List.get #[paramInOutTysExpr, i_var] + trace[Diverge.def] "param_in_out_ty := {param_in_out_ty} : {← inferType param_in_out_ty}" + -- Add an auxiliary definition for `param_in_out_ty` (this is a potentially big term) + let param_in_out_ty ← do + let value ← mkLambdaFVars #[i_var] param_in_out_ty + let name := grName.append "param_in_out_ty" let levelParams := grLvlParams let decl := Declaration.defnDecl { name := name @@ -906,19 +1091,28 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do } addDecl decl -- Return the constant - let in_out_ty := Lean.mkConst name (levelParams.map .param) - mkAppM' in_out_ty #[i_var] - trace[Diverge.def] "in_out_ty (after decl) := {in_out_ty} : {← inferType in_out_ty}" - let in_ty ← mkAppM ``Sigma.fst #[in_out_ty] + let param_in_out_ty := Lean.mkConst name (levelParams.map .param) + mkAppM' param_in_out_ty #[i_var] + trace[Diverge.def] "param_in_out_ty (after decl) := {param_in_out_ty} : {← inferType param_in_out_ty}" + -- Decompose between: param_ty, in_ty, out_ty + let param_ty ← mkAppM ``Sigma.fst #[param_in_out_ty] + let in_out_ty ← mkAppM ``Sigma.snd #[param_in_out_ty] + let in_ty ← mkAppM ``Prod.fst #[in_out_ty] + let out_ty ← mkAppM ``Prod.snd #[in_out_ty] + trace[Diverge.def] "param_ty: {param_ty}" + trace[Diverge.def] "in_ty: {in_ty}" + trace[Diverge.def] "out_ty: {out_ty}" + withLocalDeclD (mkAnonymous "t" 1) param_ty fun param => do + let in_ty ← mkAppM' in_ty #[param] + let out_ty ← mkAppM' out_ty #[param] trace[Diverge.def] "in_ty: {in_ty}" - withLocalDeclD (mkAnonymous "x" 1) in_ty fun input => do - let out_ty ← mkAppM' (← mkAppM ``Sigma.snd #[in_out_ty]) #[input] trace[Diverge.def] "out_ty: {out_ty}" -- Introduce the continuation `k` - let in_ty ← mkLambdaFVars #[i_var] in_ty - let out_ty ← mkLambdaFVars #[i_var, input] out_ty - let kk_var_ty ← mkAppM ``FixI.kk_ty #[i_var_ty, in_ty, out_ty] + let param_ty ← mkLambdaFVars #[i_var] param_ty + let in_ty ← mkLambdaFVars #[i_var, param] in_ty + let out_ty ← mkLambdaFVars #[i_var, param] out_ty + let kk_var_ty ← mkAppM ``FixII.kk_ty #[i_var_ty, param_ty, in_ty, out_ty] trace[Diverge.def] "kk_var_ty: {kk_var_ty}" withLocalDeclD (mkAnonymous "kk" 2) kk_var_ty fun kk_var => do trace[Diverge.def] "kk_var: {kk_var}" @@ -926,29 +1120,30 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Replace the recursive calls in all the function bodies by calls to the -- continuation `k` and and generate for those bodies declarations trace[Diverge.def] "# Generating the unary bodies" - let bodies ← mkDeclareUnaryBodies grLvlParams kk_var inOutTys preDefs + let bodies ← mkDeclareUnaryBodies grLvlParams kk_var paramInOutTys preDefs trace[Diverge.def] "Unary bodies (after decl): {bodies}" + -- Generate the mutually recursive body trace[Diverge.def] "# Generating the mut rec body" - let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var in_ty out_ty inOutTys.toList bodies + let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var param_ty in_ty out_ty paramInOutTys.toList bodies trace[Diverge.def] "mut rec body (after decl): {mutRecBody}" -- Prove that the mut rec body satisfies the validity criteria required by -- our fixed-point - let k_var_ty ← mkAppM ``FixI.k_ty #[i_var_ty, in_ty, out_ty] + let k_var_ty ← mkAppM ``FixII.k_ty #[i_var_ty, param_ty, in_ty, out_ty] withLocalDeclD (mkAnonymous "k" 3) k_var_ty fun k_var => do trace[Diverge.def] "# Proving that the mut rec body is valid" - let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies + let isValidThm ← proveMutRecIsValid grName grLvlParams paramInOutTysExpr bodyFuns mutRecBody k_var preDefs bodies -- Generate the final definitions trace[Diverge.def] "# Generating the final definitions" - let decls ← mkDeclareFixDefs mutRecBody inOutTys preDefs + let decls ← mkDeclareFixDefs mutRecBody paramInOutTys preDefs -- Prove the unfolding theorems trace[Diverge.def] "# Proving the unfolding theorems" - proveUnfoldingThms isValidThm inOutTys preDefs decls + proveUnfoldingThms isValidThm paramInOutTys preDefs decls - -- Generating code -- TODO + -- Generating code addAndCompilePartialRec preDefs -- The following function is copy&pasted from Lean.Elab.PreDefinition.Main diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean index 1a0676e2..b818d5af 100644 --- a/backends/lean/Base/Diverge/ElabBase.lean +++ b/backends/lean/Base/Diverge/ElabBase.lean @@ -11,7 +11,9 @@ open Utils initialize registerTraceClass `Diverge.elab initialize registerTraceClass `Diverge.def initialize registerTraceClass `Diverge.def.sigmas +initialize registerTraceClass `Diverge.def.prods initialize registerTraceClass `Diverge.def.genBody +initialize registerTraceClass `Diverge.def.genBody.visit initialize registerTraceClass `Diverge.def.valid initialize registerTraceClass `Diverge.def.unfold -- cgit v1.2.3 From ee669c4dbf8be12a3dd7249c645fd7092ba3e8eb Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 11 Dec 2023 17:07:34 +0100 Subject: Cleanup a bit --- backends/lean/Base/Diverge/Elab.lean | 63 +++++++++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 16 deletions(-) (limited to 'backends/lean/Base') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 423a2514..3c23db64 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -21,6 +21,9 @@ open Utils open WF in +def mkProdType (x y : Expr) : MetaM Expr := + mkAppM ``Prod #[x, y] + def mkProd (x y : Expr) : MetaM Expr := mkAppM ``Prod.mk #[x, y] @@ -47,6 +50,17 @@ def getSigmaTypes (ty : Expr) : MetaM (Expr × Expr) := do else pure (args.get! 0, args.get! 1) +/- Make a sigma type. + + `x` should be a variable, and `ty` and type which (might) uses `x` + -/ +def mkSigmaType (x : Expr) (sty : Expr) : MetaM Expr := do + trace[Diverge.def.sigmas] "mkSigmaType: {x} {sty}" + let alpha ← inferType x + let beta ← mkLambdaFVars #[x] sty + trace[Diverge.def.sigmas] "mkSigmaType: ({alpha}) ({beta})" + mkAppOptM ``Sigma #[some alpha, some beta] + /- Generate a Sigma type from a list of *variables* (all the expressions must be variables). @@ -64,16 +78,12 @@ def mkSigmasType (xl : List Expr) : MetaM Expr := pure (Expr.const ``PUnit [Level.succ .zero]) | [x] => do trace[Diverge.def.sigmas] "mkSigmasType: [{x}]" - let ty ← Lean.Meta.inferType x + let ty ← inferType x pure ty | x :: xl => do trace[Diverge.def.sigmas] "mkSigmasType: [{x}::{xl}]" - let alpha ← Lean.Meta.inferType x let sty ← mkSigmasType xl - trace[Diverge.def.sigmas] "mkSigmasType: [{x}::{xl}]: alpha={alpha}, sty={sty}" - let beta ← mkLambdaFVars #[x] sty - trace[Diverge.def.sigmas] "mkSigmasType: ({alpha}) ({beta})" - mkAppOptM ``Sigma #[some alpha, some beta] + mkSigmaType x sty /- Generate a product type from a list of *variables* (this is similar to `mkSigmas`). @@ -90,11 +100,11 @@ def mkProdsType (xl : List Expr) : MetaM Expr := pure (Expr.const ``PUnit [Level.succ .zero]) | [x] => do trace[Diverge.def.prods] "mkProdsType: [{x}]" - let ty ← Lean.Meta.inferType x + let ty ← inferType x pure ty | x :: xl => do trace[Diverge.def.prods] "mkProdsType: [{x}::{xl}]" - let ty ← Lean.Meta.inferType x + let ty ← inferType x let xl_ty ← mkProdsType xl mkAppM ``Prod #[ty, xl_ty] @@ -114,7 +124,7 @@ def splitInputArgs (in_tys : Array Expr) (out_ty : Expr) : MetaM (Array Expr × let rec splitAux (in_tys : List Expr) : MetaM (HashSet FVarId × List Expr × List Expr) := match in_tys with | [] => do - let fvars ← getFVarIds (← Lean.Meta.inferType out_ty) + let fvars ← getFVarIds (← inferType out_ty) pure (fvars, [], []) | ty :: in_tys => do let (fvars, in_tys, in_args) ← splitAux in_tys @@ -132,7 +142,7 @@ def splitInputArgs (in_tys : Array Expr) (out_ty : Expr) : MetaM (Array Expr × pure (fvars, [ty], in_args) else -- We must split later: update the fvars set - let fvars := fvars.insertMany (← getFVarIds (← Lean.Meta.inferType ty)) + let fvars := fvars.insertMany (← getFVarIds (← inferType ty)) pure (fvars, [], ty :: in_args) let (_, in_tys, in_args) ← splitAux in_tys.data pure (Array.mk in_tys, Array.mk in_args) @@ -250,13 +260,13 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met | Sigma.mk x ... -- the hole is given by a recursive call on the tail ``` -/ trace[Diverge.def.sigmas] "mkSigmasMatch: [{fst}::{xl}]" - let alpha ← Lean.Meta.inferType fst + let alpha ← inferType fst let snd_ty ← mkSigmasType xl let beta ← mkLambdaFVars #[fst] snd_ty let snd ← mkSigmasMatch xl out (index + 1) let mk ← mkLambdaFVars #[fst] snd -- Introduce the "scrut" variable - let scrut_ty ← mkSigmasType (fst :: xl) -- TODO: factor out with snd_ty + let scrut_ty ← mkSigmaType fst snd_ty withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do trace[Diverge.def.sigmas] "mkSigmasMatch: scrut: ({scrut}) : ({← inferType scrut})" -- TODO: make the computation of the motive more efficient @@ -294,12 +304,12 @@ partial def mkProdsMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Meta mkLambdaFVars #[x] out | fst :: xl => do trace[Diverge.def.prods] "mkProdsMatch: [{fst}::{xl}]" - let alpha ← Lean.Meta.inferType fst + let alpha ← inferType fst let beta ← mkProdsType xl let snd ← mkProdsMatch xl out (index + 1) let mk ← mkLambdaFVars #[fst] snd -- Introduce the "scrut" variable - let scrut_ty ← mkProdsType (fst :: xl) -- TODO: factor out with beta + let scrut_ty ← mkProdType alpha beta withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do trace[Diverge.def.prods] "mkProdsMatch: scrut: ({scrut}) : ({← inferType scrut})" -- TODO: make the computation of the motive more efficient @@ -1265,8 +1275,29 @@ elab_rules : command namespace Tests /- Some examples of partial functions -/ - - divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := + /- section HigherOrder + open FixI + + -- The index type + variable {id : Type u} + + -- The input/output types + variable {a : id → Type v} {b : (i:id) → a i → Type w} + + -- Example with a higher-order function + theorem map_is_valid + {{f : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)}} + (Hfvalid : ∀ k i x, is_valid_p k (λ k => f k i x)) + (k : (a → Result b) → a → Result b) + (ls : List a) : + is_valid_p k (λ k => Ex5.map (f k) ls) := + induction ls <;> simp [map] + apply is_valid_p_bind <;> try simp_all + intros + apply is_valid_p_bind <;> try simp_all + end HigherOrder -/ + + divergent def list_nth {a: Type u} (ls : List a) (i : Int) : Result a := match ls with | [] => .fail .panic | x :: ls => -- cgit v1.2.3 From 78367ef21c147b26040e0f6062a907fceab1f390 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 11 Dec 2023 18:34:10 +0100 Subject: Start working on higher-order examples for Diverge --- backends/lean/Base/Diverge/Base.lean | 3 + backends/lean/Base/Diverge/Elab.lean | 501 ++++++++++++++++++------------- backends/lean/Base/Diverge/ElabBase.lean | 63 +++- backends/lean/Base/Progress/Base.lean | 6 +- 4 files changed, 355 insertions(+), 218 deletions(-) (limited to 'backends/lean/Base') diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index a7107c1e..bdc3ed04 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -5,6 +5,7 @@ import Mathlib.Tactic.RunCmd import Mathlib.Tactic.Linarith import Base.Primitives.Base import Base.Arith.Base +import Base.Diverge.ElabBase /- TODO: this is very useful, but is there more? -/ set_option profiler true @@ -1467,6 +1468,7 @@ namespace Ex8 .ret (hd :: tl) /- The validity theorem for `map`, generic in `f` -/ + @[divspec] theorem map_is_valid (i : id) (t : ty i) {{f : (a i t → Result (b i t)) → (a i t) → Result c}} @@ -1479,6 +1481,7 @@ namespace Ex8 intros apply is_valid_p_bind <;> try simp_all + @[divspec] theorem map_is_valid' (i : id) (t : ty i) (k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 3c23db64..97364d14 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -21,6 +21,10 @@ open Utils open WF in +-- TODO: use those +def UnitType := Expr.const ``PUnit [Level.succ .zero] +def UnitValue := Expr.const ``PUnit.unit [Level.succ .zero] + def mkProdType (x y : Expr) : MetaM Expr := mkAppM ``Prod #[x, y] @@ -382,29 +386,71 @@ def mkFinVal (n i : Nat) : MetaM Expr := do let ofNat ← mkAppOptM ``Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat #[n_lit, i_lit] mkAppOptM ``OfNat.ofNat #[none, none, ofNat] +/- Information about the type of a function in a declaration group. + + In the comments about the fields, we take as example the + `list_nth (α : Type) (ls : List α) (i : Int) : Result α` function. + -/ +structure TypeInfo where + /- The total number of input arguments. + + For list_nth: 3 + -/ + total_num_args : ℕ + /- The number of type parameters (they should be a prefix of the input arguments). + + For `list_nth`: 1 + -/ + num_params : ℕ + /- The type of the dependent tuple grouping the type parameters. + + For `list_nth`: `Type` + -/ + params_ty : Expr + /- The type of the tuple grouping the input values. This is a function taking + as input a value of type `params_ty`. + + For `list_nth`: `λ a => List a × Int` + -/ + in_ty : Expr + /- The output type, without the `Return`. This is a function taking + as input a value of type `params_ty`. + + For `list_nth`: `λ a => a` + -/ + out_ty : Expr + +def mkInOutTyFromTypeInfo (info : TypeInfo) : MetaM Expr := do + mkInOutTy info.params_ty info.in_ty info.out_ty + +instance : Inhabited TypeInfo := + { default := { total_num_args := 0, num_params := 0, params_ty := UnitType, + in_ty := UnitType, out_ty := UnitType } } + +instance : ToMessageData TypeInfo := + ⟨ λ ⟨ total_num_args, num_params, params_ty, in_ty, out_ty ⟩ => + f!"\{ total_num_args: {total_num_args}, num_params: {num_params}, params_ty: {params_ty}, in_ty: {in_ty}, out_ty: {out_ty} }}" + ⟩ + /- Generate and declare as individual definitions the bodies for the individual funcions: - replace the recursive calls with calls to the continutation `k` - make those bodies take one single dependent tuple as input We name the declarations: "[original_name].body". We return the new declarations. - - Inputs: - - `paramInOutTys`: (number of type parameters, sigma type grouping the type parameters, input type, output type) -/ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) - (paramInOutTys : Array (ℕ × Expr × Expr × Expr)) (preDefs : Array PreDefinition) : + (paramInOutTys : Array TypeInfo) (preDefs : Array PreDefinition) : MetaM (Array Expr) := do let grSize := preDefs.size - /- Compute the map from name to (index, num type parameters, parameters type, input type). - Example for `list_nth (α : Type) (ls : List α) (i : Int) : Result α`: `"list_nth" → (0, 1, α, (List α × Int))` + /- Compute the map from name to (index, type info). + Remark: the continuation has an indexed type; we use the index (a finite number of type `Fin`) to control which function we call at the recursive call site. -/ - let nameToInfo : HashMap Name (Nat × Nat × Expr × Expr) := + let nameToInfo : HashMap Name (Nat × TypeInfo) := let bl := preDefs.mapIdx fun i d => - let (num_params, params_ty, in_ty, _) := paramInOutTys.get! i.val - (d.declName, (i.val, num_params, params_ty, in_ty)) + (d.declName, (i.val, paramInOutTys.get! i.val)) HashMap.ofList bl.toList trace[Diverge.def.genBody] "nameToId: {nameToInfo.toList}" @@ -423,18 +469,29 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) let name := f.constName! match nameToInfo.find? name with | none => pure e - | some (id, num_params, params_ty, _in_ty) => + | some (id, type_info) => trace[Diverge.def.genBody.visit] "this is a recursive call" -- This is a recursive call: replace it -- Compute the index let i ← mkFinVal grSize id + -- It can happen that there are no input values given to the + -- recursive calls, and only type parameters. + let num_args := args.size + if num_args ≠ type_info.total_num_args ∧ num_args ≠ type_info.num_params then + throwError "Invalid number of arguments for the recursive call: {e}" -- Split the arguments, and put them in two tuples (the first -- one is a dependent tuple) - let (param_args, args) := args.toList.splitAt num_params + let (param_args, args) := args.toList.splitAt type_info.num_params trace[Diverge.def.genBody.visit] "param_args: {param_args}, args: {args}" - let param_args ← mkSigmasVal params_ty param_args - let args ← mkProdsVal args - mkAppM' kk_var #[i, param_args, args] + let param_args ← mkSigmasVal type_info.params_ty param_args + -- Check if there are input values + if num_args = type_info.total_num_args then do + trace[Diverge.def.genBody.visit] "Recursive call with input values" + let args ← mkProdsVal args + mkAppM' kk_var #[i, param_args, args] + else do + trace[Diverge.def.genBody.visit] "Recursive call without input values" + mkAppM' kk_var #[i, param_args] else -- Not a recursive call: do nothing pure e @@ -458,8 +515,8 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) -- (over which we match to retrieve the individual arguments). lambdaTelescope body fun args body => do -- Split the arguments between the type parameters and the "regular" inputs - let (_, num_params, _, _) := nameToInfo.find! preDef.declName - let (param_args, args) := args.toList.splitAt num_params + let (_, type_info) := nameToInfo.find! preDef.declName + let (param_args, args) := args.toList.splitAt type_info.num_params let body ← mkProdsMatchOrUnit args body trace[Diverge.def.genBody] "Body after mkProdsMatchOrUnit: {body}" let body ← mkSigmasMatchOrUnit param_args body @@ -494,27 +551,30 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) -/ def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name) (kk_var i_var : Expr) - (param_ty in_ty out_ty : Expr) (paramInOutTys : List (Nat × Expr × Expr × Expr)) + (param_ty in_ty out_ty : Expr) (paramInOutTys : Array TypeInfo) (bodies : Array Expr) : MetaM (Expr × Expr) := do -- Generate the body let grSize := bodies.size let finTypeExpr := mkFin grSize -- TODO: not very clean let paramInOutTyType ← do - let (_, x, y, z) := paramInOutTys.get! 0 - inferType (← mkInOutTy x y z) - let rec mkFuns (paramInOutTys : List (Nat × Expr × Expr × Expr)) (bl : List Expr) : MetaM Expr := + let info := paramInOutTys.get! 0 + inferType (← mkInOutTyFromTypeInfo info) + let rec mkFuns (paramInOutTys : List TypeInfo) (bl : List Expr) : MetaM Expr := match paramInOutTys, bl with | [], [] => mkAppOptM ``FixII.Funs.Nil #[finTypeExpr, param_ty, in_ty, out_ty] - | (_, pty, ity, oty) :: paramInOutTys, b :: bl => do + | info :: paramInOutTys, b :: bl => do + let pty := info.params_ty + let ity := info.in_ty + let oty := info.out_ty -- Retrieving ity and oty - this is not very clean let paramInOutTysExpr ← mkListLit paramInOutTyType - (← paramInOutTys.mapM (λ (_, x, y, z) => mkInOutTy x y z)) + (← paramInOutTys.mapM mkInOutTyFromTypeInfo) let fl ← mkFuns paramInOutTys bl mkAppOptM ``FixII.Funs.Cons #[finTypeExpr, param_ty, in_ty, out_ty, pty, ity, oty, paramInOutTysExpr, b, fl] | _, _ => throwError "mkDeclareMutRecBody: `tys` and `bodies` don't have the same length" - let bodyFuns ← mkFuns paramInOutTys bodies.toList + let bodyFuns ← mkFuns paramInOutTys.toList bodies.toList -- Wrap in `get_fun` let body ← mkAppM ``FixII.get_fun #[bodyFuns, i_var, kk_var] -- Add the index `i` and the continuation `k` as a variables @@ -574,7 +634,7 @@ mutual ``` -/ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do - trace[Diverge.def.valid] "proveValid: {e}" + trace[Diverge.def.valid] "proveExprIsValid: {e}" match e with | .const _ _ => throwError "Unimplemented" -- Shouldn't get there? | .bvar _ @@ -602,160 +662,173 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do proveNoKExprIsValid k_var e | .app .. => e.withApp fun f args => do - -- There are several cases: first, check if this is a match/if - -- Check if the expression is a (dependent) if then else. - -- We treat the if then else expressions differently from the other matches, - -- and have dedicated theorems for them. - let isIte := e.isIte - if isIte || e.isDIte then do - e.withApp fun f args => do - trace[Diverge.def.valid] "ite/dite: {f}:\n{args}" - if args.size ≠ 5 then - throwError "Wrong number of parameters for {f}: {args}" - let cond := args.get! 1 - let dec := args.get! 2 - -- Prove that the branches are valid - let br0 := args.get! 3 - let br1 := args.get! 4 - let proveBranchValid (br : Expr) : MetaM Expr := - if isIte then proveExprIsValid k_var kk_var br - else do - -- There is a lambda - lambdaOne br fun x br => do - let brValid ← proveExprIsValid k_var kk_var br - mkLambdaFVars #[x] brValid - let br0Valid ← proveBranchValid br0 - let br1Valid ← proveBranchValid br1 - let const := if isIte then ``FixII.is_valid_p_ite else ``FixII.is_valid_p_dite - let eIsValid ← - mkAppOptM const #[none, none, none, none, none, - some k_var, some cond, some dec, none, none, - some br0Valid, some br1Valid] - trace[Diverge.def.valid] "ite/dite: result:\n{eIsValid}:\n{← inferType eIsValid}" - pure eIsValid - -- Check if the expression is a match (this case is for when the elaborator - -- introduces auxiliary definitions to hide the match behind syntactic - -- sugar): - else if let some me := ← matchMatcherApp? e then do - trace[Diverge.def.valid] - "matcherApp: - - params: {me.params} - - motive: {me.motive} - - discrs: {me.discrs} - - altNumParams: {me.altNumParams} - - alts: {me.alts} - - remaining: {me.remaining}" - -- matchMatcherApp does all the work for us: we simply need to gather - -- the information and call the auxiliary helper `proveMatchIsValid` - if me.remaining.size ≠ 0 then - throwError "MatcherApp: non empty remaining array: {me.remaining}" - let me : MatchInfo := { - matcherName := me.matcherName - matcherLevels := me.matcherLevels - params := me.params - motive := me.motive - scruts := me.discrs - branchesNumParams := me.altNumParams - branches := me.alts - } - proveMatchIsValid k_var kk_var me - -- Check if the expression is a raw match (this case is for when the expression - -- is a direct call to the primitive `casesOn` function, without syntactic sugar). - -- We have to check this case because functions like `mkSigmasMatch`, which we - -- use to currify function bodies, introduce such raw matches. - else if ← isCasesExpr f then do - trace[Diverge.def.valid] "rawMatch: {e}" - /- Deconstruct the match, and call the auxiliary helper `proveMatchIsValid`. - - The casesOn definition is always of the following shape: - - input parameters (implicit parameters) - - motive (implicit), -- the motive gives the return type of the match - - scrutinee (explicit) - - branches (explicit). - In particular, we notice that the scrutinee is the first *explicit* - parameter - this is how we spot it. - -/ - let matcherName := f.constName! - let matcherLevels := f.constLevels!.toArray - -- Find the first explicit parameter: this is the scrutinee - forallTelescope (← inferType f) fun xs _ => do - let rec findFirstExplicit (i : Nat) : MetaM Nat := do - if i ≥ xs.size then throwError "Unexpected: could not find an explicit parameter" - else - let x := xs.get! i - let xFVarId := x.fvarId! - let localDecl ← xFVarId.getDecl - match localDecl.binderInfo with - | .default => pure i - | _ => findFirstExplicit (i + 1) - let scrutIdx ← findFirstExplicit 0 - -- Split the arguments - let params := args.extract 0 (scrutIdx - 1) - let motive := args.get! (scrutIdx - 1) - let scrut := args.get! scrutIdx - let branches := args.extract (scrutIdx + 1) args.size - /- Compute the number of parameters for the branches: for this we use - the type of the uninstantiated casesOn constant (we can't just - destruct the lambdas in the branch expressions because the result - of a match might be a lambda expression). -/ - let branchesNumParams : Array Nat ← do - let env ← getEnv - let decl := env.constants.find! matcherName - let ty := decl.type - forallTelescope ty fun xs _ => do - let xs := xs.extract (scrutIdx + 1) xs.size - xs.mapM fun x => do - let xty ← inferType x - forallTelescope xty fun ys _ => do - pure ys.size - let me : MatchInfo := { - matcherName, - matcherLevels, - params, - motive, - scruts := #[scrut], - branchesNumParams, - branches, - } - proveMatchIsValid k_var kk_var me - -- Check if this is a monadic let-binding - else if f.isConstOf ``Bind.bind then do - trace[Diverge.def.valid] "bind:\n{args}" - -- We simply need to prove that the subexpressions are valid, and call - -- the appropriate lemma. - let x := args.get! 4 - let y := args.get! 5 - -- Prove that the subexpressions are valid - let xValid ← proveExprIsValid k_var kk_var x - trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}" - let yValid ← do - -- This is a lambda expression - lambdaOne y fun x y => do - trace[Diverge.def.valid] "bind: y: {y}" - let yValid ← proveExprIsValid k_var kk_var y - trace[Diverge.def.valid] "bind: yValid (no forall): {yValid}" - trace[Diverge.def.valid] "bind: yValid: x: {x}" - let yValid ← mkLambdaFVars #[x] yValid - trace[Diverge.def.valid] "bind: yValid (forall): {yValid}: {← inferType yValid}" - pure yValid - -- Put everything together - trace[Diverge.def.valid] "bind:\n- xValid: {xValid}: {← inferType xValid}\n- yValid: {yValid}: {← inferType yValid}" - mkAppM ``FixII.is_valid_p_bind #[xValid, yValid] - -- Check if this is a recursive call, i.e., a call to the continuation `kk` - else if f.isFVarOf kk_var.fvarId! then do - trace[Diverge.def.valid] "rec: args: \n{args}" - if args.size ≠ 3 then throwError "Recursive call with invalid number of parameters: {args}" - let i_arg := args.get! 0 - let t_arg := args.get! 1 - let x_arg := args.get! 2 - let eIsValid ← mkAppM ``FixII.is_valid_p_rec #[k_var, i_arg, t_arg, x_arg] - trace[Diverge.def.valid] "rec: result: \n{eIsValid}" - pure eIsValid + proveAppIsValid k_var kk_var e f args + +partial def proveAppIsValid (k_var kk_var : Expr) (e : Expr) (f : Expr) (args : Array Expr): MetaM Expr := do + trace[Diverge.def.valid] "proveAppIsValid: {f} {args}" + /- There are several cases: first, check if this is a match/if + Check if the expression is a (dependent) if then else. + We treat the if then else expressions differently from the other matches, + and have dedicated theorems for them. -/ + let isIte := e.isIte + if isIte || e.isDIte then do + e.withApp fun f args => do + trace[Diverge.def.valid] "ite/dite: {f}:\n{args}" + if args.size ≠ 5 then + throwError "Wrong number of parameters for {f}: {args}" + let cond := args.get! 1 + let dec := args.get! 2 + -- Prove that the branches are valid + let br0 := args.get! 3 + let br1 := args.get! 4 + let proveBranchValid (br : Expr) : MetaM Expr := + if isIte then proveExprIsValid k_var kk_var br else do - -- Remaining case: normal application. - -- It shouldn't use the continuation. - -- TODO: actually, it can - proveNoKExprIsValid k_var e + -- There is a lambda + lambdaOne br fun x br => do + let brValid ← proveExprIsValid k_var kk_var br + mkLambdaFVars #[x] brValid + let br0Valid ← proveBranchValid br0 + let br1Valid ← proveBranchValid br1 + let const := if isIte then ``FixII.is_valid_p_ite else ``FixII.is_valid_p_dite + let eIsValid ← + mkAppOptM const #[none, none, none, none, none, + some k_var, some cond, some dec, none, none, + some br0Valid, some br1Valid] + trace[Diverge.def.valid] "ite/dite: result:\n{eIsValid}:\n{← inferType eIsValid}" + pure eIsValid + /- Check if the expression is a match (this case is for when the elaborator + introduces auxiliary definitions to hide the match behind syntactic + sugar): -/ + else if let some me := ← matchMatcherApp? e then do + trace[Diverge.def.valid] + "matcherApp: + - params: {me.params} + - motive: {me.motive} + - discrs: {me.discrs} + - altNumParams: {me.altNumParams} + - alts: {me.alts} + - remaining: {me.remaining}" + -- matchMatcherApp does all the work for us: we simply need to gather + -- the information and call the auxiliary helper `proveMatchIsValid` + if me.remaining.size ≠ 0 then + throwError "MatcherApp: non empty remaining array: {me.remaining}" + let me : MatchInfo := { + matcherName := me.matcherName + matcherLevels := me.matcherLevels + params := me.params + motive := me.motive + scruts := me.discrs + branchesNumParams := me.altNumParams + branches := me.alts + } + proveMatchIsValid k_var kk_var me + /- Check if the expression is a raw match (this case is for when the expression + is a direct call to the primitive `casesOn` function, without syntactic sugar). + We have to check this case because functions like `mkSigmasMatch`, which we + use to currify function bodies, introduce such raw matches. -/ + else if ← isCasesExpr f then do + trace[Diverge.def.valid] "rawMatch: {e}" + /- Deconstruct the match, and call the auxiliary helper `proveMatchIsValid`. + + The casesOn definition is always of the following shape: + - input parameters (implicit parameters) + - motive (implicit), -- the motive gives the return type of the match + - scrutinee (explicit) + - branches (explicit). + In particular, we notice that the scrutinee is the first *explicit* + parameter - this is how we spot it. + -/ + let matcherName := f.constName! + let matcherLevels := f.constLevels!.toArray + -- Find the first explicit parameter: this is the scrutinee + forallTelescope (← inferType f) fun xs _ => do + let rec findFirstExplicit (i : Nat) : MetaM Nat := do + if i ≥ xs.size then throwError "Unexpected: could not find an explicit parameter" + else + let x := xs.get! i + let xFVarId := x.fvarId! + let localDecl ← xFVarId.getDecl + match localDecl.binderInfo with + | .default => pure i + | _ => findFirstExplicit (i + 1) + let scrutIdx ← findFirstExplicit 0 + -- Split the arguments + let params := args.extract 0 (scrutIdx - 1) + let motive := args.get! (scrutIdx - 1) + let scrut := args.get! scrutIdx + let branches := args.extract (scrutIdx + 1) args.size + /- Compute the number of parameters for the branches: for this we use + the type of the uninstantiated casesOn constant (we can't just + destruct the lambdas in the branch expressions because the result + of a match might be a lambda expression). -/ + let branchesNumParams : Array Nat ← do + let env ← getEnv + let decl := env.constants.find! matcherName + let ty := decl.type + forallTelescope ty fun xs _ => do + let xs := xs.extract (scrutIdx + 1) xs.size + xs.mapM fun x => do + let xty ← inferType x + forallTelescope xty fun ys _ => do + pure ys.size + let me : MatchInfo := { + matcherName, + matcherLevels, + params, + motive, + scruts := #[scrut], + branchesNumParams, + branches, + } + proveMatchIsValid k_var kk_var me + -- Check if this is a monadic let-binding + else if f.isConstOf ``Bind.bind then do + trace[Diverge.def.valid] "bind:\n{args}" + -- We simply need to prove that the subexpressions are valid, and call + -- the appropriate lemma. + let x := args.get! 4 + let y := args.get! 5 + -- Prove that the subexpressions are valid + let xValid ← proveExprIsValid k_var kk_var x + trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}" + let yValid ← do + -- This is a lambda expression + lambdaOne y fun x y => do + trace[Diverge.def.valid] "bind: y: {y}" + let yValid ← proveExprIsValid k_var kk_var y + trace[Diverge.def.valid] "bind: yValid (no forall): {yValid}" + trace[Diverge.def.valid] "bind: yValid: x: {x}" + let yValid ← mkLambdaFVars #[x] yValid + trace[Diverge.def.valid] "bind: yValid (forall): {yValid}: {← inferType yValid}" + pure yValid + -- Put everything together + trace[Diverge.def.valid] "bind:\n- xValid: {xValid}: {← inferType xValid}\n- yValid: {yValid}: {← inferType yValid}" + mkAppM ``FixII.is_valid_p_bind #[xValid, yValid] + -- Check if this is a recursive call, i.e., a call to the continuation `kk` + else if f.isFVarOf kk_var.fvarId! then do + trace[Diverge.def.valid] "rec: args: \n{args}" + if args.size ≠ 3 then throwError "Recursive call with invalid number of parameters: {args}" + let i_arg := args.get! 0 + let t_arg := args.get! 1 + let x_arg := args.get! 2 + let eIsValid ← mkAppM ``FixII.is_valid_p_rec #[k_var, i_arg, t_arg, x_arg] + trace[Diverge.def.valid] "rec: result: \n{eIsValid}" + pure eIsValid + else do + /- Remaining case: normal application. + Check if the arguments use the continuation: + - if no: this is simple + - if yes: we have to lookup theorems in div spec database and continue -/ + trace[Diverge.def.valid] "regular app: {e}" + let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty + if ¬ allArgsFVars.contains kk_var.fvarId! then do + -- Simple case + trace[Diverge.def.valid] "kk doesn't appear in the arguments" + proveNoKExprIsValid k_var e + else do + -- Lookup in the database for suitable theorems + throwError "TODO: {e}" -- Prove that a match expression is valid. partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Expr := do @@ -912,6 +985,7 @@ def proveMutRecIsValid -- Then prove that the mut rec body is valid trace[Diverge.def.valid] "## Proving that the 'Funs' body is valid" let isValid ← proveFunsBodyIsValid paramInOutTys bodyFuns k_var bodiesValid + trace[Diverge.def.valid] "Generated the term: {isValid}" -- Save the theorem let thmTy ← mkAppM ``FixII.is_valid #[mutRecBodyConst] let name := grName ++ "mut_rec_body_is_valid" @@ -935,18 +1009,18 @@ def proveMutRecIsValid def is_odd (i : Int) : Result Bool := mut_rec_body 1 i ``` -/ -def mkDeclareFixDefs (mutRecBody : Expr) (paramInOutTys : Array (ℕ × Expr × Expr × Expr)) (preDefs : Array PreDefinition) : +def mkDeclareFixDefs (mutRecBody : Expr) (paramInOutTys : Array TypeInfo) (preDefs : Array PreDefinition) : TermElabM (Array Name) := do let grSize := preDefs.size let defs ← preDefs.mapIdxM fun idx preDef => do lambdaTelescope preDef.value fun xs _ => do -- Retrieve the parameters info - let (num_params, param_ty, _, _) := paramInOutTys.get! idx.val + let type_info := paramInOutTys.get! idx.val -- Create the index let idx ← mkFinVal grSize idx.val -- Group the inputs into two tuples - let (params_args, input_args) := xs.toList.splitAt num_params - let params ← mkSigmasVal param_ty params_args + let (params_args, input_args) := xs.toList.splitAt type_info.num_params + let params ← mkSigmasVal type_info.params_ty params_args let input ← mkProdsVal input_args -- Apply the fixed point let fixedBody ← mkAppM ``FixII.fix #[mutRecBody, idx, params, input] @@ -968,7 +1042,7 @@ def mkDeclareFixDefs (mutRecBody : Expr) (paramInOutTys : Array (ℕ × Expr × -- Prove the equations that we will use as unfolding theorems partial def proveUnfoldingThms (isValidThm : Expr) - (paramInOutTys : Array (ℕ × Expr × Expr × Expr)) + (paramInOutTys : Array TypeInfo) (preDefs : Array PreDefinition) (decls : Array Name) : MetaM Unit := do let grSize := preDefs.size let proveIdx (i : Nat) : MetaM Unit := do @@ -993,9 +1067,9 @@ partial def proveUnfoldingThms (isValidThm : Expr) let idx ← mkFinVal grSize i let proof ← mkAppM ``congr_fun #[proof, idx] -- Add the input arguments - let (num_params, param_ty, _, _) := paramInOutTys.get! i - let (params, args) := xs.toList.splitAt num_params - let params ← mkSigmasVal param_ty params + let type_info := paramInOutTys.get! i + let (params, args) := xs.toList.splitAt type_info.num_params + let params ← mkSigmasVal type_info.params_ty params let args ← mkProdsVal args let proof ← mkAppM ``congr_fun #[proof, params] let proof ← mkAppM ``congr_fun #[proof, args] @@ -1052,7 +1126,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do we generate: `(Type, λ α => List α × i, λ α => Result α)` -/ - let paramInOutTys : Array (ℕ × Expr × Expr × Expr) ← + let paramInOutTys : Array TypeInfo ← preDefs.mapM (fun preDef => do -- Check the universe parameters - TODO: I'm not sure what the best thing -- to do is. In practice, all the type parameters should be in Type 0, so @@ -1060,19 +1134,20 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do if preDef.levelParams ≠ grLvlParams then throwError "Non-uniform polymorphism in the universes" forallTelescope preDef.type (fun in_tys out_ty => do + let total_num_args := in_tys.size let (params, in_tys) ← splitInputArgs in_tys out_ty trace[Diverge.def] "Decomposed arguments: {preDef.declName}: {params}, {in_tys}, {out_ty}" let num_params := params.size - let params_sigma ← mkSigmasType params.data - let in_tys ← mkSigmasMatchOrUnit params.data (← mkProdsType in_tys.data) + let params_ty ← mkSigmasType params.data + let in_ty ← mkSigmasMatchOrUnit params.data (← mkProdsType in_tys.data) -- Retrieve the type in the "Result" let out_ty ← getResultTy out_ty let out_ty ← mkSigmasMatchOrUnit params.data out_ty - trace[Diverge.def] "inOutTy: {preDef.declName}: {params_sigma}, {in_tys}, {out_ty}" - pure (num_params, params_sigma, in_tys, out_ty))) + trace[Diverge.def] "inOutTy: {preDef.declName}: {params_ty}, {in_tys}, {out_ty}" + pure ⟨ total_num_args, num_params, params_ty, in_ty, out_ty ⟩)) trace[Diverge.def] "paramInOutTys: {paramInOutTys}" -- Turn the list of input types/input args/output type tuples into expressions - let paramInOutTysExpr ← paramInOutTys.mapM (λ (_, x, y, z) => do mkInOutTy x y z) + let paramInOutTysExpr ← liftM (paramInOutTys.mapM mkInOutTyFromTypeInfo) let paramInOutTysExpr ← mkListLit (← inferType (paramInOutTysExpr.get! 0)) paramInOutTysExpr.toList trace[Diverge.def] "paramInOutTys: {paramInOutTys}" @@ -1135,7 +1210,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Generate the mutually recursive body trace[Diverge.def] "# Generating the mut rec body" - let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var param_ty in_ty out_ty paramInOutTys.toList bodies + let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var param_ty in_ty out_ty paramInOutTys bodies trace[Diverge.def] "mut rec body (after decl): {mutRecBody}" -- Prove that the mut rec body satisfies the validity criteria required by @@ -1275,27 +1350,27 @@ elab_rules : command namespace Tests /- Some examples of partial functions -/ - /- section HigherOrder - open FixI - - -- The index type - variable {id : Type u} - - -- The input/output types - variable {a : id → Type v} {b : (i:id) → a i → Type w} - - -- Example with a higher-order function - theorem map_is_valid - {{f : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)}} - (Hfvalid : ∀ k i x, is_valid_p k (λ k => f k i x)) - (k : (a → Result b) → a → Result b) - (ls : List a) : - is_valid_p k (λ k => Ex5.map (f k) ls) := - induction ls <;> simp [map] - apply is_valid_p_bind <;> try simp_all - intros - apply is_valid_p_bind <;> try simp_all - end HigherOrder -/ + section HigherOrder + open Ex8 + + inductive Tree (a : Type u) := + | leaf (x : a) + | node (tl : List (Tree a)) + + set_option trace.Diverge.def true + -- set_option trace.Diverge.def.genBody true + set_option trace.Diverge.def.valid true + divergent def id {a : Type u} (t : Tree a) : Result (Tree a) := + match t with + | .leaf x => .ret (.leaf x) + | .node tl => + do + let tl ← map id tl + .ret (.node tl) + + set_option trace.Diverge.def false + + end HigherOrder divergent def list_nth {a: Type u} (ls : List a) (i : Int) : Result a := match ls with diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean index b818d5af..0d33e9d2 100644 --- a/backends/lean/Base/Diverge/ElabBase.lean +++ b/backends/lean/Base/Diverge/ElabBase.lean @@ -1,13 +1,15 @@ import Lean import Base.Utils import Base.Primitives.Base +import Base.Extensions namespace Diverge open Lean Elab Term Meta -open Utils +open Utils Extensions -- We can't define and use trace classes in the same file +initialize registerTraceClass `Diverge initialize registerTraceClass `Diverge.elab initialize registerTraceClass `Diverge.def initialize registerTraceClass `Diverge.def.sigmas @@ -20,4 +22,63 @@ initialize registerTraceClass `Diverge.def.unfold -- For the attribute (for higher-order functions) initialize registerTraceClass `Diverge.attr +-- Attribute + +-- divspec attribute +structure DivSpecAttr where + attr : AttributeImpl + ext : DiscrTreeExtension Name true + deriving Inhabited + +/- The persistent map from expressions to divspec theorems. -/ +initialize divspecAttr : DivSpecAttr ← do + let ext ← mkDiscrTreeExtention `divspecMap true + let attrImpl : AttributeImpl := { + name := `divspec + descr := "Marks theorems to use with the `divergent` encoding" + add := fun thName stx attrKind => do + Attribute.Builtin.ensureNoArgs stx + -- TODO: use the attribute kind + unless attrKind == AttributeKind.global do + throwError "invalid attribute divspec, must be global" + -- Lookup the theorem + let env ← getEnv + let thDecl := env.constants.find! thName + let fKey : Array (DiscrTree.Key true) ← MetaM.run' (do + /- The theorem should have the shape: + `∀ ..., is_valid_p k (λ k => ...)` + + Dive into the ∀: + -/ + let (_, _, fExpr) ← forallMetaTelescope thDecl.type.consumeMData + /- Dive into the argument of `is_valid_p`: -/ + fExpr.consumeMData.withApp fun _ args => do + if args.size ≠ 7 then throwError "Invalid number of arguments to is_valid_p" + let fExpr := args.get! 6 + /- Dive into the lambda: -/ + let (_, _, fExpr) ← lambdaMetaTelescope fExpr.consumeMData + trace[Diverge] "Registering divspec theorem for {fExpr}" + -- Convert the function expression to a discrimination tree key + DiscrTree.mkPath fExpr) + let env := ext.addEntry env (fKey, thName) + setEnv env + trace[Diverge] "Saved the environment" + pure () + } + registerBuiltinAttribute attrImpl + pure { attr := attrImpl, ext := ext } + +def DivSpecAttr.find? (s : DivSpecAttr) (e : Expr) : MetaM (Array Name) := do + (s.ext.getState (← getEnv)).getMatch e + +def DivSpecAttr.getState (s : DivSpecAttr) : MetaM (DiscrTree Name true) := do + pure (s.ext.getState (← getEnv)) + +def showStoredDivSpec : MetaM Unit := do + let st ← divspecAttr.getState + -- TODO: how can we iterate over (at least) the values stored in the tree? + --let s := st.toList.foldl (fun s (f, th) => f!"{s}\n{f} → {th}") f!"" + let s := f!"{st}" + IO.println s + end Diverge diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean index fd80db8c..0ad16ab6 100644 --- a/backends/lean/Base/Progress/Base.lean +++ b/backends/lean/Base/Progress/Base.lean @@ -3,8 +3,6 @@ import Std.Lean.HashSet import Base.Utils import Base.Primitives.Base import Base.Extensions -import Lean.Meta.DiscrTree -import Lean.Meta.Tactic.Simp namespace Progress @@ -101,7 +99,7 @@ section Methods -- Check if some existentially quantified variables let _ := do -- Collect all the free variables in the arguments - let allArgsFVars := ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty + let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty -- Check if they intersect the fvars we introduced for the existentially quantified variables let evarsSet : HashSet FVarId := HashSet.ofArray (evars.map (fun (x : Expr) => x.fvarId!)) let filtArgsFVars := allArgsFVars.toArray.filter (fun var => evarsSet.contains var) @@ -134,7 +132,7 @@ structure PSpecAttr where ext : DiscrTreeExtension Name true deriving Inhabited -/- The persistent map from function to pspec theorems. -/ +/- The persistent map from expressions to pspec theorems. -/ initialize pspecAttr : PSpecAttr ← do let ext ← mkDiscrTreeExtention `pspecMap true let attrImpl : AttributeImpl := { -- cgit v1.2.3 From 24c5289d0ca039c1c64081285d7d120a04f40699 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 11 Dec 2023 19:48:42 +0100 Subject: Update the validity proofs for higher-order functions --- backends/lean/Base/Diverge/Elab.lean | 101 ++++++++++++++++++++++++++++++++--- backends/lean/Base/Utils.lean | 6 +-- 2 files changed, 96 insertions(+), 11 deletions(-) (limited to 'backends/lean/Base') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 97364d14..e555b3e3 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -820,15 +820,88 @@ partial def proveAppIsValid (k_var kk_var : Expr) (e : Expr) (f : Expr) (args : Check if the arguments use the continuation: - if no: this is simple - if yes: we have to lookup theorems in div spec database and continue -/ - trace[Diverge.def.valid] "regular app: {e}" + trace[Diverge.def.valid] "regular app: {e}, f: {f}, args: {args}" let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty + trace[Diverge.def.valid] "allArgsFVars: {allArgsFVars.toList.map mkFVar}" if ¬ allArgsFVars.contains kk_var.fvarId! then do -- Simple case trace[Diverge.def.valid] "kk doesn't appear in the arguments" proveNoKExprIsValid k_var e else do -- Lookup in the database for suitable theorems - throwError "TODO: {e}" + trace[Diverge.def.valid] "kk appears in the arguments" + let thms ← divspecAttr.find? e + trace[Diverge.def.valid] "Looked up theorems: {thms}" + -- Try the theorems one by one + proveAppIsValidApplyThms k_var kk_var e f args thms.toList + +partial def proveAppIsValidApplyThms (k_var kk_var : Expr) (e : Expr) + (f : Expr) (args : Array Expr) (thms : List Name) : MetaM Expr := do + trace[Diverge.def.valid] "thms: {thms}" + match thms with + | [] => throwError "Could not prove that the following expression is valid: {e}" + | thName :: thms => + -- Lookup the theorem itself + let env ← getEnv + let thDecl := env.constants.find! thName + -- Introduce fresh meta-variables for the universes + let ul : List (Name × Level) ← + thDecl.levelParams.mapM (λ x => do pure (x, ← mkFreshLevelMVar)) + let ulMap : HashMap Name Level := HashMap.ofList ul + let thTy := thDecl.type.instantiateLevelParamsCore (λ x => ulMap.find! x) + -- Introduce meta variables for the universally quantified variables + let (mvars, _binders, thTy) ← forallMetaTelescope thTy + -- thTy should now be of the shape: `is_valid_p k (λ kk => ...)` + /-thTy.consumeMData.withApp fun _ args => do + if args.size ≠ 7 then throwError "Invalid number of arguments (expected 7): {thTy}" + let thTermToMatch := args.get! 6 -/ + let thTermToMatch := thTy + trace[Diverge.def.valid] "thTermToMatch: {thTermToMatch}" + -- Create the term: `is_valid_p k (λ kk => e)` + let termToMatch ← mkLambdaFVars #[kk_var] e + let termToMatch ← mkAppM ``FixII.is_valid_p #[k_var, termToMatch] + trace[Diverge.def.valid] "termToMatch: {termToMatch}" + -- Attempt to match + let ok ← isDefEq thTermToMatch termToMatch + if ¬ ok then + -- Failure: attempt with the other theorems + proveAppIsValidApplyThms k_var kk_var e f args thms + else do + -- Success: continue with this theorem + -- Instantiate the meta variables (some of them will not be instantiated: + -- they are new subgoals) + let mvars ← mvars.mapM instantiateMVars + let th ← mkAppOptM thName (Array.map some mvars) + trace[Diverge.def.valid] "Instantiated theorm: {th}\n{← inferType th}" + -- Filter the meta variables between the instantiated ones + for mvar in mvars do + if mvar.isMVar then do + -- Prove the subgoal (i.e., the precondition of the theorem) + let mvarId := mvar.mvarId! + let mvarDecl ← mvarId.getDecl + -- Dive in the type + forallTelescope mvarDecl.type fun forall_vars mvar_e => do + -- `mvar_e` should have the shape `is_valid_p k (λ kk => ...)` + -- We need to retrieve the new `k` variable, and dive into the + -- `λ kk => ...` + mvar_e.consumeMData.withApp fun is_valid args => do + if is_valid.constName? ≠ ``FixII.is_valid_p ∨ args.size ≠ 2 then + throwError "Invalid precondition: {mvar_e}" + else do + let k_var := args.get! 0 + let e_lam := args.get! 1 + lambdaTelescope e_lam.consumeMData fun lvars e => do + if lvars.size ≠ 1 then throwError "Invalid number of lambdas (expected 1): {e_lam}" + let kk_var := lvars.get! 0 + -- Continue + let e_valid ← proveExprIsValid k_var kk_var e + let e_valid ← mkForallFVars forall_vars e_valid + -- Assign the meta variable + mvarId.assign e_valid + else + -- Nothing to do + pure () + pure th -- Prove that a match expression is valid. partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Expr := do @@ -852,7 +925,8 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp ^^^^^^^^^^^^^^^^^^^^ this is the original match expression, with the the difference that the scrutinee(s) is a variable - ``` -/ + ``` + -/ let validMotive : Expr ← do -- The motive is a function of the scrutinees (i.e., a lambda expression): -- introduce binders for the scrutinees @@ -905,7 +979,7 @@ partial def proveSingleBodyIsValid trace[Diverge.def.valid] "body: {body}" lambdaTelescope body fun xs body => do trace[Diverge.def.valid] "xs: {xs}" - assert! xs.size = 3 + if xs.size ≠ 3 then throwError "Invalid number of lambdas: {xs} (expected 3)" let kk_var := xs.get! 0 let t_var := xs.get! 1 let x_var := xs.get! 2 @@ -1349,6 +1423,7 @@ elab_rules : command Command.elabCommand <| ← `(namespace $(mkIdentFrom id ns) $cmd end $(mkIdentFrom id ns)) namespace Tests + /- Some examples of partial functions -/ section HigherOrder open Ex8 @@ -1357,9 +1432,6 @@ namespace Tests | leaf (x : a) | node (tl : List (Tree a)) - set_option trace.Diverge.def true - -- set_option trace.Diverge.def.genBody true - set_option trace.Diverge.def.valid true divergent def id {a : Type u} (t : Tree a) : Result (Tree a) := match t with | .leaf x => .ret (.leaf x) @@ -1368,7 +1440,20 @@ namespace Tests let tl ← map id tl .ret (.node tl) - set_option trace.Diverge.def false + #check id.unfold + + /-set_option trace.Diverge.def true + -- set_option trace.Diverge.def.genBody true + set_option trace.Diverge.def.valid true + divergent def incr (t : Tree Nat) : Result (Tree Nat) := + match t with + | .leaf x => .ret (.leaf (x + 1)) + | .node tl => + do + let tl ← map incr tl + .ret (.node tl) + + set_option trace.Diverge.def false-/ end HigherOrder diff --git a/backends/lean/Base/Utils.lean b/backends/lean/Base/Utils.lean index 95b2c38b..2366e800 100644 --- a/backends/lean/Base/Utils.lean +++ b/backends/lean/Base/Utils.lean @@ -371,19 +371,19 @@ def splitConjTarget : TacticM Unit := do -- Destruct an equaliy and return the two sides def destEq (e : Expr) : MetaM (Expr × Expr) := do - e.withApp fun f args => + e.consumeMData.withApp fun f args => if f.isConstOf ``Eq ∧ args.size = 3 then pure (args.get! 1, args.get! 2) else throwError "Not an equality: {e}" -- Return the set of FVarIds in the expression partial def getFVarIds (e : Expr) (hs : HashSet FVarId := HashSet.empty) : MetaM (HashSet FVarId) := do - e.withApp fun body args => do + e.consumeMData.withApp fun body args => do let hs := if body.isFVar then hs.insert body.fvarId! else hs args.foldlM (fun hs arg => getFVarIds arg hs) hs -- Return the set of MVarIds in the expression partial def getMVarIds (e : Expr) (hs : HashSet MVarId := HashSet.empty) : MetaM (HashSet MVarId) := do - e.withApp fun body args => do + e.consumeMData.withApp fun body args => do let hs := if body.isMVar then hs.insert body.mvarId! else hs args.foldlM (fun hs arg => getMVarIds arg hs) hs -- cgit v1.2.3 From c23f317f55801a4b7e3f808326b0fbd82f454d76 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 12 Dec 2023 14:57:48 +0100 Subject: Implement a map-reduce visitor for expressions and fix issues with get{M,F}VarIds --- backends/lean/Base/Utils.lean | 112 ++++++++++++++++++++++++++++++------------ 1 file changed, 81 insertions(+), 31 deletions(-) (limited to 'backends/lean/Base') diff --git a/backends/lean/Base/Utils.lean b/backends/lean/Base/Utils.lean index 2366e800..b0032281 100644 --- a/backends/lean/Base/Utils.lean +++ b/backends/lean/Base/Utils.lean @@ -159,47 +159,96 @@ elab "print_ctx_decls" : tactic => do let decls ← ctx.getDecls printDecls decls --- A map visitor function for expressions (adapted from `AbstractNestedProofs.visit`) +-- A map-reduce visitor function for expressions (adapted from `AbstractNestedProofs.visit`) -- The continuation takes as parameters: -- - the current depth of the expression (useful for printing/debugging) -- - the expression to explore -partial def mapVisit (k : Nat → Expr → MetaM Expr) (e : Expr) : MetaM Expr := do - let mapVisitBinders (xs : Array Expr) (k2 : MetaM Expr) : MetaM Expr := do +partial def mapreduceVisit {a : Type} (k : Nat → a → Expr → MetaM (a × Expr)) + (state : a) (e : Expr) : MetaM (a × Expr) := do + let mapreduceVisitBinders (state : a) (xs : Array Expr) (k2 : a → MetaM (a × Expr)) : + MetaM (a × Expr) := do let localInstances ← getLocalInstances - let mut lctx ← getLCtx - for x in xs do - let xFVarId := x.fvarId! - let localDecl ← xFVarId.getDecl - let type ← mapVisit k localDecl.type - let localDecl := localDecl.setType type - let localDecl ← match localDecl.value? with - | some value => let value ← mapVisit k value; pure <| localDecl.setValue value - | none => pure localDecl - lctx :=lctx.modifyLocalDecl xFVarId fun _ => localDecl - withLCtx lctx localInstances k2 + -- Update the local declarations for the bindings in context `lctx` + let rec visit_xs (lctx : LocalContext) (state : a) (xs : List Expr) : MetaM (LocalContext × a) := do + match xs with + | [] => pure (lctx, state) + | x :: xs => do + let xFVarId := x.fvarId! + let localDecl ← xFVarId.getDecl + let (state, type) ← mapreduceVisit k state localDecl.type + let localDecl := localDecl.setType type + let (state, localDecl) ← match localDecl.value? with + | some value => + let (state, value) ← mapreduceVisit k state value + pure (state, localDecl.setValue value) + | none => pure (state, localDecl) + let lctx := lctx.modifyLocalDecl xFVarId fun _ => localDecl + -- Recursive call + visit_xs lctx state xs + let (lctx, state) ← visit_xs (← getLCtx) state xs.toList + -- Call the continuation with the updated context + withLCtx lctx localInstances (k2 state) -- TODO: use a cache? (Lean.checkCache) - let rec visit (i : Nat) (e : Expr) : MetaM Expr := do + let rec visit (i : Nat) (state : a) (e : Expr) : MetaM (a × Expr) := do -- Explore - let e ← k i e + let (state, e) ← k i state e match e with | .bvar _ | .fvar _ | .mvar _ | .sort _ | .lit _ - | .const _ _ => pure e - | .app .. => do e.withApp fun f args => return mkAppN f (← args.mapM (visit (i + 1))) + | .const _ _ => pure (state, e) + | .app .. => do e.withApp fun f args => do + let (state, args) ← args.foldlM (fun (state, args) arg => do let (state, arg) ← visit (i + 1) state arg; pure (state, arg :: args)) (state, []) + let args := args.reverse + let (state, f) ← visit (i + 1) state f + let e' := mkAppN f (Array.mk args) + return (state, e') | .lam .. => lambdaLetTelescope e fun xs b => - mapVisitBinders xs do mkLambdaFVars xs (← visit (i + 1) b) (usedLetOnly := false) + mapreduceVisitBinders state xs fun state => do + let (state, b) ← visit (i + 1) state b + let e' ← mkLambdaFVars xs b (usedLetOnly := false) + return (state, e') | .forallE .. => do - forallTelescope e fun xs b => mapVisitBinders xs do mkForallFVars xs (← visit (i + 1) b) + forallTelescope e fun xs b => + mapreduceVisitBinders state xs fun state => do + let (state, b) ← visit (i + 1) state b + let e' ← mkForallFVars xs b + return (state, e') | .letE .. => do - lambdaLetTelescope e fun xs b => mapVisitBinders xs do - mkLambdaFVars xs (← visit (i + 1) b) (usedLetOnly := false) - | .mdata _ b => return e.updateMData! (← visit (i + 1) b) - | .proj _ _ b => return e.updateProj! (← visit (i + 1) b) - visit 0 e + lambdaLetTelescope e fun xs b => + mapreduceVisitBinders state xs fun state => do + let (state, b) ← visit (i + 1) state b + let e' ← mkLambdaFVars xs b (usedLetOnly := false) + return (state, e') + | .mdata _ b => do + let (state, b) ← visit (i + 1) state b + return (state, e.updateMData! b) + | .proj _ _ b => do + let (state, b) ← visit (i + 1) state b + return (state, e.updateProj! b) + visit 0 state e + +-- A map visitor function for expressions (adapted from `AbstractNestedProofs.visit`) +-- The continuation takes as parameters: +-- - the current depth of the expression (useful for printing/debugging) +-- - the expression to explore +partial def mapVisit (k : Nat → Expr → MetaM Expr) (e : Expr) : MetaM Expr := do + let k' i (_ : Unit) e := do + let e ← k i e + pure ((), e) + let (_, e) ← mapreduceVisit k' () e + pure e + +-- A reduce visitor +partial def reduceVisit {a : Type} (k : Nat → a → Expr → MetaM a) (s : a) (e : Expr) : MetaM a := do + let k' i (s : a) e := do + let s ← k i s e + pure (s, e) + let (s, _) ← mapreduceVisit k' s e + pure s -- Generate a fresh user name for an anonymous proposition to introduce in the -- assumptions @@ -376,16 +425,17 @@ def destEq (e : Expr) : MetaM (Expr × Expr) := do else throwError "Not an equality: {e}" -- Return the set of FVarIds in the expression +-- TODO: this collects fvars introduced in the inner bindings partial def getFVarIds (e : Expr) (hs : HashSet FVarId := HashSet.empty) : MetaM (HashSet FVarId) := do - e.consumeMData.withApp fun body args => do - let hs := if body.isFVar then hs.insert body.fvarId! else hs - args.foldlM (fun hs arg => getFVarIds arg hs) hs + reduceVisit (fun _ (hs : HashSet FVarId) e => + if e.isFVar then pure (hs.insert e.fvarId!) else pure hs) + hs e -- Return the set of MVarIds in the expression partial def getMVarIds (e : Expr) (hs : HashSet MVarId := HashSet.empty) : MetaM (HashSet MVarId) := do - e.consumeMData.withApp fun body args => do - let hs := if body.isMVar then hs.insert body.mvarId! else hs - args.foldlM (fun hs arg => getMVarIds arg hs) hs + reduceVisit (fun _ (hs : HashSet MVarId) e => + if e.isMVar then pure (hs.insert e.mvarId!) else pure hs) + hs e -- Tactic to split on a disjunction. -- The expression `h` should be an fvar. -- cgit v1.2.3 From dba35a41e66019e586502f563ce7c629356fb2d7 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 12 Dec 2023 17:28:34 +0100 Subject: Make progress on supporting higher-order divergent functions --- backends/lean/Base/Diverge/Base.lean | 22 +++--- backends/lean/Base/Diverge/Elab.lean | 126 +++++++++++++++++++++++------------ 2 files changed, 95 insertions(+), 53 deletions(-) (limited to 'backends/lean/Base') diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index bdc3ed04..9458c926 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -1467,26 +1467,28 @@ namespace Ex8 let tl ← map f tl .ret (hd :: tl) - /- The validity theorem for `map`, generic in `f` -/ + /- The validity theorems for `map`, generic in `f` -/ + + -- This is not the most general lemma, but we keep it to test the `divergence` encoding on a simple case @[divspec] - theorem map_is_valid + theorem map_is_valid_simple (i : id) (t : ty i) - {{f : (a i t → Result (b i t)) → (a i t) → Result c}} - (Hfvalid : ∀ k x, is_valid_p k (λ k => f (k i t) x)) (k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) (ls : List (a i t)) : - is_valid_p k (λ k => map (f (k i t)) ls) := by + is_valid_p k (λ k => map (k i t) ls) := by induction ls <;> simp [map] apply is_valid_p_bind <;> try simp_all intros apply is_valid_p_bind <;> try simp_all @[divspec] - theorem map_is_valid' - (i : id) (t : ty i) + theorem map_is_valid + (d : Type y) + {{f : ((i:id) → (t : ty i) → a i t → Result (b i t)) → d → Result c}} (k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) - (ls : List (a i t)) : - is_valid_p k (λ k => map (k i t) ls) := by + (Hfvalid : ∀ x1, is_valid_p k (fun kk1 => f kk1 x1)) + (ls : List d) : + is_valid_p k (λ k => map (f k) ls) := by induction ls <;> simp [map] apply is_valid_p_bind <;> try simp_all intros @@ -1532,7 +1534,7 @@ namespace Ex9 apply is_valid_p_bind <;> try simp [*] -- We have to show that `map k tl` is valid -- Remark: `map_is_valid` doesn't work here, we need the specialized version - apply map_is_valid' + apply map_is_valid_simple def body (k : (i : Fin 1) → (t : ty i) → (x : input_ty i t) → Result (output_ty i t)) (i: Fin 1) : (t : ty i) → (x : input_ty i t) → Result (output_ty i t) := get_fun bodies i k diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index e555b3e3..08d21e42 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -541,7 +541,6 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) trace[Diverge.def] "individual body of {preDef.declName}: {body}" -- Return the constant let body := Lean.mkConst name (levelParams.map .param) - -- let body ← mkAppM' body #[kk_var] trace[Diverge.def] "individual body (after decl): {body}" pure body @@ -665,7 +664,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do proveAppIsValid k_var kk_var e f args partial def proveAppIsValid (k_var kk_var : Expr) (e : Expr) (f : Expr) (args : Array Expr): MetaM Expr := do - trace[Diverge.def.valid] "proveAppIsValid: {f} {args}" + trace[Diverge.def.valid] "proveAppIsValid: {e}\nDecomposed: {f} {args}" /- There are several cases: first, check if this is a match/if Check if the expression is a (dependent) if then else. We treat the if then else expressions differently from the other matches, @@ -821,7 +820,8 @@ partial def proveAppIsValid (k_var kk_var : Expr) (e : Expr) (f : Expr) (args : - if no: this is simple - if yes: we have to lookup theorems in div spec database and continue -/ trace[Diverge.def.valid] "regular app: {e}, f: {f}, args: {args}" - let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty + let argsFVars ← args.mapM getFVarIds + let allArgsFVars := argsFVars.foldl (fun hs fvars => hs.insertMany fvars) HashSet.empty trace[Diverge.def.valid] "allArgsFVars: {allArgsFVars.toList.map mkFVar}" if ¬ allArgsFVars.contains kk_var.fvarId! then do -- Simple case @@ -837,7 +837,6 @@ partial def proveAppIsValid (k_var kk_var : Expr) (e : Expr) (f : Expr) (args : partial def proveAppIsValidApplyThms (k_var kk_var : Expr) (e : Expr) (f : Expr) (args : Array Expr) (thms : List Name) : MetaM Expr := do - trace[Diverge.def.valid] "thms: {thms}" match thms with | [] => throwError "Could not prove that the following expression is valid: {e}" | thName :: thms => @@ -849,58 +848,67 @@ partial def proveAppIsValidApplyThms (k_var kk_var : Expr) (e : Expr) thDecl.levelParams.mapM (λ x => do pure (x, ← mkFreshLevelMVar)) let ulMap : HashMap Name Level := HashMap.ofList ul let thTy := thDecl.type.instantiateLevelParamsCore (λ x => ulMap.find! x) + trace[Diverge.def.valid] "Trying with theorem {thName}: {thTy}" -- Introduce meta variables for the universally quantified variables - let (mvars, _binders, thTy) ← forallMetaTelescope thTy - -- thTy should now be of the shape: `is_valid_p k (λ kk => ...)` - /-thTy.consumeMData.withApp fun _ args => do - if args.size ≠ 7 then throwError "Invalid number of arguments (expected 7): {thTy}" - let thTermToMatch := args.get! 6 -/ - let thTermToMatch := thTy + let (mvars, _binders, thTyBody) ← forallMetaTelescope thTy + let thTermToMatch := thTyBody trace[Diverge.def.valid] "thTermToMatch: {thTermToMatch}" -- Create the term: `is_valid_p k (λ kk => e)` let termToMatch ← mkLambdaFVars #[kk_var] e let termToMatch ← mkAppM ``FixII.is_valid_p #[k_var, termToMatch] trace[Diverge.def.valid] "termToMatch: {termToMatch}" -- Attempt to match - let ok ← isDefEq thTermToMatch termToMatch + trace[Diverge.def.valid] "Matching terms:\n\n{termToMatch}\n\n{thTermToMatch}" + let ok ← isDefEq termToMatch thTermToMatch if ¬ ok then -- Failure: attempt with the other theorems proveAppIsValidApplyThms k_var kk_var e f args thms else do - -- Success: continue with this theorem - -- Instantiate the meta variables (some of them will not be instantiated: - -- they are new subgoals) + /- Success: continue with this theorem + + Instantiate the meta variables (some of them will not be instantiated: + they are new subgoals) + -/ let mvars ← mvars.mapM instantiateMVars let th ← mkAppOptM thName (Array.map some mvars) - trace[Diverge.def.valid] "Instantiated theorm: {th}\n{← inferType th}" - -- Filter the meta variables between the instantiated ones - for mvar in mvars do - if mvar.isMVar then do - -- Prove the subgoal (i.e., the precondition of the theorem) - let mvarId := mvar.mvarId! - let mvarDecl ← mvarId.getDecl - -- Dive in the type - forallTelescope mvarDecl.type fun forall_vars mvar_e => do - -- `mvar_e` should have the shape `is_valid_p k (λ kk => ...)` - -- We need to retrieve the new `k` variable, and dive into the - -- `λ kk => ...` - mvar_e.consumeMData.withApp fun is_valid args => do - if is_valid.constName? ≠ ``FixII.is_valid_p ∨ args.size ≠ 2 then - throwError "Invalid precondition: {mvar_e}" - else do - let k_var := args.get! 0 - let e_lam := args.get! 1 - lambdaTelescope e_lam.consumeMData fun lvars e => do - if lvars.size ≠ 1 then throwError "Invalid number of lambdas (expected 1): {e_lam}" - let kk_var := lvars.get! 0 - -- Continue - let e_valid ← proveExprIsValid k_var kk_var e - let e_valid ← mkForallFVars forall_vars e_valid - -- Assign the meta variable - mvarId.assign e_valid - else - -- Nothing to do - pure () + trace[Diverge.def.valid] "Instantiated theorem: {th}\n{← inferType th}" + -- Filter the instantiated meta variables + let mvars := mvars.filter (fun v => v.isMVar) + let mvars := mvars.map (fun v => v.mvarId!) + trace[Diverge.def.valid] "Remaining subgoals: {mvars}" + for mvarId in mvars do + -- Prove the subgoal (i.e., the precondition of the theorem) + let mvarDecl ← mvarId.getDecl + let declType ← instantiateMVars mvarDecl.type + -- Reduce the subgoal before diving in, if necessary + trace[Diverge.def.valid] "Subgoal: {declType}" + -- Dive in the type + forallTelescope declType fun forall_vars mvar_e => do + trace[Diverge.def.valid] "forall_vars: {forall_vars}" + -- `mvar_e` should have the shape `is_valid_p k (λ kk => ...)` + -- We need to retrieve the new `k` variable, and dive into the + -- `λ kk => ...` + mvar_e.consumeMData.withApp fun is_valid args => do + if is_valid.constName? ≠ ``FixII.is_valid_p ∨ args.size ≠ 7 then + throwError "Invalid precondition: {mvar_e}" + else do + let k_var := args.get! 5 + let e_lam := args.get! 6 + trace[Diverge.def.valid] "k_var: {k_var}\ne_lam: {e_lam}" + -- The outer lambda should be for the kk_var + lambdaOne e_lam.consumeMData fun kk_var e => do + -- Continue + trace[Diverge.def.valid] "kk_var: {kk_var}\ne: {e}" + -- We sometimes need to reduce the term + let e ← whnf e + trace[Diverge.def.valid] "e (after reduction): {e}" + let e_valid ← proveExprIsValid k_var kk_var e + trace[Diverge.def.valid] "e_valid (for e): {e_valid}" + let e_valid ← mkLambdaFVars forall_vars e_valid + trace[Diverge.def.valid] "e_valid (with foralls): {e_valid}" + let _ ← inferType e_valid -- Sanity check + -- Assign the meta variable + mvarId.assign e_valid pure th -- Prove that a match expression is valid. @@ -1442,6 +1450,38 @@ namespace Tests #check id.unfold + -- set_option pp.explicit true + -- set_option trace.Diverge.def true + -- set_option trace.Diverge.def.genBody true + set_option trace.Diverge.def.valid true + divergent def id1 {a : Type u} (t : Tree a) : Result (Tree a) := + match t with + | .leaf x => .ret (.leaf x) + | .node tl => + do + let tl ← map (fun x => id1 x) tl + .ret (.node tl) + + #check id1.unfold + + /-set_option trace.Diverge.def false + + -- set_option pp.explicit true + -- set_option trace.Diverge.def true + -- set_option trace.Diverge.def.genBody true + set_option trace.Diverge.def.valid true + divergent def id2 {a : Type u} (t : Tree a) : Result (Tree a) := + match t with + | .leaf x => .ret (.leaf x) + | .node tl => + do + let tl ← map (fun x => do let _ ← id2 x; id2 x) tl + .ret (.node tl) + + #check id2.unfold + + set_option trace.Diverge.def false -/ + /-set_option trace.Diverge.def true -- set_option trace.Diverge.def.genBody true set_option trace.Diverge.def.valid true -- cgit v1.2.3 From 698f631e7addb92eb270a75607f1f6ffd8b2414f Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 12 Dec 2023 17:47:58 +0100 Subject: Fix minor issues with the divergence encoding --- backends/lean/Base/Diverge/Elab.lean | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) (limited to 'backends/lean/Base') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 08d21e42..e9544f38 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -899,8 +899,11 @@ partial def proveAppIsValidApplyThms (k_var kk_var : Expr) (e : Expr) lambdaOne e_lam.consumeMData fun kk_var e => do -- Continue trace[Diverge.def.valid] "kk_var: {kk_var}\ne: {e}" - -- We sometimes need to reduce the term - let e ← whnf e + -- We sometimes need to reduce the term - TODO: this is really dangerous + let e ← do + let updt_config config := + { config with transparency := .reducible, zetaNonDep := false } + withConfig updt_config (whnf e) trace[Diverge.def.valid] "e (after reduction): {e}" let e_valid ← proveExprIsValid k_var kk_var e trace[Diverge.def.valid] "e_valid (for e): {e_valid}" @@ -1450,10 +1453,6 @@ namespace Tests #check id.unfold - -- set_option pp.explicit true - -- set_option trace.Diverge.def true - -- set_option trace.Diverge.def.genBody true - set_option trace.Diverge.def.valid true divergent def id1 {a : Type u} (t : Tree a) : Result (Tree a) := match t with | .leaf x => .ret (.leaf x) @@ -1464,12 +1463,6 @@ namespace Tests #check id1.unfold - /-set_option trace.Diverge.def false - - -- set_option pp.explicit true - -- set_option trace.Diverge.def true - -- set_option trace.Diverge.def.genBody true - set_option trace.Diverge.def.valid true divergent def id2 {a : Type u} (t : Tree a) : Result (Tree a) := match t with | .leaf x => .ret (.leaf x) @@ -1480,8 +1473,6 @@ namespace Tests #check id2.unfold - set_option trace.Diverge.def false -/ - /-set_option trace.Diverge.def true -- set_option trace.Diverge.def.genBody true set_option trace.Diverge.def.valid true -- cgit v1.2.3 From 91f5cd49660b5f012a2faeaf00c49455c548734a Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 12 Dec 2023 17:59:12 +0100 Subject: Fix a minor issue with the divergent encoding --- backends/lean/Base/Diverge/Elab.lean | 133 +++++++++++++++++++++-------------- 1 file changed, 80 insertions(+), 53 deletions(-) (limited to 'backends/lean/Base') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index e9544f38..ff680c07 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -496,9 +496,24 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) -- Not a recursive call: do nothing pure e | .const name _ => - -- Sanity check: we eliminated all the recursive calls - if (nameToInfo.find? name).isSome then - throwError "mkUnaryBodies: a recursive call was not eliminated" + /- This might refer to the one of the top-level functions if we use + it without arguments (if we give it to a higher-order + function for instance) and there are actually no type parameters. + -/ + if (nameToInfo.find? name).isSome then do + -- Checking the type information + match nameToInfo.find? name with + | none => pure e + | some (id, type_info) => + trace[Diverge.def.genBody.visit] "this is a recursive call" + -- This is a recursive call: replace it + -- Compute the index + let i ← mkFinVal grSize id + -- Check that there are no type parameters + if type_info.num_params ≠ 0 then throwError "mkUnaryBodies: a recursive call was not eliminated" + -- Introduce the call to the continuation + let param_args ← mkSigmasVal type_info.params_ty [] + mkAppM' kk_var #[i, param_args] else pure e | _ => pure e trace[Diverge.def.genBody.visit] "done with expression (depth: {i}): {e}" @@ -1436,57 +1451,11 @@ elab_rules : command namespace Tests /- Some examples of partial functions -/ - section HigherOrder - open Ex8 - - inductive Tree (a : Type u) := - | leaf (x : a) - | node (tl : List (Tree a)) - - divergent def id {a : Type u} (t : Tree a) : Result (Tree a) := - match t with - | .leaf x => .ret (.leaf x) - | .node tl => - do - let tl ← map id tl - .ret (.node tl) - - #check id.unfold - - divergent def id1 {a : Type u} (t : Tree a) : Result (Tree a) := - match t with - | .leaf x => .ret (.leaf x) - | .node tl => - do - let tl ← map (fun x => id1 x) tl - .ret (.node tl) - - #check id1.unfold - - divergent def id2 {a : Type u} (t : Tree a) : Result (Tree a) := - match t with - | .leaf x => .ret (.leaf x) - | .node tl => - do - let tl ← map (fun x => do let _ ← id2 x; id2 x) tl - .ret (.node tl) - - #check id2.unfold - - /-set_option trace.Diverge.def true - -- set_option trace.Diverge.def.genBody true - set_option trace.Diverge.def.valid true - divergent def incr (t : Tree Nat) : Result (Tree Nat) := - match t with - | .leaf x => .ret (.leaf (x + 1)) - | .node tl => - do - let tl ← map incr tl - .ret (.node tl) - - set_option trace.Diverge.def false-/ - end HigherOrder + --set_option trace.Diverge.def true + --set_option trace.Diverge.def.genBody true + --set_option trace.Diverge.def.valid true + --set_option trace.Diverge.def.genBody.visit true divergent def list_nth {a: Type u} (ls : List a) (i : Int) : Result a := match ls with @@ -1495,6 +1464,8 @@ namespace Tests if i = 0 then return x else return (← list_nth ls (i - 1)) + --set_option trace.Diverge false + #check list_nth.unfold example {a: Type} (ls : List a) : @@ -1590,6 +1561,62 @@ namespace Tests #check test1.unfold + /- Tests with higher-order functions -/ + section HigherOrder + open Ex8 + + inductive Tree (a : Type u) := + | leaf (x : a) + | node (tl : List (Tree a)) + + divergent def id {a : Type u} (t : Tree a) : Result (Tree a) := + match t with + | .leaf x => .ret (.leaf x) + | .node tl => + do + let tl ← map id tl + .ret (.node tl) + + #check id.unfold + + divergent def id1 {a : Type u} (t : Tree a) : Result (Tree a) := + match t with + | .leaf x => .ret (.leaf x) + | .node tl => + do + let tl ← map (fun x => id1 x) tl + .ret (.node tl) + + #check id1.unfold + + divergent def id2 {a : Type u} (t : Tree a) : Result (Tree a) := + match t with + | .leaf x => .ret (.leaf x) + | .node tl => + do + let tl ← map (fun x => do let _ ← id2 x; id2 x) tl + .ret (.node tl) + + #check id2.unfold + + --set_option trace.Diverge.def true + --set_option trace.Diverge.def.genBody true + --set_option trace.Diverge.def.valid true + --set_option trace.Diverge.def.genBody.visit true + divergent def incr (t : Tree Nat) : Result (Tree Nat) := + match t with + | .leaf x => .ret (.leaf (x + 1)) + | .node tl => + do + let tl ← map incr tl + .ret (.node tl) + + --set_option trace.Diverge false + + #check incr.unfold + + end HigherOrder + end Tests end Diverge -- cgit v1.2.3 From c14e3e5ffa261e4ed6e5539b06c182b371939ccf Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 12 Dec 2023 19:48:50 +0100 Subject: Inline the let-bindings in the validity proofs --- backends/lean/Base/Diverge/Elab.lean | 47 ++++++++++++++++++++++++++++++------ 1 file changed, 40 insertions(+), 7 deletions(-) (limited to 'backends/lean/Base') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index ff680c07..6115b13b 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -17,6 +17,8 @@ syntax (name := divergentDef) open Lean Elab Term Meta Primitives Lean.Meta open Utils +def normalize_let_bindings := true + /- The following was copied from the `wfRecursion` function. -/ open WF in @@ -649,6 +651,15 @@ mutual -/ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do trace[Diverge.def.valid] "proveExprIsValid: {e}" + -- Normalize to eliminate the lambdas - TODO: this is slightly dangerous. + let e ← do + if e.isLet ∧ normalize_let_bindings then do + let updt_config config := + { config with transparency := .reducible, zetaNonDep := false } + let e ← withConfig updt_config (whnf e) + trace[Diverge.def.valid] "e (after normalization): {e}" + pure e + else pure e match e with | .const _ _ => throwError "Unimplemented" -- Shouldn't get there? | .bvar _ @@ -659,6 +670,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do | .lam .. => throwError "Unimplemented" -- TODO | .forallE .. => throwError "Unreachable" -- Shouldn't get there | .letE .. => do + -- Remark: this branch is not taken if we normalize the expressions (above) -- Telescope all the let-bindings (remark: this also telescopes the lambdas) lambdaLetTelescope e fun xs body => do -- Note that we don't visit the bound values: there shouldn't be @@ -919,7 +931,7 @@ partial def proveAppIsValidApplyThms (k_var kk_var : Expr) (e : Expr) let updt_config config := { config with transparency := .reducible, zetaNonDep := false } withConfig updt_config (whnf e) - trace[Diverge.def.valid] "e (after reduction): {e}" + trace[Diverge.def.valid] "e (after normalization): {e}" let e_valid ← proveExprIsValid k_var kk_var e trace[Diverge.def.valid] "e_valid (for e): {e_valid}" let e_valid ← mkLambdaFVars forall_vars e_valid @@ -1018,6 +1030,7 @@ partial def proveSingleBodyIsValid let thmTy ← mkAppM ``FixII.is_valid_p #[k_var, bodyApp] trace[Diverge.def.valid] "thmTy: {thmTy}" -- Prove that the body is valid + trace[Diverge.def.valid] "body: {body}" let proof ← proveExprIsValid k_var kk_var body let proof ← mkLambdaFVars #[k_var, t_var, x_var] proof trace[Diverge.def.valid] "proveSingleBodyIsValid: proof:\n{proof}:\n{← inferType proof}" @@ -1599,10 +1612,6 @@ namespace Tests #check id2.unfold - --set_option trace.Diverge.def true - --set_option trace.Diverge.def.genBody true - --set_option trace.Diverge.def.valid true - --set_option trace.Diverge.def.genBody.visit true divergent def incr (t : Tree Nat) : Result (Tree Nat) := match t with | .leaf x => .ret (.leaf (x + 1)) @@ -1611,9 +1620,33 @@ namespace Tests let tl ← map incr tl .ret (.node tl) - --set_option trace.Diverge false + -- We handle this by inlining the let-binding + divergent def id3 (t : Tree Nat) : Result (Tree Nat) := + match t with + | .leaf x => .ret (.leaf (x + 1)) + | .node tl => + do + let f := id3 + let tl ← map f tl + .ret (.node tl) + + #check id3.unfold + + /- + -- This is not handled yet: we can only do it if we introduce "general" + -- relations for the input types and output types (result_rel should + -- be parameterized by something). + divergent def id4 (t : Tree Nat) : Result (Tree Nat) := + match t with + | .leaf x => .ret (.leaf (x + 1)) + | .node tl => + do + let f ← .ret id4 + let tl ← map f tl + .ret (.node tl) - #check incr.unfold + #check id4.unfold + -/ end HigherOrder -- cgit v1.2.3