diff options
Diffstat (limited to 'backends/lean/Base/Progress')
-rw-r--r-- | backends/lean/Base/Progress/Progress.lean | 55 |
1 files changed, 34 insertions, 21 deletions
diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean index 974a6364..1f734415 100644 --- a/backends/lean/Base/Progress/Progress.lean +++ b/backends/lean/Base/Progress/Progress.lean @@ -39,7 +39,7 @@ inductive ProgressError | Error (msg : MessageData) deriving Inhabited -def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (ids : Array Name) +def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (keep : Option Name) (ids : Array Name) (asmTac : TacticM Unit) : TacticM ProgressError := do /- Apply the theorem We try to match the theorem with the goal @@ -88,7 +88,7 @@ def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (ids : Array Name) match th with | .Theorem thName => mkAppOptM thName (mvars.map some) | .Local decl => mkAppOptM' (mkFVar decl.fvarId) (mvars.map some) - let asmName ← mkFreshUserName `h + let asmName ← do match keep with | none => mkFreshUserName `h | some n => do pure n let thTy ← inferType th let thAsm ← Utils.addDeclTac asmName th thTy (asLet := false) withMainContext do -- The context changed - TODO: remove once addDeclTac is updated @@ -109,7 +109,9 @@ def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (ids : Array Name) if ← isConj (← inferType h) then do let hName := (← h.fvarId!.getDecl).userName let (optId, ids) := listTryPopHead ids - let optIds := match optId with | none => none | some id => some (hName, id) + let optIds ← match optId with + | none => do pure (some (hName, ← mkFreshUserName `h)) + | some id => do pure (some (hName, id)) splitConjTac h optIds (fun hEq hPost => k hEq (some hPost) ids) else k h none ids -- Simplify the target by using the equality and some monad simplifications, @@ -118,9 +120,12 @@ def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (ids : Array Name) trace[Progress] "eq and post:\n{hEq} : {← inferType hEq}\n{hPost}" simpAt [] [``Primitives.bind_tc_ret, ``Primitives.bind_tc_fail, ``Primitives.bind_tc_div] [hEq.fvarId!] (.targets #[] true) - -- Clear the equality - let mgoal ← getMainGoal - let mgoal ← mgoal.tryClearMany #[hEq.fvarId!] + -- Clear the equality, unless the user requests not to do so + let mgoal ← do + if keep.isSome then getMainGoal + else do + let mgoal ← getMainGoal + mgoal.tryClearMany #[hEq.fvarId!] setGoals (mgoal :: (← getUnsolvedGoals)) trace[Progress] "Goal after splitting eq and post and simplifying the target: {mgoal}" -- Continue splitting following the ids provided by the user @@ -170,7 +175,7 @@ def getFirstArg (args : Array Expr) : Option Expr := do /- Helper: try to lookup a theorem and apply it, or continue with another tactic if it fails -/ -def tryLookupApply (ids : Array Name) (asmTac : TacticM Unit) (fnExpr : Expr) +def tryLookupApply (keep : Option Name) (ids : Array Name) (asmTac : TacticM Unit) (fnExpr : Expr) (kind : String) (th : Option TheoremOrLocal) (x : TacticM Unit) : TacticM Unit := do let res ← do match th with @@ -182,7 +187,7 @@ def tryLookupApply (ids : Array Name) (asmTac : TacticM Unit) (fnExpr : Expr) -- Apply the theorem let res ← do try - let res ← progressWith fnExpr th ids asmTac + let res ← progressWith fnExpr th keep ids asmTac pure (some res) catch _ => none match res with @@ -191,7 +196,7 @@ def tryLookupApply (ids : Array Name) (asmTac : TacticM Unit) (fnExpr : Expr) | none => x -- The array of ids are identifiers to use when introducing fresh variables -def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Name) (asmTac : TacticM Unit) : TacticM Unit := do +def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrLocal) (ids : Array Name) (asmTac : TacticM Unit) : TacticM Unit := do withMainContext do -- Retrieve the goal let mgoal ← Tactic.getMainGoal @@ -209,7 +214,7 @@ def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Na -- Otherwise, lookup one. match withTh with | some th => do - match ← progressWith fnExpr th ids asmTac with + match ← progressWith fnExpr th keep ids asmTac with | .Ok => return () | .Error msg => throwError msg | none => @@ -218,7 +223,7 @@ def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Na let decls ← ctx.getDecls for decl in decls.reverse do trace[Progress] "Trying assumption: {decl.userName} : {decl.type}" - let res ← do try progressWith fnExpr (.Local decl) ids asmTac catch _ => continue + let res ← do try progressWith fnExpr (.Local decl) keep ids asmTac catch _ => continue match res with | .Ok => return () | .Error msg => throwError msg @@ -228,7 +233,7 @@ def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Na let pspec ← do let thName ← pspecAttr.find? fName pure (thName.map fun th => .Theorem th) - tryLookupApply ids asmTac fnExpr "pspec theorem" pspec do + tryLookupApply keep ids asmTac fnExpr "pspec theorem" pspec do -- It failed: try to lookup a *class* expr spec theorem (those are more -- specific than class spec theorems) let pspecClassExpr ← do @@ -237,7 +242,7 @@ def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Na | some arg => do let thName ← pspecClassExprAttr.find? fName arg pure (thName.map fun th => .Theorem th) - tryLookupApply ids asmTac fnExpr "pspec class expr theorem" pspecClassExpr do + tryLookupApply keep ids asmTac fnExpr "pspec class expr theorem" pspecClassExpr do -- It failed: try to lookup a *class* spec theorem let pspecClass ← do match ← getFirstArgAppName args with @@ -245,7 +250,7 @@ def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Na | some argName => do let thName ← pspecClassAttr.find? fName argName pure (thName.map fun th => .Theorem th) - tryLookupApply ids asmTac fnExpr "pspec class theorem" pspecClass do + tryLookupApply keep ids asmTac fnExpr "pspec class theorem" pspecClass do -- Try a recursive call - we try the assumptions of kind "auxDecl" let ctx ← Lean.MonadLCtx.getLCtx let decls ← ctx.getAllDecls @@ -253,21 +258,29 @@ def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Na | .default | .implDetail => false | .auxDecl => true) for decl in decls.reverse do trace[Progress] "Trying recursive assumption: {decl.userName} : {decl.type}" - let res ← do try progressWith fnExpr (.Local decl) ids asmTac catch _ => continue + let res ← do try progressWith fnExpr (.Local decl) keep ids asmTac catch _ => continue match res with | .Ok => return () | .Error msg => throwError msg -- Nothing worked: failed throwError "Progress failed" -syntax progressArgs := ("with" ident)? ("as" " ⟨ " (ident)+ " ⟩")? +syntax progressArgs := ("keep" ("as" (ident))?)? ("with" ident)? ("as" " ⟨ " (ident)+ " ⟩")? def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do let args := args.raw -- Process the arguments to retrieve the identifiers to use trace[Progress] "Progress arguments: {args}" let args := args.getArgs - let withArg := (args.get! 0).getArgs + let keep : Option Name ← do + let args := (args.get! 0).getArgs + if args.size > 0 then do + let args := (args.get! 1).getArgs + if args.size > 0 then pure (some (args.get! 1).getId) + else do pure (some (← mkFreshUserName `h)) + else pure none + trace[Progress] "Keep: {keep}" + let withArg := (args.get! 1).getArgs let withArg ← do if withArg.size > 0 then let id := withArg.get! 1 @@ -287,11 +300,11 @@ def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do | id :: _ => pure (some (.Theorem id)) else pure none - let args := (args.get! 1).getArgs + let args := (args.get! 2).getArgs let args := (args.get! 2).getArgs let ids := args.map Syntax.getId trace[Progress] "User-provided ids: {ids}" - progressAsmsOrLookupTheorem withArg ids (firstTac [assumptionTac, Arith.scalarTac]) + progressAsmsOrLookupTheorem keep withArg ids (firstTac [assumptionTac, Arith.scalarTac]) elab "progress" args:progressArgs : tactic => evalProgress args @@ -306,11 +319,11 @@ namespace Test #eval showStoredPSpec #eval showStoredPSpecClass - theorem Scalar.add_spec {ty} {x y : Scalar ty} + theorem Scalar.add_spec1 {ty} {x y : Scalar ty} (hmin : Scalar.min ty ≤ x.val + y.val) (hmax : x.val + y.val ≤ Scalar.max ty) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by - progress + progress keep as h with Scalar.add_spec as ⟨ z ⟩ simp [*] /- |