diff options
author | Son Ho | 2023-07-03 18:02:52 +0200 |
---|---|---|
committer | Son Ho | 2023-07-03 18:02:52 +0200 |
commit | 9214484c471ad931924865855687f9a2ffe255dd (patch) | |
tree | 2569a0c257945b418bb5a52e28f8b8ae368f1a8a /backends/lean/Base/Diverge | |
parent | 7ceab6a725e5bd17c05bfd381753e453b15afaf7 (diff) |
Automate the proofs of the unfolding theorems for Diverge
Diffstat (limited to '')
-rw-r--r-- | backends/lean/Base/Diverge/Elab.lean | 107 | ||||
-rw-r--r-- | backends/lean/Base/Diverge/ElabBase.lean | 1 |
2 files changed, 89 insertions, 19 deletions
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 063480a2..91c51a31 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -16,8 +16,9 @@ syntax (name := divergentDef) open Lean Elab Term Meta Primitives Lean.Meta set_option trace.Diverge.def true -set_option trace.Diverge.def.valid true +-- set_option trace.Diverge.def.valid true -- set_option trace.Diverge.def.sigmas true +set_option trace.Diverge.def.unfold true /- The following was copied from the `wfRecursion` function. -/ @@ -390,9 +391,10 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do -- Introduce a local declaration for the let-binding withLetDecl dName dTy dValue fun decl => do let isValid ← proveExprIsValid k_var kk_var body - -- Add the let-binding around (rem.: the let-binding should be - -- *inside* the `is_valid_p`, not outside, but because it reduces - -- in the end it doesn't matter) + -- Add the let-binding around. + -- Rem.: the let-binding should be *inside* the `is_valid_p`, not outside, + -- but because it reduces in the end it doesn't matter. More precisely: + -- `P (let x := v in y)` and `let x := v in P y` reduce to the same expression. mkLetFVars #[decl] isValid | .mdata _ b => proveExprIsValid k_var kk_var b | .proj _ _ _ => @@ -692,9 +694,9 @@ def proveMutRecIsValid -- Generate the final definions by using the mutual body and the fixed point operator. def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : - TermElabM Unit := do + TermElabM (Array Name) := do let grSize := preDefs.size - let _ ← preDefs.mapIdxM fun idx preDef => do + let defs ← preDefs.mapIdxM fun idx preDef => do lambdaLetTelescope preDef.value fun xs _ => do -- Create the index let idx ← mkFinVal grSize idx.val @@ -715,7 +717,58 @@ def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : all := [name] } addDecl decl - pure () + pure name + pure defs + +-- Prove the equations that we will use as unfolding theorems +partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinition) + (decls : Array Name) : MetaM Unit := do + let grSize := preDefs.size + let proveIdx (i : Nat) : MetaM Unit := do + let preDef := preDefs.get! i + let defName := decls.get! i + -- Retrieve the arguments + lambdaLetTelescope preDef.value fun xs body => do + trace[Diverge.def.unfold] "proveUnfoldingThms: xs: {xs}" + trace[Diverge.def.unfold] "proveUnfoldingThms: body: {body}" + -- The theorem statement + let thmTy ← do + -- The equation: the declaration gives the lhs, the pre-def gives the rhs + let lhs ← mkAppOptM defName (xs.map some) + let rhs := body + let eq ← mkAppM ``Eq #[lhs, rhs] + mkForallFVars xs eq + trace[Diverge.def.unfold] "proveUnfoldingThms: thm statement: {thmTy}" + -- The proof + -- Use the fixed-point equation + let proof ← mkAppM ``FixI.is_valid_fix_fixed_eq #[isValidThm] + -- Add the index + let idx ← mkFinVal grSize i + let proof ← mkAppM ``congr_fun #[proof, idx] + -- Add the input argument + let arg ← mkSigmas xs.toList + let proof ← mkAppM ``congr_fun #[proof, arg] + -- Abstract the arguments away + let proof ← mkLambdaFVars xs proof + trace[Diverge.def.unfold] "proveUnfoldingThms: proof: {proof}:\n{← inferType proof}" + -- Declare the theorem + let name := preDef.declName ++ "unfold" + let decl := Declaration.thmDecl { + name + levelParams := preDef.levelParams + type := thmTy + value := proof + all := [name] + } + addDecl decl + trace[Diverge.def.unfold] "proveUnfoldingThms: added thm: {name}:\n{thmTy}" + let rec prove (i : Nat) : MetaM Unit := do + if i = preDefs.size then pure () + else do + proveIdx i + prove (i + 1) + -- + prove 0 def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) @@ -817,12 +870,12 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies -- Generate the final definitions - let defs ← mkDeclareFixDefs mutRecBody preDefs + let decls ← mkDeclareFixDefs mutRecBody preDefs - -- Prove the unfolding equations - -- TODO + -- Prove the unfolding theorems + proveUnfoldingThms isValidThm preDefs decls - -- Process the definitions + -- Process the definitions - TODO addAndCompilePartialRec preDefs -- The following function is copy&pasted from Lean.Elab.PreDefinition.Main @@ -942,10 +995,23 @@ divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := if i = 0 then return x else return (← list_nth ls (i - 1)) -#print list_nth.in_out_ty -#check list_nth.sbody -#check list_nth.mut_rec_body -#print list_nth +example {a: Type} (ls : List a) : + ∀ (i : Int), + 0 ≤ i → i < ls.length → + ∃ x, list_nth ls i = .ret x := by + induction ls + . intro i hpos h; simp at h; linarith + . rename_i hd tl ih + intro i hpos h + rw [list_nth.unfold]; simp + split <;> simp [*] + . tauto + . -- TODO: we shouldn't have to do that + have hneq : 0 < i := by cases i <;> rename_i a _ <;> simp_all; cases a <;> simp_all + simp at h + have ⟨ x, ih ⟩ := ih (i - 1) (by linarith) (by linarith) + simp [ih] + tauto mutual divergent def is_even (i : Int) : Result Bool := @@ -955,10 +1021,8 @@ mutual if i = 0 then return false else return (← is_even (i - 1)) end -example (i : Int) : is_even i = .ret (i % 2 = 0) ∧ is_odd i = .ret (i % 2 ≠ 0) := by - induction i - unfold is_even - sorry +#print is_even.unfold +#print is_odd.unfold mutual divergent def foo (i : Int) : Result Nat := @@ -968,6 +1032,9 @@ mutual if i > 20 then foo (i / 20) else .ret 42 end +#print foo.unfold +#print bar.unfold + -- Testing dependent branching and let-bindings -- TODO: why the linter warning? divergent def is_non_zero (i : Int) : Result Bool := @@ -976,4 +1043,6 @@ divergent def is_non_zero (i : Int) : Result Bool := let b := true return b +#print is_non_zero.unfold + end Diverge diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean index 281dbd6c..fd95291e 100644 --- a/backends/lean/Base/Diverge/ElabBase.lean +++ b/backends/lean/Base/Diverge/ElabBase.lean @@ -9,6 +9,7 @@ initialize registerTraceClass `Diverge.def initialize registerTraceClass `Diverge.def.sigmas initialize registerTraceClass `Diverge.def.genBody initialize registerTraceClass `Diverge.def.valid +initialize registerTraceClass `Diverge.def.unfold -- TODO: move -- TODO: small helper |