From 37e5d5501e024869037bf0ea1559229a8be62da7 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 3 Jul 2023 16:24:44 +0200 Subject: Generate the proofs of validity in Elab.lean --- backends/lean/Base/Diverge/Base.lean | 76 +++++- backends/lean/Base/Diverge/Elab.lean | 403 ++++++++++++++++++++++++++++--- backends/lean/Base/Diverge/ElabBase.lean | 1 + 3 files changed, 446 insertions(+), 34 deletions(-) diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index aa0539ba..89365d25 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -434,6 +434,23 @@ namespace Fix is_valid_p k (λ k => k x) := by simp_all [is_valid_p, is_mono_p_rec, is_cont_p_rec] + theorem is_valid_p_ite + (k : ((x:a) → Result (b x)) → (x:a) → Result (b x)) + (cond : Prop) [h : Decidable cond] + {e1 e2 : ((x:a) → Result (b x)) → Result c} + (he1: is_valid_p k e1) (he2 : is_valid_p k e2) : + is_valid_p k (ite cond e1 e2) := by + split <;> assumption + + theorem is_valid_p_dite + (k : ((x:a) → Result (b x)) → (x:a) → Result (b x)) + (cond : Prop) [h : Decidable cond] + {e1 : cond → ((x:a) → Result (b x)) → Result c} + {e2 : Not cond → ((x:a) → Result (b x)) → Result c} + (he1: ∀ x, is_valid_p k (e1 x)) (he2 : ∀ x, is_valid_p k (e2 x)) : + is_valid_p k (dite cond e1 e2) := by + split <;> simp [*] + -- Lean is good at unification: we can write a very general version -- (in particular, it will manage to figure out `g` and `h` when we -- apply the lemma) @@ -680,6 +697,24 @@ namespace FixI is_valid_p k (λ k => k i x) := by simp [is_valid_p, k_to_gen, e_to_gen, kk_to_gen, kk_of_gen] + theorem is_valid_p_ite + (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) + (cond : Prop) [h : Decidable cond] + {e1 e2 : ((i:id) → (x:a i) → Result (b i x)) → Result c} + (he1: is_valid_p k e1) (he2 : is_valid_p k e2) : + is_valid_p k (λ k => ite cond (e1 k) (e2 k)) := by + split <;> assumption + + theorem is_valid_p_dite + (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)) + (cond : Prop) [h : Decidable cond] + {e1 : ((i:id) → (x:a i) → Result (b i x)) → cond → Result c} + {e2 : ((i:id) → (x:a i) → Result (b i x)) → Not cond → Result c} + (he1: ∀ x, is_valid_p k (λ k => e1 k x)) + (he2 : ∀ x, is_valid_p k (λ k => e2 k x)) : + is_valid_p k (λ k => dite cond (e1 k) (e2 k)) := by + split <;> simp [*] + theorem is_valid_p_bind {{k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)}} {{g : ((i:id) → (x:a i) → Result (b i x)) → Result c}} @@ -699,6 +734,9 @@ namespace FixI | .Nil => True | .Cons f fl => (∀ x, FixI.is_valid_p k (λ k => f k x)) ∧ fl.is_valid_p k + theorem Funs.is_valid_p_Nil (k : k_ty id a b) : + Funs.is_valid_p k Funs.Nil := by simp [Funs.is_valid_p] + def Funs.is_valid_p_is_valid_p_aux {k : k_ty id a b} {tys : List in_out_ty} @@ -1116,7 +1154,7 @@ namespace Ex6 def body (k : (i : Fin 1) → (x : input_ty i) → Result (output_ty i x)) (i: Fin 1) : (x : input_ty i) → Result (output_ty i x) := get_fun bodies i k - theorem list_nth_body_is_valid: is_valid body := by + theorem body_is_valid: is_valid body := by -- Split the proof into proofs of validity of the individual bodies rw [is_valid] simp only [body] @@ -1131,6 +1169,20 @@ namespace Ex6 split <;> simp split <;> simp + -- Writing the proof terms explicitly + theorem list_nth_body_is_valid' (k : k_ty (Fin 1) input_ty output_ty) + (x : (a : Type u) × List a × Int) : is_valid_p k (fun k => list_nth_body k x) := + let ⟨ a, ls, i ⟩ := x + match ls with + | [] => is_valid_p_same k (.fail .panic) + | hd :: tl => + is_valid_p_ite k (Eq i 0) (is_valid_p_same k (.ret hd)) (is_valid_p_rec k 0 ⟨a, tl, i-1⟩) + + theorem body_is_valid' : is_valid body := + fun k => + Funs.is_valid_p_is_valid_p tys k bodies + (And.intro (list_nth_body_is_valid' k) (Funs.is_valid_p_Nil k)) + noncomputable def list_nth {a: Type u} (ls : List a) (i : Int) : Result a := fix body 0 ⟨ a, ls , i ⟩ @@ -1144,8 +1196,28 @@ namespace Ex6 if i = 0 then .ret hd else list_nth tl (i - 1) := by - have Heq := is_valid_fix_fixed_eq list_nth_body_is_valid + have Heq := is_valid_fix_fixed_eq body_is_valid simp [list_nth] conv => lhs; rw [Heq] + -- Write the proof term explicitly: the generation of the proof term (without tactics) + -- is automatable, and the proof term is actually a lot simpler and smaller when we + -- don't use tactics. + theorem list_nth_eq'.{u} {a : Type u} (ls : List a) (i : Int) : + list_nth ls i = + match ls with + | [] => .fail .panic + | hd :: tl => + if i = 0 then .ret hd + else list_nth tl (i - 1) + := + -- Use the fixed-point equation + have Heq := is_valid_fix_fixed_eq body_is_valid.{u} + -- Add the index + have Heqi := congr_fun Heq 0 + -- Add the input + have Heqix := congr_fun Heqi { fst := a, snd := (ls, i) } + -- Done + Heqix + end Ex6 diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index f7de7518..cf40ea8f 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -16,6 +16,7 @@ syntax (name := divergentDef) open Lean Elab Term Meta Primitives Lean.Meta set_option trace.Diverge.def true +set_option trace.Diverge.def.valid true -- set_option trace.Diverge.def.sigmas true /- The following was copied from the `wfRecursion` function. -/ @@ -196,7 +197,6 @@ private def list_nth_out_ty1 (scrut0 : @Sigma (Type) (fun (a:Type) => @[specialize] def mapi (f : Nat → α → β) : List α → List β := mapiAux 0 f -#check Array.map -- Return the expression: `Fin n` -- TODO: use more def mkFin (n : Nat) : Expr := @@ -227,7 +227,7 @@ def mkFinValOld (n i : Nat) : MetaM Expr := do We name the declarations: "[original_name].body". We return the new declarations. -/ -def mkDeclareUnaryBodies (grLvlParams : List Name) (k_var : Expr) +def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) (preDefs : Array PreDefinition) : MetaM (Array Expr) := do let grSize := preDefs.size @@ -260,7 +260,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (k_var : Expr) let i ← mkFinVal grSize id -- Put the arguments in one big dependent tuple let args ← mkSigmas args.toList - mkAppM' k_var #[i, args] + mkAppM' kk_var #[i, args] else -- Not a recursive call: do nothing pure e @@ -281,8 +281,8 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (k_var : Expr) let body ← mkSigmasMatch args.toList body 0 -- Add the declaration - let value ← mkLambdaFVars #[k_var] body - let name := preDef.declName.append "body" + let value ← mkLambdaFVars #[kk_var] body + let name := preDef.declName.append "sbody" let levelParams := grLvlParams let decl := Declaration.defnDecl { name := name @@ -297,16 +297,17 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (k_var : Expr) trace[Diverge.def] "individual body of {preDef.declName}: {body}" -- Return the constant let body := Lean.mkConst name (levelParams.map .param) - -- let body ← mkAppM' body #[k_var] + -- let body ← mkAppM' body #[kk_var] trace[Diverge.def] "individual body (after decl): {body}" pure body -- Generate a unique function body from the bodies of the mutually recursive group, --- and add it as a declaration in the context -def mkDeclareMutualBody (grName : Name) (grLvlParams : List Name) - (i_var k_var : Expr) +-- and add it as a declaration in the context. +-- We return the list of bodies (of type `Funs ...`) and the mutually recursive body. +def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name) + (kk_var i_var : Expr) (in_ty out_ty : Expr) (inOutTys : List (Expr × Expr)) - (bodies : Array Expr) : MetaM Expr := do + (bodies : Array Expr) : MetaM (Expr × Expr) := do -- Generate the body let grSize := bodies.size let finTypeExpr := mkFin grSize @@ -323,15 +324,15 @@ def mkDeclareMutualBody (grName : Name) (grLvlParams : List Name) let inOutTysExpr ← mkList (← inOutTys.mapM (λ (x, y) => mkInOutTy x y)) inOutTyType let fl ← mkFuns inOutTys bl mkAppOptM ``FixI.Funs.Cons #[finTypeExpr, in_ty, out_ty, ity, oty, inOutTysExpr, b, fl] - | _, _ => throwError "mkDeclareMutualBody: `tys` and `bodies` don't have the same length" + | _, _ => throwError "mkDeclareMutRecBody: `tys` and `bodies` don't have the same length" let bodyFuns ← mkFuns inOutTys bodies.toList -- Wrap in `get_fun` - let body ← mkAppM ``FixI.get_fun #[bodyFuns, i_var, k_var] + let body ← mkAppM ``FixI.get_fun #[bodyFuns, i_var, kk_var] -- Add the index `i` and the continuation `k` as a variables - let body ← mkLambdaFVars #[k_var, i_var] body - trace[Diverge.def] "mkDeclareMutualBody: body: {body}" + let body ← mkLambdaFVars #[kk_var, i_var] body + trace[Diverge.def] "mkDeclareMutRecBody: body: {body}" -- Add the declaration - let name := grName.append "mutrec_body" + let name := grName.append "mut_rec_body" let levelParams := grLvlParams let decl := Declaration.defnDecl { name := name @@ -344,10 +345,348 @@ def mkDeclareMutualBody (grName : Name) (grLvlParams : List Name) } addDecl decl -- Return the constant - pure (Lean.mkConst name (levelParams.map .param)) + pure (bodyFuns, Lean.mkConst name (levelParams.map .param)) + +def isCasesExpr (e : Expr) : MetaM Bool := do + let e := e.getAppFn + if e.isConst then + return isCasesOnRecursor (← getEnv) e.constName + else return false + +structure MatchInfo where + matcherName : Name + matcherLevels : Array Level + params : Array Expr + motive : Expr + scruts : Array Expr + branchesNumParams : Array Nat + branches : Array Expr + +instance : ToMessageData MatchInfo where + -- This is not a very clean formatting, but we don't need more + toMessageData := fun me => m!"\n- matcherName: {me.matcherName}\n- params: {me.params}\n- motive: {me.motive}\n- scruts: {me.scruts}\n- branchesNumParams: {me.branchesNumParams}\n- branches: {me.branches}" + +-- An expression which doesn't use the continuation kk is valid +def proveNoKExprIsValid (k_var : Expr) (e : Expr) : MetaM Expr := do + trace[Diverge.def.valid] "proveNoKExprIsValid: {e}" + let eIsValid ← mkAppM ``FixI.is_valid_p_same #[k_var, e] + trace[Diverge.def.valid] "proveNoKExprIsValid: result:\n{eIsValid}:\n{← inferType eIsValid}" + pure eIsValid + +mutual + +partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do + trace[Diverge.def.valid] "proveValid: {e}" + match e with + | .bvar _ + | .fvar _ + | .mvar _ + | .sort _ + | .lit _ + | .const _ _ => throwError "Unimplemented" + | .lam .. => throwError "Unimplemented" + | .forallE .. => throwError "Unreachable" -- Shouldn't get there + | .letE .. => throwError "TODO" + -- lambdaLetTelescope e fun xs b => mapVisitBinders xs do + -- mkLambdaFVars xs (← mapVisit k b) (usedLetOnly := false) + | .mdata _ b => proveExprIsValid k_var kk_var b + | .proj _ _ _ => + -- The projection shouldn't use the continuation + proveNoKExprIsValid k_var e + | .app .. => + e.withApp fun f args => do + -- There are several cases: first, check if this is a match/if + -- The expression is a (dependent) if then else + let isIte := e.isIte + if isIte || e.isDIte then do + e.withApp fun f args => do + trace[Diverge.def.valid] "ite/dite: {f}:\n{args}" + if args.size ≠ 5 then + throwError "Wrong number of parameters for {f}: {args}" + let cond := args.get! 1 + let dec := args.get! 2 + -- Prove that the branches are valid + let br0 := args.get! 3 + let br1 := args.get! 4 + let proveBranchValid (br : Expr) : MetaM Expr := + if isIte then proveExprIsValid k_var kk_var br + else do + -- There is a lambda -- TODO: how do we remove exacly *one* lambda? + lambdaLetTelescope br fun xs br => do + let x := xs.get! 0 + let xs := xs.extract 1 xs.size + let br ← mkLambdaFVars xs br + let brValid ← proveExprIsValid k_var kk_var br + mkLambdaFVars #[x] brValid + let br0Valid ← proveBranchValid br0 + let br1Valid ← proveBranchValid br1 + let const := if isIte then ``FixI.is_valid_p_ite else ``FixI.is_valid_p_dite + let eIsValid ← mkAppOptM const #[none, none, none, none, some k_var, some cond, some dec, none, none, some br0Valid, some br1Valid] + trace[Diverge.def.valid] "ite/dite: result:\n{eIsValid}:\n{← inferType eIsValid}" + pure eIsValid + -- The expression is a match (this case is for when the elaborator + -- introduces auxiliary definitions to hide the match behind syntactic + -- sugar) + else if let some me := ← matchMatcherApp? e then do + trace[Diverge.def.valid] + "matcherApp: + - params: {me.params} + - motive: {me.motive} + - discrs: {me.discrs} + - altNumParams: {me.altNumParams} + - alts: {me.alts} + - remaining: {me.remaining}" + -- matchMatcherApp has already done the work for us + if me.remaining.size ≠ 0 then + throwError "MatcherApp: non empty remaining array: {me.remaining}" + let me : MatchInfo := { + matcherName := me.matcherName + matcherLevels := me.matcherLevels + params := me.params + motive := me.motive + scruts := me.discrs + branchesNumParams := me.altNumParams + branches := me.alts + } + proveMatchIsValid k_var kk_var me + -- The expression is a raw match (this case is for when the expression + -- is a direct call to the primitive `casesOn` function, without + -- syntactic sugar) + else if ← isCasesExpr f then do + trace[Diverge.def.valid] "rawMatch: {e}" + -- The casesOn definition is always of the following shape: + -- input parameters (implicit parameters), then motive (implicit), + -- scrutinee (explicit), branches (explicit). + let matcherName := f.constName! + let matcherLevels := f.constLevels!.toArray + -- Find the first explicit parameter: this is the scrutinee + forallTelescope (← inferType f) fun xs _ => do + let rec findFirstExplicit (i : Nat) : MetaM Nat := do + if i ≥ xs.size then throwError "Unexpected: could not find an explicit parameter" + else + let x := xs.get! i + let xFVarId := x.fvarId! + let localDecl ← xFVarId.getDecl + match localDecl.binderInfo with + | .default => pure i + | _ => findFirstExplicit (i + 1) + let scrutIdx ← findFirstExplicit 0 + -- Split the arguments + let params := args.extract 0 (scrutIdx - 1) + let motive := args.get! (scrutIdx - 1) + let scrut := args.get! scrutIdx + let branches := args.extract (scrutIdx + 1) args.size + -- Compute the number of parameters for the branches: for this we use + -- the type of the uninstantiated casesOn constant + let branchesNumParams : Array Nat ← do + let env ← getEnv + let decl := env.constants.find! matcherName + let ty := decl.type + forallTelescope ty fun xs _ => do + let xs := xs.extract (scrutIdx + 1) xs.size + xs.mapM fun x => do + let xty ← inferType x + forallTelescope xty fun ys _ => do + pure ys.size + let me : MatchInfo := { + matcherName, + matcherLevels, + params, + motive, + scruts := #[scrut], + branchesNumParams, + branches, + } + proveMatchIsValid k_var kk_var me + -- Monadic let-binding + else if f.isConstOf ``Bind.bind then do + trace[Diverge.def.valid] "bind:\n{args}" + let x := args.get! 4 + let y := args.get! 5 + -- Prove that the subexpressions are valid + let xValid ← proveExprIsValid k_var kk_var x + trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}" + let yValid ← do + -- This is a lambda expression -- TODO: how do we remove exacly *one* lambda? + lambdaLetTelescope y fun xs y => do + let x := xs.get! 0 + let xs := xs.extract 1 xs.size + let y ← mkLambdaFVars xs y + trace[Diverge.def.valid] "bind: y: {y}" + let yValid ← proveExprIsValid k_var kk_var y + trace[Diverge.def.valid] "bind: yValid (no forall): {yValid}" + trace[Diverge.def.valid] "bind: yValid: x: {x}" + let yValid ← mkLambdaFVars #[x] yValid + trace[Diverge.def.valid] "bind: yValid (forall): {yValid}: {← inferType yValid}" + pure yValid + -- Put everything together + trace[Diverge.def.valid] "bind:\n- xValid: {xValid}: {← inferType xValid}\n- yValid: {yValid}: {← inferType yValid}" + mkAppM ``FixI.is_valid_p_bind #[xValid, yValid] + -- Recursive call + else if f.isFVarOf kk_var.fvarId! then do + trace[Diverge.def.valid] "rec: args: \n{args}" + if args.size ≠ 2 then throwError "Recursive call with invalid number of parameters: {args}" + let i_arg := args.get! 0 + let x_arg := args.get! 1 + let eIsValid ← mkAppM ``FixI.is_valid_p_rec #[k_var, i_arg, x_arg] + trace[Diverge.def.valid] "rec: result: \n{eIsValid}" + pure eIsValid + else do + -- Remaining case: normal application. + -- It shouldn't use the continuation + proveNoKExprIsValid k_var e + +partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Expr := do + trace[Diverge.def.valid] "proveMatchIsValid: {me}" + -- Prove the validity of the branch expressions + let branchesValid:Array Expr ← me.branches.mapIdxM fun idx br => do + -- Go inside the lambdas - note that we have to be careful: some of the + -- binders might come from the match, and some of the binders might come + -- from the fact that the expression in the match is a lambda expression: + -- we use the branchesNumParams field for this reason + lambdaLetTelescope br fun xs br => do + let numParams := me.branchesNumParams.get! idx + let xs_beg := xs.extract 0 numParams + let xs_end := xs.extract numParams xs.size + let br ← mkLambdaFVars xs_end br + -- Prove that the branch expression is valid + let brValid ← proveExprIsValid k_var kk_var br + -- Reconstruct the lambda expression + mkLambdaFVars xs_beg brValid + trace[Diverge.def.valid] "branchesValid:\n{branchesValid}" + -- Put together: compute the motive. + -- It must be of the shape: + -- ``` + -- λ scrut => is_valid_p k (λ k => match scrut with ...) + -- ``` + let validMotive : Expr ← do + -- The motive is a function of the scrutinees (i.e., a lambda expression): + -- introduce binders for the scrutinees + let declInfos := me.scruts.mapIdx fun idx scrut => + let name : Name := (.num (.str .anonymous "scrut") idx) + let ty := λ (_ : Array Expr) => inferType scrut + (name, ty) + withLocalDeclsD declInfos fun scrutVars => do + -- Create a match expression but where the scrutinees have been replaced + -- by variables + let params : Array (Option Expr) := me.params.map some + let motive : Option Expr := some me.motive + let scruts : Array (Option Expr) := scrutVars.map some + let branches : Array (Option Expr) := me.branches.map some + let args := params ++ [motive] ++ scruts ++ branches + let matchE ← mkAppOptM me.matcherName args + -- let matchE ← mkLambdaFVars scrutVars (← mkAppOptM me.matcherName args) + -- Wrap in the `is_valid_p` predicate + let matchE ← mkLambdaFVars #[kk_var] matchE + let validMotive ← mkAppM ``FixI.is_valid_p #[k_var, matchE] + -- Abstract away the scrutinee variables + mkLambdaFVars scrutVars validMotive + trace[Diverge.def.valid] "valid motive: {validMotive}" + -- Put together + let valid ← do + let params : Array (Option Expr) := me.params.map (λ _ => none) + let motive := some validMotive + let scruts := me.scruts.map some + let branches := branchesValid.map some + let args := params ++ [motive] ++ scruts ++ branches + mkAppOptM me.matcherName args + trace[Diverge.def.valid] "proveMatchIsValid:\n{valid}:\n{← inferType valid}" + pure valid + +end + +-- Prove that a single body (in the mutually recursive group) is valid +partial def proveSingleBodyIsValid (k_var : Expr) (preDef : PreDefinition) (bodyConst : Expr) : + MetaM Expr := do + trace[Diverge.def.valid] "proveSingleBodyIsValid: bodyConst: {bodyConst}" + -- Lookup the definition (`bodyConst` is the definition of the body, we want + -- to retrieve the value itself to dive inside) + let name := bodyConst.constName! + let env ← getEnv + let body := (env.constants.find! name).value! + trace[Diverge.def.valid] "body: {body}" + lambdaLetTelescope body fun xs body => do + assert! xs.size = 2 + let kk_var := xs.get! 0 + let x_var := xs.get! 1 + -- State the type of the theorem to prove + let thmTy ← mkAppM ``FixI.is_valid_p + #[k_var, ← mkLambdaFVars #[kk_var] (← mkAppM' bodyConst #[kk_var, x_var])] + trace[Diverge.def.valid] "thmTy: {thmTy}" + -- Prove that the body is valid + let proof ← proveExprIsValid k_var kk_var body + let proof ← mkLambdaFVars #[k_var, x_var] proof + trace[Diverge.def.valid] "proveSingleBodyIsValid: proof:\n{proof}:\n{← inferType proof}" + -- The target type (we don't have to do this: this is simply a sanity check, + -- and this allows a nicer debugging output) + let thmTy ← do + let body ← mkAppM' bodyConst #[kk_var, x_var] + let body ← mkLambdaFVars #[kk_var] body + let ty ← mkAppM ``FixI.is_valid_p #[k_var, body] + mkForallFVars #[k_var, x_var] ty + trace[Diverge.def.valid] "proveSingleBodyIsValid: thmTy\n{thmTy}:\n{← inferType thmTy}" + -- Save the theorem + let name := preDef.declName ++ "sbody_is_valid" + let decl := Declaration.thmDecl { + name + levelParams := preDef.levelParams + type := thmTy + value := proof + all := [name] + } + addDecl decl + trace[Diverge.def.valid] "proveSingleBodyIsValid: added thm: {name}" + -- Return the theorem + pure (Expr.const name (preDef.levelParams.map .param)) + +partial def proveFunsBodyIsValid (inOutTys: Expr) (bodyFuns : Expr) + (k_var : Expr) (bodiesValid : Array Expr) : MetaM Expr := do + -- Create the big "and" expression, which groups the validity proof of the individual bodies + let rec mkValidConj (i : Nat) : MetaM Expr := do + if i = bodiesValid.size then + -- We reached the end + mkAppM ``FixI.Funs.is_valid_p_Nil #[k_var] + else do + -- We haven't reached the end: introduce a conjunction + let valid := bodiesValid.get! i + let valid ← mkAppM' valid #[k_var] + mkAppM ``And.intro #[valid, ← mkValidConj (i + 1)] + let andExpr ← mkValidConj 0 + -- Wrap in the `is_valid_p_is_valid_p` theorem, and abstract the continuation + let isValid ← mkAppM ``FixI.Funs.is_valid_p_is_valid_p #[inOutTys, k_var, bodyFuns, andExpr] + mkLambdaFVars #[k_var] isValid + +-- Prove that the mut rec body is valid +-- TODO: maybe this function should introduce k_var itself +def proveMutRecIsValid + (grName : Name) (grLvlParams : List Name) + (inOutTys : Expr) (bodyFuns mutRecBodyConst : Expr) + (k_var : Expr) (preDefs : Array PreDefinition) + (bodies : Array Expr) : MetaM Expr := do + -- First prove that the individual bodies are valid + let bodiesValid ← + bodies.mapIdxM fun idx body => do + let preDef := preDefs.get! idx + proveSingleBodyIsValid k_var preDef body + -- Then prove that the mut rec body is valid + let isValid ← proveFunsBodyIsValid inOutTys bodyFuns k_var bodiesValid + -- Save the theorem + let thmTy ← mkAppM ``FixI.is_valid #[mutRecBodyConst] + let name := grName ++ "mut_rec_body_is_valid" + let decl := Declaration.thmDecl { + name + levelParams := grLvlParams + type := thmTy + value := isValid + all := [name] + } + addDecl decl + trace[Diverge.def.valid] "proveFunsBodyIsValid: added thm: {name}:\n{thmTy}" + -- Return the theorem + pure (Expr.const name (grLvlParams.map .param)) -- Generate the final definions by using the mutual body and the fixed point operator. -def mkDeclareFixDefs (mutBody : Expr) (preDefs : Array PreDefinition) : +def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : TermElabM Unit := do let grSize := preDefs.size let _ ← preDefs.mapIdxM fun idx preDef => do @@ -357,7 +696,7 @@ def mkDeclareFixDefs (mutBody : Expr) (preDefs : Array PreDefinition) : -- Group the inputs into a dependent tuple let input ← mkSigmas xs.toList -- Apply the fixed point - let fixedBody ← mkAppM ``FixI.fix #[mutBody, idx, input] + let fixedBody ← mkAppM ``FixI.fix #[mutRecBody, idx, input] let fixedBody ← mkLambdaFVars xs fixedBody -- Create the declaration let name := preDef.declName @@ -454,24 +793,26 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Introduce the continuation `k` let in_ty ← mkLambdaFVars #[i_var] in_ty let out_ty ← mkLambdaFVars #[i_var, input] out_ty - let k_var_ty ← mkAppM ``FixI.kk_ty #[i_var_ty, in_ty, out_ty] -- - trace[Diverge.def] "k_var_ty: {k_var_ty}" - withLocalDeclD (.num (.str .anonymous "k") 2) k_var_ty fun k_var => do - trace[Diverge.def] "k_var: {k_var}" + let kk_var_ty ← mkAppM ``FixI.kk_ty #[i_var_ty, in_ty, out_ty] + trace[Diverge.def] "kk_var_ty: {kk_var_ty}" + withLocalDeclD (.num (.str .anonymous "kk") 2) kk_var_ty fun kk_var => do + trace[Diverge.def] "kk_var: {kk_var}" -- Replace the recursive calls in all the function bodies by calls to the -- continuation `k` and and generate for those bodies declarations - let bodies ← mkDeclareUnaryBodies grLvlParams k_var preDefs + let bodies ← mkDeclareUnaryBodies grLvlParams kk_var preDefs -- Generate the mutually recursive body - let body ← mkDeclareMutualBody grName grLvlParams i_var k_var in_ty out_ty inOutTys.toList bodies - trace[Diverge.def] "mut rec body (after decl): {body}" + let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var in_ty out_ty inOutTys.toList bodies + trace[Diverge.def] "mut rec body (after decl): {mutRecBody}" -- Prove that the mut rec body satisfies the validity criteria required by -- our fixed-point - -- TODO + let k_var_ty ← mkAppM ``FixI.k_ty #[i_var_ty, in_ty, out_ty] + withLocalDeclD (.num (.str .anonymous "k") 3) k_var_ty fun k_var => do + let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies -- Generate the final definitions - let defs ← mkDeclareFixDefs body preDefs + let defs ← mkDeclareFixDefs mutRecBody preDefs -- Prove the unfolding equations -- TODO @@ -496,13 +837,10 @@ def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLC for preDefs in cliques do trace[Diverge.elab] "{preDefs.map (·.declName)}" try - trace[Diverge.elab] "calling divRecursion" withRef (preDefs[0]!.ref) do divRecursion preDefs - trace[Diverge.elab] "divRecursion succeeded" catch ex => - -- If it failed, we - trace[Diverge.elab] "divRecursion failed" + -- If it failed, we add the functions as partial functions hasErrors := true logException ex let s ← saveState @@ -600,7 +938,8 @@ divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := else return (← list_nth ls (i - 1)) #print list_nth.in_out_ty -#check list_nth.body +#check list_nth.sbody +#check list_nth.mut_rec_body #print list_nth mutual diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean index 82f79f94..281dbd6c 100644 --- a/backends/lean/Base/Diverge/ElabBase.lean +++ b/backends/lean/Base/Diverge/ElabBase.lean @@ -8,6 +8,7 @@ initialize registerTraceClass `Diverge.elab initialize registerTraceClass `Diverge.def initialize registerTraceClass `Diverge.def.sigmas initialize registerTraceClass `Diverge.def.genBody +initialize registerTraceClass `Diverge.def.valid -- TODO: move -- TODO: small helper -- cgit v1.2.3