diff options
author | Son Ho | 2023-07-13 14:00:11 +0200 |
---|---|---|
committer | Son Ho | 2023-07-13 14:00:11 +0200 |
commit | 2dbd529b499c2bb9dae754df0e449cad577ac7a0 (patch) | |
tree | 72c1cfbc8d29443fc2d70fd3f0ebfbd315954483 /backends/lean | |
parent | 6cc0279045d40231f1cce83f0edb7aada1e59d92 (diff) |
Add IList.lean
Diffstat (limited to '')
-rw-r--r-- | backends/lean/Base/Arith/Arith.lean | 136 | ||||
-rw-r--r-- | backends/lean/Base/Arith/Base.lean | 40 | ||||
-rw-r--r-- | backends/lean/Base/IList.lean | 127 |
3 files changed, 239 insertions, 64 deletions
diff --git a/backends/lean/Base/Arith/Arith.lean b/backends/lean/Base/Arith/Arith.lean index ab4fd182..2ff030fe 100644 --- a/backends/lean/Base/Arith/Arith.lean +++ b/backends/lean/Base/Arith/Arith.lean @@ -15,25 +15,6 @@ namespace Arith open Primitives Utils --- TODO: move? -theorem ne_zero_is_lt_or_gt {x : Int} (hne : x ≠ 0) : x < 0 ∨ x > 0 := by - cases h: x <;> simp_all - . rename_i n; - cases n <;> simp_all - . apply Int.negSucc_lt_zero - --- TODO: move? -theorem ne_is_lt_or_gt {x y : Int} (hne : x ≠ y) : x < y ∨ x > y := by - have hne : x - y ≠ 0 := by - simp - intro h - have: x = y := by linarith - simp_all - have h := ne_zero_is_lt_or_gt hne - match h with - | .inl _ => left; linarith - | .inr _ => right; linarith - -- TODO: move instance Vec.cast (a : Type): Coe (Vec a) (List a) where coe := λ v => v.val @@ -48,17 +29,21 @@ instance Vec.cast (a : Type): Coe (Vec a) (List a) where coe := λ v => v.val -/ def Scalar.toInt {ty : ScalarTy} (x : Scalar ty) : Int := x.val --- Remark: I tried a version of the shape `HasProp {a : Type} (x : a)` +-- Remark: I tried a version of the shape `HasScalarProp {a : Type} (x : a)` -- but the lookup didn't work -class HasProp (a : Sort u) where +class HasScalarProp (a : Sort u) where + prop_ty : a → Prop + prop : ∀ x:a, prop_ty x + +class HasIntProp (a : Sort u) where prop_ty : a → Prop prop : ∀ x:a, prop_ty x -instance (ty : ScalarTy) : HasProp (Scalar ty) where +instance (ty : ScalarTy) : HasScalarProp (Scalar ty) where -- prop_ty is inferred prop := λ x => And.intro x.hmin x.hmax -instance (a : Type) : HasProp (Vec a) where +instance (a : Type) : HasScalarProp (Vec a) where prop_ty := λ v => v.val.length ≤ Scalar.max ScalarTy.Usize prop := λ ⟨ _, l ⟩ => l @@ -117,37 +102,49 @@ def collectInstancesFromMainCtx (k : Expr → MetaM (Option Expr)) : Tactic.Tact let decls ← ctx.getDecls decls.foldlM (fun hs d => collectInstances k hs d.toExpr) hs --- Return an instance of `HasProp` for `e` if it has some -def lookupHasProp (e : Expr) : MetaM (Option Expr) := do - trace[Arith] "lookupHasProp" +-- Helper +def lookupProp (fName : String) (className : Name) (e : Expr) : MetaM (Option Expr) := do + trace[Arith] fName -- TODO: do we need Lean.observing? -- This actually eliminates the error messages Lean.observing? do - trace[Arith] "lookupHasProp: observing" + trace[Arith] m!"{fName}: observing" let ty ← Lean.Meta.inferType e - let hasProp ← mkAppM ``HasProp #[ty] + let hasProp ← mkAppM className #[ty] let hasPropInst ← trySynthInstance hasProp match hasPropInst with | LOption.some i => - trace[Arith] "Found HasProp instance" + trace[Arith] "Found HasScalarProp instance" let i_prop ← mkProjection i (Name.mkSimple "prop") some (← mkAppM' i_prop #[e]) | _ => none --- Collect the instances of `HasProp` for the subexpressions in the context -def collectHasPropInstancesFromMainCtx : Tactic.TacticM (HashSet Expr) := do - collectInstancesFromMainCtx lookupHasProp +-- Return an instance of `HasIntProp` for `e` if it has some +def lookupHasIntProp (e : Expr) : MetaM (Option Expr) := + lookupProp "lookupHasScalarProp" ``HasIntProp e + +-- Return an instance of `HasScalarProp` for `e` if it has some +def lookupHasScalarProp (e : Expr) : MetaM (Option Expr) := + lookupProp "lookupHasScalarProp" ``HasScalarProp e + +-- Collect the instances of `HasIntProp` for the subexpressions in the context +def collectHasIntPropInstancesFromMainCtx : Tactic.TacticM (HashSet Expr) := do + collectInstancesFromMainCtx lookupHasIntProp + +-- Collect the instances of `HasScalarProp` for the subexpressions in the context +def collectHasScalarPropInstancesFromMainCtx : Tactic.TacticM (HashSet Expr) := do + collectInstancesFromMainCtx lookupHasScalarProp elab "display_has_prop_instances" : tactic => do - trace[Arith] "Displaying the HasProp instances" - let hs ← collectHasPropInstancesFromMainCtx + trace[Arith] "Displaying the HasScalarProp instances" + let hs ← collectHasScalarPropInstancesFromMainCtx hs.forM fun e => do - trace[Arith] "+ HasProp instance: {e}" + trace[Arith] "+ HasScalarProp instance: {e}" example (x : U32) : True := by - let i : HasProp U32 := inferInstance - have p := @HasProp.prop _ i x - simp only [HasProp.prop_ty] at p + let i : HasScalarProp U32 := inferInstance + have p := @HasScalarProp.prop _ i x + simp only [HasScalarProp.prop_ty] at p display_has_prop_instances simp @@ -196,14 +193,18 @@ def introInstances (declToUnfold : Name) (lookup : Expr → MetaM (Option Expr)) -- Return the new value pure nval -def introHasPropInstances : Tactic.TacticM (Array Expr) := do - trace[Arith] "Introducing the HasProp instances" - introInstances ``HasProp.prop_ty lookupHasProp +def introHasIntPropInstances : Tactic.TacticM (Array Expr) := do + trace[Arith] "Introducing the HasIntProp instances" + introInstances ``HasIntProp.prop_ty lookupHasIntProp + +def introHasScalarPropInstances : Tactic.TacticM (Array Expr) := do + trace[Arith] "Introducing the HasScalarProp instances" + introInstances ``HasScalarProp.prop_ty lookupHasScalarProp --- Lookup the instances of `HasProp for all the sub-expressions in the context, +-- Lookup the instances of `HasScalarProp for all the sub-expressions in the context, -- and introduce the corresponding assumptions elab "intro_has_prop_instances" : tactic => do - let _ ← introHasPropInstances + let _ ← introHasScalarPropInstances example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by intro_has_prop_instances @@ -246,6 +247,7 @@ def intTacPreprocess : Tactic.TacticM Unit := do let k := splitOnAsms asms Utils.splitDisjTac asm k k -- Introduce + let _ ← introHasIntPropInstances let asms ← introInstances ``PropHasImp.concl lookupPropHasImp -- Split splitOnAsms asms.toList @@ -289,29 +291,35 @@ example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y ∧ x + y ≥ 2 := by int_tac +def scalarTacPreprocess (tac : Tactic.TacticM Unit) : Tactic.TacticM Unit := do + Tactic.withMainContext do + -- Introduce the scalar bounds + let _ ← introHasScalarPropInstances + Tactic.allGoals do + -- Inroduce the bounds for the isize/usize types + let add (e : Expr) : Tactic.TacticM Unit := do + let ty ← inferType e + let _ ← Utils.addDeclTac (← mkFreshUserName `h) e ty (asLet := false) + add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Isize []]) + add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []]) + add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []]) + -- Reveal the concrete bounds - TODO: not too sure about that. + -- Maybe we should reveal the "concrete" bounds (after normalization) + Utils.simpAt [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax, + ``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min, + ``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max, + ``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min, + ``U8.max, ``U16.max, ``U32.max, ``U64.max, ``U128.max + ] [] [] .wildcard + -- Finish the proof + tac + +elab "scalar_tac_preprocess" : tactic => + scalarTacPreprocess intTacPreprocess + -- A tactic to solve linear arithmetic goals in the presence of scalars def scalarTac : Tactic.TacticM Unit := do - Tactic.withMainContext do - -- Introduce the scalar bounds - let _ ← introHasPropInstances - Tactic.allGoals do - -- Inroduce the bounds for the isize/usize types - let add (e : Expr) : Tactic.TacticM Unit := do - let ty ← inferType e - let _ ← Utils.addDeclTac (← mkFreshUserName `h) e ty (asLet := false) - add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Isize []]) - add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []]) - add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []]) - -- Reveal the concrete bounds - TODO: not too sure about that. - -- Maybe we should reveal the "concrete" bounds (after normalization) - Utils.simpAt [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax, - ``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min, - ``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max, - ``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min, - ``U8.max, ``U16.max, ``U32.max, ``U64.max, ``U128.max - ] [] [] .wildcard - -- Apply the integer tactic - intTac + scalarTacPreprocess intTac elab "scalar_tac" : tactic => scalarTac diff --git a/backends/lean/Base/Arith/Base.lean b/backends/lean/Base/Arith/Base.lean index ddd2dc24..a6e59b74 100644 --- a/backends/lean/Base/Arith/Base.lean +++ b/backends/lean/Base/Arith/Base.lean @@ -1,4 +1,6 @@ import Lean +import Std.Data.Int.Lemmas +import Mathlib.Tactic.Linarith namespace Arith @@ -7,4 +9,42 @@ open Lean Elab Term Meta -- We can't define and use trace classes in the same file initialize registerTraceClass `Arith +-- TODO: move? +theorem ne_zero_is_lt_or_gt {x : Int} (hne : x ≠ 0) : x < 0 ∨ x > 0 := by + cases h: x <;> simp_all + . rename_i n; + cases n <;> simp_all + . apply Int.negSucc_lt_zero + +-- TODO: move? +theorem ne_is_lt_or_gt {x y : Int} (hne : x ≠ y) : x < y ∨ x > y := by + have hne : x - y ≠ 0 := by + simp + intro h + have: x = y := by linarith + simp_all + have h := ne_zero_is_lt_or_gt hne + match h with + | .inl _ => left; linarith + | .inr _ => right; linarith + + +/- Induction over positive integers -/ +-- TODO: move +theorem int_pos_ind (p : Int → Prop) : + (zero:p 0) → (pos:∀ i, 0 ≤ i → p i → p (i + 1)) → ∀ i, 0 ≤ i → p i := by + intro h0 hr i hpos +-- have heq : Int.toNat i = i := by +-- cases i <;> simp_all + have ⟨ n, heq ⟩ : {n:Nat // n = i } := ⟨ Int.toNat i, by cases i <;> simp_all ⟩ + revert i + induction n + . intro i hpos heq + cases i <;> simp_all + . rename_i n hi + intro i hpos heq + cases i <;> simp_all + rename_i m + cases m <;> simp_all + end Arith diff --git a/backends/lean/Base/IList.lean b/backends/lean/Base/IList.lean new file mode 100644 index 00000000..7e764d63 --- /dev/null +++ b/backends/lean/Base/IList.lean @@ -0,0 +1,127 @@ +/- Complementary list functions and lemmas which operate on integers rather + than natural numbers. -/ + +import Std.Data.Int.Lemmas +import Mathlib.Tactic.Linarith +import Base.Arith + +namespace List + +#check List.get +def len (ls : List α) : Int := + match ls with + | [] => 0 + | _ :: tl => 1 + len tl + +-- Remark: if i < 0, then the result is none +def optIndex (i : Int) (ls : List α) : Option α := + match ls with + | [] => none + | hd :: tl => if i = 0 then some hd else optIndex (i - 1) tl + +-- Remark: if i < 0, then the result is the defaul element +def index [Inhabited α] (i : Int) (ls : List α) : α := + match ls with + | [] => Inhabited.default + | x :: tl => + if i = 0 then x else index (i - 1) tl + +-- Remark: the list is unchanged if the index is not in bounds (in particular +-- if it is < 0) +def update (ls : List α) (i : Int) (y : α) : List α := + match ls with + | [] => [] + | x :: tl => if i = 0 then y :: tl else x :: update tl (i - 1) y + +-- Remark: the whole list is dropped if the index is not in bounds (in particular +-- if it is < 0) +def idrop (i : Int) (ls : List α) : List α := + match ls with + | [] => [] + | x :: tl => if i = 0 then x :: tl else idrop (i - 1) tl + +@[simp] theorem len_nil : len ([] : List α) = 0 := by simp [len] +@[simp] theorem len_cons : len ((x :: tl) : List α) = 1 + len tl := by simp [len] + +@[simp] theorem index_zero_cons [Inhabited α] : index 0 ((x :: tl) : List α) = x := by simp [index] +@[simp] theorem index_nzero_cons [Inhabited α] (hne : i ≠ 0) : index i ((x :: tl) : List α) = index (i - 1) tl := by simp [*, index] + +@[simp] theorem update_nil : update ([] : List α) i y = [] := by simp [update] +@[simp] theorem update_zero_cons : update ((x :: tl) : List α) 0 y = y :: tl := by simp [update] +@[simp] theorem update_nzero_cons (hne : i ≠ 0) : update ((x :: tl) : List α) i y = x :: update tl (i - 1) y := by simp [*, update] + +@[simp] theorem idrop_nil : idrop i ([] : List α) = [] := by simp [idrop] +@[simp] theorem idrop_zero : idrop 0 (ls : List α) = ls := by cases ls <;> simp [idrop] +@[simp] theorem idrop_nzero_cons (hne : i ≠ 0) : idrop i ((x :: tl) : List α) = idrop (i - 1) tl := by simp [*, idrop] + +theorem len_eq_length (ls : List α) : ls.len = ls.length := by + induction ls + . rfl + . simp [*, Int.ofNat_succ, Int.add_comm] + +theorem len_pos : 0 ≤ (ls : List α).len := by + induction ls <;> simp [*] + linarith + +instance (a : Type u) : Arith.HasIntProp (List a) where + prop_ty := λ ls => 0 ≤ ls.len + prop := λ ls => ls.len_pos + +@[simp] theorem len_append (l1 l2 : List α) : (l1 ++ l2).len = l1.len + l2.len := by + -- Remark: simp loops here because of the following rewritings: + -- @Nat.cast_add: ↑(List.length l1 + List.length l2) ==> ↑(List.length l1) + ↑(List.length l2) + -- Int.ofNat_add_ofNat: ↑(List.length l1) + ↑(List.length l2) ==> ↑(List.length l1 + List.length l2) + -- TODO: post an issue? + simp only [len_eq_length] + simp only [length_append] + simp only [Int.ofNat_add] + +theorem left_length_eq_append_eq (l1 l2 l1' l2' : List α) (heq : l1.length = l1'.length) : + l1 ++ l2 = l1' ++ l2' ↔ l1 = l1' ∧ l2 = l2' := by + revert l1' + induction l1 + . intro l1'; cases l1' <;> simp [*] + . intro l1'; cases l1' <;> simp_all; tauto + +theorem right_length_eq_append_eq (l1 l2 l1' l2' : List α) (heq : l2.length = l2'.length) : + l1 ++ l2 = l1' ++ l2' ↔ l1 = l1' ∧ l2 = l2' := by + have := left_length_eq_append_eq l1 l2 l1' l2' + constructor <;> intro heq2 <;> + have : l1.length + l2.length = l1'.length + l2'.length := by + have : (l1 ++ l2).length = (l1' ++ l2').length := by simp [*] + simp only [length_append] at this + apply this + . simp [heq] at this + tauto + . tauto + +theorem left_len_eq_append_eq (l1 l2 l1' l2' : List α) (heq : l1.len = l1'.len) : + l1 ++ l2 = l1' ++ l2' ↔ l1 = l1' ∧ l2 = l2' := by + simp [len_eq_length] at heq + apply left_length_eq_append_eq + assumption + +theorem right_len_eq_append_eq (l1 l2 l1' l2' : List α) (heq : l2.len = l2'.len) : + l1 ++ l2 = l1' ++ l2' ↔ l1 = l1' ∧ l2 = l2' := by + simp [len_eq_length] at heq + apply right_length_eq_append_eq + assumption + +open Arith in +theorem idrop_eq_nil_of_le (hineq : ls.len ≤ i) : idrop i ls = [] := by + revert i + induction ls <;> simp [*] + rename_i hd tl hi + intro i hineq + if heq: i = 0 then + simp [*] at * + have := tl.len_pos + linarith + else + simp at hineq + have : 0 < i := by int_tac + simp [*] + apply hi + linarith + +end List |