summaryrefslogtreecommitdiff
path: root/backends/lean/Base
diff options
context:
space:
mode:
authorSon Ho2023-07-12 15:41:29 +0200
committerSon Ho2023-07-12 15:41:29 +0200
commit59e4a06480b5365f48dc68de80f44841f94094ed (patch)
tree00644dda183bff958381183939f4eb54d97ed242 /backends/lean/Base
parenta18d899a2c2b9bdd36f4a5a4b70472c12a835a96 (diff)
Improve the handling of arithmetic bounds
Diffstat (limited to 'backends/lean/Base')
-rw-r--r--backends/lean/Base/Arith/Arith.lean8
-rw-r--r--backends/lean/Base/Primitives.lean228
-rw-r--r--backends/lean/Base/Progress/Progress.lean5
3 files changed, 126 insertions, 115 deletions
diff --git a/backends/lean/Base/Arith/Arith.lean b/backends/lean/Base/Arith/Arith.lean
index 3557d350..20420f36 100644
--- a/backends/lean/Base/Arith/Arith.lean
+++ b/backends/lean/Base/Arith/Arith.lean
@@ -339,13 +339,17 @@ def scalarTac : Tactic.TacticM Unit := do
let add (e : Expr) : Tactic.TacticM Unit := do
let ty ← inferType e
let _ ← Utils.addDeclTac (← mkFreshUserName `h) e ty (asLet := false)
- add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Usize []])
add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Isize []])
add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []])
add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []])
-- Reveal the concrete bounds - TODO: not too sure about that.
-- Maybe we should reveal the "concrete" bounds (after normalization)
- Utils.simpAt [``Scalar.max, ``Scalar.min, ``Scalar.cMin, ``Scalar.cMax] [] [] .wildcard
+ Utils.simpAt [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax,
+ ``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min,
+ ``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max,
+ ``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min,
+ ``U8.max, ``U16.max, ``U32.max, ``U64.max, ``U128.max
+ ] [] [] .wildcard
-- Apply the integer tactic
intTac
diff --git a/backends/lean/Base/Primitives.lean b/backends/lean/Base/Primitives.lean
index 6210688d..37abdede 100644
--- a/backends/lean/Base/Primitives.lean
+++ b/backends/lean/Base/Primitives.lean
@@ -149,54 +149,78 @@ open System.Platform.getNumBits
@[simp] def size_num_bits : Nat := (System.Platform.getNumBits ()).val
-- Remark: Lean seems to use < for the comparisons with the upper bounds by convention.
--- We keep the F* convention for now.
-@[simp] def Isize.min : Int := - (HPow.hPow 2 (size_num_bits - 1))
-@[simp] def Isize.max : Int := (HPow.hPow 2 (size_num_bits - 1)) - 1
-@[simp] def I8.min : Int := - (HPow.hPow 2 7)
-@[simp] def I8.max : Int := HPow.hPow 2 7 - 1
-@[simp] def I16.min : Int := - (HPow.hPow 2 15)
-@[simp] def I16.max : Int := HPow.hPow 2 15 - 1
-@[simp] def I32.min : Int := -(HPow.hPow 2 31)
-@[simp] def I32.max : Int := HPow.hPow 2 31 - 1
-@[simp] def I64.min : Int := -(HPow.hPow 2 63)
-@[simp] def I64.max : Int := HPow.hPow 2 63 - 1
-@[simp] def I128.min : Int := -(HPow.hPow 2 127)
-@[simp] def I128.max : Int := HPow.hPow 2 127 - 1
-@[simp] def Usize.min : Int := 0
-@[simp] def Usize.max : Int := HPow.hPow 2 size_num_bits - 1
-@[simp] def U8.min : Int := 0
-@[simp] def U8.max : Int := HPow.hPow 2 8 - 1
-@[simp] def U16.min : Int := 0
-@[simp] def U16.max : Int := HPow.hPow 2 16 - 1
-@[simp] def U32.min : Int := 0
-@[simp] def U32.max : Int := HPow.hPow 2 32 - 1
-@[simp] def U64.min : Int := 0
-@[simp] def U64.max : Int := HPow.hPow 2 64 - 1
-@[simp] def U128.min : Int := 0
-@[simp] def U128.max : Int := HPow.hPow 2 128 - 1
-
--- The normalized bounds
-@[simp] def I8.norm_min := -128
-@[simp] def I8.norm_max := 127
-@[simp] def I16.norm_min := -32768
-@[simp] def I16.norm_max := 32767
-@[simp] def I32.norm_min := -2147483648
-@[simp] def I32.norm_max := 2147483647
-@[simp] def I64.norm_min := -9223372036854775808
-@[simp] def I64.norm_max := 9223372036854775807
-@[simp] def I128.norm_min := -170141183460469231731687303715884105728
-@[simp] def I128.norm_max := 170141183460469231731687303715884105727
-@[simp] def U8.norm_min := 0
-@[simp] def U8.norm_max := 255
-@[simp] def U16.norm_min := 0
-@[simp] def U16.norm_max := 65535
-@[simp] def U32.norm_min := 0
-@[simp] def U32.norm_max := 4294967295
-@[simp] def U64.norm_min := 0
-@[simp] def U64.norm_max := 18446744073709551615
-@[simp] def U128.norm_min := 0
-@[simp] def U128.norm_max := 340282366920938463463374607431768211455
-
+
+-- The "structured" bounds
+def Isize.smin : Int := - (HPow.hPow 2 (size_num_bits - 1))
+def Isize.smax : Int := (HPow.hPow 2 (size_num_bits - 1)) - 1
+def I8.smin : Int := - (HPow.hPow 2 7)
+def I8.smax : Int := HPow.hPow 2 7 - 1
+def I16.smin : Int := - (HPow.hPow 2 15)
+def I16.smax : Int := HPow.hPow 2 15 - 1
+def I32.smin : Int := -(HPow.hPow 2 31)
+def I32.smax : Int := HPow.hPow 2 31 - 1
+def I64.smin : Int := -(HPow.hPow 2 63)
+def I64.smax : Int := HPow.hPow 2 63 - 1
+def I128.smin : Int := -(HPow.hPow 2 127)
+def I128.smax : Int := HPow.hPow 2 127 - 1
+def Usize.smin : Int := 0
+def Usize.smax : Int := HPow.hPow 2 size_num_bits - 1
+def U8.smin : Int := 0
+def U8.smax : Int := HPow.hPow 2 8 - 1
+def U16.smin : Int := 0
+def U16.smax : Int := HPow.hPow 2 16 - 1
+def U32.smin : Int := 0
+def U32.smax : Int := HPow.hPow 2 32 - 1
+def U64.smin : Int := 0
+def U64.smax : Int := HPow.hPow 2 64 - 1
+def U128.smin : Int := 0
+def U128.smax : Int := HPow.hPow 2 128 - 1
+
+-- The "normalized" bounds, that we use in practice
+def I8.min := -128
+def I8.max := 127
+def I16.min := -32768
+def I16.max := 32767
+def I32.min := -2147483648
+def I32.max := 2147483647
+def I64.min := -9223372036854775808
+def I64.max := 9223372036854775807
+def I128.min := -170141183460469231731687303715884105728
+def I128.max := 170141183460469231731687303715884105727
+@[simp] def U8.min := 0
+def U8.max := 255
+@[simp] def U16.min := 0
+def U16.max := 65535
+@[simp] def U32.min := 0
+def U32.max := 4294967295
+@[simp] def U64.min := 0
+def U64.max := 18446744073709551615
+@[simp] def U128.min := 0
+def U128.max := 340282366920938463463374607431768211455
+@[simp] def Usize.min := 0
+
+def Isize.refined_min : { n:Int // n = I32.min ∨ n = I64.min } :=
+ ⟨ Isize.smin, by
+ simp [Isize.smin]
+ cases System.Platform.numBits_eq <;>
+ unfold System.Platform.numBits at * <;> simp [*] ⟩
+
+def Isize.refined_max : { n:Int // n = I32.max ∨ n = I64.max } :=
+ ⟨ Isize.smax, by
+ simp [Isize.smax]
+ cases System.Platform.numBits_eq <;>
+ unfold System.Platform.numBits at * <;> simp [*] ⟩
+
+def Usize.refined_max : { n:Int // n = U32.max ∨ n = U64.max } :=
+ ⟨ Usize.smax, by
+ simp [Usize.smax]
+ cases System.Platform.numBits_eq <;>
+ unfold System.Platform.numBits at * <;> simp [*] ⟩
+
+def Isize.min := Isize.refined_min.val
+def Isize.max := Isize.refined_max.val
+def Usize.max := Usize.refined_max.val
+
inductive ScalarTy :=
| Isize
| I8
@@ -211,6 +235,36 @@ inductive ScalarTy :=
| U64
| U128
+def Scalar.smin (ty : ScalarTy) : Int :=
+ match ty with
+ | .Isize => Isize.smin
+ | .I8 => I8.smin
+ | .I16 => I16.smin
+ | .I32 => I32.smin
+ | .I64 => I64.smin
+ | .I128 => I128.smin
+ | .Usize => Usize.smin
+ | .U8 => U8.smin
+ | .U16 => U16.smin
+ | .U32 => U32.smin
+ | .U64 => U64.smin
+ | .U128 => U128.smin
+
+def Scalar.smax (ty : ScalarTy) : Int :=
+ match ty with
+ | .Isize => Isize.smax
+ | .I8 => I8.smax
+ | .I16 => I16.smax
+ | .I32 => I32.smax
+ | .I64 => I64.smax
+ | .I128 => I128.smax
+ | .Usize => Usize.smax
+ | .U8 => U8.smax
+ | .U16 => U16.smax
+ | .U32 => U32.smax
+ | .U64 => U64.smax
+ | .U128 => U128.smax
+
def Scalar.min (ty : ScalarTy) : Int :=
match ty with
| .Isize => Isize.min
@@ -241,44 +295,10 @@ def Scalar.max (ty : ScalarTy) : Int :=
| .U64 => U64.max
| .U128 => U128.max
-@[simp] def Scalar.norm_min (ty : ScalarTy) : Int :=
- match ty with
- -- We can't normalize the bounds for isize/usize
- | .Isize => Isize.min
- | .Usize => Usize.min
- --
- | .I8 => I8.norm_min
- | .I16 => I16.norm_min
- | .I32 => I32.norm_min
- | .I64 => I64.norm_min
- | .I128 => I128.norm_min
- | .U8 => U8.norm_min
- | .U16 => U16.norm_min
- | .U32 => U32.norm_min
- | .U64 => U64.norm_min
- | .U128 => U128.norm_min
-
-@[simp] def Scalar.norm_max (ty : ScalarTy) : Int :=
- match ty with
- -- We can't normalize the bounds for isize/usize
- | .Isize => Isize.max
- | .Usize => Usize.max
- --
- | .I8 => I8.norm_max
- | .I16 => I16.norm_max
- | .I32 => I32.norm_max
- | .I64 => I64.norm_max
- | .I128 => I128.norm_max
- | .U8 => U8.norm_max
- | .U16 => U16.norm_max
- | .U32 => U32.norm_max
- | .U64 => U64.norm_max
- | .U128 => U128.norm_max
-
-def Scalar.norm_min_eq (ty : ScalarTy) : Scalar.min ty = Scalar.norm_min ty := by
+def Scalar.smin_eq (ty : ScalarTy) : Scalar.min ty = Scalar.smin ty := by
cases ty <;> rfl
-def Scalar.norm_max_eq (ty : ScalarTy) : Scalar.max ty = Scalar.norm_max ty := by
+def Scalar.smax_eq (ty : ScalarTy) : Scalar.max ty = Scalar.smax ty := by
cases ty <;> rfl
-- "Conservative" bounds
@@ -301,30 +321,22 @@ def Scalar.cMax (ty : ScalarTy) : Int :=
theorem Scalar.cMin_bound ty : Scalar.min ty ≤ Scalar.cMin ty := by
cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at *
- cases System.Platform.numBits_eq <;>
- unfold System.Platform.numBits at * <;>
- simp [*]
+ have h := Isize.refined_min.property
+ cases h <;> simp [*, Isize.min]
theorem Scalar.cMax_bound ty : Scalar.cMax ty ≤ Scalar.max ty := by
- cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at * <;>
- cases System.Platform.numBits_eq <;>
- unfold System.Platform.numBits at * <;>
- simp [*]
+ cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at *
+ . have h := Isize.refined_max.property
+ cases h <;> simp [*, Isize.max]
+ . have h := Usize.refined_max.property
+ cases h <;> simp [*, Usize.max]
theorem Scalar.cMin_suffices ty (h : Scalar.cMin ty ≤ x) : Scalar.min ty ≤ x := by
- cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at * <;>
- cases System.Platform.numBits_eq <;>
- unfold System.Platform.numBits at * <;>
- simp [*] at *
- -- TODO: I would have expected terms like `-(1 + 1) ^ 63` to be simplified
+ have := Scalar.cMin_bound ty
linarith
theorem Scalar.cMax_suffices ty (h : x ≤ Scalar.cMax ty) : x ≤ Scalar.max ty := by
- cases ty <;> simp [Scalar.min, Scalar.max, Scalar.cMin, Scalar.cMax] at * <;>
- cases System.Platform.numBits_eq <;>
- unfold System.Platform.numBits at * <;>
- simp [*] at * <;>
- -- TODO: I would have expected terms like `-(1 + 1) ^ 63` to be simplified
+ have := Scalar.cMax_bound ty
linarith
structure Scalar (ty : ScalarTy) where
@@ -609,7 +621,7 @@ def Vec.new (α : Type u): Vec α := ⟨ [], by apply Scalar.cMax_suffices .Usiz
def Vec.len (α : Type u) (v : Vec α) : Usize :=
let ⟨ v, l ⟩ := v
- Usize.ofIntCore (List.length v) (by simp [Scalar.min]) l
+ Usize.ofIntCore (List.length v) (by simp [Scalar.min, Usize.min]) l
-- This shouldn't be used
def Vec.push_fwd (α : Type u) (_ : Vec α) (_ : α) : Unit := ()
@@ -620,13 +632,9 @@ def Vec.push (α : Type u) (v : Vec α) (x : α) : Result (Vec α)
let nlen := List.length v.val + 1
if h : nlen ≤ U32.max || nlen ≤ Usize.max then
have h : nlen ≤ Usize.max := by
- simp at *
- cases System.Platform.numBits_eq <;>
- unfold System.Platform.numBits at * <;>
- simp [*] at * <;>
- try assumption
- cases h <;>
- linarith
+ simp [Usize.max] at *
+ have hm := Usize.refined_max.property
+ cases h <;> cases hm <;> simp [U32.max, U64.max] at * <;> try linarith
return ⟨ List.concat v.val x, by simp at *; assumption ⟩
else
fail maximumSizeExceeded
@@ -647,7 +655,6 @@ def Vec.insert (α : Type u) (v: Vec α) (i: Usize) (x: α): Result (Vec α) :=
.ret ⟨ List.set v.val i x, by
have h: List.length v.val ≤ Usize.max := v.property
simp [*] at *
- assumption
else
.fail arrayOutOfBounds
@@ -688,7 +695,6 @@ def Vec.index_mut_back (α : Type u) (v: Vec α) (i: Usize) (x: α): Result (Vec
.ret ⟨ List.set v.val i x, by
have h: List.length v.val ≤ Usize.max := v.property
simp [*] at *
- assumption
else
.fail arrayOutOfBounds
diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean
index 4c68b3bd..a4df5c96 100644
--- a/backends/lean/Base/Progress/Progress.lean
+++ b/backends/lean/Base/Progress/Progress.lean
@@ -92,9 +92,10 @@ def progressLookupTheorem (asmTac : TacticM Unit) : TacticM Unit := do
if ← isConj (← inferType h) then
splitConjTac h (fun h _ => k h)
else k h
- -- Simplify the target by using the equality
+ -- Simplify the target by using the equality and some monad simplifications
splitConj fun h => do
- simpAt [] [] [h.fvarId!] (.targets #[] true)
+ simpAt [] [``Primitives.bind_tc_ret, ``Primitives.bind_tc_fail, ``Primitives.bind_tc_div]
+ [h.fvarId!] (.targets #[] true)
-- Clear the equality
let mgoal ← getMainGoal
let mgoal ← mgoal.tryClearMany #[h.fvarId!]