From fdc8693772ecb1978873018c790061854f00a015 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 29 Jun 2023 23:15:20 +0200 Subject: Write function to compute the input/output types --- backends/lean/Base/Diverge/Base.lean | 3 +- backends/lean/Base/Diverge/Elab.lean | 154 ++++++++++++++++++++++++------- backends/lean/Base/Diverge/ElabBase.lean | 1 + 3 files changed, 126 insertions(+), 32 deletions(-) diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index 630c0bf6..22b59bd0 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -560,6 +560,7 @@ namespace FixI kk_ty id a b → kk_ty id a b def in_out_ty : Type (imax (u + 1) (v + 1)) := (in_ty : Type u) × ((x:in_ty) → Type v) + -- TODO: remove? @[simp] def mk_in_out_ty (in_ty : Type u) (out_ty : in_ty → Type v) : in_out_ty := Sigma.mk in_ty out_ty @@ -1143,7 +1144,7 @@ namespace Ex6 if i = 0 then .ret hd else list_nth tl (i - 1) := by - have Heq := is_valid_fix_fixed_eq (@list_nth_body_is_valid a) + have Heq := is_valid_fix_fixed_eq list_nth_body_is_valid simp [list_nth] conv => lhs; rw [Heq] diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 22e0039f..116c5d8b 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -13,7 +13,7 @@ namespace Diverge syntax (name := divergentDef) declModifiers "divergent" "def" declId ppIndent(optDeclSig) declVal : command -open Lean Elab Term Meta Primitives +open Lean Elab Term Meta Primitives Lean.Meta set_option trace.Diverge.def true @@ -21,27 +21,9 @@ set_option trace.Diverge.def true open WF in - - -- Replace the recursive calls by a call to the continuation -- def replace_rec_calls -#check Lean.Meta.forallTelescope -#check Expr -#check withRef -#check MonadRef.withRef -#check Nat -#check Array -#check Lean.Meta.inferType -#check Nat -#check Int - -#check (0, 1) -#check Prod -#check () -#check Unit -#check Sigma - -- print_decl is_even_body #check instOfNatNat #check OfNat.ofNat -- @OfNat.ofNat ℕ 2 ... @@ -59,6 +41,100 @@ def mkProds (tys : List Expr) : MetaM Expr := let pty ← mkProds tys mkAppM ``Prod.mk #[ty, pty] +/- Generate the input type of a function body, which is a sigma type (i.e., a + dependent tuple) which groups all its inputs. + + Example: + - xl = [(a:Type), (ls:List a), (i:Int)] + + Generates: + `(a:Type) × (ls:List a) × (i:Int)` + + -/ +def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr := + match xl with + | [] => do + trace[Diverge.def.sigmas] "mkSigmasOfTypes: []" + return (Expr.const ``PUnit.unit []) + | [x] => do + trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}]" + let ty ← Lean.Meta.inferType x + return ty + | x :: xl => do + trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]" + let alpha ← Lean.Meta.inferType x + let sty ← mkSigmasTypesOfTypes xl + trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]: alpha={alpha}, sty={sty}" + let beta ← mkLambdaFVars #[x] sty + trace[Diverge.def.sigmas] "mkSigmasOfTypes: ({alpha}) ({beta})" + mkAppOptM ``Sigma #[some alpha, some beta] + +def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index + +/- Generate the out_ty of the body of a function, which from an input (a sigma + type generated by `mkSigmasTypesOfTypes`) gives the output type of the function. + + Example: + - xl = `[a:Type, ls:List a, i:Int]` + - out_ty = `a` + - index = 0 -- For naming purposes: we use it to numerotate the "scrutinee" variables + + Generates: + ``` + match scrut0 with + | Sigma.mk x scrut1 => + match scrut1 with + | Sigma.mk ls i => + a + ``` +-/ +def mkSigmasOutType (xl : List Expr) (out_ty : Expr) (index : Nat := 0) : MetaM Expr := + match xl with + | [] => do + -- This would be unexpected + throwError "mkSigmasOutType: empyt list of input parameters" + | [x] => do + -- In the explanations above: inner match case + trace[Diverge.def.sigmas] "mkSigmasOutType: [{x}]" + mkLambdaFVars #[x] out_ty + | fst :: xl => do + -- In the explanations above: outer match case + -- Remark: for the naming purposes, we use the same convention as for the + -- fields and parameters in `Sigma.casesOn and `Sigma.mk + trace[Diverge.def.sigmas] "mkSigmasOutType: [{fst}::{xl}]" + let alpha ← Lean.Meta.inferType fst + let snd_ty ← mkSigmasTypesOfTypes xl + let beta ← mkLambdaFVars #[fst] snd_ty + let snd ← mkSigmasOutType xl out_ty (index + 1) + let scrut_ty ← mkSigmasTypesOfTypes (fst :: xl) + withLocalDeclD (mk_indexed_name index) scrut_ty fun scrut => do + let mk ← mkLambdaFVars #[fst] snd + trace[Diverge.def.sigmas] "mkSigmasOutType: scrut: ({scrut}) : ({← inferType scrut})" + let motive ← mkLambdaFVars #[scrut] (← inferType out_ty) + trace[Diverge.def.sigmas] "mkSigmasOutType:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})" + let out ← mkAppOptM ``Sigma.casesOn #[some alpha, some beta, some motive, some scrut, some mk] + let out ← mkLambdaFVars #[scrut] out + trace[Diverge.def.sigmas] "mkSigmasOutType: out: {out}" + return out + +/- Small tests for list_nth: give a model of what `mkSigmasOutType` should generate -/ +private def list_nth_out_ty2 (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) := + @Sigma.casesOn (List a) + (fun (_ls : List a) => Int) + (fun (_scrut1:@Sigma (List a) (fun (_ls : List a) => Int)) => Type) + scrut1 + (fun (_ls : List a) (_i : Int) => Diverge.Primitives.Result a) + +private def list_nth_out_ty1 (scrut0 : @Sigma (Type) (fun (a:Type) => + @Sigma (List a) (fun (_ls : List a) => Int))) := + @Sigma.casesOn (Type) + (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int)) + (fun (_scrut0:@Sigma (Type) (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int))) => Type) + scrut0 + (fun (a : Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) => + list_nth_out_ty2 a scrut1) +/- -/ + def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) trace[Diverge.def] ("divRecursion: defs: " ++ msg) @@ -94,7 +170,23 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let grLvlParams := def0.levelParams trace[Diverge.def] "def0 type: {def0.type}" - -- Small utility: compute the list of type parameters + -- Compute the list of pairs: (input type × output type) + let inOutTys : Array (Expr × Expr) ← + preDefs.mapM (fun preDef => do + -- Check the universe parameters - TODO: I'm not sure what the best thing + -- to do is. In practice, all the type parameters should be in Type 0, so + -- we shouldn't have universe issues. + if preDef.levelParams ≠ grLvlParams then + throwError "Non-uniform polymorphism in the universes" + forallTelescope preDef.type (fun in_tys out_ty => do + let in_ty ← liftM (mkSigmasTypesOfTypes in_tys.toList) + let out_ty ← liftM (mkSigmasOutType in_tys.toList out_ty) + return (in_ty, out_ty) + ) + ) + trace[Diverge.def] "inOutTys: {inOutTys}" + +/- -- Small utility: compute the list of type parameters let getTypeParams (ty: Expr) : MetaM (List Expr × List Expr × Expr) := Lean.Meta.forallTelescope ty fun tys out_ty => do trace[Diverge.def] "types: {tys}" @@ -138,16 +230,16 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let (input_tys, output_tys) := List.unzip all_tys.toList let input_tys : List Expr ← liftM (List.mapM mkProds input_tys) - trace[Diverge.def] " in/out tys: {input_tys}; {output_tys}" + trace[Diverge.def] " in/out tys: {input_tys}; {output_tys}" -/ -- Compute the names set let names := preDefs.map PreDefinition.declName let names := HashSet.empty.insertMany names -- - for preDef in preDefs do - trace[Diverge.def] "about to explore: {preDef.declName}" - explore_term "" preDef.value + -- for preDef in preDefs do + -- trace[Diverge.def] "about to explore: {preDef.declName}" + -- explore_term "" preDef.value -- Compute the bodies @@ -267,6 +359,13 @@ elab_rules : command else Command.elabCommand <| ← `(namespace $(mkIdentFrom id ns) $cmd end $(mkIdentFrom id ns)) +divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := + match ls with + | [] => .fail .panic + | x :: ls => + if i = 0 then return x + else return (← list_nth ls (i - 1)) + mutual divergent def is_even (i : Int) : Result Bool := if i = 0 then return true else return (← is_odd (i - 1)) @@ -280,13 +379,6 @@ example (i : Int) : is_even i = .ret (i % 2 = 0) ∧ is_odd i = .ret (i % 2 ≠ unfold is_even sorry -divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := - match ls with - | [] => .fail .panic - | x :: ls => - if i = 0 then return x - else return (← list_nth ls (i - 1)) - mutual divergent def foo (i : Int) : Result Nat := if i > 10 then return (← foo (i / 10)) + (← bar i) else bar 10 diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean index 84b73a30..441b25f0 100644 --- a/backends/lean/Base/Diverge/ElabBase.lean +++ b/backends/lean/Base/Diverge/ElabBase.lean @@ -5,6 +5,7 @@ namespace Diverge open Lean Elab Term Meta initialize registerTraceClass `Diverge.elab (inherited := true) +initialize registerTraceClass `Diverge.def.sigmas (inherited := true) initialize registerTraceClass `Diverge.def (inherited := true) -- TODO: move -- cgit v1.2.3