diff options
Diffstat (limited to '')
-rw-r--r-- | backends/lean/Base/Diverge/Elab.lean | 63 |
1 files changed, 47 insertions, 16 deletions
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 41209021..4b08fe44 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -255,10 +255,11 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) preDefs.mapM fun preDef => do -- Replace the recursive calls let body ← mapVisit visit_e preDef.value + trace[Diverge.def.genBody] "Body after replacement of the recursive calls: {body}" -- Currify the function by grouping the arguments into a dependent tuple -- (over which we match to retrieve the individual arguments). - lambdaLetTelescope body fun args body => do + lambdaTelescope body fun args body => do let body ← mkSigmasMatch args.toList body 0 -- Add the declaration @@ -376,15 +377,18 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do | .sort _ => throwError "Unreachable" | .lam .. => throwError "Unimplemented" | .forallE .. => throwError "Unreachable" -- Shouldn't get there - | .letE dName dTy dValue body _nonDep => do - -- Introduce a local declaration for the let-binding - withLetDecl dName dTy dValue fun decl => do + | .letE .. => do + -- Telescope all the let-bindings (remark: this also telescopes the lambdas) + lambdaLetTelescope e fun xs body => do + -- Note that we don't visit the bound values: there shouldn't be + -- recursive calls, lambda expressions, etc. inside + -- Prove that the body is valid let isValid ← proveExprIsValid k_var kk_var body - -- Add the let-binding around. + -- Add the let-bindings around. -- Rem.: the let-binding should be *inside* the `is_valid_p`, not outside, -- but because it reduces in the end it doesn't matter. More precisely: -- `P (let x := v in y)` and `let x := v in P y` reduce to the same expression. - mkLetFVars #[decl] isValid + mkLambdaFVars xs isValid (usedLetOnly := false) | .mdata _ b => proveExprIsValid k_var kk_var b | .proj _ _ _ => -- The projection shouldn't use the continuation @@ -410,7 +414,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do if isIte then proveExprIsValid k_var kk_var br else do -- There is a lambda -- TODO: how do we remove exacly *one* lambda? - lambdaLetTelescope br fun xs br => do + lambdaTelescope br fun xs br => do let x := xs.get! 0 let xs := xs.extract 1 xs.size let br ← mkLambdaFVars xs br @@ -518,7 +522,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}" let yValid ← do -- This is a lambda expression -- TODO: how do we remove exacly *one* lambda? - lambdaLetTelescope y fun xs y => do + lambdaTelescope y fun xs y => do let x := xs.get! 0 let xs := xs.extract 1 xs.size let y ← mkLambdaFVars xs y @@ -555,7 +559,7 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp -- binders might come from the match, and some of the binders might come -- from the fact that the expression in the match is a lambda expression: -- we use the branchesNumParams field for this reason - lambdaLetTelescope br fun xs br => do + lambdaTelescope br fun xs br => do let numParams := me.branchesNumParams.get! idx let xs_beg := xs.extract 0 numParams let xs_end := xs.extract numParams xs.size @@ -622,7 +626,7 @@ partial def proveSingleBodyIsValid let env ← getEnv let body := (env.constants.find! name).value! trace[Diverge.def.valid] "body: {body}" - lambdaLetTelescope body fun xs body => do + lambdaTelescope body fun xs body => do assert! xs.size = 2 let kk_var := xs.get! 0 let x_var := xs.get! 1 @@ -695,8 +699,10 @@ def proveMutRecIsValid let bodiesValid ← bodies.mapIdxM fun idx body => do let preDef := preDefs.get! idx + trace[Diverge.def.valid] "## Proving that the body {body} is valid" proveSingleBodyIsValid k_var preDef body -- Then prove that the mut rec body is valid + trace[Diverge.def.valid] "## Proving that the 'Funs' body is valid" let isValid ← proveFunsBodyIsValid inOutTys bodyFuns k_var bodiesValid -- Save the theorem let thmTy ← mkAppM ``FixI.is_valid #[mutRecBodyConst] @@ -724,7 +730,7 @@ def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : TermElabM (Array Name) := do let grSize := preDefs.size let defs ← preDefs.mapIdxM fun idx preDef => do - lambdaLetTelescope preDef.value fun xs _ => do + lambdaTelescope preDef.value fun xs _ => do -- Create the index let idx ← mkFinVal grSize idx.val -- Group the inputs into a dependent tuple @@ -755,7 +761,7 @@ partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinitio let preDef := preDefs.get! i let defName := decls.get! i -- Retrieve the arguments - lambdaLetTelescope preDef.value fun xs body => do + lambdaTelescope preDef.value fun xs body => do trace[Diverge.def.unfold] "proveUnfoldingThms: xs: {xs}" trace[Diverge.def.unfold] "proveUnfoldingThms: body: {body}" -- The theorem statement @@ -799,7 +805,7 @@ partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinitio def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) - trace[Diverge.def] ("divRecursion: defs: " ++ msg) + trace[Diverge.def] ("divRecursion: defs:\n" ++ msg) -- TODO: what is this? for preDef in preDefs do @@ -880,8 +886,11 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Replace the recursive calls in all the function bodies by calls to the -- continuation `k` and and generate for those bodies declarations + trace[Diverge.def] "# Generating the unary bodies" let bodies ← mkDeclareUnaryBodies grLvlParams kk_var preDefs + trace[Diverge.def] "Unary bodies (after decl): {bodies}" -- Generate the mutually recursive body + trace[Diverge.def] "# Generating the mut rec body" let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var in_ty out_ty inOutTys.toList bodies trace[Diverge.def] "mut rec body (after decl): {mutRecBody}" @@ -889,15 +898,18 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- our fixed-point let k_var_ty ← mkAppM ``FixI.k_ty #[i_var_ty, in_ty, out_ty] withLocalDeclD (mkAnonymous "k" 3) k_var_ty fun k_var => do + trace[Diverge.def] "# Proving that the mut rec body is valid" let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies -- Generate the final definitions + trace[Diverge.def] "# Generating the final definitions" let decls ← mkDeclareFixDefs mutRecBody preDefs -- Prove the unfolding theorems + trace[Diverge.def] "# Proving the unfolding theorems" proveUnfoldingThms isValidThm preDefs decls - -- Process the definitions - TODO + -- Generating code -- TODO addAndCompilePartialRec preDefs -- The following function is copy&pasted from Lean.Elab.PreDefinition.Main @@ -1064,13 +1076,32 @@ namespace Tests -- Testing dependent branching and let-bindings -- TODO: why the linter warning? - divergent def is_non_zero (i : Int) : Result Bool := + divergent def isNonZero (i : Int) : Result Bool := if _h:i = 0 then return false else let b := true return b - #check is_non_zero.unfold + #check isNonZero.unfold + + -- Testing let-bindings + divergent def iInBounds {a : Type} (ls : List a) (i : Int) : Result Bool := + let i0 := ls.length + if i < i0 + then Result.ret True + else Result.ret False + + #check iInBounds.unfold + + divergent def isCons + {a : Type} (ls : List a) : Result Bool := + let ls1 := ls + match ls1 with + | [] => Result.ret False + | x :: tl => Result.ret True + + #check isCons.unfold + end Tests end Diverge |