summaryrefslogtreecommitdiff
path: root/backends/lean/Base
diff options
context:
space:
mode:
authorSon Ho2024-01-26 00:17:59 +0100
committerSon Ho2024-01-26 00:17:59 +0100
commit9f0e4605e1c8816dbf5ed3e9e893b25e9a2be4a3 (patch)
treea1bd3885305af5598cab879a0551660a6bcf0492 /backends/lean/Base
parent202f0153dc51983e6bc0eddb65d22c763579850c (diff)
Improve the Lean backend
Diffstat (limited to 'backends/lean/Base')
-rw-r--r--backends/lean/Base/Arith/Scalar.lean11
-rw-r--r--backends/lean/Base/Primitives/Base.lean7
-rw-r--r--backends/lean/Base/Primitives/Scalar.lean11
-rw-r--r--backends/lean/Base/Progress/Base.lean20
-rw-r--r--backends/lean/Base/Progress/Progress.lean13
5 files changed, 52 insertions, 10 deletions
diff --git a/backends/lean/Base/Arith/Scalar.lean b/backends/lean/Base/Arith/Scalar.lean
index 2342cce6..43fd2766 100644
--- a/backends/lean/Base/Arith/Scalar.lean
+++ b/backends/lean/Base/Arith/Scalar.lean
@@ -17,13 +17,20 @@ def scalarTacExtraPreprocess : Tactic.TacticM Unit := do
add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []])
add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []])
-- Reveal the concrete bounds, simplify calls to [ofInt]
- Utils.simpAt true [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax,
+ Utils.simpAt true
+ -- Unfoldings
+ [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax,
``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min,
``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max,
``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min,
``U8.max, ``U16.max, ``U32.max, ``U64.max, ``U128.max,
``Usize.min
- ] [``Scalar.ofInt_val_eq, ``Scalar.neq_to_neq_val] [] .wildcard
+ ]
+ -- Simp lemmas
+ [``Scalar.ofInt_val_eq, ``Scalar.neq_to_neq_val,
+ ``Scalar.lt_equiv, ``Scalar.le_equiv, ``Scalar.eq_equiv]
+ -- Hypotheses
+ [] .wildcard
elab "scalar_tac_preprocess" : tactic =>
diff --git a/backends/lean/Base/Primitives/Base.lean b/backends/lean/Base/Primitives/Base.lean
index 3d70c84a..9dbaf133 100644
--- a/backends/lean/Base/Primitives/Base.lean
+++ b/backends/lean/Base/Primitives/Base.lean
@@ -116,6 +116,13 @@ def Result.attach {α: Type} (o : Result α): Result { x : α // o = ret x } :=
@[simp] theorem bind_tc_div (f : α → Result β) :
(do let y ← div; f y) = div := by simp [Bind.bind, bind]
+@[simp] theorem bind_assoc_eq {a b c : Type u}
+ (e : Result a) (g : a → Result b) (h : b → Result c) :
+ (Bind.bind (Bind.bind e g) h) =
+ (Bind.bind e (λ x => Bind.bind (g x) h)) := by
+ simp [Bind.bind]
+ cases e <;> simp
+
----------
-- MISC --
----------
diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean
index a8eda6d5..2c34774b 100644
--- a/backends/lean/Base/Primitives/Scalar.lean
+++ b/backends/lean/Base/Primitives/Scalar.lean
@@ -1038,6 +1038,17 @@ instance {ty} : LT (Scalar ty) where
instance {ty} : LE (Scalar ty) where le a b := LE.le a.val b.val
+theorem Scalar.lt_equiv {ty : ScalarTy} (x y : Scalar ty) :
+ x < y ↔ x.val < y.val := by simp [LT.lt]
+
+theorem Scalar.le_equiv {ty : ScalarTy} (x y : Scalar ty) :
+ x ≤ y ↔ x.val ≤ y.val := by simp [LE.le]
+
+-- Not marking this one with @[simp] on purpose
+theorem Scalar.eq_equiv {ty : ScalarTy} (x y : Scalar ty) :
+ x = y ↔ x.val = y.val := by
+ cases x; cases y; simp_all
+
instance Scalar.decLt {ty} (a b : Scalar ty) : Decidable (LT.lt a b) := Int.decLt ..
instance Scalar.decLe {ty} (a b : Scalar ty) : Decidable (LE.le a b) := Int.decLe ..
diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean
index 0ad16ab6..935af3f5 100644
--- a/backends/lean/Base/Progress/Base.lean
+++ b/backends/lean/Base/Progress/Base.lean
@@ -82,17 +82,21 @@ section Methods
-- Destruct the equality
let (mExpr, ret) ← destEq th.consumeMData
trace[Progress] "After splitting the equality:\n- lhs: {th}\n- rhs: {ret}"
- -- Destruct the monadic application to dive into the bind, if necessary (this
- -- is for when we use `withPSpec` inside of the `progress` tactic), and
- -- destruct the application to get the function name
- mExpr.consumeMData.withApp fun mf margs => do
- trace[Progress] "After stripping the arguments of the monad expression:\n- mf: {mf}\n- margs: {margs}"
- let (fArgsExpr, f, args) ← do
+ -- Recursively destruct the monadic application to dive into the binds,
+ -- if necessary (this is for when we use `withPSpec` inside of the `progress` tactic),
+ -- and destruct the application to get the function name
+ let rec strip_monad mExpr := do
+ mExpr.consumeMData.withApp fun mf margs => do
+ trace[Progress] "After stripping the arguments of the monad expression:\n- mf: {mf}\n- margs: {margs}"
if mf.isConst ∧ mf.constName = ``Bind.bind then do
-- Dive into the bind
let fExpr := (margs.get! 4).consumeMData
- fExpr.withApp fun f args => pure (fExpr, f, args)
- else pure (mExpr, mf, margs)
+ -- Recursve
+ strip_monad fExpr
+ else
+ -- No bind
+ pure (mExpr, mf, margs)
+ let (fArgsExpr, f, args) ← strip_monad mExpr
trace[Progress] "After stripping the arguments of the function call:\n- f: {f}\n- args: {args}"
if ¬ f.isConst then throwError "Not a constant: {f}"
-- *Sanity check* (activated if we are analyzing a theorem to register it in a DB)
diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean
index a6a4e82a..0fb276aa 100644
--- a/backends/lean/Base/Progress/Progress.lean
+++ b/backends/lean/Base/Progress/Progress.lean
@@ -421,6 +421,19 @@ namespace Test
progress
simp [*]
+ /- Checking that progress can handle nested blocks -/
+ example {α : Type} (v: Vec α) (i: Usize) (x : α)
+ (hbounds : i.val < v.length) :
+ ∃ nv,
+ (do
+ (do
+ let _ ← v.update_usize i x
+ .ret ())
+ .ret ()) = ret nv
+ := by
+ progress
+ simp [*]
+
end Test
end Progress