summaryrefslogtreecommitdiff
path: root/backends/lean/Base/Progress
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--backends/lean/Base/Progress/Base.lean290
-rw-r--r--backends/lean/Base/Progress/Progress.lean91
2 files changed, 98 insertions, 283 deletions
diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean
index 76a92795..0ad16ab6 100644
--- a/backends/lean/Base/Progress/Base.lean
+++ b/backends/lean/Base/Progress/Base.lean
@@ -2,11 +2,12 @@ import Lean
import Std.Lean.HashSet
import Base.Utils
import Base.Primitives.Base
+import Base.Extensions
namespace Progress
open Lean Elab Term Meta
-open Utils
+open Utils Extensions
-- We can't define and use trace classes in the same file
initialize registerTraceClass `Progress
@@ -15,17 +16,17 @@ initialize registerTraceClass `Progress
structure PSpecDesc where
-- The universally quantified variables
+ -- Can be fvars or mvars
fvars : Array Expr
-- The existentially quantified variables
evars : Array Expr
+ -- The function applied to its arguments
+ fArgsExpr : Expr
-- The function
- fExpr : Expr
fName : Name
-- The function arguments
fLevels : List Level
args : Array Expr
- -- The universally quantified variables which appear in the function arguments
- argsFVars : Array FVarId
-- The returned value
ret : Expr
-- The postcondition (if there is)
@@ -37,7 +38,7 @@ section Methods
variable [MonadError m]
variable {a : Type}
- /- Analyze a pspec theorem to decompose its arguments.
+ /- Analyze a goal or a pspec theorem to decompose its arguments.
PSpec theorems should be of the following shape:
```
@@ -56,12 +57,20 @@ section Methods
TODO: generalize for when we do inductive proofs
-/
partial
- def withPSpec [Inhabited (m a)] [Nonempty (m a)] (th : Expr) (k : PSpecDesc → m a)
- (sanityChecks : Bool := false) :
+ def withPSpec [Inhabited (m a)] [Nonempty (m a)]
+ (isGoal : Bool) (th : Expr) (k : PSpecDesc → m a) :
m a := do
trace[Progress] "Proposition: {th}"
-- Dive into the quantified variables and the assumptions
- forallTelescope th.consumeMData fun fvars th => do
+ -- Note that if we analyze a pspec theorem to register it in a database (i.e.
+ -- a discrimination tree), we need to introduce *meta-variables* for the
+ -- quantified variables.
+ let telescope (k : Array Expr → Expr → m a) : m a :=
+ if isGoal then forallTelescope th.consumeMData (fun fvars th => k fvars th)
+ else do
+ let (fvars, _, th) ← forallMetaTelescope th.consumeMData
+ k fvars th
+ telescope fun fvars th => do
trace[Progress] "Universally quantified arguments and assumptions: {fvars}"
-- Dive into the existentials
existsTelescope th.consumeMData fun evars th => do
@@ -78,7 +87,7 @@ section Methods
-- destruct the application to get the function name
mExpr.consumeMData.withApp fun mf margs => do
trace[Progress] "After stripping the arguments of the monad expression:\n- mf: {mf}\n- margs: {margs}"
- let (fExpr, f, args) ← do
+ let (fArgsExpr, f, args) ← do
if mf.isConst ∧ mf.constName = ``Bind.bind then do
-- Dive into the bind
let fExpr := (margs.get! 4).consumeMData
@@ -86,29 +95,27 @@ section Methods
else pure (mExpr, mf, margs)
trace[Progress] "After stripping the arguments of the function call:\n- f: {f}\n- args: {args}"
if ¬ f.isConst then throwError "Not a constant: {f}"
- -- Compute the set of universally quantified variables which appear in the function arguments
- let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty
- -- Sanity check
- if sanityChecks then
- -- All the variables which appear in the inputs given to the function are
- -- universally quantified (in particular, they are not *existentially* quantified)
- let fvarsSet : HashSet FVarId := HashSet.ofArray (fvars.map (fun x => x.fvarId!))
- let filtArgsFVars := allArgsFVars.toArray.filter (fun fvar => ¬ fvarsSet.contains fvar)
- if ¬ filtArgsFVars.isEmpty then
+ -- *Sanity check* (activated if we are analyzing a theorem to register it in a DB)
+ -- Check if some existentially quantified variables
+ let _ := do
+ -- Collect all the free variables in the arguments
+ let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty
+ -- Check if they intersect the fvars we introduced for the existentially quantified variables
+ let evarsSet : HashSet FVarId := HashSet.ofArray (evars.map (fun (x : Expr) => x.fvarId!))
+ let filtArgsFVars := allArgsFVars.toArray.filter (fun var => evarsSet.contains var)
+ if filtArgsFVars.isEmpty then pure ()
+ else
let filtArgsFVars := filtArgsFVars.map (fun fvarId => Expr.fvar fvarId)
throwError "Some of the function inputs are not universally quantified: {filtArgsFVars}"
- let argsFVars := fvars.map (fun x => x.fvarId!)
- let argsFVars := argsFVars.filter (fun fvar => allArgsFVars.contains fvar)
-- Return
- trace[Progress] "Function: {f.constName!}";
+ trace[Progress] "Function with arguments: {fArgsExpr}";
let thDesc := {
fvars := fvars
evars := evars
- fExpr
+ fArgsExpr
fName := f.constName!
fLevels := f.constLevels!
args := args
- argsFVars
ret := ret
post := post
}
@@ -116,117 +123,18 @@ section Methods
end Methods
-def getPSpecFunName (th : Expr) : MetaM Name :=
- withPSpec th (fun d => do pure d.fName) true
+def getPSpecFunArgsExpr (isGoal : Bool) (th : Expr) : MetaM Expr :=
+ withPSpec isGoal th (fun d => do pure d.fArgsExpr)
-def getPSpecClassFunNames (th : Expr) : MetaM (Name × Name) :=
- withPSpec th (fun d => do
- let arg0 := d.args.get! 0
- arg0.withApp fun f _ => do
- if ¬ f.isConst then throwError "Not a constant: {f}"
- pure (d.fName, f.constName)
- ) true
-
-def getPSpecClassFunNameArg (th : Expr) : MetaM (Name × Expr) :=
- withPSpec th (fun d => do
- let arg0 := d.args.get! 0
- pure (d.fName, arg0)
- ) true
-
--- "Regular" pspec attribute
+-- pspec attribute
structure PSpecAttr where
attr : AttributeImpl
- ext : MapDeclarationExtension Name
- deriving Inhabited
-
-/- pspec attribute for type classes: we use the name of the type class to
- lookup another map. We use the *first* argument of the type class to lookup
- into this second map.
-
- Example:
- ========
- We use type classes for addition. For instance, the addition between two
- U32 is written (without syntactic sugar) as `HAdd.add (Scalar ty) x y`. As a consequence,
- we store the theorem through the bindings: HAdd.add → Scalar → ...
-
- SH: TODO: this (and `PSpecClassExprAttr`) is a bit ad-hoc. For now it works for the
- specs of the scalar operations, which is what I really need, but I'm not sure it
- applies well to other situations. A better way would probably to use type classes, but
- I couldn't get them to work on those cases. It is worth retrying.
--/
-structure PSpecClassAttr where
- attr : AttributeImpl
- ext : MapDeclarationExtension (NameMap Name)
- deriving Inhabited
-
-/- Same as `PSpecClassAttr` but we use the full first argument (it works when it
- is a constant). -/
-structure PSpecClassExprAttr where
- attr : AttributeImpl
- ext : MapDeclarationExtension (HashMap Expr Name)
+ ext : DiscrTreeExtension Name true
deriving Inhabited
--- TODO: the original function doesn't define correctly the `addImportedFn`. Do a PR?
-def mkMapDeclarationExtension [Inhabited α] (name : Name := by exact decl_name%) :
- IO (MapDeclarationExtension α) :=
- registerSimplePersistentEnvExtension {
- name := name,
- addImportedFn := fun a => a.foldl (fun s a => a.foldl (fun s (k, v) => s.insert k v) s) RBMap.empty,
- addEntryFn := fun s n => s.insert n.1 n.2 ,
- toArrayFn := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1)
- }
-
--- Declare an extension of maps of maps (using [RBMap]).
--- The important point is that we need to merge the bound values (which are maps).
-def mkMapMapDeclarationExtension [Inhabited β] (ord : α → α → Ordering)
- (name : Name := by exact decl_name%) :
- IO (MapDeclarationExtension (RBMap α β ord)) :=
- registerSimplePersistentEnvExtension {
- name := name,
- addImportedFn := fun a =>
- a.foldl (fun s a => a.foldl (
- -- We need to merge the maps
- fun s (k0, k1_to_v) =>
- match s.find? k0 with
- | none =>
- -- No binding: insert one
- s.insert k0 k1_to_v
- | some m =>
- -- There is already a binding: merge
- let m := RBMap.fold (fun m k v => m.insert k v) m k1_to_v
- s.insert k0 m)
- s) RBMap.empty,
- addEntryFn := fun s n => s.insert n.1 n.2 ,
- toArrayFn := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1)
- }
-
--- Declare an extension of maps of maps (using [HashMap]).
--- The important point is that we need to merge the bound values (which are maps).
-def mkMapHashMapDeclarationExtension [BEq α] [Hashable α] [Inhabited β]
- (name : Name := by exact decl_name%) :
- IO (MapDeclarationExtension (HashMap α β)) :=
- registerSimplePersistentEnvExtension {
- name := name,
- addImportedFn := fun a =>
- a.foldl (fun s a => a.foldl (
- -- We need to merge the maps
- fun s (k0, k1_to_v) =>
- match s.find? k0 with
- | none =>
- -- No binding: insert one
- s.insert k0 k1_to_v
- | some m =>
- -- There is already a binding: merge
- let m := HashMap.fold (fun m k v => m.insert k v) m k1_to_v
- s.insert k0 m)
- s) RBMap.empty,
- addEntryFn := fun s n => s.insert n.1 n.2 ,
- toArrayFn := fun es => es.toArray.qsort (fun a b => Name.quickLt a.1 b.1)
- }
-
-/- The persistent map from function to pspec theorems. -/
+/- The persistent map from expressions to pspec theorems. -/
initialize pspecAttr : PSpecAttr ← do
- let ext ← mkMapDeclarationExtension `pspecMap
+ let ext ← mkDiscrTreeExtention `pspecMap true
let attrImpl : AttributeImpl := {
name := `pspec
descr := "Marks theorems to use with the `progress` tactic"
@@ -238,130 +146,30 @@ initialize pspecAttr : PSpecAttr ← do
-- Lookup the theorem
let env ← getEnv
let thDecl := env.constants.find! thName
- let fName ← MetaM.run' (getPSpecFunName thDecl.type)
- trace[Progress] "Registering spec theorem for {fName}"
- let env := ext.addEntry env (fName, thName)
- setEnv env
- pure ()
- }
- registerBuiltinAttribute attrImpl
- pure { attr := attrImpl, ext := ext }
-
-/- The persistent map from type classes to pspec theorems -/
-initialize pspecClassAttr : PSpecClassAttr ← do
- let ext : MapDeclarationExtension (NameMap Name) ←
- mkMapMapDeclarationExtension Name.quickCmp `pspecClassMap
- let attrImpl : AttributeImpl := {
- name := `cpspec
- descr := "Marks theorems to use for type classes with the `progress` tactic"
- add := fun thName stx attrKind => do
- Attribute.Builtin.ensureNoArgs stx
- -- TODO: use the attribute kind
- unless attrKind == AttributeKind.global do
- throwError "invalid attribute 'cpspec', must be global"
- -- Lookup the theorem
- let env ← getEnv
- let thDecl := env.constants.find! thName
- let (fName, argName) ← MetaM.run' (getPSpecClassFunNames thDecl.type)
- trace[Progress] "Registering class spec theorem for ({fName}, {argName})"
- -- Update the entry if there is one, add an entry if there is none
- let env :=
- match (ext.getState (← getEnv)).find? fName with
- | none =>
- let m := RBMap.ofList [(argName, thName)]
- ext.addEntry env (fName, m)
- | some m =>
- let m := m.insert argName thName
- ext.addEntry env (fName, m)
+ let fKey ← MetaM.run' (do
+ let fExpr ← getPSpecFunArgsExpr false thDecl.type
+ trace[Progress] "Registering spec theorem for {fExpr}"
+ -- Convert the function expression to a discrimination tree key
+ DiscrTree.mkPath fExpr)
+ let env := ext.addEntry env (fKey, thName)
setEnv env
+ trace[Progress] "Saved the environment"
pure ()
}
registerBuiltinAttribute attrImpl
pure { attr := attrImpl, ext := ext }
-/- The 2nd persistent map from type classes to pspec theorems -/
-initialize pspecClassExprAttr : PSpecClassExprAttr ← do
- let ext : MapDeclarationExtension (HashMap Expr Name) ←
- mkMapHashMapDeclarationExtension `pspecClassExprMap
- let attrImpl : AttributeImpl := {
- name := `cepspec
- descr := "Marks theorems to use for type classes with the `progress` tactic"
- add := fun thName stx attrKind => do
- Attribute.Builtin.ensureNoArgs stx
- -- TODO: use the attribute kind
- unless attrKind == AttributeKind.global do
- throwError "invalid attribute 'cpspec', must be global"
- -- Lookup the theorem
- let env ← getEnv
- let thDecl := env.constants.find! thName
- let (fName, arg) ← MetaM.run' (getPSpecClassFunNameArg thDecl.type)
- -- Sanity check: no variables appear in the argument
- MetaM.run' do
- let fvars ← getFVarIds arg
- if ¬ fvars.isEmpty then throwError "The first argument ({arg}) contains variables"
- -- We store two bindings:
- -- - arg to theorem name
- -- - reduced arg to theorem name
- let rarg ← MetaM.run' (reduceAll arg)
- trace[Progress] "Registering class spec theorem for ({fName}, {arg}) and ({fName}, {rarg})"
- -- Update the entry if there is one, add an entry if there is none
- let env :=
- match (ext.getState (← getEnv)).find? fName with
- | none =>
- let m := HashMap.ofList [(arg, thName), (rarg, thName)]
- ext.addEntry env (fName, m)
- | some m =>
- let m := m.insert arg thName
- let m := m.insert rarg thName
- ext.addEntry env (fName, m)
- setEnv env
- pure ()
- }
- registerBuiltinAttribute attrImpl
- pure { attr := attrImpl, ext := ext }
-
-
-def PSpecAttr.find? (s : PSpecAttr) (name : Name) : MetaM (Option Name) := do
- return (s.ext.getState (← getEnv)).find? name
-
-def PSpecClassAttr.find? (s : PSpecClassAttr) (className argName : Name) : MetaM (Option Name) := do
- match (s.ext.getState (← getEnv)).find? className with
- | none => return none
- | some map => return map.find? argName
-
-def PSpecClassExprAttr.find? (s : PSpecClassExprAttr) (className : Name) (arg : Expr) : MetaM (Option Name) := do
- match (s.ext.getState (← getEnv)).find? className with
- | none => return none
- | some map => return map.find? arg
-
-def PSpecAttr.getState (s : PSpecAttr) : MetaM (NameMap Name) := do
- pure (s.ext.getState (← getEnv))
+def PSpecAttr.find? (s : PSpecAttr) (e : Expr) : MetaM (Array Name) := do
+ (s.ext.getState (← getEnv)).getMatch e
-def PSpecClassAttr.getState (s : PSpecClassAttr) : MetaM (NameMap (NameMap Name)) := do
- pure (s.ext.getState (← getEnv))
-
-def PSpecClassExprAttr.getState (s : PSpecClassExprAttr) : MetaM (NameMap (HashMap Expr Name)) := do
+def PSpecAttr.getState (s : PSpecAttr) : MetaM (DiscrTree Name true) := do
pure (s.ext.getState (← getEnv))
def showStoredPSpec : MetaM Unit := do
let st ← pspecAttr.getState
- let s := st.toList.foldl (fun s (f, th) => f!"{s}\n{f} → {th}") f!""
- IO.println s
-
-def showStoredPSpecClass : MetaM Unit := do
- let st ← pspecClassAttr.getState
- let s := st.toList.foldl (fun s (f, m) =>
- let ms := m.toList.foldl (fun s (f, th) =>
- f!"{s}\n {f} → {th}") f!""
- f!"{s}\n{f} → [{ms}]") f!""
- IO.println s
-
-def showStoredPSpecExprClass : MetaM Unit := do
- let st ← pspecClassExprAttr.getState
- let s := st.toList.foldl (fun s (f, m) =>
- let ms := m.toList.foldl (fun s (f, th) =>
- f!"{s}\n {f} → {th}") f!""
- f!"{s}\n{f} → [{ms}]") f!""
+ -- TODO: how can we iterate over (at least) the values stored in the tree?
+ --let s := st.toList.foldl (fun s (f, th) => f!"{s}\n{f} → {th}") f!""
+ let s := f!"{st}"
IO.println s
end Progress
diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean
index ba63f09d..a6a4e82a 100644
--- a/backends/lean/Base/Progress/Progress.lean
+++ b/backends/lean/Base/Progress/Progress.lean
@@ -204,11 +204,11 @@ def getFirstArg (args : Array Expr) : Option Expr := do
if args.size = 0 then none
else some (args.get! 0)
-/- Helper: try to lookup a theorem and apply it, or continue with another tactic
- if it fails -/
+/- Helper: try to lookup a theorem and apply it.
+ Return true if it succeeded. -/
def tryLookupApply (keep : Option Name) (ids : Array (Option Name)) (splitPost : Bool)
(asmTac : TacticM Unit) (fExpr : Expr)
- (kind : String) (th : Option TheoremOrLocal) (x : TacticM Unit) : TacticM Unit := do
+ (kind : String) (th : Option TheoremOrLocal) : TacticM Bool := do
let res ← do
match th with
| none =>
@@ -223,9 +223,9 @@ def tryLookupApply (keep : Option Name) (ids : Array (Option Name)) (splitPost :
pure (some res)
catch _ => none
match res with
- | some .Ok => return ()
+ | some .Ok => return true
| some (.Error msg) => throwError msg
- | none => x
+ | none => return false
-- The array of ids are identifiers to use when introducing fresh variables
def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrLocal)
@@ -236,11 +236,19 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL
let goalTy ← mgoal.getType
trace[Progress] "goal: {goalTy}"
-- Dive into the goal to lookup the theorem
- let (fExpr, fName, args) ← do
- withPSpec goalTy fun desc =>
- -- TODO: check that no quantified variables in the arguments
- pure (desc.fExpr, desc.fName, desc.args)
- trace[Progress] "Function: {fName}"
+ -- Remark: if we don't isolate the call to `withPSpec` to immediately "close"
+ -- the terms immediately, we may end up with the error:
+ -- "(kernel) declaration has free variables"
+ -- I'm not sure I understand why.
+ -- TODO: we should also check that no quantified variable appears in fExpr.
+ -- If such variables appear, we should just fail because the goal doesn't
+ -- have the proper shape.
+ let fExpr ← do
+ let isGoal := true
+ withPSpec isGoal goalTy fun desc => do
+ let fExpr := desc.fArgsExpr
+ trace[Progress] "Expression to match: {fExpr}"
+ pure fExpr
-- If the user provided a theorem/assumption: use it.
-- Otherwise, lookup one.
match withTh with
@@ -258,36 +266,24 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL
match res with
| .Ok => return ()
| .Error msg => throwError msg
- -- It failed: try to lookup a theorem
- -- TODO: use a list of theorems, and try them one by one?
- trace[Progress] "No assumption succeeded: trying to lookup a theorem"
- let pspec ← do
- let thName ← pspecAttr.find? fName
- pure (thName.map fun th => .Theorem th)
- tryLookupApply keep ids splitPost asmTac fExpr "pspec theorem" pspec do
- -- It failed: try to lookup a *class* expr spec theorem (those are more
- -- specific than class spec theorems)
- trace[Progress] "Failed using a pspec theorem: trying to lookup a pspec class expr theorem"
- let pspecClassExpr ← do
- match getFirstArg args with
- | none => pure none
- | some arg => do
- trace[Progress] "Using: f:{fName}, arg: {arg}"
- let thName ← pspecClassExprAttr.find? fName arg
- pure (thName.map fun th => .Theorem th)
- tryLookupApply keep ids splitPost asmTac fExpr "pspec class expr theorem" pspecClassExpr do
- -- It failed: try to lookup a *class* spec theorem
- trace[Progress] "Failed using a pspec class expr theorem: trying to lookup a pspec class theorem"
- let pspecClass ← do
- match ← getFirstArgAppName args with
- | none => pure none
- | some argName => do
- trace[Progress] "Using: f: {fName}, arg: {argName}"
- let thName ← pspecClassAttr.find? fName argName
- pure (thName.map fun th => .Theorem th)
- tryLookupApply keep ids splitPost asmTac fExpr "pspec class theorem" pspecClass do
- trace[Progress] "Failed using a pspec class theorem: trying to use a recursive assumption"
- -- Try a recursive call - we try the assumptions of kind "auxDecl"
+ -- It failed: lookup the pspec theorems which match the expression
+ trace[Progress] "No assumption succeeded: trying to lookup a pspec theorem"
+ let pspecs : Array TheoremOrLocal ← do
+ let thNames ← pspecAttr.find? fExpr
+ -- TODO: because of reduction, there may be several valid theorems (for
+ -- instance for the scalars). We need to sort them from most specific to
+ -- least specific. For now, we assume the most specific theorems are at
+ -- the end.
+ let thNames := thNames.reverse
+ trace[Progress] "Looked up pspec theorems: {thNames}"
+ pure (thNames.map fun th => TheoremOrLocal.Theorem th)
+ -- Try the theorems one by one
+ for pspec in pspecs do
+ if ← tryLookupApply keep ids splitPost asmTac fExpr "pspec theorem" pspec then return ()
+ else pure ()
+ -- It failed: try to use the recursive assumptions
+ trace[Progress] "Failed using a pspec theorem: trying to use a recursive assumption"
+ -- We try to apply the assumptions of kind "auxDecl"
let ctx ← Lean.MonadLCtx.getLCtx
let decls ← ctx.getAllDecls
let decls := decls.filter (λ decl => match decl.kind with
@@ -381,8 +377,6 @@ namespace Test
-- The following commands display the databases of theorems
-- #eval showStoredPSpec
- -- #eval showStoredPSpecClass
- -- #eval showStoredPSpecExprClass
open alloc.vec
example {ty} {x y : Scalar ty}
@@ -402,6 +396,19 @@ namespace Test
example {x y : U32}
(hmax : x.val + y.val ≤ U32.max) :
∃ z, x + y = ret z ∧ z.val = x.val + y.val := by
+ -- This spec theorem is suboptimal, but it is good to check that it works
+ progress with Scalar.add_spec as ⟨ z, h1 .. ⟩
+ simp [*, h1]
+
+ example {x y : U32}
+ (hmax : x.val + y.val ≤ U32.max) :
+ ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by
+ progress with U32.add_spec as ⟨ z, h1 .. ⟩
+ simp [*, h1]
+
+ example {x y : U32}
+ (hmax : x.val + y.val ≤ U32.max) :
+ ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by
progress keep _ as ⟨ z, h1 .. ⟩
simp [*, h1]