import Lean import Lean.Meta.Tactic.Simp import Init.Data.List.Basic import Mathlib.Tactic.RunCmd import Base.Diverge.Base import Base.Diverge.ElabBase namespace Diverge /- Automating the generation of the encoding and the proofs so as to use nice syntactic sugar. -/ syntax (name := divergentDef) declModifiers "divergent" "def" declId ppIndent(optDeclSig) declVal : command open Lean Elab Term Meta Primitives Lean.Meta set_option trace.Diverge.def true /- The following was copied from the `wfRecursion` function. -/ open WF in -- Replace the recursive calls by a call to the continuation -- def replace_rec_calls -- print_decl is_even_body #check instOfNatNat #check OfNat.ofNat -- @OfNat.ofNat ℕ 2 ... #check OfNat.ofNat -- @OfNat.ofNat (Fin 2) 1 ... #check Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat -- TODO: is there already such a utility somewhere? -- TODO: change to mkSigmas def mkProds (tys : List Expr) : MetaM Expr := match tys with | [] => do return (Expr.const ``PUnit.unit []) | [ty] => do return ty | ty :: tys => do let pty ← mkProds tys mkAppM ``Prod.mk #[ty, pty] /- Generate the input type of a function body, which is a sigma type (i.e., a dependent tuple) which groups all its inputs. Example: - xl = [(a:Type), (ls:List a), (i:Int)] Generates: `(a:Type) × (ls:List a) × (i:Int)` -/ def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr := match xl with | [] => do trace[Diverge.def.sigmas] "mkSigmasOfTypes: []" return (Expr.const ``PUnit.unit []) | [x] => do trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}]" let ty ← Lean.Meta.inferType x return ty | x :: xl => do trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]" let alpha ← Lean.Meta.inferType x let sty ← mkSigmasTypesOfTypes xl trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]: alpha={alpha}, sty={sty}" let beta ← mkLambdaFVars #[x] sty trace[Diverge.def.sigmas] "mkSigmasOfTypes: ({alpha}) ({beta})" mkAppOptM ``Sigma #[some alpha, some beta] def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index /- Generate the out_ty of the body of a function, which from an input (a sigma type generated by `mkSigmasTypesOfTypes`) gives the output type of the function. Example: - xl = `[a:Type, ls:List a, i:Int]` - out_ty = `a` - index = 0 -- For naming purposes: we use it to numerotate the "scrutinee" variables Generates: ``` match scrut0 with | Sigma.mk x scrut1 => match scrut1 with | Sigma.mk ls i => a ``` -/ def mkSigmasOutType (xl : List Expr) (out_ty : Expr) (index : Nat := 0) : MetaM Expr := match xl with | [] => do -- This would be unexpected throwError "mkSigmasOutType: empyt list of input parameters" | [x] => do -- In the explanations above: inner match case trace[Diverge.def.sigmas] "mkSigmasOutType: [{x}]" mkLambdaFVars #[x] out_ty | fst :: xl => do -- In the explanations above: outer match case -- Remark: for the naming purposes, we use the same convention as for the -- fields and parameters in `Sigma.casesOn and `Sigma.mk trace[Diverge.def.sigmas] "mkSigmasOutType: [{fst}::{xl}]" let alpha ← Lean.Meta.inferType fst let snd_ty ← mkSigmasTypesOfTypes xl let beta ← mkLambdaFVars #[fst] snd_ty let snd ← mkSigmasOutType xl out_ty (index + 1) let scrut_ty ← mkSigmasTypesOfTypes (fst :: xl) withLocalDeclD (mk_indexed_name index) scrut_ty fun scrut => do let mk ← mkLambdaFVars #[fst] snd trace[Diverge.def.sigmas] "mkSigmasOutType: scrut: ({scrut}) : ({← inferType scrut})" let motive ← mkLambdaFVars #[scrut] (← inferType out_ty) trace[Diverge.def.sigmas] "mkSigmasOutType:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})" let out ← mkAppOptM ``Sigma.casesOn #[some alpha, some beta, some motive, some scrut, some mk] let out ← mkLambdaFVars #[scrut] out trace[Diverge.def.sigmas] "mkSigmasOutType: out: {out}" return out /- Small tests for list_nth: give a model of what `mkSigmasOutType` should generate -/ private def list_nth_out_ty2 (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) := @Sigma.casesOn (List a) (fun (_ls : List a) => Int) (fun (_scrut1:@Sigma (List a) (fun (_ls : List a) => Int)) => Type) scrut1 (fun (_ls : List a) (_i : Int) => Diverge.Primitives.Result a) private def list_nth_out_ty1 (scrut0 : @Sigma (Type) (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int))) := @Sigma.casesOn (Type) (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int)) (fun (_scrut0:@Sigma (Type) (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int))) => Type) scrut0 (fun (a : Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) => list_nth_out_ty2 a scrut1) /- -/ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) trace[Diverge.def] ("divRecursion: defs: " ++ msg) -- CHANGE HERE This function should add definitions with these names/types/values ^^ -- Temporarily add the predefinitions as axioms for preDef in preDefs do addAsAxiom preDef -- TODO: what is this? for preDef in preDefs do applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation -- Retrieve the name of the first definition, that we will use as the namespace -- for the definitions common to the group let def0 := preDefs[0]! let grName := def0.declName trace[Diverge.def] "group name: {grName}" /- Compute the type of the continuation. We do the following - we make sure all the definitions have the same universe parameters (we can make this more general later) - we group all the type parameters together, make sure all the definitions have the same type parameters, and enforce a uniform polymorphism (we can also lift this later). This would require generalizing a bit our indexed fixed point to make the output type parametric in the input. - we group all the non-type parameters: we parameterize the continuation by those -/ let grLvlParams := def0.levelParams trace[Diverge.def] "def0 type: {def0.type}" -- Compute the list of pairs: (input type × output type) let inOutTys : Array (Expr × Expr) ← preDefs.mapM (fun preDef => do -- Check the universe parameters - TODO: I'm not sure what the best thing -- to do is. In practice, all the type parameters should be in Type 0, so -- we shouldn't have universe issues. if preDef.levelParams ≠ grLvlParams then throwError "Non-uniform polymorphism in the universes" forallTelescope preDef.type (fun in_tys out_ty => do let in_ty ← liftM (mkSigmasTypesOfTypes in_tys.toList) let out_ty ← liftM (mkSigmasOutType in_tys.toList out_ty) return (in_ty, out_ty) ) ) trace[Diverge.def] "inOutTys: {inOutTys}" /- -- Small utility: compute the list of type parameters let getTypeParams (ty: Expr) : MetaM (List Expr × List Expr × Expr) := Lean.Meta.forallTelescope ty fun tys out_ty => do trace[Diverge.def] "types: {tys}" /- let (_, params) ← StateT.run (do for x in tys do let ty ← Lean.Meta.inferType x match ty with | .sort _ => do let st ← StateT.get StateT.set (ty :: st) | _ => do break ) ([] : List Expr) let params := params.reverse trace[Diverge.def] " type parameters {params}" return params -/ let rec get_params (ls : List Expr) : MetaM (List Expr × List Expr) := match ls with | x :: tl => do let ty ← Lean.Meta.inferType x match ty with | .sort _ => do let (ty_params, params) ← get_params tl return (x :: ty_params, params) | _ => do return ([], ls) | _ => do return ([], []) let (ty_params, params) ← get_params tys.toList trace[Diverge.def] " parameters: {ty_params}; {params}" return (ty_params, params, out_ty) let (grTyParams, _, _) ← do getTypeParams def0.type -- Compute the input types and the output types let all_tys ← preDefs.mapM fun preDef => do let (tyParams, params, ret_ty) ← getTypeParams preDef.type -- TODO: this is not complete, there are more checks to perform if tyParams.length ≠ grTyParams.length then throwError "Non-uniform polymorphism" return (params, ret_ty) -- TODO: I think there are issues with the free variables let (input_tys, output_tys) := List.unzip all_tys.toList let input_tys : List Expr ← liftM (List.mapM mkProds input_tys) trace[Diverge.def] " in/out tys: {input_tys}; {output_tys}" -/ -- Compute the names set let names := preDefs.map PreDefinition.declName let names := HashSet.empty.insertMany names -- -- for preDef in preDefs do -- trace[Diverge.def] "about to explore: {preDef.declName}" -- explore_term "" preDef.value -- Compute the bodies -- Process the definitions addAndCompilePartialRec preDefs -- The following function is copy&pasted from Lean.Elab.PreDefinition.Main -- This is the only part where we make actual changes and hook into the equation compiler. -- (I've removed all the well-founded stuff to make it easier to read though.) open private ensureNoUnassignedMVarsAtPreDef betaReduceLetRecApps partitionPreDefs addAndCompilePartial addAsAxioms from Lean.Elab.PreDefinition.Main def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLCtx {} {} do for preDef in preDefs do trace[Diverge.elab] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" let preDefs ← preDefs.mapM ensureNoUnassignedMVarsAtPreDef let preDefs ← betaReduceLetRecApps preDefs let cliques := partitionPreDefs preDefs let mut hasErrors := false for preDefs in cliques do trace[Diverge.elab] "{preDefs.map (·.declName)}" try trace[Diverge.elab] "calling divRecursion" withRef (preDefs[0]!.ref) do divRecursion preDefs trace[Diverge.elab] "divRecursion succeeded" catch ex => -- If it failed, we trace[Diverge.elab] "divRecursion failed" hasErrors := true logException ex let s ← saveState try if preDefs.all fun preDef => preDef.kind == DefKind.def || preDefs.all fun preDef => preDef.kind == DefKind.abbrev then -- try to add as partial definition try addAndCompilePartial preDefs (useSorry := true) catch _ => -- Compilation failed try again just as axiom s.restore addAsAxioms preDefs else return () catch _ => s.restore -- The following two functions are copy&pasted from Lean.Elab.MutualDef open private elabHeaders levelMVarToParamHeaders getAllUserLevelNames withFunLocalDecls elabFunValues instantiateMVarsAtHeader instantiateMVarsAtLetRecToLift checkLetRecsToLiftTypes withUsed from Lean.Elab.MutualDef def Term.elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do let scopeLevelNames ← getLevelNames let headers ← elabHeaders views let headers ← levelMVarToParamHeaders views headers let allUserLevelNames := getAllUserLevelNames headers withFunLocalDecls headers fun funFVars => do for view in views, funFVar in funFVars do addLocalVarInfo view.declId funFVar let values ← try let values ← elabFunValues headers Term.synthesizeSyntheticMVarsNoPostponing values.mapM (instantiateMVars ·) catch ex => logException ex headers.mapM fun header => mkSorry header.type (synthetic := true) let headers ← headers.mapM instantiateMVarsAtHeader let letRecsToLift ← getLetRecsToLift let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift checkLetRecsToLiftTypes funFVars letRecsToLift withUsed vars headers values letRecsToLift fun vars => do let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift for preDef in preDefs do trace[Diverge.elab] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamPreDecls preDefs let preDefs ← instantiateMVarsAtPreDecls preDefs let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames for preDef in preDefs do trace[Diverge.elab] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}" checkForHiddenUnivLevels allUserLevelNames preDefs addPreDefinitions preDefs open Command in def Command.elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do let views ← ds.mapM fun d => do let `($mods:declModifiers divergent def $id:declId $sig:optDeclSig $val:declVal) := d | throwUnsupportedSyntax let modifiers ← elabModifiers mods let (binders, type) := expandOptDeclSig sig let deriving? := none pure { ref := d, kind := DefKind.def, modifiers, declId := id, binders, type? := type, value := val, deriving? } runTermElabM fun vars => Term.elabMutualDef vars views -- Special command so that we don't fall back to the built-in mutual when we produce an error. local syntax "_divergent" Parser.Command.mutual : command elab_rules : command | `(_divergent mutual $decls* end) => Command.elabMutualDef decls macro_rules | `(mutual $decls* end) => do unless !decls.isEmpty && decls.all (·.1.getKind == ``divergentDef) do Macro.throwUnsupported `(command| _divergent mutual $decls* end) open private setDeclIdName from Lean.Elab.Declaration elab_rules : command | `($mods:declModifiers divergent%$tk def $id:declId $sig:optDeclSig $val:declVal) => do let (name, _) := expandDeclIdCore id if (`_root_).isPrefixOf name then throwUnsupportedSyntax let view := extractMacroScopes name let .str ns shortName := view.name | throwUnsupportedSyntax let shortName' := { view with name := shortName }.review let cmd ← `(mutual $mods:declModifiers divergent%$tk def $(⟨setDeclIdName id shortName'⟩):declId $sig:optDeclSig $val:declVal end) if ns matches .anonymous then Command.elabCommand cmd else Command.elabCommand <| ← `(namespace $(mkIdentFrom id ns) $cmd end $(mkIdentFrom id ns)) divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := match ls with | [] => .fail .panic | x :: ls => if i = 0 then return x else return (← list_nth ls (i - 1)) mutual divergent def is_even (i : Int) : Result Bool := if i = 0 then return true else return (← is_odd (i - 1)) divergent def is_odd (i : Int) : Result Bool := if i = 0 then return false else return (← is_even (i - 1)) end example (i : Int) : is_even i = .ret (i % 2 = 0) ∧ is_odd i = .ret (i % 2 ≠ 0) := by induction i unfold is_even sorry mutual divergent def foo (i : Int) : Result Nat := if i > 10 then return (← foo (i / 10)) + (← bar i) else bar 10 divergent def bar (i : Int) : Result Nat := if i > 20 then foo (i / 20) else .ret 42 end end Diverge