diff options
-rw-r--r-- | backends/lean/Base/Primitives/Scalar.lean | 272 | ||||
-rw-r--r-- | compiler/ExtractBuiltin.ml | 31 |
2 files changed, 282 insertions, 21 deletions
diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean index 98d695a4..8de2b3f2 100644 --- a/backends/lean/Base/Primitives/Scalar.lean +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -325,33 +325,65 @@ instance [Decide (Scalar.cMin ty ≤ v ∧ v ≤ Scalar.cMax ty)] : InBounds ty @[reducible, match_pattern] def Scalar.ofInt {ty : ScalarTy} (x : Int) [InBounds ty x] : Scalar ty := Scalar.ofIntCore x (Scalar.bound_suffices ty x InBounds.hInBounds) -@[simp] def Scalar.check_bounds (ty : ScalarTy) (x : Int) : Bool := +@[simp] abbrev Scalar.in_bounds (ty : ScalarTy) (x : Int) : Prop := + Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty + +@[simp] abbrev Scalar.check_bounds (ty : ScalarTy) (x : Int) : Bool := (Scalar.cMin ty ≤ x || Scalar.min ty ≤ x) ∧ (x ≤ Scalar.cMax ty || x ≤ Scalar.max ty) -theorem Scalar.check_bounds_prop {ty : ScalarTy} {x : Int} (h: Scalar.check_bounds ty x) : - Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty := by +theorem Scalar.check_bounds_imp_in_bounds {ty : ScalarTy} {x : Int} + (h: Scalar.check_bounds ty x) : + Scalar.in_bounds ty x := by simp at * have ⟨ hmin, hmax ⟩ := h have hbmin := Scalar.cMin_bound ty have hbmax := Scalar.cMax_bound ty cases hmin <;> cases hmax <;> apply And.intro <;> linarith +theorem Scalar.check_bounds_eq_in_bounds (ty : ScalarTy) (x : Int) : + Scalar.check_bounds ty x ↔ Scalar.in_bounds ty x := by + constructor <;> intro h + . apply (check_bounds_imp_in_bounds h) + . simp_all + -- Further thoughts: look at what has been done here: -- https://github.com/leanprover-community/mathlib4/blob/master/Mathlib/Data/Fin/Basic.lean -- and -- https://github.com/leanprover-community/mathlib4/blob/master/Mathlib/Data/UInt.lean -- which both contain a fair amount of reasoning already! -def Scalar.tryMk (ty : ScalarTy) (x : Int) : Result (Scalar ty) := +def Scalar.tryMkOpt (ty : ScalarTy) (x : Int) : Option (Scalar ty) := if h:Scalar.check_bounds ty x then -- If we do: -- ``` - -- let ⟨ hmin, hmax ⟩ := (Scalar.check_bounds_prop h) + -- let ⟨ hmin, hmax ⟩ := (Scalar.check_bounds_imp_in_bounds h) -- Scalar.ofIntCore x hmin hmax -- ``` -- then normalization blocks (for instance, some proofs which use reflexivity fail). -- However, the version below doesn't block reduction (TODO: investigate): - ok (Scalar.ofIntCore x (Scalar.check_bounds_prop h)) - else fail integerOverflow + some (Scalar.ofIntCore x (Scalar.check_bounds_imp_in_bounds h)) + else none + +def Result.ofOption {a : Type u} (x : Option a) (e : Error) : Result a := + match x with + | some x => ok x + | none => fail e + +def Scalar.tryMk (ty : ScalarTy) (x : Int) : Result (Scalar ty) := + Result.ofOption (tryMkOpt ty x) integerOverflow + +theorem Scalar.tryMk_eq (ty : ScalarTy) (x : Int) : + match tryMk ty x with + | ok y => y.val = x ∧ in_bounds ty x + | fail _ => ¬ (in_bounds ty x) + | _ => False := by + simp [tryMk, ofOption, tryMkOpt, ofIntCore] + have h := check_bounds_eq_in_bounds ty x + split_ifs <;> simp_all + +@[simp] theorem Scalar.tryMk_eq_div (ty : ScalarTy) (x : Int) : + tryMk ty x = div ↔ False := by + simp [tryMk, ofOption, tryMkOpt] + split_ifs <;> simp def Scalar.neg {ty : ScalarTy} (x : Scalar ty) : Result (Scalar ty) := Scalar.tryMk ty (- x.val) @@ -579,17 +611,121 @@ instance {ty} : HOr (Scalar ty) (Scalar ty) (Scalar ty) where instance {ty} : HAnd (Scalar ty) (Scalar ty) (Scalar ty) where hAnd x y := Scalar.and x y +-- core checked arithmetic operations + +/- A helper function that converts failure to none and success to some + TODO: move up to Base module? -/ +def Option.ofResult {a : Type u} (x : Result a) : + Option a := + match x with + | ok x => some x + | _ => none + +/- [core::num::{T}::checked_add] -/ +def core.num.checked_add (x y : Scalar ty) : Option (Scalar ty) := + Option.ofResult (x + y) + +def U8.checked_add (x y : U8) : Option U8 := core.num.checked_add x y +def U16.checked_add (x y : U16) : Option U16 := core.num.checked_add x y +def U32.checked_add (x y : U32) : Option U32 := core.num.checked_add x y +def U64.checked_add (x y : U64) : Option U64 := core.num.checked_add x y +def U128.checked_add (x y : U128) : Option U128 := core.num.checked_add x y +def Usize.checked_add (x y : Usize) : Option Usize := core.num.checked_add x y +def I8.checked_add (x y : I8) : Option I8 := core.num.checked_add x y +def I16.checked_add (x y : I16) : Option I16 := core.num.checked_add x y +def I32.checked_add (x y : I32) : Option I32 := core.num.checked_add x y +def I64.checked_add (x y : I64) : Option I64 := core.num.checked_add x y +def I128.checked_add (x y : I128) : Option I128 := core.num.checked_add x y +def Isize.checked_add (x y : Isize) : Option Isize := core.num.checked_add x y + +/- [core::num::{T}::checked_sub] -/ +def core.num.checked_sub (x y : Scalar ty) : Option (Scalar ty) := + Option.ofResult (x - y) + +def U8.checked_sub (x y : U8) : Option U8 := core.num.checked_sub x y +def U16.checked_sub (x y : U16) : Option U16 := core.num.checked_sub x y +def U32.checked_sub (x y : U32) : Option U32 := core.num.checked_sub x y +def U64.checked_sub (x y : U64) : Option U64 := core.num.checked_sub x y +def U128.checked_sub (x y : U128) : Option U128 := core.num.checked_sub x y +def Usize.checked_sub (x y : Usize) : Option Usize := core.num.checked_sub x y +def I8.checked_sub (x y : I8) : Option I8 := core.num.checked_sub x y +def I16.checked_sub (x y : I16) : Option I16 := core.num.checked_sub x y +def I32.checked_sub (x y : I32) : Option I32 := core.num.checked_sub x y +def I64.checked_sub (x y : I64) : Option I64 := core.num.checked_sub x y +def I128.checked_sub (x y : I128) : Option I128 := core.num.checked_sub x y +def Isize.checked_sub (x y : Isize) : Option Isize := core.num.checked_sub x y + +/- [core::num::{T}::checked_mul] -/ +def core.num.checked_mul (x y : Scalar ty) : Option (Scalar ty) := + Option.ofResult (x * y) + +def U8.checked_mul (x y : U8) : Option U8 := core.num.checked_mul x y +def U16.checked_mul (x y : U16) : Option U16 := core.num.checked_mul x y +def U32.checked_mul (x y : U32) : Option U32 := core.num.checked_mul x y +def U64.checked_mul (x y : U64) : Option U64 := core.num.checked_mul x y +def U128.checked_mul (x y : U128) : Option U128 := core.num.checked_mul x y +def Usize.checked_mul (x y : Usize) : Option Usize := core.num.checked_mul x y +def I8.checked_mul (x y : I8) : Option I8 := core.num.checked_mul x y +def I16.checked_mul (x y : I16) : Option I16 := core.num.checked_mul x y +def I32.checked_mul (x y : I32) : Option I32 := core.num.checked_mul x y +def I64.checked_mul (x y : I64) : Option I64 := core.num.checked_mul x y +def I128.checked_mul (x y : I128) : Option I128 := core.num.checked_mul x y +def Isize.checked_mul (x y : Isize) : Option Isize := core.num.checked_mul x y + +/- [core::num::{T}::checked_div] -/ +def core.num.checked_div (x y : Scalar ty) : Option (Scalar ty) := + Option.ofResult (x / y) + +def U8.checked_div (x y : U8) : Option U8 := core.num.checked_div x y +def U16.checked_div (x y : U16) : Option U16 := core.num.checked_div x y +def U32.checked_div (x y : U32) : Option U32 := core.num.checked_div x y +def U64.checked_div (x y : U64) : Option U64 := core.num.checked_div x y +def U128.checked_div (x y : U128) : Option U128 := core.num.checked_div x y +def Usize.checked_div (x y : Usize) : Option Usize := core.num.checked_div x y +def I8.checked_div (x y : I8) : Option I8 := core.num.checked_div x y +def I16.checked_div (x y : I16) : Option I16 := core.num.checked_div x y +def I32.checked_div (x y : I32) : Option I32 := core.num.checked_div x y +def I64.checked_div (x y : I64) : Option I64 := core.num.checked_div x y +def I128.checked_div (x y : I128) : Option I128 := core.num.checked_div x y +def Isize.checked_div (x y : Isize) : Option Isize := core.num.checked_div x y + +/- [core::num::{T}::checked_rem] -/ +def core.num.checked_rem (x y : Scalar ty) : Option (Scalar ty) := + Option.ofResult (x % y) + +def U8.checked_rem (x y : U8) : Option U8 := core.num.checked_rem x y +def U16.checked_rem (x y : U16) : Option U16 := core.num.checked_rem x y +def U32.checked_rem (x y : U32) : Option U32 := core.num.checked_rem x y +def U64.checked_rem (x y : U64) : Option U64 := core.num.checked_rem x y +def U128.checked_rem (x y : U128) : Option U128 := core.num.checked_rem x y +def Usize.checked_rem (x y : Usize) : Option Usize := core.num.checked_rem x y +def I8.checked_rem (x y : I8) : Option I8 := core.num.checked_rem x y +def I16.checked_rem (x y : I16) : Option I16 := core.num.checked_rem x y +def I32.checked_rem (x y : I32) : Option I32 := core.num.checked_rem x y +def I64.checked_rem (x y : I64) : Option I64 := core.num.checked_rem x y +def I128.checked_rem (x y : I128) : Option I128 := core.num.checked_rem x y +def Isize.checked_rem (x y : Isize) : Option Isize := core.num.checked_rem x y + +theorem Scalar.add_equiv {ty} {x y : Scalar ty} : + match x + y with + | ok z => Scalar.in_bounds ty (↑x + ↑y) ∧ (↑z : Int) = ↑x + ↑y + | fail _ => ¬ (Scalar.in_bounds ty (↑x + ↑y)) + | _ => ⊥ := by + -- Applying the unfoldings only inside the match + conv in _ + _ => unfold HAdd.hAdd instHAddScalarResult; simp [add] + have h := tryMk_eq ty (↑x + ↑y) + simp [in_bounds] at h + split at h <;> simp_all [check_bounds_eq_in_bounds] + -- Generic theorem - shouldn't be used much @[pspec] theorem Scalar.add_spec {ty} {x y : Scalar ty} (hmin : Scalar.min ty ≤ ↑x + y.val) (hmax : ↑x + ↑y ≤ Scalar.max ty) : (∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y) := by - -- Applying the unfoldings only on the left - conv => congr; ext; lhs; unfold HAdd.hAdd instHAddScalarResult; simp [add, tryMk] - split - . simp [pure]; rfl - . tauto + have h := @add_equiv ty x y + split at h <;> simp_all + apply h theorem Scalar.add_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty} (hmax : ↑x + ↑y ≤ Scalar.max ty) : @@ -655,17 +791,36 @@ theorem Scalar.add_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty} ∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y := Scalar.add_spec hmin hmax +theorem core.num.checked_add_spec {ty} {x y : Scalar ty} : + match core.num.checked_add x y with + | some z => Scalar.in_bounds ty (↑x + ↑y) ∧ ↑z = (↑x + ↑y : Int) + | none => ¬ (Scalar.in_bounds ty (↑x + ↑y)) := by + have h := Scalar.tryMk_eq ty (↑x + ↑y) + simp only [checked_add, Option.ofResult] + cases heq: x + y <;> simp_all <;> simp [HAdd.hAdd, Scalar.add] at heq + <;> simp [Add.add] at heq + <;> simp_all + +theorem Scalar.sub_equiv {ty} {x y : Scalar ty} : + match x - y with + | ok z => Scalar.in_bounds ty (↑x - ↑y) ∧ (↑z : Int) = ↑x - ↑y + | fail _ => ¬ (Scalar.in_bounds ty (↑x - ↑y)) + | _ => ⊥ := by + -- Applying the unfoldings only inside the match + conv in _ - _ => unfold HSub.hSub instHSubScalarResult; simp [sub] + have h := tryMk_eq ty (↑x - ↑y) + simp [in_bounds] at h + split at h <;> simp_all [check_bounds_eq_in_bounds] + -- Generic theorem - shouldn't be used much @[pspec] theorem Scalar.sub_spec {ty} {x y : Scalar ty} (hmin : Scalar.min ty ≤ ↑x - ↑y) (hmax : ↑x - ↑y ≤ Scalar.max ty) : ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y := by - conv => congr; ext; lhs; simp [HSub.hSub, sub, tryMk, Sub.sub] - split - . simp [pure] - rfl - . tauto + have h := @sub_equiv ty x y + split at h <;> simp_all + apply h theorem Scalar.sub_unsigned_spec {ty : ScalarTy} (s : ¬ ty.isSigned) {x y : Scalar ty} (hmin : Scalar.min ty ≤ ↑x - ↑y) : @@ -739,12 +894,33 @@ theorem Scalar.mul_spec {ty} {x y : Scalar ty} (hmax : ↑x * ↑y ≤ Scalar.max ty) : ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y := by conv => congr; ext; lhs; simp [HMul.hMul] - simp [mul, tryMk] - split + simp [mul, tryMk, tryMkOpt, ofOption] + split_ifs . simp [pure] rfl . tauto +theorem core.num.checked_sub_spec {ty} {x y : Scalar ty} : + match core.num.checked_sub x y with + | some z => Scalar.in_bounds ty (↑x - ↑y) ∧ ↑z = (↑x - ↑y : Int) + | none => ¬ (Scalar.in_bounds ty (↑x - ↑y)) := by + have h := Scalar.tryMk_eq ty (↑x - ↑y) + simp only [checked_sub, Option.ofResult] + have add_neg_eq : x.val + (-y.val) = x.val - y.val := by omega -- TODO: why do we need this?? + cases heq: x - y <;> simp_all <;> simp only [HSub.hSub, Scalar.sub, Sub.sub, Int.sub] at heq + <;> simp_all + +theorem Scalar.mul_equiv {ty} {x y : Scalar ty} : + match x * y with + | ok z => Scalar.in_bounds ty (↑x * ↑y) ∧ (↑z : Int) = ↑x * ↑y + | fail _ => ¬ (Scalar.in_bounds ty (↑x * ↑y)) + | _ => ⊥ := by + -- Applying the unfoldings only inside the match + conv in _ * _ => unfold HMul.hMul instHMulScalarResult; simp [mul] + have h := tryMk_eq ty (↑x * ↑y) + simp [in_bounds] at h + split at h <;> simp_all [check_bounds_eq_in_bounds] + theorem Scalar.mul_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty} (hmax : ↑x * ↑y ≤ Scalar.max ty) : ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y := by @@ -809,6 +985,28 @@ theorem Scalar.mul_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty} ∃ z, x * y = ok z ∧ (↑z : Int) = ↑x * ↑y := Scalar.mul_spec hmin hmax +theorem core.num.checked_mul_spec {ty} {x y : Scalar ty} : + match core.num.checked_mul x y with + | some z => Scalar.in_bounds ty (↑x * ↑y) ∧ ↑z = (↑x * ↑y : Int) + | none => ¬ (Scalar.in_bounds ty (↑x * ↑y)) := by + have h := Scalar.tryMk_eq ty (↑x * ↑y) + simp only [checked_mul, Option.ofResult] + have : Int.mul ↑x ↑y = ↑x * ↑y := by simp -- TODO: why do we need this?? + cases heq: x * y <;> simp_all <;> simp only [HMul.hMul, Scalar.mul, Mul.mul] at heq + <;> simp_all + +theorem Scalar.div_equiv {ty} {x y : Scalar ty} : + match x / y with + | ok z => y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_div ↑x ↑y) ∧ (↑z : Int) = scalar_div ↑x ↑y + | fail _ => ¬ (y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_div ↑x ↑y)) + | _ => ⊥ := by + -- Applying the unfoldings only inside the match + conv in _ / _ => unfold HDiv.hDiv instHDivScalarResult; simp [div] + have h := tryMk_eq ty (scalar_div ↑x ↑y) + simp [in_bounds] at h + split_ifs <;> simp <;> + split at h <;> simp_all [check_bounds_eq_in_bounds] + -- Generic theorem - shouldn't be used much @[pspec] theorem Scalar.div_spec {ty} {x y : Scalar ty} @@ -817,7 +1015,7 @@ theorem Scalar.div_spec {ty} {x y : Scalar ty} (hmax : scalar_div ↑x ↑y ≤ Scalar.max ty) : ∃ z, x / y = ok z ∧ (↑z : Int) = scalar_div ↑x ↑y := by simp [HDiv.hDiv, div, Div.div] - simp [tryMk, *] + simp [tryMk, tryMkOpt, ofOption, *] rfl theorem Scalar.div_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : Scalar ty} @@ -903,6 +1101,28 @@ theorem Scalar.div_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : S ∃ z, x / y = ok z ∧ (↑z : Int) = scalar_div ↑x ↑y := Scalar.div_spec hnz hmin hmax +theorem core.num.checked_div_spec {ty} {x y : Scalar ty} : + match core.num.checked_div x y with + | some z => y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_div ↑x ↑y) ∧ ↑z = (scalar_div ↑x ↑y : Int) + | none => ¬ (y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_div ↑x ↑y)) := by + have h := Scalar.tryMk_eq ty (scalar_div ↑x ↑y) + simp only [checked_div, Option.ofResult] + cases heq0: (y.val = 0 : Bool) <;> + cases heq1: x / y <;> simp_all <;> simp only [HDiv.hDiv, Scalar.div, Div.div] at heq1 + <;> simp_all + +theorem Scalar.rem_equiv {ty} {x y : Scalar ty} : + match x % y with + | ok z => y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_rem ↑x ↑y) ∧ (↑z : Int) = scalar_rem ↑x ↑y + | fail _ => ¬ (y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_rem ↑x ↑y)) + | _ => ⊥ := by + -- Applying the unfoldings only inside the match + conv in _ % _ => unfold HMod.hMod instHModScalarResult; simp [rem] + have h := tryMk_eq ty (scalar_rem ↑x ↑y) + simp [in_bounds] at h + split_ifs <;> simp <;> + split at h <;> simp_all [check_bounds_eq_in_bounds] + -- Generic theorem - shouldn't be used much @[pspec] theorem Scalar.rem_spec {ty} {x y : Scalar ty} @@ -911,7 +1131,7 @@ theorem Scalar.rem_spec {ty} {x y : Scalar ty} (hmax : scalar_rem ↑x ↑y ≤ Scalar.max ty) : ∃ z, x % y = ok z ∧ (↑z : Int) = scalar_rem ↑x ↑y := by simp [HMod.hMod, rem] - simp [tryMk, *] + simp [tryMk, tryMkOpt, ofOption, *] rfl theorem Scalar.rem_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : Scalar ty} @@ -990,6 +1210,16 @@ theorem Scalar.rem_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : S ∃ z, x % y = ok z ∧ (↑z : Int) = scalar_rem ↑x ↑y := Scalar.rem_spec hnz hmin hmax +theorem core.num.checked_rem_spec {ty} {x y : Scalar ty} : + match core.num.checked_rem x y with + | some z => y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_rem ↑x ↑y) ∧ ↑z = (scalar_rem ↑x ↑y : Int) + | none => ¬ (y.val ≠ 0 ∧ Scalar.in_bounds ty (scalar_rem ↑x ↑y)) := by + have h := Scalar.tryMk_eq ty (scalar_rem ↑x ↑y) + simp only [checked_rem, Option.ofResult] + cases heq0: (y.val = 0 : Bool) <;> + cases heq1: x % y <;> simp_all <;> simp only [HMod.hMod, Scalar.rem, Mod.mod] at heq1 + <;> simp_all + -- ofIntCore -- TODO: typeclass? def Isize.ofIntCore := @Scalar.ofIntCore .Isize diff --git a/compiler/ExtractBuiltin.ml b/compiler/ExtractBuiltin.ml index 401d0137..a2983573 100644 --- a/compiler/ExtractBuiltin.ml +++ b/compiler/ExtractBuiltin.ml @@ -240,6 +240,27 @@ let builtin_funs () : (pattern * bool list option * builtin_fun_info) list = let f = { extract_name = basename } in (rust_name, filter, f) in + let mk_scalar_fun (rust_name_prefix : string) (rust_name_suffix : string) + (extract_name : string option) (filter : bool list option) : + (pattern * bool list option * builtin_fun_info) list = + List.map + (fun ty -> + mk_fun (rust_name_prefix ^ ty ^ rust_name_suffix) extract_name filter) + [ + "usize"; + "u8"; + "u16"; + "u32"; + "u64"; + "u128"; + "isize"; + "i8"; + "i16"; + "i32"; + "i64"; + "i128"; + ] + in [ mk_fun "core::mem::replace" None None; mk_fun "core::slice::{[@T]}::len" @@ -325,6 +346,16 @@ let builtin_funs () : (pattern * bool list option * builtin_fun_info) list = [@T]>}::index_mut" (Some "core_slice_index_Slice_index_mut") None; ] + @ mk_scalar_fun "core::num::{" "}::checked_add" (Some "core.num.checked_add") + None + @ mk_scalar_fun "core::num::{" "}::checked_sub" (Some "core.num.checked_sub") + None + @ mk_scalar_fun "core::num::{" "}::checked_mul" (Some "core.num.checked_mul") + None + @ mk_scalar_fun "core::num::{" "}::checked_div" (Some "core.num.checked_div") + None + @ mk_scalar_fun "core::num::{" "}::checked_rem" (Some "core.num.checked_rem") + None let mk_builtin_funs_map () = let m = |