From 79e19aa701086de9f080357d817284559f900bcc Mon Sep 17 00:00:00 2001
From: Son Ho
Date: Wed, 12 Jun 2024 18:40:17 +0200
Subject: Update the scalar notations in Lean

---
 backends/lean/Base/Primitives.lean                 |   1 +
 backends/lean/Base/Primitives/Scalar.lean          | 123 ++++-----------------
 backends/lean/Base/Primitives/ScalarNotations.lean |  87 +++++++++++++++
 3 files changed, 109 insertions(+), 102 deletions(-)
 create mode 100644 backends/lean/Base/Primitives/ScalarNotations.lean

diff --git a/backends/lean/Base/Primitives.lean b/backends/lean/Base/Primitives.lean
index f80c2004..93617049 100644
--- a/backends/lean/Base/Primitives.lean
+++ b/backends/lean/Base/Primitives.lean
@@ -1,6 +1,7 @@
 import Base.Primitives.Base
 import Base.Tuples
 import Base.Primitives.Scalar
+import Base.Primitives.ScalarNotations
 import Base.Primitives.ArraySlice
 import Base.Primitives.Vec
 import Base.Primitives.Alloc
diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean
index 157ade2c..f4264b9b 100644
--- a/backends/lean/Base/Primitives/Scalar.lean
+++ b/backends/lean/Base/Primitives/Scalar.lean
@@ -312,38 +312,13 @@ theorem Scalar.bound_suffices (ty : ScalarTy) (x : Int) :
   λ h => by
   apply And.intro <;> have hmin := Scalar.cMin_bound ty <;> have hmax := Scalar.cMax_bound ty <;> linarith
 
-/- [match_pattern] attribute: allows to us `Scalar.ofIntCore` inside of patterns.
-   This is particularly useful once we introduce notations like `#u32` (which
-   desugards to `Scalar.ofIntCore`) as it allows to write expressions like this:
-   Example:
-   ```
-   match x with
-   | 0#u32 => ...
-   | 1#u32 => ...
-   | ...
-   ```
- -/
-@[match_pattern] def Scalar.ofIntCore {ty : ScalarTy} (x : Int)
+def Scalar.ofIntCore {ty : ScalarTy} (x : Int)
   (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : Scalar ty :=
   { val := x, hmin := h.left, hmax := h.right }
 
--- The definitions below are used later to introduce nice syntax for constants,
--- like `1#u32`. We are reusing the technique described here: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Different.20elaboration.20inside.2Foutside.20of.20match.20patterns/near/425455284
-
-class InBounds (ty : ScalarTy) (x : Int) :=
-  hInBounds : Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty
-
--- This trick to trigger reduction for decidable propositions comes from
--- here: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/instance.20with.20tactic.20autoparam/near/343495807
-class Decide (p : Prop) [Decidable p] : Prop where
-  isTrue : p
-instance : @Decide p (.isTrue h) := @Decide.mk p (_) h
-
-instance [Decide (Scalar.cMin ty ≤ v ∧ v ≤ Scalar.cMax ty)] : InBounds ty v where
-  hInBounds := Decide.isTrue
-
-@[reducible, match_pattern] def Scalar.ofInt {ty : ScalarTy} (x : Int) [InBounds ty x] : Scalar ty :=
-  Scalar.ofIntCore x (Scalar.bound_suffices ty x InBounds.hInBounds)
+@[reducible] def Scalar.ofInt {ty : ScalarTy} (x : Int)
+  (hInBounds : Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty := by decide) : Scalar ty :=
+  Scalar.ofIntCore x (Scalar.bound_suffices ty x hInBounds)
 
 @[simp] abbrev Scalar.in_bounds (ty : ScalarTy) (x : Int) : Prop :=
   Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty
@@ -412,9 +387,8 @@ theorem Scalar.tryMk_eq (ty : ScalarTy) (x : Int) :
   simp [tryMk, ofOption, tryMkOpt]
   split_ifs <;> simp
 
-instance (ty: ScalarTy) : InBounds ty 0 where
-  hInBounds := by
-    induction ty <;> simp [Scalar.cMax, Scalar.cMin, Scalar.max, Scalar.min] <;> decide
+@[simp] theorem zero_in_cbounds {ty : ScalarTy} : Scalar.cMin ty ≤ 0 ∧ 0 ≤ Scalar.cMax ty := by
+  cases ty <;> simp [Scalar.cMax, Scalar.cMin, Scalar.max, Scalar.min] <;> decide
 
 def Scalar.neg {ty : ScalarTy} (x : Scalar ty) : Result (Scalar ty) := Scalar.tryMk ty (- x.val)
 
@@ -1268,73 +1242,18 @@ def U128.ofIntCore  := @Scalar.ofIntCore .U128
 
 --  ofInt
 -- TODO: typeclass?
-@[match_pattern] abbrev Isize.ofInt := @Scalar.ofInt .Isize
-@[match_pattern] abbrev I8.ofInt    := @Scalar.ofInt .I8
-@[match_pattern] abbrev I16.ofInt   := @Scalar.ofInt .I16
-@[match_pattern] abbrev I32.ofInt   := @Scalar.ofInt .I32
-@[match_pattern] abbrev I64.ofInt   := @Scalar.ofInt .I64
-@[match_pattern] abbrev I128.ofInt  := @Scalar.ofInt .I128
-@[match_pattern] abbrev Usize.ofInt := @Scalar.ofInt .Usize
-@[match_pattern] abbrev U8.ofInt    := @Scalar.ofInt .U8
-@[match_pattern] abbrev U16.ofInt   := @Scalar.ofInt .U16
-@[match_pattern] abbrev U32.ofInt   := @Scalar.ofInt .U32
-@[match_pattern] abbrev U64.ofInt   := @Scalar.ofInt .U64
-@[match_pattern] abbrev U128.ofInt  := @Scalar.ofInt .U128
-
-postfix:max "#isize" => Isize.ofInt
-postfix:max "#i8"    => I8.ofInt
-postfix:max "#i16"   => I16.ofInt
-postfix:max "#i32"   => I32.ofInt
-postfix:max "#i64"   => I64.ofInt
-postfix:max "#i128"  => I128.ofInt
-postfix:max "#usize" => Usize.ofInt
-postfix:max "#u8"    => U8.ofInt
-postfix:max "#u16"   => U16.ofInt
-postfix:max "#u32"   => U32.ofInt
-postfix:max "#u64"   => U64.ofInt
-postfix:max "#u128"  => U128.ofInt
-
-/- Testing the notations -/
-example := 0#u32
-example := 1#u32
-example := 1#i32
-example := 0#isize
-example := (-1)#isize
-example (x : U32) : Bool :=
-  match x with
-  | 0#u32 => true
-  | _ => false
-
-example (x : U32) : Bool :=
-  match x with
-  | 1#u32 => true
-  | _ => false
-
-example (x : I32) : Bool :=
-  match x with
-  | (-1)#i32 => true
-  | _ => false
-
--- Notation for pattern matching
--- We make the precedence looser than the negation.
-notation:70 a:70 "#scalar" => Scalar.mk (a) _ _
-
-example {ty} (x : Scalar ty) : ℤ :=
-  match x with
-  | v#scalar => v
-
-example {ty} (x : Scalar ty) : Bool :=
-  match x with
-  | 1#scalar => true
-  | _ => false
-
-example {ty} (x : Scalar ty) : Bool :=
-  match x with
-  | -(1 : Int)#scalar => true
-  | _ => false
-
--- Testing the notations
-example : Result Usize := 0#usize + 1#usize
+abbrev Isize.ofInt := @Scalar.ofInt .Isize
+abbrev I8.ofInt    := @Scalar.ofInt .I8
+abbrev I16.ofInt   := @Scalar.ofInt .I16
+abbrev I32.ofInt   := @Scalar.ofInt .I32
+abbrev I64.ofInt   := @Scalar.ofInt .I64
+abbrev I128.ofInt  := @Scalar.ofInt .I128
+abbrev Usize.ofInt := @Scalar.ofInt .Usize
+abbrev U8.ofInt    := @Scalar.ofInt .U8
+abbrev U16.ofInt   := @Scalar.ofInt .U16
+abbrev U32.ofInt   := @Scalar.ofInt .U32
+abbrev U64.ofInt   := @Scalar.ofInt .U64
+abbrev U128.ofInt  := @Scalar.ofInt .U128
 
 -- TODO: factor those lemmas out
 @[simp] theorem Scalar.ofInt_val_eq {ty} (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : (Scalar.ofIntCore x h).val = x := by
@@ -1464,18 +1383,18 @@ theorem coe_max {ty: ScalarTy} (a b: Scalar ty): ↑(Max.max a b) = (Max.max (
 -- Max theory
 -- TODO: do the min theory later on.
 
-theorem Scalar.zero_le_unsigned {ty} (s: ¬ ty.isSigned) (x: Scalar ty): Scalar.ofInt 0 ≤ x := by
+theorem Scalar.zero_le_unsigned {ty} (s: ¬ ty.isSigned) (x: Scalar ty): Scalar.ofInt 0 (by simp) ≤ x := by
   apply (Scalar.le_equiv _ _).2
   convert x.hmin
   cases ty <;> simp [ScalarTy.isSigned] at s <;> simp [Scalar.min]
 
 @[simp]
 theorem Scalar.max_unsigned_left_zero_eq {ty} [s: Fact (¬ ty.isSigned)] (x: Scalar ty):
-  Max.max (Scalar.ofInt 0) x = x := max_eq_right (Scalar.zero_le_unsigned s.out x)
+  Max.max (Scalar.ofInt 0 (by simp)) x = x := max_eq_right (Scalar.zero_le_unsigned s.out x)
 
 @[simp]
 theorem Scalar.max_unsigned_right_zero_eq {ty} [s: Fact (¬ ty.isSigned)] (x: Scalar ty):
-  Max.max x (Scalar.ofInt 0) = x := max_eq_left (Scalar.zero_le_unsigned s.out x)
+  Max.max x (Scalar.ofInt 0 (by simp)) = x := max_eq_left (Scalar.zero_le_unsigned s.out x)
 
 -- Leading zeros
 def core.num.Usize.leading_zeros (x : Usize) : U32 := sorry
diff --git a/backends/lean/Base/Primitives/ScalarNotations.lean b/backends/lean/Base/Primitives/ScalarNotations.lean
new file mode 100644
index 00000000..50d8c1b6
--- /dev/null
+++ b/backends/lean/Base/Primitives/ScalarNotations.lean
@@ -0,0 +1,87 @@
+import Lean
+import Lean.Meta.Tactic.Simp
+import Mathlib.Tactic.Linarith
+import Base.Primitives.Scalar
+import Base.Arith
+
+namespace Primitives
+
+open Lean Meta Elab Term
+
+macro x:term:max "#isize" : term => `(Isize.ofInt $x (by scalar_tac))
+macro x:term:max "#i8"    : term => `(I8.ofInt $x (by scalar_tac))
+macro x:term:max "#i16"   : term => `(I16.ofInt $x (by scalar_tac))
+macro x:term:max "#i32"   : term => `(I32.ofInt $x (by scalar_tac))
+macro x:term:max "#i64"   : term => `(I64.ofInt $x (by scalar_tac))
+macro x:term:max "#i128"  : term => `(I128.ofInt $x (by scalar_tac))
+macro x:term:max "#usize" : term => `(Usize.ofInt $x (by scalar_tac))
+macro x:term:max "#u8"    : term => `(U8.ofInt $x (by scalar_tac))
+macro x:term:max "#u16"   : term => `(U16.ofInt $x (by scalar_tac))
+macro x:term:max "#u32"   : term => `(U32.ofInt $x (by scalar_tac))
+macro x:term:max "#u64"   : term => `(U64.ofInt $x (by scalar_tac))
+macro x:term:max "#u128"  : term => `(U128.ofInt $x (by scalar_tac))
+
+macro x:term:max noWs "u32"   : term => `(U32.ofInt $x (by scalar_tac))
+
+-- Notation for pattern matching
+-- We make the precedence looser than the negation.
+notation:70 a:70 "#scalar" => Scalar.mk (a) _ _
+
+/- Testing the notations -/
+example := 0#u32
+example := 1#u32
+example := 1#i32
+example := 0#isize
+example := (-1)#isize
+
+/-
+-- This doesn't work anymore
+example (x : U32) : Bool :=
+  match x with
+  | 0#u32 => true
+  | _ => false
+
+example (x : U32) : Bool :=
+  match x with
+  | 1#u32 => true
+  | _ => false
+
+example (x : I32) : Bool :=
+  match x with
+  | (-1)#i32 => true
+  | _ => false
+-/
+
+example (x : U32) : Bool :=
+  match x with
+  | 0#scalar => true
+  | _ => false
+
+example (x : U32) : Bool :=
+  match x with
+  | 1#scalar => true
+  | _ => false
+
+example (x : I32) : Bool :=
+  match x with
+  | (-1)#scalar => true
+  | _ => false
+
+example {ty} (x : Scalar ty) : ℤ :=
+  match x with
+  | v#scalar => v
+
+example {ty} (x : Scalar ty) : Bool :=
+  match x with
+  | 1#scalar => true
+  | _ => false
+
+example {ty} (x : Scalar ty) : Bool :=
+  match x with
+  | -(1 : Int)#scalar => true
+  | _ => false
+
+-- Testing the notations
+example : Result Usize := 0#usize + 1#usize
+
+end Primitives
-- 
cgit v1.2.3