summaryrefslogtreecommitdiff
path: root/backends/lean/Base/Diverge
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--backends/lean/Base/Diverge/Elab.lean63
1 files changed, 47 insertions, 16 deletions
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean
index 41209021..4b08fe44 100644
--- a/backends/lean/Base/Diverge/Elab.lean
+++ b/backends/lean/Base/Diverge/Elab.lean
@@ -255,10 +255,11 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
preDefs.mapM fun preDef => do
-- Replace the recursive calls
let body ← mapVisit visit_e preDef.value
+ trace[Diverge.def.genBody] "Body after replacement of the recursive calls: {body}"
-- Currify the function by grouping the arguments into a dependent tuple
-- (over which we match to retrieve the individual arguments).
- lambdaLetTelescope body fun args body => do
+ lambdaTelescope body fun args body => do
let body ← mkSigmasMatch args.toList body 0
-- Add the declaration
@@ -376,15 +377,18 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
| .sort _ => throwError "Unreachable"
| .lam .. => throwError "Unimplemented"
| .forallE .. => throwError "Unreachable" -- Shouldn't get there
- | .letE dName dTy dValue body _nonDep => do
- -- Introduce a local declaration for the let-binding
- withLetDecl dName dTy dValue fun decl => do
+ | .letE .. => do
+ -- Telescope all the let-bindings (remark: this also telescopes the lambdas)
+ lambdaLetTelescope e fun xs body => do
+ -- Note that we don't visit the bound values: there shouldn't be
+ -- recursive calls, lambda expressions, etc. inside
+ -- Prove that the body is valid
let isValid ← proveExprIsValid k_var kk_var body
- -- Add the let-binding around.
+ -- Add the let-bindings around.
-- Rem.: the let-binding should be *inside* the `is_valid_p`, not outside,
-- but because it reduces in the end it doesn't matter. More precisely:
-- `P (let x := v in y)` and `let x := v in P y` reduce to the same expression.
- mkLetFVars #[decl] isValid
+ mkLambdaFVars xs isValid (usedLetOnly := false)
| .mdata _ b => proveExprIsValid k_var kk_var b
| .proj _ _ _ =>
-- The projection shouldn't use the continuation
@@ -410,7 +414,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
if isIte then proveExprIsValid k_var kk_var br
else do
-- There is a lambda -- TODO: how do we remove exacly *one* lambda?
- lambdaLetTelescope br fun xs br => do
+ lambdaTelescope br fun xs br => do
let x := xs.get! 0
let xs := xs.extract 1 xs.size
let br ← mkLambdaFVars xs br
@@ -518,7 +522,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}"
let yValid ← do
-- This is a lambda expression -- TODO: how do we remove exacly *one* lambda?
- lambdaLetTelescope y fun xs y => do
+ lambdaTelescope y fun xs y => do
let x := xs.get! 0
let xs := xs.extract 1 xs.size
let y ← mkLambdaFVars xs y
@@ -555,7 +559,7 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp
-- binders might come from the match, and some of the binders might come
-- from the fact that the expression in the match is a lambda expression:
-- we use the branchesNumParams field for this reason
- lambdaLetTelescope br fun xs br => do
+ lambdaTelescope br fun xs br => do
let numParams := me.branchesNumParams.get! idx
let xs_beg := xs.extract 0 numParams
let xs_end := xs.extract numParams xs.size
@@ -622,7 +626,7 @@ partial def proveSingleBodyIsValid
let env ← getEnv
let body := (env.constants.find! name).value!
trace[Diverge.def.valid] "body: {body}"
- lambdaLetTelescope body fun xs body => do
+ lambdaTelescope body fun xs body => do
assert! xs.size = 2
let kk_var := xs.get! 0
let x_var := xs.get! 1
@@ -695,8 +699,10 @@ def proveMutRecIsValid
let bodiesValid ←
bodies.mapIdxM fun idx body => do
let preDef := preDefs.get! idx
+ trace[Diverge.def.valid] "## Proving that the body {body} is valid"
proveSingleBodyIsValid k_var preDef body
-- Then prove that the mut rec body is valid
+ trace[Diverge.def.valid] "## Proving that the 'Funs' body is valid"
let isValid ← proveFunsBodyIsValid inOutTys bodyFuns k_var bodiesValid
-- Save the theorem
let thmTy ← mkAppM ``FixI.is_valid #[mutRecBodyConst]
@@ -724,7 +730,7 @@ def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) :
TermElabM (Array Name) := do
let grSize := preDefs.size
let defs ← preDefs.mapIdxM fun idx preDef => do
- lambdaLetTelescope preDef.value fun xs _ => do
+ lambdaTelescope preDef.value fun xs _ => do
-- Create the index
let idx ← mkFinVal grSize idx.val
-- Group the inputs into a dependent tuple
@@ -755,7 +761,7 @@ partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinitio
let preDef := preDefs.get! i
let defName := decls.get! i
-- Retrieve the arguments
- lambdaLetTelescope preDef.value fun xs body => do
+ lambdaTelescope preDef.value fun xs body => do
trace[Diverge.def.unfold] "proveUnfoldingThms: xs: {xs}"
trace[Diverge.def.unfold] "proveUnfoldingThms: body: {body}"
-- The theorem statement
@@ -799,7 +805,7 @@ partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinitio
def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value)
- trace[Diverge.def] ("divRecursion: defs: " ++ msg)
+ trace[Diverge.def] ("divRecursion: defs:\n" ++ msg)
-- TODO: what is this?
for preDef in preDefs do
@@ -880,8 +886,11 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
-- Replace the recursive calls in all the function bodies by calls to the
-- continuation `k` and and generate for those bodies declarations
+ trace[Diverge.def] "# Generating the unary bodies"
let bodies ← mkDeclareUnaryBodies grLvlParams kk_var preDefs
+ trace[Diverge.def] "Unary bodies (after decl): {bodies}"
-- Generate the mutually recursive body
+ trace[Diverge.def] "# Generating the mut rec body"
let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var in_ty out_ty inOutTys.toList bodies
trace[Diverge.def] "mut rec body (after decl): {mutRecBody}"
@@ -889,15 +898,18 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
-- our fixed-point
let k_var_ty ← mkAppM ``FixI.k_ty #[i_var_ty, in_ty, out_ty]
withLocalDeclD (mkAnonymous "k" 3) k_var_ty fun k_var => do
+ trace[Diverge.def] "# Proving that the mut rec body is valid"
let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies
-- Generate the final definitions
+ trace[Diverge.def] "# Generating the final definitions"
let decls ← mkDeclareFixDefs mutRecBody preDefs
-- Prove the unfolding theorems
+ trace[Diverge.def] "# Proving the unfolding theorems"
proveUnfoldingThms isValidThm preDefs decls
- -- Process the definitions - TODO
+ -- Generating code -- TODO
addAndCompilePartialRec preDefs
-- The following function is copy&pasted from Lean.Elab.PreDefinition.Main
@@ -1064,13 +1076,32 @@ namespace Tests
-- Testing dependent branching and let-bindings
-- TODO: why the linter warning?
- divergent def is_non_zero (i : Int) : Result Bool :=
+ divergent def isNonZero (i : Int) : Result Bool :=
if _h:i = 0 then return false
else
let b := true
return b
- #check is_non_zero.unfold
+ #check isNonZero.unfold
+
+ -- Testing let-bindings
+ divergent def iInBounds {a : Type} (ls : List a) (i : Int) : Result Bool :=
+ let i0 := ls.length
+ if i < i0
+ then Result.ret True
+ else Result.ret False
+
+ #check iInBounds.unfold
+
+ divergent def isCons
+ {a : Type} (ls : List a) : Result Bool :=
+ let ls1 := ls
+ match ls1 with
+ | [] => Result.ret False
+ | x :: tl => Result.ret True
+
+ #check isCons.unfold
+
end Tests
end Diverge