summaryrefslogtreecommitdiff
path: root/backends/lean/Base/Diverge
diff options
context:
space:
mode:
authorSon Ho2024-06-17 06:16:43 +0200
committerSon Ho2024-06-17 06:16:43 +0200
commite57e6f08e5cc34bf4e9237650f5ecbab440b9ea2 (patch)
tree1e48b2d23719d72f39282213a1806591cc35c3b8 /backends/lean/Base/Diverge
parentf3b22b5cca9bc1154f55a81c9a82dc491074067d (diff)
parent85098d7caf5e3196c2e8f92411efd2814bfed1ea (diff)
Merge branch 'son/update-lean' into has-int-pred
Diffstat (limited to 'backends/lean/Base/Diverge')
-rw-r--r--backends/lean/Base/Diverge/Base.lean21
-rw-r--r--backends/lean/Base/Diverge/Elab.lean129
2 files changed, 95 insertions, 55 deletions
diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean
index 0f20125f..aab4db8f 100644
--- a/backends/lean/Base/Diverge/Base.lean
+++ b/backends/lean/Base/Diverge/Base.lean
@@ -1,7 +1,6 @@
import Lean
import Lean.Meta.Tactic.Simp
import Init.Data.List.Basic
-import Mathlib.Tactic.Linarith
import Base.Primitives.Base
import Base.Arith.Base
import Base.Diverge.ElabBase
@@ -36,20 +35,19 @@ namespace Lemmas
revert m
induction k -- TODO: induction h rather?
case zero =>
- simp_all
intro m h1 h2
have h: n = m := by omega
unfold for_all_fin_aux; simp_all
simp_all
-- There is no i s.t. m ≤ i
intro i h3; cases i; simp_all
- linarith
+ omega
case succ k hi =>
intro m hk hmn
intro hf i hmi
have hne: m ≠ n := by
have hineq := Nat.lt_of_sub_eq_succ hk
- linarith
+ omega
-- m = i?
if heq: m = i then
-- Yes: simply use the `for_all_fin_aux` hyp
@@ -64,7 +62,7 @@ namespace Lemmas
have heq1: n - (m + 1) = k := by
-- TODO: very annoying arithmetic proof
simp [Nat.sub_eq_iff_eq_add hineq]
- have hineq1: m ≤ n := by linarith
+ have hineq1: m ≤ n := by omega
simp [Nat.sub_eq_iff_eq_add hineq1] at hk
simp_arith [hk]
have hi := hi (m + 1) heq1 hineq
@@ -199,7 +197,7 @@ namespace Fix
| 0 =>
exfalso
zify at *
- linarith
+ omega
| Nat.succ m1 =>
simp_arith at Hle
simp [fix_fuel]
@@ -407,7 +405,7 @@ namespace Fix
. simp at Hl
-- Make a case disjunction on `h y (fix_fuel m k)`: if it is not equal
-- to div, use the monotonicity of `h y`
- have Hle : m ≤ n := by linarith
+ have Hle : m ≤ n := by omega
have Hffmono := fix_fuel_mono Hkmono Hle
have Hmono := Hhmono y Hffmono
simp [result_rel] at Hmono
@@ -568,6 +566,7 @@ namespace FixI
have Heq := Fix.is_valid_fix_fixed_eq Hvalid'
simp [fix]
conv => lhs; rw [Heq]
+ rfl
/- Some utilities to define the mutually recursive functions -/
@@ -778,6 +777,7 @@ namespace FixII
have Heq := Fix.is_valid_fix_fixed_eq Hvalid'
simp [fix]
conv => lhs; rw [Heq]
+ rfl
/- Some utilities to define the mutually recursive functions -/
@@ -966,6 +966,7 @@ namespace Ex1
have Heq := is_valid_fix_fixed_eq (@list_nth_body_is_valid a)
simp [list_nth]
conv => lhs; rw [Heq]
+ rfl
end Ex1
@@ -1011,6 +1012,7 @@ namespace Ex2
have Heq := is_valid_fix_fixed_eq (@list_nth_body_is_valid a)
simp [list_nth]
conv => lhs; rw [Heq]
+ rfl
end Ex2
@@ -1183,6 +1185,7 @@ namespace Ex4
.ok b) := by
simp [is_even, is_odd];
conv => lhs; rw [body_fix_eq]
+ rfl
theorem is_odd_eq (i : Int) : is_odd i =
(if i = 0
@@ -1192,6 +1195,7 @@ namespace Ex4
.ok b) := by
simp [is_even, is_odd];
conv => lhs; rw [body_fix_eq]
+ rfl
end Ex4
namespace Ex5
@@ -1263,6 +1267,7 @@ namespace Ex5
have Heq := is_valid_fix_fixed_eq (@id_body_is_valid a)
simp [id]
conv => lhs; rw [Heq]; simp; rw [id_body]
+ rfl
end Ex5
@@ -1336,6 +1341,7 @@ namespace Ex6
have Heq := is_valid_fix_fixed_eq body_is_valid
simp [list_nth]
conv => lhs; rw [Heq]
+ rfl
-- Write the proof term explicitly: the generation of the proof term (without tactics)
-- is automatable, and the proof term is actually a lot simpler and smaller when we
@@ -1429,6 +1435,7 @@ namespace Ex7
have Heq := is_valid_fix_fixed_eq body_is_valid
simp [list_nth]
conv => lhs; rw [Heq]
+ rfl
-- Write the proof term explicitly: the generation of the proof term (without tactics)
-- is automatable, and the proof term is actually a lot simpler and smaller when we
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean
index 5db8ffed..60955051 100644
--- a/backends/lean/Base/Diverge/Elab.lean
+++ b/backends/lean/Base/Diverge/Elab.lean
@@ -22,6 +22,10 @@ def normalize_let_bindings := true
open WF in
+-- Small utility - it seems that `Name.append` doesn't do what we want
+def appendToName (n : Name) (s : String) : Name :=
+ Name.str n s
+
-- TODO: use those
def UnitType := Expr.const ``PUnit [Level.succ .zero]
def UnitValue := Expr.const ``PUnit.unit [Level.succ .zero]
@@ -548,7 +552,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
-- Add the declaration
let value ← mkLambdaFVars #[kk_var] body
trace[Diverge.def.genBody] "Body after abstracting kk: {value}"
- let name := preDef.declName.append "body"
+ let name := appendToName preDef.declName "body"
let levelParams := grLvlParams
let decl := Declaration.defnDecl {
name := name
@@ -603,7 +607,7 @@ def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name)
let body ← mkLambdaFVars #[kk_var, i_var] body
trace[Diverge.def] "mkDeclareMutRecBody: body: {body}"
-- Add the declaration
- let name := grName.append "mut_rec_body"
+ let name := appendToName grName "mut_rec_body"
let levelParams := grLvlParams
let decl := Declaration.defnDecl {
name := name
@@ -1047,7 +1051,7 @@ partial def proveSingleBodyIsValid
mkForallFVars #[k_var, t_var, x_var] ty
trace[Diverge.def.valid] "proveSingleBodyIsValid: thmTy\n{thmTy}:\n{← inferType thmTy}"
-- Save the theorem
- let name := preDef.declName ++ "body_is_valid"
+ let name := appendToName preDef.declName "body_is_valid"
let decl := Declaration.thmDecl {
name
levelParams := preDef.levelParams
@@ -1107,7 +1111,7 @@ def proveMutRecIsValid
trace[Diverge.def.valid] "Generated the term: {isValid}"
-- Save the theorem
let thmTy ← mkAppM ``FixII.is_valid #[mutRecBodyConst]
- let name := grName ++ "mut_rec_body_is_valid"
+ let name := appendToName grName "mut_rec_body_is_valid"
let decl := Declaration.thmDecl {
name
levelParams := grLvlParams
@@ -1196,7 +1200,7 @@ partial def proveUnfoldingThms (isValidThm : Expr)
let proof ← mkLambdaFVars xs proof
trace[Diverge.def.unfold] "proveUnfoldingThms: proof: {proof}:\n{← inferType proof}"
-- Declare the theorem
- let name := preDef.declName ++ "unfold"
+ let name := appendToName preDef.declName "unfold"
let decl := Declaration.thmDecl {
name
levelParams := preDef.levelParams
@@ -1282,7 +1286,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
-- Add an auxiliary definition for `param_in_out_ty` (this is a potentially big term)
let param_in_out_ty ← do
let value ← mkLambdaFVars #[i_var] param_in_out_ty
- let name := grName.append "param_in_out_ty"
+ let name := appendToName grName "param_in_out_ty"
let levelParams := grLvlParams
let decl := Declaration.defnDecl {
name := name
@@ -1392,44 +1396,71 @@ def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLC
open private elabHeaders levelMVarToParamHeaders getAllUserLevelNames withFunLocalDecls elabFunValues
instantiateMVarsAtHeader instantiateMVarsAtLetRecToLift checkLetRecsToLiftTypes withUsed from Lean.Elab.MutualDef
-def Term.elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do
- let scopeLevelNames ← getLevelNames
- let headers ← elabHeaders views
- let headers ← levelMVarToParamHeaders views headers
- let allUserLevelNames := getAllUserLevelNames headers
- withFunLocalDecls headers fun funFVars => do
- for view in views, funFVar in funFVars do
- addLocalVarInfo view.declId funFVar
- -- Add fake use site to prevent "unused variable" warning (if the
- -- function is actually not recursive, Lean would print this warning).
- -- Remark: we could detect this case and encode the function without
- -- using the fixed-point. In practice it shouldn't happen however:
- -- we define non-recursive functions with the `divergent` keyword
- -- only for testing purposes.
- addTermInfo' view.declId funFVar
- let values ←
- try
- let values ← elabFunValues headers
- Term.synthesizeSyntheticMVarsNoPostponing
- values.mapM (instantiateMVars ·)
- catch ex =>
- logException ex
- headers.mapM fun header => mkSorry header.type (synthetic := true)
- let headers ← headers.mapM instantiateMVarsAtHeader
- let letRecsToLift ← getLetRecsToLift
- let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift
- checkLetRecsToLiftTypes funFVars letRecsToLift
- withUsed vars headers values letRecsToLift fun vars => do
- let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift
- for preDef in preDefs do
- trace[Diverge.elab] "{preDef.declName} : {preDef.type} :=\n{preDef.value}"
- let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamPreDecls preDefs
- let preDefs ← instantiateMVarsAtPreDecls preDefs
- let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames
- for preDef in preDefs do
- trace[Diverge.elab] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}"
- checkForHiddenUnivLevels allUserLevelNames preDefs
- addPreDefinitions preDefs
+-- Comes from Term.isExample
+def isExample (views : Array DefView) : Bool :=
+ views.any (·.kind.isExample)
+
+open Language in
+def Term.elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit :=
+ if isExample views then
+ withoutModifyingEnv do
+ -- save correct environment in info tree
+ withSaveInfoContext do
+ go
+ else
+ go
+where
+ go :=
+ withAlwaysResolvedPromises views.size fun bodyPromises =>
+ withAlwaysResolvedPromises views.size fun tacPromises => do
+ let scopeLevelNames ← getLevelNames
+ let headers ← elabHeaders views bodyPromises tacPromises
+ let headers ← levelMVarToParamHeaders views headers
+ let allUserLevelNames := getAllUserLevelNames headers
+ withFunLocalDecls headers fun funFVars => do
+ for view in views, funFVar in funFVars do
+ addLocalVarInfo view.declId funFVar
+ -- Modification 1:
+ -- Add fake use site to prevent "unused variable" warning (if the
+ -- function is actually not recursive, Lean would print this warning).
+ -- Remark: we could detect this case and encode the function without
+ -- using the fixed-point. In practice it shouldn't happen however:
+ -- we define non-recursive functions with the `divergent` keyword
+ -- only for testing purposes.
+ addTermInfo' view.declId funFVar
+ let values ←
+ try
+ let values ← elabFunValues headers
+ Term.synthesizeSyntheticMVarsNoPostponing
+ values.mapM (instantiateMVars ·)
+ catch ex =>
+ logException ex
+ headers.mapM fun header => mkSorry header.type (synthetic := true)
+ let headers ← headers.mapM instantiateMVarsAtHeader
+ let letRecsToLift ← getLetRecsToLift
+ let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift
+ checkLetRecsToLiftTypes funFVars letRecsToLift
+ withUsed vars headers values letRecsToLift fun vars => do
+ let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift
+ for preDef in preDefs do
+ trace[Elab.definition] "{preDef.declName} : {preDef.type} :=\n{preDef.value}"
+ let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamPreDecls preDefs
+ let preDefs ← instantiateMVarsAtPreDecls preDefs
+ let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames
+ for preDef in preDefs do
+ trace[Elab.definition] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}"
+ checkForHiddenUnivLevels allUserLevelNames preDefs
+ addPreDefinitions preDefs -- Modification 2: we use our custom function here
+ processDeriving headers
+
+ processDeriving (headers : Array DefViewElabHeader) := do
+ for header in headers, view in views do
+ if let some classNamesStx := view.deriving? then
+ for classNameStx in classNamesStx do
+ let className ← realizeGlobalConstNoOverload classNameStx
+ withRef classNameStx do
+ unless (← processDefDeriving className header.declName) do
+ throwError "failed to synthesize instance '{className}' for '{header.declName}'"
open Command in
def Command.elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do
@@ -1439,7 +1470,8 @@ def Command.elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do
let modifiers ← elabModifiers mods
let (binders, type) := expandOptDeclSig sig
let deriving? := none
- pure { ref := d, kind := DefKind.def, modifiers,
+ let headerRef := Syntax.missing -- Not sure what to put here
+ pure { ref := d, kind := DefKind.def, headerRef, modifiers,
declId := id, binders, type? := type, value := val, deriving? }
runTermElabM fun vars => Term.elabMutualDef vars views
@@ -1460,7 +1492,7 @@ elab_rules : command
if (`_root_).isPrefixOf name then throwUnsupportedSyntax
let view := extractMacroScopes name
let .str ns shortName := view.name | throwUnsupportedSyntax
- let shortName' := { view with name := shortName }.review
+ let shortName' := { view with name := Name.mkSimple shortName }.review
let cmd ← `(mutual $mods:declModifiers divergent%$tk def $(⟨setDeclIdName id shortName'⟩):declId $sig:optDeclSig $val:declVal end)
if ns matches .anonymous then
Command.elabCommand cmd
@@ -1475,6 +1507,7 @@ namespace Tests
--set_option trace.Diverge.def.genBody true
--set_option trace.Diverge.def.valid true
--set_option trace.Diverge.def.genBody.visit true
+ --set_option trace.Diverge.def.unfold true
divergent def list_nth {a: Type u} (ls : List a) (i : Int) : Result a :=
match ls with
@@ -1492,7 +1525,7 @@ namespace Tests
0 ≤ i → i < ls.length →
∃ x, list_nth ls i = .ok x := by
induction ls
- . intro i hpos h; simp at h; linarith
+ . intro i hpos h; simp at h; omega
. rename_i hd tl ih
intro i hpos h
-- We can directly use `rw [list_nth]`
@@ -1502,7 +1535,7 @@ namespace Tests
. -- We don't have to do this if we use scalar_tac
have hneq : 0 < i := by cases i <;> rename_i a _ <;> simp_all; cases a <;> simp_all
simp at h
- have ⟨ x, ih ⟩ := ih (i - 1) (by linarith) (by linarith)
+ have ⟨ x, ih ⟩ := ih (i - 1) (by omega) (by omega)
simp [ih]
tauto