From a6de153f3bfda7feb27d16fcdf2131d37f99c7a3 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 29 Jun 2023 11:22:32 +0200 Subject: Start working on Elab.lean --- backends/lean/Base/Diverge/Base.lean | 3 + backends/lean/Base/Diverge/Elab.lean | 138 ++++++++++++++++++++++++++++--- backends/lean/Base/Diverge/ElabBase.lean | 75 ++++++++++++++++- 3 files changed, 203 insertions(+), 13 deletions(-) (limited to 'backends/lean') diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index 0f92e682..2e60f6e8 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -4,6 +4,9 @@ import Init.Data.List.Basic import Mathlib.Tactic.RunCmd import Mathlib.Tactic.Linarith +-- For debugging +import Base.Diverge.ElabBase + /- TODO: - we want an easier to use cases: diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 313c5a79..22e0039f 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -15,16 +15,53 @@ syntax (name := divergentDef) open Lean Elab Term Meta Primitives -initialize registerTraceClass `Diverge.divRecursion (inherited := true) - -set_option trace.Diverge.divRecursion true +set_option trace.Diverge.def true /- The following was copied from the `wfRecursion` function. -/ open WF in + + + +-- Replace the recursive calls by a call to the continuation +-- def replace_rec_calls + +#check Lean.Meta.forallTelescope +#check Expr +#check withRef +#check MonadRef.withRef +#check Nat +#check Array +#check Lean.Meta.inferType +#check Nat +#check Int + +#check (0, 1) +#check Prod +#check () +#check Unit +#check Sigma + +-- print_decl is_even_body +#check instOfNatNat +#check OfNat.ofNat -- @OfNat.ofNat ℕ 2 ... +#check OfNat.ofNat -- @OfNat.ofNat (Fin 2) 1 ... +#check Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat + + +-- TODO: is there already such a utility somewhere? +-- TODO: change to mkSigmas +def mkProds (tys : List Expr) : MetaM Expr := + match tys with + | [] => do return (Expr.const ``PUnit.unit []) + | [ty] => do return ty + | ty :: tys => do + let pty ← mkProds tys + mkAppM ``Prod.mk #[ty, pty] + def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) - logInfo ("divRecursion: defs: " ++ msg) + trace[Diverge.def] ("divRecursion: defs: " ++ msg) -- CHANGE HERE This function should add definitions with these names/types/values ^^ -- Temporarily add the predefinitions as axioms @@ -35,6 +72,85 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do for preDef in preDefs do applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation + -- Retrieve the name of the first definition, that we will use as the namespace + -- for the definitions common to the group + let def0 := preDefs[0]! + let grName := def0.declName + trace[Diverge.def] "group name: {grName}" + + /- Compute the type of the continuation. + + We do the following + - we make sure all the definitions have the same universe parameters + (we can make this more general later) + - we group all the type parameters together, make sure all the + definitions have the same type parameters, and enforce + a uniform polymorphism (we can also lift this later). + This would require generalizing a bit our indexed fixed point to + make the output type parametric in the input. + - we group all the non-type parameters: we parameterize the continuation + by those + -/ + let grLvlParams := def0.levelParams + trace[Diverge.def] "def0 type: {def0.type}" + + -- Small utility: compute the list of type parameters + let getTypeParams (ty: Expr) : MetaM (List Expr × List Expr × Expr) := + Lean.Meta.forallTelescope ty fun tys out_ty => do + trace[Diverge.def] "types: {tys}" +/- let (_, params) ← StateT.run (do + for x in tys do + let ty ← Lean.Meta.inferType x + match ty with + | .sort _ => do + let st ← StateT.get + StateT.set (ty :: st) + | _ => do break + ) ([] : List Expr) + let params := params.reverse + trace[Diverge.def] " type parameters {params}" + return params -/ + let rec get_params (ls : List Expr) : MetaM (List Expr × List Expr) := + match ls with + | x :: tl => do + let ty ← Lean.Meta.inferType x + match ty with + | .sort _ => do + let (ty_params, params) ← get_params tl + return (x :: ty_params, params) + | _ => do return ([], ls) + | _ => do return ([], []) + let (ty_params, params) ← get_params tys.toList + trace[Diverge.def] " parameters: {ty_params}; {params}" + return (ty_params, params, out_ty) + let (grTyParams, _, _) ← do + getTypeParams def0.type + + -- Compute the input types and the output types + let all_tys ← preDefs.mapM fun preDef => do + let (tyParams, params, ret_ty) ← getTypeParams preDef.type + -- TODO: this is not complete, there are more checks to perform + if tyParams.length ≠ grTyParams.length then + throwError "Non-uniform polymorphism" + return (params, ret_ty) + + -- TODO: I think there are issues with the free variables + let (input_tys, output_tys) := List.unzip all_tys.toList + let input_tys : List Expr ← liftM (List.mapM mkProds input_tys) + + trace[Diverge.def] " in/out tys: {input_tys}; {output_tys}" + + -- Compute the names set + let names := preDefs.map PreDefinition.declName + let names := HashSet.empty.insertMany names + + -- + for preDef in preDefs do + trace[Diverge.def] "about to explore: {preDef.declName}" + explore_term "" preDef.value + + -- Compute the bodies + -- Process the definitions addAndCompilePartialRec preDefs @@ -47,21 +163,21 @@ open private ensureNoUnassignedMVarsAtPreDef betaReduceLetRecApps partitionPreDe def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLCtx {} {} do for preDef in preDefs do - trace[Elab.definition.body] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" + trace[Diverge.elab] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" let preDefs ← preDefs.mapM ensureNoUnassignedMVarsAtPreDef let preDefs ← betaReduceLetRecApps preDefs let cliques := partitionPreDefs preDefs let mut hasErrors := false for preDefs in cliques do - trace[Elab.definition.scc] "{preDefs.map (·.declName)}" + trace[Diverge.elab] "{preDefs.map (·.declName)}" try - logInfo "calling divRecursion" + trace[Diverge.elab] "calling divRecursion" withRef (preDefs[0]!.ref) do divRecursion preDefs - logInfo "divRecursion succeeded" + trace[Diverge.elab] "divRecursion succeeded" catch ex => -- If it failed, we - logInfo "divRecursion failed" + trace[Diverge.elab] "divRecursion failed" hasErrors := true logException ex let s ← saveState @@ -106,12 +222,12 @@ def Term.elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM U withUsed vars headers values letRecsToLift fun vars => do let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift for preDef in preDefs do - trace[Elab.definition] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" + trace[Diverge.elab] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamPreDecls preDefs let preDefs ← instantiateMVarsAtPreDecls preDefs let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames for preDef in preDefs do - trace[Elab.definition] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}" + trace[Diverge.elab] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}" checkForHiddenUnivLevels allUserLevelNames preDefs addPreDefinitions preDefs diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean index e693dce2..84b73a30 100644 --- a/backends/lean/Base/Diverge/ElabBase.lean +++ b/backends/lean/Base/Diverge/ElabBase.lean @@ -2,8 +2,79 @@ import Lean namespace Diverge -open Lean +open Lean Elab Term Meta -initialize registerTraceClass `Diverge.divRecursion (inherited := true) +initialize registerTraceClass `Diverge.elab (inherited := true) +initialize registerTraceClass `Diverge.def (inherited := true) + +-- TODO: move +-- TODO: small helper +def explore_term (incr : String) (e : Expr) : TermElabM Unit := + match e with + | .bvar _ => do logInfo m!"{incr}bvar: {e}"; return () + | .fvar _ => do logInfo m!"{incr}fvar: {e}"; return () + | .mvar _ => do logInfo m!"{incr}mvar: {e}"; return () + | .sort _ => do logInfo m!"{incr}sort: {e}"; return () + | .const _ _ => do logInfo m!"{incr}const: {e}"; return () + | .app fn arg => do + logInfo m!"{incr}app: {e}" + explore_term (incr ++ " ") fn + explore_term (incr ++ " ") arg + | .lam _bName bTy body _binfo => do + logInfo m!"{incr}lam: {e}" + explore_term (incr ++ " ") bTy + explore_term (incr ++ " ") body + | .forallE _bName bTy body _bInfo => do + logInfo m!"{incr}forallE: {e}" + explore_term (incr ++ " ") bTy + explore_term (incr ++ " ") body + | .letE _dName ty val body _nonDep => do + logInfo m!"{incr}letE: {e}" + explore_term (incr ++ " ") ty + explore_term (incr ++ " ") val + explore_term (incr ++ " ") body + | .lit _ => do logInfo m!"{incr}lit: {e}"; return () + | .mdata _ e => do + logInfo m!"{incr}mdata: {e}" + explore_term (incr ++ " ") e + | .proj _ _ struct => do + logInfo m!"{incr}proj: {e}" + explore_term (incr ++ " ") struct + +def explore_decl (n : Name) : TermElabM Unit := do + logInfo m!"Name: {n}" + let env ← getEnv + let decl := env.constants.find! n + match decl with + | .defnInfo val => + logInfo m!"About to explore defn: {decl.name}" + logInfo m!"# Type:" + explore_term "" val.type + logInfo m!"# Value:" + explore_term "" val.value + | .axiomInfo _ => throwError m!"axiom: {n}" + | .thmInfo _ => throwError m!"thm: {n}" + | .opaqueInfo _ => throwError m!"opaque: {n}" + | .quotInfo _ => throwError m!"quot: {n}" + | .inductInfo _ => throwError m!"induct: {n}" + | .ctorInfo _ => throwError m!"ctor: {n}" + | .recInfo _ => throwError m!"rec: {n}" + +syntax (name := printDecl) "print_decl " ident : command + +open Lean.Elab.Command + +@[command_elab printDecl] def elabPrintDecl : CommandElab := fun stx => do + liftTermElabM do + let id := stx[1] + addCompletionInfo <| CompletionInfo.id id id.getId (danglingDot := false) {} none + let cs ← resolveGlobalConstWithInfos id + explore_decl cs[0]! + +private def test1 : Nat := 0 +private def test2 (x : Nat) : Nat := x + +print_decl test1 +print_decl test2 end Diverge -- cgit v1.2.3