summaryrefslogtreecommitdiff
path: root/backends/lean/Base/Primitives/Scalar.lean
diff options
context:
space:
mode:
authorSon Ho2024-06-17 06:16:43 +0200
committerSon Ho2024-06-17 06:16:43 +0200
commite57e6f08e5cc34bf4e9237650f5ecbab440b9ea2 (patch)
tree1e48b2d23719d72f39282213a1806591cc35c3b8 /backends/lean/Base/Primitives/Scalar.lean
parentf3b22b5cca9bc1154f55a81c9a82dc491074067d (diff)
parent85098d7caf5e3196c2e8f92411efd2814bfed1ea (diff)
Merge branch 'son/update-lean' into has-int-pred
Diffstat (limited to '')
-rw-r--r--backends/lean/Base/Primitives/Scalar.lean154
1 files changed, 40 insertions, 114 deletions
diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean
index 8fb067e1..9f809ead 100644
--- a/backends/lean/Base/Primitives/Scalar.lean
+++ b/backends/lean/Base/Primitives/Scalar.lean
@@ -1,6 +1,5 @@
import Lean
import Lean.Meta.Tactic.Simp
-import Mathlib.Tactic.Linarith
import Base.Primitives.Base
import Base.Primitives.Core
import Base.Diverge.Base
@@ -9,6 +8,9 @@ import Base.Arith.Int
namespace Primitives
+-- Deactivate the warnings which appear when we use `#assert`
+set_option linter.hashCommand false
+
----------------------
-- MACHINE INTEGERS --
----------------------
@@ -279,11 +281,11 @@ theorem Scalar.cMax_bound ty : Scalar.cMax ty ≤ Scalar.max ty := by
theorem Scalar.cMin_suffices ty (h : Scalar.cMin ty ≤ x) : Scalar.min ty ≤ x := by
have := Scalar.cMin_bound ty
- linarith
+ omega
theorem Scalar.cMax_suffices ty (h : x ≤ Scalar.cMax ty) : x ≤ Scalar.max ty := by
have := Scalar.cMax_bound ty
- linarith
+ omega
/-- The scalar type.
@@ -310,40 +312,15 @@ theorem Scalar.bound_suffices (ty : ScalarTy) (x : Int) :
Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty
:=
λ h => by
- apply And.intro <;> have hmin := Scalar.cMin_bound ty <;> have hmax := Scalar.cMax_bound ty <;> linarith
+ apply And.intro <;> have hmin := Scalar.cMin_bound ty <;> have hmax := Scalar.cMax_bound ty <;> omega
-/- [match_pattern] attribute: allows to us `Scalar.ofIntCore` inside of patterns.
- This is particularly useful once we introduce notations like `#u32` (which
- desugards to `Scalar.ofIntCore`) as it allows to write expressions like this:
- Example:
- ```
- match x with
- | 0#u32 => ...
- | 1#u32 => ...
- | ...
- ```
- -/
-@[match_pattern] def Scalar.ofIntCore {ty : ScalarTy} (x : Int)
+def Scalar.ofIntCore {ty : ScalarTy} (x : Int)
(h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : Scalar ty :=
{ val := x, hmin := h.left, hmax := h.right }
--- The definitions below are used later to introduce nice syntax for constants,
--- like `1#u32`. We are reusing the technique described here: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Different.20elaboration.20inside.2Foutside.20of.20match.20patterns/near/425455284
-
-class InBounds (ty : ScalarTy) (x : Int) :=
- hInBounds : Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty
-
--- This trick to trigger reduction for decidable propositions comes from
--- here: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/instance.20with.20tactic.20autoparam/near/343495807
-class Decide (p : Prop) [Decidable p] : Prop where
- isTrue : p
-instance : @Decide p (.isTrue h) := @Decide.mk p (_) h
-
-instance [Decide (Scalar.cMin ty ≤ v ∧ v ≤ Scalar.cMax ty)] : InBounds ty v where
- hInBounds := Decide.isTrue
-
-@[reducible, match_pattern] def Scalar.ofInt {ty : ScalarTy} (x : Int) [InBounds ty x] : Scalar ty :=
- Scalar.ofIntCore x (Scalar.bound_suffices ty x InBounds.hInBounds)
+@[reducible] def Scalar.ofInt {ty : ScalarTy} (x : Int)
+ (hInBounds : Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty := by decide) : Scalar ty :=
+ Scalar.ofIntCore x (Scalar.bound_suffices ty x hInBounds)
@[simp] abbrev Scalar.in_bounds (ty : ScalarTy) (x : Int) : Prop :=
Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty
@@ -351,10 +328,17 @@ instance [Decide (Scalar.cMin ty ≤ v ∧ v ≤ Scalar.cMax ty)] : InBounds ty
@[simp] abbrev Scalar.check_bounds (ty : ScalarTy) (x : Int) : Bool :=
(Scalar.cMin ty ≤ x || Scalar.min ty ≤ x) ∧ (x ≤ Scalar.cMax ty || x ≤ Scalar.max ty)
+/- Discussion:
+ This coercion can be slightly annoying at times, because if we write
+ something like `u = 3` (where `u` is, for instance, as `U32`), then instead of
+ coercing `u` to `Int`, Lean will lift `3` to `U32`).
+ For now we deactivate it.
+
-- TODO(raitobezarius): the inbounds constraint is a bit ugly as we can pretty trivially
-- discharge the lhs on ≥ 0.
instance {ty: ScalarTy} [InBounds ty (Int.ofNat n)]: OfNat (Scalar ty) (n: â„•) where
ofNat := Scalar.ofInt n
+-/
theorem Scalar.check_bounds_imp_in_bounds {ty : ScalarTy} {x : Int}
(h: Scalar.check_bounds ty x) :
@@ -363,7 +347,7 @@ theorem Scalar.check_bounds_imp_in_bounds {ty : ScalarTy} {x : Int}
have ⟨ hmin, hmax ⟩ := h
have hbmin := Scalar.cMin_bound ty
have hbmax := Scalar.cMax_bound ty
- cases hmin <;> cases hmax <;> apply And.intro <;> linarith
+ cases hmin <;> cases hmax <;> apply And.intro <;> omega
theorem Scalar.check_bounds_eq_in_bounds (ty : ScalarTy) (x : Int) :
Scalar.check_bounds ty x ↔ Scalar.in_bounds ty x := by
@@ -405,9 +389,8 @@ theorem Scalar.tryMk_eq (ty : ScalarTy) (x : Int) :
simp [tryMk, ofOption, tryMkOpt]
split_ifs <;> simp
-instance (ty: ScalarTy) : InBounds ty 0 where
- hInBounds := by
- induction ty <;> simp [Scalar.cMax, Scalar.cMin, Scalar.max, Scalar.min] <;> decide
+@[simp] theorem zero_in_cbounds {ty : ScalarTy} : Scalar.cMin ty ≤ 0 ∧ 0 ≤ Scalar.cMax ty := by
+ cases ty <;> simp [Scalar.cMax, Scalar.cMin, Scalar.max, Scalar.min] <;> decide
def Scalar.neg {ty : ScalarTy} (x : Scalar ty) : Result (Scalar ty) := Scalar.tryMk ty (- x.val)
@@ -749,7 +732,6 @@ theorem Scalar.add_spec {ty} {x y : Scalar ty}
(∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y) := by
have h := @add_equiv ty x y
split at h <;> simp_all
- apply h
theorem Scalar.add_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty}
(hmax : ↑x + ↑y ≤ Scalar.max ty) :
@@ -757,7 +739,7 @@ theorem Scalar.add_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty}
have hmin : Scalar.min ty ≤ ↑x + ↑y := by
have hx := x.hmin
have hy := y.hmin
- cases ty <;> simp [min, ScalarTy.isSigned] at * <;> linarith
+ cases ty <;> simp [min, ScalarTy.isSigned] at * <;> omega
apply add_spec <;> assumption
/- Fine-grained theorems -/
@@ -844,7 +826,6 @@ theorem Scalar.sub_spec {ty} {x y : Scalar ty}
∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y := by
have h := @sub_equiv ty x y
split at h <;> simp_all
- apply h
theorem Scalar.sub_unsigned_spec {ty : ScalarTy} (s : ¬ ty.isSigned)
{x y : Scalar ty} (hmin : Scalar.min ty ≤ ↑x - ↑y) :
@@ -853,7 +834,7 @@ theorem Scalar.sub_unsigned_spec {ty : ScalarTy} (s : ¬ ty.isSigned)
have hx := x.hmin
have hxm := x.hmax
have hy := y.hmin
- cases ty <;> simp [min, max, ScalarTy.isSigned] at * <;> linarith
+ cases ty <;> simp [min, max, ScalarTy.isSigned] at * <;> omega
intros
apply sub_spec <;> assumption
@@ -1049,11 +1030,11 @@ theorem Scalar.div_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : S
have hx := x.hmin
have hy := y.hmin
simp [h] at hx hy
- have hmin : 0 ≤ ↑x / ↑y := Int.ediv_nonneg hx hy
+ have hmin : 0 ≤ x.val / y.val := Int.ediv_nonneg hx hy
have hmax : ↑x / ↑y ≤ Scalar.max ty := by
have := Int.ediv_le_self ↑y hx
have := x.hmax
- linarith
+ omega
have hs := @div_spec ty x y hnz
simp [*] at hs
apply hs
@@ -1170,7 +1151,7 @@ theorem Scalar.rem_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : S
have h : (0 : Int) < y := by int_tac
have h := Int.emod_lt_of_pos ↑x h
have := y.hmax
- linarith
+ omega
have hs := @rem_spec ty x y hnz
simp [*] at hs
simp [*]
@@ -1261,73 +1242,18 @@ def U128.ofIntCore := @Scalar.ofIntCore .U128
-- ofInt
-- TODO: typeclass?
-@[match_pattern] abbrev Isize.ofInt := @Scalar.ofInt .Isize
-@[match_pattern] abbrev I8.ofInt := @Scalar.ofInt .I8
-@[match_pattern] abbrev I16.ofInt := @Scalar.ofInt .I16
-@[match_pattern] abbrev I32.ofInt := @Scalar.ofInt .I32
-@[match_pattern] abbrev I64.ofInt := @Scalar.ofInt .I64
-@[match_pattern] abbrev I128.ofInt := @Scalar.ofInt .I128
-@[match_pattern] abbrev Usize.ofInt := @Scalar.ofInt .Usize
-@[match_pattern] abbrev U8.ofInt := @Scalar.ofInt .U8
-@[match_pattern] abbrev U16.ofInt := @Scalar.ofInt .U16
-@[match_pattern] abbrev U32.ofInt := @Scalar.ofInt .U32
-@[match_pattern] abbrev U64.ofInt := @Scalar.ofInt .U64
-@[match_pattern] abbrev U128.ofInt := @Scalar.ofInt .U128
-
-postfix:max "#isize" => Isize.ofInt
-postfix:max "#i8" => I8.ofInt
-postfix:max "#i16" => I16.ofInt
-postfix:max "#i32" => I32.ofInt
-postfix:max "#i64" => I64.ofInt
-postfix:max "#i128" => I128.ofInt
-postfix:max "#usize" => Usize.ofInt
-postfix:max "#u8" => U8.ofInt
-postfix:max "#u16" => U16.ofInt
-postfix:max "#u32" => U32.ofInt
-postfix:max "#u64" => U64.ofInt
-postfix:max "#u128" => U128.ofInt
-
-/- Testing the notations -/
-example := 0#u32
-example := 1#u32
-example := 1#i32
-example := 0#isize
-example := (-1)#isize
-example (x : U32) : Bool :=
- match x with
- | 0#u32 => true
- | _ => false
-
-example (x : U32) : Bool :=
- match x with
- | 1#u32 => true
- | _ => false
-
-example (x : I32) : Bool :=
- match x with
- | (-1)#i32 => true
- | _ => false
-
--- Notation for pattern matching
--- We make the precedence looser than the negation.
-notation:70 a:70 "#scalar" => Scalar.mk (a) _ _
-
-example {ty} (x : Scalar ty) : ℤ :=
- match x with
- | v#scalar => v
-
-example {ty} (x : Scalar ty) : Bool :=
- match x with
- | 1#scalar => true
- | _ => false
-
-example {ty} (x : Scalar ty) : Bool :=
- match x with
- | -(1 : Int)#scalar => true
- | _ => false
-
--- Testing the notations
-example : Result Usize := 0#usize + 1#usize
+abbrev Isize.ofInt := @Scalar.ofInt .Isize
+abbrev I8.ofInt := @Scalar.ofInt .I8
+abbrev I16.ofInt := @Scalar.ofInt .I16
+abbrev I32.ofInt := @Scalar.ofInt .I32
+abbrev I64.ofInt := @Scalar.ofInt .I64
+abbrev I128.ofInt := @Scalar.ofInt .I128
+abbrev Usize.ofInt := @Scalar.ofInt .Usize
+abbrev U8.ofInt := @Scalar.ofInt .U8
+abbrev U16.ofInt := @Scalar.ofInt .U16
+abbrev U32.ofInt := @Scalar.ofInt .U32
+abbrev U64.ofInt := @Scalar.ofInt .U64
+abbrev U128.ofInt := @Scalar.ofInt .U128
-- TODO: factor those lemmas out
@[simp] theorem Scalar.ofInt_val_eq {ty} (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : (Scalar.ofIntCore x h).val = x := by
@@ -1457,18 +1383,18 @@ theorem coe_max {ty: ScalarTy} (a b: Scalar ty): ↑(Max.max a b) = (Max.max (â†
-- Max theory
-- TODO: do the min theory later on.
-theorem Scalar.zero_le_unsigned {ty} (s: ¬ ty.isSigned) (x: Scalar ty): Scalar.ofInt 0 ≤ x := by
+theorem Scalar.zero_le_unsigned {ty} (s: ¬ ty.isSigned) (x: Scalar ty): Scalar.ofInt 0 (by simp) ≤ x := by
apply (Scalar.le_equiv _ _).2
convert x.hmin
cases ty <;> simp [ScalarTy.isSigned] at s <;> simp [Scalar.min]
@[simp]
theorem Scalar.max_unsigned_left_zero_eq {ty} [s: Fact (¬ ty.isSigned)] (x: Scalar ty):
- Max.max (Scalar.ofInt 0) x = x := max_eq_right (Scalar.zero_le_unsigned s.out x)
+ Max.max (Scalar.ofInt 0 (by simp)) x = x := max_eq_right (Scalar.zero_le_unsigned s.out x)
@[simp]
theorem Scalar.max_unsigned_right_zero_eq {ty} [s: Fact (¬ ty.isSigned)] (x: Scalar ty):
- Max.max x (Scalar.ofInt 0) = x := max_eq_left (Scalar.zero_le_unsigned s.out x)
+ Max.max x (Scalar.ofInt 0 (by simp)) = x := max_eq_left (Scalar.zero_le_unsigned s.out x)
-- Leading zeros
def core.num.Usize.leading_zeros (x : Usize) : U32 := sorry