summaryrefslogtreecommitdiff
path: root/backends/lean
diff options
context:
space:
mode:
authorSon Ho2023-07-25 14:08:44 +0200
committerSon Ho2023-07-25 14:08:44 +0200
commit876137dff361620d8ade1a4ee94fa9274df0bdc6 (patch)
treed25cb5bf68b53b2f67e67186317f666407d09a04 /backends/lean
parentc652e97f7ab13164150331b4aa3f2e7ef11d24b9 (diff)
Improve int_tac and scalar_tac
Diffstat (limited to '')
-rw-r--r--backends/lean/Base/Arith/Int.lean63
-rw-r--r--backends/lean/Base/Arith/Scalar.lean6
-rw-r--r--backends/lean/Base/IList/IList.lean12
-rw-r--r--backends/lean/Base/Primitives/Vec.lean25
-rw-r--r--backends/lean/Base/Progress/Progress.lean13
5 files changed, 87 insertions, 32 deletions
diff --git a/backends/lean/Base/Arith/Int.lean b/backends/lean/Base/Arith/Int.lean
index fa957293..3415866e 100644
--- a/backends/lean/Base/Arith/Int.lean
+++ b/backends/lean/Base/Arith/Int.lean
@@ -24,12 +24,29 @@ class PropHasImp (x : Prop) where
concl : Prop
prop : x → concl
+instance (p : Int → Prop) : HasIntProp (Subtype p) where
+ prop_ty := λ x => p x
+ prop := λ x => x.property
+
-- This also works for `x ≠ y` because this expression reduces to `¬ x = y`
-- and `Ne` is marked as `reducible`
instance (x y : Int) : PropHasImp (¬ x = y) where
concl := x < y ∨ x > y
prop := λ (h:x ≠ y) => ne_is_lt_or_gt h
+-- Check if a proposition is a linear integer proposition.
+-- We notably use this to check the goals.
+class IsLinearIntProp (x : Prop) where
+
+instance (x y : Int) : IsLinearIntProp (x < y) where
+instance (x y : Int) : IsLinearIntProp (x > y) where
+instance (x y : Int) : IsLinearIntProp (x ≤ y) where
+instance (x y : Int) : IsLinearIntProp (x ≥ y) where
+instance (x y : Int) : IsLinearIntProp (x ≥ y) where
+/- It seems we don't need to do any special preprocessing when the *goal*
+ has the following shape - I guess `linarith` automatically calls `intro` -/
+instance (x y : Int) : IsLinearIntProp (¬ x = y) where
+
open Lean Lean.Elab Lean.Meta
-- Explore a term by decomposing the applications (we explore the applied
@@ -189,14 +206,27 @@ def intTacPreprocess (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM U
elab "int_tac_preprocess" : tactic =>
intTacPreprocess (do pure ())
-def intTac (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM Unit := do
+-- Check if the goal is a linear arithmetic goal
+def goalIsLinearInt : Tactic.TacticM Bool := do
+ Tactic.withMainContext do
+ let gty ← Tactic.getMainTarget
+ match ← trySynthInstance (← mkAppM ``IsLinearIntProp #[gty]) with
+ | .some _ => pure true
+ | _ => pure false
+
+def intTac (splitGoalConjs : Bool) (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM Unit := do
Tactic.withMainContext do
Tactic.focus do
+ let g ← Tactic.getMainGoal
+ trace[Arith] "Original goal: {g}"
+ -- Introduce all the universally quantified variables (includes the assumptions)
+ let (_, g) ← g.intros
+ Tactic.setGoals [g]
-- Preprocess - wondering if we should do this before or after splitting
-- the goal. I think before leads to a smaller proof term?
Tactic.allGoals (intTacPreprocess extraPreprocess)
-- Split the conjunctions in the goal
- Tactic.allGoals (Utils.repeatTac Utils.splitConjTarget)
+ if splitGoalConjs then Tactic.allGoals (Utils.repeatTac Utils.splitConjTarget)
-- Call linarith
let linarith := do
let cfg : Linarith.LinarithConfig := {
@@ -204,10 +234,25 @@ def intTac (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM Unit := do
splitNe := false
}
Tactic.liftMetaFinishingTactic <| Linarith.linarith false [] cfg
- Tactic.allGoals linarith
-
-elab "int_tac" : tactic =>
- intTac (do pure ())
+ Tactic.allGoals do
+ -- We check if the goal is a linear arithmetic goal: if yes, we directly
+ -- call linarith, otherwise we first apply exfalso (we do this because
+ -- linarith is too general and sometimes fails to do this correctly).
+ if ← goalIsLinearInt then do
+ trace[Arith] "linarith goal: {← Tactic.getMainGoal}"
+ linarith
+ else do
+ let g ← Tactic.getMainGoal
+ let gs ← g.apply (Expr.const ``False.elim [.zero])
+ let goals ← Tactic.getGoals
+ Tactic.setGoals (gs ++ goals)
+ Tactic.allGoals do
+ trace[Arith] "linarith goal: {← Tactic.getMainGoal}"
+ linarith
+
+elab "int_tac" args:(" split_goal"?): tactic =>
+ let split := args.raw.getArgs.size > 0
+ intTac split (do pure ())
example (x : Int) (h0: 0 ≤ x) (h1: x ≠ 0) : 0 < x := by
int_tac_preprocess
@@ -219,10 +264,14 @@ example (x : Int) (h0: 0 ≤ x) (h1: x ≠ 0) : 0 < x := by
-- Checking that things append correctly when there are several disjunctions
example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y := by
- int_tac
+ int_tac split_goal
-- Checking that things append correctly when there are several disjunctions
example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y ∧ x + y ≥ 2 := by
+ int_tac split_goal
+
+-- Checking that we can prove exfalso
+example (a : Prop) (x : Int) (h0: 0 < x) (h1: x < 0) : a := by
int_tac
end Arith
diff --git a/backends/lean/Base/Arith/Scalar.lean b/backends/lean/Base/Arith/Scalar.lean
index f8903ecf..a56ea08b 100644
--- a/backends/lean/Base/Arith/Scalar.lean
+++ b/backends/lean/Base/Arith/Scalar.lean
@@ -28,11 +28,11 @@ elab "scalar_tac_preprocess" : tactic =>
intTacPreprocess scalarTacExtraPreprocess
-- A tactic to solve linear arithmetic goals in the presence of scalars
-def scalarTac : Tactic.TacticM Unit := do
- intTac scalarTacExtraPreprocess
+def scalarTac (splitGoalConjs : Bool) : Tactic.TacticM Unit := do
+ intTac splitGoalConjs scalarTacExtraPreprocess
elab "scalar_tac" : tactic =>
- scalarTac
+ scalarTac false
instance (ty : ScalarTy) : HasIntProp (Scalar ty) where
-- prop_ty is inferred
diff --git a/backends/lean/Base/IList/IList.lean b/backends/lean/Base/IList/IList.lean
index 1773e593..2443b1a6 100644
--- a/backends/lean/Base/IList/IList.lean
+++ b/backends/lean/Base/IList/IList.lean
@@ -46,21 +46,18 @@ theorem indexOpt_bounds (ls : List α) (i : Int) :
ls.indexOpt i = none ↔ i < 0 ∨ ls.len ≤ i :=
match ls with
| [] =>
- have : ¬ (i < 0) → 0 ≤ i := by intro; linarith -- TODO: simplify (we could boost int_tac)
+ have : ¬ (i < 0) → 0 ≤ i := by int_tac
by simp; tauto
| _ :: tl =>
have := indexOpt_bounds tl (i - 1)
if h: i = 0 then
by
simp [*];
- -- TODO: int_tac/scalar_tac should also explore the goal!
- have := tl.len_pos
- linarith
+ int_tac
else by
simp [*]
constructor <;> intros <;>
- -- TODO: tactic to split all disjunctions
- rename_i hor <;> cases hor <;>
+ casesm* _ ∨ _ <;> -- splits all the disjunctions
first | left; int_tac | right; int_tac
theorem indexOpt_eq_index [Inhabited α] (ls : List α) (i : Int) :
@@ -126,7 +123,6 @@ theorem length_update (ls : List α) (i : Int) (x : α) : (ls.update i x).length
theorem len_update (ls : List α) (i : Int) (x : α) : (ls.update i x).len = ls.len := by
simp [len_eq_length]
-
theorem left_length_eq_append_eq (l1 l2 l1' l2' : List α) (heq : l1.length = l1'.length) :
l1 ++ l2 = l1' ++ l2' ↔ l1 = l1' ∧ l2 = l2' := by
revert l1'
@@ -203,7 +199,7 @@ theorem index_eq
(l.update i x).index i = x
:=
fun _ _ => match l with
- | [] => by simp at *; exfalso; scalar_tac -- TODO: exfalso needed. Son FIXME
+ | [] => by simp at *; scalar_tac
| hd :: tl =>
if h: i = 0 then
by
diff --git a/backends/lean/Base/Primitives/Vec.lean b/backends/lean/Base/Primitives/Vec.lean
index be3a0e5b..35092c29 100644
--- a/backends/lean/Base/Primitives/Vec.lean
+++ b/backends/lean/Base/Primitives/Vec.lean
@@ -16,20 +16,19 @@ open Result Error
-- VECTORS --
-------------
-def Vec (α : Type u) := { l : List α // List.length l ≤ Usize.max }
+def Vec (α : Type u) := { l : List α // l.length ≤ Usize.max }
-- TODO: do we really need it? It should be with Subtype by default
-instance Vec.cast (a : Type): Coe (Vec a) (List a) where coe := λ v => v.val
+instance Vec.cast (a : Type u): Coe (Vec a) (List a) where coe := λ v => v.val
-instance (a : Type) : Arith.HasIntProp (Vec a) where
- prop_ty := λ v => v.val.length ≤ Scalar.max ScalarTy.Usize
- prop := λ ⟨ _, l ⟩ => l
+instance (a : Type u) : Arith.HasIntProp (Vec a) where
+ prop_ty := λ v => v.val.len ≤ Scalar.max ScalarTy.Usize
+ prop := λ ⟨ _, l ⟩ => by simp[Scalar.max, List.len_eq_length, *]
-example {a: Type} (v : Vec a) : v.val.length ≤ Scalar.max ScalarTy.Usize := by
- intro_has_int_prop_instances
- simp_all [Scalar.max, Scalar.min]
+@[simp]
+abbrev Vec.length {α : Type u} (v : Vec α) : Int := v.val.len
-example {a: Type} (v : Vec a) : v.val.length ≤ Scalar.max ScalarTy.Usize := by
+example {a: Type u} (v : Vec a) : v.length ≤ Scalar.max ScalarTy.Usize := by
scalar_tac
def Vec.new (α : Type u): Vec α := ⟨ [], by apply Scalar.cMax_suffices .Usize; simp ⟩
@@ -38,9 +37,6 @@ def Vec.len (α : Type u) (v : Vec α) : Usize :=
let ⟨ v, l ⟩ := v
Usize.ofIntCore (List.length v) (by simp [Scalar.min, Usize.min]) l
-@[simp]
-abbrev Vec.length {α : Type u} (v : Vec α) : Int := v.val.len
-
-- This shouldn't be used
def Vec.push_fwd (α : Type u) (_ : Vec α) (_ : α) : Unit := ()
@@ -115,11 +111,14 @@ theorem Vec.index_mut_spec {α : Type u} [Inhabited α] (v: Vec α) (i: Usize) :
have h := List.indexOpt_eq_index v.val i.val (by scalar_tac) (by simp [*])
simp only [*]
+instance {α : Type u} (p : Vec α → Prop) : Arith.HasIntProp (Subtype p) where
+ prop_ty := λ x => p x
+ prop := λ x => x.property
+
def Vec.index_mut_back (α : Type u) (v: Vec α) (i: Usize) (x: α) : Result (Vec α) :=
match v.val.indexOpt i.val with
| none => fail .arrayOutOfBounds
| some _ =>
- -- TODO: int_tac: introduce the refinements in the context?
.ret ⟨ v.val.update i.val x, by have := v.property; simp [*] ⟩
@[pspec]
diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean
index c0ddc63d..a281f1d2 100644
--- a/backends/lean/Base/Progress/Progress.lean
+++ b/backends/lean/Base/Progress/Progress.lean
@@ -307,7 +307,18 @@ def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do
let args := (args.get! 2).getArgs
(args.get! 3).getArgs.size > 0
trace[Progress] "Split post: {splitPost}"
- progressAsmsOrLookupTheorem keep withArg ids splitPost (firstTac [assumptionTac, Arith.scalarTac])
+ /- For scalarTac we have a fast track: if the goal is not a linear
+ arithmetic goal, we skip (note that otherwise, scalarTac would try
+ to prove a contradiction) -/
+ let scalarTac : TacticM Unit := do
+ if ← Arith.goalIsLinearInt then
+ -- Also: we don't try to split the goal if it is a conjunction
+ -- (it shouldn't be)
+ Arith.scalarTac false
+ else
+ throwError "Not a linear arithmetic goal"
+ progressAsmsOrLookupTheorem keep withArg ids splitPost (
+ firstTac [assumptionTac, scalarTac])
elab "progress" args:progressArgs : tactic =>
evalProgress args