summaryrefslogtreecommitdiff
path: root/backends
diff options
context:
space:
mode:
authorSon Ho2023-10-17 10:36:15 +0200
committerSon Ho2023-10-17 10:44:20 +0200
commit61368028027a7c160c33b05ec605c26833212667 (patch)
treea841ad12878ec4f9314b2af17f1a774b41b6dd7b /backends
parent584726e9c4e4378129a35f6cfbbbf934448d10a9 (diff)
Refold the scalar types when applying progress
Diffstat (limited to 'backends')
-rw-r--r--backends/lean/Base/Arith/Int.lean4
-rw-r--r--backends/lean/Base/Arith/Scalar.lean2
-rw-r--r--backends/lean/Base/Progress/Progress.lean26
-rw-r--r--backends/lean/Base/Utils.lean75
4 files changed, 91 insertions, 16 deletions
diff --git a/backends/lean/Base/Arith/Int.lean b/backends/lean/Base/Arith/Int.lean
index 2959e245..a57f8bb1 100644
--- a/backends/lean/Base/Arith/Int.lean
+++ b/backends/lean/Base/Arith/Int.lean
@@ -162,7 +162,7 @@ def introInstances (declToUnfold : Name) (lookup : Expr → MetaM (Option Expr))
-- Add a declaration
let nval ← Utils.addDeclTac name e type (asLet := false)
-- Simplify to unfold the declaration to unfold (i.e., the projector)
- Utils.simpAt [declToUnfold] [] [] (Tactic.Location.targets #[mkIdent name] false)
+ Utils.simpAt true [declToUnfold] [] [] (Location.targets #[mkIdent name] false)
-- Return the new value
pure nval
@@ -240,7 +240,7 @@ def intTac (splitGoalConjs : Bool) (extraPreprocess : Tactic.TacticM Unit) : Ta
-- the goal. I think before leads to a smaller proof term?
Tactic.allGoals (intTacPreprocess extraPreprocess)
-- More preprocessing
- Tactic.allGoals (Utils.tryTac (Utils.simpAt [] [``nat_zero_eq_int_zero] [] .wildcard))
+ Tactic.allGoals (Utils.tryTac (Utils.simpAt true [] [``nat_zero_eq_int_zero] [] .wildcard))
-- Split the conjunctions in the goal
if splitGoalConjs then Tactic.allGoals (Utils.repeatTac Utils.splitConjTarget)
-- Call linarith
diff --git a/backends/lean/Base/Arith/Scalar.lean b/backends/lean/Base/Arith/Scalar.lean
index 66c31129..2342cce6 100644
--- a/backends/lean/Base/Arith/Scalar.lean
+++ b/backends/lean/Base/Arith/Scalar.lean
@@ -17,7 +17,7 @@ def scalarTacExtraPreprocess : Tactic.TacticM Unit := do
add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []])
add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []])
-- Reveal the concrete bounds, simplify calls to [ofInt]
- Utils.simpAt [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax,
+ Utils.simpAt true [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax,
``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min,
``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max,
``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min,
diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean
index 8b0759c5..ecf05dab 100644
--- a/backends/lean/Base/Progress/Progress.lean
+++ b/backends/lean/Base/Progress/Progress.lean
@@ -8,6 +8,27 @@ namespace Progress
open Lean Elab Term Meta Tactic
open Utils
+-- TODO: the scalar types annoyingly often get reduced when we use the progress
+-- tactic. We should find a way of controling reduction. For now we use rewriting
+-- lemmas to make sure the goal remains clean, but this complexifies proof terms.
+-- It seems there used to be a `fold` tactic.
+theorem scalar_isize_eq : Primitives.Scalar .Isize = Primitives.Isize := by rfl
+theorem scalar_i8_eq : Primitives.Scalar .I8 = Primitives.I8 := by rfl
+theorem scalar_i16_eq : Primitives.Scalar .I16 = Primitives.I16 := by rfl
+theorem scalar_i32_eq : Primitives.Scalar .I32 = Primitives.I32 := by rfl
+theorem scalar_i64_eq : Primitives.Scalar .I64 = Primitives.I64 := by rfl
+theorem scalar_i128_eq : Primitives.Scalar .I128 = Primitives.I128 := by rfl
+theorem scalar_usize_eq : Primitives.Scalar .Usize = Primitives.Usize := by rfl
+theorem scalar_u8_eq : Primitives.Scalar .U8 = Primitives.U8 := by rfl
+theorem scalar_u16_eq : Primitives.Scalar .U16 = Primitives.U16 := by rfl
+theorem scalar_u32_eq : Primitives.Scalar .U32 = Primitives.U32 := by rfl
+theorem scalar_u64_eq : Primitives.Scalar .U64 = Primitives.U64 := by rfl
+theorem scalar_u128_eq : Primitives.Scalar .U128 = Primitives.U128 := by rfl
+def scalar_eqs := [
+ ``scalar_isize_eq, ``scalar_i8_eq, ``scalar_i16_eq, ``scalar_i32_eq, ``scalar_i64_eq, ``scalar_i128_eq,
+ ``scalar_usize_eq, ``scalar_u8_eq, ``scalar_u16_eq, ``scalar_u32_eq, ``scalar_u64_eq, ``scalar_u128_eq
+]
+
inductive TheoremOrLocal where
| Theorem (thName : Name)
| Local (asm : LocalDecl)
@@ -111,8 +132,11 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal)
splitEqAndPost fun hEq hPost ids => do
trace[Progress] "eq and post:\n{hEq} : {← inferType hEq}\n{hPost}"
tryTac (
- simpAt [] [``Primitives.bind_tc_ret, ``Primitives.bind_tc_fail, ``Primitives.bind_tc_div]
+ simpAt true []
+ [``Primitives.bind_tc_ret, ``Primitives.bind_tc_fail, ``Primitives.bind_tc_div]
[hEq.fvarId!] (.targets #[] true))
+ -- TODO: remove this (some types get unfolded too much: we "fold" them back)
+ tryTac (simpAt true [] scalar_eqs [] .wildcard_dep)
-- Clear the equality, unless the user requests not to do so
let mgoal ← do
if keep.isSome then getMainGoal
diff --git a/backends/lean/Base/Utils.lean b/backends/lean/Base/Utils.lean
index 5224e1c3..b917a789 100644
--- a/backends/lean/Base/Utils.lean
+++ b/backends/lean/Base/Utils.lean
@@ -604,16 +604,12 @@ example (h : ∃ x y z, x + y + z ≥ 0) : ∃ x, x ≥ 0 := by
rename_i x y z
exists x + y + z
-/- Call the simp tactic.
- The initialization of the context is adapted from Tactic.elabSimpArgs.
- Something very annoying is that there is no function which allows to
- initialize a simp context without doing an elaboration - as a consequence
- we write our own here. -/
-def simpAt (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId)
- (loc : Tactic.Location) :
- Tactic.TacticM Unit := do
- -- Initialize with the builtin simp theorems
- let simpThms ← Tactic.simpOnlyBuiltins.foldlM (·.addConst ·) ({} : SimpTheorems)
+def mkSimpCtx (simpOnly : Bool) (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) :
+ Tactic.TacticM Simp.Context := do
+ -- Initialize either with the builtin simp theorems or with all the simp theorems
+ let simpThms ←
+ if simpOnly then Tactic.simpOnlyBuiltins.foldlM (·.addConst ·) ({} : SimpTheorems)
+ else getSimpTheorems
-- Add the equational theorem for the declarations to unfold
let simpThms ←
declsToUnfold.foldlM (fun thms decl => thms.addDeclToUnfold decl) simpThms
@@ -637,8 +633,63 @@ def simpAt (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVar
throwError "Not a proposition: {thmName}"
) simpThms
let congrTheorems ← getSimpCongrTheorems
- let ctx : Simp.Context := { simpTheorems := #[simpThms], congrTheorems }
+ pure { simpTheorems := #[simpThms], congrTheorems }
+
+
+inductive Location where
+ /-- Apply the tactic everywhere. Same as `Tactic.Location.wildcard` -/
+ | wildcard
+ /-- Apply the tactic everywhere, including in the variable types (i.e., in
+ assumptions which are not propositions). --/
+ | wildcard_dep
+ /-- Same as Tactic.Location -/
+ | targets (hypotheses : Array Syntax) (type : Bool)
+
+-- Comes from Tactic.simpLocation
+def customSimpLocation (ctx : Simp.Context) (discharge? : Option Simp.Discharge := none)
+ (loc : Location) : TacticM Simp.UsedSimps := do
+ match loc with
+ | Location.targets hyps simplifyTarget =>
+ withMainContext do
+ let fvarIds ← Lean.Elab.Tactic.getFVarIds hyps
+ go fvarIds simplifyTarget
+ | Location.wildcard =>
+ withMainContext do
+ go (← (← getMainGoal).getNondepPropHyps) (simplifyTarget := true)
+ | Location.wildcard_dep =>
+ withMainContext do
+ let ctx ← Lean.MonadLCtx.getLCtx
+ let decls ← ctx.getDecls
+ let tgts := (decls.map (fun d => d.fvarId)).toArray
+ go tgts (simplifyTarget := true)
+where
+ go (fvarIdsToSimp : Array FVarId) (simplifyTarget : Bool) : TacticM Simp.UsedSimps := do
+ let mvarId ← getMainGoal
+ let (result?, usedSimps) ← simpGoal mvarId ctx (simplifyTarget := simplifyTarget) (discharge? := discharge?) (fvarIdsToSimp := fvarIdsToSimp)
+ match result? with
+ | none => replaceMainGoal []
+ | some (_, mvarId) => replaceMainGoal [mvarId]
+ return usedSimps
+
+/- Call the simp tactic.
+ The initialization of the context is adapted from Tactic.elabSimpArgs.
+ Something very annoying is that there is no function which allows to
+ initialize a simp context without doing an elaboration - as a consequence
+ we write our own here. -/
+def simpAt (simpOnly : Bool) (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId)
+ (loc : Location) :
+ Tactic.TacticM Unit := do
+ -- Initialize the simp context
+ let ctx ← mkSimpCtx simpOnly declsToUnfold thms hypsToUse
+ -- Apply the simplifier
+ let _ ← customSimpLocation ctx (discharge? := .none) loc
+
+-- Call the simpAll tactic
+def simpAll (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) :
+ Tactic.TacticM Unit := do
+ -- Initialize the simp context
+ let ctx ← mkSimpCtx false declsToUnfold thms hypsToUse
-- Apply the simplifier
- let _ ← Tactic.simpLocation ctx (discharge? := .none) loc
+ let _ ← Lean.Meta.simpAll (← getMainGoal) ctx
end Utils