summaryrefslogtreecommitdiff
path: root/backends/lean/Base/Diverge/Elab.lean
diff options
context:
space:
mode:
authorSon Ho2024-06-17 06:16:43 +0200
committerSon Ho2024-06-17 06:16:43 +0200
commite57e6f08e5cc34bf4e9237650f5ecbab440b9ea2 (patch)
tree1e48b2d23719d72f39282213a1806591cc35c3b8 /backends/lean/Base/Diverge/Elab.lean
parentf3b22b5cca9bc1154f55a81c9a82dc491074067d (diff)
parent85098d7caf5e3196c2e8f92411efd2814bfed1ea (diff)
Merge branch 'son/update-lean' into has-int-pred
Diffstat (limited to '')
-rw-r--r--backends/lean/Base/Diverge/Elab.lean129
1 files changed, 81 insertions, 48 deletions
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean
index 5db8ffed..60955051 100644
--- a/backends/lean/Base/Diverge/Elab.lean
+++ b/backends/lean/Base/Diverge/Elab.lean
@@ -22,6 +22,10 @@ def normalize_let_bindings := true
open WF in
+-- Small utility - it seems that `Name.append` doesn't do what we want
+def appendToName (n : Name) (s : String) : Name :=
+ Name.str n s
+
-- TODO: use those
def UnitType := Expr.const ``PUnit [Level.succ .zero]
def UnitValue := Expr.const ``PUnit.unit [Level.succ .zero]
@@ -548,7 +552,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
-- Add the declaration
let value ← mkLambdaFVars #[kk_var] body
trace[Diverge.def.genBody] "Body after abstracting kk: {value}"
- let name := preDef.declName.append "body"
+ let name := appendToName preDef.declName "body"
let levelParams := grLvlParams
let decl := Declaration.defnDecl {
name := name
@@ -603,7 +607,7 @@ def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name)
let body ← mkLambdaFVars #[kk_var, i_var] body
trace[Diverge.def] "mkDeclareMutRecBody: body: {body}"
-- Add the declaration
- let name := grName.append "mut_rec_body"
+ let name := appendToName grName "mut_rec_body"
let levelParams := grLvlParams
let decl := Declaration.defnDecl {
name := name
@@ -1047,7 +1051,7 @@ partial def proveSingleBodyIsValid
mkForallFVars #[k_var, t_var, x_var] ty
trace[Diverge.def.valid] "proveSingleBodyIsValid: thmTy\n{thmTy}:\n{← inferType thmTy}"
-- Save the theorem
- let name := preDef.declName ++ "body_is_valid"
+ let name := appendToName preDef.declName "body_is_valid"
let decl := Declaration.thmDecl {
name
levelParams := preDef.levelParams
@@ -1107,7 +1111,7 @@ def proveMutRecIsValid
trace[Diverge.def.valid] "Generated the term: {isValid}"
-- Save the theorem
let thmTy ← mkAppM ``FixII.is_valid #[mutRecBodyConst]
- let name := grName ++ "mut_rec_body_is_valid"
+ let name := appendToName grName "mut_rec_body_is_valid"
let decl := Declaration.thmDecl {
name
levelParams := grLvlParams
@@ -1196,7 +1200,7 @@ partial def proveUnfoldingThms (isValidThm : Expr)
let proof ← mkLambdaFVars xs proof
trace[Diverge.def.unfold] "proveUnfoldingThms: proof: {proof}:\n{← inferType proof}"
-- Declare the theorem
- let name := preDef.declName ++ "unfold"
+ let name := appendToName preDef.declName "unfold"
let decl := Declaration.thmDecl {
name
levelParams := preDef.levelParams
@@ -1282,7 +1286,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
-- Add an auxiliary definition for `param_in_out_ty` (this is a potentially big term)
let param_in_out_ty ← do
let value ← mkLambdaFVars #[i_var] param_in_out_ty
- let name := grName.append "param_in_out_ty"
+ let name := appendToName grName "param_in_out_ty"
let levelParams := grLvlParams
let decl := Declaration.defnDecl {
name := name
@@ -1392,44 +1396,71 @@ def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLC
open private elabHeaders levelMVarToParamHeaders getAllUserLevelNames withFunLocalDecls elabFunValues
instantiateMVarsAtHeader instantiateMVarsAtLetRecToLift checkLetRecsToLiftTypes withUsed from Lean.Elab.MutualDef
-def Term.elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do
- let scopeLevelNames ← getLevelNames
- let headers ← elabHeaders views
- let headers ← levelMVarToParamHeaders views headers
- let allUserLevelNames := getAllUserLevelNames headers
- withFunLocalDecls headers fun funFVars => do
- for view in views, funFVar in funFVars do
- addLocalVarInfo view.declId funFVar
- -- Add fake use site to prevent "unused variable" warning (if the
- -- function is actually not recursive, Lean would print this warning).
- -- Remark: we could detect this case and encode the function without
- -- using the fixed-point. In practice it shouldn't happen however:
- -- we define non-recursive functions with the `divergent` keyword
- -- only for testing purposes.
- addTermInfo' view.declId funFVar
- let values ←
- try
- let values ← elabFunValues headers
- Term.synthesizeSyntheticMVarsNoPostponing
- values.mapM (instantiateMVars ·)
- catch ex =>
- logException ex
- headers.mapM fun header => mkSorry header.type (synthetic := true)
- let headers ← headers.mapM instantiateMVarsAtHeader
- let letRecsToLift ← getLetRecsToLift
- let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift
- checkLetRecsToLiftTypes funFVars letRecsToLift
- withUsed vars headers values letRecsToLift fun vars => do
- let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift
- for preDef in preDefs do
- trace[Diverge.elab] "{preDef.declName} : {preDef.type} :=\n{preDef.value}"
- let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamPreDecls preDefs
- let preDefs ← instantiateMVarsAtPreDecls preDefs
- let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames
- for preDef in preDefs do
- trace[Diverge.elab] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}"
- checkForHiddenUnivLevels allUserLevelNames preDefs
- addPreDefinitions preDefs
+-- Comes from Term.isExample
+def isExample (views : Array DefView) : Bool :=
+ views.any (·.kind.isExample)
+
+open Language in
+def Term.elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit :=
+ if isExample views then
+ withoutModifyingEnv do
+ -- save correct environment in info tree
+ withSaveInfoContext do
+ go
+ else
+ go
+where
+ go :=
+ withAlwaysResolvedPromises views.size fun bodyPromises =>
+ withAlwaysResolvedPromises views.size fun tacPromises => do
+ let scopeLevelNames ← getLevelNames
+ let headers ← elabHeaders views bodyPromises tacPromises
+ let headers ← levelMVarToParamHeaders views headers
+ let allUserLevelNames := getAllUserLevelNames headers
+ withFunLocalDecls headers fun funFVars => do
+ for view in views, funFVar in funFVars do
+ addLocalVarInfo view.declId funFVar
+ -- Modification 1:
+ -- Add fake use site to prevent "unused variable" warning (if the
+ -- function is actually not recursive, Lean would print this warning).
+ -- Remark: we could detect this case and encode the function without
+ -- using the fixed-point. In practice it shouldn't happen however:
+ -- we define non-recursive functions with the `divergent` keyword
+ -- only for testing purposes.
+ addTermInfo' view.declId funFVar
+ let values ←
+ try
+ let values ← elabFunValues headers
+ Term.synthesizeSyntheticMVarsNoPostponing
+ values.mapM (instantiateMVars ·)
+ catch ex =>
+ logException ex
+ headers.mapM fun header => mkSorry header.type (synthetic := true)
+ let headers ← headers.mapM instantiateMVarsAtHeader
+ let letRecsToLift ← getLetRecsToLift
+ let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift
+ checkLetRecsToLiftTypes funFVars letRecsToLift
+ withUsed vars headers values letRecsToLift fun vars => do
+ let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift
+ for preDef in preDefs do
+ trace[Elab.definition] "{preDef.declName} : {preDef.type} :=\n{preDef.value}"
+ let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamPreDecls preDefs
+ let preDefs ← instantiateMVarsAtPreDecls preDefs
+ let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames
+ for preDef in preDefs do
+ trace[Elab.definition] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}"
+ checkForHiddenUnivLevels allUserLevelNames preDefs
+ addPreDefinitions preDefs -- Modification 2: we use our custom function here
+ processDeriving headers
+
+ processDeriving (headers : Array DefViewElabHeader) := do
+ for header in headers, view in views do
+ if let some classNamesStx := view.deriving? then
+ for classNameStx in classNamesStx do
+ let className ← realizeGlobalConstNoOverload classNameStx
+ withRef classNameStx do
+ unless (← processDefDeriving className header.declName) do
+ throwError "failed to synthesize instance '{className}' for '{header.declName}'"
open Command in
def Command.elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do
@@ -1439,7 +1470,8 @@ def Command.elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do
let modifiers ← elabModifiers mods
let (binders, type) := expandOptDeclSig sig
let deriving? := none
- pure { ref := d, kind := DefKind.def, modifiers,
+ let headerRef := Syntax.missing -- Not sure what to put here
+ pure { ref := d, kind := DefKind.def, headerRef, modifiers,
declId := id, binders, type? := type, value := val, deriving? }
runTermElabM fun vars => Term.elabMutualDef vars views
@@ -1460,7 +1492,7 @@ elab_rules : command
if (`_root_).isPrefixOf name then throwUnsupportedSyntax
let view := extractMacroScopes name
let .str ns shortName := view.name | throwUnsupportedSyntax
- let shortName' := { view with name := shortName }.review
+ let shortName' := { view with name := Name.mkSimple shortName }.review
let cmd ← `(mutual $mods:declModifiers divergent%$tk def $(⟨setDeclIdName id shortName'⟩):declId $sig:optDeclSig $val:declVal end)
if ns matches .anonymous then
Command.elabCommand cmd
@@ -1475,6 +1507,7 @@ namespace Tests
--set_option trace.Diverge.def.genBody true
--set_option trace.Diverge.def.valid true
--set_option trace.Diverge.def.genBody.visit true
+ --set_option trace.Diverge.def.unfold true
divergent def list_nth {a: Type u} (ls : List a) (i : Int) : Result a :=
match ls with
@@ -1492,7 +1525,7 @@ namespace Tests
0 ≤ i → i < ls.length →
∃ x, list_nth ls i = .ok x := by
induction ls
- . intro i hpos h; simp at h; linarith
+ . intro i hpos h; simp at h; omega
. rename_i hd tl ih
intro i hpos h
-- We can directly use `rw [list_nth]`
@@ -1502,7 +1535,7 @@ namespace Tests
. -- We don't have to do this if we use scalar_tac
have hneq : 0 < i := by cases i <;> rename_i a _ <;> simp_all; cases a <;> simp_all
simp at h
- have ⟨ x, ih ⟩ := ih (i - 1) (by linarith) (by linarith)
+ have ⟨ x, ih ⟩ := ih (i - 1) (by omega) (by omega)
simp [ih]
tauto