summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--backends/lean/Base/Progress/Progress.lean98
-rw-r--r--backends/lean/Base/Utils.lean47
2 files changed, 107 insertions, 38 deletions
diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean
index b0db465d..835dc468 100644
--- a/backends/lean/Base/Progress/Progress.lean
+++ b/backends/lean/Base/Progress/Progress.lean
@@ -25,7 +25,17 @@ inductive TheoremOrLocal where
| Theorem (thName : Name)
| Local (asm : LocalDecl)
-def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (ids : Array Name) (asmTac : TacticM Unit) : TacticM Unit := do
+/- Type to propagate the errors of `progressWith`.
+ We need this because we use the exceptions to backtrack, when trying to
+ use the assumptions for instance. When there is actually an error we want
+ to propagate to the user, we return it. -/
+inductive ProgressError
+| Ok
+| Error (msg : MessageData)
+deriving Inhabited
+
+def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (ids : Array Name)
+ (asmTac : TacticM Unit) : TacticM ProgressError := do
/- Apply the theorem
We try to match the theorem with the goal
In order to do so, we introduce meta-variables for all the parameters
@@ -77,32 +87,62 @@ def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (ids : Array Name) (asmTa
-- The assumption should be of the shape:
-- `∃ x1 ... xn, f args = ... ∧ ...`
-- We introduce the existentially quantified variables and split the top-most
- -- conjunction if there is one
- splitAllExistsTac thAsm fun h => do
- -- Split the conjunction
- let splitConj (k : Expr → TacticM Unit) : TacticM Unit := do
- if ← isConj (← inferType h) then
- splitConjTac h (fun h _ => k h)
- else k h
- -- Simplify the target by using the equality and some monad simplifications
- splitConj fun h => do
+ -- conjunction if there is one. We use the provided `ids` list to name the
+ -- introduced variables.
+ let res ← splitAllExistsTac thAsm ids.toList fun h ids => do
+ -- Split the conjunctions.
+ -- For the conjunctions, we split according once to separate the equality `f ... = .ret ...`
+ -- from the postcondition, if there is, then continue to split the postcondition if there
+ -- are remaining ids.
+ let splitEqAndPost (k : Expr → Option Expr → List Name → TacticM ProgressError) : TacticM ProgressError := do
+ if ← isConj (← inferType h) then do
+ let hName := (← h.fvarId!.getDecl).userName
+ let (optId, ids) := listTryPopHead ids
+ let optIds := match optId with | none => none | some id => some (hName, id)
+ splitConjTac h optIds (fun hEq hPost => k hEq (some hPost) ids)
+ else k h none ids
+ -- Simplify the target by using the equality and some monad simplifications,
+ -- then continue splitting the post-condition
+ splitEqAndPost fun hEq hPost ids => do
+ trace[Progress] "eq and post:\n{hEq} : {← inferType hEq}\n{hPost}"
simpAt [] [``Primitives.bind_tc_ret, ``Primitives.bind_tc_fail, ``Primitives.bind_tc_div]
- [h.fvarId!] (.targets #[] true)
+ [hEq.fvarId!] (.targets #[] true)
-- Clear the equality
let mgoal ← getMainGoal
- let mgoal ← mgoal.tryClearMany #[h.fvarId!]
+ let mgoal ← mgoal.tryClearMany #[hEq.fvarId!]
setGoals (mgoal :: (← getUnsolvedGoals))
- -- Update the set of goals
- let curGoals ← getUnsolvedGoals
- let newGoals := mvars.map Expr.mvarId!
- let newGoals ← newGoals.filterM fun mvar => not <$> mvar.isAssigned
- trace[Progress] "new goals: {newGoals}"
- setGoals newGoals.toList
- allGoals asmTac
- let newGoals ← getUnsolvedGoals
- setGoals (newGoals ++ curGoals)
- --
- pure ()
+ trace[Progress] "Goal after splitting eq and post and simplifying the target: {mgoal}"
+ -- Continue splitting following the ids provided by the user
+ if ¬ ids.isEmpty then
+ let hPost ←
+ match hPost with
+ | none => do return (.Error m!"Too many ids provided ({ids}): there is no postcondition to split")
+ | some hPost => pure hPost
+ let curPostId := (← hPost.fvarId!.getDecl).userName
+ let rec splitPost (hPost : Expr) (ids : List Name) : TacticM ProgressError := do
+ match ids with
+ | [] => pure .Ok -- Stop
+ | nid :: ids => do
+ -- Split
+ if ← isConj hPost then
+ splitConjTac hPost (some (nid, curPostId)) (λ _ nhPost => splitPost nhPost ids)
+ else return (.Error m!"Too many ids provided ({nid :: ids}) not enough conjuncts to split in the postcondition")
+ splitPost hPost ids
+ else return .Ok
+ match res with
+ | .Error _ => return res -- Can we get there? We're using "return"
+ | .Ok =>
+ -- Update the set of goals
+ let curGoals ← getUnsolvedGoals
+ let newGoals := mvars.map Expr.mvarId!
+ let newGoals ← newGoals.filterM fun mvar => not <$> mvar.isAssigned
+ trace[Progress] "new goals: {newGoals}"
+ setGoals newGoals.toList
+ allGoals asmTac
+ let newGoals ← getUnsolvedGoals
+ setGoals (newGoals ++ curGoals)
+ --
+ pure .Ok
-- The array of ids are identifiers to use when introducing fresh variables
def progressAsmsOrLookupTheorem (ids : Array Name) (asmTac : TacticM Unit) : TacticM Unit := do
@@ -124,8 +164,9 @@ def progressAsmsOrLookupTheorem (ids : Array Name) (asmTac : TacticM Unit) : Tac
for decl in decls.reverse do
trace[Progress] "Trying assumption: {decl.userName} : {decl.type}"
try
- progressWith fnExpr (.Local decl) ids asmTac
- return ()
+ match ← progressWith fnExpr (.Local decl) ids asmTac with
+ | .Ok => return ()
+ | .Error msg => throwError msg
catch _ => continue
-- It failed: try to lookup a theorem
-- TODO: use a list of theorems, and try them one by one?
@@ -136,9 +177,10 @@ def progressAsmsOrLookupTheorem (ids : Array Name) (asmTac : TacticM Unit) : Tac
| some thName => pure thName
trace[Progress] "Lookuped up: {thName}"
-- Apply the theorem
- progressWith fnExpr (.Theorem thName) ids asmTac
+ match ← progressWith fnExpr (.Theorem thName) ids asmTac with
+ | .Ok => return ()
+ | .Error msg => throwError msg
-#check Syntax
syntax progressArgs := ("as" " ⟨ " (ident)+ " ⟩")?
def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do
@@ -168,7 +210,7 @@ namespace Test
@[pspec]
theorem vec_index_test2 (α : Type u) (v: Vec α) (i: Usize) (h: i.val < v.val.length) :
∃ (x: α), v.index α i = .ret x := by
- progress as ⟨ x y z ⟩
+ progress as ⟨ x ⟩
simp
set_option trace.Progress false
diff --git a/backends/lean/Base/Utils.lean b/backends/lean/Base/Utils.lean
index 505412b9..599c3a9f 100644
--- a/backends/lean/Base/Utils.lean
+++ b/backends/lean/Base/Utils.lean
@@ -396,13 +396,21 @@ example (x y : Int) (h0 : x ≤ y ∨ x ≥ y) : x ≤ y ∨ x ≥ y := by
. right; assumption
--- Tactic to split on an exists
-def splitExistsTac (h : Expr) (k : Expr → Expr → TacticM α) : TacticM α := do
+-- Tactic to split on an exists.
+-- `h` must be an FVar
+def splitExistsTac (h : Expr) (optId : Option Name) (k : Expr → Expr → TacticM α) : TacticM α := do
withMainContext do
let goal ← getMainGoal
let hTy ← inferType h
if isExists hTy then do
- let newGoals ← goal.cases h.fvarId! #[]
+ -- Try to use the user-provided names
+ let altVarNames ←
+ match optId with
+ | none => pure #[]
+ | some id => do
+ let hDecl ← h.fvarId!.getDecl
+ pure #[{ varNames := [id, hDecl.userName] }]
+ let newGoals ← goal.cases h.fvarId! altVarNames
-- There should be exactly one goal
match newGoals.toList with
| [ newGoal ] =>
@@ -418,18 +426,37 @@ def splitExistsTac (h : Expr) (k : Expr → Expr → TacticM α) : TacticM α :=
else
throwError "Not a conjunction"
-partial def splitAllExistsTac [Inhabited α] (h : Expr) (k : Expr → TacticM α) : TacticM α := do
+-- TODO: move
+def listTryPopHead (ls : List α) : Option α × List α :=
+ match ls with
+ | [] => (none, ls)
+ | hd :: tl => (some hd, tl)
+
+/- Destruct all the existentials appearing in `h`, and introduce them as variables
+ in the context.
+
+ If `ids` is not empty, we use it to name the introduced variables. We
+ transmit the stripped expression and the remaining ids to the continuation.
+ -/
+partial def splitAllExistsTac [Inhabited α] (h : Expr) (ids : List Name) (k : Expr → List Name → TacticM α) : TacticM α := do
try
- splitExistsTac h (fun _ body => splitAllExistsTac body k)
- catch _ => k h
+ let (optId, ids) := listTryPopHead ids
+ splitExistsTac h optId (fun _ body => splitAllExistsTac body ids k)
+ catch _ => k h ids
-- Tactic to split on a conjunction.
-def splitConjTac (h : Expr) (k : Expr → Expr → TacticM α) : TacticM α := do
+def splitConjTac (h : Expr) (optIds : Option (Name × Name)) (k : Expr → Expr → TacticM α) : TacticM α := do
withMainContext do
let goal ← getMainGoal
let hTy ← inferType h
if ← isConj hTy then do
- let newGoals ← goal.cases h.fvarId! #[]
+ -- Try to use the user-provided names
+ let altVarNames ←
+ match optIds with
+ | none => pure #[]
+ | some (id0, id1) => do
+ pure #[{ varNames := [id0, id1] }]
+ let newGoals ← goal.cases h.fvarId! altVarNames
-- There should be exactly one goal
match newGoals.toList with
| [ newGoal ] =>
@@ -449,13 +476,13 @@ elab "split_conj " n:ident : tactic => do
withMainContext do
let decl ← Lean.Meta.getLocalDeclFromUserName n.getId
let fvar := mkFVar decl.fvarId
- splitConjTac fvar (fun _ _ => pure ())
+ splitConjTac fvar none (fun _ _ => pure ())
elab "split_all_exists " n:ident : tactic => do
withMainContext do
let decl ← Lean.Meta.getLocalDeclFromUserName n.getId
let fvar := mkFVar decl.fvarId
- splitAllExistsTac fvar (fun _ => pure ())
+ splitAllExistsTac fvar [] (fun _ _ => pure ())
example (h : a ∧ b) : a := by
split_all_exists h