From 6cc0279045d40231f1cce83f0edb7aada1e59d92 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 13 Jul 2023 10:37:16 +0200 Subject: Finish implementing the syntax for `progress` --- backends/lean/Base/Progress/Progress.lean | 98 ++++++++++++++++++++++--------- 1 file changed, 70 insertions(+), 28 deletions(-) (limited to 'backends/lean/Base/Progress') diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean index b0db465d..835dc468 100644 --- a/backends/lean/Base/Progress/Progress.lean +++ b/backends/lean/Base/Progress/Progress.lean @@ -25,7 +25,17 @@ inductive TheoremOrLocal where | Theorem (thName : Name) | Local (asm : LocalDecl) -def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (ids : Array Name) (asmTac : TacticM Unit) : TacticM Unit := do +/- Type to propagate the errors of `progressWith`. + We need this because we use the exceptions to backtrack, when trying to + use the assumptions for instance. When there is actually an error we want + to propagate to the user, we return it. -/ +inductive ProgressError +| Ok +| Error (msg : MessageData) +deriving Inhabited + +def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (ids : Array Name) + (asmTac : TacticM Unit) : TacticM ProgressError := do /- Apply the theorem We try to match the theorem with the goal In order to do so, we introduce meta-variables for all the parameters @@ -77,32 +87,62 @@ def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (ids : Array Name) (asmTa -- The assumption should be of the shape: -- `∃ x1 ... xn, f args = ... ∧ ...` -- We introduce the existentially quantified variables and split the top-most - -- conjunction if there is one - splitAllExistsTac thAsm fun h => do - -- Split the conjunction - let splitConj (k : Expr → TacticM Unit) : TacticM Unit := do - if ← isConj (← inferType h) then - splitConjTac h (fun h _ => k h) - else k h - -- Simplify the target by using the equality and some monad simplifications - splitConj fun h => do + -- conjunction if there is one. We use the provided `ids` list to name the + -- introduced variables. + let res ← splitAllExistsTac thAsm ids.toList fun h ids => do + -- Split the conjunctions. + -- For the conjunctions, we split according once to separate the equality `f ... = .ret ...` + -- from the postcondition, if there is, then continue to split the postcondition if there + -- are remaining ids. + let splitEqAndPost (k : Expr → Option Expr → List Name → TacticM ProgressError) : TacticM ProgressError := do + if ← isConj (← inferType h) then do + let hName := (← h.fvarId!.getDecl).userName + let (optId, ids) := listTryPopHead ids + let optIds := match optId with | none => none | some id => some (hName, id) + splitConjTac h optIds (fun hEq hPost => k hEq (some hPost) ids) + else k h none ids + -- Simplify the target by using the equality and some monad simplifications, + -- then continue splitting the post-condition + splitEqAndPost fun hEq hPost ids => do + trace[Progress] "eq and post:\n{hEq} : {← inferType hEq}\n{hPost}" simpAt [] [``Primitives.bind_tc_ret, ``Primitives.bind_tc_fail, ``Primitives.bind_tc_div] - [h.fvarId!] (.targets #[] true) + [hEq.fvarId!] (.targets #[] true) -- Clear the equality let mgoal ← getMainGoal - let mgoal ← mgoal.tryClearMany #[h.fvarId!] + let mgoal ← mgoal.tryClearMany #[hEq.fvarId!] setGoals (mgoal :: (← getUnsolvedGoals)) - -- Update the set of goals - let curGoals ← getUnsolvedGoals - let newGoals := mvars.map Expr.mvarId! - let newGoals ← newGoals.filterM fun mvar => not <$> mvar.isAssigned - trace[Progress] "new goals: {newGoals}" - setGoals newGoals.toList - allGoals asmTac - let newGoals ← getUnsolvedGoals - setGoals (newGoals ++ curGoals) - -- - pure () + trace[Progress] "Goal after splitting eq and post and simplifying the target: {mgoal}" + -- Continue splitting following the ids provided by the user + if ¬ ids.isEmpty then + let hPost ← + match hPost with + | none => do return (.Error m!"Too many ids provided ({ids}): there is no postcondition to split") + | some hPost => pure hPost + let curPostId := (← hPost.fvarId!.getDecl).userName + let rec splitPost (hPost : Expr) (ids : List Name) : TacticM ProgressError := do + match ids with + | [] => pure .Ok -- Stop + | nid :: ids => do + -- Split + if ← isConj hPost then + splitConjTac hPost (some (nid, curPostId)) (λ _ nhPost => splitPost nhPost ids) + else return (.Error m!"Too many ids provided ({nid :: ids}) not enough conjuncts to split in the postcondition") + splitPost hPost ids + else return .Ok + match res with + | .Error _ => return res -- Can we get there? We're using "return" + | .Ok => + -- Update the set of goals + let curGoals ← getUnsolvedGoals + let newGoals := mvars.map Expr.mvarId! + let newGoals ← newGoals.filterM fun mvar => not <$> mvar.isAssigned + trace[Progress] "new goals: {newGoals}" + setGoals newGoals.toList + allGoals asmTac + let newGoals ← getUnsolvedGoals + setGoals (newGoals ++ curGoals) + -- + pure .Ok -- The array of ids are identifiers to use when introducing fresh variables def progressAsmsOrLookupTheorem (ids : Array Name) (asmTac : TacticM Unit) : TacticM Unit := do @@ -124,8 +164,9 @@ def progressAsmsOrLookupTheorem (ids : Array Name) (asmTac : TacticM Unit) : Tac for decl in decls.reverse do trace[Progress] "Trying assumption: {decl.userName} : {decl.type}" try - progressWith fnExpr (.Local decl) ids asmTac - return () + match ← progressWith fnExpr (.Local decl) ids asmTac with + | .Ok => return () + | .Error msg => throwError msg catch _ => continue -- It failed: try to lookup a theorem -- TODO: use a list of theorems, and try them one by one? @@ -136,9 +177,10 @@ def progressAsmsOrLookupTheorem (ids : Array Name) (asmTac : TacticM Unit) : Tac | some thName => pure thName trace[Progress] "Lookuped up: {thName}" -- Apply the theorem - progressWith fnExpr (.Theorem thName) ids asmTac + match ← progressWith fnExpr (.Theorem thName) ids asmTac with + | .Ok => return () + | .Error msg => throwError msg -#check Syntax syntax progressArgs := ("as" " ⟨ " (ident)+ " ⟩")? def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do @@ -168,7 +210,7 @@ namespace Test @[pspec] theorem vec_index_test2 (α : Type u) (v: Vec α) (i: Usize) (h: i.val < v.val.length) : ∃ (x: α), v.index α i = .ret x := by - progress as ⟨ x y z ⟩ + progress as ⟨ x ⟩ simp set_option trace.Progress false -- cgit v1.2.3