summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-12-11 11:31:43 +0100
committerSon Ho2023-12-11 11:31:43 +0100
commit10a77d17ea06b732106348588bedc6a89766d56f (patch)
tree92a3cfabd788a2f74f71157e80b49f722d8d15f1
parent3c092169efcbc36a9b435c68c590b36f69204f94 (diff)
Reactivate the sanity checks for the progress tactic
-rw-r--r--backends/lean/Base/Progress/Base.lean40
-rw-r--r--backends/lean/Base/Progress/Progress.lean4
-rw-r--r--backends/lean/Base/Utils.lean6
3 files changed, 24 insertions, 26 deletions
diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean
index d50c357c..a72cd641 100644
--- a/backends/lean/Base/Progress/Base.lean
+++ b/backends/lean/Base/Progress/Base.lean
@@ -40,6 +40,7 @@ def mkDiscrTreeExtention [Inhabited α] [BEq α] (name : Name := by exact decl_n
structure PSpecDesc where
-- The universally quantified variables
+ -- Can be fvars or mvars
fvars : Array Expr
-- The existentially quantified variables
evars : Array Expr
@@ -50,8 +51,6 @@ structure PSpecDesc where
-- The function arguments
fLevels : List Level
args : Array Expr
- -- The universally quantified variables which appear in the function arguments
- argsFVars : Array FVarId
-- The returned value
ret : Expr
-- The postcondition (if there is)
@@ -82,7 +81,7 @@ section Methods
TODO: generalize for when we do inductive proofs
-/
partial
- def withPSpec [Inhabited (m a)] [Nonempty (m a)] (sanityChecks : Bool := false)
+ def withPSpec [Inhabited (m a)] [Nonempty (m a)]
(isGoal : Bool) (th : Expr) (k : PSpecDesc → m a) :
m a := do
trace[Progress] "Proposition: {th}"
@@ -120,19 +119,18 @@ section Methods
else pure (mExpr, mf, margs)
trace[Progress] "After stripping the arguments of the function call:\n- f: {f}\n- args: {args}"
if ¬ f.isConst then throwError "Not a constant: {f}"
- -- Compute the set of universally quantified variables which appear in the function arguments
- let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty
- -- Sanity check
- if sanityChecks then
- -- All the variables which appear in the inputs given to the function are
- -- universally quantified (in particular, they are not *existentially* quantified)
- let fvarsSet : HashSet FVarId := HashSet.ofArray (fvars.map (fun x => x.fvarId!))
- let filtArgsFVars := allArgsFVars.toArray.filter (fun fvar => ¬ fvarsSet.contains fvar)
- if ¬ filtArgsFVars.isEmpty then
+ -- *Sanity check* (activated if we are analyzing a theorem to register it in a DB)
+ -- Check if some existentially quantified variables
+ let _ := do
+ -- Collect all the free variables in the arguments
+ let allArgsFVars := ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty
+ -- Check if they intersect the fvars we introduced for the existentially quantified variables
+ let evarsSet : HashSet FVarId := HashSet.ofArray (evars.map (fun (x : Expr) => x.fvarId!))
+ let filtArgsFVars := allArgsFVars.toArray.filter (fun var => evarsSet.contains var)
+ if filtArgsFVars.isEmpty then pure ()
+ else
let filtArgsFVars := filtArgsFVars.map (fun fvarId => Expr.fvar fvarId)
throwError "Some of the function inputs are not universally quantified: {filtArgsFVars}"
- let argsFVars := fvars.map (fun x => x.fvarId!)
- let argsFVars := argsFVars.filter (fun fvar => allArgsFVars.contains fvar)
-- Return
trace[Progress] "Function with arguments: {fArgsExpr}";
let thDesc := {
@@ -142,7 +140,6 @@ section Methods
fName := f.constName!
fLevels := f.constLevels!
args := args
- argsFVars
ret := ret
post := post
}
@@ -150,11 +147,8 @@ section Methods
end Methods
-/-def getPSpecFunArgsExpr (th : Expr) : MetaM Expr :=
- withPSpec true th (fun d => do pure d.fArgsExpr)
-
-def getPSpecFunName (th : Expr) : MetaM Name :=
- withPSpec true th (fun d => do pure d.fName)-/
+def getPSpecFunArgsExpr (isGoal : Bool) (th : Expr) : MetaM Expr :=
+ withPSpec isGoal th (fun d => do pure d.fArgsExpr)
-- pspec attribute
structure PSpecAttr where
@@ -176,14 +170,14 @@ initialize pspecAttr : PSpecAttr ← do
-- Lookup the theorem
let env ← getEnv
let thDecl := env.constants.find! thName
- let isGoal := false
- let fKey ← MetaM.run' (withPSpec true isGoal thDecl.type fun d => do
- let fExpr := d.fArgsExpr
+ let fKey ← MetaM.run' (do
+ let fExpr ← getPSpecFunArgsExpr false thDecl.type
trace[Progress] "Registering spec theorem for {fExpr}"
-- Convert the function expression to a discrimination tree key
DiscrTree.mkPath fExpr)
let env := ext.addEntry env (fKey, thName)
setEnv env
+ trace[Progress] "Saved the environment"
pure ()
}
registerBuiltinAttribute attrImpl
diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean
index 93b7d7d5..a6a4e82a 100644
--- a/backends/lean/Base/Progress/Progress.lean
+++ b/backends/lean/Base/Progress/Progress.lean
@@ -245,7 +245,7 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL
-- have the proper shape.
let fExpr ← do
let isGoal := true
- withPSpec false isGoal goalTy fun desc => do
+ withPSpec isGoal goalTy fun desc => do
let fExpr := desc.fArgsExpr
trace[Progress] "Expression to match: {fExpr}"
pure fExpr
@@ -386,8 +386,6 @@ namespace Test
progress keep _ as ⟨ z, h1 .. ⟩
simp [*, h1]
- set_option trace.Progress false
-
example {ty} {x y : Scalar ty}
(hmin : Scalar.min ty ≤ x.val + y.val)
(hmax : x.val + y.val ≤ Scalar.max ty) :
diff --git a/backends/lean/Base/Utils.lean b/backends/lean/Base/Utils.lean
index b917a789..95b2c38b 100644
--- a/backends/lean/Base/Utils.lean
+++ b/backends/lean/Base/Utils.lean
@@ -381,6 +381,12 @@ partial def getFVarIds (e : Expr) (hs : HashSet FVarId := HashSet.empty) : MetaM
let hs := if body.isFVar then hs.insert body.fvarId! else hs
args.foldlM (fun hs arg => getFVarIds arg hs) hs
+-- Return the set of MVarIds in the expression
+partial def getMVarIds (e : Expr) (hs : HashSet MVarId := HashSet.empty) : MetaM (HashSet MVarId) := do
+ e.withApp fun body args => do
+ let hs := if body.isMVar then hs.insert body.mvarId! else hs
+ args.foldlM (fun hs arg => getMVarIds arg hs) hs
+
-- Tactic to split on a disjunction.
-- The expression `h` should be an fvar.
-- TODO: there must be simpler. Use use _root_.Lean.MVarId.cases for instance