diff options
Diffstat (limited to '')
-rw-r--r-- | backends/lean/Base/Arith/Base.lean | 8 | ||||
-rw-r--r-- | backends/lean/Base/Arith/Int.lean | 48 | ||||
-rw-r--r-- | backends/lean/Base/Arith/Scalar.lean | 12 | ||||
-rw-r--r-- | backends/lean/Base/IList/IList.lean | 23 | ||||
-rw-r--r-- | backends/lean/Base/Primitives/ArraySlice.lean | 2 | ||||
-rw-r--r-- | backends/lean/Base/Primitives/Scalar.lean | 11 | ||||
-rw-r--r-- | backends/lean/Base/Primitives/Vec.lean | 13 | ||||
-rw-r--r-- | backends/lean/Base/Progress/Progress.lean | 6 | ||||
-rw-r--r-- | backends/lean/Base/Utils.lean | 42 |
9 files changed, 125 insertions, 40 deletions
diff --git a/backends/lean/Base/Arith/Base.lean b/backends/lean/Base/Arith/Base.lean index fb6b12e5..320b4b53 100644 --- a/backends/lean/Base/Arith/Base.lean +++ b/backends/lean/Base/Arith/Base.lean @@ -52,10 +52,6 @@ theorem int_pos_ind (p : Int → Prop) : rename_i m cases m <;> simp_all --- We sometimes need this to make sure no natural numbers appear in the goals --- TODO: there is probably something more general to do -theorem nat_zero_eq_int_zero : (0 : Nat) = (0 : Int) := by simp - -- This is mostly used in termination proofs theorem to_int_to_nat_lt (x y : ℤ) (h0 : 0 ≤ x) (h1 : x < y) : ↑(x.toNat) < y := by @@ -68,4 +64,8 @@ theorem to_int_sub_to_nat_lt (x y : ℤ) (x' : ℕ) have : 0 ≤ x := by omega simp [Int.toNat_sub_of_le, *] +-- WARNING: do not use this with `simp` as it might loop. The left-hand side indeed reduces to the +-- righ-hand side, meaning the rewriting can be applied to `n` itself. +theorem ofNat_instOfNatNat_eq (n : Nat) : @OfNat.ofNat Nat n (instOfNatNat n) = n := by rfl + end Arith diff --git a/backends/lean/Base/Arith/Int.lean b/backends/lean/Base/Arith/Int.lean index 068d6f2f..b1927cfd 100644 --- a/backends/lean/Base/Arith/Int.lean +++ b/backends/lean/Base/Arith/Int.lean @@ -3,6 +3,7 @@ import Lean import Lean.Meta.Tactic.Simp import Init.Data.List.Basic +import Mathlib.Tactic.Ring.RingNF import Base.Utils import Base.Arith.Base @@ -111,7 +112,7 @@ def collectInstancesFromMainCtx (k : Expr → MetaM (Option Expr)) : Tactic.Tact let hs := HashSet.empty -- Explore the declarations let decls ← ctx.getDecls - let hs ← decls.foldlM (fun hs d => do + let hs ← decls.foldlM (fun hs d => do -- Collect instances over all subexpressions in the context. -- Note that we explore the *type* of the local declarations: if we have -- for instance `h : A ∧ B` in the context, the expression itself is simply @@ -154,7 +155,7 @@ def lookupHasIntPred (e : Expr) : MetaM (Option Expr) := lookupProp "lookupHasIntPred" ``HasIntPred e (fun term => pure #[term]) (fun _ => pure #[]) -- Collect the instances of `HasIntPred` for the subexpressions in the context -def collectHasIntPredInstancesFromMainCtx : Tactic.TacticM (HashSet Expr) := do +def collectHasIntPredInstancesFromMainCtx : Tactic.TacticM (HashSet Expr) := do collectInstancesFromMainCtx lookupHasIntPred -- Return an instance of `PropHasImp` for `e` if it has some @@ -201,7 +202,7 @@ def introInstances (declToUnfold : Name) (lookup : Expr → MetaM (Option Expr)) -- Add a declaration let nval ← Utils.addDeclTac name e type (asLet := false) -- Simplify to unfold the declaration to unfold (i.e., the projector) - Utils.simpAt true {} #[] [declToUnfold] [] [] (Location.targets #[mkIdent name] false) + Utils.simpAt true {} [] [declToUnfold] [] [] (Location.targets #[mkIdent name] false) -- Return the new value pure nval @@ -214,7 +215,7 @@ def introHasIntPropInstances : Tactic.TacticM (Array Expr) := do elab "intro_has_int_prop_instances" : tactic => do let _ ← introHasIntPropInstances -def introHasIntPredInstances : Tactic.TacticM (Array Expr) := do +def introHasIntPredInstances : Tactic.TacticM (Array Expr) := do trace[Arith] "Introducing the HasIntPred instances" introInstances ``HasIntPred.concl lookupHasIntPred @@ -230,6 +231,8 @@ def introPropHasImpInstances : Tactic.TacticM (Array Expr) := do elab "intro_prop_has_imp_instances" : tactic => do let _ ← introPropHasImpInstances +def intTacSimpRocs : List Name := [``Int.reduceNegSucc, ``Int.reduceNeg] + /- Boosting a bit the `omega` tac. -/ def intTacPreprocess (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM Unit := do @@ -244,7 +247,33 @@ def intTacPreprocess (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM U extraPreprocess -- Reduce all the terms in the goal - note that the extra preprocessing step -- might have proven the goal, hence the `Tactic.allGoals` - Tactic.allGoals do tryTac (dsimpAt false {} #[] [] [] [] Tactic.Location.wildcard) + let dsimp := + Tactic.allGoals do tryTac ( + -- We set `simpOnly` at false on purpose + dsimpAt false {} intTacSimpRocs + -- Declarations to unfold + [] + -- Theorems + [] + [] Tactic.Location.wildcard) + dsimp + -- More preprocessing: apply norm_cast to the whole context + Tactic.allGoals (Utils.tryTac (Utils.normCastAtAll)) + -- norm_cast does weird things with negative numbers so we reapply simp + dsimp + -- We also need this, in case the goal is: ¬ False + Tactic.allGoals do tryTac ( + Utils.simpAt true {} + -- Simprocs + intTacSimpRocs + -- Unfoldings + [] + -- Simp lemmas + [``not_false_eq_true] + -- Hypotheses + [] + (.targets #[] true) + ) elab "int_tac_preprocess" : tactic => intTacPreprocess (do pure ()) @@ -260,8 +289,6 @@ def intTac (tacName : String) (splitGoalConjs : Bool) (extraPreprocess : Tactic -- Preprocess - wondering if we should do this before or after splitting -- the goal. I think before leads to a smaller proof term? Tactic.allGoals (intTacPreprocess extraPreprocess) - -- More preprocessing - Tactic.allGoals (Utils.tryTac (Utils.simpAt true {} #[] [] [``nat_zero_eq_int_zero] [] .wildcard)) -- Split the conjunctions in the goal if splitGoalConjs then Tactic.allGoals (Utils.repeatTac Utils.splitConjTarget) -- Call omega @@ -298,4 +325,11 @@ example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : example (a : Prop) (x : Int) (h0: 0 < x) (h1: x < 0) : a := by int_tac +-- Intermediate cast through natural numbers +example (a : Prop) (x : Int) (h0: (0 : Nat) < x) (h1: x < 0) : a := by + int_tac + +example (x : Int) (h : x ≤ -3) : x ≤ -2 := by + int_tac + end Arith diff --git a/backends/lean/Base/Arith/Scalar.lean b/backends/lean/Base/Arith/Scalar.lean index ecc5acaf..31110b95 100644 --- a/backends/lean/Base/Arith/Scalar.lean +++ b/backends/lean/Base/Arith/Scalar.lean @@ -19,7 +19,7 @@ def scalarTacExtraPreprocess : Tactic.TacticM Unit := do -- Reveal the concrete bounds, simplify calls to [ofInt] Utils.simpAt true {} -- Simprocs - #[] + intTacSimpRocs -- Unfoldings [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax, ``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min, @@ -59,11 +59,11 @@ instance (ty : ScalarTy) : HasIntProp (Scalar ty) where -- prop_ty is inferred prop := λ x => And.intro x.hmin x.hmax -example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by +example (x _y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by intro_has_int_prop_instances simp [*] -example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by +example (x _y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by scalar_tac -- Checking that we explore the goal *and* projectors correctly @@ -92,4 +92,10 @@ example (x : U32) (h0 : ¬ x = U32.ofInt 0) : 0 < x.val := by example {u: U64} (h1: (u : Int) < 2): (u : Int) = 0 ∨ (u : Int) = 1 := by scalar_tac +example (x : I32) : -100000000000 < x.val := by + scalar_tac + +example : (Usize.ofInt 2).val ≠ 0 := by + scalar_tac + end Arith diff --git a/backends/lean/Base/IList/IList.lean b/backends/lean/Base/IList/IList.lean index 96843f55..ab71daed 100644 --- a/backends/lean/Base/IList/IList.lean +++ b/backends/lean/Base/IList/IList.lean @@ -43,6 +43,9 @@ def index [Inhabited α] (ls : List α) (i : Int) : α := @[simp] theorem index_zero_cons [Inhabited α] : index ((x :: tl) : List α) 0 = x := by simp [index] @[simp] theorem index_nzero_cons [Inhabited α] (hne : i ≠ 0) : index ((x :: tl) : List α) i = index tl (i - 1) := by simp [*, index] +@[simp] theorem index_zero_lt_cons [Inhabited α] (hne : 0 < i) : index ((x :: tl) : List α) i = index tl (i - 1) := by + have : i ≠ 0 := by scalar_tac + simp [*, index] theorem indexOpt_bounds (ls : List α) (i : Int) : ls.indexOpt i = none ↔ i < 0 ∨ ls.len ≤ i := @@ -453,16 +456,18 @@ theorem index_update_eq simp at * apply index_update_eq <;> scalar_tac -theorem update_map_eq {α : Type u} {β : Type v} (ls : List α) (i : Int) (x : α) (f : α → β) : +@[simp] +theorem map_update_eq {α : Type u} {β : Type v} (ls : List α) (i : Int) (x : α) (f : α → β) : (ls.update i x).map f = (ls.map f).update i (f x) := match ls with | [] => by simp | hd :: tl => if h : i = 0 then by simp [*] else - have hi := update_map_eq tl (i - 1) x f + have hi := map_update_eq tl (i - 1) x f by simp [*] +@[simp] theorem len_flatten_update_eq {α : Type u} (ls : List (List α)) (i : Int) (x : List α) (h0 : 0 ≤ i) (h1 : i < ls.len) : (ls.update i x).flatten.len = ls.flatten.len + x.len - (ls.index i).len := @@ -476,6 +481,20 @@ theorem len_flatten_update_eq {α : Type u} (ls : List (List α)) (i : Int) (x : simp [*] int_tac +theorem len_index_le_len_flatten (ls : List (List α)) : + forall (i : Int), (ls.index i).len ≤ ls.flatten.len := by + induction ls <;> intro i <;> simp_all + . rw [List.index] + simp [default] + . rename ∀ _, _ => ih + if hi: i = 0 then + simp_all + int_tac + else + replace ih := ih (i - 1) + simp_all + int_tac + @[simp] theorem index_map_eq {α : Type u} {β : Type v} [Inhabited α] [Inhabited β] (ls : List α) (i : Int) (f : α → β) (h0 : 0 ≤ i) (h1 : i < ls.len) : diff --git a/backends/lean/Base/Primitives/ArraySlice.lean b/backends/lean/Base/Primitives/ArraySlice.lean index be460987..899871af 100644 --- a/backends/lean/Base/Primitives/ArraySlice.lean +++ b/backends/lean/Base/Primitives/ArraySlice.lean @@ -129,7 +129,7 @@ example {a: Type u} (v : Slice a) : v.length ≤ Scalar.max ScalarTy.Usize := by def Slice.new (α : Type u): Slice α := ⟨ [], by apply Scalar.cMax_suffices .Usize; simp ⟩ -- TODO: very annoying that the α is an explicit parameter -def Slice.len (α : Type u) (v : Slice α) : Usize := +abbrev Slice.len (α : Type u) (v : Slice α) : Usize := Usize.ofIntCore v.val.len (by constructor <;> scalar_tac) @[simp] diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean index 9f809ead..31038e0d 100644 --- a/backends/lean/Base/Primitives/Scalar.lean +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -1301,22 +1301,25 @@ instance {ty} : LT (Scalar ty) where instance {ty} : LE (Scalar ty) where le a b := LE.le a.val b.val --- Not marking this one with @[simp] on purpose +-- Not marking this one with @[simp] on purpose: if we have `x = y` somewhere in the context, +-- we may want to use it to substitute `y` with `x` somewhere. +-- TODO: mark it as simp anyway? theorem Scalar.eq_equiv {ty : ScalarTy} (x y : Scalar ty) : x = y ↔ (↑x : Int) = ↑y := by cases x; cases y; simp_all -- This is sometimes useful when rewriting the goal with the local assumptions +-- TODO: this doesn't get triggered @[simp] theorem Scalar.eq_imp {ty : ScalarTy} (x y : Scalar ty) : (↑x : Int) = ↑y → x = y := (eq_equiv x y).mpr -theorem Scalar.lt_equiv {ty : ScalarTy} (x y : Scalar ty) : +@[simp] theorem Scalar.lt_equiv {ty : ScalarTy} (x y : Scalar ty) : x < y ↔ (↑x : Int) < ↑y := by simp [LT.lt] @[simp] theorem Scalar.lt_imp {ty : ScalarTy} (x y : Scalar ty) : (↑x : Int) < (↑y) → x < y := (lt_equiv x y).mpr -theorem Scalar.le_equiv {ty : ScalarTy} (x y : Scalar ty) : +@[simp] theorem Scalar.le_equiv {ty : ScalarTy} (x y : Scalar ty) : x ≤ y ↔ (↑x : Int) ≤ ↑y := by simp [LE.le] @[simp] theorem Scalar.le_imp {ty : ScalarTy} (x y : Scalar ty) : @@ -1377,8 +1380,6 @@ theorem coe_max {ty: ScalarTy} (a b: Scalar ty): ↑(Max.max a b) = (Max.max ( -- TODO: there should be a shorter way to prove this. rw [max_def, max_def] split_ifs <;> simp_all - refine' absurd _ (lt_irrefl a) - exact lt_of_le_of_lt (by assumption) ((Scalar.lt_equiv _ _).2 (by assumption)) -- Max theory -- TODO: do the min theory later on. diff --git a/backends/lean/Base/Primitives/Vec.lean b/backends/lean/Base/Primitives/Vec.lean index 0b010944..e584777a 100644 --- a/backends/lean/Base/Primitives/Vec.lean +++ b/backends/lean/Base/Primitives/Vec.lean @@ -33,14 +33,15 @@ abbrev Vec.v {α : Type u} (v : Vec α) : List α := v.val example {a: Type u} (v : Vec a) : v.length ≤ Scalar.max ScalarTy.Usize := by scalar_tac -def Vec.new (α : Type u): Vec α := ⟨ [], by apply Scalar.cMax_suffices .Usize; simp ⟩ +abbrev Vec.new (α : Type u): Vec α := ⟨ [], by apply Scalar.cMax_suffices .Usize; simp ⟩ instance (α : Type u) : Inhabited (Vec α) := by constructor apply Vec.new -- TODO: very annoying that the α is an explicit parameter -def Vec.len (α : Type u) (v : Vec α) : Usize := +@[simp] +abbrev Vec.len (α : Type u) (v : Vec α) : Usize := Usize.ofIntCore v.val.len (by constructor <;> scalar_tac) @[simp] @@ -63,6 +64,14 @@ def Vec.push (α : Type u) (v : Vec α) (x : α) : Result (Vec α) else fail maximumSizeExceeded +@[pspec] +theorem Vec.push_spec {α : Type u} (v : Vec α) (x : α) (h : v.val.len < Usize.max) : + ∃ v1, v.push α x = ok v1 ∧ + v1.val = v.val ++ [x] := by + simp [push] + split <;> simp_all [List.len_eq_length] + scalar_tac + -- This shouldn't be used def Vec.insert_fwd (α : Type u) (v: Vec α) (i: Usize) (_: α) : Result Unit := if i.val < v.length then diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean index da601b73..35cc8399 100644 --- a/backends/lean/Base/Progress/Progress.lean +++ b/backends/lean/Base/Progress/Progress.lean @@ -131,7 +131,7 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal) Tactic.focus do let _ ← tryTac - (simpAt true {} #[] [] + (simpAt true {} [] [] [``Primitives.bind_tc_ok, ``Primitives.bind_tc_fail, ``Primitives.bind_tc_div] [hEq.fvarId!] (.targets #[] true)) -- It may happen that at this point the goal is already solved (though this is rare) @@ -140,7 +140,7 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal) else trace[Progress] "goal after applying the eq and simplifying the binds: {← getMainGoal}" -- TODO: remove this (some types get unfolded too much: we "fold" them back) - let _ ← tryTac (simpAt true {} #[] [] scalar_eqs [] .wildcard_dep) + let _ ← tryTac (simpAt true {} [] [] scalar_eqs [] .wildcard_dep) trace[Progress] "goal after folding back scalar types: {← getMainGoal}" -- Clear the equality, unless the user requests not to do so let mgoal ← do @@ -410,7 +410,7 @@ namespace Test -- This spec theorem is suboptimal, but it is good to check that it works progress with Scalar.add_spec as ⟨ z, h1 .. ⟩ simp [*, h1] - + example {x y : U32} (hmax : x.val + y.val ≤ U32.max) : ∃ z, x + y = ok z ∧ z.val = x.val + y.val := by diff --git a/backends/lean/Base/Utils.lean b/backends/lean/Base/Utils.lean index 5954f048..b9de2fd1 100644 --- a/backends/lean/Base/Utils.lean +++ b/backends/lean/Base/Utils.lean @@ -664,21 +664,26 @@ example (h : ∃ x y z, x + y + z ≥ 0) : ∃ x, x ≥ 0 := by Something very annoying is that there is no function which allows to initialize a simp context without doing an elaboration - as a consequence we write our own here. -/ -def mkSimpCtx (simpOnly : Bool) (config : Simp.Config) (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) : - Tactic.TacticM Simp.Context := do +def mkSimpCtx (simpOnly : Bool) (config : Simp.Config) (kind : SimpKind) + (simprocs : List Name) (declsToUnfold : List Name) + (thms : List Name) (hypsToUse : List FVarId) : + Tactic.TacticM (Simp.Context × Simp.SimprocsArray) := do -- Initialize either with the builtin simp theorems or with all the simp theorems let simpThms ← if simpOnly then Tactic.simpOnlyBuiltins.foldlM (·.addConst ·) ({} : SimpTheorems) else getSimpTheorems -- Add the equational theorem for the declarations to unfold + let addDeclToUnfold (thms : SimpTheorems) (decl : Name) : Tactic.TacticM SimpTheorems := + if kind == .dsimp then pure (thms.addDeclToUnfoldCore decl) + else thms.addDeclToUnfold decl let simpThms ← - declsToUnfold.foldlM (fun thms decl => thms.addDeclToUnfold decl) simpThms + declsToUnfold.foldlM addDeclToUnfold simpThms -- Add the hypotheses and the rewriting theorems let simpThms ← hypsToUse.foldlM (fun thms fvarId => - -- post: TODO: don't know what that is + -- post: TODO: don't know what that is. It seems to be true by default. -- inv: invert the equality - thms.add (.fvar fvarId) #[] (mkFVar fvarId) (post := false) (inv := false) + thms.add (.fvar fvarId) #[] (mkFVar fvarId) (post := true) (inv := false) -- thms.eraseCore (.fvar fvar) ) simpThms -- Add the rewriting theorems to use @@ -693,7 +698,10 @@ def mkSimpCtx (simpOnly : Bool) (config : Simp.Config) (declsToUnfold : List Nam throwError "Not a proposition: {thmName}" ) simpThms let congrTheorems ← getSimpCongrTheorems - pure { config, simpTheorems := #[simpThms], congrTheorems } + let defaultSimprocs ← if simpOnly then pure {} else Simp.getSimprocs + let simprocs ← simprocs.foldlM (fun simprocs name => simprocs.add name true) defaultSimprocs + let ctx := { config, simpTheorems := #[simpThms], congrTheorems } + pure (ctx, #[simprocs]) inductive Location where /-- Apply the tactic everywhere. Same as `Tactic.Location.wildcard` -/ @@ -725,30 +733,30 @@ def customSimpLocation (ctx : Simp.Context) (simprocs : Simp.SimprocsArray) (dis simpLocation.go ctx simprocs discharge? tgts (simplifyTarget := true) /- Call the simp tactic. -/ -def simpAt (simpOnly : Bool) (config : Simp.Config) (simprocs : Simp.SimprocsArray) +def simpAt (simpOnly : Bool) (config : Simp.Config) (simprocs : List Name) (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) (loc : Location) : Tactic.TacticM Unit := do -- Initialize the simp context - let ctx ← mkSimpCtx simpOnly config declsToUnfold thms hypsToUse + let (ctx, simprocs) ← mkSimpCtx simpOnly config .simp simprocs declsToUnfold thms hypsToUse -- Apply the simplifier let _ ← customSimpLocation ctx simprocs (discharge? := .none) loc /- Call the dsimp tactic. -/ -def dsimpAt (simpOnly : Bool) (config : Simp.Config) (simprocs : Simp.SimprocsArray) +def dsimpAt (simpOnly : Bool) (config : Simp.Config) (simprocs : List Name) (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) (loc : Tactic.Location) : Tactic.TacticM Unit := do -- Initialize the simp context - let ctx ← mkSimpCtx simpOnly config declsToUnfold thms hypsToUse + let (ctx, simprocs) ← mkSimpCtx simpOnly config .dsimp simprocs declsToUnfold thms hypsToUse -- Apply the simplifier dsimpLocation ctx simprocs loc -- Call the simpAll tactic -def simpAll (config : Simp.Config) (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) : +def simpAll (config : Simp.Config) (simprocs : List Name) (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) : Tactic.TacticM Unit := do -- Initialize the simp context - let ctx ← mkSimpCtx false config declsToUnfold thms hypsToUse + let (ctx, simprocs) ← mkSimpCtx false config .simpAll simprocs declsToUnfold thms hypsToUse -- Apply the simplifier - let _ ← Lean.Meta.simpAll (← getMainGoal) ctx + let _ ← Lean.Meta.simpAll (← getMainGoal) ctx simprocs /- Adapted from Elab.Tactic.Rewrite -/ def rewriteTarget (eqThm : Expr) (symm : Bool) (config : Rewrite.Config := {}) : TacticM Unit := do @@ -811,4 +819,12 @@ def rewriteAt (cfg : Rewrite.Config) (rpt : Bool) else evalRewriteSeqAux cfg thms loc +/-- Apply norm_cast to the whole context -/ +def normCastAtAll : TacticM Unit := do + withMainContext do + let ctx ← Lean.MonadLCtx.getLCtx + let decls ← ctx.getDecls + NormCast.normCastTarget + decls.forM (fun d => NormCast.normCastHyp d.fvarId) + end Utils |