diff options
Diffstat (limited to 'backends/lean/Base/Diverge')
-rw-r--r-- | backends/lean/Base/Diverge/Base.lean | 21 | ||||
-rw-r--r-- | backends/lean/Base/Diverge/Elab.lean | 129 |
2 files changed, 95 insertions, 55 deletions
diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index 0f20125f..aab4db8f 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -1,7 +1,6 @@ import Lean import Lean.Meta.Tactic.Simp import Init.Data.List.Basic -import Mathlib.Tactic.Linarith import Base.Primitives.Base import Base.Arith.Base import Base.Diverge.ElabBase @@ -36,20 +35,19 @@ namespace Lemmas revert m induction k -- TODO: induction h rather? case zero => - simp_all intro m h1 h2 have h: n = m := by omega unfold for_all_fin_aux; simp_all simp_all -- There is no i s.t. m ≤ i intro i h3; cases i; simp_all - linarith + omega case succ k hi => intro m hk hmn intro hf i hmi have hne: m ≠ n := by have hineq := Nat.lt_of_sub_eq_succ hk - linarith + omega -- m = i? if heq: m = i then -- Yes: simply use the `for_all_fin_aux` hyp @@ -64,7 +62,7 @@ namespace Lemmas have heq1: n - (m + 1) = k := by -- TODO: very annoying arithmetic proof simp [Nat.sub_eq_iff_eq_add hineq] - have hineq1: m ≤ n := by linarith + have hineq1: m ≤ n := by omega simp [Nat.sub_eq_iff_eq_add hineq1] at hk simp_arith [hk] have hi := hi (m + 1) heq1 hineq @@ -199,7 +197,7 @@ namespace Fix | 0 => exfalso zify at * - linarith + omega | Nat.succ m1 => simp_arith at Hle simp [fix_fuel] @@ -407,7 +405,7 @@ namespace Fix . simp at Hl -- Make a case disjunction on `h y (fix_fuel m k)`: if it is not equal -- to div, use the monotonicity of `h y` - have Hle : m ≤ n := by linarith + have Hle : m ≤ n := by omega have Hffmono := fix_fuel_mono Hkmono Hle have Hmono := Hhmono y Hffmono simp [result_rel] at Hmono @@ -568,6 +566,7 @@ namespace FixI have Heq := Fix.is_valid_fix_fixed_eq Hvalid' simp [fix] conv => lhs; rw [Heq] + rfl /- Some utilities to define the mutually recursive functions -/ @@ -778,6 +777,7 @@ namespace FixII have Heq := Fix.is_valid_fix_fixed_eq Hvalid' simp [fix] conv => lhs; rw [Heq] + rfl /- Some utilities to define the mutually recursive functions -/ @@ -966,6 +966,7 @@ namespace Ex1 have Heq := is_valid_fix_fixed_eq (@list_nth_body_is_valid a) simp [list_nth] conv => lhs; rw [Heq] + rfl end Ex1 @@ -1011,6 +1012,7 @@ namespace Ex2 have Heq := is_valid_fix_fixed_eq (@list_nth_body_is_valid a) simp [list_nth] conv => lhs; rw [Heq] + rfl end Ex2 @@ -1183,6 +1185,7 @@ namespace Ex4 .ok b) := by simp [is_even, is_odd]; conv => lhs; rw [body_fix_eq] + rfl theorem is_odd_eq (i : Int) : is_odd i = (if i = 0 @@ -1192,6 +1195,7 @@ namespace Ex4 .ok b) := by simp [is_even, is_odd]; conv => lhs; rw [body_fix_eq] + rfl end Ex4 namespace Ex5 @@ -1263,6 +1267,7 @@ namespace Ex5 have Heq := is_valid_fix_fixed_eq (@id_body_is_valid a) simp [id] conv => lhs; rw [Heq]; simp; rw [id_body] + rfl end Ex5 @@ -1336,6 +1341,7 @@ namespace Ex6 have Heq := is_valid_fix_fixed_eq body_is_valid simp [list_nth] conv => lhs; rw [Heq] + rfl -- Write the proof term explicitly: the generation of the proof term (without tactics) -- is automatable, and the proof term is actually a lot simpler and smaller when we @@ -1429,6 +1435,7 @@ namespace Ex7 have Heq := is_valid_fix_fixed_eq body_is_valid simp [list_nth] conv => lhs; rw [Heq] + rfl -- Write the proof term explicitly: the generation of the proof term (without tactics) -- is automatable, and the proof term is actually a lot simpler and smaller when we diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 5db8ffed..60955051 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -22,6 +22,10 @@ def normalize_let_bindings := true open WF in +-- Small utility - it seems that `Name.append` doesn't do what we want +def appendToName (n : Name) (s : String) : Name := + Name.str n s + -- TODO: use those def UnitType := Expr.const ``PUnit [Level.succ .zero] def UnitValue := Expr.const ``PUnit.unit [Level.succ .zero] @@ -548,7 +552,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) -- Add the declaration let value ← mkLambdaFVars #[kk_var] body trace[Diverge.def.genBody] "Body after abstracting kk: {value}" - let name := preDef.declName.append "body" + let name := appendToName preDef.declName "body" let levelParams := grLvlParams let decl := Declaration.defnDecl { name := name @@ -603,7 +607,7 @@ def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name) let body ← mkLambdaFVars #[kk_var, i_var] body trace[Diverge.def] "mkDeclareMutRecBody: body: {body}" -- Add the declaration - let name := grName.append "mut_rec_body" + let name := appendToName grName "mut_rec_body" let levelParams := grLvlParams let decl := Declaration.defnDecl { name := name @@ -1047,7 +1051,7 @@ partial def proveSingleBodyIsValid mkForallFVars #[k_var, t_var, x_var] ty trace[Diverge.def.valid] "proveSingleBodyIsValid: thmTy\n{thmTy}:\n{← inferType thmTy}" -- Save the theorem - let name := preDef.declName ++ "body_is_valid" + let name := appendToName preDef.declName "body_is_valid" let decl := Declaration.thmDecl { name levelParams := preDef.levelParams @@ -1107,7 +1111,7 @@ def proveMutRecIsValid trace[Diverge.def.valid] "Generated the term: {isValid}" -- Save the theorem let thmTy ← mkAppM ``FixII.is_valid #[mutRecBodyConst] - let name := grName ++ "mut_rec_body_is_valid" + let name := appendToName grName "mut_rec_body_is_valid" let decl := Declaration.thmDecl { name levelParams := grLvlParams @@ -1196,7 +1200,7 @@ partial def proveUnfoldingThms (isValidThm : Expr) let proof ← mkLambdaFVars xs proof trace[Diverge.def.unfold] "proveUnfoldingThms: proof: {proof}:\n{← inferType proof}" -- Declare the theorem - let name := preDef.declName ++ "unfold" + let name := appendToName preDef.declName "unfold" let decl := Declaration.thmDecl { name levelParams := preDef.levelParams @@ -1282,7 +1286,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Add an auxiliary definition for `param_in_out_ty` (this is a potentially big term) let param_in_out_ty ← do let value ← mkLambdaFVars #[i_var] param_in_out_ty - let name := grName.append "param_in_out_ty" + let name := appendToName grName "param_in_out_ty" let levelParams := grLvlParams let decl := Declaration.defnDecl { name := name @@ -1392,44 +1396,71 @@ def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLC open private elabHeaders levelMVarToParamHeaders getAllUserLevelNames withFunLocalDecls elabFunValues instantiateMVarsAtHeader instantiateMVarsAtLetRecToLift checkLetRecsToLiftTypes withUsed from Lean.Elab.MutualDef -def Term.elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do - let scopeLevelNames ← getLevelNames - let headers ← elabHeaders views - let headers ← levelMVarToParamHeaders views headers - let allUserLevelNames := getAllUserLevelNames headers - withFunLocalDecls headers fun funFVars => do - for view in views, funFVar in funFVars do - addLocalVarInfo view.declId funFVar - -- Add fake use site to prevent "unused variable" warning (if the - -- function is actually not recursive, Lean would print this warning). - -- Remark: we could detect this case and encode the function without - -- using the fixed-point. In practice it shouldn't happen however: - -- we define non-recursive functions with the `divergent` keyword - -- only for testing purposes. - addTermInfo' view.declId funFVar - let values ← - try - let values ← elabFunValues headers - Term.synthesizeSyntheticMVarsNoPostponing - values.mapM (instantiateMVars ·) - catch ex => - logException ex - headers.mapM fun header => mkSorry header.type (synthetic := true) - let headers ← headers.mapM instantiateMVarsAtHeader - let letRecsToLift ← getLetRecsToLift - let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift - checkLetRecsToLiftTypes funFVars letRecsToLift - withUsed vars headers values letRecsToLift fun vars => do - let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift - for preDef in preDefs do - trace[Diverge.elab] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" - let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamPreDecls preDefs - let preDefs ← instantiateMVarsAtPreDecls preDefs - let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames - for preDef in preDefs do - trace[Diverge.elab] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}" - checkForHiddenUnivLevels allUserLevelNames preDefs - addPreDefinitions preDefs +-- Comes from Term.isExample +def isExample (views : Array DefView) : Bool := + views.any (·.kind.isExample) + +open Language in +def Term.elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := + if isExample views then + withoutModifyingEnv do + -- save correct environment in info tree + withSaveInfoContext do + go + else + go +where + go := + withAlwaysResolvedPromises views.size fun bodyPromises => + withAlwaysResolvedPromises views.size fun tacPromises => do + let scopeLevelNames ← getLevelNames + let headers ← elabHeaders views bodyPromises tacPromises + let headers ← levelMVarToParamHeaders views headers + let allUserLevelNames := getAllUserLevelNames headers + withFunLocalDecls headers fun funFVars => do + for view in views, funFVar in funFVars do + addLocalVarInfo view.declId funFVar + -- Modification 1: + -- Add fake use site to prevent "unused variable" warning (if the + -- function is actually not recursive, Lean would print this warning). + -- Remark: we could detect this case and encode the function without + -- using the fixed-point. In practice it shouldn't happen however: + -- we define non-recursive functions with the `divergent` keyword + -- only for testing purposes. + addTermInfo' view.declId funFVar + let values ← + try + let values ← elabFunValues headers + Term.synthesizeSyntheticMVarsNoPostponing + values.mapM (instantiateMVars ·) + catch ex => + logException ex + headers.mapM fun header => mkSorry header.type (synthetic := true) + let headers ← headers.mapM instantiateMVarsAtHeader + let letRecsToLift ← getLetRecsToLift + let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift + checkLetRecsToLiftTypes funFVars letRecsToLift + withUsed vars headers values letRecsToLift fun vars => do + let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift + for preDef in preDefs do + trace[Elab.definition] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" + let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamPreDecls preDefs + let preDefs ← instantiateMVarsAtPreDecls preDefs + let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames + for preDef in preDefs do + trace[Elab.definition] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}" + checkForHiddenUnivLevels allUserLevelNames preDefs + addPreDefinitions preDefs -- Modification 2: we use our custom function here + processDeriving headers + + processDeriving (headers : Array DefViewElabHeader) := do + for header in headers, view in views do + if let some classNamesStx := view.deriving? then + for classNameStx in classNamesStx do + let className ← realizeGlobalConstNoOverload classNameStx + withRef classNameStx do + unless (← processDefDeriving className header.declName) do + throwError "failed to synthesize instance '{className}' for '{header.declName}'" open Command in def Command.elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do @@ -1439,7 +1470,8 @@ def Command.elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do let modifiers ← elabModifiers mods let (binders, type) := expandOptDeclSig sig let deriving? := none - pure { ref := d, kind := DefKind.def, modifiers, + let headerRef := Syntax.missing -- Not sure what to put here + pure { ref := d, kind := DefKind.def, headerRef, modifiers, declId := id, binders, type? := type, value := val, deriving? } runTermElabM fun vars => Term.elabMutualDef vars views @@ -1460,7 +1492,7 @@ elab_rules : command if (`_root_).isPrefixOf name then throwUnsupportedSyntax let view := extractMacroScopes name let .str ns shortName := view.name | throwUnsupportedSyntax - let shortName' := { view with name := shortName }.review + let shortName' := { view with name := Name.mkSimple shortName }.review let cmd ← `(mutual $mods:declModifiers divergent%$tk def $(⟨setDeclIdName id shortName'⟩):declId $sig:optDeclSig $val:declVal end) if ns matches .anonymous then Command.elabCommand cmd @@ -1475,6 +1507,7 @@ namespace Tests --set_option trace.Diverge.def.genBody true --set_option trace.Diverge.def.valid true --set_option trace.Diverge.def.genBody.visit true + --set_option trace.Diverge.def.unfold true divergent def list_nth {a: Type u} (ls : List a) (i : Int) : Result a := match ls with @@ -1492,7 +1525,7 @@ namespace Tests 0 ≤ i → i < ls.length → ∃ x, list_nth ls i = .ok x := by induction ls - . intro i hpos h; simp at h; linarith + . intro i hpos h; simp at h; omega . rename_i hd tl ih intro i hpos h -- We can directly use `rw [list_nth]` @@ -1502,7 +1535,7 @@ namespace Tests . -- We don't have to do this if we use scalar_tac have hneq : 0 < i := by cases i <;> rename_i a _ <;> simp_all; cases a <;> simp_all simp at h - have ⟨ x, ih ⟩ := ih (i - 1) (by linarith) (by linarith) + have ⟨ x, ih ⟩ := ih (i - 1) (by omega) (by omega) simp [ih] tauto |