summaryrefslogtreecommitdiff
path: root/backends/lean/Base
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--backends/lean/Base/Arith/Int.lean10
-rw-r--r--backends/lean/Base/Arith/Scalar.lean17
-rw-r--r--backends/lean/Base/Diverge/Base.lean35
-rw-r--r--backends/lean/Base/Diverge/Elab.lean2
-rw-r--r--backends/lean/Base/IList/IList.lean5
-rw-r--r--backends/lean/Base/Primitives/Scalar.lean226
-rw-r--r--backends/lean/Base/Progress/Base.lean57
-rw-r--r--backends/lean/Base/Progress/Progress.lean35
-rw-r--r--backends/lean/Base/Utils.lean4
9 files changed, 322 insertions, 69 deletions
diff --git a/backends/lean/Base/Arith/Int.lean b/backends/lean/Base/Arith/Int.lean
index 531ec94f..3359ecdb 100644
--- a/backends/lean/Base/Arith/Int.lean
+++ b/backends/lean/Base/Arith/Int.lean
@@ -211,9 +211,11 @@ def intTacPreprocess (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM U
let _ ← introHasIntPropInstances
-- Extra preprocessing, before we split on the disjunctions
extraPreprocess
- -- Split
- let asms ← introInstances ``PropHasImp.concl lookupPropHasImp
- splitOnAsms asms.toList
+ -- Split - note that the extra-preprocessing step might actually have
+ -- proven the goal (by doing simplifications for instance)
+ Tactic.allGoals do
+ let asms ← introInstances ``PropHasImp.concl lookupPropHasImp
+ splitOnAsms asms.toList
elab "int_tac_preprocess" : tactic =>
intTacPreprocess (do pure ())
@@ -238,7 +240,7 @@ def intTac (splitGoalConjs : Bool) (extraPreprocess : Tactic.TacticM Unit) : Ta
-- the goal. I think before leads to a smaller proof term?
Tactic.allGoals (intTacPreprocess extraPreprocess)
-- More preprocessing
- Tactic.allGoals (Utils.simpAt [] [``nat_zero_eq_int_zero] [] .wildcard)
+ Tactic.allGoals (Utils.tryTac (Utils.simpAt [] [``nat_zero_eq_int_zero] [] .wildcard))
-- Split the conjunctions in the goal
if splitGoalConjs then Tactic.allGoals (Utils.repeatTac Utils.splitConjTarget)
-- Call linarith
diff --git a/backends/lean/Base/Arith/Scalar.lean b/backends/lean/Base/Arith/Scalar.lean
index db672489..47751c8a 100644
--- a/backends/lean/Base/Arith/Scalar.lean
+++ b/backends/lean/Base/Arith/Scalar.lean
@@ -16,14 +16,15 @@ def scalarTacExtraPreprocess : Tactic.TacticM Unit := do
add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Isize []])
add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []])
add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []])
- -- Reveal the concrete bounds
+ -- Reveal the concrete bounds, simplify calls to [ofInt]
Utils.simpAt [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax,
``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min,
``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max,
``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min,
``U8.max, ``U16.max, ``U32.max, ``U64.max, ``U128.max,
``Usize.min
- ] [] [] .wildcard
+ ] [``Scalar.ofInt_val_eq, ``Scalar.neq_to_neq_val] [] .wildcard
+
elab "scalar_tac_preprocess" : tactic =>
intTacPreprocess scalarTacExtraPreprocess
@@ -50,4 +51,16 @@ example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by
example (x : U32 × U32) : 0 ≤ x.fst.val := by
scalar_tac
+-- Checking that we properly handle [ofInt]
+example : U32.ofInt 1 ≤ U32.max := by
+ scalar_tac
+
+example (x : Int) (h0 : 0 ≤ x) (h1 : x ≤ U32.max) :
+ U32.ofInt x (by constructor <;> scalar_tac) ≤ U32.max := by
+ scalar_tac
+
+-- Not equal
+example (x : U32) (h0 : ¬ x = U32.ofInt 0) : 0 < x.val := by
+ scalar_tac
+
end Arith
diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean
index 1d548389..6a52387d 100644
--- a/backends/lean/Base/Diverge/Base.lean
+++ b/backends/lean/Base/Diverge/Base.lean
@@ -270,7 +270,7 @@ namespace Fix
simp [karrow_rel, result_rel]
have hg := hg Hrgh; simp [result_rel] at hg
cases heq0: g fg <;> simp_all
- rename_i y _
+ rename_i _ y
have hh := hh y Hrgh; simp [result_rel] at hh
simp_all
@@ -546,7 +546,7 @@ namespace FixI
termination_by for_all_fin_aux n _ m h => n - m
decreasing_by
simp_wf
- apply Nat.sub_add_lt_sub <;> simp
+ apply Nat.sub_add_lt_sub <;> try simp
simp_all [Arith.add_one_le_iff_le_ne]
def for_all_fin {n : Nat} (f : Fin n → Prop) := for_all_fin_aux f 0 (by simp)
@@ -569,7 +569,6 @@ namespace FixI
intro i h3; cases i; simp_all
linarith
case succ k hi =>
- simp_all
intro m hk hmn
intro hf i hmi
have hne: m ≠ n := by
@@ -580,7 +579,6 @@ namespace FixI
-- Yes: simply use the `for_all_fin_aux` hyp
unfold for_all_fin_aux at hf
simp_all
- tauto
else
-- No: use the induction hypothesis
have hlt: m < i := by simp_all [Nat.lt_iff_le_and_ne]
@@ -726,8 +724,8 @@ namespace Ex1
theorem list_nth_body_is_valid: ∀ k x, is_valid_p k (λ k => @list_nth_body a k x) := by
intro k x
simp [list_nth_body]
- split <;> simp
- split <;> simp
+ split <;> try simp
+ split <;> try simp
def list_nth (ls : List a) (i : Int) : Result a := fix list_nth_body (ls, i)
@@ -767,8 +765,8 @@ namespace Ex2
theorem list_nth_body_is_valid: ∀ k x, is_valid_p k (λ k => @list_nth_body a k x) := by
intro k x
simp [list_nth_body]
- split <;> simp
- split <;> simp
+ split <;> try simp
+ split <;> try simp
apply is_valid_p_bind <;> intros <;> simp_all
def list_nth (ls : List a) (i : Int) : Result a := fix list_nth_body (ls, i)
@@ -845,7 +843,7 @@ namespace Ex3
∀ k x, is_valid_p k (λ k => is_even_is_odd_body k x) := by
intro k x
simp [is_even_is_odd_body]
- split <;> simp <;> split <;> simp
+ split <;> (try simp) <;> split <;> try simp
apply is_valid_p_bind; simp
intros; split <;> simp
apply is_valid_p_bind; simp
@@ -878,7 +876,7 @@ namespace Ex3
-- inductives on the fly).
-- The simplest is to repeatedly split then simplify (we identify
-- the outer match or monadic let-binding, and split on its scrutinee)
- split <;> simp
+ split <;> try simp
cases H0 : fix is_even_is_odd_body (Sum.inr (i - 1)) <;> simp
rename_i v
split <;> simp
@@ -891,7 +889,7 @@ namespace Ex3
simp [is_even, is_odd]
conv => lhs; rw [Heq]; simp; rw [is_even_is_odd_body]; simp
-- Same remark as for `even`
- split <;> simp
+ split <;> try simp
cases H0 : fix is_even_is_odd_body (Sum.inl (i - 1)) <;> simp
rename_i v
split <;> simp
@@ -938,7 +936,7 @@ namespace Ex4
intro k
apply (Funs.is_valid_p_is_valid_p tys)
simp [Funs.is_valid_p]
- (repeat (apply And.intro)) <;> intro x <;> simp at x <;>
+ (repeat (apply And.intro)) <;> intro x <;> (try simp at x) <;>
simp only [is_even_body, is_odd_body]
-- Prove the validity of the individual bodies
. split <;> simp
@@ -995,9 +993,9 @@ namespace Ex5
(ls : List a) :
is_valid_p k (λ k => map (f k) ls) := by
induction ls <;> simp [map]
- apply is_valid_p_bind <;> simp_all
+ apply is_valid_p_bind <;> try simp_all
intros
- apply is_valid_p_bind <;> simp_all
+ apply is_valid_p_bind <;> try simp_all
/- An example which uses map -/
inductive Tree (a : Type) :=
@@ -1016,8 +1014,8 @@ namespace Ex5
∀ k x, is_valid_p k (λ k => @id_body a k x) := by
intro k x
simp only [id_body]
- split <;> simp
- apply is_valid_p_bind <;> simp [*]
+ split <;> try simp
+ apply is_valid_p_bind <;> try simp [*]
-- We have to show that `map k tl` is valid
apply map_is_valid;
-- Remark: if we don't do the intro, then the last step is expensive:
@@ -1077,12 +1075,11 @@ namespace Ex6
intro k
apply (Funs.is_valid_p_is_valid_p tys)
simp [Funs.is_valid_p]
- (repeat (apply And.intro)); intro x; simp at x
+ (repeat (apply And.intro)); intro x; try simp at x
simp only [list_nth_body]
-- Prove the validity of the individual bodies
intro k x
- simp [list_nth_body]
- split <;> simp
+ split <;> try simp
split <;> simp
-- Writing the proof terms explicitly
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean
index f109e847..c6628486 100644
--- a/backends/lean/Base/Diverge/Elab.lean
+++ b/backends/lean/Base/Diverge/Elab.lean
@@ -1089,7 +1089,7 @@ namespace Tests
intro i hpos h
-- We can directly use `rw [list_nth]`!
rw [list_nth]; simp
- split <;> simp [*]
+ split <;> try simp [*]
. tauto
. -- TODO: we shouldn't have to do that
have hneq : 0 < i := by cases i <;> rename_i a _ <;> simp_all; cases a <;> simp_all
diff --git a/backends/lean/Base/IList/IList.lean b/backends/lean/Base/IList/IList.lean
index f10ec4e7..79de93d5 100644
--- a/backends/lean/Base/IList/IList.lean
+++ b/backends/lean/Base/IList/IList.lean
@@ -294,7 +294,6 @@ open Arith in
have := tl.len_pos
linarith
else
- simp at hineq
have : 0 < i := by int_tac
simp [*]
apply hi
@@ -419,8 +418,8 @@ theorem index_itake_append_end [Inhabited α] (i j : Int) (l0 l1 : List α)
match l0 with
| [] => by
simp at *
- have := index_itake i j l1 (by simp_all) (by simp_all) (by simp_all; int_tac)
- simp [*]
+ have := index_itake i j l1 (by simp_all) (by simp_all) (by int_tac)
+ try simp [*]
| hd :: tl =>
have : ¬ i = 0 := by simp at *; int_tac
if hj : j = 0 then by simp_all; int_tac -- Contradiction
diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean
index 9e65d3c0..ec9665a5 100644
--- a/backends/lean/Base/Primitives/Scalar.lean
+++ b/backends/lean/Base/Primitives/Scalar.lean
@@ -533,6 +533,36 @@ theorem Scalar.add_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty}
∃ z, x + y = ret z ∧ z.val = x.val + y.val := by
apply Scalar.add_unsigned_spec <;> simp only [Scalar.max, *]
+@[cepspec] theorem Isize.add_spec {x y : Isize}
+ (hmin : Isize.min ≤ x.val + y.val) (hmax : x.val + y.val ≤ Isize.max) :
+ ∃ z, x + y = ret z ∧ z.val = x.val + y.val :=
+ Scalar.add_spec hmin hmax
+
+@[cepspec] theorem I8.add_spec {x y : I8}
+ (hmin : I8.min ≤ x.val + y.val) (hmax : x.val + y.val ≤ I8.max) :
+ ∃ z, x + y = ret z ∧ z.val = x.val + y.val :=
+ Scalar.add_spec hmin hmax
+
+@[cepspec] theorem I16.add_spec {x y : I16}
+ (hmin : I16.min ≤ x.val + y.val) (hmax : x.val + y.val ≤ I16.max) :
+ ∃ z, x + y = ret z ∧ z.val = x.val + y.val :=
+ Scalar.add_spec hmin hmax
+
+@[cepspec] theorem I32.add_spec {x y : I32}
+ (hmin : I32.min ≤ x.val + y.val) (hmax : x.val + y.val ≤ I32.max) :
+ ∃ z, x + y = ret z ∧ z.val = x.val + y.val :=
+ Scalar.add_spec hmin hmax
+
+@[cepspec] theorem I64.add_spec {x y : I64}
+ (hmin : I64.min ≤ x.val + y.val) (hmax : x.val + y.val ≤ I64.max) :
+ ∃ z, x + y = ret z ∧ z.val = x.val + y.val :=
+ Scalar.add_spec hmin hmax
+
+@[cepspec] theorem I128.add_spec {x y : I128}
+ (hmin : I128.min ≤ x.val + y.val) (hmax : x.val + y.val ≤ I128.max) :
+ ∃ z, x + y = ret z ∧ z.val = x.val + y.val :=
+ Scalar.add_spec hmin hmax
+
-- Generic theorem - shouldn't be used much
@[cpspec]
theorem Scalar.sub_spec {ty} {x y : Scalar ty}
@@ -582,6 +612,36 @@ theorem Scalar.sub_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty}
∃ z, x - y = ret z ∧ z.val = x.val - y.val := by
apply Scalar.sub_unsigned_spec <;> simp only [Scalar.min, *]
+@[cepspec] theorem Isize.sub_spec {x y : Isize} (hmin : Isize.min ≤ x.val - y.val)
+ (hmax : x.val - y.val ≤ Isize.max) :
+ ∃ z, x - y = ret z ∧ z.val = x.val - y.val :=
+ Scalar.sub_spec hmin hmax
+
+@[cepspec] theorem I8.sub_spec {x y : I8} (hmin : I8.min ≤ x.val - y.val)
+ (hmax : x.val - y.val ≤ I8.max) :
+ ∃ z, x - y = ret z ∧ z.val = x.val - y.val :=
+ Scalar.sub_spec hmin hmax
+
+@[cepspec] theorem I16.sub_spec {x y : I16} (hmin : I16.min ≤ x.val - y.val)
+ (hmax : x.val - y.val ≤ I16.max) :
+ ∃ z, x - y = ret z ∧ z.val = x.val - y.val :=
+ Scalar.sub_spec hmin hmax
+
+@[cepspec] theorem I32.sub_spec {x y : I32} (hmin : I32.min ≤ x.val - y.val)
+ (hmax : x.val - y.val ≤ I32.max) :
+ ∃ z, x - y = ret z ∧ z.val = x.val - y.val :=
+ Scalar.sub_spec hmin hmax
+
+@[cepspec] theorem I64.sub_spec {x y : I64} (hmin : I64.min ≤ x.val - y.val)
+ (hmax : x.val - y.val ≤ I64.max) :
+ ∃ z, x - y = ret z ∧ z.val = x.val - y.val :=
+ Scalar.sub_spec hmin hmax
+
+@[cepspec] theorem I128.sub_spec {x y : I128} (hmin : I128.min ≤ x.val - y.val)
+ (hmax : x.val - y.val ≤ I128.max) :
+ ∃ z, x - y = ret z ∧ z.val = x.val - y.val :=
+ Scalar.sub_spec hmin hmax
+
-- Generic theorem - shouldn't be used much
theorem Scalar.mul_spec {ty} {x y : Scalar ty}
(hmin : Scalar.min ty ≤ x.val * y.val)
@@ -628,6 +688,36 @@ theorem Scalar.mul_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty}
∃ z, x * y = ret z ∧ z.val = x.val * y.val := by
apply Scalar.mul_unsigned_spec <;> simp only [Scalar.max, *]
+@[cepspec] theorem Isize.mul_spec {x y : Isize} (hmin : Isize.min ≤ x.val * y.val)
+ (hmax : x.val * y.val ≤ Isize.max) :
+ ∃ z, x * y = ret z ∧ z.val = x.val * y.val :=
+ Scalar.mul_spec hmin hmax
+
+@[cepspec] theorem I8.mul_spec {x y : I8} (hmin : I8.min ≤ x.val * y.val)
+ (hmax : x.val * y.val ≤ I8.max) :
+ ∃ z, x * y = ret z ∧ z.val = x.val * y.val :=
+ Scalar.mul_spec hmin hmax
+
+@[cepspec] theorem I16.mul_spec {x y : I16} (hmin : I16.min ≤ x.val * y.val)
+ (hmax : x.val * y.val ≤ I16.max) :
+ ∃ z, x * y = ret z ∧ z.val = x.val * y.val :=
+ Scalar.mul_spec hmin hmax
+
+@[cepspec] theorem I32.mul_spec {x y : I32} (hmin : I32.min ≤ x.val * y.val)
+ (hmax : x.val * y.val ≤ I32.max) :
+ ∃ z, x * y = ret z ∧ z.val = x.val * y.val :=
+ Scalar.mul_spec hmin hmax
+
+@[cepspec] theorem I64.mul_spec {x y : I64} (hmin : I64.min ≤ x.val * y.val)
+ (hmax : x.val * y.val ≤ I64.max) :
+ ∃ z, x * y = ret z ∧ z.val = x.val * y.val :=
+ Scalar.mul_spec hmin hmax
+
+@[cepspec] theorem I128.mul_spec {x y : I128} (hmin : I128.min ≤ x.val * y.val)
+ (hmax : x.val * y.val ≤ I128.max) :
+ ∃ z, x * y = ret z ∧ z.val = x.val * y.val :=
+ Scalar.mul_spec hmin hmax
+
-- Generic theorem - shouldn't be used much
@[cpspec]
theorem Scalar.div_spec {ty} {x y : Scalar ty}
@@ -681,6 +771,48 @@ theorem Scalar.div_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : S
∃ z, x / y = ret z ∧ z.val = x.val / y.val := by
apply Scalar.div_unsigned_spec <;> simp [Scalar.max, *]
+@[cepspec] theorem Isize.div_spec (x : Isize) {y : Isize}
+ (hnz : y.val ≠ 0)
+ (hmin : Isize.min ≤ scalar_div x.val y.val)
+ (hmax : scalar_div x.val y.val ≤ Isize.max):
+ ∃ z, x / y = ret z ∧ z.val = scalar_div x.val y.val :=
+ Scalar.div_spec hnz hmin hmax
+
+@[cepspec] theorem I8.div_spec (x : I8) {y : I8}
+ (hnz : y.val ≠ 0)
+ (hmin : I8.min ≤ scalar_div x.val y.val)
+ (hmax : scalar_div x.val y.val ≤ I8.max):
+ ∃ z, x / y = ret z ∧ z.val = scalar_div x.val y.val :=
+ Scalar.div_spec hnz hmin hmax
+
+@[cepspec] theorem I16.div_spec (x : I16) {y : I16}
+ (hnz : y.val ≠ 0)
+ (hmin : I16.min ≤ scalar_div x.val y.val)
+ (hmax : scalar_div x.val y.val ≤ I16.max):
+ ∃ z, x / y = ret z ∧ z.val = scalar_div x.val y.val :=
+ Scalar.div_spec hnz hmin hmax
+
+@[cepspec] theorem I32.div_spec (x : I32) {y : I32}
+ (hnz : y.val ≠ 0)
+ (hmin : I32.min ≤ scalar_div x.val y.val)
+ (hmax : scalar_div x.val y.val ≤ I32.max):
+ ∃ z, x / y = ret z ∧ z.val = scalar_div x.val y.val :=
+ Scalar.div_spec hnz hmin hmax
+
+@[cepspec] theorem I64.div_spec (x : I64) {y : I64}
+ (hnz : y.val ≠ 0)
+ (hmin : I64.min ≤ scalar_div x.val y.val)
+ (hmax : scalar_div x.val y.val ≤ I64.max):
+ ∃ z, x / y = ret z ∧ z.val = scalar_div x.val y.val :=
+ Scalar.div_spec hnz hmin hmax
+
+@[cepspec] theorem I128.div_spec (x : I128) {y : I128}
+ (hnz : y.val ≠ 0)
+ (hmin : I128.min ≤ scalar_div x.val y.val)
+ (hmax : scalar_div x.val y.val ≤ I128.max):
+ ∃ z, x / y = ret z ∧ z.val = scalar_div x.val y.val :=
+ Scalar.div_spec hnz hmin hmax
+
-- Generic theorem - shouldn't be used much
@[cpspec]
theorem Scalar.rem_spec {ty} {x y : Scalar ty}
@@ -734,6 +866,41 @@ theorem Scalar.rem_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : S
∃ z, x % y = ret z ∧ z.val = x.val % y.val := by
apply Scalar.rem_unsigned_spec <;> simp [Scalar.max, *]
+@[cepspec] theorem I8.rem_spec (x : I8) {y : I8}
+ (hnz : y.val ≠ 0)
+ (hmin : I8.min ≤ scalar_rem x.val y.val)
+ (hmax : scalar_rem x.val y.val ≤ I8.max):
+ ∃ z, x % y = ret z ∧ z.val = scalar_rem x.val y.val :=
+ Scalar.rem_spec hnz hmin hmax
+
+@[cepspec] theorem I16.rem_spec (x : I16) {y : I16}
+ (hnz : y.val ≠ 0)
+ (hmin : I16.min ≤ scalar_rem x.val y.val)
+ (hmax : scalar_rem x.val y.val ≤ I16.max):
+ ∃ z, x % y = ret z ∧ z.val = scalar_rem x.val y.val :=
+ Scalar.rem_spec hnz hmin hmax
+
+@[cepspec] theorem I32.rem_spec (x : I32) {y : I32}
+ (hnz : y.val ≠ 0)
+ (hmin : I32.min ≤ scalar_rem x.val y.val)
+ (hmax : scalar_rem x.val y.val ≤ I32.max):
+ ∃ z, x % y = ret z ∧ z.val = scalar_rem x.val y.val :=
+ Scalar.rem_spec hnz hmin hmax
+
+@[cepspec] theorem I64.rem_spec (x : I64) {y : I64}
+ (hnz : y.val ≠ 0)
+ (hmin : I64.min ≤ scalar_rem x.val y.val)
+ (hmax : scalar_rem x.val y.val ≤ I64.max):
+ ∃ z, x % y = ret z ∧ z.val = scalar_rem x.val y.val :=
+ Scalar.rem_spec hnz hmin hmax
+
+@[cepspec] theorem I128.rem_spec (x : I128) {y : I128}
+ (hnz : y.val ≠ 0)
+ (hmin : I128.min ≤ scalar_rem x.val y.val)
+ (hmax : scalar_rem x.val y.val ≤ I128.max):
+ ∃ z, x % y = ret z ∧ z.val = scalar_rem x.val y.val :=
+ Scalar.rem_spec hnz hmin hmax
+
-- ofIntCore
-- TODO: typeclass?
def Isize.ofIntCore := @Scalar.ofIntCore .Isize
@@ -751,33 +918,34 @@ def U128.ofIntCore := @Scalar.ofIntCore .U128
-- ofInt
-- TODO: typeclass?
-def Isize.ofInt := @Scalar.ofInt .Isize
-def I8.ofInt := @Scalar.ofInt .I8
-def I16.ofInt := @Scalar.ofInt .I16
-def I32.ofInt := @Scalar.ofInt .I32
-def I64.ofInt := @Scalar.ofInt .I64
-def I128.ofInt := @Scalar.ofInt .I128
-def Usize.ofInt := @Scalar.ofInt .Usize
-def U8.ofInt := @Scalar.ofInt .U8
-def U16.ofInt := @Scalar.ofInt .U16
-def U32.ofInt := @Scalar.ofInt .U32
-def U64.ofInt := @Scalar.ofInt .U64
-def U128.ofInt := @Scalar.ofInt .U128
-
-postfix:74 "%isize" => Isize.ofInt
-postfix:74 "%i8" => I8.ofInt
-postfix:74 "%i16" => I16.ofInt
-postfix:74 "%i32" => I32.ofInt
-postfix:74 "%i64" => I64.ofInt
-postfix:74 "%i128" => I128.ofInt
-postfix:74 "%usize" => Usize.ofInt
-postfix:74 "%u8" => U8.ofInt
-postfix:74 "%u16" => U16.ofInt
-postfix:74 "%u32" => U32.ofInt
-postfix:74 "%u64" => U64.ofInt
-postfix:74 "%u128" => U128.ofInt
-
-example : Result U32 := 1%u32 + 2%u32
+abbrev Isize.ofInt := @Scalar.ofInt .Isize
+abbrev I8.ofInt := @Scalar.ofInt .I8
+abbrev I16.ofInt := @Scalar.ofInt .I16
+abbrev I32.ofInt := @Scalar.ofInt .I32
+abbrev I64.ofInt := @Scalar.ofInt .I64
+abbrev I128.ofInt := @Scalar.ofInt .I128
+abbrev Usize.ofInt := @Scalar.ofInt .Usize
+abbrev U8.ofInt := @Scalar.ofInt .U8
+abbrev U16.ofInt := @Scalar.ofInt .U16
+abbrev U32.ofInt := @Scalar.ofInt .U32
+abbrev U64.ofInt := @Scalar.ofInt .U64
+abbrev U128.ofInt := @Scalar.ofInt .U128
+
+postfix:max "#isize" => Isize.ofInt
+postfix:max "#i8" => I8.ofInt
+postfix:max "#i16" => I16.ofInt
+postfix:max "#i32" => I32.ofInt
+postfix:max "#i64" => I64.ofInt
+postfix:max "#i128" => I128.ofInt
+postfix:max "#usize" => Usize.ofInt
+postfix:max "#u8" => U8.ofInt
+postfix:max "#u16" => U16.ofInt
+postfix:max "#u32" => U32.ofInt
+postfix:max "#u64" => U64.ofInt
+postfix:max "#u128" => U128.ofInt
+
+-- Testing the notations
+example : Result Usize := 0#usize + 1#usize
-- TODO: factor those lemmas out
@[simp] theorem Scalar.ofInt_val_eq {ty} (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : (Scalar.ofInt x h).val = x := by
@@ -819,7 +987,6 @@ example : Result U32 := 1%u32 + 2%u32
@[simp] theorem U128.ofInt_val_eq (h : Scalar.min ScalarTy.U128 ≤ x ∧ x ≤ Scalar.max ScalarTy.U128) : (U128.ofInt x h).val = x := by
apply Scalar.ofInt_val_eq h
-
-- Comparisons
instance {ty} : LT (Scalar ty) where
lt a b := LT.lt a.val b.val
@@ -847,6 +1014,9 @@ instance (ty : ScalarTy) : DecidableEq (Scalar ty) :=
instance (ty : ScalarTy) : CoeOut (Scalar ty) Int where
coe := λ v => v.val
+@[simp] theorem Scalar.neq_to_neq_val {ty} : ∀ {i j : Scalar ty}, (¬ i = j) ↔ ¬ i.val = j.val := by
+ intro i j; cases i; cases j; simp
+
-- -- We now define a type class that subsumes the various machine integer types, so
-- -- as to write a concise definition for scalar_cast, rather than exhaustively
-- -- enumerating all of the possible pairs. We remark that Rust has sane semantics
diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean
index 6f820a84..76a92795 100644
--- a/backends/lean/Base/Progress/Base.lean
+++ b/backends/lean/Base/Progress/Base.lean
@@ -167,7 +167,8 @@ structure PSpecClassExprAttr where
deriving Inhabited
-- TODO: the original function doesn't define correctly the `addImportedFn`. Do a PR?
-def mkMapDeclarationExtension [Inhabited α] (name : Name := by exact decl_name%) : IO (MapDeclarationExtension α) :=
+def mkMapDeclarationExtension [Inhabited α] (name : Name := by exact decl_name%) :
+ IO (MapDeclarationExtension α) :=
registerSimplePersistentEnvExtension {
name := name,
addImportedFn := fun a => a.foldl (fun s a => a.foldl (fun s (k, v) => s.insert k v) s) RBMap.empty,
@@ -175,6 +176,54 @@ def mkMapDeclarationExtension [Inhabited α] (name : Name := by exact decl_name%
toArrayFn := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1)
}
+-- Declare an extension of maps of maps (using [RBMap]).
+-- The important point is that we need to merge the bound values (which are maps).
+def mkMapMapDeclarationExtension [Inhabited β] (ord : α → α → Ordering)
+ (name : Name := by exact decl_name%) :
+ IO (MapDeclarationExtension (RBMap α β ord)) :=
+ registerSimplePersistentEnvExtension {
+ name := name,
+ addImportedFn := fun a =>
+ a.foldl (fun s a => a.foldl (
+ -- We need to merge the maps
+ fun s (k0, k1_to_v) =>
+ match s.find? k0 with
+ | none =>
+ -- No binding: insert one
+ s.insert k0 k1_to_v
+ | some m =>
+ -- There is already a binding: merge
+ let m := RBMap.fold (fun m k v => m.insert k v) m k1_to_v
+ s.insert k0 m)
+ s) RBMap.empty,
+ addEntryFn := fun s n => s.insert n.1 n.2 ,
+ toArrayFn := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1)
+ }
+
+-- Declare an extension of maps of maps (using [HashMap]).
+-- The important point is that we need to merge the bound values (which are maps).
+def mkMapHashMapDeclarationExtension [BEq α] [Hashable α] [Inhabited β]
+ (name : Name := by exact decl_name%) :
+ IO (MapDeclarationExtension (HashMap α β)) :=
+ registerSimplePersistentEnvExtension {
+ name := name,
+ addImportedFn := fun a =>
+ a.foldl (fun s a => a.foldl (
+ -- We need to merge the maps
+ fun s (k0, k1_to_v) =>
+ match s.find? k0 with
+ | none =>
+ -- No binding: insert one
+ s.insert k0 k1_to_v
+ | some m =>
+ -- There is already a binding: merge
+ let m := HashMap.fold (fun m k v => m.insert k v) m k1_to_v
+ s.insert k0 m)
+ s) RBMap.empty,
+ addEntryFn := fun s n => s.insert n.1 n.2 ,
+ toArrayFn := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1)
+ }
+
/- The persistent map from function to pspec theorems. -/
initialize pspecAttr : PSpecAttr ← do
let ext ← mkMapDeclarationExtension `pspecMap
@@ -200,7 +249,8 @@ initialize pspecAttr : PSpecAttr ← do
/- The persistent map from type classes to pspec theorems -/
initialize pspecClassAttr : PSpecClassAttr ← do
- let ext : MapDeclarationExtension (NameMap Name) ← mkMapDeclarationExtension `pspecClassMap
+ let ext : MapDeclarationExtension (NameMap Name) ←
+ mkMapMapDeclarationExtension Name.quickCmp `pspecClassMap
let attrImpl : AttributeImpl := {
name := `cpspec
descr := "Marks theorems to use for type classes with the `progress` tactic"
@@ -231,7 +281,8 @@ initialize pspecClassAttr : PSpecClassAttr ← do
/- The 2nd persistent map from type classes to pspec theorems -/
initialize pspecClassExprAttr : PSpecClassExprAttr ← do
- let ext : MapDeclarationExtension (HashMap Expr Name) ← mkMapDeclarationExtension `pspecClassExprMap
+ let ext : MapDeclarationExtension (HashMap Expr Name) ←
+ mkMapHashMapDeclarationExtension `pspecClassExprMap
let attrImpl : AttributeImpl := {
name := `cepspec
descr := "Marks theorems to use for type classes with the `progress` tactic"
diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean
index 6a4729dc..8b0759c5 100644
--- a/backends/lean/Base/Progress/Progress.lean
+++ b/backends/lean/Base/Progress/Progress.lean
@@ -110,8 +110,9 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal)
-- then continue splitting the post-condition
splitEqAndPost fun hEq hPost ids => do
trace[Progress] "eq and post:\n{hEq} : {← inferType hEq}\n{hPost}"
- simpAt [] [``Primitives.bind_tc_ret, ``Primitives.bind_tc_fail, ``Primitives.bind_tc_div]
- [hEq.fvarId!] (.targets #[] true)
+ tryTac (
+ simpAt [] [``Primitives.bind_tc_ret, ``Primitives.bind_tc_fail, ``Primitives.bind_tc_div]
+ [hEq.fvarId!] (.targets #[] true))
-- Clear the equality, unless the user requests not to do so
let mgoal ← do
if keep.isSome then getMainGoal
@@ -242,21 +243,26 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL
tryLookupApply keep ids splitPost asmTac fExpr "pspec theorem" pspec do
-- It failed: try to lookup a *class* expr spec theorem (those are more
-- specific than class spec theorems)
+ trace[Progress] "Failed using a pspec theorem: trying to lookup a pspec class expr theorem"
let pspecClassExpr ← do
match getFirstArg args with
| none => pure none
| some arg => do
+ trace[Progress] "Using: f:{fName}, arg: {arg}"
let thName ← pspecClassExprAttr.find? fName arg
pure (thName.map fun th => .Theorem th)
tryLookupApply keep ids splitPost asmTac fExpr "pspec class expr theorem" pspecClassExpr do
-- It failed: try to lookup a *class* spec theorem
+ trace[Progress] "Failed using a pspec class expr theorem: trying to lookup a pspec class theorem"
let pspecClass ← do
match ← getFirstArgAppName args with
| none => pure none
| some argName => do
+ trace[Progress] "Using: f: {fName}, arg: {argName}"
let thName ← pspecClassAttr.find? fName argName
pure (thName.map fun th => .Theorem th)
tryLookupApply keep ids splitPost asmTac fExpr "pspec class theorem" pspecClass do
+ trace[Progress] "Failed using a pspec class theorem: trying to use a recursive assumption"
-- Try a recursive call - we try the assumptions of kind "auxDecl"
let ctx ← Lean.MonadLCtx.getLCtx
let decls ← ctx.getAllDecls
@@ -314,12 +320,14 @@ def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do
else pure none
let ids :=
let args := asArgs.getArgs
- let args := (args.get! 2).getSepArgs
- args.map (λ s => if s.isIdent then some s.getId else none)
+ if args.size > 2 then
+ let args := (args.get! 2).getSepArgs
+ args.map (λ s => if s.isIdent then some s.getId else none)
+ else #[]
trace[Progress] "User-provided ids: {ids}"
let splitPost : Bool :=
let args := asArgs.getArgs
- (args.get! 3).getArgs.size > 0
+ args.size > 3 ∧ (args.get! 3).getArgs.size > 0
trace[Progress] "Split post: {splitPost}"
/- For scalarTac we have a fast track: if the goal is not a linear
arithmetic goal, we skip (note that otherwise, scalarTac would try
@@ -343,11 +351,14 @@ elab "progress" args:progressArgs : tactic =>
namespace Test
open Primitives Result
- set_option trace.Progress true
- set_option pp.rawOnError true
+ -- Show the traces
+ -- set_option trace.Progress true
+ -- set_option pp.rawOnError true
- #eval showStoredPSpec
- #eval showStoredPSpecClass
+ -- The following commands display the databases of theorems
+ -- #eval showStoredPSpec
+ -- #eval showStoredPSpecClass
+ -- #eval showStoredPSpecExprClass
example {ty} {x y : Scalar ty}
(hmin : Scalar.min ty ≤ x.val + y.val)
@@ -363,6 +374,12 @@ namespace Test
progress keep h with Scalar.add_spec as ⟨ z ⟩
simp [*, h]
+ example {x y : U32}
+ (hmax : x.val + y.val ≤ U32.max) :
+ ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by
+ progress keep _ as ⟨ z, h1 .. ⟩
+ simp [*, h1]
+
/- Checking that universe instantiation works: the original spec uses
`α : Type u` where u is quantified, while here we use `α : Type 0` -/
example {α : Type} (v: Vec α) (i: Usize) (x : α)
diff --git a/backends/lean/Base/Utils.lean b/backends/lean/Base/Utils.lean
index 1f8f1455..5224e1c3 100644
--- a/backends/lean/Base/Utils.lean
+++ b/backends/lean/Base/Utils.lean
@@ -301,6 +301,10 @@ example : Nat := by
example (x : Bool) : Nat := by
cases x <;> custom_let x := 3 <;> apply x
+-- Attempt to apply a tactic
+def tryTac (tac : TacticM Unit) : TacticM Unit := do
+ let _ ← tryTactic tac
+
-- Repeatedly apply a tactic
partial def repeatTac (tac : TacticM Unit) : TacticM Unit := do
try