diff options
Diffstat (limited to '')
-rw-r--r-- | backends/lean/Base/Arith.lean | 2 | ||||
-rw-r--r-- | backends/lean/Base/Arith/Arith.lean | 0 | ||||
-rw-r--r-- | backends/lean/Base/Arith/Base.lean | 60 | ||||
-rw-r--r-- | backends/lean/Base/Arith/Int.lean | 280 | ||||
-rw-r--r-- | backends/lean/Base/Arith/Scalar.lean | 49 |
5 files changed, 391 insertions, 0 deletions
diff --git a/backends/lean/Base/Arith.lean b/backends/lean/Base/Arith.lean new file mode 100644 index 00000000..c0d09fd2 --- /dev/null +++ b/backends/lean/Base/Arith.lean @@ -0,0 +1,2 @@ +import Base.Arith.Int +import Base.Arith.Scalar diff --git a/backends/lean/Base/Arith/Arith.lean b/backends/lean/Base/Arith/Arith.lean new file mode 100644 index 00000000..e69de29b --- /dev/null +++ b/backends/lean/Base/Arith/Arith.lean diff --git a/backends/lean/Base/Arith/Base.lean b/backends/lean/Base/Arith/Base.lean new file mode 100644 index 00000000..9c11ed45 --- /dev/null +++ b/backends/lean/Base/Arith/Base.lean @@ -0,0 +1,60 @@ +import Lean +import Std.Data.Int.Lemmas +import Mathlib.Tactic.Linarith + +namespace Arith + +open Lean Elab Term Meta + +-- We can't define and use trace classes in the same file +initialize registerTraceClass `Arith + +-- TODO: move? +theorem ne_zero_is_lt_or_gt {x : Int} (hne : x ≠ 0) : x < 0 ∨ x > 0 := by + cases h: x <;> simp_all + . rename_i n; + cases n <;> simp_all + . apply Int.negSucc_lt_zero + +-- TODO: move? +theorem ne_is_lt_or_gt {x y : Int} (hne : x ≠ y) : x < y ∨ x > y := by + have hne : x - y ≠ 0 := by + simp + intro h + have: x = y := by linarith + simp_all + have h := ne_zero_is_lt_or_gt hne + match h with + | .inl _ => left; linarith + | .inr _ => right; linarith + +-- TODO: move? +theorem add_one_le_iff_le_ne (n m : Nat) (h1 : m ≤ n) (h2 : m ≠ n) : m + 1 ≤ n := by + -- Damn, those proofs on natural numbers are hard - I wish Omega was in mathlib4... + simp [Nat.add_one_le_iff] + simp [Nat.lt_iff_le_and_ne] + simp_all + +/- Induction over positive integers -/ +-- TODO: move +theorem int_pos_ind (p : Int → Prop) : + (zero:p 0) → (pos:∀ i, 0 ≤ i → p i → p (i + 1)) → ∀ i, 0 ≤ i → p i := by + intro h0 hr i hpos +-- have heq : Int.toNat i = i := by +-- cases i <;> simp_all + have ⟨ n, heq ⟩ : {n:Nat // n = i } := ⟨ Int.toNat i, by cases i <;> simp_all ⟩ + revert i + induction n + . intro i hpos heq + cases i <;> simp_all + . rename_i n hi + intro i hpos heq + cases i <;> simp_all + rename_i m + cases m <;> simp_all + +-- We sometimes need this to make sure no natural numbers appear in the goals +-- TODO: there is probably something more general to do +theorem nat_zero_eq_int_zero : (0 : Nat) = (0 : Int) := by simp + +end Arith diff --git a/backends/lean/Base/Arith/Int.lean b/backends/lean/Base/Arith/Int.lean new file mode 100644 index 00000000..7a5bbe98 --- /dev/null +++ b/backends/lean/Base/Arith/Int.lean @@ -0,0 +1,280 @@ +/- This file contains tactics to solve arithmetic goals -/ + +import Lean +import Lean.Meta.Tactic.Simp +import Init.Data.List.Basic +import Mathlib.Tactic.RunCmd +import Mathlib.Tactic.Linarith +-- TODO: there is no Omega tactic for now - it seems it hasn't been ported yet +--import Mathlib.Tactic.Omega +import Base.Utils +import Base.Arith.Base + +namespace Arith + +open Utils + +-- Remark: I tried a version of the shape `HasScalarProp {a : Type} (x : a)` +-- but the lookup didn't work +class HasIntProp (a : Sort u) where + prop_ty : a → Prop + prop : ∀ x:a, prop_ty x + +class PropHasImp (x : Prop) where + concl : Prop + prop : x → concl + +instance (p : Int → Prop) : HasIntProp (Subtype p) where + prop_ty := λ x => p x + prop := λ x => x.property + +-- This also works for `x ≠ y` because this expression reduces to `¬ x = y` +-- and `Ne` is marked as `reducible` +instance (x y : Int) : PropHasImp (¬ x = y) where + concl := x < y ∨ x > y + prop := λ (h:x ≠ y) => ne_is_lt_or_gt h + +-- Check if a proposition is a linear integer proposition. +-- We notably use this to check the goals. +class IsLinearIntProp (x : Prop) where + +instance (x y : Int) : IsLinearIntProp (x < y) where +instance (x y : Int) : IsLinearIntProp (x > y) where +instance (x y : Int) : IsLinearIntProp (x ≤ y) where +instance (x y : Int) : IsLinearIntProp (x ≥ y) where +instance (x y : Int) : IsLinearIntProp (x ≥ y) where +instance (x y : Int) : IsLinearIntProp (x = y) where +/- It seems we don't need to do any special preprocessing when the *goal* + has the following shape - I guess `linarith` automatically calls `intro` -/ +instance (x y : Int) : IsLinearIntProp (¬ x = y) where + +open Lean Lean.Elab Lean.Meta + +-- Explore a term by decomposing the applications (we explore the applied +-- functions and their arguments, but ignore lambdas, forall, etc. - +-- should we go inside?). +partial def foldTermApps (k : α → Expr → MetaM α) (s : α) (e : Expr) : MetaM α := do + -- We do it in a very simpler manner: we deconstruct applications, + -- and recursively explore the sub-expressions. Note that we do + -- not go inside foralls and abstractions (should we?). + e.withApp fun f args => do + let s ← k s f + args.foldlM (foldTermApps k) s + +-- Provided a function `k` which lookups type class instances on an expression, +-- collect all the instances lookuped by applying `k` on the sub-expressions of `e`. +def collectInstances + (k : Expr → MetaM (Option Expr)) (s : HashSet Expr) (e : Expr) : MetaM (HashSet Expr) := do + let k s e := do + match ← k e with + | none => pure s + | some i => pure (s.insert i) + foldTermApps k s e + +-- Similar to `collectInstances`, but explores all the local declarations in the +-- main context. +def collectInstancesFromMainCtx (k : Expr → MetaM (Option Expr)) : Tactic.TacticM (HashSet Expr) := do + Tactic.withMainContext do + -- Get the local context + let ctx ← Lean.MonadLCtx.getLCtx + -- Just a matter of precaution + let ctx ← instantiateLCtxMVars ctx + -- Initialize the hashset + let hs := HashSet.empty + -- Explore the declarations + let decls ← ctx.getDecls + decls.foldlM (fun hs d => collectInstances k hs d.toExpr) hs + +-- Helper +def lookupProp (fName : String) (className : Name) (e : Expr) : MetaM (Option Expr) := do + trace[Arith] fName + -- TODO: do we need Lean.observing? + -- This actually eliminates the error messages + Lean.observing? do + trace[Arith] m!"{fName}: observing" + let ty ← Lean.Meta.inferType e + let hasProp ← mkAppM className #[ty] + let hasPropInst ← trySynthInstance hasProp + match hasPropInst with + | LOption.some i => + trace[Arith] "Found {fName} instance" + let i_prop ← mkProjection i (Name.mkSimple "prop") + some (← mkAppM' i_prop #[e]) + | _ => none + +-- Return an instance of `HasIntProp` for `e` if it has some +def lookupHasIntProp (e : Expr) : MetaM (Option Expr) := + lookupProp "lookupHasIntProp" ``HasIntProp e + +-- Collect the instances of `HasIntProp` for the subexpressions in the context +def collectHasIntPropInstancesFromMainCtx : Tactic.TacticM (HashSet Expr) := do + collectInstancesFromMainCtx lookupHasIntProp + +-- Return an instance of `PropHasImp` for `e` if it has some +def lookupPropHasImp (e : Expr) : MetaM (Option Expr) := do + trace[Arith] "lookupPropHasImp" + -- TODO: do we need Lean.observing? + -- This actually eliminates the error messages + Lean.observing? do + trace[Arith] "lookupPropHasImp: observing" + let ty ← Lean.Meta.inferType e + trace[Arith] "lookupPropHasImp: ty: {ty}" + let cl ← mkAppM ``PropHasImp #[ty] + let inst ← trySynthInstance cl + match inst with + | LOption.some i => + trace[Arith] "Found PropHasImp instance" + let i_prop ← mkProjection i (Name.mkSimple "prop") + some (← mkAppM' i_prop #[e]) + | _ => none + +-- Collect the instances of `PropHasImp` for the subexpressions in the context +def collectPropHasImpInstancesFromMainCtx : Tactic.TacticM (HashSet Expr) := do + collectInstancesFromMainCtx lookupPropHasImp + +elab "display_prop_has_imp_instances" : tactic => do + trace[Arith] "Displaying the PropHasImp instances" + let hs ← collectPropHasImpInstancesFromMainCtx + hs.forM fun e => do + trace[Arith] "+ PropHasImp instance: {e}" + +example (x y : Int) (_ : x ≠ y) (_ : ¬ x = y) : True := by + display_prop_has_imp_instances + simp + +-- Lookup instances in a context and introduce them with additional declarations. +def introInstances (declToUnfold : Name) (lookup : Expr → MetaM (Option Expr)) : Tactic.TacticM (Array Expr) := do + let hs ← collectInstancesFromMainCtx lookup + hs.toArray.mapM fun e => do + let type ← inferType e + let name ← mkFreshAnonPropUserName + -- Add a declaration + let nval ← Utils.addDeclTac name e type (asLet := false) + -- Simplify to unfold the declaration to unfold (i.e., the projector) + Utils.simpAt [declToUnfold] [] [] (Tactic.Location.targets #[mkIdent name] false) + -- Return the new value + pure nval + +def introHasIntPropInstances : Tactic.TacticM (Array Expr) := do + trace[Arith] "Introducing the HasIntProp instances" + introInstances ``HasIntProp.prop_ty lookupHasIntProp + +-- Lookup the instances of `HasIntProp for all the sub-expressions in the context, +-- and introduce the corresponding assumptions +elab "intro_has_int_prop_instances" : tactic => do + let _ ← introHasIntPropInstances + +-- Lookup the instances of `PropHasImp for all the sub-expressions in the context, +-- and introduce the corresponding assumptions +elab "intro_prop_has_imp_instances" : tactic => do + trace[Arith] "Introducing the PropHasImp instances" + let _ ← introInstances ``PropHasImp.concl lookupPropHasImp + +example (x y : Int) (h0 : x ≤ y) (h1 : x ≠ y) : x < y := by + intro_prop_has_imp_instances + rename_i h + split_disj h + . linarith + . linarith + +/- Boosting a bit the linarith tac. + + We do the following: + - for all the assumptions of the shape `(x : Int) ≠ y` or `¬ (x = y), we + introduce two goals with the assumptions `x < y` and `x > y` + TODO: we could create a PR for mathlib. + -/ +def intTacPreprocess (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM Unit := do + Tactic.withMainContext do + -- Lookup the instances of PropHasImp (this is how we detect assumptions + -- of the proper shape), introduce assumptions in the context and split + -- on those + -- TODO: get rid of the assumptions that we split + let rec splitOnAsms (asms : List Expr) : Tactic.TacticM Unit := + match asms with + | [] => pure () + | asm :: asms => + let k := splitOnAsms asms + Utils.splitDisjTac asm k k + -- Introduce the scalar bounds + let _ ← introHasIntPropInstances + -- Extra preprocessing, before we split on the disjunctions + extraPreprocess + -- Split + let asms ← introInstances ``PropHasImp.concl lookupPropHasImp + splitOnAsms asms.toList + +elab "int_tac_preprocess" : tactic => + intTacPreprocess (do pure ()) + +-- Check if the goal is a linear arithmetic goal +def goalIsLinearInt : Tactic.TacticM Bool := do + Tactic.withMainContext do + let gty ← Tactic.getMainTarget + match ← trySynthInstance (← mkAppM ``IsLinearIntProp #[gty]) with + | .some _ => pure true + | _ => pure false + +def intTac (splitGoalConjs : Bool) (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM Unit := do + Tactic.withMainContext do + Tactic.focus do + let g ← Tactic.getMainGoal + trace[Arith] "Original goal: {g}" + -- Introduce all the universally quantified variables (includes the assumptions) + let (_, g) ← g.intros + Tactic.setGoals [g] + -- Preprocess - wondering if we should do this before or after splitting + -- the goal. I think before leads to a smaller proof term? + Tactic.allGoals (intTacPreprocess extraPreprocess) + -- More preprocessing + Tactic.allGoals (Utils.simpAt [] [``nat_zero_eq_int_zero] [] .wildcard) + -- Split the conjunctions in the goal + if splitGoalConjs then Tactic.allGoals (Utils.repeatTac Utils.splitConjTarget) + -- Call linarith + let linarith := do + let cfg : Linarith.LinarithConfig := { + -- We do this with our custom preprocessing + splitNe := false + } + Tactic.liftMetaFinishingTactic <| Linarith.linarith false [] cfg + Tactic.allGoals do + -- We check if the goal is a linear arithmetic goal: if yes, we directly + -- call linarith, otherwise we first apply exfalso (we do this because + -- linarith is too general and sometimes fails to do this correctly). + if ← goalIsLinearInt then do + trace[Arith] "linarith goal: {← Tactic.getMainGoal}" + linarith + else do + let g ← Tactic.getMainGoal + let gs ← g.apply (Expr.const ``False.elim [.zero]) + let goals ← Tactic.getGoals + Tactic.setGoals (gs ++ goals) + Tactic.allGoals do + trace[Arith] "linarith goal: {← Tactic.getMainGoal}" + linarith + +elab "int_tac" args:(" split_goal"?): tactic => + let split := args.raw.getArgs.size > 0 + intTac split (do pure ()) + +example (x : Int) (h0: 0 ≤ x) (h1: x ≠ 0) : 0 < x := by + int_tac_preprocess + linarith + linarith + +example (x : Int) (h0: 0 ≤ x) (h1: x ≠ 0) : 0 < x := by + int_tac + +-- Checking that things append correctly when there are several disjunctions +example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y := by + int_tac split_goal + +-- Checking that things append correctly when there are several disjunctions +example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y ∧ x + y ≥ 2 := by + int_tac split_goal + +-- Checking that we can prove exfalso +example (a : Prop) (x : Int) (h0: 0 < x) (h1: x < 0) : a := by + int_tac + +end Arith diff --git a/backends/lean/Base/Arith/Scalar.lean b/backends/lean/Base/Arith/Scalar.lean new file mode 100644 index 00000000..b792ff21 --- /dev/null +++ b/backends/lean/Base/Arith/Scalar.lean @@ -0,0 +1,49 @@ +import Base.Arith.Int +import Base.Primitives.Scalar + +/- Automation for scalars - TODO: not sure it is worth having two files (Int.lean and Scalar.lean) -/ +namespace Arith + +open Lean Lean.Elab Lean.Meta +open Primitives + +def scalarTacExtraPreprocess : Tactic.TacticM Unit := do + Tactic.withMainContext do + -- Inroduce the bounds for the isize/usize types + let add (e : Expr) : Tactic.TacticM Unit := do + let ty ← inferType e + let _ ← Utils.addDeclTac (← Utils.mkFreshAnonPropUserName) e ty (asLet := false) + add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Isize []]) + add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []]) + add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []]) + -- Reveal the concrete bounds + Utils.simpAt [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax, + ``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min, + ``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max, + ``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min, + ``U8.max, ``U16.max, ``U32.max, ``U64.max, ``U128.max, + ``Usize.min + ] [] [] .wildcard + +elab "scalar_tac_preprocess" : tactic => + intTacPreprocess scalarTacExtraPreprocess + +-- A tactic to solve linear arithmetic goals in the presence of scalars +def scalarTac (splitGoalConjs : Bool) : Tactic.TacticM Unit := do + intTac splitGoalConjs scalarTacExtraPreprocess + +elab "scalar_tac" : tactic => + scalarTac false + +instance (ty : ScalarTy) : HasIntProp (Scalar ty) where + -- prop_ty is inferred + prop := λ x => And.intro x.hmin x.hmax + +example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by + intro_has_int_prop_instances + simp [*] + +example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by + scalar_tac + +end Arith |