From 24c5289d0ca039c1c64081285d7d120a04f40699 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 11 Dec 2023 19:48:42 +0100 Subject: Update the validity proofs for higher-order functions --- backends/lean/Base/Diverge/Elab.lean | 101 ++++++++++++++++++++++++++++++++--- backends/lean/Base/Utils.lean | 6 +-- 2 files changed, 96 insertions(+), 11 deletions(-) diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 97364d14..e555b3e3 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -820,15 +820,88 @@ partial def proveAppIsValid (k_var kk_var : Expr) (e : Expr) (f : Expr) (args : Check if the arguments use the continuation: - if no: this is simple - if yes: we have to lookup theorems in div spec database and continue -/ - trace[Diverge.def.valid] "regular app: {e}" + trace[Diverge.def.valid] "regular app: {e}, f: {f}, args: {args}" let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty + trace[Diverge.def.valid] "allArgsFVars: {allArgsFVars.toList.map mkFVar}" if ¬ allArgsFVars.contains kk_var.fvarId! then do -- Simple case trace[Diverge.def.valid] "kk doesn't appear in the arguments" proveNoKExprIsValid k_var e else do -- Lookup in the database for suitable theorems - throwError "TODO: {e}" + trace[Diverge.def.valid] "kk appears in the arguments" + let thms ← divspecAttr.find? e + trace[Diverge.def.valid] "Looked up theorems: {thms}" + -- Try the theorems one by one + proveAppIsValidApplyThms k_var kk_var e f args thms.toList + +partial def proveAppIsValidApplyThms (k_var kk_var : Expr) (e : Expr) + (f : Expr) (args : Array Expr) (thms : List Name) : MetaM Expr := do + trace[Diverge.def.valid] "thms: {thms}" + match thms with + | [] => throwError "Could not prove that the following expression is valid: {e}" + | thName :: thms => + -- Lookup the theorem itself + let env ← getEnv + let thDecl := env.constants.find! thName + -- Introduce fresh meta-variables for the universes + let ul : List (Name × Level) ← + thDecl.levelParams.mapM (λ x => do pure (x, ← mkFreshLevelMVar)) + let ulMap : HashMap Name Level := HashMap.ofList ul + let thTy := thDecl.type.instantiateLevelParamsCore (λ x => ulMap.find! x) + -- Introduce meta variables for the universally quantified variables + let (mvars, _binders, thTy) ← forallMetaTelescope thTy + -- thTy should now be of the shape: `is_valid_p k (λ kk => ...)` + /-thTy.consumeMData.withApp fun _ args => do + if args.size ≠ 7 then throwError "Invalid number of arguments (expected 7): {thTy}" + let thTermToMatch := args.get! 6 -/ + let thTermToMatch := thTy + trace[Diverge.def.valid] "thTermToMatch: {thTermToMatch}" + -- Create the term: `is_valid_p k (λ kk => e)` + let termToMatch ← mkLambdaFVars #[kk_var] e + let termToMatch ← mkAppM ``FixII.is_valid_p #[k_var, termToMatch] + trace[Diverge.def.valid] "termToMatch: {termToMatch}" + -- Attempt to match + let ok ← isDefEq thTermToMatch termToMatch + if ¬ ok then + -- Failure: attempt with the other theorems + proveAppIsValidApplyThms k_var kk_var e f args thms + else do + -- Success: continue with this theorem + -- Instantiate the meta variables (some of them will not be instantiated: + -- they are new subgoals) + let mvars ← mvars.mapM instantiateMVars + let th ← mkAppOptM thName (Array.map some mvars) + trace[Diverge.def.valid] "Instantiated theorm: {th}\n{← inferType th}" + -- Filter the meta variables between the instantiated ones + for mvar in mvars do + if mvar.isMVar then do + -- Prove the subgoal (i.e., the precondition of the theorem) + let mvarId := mvar.mvarId! + let mvarDecl ← mvarId.getDecl + -- Dive in the type + forallTelescope mvarDecl.type fun forall_vars mvar_e => do + -- `mvar_e` should have the shape `is_valid_p k (λ kk => ...)` + -- We need to retrieve the new `k` variable, and dive into the + -- `λ kk => ...` + mvar_e.consumeMData.withApp fun is_valid args => do + if is_valid.constName? ≠ ``FixII.is_valid_p ∨ args.size ≠ 2 then + throwError "Invalid precondition: {mvar_e}" + else do + let k_var := args.get! 0 + let e_lam := args.get! 1 + lambdaTelescope e_lam.consumeMData fun lvars e => do + if lvars.size ≠ 1 then throwError "Invalid number of lambdas (expected 1): {e_lam}" + let kk_var := lvars.get! 0 + -- Continue + let e_valid ← proveExprIsValid k_var kk_var e + let e_valid ← mkForallFVars forall_vars e_valid + -- Assign the meta variable + mvarId.assign e_valid + else + -- Nothing to do + pure () + pure th -- Prove that a match expression is valid. partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Expr := do @@ -852,7 +925,8 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp ^^^^^^^^^^^^^^^^^^^^ this is the original match expression, with the the difference that the scrutinee(s) is a variable - ``` -/ + ``` + -/ let validMotive : Expr ← do -- The motive is a function of the scrutinees (i.e., a lambda expression): -- introduce binders for the scrutinees @@ -905,7 +979,7 @@ partial def proveSingleBodyIsValid trace[Diverge.def.valid] "body: {body}" lambdaTelescope body fun xs body => do trace[Diverge.def.valid] "xs: {xs}" - assert! xs.size = 3 + if xs.size ≠ 3 then throwError "Invalid number of lambdas: {xs} (expected 3)" let kk_var := xs.get! 0 let t_var := xs.get! 1 let x_var := xs.get! 2 @@ -1349,6 +1423,7 @@ elab_rules : command Command.elabCommand <| ← `(namespace $(mkIdentFrom id ns) $cmd end $(mkIdentFrom id ns)) namespace Tests + /- Some examples of partial functions -/ section HigherOrder open Ex8 @@ -1357,9 +1432,6 @@ namespace Tests | leaf (x : a) | node (tl : List (Tree a)) - set_option trace.Diverge.def true - -- set_option trace.Diverge.def.genBody true - set_option trace.Diverge.def.valid true divergent def id {a : Type u} (t : Tree a) : Result (Tree a) := match t with | .leaf x => .ret (.leaf x) @@ -1368,7 +1440,20 @@ namespace Tests let tl ← map id tl .ret (.node tl) - set_option trace.Diverge.def false + #check id.unfold + + /-set_option trace.Diverge.def true + -- set_option trace.Diverge.def.genBody true + set_option trace.Diverge.def.valid true + divergent def incr (t : Tree Nat) : Result (Tree Nat) := + match t with + | .leaf x => .ret (.leaf (x + 1)) + | .node tl => + do + let tl ← map incr tl + .ret (.node tl) + + set_option trace.Diverge.def false-/ end HigherOrder diff --git a/backends/lean/Base/Utils.lean b/backends/lean/Base/Utils.lean index 95b2c38b..2366e800 100644 --- a/backends/lean/Base/Utils.lean +++ b/backends/lean/Base/Utils.lean @@ -371,19 +371,19 @@ def splitConjTarget : TacticM Unit := do -- Destruct an equaliy and return the two sides def destEq (e : Expr) : MetaM (Expr × Expr) := do - e.withApp fun f args => + e.consumeMData.withApp fun f args => if f.isConstOf ``Eq ∧ args.size = 3 then pure (args.get! 1, args.get! 2) else throwError "Not an equality: {e}" -- Return the set of FVarIds in the expression partial def getFVarIds (e : Expr) (hs : HashSet FVarId := HashSet.empty) : MetaM (HashSet FVarId) := do - e.withApp fun body args => do + e.consumeMData.withApp fun body args => do let hs := if body.isFVar then hs.insert body.fvarId! else hs args.foldlM (fun hs arg => getFVarIds arg hs) hs -- Return the set of MVarIds in the expression partial def getMVarIds (e : Expr) (hs : HashSet MVarId := HashSet.empty) : MetaM (HashSet MVarId) := do - e.withApp fun body args => do + e.consumeMData.withApp fun body args => do let hs := if body.isMVar then hs.insert body.mvarId! else hs args.foldlM (fun hs arg => getMVarIds arg hs) hs -- cgit v1.2.3