summaryrefslogtreecommitdiff
path: root/backends/lean
diff options
context:
space:
mode:
Diffstat (limited to 'backends/lean')
-rw-r--r--backends/lean/Base/Primitives/Scalar.lean272
1 files changed, 251 insertions, 21 deletions
diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean
index 98d695a4..8de2b3f2 100644
--- a/backends/lean/Base/Primitives/Scalar.lean
+++ b/backends/lean/Base/Primitives/Scalar.lean
@@ -325,33 +325,65 @@ instance [Decide (Scalar.cMin ty ≤ v ∧ v ≤ Scalar.cMax ty)] : InBounds ty
@[reducible, match_pattern] def Scalar.ofInt {ty : ScalarTy} (x : Int) [InBounds ty x] : Scalar ty :=
Scalar.ofIntCore x (Scalar.bound_suffices ty x InBounds.hInBounds)
-@[simp] def Scalar.check_bounds (ty : ScalarTy) (x : Int) : Bool :=
+@[simp] abbrev Scalar.in_bounds (ty : ScalarTy) (x : Int) : Prop :=
+ Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty
+
+@[simp] abbrev Scalar.check_bounds (ty : ScalarTy) (x : Int) : Bool :=
(Scalar.cMin ty ≤ x || Scalar.min ty ≤ x) ∧ (x ≤ Scalar.cMax ty || x ≤ Scalar.max ty)
-theorem Scalar.check_bounds_prop {ty : ScalarTy} {x : Int} (h: Scalar.check_bounds ty x) :
- Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty := by
+theorem Scalar.check_bounds_imp_in_bounds {ty : ScalarTy} {x : Int}
+ (h: Scalar.check_bounds ty x) :
+ Scalar.in_bounds ty x := by
simp at *
have ⟨ hmin, hmax ⟩ := h
have hbmin := Scalar.cMin_bound ty
have hbmax := Scalar.cMax_bound ty
cases hmin <;> cases hmax <;> apply And.intro <;> linarith
+theorem Scalar.check_bounds_eq_in_bounds (ty : ScalarTy) (x : Int) :
+ Scalar.check_bounds ty x ↔ Scalar.in_bounds ty x := by
+ constructor <;> intro h
+ . apply (check_bounds_imp_in_bounds h)
+ . simp_all
+
-- Further thoughts: look at what has been done here:
-- https://github.com/leanprover-community/mathlib4/blob/master/Mathlib/Data/Fin/Basic.lean
-- and
-- https://github.com/leanprover-community/mathlib4/blob/master/Mathlib/Data/UInt.lean
-- which both contain a fair amount of reasoning already!
-def Scalar.tryMk (ty : ScalarTy) (x : Int) : Result (Scalar ty) :=
+def Scalar.tryMkOpt (ty : ScalarTy) (x : Int) : Option (Scalar ty) :=
if h:Scalar.check_bounds ty x then
-- If we do:
-- ```
- -- let ⟨ hmin, hmax ⟩ := (Scalar.check_bounds_prop h)
+ -- let ⟨ hmin, hmax ⟩ := (Scalar.check_bounds_imp_in_bounds h)
-- Scalar.ofIntCore x hmin hmax
-- ```
-- then normalization blocks (for instance, some proofs which use reflexivity fail).
-- However, the version below doesn't block reduction (TODO: investigate):
- ok (Scalar.ofIntCore x (Scalar.check_bounds_prop h))
- else fail integerOverflow
+ some (Scalar.ofIntCore x (Scalar.check_bounds_imp_in_bounds h))
+ else none
+
+def Result.ofOption {a : Type u} (x : Option a) (e : Error) : Result a :=
+ match x with
+ | some x => ok x
+ | none => fail e
+
+def Scalar.tryMk (ty : ScalarTy) (x : Int) : Result (Scalar ty) :=
+ Result.ofOption (tryMkOpt ty x) integerOverflow
+
+theorem Scalar.tryMk_eq (ty : ScalarTy) (x : Int) :
+ match tryMk ty x with
+ | ok y => y.val = x ∧ in_bounds ty x
+ | fail _ => ¬ (in_bounds ty x)
+ | _ => False := by
+ simp [tryMk, ofOption, tryMkOpt, ofIntCore]
+ have h := check_bounds_eq_in_bounds ty x
+ split_ifs <;> simp_all
+
+@[simp] theorem Scalar.tryMk_eq_div (ty : ScalarTy) (x : Int) :
+ tryMk ty x = div ↔ False := by
+ simp [tryMk, ofOption, tryMkOpt]
+ split_ifs <;> simp
def Scalar.neg {ty : ScalarTy} (x : Scalar ty) : Result (Scalar ty) := Scalar.tryMk ty (- x.val)
@@ -579,17 +611,121 @@ instance {ty} : HOr (Scalar ty) (Scalar ty) (Scalar ty) where
instance {ty} : HAnd (Scalar ty) (Scalar ty) (Scalar ty) where
hAnd x y := Scalar.and x y
+-- core checked arithmetic operations
+
+/- A helper function that converts failure to none and success to some
+ TODO: move up to Base module? -/
+def Option.ofResult {a : Type u} (x : Result a) :
+ Option a :=
+ match x with
+ | ok x => some x
+ | _ => none
+
+/- [core::num::{T}::checked_add] -/
+def core.num.checked_add (x y : Scalar ty) : Option (Scalar ty) :=
+ Option.ofResult (x + y)
+
+def U8.checked_add (x y : U8) : Option U8 := core.num.checked_add x y
+def U16.checked_add (x y : U16) : Option U16 := core.num.checked_add x y
+def U32.checked_add (x y : U32) : Option U32 := core.num.checked_add x y
+def U64.checked_add (x y : U64) : Option U64 := core.num.checked_add x y
+def U128.checked_add (x y : U128) : Option U128 := core.num.checked_add x y
+def Usize.checked_add (x y : Usize) : Option Usize := core.num.checked_add x y
+def I8.checked_add (x y : I8) : Option I8 := core.num.checked_add x y
+def I16.checked_add (x y : I16) : Option I16 := core.num.checked_add x y
+def I32.checked_add (x y : I32) : Option I32 := core.num.checked_add x y
+def I64.checked_add (x y : I64) : Option I64 := core.num.checked_add x y
+def I128.checked_add (x y : I128) : Option I128 := core.num.checked_add x y
+def Isize.checked_add (x y : Isize) : Option Isize := core.num.checked_add x y
+
+/- [core::num::{T}::checked_sub] -/
+def core.num.checked_sub (x y : Scalar ty) : Option (Scalar ty) :=
+ Option.ofResult (x - y)
+
+def U8.checked_sub (x y : U8) : Option U8 := core.num.checked_sub x y
+def U16.checked_sub (x y : U16) : Option U16 := core.num.checked_sub x y
+def U32.checked_sub (x y : U32) : Option U32 := core.num.checked_sub x y
+def U64.checked_sub (x y : U64) : Option U64 := core.num.checked_sub x y
+def U128.checked_sub (x y : U128) : Option U128 := core.num.checked_sub x y
+def Usize.checked_sub (x y : Usize) : Option Usize := core.num.checked_sub x y
+def I8.checked_sub (x y : I8) : Option I8 := core.num.checked_sub x y
+def I16.checked_sub (x y : I16) : Option I16 := core.num.checked_sub x y
+def I32.checked_sub (x y : I32) : Option I32 := core.num.checked_sub x y
+def I64.checked_sub (x y : I64) : Option I64 := core.num.checked_sub x y
+def I128.checked_sub (x y : I128) : Option I128 := core.num.checked_sub x y
+def Isize.checked_sub (x y : Isize) : Option Isize := core.num.checked_sub x y
+
+/- [core::num::{T}::checked_mul] -/
+def core.num.checked_mul (x y : Scalar ty) : Option (Scalar ty) :=
+ Option.ofResult (x * y)
+
+def U8.checked_mul (x y : U8) : Option U8 := core.num.checked_mul x y
+def U16.checked_mul (x y : U16) : Option U16 := core.num.checked_mul x y
+def U32.checked_mul (x y : U32) : Option U32 := core.num.checked_mul x y
+def U64.checked_mul (x y : U64) : Option U64 := core.num.checked_mul x y
+def U128.checked_mul (x y : U128) : Option U128 := core.num.checked_mul x y
+def Usize.checked_mul (x y : Usize) : Option Usize := core.num.checked_mul x y
+def I8.checked_mul (x y : I8) : Option I8 := core.num.checked_mul x y
+def I16.checked_mul (x y : I16) : Option I16 := core.num.checked_mul x y
+def I32.checked_mul (x y : I32) : Option I32 := core.num.checked_mul x y
+def I64.checked_mul (x y : I64) : Option I64 := core.num.checked_mul x y
+def I128.checked_mul (x y : I128) : Option I128 := core.num.checked_mul x y
+def Isize.checked_mul (x y : Isize) : Option Isize := core.num.checked_mul x y
+
+/- [core::num::{T}::checked_div] -/
+def core.num.checked_div (x y : Scalar ty) : Option (Scalar ty) :=
+ Option.ofResult (x / y)
+
+def U8.checked_div (x y : U8) : Option U8 := core.num.checked_div x y
+def U16.checked_div (x y : U16) : Option U16 := core.num.checked_div x y
+def U32.checked_div (x y : U32) : Option U32 := core.num.checked_div x y
+def U64.checked_div (x y : U64) : Option U64 := core.num.checked_div x y
+def U128.checked_div (x y : U128) : Option U128 := core.num.checked_div x y
+def Usize.checked_div (x y : Usize) : Option Usize := core.num.checked_div x y
+def I8.checked_div (x y : I8) : Option I8 := core.num.checked_div x y
+def I16.checked_div (x y : I16) : Option I16 := core.num.checked_div x y
+def I32.checked_div (x y : I32) : Option I32 := core.num.checked_div x y
+def I64.checked_div (x y : I64) : Option I64 := core.num.checked_div x y
+def I128.checked_div (x y : I128) : Option I128 := core.num.checked_div x y
+def Isize.checked_div (x y : Isize) : Option Isize := core.num.checked_div x y
+
+/- [core::num::{T}::checked_rem] -/
+def core.num.checked_rem (x y : Scalar ty) : Option (Scalar ty) :=
+ Option.ofResult (x % y)
+
+def U8.checked_rem (x y : U8) : Option U8 := core.num.checked_rem x y
+def U16.checked_rem (x y : U16) : Option U16 := core.num.checked_rem x y
+def U32.checked_rem (x y : U32) : Option U32 := core.num.checked_rem x y
+def U64.checked_rem (x y : U64) : Option U64 := core.num.checked_rem x y
+def U128.checked_rem (x y : U128) : Option U128 := core.num.checked_rem x y
+def Usize.checked_rem (x y : Usize) : Option Usize := core.num.checked_rem x y
+def I8.checked_rem (x y : I8) : Option I8 := core.num.checked_rem x y
+def I16.checked_rem (x y : I16) : Option I16 := core.num.checked_rem x y
+def I32.checked_rem (x y : I32) : Option I32 := core.num.checked_rem x y
+def I64.checked_rem (x y : I64) : Option I64 := core.num.checked_rem x y
+def I128.checked_rem (x y : I128) : Option I128 := core.num.checked_rem x y
+def Isize.checked_rem (x y : Isize) : Option Isize := core.num.checked_rem x y
+
+theorem Scalar.add_equiv {ty} {x y : Scalar ty} :
+ match x + y with
+ | ok z => Scalar.in_bounds ty (↑x + ↑y) ∧ (↑z : Int) = ↑x + ↑y
+ | fail _ => ¬ (Scalar.in_bounds ty (↑x + ↑y))
+ | _ => ⊥ := by
+ -- Applying the unfoldings only inside the match
+ conv in _ + _ => unfold HAdd.hAdd instHAddScalarResult; simp [add]
+ have h := tryMk_eq ty (↑x + ↑y)
+ simp [in_bounds] at h
+ split at h <;> simp_all [check_bounds_eq_in_bounds]
+
-- Generic theorem - shouldn't be used much
@[pspec]
theorem Scalar.add_spec {ty} {x y : Scalar ty}
(hmin : Scalar.min ty ≤ ↑x + y.val)
(hmax : ↑x + ↑y ≤ Scalar.max ty) :
(∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y) := by
- -- Applying the unfoldings only on the left
- conv => congr; ext; lhs; unfold HAdd.hAdd instHAddScalarResult; simp [add, tryMk]
- split
- . simp [pure]; rfl
- . tauto
+ have h := @add_equiv ty x y
+ split at h <;> simp_all
+ apply h
theorem Scalar.add_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty}
(hmax : ↑x + ↑y ≤ Scalar.max ty) :
@@ -655,17 +791,36 @@ theorem Scalar.add_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty}
∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y :=
Scalar.add_spec hmin hmax
+theorem core.num.checked_add_spec {ty} {x y : Scalar ty} :
+ match core.num.checked_add x y with
+ | some z => Scalar.in_bounds ty (↑x + ↑y) ∧ ↑z = (↑x + ↑y : Int)
+ | none => ¬ (Scalar.in_bounds ty (↑x + ↑y)) := by
+ have h := Scalar.tryMk_eq ty (↑x + ↑y)
+ simp only [checked_add, Option.ofResult]
+ cases heq: x + y <;> simp_all <;> simp [HAdd.hAdd, Scalar.add] at heq
+ <;> simp [Add.add] at heq
+ <;> simp_all
+
+theorem Scalar.sub_equiv {ty} {x y : Scalar ty} :
+ match x - y with
+ | ok z => Scalar.in_bounds ty (↑x - ↑y) ∧ (↑z : Int) = ↑x - ↑y
+ | fail _ => ¬ (Scalar.in_bounds ty (↑x - ↑y))
+ | _ => ⊥ := by
+ -- Applying the unfoldings only inside the match
+ conv in _ - _ => unfold HSub.hSub instHSubScalarResult; simp [sub]
+ have h := tryMk_eq ty (↑x - ↑y)
+ simp [in_bounds] at h
+ split at h <;> simp_all [check_bounds_eq_in_bounds]
+
-- Generic theorem - shouldn't be used much
@[pspec]
theorem Scalar.sub_spec {ty} {x y : Scalar ty}
(hmin : Scalar.min ty ≤ ↑x - ↑y)
(hmax : ↑x - ↑y ≤ Scalar.max ty) :
∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y := by
- conv => congr; ext; lhs; simp [HSub.hSub, sub, tryMk, Sub.sub]
- split
- . simp [pure]
- rfl
- . tauto
+ have h := @sub_equiv ty x y
+ split at h <;> simp_all
+ apply h
theorem Scalar.sub_unsigned_spec {ty : ScalarTy} (s : ¬ ty.isSigned)
{x y : Scalar ty} (hmin : Scalar.min ty ≤ ↑x - ↑y) :
@@ -739,12 +894,33 @@ theorem Scalar.mul_spec {ty} {x y : Scalar ty}
(hmax : ↑x * ↑y ≤ Scalar.max ty) :
∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y := by
conv => congr; ext; lhs; simp [HMul.hMul]
- simp [mul, tryMk]
- split
+ simp [mul, tryMk, tryMkOpt, ofOption]
+ split_ifs
. simp [pure]
rfl
. tauto
+theorem core.num.checked_sub_spec {ty} {x y : Scalar ty} :
+ match core.num.checked_sub x y with
+ | some z => Scalar.in_bounds ty (↑x - ↑y) ∧ ↑z = (↑x - ↑y : Int)
+ | none => ¬ (Scalar.in_bounds ty (↑x - ↑y)) := by
+ have h := Scalar.tryMk_eq ty (↑x - ↑y)
+ simp only [checked_sub, Option.ofResult]
+ have add_neg_eq : x.val + (-y.val) = x.val - y.val := by omega -- TODO: why do we need this??
+ cases heq: x - y <;> simp_all <;> simp only [HSub.hSub, Scalar.sub, Sub.sub, Int.sub] at heq
+ <;> simp_all
+
+theorem Scalar.mul_equiv {ty} {x y : Scalar ty} :
+ match x * y with
+ | ok z => Scalar.in_bounds ty (↑x * ↑y) ∧ (↑z : Int) = ↑x * ↑y
+ | fail _ => ¬ (Scalar.in_bounds ty (↑x * ↑y))
+ | _ => ⊥ := by
+ -- Applying the unfoldings only inside the match
+ conv in _ * _ => unfold HMul.hMul instHMulScalarResult; simp [mul]
+ have h := tryMk_eq ty (↑x * ↑y)
+ simp [in_bounds] at h
+ split at h <;> simp_all [check_bounds_eq_in_bounds]
+
theorem Scalar.mul_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty}
(hmax : ↑x * ↑y ≤ Scalar.max ty) :
∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y := by
@@ -809,6 +985,28 @@ theorem Scalar.mul_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty}
∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y :=
Scalar.mul_spec hmin hmax
+theorem core.num.checked_mul_spec {ty} {x y : Scalar ty} :
+ match core.num.checked_mul x y with
+ | some z => Scalar.in_bounds ty (↑x * ↑y) ∧ ↑z = (↑x * ↑y : Int)
+ | none => ¬ (Scalar.in_bounds ty (↑x * ↑y)) := by
+ have h := Scalar.tryMk_eq ty (↑x * ↑y)
+ simp only [checked_mul, Option.ofResult]
+ have : Int.mul ↑x ↑y = ↑x * ↑y := by simp -- TODO: why do we need this??
+ cases heq: x * y <;> simp_all <;> simp only [HMul.hMul, Scalar.mul, Mul.mul] at heq
+ <;> simp_all
+
+theorem Scalar.div_equiv {ty} {x y : Scalar ty} :
+ match x / y with
+ | ok z => y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_div ↑x ↑y) ∧ (↑z : Int) = scalar_div ↑x ↑y
+ | fail _ => ¬ (y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_div ↑x ↑y))
+ | _ => ⊥ := by
+ -- Applying the unfoldings only inside the match
+ conv in _ / _ => unfold HDiv.hDiv instHDivScalarResult; simp [div]
+ have h := tryMk_eq ty (scalar_div ↑x ↑y)
+ simp [in_bounds] at h
+ split_ifs <;> simp <;>
+ split at h <;> simp_all [check_bounds_eq_in_bounds]
+
-- Generic theorem - shouldn't be used much
@[pspec]
theorem Scalar.div_spec {ty} {x y : Scalar ty}
@@ -817,7 +1015,7 @@ theorem Scalar.div_spec {ty} {x y : Scalar ty}
(hmax : scalar_div ↑x ↑y ≤ Scalar.max ty) :
∃ z, x / y = ok z ∧ (↑z : Int) = scalar_div ↑x ↑y := by
simp [HDiv.hDiv, div, Div.div]
- simp [tryMk, *]
+ simp [tryMk, tryMkOpt, ofOption, *]
rfl
theorem Scalar.div_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : Scalar ty}
@@ -903,6 +1101,28 @@ theorem Scalar.div_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : S
∃ z, x / y = ok z ∧ (↑z : Int) = scalar_div ↑x ↑y :=
Scalar.div_spec hnz hmin hmax
+theorem core.num.checked_div_spec {ty} {x y : Scalar ty} :
+ match core.num.checked_div x y with
+ | some z => y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_div ↑x ↑y) ∧ ↑z = (scalar_div ↑x ↑y : Int)
+ | none => ¬ (y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_div ↑x ↑y)) := by
+ have h := Scalar.tryMk_eq ty (scalar_div ↑x ↑y)
+ simp only [checked_div, Option.ofResult]
+ cases heq0: (y.val = 0 : Bool) <;>
+ cases heq1: x / y <;> simp_all <;> simp only [HDiv.hDiv, Scalar.div, Div.div] at heq1
+ <;> simp_all
+
+theorem Scalar.rem_equiv {ty} {x y : Scalar ty} :
+ match x % y with
+ | ok z => y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_rem ↑x ↑y) ∧ (↑z : Int) = scalar_rem ↑x ↑y
+ | fail _ => ¬ (y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_rem ↑x ↑y))
+ | _ => ⊥ := by
+ -- Applying the unfoldings only inside the match
+ conv in _ % _ => unfold HMod.hMod instHModScalarResult; simp [rem]
+ have h := tryMk_eq ty (scalar_rem ↑x ↑y)
+ simp [in_bounds] at h
+ split_ifs <;> simp <;>
+ split at h <;> simp_all [check_bounds_eq_in_bounds]
+
-- Generic theorem - shouldn't be used much
@[pspec]
theorem Scalar.rem_spec {ty} {x y : Scalar ty}
@@ -911,7 +1131,7 @@ theorem Scalar.rem_spec {ty} {x y : Scalar ty}
(hmax : scalar_rem ↑x ↑y ≤ Scalar.max ty) :
∃ z, x % y = ok z ∧ (↑z : Int) = scalar_rem ↑x ↑y := by
simp [HMod.hMod, rem]
- simp [tryMk, *]
+ simp [tryMk, tryMkOpt, ofOption, *]
rfl
theorem Scalar.rem_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : Scalar ty}
@@ -990,6 +1210,16 @@ theorem Scalar.rem_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : S
∃ z, x % y = ok z ∧ (↑z : Int) = scalar_rem ↑x ↑y :=
Scalar.rem_spec hnz hmin hmax
+theorem core.num.checked_rem_spec {ty} {x y : Scalar ty} :
+ match core.num.checked_rem x y with
+ | some z => y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_rem ↑x ↑y) ∧ ↑z = (scalar_rem ↑x ↑y : Int)
+ | none => ¬ (y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_rem ↑x ↑y)) := by
+ have h := Scalar.tryMk_eq ty (scalar_rem ↑x ↑y)
+ simp only [checked_rem, Option.ofResult]
+ cases heq0: (y.val = 0 : Bool) <;>
+ cases heq1: x % y <;> simp_all <;> simp only [HMod.hMod, Scalar.rem, Mod.mod] at heq1
+ <;> simp_all
+
-- ofIntCore
-- TODO: typeclass?
def Isize.ofIntCore := @Scalar.ofIntCore .Isize