summaryrefslogtreecommitdiff
path: root/backends/lean
diff options
context:
space:
mode:
authorSon Ho2023-07-13 14:00:11 +0200
committerSon Ho2023-07-13 14:00:11 +0200
commit2dbd529b499c2bb9dae754df0e449cad577ac7a0 (patch)
tree72c1cfbc8d29443fc2d70fd3f0ebfbd315954483 /backends/lean
parent6cc0279045d40231f1cce83f0edb7aada1e59d92 (diff)
Add IList.lean
Diffstat (limited to 'backends/lean')
-rw-r--r--backends/lean/Base/Arith/Arith.lean136
-rw-r--r--backends/lean/Base/Arith/Base.lean40
-rw-r--r--backends/lean/Base/IList.lean127
3 files changed, 239 insertions, 64 deletions
diff --git a/backends/lean/Base/Arith/Arith.lean b/backends/lean/Base/Arith/Arith.lean
index ab4fd182..2ff030fe 100644
--- a/backends/lean/Base/Arith/Arith.lean
+++ b/backends/lean/Base/Arith/Arith.lean
@@ -15,25 +15,6 @@ namespace Arith
open Primitives Utils
--- TODO: move?
-theorem ne_zero_is_lt_or_gt {x : Int} (hne : x ≠ 0) : x < 0 ∨ x > 0 := by
- cases h: x <;> simp_all
- . rename_i n;
- cases n <;> simp_all
- . apply Int.negSucc_lt_zero
-
--- TODO: move?
-theorem ne_is_lt_or_gt {x y : Int} (hne : x ≠ y) : x < y ∨ x > y := by
- have hne : x - y ≠ 0 := by
- simp
- intro h
- have: x = y := by linarith
- simp_all
- have h := ne_zero_is_lt_or_gt hne
- match h with
- | .inl _ => left; linarith
- | .inr _ => right; linarith
-
-- TODO: move
instance Vec.cast (a : Type): Coe (Vec a) (List a) where coe := λ v => v.val
@@ -48,17 +29,21 @@ instance Vec.cast (a : Type): Coe (Vec a) (List a) where coe := λ v => v.val
-/
def Scalar.toInt {ty : ScalarTy} (x : Scalar ty) : Int := x.val
--- Remark: I tried a version of the shape `HasProp {a : Type} (x : a)`
+-- Remark: I tried a version of the shape `HasScalarProp {a : Type} (x : a)`
-- but the lookup didn't work
-class HasProp (a : Sort u) where
+class HasScalarProp (a : Sort u) where
+ prop_ty : a → Prop
+ prop : ∀ x:a, prop_ty x
+
+class HasIntProp (a : Sort u) where
prop_ty : a → Prop
prop : ∀ x:a, prop_ty x
-instance (ty : ScalarTy) : HasProp (Scalar ty) where
+instance (ty : ScalarTy) : HasScalarProp (Scalar ty) where
-- prop_ty is inferred
prop := λ x => And.intro x.hmin x.hmax
-instance (a : Type) : HasProp (Vec a) where
+instance (a : Type) : HasScalarProp (Vec a) where
prop_ty := λ v => v.val.length ≤ Scalar.max ScalarTy.Usize
prop := λ ⟨ _, l ⟩ => l
@@ -117,37 +102,49 @@ def collectInstancesFromMainCtx (k : Expr → MetaM (Option Expr)) : Tactic.Tact
let decls ← ctx.getDecls
decls.foldlM (fun hs d => collectInstances k hs d.toExpr) hs
--- Return an instance of `HasProp` for `e` if it has some
-def lookupHasProp (e : Expr) : MetaM (Option Expr) := do
- trace[Arith] "lookupHasProp"
+-- Helper
+def lookupProp (fName : String) (className : Name) (e : Expr) : MetaM (Option Expr) := do
+ trace[Arith] fName
-- TODO: do we need Lean.observing?
-- This actually eliminates the error messages
Lean.observing? do
- trace[Arith] "lookupHasProp: observing"
+ trace[Arith] m!"{fName}: observing"
let ty ← Lean.Meta.inferType e
- let hasProp ← mkAppM ``HasProp #[ty]
+ let hasProp ← mkAppM className #[ty]
let hasPropInst ← trySynthInstance hasProp
match hasPropInst with
| LOption.some i =>
- trace[Arith] "Found HasProp instance"
+ trace[Arith] "Found HasScalarProp instance"
let i_prop ← mkProjection i (Name.mkSimple "prop")
some (← mkAppM' i_prop #[e])
| _ => none
--- Collect the instances of `HasProp` for the subexpressions in the context
-def collectHasPropInstancesFromMainCtx : Tactic.TacticM (HashSet Expr) := do
- collectInstancesFromMainCtx lookupHasProp
+-- Return an instance of `HasIntProp` for `e` if it has some
+def lookupHasIntProp (e : Expr) : MetaM (Option Expr) :=
+ lookupProp "lookupHasScalarProp" ``HasIntProp e
+
+-- Return an instance of `HasScalarProp` for `e` if it has some
+def lookupHasScalarProp (e : Expr) : MetaM (Option Expr) :=
+ lookupProp "lookupHasScalarProp" ``HasScalarProp e
+
+-- Collect the instances of `HasIntProp` for the subexpressions in the context
+def collectHasIntPropInstancesFromMainCtx : Tactic.TacticM (HashSet Expr) := do
+ collectInstancesFromMainCtx lookupHasIntProp
+
+-- Collect the instances of `HasScalarProp` for the subexpressions in the context
+def collectHasScalarPropInstancesFromMainCtx : Tactic.TacticM (HashSet Expr) := do
+ collectInstancesFromMainCtx lookupHasScalarProp
elab "display_has_prop_instances" : tactic => do
- trace[Arith] "Displaying the HasProp instances"
- let hs ← collectHasPropInstancesFromMainCtx
+ trace[Arith] "Displaying the HasScalarProp instances"
+ let hs ← collectHasScalarPropInstancesFromMainCtx
hs.forM fun e => do
- trace[Arith] "+ HasProp instance: {e}"
+ trace[Arith] "+ HasScalarProp instance: {e}"
example (x : U32) : True := by
- let i : HasProp U32 := inferInstance
- have p := @HasProp.prop _ i x
- simp only [HasProp.prop_ty] at p
+ let i : HasScalarProp U32 := inferInstance
+ have p := @HasScalarProp.prop _ i x
+ simp only [HasScalarProp.prop_ty] at p
display_has_prop_instances
simp
@@ -196,14 +193,18 @@ def introInstances (declToUnfold : Name) (lookup : Expr → MetaM (Option Expr))
-- Return the new value
pure nval
-def introHasPropInstances : Tactic.TacticM (Array Expr) := do
- trace[Arith] "Introducing the HasProp instances"
- introInstances ``HasProp.prop_ty lookupHasProp
+def introHasIntPropInstances : Tactic.TacticM (Array Expr) := do
+ trace[Arith] "Introducing the HasIntProp instances"
+ introInstances ``HasIntProp.prop_ty lookupHasIntProp
+
+def introHasScalarPropInstances : Tactic.TacticM (Array Expr) := do
+ trace[Arith] "Introducing the HasScalarProp instances"
+ introInstances ``HasScalarProp.prop_ty lookupHasScalarProp
--- Lookup the instances of `HasProp for all the sub-expressions in the context,
+-- Lookup the instances of `HasScalarProp for all the sub-expressions in the context,
-- and introduce the corresponding assumptions
elab "intro_has_prop_instances" : tactic => do
- let _ ← introHasPropInstances
+ let _ ← introHasScalarPropInstances
example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by
intro_has_prop_instances
@@ -246,6 +247,7 @@ def intTacPreprocess : Tactic.TacticM Unit := do
let k := splitOnAsms asms
Utils.splitDisjTac asm k k
-- Introduce
+ let _ ← introHasIntPropInstances
let asms ← introInstances ``PropHasImp.concl lookupPropHasImp
-- Split
splitOnAsms asms.toList
@@ -289,29 +291,35 @@ example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) :
example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y ∧ x + y ≥ 2 := by
int_tac
+def scalarTacPreprocess (tac : Tactic.TacticM Unit) : Tactic.TacticM Unit := do
+ Tactic.withMainContext do
+ -- Introduce the scalar bounds
+ let _ ← introHasScalarPropInstances
+ Tactic.allGoals do
+ -- Inroduce the bounds for the isize/usize types
+ let add (e : Expr) : Tactic.TacticM Unit := do
+ let ty ← inferType e
+ let _ ← Utils.addDeclTac (← mkFreshUserName `h) e ty (asLet := false)
+ add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Isize []])
+ add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []])
+ add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []])
+ -- Reveal the concrete bounds - TODO: not too sure about that.
+ -- Maybe we should reveal the "concrete" bounds (after normalization)
+ Utils.simpAt [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax,
+ ``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min,
+ ``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max,
+ ``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min,
+ ``U8.max, ``U16.max, ``U32.max, ``U64.max, ``U128.max
+ ] [] [] .wildcard
+ -- Finish the proof
+ tac
+
+elab "scalar_tac_preprocess" : tactic =>
+ scalarTacPreprocess intTacPreprocess
+
-- A tactic to solve linear arithmetic goals in the presence of scalars
def scalarTac : Tactic.TacticM Unit := do
- Tactic.withMainContext do
- -- Introduce the scalar bounds
- let _ ← introHasPropInstances
- Tactic.allGoals do
- -- Inroduce the bounds for the isize/usize types
- let add (e : Expr) : Tactic.TacticM Unit := do
- let ty ← inferType e
- let _ ← Utils.addDeclTac (← mkFreshUserName `h) e ty (asLet := false)
- add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Isize []])
- add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []])
- add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []])
- -- Reveal the concrete bounds - TODO: not too sure about that.
- -- Maybe we should reveal the "concrete" bounds (after normalization)
- Utils.simpAt [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax,
- ``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min,
- ``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max,
- ``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min,
- ``U8.max, ``U16.max, ``U32.max, ``U64.max, ``U128.max
- ] [] [] .wildcard
- -- Apply the integer tactic
- intTac
+ scalarTacPreprocess intTac
elab "scalar_tac" : tactic =>
scalarTac
diff --git a/backends/lean/Base/Arith/Base.lean b/backends/lean/Base/Arith/Base.lean
index ddd2dc24..a6e59b74 100644
--- a/backends/lean/Base/Arith/Base.lean
+++ b/backends/lean/Base/Arith/Base.lean
@@ -1,4 +1,6 @@
import Lean
+import Std.Data.Int.Lemmas
+import Mathlib.Tactic.Linarith
namespace Arith
@@ -7,4 +9,42 @@ open Lean Elab Term Meta
-- We can't define and use trace classes in the same file
initialize registerTraceClass `Arith
+-- TODO: move?
+theorem ne_zero_is_lt_or_gt {x : Int} (hne : x ≠ 0) : x < 0 ∨ x > 0 := by
+ cases h: x <;> simp_all
+ . rename_i n;
+ cases n <;> simp_all
+ . apply Int.negSucc_lt_zero
+
+-- TODO: move?
+theorem ne_is_lt_or_gt {x y : Int} (hne : x ≠ y) : x < y ∨ x > y := by
+ have hne : x - y ≠ 0 := by
+ simp
+ intro h
+ have: x = y := by linarith
+ simp_all
+ have h := ne_zero_is_lt_or_gt hne
+ match h with
+ | .inl _ => left; linarith
+ | .inr _ => right; linarith
+
+
+/- Induction over positive integers -/
+-- TODO: move
+theorem int_pos_ind (p : Int → Prop) :
+ (zero:p 0) → (pos:∀ i, 0 ≤ i → p i → p (i + 1)) → ∀ i, 0 ≤ i → p i := by
+ intro h0 hr i hpos
+-- have heq : Int.toNat i = i := by
+-- cases i <;> simp_all
+ have ⟨ n, heq ⟩ : {n:Nat // n = i } := ⟨ Int.toNat i, by cases i <;> simp_all ⟩
+ revert i
+ induction n
+ . intro i hpos heq
+ cases i <;> simp_all
+ . rename_i n hi
+ intro i hpos heq
+ cases i <;> simp_all
+ rename_i m
+ cases m <;> simp_all
+
end Arith
diff --git a/backends/lean/Base/IList.lean b/backends/lean/Base/IList.lean
new file mode 100644
index 00000000..7e764d63
--- /dev/null
+++ b/backends/lean/Base/IList.lean
@@ -0,0 +1,127 @@
+/- Complementary list functions and lemmas which operate on integers rather
+ than natural numbers. -/
+
+import Std.Data.Int.Lemmas
+import Mathlib.Tactic.Linarith
+import Base.Arith
+
+namespace List
+
+#check List.get
+def len (ls : List α) : Int :=
+ match ls with
+ | [] => 0
+ | _ :: tl => 1 + len tl
+
+-- Remark: if i < 0, then the result is none
+def optIndex (i : Int) (ls : List α) : Option α :=
+ match ls with
+ | [] => none
+ | hd :: tl => if i = 0 then some hd else optIndex (i - 1) tl
+
+-- Remark: if i < 0, then the result is the defaul element
+def index [Inhabited α] (i : Int) (ls : List α) : α :=
+ match ls with
+ | [] => Inhabited.default
+ | x :: tl =>
+ if i = 0 then x else index (i - 1) tl
+
+-- Remark: the list is unchanged if the index is not in bounds (in particular
+-- if it is < 0)
+def update (ls : List α) (i : Int) (y : α) : List α :=
+ match ls with
+ | [] => []
+ | x :: tl => if i = 0 then y :: tl else x :: update tl (i - 1) y
+
+-- Remark: the whole list is dropped if the index is not in bounds (in particular
+-- if it is < 0)
+def idrop (i : Int) (ls : List α) : List α :=
+ match ls with
+ | [] => []
+ | x :: tl => if i = 0 then x :: tl else idrop (i - 1) tl
+
+@[simp] theorem len_nil : len ([] : List α) = 0 := by simp [len]
+@[simp] theorem len_cons : len ((x :: tl) : List α) = 1 + len tl := by simp [len]
+
+@[simp] theorem index_zero_cons [Inhabited α] : index 0 ((x :: tl) : List α) = x := by simp [index]
+@[simp] theorem index_nzero_cons [Inhabited α] (hne : i ≠ 0) : index i ((x :: tl) : List α) = index (i - 1) tl := by simp [*, index]
+
+@[simp] theorem update_nil : update ([] : List α) i y = [] := by simp [update]
+@[simp] theorem update_zero_cons : update ((x :: tl) : List α) 0 y = y :: tl := by simp [update]
+@[simp] theorem update_nzero_cons (hne : i ≠ 0) : update ((x :: tl) : List α) i y = x :: update tl (i - 1) y := by simp [*, update]
+
+@[simp] theorem idrop_nil : idrop i ([] : List α) = [] := by simp [idrop]
+@[simp] theorem idrop_zero : idrop 0 (ls : List α) = ls := by cases ls <;> simp [idrop]
+@[simp] theorem idrop_nzero_cons (hne : i ≠ 0) : idrop i ((x :: tl) : List α) = idrop (i - 1) tl := by simp [*, idrop]
+
+theorem len_eq_length (ls : List α) : ls.len = ls.length := by
+ induction ls
+ . rfl
+ . simp [*, Int.ofNat_succ, Int.add_comm]
+
+theorem len_pos : 0 ≤ (ls : List α).len := by
+ induction ls <;> simp [*]
+ linarith
+
+instance (a : Type u) : Arith.HasIntProp (List a) where
+ prop_ty := λ ls => 0 ≤ ls.len
+ prop := λ ls => ls.len_pos
+
+@[simp] theorem len_append (l1 l2 : List α) : (l1 ++ l2).len = l1.len + l2.len := by
+ -- Remark: simp loops here because of the following rewritings:
+ -- @Nat.cast_add: ↑(List.length l1 + List.length l2) ==> ↑(List.length l1) + ↑(List.length l2)
+ -- Int.ofNat_add_ofNat: ↑(List.length l1) + ↑(List.length l2) ==> ↑(List.length l1 + List.length l2)
+ -- TODO: post an issue?
+ simp only [len_eq_length]
+ simp only [length_append]
+ simp only [Int.ofNat_add]
+
+theorem left_length_eq_append_eq (l1 l2 l1' l2' : List α) (heq : l1.length = l1'.length) :
+ l1 ++ l2 = l1' ++ l2' ↔ l1 = l1' ∧ l2 = l2' := by
+ revert l1'
+ induction l1
+ . intro l1'; cases l1' <;> simp [*]
+ . intro l1'; cases l1' <;> simp_all; tauto
+
+theorem right_length_eq_append_eq (l1 l2 l1' l2' : List α) (heq : l2.length = l2'.length) :
+ l1 ++ l2 = l1' ++ l2' ↔ l1 = l1' ∧ l2 = l2' := by
+ have := left_length_eq_append_eq l1 l2 l1' l2'
+ constructor <;> intro heq2 <;>
+ have : l1.length + l2.length = l1'.length + l2'.length := by
+ have : (l1 ++ l2).length = (l1' ++ l2').length := by simp [*]
+ simp only [length_append] at this
+ apply this
+ . simp [heq] at this
+ tauto
+ . tauto
+
+theorem left_len_eq_append_eq (l1 l2 l1' l2' : List α) (heq : l1.len = l1'.len) :
+ l1 ++ l2 = l1' ++ l2' ↔ l1 = l1' ∧ l2 = l2' := by
+ simp [len_eq_length] at heq
+ apply left_length_eq_append_eq
+ assumption
+
+theorem right_len_eq_append_eq (l1 l2 l1' l2' : List α) (heq : l2.len = l2'.len) :
+ l1 ++ l2 = l1' ++ l2' ↔ l1 = l1' ∧ l2 = l2' := by
+ simp [len_eq_length] at heq
+ apply right_length_eq_append_eq
+ assumption
+
+open Arith in
+theorem idrop_eq_nil_of_le (hineq : ls.len ≤ i) : idrop i ls = [] := by
+ revert i
+ induction ls <;> simp [*]
+ rename_i hd tl hi
+ intro i hineq
+ if heq: i = 0 then
+ simp [*] at *
+ have := tl.len_pos
+ linarith
+ else
+ simp at hineq
+ have : 0 < i := by int_tac
+ simp [*]
+ apply hi
+ linarith
+
+end List