summaryrefslogtreecommitdiff
path: root/backends/lean/Base/Arith
diff options
context:
space:
mode:
Diffstat (limited to 'backends/lean/Base/Arith')
-rw-r--r--backends/lean/Base/Arith/Arith.lean122
1 files changed, 29 insertions, 93 deletions
diff --git a/backends/lean/Base/Arith/Arith.lean b/backends/lean/Base/Arith/Arith.lean
index ff628cf3..3557d350 100644
--- a/backends/lean/Base/Arith/Arith.lean
+++ b/backends/lean/Base/Arith/Arith.lean
@@ -230,25 +230,20 @@ def introInstances (declToUnfold : Name) (lookup : Expr → MetaM (Option Expr))
let type ← inferType e
let name ← mkFreshUserName `h
-- Add a declaration
- let nval ← Utils.addDecl name e type (asLet := false)
+ let nval ← Utils.addDeclTac name e type (asLet := false)
-- Simplify to unfold the declaration to unfold (i.e., the projector)
- let simpTheorems ← Tactic.simpOnlyBuiltins.foldlM (·.addConst ·) ({} : SimpTheorems)
- -- Add the equational theorem for the decl to unfold
- let simpTheorems ← simpTheorems.addDeclToUnfold declToUnfold
- let congrTheorems ← getSimpCongrTheorems
- let ctx : Simp.Context := { simpTheorems := #[simpTheorems], congrTheorems }
- -- Where to apply the simplifier
- let loc := Tactic.Location.targets #[mkIdent name] false
- -- Apply the simplifier
- let _ ← Tactic.simpLocation ctx (discharge? := .none) loc
+ Utils.simpAt [declToUnfold] [] [] (Tactic.Location.targets #[mkIdent name] false)
-- Return the new value
pure nval
+def introHasPropInstances : Tactic.TacticM (Array Expr) := do
+ trace[Arith] "Introducing the HasProp instances"
+ introInstances ``HasProp.prop_ty lookupHasProp
+
-- Lookup the instances of `HasProp for all the sub-expressions in the context,
-- and introduce the corresponding assumptions
elab "intro_has_prop_instances" : tactic => do
- trace[Arith] "Introducing the HasProp instances"
- let _ ← introInstances ``HasProp.prop_ty lookupHasProp
+ let _ ← introHasPropInstances
example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by
intro_has_prop_instances
@@ -258,74 +253,6 @@ example {a: Type} (v : Vec a) : v.val.length ≤ Scalar.max ScalarTy.Usize := by
intro_has_prop_instances
simp_all [Scalar.max, Scalar.min]
--- Tactic to split on a disjunction.
--- The expression `h` should be an fvar.
-def splitDisj (h : Expr) (kleft kright : Tactic.TacticM Unit) : Tactic.TacticM Unit := do
- trace[Arith] "assumption on which to split: {h}"
- -- Retrieve the main goal
- Tactic.withMainContext do
- let goalType ← Tactic.getMainTarget
- let hDecl := (← getLCtx).get! h.fvarId!
- let hName := hDecl.userName
- -- Case disjunction
- let hTy ← inferType h
- hTy.withApp fun f xs => do
- trace[Arith] "as app: {f} {xs}"
- -- Sanity check
- if ¬ (f.isConstOf ``Or ∧ xs.size = 2) then throwError "Invalid argument to splitDisj"
- let a := xs.get! 0
- let b := xs.get! 1
- -- Introduce the new goals
- -- Returns:
- -- - the match branch
- -- - a fresh new mvar id
- let mkGoal (hTy : Expr) (nGoalName : String) : MetaM (Expr × MVarId) := do
- -- Introduce a variable for the assumption (`a` or `b`). Note that we reuse
- -- the name of the assumption we split.
- withLocalDeclD hName hTy fun var => do
- -- The new goal
- let mgoal ← mkFreshExprSyntheticOpaqueMVar goalType (tag := Name.mkSimple nGoalName)
- -- Clear the assumption that we split
- let mgoal ← mgoal.mvarId!.tryClearMany #[h.fvarId!]
- -- The branch expression
- let branch ← mkLambdaFVars #[var] (mkMVar mgoal)
- pure (branch, mgoal)
- let (inl, mleft) ← mkGoal a "left"
- let (inr, mright) ← mkGoal b "right"
- trace[Arith] "left: {inl}: {mleft}"
- trace[Arith] "right: {inr}: {mright}"
- -- Create the match expression
- withLocalDeclD (← mkFreshUserName `h) hTy fun hVar => do
- let motive ← mkLambdaFVars #[hVar] goalType
- let casesExpr ← mkAppOptM ``Or.casesOn #[a, b, motive, h, inl, inr]
- let mgoal ← Tactic.getMainGoal
- trace[Arith] "goals: {← Tactic.getUnsolvedGoals}"
- trace[Arith] "main goal: {mgoal}"
- mgoal.assign casesExpr
- let goals ← Tactic.getUnsolvedGoals
- -- Focus on the left
- Tactic.setGoals [mleft]
- kleft
- let leftGoals ← Tactic.getUnsolvedGoals
- -- Focus on the right
- Tactic.setGoals [mright]
- kright
- let rightGoals ← Tactic.getUnsolvedGoals
- -- Put all the goals back
- Tactic.setGoals (leftGoals ++ rightGoals ++ goals)
- trace[Arith] "new goals: {← Tactic.getUnsolvedGoals}"
-
-elab "split_disj " n:ident : tactic => do
- Tactic.withMainContext do
- let decl ← Lean.Meta.getLocalDeclFromUserName n.getId
- let fvar := mkFVar decl.fvarId
- splitDisj fvar (fun _ => pure ()) (fun _ => pure ())
-
-example (x y : Int) (h0 : x ≤ y ∨ x ≥ y) : x ≤ y ∨ x ≥ y := by
- split_disj h0
- . left; assumption
- . right; assumption
-
-- Lookup the instances of `PropHasImp for all the sub-expressions in the context,
-- and introduce the corresponding assumptions
elab "intro_prop_has_imp_instances" : tactic => do
@@ -357,7 +284,7 @@ def intTacPreprocess : Tactic.TacticM Unit := do
| [] => pure ()
| asm :: asms =>
let k := splitOnAsms asms
- splitDisj asm k k
+ Utils.splitDisjTac asm k k
-- Introduce
let asms ← introInstances ``PropHasImp.concl lookupPropHasImp
-- Split
@@ -403,18 +330,27 @@ example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) :
int_tac
-- A tactic to solve linear arithmetic goals in the presence of scalars
-syntax "scalar_tac" : tactic
-macro_rules
- | `(tactic| scalar_tac) =>
- `(tactic|
- intro_has_prop_instances;
- have := Scalar.cMin_bound ScalarTy.Usize;
- have := Scalar.cMin_bound ScalarTy.Isize;
- have := Scalar.cMax_bound ScalarTy.Usize;
- have := Scalar.cMax_bound ScalarTy.Isize;
- -- TODO: not too sure about that
- simp only [*, Scalar.max, Scalar.min, Scalar.cMin, Scalar.cMax] at *;
- int_tac)
+def scalarTac : Tactic.TacticM Unit := do
+ Tactic.withMainContext do
+ -- Introduce the scalar bounds
+ let _ ← introHasPropInstances
+ Tactic.allGoals do
+ -- Inroduce the bounds for the isize/usize types
+ let add (e : Expr) : Tactic.TacticM Unit := do
+ let ty ← inferType e
+ let _ ← Utils.addDeclTac (← mkFreshUserName `h) e ty (asLet := false)
+ add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Usize []])
+ add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Isize []])
+ add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []])
+ add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []])
+ -- Reveal the concrete bounds - TODO: not too sure about that.
+ -- Maybe we should reveal the "concrete" bounds (after normalization)
+ Utils.simpAt [``Scalar.max, ``Scalar.min, ``Scalar.cMin, ``Scalar.cMax] [] [] .wildcard
+ -- Apply the integer tactic
+ intTac
+
+elab "scalar_tac" : tactic =>
+ scalarTac
example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by
scalar_tac