diff options
| author | Son Ho | 2023-06-30 15:53:39 +0200 | 
|---|---|---|
| committer | Son Ho | 2023-06-30 15:53:39 +0200 | 
| commit | 1c9331ce92b68b9a83c601212149a6c24591708f (patch) | |
| tree | 7918a0c930ff675bb83e5a5030dd8208a9e500e3 /backends/lean | |
| parent | fdc8693772ecb1978873018c790061854f00a015 (diff) | |
Generate the fixed-point bodies in Elab.lean
Diffstat (limited to '')
| -rw-r--r-- | backends/lean/Base/Diverge/Base.lean | 8 | ||||
| -rw-r--r-- | backends/lean/Base/Diverge/Elab.lean | 451 | ||||
| -rw-r--r-- | backends/lean/Base/Diverge/ElabBase.lean | 47 | 
3 files changed, 391 insertions, 115 deletions
| diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index 22b59bd0..aa0539ba 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -554,14 +554,14 @@ namespace FixI    /- Some utilities to define the mutually recursive functions -/    -- TODO: use more -  @[simp] def kk_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) := +  abbrev kk_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) :=      (i:id) → (x:a i) → Result (b i x) -  @[simp] def k_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) := +  abbrev k_ty (id : Type u) (a : id → Type v) (b : (i:id) → (x:a i) → Type w) :=      kk_ty id a b → kk_ty id a b -  def in_out_ty : Type (imax (u + 1) (v + 1)) := (in_ty : Type u) × ((x:in_ty) → Type v) +  abbrev in_out_ty : Type (imax (u + 1) (v + 1)) := (in_ty : Type u) × ((x:in_ty) → Type v)    -- TODO: remove? -  @[simp] def mk_in_out_ty (in_ty : Type u) (out_ty : in_ty → Type v) : +  abbrev mk_in_out_ty (in_ty : Type u) (out_ty : in_ty → Type v) :      in_out_ty :=      Sigma.mk in_ty out_ty diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 116c5d8b..f7de7518 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -16,31 +16,62 @@ syntax (name := divergentDef)  open Lean Elab Term Meta Primitives Lean.Meta  set_option trace.Diverge.def true +-- set_option trace.Diverge.def.sigmas true  /- The following was copied from the `wfRecursion` function. -/  open WF in --- Replace the recursive calls by a call to the continuation --- def replace_rec_calls +def mkList (xl : List Expr) (ty : Expr) : MetaM Expr := +  match xl with +  | [] => +    mkAppOptM ``List.nil #[some ty] +  | x :: tl => do +    let tl ← mkList tl ty +    mkAppOptM ``List.cons #[some ty, some x, some tl] ---  print_decl is_even_body -#check instOfNatNat -#check OfNat.ofNat -- @OfNat.ofNat ℕ 2 ... -#check OfNat.ofNat -- @OfNat.ofNat (Fin 2) 1 ... -#check Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat +def mkProd (x y : Expr) : MetaM Expr := +  mkAppM ``Prod.mk #[x, y] +def mkInOutTy (x y : Expr) : MetaM Expr := +  mkAppM ``FixI.mk_in_out_ty #[x, y]  -- TODO: is there already such a utility somewhere?  -- TODO: change to mkSigmas  def mkProds (tys : List Expr) : MetaM Expr :=    match tys with -  | [] => do return (Expr.const ``PUnit.unit []) -  | [ty] => do return ty +  | [] => do pure (Expr.const ``PUnit.unit []) +  | [ty] => do pure ty    | ty :: tys => do      let pty ← mkProds tys      mkAppM ``Prod.mk #[ty, pty] +-- Return the `a` in `Return a` +def get_result_ty (ty : Expr) : MetaM Expr := +  ty.withApp fun f args => do +  if ¬ f.isConstOf ``Result ∨ args.size ≠ 1 then +    throwError "Invalid argument to get_result_ty: {ty}" +  else +    pure (args.get! 0) + +-- Group a list of expressions into a dependent tuple +def mkSigmas (xl : List Expr) : MetaM Expr := +  match xl with +  | [] => do +    trace[Diverge.def.sigmas] "mkSigmas: []" +    pure (Expr.const ``PUnit.unit []) +  | [x] => do +    trace[Diverge.def.sigmas] "mkSigmas: [{x}]" +    pure x +  | fst :: xl => do +    trace[Diverge.def.sigmas] "mkSigmas: [{fst}::{xl}]" +    let alpha ← Lean.Meta.inferType fst +    let snd ← mkSigmas xl +    let snd_ty ← inferType snd +    let beta ← mkLambdaFVars #[fst] snd_ty +    trace[Diverge.def.sigmas] "mkSigmas:\n{alpha}\n{beta}\n{fst}\n{snd}" +    mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd] +  /- Generate the input type of a function body, which is a sigma type (i.e., a     dependent tuple) which groups all its inputs. @@ -55,11 +86,11 @@ def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr :=    match xl with    | [] => do      trace[Diverge.def.sigmas] "mkSigmasOfTypes: []" -    return (Expr.const ``PUnit.unit []) +    pure (Expr.const ``PUnit.unit [])    | [x] => do      trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}]"      let ty ← Lean.Meta.inferType x -    return ty +    pure ty    | x :: xl => do      trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]"      let alpha ← Lean.Meta.inferType x @@ -71,15 +102,26 @@ def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr :=  def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index -/- Generate the out_ty of the body of a function, which from an input (a sigma -   type generated by `mkSigmasTypesOfTypes`) gives the output type of the function. +/- Given a list of values `[x0:ty0, ..., xn:ty1]` where every `xi` might use the previous +   `xj` (j < i) and a value `out` which uses `x0`, ..., `xn`, generate the following +   expression: +   ``` +   fun x:((x0:ty0) × ... × (xn:tyn) => -- **Dependent** tuple +   match x with +   | (x0, ..., xn) => out +   ``` +    +   The `index` parameter is used for naming purposes: we use it to numerotate the +   bound variables that we introduce.     Example: +   ======== +   More precisely:     - xl = `[a:Type, ls:List a, i:Int]` -   - out_ty = `a` -   - index = 0      -- For naming purposes: we use it to numerotate the "scrutinee" variables +   - out = `a` +   - index = 0 -   Generates: +   generates:     ```     match scrut0 with     | Sigma.mk x scrut1 => @@ -88,36 +130,47 @@ def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index         a     ```  -/ -def mkSigmasOutType (xl : List Expr) (out_ty : Expr) (index : Nat := 0) : MetaM Expr := +partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : MetaM Expr :=    match xl with    | [] => do      -- This would be unexpected -    throwError "mkSigmasOutType: empyt list of input parameters" +    throwError "mkSigmasMatch: empyt list of input parameters"    | [x] => do      -- In the explanations above: inner match case -    trace[Diverge.def.sigmas] "mkSigmasOutType: [{x}]" -    mkLambdaFVars #[x] out_ty +    trace[Diverge.def.sigmas] "mkSigmasMatch: [{x}]" +    mkLambdaFVars #[x] out    | fst :: xl => do      -- In the explanations above: outer match case      -- Remark: for the naming purposes, we use the same convention as for the      -- fields and parameters in `Sigma.casesOn and `Sigma.mk -    trace[Diverge.def.sigmas] "mkSigmasOutType: [{fst}::{xl}]" +    trace[Diverge.def.sigmas] "mkSigmasMatch: [{fst}::{xl}]"      let alpha ← Lean.Meta.inferType fst      let snd_ty ← mkSigmasTypesOfTypes xl      let beta ← mkLambdaFVars #[fst] snd_ty -    let snd ← mkSigmasOutType xl out_ty (index + 1) +    let snd ← mkSigmasMatch xl out (index + 1)      let scrut_ty ← mkSigmasTypesOfTypes (fst :: xl)      withLocalDeclD (mk_indexed_name index) scrut_ty fun scrut => do      let mk ← mkLambdaFVars #[fst] snd -    trace[Diverge.def.sigmas] "mkSigmasOutType: scrut: ({scrut}) : ({← inferType scrut})" -    let motive ← mkLambdaFVars #[scrut] (← inferType out_ty) -    trace[Diverge.def.sigmas] "mkSigmasOutType:\n  ({alpha})\n  ({beta})\n  ({motive})\n  ({scrut})\n  ({mk})" -    let out ← mkAppOptM ``Sigma.casesOn #[some alpha, some beta, some motive, some scrut, some mk] -    let out ← mkLambdaFVars #[scrut] out -    trace[Diverge.def.sigmas] "mkSigmasOutType: out: {out}" -    return out - -/- Small tests for list_nth: give a model of what `mkSigmasOutType` should generate -/ +    trace[Diverge.def.sigmas] "mkSigmasMatch: scrut: ({scrut}) : ({← inferType scrut})" +    -- TODO: make the computation of the motive more efficient +    let motive ← do +      let out_ty ← inferType out +      match out_ty  with +      | .sort _ | .lit _ | .const .. => +        -- The type of the motive doesn't depend on the scrutinee +        mkLambdaFVars #[scrut] out_ty +      | _ => +        -- The type of the motive *may* depend on the scrutinee +        -- TODO: make this more efficient (we could change the output type of +        -- mkSigmasMatch +        mkSigmasMatch (fst :: xl) out_ty +    trace[Diverge.def.sigmas] "mkSigmasMatch:\n  ({alpha})\n  ({beta})\n  ({motive})\n  ({scrut})\n  ({mk})" +    let sm ← mkAppOptM ``Sigma.casesOn #[some alpha, some beta, some motive, some scrut, some mk] +    let sm ← mkLambdaFVars #[scrut] sm +    trace[Diverge.def.sigmas] "mkSigmasMatch: sm: {sm}" +    pure sm + +/- Small tests for list_nth: give a model of what `mkSigmasMatch` should generate -/  private def list_nth_out_ty2 (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) :=    @Sigma.casesOn (List a)                   (fun (_ls : List a) => Int) @@ -135,14 +188,199 @@ private def list_nth_out_ty1 (scrut0 : @Sigma (Type) (fun (a:Type) =>                    list_nth_out_ty2 a scrut1)  /- -/ +-- TODO: move +-- TODO: we can use Array.mapIdx +@[specialize] def mapiAux (i : Nat) (f : Nat → α → β) : List α → List β +  | []    => [] +  | a::as => f i a :: mapiAux (i+1) f as + +@[specialize] def mapi (f : Nat → α → β) : List α → List β := mapiAux 0 f + +#check Array.map +-- Return the expression: `Fin n` +-- TODO: use more +def mkFin (n : Nat) : Expr := +  mkAppN (.const ``Fin []) #[.lit (.natVal n)] + +-- Return the expression: `i : Fin n` +def mkFinVal (n i : Nat) : MetaM Expr := do +  let n_lit : Expr := .lit (.natVal (n - 1)) +  let i_lit : Expr := .lit (.natVal i) +  -- We could use `trySynthInstance`, but as we know the instance that we are +  -- going to use, we can save the lookup +  let ofNat ← mkAppOptM ``Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat #[n_lit, i_lit] +  mkAppOptM ``OfNat.ofNat #[none, none, ofNat] + +-- TODO: remove? +def mkFinValOld (n i : Nat) : MetaM Expr := do +  let finTy := mkFin n +  let ofNat ← mkAppM ``OfNat #[finTy, .lit (.natVal i)] +  match ← trySynthInstance ofNat with +  | LOption.some x => +    mkAppOptM ``OfNat.ofNat #[none, none, x] +  | _ => throwError "mkFinVal: could not synthesize an instance of {ofNat} " + +/- Generate and declare as individual definitions the bodies for the individual funcions: +   - replace the recursive calls with calls to the continutation `k` +   - make those bodies take one single dependent tuple as input + +   We name the declarations: "[original_name].body". +   We return the new declarations. + -/ +def mkDeclareUnaryBodies (grLvlParams : List Name) (k_var : Expr) +  (preDefs : Array PreDefinition) : +  MetaM (Array Expr) := do +  let grSize := preDefs.size + +  -- Compute the map from name to index - the continuation has an indexed type: +  -- we use the index (a finite number of type `Fin`) to control the function +  -- we call at the recursive call +  let nameToId : HashMap Name Nat := +    let namesIds := mapi (fun i d => (d.declName, i)) preDefs.toList +    HashMap.ofList namesIds + +  trace[Diverge.def.genBody] "nameToId: {nameToId.toList}" + +  -- Auxiliary function to explore the function bodies and replace the +  -- recursive calls +  let visit_e (e : Expr) : MetaM Expr := do +    trace[Diverge.def.genBody] "visiting expression: {e}" +    match e with +    | .app .. => do +      e.withApp fun f args => do +        trace[Diverge.def.genBody] "this is an app: {f} {args}" +        -- Check if this is a recursive call +        if f.isConst then +          let name := f.constName! +          match nameToId.find? name with +          | none => pure e +          | some id => +            -- This is a recursive call: replace it +            -- Compute the index +            let i ← mkFinVal grSize id +            -- Put the arguments in one big dependent tuple +            let args ← mkSigmas args.toList +            mkAppM' k_var #[i, args] +        else +          -- Not a recursive call: do nothing +          pure e +     | .const name _ => +       -- Sanity check: we eliminated all the recursive calls +       if (nameToId.find? name).isSome then +         throwError "mkUnaryBodies: a recursive call was not eliminated" +       else pure e +     | _ => pure e + +  -- Explore the bodies +  preDefs.mapM fun preDef => do +    -- Replace the recursive calls +    let body ← mapVisit visit_e preDef.value + +    -- Change the type +    lambdaLetTelescope body fun args body => do +    let body ← mkSigmasMatch args.toList body 0 + +    -- Add the declaration +    let value ← mkLambdaFVars #[k_var] body +    let name := preDef.declName.append "body" +    let levelParams := grLvlParams +    let decl := Declaration.defnDecl { +      name := name +      levelParams := levelParams +      type := ← inferType value -- TODO: change the type +      value := value +      hints := ReducibilityHints.regular (getMaxHeight (← getEnv) value + 1) +      safety := .safe +      all := [name] +    } +    addDecl decl +    trace[Diverge.def] "individual body of {preDef.declName}: {body}" +    -- Return the constant +    let body := Lean.mkConst name (levelParams.map .param) +    -- let body ← mkAppM' body #[k_var] +    trace[Diverge.def] "individual body (after decl): {body}" +    pure body + +-- Generate a unique function body from the bodies of the mutually recursive group, +-- and add it as a declaration in the context +def mkDeclareMutualBody (grName : Name) (grLvlParams : List Name) +  (i_var k_var : Expr) +  (in_ty out_ty : Expr) (inOutTys : List (Expr × Expr)) +  (bodies : Array Expr) : MetaM Expr := do +  -- Generate the body +  let grSize := bodies.size +  let finTypeExpr := mkFin grSize +  -- TODO: not very clean +  let inOutTyType ← do +    let (x, y) := inOutTys.get! 0 +    inferType (← mkInOutTy x y) +  let rec mkFuns (inOutTys : List (Expr × Expr)) (bl : List Expr) : MetaM Expr := +    match inOutTys, bl with +    | [], [] => +      mkAppOptM ``FixI.Funs.Nil #[finTypeExpr, in_ty, out_ty] +    | (ity, oty) :: inOutTys, b :: bl => do +      -- Retrieving ity and oty - this is not very clean +      let inOutTysExpr ← mkList (← inOutTys.mapM (λ (x, y) => mkInOutTy x y)) inOutTyType +      let fl ← mkFuns inOutTys bl +      mkAppOptM ``FixI.Funs.Cons #[finTypeExpr, in_ty, out_ty, ity, oty, inOutTysExpr, b, fl] +    | _, _ => throwError "mkDeclareMutualBody: `tys` and `bodies` don't have the same length" +  let bodyFuns ← mkFuns inOutTys bodies.toList +  -- Wrap in `get_fun` +  let body ← mkAppM ``FixI.get_fun #[bodyFuns, i_var, k_var] +  -- Add the index `i` and the continuation `k` as a variables +  let body ← mkLambdaFVars #[k_var, i_var] body +  trace[Diverge.def] "mkDeclareMutualBody: body: {body}" +  -- Add the declaration +  let name := grName.append "mutrec_body" +  let levelParams := grLvlParams +  let decl := Declaration.defnDecl { +    name := name +    levelParams := levelParams +    type := ← inferType body +    value := body +    hints := ReducibilityHints.regular (getMaxHeight (← getEnv) body + 1) +    safety := .safe +    all := [name] +  } +  addDecl decl +  -- Return the constant +  pure (Lean.mkConst name (levelParams.map .param)) + +-- Generate the final definions by using the mutual body and the fixed point operator. +def mkDeclareFixDefs (mutBody : Expr) (preDefs : Array PreDefinition) : +  TermElabM Unit := do +  let grSize := preDefs.size +  let _ ← preDefs.mapIdxM fun idx preDef => do +    lambdaLetTelescope preDef.value fun xs _ => do +    -- Create the index +    let idx ← mkFinVal grSize idx.val +    -- Group the inputs into a dependent tuple +    let input ← mkSigmas xs.toList +    -- Apply the fixed point +    let fixedBody ← mkAppM ``FixI.fix #[mutBody, idx, input] +    let fixedBody ← mkLambdaFVars xs fixedBody +    -- Create the declaration +    let name := preDef.declName +    let decl := Declaration.defnDecl { +      name := name +      levelParams := preDef.levelParams +      type := preDef.type +      value := fixedBody +      hints := ReducibilityHints.regular (getMaxHeight (← getEnv) fixedBody + 1) +      safety := .safe +      all := [name] +    } +    addDecl decl +  pure () +  def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do    let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value)    trace[Diverge.def] ("divRecursion: defs: " ++ msg)    -- CHANGE HERE This function should add definitions with these names/types/values ^^    -- Temporarily add the predefinitions as axioms -  for preDef in preDefs do -    addAsAxiom preDef +  -- for preDef in preDefs do +  --  addAsAxiom preDef    -- TODO: what is this?    for preDef in preDefs do @@ -154,25 +392,14 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do    let grName := def0.declName    trace[Diverge.def] "group name: {grName}" -  /- Compute the type of the continuation. - -     We do the following -     - we make sure all the definitions have the same universe parameters -       (we can make this more general later) -     - we group all the type parameters together, make sure all the -       definitions have the same type parameters, and enforce -       a uniform polymorphism (we can also lift this later). -       This would require generalizing a bit our indexed fixed point to -       make the output type parametric in the input. -     - we group all the non-type parameters: we parameterize the continuation -       by those -  -/ +  /- # Compute the input/output types of the continuation `k`. -/    let grLvlParams := def0.levelParams -  trace[Diverge.def] "def0 type: {def0.type}" +  trace[Diverge.def] "def0 universe levels: {def0.levelParams}" -  -- Compute the list of pairs: (input type × output type) +  -- We first compute the list of pairs: (input type × output type)    let inOutTys : Array (Expr × Expr) ←        preDefs.mapM (fun preDef => do +        withRef preDef.ref do -- is the withRef useful?          -- Check the universe parameters - TODO: I'm not sure what the best thing          -- to do is. In practice, all the type parameters should be in Type 0, so          -- we shouldn't have universe issues. @@ -180,68 +407,74 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do            throwError "Non-uniform polymorphism in the universes"          forallTelescope preDef.type (fun in_tys out_ty => do            let in_ty ← liftM (mkSigmasTypesOfTypes in_tys.toList) -          let out_ty ← liftM (mkSigmasOutType in_tys.toList out_ty) -          return (in_ty, out_ty) +          -- Retrieve the type in the "Result" +          let out_ty ← get_result_ty out_ty +          let out_ty ← liftM (mkSigmasMatch in_tys.toList out_ty) +          pure (in_ty, out_ty)          )        )    trace[Diverge.def] "inOutTys: {inOutTys}" - -/-  -- Small utility: compute the list of type parameters -  let getTypeParams (ty: Expr) : MetaM (List Expr × List Expr × Expr) := -    Lean.Meta.forallTelescope ty fun tys out_ty => do -      trace[Diverge.def] "types: {tys}" -/-      let (_, params) ← StateT.run (do -        for x in tys do -          let ty ← Lean.Meta.inferType x -          match ty with -          | .sort _ => do -            let st ← StateT.get -            StateT.set (ty :: st) -          | _ => do break -        ) ([] : List Expr) -      let params := params.reverse -      trace[Diverge.def] "  type parameters {params}" -      return params -/ -      let rec get_params (ls : List Expr) : MetaM (List Expr × List Expr) := -        match ls with -        | x :: tl => do -          let ty ← Lean.Meta.inferType x -          match ty with -          | .sort _ => do -            let (ty_params, params) ← get_params tl -            return (x :: ty_params, params) -          | _ => do return ([], ls) -        | _ => do return ([], []) -      let (ty_params, params) ← get_params tys.toList -      trace[Diverge.def] "  parameters: {ty_params}; {params}" -      return (ty_params, params, out_ty) -  let (grTyParams, _, _) ← do -    getTypeParams def0.type - -  -- Compute the input types and the output types -  let all_tys ← preDefs.mapM fun preDef => do -    let (tyParams, params, ret_ty) ← getTypeParams preDef.type -    -- TODO: this is not complete, there are more checks to perform -    if tyParams.length ≠ grTyParams.length then -      throwError "Non-uniform polymorphism" -    return (params, ret_ty) - -  -- TODO: I think there are issues with the free variables -  let (input_tys, output_tys) := List.unzip all_tys.toList -  let input_tys : List Expr ← liftM (List.mapM mkProds input_tys) - -  trace[Diverge.def] "  in/out tys: {input_tys}; {output_tys}" -/ - -  -- Compute the names set -  let names := preDefs.map PreDefinition.declName -  let names := HashSet.empty.insertMany names - -  -- -  -- for preDef in preDefs do -  --  trace[Diverge.def] "about to explore: {preDef.declName}" -  --  explore_term "" preDef.value - -  -- Compute the bodies +  -- Turn the list of input/output type pairs into an expresion +  let inOutTysExpr ← inOutTys.mapM (λ (x, y) => mkInOutTy x y) +  let inOutTysExpr ← mkList inOutTysExpr.toList (← inferType (inOutTysExpr.get! 0)) + +  -- From the list of pairs of input/output types, actually compute the +  -- type of the continuation `k`. +  -- We first introduce the index `i : Fin n` where `n` is the number of +  -- functions in the group. +  let i_var_ty := mkFin preDefs.size +  withLocalDeclD (.num (.str .anonymous "i") 0) i_var_ty fun i_var => do +  let in_out_ty ← mkAppM ``List.get #[inOutTysExpr, i_var] +  trace[Diverge.def] "in_out_ty := {in_out_ty} : {← inferType in_out_ty}" +  -- Add an auxiliary definition for `in_out_ty` +  let in_out_ty ← do +    let value ← mkLambdaFVars #[i_var] in_out_ty +    let name := grName.append "in_out_ty" +    let levelParams := grLvlParams +    let decl := Declaration.defnDecl { +      name := name +      levelParams := levelParams +      type := ← inferType value +      value := value +      hints := .abbrev +      safety := .safe +      all := [name] +    } +    addDecl decl +    -- Return the constant +    let in_out_ty := Lean.mkConst name (levelParams.map .param) +    mkAppM' in_out_ty #[i_var] +  trace[Diverge.def] "in_out_ty (after decl) := {in_out_ty} : {← inferType in_out_ty}" +  let in_ty ← mkAppM ``Sigma.fst #[in_out_ty] +  trace[Diverge.def] "in_ty: {in_ty}" +  withLocalDeclD (.num (.str .anonymous "x") 1) in_ty fun input => do +  let out_ty ← mkAppM' (← mkAppM ``Sigma.snd #[in_out_ty]) #[input] +  trace[Diverge.def] "out_ty: {out_ty}" + +  -- Introduce the continuation `k` +  let in_ty ← mkLambdaFVars #[i_var] in_ty +  let out_ty ← mkLambdaFVars #[i_var, input] out_ty +  let k_var_ty ← mkAppM ``FixI.kk_ty #[i_var_ty, in_ty, out_ty] -- +  trace[Diverge.def] "k_var_ty: {k_var_ty}" +  withLocalDeclD (.num (.str .anonymous "k") 2) k_var_ty fun k_var => do +  trace[Diverge.def] "k_var: {k_var}" + +  -- Replace the recursive calls in all the function bodies by calls to the +  -- continuation `k` and and generate for those bodies declarations +  let bodies ← mkDeclareUnaryBodies grLvlParams k_var preDefs +  -- Generate the mutually recursive body +  let body ← mkDeclareMutualBody grName grLvlParams i_var k_var in_ty out_ty inOutTys.toList bodies +  trace[Diverge.def] "mut rec body (after decl): {body}" + +  -- Prove that the mut rec body satisfies the validity criteria required by +  -- our fixed-point +  -- TODO + +  -- Generate the final definitions +  let defs ← mkDeclareFixDefs body preDefs + +  -- Prove the unfolding equations +  -- TODO    -- Process the definitions    addAndCompilePartialRec preDefs @@ -366,6 +599,10 @@ divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a :=      if i = 0 then return x      else return (← list_nth ls (i - 1)) +#print list_nth.in_out_ty +#check list_nth.body +#print list_nth +  mutual    divergent def is_even (i : Int) : Result Bool :=      if i = 0 then return true else return (← is_odd (i - 1)) diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean index 441b25f0..82f79f94 100644 --- a/backends/lean/Base/Diverge/ElabBase.lean +++ b/backends/lean/Base/Diverge/ElabBase.lean @@ -4,13 +4,14 @@ namespace Diverge  open Lean Elab Term Meta -initialize registerTraceClass `Diverge.elab (inherited := true) -initialize registerTraceClass `Diverge.def.sigmas (inherited := true) -initialize registerTraceClass `Diverge.def (inherited := true) +initialize registerTraceClass `Diverge.elab +initialize registerTraceClass `Diverge.def +initialize registerTraceClass `Diverge.def.sigmas +initialize registerTraceClass `Diverge.def.genBody  -- TODO: move  -- TODO: small helper -def explore_term (incr : String) (e : Expr) : TermElabM Unit := +def explore_term (incr : String) (e : Expr) : MetaM Unit :=    match e with    | .bvar _ => do logInfo m!"{incr}bvar: {e}"; return ()    | .fvar _ => do logInfo m!"{incr}fvar: {e}"; return () @@ -78,4 +79,42 @@ private def test2 (x : Nat) : Nat := x  print_decl test1  print_decl test2 +-- We adapted this from AbstractNestedProofs.visit +-- A map visitor function for expressions +partial def mapVisit (k : Expr → MetaM Expr) (e : Expr) : MetaM Expr := do +  let mapVisitBinders (xs : Array Expr) (k2 : MetaM Expr) : MetaM Expr := do +    let localInstances ← getLocalInstances +    let mut lctx ← getLCtx +    for x in xs do +      let xFVarId := x.fvarId! +      let localDecl ← xFVarId.getDecl +      let type      ← mapVisit k localDecl.type +      let localDecl := localDecl.setType type +      let localDecl ← match localDecl.value? with +         | some value => let value ← mapVisit k value; pure <| localDecl.setValue value +         | none       => pure localDecl +      lctx :=lctx.modifyLocalDecl xFVarId fun _ => localDecl +    withLCtx lctx localInstances k2 +  -- TODO: use a cache? (Lean.checkCache) +  -- Explore +  let e ← k e +  match e with +  | .bvar _ +  | .fvar _ +  | .mvar _ +  | .sort _ +  | .lit _ +  | .const _ _ => pure e +  | .app .. => do e.withApp fun f args => return mkAppN f (← args.mapM (mapVisit k)) +  | .lam .. => +    lambdaLetTelescope e fun xs b => +      mapVisitBinders xs do mkLambdaFVars xs (← mapVisit k b) (usedLetOnly := false) +  | .forallE .. => do +    forallTelescope e fun xs b => mapVisitBinders xs do mkForallFVars xs (← mapVisit k b) +  | .letE .. => do +    lambdaLetTelescope e fun xs b => mapVisitBinders xs do +      mkLambdaFVars xs (← mapVisit k b) (usedLetOnly := false) +  | .mdata _ b => return e.updateMData! (← mapVisit k b) +  | .proj _ _ b => return e.updateProj! (← mapVisit k b) +  end Diverge | 
