summaryrefslogtreecommitdiff
path: root/backends/lean
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--backends/lean/Base/Diverge/Elab.lean107
-rw-r--r--backends/lean/Base/Diverge/ElabBase.lean1
2 files changed, 89 insertions, 19 deletions
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean
index 063480a2..91c51a31 100644
--- a/backends/lean/Base/Diverge/Elab.lean
+++ b/backends/lean/Base/Diverge/Elab.lean
@@ -16,8 +16,9 @@ syntax (name := divergentDef)
open Lean Elab Term Meta Primitives Lean.Meta
set_option trace.Diverge.def true
-set_option trace.Diverge.def.valid true
+-- set_option trace.Diverge.def.valid true
-- set_option trace.Diverge.def.sigmas true
+set_option trace.Diverge.def.unfold true
/- The following was copied from the `wfRecursion` function. -/
@@ -390,9 +391,10 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
-- Introduce a local declaration for the let-binding
withLetDecl dName dTy dValue fun decl => do
let isValid ← proveExprIsValid k_var kk_var body
- -- Add the let-binding around (rem.: the let-binding should be
- -- *inside* the `is_valid_p`, not outside, but because it reduces
- -- in the end it doesn't matter)
+ -- Add the let-binding around.
+ -- Rem.: the let-binding should be *inside* the `is_valid_p`, not outside,
+ -- but because it reduces in the end it doesn't matter. More precisely:
+ -- `P (let x := v in y)` and `let x := v in P y` reduce to the same expression.
mkLetFVars #[decl] isValid
| .mdata _ b => proveExprIsValid k_var kk_var b
| .proj _ _ _ =>
@@ -692,9 +694,9 @@ def proveMutRecIsValid
-- Generate the final definions by using the mutual body and the fixed point operator.
def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) :
- TermElabM Unit := do
+ TermElabM (Array Name) := do
let grSize := preDefs.size
- let _ ← preDefs.mapIdxM fun idx preDef => do
+ let defs ← preDefs.mapIdxM fun idx preDef => do
lambdaLetTelescope preDef.value fun xs _ => do
-- Create the index
let idx ← mkFinVal grSize idx.val
@@ -715,7 +717,58 @@ def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) :
all := [name]
}
addDecl decl
- pure ()
+ pure name
+ pure defs
+
+-- Prove the equations that we will use as unfolding theorems
+partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinition)
+ (decls : Array Name) : MetaM Unit := do
+ let grSize := preDefs.size
+ let proveIdx (i : Nat) : MetaM Unit := do
+ let preDef := preDefs.get! i
+ let defName := decls.get! i
+ -- Retrieve the arguments
+ lambdaLetTelescope preDef.value fun xs body => do
+ trace[Diverge.def.unfold] "proveUnfoldingThms: xs: {xs}"
+ trace[Diverge.def.unfold] "proveUnfoldingThms: body: {body}"
+ -- The theorem statement
+ let thmTy ← do
+ -- The equation: the declaration gives the lhs, the pre-def gives the rhs
+ let lhs ← mkAppOptM defName (xs.map some)
+ let rhs := body
+ let eq ← mkAppM ``Eq #[lhs, rhs]
+ mkForallFVars xs eq
+ trace[Diverge.def.unfold] "proveUnfoldingThms: thm statement: {thmTy}"
+ -- The proof
+ -- Use the fixed-point equation
+ let proof ← mkAppM ``FixI.is_valid_fix_fixed_eq #[isValidThm]
+ -- Add the index
+ let idx ← mkFinVal grSize i
+ let proof ← mkAppM ``congr_fun #[proof, idx]
+ -- Add the input argument
+ let arg ← mkSigmas xs.toList
+ let proof ← mkAppM ``congr_fun #[proof, arg]
+ -- Abstract the arguments away
+ let proof ← mkLambdaFVars xs proof
+ trace[Diverge.def.unfold] "proveUnfoldingThms: proof: {proof}:\n{← inferType proof}"
+ -- Declare the theorem
+ let name := preDef.declName ++ "unfold"
+ let decl := Declaration.thmDecl {
+ name
+ levelParams := preDef.levelParams
+ type := thmTy
+ value := proof
+ all := [name]
+ }
+ addDecl decl
+ trace[Diverge.def.unfold] "proveUnfoldingThms: added thm: {name}:\n{thmTy}"
+ let rec prove (i : Nat) : MetaM Unit := do
+ if i = preDefs.size then pure ()
+ else do
+ proveIdx i
+ prove (i + 1)
+ --
+ prove 0
def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value)
@@ -817,12 +870,12 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies
-- Generate the final definitions
- let defs ← mkDeclareFixDefs mutRecBody preDefs
+ let decls ← mkDeclareFixDefs mutRecBody preDefs
- -- Prove the unfolding equations
- -- TODO
+ -- Prove the unfolding theorems
+ proveUnfoldingThms isValidThm preDefs decls
- -- Process the definitions
+ -- Process the definitions - TODO
addAndCompilePartialRec preDefs
-- The following function is copy&pasted from Lean.Elab.PreDefinition.Main
@@ -942,10 +995,23 @@ divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a :=
if i = 0 then return x
else return (← list_nth ls (i - 1))
-#print list_nth.in_out_ty
-#check list_nth.sbody
-#check list_nth.mut_rec_body
-#print list_nth
+example {a: Type} (ls : List a) :
+ ∀ (i : Int),
+ 0 ≤ i → i < ls.length →
+ ∃ x, list_nth ls i = .ret x := by
+ induction ls
+ . intro i hpos h; simp at h; linarith
+ . rename_i hd tl ih
+ intro i hpos h
+ rw [list_nth.unfold]; simp
+ split <;> simp [*]
+ . tauto
+ . -- TODO: we shouldn't have to do that
+ have hneq : 0 < i := by cases i <;> rename_i a _ <;> simp_all; cases a <;> simp_all
+ simp at h
+ have ⟨ x, ih ⟩ := ih (i - 1) (by linarith) (by linarith)
+ simp [ih]
+ tauto
mutual
divergent def is_even (i : Int) : Result Bool :=
@@ -955,10 +1021,8 @@ mutual
if i = 0 then return false else return (← is_even (i - 1))
end
-example (i : Int) : is_even i = .ret (i % 2 = 0) ∧ is_odd i = .ret (i % 2 ≠ 0) := by
- induction i
- unfold is_even
- sorry
+#print is_even.unfold
+#print is_odd.unfold
mutual
divergent def foo (i : Int) : Result Nat :=
@@ -968,6 +1032,9 @@ mutual
if i > 20 then foo (i / 20) else .ret 42
end
+#print foo.unfold
+#print bar.unfold
+
-- Testing dependent branching and let-bindings
-- TODO: why the linter warning?
divergent def is_non_zero (i : Int) : Result Bool :=
@@ -976,4 +1043,6 @@ divergent def is_non_zero (i : Int) : Result Bool :=
let b := true
return b
+#print is_non_zero.unfold
+
end Diverge
diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean
index 281dbd6c..fd95291e 100644
--- a/backends/lean/Base/Diverge/ElabBase.lean
+++ b/backends/lean/Base/Diverge/ElabBase.lean
@@ -9,6 +9,7 @@ initialize registerTraceClass `Diverge.def
initialize registerTraceClass `Diverge.def.sigmas
initialize registerTraceClass `Diverge.def.genBody
initialize registerTraceClass `Diverge.def.valid
+initialize registerTraceClass `Diverge.def.unfold
-- TODO: move
-- TODO: small helper