summaryrefslogtreecommitdiff
path: root/backends/lean/Base
diff options
context:
space:
mode:
Diffstat (limited to 'backends/lean/Base')
-rw-r--r--backends/lean/Base/Diverge/Base.lean3
-rw-r--r--backends/lean/Base/Diverge/Elab.lean138
-rw-r--r--backends/lean/Base/Diverge/ElabBase.lean75
3 files changed, 203 insertions, 13 deletions
diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean
index 0f92e682..2e60f6e8 100644
--- a/backends/lean/Base/Diverge/Base.lean
+++ b/backends/lean/Base/Diverge/Base.lean
@@ -4,6 +4,9 @@ import Init.Data.List.Basic
import Mathlib.Tactic.RunCmd
import Mathlib.Tactic.Linarith
+-- For debugging
+import Base.Diverge.ElabBase
+
/-
TODO:
- we want an easier to use cases:
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean
index 313c5a79..22e0039f 100644
--- a/backends/lean/Base/Diverge/Elab.lean
+++ b/backends/lean/Base/Diverge/Elab.lean
@@ -15,16 +15,53 @@ syntax (name := divergentDef)
open Lean Elab Term Meta Primitives
-initialize registerTraceClass `Diverge.divRecursion (inherited := true)
-
-set_option trace.Diverge.divRecursion true
+set_option trace.Diverge.def true
/- The following was copied from the `wfRecursion` function. -/
open WF in
+
+
+
+-- Replace the recursive calls by a call to the continuation
+-- def replace_rec_calls
+
+#check Lean.Meta.forallTelescope
+#check Expr
+#check withRef
+#check MonadRef.withRef
+#check Nat
+#check Array
+#check Lean.Meta.inferType
+#check Nat
+#check Int
+
+#check (0, 1)
+#check Prod
+#check ()
+#check Unit
+#check Sigma
+
+-- print_decl is_even_body
+#check instOfNatNat
+#check OfNat.ofNat -- @OfNat.ofNat ℕ 2 ...
+#check OfNat.ofNat -- @OfNat.ofNat (Fin 2) 1 ...
+#check Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat
+
+
+-- TODO: is there already such a utility somewhere?
+-- TODO: change to mkSigmas
+def mkProds (tys : List Expr) : MetaM Expr :=
+ match tys with
+ | [] => do return (Expr.const ``PUnit.unit [])
+ | [ty] => do return ty
+ | ty :: tys => do
+ let pty ← mkProds tys
+ mkAppM ``Prod.mk #[ty, pty]
+
def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value)
- logInfo ("divRecursion: defs: " ++ msg)
+ trace[Diverge.def] ("divRecursion: defs: " ++ msg)
-- CHANGE HERE This function should add definitions with these names/types/values ^^
-- Temporarily add the predefinitions as axioms
@@ -35,6 +72,85 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
for preDef in preDefs do
applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation
+ -- Retrieve the name of the first definition, that we will use as the namespace
+ -- for the definitions common to the group
+ let def0 := preDefs[0]!
+ let grName := def0.declName
+ trace[Diverge.def] "group name: {grName}"
+
+ /- Compute the type of the continuation.
+
+ We do the following
+ - we make sure all the definitions have the same universe parameters
+ (we can make this more general later)
+ - we group all the type parameters together, make sure all the
+ definitions have the same type parameters, and enforce
+ a uniform polymorphism (we can also lift this later).
+ This would require generalizing a bit our indexed fixed point to
+ make the output type parametric in the input.
+ - we group all the non-type parameters: we parameterize the continuation
+ by those
+ -/
+ let grLvlParams := def0.levelParams
+ trace[Diverge.def] "def0 type: {def0.type}"
+
+ -- Small utility: compute the list of type parameters
+ let getTypeParams (ty: Expr) : MetaM (List Expr × List Expr × Expr) :=
+ Lean.Meta.forallTelescope ty fun tys out_ty => do
+ trace[Diverge.def] "types: {tys}"
+/- let (_, params) ← StateT.run (do
+ for x in tys do
+ let ty ← Lean.Meta.inferType x
+ match ty with
+ | .sort _ => do
+ let st ← StateT.get
+ StateT.set (ty :: st)
+ | _ => do break
+ ) ([] : List Expr)
+ let params := params.reverse
+ trace[Diverge.def] " type parameters {params}"
+ return params -/
+ let rec get_params (ls : List Expr) : MetaM (List Expr × List Expr) :=
+ match ls with
+ | x :: tl => do
+ let ty ← Lean.Meta.inferType x
+ match ty with
+ | .sort _ => do
+ let (ty_params, params) ← get_params tl
+ return (x :: ty_params, params)
+ | _ => do return ([], ls)
+ | _ => do return ([], [])
+ let (ty_params, params) ← get_params tys.toList
+ trace[Diverge.def] " parameters: {ty_params}; {params}"
+ return (ty_params, params, out_ty)
+ let (grTyParams, _, _) ← do
+ getTypeParams def0.type
+
+ -- Compute the input types and the output types
+ let all_tys ← preDefs.mapM fun preDef => do
+ let (tyParams, params, ret_ty) ← getTypeParams preDef.type
+ -- TODO: this is not complete, there are more checks to perform
+ if tyParams.length ≠ grTyParams.length then
+ throwError "Non-uniform polymorphism"
+ return (params, ret_ty)
+
+ -- TODO: I think there are issues with the free variables
+ let (input_tys, output_tys) := List.unzip all_tys.toList
+ let input_tys : List Expr ← liftM (List.mapM mkProds input_tys)
+
+ trace[Diverge.def] " in/out tys: {input_tys}; {output_tys}"
+
+ -- Compute the names set
+ let names := preDefs.map PreDefinition.declName
+ let names := HashSet.empty.insertMany names
+
+ --
+ for preDef in preDefs do
+ trace[Diverge.def] "about to explore: {preDef.declName}"
+ explore_term "" preDef.value
+
+ -- Compute the bodies
+
-- Process the definitions
addAndCompilePartialRec preDefs
@@ -47,21 +163,21 @@ open private ensureNoUnassignedMVarsAtPreDef betaReduceLetRecApps partitionPreDe
def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLCtx {} {} do
for preDef in preDefs do
- trace[Elab.definition.body] "{preDef.declName} : {preDef.type} :=\n{preDef.value}"
+ trace[Diverge.elab] "{preDef.declName} : {preDef.type} :=\n{preDef.value}"
let preDefs ← preDefs.mapM ensureNoUnassignedMVarsAtPreDef
let preDefs ← betaReduceLetRecApps preDefs
let cliques := partitionPreDefs preDefs
let mut hasErrors := false
for preDefs in cliques do
- trace[Elab.definition.scc] "{preDefs.map (·.declName)}"
+ trace[Diverge.elab] "{preDefs.map (·.declName)}"
try
- logInfo "calling divRecursion"
+ trace[Diverge.elab] "calling divRecursion"
withRef (preDefs[0]!.ref) do
divRecursion preDefs
- logInfo "divRecursion succeeded"
+ trace[Diverge.elab] "divRecursion succeeded"
catch ex =>
-- If it failed, we
- logInfo "divRecursion failed"
+ trace[Diverge.elab] "divRecursion failed"
hasErrors := true
logException ex
let s ← saveState
@@ -106,12 +222,12 @@ def Term.elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM U
withUsed vars headers values letRecsToLift fun vars => do
let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift
for preDef in preDefs do
- trace[Elab.definition] "{preDef.declName} : {preDef.type} :=\n{preDef.value}"
+ trace[Diverge.elab] "{preDef.declName} : {preDef.type} :=\n{preDef.value}"
let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamPreDecls preDefs
let preDefs ← instantiateMVarsAtPreDecls preDefs
let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames
for preDef in preDefs do
- trace[Elab.definition] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}"
+ trace[Diverge.elab] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}"
checkForHiddenUnivLevels allUserLevelNames preDefs
addPreDefinitions preDefs
diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean
index e693dce2..84b73a30 100644
--- a/backends/lean/Base/Diverge/ElabBase.lean
+++ b/backends/lean/Base/Diverge/ElabBase.lean
@@ -2,8 +2,79 @@ import Lean
namespace Diverge
-open Lean
+open Lean Elab Term Meta
-initialize registerTraceClass `Diverge.divRecursion (inherited := true)
+initialize registerTraceClass `Diverge.elab (inherited := true)
+initialize registerTraceClass `Diverge.def (inherited := true)
+
+-- TODO: move
+-- TODO: small helper
+def explore_term (incr : String) (e : Expr) : TermElabM Unit :=
+ match e with
+ | .bvar _ => do logInfo m!"{incr}bvar: {e}"; return ()
+ | .fvar _ => do logInfo m!"{incr}fvar: {e}"; return ()
+ | .mvar _ => do logInfo m!"{incr}mvar: {e}"; return ()
+ | .sort _ => do logInfo m!"{incr}sort: {e}"; return ()
+ | .const _ _ => do logInfo m!"{incr}const: {e}"; return ()
+ | .app fn arg => do
+ logInfo m!"{incr}app: {e}"
+ explore_term (incr ++ " ") fn
+ explore_term (incr ++ " ") arg
+ | .lam _bName bTy body _binfo => do
+ logInfo m!"{incr}lam: {e}"
+ explore_term (incr ++ " ") bTy
+ explore_term (incr ++ " ") body
+ | .forallE _bName bTy body _bInfo => do
+ logInfo m!"{incr}forallE: {e}"
+ explore_term (incr ++ " ") bTy
+ explore_term (incr ++ " ") body
+ | .letE _dName ty val body _nonDep => do
+ logInfo m!"{incr}letE: {e}"
+ explore_term (incr ++ " ") ty
+ explore_term (incr ++ " ") val
+ explore_term (incr ++ " ") body
+ | .lit _ => do logInfo m!"{incr}lit: {e}"; return ()
+ | .mdata _ e => do
+ logInfo m!"{incr}mdata: {e}"
+ explore_term (incr ++ " ") e
+ | .proj _ _ struct => do
+ logInfo m!"{incr}proj: {e}"
+ explore_term (incr ++ " ") struct
+
+def explore_decl (n : Name) : TermElabM Unit := do
+ logInfo m!"Name: {n}"
+ let env ← getEnv
+ let decl := env.constants.find! n
+ match decl with
+ | .defnInfo val =>
+ logInfo m!"About to explore defn: {decl.name}"
+ logInfo m!"# Type:"
+ explore_term "" val.type
+ logInfo m!"# Value:"
+ explore_term "" val.value
+ | .axiomInfo _ => throwError m!"axiom: {n}"
+ | .thmInfo _ => throwError m!"thm: {n}"
+ | .opaqueInfo _ => throwError m!"opaque: {n}"
+ | .quotInfo _ => throwError m!"quot: {n}"
+ | .inductInfo _ => throwError m!"induct: {n}"
+ | .ctorInfo _ => throwError m!"ctor: {n}"
+ | .recInfo _ => throwError m!"rec: {n}"
+
+syntax (name := printDecl) "print_decl " ident : command
+
+open Lean.Elab.Command
+
+@[command_elab printDecl] def elabPrintDecl : CommandElab := fun stx => do
+ liftTermElabM do
+ let id := stx[1]
+ addCompletionInfo <| CompletionInfo.id id id.getId (danglingDot := false) {} none
+ let cs ← resolveGlobalConstWithInfos id
+ explore_decl cs[0]!
+
+private def test1 : Nat := 0
+private def test2 (x : Nat) : Nat := x
+
+print_decl test1
+print_decl test2
end Diverge