From 6166c410a4b3353377e640acbae9f56e877a9118 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 11 Jul 2023 15:23:49 +0200 Subject: Work on the progress tactic --- backends/lean/Base/Arith.lean | 1 + backends/lean/Base/Arith/Arith.lean | 42 ++++--- backends/lean/Base/Progress.lean | 1 + backends/lean/Base/Progress/Base.lean | 175 ++++++++++++++++++++++++++++++ backends/lean/Base/Progress/Progress.lean | 112 +++++++++++++++++++ backends/lean/Base/Utils.lean | 28 +++++ 6 files changed, 346 insertions(+), 13 deletions(-) create mode 100644 backends/lean/Base/Arith.lean create mode 100644 backends/lean/Base/Progress.lean create mode 100644 backends/lean/Base/Progress/Base.lean create mode 100644 backends/lean/Base/Progress/Progress.lean (limited to 'backends') diff --git a/backends/lean/Base/Arith.lean b/backends/lean/Base/Arith.lean new file mode 100644 index 00000000..fd5698c5 --- /dev/null +++ b/backends/lean/Base/Arith.lean @@ -0,0 +1 @@ +import Base.Arith.Arith diff --git a/backends/lean/Base/Arith/Arith.lean b/backends/lean/Base/Arith/Arith.lean index 0ba73d18..ff628cf3 100644 --- a/backends/lean/Base/Arith/Arith.lean +++ b/backends/lean/Base/Arith/Arith.lean @@ -146,7 +146,7 @@ def collectInstances -- Similar to `collectInstances`, but explores all the local declarations in the -- main context. def collectInstancesFromMainCtx (k : Expr → MetaM (Option Expr)) : Tactic.TacticM (HashSet Expr) := do - Lean.Elab.Tactic.withMainContext do + Tactic.withMainContext do -- Get the local context let ctx ← Lean.MonadLCtx.getLCtx -- Just a matter of precaution @@ -263,8 +263,8 @@ example {a: Type} (v : Vec a) : v.val.length ≤ Scalar.max ScalarTy.Usize := by def splitDisj (h : Expr) (kleft kright : Tactic.TacticM Unit) : Tactic.TacticM Unit := do trace[Arith] "assumption on which to split: {h}" -- Retrieve the main goal - Lean.Elab.Tactic.withMainContext do - let goalType ← Lean.Elab.Tactic.getMainTarget + Tactic.withMainContext do + let goalType ← Tactic.getMainTarget let hDecl := (← getLCtx).get! h.fvarId! let hName := hDecl.userName -- Case disjunction @@ -316,7 +316,7 @@ def splitDisj (h : Expr) (kleft kright : Tactic.TacticM Unit) : Tactic.TacticM U trace[Arith] "new goals: {← Tactic.getUnsolvedGoals}" elab "split_disj " n:ident : tactic => do - Lean.Elab.Tactic.withMainContext do + Tactic.withMainContext do let decl ← Lean.Meta.getLocalDeclFromUserName n.getId let fvar := mkFVar decl.fvarId splitDisj fvar (fun _ => pure ()) (fun _ => pure ()) @@ -347,7 +347,7 @@ example (x y : Int) (h0 : x ≤ y) (h1 : x ≠ y) : x < y := by TODO: we could create a PR for mathlib. -/ def intTacPreprocess : Tactic.TacticM Unit := do - Lean.Elab.Tactic.withMainContext do + Tactic.withMainContext do -- Lookup the instances of PropHasImp (this is how we detect assumptions -- of the proper shape), introduce assumptions in the context and split -- on those @@ -366,19 +366,31 @@ def intTacPreprocess : Tactic.TacticM Unit := do elab "int_tac_preprocess" : tactic => intTacPreprocess +def intTac : Tactic.TacticM Unit := do + Tactic.withMainContext do + Tactic.focus do + -- Preprocess - wondering if we should do this before or after splitting + -- the goal. I think before leads to a smaller proof term? + Tactic.allGoals intTacPreprocess + -- Split the conjunctions in the goal + Utils.repeatTac Utils.splitConjTarget + -- Call linarith + let linarith := + let cfg : Linarith.LinarithConfig := { + -- We do this with our custom preprocessing + splitNe := false + } + Tactic.liftMetaFinishingTactic <| Linarith.linarith false [] cfg + Tactic.allGoals linarith + +elab "int_tac" : tactic => + intTac + example (x : Int) (h0: 0 ≤ x) (h1: x ≠ 0) : 0 < x := by int_tac_preprocess linarith linarith -syntax "int_tac" : tactic -macro_rules - | `(tactic| int_tac) => - `(tactic| - (repeat (apply And.intro)) <;> -- TODO: improve this - int_tac_preprocess <;> - linarith) - example (x : Int) (h0: 0 ≤ x) (h1: x ≠ 0) : 0 < x := by int_tac @@ -386,6 +398,10 @@ example (x : Int) (h0: 0 ≤ x) (h1: x ≠ 0) : 0 < x := by example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y := by int_tac +-- Checking that things append correctly when there are several disjunctions +example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y ∧ x + y ≥ 2 := by + int_tac + -- A tactic to solve linear arithmetic goals in the presence of scalars syntax "scalar_tac" : tactic macro_rules diff --git a/backends/lean/Base/Progress.lean b/backends/lean/Base/Progress.lean new file mode 100644 index 00000000..d812b896 --- /dev/null +++ b/backends/lean/Base/Progress.lean @@ -0,0 +1 @@ +import Base.Progress.Progress diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean new file mode 100644 index 00000000..3f44f46c --- /dev/null +++ b/backends/lean/Base/Progress/Base.lean @@ -0,0 +1,175 @@ +import Lean +import Base.Utils +import Base.Primitives + +namespace Progress + +open Lean Elab Term Meta +open Utils + +-- We can't define and use trace classes in the same file +initialize registerTraceClass `Progress + +-- Return the first conjunct if the expression is a conjunction, or the +-- expression itself otherwise. Also return the second conjunct if it is a +-- conjunction. +def getFirstConj (e : Expr) : MetaM (Expr × Option Expr) := do + e.withApp fun f args => + if f.isConstOf ``And ∧ args.size = 2 then pure (args.get! 0, some (args.get! 1)) + else pure (e, none) + +-- Destruct an equaliy and return the two sides +def destEq (e : Expr) : MetaM (Expr × Expr) := do + e.withApp fun f args => + if f.isConstOf ``Eq ∧ args.size = 3 then pure (args.get! 1, args.get! 2) + else throwError "Not an equality: {e}" + +-- Return the set of FVarIds in the expression +partial def getFVarIds (e : Expr) (hs : HashSet FVarId := HashSet.empty) : MetaM (HashSet FVarId) := do + e.withApp fun body args => do + let hs := if body.isFVar then hs.insert body.fvarId! else hs + args.foldlM (fun hs arg => getFVarIds arg hs) hs + +/- # Progress tactic -/ + +structure PSpecDesc where + -- The universally quantified variables + fvars : Array Expr + -- The existentially quantified variables + evars : Array Expr + -- The function + fName : Name + -- The function arguments + fLevels : List Level + args : Array Expr + -- The universally quantified variables which appear in the function arguments + argsFVars : Array FVarId + -- The returned value + ret : Expr + -- The postcondition (if there is) + post : Option Expr + +section Methods + variable [MonadLiftT MetaM m] [MonadControlT MetaM m] [Monad m] [MonadOptions m] + variable [MonadTrace m] [MonadLiftT IO m] [MonadRef m] [AddMessageContext m] + variable [MonadError m] + variable {a : Type} + + /- Analyze a pspec theorem to decompose its arguments. + + PSpec theorems should be of the following shape: + ``` + ∀ x1 ... xn, H1 → ... Hn → ∃ y1 ... ym. f x1 ... xn = .ret ... ∧ Post1 ∧ ... ∧ Postk + ``` + + The continuation `k` receives the following inputs: + - universally quantified variables + - assumptions + - existentially quantified variables + - function name + - function arguments + - return + - postconditions + + TODO: generalize for when we do inductive proofs + -/ + partial + def withPSpec [Inhabited (m a)] [Nonempty (m a)] (th : Expr) (k : PSpecDesc → m a) + (sanityChecks : Bool := false) : + m a := do + trace[Progress] "Theorem: {th}" + -- Dive into the quantified variables and the assumptions + forallTelescope th fun fvars th => do + trace[Progress] "All argumens: {fvars}" + /- -- Filter the argumens which are not propositions + let rec getFirstPropIdx (i : Nat) : MetaM Nat := do + if i ≥ fargs.size then pure i + else do + let x := fargs.get! i + if ← Meta.isProp (← inferType x) then pure i + else getFirstPropIdx (i + 1) + let i ← getFirstPropIdx 0 + let fvars := fargs.extract 0 i + let hyps := fargs.extract i fargs.size + trace[Progress] "Quantified variables: {fvars}" + trace[Progress] "Assumptions: {hyps}" + -- Sanity check: all hypotheses are propositions (in particular, all the + -- quantified variables are at the beginning) + let hypsAreProp ← hyps.allM fun x => do Meta.isProp (← inferType x) + if ¬ hypsAreProp then + throwError "The theorem doesn't have the proper shape: all the quantified arguments should be at the beginning" + -/ + -- Dive into the existentials + existsTelescope th fun evars th => do + trace[Progress] "Existentials: {evars}" + -- Take the first conjunct + let (th, post) ← getFirstConj th + -- Destruct the equality + let (th, ret) ← destEq th + -- Destruct the application to get the name + th.withApp fun f args => do + if ¬ f.isConst then throwError "Not a constant: {f}" + -- Compute the set of universally quantified variables which appear in the function arguments + let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty + -- Sanity check + if sanityChecks then + let fvarsSet : HashSet FVarId := HashSet.ofArray (fvars.map (fun x => x.fvarId!)) + let filtArgsFVars := allArgsFVars.toArray.filter (fun fvar => ¬ fvarsSet.contains fvar) + if ¬ filtArgsFVars.isEmpty then + let filtArgsFVars := filtArgsFVars.map (fun fvarId => Expr.fvar fvarId) + throwError "Some of the function inputs are not universally quantified: {filtArgsFVars}" + let argsFVars := fvars.map (fun x => x.fvarId!) + let argsFVars := argsFVars.filter (fun fvar => allArgsFVars.contains fvar) + -- Return + trace[Progress] "Function: {f.constName!}"; + let thDesc := { + fvars := fvars + evars := evars + fName := f.constName! + fLevels := f.constLevels! + args := args + argsFVars + ret := ret + post := post + } + k thDesc +end Methods + + +def getPSpecFunName (th : Expr) : MetaM Name := + withPSpec th (fun d => do pure d.fName) true + +structure PSpecAttr where + attr : AttributeImpl + ext : MapDeclarationExtension Name + deriving Inhabited + +/- The persistent map from function to pspec theorems. -/ +initialize pspecAttr : PSpecAttr ← do + let ext ← mkMapDeclarationExtension `pspecMap + let attrImpl := { + name := `pspec + descr := "Marks theorems to use with the `progress` tactic" + add := fun thName stx attrKind => do + Attribute.Builtin.ensureNoArgs stx + -- TODO: use the attribute kind + unless attrKind == AttributeKind.global do + throwError "invalid attribute 'pspec', must be global" + -- Lookup the theorem + let env ← getEnv + let thDecl := env.constants.find! thName + let fName ← MetaM.run' (getPSpecFunName thDecl.type) + trace[Progress] "Registering spec theorem for {fName}" + let env := ext.addEntry env (fName, thName) + setEnv env + pure () + } + registerBuiltinAttribute attrImpl + pure { attr := attrImpl, ext := ext } + +def PSpecAttr.find? (s : PSpecAttr) (name : Name) : MetaM (Option Name) := do + return (s.ext.getState (← getEnv)).find? name + --return s.ext.find? (← getEnv) name + + +end Progress diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean new file mode 100644 index 00000000..1b9ee55c --- /dev/null +++ b/backends/lean/Base/Progress/Progress.lean @@ -0,0 +1,112 @@ +import Lean +import Base.Arith +import Base.Progress.Base + +namespace Progress + +open Lean Elab Term Meta Tactic +open Utils + +namespace Test + open Primitives + + set_option trace.Progress true + + @[pspec] + theorem vec_index_test (α : Type u) (v: Vec α) (i: Usize) (h: i.val < v.val.length) : + ∃ x, v.index α i = .ret x := by + apply + sorry + + #eval pspecAttr.find? ``Primitives.Vec.index +end Test + +#check isDefEq +#check allGoals + +def progressLookupTheorem (asmTac : TacticM Unit) : TacticM Unit := do + withMainContext do + -- Retrieve the goal + let mgoal ← Tactic.getMainGoal + let goalTy ← mgoal.getType + -- Dive into the goal to lookup the theorem + let (fName, fLevels, args) ← do + withPSpec goalTy fun desc => + -- TODO: check that no universally quantified variables in the arguments + pure (desc.fName, desc.fLevels, desc.args) + -- TODO: also try the assumptions + trace[Progress] "Function: {fName}" + -- TODO: use a list of theorems, and try them one by one? + let thName ← do + match ← pspecAttr.find? fName with + | none => throwError "Could not find a pspec theorem for {fName}" + | some thName => pure thName + trace[Progress] "Lookuped up: {thName}" + /- Apply the theorem + We try to match the theorem with the goal + In order to do so, we introduce meta-variables for all the parameters + (i.e., quantified variables and assumpions), and unify those with the goal. + Remark: we do not introduce meta-variables for the quantified variables + which don't appear in the function arguments (we want to let them + quantified). + We also make sure that all the meta variables which appear in the + function arguments have been instantiated + -/ + let env ← getEnv + let thDecl := env.constants.find! thName + let thTy := thDecl.type + -- TODO: the tactic fails if we uncomment withNewMCtxDepth + -- withNewMCtxDepth do + let (mvars, binders, thExBody) ← forallMetaTelescope thTy + -- Introduce the existentially quantified variables and the post-condition + -- in the context + let thBody ← + existsTelescope thExBody fun _evars thBody => do + let (thBody, _) ← destEq thBody + -- There shouldn't be any existential variables in thBody + pure thBody + -- Match the body with the target + let target := mkAppN (.const fName fLevels) args + trace[Progress] "mvars:\n{mvars.map Expr.mvarId!}" + trace[Progress] "thBody: {thBody}" + trace[Progress] "target: {target}" + let ok ← isDefEq thBody target + if ¬ ok then throwError "Could not unify the theorem with the target:\n- theorem: {thBody}\n- target: {target}" + postprocessAppMVars `progress mgoal mvars binders true true + Term.synthesizeSyntheticMVarsNoPostponing + let thBody ← instantiateMVars thBody + trace[Progress] "thBody (after instantiation): {thBody}" + -- Add the instantiated theorem to the assumptions (we apply it on the metavariables). + let th ← mkAppOptM thName (mvars.map some) + let asmName ← mkFreshUserName `h + let thTy ← inferType th + let thAsm ← Utils.addDecl asmName th thTy (asLet := false) + -- Update the set of goals + let curGoals ← getUnsolvedGoals + let newGoals := mvars.map Expr.mvarId! + let newGoals ← newGoals.filterM fun mvar => not <$> mvar.isAssigned + trace[Progress] "new goals: {newGoals}" + setGoals newGoals.toList + allGoals asmTac + let newGoals ← getUnsolvedGoals + setGoals (newGoals ++ curGoals) + -- + pure () + +elab "progress" : tactic => do + progressLookupTheorem (firstTac [assumptionTac, Arith.intTac]) + +namespace Test + open Primitives + + set_option trace.Progress true + + @[pspec] + theorem vec_index_test2 (α : Type u) (v: Vec α) (i: Usize) (h: i.val < v.val.length) : + ∃ x, v.index α i = .ret x := by + progress + tauto + +end Test + +end Progress diff --git a/backends/lean/Base/Utils.lean b/backends/lean/Base/Utils.lean index 2ce63620..1351f3d4 100644 --- a/backends/lean/Base/Utils.lean +++ b/backends/lean/Base/Utils.lean @@ -1,4 +1,5 @@ import Lean +import Mathlib.Tactic.Core namespace Utils @@ -211,4 +212,31 @@ example : Nat := by example (x : Bool) : Nat := by cases x <;> custom_let x := 3 <;> apply x +-- Repeatedly apply a tactic +partial def repeatTac (tac : Tactic.TacticM Unit) : Tactic.TacticM Unit := do + try + tac + Tactic.allGoals (Tactic.focus (repeatTac tac)) + -- TODO: does this restore the state? + catch _ => pure () + +def firstTac (tacl : List (Tactic.TacticM Unit)) : Tactic.TacticM Unit := do + match tacl with + | [] => pure () + | tac :: tacl => + try tac + catch _ => firstTac tacl + +-- Split the goal if it is a conjunction +def splitConjTarget : Tactic.TacticM Unit := do + Tactic.withMainContext do + let and_intro := Expr.const ``And.intro [] + let mvarIds' ← _root_.Lean.MVarId.apply (← Tactic.getMainGoal) and_intro + Term.synthesizeSyntheticMVarsNoPostponing + Tactic.replaceMainGoal mvarIds' + +-- Taken from Lean.Elab.Tactic.evalAssumption +def assumptionTac : Tactic.TacticM Unit := + Tactic.liftMetaTactic fun mvarId => do mvarId.assumption; pure [] + end Utils -- cgit v1.2.3