diff options
author | Son Ho | 2023-06-30 15:53:39 +0200 |
---|---|---|
committer | Son Ho | 2023-06-30 15:53:39 +0200 |
commit | 1c9331ce92b68b9a83c601212149a6c24591708f (patch) | |
tree | 7918a0c930ff675bb83e5a5030dd8208a9e500e3 /backends/lean/Base/Diverge | |
parent | fdc8693772ecb1978873018c790061854f00a015 (diff) |
Generate the fixed-point bodies in Elab.lean
Diffstat (limited to '')
-rw-r--r-- | backends/lean/Base/Diverge/Base.lean | 8 | ||||
-rw-r--r-- | backends/lean/Base/Diverge/Elab.lean | 451 | ||||
-rw-r--r-- | backends/lean/Base/Diverge/ElabBase.lean | 47 |
3 files changed, 391 insertions, 115 deletions
diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index 22b59bd0..aa0539ba 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -554,14 +554,14 @@ namespace FixI /- Some utilities to define the mutually recursive functions -/ -- TODO: use more - @[simp] def kk_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) := + abbrev kk_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) := (i:id) → (x:a i) → Result (b i x) - @[simp] def k_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) := + abbrev k_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) := kk_ty id a b → kk_ty id a b - def in_out_ty : Type (imax (u + 1) (v + 1)) := (in_ty : Type u) × ((x:in_ty) → Type v) + abbrev in_out_ty : Type (imax (u + 1) (v + 1)) := (in_ty : Type u) × ((x:in_ty) → Type v) -- TODO: remove? - @[simp] def mk_in_out_ty (in_ty : Type u) (out_ty : in_ty → Type v) : + abbrev mk_in_out_ty (in_ty : Type u) (out_ty : in_ty → Type v) : in_out_ty := Sigma.mk in_ty out_ty diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 116c5d8b..f7de7518 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -16,31 +16,62 @@ syntax (name := divergentDef) open Lean Elab Term Meta Primitives Lean.Meta set_option trace.Diverge.def true +-- set_option trace.Diverge.def.sigmas true /- The following was copied from the `wfRecursion` function. -/ open WF in --- Replace the recursive calls by a call to the continuation --- def replace_rec_calls +def mkList (xl : List Expr) (ty : Expr) : MetaM Expr := + match xl with + | [] => + mkAppOptM ``List.nil #[some ty] + | x :: tl => do + let tl ← mkList tl ty + mkAppOptM ``List.cons #[some ty, some x, some tl] --- print_decl is_even_body -#check instOfNatNat -#check OfNat.ofNat -- @OfNat.ofNat ℕ 2 ... -#check OfNat.ofNat -- @OfNat.ofNat (Fin 2) 1 ... -#check Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat +def mkProd (x y : Expr) : MetaM Expr := + mkAppM ``Prod.mk #[x, y] +def mkInOutTy (x y : Expr) : MetaM Expr := + mkAppM ``FixI.mk_in_out_ty #[x, y] -- TODO: is there already such a utility somewhere? -- TODO: change to mkSigmas def mkProds (tys : List Expr) : MetaM Expr := match tys with - | [] => do return (Expr.const ``PUnit.unit []) - | [ty] => do return ty + | [] => do pure (Expr.const ``PUnit.unit []) + | [ty] => do pure ty | ty :: tys => do let pty ← mkProds tys mkAppM ``Prod.mk #[ty, pty] +-- Return the `a` in `Return a` +def get_result_ty (ty : Expr) : MetaM Expr := + ty.withApp fun f args => do + if ¬ f.isConstOf ``Result ∨ args.size ≠ 1 then + throwError "Invalid argument to get_result_ty: {ty}" + else + pure (args.get! 0) + +-- Group a list of expressions into a dependent tuple +def mkSigmas (xl : List Expr) : MetaM Expr := + match xl with + | [] => do + trace[Diverge.def.sigmas] "mkSigmas: []" + pure (Expr.const ``PUnit.unit []) + | [x] => do + trace[Diverge.def.sigmas] "mkSigmas: [{x}]" + pure x + | fst :: xl => do + trace[Diverge.def.sigmas] "mkSigmas: [{fst}::{xl}]" + let alpha ← Lean.Meta.inferType fst + let snd ← mkSigmas xl + let snd_ty ← inferType snd + let beta ← mkLambdaFVars #[fst] snd_ty + trace[Diverge.def.sigmas] "mkSigmas:\n{alpha}\n{beta}\n{fst}\n{snd}" + mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd] + /- Generate the input type of a function body, which is a sigma type (i.e., a dependent tuple) which groups all its inputs. @@ -55,11 +86,11 @@ def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr := match xl with | [] => do trace[Diverge.def.sigmas] "mkSigmasOfTypes: []" - return (Expr.const ``PUnit.unit []) + pure (Expr.const ``PUnit.unit []) | [x] => do trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}]" let ty ← Lean.Meta.inferType x - return ty + pure ty | x :: xl => do trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]" let alpha ← Lean.Meta.inferType x @@ -71,15 +102,26 @@ def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr := def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index -/- Generate the out_ty of the body of a function, which from an input (a sigma - type generated by `mkSigmasTypesOfTypes`) gives the output type of the function. +/- Given a list of values `[x0:ty0, ..., xn:ty1]` where every `xi` might use the previous + `xj` (j < i) and a value `out` which uses `x0`, ..., `xn`, generate the following + expression: + ``` + fun x:((x0:ty0) × ... × (xn:tyn) => -- **Dependent** tuple + match x with + | (x0, ..., xn) => out + ``` + + The `index` parameter is used for naming purposes: we use it to numerotate the + bound variables that we introduce. Example: + ======== + More precisely: - xl = `[a:Type, ls:List a, i:Int]` - - out_ty = `a` - - index = 0 -- For naming purposes: we use it to numerotate the "scrutinee" variables + - out = `a` + - index = 0 - Generates: + generates: ``` match scrut0 with | Sigma.mk x scrut1 => @@ -88,36 +130,47 @@ def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index a ``` -/ -def mkSigmasOutType (xl : List Expr) (out_ty : Expr) (index : Nat := 0) : MetaM Expr := +partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : MetaM Expr := match xl with | [] => do -- This would be unexpected - throwError "mkSigmasOutType: empyt list of input parameters" + throwError "mkSigmasMatch: empyt list of input parameters" | [x] => do -- In the explanations above: inner match case - trace[Diverge.def.sigmas] "mkSigmasOutType: [{x}]" - mkLambdaFVars #[x] out_ty + trace[Diverge.def.sigmas] "mkSigmasMatch: [{x}]" + mkLambdaFVars #[x] out | fst :: xl => do -- In the explanations above: outer match case -- Remark: for the naming purposes, we use the same convention as for the -- fields and parameters in `Sigma.casesOn and `Sigma.mk - trace[Diverge.def.sigmas] "mkSigmasOutType: [{fst}::{xl}]" + trace[Diverge.def.sigmas] "mkSigmasMatch: [{fst}::{xl}]" let alpha ← Lean.Meta.inferType fst let snd_ty ← mkSigmasTypesOfTypes xl let beta ← mkLambdaFVars #[fst] snd_ty - let snd ← mkSigmasOutType xl out_ty (index + 1) + let snd ← mkSigmasMatch xl out (index + 1) let scrut_ty ← mkSigmasTypesOfTypes (fst :: xl) withLocalDeclD (mk_indexed_name index) scrut_ty fun scrut => do let mk ← mkLambdaFVars #[fst] snd - trace[Diverge.def.sigmas] "mkSigmasOutType: scrut: ({scrut}) : ({← inferType scrut})" - let motive ← mkLambdaFVars #[scrut] (← inferType out_ty) - trace[Diverge.def.sigmas] "mkSigmasOutType:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})" - let out ← mkAppOptM ``Sigma.casesOn #[some alpha, some beta, some motive, some scrut, some mk] - let out ← mkLambdaFVars #[scrut] out - trace[Diverge.def.sigmas] "mkSigmasOutType: out: {out}" - return out - -/- Small tests for list_nth: give a model of what `mkSigmasOutType` should generate -/ + trace[Diverge.def.sigmas] "mkSigmasMatch: scrut: ({scrut}) : ({← inferType scrut})" + -- TODO: make the computation of the motive more efficient + let motive ← do + let out_ty ← inferType out + match out_ty with + | .sort _ | .lit _ | .const .. => + -- The type of the motive doesn't depend on the scrutinee + mkLambdaFVars #[scrut] out_ty + | _ => + -- The type of the motive *may* depend on the scrutinee + -- TODO: make this more efficient (we could change the output type of + -- mkSigmasMatch + mkSigmasMatch (fst :: xl) out_ty + trace[Diverge.def.sigmas] "mkSigmasMatch:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})" + let sm ← mkAppOptM ``Sigma.casesOn #[some alpha, some beta, some motive, some scrut, some mk] + let sm ← mkLambdaFVars #[scrut] sm + trace[Diverge.def.sigmas] "mkSigmasMatch: sm: {sm}" + pure sm + +/- Small tests for list_nth: give a model of what `mkSigmasMatch` should generate -/ private def list_nth_out_ty2 (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) := @Sigma.casesOn (List a) (fun (_ls : List a) => Int) @@ -135,14 +188,199 @@ private def list_nth_out_ty1 (scrut0 : @Sigma (Type) (fun (a:Type) => list_nth_out_ty2 a scrut1) /- -/ +-- TODO: move +-- TODO: we can use Array.mapIdx +@[specialize] def mapiAux (i : Nat) (f : Nat → α → β) : List α → List β + | [] => [] + | a::as => f i a :: mapiAux (i+1) f as + +@[specialize] def mapi (f : Nat → α → β) : List α → List β := mapiAux 0 f + +#check Array.map +-- Return the expression: `Fin n` +-- TODO: use more +def mkFin (n : Nat) : Expr := + mkAppN (.const ``Fin []) #[.lit (.natVal n)] + +-- Return the expression: `i : Fin n` +def mkFinVal (n i : Nat) : MetaM Expr := do + let n_lit : Expr := .lit (.natVal (n - 1)) + let i_lit : Expr := .lit (.natVal i) + -- We could use `trySynthInstance`, but as we know the instance that we are + -- going to use, we can save the lookup + let ofNat ← mkAppOptM ``Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat #[n_lit, i_lit] + mkAppOptM ``OfNat.ofNat #[none, none, ofNat] + +-- TODO: remove? +def mkFinValOld (n i : Nat) : MetaM Expr := do + let finTy := mkFin n + let ofNat ← mkAppM ``OfNat #[finTy, .lit (.natVal i)] + match ← trySynthInstance ofNat with + | LOption.some x => + mkAppOptM ``OfNat.ofNat #[none, none, x] + | _ => throwError "mkFinVal: could not synthesize an instance of {ofNat} " + +/- Generate and declare as individual definitions the bodies for the individual funcions: + - replace the recursive calls with calls to the continutation `k` + - make those bodies take one single dependent tuple as input + + We name the declarations: "[original_name].body". + We return the new declarations. + -/ +def mkDeclareUnaryBodies (grLvlParams : List Name) (k_var : Expr) + (preDefs : Array PreDefinition) : + MetaM (Array Expr) := do + let grSize := preDefs.size + + -- Compute the map from name to index - the continuation has an indexed type: + -- we use the index (a finite number of type `Fin`) to control the function + -- we call at the recursive call + let nameToId : HashMap Name Nat := + let namesIds := mapi (fun i d => (d.declName, i)) preDefs.toList + HashMap.ofList namesIds + + trace[Diverge.def.genBody] "nameToId: {nameToId.toList}" + + -- Auxiliary function to explore the function bodies and replace the + -- recursive calls + let visit_e (e : Expr) : MetaM Expr := do + trace[Diverge.def.genBody] "visiting expression: {e}" + match e with + | .app .. => do + e.withApp fun f args => do + trace[Diverge.def.genBody] "this is an app: {f} {args}" + -- Check if this is a recursive call + if f.isConst then + let name := f.constName! + match nameToId.find? name with + | none => pure e + | some id => + -- This is a recursive call: replace it + -- Compute the index + let i ← mkFinVal grSize id + -- Put the arguments in one big dependent tuple + let args ← mkSigmas args.toList + mkAppM' k_var #[i, args] + else + -- Not a recursive call: do nothing + pure e + | .const name _ => + -- Sanity check: we eliminated all the recursive calls + if (nameToId.find? name).isSome then + throwError "mkUnaryBodies: a recursive call was not eliminated" + else pure e + | _ => pure e + + -- Explore the bodies + preDefs.mapM fun preDef => do + -- Replace the recursive calls + let body ← mapVisit visit_e preDef.value + + -- Change the type + lambdaLetTelescope body fun args body => do + let body ← mkSigmasMatch args.toList body 0 + + -- Add the declaration + let value ← mkLambdaFVars #[k_var] body + let name := preDef.declName.append "body" + let levelParams := grLvlParams + let decl := Declaration.defnDecl { + name := name + levelParams := levelParams + type := ← inferType value -- TODO: change the type + value := value + hints := ReducibilityHints.regular (getMaxHeight (← getEnv) value + 1) + safety := .safe + all := [name] + } + addDecl decl + trace[Diverge.def] "individual body of {preDef.declName}: {body}" + -- Return the constant + let body := Lean.mkConst name (levelParams.map .param) + -- let body ← mkAppM' body #[k_var] + trace[Diverge.def] "individual body (after decl): {body}" + pure body + +-- Generate a unique function body from the bodies of the mutually recursive group, +-- and add it as a declaration in the context +def mkDeclareMutualBody (grName : Name) (grLvlParams : List Name) + (i_var k_var : Expr) + (in_ty out_ty : Expr) (inOutTys : List (Expr × Expr)) + (bodies : Array Expr) : MetaM Expr := do + -- Generate the body + let grSize := bodies.size + let finTypeExpr := mkFin grSize + -- TODO: not very clean + let inOutTyType ← do + let (x, y) := inOutTys.get! 0 + inferType (← mkInOutTy x y) + let rec mkFuns (inOutTys : List (Expr × Expr)) (bl : List Expr) : MetaM Expr := + match inOutTys, bl with + | [], [] => + mkAppOptM ``FixI.Funs.Nil #[finTypeExpr, in_ty, out_ty] + | (ity, oty) :: inOutTys, b :: bl => do + -- Retrieving ity and oty - this is not very clean + let inOutTysExpr ← mkList (← inOutTys.mapM (λ (x, y) => mkInOutTy x y)) inOutTyType + let fl ← mkFuns inOutTys bl + mkAppOptM ``FixI.Funs.Cons #[finTypeExpr, in_ty, out_ty, ity, oty, inOutTysExpr, b, fl] + | _, _ => throwError "mkDeclareMutualBody: `tys` and `bodies` don't have the same length" + let bodyFuns ← mkFuns inOutTys bodies.toList + -- Wrap in `get_fun` + let body ← mkAppM ``FixI.get_fun #[bodyFuns, i_var, k_var] + -- Add the index `i` and the continuation `k` as a variables + let body ← mkLambdaFVars #[k_var, i_var] body + trace[Diverge.def] "mkDeclareMutualBody: body: {body}" + -- Add the declaration + let name := grName.append "mutrec_body" + let levelParams := grLvlParams + let decl := Declaration.defnDecl { + name := name + levelParams := levelParams + type := ← inferType body + value := body + hints := ReducibilityHints.regular (getMaxHeight (← getEnv) body + 1) + safety := .safe + all := [name] + } + addDecl decl + -- Return the constant + pure (Lean.mkConst name (levelParams.map .param)) + +-- Generate the final definions by using the mutual body and the fixed point operator. +def mkDeclareFixDefs (mutBody : Expr) (preDefs : Array PreDefinition) : + TermElabM Unit := do + let grSize := preDefs.size + let _ ← preDefs.mapIdxM fun idx preDef => do + lambdaLetTelescope preDef.value fun xs _ => do + -- Create the index + let idx ← mkFinVal grSize idx.val + -- Group the inputs into a dependent tuple + let input ← mkSigmas xs.toList + -- Apply the fixed point + let fixedBody ← mkAppM ``FixI.fix #[mutBody, idx, input] + let fixedBody ← mkLambdaFVars xs fixedBody + -- Create the declaration + let name := preDef.declName + let decl := Declaration.defnDecl { + name := name + levelParams := preDef.levelParams + type := preDef.type + value := fixedBody + hints := ReducibilityHints.regular (getMaxHeight (← getEnv) fixedBody + 1) + safety := .safe + all := [name] + } + addDecl decl + pure () + def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) trace[Diverge.def] ("divRecursion: defs: " ++ msg) -- CHANGE HERE This function should add definitions with these names/types/values ^^ -- Temporarily add the predefinitions as axioms - for preDef in preDefs do - addAsAxiom preDef + -- for preDef in preDefs do + -- addAsAxiom preDef -- TODO: what is this? for preDef in preDefs do @@ -154,25 +392,14 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let grName := def0.declName trace[Diverge.def] "group name: {grName}" - /- Compute the type of the continuation. - - We do the following - - we make sure all the definitions have the same universe parameters - (we can make this more general later) - - we group all the type parameters together, make sure all the - definitions have the same type parameters, and enforce - a uniform polymorphism (we can also lift this later). - This would require generalizing a bit our indexed fixed point to - make the output type parametric in the input. - - we group all the non-type parameters: we parameterize the continuation - by those - -/ + /- # Compute the input/output types of the continuation `k`. -/ let grLvlParams := def0.levelParams - trace[Diverge.def] "def0 type: {def0.type}" + trace[Diverge.def] "def0 universe levels: {def0.levelParams}" - -- Compute the list of pairs: (input type × output type) + -- We first compute the list of pairs: (input type × output type) let inOutTys : Array (Expr × Expr) ← preDefs.mapM (fun preDef => do + withRef preDef.ref do -- is the withRef useful? -- Check the universe parameters - TODO: I'm not sure what the best thing -- to do is. In practice, all the type parameters should be in Type 0, so -- we shouldn't have universe issues. @@ -180,68 +407,74 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do throwError "Non-uniform polymorphism in the universes" forallTelescope preDef.type (fun in_tys out_ty => do let in_ty ← liftM (mkSigmasTypesOfTypes in_tys.toList) - let out_ty ← liftM (mkSigmasOutType in_tys.toList out_ty) - return (in_ty, out_ty) + -- Retrieve the type in the "Result" + let out_ty ← get_result_ty out_ty + let out_ty ← liftM (mkSigmasMatch in_tys.toList out_ty) + pure (in_ty, out_ty) ) ) trace[Diverge.def] "inOutTys: {inOutTys}" - -/- -- Small utility: compute the list of type parameters - let getTypeParams (ty: Expr) : MetaM (List Expr × List Expr × Expr) := - Lean.Meta.forallTelescope ty fun tys out_ty => do - trace[Diverge.def] "types: {tys}" -/- let (_, params) ← StateT.run (do - for x in tys do - let ty ← Lean.Meta.inferType x - match ty with - | .sort _ => do - let st ← StateT.get - StateT.set (ty :: st) - | _ => do break - ) ([] : List Expr) - let params := params.reverse - trace[Diverge.def] " type parameters {params}" - return params -/ - let rec get_params (ls : List Expr) : MetaM (List Expr × List Expr) := - match ls with - | x :: tl => do - let ty ← Lean.Meta.inferType x - match ty with - | .sort _ => do - let (ty_params, params) ← get_params tl - return (x :: ty_params, params) - | _ => do return ([], ls) - | _ => do return ([], []) - let (ty_params, params) ← get_params tys.toList - trace[Diverge.def] " parameters: {ty_params}; {params}" - return (ty_params, params, out_ty) - let (grTyParams, _, _) ← do - getTypeParams def0.type - - -- Compute the input types and the output types - let all_tys ← preDefs.mapM fun preDef => do - let (tyParams, params, ret_ty) ← getTypeParams preDef.type - -- TODO: this is not complete, there are more checks to perform - if tyParams.length ≠ grTyParams.length then - throwError "Non-uniform polymorphism" - return (params, ret_ty) - - -- TODO: I think there are issues with the free variables - let (input_tys, output_tys) := List.unzip all_tys.toList - let input_tys : List Expr ← liftM (List.mapM mkProds input_tys) - - trace[Diverge.def] " in/out tys: {input_tys}; {output_tys}" -/ - - -- Compute the names set - let names := preDefs.map PreDefinition.declName - let names := HashSet.empty.insertMany names - - -- - -- for preDef in preDefs do - -- trace[Diverge.def] "about to explore: {preDef.declName}" - -- explore_term "" preDef.value - - -- Compute the bodies + -- Turn the list of input/output type pairs into an expresion + let inOutTysExpr ← inOutTys.mapM (λ (x, y) => mkInOutTy x y) + let inOutTysExpr ← mkList inOutTysExpr.toList (← inferType (inOutTysExpr.get! 0)) + + -- From the list of pairs of input/output types, actually compute the + -- type of the continuation `k`. + -- We first introduce the index `i : Fin n` where `n` is the number of + -- functions in the group. + let i_var_ty := mkFin preDefs.size + withLocalDeclD (.num (.str .anonymous "i") 0) i_var_ty fun i_var => do + let in_out_ty ← mkAppM ``List.get #[inOutTysExpr, i_var] + trace[Diverge.def] "in_out_ty := {in_out_ty} : {← inferType in_out_ty}" + -- Add an auxiliary definition for `in_out_ty` + let in_out_ty ← do + let value ← mkLambdaFVars #[i_var] in_out_ty + let name := grName.append "in_out_ty" + let levelParams := grLvlParams + let decl := Declaration.defnDecl { + name := name + levelParams := levelParams + type := ← inferType value + value := value + hints := .abbrev + safety := .safe + all := [name] + } + addDecl decl + -- Return the constant + let in_out_ty := Lean.mkConst name (levelParams.map .param) + mkAppM' in_out_ty #[i_var] + trace[Diverge.def] "in_out_ty (after decl) := {in_out_ty} : {← inferType in_out_ty}" + let in_ty ← mkAppM ``Sigma.fst #[in_out_ty] + trace[Diverge.def] "in_ty: {in_ty}" + withLocalDeclD (.num (.str .anonymous "x") 1) in_ty fun input => do + let out_ty ← mkAppM' (← mkAppM ``Sigma.snd #[in_out_ty]) #[input] + trace[Diverge.def] "out_ty: {out_ty}" + + -- Introduce the continuation `k` + let in_ty ← mkLambdaFVars #[i_var] in_ty + let out_ty ← mkLambdaFVars #[i_var, input] out_ty + let k_var_ty ← mkAppM ``FixI.kk_ty #[i_var_ty, in_ty, out_ty] -- + trace[Diverge.def] "k_var_ty: {k_var_ty}" + withLocalDeclD (.num (.str .anonymous "k") 2) k_var_ty fun k_var => do + trace[Diverge.def] "k_var: {k_var}" + + -- Replace the recursive calls in all the function bodies by calls to the + -- continuation `k` and and generate for those bodies declarations + let bodies ← mkDeclareUnaryBodies grLvlParams k_var preDefs + -- Generate the mutually recursive body + let body ← mkDeclareMutualBody grName grLvlParams i_var k_var in_ty out_ty inOutTys.toList bodies + trace[Diverge.def] "mut rec body (after decl): {body}" + + -- Prove that the mut rec body satisfies the validity criteria required by + -- our fixed-point + -- TODO + + -- Generate the final definitions + let defs ← mkDeclareFixDefs body preDefs + + -- Prove the unfolding equations + -- TODO -- Process the definitions addAndCompilePartialRec preDefs @@ -366,6 +599,10 @@ divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := if i = 0 then return x else return (← list_nth ls (i - 1)) +#print list_nth.in_out_ty +#check list_nth.body +#print list_nth + mutual divergent def is_even (i : Int) : Result Bool := if i = 0 then return true else return (← is_odd (i - 1)) diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean index 441b25f0..82f79f94 100644 --- a/backends/lean/Base/Diverge/ElabBase.lean +++ b/backends/lean/Base/Diverge/ElabBase.lean @@ -4,13 +4,14 @@ namespace Diverge open Lean Elab Term Meta -initialize registerTraceClass `Diverge.elab (inherited := true) -initialize registerTraceClass `Diverge.def.sigmas (inherited := true) -initialize registerTraceClass `Diverge.def (inherited := true) +initialize registerTraceClass `Diverge.elab +initialize registerTraceClass `Diverge.def +initialize registerTraceClass `Diverge.def.sigmas +initialize registerTraceClass `Diverge.def.genBody -- TODO: move -- TODO: small helper -def explore_term (incr : String) (e : Expr) : TermElabM Unit := +def explore_term (incr : String) (e : Expr) : MetaM Unit := match e with | .bvar _ => do logInfo m!"{incr}bvar: {e}"; return () | .fvar _ => do logInfo m!"{incr}fvar: {e}"; return () @@ -78,4 +79,42 @@ private def test2 (x : Nat) : Nat := x print_decl test1 print_decl test2 +-- We adapted this from AbstractNestedProofs.visit +-- A map visitor function for expressions +partial def mapVisit (k : Expr → MetaM Expr) (e : Expr) : MetaM Expr := do + let mapVisitBinders (xs : Array Expr) (k2 : MetaM Expr) : MetaM Expr := do + let localInstances ← getLocalInstances + let mut lctx ← getLCtx + for x in xs do + let xFVarId := x.fvarId! + let localDecl ← xFVarId.getDecl + let type ← mapVisit k localDecl.type + let localDecl := localDecl.setType type + let localDecl ← match localDecl.value? with + | some value => let value ← mapVisit k value; pure <| localDecl.setValue value + | none => pure localDecl + lctx :=lctx.modifyLocalDecl xFVarId fun _ => localDecl + withLCtx lctx localInstances k2 + -- TODO: use a cache? (Lean.checkCache) + -- Explore + let e ← k e + match e with + | .bvar _ + | .fvar _ + | .mvar _ + | .sort _ + | .lit _ + | .const _ _ => pure e + | .app .. => do e.withApp fun f args => return mkAppN f (← args.mapM (mapVisit k)) + | .lam .. => + lambdaLetTelescope e fun xs b => + mapVisitBinders xs do mkLambdaFVars xs (← mapVisit k b) (usedLetOnly := false) + | .forallE .. => do + forallTelescope e fun xs b => mapVisitBinders xs do mkForallFVars xs (← mapVisit k b) + | .letE .. => do + lambdaLetTelescope e fun xs b => mapVisitBinders xs do + mkLambdaFVars xs (← mapVisit k b) (usedLetOnly := false) + | .mdata _ b => return e.updateMData! (← mapVisit k b) + | .proj _ _ b => return e.updateProj! (← mapVisit k b) + end Diverge |