diff options
Diffstat (limited to 'backends/lean/Base')
-rw-r--r-- | backends/lean/Base/Arith/Base.lean | 11 | ||||
-rw-r--r-- | backends/lean/Base/Arith/Int.lean | 10 | ||||
-rw-r--r-- | backends/lean/Base/Arith/Scalar.lean | 59 | ||||
-rw-r--r-- | backends/lean/Base/Diverge/Base.lean | 21 | ||||
-rw-r--r-- | backends/lean/Base/Diverge/Elab.lean | 129 | ||||
-rw-r--r-- | backends/lean/Base/Extensions.lean | 1 | ||||
-rw-r--r-- | backends/lean/Base/IList/IList.lean | 12 | ||||
-rw-r--r-- | backends/lean/Base/Primitives.lean | 1 | ||||
-rw-r--r-- | backends/lean/Base/Primitives/ArraySlice.lean | 5 | ||||
-rw-r--r-- | backends/lean/Base/Primitives/CoreConvertNum.lean | 1 | ||||
-rw-r--r-- | backends/lean/Base/Primitives/Scalar.lean | 154 | ||||
-rw-r--r-- | backends/lean/Base/Primitives/ScalarNotations.lean | 109 | ||||
-rw-r--r-- | backends/lean/Base/Primitives/Vec.lean | 7 | ||||
-rw-r--r-- | backends/lean/Base/Progress/Base.lean | 3 | ||||
-rw-r--r-- | backends/lean/Base/Progress/Progress.lean | 23 | ||||
-rw-r--r-- | backends/lean/Base/Utils.lean | 123 |
16 files changed, 400 insertions, 269 deletions
diff --git a/backends/lean/Base/Arith/Base.lean b/backends/lean/Base/Arith/Base.lean index 8ada4171..fb6b12e5 100644 --- a/backends/lean/Base/Arith/Base.lean +++ b/backends/lean/Base/Arith/Base.lean @@ -1,6 +1,5 @@ import Lean -import Std.Data.Int.Lemmas -import Mathlib.Tactic.Linarith +import Mathlib.Tactic.Linarith -- Introduces a lot of useful lemmas namespace Arith @@ -21,12 +20,12 @@ theorem ne_is_lt_or_gt {x y : Int} (hne : x ≠ y) : x < y ∨ x > y := by have hne : x - y ≠ 0 := by simp intro h - have: x = y := by linarith + have: x = y := by omega simp_all have h := ne_zero_is_lt_or_gt hne match h with - | .inl _ => left; linarith - | .inr _ => right; linarith + | .inl _ => left; omega + | .inr _ => right; omega -- TODO: move? theorem add_one_le_iff_le_ne (n m : Nat) (h1 : m ≤ n) (h2 : m ≠ n) : m + 1 ≤ n := by @@ -66,7 +65,7 @@ theorem to_int_to_nat_lt (x y : ℤ) (h0 : 0 ≤ x) (h1 : x < y) : theorem to_int_sub_to_nat_lt (x y : ℤ) (x' : ℕ) (h0 : ↑x' ≤ x) (h1 : x - ↑x' < y) : ↑(x.toNat - x') < y := by - have : 0 ≤ x := by linarith + have : 0 ≤ x := by omega simp [Int.toNat_sub_of_le, *] end Arith diff --git a/backends/lean/Base/Arith/Int.lean b/backends/lean/Base/Arith/Int.lean index a1cb9da3..ab6dd4ab 100644 --- a/backends/lean/Base/Arith/Int.lean +++ b/backends/lean/Base/Arith/Int.lean @@ -27,7 +27,7 @@ class HasIntPred {a: Sort u} (x: a) where prop : concl /- Proposition with implications: if we find P we can introduce Q in the context -/ -class PropHasImp (x : Prop) where +class PropHasImp (x : Sort u) where concl : Prop prop : x → concl @@ -199,7 +199,7 @@ def introInstances (declToUnfold : Name) (lookup : Expr → MetaM (Option Expr)) -- Add a declaration let nval ← Utils.addDeclTac name e type (asLet := false) -- Simplify to unfold the declaration to unfold (i.e., the projector) - Utils.simpAt true [declToUnfold] [] [] (Location.targets #[mkIdent name] false) + Utils.simpAt true {} #[] [declToUnfold] [] [] (Location.targets #[mkIdent name] false) -- Return the new value pure nval @@ -242,7 +242,7 @@ def intTacPreprocess (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM U extraPreprocess -- Reduce all the terms in the goal - note that the extra preprocessing step -- might have proven the goal, hence the `Tactic.allGoals` - Tactic.allGoals do tryTac (dsimpAt false [] [] [] Tactic.Location.wildcard) + Tactic.allGoals do tryTac (dsimpAt false {} #[] [] [] [] Tactic.Location.wildcard) elab "int_tac_preprocess" : tactic => intTacPreprocess (do pure ()) @@ -259,10 +259,10 @@ def intTac (tacName : String) (splitGoalConjs : Bool) (extraPreprocess : Tactic -- the goal. I think before leads to a smaller proof term? Tactic.allGoals (intTacPreprocess extraPreprocess) -- More preprocessing - Tactic.allGoals (Utils.tryTac (Utils.simpAt true [] [``nat_zero_eq_int_zero] [] .wildcard)) + Tactic.allGoals (Utils.tryTac (Utils.simpAt true {} #[] [] [``nat_zero_eq_int_zero] [] .wildcard)) -- Split the conjunctions in the goal if splitGoalConjs then Tactic.allGoals (Utils.repeatTac Utils.splitConjTarget) - -- Call linarith + -- Call omega Tactic.allGoals do try do Tactic.Omega.omegaTactic {} catch _ => diff --git a/backends/lean/Base/Arith/Scalar.lean b/backends/lean/Base/Arith/Scalar.lean index 86b2e216..ecc5acaf 100644 --- a/backends/lean/Base/Arith/Scalar.lean +++ b/backends/lean/Base/Arith/Scalar.lean @@ -8,30 +8,31 @@ open Lean Lean.Elab Lean.Meta open Primitives def scalarTacExtraPreprocess : Tactic.TacticM Unit := do - Tactic.withMainContext do - -- Inroduce the bounds for the isize/usize types - let add (e : Expr) : Tactic.TacticM Unit := do - let ty ← inferType e - let _ ← Utils.addDeclTac (← Utils.mkFreshAnonPropUserName) e ty (asLet := false) - add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Isize []]) - add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []]) - add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []]) - -- Reveal the concrete bounds, simplify calls to [ofInt] - Utils.simpAt true - -- Unfoldings - [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax, - ``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min, - ``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max, - ``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min, - ``U8.max, ``U16.max, ``U32.max, ``U64.max, ``U128.max, - ``Usize.min - ] - -- Simp lemmas - [``Scalar.ofInt_val_eq, ``Scalar.neq_to_neq_val, - ``Scalar.lt_equiv, ``Scalar.le_equiv, ``Scalar.eq_equiv] - -- Hypotheses - [] .wildcard - + Tactic.withMainContext do + -- Inroduce the bounds for the isize/usize types + let add (e : Expr) : Tactic.TacticM Unit := do + let ty ← inferType e + let _ ← Utils.addDeclTac (← Utils.mkFreshAnonPropUserName) e ty (asLet := false) + add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Isize []]) + add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []]) + add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []]) + -- Reveal the concrete bounds, simplify calls to [ofInt] + Utils.simpAt true {} + -- Simprocs + #[] + -- Unfoldings + [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax, + ``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min, + ``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max, + ``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min, + ``U8.max, ``U16.max, ``U32.max, ``U64.max, ``U128.max, + ``Usize.min + ] + -- Simp lemmas + [``Scalar.ofInt_val_eq, ``Scalar.neq_to_neq_val, + ``Scalar.lt_equiv, ``Scalar.le_equiv, ``Scalar.eq_equiv] + -- Hypotheses + [] .wildcard elab "scalar_tac_preprocess" : tactic => intTacPreprocess scalarTacExtraPreprocess @@ -81,4 +82,14 @@ example (x : Int) (h0 : 0 ≤ x) (h1 : x ≤ U32.max) : example (x : U32) (h0 : ¬ x = U32.ofInt 0) : 0 < x.val := by scalar_tac +/- See this: https://aeneas-verif.zulipchat.com/#narrow/stream/349819-general/topic/U64.20trouble/near/444049757 + + We solved it by removing the instance `OfNat` for `Scalar`. + Note however that we could also solve it with a simplification lemma. + However, after testing, we noticed we could only apply such a lemma with + the rewriting tactic (not the simplifier), probably because of the use + of typeclasses. -/ +example {u: U64} (h1: (u : Int) < 2): (u : Int) = 0 ∨ (u : Int) = 1 := by + scalar_tac + end Arith diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index 0f20125f..aab4db8f 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -1,7 +1,6 @@ import Lean import Lean.Meta.Tactic.Simp import Init.Data.List.Basic -import Mathlib.Tactic.Linarith import Base.Primitives.Base import Base.Arith.Base import Base.Diverge.ElabBase @@ -36,20 +35,19 @@ namespace Lemmas revert m induction k -- TODO: induction h rather? case zero => - simp_all intro m h1 h2 have h: n = m := by omega unfold for_all_fin_aux; simp_all simp_all -- There is no i s.t. m ≤ i intro i h3; cases i; simp_all - linarith + omega case succ k hi => intro m hk hmn intro hf i hmi have hne: m ≠ n := by have hineq := Nat.lt_of_sub_eq_succ hk - linarith + omega -- m = i? if heq: m = i then -- Yes: simply use the `for_all_fin_aux` hyp @@ -64,7 +62,7 @@ namespace Lemmas have heq1: n - (m + 1) = k := by -- TODO: very annoying arithmetic proof simp [Nat.sub_eq_iff_eq_add hineq] - have hineq1: m ≤ n := by linarith + have hineq1: m ≤ n := by omega simp [Nat.sub_eq_iff_eq_add hineq1] at hk simp_arith [hk] have hi := hi (m + 1) heq1 hineq @@ -199,7 +197,7 @@ namespace Fix | 0 => exfalso zify at * - linarith + omega | Nat.succ m1 => simp_arith at Hle simp [fix_fuel] @@ -407,7 +405,7 @@ namespace Fix . simp at Hl -- Make a case disjunction on `h y (fix_fuel m k)`: if it is not equal -- to div, use the monotonicity of `h y` - have Hle : m ≤ n := by linarith + have Hle : m ≤ n := by omega have Hffmono := fix_fuel_mono Hkmono Hle have Hmono := Hhmono y Hffmono simp [result_rel] at Hmono @@ -568,6 +566,7 @@ namespace FixI have Heq := Fix.is_valid_fix_fixed_eq Hvalid' simp [fix] conv => lhs; rw [Heq] + rfl /- Some utilities to define the mutually recursive functions -/ @@ -778,6 +777,7 @@ namespace FixII have Heq := Fix.is_valid_fix_fixed_eq Hvalid' simp [fix] conv => lhs; rw [Heq] + rfl /- Some utilities to define the mutually recursive functions -/ @@ -966,6 +966,7 @@ namespace Ex1 have Heq := is_valid_fix_fixed_eq (@list_nth_body_is_valid a) simp [list_nth] conv => lhs; rw [Heq] + rfl end Ex1 @@ -1011,6 +1012,7 @@ namespace Ex2 have Heq := is_valid_fix_fixed_eq (@list_nth_body_is_valid a) simp [list_nth] conv => lhs; rw [Heq] + rfl end Ex2 @@ -1183,6 +1185,7 @@ namespace Ex4 .ok b) := by simp [is_even, is_odd]; conv => lhs; rw [body_fix_eq] + rfl theorem is_odd_eq (i : Int) : is_odd i = (if i = 0 @@ -1192,6 +1195,7 @@ namespace Ex4 .ok b) := by simp [is_even, is_odd]; conv => lhs; rw [body_fix_eq] + rfl end Ex4 namespace Ex5 @@ -1263,6 +1267,7 @@ namespace Ex5 have Heq := is_valid_fix_fixed_eq (@id_body_is_valid a) simp [id] conv => lhs; rw [Heq]; simp; rw [id_body] + rfl end Ex5 @@ -1336,6 +1341,7 @@ namespace Ex6 have Heq := is_valid_fix_fixed_eq body_is_valid simp [list_nth] conv => lhs; rw [Heq] + rfl -- Write the proof term explicitly: the generation of the proof term (without tactics) -- is automatable, and the proof term is actually a lot simpler and smaller when we @@ -1429,6 +1435,7 @@ namespace Ex7 have Heq := is_valid_fix_fixed_eq body_is_valid simp [list_nth] conv => lhs; rw [Heq] + rfl -- Write the proof term explicitly: the generation of the proof term (without tactics) -- is automatable, and the proof term is actually a lot simpler and smaller when we diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 5db8ffed..60955051 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -22,6 +22,10 @@ def normalize_let_bindings := true open WF in +-- Small utility - it seems that `Name.append` doesn't do what we want +def appendToName (n : Name) (s : String) : Name := + Name.str n s + -- TODO: use those def UnitType := Expr.const ``PUnit [Level.succ .zero] def UnitValue := Expr.const ``PUnit.unit [Level.succ .zero] @@ -548,7 +552,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) -- Add the declaration let value ← mkLambdaFVars #[kk_var] body trace[Diverge.def.genBody] "Body after abstracting kk: {value}" - let name := preDef.declName.append "body" + let name := appendToName preDef.declName "body" let levelParams := grLvlParams let decl := Declaration.defnDecl { name := name @@ -603,7 +607,7 @@ def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name) let body ← mkLambdaFVars #[kk_var, i_var] body trace[Diverge.def] "mkDeclareMutRecBody: body: {body}" -- Add the declaration - let name := grName.append "mut_rec_body" + let name := appendToName grName "mut_rec_body" let levelParams := grLvlParams let decl := Declaration.defnDecl { name := name @@ -1047,7 +1051,7 @@ partial def proveSingleBodyIsValid mkForallFVars #[k_var, t_var, x_var] ty trace[Diverge.def.valid] "proveSingleBodyIsValid: thmTy\n{thmTy}:\n{← inferType thmTy}" -- Save the theorem - let name := preDef.declName ++ "body_is_valid" + let name := appendToName preDef.declName "body_is_valid" let decl := Declaration.thmDecl { name levelParams := preDef.levelParams @@ -1107,7 +1111,7 @@ def proveMutRecIsValid trace[Diverge.def.valid] "Generated the term: {isValid}" -- Save the theorem let thmTy ← mkAppM ``FixII.is_valid #[mutRecBodyConst] - let name := grName ++ "mut_rec_body_is_valid" + let name := appendToName grName "mut_rec_body_is_valid" let decl := Declaration.thmDecl { name levelParams := grLvlParams @@ -1196,7 +1200,7 @@ partial def proveUnfoldingThms (isValidThm : Expr) let proof ← mkLambdaFVars xs proof trace[Diverge.def.unfold] "proveUnfoldingThms: proof: {proof}:\n{← inferType proof}" -- Declare the theorem - let name := preDef.declName ++ "unfold" + let name := appendToName preDef.declName "unfold" let decl := Declaration.thmDecl { name levelParams := preDef.levelParams @@ -1282,7 +1286,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Add an auxiliary definition for `param_in_out_ty` (this is a potentially big term) let param_in_out_ty ← do let value ← mkLambdaFVars #[i_var] param_in_out_ty - let name := grName.append "param_in_out_ty" + let name := appendToName grName "param_in_out_ty" let levelParams := grLvlParams let decl := Declaration.defnDecl { name := name @@ -1392,44 +1396,71 @@ def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLC open private elabHeaders levelMVarToParamHeaders getAllUserLevelNames withFunLocalDecls elabFunValues instantiateMVarsAtHeader instantiateMVarsAtLetRecToLift checkLetRecsToLiftTypes withUsed from Lean.Elab.MutualDef -def Term.elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do - let scopeLevelNames ← getLevelNames - let headers ← elabHeaders views - let headers ← levelMVarToParamHeaders views headers - let allUserLevelNames := getAllUserLevelNames headers - withFunLocalDecls headers fun funFVars => do - for view in views, funFVar in funFVars do - addLocalVarInfo view.declId funFVar - -- Add fake use site to prevent "unused variable" warning (if the - -- function is actually not recursive, Lean would print this warning). - -- Remark: we could detect this case and encode the function without - -- using the fixed-point. In practice it shouldn't happen however: - -- we define non-recursive functions with the `divergent` keyword - -- only for testing purposes. - addTermInfo' view.declId funFVar - let values ← - try - let values ← elabFunValues headers - Term.synthesizeSyntheticMVarsNoPostponing - values.mapM (instantiateMVars ·) - catch ex => - logException ex - headers.mapM fun header => mkSorry header.type (synthetic := true) - let headers ← headers.mapM instantiateMVarsAtHeader - let letRecsToLift ← getLetRecsToLift - let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift - checkLetRecsToLiftTypes funFVars letRecsToLift - withUsed vars headers values letRecsToLift fun vars => do - let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift - for preDef in preDefs do - trace[Diverge.elab] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" - let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamPreDecls preDefs - let preDefs ← instantiateMVarsAtPreDecls preDefs - let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames - for preDef in preDefs do - trace[Diverge.elab] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}" - checkForHiddenUnivLevels allUserLevelNames preDefs - addPreDefinitions preDefs +-- Comes from Term.isExample +def isExample (views : Array DefView) : Bool := + views.any (·.kind.isExample) + +open Language in +def Term.elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := + if isExample views then + withoutModifyingEnv do + -- save correct environment in info tree + withSaveInfoContext do + go + else + go +where + go := + withAlwaysResolvedPromises views.size fun bodyPromises => + withAlwaysResolvedPromises views.size fun tacPromises => do + let scopeLevelNames ← getLevelNames + let headers ← elabHeaders views bodyPromises tacPromises + let headers ← levelMVarToParamHeaders views headers + let allUserLevelNames := getAllUserLevelNames headers + withFunLocalDecls headers fun funFVars => do + for view in views, funFVar in funFVars do + addLocalVarInfo view.declId funFVar + -- Modification 1: + -- Add fake use site to prevent "unused variable" warning (if the + -- function is actually not recursive, Lean would print this warning). + -- Remark: we could detect this case and encode the function without + -- using the fixed-point. In practice it shouldn't happen however: + -- we define non-recursive functions with the `divergent` keyword + -- only for testing purposes. + addTermInfo' view.declId funFVar + let values ← + try + let values ← elabFunValues headers + Term.synthesizeSyntheticMVarsNoPostponing + values.mapM (instantiateMVars ·) + catch ex => + logException ex + headers.mapM fun header => mkSorry header.type (synthetic := true) + let headers ← headers.mapM instantiateMVarsAtHeader + let letRecsToLift ← getLetRecsToLift + let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift + checkLetRecsToLiftTypes funFVars letRecsToLift + withUsed vars headers values letRecsToLift fun vars => do + let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift + for preDef in preDefs do + trace[Elab.definition] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" + let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamPreDecls preDefs + let preDefs ← instantiateMVarsAtPreDecls preDefs + let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames + for preDef in preDefs do + trace[Elab.definition] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}" + checkForHiddenUnivLevels allUserLevelNames preDefs + addPreDefinitions preDefs -- Modification 2: we use our custom function here + processDeriving headers + + processDeriving (headers : Array DefViewElabHeader) := do + for header in headers, view in views do + if let some classNamesStx := view.deriving? then + for classNameStx in classNamesStx do + let className ← realizeGlobalConstNoOverload classNameStx + withRef classNameStx do + unless (← processDefDeriving className header.declName) do + throwError "failed to synthesize instance '{className}' for '{header.declName}'" open Command in def Command.elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do @@ -1439,7 +1470,8 @@ def Command.elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do let modifiers ← elabModifiers mods let (binders, type) := expandOptDeclSig sig let deriving? := none - pure { ref := d, kind := DefKind.def, modifiers, + let headerRef := Syntax.missing -- Not sure what to put here + pure { ref := d, kind := DefKind.def, headerRef, modifiers, declId := id, binders, type? := type, value := val, deriving? } runTermElabM fun vars => Term.elabMutualDef vars views @@ -1460,7 +1492,7 @@ elab_rules : command if (`_root_).isPrefixOf name then throwUnsupportedSyntax let view := extractMacroScopes name let .str ns shortName := view.name | throwUnsupportedSyntax - let shortName' := { view with name := shortName }.review + let shortName' := { view with name := Name.mkSimple shortName }.review let cmd ← `(mutual $mods:declModifiers divergent%$tk def $(⟨setDeclIdName id shortName'⟩):declId $sig:optDeclSig $val:declVal end) if ns matches .anonymous then Command.elabCommand cmd @@ -1475,6 +1507,7 @@ namespace Tests --set_option trace.Diverge.def.genBody true --set_option trace.Diverge.def.valid true --set_option trace.Diverge.def.genBody.visit true + --set_option trace.Diverge.def.unfold true divergent def list_nth {a: Type u} (ls : List a) (i : Int) : Result a := match ls with @@ -1492,7 +1525,7 @@ namespace Tests 0 ≤ i → i < ls.length → ∃ x, list_nth ls i = .ok x := by induction ls - . intro i hpos h; simp at h; linarith + . intro i hpos h; simp at h; omega . rename_i hd tl ih intro i hpos h -- We can directly use `rw [list_nth]` @@ -1502,7 +1535,7 @@ namespace Tests . -- We don't have to do this if we use scalar_tac have hneq : 0 < i := by cases i <;> rename_i a _ <;> simp_all; cases a <;> simp_all simp at h - have ⟨ x, ih ⟩ := ih (i - 1) (by linarith) (by linarith) + have ⟨ x, ih ⟩ := ih (i - 1) (by omega) (by omega) simp [ih] tauto diff --git a/backends/lean/Base/Extensions.lean b/backends/lean/Base/Extensions.lean index c0e80861..b491f81b 100644 --- a/backends/lean/Base/Extensions.lean +++ b/backends/lean/Base/Extensions.lean @@ -1,5 +1,4 @@ import Lean -import Std.Lean.HashSet import Base.Utils import Base.Primitives.Base diff --git a/backends/lean/Base/IList/IList.lean b/backends/lean/Base/IList/IList.lean index 9fe2297f..96843f55 100644 --- a/backends/lean/Base/IList/IList.lean +++ b/backends/lean/Base/IList/IList.lean @@ -1,7 +1,6 @@ /- Complementary list functions and lemmas which operate on integers rather than natural numbers. -/ -import Std.Data.Int.Lemmas import Base.Arith import Base.Utils @@ -17,7 +16,7 @@ def len (ls : List α) : Int := theorem len_pos : 0 ≤ (ls : List α).len := by induction ls <;> simp [*] - linarith + omega instance (l: List a) : Arith.HasIntPred (l.len) where concl := 0 ≤ l.len @@ -171,6 +170,7 @@ theorem ireplicate_replicate {α : Type u} (l : ℤ) (x : α) (h : 0 ≤ l) : have hl : l.toNat = .succ (l.toNat - 1) := by cases hl: l.toNat <;> simp_all conv => rhs; rw[hl] + rfl termination_by l.toNat decreasing_by int_decr_tac @@ -279,12 +279,12 @@ open Arith in if heq: i = 0 then simp [*] at * have := tl.len_pos - linarith + omega else have : 0 < i := by int_tac simp [*] apply hi - linarith + omega theorem idrop_len_le (i : Int) (ls : List α) : (ls.idrop i).len ≤ ls.len := match ls with @@ -293,13 +293,13 @@ theorem idrop_len_le (i : Int) (ls : List α) : (ls.idrop i).len ≤ ls.len := if h: i = 0 then by simp [*] else have := idrop_len_le (i - 1) tl - by simp [*]; linarith + by simp [*]; omega @[simp] theorem idrop_len (i : Int) (ls : List α) (_ : 0 ≤ i) (_ : i ≤ ls.len) : (ls.idrop i).len = ls.len - i := match ls with - | [] => by simp_all; linarith + | [] => by simp_all; omega | hd :: tl => if h: i = 0 then by simp [*] else diff --git a/backends/lean/Base/Primitives.lean b/backends/lean/Base/Primitives.lean index f80c2004..93617049 100644 --- a/backends/lean/Base/Primitives.lean +++ b/backends/lean/Base/Primitives.lean @@ -1,6 +1,7 @@ import Base.Primitives.Base import Base.Tuples import Base.Primitives.Scalar +import Base.Primitives.ScalarNotations import Base.Primitives.ArraySlice import Base.Primitives.Vec import Base.Primitives.Alloc diff --git a/backends/lean/Base/Primitives/ArraySlice.lean b/backends/lean/Base/Primitives/ArraySlice.lean index 157f9df1..be460987 100644 --- a/backends/lean/Base/Primitives/ArraySlice.lean +++ b/backends/lean/Base/Primitives/ArraySlice.lean @@ -126,7 +126,7 @@ abbrev Slice.v {α : Type u} (v : Slice α) : List α := v.val example {a: Type u} (v : Slice a) : v.length ≤ Scalar.max ScalarTy.Usize := by scalar_tac -def Slice.new (α : Type u): Slice α := ⟨ [], by apply Scalar.cMax_suffices .Usize; simp; decide ⟩ +def Slice.new (α : Type u): Slice α := ⟨ [], by apply Scalar.cMax_suffices .Usize; simp ⟩ -- TODO: very annoying that the α is an explicit parameter def Slice.len (α : Type u) (v : Slice α) : Usize := @@ -325,8 +325,7 @@ theorem Slice.subslice_spec {α : Type u} [Inhabited α] (s : Slice α) (r : Ran have := List.index_slice r.start.val r.end_.val i s.val (by scalar_tac) (by scalar_tac) (by trivial) (by scalar_tac) simp [*] -attribute [pp_dot] List.len List.length List.index -- use the dot notation when printing -set_option pp.coercions false -- do not print coercions with ↑ (this doesn't parse) +set_option pp.fieldNotation.generalized true def Slice.update_subslice (α : Type u) (s : Slice α) (r : Range Usize) (ss : Slice α) : Result (Slice α) := -- TODO: not completely sure here diff --git a/backends/lean/Base/Primitives/CoreConvertNum.lean b/backends/lean/Base/Primitives/CoreConvertNum.lean index eb456a96..b53d11db 100644 --- a/backends/lean/Base/Primitives/CoreConvertNum.lean +++ b/backends/lean/Base/Primitives/CoreConvertNum.lean @@ -4,6 +4,7 @@ import Init.Data.List.Basic import Mathlib.Tactic.Linarith import Base.IList import Base.Primitives.Scalar +import Base.Primitives.ScalarNotations import Base.Primitives.ArraySlice import Base.Arith import Base.Progress.Base diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean index 8fb067e1..9f809ead 100644 --- a/backends/lean/Base/Primitives/Scalar.lean +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -1,6 +1,5 @@ import Lean import Lean.Meta.Tactic.Simp -import Mathlib.Tactic.Linarith import Base.Primitives.Base import Base.Primitives.Core import Base.Diverge.Base @@ -9,6 +8,9 @@ import Base.Arith.Int namespace Primitives +-- Deactivate the warnings which appear when we use `#assert` +set_option linter.hashCommand false + ---------------------- -- MACHINE INTEGERS -- ---------------------- @@ -279,11 +281,11 @@ theorem Scalar.cMax_bound ty : Scalar.cMax ty ≤ Scalar.max ty := by theorem Scalar.cMin_suffices ty (h : Scalar.cMin ty ≤ x) : Scalar.min ty ≤ x := by have := Scalar.cMin_bound ty - linarith + omega theorem Scalar.cMax_suffices ty (h : x ≤ Scalar.cMax ty) : x ≤ Scalar.max ty := by have := Scalar.cMax_bound ty - linarith + omega /-- The scalar type. @@ -310,40 +312,15 @@ theorem Scalar.bound_suffices (ty : ScalarTy) (x : Int) : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty := λ h => by - apply And.intro <;> have hmin := Scalar.cMin_bound ty <;> have hmax := Scalar.cMax_bound ty <;> linarith + apply And.intro <;> have hmin := Scalar.cMin_bound ty <;> have hmax := Scalar.cMax_bound ty <;> omega -/- [match_pattern] attribute: allows to us `Scalar.ofIntCore` inside of patterns. - This is particularly useful once we introduce notations like `#u32` (which - desugards to `Scalar.ofIntCore`) as it allows to write expressions like this: - Example: - ``` - match x with - | 0#u32 => ... - | 1#u32 => ... - | ... - ``` - -/ -@[match_pattern] def Scalar.ofIntCore {ty : ScalarTy} (x : Int) +def Scalar.ofIntCore {ty : ScalarTy} (x : Int) (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : Scalar ty := { val := x, hmin := h.left, hmax := h.right } --- The definitions below are used later to introduce nice syntax for constants, --- like `1#u32`. We are reusing the technique described here: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Different.20elaboration.20inside.2Foutside.20of.20match.20patterns/near/425455284 - -class InBounds (ty : ScalarTy) (x : Int) := - hInBounds : Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty - --- This trick to trigger reduction for decidable propositions comes from --- here: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/instance.20with.20tactic.20autoparam/near/343495807 -class Decide (p : Prop) [Decidable p] : Prop where - isTrue : p -instance : @Decide p (.isTrue h) := @Decide.mk p (_) h - -instance [Decide (Scalar.cMin ty ≤ v ∧ v ≤ Scalar.cMax ty)] : InBounds ty v where - hInBounds := Decide.isTrue - -@[reducible, match_pattern] def Scalar.ofInt {ty : ScalarTy} (x : Int) [InBounds ty x] : Scalar ty := - Scalar.ofIntCore x (Scalar.bound_suffices ty x InBounds.hInBounds) +@[reducible] def Scalar.ofInt {ty : ScalarTy} (x : Int) + (hInBounds : Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty := by decide) : Scalar ty := + Scalar.ofIntCore x (Scalar.bound_suffices ty x hInBounds) @[simp] abbrev Scalar.in_bounds (ty : ScalarTy) (x : Int) : Prop := Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty @@ -351,10 +328,17 @@ instance [Decide (Scalar.cMin ty ≤ v ∧ v ≤ Scalar.cMax ty)] : InBounds ty @[simp] abbrev Scalar.check_bounds (ty : ScalarTy) (x : Int) : Bool := (Scalar.cMin ty ≤ x || Scalar.min ty ≤ x) ∧ (x ≤ Scalar.cMax ty || x ≤ Scalar.max ty) +/- Discussion: + This coercion can be slightly annoying at times, because if we write + something like `u = 3` (where `u` is, for instance, as `U32`), then instead of + coercing `u` to `Int`, Lean will lift `3` to `U32`). + For now we deactivate it. + -- TODO(raitobezarius): the inbounds constraint is a bit ugly as we can pretty trivially -- discharge the lhs on ≥ 0. instance {ty: ScalarTy} [InBounds ty (Int.ofNat n)]: OfNat (Scalar ty) (n: ℕ) where ofNat := Scalar.ofInt n +-/ theorem Scalar.check_bounds_imp_in_bounds {ty : ScalarTy} {x : Int} (h: Scalar.check_bounds ty x) : @@ -363,7 +347,7 @@ theorem Scalar.check_bounds_imp_in_bounds {ty : ScalarTy} {x : Int} have ⟨ hmin, hmax ⟩ := h have hbmin := Scalar.cMin_bound ty have hbmax := Scalar.cMax_bound ty - cases hmin <;> cases hmax <;> apply And.intro <;> linarith + cases hmin <;> cases hmax <;> apply And.intro <;> omega theorem Scalar.check_bounds_eq_in_bounds (ty : ScalarTy) (x : Int) : Scalar.check_bounds ty x ↔ Scalar.in_bounds ty x := by @@ -405,9 +389,8 @@ theorem Scalar.tryMk_eq (ty : ScalarTy) (x : Int) : simp [tryMk, ofOption, tryMkOpt] split_ifs <;> simp -instance (ty: ScalarTy) : InBounds ty 0 where - hInBounds := by - induction ty <;> simp [Scalar.cMax, Scalar.cMin, Scalar.max, Scalar.min] <;> decide +@[simp] theorem zero_in_cbounds {ty : ScalarTy} : Scalar.cMin ty ≤ 0 ∧ 0 ≤ Scalar.cMax ty := by + cases ty <;> simp [Scalar.cMax, Scalar.cMin, Scalar.max, Scalar.min] <;> decide def Scalar.neg {ty : ScalarTy} (x : Scalar ty) : Result (Scalar ty) := Scalar.tryMk ty (- x.val) @@ -749,7 +732,6 @@ theorem Scalar.add_spec {ty} {x y : Scalar ty} (∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y) := by have h := @add_equiv ty x y split at h <;> simp_all - apply h theorem Scalar.add_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty} (hmax : ↑x + ↑y ≤ Scalar.max ty) : @@ -757,7 +739,7 @@ theorem Scalar.add_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty} have hmin : Scalar.min ty ≤ ↑x + ↑y := by have hx := x.hmin have hy := y.hmin - cases ty <;> simp [min, ScalarTy.isSigned] at * <;> linarith + cases ty <;> simp [min, ScalarTy.isSigned] at * <;> omega apply add_spec <;> assumption /- Fine-grained theorems -/ @@ -844,7 +826,6 @@ theorem Scalar.sub_spec {ty} {x y : Scalar ty} ∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y := by have h := @sub_equiv ty x y split at h <;> simp_all - apply h theorem Scalar.sub_unsigned_spec {ty : ScalarTy} (s : ¬ ty.isSigned) {x y : Scalar ty} (hmin : Scalar.min ty ≤ ↑x - ↑y) : @@ -853,7 +834,7 @@ theorem Scalar.sub_unsigned_spec {ty : ScalarTy} (s : ¬ ty.isSigned) have hx := x.hmin have hxm := x.hmax have hy := y.hmin - cases ty <;> simp [min, max, ScalarTy.isSigned] at * <;> linarith + cases ty <;> simp [min, max, ScalarTy.isSigned] at * <;> omega intros apply sub_spec <;> assumption @@ -1049,11 +1030,11 @@ theorem Scalar.div_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : S have hx := x.hmin have hy := y.hmin simp [h] at hx hy - have hmin : 0 ≤ ↑x / ↑y := Int.ediv_nonneg hx hy + have hmin : 0 ≤ x.val / y.val := Int.ediv_nonneg hx hy have hmax : ↑x / ↑y ≤ Scalar.max ty := by have := Int.ediv_le_self ↑y hx have := x.hmax - linarith + omega have hs := @div_spec ty x y hnz simp [*] at hs apply hs @@ -1170,7 +1151,7 @@ theorem Scalar.rem_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : S have h : (0 : Int) < y := by int_tac have h := Int.emod_lt_of_pos ↑x h have := y.hmax - linarith + omega have hs := @rem_spec ty x y hnz simp [*] at hs simp [*] @@ -1261,73 +1242,18 @@ def U128.ofIntCore := @Scalar.ofIntCore .U128 -- ofInt -- TODO: typeclass? -@[match_pattern] abbrev Isize.ofInt := @Scalar.ofInt .Isize -@[match_pattern] abbrev I8.ofInt := @Scalar.ofInt .I8 -@[match_pattern] abbrev I16.ofInt := @Scalar.ofInt .I16 -@[match_pattern] abbrev I32.ofInt := @Scalar.ofInt .I32 -@[match_pattern] abbrev I64.ofInt := @Scalar.ofInt .I64 -@[match_pattern] abbrev I128.ofInt := @Scalar.ofInt .I128 -@[match_pattern] abbrev Usize.ofInt := @Scalar.ofInt .Usize -@[match_pattern] abbrev U8.ofInt := @Scalar.ofInt .U8 -@[match_pattern] abbrev U16.ofInt := @Scalar.ofInt .U16 -@[match_pattern] abbrev U32.ofInt := @Scalar.ofInt .U32 -@[match_pattern] abbrev U64.ofInt := @Scalar.ofInt .U64 -@[match_pattern] abbrev U128.ofInt := @Scalar.ofInt .U128 - -postfix:max "#isize" => Isize.ofInt -postfix:max "#i8" => I8.ofInt -postfix:max "#i16" => I16.ofInt -postfix:max "#i32" => I32.ofInt -postfix:max "#i64" => I64.ofInt -postfix:max "#i128" => I128.ofInt -postfix:max "#usize" => Usize.ofInt -postfix:max "#u8" => U8.ofInt -postfix:max "#u16" => U16.ofInt -postfix:max "#u32" => U32.ofInt -postfix:max "#u64" => U64.ofInt -postfix:max "#u128" => U128.ofInt - -/- Testing the notations -/ -example := 0#u32 -example := 1#u32 -example := 1#i32 -example := 0#isize -example := (-1)#isize -example (x : U32) : Bool := - match x with - | 0#u32 => true - | _ => false - -example (x : U32) : Bool := - match x with - | 1#u32 => true - | _ => false - -example (x : I32) : Bool := - match x with - | (-1)#i32 => true - | _ => false - --- Notation for pattern matching --- We make the precedence looser than the negation. -notation:70 a:70 "#scalar" => Scalar.mk (a) _ _ - -example {ty} (x : Scalar ty) : ℤ := - match x with - | v#scalar => v - -example {ty} (x : Scalar ty) : Bool := - match x with - | 1#scalar => true - | _ => false - -example {ty} (x : Scalar ty) : Bool := - match x with - | -(1 : Int)#scalar => true - | _ => false - --- Testing the notations -example : Result Usize := 0#usize + 1#usize +abbrev Isize.ofInt := @Scalar.ofInt .Isize +abbrev I8.ofInt := @Scalar.ofInt .I8 +abbrev I16.ofInt := @Scalar.ofInt .I16 +abbrev I32.ofInt := @Scalar.ofInt .I32 +abbrev I64.ofInt := @Scalar.ofInt .I64 +abbrev I128.ofInt := @Scalar.ofInt .I128 +abbrev Usize.ofInt := @Scalar.ofInt .Usize +abbrev U8.ofInt := @Scalar.ofInt .U8 +abbrev U16.ofInt := @Scalar.ofInt .U16 +abbrev U32.ofInt := @Scalar.ofInt .U32 +abbrev U64.ofInt := @Scalar.ofInt .U64 +abbrev U128.ofInt := @Scalar.ofInt .U128 -- TODO: factor those lemmas out @[simp] theorem Scalar.ofInt_val_eq {ty} (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : (Scalar.ofIntCore x h).val = x := by @@ -1457,18 +1383,18 @@ theorem coe_max {ty: ScalarTy} (a b: Scalar ty): ↑(Max.max a b) = (Max.max ( -- Max theory -- TODO: do the min theory later on. -theorem Scalar.zero_le_unsigned {ty} (s: ¬ ty.isSigned) (x: Scalar ty): Scalar.ofInt 0 ≤ x := by +theorem Scalar.zero_le_unsigned {ty} (s: ¬ ty.isSigned) (x: Scalar ty): Scalar.ofInt 0 (by simp) ≤ x := by apply (Scalar.le_equiv _ _).2 convert x.hmin cases ty <;> simp [ScalarTy.isSigned] at s <;> simp [Scalar.min] @[simp] theorem Scalar.max_unsigned_left_zero_eq {ty} [s: Fact (¬ ty.isSigned)] (x: Scalar ty): - Max.max (Scalar.ofInt 0) x = x := max_eq_right (Scalar.zero_le_unsigned s.out x) + Max.max (Scalar.ofInt 0 (by simp)) x = x := max_eq_right (Scalar.zero_le_unsigned s.out x) @[simp] theorem Scalar.max_unsigned_right_zero_eq {ty} [s: Fact (¬ ty.isSigned)] (x: Scalar ty): - Max.max x (Scalar.ofInt 0) = x := max_eq_left (Scalar.zero_le_unsigned s.out x) + Max.max x (Scalar.ofInt 0 (by simp)) = x := max_eq_left (Scalar.zero_le_unsigned s.out x) -- Leading zeros def core.num.Usize.leading_zeros (x : Usize) : U32 := sorry diff --git a/backends/lean/Base/Primitives/ScalarNotations.lean b/backends/lean/Base/Primitives/ScalarNotations.lean new file mode 100644 index 00000000..3bc86a9c --- /dev/null +++ b/backends/lean/Base/Primitives/ScalarNotations.lean @@ -0,0 +1,109 @@ +import Lean +import Lean.Meta.Tactic.Simp +import Mathlib.Tactic.Linarith +import Base.Primitives.Scalar +import Base.Arith + +namespace Primitives + +open Lean Meta Elab Term + +/- Something strange happens here: when we solve the goal with scalar_tac, it + sometimes leaves meta-variables in place, which then causes issues when + type-checking functions. For instance, it happens when we have const-generics + in the translation: the constants contain meta-variables, which are then + used in the types, which cause issues later. An example is given below: + -/ +macro:max x:term:max noWs "#isize" : term => `(Isize.ofInt $x (by first | decide | scalar_tac)) +macro:max x:term:max noWs "#i8" : term => `(I8.ofInt $x (by first | decide | scalar_tac)) +macro:max x:term:max noWs "#i16" : term => `(I16.ofInt $x (by first | decide | scalar_tac)) +macro:max x:term:max noWs "#i32" : term => `(I32.ofInt $x (by first | decide | scalar_tac)) +macro:max x:term:max noWs "#i64" : term => `(I64.ofInt $x (by first | decide | scalar_tac)) +macro:max x:term:max noWs "#i128" : term => `(I128.ofInt $x (by first | decide | scalar_tac)) +macro:max x:term:max noWs "#usize" : term => `(Usize.ofInt $x (by first | decide | scalar_tac)) +macro:max x:term:max noWs "#u8" : term => `(U8.ofInt $x (by first | decide | scalar_tac)) +macro:max x:term:max noWs "#u16" : term => `(U16.ofInt $x (by first | decide | scalar_tac)) +macro:max x:term:max noWs "#u32" : term => `(U32.ofInt $x (by first | decide | scalar_tac)) +macro:max x:term:max noWs "#u64" : term => `(U64.ofInt $x (by first | decide | scalar_tac)) +macro:max x:term:max noWs "#u128" : term => `(U128.ofInt $x (by first | decide | scalar_tac)) + +-- Notation for pattern matching +-- We make the precedence looser than the negation. +notation:70 a:70 "#scalar" => Scalar.mk (a) _ _ + +/- Testing the notations -/ +example := 0#u32 +example := 1#u32 +example := 1#i32 +example := 0#isize +example := (-1)#isize + +example := 1#u32 + +/- +-- This doesn't work anymore +example (x : U32) : Bool := + match x with + | 0#u32 => true + | _ => false + +example (x : U32) : Bool := + match x with + | 1#u32 => true + | _ => false + +example (x : I32) : Bool := + match x with + | (-1)#i32 => true + | _ => false +-/ + +example (x : U32) : Bool := + match x with + | 0#scalar => true + | _ => false + +example (x : U32) : Bool := + match x with + | 1#scalar => true + | _ => false + +example (x : I32) : Bool := + match x with + | (-1)#scalar => true + | _ => false + +example {ty} (x : Scalar ty) : ℤ := + match x with + | v#scalar => v + +example {ty} (x : Scalar ty) : Bool := + match x with + | 1#scalar => true + | _ => false + +example {ty} (x : Scalar ty) : Bool := + match x with + | -(1 : Int)#scalar => true + | _ => false + +-- Testing the notations +example : Result Usize := 0#usize + 1#usize + +-- More complex expressions +example (x y : Int) (h : 0 ≤ x + y ∧ x + y ≤ 1000) : U32 := (x + y)#u32 + +namespace Scalar.Examples + + abbrev Array (a : Type) (len : U32) := { l : List a // l.length = len.val } + + -- Checking the syntax + example : Array Int 0#u32 := ⟨ [], by simp ⟩ + + /- The example below fails if we don't use `decide` in the elaboration + of the scalar notation -/ + example (a : Array (Array Int 32#u32) 32#u32) := a + +end Scalar.Examples + +end Primitives diff --git a/backends/lean/Base/Primitives/Vec.lean b/backends/lean/Base/Primitives/Vec.lean index 5ed7b606..0b010944 100644 --- a/backends/lean/Base/Primitives/Vec.lean +++ b/backends/lean/Base/Primitives/Vec.lean @@ -2,7 +2,6 @@ import Lean import Lean.Meta.Tactic.Simp import Init.Data.List.Basic -import Mathlib.Tactic.Linarith import Base.IList import Base.Primitives.Scalar import Base.Primitives.ArraySlice @@ -34,7 +33,7 @@ abbrev Vec.v {α : Type u} (v : Vec α) : List α := v.val example {a: Type u} (v : Vec a) : v.length ≤ Scalar.max ScalarTy.Usize := by scalar_tac -def Vec.new (α : Type u): Vec α := ⟨ [], by apply Scalar.cMax_suffices .Usize; simp; decide ⟩ +def Vec.new (α : Type u): Vec α := ⟨ [], by apply Scalar.cMax_suffices .Usize; simp ⟩ instance (α : Type u) : Inhabited (Vec α) := by constructor @@ -59,7 +58,7 @@ def Vec.push (α : Type u) (v : Vec α) (x : α) : Result (Vec α) have h : nlen ≤ Usize.max := by simp [Usize.max] at * have hm := Usize.refined_max.property - cases h <;> cases hm <;> simp [U32.max, U64.max] at * <;> try linarith + cases h <;> cases hm <;> simp [U32.max, U64.max] at * <;> try omega ok ⟨ List.concat v.val x, by simp at *; assumption ⟩ else fail maximumSizeExceeded @@ -192,7 +191,7 @@ def alloc.slice.Slice.to_vec def core.slice.Slice.reverse (T : Type) (s : Slice T) : Slice T := ⟨ s.val.reverse, by sorry ⟩ -def alloc.vec.Vec.with_capacity (T : Type) (s : Usize) : alloc.vec.Vec T := Vec.new T +def alloc.vec.Vec.with_capacity (T : Type) (_ : Usize) : alloc.vec.Vec T := Vec.new T /- [alloc::vec::{(core::ops::deref::Deref for alloc::vec::Vec<T, A>)#9}::deref]: Source: '/rustc/d59363ad0b6391b7fc5bbb02c9ccf9300eef3753/library/alloc/src/vec/mod.rs', lines 2624:4-2624:27 diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean index 03c80a42..0e46737f 100644 --- a/backends/lean/Base/Progress/Base.lean +++ b/backends/lean/Base/Progress/Base.lean @@ -1,5 +1,4 @@ import Lean -import Std.Lean.HashSet import Base.Utils import Base.Primitives.Base import Base.Extensions @@ -111,7 +110,7 @@ section Methods -- Collect all the free variables in the arguments let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty -- Check if they intersect the fvars we introduced for the existentially quantified variables - let evarsSet : HashSet FVarId := HashSet.ofArray (evars.map (fun (x : Expr) => x.fvarId!)) + let evarsSet : HashSet FVarId := HashSet.empty.insertMany (evars.map (fun (x : Expr) => x.fvarId!)) let filtArgsFVars := allArgsFVars.toArray.filter (fun var => evarsSet.contains var) if filtArgsFVars.isEmpty then pure () else diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean index f2a56e50..da601b73 100644 --- a/backends/lean/Base/Progress/Progress.lean +++ b/backends/lean/Base/Progress/Progress.lean @@ -58,17 +58,13 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal) We also make sure that all the meta variables which appear in the function arguments have been instantiated -/ - let env ← getEnv let thTy ← do match th with | .Theorem thName => - let thDecl := env.constants.find! thName - -- We have to introduce fresh meta-variables for the universes already - let ul : List (Name × Level) ← - thDecl.levelParams.mapM (λ x => do pure (x, ← mkFreshLevelMVar)) - let ulMap : HashMap Name Level := HashMap.ofList ul - let thTy := thDecl.type.instantiateLevelParamsCore (λ x => ulMap.find! x) - pure thTy + -- Lookup the theorem and introduce fresh meta-variables for the universes + let th ← mkConstWithFreshMVarLevels thName + -- Retrieve the type + inferType th | .Local asmDecl => pure asmDecl.type trace[Progress] "Looked up theorem/assumption type: {thTy}" -- TODO: the tactic fails if we uncomment withNewMCtxDepth @@ -135,7 +131,7 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal) Tactic.focus do let _ ← tryTac - (simpAt true [] + (simpAt true {} #[] [] [``Primitives.bind_tc_ok, ``Primitives.bind_tc_fail, ``Primitives.bind_tc_div] [hEq.fvarId!] (.targets #[] true)) -- It may happen that at this point the goal is already solved (though this is rare) @@ -144,7 +140,7 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal) else trace[Progress] "goal after applying the eq and simplifying the binds: {← getMainGoal}" -- TODO: remove this (some types get unfolded too much: we "fold" them back) - let _ ← tryTac (simpAt true [] scalar_eqs [] .wildcard_dep) + let _ ← tryTac (simpAt true {} #[] [] scalar_eqs [] .wildcard_dep) trace[Progress] "goal after folding back scalar types: {← getMainGoal}" -- Clear the equality, unless the user requests not to do so let mgoal ← do @@ -350,11 +346,8 @@ def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do -- Not a local declaration: should be a theorem trace[Progress] "With arg: theorem" addCompletionInfo <| CompletionInfo.id id id.getId (danglingDot := false) {} none - let cs ← resolveGlobalConstWithInfos id - match cs with - | [] => throwError "Could not find theorem {id}" - | id :: _ => - pure (some (.Theorem id)) + let some (.const name _) ← Term.resolveId? id | throwError m!"Could not find theorem: {id}" + pure (some (.Theorem name)) else pure none let ids := let args := asArgs.getArgs diff --git a/backends/lean/Base/Utils.lean b/backends/lean/Base/Utils.lean index 7ae5a832..5954f048 100644 --- a/backends/lean/Base/Utils.lean +++ b/backends/lean/Base/Utils.lean @@ -7,7 +7,6 @@ Mathlib tactics: - rcases: https://leanprover-community.github.io/mathlib_docs/tactics.html#rcases - split_ifs: https://leanprover-community.github.io/mathlib_docs/tactics.html#split_ifs - norm_num: https://leanprover-community.github.io/mathlib_docs/tactics.html#norm_num -- should we use linarith or omega? - hint: https://leanprover-community.github.io/mathlib_docs/tactics.html#hint - classical: https://leanprover-community.github.io/mathlib_docs/tactics.html#classical -/ @@ -133,8 +132,9 @@ open Lean.Elab.Command liftTermElabM do let id := stx[1] addCompletionInfo <| CompletionInfo.id id id.getId (danglingDot := false) {} none - let cs ← resolveGlobalConstWithInfos id - explore_decl cs[0]! + let some cs ← Term.resolveId? id | throwError m!"Unknown id: {id}" + let name := cs.constName! + explore_decl name private def test1 : Nat := 0 private def test2 (x : Nat) : Nat := x @@ -664,7 +664,7 @@ example (h : ∃ x y z, x + y + z ≥ 0) : ∃ x, x ≥ 0 := by Something very annoying is that there is no function which allows to initialize a simp context without doing an elaboration - as a consequence we write our own here. -/ -def mkSimpCtx (simpOnly : Bool) (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) : +def mkSimpCtx (simpOnly : Bool) (config : Simp.Config) (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) : Tactic.TacticM Simp.Context := do -- Initialize either with the builtin simp theorems or with all the simp theorems let simpThms ← @@ -693,7 +693,7 @@ def mkSimpCtx (simpOnly : Bool) (declsToUnfold : List Name) (thms : List Name) ( throwError "Not a proposition: {thmName}" ) simpThms let congrTheorems ← getSimpCongrTheorems - pure { simpTheorems := #[simpThms], congrTheorems } + pure { config, simpTheorems := #[simpThms], congrTheorems } inductive Location where /-- Apply the tactic everywhere. Same as `Tactic.Location.wildcard` -/ @@ -704,56 +704,111 @@ inductive Location where /-- Same as Tactic.Location -/ | targets (hypotheses : Array Syntax) (type : Bool) --- Comes from Tactic.simpLocation -def customSimpLocation (ctx : Simp.Context) (discharge? : Option Simp.Discharge := none) - (loc : Location) : TacticM Simp.UsedSimps := do +-- Adapted from Tactic.simpLocation +def customSimpLocation (ctx : Simp.Context) (simprocs : Simp.SimprocsArray) (discharge? : Option Simp.Discharge := none) + (loc : Location) : TacticM Simp.Stats := do match loc with | Location.targets hyps simplifyTarget => - withMainContext do - let fvarIds ← Lean.Elab.Tactic.getFVarIds hyps - go fvarIds simplifyTarget + -- Simply call the regular simpLocation + simpLocation ctx simprocs discharge? (Tactic.Location.targets hyps simplifyTarget) | Location.wildcard => - withMainContext do - go (← (← getMainGoal).getNondepPropHyps) (simplifyTarget := true) + -- Simply call the regular simpLocation + simpLocation ctx simprocs discharge? Tactic.Location.wildcard | Location.wildcard_dep => + -- Custom behavior withMainContext do - let ctx ← Lean.MonadLCtx.getLCtx - let decls ← ctx.getDecls + -- Lookup *all* the declarations + let lctx ← Lean.MonadLCtx.getLCtx + let decls ← lctx.getDecls let tgts := (decls.map (fun d => d.fvarId)).toArray - go tgts (simplifyTarget := true) -where - go (fvarIdsToSimp : Array FVarId) (simplifyTarget : Bool) : TacticM Simp.UsedSimps := do - let mvarId ← getMainGoal - let (result?, usedSimps) ← simpGoal mvarId ctx (simplifyTarget := simplifyTarget) (discharge? := discharge?) (fvarIdsToSimp := fvarIdsToSimp) - match result? with - | none => replaceMainGoal [] - | some (_, mvarId) => replaceMainGoal [mvarId] - return usedSimps + -- Call the regular simpLocation.go + simpLocation.go ctx simprocs discharge? tgts (simplifyTarget := true) /- Call the simp tactic. -/ -def simpAt (simpOnly : Bool) (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) - (loc : Location) : +def simpAt (simpOnly : Bool) (config : Simp.Config) (simprocs : Simp.SimprocsArray) + (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) (loc : Location) : Tactic.TacticM Unit := do -- Initialize the simp context - let ctx ← mkSimpCtx simpOnly declsToUnfold thms hypsToUse + let ctx ← mkSimpCtx simpOnly config declsToUnfold thms hypsToUse -- Apply the simplifier - let _ ← customSimpLocation ctx (discharge? := .none) loc + let _ ← customSimpLocation ctx simprocs (discharge? := .none) loc /- Call the dsimp tactic. -/ -def dsimpAt (simpOnly : Bool) (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) - (loc : Tactic.Location) : +def dsimpAt (simpOnly : Bool) (config : Simp.Config) (simprocs : Simp.SimprocsArray) + (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) (loc : Tactic.Location) : Tactic.TacticM Unit := do -- Initialize the simp context - let ctx ← mkSimpCtx simpOnly declsToUnfold thms hypsToUse + let ctx ← mkSimpCtx simpOnly config declsToUnfold thms hypsToUse -- Apply the simplifier - dsimpLocation ctx loc + dsimpLocation ctx simprocs loc -- Call the simpAll tactic -def simpAll (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) : +def simpAll (config : Simp.Config) (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) : Tactic.TacticM Unit := do -- Initialize the simp context - let ctx ← mkSimpCtx false declsToUnfold thms hypsToUse + let ctx ← mkSimpCtx false config declsToUnfold thms hypsToUse -- Apply the simplifier let _ ← Lean.Meta.simpAll (← getMainGoal) ctx +/- Adapted from Elab.Tactic.Rewrite -/ +def rewriteTarget (eqThm : Expr) (symm : Bool) (config : Rewrite.Config := {}) : TacticM Unit := do + Term.withSynthesize <| withMainContext do + let r ← (← getMainGoal).rewrite (← getMainTarget) eqThm symm (config := config) + let mvarId' ← (← getMainGoal).replaceTargetEq r.eNew r.eqProof + replaceMainGoal (mvarId' :: r.mvarIds) + +/- Adapted from Elab.Tactic.Rewrite -/ +def rewriteLocalDecl (eqThm : Expr) (symm : Bool) (fvarId : FVarId) (config : Rewrite.Config := {}) : + TacticM Unit := withMainContext do + -- Note: we cannot execute `replaceLocalDecl` inside `Term.withSynthesize`. + -- See issues #2711 and #2727. + let rwResult ← Term.withSynthesize <| withMainContext do + let localDecl ← fvarId.getDecl + (← getMainGoal).rewrite localDecl.type eqThm symm (config := config) + let replaceResult ← (← getMainGoal).replaceLocalDecl fvarId rwResult.eNew rwResult.eqProof + replaceMainGoal (replaceResult.mvarId :: rwResult.mvarIds) + +/- Adapted from Elab.Tactic.Rewrite -/ +def rewriteWithThms + (thms : List (Bool × Expr)) + (rewrite : (symm : Bool) → (thm : Expr) → TacticM Unit) + : TacticM Unit := do + let rec go thms := + match thms with + | [] => throwError "Failed to rewrite with any theorem" + | (symm, eqThm)::thms => + rewrite symm eqThm <|> go thms + go thms + +/- Adapted from Elab.Tactic.Rewrite -/ +def evalRewriteSeqAux (cfg : Rewrite.Config) (thms : List (Bool × Expr)) (loc : Tactic.Location) : TacticM Unit := + rewriteWithThms thms fun symm term => do + withLocation loc + (rewriteLocalDecl term symm · cfg) + (rewriteTarget term symm cfg) + (throwTacticEx `rewrite · "did not find instance of the pattern in the current goal") + +/-- `rpt`: if `true`, repeatedly rewrite -/ +def rewriteAt (cfg : Rewrite.Config) (rpt : Bool) + (thms : List (Bool × Name)) (loc : Tactic.Location) : TacticM Unit := do + -- Lookup the theorems + let lookupThm (x : Bool × Name) : TacticM (List (Bool × Expr)) := do + let thName := x.snd + let lookupOne (thName : Name) : TacticM (Bool × Expr) := do + -- Lookup the theorem and introduce fresh meta-variables for the universes + let th ← mkConstWithFreshMVarLevels thName + pure (x.fst, th) + match ← getEqnsFor? thName (nonRec := true) with + | some eqThms => do + eqThms.data.mapM lookupOne + | none => do + pure [← lookupOne thName] + let thms ← List.mapM lookupThm thms + let thms := thms.flatten + -- Rewrite + if rpt then + Utils.repeatTac (evalRewriteSeqAux cfg thms loc) + else + evalRewriteSeqAux cfg thms loc + end Utils |