summaryrefslogtreecommitdiff
path: root/backends/lean
diff options
context:
space:
mode:
authorSon Ho2023-07-11 15:23:49 +0200
committerSon Ho2023-07-11 15:23:49 +0200
commit6166c410a4b3353377e640acbae9f56e877a9118 (patch)
tree648f114119502db5126604ffee4c011c2c2913e3 /backends/lean
parent7206b48a73d6204baea99f4f4675be2518a8f8c2 (diff)
Work on the progress tactic
Diffstat (limited to 'backends/lean')
-rw-r--r--backends/lean/Base/Arith.lean1
-rw-r--r--backends/lean/Base/Arith/Arith.lean42
-rw-r--r--backends/lean/Base/Progress.lean1
-rw-r--r--backends/lean/Base/Progress/Base.lean175
-rw-r--r--backends/lean/Base/Progress/Progress.lean112
-rw-r--r--backends/lean/Base/Utils.lean28
6 files changed, 346 insertions, 13 deletions
diff --git a/backends/lean/Base/Arith.lean b/backends/lean/Base/Arith.lean
new file mode 100644
index 00000000..fd5698c5
--- /dev/null
+++ b/backends/lean/Base/Arith.lean
@@ -0,0 +1 @@
+import Base.Arith.Arith
diff --git a/backends/lean/Base/Arith/Arith.lean b/backends/lean/Base/Arith/Arith.lean
index 0ba73d18..ff628cf3 100644
--- a/backends/lean/Base/Arith/Arith.lean
+++ b/backends/lean/Base/Arith/Arith.lean
@@ -146,7 +146,7 @@ def collectInstances
-- Similar to `collectInstances`, but explores all the local declarations in the
-- main context.
def collectInstancesFromMainCtx (k : Expr → MetaM (Option Expr)) : Tactic.TacticM (HashSet Expr) := do
- Lean.Elab.Tactic.withMainContext do
+ Tactic.withMainContext do
-- Get the local context
let ctx ← Lean.MonadLCtx.getLCtx
-- Just a matter of precaution
@@ -263,8 +263,8 @@ example {a: Type} (v : Vec a) : v.val.length ≤ Scalar.max ScalarTy.Usize := by
def splitDisj (h : Expr) (kleft kright : Tactic.TacticM Unit) : Tactic.TacticM Unit := do
trace[Arith] "assumption on which to split: {h}"
-- Retrieve the main goal
- Lean.Elab.Tactic.withMainContext do
- let goalType ← Lean.Elab.Tactic.getMainTarget
+ Tactic.withMainContext do
+ let goalType ← Tactic.getMainTarget
let hDecl := (← getLCtx).get! h.fvarId!
let hName := hDecl.userName
-- Case disjunction
@@ -316,7 +316,7 @@ def splitDisj (h : Expr) (kleft kright : Tactic.TacticM Unit) : Tactic.TacticM U
trace[Arith] "new goals: {← Tactic.getUnsolvedGoals}"
elab "split_disj " n:ident : tactic => do
- Lean.Elab.Tactic.withMainContext do
+ Tactic.withMainContext do
let decl ← Lean.Meta.getLocalDeclFromUserName n.getId
let fvar := mkFVar decl.fvarId
splitDisj fvar (fun _ => pure ()) (fun _ => pure ())
@@ -347,7 +347,7 @@ example (x y : Int) (h0 : x ≤ y) (h1 : x ≠ y) : x < y := by
TODO: we could create a PR for mathlib.
-/
def intTacPreprocess : Tactic.TacticM Unit := do
- Lean.Elab.Tactic.withMainContext do
+ Tactic.withMainContext do
-- Lookup the instances of PropHasImp (this is how we detect assumptions
-- of the proper shape), introduce assumptions in the context and split
-- on those
@@ -366,19 +366,31 @@ def intTacPreprocess : Tactic.TacticM Unit := do
elab "int_tac_preprocess" : tactic =>
intTacPreprocess
+def intTac : Tactic.TacticM Unit := do
+ Tactic.withMainContext do
+ Tactic.focus do
+ -- Preprocess - wondering if we should do this before or after splitting
+ -- the goal. I think before leads to a smaller proof term?
+ Tactic.allGoals intTacPreprocess
+ -- Split the conjunctions in the goal
+ Utils.repeatTac Utils.splitConjTarget
+ -- Call linarith
+ let linarith :=
+ let cfg : Linarith.LinarithConfig := {
+ -- We do this with our custom preprocessing
+ splitNe := false
+ }
+ Tactic.liftMetaFinishingTactic <| Linarith.linarith false [] cfg
+ Tactic.allGoals linarith
+
+elab "int_tac" : tactic =>
+ intTac
+
example (x : Int) (h0: 0 ≤ x) (h1: x ≠ 0) : 0 < x := by
int_tac_preprocess
linarith
linarith
-syntax "int_tac" : tactic
-macro_rules
- | `(tactic| int_tac) =>
- `(tactic|
- (repeat (apply And.intro)) <;> -- TODO: improve this
- int_tac_preprocess <;>
- linarith)
-
example (x : Int) (h0: 0 ≤ x) (h1: x ≠ 0) : 0 < x := by
int_tac
@@ -386,6 +398,10 @@ example (x : Int) (h0: 0 ≤ x) (h1: x ≠ 0) : 0 < x := by
example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y := by
int_tac
+-- Checking that things append correctly when there are several disjunctions
+example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y ∧ x + y ≥ 2 := by
+ int_tac
+
-- A tactic to solve linear arithmetic goals in the presence of scalars
syntax "scalar_tac" : tactic
macro_rules
diff --git a/backends/lean/Base/Progress.lean b/backends/lean/Base/Progress.lean
new file mode 100644
index 00000000..d812b896
--- /dev/null
+++ b/backends/lean/Base/Progress.lean
@@ -0,0 +1 @@
+import Base.Progress.Progress
diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean
new file mode 100644
index 00000000..3f44f46c
--- /dev/null
+++ b/backends/lean/Base/Progress/Base.lean
@@ -0,0 +1,175 @@
+import Lean
+import Base.Utils
+import Base.Primitives
+
+namespace Progress
+
+open Lean Elab Term Meta
+open Utils
+
+-- We can't define and use trace classes in the same file
+initialize registerTraceClass `Progress
+
+-- Return the first conjunct if the expression is a conjunction, or the
+-- expression itself otherwise. Also return the second conjunct if it is a
+-- conjunction.
+def getFirstConj (e : Expr) : MetaM (Expr × Option Expr) := do
+ e.withApp fun f args =>
+ if f.isConstOf ``And ∧ args.size = 2 then pure (args.get! 0, some (args.get! 1))
+ else pure (e, none)
+
+-- Destruct an equaliy and return the two sides
+def destEq (e : Expr) : MetaM (Expr × Expr) := do
+ e.withApp fun f args =>
+ if f.isConstOf ``Eq ∧ args.size = 3 then pure (args.get! 1, args.get! 2)
+ else throwError "Not an equality: {e}"
+
+-- Return the set of FVarIds in the expression
+partial def getFVarIds (e : Expr) (hs : HashSet FVarId := HashSet.empty) : MetaM (HashSet FVarId) := do
+ e.withApp fun body args => do
+ let hs := if body.isFVar then hs.insert body.fvarId! else hs
+ args.foldlM (fun hs arg => getFVarIds arg hs) hs
+
+/- # Progress tactic -/
+
+structure PSpecDesc where
+ -- The universally quantified variables
+ fvars : Array Expr
+ -- The existentially quantified variables
+ evars : Array Expr
+ -- The function
+ fName : Name
+ -- The function arguments
+ fLevels : List Level
+ args : Array Expr
+ -- The universally quantified variables which appear in the function arguments
+ argsFVars : Array FVarId
+ -- The returned value
+ ret : Expr
+ -- The postcondition (if there is)
+ post : Option Expr
+
+section Methods
+ variable [MonadLiftT MetaM m] [MonadControlT MetaM m] [Monad m] [MonadOptions m]
+ variable [MonadTrace m] [MonadLiftT IO m] [MonadRef m] [AddMessageContext m]
+ variable [MonadError m]
+ variable {a : Type}
+
+ /- Analyze a pspec theorem to decompose its arguments.
+
+ PSpec theorems should be of the following shape:
+ ```
+ ∀ x1 ... xn, H1 → ... Hn → ∃ y1 ... ym. f x1 ... xn = .ret ... ∧ Post1 ∧ ... ∧ Postk
+ ```
+
+ The continuation `k` receives the following inputs:
+ - universally quantified variables
+ - assumptions
+ - existentially quantified variables
+ - function name
+ - function arguments
+ - return
+ - postconditions
+
+ TODO: generalize for when we do inductive proofs
+ -/
+ partial
+ def withPSpec [Inhabited (m a)] [Nonempty (m a)] (th : Expr) (k : PSpecDesc → m a)
+ (sanityChecks : Bool := false) :
+ m a := do
+ trace[Progress] "Theorem: {th}"
+ -- Dive into the quantified variables and the assumptions
+ forallTelescope th fun fvars th => do
+ trace[Progress] "All argumens: {fvars}"
+ /- -- Filter the argumens which are not propositions
+ let rec getFirstPropIdx (i : Nat) : MetaM Nat := do
+ if i ≥ fargs.size then pure i
+ else do
+ let x := fargs.get! i
+ if ← Meta.isProp (← inferType x) then pure i
+ else getFirstPropIdx (i + 1)
+ let i ← getFirstPropIdx 0
+ let fvars := fargs.extract 0 i
+ let hyps := fargs.extract i fargs.size
+ trace[Progress] "Quantified variables: {fvars}"
+ trace[Progress] "Assumptions: {hyps}"
+ -- Sanity check: all hypotheses are propositions (in particular, all the
+ -- quantified variables are at the beginning)
+ let hypsAreProp ← hyps.allM fun x => do Meta.isProp (← inferType x)
+ if ¬ hypsAreProp then
+ throwError "The theorem doesn't have the proper shape: all the quantified arguments should be at the beginning"
+ -/
+ -- Dive into the existentials
+ existsTelescope th fun evars th => do
+ trace[Progress] "Existentials: {evars}"
+ -- Take the first conjunct
+ let (th, post) ← getFirstConj th
+ -- Destruct the equality
+ let (th, ret) ← destEq th
+ -- Destruct the application to get the name
+ th.withApp fun f args => do
+ if ¬ f.isConst then throwError "Not a constant: {f}"
+ -- Compute the set of universally quantified variables which appear in the function arguments
+ let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty
+ -- Sanity check
+ if sanityChecks then
+ let fvarsSet : HashSet FVarId := HashSet.ofArray (fvars.map (fun x => x.fvarId!))
+ let filtArgsFVars := allArgsFVars.toArray.filter (fun fvar => ¬ fvarsSet.contains fvar)
+ if ¬ filtArgsFVars.isEmpty then
+ let filtArgsFVars := filtArgsFVars.map (fun fvarId => Expr.fvar fvarId)
+ throwError "Some of the function inputs are not universally quantified: {filtArgsFVars}"
+ let argsFVars := fvars.map (fun x => x.fvarId!)
+ let argsFVars := argsFVars.filter (fun fvar => allArgsFVars.contains fvar)
+ -- Return
+ trace[Progress] "Function: {f.constName!}";
+ let thDesc := {
+ fvars := fvars
+ evars := evars
+ fName := f.constName!
+ fLevels := f.constLevels!
+ args := args
+ argsFVars
+ ret := ret
+ post := post
+ }
+ k thDesc
+end Methods
+
+
+def getPSpecFunName (th : Expr) : MetaM Name :=
+ withPSpec th (fun d => do pure d.fName) true
+
+structure PSpecAttr where
+ attr : AttributeImpl
+ ext : MapDeclarationExtension Name
+ deriving Inhabited
+
+/- The persistent map from function to pspec theorems. -/
+initialize pspecAttr : PSpecAttr ← do
+ let ext ← mkMapDeclarationExtension `pspecMap
+ let attrImpl := {
+ name := `pspec
+ descr := "Marks theorems to use with the `progress` tactic"
+ add := fun thName stx attrKind => do
+ Attribute.Builtin.ensureNoArgs stx
+ -- TODO: use the attribute kind
+ unless attrKind == AttributeKind.global do
+ throwError "invalid attribute 'pspec', must be global"
+ -- Lookup the theorem
+ let env ← getEnv
+ let thDecl := env.constants.find! thName
+ let fName ← MetaM.run' (getPSpecFunName thDecl.type)
+ trace[Progress] "Registering spec theorem for {fName}"
+ let env := ext.addEntry env (fName, thName)
+ setEnv env
+ pure ()
+ }
+ registerBuiltinAttribute attrImpl
+ pure { attr := attrImpl, ext := ext }
+
+def PSpecAttr.find? (s : PSpecAttr) (name : Name) : MetaM (Option Name) := do
+ return (s.ext.getState (← getEnv)).find? name
+ --return s.ext.find? (← getEnv) name
+
+
+end Progress
diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean
new file mode 100644
index 00000000..1b9ee55c
--- /dev/null
+++ b/backends/lean/Base/Progress/Progress.lean
@@ -0,0 +1,112 @@
+import Lean
+import Base.Arith
+import Base.Progress.Base
+
+namespace Progress
+
+open Lean Elab Term Meta Tactic
+open Utils
+
+namespace Test
+ open Primitives
+
+ set_option trace.Progress true
+
+ @[pspec]
+ theorem vec_index_test (α : Type u) (v: Vec α) (i: Usize) (h: i.val < v.val.length) :
+ ∃ x, v.index α i = .ret x := by
+ apply
+ sorry
+
+ #eval pspecAttr.find? ``Primitives.Vec.index
+end Test
+
+#check isDefEq
+#check allGoals
+
+def progressLookupTheorem (asmTac : TacticM Unit) : TacticM Unit := do
+ withMainContext do
+ -- Retrieve the goal
+ let mgoal ← Tactic.getMainGoal
+ let goalTy ← mgoal.getType
+ -- Dive into the goal to lookup the theorem
+ let (fName, fLevels, args) ← do
+ withPSpec goalTy fun desc =>
+ -- TODO: check that no universally quantified variables in the arguments
+ pure (desc.fName, desc.fLevels, desc.args)
+ -- TODO: also try the assumptions
+ trace[Progress] "Function: {fName}"
+ -- TODO: use a list of theorems, and try them one by one?
+ let thName ← do
+ match ← pspecAttr.find? fName with
+ | none => throwError "Could not find a pspec theorem for {fName}"
+ | some thName => pure thName
+ trace[Progress] "Lookuped up: {thName}"
+ /- Apply the theorem
+ We try to match the theorem with the goal
+ In order to do so, we introduce meta-variables for all the parameters
+ (i.e., quantified variables and assumpions), and unify those with the goal.
+ Remark: we do not introduce meta-variables for the quantified variables
+ which don't appear in the function arguments (we want to let them
+ quantified).
+ We also make sure that all the meta variables which appear in the
+ function arguments have been instantiated
+ -/
+ let env ← getEnv
+ let thDecl := env.constants.find! thName
+ let thTy := thDecl.type
+ -- TODO: the tactic fails if we uncomment withNewMCtxDepth
+ -- withNewMCtxDepth do
+ let (mvars, binders, thExBody) ← forallMetaTelescope thTy
+ -- Introduce the existentially quantified variables and the post-condition
+ -- in the context
+ let thBody ←
+ existsTelescope thExBody fun _evars thBody => do
+ let (thBody, _) ← destEq thBody
+ -- There shouldn't be any existential variables in thBody
+ pure thBody
+ -- Match the body with the target
+ let target := mkAppN (.const fName fLevels) args
+ trace[Progress] "mvars:\n{mvars.map Expr.mvarId!}"
+ trace[Progress] "thBody: {thBody}"
+ trace[Progress] "target: {target}"
+ let ok ← isDefEq thBody target
+ if ¬ ok then throwError "Could not unify the theorem with the target:\n- theorem: {thBody}\n- target: {target}"
+ postprocessAppMVars `progress mgoal mvars binders true true
+ Term.synthesizeSyntheticMVarsNoPostponing
+ let thBody ← instantiateMVars thBody
+ trace[Progress] "thBody (after instantiation): {thBody}"
+ -- Add the instantiated theorem to the assumptions (we apply it on the metavariables).
+ let th ← mkAppOptM thName (mvars.map some)
+ let asmName ← mkFreshUserName `h
+ let thTy ← inferType th
+ let thAsm ← Utils.addDecl asmName th thTy (asLet := false)
+ -- Update the set of goals
+ let curGoals ← getUnsolvedGoals
+ let newGoals := mvars.map Expr.mvarId!
+ let newGoals ← newGoals.filterM fun mvar => not <$> mvar.isAssigned
+ trace[Progress] "new goals: {newGoals}"
+ setGoals newGoals.toList
+ allGoals asmTac
+ let newGoals ← getUnsolvedGoals
+ setGoals (newGoals ++ curGoals)
+ --
+ pure ()
+
+elab "progress" : tactic => do
+ progressLookupTheorem (firstTac [assumptionTac, Arith.intTac])
+
+namespace Test
+ open Primitives
+
+ set_option trace.Progress true
+
+ @[pspec]
+ theorem vec_index_test2 (α : Type u) (v: Vec α) (i: Usize) (h: i.val < v.val.length) :
+ ∃ x, v.index α i = .ret x := by
+ progress
+ tauto
+
+end Test
+
+end Progress
diff --git a/backends/lean/Base/Utils.lean b/backends/lean/Base/Utils.lean
index 2ce63620..1351f3d4 100644
--- a/backends/lean/Base/Utils.lean
+++ b/backends/lean/Base/Utils.lean
@@ -1,4 +1,5 @@
import Lean
+import Mathlib.Tactic.Core
namespace Utils
@@ -211,4 +212,31 @@ example : Nat := by
example (x : Bool) : Nat := by
cases x <;> custom_let x := 3 <;> apply x
+-- Repeatedly apply a tactic
+partial def repeatTac (tac : Tactic.TacticM Unit) : Tactic.TacticM Unit := do
+ try
+ tac
+ Tactic.allGoals (Tactic.focus (repeatTac tac))
+ -- TODO: does this restore the state?
+ catch _ => pure ()
+
+def firstTac (tacl : List (Tactic.TacticM Unit)) : Tactic.TacticM Unit := do
+ match tacl with
+ | [] => pure ()
+ | tac :: tacl =>
+ try tac
+ catch _ => firstTac tacl
+
+-- Split the goal if it is a conjunction
+def splitConjTarget : Tactic.TacticM Unit := do
+ Tactic.withMainContext do
+ let and_intro := Expr.const ``And.intro []
+ let mvarIds' ← _root_.Lean.MVarId.apply (← Tactic.getMainGoal) and_intro
+ Term.synthesizeSyntheticMVarsNoPostponing
+ Tactic.replaceMainGoal mvarIds'
+
+-- Taken from Lean.Elab.Tactic.evalAssumption
+def assumptionTac : Tactic.TacticM Unit :=
+ Tactic.liftMetaTactic fun mvarId => do mvarId.assumption; pure []
+
end Utils