summaryrefslogtreecommitdiff
path: root/backends/lean/Base/Diverge/Elab.lean
diff options
context:
space:
mode:
Diffstat (limited to 'backends/lean/Base/Diverge/Elab.lean')
-rw-r--r--backends/lean/Base/Diverge/Elab.lean366
1 files changed, 197 insertions, 169 deletions
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean
index 91c51a31..cc580265 100644
--- a/backends/lean/Base/Diverge/Elab.lean
+++ b/backends/lean/Base/Diverge/Elab.lean
@@ -15,39 +15,16 @@ syntax (name := divergentDef)
open Lean Elab Term Meta Primitives Lean.Meta
-set_option trace.Diverge.def true
--- set_option trace.Diverge.def.valid true
--- set_option trace.Diverge.def.sigmas true
-set_option trace.Diverge.def.unfold true
-
/- The following was copied from the `wfRecursion` function. -/
open WF in
-def mkList (xl : List Expr) (ty : Expr) : MetaM Expr :=
- match xl with
- | [] =>
- mkAppOptM ``List.nil #[some ty]
- | x :: tl => do
- let tl ← mkList tl ty
- mkAppOptM ``List.cons #[some ty, some x, some tl]
-
def mkProd (x y : Expr) : MetaM Expr :=
mkAppM ``Prod.mk #[x, y]
def mkInOutTy (x y : Expr) : MetaM Expr :=
mkAppM ``FixI.mk_in_out_ty #[x, y]
--- TODO: is there already such a utility somewhere?
--- TODO: change to mkSigmas
-def mkProds (tys : List Expr) : MetaM Expr :=
- match tys with
- | [] => do pure (Expr.const ``PUnit.unit [])
- | [ty] => do pure ty
- | ty :: tys => do
- let pty ← mkProds tys
- mkAppM ``Prod.mk #[ty, pty]
-
-- Return the `a` in `Return a`
def get_result_ty (ty : Expr) : MetaM Expr :=
ty.withApp fun f args => do
@@ -56,26 +33,31 @@ def get_result_ty (ty : Expr) : MetaM Expr :=
else
pure (args.get! 0)
--- Group a list of expressions into a dependent tuple
-def mkSigmas (xl : List Expr) : MetaM Expr :=
+/- Group a list of expressions into a dependent tuple.
+
+ Example:
+ xl = [`a : Type`, `ls : List a`]
+ returns:
+ `⟨ (a:Type), (ls: List a) ⟩`
+ -/
+def mkSigmasVal (xl : List Expr) : MetaM Expr :=
match xl with
| [] => do
- trace[Diverge.def.sigmas] "mkSigmas: []"
+ trace[Diverge.def.sigmas] "mkSigmasVal: []"
pure (Expr.const ``PUnit.unit [])
| [x] => do
- trace[Diverge.def.sigmas] "mkSigmas: [{x}]"
+ trace[Diverge.def.sigmas] "mkSigmasVal: [{x}]"
pure x
| fst :: xl => do
- trace[Diverge.def.sigmas] "mkSigmas: [{fst}::{xl}]"
+ trace[Diverge.def.sigmas] "mkSigmasVal: [{fst}::{xl}]"
let alpha ← Lean.Meta.inferType fst
- let snd ← mkSigmas xl
+ let snd ← mkSigmasVal xl
let snd_ty ← inferType snd
let beta ← mkLambdaFVars #[fst] snd_ty
- trace[Diverge.def.sigmas] "mkSigmas:\n{alpha}\n{beta}\n{fst}\n{snd}"
+ trace[Diverge.def.sigmas] "mkSigmasVal:\n{alpha}\n{beta}\n{fst}\n{snd}"
mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd]
-/- Generate the input type of a function body, which is a sigma type (i.e., a
- dependent tuple) which groups all its inputs.
+/- Generate a Sigma type from a list of expressions.
Example:
- xl = [(a:Type), (ls:List a), (i:Int)]
@@ -84,7 +66,7 @@ def mkSigmas (xl : List Expr) : MetaM Expr :=
`(a:Type) × (ls:List a) × (i:Int)`
-/
-def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr :=
+def mkSigmasType (xl : List Expr) : MetaM Expr :=
match xl with
| [] => do
trace[Diverge.def.sigmas] "mkSigmasOfTypes: []"
@@ -96,15 +78,16 @@ def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr :=
| x :: xl => do
trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]"
let alpha ← Lean.Meta.inferType x
- let sty ← mkSigmasTypesOfTypes xl
+ let sty ← mkSigmasType xl
trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]: alpha={alpha}, sty={sty}"
let beta ← mkLambdaFVars #[x] sty
trace[Diverge.def.sigmas] "mkSigmasOfTypes: ({alpha}) ({beta})"
mkAppOptM ``Sigma #[some alpha, some beta]
-def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index
+def mkAnonymous (s : String) (i : Nat) : Name :=
+ .num (.str .anonymous s) i
-/- Given a list of values `[x0:ty0, ..., xn:ty1]` where every `xi` might use the previous
+/- Given a list of values `[x0:ty0, ..., xn:ty1]`, where every `xi` might use the previous
`xj` (j < i) and a value `out` which uses `x0`, ..., `xn`, generate the following
expression:
```
@@ -112,20 +95,22 @@ def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index
match x with
| (x0, ..., xn) => out
```
-
+
The `index` parameter is used for naming purposes: we use it to numerotate the
bound variables that we introduce.
+ We use this function to currify functions (the function bodies given to the
+ fixed-point operator must be unary functions).
+
Example:
========
- More precisely:
- xl = `[a:Type, ls:List a, i:Int]`
- out = `a`
- index = 0
- generates:
+ generates (getting rid of most of the syntactic sugar):
```
- match scrut0 with
+ λ scrut0 => match scrut0 with
| Sigma.mk x scrut1 =>
match scrut1 with
| Sigma.mk ls i =>
@@ -138,21 +123,30 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met
-- This would be unexpected
throwError "mkSigmasMatch: empyt list of input parameters"
| [x] => do
- -- In the explanations above: inner match case
+ -- In the example given for the explanations: this is the inner match case
trace[Diverge.def.sigmas] "mkSigmasMatch: [{x}]"
mkLambdaFVars #[x] out
| fst :: xl => do
- -- In the explanations above: outer match case
+ -- In the example given for the explanations: this is the outer match case
-- Remark: for the naming purposes, we use the same convention as for the
- -- fields and parameters in `Sigma.casesOn and `Sigma.mk
+ -- fields and parameters in `Sigma.casesOn` and `Sigma.mk` (looking at
+ -- those definitions might help)
+ --
+ -- We want to build the match expression:
+ -- ```
+ -- λ scrut =>
+ -- match scrut with
+ -- | Sigma.mk x ... -- the hole is given by a recursive call on the tail
+ -- ```
trace[Diverge.def.sigmas] "mkSigmasMatch: [{fst}::{xl}]"
let alpha ← Lean.Meta.inferType fst
- let snd_ty ← mkSigmasTypesOfTypes xl
+ let snd_ty ← mkSigmasType xl
let beta ← mkLambdaFVars #[fst] snd_ty
let snd ← mkSigmasMatch xl out (index + 1)
- let scrut_ty ← mkSigmasTypesOfTypes (fst :: xl)
- withLocalDeclD (mk_indexed_name index) scrut_ty fun scrut => do
let mk ← mkLambdaFVars #[fst] snd
+ -- Introduce the "scrut" variable
+ let scrut_ty ← mkSigmasType (fst :: xl)
+ withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do
trace[Diverge.def.sigmas] "mkSigmasMatch: scrut: ({scrut}) : ({← inferType scrut})"
-- TODO: make the computation of the motive more efficient
let motive ← do
@@ -166,38 +160,32 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met
-- TODO: make this more efficient (we could change the output type of
-- mkSigmasMatch
mkSigmasMatch (fst :: xl) out_ty
+ -- The final expression: putting everything together
trace[Diverge.def.sigmas] "mkSigmasMatch:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})"
let sm ← mkAppOptM ``Sigma.casesOn #[some alpha, some beta, some motive, some scrut, some mk]
+ -- Abstracting the "scrut" variable
let sm ← mkLambdaFVars #[scrut] sm
trace[Diverge.def.sigmas] "mkSigmasMatch: sm: {sm}"
pure sm
/- Small tests for list_nth: give a model of what `mkSigmasMatch` should generate -/
-private def list_nth_out_ty2 (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) :=
+private def list_nth_out_ty_inner (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) :=
@Sigma.casesOn (List a)
(fun (_ls : List a) => Int)
(fun (_scrut1:@Sigma (List a) (fun (_ls : List a) => Int)) => Type)
scrut1
(fun (_ls : List a) (_i : Int) => Diverge.Primitives.Result a)
-private def list_nth_out_ty1 (scrut0 : @Sigma (Type) (fun (a:Type) =>
+private def list_nth_out_ty_outer (scrut0 : @Sigma (Type) (fun (a:Type) =>
@Sigma (List a) (fun (_ls : List a) => Int))) :=
@Sigma.casesOn (Type)
(fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int))
(fun (_scrut0:@Sigma (Type) (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int))) => Type)
scrut0
(fun (a : Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) =>
- list_nth_out_ty2 a scrut1)
+ list_nth_out_ty_inner a scrut1)
/- -/
--- TODO: move
--- TODO: we can use Array.mapIdx
-@[specialize] def mapiAux (i : Nat) (f : Nat → α → β) : List α → List β
- | [] => []
- | a::as => f i a :: mapiAux (i+1) f as
-
-@[specialize] def mapi (f : Nat → α → β) : List α → List β := mapiAux 0 f
-
-- Return the expression: `Fin n`
-- TODO: use more
def mkFin (n : Nat) : Expr :=
@@ -212,15 +200,6 @@ def mkFinVal (n i : Nat) : MetaM Expr := do
let ofNat ← mkAppOptM ``Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat #[n_lit, i_lit]
mkAppOptM ``OfNat.ofNat #[none, none, ofNat]
--- TODO: remove?
-def mkFinValOld (n i : Nat) : MetaM Expr := do
- let finTy := mkFin n
- let ofNat ← mkAppM ``OfNat #[finTy, .lit (.natVal i)]
- match ← trySynthInstance ofNat with
- | LOption.some x =>
- mkAppOptM ``OfNat.ofNat #[none, none, x]
- | _ => throwError "mkFinVal: could not synthesize an instance of {ofNat} "
-
/- Generate and declare as individual definitions the bodies for the individual funcions:
- replace the recursive calls with calls to the continutation `k`
- make those bodies take one single dependent tuple as input
@@ -234,11 +213,11 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
let grSize := preDefs.size
-- Compute the map from name to index - the continuation has an indexed type:
- -- we use the index (a finite number of type `Fin`) to control the function
- -- we call at the recursive call
+ -- we use the index (a finite number of type `Fin`) to control which function
+ -- we call at the recursive call site.
let nameToId : HashMap Name Nat :=
- let namesIds := mapi (fun i d => (d.declName, i)) preDefs.toList
- HashMap.ofList namesIds
+ let namesIds := preDefs.mapIdx (fun i d => (d.declName, i.val))
+ HashMap.ofList namesIds.toList
trace[Diverge.def.genBody] "nameToId: {nameToId.toList}"
@@ -260,7 +239,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
-- Compute the index
let i ← mkFinVal grSize id
-- Put the arguments in one big dependent tuple
- let args ← mkSigmas args.toList
+ let args ← mkSigmasVal args.toList
mkAppM' kk_var #[i, args]
else
-- Not a recursive call: do nothing
@@ -277,13 +256,14 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
-- Replace the recursive calls
let body ← mapVisit visit_e preDef.value
- -- Change the type
+ -- Currify the function by grouping the arguments into a dependent tuple
+ -- (over which we match to retrieve the individual arguments).
lambdaLetTelescope body fun args body => do
let body ← mkSigmasMatch args.toList body 0
-- Add the declaration
let value ← mkLambdaFVars #[kk_var] body
- let name := preDef.declName.append "sbody"
+ let name := preDef.declName.append "body"
let levelParams := grLvlParams
let decl := Declaration.defnDecl {
name := name
@@ -304,7 +284,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
-- Generate a unique function body from the bodies of the mutually recursive group,
-- and add it as a declaration in the context.
--- We return the list of bodies (of type `Funs ...`) and the mutually recursive body.
+-- We return the list of bodies (of type `FixI.Funs ...`) and the mutually recursive body.
def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name)
(kk_var i_var : Expr)
(in_ty out_ty : Expr) (inOutTys : List (Expr × Expr))
@@ -322,7 +302,7 @@ def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name)
mkAppOptM ``FixI.Funs.Nil #[finTypeExpr, in_ty, out_ty]
| (ity, oty) :: inOutTys, b :: bl => do
-- Retrieving ity and oty - this is not very clean
- let inOutTysExpr ← mkList (← inOutTys.mapM (λ (x, y) => mkInOutTy x y)) inOutTyType
+ let inOutTysExpr ← mkListLit inOutTyType (← inOutTys.mapM (λ (x, y) => mkInOutTy x y))
let fl ← mkFuns inOutTys bl
mkAppOptM ``FixI.Funs.Cons #[finTypeExpr, in_ty, out_ty, ity, oty, inOutTysExpr, b, fl]
| _, _ => throwError "mkDeclareMutRecBody: `tys` and `bodies` don't have the same length"
@@ -345,7 +325,7 @@ def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name)
all := [name]
}
addDecl decl
- -- Return the constant
+ -- Return the bodies and the constant
pure (bodyFuns, Lean.mkConst name (levelParams.map .param))
def isCasesExpr (e : Expr) : MetaM Bool := do
@@ -367,7 +347,8 @@ instance : ToMessageData MatchInfo where
-- This is not a very clean formatting, but we don't need more
toMessageData := fun me => m!"\n- matcherName: {me.matcherName}\n- params: {me.params}\n- motive: {me.motive}\n- scruts: {me.scruts}\n- branchesNumParams: {me.branchesNumParams}\n- branches: {me.branches}"
--- An expression which doesn't use the continuation kk is valid
+-- Small helper: prove that an expression which doesn't use the continuation `kk`
+-- is valid, and return the proof.
def proveNoKExprIsValid (k_var : Expr) (e : Expr) : MetaM Expr := do
trace[Diverge.def.valid] "proveNoKExprIsValid: {e}"
let eIsValid ← mkAppM ``FixI.is_valid_p_same #[k_var, e]
@@ -376,6 +357,14 @@ def proveNoKExprIsValid (k_var : Expr) (e : Expr) : MetaM Expr := do
mutual
+/- Prove that an expression is valid, and return the proof.
+
+ More precisely, if `e` is an expression which potentially uses the continution
+ `kk`, return an expression of type:
+ ```
+ is_valid_p k (λ kk => e)
+ ```
+ -/
partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
trace[Diverge.def.valid] "proveValid: {e}"
match e with
@@ -403,7 +392,9 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
| .app .. =>
e.withApp fun f args => do
-- There are several cases: first, check if this is a match/if
- -- The expression is a (dependent) if then else
+ -- Check if the expression is a (dependent) if then else.
+ -- We treat the if then else expressions differently from the other matches,
+ -- and have dedicated theorems for them.
let isIte := e.isIte
if isIte || e.isDIte then do
e.withApp fun f args => do
@@ -431,9 +422,9 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
let eIsValid ← mkAppOptM const #[none, none, none, none, some k_var, some cond, some dec, none, none, some br0Valid, some br1Valid]
trace[Diverge.def.valid] "ite/dite: result:\n{eIsValid}:\n{← inferType eIsValid}"
pure eIsValid
- -- The expression is a match (this case is for when the elaborator
+ -- Check if the expression is a match (this case is for when the elaborator
-- introduces auxiliary definitions to hide the match behind syntactic
- -- sugar)
+ -- sugar):
else if let some me := ← matchMatcherApp? e then do
trace[Diverge.def.valid]
"matcherApp:
@@ -443,7 +434,8 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
- altNumParams: {me.altNumParams}
- alts: {me.alts}
- remaining: {me.remaining}"
- -- matchMatcherApp has already done the work for us
+ -- matchMatcherApp does all the work for us: we simply need to gather
+ -- the information and call the auxiliary helper `proveMatchIsValid`
if me.remaining.size ≠ 0 then
throwError "MatcherApp: non empty remaining array: {me.remaining}"
let me : MatchInfo := {
@@ -456,14 +448,21 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
branches := me.alts
}
proveMatchIsValid k_var kk_var me
- -- The expression is a raw match (this case is for when the expression
- -- is a direct call to the primitive `casesOn` function, without
- -- syntactic sugar)
+ -- Check if the expression is a raw match (this case is for when the expression
+ -- is a direct call to the primitive `casesOn` function, without syntactic sugar).
+ -- We have to check this case because functions like `mkSigmasMatch`, which we
+ -- use to currify function bodies, introduce such raw matches.
else if ← isCasesExpr f then do
trace[Diverge.def.valid] "rawMatch: {e}"
+ -- Deconstruct the match, and call the auxiliary helper `proveMatchIsValid`.
+ --
-- The casesOn definition is always of the following shape:
- -- input parameters (implicit parameters), then motive (implicit),
- -- scrutinee (explicit), branches (explicit).
+ -- - input parameters (implicit parameters)
+ -- - motive (implicit), -- the motive gives the return type of the match
+ -- - scrutinee (explicit)
+ -- - branches (explicit).
+ -- In particular, we notice that the scrutinee is the first *explicit*
+ -- parameter - this is how we spot it.
let matcherName := f.constName!
let matcherLevels := f.constLevels!.toArray
-- Find the first explicit parameter: this is the scrutinee
@@ -484,7 +483,9 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
let scrut := args.get! scrutIdx
let branches := args.extract (scrutIdx + 1) args.size
-- Compute the number of parameters for the branches: for this we use
- -- the type of the uninstantiated casesOn constant
+ -- the type of the uninstantiated casesOn constant (we can't just
+ -- destruct the lambdas in the branch expressions because the result
+ -- of a match might be a lambda expression).
let branchesNumParams : Array Nat ← do
let env ← getEnv
let decl := env.constants.find! matcherName
@@ -505,9 +506,11 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
branches,
}
proveMatchIsValid k_var kk_var me
- -- Monadic let-binding
+ -- Check if this is a monadic let-binding
else if f.isConstOf ``Bind.bind then do
trace[Diverge.def.valid] "bind:\n{args}"
+ -- We simply need to prove that the subexpressions are valid, and call
+ -- the appropriate lemma.
let x := args.get! 4
let y := args.get! 5
-- Prove that the subexpressions are valid
@@ -529,7 +532,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
-- Put everything together
trace[Diverge.def.valid] "bind:\n- xValid: {xValid}: {← inferType xValid}\n- yValid: {yValid}: {← inferType yValid}"
mkAppM ``FixI.is_valid_p_bind #[xValid, yValid]
- -- Recursive call
+ -- Check if this is a recursive call, i.e., a call to the continuation `kk`
else if f.isFVarOf kk_var.fvarId! then do
trace[Diverge.def.valid] "rec: args: \n{args}"
if args.size ≠ 2 then throwError "Recursive call with invalid number of parameters: {args}"
@@ -540,9 +543,10 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
pure eIsValid
else do
-- Remaining case: normal application.
- -- It shouldn't use the continuation
+ -- It shouldn't use the continuation.
proveNoKExprIsValid k_var e
+-- Prove that a match expression is valid.
partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Expr := do
trace[Diverge.def.valid] "proveMatchIsValid: {me}"
-- Prove the validity of the branch expressions
@@ -561,16 +565,18 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp
-- Reconstruct the lambda expression
mkLambdaFVars xs_beg brValid
trace[Diverge.def.valid] "branchesValid:\n{branchesValid}"
- -- Put together: compute the motive.
- -- It must be of the shape:
+ -- Compute the motive, which has the following shape:
-- ```
-- λ scrut => is_valid_p k (λ k => match scrut with ...)
+ -- ^^^^^^^^^^^^^^^^^^^^
+ -- this is the original match expression, with the
+ -- the difference that the scrutinee(s) is a variable
-- ```
let validMotive : Expr ← do
-- The motive is a function of the scrutinees (i.e., a lambda expression):
-- introduce binders for the scrutinees
let declInfos := me.scruts.mapIdx fun idx scrut =>
- let name : Name := (.num (.str .anonymous "scrut") idx)
+ let name : Name := mkAnonymous "scrut" idx
let ty := λ (_ : Array Expr) => inferType scrut
(name, ty)
withLocalDeclsD declInfos fun scrutVars => do
@@ -582,7 +588,6 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp
let branches : Array (Option Expr) := me.branches.map some
let args := params ++ [motive] ++ scruts ++ branches
let matchE ← mkAppOptM me.matcherName args
- -- let matchE ← mkLambdaFVars scrutVars (← mkAppOptM me.matcherName args)
-- Wrap in the `is_valid_p` predicate
let matchE ← mkLambdaFVars #[kk_var] matchE
let validMotive ← mkAppM ``FixI.is_valid_p #[k_var, matchE]
@@ -591,6 +596,7 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp
trace[Diverge.def.valid] "valid motive: {validMotive}"
-- Put together
let valid ← do
+ -- We let Lean infer the parameters
let params : Array (Option Expr) := me.params.map (λ _ => none)
let motive := some validMotive
let scruts := me.scruts.map some
@@ -602,12 +608,16 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp
end
--- Prove that a single body (in the mutually recursive group) is valid
-partial def proveSingleBodyIsValid (k_var : Expr) (preDef : PreDefinition) (bodyConst : Expr) :
+-- Prove that a single body (in the mutually recursive group) is valid.
+--
+-- For instance, if we define the mutually recursive group [`is_even`, `is_odd`],
+-- we prove that `is_even.body` and `is_odd.body` are valid.
+partial def proveSingleBodyIsValid
+ (k_var : Expr) (preDef : PreDefinition) (bodyConst : Expr) :
MetaM Expr := do
trace[Diverge.def.valid] "proveSingleBodyIsValid: bodyConst: {bodyConst}"
- -- Lookup the definition (`bodyConst` is the definition of the body, we want
- -- to retrieve the value itself to dive inside)
+ -- Lookup the definition (`bodyConst` is a const, we want to retrieve its
+ -- definition to dive inside)
let name := bodyConst.constName!
let env ← getEnv
let body := (env.constants.find! name).value!
@@ -633,7 +643,7 @@ partial def proveSingleBodyIsValid (k_var : Expr) (preDef : PreDefinition) (body
mkForallFVars #[k_var, x_var] ty
trace[Diverge.def.valid] "proveSingleBodyIsValid: thmTy\n{thmTy}:\n{← inferType thmTy}"
-- Save the theorem
- let name := preDef.declName ++ "sbody_is_valid"
+ let name := preDef.declName ++ "body_is_valid"
let decl := Declaration.thmDecl {
name
levelParams := preDef.levelParams
@@ -646,6 +656,11 @@ partial def proveSingleBodyIsValid (k_var : Expr) (preDef : PreDefinition) (body
-- Return the theorem
pure (Expr.const name (preDef.levelParams.map .param))
+-- Prove that the list of bodies are valid.
+--
+-- For instance, if we define the mutually recursive group [`is_even`, `is_odd`],
+-- we prove that `Funs.Cons is_even.body (Funs.Cons is_odd.body Funs.Nil)` is
+-- valid.
partial def proveFunsBodyIsValid (inOutTys: Expr) (bodyFuns : Expr)
(k_var : Expr) (bodiesValid : Array Expr) : MetaM Expr := do
-- Create the big "and" expression, which groups the validity proof of the individual bodies
@@ -663,7 +678,13 @@ partial def proveFunsBodyIsValid (inOutTys: Expr) (bodyFuns : Expr)
let isValid ← mkAppM ``FixI.Funs.is_valid_p_is_valid_p #[inOutTys, k_var, bodyFuns, andExpr]
mkLambdaFVars #[k_var] isValid
--- Prove that the mut rec body is valid
+-- Prove that the mut rec body (i.e., the unary body which groups the bodies
+-- of all the functions in the mutually recursive group and on which we will
+-- apply the fixed-point operator) is valid.
+--
+-- We save the proof in the theorem "[GROUP_NAME]."mut_rec_body_is_valid",
+-- which we return.
+--
-- TODO: maybe this function should introduce k_var itself
def proveMutRecIsValid
(grName : Name) (grLvlParams : List Name)
@@ -693,6 +714,12 @@ def proveMutRecIsValid
pure (Expr.const name (grLvlParams.map .param))
-- Generate the final definions by using the mutual body and the fixed point operator.
+--
+-- For instance:
+-- ```
+-- def is_even (i : Int) : Result Bool := mut_rec_body 0 i
+-- def is_odd (i : Int) : Result Bool := mut_rec_body 1 i
+-- ```
def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) :
TermElabM (Array Name) := do
let grSize := preDefs.size
@@ -701,7 +728,7 @@ def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) :
-- Create the index
let idx ← mkFinVal grSize idx.val
-- Group the inputs into a dependent tuple
- let input ← mkSigmas xs.toList
+ let input ← mkSigmasVal xs.toList
-- Apply the fixed point
let fixedBody ← mkAppM ``FixI.fix #[mutRecBody, idx, input]
let fixedBody ← mkLambdaFVars xs fixedBody
@@ -746,7 +773,7 @@ partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinitio
let idx ← mkFinVal grSize i
let proof ← mkAppM ``congr_fun #[proof, idx]
-- Add the input argument
- let arg ← mkSigmas xs.toList
+ let arg ← mkSigmasVal xs.toList
let proof ← mkAppM ``congr_fun #[proof, arg]
-- Abstract the arguments away
let proof ← mkLambdaFVars xs proof
@@ -774,11 +801,6 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value)
trace[Diverge.def] ("divRecursion: defs: " ++ msg)
- -- CHANGE HERE This function should add definitions with these names/types/values ^^
- -- Temporarily add the predefinitions as axioms
- -- for preDef in preDefs do
- -- addAsAxiom preDef
-
-- TODO: what is this?
for preDef in preDefs do
applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation
@@ -803,7 +825,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
if preDef.levelParams ≠ grLvlParams then
throwError "Non-uniform polymorphism in the universes"
forallTelescope preDef.type (fun in_tys out_ty => do
- let in_ty ← liftM (mkSigmasTypesOfTypes in_tys.toList)
+ let in_ty ← liftM (mkSigmasType in_tys.toList)
-- Retrieve the type in the "Result"
let out_ty ← get_result_ty out_ty
let out_ty ← liftM (mkSigmasMatch in_tys.toList out_ty)
@@ -813,14 +835,14 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
trace[Diverge.def] "inOutTys: {inOutTys}"
-- Turn the list of input/output type pairs into an expresion
let inOutTysExpr ← inOutTys.mapM (λ (x, y) => mkInOutTy x y)
- let inOutTysExpr ← mkList inOutTysExpr.toList (← inferType (inOutTysExpr.get! 0))
+ let inOutTysExpr ← mkListLit (← inferType (inOutTysExpr.get! 0)) inOutTysExpr.toList
-- From the list of pairs of input/output types, actually compute the
-- type of the continuation `k`.
-- We first introduce the index `i : Fin n` where `n` is the number of
-- functions in the group.
let i_var_ty := mkFin preDefs.size
- withLocalDeclD (.num (.str .anonymous "i") 0) i_var_ty fun i_var => do
+ withLocalDeclD (mkAnonymous "i" 0) i_var_ty fun i_var => do
let in_out_ty ← mkAppM ``List.get #[inOutTysExpr, i_var]
trace[Diverge.def] "in_out_ty := {in_out_ty} : {← inferType in_out_ty}"
-- Add an auxiliary definition for `in_out_ty`
@@ -844,7 +866,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
trace[Diverge.def] "in_out_ty (after decl) := {in_out_ty} : {← inferType in_out_ty}"
let in_ty ← mkAppM ``Sigma.fst #[in_out_ty]
trace[Diverge.def] "in_ty: {in_ty}"
- withLocalDeclD (.num (.str .anonymous "x") 1) in_ty fun input => do
+ withLocalDeclD (mkAnonymous "x" 1) in_ty fun input => do
let out_ty ← mkAppM' (← mkAppM ``Sigma.snd #[in_out_ty]) #[input]
trace[Diverge.def] "out_ty: {out_ty}"
@@ -853,7 +875,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
let out_ty ← mkLambdaFVars #[i_var, input] out_ty
let kk_var_ty ← mkAppM ``FixI.kk_ty #[i_var_ty, in_ty, out_ty]
trace[Diverge.def] "kk_var_ty: {kk_var_ty}"
- withLocalDeclD (.num (.str .anonymous "kk") 2) kk_var_ty fun kk_var => do
+ withLocalDeclD (mkAnonymous "kk" 2) kk_var_ty fun kk_var => do
trace[Diverge.def] "kk_var: {kk_var}"
-- Replace the recursive calls in all the function bodies by calls to the
@@ -866,7 +888,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
-- Prove that the mut rec body satisfies the validity criteria required by
-- our fixed-point
let k_var_ty ← mkAppM ``FixI.k_ty #[i_var_ty, in_ty, out_ty]
- withLocalDeclD (.num (.str .anonymous "k") 3) k_var_ty fun k_var => do
+ withLocalDeclD (mkAnonymous "k" 3) k_var_ty fun k_var => do
let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies
-- Generate the final definitions
@@ -915,7 +937,7 @@ def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLC
else return ()
catch _ => s.restore
--- The following two functions are copy&pasted from Lean.Elab.MutualDef
+-- The following two functions are copy-pasted from Lean.Elab.MutualDef
open private elabHeaders levelMVarToParamHeaders getAllUserLevelNames withFunLocalDecls elabFunValues
instantiateMVarsAtHeader instantiateMVarsAtLetRecToLift checkLetRecsToLiftTypes withUsed from Lean.Elab.MutualDef
@@ -988,61 +1010,67 @@ elab_rules : command
else
Command.elabCommand <| ← `(namespace $(mkIdentFrom id ns) $cmd end $(mkIdentFrom id ns))
-divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a :=
- match ls with
- | [] => .fail .panic
- | x :: ls =>
- if i = 0 then return x
- else return (← list_nth ls (i - 1))
-
-example {a: Type} (ls : List a) :
- ∀ (i : Int),
- 0 ≤ i → i < ls.length →
- ∃ x, list_nth ls i = .ret x := by
- induction ls
- . intro i hpos h; simp at h; linarith
- . rename_i hd tl ih
- intro i hpos h
- rw [list_nth.unfold]; simp
- split <;> simp [*]
- . tauto
- . -- TODO: we shouldn't have to do that
- have hneq : 0 < i := by cases i <;> rename_i a _ <;> simp_all; cases a <;> simp_all
- simp at h
- have ⟨ x, ih ⟩ := ih (i - 1) (by linarith) (by linarith)
- simp [ih]
- tauto
-
-mutual
- divergent def is_even (i : Int) : Result Bool :=
- if i = 0 then return true else return (← is_odd (i - 1))
-
- divergent def is_odd (i : Int) : Result Bool :=
- if i = 0 then return false else return (← is_even (i - 1))
-end
-
-#print is_even.unfold
-#print is_odd.unfold
-
-mutual
- divergent def foo (i : Int) : Result Nat :=
- if i > 10 then return (← foo (i / 10)) + (← bar i) else bar 10
-
- divergent def bar (i : Int) : Result Nat :=
- if i > 20 then foo (i / 20) else .ret 42
-end
-
-#print foo.unfold
-#print bar.unfold
-
--- Testing dependent branching and let-bindings
--- TODO: why the linter warning?
-divergent def is_non_zero (i : Int) : Result Bool :=
- if _h:i = 0 then return false
- else
- let b := true
- return b
+namespace Tests
+ /- Some examples of partial functions -/
+
+ divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a :=
+ match ls with
+ | [] => .fail .panic
+ | x :: ls =>
+ if i = 0 then return x
+ else return (← list_nth ls (i - 1))
+
+ #check list_nth.unfold
+
+ example {a: Type} (ls : List a) :
+ ∀ (i : Int),
+ 0 ≤ i → i < ls.length →
+ ∃ x, list_nth ls i = .ret x := by
+ induction ls
+ . intro i hpos h; simp at h; linarith
+ . rename_i hd tl ih
+ intro i hpos h
+ rw [list_nth.unfold]; simp
+ split <;> simp [*]
+ . tauto
+ . -- TODO: we shouldn't have to do that
+ have hneq : 0 < i := by cases i <;> rename_i a _ <;> simp_all; cases a <;> simp_all
+ simp at h
+ have ⟨ x, ih ⟩ := ih (i - 1) (by linarith) (by linarith)
+ simp [ih]
+ tauto
+
+ mutual
+ divergent def is_even (i : Int) : Result Bool :=
+ if i = 0 then return true else return (← is_odd (i - 1))
+
+ divergent def is_odd (i : Int) : Result Bool :=
+ if i = 0 then return false else return (← is_even (i - 1))
+ end
+
+ #check is_even.unfold
+ #check is_odd.unfold
+
+ mutual
+ divergent def foo (i : Int) : Result Nat :=
+ if i > 10 then return (← foo (i / 10)) + (← bar i) else bar 10
+
+ divergent def bar (i : Int) : Result Nat :=
+ if i > 20 then foo (i / 20) else .ret 42
+ end
+
+ #check foo.unfold
+ #check bar.unfold
+
+ -- Testing dependent branching and let-bindings
+ -- TODO: why the linter warning?
+ divergent def is_non_zero (i : Int) : Result Bool :=
+ if _h:i = 0 then return false
+ else
+ let b := true
+ return b
-#print is_non_zero.unfold
+ #check is_non_zero.unfold
+end Tests
end Diverge