From 8719c17f1a363c0463d74b90e558b2aaa24645d6 Mon Sep 17 00:00:00 2001 From: Son HO Date: Sat, 22 Jun 2024 15:07:14 +0200 Subject: Do some cleanup in the Lean backend (#257) --- backends/lean/Base/Core.lean | 17 ++++++ backends/lean/Base/IList/IList.lean | 18 +++---- backends/lean/Base/Primitives/ArraySlice.lean | 8 +++ backends/lean/Base/Primitives/Base.lean | 22 ++++---- backends/lean/Base/Primitives/Core.lean | 13 +++++ backends/lean/Base/Primitives/Scalar.lean | 9 +++- backends/lean/Base/Primitives/Vec.lean | 4 ++ tests/lean/Hashmap/Properties.lean | 76 +++++++-------------------- 8 files changed, 89 insertions(+), 78 deletions(-) create mode 100644 backends/lean/Base/Core.lean diff --git a/backends/lean/Base/Core.lean b/backends/lean/Base/Core.lean new file mode 100644 index 00000000..89dd199b --- /dev/null +++ b/backends/lean/Base/Core.lean @@ -0,0 +1,17 @@ + +import Lean + +/- This lemma is generally useful. It often happens that (because we + make a split on a condition for instance) we have `x ≠ y` in the context + and need to simplify `y ≠ x` somewhere. -/ +@[simp] +theorem neq_imp {α : Type u} {x y : α} (h : ¬ x = y) : ¬ y = x := by intro; simp_all + +/- This is generally useful, and doing without is actually quite cumbersome. + + Note that the following theorem does not seem to be necessary (we invert `x` + and `y` in the conclusion), probably because of `neq_imp`: + `¬ x = y → ¬ y == x` + -/ +@[simp] +theorem neq_imp_nbeq [BEq α] [LawfulBEq α] (x y : α) (heq : ¬ x = y) : ¬ x == y := by simp [*] diff --git a/backends/lean/Base/IList/IList.lean b/backends/lean/Base/IList/IList.lean index c77f075f..1b103bb3 100644 --- a/backends/lean/Base/IList/IList.lean +++ b/backends/lean/Base/IList/IList.lean @@ -3,13 +3,7 @@ import Base.Arith import Base.Utils - --- TODO: move? --- This lemma is generally useful. It often happens that (because we --- make a split on a condition for instance) we have `x ≠ y` in the context --- and need to simplify `y ≠ x` somewhere. -@[simp] -theorem neq_imp {α : Type u} {x y : α} (h : ¬ x = y) : ¬ y = x := by intro; simp_all +import Base.Core namespace List @@ -134,7 +128,7 @@ def pairwise_rel | [] => True | hd :: tl => allP tl (rel hd) ∧ pairwise_rel rel tl -section Lemmas +section variable {α : Type u} @@ -578,6 +572,12 @@ theorem pairwise_rel_cons {α : Type u} (rel : α → α → Prop) (hd: α) (tl: pairwise_rel rel (hd :: tl) ↔ allP tl (rel hd) ∧ pairwise_rel rel tl := by simp [pairwise_rel] -end Lemmas +theorem lookup_not_none_imp_len_pos [BEq α] (l : List (α × β)) (key : α) + (hLookup : l.lookup key ≠ none) : + 0 < l.len := by + induction l <;> simp_all + scalar_tac + +end end List diff --git a/backends/lean/Base/Primitives/ArraySlice.lean b/backends/lean/Base/Primitives/ArraySlice.lean index 899871af..3cc0f9c1 100644 --- a/backends/lean/Base/Primitives/ArraySlice.lean +++ b/backends/lean/Base/Primitives/ArraySlice.lean @@ -16,6 +16,10 @@ open Result Error core.ops.range def Array (α : Type u) (n : Usize) := { l : List α // l.length = n.val } +instance [BEq α] : BEq (Array α n) := SubtypeBEq _ + +instance [BEq α] [LawfulBEq α] : LawfulBEq (Array α n) := SubtypeLawfulBEq _ + instance (a : Type u) (n : Usize) : Arith.HasIntProp (Array a n) where prop_ty := λ v => v.val.len = n.val prop := λ ⟨ _, l ⟩ => by simp[Scalar.max, List.len_eq_length, *] @@ -109,6 +113,10 @@ theorem Array.index_mut_usize_spec {α : Type u} {n : Usize} [Inhabited α] (v: def Slice (α : Type u) := { l : List α // l.length ≤ Usize.max } +instance [BEq α] : BEq (Slice α) := SubtypeBEq _ + +instance [BEq α] [LawfulBEq α] : LawfulBEq (Slice α) := SubtypeLawfulBEq _ + instance (a : Type u) : Arith.HasIntProp (Slice a) where prop_ty := λ v => 0 ≤ v.val.len ∧ v.val.len ≤ Scalar.max ScalarTy.Usize prop := λ ⟨ _, l ⟩ => by simp[Scalar.max, List.len_eq_length, *] diff --git a/backends/lean/Base/Primitives/Base.lean b/backends/lean/Base/Primitives/Base.lean index c9237e65..63fbd8c0 100644 --- a/backends/lean/Base/Primitives/Base.lean +++ b/backends/lean/Base/Primitives/Base.lean @@ -134,18 +134,16 @@ def Result.attach {α: Type} (o : Result α): Result { x : α // o = ok x } := -- MISC -- ---------- --- This acts like a swap effectively in a functional pure world. --- We return the old value of `dst`, i.e. `dst` itself. --- The new value of `dst` is `src`. -@[simp] def core.mem.replace (a : Type) (dst : a) (src : a) : a × a := (dst, src) -/- [core::option::Option::take] -/ -@[simp] def Option.take (T: Type) (self: Option T): Option T × Option T := (self, .none) -/- [core::mem::swap] -/ -@[simp] def core.mem.swap (T: Type) (a b: T): T × T := (b, a) - -/-- Aeneas-translated function -- useful to reduce non-recursive definitions. - Use with `simp [ aeneas ]` -/ -register_simp_attr aeneas +instance SubtypeBEq [BEq α] (p : α → Prop) : BEq (Subtype p) where + beq v0 v1 := v0.val == v1.val + +instance SubtypeLawfulBEq [BEq α] (p : α → Prop) [LawfulBEq α] : LawfulBEq (Subtype p) where + eq_of_beq {a b} h := by cases a; cases b; simp_all [BEq.beq] + rfl := by intro a; cases a; simp [BEq.beq] + +------------------------------ +---- Misc Primitives Types --- +------------------------------ -- We don't really use raw pointers for now structure MutRawPtr (T : Type) where diff --git a/backends/lean/Base/Primitives/Core.lean b/backends/lean/Base/Primitives/Core.lean index 14a51bc1..aa4a7f28 100644 --- a/backends/lean/Base/Primitives/Core.lean +++ b/backends/lean/Base/Primitives/Core.lean @@ -59,4 +59,17 @@ def Option.unwrap (T : Type) (x : Option T) : Result T := end option -- core.option +/- [core::option::Option::take] -/ +@[simp] def Option.take (T: Type) (self: Option T): Option T × Option T := (self, .none) + +/- [core::mem::replace] + + This acts like a swap effectively in a functional pure world. + We return the old value of `dst`, i.e. `dst` itself. + The new value of `dst` is `src`. -/ +@[simp] def mem.replace (a : Type) (dst : a) (src : a) : a × a := (dst, src) + +/- [core::mem::swap] -/ +@[simp] def mem.swap (T: Type) (a b: T): T × T := (b, a) + end core diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean index 31038e0d..2359c140 100644 --- a/backends/lean/Base/Primitives/Scalar.lean +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -299,7 +299,14 @@ structure Scalar (ty : ScalarTy) where val : Int hmin : Scalar.min ty ≤ val hmax : val ≤ Scalar.max ty -deriving Repr +deriving Repr, BEq, DecidableEq + +instance {ty} : BEq (Scalar ty) where + beq a b := a.val = b.val + +instance {ty} : LawfulBEq (Scalar ty) where + eq_of_beq {a b} := by cases a; cases b; simp[BEq.beq] + rfl {a} := by cases a; simp [BEq.beq] instance (ty : ScalarTy) : CoeOut (Scalar ty) Int where coe := λ v => v.val diff --git a/backends/lean/Base/Primitives/Vec.lean b/backends/lean/Base/Primitives/Vec.lean index 12789fa9..82ecb8ed 100644 --- a/backends/lean/Base/Primitives/Vec.lean +++ b/backends/lean/Base/Primitives/Vec.lean @@ -16,6 +16,10 @@ namespace alloc.vec def Vec (α : Type u) := { l : List α // l.length ≤ Usize.max } +instance [BEq α] : BEq (Vec α) := SubtypeBEq _ + +instance [BEq α] [LawfulBEq α] : LawfulBEq (Vec α) := SubtypeLawfulBEq _ + instance (a : Type u) : Arith.HasIntProp (Vec a) where prop_ty := λ v => 0 ≤ v.val.len ∧ v.val.len ≤ Scalar.max ScalarTy.Usize prop := λ ⟨ _, l ⟩ => by simp[Scalar.max, List.len_eq_length, *] diff --git a/tests/lean/Hashmap/Properties.lean b/tests/lean/Hashmap/Properties.lean index d9be15dd..76bd2598 100644 --- a/tests/lean/Hashmap/Properties.lean +++ b/tests/lean/Hashmap/Properties.lean @@ -3,34 +3,19 @@ import Hashmap.Funs open Primitives open Result -namespace List - --- TODO: we don't want to use the original List.lookup because it uses BEq --- TODO: rewrite rule: match x == y with ... -> if x = y then ... else ... ? (actually doesn't work because of sugar) --- TODO: move? -@[simp] -def lookup' {α : Type} (ls: List (Usize × α)) (key: Usize) : Option α := - match ls with - | [] => none - | (k, x) :: tl => if k = key then some x else lookup' tl key - -end List - namespace hashmap namespace AList +@[simp] def v {α : Type} (ls: AList α) : List (Usize × α) := match ls with | Nil => [] | Cons k x tl => (k, x) :: v tl -@[simp] theorem v_nil (α : Type) : (Nil : AList α).v = [] := by rfl -@[simp] theorem v_cons {α : Type} k x (tl: AList α) : (Cons k x tl).v = (k, x) :: v tl := by rfl - @[simp] abbrev lookup {α : Type} (ls: AList α) (key: Usize) : Option α := - ls.v.lookup' key + ls.v.lookup key @[simp] abbrev len {α : Type} (ls : AList α) : Int := ls.v.len @@ -39,20 +24,6 @@ end AList namespace HashMap -namespace List - -end List - --- TODO: move -@[simp] theorem neq_imp_nbeq [BEq α] [LawfulBEq α] (x y : α) (heq : ¬ x = y) : ¬ x == y := by simp [*] -@[simp] theorem neq_imp_nbeq_rev [BEq α] [LawfulBEq α] (x y : α) (heq : ¬ x = y) : ¬ y == x := by simp [*] - --- TODO: move --- TODO: this doesn't work because of sugar -theorem match_lawful_beq [BEq α] [LawfulBEq α] [DecidableEq α] (x y : α) : - (x == y) = (if x = y then true else false) := by - split <;> simp_all - def distinct_keys (ls : List (Usize × α)) := ls.pairwise_rel (λ x y => x.fst ≠ y.fst) def hash_mod_key (k : Usize) (l : Int) : Int := @@ -168,6 +139,8 @@ def frame_load (hm nhm : HashMap α) : Prop := -- This rewriting lemma is problematic below attribute [-simp] Bool.exists_bool +attribute [local simp] List.lookup + -- The proofs below are a bit expensive, so we deactivate the heart bits limit set_option maxHeartbeats 0 @@ -276,6 +249,9 @@ theorem new_spec (α : Type) : progress as ⟨ hm ⟩ simp_all +--set_option pp.all true +example (key : Usize) : key == key := by simp [beq_iff_eq] + theorem insert_in_list_spec_aux {α : Type} (l : Int) (key: Usize) (value: α) (l0: AList α) (hinv : slot_s_inv_hash l (hash_mod_key key l) l0.v) (hdk : distinct_keys l0.v) : @@ -307,14 +283,8 @@ theorem insert_in_list_spec_aux {α : Type} (l : Int) (key: Usize) (value: α) ( if h: k = key then rw [insert_in_list] rw [insert_in_list_loop] - simp [h] - exists false; simp only [true_and, exists_eq_left', List.lookup', ↓reduceIte, AList.v_cons] -- TODO: why do we need to do this? - split_conjs - . intros; simp [*] - . simp_all [slot_s_inv_hash] - . simp at hinv; tauto - . simp_all [slot_s_inv_hash] - . simp_all + simp [h, and_assoc] + split_conjs <;> simp_all [slot_s_inv_hash] else rw [insert_in_list] rw [insert_in_list_loop] @@ -448,10 +418,10 @@ theorem insert_no_resize_spec {α : Type} (hm : HashMap α) (key : Usize) (value slots := v } exists nhm have hupdt : lookup nhm key = some value := by - simp [lookup, List.lookup] at * + simp [lookup] at * simp_all have hlkp : ∀ k, ¬ k = key → nhm.lookup k = hm.lookup k := by - simp [lookup, List.lookup] at * + simp [lookup] at * intro k hk -- We have to make a case disjunction: either the hashes are different, -- in which case we don't even lookup the same slots, or the hashes @@ -476,7 +446,7 @@ theorem insert_no_resize_spec {α : Type} (hm : HashMap α) (key : Usize) (value match hm.lookup key with | none => nhm.len_s = hm.len_s + 1 | some _ => nhm.len_s = hm.len_s := by - simp only [lookup, List.lookup, len_s, al_v, HashMap.v, slots_s_lookup] at * + simp only [lookup, len_s, al_v, HashMap.v, slots_s_lookup] at * -- We have to do a case disjunction simp_all [List.map_update_eq] -- TODO: dependent rewrites @@ -508,7 +478,7 @@ theorem insert_no_resize_spec {α : Type} (hm : HashMap α) (key : Usize) (value . simp_all [frame_load, inv_base, inv_load] simp_all -private theorem slot_allP_not_key_lookup (slot : AList T) (h : slot.v.allP fun (k', _) => ¬k = k') : +private theorem slot_allP_not_key_lookup (slot : AList α) (h : slot.v.allP fun (k', _) => ¬k = k') : slot.lookup k = none := by induction slot <;> simp_all @@ -624,7 +594,7 @@ private theorem slots_index_len_le_flatten_len (slots : List (AList α)) (i : In (slots.index i).len ≤ (List.map AList.v slots).flatten.len := by match slots with | [] => - simp at *; scalar_tac + simp at * | slot :: slots' => simp at * if hi : i = 0 then @@ -643,7 +613,7 @@ private theorem slots_inv_lookup_imp_eq (slots : Slots α) (hInv : slots_t_inv s (slots.val.index i).lookup key ≠ none → i = key.val % slots.val.len := by suffices hSlot : ∀ (slot : List (Usize × α)), slot_s_inv slots.val.len i slot → - slot.lookup' key ≠ none → + slot.lookup key ≠ none → i = key.val % slots.val.len from by rw [slots_t_inv, slots_s_inv] at hInv @@ -965,8 +935,8 @@ theorem try_resize_spec {α : Type} (hm : HashMap α) (hInv : hm.inv): simp_all [lookup, al_v, v, alloc.vec.Vec.len] intro key replace hLookup := hLookup key - cases h1: (ntable2.slots.val.index (key.val % ntable2.slots.val.len)).v.lookup' key <;> - cases h2: (hm.slots.val.index (key.val % hm.slots.val.len)).v.lookup' key <;> + cases h1: (ntable2.slots.val.index (key.val % ntable2.slots.val.len)).v.lookup key <;> + cases h2: (hm.slots.val.index (key.val % hm.slots.val.len)).v.lookup key <;> simp_all [Slots.lookup] else simp [hSmaller] @@ -1002,7 +972,7 @@ theorem get_in_list_spec {α} (key : Usize) (slot : AList α) (hLookup : slot.lo ∃ v, get_in_list α key slot = ok v ∧ slot.lookup key = some v := by induction slot <;> rw [get_in_list, get_in_list_loop] <;> - simp_all [AList.lookup] + simp_all split <;> simp_all @[pspec] @@ -1038,7 +1008,7 @@ theorem get_mut_in_list_spec {α} (key : Usize) (slot : AList α) := by induction slot <;> rw [get_mut_in_list, get_mut_in_list_loop] <;> - simp_all [AList.lookup] + simp_all split . -- Non-recursive case simp_all [and_assoc, slot_t_inv] @@ -1134,12 +1104,6 @@ theorem remove_from_list_spec {α} (key : Usize) (slot : AList α) {l i} (hInv : simp_all . cases v1 <;> simp_all --- TODO: move? -theorem lookup'_not_none_imp_len_pos (l : List (Usize × α)) (key : Usize) (hLookup : l.lookup' key ≠ none) : - 0 < l.len := by - induction l <;> simp_all - scalar_tac - private theorem lookup_not_none_imp_len_s_pos (hm : HashMap α) (key : Usize) (hLookup : hm.lookup key ≠ none) (hNotEmpty : 0 < hm.slots.val.len) : 0 < hm.len_s := by @@ -1148,7 +1112,7 @@ private theorem lookup_not_none_imp_len_s_pos (hm : HashMap α) (key : Usize) (h have : key.val % hm.slots.val.len < hm.slots.val.len := by -- TODO: automate apply Int.emod_lt_of_pos; scalar_tac have := List.len_index_le_len_flatten hm.v (key.val % hm.slots.val.len) - have := lookup'_not_none_imp_len_pos (hm.slots.val.index (key.val % hm.slots.val.len)).v key + have := List.lookup_not_none_imp_len_pos (hm.slots.val.index (key.val % hm.slots.val.len)).v key simp_all [lookup, len_s, al_v, v] scalar_tac -- cgit v1.2.3