summaryrefslogtreecommitdiff
path: root/backends/lean/Base/Diverge
diff options
context:
space:
mode:
authorSon Ho2023-12-11 18:34:10 +0100
committerSon Ho2023-12-11 18:34:10 +0100
commit78367ef21c147b26040e0f6062a907fceab1f390 (patch)
tree6dfb4095d4df161023164b512847ba90f3f72300 /backends/lean/Base/Diverge
parentee669c4dbf8be12a3dd7249c645fd7092ba3e8eb (diff)
Start working on higher-order examples for Diverge
Diffstat (limited to 'backends/lean/Base/Diverge')
-rw-r--r--backends/lean/Base/Diverge/Base.lean3
-rw-r--r--backends/lean/Base/Diverge/Elab.lean501
-rw-r--r--backends/lean/Base/Diverge/ElabBase.lean63
3 files changed, 353 insertions, 214 deletions
diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean
index a7107c1e..bdc3ed04 100644
--- a/backends/lean/Base/Diverge/Base.lean
+++ b/backends/lean/Base/Diverge/Base.lean
@@ -5,6 +5,7 @@ import Mathlib.Tactic.RunCmd
import Mathlib.Tactic.Linarith
import Base.Primitives.Base
import Base.Arith.Base
+import Base.Diverge.ElabBase
/- TODO: this is very useful, but is there more? -/
set_option profiler true
@@ -1467,6 +1468,7 @@ namespace Ex8
.ret (hd :: tl)
/- The validity theorem for `map`, generic in `f` -/
+ @[divspec]
theorem map_is_valid
(i : id) (t : ty i)
{{f : (a i t → Result (b i t)) → (a i t) → Result c}}
@@ -1479,6 +1481,7 @@ namespace Ex8
intros
apply is_valid_p_bind <;> try simp_all
+ @[divspec]
theorem map_is_valid'
(i : id) (t : ty i)
(k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t))
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean
index 3c23db64..97364d14 100644
--- a/backends/lean/Base/Diverge/Elab.lean
+++ b/backends/lean/Base/Diverge/Elab.lean
@@ -21,6 +21,10 @@ open Utils
open WF in
+-- TODO: use those
+def UnitType := Expr.const ``PUnit [Level.succ .zero]
+def UnitValue := Expr.const ``PUnit.unit [Level.succ .zero]
+
def mkProdType (x y : Expr) : MetaM Expr :=
mkAppM ``Prod #[x, y]
@@ -382,29 +386,71 @@ def mkFinVal (n i : Nat) : MetaM Expr := do
let ofNat ← mkAppOptM ``Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat #[n_lit, i_lit]
mkAppOptM ``OfNat.ofNat #[none, none, ofNat]
+/- Information about the type of a function in a declaration group.
+
+ In the comments about the fields, we take as example the
+ `list_nth (α : Type) (ls : List α) (i : Int) : Result α` function.
+ -/
+structure TypeInfo where
+ /- The total number of input arguments.
+
+ For list_nth: 3
+ -/
+ total_num_args : ℕ
+ /- The number of type parameters (they should be a prefix of the input arguments).
+
+ For `list_nth`: 1
+ -/
+ num_params : ℕ
+ /- The type of the dependent tuple grouping the type parameters.
+
+ For `list_nth`: `Type`
+ -/
+ params_ty : Expr
+ /- The type of the tuple grouping the input values. This is a function taking
+ as input a value of type `params_ty`.
+
+ For `list_nth`: `λ a => List a × Int`
+ -/
+ in_ty : Expr
+ /- The output type, without the `Return`. This is a function taking
+ as input a value of type `params_ty`.
+
+ For `list_nth`: `λ a => a`
+ -/
+ out_ty : Expr
+
+def mkInOutTyFromTypeInfo (info : TypeInfo) : MetaM Expr := do
+ mkInOutTy info.params_ty info.in_ty info.out_ty
+
+instance : Inhabited TypeInfo :=
+ { default := { total_num_args := 0, num_params := 0, params_ty := UnitType,
+ in_ty := UnitType, out_ty := UnitType } }
+
+instance : ToMessageData TypeInfo :=
+ ⟨ λ ⟨ total_num_args, num_params, params_ty, in_ty, out_ty ⟩ =>
+ f!"\{ total_num_args: {total_num_args}, num_params: {num_params}, params_ty: {params_ty}, in_ty: {in_ty}, out_ty: {out_ty} }}"
+ ⟩
+
/- Generate and declare as individual definitions the bodies for the individual funcions:
- replace the recursive calls with calls to the continutation `k`
- make those bodies take one single dependent tuple as input
We name the declarations: "[original_name].body".
We return the new declarations.
-
- Inputs:
- - `paramInOutTys`: (number of type parameters, sigma type grouping the type parameters, input type, output type)
-/
def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
- (paramInOutTys : Array (ℕ × Expr × Expr × Expr)) (preDefs : Array PreDefinition) :
+ (paramInOutTys : Array TypeInfo) (preDefs : Array PreDefinition) :
MetaM (Array Expr) := do
let grSize := preDefs.size
- /- Compute the map from name to (index, num type parameters, parameters type, input type).
- Example for `list_nth (α : Type) (ls : List α) (i : Int) : Result α`: `"list_nth" → (0, 1, α, (List α × Int))`
+ /- Compute the map from name to (index, type info).
+
Remark: the continuation has an indexed type; we use the index (a finite number of
type `Fin`) to control which function we call at the recursive call site. -/
- let nameToInfo : HashMap Name (Nat × Nat × Expr × Expr) :=
+ let nameToInfo : HashMap Name (Nat × TypeInfo) :=
let bl := preDefs.mapIdx fun i d =>
- let (num_params, params_ty, in_ty, _) := paramInOutTys.get! i.val
- (d.declName, (i.val, num_params, params_ty, in_ty))
+ (d.declName, (i.val, paramInOutTys.get! i.val))
HashMap.ofList bl.toList
trace[Diverge.def.genBody] "nameToId: {nameToInfo.toList}"
@@ -423,18 +469,29 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
let name := f.constName!
match nameToInfo.find? name with
| none => pure e
- | some (id, num_params, params_ty, _in_ty) =>
+ | some (id, type_info) =>
trace[Diverge.def.genBody.visit] "this is a recursive call"
-- This is a recursive call: replace it
-- Compute the index
let i ← mkFinVal grSize id
+ -- It can happen that there are no input values given to the
+ -- recursive calls, and only type parameters.
+ let num_args := args.size
+ if num_args ≠ type_info.total_num_args ∧ num_args ≠ type_info.num_params then
+ throwError "Invalid number of arguments for the recursive call: {e}"
-- Split the arguments, and put them in two tuples (the first
-- one is a dependent tuple)
- let (param_args, args) := args.toList.splitAt num_params
+ let (param_args, args) := args.toList.splitAt type_info.num_params
trace[Diverge.def.genBody.visit] "param_args: {param_args}, args: {args}"
- let param_args ← mkSigmasVal params_ty param_args
- let args ← mkProdsVal args
- mkAppM' kk_var #[i, param_args, args]
+ let param_args ← mkSigmasVal type_info.params_ty param_args
+ -- Check if there are input values
+ if num_args = type_info.total_num_args then do
+ trace[Diverge.def.genBody.visit] "Recursive call with input values"
+ let args ← mkProdsVal args
+ mkAppM' kk_var #[i, param_args, args]
+ else do
+ trace[Diverge.def.genBody.visit] "Recursive call without input values"
+ mkAppM' kk_var #[i, param_args]
else
-- Not a recursive call: do nothing
pure e
@@ -458,8 +515,8 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
-- (over which we match to retrieve the individual arguments).
lambdaTelescope body fun args body => do
-- Split the arguments between the type parameters and the "regular" inputs
- let (_, num_params, _, _) := nameToInfo.find! preDef.declName
- let (param_args, args) := args.toList.splitAt num_params
+ let (_, type_info) := nameToInfo.find! preDef.declName
+ let (param_args, args) := args.toList.splitAt type_info.num_params
let body ← mkProdsMatchOrUnit args body
trace[Diverge.def.genBody] "Body after mkProdsMatchOrUnit: {body}"
let body ← mkSigmasMatchOrUnit param_args body
@@ -494,27 +551,30 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
-/
def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name)
(kk_var i_var : Expr)
- (param_ty in_ty out_ty : Expr) (paramInOutTys : List (Nat × Expr × Expr × Expr))
+ (param_ty in_ty out_ty : Expr) (paramInOutTys : Array TypeInfo)
(bodies : Array Expr) : MetaM (Expr × Expr) := do
-- Generate the body
let grSize := bodies.size
let finTypeExpr := mkFin grSize
-- TODO: not very clean
let paramInOutTyType ← do
- let (_, x, y, z) := paramInOutTys.get! 0
- inferType (← mkInOutTy x y z)
- let rec mkFuns (paramInOutTys : List (Nat × Expr × Expr × Expr)) (bl : List Expr) : MetaM Expr :=
+ let info := paramInOutTys.get! 0
+ inferType (← mkInOutTyFromTypeInfo info)
+ let rec mkFuns (paramInOutTys : List TypeInfo) (bl : List Expr) : MetaM Expr :=
match paramInOutTys, bl with
| [], [] =>
mkAppOptM ``FixII.Funs.Nil #[finTypeExpr, param_ty, in_ty, out_ty]
- | (_, pty, ity, oty) :: paramInOutTys, b :: bl => do
+ | info :: paramInOutTys, b :: bl => do
+ let pty := info.params_ty
+ let ity := info.in_ty
+ let oty := info.out_ty
-- Retrieving ity and oty - this is not very clean
let paramInOutTysExpr ← mkListLit paramInOutTyType
- (← paramInOutTys.mapM (λ (_, x, y, z) => mkInOutTy x y z))
+ (← paramInOutTys.mapM mkInOutTyFromTypeInfo)
let fl ← mkFuns paramInOutTys bl
mkAppOptM ``FixII.Funs.Cons #[finTypeExpr, param_ty, in_ty, out_ty, pty, ity, oty, paramInOutTysExpr, b, fl]
| _, _ => throwError "mkDeclareMutRecBody: `tys` and `bodies` don't have the same length"
- let bodyFuns ← mkFuns paramInOutTys bodies.toList
+ let bodyFuns ← mkFuns paramInOutTys.toList bodies.toList
-- Wrap in `get_fun`
let body ← mkAppM ``FixII.get_fun #[bodyFuns, i_var, kk_var]
-- Add the index `i` and the continuation `k` as a variables
@@ -574,7 +634,7 @@ mutual
```
-/
partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
- trace[Diverge.def.valid] "proveValid: {e}"
+ trace[Diverge.def.valid] "proveExprIsValid: {e}"
match e with
| .const _ _ => throwError "Unimplemented" -- Shouldn't get there?
| .bvar _
@@ -602,160 +662,173 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
proveNoKExprIsValid k_var e
| .app .. =>
e.withApp fun f args => do
- -- There are several cases: first, check if this is a match/if
- -- Check if the expression is a (dependent) if then else.
- -- We treat the if then else expressions differently from the other matches,
- -- and have dedicated theorems for them.
- let isIte := e.isIte
- if isIte || e.isDIte then do
- e.withApp fun f args => do
- trace[Diverge.def.valid] "ite/dite: {f}:\n{args}"
- if args.size ≠ 5 then
- throwError "Wrong number of parameters for {f}: {args}"
- let cond := args.get! 1
- let dec := args.get! 2
- -- Prove that the branches are valid
- let br0 := args.get! 3
- let br1 := args.get! 4
- let proveBranchValid (br : Expr) : MetaM Expr :=
- if isIte then proveExprIsValid k_var kk_var br
- else do
- -- There is a lambda
- lambdaOne br fun x br => do
- let brValid ← proveExprIsValid k_var kk_var br
- mkLambdaFVars #[x] brValid
- let br0Valid ← proveBranchValid br0
- let br1Valid ← proveBranchValid br1
- let const := if isIte then ``FixII.is_valid_p_ite else ``FixII.is_valid_p_dite
- let eIsValid ←
- mkAppOptM const #[none, none, none, none, none,
- some k_var, some cond, some dec, none, none,
- some br0Valid, some br1Valid]
- trace[Diverge.def.valid] "ite/dite: result:\n{eIsValid}:\n{← inferType eIsValid}"
- pure eIsValid
- -- Check if the expression is a match (this case is for when the elaborator
- -- introduces auxiliary definitions to hide the match behind syntactic
- -- sugar):
- else if let some me := ← matchMatcherApp? e then do
- trace[Diverge.def.valid]
- "matcherApp:
- - params: {me.params}
- - motive: {me.motive}
- - discrs: {me.discrs}
- - altNumParams: {me.altNumParams}
- - alts: {me.alts}
- - remaining: {me.remaining}"
- -- matchMatcherApp does all the work for us: we simply need to gather
- -- the information and call the auxiliary helper `proveMatchIsValid`
- if me.remaining.size ≠ 0 then
- throwError "MatcherApp: non empty remaining array: {me.remaining}"
- let me : MatchInfo := {
- matcherName := me.matcherName
- matcherLevels := me.matcherLevels
- params := me.params
- motive := me.motive
- scruts := me.discrs
- branchesNumParams := me.altNumParams
- branches := me.alts
- }
- proveMatchIsValid k_var kk_var me
- -- Check if the expression is a raw match (this case is for when the expression
- -- is a direct call to the primitive `casesOn` function, without syntactic sugar).
- -- We have to check this case because functions like `mkSigmasMatch`, which we
- -- use to currify function bodies, introduce such raw matches.
- else if ← isCasesExpr f then do
- trace[Diverge.def.valid] "rawMatch: {e}"
- /- Deconstruct the match, and call the auxiliary helper `proveMatchIsValid`.
-
- The casesOn definition is always of the following shape:
- - input parameters (implicit parameters)
- - motive (implicit), -- the motive gives the return type of the match
- - scrutinee (explicit)
- - branches (explicit).
- In particular, we notice that the scrutinee is the first *explicit*
- parameter - this is how we spot it.
- -/
- let matcherName := f.constName!
- let matcherLevels := f.constLevels!.toArray
- -- Find the first explicit parameter: this is the scrutinee
- forallTelescope (← inferType f) fun xs _ => do
- let rec findFirstExplicit (i : Nat) : MetaM Nat := do
- if i ≥ xs.size then throwError "Unexpected: could not find an explicit parameter"
- else
- let x := xs.get! i
- let xFVarId := x.fvarId!
- let localDecl ← xFVarId.getDecl
- match localDecl.binderInfo with
- | .default => pure i
- | _ => findFirstExplicit (i + 1)
- let scrutIdx ← findFirstExplicit 0
- -- Split the arguments
- let params := args.extract 0 (scrutIdx - 1)
- let motive := args.get! (scrutIdx - 1)
- let scrut := args.get! scrutIdx
- let branches := args.extract (scrutIdx + 1) args.size
- /- Compute the number of parameters for the branches: for this we use
- the type of the uninstantiated casesOn constant (we can't just
- destruct the lambdas in the branch expressions because the result
- of a match might be a lambda expression). -/
- let branchesNumParams : Array Nat ← do
- let env ← getEnv
- let decl := env.constants.find! matcherName
- let ty := decl.type
- forallTelescope ty fun xs _ => do
- let xs := xs.extract (scrutIdx + 1) xs.size
- xs.mapM fun x => do
- let xty ← inferType x
- forallTelescope xty fun ys _ => do
- pure ys.size
- let me : MatchInfo := {
- matcherName,
- matcherLevels,
- params,
- motive,
- scruts := #[scrut],
- branchesNumParams,
- branches,
- }
- proveMatchIsValid k_var kk_var me
- -- Check if this is a monadic let-binding
- else if f.isConstOf ``Bind.bind then do
- trace[Diverge.def.valid] "bind:\n{args}"
- -- We simply need to prove that the subexpressions are valid, and call
- -- the appropriate lemma.
- let x := args.get! 4
- let y := args.get! 5
- -- Prove that the subexpressions are valid
- let xValid ← proveExprIsValid k_var kk_var x
- trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}"
- let yValid ← do
- -- This is a lambda expression
- lambdaOne y fun x y => do
- trace[Diverge.def.valid] "bind: y: {y}"
- let yValid ← proveExprIsValid k_var kk_var y
- trace[Diverge.def.valid] "bind: yValid (no forall): {yValid}"
- trace[Diverge.def.valid] "bind: yValid: x: {x}"
- let yValid ← mkLambdaFVars #[x] yValid
- trace[Diverge.def.valid] "bind: yValid (forall): {yValid}: {← inferType yValid}"
- pure yValid
- -- Put everything together
- trace[Diverge.def.valid] "bind:\n- xValid: {xValid}: {← inferType xValid}\n- yValid: {yValid}: {← inferType yValid}"
- mkAppM ``FixII.is_valid_p_bind #[xValid, yValid]
- -- Check if this is a recursive call, i.e., a call to the continuation `kk`
- else if f.isFVarOf kk_var.fvarId! then do
- trace[Diverge.def.valid] "rec: args: \n{args}"
- if args.size ≠ 3 then throwError "Recursive call with invalid number of parameters: {args}"
- let i_arg := args.get! 0
- let t_arg := args.get! 1
- let x_arg := args.get! 2
- let eIsValid ← mkAppM ``FixII.is_valid_p_rec #[k_var, i_arg, t_arg, x_arg]
- trace[Diverge.def.valid] "rec: result: \n{eIsValid}"
- pure eIsValid
+ proveAppIsValid k_var kk_var e f args
+
+partial def proveAppIsValid (k_var kk_var : Expr) (e : Expr) (f : Expr) (args : Array Expr): MetaM Expr := do
+ trace[Diverge.def.valid] "proveAppIsValid: {f} {args}"
+ /- There are several cases: first, check if this is a match/if
+ Check if the expression is a (dependent) if then else.
+ We treat the if then else expressions differently from the other matches,
+ and have dedicated theorems for them. -/
+ let isIte := e.isIte
+ if isIte || e.isDIte then do
+ e.withApp fun f args => do
+ trace[Diverge.def.valid] "ite/dite: {f}:\n{args}"
+ if args.size ≠ 5 then
+ throwError "Wrong number of parameters for {f}: {args}"
+ let cond := args.get! 1
+ let dec := args.get! 2
+ -- Prove that the branches are valid
+ let br0 := args.get! 3
+ let br1 := args.get! 4
+ let proveBranchValid (br : Expr) : MetaM Expr :=
+ if isIte then proveExprIsValid k_var kk_var br
else do
- -- Remaining case: normal application.
- -- It shouldn't use the continuation.
- -- TODO: actually, it can
- proveNoKExprIsValid k_var e
+ -- There is a lambda
+ lambdaOne br fun x br => do
+ let brValid ← proveExprIsValid k_var kk_var br
+ mkLambdaFVars #[x] brValid
+ let br0Valid ← proveBranchValid br0
+ let br1Valid ← proveBranchValid br1
+ let const := if isIte then ``FixII.is_valid_p_ite else ``FixII.is_valid_p_dite
+ let eIsValid ←
+ mkAppOptM const #[none, none, none, none, none,
+ some k_var, some cond, some dec, none, none,
+ some br0Valid, some br1Valid]
+ trace[Diverge.def.valid] "ite/dite: result:\n{eIsValid}:\n{← inferType eIsValid}"
+ pure eIsValid
+ /- Check if the expression is a match (this case is for when the elaborator
+ introduces auxiliary definitions to hide the match behind syntactic
+ sugar): -/
+ else if let some me := ← matchMatcherApp? e then do
+ trace[Diverge.def.valid]
+ "matcherApp:
+ - params: {me.params}
+ - motive: {me.motive}
+ - discrs: {me.discrs}
+ - altNumParams: {me.altNumParams}
+ - alts: {me.alts}
+ - remaining: {me.remaining}"
+ -- matchMatcherApp does all the work for us: we simply need to gather
+ -- the information and call the auxiliary helper `proveMatchIsValid`
+ if me.remaining.size ≠ 0 then
+ throwError "MatcherApp: non empty remaining array: {me.remaining}"
+ let me : MatchInfo := {
+ matcherName := me.matcherName
+ matcherLevels := me.matcherLevels
+ params := me.params
+ motive := me.motive
+ scruts := me.discrs
+ branchesNumParams := me.altNumParams
+ branches := me.alts
+ }
+ proveMatchIsValid k_var kk_var me
+ /- Check if the expression is a raw match (this case is for when the expression
+ is a direct call to the primitive `casesOn` function, without syntactic sugar).
+ We have to check this case because functions like `mkSigmasMatch`, which we
+ use to currify function bodies, introduce such raw matches. -/
+ else if ← isCasesExpr f then do
+ trace[Diverge.def.valid] "rawMatch: {e}"
+ /- Deconstruct the match, and call the auxiliary helper `proveMatchIsValid`.
+
+ The casesOn definition is always of the following shape:
+ - input parameters (implicit parameters)
+ - motive (implicit), -- the motive gives the return type of the match
+ - scrutinee (explicit)
+ - branches (explicit).
+ In particular, we notice that the scrutinee is the first *explicit*
+ parameter - this is how we spot it.
+ -/
+ let matcherName := f.constName!
+ let matcherLevels := f.constLevels!.toArray
+ -- Find the first explicit parameter: this is the scrutinee
+ forallTelescope (← inferType f) fun xs _ => do
+ let rec findFirstExplicit (i : Nat) : MetaM Nat := do
+ if i ≥ xs.size then throwError "Unexpected: could not find an explicit parameter"
+ else
+ let x := xs.get! i
+ let xFVarId := x.fvarId!
+ let localDecl ← xFVarId.getDecl
+ match localDecl.binderInfo with
+ | .default => pure i
+ | _ => findFirstExplicit (i + 1)
+ let scrutIdx ← findFirstExplicit 0
+ -- Split the arguments
+ let params := args.extract 0 (scrutIdx - 1)
+ let motive := args.get! (scrutIdx - 1)
+ let scrut := args.get! scrutIdx
+ let branches := args.extract (scrutIdx + 1) args.size
+ /- Compute the number of parameters for the branches: for this we use
+ the type of the uninstantiated casesOn constant (we can't just
+ destruct the lambdas in the branch expressions because the result
+ of a match might be a lambda expression). -/
+ let branchesNumParams : Array Nat ← do
+ let env ← getEnv
+ let decl := env.constants.find! matcherName
+ let ty := decl.type
+ forallTelescope ty fun xs _ => do
+ let xs := xs.extract (scrutIdx + 1) xs.size
+ xs.mapM fun x => do
+ let xty ← inferType x
+ forallTelescope xty fun ys _ => do
+ pure ys.size
+ let me : MatchInfo := {
+ matcherName,
+ matcherLevels,
+ params,
+ motive,
+ scruts := #[scrut],
+ branchesNumParams,
+ branches,
+ }
+ proveMatchIsValid k_var kk_var me
+ -- Check if this is a monadic let-binding
+ else if f.isConstOf ``Bind.bind then do
+ trace[Diverge.def.valid] "bind:\n{args}"
+ -- We simply need to prove that the subexpressions are valid, and call
+ -- the appropriate lemma.
+ let x := args.get! 4
+ let y := args.get! 5
+ -- Prove that the subexpressions are valid
+ let xValid ← proveExprIsValid k_var kk_var x
+ trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}"
+ let yValid ← do
+ -- This is a lambda expression
+ lambdaOne y fun x y => do
+ trace[Diverge.def.valid] "bind: y: {y}"
+ let yValid ← proveExprIsValid k_var kk_var y
+ trace[Diverge.def.valid] "bind: yValid (no forall): {yValid}"
+ trace[Diverge.def.valid] "bind: yValid: x: {x}"
+ let yValid ← mkLambdaFVars #[x] yValid
+ trace[Diverge.def.valid] "bind: yValid (forall): {yValid}: {← inferType yValid}"
+ pure yValid
+ -- Put everything together
+ trace[Diverge.def.valid] "bind:\n- xValid: {xValid}: {← inferType xValid}\n- yValid: {yValid}: {← inferType yValid}"
+ mkAppM ``FixII.is_valid_p_bind #[xValid, yValid]
+ -- Check if this is a recursive call, i.e., a call to the continuation `kk`
+ else if f.isFVarOf kk_var.fvarId! then do
+ trace[Diverge.def.valid] "rec: args: \n{args}"
+ if args.size ≠ 3 then throwError "Recursive call with invalid number of parameters: {args}"
+ let i_arg := args.get! 0
+ let t_arg := args.get! 1
+ let x_arg := args.get! 2
+ let eIsValid ← mkAppM ``FixII.is_valid_p_rec #[k_var, i_arg, t_arg, x_arg]
+ trace[Diverge.def.valid] "rec: result: \n{eIsValid}"
+ pure eIsValid
+ else do
+ /- Remaining case: normal application.
+ Check if the arguments use the continuation:
+ - if no: this is simple
+ - if yes: we have to lookup theorems in div spec database and continue -/
+ trace[Diverge.def.valid] "regular app: {e}"
+ let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty
+ if ¬ allArgsFVars.contains kk_var.fvarId! then do
+ -- Simple case
+ trace[Diverge.def.valid] "kk doesn't appear in the arguments"
+ proveNoKExprIsValid k_var e
+ else do
+ -- Lookup in the database for suitable theorems
+ throwError "TODO: {e}"
-- Prove that a match expression is valid.
partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Expr := do
@@ -912,6 +985,7 @@ def proveMutRecIsValid
-- Then prove that the mut rec body is valid
trace[Diverge.def.valid] "## Proving that the 'Funs' body is valid"
let isValid ← proveFunsBodyIsValid paramInOutTys bodyFuns k_var bodiesValid
+ trace[Diverge.def.valid] "Generated the term: {isValid}"
-- Save the theorem
let thmTy ← mkAppM ``FixII.is_valid #[mutRecBodyConst]
let name := grName ++ "mut_rec_body_is_valid"
@@ -935,18 +1009,18 @@ def proveMutRecIsValid
def is_odd (i : Int) : Result Bool := mut_rec_body 1 i
```
-/
-def mkDeclareFixDefs (mutRecBody : Expr) (paramInOutTys : Array (ℕ × Expr × Expr × Expr)) (preDefs : Array PreDefinition) :
+def mkDeclareFixDefs (mutRecBody : Expr) (paramInOutTys : Array TypeInfo) (preDefs : Array PreDefinition) :
TermElabM (Array Name) := do
let grSize := preDefs.size
let defs ← preDefs.mapIdxM fun idx preDef => do
lambdaTelescope preDef.value fun xs _ => do
-- Retrieve the parameters info
- let (num_params, param_ty, _, _) := paramInOutTys.get! idx.val
+ let type_info := paramInOutTys.get! idx.val
-- Create the index
let idx ← mkFinVal grSize idx.val
-- Group the inputs into two tuples
- let (params_args, input_args) := xs.toList.splitAt num_params
- let params ← mkSigmasVal param_ty params_args
+ let (params_args, input_args) := xs.toList.splitAt type_info.num_params
+ let params ← mkSigmasVal type_info.params_ty params_args
let input ← mkProdsVal input_args
-- Apply the fixed point
let fixedBody ← mkAppM ``FixII.fix #[mutRecBody, idx, params, input]
@@ -968,7 +1042,7 @@ def mkDeclareFixDefs (mutRecBody : Expr) (paramInOutTys : Array (ℕ × Expr ×
-- Prove the equations that we will use as unfolding theorems
partial def proveUnfoldingThms (isValidThm : Expr)
- (paramInOutTys : Array (ℕ × Expr × Expr × Expr))
+ (paramInOutTys : Array TypeInfo)
(preDefs : Array PreDefinition) (decls : Array Name) : MetaM Unit := do
let grSize := preDefs.size
let proveIdx (i : Nat) : MetaM Unit := do
@@ -993,9 +1067,9 @@ partial def proveUnfoldingThms (isValidThm : Expr)
let idx ← mkFinVal grSize i
let proof ← mkAppM ``congr_fun #[proof, idx]
-- Add the input arguments
- let (num_params, param_ty, _, _) := paramInOutTys.get! i
- let (params, args) := xs.toList.splitAt num_params
- let params ← mkSigmasVal param_ty params
+ let type_info := paramInOutTys.get! i
+ let (params, args) := xs.toList.splitAt type_info.num_params
+ let params ← mkSigmasVal type_info.params_ty params
let args ← mkProdsVal args
let proof ← mkAppM ``congr_fun #[proof, params]
let proof ← mkAppM ``congr_fun #[proof, args]
@@ -1052,7 +1126,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
we generate:
`(Type, λ α => List α × i, λ α => Result α)`
-/
- let paramInOutTys : Array (ℕ × Expr × Expr × Expr) ←
+ let paramInOutTys : Array TypeInfo ←
preDefs.mapM (fun preDef => do
-- Check the universe parameters - TODO: I'm not sure what the best thing
-- to do is. In practice, all the type parameters should be in Type 0, so
@@ -1060,19 +1134,20 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
if preDef.levelParams ≠ grLvlParams then
throwError "Non-uniform polymorphism in the universes"
forallTelescope preDef.type (fun in_tys out_ty => do
+ let total_num_args := in_tys.size
let (params, in_tys) ← splitInputArgs in_tys out_ty
trace[Diverge.def] "Decomposed arguments: {preDef.declName}: {params}, {in_tys}, {out_ty}"
let num_params := params.size
- let params_sigma ← mkSigmasType params.data
- let in_tys ← mkSigmasMatchOrUnit params.data (← mkProdsType in_tys.data)
+ let params_ty ← mkSigmasType params.data
+ let in_ty ← mkSigmasMatchOrUnit params.data (← mkProdsType in_tys.data)
-- Retrieve the type in the "Result"
let out_ty ← getResultTy out_ty
let out_ty ← mkSigmasMatchOrUnit params.data out_ty
- trace[Diverge.def] "inOutTy: {preDef.declName}: {params_sigma}, {in_tys}, {out_ty}"
- pure (num_params, params_sigma, in_tys, out_ty)))
+ trace[Diverge.def] "inOutTy: {preDef.declName}: {params_ty}, {in_tys}, {out_ty}"
+ pure ⟨ total_num_args, num_params, params_ty, in_ty, out_ty ⟩))
trace[Diverge.def] "paramInOutTys: {paramInOutTys}"
-- Turn the list of input types/input args/output type tuples into expressions
- let paramInOutTysExpr ← paramInOutTys.mapM (λ (_, x, y, z) => do mkInOutTy x y z)
+ let paramInOutTysExpr ← liftM (paramInOutTys.mapM mkInOutTyFromTypeInfo)
let paramInOutTysExpr ← mkListLit (← inferType (paramInOutTysExpr.get! 0)) paramInOutTysExpr.toList
trace[Diverge.def] "paramInOutTys: {paramInOutTys}"
@@ -1135,7 +1210,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
-- Generate the mutually recursive body
trace[Diverge.def] "# Generating the mut rec body"
- let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var param_ty in_ty out_ty paramInOutTys.toList bodies
+ let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var param_ty in_ty out_ty paramInOutTys bodies
trace[Diverge.def] "mut rec body (after decl): {mutRecBody}"
-- Prove that the mut rec body satisfies the validity criteria required by
@@ -1275,27 +1350,27 @@ elab_rules : command
namespace Tests
/- Some examples of partial functions -/
- /- section HigherOrder
- open FixI
-
- -- The index type
- variable {id : Type u}
-
- -- The input/output types
- variable {a : id → Type v} {b : (i:id) → a i → Type w}
-
- -- Example with a higher-order function
- theorem map_is_valid
- {{f : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)}}
- (Hfvalid : ∀ k i x, is_valid_p k (λ k => f k i x))
- (k : (a → Result b) → a → Result b)
- (ls : List a) :
- is_valid_p k (λ k => Ex5.map (f k) ls) :=
- induction ls <;> simp [map]
- apply is_valid_p_bind <;> try simp_all
- intros
- apply is_valid_p_bind <;> try simp_all
- end HigherOrder -/
+ section HigherOrder
+ open Ex8
+
+ inductive Tree (a : Type u) :=
+ | leaf (x : a)
+ | node (tl : List (Tree a))
+
+ set_option trace.Diverge.def true
+ -- set_option trace.Diverge.def.genBody true
+ set_option trace.Diverge.def.valid true
+ divergent def id {a : Type u} (t : Tree a) : Result (Tree a) :=
+ match t with
+ | .leaf x => .ret (.leaf x)
+ | .node tl =>
+ do
+ let tl ← map id tl
+ .ret (.node tl)
+
+ set_option trace.Diverge.def false
+
+ end HigherOrder
divergent def list_nth {a: Type u} (ls : List a) (i : Int) : Result a :=
match ls with
diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean
index b818d5af..0d33e9d2 100644
--- a/backends/lean/Base/Diverge/ElabBase.lean
+++ b/backends/lean/Base/Diverge/ElabBase.lean
@@ -1,13 +1,15 @@
import Lean
import Base.Utils
import Base.Primitives.Base
+import Base.Extensions
namespace Diverge
open Lean Elab Term Meta
-open Utils
+open Utils Extensions
-- We can't define and use trace classes in the same file
+initialize registerTraceClass `Diverge
initialize registerTraceClass `Diverge.elab
initialize registerTraceClass `Diverge.def
initialize registerTraceClass `Diverge.def.sigmas
@@ -20,4 +22,63 @@ initialize registerTraceClass `Diverge.def.unfold
-- For the attribute (for higher-order functions)
initialize registerTraceClass `Diverge.attr
+-- Attribute
+
+-- divspec attribute
+structure DivSpecAttr where
+ attr : AttributeImpl
+ ext : DiscrTreeExtension Name true
+ deriving Inhabited
+
+/- The persistent map from expressions to divspec theorems. -/
+initialize divspecAttr : DivSpecAttr ← do
+ let ext ← mkDiscrTreeExtention `divspecMap true
+ let attrImpl : AttributeImpl := {
+ name := `divspec
+ descr := "Marks theorems to use with the `divergent` encoding"
+ add := fun thName stx attrKind => do
+ Attribute.Builtin.ensureNoArgs stx
+ -- TODO: use the attribute kind
+ unless attrKind == AttributeKind.global do
+ throwError "invalid attribute divspec, must be global"
+ -- Lookup the theorem
+ let env ← getEnv
+ let thDecl := env.constants.find! thName
+ let fKey : Array (DiscrTree.Key true) ← MetaM.run' (do
+ /- The theorem should have the shape:
+ `∀ ..., is_valid_p k (λ k => ...)`
+
+ Dive into the ∀:
+ -/
+ let (_, _, fExpr) ← forallMetaTelescope thDecl.type.consumeMData
+ /- Dive into the argument of `is_valid_p`: -/
+ fExpr.consumeMData.withApp fun _ args => do
+ if args.size ≠ 7 then throwError "Invalid number of arguments to is_valid_p"
+ let fExpr := args.get! 6
+ /- Dive into the lambda: -/
+ let (_, _, fExpr) ← lambdaMetaTelescope fExpr.consumeMData
+ trace[Diverge] "Registering divspec theorem for {fExpr}"
+ -- Convert the function expression to a discrimination tree key
+ DiscrTree.mkPath fExpr)
+ let env := ext.addEntry env (fKey, thName)
+ setEnv env
+ trace[Diverge] "Saved the environment"
+ pure ()
+ }
+ registerBuiltinAttribute attrImpl
+ pure { attr := attrImpl, ext := ext }
+
+def DivSpecAttr.find? (s : DivSpecAttr) (e : Expr) : MetaM (Array Name) := do
+ (s.ext.getState (← getEnv)).getMatch e
+
+def DivSpecAttr.getState (s : DivSpecAttr) : MetaM (DiscrTree Name true) := do
+ pure (s.ext.getState (← getEnv))
+
+def showStoredDivSpec : MetaM Unit := do
+ let st ← divspecAttr.getState
+ -- TODO: how can we iterate over (at least) the values stored in the tree?
+ --let s := st.toList.foldl (fun s (f, th) => f!"{s}\n{f} → {th}") f!""
+ let s := f!"{st}"
+ IO.println s
+
end Diverge