summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-06-29 23:15:20 +0200
committerSon Ho2023-06-29 23:15:20 +0200
commitfdc8693772ecb1978873018c790061854f00a015 (patch)
tree7ef00d00d7f939fc364faca43974bbdb871a48cf
parent0cee49de70bec6d3ec2221b64a532d19ad71e5e0 (diff)
Write function to compute the input/output types
-rw-r--r--backends/lean/Base/Diverge/Base.lean3
-rw-r--r--backends/lean/Base/Diverge/Elab.lean154
-rw-r--r--backends/lean/Base/Diverge/ElabBase.lean1
3 files changed, 126 insertions, 32 deletions
diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean
index 630c0bf6..22b59bd0 100644
--- a/backends/lean/Base/Diverge/Base.lean
+++ b/backends/lean/Base/Diverge/Base.lean
@@ -560,6 +560,7 @@ namespace FixI
kk_ty id a b → kk_ty id a b
def in_out_ty : Type (imax (u + 1) (v + 1)) := (in_ty : Type u) × ((x:in_ty) → Type v)
+ -- TODO: remove?
@[simp] def mk_in_out_ty (in_ty : Type u) (out_ty : in_ty → Type v) :
in_out_ty :=
Sigma.mk in_ty out_ty
@@ -1143,7 +1144,7 @@ namespace Ex6
if i = 0 then .ret hd
else list_nth tl (i - 1)
:= by
- have Heq := is_valid_fix_fixed_eq (@list_nth_body_is_valid a)
+ have Heq := is_valid_fix_fixed_eq list_nth_body_is_valid
simp [list_nth]
conv => lhs; rw [Heq]
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean
index 22e0039f..116c5d8b 100644
--- a/backends/lean/Base/Diverge/Elab.lean
+++ b/backends/lean/Base/Diverge/Elab.lean
@@ -13,7 +13,7 @@ namespace Diverge
syntax (name := divergentDef)
declModifiers "divergent" "def" declId ppIndent(optDeclSig) declVal : command
-open Lean Elab Term Meta Primitives
+open Lean Elab Term Meta Primitives Lean.Meta
set_option trace.Diverge.def true
@@ -21,27 +21,9 @@ set_option trace.Diverge.def true
open WF in
-
-
-- Replace the recursive calls by a call to the continuation
-- def replace_rec_calls
-#check Lean.Meta.forallTelescope
-#check Expr
-#check withRef
-#check MonadRef.withRef
-#check Nat
-#check Array
-#check Lean.Meta.inferType
-#check Nat
-#check Int
-
-#check (0, 1)
-#check Prod
-#check ()
-#check Unit
-#check Sigma
-
-- print_decl is_even_body
#check instOfNatNat
#check OfNat.ofNat -- @OfNat.ofNat ℕ 2 ...
@@ -59,6 +41,100 @@ def mkProds (tys : List Expr) : MetaM Expr :=
let pty ← mkProds tys
mkAppM ``Prod.mk #[ty, pty]
+/- Generate the input type of a function body, which is a sigma type (i.e., a
+ dependent tuple) which groups all its inputs.
+
+ Example:
+ - xl = [(a:Type), (ls:List a), (i:Int)]
+
+ Generates:
+ `(a:Type) × (ls:List a) × (i:Int)`
+
+ -/
+def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr :=
+ match xl with
+ | [] => do
+ trace[Diverge.def.sigmas] "mkSigmasOfTypes: []"
+ return (Expr.const ``PUnit.unit [])
+ | [x] => do
+ trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}]"
+ let ty ← Lean.Meta.inferType x
+ return ty
+ | x :: xl => do
+ trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]"
+ let alpha ← Lean.Meta.inferType x
+ let sty ← mkSigmasTypesOfTypes xl
+ trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]: alpha={alpha}, sty={sty}"
+ let beta ← mkLambdaFVars #[x] sty
+ trace[Diverge.def.sigmas] "mkSigmasOfTypes: ({alpha}) ({beta})"
+ mkAppOptM ``Sigma #[some alpha, some beta]
+
+def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index
+
+/- Generate the out_ty of the body of a function, which from an input (a sigma
+ type generated by `mkSigmasTypesOfTypes`) gives the output type of the function.
+
+ Example:
+ - xl = `[a:Type, ls:List a, i:Int]`
+ - out_ty = `a`
+ - index = 0 -- For naming purposes: we use it to numerotate the "scrutinee" variables
+
+ Generates:
+ ```
+ match scrut0 with
+ | Sigma.mk x scrut1 =>
+ match scrut1 with
+ | Sigma.mk ls i =>
+ a
+ ```
+-/
+def mkSigmasOutType (xl : List Expr) (out_ty : Expr) (index : Nat := 0) : MetaM Expr :=
+ match xl with
+ | [] => do
+ -- This would be unexpected
+ throwError "mkSigmasOutType: empyt list of input parameters"
+ | [x] => do
+ -- In the explanations above: inner match case
+ trace[Diverge.def.sigmas] "mkSigmasOutType: [{x}]"
+ mkLambdaFVars #[x] out_ty
+ | fst :: xl => do
+ -- In the explanations above: outer match case
+ -- Remark: for the naming purposes, we use the same convention as for the
+ -- fields and parameters in `Sigma.casesOn and `Sigma.mk
+ trace[Diverge.def.sigmas] "mkSigmasOutType: [{fst}::{xl}]"
+ let alpha ← Lean.Meta.inferType fst
+ let snd_ty ← mkSigmasTypesOfTypes xl
+ let beta ← mkLambdaFVars #[fst] snd_ty
+ let snd ← mkSigmasOutType xl out_ty (index + 1)
+ let scrut_ty ← mkSigmasTypesOfTypes (fst :: xl)
+ withLocalDeclD (mk_indexed_name index) scrut_ty fun scrut => do
+ let mk ← mkLambdaFVars #[fst] snd
+ trace[Diverge.def.sigmas] "mkSigmasOutType: scrut: ({scrut}) : ({← inferType scrut})"
+ let motive ← mkLambdaFVars #[scrut] (← inferType out_ty)
+ trace[Diverge.def.sigmas] "mkSigmasOutType:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})"
+ let out ← mkAppOptM ``Sigma.casesOn #[some alpha, some beta, some motive, some scrut, some mk]
+ let out ← mkLambdaFVars #[scrut] out
+ trace[Diverge.def.sigmas] "mkSigmasOutType: out: {out}"
+ return out
+
+/- Small tests for list_nth: give a model of what `mkSigmasOutType` should generate -/
+private def list_nth_out_ty2 (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) :=
+ @Sigma.casesOn (List a)
+ (fun (_ls : List a) => Int)
+ (fun (_scrut1:@Sigma (List a) (fun (_ls : List a) => Int)) => Type)
+ scrut1
+ (fun (_ls : List a) (_i : Int) => Diverge.Primitives.Result a)
+
+private def list_nth_out_ty1 (scrut0 : @Sigma (Type) (fun (a:Type) =>
+ @Sigma (List a) (fun (_ls : List a) => Int))) :=
+ @Sigma.casesOn (Type)
+ (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int))
+ (fun (_scrut0:@Sigma (Type) (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int))) => Type)
+ scrut0
+ (fun (a : Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) =>
+ list_nth_out_ty2 a scrut1)
+/- -/
+
def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value)
trace[Diverge.def] ("divRecursion: defs: " ++ msg)
@@ -94,7 +170,23 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
let grLvlParams := def0.levelParams
trace[Diverge.def] "def0 type: {def0.type}"
- -- Small utility: compute the list of type parameters
+ -- Compute the list of pairs: (input type × output type)
+ let inOutTys : Array (Expr × Expr) ←
+ preDefs.mapM (fun preDef => do
+ -- Check the universe parameters - TODO: I'm not sure what the best thing
+ -- to do is. In practice, all the type parameters should be in Type 0, so
+ -- we shouldn't have universe issues.
+ if preDef.levelParams ≠ grLvlParams then
+ throwError "Non-uniform polymorphism in the universes"
+ forallTelescope preDef.type (fun in_tys out_ty => do
+ let in_ty ← liftM (mkSigmasTypesOfTypes in_tys.toList)
+ let out_ty ← liftM (mkSigmasOutType in_tys.toList out_ty)
+ return (in_ty, out_ty)
+ )
+ )
+ trace[Diverge.def] "inOutTys: {inOutTys}"
+
+/- -- Small utility: compute the list of type parameters
let getTypeParams (ty: Expr) : MetaM (List Expr × List Expr × Expr) :=
Lean.Meta.forallTelescope ty fun tys out_ty => do
trace[Diverge.def] "types: {tys}"
@@ -138,16 +230,16 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
let (input_tys, output_tys) := List.unzip all_tys.toList
let input_tys : List Expr ← liftM (List.mapM mkProds input_tys)
- trace[Diverge.def] " in/out tys: {input_tys}; {output_tys}"
+ trace[Diverge.def] " in/out tys: {input_tys}; {output_tys}" -/
-- Compute the names set
let names := preDefs.map PreDefinition.declName
let names := HashSet.empty.insertMany names
--
- for preDef in preDefs do
- trace[Diverge.def] "about to explore: {preDef.declName}"
- explore_term "" preDef.value
+ -- for preDef in preDefs do
+ -- trace[Diverge.def] "about to explore: {preDef.declName}"
+ -- explore_term "" preDef.value
-- Compute the bodies
@@ -267,6 +359,13 @@ elab_rules : command
else
Command.elabCommand <| ← `(namespace $(mkIdentFrom id ns) $cmd end $(mkIdentFrom id ns))
+divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a :=
+ match ls with
+ | [] => .fail .panic
+ | x :: ls =>
+ if i = 0 then return x
+ else return (← list_nth ls (i - 1))
+
mutual
divergent def is_even (i : Int) : Result Bool :=
if i = 0 then return true else return (← is_odd (i - 1))
@@ -280,13 +379,6 @@ example (i : Int) : is_even i = .ret (i % 2 = 0) ∧ is_odd i = .ret (i % 2 ≠
unfold is_even
sorry
-divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a :=
- match ls with
- | [] => .fail .panic
- | x :: ls =>
- if i = 0 then return x
- else return (← list_nth ls (i - 1))
-
mutual
divergent def foo (i : Int) : Result Nat :=
if i > 10 then return (← foo (i / 10)) + (← bar i) else bar 10
diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean
index 84b73a30..441b25f0 100644
--- a/backends/lean/Base/Diverge/ElabBase.lean
+++ b/backends/lean/Base/Diverge/ElabBase.lean
@@ -5,6 +5,7 @@ namespace Diverge
open Lean Elab Term Meta
initialize registerTraceClass `Diverge.elab (inherited := true)
+initialize registerTraceClass `Diverge.def.sigmas (inherited := true)
initialize registerTraceClass `Diverge.def (inherited := true)
-- TODO: move