summaryrefslogtreecommitdiff
path: root/backends/lean/Base/Progress
diff options
context:
space:
mode:
Diffstat (limited to 'backends/lean/Base/Progress')
-rw-r--r--backends/lean/Base/Progress/Progress.lean55
1 files changed, 34 insertions, 21 deletions
diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean
index 974a6364..1f734415 100644
--- a/backends/lean/Base/Progress/Progress.lean
+++ b/backends/lean/Base/Progress/Progress.lean
@@ -39,7 +39,7 @@ inductive ProgressError
| Error (msg : MessageData)
deriving Inhabited
-def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (ids : Array Name)
+def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (keep : Option Name) (ids : Array Name)
(asmTac : TacticM Unit) : TacticM ProgressError := do
/- Apply the theorem
We try to match the theorem with the goal
@@ -88,7 +88,7 @@ def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (ids : Array Name)
match th with
| .Theorem thName => mkAppOptM thName (mvars.map some)
| .Local decl => mkAppOptM' (mkFVar decl.fvarId) (mvars.map some)
- let asmName ← mkFreshUserName `h
+ let asmName ← do match keep with | none => mkFreshUserName `h | some n => do pure n
let thTy ← inferType th
let thAsm ← Utils.addDeclTac asmName th thTy (asLet := false)
withMainContext do -- The context changed - TODO: remove once addDeclTac is updated
@@ -109,7 +109,9 @@ def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (ids : Array Name)
if ← isConj (← inferType h) then do
let hName := (← h.fvarId!.getDecl).userName
let (optId, ids) := listTryPopHead ids
- let optIds := match optId with | none => none | some id => some (hName, id)
+ let optIds ← match optId with
+ | none => do pure (some (hName, ← mkFreshUserName `h))
+ | some id => do pure (some (hName, id))
splitConjTac h optIds (fun hEq hPost => k hEq (some hPost) ids)
else k h none ids
-- Simplify the target by using the equality and some monad simplifications,
@@ -118,9 +120,12 @@ def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (ids : Array Name)
trace[Progress] "eq and post:\n{hEq} : {← inferType hEq}\n{hPost}"
simpAt [] [``Primitives.bind_tc_ret, ``Primitives.bind_tc_fail, ``Primitives.bind_tc_div]
[hEq.fvarId!] (.targets #[] true)
- -- Clear the equality
- let mgoal ← getMainGoal
- let mgoal ← mgoal.tryClearMany #[hEq.fvarId!]
+ -- Clear the equality, unless the user requests not to do so
+ let mgoal ← do
+ if keep.isSome then getMainGoal
+ else do
+ let mgoal ← getMainGoal
+ mgoal.tryClearMany #[hEq.fvarId!]
setGoals (mgoal :: (← getUnsolvedGoals))
trace[Progress] "Goal after splitting eq and post and simplifying the target: {mgoal}"
-- Continue splitting following the ids provided by the user
@@ -170,7 +175,7 @@ def getFirstArg (args : Array Expr) : Option Expr := do
/- Helper: try to lookup a theorem and apply it, or continue with another tactic
if it fails -/
-def tryLookupApply (ids : Array Name) (asmTac : TacticM Unit) (fnExpr : Expr)
+def tryLookupApply (keep : Option Name) (ids : Array Name) (asmTac : TacticM Unit) (fnExpr : Expr)
(kind : String) (th : Option TheoremOrLocal) (x : TacticM Unit) : TacticM Unit := do
let res ← do
match th with
@@ -182,7 +187,7 @@ def tryLookupApply (ids : Array Name) (asmTac : TacticM Unit) (fnExpr : Expr)
-- Apply the theorem
let res ← do
try
- let res ← progressWith fnExpr th ids asmTac
+ let res ← progressWith fnExpr th keep ids asmTac
pure (some res)
catch _ => none
match res with
@@ -191,7 +196,7 @@ def tryLookupApply (ids : Array Name) (asmTac : TacticM Unit) (fnExpr : Expr)
| none => x
-- The array of ids are identifiers to use when introducing fresh variables
-def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Name) (asmTac : TacticM Unit) : TacticM Unit := do
+def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrLocal) (ids : Array Name) (asmTac : TacticM Unit) : TacticM Unit := do
withMainContext do
-- Retrieve the goal
let mgoal ← Tactic.getMainGoal
@@ -209,7 +214,7 @@ def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Na
-- Otherwise, lookup one.
match withTh with
| some th => do
- match ← progressWith fnExpr th ids asmTac with
+ match ← progressWith fnExpr th keep ids asmTac with
| .Ok => return ()
| .Error msg => throwError msg
| none =>
@@ -218,7 +223,7 @@ def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Na
let decls ← ctx.getDecls
for decl in decls.reverse do
trace[Progress] "Trying assumption: {decl.userName} : {decl.type}"
- let res ← do try progressWith fnExpr (.Local decl) ids asmTac catch _ => continue
+ let res ← do try progressWith fnExpr (.Local decl) keep ids asmTac catch _ => continue
match res with
| .Ok => return ()
| .Error msg => throwError msg
@@ -228,7 +233,7 @@ def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Na
let pspec ← do
let thName ← pspecAttr.find? fName
pure (thName.map fun th => .Theorem th)
- tryLookupApply ids asmTac fnExpr "pspec theorem" pspec do
+ tryLookupApply keep ids asmTac fnExpr "pspec theorem" pspec do
-- It failed: try to lookup a *class* expr spec theorem (those are more
-- specific than class spec theorems)
let pspecClassExpr ← do
@@ -237,7 +242,7 @@ def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Na
| some arg => do
let thName ← pspecClassExprAttr.find? fName arg
pure (thName.map fun th => .Theorem th)
- tryLookupApply ids asmTac fnExpr "pspec class expr theorem" pspecClassExpr do
+ tryLookupApply keep ids asmTac fnExpr "pspec class expr theorem" pspecClassExpr do
-- It failed: try to lookup a *class* spec theorem
let pspecClass ← do
match ← getFirstArgAppName args with
@@ -245,7 +250,7 @@ def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Na
| some argName => do
let thName ← pspecClassAttr.find? fName argName
pure (thName.map fun th => .Theorem th)
- tryLookupApply ids asmTac fnExpr "pspec class theorem" pspecClass do
+ tryLookupApply keep ids asmTac fnExpr "pspec class theorem" pspecClass do
-- Try a recursive call - we try the assumptions of kind "auxDecl"
let ctx ← Lean.MonadLCtx.getLCtx
let decls ← ctx.getAllDecls
@@ -253,21 +258,29 @@ def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Na
| .default | .implDetail => false | .auxDecl => true)
for decl in decls.reverse do
trace[Progress] "Trying recursive assumption: {decl.userName} : {decl.type}"
- let res ← do try progressWith fnExpr (.Local decl) ids asmTac catch _ => continue
+ let res ← do try progressWith fnExpr (.Local decl) keep ids asmTac catch _ => continue
match res with
| .Ok => return ()
| .Error msg => throwError msg
-- Nothing worked: failed
throwError "Progress failed"
-syntax progressArgs := ("with" ident)? ("as" " ⟨ " (ident)+ " ⟩")?
+syntax progressArgs := ("keep" ("as" (ident))?)? ("with" ident)? ("as" " ⟨ " (ident)+ " ⟩")?
def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do
let args := args.raw
-- Process the arguments to retrieve the identifiers to use
trace[Progress] "Progress arguments: {args}"
let args := args.getArgs
- let withArg := (args.get! 0).getArgs
+ let keep : Option Name ← do
+ let args := (args.get! 0).getArgs
+ if args.size > 0 then do
+ let args := (args.get! 1).getArgs
+ if args.size > 0 then pure (some (args.get! 1).getId)
+ else do pure (some (← mkFreshUserName `h))
+ else pure none
+ trace[Progress] "Keep: {keep}"
+ let withArg := (args.get! 1).getArgs
let withArg ← do
if withArg.size > 0 then
let id := withArg.get! 1
@@ -287,11 +300,11 @@ def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do
| id :: _ =>
pure (some (.Theorem id))
else pure none
- let args := (args.get! 1).getArgs
+ let args := (args.get! 2).getArgs
let args := (args.get! 2).getArgs
let ids := args.map Syntax.getId
trace[Progress] "User-provided ids: {ids}"
- progressAsmsOrLookupTheorem withArg ids (firstTac [assumptionTac, Arith.scalarTac])
+ progressAsmsOrLookupTheorem keep withArg ids (firstTac [assumptionTac, Arith.scalarTac])
elab "progress" args:progressArgs : tactic =>
evalProgress args
@@ -306,11 +319,11 @@ namespace Test
#eval showStoredPSpec
#eval showStoredPSpecClass
- theorem Scalar.add_spec {ty} {x y : Scalar ty}
+ theorem Scalar.add_spec1 {ty} {x y : Scalar ty}
(hmin : Scalar.min ty ≤ x.val + y.val)
(hmax : x.val + y.val ≤ Scalar.max ty) :
∃ z, x + y = ret z ∧ z.val = x.val + y.val := by
- progress
+ progress keep as h with Scalar.add_spec as ⟨ z ⟩
simp [*]
/-