summaryrefslogtreecommitdiff
path: root/backends/lean
diff options
context:
space:
mode:
authorSon Ho2024-03-08 08:50:50 +0100
committerSon Ho2024-03-08 08:50:50 +0100
commit46b126f4e0e86f14475bc310e150948434726dc7 (patch)
tree47f1ee50bfbb29c9c1d74c492f2c817f2709d2e0 /backends/lean
parentf74647773d7dd21580fd938dd9b4e300719b0234 (diff)
Update the handling of notations like #u32 or #isize
Diffstat (limited to 'backends/lean')
-rw-r--r--backends/lean/Base/Arith/Scalar.lean2
-rw-r--r--backends/lean/Base/Primitives/ArraySlice.lean2
-rw-r--r--backends/lean/Base/Primitives/Scalar.lean168
-rw-r--r--backends/lean/Base/Primitives/Vec.lean2
4 files changed, 104 insertions, 70 deletions
diff --git a/backends/lean/Base/Arith/Scalar.lean b/backends/lean/Base/Arith/Scalar.lean
index 43fd2766..9441be86 100644
--- a/backends/lean/Base/Arith/Scalar.lean
+++ b/backends/lean/Base/Arith/Scalar.lean
@@ -74,7 +74,7 @@ example : U32.ofInt 1 ≤ U32.max := by
scalar_tac
example (x : Int) (h0 : 0 ≤ x) (h1 : x ≤ U32.max) :
- U32.ofInt x (by constructor <;> scalar_tac) ≤ U32.max := by
+ U32.ofIntCore x (by constructor <;> scalar_tac) ≤ U32.max := by
scalar_tac
-- Not equal
diff --git a/backends/lean/Base/Primitives/ArraySlice.lean b/backends/lean/Base/Primitives/ArraySlice.lean
index c90a85b8..e1a39d40 100644
--- a/backends/lean/Base/Primitives/ArraySlice.lean
+++ b/backends/lean/Base/Primitives/ArraySlice.lean
@@ -131,7 +131,7 @@ def Slice.new (α : Type u): Slice α := ⟨ [], by apply Scalar.cMax_suffices .
-- TODO: very annoying that the α is an explicit parameter
def Slice.len (α : Type u) (v : Slice α) : Usize :=
- Usize.ofIntCore v.val.len (by scalar_tac) (by scalar_tac)
+ Usize.ofIntCore v.val.len (by constructor <;> scalar_tac)
@[simp]
theorem Slice.len_val {α : Type u} (v : Slice α) : (Slice.len α v).val = v.length :=
diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean
index 3afd13d2..bf6b01a6 100644
--- a/backends/lean/Base/Primitives/Scalar.lean
+++ b/backends/lean/Base/Primitives/Scalar.lean
@@ -281,25 +281,38 @@ theorem Scalar.bound_suffices (ty : ScalarTy) (x : Int) :
λ h => by
apply And.intro <;> have hmin := Scalar.cMin_bound ty <;> have hmax := Scalar.cMax_bound ty <;> linarith
-def Scalar.ofIntCore {ty : ScalarTy} (x : Int)
- (hmin : Scalar.min ty ≤ x) (hmax : x ≤ Scalar.max ty) : Scalar ty :=
- { val := x, hmin := hmin, hmax := hmax }
-
--- Tactic to prove that integers are in bounds
--- TODO: use this: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/instance.20with.20tactic.20autoparam
-syntax "intlit" : tactic
-macro_rules
- | `(tactic| intlit) => `(tactic| apply Scalar.bound_suffices; decide)
-
-def Scalar.ofInt {ty : ScalarTy} (x : Int)
- (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty := by intlit) : Scalar ty :=
- -- Remark: we initially wrote:
- -- let ⟨ hmin, hmax ⟩ := h
- -- Scalar.ofIntCore x hmin hmax
- -- We updated to the line below because a similar pattern in `Scalar.tryMk`
- -- made reduction block. Both versions seem to work for `Scalar.ofInt`, though.
- -- TODO: investigate
- Scalar.ofIntCore x h.left h.right
+/- [match_pattern] attribute: allows to us `Scalar.ofIntCore` inside of patterns.
+ This is particularly useful once we introduce notations like `#u32` (which
+ desugards to `Scalar.ofIntCore`) as it allows to write expressions like this:
+ Example:
+ ```
+ match x with
+ | 0#u32 => ...
+ | 1#u32 => ...
+ | ...
+ ```
+ -/
+@[match_pattern] def Scalar.ofIntCore {ty : ScalarTy} (x : Int)
+ (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : Scalar ty :=
+ { val := x, hmin := h.left, hmax := h.right }
+
+-- The definitions below are used later to introduce nice syntax for constants,
+-- like `1#u32`. We are reusing the technique described here: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Different.20elaboration.20inside.2Foutside.20of.20match.20patterns/near/425455284
+
+class InBounds (ty : ScalarTy) (x : Int) :=
+ hInBounds : Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty
+
+-- This trick to trigger reduction for decidable propositions comes from
+-- here: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/instance.20with.20tactic.20autoparam/near/343495807
+class Decide (p : Prop) [Decidable p] : Prop where
+ isTrue : p
+instance : @Decide p (.isTrue h) := @Decide.mk p (_) h
+
+instance [Decide (Scalar.cMin ty ≤ v ∧ v ≤ Scalar.cMax ty)] : InBounds ty v where
+ hInBounds := Decide.isTrue
+
+@[reducible, match_pattern] def Scalar.ofInt {ty : ScalarTy} (x : Int) [InBounds ty x] : Scalar ty :=
+ Scalar.ofIntCore x (Scalar.bound_suffices ty x InBounds.hInBounds)
@[simp] def Scalar.check_bounds (ty : ScalarTy) (x : Int) : Bool :=
(Scalar.cMin ty ≤ x || Scalar.min ty ≤ x) ∧ (x ≤ Scalar.cMax ty || x ≤ Scalar.max ty)
@@ -326,7 +339,7 @@ def Scalar.tryMk (ty : ScalarTy) (x : Int) : Result (Scalar ty) :=
-- ```
-- then normalization blocks (for instance, some proofs which use reflexivity fail).
-- However, the version below doesn't block reduction (TODO: investigate):
- return Scalar.ofInt x (Scalar.check_bounds_prop h)
+ return Scalar.ofIntCore x (Scalar.check_bounds_prop h)
else fail integerOverflow
def Scalar.neg {ty : ScalarTy} (x : Scalar ty) : Result (Scalar ty) := Scalar.tryMk ty (- x.val)
@@ -439,8 +452,8 @@ instance (ty : ScalarTy) : Inhabited (Scalar ty) := by
constructor; cases ty <;> apply (Scalar.ofInt 0)
-- TODO: reducible?
-@[reducible] def core_isize_min : Isize := Scalar.ofInt Isize.min (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Isize))
-@[reducible] def core_isize_max : Isize := Scalar.ofInt Isize.max (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Isize))
+@[reducible] def core_isize_min : Isize := Scalar.ofIntCore Isize.min (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Isize))
+@[reducible] def core_isize_max : Isize := Scalar.ofIntCore Isize.max (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Isize))
@[reducible] def core_i8_min : I8 := Scalar.ofInt I8.min
@[reducible] def core_i8_max : I8 := Scalar.ofInt I8.max
@[reducible] def core_i16_min : I16 := Scalar.ofInt I16.min
@@ -453,8 +466,8 @@ instance (ty : ScalarTy) : Inhabited (Scalar ty) := by
@[reducible] def core_i128_max : I128 := Scalar.ofInt I128.max
-- TODO: reducible?
-@[reducible] def core_usize_min : Usize := Scalar.ofInt Usize.min
-@[reducible] def core_usize_max : Usize := Scalar.ofInt Usize.max (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Usize))
+@[reducible] def core_usize_min : Usize := Scalar.ofIntCore Usize.min (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Usize))
+@[reducible] def core_usize_max : Usize := Scalar.ofIntCore Usize.max (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Usize))
@[reducible] def core_u8_min : U8 := Scalar.ofInt U8.min
@[reducible] def core_u8_max : U8 := Scalar.ofInt U8.max
@[reducible] def core_u16_min : U16 := Scalar.ofInt U16.min
@@ -985,18 +998,18 @@ def U128.ofIntCore := @Scalar.ofIntCore .U128
-- ofInt
-- TODO: typeclass?
-abbrev Isize.ofInt := @Scalar.ofInt .Isize
-abbrev I8.ofInt := @Scalar.ofInt .I8
-abbrev I16.ofInt := @Scalar.ofInt .I16
-abbrev I32.ofInt := @Scalar.ofInt .I32
-abbrev I64.ofInt := @Scalar.ofInt .I64
-abbrev I128.ofInt := @Scalar.ofInt .I128
-abbrev Usize.ofInt := @Scalar.ofInt .Usize
-abbrev U8.ofInt := @Scalar.ofInt .U8
-abbrev U16.ofInt := @Scalar.ofInt .U16
-abbrev U32.ofInt := @Scalar.ofInt .U32
-abbrev U64.ofInt := @Scalar.ofInt .U64
-abbrev U128.ofInt := @Scalar.ofInt .U128
+@[match_pattern] abbrev Isize.ofInt := @Scalar.ofInt .Isize
+@[match_pattern] abbrev I8.ofInt := @Scalar.ofInt .I8
+@[match_pattern] abbrev I16.ofInt := @Scalar.ofInt .I16
+@[match_pattern] abbrev I32.ofInt := @Scalar.ofInt .I32
+@[match_pattern] abbrev I64.ofInt := @Scalar.ofInt .I64
+@[match_pattern] abbrev I128.ofInt := @Scalar.ofInt .I128
+@[match_pattern] abbrev Usize.ofInt := @Scalar.ofInt .Usize
+@[match_pattern] abbrev U8.ofInt := @Scalar.ofInt .U8
+@[match_pattern] abbrev U16.ofInt := @Scalar.ofInt .U16
+@[match_pattern] abbrev U32.ofInt := @Scalar.ofInt .U32
+@[match_pattern] abbrev U64.ofInt := @Scalar.ofInt .U64
+@[match_pattern] abbrev U128.ofInt := @Scalar.ofInt .U128
postfix:max "#isize" => Isize.ofInt
postfix:max "#i8" => I8.ofInt
@@ -1011,47 +1024,86 @@ postfix:max "#u32" => U32.ofInt
postfix:max "#u64" => U64.ofInt
postfix:max "#u128" => U128.ofInt
+/- Testing the notations -/
+example := 0#u32
+example := 1#u32
+example := 1#i32
+example := 0#isize
+example := (-1)#isize
+example (x : U32) : Bool :=
+ match x with
+ | 0#u32 => true
+ | _ => false
+
+example (x : U32) : Bool :=
+ match x with
+ | 1#u32 => true
+ | _ => false
+
+example (x : I32) : Bool :=
+ match x with
+ | (-1)#i32 => true
+ | _ => false
+
+-- Notation for pattern matching
+-- We make the precedence looser than the negation.
+notation:70 a:70 "#scalar" => Scalar.mk (a) _ _
+
+example {ty} (x : Scalar ty) : ℤ :=
+ match x with
+ | v#scalar => v
+
+example {ty} (x : Scalar ty) : Bool :=
+ match x with
+ | 1#scalar => true
+ | _ => false
+
+example {ty} (x : Scalar ty) : Bool :=
+ match x with
+ | -(1 : Int)#scalar => true
+ | _ => false
+
-- Testing the notations
example : Result Usize := 0#usize + 1#usize
-- TODO: factor those lemmas out
-@[simp] theorem Scalar.ofInt_val_eq {ty} (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : (Scalar.ofInt x h).val = x := by
+@[simp] theorem Scalar.ofInt_val_eq {ty} (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : (Scalar.ofIntCore x h).val = x := by
simp [Scalar.ofInt, Scalar.ofIntCore]
-@[simp] theorem Isize.ofInt_val_eq (h : Scalar.min ScalarTy.Isize ≤ x ∧ x ≤ Scalar.max ScalarTy.Isize) : (Isize.ofInt x h).val = x := by
+@[simp] theorem Isize.ofInt_val_eq (h : Scalar.min ScalarTy.Isize ≤ x ∧ x ≤ Scalar.max ScalarTy.Isize) : (Isize.ofIntCore x h).val = x := by
apply Scalar.ofInt_val_eq h
-@[simp] theorem I8.ofInt_val_eq (h : Scalar.min ScalarTy.I8 ≤ x ∧ x ≤ Scalar.max ScalarTy.I8) : (I8.ofInt x h).val = x := by
+@[simp] theorem I8.ofInt_val_eq (h : Scalar.min ScalarTy.I8 ≤ x ∧ x ≤ Scalar.max ScalarTy.I8) : (I8.ofIntCore x h).val = x := by
apply Scalar.ofInt_val_eq h
-@[simp] theorem I16.ofInt_val_eq (h : Scalar.min ScalarTy.I16 ≤ x ∧ x ≤ Scalar.max ScalarTy.I16) : (I16.ofInt x h).val = x := by
+@[simp] theorem I16.ofInt_val_eq (h : Scalar.min ScalarTy.I16 ≤ x ∧ x ≤ Scalar.max ScalarTy.I16) : (I16.ofIntCore x h).val = x := by
apply Scalar.ofInt_val_eq h
-@[simp] theorem I32.ofInt_val_eq (h : Scalar.min ScalarTy.I32 ≤ x ∧ x ≤ Scalar.max ScalarTy.I32) : (I32.ofInt x h).val = x := by
+@[simp] theorem I32.ofInt_val_eq (h : Scalar.min ScalarTy.I32 ≤ x ∧ x ≤ Scalar.max ScalarTy.I32) : (I32.ofIntCore x h).val = x := by
apply Scalar.ofInt_val_eq h
-@[simp] theorem I64.ofInt_val_eq (h : Scalar.min ScalarTy.I64 ≤ x ∧ x ≤ Scalar.max ScalarTy.I64) : (I64.ofInt x h).val = x := by
+@[simp] theorem I64.ofInt_val_eq (h : Scalar.min ScalarTy.I64 ≤ x ∧ x ≤ Scalar.max ScalarTy.I64) : (I64.ofIntCore x h).val = x := by
apply Scalar.ofInt_val_eq h
-@[simp] theorem I128.ofInt_val_eq (h : Scalar.min ScalarTy.I128 ≤ x ∧ x ≤ Scalar.max ScalarTy.I128) : (I128.ofInt x h).val = x := by
+@[simp] theorem I128.ofInt_val_eq (h : Scalar.min ScalarTy.I128 ≤ x ∧ x ≤ Scalar.max ScalarTy.I128) : (I128.ofIntCore x h).val = x := by
apply Scalar.ofInt_val_eq h
-@[simp] theorem Usize.ofInt_val_eq (h : Scalar.min ScalarTy.Usize ≤ x ∧ x ≤ Scalar.max ScalarTy.Usize) : (Usize.ofInt x h).val = x := by
+@[simp] theorem Usize.ofInt_val_eq (h : Scalar.min ScalarTy.Usize ≤ x ∧ x ≤ Scalar.max ScalarTy.Usize) : (Usize.ofIntCore x h).val = x := by
apply Scalar.ofInt_val_eq h
-@[simp] theorem U8.ofInt_val_eq (h : Scalar.min ScalarTy.U8 ≤ x ∧ x ≤ Scalar.max ScalarTy.U8) : (U8.ofInt x h).val = x := by
+@[simp] theorem U8.ofInt_val_eq (h : Scalar.min ScalarTy.U8 ≤ x ∧ x ≤ Scalar.max ScalarTy.U8) : (U8.ofIntCore x h).val = x := by
apply Scalar.ofInt_val_eq h
-@[simp] theorem U16.ofInt_val_eq (h : Scalar.min ScalarTy.U16 ≤ x ∧ x ≤ Scalar.max ScalarTy.U16) : (U16.ofInt x h).val = x := by
+@[simp] theorem U16.ofInt_val_eq (h : Scalar.min ScalarTy.U16 ≤ x ∧ x ≤ Scalar.max ScalarTy.U16) : (U16.ofIntCore x h).val = x := by
apply Scalar.ofInt_val_eq h
-@[simp] theorem U32.ofInt_val_eq (h : Scalar.min ScalarTy.U32 ≤ x ∧ x ≤ Scalar.max ScalarTy.U32) : (U32.ofInt x h).val = x := by
+@[simp] theorem U32.ofInt_val_eq (h : Scalar.min ScalarTy.U32 ≤ x ∧ x ≤ Scalar.max ScalarTy.U32) : (U32.ofIntCore x h).val = x := by
apply Scalar.ofInt_val_eq h
-@[simp] theorem U64.ofInt_val_eq (h : Scalar.min ScalarTy.U64 ≤ x ∧ x ≤ Scalar.max ScalarTy.U64) : (U64.ofInt x h).val = x := by
+@[simp] theorem U64.ofInt_val_eq (h : Scalar.min ScalarTy.U64 ≤ x ∧ x ≤ Scalar.max ScalarTy.U64) : (U64.ofIntCore x h).val = x := by
apply Scalar.ofInt_val_eq h
-@[simp] theorem U128.ofInt_val_eq (h : Scalar.min ScalarTy.U128 ≤ x ∧ x ≤ Scalar.max ScalarTy.U128) : (U128.ofInt x h).val = x := by
+@[simp] theorem U128.ofInt_val_eq (h : Scalar.min ScalarTy.U128 ≤ x ∧ x ≤ Scalar.max ScalarTy.U128) : (U128.ofIntCore x h).val = x := by
apply Scalar.ofInt_val_eq h
-- Comparisons
@@ -1133,22 +1185,4 @@ instance (ty : ScalarTy) : DecidableEq (Scalar ty) :=
-- else
-- .fail integerOverflow
--- Notation for pattern matching
--- We make the precedence looser than the negation.
-notation:70 a:70 "#scalar" => Scalar.mk (a) _ _
-
-example {ty} (x : Scalar ty) : ℤ :=
- match x with
- | v#scalar => v
-
-example {ty} (x : Scalar ty) : Bool :=
- match x with
- | 1#scalar => true
- | _ => false
-
-example {ty} (x : Scalar ty) : Bool :=
- match x with
- | -(1 : Int)#scalar => true
- | _ => false
-
end Primitives
diff --git a/backends/lean/Base/Primitives/Vec.lean b/backends/lean/Base/Primitives/Vec.lean
index b03de15b..65249c12 100644
--- a/backends/lean/Base/Primitives/Vec.lean
+++ b/backends/lean/Base/Primitives/Vec.lean
@@ -43,7 +43,7 @@ instance (α : Type u) : Inhabited (Vec α) := by
-- TODO: very annoying that the α is an explicit parameter
def Vec.len (α : Type u) (v : Vec α) : Usize :=
- Usize.ofIntCore v.val.len (by scalar_tac) (by scalar_tac)
+ Usize.ofIntCore v.val.len (by constructor <;> scalar_tac)
@[simp]
theorem Vec.len_val {α : Type u} (v : Vec α) : (Vec.len α v).val = v.length :=