From 59e4a06480b5365f48dc68de80f44841f94094ed Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 12 Jul 2023 15:41:29 +0200 Subject: Improve the handling of arithmetic bounds --- backends/lean/Base/Arith/Arith.lean | 8 +- backends/lean/Base/Primitives.lean | 228 +++++++++++++++--------------- backends/lean/Base/Progress/Progress.lean | 5 +- 3 files changed, 126 insertions(+), 115 deletions(-) (limited to 'backends') diff --git a/backends/lean/Base/Arith/Arith.lean b/backends/lean/Base/Arith/Arith.lean index 3557d350..20420f36 100644 --- a/backends/lean/Base/Arith/Arith.lean +++ b/backends/lean/Base/Arith/Arith.lean @@ -339,13 +339,17 @@ def scalarTac : Tactic.TacticM Unit := do let add (e : Expr) : Tactic.TacticM Unit := do let ty ← inferType e let _ ← Utils.addDeclTac (← mkFreshUserName `h) e ty (asLet := false) - add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Usize []]) add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Isize []]) add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []]) add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []]) -- Reveal the concrete bounds - TODO: not too sure about that. -- Maybe we should reveal the "concrete" bounds (after normalization) - Utils.simpAt [``Scalar.max, ``Scalar.min, ``Scalar.cMin, ``Scalar.cMax] [] [] .wildcard + Utils.simpAt [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax, + ``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min, + ``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max, + ``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min, + ``U8.max, ``U16.max, ``U32.max, ``U64.max, ``U128.max + ] [] [] .wildcard -- Apply the integer tactic intTac diff --git a/backends/lean/Base/Primitives.lean b/backends/lean/Base/Primitives.lean index 6210688d..37abdede 100644 --- a/backends/lean/Base/Primitives.lean +++ b/backends/lean/Base/Primitives.lean @@ -149,54 +149,78 @@ open System.Platform.getNumBits @[simp] def size_num_bits : Nat := (System.Platform.getNumBits ()).val -- Remark: Lean seems to use < for the comparisons with the upper bounds by convention. --- We keep the F* convention for now. -@[simp] def Isize.min : Int := - (HPow.hPow 2 (size_num_bits - 1)) -@[simp] def Isize.max : Int := (HPow.hPow 2 (size_num_bits - 1)) - 1 -@[simp] def I8.min : Int := - (HPow.hPow 2 7) -@[simp] def I8.max : Int := HPow.hPow 2 7 - 1 -@[simp] def I16.min : Int := - (HPow.hPow 2 15) -@[simp] def I16.max : Int := HPow.hPow 2 15 - 1 -@[simp] def I32.min : Int := -(HPow.hPow 2 31) -@[simp] def I32.max : Int := HPow.hPow 2 31 - 1 -@[simp] def I64.min : Int := -(HPow.hPow 2 63) -@[simp] def I64.max : Int := HPow.hPow 2 63 - 1 -@[simp] def I128.min : Int := -(HPow.hPow 2 127) -@[simp] def I128.max : Int := HPow.hPow 2 127 - 1 -@[simp] def Usize.min : Int := 0 -@[simp] def Usize.max : Int := HPow.hPow 2 size_num_bits - 1 -@[simp] def U8.min : Int := 0 -@[simp] def U8.max : Int := HPow.hPow 2 8 - 1 -@[simp] def U16.min : Int := 0 -@[simp] def U16.max : Int := HPow.hPow 2 16 - 1 -@[simp] def U32.min : Int := 0 -@[simp] def U32.max : Int := HPow.hPow 2 32 - 1 -@[simp] def U64.min : Int := 0 -@[simp] def U64.max : Int := HPow.hPow 2 64 - 1 -@[simp] def U128.min : Int := 0 -@[simp] def U128.max : Int := HPow.hPow 2 128 - 1 - --- The normalized bounds -@[simp] def I8.norm_min := -128 -@[simp] def I8.norm_max := 127 -@[simp] def I16.norm_min := -32768 -@[simp] def I16.norm_max := 32767 -@[simp] def I32.norm_min := -2147483648 -@[simp] def I32.norm_max := 2147483647 -@[simp] def I64.norm_min := -9223372036854775808 -@[simp] def I64.norm_max := 9223372036854775807 -@[simp] def I128.norm_min := -170141183460469231731687303715884105728 -@[simp] def I128.norm_max := 170141183460469231731687303715884105727 -@[simp] def U8.norm_min := 0 -@[simp] def U8.norm_max := 255 -@[simp] def U16.norm_min := 0 -@[simp] def U16.norm_max := 65535 -@[simp] def U32.norm_min := 0 -@[simp] def U32.norm_max := 4294967295 -@[simp] def U64.norm_min := 0 -@[simp] def U64.norm_max := 18446744073709551615 -@[simp] def U128.norm_min := 0 -@[simp] def U128.norm_max := 340282366920938463463374607431768211455 - + +-- The "structured" bounds +def Isize.smin : Int := - (HPow.hPow 2 (size_num_bits - 1)) +def Isize.smax : Int := (HPow.hPow 2 (size_num_bits - 1)) - 1 +def I8.smin : Int := - (HPow.hPow 2 7) +def I8.smax : Int := HPow.hPow 2 7 - 1 +def I16.smin : Int := - (HPow.hPow 2 15) +def I16.smax : Int := HPow.hPow 2 15 - 1 +def I32.smin : Int := -(HPow.hPow 2 31) +def I32.smax : Int := HPow.hPow 2 31 - 1 +def I64.smin : Int := -(HPow.hPow 2 63) +def I64.smax : Int := HPow.hPow 2 63 - 1 +def I128.smin : Int := -(HPow.hPow 2 127) +def I128.smax : Int := HPow.hPow 2 127 - 1 +def Usize.smin : Int := 0 +def Usize.smax : Int := HPow.hPow 2 size_num_bits - 1 +def U8.smin : Int := 0 +def U8.smax : Int := HPow.hPow 2 8 - 1 +def U16.smin : Int := 0 +def U16.smax : Int := HPow.hPow 2 16 - 1 +def U32.smin : Int := 0 +def U32.smax : Int := HPow.hPow 2 32 - 1 +def U64.smin : Int := 0 +def U64.smax : Int := HPow.hPow 2 64 - 1 +def U128.smin : Int := 0 +def U128.smax : Int := HPow.hPow 2 128 - 1 + +-- The "normalized" bounds, that we use in practice +def I8.min := -128 +def I8.max := 127 +def I16.min := -32768 +def I16.max := 32767 +def I32.min := -2147483648 +def I32.max := 2147483647 +def I64.min := -9223372036854775808 +def I64.max := 9223372036854775807 +def I128.min := -170141183460469231731687303715884105728 +def I128.max := 170141183460469231731687303715884105727 +@[simp] def U8.min := 0 +def U8.max := 255 +@[simp] def U16.min := 0 +def U16.max := 65535 +@[simp] def U32.min := 0 +def U32.max := 4294967295 +@[simp] def U64.min := 0 +def U64.max := 18446744073709551615 +@[simp] def U128.min := 0 +def U128.max := 340282366920938463463374607431768211455 +@[simp] def Usize.min := 0 + +def Isize.refined_min : { n:Int // n = I32.min ∨ n = I64.min } := + ⟨ Isize.smin, by + simp [Isize.smin] + cases System.Platform.numBits_eq <;> + unfold System.Platform.numBits at * <;> simp [*] ⟩ + +def Isize.refined_max : { n:Int // n = I32.max ∨ n = I64.max } := + ⟨ Isize.smax, by + simp [Isize.smax] + cases System.Platform.numBits_eq <;> + unfold System.Platform.numBits at * <;> simp [*] ⟩ + +def Usize.refined_max : { n:Int // n = U32.max ∨ n = U64.max } := + ⟨ Usize.smax, by + simp [Usize.smax] + cases System.Platform.numBits_eq <;> + unfold System.Platform.numBits at * <;> simp [*] ⟩ + +def Isize.min := Isize.refined_min.val +def Isize.max := Isize.refined_max.val +def Usize.max := Usize.refined_max.val + inductive ScalarTy := | Isize | I8 @@ -211,6 +235,36 @@ inductive ScalarTy := | U64 | U128 +def Scalar.smin (ty : ScalarTy) : Int := + match ty with + | .Isize => Isize.smin + | .I8 => I8.smin + | .I16 => I16.smin + | .I32 => I32.smin + | .I64 => I64.smin + | .I128 => I128.smin + | .Usize => Usize.smin + | .U8 => U8.smin + | .U16 => U16.smin + | .U32 => U32.smin + | .U64 => U64.smin + | .U128 => U128.smin + +def Scalar.smax (ty : ScalarTy) : Int := + match ty with + | .Isize => Isize.smax + | .I8 => I8.smax + | .I16 => I16.smax + | .I32 => I32.smax + | .I64 => I64.smax + | .I128 => I128.smax + | .Usize => Usize.smax + | .U8 => U8.smax + | .U16 => U16.smax + | .U32 => U32.smax + | .U64 => U64.smax + | .U128 => U128.smax + def Scalar.min (ty : ScalarTy) : Int := match ty with | .Isize => Isize.min @@ -241,44 +295,10 @@ def Scalar.max (ty : ScalarTy) : Int := | .U64 => U64.max | .U128 => U128.max -@[simp] def Scalar.norm_min (ty : ScalarTy) : Int := - match ty with - -- We can't normalize the bounds for isize/usize - | .Isize => Isize.min - | .Usize => Usize.min - -- - | .I8 => I8.norm_min - | .I16 => I16.norm_min - | .I32 => I32.norm_min - | .I64 => I64.norm_min - | .I128 => I128.norm_min - | .U8 => U8.norm_min - | .U16 => U16.norm_min - | .U32 => U32.norm_min - | .U64 => U64.norm_min - | .U128 => U128.norm_min - -@[simp] def Scalar.norm_max (ty : ScalarTy) : Int := - match ty with - -- We can't normalize the bounds for isize/usize - | .Isize => Isize.max - | .Usize => Usize.max - -- - | .I8 => I8.norm_max - | .I16 => I16.norm_max - | .I32 => I32.norm_max - | .I64 => I64.norm_max - | .I128 => I128.norm_max - | .U8 => U8.norm_max - | .U16 => U16.norm_max - | .U32 => U32.norm_max - | .U64 => U64.norm_max - | .U128 => U128.norm_max - -def Scalar.norm_min_eq (ty : ScalarTy) : Scalar.min ty = Scalar.norm_min ty := by +def Scalar.smin_eq (ty : ScalarTy) : Scalar.min ty = Scalar.smin ty := by cases ty <;> rfl -def Scalar.norm_max_eq (ty : ScalarTy) : Scalar.max ty = Scalar.norm_max ty := by +def Scalar.smax_eq (ty : ScalarTy) : Scalar.max ty = Scalar.smax ty := by cases ty <;> rfl -- "Conservative" bounds @@ -301,30 +321,22 @@ def Scalar.cMax (ty : ScalarTy) : Int := theorem Scalar.cMin_bound ty : Scalar.min ty ≤ Scalar.cMin ty := by cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at * - cases System.Platform.numBits_eq <;> - unfold System.Platform.numBits at * <;> - simp [*] + have h := Isize.refined_min.property + cases h <;> simp [*, Isize.min] theorem Scalar.cMax_bound ty : Scalar.cMax ty ≤ Scalar.max ty := by - cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at * <;> - cases System.Platform.numBits_eq <;> - unfold System.Platform.numBits at * <;> - simp [*] + cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at * + . have h := Isize.refined_max.property + cases h <;> simp [*, Isize.max] + . have h := Usize.refined_max.property + cases h <;> simp [*, Usize.max] theorem Scalar.cMin_suffices ty (h : Scalar.cMin ty ≤ x) : Scalar.min ty ≤ x := by - cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at * <;> - cases System.Platform.numBits_eq <;> - unfold System.Platform.numBits at * <;> - simp [*] at * - -- TODO: I would have expected terms like `-(1 + 1) ^ 63` to be simplified + have := Scalar.cMin_bound ty linarith theorem Scalar.cMax_suffices ty (h : x ≤ Scalar.cMax ty) : x ≤ Scalar.max ty := by - cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at * <;> - cases System.Platform.numBits_eq <;> - unfold System.Platform.numBits at * <;> - simp [*] at * <;> - -- TODO: I would have expected terms like `-(1 + 1) ^ 63` to be simplified + have := Scalar.cMax_bound ty linarith structure Scalar (ty : ScalarTy) where @@ -609,7 +621,7 @@ def Vec.new (α : Type u): Vec α := ⟨ [], by apply Scalar.cMax_suffices .Usiz def Vec.len (α : Type u) (v : Vec α) : Usize := let ⟨ v, l ⟩ := v - Usize.ofIntCore (List.length v) (by simp [Scalar.min]) l + Usize.ofIntCore (List.length v) (by simp [Scalar.min, Usize.min]) l -- This shouldn't be used def Vec.push_fwd (α : Type u) (_ : Vec α) (_ : α) : Unit := () @@ -620,13 +632,9 @@ def Vec.push (α : Type u) (v : Vec α) (x : α) : Result (Vec α) let nlen := List.length v.val + 1 if h : nlen ≤ U32.max || nlen ≤ Usize.max then have h : nlen ≤ Usize.max := by - simp at * - cases System.Platform.numBits_eq <;> - unfold System.Platform.numBits at * <;> - simp [*] at * <;> - try assumption - cases h <;> - linarith + simp [Usize.max] at * + have hm := Usize.refined_max.property + cases h <;> cases hm <;> simp [U32.max, U64.max] at * <;> try linarith return ⟨ List.concat v.val x, by simp at *; assumption ⟩ else fail maximumSizeExceeded @@ -647,7 +655,6 @@ def Vec.insert (α : Type u) (v: Vec α) (i: Usize) (x: α): Result (Vec α) := .ret ⟨ List.set v.val i x, by have h: List.length v.val ≤ Usize.max := v.property simp [*] at * - assumption ⟩ else .fail arrayOutOfBounds @@ -688,7 +695,6 @@ def Vec.index_mut_back (α : Type u) (v: Vec α) (i: Usize) (x: α): Result (Vec .ret ⟨ List.set v.val i x, by have h: List.length v.val ≤ Usize.max := v.property simp [*] at * - assumption ⟩ else .fail arrayOutOfBounds diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean index 4c68b3bd..a4df5c96 100644 --- a/backends/lean/Base/Progress/Progress.lean +++ b/backends/lean/Base/Progress/Progress.lean @@ -92,9 +92,10 @@ def progressLookupTheorem (asmTac : TacticM Unit) : TacticM Unit := do if ← isConj (← inferType h) then splitConjTac h (fun h _ => k h) else k h - -- Simplify the target by using the equality + -- Simplify the target by using the equality and some monad simplifications splitConj fun h => do - simpAt [] [] [h.fvarId!] (.targets #[] true) + simpAt [] [``Primitives.bind_tc_ret, ``Primitives.bind_tc_fail, ``Primitives.bind_tc_div] + [h.fvarId!] (.targets #[] true) -- Clear the equality let mgoal ← getMainGoal let mgoal ← mgoal.tryClearMany #[h.fvarId!] -- cgit v1.2.3