From c23a37617188a1bbf913b5c700522abc33bf39c9 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 11 Dec 2023 17:00:01 +0100 Subject: Update Diverge/Elab.lean to use the more general FixII definitions --- backends/lean/Base/Diverge/Elab.lean | 571 +++++++++++++++++++++++------------ 1 file changed, 383 insertions(+), 188 deletions(-) (limited to 'backends/lean/Base/Diverge/Elab.lean') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 0088fd16..423a2514 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -24,8 +24,8 @@ open WF in def mkProd (x y : Expr) : MetaM Expr := mkAppM ``Prod.mk #[x, y] -def mkInOutTy (x y : Expr) : MetaM Expr := - mkAppM ``FixI.mk_in_out_ty #[x, y] +def mkInOutTy (x y z : Expr) : MetaM Expr := do + mkAppM ``FixII.mk_in_out_ty #[x, y, z] -- Return the `a` in `Return a` def getResultTy (ty : Expr) : MetaM Expr := @@ -60,21 +60,83 @@ def getSigmaTypes (ty : Expr) : MetaM (Expr × Expr) := do def mkSigmasType (xl : List Expr) : MetaM Expr := match xl with | [] => do - trace[Diverge.def.sigmas] "mkSigmasOfTypes: []" - pure (Expr.const ``PUnit.unit []) + trace[Diverge.def.sigmas] "mkSigmasType: []" + pure (Expr.const ``PUnit [Level.succ .zero]) | [x] => do - trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}]" + trace[Diverge.def.sigmas] "mkSigmasType: [{x}]" let ty ← Lean.Meta.inferType x pure ty | x :: xl => do - trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]" + trace[Diverge.def.sigmas] "mkSigmasType: [{x}::{xl}]" let alpha ← Lean.Meta.inferType x let sty ← mkSigmasType xl - trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]: alpha={alpha}, sty={sty}" + trace[Diverge.def.sigmas] "mkSigmasType: [{x}::{xl}]: alpha={alpha}, sty={sty}" let beta ← mkLambdaFVars #[x] sty - trace[Diverge.def.sigmas] "mkSigmasOfTypes: ({alpha}) ({beta})" + trace[Diverge.def.sigmas] "mkSigmasType: ({alpha}) ({beta})" mkAppOptM ``Sigma #[some alpha, some beta] +/- Generate a product type from a list of *variables* (this is similar to `mkSigmas`). + + Example: + - xl = [(ls:List a), (i:Int)] + + Generates: + `List a × Int` + -/ +def mkProdsType (xl : List Expr) : MetaM Expr := + match xl with + | [] => do + trace[Diverge.def.prods] "mkProdsType: []" + pure (Expr.const ``PUnit [Level.succ .zero]) + | [x] => do + trace[Diverge.def.prods] "mkProdsType: [{x}]" + let ty ← Lean.Meta.inferType x + pure ty + | x :: xl => do + trace[Diverge.def.prods] "mkProdsType: [{x}::{xl}]" + let ty ← Lean.Meta.inferType x + let xl_ty ← mkProdsType xl + mkAppM ``Prod #[ty, xl_ty] + +/- Split the input arguments between the types and the "regular" arguments. + + We do something simple: we treat an input argument as an + input type iff it appears in the type of the following arguments. + + Note that what really matters is that we find the arguments which appear + in the output type. + + Also, we stop at the first input that we treat as an + input type. + -/ +def splitInputArgs (in_tys : Array Expr) (out_ty : Expr) : MetaM (Array Expr × Array Expr) := do + -- Look for the first parameter which appears in the subsequent parameters + let rec splitAux (in_tys : List Expr) : MetaM (HashSet FVarId × List Expr × List Expr) := + match in_tys with + | [] => do + let fvars ← getFVarIds (← Lean.Meta.inferType out_ty) + pure (fvars, [], []) + | ty :: in_tys => do + let (fvars, in_tys, in_args) ← splitAux in_tys + -- Have we already found where to split between type variables/regular + -- variables? + if ¬ in_tys.isEmpty then + -- The fvars set is now useless: no need to update it anymore + pure (fvars, ty :: in_tys, in_args) + else + -- Check if ty appears in the set of free variables: + let ty_id := ty.fvarId! + if fvars.contains ty_id then + -- We must split here. Note that we don't need to update the fvars + -- set: it is not useful anymore + pure (fvars, [ty], in_args) + else + -- We must split later: update the fvars set + let fvars := fvars.insertMany (← getFVarIds (← Lean.Meta.inferType ty)) + pure (fvars, [], ty :: in_args) + let (_, in_tys, in_args) ← splitAux in_tys.data + pure (Array.mk in_tys, Array.mk in_args) + /- Apply a lambda expression to some arguments, simplifying the lambdas -/ def applyLambdaToArgs (e : Expr) (xs : Array Expr) : MetaM Expr := do lambdaTelescopeN e xs.size fun vars body => @@ -105,7 +167,7 @@ def mkSigmasVal (ty : Expr) (xl : List Expr) : MetaM Expr := match xl with | [] => do trace[Diverge.def.sigmas] "mkSigmasVal: []" - pure (Expr.const ``PUnit.unit []) + pure (Expr.const ``PUnit.unit [Level.succ .zero]) | [x] => do trace[Diverge.def.sigmas] "mkSigmasVal: [{x}]" pure x @@ -122,6 +184,17 @@ def mkSigmasVal (ty : Expr) (xl : List Expr) : MetaM Expr := trace[Diverge.def.sigmas] "mkSigmasVal:\n{alpha}\n{beta}\n{fst}\n{snd}" mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd] +/- Group a list of expressions into a (non-dependent) tuple -/ +def mkProdsVal (xl : List Expr) : MetaM Expr := + match xl with + | [] => + pure (Expr.const ``PUnit.unit [Level.succ .zero]) + | [x] => do + pure x + | x :: xl => do + let xl ← mkProdsVal xl + mkAppM ``Prod.mk #[x, xl] + def mkAnonymous (s : String) (i : Nat) : Name := .num (.str .anonymous s) i @@ -159,23 +232,23 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met match xl with | [] => do -- This would be unexpected - throwError "mkSigmasMatch: empyt list of input parameters" + throwError "mkSigmasMatch: empty list of input parameters" | [x] => do -- In the example given for the explanations: this is the inner match case trace[Diverge.def.sigmas] "mkSigmasMatch: [{x}]" mkLambdaFVars #[x] out | fst :: xl => do - -- In the example given for the explanations: this is the outer match case - -- Remark: for the naming purposes, we use the same convention as for the - -- fields and parameters in `Sigma.casesOn` and `Sigma.mk` (looking at - -- those definitions might help) - -- - -- We want to build the match expression: - -- ``` - -- λ scrut => - -- match scrut with - -- | Sigma.mk x ... -- the hole is given by a recursive call on the tail - -- ``` + /- In the example given for the explanations: this is the outer match case + Remark: for the naming purposes, we use the same convention as for the + fields and parameters in `Sigma.casesOn` and `Sigma.mk` (looking at + those definitions might help) + + We want to build the match expression: + ``` + λ scrut => + match scrut with + | Sigma.mk x ... -- the hole is given by a recursive call on the tail + ``` -/ trace[Diverge.def.sigmas] "mkSigmasMatch: [{fst}::{xl}]" let alpha ← Lean.Meta.inferType fst let snd_ty ← mkSigmasType xl @@ -183,7 +256,7 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met let snd ← mkSigmasMatch xl out (index + 1) let mk ← mkLambdaFVars #[fst] snd -- Introduce the "scrut" variable - let scrut_ty ← mkSigmasType (fst :: xl) + let scrut_ty ← mkSigmasType (fst :: xl) -- TODO: factor out with snd_ty withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do trace[Diverge.def.sigmas] "mkSigmasMatch: scrut: ({scrut}) : ({← inferType scrut})" -- TODO: make the computation of the motive more efficient @@ -206,6 +279,67 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met trace[Diverge.def.sigmas] "mkSigmasMatch: sm: {sm}" pure sm +/- This is similar to `mkSigmasMatch`, but with non-dependent tuples + + Remark: factor out with `mkSigmasMatch`? This is extremely similar. +-/ +partial def mkProdsMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : MetaM Expr := + match xl with + | [] => do + -- This would be unexpected + throwError "mkProdsMatch: empty list of input parameters" + | [x] => do + -- In the example given for the explanations: this is the inner match case + trace[Diverge.def.prods] "mkProdsMatch: [{x}]" + mkLambdaFVars #[x] out + | fst :: xl => do + trace[Diverge.def.prods] "mkProdsMatch: [{fst}::{xl}]" + let alpha ← Lean.Meta.inferType fst + let beta ← mkProdsType xl + let snd ← mkProdsMatch xl out (index + 1) + let mk ← mkLambdaFVars #[fst] snd + -- Introduce the "scrut" variable + let scrut_ty ← mkProdsType (fst :: xl) -- TODO: factor out with beta + withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do + trace[Diverge.def.prods] "mkProdsMatch: scrut: ({scrut}) : ({← inferType scrut})" + -- TODO: make the computation of the motive more efficient + let motive ← do + let out_ty ← inferType out + mkLambdaFVars #[scrut] out_ty + -- The final expression: putting everything together + trace[Diverge.def.prods] "mkProdsMatch:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})" + let sm ← mkAppOptM ``Prod.casesOn #[some alpha, some beta, some motive, some scrut, some mk] + -- Abstracting the "scrut" variable + let sm ← mkLambdaFVars #[scrut] sm + trace[Diverge.def.prods] "mkProdsMatch: sm: {sm}" + pure sm + +/- Same as `mkSigmasMatch` but also accepts an empty list of inputs, in which case + it generates the expression: + ``` + λ () => e + ``` -/ +def mkSigmasMatchOrUnit (xl : List Expr) (out : Expr) : MetaM Expr := + if xl.isEmpty then do + let scrut_ty := Expr.const ``PUnit [Level.succ .zero] + withLocalDeclD (mkAnonymous "scrut" 0) scrut_ty fun scrut => do + mkLambdaFVars #[scrut] out + else + mkSigmasMatch xl out + +/- Same as `mkProdsMatch` but also accepts an empty list of inputs, in which case + it generates the expression: + ``` + λ () => e + ``` -/ +def mkProdsMatchOrUnit (xl : List Expr) (out : Expr) : MetaM Expr := + if xl.isEmpty then do + let scrut_ty := Expr.const ``PUnit [Level.succ .zero] + withLocalDeclD (mkAnonymous "scrut" 0) scrut_ty fun scrut => do + mkLambdaFVars #[scrut] out + else + mkProdsMatch xl out + /- Small tests for list_nth: give a model of what `mkSigmasMatch` should generate -/ private def list_nth_out_ty_inner (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) := @Sigma.casesOn (List a) @@ -244,17 +378,23 @@ def mkFinVal (n i : Nat) : MetaM Expr := do We name the declarations: "[original_name].body". We return the new declarations. + + Inputs: + - `paramInOutTys`: (number of type parameters, sigma type grouping the type parameters, input type, output type) -/ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) - (inOutTys : Array (Expr × Expr)) (preDefs : Array PreDefinition) : + (paramInOutTys : Array (ℕ × Expr × Expr × Expr)) (preDefs : Array PreDefinition) : MetaM (Array Expr) := do let grSize := preDefs.size - -- Compute the map from name to (index × input type). - -- Remark: the continuation has an indexed type; we use the index (a finite number of - -- type `Fin`) to control which function we call at the recursive call site. - let nameToInfo : HashMap Name (Nat × Expr) := - let bl := preDefs.mapIdx fun i d => (d.declName, (i.val, (inOutTys.get! i.val).fst)) + /- Compute the map from name to (index, num type parameters, parameters type, input type). + Example for `list_nth (α : Type) (ls : List α) (i : Int) : Result α`: `"list_nth" → (0, 1, α, (List α × Int))` + Remark: the continuation has an indexed type; we use the index (a finite number of + type `Fin`) to control which function we call at the recursive call site. -/ + let nameToInfo : HashMap Name (Nat × Nat × Expr × Expr) := + let bl := preDefs.mapIdx fun i d => + let (num_params, params_ty, in_ty, _) := paramInOutTys.get! i.val + (d.declName, (i.val, num_params, params_ty, in_ty)) HashMap.ofList bl.toList trace[Diverge.def.genBody] "nameToId: {nameToInfo.toList}" @@ -262,25 +402,29 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) -- Auxiliary function to explore the function bodies and replace the -- recursive calls let visit_e (i : Nat) (e : Expr) : MetaM Expr := do - trace[Diverge.def.genBody] "visiting expression (dept: {i}): {e}" + trace[Diverge.def.genBody.visit] "visiting expression (dept: {i}): {e}" let ne ← do match e with | .app .. => do e.withApp fun f args => do - trace[Diverge.def.genBody] "this is an app: {f} {args}" + trace[Diverge.def.genBody.visit] "this is an app: {f} {args}" -- Check if this is a recursive call if f.isConst then let name := f.constName! match nameToInfo.find? name with | none => pure e - | some (id, in_ty) => - trace[Diverge.def.genBody] "this is a recursive call" + | some (id, num_params, params_ty, _in_ty) => + trace[Diverge.def.genBody.visit] "this is a recursive call" -- This is a recursive call: replace it -- Compute the index let i ← mkFinVal grSize id - -- Put the arguments in one big dependent tuple - let args ← mkSigmasVal in_ty args.toList - mkAppM' kk_var #[i, args] + -- Split the arguments, and put them in two tuples (the first + -- one is a dependent tuple) + let (param_args, args) := args.toList.splitAt num_params + trace[Diverge.def.genBody.visit] "param_args: {param_args}, args: {args}" + let param_args ← mkSigmasVal params_ty param_args + let args ← mkProdsVal args + mkAppM' kk_var #[i, param_args, args] else -- Not a recursive call: do nothing pure e @@ -290,7 +434,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) throwError "mkUnaryBodies: a recursive call was not eliminated" else pure e | _ => pure e - trace[Diverge.def.genBody] "done with expression (depth: {i}): {e}" + trace[Diverge.def.genBody.visit] "done with expression (depth: {i}): {e}" pure ne -- Explore the bodies @@ -300,13 +444,20 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) let body ← mapVisit visit_e preDef.value trace[Diverge.def.genBody] "Body after replacement of the recursive calls: {body}" - -- Currify the function by grouping the arguments into a dependent tuple + -- Currify the function by grouping the arguments into dependent tuples -- (over which we match to retrieve the individual arguments). lambdaTelescope body fun args body => do - let body ← mkSigmasMatch args.toList body 0 + -- Split the arguments between the type parameters and the "regular" inputs + let (_, num_params, _, _) := nameToInfo.find! preDef.declName + let (param_args, args) := args.toList.splitAt num_params + let body ← mkProdsMatchOrUnit args body + trace[Diverge.def.genBody] "Body after mkProdsMatchOrUnit: {body}" + let body ← mkSigmasMatchOrUnit param_args body + trace[Diverge.def.genBody] "Body after mkSigmasMatchOrUnit: {body}" -- Add the declaration let value ← mkLambdaFVars #[kk_var] body + trace[Diverge.def.genBody] "Body after abstracting kk: {value}" let name := preDef.declName.append "body" let levelParams := grLvlParams let decl := Declaration.defnDecl { @@ -318,6 +469,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) safety := .safe all := [name] } + trace[Diverge.def.genBody] "About to add decl" addDecl decl trace[Diverge.def] "individual body of {preDef.declName}: {body}" -- Return the constant @@ -326,33 +478,35 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) trace[Diverge.def] "individual body (after decl): {body}" pure body --- Generate a unique function body from the bodies of the mutually recursive group, --- and add it as a declaration in the context. --- We return the list of bodies (of type `FixI.Funs ...`) and the mutually recursive body. +/- Generate a unique function body from the bodies of the mutually recursive group, + and add it as a declaration in the context. + We return the list of bodies (of type `FixI.Funs ...`) and the mutually recursive body. + -/ def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name) (kk_var i_var : Expr) - (in_ty out_ty : Expr) (inOutTys : List (Expr × Expr)) + (param_ty in_ty out_ty : Expr) (paramInOutTys : List (Nat × Expr × Expr × Expr)) (bodies : Array Expr) : MetaM (Expr × Expr) := do -- Generate the body let grSize := bodies.size let finTypeExpr := mkFin grSize -- TODO: not very clean - let inOutTyType ← do - let (x, y) := inOutTys.get! 0 - inferType (← mkInOutTy x y) - let rec mkFuns (inOutTys : List (Expr × Expr)) (bl : List Expr) : MetaM Expr := - match inOutTys, bl with + let paramInOutTyType ← do + let (_, x, y, z) := paramInOutTys.get! 0 + inferType (← mkInOutTy x y z) + let rec mkFuns (paramInOutTys : List (Nat × Expr × Expr × Expr)) (bl : List Expr) : MetaM Expr := + match paramInOutTys, bl with | [], [] => - mkAppOptM ``FixI.Funs.Nil #[finTypeExpr, in_ty, out_ty] - | (ity, oty) :: inOutTys, b :: bl => do + mkAppOptM ``FixII.Funs.Nil #[finTypeExpr, param_ty, in_ty, out_ty] + | (_, pty, ity, oty) :: paramInOutTys, b :: bl => do -- Retrieving ity and oty - this is not very clean - let inOutTysExpr ← mkListLit inOutTyType (← inOutTys.mapM (λ (x, y) => mkInOutTy x y)) - let fl ← mkFuns inOutTys bl - mkAppOptM ``FixI.Funs.Cons #[finTypeExpr, in_ty, out_ty, ity, oty, inOutTysExpr, b, fl] + let paramInOutTysExpr ← mkListLit paramInOutTyType + (← paramInOutTys.mapM (λ (_, x, y, z) => mkInOutTy x y z)) + let fl ← mkFuns paramInOutTys bl + mkAppOptM ``FixII.Funs.Cons #[finTypeExpr, param_ty, in_ty, out_ty, pty, ity, oty, paramInOutTysExpr, b, fl] | _, _ => throwError "mkDeclareMutRecBody: `tys` and `bodies` don't have the same length" - let bodyFuns ← mkFuns inOutTys bodies.toList + let bodyFuns ← mkFuns paramInOutTys bodies.toList -- Wrap in `get_fun` - let body ← mkAppM ``FixI.get_fun #[bodyFuns, i_var, kk_var] + let body ← mkAppM ``FixII.get_fun #[bodyFuns, i_var, kk_var] -- Add the index `i` and the continuation `k` as a variables let body ← mkLambdaFVars #[kk_var, i_var] body trace[Diverge.def] "mkDeclareMutRecBody: body: {body}" @@ -391,11 +545,11 @@ instance : ToMessageData MatchInfo where -- This is not a very clean formatting, but we don't need more toMessageData := fun me => m!"\n- matcherName: {me.matcherName}\n- params: {me.params}\n- motive: {me.motive}\n- scruts: {me.scruts}\n- branchesNumParams: {me.branchesNumParams}\n- branches: {me.branches}" --- Small helper: prove that an expression which doesn't use the continuation `kk` --- is valid, and return the proof. +/- Small helper: prove that an expression which doesn't use the continuation `kk` + is valid, and return the proof. -/ def proveNoKExprIsValid (k_var : Expr) (e : Expr) : MetaM Expr := do trace[Diverge.def.valid] "proveNoKExprIsValid: {e}" - let eIsValid ← mkAppM ``FixI.is_valid_p_same #[k_var, e] + let eIsValid ← mkAppM ``FixII.is_valid_p_same #[k_var, e] trace[Diverge.def.valid] "proveNoKExprIsValid: result:\n{eIsValid}:\n{← inferType eIsValid}" pure eIsValid @@ -462,8 +616,11 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do mkLambdaFVars #[x] brValid let br0Valid ← proveBranchValid br0 let br1Valid ← proveBranchValid br1 - let const := if isIte then ``FixI.is_valid_p_ite else ``FixI.is_valid_p_dite - let eIsValid ← mkAppOptM const #[none, none, none, none, some k_var, some cond, some dec, none, none, some br0Valid, some br1Valid] + let const := if isIte then ``FixII.is_valid_p_ite else ``FixII.is_valid_p_dite + let eIsValid ← + mkAppOptM const #[none, none, none, none, none, + some k_var, some cond, some dec, none, none, + some br0Valid, some br1Valid] trace[Diverge.def.valid] "ite/dite: result:\n{eIsValid}:\n{← inferType eIsValid}" pure eIsValid -- Check if the expression is a match (this case is for when the elaborator @@ -498,15 +655,16 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do -- use to currify function bodies, introduce such raw matches. else if ← isCasesExpr f then do trace[Diverge.def.valid] "rawMatch: {e}" - -- Deconstruct the match, and call the auxiliary helper `proveMatchIsValid`. - -- - -- The casesOn definition is always of the following shape: - -- - input parameters (implicit parameters) - -- - motive (implicit), -- the motive gives the return type of the match - -- - scrutinee (explicit) - -- - branches (explicit). - -- In particular, we notice that the scrutinee is the first *explicit* - -- parameter - this is how we spot it. + /- Deconstruct the match, and call the auxiliary helper `proveMatchIsValid`. + + The casesOn definition is always of the following shape: + - input parameters (implicit parameters) + - motive (implicit), -- the motive gives the return type of the match + - scrutinee (explicit) + - branches (explicit). + In particular, we notice that the scrutinee is the first *explicit* + parameter - this is how we spot it. + -/ let matcherName := f.constName! let matcherLevels := f.constLevels!.toArray -- Find the first explicit parameter: this is the scrutinee @@ -526,10 +684,10 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do let motive := args.get! (scrutIdx - 1) let scrut := args.get! scrutIdx let branches := args.extract (scrutIdx + 1) args.size - -- Compute the number of parameters for the branches: for this we use - -- the type of the uninstantiated casesOn constant (we can't just - -- destruct the lambdas in the branch expressions because the result - -- of a match might be a lambda expression). + /- Compute the number of parameters for the branches: for this we use + the type of the uninstantiated casesOn constant (we can't just + destruct the lambdas in the branch expressions because the result + of a match might be a lambda expression). -/ let branchesNumParams : Array Nat ← do let env ← getEnv let decl := env.constants.find! matcherName @@ -572,14 +730,15 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do pure yValid -- Put everything together trace[Diverge.def.valid] "bind:\n- xValid: {xValid}: {← inferType xValid}\n- yValid: {yValid}: {← inferType yValid}" - mkAppM ``FixI.is_valid_p_bind #[xValid, yValid] + mkAppM ``FixII.is_valid_p_bind #[xValid, yValid] -- Check if this is a recursive call, i.e., a call to the continuation `kk` else if f.isFVarOf kk_var.fvarId! then do trace[Diverge.def.valid] "rec: args: \n{args}" - if args.size ≠ 2 then throwError "Recursive call with invalid number of parameters: {args}" + if args.size ≠ 3 then throwError "Recursive call with invalid number of parameters: {args}" let i_arg := args.get! 0 - let x_arg := args.get! 1 - let eIsValid ← mkAppM ``FixI.is_valid_p_rec #[k_var, i_arg, x_arg] + let t_arg := args.get! 1 + let x_arg := args.get! 2 + let eIsValid ← mkAppM ``FixII.is_valid_p_rec #[k_var, i_arg, t_arg, x_arg] trace[Diverge.def.valid] "rec: result: \n{eIsValid}" pure eIsValid else do @@ -593,10 +752,10 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp trace[Diverge.def.valid] "proveMatchIsValid: {me}" -- Prove the validity of the branch expressions let branchesValid:Array Expr ← me.branches.mapIdxM fun idx br => do - -- Go inside the lambdas - note that we have to be careful: some of the - -- binders might come from the match, and some of the binders might come - -- from the fact that the expression in the match is a lambda expression: - -- we use the branchesNumParams field for this reason + /- Go inside the lambdas - note that we have to be careful: some of the + binders might come from the match, and some of the binders might come + from the fact that the expression in the match is a lambda expression: + we use the branchesNumParams field for this reason. -/ let numParams := me.branchesNumParams.get! idx lambdaTelescopeN br numParams fun xs br => do -- Prove that the branch expression is valid @@ -604,13 +763,13 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp -- Reconstruct the lambda expression mkLambdaFVars xs brValid trace[Diverge.def.valid] "branchesValid:\n{branchesValid}" - -- Compute the motive, which has the following shape: - -- ``` - -- λ scrut => is_valid_p k (λ k => match scrut with ...) - -- ^^^^^^^^^^^^^^^^^^^^ - -- this is the original match expression, with the - -- the difference that the scrutinee(s) is a variable - -- ``` + /- Compute the motive, which has the following shape: + ``` + λ scrut => is_valid_p k (λ k => match scrut with ...) + ^^^^^^^^^^^^^^^^^^^^ + this is the original match expression, with the + the difference that the scrutinee(s) is a variable + ``` -/ let validMotive : Expr ← do -- The motive is a function of the scrutinees (i.e., a lambda expression): -- introduce binders for the scrutinees @@ -629,7 +788,7 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp let matchE ← mkAppOptM me.matcherName args -- Wrap in the `is_valid_p` predicate let matchE ← mkLambdaFVars #[kk_var] matchE - let validMotive ← mkAppM ``FixI.is_valid_p #[k_var, matchE] + let validMotive ← mkAppM ``FixII.is_valid_p #[k_var, matchE] -- Abstract away the scrutinee variables mkLambdaFVars scrutVars validMotive trace[Diverge.def.valid] "valid motive: {validMotive}" @@ -647,10 +806,10 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp end --- Prove that a single body (in the mutually recursive group) is valid. --- --- For instance, if we define the mutually recursive group [`is_even`, `is_odd`], --- we prove that `is_even.body` and `is_odd.body` are valid. +/- Prove that a single body (in the mutually recursive group) is valid. + + For instance, if we define the mutually recursive group [`is_even`, `is_odd`], + we prove that `is_even.body` and `is_odd.body` are valid. -/ partial def proveSingleBodyIsValid (k_var : Expr) (preDef : PreDefinition) (bodyConst : Expr) : MetaM Expr := do @@ -662,24 +821,28 @@ partial def proveSingleBodyIsValid let body := (env.constants.find! name).value! trace[Diverge.def.valid] "body: {body}" lambdaTelescope body fun xs body => do - assert! xs.size = 2 + trace[Diverge.def.valid] "xs: {xs}" + assert! xs.size = 3 let kk_var := xs.get! 0 - let x_var := xs.get! 1 + let t_var := xs.get! 1 + let x_var := xs.get! 2 -- State the type of the theorem to prove - let thmTy ← mkAppM ``FixI.is_valid_p - #[k_var, ← mkLambdaFVars #[kk_var] (← mkAppM' bodyConst #[kk_var, x_var])] + trace[Diverge.def.valid] "bodyConst: {bodyConst} : {← inferType bodyConst}" + let bodyApp ← mkAppOptM' bodyConst #[.some kk_var, .some t_var, .some x_var] + trace[Diverge.def.valid] "bodyApp: {bodyApp}" + let bodyApp ← mkLambdaFVars #[kk_var] bodyApp + trace[Diverge.def.valid] "bodyApp: {bodyApp}" + let thmTy ← mkAppM ``FixII.is_valid_p #[k_var, bodyApp] trace[Diverge.def.valid] "thmTy: {thmTy}" -- Prove that the body is valid let proof ← proveExprIsValid k_var kk_var body - let proof ← mkLambdaFVars #[k_var, x_var] proof + let proof ← mkLambdaFVars #[k_var, t_var, x_var] proof trace[Diverge.def.valid] "proveSingleBodyIsValid: proof:\n{proof}:\n{← inferType proof}" -- The target type (we don't have to do this: this is simply a sanity check, -- and this allows a nicer debugging output) let thmTy ← do - let body ← mkAppM' bodyConst #[kk_var, x_var] - let body ← mkLambdaFVars #[kk_var] body - let ty ← mkAppM ``FixI.is_valid_p #[k_var, body] - mkForallFVars #[k_var, x_var] ty + let ty ← mkAppM ``FixII.is_valid_p #[k_var, bodyApp] + mkForallFVars #[k_var, t_var, x_var] ty trace[Diverge.def.valid] "proveSingleBodyIsValid: thmTy\n{thmTy}:\n{← inferType thmTy}" -- Save the theorem let name := preDef.declName ++ "body_is_valid" @@ -695,18 +858,18 @@ partial def proveSingleBodyIsValid -- Return the theorem pure (Expr.const name (preDef.levelParams.map .param)) --- Prove that the list of bodies are valid. --- --- For instance, if we define the mutually recursive group [`is_even`, `is_odd`], --- we prove that `Funs.Cons is_even.body (Funs.Cons is_odd.body Funs.Nil)` is --- valid. -partial def proveFunsBodyIsValid (inOutTys: Expr) (bodyFuns : Expr) +/- Prove that the list of bodies are valid. + + For instance, if we define the mutually recursive group [`is_even`, `is_odd`], + we prove that `Funs.Cons is_even.body (Funs.Cons is_odd.body Funs.Nil)` is + valid. -/ +partial def proveFunsBodyIsValid (paramInOutTys: Expr) (bodyFuns : Expr) (k_var : Expr) (bodiesValid : Array Expr) : MetaM Expr := do -- Create the big "and" expression, which groups the validity proof of the individual bodies let rec mkValidConj (i : Nat) : MetaM Expr := do if i = bodiesValid.size then -- We reached the end - mkAppM ``FixI.Funs.is_valid_p_Nil #[k_var] + mkAppM ``FixII.Funs.is_valid_p_Nil #[k_var] else do -- We haven't reached the end: introduce a conjunction let valid := bodiesValid.get! i @@ -714,20 +877,20 @@ partial def proveFunsBodyIsValid (inOutTys: Expr) (bodyFuns : Expr) mkAppM ``And.intro #[valid, ← mkValidConj (i + 1)] let andExpr ← mkValidConj 0 -- Wrap in the `is_valid_p_is_valid_p` theorem, and abstract the continuation - let isValid ← mkAppM ``FixI.Funs.is_valid_p_is_valid_p #[inOutTys, k_var, bodyFuns, andExpr] + let isValid ← mkAppM ``FixII.Funs.is_valid_p_is_valid_p #[paramInOutTys, k_var, bodyFuns, andExpr] mkLambdaFVars #[k_var] isValid --- Prove that the mut rec body (i.e., the unary body which groups the bodies --- of all the functions in the mutually recursive group and on which we will --- apply the fixed-point operator) is valid. --- --- We save the proof in the theorem "[GROUP_NAME]."mut_rec_body_is_valid", --- which we return. --- --- TODO: maybe this function should introduce k_var itself +/- Prove that the mut rec body (i.e., the unary body which groups the bodies + of all the functions in the mutually recursive group and on which we will + apply the fixed-point operator) is valid. + + We save the proof in the theorem "[GROUP_NAME]."mut_rec_body_is_valid", + which we return. + + TODO: maybe this function should introduce k_var itself -/ def proveMutRecIsValid (grName : Name) (grLvlParams : List Name) - (inOutTys : Expr) (bodyFuns mutRecBodyConst : Expr) + (paramInOutTys : Expr) (bodyFuns mutRecBodyConst : Expr) (k_var : Expr) (preDefs : Array PreDefinition) (bodies : Array Expr) : MetaM Expr := do -- First prove that the individual bodies are valid @@ -738,9 +901,9 @@ def proveMutRecIsValid proveSingleBodyIsValid k_var preDef body -- Then prove that the mut rec body is valid trace[Diverge.def.valid] "## Proving that the 'Funs' body is valid" - let isValid ← proveFunsBodyIsValid inOutTys bodyFuns k_var bodiesValid + let isValid ← proveFunsBodyIsValid paramInOutTys bodyFuns k_var bodiesValid -- Save the theorem - let thmTy ← mkAppM ``FixI.is_valid #[mutRecBodyConst] + let thmTy ← mkAppM ``FixII.is_valid #[mutRecBodyConst] let name := grName ++ "mut_rec_body_is_valid" let decl := Declaration.thmDecl { name @@ -754,26 +917,29 @@ def proveMutRecIsValid -- Return the theorem pure (Expr.const name (grLvlParams.map .param)) --- Generate the final definions by using the mutual body and the fixed point operator. --- --- For instance: --- ``` --- def is_even (i : Int) : Result Bool := mut_rec_body 0 i --- def is_odd (i : Int) : Result Bool := mut_rec_body 1 i --- ``` -def mkDeclareFixDefs (mutRecBody : Expr) (inOutTys : Array (Expr × Expr)) (preDefs : Array PreDefinition) : +/- Generate the final definions by using the mutual body and the fixed point operator. + + For instance: + ``` + def is_even (i : Int) : Result Bool := mut_rec_body 0 i + def is_odd (i : Int) : Result Bool := mut_rec_body 1 i + ``` + -/ +def mkDeclareFixDefs (mutRecBody : Expr) (paramInOutTys : Array (ℕ × Expr × Expr × Expr)) (preDefs : Array PreDefinition) : TermElabM (Array Name) := do let grSize := preDefs.size let defs ← preDefs.mapIdxM fun idx preDef => do lambdaTelescope preDef.value fun xs _ => do - -- Retrieve the input type - let in_ty := (inOutTys.get! idx.val).fst + -- Retrieve the parameters info + let (num_params, param_ty, _, _) := paramInOutTys.get! idx.val -- Create the index let idx ← mkFinVal grSize idx.val - -- Group the inputs into a dependent tuple - let input ← mkSigmasVal in_ty xs.toList + -- Group the inputs into two tuples + let (params_args, input_args) := xs.toList.splitAt num_params + let params ← mkSigmasVal param_ty params_args + let input ← mkProdsVal input_args -- Apply the fixed point - let fixedBody ← mkAppM ``FixI.fix #[mutRecBody, idx, input] + let fixedBody ← mkAppM ``FixII.fix #[mutRecBody, idx, params, input] let fixedBody ← mkLambdaFVars xs fixedBody -- Create the declaration let name := preDef.declName @@ -791,7 +957,8 @@ def mkDeclareFixDefs (mutRecBody : Expr) (inOutTys : Array (Expr × Expr)) (preD pure defs -- Prove the equations that we will use as unfolding theorems -partial def proveUnfoldingThms (isValidThm : Expr) (inOutTys : Array (Expr × Expr)) +partial def proveUnfoldingThms (isValidThm : Expr) + (paramInOutTys : Array (ℕ × Expr × Expr × Expr)) (preDefs : Array PreDefinition) (decls : Array Name) : MetaM Unit := do let grSize := preDefs.size let proveIdx (i : Nat) : MetaM Unit := do @@ -811,14 +978,18 @@ partial def proveUnfoldingThms (isValidThm : Expr) (inOutTys : Array (Expr × Ex trace[Diverge.def.unfold] "proveUnfoldingThms: thm statement: {thmTy}" -- The proof -- Use the fixed-point equation - let proof ← mkAppM ``FixI.is_valid_fix_fixed_eq #[isValidThm] + let proof ← mkAppM ``FixII.is_valid_fix_fixed_eq #[isValidThm] -- Add the index let idx ← mkFinVal grSize i let proof ← mkAppM ``congr_fun #[proof, idx] - -- Add the input argument - let arg ← mkSigmasVal (inOutTys.get! i).fst xs.toList - let proof ← mkAppM ``congr_fun #[proof, arg] - -- Abstract the arguments away + -- Add the input arguments + let (num_params, param_ty, _, _) := paramInOutTys.get! i + let (params, args) := xs.toList.splitAt num_params + let params ← mkSigmasVal param_ty params + let args ← mkProdsVal args + let proof ← mkAppM ``congr_fun #[proof, params] + let proof ← mkAppM ``congr_fun #[proof, args] + -- Abstract all the arguments away let proof ← mkLambdaFVars xs proof trace[Diverge.def.unfold] "proveUnfoldingThms: proof: {proof}:\n{← inferType proof}" -- Declare the theorem @@ -846,7 +1017,9 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) trace[Diverge.def] ("divRecursion: defs:\n" ++ msg) - -- TODO: what is this? + -- Apply all the "attribute" functions (for instance, the function which + -- registers the theorem in the simp database if there is the `simp` attribute, + -- etc.) for preDef in preDefs do applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation @@ -860,40 +1033,52 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let grLvlParams := def0.levelParams trace[Diverge.def] "def0 universe levels: {def0.levelParams}" - -- We first compute the list of pairs: (input type × output type) - let inOutTys : Array (Expr × Expr) ← - preDefs.mapM (fun preDef => do - withRef preDef.ref do -- is the withRef useful? - -- Check the universe parameters - TODO: I'm not sure what the best thing - -- to do is. In practice, all the type parameters should be in Type 0, so - -- we shouldn't have universe issues. - if preDef.levelParams ≠ grLvlParams then - throwError "Non-uniform polymorphism in the universes" - forallTelescope preDef.type (fun in_tys out_ty => do - let in_ty ← liftM (mkSigmasType in_tys.toList) - -- Retrieve the type in the "Result" - let out_ty ← getResultTy out_ty - let out_ty ← liftM (mkSigmasMatch in_tys.toList out_ty) - pure (in_ty, out_ty) - ) - ) - trace[Diverge.def] "inOutTys: {inOutTys}" - -- Turn the list of input/output type pairs into an expresion - let inOutTysExpr ← inOutTys.mapM (λ (x, y) => mkInOutTy x y) - let inOutTysExpr ← mkListLit (← inferType (inOutTysExpr.get! 0)) inOutTysExpr.toList - - -- From the list of pairs of input/output types, actually compute the - -- type of the continuation `k`. - -- We first introduce the index `i : Fin n` where `n` is the number of - -- functions in the group. + /- We first compute the tuples: (type parameters × input type × output type) + - type parameters: this is a sigma type + - input type: λ params_type => product type + - output type: λ params_type => output type + For instance, on the function: + `list_nth (α : Type) (ls : List α) (i : Int) : Result α`: + we generate: + `(Type, λ α => List α × i, λ α => Result α)` + -/ + let paramInOutTys : Array (ℕ × Expr × Expr × Expr) ← + preDefs.mapM (fun preDef => do + -- Check the universe parameters - TODO: I'm not sure what the best thing + -- to do is. In practice, all the type parameters should be in Type 0, so + -- we shouldn't have universe issues. + if preDef.levelParams ≠ grLvlParams then + throwError "Non-uniform polymorphism in the universes" + forallTelescope preDef.type (fun in_tys out_ty => do + let (params, in_tys) ← splitInputArgs in_tys out_ty + trace[Diverge.def] "Decomposed arguments: {preDef.declName}: {params}, {in_tys}, {out_ty}" + let num_params := params.size + let params_sigma ← mkSigmasType params.data + let in_tys ← mkSigmasMatchOrUnit params.data (← mkProdsType in_tys.data) + -- Retrieve the type in the "Result" + let out_ty ← getResultTy out_ty + let out_ty ← mkSigmasMatchOrUnit params.data out_ty + trace[Diverge.def] "inOutTy: {preDef.declName}: {params_sigma}, {in_tys}, {out_ty}" + pure (num_params, params_sigma, in_tys, out_ty))) + trace[Diverge.def] "paramInOutTys: {paramInOutTys}" + -- Turn the list of input types/input args/output type tuples into expressions + let paramInOutTysExpr ← paramInOutTys.mapM (λ (_, x, y, z) => do mkInOutTy x y z) + let paramInOutTysExpr ← mkListLit (← inferType (paramInOutTysExpr.get! 0)) paramInOutTysExpr.toList + trace[Diverge.def] "paramInOutTys: {paramInOutTys}" + + /- From the list of pairs of input/output types, actually compute the + type of the continuation `k`. + We first introduce the index `i : Fin n` where `n` is the number of + functions in the group. + -/ let i_var_ty := mkFin preDefs.size withLocalDeclD (mkAnonymous "i" 0) i_var_ty fun i_var => do - let in_out_ty ← mkAppM ``List.get #[inOutTysExpr, i_var] - trace[Diverge.def] "in_out_ty := {in_out_ty} : {← inferType in_out_ty}" - -- Add an auxiliary definition for `in_out_ty` - let in_out_ty ← do - let value ← mkLambdaFVars #[i_var] in_out_ty - let name := grName.append "in_out_ty" + let param_in_out_ty ← mkAppM ``List.get #[paramInOutTysExpr, i_var] + trace[Diverge.def] "param_in_out_ty := {param_in_out_ty} : {← inferType param_in_out_ty}" + -- Add an auxiliary definition for `param_in_out_ty` (this is a potentially big term) + let param_in_out_ty ← do + let value ← mkLambdaFVars #[i_var] param_in_out_ty + let name := grName.append "param_in_out_ty" let levelParams := grLvlParams let decl := Declaration.defnDecl { name := name @@ -906,19 +1091,28 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do } addDecl decl -- Return the constant - let in_out_ty := Lean.mkConst name (levelParams.map .param) - mkAppM' in_out_ty #[i_var] - trace[Diverge.def] "in_out_ty (after decl) := {in_out_ty} : {← inferType in_out_ty}" - let in_ty ← mkAppM ``Sigma.fst #[in_out_ty] + let param_in_out_ty := Lean.mkConst name (levelParams.map .param) + mkAppM' param_in_out_ty #[i_var] + trace[Diverge.def] "param_in_out_ty (after decl) := {param_in_out_ty} : {← inferType param_in_out_ty}" + -- Decompose between: param_ty, in_ty, out_ty + let param_ty ← mkAppM ``Sigma.fst #[param_in_out_ty] + let in_out_ty ← mkAppM ``Sigma.snd #[param_in_out_ty] + let in_ty ← mkAppM ``Prod.fst #[in_out_ty] + let out_ty ← mkAppM ``Prod.snd #[in_out_ty] + trace[Diverge.def] "param_ty: {param_ty}" + trace[Diverge.def] "in_ty: {in_ty}" + trace[Diverge.def] "out_ty: {out_ty}" + withLocalDeclD (mkAnonymous "t" 1) param_ty fun param => do + let in_ty ← mkAppM' in_ty #[param] + let out_ty ← mkAppM' out_ty #[param] trace[Diverge.def] "in_ty: {in_ty}" - withLocalDeclD (mkAnonymous "x" 1) in_ty fun input => do - let out_ty ← mkAppM' (← mkAppM ``Sigma.snd #[in_out_ty]) #[input] trace[Diverge.def] "out_ty: {out_ty}" -- Introduce the continuation `k` - let in_ty ← mkLambdaFVars #[i_var] in_ty - let out_ty ← mkLambdaFVars #[i_var, input] out_ty - let kk_var_ty ← mkAppM ``FixI.kk_ty #[i_var_ty, in_ty, out_ty] + let param_ty ← mkLambdaFVars #[i_var] param_ty + let in_ty ← mkLambdaFVars #[i_var, param] in_ty + let out_ty ← mkLambdaFVars #[i_var, param] out_ty + let kk_var_ty ← mkAppM ``FixII.kk_ty #[i_var_ty, param_ty, in_ty, out_ty] trace[Diverge.def] "kk_var_ty: {kk_var_ty}" withLocalDeclD (mkAnonymous "kk" 2) kk_var_ty fun kk_var => do trace[Diverge.def] "kk_var: {kk_var}" @@ -926,29 +1120,30 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Replace the recursive calls in all the function bodies by calls to the -- continuation `k` and and generate for those bodies declarations trace[Diverge.def] "# Generating the unary bodies" - let bodies ← mkDeclareUnaryBodies grLvlParams kk_var inOutTys preDefs + let bodies ← mkDeclareUnaryBodies grLvlParams kk_var paramInOutTys preDefs trace[Diverge.def] "Unary bodies (after decl): {bodies}" + -- Generate the mutually recursive body trace[Diverge.def] "# Generating the mut rec body" - let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var in_ty out_ty inOutTys.toList bodies + let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var param_ty in_ty out_ty paramInOutTys.toList bodies trace[Diverge.def] "mut rec body (after decl): {mutRecBody}" -- Prove that the mut rec body satisfies the validity criteria required by -- our fixed-point - let k_var_ty ← mkAppM ``FixI.k_ty #[i_var_ty, in_ty, out_ty] + let k_var_ty ← mkAppM ``FixII.k_ty #[i_var_ty, param_ty, in_ty, out_ty] withLocalDeclD (mkAnonymous "k" 3) k_var_ty fun k_var => do trace[Diverge.def] "# Proving that the mut rec body is valid" - let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies + let isValidThm ← proveMutRecIsValid grName grLvlParams paramInOutTysExpr bodyFuns mutRecBody k_var preDefs bodies -- Generate the final definitions trace[Diverge.def] "# Generating the final definitions" - let decls ← mkDeclareFixDefs mutRecBody inOutTys preDefs + let decls ← mkDeclareFixDefs mutRecBody paramInOutTys preDefs -- Prove the unfolding theorems trace[Diverge.def] "# Proving the unfolding theorems" - proveUnfoldingThms isValidThm inOutTys preDefs decls + proveUnfoldingThms isValidThm paramInOutTys preDefs decls - -- Generating code -- TODO + -- Generating code addAndCompilePartialRec preDefs -- The following function is copy&pasted from Lean.Elab.PreDefinition.Main -- cgit v1.2.3