diff options
author | Son Ho | 2023-06-06 15:53:46 +0200 |
---|---|---|
committer | Son Ho | 2023-06-06 15:53:46 +0200 |
commit | acc09d5c69690f2c46cb1bacf290da5dcc268b24 (patch) | |
tree | 04384980d21b90b85ae047d65b4139824b1dd635 /backends/lean | |
parent | 53adf30fe440eb8b6f58ba89f4a4c0acc7877498 (diff) |
Remove the sorries from Primitives.lean
Diffstat (limited to '')
-rw-r--r-- | backends/lean/Primitives.lean | 184 | ||||
-rw-r--r-- | backends/lean/lakefile.lean | 1 |
2 files changed, 120 insertions, 65 deletions
diff --git a/backends/lean/Primitives.lean b/backends/lean/Primitives.lean index 4a66a453..e7826fbf 100644 --- a/backends/lean/Primitives.lean +++ b/backends/lean/Primitives.lean @@ -2,9 +2,10 @@ import Lean import Lean.Meta.Tactic.Simp import Init.Data.List.Basic import Mathlib.Tactic.RunCmd +import Mathlib.Tactic.Linarith -------------------- --- ASSERT COMMAND -- +-- ASSERT COMMAND --Std. -------------------- open Lean Elab Command Term Meta @@ -249,27 +250,53 @@ def Scalar.cMax (ty : ScalarTy) : Int := | .Usize => U32.max | _ => Scalar.max ty -theorem Scalar.cMin_bound ty : Scalar.min ty <= Scalar.cMin ty := by sorry -theorem Scalar.cMax_bound ty : Scalar.min ty <= Scalar.cMin ty := by sorry +theorem Scalar.cMin_bound ty : Scalar.min ty ≤ Scalar.cMin ty := by + cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at * + cases System.Platform.numBits_eq <;> + unfold System.Platform.numBits at * <;> + simp [*] + +theorem Scalar.cMax_bound ty : Scalar.cMax ty ≤ Scalar.max ty := by + cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at * <;> + cases System.Platform.numBits_eq <;> + unfold System.Platform.numBits at * <;> + simp [*] + +theorem Scalar.cMin_suffices ty (h : Scalar.cMin ty ≤ x) : Scalar.min ty ≤ x := by + cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at * <;> + cases System.Platform.numBits_eq <;> + unfold System.Platform.numBits at * <;> + simp [*] at * + -- TODO: I would have expected terms like `-(1 + 1) ^ 63` to be simplified + linarith + +theorem Scalar.cMax_suffices ty (h : x ≤ Scalar.cMax ty) : x ≤ Scalar.max ty := by + cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at * <;> + cases System.Platform.numBits_eq <;> + unfold System.Platform.numBits at * <;> + simp [*] at * <;> + -- TODO: I would have expected terms like `-(1 + 1) ^ 63` to be simplified + linarith structure Scalar (ty : ScalarTy) where val : Int - hmin : Scalar.min ty <= val - hmax : val <= Scalar.max ty + hmin : Scalar.min ty ≤ val + hmax : val ≤ Scalar.max ty theorem Scalar.bound_suffices (ty : ScalarTy) (x : Int) : - Scalar.cMin ty <= x && x <= Scalar.cMax ty -> - (decide (Scalar.min ty ≤ x) && decide (x ≤ Scalar.max ty)) = true - := by sorry + Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty -> + Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty + := + λ h => by + apply And.intro <;> have hmin := Scalar.cMin_bound ty <;> have hmax := Scalar.cMax_bound ty <;> linarith def Scalar.ofIntCore {ty : ScalarTy} (x : Int) - (hmin : Scalar.min ty <= x) (hmax : x <= Scalar.max ty) : Scalar ty := + (hmin : Scalar.min ty ≤ x) (hmax : x ≤ Scalar.max ty) : Scalar ty := { val := x, hmin := hmin, hmax := hmax } def Scalar.ofInt {ty : ScalarTy} (x : Int) - (h : Scalar.min ty <= x && x <= Scalar.max ty) : Scalar ty := - let hmin: Scalar.min ty <= x := by sorry - let hmax: x <= Scalar.max ty := by sorry + (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : Scalar ty := + let ⟨ hmin, hmax ⟩ := h Scalar.ofIntCore x hmin hmax -- Further thoughts: look at what has been done here: @@ -279,12 +306,15 @@ def Scalar.ofInt {ty : ScalarTy} (x : Int) -- which both contain a fair amount of reasoning already! def Scalar.tryMk (ty : ScalarTy) (x : Int) : Result (Scalar ty) := -- TODO: write this with only one if then else - if hmin_cons: Scalar.cMin ty <= x || Scalar.min ty <= x then - if hmax_cons: x <= Scalar.cMax ty || x <= Scalar.max ty then - let hmin: Scalar.min ty <= x := by sorry - let hmax: x <= Scalar.max ty := by sorry - return Scalar.ofIntCore x hmin hmax - else fail integerOverflow + if h: (Scalar.cMin ty ≤ x || Scalar.min ty ≤ x) && (x ≤ Scalar.cMax ty || x ≤ Scalar.max ty) then + let h: Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty := by + simp at * + have ⟨ hmin, hmax ⟩ := h + have hbmin := Scalar.cMin_bound ty + have hbmax := Scalar.cMax_bound ty + cases hmin <;> cases hmax <;> apply And.intro <;> linarith + let ⟨ hmin, hmax ⟩ := h + return Scalar.ofIntCore x hmin hmax else fail integerOverflow def Scalar.neg {ty : ScalarTy} (x : Scalar ty) : Result (Scalar ty) := Scalar.tryMk ty (- x.val) @@ -292,11 +322,39 @@ def Scalar.neg {ty : ScalarTy} (x : Scalar ty) : Result (Scalar ty) := Scalar.tr def Scalar.div {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) := if y.val != 0 then Scalar.tryMk ty (x.val / y.val) else fail divisionByZero --- Checking that the % operation in Lean computes the same as the remainder operation in Rust -#assert 1 % 2 = (1:Int) -#assert (-1) % 2 = -1 -#assert 1 % (-2) = 1 -#assert (-1) % (-2) = -1 +-- Our custom remainder operation, which satisfies the semantics of Rust +-- TODO: is there a better way? +def scalar_rem (x y : Int) : Int := + if 0 ≤ x then |x| % |y| + else - (|x| % |y|) + +-- Our custom division operation, which satisfies the semantics of Rust +-- TODO: is there a better way? +def scalar_div (x y : Int) : Int := + if 0 ≤ x && 0 ≤ y then |x| / |y| + else if 0 ≤ x && y < 0 then - (|x| / |y|) + else if x < 0 && 0 ≤ y then - (|x| / |y|) + else |x| / |y| + +-- Checking that the remainder operation is correct +#assert scalar_rem 1 2 = 1 +#assert scalar_rem (-1) 2 = -1 +#assert scalar_rem 1 (-2) = 1 +#assert scalar_rem (-1) (-2) = -1 +#assert scalar_rem 7 3 = (1:Int) +#assert scalar_rem (-7) 3 = -1 +#assert scalar_rem 7 (-3) = 1 +#assert scalar_rem (-7) (-3) = -1 + +-- Checking that the division operation is correct +#assert scalar_div 3 2 = 1 +#assert scalar_div (-3) 2 = -1 +#assert scalar_div 3 (-2) = -1 +#assert scalar_div (-3) (-2) = 1 +#assert scalar_div 7 3 = 2 +#assert scalar_div (-7) 3 = -2 +#assert scalar_div 7 (-3) = -2 +#assert scalar_div (-7) (-3) = 2 def Scalar.rem {ty : ScalarTy} (x : Scalar ty) (y : Scalar ty) : Result (Scalar ty) := if y.val != 0 then Scalar.tryMk ty (x.val % y.val) else fail divisionByZero @@ -479,20 +537,29 @@ macro_rules -- VECTORS -- ------------- -def Vec (α : Type u) := { l : List α // List.length l <= Usize.max } +def Vec (α : Type u) := { l : List α // List.length l ≤ Usize.max } -def vec_new (α : Type u): Vec α := ⟨ [], by sorry ⟩ +def vec_new (α : Type u): Vec α := ⟨ [], by apply Scalar.cMax_suffices .Usize; simp ⟩ def vec_len (α : Type u) (v : Vec α) : Usize := let ⟨ v, l ⟩ := v - Usize.ofIntCore (List.length v) (by sorry) l + Usize.ofIntCore (List.length v) (by simp [Scalar.min]) l def vec_push_fwd (α : Type u) (_ : Vec α) (_ : α) : Unit := () def vec_push_back (α : Type u) (v : Vec α) (x : α) : Result (Vec α) := - if h : List.length v.val <= U32.max || List.length v.val <= Usize.max then - return ⟨ List.concat v.val x, by sorry ⟩ + let nlen := List.length v.val + 1 + if h : nlen ≤ U32.max || nlen ≤ Usize.max then + have h : nlen ≤ Usize.max := by + simp at * + cases System.Platform.numBits_eq <;> + unfold System.Platform.numBits at * <;> + simp [*] at * <;> + try assumption + cases h <;> + linarith + return ⟨ List.concat v.val x, by simp at *; assumption ⟩ else fail maximumSizeExceeded @@ -506,30 +573,28 @@ def vec_insert_back (α : Type u) (v: Vec α) (i: Usize) (x: α): Result (Vec α if i.val < List.length v.val then -- TODO: maybe we should redefine a list library which uses integers -- (instead of natural numbers) - let i : Nat := - match i.val with - | .ofNat n => n - | .negSucc n => by sorry -- TODO: we can't get here - let isLt: i < USize.size := by sorry - let i : Fin USize.size := { val := i, isLt := isLt } - .ret ⟨ List.set v.val i.val x, by - have h: List.length v.val <= Usize.max := v.property - rewrite [ List.length_set v.val i.val x ] + let i := i.val.toNat + .ret ⟨ List.set v.val i x, by + have h: List.length v.val ≤ Usize.max := v.property + simp [*] at * assumption ⟩ else .fail arrayOutOfBounds +def vec_index_to_fin {α : Type u} {v: Vec α} {i: Usize} (h : i.val < List.length v.val) : + Fin (List.length v.val) := + let j := i.val.toNat + let h: j < List.length v.val := by + have heq := @Int.toNat_lt (List.length v.val) i.val i.hmin + apply heq.mpr + assumption + ⟨j, h⟩ + def vec_index_fwd (α : Type u) (v: Vec α) (i: Usize): Result α := - if i.val < List.length v.val then - let i : Nat := - match i.val with - | .ofNat n => n - | .negSucc n => by sorry -- TODO: we can't get here - let isLt: i < USize.size := by sorry - let i : Fin USize.size := { val := i, isLt := isLt } - let h: i < List.length v.val := by sorry - .ret (List.get v.val ⟨i.val, h⟩) + if h: i.val < List.length v.val then + let i := vec_index_to_fin h + .ret (List.get v.val i) else .fail arrayOutOfBounds @@ -540,29 +605,18 @@ def vec_index_back (α : Type u) (v: Vec α) (i: Usize) (_: α): Result Unit := .fail arrayOutOfBounds def vec_index_mut_fwd (α : Type u) (v: Vec α) (i: Usize): Result α := - if i.val < List.length v.val then - let i : Nat := - match i.val with - | .ofNat n => n - | .negSucc n => by sorry -- TODO: we can't get here - let isLt: i < USize.size := by sorry - let i : Fin USize.size := { val := i, isLt := isLt } - let h: i < List.length v.val := by sorry - .ret (List.get v.val ⟨i.val, h⟩) + if h: i.val < List.length v.val then + let i := vec_index_to_fin h + .ret (List.get v.val i) else .fail arrayOutOfBounds def vec_index_mut_back (α : Type u) (v: Vec α) (i: Usize) (x: α): Result (Vec α) := - if i.val < List.length v.val then - let i : Nat := - match i.val with - | .ofNat n => n - | .negSucc n => by sorry -- TODO: we can't get here - let isLt: i < USize.size := by sorry - let i : Fin USize.size := { val := i, isLt := isLt } - .ret ⟨ List.set v.val i.val x, by - have h: List.length v.val <= Usize.max := v.property - rewrite [ List.length_set v.val i.val x ] + if h: i.val < List.length v.val then + let i := vec_index_to_fin h + .ret ⟨ List.set v.val i x, by + have h: List.length v.val ≤ Usize.max := v.property + simp [*] at * assumption ⟩ else diff --git a/backends/lean/lakefile.lean b/backends/lean/lakefile.lean index 9633e1e8..c5e27d1c 100644 --- a/backends/lean/lakefile.lean +++ b/backends/lean/lakefile.lean @@ -1,6 +1,7 @@ import Lake open Lake DSL +-- Important: mathlib imports std4 and quote4: we mustn't add a `require std4` line require mathlib from git "https://github.com/leanprover-community/mathlib4.git" |