diff options
author | Son Ho | 2023-07-03 19:26:27 +0200 |
---|---|---|
committer | Son Ho | 2023-07-03 19:26:27 +0200 |
commit | 75fae6384716f24fe137283d4a41836782b9aec7 (patch) | |
tree | 61bd96b19da72dab50f95db15235251fecc1fa2b /backends/lean/Base/Diverge | |
parent | 9214484c471ad931924865855687f9a2ffe255dd (diff) |
Cleanup a bit Diverge/Elab.lean
Diffstat (limited to '')
-rw-r--r-- | backends/lean/Base/Diverge/Elab.lean | 366 |
1 files changed, 197 insertions, 169 deletions
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 91c51a31..cc580265 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -15,39 +15,16 @@ syntax (name := divergentDef) open Lean Elab Term Meta Primitives Lean.Meta -set_option trace.Diverge.def true --- set_option trace.Diverge.def.valid true --- set_option trace.Diverge.def.sigmas true -set_option trace.Diverge.def.unfold true - /- The following was copied from the `wfRecursion` function. -/ open WF in -def mkList (xl : List Expr) (ty : Expr) : MetaM Expr := - match xl with - | [] => - mkAppOptM ``List.nil #[some ty] - | x :: tl => do - let tl ← mkList tl ty - mkAppOptM ``List.cons #[some ty, some x, some tl] - def mkProd (x y : Expr) : MetaM Expr := mkAppM ``Prod.mk #[x, y] def mkInOutTy (x y : Expr) : MetaM Expr := mkAppM ``FixI.mk_in_out_ty #[x, y] --- TODO: is there already such a utility somewhere? --- TODO: change to mkSigmas -def mkProds (tys : List Expr) : MetaM Expr := - match tys with - | [] => do pure (Expr.const ``PUnit.unit []) - | [ty] => do pure ty - | ty :: tys => do - let pty ← mkProds tys - mkAppM ``Prod.mk #[ty, pty] - -- Return the `a` in `Return a` def get_result_ty (ty : Expr) : MetaM Expr := ty.withApp fun f args => do @@ -56,26 +33,31 @@ def get_result_ty (ty : Expr) : MetaM Expr := else pure (args.get! 0) --- Group a list of expressions into a dependent tuple -def mkSigmas (xl : List Expr) : MetaM Expr := +/- Group a list of expressions into a dependent tuple. + + Example: + xl = [`a : Type`, `ls : List a`] + returns: + `⟨ (a:Type), (ls: List a) ⟩` + -/ +def mkSigmasVal (xl : List Expr) : MetaM Expr := match xl with | [] => do - trace[Diverge.def.sigmas] "mkSigmas: []" + trace[Diverge.def.sigmas] "mkSigmasVal: []" pure (Expr.const ``PUnit.unit []) | [x] => do - trace[Diverge.def.sigmas] "mkSigmas: [{x}]" + trace[Diverge.def.sigmas] "mkSigmasVal: [{x}]" pure x | fst :: xl => do - trace[Diverge.def.sigmas] "mkSigmas: [{fst}::{xl}]" + trace[Diverge.def.sigmas] "mkSigmasVal: [{fst}::{xl}]" let alpha ← Lean.Meta.inferType fst - let snd ← mkSigmas xl + let snd ← mkSigmasVal xl let snd_ty ← inferType snd let beta ← mkLambdaFVars #[fst] snd_ty - trace[Diverge.def.sigmas] "mkSigmas:\n{alpha}\n{beta}\n{fst}\n{snd}" + trace[Diverge.def.sigmas] "mkSigmasVal:\n{alpha}\n{beta}\n{fst}\n{snd}" mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd] -/- Generate the input type of a function body, which is a sigma type (i.e., a - dependent tuple) which groups all its inputs. +/- Generate a Sigma type from a list of expressions. Example: - xl = [(a:Type), (ls:List a), (i:Int)] @@ -84,7 +66,7 @@ def mkSigmas (xl : List Expr) : MetaM Expr := `(a:Type) × (ls:List a) × (i:Int)` -/ -def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr := +def mkSigmasType (xl : List Expr) : MetaM Expr := match xl with | [] => do trace[Diverge.def.sigmas] "mkSigmasOfTypes: []" @@ -96,15 +78,16 @@ def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr := | x :: xl => do trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]" let alpha ← Lean.Meta.inferType x - let sty ← mkSigmasTypesOfTypes xl + let sty ← mkSigmasType xl trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]: alpha={alpha}, sty={sty}" let beta ← mkLambdaFVars #[x] sty trace[Diverge.def.sigmas] "mkSigmasOfTypes: ({alpha}) ({beta})" mkAppOptM ``Sigma #[some alpha, some beta] -def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index +def mkAnonymous (s : String) (i : Nat) : Name := + .num (.str .anonymous s) i -/- Given a list of values `[x0:ty0, ..., xn:ty1]` where every `xi` might use the previous +/- Given a list of values `[x0:ty0, ..., xn:ty1]`, where every `xi` might use the previous `xj` (j < i) and a value `out` which uses `x0`, ..., `xn`, generate the following expression: ``` @@ -112,20 +95,22 @@ def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index match x with | (x0, ..., xn) => out ``` - + The `index` parameter is used for naming purposes: we use it to numerotate the bound variables that we introduce. + We use this function to currify functions (the function bodies given to the + fixed-point operator must be unary functions). + Example: ======== - More precisely: - xl = `[a:Type, ls:List a, i:Int]` - out = `a` - index = 0 - generates: + generates (getting rid of most of the syntactic sugar): ``` - match scrut0 with + λ scrut0 => match scrut0 with | Sigma.mk x scrut1 => match scrut1 with | Sigma.mk ls i => @@ -138,21 +123,30 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met -- This would be unexpected throwError "mkSigmasMatch: empyt list of input parameters" | [x] => do - -- In the explanations above: inner match case + -- In the example given for the explanations: this is the inner match case trace[Diverge.def.sigmas] "mkSigmasMatch: [{x}]" mkLambdaFVars #[x] out | fst :: xl => do - -- In the explanations above: outer match case + -- In the example given for the explanations: this is the outer match case -- Remark: for the naming purposes, we use the same convention as for the - -- fields and parameters in `Sigma.casesOn and `Sigma.mk + -- fields and parameters in `Sigma.casesOn` and `Sigma.mk` (looking at + -- those definitions might help) + -- + -- We want to build the match expression: + -- ``` + -- λ scrut => + -- match scrut with + -- | Sigma.mk x ... -- the hole is given by a recursive call on the tail + -- ``` trace[Diverge.def.sigmas] "mkSigmasMatch: [{fst}::{xl}]" let alpha ← Lean.Meta.inferType fst - let snd_ty ← mkSigmasTypesOfTypes xl + let snd_ty ← mkSigmasType xl let beta ← mkLambdaFVars #[fst] snd_ty let snd ← mkSigmasMatch xl out (index + 1) - let scrut_ty ← mkSigmasTypesOfTypes (fst :: xl) - withLocalDeclD (mk_indexed_name index) scrut_ty fun scrut => do let mk ← mkLambdaFVars #[fst] snd + -- Introduce the "scrut" variable + let scrut_ty ← mkSigmasType (fst :: xl) + withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do trace[Diverge.def.sigmas] "mkSigmasMatch: scrut: ({scrut}) : ({← inferType scrut})" -- TODO: make the computation of the motive more efficient let motive ← do @@ -166,38 +160,32 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met -- TODO: make this more efficient (we could change the output type of -- mkSigmasMatch mkSigmasMatch (fst :: xl) out_ty + -- The final expression: putting everything together trace[Diverge.def.sigmas] "mkSigmasMatch:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})" let sm ← mkAppOptM ``Sigma.casesOn #[some alpha, some beta, some motive, some scrut, some mk] + -- Abstracting the "scrut" variable let sm ← mkLambdaFVars #[scrut] sm trace[Diverge.def.sigmas] "mkSigmasMatch: sm: {sm}" pure sm /- Small tests for list_nth: give a model of what `mkSigmasMatch` should generate -/ -private def list_nth_out_ty2 (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) := +private def list_nth_out_ty_inner (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) := @Sigma.casesOn (List a) (fun (_ls : List a) => Int) (fun (_scrut1:@Sigma (List a) (fun (_ls : List a) => Int)) => Type) scrut1 (fun (_ls : List a) (_i : Int) => Diverge.Primitives.Result a) -private def list_nth_out_ty1 (scrut0 : @Sigma (Type) (fun (a:Type) => +private def list_nth_out_ty_outer (scrut0 : @Sigma (Type) (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int))) := @Sigma.casesOn (Type) (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int)) (fun (_scrut0:@Sigma (Type) (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int))) => Type) scrut0 (fun (a : Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) => - list_nth_out_ty2 a scrut1) + list_nth_out_ty_inner a scrut1) /- -/ --- TODO: move --- TODO: we can use Array.mapIdx -@[specialize] def mapiAux (i : Nat) (f : Nat → α → β) : List α → List β - | [] => [] - | a::as => f i a :: mapiAux (i+1) f as - -@[specialize] def mapi (f : Nat → α → β) : List α → List β := mapiAux 0 f - -- Return the expression: `Fin n` -- TODO: use more def mkFin (n : Nat) : Expr := @@ -212,15 +200,6 @@ def mkFinVal (n i : Nat) : MetaM Expr := do let ofNat ← mkAppOptM ``Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat #[n_lit, i_lit] mkAppOptM ``OfNat.ofNat #[none, none, ofNat] --- TODO: remove? -def mkFinValOld (n i : Nat) : MetaM Expr := do - let finTy := mkFin n - let ofNat ← mkAppM ``OfNat #[finTy, .lit (.natVal i)] - match ← trySynthInstance ofNat with - | LOption.some x => - mkAppOptM ``OfNat.ofNat #[none, none, x] - | _ => throwError "mkFinVal: could not synthesize an instance of {ofNat} " - /- Generate and declare as individual definitions the bodies for the individual funcions: - replace the recursive calls with calls to the continutation `k` - make those bodies take one single dependent tuple as input @@ -234,11 +213,11 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) let grSize := preDefs.size -- Compute the map from name to index - the continuation has an indexed type: - -- we use the index (a finite number of type `Fin`) to control the function - -- we call at the recursive call + -- we use the index (a finite number of type `Fin`) to control which function + -- we call at the recursive call site. let nameToId : HashMap Name Nat := - let namesIds := mapi (fun i d => (d.declName, i)) preDefs.toList - HashMap.ofList namesIds + let namesIds := preDefs.mapIdx (fun i d => (d.declName, i.val)) + HashMap.ofList namesIds.toList trace[Diverge.def.genBody] "nameToId: {nameToId.toList}" @@ -260,7 +239,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) -- Compute the index let i ← mkFinVal grSize id -- Put the arguments in one big dependent tuple - let args ← mkSigmas args.toList + let args ← mkSigmasVal args.toList mkAppM' kk_var #[i, args] else -- Not a recursive call: do nothing @@ -277,13 +256,14 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) -- Replace the recursive calls let body ← mapVisit visit_e preDef.value - -- Change the type + -- Currify the function by grouping the arguments into a dependent tuple + -- (over which we match to retrieve the individual arguments). lambdaLetTelescope body fun args body => do let body ← mkSigmasMatch args.toList body 0 -- Add the declaration let value ← mkLambdaFVars #[kk_var] body - let name := preDef.declName.append "sbody" + let name := preDef.declName.append "body" let levelParams := grLvlParams let decl := Declaration.defnDecl { name := name @@ -304,7 +284,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) -- Generate a unique function body from the bodies of the mutually recursive group, -- and add it as a declaration in the context. --- We return the list of bodies (of type `Funs ...`) and the mutually recursive body. +-- We return the list of bodies (of type `FixI.Funs ...`) and the mutually recursive body. def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name) (kk_var i_var : Expr) (in_ty out_ty : Expr) (inOutTys : List (Expr × Expr)) @@ -322,7 +302,7 @@ def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name) mkAppOptM ``FixI.Funs.Nil #[finTypeExpr, in_ty, out_ty] | (ity, oty) :: inOutTys, b :: bl => do -- Retrieving ity and oty - this is not very clean - let inOutTysExpr ← mkList (← inOutTys.mapM (λ (x, y) => mkInOutTy x y)) inOutTyType + let inOutTysExpr ← mkListLit inOutTyType (← inOutTys.mapM (λ (x, y) => mkInOutTy x y)) let fl ← mkFuns inOutTys bl mkAppOptM ``FixI.Funs.Cons #[finTypeExpr, in_ty, out_ty, ity, oty, inOutTysExpr, b, fl] | _, _ => throwError "mkDeclareMutRecBody: `tys` and `bodies` don't have the same length" @@ -345,7 +325,7 @@ def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name) all := [name] } addDecl decl - -- Return the constant + -- Return the bodies and the constant pure (bodyFuns, Lean.mkConst name (levelParams.map .param)) def isCasesExpr (e : Expr) : MetaM Bool := do @@ -367,7 +347,8 @@ instance : ToMessageData MatchInfo where -- This is not a very clean formatting, but we don't need more toMessageData := fun me => m!"\n- matcherName: {me.matcherName}\n- params: {me.params}\n- motive: {me.motive}\n- scruts: {me.scruts}\n- branchesNumParams: {me.branchesNumParams}\n- branches: {me.branches}" --- An expression which doesn't use the continuation kk is valid +-- Small helper: prove that an expression which doesn't use the continuation `kk` +-- is valid, and return the proof. def proveNoKExprIsValid (k_var : Expr) (e : Expr) : MetaM Expr := do trace[Diverge.def.valid] "proveNoKExprIsValid: {e}" let eIsValid ← mkAppM ``FixI.is_valid_p_same #[k_var, e] @@ -376,6 +357,14 @@ def proveNoKExprIsValid (k_var : Expr) (e : Expr) : MetaM Expr := do mutual +/- Prove that an expression is valid, and return the proof. + + More precisely, if `e` is an expression which potentially uses the continution + `kk`, return an expression of type: + ``` + is_valid_p k (λ kk => e) + ``` + -/ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do trace[Diverge.def.valid] "proveValid: {e}" match e with @@ -403,7 +392,9 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do | .app .. => e.withApp fun f args => do -- There are several cases: first, check if this is a match/if - -- The expression is a (dependent) if then else + -- Check if the expression is a (dependent) if then else. + -- We treat the if then else expressions differently from the other matches, + -- and have dedicated theorems for them. let isIte := e.isIte if isIte || e.isDIte then do e.withApp fun f args => do @@ -431,9 +422,9 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do let eIsValid ← mkAppOptM const #[none, none, none, none, some k_var, some cond, some dec, none, none, some br0Valid, some br1Valid] trace[Diverge.def.valid] "ite/dite: result:\n{eIsValid}:\n{← inferType eIsValid}" pure eIsValid - -- The expression is a match (this case is for when the elaborator + -- Check if the expression is a match (this case is for when the elaborator -- introduces auxiliary definitions to hide the match behind syntactic - -- sugar) + -- sugar): else if let some me := ← matchMatcherApp? e then do trace[Diverge.def.valid] "matcherApp: @@ -443,7 +434,8 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do - altNumParams: {me.altNumParams} - alts: {me.alts} - remaining: {me.remaining}" - -- matchMatcherApp has already done the work for us + -- matchMatcherApp does all the work for us: we simply need to gather + -- the information and call the auxiliary helper `proveMatchIsValid` if me.remaining.size ≠ 0 then throwError "MatcherApp: non empty remaining array: {me.remaining}" let me : MatchInfo := { @@ -456,14 +448,21 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do branches := me.alts } proveMatchIsValid k_var kk_var me - -- The expression is a raw match (this case is for when the expression - -- is a direct call to the primitive `casesOn` function, without - -- syntactic sugar) + -- Check if the expression is a raw match (this case is for when the expression + -- is a direct call to the primitive `casesOn` function, without syntactic sugar). + -- We have to check this case because functions like `mkSigmasMatch`, which we + -- use to currify function bodies, introduce such raw matches. else if ← isCasesExpr f then do trace[Diverge.def.valid] "rawMatch: {e}" + -- Deconstruct the match, and call the auxiliary helper `proveMatchIsValid`. + -- -- The casesOn definition is always of the following shape: - -- input parameters (implicit parameters), then motive (implicit), - -- scrutinee (explicit), branches (explicit). + -- - input parameters (implicit parameters) + -- - motive (implicit), -- the motive gives the return type of the match + -- - scrutinee (explicit) + -- - branches (explicit). + -- In particular, we notice that the scrutinee is the first *explicit* + -- parameter - this is how we spot it. let matcherName := f.constName! let matcherLevels := f.constLevels!.toArray -- Find the first explicit parameter: this is the scrutinee @@ -484,7 +483,9 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do let scrut := args.get! scrutIdx let branches := args.extract (scrutIdx + 1) args.size -- Compute the number of parameters for the branches: for this we use - -- the type of the uninstantiated casesOn constant + -- the type of the uninstantiated casesOn constant (we can't just + -- destruct the lambdas in the branch expressions because the result + -- of a match might be a lambda expression). let branchesNumParams : Array Nat ← do let env ← getEnv let decl := env.constants.find! matcherName @@ -505,9 +506,11 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do branches, } proveMatchIsValid k_var kk_var me - -- Monadic let-binding + -- Check if this is a monadic let-binding else if f.isConstOf ``Bind.bind then do trace[Diverge.def.valid] "bind:\n{args}" + -- We simply need to prove that the subexpressions are valid, and call + -- the appropriate lemma. let x := args.get! 4 let y := args.get! 5 -- Prove that the subexpressions are valid @@ -529,7 +532,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do -- Put everything together trace[Diverge.def.valid] "bind:\n- xValid: {xValid}: {← inferType xValid}\n- yValid: {yValid}: {← inferType yValid}" mkAppM ``FixI.is_valid_p_bind #[xValid, yValid] - -- Recursive call + -- Check if this is a recursive call, i.e., a call to the continuation `kk` else if f.isFVarOf kk_var.fvarId! then do trace[Diverge.def.valid] "rec: args: \n{args}" if args.size ≠ 2 then throwError "Recursive call with invalid number of parameters: {args}" @@ -540,9 +543,10 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do pure eIsValid else do -- Remaining case: normal application. - -- It shouldn't use the continuation + -- It shouldn't use the continuation. proveNoKExprIsValid k_var e +-- Prove that a match expression is valid. partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Expr := do trace[Diverge.def.valid] "proveMatchIsValid: {me}" -- Prove the validity of the branch expressions @@ -561,16 +565,18 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp -- Reconstruct the lambda expression mkLambdaFVars xs_beg brValid trace[Diverge.def.valid] "branchesValid:\n{branchesValid}" - -- Put together: compute the motive. - -- It must be of the shape: + -- Compute the motive, which has the following shape: -- ``` -- λ scrut => is_valid_p k (λ k => match scrut with ...) + -- ^^^^^^^^^^^^^^^^^^^^ + -- this is the original match expression, with the + -- the difference that the scrutinee(s) is a variable -- ``` let validMotive : Expr ← do -- The motive is a function of the scrutinees (i.e., a lambda expression): -- introduce binders for the scrutinees let declInfos := me.scruts.mapIdx fun idx scrut => - let name : Name := (.num (.str .anonymous "scrut") idx) + let name : Name := mkAnonymous "scrut" idx let ty := λ (_ : Array Expr) => inferType scrut (name, ty) withLocalDeclsD declInfos fun scrutVars => do @@ -582,7 +588,6 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp let branches : Array (Option Expr) := me.branches.map some let args := params ++ [motive] ++ scruts ++ branches let matchE ← mkAppOptM me.matcherName args - -- let matchE ← mkLambdaFVars scrutVars (← mkAppOptM me.matcherName args) -- Wrap in the `is_valid_p` predicate let matchE ← mkLambdaFVars #[kk_var] matchE let validMotive ← mkAppM ``FixI.is_valid_p #[k_var, matchE] @@ -591,6 +596,7 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp trace[Diverge.def.valid] "valid motive: {validMotive}" -- Put together let valid ← do + -- We let Lean infer the parameters let params : Array (Option Expr) := me.params.map (λ _ => none) let motive := some validMotive let scruts := me.scruts.map some @@ -602,12 +608,16 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp end --- Prove that a single body (in the mutually recursive group) is valid -partial def proveSingleBodyIsValid (k_var : Expr) (preDef : PreDefinition) (bodyConst : Expr) : +-- Prove that a single body (in the mutually recursive group) is valid. +-- +-- For instance, if we define the mutually recursive group [`is_even`, `is_odd`], +-- we prove that `is_even.body` and `is_odd.body` are valid. +partial def proveSingleBodyIsValid + (k_var : Expr) (preDef : PreDefinition) (bodyConst : Expr) : MetaM Expr := do trace[Diverge.def.valid] "proveSingleBodyIsValid: bodyConst: {bodyConst}" - -- Lookup the definition (`bodyConst` is the definition of the body, we want - -- to retrieve the value itself to dive inside) + -- Lookup the definition (`bodyConst` is a const, we want to retrieve its + -- definition to dive inside) let name := bodyConst.constName! let env ← getEnv let body := (env.constants.find! name).value! @@ -633,7 +643,7 @@ partial def proveSingleBodyIsValid (k_var : Expr) (preDef : PreDefinition) (body mkForallFVars #[k_var, x_var] ty trace[Diverge.def.valid] "proveSingleBodyIsValid: thmTy\n{thmTy}:\n{← inferType thmTy}" -- Save the theorem - let name := preDef.declName ++ "sbody_is_valid" + let name := preDef.declName ++ "body_is_valid" let decl := Declaration.thmDecl { name levelParams := preDef.levelParams @@ -646,6 +656,11 @@ partial def proveSingleBodyIsValid (k_var : Expr) (preDef : PreDefinition) (body -- Return the theorem pure (Expr.const name (preDef.levelParams.map .param)) +-- Prove that the list of bodies are valid. +-- +-- For instance, if we define the mutually recursive group [`is_even`, `is_odd`], +-- we prove that `Funs.Cons is_even.body (Funs.Cons is_odd.body Funs.Nil)` is +-- valid. partial def proveFunsBodyIsValid (inOutTys: Expr) (bodyFuns : Expr) (k_var : Expr) (bodiesValid : Array Expr) : MetaM Expr := do -- Create the big "and" expression, which groups the validity proof of the individual bodies @@ -663,7 +678,13 @@ partial def proveFunsBodyIsValid (inOutTys: Expr) (bodyFuns : Expr) let isValid ← mkAppM ``FixI.Funs.is_valid_p_is_valid_p #[inOutTys, k_var, bodyFuns, andExpr] mkLambdaFVars #[k_var] isValid --- Prove that the mut rec body is valid +-- Prove that the mut rec body (i.e., the unary body which groups the bodies +-- of all the functions in the mutually recursive group and on which we will +-- apply the fixed-point operator) is valid. +-- +-- We save the proof in the theorem "[GROUP_NAME]."mut_rec_body_is_valid", +-- which we return. +-- -- TODO: maybe this function should introduce k_var itself def proveMutRecIsValid (grName : Name) (grLvlParams : List Name) @@ -693,6 +714,12 @@ def proveMutRecIsValid pure (Expr.const name (grLvlParams.map .param)) -- Generate the final definions by using the mutual body and the fixed point operator. +-- +-- For instance: +-- ``` +-- def is_even (i : Int) : Result Bool := mut_rec_body 0 i +-- def is_odd (i : Int) : Result Bool := mut_rec_body 1 i +-- ``` def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : TermElabM (Array Name) := do let grSize := preDefs.size @@ -701,7 +728,7 @@ def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : -- Create the index let idx ← mkFinVal grSize idx.val -- Group the inputs into a dependent tuple - let input ← mkSigmas xs.toList + let input ← mkSigmasVal xs.toList -- Apply the fixed point let fixedBody ← mkAppM ``FixI.fix #[mutRecBody, idx, input] let fixedBody ← mkLambdaFVars xs fixedBody @@ -746,7 +773,7 @@ partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinitio let idx ← mkFinVal grSize i let proof ← mkAppM ``congr_fun #[proof, idx] -- Add the input argument - let arg ← mkSigmas xs.toList + let arg ← mkSigmasVal xs.toList let proof ← mkAppM ``congr_fun #[proof, arg] -- Abstract the arguments away let proof ← mkLambdaFVars xs proof @@ -774,11 +801,6 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) trace[Diverge.def] ("divRecursion: defs: " ++ msg) - -- CHANGE HERE This function should add definitions with these names/types/values ^^ - -- Temporarily add the predefinitions as axioms - -- for preDef in preDefs do - -- addAsAxiom preDef - -- TODO: what is this? for preDef in preDefs do applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation @@ -803,7 +825,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do if preDef.levelParams ≠ grLvlParams then throwError "Non-uniform polymorphism in the universes" forallTelescope preDef.type (fun in_tys out_ty => do - let in_ty ← liftM (mkSigmasTypesOfTypes in_tys.toList) + let in_ty ← liftM (mkSigmasType in_tys.toList) -- Retrieve the type in the "Result" let out_ty ← get_result_ty out_ty let out_ty ← liftM (mkSigmasMatch in_tys.toList out_ty) @@ -813,14 +835,14 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do trace[Diverge.def] "inOutTys: {inOutTys}" -- Turn the list of input/output type pairs into an expresion let inOutTysExpr ← inOutTys.mapM (λ (x, y) => mkInOutTy x y) - let inOutTysExpr ← mkList inOutTysExpr.toList (← inferType (inOutTysExpr.get! 0)) + let inOutTysExpr ← mkListLit (← inferType (inOutTysExpr.get! 0)) inOutTysExpr.toList -- From the list of pairs of input/output types, actually compute the -- type of the continuation `k`. -- We first introduce the index `i : Fin n` where `n` is the number of -- functions in the group. let i_var_ty := mkFin preDefs.size - withLocalDeclD (.num (.str .anonymous "i") 0) i_var_ty fun i_var => do + withLocalDeclD (mkAnonymous "i" 0) i_var_ty fun i_var => do let in_out_ty ← mkAppM ``List.get #[inOutTysExpr, i_var] trace[Diverge.def] "in_out_ty := {in_out_ty} : {← inferType in_out_ty}" -- Add an auxiliary definition for `in_out_ty` @@ -844,7 +866,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do trace[Diverge.def] "in_out_ty (after decl) := {in_out_ty} : {← inferType in_out_ty}" let in_ty ← mkAppM ``Sigma.fst #[in_out_ty] trace[Diverge.def] "in_ty: {in_ty}" - withLocalDeclD (.num (.str .anonymous "x") 1) in_ty fun input => do + withLocalDeclD (mkAnonymous "x" 1) in_ty fun input => do let out_ty ← mkAppM' (← mkAppM ``Sigma.snd #[in_out_ty]) #[input] trace[Diverge.def] "out_ty: {out_ty}" @@ -853,7 +875,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let out_ty ← mkLambdaFVars #[i_var, input] out_ty let kk_var_ty ← mkAppM ``FixI.kk_ty #[i_var_ty, in_ty, out_ty] trace[Diverge.def] "kk_var_ty: {kk_var_ty}" - withLocalDeclD (.num (.str .anonymous "kk") 2) kk_var_ty fun kk_var => do + withLocalDeclD (mkAnonymous "kk" 2) kk_var_ty fun kk_var => do trace[Diverge.def] "kk_var: {kk_var}" -- Replace the recursive calls in all the function bodies by calls to the @@ -866,7 +888,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Prove that the mut rec body satisfies the validity criteria required by -- our fixed-point let k_var_ty ← mkAppM ``FixI.k_ty #[i_var_ty, in_ty, out_ty] - withLocalDeclD (.num (.str .anonymous "k") 3) k_var_ty fun k_var => do + withLocalDeclD (mkAnonymous "k" 3) k_var_ty fun k_var => do let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies -- Generate the final definitions @@ -915,7 +937,7 @@ def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLC else return () catch _ => s.restore --- The following two functions are copy&pasted from Lean.Elab.MutualDef +-- The following two functions are copy-pasted from Lean.Elab.MutualDef open private elabHeaders levelMVarToParamHeaders getAllUserLevelNames withFunLocalDecls elabFunValues instantiateMVarsAtHeader instantiateMVarsAtLetRecToLift checkLetRecsToLiftTypes withUsed from Lean.Elab.MutualDef @@ -988,61 +1010,67 @@ elab_rules : command else Command.elabCommand <| ← `(namespace $(mkIdentFrom id ns) $cmd end $(mkIdentFrom id ns)) -divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := - match ls with - | [] => .fail .panic - | x :: ls => - if i = 0 then return x - else return (← list_nth ls (i - 1)) - -example {a: Type} (ls : List a) : - ∀ (i : Int), - 0 ≤ i → i < ls.length → - ∃ x, list_nth ls i = .ret x := by - induction ls - . intro i hpos h; simp at h; linarith - . rename_i hd tl ih - intro i hpos h - rw [list_nth.unfold]; simp - split <;> simp [*] - . tauto - . -- TODO: we shouldn't have to do that - have hneq : 0 < i := by cases i <;> rename_i a _ <;> simp_all; cases a <;> simp_all - simp at h - have ⟨ x, ih ⟩ := ih (i - 1) (by linarith) (by linarith) - simp [ih] - tauto - -mutual - divergent def is_even (i : Int) : Result Bool := - if i = 0 then return true else return (← is_odd (i - 1)) - - divergent def is_odd (i : Int) : Result Bool := - if i = 0 then return false else return (← is_even (i - 1)) -end - -#print is_even.unfold -#print is_odd.unfold - -mutual - divergent def foo (i : Int) : Result Nat := - if i > 10 then return (← foo (i / 10)) + (← bar i) else bar 10 - - divergent def bar (i : Int) : Result Nat := - if i > 20 then foo (i / 20) else .ret 42 -end - -#print foo.unfold -#print bar.unfold - --- Testing dependent branching and let-bindings --- TODO: why the linter warning? -divergent def is_non_zero (i : Int) : Result Bool := - if _h:i = 0 then return false - else - let b := true - return b +namespace Tests + /- Some examples of partial functions -/ + + divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := + match ls with + | [] => .fail .panic + | x :: ls => + if i = 0 then return x + else return (← list_nth ls (i - 1)) + + #check list_nth.unfold + + example {a: Type} (ls : List a) : + ∀ (i : Int), + 0 ≤ i → i < ls.length → + ∃ x, list_nth ls i = .ret x := by + induction ls + . intro i hpos h; simp at h; linarith + . rename_i hd tl ih + intro i hpos h + rw [list_nth.unfold]; simp + split <;> simp [*] + . tauto + . -- TODO: we shouldn't have to do that + have hneq : 0 < i := by cases i <;> rename_i a _ <;> simp_all; cases a <;> simp_all + simp at h + have ⟨ x, ih ⟩ := ih (i - 1) (by linarith) (by linarith) + simp [ih] + tauto + + mutual + divergent def is_even (i : Int) : Result Bool := + if i = 0 then return true else return (← is_odd (i - 1)) + + divergent def is_odd (i : Int) : Result Bool := + if i = 0 then return false else return (← is_even (i - 1)) + end + + #check is_even.unfold + #check is_odd.unfold + + mutual + divergent def foo (i : Int) : Result Nat := + if i > 10 then return (← foo (i / 10)) + (← bar i) else bar 10 + + divergent def bar (i : Int) : Result Nat := + if i > 20 then foo (i / 20) else .ret 42 + end + + #check foo.unfold + #check bar.unfold + + -- Testing dependent branching and let-bindings + -- TODO: why the linter warning? + divergent def is_non_zero (i : Int) : Result Bool := + if _h:i = 0 then return false + else + let b := true + return b -#print is_non_zero.unfold + #check is_non_zero.unfold +end Tests end Diverge |