summaryrefslogtreecommitdiff
path: root/backends/lean/Base
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--backends/lean/Base/Diverge/Elab.lean63
1 files changed, 47 insertions, 16 deletions
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean
index 423a2514..3c23db64 100644
--- a/backends/lean/Base/Diverge/Elab.lean
+++ b/backends/lean/Base/Diverge/Elab.lean
@@ -21,6 +21,9 @@ open Utils
open WF in
+def mkProdType (x y : Expr) : MetaM Expr :=
+ mkAppM ``Prod #[x, y]
+
def mkProd (x y : Expr) : MetaM Expr :=
mkAppM ``Prod.mk #[x, y]
@@ -47,6 +50,17 @@ def getSigmaTypes (ty : Expr) : MetaM (Expr × Expr) := do
else
pure (args.get! 0, args.get! 1)
+/- Make a sigma type.
+
+ `x` should be a variable, and `ty` and type which (might) uses `x`
+ -/
+def mkSigmaType (x : Expr) (sty : Expr) : MetaM Expr := do
+ trace[Diverge.def.sigmas] "mkSigmaType: {x} {sty}"
+ let alpha ← inferType x
+ let beta ← mkLambdaFVars #[x] sty
+ trace[Diverge.def.sigmas] "mkSigmaType: ({alpha}) ({beta})"
+ mkAppOptM ``Sigma #[some alpha, some beta]
+
/- Generate a Sigma type from a list of *variables* (all the expressions
must be variables).
@@ -64,16 +78,12 @@ def mkSigmasType (xl : List Expr) : MetaM Expr :=
pure (Expr.const ``PUnit [Level.succ .zero])
| [x] => do
trace[Diverge.def.sigmas] "mkSigmasType: [{x}]"
- let ty ← Lean.Meta.inferType x
+ let ty ← inferType x
pure ty
| x :: xl => do
trace[Diverge.def.sigmas] "mkSigmasType: [{x}::{xl}]"
- let alpha ← Lean.Meta.inferType x
let sty ← mkSigmasType xl
- trace[Diverge.def.sigmas] "mkSigmasType: [{x}::{xl}]: alpha={alpha}, sty={sty}"
- let beta ← mkLambdaFVars #[x] sty
- trace[Diverge.def.sigmas] "mkSigmasType: ({alpha}) ({beta})"
- mkAppOptM ``Sigma #[some alpha, some beta]
+ mkSigmaType x sty
/- Generate a product type from a list of *variables* (this is similar to `mkSigmas`).
@@ -90,11 +100,11 @@ def mkProdsType (xl : List Expr) : MetaM Expr :=
pure (Expr.const ``PUnit [Level.succ .zero])
| [x] => do
trace[Diverge.def.prods] "mkProdsType: [{x}]"
- let ty ← Lean.Meta.inferType x
+ let ty ← inferType x
pure ty
| x :: xl => do
trace[Diverge.def.prods] "mkProdsType: [{x}::{xl}]"
- let ty ← Lean.Meta.inferType x
+ let ty ← inferType x
let xl_ty ← mkProdsType xl
mkAppM ``Prod #[ty, xl_ty]
@@ -114,7 +124,7 @@ def splitInputArgs (in_tys : Array Expr) (out_ty : Expr) : MetaM (Array Expr ×
let rec splitAux (in_tys : List Expr) : MetaM (HashSet FVarId × List Expr × List Expr) :=
match in_tys with
| [] => do
- let fvars ← getFVarIds (← Lean.Meta.inferType out_ty)
+ let fvars ← getFVarIds (← inferType out_ty)
pure (fvars, [], [])
| ty :: in_tys => do
let (fvars, in_tys, in_args) ← splitAux in_tys
@@ -132,7 +142,7 @@ def splitInputArgs (in_tys : Array Expr) (out_ty : Expr) : MetaM (Array Expr ×
pure (fvars, [ty], in_args)
else
-- We must split later: update the fvars set
- let fvars := fvars.insertMany (← getFVarIds (← Lean.Meta.inferType ty))
+ let fvars := fvars.insertMany (← getFVarIds (← inferType ty))
pure (fvars, [], ty :: in_args)
let (_, in_tys, in_args) ← splitAux in_tys.data
pure (Array.mk in_tys, Array.mk in_args)
@@ -250,13 +260,13 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met
| Sigma.mk x ... -- the hole is given by a recursive call on the tail
``` -/
trace[Diverge.def.sigmas] "mkSigmasMatch: [{fst}::{xl}]"
- let alpha ← Lean.Meta.inferType fst
+ let alpha ← inferType fst
let snd_ty ← mkSigmasType xl
let beta ← mkLambdaFVars #[fst] snd_ty
let snd ← mkSigmasMatch xl out (index + 1)
let mk ← mkLambdaFVars #[fst] snd
-- Introduce the "scrut" variable
- let scrut_ty ← mkSigmasType (fst :: xl) -- TODO: factor out with snd_ty
+ let scrut_ty ← mkSigmaType fst snd_ty
withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do
trace[Diverge.def.sigmas] "mkSigmasMatch: scrut: ({scrut}) : ({← inferType scrut})"
-- TODO: make the computation of the motive more efficient
@@ -294,12 +304,12 @@ partial def mkProdsMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Meta
mkLambdaFVars #[x] out
| fst :: xl => do
trace[Diverge.def.prods] "mkProdsMatch: [{fst}::{xl}]"
- let alpha ← Lean.Meta.inferType fst
+ let alpha ← inferType fst
let beta ← mkProdsType xl
let snd ← mkProdsMatch xl out (index + 1)
let mk ← mkLambdaFVars #[fst] snd
-- Introduce the "scrut" variable
- let scrut_ty ← mkProdsType (fst :: xl) -- TODO: factor out with beta
+ let scrut_ty ← mkProdType alpha beta
withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do
trace[Diverge.def.prods] "mkProdsMatch: scrut: ({scrut}) : ({← inferType scrut})"
-- TODO: make the computation of the motive more efficient
@@ -1265,8 +1275,29 @@ elab_rules : command
namespace Tests
/- Some examples of partial functions -/
-
- divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a :=
+ /- section HigherOrder
+ open FixI
+
+ -- The index type
+ variable {id : Type u}
+
+ -- The input/output types
+ variable {a : id → Type v} {b : (i:id) → a i → Type w}
+
+ -- Example with a higher-order function
+ theorem map_is_valid
+ {{f : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)}}
+ (Hfvalid : ∀ k i x, is_valid_p k (λ k => f k i x))
+ (k : (a → Result b) → a → Result b)
+ (ls : List a) :
+ is_valid_p k (λ k => Ex5.map (f k) ls) :=
+ induction ls <;> simp [map]
+ apply is_valid_p_bind <;> try simp_all
+ intros
+ apply is_valid_p_bind <;> try simp_all
+ end HigherOrder -/
+
+ divergent def list_nth {a: Type u} (ls : List a) (i : Int) : Result a :=
match ls with
| [] => .fail .panic
| x :: ls =>