From 876137dff361620d8ade1a4ee94fa9274df0bdc6 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 25 Jul 2023 14:08:44 +0200 Subject: Improve int_tac and scalar_tac --- backends/lean/Base/Arith/Int.lean | 63 +++++++++++++++++++++++++++---- backends/lean/Base/Arith/Scalar.lean | 6 +-- backends/lean/Base/IList/IList.lean | 12 ++---- backends/lean/Base/Primitives/Vec.lean | 25 ++++++------ backends/lean/Base/Progress/Progress.lean | 13 ++++++- 5 files changed, 87 insertions(+), 32 deletions(-) (limited to 'backends/lean/Base') diff --git a/backends/lean/Base/Arith/Int.lean b/backends/lean/Base/Arith/Int.lean index fa957293..3415866e 100644 --- a/backends/lean/Base/Arith/Int.lean +++ b/backends/lean/Base/Arith/Int.lean @@ -24,12 +24,29 @@ class PropHasImp (x : Prop) where concl : Prop prop : x → concl +instance (p : Int → Prop) : HasIntProp (Subtype p) where + prop_ty := λ x => p x + prop := λ x => x.property + -- This also works for `x ≠ y` because this expression reduces to `¬ x = y` -- and `Ne` is marked as `reducible` instance (x y : Int) : PropHasImp (¬ x = y) where concl := x < y ∨ x > y prop := λ (h:x ≠ y) => ne_is_lt_or_gt h +-- Check if a proposition is a linear integer proposition. +-- We notably use this to check the goals. +class IsLinearIntProp (x : Prop) where + +instance (x y : Int) : IsLinearIntProp (x < y) where +instance (x y : Int) : IsLinearIntProp (x > y) where +instance (x y : Int) : IsLinearIntProp (x ≤ y) where +instance (x y : Int) : IsLinearIntProp (x ≥ y) where +instance (x y : Int) : IsLinearIntProp (x ≥ y) where +/- It seems we don't need to do any special preprocessing when the *goal* + has the following shape - I guess `linarith` automatically calls `intro` -/ +instance (x y : Int) : IsLinearIntProp (¬ x = y) where + open Lean Lean.Elab Lean.Meta -- Explore a term by decomposing the applications (we explore the applied @@ -189,14 +206,27 @@ def intTacPreprocess (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM U elab "int_tac_preprocess" : tactic => intTacPreprocess (do pure ()) -def intTac (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM Unit := do +-- Check if the goal is a linear arithmetic goal +def goalIsLinearInt : Tactic.TacticM Bool := do + Tactic.withMainContext do + let gty ← Tactic.getMainTarget + match ← trySynthInstance (← mkAppM ``IsLinearIntProp #[gty]) with + | .some _ => pure true + | _ => pure false + +def intTac (splitGoalConjs : Bool) (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM Unit := do Tactic.withMainContext do Tactic.focus do + let g ← Tactic.getMainGoal + trace[Arith] "Original goal: {g}" + -- Introduce all the universally quantified variables (includes the assumptions) + let (_, g) ← g.intros + Tactic.setGoals [g] -- Preprocess - wondering if we should do this before or after splitting -- the goal. I think before leads to a smaller proof term? Tactic.allGoals (intTacPreprocess extraPreprocess) -- Split the conjunctions in the goal - Tactic.allGoals (Utils.repeatTac Utils.splitConjTarget) + if splitGoalConjs then Tactic.allGoals (Utils.repeatTac Utils.splitConjTarget) -- Call linarith let linarith := do let cfg : Linarith.LinarithConfig := { @@ -204,10 +234,25 @@ def intTac (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM Unit := do splitNe := false } Tactic.liftMetaFinishingTactic <| Linarith.linarith false [] cfg - Tactic.allGoals linarith - -elab "int_tac" : tactic => - intTac (do pure ()) + Tactic.allGoals do + -- We check if the goal is a linear arithmetic goal: if yes, we directly + -- call linarith, otherwise we first apply exfalso (we do this because + -- linarith is too general and sometimes fails to do this correctly). + if ← goalIsLinearInt then do + trace[Arith] "linarith goal: {← Tactic.getMainGoal}" + linarith + else do + let g ← Tactic.getMainGoal + let gs ← g.apply (Expr.const ``False.elim [.zero]) + let goals ← Tactic.getGoals + Tactic.setGoals (gs ++ goals) + Tactic.allGoals do + trace[Arith] "linarith goal: {← Tactic.getMainGoal}" + linarith + +elab "int_tac" args:(" split_goal"?): tactic => + let split := args.raw.getArgs.size > 0 + intTac split (do pure ()) example (x : Int) (h0: 0 ≤ x) (h1: x ≠ 0) : 0 < x := by int_tac_preprocess @@ -219,10 +264,14 @@ example (x : Int) (h0: 0 ≤ x) (h1: x ≠ 0) : 0 < x := by -- Checking that things append correctly when there are several disjunctions example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y := by - int_tac + int_tac split_goal -- Checking that things append correctly when there are several disjunctions example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y ∧ x + y ≥ 2 := by + int_tac split_goal + +-- Checking that we can prove exfalso +example (a : Prop) (x : Int) (h0: 0 < x) (h1: x < 0) : a := by int_tac end Arith diff --git a/backends/lean/Base/Arith/Scalar.lean b/backends/lean/Base/Arith/Scalar.lean index f8903ecf..a56ea08b 100644 --- a/backends/lean/Base/Arith/Scalar.lean +++ b/backends/lean/Base/Arith/Scalar.lean @@ -28,11 +28,11 @@ elab "scalar_tac_preprocess" : tactic => intTacPreprocess scalarTacExtraPreprocess -- A tactic to solve linear arithmetic goals in the presence of scalars -def scalarTac : Tactic.TacticM Unit := do - intTac scalarTacExtraPreprocess +def scalarTac (splitGoalConjs : Bool) : Tactic.TacticM Unit := do + intTac splitGoalConjs scalarTacExtraPreprocess elab "scalar_tac" : tactic => - scalarTac + scalarTac false instance (ty : ScalarTy) : HasIntProp (Scalar ty) where -- prop_ty is inferred diff --git a/backends/lean/Base/IList/IList.lean b/backends/lean/Base/IList/IList.lean index 1773e593..2443b1a6 100644 --- a/backends/lean/Base/IList/IList.lean +++ b/backends/lean/Base/IList/IList.lean @@ -46,21 +46,18 @@ theorem indexOpt_bounds (ls : List α) (i : Int) : ls.indexOpt i = none ↔ i < 0 ∨ ls.len ≤ i := match ls with | [] => - have : ¬ (i < 0) → 0 ≤ i := by intro; linarith -- TODO: simplify (we could boost int_tac) + have : ¬ (i < 0) → 0 ≤ i := by int_tac by simp; tauto | _ :: tl => have := indexOpt_bounds tl (i - 1) if h: i = 0 then by simp [*]; - -- TODO: int_tac/scalar_tac should also explore the goal! - have := tl.len_pos - linarith + int_tac else by simp [*] constructor <;> intros <;> - -- TODO: tactic to split all disjunctions - rename_i hor <;> cases hor <;> + casesm* _ ∨ _ <;> -- splits all the disjunctions first | left; int_tac | right; int_tac theorem indexOpt_eq_index [Inhabited α] (ls : List α) (i : Int) : @@ -126,7 +123,6 @@ theorem length_update (ls : List α) (i : Int) (x : α) : (ls.update i x).length theorem len_update (ls : List α) (i : Int) (x : α) : (ls.update i x).len = ls.len := by simp [len_eq_length] - theorem left_length_eq_append_eq (l1 l2 l1' l2' : List α) (heq : l1.length = l1'.length) : l1 ++ l2 = l1' ++ l2' ↔ l1 = l1' ∧ l2 = l2' := by revert l1' @@ -203,7 +199,7 @@ theorem index_eq (l.update i x).index i = x := fun _ _ => match l with - | [] => by simp at *; exfalso; scalar_tac -- TODO: exfalso needed. Son FIXME + | [] => by simp at *; scalar_tac | hd :: tl => if h: i = 0 then by diff --git a/backends/lean/Base/Primitives/Vec.lean b/backends/lean/Base/Primitives/Vec.lean index be3a0e5b..35092c29 100644 --- a/backends/lean/Base/Primitives/Vec.lean +++ b/backends/lean/Base/Primitives/Vec.lean @@ -16,20 +16,19 @@ open Result Error -- VECTORS -- ------------- -def Vec (α : Type u) := { l : List α // List.length l ≤ Usize.max } +def Vec (α : Type u) := { l : List α // l.length ≤ Usize.max } -- TODO: do we really need it? It should be with Subtype by default -instance Vec.cast (a : Type): Coe (Vec a) (List a) where coe := λ v => v.val +instance Vec.cast (a : Type u): Coe (Vec a) (List a) where coe := λ v => v.val -instance (a : Type) : Arith.HasIntProp (Vec a) where - prop_ty := λ v => v.val.length ≤ Scalar.max ScalarTy.Usize - prop := λ ⟨ _, l ⟩ => l +instance (a : Type u) : Arith.HasIntProp (Vec a) where + prop_ty := λ v => v.val.len ≤ Scalar.max ScalarTy.Usize + prop := λ ⟨ _, l ⟩ => by simp[Scalar.max, List.len_eq_length, *] -example {a: Type} (v : Vec a) : v.val.length ≤ Scalar.max ScalarTy.Usize := by - intro_has_int_prop_instances - simp_all [Scalar.max, Scalar.min] +@[simp] +abbrev Vec.length {α : Type u} (v : Vec α) : Int := v.val.len -example {a: Type} (v : Vec a) : v.val.length ≤ Scalar.max ScalarTy.Usize := by +example {a: Type u} (v : Vec a) : v.length ≤ Scalar.max ScalarTy.Usize := by scalar_tac def Vec.new (α : Type u): Vec α := ⟨ [], by apply Scalar.cMax_suffices .Usize; simp ⟩ @@ -38,9 +37,6 @@ def Vec.len (α : Type u) (v : Vec α) : Usize := let ⟨ v, l ⟩ := v Usize.ofIntCore (List.length v) (by simp [Scalar.min, Usize.min]) l -@[simp] -abbrev Vec.length {α : Type u} (v : Vec α) : Int := v.val.len - -- This shouldn't be used def Vec.push_fwd (α : Type u) (_ : Vec α) (_ : α) : Unit := () @@ -115,11 +111,14 @@ theorem Vec.index_mut_spec {α : Type u} [Inhabited α] (v: Vec α) (i: Usize) : have h := List.indexOpt_eq_index v.val i.val (by scalar_tac) (by simp [*]) simp only [*] +instance {α : Type u} (p : Vec α → Prop) : Arith.HasIntProp (Subtype p) where + prop_ty := λ x => p x + prop := λ x => x.property + def Vec.index_mut_back (α : Type u) (v: Vec α) (i: Usize) (x: α) : Result (Vec α) := match v.val.indexOpt i.val with | none => fail .arrayOutOfBounds | some _ => - -- TODO: int_tac: introduce the refinements in the context? .ret ⟨ v.val.update i.val x, by have := v.property; simp [*] ⟩ @[pspec] diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean index c0ddc63d..a281f1d2 100644 --- a/backends/lean/Base/Progress/Progress.lean +++ b/backends/lean/Base/Progress/Progress.lean @@ -307,7 +307,18 @@ def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do let args := (args.get! 2).getArgs (args.get! 3).getArgs.size > 0 trace[Progress] "Split post: {splitPost}" - progressAsmsOrLookupTheorem keep withArg ids splitPost (firstTac [assumptionTac, Arith.scalarTac]) + /- For scalarTac we have a fast track: if the goal is not a linear + arithmetic goal, we skip (note that otherwise, scalarTac would try + to prove a contradiction) -/ + let scalarTac : TacticM Unit := do + if ← Arith.goalIsLinearInt then + -- Also: we don't try to split the goal if it is a conjunction + -- (it shouldn't be) + Arith.scalarTac false + else + throwError "Not a linear arithmetic goal" + progressAsmsOrLookupTheorem keep withArg ids splitPost ( + firstTac [assumptionTac, scalarTac]) elab "progress" args:progressArgs : tactic => evalProgress args -- cgit v1.2.3