diff options
Diffstat (limited to 'backends/lean/Base/Progress')
-rw-r--r-- | backends/lean/Base/Progress/Progress.lean | 37 |
1 files changed, 21 insertions, 16 deletions
diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean index c8f94e9e..c0ddc63d 100644 --- a/backends/lean/Base/Progress/Progress.lean +++ b/backends/lean/Base/Progress/Progress.lean @@ -24,7 +24,7 @@ inductive ProgressError deriving Inhabited def progressWith (fExpr : Expr) (th : TheoremOrLocal) - (keep : Option Name) (ids : Array Name) (splitPost : Bool) + (keep : Option Name) (ids : Array (Option Name)) (splitPost : Bool) (asmTac : TacticM Unit) : TacticM ProgressError := do /- Apply the theorem We try to match the theorem with the goal @@ -90,13 +90,14 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal) -- For the conjunctions, we split according once to separate the equality `f ... = .ret ...` -- from the postcondition, if there is, then continue to split the postcondition if there -- are remaining ids. - let splitEqAndPost (k : Expr → Option Expr → List Name → TacticM ProgressError) : TacticM ProgressError := do + let splitEqAndPost (k : Expr → Option Expr → List (Option Name) → TacticM ProgressError) : TacticM ProgressError := do if ← isConj (← inferType h) then do let hName := (← h.fvarId!.getDecl).userName - let (optId, ids) := listTryPopHead ids - let optIds ← match optId with - | none => do pure (some (hName, ← mkFreshUserName `h)) - | some id => do pure (some (hName, id)) + let (optIds, ids) ← do + match ids with + | [] => do pure (some (hName, ← mkFreshUserName `h), []) + | none :: ids => do pure (some (hName, ← mkFreshUserName `h), ids) + | some id :: ids => do pure (some (hName, id), ids) splitConjTac h optIds (fun hEq hPost => k hEq (some hPost) ids) else k h none ids -- Simplify the target by using the equality and some monad simplifications, @@ -121,8 +122,8 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal) return (.Error m!"Too many ids provided ({ids}): there is no postcondition to split") else return .Ok | some hPost => do - let rec splitPostWithIds (prevId : Name) (hPost : Expr) (ids : List Name) : TacticM ProgressError := do - match ids with + let rec splitPostWithIds (prevId : Name) (hPost : Expr) (ids0 : List (Option Name)) : TacticM ProgressError := do + match ids0 with | [] => /- We used all the user provided ids. Split the remaining conjunctions by using fresh ids if the user @@ -133,9 +134,13 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal) | nid :: ids => do trace[Progress] "Splitting post: {hPost}" -- Split + let nid ← do + match nid with + | none => mkFreshUserName `h + | some nid => pure nid if ← isConj (← inferType hPost) then splitConjTac hPost (some (prevId, nid)) (λ _ nhPost => splitPostWithIds nid nhPost ids) - else return (.Error m!"Too many ids provided ({nid :: ids}) not enough conjuncts to split in the postcondition") + else return (.Error m!"Too many ids provided ({ids0}) not enough conjuncts to split in the postcondition") let curPostId := (← hPost.fvarId!.getDecl).userName splitPostWithIds curPostId hPost ids match res with @@ -168,7 +173,7 @@ def getFirstArg (args : Array Expr) : Option Expr := do /- Helper: try to lookup a theorem and apply it, or continue with another tactic if it fails -/ -def tryLookupApply (keep : Option Name) (ids : Array Name) (splitPost : Bool) +def tryLookupApply (keep : Option Name) (ids : Array (Option Name)) (splitPost : Bool) (asmTac : TacticM Unit) (fExpr : Expr) (kind : String) (th : Option TheoremOrLocal) (x : TacticM Unit) : TacticM Unit := do let res ← do @@ -191,7 +196,7 @@ def tryLookupApply (keep : Option Name) (ids : Array Name) (splitPost : Bool) -- The array of ids are identifiers to use when introducing fresh variables def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrLocal) - (ids : Array Name) (splitPost : Bool) (asmTac : TacticM Unit) : TacticM Unit := do + (ids : Array (Option Name)) (splitPost : Bool) (asmTac : TacticM Unit) : TacticM Unit := do withMainContext do -- Retrieve the goal let mgoal ← Tactic.getMainGoal @@ -258,7 +263,7 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL -- Nothing worked: failed throwError "Progress failed" -syntax progressArgs := ("keep" ("as" (ident))?)? ("with" ident)? ("as" " ⟨ " ident,* " .."? " ⟩")? +syntax progressArgs := ("keep" ("as" (ident))?)? ("with" ident)? ("as" " ⟨ " (ident <|> "_"),* " .."? " ⟩")? def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do let args := args.raw @@ -296,7 +301,7 @@ def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do let ids := let args := (args.get! 2).getArgs let args := (args.get! 2).getSepArgs - args.map Syntax.getId + args.map (λ s => if s.isIdent then some s.getId else none) trace[Progress] "User-provided ids: {ids}" let splitPost : Bool := let args := (args.get! 2).getArgs @@ -307,7 +312,7 @@ def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do elab "progress" args:progressArgs : tactic => evalProgress args -/-namespace Test +/- namespace Test open Primitives Result set_option trace.Progress true @@ -321,9 +326,9 @@ elab "progress" args:progressArgs : tactic => (hmax : x.val + y.val ≤ Scalar.max ty) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by -- progress keep as h with Scalar.add_spec as ⟨ z ⟩ - progress keep as h as ⟨ z, h1 .. ⟩ + progress keep as h as ⟨ x, h1 .. ⟩ simp [*] -end Test-/ +end Test -/ end Progress |