summaryrefslogtreecommitdiff
path: root/backends/lean/Base/Primitives/Base.lean
diff options
context:
space:
mode:
authorSon Ho2024-03-07 17:43:55 +0100
committerSon Ho2024-03-07 17:43:55 +0100
commit124ee77181c4255e2c8f730305b0b1b7802b9a58 (patch)
tree96097be160795705d479a39ccd165977a3bb9f1d /backends/lean/Base/Primitives/Base.lean
parent305f916c602457b0a1fa8ce5569c6c0bf26d6f8e (diff)
Add a notation for tuple field accesses in Lean
Diffstat (limited to 'backends/lean/Base/Primitives/Base.lean')
-rw-r--r--backends/lean/Base/Primitives/Base.lean51
1 files changed, 51 insertions, 0 deletions
diff --git a/backends/lean/Base/Primitives/Base.lean b/backends/lean/Base/Primitives/Base.lean
index 9dbaf133..adec9a8b 100644
--- a/backends/lean/Base/Primitives/Base.lean
+++ b/backends/lean/Base/Primitives/Base.lean
@@ -123,6 +123,57 @@ def Result.attach {α: Type} (o : Result α): Result { x : α // o = ret x } :=
simp [Bind.bind]
cases e <;> simp
+-------------------------------
+-- Tuple field access syntax --
+-------------------------------
+-- Declare new syntax `a.#i` for accessing the `i`-th term in a tuple
+-- The `noWs` parser is used to ensure there is no whitespace.
+syntax term noWs ".#" noWs num : term
+
+open Lean Meta Elab Term
+
+-- Auxliary function for computing the number of elements in a tuple (`Prod`) type.
+def getArity (type : Expr) : Nat :=
+ match type with
+ | .app (.app (.const ``Prod _) _) as => getArity as + 1
+ | _ => 1 -- It is not product
+
+-- Given a `tuple` of size `n`, construct a term that for accessing the `i`-th element
+def mkGetIdx (tuple : Expr) (n : Nat) (i : Nat) : MetaM Expr := do
+ match i with
+ | 0 => mkAppM ``Prod.fst #[tuple]
+ | i+1 =>
+ if n = 2 then
+ -- If the tuple has only two elements and `i` is not `0`,
+ -- we just return the second element.
+ mkAppM ``Prod.snd #[tuple]
+ else
+ -- Otherwise, we continue with the rest of the tuple.
+ let tuple ← mkAppM ``Prod.snd #[tuple]
+ mkGetIdx tuple (n-1) i
+
+-- Now, we define the elaboration function for the new syntax `a#i`
+elab_rules : term
+| `($a:term.#$i:num) => do
+ -- Convert `i : Syntax` into a natural number
+ let i := i.getNat
+ -- Return error if it is 0.
+ unless i ≥ 0 do
+ throwError "tuple index must be greater or equal to 0"
+ -- Convert `a : Syntax` into an `tuple : Expr` without providing expected type
+ let tuple ← elabTerm a none
+ let type ← inferType tuple
+ -- Instantiate assigned metavariable occurring in `type`
+ let type ← instantiateMVars type
+ -- Ensure `tuple`'s type is a `Prod`uct.
+ unless type.isAppOf ``Prod do
+ throwError "tuple expected{indentExpr type}"
+ let n := getArity type
+ -- Ensure `i` is a valid index
+ unless i < n do
+ throwError "invalid tuple access at {i}, tuple has {n} elements"
+ mkGetIdx tuple n i
+
----------
-- MISC --
----------