summaryrefslogtreecommitdiff
path: root/backends/lean/Base/Arith
diff options
context:
space:
mode:
Diffstat (limited to 'backends/lean/Base/Arith')
-rw-r--r--backends/lean/Base/Arith/Arith.lean0
-rw-r--r--backends/lean/Base/Arith/Base.lean60
-rw-r--r--backends/lean/Base/Arith/Int.lean280
-rw-r--r--backends/lean/Base/Arith/Scalar.lean49
4 files changed, 389 insertions, 0 deletions
diff --git a/backends/lean/Base/Arith/Arith.lean b/backends/lean/Base/Arith/Arith.lean
new file mode 100644
index 00000000..e69de29b
--- /dev/null
+++ b/backends/lean/Base/Arith/Arith.lean
diff --git a/backends/lean/Base/Arith/Base.lean b/backends/lean/Base/Arith/Base.lean
new file mode 100644
index 00000000..9c11ed45
--- /dev/null
+++ b/backends/lean/Base/Arith/Base.lean
@@ -0,0 +1,60 @@
+import Lean
+import Std.Data.Int.Lemmas
+import Mathlib.Tactic.Linarith
+
+namespace Arith
+
+open Lean Elab Term Meta
+
+-- We can't define and use trace classes in the same file
+initialize registerTraceClass `Arith
+
+-- TODO: move?
+theorem ne_zero_is_lt_or_gt {x : Int} (hne : x ≠ 0) : x < 0 ∨ x > 0 := by
+ cases h: x <;> simp_all
+ . rename_i n;
+ cases n <;> simp_all
+ . apply Int.negSucc_lt_zero
+
+-- TODO: move?
+theorem ne_is_lt_or_gt {x y : Int} (hne : x ≠ y) : x < y ∨ x > y := by
+ have hne : x - y ≠ 0 := by
+ simp
+ intro h
+ have: x = y := by linarith
+ simp_all
+ have h := ne_zero_is_lt_or_gt hne
+ match h with
+ | .inl _ => left; linarith
+ | .inr _ => right; linarith
+
+-- TODO: move?
+theorem add_one_le_iff_le_ne (n m : Nat) (h1 : m ≤ n) (h2 : m ≠ n) : m + 1 ≤ n := by
+ -- Damn, those proofs on natural numbers are hard - I wish Omega was in mathlib4...
+ simp [Nat.add_one_le_iff]
+ simp [Nat.lt_iff_le_and_ne]
+ simp_all
+
+/- Induction over positive integers -/
+-- TODO: move
+theorem int_pos_ind (p : Int → Prop) :
+ (zero:p 0) → (pos:∀ i, 0 ≤ i → p i → p (i + 1)) → ∀ i, 0 ≤ i → p i := by
+ intro h0 hr i hpos
+-- have heq : Int.toNat i = i := by
+-- cases i <;> simp_all
+ have ⟨ n, heq ⟩ : {n:Nat // n = i } := ⟨ Int.toNat i, by cases i <;> simp_all ⟩
+ revert i
+ induction n
+ . intro i hpos heq
+ cases i <;> simp_all
+ . rename_i n hi
+ intro i hpos heq
+ cases i <;> simp_all
+ rename_i m
+ cases m <;> simp_all
+
+-- We sometimes need this to make sure no natural numbers appear in the goals
+-- TODO: there is probably something more general to do
+theorem nat_zero_eq_int_zero : (0 : Nat) = (0 : Int) := by simp
+
+end Arith
diff --git a/backends/lean/Base/Arith/Int.lean b/backends/lean/Base/Arith/Int.lean
new file mode 100644
index 00000000..7a5bbe98
--- /dev/null
+++ b/backends/lean/Base/Arith/Int.lean
@@ -0,0 +1,280 @@
+/- This file contains tactics to solve arithmetic goals -/
+
+import Lean
+import Lean.Meta.Tactic.Simp
+import Init.Data.List.Basic
+import Mathlib.Tactic.RunCmd
+import Mathlib.Tactic.Linarith
+-- TODO: there is no Omega tactic for now - it seems it hasn't been ported yet
+--import Mathlib.Tactic.Omega
+import Base.Utils
+import Base.Arith.Base
+
+namespace Arith
+
+open Utils
+
+-- Remark: I tried a version of the shape `HasScalarProp {a : Type} (x : a)`
+-- but the lookup didn't work
+class HasIntProp (a : Sort u) where
+ prop_ty : a → Prop
+ prop : ∀ x:a, prop_ty x
+
+class PropHasImp (x : Prop) where
+ concl : Prop
+ prop : x → concl
+
+instance (p : Int → Prop) : HasIntProp (Subtype p) where
+ prop_ty := λ x => p x
+ prop := λ x => x.property
+
+-- This also works for `x ≠ y` because this expression reduces to `¬ x = y`
+-- and `Ne` is marked as `reducible`
+instance (x y : Int) : PropHasImp (¬ x = y) where
+ concl := x < y ∨ x > y
+ prop := λ (h:x ≠ y) => ne_is_lt_or_gt h
+
+-- Check if a proposition is a linear integer proposition.
+-- We notably use this to check the goals.
+class IsLinearIntProp (x : Prop) where
+
+instance (x y : Int) : IsLinearIntProp (x < y) where
+instance (x y : Int) : IsLinearIntProp (x > y) where
+instance (x y : Int) : IsLinearIntProp (x ≤ y) where
+instance (x y : Int) : IsLinearIntProp (x ≥ y) where
+instance (x y : Int) : IsLinearIntProp (x ≥ y) where
+instance (x y : Int) : IsLinearIntProp (x = y) where
+/- It seems we don't need to do any special preprocessing when the *goal*
+ has the following shape - I guess `linarith` automatically calls `intro` -/
+instance (x y : Int) : IsLinearIntProp (¬ x = y) where
+
+open Lean Lean.Elab Lean.Meta
+
+-- Explore a term by decomposing the applications (we explore the applied
+-- functions and their arguments, but ignore lambdas, forall, etc. -
+-- should we go inside?).
+partial def foldTermApps (k : α → Expr → MetaM α) (s : α) (e : Expr) : MetaM α := do
+ -- We do it in a very simpler manner: we deconstruct applications,
+ -- and recursively explore the sub-expressions. Note that we do
+ -- not go inside foralls and abstractions (should we?).
+ e.withApp fun f args => do
+ let s ← k s f
+ args.foldlM (foldTermApps k) s
+
+-- Provided a function `k` which lookups type class instances on an expression,
+-- collect all the instances lookuped by applying `k` on the sub-expressions of `e`.
+def collectInstances
+ (k : Expr → MetaM (Option Expr)) (s : HashSet Expr) (e : Expr) : MetaM (HashSet Expr) := do
+ let k s e := do
+ match ← k e with
+ | none => pure s
+ | some i => pure (s.insert i)
+ foldTermApps k s e
+
+-- Similar to `collectInstances`, but explores all the local declarations in the
+-- main context.
+def collectInstancesFromMainCtx (k : Expr → MetaM (Option Expr)) : Tactic.TacticM (HashSet Expr) := do
+ Tactic.withMainContext do
+ -- Get the local context
+ let ctx ← Lean.MonadLCtx.getLCtx
+ -- Just a matter of precaution
+ let ctx ← instantiateLCtxMVars ctx
+ -- Initialize the hashset
+ let hs := HashSet.empty
+ -- Explore the declarations
+ let decls ← ctx.getDecls
+ decls.foldlM (fun hs d => collectInstances k hs d.toExpr) hs
+
+-- Helper
+def lookupProp (fName : String) (className : Name) (e : Expr) : MetaM (Option Expr) := do
+ trace[Arith] fName
+ -- TODO: do we need Lean.observing?
+ -- This actually eliminates the error messages
+ Lean.observing? do
+ trace[Arith] m!"{fName}: observing"
+ let ty ← Lean.Meta.inferType e
+ let hasProp ← mkAppM className #[ty]
+ let hasPropInst ← trySynthInstance hasProp
+ match hasPropInst with
+ | LOption.some i =>
+ trace[Arith] "Found {fName} instance"
+ let i_prop ← mkProjection i (Name.mkSimple "prop")
+ some (← mkAppM' i_prop #[e])
+ | _ => none
+
+-- Return an instance of `HasIntProp` for `e` if it has some
+def lookupHasIntProp (e : Expr) : MetaM (Option Expr) :=
+ lookupProp "lookupHasIntProp" ``HasIntProp e
+
+-- Collect the instances of `HasIntProp` for the subexpressions in the context
+def collectHasIntPropInstancesFromMainCtx : Tactic.TacticM (HashSet Expr) := do
+ collectInstancesFromMainCtx lookupHasIntProp
+
+-- Return an instance of `PropHasImp` for `e` if it has some
+def lookupPropHasImp (e : Expr) : MetaM (Option Expr) := do
+ trace[Arith] "lookupPropHasImp"
+ -- TODO: do we need Lean.observing?
+ -- This actually eliminates the error messages
+ Lean.observing? do
+ trace[Arith] "lookupPropHasImp: observing"
+ let ty ← Lean.Meta.inferType e
+ trace[Arith] "lookupPropHasImp: ty: {ty}"
+ let cl ← mkAppM ``PropHasImp #[ty]
+ let inst ← trySynthInstance cl
+ match inst with
+ | LOption.some i =>
+ trace[Arith] "Found PropHasImp instance"
+ let i_prop ← mkProjection i (Name.mkSimple "prop")
+ some (← mkAppM' i_prop #[e])
+ | _ => none
+
+-- Collect the instances of `PropHasImp` for the subexpressions in the context
+def collectPropHasImpInstancesFromMainCtx : Tactic.TacticM (HashSet Expr) := do
+ collectInstancesFromMainCtx lookupPropHasImp
+
+elab "display_prop_has_imp_instances" : tactic => do
+ trace[Arith] "Displaying the PropHasImp instances"
+ let hs ← collectPropHasImpInstancesFromMainCtx
+ hs.forM fun e => do
+ trace[Arith] "+ PropHasImp instance: {e}"
+
+example (x y : Int) (_ : x ≠ y) (_ : ¬ x = y) : True := by
+ display_prop_has_imp_instances
+ simp
+
+-- Lookup instances in a context and introduce them with additional declarations.
+def introInstances (declToUnfold : Name) (lookup : Expr → MetaM (Option Expr)) : Tactic.TacticM (Array Expr) := do
+ let hs ← collectInstancesFromMainCtx lookup
+ hs.toArray.mapM fun e => do
+ let type ← inferType e
+ let name ← mkFreshAnonPropUserName
+ -- Add a declaration
+ let nval ← Utils.addDeclTac name e type (asLet := false)
+ -- Simplify to unfold the declaration to unfold (i.e., the projector)
+ Utils.simpAt [declToUnfold] [] [] (Tactic.Location.targets #[mkIdent name] false)
+ -- Return the new value
+ pure nval
+
+def introHasIntPropInstances : Tactic.TacticM (Array Expr) := do
+ trace[Arith] "Introducing the HasIntProp instances"
+ introInstances ``HasIntProp.prop_ty lookupHasIntProp
+
+-- Lookup the instances of `HasIntProp for all the sub-expressions in the context,
+-- and introduce the corresponding assumptions
+elab "intro_has_int_prop_instances" : tactic => do
+ let _ ← introHasIntPropInstances
+
+-- Lookup the instances of `PropHasImp for all the sub-expressions in the context,
+-- and introduce the corresponding assumptions
+elab "intro_prop_has_imp_instances" : tactic => do
+ trace[Arith] "Introducing the PropHasImp instances"
+ let _ ← introInstances ``PropHasImp.concl lookupPropHasImp
+
+example (x y : Int) (h0 : x ≤ y) (h1 : x ≠ y) : x < y := by
+ intro_prop_has_imp_instances
+ rename_i h
+ split_disj h
+ . linarith
+ . linarith
+
+/- Boosting a bit the linarith tac.
+
+ We do the following:
+ - for all the assumptions of the shape `(x : Int) ≠ y` or `¬ (x = y), we
+ introduce two goals with the assumptions `x < y` and `x > y`
+ TODO: we could create a PR for mathlib.
+ -/
+def intTacPreprocess (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM Unit := do
+ Tactic.withMainContext do
+ -- Lookup the instances of PropHasImp (this is how we detect assumptions
+ -- of the proper shape), introduce assumptions in the context and split
+ -- on those
+ -- TODO: get rid of the assumptions that we split
+ let rec splitOnAsms (asms : List Expr) : Tactic.TacticM Unit :=
+ match asms with
+ | [] => pure ()
+ | asm :: asms =>
+ let k := splitOnAsms asms
+ Utils.splitDisjTac asm k k
+ -- Introduce the scalar bounds
+ let _ ← introHasIntPropInstances
+ -- Extra preprocessing, before we split on the disjunctions
+ extraPreprocess
+ -- Split
+ let asms ← introInstances ``PropHasImp.concl lookupPropHasImp
+ splitOnAsms asms.toList
+
+elab "int_tac_preprocess" : tactic =>
+ intTacPreprocess (do pure ())
+
+-- Check if the goal is a linear arithmetic goal
+def goalIsLinearInt : Tactic.TacticM Bool := do
+ Tactic.withMainContext do
+ let gty ← Tactic.getMainTarget
+ match ← trySynthInstance (← mkAppM ``IsLinearIntProp #[gty]) with
+ | .some _ => pure true
+ | _ => pure false
+
+def intTac (splitGoalConjs : Bool) (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM Unit := do
+ Tactic.withMainContext do
+ Tactic.focus do
+ let g ← Tactic.getMainGoal
+ trace[Arith] "Original goal: {g}"
+ -- Introduce all the universally quantified variables (includes the assumptions)
+ let (_, g) ← g.intros
+ Tactic.setGoals [g]
+ -- Preprocess - wondering if we should do this before or after splitting
+ -- the goal. I think before leads to a smaller proof term?
+ Tactic.allGoals (intTacPreprocess extraPreprocess)
+ -- More preprocessing
+ Tactic.allGoals (Utils.simpAt [] [``nat_zero_eq_int_zero] [] .wildcard)
+ -- Split the conjunctions in the goal
+ if splitGoalConjs then Tactic.allGoals (Utils.repeatTac Utils.splitConjTarget)
+ -- Call linarith
+ let linarith := do
+ let cfg : Linarith.LinarithConfig := {
+ -- We do this with our custom preprocessing
+ splitNe := false
+ }
+ Tactic.liftMetaFinishingTactic <| Linarith.linarith false [] cfg
+ Tactic.allGoals do
+ -- We check if the goal is a linear arithmetic goal: if yes, we directly
+ -- call linarith, otherwise we first apply exfalso (we do this because
+ -- linarith is too general and sometimes fails to do this correctly).
+ if ← goalIsLinearInt then do
+ trace[Arith] "linarith goal: {← Tactic.getMainGoal}"
+ linarith
+ else do
+ let g ← Tactic.getMainGoal
+ let gs ← g.apply (Expr.const ``False.elim [.zero])
+ let goals ← Tactic.getGoals
+ Tactic.setGoals (gs ++ goals)
+ Tactic.allGoals do
+ trace[Arith] "linarith goal: {← Tactic.getMainGoal}"
+ linarith
+
+elab "int_tac" args:(" split_goal"?): tactic =>
+ let split := args.raw.getArgs.size > 0
+ intTac split (do pure ())
+
+example (x : Int) (h0: 0 ≤ x) (h1: x ≠ 0) : 0 < x := by
+ int_tac_preprocess
+ linarith
+ linarith
+
+example (x : Int) (h0: 0 ≤ x) (h1: x ≠ 0) : 0 < x := by
+ int_tac
+
+-- Checking that things append correctly when there are several disjunctions
+example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y := by
+ int_tac split_goal
+
+-- Checking that things append correctly when there are several disjunctions
+example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y ∧ x + y ≥ 2 := by
+ int_tac split_goal
+
+-- Checking that we can prove exfalso
+example (a : Prop) (x : Int) (h0: 0 < x) (h1: x < 0) : a := by
+ int_tac
+
+end Arith
diff --git a/backends/lean/Base/Arith/Scalar.lean b/backends/lean/Base/Arith/Scalar.lean
new file mode 100644
index 00000000..b792ff21
--- /dev/null
+++ b/backends/lean/Base/Arith/Scalar.lean
@@ -0,0 +1,49 @@
+import Base.Arith.Int
+import Base.Primitives.Scalar
+
+/- Automation for scalars - TODO: not sure it is worth having two files (Int.lean and Scalar.lean) -/
+namespace Arith
+
+open Lean Lean.Elab Lean.Meta
+open Primitives
+
+def scalarTacExtraPreprocess : Tactic.TacticM Unit := do
+ Tactic.withMainContext do
+ -- Inroduce the bounds for the isize/usize types
+ let add (e : Expr) : Tactic.TacticM Unit := do
+ let ty ← inferType e
+ let _ ← Utils.addDeclTac (← Utils.mkFreshAnonPropUserName) e ty (asLet := false)
+ add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Isize []])
+ add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []])
+ add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []])
+ -- Reveal the concrete bounds
+ Utils.simpAt [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax,
+ ``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min,
+ ``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max,
+ ``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min,
+ ``U8.max, ``U16.max, ``U32.max, ``U64.max, ``U128.max,
+ ``Usize.min
+ ] [] [] .wildcard
+
+elab "scalar_tac_preprocess" : tactic =>
+ intTacPreprocess scalarTacExtraPreprocess
+
+-- A tactic to solve linear arithmetic goals in the presence of scalars
+def scalarTac (splitGoalConjs : Bool) : Tactic.TacticM Unit := do
+ intTac splitGoalConjs scalarTacExtraPreprocess
+
+elab "scalar_tac" : tactic =>
+ scalarTac false
+
+instance (ty : ScalarTy) : HasIntProp (Scalar ty) where
+ -- prop_ty is inferred
+ prop := λ x => And.intro x.hmin x.hmax
+
+example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by
+ intro_has_int_prop_instances
+ simp [*]
+
+example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by
+ scalar_tac
+
+end Arith