diff options
| author | Son Ho | 2023-12-11 17:07:34 +0100 | 
|---|---|---|
| committer | Son Ho | 2023-12-11 17:07:34 +0100 | 
| commit | ee669c4dbf8be12a3dd7249c645fd7092ba3e8eb (patch) | |
| tree | 7c5fe4cb881b30698cfaad48ee84198aa2d0f331 /backends | |
| parent | c23a37617188a1bbf913b5c700522abc33bf39c9 (diff) | |
Cleanup a bit
Diffstat (limited to 'backends')
| -rw-r--r-- | backends/lean/Base/Diverge/Elab.lean | 63 | 
1 files changed, 47 insertions, 16 deletions
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 423a2514..3c23db64 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -21,6 +21,9 @@ open Utils  open WF in +def mkProdType (x y : Expr) : MetaM Expr := +  mkAppM ``Prod #[x, y] +  def mkProd (x y : Expr) : MetaM Expr :=    mkAppM ``Prod.mk #[x, y] @@ -47,6 +50,17 @@ def getSigmaTypes (ty : Expr) : MetaM (Expr × Expr) := do    else      pure (args.get! 0, args.get! 1) +/- Make a sigma type. + +   `x` should be a variable, and `ty` and type which (might) uses `x` + -/ +def mkSigmaType (x : Expr) (sty : Expr) : MetaM Expr := do +  trace[Diverge.def.sigmas] "mkSigmaType: {x} {sty}" +  let alpha ← inferType x +  let beta ← mkLambdaFVars #[x] sty +  trace[Diverge.def.sigmas] "mkSigmaType: ({alpha}) ({beta})" +  mkAppOptM ``Sigma #[some alpha, some beta] +  /- Generate a Sigma type from a list of *variables* (all the expressions     must be variables). @@ -64,16 +78,12 @@ def mkSigmasType (xl : List Expr) : MetaM Expr :=      pure (Expr.const ``PUnit [Level.succ .zero])    | [x] => do      trace[Diverge.def.sigmas] "mkSigmasType: [{x}]" -    let ty ← Lean.Meta.inferType x +    let ty ← inferType x      pure ty    | x :: xl => do      trace[Diverge.def.sigmas] "mkSigmasType: [{x}::{xl}]" -    let alpha ← Lean.Meta.inferType x      let sty ← mkSigmasType xl -    trace[Diverge.def.sigmas] "mkSigmasType: [{x}::{xl}]: alpha={alpha}, sty={sty}" -    let beta ← mkLambdaFVars #[x] sty -    trace[Diverge.def.sigmas] "mkSigmasType: ({alpha}) ({beta})" -    mkAppOptM ``Sigma #[some alpha, some beta] +    mkSigmaType x sty  /- Generate a product type from a list of *variables* (this is similar to `mkSigmas`). @@ -90,11 +100,11 @@ def mkProdsType (xl : List Expr) : MetaM Expr :=      pure (Expr.const ``PUnit [Level.succ .zero])    | [x] => do      trace[Diverge.def.prods] "mkProdsType: [{x}]" -    let ty ← Lean.Meta.inferType x +    let ty ← inferType x      pure ty    | x :: xl => do      trace[Diverge.def.prods] "mkProdsType: [{x}::{xl}]" -    let ty ← Lean.Meta.inferType x +    let ty ← inferType x      let xl_ty ← mkProdsType xl      mkAppM ``Prod #[ty, xl_ty] @@ -114,7 +124,7 @@ def splitInputArgs (in_tys : Array Expr) (out_ty : Expr) : MetaM (Array Expr ×    let rec splitAux (in_tys : List Expr) : MetaM (HashSet FVarId × List Expr × List Expr) :=      match in_tys with      | [] => do -      let fvars ← getFVarIds (← Lean.Meta.inferType out_ty) +      let fvars ← getFVarIds (← inferType out_ty)        pure (fvars, [], [])      | ty :: in_tys => do        let (fvars, in_tys, in_args) ← splitAux in_tys @@ -132,7 +142,7 @@ def splitInputArgs (in_tys : Array Expr) (out_ty : Expr) : MetaM (Array Expr ×            pure (fvars, [ty], in_args)          else            -- We must split later: update the fvars set -          let fvars := fvars.insertMany (← getFVarIds (← Lean.Meta.inferType ty)) +          let fvars := fvars.insertMany (← getFVarIds (← inferType ty))            pure (fvars, [], ty :: in_args)    let (_, in_tys, in_args) ← splitAux in_tys.data    pure (Array.mk in_tys, Array.mk in_args) @@ -250,13 +260,13 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met         | Sigma.mk x ...  -- the hole is given by a recursive call on the tail         ``` -/      trace[Diverge.def.sigmas] "mkSigmasMatch: [{fst}::{xl}]" -    let alpha ← Lean.Meta.inferType fst +    let alpha ← inferType fst      let snd_ty ← mkSigmasType xl      let beta ← mkLambdaFVars #[fst] snd_ty      let snd ← mkSigmasMatch xl out (index + 1)      let mk ← mkLambdaFVars #[fst] snd      -- Introduce the "scrut" variable -    let scrut_ty ← mkSigmasType (fst :: xl) -- TODO: factor out with snd_ty +    let scrut_ty ← mkSigmaType fst snd_ty      withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do      trace[Diverge.def.sigmas] "mkSigmasMatch: scrut: ({scrut}) : ({← inferType scrut})"      -- TODO: make the computation of the motive more efficient @@ -294,12 +304,12 @@ partial def mkProdsMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Meta      mkLambdaFVars #[x] out    | fst :: xl => do      trace[Diverge.def.prods] "mkProdsMatch: [{fst}::{xl}]" -    let alpha ← Lean.Meta.inferType fst +    let alpha ← inferType fst      let beta ← mkProdsType xl      let snd ← mkProdsMatch xl out (index + 1)      let mk ← mkLambdaFVars #[fst] snd      -- Introduce the "scrut" variable -    let scrut_ty ← mkProdsType (fst :: xl) -- TODO: factor out with beta +    let scrut_ty ← mkProdType alpha beta      withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do      trace[Diverge.def.prods] "mkProdsMatch: scrut: ({scrut}) : ({← inferType scrut})"      -- TODO: make the computation of the motive more efficient @@ -1265,8 +1275,29 @@ elab_rules : command  namespace Tests    /- Some examples of partial functions -/ - -  divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := +  /- section HigherOrder +  open FixI + +  -- The index type +  variable {id : Type u} + +  -- The input/output types +  variable {a : id → Type v} {b : (i:id) → a i → Type w} + +  -- Example with a higher-order function +  theorem map_is_valid +    {{f : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)}} +    (Hfvalid : ∀ k i x, is_valid_p k (λ k => f k i x)) +    (k : (a → Result b) → a → Result b) +    (ls : List a) : +    is_valid_p k (λ k => Ex5.map (f k) ls) := +    induction ls <;> simp [map] +    apply is_valid_p_bind <;> try simp_all +    intros +    apply is_valid_p_bind <;> try simp_all +  end HigherOrder -/ + +  divergent def list_nth {a: Type u} (ls : List a) (i : Int) : Result a :=      match ls with      | [] => .fail .panic      | x :: ls =>  | 
