From ee669c4dbf8be12a3dd7249c645fd7092ba3e8eb Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 11 Dec 2023 17:07:34 +0100 Subject: Cleanup a bit --- backends/lean/Base/Diverge/Elab.lean | 63 +++++++++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 16 deletions(-) diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 423a2514..3c23db64 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -21,6 +21,9 @@ open Utils open WF in +def mkProdType (x y : Expr) : MetaM Expr := + mkAppM ``Prod #[x, y] + def mkProd (x y : Expr) : MetaM Expr := mkAppM ``Prod.mk #[x, y] @@ -47,6 +50,17 @@ def getSigmaTypes (ty : Expr) : MetaM (Expr × Expr) := do else pure (args.get! 0, args.get! 1) +/- Make a sigma type. + + `x` should be a variable, and `ty` and type which (might) uses `x` + -/ +def mkSigmaType (x : Expr) (sty : Expr) : MetaM Expr := do + trace[Diverge.def.sigmas] "mkSigmaType: {x} {sty}" + let alpha ← inferType x + let beta ← mkLambdaFVars #[x] sty + trace[Diverge.def.sigmas] "mkSigmaType: ({alpha}) ({beta})" + mkAppOptM ``Sigma #[some alpha, some beta] + /- Generate a Sigma type from a list of *variables* (all the expressions must be variables). @@ -64,16 +78,12 @@ def mkSigmasType (xl : List Expr) : MetaM Expr := pure (Expr.const ``PUnit [Level.succ .zero]) | [x] => do trace[Diverge.def.sigmas] "mkSigmasType: [{x}]" - let ty ← Lean.Meta.inferType x + let ty ← inferType x pure ty | x :: xl => do trace[Diverge.def.sigmas] "mkSigmasType: [{x}::{xl}]" - let alpha ← Lean.Meta.inferType x let sty ← mkSigmasType xl - trace[Diverge.def.sigmas] "mkSigmasType: [{x}::{xl}]: alpha={alpha}, sty={sty}" - let beta ← mkLambdaFVars #[x] sty - trace[Diverge.def.sigmas] "mkSigmasType: ({alpha}) ({beta})" - mkAppOptM ``Sigma #[some alpha, some beta] + mkSigmaType x sty /- Generate a product type from a list of *variables* (this is similar to `mkSigmas`). @@ -90,11 +100,11 @@ def mkProdsType (xl : List Expr) : MetaM Expr := pure (Expr.const ``PUnit [Level.succ .zero]) | [x] => do trace[Diverge.def.prods] "mkProdsType: [{x}]" - let ty ← Lean.Meta.inferType x + let ty ← inferType x pure ty | x :: xl => do trace[Diverge.def.prods] "mkProdsType: [{x}::{xl}]" - let ty ← Lean.Meta.inferType x + let ty ← inferType x let xl_ty ← mkProdsType xl mkAppM ``Prod #[ty, xl_ty] @@ -114,7 +124,7 @@ def splitInputArgs (in_tys : Array Expr) (out_ty : Expr) : MetaM (Array Expr × let rec splitAux (in_tys : List Expr) : MetaM (HashSet FVarId × List Expr × List Expr) := match in_tys with | [] => do - let fvars ← getFVarIds (← Lean.Meta.inferType out_ty) + let fvars ← getFVarIds (← inferType out_ty) pure (fvars, [], []) | ty :: in_tys => do let (fvars, in_tys, in_args) ← splitAux in_tys @@ -132,7 +142,7 @@ def splitInputArgs (in_tys : Array Expr) (out_ty : Expr) : MetaM (Array Expr × pure (fvars, [ty], in_args) else -- We must split later: update the fvars set - let fvars := fvars.insertMany (← getFVarIds (← Lean.Meta.inferType ty)) + let fvars := fvars.insertMany (← getFVarIds (← inferType ty)) pure (fvars, [], ty :: in_args) let (_, in_tys, in_args) ← splitAux in_tys.data pure (Array.mk in_tys, Array.mk in_args) @@ -250,13 +260,13 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met | Sigma.mk x ... -- the hole is given by a recursive call on the tail ``` -/ trace[Diverge.def.sigmas] "mkSigmasMatch: [{fst}::{xl}]" - let alpha ← Lean.Meta.inferType fst + let alpha ← inferType fst let snd_ty ← mkSigmasType xl let beta ← mkLambdaFVars #[fst] snd_ty let snd ← mkSigmasMatch xl out (index + 1) let mk ← mkLambdaFVars #[fst] snd -- Introduce the "scrut" variable - let scrut_ty ← mkSigmasType (fst :: xl) -- TODO: factor out with snd_ty + let scrut_ty ← mkSigmaType fst snd_ty withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do trace[Diverge.def.sigmas] "mkSigmasMatch: scrut: ({scrut}) : ({← inferType scrut})" -- TODO: make the computation of the motive more efficient @@ -294,12 +304,12 @@ partial def mkProdsMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Meta mkLambdaFVars #[x] out | fst :: xl => do trace[Diverge.def.prods] "mkProdsMatch: [{fst}::{xl}]" - let alpha ← Lean.Meta.inferType fst + let alpha ← inferType fst let beta ← mkProdsType xl let snd ← mkProdsMatch xl out (index + 1) let mk ← mkLambdaFVars #[fst] snd -- Introduce the "scrut" variable - let scrut_ty ← mkProdsType (fst :: xl) -- TODO: factor out with beta + let scrut_ty ← mkProdType alpha beta withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do trace[Diverge.def.prods] "mkProdsMatch: scrut: ({scrut}) : ({← inferType scrut})" -- TODO: make the computation of the motive more efficient @@ -1265,8 +1275,29 @@ elab_rules : command namespace Tests /- Some examples of partial functions -/ - - divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := + /- section HigherOrder + open FixI + + -- The index type + variable {id : Type u} + + -- The input/output types + variable {a : id → Type v} {b : (i:id) → a i → Type w} + + -- Example with a higher-order function + theorem map_is_valid + {{f : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)}} + (Hfvalid : ∀ k i x, is_valid_p k (λ k => f k i x)) + (k : (a → Result b) → a → Result b) + (ls : List a) : + is_valid_p k (λ k => Ex5.map (f k) ls) := + induction ls <;> simp [map] + apply is_valid_p_bind <;> try simp_all + intros + apply is_valid_p_bind <;> try simp_all + end HigherOrder -/ + + divergent def list_nth {a: Type u} (ls : List a) (i : Int) : Result a := match ls with | [] => .fail .panic | x :: ls => -- cgit v1.2.3