summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon HO2024-06-22 15:07:14 +0200
committerGitHub2024-06-22 15:07:14 +0200
commit8719c17f1a363c0463d74b90e558b2aaa24645d6 (patch)
tree94cd2fb84f10912e76d1d1e8e89d8f9aee948f0c
parent8144c39f4d37aa1fa14a8a061eb7ed60e153fb4c (diff)
Do some cleanup in the Lean backend (#257)
-rw-r--r--backends/lean/Base/Core.lean17
-rw-r--r--backends/lean/Base/IList/IList.lean18
-rw-r--r--backends/lean/Base/Primitives/ArraySlice.lean8
-rw-r--r--backends/lean/Base/Primitives/Base.lean22
-rw-r--r--backends/lean/Base/Primitives/Core.lean13
-rw-r--r--backends/lean/Base/Primitives/Scalar.lean9
-rw-r--r--backends/lean/Base/Primitives/Vec.lean4
-rw-r--r--tests/lean/Hashmap/Properties.lean76
8 files changed, 89 insertions, 78 deletions
diff --git a/backends/lean/Base/Core.lean b/backends/lean/Base/Core.lean
new file mode 100644
index 00000000..89dd199b
--- /dev/null
+++ b/backends/lean/Base/Core.lean
@@ -0,0 +1,17 @@
+
+import Lean
+
+/- This lemma is generally useful. It often happens that (because we
+ make a split on a condition for instance) we have `x ≠ y` in the context
+ and need to simplify `y ≠ x` somewhere. -/
+@[simp]
+theorem neq_imp {α : Type u} {x y : α} (h : ¬ x = y) : ¬ y = x := by intro; simp_all
+
+/- This is generally useful, and doing without is actually quite cumbersome.
+
+ Note that the following theorem does not seem to be necessary (we invert `x`
+ and `y` in the conclusion), probably because of `neq_imp`:
+ `¬ x = y → ¬ y == x`
+ -/
+@[simp]
+theorem neq_imp_nbeq [BEq α] [LawfulBEq α] (x y : α) (heq : ¬ x = y) : ¬ x == y := by simp [*]
diff --git a/backends/lean/Base/IList/IList.lean b/backends/lean/Base/IList/IList.lean
index c77f075f..1b103bb3 100644
--- a/backends/lean/Base/IList/IList.lean
+++ b/backends/lean/Base/IList/IList.lean
@@ -3,13 +3,7 @@
import Base.Arith
import Base.Utils
-
--- TODO: move?
--- This lemma is generally useful. It often happens that (because we
--- make a split on a condition for instance) we have `x ≠ y` in the context
--- and need to simplify `y ≠ x` somewhere.
-@[simp]
-theorem neq_imp {α : Type u} {x y : α} (h : ¬ x = y) : ¬ y = x := by intro; simp_all
+import Base.Core
namespace List
@@ -134,7 +128,7 @@ def pairwise_rel
| [] => True
| hd :: tl => allP tl (rel hd) ∧ pairwise_rel rel tl
-section Lemmas
+section
variable {α : Type u}
@@ -578,6 +572,12 @@ theorem pairwise_rel_cons {α : Type u} (rel : α → α → Prop) (hd: α) (tl:
pairwise_rel rel (hd :: tl) ↔ allP tl (rel hd) ∧ pairwise_rel rel tl
:= by simp [pairwise_rel]
-end Lemmas
+theorem lookup_not_none_imp_len_pos [BEq α] (l : List (α × β)) (key : α)
+ (hLookup : l.lookup key ≠ none) :
+ 0 < l.len := by
+ induction l <;> simp_all
+ scalar_tac
+
+end
end List
diff --git a/backends/lean/Base/Primitives/ArraySlice.lean b/backends/lean/Base/Primitives/ArraySlice.lean
index 899871af..3cc0f9c1 100644
--- a/backends/lean/Base/Primitives/ArraySlice.lean
+++ b/backends/lean/Base/Primitives/ArraySlice.lean
@@ -16,6 +16,10 @@ open Result Error core.ops.range
def Array (α : Type u) (n : Usize) := { l : List α // l.length = n.val }
+instance [BEq α] : BEq (Array α n) := SubtypeBEq _
+
+instance [BEq α] [LawfulBEq α] : LawfulBEq (Array α n) := SubtypeLawfulBEq _
+
instance (a : Type u) (n : Usize) : Arith.HasIntProp (Array a n) where
prop_ty := λ v => v.val.len = n.val
prop := λ ⟨ _, l ⟩ => by simp[Scalar.max, List.len_eq_length, *]
@@ -109,6 +113,10 @@ theorem Array.index_mut_usize_spec {α : Type u} {n : Usize} [Inhabited α] (v:
def Slice (α : Type u) := { l : List α // l.length ≤ Usize.max }
+instance [BEq α] : BEq (Slice α) := SubtypeBEq _
+
+instance [BEq α] [LawfulBEq α] : LawfulBEq (Slice α) := SubtypeLawfulBEq _
+
instance (a : Type u) : Arith.HasIntProp (Slice a) where
prop_ty := λ v => 0 ≤ v.val.len ∧ v.val.len ≤ Scalar.max ScalarTy.Usize
prop := λ ⟨ _, l ⟩ => by simp[Scalar.max, List.len_eq_length, *]
diff --git a/backends/lean/Base/Primitives/Base.lean b/backends/lean/Base/Primitives/Base.lean
index c9237e65..63fbd8c0 100644
--- a/backends/lean/Base/Primitives/Base.lean
+++ b/backends/lean/Base/Primitives/Base.lean
@@ -134,18 +134,16 @@ def Result.attach {α: Type} (o : Result α): Result { x : α // o = ok x } :=
-- MISC --
----------
--- This acts like a swap effectively in a functional pure world.
--- We return the old value of `dst`, i.e. `dst` itself.
--- The new value of `dst` is `src`.
-@[simp] def core.mem.replace (a : Type) (dst : a) (src : a) : a × a := (dst, src)
-/- [core::option::Option::take] -/
-@[simp] def Option.take (T: Type) (self: Option T): Option T × Option T := (self, .none)
-/- [core::mem::swap] -/
-@[simp] def core.mem.swap (T: Type) (a b: T): T × T := (b, a)
-
-/-- Aeneas-translated function -- useful to reduce non-recursive definitions.
- Use with `simp [ aeneas ]` -/
-register_simp_attr aeneas
+instance SubtypeBEq [BEq α] (p : α → Prop) : BEq (Subtype p) where
+ beq v0 v1 := v0.val == v1.val
+
+instance SubtypeLawfulBEq [BEq α] (p : α → Prop) [LawfulBEq α] : LawfulBEq (Subtype p) where
+ eq_of_beq {a b} h := by cases a; cases b; simp_all [BEq.beq]
+ rfl := by intro a; cases a; simp [BEq.beq]
+
+------------------------------
+---- Misc Primitives Types ---
+------------------------------
-- We don't really use raw pointers for now
structure MutRawPtr (T : Type) where
diff --git a/backends/lean/Base/Primitives/Core.lean b/backends/lean/Base/Primitives/Core.lean
index 14a51bc1..aa4a7f28 100644
--- a/backends/lean/Base/Primitives/Core.lean
+++ b/backends/lean/Base/Primitives/Core.lean
@@ -59,4 +59,17 @@ def Option.unwrap (T : Type) (x : Option T) : Result T :=
end option -- core.option
+/- [core::option::Option::take] -/
+@[simp] def Option.take (T: Type) (self: Option T): Option T × Option T := (self, .none)
+
+/- [core::mem::replace]
+
+ This acts like a swap effectively in a functional pure world.
+ We return the old value of `dst`, i.e. `dst` itself.
+ The new value of `dst` is `src`. -/
+@[simp] def mem.replace (a : Type) (dst : a) (src : a) : a × a := (dst, src)
+
+/- [core::mem::swap] -/
+@[simp] def mem.swap (T: Type) (a b: T): T × T := (b, a)
+
end core
diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean
index 31038e0d..2359c140 100644
--- a/backends/lean/Base/Primitives/Scalar.lean
+++ b/backends/lean/Base/Primitives/Scalar.lean
@@ -299,7 +299,14 @@ structure Scalar (ty : ScalarTy) where
val : Int
hmin : Scalar.min ty ≤ val
hmax : val ≤ Scalar.max ty
-deriving Repr
+deriving Repr, BEq, DecidableEq
+
+instance {ty} : BEq (Scalar ty) where
+ beq a b := a.val = b.val
+
+instance {ty} : LawfulBEq (Scalar ty) where
+ eq_of_beq {a b} := by cases a; cases b; simp[BEq.beq]
+ rfl {a} := by cases a; simp [BEq.beq]
instance (ty : ScalarTy) : CoeOut (Scalar ty) Int where
coe := λ v => v.val
diff --git a/backends/lean/Base/Primitives/Vec.lean b/backends/lean/Base/Primitives/Vec.lean
index 12789fa9..82ecb8ed 100644
--- a/backends/lean/Base/Primitives/Vec.lean
+++ b/backends/lean/Base/Primitives/Vec.lean
@@ -16,6 +16,10 @@ namespace alloc.vec
def Vec (α : Type u) := { l : List α // l.length ≤ Usize.max }
+instance [BEq α] : BEq (Vec α) := SubtypeBEq _
+
+instance [BEq α] [LawfulBEq α] : LawfulBEq (Vec α) := SubtypeLawfulBEq _
+
instance (a : Type u) : Arith.HasIntProp (Vec a) where
prop_ty := λ v => 0 ≤ v.val.len ∧ v.val.len ≤ Scalar.max ScalarTy.Usize
prop := λ ⟨ _, l ⟩ => by simp[Scalar.max, List.len_eq_length, *]
diff --git a/tests/lean/Hashmap/Properties.lean b/tests/lean/Hashmap/Properties.lean
index d9be15dd..76bd2598 100644
--- a/tests/lean/Hashmap/Properties.lean
+++ b/tests/lean/Hashmap/Properties.lean
@@ -3,34 +3,19 @@ import Hashmap.Funs
open Primitives
open Result
-namespace List
-
--- TODO: we don't want to use the original List.lookup because it uses BEq
--- TODO: rewrite rule: match x == y with ... -> if x = y then ... else ... ? (actually doesn't work because of sugar)
--- TODO: move?
-@[simp]
-def lookup' {α : Type} (ls: List (Usize × α)) (key: Usize) : Option α :=
- match ls with
- | [] => none
- | (k, x) :: tl => if k = key then some x else lookup' tl key
-
-end List
-
namespace hashmap
namespace AList
+@[simp]
def v {α : Type} (ls: AList α) : List (Usize × α) :=
match ls with
| Nil => []
| Cons k x tl => (k, x) :: v tl
-@[simp] theorem v_nil (α : Type) : (Nil : AList α).v = [] := by rfl
-@[simp] theorem v_cons {α : Type} k x (tl: AList α) : (Cons k x tl).v = (k, x) :: v tl := by rfl
-
@[simp]
abbrev lookup {α : Type} (ls: AList α) (key: Usize) : Option α :=
- ls.v.lookup' key
+ ls.v.lookup key
@[simp]
abbrev len {α : Type} (ls : AList α) : Int := ls.v.len
@@ -39,20 +24,6 @@ end AList
namespace HashMap
-namespace List
-
-end List
-
--- TODO: move
-@[simp] theorem neq_imp_nbeq [BEq α] [LawfulBEq α] (x y : α) (heq : ¬ x = y) : ¬ x == y := by simp [*]
-@[simp] theorem neq_imp_nbeq_rev [BEq α] [LawfulBEq α] (x y : α) (heq : ¬ x = y) : ¬ y == x := by simp [*]
-
--- TODO: move
--- TODO: this doesn't work because of sugar
-theorem match_lawful_beq [BEq α] [LawfulBEq α] [DecidableEq α] (x y : α) :
- (x == y) = (if x = y then true else false) := by
- split <;> simp_all
-
def distinct_keys (ls : List (Usize × α)) := ls.pairwise_rel (λ x y => x.fst ≠ y.fst)
def hash_mod_key (k : Usize) (l : Int) : Int :=
@@ -168,6 +139,8 @@ def frame_load (hm nhm : HashMap α) : Prop :=
-- This rewriting lemma is problematic below
attribute [-simp] Bool.exists_bool
+attribute [local simp] List.lookup
+
-- The proofs below are a bit expensive, so we deactivate the heart bits limit
set_option maxHeartbeats 0
@@ -276,6 +249,9 @@ theorem new_spec (α : Type) :
progress as ⟨ hm ⟩
simp_all
+--set_option pp.all true
+example (key : Usize) : key == key := by simp [beq_iff_eq]
+
theorem insert_in_list_spec_aux {α : Type} (l : Int) (key: Usize) (value: α) (l0: AList α)
(hinv : slot_s_inv_hash l (hash_mod_key key l) l0.v)
(hdk : distinct_keys l0.v) :
@@ -307,14 +283,8 @@ theorem insert_in_list_spec_aux {α : Type} (l : Int) (key: Usize) (value: α) (
if h: k = key then
rw [insert_in_list]
rw [insert_in_list_loop]
- simp [h]
- exists false; simp only [true_and, exists_eq_left', List.lookup', ↓reduceIte, AList.v_cons] -- TODO: why do we need to do this?
- split_conjs
- . intros; simp [*]
- . simp_all [slot_s_inv_hash]
- . simp at hinv; tauto
- . simp_all [slot_s_inv_hash]
- . simp_all
+ simp [h, and_assoc]
+ split_conjs <;> simp_all [slot_s_inv_hash]
else
rw [insert_in_list]
rw [insert_in_list_loop]
@@ -448,10 +418,10 @@ theorem insert_no_resize_spec {α : Type} (hm : HashMap α) (key : Usize) (value
slots := v }
exists nhm
have hupdt : lookup nhm key = some value := by
- simp [lookup, List.lookup] at *
+ simp [lookup] at *
simp_all
have hlkp : ∀ k, ¬ k = key → nhm.lookup k = hm.lookup k := by
- simp [lookup, List.lookup] at *
+ simp [lookup] at *
intro k hk
-- We have to make a case disjunction: either the hashes are different,
-- in which case we don't even lookup the same slots, or the hashes
@@ -476,7 +446,7 @@ theorem insert_no_resize_spec {α : Type} (hm : HashMap α) (key : Usize) (value
match hm.lookup key with
| none => nhm.len_s = hm.len_s + 1
| some _ => nhm.len_s = hm.len_s := by
- simp only [lookup, List.lookup, len_s, al_v, HashMap.v, slots_s_lookup] at *
+ simp only [lookup, len_s, al_v, HashMap.v, slots_s_lookup] at *
-- We have to do a case disjunction
simp_all [List.map_update_eq]
-- TODO: dependent rewrites
@@ -508,7 +478,7 @@ theorem insert_no_resize_spec {α : Type} (hm : HashMap α) (key : Usize) (value
. simp_all [frame_load, inv_base, inv_load]
simp_all
-private theorem slot_allP_not_key_lookup (slot : AList T) (h : slot.v.allP fun (k', _) => ¬k = k') :
+private theorem slot_allP_not_key_lookup (slot : AList α) (h : slot.v.allP fun (k', _) => ¬k = k') :
slot.lookup k = none := by
induction slot <;> simp_all
@@ -624,7 +594,7 @@ private theorem slots_index_len_le_flatten_len (slots : List (AList α)) (i : In
(slots.index i).len ≤ (List.map AList.v slots).flatten.len := by
match slots with
| [] =>
- simp at *; scalar_tac
+ simp at *
| slot :: slots' =>
simp at *
if hi : i = 0 then
@@ -643,7 +613,7 @@ private theorem slots_inv_lookup_imp_eq (slots : Slots α) (hInv : slots_t_inv s
(slots.val.index i).lookup key ≠ none → i = key.val % slots.val.len := by
suffices hSlot : ∀ (slot : List (Usize × α)),
slot_s_inv slots.val.len i slot →
- slot.lookup' key ≠ none →
+ slot.lookup key ≠ none →
i = key.val % slots.val.len
from by
rw [slots_t_inv, slots_s_inv] at hInv
@@ -965,8 +935,8 @@ theorem try_resize_spec {α : Type} (hm : HashMap α) (hInv : hm.inv):
simp_all [lookup, al_v, v, alloc.vec.Vec.len]
intro key
replace hLookup := hLookup key
- cases h1: (ntable2.slots.val.index (key.val % ntable2.slots.val.len)).v.lookup' key <;>
- cases h2: (hm.slots.val.index (key.val % hm.slots.val.len)).v.lookup' key <;>
+ cases h1: (ntable2.slots.val.index (key.val % ntable2.slots.val.len)).v.lookup key <;>
+ cases h2: (hm.slots.val.index (key.val % hm.slots.val.len)).v.lookup key <;>
simp_all [Slots.lookup]
else
simp [hSmaller]
@@ -1002,7 +972,7 @@ theorem get_in_list_spec {α} (key : Usize) (slot : AList α) (hLookup : slot.lo
∃ v, get_in_list α key slot = ok v ∧ slot.lookup key = some v := by
induction slot <;>
rw [get_in_list, get_in_list_loop] <;>
- simp_all [AList.lookup]
+ simp_all
split <;> simp_all
@[pspec]
@@ -1038,7 +1008,7 @@ theorem get_mut_in_list_spec {α} (key : Usize) (slot : AList α)
:= by
induction slot <;>
rw [get_mut_in_list, get_mut_in_list_loop] <;>
- simp_all [AList.lookup]
+ simp_all
split
. -- Non-recursive case
simp_all [and_assoc, slot_t_inv]
@@ -1134,12 +1104,6 @@ theorem remove_from_list_spec {α} (key : Usize) (slot : AList α) {l i} (hInv :
simp_all
. cases v1 <;> simp_all
--- TODO: move?
-theorem lookup'_not_none_imp_len_pos (l : List (Usize × α)) (key : Usize) (hLookup : l.lookup' key ≠ none) :
- 0 < l.len := by
- induction l <;> simp_all
- scalar_tac
-
private theorem lookup_not_none_imp_len_s_pos (hm : HashMap α) (key : Usize) (hLookup : hm.lookup key ≠ none)
(hNotEmpty : 0 < hm.slots.val.len) :
0 < hm.len_s := by
@@ -1148,7 +1112,7 @@ private theorem lookup_not_none_imp_len_s_pos (hm : HashMap α) (key : Usize) (h
have : key.val % hm.slots.val.len < hm.slots.val.len := by -- TODO: automate
apply Int.emod_lt_of_pos; scalar_tac
have := List.len_index_le_len_flatten hm.v (key.val % hm.slots.val.len)
- have := lookup'_not_none_imp_len_pos (hm.slots.val.index (key.val % hm.slots.val.len)).v key
+ have := List.lookup_not_none_imp_len_pos (hm.slots.val.index (key.val % hm.slots.val.len)).v key
simp_all [lookup, len_s, al_v, v]
scalar_tac