summaryrefslogtreecommitdiff
path: root/backends/lean/Base/Progress
diff options
context:
space:
mode:
Diffstat (limited to 'backends/lean/Base/Progress')
-rw-r--r--backends/lean/Base/Progress/Base.lean175
-rw-r--r--backends/lean/Base/Progress/Progress.lean112
2 files changed, 287 insertions, 0 deletions
diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean
new file mode 100644
index 00000000..3f44f46c
--- /dev/null
+++ b/backends/lean/Base/Progress/Base.lean
@@ -0,0 +1,175 @@
+import Lean
+import Base.Utils
+import Base.Primitives
+
+namespace Progress
+
+open Lean Elab Term Meta
+open Utils
+
+-- We can't define and use trace classes in the same file
+initialize registerTraceClass `Progress
+
+-- Return the first conjunct if the expression is a conjunction, or the
+-- expression itself otherwise. Also return the second conjunct if it is a
+-- conjunction.
+def getFirstConj (e : Expr) : MetaM (Expr × Option Expr) := do
+ e.withApp fun f args =>
+ if f.isConstOf ``And ∧ args.size = 2 then pure (args.get! 0, some (args.get! 1))
+ else pure (e, none)
+
+-- Destruct an equaliy and return the two sides
+def destEq (e : Expr) : MetaM (Expr × Expr) := do
+ e.withApp fun f args =>
+ if f.isConstOf ``Eq ∧ args.size = 3 then pure (args.get! 1, args.get! 2)
+ else throwError "Not an equality: {e}"
+
+-- Return the set of FVarIds in the expression
+partial def getFVarIds (e : Expr) (hs : HashSet FVarId := HashSet.empty) : MetaM (HashSet FVarId) := do
+ e.withApp fun body args => do
+ let hs := if body.isFVar then hs.insert body.fvarId! else hs
+ args.foldlM (fun hs arg => getFVarIds arg hs) hs
+
+/- # Progress tactic -/
+
+structure PSpecDesc where
+ -- The universally quantified variables
+ fvars : Array Expr
+ -- The existentially quantified variables
+ evars : Array Expr
+ -- The function
+ fName : Name
+ -- The function arguments
+ fLevels : List Level
+ args : Array Expr
+ -- The universally quantified variables which appear in the function arguments
+ argsFVars : Array FVarId
+ -- The returned value
+ ret : Expr
+ -- The postcondition (if there is)
+ post : Option Expr
+
+section Methods
+ variable [MonadLiftT MetaM m] [MonadControlT MetaM m] [Monad m] [MonadOptions m]
+ variable [MonadTrace m] [MonadLiftT IO m] [MonadRef m] [AddMessageContext m]
+ variable [MonadError m]
+ variable {a : Type}
+
+ /- Analyze a pspec theorem to decompose its arguments.
+
+ PSpec theorems should be of the following shape:
+ ```
+ ∀ x1 ... xn, H1 → ... Hn → ∃ y1 ... ym. f x1 ... xn = .ret ... ∧ Post1 ∧ ... ∧ Postk
+ ```
+
+ The continuation `k` receives the following inputs:
+ - universally quantified variables
+ - assumptions
+ - existentially quantified variables
+ - function name
+ - function arguments
+ - return
+ - postconditions
+
+ TODO: generalize for when we do inductive proofs
+ -/
+ partial
+ def withPSpec [Inhabited (m a)] [Nonempty (m a)] (th : Expr) (k : PSpecDesc → m a)
+ (sanityChecks : Bool := false) :
+ m a := do
+ trace[Progress] "Theorem: {th}"
+ -- Dive into the quantified variables and the assumptions
+ forallTelescope th fun fvars th => do
+ trace[Progress] "All argumens: {fvars}"
+ /- -- Filter the argumens which are not propositions
+ let rec getFirstPropIdx (i : Nat) : MetaM Nat := do
+ if i ≥ fargs.size then pure i
+ else do
+ let x := fargs.get! i
+ if ← Meta.isProp (← inferType x) then pure i
+ else getFirstPropIdx (i + 1)
+ let i ← getFirstPropIdx 0
+ let fvars := fargs.extract 0 i
+ let hyps := fargs.extract i fargs.size
+ trace[Progress] "Quantified variables: {fvars}"
+ trace[Progress] "Assumptions: {hyps}"
+ -- Sanity check: all hypotheses are propositions (in particular, all the
+ -- quantified variables are at the beginning)
+ let hypsAreProp ← hyps.allM fun x => do Meta.isProp (← inferType x)
+ if ¬ hypsAreProp then
+ throwError "The theorem doesn't have the proper shape: all the quantified arguments should be at the beginning"
+ -/
+ -- Dive into the existentials
+ existsTelescope th fun evars th => do
+ trace[Progress] "Existentials: {evars}"
+ -- Take the first conjunct
+ let (th, post) ← getFirstConj th
+ -- Destruct the equality
+ let (th, ret) ← destEq th
+ -- Destruct the application to get the name
+ th.withApp fun f args => do
+ if ¬ f.isConst then throwError "Not a constant: {f}"
+ -- Compute the set of universally quantified variables which appear in the function arguments
+ let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty
+ -- Sanity check
+ if sanityChecks then
+ let fvarsSet : HashSet FVarId := HashSet.ofArray (fvars.map (fun x => x.fvarId!))
+ let filtArgsFVars := allArgsFVars.toArray.filter (fun fvar => ¬ fvarsSet.contains fvar)
+ if ¬ filtArgsFVars.isEmpty then
+ let filtArgsFVars := filtArgsFVars.map (fun fvarId => Expr.fvar fvarId)
+ throwError "Some of the function inputs are not universally quantified: {filtArgsFVars}"
+ let argsFVars := fvars.map (fun x => x.fvarId!)
+ let argsFVars := argsFVars.filter (fun fvar => allArgsFVars.contains fvar)
+ -- Return
+ trace[Progress] "Function: {f.constName!}";
+ let thDesc := {
+ fvars := fvars
+ evars := evars
+ fName := f.constName!
+ fLevels := f.constLevels!
+ args := args
+ argsFVars
+ ret := ret
+ post := post
+ }
+ k thDesc
+end Methods
+
+
+def getPSpecFunName (th : Expr) : MetaM Name :=
+ withPSpec th (fun d => do pure d.fName) true
+
+structure PSpecAttr where
+ attr : AttributeImpl
+ ext : MapDeclarationExtension Name
+ deriving Inhabited
+
+/- The persistent map from function to pspec theorems. -/
+initialize pspecAttr : PSpecAttr ← do
+ let ext ← mkMapDeclarationExtension `pspecMap
+ let attrImpl := {
+ name := `pspec
+ descr := "Marks theorems to use with the `progress` tactic"
+ add := fun thName stx attrKind => do
+ Attribute.Builtin.ensureNoArgs stx
+ -- TODO: use the attribute kind
+ unless attrKind == AttributeKind.global do
+ throwError "invalid attribute 'pspec', must be global"
+ -- Lookup the theorem
+ let env ← getEnv
+ let thDecl := env.constants.find! thName
+ let fName ← MetaM.run' (getPSpecFunName thDecl.type)
+ trace[Progress] "Registering spec theorem for {fName}"
+ let env := ext.addEntry env (fName, thName)
+ setEnv env
+ pure ()
+ }
+ registerBuiltinAttribute attrImpl
+ pure { attr := attrImpl, ext := ext }
+
+def PSpecAttr.find? (s : PSpecAttr) (name : Name) : MetaM (Option Name) := do
+ return (s.ext.getState (← getEnv)).find? name
+ --return s.ext.find? (← getEnv) name
+
+
+end Progress
diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean
new file mode 100644
index 00000000..1b9ee55c
--- /dev/null
+++ b/backends/lean/Base/Progress/Progress.lean
@@ -0,0 +1,112 @@
+import Lean
+import Base.Arith
+import Base.Progress.Base
+
+namespace Progress
+
+open Lean Elab Term Meta Tactic
+open Utils
+
+namespace Test
+ open Primitives
+
+ set_option trace.Progress true
+
+ @[pspec]
+ theorem vec_index_test (α : Type u) (v: Vec α) (i: Usize) (h: i.val < v.val.length) :
+ ∃ x, v.index α i = .ret x := by
+ apply
+ sorry
+
+ #eval pspecAttr.find? ``Primitives.Vec.index
+end Test
+
+#check isDefEq
+#check allGoals
+
+def progressLookupTheorem (asmTac : TacticM Unit) : TacticM Unit := do
+ withMainContext do
+ -- Retrieve the goal
+ let mgoal ← Tactic.getMainGoal
+ let goalTy ← mgoal.getType
+ -- Dive into the goal to lookup the theorem
+ let (fName, fLevels, args) ← do
+ withPSpec goalTy fun desc =>
+ -- TODO: check that no universally quantified variables in the arguments
+ pure (desc.fName, desc.fLevels, desc.args)
+ -- TODO: also try the assumptions
+ trace[Progress] "Function: {fName}"
+ -- TODO: use a list of theorems, and try them one by one?
+ let thName ← do
+ match ← pspecAttr.find? fName with
+ | none => throwError "Could not find a pspec theorem for {fName}"
+ | some thName => pure thName
+ trace[Progress] "Lookuped up: {thName}"
+ /- Apply the theorem
+ We try to match the theorem with the goal
+ In order to do so, we introduce meta-variables for all the parameters
+ (i.e., quantified variables and assumpions), and unify those with the goal.
+ Remark: we do not introduce meta-variables for the quantified variables
+ which don't appear in the function arguments (we want to let them
+ quantified).
+ We also make sure that all the meta variables which appear in the
+ function arguments have been instantiated
+ -/
+ let env ← getEnv
+ let thDecl := env.constants.find! thName
+ let thTy := thDecl.type
+ -- TODO: the tactic fails if we uncomment withNewMCtxDepth
+ -- withNewMCtxDepth do
+ let (mvars, binders, thExBody) ← forallMetaTelescope thTy
+ -- Introduce the existentially quantified variables and the post-condition
+ -- in the context
+ let thBody ←
+ existsTelescope thExBody fun _evars thBody => do
+ let (thBody, _) ← destEq thBody
+ -- There shouldn't be any existential variables in thBody
+ pure thBody
+ -- Match the body with the target
+ let target := mkAppN (.const fName fLevels) args
+ trace[Progress] "mvars:\n{mvars.map Expr.mvarId!}"
+ trace[Progress] "thBody: {thBody}"
+ trace[Progress] "target: {target}"
+ let ok ← isDefEq thBody target
+ if ¬ ok then throwError "Could not unify the theorem with the target:\n- theorem: {thBody}\n- target: {target}"
+ postprocessAppMVars `progress mgoal mvars binders true true
+ Term.synthesizeSyntheticMVarsNoPostponing
+ let thBody ← instantiateMVars thBody
+ trace[Progress] "thBody (after instantiation): {thBody}"
+ -- Add the instantiated theorem to the assumptions (we apply it on the metavariables).
+ let th ← mkAppOptM thName (mvars.map some)
+ let asmName ← mkFreshUserName `h
+ let thTy ← inferType th
+ let thAsm ← Utils.addDecl asmName th thTy (asLet := false)
+ -- Update the set of goals
+ let curGoals ← getUnsolvedGoals
+ let newGoals := mvars.map Expr.mvarId!
+ let newGoals ← newGoals.filterM fun mvar => not <$> mvar.isAssigned
+ trace[Progress] "new goals: {newGoals}"
+ setGoals newGoals.toList
+ allGoals asmTac
+ let newGoals ← getUnsolvedGoals
+ setGoals (newGoals ++ curGoals)
+ --
+ pure ()
+
+elab "progress" : tactic => do
+ progressLookupTheorem (firstTac [assumptionTac, Arith.intTac])
+
+namespace Test
+ open Primitives
+
+ set_option trace.Progress true
+
+ @[pspec]
+ theorem vec_index_test2 (α : Type u) (v: Vec α) (i: Usize) (h: i.val < v.val.length) :
+ ∃ x, v.index α i = .ret x := by
+ progress
+ tauto
+
+end Test
+
+end Progress