summaryrefslogtreecommitdiff
path: root/backends/lean/Base/Progress
diff options
context:
space:
mode:
Diffstat (limited to 'backends/lean/Base/Progress')
-rw-r--r--backends/lean/Base/Progress/Progress.lean37
1 files changed, 21 insertions, 16 deletions
diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean
index c8f94e9e..c0ddc63d 100644
--- a/backends/lean/Base/Progress/Progress.lean
+++ b/backends/lean/Base/Progress/Progress.lean
@@ -24,7 +24,7 @@ inductive ProgressError
deriving Inhabited
def progressWith (fExpr : Expr) (th : TheoremOrLocal)
- (keep : Option Name) (ids : Array Name) (splitPost : Bool)
+ (keep : Option Name) (ids : Array (Option Name)) (splitPost : Bool)
(asmTac : TacticM Unit) : TacticM ProgressError := do
/- Apply the theorem
We try to match the theorem with the goal
@@ -90,13 +90,14 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal)
-- For the conjunctions, we split according once to separate the equality `f ... = .ret ...`
-- from the postcondition, if there is, then continue to split the postcondition if there
-- are remaining ids.
- let splitEqAndPost (k : Expr → Option Expr → List Name → TacticM ProgressError) : TacticM ProgressError := do
+ let splitEqAndPost (k : Expr → Option Expr → List (Option Name) → TacticM ProgressError) : TacticM ProgressError := do
if ← isConj (← inferType h) then do
let hName := (← h.fvarId!.getDecl).userName
- let (optId, ids) := listTryPopHead ids
- let optIds ← match optId with
- | none => do pure (some (hName, ← mkFreshUserName `h))
- | some id => do pure (some (hName, id))
+ let (optIds, ids) ← do
+ match ids with
+ | [] => do pure (some (hName, ← mkFreshUserName `h), [])
+ | none :: ids => do pure (some (hName, ← mkFreshUserName `h), ids)
+ | some id :: ids => do pure (some (hName, id), ids)
splitConjTac h optIds (fun hEq hPost => k hEq (some hPost) ids)
else k h none ids
-- Simplify the target by using the equality and some monad simplifications,
@@ -121,8 +122,8 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal)
return (.Error m!"Too many ids provided ({ids}): there is no postcondition to split")
else return .Ok
| some hPost => do
- let rec splitPostWithIds (prevId : Name) (hPost : Expr) (ids : List Name) : TacticM ProgressError := do
- match ids with
+ let rec splitPostWithIds (prevId : Name) (hPost : Expr) (ids0 : List (Option Name)) : TacticM ProgressError := do
+ match ids0 with
| [] =>
/- We used all the user provided ids.
Split the remaining conjunctions by using fresh ids if the user
@@ -133,9 +134,13 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal)
| nid :: ids => do
trace[Progress] "Splitting post: {hPost}"
-- Split
+ let nid ← do
+ match nid with
+ | none => mkFreshUserName `h
+ | some nid => pure nid
if ← isConj (← inferType hPost) then
splitConjTac hPost (some (prevId, nid)) (λ _ nhPost => splitPostWithIds nid nhPost ids)
- else return (.Error m!"Too many ids provided ({nid :: ids}) not enough conjuncts to split in the postcondition")
+ else return (.Error m!"Too many ids provided ({ids0}) not enough conjuncts to split in the postcondition")
let curPostId := (← hPost.fvarId!.getDecl).userName
splitPostWithIds curPostId hPost ids
match res with
@@ -168,7 +173,7 @@ def getFirstArg (args : Array Expr) : Option Expr := do
/- Helper: try to lookup a theorem and apply it, or continue with another tactic
if it fails -/
-def tryLookupApply (keep : Option Name) (ids : Array Name) (splitPost : Bool)
+def tryLookupApply (keep : Option Name) (ids : Array (Option Name)) (splitPost : Bool)
(asmTac : TacticM Unit) (fExpr : Expr)
(kind : String) (th : Option TheoremOrLocal) (x : TacticM Unit) : TacticM Unit := do
let res ← do
@@ -191,7 +196,7 @@ def tryLookupApply (keep : Option Name) (ids : Array Name) (splitPost : Bool)
-- The array of ids are identifiers to use when introducing fresh variables
def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrLocal)
- (ids : Array Name) (splitPost : Bool) (asmTac : TacticM Unit) : TacticM Unit := do
+ (ids : Array (Option Name)) (splitPost : Bool) (asmTac : TacticM Unit) : TacticM Unit := do
withMainContext do
-- Retrieve the goal
let mgoal ← Tactic.getMainGoal
@@ -258,7 +263,7 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL
-- Nothing worked: failed
throwError "Progress failed"
-syntax progressArgs := ("keep" ("as" (ident))?)? ("with" ident)? ("as" " ⟨ " ident,* " .."? " ⟩")?
+syntax progressArgs := ("keep" ("as" (ident))?)? ("with" ident)? ("as" " ⟨ " (ident <|> "_"),* " .."? " ⟩")?
def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do
let args := args.raw
@@ -296,7 +301,7 @@ def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do
let ids :=
let args := (args.get! 2).getArgs
let args := (args.get! 2).getSepArgs
- args.map Syntax.getId
+ args.map (λ s => if s.isIdent then some s.getId else none)
trace[Progress] "User-provided ids: {ids}"
let splitPost : Bool :=
let args := (args.get! 2).getArgs
@@ -307,7 +312,7 @@ def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do
elab "progress" args:progressArgs : tactic =>
evalProgress args
-/-namespace Test
+/- namespace Test
open Primitives Result
set_option trace.Progress true
@@ -321,9 +326,9 @@ elab "progress" args:progressArgs : tactic =>
(hmax : x.val + y.val ≤ Scalar.max ty) :
∃ z, x + y = ret z ∧ z.val = x.val + y.val := by
-- progress keep as h with Scalar.add_spec as ⟨ z ⟩
- progress keep as h as ⟨ z, h1 .. ⟩
+ progress keep as h as ⟨ x, h1 .. ⟩
simp [*]
-end Test-/
+end Test -/
end Progress