diff options
author | Son Ho | 2023-12-11 18:34:10 +0100 |
---|---|---|
committer | Son Ho | 2023-12-11 18:34:10 +0100 |
commit | 78367ef21c147b26040e0f6062a907fceab1f390 (patch) | |
tree | 6dfb4095d4df161023164b512847ba90f3f72300 /backends/lean/Base/Diverge | |
parent | ee669c4dbf8be12a3dd7249c645fd7092ba3e8eb (diff) |
Start working on higher-order examples for Diverge
Diffstat (limited to '')
-rw-r--r-- | backends/lean/Base/Diverge/Base.lean | 3 | ||||
-rw-r--r-- | backends/lean/Base/Diverge/Elab.lean | 501 | ||||
-rw-r--r-- | backends/lean/Base/Diverge/ElabBase.lean | 63 |
3 files changed, 353 insertions, 214 deletions
diff --git a/backends/lean/Base/Diverge/Base.lean b/backends/lean/Base/Diverge/Base.lean index a7107c1e..bdc3ed04 100644 --- a/backends/lean/Base/Diverge/Base.lean +++ b/backends/lean/Base/Diverge/Base.lean @@ -5,6 +5,7 @@ import Mathlib.Tactic.RunCmd import Mathlib.Tactic.Linarith import Base.Primitives.Base import Base.Arith.Base +import Base.Diverge.ElabBase /- TODO: this is very useful, but is there more? -/ set_option profiler true @@ -1467,6 +1468,7 @@ namespace Ex8 .ret (hd :: tl) /- The validity theorem for `map`, generic in `f` -/ + @[divspec] theorem map_is_valid (i : id) (t : ty i) {{f : (a i t → Result (b i t)) → (a i t) → Result c}} @@ -1479,6 +1481,7 @@ namespace Ex8 intros apply is_valid_p_bind <;> try simp_all + @[divspec] theorem map_is_valid' (i : id) (t : ty i) (k : ((i:id) → (t:ty i) → a i t → Result (b i t)) → (i:id) → (t:ty i) → a i t → Result (b i t)) diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 3c23db64..97364d14 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -21,6 +21,10 @@ open Utils open WF in +-- TODO: use those +def UnitType := Expr.const ``PUnit [Level.succ .zero] +def UnitValue := Expr.const ``PUnit.unit [Level.succ .zero] + def mkProdType (x y : Expr) : MetaM Expr := mkAppM ``Prod #[x, y] @@ -382,29 +386,71 @@ def mkFinVal (n i : Nat) : MetaM Expr := do let ofNat ← mkAppOptM ``Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat #[n_lit, i_lit] mkAppOptM ``OfNat.ofNat #[none, none, ofNat] +/- Information about the type of a function in a declaration group. + + In the comments about the fields, we take as example the + `list_nth (α : Type) (ls : List α) (i : Int) : Result α` function. + -/ +structure TypeInfo where + /- The total number of input arguments. + + For list_nth: 3 + -/ + total_num_args : ℕ + /- The number of type parameters (they should be a prefix of the input arguments). + + For `list_nth`: 1 + -/ + num_params : ℕ + /- The type of the dependent tuple grouping the type parameters. + + For `list_nth`: `Type` + -/ + params_ty : Expr + /- The type of the tuple grouping the input values. This is a function taking + as input a value of type `params_ty`. + + For `list_nth`: `λ a => List a × Int` + -/ + in_ty : Expr + /- The output type, without the `Return`. This is a function taking + as input a value of type `params_ty`. + + For `list_nth`: `λ a => a` + -/ + out_ty : Expr + +def mkInOutTyFromTypeInfo (info : TypeInfo) : MetaM Expr := do + mkInOutTy info.params_ty info.in_ty info.out_ty + +instance : Inhabited TypeInfo := + { default := { total_num_args := 0, num_params := 0, params_ty := UnitType, + in_ty := UnitType, out_ty := UnitType } } + +instance : ToMessageData TypeInfo := + ⟨ λ ⟨ total_num_args, num_params, params_ty, in_ty, out_ty ⟩ => + f!"\{ total_num_args: {total_num_args}, num_params: {num_params}, params_ty: {params_ty}, in_ty: {in_ty}, out_ty: {out_ty} }}" + ⟩ + /- Generate and declare as individual definitions the bodies for the individual funcions: - replace the recursive calls with calls to the continutation `k` - make those bodies take one single dependent tuple as input We name the declarations: "[original_name].body". We return the new declarations. - - Inputs: - - `paramInOutTys`: (number of type parameters, sigma type grouping the type parameters, input type, output type) -/ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) - (paramInOutTys : Array (ℕ × Expr × Expr × Expr)) (preDefs : Array PreDefinition) : + (paramInOutTys : Array TypeInfo) (preDefs : Array PreDefinition) : MetaM (Array Expr) := do let grSize := preDefs.size - /- Compute the map from name to (index, num type parameters, parameters type, input type). - Example for `list_nth (α : Type) (ls : List α) (i : Int) : Result α`: `"list_nth" → (0, 1, α, (List α × Int))` + /- Compute the map from name to (index, type info). + Remark: the continuation has an indexed type; we use the index (a finite number of type `Fin`) to control which function we call at the recursive call site. -/ - let nameToInfo : HashMap Name (Nat × Nat × Expr × Expr) := + let nameToInfo : HashMap Name (Nat × TypeInfo) := let bl := preDefs.mapIdx fun i d => - let (num_params, params_ty, in_ty, _) := paramInOutTys.get! i.val - (d.declName, (i.val, num_params, params_ty, in_ty)) + (d.declName, (i.val, paramInOutTys.get! i.val)) HashMap.ofList bl.toList trace[Diverge.def.genBody] "nameToId: {nameToInfo.toList}" @@ -423,18 +469,29 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) let name := f.constName! match nameToInfo.find? name with | none => pure e - | some (id, num_params, params_ty, _in_ty) => + | some (id, type_info) => trace[Diverge.def.genBody.visit] "this is a recursive call" -- This is a recursive call: replace it -- Compute the index let i ← mkFinVal grSize id + -- It can happen that there are no input values given to the + -- recursive calls, and only type parameters. + let num_args := args.size + if num_args ≠ type_info.total_num_args ∧ num_args ≠ type_info.num_params then + throwError "Invalid number of arguments for the recursive call: {e}" -- Split the arguments, and put them in two tuples (the first -- one is a dependent tuple) - let (param_args, args) := args.toList.splitAt num_params + let (param_args, args) := args.toList.splitAt type_info.num_params trace[Diverge.def.genBody.visit] "param_args: {param_args}, args: {args}" - let param_args ← mkSigmasVal params_ty param_args - let args ← mkProdsVal args - mkAppM' kk_var #[i, param_args, args] + let param_args ← mkSigmasVal type_info.params_ty param_args + -- Check if there are input values + if num_args = type_info.total_num_args then do + trace[Diverge.def.genBody.visit] "Recursive call with input values" + let args ← mkProdsVal args + mkAppM' kk_var #[i, param_args, args] + else do + trace[Diverge.def.genBody.visit] "Recursive call without input values" + mkAppM' kk_var #[i, param_args] else -- Not a recursive call: do nothing pure e @@ -458,8 +515,8 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) -- (over which we match to retrieve the individual arguments). lambdaTelescope body fun args body => do -- Split the arguments between the type parameters and the "regular" inputs - let (_, num_params, _, _) := nameToInfo.find! preDef.declName - let (param_args, args) := args.toList.splitAt num_params + let (_, type_info) := nameToInfo.find! preDef.declName + let (param_args, args) := args.toList.splitAt type_info.num_params let body ← mkProdsMatchOrUnit args body trace[Diverge.def.genBody] "Body after mkProdsMatchOrUnit: {body}" let body ← mkSigmasMatchOrUnit param_args body @@ -494,27 +551,30 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) -/ def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name) (kk_var i_var : Expr) - (param_ty in_ty out_ty : Expr) (paramInOutTys : List (Nat × Expr × Expr × Expr)) + (param_ty in_ty out_ty : Expr) (paramInOutTys : Array TypeInfo) (bodies : Array Expr) : MetaM (Expr × Expr) := do -- Generate the body let grSize := bodies.size let finTypeExpr := mkFin grSize -- TODO: not very clean let paramInOutTyType ← do - let (_, x, y, z) := paramInOutTys.get! 0 - inferType (← mkInOutTy x y z) - let rec mkFuns (paramInOutTys : List (Nat × Expr × Expr × Expr)) (bl : List Expr) : MetaM Expr := + let info := paramInOutTys.get! 0 + inferType (← mkInOutTyFromTypeInfo info) + let rec mkFuns (paramInOutTys : List TypeInfo) (bl : List Expr) : MetaM Expr := match paramInOutTys, bl with | [], [] => mkAppOptM ``FixII.Funs.Nil #[finTypeExpr, param_ty, in_ty, out_ty] - | (_, pty, ity, oty) :: paramInOutTys, b :: bl => do + | info :: paramInOutTys, b :: bl => do + let pty := info.params_ty + let ity := info.in_ty + let oty := info.out_ty -- Retrieving ity and oty - this is not very clean let paramInOutTysExpr ← mkListLit paramInOutTyType - (← paramInOutTys.mapM (λ (_, x, y, z) => mkInOutTy x y z)) + (← paramInOutTys.mapM mkInOutTyFromTypeInfo) let fl ← mkFuns paramInOutTys bl mkAppOptM ``FixII.Funs.Cons #[finTypeExpr, param_ty, in_ty, out_ty, pty, ity, oty, paramInOutTysExpr, b, fl] | _, _ => throwError "mkDeclareMutRecBody: `tys` and `bodies` don't have the same length" - let bodyFuns ← mkFuns paramInOutTys bodies.toList + let bodyFuns ← mkFuns paramInOutTys.toList bodies.toList -- Wrap in `get_fun` let body ← mkAppM ``FixII.get_fun #[bodyFuns, i_var, kk_var] -- Add the index `i` and the continuation `k` as a variables @@ -574,7 +634,7 @@ mutual ``` -/ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do - trace[Diverge.def.valid] "proveValid: {e}" + trace[Diverge.def.valid] "proveExprIsValid: {e}" match e with | .const _ _ => throwError "Unimplemented" -- Shouldn't get there? | .bvar _ @@ -602,160 +662,173 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do proveNoKExprIsValid k_var e | .app .. => e.withApp fun f args => do - -- There are several cases: first, check if this is a match/if - -- Check if the expression is a (dependent) if then else. - -- We treat the if then else expressions differently from the other matches, - -- and have dedicated theorems for them. - let isIte := e.isIte - if isIte || e.isDIte then do - e.withApp fun f args => do - trace[Diverge.def.valid] "ite/dite: {f}:\n{args}" - if args.size ≠ 5 then - throwError "Wrong number of parameters for {f}: {args}" - let cond := args.get! 1 - let dec := args.get! 2 - -- Prove that the branches are valid - let br0 := args.get! 3 - let br1 := args.get! 4 - let proveBranchValid (br : Expr) : MetaM Expr := - if isIte then proveExprIsValid k_var kk_var br - else do - -- There is a lambda - lambdaOne br fun x br => do - let brValid ← proveExprIsValid k_var kk_var br - mkLambdaFVars #[x] brValid - let br0Valid ← proveBranchValid br0 - let br1Valid ← proveBranchValid br1 - let const := if isIte then ``FixII.is_valid_p_ite else ``FixII.is_valid_p_dite - let eIsValid ← - mkAppOptM const #[none, none, none, none, none, - some k_var, some cond, some dec, none, none, - some br0Valid, some br1Valid] - trace[Diverge.def.valid] "ite/dite: result:\n{eIsValid}:\n{← inferType eIsValid}" - pure eIsValid - -- Check if the expression is a match (this case is for when the elaborator - -- introduces auxiliary definitions to hide the match behind syntactic - -- sugar): - else if let some me := ← matchMatcherApp? e then do - trace[Diverge.def.valid] - "matcherApp: - - params: {me.params} - - motive: {me.motive} - - discrs: {me.discrs} - - altNumParams: {me.altNumParams} - - alts: {me.alts} - - remaining: {me.remaining}" - -- matchMatcherApp does all the work for us: we simply need to gather - -- the information and call the auxiliary helper `proveMatchIsValid` - if me.remaining.size ≠ 0 then - throwError "MatcherApp: non empty remaining array: {me.remaining}" - let me : MatchInfo := { - matcherName := me.matcherName - matcherLevels := me.matcherLevels - params := me.params - motive := me.motive - scruts := me.discrs - branchesNumParams := me.altNumParams - branches := me.alts - } - proveMatchIsValid k_var kk_var me - -- Check if the expression is a raw match (this case is for when the expression - -- is a direct call to the primitive `casesOn` function, without syntactic sugar). - -- We have to check this case because functions like `mkSigmasMatch`, which we - -- use to currify function bodies, introduce such raw matches. - else if ← isCasesExpr f then do - trace[Diverge.def.valid] "rawMatch: {e}" - /- Deconstruct the match, and call the auxiliary helper `proveMatchIsValid`. - - The casesOn definition is always of the following shape: - - input parameters (implicit parameters) - - motive (implicit), -- the motive gives the return type of the match - - scrutinee (explicit) - - branches (explicit). - In particular, we notice that the scrutinee is the first *explicit* - parameter - this is how we spot it. - -/ - let matcherName := f.constName! - let matcherLevels := f.constLevels!.toArray - -- Find the first explicit parameter: this is the scrutinee - forallTelescope (← inferType f) fun xs _ => do - let rec findFirstExplicit (i : Nat) : MetaM Nat := do - if i ≥ xs.size then throwError "Unexpected: could not find an explicit parameter" - else - let x := xs.get! i - let xFVarId := x.fvarId! - let localDecl ← xFVarId.getDecl - match localDecl.binderInfo with - | .default => pure i - | _ => findFirstExplicit (i + 1) - let scrutIdx ← findFirstExplicit 0 - -- Split the arguments - let params := args.extract 0 (scrutIdx - 1) - let motive := args.get! (scrutIdx - 1) - let scrut := args.get! scrutIdx - let branches := args.extract (scrutIdx + 1) args.size - /- Compute the number of parameters for the branches: for this we use - the type of the uninstantiated casesOn constant (we can't just - destruct the lambdas in the branch expressions because the result - of a match might be a lambda expression). -/ - let branchesNumParams : Array Nat ← do - let env ← getEnv - let decl := env.constants.find! matcherName - let ty := decl.type - forallTelescope ty fun xs _ => do - let xs := xs.extract (scrutIdx + 1) xs.size - xs.mapM fun x => do - let xty ← inferType x - forallTelescope xty fun ys _ => do - pure ys.size - let me : MatchInfo := { - matcherName, - matcherLevels, - params, - motive, - scruts := #[scrut], - branchesNumParams, - branches, - } - proveMatchIsValid k_var kk_var me - -- Check if this is a monadic let-binding - else if f.isConstOf ``Bind.bind then do - trace[Diverge.def.valid] "bind:\n{args}" - -- We simply need to prove that the subexpressions are valid, and call - -- the appropriate lemma. - let x := args.get! 4 - let y := args.get! 5 - -- Prove that the subexpressions are valid - let xValid ← proveExprIsValid k_var kk_var x - trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}" - let yValid ← do - -- This is a lambda expression - lambdaOne y fun x y => do - trace[Diverge.def.valid] "bind: y: {y}" - let yValid ← proveExprIsValid k_var kk_var y - trace[Diverge.def.valid] "bind: yValid (no forall): {yValid}" - trace[Diverge.def.valid] "bind: yValid: x: {x}" - let yValid ← mkLambdaFVars #[x] yValid - trace[Diverge.def.valid] "bind: yValid (forall): {yValid}: {← inferType yValid}" - pure yValid - -- Put everything together - trace[Diverge.def.valid] "bind:\n- xValid: {xValid}: {← inferType xValid}\n- yValid: {yValid}: {← inferType yValid}" - mkAppM ``FixII.is_valid_p_bind #[xValid, yValid] - -- Check if this is a recursive call, i.e., a call to the continuation `kk` - else if f.isFVarOf kk_var.fvarId! then do - trace[Diverge.def.valid] "rec: args: \n{args}" - if args.size ≠ 3 then throwError "Recursive call with invalid number of parameters: {args}" - let i_arg := args.get! 0 - let t_arg := args.get! 1 - let x_arg := args.get! 2 - let eIsValid ← mkAppM ``FixII.is_valid_p_rec #[k_var, i_arg, t_arg, x_arg] - trace[Diverge.def.valid] "rec: result: \n{eIsValid}" - pure eIsValid + proveAppIsValid k_var kk_var e f args + +partial def proveAppIsValid (k_var kk_var : Expr) (e : Expr) (f : Expr) (args : Array Expr): MetaM Expr := do + trace[Diverge.def.valid] "proveAppIsValid: {f} {args}" + /- There are several cases: first, check if this is a match/if + Check if the expression is a (dependent) if then else. + We treat the if then else expressions differently from the other matches, + and have dedicated theorems for them. -/ + let isIte := e.isIte + if isIte || e.isDIte then do + e.withApp fun f args => do + trace[Diverge.def.valid] "ite/dite: {f}:\n{args}" + if args.size ≠ 5 then + throwError "Wrong number of parameters for {f}: {args}" + let cond := args.get! 1 + let dec := args.get! 2 + -- Prove that the branches are valid + let br0 := args.get! 3 + let br1 := args.get! 4 + let proveBranchValid (br : Expr) : MetaM Expr := + if isIte then proveExprIsValid k_var kk_var br else do - -- Remaining case: normal application. - -- It shouldn't use the continuation. - -- TODO: actually, it can - proveNoKExprIsValid k_var e + -- There is a lambda + lambdaOne br fun x br => do + let brValid ← proveExprIsValid k_var kk_var br + mkLambdaFVars #[x] brValid + let br0Valid ← proveBranchValid br0 + let br1Valid ← proveBranchValid br1 + let const := if isIte then ``FixII.is_valid_p_ite else ``FixII.is_valid_p_dite + let eIsValid ← + mkAppOptM const #[none, none, none, none, none, + some k_var, some cond, some dec, none, none, + some br0Valid, some br1Valid] + trace[Diverge.def.valid] "ite/dite: result:\n{eIsValid}:\n{← inferType eIsValid}" + pure eIsValid + /- Check if the expression is a match (this case is for when the elaborator + introduces auxiliary definitions to hide the match behind syntactic + sugar): -/ + else if let some me := ← matchMatcherApp? e then do + trace[Diverge.def.valid] + "matcherApp: + - params: {me.params} + - motive: {me.motive} + - discrs: {me.discrs} + - altNumParams: {me.altNumParams} + - alts: {me.alts} + - remaining: {me.remaining}" + -- matchMatcherApp does all the work for us: we simply need to gather + -- the information and call the auxiliary helper `proveMatchIsValid` + if me.remaining.size ≠ 0 then + throwError "MatcherApp: non empty remaining array: {me.remaining}" + let me : MatchInfo := { + matcherName := me.matcherName + matcherLevels := me.matcherLevels + params := me.params + motive := me.motive + scruts := me.discrs + branchesNumParams := me.altNumParams + branches := me.alts + } + proveMatchIsValid k_var kk_var me + /- Check if the expression is a raw match (this case is for when the expression + is a direct call to the primitive `casesOn` function, without syntactic sugar). + We have to check this case because functions like `mkSigmasMatch`, which we + use to currify function bodies, introduce such raw matches. -/ + else if ← isCasesExpr f then do + trace[Diverge.def.valid] "rawMatch: {e}" + /- Deconstruct the match, and call the auxiliary helper `proveMatchIsValid`. + + The casesOn definition is always of the following shape: + - input parameters (implicit parameters) + - motive (implicit), -- the motive gives the return type of the match + - scrutinee (explicit) + - branches (explicit). + In particular, we notice that the scrutinee is the first *explicit* + parameter - this is how we spot it. + -/ + let matcherName := f.constName! + let matcherLevels := f.constLevels!.toArray + -- Find the first explicit parameter: this is the scrutinee + forallTelescope (← inferType f) fun xs _ => do + let rec findFirstExplicit (i : Nat) : MetaM Nat := do + if i ≥ xs.size then throwError "Unexpected: could not find an explicit parameter" + else + let x := xs.get! i + let xFVarId := x.fvarId! + let localDecl ← xFVarId.getDecl + match localDecl.binderInfo with + | .default => pure i + | _ => findFirstExplicit (i + 1) + let scrutIdx ← findFirstExplicit 0 + -- Split the arguments + let params := args.extract 0 (scrutIdx - 1) + let motive := args.get! (scrutIdx - 1) + let scrut := args.get! scrutIdx + let branches := args.extract (scrutIdx + 1) args.size + /- Compute the number of parameters for the branches: for this we use + the type of the uninstantiated casesOn constant (we can't just + destruct the lambdas in the branch expressions because the result + of a match might be a lambda expression). -/ + let branchesNumParams : Array Nat ← do + let env ← getEnv + let decl := env.constants.find! matcherName + let ty := decl.type + forallTelescope ty fun xs _ => do + let xs := xs.extract (scrutIdx + 1) xs.size + xs.mapM fun x => do + let xty ← inferType x + forallTelescope xty fun ys _ => do + pure ys.size + let me : MatchInfo := { + matcherName, + matcherLevels, + params, + motive, + scruts := #[scrut], + branchesNumParams, + branches, + } + proveMatchIsValid k_var kk_var me + -- Check if this is a monadic let-binding + else if f.isConstOf ``Bind.bind then do + trace[Diverge.def.valid] "bind:\n{args}" + -- We simply need to prove that the subexpressions are valid, and call + -- the appropriate lemma. + let x := args.get! 4 + let y := args.get! 5 + -- Prove that the subexpressions are valid + let xValid ← proveExprIsValid k_var kk_var x + trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}" + let yValid ← do + -- This is a lambda expression + lambdaOne y fun x y => do + trace[Diverge.def.valid] "bind: y: {y}" + let yValid ← proveExprIsValid k_var kk_var y + trace[Diverge.def.valid] "bind: yValid (no forall): {yValid}" + trace[Diverge.def.valid] "bind: yValid: x: {x}" + let yValid ← mkLambdaFVars #[x] yValid + trace[Diverge.def.valid] "bind: yValid (forall): {yValid}: {← inferType yValid}" + pure yValid + -- Put everything together + trace[Diverge.def.valid] "bind:\n- xValid: {xValid}: {← inferType xValid}\n- yValid: {yValid}: {← inferType yValid}" + mkAppM ``FixII.is_valid_p_bind #[xValid, yValid] + -- Check if this is a recursive call, i.e., a call to the continuation `kk` + else if f.isFVarOf kk_var.fvarId! then do + trace[Diverge.def.valid] "rec: args: \n{args}" + if args.size ≠ 3 then throwError "Recursive call with invalid number of parameters: {args}" + let i_arg := args.get! 0 + let t_arg := args.get! 1 + let x_arg := args.get! 2 + let eIsValid ← mkAppM ``FixII.is_valid_p_rec #[k_var, i_arg, t_arg, x_arg] + trace[Diverge.def.valid] "rec: result: \n{eIsValid}" + pure eIsValid + else do + /- Remaining case: normal application. + Check if the arguments use the continuation: + - if no: this is simple + - if yes: we have to lookup theorems in div spec database and continue -/ + trace[Diverge.def.valid] "regular app: {e}" + let allArgsFVars ← args.foldlM (fun hs arg => getFVarIds arg hs) HashSet.empty + if ¬ allArgsFVars.contains kk_var.fvarId! then do + -- Simple case + trace[Diverge.def.valid] "kk doesn't appear in the arguments" + proveNoKExprIsValid k_var e + else do + -- Lookup in the database for suitable theorems + throwError "TODO: {e}" -- Prove that a match expression is valid. partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Expr := do @@ -912,6 +985,7 @@ def proveMutRecIsValid -- Then prove that the mut rec body is valid trace[Diverge.def.valid] "## Proving that the 'Funs' body is valid" let isValid ← proveFunsBodyIsValid paramInOutTys bodyFuns k_var bodiesValid + trace[Diverge.def.valid] "Generated the term: {isValid}" -- Save the theorem let thmTy ← mkAppM ``FixII.is_valid #[mutRecBodyConst] let name := grName ++ "mut_rec_body_is_valid" @@ -935,18 +1009,18 @@ def proveMutRecIsValid def is_odd (i : Int) : Result Bool := mut_rec_body 1 i ``` -/ -def mkDeclareFixDefs (mutRecBody : Expr) (paramInOutTys : Array (ℕ × Expr × Expr × Expr)) (preDefs : Array PreDefinition) : +def mkDeclareFixDefs (mutRecBody : Expr) (paramInOutTys : Array TypeInfo) (preDefs : Array PreDefinition) : TermElabM (Array Name) := do let grSize := preDefs.size let defs ← preDefs.mapIdxM fun idx preDef => do lambdaTelescope preDef.value fun xs _ => do -- Retrieve the parameters info - let (num_params, param_ty, _, _) := paramInOutTys.get! idx.val + let type_info := paramInOutTys.get! idx.val -- Create the index let idx ← mkFinVal grSize idx.val -- Group the inputs into two tuples - let (params_args, input_args) := xs.toList.splitAt num_params - let params ← mkSigmasVal param_ty params_args + let (params_args, input_args) := xs.toList.splitAt type_info.num_params + let params ← mkSigmasVal type_info.params_ty params_args let input ← mkProdsVal input_args -- Apply the fixed point let fixedBody ← mkAppM ``FixII.fix #[mutRecBody, idx, params, input] @@ -968,7 +1042,7 @@ def mkDeclareFixDefs (mutRecBody : Expr) (paramInOutTys : Array (ℕ × Expr × -- Prove the equations that we will use as unfolding theorems partial def proveUnfoldingThms (isValidThm : Expr) - (paramInOutTys : Array (ℕ × Expr × Expr × Expr)) + (paramInOutTys : Array TypeInfo) (preDefs : Array PreDefinition) (decls : Array Name) : MetaM Unit := do let grSize := preDefs.size let proveIdx (i : Nat) : MetaM Unit := do @@ -993,9 +1067,9 @@ partial def proveUnfoldingThms (isValidThm : Expr) let idx ← mkFinVal grSize i let proof ← mkAppM ``congr_fun #[proof, idx] -- Add the input arguments - let (num_params, param_ty, _, _) := paramInOutTys.get! i - let (params, args) := xs.toList.splitAt num_params - let params ← mkSigmasVal param_ty params + let type_info := paramInOutTys.get! i + let (params, args) := xs.toList.splitAt type_info.num_params + let params ← mkSigmasVal type_info.params_ty params let args ← mkProdsVal args let proof ← mkAppM ``congr_fun #[proof, params] let proof ← mkAppM ``congr_fun #[proof, args] @@ -1052,7 +1126,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do we generate: `(Type, λ α => List α × i, λ α => Result α)` -/ - let paramInOutTys : Array (ℕ × Expr × Expr × Expr) ← + let paramInOutTys : Array TypeInfo ← preDefs.mapM (fun preDef => do -- Check the universe parameters - TODO: I'm not sure what the best thing -- to do is. In practice, all the type parameters should be in Type 0, so @@ -1060,19 +1134,20 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do if preDef.levelParams ≠ grLvlParams then throwError "Non-uniform polymorphism in the universes" forallTelescope preDef.type (fun in_tys out_ty => do + let total_num_args := in_tys.size let (params, in_tys) ← splitInputArgs in_tys out_ty trace[Diverge.def] "Decomposed arguments: {preDef.declName}: {params}, {in_tys}, {out_ty}" let num_params := params.size - let params_sigma ← mkSigmasType params.data - let in_tys ← mkSigmasMatchOrUnit params.data (← mkProdsType in_tys.data) + let params_ty ← mkSigmasType params.data + let in_ty ← mkSigmasMatchOrUnit params.data (← mkProdsType in_tys.data) -- Retrieve the type in the "Result" let out_ty ← getResultTy out_ty let out_ty ← mkSigmasMatchOrUnit params.data out_ty - trace[Diverge.def] "inOutTy: {preDef.declName}: {params_sigma}, {in_tys}, {out_ty}" - pure (num_params, params_sigma, in_tys, out_ty))) + trace[Diverge.def] "inOutTy: {preDef.declName}: {params_ty}, {in_tys}, {out_ty}" + pure ⟨ total_num_args, num_params, params_ty, in_ty, out_ty ⟩)) trace[Diverge.def] "paramInOutTys: {paramInOutTys}" -- Turn the list of input types/input args/output type tuples into expressions - let paramInOutTysExpr ← paramInOutTys.mapM (λ (_, x, y, z) => do mkInOutTy x y z) + let paramInOutTysExpr ← liftM (paramInOutTys.mapM mkInOutTyFromTypeInfo) let paramInOutTysExpr ← mkListLit (← inferType (paramInOutTysExpr.get! 0)) paramInOutTysExpr.toList trace[Diverge.def] "paramInOutTys: {paramInOutTys}" @@ -1135,7 +1210,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Generate the mutually recursive body trace[Diverge.def] "# Generating the mut rec body" - let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var param_ty in_ty out_ty paramInOutTys.toList bodies + let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var param_ty in_ty out_ty paramInOutTys bodies trace[Diverge.def] "mut rec body (after decl): {mutRecBody}" -- Prove that the mut rec body satisfies the validity criteria required by @@ -1275,27 +1350,27 @@ elab_rules : command namespace Tests /- Some examples of partial functions -/ - /- section HigherOrder - open FixI - - -- The index type - variable {id : Type u} - - -- The input/output types - variable {a : id → Type v} {b : (i:id) → a i → Type w} - - -- Example with a higher-order function - theorem map_is_valid - {{f : ((i:id) → (x:a i) → Result (b i x)) → (i:id) → (x:a i) → Result (b i x)}} - (Hfvalid : ∀ k i x, is_valid_p k (λ k => f k i x)) - (k : (a → Result b) → a → Result b) - (ls : List a) : - is_valid_p k (λ k => Ex5.map (f k) ls) := - induction ls <;> simp [map] - apply is_valid_p_bind <;> try simp_all - intros - apply is_valid_p_bind <;> try simp_all - end HigherOrder -/ + section HigherOrder + open Ex8 + + inductive Tree (a : Type u) := + | leaf (x : a) + | node (tl : List (Tree a)) + + set_option trace.Diverge.def true + -- set_option trace.Diverge.def.genBody true + set_option trace.Diverge.def.valid true + divergent def id {a : Type u} (t : Tree a) : Result (Tree a) := + match t with + | .leaf x => .ret (.leaf x) + | .node tl => + do + let tl ← map id tl + .ret (.node tl) + + set_option trace.Diverge.def false + + end HigherOrder divergent def list_nth {a: Type u} (ls : List a) (i : Int) : Result a := match ls with diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean index b818d5af..0d33e9d2 100644 --- a/backends/lean/Base/Diverge/ElabBase.lean +++ b/backends/lean/Base/Diverge/ElabBase.lean @@ -1,13 +1,15 @@ import Lean import Base.Utils import Base.Primitives.Base +import Base.Extensions namespace Diverge open Lean Elab Term Meta -open Utils +open Utils Extensions -- We can't define and use trace classes in the same file +initialize registerTraceClass `Diverge initialize registerTraceClass `Diverge.elab initialize registerTraceClass `Diverge.def initialize registerTraceClass `Diverge.def.sigmas @@ -20,4 +22,63 @@ initialize registerTraceClass `Diverge.def.unfold -- For the attribute (for higher-order functions) initialize registerTraceClass `Diverge.attr +-- Attribute + +-- divspec attribute +structure DivSpecAttr where + attr : AttributeImpl + ext : DiscrTreeExtension Name true + deriving Inhabited + +/- The persistent map from expressions to divspec theorems. -/ +initialize divspecAttr : DivSpecAttr ← do + let ext ← mkDiscrTreeExtention `divspecMap true + let attrImpl : AttributeImpl := { + name := `divspec + descr := "Marks theorems to use with the `divergent` encoding" + add := fun thName stx attrKind => do + Attribute.Builtin.ensureNoArgs stx + -- TODO: use the attribute kind + unless attrKind == AttributeKind.global do + throwError "invalid attribute divspec, must be global" + -- Lookup the theorem + let env ← getEnv + let thDecl := env.constants.find! thName + let fKey : Array (DiscrTree.Key true) ← MetaM.run' (do + /- The theorem should have the shape: + `∀ ..., is_valid_p k (λ k => ...)` + + Dive into the ∀: + -/ + let (_, _, fExpr) ← forallMetaTelescope thDecl.type.consumeMData + /- Dive into the argument of `is_valid_p`: -/ + fExpr.consumeMData.withApp fun _ args => do + if args.size ≠ 7 then throwError "Invalid number of arguments to is_valid_p" + let fExpr := args.get! 6 + /- Dive into the lambda: -/ + let (_, _, fExpr) ← lambdaMetaTelescope fExpr.consumeMData + trace[Diverge] "Registering divspec theorem for {fExpr}" + -- Convert the function expression to a discrimination tree key + DiscrTree.mkPath fExpr) + let env := ext.addEntry env (fKey, thName) + setEnv env + trace[Diverge] "Saved the environment" + pure () + } + registerBuiltinAttribute attrImpl + pure { attr := attrImpl, ext := ext } + +def DivSpecAttr.find? (s : DivSpecAttr) (e : Expr) : MetaM (Array Name) := do + (s.ext.getState (← getEnv)).getMatch e + +def DivSpecAttr.getState (s : DivSpecAttr) : MetaM (DiscrTree Name true) := do + pure (s.ext.getState (← getEnv)) + +def showStoredDivSpec : MetaM Unit := do + let st ← divspecAttr.getState + -- TODO: how can we iterate over (at least) the values stored in the tree? + --let s := st.toList.foldl (fun s (f, th) => f!"{s}\n{f} → {th}") f!"" + let s := f!"{st}" + IO.println s + end Diverge |