From 9f0e4605e1c8816dbf5ed3e9e893b25e9a2be4a3 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 26 Jan 2024 00:17:59 +0100 Subject: Improve the Lean backend --- backends/lean/Base/Arith/Scalar.lean | 11 +++++++++-- backends/lean/Base/Primitives/Base.lean | 7 +++++++ backends/lean/Base/Primitives/Scalar.lean | 11 +++++++++++ backends/lean/Base/Progress/Base.lean | 20 ++++++++++++-------- backends/lean/Base/Progress/Progress.lean | 13 +++++++++++++ 5 files changed, 52 insertions(+), 10 deletions(-) (limited to 'backends') diff --git a/backends/lean/Base/Arith/Scalar.lean b/backends/lean/Base/Arith/Scalar.lean index 2342cce6..43fd2766 100644 --- a/backends/lean/Base/Arith/Scalar.lean +++ b/backends/lean/Base/Arith/Scalar.lean @@ -17,13 +17,20 @@ def scalarTacExtraPreprocess : Tactic.TacticM Unit := do add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []]) add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []]) -- Reveal the concrete bounds, simplify calls to [ofInt] - Utils.simpAt true [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax, + Utils.simpAt true + -- Unfoldings + [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax, ``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min, ``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max, ``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min, ``U8.max, ``U16.max, ``U32.max, ``U64.max, ``U128.max, ``Usize.min - ] [``Scalar.ofInt_val_eq, ``Scalar.neq_to_neq_val] [] .wildcard + ] + -- Simp lemmas + [``Scalar.ofInt_val_eq, ``Scalar.neq_to_neq_val, + ``Scalar.lt_equiv, ``Scalar.le_equiv, ``Scalar.eq_equiv] + -- Hypotheses + [] .wildcard elab "scalar_tac_preprocess" : tactic => diff --git a/backends/lean/Base/Primitives/Base.lean b/backends/lean/Base/Primitives/Base.lean index 3d70c84a..9dbaf133 100644 --- a/backends/lean/Base/Primitives/Base.lean +++ b/backends/lean/Base/Primitives/Base.lean @@ -116,6 +116,13 @@ def Result.attach {α: Type} (o : Result α): Result { x : α // o = ret x } := @[simp] theorem bind_tc_div (f : α → Result β) : (do let y ← div; f y) = div := by simp [Bind.bind, bind] +@[simp] theorem bind_assoc_eq {a b c : Type u} + (e : Result a) (g : a → Result b) (h : b → Result c) : + (Bind.bind (Bind.bind e g) h) = + (Bind.bind e (λ x => Bind.bind (g x) h)) := by + simp [Bind.bind] + cases e <;> simp + ---------- -- MISC -- ---------- diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean index a8eda6d5..2c34774b 100644 --- a/backends/lean/Base/Primitives/Scalar.lean +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -1038,6 +1038,17 @@ instance {ty} : LT (Scalar ty) where instance {ty} : LE (Scalar ty) where le a b := LE.le a.val b.val +theorem Scalar.lt_equiv {ty : ScalarTy} (x y : Scalar ty) : + x < y ↔ x.val < y.val := by simp [LT.lt] + +theorem Scalar.le_equiv {ty : ScalarTy} (x y : Scalar ty) : + x ≤ y ↔ x.val ≤ y.val := by simp [LE.le] + +-- Not marking this one with @[simp] on purpose +theorem Scalar.eq_equiv {ty : ScalarTy} (x y : Scalar ty) : + x = y ↔ x.val = y.val := by + cases x; cases y; simp_all + instance Scalar.decLt {ty} (a b : Scalar ty) : Decidable (LT.lt a b) := Int.decLt .. instance Scalar.decLe {ty} (a b : Scalar ty) : Decidable (LE.le a b) := Int.decLe .. diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean index 0ad16ab6..935af3f5 100644 --- a/backends/lean/Base/Progress/Base.lean +++ b/backends/lean/Base/Progress/Base.lean @@ -82,17 +82,21 @@ section Methods -- Destruct the equality let (mExpr, ret) ← destEq th.consumeMData trace[Progress] "After splitting the equality:\n- lhs: {th}\n- rhs: {ret}" - -- Destruct the monadic application to dive into the bind, if necessary (this - -- is for when we use `withPSpec` inside of the `progress` tactic), and - -- destruct the application to get the function name - mExpr.consumeMData.withApp fun mf margs => do - trace[Progress] "After stripping the arguments of the monad expression:\n- mf: {mf}\n- margs: {margs}" - let (fArgsExpr, f, args) ← do + -- Recursively destruct the monadic application to dive into the binds, + -- if necessary (this is for when we use `withPSpec` inside of the `progress` tactic), + -- and destruct the application to get the function name + let rec strip_monad mExpr := do + mExpr.consumeMData.withApp fun mf margs => do + trace[Progress] "After stripping the arguments of the monad expression:\n- mf: {mf}\n- margs: {margs}" if mf.isConst ∧ mf.constName = ``Bind.bind then do -- Dive into the bind let fExpr := (margs.get! 4).consumeMData - fExpr.withApp fun f args => pure (fExpr, f, args) - else pure (mExpr, mf, margs) + -- Recursve + strip_monad fExpr + else + -- No bind + pure (mExpr, mf, margs) + let (fArgsExpr, f, args) ← strip_monad mExpr trace[Progress] "After stripping the arguments of the function call:\n- f: {f}\n- args: {args}" if ¬ f.isConst then throwError "Not a constant: {f}" -- *Sanity check* (activated if we are analyzing a theorem to register it in a DB) diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean index a6a4e82a..0fb276aa 100644 --- a/backends/lean/Base/Progress/Progress.lean +++ b/backends/lean/Base/Progress/Progress.lean @@ -421,6 +421,19 @@ namespace Test progress simp [*] + /- Checking that progress can handle nested blocks -/ + example {α : Type} (v: Vec α) (i: Usize) (x : α) + (hbounds : i.val < v.length) : + ∃ nv, + (do + (do + let _ ← v.update_usize i x + .ret ()) + .ret ()) = ret nv + := by + progress + simp [*] + end Test end Progress -- cgit v1.2.3