diff options
author | Son Ho | 2023-07-12 14:34:55 +0200 |
---|---|---|
committer | Son Ho | 2023-07-12 14:34:55 +0200 |
commit | a18d899a2c2b9bdd36f4a5a4b70472c12a835a96 (patch) | |
tree | b8bc72479804fcd416fea42478d7e7945845cbe7 /backends/lean | |
parent | 6166c410a4b3353377e640acbae9f56e877a9118 (diff) |
Finish a first version of the progress tactic
Diffstat (limited to '')
-rw-r--r-- | backends/lean/Base/Arith/Arith.lean | 122 | ||||
-rw-r--r-- | backends/lean/Base/Diverge/Base.lean | 2 | ||||
-rw-r--r-- | backends/lean/Base/Primitives.lean | 89 | ||||
-rw-r--r-- | backends/lean/Base/Progress/Base.lean | 24 | ||||
-rw-r--r-- | backends/lean/Base/Progress/Progress.lean | 36 | ||||
-rw-r--r-- | backends/lean/Base/Utils.lean | 247 |
6 files changed, 349 insertions, 171 deletions
diff --git a/backends/lean/Base/Arith/Arith.lean b/backends/lean/Base/Arith/Arith.lean index ff628cf3..3557d350 100644 --- a/backends/lean/Base/Arith/Arith.lean +++ b/backends/lean/Base/Arith/Arith.lean @@ -230,25 +230,20 @@ def introInstances (declToUnfold : Name) (lookup : Expr → MetaM (Option Expr)) let type ← inferType e let name ← mkFreshUserName `h -- Add a declaration - let nval ← Utils.addDecl name e type (asLet := false) + let nval ← Utils.addDeclTac name e type (asLet := false) -- Simplify to unfold the declaration to unfold (i.e., the projector) - let simpTheorems ← Tactic.simpOnlyBuiltins.foldlM (·.addConst ·) ({} : SimpTheorems) - -- Add the equational theorem for the decl to unfold - let simpTheorems ← simpTheorems.addDeclToUnfold declToUnfold - let congrTheorems ← getSimpCongrTheorems - let ctx : Simp.Context := { simpTheorems := #[simpTheorems], congrTheorems } - -- Where to apply the simplifier - let loc := Tactic.Location.targets #[mkIdent name] false - -- Apply the simplifier - let _ ← Tactic.simpLocation ctx (discharge? := .none) loc + Utils.simpAt [declToUnfold] [] [] (Tactic.Location.targets #[mkIdent name] false) -- Return the new value pure nval +def introHasPropInstances : Tactic.TacticM (Array Expr) := do + trace[Arith] "Introducing the HasProp instances" + introInstances ``HasProp.prop_ty lookupHasProp + -- Lookup the instances of `HasProp for all the sub-expressions in the context, -- and introduce the corresponding assumptions elab "intro_has_prop_instances" : tactic => do - trace[Arith] "Introducing the HasProp instances" - let _ ← introInstances ``HasProp.prop_ty lookupHasProp + let _ ← introHasPropInstances example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by intro_has_prop_instances @@ -258,74 +253,6 @@ example {a: Type} (v : Vec a) : v.val.length ≤ Scalar.max ScalarTy.Usize := by intro_has_prop_instances simp_all [Scalar.max, Scalar.min] --- Tactic to split on a disjunction. --- The expression `h` should be an fvar. -def splitDisj (h : Expr) (kleft kright : Tactic.TacticM Unit) : Tactic.TacticM Unit := do - trace[Arith] "assumption on which to split: {h}" - -- Retrieve the main goal - Tactic.withMainContext do - let goalType ← Tactic.getMainTarget - let hDecl := (← getLCtx).get! h.fvarId! - let hName := hDecl.userName - -- Case disjunction - let hTy ← inferType h - hTy.withApp fun f xs => do - trace[Arith] "as app: {f} {xs}" - -- Sanity check - if ¬ (f.isConstOf ``Or ∧ xs.size = 2) then throwError "Invalid argument to splitDisj" - let a := xs.get! 0 - let b := xs.get! 1 - -- Introduce the new goals - -- Returns: - -- - the match branch - -- - a fresh new mvar id - let mkGoal (hTy : Expr) (nGoalName : String) : MetaM (Expr × MVarId) := do - -- Introduce a variable for the assumption (`a` or `b`). Note that we reuse - -- the name of the assumption we split. - withLocalDeclD hName hTy fun var => do - -- The new goal - let mgoal ← mkFreshExprSyntheticOpaqueMVar goalType (tag := Name.mkSimple nGoalName) - -- Clear the assumption that we split - let mgoal ← mgoal.mvarId!.tryClearMany #[h.fvarId!] - -- The branch expression - let branch ← mkLambdaFVars #[var] (mkMVar mgoal) - pure (branch, mgoal) - let (inl, mleft) ← mkGoal a "left" - let (inr, mright) ← mkGoal b "right" - trace[Arith] "left: {inl}: {mleft}" - trace[Arith] "right: {inr}: {mright}" - -- Create the match expression - withLocalDeclD (← mkFreshUserName `h) hTy fun hVar => do - let motive ← mkLambdaFVars #[hVar] goalType - let casesExpr ← mkAppOptM ``Or.casesOn #[a, b, motive, h, inl, inr] - let mgoal ← Tactic.getMainGoal - trace[Arith] "goals: {← Tactic.getUnsolvedGoals}" - trace[Arith] "main goal: {mgoal}" - mgoal.assign casesExpr - let goals ← Tactic.getUnsolvedGoals - -- Focus on the left - Tactic.setGoals [mleft] - kleft - let leftGoals ← Tactic.getUnsolvedGoals - -- Focus on the right - Tactic.setGoals [mright] - kright - let rightGoals ← Tactic.getUnsolvedGoals - -- Put all the goals back - Tactic.setGoals (leftGoals ++ rightGoals ++ goals) - trace[Arith] "new goals: {← Tactic.getUnsolvedGoals}" - -elab "split_disj " n:ident : tactic => do - Tactic.withMainContext do - let decl ← Lean.Meta.getLocalDeclFromUserName n.getId - let fvar := mkFVar decl.fvarId - splitDisj fvar (fun _ => pure ()) (fun _ => pure ()) - -example (x y : Int) (h0 : x ≤ y ∨ x ≥ y) : x ≤ y ∨ x ≥ y := by - split_disj h0 - . left; assumption - . right; assumption - -- Lookup the instances of `PropHasImp for all the sub-expressions in the context, -- and introduce the corresponding assumptions elab "intro_prop_has_imp_instances" : tactic => do @@ -357,7 +284,7 @@ def intTacPreprocess : Tactic.TacticM Unit := do | [] => pure () | asm :: asms => let k := splitOnAsms asms - splitDisj asm k k + Utils.splitDisjTac asm k k -- Introduce let asms ← introInstances ``PropHasImp.concl lookupPropHasImp -- Split @@ -403,18 +330,27 @@ example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : int_tac -- A tactic to solve linear arithmetic goals in the presence of scalars -syntax "scalar_tac" : tactic -macro_rules - | `(tactic| scalar_tac) => - `(tactic| - intro_has_prop_instances; - have := Scalar.cMin_bound ScalarTy.Usize; - have := Scalar.cMin_bound ScalarTy.Isize; - have := Scalar.cMax_bound ScalarTy.Usize; - have := Scalar.cMax_bound ScalarTy.Isize; - -- TODO: not too sure about that - simp only [*, Scalar.max, Scalar.min, Scalar.cMin, Scalar.cMax] at *; - int_tac) +def scalarTac : Tactic.TacticM Unit := do + Tactic.withMainContext do + -- Introduce the scalar bounds + let _ ← introHasPropInstances + Tactic.allGoals do + -- Inroduce the bounds for the isize/usize types + let add (e : Expr) : Tactic.TacticM Unit := do + let ty ← inferType e + let _ ← Utils.addDeclTac (← mkFreshUserName `h) e ty (asLet := false) + add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Usize []]) + add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Isize []]) + add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []]) + add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []]) + -- Reveal the concrete bounds - TODO: not too sure about that. + -- Maybe we should reveal the "concrete" bounds (after normalization) + Utils.simpAt [``Scalar.max, ``Scalar.min, ``Scalar.cMin, ``Scalar.cMax] [] [] .wildcard + -- Apply the integer tactic + intTac + +elab "scalar_tac" : tactic => + scalarTac example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by scalar_tac diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index e22eb914..d2c91ff8 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -14,7 +14,7 @@ TODO: Actually, the cases from mathlib seems already quite powerful (https://leanprover-community.github.io/mathlib_docs/tactics.html#cases) For instance: cases h : e - Also: cases_matching + Also: **casesm** - better split tactic - we need conversions to operate on the head of applications. Actually, something like this works: diff --git a/backends/lean/Base/Primitives.lean b/backends/lean/Base/Primitives.lean index 14f5971e..6210688d 100644 --- a/backends/lean/Base/Primitives.lean +++ b/backends/lean/Base/Primitives.lean @@ -175,27 +175,28 @@ open System.Platform.getNumBits @[simp] def U128.min : Int := 0 @[simp] def U128.max : Int := HPow.hPow 2 128 - 1 -#assert (I8.min == -128) -#assert (I8.max == 127) -#assert (I16.min == -32768) -#assert (I16.max == 32767) -#assert (I32.min == -2147483648) -#assert (I32.max == 2147483647) -#assert (I64.min == -9223372036854775808) -#assert (I64.max == 9223372036854775807) -#assert (I128.min == -170141183460469231731687303715884105728) -#assert (I128.max == 170141183460469231731687303715884105727) -#assert (U8.min == 0) -#assert (U8.max == 255) -#assert (U16.min == 0) -#assert (U16.max == 65535) -#assert (U32.min == 0) -#assert (U32.max == 4294967295) -#assert (U64.min == 0) -#assert (U64.max == 18446744073709551615) -#assert (U128.min == 0) -#assert (U128.max == 340282366920938463463374607431768211455) - +-- The normalized bounds +@[simp] def I8.norm_min := -128 +@[simp] def I8.norm_max := 127 +@[simp] def I16.norm_min := -32768 +@[simp] def I16.norm_max := 32767 +@[simp] def I32.norm_min := -2147483648 +@[simp] def I32.norm_max := 2147483647 +@[simp] def I64.norm_min := -9223372036854775808 +@[simp] def I64.norm_max := 9223372036854775807 +@[simp] def I128.norm_min := -170141183460469231731687303715884105728 +@[simp] def I128.norm_max := 170141183460469231731687303715884105727 +@[simp] def U8.norm_min := 0 +@[simp] def U8.norm_max := 255 +@[simp] def U16.norm_min := 0 +@[simp] def U16.norm_max := 65535 +@[simp] def U32.norm_min := 0 +@[simp] def U32.norm_max := 4294967295 +@[simp] def U64.norm_min := 0 +@[simp] def U64.norm_max := 18446744073709551615 +@[simp] def U128.norm_min := 0 +@[simp] def U128.norm_max := 340282366920938463463374607431768211455 + inductive ScalarTy := | Isize | I8 @@ -240,6 +241,46 @@ def Scalar.max (ty : ScalarTy) : Int := | .U64 => U64.max | .U128 => U128.max +@[simp] def Scalar.norm_min (ty : ScalarTy) : Int := + match ty with + -- We can't normalize the bounds for isize/usize + | .Isize => Isize.min + | .Usize => Usize.min + -- + | .I8 => I8.norm_min + | .I16 => I16.norm_min + | .I32 => I32.norm_min + | .I64 => I64.norm_min + | .I128 => I128.norm_min + | .U8 => U8.norm_min + | .U16 => U16.norm_min + | .U32 => U32.norm_min + | .U64 => U64.norm_min + | .U128 => U128.norm_min + +@[simp] def Scalar.norm_max (ty : ScalarTy) : Int := + match ty with + -- We can't normalize the bounds for isize/usize + | .Isize => Isize.max + | .Usize => Usize.max + -- + | .I8 => I8.norm_max + | .I16 => I16.norm_max + | .I32 => I32.norm_max + | .I64 => I64.norm_max + | .I128 => I128.norm_max + | .U8 => U8.norm_max + | .U16 => U16.norm_max + | .U32 => U32.norm_max + | .U64 => U64.norm_max + | .U128 => U128.norm_max + +def Scalar.norm_min_eq (ty : ScalarTy) : Scalar.min ty = Scalar.norm_min ty := by + cases ty <;> rfl + +def Scalar.norm_max_eq (ty : ScalarTy) : Scalar.max ty = Scalar.norm_max ty := by + cases ty <;> rfl + -- "Conservative" bounds -- We use those because we can't compare to the isize bounds (which can't -- reduce at compile-time). Whenever we perform an arithmetic operation like @@ -249,13 +290,13 @@ def Scalar.max (ty : ScalarTy) : Int := -- type-checking time. def Scalar.cMin (ty : ScalarTy) : Int := match ty with - | .Isize => I32.min + | .Isize => Scalar.min .I32 | _ => Scalar.min ty def Scalar.cMax (ty : ScalarTy) : Int := match ty with - | .Isize => I32.max - | .Usize => U32.max + | .Isize => Scalar.max .I32 + | .Usize => Scalar.max .U32 | _ => Scalar.max ty theorem Scalar.cMin_bound ty : Scalar.min ty ≤ Scalar.cMin ty := by diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean index 3f44f46c..613f38f8 100644 --- a/backends/lean/Base/Progress/Base.lean +++ b/backends/lean/Base/Progress/Base.lean @@ -10,26 +10,6 @@ open Utils -- We can't define and use trace classes in the same file initialize registerTraceClass `Progress --- Return the first conjunct if the expression is a conjunction, or the --- expression itself otherwise. Also return the second conjunct if it is a --- conjunction. -def getFirstConj (e : Expr) : MetaM (Expr × Option Expr) := do - e.withApp fun f args => - if f.isConstOf ``And ∧ args.size = 2 then pure (args.get! 0, some (args.get! 1)) - else pure (e, none) - --- Destruct an equaliy and return the two sides -def destEq (e : Expr) : MetaM (Expr × Expr) := do - e.withApp fun f args => - if f.isConstOf ``Eq ∧ args.size = 3 then pure (args.get! 1, args.get! 2) - else throwError "Not an equality: {e}" - --- Return the set of FVarIds in the expression -partial def getFVarIds (e : Expr) (hs : HashSet FVarId := HashSet.empty) : MetaM (HashSet FVarId) := do - e.withApp fun body args => do - let hs := if body.isFVar then hs.insert body.fvarId! else hs - args.foldlM (fun hs arg => getFVarIds arg hs) hs - /- # Progress tactic -/ structure PSpecDesc where @@ -103,7 +83,7 @@ section Methods existsTelescope th fun evars th => do trace[Progress] "Existentials: {evars}" -- Take the first conjunct - let (th, post) ← getFirstConj th + let (th, post) ← optSplitConj th -- Destruct the equality let (th, ret) ← destEq th -- Destruct the application to get the name @@ -169,7 +149,5 @@ initialize pspecAttr : PSpecAttr ← do def PSpecAttr.find? (s : PSpecAttr) (name : Name) : MetaM (Option Name) := do return (s.ext.getState (← getEnv)).find? name - --return s.ext.find? (← getEnv) name - end Progress diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean index 1b9ee55c..4c68b3bd 100644 --- a/backends/lean/Base/Progress/Progress.lean +++ b/backends/lean/Base/Progress/Progress.lean @@ -21,9 +21,6 @@ namespace Test #eval pspecAttr.find? ``Primitives.Vec.index end Test -#check isDefEq -#check allGoals - def progressLookupTheorem (asmTac : TacticM Unit) : TacticM Unit := do withMainContext do -- Retrieve the goal @@ -80,7 +77,28 @@ def progressLookupTheorem (asmTac : TacticM Unit) : TacticM Unit := do let th ← mkAppOptM thName (mvars.map some) let asmName ← mkFreshUserName `h let thTy ← inferType th - let thAsm ← Utils.addDecl asmName th thTy (asLet := false) + let thAsm ← Utils.addDeclTac asmName th thTy (asLet := false) + withMainContext do -- The context changed - TODO: remove once addDeclTac is updated + let ngoal ← getMainGoal + trace[Progress] "current goal: {ngoal}" + trace[Progress] "current goal: {← ngoal.isAssigned}" + -- The assumption should be of the shape: + -- `∃ x1 ... xn, f args = ... ∧ ...` + -- We introduce the existentially quantified variables and split the top-most + -- conjunction if there is one + splitAllExistsTac thAsm fun h => do + -- Split the conjunction + let splitConj (k : Expr → TacticM Unit) : TacticM Unit := do + if ← isConj (← inferType h) then + splitConjTac h (fun h _ => k h) + else k h + -- Simplify the target by using the equality + splitConj fun h => do + simpAt [] [] [h.fvarId!] (.targets #[] true) + -- Clear the equality + let mgoal ← getMainGoal + let mgoal ← mgoal.tryClearMany #[h.fvarId!] + setGoals (mgoal :: (← getUnsolvedGoals)) -- Update the set of goals let curGoals ← getUnsolvedGoals let newGoals := mvars.map Expr.mvarId! @@ -94,7 +112,7 @@ def progressLookupTheorem (asmTac : TacticM Unit) : TacticM Unit := do pure () elab "progress" : tactic => do - progressLookupTheorem (firstTac [assumptionTac, Arith.intTac]) + progressLookupTheorem (firstTac [assumptionTac, Arith.scalarTac]) namespace Test open Primitives @@ -103,10 +121,12 @@ namespace Test @[pspec] theorem vec_index_test2 (α : Type u) (v: Vec α) (i: Usize) (h: i.val < v.val.length) : - ∃ x, v.index α i = .ret x := by + ∃ (x: α), v.index α i = .ret x := by progress - tauto - + simp + + set_option trace.Progress false + end Test end Progress diff --git a/backends/lean/Base/Utils.lean b/backends/lean/Base/Utils.lean index 1351f3d4..14feb567 100644 --- a/backends/lean/Base/Utils.lean +++ b/backends/lean/Base/Utils.lean @@ -1,9 +1,10 @@ import Lean import Mathlib.Tactic.Core +import Mathlib.Tactic.LeftRight namespace Utils -open Lean Elab Term Meta +open Lean Elab Term Meta Tactic -- Useful helper to explore definitions and figure out the variant -- of their sub-expressions. @@ -156,9 +157,10 @@ section Methods end Methods -def addDecl (name : Name) (val : Expr) (type : Expr) (asLet : Bool) : Tactic.TacticM Expr := +-- TODO: this should take a continuation +def addDeclTac (name : Name) (val : Expr) (type : Expr) (asLet : Bool) : TacticM Expr := -- I don't think we need that - Lean.Elab.Tactic.withMainContext do + withMainContext do -- Insert the new declaration let withDecl := if asLet then withLetDecl name type val else withLocalDeclD name type withDecl fun nval => do @@ -169,7 +171,7 @@ def addDecl (name : Name) (val : Expr) (type : Expr) (asLet : Bool) : Tactic.Tac trace[Arith] " new decl: \"{decl.userName}\" ({nval}) : {decl.type} := {decl.value}" -- -- Tranform the main goal `?m0` to `let x = nval in ?m1` - let mvarId ← Tactic.getMainGoal + let mvarId ← getMainGoal let newMVar ← mkFreshExprSyntheticOpaqueMVar (← mvarId.getType) let newVal ← mkLetFVars #[nval] newMVar -- There are two cases: @@ -179,30 +181,30 @@ def addDecl (name : Name) (val : Expr) (type : Expr) (asLet : Bool) : Tactic.Tac let newVal := if asLet then newVal else mkAppN newVal #[val] -- Assign the main goal and update the current goal mvarId.assign newVal - let goals ← Tactic.getUnsolvedGoals - Lean.Elab.Tactic.setGoals (newMVar.mvarId! :: goals) + let goals ← getUnsolvedGoals + setGoals (newMVar.mvarId! :: goals) -- Return the new value - note: we are in the *new* context, created -- after the declaration was added, so it will persist pure nval -def addDeclSyntax (name : Name) (val : Syntax) (asLet : Bool) : Tactic.TacticM Unit := +def addDeclTacSyntax (name : Name) (val : Syntax) (asLet : Bool) : TacticM Unit := -- I don't think we need that - Lean.Elab.Tactic.withMainContext do + withMainContext do -- - let val ← elabTerm val .none + let val ← Term.elabTerm val .none let type ← inferType val -- In some situations, the type will be left as a metavariable (for instance, -- if the term is `3`, Lean has the choice between `Nat` and `Int` and will -- not choose): we force the instantiation of the meta-variable synthesizeSyntheticMVarsUsingDefault -- - let _ ← addDecl name val type asLet + let _ ← addDeclTac name val type asLet elab "custom_let " n:ident " := " v:term : tactic => do - addDeclSyntax n.getId v (asLet := true) + addDeclTacSyntax n.getId v (asLet := true) elab "custom_have " n:ident " := " v:term : tactic => - addDeclSyntax n.getId v (asLet := false) + addDeclTacSyntax n.getId v (asLet := false) example : Nat := by custom_let x := 4 @@ -213,14 +215,14 @@ example (x : Bool) : Nat := by cases x <;> custom_let x := 3 <;> apply x -- Repeatedly apply a tactic -partial def repeatTac (tac : Tactic.TacticM Unit) : Tactic.TacticM Unit := do +partial def repeatTac (tac : TacticM Unit) : TacticM Unit := do try tac - Tactic.allGoals (Tactic.focus (repeatTac tac)) + allGoals (focus (repeatTac tac)) -- TODO: does this restore the state? catch _ => pure () -def firstTac (tacl : List (Tactic.TacticM Unit)) : Tactic.TacticM Unit := do +def firstTac (tacl : List (TacticM Unit)) : TacticM Unit := do match tacl with | [] => pure () | tac :: tacl => @@ -228,15 +230,216 @@ def firstTac (tacl : List (Tactic.TacticM Unit)) : Tactic.TacticM Unit := do catch _ => firstTac tacl -- Split the goal if it is a conjunction -def splitConjTarget : Tactic.TacticM Unit := do - Tactic.withMainContext do +def splitConjTarget : TacticM Unit := do + withMainContext do let and_intro := Expr.const ``And.intro [] - let mvarIds' ← _root_.Lean.MVarId.apply (← Tactic.getMainGoal) and_intro + let mvarIds' ← _root_.Lean.MVarId.apply (← getMainGoal) and_intro Term.synthesizeSyntheticMVarsNoPostponing - Tactic.replaceMainGoal mvarIds' + replaceMainGoal mvarIds' --- Taken from Lean.Elab.Tactic.evalAssumption -def assumptionTac : Tactic.TacticM Unit := - Tactic.liftMetaTactic fun mvarId => do mvarId.assumption; pure [] +-- Taken from Lean.Elab.evalAssumption +def assumptionTac : TacticM Unit := + liftMetaTactic fun mvarId => do mvarId.assumption; pure [] + +def isConj (e : Expr) : MetaM Bool := + e.withApp fun f args => pure (f.isConstOf ``And ∧ args.size = 2) + +-- Return the first conjunct if the expression is a conjunction, or the +-- expression itself otherwise. Also return the second conjunct if it is a +-- conjunction. +def optSplitConj (e : Expr) : MetaM (Expr × Option Expr) := do + e.withApp fun f args => + if f.isConstOf ``And ∧ args.size = 2 then pure (args.get! 0, some (args.get! 1)) + else pure (e, none) + +-- Destruct an equaliy and return the two sides +def destEq (e : Expr) : MetaM (Expr × Expr) := do + e.withApp fun f args => + if f.isConstOf ``Eq ∧ args.size = 3 then pure (args.get! 1, args.get! 2) + else throwError "Not an equality: {e}" + +-- Return the set of FVarIds in the expression +partial def getFVarIds (e : Expr) (hs : HashSet FVarId := HashSet.empty) : MetaM (HashSet FVarId) := do + e.withApp fun body args => do + let hs := if body.isFVar then hs.insert body.fvarId! else hs + args.foldlM (fun hs arg => getFVarIds arg hs) hs + +-- Tactic to split on a disjunction. +-- The expression `h` should be an fvar. +-- TODO: there must be simpler. Use use _root_.Lean.MVarId.cases for instance +def splitDisjTac (h : Expr) (kleft kright : TacticM Unit) : TacticM Unit := do + trace[Arith] "assumption on which to split: {h}" + -- Retrieve the main goal + withMainContext do + let goalType ← getMainTarget + let hDecl := (← getLCtx).get! h.fvarId! + let hName := hDecl.userName + -- Case disjunction + let hTy ← inferType h + hTy.withApp fun f xs => do + trace[Arith] "as app: {f} {xs}" + -- Sanity check + if ¬ (f.isConstOf ``Or ∧ xs.size = 2) then throwError "Invalid argument to splitDisjTac" + let a := xs.get! 0 + let b := xs.get! 1 + -- Introduce the new goals + -- Returns: + -- - the match branch + -- - a fresh new mvar id + let mkGoal (hTy : Expr) (nGoalName : String) : MetaM (Expr × MVarId) := do + -- Introduce a variable for the assumption (`a` or `b`). Note that we reuse + -- the name of the assumption we split. + withLocalDeclD hName hTy fun var => do + -- The new goal + let mgoal ← mkFreshExprSyntheticOpaqueMVar goalType (tag := Name.mkSimple nGoalName) + -- Clear the assumption that we split + let mgoal ← mgoal.mvarId!.tryClearMany #[h.fvarId!] + -- The branch expression + let branch ← mkLambdaFVars #[var] (mkMVar mgoal) + pure (branch, mgoal) + let (inl, mleft) ← mkGoal a "left" + let (inr, mright) ← mkGoal b "right" + trace[Arith] "left: {inl}: {mleft}" + trace[Arith] "right: {inr}: {mright}" + -- Create the match expression + withLocalDeclD (← mkFreshUserName `h) hTy fun hVar => do + let motive ← mkLambdaFVars #[hVar] goalType + let casesExpr ← mkAppOptM ``Or.casesOn #[a, b, motive, h, inl, inr] + let mgoal ← getMainGoal + trace[Arith] "goals: {← getUnsolvedGoals}" + trace[Arith] "main goal: {mgoal}" + mgoal.assign casesExpr + let goals ← getUnsolvedGoals + -- Focus on the left + setGoals [mleft] + withMainContext kleft + let leftGoals ← getUnsolvedGoals + -- Focus on the right + setGoals [mright] + withMainContext kright + let rightGoals ← getUnsolvedGoals + -- Put all the goals back + setGoals (leftGoals ++ rightGoals ++ goals) + trace[Arith] "new goals: {← getUnsolvedGoals}" + +elab "split_disj " n:ident : tactic => do + withMainContext do + let decl ← Lean.Meta.getLocalDeclFromUserName n.getId + let fvar := mkFVar decl.fvarId + splitDisjTac fvar (fun _ => pure ()) (fun _ => pure ()) + +example (x y : Int) (h0 : x ≤ y ∨ x ≥ y) : x ≤ y ∨ x ≥ y := by + split_disj h0 + . left; assumption + . right; assumption + + +-- Tactic to split on an exists +def splitExistsTac (h : Expr) (k : Expr → Expr → TacticM α) : TacticM α := do + withMainContext do + let goal ← getMainGoal + let hTy ← inferType h + if isExists hTy then do + let newGoals ← goal.cases h.fvarId! #[] + -- There should be exactly one goal + match newGoals.toList with + | [ newGoal ] => + -- Set the new goal + let goals ← getUnsolvedGoals + setGoals (newGoal.mvarId :: goals) + -- There should be exactly two fields + let fields := newGoal.fields + withMainContext do + k (fields.get! 0) (fields.get! 1) + | _ => + throwError "Unreachable" + else + throwError "Not a conjunction" + +partial def splitAllExistsTac [Inhabited α] (h : Expr) (k : Expr → TacticM α) : TacticM α := do + try + splitExistsTac h (fun _ body => splitAllExistsTac body k) + catch _ => k h + +-- Tactic to split on a conjunction. +def splitConjTac (h : Expr) (k : Expr → Expr → TacticM α) : TacticM α := do + withMainContext do + let goal ← getMainGoal + let hTy ← inferType h + if ← isConj hTy then do + let newGoals ← goal.cases h.fvarId! #[] + -- There should be exactly one goal + match newGoals.toList with + | [ newGoal ] => + -- Set the new goal + let goals ← getUnsolvedGoals + setGoals (newGoal.mvarId :: goals) + -- There should be exactly two fields + let fields := newGoal.fields + withMainContext do + k (fields.get! 0) (fields.get! 1) + | _ => + throwError "Unreachable" + else + throwError "Not a conjunction" + +elab "split_conj " n:ident : tactic => do + withMainContext do + let decl ← Lean.Meta.getLocalDeclFromUserName n.getId + let fvar := mkFVar decl.fvarId + splitConjTac fvar (fun _ _ => pure ()) + +elab "split_all_exists " n:ident : tactic => do + withMainContext do + let decl ← Lean.Meta.getLocalDeclFromUserName n.getId + let fvar := mkFVar decl.fvarId + splitAllExistsTac fvar (fun _ => pure ()) + +example (h : a ∧ b) : a := by + split_all_exists h + split_conj h + assumption + +example (h : ∃ x y z, x + y + z ≥ 0) : ∃ x, x ≥ 0 := by + split_all_exists h + rename_i x y z h + exists x + y + z + +/- Call the simp tactic. + The initialization of the context is adapted from Tactic.elabSimpArgs. + Something very annoying is that there is no function which allows to + initialize a simp context without doing an elaboration - as a consequence + we write our own here. -/ +def simpAt (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) + (loc : Tactic.Location) : + Tactic.TacticM Unit := do + -- Initialize with the builtin simp theorems + let simpThms ← Tactic.simpOnlyBuiltins.foldlM (·.addConst ·) ({} : SimpTheorems) + -- Add the equational theorem for the declarations to unfold + let simpThms ← + declsToUnfold.foldlM (fun thms decl => thms.addDeclToUnfold decl) simpThms + -- Add the hypotheses and the rewriting theorems + let simpThms ← + hypsToUse.foldlM (fun thms fvarId => + -- post: TODO: don't know what that is + -- inv: invert the equality + thms.add (.fvar fvarId) #[] (mkFVar fvarId) (post := false) (inv := false) + -- thms.eraseCore (.fvar fvar) + ) simpThms + -- Add the rewriting theorems to use + let simpThms ← + thms.foldlM (fun thms thmName => do + let info ← getConstInfo thmName + if (← isProp info.type) then + -- post: TODO: don't know what that is + -- inv: invert the equality + thms.addConst thmName (post := false) (inv := false) + else + throwError "Not a proposition: {thmName}" + ) simpThms + let congrTheorems ← getSimpCongrTheorems + let ctx : Simp.Context := { simpTheorems := #[simpThms], congrTheorems } + -- Apply the simplifier + let _ ← Tactic.simpLocation ctx (discharge? := .none) loc end Utils |