summaryrefslogtreecommitdiff
path: root/backends/lean/Base/Diverge
diff options
context:
space:
mode:
authorSon Ho2023-12-11 17:00:01 +0100
committerSon Ho2023-12-11 17:00:01 +0100
commitc23a37617188a1bbf913b5c700522abc33bf39c9 (patch)
tree574c4d1ccde77f5952d5309621dac152452b2e94 /backends/lean/Base/Diverge
parentcb332ffb55425e6e6bc3b0ef8da7e646b2174fdf (diff)
Update Diverge/Elab.lean to use the more general FixII definitions
Diffstat (limited to 'backends/lean/Base/Diverge')
-rw-r--r--backends/lean/Base/Diverge/Base.lean9
-rw-r--r--backends/lean/Base/Diverge/Elab.lean571
-rw-r--r--backends/lean/Base/Diverge/ElabBase.lean2
3 files changed, 387 insertions, 195 deletions
diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean
index 9d986f4e..a7107c1e 100644
--- a/backends/lean/Base/Diverge/Base.lean
+++ b/backends/lean/Base/Diverge/Base.lean
@@ -581,7 +581,6 @@ namespace FixI
kk_ty id a b → kk_ty id a b
abbrev in_out_ty : Type (imax (u + 1) (v + 1)) := (in_ty : Type u) × ((x:in_ty) → Type v)
- -- TODO: remove?
abbrev mk_in_out_ty (in_ty : Type u) (out_ty : in_ty → Type v) :
in_out_ty :=
Sigma.mk in_ty out_ty
@@ -717,11 +716,8 @@ namespace FixI
end FixI
namespace FixII
- /- Indexed fixed-point: definitions with indexed types, convenient to use for mutually
- recursive definitions. We simply port the definitions and proofs from Fix to a more
- specific case.
-
- Here however, we group the types into a parameter distinct from the inputs.
+ /- Similar to FixI, but we split the input arguments between the type parameters
+ and the input values.
-/
open Primitives Fix
@@ -792,7 +788,6 @@ namespace FixII
abbrev in_out_ty : Type (imax (u + 1) (imax (v + 1) (w + 1))) :=
(ty : Type u) × (ty → Type v) × (ty → Type w)
- -- TODO: remove?
abbrev mk_in_out_ty (ty : Type u) (in_ty : ty → Type v) (out_ty : ty → Type w) :
in_out_ty :=
Sigma.mk ty (Prod.mk in_ty out_ty)
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean
index 0088fd16..423a2514 100644
--- a/backends/lean/Base/Diverge/Elab.lean
+++ b/backends/lean/Base/Diverge/Elab.lean
@@ -24,8 +24,8 @@ open WF in
def mkProd (x y : Expr) : MetaM Expr :=
mkAppM ``Prod.mk #[x, y]
-def mkInOutTy (x y : Expr) : MetaM Expr :=
- mkAppM ``FixI.mk_in_out_ty #[x, y]
+def mkInOutTy (x y z : Expr) : MetaM Expr := do
+ mkAppM ``FixII.mk_in_out_ty #[x, y, z]
-- Return the `a` in `Return a`
def getResultTy (ty : Expr) : MetaM Expr :=
@@ -60,21 +60,83 @@ def getSigmaTypes (ty : Expr) : MetaM (Expr × Expr) := do
def mkSigmasType (xl : List Expr) : MetaM Expr :=
match xl with
| [] => do
- trace[Diverge.def.sigmas] "mkSigmasOfTypes: []"
- pure (Expr.const ``PUnit.unit [])
+ trace[Diverge.def.sigmas] "mkSigmasType: []"
+ pure (Expr.const ``PUnit [Level.succ .zero])
| [x] => do
- trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}]"
+ trace[Diverge.def.sigmas] "mkSigmasType: [{x}]"
let ty ← Lean.Meta.inferType x
pure ty
| x :: xl => do
- trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]"
+ trace[Diverge.def.sigmas] "mkSigmasType: [{x}::{xl}]"
let alpha ← Lean.Meta.inferType x
let sty ← mkSigmasType xl
- trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]: alpha={alpha}, sty={sty}"
+ trace[Diverge.def.sigmas] "mkSigmasType: [{x}::{xl}]: alpha={alpha}, sty={sty}"
let beta ← mkLambdaFVars #[x] sty
- trace[Diverge.def.sigmas] "mkSigmasOfTypes: ({alpha}) ({beta})"
+ trace[Diverge.def.sigmas] "mkSigmasType: ({alpha}) ({beta})"
mkAppOptM ``Sigma #[some alpha, some beta]
+/- Generate a product type from a list of *variables* (this is similar to `mkSigmas`).
+
+ Example:
+ - xl = [(ls:List a), (i:Int)]
+
+ Generates:
+ `List a × Int`
+ -/
+def mkProdsType (xl : List Expr) : MetaM Expr :=
+ match xl with
+ | [] => do
+ trace[Diverge.def.prods] "mkProdsType: []"
+ pure (Expr.const ``PUnit [Level.succ .zero])
+ | [x] => do
+ trace[Diverge.def.prods] "mkProdsType: [{x}]"
+ let ty ← Lean.Meta.inferType x
+ pure ty
+ | x :: xl => do
+ trace[Diverge.def.prods] "mkProdsType: [{x}::{xl}]"
+ let ty ← Lean.Meta.inferType x
+ let xl_ty ← mkProdsType xl
+ mkAppM ``Prod #[ty, xl_ty]
+
+/- Split the input arguments between the types and the "regular" arguments.
+
+ We do something simple: we treat an input argument as an
+ input type iff it appears in the type of the following arguments.
+
+ Note that what really matters is that we find the arguments which appear
+ in the output type.
+
+ Also, we stop at the first input that we treat as an
+ input type.
+ -/
+def splitInputArgs (in_tys : Array Expr) (out_ty : Expr) : MetaM (Array Expr × Array Expr) := do
+ -- Look for the first parameter which appears in the subsequent parameters
+ let rec splitAux (in_tys : List Expr) : MetaM (HashSet FVarId × List Expr × List Expr) :=
+ match in_tys with
+ | [] => do
+ let fvars ← getFVarIds (← Lean.Meta.inferType out_ty)
+ pure (fvars, [], [])
+ | ty :: in_tys => do
+ let (fvars, in_tys, in_args) ← splitAux in_tys
+ -- Have we already found where to split between type variables/regular
+ -- variables?
+ if ¬ in_tys.isEmpty then
+ -- The fvars set is now useless: no need to update it anymore
+ pure (fvars, ty :: in_tys, in_args)
+ else
+ -- Check if ty appears in the set of free variables:
+ let ty_id := ty.fvarId!
+ if fvars.contains ty_id then
+ -- We must split here. Note that we don't need to update the fvars
+ -- set: it is not useful anymore
+ pure (fvars, [ty], in_args)
+ else
+ -- We must split later: update the fvars set
+ let fvars := fvars.insertMany (← getFVarIds (← Lean.Meta.inferType ty))
+ pure (fvars, [], ty :: in_args)
+ let (_, in_tys, in_args) ← splitAux in_tys.data
+ pure (Array.mk in_tys, Array.mk in_args)
+
/- Apply a lambda expression to some arguments, simplifying the lambdas -/
def applyLambdaToArgs (e : Expr) (xs : Array Expr) : MetaM Expr := do
lambdaTelescopeN e xs.size fun vars body =>
@@ -105,7 +167,7 @@ def mkSigmasVal (ty : Expr) (xl : List Expr) : MetaM Expr :=
match xl with
| [] => do
trace[Diverge.def.sigmas] "mkSigmasVal: []"
- pure (Expr.const ``PUnit.unit [])
+ pure (Expr.const ``PUnit.unit [Level.succ .zero])
| [x] => do
trace[Diverge.def.sigmas] "mkSigmasVal: [{x}]"
pure x
@@ -122,6 +184,17 @@ def mkSigmasVal (ty : Expr) (xl : List Expr) : MetaM Expr :=
trace[Diverge.def.sigmas] "mkSigmasVal:\n{alpha}\n{beta}\n{fst}\n{snd}"
mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd]
+/- Group a list of expressions into a (non-dependent) tuple -/
+def mkProdsVal (xl : List Expr) : MetaM Expr :=
+ match xl with
+ | [] =>
+ pure (Expr.const ``PUnit.unit [Level.succ .zero])
+ | [x] => do
+ pure x
+ | x :: xl => do
+ let xl ← mkProdsVal xl
+ mkAppM ``Prod.mk #[x, xl]
+
def mkAnonymous (s : String) (i : Nat) : Name :=
.num (.str .anonymous s) i
@@ -159,23 +232,23 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met
match xl with
| [] => do
-- This would be unexpected
- throwError "mkSigmasMatch: empyt list of input parameters"
+ throwError "mkSigmasMatch: empty list of input parameters"
| [x] => do
-- In the example given for the explanations: this is the inner match case
trace[Diverge.def.sigmas] "mkSigmasMatch: [{x}]"
mkLambdaFVars #[x] out
| fst :: xl => do
- -- In the example given for the explanations: this is the outer match case
- -- Remark: for the naming purposes, we use the same convention as for the
- -- fields and parameters in `Sigma.casesOn` and `Sigma.mk` (looking at
- -- those definitions might help)
- --
- -- We want to build the match expression:
- -- ```
- -- λ scrut =>
- -- match scrut with
- -- | Sigma.mk x ... -- the hole is given by a recursive call on the tail
- -- ```
+ /- In the example given for the explanations: this is the outer match case
+ Remark: for the naming purposes, we use the same convention as for the
+ fields and parameters in `Sigma.casesOn` and `Sigma.mk` (looking at
+ those definitions might help)
+
+ We want to build the match expression:
+ ```
+ λ scrut =>
+ match scrut with
+ | Sigma.mk x ... -- the hole is given by a recursive call on the tail
+ ``` -/
trace[Diverge.def.sigmas] "mkSigmasMatch: [{fst}::{xl}]"
let alpha ← Lean.Meta.inferType fst
let snd_ty ← mkSigmasType xl
@@ -183,7 +256,7 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met
let snd ← mkSigmasMatch xl out (index + 1)
let mk ← mkLambdaFVars #[fst] snd
-- Introduce the "scrut" variable
- let scrut_ty ← mkSigmasType (fst :: xl)
+ let scrut_ty ← mkSigmasType (fst :: xl) -- TODO: factor out with snd_ty
withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do
trace[Diverge.def.sigmas] "mkSigmasMatch: scrut: ({scrut}) : ({← inferType scrut})"
-- TODO: make the computation of the motive more efficient
@@ -206,6 +279,67 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met
trace[Diverge.def.sigmas] "mkSigmasMatch: sm: {sm}"
pure sm
+/- This is similar to `mkSigmasMatch`, but with non-dependent tuples
+
+ Remark: factor out with `mkSigmasMatch`? This is extremely similar.
+-/
+partial def mkProdsMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : MetaM Expr :=
+ match xl with
+ | [] => do
+ -- This would be unexpected
+ throwError "mkProdsMatch: empty list of input parameters"
+ | [x] => do
+ -- In the example given for the explanations: this is the inner match case
+ trace[Diverge.def.prods] "mkProdsMatch: [{x}]"
+ mkLambdaFVars #[x] out
+ | fst :: xl => do
+ trace[Diverge.def.prods] "mkProdsMatch: [{fst}::{xl}]"
+ let alpha ← Lean.Meta.inferType fst
+ let beta ← mkProdsType xl
+ let snd ← mkProdsMatch xl out (index + 1)
+ let mk ← mkLambdaFVars #[fst] snd
+ -- Introduce the "scrut" variable
+ let scrut_ty ← mkProdsType (fst :: xl) -- TODO: factor out with beta
+ withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do
+ trace[Diverge.def.prods] "mkProdsMatch: scrut: ({scrut}) : ({← inferType scrut})"
+ -- TODO: make the computation of the motive more efficient
+ let motive ← do
+ let out_ty ← inferType out
+ mkLambdaFVars #[scrut] out_ty
+ -- The final expression: putting everything together
+ trace[Diverge.def.prods] "mkProdsMatch:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})"
+ let sm ← mkAppOptM ``Prod.casesOn #[some alpha, some beta, some motive, some scrut, some mk]
+ -- Abstracting the "scrut" variable
+ let sm ← mkLambdaFVars #[scrut] sm
+ trace[Diverge.def.prods] "mkProdsMatch: sm: {sm}"
+ pure sm
+
+/- Same as `mkSigmasMatch` but also accepts an empty list of inputs, in which case
+ it generates the expression:
+ ```
+ λ () => e
+ ``` -/
+def mkSigmasMatchOrUnit (xl : List Expr) (out : Expr) : MetaM Expr :=
+ if xl.isEmpty then do
+ let scrut_ty := Expr.const ``PUnit [Level.succ .zero]
+ withLocalDeclD (mkAnonymous "scrut" 0) scrut_ty fun scrut => do
+ mkLambdaFVars #[scrut] out
+ else
+ mkSigmasMatch xl out
+
+/- Same as `mkProdsMatch` but also accepts an empty list of inputs, in which case
+ it generates the expression:
+ ```
+ λ () => e
+ ``` -/
+def mkProdsMatchOrUnit (xl : List Expr) (out : Expr) : MetaM Expr :=
+ if xl.isEmpty then do
+ let scrut_ty := Expr.const ``PUnit [Level.succ .zero]
+ withLocalDeclD (mkAnonymous "scrut" 0) scrut_ty fun scrut => do
+ mkLambdaFVars #[scrut] out
+ else
+ mkProdsMatch xl out
+
/- Small tests for list_nth: give a model of what `mkSigmasMatch` should generate -/
private def list_nth_out_ty_inner (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) :=
@Sigma.casesOn (List a)
@@ -244,17 +378,23 @@ def mkFinVal (n i : Nat) : MetaM Expr := do
We name the declarations: "[original_name].body".
We return the new declarations.
+
+ Inputs:
+ - `paramInOutTys`: (number of type parameters, sigma type grouping the type parameters, input type, output type)
-/
def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
- (inOutTys : Array (Expr × Expr)) (preDefs : Array PreDefinition) :
+ (paramInOutTys : Array (ℕ × Expr × Expr × Expr)) (preDefs : Array PreDefinition) :
MetaM (Array Expr) := do
let grSize := preDefs.size
- -- Compute the map from name to (index × input type).
- -- Remark: the continuation has an indexed type; we use the index (a finite number of
- -- type `Fin`) to control which function we call at the recursive call site.
- let nameToInfo : HashMap Name (Nat × Expr) :=
- let bl := preDefs.mapIdx fun i d => (d.declName, (i.val, (inOutTys.get! i.val).fst))
+ /- Compute the map from name to (index, num type parameters, parameters type, input type).
+ Example for `list_nth (α : Type) (ls : List α) (i : Int) : Result α`: `"list_nth" → (0, 1, α, (List α × Int))`
+ Remark: the continuation has an indexed type; we use the index (a finite number of
+ type `Fin`) to control which function we call at the recursive call site. -/
+ let nameToInfo : HashMap Name (Nat × Nat × Expr × Expr) :=
+ let bl := preDefs.mapIdx fun i d =>
+ let (num_params, params_ty, in_ty, _) := paramInOutTys.get! i.val
+ (d.declName, (i.val, num_params, params_ty, in_ty))
HashMap.ofList bl.toList
trace[Diverge.def.genBody] "nameToId: {nameToInfo.toList}"
@@ -262,25 +402,29 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
-- Auxiliary function to explore the function bodies and replace the
-- recursive calls
let visit_e (i : Nat) (e : Expr) : MetaM Expr := do
- trace[Diverge.def.genBody] "visiting expression (dept: {i}): {e}"
+ trace[Diverge.def.genBody.visit] "visiting expression (dept: {i}): {e}"
let ne ← do
match e with
| .app .. => do
e.withApp fun f args => do
- trace[Diverge.def.genBody] "this is an app: {f} {args}"
+ trace[Diverge.def.genBody.visit] "this is an app: {f} {args}"
-- Check if this is a recursive call
if f.isConst then
let name := f.constName!
match nameToInfo.find? name with
| none => pure e
- | some (id, in_ty) =>
- trace[Diverge.def.genBody] "this is a recursive call"
+ | some (id, num_params, params_ty, _in_ty) =>
+ trace[Diverge.def.genBody.visit] "this is a recursive call"
-- This is a recursive call: replace it
-- Compute the index
let i ← mkFinVal grSize id
- -- Put the arguments in one big dependent tuple
- let args ← mkSigmasVal in_ty args.toList
- mkAppM' kk_var #[i, args]
+ -- Split the arguments, and put them in two tuples (the first
+ -- one is a dependent tuple)
+ let (param_args, args) := args.toList.splitAt num_params
+ trace[Diverge.def.genBody.visit] "param_args: {param_args}, args: {args}"
+ let param_args ← mkSigmasVal params_ty param_args
+ let args ← mkProdsVal args
+ mkAppM' kk_var #[i, param_args, args]
else
-- Not a recursive call: do nothing
pure e
@@ -290,7 +434,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
throwError "mkUnaryBodies: a recursive call was not eliminated"
else pure e
| _ => pure e
- trace[Diverge.def.genBody] "done with expression (depth: {i}): {e}"
+ trace[Diverge.def.genBody.visit] "done with expression (depth: {i}): {e}"
pure ne
-- Explore the bodies
@@ -300,13 +444,20 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
let body ← mapVisit visit_e preDef.value
trace[Diverge.def.genBody] "Body after replacement of the recursive calls: {body}"
- -- Currify the function by grouping the arguments into a dependent tuple
+ -- Currify the function by grouping the arguments into dependent tuples
-- (over which we match to retrieve the individual arguments).
lambdaTelescope body fun args body => do
- let body ← mkSigmasMatch args.toList body 0
+ -- Split the arguments between the type parameters and the "regular" inputs
+ let (_, num_params, _, _) := nameToInfo.find! preDef.declName
+ let (param_args, args) := args.toList.splitAt num_params
+ let body ← mkProdsMatchOrUnit args body
+ trace[Diverge.def.genBody] "Body after mkProdsMatchOrUnit: {body}"
+ let body ← mkSigmasMatchOrUnit param_args body
+ trace[Diverge.def.genBody] "Body after mkSigmasMatchOrUnit: {body}"
-- Add the declaration
let value ← mkLambdaFVars #[kk_var] body
+ trace[Diverge.def.genBody] "Body after abstracting kk: {value}"
let name := preDef.declName.append "body"
let levelParams := grLvlParams
let decl := Declaration.defnDecl {
@@ -318,6 +469,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
safety := .safe
all := [name]
}
+ trace[Diverge.def.genBody] "About to add decl"
addDecl decl
trace[Diverge.def] "individual body of {preDef.declName}: {body}"
-- Return the constant
@@ -326,33 +478,35 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
trace[Diverge.def] "individual body (after decl): {body}"
pure body
--- Generate a unique function body from the bodies of the mutually recursive group,
--- and add it as a declaration in the context.
--- We return the list of bodies (of type `FixI.Funs ...`) and the mutually recursive body.
+/- Generate a unique function body from the bodies of the mutually recursive group,
+ and add it as a declaration in the context.
+ We return the list of bodies (of type `FixI.Funs ...`) and the mutually recursive body.
+ -/
def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name)
(kk_var i_var : Expr)
- (in_ty out_ty : Expr) (inOutTys : List (Expr × Expr))
+ (param_ty in_ty out_ty : Expr) (paramInOutTys : List (Nat × Expr × Expr × Expr))
(bodies : Array Expr) : MetaM (Expr × Expr) := do
-- Generate the body
let grSize := bodies.size
let finTypeExpr := mkFin grSize
-- TODO: not very clean
- let inOutTyType ← do
- let (x, y) := inOutTys.get! 0
- inferType (← mkInOutTy x y)
- let rec mkFuns (inOutTys : List (Expr × Expr)) (bl : List Expr) : MetaM Expr :=
- match inOutTys, bl with
+ let paramInOutTyType ← do
+ let (_, x, y, z) := paramInOutTys.get! 0
+ inferType (← mkInOutTy x y z)
+ let rec mkFuns (paramInOutTys : List (Nat × Expr × Expr × Expr)) (bl : List Expr) : MetaM Expr :=
+ match paramInOutTys, bl with
| [], [] =>
- mkAppOptM ``FixI.Funs.Nil #[finTypeExpr, in_ty, out_ty]
- | (ity, oty) :: inOutTys, b :: bl => do
+ mkAppOptM ``FixII.Funs.Nil #[finTypeExpr, param_ty, in_ty, out_ty]
+ | (_, pty, ity, oty) :: paramInOutTys, b :: bl => do
-- Retrieving ity and oty - this is not very clean
- let inOutTysExpr ← mkListLit inOutTyType (← inOutTys.mapM (λ (x, y) => mkInOutTy x y))
- let fl ← mkFuns inOutTys bl
- mkAppOptM ``FixI.Funs.Cons #[finTypeExpr, in_ty, out_ty, ity, oty, inOutTysExpr, b, fl]
+ let paramInOutTysExpr ← mkListLit paramInOutTyType
+ (← paramInOutTys.mapM (λ (_, x, y, z) => mkInOutTy x y z))
+ let fl ← mkFuns paramInOutTys bl
+ mkAppOptM ``FixII.Funs.Cons #[finTypeExpr, param_ty, in_ty, out_ty, pty, ity, oty, paramInOutTysExpr, b, fl]
| _, _ => throwError "mkDeclareMutRecBody: `tys` and `bodies` don't have the same length"
- let bodyFuns ← mkFuns inOutTys bodies.toList
+ let bodyFuns ← mkFuns paramInOutTys bodies.toList
-- Wrap in `get_fun`
- let body ← mkAppM ``FixI.get_fun #[bodyFuns, i_var, kk_var]
+ let body ← mkAppM ``FixII.get_fun #[bodyFuns, i_var, kk_var]
-- Add the index `i` and the continuation `k` as a variables
let body ← mkLambdaFVars #[kk_var, i_var] body
trace[Diverge.def] "mkDeclareMutRecBody: body: {body}"
@@ -391,11 +545,11 @@ instance : ToMessageData MatchInfo where
-- This is not a very clean formatting, but we don't need more
toMessageData := fun me => m!"\n- matcherName: {me.matcherName}\n- params: {me.params}\n- motive: {me.motive}\n- scruts: {me.scruts}\n- branchesNumParams: {me.branchesNumParams}\n- branches: {me.branches}"
--- Small helper: prove that an expression which doesn't use the continuation `kk`
--- is valid, and return the proof.
+/- Small helper: prove that an expression which doesn't use the continuation `kk`
+ is valid, and return the proof. -/
def proveNoKExprIsValid (k_var : Expr) (e : Expr) : MetaM Expr := do
trace[Diverge.def.valid] "proveNoKExprIsValid: {e}"
- let eIsValid ← mkAppM ``FixI.is_valid_p_same #[k_var, e]
+ let eIsValid ← mkAppM ``FixII.is_valid_p_same #[k_var, e]
trace[Diverge.def.valid] "proveNoKExprIsValid: result:\n{eIsValid}:\n{← inferType eIsValid}"
pure eIsValid
@@ -462,8 +616,11 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
mkLambdaFVars #[x] brValid
let br0Valid ← proveBranchValid br0
let br1Valid ← proveBranchValid br1
- let const := if isIte then ``FixI.is_valid_p_ite else ``FixI.is_valid_p_dite
- let eIsValid ← mkAppOptM const #[none, none, none, none, some k_var, some cond, some dec, none, none, some br0Valid, some br1Valid]
+ let const := if isIte then ``FixII.is_valid_p_ite else ``FixII.is_valid_p_dite
+ let eIsValid ←
+ mkAppOptM const #[none, none, none, none, none,
+ some k_var, some cond, some dec, none, none,
+ some br0Valid, some br1Valid]
trace[Diverge.def.valid] "ite/dite: result:\n{eIsValid}:\n{← inferType eIsValid}"
pure eIsValid
-- Check if the expression is a match (this case is for when the elaborator
@@ -498,15 +655,16 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
-- use to currify function bodies, introduce such raw matches.
else if ← isCasesExpr f then do
trace[Diverge.def.valid] "rawMatch: {e}"
- -- Deconstruct the match, and call the auxiliary helper `proveMatchIsValid`.
- --
- -- The casesOn definition is always of the following shape:
- -- - input parameters (implicit parameters)
- -- - motive (implicit), -- the motive gives the return type of the match
- -- - scrutinee (explicit)
- -- - branches (explicit).
- -- In particular, we notice that the scrutinee is the first *explicit*
- -- parameter - this is how we spot it.
+ /- Deconstruct the match, and call the auxiliary helper `proveMatchIsValid`.
+
+ The casesOn definition is always of the following shape:
+ - input parameters (implicit parameters)
+ - motive (implicit), -- the motive gives the return type of the match
+ - scrutinee (explicit)
+ - branches (explicit).
+ In particular, we notice that the scrutinee is the first *explicit*
+ parameter - this is how we spot it.
+ -/
let matcherName := f.constName!
let matcherLevels := f.constLevels!.toArray
-- Find the first explicit parameter: this is the scrutinee
@@ -526,10 +684,10 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
let motive := args.get! (scrutIdx - 1)
let scrut := args.get! scrutIdx
let branches := args.extract (scrutIdx + 1) args.size
- -- Compute the number of parameters for the branches: for this we use
- -- the type of the uninstantiated casesOn constant (we can't just
- -- destruct the lambdas in the branch expressions because the result
- -- of a match might be a lambda expression).
+ /- Compute the number of parameters for the branches: for this we use
+ the type of the uninstantiated casesOn constant (we can't just
+ destruct the lambdas in the branch expressions because the result
+ of a match might be a lambda expression). -/
let branchesNumParams : Array Nat ← do
let env ← getEnv
let decl := env.constants.find! matcherName
@@ -572,14 +730,15 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
pure yValid
-- Put everything together
trace[Diverge.def.valid] "bind:\n- xValid: {xValid}: {← inferType xValid}\n- yValid: {yValid}: {← inferType yValid}"
- mkAppM ``FixI.is_valid_p_bind #[xValid, yValid]
+ mkAppM ``FixII.is_valid_p_bind #[xValid, yValid]
-- Check if this is a recursive call, i.e., a call to the continuation `kk`
else if f.isFVarOf kk_var.fvarId! then do
trace[Diverge.def.valid] "rec: args: \n{args}"
- if args.size ≠ 2 then throwError "Recursive call with invalid number of parameters: {args}"
+ if args.size ≠ 3 then throwError "Recursive call with invalid number of parameters: {args}"
let i_arg := args.get! 0
- let x_arg := args.get! 1
- let eIsValid ← mkAppM ``FixI.is_valid_p_rec #[k_var, i_arg, x_arg]
+ let t_arg := args.get! 1
+ let x_arg := args.get! 2
+ let eIsValid ← mkAppM ``FixII.is_valid_p_rec #[k_var, i_arg, t_arg, x_arg]
trace[Diverge.def.valid] "rec: result: \n{eIsValid}"
pure eIsValid
else do
@@ -593,10 +752,10 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp
trace[Diverge.def.valid] "proveMatchIsValid: {me}"
-- Prove the validity of the branch expressions
let branchesValid:Array Expr ← me.branches.mapIdxM fun idx br => do
- -- Go inside the lambdas - note that we have to be careful: some of the
- -- binders might come from the match, and some of the binders might come
- -- from the fact that the expression in the match is a lambda expression:
- -- we use the branchesNumParams field for this reason
+ /- Go inside the lambdas - note that we have to be careful: some of the
+ binders might come from the match, and some of the binders might come
+ from the fact that the expression in the match is a lambda expression:
+ we use the branchesNumParams field for this reason. -/
let numParams := me.branchesNumParams.get! idx
lambdaTelescopeN br numParams fun xs br => do
-- Prove that the branch expression is valid
@@ -604,13 +763,13 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp
-- Reconstruct the lambda expression
mkLambdaFVars xs brValid
trace[Diverge.def.valid] "branchesValid:\n{branchesValid}"
- -- Compute the motive, which has the following shape:
- -- ```
- -- λ scrut => is_valid_p k (λ k => match scrut with ...)
- -- ^^^^^^^^^^^^^^^^^^^^
- -- this is the original match expression, with the
- -- the difference that the scrutinee(s) is a variable
- -- ```
+ /- Compute the motive, which has the following shape:
+ ```
+ λ scrut => is_valid_p k (λ k => match scrut with ...)
+ ^^^^^^^^^^^^^^^^^^^^
+ this is the original match expression, with the
+ the difference that the scrutinee(s) is a variable
+ ``` -/
let validMotive : Expr ← do
-- The motive is a function of the scrutinees (i.e., a lambda expression):
-- introduce binders for the scrutinees
@@ -629,7 +788,7 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp
let matchE ← mkAppOptM me.matcherName args
-- Wrap in the `is_valid_p` predicate
let matchE ← mkLambdaFVars #[kk_var] matchE
- let validMotive ← mkAppM ``FixI.is_valid_p #[k_var, matchE]
+ let validMotive ← mkAppM ``FixII.is_valid_p #[k_var, matchE]
-- Abstract away the scrutinee variables
mkLambdaFVars scrutVars validMotive
trace[Diverge.def.valid] "valid motive: {validMotive}"
@@ -647,10 +806,10 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp
end
--- Prove that a single body (in the mutually recursive group) is valid.
---
--- For instance, if we define the mutually recursive group [`is_even`, `is_odd`],
--- we prove that `is_even.body` and `is_odd.body` are valid.
+/- Prove that a single body (in the mutually recursive group) is valid.
+
+ For instance, if we define the mutually recursive group [`is_even`, `is_odd`],
+ we prove that `is_even.body` and `is_odd.body` are valid. -/
partial def proveSingleBodyIsValid
(k_var : Expr) (preDef : PreDefinition) (bodyConst : Expr) :
MetaM Expr := do
@@ -662,24 +821,28 @@ partial def proveSingleBodyIsValid
let body := (env.constants.find! name).value!
trace[Diverge.def.valid] "body: {body}"
lambdaTelescope body fun xs body => do
- assert! xs.size = 2
+ trace[Diverge.def.valid] "xs: {xs}"
+ assert! xs.size = 3
let kk_var := xs.get! 0
- let x_var := xs.get! 1
+ let t_var := xs.get! 1
+ let x_var := xs.get! 2
-- State the type of the theorem to prove
- let thmTy ← mkAppM ``FixI.is_valid_p
- #[k_var, ← mkLambdaFVars #[kk_var] (← mkAppM' bodyConst #[kk_var, x_var])]
+ trace[Diverge.def.valid] "bodyConst: {bodyConst} : {← inferType bodyConst}"
+ let bodyApp ← mkAppOptM' bodyConst #[.some kk_var, .some t_var, .some x_var]
+ trace[Diverge.def.valid] "bodyApp: {bodyApp}"
+ let bodyApp ← mkLambdaFVars #[kk_var] bodyApp
+ trace[Diverge.def.valid] "bodyApp: {bodyApp}"
+ let thmTy ← mkAppM ``FixII.is_valid_p #[k_var, bodyApp]
trace[Diverge.def.valid] "thmTy: {thmTy}"
-- Prove that the body is valid
let proof ← proveExprIsValid k_var kk_var body
- let proof ← mkLambdaFVars #[k_var, x_var] proof
+ let proof ← mkLambdaFVars #[k_var, t_var, x_var] proof
trace[Diverge.def.valid] "proveSingleBodyIsValid: proof:\n{proof}:\n{← inferType proof}"
-- The target type (we don't have to do this: this is simply a sanity check,
-- and this allows a nicer debugging output)
let thmTy ← do
- let body ← mkAppM' bodyConst #[kk_var, x_var]
- let body ← mkLambdaFVars #[kk_var] body
- let ty ← mkAppM ``FixI.is_valid_p #[k_var, body]
- mkForallFVars #[k_var, x_var] ty
+ let ty ← mkAppM ``FixII.is_valid_p #[k_var, bodyApp]
+ mkForallFVars #[k_var, t_var, x_var] ty
trace[Diverge.def.valid] "proveSingleBodyIsValid: thmTy\n{thmTy}:\n{← inferType thmTy}"
-- Save the theorem
let name := preDef.declName ++ "body_is_valid"
@@ -695,18 +858,18 @@ partial def proveSingleBodyIsValid
-- Return the theorem
pure (Expr.const name (preDef.levelParams.map .param))
--- Prove that the list of bodies are valid.
---
--- For instance, if we define the mutually recursive group [`is_even`, `is_odd`],
--- we prove that `Funs.Cons is_even.body (Funs.Cons is_odd.body Funs.Nil)` is
--- valid.
-partial def proveFunsBodyIsValid (inOutTys: Expr) (bodyFuns : Expr)
+/- Prove that the list of bodies are valid.
+
+ For instance, if we define the mutually recursive group [`is_even`, `is_odd`],
+ we prove that `Funs.Cons is_even.body (Funs.Cons is_odd.body Funs.Nil)` is
+ valid. -/
+partial def proveFunsBodyIsValid (paramInOutTys: Expr) (bodyFuns : Expr)
(k_var : Expr) (bodiesValid : Array Expr) : MetaM Expr := do
-- Create the big "and" expression, which groups the validity proof of the individual bodies
let rec mkValidConj (i : Nat) : MetaM Expr := do
if i = bodiesValid.size then
-- We reached the end
- mkAppM ``FixI.Funs.is_valid_p_Nil #[k_var]
+ mkAppM ``FixII.Funs.is_valid_p_Nil #[k_var]
else do
-- We haven't reached the end: introduce a conjunction
let valid := bodiesValid.get! i
@@ -714,20 +877,20 @@ partial def proveFunsBodyIsValid (inOutTys: Expr) (bodyFuns : Expr)
mkAppM ``And.intro #[valid, ← mkValidConj (i + 1)]
let andExpr ← mkValidConj 0
-- Wrap in the `is_valid_p_is_valid_p` theorem, and abstract the continuation
- let isValid ← mkAppM ``FixI.Funs.is_valid_p_is_valid_p #[inOutTys, k_var, bodyFuns, andExpr]
+ let isValid ← mkAppM ``FixII.Funs.is_valid_p_is_valid_p #[paramInOutTys, k_var, bodyFuns, andExpr]
mkLambdaFVars #[k_var] isValid
--- Prove that the mut rec body (i.e., the unary body which groups the bodies
--- of all the functions in the mutually recursive group and on which we will
--- apply the fixed-point operator) is valid.
---
--- We save the proof in the theorem "[GROUP_NAME]."mut_rec_body_is_valid",
--- which we return.
---
--- TODO: maybe this function should introduce k_var itself
+/- Prove that the mut rec body (i.e., the unary body which groups the bodies
+ of all the functions in the mutually recursive group and on which we will
+ apply the fixed-point operator) is valid.
+
+ We save the proof in the theorem "[GROUP_NAME]."mut_rec_body_is_valid",
+ which we return.
+
+ TODO: maybe this function should introduce k_var itself -/
def proveMutRecIsValid
(grName : Name) (grLvlParams : List Name)
- (inOutTys : Expr) (bodyFuns mutRecBodyConst : Expr)
+ (paramInOutTys : Expr) (bodyFuns mutRecBodyConst : Expr)
(k_var : Expr) (preDefs : Array PreDefinition)
(bodies : Array Expr) : MetaM Expr := do
-- First prove that the individual bodies are valid
@@ -738,9 +901,9 @@ def proveMutRecIsValid
proveSingleBodyIsValid k_var preDef body
-- Then prove that the mut rec body is valid
trace[Diverge.def.valid] "## Proving that the 'Funs' body is valid"
- let isValid ← proveFunsBodyIsValid inOutTys bodyFuns k_var bodiesValid
+ let isValid ← proveFunsBodyIsValid paramInOutTys bodyFuns k_var bodiesValid
-- Save the theorem
- let thmTy ← mkAppM ``FixI.is_valid #[mutRecBodyConst]
+ let thmTy ← mkAppM ``FixII.is_valid #[mutRecBodyConst]
let name := grName ++ "mut_rec_body_is_valid"
let decl := Declaration.thmDecl {
name
@@ -754,26 +917,29 @@ def proveMutRecIsValid
-- Return the theorem
pure (Expr.const name (grLvlParams.map .param))
--- Generate the final definions by using the mutual body and the fixed point operator.
---
--- For instance:
--- ```
--- def is_even (i : Int) : Result Bool := mut_rec_body 0 i
--- def is_odd (i : Int) : Result Bool := mut_rec_body 1 i
--- ```
-def mkDeclareFixDefs (mutRecBody : Expr) (inOutTys : Array (Expr × Expr)) (preDefs : Array PreDefinition) :
+/- Generate the final definions by using the mutual body and the fixed point operator.
+
+ For instance:
+ ```
+ def is_even (i : Int) : Result Bool := mut_rec_body 0 i
+ def is_odd (i : Int) : Result Bool := mut_rec_body 1 i
+ ```
+ -/
+def mkDeclareFixDefs (mutRecBody : Expr) (paramInOutTys : Array (ℕ × Expr × Expr × Expr)) (preDefs : Array PreDefinition) :
TermElabM (Array Name) := do
let grSize := preDefs.size
let defs ← preDefs.mapIdxM fun idx preDef => do
lambdaTelescope preDef.value fun xs _ => do
- -- Retrieve the input type
- let in_ty := (inOutTys.get! idx.val).fst
+ -- Retrieve the parameters info
+ let (num_params, param_ty, _, _) := paramInOutTys.get! idx.val
-- Create the index
let idx ← mkFinVal grSize idx.val
- -- Group the inputs into a dependent tuple
- let input ← mkSigmasVal in_ty xs.toList
+ -- Group the inputs into two tuples
+ let (params_args, input_args) := xs.toList.splitAt num_params
+ let params ← mkSigmasVal param_ty params_args
+ let input ← mkProdsVal input_args
-- Apply the fixed point
- let fixedBody ← mkAppM ``FixI.fix #[mutRecBody, idx, input]
+ let fixedBody ← mkAppM ``FixII.fix #[mutRecBody, idx, params, input]
let fixedBody ← mkLambdaFVars xs fixedBody
-- Create the declaration
let name := preDef.declName
@@ -791,7 +957,8 @@ def mkDeclareFixDefs (mutRecBody : Expr) (inOutTys : Array (Expr × Expr)) (preD
pure defs
-- Prove the equations that we will use as unfolding theorems
-partial def proveUnfoldingThms (isValidThm : Expr) (inOutTys : Array (Expr × Expr))
+partial def proveUnfoldingThms (isValidThm : Expr)
+ (paramInOutTys : Array (ℕ × Expr × Expr × Expr))
(preDefs : Array PreDefinition) (decls : Array Name) : MetaM Unit := do
let grSize := preDefs.size
let proveIdx (i : Nat) : MetaM Unit := do
@@ -811,14 +978,18 @@ partial def proveUnfoldingThms (isValidThm : Expr) (inOutTys : Array (Expr × Ex
trace[Diverge.def.unfold] "proveUnfoldingThms: thm statement: {thmTy}"
-- The proof
-- Use the fixed-point equation
- let proof ← mkAppM ``FixI.is_valid_fix_fixed_eq #[isValidThm]
+ let proof ← mkAppM ``FixII.is_valid_fix_fixed_eq #[isValidThm]
-- Add the index
let idx ← mkFinVal grSize i
let proof ← mkAppM ``congr_fun #[proof, idx]
- -- Add the input argument
- let arg ← mkSigmasVal (inOutTys.get! i).fst xs.toList
- let proof ← mkAppM ``congr_fun #[proof, arg]
- -- Abstract the arguments away
+ -- Add the input arguments
+ let (num_params, param_ty, _, _) := paramInOutTys.get! i
+ let (params, args) := xs.toList.splitAt num_params
+ let params ← mkSigmasVal param_ty params
+ let args ← mkProdsVal args
+ let proof ← mkAppM ``congr_fun #[proof, params]
+ let proof ← mkAppM ``congr_fun #[proof, args]
+ -- Abstract all the arguments away
let proof ← mkLambdaFVars xs proof
trace[Diverge.def.unfold] "proveUnfoldingThms: proof: {proof}:\n{← inferType proof}"
-- Declare the theorem
@@ -846,7 +1017,9 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value)
trace[Diverge.def] ("divRecursion: defs:\n" ++ msg)
- -- TODO: what is this?
+ -- Apply all the "attribute" functions (for instance, the function which
+ -- registers the theorem in the simp database if there is the `simp` attribute,
+ -- etc.)
for preDef in preDefs do
applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation
@@ -860,40 +1033,52 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
let grLvlParams := def0.levelParams
trace[Diverge.def] "def0 universe levels: {def0.levelParams}"
- -- We first compute the list of pairs: (input type × output type)
- let inOutTys : Array (Expr × Expr) ←
- preDefs.mapM (fun preDef => do
- withRef preDef.ref do -- is the withRef useful?
- -- Check the universe parameters - TODO: I'm not sure what the best thing
- -- to do is. In practice, all the type parameters should be in Type 0, so
- -- we shouldn't have universe issues.
- if preDef.levelParams ≠ grLvlParams then
- throwError "Non-uniform polymorphism in the universes"
- forallTelescope preDef.type (fun in_tys out_ty => do
- let in_ty ← liftM (mkSigmasType in_tys.toList)
- -- Retrieve the type in the "Result"
- let out_ty ← getResultTy out_ty
- let out_ty ← liftM (mkSigmasMatch in_tys.toList out_ty)
- pure (in_ty, out_ty)
- )
- )
- trace[Diverge.def] "inOutTys: {inOutTys}"
- -- Turn the list of input/output type pairs into an expresion
- let inOutTysExpr ← inOutTys.mapM (λ (x, y) => mkInOutTy x y)
- let inOutTysExpr ← mkListLit (← inferType (inOutTysExpr.get! 0)) inOutTysExpr.toList
-
- -- From the list of pairs of input/output types, actually compute the
- -- type of the continuation `k`.
- -- We first introduce the index `i : Fin n` where `n` is the number of
- -- functions in the group.
+ /- We first compute the tuples: (type parameters × input type × output type)
+ - type parameters: this is a sigma type
+ - input type: λ params_type => product type
+ - output type: λ params_type => output type
+ For instance, on the function:
+ `list_nth (α : Type) (ls : List α) (i : Int) : Result α`:
+ we generate:
+ `(Type, λ α => List α × i, λ α => Result α)`
+ -/
+ let paramInOutTys : Array (ℕ × Expr × Expr × Expr) ←
+ preDefs.mapM (fun preDef => do
+ -- Check the universe parameters - TODO: I'm not sure what the best thing
+ -- to do is. In practice, all the type parameters should be in Type 0, so
+ -- we shouldn't have universe issues.
+ if preDef.levelParams ≠ grLvlParams then
+ throwError "Non-uniform polymorphism in the universes"
+ forallTelescope preDef.type (fun in_tys out_ty => do
+ let (params, in_tys) ← splitInputArgs in_tys out_ty
+ trace[Diverge.def] "Decomposed arguments: {preDef.declName}: {params}, {in_tys}, {out_ty}"
+ let num_params := params.size
+ let params_sigma ← mkSigmasType params.data
+ let in_tys ← mkSigmasMatchOrUnit params.data (← mkProdsType in_tys.data)
+ -- Retrieve the type in the "Result"
+ let out_ty ← getResultTy out_ty
+ let out_ty ← mkSigmasMatchOrUnit params.data out_ty
+ trace[Diverge.def] "inOutTy: {preDef.declName}: {params_sigma}, {in_tys}, {out_ty}"
+ pure (num_params, params_sigma, in_tys, out_ty)))
+ trace[Diverge.def] "paramInOutTys: {paramInOutTys}"
+ -- Turn the list of input types/input args/output type tuples into expressions
+ let paramInOutTysExpr ← paramInOutTys.mapM (λ (_, x, y, z) => do mkInOutTy x y z)
+ let paramInOutTysExpr ← mkListLit (← inferType (paramInOutTysExpr.get! 0)) paramInOutTysExpr.toList
+ trace[Diverge.def] "paramInOutTys: {paramInOutTys}"
+
+ /- From the list of pairs of input/output types, actually compute the
+ type of the continuation `k`.
+ We first introduce the index `i : Fin n` where `n` is the number of
+ functions in the group.
+ -/
let i_var_ty := mkFin preDefs.size
withLocalDeclD (mkAnonymous "i" 0) i_var_ty fun i_var => do
- let in_out_ty ← mkAppM ``List.get #[inOutTysExpr, i_var]
- trace[Diverge.def] "in_out_ty := {in_out_ty} : {← inferType in_out_ty}"
- -- Add an auxiliary definition for `in_out_ty`
- let in_out_ty ← do
- let value ← mkLambdaFVars #[i_var] in_out_ty
- let name := grName.append "in_out_ty"
+ let param_in_out_ty ← mkAppM ``List.get #[paramInOutTysExpr, i_var]
+ trace[Diverge.def] "param_in_out_ty := {param_in_out_ty} : {← inferType param_in_out_ty}"
+ -- Add an auxiliary definition for `param_in_out_ty` (this is a potentially big term)
+ let param_in_out_ty ← do
+ let value ← mkLambdaFVars #[i_var] param_in_out_ty
+ let name := grName.append "param_in_out_ty"
let levelParams := grLvlParams
let decl := Declaration.defnDecl {
name := name
@@ -906,19 +1091,28 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
}
addDecl decl
-- Return the constant
- let in_out_ty := Lean.mkConst name (levelParams.map .param)
- mkAppM' in_out_ty #[i_var]
- trace[Diverge.def] "in_out_ty (after decl) := {in_out_ty} : {← inferType in_out_ty}"
- let in_ty ← mkAppM ``Sigma.fst #[in_out_ty]
+ let param_in_out_ty := Lean.mkConst name (levelParams.map .param)
+ mkAppM' param_in_out_ty #[i_var]
+ trace[Diverge.def] "param_in_out_ty (after decl) := {param_in_out_ty} : {← inferType param_in_out_ty}"
+ -- Decompose between: param_ty, in_ty, out_ty
+ let param_ty ← mkAppM ``Sigma.fst #[param_in_out_ty]
+ let in_out_ty ← mkAppM ``Sigma.snd #[param_in_out_ty]
+ let in_ty ← mkAppM ``Prod.fst #[in_out_ty]
+ let out_ty ← mkAppM ``Prod.snd #[in_out_ty]
+ trace[Diverge.def] "param_ty: {param_ty}"
+ trace[Diverge.def] "in_ty: {in_ty}"
+ trace[Diverge.def] "out_ty: {out_ty}"
+ withLocalDeclD (mkAnonymous "t" 1) param_ty fun param => do
+ let in_ty ← mkAppM' in_ty #[param]
+ let out_ty ← mkAppM' out_ty #[param]
trace[Diverge.def] "in_ty: {in_ty}"
- withLocalDeclD (mkAnonymous "x" 1) in_ty fun input => do
- let out_ty ← mkAppM' (← mkAppM ``Sigma.snd #[in_out_ty]) #[input]
trace[Diverge.def] "out_ty: {out_ty}"
-- Introduce the continuation `k`
- let in_ty ← mkLambdaFVars #[i_var] in_ty
- let out_ty ← mkLambdaFVars #[i_var, input] out_ty
- let kk_var_ty ← mkAppM ``FixI.kk_ty #[i_var_ty, in_ty, out_ty]
+ let param_ty ← mkLambdaFVars #[i_var] param_ty
+ let in_ty ← mkLambdaFVars #[i_var, param] in_ty
+ let out_ty ← mkLambdaFVars #[i_var, param] out_ty
+ let kk_var_ty ← mkAppM ``FixII.kk_ty #[i_var_ty, param_ty, in_ty, out_ty]
trace[Diverge.def] "kk_var_ty: {kk_var_ty}"
withLocalDeclD (mkAnonymous "kk" 2) kk_var_ty fun kk_var => do
trace[Diverge.def] "kk_var: {kk_var}"
@@ -926,29 +1120,30 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
-- Replace the recursive calls in all the function bodies by calls to the
-- continuation `k` and and generate for those bodies declarations
trace[Diverge.def] "# Generating the unary bodies"
- let bodies ← mkDeclareUnaryBodies grLvlParams kk_var inOutTys preDefs
+ let bodies ← mkDeclareUnaryBodies grLvlParams kk_var paramInOutTys preDefs
trace[Diverge.def] "Unary bodies (after decl): {bodies}"
+
-- Generate the mutually recursive body
trace[Diverge.def] "# Generating the mut rec body"
- let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var in_ty out_ty inOutTys.toList bodies
+ let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var param_ty in_ty out_ty paramInOutTys.toList bodies
trace[Diverge.def] "mut rec body (after decl): {mutRecBody}"
-- Prove that the mut rec body satisfies the validity criteria required by
-- our fixed-point
- let k_var_ty ← mkAppM ``FixI.k_ty #[i_var_ty, in_ty, out_ty]
+ let k_var_ty ← mkAppM ``FixII.k_ty #[i_var_ty, param_ty, in_ty, out_ty]
withLocalDeclD (mkAnonymous "k" 3) k_var_ty fun k_var => do
trace[Diverge.def] "# Proving that the mut rec body is valid"
- let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies
+ let isValidThm ← proveMutRecIsValid grName grLvlParams paramInOutTysExpr bodyFuns mutRecBody k_var preDefs bodies
-- Generate the final definitions
trace[Diverge.def] "# Generating the final definitions"
- let decls ← mkDeclareFixDefs mutRecBody inOutTys preDefs
+ let decls ← mkDeclareFixDefs mutRecBody paramInOutTys preDefs
-- Prove the unfolding theorems
trace[Diverge.def] "# Proving the unfolding theorems"
- proveUnfoldingThms isValidThm inOutTys preDefs decls
+ proveUnfoldingThms isValidThm paramInOutTys preDefs decls
- -- Generating code -- TODO
+ -- Generating code
addAndCompilePartialRec preDefs
-- The following function is copy&pasted from Lean.Elab.PreDefinition.Main
diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean
index 1a0676e2..b818d5af 100644
--- a/backends/lean/Base/Diverge/ElabBase.lean
+++ b/backends/lean/Base/Diverge/ElabBase.lean
@@ -11,7 +11,9 @@ open Utils
initialize registerTraceClass `Diverge.elab
initialize registerTraceClass `Diverge.def
initialize registerTraceClass `Diverge.def.sigmas
+initialize registerTraceClass `Diverge.def.prods
initialize registerTraceClass `Diverge.def.genBody
+initialize registerTraceClass `Diverge.def.genBody.visit
initialize registerTraceClass `Diverge.def.valid
initialize registerTraceClass `Diverge.def.unfold