summaryrefslogtreecommitdiff
path: root/backends
diff options
context:
space:
mode:
authorSon Ho2024-04-04 15:29:11 +0200
committerSon Ho2024-04-04 15:30:05 +0200
commitfc0b39c4fe48fdbab06d5fd32e0a2b7dcae674e6 (patch)
treed91a1e767d99c25d903a50e6c1af98784151b6cc /backends
parentb4f5719a10427dfc168f1210b05397599e761f9a (diff)
Fix the coerce notation for scalars and update some lemmas
Diffstat (limited to 'backends')
-rw-r--r--backends/lean/Base/Primitives/Scalar.lean19
1 files changed, 15 insertions, 4 deletions
diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean
index 3d90f1a5..7668bc59 100644
--- a/backends/lean/Base/Primitives/Scalar.lean
+++ b/backends/lean/Base/Primitives/Scalar.lean
@@ -265,6 +265,14 @@ theorem Scalar.cMax_suffices ty (h : x ≤ Scalar.cMax ty) : x ≤ Scalar.max ty
have := Scalar.cMax_bound ty
linarith
+/-- The scalar type.
+
+ We could use a subtype, but it using a custom structure type allows us
+ to have more control over the coercions and the simplifications (we tried
+ using a subtype and it caused issues especially as we had to make the Scalar
+ type non-reducible, so that we could have more control, but leading to
+ some natural equalities not being obvious to the simplifier anymore).
+ -/
structure Scalar (ty : ScalarTy) where
val : Int
hmin : Scalar.min ty ≤ val
@@ -274,6 +282,9 @@ deriving Repr
instance (ty : ScalarTy) : CoeOut (Scalar ty) Int where
coe := λ v => v.val
+/- Activate the ↑ notation -/
+attribute [coe] Scalar.val
+
theorem Scalar.bound_suffices (ty : ScalarTy) (x : Int) :
Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty ->
Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty
@@ -1119,19 +1130,19 @@ theorem Scalar.eq_equiv {ty : ScalarTy} (x y : Scalar ty) :
-- This is sometimes useful when rewriting the goal with the local assumptions
@[simp] theorem Scalar.eq_imp {ty : ScalarTy} (x y : Scalar ty) :
- x = y → (↑x : Int) = ↑y := (eq_equiv x y).mp
+ (↑x : Int) = ↑y → x = y := (eq_equiv x y).mpr
theorem Scalar.lt_equiv {ty : ScalarTy} (x y : Scalar ty) :
x < y ↔ (↑x : Int) < ↑y := by simp [LT.lt]
@[simp] theorem Scalar.lt_imp {ty : ScalarTy} (x y : Scalar ty) :
- x < y → (↑x : Int) < ↑y := (lt_equiv x y).mp
+ (↑x : Int) < (↑y) → x < y := (lt_equiv x y).mpr
theorem Scalar.le_equiv {ty : ScalarTy} (x y : Scalar ty) :
x ≤ y ↔ (↑x : Int) ≤ ↑y := by simp [LE.le]
@[simp] theorem Scalar.le_imp {ty : ScalarTy} (x y : Scalar ty) :
- x ≤ y → (↑x : Int) ≤ ↑y := (le_equiv x y).mp
+ (↑x : Int) ≤ ↑y → x ≤ y := (le_equiv x y).mpr
instance Scalar.decLt {ty} (a b : Scalar ty) : Decidable (LT.lt a b) := Int.decLt ..
instance Scalar.decLe {ty} (a b : Scalar ty) : Decidable (LE.le a b) := Int.decLe ..
@@ -1152,6 +1163,6 @@ instance (ty : ScalarTy) : DecidableEq (Scalar ty) :=
| isFalse h => isFalse (Scalar.ne_of_val_ne h)
@[simp] theorem Scalar.neq_to_neq_val {ty} : ∀ {i j : Scalar ty}, (¬ i = j) ↔ ¬ i.val = j.val := by
- intro i j; cases i; cases j; simp
+ simp [eq_equiv]
end Primitives