summaryrefslogtreecommitdiff
path: root/backends/lean
diff options
context:
space:
mode:
authorSon Ho2023-07-03 16:24:44 +0200
committerSon Ho2023-07-03 16:24:44 +0200
commit37e5d5501e024869037bf0ea1559229a8be62da7 (patch)
treef7f48b6cddd0a2c03a07a24b43bad0df675c2d54 /backends/lean
parent1c9331ce92b68b9a83c601212149a6c24591708f (diff)
Generate the proofs of validity in Elab.lean
Diffstat (limited to 'backends/lean')
-rw-r--r--backends/lean/Base/Diverge/Base.lean76
-rw-r--r--backends/lean/Base/Diverge/Elab.lean403
-rw-r--r--backends/lean/Base/Diverge/ElabBase.lean1
3 files changed, 446 insertions, 34 deletions
diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean
index aa0539ba..89365d25 100644
--- a/backends/lean/Base/Diverge/Base.lean
+++ b/backends/lean/Base/Diverge/Base.lean
@@ -434,6 +434,23 @@ namespace Fix
is_valid_p k (λ k => k x) := by
simp_all [is_valid_p, is_mono_p_rec, is_cont_p_rec]
+ theorem is_valid_p_ite
+ (k : ((x:a) → Result (b x)) → (x:a) → Result (b x))
+ (cond : Prop) [h : Decidable cond]
+ {e1 e2 : ((x:a) → Result (b x)) → Result c}
+ (he1: is_valid_p k e1) (he2 : is_valid_p k e2) :
+ is_valid_p k (ite cond e1 e2) := by
+ split <;> assumption
+
+ theorem is_valid_p_dite
+ (k : ((x:a) → Result (b x)) → (x:a) → Result (b x))
+ (cond : Prop) [h : Decidable cond]
+ {e1 : cond → ((x:a) → Result (b x)) → Result c}
+ {e2 : Not cond → ((x:a) → Result (b x)) → Result c}
+ (he1: ∀ x, is_valid_p k (e1 x)) (he2 : ∀ x, is_valid_p k (e2 x)) :
+ is_valid_p k (dite cond e1 e2) := by
+ split <;> simp [*]
+
-- Lean is good at unification: we can write a very general version
-- (in particular, it will manage to figure out `g` and `h` when we
-- apply the lemma)
@@ -680,6 +697,24 @@ namespace FixI
is_valid_p k (λ k => k i x) := by
simp [is_valid_p, k_to_gen, e_to_gen, kk_to_gen, kk_of_gen]
+ theorem is_valid_p_ite
+ (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x))
+ (cond : Prop) [h : Decidable cond]
+ {e1 e2 : ((i:id) → (x:a i) → Result (b i x)) → Result c}
+ (he1: is_valid_p k e1) (he2 : is_valid_p k e2) :
+ is_valid_p k (λ k => ite cond (e1 k) (e2 k)) := by
+ split <;> assumption
+
+ theorem is_valid_p_dite
+ (k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x))
+ (cond : Prop) [h : Decidable cond]
+ {e1 : ((i:id) → (x:a i) → Result (b i x)) → cond → Result c}
+ {e2 : ((i:id) → (x:a i) → Result (b i x)) → Not cond → Result c}
+ (he1: ∀ x, is_valid_p k (λ k => e1 k x))
+ (he2 : ∀ x, is_valid_p k (λ k => e2 k x)) :
+ is_valid_p k (λ k => dite cond (e1 k) (e2 k)) := by
+ split <;> simp [*]
+
theorem is_valid_p_bind
{{k : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)}}
{{g : ((i:id) → (x:a i) → Result (b i x)) → Result c}}
@@ -699,6 +734,9 @@ namespace FixI
| .Nil => True
| .Cons f fl => (∀ x, FixI.is_valid_p k (λ k => f k x)) ∧ fl.is_valid_p k
+ theorem Funs.is_valid_p_Nil (k : k_ty id a b) :
+ Funs.is_valid_p k Funs.Nil := by simp [Funs.is_valid_p]
+
def Funs.is_valid_p_is_valid_p_aux
{k : k_ty id a b}
{tys : List in_out_ty}
@@ -1116,7 +1154,7 @@ namespace Ex6
def body (k : (i : Fin 1) → (x : input_ty i) → Result (output_ty i x)) (i: Fin 1) :
(x : input_ty i) → Result (output_ty i x) := get_fun bodies i k
- theorem list_nth_body_is_valid: is_valid body := by
+ theorem body_is_valid: is_valid body := by
-- Split the proof into proofs of validity of the individual bodies
rw [is_valid]
simp only [body]
@@ -1131,6 +1169,20 @@ namespace Ex6
split <;> simp
split <;> simp
+ -- Writing the proof terms explicitly
+ theorem list_nth_body_is_valid' (k : k_ty (Fin 1) input_ty output_ty)
+ (x : (a : Type u) × List a × Int) : is_valid_p k (fun k => list_nth_body k x) :=
+ let ⟨ a, ls, i ⟩ := x
+ match ls with
+ | [] => is_valid_p_same k (.fail .panic)
+ | hd :: tl =>
+ is_valid_p_ite k (Eq i 0) (is_valid_p_same k (.ret hd)) (is_valid_p_rec k 0 ⟨a, tl, i-1⟩)
+
+ theorem body_is_valid' : is_valid body :=
+ fun k =>
+ Funs.is_valid_p_is_valid_p tys k bodies
+ (And.intro (list_nth_body_is_valid' k) (Funs.is_valid_p_Nil k))
+
noncomputable
def list_nth {a: Type u} (ls : List a) (i : Int) : Result a :=
fix body 0 ⟨ a, ls , i ⟩
@@ -1144,8 +1196,28 @@ namespace Ex6
if i = 0 then .ret hd
else list_nth tl (i - 1)
:= by
- have Heq := is_valid_fix_fixed_eq list_nth_body_is_valid
+ have Heq := is_valid_fix_fixed_eq body_is_valid
simp [list_nth]
conv => lhs; rw [Heq]
+ -- Write the proof term explicitly: the generation of the proof term (without tactics)
+ -- is automatable, and the proof term is actually a lot simpler and smaller when we
+ -- don't use tactics.
+ theorem list_nth_eq'.{u} {a : Type u} (ls : List a) (i : Int) :
+ list_nth ls i =
+ match ls with
+ | [] => .fail .panic
+ | hd :: tl =>
+ if i = 0 then .ret hd
+ else list_nth tl (i - 1)
+ :=
+ -- Use the fixed-point equation
+ have Heq := is_valid_fix_fixed_eq body_is_valid.{u}
+ -- Add the index
+ have Heqi := congr_fun Heq 0
+ -- Add the input
+ have Heqix := congr_fun Heqi { fst := a, snd := (ls, i) }
+ -- Done
+ Heqix
+
end Ex6
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean
index f7de7518..cf40ea8f 100644
--- a/backends/lean/Base/Diverge/Elab.lean
+++ b/backends/lean/Base/Diverge/Elab.lean
@@ -16,6 +16,7 @@ syntax (name := divergentDef)
open Lean Elab Term Meta Primitives Lean.Meta
set_option trace.Diverge.def true
+set_option trace.Diverge.def.valid true
-- set_option trace.Diverge.def.sigmas true
/- The following was copied from the `wfRecursion` function. -/
@@ -196,7 +197,6 @@ private def list_nth_out_ty1 (scrut0 : @Sigma (Type) (fun (a:Type) =>
@[specialize] def mapi (f : Nat → α → β) : List α → List β := mapiAux 0 f
-#check Array.map
-- Return the expression: `Fin n`
-- TODO: use more
def mkFin (n : Nat) : Expr :=
@@ -227,7 +227,7 @@ def mkFinValOld (n i : Nat) : MetaM Expr := do
We name the declarations: "[original_name].body".
We return the new declarations.
-/
-def mkDeclareUnaryBodies (grLvlParams : List Name) (k_var : Expr)
+def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
(preDefs : Array PreDefinition) :
MetaM (Array Expr) := do
let grSize := preDefs.size
@@ -260,7 +260,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (k_var : Expr)
let i ← mkFinVal grSize id
-- Put the arguments in one big dependent tuple
let args ← mkSigmas args.toList
- mkAppM' k_var #[i, args]
+ mkAppM' kk_var #[i, args]
else
-- Not a recursive call: do nothing
pure e
@@ -281,8 +281,8 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (k_var : Expr)
let body ← mkSigmasMatch args.toList body 0
-- Add the declaration
- let value ← mkLambdaFVars #[k_var] body
- let name := preDef.declName.append "body"
+ let value ← mkLambdaFVars #[kk_var] body
+ let name := preDef.declName.append "sbody"
let levelParams := grLvlParams
let decl := Declaration.defnDecl {
name := name
@@ -297,16 +297,17 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (k_var : Expr)
trace[Diverge.def] "individual body of {preDef.declName}: {body}"
-- Return the constant
let body := Lean.mkConst name (levelParams.map .param)
- -- let body ← mkAppM' body #[k_var]
+ -- let body ← mkAppM' body #[kk_var]
trace[Diverge.def] "individual body (after decl): {body}"
pure body
-- Generate a unique function body from the bodies of the mutually recursive group,
--- and add it as a declaration in the context
-def mkDeclareMutualBody (grName : Name) (grLvlParams : List Name)
- (i_var k_var : Expr)
+-- and add it as a declaration in the context.
+-- We return the list of bodies (of type `Funs ...`) and the mutually recursive body.
+def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name)
+ (kk_var i_var : Expr)
(in_ty out_ty : Expr) (inOutTys : List (Expr × Expr))
- (bodies : Array Expr) : MetaM Expr := do
+ (bodies : Array Expr) : MetaM (Expr × Expr) := do
-- Generate the body
let grSize := bodies.size
let finTypeExpr := mkFin grSize
@@ -323,15 +324,15 @@ def mkDeclareMutualBody (grName : Name) (grLvlParams : List Name)
let inOutTysExpr ← mkList (← inOutTys.mapM (λ (x, y) => mkInOutTy x y)) inOutTyType
let fl ← mkFuns inOutTys bl
mkAppOptM ``FixI.Funs.Cons #[finTypeExpr, in_ty, out_ty, ity, oty, inOutTysExpr, b, fl]
- | _, _ => throwError "mkDeclareMutualBody: `tys` and `bodies` don't have the same length"
+ | _, _ => throwError "mkDeclareMutRecBody: `tys` and `bodies` don't have the same length"
let bodyFuns ← mkFuns inOutTys bodies.toList
-- Wrap in `get_fun`
- let body ← mkAppM ``FixI.get_fun #[bodyFuns, i_var, k_var]
+ let body ← mkAppM ``FixI.get_fun #[bodyFuns, i_var, kk_var]
-- Add the index `i` and the continuation `k` as a variables
- let body ← mkLambdaFVars #[k_var, i_var] body
- trace[Diverge.def] "mkDeclareMutualBody: body: {body}"
+ let body ← mkLambdaFVars #[kk_var, i_var] body
+ trace[Diverge.def] "mkDeclareMutRecBody: body: {body}"
-- Add the declaration
- let name := grName.append "mutrec_body"
+ let name := grName.append "mut_rec_body"
let levelParams := grLvlParams
let decl := Declaration.defnDecl {
name := name
@@ -344,10 +345,348 @@ def mkDeclareMutualBody (grName : Name) (grLvlParams : List Name)
}
addDecl decl
-- Return the constant
- pure (Lean.mkConst name (levelParams.map .param))
+ pure (bodyFuns, Lean.mkConst name (levelParams.map .param))
+
+def isCasesExpr (e : Expr) : MetaM Bool := do
+ let e := e.getAppFn
+ if e.isConst then
+ return isCasesOnRecursor (← getEnv) e.constName
+ else return false
+
+structure MatchInfo where
+ matcherName : Name
+ matcherLevels : Array Level
+ params : Array Expr
+ motive : Expr
+ scruts : Array Expr
+ branchesNumParams : Array Nat
+ branches : Array Expr
+
+instance : ToMessageData MatchInfo where
+ -- This is not a very clean formatting, but we don't need more
+ toMessageData := fun me => m!"\n- matcherName: {me.matcherName}\n- params: {me.params}\n- motive: {me.motive}\n- scruts: {me.scruts}\n- branchesNumParams: {me.branchesNumParams}\n- branches: {me.branches}"
+
+-- An expression which doesn't use the continuation kk is valid
+def proveNoKExprIsValid (k_var : Expr) (e : Expr) : MetaM Expr := do
+ trace[Diverge.def.valid] "proveNoKExprIsValid: {e}"
+ let eIsValid ← mkAppM ``FixI.is_valid_p_same #[k_var, e]
+ trace[Diverge.def.valid] "proveNoKExprIsValid: result:\n{eIsValid}:\n{← inferType eIsValid}"
+ pure eIsValid
+
+mutual
+
+partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
+ trace[Diverge.def.valid] "proveValid: {e}"
+ match e with
+ | .bvar _
+ | .fvar _
+ | .mvar _
+ | .sort _
+ | .lit _
+ | .const _ _ => throwError "Unimplemented"
+ | .lam .. => throwError "Unimplemented"
+ | .forallE .. => throwError "Unreachable" -- Shouldn't get there
+ | .letE .. => throwError "TODO"
+ -- lambdaLetTelescope e fun xs b => mapVisitBinders xs do
+ -- mkLambdaFVars xs (← mapVisit k b) (usedLetOnly := false)
+ | .mdata _ b => proveExprIsValid k_var kk_var b
+ | .proj _ _ _ =>
+ -- The projection shouldn't use the continuation
+ proveNoKExprIsValid k_var e
+ | .app .. =>
+ e.withApp fun f args => do
+ -- There are several cases: first, check if this is a match/if
+ -- The expression is a (dependent) if then else
+ let isIte := e.isIte
+ if isIte || e.isDIte then do
+ e.withApp fun f args => do
+ trace[Diverge.def.valid] "ite/dite: {f}:\n{args}"
+ if args.size ≠ 5 then
+ throwError "Wrong number of parameters for {f}: {args}"
+ let cond := args.get! 1
+ let dec := args.get! 2
+ -- Prove that the branches are valid
+ let br0 := args.get! 3
+ let br1 := args.get! 4
+ let proveBranchValid (br : Expr) : MetaM Expr :=
+ if isIte then proveExprIsValid k_var kk_var br
+ else do
+ -- There is a lambda -- TODO: how do we remove exacly *one* lambda?
+ lambdaLetTelescope br fun xs br => do
+ let x := xs.get! 0
+ let xs := xs.extract 1 xs.size
+ let br ← mkLambdaFVars xs br
+ let brValid ← proveExprIsValid k_var kk_var br
+ mkLambdaFVars #[x] brValid
+ let br0Valid ← proveBranchValid br0
+ let br1Valid ← proveBranchValid br1
+ let const := if isIte then ``FixI.is_valid_p_ite else ``FixI.is_valid_p_dite
+ let eIsValid ← mkAppOptM const #[none, none, none, none, some k_var, some cond, some dec, none, none, some br0Valid, some br1Valid]
+ trace[Diverge.def.valid] "ite/dite: result:\n{eIsValid}:\n{← inferType eIsValid}"
+ pure eIsValid
+ -- The expression is a match (this case is for when the elaborator
+ -- introduces auxiliary definitions to hide the match behind syntactic
+ -- sugar)
+ else if let some me := ← matchMatcherApp? e then do
+ trace[Diverge.def.valid]
+ "matcherApp:
+ - params: {me.params}
+ - motive: {me.motive}
+ - discrs: {me.discrs}
+ - altNumParams: {me.altNumParams}
+ - alts: {me.alts}
+ - remaining: {me.remaining}"
+ -- matchMatcherApp has already done the work for us
+ if me.remaining.size ≠ 0 then
+ throwError "MatcherApp: non empty remaining array: {me.remaining}"
+ let me : MatchInfo := {
+ matcherName := me.matcherName
+ matcherLevels := me.matcherLevels
+ params := me.params
+ motive := me.motive
+ scruts := me.discrs
+ branchesNumParams := me.altNumParams
+ branches := me.alts
+ }
+ proveMatchIsValid k_var kk_var me
+ -- The expression is a raw match (this case is for when the expression
+ -- is a direct call to the primitive `casesOn` function, without
+ -- syntactic sugar)
+ else if ← isCasesExpr f then do
+ trace[Diverge.def.valid] "rawMatch: {e}"
+ -- The casesOn definition is always of the following shape:
+ -- input parameters (implicit parameters), then motive (implicit),
+ -- scrutinee (explicit), branches (explicit).
+ let matcherName := f.constName!
+ let matcherLevels := f.constLevels!.toArray
+ -- Find the first explicit parameter: this is the scrutinee
+ forallTelescope (← inferType f) fun xs _ => do
+ let rec findFirstExplicit (i : Nat) : MetaM Nat := do
+ if i ≥ xs.size then throwError "Unexpected: could not find an explicit parameter"
+ else
+ let x := xs.get! i
+ let xFVarId := x.fvarId!
+ let localDecl ← xFVarId.getDecl
+ match localDecl.binderInfo with
+ | .default => pure i
+ | _ => findFirstExplicit (i + 1)
+ let scrutIdx ← findFirstExplicit 0
+ -- Split the arguments
+ let params := args.extract 0 (scrutIdx - 1)
+ let motive := args.get! (scrutIdx - 1)
+ let scrut := args.get! scrutIdx
+ let branches := args.extract (scrutIdx + 1) args.size
+ -- Compute the number of parameters for the branches: for this we use
+ -- the type of the uninstantiated casesOn constant
+ let branchesNumParams : Array Nat ← do
+ let env ← getEnv
+ let decl := env.constants.find! matcherName
+ let ty := decl.type
+ forallTelescope ty fun xs _ => do
+ let xs := xs.extract (scrutIdx + 1) xs.size
+ xs.mapM fun x => do
+ let xty ← inferType x
+ forallTelescope xty fun ys _ => do
+ pure ys.size
+ let me : MatchInfo := {
+ matcherName,
+ matcherLevels,
+ params,
+ motive,
+ scruts := #[scrut],
+ branchesNumParams,
+ branches,
+ }
+ proveMatchIsValid k_var kk_var me
+ -- Monadic let-binding
+ else if f.isConstOf ``Bind.bind then do
+ trace[Diverge.def.valid] "bind:\n{args}"
+ let x := args.get! 4
+ let y := args.get! 5
+ -- Prove that the subexpressions are valid
+ let xValid ← proveExprIsValid k_var kk_var x
+ trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}"
+ let yValid ← do
+ -- This is a lambda expression -- TODO: how do we remove exacly *one* lambda?
+ lambdaLetTelescope y fun xs y => do
+ let x := xs.get! 0
+ let xs := xs.extract 1 xs.size
+ let y ← mkLambdaFVars xs y
+ trace[Diverge.def.valid] "bind: y: {y}"
+ let yValid ← proveExprIsValid k_var kk_var y
+ trace[Diverge.def.valid] "bind: yValid (no forall): {yValid}"
+ trace[Diverge.def.valid] "bind: yValid: x: {x}"
+ let yValid ← mkLambdaFVars #[x] yValid
+ trace[Diverge.def.valid] "bind: yValid (forall): {yValid}: {← inferType yValid}"
+ pure yValid
+ -- Put everything together
+ trace[Diverge.def.valid] "bind:\n- xValid: {xValid}: {← inferType xValid}\n- yValid: {yValid}: {← inferType yValid}"
+ mkAppM ``FixI.is_valid_p_bind #[xValid, yValid]
+ -- Recursive call
+ else if f.isFVarOf kk_var.fvarId! then do
+ trace[Diverge.def.valid] "rec: args: \n{args}"
+ if args.size ≠ 2 then throwError "Recursive call with invalid number of parameters: {args}"
+ let i_arg := args.get! 0
+ let x_arg := args.get! 1
+ let eIsValid ← mkAppM ``FixI.is_valid_p_rec #[k_var, i_arg, x_arg]
+ trace[Diverge.def.valid] "rec: result: \n{eIsValid}"
+ pure eIsValid
+ else do
+ -- Remaining case: normal application.
+ -- It shouldn't use the continuation
+ proveNoKExprIsValid k_var e
+
+partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Expr := do
+ trace[Diverge.def.valid] "proveMatchIsValid: {me}"
+ -- Prove the validity of the branch expressions
+ let branchesValid:Array Expr ← me.branches.mapIdxM fun idx br => do
+ -- Go inside the lambdas - note that we have to be careful: some of the
+ -- binders might come from the match, and some of the binders might come
+ -- from the fact that the expression in the match is a lambda expression:
+ -- we use the branchesNumParams field for this reason
+ lambdaLetTelescope br fun xs br => do
+ let numParams := me.branchesNumParams.get! idx
+ let xs_beg := xs.extract 0 numParams
+ let xs_end := xs.extract numParams xs.size
+ let br ← mkLambdaFVars xs_end br
+ -- Prove that the branch expression is valid
+ let brValid ← proveExprIsValid k_var kk_var br
+ -- Reconstruct the lambda expression
+ mkLambdaFVars xs_beg brValid
+ trace[Diverge.def.valid] "branchesValid:\n{branchesValid}"
+ -- Put together: compute the motive.
+ -- It must be of the shape:
+ -- ```
+ -- λ scrut => is_valid_p k (λ k => match scrut with ...)
+ -- ```
+ let validMotive : Expr ← do
+ -- The motive is a function of the scrutinees (i.e., a lambda expression):
+ -- introduce binders for the scrutinees
+ let declInfos := me.scruts.mapIdx fun idx scrut =>
+ let name : Name := (.num (.str .anonymous "scrut") idx)
+ let ty := λ (_ : Array Expr) => inferType scrut
+ (name, ty)
+ withLocalDeclsD declInfos fun scrutVars => do
+ -- Create a match expression but where the scrutinees have been replaced
+ -- by variables
+ let params : Array (Option Expr) := me.params.map some
+ let motive : Option Expr := some me.motive
+ let scruts : Array (Option Expr) := scrutVars.map some
+ let branches : Array (Option Expr) := me.branches.map some
+ let args := params ++ [motive] ++ scruts ++ branches
+ let matchE ← mkAppOptM me.matcherName args
+ -- let matchE ← mkLambdaFVars scrutVars (← mkAppOptM me.matcherName args)
+ -- Wrap in the `is_valid_p` predicate
+ let matchE ← mkLambdaFVars #[kk_var] matchE
+ let validMotive ← mkAppM ``FixI.is_valid_p #[k_var, matchE]
+ -- Abstract away the scrutinee variables
+ mkLambdaFVars scrutVars validMotive
+ trace[Diverge.def.valid] "valid motive: {validMotive}"
+ -- Put together
+ let valid ← do
+ let params : Array (Option Expr) := me.params.map (λ _ => none)
+ let motive := some validMotive
+ let scruts := me.scruts.map some
+ let branches := branchesValid.map some
+ let args := params ++ [motive] ++ scruts ++ branches
+ mkAppOptM me.matcherName args
+ trace[Diverge.def.valid] "proveMatchIsValid:\n{valid}:\n{← inferType valid}"
+ pure valid
+
+end
+
+-- Prove that a single body (in the mutually recursive group) is valid
+partial def proveSingleBodyIsValid (k_var : Expr) (preDef : PreDefinition) (bodyConst : Expr) :
+ MetaM Expr := do
+ trace[Diverge.def.valid] "proveSingleBodyIsValid: bodyConst: {bodyConst}"
+ -- Lookup the definition (`bodyConst` is the definition of the body, we want
+ -- to retrieve the value itself to dive inside)
+ let name := bodyConst.constName!
+ let env ← getEnv
+ let body := (env.constants.find! name).value!
+ trace[Diverge.def.valid] "body: {body}"
+ lambdaLetTelescope body fun xs body => do
+ assert! xs.size = 2
+ let kk_var := xs.get! 0
+ let x_var := xs.get! 1
+ -- State the type of the theorem to prove
+ let thmTy ← mkAppM ``FixI.is_valid_p
+ #[k_var, ← mkLambdaFVars #[kk_var] (← mkAppM' bodyConst #[kk_var, x_var])]
+ trace[Diverge.def.valid] "thmTy: {thmTy}"
+ -- Prove that the body is valid
+ let proof ← proveExprIsValid k_var kk_var body
+ let proof ← mkLambdaFVars #[k_var, x_var] proof
+ trace[Diverge.def.valid] "proveSingleBodyIsValid: proof:\n{proof}:\n{← inferType proof}"
+ -- The target type (we don't have to do this: this is simply a sanity check,
+ -- and this allows a nicer debugging output)
+ let thmTy ← do
+ let body ← mkAppM' bodyConst #[kk_var, x_var]
+ let body ← mkLambdaFVars #[kk_var] body
+ let ty ← mkAppM ``FixI.is_valid_p #[k_var, body]
+ mkForallFVars #[k_var, x_var] ty
+ trace[Diverge.def.valid] "proveSingleBodyIsValid: thmTy\n{thmTy}:\n{← inferType thmTy}"
+ -- Save the theorem
+ let name := preDef.declName ++ "sbody_is_valid"
+ let decl := Declaration.thmDecl {
+ name
+ levelParams := preDef.levelParams
+ type := thmTy
+ value := proof
+ all := [name]
+ }
+ addDecl decl
+ trace[Diverge.def.valid] "proveSingleBodyIsValid: added thm: {name}"
+ -- Return the theorem
+ pure (Expr.const name (preDef.levelParams.map .param))
+
+partial def proveFunsBodyIsValid (inOutTys: Expr) (bodyFuns : Expr)
+ (k_var : Expr) (bodiesValid : Array Expr) : MetaM Expr := do
+ -- Create the big "and" expression, which groups the validity proof of the individual bodies
+ let rec mkValidConj (i : Nat) : MetaM Expr := do
+ if i = bodiesValid.size then
+ -- We reached the end
+ mkAppM ``FixI.Funs.is_valid_p_Nil #[k_var]
+ else do
+ -- We haven't reached the end: introduce a conjunction
+ let valid := bodiesValid.get! i
+ let valid ← mkAppM' valid #[k_var]
+ mkAppM ``And.intro #[valid, ← mkValidConj (i + 1)]
+ let andExpr ← mkValidConj 0
+ -- Wrap in the `is_valid_p_is_valid_p` theorem, and abstract the continuation
+ let isValid ← mkAppM ``FixI.Funs.is_valid_p_is_valid_p #[inOutTys, k_var, bodyFuns, andExpr]
+ mkLambdaFVars #[k_var] isValid
+
+-- Prove that the mut rec body is valid
+-- TODO: maybe this function should introduce k_var itself
+def proveMutRecIsValid
+ (grName : Name) (grLvlParams : List Name)
+ (inOutTys : Expr) (bodyFuns mutRecBodyConst : Expr)
+ (k_var : Expr) (preDefs : Array PreDefinition)
+ (bodies : Array Expr) : MetaM Expr := do
+ -- First prove that the individual bodies are valid
+ let bodiesValid ←
+ bodies.mapIdxM fun idx body => do
+ let preDef := preDefs.get! idx
+ proveSingleBodyIsValid k_var preDef body
+ -- Then prove that the mut rec body is valid
+ let isValid ← proveFunsBodyIsValid inOutTys bodyFuns k_var bodiesValid
+ -- Save the theorem
+ let thmTy ← mkAppM ``FixI.is_valid #[mutRecBodyConst]
+ let name := grName ++ "mut_rec_body_is_valid"
+ let decl := Declaration.thmDecl {
+ name
+ levelParams := grLvlParams
+ type := thmTy
+ value := isValid
+ all := [name]
+ }
+ addDecl decl
+ trace[Diverge.def.valid] "proveFunsBodyIsValid: added thm: {name}:\n{thmTy}"
+ -- Return the theorem
+ pure (Expr.const name (grLvlParams.map .param))
-- Generate the final definions by using the mutual body and the fixed point operator.
-def mkDeclareFixDefs (mutBody : Expr) (preDefs : Array PreDefinition) :
+def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) :
TermElabM Unit := do
let grSize := preDefs.size
let _ ← preDefs.mapIdxM fun idx preDef => do
@@ -357,7 +696,7 @@ def mkDeclareFixDefs (mutBody : Expr) (preDefs : Array PreDefinition) :
-- Group the inputs into a dependent tuple
let input ← mkSigmas xs.toList
-- Apply the fixed point
- let fixedBody ← mkAppM ``FixI.fix #[mutBody, idx, input]
+ let fixedBody ← mkAppM ``FixI.fix #[mutRecBody, idx, input]
let fixedBody ← mkLambdaFVars xs fixedBody
-- Create the declaration
let name := preDef.declName
@@ -454,24 +793,26 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
-- Introduce the continuation `k`
let in_ty ← mkLambdaFVars #[i_var] in_ty
let out_ty ← mkLambdaFVars #[i_var, input] out_ty
- let k_var_ty ← mkAppM ``FixI.kk_ty #[i_var_ty, in_ty, out_ty] --
- trace[Diverge.def] "k_var_ty: {k_var_ty}"
- withLocalDeclD (.num (.str .anonymous "k") 2) k_var_ty fun k_var => do
- trace[Diverge.def] "k_var: {k_var}"
+ let kk_var_ty ← mkAppM ``FixI.kk_ty #[i_var_ty, in_ty, out_ty]
+ trace[Diverge.def] "kk_var_ty: {kk_var_ty}"
+ withLocalDeclD (.num (.str .anonymous "kk") 2) kk_var_ty fun kk_var => do
+ trace[Diverge.def] "kk_var: {kk_var}"
-- Replace the recursive calls in all the function bodies by calls to the
-- continuation `k` and and generate for those bodies declarations
- let bodies ← mkDeclareUnaryBodies grLvlParams k_var preDefs
+ let bodies ← mkDeclareUnaryBodies grLvlParams kk_var preDefs
-- Generate the mutually recursive body
- let body ← mkDeclareMutualBody grName grLvlParams i_var k_var in_ty out_ty inOutTys.toList bodies
- trace[Diverge.def] "mut rec body (after decl): {body}"
+ let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var in_ty out_ty inOutTys.toList bodies
+ trace[Diverge.def] "mut rec body (after decl): {mutRecBody}"
-- Prove that the mut rec body satisfies the validity criteria required by
-- our fixed-point
- -- TODO
+ let k_var_ty ← mkAppM ``FixI.k_ty #[i_var_ty, in_ty, out_ty]
+ withLocalDeclD (.num (.str .anonymous "k") 3) k_var_ty fun k_var => do
+ let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies
-- Generate the final definitions
- let defs ← mkDeclareFixDefs body preDefs
+ let defs ← mkDeclareFixDefs mutRecBody preDefs
-- Prove the unfolding equations
-- TODO
@@ -496,13 +837,10 @@ def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLC
for preDefs in cliques do
trace[Diverge.elab] "{preDefs.map (·.declName)}"
try
- trace[Diverge.elab] "calling divRecursion"
withRef (preDefs[0]!.ref) do
divRecursion preDefs
- trace[Diverge.elab] "divRecursion succeeded"
catch ex =>
- -- If it failed, we
- trace[Diverge.elab] "divRecursion failed"
+ -- If it failed, we add the functions as partial functions
hasErrors := true
logException ex
let s ← saveState
@@ -600,7 +938,8 @@ divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a :=
else return (← list_nth ls (i - 1))
#print list_nth.in_out_ty
-#check list_nth.body
+#check list_nth.sbody
+#check list_nth.mut_rec_body
#print list_nth
mutual
diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean
index 82f79f94..281dbd6c 100644
--- a/backends/lean/Base/Diverge/ElabBase.lean
+++ b/backends/lean/Base/Diverge/ElabBase.lean
@@ -8,6 +8,7 @@ initialize registerTraceClass `Diverge.elab
initialize registerTraceClass `Diverge.def
initialize registerTraceClass `Diverge.def.sigmas
initialize registerTraceClass `Diverge.def.genBody
+initialize registerTraceClass `Diverge.def.valid
-- TODO: move
-- TODO: small helper