From 124ee77181c4255e2c8f730305b0b1b7802b9a58 Mon Sep 17 00:00:00 2001
From: Son Ho
Date: Thu, 7 Mar 2024 17:43:55 +0100
Subject: Add a notation for tuple field accesses in Lean

---
 backends/lean/Base/Primitives/Base.lean | 51 +++++++++++++++++++++++++++++++++
 1 file changed, 51 insertions(+)

(limited to 'backends/lean/Base')

diff --git a/backends/lean/Base/Primitives/Base.lean b/backends/lean/Base/Primitives/Base.lean
index 9dbaf133..adec9a8b 100644
--- a/backends/lean/Base/Primitives/Base.lean
+++ b/backends/lean/Base/Primitives/Base.lean
@@ -123,6 +123,57 @@ def Result.attach {α: Type} (o : Result α): Result { x : α // o = ret x } :=
   simp [Bind.bind]
   cases e <;> simp
 
+-------------------------------
+-- Tuple field access syntax --
+-------------------------------
+-- Declare new syntax `a.#i` for accessing the `i`-th term in a tuple
+-- The `noWs` parser is used to ensure there is no whitespace.
+syntax term noWs ".#" noWs num : term
+
+open Lean Meta Elab Term
+
+-- Auxliary function for computing the number of elements in a tuple (`Prod`) type.
+def getArity (type : Expr) : Nat :=
+  match type with
+  | .app (.app (.const ``Prod _) _) as => getArity as + 1
+  | _ => 1 -- It is not product
+
+-- Given a `tuple` of size `n`, construct a term that for accessing the `i`-th element
+def mkGetIdx (tuple : Expr) (n : Nat) (i : Nat) : MetaM Expr := do
+  match i with
+  | 0 => mkAppM ``Prod.fst #[tuple]
+  | i+1 =>
+    if n = 2 then
+      -- If the tuple has only two elements and `i` is not `0`,
+      -- we just return the second element.
+      mkAppM ``Prod.snd #[tuple]
+    else
+      -- Otherwise, we continue with the rest of the tuple.
+      let tuple ← mkAppM ``Prod.snd #[tuple]
+      mkGetIdx tuple (n-1) i
+
+-- Now, we define the elaboration function for the new syntax `a#i`
+elab_rules : term
+| `($a:term.#$i:num) => do
+  -- Convert `i : Syntax` into a natural number
+  let i := i.getNat
+  -- Return error if it is 0.
+  unless i ≥ 0 do
+    throwError "tuple index must be greater or equal to 0"
+  -- Convert `a : Syntax` into an `tuple : Expr` without providing expected type
+  let tuple ← elabTerm a none
+  let type ← inferType tuple
+  -- Instantiate assigned metavariable occurring in `type`
+  let type ← instantiateMVars type
+  -- Ensure `tuple`'s type is a `Prod`uct.
+  unless type.isAppOf ``Prod do
+    throwError "tuple expected{indentExpr type}"
+  let n := getArity type
+  -- Ensure `i` is a valid index
+  unless i < n do
+    throwError "invalid tuple access at {i}, tuple has {n} elements"
+  mkGetIdx tuple n i
+
 ----------
 -- MISC --
 ----------
-- 
cgit v1.2.3


From 23ce25c77052c02312f19f17c51fe0b61d6abc93 Mon Sep 17 00:00:00 2001
From: Son Ho
Date: Thu, 7 Mar 2024 18:17:41 +0100
Subject: Introduce a notation for constant scalars in match patterns

---
 backends/lean/Base/Primitives/Scalar.lean | 29 +++++++++++++++++++++++++++++
 1 file changed, 29 insertions(+)

(limited to 'backends/lean/Base')

diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean
index 285bc7fb..422cbc6a 100644
--- a/backends/lean/Base/Primitives/Scalar.lean
+++ b/backends/lean/Base/Primitives/Scalar.lean
@@ -488,6 +488,17 @@ class HNeg (α : Type u) (β : outParam (Type v)) where
 
 prefix:75  "-"   => HNeg.hNeg
 
+/- We need this, otherwise we break pattern matching like in:
+
+   ```
+   def is_minus_one (x : Int) : Bool :=
+     match x with
+     | -1 => true
+     | _ => false
+   ```
+-/
+attribute [match_pattern] HNeg.hNeg
+
 instance : HNeg Isize (Result Isize) where hNeg x := Scalar.neg x
 instance : HNeg I8 (Result I8) where hNeg x := Scalar.neg x
 instance : HNeg I16 (Result I16) where hNeg x := Scalar.neg x
@@ -1113,4 +1124,22 @@ instance (ty : ScalarTy) : DecidableEq (Scalar ty) :=
 --   else
 --     .fail integerOverflow
 
+-- Notation for pattern matching
+-- We make the precedence looser than the negation.
+notation:70 a:70 "#scalar" => Scalar.mk (a) _ _
+
+example {ty} (x : Scalar ty) : ℤ :=
+  match x with
+  | v#scalar => v
+
+example {ty} (x : Scalar ty) : Bool :=
+  match x with
+  | 1#scalar => true
+  | _ => false
+
+example {ty} (x : Scalar ty) : Bool :=
+  match x with
+  | -(1 : Int)#scalar => true
+  | _ => false
+
 end Primitives
-- 
cgit v1.2.3


From bc397dea5c5a67766c9c0381efad222524f68881 Mon Sep 17 00:00:00 2001
From: Son Ho
Date: Fri, 8 Mar 2024 07:56:44 +0100
Subject: Update the notation for heterogeneous negation

---
 backends/lean/Base/Primitives/Scalar.lean | 15 ++++++++++++---
 1 file changed, 12 insertions(+), 3 deletions(-)

(limited to 'backends/lean/Base')

diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean
index 422cbc6a..3afd13d2 100644
--- a/backends/lean/Base/Primitives/Scalar.lean
+++ b/backends/lean/Base/Primitives/Scalar.lean
@@ -478,15 +478,24 @@ instance (ty : ScalarTy) : Inhabited (Scalar ty) := by
 Remark: there is no heterogeneous negation in the Lean prelude: we thus introduce
 one here.
 
-The notation typeclass for heterogeneous addition.
-This enables the notation `- a : β` where `a : α`.
+The notation typeclass for heterogeneous negation.
 -/
 class HNeg (α : Type u) (β : outParam (Type v)) where
   /-- `- a` computes the negation of `a`.
   The meaning of this notation is type-dependent. -/
   hNeg : α → β
 
-prefix:75  "-"   => HNeg.hNeg
+/- Notation for heterogeneous negation.
+
+   We initially used the notation "-" but it conflicted with the homogeneous
+   negation too much. In particular, it made terms like `-10` ambiguous,
+   and seemingly caused to backtracking in elaboration, leading to definitions
+   like arrays of constants to take an unreasonable time to get elaborated
+   and type-checked.
+
+   TODO: PR to replace Neg with HNeg in Lean?
+ -/
+prefix:75  "-."   => HNeg.hNeg
 
 /- We need this, otherwise we break pattern matching like in:
 
-- 
cgit v1.2.3


From 46b126f4e0e86f14475bc310e150948434726dc7 Mon Sep 17 00:00:00 2001
From: Son Ho
Date: Fri, 8 Mar 2024 08:50:50 +0100
Subject: Update the handling of notations like #u32 or #isize

---
 backends/lean/Base/Arith/Scalar.lean          |   2 +-
 backends/lean/Base/Primitives/ArraySlice.lean |   2 +-
 backends/lean/Base/Primitives/Scalar.lean     | 168 ++++++++++++++++----------
 backends/lean/Base/Primitives/Vec.lean        |   2 +-
 4 files changed, 104 insertions(+), 70 deletions(-)

(limited to 'backends/lean/Base')

diff --git a/backends/lean/Base/Arith/Scalar.lean b/backends/lean/Base/Arith/Scalar.lean
index 43fd2766..9441be86 100644
--- a/backends/lean/Base/Arith/Scalar.lean
+++ b/backends/lean/Base/Arith/Scalar.lean
@@ -74,7 +74,7 @@ example : U32.ofInt 1 ≤ U32.max := by
   scalar_tac
 
 example (x : Int) (h0 : 0 ≤ x) (h1 : x ≤ U32.max) :
-  U32.ofInt x (by constructor <;> scalar_tac) ≤ U32.max := by
+  U32.ofIntCore x (by constructor <;> scalar_tac) ≤ U32.max := by
   scalar_tac
 
 -- Not equal
diff --git a/backends/lean/Base/Primitives/ArraySlice.lean b/backends/lean/Base/Primitives/ArraySlice.lean
index c90a85b8..e1a39d40 100644
--- a/backends/lean/Base/Primitives/ArraySlice.lean
+++ b/backends/lean/Base/Primitives/ArraySlice.lean
@@ -131,7 +131,7 @@ def Slice.new (α : Type u): Slice α := ⟨ [], by apply Scalar.cMax_suffices .
 
 -- TODO: very annoying that the α is an explicit parameter
 def Slice.len (α : Type u) (v : Slice α) : Usize :=
-  Usize.ofIntCore v.val.len (by scalar_tac) (by scalar_tac)
+  Usize.ofIntCore v.val.len (by constructor <;> scalar_tac)
 
 @[simp]
 theorem Slice.len_val {α : Type u} (v : Slice α) : (Slice.len α v).val = v.length :=
diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean
index 3afd13d2..bf6b01a6 100644
--- a/backends/lean/Base/Primitives/Scalar.lean
+++ b/backends/lean/Base/Primitives/Scalar.lean
@@ -281,25 +281,38 @@ theorem Scalar.bound_suffices (ty : ScalarTy) (x : Int) :
   λ h => by
   apply And.intro <;> have hmin := Scalar.cMin_bound ty <;> have hmax := Scalar.cMax_bound ty <;> linarith
 
-def Scalar.ofIntCore {ty : ScalarTy} (x : Int)
-  (hmin : Scalar.min ty ≤ x) (hmax : x ≤ Scalar.max ty) : Scalar ty :=
-  { val := x, hmin := hmin, hmax := hmax }
-
--- Tactic to prove that integers are in bounds
--- TODO: use this: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/instance.20with.20tactic.20autoparam
-syntax "intlit" : tactic
-macro_rules
-  | `(tactic| intlit) => `(tactic| apply Scalar.bound_suffices; decide)
-
-def Scalar.ofInt {ty : ScalarTy} (x : Int)
-  (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty := by intlit) : Scalar ty :=
-  -- Remark: we initially wrote:
-  --  let ⟨ hmin, hmax ⟩ := h
-  --  Scalar.ofIntCore x hmin hmax
-  -- We updated to the line below because a similar pattern in `Scalar.tryMk`
-  -- made reduction block. Both versions seem to work for `Scalar.ofInt`, though.
-  -- TODO: investigate
-  Scalar.ofIntCore x h.left h.right
+/- [match_pattern] attribute: allows to us `Scalar.ofIntCore` inside of patterns.
+   This is particularly useful once we introduce notations like `#u32` (which
+   desugards to `Scalar.ofIntCore`) as it allows to write expressions like this:
+   Example:
+   ```
+   match x with
+   | 0#u32 => ...
+   | 1#u32 => ...
+   | ...
+   ```
+ -/
+@[match_pattern] def Scalar.ofIntCore {ty : ScalarTy} (x : Int)
+  (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : Scalar ty :=
+  { val := x, hmin := h.left, hmax := h.right }
+
+-- The definitions below are used later to introduce nice syntax for constants,
+-- like `1#u32`. We are reusing the technique described here: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Different.20elaboration.20inside.2Foutside.20of.20match.20patterns/near/425455284
+
+class InBounds (ty : ScalarTy) (x : Int) :=
+  hInBounds : Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty
+
+-- This trick to trigger reduction for decidable propositions comes from
+-- here: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/instance.20with.20tactic.20autoparam/near/343495807
+class Decide (p : Prop) [Decidable p] : Prop where
+  isTrue : p
+instance : @Decide p (.isTrue h) := @Decide.mk p (_) h
+
+instance [Decide (Scalar.cMin ty ≤ v ∧ v ≤ Scalar.cMax ty)] : InBounds ty v where
+  hInBounds := Decide.isTrue
+
+@[reducible, match_pattern] def Scalar.ofInt {ty : ScalarTy} (x : Int) [InBounds ty x] : Scalar ty :=
+  Scalar.ofIntCore x (Scalar.bound_suffices ty x InBounds.hInBounds)
 
 @[simp] def Scalar.check_bounds (ty : ScalarTy) (x : Int) : Bool :=
   (Scalar.cMin ty ≤ x || Scalar.min ty ≤ x) ∧ (x ≤ Scalar.cMax ty || x ≤ Scalar.max ty)
@@ -326,7 +339,7 @@ def Scalar.tryMk (ty : ScalarTy) (x : Int) : Result (Scalar ty) :=
     -- ```
     -- then normalization blocks (for instance, some proofs which use reflexivity fail).
     -- However, the version below doesn't block reduction (TODO: investigate):
-    return Scalar.ofInt x (Scalar.check_bounds_prop h)
+    return Scalar.ofIntCore x (Scalar.check_bounds_prop h)
   else fail integerOverflow
 
 def Scalar.neg {ty : ScalarTy} (x : Scalar ty) : Result (Scalar ty) := Scalar.tryMk ty (- x.val)
@@ -439,8 +452,8 @@ instance (ty : ScalarTy) : Inhabited (Scalar ty) := by
   constructor; cases ty <;> apply (Scalar.ofInt 0)
 
 -- TODO: reducible?
-@[reducible] def core_isize_min : Isize := Scalar.ofInt Isize.min (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Isize))
-@[reducible] def core_isize_max : Isize := Scalar.ofInt Isize.max (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Isize))
+@[reducible] def core_isize_min : Isize := Scalar.ofIntCore Isize.min (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Isize))
+@[reducible] def core_isize_max : Isize := Scalar.ofIntCore Isize.max (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Isize))
 @[reducible] def core_i8_min : I8 := Scalar.ofInt I8.min
 @[reducible] def core_i8_max : I8 := Scalar.ofInt I8.max
 @[reducible] def core_i16_min : I16 := Scalar.ofInt I16.min
@@ -453,8 +466,8 @@ instance (ty : ScalarTy) : Inhabited (Scalar ty) := by
 @[reducible] def core_i128_max : I128 := Scalar.ofInt I128.max
 
 -- TODO: reducible?
-@[reducible] def core_usize_min : Usize := Scalar.ofInt Usize.min
-@[reducible] def core_usize_max : Usize := Scalar.ofInt Usize.max (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Usize))
+@[reducible] def core_usize_min : Usize := Scalar.ofIntCore Usize.min (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Usize))
+@[reducible] def core_usize_max : Usize := Scalar.ofIntCore Usize.max (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Usize))
 @[reducible] def core_u8_min : U8 := Scalar.ofInt U8.min
 @[reducible] def core_u8_max : U8 := Scalar.ofInt U8.max
 @[reducible] def core_u16_min : U16 := Scalar.ofInt U16.min
@@ -985,18 +998,18 @@ def U128.ofIntCore  := @Scalar.ofIntCore .U128
 
 --  ofInt
 -- TODO: typeclass?
-abbrev Isize.ofInt := @Scalar.ofInt .Isize
-abbrev I8.ofInt    := @Scalar.ofInt .I8
-abbrev I16.ofInt   := @Scalar.ofInt .I16
-abbrev I32.ofInt   := @Scalar.ofInt .I32
-abbrev I64.ofInt   := @Scalar.ofInt .I64
-abbrev I128.ofInt  := @Scalar.ofInt .I128
-abbrev Usize.ofInt := @Scalar.ofInt .Usize
-abbrev U8.ofInt    := @Scalar.ofInt .U8
-abbrev U16.ofInt   := @Scalar.ofInt .U16
-abbrev U32.ofInt   := @Scalar.ofInt .U32
-abbrev U64.ofInt   := @Scalar.ofInt .U64
-abbrev U128.ofInt  := @Scalar.ofInt .U128
+@[match_pattern] abbrev Isize.ofInt := @Scalar.ofInt .Isize
+@[match_pattern] abbrev I8.ofInt    := @Scalar.ofInt .I8
+@[match_pattern] abbrev I16.ofInt   := @Scalar.ofInt .I16
+@[match_pattern] abbrev I32.ofInt   := @Scalar.ofInt .I32
+@[match_pattern] abbrev I64.ofInt   := @Scalar.ofInt .I64
+@[match_pattern] abbrev I128.ofInt  := @Scalar.ofInt .I128
+@[match_pattern] abbrev Usize.ofInt := @Scalar.ofInt .Usize
+@[match_pattern] abbrev U8.ofInt    := @Scalar.ofInt .U8
+@[match_pattern] abbrev U16.ofInt   := @Scalar.ofInt .U16
+@[match_pattern] abbrev U32.ofInt   := @Scalar.ofInt .U32
+@[match_pattern] abbrev U64.ofInt   := @Scalar.ofInt .U64
+@[match_pattern] abbrev U128.ofInt  := @Scalar.ofInt .U128
 
 postfix:max "#isize" => Isize.ofInt
 postfix:max "#i8"    => I8.ofInt
@@ -1011,47 +1024,86 @@ postfix:max "#u32"   => U32.ofInt
 postfix:max "#u64"   => U64.ofInt
 postfix:max "#u128"  => U128.ofInt
 
+/- Testing the notations -/
+example := 0#u32
+example := 1#u32
+example := 1#i32
+example := 0#isize
+example := (-1)#isize
+example (x : U32) : Bool :=
+  match x with
+  | 0#u32 => true
+  | _ => false
+
+example (x : U32) : Bool :=
+  match x with
+  | 1#u32 => true
+  | _ => false
+
+example (x : I32) : Bool :=
+  match x with
+  | (-1)#i32 => true
+  | _ => false
+
+-- Notation for pattern matching
+-- We make the precedence looser than the negation.
+notation:70 a:70 "#scalar" => Scalar.mk (a) _ _
+
+example {ty} (x : Scalar ty) : ℤ :=
+  match x with
+  | v#scalar => v
+
+example {ty} (x : Scalar ty) : Bool :=
+  match x with
+  | 1#scalar => true
+  | _ => false
+
+example {ty} (x : Scalar ty) : Bool :=
+  match x with
+  | -(1 : Int)#scalar => true
+  | _ => false
+
 -- Testing the notations
 example : Result Usize := 0#usize + 1#usize
 
 -- TODO: factor those lemmas out
-@[simp] theorem Scalar.ofInt_val_eq {ty} (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : (Scalar.ofInt x h).val = x := by
+@[simp] theorem Scalar.ofInt_val_eq {ty} (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : (Scalar.ofIntCore x h).val = x := by
   simp [Scalar.ofInt, Scalar.ofIntCore]
 
-@[simp] theorem Isize.ofInt_val_eq (h : Scalar.min ScalarTy.Isize ≤ x ∧ x ≤ Scalar.max ScalarTy.Isize) : (Isize.ofInt x h).val = x := by
+@[simp] theorem Isize.ofInt_val_eq (h : Scalar.min ScalarTy.Isize ≤ x ∧ x ≤ Scalar.max ScalarTy.Isize) : (Isize.ofIntCore x h).val = x := by
   apply Scalar.ofInt_val_eq h
 
-@[simp] theorem I8.ofInt_val_eq (h : Scalar.min ScalarTy.I8 ≤ x ∧ x ≤ Scalar.max ScalarTy.I8) : (I8.ofInt x h).val = x := by
+@[simp] theorem I8.ofInt_val_eq (h : Scalar.min ScalarTy.I8 ≤ x ∧ x ≤ Scalar.max ScalarTy.I8) : (I8.ofIntCore x h).val = x := by
   apply Scalar.ofInt_val_eq h
 
-@[simp] theorem I16.ofInt_val_eq (h : Scalar.min ScalarTy.I16 ≤ x ∧ x ≤ Scalar.max ScalarTy.I16) : (I16.ofInt x h).val = x := by
+@[simp] theorem I16.ofInt_val_eq (h : Scalar.min ScalarTy.I16 ≤ x ∧ x ≤ Scalar.max ScalarTy.I16) : (I16.ofIntCore x h).val = x := by
   apply Scalar.ofInt_val_eq h
 
-@[simp] theorem I32.ofInt_val_eq (h : Scalar.min ScalarTy.I32 ≤ x ∧ x ≤ Scalar.max ScalarTy.I32) : (I32.ofInt x h).val = x := by
+@[simp] theorem I32.ofInt_val_eq (h : Scalar.min ScalarTy.I32 ≤ x ∧ x ≤ Scalar.max ScalarTy.I32) : (I32.ofIntCore x h).val = x := by
   apply Scalar.ofInt_val_eq h
 
-@[simp] theorem I64.ofInt_val_eq (h : Scalar.min ScalarTy.I64 ≤ x ∧ x ≤ Scalar.max ScalarTy.I64) : (I64.ofInt x h).val = x := by
+@[simp] theorem I64.ofInt_val_eq (h : Scalar.min ScalarTy.I64 ≤ x ∧ x ≤ Scalar.max ScalarTy.I64) : (I64.ofIntCore x h).val = x := by
   apply Scalar.ofInt_val_eq h
 
-@[simp] theorem I128.ofInt_val_eq (h : Scalar.min ScalarTy.I128 ≤ x ∧ x ≤ Scalar.max ScalarTy.I128) : (I128.ofInt x h).val = x := by
+@[simp] theorem I128.ofInt_val_eq (h : Scalar.min ScalarTy.I128 ≤ x ∧ x ≤ Scalar.max ScalarTy.I128) : (I128.ofIntCore x h).val = x := by
   apply Scalar.ofInt_val_eq h
 
-@[simp] theorem Usize.ofInt_val_eq (h : Scalar.min ScalarTy.Usize ≤ x ∧ x ≤ Scalar.max ScalarTy.Usize) : (Usize.ofInt x h).val = x := by
+@[simp] theorem Usize.ofInt_val_eq (h : Scalar.min ScalarTy.Usize ≤ x ∧ x ≤ Scalar.max ScalarTy.Usize) : (Usize.ofIntCore x h).val = x := by
   apply Scalar.ofInt_val_eq h
 
-@[simp] theorem U8.ofInt_val_eq (h : Scalar.min ScalarTy.U8 ≤ x ∧ x ≤ Scalar.max ScalarTy.U8) : (U8.ofInt x h).val = x := by
+@[simp] theorem U8.ofInt_val_eq (h : Scalar.min ScalarTy.U8 ≤ x ∧ x ≤ Scalar.max ScalarTy.U8) : (U8.ofIntCore x h).val = x := by
   apply Scalar.ofInt_val_eq h
 
-@[simp] theorem U16.ofInt_val_eq (h : Scalar.min ScalarTy.U16 ≤ x ∧ x ≤ Scalar.max ScalarTy.U16) : (U16.ofInt x h).val = x := by
+@[simp] theorem U16.ofInt_val_eq (h : Scalar.min ScalarTy.U16 ≤ x ∧ x ≤ Scalar.max ScalarTy.U16) : (U16.ofIntCore x h).val = x := by
   apply Scalar.ofInt_val_eq h
 
-@[simp] theorem U32.ofInt_val_eq (h : Scalar.min ScalarTy.U32 ≤ x ∧ x ≤ Scalar.max ScalarTy.U32) : (U32.ofInt x h).val = x := by
+@[simp] theorem U32.ofInt_val_eq (h : Scalar.min ScalarTy.U32 ≤ x ∧ x ≤ Scalar.max ScalarTy.U32) : (U32.ofIntCore x h).val = x := by
   apply Scalar.ofInt_val_eq h
 
-@[simp] theorem U64.ofInt_val_eq (h : Scalar.min ScalarTy.U64 ≤ x ∧ x ≤ Scalar.max ScalarTy.U64) : (U64.ofInt x h).val = x := by
+@[simp] theorem U64.ofInt_val_eq (h : Scalar.min ScalarTy.U64 ≤ x ∧ x ≤ Scalar.max ScalarTy.U64) : (U64.ofIntCore x h).val = x := by
   apply Scalar.ofInt_val_eq h
 
-@[simp] theorem U128.ofInt_val_eq (h : Scalar.min ScalarTy.U128 ≤ x ∧ x ≤ Scalar.max ScalarTy.U128) : (U128.ofInt x h).val = x := by
+@[simp] theorem U128.ofInt_val_eq (h : Scalar.min ScalarTy.U128 ≤ x ∧ x ≤ Scalar.max ScalarTy.U128) : (U128.ofIntCore x h).val = x := by
   apply Scalar.ofInt_val_eq h
 
 -- Comparisons
@@ -1133,22 +1185,4 @@ instance (ty : ScalarTy) : DecidableEq (Scalar ty) :=
 --   else
 --     .fail integerOverflow
 
--- Notation for pattern matching
--- We make the precedence looser than the negation.
-notation:70 a:70 "#scalar" => Scalar.mk (a) _ _
-
-example {ty} (x : Scalar ty) : ℤ :=
-  match x with
-  | v#scalar => v
-
-example {ty} (x : Scalar ty) : Bool :=
-  match x with
-  | 1#scalar => true
-  | _ => false
-
-example {ty} (x : Scalar ty) : Bool :=
-  match x with
-  | -(1 : Int)#scalar => true
-  | _ => false
-
 end Primitives
diff --git a/backends/lean/Base/Primitives/Vec.lean b/backends/lean/Base/Primitives/Vec.lean
index b03de15b..65249c12 100644
--- a/backends/lean/Base/Primitives/Vec.lean
+++ b/backends/lean/Base/Primitives/Vec.lean
@@ -43,7 +43,7 @@ instance (α : Type u) : Inhabited (Vec α) := by
 
 -- TODO: very annoying that the α is an explicit parameter
 def Vec.len (α : Type u) (v : Vec α) : Usize :=
-  Usize.ofIntCore v.val.len (by scalar_tac) (by scalar_tac)
+  Usize.ofIntCore v.val.len (by constructor <;> scalar_tac)
 
 @[simp]
 theorem Vec.len_val {α : Type u} (v : Vec α) : (Vec.len α v).val = v.length :=
-- 
cgit v1.2.3


From b6f63f106baef03dd61f1100bd46c9bad7cb79e4 Mon Sep 17 00:00:00 2001
From: Son Ho
Date: Fri, 8 Mar 2024 08:53:38 +0100
Subject: Remove some comments

---
 backends/lean/Base/Primitives/Scalar.lean | 31 -------------------------------
 1 file changed, 31 deletions(-)

(limited to 'backends/lean/Base')

diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean
index bf6b01a6..3d90f1a5 100644
--- a/backends/lean/Base/Primitives/Scalar.lean
+++ b/backends/lean/Base/Primitives/Scalar.lean
@@ -1154,35 +1154,4 @@ instance (ty : ScalarTy) : DecidableEq (Scalar ty) :=
 @[simp] theorem Scalar.neq_to_neq_val {ty} : ∀ {i j : Scalar ty}, (¬ i = j) ↔ ¬ i.val = j.val := by
   intro i j; cases i; cases j; simp
 
--- -- We now define a type class that subsumes the various machine integer types, so
--- -- as to write a concise definition for scalar_cast, rather than exhaustively
--- -- enumerating all of the possible pairs. We remark that Rust has sane semantics
--- -- and fails if a cast operation would involve a truncation or modulo.
-
--- class MachineInteger (t: Type) where
---   size: Nat
---   val: t -> Fin size
---   ofNatCore: (n:Nat) -> LT.lt n size -> t
-
--- set_option hygiene false in
--- run_cmd
---   for typeName in [`UInt8, `UInt16, `UInt32, `UInt64, `USize].map Lean.mkIdent do
---   Lean.Elab.Command.elabCommand (← `(
---     namespace $typeName
---     instance: MachineInteger $typeName where
---       size := size
---       val := val
---       ofNatCore := ofNatCore
---     end $typeName
---   ))
-
--- -- Aeneas only instantiates the destination type (`src` is implicit). We rely on
--- -- Lean to infer `src`.
-
--- def scalar_cast { src: Type } (dst: Type) [ MachineInteger src ] [ MachineInteger dst ] (x: src): Result dst :=
---   if h: MachineInteger.val x < MachineInteger.size dst then
---     .ret (MachineInteger.ofNatCore (MachineInteger.val x).val h)
---   else
---     .fail integerOverflow
-
 end Primitives
-- 
cgit v1.2.3


From 41d6f78a0ad6bd272164894bead3258b2001ec0c Mon Sep 17 00:00:00 2001
From: Son Ho
Date: Fri, 8 Mar 2024 09:22:08 +0100
Subject: Update the tuples notations

---
 backends/lean/Base/Primitives.lean      |  1 +
 backends/lean/Base/Primitives/Base.lean | 51 ---------------------
 backends/lean/Base/Tuples.lean          | 80 +++++++++++++++++++++++++++++++++
 3 files changed, 81 insertions(+), 51 deletions(-)
 create mode 100644 backends/lean/Base/Tuples.lean

(limited to 'backends/lean/Base')

diff --git a/backends/lean/Base/Primitives.lean b/backends/lean/Base/Primitives.lean
index 613b6076..7196d2ec 100644
--- a/backends/lean/Base/Primitives.lean
+++ b/backends/lean/Base/Primitives.lean
@@ -1,4 +1,5 @@
 import Base.Primitives.Base
+import Base.Tuples
 import Base.Primitives.Scalar
 import Base.Primitives.ArraySlice
 import Base.Primitives.Vec
diff --git a/backends/lean/Base/Primitives/Base.lean b/backends/lean/Base/Primitives/Base.lean
index adec9a8b..9dbaf133 100644
--- a/backends/lean/Base/Primitives/Base.lean
+++ b/backends/lean/Base/Primitives/Base.lean
@@ -123,57 +123,6 @@ def Result.attach {α: Type} (o : Result α): Result { x : α // o = ret x } :=
   simp [Bind.bind]
   cases e <;> simp
 
--------------------------------
--- Tuple field access syntax --
--------------------------------
--- Declare new syntax `a.#i` for accessing the `i`-th term in a tuple
--- The `noWs` parser is used to ensure there is no whitespace.
-syntax term noWs ".#" noWs num : term
-
-open Lean Meta Elab Term
-
--- Auxliary function for computing the number of elements in a tuple (`Prod`) type.
-def getArity (type : Expr) : Nat :=
-  match type with
-  | .app (.app (.const ``Prod _) _) as => getArity as + 1
-  | _ => 1 -- It is not product
-
--- Given a `tuple` of size `n`, construct a term that for accessing the `i`-th element
-def mkGetIdx (tuple : Expr) (n : Nat) (i : Nat) : MetaM Expr := do
-  match i with
-  | 0 => mkAppM ``Prod.fst #[tuple]
-  | i+1 =>
-    if n = 2 then
-      -- If the tuple has only two elements and `i` is not `0`,
-      -- we just return the second element.
-      mkAppM ``Prod.snd #[tuple]
-    else
-      -- Otherwise, we continue with the rest of the tuple.
-      let tuple ← mkAppM ``Prod.snd #[tuple]
-      mkGetIdx tuple (n-1) i
-
--- Now, we define the elaboration function for the new syntax `a#i`
-elab_rules : term
-| `($a:term.#$i:num) => do
-  -- Convert `i : Syntax` into a natural number
-  let i := i.getNat
-  -- Return error if it is 0.
-  unless i ≥ 0 do
-    throwError "tuple index must be greater or equal to 0"
-  -- Convert `a : Syntax` into an `tuple : Expr` without providing expected type
-  let tuple ← elabTerm a none
-  let type ← inferType tuple
-  -- Instantiate assigned metavariable occurring in `type`
-  let type ← instantiateMVars type
-  -- Ensure `tuple`'s type is a `Prod`uct.
-  unless type.isAppOf ``Prod do
-    throwError "tuple expected{indentExpr type}"
-  let n := getArity type
-  -- Ensure `i` is a valid index
-  unless i < n do
-    throwError "invalid tuple access at {i}, tuple has {n} elements"
-  mkGetIdx tuple n i
-
 ----------
 -- MISC --
 ----------
diff --git a/backends/lean/Base/Tuples.lean b/backends/lean/Base/Tuples.lean
new file mode 100644
index 00000000..d8e4a843
--- /dev/null
+++ b/backends/lean/Base/Tuples.lean
@@ -0,0 +1,80 @@
+import Lean
+import Base.Utils
+
+namespace Primitives
+
+-------------------------------
+-- Tuple field access syntax --
+-------------------------------
+-- Declare new syntax `a.#i` for accessing the `i`-th term in a tuple
+-- The `noWs` parser is used to ensure there is no whitespace.
+syntax term noWs ".#" noWs num : term
+
+open Lean Meta Elab Term
+
+-- Auxliary function for computing the number of elements in a tuple (`Prod`) type.
+def getArity (type : Expr) : Nat :=
+  match type with
+  | .app (.app (.const ``Prod _) _) as => getArity as + 1
+  | _ => 1 -- It is not product
+
+-- Given a `tuple` of size `n`, construct a term that for accessing the `i`-th element
+def mkGetIdx (tuple : Expr) (n : Nat) (i : Nat) : MetaM Expr := do
+  match i with
+  | 0 => mkAppM ``Prod.fst #[tuple]
+  | i+1 =>
+    if n = 2 then
+      -- If the tuple has only two elements and `i` is not `0`,
+      -- we just return the second element.
+      mkAppM ``Prod.snd #[tuple]
+    else
+      -- Otherwise, we continue with the rest of the tuple.
+      let tuple ← mkAppM ``Prod.snd #[tuple]
+      mkGetIdx tuple (n-1) i
+
+-- Now, we define the elaboration function for the new syntax `a#i`
+elab_rules : term
+| `($a:term.#$i:num) => do
+  -- Convert `i : Syntax` into a natural number
+  let i := i.getNat
+  -- Return error if it is 0.
+  unless i ≥ 0 do
+    throwError "tuple index must be greater or equal to 0"
+  -- Convert `a : Syntax` into an `tuple : Expr` without providing expected type
+  let tuple ← elabTerm a none
+  let type ← inferType tuple
+  -- Instantiate assigned metavariable occurring in `type`
+  let type ← instantiateMVars type
+  /- In case we are indexing into a type abbreviation, we need to unfold the type.
+
+     TODO: we have to be careful about not unfolding too much,
+     for instance because of the following code:
+     ```
+     def Pair T U := T × U
+     def Tuple T U V := T × Pair U V
+     ```
+     We have to make sure that, given `x : Tuple T U V`, `x.1` evaluates
+     to the pair (an element of type `Pair T U`), not to the first field
+     of the pair (an element of type `T`).
+
+     We have a similar issue below if we generate code from the following Rust definition:
+     ```
+     struct Tuple(u32, (u32, u32));
+     ```
+     The issue is that in Rust, field 1 of `Tuple` is a pair `(u32, u32)`, but
+     in Lean there is no difference between `A × B × C` and `A × (B × C)`.
+
+     In case such situations happen we probably need to resort to chaining
+     the pair projectors, like in: `x.snd.fst`.
+   -/
+  let type ← whnf type
+  -- Ensure `tuple`'s type is a `Prod`uct.
+  unless type.isAppOf ``Prod do
+    throwError "tuple expected{indentExpr type}"
+  let n := getArity type
+  -- Ensure `i` is a valid index
+  unless i < n do
+    throwError "invalid tuple access at {i}, tuple has {n} elements"
+  mkGetIdx tuple n i
+
+end Primitives
-- 
cgit v1.2.3


From 9d541d1ab6b91e59e4f78f4711af085a33ee4f82 Mon Sep 17 00:00:00 2001
From: Son Ho
Date: Fri, 8 Mar 2024 09:25:11 +0100
Subject: Update the tuples syntax

---
 backends/lean/Base/Tuples.lean | 4 +++-
 1 file changed, 3 insertions(+), 1 deletion(-)

(limited to 'backends/lean/Base')

diff --git a/backends/lean/Base/Tuples.lean b/backends/lean/Base/Tuples.lean
index d8e4a843..4c59dac9 100644
--- a/backends/lean/Base/Tuples.lean
+++ b/backends/lean/Base/Tuples.lean
@@ -8,7 +8,9 @@ namespace Primitives
 -------------------------------
 -- Declare new syntax `a.#i` for accessing the `i`-th term in a tuple
 -- The `noWs` parser is used to ensure there is no whitespace.
-syntax term noWs ".#" noWs num : term
+-- We use the maximum precedence to make the syntax work with function calls.
+-- Ex.: `f (0, 1).#0`
+syntax:max term noWs ".#" noWs num : term
 
 open Lean Meta Elab Term
 
-- 
cgit v1.2.3


From 44248ccfe3bfb8c45e5bb434d8dfb3dfa6e6b69c Mon Sep 17 00:00:00 2001
From: Son Ho
Date: Fri, 8 Mar 2024 09:42:29 +0100
Subject: Update the generation of constant bodies for Lean

---
 backends/lean/Base/Primitives/Base.lean | 4 ++--
 1 file changed, 2 insertions(+), 2 deletions(-)

(limited to 'backends/lean/Base')

diff --git a/backends/lean/Base/Primitives/Base.lean b/backends/lean/Base/Primitives/Base.lean
index 9dbaf133..0b9d9c39 100644
--- a/backends/lean/Base/Primitives/Base.lean
+++ b/backends/lean/Base/Primitives/Base.lean
@@ -69,7 +69,7 @@ def div? {α: Type u} (r: Result α): Bool :=
 def massert (b:Bool) : Result Unit :=
   if b then ret () else fail assertionFailure
 
-def eval_global {α: Type u} (x: Result α) (_: ret? x): α :=
+def eval_global {α: Type u} (x: Result α) (_: ret? x := by decide): α :=
   match x with
   | fail _ | div => by contradiction
   | ret x => x
@@ -78,7 +78,7 @@ def eval_global {α: Type u} (x: Result α) (_: ret? x): α :=
 
 def bind {α : Type u} {β : Type v} (x: Result α) (f: α → Result β) : Result β :=
   match x with
-  | ret v  => f v 
+  | ret v  => f v
   | fail v => fail v
   | div => div
 
-- 
cgit v1.2.3


From 5427563a8000f281ac614a2501fb9983beb44f21 Mon Sep 17 00:00:00 2001
From: Zyad Hassan
Date: Fri, 23 Feb 2024 16:37:58 -0800
Subject: Fix tuple indexing for Lean backend

---
 backends/lean/Base/IList/IList.lean | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

(limited to 'backends/lean/Base')

diff --git a/backends/lean/Base/IList/IList.lean b/backends/lean/Base/IList/IList.lean
index 51457c20..ca5ee266 100644
--- a/backends/lean/Base/IList/IList.lean
+++ b/backends/lean/Base/IList/IList.lean
@@ -33,7 +33,7 @@ def indexOpt (ls : List α) (i : Int) : Option α :=
 @[simp] theorem indexOpt_zero_cons : indexOpt ((x :: tl) : List α) 0 = some x := by simp [indexOpt]
 @[simp] theorem indexOpt_nzero_cons (hne : i ≠ 0) : indexOpt ((x :: tl) : List α) i = indexOpt tl (i - 1) := by simp [*, indexOpt]
 
--- Remark: if i < 0, then the result is the defaul element
+-- Remark: if i < 0, then the result is the default element
 def index [Inhabited α] (ls : List α) (i : Int) : α :=
   match ls with
   | [] => Inhabited.default
-- 
cgit v1.2.3