From c652e97f7ab13164150331b4aa3f2e7ef11d24b9 Mon Sep 17 00:00:00 2001
From: Son Ho
Date: Tue, 25 Jul 2023 12:13:20 +0200
Subject: Add the possibility of using "_" as ident for progress

---
 backends/lean/Base/Progress/Progress.lean | 37 ++++++++++++++++++-------------
 backends/lean/Base/Utils.lean             |  7 ++++--
 2 files changed, 26 insertions(+), 18 deletions(-)

(limited to 'backends')

diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean
index c8f94e9e..c0ddc63d 100644
--- a/backends/lean/Base/Progress/Progress.lean
+++ b/backends/lean/Base/Progress/Progress.lean
@@ -24,7 +24,7 @@ inductive ProgressError
 deriving Inhabited
 
 def progressWith (fExpr : Expr) (th : TheoremOrLocal)
-  (keep : Option Name) (ids : Array Name) (splitPost : Bool)
+  (keep : Option Name) (ids : Array (Option Name)) (splitPost : Bool)
   (asmTac : TacticM Unit) : TacticM ProgressError := do
   /- Apply the theorem
      We try to match the theorem with the goal
@@ -90,13 +90,14 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal)
     -- For the conjunctions, we split according once to separate the equality `f ... = .ret ...`
     -- from the postcondition, if there is, then continue to split the postcondition if there
     -- are remaining ids.
-    let splitEqAndPost (k : Expr → Option Expr → List Name → TacticM ProgressError) : TacticM ProgressError := do
+    let splitEqAndPost (k : Expr → Option Expr → List (Option Name) → TacticM ProgressError) : TacticM ProgressError := do
       if ← isConj (← inferType h) then do
         let hName := (← h.fvarId!.getDecl).userName
-        let (optId, ids) := listTryPopHead ids
-        let optIds ← match optId with
-          | none => do pure (some (hName, ← mkFreshUserName `h))
-          | some id => do pure (some (hName, id))
+        let (optIds, ids) ← do
+          match ids with
+          | [] => do pure (some (hName, ← mkFreshUserName `h), [])
+          | none :: ids => do pure (some (hName, ← mkFreshUserName `h), ids)
+          | some id :: ids => do pure (some (hName, id), ids)
         splitConjTac h optIds (fun hEq hPost => k hEq (some hPost) ids)
       else k h none ids
     -- Simplify the target by using the equality and some monad simplifications,
@@ -121,8 +122,8 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal)
         return (.Error m!"Too many ids provided ({ids}): there is no postcondition to split")
       else return .Ok
     | some hPost => do
-      let rec splitPostWithIds (prevId : Name) (hPost : Expr) (ids : List Name) : TacticM ProgressError := do
-        match ids with
+      let rec splitPostWithIds (prevId : Name) (hPost : Expr) (ids0 : List (Option Name)) : TacticM ProgressError := do
+        match ids0 with
         | [] =>
           /- We used all the user provided ids.
              Split the remaining conjunctions by using fresh ids if the user
@@ -133,9 +134,13 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal)
         | nid :: ids => do
           trace[Progress] "Splitting post: {hPost}"
           -- Split
+          let nid ← do
+            match nid with
+            | none => mkFreshUserName `h
+            | some nid => pure nid
           if ← isConj (← inferType hPost) then
             splitConjTac hPost (some (prevId, nid)) (λ _ nhPost => splitPostWithIds nid nhPost ids)
-          else return (.Error m!"Too many ids provided ({nid :: ids}) not enough conjuncts to split in the postcondition")
+          else return (.Error m!"Too many ids provided ({ids0}) not enough conjuncts to split in the postcondition")
       let curPostId := (← hPost.fvarId!.getDecl).userName
       splitPostWithIds curPostId hPost ids
   match res with
@@ -168,7 +173,7 @@ def getFirstArg (args : Array Expr) : Option Expr := do
 
 /- Helper: try to lookup a theorem and apply it, or continue with another tactic
    if it fails -/
-def tryLookupApply (keep : Option Name) (ids : Array Name) (splitPost : Bool)
+def tryLookupApply (keep : Option Name) (ids : Array (Option Name)) (splitPost : Bool)
   (asmTac : TacticM Unit) (fExpr : Expr)
   (kind : String) (th : Option TheoremOrLocal) (x : TacticM Unit) : TacticM Unit := do
   let res ← do
@@ -191,7 +196,7 @@ def tryLookupApply (keep : Option Name) (ids : Array Name) (splitPost : Bool)
 
 -- The array of ids are identifiers to use when introducing fresh variables
 def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrLocal)
-  (ids : Array Name) (splitPost : Bool) (asmTac : TacticM Unit) : TacticM Unit := do
+  (ids : Array (Option Name)) (splitPost : Bool) (asmTac : TacticM Unit) : TacticM Unit := do
   withMainContext do
   -- Retrieve the goal
   let mgoal ← Tactic.getMainGoal
@@ -258,7 +263,7 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL
     -- Nothing worked: failed
     throwError "Progress failed"
 
-syntax progressArgs := ("keep" ("as" (ident))?)? ("with" ident)? ("as" " ⟨ " ident,* " .."? " ⟩")?
+syntax progressArgs := ("keep" ("as" (ident))?)? ("with" ident)? ("as" " ⟨ " (ident <|> "_"),* " .."? " ⟩")?
 
 def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do
   let args := args.raw
@@ -296,7 +301,7 @@ def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do
   let ids :=
     let args := (args.get! 2).getArgs
     let args := (args.get! 2).getSepArgs
-    args.map Syntax.getId
+    args.map (λ s => if s.isIdent then some s.getId else none)
   trace[Progress] "User-provided ids: {ids}"
   let splitPost : Bool :=
     let args := (args.get! 2).getArgs
@@ -307,7 +312,7 @@ def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do
 elab "progress" args:progressArgs : tactic =>
   evalProgress args
 
-/-namespace Test
+/- namespace Test
   open Primitives Result
 
   set_option trace.Progress true
@@ -321,9 +326,9 @@ elab "progress" args:progressArgs : tactic =>
     (hmax : x.val + y.val ≤ Scalar.max ty) :
     ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by
 --    progress keep as h with Scalar.add_spec as ⟨ z ⟩
-    progress keep as h as ⟨ z, h1 .. ⟩
+    progress keep as h as ⟨ x, h1 .. ⟩
     simp [*]
 
-end Test-/
+end Test -/
 
 end Progress
diff --git a/backends/lean/Base/Utils.lean b/backends/lean/Base/Utils.lean
index 3b3d4729..66497a49 100644
--- a/backends/lean/Base/Utils.lean
+++ b/backends/lean/Base/Utils.lean
@@ -484,9 +484,12 @@ def listTryPopHead (ls : List α) : Option α × List α :=
    If `ids` is not empty, we use it to name the introduced variables. We
    transmit the stripped expression and the remaining ids to the continuation.
  -/
-partial def splitAllExistsTac [Inhabited α] (h : Expr) (ids : List Name) (k : Expr → List Name → TacticM α) : TacticM α := do
+partial def splitAllExistsTac [Inhabited α] (h : Expr) (ids : List (Option Name)) (k : Expr → List (Option Name) → TacticM α) : TacticM α := do
   try
-    let (optId, ids) := listTryPopHead ids
+    let (optId, ids) :=
+      match ids with
+      | [] => (none, [])
+      | x :: ids => (x, ids)
     splitExistsTac h optId (fun _ body => splitAllExistsTac body ids k)
   catch _ => k h ids
 
-- 
cgit v1.2.3