summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-12-11 19:48:42 +0100
committerSon Ho2023-12-11 19:48:42 +0100
commit24c5289d0ca039c1c64081285d7d120a04f40699 (patch)
treed2d73c3d73f62be0778a46beb428b2d4b59188d4
parent78367ef21c147b26040e0f6062a907fceab1f390 (diff)
Update the validity proofs for higher-order functions
Diffstat (limited to '')
-rw-r--r--backends/lean/Base/Diverge/Elab.lean101
-rw-r--r--backends/lean/Base/Utils.lean6
2 files changed, 96 insertions, 11 deletions
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean
index 97364d14..e555b3e3 100644
--- a/backends/lean/Base/Diverge/Elab.lean
+++ b/backends/lean/Base/Diverge/Elab.lean
@@ -820,15 +820,88 @@ partial def proveAppIsValid (k_var kk_var : Expr) (e : Expr) (f : Expr) (args :
Check if the arguments use the continuation:
- if no: this is simple
- if yes: we have to lookup theorems in div spec database and continue -/
- trace[Diverge.def.valid] "regular app: {e}"
+ trace[Diverge.def.valid] "regular app: {e}, f: {f}, args: {args}"
let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty
+ trace[Diverge.def.valid] "allArgsFVars: {allArgsFVars.toList.map mkFVar}"
if ¬ allArgsFVars.contains kk_var.fvarId! then do
-- Simple case
trace[Diverge.def.valid] "kk doesn't appear in the arguments"
proveNoKExprIsValid k_var e
else do
-- Lookup in the database for suitable theorems
- throwError "TODO: {e}"
+ trace[Diverge.def.valid] "kk appears in the arguments"
+ let thms ← divspecAttr.find? e
+ trace[Diverge.def.valid] "Looked up theorems: {thms}"
+ -- Try the theorems one by one
+ proveAppIsValidApplyThms k_var kk_var e f args thms.toList
+
+partial def proveAppIsValidApplyThms (k_var kk_var : Expr) (e : Expr)
+ (f : Expr) (args : Array Expr) (thms : List Name) : MetaM Expr := do
+ trace[Diverge.def.valid] "thms: {thms}"
+ match thms with
+ | [] => throwError "Could not prove that the following expression is valid: {e}"
+ | thName :: thms =>
+ -- Lookup the theorem itself
+ let env ← getEnv
+ let thDecl := env.constants.find! thName
+ -- Introduce fresh meta-variables for the universes
+ let ul : List (Name × Level) ←
+ thDecl.levelParams.mapM (λ x => do pure (x, ← mkFreshLevelMVar))
+ let ulMap : HashMap Name Level := HashMap.ofList ul
+ let thTy := thDecl.type.instantiateLevelParamsCore (λ x => ulMap.find! x)
+ -- Introduce meta variables for the universally quantified variables
+ let (mvars, _binders, thTy) ← forallMetaTelescope thTy
+ -- thTy should now be of the shape: `is_valid_p k (λ kk => ...)`
+ /-thTy.consumeMData.withApp fun _ args => do
+ if args.size ≠ 7 then throwError "Invalid number of arguments (expected 7): {thTy}"
+ let thTermToMatch := args.get! 6 -/
+ let thTermToMatch := thTy
+ trace[Diverge.def.valid] "thTermToMatch: {thTermToMatch}"
+ -- Create the term: `is_valid_p k (λ kk => e)`
+ let termToMatch ← mkLambdaFVars #[kk_var] e
+ let termToMatch ← mkAppM ``FixII.is_valid_p #[k_var, termToMatch]
+ trace[Diverge.def.valid] "termToMatch: {termToMatch}"
+ -- Attempt to match
+ let ok ← isDefEq thTermToMatch termToMatch
+ if ¬ ok then
+ -- Failure: attempt with the other theorems
+ proveAppIsValidApplyThms k_var kk_var e f args thms
+ else do
+ -- Success: continue with this theorem
+ -- Instantiate the meta variables (some of them will not be instantiated:
+ -- they are new subgoals)
+ let mvars ← mvars.mapM instantiateMVars
+ let th ← mkAppOptM thName (Array.map some mvars)
+ trace[Diverge.def.valid] "Instantiated theorm: {th}\n{← inferType th}"
+ -- Filter the meta variables between the instantiated ones
+ for mvar in mvars do
+ if mvar.isMVar then do
+ -- Prove the subgoal (i.e., the precondition of the theorem)
+ let mvarId := mvar.mvarId!
+ let mvarDecl ← mvarId.getDecl
+ -- Dive in the type
+ forallTelescope mvarDecl.type fun forall_vars mvar_e => do
+ -- `mvar_e` should have the shape `is_valid_p k (λ kk => ...)`
+ -- We need to retrieve the new `k` variable, and dive into the
+ -- `λ kk => ...`
+ mvar_e.consumeMData.withApp fun is_valid args => do
+ if is_valid.constName? ≠ ``FixII.is_valid_p ∨ args.size ≠ 2 then
+ throwError "Invalid precondition: {mvar_e}"
+ else do
+ let k_var := args.get! 0
+ let e_lam := args.get! 1
+ lambdaTelescope e_lam.consumeMData fun lvars e => do
+ if lvars.size ≠ 1 then throwError "Invalid number of lambdas (expected 1): {e_lam}"
+ let kk_var := lvars.get! 0
+ -- Continue
+ let e_valid ← proveExprIsValid k_var kk_var e
+ let e_valid ← mkForallFVars forall_vars e_valid
+ -- Assign the meta variable
+ mvarId.assign e_valid
+ else
+ -- Nothing to do
+ pure ()
+ pure th
-- Prove that a match expression is valid.
partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Expr := do
@@ -852,7 +925,8 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp
^^^^^^^^^^^^^^^^^^^^
this is the original match expression, with the
the difference that the scrutinee(s) is a variable
- ``` -/
+ ```
+ -/
let validMotive : Expr ← do
-- The motive is a function of the scrutinees (i.e., a lambda expression):
-- introduce binders for the scrutinees
@@ -905,7 +979,7 @@ partial def proveSingleBodyIsValid
trace[Diverge.def.valid] "body: {body}"
lambdaTelescope body fun xs body => do
trace[Diverge.def.valid] "xs: {xs}"
- assert! xs.size = 3
+ if xs.size ≠ 3 then throwError "Invalid number of lambdas: {xs} (expected 3)"
let kk_var := xs.get! 0
let t_var := xs.get! 1
let x_var := xs.get! 2
@@ -1349,6 +1423,7 @@ elab_rules : command
Command.elabCommand <| ← `(namespace $(mkIdentFrom id ns) $cmd end $(mkIdentFrom id ns))
namespace Tests
+
/- Some examples of partial functions -/
section HigherOrder
open Ex8
@@ -1357,9 +1432,6 @@ namespace Tests
| leaf (x : a)
| node (tl : List (Tree a))
- set_option trace.Diverge.def true
- -- set_option trace.Diverge.def.genBody true
- set_option trace.Diverge.def.valid true
divergent def id {a : Type u} (t : Tree a) : Result (Tree a) :=
match t with
| .leaf x => .ret (.leaf x)
@@ -1368,7 +1440,20 @@ namespace Tests
let tl ← map id tl
.ret (.node tl)
- set_option trace.Diverge.def false
+ #check id.unfold
+
+ /-set_option trace.Diverge.def true
+ -- set_option trace.Diverge.def.genBody true
+ set_option trace.Diverge.def.valid true
+ divergent def incr (t : Tree Nat) : Result (Tree Nat) :=
+ match t with
+ | .leaf x => .ret (.leaf (x + 1))
+ | .node tl =>
+ do
+ let tl ← map incr tl
+ .ret (.node tl)
+
+ set_option trace.Diverge.def false-/
end HigherOrder
diff --git a/backends/lean/Base/Utils.lean b/backends/lean/Base/Utils.lean
index 95b2c38b..2366e800 100644
--- a/backends/lean/Base/Utils.lean
+++ b/backends/lean/Base/Utils.lean
@@ -371,19 +371,19 @@ def splitConjTarget : TacticM Unit := do
-- Destruct an equaliy and return the two sides
def destEq (e : Expr) : MetaM (Expr × Expr) := do
- e.withApp fun f args =>
+ e.consumeMData.withApp fun f args =>
if f.isConstOf ``Eq ∧ args.size = 3 then pure (args.get! 1, args.get! 2)
else throwError "Not an equality: {e}"
-- Return the set of FVarIds in the expression
partial def getFVarIds (e : Expr) (hs : HashSet FVarId := HashSet.empty) : MetaM (HashSet FVarId) := do
- e.withApp fun body args => do
+ e.consumeMData.withApp fun body args => do
let hs := if body.isFVar then hs.insert body.fvarId! else hs
args.foldlM (fun hs arg => getFVarIds arg hs) hs
-- Return the set of MVarIds in the expression
partial def getMVarIds (e : Expr) (hs : HashSet MVarId := HashSet.empty) : MetaM (HashSet MVarId) := do
- e.withApp fun body args => do
+ e.consumeMData.withApp fun body args => do
let hs := if body.isMVar then hs.insert body.mvarId! else hs
args.foldlM (fun hs arg => getMVarIds arg hs) hs