summaryrefslogtreecommitdiff
path: root/backends/lean/Base
diff options
context:
space:
mode:
Diffstat (limited to 'backends/lean/Base')
-rw-r--r--backends/lean/Base/Arith/Base.lean11
-rw-r--r--backends/lean/Base/Arith/Int.lean10
-rw-r--r--backends/lean/Base/Arith/Scalar.lean59
-rw-r--r--backends/lean/Base/Diverge/Base.lean21
-rw-r--r--backends/lean/Base/Diverge/Elab.lean129
-rw-r--r--backends/lean/Base/Extensions.lean1
-rw-r--r--backends/lean/Base/IList/IList.lean12
-rw-r--r--backends/lean/Base/Primitives.lean1
-rw-r--r--backends/lean/Base/Primitives/ArraySlice.lean5
-rw-r--r--backends/lean/Base/Primitives/CoreConvertNum.lean1
-rw-r--r--backends/lean/Base/Primitives/Scalar.lean154
-rw-r--r--backends/lean/Base/Primitives/ScalarNotations.lean109
-rw-r--r--backends/lean/Base/Primitives/Vec.lean7
-rw-r--r--backends/lean/Base/Progress/Base.lean3
-rw-r--r--backends/lean/Base/Progress/Progress.lean23
-rw-r--r--backends/lean/Base/Utils.lean123
16 files changed, 400 insertions, 269 deletions
diff --git a/backends/lean/Base/Arith/Base.lean b/backends/lean/Base/Arith/Base.lean
index 8ada4171..fb6b12e5 100644
--- a/backends/lean/Base/Arith/Base.lean
+++ b/backends/lean/Base/Arith/Base.lean
@@ -1,6 +1,5 @@
import Lean
-import Std.Data.Int.Lemmas
-import Mathlib.Tactic.Linarith
+import Mathlib.Tactic.Linarith -- Introduces a lot of useful lemmas
namespace Arith
@@ -21,12 +20,12 @@ theorem ne_is_lt_or_gt {x y : Int} (hne : x ≠ y) : x < y ∨ x > y := by
have hne : x - y ≠ 0 := by
simp
intro h
- have: x = y := by linarith
+ have: x = y := by omega
simp_all
have h := ne_zero_is_lt_or_gt hne
match h with
- | .inl _ => left; linarith
- | .inr _ => right; linarith
+ | .inl _ => left; omega
+ | .inr _ => right; omega
-- TODO: move?
theorem add_one_le_iff_le_ne (n m : Nat) (h1 : m ≤ n) (h2 : m ≠ n) : m + 1 ≤ n := by
@@ -66,7 +65,7 @@ theorem to_int_to_nat_lt (x y : ℤ) (h0 : 0 ≤ x) (h1 : x < y) :
theorem to_int_sub_to_nat_lt (x y : ℤ) (x' : ℕ)
(h0 : ↑x' ≤ x) (h1 : x - ↑x' < y) :
↑(x.toNat - x') < y := by
- have : 0 ≤ x := by linarith
+ have : 0 ≤ x := by omega
simp [Int.toNat_sub_of_le, *]
end Arith
diff --git a/backends/lean/Base/Arith/Int.lean b/backends/lean/Base/Arith/Int.lean
index a1cb9da3..ab6dd4ab 100644
--- a/backends/lean/Base/Arith/Int.lean
+++ b/backends/lean/Base/Arith/Int.lean
@@ -27,7 +27,7 @@ class HasIntPred {a: Sort u} (x: a) where
prop : concl
/- Proposition with implications: if we find P we can introduce Q in the context -/
-class PropHasImp (x : Prop) where
+class PropHasImp (x : Sort u) where
concl : Prop
prop : x → concl
@@ -199,7 +199,7 @@ def introInstances (declToUnfold : Name) (lookup : Expr → MetaM (Option Expr))
-- Add a declaration
let nval ← Utils.addDeclTac name e type (asLet := false)
-- Simplify to unfold the declaration to unfold (i.e., the projector)
- Utils.simpAt true [declToUnfold] [] [] (Location.targets #[mkIdent name] false)
+ Utils.simpAt true {} #[] [declToUnfold] [] [] (Location.targets #[mkIdent name] false)
-- Return the new value
pure nval
@@ -242,7 +242,7 @@ def intTacPreprocess (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM U
extraPreprocess
-- Reduce all the terms in the goal - note that the extra preprocessing step
-- might have proven the goal, hence the `Tactic.allGoals`
- Tactic.allGoals do tryTac (dsimpAt false [] [] [] Tactic.Location.wildcard)
+ Tactic.allGoals do tryTac (dsimpAt false {} #[] [] [] [] Tactic.Location.wildcard)
elab "int_tac_preprocess" : tactic =>
intTacPreprocess (do pure ())
@@ -259,10 +259,10 @@ def intTac (tacName : String) (splitGoalConjs : Bool) (extraPreprocess : Tactic
-- the goal. I think before leads to a smaller proof term?
Tactic.allGoals (intTacPreprocess extraPreprocess)
-- More preprocessing
- Tactic.allGoals (Utils.tryTac (Utils.simpAt true [] [``nat_zero_eq_int_zero] [] .wildcard))
+ Tactic.allGoals (Utils.tryTac (Utils.simpAt true {} #[] [] [``nat_zero_eq_int_zero] [] .wildcard))
-- Split the conjunctions in the goal
if splitGoalConjs then Tactic.allGoals (Utils.repeatTac Utils.splitConjTarget)
- -- Call linarith
+ -- Call omega
Tactic.allGoals do
try do Tactic.Omega.omegaTactic {}
catch _ =>
diff --git a/backends/lean/Base/Arith/Scalar.lean b/backends/lean/Base/Arith/Scalar.lean
index 86b2e216..ecc5acaf 100644
--- a/backends/lean/Base/Arith/Scalar.lean
+++ b/backends/lean/Base/Arith/Scalar.lean
@@ -8,30 +8,31 @@ open Lean Lean.Elab Lean.Meta
open Primitives
def scalarTacExtraPreprocess : Tactic.TacticM Unit := do
- Tactic.withMainContext do
- -- Inroduce the bounds for the isize/usize types
- let add (e : Expr) : Tactic.TacticM Unit := do
- let ty ← inferType e
- let _ ← Utils.addDeclTac (← Utils.mkFreshAnonPropUserName) e ty (asLet := false)
- add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Isize []])
- add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []])
- add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []])
- -- Reveal the concrete bounds, simplify calls to [ofInt]
- Utils.simpAt true
- -- Unfoldings
- [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax,
- ``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min,
- ``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max,
- ``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min,
- ``U8.max, ``U16.max, ``U32.max, ``U64.max, ``U128.max,
- ``Usize.min
- ]
- -- Simp lemmas
- [``Scalar.ofInt_val_eq, ``Scalar.neq_to_neq_val,
- ``Scalar.lt_equiv, ``Scalar.le_equiv, ``Scalar.eq_equiv]
- -- Hypotheses
- [] .wildcard
-
+ Tactic.withMainContext do
+ -- Inroduce the bounds for the isize/usize types
+ let add (e : Expr) : Tactic.TacticM Unit := do
+ let ty ← inferType e
+ let _ ← Utils.addDeclTac (← Utils.mkFreshAnonPropUserName) e ty (asLet := false)
+ add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Isize []])
+ add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []])
+ add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []])
+ -- Reveal the concrete bounds, simplify calls to [ofInt]
+ Utils.simpAt true {}
+ -- Simprocs
+ #[]
+ -- Unfoldings
+ [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax,
+ ``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min,
+ ``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max,
+ ``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min,
+ ``U8.max, ``U16.max, ``U32.max, ``U64.max, ``U128.max,
+ ``Usize.min
+ ]
+ -- Simp lemmas
+ [``Scalar.ofInt_val_eq, ``Scalar.neq_to_neq_val,
+ ``Scalar.lt_equiv, ``Scalar.le_equiv, ``Scalar.eq_equiv]
+ -- Hypotheses
+ [] .wildcard
elab "scalar_tac_preprocess" : tactic =>
intTacPreprocess scalarTacExtraPreprocess
@@ -81,4 +82,14 @@ example (x : Int) (h0 : 0 ≤ x) (h1 : x ≤ U32.max) :
example (x : U32) (h0 : ¬ x = U32.ofInt 0) : 0 < x.val := by
scalar_tac
+/- See this: https://aeneas-verif.zulipchat.com/#narrow/stream/349819-general/topic/U64.20trouble/near/444049757
+
+ We solved it by removing the instance `OfNat` for `Scalar`.
+ Note however that we could also solve it with a simplification lemma.
+ However, after testing, we noticed we could only apply such a lemma with
+ the rewriting tactic (not the simplifier), probably because of the use
+ of typeclasses. -/
+example {u: U64} (h1: (u : Int) < 2): (u : Int) = 0 ∨ (u : Int) = 1 := by
+ scalar_tac
+
end Arith
diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean
index 0f20125f..aab4db8f 100644
--- a/backends/lean/Base/Diverge/Base.lean
+++ b/backends/lean/Base/Diverge/Base.lean
@@ -1,7 +1,6 @@
import Lean
import Lean.Meta.Tactic.Simp
import Init.Data.List.Basic
-import Mathlib.Tactic.Linarith
import Base.Primitives.Base
import Base.Arith.Base
import Base.Diverge.ElabBase
@@ -36,20 +35,19 @@ namespace Lemmas
revert m
induction k -- TODO: induction h rather?
case zero =>
- simp_all
intro m h1 h2
have h: n = m := by omega
unfold for_all_fin_aux; simp_all
simp_all
-- There is no i s.t. m ≤ i
intro i h3; cases i; simp_all
- linarith
+ omega
case succ k hi =>
intro m hk hmn
intro hf i hmi
have hne: m ≠ n := by
have hineq := Nat.lt_of_sub_eq_succ hk
- linarith
+ omega
-- m = i?
if heq: m = i then
-- Yes: simply use the `for_all_fin_aux` hyp
@@ -64,7 +62,7 @@ namespace Lemmas
have heq1: n - (m + 1) = k := by
-- TODO: very annoying arithmetic proof
simp [Nat.sub_eq_iff_eq_add hineq]
- have hineq1: m ≤ n := by linarith
+ have hineq1: m ≤ n := by omega
simp [Nat.sub_eq_iff_eq_add hineq1] at hk
simp_arith [hk]
have hi := hi (m + 1) heq1 hineq
@@ -199,7 +197,7 @@ namespace Fix
| 0 =>
exfalso
zify at *
- linarith
+ omega
| Nat.succ m1 =>
simp_arith at Hle
simp [fix_fuel]
@@ -407,7 +405,7 @@ namespace Fix
. simp at Hl
-- Make a case disjunction on `h y (fix_fuel m k)`: if it is not equal
-- to div, use the monotonicity of `h y`
- have Hle : m ≤ n := by linarith
+ have Hle : m ≤ n := by omega
have Hffmono := fix_fuel_mono Hkmono Hle
have Hmono := Hhmono y Hffmono
simp [result_rel] at Hmono
@@ -568,6 +566,7 @@ namespace FixI
have Heq := Fix.is_valid_fix_fixed_eq Hvalid'
simp [fix]
conv => lhs; rw [Heq]
+ rfl
/- Some utilities to define the mutually recursive functions -/
@@ -778,6 +777,7 @@ namespace FixII
have Heq := Fix.is_valid_fix_fixed_eq Hvalid'
simp [fix]
conv => lhs; rw [Heq]
+ rfl
/- Some utilities to define the mutually recursive functions -/
@@ -966,6 +966,7 @@ namespace Ex1
have Heq := is_valid_fix_fixed_eq (@list_nth_body_is_valid a)
simp [list_nth]
conv => lhs; rw [Heq]
+ rfl
end Ex1
@@ -1011,6 +1012,7 @@ namespace Ex2
have Heq := is_valid_fix_fixed_eq (@list_nth_body_is_valid a)
simp [list_nth]
conv => lhs; rw [Heq]
+ rfl
end Ex2
@@ -1183,6 +1185,7 @@ namespace Ex4
.ok b) := by
simp [is_even, is_odd];
conv => lhs; rw [body_fix_eq]
+ rfl
theorem is_odd_eq (i : Int) : is_odd i =
(if i = 0
@@ -1192,6 +1195,7 @@ namespace Ex4
.ok b) := by
simp [is_even, is_odd];
conv => lhs; rw [body_fix_eq]
+ rfl
end Ex4
namespace Ex5
@@ -1263,6 +1267,7 @@ namespace Ex5
have Heq := is_valid_fix_fixed_eq (@id_body_is_valid a)
simp [id]
conv => lhs; rw [Heq]; simp; rw [id_body]
+ rfl
end Ex5
@@ -1336,6 +1341,7 @@ namespace Ex6
have Heq := is_valid_fix_fixed_eq body_is_valid
simp [list_nth]
conv => lhs; rw [Heq]
+ rfl
-- Write the proof term explicitly: the generation of the proof term (without tactics)
-- is automatable, and the proof term is actually a lot simpler and smaller when we
@@ -1429,6 +1435,7 @@ namespace Ex7
have Heq := is_valid_fix_fixed_eq body_is_valid
simp [list_nth]
conv => lhs; rw [Heq]
+ rfl
-- Write the proof term explicitly: the generation of the proof term (without tactics)
-- is automatable, and the proof term is actually a lot simpler and smaller when we
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean
index 5db8ffed..60955051 100644
--- a/backends/lean/Base/Diverge/Elab.lean
+++ b/backends/lean/Base/Diverge/Elab.lean
@@ -22,6 +22,10 @@ def normalize_let_bindings := true
open WF in
+-- Small utility - it seems that `Name.append` doesn't do what we want
+def appendToName (n : Name) (s : String) : Name :=
+ Name.str n s
+
-- TODO: use those
def UnitType := Expr.const ``PUnit [Level.succ .zero]
def UnitValue := Expr.const ``PUnit.unit [Level.succ .zero]
@@ -548,7 +552,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
-- Add the declaration
let value ← mkLambdaFVars #[kk_var] body
trace[Diverge.def.genBody] "Body after abstracting kk: {value}"
- let name := preDef.declName.append "body"
+ let name := appendToName preDef.declName "body"
let levelParams := grLvlParams
let decl := Declaration.defnDecl {
name := name
@@ -603,7 +607,7 @@ def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name)
let body ← mkLambdaFVars #[kk_var, i_var] body
trace[Diverge.def] "mkDeclareMutRecBody: body: {body}"
-- Add the declaration
- let name := grName.append "mut_rec_body"
+ let name := appendToName grName "mut_rec_body"
let levelParams := grLvlParams
let decl := Declaration.defnDecl {
name := name
@@ -1047,7 +1051,7 @@ partial def proveSingleBodyIsValid
mkForallFVars #[k_var, t_var, x_var] ty
trace[Diverge.def.valid] "proveSingleBodyIsValid: thmTy\n{thmTy}:\n{← inferType thmTy}"
-- Save the theorem
- let name := preDef.declName ++ "body_is_valid"
+ let name := appendToName preDef.declName "body_is_valid"
let decl := Declaration.thmDecl {
name
levelParams := preDef.levelParams
@@ -1107,7 +1111,7 @@ def proveMutRecIsValid
trace[Diverge.def.valid] "Generated the term: {isValid}"
-- Save the theorem
let thmTy ← mkAppM ``FixII.is_valid #[mutRecBodyConst]
- let name := grName ++ "mut_rec_body_is_valid"
+ let name := appendToName grName "mut_rec_body_is_valid"
let decl := Declaration.thmDecl {
name
levelParams := grLvlParams
@@ -1196,7 +1200,7 @@ partial def proveUnfoldingThms (isValidThm : Expr)
let proof ← mkLambdaFVars xs proof
trace[Diverge.def.unfold] "proveUnfoldingThms: proof: {proof}:\n{← inferType proof}"
-- Declare the theorem
- let name := preDef.declName ++ "unfold"
+ let name := appendToName preDef.declName "unfold"
let decl := Declaration.thmDecl {
name
levelParams := preDef.levelParams
@@ -1282,7 +1286,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
-- Add an auxiliary definition for `param_in_out_ty` (this is a potentially big term)
let param_in_out_ty ← do
let value ← mkLambdaFVars #[i_var] param_in_out_ty
- let name := grName.append "param_in_out_ty"
+ let name := appendToName grName "param_in_out_ty"
let levelParams := grLvlParams
let decl := Declaration.defnDecl {
name := name
@@ -1392,44 +1396,71 @@ def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLC
open private elabHeaders levelMVarToParamHeaders getAllUserLevelNames withFunLocalDecls elabFunValues
instantiateMVarsAtHeader instantiateMVarsAtLetRecToLift checkLetRecsToLiftTypes withUsed from Lean.Elab.MutualDef
-def Term.elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do
- let scopeLevelNames ← getLevelNames
- let headers ← elabHeaders views
- let headers ← levelMVarToParamHeaders views headers
- let allUserLevelNames := getAllUserLevelNames headers
- withFunLocalDecls headers fun funFVars => do
- for view in views, funFVar in funFVars do
- addLocalVarInfo view.declId funFVar
- -- Add fake use site to prevent "unused variable" warning (if the
- -- function is actually not recursive, Lean would print this warning).
- -- Remark: we could detect this case and encode the function without
- -- using the fixed-point. In practice it shouldn't happen however:
- -- we define non-recursive functions with the `divergent` keyword
- -- only for testing purposes.
- addTermInfo' view.declId funFVar
- let values ←
- try
- let values ← elabFunValues headers
- Term.synthesizeSyntheticMVarsNoPostponing
- values.mapM (instantiateMVars ·)
- catch ex =>
- logException ex
- headers.mapM fun header => mkSorry header.type (synthetic := true)
- let headers ← headers.mapM instantiateMVarsAtHeader
- let letRecsToLift ← getLetRecsToLift
- let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift
- checkLetRecsToLiftTypes funFVars letRecsToLift
- withUsed vars headers values letRecsToLift fun vars => do
- let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift
- for preDef in preDefs do
- trace[Diverge.elab] "{preDef.declName} : {preDef.type} :=\n{preDef.value}"
- let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamPreDecls preDefs
- let preDefs ← instantiateMVarsAtPreDecls preDefs
- let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames
- for preDef in preDefs do
- trace[Diverge.elab] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}"
- checkForHiddenUnivLevels allUserLevelNames preDefs
- addPreDefinitions preDefs
+-- Comes from Term.isExample
+def isExample (views : Array DefView) : Bool :=
+ views.any (·.kind.isExample)
+
+open Language in
+def Term.elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit :=
+ if isExample views then
+ withoutModifyingEnv do
+ -- save correct environment in info tree
+ withSaveInfoContext do
+ go
+ else
+ go
+where
+ go :=
+ withAlwaysResolvedPromises views.size fun bodyPromises =>
+ withAlwaysResolvedPromises views.size fun tacPromises => do
+ let scopeLevelNames ← getLevelNames
+ let headers ← elabHeaders views bodyPromises tacPromises
+ let headers ← levelMVarToParamHeaders views headers
+ let allUserLevelNames := getAllUserLevelNames headers
+ withFunLocalDecls headers fun funFVars => do
+ for view in views, funFVar in funFVars do
+ addLocalVarInfo view.declId funFVar
+ -- Modification 1:
+ -- Add fake use site to prevent "unused variable" warning (if the
+ -- function is actually not recursive, Lean would print this warning).
+ -- Remark: we could detect this case and encode the function without
+ -- using the fixed-point. In practice it shouldn't happen however:
+ -- we define non-recursive functions with the `divergent` keyword
+ -- only for testing purposes.
+ addTermInfo' view.declId funFVar
+ let values ←
+ try
+ let values ← elabFunValues headers
+ Term.synthesizeSyntheticMVarsNoPostponing
+ values.mapM (instantiateMVars ·)
+ catch ex =>
+ logException ex
+ headers.mapM fun header => mkSorry header.type (synthetic := true)
+ let headers ← headers.mapM instantiateMVarsAtHeader
+ let letRecsToLift ← getLetRecsToLift
+ let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift
+ checkLetRecsToLiftTypes funFVars letRecsToLift
+ withUsed vars headers values letRecsToLift fun vars => do
+ let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift
+ for preDef in preDefs do
+ trace[Elab.definition] "{preDef.declName} : {preDef.type} :=\n{preDef.value}"
+ let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamPreDecls preDefs
+ let preDefs ← instantiateMVarsAtPreDecls preDefs
+ let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames
+ for preDef in preDefs do
+ trace[Elab.definition] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}"
+ checkForHiddenUnivLevels allUserLevelNames preDefs
+ addPreDefinitions preDefs -- Modification 2: we use our custom function here
+ processDeriving headers
+
+ processDeriving (headers : Array DefViewElabHeader) := do
+ for header in headers, view in views do
+ if let some classNamesStx := view.deriving? then
+ for classNameStx in classNamesStx do
+ let className ← realizeGlobalConstNoOverload classNameStx
+ withRef classNameStx do
+ unless (← processDefDeriving className header.declName) do
+ throwError "failed to synthesize instance '{className}' for '{header.declName}'"
open Command in
def Command.elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do
@@ -1439,7 +1470,8 @@ def Command.elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do
let modifiers ← elabModifiers mods
let (binders, type) := expandOptDeclSig sig
let deriving? := none
- pure { ref := d, kind := DefKind.def, modifiers,
+ let headerRef := Syntax.missing -- Not sure what to put here
+ pure { ref := d, kind := DefKind.def, headerRef, modifiers,
declId := id, binders, type? := type, value := val, deriving? }
runTermElabM fun vars => Term.elabMutualDef vars views
@@ -1460,7 +1492,7 @@ elab_rules : command
if (`_root_).isPrefixOf name then throwUnsupportedSyntax
let view := extractMacroScopes name
let .str ns shortName := view.name | throwUnsupportedSyntax
- let shortName' := { view with name := shortName }.review
+ let shortName' := { view with name := Name.mkSimple shortName }.review
let cmd ← `(mutual $mods:declModifiers divergent%$tk def $(⟨setDeclIdName id shortName'⟩):declId $sig:optDeclSig $val:declVal end)
if ns matches .anonymous then
Command.elabCommand cmd
@@ -1475,6 +1507,7 @@ namespace Tests
--set_option trace.Diverge.def.genBody true
--set_option trace.Diverge.def.valid true
--set_option trace.Diverge.def.genBody.visit true
+ --set_option trace.Diverge.def.unfold true
divergent def list_nth {a: Type u} (ls : List a) (i : Int) : Result a :=
match ls with
@@ -1492,7 +1525,7 @@ namespace Tests
0 ≤ i → i < ls.length →
∃ x, list_nth ls i = .ok x := by
induction ls
- . intro i hpos h; simp at h; linarith
+ . intro i hpos h; simp at h; omega
. rename_i hd tl ih
intro i hpos h
-- We can directly use `rw [list_nth]`
@@ -1502,7 +1535,7 @@ namespace Tests
. -- We don't have to do this if we use scalar_tac
have hneq : 0 < i := by cases i <;> rename_i a _ <;> simp_all; cases a <;> simp_all
simp at h
- have ⟨ x, ih ⟩ := ih (i - 1) (by linarith) (by linarith)
+ have ⟨ x, ih ⟩ := ih (i - 1) (by omega) (by omega)
simp [ih]
tauto
diff --git a/backends/lean/Base/Extensions.lean b/backends/lean/Base/Extensions.lean
index c0e80861..b491f81b 100644
--- a/backends/lean/Base/Extensions.lean
+++ b/backends/lean/Base/Extensions.lean
@@ -1,5 +1,4 @@
import Lean
-import Std.Lean.HashSet
import Base.Utils
import Base.Primitives.Base
diff --git a/backends/lean/Base/IList/IList.lean b/backends/lean/Base/IList/IList.lean
index 9fe2297f..96843f55 100644
--- a/backends/lean/Base/IList/IList.lean
+++ b/backends/lean/Base/IList/IList.lean
@@ -1,7 +1,6 @@
/- Complementary list functions and lemmas which operate on integers rather
than natural numbers. -/
-import Std.Data.Int.Lemmas
import Base.Arith
import Base.Utils
@@ -17,7 +16,7 @@ def len (ls : List α) : Int :=
theorem len_pos : 0 ≤ (ls : List α).len := by
induction ls <;> simp [*]
- linarith
+ omega
instance (l: List a) : Arith.HasIntPred (l.len) where
concl := 0 ≤ l.len
@@ -171,6 +170,7 @@ theorem ireplicate_replicate {α : Type u} (l : ℤ) (x : α) (h : 0 ≤ l) :
have hl : l.toNat = .succ (l.toNat - 1) := by
cases hl: l.toNat <;> simp_all
conv => rhs; rw[hl]
+ rfl
termination_by l.toNat
decreasing_by int_decr_tac
@@ -279,12 +279,12 @@ open Arith in
if heq: i = 0 then
simp [*] at *
have := tl.len_pos
- linarith
+ omega
else
have : 0 < i := by int_tac
simp [*]
apply hi
- linarith
+ omega
theorem idrop_len_le (i : Int) (ls : List α) : (ls.idrop i).len ≤ ls.len :=
match ls with
@@ -293,13 +293,13 @@ theorem idrop_len_le (i : Int) (ls : List α) : (ls.idrop i).len ≤ ls.len :=
if h: i = 0 then by simp [*]
else
have := idrop_len_le (i - 1) tl
- by simp [*]; linarith
+ by simp [*]; omega
@[simp]
theorem idrop_len (i : Int) (ls : List α) (_ : 0 ≤ i) (_ : i ≤ ls.len) :
(ls.idrop i).len = ls.len - i :=
match ls with
- | [] => by simp_all; linarith
+ | [] => by simp_all; omega
| hd :: tl =>
if h: i = 0 then by simp [*]
else
diff --git a/backends/lean/Base/Primitives.lean b/backends/lean/Base/Primitives.lean
index f80c2004..93617049 100644
--- a/backends/lean/Base/Primitives.lean
+++ b/backends/lean/Base/Primitives.lean
@@ -1,6 +1,7 @@
import Base.Primitives.Base
import Base.Tuples
import Base.Primitives.Scalar
+import Base.Primitives.ScalarNotations
import Base.Primitives.ArraySlice
import Base.Primitives.Vec
import Base.Primitives.Alloc
diff --git a/backends/lean/Base/Primitives/ArraySlice.lean b/backends/lean/Base/Primitives/ArraySlice.lean
index 157f9df1..be460987 100644
--- a/backends/lean/Base/Primitives/ArraySlice.lean
+++ b/backends/lean/Base/Primitives/ArraySlice.lean
@@ -126,7 +126,7 @@ abbrev Slice.v {α : Type u} (v : Slice α) : List α := v.val
example {a: Type u} (v : Slice a) : v.length ≤ Scalar.max ScalarTy.Usize := by
scalar_tac
-def Slice.new (α : Type u): Slice α := ⟨ [], by apply Scalar.cMax_suffices .Usize; simp; decide ⟩
+def Slice.new (α : Type u): Slice α := ⟨ [], by apply Scalar.cMax_suffices .Usize; simp ⟩
-- TODO: very annoying that the α is an explicit parameter
def Slice.len (α : Type u) (v : Slice α) : Usize :=
@@ -325,8 +325,7 @@ theorem Slice.subslice_spec {α : Type u} [Inhabited α] (s : Slice α) (r : Ran
have := List.index_slice r.start.val r.end_.val i s.val (by scalar_tac) (by scalar_tac) (by trivial) (by scalar_tac)
simp [*]
-attribute [pp_dot] List.len List.length List.index -- use the dot notation when printing
-set_option pp.coercions false -- do not print coercions with ↑ (this doesn't parse)
+set_option pp.fieldNotation.generalized true
def Slice.update_subslice (α : Type u) (s : Slice α) (r : Range Usize) (ss : Slice α) : Result (Slice α) :=
-- TODO: not completely sure here
diff --git a/backends/lean/Base/Primitives/CoreConvertNum.lean b/backends/lean/Base/Primitives/CoreConvertNum.lean
index eb456a96..b53d11db 100644
--- a/backends/lean/Base/Primitives/CoreConvertNum.lean
+++ b/backends/lean/Base/Primitives/CoreConvertNum.lean
@@ -4,6 +4,7 @@ import Init.Data.List.Basic
import Mathlib.Tactic.Linarith
import Base.IList
import Base.Primitives.Scalar
+import Base.Primitives.ScalarNotations
import Base.Primitives.ArraySlice
import Base.Arith
import Base.Progress.Base
diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean
index 8fb067e1..9f809ead 100644
--- a/backends/lean/Base/Primitives/Scalar.lean
+++ b/backends/lean/Base/Primitives/Scalar.lean
@@ -1,6 +1,5 @@
import Lean
import Lean.Meta.Tactic.Simp
-import Mathlib.Tactic.Linarith
import Base.Primitives.Base
import Base.Primitives.Core
import Base.Diverge.Base
@@ -9,6 +8,9 @@ import Base.Arith.Int
namespace Primitives
+-- Deactivate the warnings which appear when we use `#assert`
+set_option linter.hashCommand false
+
----------------------
-- MACHINE INTEGERS --
----------------------
@@ -279,11 +281,11 @@ theorem Scalar.cMax_bound ty : Scalar.cMax ty ≤ Scalar.max ty := by
theorem Scalar.cMin_suffices ty (h : Scalar.cMin ty ≤ x) : Scalar.min ty ≤ x := by
have := Scalar.cMin_bound ty
- linarith
+ omega
theorem Scalar.cMax_suffices ty (h : x ≤ Scalar.cMax ty) : x ≤ Scalar.max ty := by
have := Scalar.cMax_bound ty
- linarith
+ omega
/-- The scalar type.
@@ -310,40 +312,15 @@ theorem Scalar.bound_suffices (ty : ScalarTy) (x : Int) :
Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty
:=
λ h => by
- apply And.intro <;> have hmin := Scalar.cMin_bound ty <;> have hmax := Scalar.cMax_bound ty <;> linarith
+ apply And.intro <;> have hmin := Scalar.cMin_bound ty <;> have hmax := Scalar.cMax_bound ty <;> omega
-/- [match_pattern] attribute: allows to us `Scalar.ofIntCore` inside of patterns.
- This is particularly useful once we introduce notations like `#u32` (which
- desugards to `Scalar.ofIntCore`) as it allows to write expressions like this:
- Example:
- ```
- match x with
- | 0#u32 => ...
- | 1#u32 => ...
- | ...
- ```
- -/
-@[match_pattern] def Scalar.ofIntCore {ty : ScalarTy} (x : Int)
+def Scalar.ofIntCore {ty : ScalarTy} (x : Int)
(h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : Scalar ty :=
{ val := x, hmin := h.left, hmax := h.right }
--- The definitions below are used later to introduce nice syntax for constants,
--- like `1#u32`. We are reusing the technique described here: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Different.20elaboration.20inside.2Foutside.20of.20match.20patterns/near/425455284
-
-class InBounds (ty : ScalarTy) (x : Int) :=
- hInBounds : Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty
-
--- This trick to trigger reduction for decidable propositions comes from
--- here: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/instance.20with.20tactic.20autoparam/near/343495807
-class Decide (p : Prop) [Decidable p] : Prop where
- isTrue : p
-instance : @Decide p (.isTrue h) := @Decide.mk p (_) h
-
-instance [Decide (Scalar.cMin ty ≤ v ∧ v ≤ Scalar.cMax ty)] : InBounds ty v where
- hInBounds := Decide.isTrue
-
-@[reducible, match_pattern] def Scalar.ofInt {ty : ScalarTy} (x : Int) [InBounds ty x] : Scalar ty :=
- Scalar.ofIntCore x (Scalar.bound_suffices ty x InBounds.hInBounds)
+@[reducible] def Scalar.ofInt {ty : ScalarTy} (x : Int)
+ (hInBounds : Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty := by decide) : Scalar ty :=
+ Scalar.ofIntCore x (Scalar.bound_suffices ty x hInBounds)
@[simp] abbrev Scalar.in_bounds (ty : ScalarTy) (x : Int) : Prop :=
Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty
@@ -351,10 +328,17 @@ instance [Decide (Scalar.cMin ty ≤ v ∧ v ≤ Scalar.cMax ty)] : InBounds ty
@[simp] abbrev Scalar.check_bounds (ty : ScalarTy) (x : Int) : Bool :=
(Scalar.cMin ty ≤ x || Scalar.min ty ≤ x) ∧ (x ≤ Scalar.cMax ty || x ≤ Scalar.max ty)
+/- Discussion:
+ This coercion can be slightly annoying at times, because if we write
+ something like `u = 3` (where `u` is, for instance, as `U32`), then instead of
+ coercing `u` to `Int`, Lean will lift `3` to `U32`).
+ For now we deactivate it.
+
-- TODO(raitobezarius): the inbounds constraint is a bit ugly as we can pretty trivially
-- discharge the lhs on ≥ 0.
instance {ty: ScalarTy} [InBounds ty (Int.ofNat n)]: OfNat (Scalar ty) (n: ℕ) where
ofNat := Scalar.ofInt n
+-/
theorem Scalar.check_bounds_imp_in_bounds {ty : ScalarTy} {x : Int}
(h: Scalar.check_bounds ty x) :
@@ -363,7 +347,7 @@ theorem Scalar.check_bounds_imp_in_bounds {ty : ScalarTy} {x : Int}
have ⟨ hmin, hmax ⟩ := h
have hbmin := Scalar.cMin_bound ty
have hbmax := Scalar.cMax_bound ty
- cases hmin <;> cases hmax <;> apply And.intro <;> linarith
+ cases hmin <;> cases hmax <;> apply And.intro <;> omega
theorem Scalar.check_bounds_eq_in_bounds (ty : ScalarTy) (x : Int) :
Scalar.check_bounds ty x ↔ Scalar.in_bounds ty x := by
@@ -405,9 +389,8 @@ theorem Scalar.tryMk_eq (ty : ScalarTy) (x : Int) :
simp [tryMk, ofOption, tryMkOpt]
split_ifs <;> simp
-instance (ty: ScalarTy) : InBounds ty 0 where
- hInBounds := by
- induction ty <;> simp [Scalar.cMax, Scalar.cMin, Scalar.max, Scalar.min] <;> decide
+@[simp] theorem zero_in_cbounds {ty : ScalarTy} : Scalar.cMin ty ≤ 0 ∧ 0 ≤ Scalar.cMax ty := by
+ cases ty <;> simp [Scalar.cMax, Scalar.cMin, Scalar.max, Scalar.min] <;> decide
def Scalar.neg {ty : ScalarTy} (x : Scalar ty) : Result (Scalar ty) := Scalar.tryMk ty (- x.val)
@@ -749,7 +732,6 @@ theorem Scalar.add_spec {ty} {x y : Scalar ty}
(∃ z, x + y = ok z ∧ (↑z : Int) = ↑x + ↑y) := by
have h := @add_equiv ty x y
split at h <;> simp_all
- apply h
theorem Scalar.add_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty}
(hmax : ↑x + ↑y ≤ Scalar.max ty) :
@@ -757,7 +739,7 @@ theorem Scalar.add_unsigned_spec {ty} (s: ¬ ty.isSigned) {x y : Scalar ty}
have hmin : Scalar.min ty ≤ ↑x + ↑y := by
have hx := x.hmin
have hy := y.hmin
- cases ty <;> simp [min, ScalarTy.isSigned] at * <;> linarith
+ cases ty <;> simp [min, ScalarTy.isSigned] at * <;> omega
apply add_spec <;> assumption
/- Fine-grained theorems -/
@@ -844,7 +826,6 @@ theorem Scalar.sub_spec {ty} {x y : Scalar ty}
∃ z, x - y = ok z ∧ (↑z : Int) = ↑x - ↑y := by
have h := @sub_equiv ty x y
split at h <;> simp_all
- apply h
theorem Scalar.sub_unsigned_spec {ty : ScalarTy} (s : ¬ ty.isSigned)
{x y : Scalar ty} (hmin : Scalar.min ty ≤ ↑x - ↑y) :
@@ -853,7 +834,7 @@ theorem Scalar.sub_unsigned_spec {ty : ScalarTy} (s : ¬ ty.isSigned)
have hx := x.hmin
have hxm := x.hmax
have hy := y.hmin
- cases ty <;> simp [min, max, ScalarTy.isSigned] at * <;> linarith
+ cases ty <;> simp [min, max, ScalarTy.isSigned] at * <;> omega
intros
apply sub_spec <;> assumption
@@ -1049,11 +1030,11 @@ theorem Scalar.div_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : S
have hx := x.hmin
have hy := y.hmin
simp [h] at hx hy
- have hmin : 0 ≤ ↑x / ↑y := Int.ediv_nonneg hx hy
+ have hmin : 0 ≤ x.val / y.val := Int.ediv_nonneg hx hy
have hmax : ↑x / ↑y ≤ Scalar.max ty := by
have := Int.ediv_le_self ↑y hx
have := x.hmax
- linarith
+ omega
have hs := @div_spec ty x y hnz
simp [*] at hs
apply hs
@@ -1170,7 +1151,7 @@ theorem Scalar.rem_unsigned_spec {ty} (s: ¬ ty.isSigned) (x : Scalar ty) {y : S
have h : (0 : Int) < y := by int_tac
have h := Int.emod_lt_of_pos ↑x h
have := y.hmax
- linarith
+ omega
have hs := @rem_spec ty x y hnz
simp [*] at hs
simp [*]
@@ -1261,73 +1242,18 @@ def U128.ofIntCore := @Scalar.ofIntCore .U128
-- ofInt
-- TODO: typeclass?
-@[match_pattern] abbrev Isize.ofInt := @Scalar.ofInt .Isize
-@[match_pattern] abbrev I8.ofInt := @Scalar.ofInt .I8
-@[match_pattern] abbrev I16.ofInt := @Scalar.ofInt .I16
-@[match_pattern] abbrev I32.ofInt := @Scalar.ofInt .I32
-@[match_pattern] abbrev I64.ofInt := @Scalar.ofInt .I64
-@[match_pattern] abbrev I128.ofInt := @Scalar.ofInt .I128
-@[match_pattern] abbrev Usize.ofInt := @Scalar.ofInt .Usize
-@[match_pattern] abbrev U8.ofInt := @Scalar.ofInt .U8
-@[match_pattern] abbrev U16.ofInt := @Scalar.ofInt .U16
-@[match_pattern] abbrev U32.ofInt := @Scalar.ofInt .U32
-@[match_pattern] abbrev U64.ofInt := @Scalar.ofInt .U64
-@[match_pattern] abbrev U128.ofInt := @Scalar.ofInt .U128
-
-postfix:max "#isize" => Isize.ofInt
-postfix:max "#i8" => I8.ofInt
-postfix:max "#i16" => I16.ofInt
-postfix:max "#i32" => I32.ofInt
-postfix:max "#i64" => I64.ofInt
-postfix:max "#i128" => I128.ofInt
-postfix:max "#usize" => Usize.ofInt
-postfix:max "#u8" => U8.ofInt
-postfix:max "#u16" => U16.ofInt
-postfix:max "#u32" => U32.ofInt
-postfix:max "#u64" => U64.ofInt
-postfix:max "#u128" => U128.ofInt
-
-/- Testing the notations -/
-example := 0#u32
-example := 1#u32
-example := 1#i32
-example := 0#isize
-example := (-1)#isize
-example (x : U32) : Bool :=
- match x with
- | 0#u32 => true
- | _ => false
-
-example (x : U32) : Bool :=
- match x with
- | 1#u32 => true
- | _ => false
-
-example (x : I32) : Bool :=
- match x with
- | (-1)#i32 => true
- | _ => false
-
--- Notation for pattern matching
--- We make the precedence looser than the negation.
-notation:70 a:70 "#scalar" => Scalar.mk (a) _ _
-
-example {ty} (x : Scalar ty) : ℤ :=
- match x with
- | v#scalar => v
-
-example {ty} (x : Scalar ty) : Bool :=
- match x with
- | 1#scalar => true
- | _ => false
-
-example {ty} (x : Scalar ty) : Bool :=
- match x with
- | -(1 : Int)#scalar => true
- | _ => false
-
--- Testing the notations
-example : Result Usize := 0#usize + 1#usize
+abbrev Isize.ofInt := @Scalar.ofInt .Isize
+abbrev I8.ofInt := @Scalar.ofInt .I8
+abbrev I16.ofInt := @Scalar.ofInt .I16
+abbrev I32.ofInt := @Scalar.ofInt .I32
+abbrev I64.ofInt := @Scalar.ofInt .I64
+abbrev I128.ofInt := @Scalar.ofInt .I128
+abbrev Usize.ofInt := @Scalar.ofInt .Usize
+abbrev U8.ofInt := @Scalar.ofInt .U8
+abbrev U16.ofInt := @Scalar.ofInt .U16
+abbrev U32.ofInt := @Scalar.ofInt .U32
+abbrev U64.ofInt := @Scalar.ofInt .U64
+abbrev U128.ofInt := @Scalar.ofInt .U128
-- TODO: factor those lemmas out
@[simp] theorem Scalar.ofInt_val_eq {ty} (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : (Scalar.ofIntCore x h).val = x := by
@@ -1457,18 +1383,18 @@ theorem coe_max {ty: ScalarTy} (a b: Scalar ty): ↑(Max.max a b) = (Max.max (
-- Max theory
-- TODO: do the min theory later on.
-theorem Scalar.zero_le_unsigned {ty} (s: ¬ ty.isSigned) (x: Scalar ty): Scalar.ofInt 0 ≤ x := by
+theorem Scalar.zero_le_unsigned {ty} (s: ¬ ty.isSigned) (x: Scalar ty): Scalar.ofInt 0 (by simp) ≤ x := by
apply (Scalar.le_equiv _ _).2
convert x.hmin
cases ty <;> simp [ScalarTy.isSigned] at s <;> simp [Scalar.min]
@[simp]
theorem Scalar.max_unsigned_left_zero_eq {ty} [s: Fact (¬ ty.isSigned)] (x: Scalar ty):
- Max.max (Scalar.ofInt 0) x = x := max_eq_right (Scalar.zero_le_unsigned s.out x)
+ Max.max (Scalar.ofInt 0 (by simp)) x = x := max_eq_right (Scalar.zero_le_unsigned s.out x)
@[simp]
theorem Scalar.max_unsigned_right_zero_eq {ty} [s: Fact (¬ ty.isSigned)] (x: Scalar ty):
- Max.max x (Scalar.ofInt 0) = x := max_eq_left (Scalar.zero_le_unsigned s.out x)
+ Max.max x (Scalar.ofInt 0 (by simp)) = x := max_eq_left (Scalar.zero_le_unsigned s.out x)
-- Leading zeros
def core.num.Usize.leading_zeros (x : Usize) : U32 := sorry
diff --git a/backends/lean/Base/Primitives/ScalarNotations.lean b/backends/lean/Base/Primitives/ScalarNotations.lean
new file mode 100644
index 00000000..3bc86a9c
--- /dev/null
+++ b/backends/lean/Base/Primitives/ScalarNotations.lean
@@ -0,0 +1,109 @@
+import Lean
+import Lean.Meta.Tactic.Simp
+import Mathlib.Tactic.Linarith
+import Base.Primitives.Scalar
+import Base.Arith
+
+namespace Primitives
+
+open Lean Meta Elab Term
+
+/- Something strange happens here: when we solve the goal with scalar_tac, it
+ sometimes leaves meta-variables in place, which then causes issues when
+ type-checking functions. For instance, it happens when we have const-generics
+ in the translation: the constants contain meta-variables, which are then
+ used in the types, which cause issues later. An example is given below:
+ -/
+macro:max x:term:max noWs "#isize" : term => `(Isize.ofInt $x (by first | decide | scalar_tac))
+macro:max x:term:max noWs "#i8" : term => `(I8.ofInt $x (by first | decide | scalar_tac))
+macro:max x:term:max noWs "#i16" : term => `(I16.ofInt $x (by first | decide | scalar_tac))
+macro:max x:term:max noWs "#i32" : term => `(I32.ofInt $x (by first | decide | scalar_tac))
+macro:max x:term:max noWs "#i64" : term => `(I64.ofInt $x (by first | decide | scalar_tac))
+macro:max x:term:max noWs "#i128" : term => `(I128.ofInt $x (by first | decide | scalar_tac))
+macro:max x:term:max noWs "#usize" : term => `(Usize.ofInt $x (by first | decide | scalar_tac))
+macro:max x:term:max noWs "#u8" : term => `(U8.ofInt $x (by first | decide | scalar_tac))
+macro:max x:term:max noWs "#u16" : term => `(U16.ofInt $x (by first | decide | scalar_tac))
+macro:max x:term:max noWs "#u32" : term => `(U32.ofInt $x (by first | decide | scalar_tac))
+macro:max x:term:max noWs "#u64" : term => `(U64.ofInt $x (by first | decide | scalar_tac))
+macro:max x:term:max noWs "#u128" : term => `(U128.ofInt $x (by first | decide | scalar_tac))
+
+-- Notation for pattern matching
+-- We make the precedence looser than the negation.
+notation:70 a:70 "#scalar" => Scalar.mk (a) _ _
+
+/- Testing the notations -/
+example := 0#u32
+example := 1#u32
+example := 1#i32
+example := 0#isize
+example := (-1)#isize
+
+example := 1#u32
+
+/-
+-- This doesn't work anymore
+example (x : U32) : Bool :=
+ match x with
+ | 0#u32 => true
+ | _ => false
+
+example (x : U32) : Bool :=
+ match x with
+ | 1#u32 => true
+ | _ => false
+
+example (x : I32) : Bool :=
+ match x with
+ | (-1)#i32 => true
+ | _ => false
+-/
+
+example (x : U32) : Bool :=
+ match x with
+ | 0#scalar => true
+ | _ => false
+
+example (x : U32) : Bool :=
+ match x with
+ | 1#scalar => true
+ | _ => false
+
+example (x : I32) : Bool :=
+ match x with
+ | (-1)#scalar => true
+ | _ => false
+
+example {ty} (x : Scalar ty) : ℤ :=
+ match x with
+ | v#scalar => v
+
+example {ty} (x : Scalar ty) : Bool :=
+ match x with
+ | 1#scalar => true
+ | _ => false
+
+example {ty} (x : Scalar ty) : Bool :=
+ match x with
+ | -(1 : Int)#scalar => true
+ | _ => false
+
+-- Testing the notations
+example : Result Usize := 0#usize + 1#usize
+
+-- More complex expressions
+example (x y : Int) (h : 0 ≤ x + y ∧ x + y ≤ 1000) : U32 := (x + y)#u32
+
+namespace Scalar.Examples
+
+ abbrev Array (a : Type) (len : U32) := { l : List a // l.length = len.val }
+
+ -- Checking the syntax
+ example : Array Int 0#u32 := ⟨ [], by simp ⟩
+
+ /- The example below fails if we don't use `decide` in the elaboration
+ of the scalar notation -/
+ example (a : Array (Array Int 32#u32) 32#u32) := a
+
+end Scalar.Examples
+
+end Primitives
diff --git a/backends/lean/Base/Primitives/Vec.lean b/backends/lean/Base/Primitives/Vec.lean
index 5ed7b606..0b010944 100644
--- a/backends/lean/Base/Primitives/Vec.lean
+++ b/backends/lean/Base/Primitives/Vec.lean
@@ -2,7 +2,6 @@
import Lean
import Lean.Meta.Tactic.Simp
import Init.Data.List.Basic
-import Mathlib.Tactic.Linarith
import Base.IList
import Base.Primitives.Scalar
import Base.Primitives.ArraySlice
@@ -34,7 +33,7 @@ abbrev Vec.v {α : Type u} (v : Vec α) : List α := v.val
example {a: Type u} (v : Vec a) : v.length ≤ Scalar.max ScalarTy.Usize := by
scalar_tac
-def Vec.new (α : Type u): Vec α := ⟨ [], by apply Scalar.cMax_suffices .Usize; simp; decide ⟩
+def Vec.new (α : Type u): Vec α := ⟨ [], by apply Scalar.cMax_suffices .Usize; simp ⟩
instance (α : Type u) : Inhabited (Vec α) := by
constructor
@@ -59,7 +58,7 @@ def Vec.push (α : Type u) (v : Vec α) (x : α) : Result (Vec α)
have h : nlen ≤ Usize.max := by
simp [Usize.max] at *
have hm := Usize.refined_max.property
- cases h <;> cases hm <;> simp [U32.max, U64.max] at * <;> try linarith
+ cases h <;> cases hm <;> simp [U32.max, U64.max] at * <;> try omega
ok ⟨ List.concat v.val x, by simp at *; assumption ⟩
else
fail maximumSizeExceeded
@@ -192,7 +191,7 @@ def alloc.slice.Slice.to_vec
def core.slice.Slice.reverse (T : Type) (s : Slice T) : Slice T :=
⟨ s.val.reverse, by sorry ⟩
-def alloc.vec.Vec.with_capacity (T : Type) (s : Usize) : alloc.vec.Vec T := Vec.new T
+def alloc.vec.Vec.with_capacity (T : Type) (_ : Usize) : alloc.vec.Vec T := Vec.new T
/- [alloc::vec::{(core::ops::deref::Deref for alloc::vec::Vec<T, A>)#9}::deref]:
Source: '/rustc/d59363ad0b6391b7fc5bbb02c9ccf9300eef3753/library/alloc/src/vec/mod.rs', lines 2624:4-2624:27
diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean
index 03c80a42..0e46737f 100644
--- a/backends/lean/Base/Progress/Base.lean
+++ b/backends/lean/Base/Progress/Base.lean
@@ -1,5 +1,4 @@
import Lean
-import Std.Lean.HashSet
import Base.Utils
import Base.Primitives.Base
import Base.Extensions
@@ -111,7 +110,7 @@ section Methods
-- Collect all the free variables in the arguments
let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty
-- Check if they intersect the fvars we introduced for the existentially quantified variables
- let evarsSet : HashSet FVarId := HashSet.ofArray (evars.map (fun (x : Expr) => x.fvarId!))
+ let evarsSet : HashSet FVarId := HashSet.empty.insertMany (evars.map (fun (x : Expr) => x.fvarId!))
let filtArgsFVars := allArgsFVars.toArray.filter (fun var => evarsSet.contains var)
if filtArgsFVars.isEmpty then pure ()
else
diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean
index f2a56e50..da601b73 100644
--- a/backends/lean/Base/Progress/Progress.lean
+++ b/backends/lean/Base/Progress/Progress.lean
@@ -58,17 +58,13 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal)
We also make sure that all the meta variables which appear in the
function arguments have been instantiated
-/
- let env ← getEnv
let thTy ← do
match th with
| .Theorem thName =>
- let thDecl := env.constants.find! thName
- -- We have to introduce fresh meta-variables for the universes already
- let ul : List (Name × Level) ←
- thDecl.levelParams.mapM (λ x => do pure (x, ← mkFreshLevelMVar))
- let ulMap : HashMap Name Level := HashMap.ofList ul
- let thTy := thDecl.type.instantiateLevelParamsCore (λ x => ulMap.find! x)
- pure thTy
+ -- Lookup the theorem and introduce fresh meta-variables for the universes
+ let th ← mkConstWithFreshMVarLevels thName
+ -- Retrieve the type
+ inferType th
| .Local asmDecl => pure asmDecl.type
trace[Progress] "Looked up theorem/assumption type: {thTy}"
-- TODO: the tactic fails if we uncomment withNewMCtxDepth
@@ -135,7 +131,7 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal)
Tactic.focus do
let _ ←
tryTac
- (simpAt true []
+ (simpAt true {} #[] []
[``Primitives.bind_tc_ok, ``Primitives.bind_tc_fail, ``Primitives.bind_tc_div]
[hEq.fvarId!] (.targets #[] true))
-- It may happen that at this point the goal is already solved (though this is rare)
@@ -144,7 +140,7 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal)
else
trace[Progress] "goal after applying the eq and simplifying the binds: {← getMainGoal}"
-- TODO: remove this (some types get unfolded too much: we "fold" them back)
- let _ ← tryTac (simpAt true [] scalar_eqs [] .wildcard_dep)
+ let _ ← tryTac (simpAt true {} #[] [] scalar_eqs [] .wildcard_dep)
trace[Progress] "goal after folding back scalar types: {← getMainGoal}"
-- Clear the equality, unless the user requests not to do so
let mgoal ← do
@@ -350,11 +346,8 @@ def evalProgress (args : TSyntax `Progress.progressArgs) : TacticM Unit := do
-- Not a local declaration: should be a theorem
trace[Progress] "With arg: theorem"
addCompletionInfo <| CompletionInfo.id id id.getId (danglingDot := false) {} none
- let cs ← resolveGlobalConstWithInfos id
- match cs with
- | [] => throwError "Could not find theorem {id}"
- | id :: _ =>
- pure (some (.Theorem id))
+ let some (.const name _) ← Term.resolveId? id | throwError m!"Could not find theorem: {id}"
+ pure (some (.Theorem name))
else pure none
let ids :=
let args := asArgs.getArgs
diff --git a/backends/lean/Base/Utils.lean b/backends/lean/Base/Utils.lean
index 7ae5a832..5954f048 100644
--- a/backends/lean/Base/Utils.lean
+++ b/backends/lean/Base/Utils.lean
@@ -7,7 +7,6 @@ Mathlib tactics:
- rcases: https://leanprover-community.github.io/mathlib_docs/tactics.html#rcases
- split_ifs: https://leanprover-community.github.io/mathlib_docs/tactics.html#split_ifs
- norm_num: https://leanprover-community.github.io/mathlib_docs/tactics.html#norm_num
-- should we use linarith or omega?
- hint: https://leanprover-community.github.io/mathlib_docs/tactics.html#hint
- classical: https://leanprover-community.github.io/mathlib_docs/tactics.html#classical
-/
@@ -133,8 +132,9 @@ open Lean.Elab.Command
liftTermElabM do
let id := stx[1]
addCompletionInfo <| CompletionInfo.id id id.getId (danglingDot := false) {} none
- let cs ← resolveGlobalConstWithInfos id
- explore_decl cs[0]!
+ let some cs ← Term.resolveId? id | throwError m!"Unknown id: {id}"
+ let name := cs.constName!
+ explore_decl name
private def test1 : Nat := 0
private def test2 (x : Nat) : Nat := x
@@ -664,7 +664,7 @@ example (h : ∃ x y z, x + y + z ≥ 0) : ∃ x, x ≥ 0 := by
Something very annoying is that there is no function which allows to
initialize a simp context without doing an elaboration - as a consequence
we write our own here. -/
-def mkSimpCtx (simpOnly : Bool) (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) :
+def mkSimpCtx (simpOnly : Bool) (config : Simp.Config) (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) :
Tactic.TacticM Simp.Context := do
-- Initialize either with the builtin simp theorems or with all the simp theorems
let simpThms ←
@@ -693,7 +693,7 @@ def mkSimpCtx (simpOnly : Bool) (declsToUnfold : List Name) (thms : List Name) (
throwError "Not a proposition: {thmName}"
) simpThms
let congrTheorems ← getSimpCongrTheorems
- pure { simpTheorems := #[simpThms], congrTheorems }
+ pure { config, simpTheorems := #[simpThms], congrTheorems }
inductive Location where
/-- Apply the tactic everywhere. Same as `Tactic.Location.wildcard` -/
@@ -704,56 +704,111 @@ inductive Location where
/-- Same as Tactic.Location -/
| targets (hypotheses : Array Syntax) (type : Bool)
--- Comes from Tactic.simpLocation
-def customSimpLocation (ctx : Simp.Context) (discharge? : Option Simp.Discharge := none)
- (loc : Location) : TacticM Simp.UsedSimps := do
+-- Adapted from Tactic.simpLocation
+def customSimpLocation (ctx : Simp.Context) (simprocs : Simp.SimprocsArray) (discharge? : Option Simp.Discharge := none)
+ (loc : Location) : TacticM Simp.Stats := do
match loc with
| Location.targets hyps simplifyTarget =>
- withMainContext do
- let fvarIds ← Lean.Elab.Tactic.getFVarIds hyps
- go fvarIds simplifyTarget
+ -- Simply call the regular simpLocation
+ simpLocation ctx simprocs discharge? (Tactic.Location.targets hyps simplifyTarget)
| Location.wildcard =>
- withMainContext do
- go (← (← getMainGoal).getNondepPropHyps) (simplifyTarget := true)
+ -- Simply call the regular simpLocation
+ simpLocation ctx simprocs discharge? Tactic.Location.wildcard
| Location.wildcard_dep =>
+ -- Custom behavior
withMainContext do
- let ctx ← Lean.MonadLCtx.getLCtx
- let decls ← ctx.getDecls
+ -- Lookup *all* the declarations
+ let lctx ← Lean.MonadLCtx.getLCtx
+ let decls ← lctx.getDecls
let tgts := (decls.map (fun d => d.fvarId)).toArray
- go tgts (simplifyTarget := true)
-where
- go (fvarIdsToSimp : Array FVarId) (simplifyTarget : Bool) : TacticM Simp.UsedSimps := do
- let mvarId ← getMainGoal
- let (result?, usedSimps) ← simpGoal mvarId ctx (simplifyTarget := simplifyTarget) (discharge? := discharge?) (fvarIdsToSimp := fvarIdsToSimp)
- match result? with
- | none => replaceMainGoal []
- | some (_, mvarId) => replaceMainGoal [mvarId]
- return usedSimps
+ -- Call the regular simpLocation.go
+ simpLocation.go ctx simprocs discharge? tgts (simplifyTarget := true)
/- Call the simp tactic. -/
-def simpAt (simpOnly : Bool) (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId)
- (loc : Location) :
+def simpAt (simpOnly : Bool) (config : Simp.Config) (simprocs : Simp.SimprocsArray)
+ (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) (loc : Location) :
Tactic.TacticM Unit := do
-- Initialize the simp context
- let ctx ← mkSimpCtx simpOnly declsToUnfold thms hypsToUse
+ let ctx ← mkSimpCtx simpOnly config declsToUnfold thms hypsToUse
-- Apply the simplifier
- let _ ← customSimpLocation ctx (discharge? := .none) loc
+ let _ ← customSimpLocation ctx simprocs (discharge? := .none) loc
/- Call the dsimp tactic. -/
-def dsimpAt (simpOnly : Bool) (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId)
- (loc : Tactic.Location) :
+def dsimpAt (simpOnly : Bool) (config : Simp.Config) (simprocs : Simp.SimprocsArray)
+ (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) (loc : Tactic.Location) :
Tactic.TacticM Unit := do
-- Initialize the simp context
- let ctx ← mkSimpCtx simpOnly declsToUnfold thms hypsToUse
+ let ctx ← mkSimpCtx simpOnly config declsToUnfold thms hypsToUse
-- Apply the simplifier
- dsimpLocation ctx loc
+ dsimpLocation ctx simprocs loc
-- Call the simpAll tactic
-def simpAll (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) :
+def simpAll (config : Simp.Config) (declsToUnfold : List Name) (thms : List Name) (hypsToUse : List FVarId) :
Tactic.TacticM Unit := do
-- Initialize the simp context
- let ctx ← mkSimpCtx false declsToUnfold thms hypsToUse
+ let ctx ← mkSimpCtx false config declsToUnfold thms hypsToUse
-- Apply the simplifier
let _ ← Lean.Meta.simpAll (← getMainGoal) ctx
+/- Adapted from Elab.Tactic.Rewrite -/
+def rewriteTarget (eqThm : Expr) (symm : Bool) (config : Rewrite.Config := {}) : TacticM Unit := do
+ Term.withSynthesize <| withMainContext do
+ let r ← (← getMainGoal).rewrite (← getMainTarget) eqThm symm (config := config)
+ let mvarId' ← (← getMainGoal).replaceTargetEq r.eNew r.eqProof
+ replaceMainGoal (mvarId' :: r.mvarIds)
+
+/- Adapted from Elab.Tactic.Rewrite -/
+def rewriteLocalDecl (eqThm : Expr) (symm : Bool) (fvarId : FVarId) (config : Rewrite.Config := {}) :
+ TacticM Unit := withMainContext do
+ -- Note: we cannot execute `replaceLocalDecl` inside `Term.withSynthesize`.
+ -- See issues #2711 and #2727.
+ let rwResult ← Term.withSynthesize <| withMainContext do
+ let localDecl ← fvarId.getDecl
+ (← getMainGoal).rewrite localDecl.type eqThm symm (config := config)
+ let replaceResult ← (← getMainGoal).replaceLocalDecl fvarId rwResult.eNew rwResult.eqProof
+ replaceMainGoal (replaceResult.mvarId :: rwResult.mvarIds)
+
+/- Adapted from Elab.Tactic.Rewrite -/
+def rewriteWithThms
+ (thms : List (Bool × Expr))
+ (rewrite : (symm : Bool) → (thm : Expr) → TacticM Unit)
+ : TacticM Unit := do
+ let rec go thms :=
+ match thms with
+ | [] => throwError "Failed to rewrite with any theorem"
+ | (symm, eqThm)::thms =>
+ rewrite symm eqThm <|> go thms
+ go thms
+
+/- Adapted from Elab.Tactic.Rewrite -/
+def evalRewriteSeqAux (cfg : Rewrite.Config) (thms : List (Bool × Expr)) (loc : Tactic.Location) : TacticM Unit :=
+ rewriteWithThms thms fun symm term => do
+ withLocation loc
+ (rewriteLocalDecl term symm · cfg)
+ (rewriteTarget term symm cfg)
+ (throwTacticEx `rewrite · "did not find instance of the pattern in the current goal")
+
+/-- `rpt`: if `true`, repeatedly rewrite -/
+def rewriteAt (cfg : Rewrite.Config) (rpt : Bool)
+ (thms : List (Bool × Name)) (loc : Tactic.Location) : TacticM Unit := do
+ -- Lookup the theorems
+ let lookupThm (x : Bool × Name) : TacticM (List (Bool × Expr)) := do
+ let thName := x.snd
+ let lookupOne (thName : Name) : TacticM (Bool × Expr) := do
+ -- Lookup the theorem and introduce fresh meta-variables for the universes
+ let th ← mkConstWithFreshMVarLevels thName
+ pure (x.fst, th)
+ match ← getEqnsFor? thName (nonRec := true) with
+ | some eqThms => do
+ eqThms.data.mapM lookupOne
+ | none => do
+ pure [← lookupOne thName]
+ let thms ← List.mapM lookupThm thms
+ let thms := thms.flatten
+ -- Rewrite
+ if rpt then
+ Utils.repeatTac (evalRewriteSeqAux cfg thms loc)
+ else
+ evalRewriteSeqAux cfg thms loc
+
end Utils