diff options
Diffstat (limited to 'backends')
-rw-r--r-- | backends/lean/Base/IList/IList.lean | 67 | ||||
-rw-r--r-- | backends/lean/Base/Progress/Progress.lean | 55 |
2 files changed, 101 insertions, 21 deletions
diff --git a/backends/lean/Base/IList/IList.lean b/backends/lean/Base/IList/IList.lean index ddb10236..1773e593 100644 --- a/backends/lean/Base/IList/IList.lean +++ b/backends/lean/Base/IList/IList.lean @@ -175,6 +175,73 @@ theorem idrop_eq_nil_of_le (hineq : ls.len ≤ i) : idrop i ls = [] := by apply hi linarith +@[simp] +theorem index_ne + {α : Type u} [Inhabited α] (l: List α) (i: ℤ) (j: ℤ) (x: α) : + 0 ≤ i → i < l.len → 0 ≤ j → j < l.len → j ≠ i → + (l.update i x).index j = l.index j + := + λ _ _ _ _ _ => match l with + | [] => by simp at * + | hd :: tl => + if h: i = 0 then + have : j ≠ 0 := by scalar_tac + by simp [*] + else if h : j = 0 then + have : i ≠ 0 := by scalar_tac + by simp [*] + else + by + simp [*] + simp at * + apply index_ne <;> scalar_tac + +@[simp] +theorem index_eq + {α : Type u} [Inhabited α] (l: List α) (i: ℤ) (x: α) : + 0 ≤ i → i < l.len → + (l.update i x).index i = x + := + fun _ _ => match l with + | [] => by simp at *; exfalso; scalar_tac -- TODO: exfalso needed. Son FIXME + | hd :: tl => + if h: i = 0 then + by + simp [*] + else + by + simp [*] + simp at * + apply index_eq <;> scalar_tac + +def allP {α : Type u} (l : List α) (p: α → Prop) : Prop := + foldr (fun a r => p a ∧ r) True l + +@[simp] +theorem allP_nil {α : Type u} (p: α → Prop) : allP [] p := + by simp [allP, foldr] + +@[simp] +theorem allP_cons {α : Type u} (hd: α) (tl : List α) (p: α → Prop) : + allP (hd :: tl) p ↔ p hd ∧ allP tl p + := by simp [allP, foldr] + +def pairwise_rel + {α : Type u} (rel : α → α → Prop) (l: List α) : Prop + := match l with + | [] => True + | hd :: tl => allP tl (rel hd) ∧ pairwise_rel rel tl + +@[simp] +theorem pairwise_rel_nil {α : Type u} (rel : α → α → Prop) : + pairwise_rel rel [] + := by simp [pairwise_rel] + +@[simp] +theorem pairwise_rel_cons {α : Type u} (rel : α → α → Prop) (hd: α) (tl: List α) : + pairwise_rel rel (hd :: tl) ↔ allP tl (rel hd) ∧ pairwise_rel rel tl + := by simp [pairwise_rel] + end Lemmas end List diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean index 974a6364..1f734415 100644 --- a/backends/lean/Base/Progress/Progress.lean +++ b/backends/lean/Base/Progress/Progress.lean @@ -39,7 +39,7 @@ inductive ProgressError | Error (msg : MessageData) deriving Inhabited -def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (ids : Array Name) +def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (keep : Option Name) (ids : Array Name) (asmTac : TacticM Unit) : TacticM ProgressError := do /- Apply the theorem We try to match the theorem with the goal @@ -88,7 +88,7 @@ def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (ids : Array Name) match th with | .Theorem thName => mkAppOptM thName (mvars.map some) | .Local decl => mkAppOptM' (mkFVar decl.fvarId) (mvars.map some) - let asmName ← mkFreshUserName `h + let asmName ← do match keep with | none => mkFreshUserName `h | some n => do pure n let thTy ← inferType th let thAsm ← Utils.addDeclTac asmName th thTy (asLet := false) withMainContext do -- The context changed - TODO: remove once addDeclTac is updated @@ -109,7 +109,9 @@ def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (ids : Array Name) if ← isConj (← inferType h) then do let hName := (← h.fvarId!.getDecl).userName let (optId, ids) := listTryPopHead ids - let optIds := match optId with | none => none | some id => some (hName, id) + let optIds ← match optId with + | none => do pure (some (hName, ← mkFreshUserName `h)) + | some id => do pure (some (hName, id)) splitConjTac h optIds (fun hEq hPost => k hEq (some hPost) ids) else k h none ids -- Simplify the target by using the equality and some monad simplifications, @@ -118,9 +120,12 @@ def progressWith (fnExpr : Expr) (th : TheoremOrLocal) (ids : Array Name) trace[Progress] "eq and post:\n{hEq} : {← inferType hEq}\n{hPost}" simpAt [] [``Primitives.bind_tc_ret, ``Primitives.bind_tc_fail, ``Primitives.bind_tc_div] [hEq.fvarId!] (.targets #[] true) - -- Clear the equality - let mgoal ← getMainGoal - let mgoal ← mgoal.tryClearMany #[hEq.fvarId!] + -- Clear the equality, unless the user requests not to do so + let mgoal ← do + if keep.isSome then getMainGoal + else do + let mgoal ← getMainGoal + mgoal.tryClearMany #[hEq.fvarId!] setGoals (mgoal :: (← getUnsolvedGoals)) trace[Progress] "Goal after splitting eq and post and simplifying the target: {mgoal}" -- Continue splitting following the ids provided by the user @@ -170,7 +175,7 @@ def getFirstArg (args : Array Expr) : Option Expr := do /- Helper: try to lookup a theorem and apply it, or continue with another tactic if it fails -/ -def tryLookupApply (ids : Array Name) (asmTac : TacticM Unit) (fnExpr : Expr) +def tryLookupApply (keep : Option Name) (ids : Array Name) (asmTac : TacticM Unit) (fnExpr : Expr) (kind : String) (th : Option TheoremOrLocal) (x : TacticM Unit) : TacticM Unit := do let res ← do match th with @@ -182,7 +187,7 @@ def tryLookupApply (ids : Array Name) (asmTac : TacticM Unit) (fnExpr : Expr) -- Apply the theorem let res ← do try - let res ← progressWith fnExpr th ids asmTac + let res ← progressWith fnExpr th keep ids asmTac pure (some res) catch _ => none match res with @@ -191,7 +196,7 @@ def tryLookupApply (ids : Array Name) (asmTac : TacticM Unit) (fnExpr : Expr) | none => x -- The array of ids are identifiers to use when introducing fresh variables -def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Name) (asmTac : TacticM Unit) : TacticM Unit := do +def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrLocal) (ids : Array Name) (asmTac : TacticM Unit) : TacticM Unit := do withMainContext do -- Retrieve the goal let mgoal ← Tactic.getMainGoal @@ -209,7 +214,7 @@ def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Na -- Otherwise, lookup one. match withTh with | some th => do - match ← progressWith fnExpr th ids asmTac with + match ← progressWith fnExpr th keep ids asmTac with | .Ok => return () | .Error msg => throwError msg | none => @@ -218,7 +223,7 @@ def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Na let decls ← ctx.getDecls for decl in decls.reverse do trace[Progress] "Trying assumption: {decl.userName} : {decl.type}" - let res ← do try progressWith fnExpr (.Local decl) ids asmTac catch _ => continue + let res ← do try progressWith fnExpr (.Local decl) keep ids asmTac catch _ => continue match res with | .Ok => return () | .Error msg => throwError msg @@ -228,7 +233,7 @@ def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Na let pspec ← do let thName ← pspecAttr.find? fName pure (thName.map fun th => .Theorem th) - tryLookupApply ids asmTac fnExpr "pspec theorem" pspec do + tryLookupApply keep ids asmTac fnExpr "pspec theorem" pspec do -- It failed: try to lookup a *class* expr spec theorem (those are more -- specific than class spec theorems) let pspecClassExpr ← do @@ -237,7 +242,7 @@ def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Na | some arg => do let thName ← pspecClassExprAttr.find? fName arg pure (thName.map fun th => .Theorem th) - tryLookupApply ids asmTac fnExpr "pspec class expr theorem" pspecClassExpr do + tryLookupApply keep ids asmTac fnExpr "pspec class expr theorem" pspecClassExpr do -- It failed: try to lookup a *class* spec theorem let pspecClass ← do match ← getFirstArgAppName args with @@ -245,7 +250,7 @@ def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Na | some argName => do let thName ← pspecClassAttr.find? fName argName pure (thName.map fun th => .Theorem th) - tryLookupApply ids asmTac fnExpr "pspec class theorem" pspecClass do + tryLookupApply keep ids asmTac fnExpr "pspec class theorem" pspecClass do -- Try a recursive call - we try the assumptions of kind "auxDecl" let ctx ← Lean.MonadLCtx.getLCtx let decls ← ctx.getAllDecls @@ -253,21 +258,29 @@ def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Na | .default | .implDetail => false | .auxDecl => true) for decl in decls.reverse do trace[Progress] "Trying recursive assumption: {decl.userName} : {decl.type}" - let res ← do try progressWith fnExpr (.Local decl) ids asmTac catch _ => continue + let res ← do try progressWith fnExpr (.Local decl) keep ids asmTac catch _ => continue match res with | .Ok => return () | .Error msg => throwError msg -- Nothing worked: failed throwError "Progress failed" -syntax progressArgs := ("with" ident)? ("as" " ⟨ " (ident)+ " ⟩")? +syntax progressArgs := ("keep" ("as" (ident))?)? ("with" ident)? ("as" " ⟨ " (ident)+ " ⟩")? def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do let args := args.raw -- Process the arguments to retrieve the identifiers to use trace[Progress] "Progress arguments: {args}" let args := args.getArgs - let withArg := (args.get! 0).getArgs + let keep : Option Name ← do + let args := (args.get! 0).getArgs + if args.size > 0 then do + let args := (args.get! 1).getArgs + if args.size > 0 then pure (some (args.get! 1).getId) + else do pure (some (← mkFreshUserName `h)) + else pure none + trace[Progress] "Keep: {keep}" + let withArg := (args.get! 1).getArgs let withArg ← do if withArg.size > 0 then let id := withArg.get! 1 @@ -287,11 +300,11 @@ def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do | id :: _ => pure (some (.Theorem id)) else pure none - let args := (args.get! 1).getArgs + let args := (args.get! 2).getArgs let args := (args.get! 2).getArgs let ids := args.map Syntax.getId trace[Progress] "User-provided ids: {ids}" - progressAsmsOrLookupTheorem withArg ids (firstTac [assumptionTac, Arith.scalarTac]) + progressAsmsOrLookupTheorem keep withArg ids (firstTac [assumptionTac, Arith.scalarTac]) elab "progress" args:progressArgs : tactic => evalProgress args @@ -306,11 +319,11 @@ namespace Test #eval showStoredPSpec #eval showStoredPSpecClass - theorem Scalar.add_spec {ty} {x y : Scalar ty} + theorem Scalar.add_spec1 {ty} {x y : Scalar ty} (hmin : Scalar.min ty ≤ x.val + y.val) (hmax : x.val + y.val ≤ Scalar.max ty) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by - progress + progress keep as h with Scalar.add_spec as ⟨ z ⟩ simp [*] /- |