summaryrefslogtreecommitdiff
path: root/backends/lean
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--backends/lean/Base/Arith/Arith.lean122
-rw-r--r--backends/lean/Base/Diverge/Base.lean2
-rw-r--r--backends/lean/Base/Primitives.lean89
-rw-r--r--backends/lean/Base/Progress/Base.lean24
-rw-r--r--backends/lean/Base/Progress/Progress.lean36
-rw-r--r--backends/lean/Base/Utils.lean247
6 files changed, 349 insertions, 171 deletions
diff --git a/backends/lean/Base/Arith/Arith.lean b/backends/lean/Base/Arith/Arith.lean
index ff628cf3..3557d350 100644
--- a/backends/lean/Base/Arith/Arith.lean
+++ b/backends/lean/Base/Arith/Arith.lean
@@ -230,25 +230,20 @@ def introInstances (declToUnfold : Name) (lookup : Expr → MetaM (Option Expr))
let type ← inferType e
let name ← mkFreshUserName `h
-- Add a declaration
- let nval ← Utils.addDecl name e type (asLet := false)
+ let nval ← Utils.addDeclTac name e type (asLet := false)
-- Simplify to unfold the declaration to unfold (i.e., the projector)
- let simpTheorems ← Tactic.simpOnlyBuiltins.foldlM (·.addConst ·) ({} : SimpTheorems)
- -- Add the equational theorem for the decl to unfold
- let simpTheorems ← simpTheorems.addDeclToUnfold declToUnfold
- let congrTheorems ← getSimpCongrTheorems
- let ctx : Simp.Context := { simpTheorems := #[simpTheorems], congrTheorems }
- -- Where to apply the simplifier
- let loc := Tactic.Location.targets #[mkIdent name] false
- -- Apply the simplifier
- let _ ← Tactic.simpLocation ctx (discharge? := .none) loc
+ Utils.simpAt [declToUnfold] [] [] (Tactic.Location.targets #[mkIdent name] false)
-- Return the new value
pure nval
+def introHasPropInstances : Tactic.TacticM (Array Expr) := do
+ trace[Arith] "Introducing the HasProp instances"
+ introInstances ``HasProp.prop_ty lookupHasProp
+
-- Lookup the instances of `HasProp for all the sub-expressions in the context,
-- and introduce the corresponding assumptions
elab "intro_has_prop_instances" : tactic => do
- trace[Arith] "Introducing the HasProp instances"
- let _ ← introInstances ``HasProp.prop_ty lookupHasProp
+ let _ ← introHasPropInstances
example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by
intro_has_prop_instances
@@ -258,74 +253,6 @@ example {a: Type} (v : Vec a) : v.val.length ≤ Scalar.max ScalarTy.Usize := by
intro_has_prop_instances
simp_all [Scalar.max, Scalar.min]
--- Tactic to split on a disjunction.
--- The expression `h` should be an fvar.
-def splitDisj (h : Expr) (kleft kright : Tactic.TacticM Unit) : Tactic.TacticM Unit := do
- trace[Arith] "assumption on which to split: {h}"
- -- Retrieve the main goal
- Tactic.withMainContext do
- let goalType ← Tactic.getMainTarget
- let hDecl := (← getLCtx).get! h.fvarId!
- let hName := hDecl.userName
- -- Case disjunction
- let hTy ← inferType h
- hTy.withApp fun f xs => do
- trace[Arith] "as app: {f} {xs}"
- -- Sanity check
- if ¬ (f.isConstOf ``Or ∧ xs.size = 2) then throwError "Invalid argument to splitDisj"
- let a := xs.get! 0
- let b := xs.get! 1
- -- Introduce the new goals
- -- Returns:
- -- - the match branch
- -- - a fresh new mvar id
- let mkGoal (hTy : Expr) (nGoalName : String) : MetaM (Expr × MVarId) := do
- -- Introduce a variable for the assumption (`a` or `b`). Note that we reuse
- -- the name of the assumption we split.
- withLocalDeclD hName hTy fun var => do
- -- The new goal
- let mgoal ← mkFreshExprSyntheticOpaqueMVar goalType (tag := Name.mkSimple nGoalName)
- -- Clear the assumption that we split
- let mgoal ← mgoal.mvarId!.tryClearMany #[h.fvarId!]
- -- The branch expression
- let branch ← mkLambdaFVars #[var] (mkMVar mgoal)
- pure (branch, mgoal)
- let (inl, mleft) ← mkGoal a "left"
- let (inr, mright) ← mkGoal b "right"
- trace[Arith] "left: {inl}: {mleft}"
- trace[Arith] "right: {inr}: {mright}"
- -- Create the match expression
- withLocalDeclD (← mkFreshUserName `h) hTy fun hVar => do
- let motive ← mkLambdaFVars #[hVar] goalType
- let casesExpr ← mkAppOptM ``Or.casesOn #[a, b, motive, h, inl, inr]
- let mgoal ← Tactic.getMainGoal
- trace[Arith] "goals: {← Tactic.getUnsolvedGoals}"
- trace[Arith] "main goal: {mgoal}"
- mgoal.assign casesExpr
- let goals ← Tactic.getUnsolvedGoals
- -- Focus on the left
- Tactic.setGoals [mleft]
- kleft
- let leftGoals ← Tactic.getUnsolvedGoals
- -- Focus on the right
- Tactic.setGoals [mright]
- kright
- let rightGoals ← Tactic.getUnsolvedGoals
- -- Put all the goals back
- Tactic.setGoals (leftGoals ++ rightGoals ++ goals)
- trace[Arith] "new goals: {← Tactic.getUnsolvedGoals}"
-
-elab "split_disj " n:ident : tactic => do
- Tactic.withMainContext do
- let decl ← Lean.Meta.getLocalDeclFromUserName n.getId
- let fvar := mkFVar decl.fvarId
- splitDisj fvar (fun _ => pure ()) (fun _ => pure ())
-
-example (x y : Int) (h0 : x ≤ y ∨ x ≥ y) : x ≤ y ∨ x ≥ y := by
- split_disj h0
- . left; assumption
- . right; assumption
-
-- Lookup the instances of `PropHasImp for all the sub-expressions in the context,
-- and introduce the corresponding assumptions
elab "intro_prop_has_imp_instances" : tactic => do
@@ -357,7 +284,7 @@ def intTacPreprocess : Tactic.TacticM Unit := do
| [] => pure ()
| asm :: asms =>
let k := splitOnAsms asms
- splitDisj asm k k
+ Utils.splitDisjTac asm k k
-- Introduce
let asms ← introInstances ``PropHasImp.concl lookupPropHasImp
-- Split
@@ -403,18 +330,27 @@ example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) :
int_tac
-- A tactic to solve linear arithmetic goals in the presence of scalars
-syntax "scalar_tac" : tactic
-macro_rules
- | `(tactic| scalar_tac) =>
- `(tactic|
- intro_has_prop_instances;
- have := Scalar.cMin_bound ScalarTy.Usize;
- have := Scalar.cMin_bound ScalarTy.Isize;
- have := Scalar.cMax_bound ScalarTy.Usize;
- have := Scalar.cMax_bound ScalarTy.Isize;
- -- TODO: not too sure about that
- simp only [*, Scalar.max, Scalar.min, Scalar.cMin, Scalar.cMax] at *;
- int_tac)
+def scalarTac : Tactic.TacticM Unit := do
+ Tactic.withMainContext do
+ -- Introduce the scalar bounds
+ let _ ← introHasPropInstances
+ Tactic.allGoals do
+ -- Inroduce the bounds for the isize/usize types
+ let add (e : Expr) : Tactic.TacticM Unit := do
+ let ty ← inferType e
+ let _ ← Utils.addDeclTac (← mkFreshUserName `h) e ty (asLet := false)
+ add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Usize []])
+ add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Isize []])
+ add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []])
+ add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []])
+ -- Reveal the concrete bounds - TODO: not too sure about that.
+ -- Maybe we should reveal the "concrete" bounds (after normalization)
+ Utils.simpAt [``Scalar.max, ``Scalar.min, ``Scalar.cMin, ``Scalar.cMax] [] [] .wildcard
+ -- Apply the integer tactic
+ intTac
+
+elab "scalar_tac" : tactic =>
+ scalarTac
example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by
scalar_tac
diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean
index e22eb914..d2c91ff8 100644
--- a/backends/lean/Base/Diverge/Base.lean
+++ b/backends/lean/Base/Diverge/Base.lean
@@ -14,7 +14,7 @@ TODO:
Actually, the cases from mathlib seems already quite powerful
(https://leanprover-community.github.io/mathlib_docs/tactics.html#cases)
For instance: cases h : e
- Also: cases_matching
+ Also: **casesm**
- better split tactic
- we need conversions to operate on the head of applications.
Actually, something like this works:
diff --git a/backends/lean/Base/Primitives.lean b/backends/lean/Base/Primitives.lean
index 14f5971e..6210688d 100644
--- a/backends/lean/Base/Primitives.lean
+++ b/backends/lean/Base/Primitives.lean
@@ -175,27 +175,28 @@ open System.Platform.getNumBits
@[simp] def U128.min : Int := 0
@[simp] def U128.max : Int := HPow.hPow 2 128 - 1
-#assert (I8.min == -128)
-#assert (I8.max == 127)
-#assert (I16.min == -32768)
-#assert (I16.max == 32767)
-#assert (I32.min == -2147483648)
-#assert (I32.max == 2147483647)
-#assert (I64.min == -9223372036854775808)
-#assert (I64.max == 9223372036854775807)
-#assert (I128.min == -170141183460469231731687303715884105728)
-#assert (I128.max == 170141183460469231731687303715884105727)
-#assert (U8.min == 0)
-#assert (U8.max == 255)
-#assert (U16.min == 0)
-#assert (U16.max == 65535)
-#assert (U32.min == 0)
-#assert (U32.max == 4294967295)
-#assert (U64.min == 0)
-#assert (U64.max == 18446744073709551615)
-#assert (U128.min == 0)
-#assert (U128.max == 340282366920938463463374607431768211455)
-
+-- The normalized bounds
+@[simp] def I8.norm_min := -128
+@[simp] def I8.norm_max := 127
+@[simp] def I16.norm_min := -32768
+@[simp] def I16.norm_max := 32767
+@[simp] def I32.norm_min := -2147483648
+@[simp] def I32.norm_max := 2147483647
+@[simp] def I64.norm_min := -9223372036854775808
+@[simp] def I64.norm_max := 9223372036854775807
+@[simp] def I128.norm_min := -170141183460469231731687303715884105728
+@[simp] def I128.norm_max := 170141183460469231731687303715884105727
+@[simp] def U8.norm_min := 0
+@[simp] def U8.norm_max := 255
+@[simp] def U16.norm_min := 0
+@[simp] def U16.norm_max := 65535
+@[simp] def U32.norm_min := 0
+@[simp] def U32.norm_max := 4294967295
+@[simp] def U64.norm_min := 0
+@[simp] def U64.norm_max := 18446744073709551615
+@[simp] def U128.norm_min := 0
+@[simp] def U128.norm_max := 340282366920938463463374607431768211455
+
inductive ScalarTy :=
| Isize
| I8
@@ -240,6 +241,46 @@ def Scalar.max (ty : ScalarTy) : Int :=
| .U64 => U64.max
| .U128 => U128.max
+@[simp] def Scalar.norm_min (ty : ScalarTy) : Int :=
+ match ty with
+ -- We can't normalize the bounds for isize/usize
+ | .Isize => Isize.min
+ | .Usize => Usize.min
+ --
+ | .I8 => I8.norm_min
+ | .I16 => I16.norm_min
+ | .I32 => I32.norm_min
+ | .I64 => I64.norm_min
+ | .I128 => I128.norm_min
+ | .U8 => U8.norm_min
+ | .U16 => U16.norm_min
+ | .U32 => U32.norm_min
+ | .U64 => U64.norm_min
+ | .U128 => U128.norm_min
+
+@[simp] def Scalar.norm_max (ty : ScalarTy) : Int :=
+ match ty with
+ -- We can't normalize the bounds for isize/usize
+ | .Isize => Isize.max
+ | .Usize => Usize.max
+ --
+ | .I8 => I8.norm_max
+ | .I16 => I16.norm_max
+ | .I32 => I32.norm_max
+ | .I64 => I64.norm_max
+ | .I128 => I128.norm_max
+ | .U8 => U8.norm_max
+ | .U16 => U16.norm_max
+ | .U32 => U32.norm_max
+ | .U64 => U64.norm_max
+ | .U128 => U128.norm_max
+
+def Scalar.norm_min_eq (ty : ScalarTy) : Scalar.min ty = Scalar.norm_min ty := by
+ cases ty <;> rfl
+
+def Scalar.norm_max_eq (ty : ScalarTy) : Scalar.max ty = Scalar.norm_max ty := by
+ cases ty <;> rfl
+
-- "Conservative" bounds
-- We use those because we can't compare to the isize bounds (which can't
-- reduce at compile-time). Whenever we perform an arithmetic operation like
@@ -249,13 +290,13 @@ def Scalar.max (ty : ScalarTy) : Int :=
-- type-checking time.
def Scalar.cMin (ty : ScalarTy) : Int :=
match ty with
- | .Isize => I32.min
+ | .Isize => Scalar.min .I32
| _ => Scalar.min ty
def Scalar.cMax (ty : ScalarTy) : Int :=
match ty with
- | .Isize => I32.max
- | .Usize => U32.max
+ | .Isize => Scalar.max .I32
+ | .Usize => Scalar.max .U32
| _ => Scalar.max ty
theorem Scalar.cMin_bound ty : Scalar.min ty ≤ Scalar.cMin ty := by
diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean
index 3f44f46c..613f38f8 100644
--- a/backends/lean/Base/Progress/Base.lean
+++ b/backends/lean/Base/Progress/Base.lean
@@ -10,26 +10,6 @@ open Utils
-- We can't define and use trace classes in the same file
initialize registerTraceClass `Progress
--- Return the first conjunct if the expression is a conjunction, or the
--- expression itself otherwise. Also return the second conjunct if it is a
--- conjunction.
-def getFirstConj (e : Expr) : MetaM (Expr × Option Expr) := do
- e.withApp fun f args =>
- if f.isConstOf ``And ∧ args.size = 2 then pure (args.get! 0, some (args.get! 1))
- else pure (e, none)
-
--- Destruct an equaliy and return the two sides
-def destEq (e : Expr) : MetaM (Expr × Expr) := do
- e.withApp fun f args =>
- if f.isConstOf ``Eq ∧ args.size = 3 then pure (args.get! 1, args.get! 2)
- else throwError "Not an equality: {e}"
-
--- Return the set of FVarIds in the expression
-partial def getFVarIds (e : Expr) (hs : HashSet FVarId := HashSet.empty) : MetaM (HashSet FVarId) := do
- e.withApp fun body args => do
- let hs := if body.isFVar then hs.insert body.fvarId! else hs
- args.foldlM (fun hs arg => getFVarIds arg hs) hs
-
/- # Progress tactic -/
structure PSpecDesc where
@@ -103,7 +83,7 @@ section Methods
existsTelescope th fun evars th => do
trace[Progress] "Existentials: {evars}"
-- Take the first conjunct
- let (th, post) ← getFirstConj th
+ let (th, post) ← optSplitConj th
-- Destruct the equality
let (th, ret) ← destEq th
-- Destruct the application to get the name
@@ -169,7 +149,5 @@ initialize pspecAttr : PSpecAttr ← do
def PSpecAttr.find? (s : PSpecAttr) (name : Name) : MetaM (Option Name) := do
return (s.ext.getState (← getEnv)).find? name
- --return s.ext.find? (← getEnv) name
-
end Progress
diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean
index 1b9ee55c..4c68b3bd 100644
--- a/backends/lean/Base/Progress/Progress.lean
+++ b/backends/lean/Base/Progress/Progress.lean
@@ -21,9 +21,6 @@ namespace Test
#eval pspecAttr.find? ``Primitives.Vec.index
end Test
-#check isDefEq
-#check allGoals
-
def progressLookupTheorem (asmTac : TacticM Unit) : TacticM Unit := do
withMainContext do
-- Retrieve the goal
@@ -80,7 +77,28 @@ def progressLookupTheorem (asmTac : TacticM Unit) : TacticM Unit := do
let th ← mkAppOptM thName (mvars.map some)
let asmName ← mkFreshUserName `h
let thTy ← inferType th
- let thAsm ← Utils.addDecl asmName th thTy (asLet := false)
+ let thAsm ← Utils.addDeclTac asmName th thTy (asLet := false)
+ withMainContext do -- The context changed - TODO: remove once addDeclTac is updated
+ let ngoal ← getMainGoal
+ trace[Progress] "current goal: {ngoal}"
+ trace[Progress] "current goal: {← ngoal.isAssigned}"
+ -- The assumption should be of the shape:
+ -- `∃ x1 ... xn, f args = ... ∧ ...`
+ -- We introduce the existentially quantified variables and split the top-most
+ -- conjunction if there is one
+ splitAllExistsTac thAsm fun h => do
+ -- Split the conjunction
+ let splitConj (k : Expr → TacticM Unit) : TacticM Unit := do
+ if ← isConj (← inferType h) then
+ splitConjTac h (fun h _ => k h)
+ else k h
+ -- Simplify the target by using the equality
+ splitConj fun h => do
+ simpAt [] [] [h.fvarId!] (.targets #[] true)
+ -- Clear the equality
+ let mgoal ← getMainGoal
+ let mgoal ← mgoal.tryClearMany #[h.fvarId!]
+ setGoals (mgoal :: (← getUnsolvedGoals))
-- Update the set of goals
let curGoals ← getUnsolvedGoals
let newGoals := mvars.map Expr.mvarId!
@@ -94,7 +112,7 @@ def progressLookupTheorem (asmTac : TacticM Unit) : TacticM Unit := do
pure ()
elab "progress" : tactic => do
- progressLookupTheorem (firstTac [assumptionTac, Arith.intTac])
+ progressLookupTheorem (firstTac [assumptionTac, Arith.scalarTac])
namespace Test
open Primitives
@@ -103,10 +121,12 @@ namespace Test
@[pspec]
theorem vec_index_test2 (α : Type u) (v: Vec α) (i: Usize) (h: i.val < v.val.length) :
- ∃ x, v.index α i = .ret x := by
+ ∃ (x: α), v.index α i = .ret x := by
progress
- tauto
-
+ simp
+
+ set_option trace.Progress false
+
end Test
end Progress
diff --git a/backends/lean/Base/Utils.lean b/backends/lean/Base/Utils.lean
index 1351f3d4..14feb567 100644
--- a/backends/lean/Base/Utils.lean
+++ b/backends/lean/Base/Utils.lean
@@ -1,9 +1,10 @@
import Lean
import Mathlib.Tactic.Core
+import Mathlib.Tactic.LeftRight
namespace Utils
-open Lean Elab Term Meta
+open Lean Elab Term Meta Tactic
-- Useful helper to explore definitions and figure out the variant
-- of their sub-expressions.
@@ -156,9 +157,10 @@ section Methods
end Methods
-def addDecl (name : Name) (val : Expr) (type : Expr) (asLet : Bool) : Tactic.TacticM Expr :=
+-- TODO: this should take a continuation
+def addDeclTac (name : Name) (val : Expr) (type : Expr) (asLet : Bool) : TacticM Expr :=
-- I don't think we need that
- Lean.Elab.Tactic.withMainContext do
+ withMainContext do
-- Insert the new declaration
let withDecl := if asLet then withLetDecl name type val else withLocalDeclD name type
withDecl fun nval => do
@@ -169,7 +171,7 @@ def addDecl (name : Name) (val : Expr) (type : Expr) (asLet : Bool) : Tactic.Tac
trace[Arith] " new decl: \"{decl.userName}\" ({nval}) : {decl.type} := {decl.value}"
--
-- Tranform the main goal `?m0` to `let x = nval in ?m1`
- let mvarId ← Tactic.getMainGoal
+ let mvarId ← getMainGoal
let newMVar ← mkFreshExprSyntheticOpaqueMVar (← mvarId.getType)
let newVal ← mkLetFVars #[nval] newMVar
-- There are two cases:
@@ -179,30 +181,30 @@ def addDecl (name : Name) (val : Expr) (type : Expr) (asLet : Bool) : Tactic.Tac
let newVal := if asLet then newVal else mkAppN newVal #[val]
-- Assign the main goal and update the current goal
mvarId.assign newVal
- let goals ← Tactic.getUnsolvedGoals
- Lean.Elab.Tactic.setGoals (newMVar.mvarId! :: goals)
+ let goals ← getUnsolvedGoals
+ setGoals (newMVar.mvarId! :: goals)
-- Return the new value - note: we are in the *new* context, created
-- after the declaration was added, so it will persist
pure nval
-def addDeclSyntax (name : Name) (val : Syntax) (asLet : Bool) : Tactic.TacticM Unit :=
+def addDeclTacSyntax (name : Name) (val : Syntax) (asLet : Bool) : TacticM Unit :=
-- I don't think we need that
- Lean.Elab.Tactic.withMainContext do
+ withMainContext do
--
- let val ← elabTerm val .none
+ let val ← Term.elabTerm val .none
let type ← inferType val
-- In some situations, the type will be left as a metavariable (for instance,
-- if the term is `3`, Lean has the choice between `Nat` and `Int` and will
-- not choose): we force the instantiation of the meta-variable
synthesizeSyntheticMVarsUsingDefault
--
- let _ ← addDecl name val type asLet
+ let _ ← addDeclTac name val type asLet
elab "custom_let " n:ident " := " v:term : tactic => do
- addDeclSyntax n.getId v (asLet := true)
+ addDeclTacSyntax n.getId v (asLet := true)
elab "custom_have " n:ident " := " v:term : tactic =>
- addDeclSyntax n.getId v (asLet := false)
+ addDeclTacSyntax n.getId v (asLet := false)
example : Nat := by
custom_let x := 4
@@ -213,14 +215,14 @@ example (x : Bool) : Nat := by
cases x <;> custom_let x := 3 <;> apply x
-- Repeatedly apply a tactic
-partial def repeatTac (tac : Tactic.TacticM Unit) : Tactic.TacticM Unit := do
+partial def repeatTac (tac : TacticM Unit) : TacticM Unit := do
try
tac
- Tactic.allGoals (Tactic.focus (repeatTac tac))
+ allGoals (focus (repeatTac tac))
-- TODO: does this restore the state?
catch _ => pure ()
-def firstTac (tacl : List (Tactic.TacticM Unit)) : Tactic.TacticM Unit := do
+def firstTac (tacl : List (TacticM Unit)) : TacticM Unit := do
match tacl with
| [] => pure ()
| tac :: tacl =>
@@ -228,15 +230,216 @@ def firstTac (tacl : List (Tactic.TacticM Unit)) : Tactic.TacticM Unit := do
catch _ => firstTac tacl
-- Split the goal if it is a conjunction
-def splitConjTarget : Tactic.TacticM Unit := do
- Tactic.withMainContext do
+def splitConjTarget : TacticM Unit := do
+ withMainContext do
let and_intro := Expr.const ``And.intro []
- let mvarIds' ← _root_.Lean.MVarId.apply (← Tactic.getMainGoal) and_intro
+ let mvarIds' ← _root_.Lean.MVarId.apply (← getMainGoal) and_intro
Term.synthesizeSyntheticMVarsNoPostponing
- Tactic.replaceMainGoal mvarIds'
+ replaceMainGoal mvarIds'
--- Taken from Lean.Elab.Tactic.evalAssumption
-def assumptionTac : Tactic.TacticM Unit :=
- Tactic.liftMetaTactic fun mvarId => do mvarId.assumption; pure []
+-- Taken from Lean.Elab.evalAssumption
+def assumptionTac : TacticM Unit :=
+ liftMetaTactic fun mvarId => do mvarId.assumption; pure []
+
+def isConj (e : Expr) : MetaM Bool :=
+ e.withApp fun f args => pure (f.isConstOf ``And ∧ args.size = 2)
+
+-- Return the first conjunct if the expression is a conjunction, or the
+-- expression itself otherwise. Also return the second conjunct if it is a
+-- conjunction.
+def optSplitConj (e : Expr) : MetaM (Expr × Option Expr) := do
+ e.withApp fun f args =>
+ if f.isConstOf ``And ∧ args.size = 2 then pure (args.get! 0, some (args.get! 1))
+ else pure (e, none)
+
+-- Destruct an equaliy and return the two sides
+def destEq (e : Expr) : MetaM (Expr × Expr) := do
+ e.withApp fun f args =>
+ if f.isConstOf ``Eq ∧ args.size = 3 then pure (args.get! 1, args.get! 2)
+ else throwError "Not an equality: {e}"
+
+-- Return the set of FVarIds in the expression
+partial def getFVarIds (e : Expr) (hs : HashSet FVarId := HashSet.empty) : MetaM (HashSet FVarId) := do
+ e.withApp fun body args => do
+ let hs := if body.isFVar then hs.insert body.fvarId! else hs
+ args.foldlM (fun hs arg => getFVarIds arg hs) hs
+
+-- Tactic to split on a disjunction.
+-- The expression `h` should be an fvar.
+-- TODO: there must be simpler. Use use _root_.Lean.MVarId.cases for instance
+def splitDisjTac (h : Expr) (kleft kright : TacticM Unit) : TacticM Unit := do
+ trace[Arith] "assumption on which to split: {h}"
+ -- Retrieve the main goal
+ withMainContext do
+ let goalType ← getMainTarget
+ let hDecl := (← getLCtx).get! h.fvarId!
+ let hName := hDecl.userName
+ -- Case disjunction
+ let hTy ← inferType h
+ hTy.withApp fun f xs => do
+ trace[Arith] "as app: {f} {xs}"
+ -- Sanity check
+ if ¬ (f.isConstOf ``Or ∧ xs.size = 2) then throwError "Invalid argument to splitDisjTac"
+ let a := xs.get! 0
+ let b := xs.get! 1
+ -- Introduce the new goals
+ -- Returns:
+ -- - the match branch
+ -- - a fresh new mvar id
+ let mkGoal (hTy : Expr) (nGoalName : String) : MetaM (Expr × MVarId) := do
+ -- Introduce a variable for the assumption (`a` or `b`). Note that we reuse
+ -- the name of the assumption we split.
+ withLocalDeclD hName hTy fun var => do
+ -- The new goal
+ let mgoal ← mkFreshExprSyntheticOpaqueMVar goalType (tag := Name.mkSimple nGoalName)
+ -- Clear the assumption that we split
+ let mgoal ← mgoal.mvarId!.tryClearMany #[h.fvarId!]
+ -- The branch expression
+ let branch ← mkLambdaFVars #[var] (mkMVar mgoal)
+ pure (branch, mgoal)
+ let (inl, mleft) ← mkGoal a "left"
+ let (inr, mright) ← mkGoal b "right"
+ trace[Arith] "left: {inl}: {mleft}"
+ trace[Arith] "right: {inr}: {mright}"
+ -- Create the match expression
+ withLocalDeclD (← mkFreshUserName `h) hTy fun hVar => do
+ let motive ← mkLambdaFVars #[hVar] goalType
+ let casesExpr ← mkAppOptM ``Or.casesOn #[a, b, motive, h, inl, inr]
+ let mgoal ← getMainGoal
+ trace[Arith] "goals: {← getUnsolvedGoals}"
+ trace[Arith] "main goal: {mgoal}"
+ mgoal.assign casesExpr
+ let goals ← getUnsolvedGoals
+ -- Focus on the left
+ setGoals [mleft]
+ withMainContext kleft
+ let leftGoals ← getUnsolvedGoals
+ -- Focus on the right
+ setGoals [mright]
+ withMainContext kright
+ let rightGoals ← getUnsolvedGoals
+ -- Put all the goals back
+ setGoals (leftGoals ++ rightGoals ++ goals)
+ trace[Arith] "new goals: {← getUnsolvedGoals}"
+
+elab "split_disj " n:ident : tactic => do
+ withMainContext do
+ let decl ← Lean.Meta.getLocalDeclFromUserName n.getId
+ let fvar := mkFVar decl.fvarId
+ splitDisjTac fvar (fun _ => pure ()) (fun _ => pure ())
+
+example (x y : Int) (h0 : x ≤ y ∨ x ≥ y) : x ≤ y ∨ x ≥ y := by
+ split_disj h0
+ . left; assumption
+ . right; assumption
+
+
+-- Tactic to split on an exists
+def splitExistsTac (h : Expr) (k : Expr → Expr → TacticM α) : TacticM α := do
+ withMainContext do
+ let goal ← getMainGoal
+ let hTy ← inferType h
+ if isExists hTy then do
+ let newGoals ← goal.cases h.fvarId! #[]
+ -- There should be exactly one goal
+ match newGoals.toList with
+ | [ newGoal ] =>
+ -- Set the new goal
+ let goals ← getUnsolvedGoals
+ setGoals (newGoal.mvarId :: goals)
+ -- There should be exactly two fields
+ let fields := newGoal.fields
+ withMainContext do
+ k (fields.get! 0) (fields.get! 1)
+ | _ =>
+ throwError "Unreachable"
+ else
+ throwError "Not a conjunction"
+
+partial def splitAllExistsTac [Inhabited α] (h : Expr) (k : Expr → TacticM α) : TacticM α := do
+ try
+ splitExistsTac h (fun _ body => splitAllExistsTac body k)
+ catch _ => k h
+
+-- Tactic to split on a conjunction.
+def splitConjTac (h : Expr) (k : Expr → Expr → TacticM α) : TacticM α := do
+ withMainContext do
+ let goal ← getMainGoal
+ let hTy ← inferType h
+ if ← isConj hTy then do
+ let newGoals ← goal.cases h.fvarId! #[]
+ -- There should be exactly one goal
+ match newGoals.toList with
+ | [ newGoal ] =>
+ -- Set the new goal
+ let goals ← getUnsolvedGoals
+ setGoals (newGoal.mvarId :: goals)
+ -- There should be exactly two fields
+ let fields := newGoal.fields
+ withMainContext do
+ k (fields.get! 0) (fields.get! 1)
+ | _ =>
+ throwError "Unreachable"
+ else
+ throwError "Not a conjunction"
+
+elab "split_conj " n:ident : tactic => do
+ withMainContext do
+ let decl ← Lean.Meta.getLocalDeclFromUserName n.getId
+ let fvar := mkFVar decl.fvarId
+ splitConjTac fvar (fun _ _ => pure ())
+
+elab "split_all_exists " n:ident : tactic => do
+ withMainContext do
+ let decl ← Lean.Meta.getLocalDeclFromUserName n.getId
+ let fvar := mkFVar decl.fvarId
+ splitAllExistsTac fvar (fun _ => pure ())
+
+example (h : a ∧ b) : a := by
+ split_all_exists h
+ split_conj h
+ assumption
+
+example (h : ∃ x y z, x + y + z ≥ 0) : ∃ x, x ≥ 0 := by
+ split_all_exists h
+ rename_i x y z h
+ exists x + y + z
+
+/- Call the simp tactic.
+ The initialization of the context is adapted from Tactic.elabSimpArgs.
+ Something very annoying is that there is no function which allows to
+ initialize a simp context without doing an elaboration - as a consequence
+ we write our own here. -/
+def simpAt (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId)
+ (loc : Tactic.Location) :
+ Tactic.TacticM Unit := do
+ -- Initialize with the builtin simp theorems
+ let simpThms ← Tactic.simpOnlyBuiltins.foldlM (·.addConst ·) ({} : SimpTheorems)
+ -- Add the equational theorem for the declarations to unfold
+ let simpThms ←
+ declsToUnfold.foldlM (fun thms decl => thms.addDeclToUnfold decl) simpThms
+ -- Add the hypotheses and the rewriting theorems
+ let simpThms ←
+ hypsToUse.foldlM (fun thms fvarId =>
+ -- post: TODO: don't know what that is
+ -- inv: invert the equality
+ thms.add (.fvar fvarId) #[] (mkFVar fvarId) (post := false) (inv := false)
+ -- thms.eraseCore (.fvar fvar)
+ ) simpThms
+ -- Add the rewriting theorems to use
+ let simpThms ←
+ thms.foldlM (fun thms thmName => do
+ let info ← getConstInfo thmName
+ if (← isProp info.type) then
+ -- post: TODO: don't know what that is
+ -- inv: invert the equality
+ thms.addConst thmName (post := false) (inv := false)
+ else
+ throwError "Not a proposition: {thmName}"
+ ) simpThms
+ let congrTheorems ← getSimpCongrTheorems
+ let ctx : Simp.Context := { simpTheorems := #[simpThms], congrTheorems }
+ -- Apply the simplifier
+ let _ ← Tactic.simpLocation ctx (discharge? := .none) loc
end Utils