summaryrefslogtreecommitdiff
path: root/backends/lean/Base/Arith.lean
diff options
context:
space:
mode:
Diffstat (limited to 'backends/lean/Base/Arith.lean')
-rw-r--r--backends/lean/Base/Arith.lean221
1 files changed, 221 insertions, 0 deletions
diff --git a/backends/lean/Base/Arith.lean b/backends/lean/Base/Arith.lean
new file mode 100644
index 00000000..6339f218
--- /dev/null
+++ b/backends/lean/Base/Arith.lean
@@ -0,0 +1,221 @@
+/- This file contains tactics to solve arithmetic goals -/
+
+import Lean
+import Lean.Meta.Tactic.Simp
+import Init.Data.List.Basic
+import Mathlib.Tactic.RunCmd
+import Mathlib.Tactic.Linarith
+import Base.Primitives
+
+/-
+Mathlib tactics:
+- rcases: https://leanprover-community.github.io/mathlib_docs/tactics.html#rcases
+- split_ifs: https://leanprover-community.github.io/mathlib_docs/tactics.html#split_ifs
+- norm_num: https://leanprover-community.github.io/mathlib_docs/tactics.html#norm_num
+- should we use linarith or omega?
+- hint: https://leanprover-community.github.io/mathlib_docs/tactics.html#hint
+- classical: https://leanprover-community.github.io/mathlib_docs/tactics.html#classical
+-/
+
+namespace List
+
+ -- TODO: I could not find this function??
+ @[simp] def flatten {a : Type u} : List (List a) → List a
+ | [] => []
+ | x :: ls => x ++ flatten ls
+
+end List
+
+namespace Lean
+
+namespace LocalContext
+
+ open Lean Lean.Elab Command Term Lean.Meta
+
+ -- Small utility: return the list of declarations in the context, from
+ -- the last to the first.
+ def getAllDecls (lctx : Lean.LocalContext) : MetaM (List Lean.LocalDecl) :=
+ lctx.foldrM (fun d ls => do let d ← instantiateLocalDeclMVars d; pure (d :: ls)) []
+
+ -- Return the list of declarations in the context, but filter the
+ -- declarations which are considered as implementation details
+ def getDecls (lctx : Lean.LocalContext) : MetaM (List Lean.LocalDecl) := do
+ let ls ← lctx.getAllDecls
+ pure (ls.filter (fun d => not d.isImplementationDetail))
+
+end LocalContext
+
+end Lean
+
+namespace Arith
+
+open Primitives
+
+--set_option pp.explicit true
+--set_option pp.notation false
+--set_option pp.coercions false
+
+-- TODO: move
+instance Vec.cast (a : Type): Coe (Vec a) (List a) where coe := λ v => v.val
+
+-- TODO: move
+/- Remark: we can't write the following instance because of restrictions about
+ the type class parameters (`ty` doesn't appear in the return type, which is
+ forbidden):
+
+ ```
+ instance Scalar.cast (ty : ScalarTy) : Coe (Scalar ty) Int where coe := λ v => v.val
+ ```
+ -/
+def Scalar.toInt {ty : ScalarTy} (x : Scalar ty) : Int := x.val
+
+-- We use this type-class to test if an expression is a scalar (if we manage
+-- to lookup an instance of this type-class, then it is)
+class IsScalar (a : Type) where
+
+instance (ty : ScalarTy) : IsScalar (Scalar ty) where
+
+--example (ty : ScalarTy) : IsScalar (Scalar ty) := _
+
+open Lean Lean.Elab Command Term Lean.Meta
+
+-- Return true if the expression is a scalar expression
+def isScalarExpr (e : Expr) : MetaM Bool := do
+ -- Try to convert the expression to a scalar
+ -- TODO: I tried to do it with Lean.Meta.mkAppM but it didn't work: how
+ -- do we allow Lean to perform (controlled) unfoldings for instantiation
+ -- purposes?
+ let r ← Lean.observing? do
+ let ty ← Lean.Meta.inferType e
+ let isScalar ← mkAppM `Arith.IsScalar #[ty]
+ let isScalar ← trySynthInstance isScalar
+ match isScalar with
+ | LOption.some x => some x
+ | _ => none
+ match r with
+ | .some _ => pure true
+ | _ => pure false
+
+-- Explore a term and return the set of scalar expressions found inside
+partial def collectScalarExprsAux (hs : HashSet Expr) (e : Expr) : MetaM (HashSet Expr) := do
+ -- We do it in a very simpler manner: we deconstruct applications,
+ -- and recursively explore the sub-expressions. Note that we do
+ -- not go inside foralls and abstractions (should we?).
+ e.withApp fun f args => do
+ let hs ← if ← isScalarExpr f then pure (hs.insert f) else pure hs
+ let hs ← args.foldlM collectScalarExprsAux hs
+ pure hs
+
+-- Explore a term and return the list of scalar expressions found inside
+def collectScalarExprs (e : Expr) : MetaM (HashSet Expr) :=
+ collectScalarExprsAux HashSet.empty e
+
+-- Collect the scalar expressions in the context
+def getScalarExprsFromMainCtx : Tactic.TacticM (HashSet Expr) := do
+ Lean.Elab.Tactic.withMainContext do
+ -- Get the local context
+ let ctx ← Lean.MonadLCtx.getLCtx
+ -- Just a matter of precaution
+ let ctx ← instantiateLCtxMVars ctx
+ -- Initialize the hashset
+ let hs := HashSet.empty
+ -- Explore the declarations
+ let decls ← ctx.getDecls
+ let hs ← decls.foldlM (fun hs d => collectScalarExprsAux hs d.toExpr) hs
+ -- Return
+ pure hs
+
+
+#check TSyntax
+#check mkIdent
+-- TODO: addDecl?
+-- Project the scalar expressions into the context, to retrieve the bound inequalities
+-- def projectScalarExpr (e: Expr) : Tactic.TacticM Unit := do
+-- let e ← `($e)
+-- let e ← Lean.Elab.Term.elabTerm `($e) none
+-- Lean.Elab.Tactic.evalCases `($e)
+
+elab "list_scalar_exprs" : tactic => do
+ let hs ← getScalarExprsFromMainCtx
+ hs.forM fun e => do
+ dbg_trace f!"+ Scalar expression: {e}"
+
+#check LocalContext
+
+elab "list_local_decls_1" : tactic =>
+ Lean.Elab.Tactic.withMainContext do
+ -- Get the local context
+ let ctx ← Lean.MonadLCtx.getLCtx
+ let ctx ← instantiateLCtxMVars ctx
+ let decls ← ctx.getDecls
+ -- Filter the scalar expressions
+ let decls ← decls.filterMapM fun decl: Lean.LocalDecl => do
+ let declExpr := decl.toExpr
+ let declName := decl.userName
+ let declType ← Lean.Meta.inferType declExpr
+ dbg_trace f!"+ local decl: name: {declName} | expr: {declExpr} | ty: {declType}"
+ -- Try to convert the expression to a scalar
+ -- TODO: I tried to do it with Lean.Meta.mkAppM but it didn't work: how
+ -- do we allow Lean to perform (controlled) unfoldings for instantiation
+ -- purposes?
+ let r ← Lean.observing? do
+ let isScalar ← mkAppM `Arith.IsScalar #[declType]
+ let isScalar ← trySynthInstance isScalar
+ match isScalar with
+ | LOption.some x => some x
+ | _ => none
+ match r with
+ | .some _ => dbg_trace f!" Scalar expression"; pure r
+ | _ => dbg_trace f!" Not a scalar"; pure .none
+ pure ()
+ -- match ← Lean.observing? (Lean.Meta.mkAppM `Scalar.toInt #[decl.toExpr]) with
+ -- | .none => dbg_trace f!" Not a scalar"
+ -- | .some _ => dbg_trace f!" Scalar expression"
+
+#check Lean.Environment.addDecl
+#check Expr
+#check LocalContext
+#check MVarId
+#check Lean.Elab.Tactic.setGoals
+#check Lean.Elab.Tactic.Context
+#check withLocalDecl
+#check Lean.MVarId.assert
+#check LocalDecl
+
+-- Insert x = 3 in the context
+elab "custom_let" : tactic =>
+ -- I don't think we need that
+ Lean.Elab.Tactic.withMainContext do
+ --
+ let type := (Expr.const `Nat [])
+ let val : Expr := .lit (.natVal 3)
+ let n := `x -- the name is "x"
+ withLetDecl n type val fun nval => do
+ -- For debugging
+ let lctx ← Lean.MonadLCtx.getLCtx
+ let fid := nval.fvarId!
+ let decl := lctx.get! fid
+ dbg_trace f!" nval: \"{decl.userName}\" ({nval}) : {decl.type} := {decl.value}"
+ --
+ -- Tranform the main goal `m0?` to `let x = nval in m1?`
+ let mvarId ← Tactic.getMainGoal
+ let newMVar ← mkFreshExprSyntheticOpaqueMVar (← mvarId.getType)
+ let newVal ← mkLetFVars #[nval] newMVar
+ -- Focus on the current goal
+ Tactic.focus do
+ -- Assign the main goal.
+ -- We must do this *after* we focused on the current goal, because
+ -- after we assigned the meta variable the goal is considered as solved
+ mvarId.assign newVal
+ -- Replace the list of goals with the new goal - we can do this because
+ -- we focused on the current goal
+ Lean.Elab.Tactic.setGoals [newMVar.mvarId!]
+
+example : Nat := by
+ custom_let
+ apply x
+
+example (x : Bool) : Nat := by
+ cases x <;> custom_let <;> apply x
+
+end Arith