summaryrefslogtreecommitdiff
path: root/backends/lean
diff options
context:
space:
mode:
authorSon Ho2023-06-30 15:53:39 +0200
committerSon Ho2023-06-30 15:53:39 +0200
commit1c9331ce92b68b9a83c601212149a6c24591708f (patch)
tree7918a0c930ff675bb83e5a5030dd8208a9e500e3 /backends/lean
parentfdc8693772ecb1978873018c790061854f00a015 (diff)
Generate the fixed-point bodies in Elab.lean
Diffstat (limited to '')
-rw-r--r--backends/lean/Base/Diverge/Base.lean8
-rw-r--r--backends/lean/Base/Diverge/Elab.lean451
-rw-r--r--backends/lean/Base/Diverge/ElabBase.lean47
3 files changed, 391 insertions, 115 deletions
diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean
index 22b59bd0..aa0539ba 100644
--- a/backends/lean/Base/Diverge/Base.lean
+++ b/backends/lean/Base/Diverge/Base.lean
@@ -554,14 +554,14 @@ namespace FixI
/- Some utilities to define the mutually recursive functions -/
-- TODO: use more
- @[simp] def kk_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) :=
+ abbrev kk_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) :=
(i:id) → (x:a i) → Result (b i x)
- @[simp] def k_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) :=
+ abbrev k_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) :=
kk_ty id a b → kk_ty id a b
- def in_out_ty : Type (imax (u + 1) (v + 1)) := (in_ty : Type u) × ((x:in_ty) → Type v)
+ abbrev in_out_ty : Type (imax (u + 1) (v + 1)) := (in_ty : Type u) × ((x:in_ty) → Type v)
-- TODO: remove?
- @[simp] def mk_in_out_ty (in_ty : Type u) (out_ty : in_ty → Type v) :
+ abbrev mk_in_out_ty (in_ty : Type u) (out_ty : in_ty → Type v) :
in_out_ty :=
Sigma.mk in_ty out_ty
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean
index 116c5d8b..f7de7518 100644
--- a/backends/lean/Base/Diverge/Elab.lean
+++ b/backends/lean/Base/Diverge/Elab.lean
@@ -16,31 +16,62 @@ syntax (name := divergentDef)
open Lean Elab Term Meta Primitives Lean.Meta
set_option trace.Diverge.def true
+-- set_option trace.Diverge.def.sigmas true
/- The following was copied from the `wfRecursion` function. -/
open WF in
--- Replace the recursive calls by a call to the continuation
--- def replace_rec_calls
+def mkList (xl : List Expr) (ty : Expr) : MetaM Expr :=
+ match xl with
+ | [] =>
+ mkAppOptM ``List.nil #[some ty]
+ | x :: tl => do
+ let tl ← mkList tl ty
+ mkAppOptM ``List.cons #[some ty, some x, some tl]
--- print_decl is_even_body
-#check instOfNatNat
-#check OfNat.ofNat -- @OfNat.ofNat ℕ 2 ...
-#check OfNat.ofNat -- @OfNat.ofNat (Fin 2) 1 ...
-#check Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat
+def mkProd (x y : Expr) : MetaM Expr :=
+ mkAppM ``Prod.mk #[x, y]
+def mkInOutTy (x y : Expr) : MetaM Expr :=
+ mkAppM ``FixI.mk_in_out_ty #[x, y]
-- TODO: is there already such a utility somewhere?
-- TODO: change to mkSigmas
def mkProds (tys : List Expr) : MetaM Expr :=
match tys with
- | [] => do return (Expr.const ``PUnit.unit [])
- | [ty] => do return ty
+ | [] => do pure (Expr.const ``PUnit.unit [])
+ | [ty] => do pure ty
| ty :: tys => do
let pty ← mkProds tys
mkAppM ``Prod.mk #[ty, pty]
+-- Return the `a` in `Return a`
+def get_result_ty (ty : Expr) : MetaM Expr :=
+ ty.withApp fun f args => do
+ if ¬ f.isConstOf ``Result ∨ args.size ≠ 1 then
+ throwError "Invalid argument to get_result_ty: {ty}"
+ else
+ pure (args.get! 0)
+
+-- Group a list of expressions into a dependent tuple
+def mkSigmas (xl : List Expr) : MetaM Expr :=
+ match xl with
+ | [] => do
+ trace[Diverge.def.sigmas] "mkSigmas: []"
+ pure (Expr.const ``PUnit.unit [])
+ | [x] => do
+ trace[Diverge.def.sigmas] "mkSigmas: [{x}]"
+ pure x
+ | fst :: xl => do
+ trace[Diverge.def.sigmas] "mkSigmas: [{fst}::{xl}]"
+ let alpha ← Lean.Meta.inferType fst
+ let snd ← mkSigmas xl
+ let snd_ty ← inferType snd
+ let beta ← mkLambdaFVars #[fst] snd_ty
+ trace[Diverge.def.sigmas] "mkSigmas:\n{alpha}\n{beta}\n{fst}\n{snd}"
+ mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd]
+
/- Generate the input type of a function body, which is a sigma type (i.e., a
dependent tuple) which groups all its inputs.
@@ -55,11 +86,11 @@ def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr :=
match xl with
| [] => do
trace[Diverge.def.sigmas] "mkSigmasOfTypes: []"
- return (Expr.const ``PUnit.unit [])
+ pure (Expr.const ``PUnit.unit [])
| [x] => do
trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}]"
let ty ← Lean.Meta.inferType x
- return ty
+ pure ty
| x :: xl => do
trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]"
let alpha ← Lean.Meta.inferType x
@@ -71,15 +102,26 @@ def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr :=
def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index
-/- Generate the out_ty of the body of a function, which from an input (a sigma
- type generated by `mkSigmasTypesOfTypes`) gives the output type of the function.
+/- Given a list of values `[x0:ty0, ..., xn:ty1]` where every `xi` might use the previous
+ `xj` (j < i) and a value `out` which uses `x0`, ..., `xn`, generate the following
+ expression:
+ ```
+ fun x:((x0:ty0) × ... × (xn:tyn) => -- **Dependent** tuple
+ match x with
+ | (x0, ..., xn) => out
+ ```
+
+ The `index` parameter is used for naming purposes: we use it to numerotate the
+ bound variables that we introduce.
Example:
+ ========
+ More precisely:
- xl = `[a:Type, ls:List a, i:Int]`
- - out_ty = `a`
- - index = 0 -- For naming purposes: we use it to numerotate the "scrutinee" variables
+ - out = `a`
+ - index = 0
- Generates:
+ generates:
```
match scrut0 with
| Sigma.mk x scrut1 =>
@@ -88,36 +130,47 @@ def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index
a
```
-/
-def mkSigmasOutType (xl : List Expr) (out_ty : Expr) (index : Nat := 0) : MetaM Expr :=
+partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : MetaM Expr :=
match xl with
| [] => do
-- This would be unexpected
- throwError "mkSigmasOutType: empyt list of input parameters"
+ throwError "mkSigmasMatch: empyt list of input parameters"
| [x] => do
-- In the explanations above: inner match case
- trace[Diverge.def.sigmas] "mkSigmasOutType: [{x}]"
- mkLambdaFVars #[x] out_ty
+ trace[Diverge.def.sigmas] "mkSigmasMatch: [{x}]"
+ mkLambdaFVars #[x] out
| fst :: xl => do
-- In the explanations above: outer match case
-- Remark: for the naming purposes, we use the same convention as for the
-- fields and parameters in `Sigma.casesOn and `Sigma.mk
- trace[Diverge.def.sigmas] "mkSigmasOutType: [{fst}::{xl}]"
+ trace[Diverge.def.sigmas] "mkSigmasMatch: [{fst}::{xl}]"
let alpha ← Lean.Meta.inferType fst
let snd_ty ← mkSigmasTypesOfTypes xl
let beta ← mkLambdaFVars #[fst] snd_ty
- let snd ← mkSigmasOutType xl out_ty (index + 1)
+ let snd ← mkSigmasMatch xl out (index + 1)
let scrut_ty ← mkSigmasTypesOfTypes (fst :: xl)
withLocalDeclD (mk_indexed_name index) scrut_ty fun scrut => do
let mk ← mkLambdaFVars #[fst] snd
- trace[Diverge.def.sigmas] "mkSigmasOutType: scrut: ({scrut}) : ({← inferType scrut})"
- let motive ← mkLambdaFVars #[scrut] (← inferType out_ty)
- trace[Diverge.def.sigmas] "mkSigmasOutType:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})"
- let out ← mkAppOptM ``Sigma.casesOn #[some alpha, some beta, some motive, some scrut, some mk]
- let out ← mkLambdaFVars #[scrut] out
- trace[Diverge.def.sigmas] "mkSigmasOutType: out: {out}"
- return out
-
-/- Small tests for list_nth: give a model of what `mkSigmasOutType` should generate -/
+ trace[Diverge.def.sigmas] "mkSigmasMatch: scrut: ({scrut}) : ({← inferType scrut})"
+ -- TODO: make the computation of the motive more efficient
+ let motive ← do
+ let out_ty ← inferType out
+ match out_ty with
+ | .sort _ | .lit _ | .const .. =>
+ -- The type of the motive doesn't depend on the scrutinee
+ mkLambdaFVars #[scrut] out_ty
+ | _ =>
+ -- The type of the motive *may* depend on the scrutinee
+ -- TODO: make this more efficient (we could change the output type of
+ -- mkSigmasMatch
+ mkSigmasMatch (fst :: xl) out_ty
+ trace[Diverge.def.sigmas] "mkSigmasMatch:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})"
+ let sm ← mkAppOptM ``Sigma.casesOn #[some alpha, some beta, some motive, some scrut, some mk]
+ let sm ← mkLambdaFVars #[scrut] sm
+ trace[Diverge.def.sigmas] "mkSigmasMatch: sm: {sm}"
+ pure sm
+
+/- Small tests for list_nth: give a model of what `mkSigmasMatch` should generate -/
private def list_nth_out_ty2 (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) :=
@Sigma.casesOn (List a)
(fun (_ls : List a) => Int)
@@ -135,14 +188,199 @@ private def list_nth_out_ty1 (scrut0 : @Sigma (Type) (fun (a:Type) =>
list_nth_out_ty2 a scrut1)
/- -/
+-- TODO: move
+-- TODO: we can use Array.mapIdx
+@[specialize] def mapiAux (i : Nat) (f : Nat → α → β) : List α → List β
+ | [] => []
+ | a::as => f i a :: mapiAux (i+1) f as
+
+@[specialize] def mapi (f : Nat → α → β) : List α → List β := mapiAux 0 f
+
+#check Array.map
+-- Return the expression: `Fin n`
+-- TODO: use more
+def mkFin (n : Nat) : Expr :=
+ mkAppN (.const ``Fin []) #[.lit (.natVal n)]
+
+-- Return the expression: `i : Fin n`
+def mkFinVal (n i : Nat) : MetaM Expr := do
+ let n_lit : Expr := .lit (.natVal (n - 1))
+ let i_lit : Expr := .lit (.natVal i)
+ -- We could use `trySynthInstance`, but as we know the instance that we are
+ -- going to use, we can save the lookup
+ let ofNat ← mkAppOptM ``Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat #[n_lit, i_lit]
+ mkAppOptM ``OfNat.ofNat #[none, none, ofNat]
+
+-- TODO: remove?
+def mkFinValOld (n i : Nat) : MetaM Expr := do
+ let finTy := mkFin n
+ let ofNat ← mkAppM ``OfNat #[finTy, .lit (.natVal i)]
+ match ← trySynthInstance ofNat with
+ | LOption.some x =>
+ mkAppOptM ``OfNat.ofNat #[none, none, x]
+ | _ => throwError "mkFinVal: could not synthesize an instance of {ofNat} "
+
+/- Generate and declare as individual definitions the bodies for the individual funcions:
+ - replace the recursive calls with calls to the continutation `k`
+ - make those bodies take one single dependent tuple as input
+
+ We name the declarations: "[original_name].body".
+ We return the new declarations.
+ -/
+def mkDeclareUnaryBodies (grLvlParams : List Name) (k_var : Expr)
+ (preDefs : Array PreDefinition) :
+ MetaM (Array Expr) := do
+ let grSize := preDefs.size
+
+ -- Compute the map from name to index - the continuation has an indexed type:
+ -- we use the index (a finite number of type `Fin`) to control the function
+ -- we call at the recursive call
+ let nameToId : HashMap Name Nat :=
+ let namesIds := mapi (fun i d => (d.declName, i)) preDefs.toList
+ HashMap.ofList namesIds
+
+ trace[Diverge.def.genBody] "nameToId: {nameToId.toList}"
+
+ -- Auxiliary function to explore the function bodies and replace the
+ -- recursive calls
+ let visit_e (e : Expr) : MetaM Expr := do
+ trace[Diverge.def.genBody] "visiting expression: {e}"
+ match e with
+ | .app .. => do
+ e.withApp fun f args => do
+ trace[Diverge.def.genBody] "this is an app: {f} {args}"
+ -- Check if this is a recursive call
+ if f.isConst then
+ let name := f.constName!
+ match nameToId.find? name with
+ | none => pure e
+ | some id =>
+ -- This is a recursive call: replace it
+ -- Compute the index
+ let i ← mkFinVal grSize id
+ -- Put the arguments in one big dependent tuple
+ let args ← mkSigmas args.toList
+ mkAppM' k_var #[i, args]
+ else
+ -- Not a recursive call: do nothing
+ pure e
+ | .const name _ =>
+ -- Sanity check: we eliminated all the recursive calls
+ if (nameToId.find? name).isSome then
+ throwError "mkUnaryBodies: a recursive call was not eliminated"
+ else pure e
+ | _ => pure e
+
+ -- Explore the bodies
+ preDefs.mapM fun preDef => do
+ -- Replace the recursive calls
+ let body ← mapVisit visit_e preDef.value
+
+ -- Change the type
+ lambdaLetTelescope body fun args body => do
+ let body ← mkSigmasMatch args.toList body 0
+
+ -- Add the declaration
+ let value ← mkLambdaFVars #[k_var] body
+ let name := preDef.declName.append "body"
+ let levelParams := grLvlParams
+ let decl := Declaration.defnDecl {
+ name := name
+ levelParams := levelParams
+ type := ← inferType value -- TODO: change the type
+ value := value
+ hints := ReducibilityHints.regular (getMaxHeight (← getEnv) value + 1)
+ safety := .safe
+ all := [name]
+ }
+ addDecl decl
+ trace[Diverge.def] "individual body of {preDef.declName}: {body}"
+ -- Return the constant
+ let body := Lean.mkConst name (levelParams.map .param)
+ -- let body ← mkAppM' body #[k_var]
+ trace[Diverge.def] "individual body (after decl): {body}"
+ pure body
+
+-- Generate a unique function body from the bodies of the mutually recursive group,
+-- and add it as a declaration in the context
+def mkDeclareMutualBody (grName : Name) (grLvlParams : List Name)
+ (i_var k_var : Expr)
+ (in_ty out_ty : Expr) (inOutTys : List (Expr × Expr))
+ (bodies : Array Expr) : MetaM Expr := do
+ -- Generate the body
+ let grSize := bodies.size
+ let finTypeExpr := mkFin grSize
+ -- TODO: not very clean
+ let inOutTyType ← do
+ let (x, y) := inOutTys.get! 0
+ inferType (← mkInOutTy x y)
+ let rec mkFuns (inOutTys : List (Expr × Expr)) (bl : List Expr) : MetaM Expr :=
+ match inOutTys, bl with
+ | [], [] =>
+ mkAppOptM ``FixI.Funs.Nil #[finTypeExpr, in_ty, out_ty]
+ | (ity, oty) :: inOutTys, b :: bl => do
+ -- Retrieving ity and oty - this is not very clean
+ let inOutTysExpr ← mkList (← inOutTys.mapM (λ (x, y) => mkInOutTy x y)) inOutTyType
+ let fl ← mkFuns inOutTys bl
+ mkAppOptM ``FixI.Funs.Cons #[finTypeExpr, in_ty, out_ty, ity, oty, inOutTysExpr, b, fl]
+ | _, _ => throwError "mkDeclareMutualBody: `tys` and `bodies` don't have the same length"
+ let bodyFuns ← mkFuns inOutTys bodies.toList
+ -- Wrap in `get_fun`
+ let body ← mkAppM ``FixI.get_fun #[bodyFuns, i_var, k_var]
+ -- Add the index `i` and the continuation `k` as a variables
+ let body ← mkLambdaFVars #[k_var, i_var] body
+ trace[Diverge.def] "mkDeclareMutualBody: body: {body}"
+ -- Add the declaration
+ let name := grName.append "mutrec_body"
+ let levelParams := grLvlParams
+ let decl := Declaration.defnDecl {
+ name := name
+ levelParams := levelParams
+ type := ← inferType body
+ value := body
+ hints := ReducibilityHints.regular (getMaxHeight (← getEnv) body + 1)
+ safety := .safe
+ all := [name]
+ }
+ addDecl decl
+ -- Return the constant
+ pure (Lean.mkConst name (levelParams.map .param))
+
+-- Generate the final definions by using the mutual body and the fixed point operator.
+def mkDeclareFixDefs (mutBody : Expr) (preDefs : Array PreDefinition) :
+ TermElabM Unit := do
+ let grSize := preDefs.size
+ let _ ← preDefs.mapIdxM fun idx preDef => do
+ lambdaLetTelescope preDef.value fun xs _ => do
+ -- Create the index
+ let idx ← mkFinVal grSize idx.val
+ -- Group the inputs into a dependent tuple
+ let input ← mkSigmas xs.toList
+ -- Apply the fixed point
+ let fixedBody ← mkAppM ``FixI.fix #[mutBody, idx, input]
+ let fixedBody ← mkLambdaFVars xs fixedBody
+ -- Create the declaration
+ let name := preDef.declName
+ let decl := Declaration.defnDecl {
+ name := name
+ levelParams := preDef.levelParams
+ type := preDef.type
+ value := fixedBody
+ hints := ReducibilityHints.regular (getMaxHeight (← getEnv) fixedBody + 1)
+ safety := .safe
+ all := [name]
+ }
+ addDecl decl
+ pure ()
+
def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value)
trace[Diverge.def] ("divRecursion: defs: " ++ msg)
-- CHANGE HERE This function should add definitions with these names/types/values ^^
-- Temporarily add the predefinitions as axioms
- for preDef in preDefs do
- addAsAxiom preDef
+ -- for preDef in preDefs do
+ -- addAsAxiom preDef
-- TODO: what is this?
for preDef in preDefs do
@@ -154,25 +392,14 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
let grName := def0.declName
trace[Diverge.def] "group name: {grName}"
- /- Compute the type of the continuation.
-
- We do the following
- - we make sure all the definitions have the same universe parameters
- (we can make this more general later)
- - we group all the type parameters together, make sure all the
- definitions have the same type parameters, and enforce
- a uniform polymorphism (we can also lift this later).
- This would require generalizing a bit our indexed fixed point to
- make the output type parametric in the input.
- - we group all the non-type parameters: we parameterize the continuation
- by those
- -/
+ /- # Compute the input/output types of the continuation `k`. -/
let grLvlParams := def0.levelParams
- trace[Diverge.def] "def0 type: {def0.type}"
+ trace[Diverge.def] "def0 universe levels: {def0.levelParams}"
- -- Compute the list of pairs: (input type × output type)
+ -- We first compute the list of pairs: (input type × output type)
let inOutTys : Array (Expr × Expr) ←
preDefs.mapM (fun preDef => do
+ withRef preDef.ref do -- is the withRef useful?
-- Check the universe parameters - TODO: I'm not sure what the best thing
-- to do is. In practice, all the type parameters should be in Type 0, so
-- we shouldn't have universe issues.
@@ -180,68 +407,74 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
throwError "Non-uniform polymorphism in the universes"
forallTelescope preDef.type (fun in_tys out_ty => do
let in_ty ← liftM (mkSigmasTypesOfTypes in_tys.toList)
- let out_ty ← liftM (mkSigmasOutType in_tys.toList out_ty)
- return (in_ty, out_ty)
+ -- Retrieve the type in the "Result"
+ let out_ty ← get_result_ty out_ty
+ let out_ty ← liftM (mkSigmasMatch in_tys.toList out_ty)
+ pure (in_ty, out_ty)
)
)
trace[Diverge.def] "inOutTys: {inOutTys}"
-
-/- -- Small utility: compute the list of type parameters
- let getTypeParams (ty: Expr) : MetaM (List Expr × List Expr × Expr) :=
- Lean.Meta.forallTelescope ty fun tys out_ty => do
- trace[Diverge.def] "types: {tys}"
-/- let (_, params) ← StateT.run (do
- for x in tys do
- let ty ← Lean.Meta.inferType x
- match ty with
- | .sort _ => do
- let st ← StateT.get
- StateT.set (ty :: st)
- | _ => do break
- ) ([] : List Expr)
- let params := params.reverse
- trace[Diverge.def] " type parameters {params}"
- return params -/
- let rec get_params (ls : List Expr) : MetaM (List Expr × List Expr) :=
- match ls with
- | x :: tl => do
- let ty ← Lean.Meta.inferType x
- match ty with
- | .sort _ => do
- let (ty_params, params) ← get_params tl
- return (x :: ty_params, params)
- | _ => do return ([], ls)
- | _ => do return ([], [])
- let (ty_params, params) ← get_params tys.toList
- trace[Diverge.def] " parameters: {ty_params}; {params}"
- return (ty_params, params, out_ty)
- let (grTyParams, _, _) ← do
- getTypeParams def0.type
-
- -- Compute the input types and the output types
- let all_tys ← preDefs.mapM fun preDef => do
- let (tyParams, params, ret_ty) ← getTypeParams preDef.type
- -- TODO: this is not complete, there are more checks to perform
- if tyParams.length ≠ grTyParams.length then
- throwError "Non-uniform polymorphism"
- return (params, ret_ty)
-
- -- TODO: I think there are issues with the free variables
- let (input_tys, output_tys) := List.unzip all_tys.toList
- let input_tys : List Expr ← liftM (List.mapM mkProds input_tys)
-
- trace[Diverge.def] " in/out tys: {input_tys}; {output_tys}" -/
-
- -- Compute the names set
- let names := preDefs.map PreDefinition.declName
- let names := HashSet.empty.insertMany names
-
- --
- -- for preDef in preDefs do
- -- trace[Diverge.def] "about to explore: {preDef.declName}"
- -- explore_term "" preDef.value
-
- -- Compute the bodies
+ -- Turn the list of input/output type pairs into an expresion
+ let inOutTysExpr ← inOutTys.mapM (λ (x, y) => mkInOutTy x y)
+ let inOutTysExpr ← mkList inOutTysExpr.toList (← inferType (inOutTysExpr.get! 0))
+
+ -- From the list of pairs of input/output types, actually compute the
+ -- type of the continuation `k`.
+ -- We first introduce the index `i : Fin n` where `n` is the number of
+ -- functions in the group.
+ let i_var_ty := mkFin preDefs.size
+ withLocalDeclD (.num (.str .anonymous "i") 0) i_var_ty fun i_var => do
+ let in_out_ty ← mkAppM ``List.get #[inOutTysExpr, i_var]
+ trace[Diverge.def] "in_out_ty := {in_out_ty} : {← inferType in_out_ty}"
+ -- Add an auxiliary definition for `in_out_ty`
+ let in_out_ty ← do
+ let value ← mkLambdaFVars #[i_var] in_out_ty
+ let name := grName.append "in_out_ty"
+ let levelParams := grLvlParams
+ let decl := Declaration.defnDecl {
+ name := name
+ levelParams := levelParams
+ type := ← inferType value
+ value := value
+ hints := .abbrev
+ safety := .safe
+ all := [name]
+ }
+ addDecl decl
+ -- Return the constant
+ let in_out_ty := Lean.mkConst name (levelParams.map .param)
+ mkAppM' in_out_ty #[i_var]
+ trace[Diverge.def] "in_out_ty (after decl) := {in_out_ty} : {← inferType in_out_ty}"
+ let in_ty ← mkAppM ``Sigma.fst #[in_out_ty]
+ trace[Diverge.def] "in_ty: {in_ty}"
+ withLocalDeclD (.num (.str .anonymous "x") 1) in_ty fun input => do
+ let out_ty ← mkAppM' (← mkAppM ``Sigma.snd #[in_out_ty]) #[input]
+ trace[Diverge.def] "out_ty: {out_ty}"
+
+ -- Introduce the continuation `k`
+ let in_ty ← mkLambdaFVars #[i_var] in_ty
+ let out_ty ← mkLambdaFVars #[i_var, input] out_ty
+ let k_var_ty ← mkAppM ``FixI.kk_ty #[i_var_ty, in_ty, out_ty] --
+ trace[Diverge.def] "k_var_ty: {k_var_ty}"
+ withLocalDeclD (.num (.str .anonymous "k") 2) k_var_ty fun k_var => do
+ trace[Diverge.def] "k_var: {k_var}"
+
+ -- Replace the recursive calls in all the function bodies by calls to the
+ -- continuation `k` and and generate for those bodies declarations
+ let bodies ← mkDeclareUnaryBodies grLvlParams k_var preDefs
+ -- Generate the mutually recursive body
+ let body ← mkDeclareMutualBody grName grLvlParams i_var k_var in_ty out_ty inOutTys.toList bodies
+ trace[Diverge.def] "mut rec body (after decl): {body}"
+
+ -- Prove that the mut rec body satisfies the validity criteria required by
+ -- our fixed-point
+ -- TODO
+
+ -- Generate the final definitions
+ let defs ← mkDeclareFixDefs body preDefs
+
+ -- Prove the unfolding equations
+ -- TODO
-- Process the definitions
addAndCompilePartialRec preDefs
@@ -366,6 +599,10 @@ divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a :=
if i = 0 then return x
else return (← list_nth ls (i - 1))
+#print list_nth.in_out_ty
+#check list_nth.body
+#print list_nth
+
mutual
divergent def is_even (i : Int) : Result Bool :=
if i = 0 then return true else return (← is_odd (i - 1))
diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean
index 441b25f0..82f79f94 100644
--- a/backends/lean/Base/Diverge/ElabBase.lean
+++ b/backends/lean/Base/Diverge/ElabBase.lean
@@ -4,13 +4,14 @@ namespace Diverge
open Lean Elab Term Meta
-initialize registerTraceClass `Diverge.elab (inherited := true)
-initialize registerTraceClass `Diverge.def.sigmas (inherited := true)
-initialize registerTraceClass `Diverge.def (inherited := true)
+initialize registerTraceClass `Diverge.elab
+initialize registerTraceClass `Diverge.def
+initialize registerTraceClass `Diverge.def.sigmas
+initialize registerTraceClass `Diverge.def.genBody
-- TODO: move
-- TODO: small helper
-def explore_term (incr : String) (e : Expr) : TermElabM Unit :=
+def explore_term (incr : String) (e : Expr) : MetaM Unit :=
match e with
| .bvar _ => do logInfo m!"{incr}bvar: {e}"; return ()
| .fvar _ => do logInfo m!"{incr}fvar: {e}"; return ()
@@ -78,4 +79,42 @@ private def test2 (x : Nat) : Nat := x
print_decl test1
print_decl test2
+-- We adapted this from AbstractNestedProofs.visit
+-- A map visitor function for expressions
+partial def mapVisit (k : Expr → MetaM Expr) (e : Expr) : MetaM Expr := do
+ let mapVisitBinders (xs : Array Expr) (k2 : MetaM Expr) : MetaM Expr := do
+ let localInstances ← getLocalInstances
+ let mut lctx ← getLCtx
+ for x in xs do
+ let xFVarId := x.fvarId!
+ let localDecl ← xFVarId.getDecl
+ let type ← mapVisit k localDecl.type
+ let localDecl := localDecl.setType type
+ let localDecl ← match localDecl.value? with
+ | some value => let value ← mapVisit k value; pure <| localDecl.setValue value
+ | none => pure localDecl
+ lctx :=lctx.modifyLocalDecl xFVarId fun _ => localDecl
+ withLCtx lctx localInstances k2
+ -- TODO: use a cache? (Lean.checkCache)
+ -- Explore
+ let e ← k e
+ match e with
+ | .bvar _
+ | .fvar _
+ | .mvar _
+ | .sort _
+ | .lit _
+ | .const _ _ => pure e
+ | .app .. => do e.withApp fun f args => return mkAppN f (← args.mapM (mapVisit k))
+ | .lam .. =>
+ lambdaLetTelescope e fun xs b =>
+ mapVisitBinders xs do mkLambdaFVars xs (← mapVisit k b) (usedLetOnly := false)
+ | .forallE .. => do
+ forallTelescope e fun xs b => mapVisitBinders xs do mkForallFVars xs (← mapVisit k b)
+ | .letE .. => do
+ lambdaLetTelescope e fun xs b => mapVisitBinders xs do
+ mkLambdaFVars xs (← mapVisit k b) (usedLetOnly := false)
+ | .mdata _ b => return e.updateMData! (← mapVisit k b)
+ | .proj _ _ b => return e.updateProj! (← mapVisit k b)
+
end Diverge