summaryrefslogtreecommitdiff
path: root/backends/lean
diff options
context:
space:
mode:
authorSon Ho2023-06-06 15:53:46 +0200
committerSon Ho2023-06-06 15:53:46 +0200
commitacc09d5c69690f2c46cb1bacf290da5dcc268b24 (patch)
tree04384980d21b90b85ae047d65b4139824b1dd635 /backends/lean
parent53adf30fe440eb8b6f58ba89f4a4c0acc7877498 (diff)
Remove the sorries from Primitives.lean
Diffstat (limited to 'backends/lean')
-rw-r--r--backends/lean/Primitives.lean184
-rw-r--r--backends/lean/lakefile.lean1
2 files changed, 120 insertions, 65 deletions
diff --git a/backends/lean/Primitives.lean b/backends/lean/Primitives.lean
index 4a66a453..e7826fbf 100644
--- a/backends/lean/Primitives.lean
+++ b/backends/lean/Primitives.lean
@@ -2,9 +2,10 @@ import Lean
import Lean.Meta.Tactic.Simp
import Init.Data.List.Basic
import Mathlib.Tactic.RunCmd
+import Mathlib.Tactic.Linarith
--------------------
--- ASSERT COMMAND --
+-- ASSERT COMMAND --Std.
--------------------
open Lean Elab Command Term Meta
@@ -249,27 +250,53 @@ def Scalar.cMax (ty : ScalarTy) : Int :=
| .Usize => U32.max
| _ => Scalar.max ty
-theorem Scalar.cMin_bound ty : Scalar.min ty <= Scalar.cMin ty := by sorry
-theorem Scalar.cMax_bound ty : Scalar.min ty <= Scalar.cMin ty := by sorry
+theorem Scalar.cMin_bound ty : Scalar.min ty ≤ Scalar.cMin ty := by
+ cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at *
+ cases System.Platform.numBits_eq <;>
+ unfold System.Platform.numBits at * <;>
+ simp [*]
+
+theorem Scalar.cMax_bound ty : Scalar.cMax ty ≤ Scalar.max ty := by
+ cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at * <;>
+ cases System.Platform.numBits_eq <;>
+ unfold System.Platform.numBits at * <;>
+ simp [*]
+
+theorem Scalar.cMin_suffices ty (h : Scalar.cMin ty ≤ x) : Scalar.min ty ≤ x := by
+ cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at * <;>
+ cases System.Platform.numBits_eq <;>
+ unfold System.Platform.numBits at * <;>
+ simp [*] at *
+ -- TODO: I would have expected terms like `-(1 + 1) ^ 63` to be simplified
+ linarith
+
+theorem Scalar.cMax_suffices ty (h : x ≤ Scalar.cMax ty) : x ≤ Scalar.max ty := by
+ cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at * <;>
+ cases System.Platform.numBits_eq <;>
+ unfold System.Platform.numBits at * <;>
+ simp [*] at * <;>
+ -- TODO: I would have expected terms like `-(1 + 1) ^ 63` to be simplified
+ linarith
structure Scalar (ty : ScalarTy) where
val : Int
- hmin : Scalar.min ty <= val
- hmax : val <= Scalar.max ty
+ hmin : Scalar.min ty ≤ val
+ hmax : val ≤ Scalar.max ty
theorem Scalar.bound_suffices (ty : ScalarTy) (x : Int) :
- Scalar.cMin ty <= x && x <= Scalar.cMax ty ->
- (decide (Scalar.min ty ≤ x) && decide (x ≤ Scalar.max ty)) = true
- := by sorry
+ Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty ->
+ Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty
+ :=
+ λ h => by
+ apply And.intro <;> have hmin := Scalar.cMin_bound ty <;> have hmax := Scalar.cMax_bound ty <;> linarith
def Scalar.ofIntCore {ty : ScalarTy} (x : Int)
- (hmin : Scalar.min ty <= x) (hmax : x <= Scalar.max ty) : Scalar ty :=
+ (hmin : Scalar.min ty ≤ x) (hmax : x ≤ Scalar.max ty) : Scalar ty :=
{ val := x, hmin := hmin, hmax := hmax }
def Scalar.ofInt {ty : ScalarTy} (x : Int)
- (h : Scalar.min ty <= x && x <= Scalar.max ty) : Scalar ty :=
- let hmin: Scalar.min ty <= x := by sorry
- let hmax: x <= Scalar.max ty := by sorry
+ (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : Scalar ty :=
+ let ⟨ hmin, hmax ⟩ := h
Scalar.ofIntCore x hmin hmax
-- Further thoughts: look at what has been done here:
@@ -279,12 +306,15 @@ def Scalar.ofInt {ty : ScalarTy} (x : Int)
-- which both contain a fair amount of reasoning already!
def Scalar.tryMk (ty : ScalarTy) (x : Int) : Result (Scalar ty) :=
-- TODO: write this with only one if then else
- if hmin_cons: Scalar.cMin ty <= x || Scalar.min ty <= x then
- if hmax_cons: x <= Scalar.cMax ty || x <= Scalar.max ty then
- let hmin: Scalar.min ty <= x := by sorry
- let hmax: x <= Scalar.max ty := by sorry
- return Scalar.ofIntCore x hmin hmax
- else fail integerOverflow
+ if h: (Scalar.cMin ty ≤ x || Scalar.min ty ≤ x) && (x ≤ Scalar.cMax ty || x ≤ Scalar.max ty) then
+ let h: Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty := by
+ simp at *
+ have ⟨ hmin, hmax ⟩ := h
+ have hbmin := Scalar.cMin_bound ty
+ have hbmax := Scalar.cMax_bound ty
+ cases hmin <;> cases hmax <;> apply And.intro <;> linarith
+ let ⟨ hmin, hmax ⟩ := h
+ return Scalar.ofIntCore x hmin hmax
else fail integerOverflow
def Scalar.neg {ty : ScalarTy} (x : Scalar ty) : Result (Scalar ty) := Scalar.tryMk ty (- x.val)
@@ -292,11 +322,39 @@ def Scalar.neg {ty : ScalarTy} (x : Scalar ty) : Result (Scalar ty) := Scalar.tr
def Scalar.div {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) :=
if y.val != 0 then Scalar.tryMk ty (x.val / y.val) else fail divisionByZero
--- Checking that the % operation in Lean computes the same as the remainder operation in Rust
-#assert 1 % 2 = (1:Int)
-#assert (-1) % 2 = -1
-#assert 1 % (-2) = 1
-#assert (-1) % (-2) = -1
+-- Our custom remainder operation, which satisfies the semantics of Rust
+-- TODO: is there a better way?
+def scalar_rem (x y : Int) : Int :=
+ if 0 ≤ x then |x| % |y|
+ else - (|x| % |y|)
+
+-- Our custom division operation, which satisfies the semantics of Rust
+-- TODO: is there a better way?
+def scalar_div (x y : Int) : Int :=
+ if 0 ≤ x && 0 ≤ y then |x| / |y|
+ else if 0 ≤ x && y < 0 then - (|x| / |y|)
+ else if x < 0 && 0 ≤ y then - (|x| / |y|)
+ else |x| / |y|
+
+-- Checking that the remainder operation is correct
+#assert scalar_rem 1 2 = 1
+#assert scalar_rem (-1) 2 = -1
+#assert scalar_rem 1 (-2) = 1
+#assert scalar_rem (-1) (-2) = -1
+#assert scalar_rem 7 3 = (1:Int)
+#assert scalar_rem (-7) 3 = -1
+#assert scalar_rem 7 (-3) = 1
+#assert scalar_rem (-7) (-3) = -1
+
+-- Checking that the division operation is correct
+#assert scalar_div 3 2 = 1
+#assert scalar_div (-3) 2 = -1
+#assert scalar_div 3 (-2) = -1
+#assert scalar_div (-3) (-2) = 1
+#assert scalar_div 7 3 = 2
+#assert scalar_div (-7) 3 = -2
+#assert scalar_div 7 (-3) = -2
+#assert scalar_div (-7) (-3) = 2
def Scalar.rem {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) :=
if y.val != 0 then Scalar.tryMk ty (x.val % y.val) else fail divisionByZero
@@ -479,20 +537,29 @@ macro_rules
-- VECTORS --
-------------
-def Vec (α : Type u) := { l : List α // List.length l <= Usize.max }
+def Vec (α : Type u) := { l : List α // List.length l ≤ Usize.max }
-def vec_new (α : Type u): Vec α := ⟨ [], by sorry ⟩
+def vec_new (α : Type u): Vec α := ⟨ [], by apply Scalar.cMax_suffices .Usize; simp ⟩
def vec_len (α : Type u) (v : Vec α) : Usize :=
let ⟨ v, l ⟩ := v
- Usize.ofIntCore (List.length v) (by sorry) l
+ Usize.ofIntCore (List.length v) (by simp [Scalar.min]) l
def vec_push_fwd (α : Type u) (_ : Vec α) (_ : α) : Unit := ()
def vec_push_back (α : Type u) (v : Vec α) (x : α) : Result (Vec α)
:=
- if h : List.length v.val <= U32.max || List.length v.val <= Usize.max then
- return ⟨ List.concat v.val x, by sorry ⟩
+ let nlen := List.length v.val + 1
+ if h : nlen ≤ U32.max || nlen ≤ Usize.max then
+ have h : nlen ≤ Usize.max := by
+ simp at *
+ cases System.Platform.numBits_eq <;>
+ unfold System.Platform.numBits at * <;>
+ simp [*] at * <;>
+ try assumption
+ cases h <;>
+ linarith
+ return ⟨ List.concat v.val x, by simp at *; assumption ⟩
else
fail maximumSizeExceeded
@@ -506,30 +573,28 @@ def vec_insert_back (α : Type u) (v: Vec α) (i: Usize) (x: α): Result (Vec α
if i.val < List.length v.val then
-- TODO: maybe we should redefine a list library which uses integers
-- (instead of natural numbers)
- let i : Nat :=
- match i.val with
- | .ofNat n => n
- | .negSucc n => by sorry -- TODO: we can't get here
- let isLt: i < USize.size := by sorry
- let i : Fin USize.size := { val := i, isLt := isLt }
- .ret ⟨ List.set v.val i.val x, by
- have h: List.length v.val <= Usize.max := v.property
- rewrite [ List.length_set v.val i.val x ]
+ let i := i.val.toNat
+ .ret ⟨ List.set v.val i x, by
+ have h: List.length v.val ≤ Usize.max := v.property
+ simp [*] at *
assumption
else
.fail arrayOutOfBounds
+def vec_index_to_fin {α : Type u} {v: Vec α} {i: Usize} (h : i.val < List.length v.val) :
+ Fin (List.length v.val) :=
+ let j := i.val.toNat
+ let h: j < List.length v.val := by
+ have heq := @Int.toNat_lt (List.length v.val) i.val i.hmin
+ apply heq.mpr
+ assumption
+ ⟨j, h⟩
+
def vec_index_fwd (α : Type u) (v: Vec α) (i: Usize): Result α :=
- if i.val < List.length v.val then
- let i : Nat :=
- match i.val with
- | .ofNat n => n
- | .negSucc n => by sorry -- TODO: we can't get here
- let isLt: i < USize.size := by sorry
- let i : Fin USize.size := { val := i, isLt := isLt }
- let h: i < List.length v.val := by sorry
- .ret (List.get v.val ⟨i.val, h⟩)
+ if h: i.val < List.length v.val then
+ let i := vec_index_to_fin h
+ .ret (List.get v.val i)
else
.fail arrayOutOfBounds
@@ -540,29 +605,18 @@ def vec_index_back (α : Type u) (v: Vec α) (i: Usize) (_: α): Result Unit :=
.fail arrayOutOfBounds
def vec_index_mut_fwd (α : Type u) (v: Vec α) (i: Usize): Result α :=
- if i.val < List.length v.val then
- let i : Nat :=
- match i.val with
- | .ofNat n => n
- | .negSucc n => by sorry -- TODO: we can't get here
- let isLt: i < USize.size := by sorry
- let i : Fin USize.size := { val := i, isLt := isLt }
- let h: i < List.length v.val := by sorry
- .ret (List.get v.val ⟨i.val, h⟩)
+ if h: i.val < List.length v.val then
+ let i := vec_index_to_fin h
+ .ret (List.get v.val i)
else
.fail arrayOutOfBounds
def vec_index_mut_back (α : Type u) (v: Vec α) (i: Usize) (x: α): Result (Vec α) :=
- if i.val < List.length v.val then
- let i : Nat :=
- match i.val with
- | .ofNat n => n
- | .negSucc n => by sorry -- TODO: we can't get here
- let isLt: i < USize.size := by sorry
- let i : Fin USize.size := { val := i, isLt := isLt }
- .ret ⟨ List.set v.val i.val x, by
- have h: List.length v.val <= Usize.max := v.property
- rewrite [ List.length_set v.val i.val x ]
+ if h: i.val < List.length v.val then
+ let i := vec_index_to_fin h
+ .ret ⟨ List.set v.val i x, by
+ have h: List.length v.val ≤ Usize.max := v.property
+ simp [*] at *
assumption
else
diff --git a/backends/lean/lakefile.lean b/backends/lean/lakefile.lean
index 9633e1e8..c5e27d1c 100644
--- a/backends/lean/lakefile.lean
+++ b/backends/lean/lakefile.lean
@@ -1,6 +1,7 @@
import Lake
open Lake DSL
+-- Important: mathlib imports std4 and quote4: we mustn't add a `require std4` line
require mathlib from git
"https://github.com/leanprover-community/mathlib4.git"