From dba35a41e66019e586502f563ce7c629356fb2d7 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 12 Dec 2023 17:28:34 +0100 Subject: Make progress on supporting higher-order divergent functions --- backends/lean/Base/Diverge/Base.lean | 22 +++--- backends/lean/Base/Diverge/Elab.lean | 126 +++++++++++++++++++++++------------ 2 files changed, 95 insertions(+), 53 deletions(-) (limited to 'backends/lean/Base/Diverge') diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index bdc3ed04..9458c926 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -1467,26 +1467,28 @@ namespace Ex8 let tl ← map f tl .ret (hd :: tl) - /- The validity theorem for `map`, generic in `f` -/ + /- The validity theorems for `map`, generic in `f` -/ + + -- This is not the most general lemma, but we keep it to test the `divergence` encoding on a simple case @[divspec] - theorem map_is_valid + theorem map_is_valid_simple (i : id) (t : ty i) - {{f : (a i t → Result (b i t)) → (a i t) → Result c}} - (Hfvalid : ∀ k x, is_valid_p k (λ k => f (k i t) x)) (k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) (ls : List (a i t)) : - is_valid_p k (λ k => map (f (k i t)) ls) := by + is_valid_p k (λ k => map (k i t) ls) := by induction ls <;> simp [map] apply is_valid_p_bind <;> try simp_all intros apply is_valid_p_bind <;> try simp_all @[divspec] - theorem map_is_valid' - (i : id) (t : ty i) + theorem map_is_valid + (d : Type y) + {{f : ((i:id) → (t : ty i) → a i t → Result (b i t)) → d → Result c}} (k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) - (ls : List (a i t)) : - is_valid_p k (λ k => map (k i t) ls) := by + (Hfvalid : ∀ x1, is_valid_p k (fun kk1 => f kk1 x1)) + (ls : List d) : + is_valid_p k (λ k => map (f k) ls) := by induction ls <;> simp [map] apply is_valid_p_bind <;> try simp_all intros @@ -1532,7 +1534,7 @@ namespace Ex9 apply is_valid_p_bind <;> try simp [*] -- We have to show that `map k tl` is valid -- Remark: `map_is_valid` doesn't work here, we need the specialized version - apply map_is_valid' + apply map_is_valid_simple def body (k : (i : Fin 1) → (t : ty i) → (x : input_ty i t) → Result (output_ty i t)) (i: Fin 1) : (t : ty i) → (x : input_ty i t) → Result (output_ty i t) := get_fun bodies i k diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index e555b3e3..08d21e42 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -541,7 +541,6 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) trace[Diverge.def] "individual body of {preDef.declName}: {body}" -- Return the constant let body := Lean.mkConst name (levelParams.map .param) - -- let body ← mkAppM' body #[kk_var] trace[Diverge.def] "individual body (after decl): {body}" pure body @@ -665,7 +664,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do proveAppIsValid k_var kk_var e f args partial def proveAppIsValid (k_var kk_var : Expr) (e : Expr) (f : Expr) (args : Array Expr): MetaM Expr := do - trace[Diverge.def.valid] "proveAppIsValid: {f} {args}" + trace[Diverge.def.valid] "proveAppIsValid: {e}\nDecomposed: {f} {args}" /- There are several cases: first, check if this is a match/if Check if the expression is a (dependent) if then else. We treat the if then else expressions differently from the other matches, @@ -821,7 +820,8 @@ partial def proveAppIsValid (k_var kk_var : Expr) (e : Expr) (f : Expr) (args : - if no: this is simple - if yes: we have to lookup theorems in div spec database and continue -/ trace[Diverge.def.valid] "regular app: {e}, f: {f}, args: {args}" - let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty + let argsFVars ← args.mapM getFVarIds + let allArgsFVars := argsFVars.foldl (fun hs fvars => hs.insertMany fvars) HashSet.empty trace[Diverge.def.valid] "allArgsFVars: {allArgsFVars.toList.map mkFVar}" if ¬ allArgsFVars.contains kk_var.fvarId! then do -- Simple case @@ -837,7 +837,6 @@ partial def proveAppIsValid (k_var kk_var : Expr) (e : Expr) (f : Expr) (args : partial def proveAppIsValidApplyThms (k_var kk_var : Expr) (e : Expr) (f : Expr) (args : Array Expr) (thms : List Name) : MetaM Expr := do - trace[Diverge.def.valid] "thms: {thms}" match thms with | [] => throwError "Could not prove that the following expression is valid: {e}" | thName :: thms => @@ -849,58 +848,67 @@ partial def proveAppIsValidApplyThms (k_var kk_var : Expr) (e : Expr) thDecl.levelParams.mapM (λ x => do pure (x, ← mkFreshLevelMVar)) let ulMap : HashMap Name Level := HashMap.ofList ul let thTy := thDecl.type.instantiateLevelParamsCore (λ x => ulMap.find! x) + trace[Diverge.def.valid] "Trying with theorem {thName}: {thTy}" -- Introduce meta variables for the universally quantified variables - let (mvars, _binders, thTy) ← forallMetaTelescope thTy - -- thTy should now be of the shape: `is_valid_p k (λ kk => ...)` - /-thTy.consumeMData.withApp fun _ args => do - if args.size ≠ 7 then throwError "Invalid number of arguments (expected 7): {thTy}" - let thTermToMatch := args.get! 6 -/ - let thTermToMatch := thTy + let (mvars, _binders, thTyBody) ← forallMetaTelescope thTy + let thTermToMatch := thTyBody trace[Diverge.def.valid] "thTermToMatch: {thTermToMatch}" -- Create the term: `is_valid_p k (λ kk => e)` let termToMatch ← mkLambdaFVars #[kk_var] e let termToMatch ← mkAppM ``FixII.is_valid_p #[k_var, termToMatch] trace[Diverge.def.valid] "termToMatch: {termToMatch}" -- Attempt to match - let ok ← isDefEq thTermToMatch termToMatch + trace[Diverge.def.valid] "Matching terms:\n\n{termToMatch}\n\n{thTermToMatch}" + let ok ← isDefEq termToMatch thTermToMatch if ¬ ok then -- Failure: attempt with the other theorems proveAppIsValidApplyThms k_var kk_var e f args thms else do - -- Success: continue with this theorem - -- Instantiate the meta variables (some of them will not be instantiated: - -- they are new subgoals) + /- Success: continue with this theorem + + Instantiate the meta variables (some of them will not be instantiated: + they are new subgoals) + -/ let mvars ← mvars.mapM instantiateMVars let th ← mkAppOptM thName (Array.map some mvars) - trace[Diverge.def.valid] "Instantiated theorm: {th}\n{← inferType th}" - -- Filter the meta variables between the instantiated ones - for mvar in mvars do - if mvar.isMVar then do - -- Prove the subgoal (i.e., the precondition of the theorem) - let mvarId := mvar.mvarId! - let mvarDecl ← mvarId.getDecl - -- Dive in the type - forallTelescope mvarDecl.type fun forall_vars mvar_e => do - -- `mvar_e` should have the shape `is_valid_p k (λ kk => ...)` - -- We need to retrieve the new `k` variable, and dive into the - -- `λ kk => ...` - mvar_e.consumeMData.withApp fun is_valid args => do - if is_valid.constName? ≠ ``FixII.is_valid_p ∨ args.size ≠ 2 then - throwError "Invalid precondition: {mvar_e}" - else do - let k_var := args.get! 0 - let e_lam := args.get! 1 - lambdaTelescope e_lam.consumeMData fun lvars e => do - if lvars.size ≠ 1 then throwError "Invalid number of lambdas (expected 1): {e_lam}" - let kk_var := lvars.get! 0 - -- Continue - let e_valid ← proveExprIsValid k_var kk_var e - let e_valid ← mkForallFVars forall_vars e_valid - -- Assign the meta variable - mvarId.assign e_valid - else - -- Nothing to do - pure () + trace[Diverge.def.valid] "Instantiated theorem: {th}\n{← inferType th}" + -- Filter the instantiated meta variables + let mvars := mvars.filter (fun v => v.isMVar) + let mvars := mvars.map (fun v => v.mvarId!) + trace[Diverge.def.valid] "Remaining subgoals: {mvars}" + for mvarId in mvars do + -- Prove the subgoal (i.e., the precondition of the theorem) + let mvarDecl ← mvarId.getDecl + let declType ← instantiateMVars mvarDecl.type + -- Reduce the subgoal before diving in, if necessary + trace[Diverge.def.valid] "Subgoal: {declType}" + -- Dive in the type + forallTelescope declType fun forall_vars mvar_e => do + trace[Diverge.def.valid] "forall_vars: {forall_vars}" + -- `mvar_e` should have the shape `is_valid_p k (λ kk => ...)` + -- We need to retrieve the new `k` variable, and dive into the + -- `λ kk => ...` + mvar_e.consumeMData.withApp fun is_valid args => do + if is_valid.constName? ≠ ``FixII.is_valid_p ∨ args.size ≠ 7 then + throwError "Invalid precondition: {mvar_e}" + else do + let k_var := args.get! 5 + let e_lam := args.get! 6 + trace[Diverge.def.valid] "k_var: {k_var}\ne_lam: {e_lam}" + -- The outer lambda should be for the kk_var + lambdaOne e_lam.consumeMData fun kk_var e => do + -- Continue + trace[Diverge.def.valid] "kk_var: {kk_var}\ne: {e}" + -- We sometimes need to reduce the term + let e ← whnf e + trace[Diverge.def.valid] "e (after reduction): {e}" + let e_valid ← proveExprIsValid k_var kk_var e + trace[Diverge.def.valid] "e_valid (for e): {e_valid}" + let e_valid ← mkLambdaFVars forall_vars e_valid + trace[Diverge.def.valid] "e_valid (with foralls): {e_valid}" + let _ ← inferType e_valid -- Sanity check + -- Assign the meta variable + mvarId.assign e_valid pure th -- Prove that a match expression is valid. @@ -1442,6 +1450,38 @@ namespace Tests #check id.unfold + -- set_option pp.explicit true + -- set_option trace.Diverge.def true + -- set_option trace.Diverge.def.genBody true + set_option trace.Diverge.def.valid true + divergent def id1 {a : Type u} (t : Tree a) : Result (Tree a) := + match t with + | .leaf x => .ret (.leaf x) + | .node tl => + do + let tl ← map (fun x => id1 x) tl + .ret (.node tl) + + #check id1.unfold + + /-set_option trace.Diverge.def false + + -- set_option pp.explicit true + -- set_option trace.Diverge.def true + -- set_option trace.Diverge.def.genBody true + set_option trace.Diverge.def.valid true + divergent def id2 {a : Type u} (t : Tree a) : Result (Tree a) := + match t with + | .leaf x => .ret (.leaf x) + | .node tl => + do + let tl ← map (fun x => do let _ ← id2 x; id2 x) tl + .ret (.node tl) + + #check id2.unfold + + set_option trace.Diverge.def false -/ + /-set_option trace.Diverge.def true -- set_option trace.Diverge.def.genBody true set_option trace.Diverge.def.valid true -- cgit v1.2.3