summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--backends/lean/Base/Primitives/Scalar.lean167
1 files changed, 161 insertions, 6 deletions
diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean
index 3f88caa2..aaa4027f 100644
--- a/backends/lean/Base/Primitives/Scalar.lean
+++ b/backends/lean/Base/Primitives/Scalar.lean
@@ -3,6 +3,8 @@ import Lean.Meta.Tactic.Simp
import Mathlib.Tactic.Linarith
import Base.Primitives.Base
import Base.Diverge.Base
+import Base.Progress.Base
+import Base.Arith.Int
namespace Primitives
@@ -122,6 +124,22 @@ inductive ScalarTy :=
| U64
| U128
+def ScalarTy.isSigned (ty : ScalarTy) : Bool :=
+ match ty with
+ | Isize
+ | I8
+ | I16
+ | I32
+ | I64
+ | I128 => true
+ | Usize
+ | U8
+ | U16
+ | U32
+ | U64
+ | U128 => false
+
+
def Scalar.smin (ty : ScalarTy) : Int :=
match ty with
| .Isize => Isize.smin
@@ -289,23 +307,30 @@ def Scalar.tryMk (ty : ScalarTy) (x : Int) : Result (Scalar ty) :=
def Scalar.neg {ty : ScalarTy} (x : Scalar ty) : Result (Scalar ty) := Scalar.tryMk ty (- x.val)
-def Scalar.div {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) :=
- if y.val != 0 then Scalar.tryMk ty (x.val / y.val) else fail divisionByZero
-
-- Our custom remainder operation, which satisfies the semantics of Rust
-- TODO: is there a better way?
def scalar_rem (x y : Int) : Int :=
- if 0 ≤ x then |x| % |y|
+ if 0 ≤ x then x % y
else - (|x| % |y|)
+@[simp]
+def scalar_rem_nonneg {x y : Int} (hx : 0 ≤ x) : scalar_rem x y = x % y := by
+ intros
+ simp [*, scalar_rem]
+
-- Our custom division operation, which satisfies the semantics of Rust
-- TODO: is there a better way?
def scalar_div (x y : Int) : Int :=
- if 0 ≤ x && 0 ≤ y then |x| / |y|
+ if 0 ≤ x && 0 ≤ y then x / y
else if 0 ≤ x && y < 0 then - (|x| / |y|)
else if x < 0 && 0 ≤ y then - (|x| / |y|)
else |x| / |y|
+@[simp]
+def scalar_div_nonneg {x y : Int} (hx : 0 ≤ x) (hy : 0 ≤ y) : scalar_div x y = x / y := by
+ intros
+ simp [*, scalar_div]
+
-- Checking that the remainder operation is correct
#assert scalar_rem 1 2 = 1
#assert scalar_rem (-1) 2 = -1
@@ -326,8 +351,11 @@ def scalar_div (x y : Int) : Int :=
#assert scalar_div 7 (-3) = -2
#assert scalar_div (-7) (-3) = 2
+def Scalar.div {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) :=
+ if y.val != 0 then Scalar.tryMk ty (scalar_div x.val y.val) else fail divisionByZero
+
def Scalar.rem {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) :=
- if y.val != 0 then Scalar.tryMk ty (x.val % y.val) else fail divisionByZero
+ if y.val != 0 then Scalar.tryMk ty (scalar_rem x.val y.val) else fail divisionByZero
def Scalar.add {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) :=
Scalar.tryMk ty (x.val + y.val)
@@ -410,6 +438,133 @@ instance {ty} : HDiv (Scalar ty) (Scalar ty) (Result (Scalar ty)) where
instance {ty} : HMod (Scalar ty) (Scalar ty) (Result (Scalar ty)) where
hMod x y := Scalar.rem x y
+-- TODO: make progress work at a more fine grained level (see `Scalar.add_unsigned_spec`)
+@[cpspec]
+theorem Scalar.add_spec {ty} {x y : Scalar ty}
+ (hmin : Scalar.min ty ≤ x.val + y.val)
+ (hmax : x.val + y.val ≤ Scalar.max ty) :
+ ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by
+ simp [HAdd.hAdd, add, Add.add]
+ simp [tryMk]
+ split
+ . simp [pure]
+ rfl
+ . tauto
+
+theorem Scalar.add_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty}
+ (hmax : x.val + y.val ≤ Scalar.max ty) :
+ ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by
+ have hmin : Scalar.min ty ≤ x.val + y.val := by
+ have hx := x.hmin
+ have hy := y.hmin
+ cases ty <;> simp [min] at * <;> linarith
+ apply add_spec <;> assumption
+
+-- TODO: make it finer grained
+@[cpspec]
+theorem Scalar.sub_spec {ty} {x y : Scalar ty}
+ (hmin : Scalar.min ty ≤ x.val - y.val)
+ (hmax : x.val - y.val ≤ Scalar.max ty) :
+ ∃ z, x - y = ret z ∧ z.val = x.val - y.val := by
+ simp [HSub.hSub, sub, Sub.sub]
+ simp [tryMk]
+ split
+ . simp [pure]
+ rfl
+ . tauto
+
+theorem Scalar.sub_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty}
+ (hmin : Scalar.min ty ≤ x.val - y.val) :
+ ∃ z, x - y = ret z ∧ z.val = x.val - y.val := by
+ have : x.val - y.val ≤ Scalar.max ty := by
+ have hx := x.hmin
+ have hxm := x.hmax
+ have hy := y.hmin
+ cases ty <;> simp [min, max] at * <;> linarith
+ intros
+ apply sub_spec <;> assumption
+
+-- TODO: make it finer grained
+@[cpspec]
+theorem Scalar.mul_spec {ty} {x y : Scalar ty}
+ (hmin : Scalar.min ty ≤ x.val * y.val)
+ (hmax : x.val * y.val ≤ Scalar.max ty) :
+ ∃ z, x * y = ret z ∧ z.val = x.val * y.val := by
+ simp [HMul.hMul, mul, Mul.mul]
+ simp [tryMk]
+ split
+ . simp [pure]
+ rfl
+ . tauto
+
+theorem Scalar.mul_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty}
+ (hmax : x.val * y.val ≤ Scalar.max ty) :
+ ∃ z, x * y = ret z ∧ z.val = x.val * y.val := by
+ have : Scalar.min ty ≤ x.val * y.val := by
+ have hx := x.hmin
+ have hy := y.hmin
+ cases ty <;> simp at * <;> apply mul_nonneg hx hy
+ apply mul_spec <;> assumption
+
+-- TODO: make it finer grained
+@[cpspec]
+theorem Scalar.div_spec {ty} {x y : Scalar ty}
+ (hnz : y.val ≠ 0)
+ (hmin : Scalar.min ty ≤ scalar_div x.val y.val)
+ (hmax : scalar_div x.val y.val ≤ Scalar.max ty) :
+ ∃ z, x / y = ret z ∧ z.val = scalar_div x.val y.val := by
+ simp [HDiv.hDiv, div, Div.div]
+ simp [tryMk, *]
+ simp [pure]
+ rfl
+
+theorem Scalar.div_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : Scalar ty}
+ (hnz : y.val ≠ 0) :
+ ∃ z, x / y = ret z ∧ z.val = x.val / y.val := by
+ have h : Scalar.min ty = 0 := by cases ty <;> simp at *
+ have hx := x.hmin
+ have hy := y.hmin
+ simp [h] at hx hy
+ have hmin : 0 ≤ x.val / y.val := Int.ediv_nonneg hx hy
+ have hmax : x.val / y.val ≤ Scalar.max ty := by
+ have := Int.ediv_le_self y.val hx
+ have := x.hmax
+ linarith
+ have hs := @div_spec ty x y hnz
+ simp [*] at hs
+ apply hs
+
+-- TODO: make it finer grained
+@[cpspec]
+theorem Scalar.rem_spec {ty} {x y : Scalar ty}
+ (hnz : y.val ≠ 0)
+ (hmin : Scalar.min ty ≤ scalar_rem x.val y.val)
+ (hmax : scalar_rem x.val y.val ≤ Scalar.max ty) :
+ ∃ z, x % y = ret z ∧ z.val = scalar_rem x.val y.val := by
+ simp [HMod.hMod, rem]
+ simp [tryMk, *]
+ simp [pure]
+ rfl
+
+theorem Scalar.rem_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : Scalar ty}
+ (hnz : y.val ≠ 0) :
+ ∃ z, x % y = ret z ∧ z.val = scalar_rem x.val y.val := by
+ have h : Scalar.min ty = 0 := by cases ty <;> simp at *
+ have hx := x.hmin
+ have hy := y.hmin
+ simp [h] at hx hy
+ have hmin : 0 ≤ x.val % y.val := Int.emod_nonneg x.val hnz
+ have hmax : x.val % y.val ≤ Scalar.max ty := by
+ have h := @Int.ediv_emod_unique x.val y.val (x.val % y.val) (x.val / y.val)
+ simp at h
+ have : 0 < y.val := by int_tac
+ simp [*] at h
+ have := y.hmax
+ linarith
+ have hs := @rem_spec ty x y hnz
+ simp [*] at hs
+ simp [*]
+
-- ofIntCore
-- TODO: typeclass?
def Isize.ofIntCore := @Scalar.ofIntCore .Isize