summaryrefslogtreecommitdiff
path: root/backends/lean/Base/Arith
diff options
context:
space:
mode:
authorSon Ho2023-09-18 12:25:40 +0200
committerSon Ho2023-09-18 12:25:40 +0200
commit1fc263ec0f527698b2f4d734d9757c9f723d0bee (patch)
tree5ad211f79e598a96dd9e867289fc601a2f7e2815 /backends/lean/Base/Arith
parent888565dd83636522564c6fcae1571735b32125da (diff)
Improve scalar_tac
Diffstat (limited to '')
-rw-r--r--backends/lean/Base/Arith/Int.lean8
-rw-r--r--backends/lean/Base/Arith/Scalar.lean17
2 files changed, 20 insertions, 5 deletions
diff --git a/backends/lean/Base/Arith/Int.lean b/backends/lean/Base/Arith/Int.lean
index eb6701c2..3359ecdb 100644
--- a/backends/lean/Base/Arith/Int.lean
+++ b/backends/lean/Base/Arith/Int.lean
@@ -211,9 +211,11 @@ def intTacPreprocess (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM U
let _ ← introHasIntPropInstances
-- Extra preprocessing, before we split on the disjunctions
extraPreprocess
- -- Split
- let asms ← introInstances ``PropHasImp.concl lookupPropHasImp
- splitOnAsms asms.toList
+ -- Split - note that the extra-preprocessing step might actually have
+ -- proven the goal (by doing simplifications for instance)
+ Tactic.allGoals do
+ let asms ← introInstances ``PropHasImp.concl lookupPropHasImp
+ splitOnAsms asms.toList
elab "int_tac_preprocess" : tactic =>
intTacPreprocess (do pure ())
diff --git a/backends/lean/Base/Arith/Scalar.lean b/backends/lean/Base/Arith/Scalar.lean
index db672489..47751c8a 100644
--- a/backends/lean/Base/Arith/Scalar.lean
+++ b/backends/lean/Base/Arith/Scalar.lean
@@ -16,14 +16,15 @@ def scalarTacExtraPreprocess : Tactic.TacticM Unit := do
add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Isize []])
add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []])
add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []])
- -- Reveal the concrete bounds
+ -- Reveal the concrete bounds, simplify calls to [ofInt]
Utils.simpAt [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax,
``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min,
``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max,
``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min,
``U8.max, ``U16.max, ``U32.max, ``U64.max, ``U128.max,
``Usize.min
- ] [] [] .wildcard
+ ] [``Scalar.ofInt_val_eq, ``Scalar.neq_to_neq_val] [] .wildcard
+
elab "scalar_tac_preprocess" : tactic =>
intTacPreprocess scalarTacExtraPreprocess
@@ -50,4 +51,16 @@ example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by
example (x : U32 × U32) : 0 ≤ x.fst.val := by
scalar_tac
+-- Checking that we properly handle [ofInt]
+example : U32.ofInt 1 ≤ U32.max := by
+ scalar_tac
+
+example (x : Int) (h0 : 0 ≤ x) (h1 : x ≤ U32.max) :
+ U32.ofInt x (by constructor <;> scalar_tac) ≤ U32.max := by
+ scalar_tac
+
+-- Not equal
+example (x : U32) (h0 : ¬ x = U32.ofInt 0) : 0 < x.val := by
+ scalar_tac
+
end Arith