summaryrefslogtreecommitdiff
path: root/backends/lean
diff options
context:
space:
mode:
authorSon Ho2023-07-20 11:22:18 +0200
committerSon Ho2023-07-20 11:22:18 +0200
commit975b7c555cbffef2648a6469b777d1f1760d926d (patch)
tree147695467dd2bf4d92cf3c84556b15e741fb41be /backends/lean
parent821b09b14794ebc2fe7b7047fc60fd56fb2cd107 (diff)
Improve progress further
Diffstat (limited to 'backends/lean')
-rw-r--r--backends/lean/Base/Progress/Base.lean102
-rw-r--r--backends/lean/Base/Progress/Progress.lean112
2 files changed, 142 insertions, 72 deletions
diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean
index 72438d40..3599d866 100644
--- a/backends/lean/Base/Progress/Base.lean
+++ b/backends/lean/Base/Progress/Base.lean
@@ -62,24 +62,6 @@ section Methods
-- Dive into the quantified variables and the assumptions
forallTelescope th fun fvars th => do
trace[Progress] "Universally quantified arguments and assumptions: {fvars}"
- /- -- Filter the argumens which are not propositions
- let rec getFirstPropIdx (i : Nat) : MetaM Nat := do
- if i ≥ fargs.size then pure i
- else do
- let x := fargs.get! i
- if ← Meta.isProp (← inferType x) then pure i
- else getFirstPropIdx (i + 1)
- let i ← getFirstPropIdx 0
- let fvars := fargs.extract 0 i
- let hyps := fargs.extract i fargs.size
- trace[Progress] "Quantified variables: {fvars}"
- trace[Progress] "Assumptions: {hyps}"
- -- Sanity check: all hypotheses are propositions (in particular, all the
- -- quantified variables are at the beginning)
- let hypsAreProp ← hyps.allM fun x => do Meta.isProp (← inferType x)
- if ¬ hypsAreProp then
- throwError "The theorem doesn't have the proper shape: all the quantified arguments should be at the beginning"
- -/
-- Dive into the existentials
existsTelescope th fun evars th => do
trace[Progress] "Existentials: {evars}"
@@ -98,6 +80,8 @@ section Methods
let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty
-- Sanity check
if sanityChecks then
+ -- All the variables which appear in the inputs given to the function are
+ -- universally quantified (in particular, they are not *existentially* quantified)
let fvarsSet : HashSet FVarId := HashSet.ofArray (fvars.map (fun x => x.fvarId!))
let filtArgsFVars := allArgsFVars.toArray.filter (fun fvar => ¬ fvarsSet.contains fvar)
if ¬ filtArgsFVars.isEmpty then
@@ -132,6 +116,12 @@ def getPSpecClassFunNames (th : Expr) : MetaM (Name × Name) :=
pure (d.fName, f.constName)
) true
+def getPSpecClassFunNameArg (th : Expr) : MetaM (Name × Expr) :=
+ withPSpec th (fun d => do
+ let arg0 := d.args.get! 0
+ pure (d.fName, arg0)
+ ) true
+
-- "Regular" pspec attribute
structure PSpecAttr where
attr : AttributeImpl
@@ -145,14 +135,26 @@ structure PSpecAttr where
Example:
========
We use type classes for addition. For instance, the addition between two
- U32 is written (without syntactic sugar) as `HAdd.add (Scalar ) x y`. As a consequence,
+ U32 is written (without syntactic sugar) as `HAdd.add (Scalar ty) x y`. As a consequence,
we store the theorem through the bindings: HAdd.add → Scalar → ...
+
+ SH: TODO: this (and `PSpecClassExprAttr`) is a bit ad-hoc. For now it works for the
+ specs of the scalar operations, which is what I really need, but I'm not sure it
+ applies well to other situations. A better way would probably to use type classes, but
+ I couldn't get them to work on those cases. It is worth retrying.
-/
structure PSpecClassAttr where
attr : AttributeImpl
ext : MapDeclarationExtension (NameMap Name)
deriving Inhabited
+/- Same as `PSpecClassAttr` but we use the full first argument (it works when it
+ is a constant). -/
+structure PSpecClassExprAttr where
+ attr : AttributeImpl
+ ext : MapDeclarationExtension (HashMap Expr Name)
+ deriving Inhabited
+
-- TODO: the original function doesn't define correctly the `addImportedFn`. Do a PR?
def mkMapDeclarationExtension [Inhabited α] (name : Name := by exact decl_name%) : IO (MapDeclarationExtension α) :=
registerSimplePersistentEnvExtension {
@@ -216,21 +218,69 @@ initialize pspecClassAttr : PSpecClassAttr ← do
registerBuiltinAttribute attrImpl
pure { attr := attrImpl, ext := ext }
+/- The 2nd persistent map from type classes to pspec theorems -/
+initialize pspecClassExprAttr : PSpecClassExprAttr ← do
+ let ext : MapDeclarationExtension (HashMap Expr Name) ← mkMapDeclarationExtension `pspecClassExprMap
+ let attrImpl : AttributeImpl := {
+ name := `cepspec
+ descr := "Marks theorems to use for type classes with the `progress` tactic"
+ add := fun thName stx attrKind => do
+ Attribute.Builtin.ensureNoArgs stx
+ -- TODO: use the attribute kind
+ unless attrKind == AttributeKind.global do
+ throwError "invalid attribute 'cpspec', must be global"
+ -- Lookup the theorem
+ let env ← getEnv
+ let thDecl := env.constants.find! thName
+ let (fName, arg) ← MetaM.run' (getPSpecClassFunNameArg thDecl.type)
+ -- Sanity check: no variables appear in the argument
+ MetaM.run' do
+ let fvars ← getFVarIds arg
+ if ¬ fvars.isEmpty then throwError "The first argument ({arg}) contains variables"
+ -- We store two bindings:
+ -- - arg to theorem name
+ -- - reduced arg to theorem name
+ let rarg ← MetaM.run' (reduce arg)
+ trace[Progress] "Registering class spec theorem for ({fName}, {arg}) and ({fName}, {rarg})"
+ -- Update the entry if there is one, add an entry if there is none
+ let env :=
+ match (ext.getState (← getEnv)).find? fName with
+ | none =>
+ let m := HashMap.ofList [(arg, thName), (rarg, thName)]
+ ext.addEntry env (fName, m)
+ | some m =>
+ let m := m.insert arg thName
+ let m := m.insert rarg thName
+ ext.addEntry env (fName, m)
+ setEnv env
+ pure ()
+ }
+ registerBuiltinAttribute attrImpl
+ pure { attr := attrImpl, ext := ext }
+
def PSpecAttr.find? (s : PSpecAttr) (name : Name) : MetaM (Option Name) := do
return (s.ext.getState (← getEnv)).find? name
-def PSpecAttr.getState (s : PSpecAttr) : MetaM (NameMap Name) := do
- pure (s.ext.getState (← getEnv))
-
def PSpecClassAttr.find? (s : PSpecClassAttr) (className argName : Name) : MetaM (Option Name) := do
match (s.ext.getState (← getEnv)).find? className with
| none => return none
| some map => return map.find? argName
+def PSpecClassExprAttr.find? (s : PSpecClassExprAttr) (className : Name) (arg : Expr) : MetaM (Option Name) := do
+ match (s.ext.getState (← getEnv)).find? className with
+ | none => return none
+ | some map => return map.find? arg
+
+def PSpecAttr.getState (s : PSpecAttr) : MetaM (NameMap Name) := do
+ pure (s.ext.getState (← getEnv))
+
def PSpecClassAttr.getState (s : PSpecClassAttr) : MetaM (NameMap (NameMap Name)) := do
pure (s.ext.getState (← getEnv))
+def PSpecClassExprAttr.getState (s : PSpecClassExprAttr) : MetaM (NameMap (HashMap Expr Name)) := do
+ pure (s.ext.getState (← getEnv))
+
def showStoredPSpec : MetaM Unit := do
let st ← pspecAttr.getState
let s := st.toList.foldl (fun s (f, th) => f!"{s}\n{f} → {th}") f!""
@@ -244,4 +294,12 @@ def showStoredPSpecClass : MetaM Unit := do
f!"{s}\n{f} → [{ms}]") f!""
IO.println s
+def showStoredPSpecExprClass : MetaM Unit := do
+ let st ← pspecClassExprAttr.getState
+ let s := st.toList.foldl (fun s (f, m) =>
+ let ms := m.toList.foldl (fun s (f, th) =>
+ f!"{s}\n {f} → {th}") f!""
+ f!"{s}\n{f} → [{ms}]") f!""
+ IO.println s
+
end Progress
diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean
index 84053150..974a6364 100644
--- a/backends/lean/Base/Progress/Progress.lean
+++ b/backends/lean/Base/Progress/Progress.lean
@@ -27,6 +27,9 @@ inductive TheoremOrLocal where
| Theorem (thName : Name)
| Local (asm : LocalDecl)
+instance : ToMessageData TheoremOrLocal where
+ toMessageData := λ x => match x with | .Theorem thName => m!"{thName}" | .Local asm => m!"{asm.userName}"
+
/- Type to propagate the errors of `progressWith`.
We need this because we use the exceptions to backtrack, when trying to
use the assumptions for instance. When there is actually an error we want
@@ -161,6 +164,32 @@ def getFirstArgAppName (args : Array Expr) : MetaM (Option Name) := do
if f.isConst then pure (some f.constName)
else pure none
+def getFirstArg (args : Array Expr) : Option Expr := do
+ if args.size = 0 then none
+ else some (args.get! 0)
+
+/- Helper: try to lookup a theorem and apply it, or continue with another tactic
+ if it fails -/
+def tryLookupApply (ids : Array Name) (asmTac : TacticM Unit) (fnExpr : Expr)
+ (kind : String) (th : Option TheoremOrLocal) (x : TacticM Unit) : TacticM Unit := do
+ let res ← do
+ match th with
+ | none =>
+ trace[Progress] "Could not find a {kind}"
+ pure none
+ | some th => do
+ trace[Progress] "Lookuped up {kind}: {th}"
+ -- Apply the theorem
+ let res ← do
+ try
+ let res ← progressWith fnExpr th ids asmTac
+ pure (some res)
+ catch _ => none
+ match res with
+ | some .Ok => return ()
+ | some (.Error msg) => throwError msg
+ | none => x
+
-- The array of ids are identifiers to use when introducing fresh variables
def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Name) (asmTac : TacticM Unit) : TacticM Unit := do
withMainContext do
@@ -196,57 +225,40 @@ def progressAsmsOrLookupTheorem (withTh : Option TheoremOrLocal) (ids : Array Na
-- It failed: try to lookup a theorem
-- TODO: use a list of theorems, and try them one by one?
trace[Progress] "No assumption succeeded: trying to lookup a theorem"
- let res ←
- match ← pspecAttr.find? fName with
- | some thName =>
- trace[Progress] "Lookuped up theorem: {thName}"
- -- Apply the theorem
- let res ← do
- try
- let res ← progressWith fnExpr (.Theorem thName) ids asmTac
- pure (some res)
- catch _ => none
- | none =>
- trace[Progress] "Could not find a pspec theorem for {fName}"
- pure none
- match res with
- | some .Ok => return ()
- | some (.Error msg) => throwError msg
- | none =>
- -- It failed: try to lookup a *class* spec theorem
- let res ← do
- match ← getFirstArgAppName args with
- | none => none
- | some argName => do
- match ← pspecClassAttr.find? fName argName with
- | some thName =>
- trace[Progress] "Lookuped up class theorem: {thName}"
- -- Apply the theorem
- let res ← do
- try
- let res ← progressWith fnExpr (.Theorem thName) ids asmTac
- pure (some res)
- catch _ => none
- | none =>
- trace[Progress] "Could not find a class pspec theorem for ({fName}, {argName})"
- pure none
+ let pspec ← do
+ let thName ← pspecAttr.find? fName
+ pure (thName.map fun th => .Theorem th)
+ tryLookupApply ids asmTac fnExpr "pspec theorem" pspec do
+ -- It failed: try to lookup a *class* expr spec theorem (those are more
+ -- specific than class spec theorems)
+ let pspecClassExpr ← do
+ match getFirstArg args with
+ | none => pure none
+ | some arg => do
+ let thName ← pspecClassExprAttr.find? fName arg
+ pure (thName.map fun th => .Theorem th)
+ tryLookupApply ids asmTac fnExpr "pspec class expr theorem" pspecClassExpr do
+ -- It failed: try to lookup a *class* spec theorem
+ let pspecClass ← do
+ match ← getFirstArgAppName args with
+ | none => pure none
+ | some argName => do
+ let thName ← pspecClassAttr.find? fName argName
+ pure (thName.map fun th => .Theorem th)
+ tryLookupApply ids asmTac fnExpr "pspec class theorem" pspecClass do
+ -- Try a recursive call - we try the assumptions of kind "auxDecl"
+ let ctx ← Lean.MonadLCtx.getLCtx
+ let decls ← ctx.getAllDecls
+ let decls := decls.filter (λ decl => match decl.kind with
+ | .default | .implDetail => false | .auxDecl => true)
+ for decl in decls.reverse do
+ trace[Progress] "Trying recursive assumption: {decl.userName} : {decl.type}"
+ let res ← do try progressWith fnExpr (.Local decl) ids asmTac catch _ => continue
match res with
- | some .Ok => return ()
- | some (.Error msg) => throwError msg
- | none =>
- -- Try a recursive call - we try the assumptions of kind "auxDecl"
- let ctx ← Lean.MonadLCtx.getLCtx
- let decls ← ctx.getAllDecls
- let decls := decls.filter (λ decl => match decl.kind with
- | .default | .implDetail => false | .auxDecl => true)
- for decl in decls.reverse do
- trace[Progress] "Trying recursive assumption: {decl.userName} : {decl.type}"
- let res ← do try progressWith fnExpr (.Local decl) ids asmTac catch _ => continue
- match res with
- | .Ok => return ()
- | .Error msg => throwError msg
- -- Nothing worked: failed
- throwError "Progress failed"
+ | .Ok => return ()
+ | .Error msg => throwError msg
+ -- Nothing worked: failed
+ throwError "Progress failed"
syntax progressArgs := ("with" ident)? ("as" " ⟨ " (ident)+ " ⟩")?