summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon HO2024-01-27 21:51:38 +0100
committerGitHub2024-01-27 21:51:38 +0100
commit689954a5c84c29c9b86f02e5009f286d909c355c (patch)
tree0e801b5e01eda423d49bdb0a43cff11d65e78bb1
parent202f0153dc51983e6bc0eddb65d22c763579850c (diff)
parentd8247d99520738188bbd160be7de03550f8156ce (diff)
Merge pull request #66 from AeneasVerif/son/lean
Improve the Lean backend
-rw-r--r--backends/lean/Base/Arith/Scalar.lean11
-rw-r--r--backends/lean/Base/Primitives/Base.lean7
-rw-r--r--backends/lean/Base/Primitives/Scalar.lean21
-rw-r--r--backends/lean/Base/Progress/Base.lean36
-rw-r--r--backends/lean/Base/Progress/Progress.lean191
5 files changed, 174 insertions, 92 deletions
diff --git a/backends/lean/Base/Arith/Scalar.lean b/backends/lean/Base/Arith/Scalar.lean
index 2342cce6..43fd2766 100644
--- a/backends/lean/Base/Arith/Scalar.lean
+++ b/backends/lean/Base/Arith/Scalar.lean
@@ -17,13 +17,20 @@ def scalarTacExtraPreprocess : Tactic.TacticM Unit := do
add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []])
add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []])
-- Reveal the concrete bounds, simplify calls to [ofInt]
- Utils.simpAt true [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax,
+ Utils.simpAt true
+ -- Unfoldings
+ [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax,
``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min,
``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max,
``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min,
``U8.max, ``U16.max, ``U32.max, ``U64.max, ``U128.max,
``Usize.min
- ] [``Scalar.ofInt_val_eq, ``Scalar.neq_to_neq_val] [] .wildcard
+ ]
+ -- Simp lemmas
+ [``Scalar.ofInt_val_eq, ``Scalar.neq_to_neq_val,
+ ``Scalar.lt_equiv, ``Scalar.le_equiv, ``Scalar.eq_equiv]
+ -- Hypotheses
+ [] .wildcard
elab "scalar_tac_preprocess" : tactic =>
diff --git a/backends/lean/Base/Primitives/Base.lean b/backends/lean/Base/Primitives/Base.lean
index 3d70c84a..9dbaf133 100644
--- a/backends/lean/Base/Primitives/Base.lean
+++ b/backends/lean/Base/Primitives/Base.lean
@@ -116,6 +116,13 @@ def Result.attach {α: Type} (o : Result α): Result { x : α // o = ret x } :=
@[simp] theorem bind_tc_div (f : α → Result β) :
(do let y ← div; f y) = div := by simp [Bind.bind, bind]
+@[simp] theorem bind_assoc_eq {a b c : Type u}
+ (e : Result a) (g : a → Result b) (h : b → Result c) :
+ (Bind.bind (Bind.bind e g) h) =
+ (Bind.bind e (λ x => Bind.bind (g x) h)) := by
+ simp [Bind.bind]
+ cases e <;> simp
+
----------
-- MISC --
----------
diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean
index a8eda6d5..fe8dc8ec 100644
--- a/backends/lean/Base/Primitives/Scalar.lean
+++ b/backends/lean/Base/Primitives/Scalar.lean
@@ -1038,6 +1038,27 @@ instance {ty} : LT (Scalar ty) where
instance {ty} : LE (Scalar ty) where le a b := LE.le a.val b.val
+-- Not marking this one with @[simp] on purpose
+theorem Scalar.eq_equiv {ty : ScalarTy} (x y : Scalar ty) :
+ x = y ↔ x.val = y.val := by
+ cases x; cases y; simp_all
+
+-- This is sometimes useful when rewriting the goal with the local assumptions
+@[simp] theorem Scalar.eq_imp {ty : ScalarTy} (x y : Scalar ty) :
+ x = y → x.val = y.val := (eq_equiv x y).mp
+
+theorem Scalar.lt_equiv {ty : ScalarTy} (x y : Scalar ty) :
+ x < y ↔ x.val < y.val := by simp [LT.lt]
+
+@[simp] theorem Scalar.lt_imp {ty : ScalarTy} (x y : Scalar ty) :
+ x < y → x.val < y.val := (lt_equiv x y).mp
+
+theorem Scalar.le_equiv {ty : ScalarTy} (x y : Scalar ty) :
+ x ≤ y ↔ x.val ≤ y.val := by simp [LE.le]
+
+@[simp] theorem Scalar.le_imp {ty : ScalarTy} (x y : Scalar ty) :
+ x ≤ y → x.val ≤ y.val := (le_equiv x y).mp
+
instance Scalar.decLt {ty} (a b : Scalar ty) : Decidable (LT.lt a b) := Int.decLt ..
instance Scalar.decLe {ty} (a b : Scalar ty) : Decidable (LE.le a b) := Int.decLe ..
diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean
index 0ad16ab6..a64212a5 100644
--- a/backends/lean/Base/Progress/Base.lean
+++ b/backends/lean/Base/Progress/Base.lean
@@ -22,8 +22,9 @@ structure PSpecDesc where
evars : Array Expr
-- The function applied to its arguments
fArgsExpr : Expr
- -- The function
- fName : Name
+ -- ⊤ if the function is a constant (must be if we are registering a theorem,
+ -- but is not necessarily the case if we are looking at a goal)
+ fIsConst : Bool
-- The function arguments
fLevels : List Level
args : Array Expr
@@ -82,19 +83,28 @@ section Methods
-- Destruct the equality
let (mExpr, ret) ← destEq th.consumeMData
trace[Progress] "After splitting the equality:\n- lhs: {th}\n- rhs: {ret}"
- -- Destruct the monadic application to dive into the bind, if necessary (this
- -- is for when we use `withPSpec` inside of the `progress` tactic), and
- -- destruct the application to get the function name
- mExpr.consumeMData.withApp fun mf margs => do
- trace[Progress] "After stripping the arguments of the monad expression:\n- mf: {mf}\n- margs: {margs}"
- let (fArgsExpr, f, args) ← do
+ -- Recursively destruct the monadic application to dive into the binds,
+ -- if necessary (this is for when we use `withPSpec` inside of the `progress` tactic),
+ -- and destruct the application to get the function name
+ let rec strip_monad mExpr := do
+ mExpr.consumeMData.withApp fun mf margs => do
+ trace[Progress] "After stripping the arguments of the monad expression:\n- mf: {mf}\n- margs: {margs}"
if mf.isConst ∧ mf.constName = ``Bind.bind then do
-- Dive into the bind
let fExpr := (margs.get! 4).consumeMData
- fExpr.withApp fun f args => pure (fExpr, f, args)
- else pure (mExpr, mf, margs)
+ -- Recursve
+ strip_monad fExpr
+ else
+ -- No bind
+ pure (mExpr, mf, margs)
+ let (fArgsExpr, f, args) ← strip_monad mExpr
trace[Progress] "After stripping the arguments of the function call:\n- f: {f}\n- args: {args}"
- if ¬ f.isConst then throwError "Not a constant: {f}"
+ let fLevels ← do
+ -- If we are registering a theorem, then the function must be a constant
+ if ¬ f.isConst then
+ if isGoal then pure []
+ else throwError "Not a constant: {f}"
+ else pure f.constLevels!
-- *Sanity check* (activated if we are analyzing a theorem to register it in a DB)
-- Check if some existentially quantified variables
let _ := do
@@ -113,8 +123,8 @@ section Methods
fvars := fvars
evars := evars
fArgsExpr
- fName := f.constName!
- fLevels := f.constLevels!
+ fIsConst := f.isConst
+ fLevels
args := args
ret := ret
post := post
diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean
index a6a4e82a..dc30c441 100644
--- a/backends/lean/Base/Progress/Progress.lean
+++ b/backends/lean/Base/Progress/Progress.lean
@@ -106,7 +106,7 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal)
withMainContext do -- The context changed - TODO: remove once addDeclTac is updated
let ngoal ← getMainGoal
trace[Progress] "current goal: {ngoal}"
- trace[Progress] "current goal: {← ngoal.isAssigned}"
+ trace[Progress] "current goal is assigned: {← ngoal.isAssigned}"
-- The assumption should be of the shape:
-- `∃ x1 ... xn, f args = ... ∧ ...`
-- We introduce the existentially quantified variables and split the top-most
@@ -131,50 +131,59 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal)
-- then continue splitting the post-condition
splitEqAndPost fun hEq hPost ids => do
trace[Progress] "eq and post:\n{hEq} : {← inferType hEq}\n{hPost}"
- tryTac (
- simpAt true []
- [``Primitives.bind_tc_ret, ``Primitives.bind_tc_fail, ``Primitives.bind_tc_div]
- [hEq.fvarId!] (.targets #[] true))
- -- TODO: remove this (some types get unfolded too much: we "fold" them back)
- tryTac (simpAt true [] scalar_eqs [] .wildcard_dep)
- -- Clear the equality, unless the user requests not to do so
- let mgoal ← do
- if keep.isSome then getMainGoal
- else do
- let mgoal ← getMainGoal
- mgoal.tryClearMany #[hEq.fvarId!]
- setGoals (mgoal :: (← getUnsolvedGoals))
- trace[Progress] "Goal after splitting eq and post and simplifying the target: {mgoal}"
- -- Continue splitting following the post following the user's instructions
- match hPost with
- | none =>
- -- Sanity check
- if ¬ ids.isEmpty then
- return (.Error m!"Too many ids provided ({ids}): there is no postcondition to split")
- else return .Ok
- | some hPost => do
- let rec splitPostWithIds (prevId : Name) (hPost : Expr) (ids0 : List (Option Name)) : TacticM ProgressError := do
- match ids0 with
- | [] =>
- /- We used all the user provided ids.
- Split the remaining conjunctions by using fresh ids if the user
- instructed to fully split the post-condition, otherwise stop -/
- if splitPost then
- splitFullConjTac true hPost (λ _ => pure .Ok)
- else pure .Ok
- | nid :: ids => do
- trace[Progress] "Splitting post: {← inferType hPost}"
- -- Split
- let nid ← do
- match nid with
- | none => mkFreshAnonPropUserName
- | some nid => pure nid
- trace[Progress] "\n- prevId: {prevId}\n- nid: {nid}\n- remaining ids: {ids}"
- if ← isConj (← inferType hPost) then
- splitConjTac hPost (some (prevId, nid)) (λ _ nhPost => splitPostWithIds nid nhPost ids)
- else return (.Error m!"Too many ids provided ({ids0}) not enough conjuncts to split in the postcondition")
- let curPostId := (← hPost.fvarId!.getDecl).userName
- splitPostWithIds curPostId hPost ids
+ trace[Progress] "current goal: {← getMainGoal}"
+ Tactic.focus do
+ let _ ←
+ tryTac
+ (simpAt true []
+ [``Primitives.bind_tc_ret, ``Primitives.bind_tc_fail, ``Primitives.bind_tc_div]
+ [hEq.fvarId!] (.targets #[] true))
+ -- It may happen that at this point the goal is already solved (though this is rare)
+ -- TODO: not sure this is the best way of checking it
+ if (← getUnsolvedGoals) == [] then pure .Ok
+ else
+ trace[Progress] "goal after applying the eq and simplifying the binds: {← getMainGoal}"
+ -- TODO: remove this (some types get unfolded too much: we "fold" them back)
+ let _ ← tryTac (simpAt true [] scalar_eqs [] .wildcard_dep)
+ trace[Progress] "goal after folding back scalar types: {← getMainGoal}"
+ -- Clear the equality, unless the user requests not to do so
+ let mgoal ← do
+ if keep.isSome then getMainGoal
+ else do
+ let mgoal ← getMainGoal
+ mgoal.tryClearMany #[hEq.fvarId!]
+ setGoals (mgoal :: (← getUnsolvedGoals))
+ trace[Progress] "Goal after splitting eq and post and simplifying the target: {mgoal}"
+ -- Continue splitting following the post following the user's instructions
+ match hPost with
+ | none =>
+ -- Sanity check
+ if ¬ ids.isEmpty then
+ return (.Error m!"Too many ids provided ({ids}): there is no postcondition to split")
+ else return .Ok
+ | some hPost => do
+ let rec splitPostWithIds (prevId : Name) (hPost : Expr) (ids0 : List (Option Name)) : TacticM ProgressError := do
+ match ids0 with
+ | [] =>
+ /- We used all the user provided ids.
+ Split the remaining conjunctions by using fresh ids if the user
+ instructed to fully split the post-condition, otherwise stop -/
+ if splitPost then
+ splitFullConjTac true hPost (λ _ => pure .Ok)
+ else pure .Ok
+ | nid :: ids => do
+ trace[Progress] "Splitting post: {← inferType hPost}"
+ -- Split
+ let nid ← do
+ match nid with
+ | none => mkFreshAnonPropUserName
+ | some nid => pure nid
+ trace[Progress] "\n- prevId: {prevId}\n- nid: {nid}\n- remaining ids: {ids}"
+ if ← isConj (← inferType hPost) then
+ splitConjTac hPost (some (prevId, nid)) (λ _ nhPost => splitPostWithIds nid nhPost ids)
+ else return (.Error m!"Too many ids provided ({ids0}) not enough conjuncts to split in the postcondition")
+ let curPostId := (← hPost.fvarId!.getDecl).userName
+ splitPostWithIds curPostId hPost ids
match res with
| .Error _ => return res -- Can we get there? We're using "return"
| .Ok =>
@@ -223,9 +232,9 @@ def tryLookupApply (keep : Option Name) (ids : Array (Option Name)) (splitPost :
pure (some res)
catch _ => none
match res with
- | some .Ok => return true
+ | some .Ok => pure true
| some (.Error msg) => throwError msg
- | none => return false
+ | none => pure false
-- The array of ids are identifiers to use when introducing fresh variables
def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrLocal)
@@ -266,36 +275,42 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL
match res with
| .Ok => return ()
| .Error msg => throwError msg
- -- It failed: lookup the pspec theorems which match the expression
- trace[Progress] "No assumption succeeded: trying to lookup a pspec theorem"
- let pspecs : Array TheoremOrLocal ← do
- let thNames ← pspecAttr.find? fExpr
- -- TODO: because of reduction, there may be several valid theorems (for
- -- instance for the scalars). We need to sort them from most specific to
- -- least specific. For now, we assume the most specific theorems are at
- -- the end.
- let thNames := thNames.reverse
- trace[Progress] "Looked up pspec theorems: {thNames}"
- pure (thNames.map fun th => TheoremOrLocal.Theorem th)
- -- Try the theorems one by one
- for pspec in pspecs do
- if ← tryLookupApply keep ids splitPost asmTac fExpr "pspec theorem" pspec then return ()
- else pure ()
- -- It failed: try to use the recursive assumptions
- trace[Progress] "Failed using a pspec theorem: trying to use a recursive assumption"
- -- We try to apply the assumptions of kind "auxDecl"
- let ctx ← Lean.MonadLCtx.getLCtx
- let decls ← ctx.getAllDecls
- let decls := decls.filter (λ decl => match decl.kind with
- | .default | .implDetail => false | .auxDecl => true)
- for decl in decls.reverse do
- trace[Progress] "Trying recursive assumption: {decl.userName} : {decl.type}"
- let res ← do try progressWith fExpr (.Local decl) keep ids splitPost asmTac catch _ => continue
- match res with
- | .Ok => return ()
- | .Error msg => throwError msg
- -- Nothing worked: failed
- throwError "Progress failed"
+ -- It failed: lookup the pspec theorems which match the expression *only
+ -- if the function is a constant*
+ let fIsConst ← do
+ fExpr.consumeMData.withApp fun mf _ => do
+ pure mf.isConst
+ if ¬ fIsConst then throwError "Progress failed"
+ else do
+ trace[Progress] "No assumption succeeded: trying to lookup a pspec theorem"
+ let pspecs : Array TheoremOrLocal ← do
+ let thNames ← pspecAttr.find? fExpr
+ -- TODO: because of reduction, there may be several valid theorems (for
+ -- instance for the scalars). We need to sort them from most specific to
+ -- least specific. For now, we assume the most specific theorems are at
+ -- the end.
+ let thNames := thNames.reverse
+ trace[Progress] "Looked up pspec theorems: {thNames}"
+ pure (thNames.map fun th => TheoremOrLocal.Theorem th)
+ -- Try the theorems one by one
+ for pspec in pspecs do
+ if ← tryLookupApply keep ids splitPost asmTac fExpr "pspec theorem" pspec then return ()
+ else pure ()
+ -- It failed: try to use the recursive assumptions
+ trace[Progress] "Failed using a pspec theorem: trying to use a recursive assumption"
+ -- We try to apply the assumptions of kind "auxDecl"
+ let ctx ← Lean.MonadLCtx.getLCtx
+ let decls ← ctx.getAllDecls
+ let decls := decls.filter (λ decl => match decl.kind with
+ | .default | .implDetail => false | .auxDecl => true)
+ for decl in decls.reverse do
+ trace[Progress] "Trying recursive assumption: {decl.userName} : {decl.type}"
+ let res ← do try progressWith fExpr (.Local decl) keep ids splitPost asmTac catch _ => continue
+ match res with
+ | .Ok => return ()
+ | .Error msg => throwError msg
+ -- Nothing worked: failed
+ throwError "Progress failed"
syntax progressArgs := ("keep" (ident <|> "_"))? ("with" ident)? ("as" " ⟨ " (ident <|> "_"),* " .."? " ⟩")?
@@ -421,6 +436,28 @@ namespace Test
progress
simp [*]
+ /- Checking that progress can handle nested blocks -/
+ example {α : Type} (v: Vec α) (i: Usize) (x : α)
+ (hbounds : i.val < v.length) :
+ ∃ nv,
+ (do
+ (do
+ let _ ← v.update_usize i x
+ .ret ())
+ .ret ()) = ret nv
+ := by
+ progress
+ simp [*]
+
+ /- Checking the case where simplifying the goal after instantiating the
+ pspec theorem the goal actually solves it, and where the function is
+ not a constant. We also test the case where the function under scrutinee
+ is not a constant. -/
+ example {x : U32}
+ (f : U32 → Result Unit) (h : ∀ x, f x = .ret ()) :
+ f x = ret () := by
+ progress
+
end Test
end Progress