diff options
-rw-r--r-- | backends/lean/Base/Progress/Progress.lean | 98 | ||||
-rw-r--r-- | backends/lean/Base/Utils.lean | 47 |
2 files changed, 107 insertions, 38 deletions
diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean index b0db465d..835dc468 100644 --- a/backends/lean/Base/Progress/Progress.lean +++ b/backends/lean/Base/Progress/Progress.lean @@ -25,7 +25,17 @@ inductive TheoremOrLocal where | Theorem (thName : Name) | Local (asm : LocalDecl) -def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (ids : Array Name) (asmTac : TacticM Unit) : TacticM Unit := do +/- Type to propagate the errors of `progressWith`. + We need this because we use the exceptions to backtrack, when trying to + use the assumptions for instance. When there is actually an error we want + to propagate to the user, we return it. -/ +inductive ProgressError +| Ok +| Error (msg : MessageData) +deriving Inhabited + +def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (ids : Array Name) + (asmTac : TacticM Unit) : TacticM ProgressError := do /- Apply the theorem We try to match the theorem with the goal In order to do so, we introduce meta-variables for all the parameters @@ -77,32 +87,62 @@ def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (ids : Array Name) (asmTa -- The assumption should be of the shape: -- `∃ x1 ... xn, f args = ... ∧ ...` -- We introduce the existentially quantified variables and split the top-most - -- conjunction if there is one - splitAllExistsTac thAsm fun h => do - -- Split the conjunction - let splitConj (k : Expr → TacticM Unit) : TacticM Unit := do - if ← isConj (← inferType h) then - splitConjTac h (fun h _ => k h) - else k h - -- Simplify the target by using the equality and some monad simplifications - splitConj fun h => do + -- conjunction if there is one. We use the provided `ids` list to name the + -- introduced variables. + let res ← splitAllExistsTac thAsm ids.toList fun h ids => do + -- Split the conjunctions. + -- For the conjunctions, we split according once to separate the equality `f ... = .ret ...` + -- from the postcondition, if there is, then continue to split the postcondition if there + -- are remaining ids. + let splitEqAndPost (k : Expr → Option Expr → List Name → TacticM ProgressError) : TacticM ProgressError := do + if ← isConj (← inferType h) then do + let hName := (← h.fvarId!.getDecl).userName + let (optId, ids) := listTryPopHead ids + let optIds := match optId with | none => none | some id => some (hName, id) + splitConjTac h optIds (fun hEq hPost => k hEq (some hPost) ids) + else k h none ids + -- Simplify the target by using the equality and some monad simplifications, + -- then continue splitting the post-condition + splitEqAndPost fun hEq hPost ids => do + trace[Progress] "eq and post:\n{hEq} : {← inferType hEq}\n{hPost}" simpAt [] [``Primitives.bind_tc_ret, ``Primitives.bind_tc_fail, ``Primitives.bind_tc_div] - [h.fvarId!] (.targets #[] true) + [hEq.fvarId!] (.targets #[] true) -- Clear the equality let mgoal ← getMainGoal - let mgoal ← mgoal.tryClearMany #[h.fvarId!] + let mgoal ← mgoal.tryClearMany #[hEq.fvarId!] setGoals (mgoal :: (← getUnsolvedGoals)) - -- Update the set of goals - let curGoals ← getUnsolvedGoals - let newGoals := mvars.map Expr.mvarId! - let newGoals ← newGoals.filterM fun mvar => not <$> mvar.isAssigned - trace[Progress] "new goals: {newGoals}" - setGoals newGoals.toList - allGoals asmTac - let newGoals ← getUnsolvedGoals - setGoals (newGoals ++ curGoals) - -- - pure () + trace[Progress] "Goal after splitting eq and post and simplifying the target: {mgoal}" + -- Continue splitting following the ids provided by the user + if ¬ ids.isEmpty then + let hPost ← + match hPost with + | none => do return (.Error m!"Too many ids provided ({ids}): there is no postcondition to split") + | some hPost => pure hPost + let curPostId := (← hPost.fvarId!.getDecl).userName + let rec splitPost (hPost : Expr) (ids : List Name) : TacticM ProgressError := do + match ids with + | [] => pure .Ok -- Stop + | nid :: ids => do + -- Split + if ← isConj hPost then + splitConjTac hPost (some (nid, curPostId)) (λ _ nhPost => splitPost nhPost ids) + else return (.Error m!"Too many ids provided ({nid :: ids}) not enough conjuncts to split in the postcondition") + splitPost hPost ids + else return .Ok + match res with + | .Error _ => return res -- Can we get there? We're using "return" + | .Ok => + -- Update the set of goals + let curGoals ← getUnsolvedGoals + let newGoals := mvars.map Expr.mvarId! + let newGoals ← newGoals.filterM fun mvar => not <$> mvar.isAssigned + trace[Progress] "new goals: {newGoals}" + setGoals newGoals.toList + allGoals asmTac + let newGoals ← getUnsolvedGoals + setGoals (newGoals ++ curGoals) + -- + pure .Ok -- The array of ids are identifiers to use when introducing fresh variables def progressAsmsOrLookupTheorem (ids : Array Name) (asmTac : TacticM Unit) : TacticM Unit := do @@ -124,8 +164,9 @@ def progressAsmsOrLookupTheorem (ids : Array Name) (asmTac : TacticM Unit) : Tac for decl in decls.reverse do trace[Progress] "Trying assumption: {decl.userName} : {decl.type}" try - progressWith fnExpr (.Local decl) ids asmTac - return () + match ← progressWith fnExpr (.Local decl) ids asmTac with + | .Ok => return () + | .Error msg => throwError msg catch _ => continue -- It failed: try to lookup a theorem -- TODO: use a list of theorems, and try them one by one? @@ -136,9 +177,10 @@ def progressAsmsOrLookupTheorem (ids : Array Name) (asmTac : TacticM Unit) : Tac | some thName => pure thName trace[Progress] "Lookuped up: {thName}" -- Apply the theorem - progressWith fnExpr (.Theorem thName) ids asmTac + match ← progressWith fnExpr (.Theorem thName) ids asmTac with + | .Ok => return () + | .Error msg => throwError msg -#check Syntax syntax progressArgs := ("as" " ⟨ " (ident)+ " ⟩")? def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do @@ -168,7 +210,7 @@ namespace Test @[pspec] theorem vec_index_test2 (α : Type u) (v: Vec α) (i: Usize) (h: i.val < v.val.length) : ∃ (x: α), v.index α i = .ret x := by - progress as ⟨ x y z ⟩ + progress as ⟨ x ⟩ simp set_option trace.Progress false diff --git a/backends/lean/Base/Utils.lean b/backends/lean/Base/Utils.lean index 505412b9..599c3a9f 100644 --- a/backends/lean/Base/Utils.lean +++ b/backends/lean/Base/Utils.lean @@ -396,13 +396,21 @@ example (x y : Int) (h0 : x ≤ y ∨ x ≥ y) : x ≤ y ∨ x ≥ y := by . right; assumption --- Tactic to split on an exists -def splitExistsTac (h : Expr) (k : Expr → Expr → TacticM α) : TacticM α := do +-- Tactic to split on an exists. +-- `h` must be an FVar +def splitExistsTac (h : Expr) (optId : Option Name) (k : Expr → Expr → TacticM α) : TacticM α := do withMainContext do let goal ← getMainGoal let hTy ← inferType h if isExists hTy then do - let newGoals ← goal.cases h.fvarId! #[] + -- Try to use the user-provided names + let altVarNames ← + match optId with + | none => pure #[] + | some id => do + let hDecl ← h.fvarId!.getDecl + pure #[{ varNames := [id, hDecl.userName] }] + let newGoals ← goal.cases h.fvarId! altVarNames -- There should be exactly one goal match newGoals.toList with | [ newGoal ] => @@ -418,18 +426,37 @@ def splitExistsTac (h : Expr) (k : Expr → Expr → TacticM α) : TacticM α := else throwError "Not a conjunction" -partial def splitAllExistsTac [Inhabited α] (h : Expr) (k : Expr → TacticM α) : TacticM α := do +-- TODO: move +def listTryPopHead (ls : List α) : Option α × List α := + match ls with + | [] => (none, ls) + | hd :: tl => (some hd, tl) + +/- Destruct all the existentials appearing in `h`, and introduce them as variables + in the context. + + If `ids` is not empty, we use it to name the introduced variables. We + transmit the stripped expression and the remaining ids to the continuation. + -/ +partial def splitAllExistsTac [Inhabited α] (h : Expr) (ids : List Name) (k : Expr → List Name → TacticM α) : TacticM α := do try - splitExistsTac h (fun _ body => splitAllExistsTac body k) - catch _ => k h + let (optId, ids) := listTryPopHead ids + splitExistsTac h optId (fun _ body => splitAllExistsTac body ids k) + catch _ => k h ids -- Tactic to split on a conjunction. -def splitConjTac (h : Expr) (k : Expr → Expr → TacticM α) : TacticM α := do +def splitConjTac (h : Expr) (optIds : Option (Name × Name)) (k : Expr → Expr → TacticM α) : TacticM α := do withMainContext do let goal ← getMainGoal let hTy ← inferType h if ← isConj hTy then do - let newGoals ← goal.cases h.fvarId! #[] + -- Try to use the user-provided names + let altVarNames ← + match optIds with + | none => pure #[] + | some (id0, id1) => do + pure #[{ varNames := [id0, id1] }] + let newGoals ← goal.cases h.fvarId! altVarNames -- There should be exactly one goal match newGoals.toList with | [ newGoal ] => @@ -449,13 +476,13 @@ elab "split_conj " n:ident : tactic => do withMainContext do let decl ← Lean.Meta.getLocalDeclFromUserName n.getId let fvar := mkFVar decl.fvarId - splitConjTac fvar (fun _ _ => pure ()) + splitConjTac fvar none (fun _ _ => pure ()) elab "split_all_exists " n:ident : tactic => do withMainContext do let decl ← Lean.Meta.getLocalDeclFromUserName n.getId let fvar := mkFVar decl.fvarId - splitAllExistsTac fvar (fun _ => pure ()) + splitAllExistsTac fvar [] (fun _ _ => pure ()) example (h : a ∧ b) : a := by split_all_exists h |