From 23ce25c77052c02312f19f17c51fe0b61d6abc93 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 7 Mar 2024 18:17:41 +0100 Subject: Introduce a notation for constant scalars in match patterns --- backends/lean/Base/Primitives/Scalar.lean | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) (limited to 'backends/lean/Base/Primitives/Scalar.lean') diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean index 285bc7fb..422cbc6a 100644 --- a/backends/lean/Base/Primitives/Scalar.lean +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -488,6 +488,17 @@ class HNeg (α : Type u) (β : outParam (Type v)) where prefix:75 "-" => HNeg.hNeg +/- We need this, otherwise we break pattern matching like in: + + ``` + def is_minus_one (x : Int) : Bool := + match x with + | -1 => true + | _ => false + ``` +-/ +attribute [match_pattern] HNeg.hNeg + instance : HNeg Isize (Result Isize) where hNeg x := Scalar.neg x instance : HNeg I8 (Result I8) where hNeg x := Scalar.neg x instance : HNeg I16 (Result I16) where hNeg x := Scalar.neg x @@ -1113,4 +1124,22 @@ instance (ty : ScalarTy) : DecidableEq (Scalar ty) := -- else -- .fail integerOverflow +-- Notation for pattern matching +-- We make the precedence looser than the negation. +notation:70 a:70 "#scalar" => Scalar.mk (a) _ _ + +example {ty} (x : Scalar ty) : ℤ := + match x with + | v#scalar => v + +example {ty} (x : Scalar ty) : Bool := + match x with + | 1#scalar => true + | _ => false + +example {ty} (x : Scalar ty) : Bool := + match x with + | -(1 : Int)#scalar => true + | _ => false + end Primitives -- cgit v1.2.3 From bc397dea5c5a67766c9c0381efad222524f68881 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 8 Mar 2024 07:56:44 +0100 Subject: Update the notation for heterogeneous negation --- backends/lean/Base/Primitives/Scalar.lean | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) (limited to 'backends/lean/Base/Primitives/Scalar.lean') diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean index 422cbc6a..3afd13d2 100644 --- a/backends/lean/Base/Primitives/Scalar.lean +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -478,15 +478,24 @@ instance (ty : ScalarTy) : Inhabited (Scalar ty) := by Remark: there is no heterogeneous negation in the Lean prelude: we thus introduce one here. -The notation typeclass for heterogeneous addition. -This enables the notation `- a : β` where `a : α`. +The notation typeclass for heterogeneous negation. -/ class HNeg (α : Type u) (β : outParam (Type v)) where /-- `- a` computes the negation of `a`. The meaning of this notation is type-dependent. -/ hNeg : α → β -prefix:75 "-" => HNeg.hNeg +/- Notation for heterogeneous negation. + + We initially used the notation "-" but it conflicted with the homogeneous + negation too much. In particular, it made terms like `-10` ambiguous, + and seemingly caused to backtracking in elaboration, leading to definitions + like arrays of constants to take an unreasonable time to get elaborated + and type-checked. + + TODO: PR to replace Neg with HNeg in Lean? + -/ +prefix:75 "-." => HNeg.hNeg /- We need this, otherwise we break pattern matching like in: -- cgit v1.2.3 From 46b126f4e0e86f14475bc310e150948434726dc7 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 8 Mar 2024 08:50:50 +0100 Subject: Update the handling of notations like #u32 or #isize --- backends/lean/Base/Primitives/Scalar.lean | 168 ++++++++++++++++++------------ 1 file changed, 101 insertions(+), 67 deletions(-) (limited to 'backends/lean/Base/Primitives/Scalar.lean') diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean index 3afd13d2..bf6b01a6 100644 --- a/backends/lean/Base/Primitives/Scalar.lean +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -281,25 +281,38 @@ theorem Scalar.bound_suffices (ty : ScalarTy) (x : Int) : λ h => by apply And.intro <;> have hmin := Scalar.cMin_bound ty <;> have hmax := Scalar.cMax_bound ty <;> linarith -def Scalar.ofIntCore {ty : ScalarTy} (x : Int) - (hmin : Scalar.min ty ≤ x) (hmax : x ≤ Scalar.max ty) : Scalar ty := - { val := x, hmin := hmin, hmax := hmax } - --- Tactic to prove that integers are in bounds --- TODO: use this: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/instance.20with.20tactic.20autoparam -syntax "intlit" : tactic -macro_rules - | `(tactic| intlit) => `(tactic| apply Scalar.bound_suffices; decide) - -def Scalar.ofInt {ty : ScalarTy} (x : Int) - (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty := by intlit) : Scalar ty := - -- Remark: we initially wrote: - -- let ⟨ hmin, hmax ⟩ := h - -- Scalar.ofIntCore x hmin hmax - -- We updated to the line below because a similar pattern in `Scalar.tryMk` - -- made reduction block. Both versions seem to work for `Scalar.ofInt`, though. - -- TODO: investigate - Scalar.ofIntCore x h.left h.right +/- [match_pattern] attribute: allows to us `Scalar.ofIntCore` inside of patterns. + This is particularly useful once we introduce notations like `#u32` (which + desugards to `Scalar.ofIntCore`) as it allows to write expressions like this: + Example: + ``` + match x with + | 0#u32 => ... + | 1#u32 => ... + | ... + ``` + -/ +@[match_pattern] def Scalar.ofIntCore {ty : ScalarTy} (x : Int) + (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : Scalar ty := + { val := x, hmin := h.left, hmax := h.right } + +-- The definitions below are used later to introduce nice syntax for constants, +-- like `1#u32`. We are reusing the technique described here: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Different.20elaboration.20inside.2Foutside.20of.20match.20patterns/near/425455284 + +class InBounds (ty : ScalarTy) (x : Int) := + hInBounds : Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty + +-- This trick to trigger reduction for decidable propositions comes from +-- here: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/instance.20with.20tactic.20autoparam/near/343495807 +class Decide (p : Prop) [Decidable p] : Prop where + isTrue : p +instance : @Decide p (.isTrue h) := @Decide.mk p (_) h + +instance [Decide (Scalar.cMin ty ≤ v ∧ v ≤ Scalar.cMax ty)] : InBounds ty v where + hInBounds := Decide.isTrue + +@[reducible, match_pattern] def Scalar.ofInt {ty : ScalarTy} (x : Int) [InBounds ty x] : Scalar ty := + Scalar.ofIntCore x (Scalar.bound_suffices ty x InBounds.hInBounds) @[simp] def Scalar.check_bounds (ty : ScalarTy) (x : Int) : Bool := (Scalar.cMin ty ≤ x || Scalar.min ty ≤ x) ∧ (x ≤ Scalar.cMax ty || x ≤ Scalar.max ty) @@ -326,7 +339,7 @@ def Scalar.tryMk (ty : ScalarTy) (x : Int) : Result (Scalar ty) := -- ``` -- then normalization blocks (for instance, some proofs which use reflexivity fail). -- However, the version below doesn't block reduction (TODO: investigate): - return Scalar.ofInt x (Scalar.check_bounds_prop h) + return Scalar.ofIntCore x (Scalar.check_bounds_prop h) else fail integerOverflow def Scalar.neg {ty : ScalarTy} (x : Scalar ty) : Result (Scalar ty) := Scalar.tryMk ty (- x.val) @@ -439,8 +452,8 @@ instance (ty : ScalarTy) : Inhabited (Scalar ty) := by constructor; cases ty <;> apply (Scalar.ofInt 0) -- TODO: reducible? -@[reducible] def core_isize_min : Isize := Scalar.ofInt Isize.min (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Isize)) -@[reducible] def core_isize_max : Isize := Scalar.ofInt Isize.max (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Isize)) +@[reducible] def core_isize_min : Isize := Scalar.ofIntCore Isize.min (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Isize)) +@[reducible] def core_isize_max : Isize := Scalar.ofIntCore Isize.max (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Isize)) @[reducible] def core_i8_min : I8 := Scalar.ofInt I8.min @[reducible] def core_i8_max : I8 := Scalar.ofInt I8.max @[reducible] def core_i16_min : I16 := Scalar.ofInt I16.min @@ -453,8 +466,8 @@ instance (ty : ScalarTy) : Inhabited (Scalar ty) := by @[reducible] def core_i128_max : I128 := Scalar.ofInt I128.max -- TODO: reducible? -@[reducible] def core_usize_min : Usize := Scalar.ofInt Usize.min -@[reducible] def core_usize_max : Usize := Scalar.ofInt Usize.max (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Usize)) +@[reducible] def core_usize_min : Usize := Scalar.ofIntCore Usize.min (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Usize)) +@[reducible] def core_usize_max : Usize := Scalar.ofIntCore Usize.max (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Usize)) @[reducible] def core_u8_min : U8 := Scalar.ofInt U8.min @[reducible] def core_u8_max : U8 := Scalar.ofInt U8.max @[reducible] def core_u16_min : U16 := Scalar.ofInt U16.min @@ -985,18 +998,18 @@ def U128.ofIntCore := @Scalar.ofIntCore .U128 -- ofInt -- TODO: typeclass? -abbrev Isize.ofInt := @Scalar.ofInt .Isize -abbrev I8.ofInt := @Scalar.ofInt .I8 -abbrev I16.ofInt := @Scalar.ofInt .I16 -abbrev I32.ofInt := @Scalar.ofInt .I32 -abbrev I64.ofInt := @Scalar.ofInt .I64 -abbrev I128.ofInt := @Scalar.ofInt .I128 -abbrev Usize.ofInt := @Scalar.ofInt .Usize -abbrev U8.ofInt := @Scalar.ofInt .U8 -abbrev U16.ofInt := @Scalar.ofInt .U16 -abbrev U32.ofInt := @Scalar.ofInt .U32 -abbrev U64.ofInt := @Scalar.ofInt .U64 -abbrev U128.ofInt := @Scalar.ofInt .U128 +@[match_pattern] abbrev Isize.ofInt := @Scalar.ofInt .Isize +@[match_pattern] abbrev I8.ofInt := @Scalar.ofInt .I8 +@[match_pattern] abbrev I16.ofInt := @Scalar.ofInt .I16 +@[match_pattern] abbrev I32.ofInt := @Scalar.ofInt .I32 +@[match_pattern] abbrev I64.ofInt := @Scalar.ofInt .I64 +@[match_pattern] abbrev I128.ofInt := @Scalar.ofInt .I128 +@[match_pattern] abbrev Usize.ofInt := @Scalar.ofInt .Usize +@[match_pattern] abbrev U8.ofInt := @Scalar.ofInt .U8 +@[match_pattern] abbrev U16.ofInt := @Scalar.ofInt .U16 +@[match_pattern] abbrev U32.ofInt := @Scalar.ofInt .U32 +@[match_pattern] abbrev U64.ofInt := @Scalar.ofInt .U64 +@[match_pattern] abbrev U128.ofInt := @Scalar.ofInt .U128 postfix:max "#isize" => Isize.ofInt postfix:max "#i8" => I8.ofInt @@ -1011,47 +1024,86 @@ postfix:max "#u32" => U32.ofInt postfix:max "#u64" => U64.ofInt postfix:max "#u128" => U128.ofInt +/- Testing the notations -/ +example := 0#u32 +example := 1#u32 +example := 1#i32 +example := 0#isize +example := (-1)#isize +example (x : U32) : Bool := + match x with + | 0#u32 => true + | _ => false + +example (x : U32) : Bool := + match x with + | 1#u32 => true + | _ => false + +example (x : I32) : Bool := + match x with + | (-1)#i32 => true + | _ => false + +-- Notation for pattern matching +-- We make the precedence looser than the negation. +notation:70 a:70 "#scalar" => Scalar.mk (a) _ _ + +example {ty} (x : Scalar ty) : ℤ := + match x with + | v#scalar => v + +example {ty} (x : Scalar ty) : Bool := + match x with + | 1#scalar => true + | _ => false + +example {ty} (x : Scalar ty) : Bool := + match x with + | -(1 : Int)#scalar => true + | _ => false + -- Testing the notations example : Result Usize := 0#usize + 1#usize -- TODO: factor those lemmas out -@[simp] theorem Scalar.ofInt_val_eq {ty} (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : (Scalar.ofInt x h).val = x := by +@[simp] theorem Scalar.ofInt_val_eq {ty} (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : (Scalar.ofIntCore x h).val = x := by simp [Scalar.ofInt, Scalar.ofIntCore] -@[simp] theorem Isize.ofInt_val_eq (h : Scalar.min ScalarTy.Isize ≤ x ∧ x ≤ Scalar.max ScalarTy.Isize) : (Isize.ofInt x h).val = x := by +@[simp] theorem Isize.ofInt_val_eq (h : Scalar.min ScalarTy.Isize ≤ x ∧ x ≤ Scalar.max ScalarTy.Isize) : (Isize.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -@[simp] theorem I8.ofInt_val_eq (h : Scalar.min ScalarTy.I8 ≤ x ∧ x ≤ Scalar.max ScalarTy.I8) : (I8.ofInt x h).val = x := by +@[simp] theorem I8.ofInt_val_eq (h : Scalar.min ScalarTy.I8 ≤ x ∧ x ≤ Scalar.max ScalarTy.I8) : (I8.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -@[simp] theorem I16.ofInt_val_eq (h : Scalar.min ScalarTy.I16 ≤ x ∧ x ≤ Scalar.max ScalarTy.I16) : (I16.ofInt x h).val = x := by +@[simp] theorem I16.ofInt_val_eq (h : Scalar.min ScalarTy.I16 ≤ x ∧ x ≤ Scalar.max ScalarTy.I16) : (I16.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -@[simp] theorem I32.ofInt_val_eq (h : Scalar.min ScalarTy.I32 ≤ x ∧ x ≤ Scalar.max ScalarTy.I32) : (I32.ofInt x h).val = x := by +@[simp] theorem I32.ofInt_val_eq (h : Scalar.min ScalarTy.I32 ≤ x ∧ x ≤ Scalar.max ScalarTy.I32) : (I32.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -@[simp] theorem I64.ofInt_val_eq (h : Scalar.min ScalarTy.I64 ≤ x ∧ x ≤ Scalar.max ScalarTy.I64) : (I64.ofInt x h).val = x := by +@[simp] theorem I64.ofInt_val_eq (h : Scalar.min ScalarTy.I64 ≤ x ∧ x ≤ Scalar.max ScalarTy.I64) : (I64.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -@[simp] theorem I128.ofInt_val_eq (h : Scalar.min ScalarTy.I128 ≤ x ∧ x ≤ Scalar.max ScalarTy.I128) : (I128.ofInt x h).val = x := by +@[simp] theorem I128.ofInt_val_eq (h : Scalar.min ScalarTy.I128 ≤ x ∧ x ≤ Scalar.max ScalarTy.I128) : (I128.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -@[simp] theorem Usize.ofInt_val_eq (h : Scalar.min ScalarTy.Usize ≤ x ∧ x ≤ Scalar.max ScalarTy.Usize) : (Usize.ofInt x h).val = x := by +@[simp] theorem Usize.ofInt_val_eq (h : Scalar.min ScalarTy.Usize ≤ x ∧ x ≤ Scalar.max ScalarTy.Usize) : (Usize.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -@[simp] theorem U8.ofInt_val_eq (h : Scalar.min ScalarTy.U8 ≤ x ∧ x ≤ Scalar.max ScalarTy.U8) : (U8.ofInt x h).val = x := by +@[simp] theorem U8.ofInt_val_eq (h : Scalar.min ScalarTy.U8 ≤ x ∧ x ≤ Scalar.max ScalarTy.U8) : (U8.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -@[simp] theorem U16.ofInt_val_eq (h : Scalar.min ScalarTy.U16 ≤ x ∧ x ≤ Scalar.max ScalarTy.U16) : (U16.ofInt x h).val = x := by +@[simp] theorem U16.ofInt_val_eq (h : Scalar.min ScalarTy.U16 ≤ x ∧ x ≤ Scalar.max ScalarTy.U16) : (U16.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -@[simp] theorem U32.ofInt_val_eq (h : Scalar.min ScalarTy.U32 ≤ x ∧ x ≤ Scalar.max ScalarTy.U32) : (U32.ofInt x h).val = x := by +@[simp] theorem U32.ofInt_val_eq (h : Scalar.min ScalarTy.U32 ≤ x ∧ x ≤ Scalar.max ScalarTy.U32) : (U32.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -@[simp] theorem U64.ofInt_val_eq (h : Scalar.min ScalarTy.U64 ≤ x ∧ x ≤ Scalar.max ScalarTy.U64) : (U64.ofInt x h).val = x := by +@[simp] theorem U64.ofInt_val_eq (h : Scalar.min ScalarTy.U64 ≤ x ∧ x ≤ Scalar.max ScalarTy.U64) : (U64.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -@[simp] theorem U128.ofInt_val_eq (h : Scalar.min ScalarTy.U128 ≤ x ∧ x ≤ Scalar.max ScalarTy.U128) : (U128.ofInt x h).val = x := by +@[simp] theorem U128.ofInt_val_eq (h : Scalar.min ScalarTy.U128 ≤ x ∧ x ≤ Scalar.max ScalarTy.U128) : (U128.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -- Comparisons @@ -1133,22 +1185,4 @@ instance (ty : ScalarTy) : DecidableEq (Scalar ty) := -- else -- .fail integerOverflow --- Notation for pattern matching --- We make the precedence looser than the negation. -notation:70 a:70 "#scalar" => Scalar.mk (a) _ _ - -example {ty} (x : Scalar ty) : ℤ := - match x with - | v#scalar => v - -example {ty} (x : Scalar ty) : Bool := - match x with - | 1#scalar => true - | _ => false - -example {ty} (x : Scalar ty) : Bool := - match x with - | -(1 : Int)#scalar => true - | _ => false - end Primitives -- cgit v1.2.3 From b6f63f106baef03dd61f1100bd46c9bad7cb79e4 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 8 Mar 2024 08:53:38 +0100 Subject: Remove some comments --- backends/lean/Base/Primitives/Scalar.lean | 31 ------------------------------- 1 file changed, 31 deletions(-) (limited to 'backends/lean/Base/Primitives/Scalar.lean') diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean index bf6b01a6..3d90f1a5 100644 --- a/backends/lean/Base/Primitives/Scalar.lean +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -1154,35 +1154,4 @@ instance (ty : ScalarTy) : DecidableEq (Scalar ty) := @[simp] theorem Scalar.neq_to_neq_val {ty} : ∀ {i j : Scalar ty}, (¬ i = j) ↔ ¬ i.val = j.val := by intro i j; cases i; cases j; simp --- -- We now define a type class that subsumes the various machine integer types, so --- -- as to write a concise definition for scalar_cast, rather than exhaustively --- -- enumerating all of the possible pairs. We remark that Rust has sane semantics --- -- and fails if a cast operation would involve a truncation or modulo. - --- class MachineInteger (t: Type) where --- size: Nat --- val: t -> Fin size --- ofNatCore: (n:Nat) -> LT.lt n size -> t - --- set_option hygiene false in --- run_cmd --- for typeName in [`UInt8, `UInt16, `UInt32, `UInt64, `USize].map Lean.mkIdent do --- Lean.Elab.Command.elabCommand (← `( --- namespace $typeName --- instance: MachineInteger $typeName where --- size := size --- val := val --- ofNatCore := ofNatCore --- end $typeName --- )) - --- -- Aeneas only instantiates the destination type (`src` is implicit). We rely on --- -- Lean to infer `src`. - --- def scalar_cast { src: Type } (dst: Type) [ MachineInteger src ] [ MachineInteger dst ] (x: src): Result dst := --- if h: MachineInteger.val x < MachineInteger.size dst then --- .ret (MachineInteger.ofNatCore (MachineInteger.val x).val h) --- else --- .fail integerOverflow - end Primitives -- cgit v1.2.3