summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--backends/lean/Base/Primitives.lean1
-rw-r--r--backends/lean/Base/Primitives/Base.lean51
-rw-r--r--backends/lean/Base/Tuples.lean80
3 files changed, 81 insertions, 51 deletions
diff --git a/backends/lean/Base/Primitives.lean b/backends/lean/Base/Primitives.lean
index 613b6076..7196d2ec 100644
--- a/backends/lean/Base/Primitives.lean
+++ b/backends/lean/Base/Primitives.lean
@@ -1,4 +1,5 @@
import Base.Primitives.Base
+import Base.Tuples
import Base.Primitives.Scalar
import Base.Primitives.ArraySlice
import Base.Primitives.Vec
diff --git a/backends/lean/Base/Primitives/Base.lean b/backends/lean/Base/Primitives/Base.lean
index adec9a8b..9dbaf133 100644
--- a/backends/lean/Base/Primitives/Base.lean
+++ b/backends/lean/Base/Primitives/Base.lean
@@ -123,57 +123,6 @@ def Result.attach {α: Type} (o : Result α): Result { x : α // o = ret x } :=
simp [Bind.bind]
cases e <;> simp
--------------------------------
--- Tuple field access syntax --
--------------------------------
--- Declare new syntax `a.#i` for accessing the `i`-th term in a tuple
--- The `noWs` parser is used to ensure there is no whitespace.
-syntax term noWs ".#" noWs num : term
-
-open Lean Meta Elab Term
-
--- Auxliary function for computing the number of elements in a tuple (`Prod`) type.
-def getArity (type : Expr) : Nat :=
- match type with
- | .app (.app (.const ``Prod _) _) as => getArity as + 1
- | _ => 1 -- It is not product
-
--- Given a `tuple` of size `n`, construct a term that for accessing the `i`-th element
-def mkGetIdx (tuple : Expr) (n : Nat) (i : Nat) : MetaM Expr := do
- match i with
- | 0 => mkAppM ``Prod.fst #[tuple]
- | i+1 =>
- if n = 2 then
- -- If the tuple has only two elements and `i` is not `0`,
- -- we just return the second element.
- mkAppM ``Prod.snd #[tuple]
- else
- -- Otherwise, we continue with the rest of the tuple.
- let tuple ← mkAppM ``Prod.snd #[tuple]
- mkGetIdx tuple (n-1) i
-
--- Now, we define the elaboration function for the new syntax `a#i`
-elab_rules : term
-| `($a:term.#$i:num) => do
- -- Convert `i : Syntax` into a natural number
- let i := i.getNat
- -- Return error if it is 0.
- unless i ≥ 0 do
- throwError "tuple index must be greater or equal to 0"
- -- Convert `a : Syntax` into an `tuple : Expr` without providing expected type
- let tuple ← elabTerm a none
- let type ← inferType tuple
- -- Instantiate assigned metavariable occurring in `type`
- let type ← instantiateMVars type
- -- Ensure `tuple`'s type is a `Prod`uct.
- unless type.isAppOf ``Prod do
- throwError "tuple expected{indentExpr type}"
- let n := getArity type
- -- Ensure `i` is a valid index
- unless i < n do
- throwError "invalid tuple access at {i}, tuple has {n} elements"
- mkGetIdx tuple n i
-
----------
-- MISC --
----------
diff --git a/backends/lean/Base/Tuples.lean b/backends/lean/Base/Tuples.lean
new file mode 100644
index 00000000..d8e4a843
--- /dev/null
+++ b/backends/lean/Base/Tuples.lean
@@ -0,0 +1,80 @@
+import Lean
+import Base.Utils
+
+namespace Primitives
+
+-------------------------------
+-- Tuple field access syntax --
+-------------------------------
+-- Declare new syntax `a.#i` for accessing the `i`-th term in a tuple
+-- The `noWs` parser is used to ensure there is no whitespace.
+syntax term noWs ".#" noWs num : term
+
+open Lean Meta Elab Term
+
+-- Auxliary function for computing the number of elements in a tuple (`Prod`) type.
+def getArity (type : Expr) : Nat :=
+ match type with
+ | .app (.app (.const ``Prod _) _) as => getArity as + 1
+ | _ => 1 -- It is not product
+
+-- Given a `tuple` of size `n`, construct a term that for accessing the `i`-th element
+def mkGetIdx (tuple : Expr) (n : Nat) (i : Nat) : MetaM Expr := do
+ match i with
+ | 0 => mkAppM ``Prod.fst #[tuple]
+ | i+1 =>
+ if n = 2 then
+ -- If the tuple has only two elements and `i` is not `0`,
+ -- we just return the second element.
+ mkAppM ``Prod.snd #[tuple]
+ else
+ -- Otherwise, we continue with the rest of the tuple.
+ let tuple ← mkAppM ``Prod.snd #[tuple]
+ mkGetIdx tuple (n-1) i
+
+-- Now, we define the elaboration function for the new syntax `a#i`
+elab_rules : term
+| `($a:term.#$i:num) => do
+ -- Convert `i : Syntax` into a natural number
+ let i := i.getNat
+ -- Return error if it is 0.
+ unless i ≥ 0 do
+ throwError "tuple index must be greater or equal to 0"
+ -- Convert `a : Syntax` into an `tuple : Expr` without providing expected type
+ let tuple ← elabTerm a none
+ let type ← inferType tuple
+ -- Instantiate assigned metavariable occurring in `type`
+ let type ← instantiateMVars type
+ /- In case we are indexing into a type abbreviation, we need to unfold the type.
+
+ TODO: we have to be careful about not unfolding too much,
+ for instance because of the following code:
+ ```
+ def Pair T U := T × U
+ def Tuple T U V := T × Pair U V
+ ```
+ We have to make sure that, given `x : Tuple T U V`, `x.1` evaluates
+ to the pair (an element of type `Pair T U`), not to the first field
+ of the pair (an element of type `T`).
+
+ We have a similar issue below if we generate code from the following Rust definition:
+ ```
+ struct Tuple(u32, (u32, u32));
+ ```
+ The issue is that in Rust, field 1 of `Tuple` is a pair `(u32, u32)`, but
+ in Lean there is no difference between `A × B × C` and `A × (B × C)`.
+
+ In case such situations happen we probably need to resort to chaining
+ the pair projectors, like in: `x.snd.fst`.
+ -/
+ let type ← whnf type
+ -- Ensure `tuple`'s type is a `Prod`uct.
+ unless type.isAppOf ``Prod do
+ throwError "tuple expected{indentExpr type}"
+ let n := getArity type
+ -- Ensure `i` is a valid index
+ unless i < n do
+ throwError "invalid tuple access at {i}, tuple has {n} elements"
+ mkGetIdx tuple n i
+
+end Primitives