From 19bde89b84619defc2a822c3bf96bdca9c97eee7 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 28 Jun 2023 12:16:10 +0200 Subject: Reorganize backends/lean/Base --- backends/lean/Base/Diverge/Elab.lean | 182 +++++++++++++++++++++++++++++++++++ 1 file changed, 182 insertions(+) create mode 100644 backends/lean/Base/Diverge/Elab.lean (limited to 'backends/lean/Base/Diverge/Elab.lean') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean new file mode 100644 index 00000000..313c5a79 --- /dev/null +++ b/backends/lean/Base/Diverge/Elab.lean @@ -0,0 +1,182 @@ +import Lean +import Lean.Meta.Tactic.Simp +import Init.Data.List.Basic +import Mathlib.Tactic.RunCmd +import Base.Diverge.Base +import Base.Diverge.ElabBase + +namespace Diverge + +/- Automating the generation of the encoding and the proofs so as to use nice + syntactic sugar. -/ + +syntax (name := divergentDef) + declModifiers "divergent" "def" declId ppIndent(optDeclSig) declVal : command + +open Lean Elab Term Meta Primitives + +initialize registerTraceClass `Diverge.divRecursion (inherited := true) + +set_option trace.Diverge.divRecursion true + +/- The following was copied from the `wfRecursion` function. -/ + +open WF in +def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do + let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) + logInfo ("divRecursion: defs: " ++ msg) + + -- CHANGE HERE This function should add definitions with these names/types/values ^^ + -- Temporarily add the predefinitions as axioms + for preDef in preDefs do + addAsAxiom preDef + + -- TODO: what is this? + for preDef in preDefs do + applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation + + -- Process the definitions + addAndCompilePartialRec preDefs + +-- The following function is copy&pasted from Lean.Elab.PreDefinition.Main +-- This is the only part where we make actual changes and hook into the equation compiler. +-- (I've removed all the well-founded stuff to make it easier to read though.) + +open private ensureNoUnassignedMVarsAtPreDef betaReduceLetRecApps partitionPreDefs + addAndCompilePartial addAsAxioms from Lean.Elab.PreDefinition.Main + +def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLCtx {} {} do + for preDef in preDefs do + trace[Elab.definition.body] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" + let preDefs ← preDefs.mapM ensureNoUnassignedMVarsAtPreDef + let preDefs ← betaReduceLetRecApps preDefs + let cliques := partitionPreDefs preDefs + let mut hasErrors := false + for preDefs in cliques do + trace[Elab.definition.scc] "{preDefs.map (·.declName)}" + try + logInfo "calling divRecursion" + withRef (preDefs[0]!.ref) do + divRecursion preDefs + logInfo "divRecursion succeeded" + catch ex => + -- If it failed, we + logInfo "divRecursion failed" + hasErrors := true + logException ex + let s ← saveState + try + if preDefs.all fun preDef => preDef.kind == DefKind.def || + preDefs.all fun preDef => preDef.kind == DefKind.abbrev then + -- try to add as partial definition + try + addAndCompilePartial preDefs (useSorry := true) + catch _ => + -- Compilation failed try again just as axiom + s.restore + addAsAxioms preDefs + else return () + catch _ => s.restore + +-- The following two functions are copy&pasted from Lean.Elab.MutualDef + +open private elabHeaders levelMVarToParamHeaders getAllUserLevelNames withFunLocalDecls elabFunValues + instantiateMVarsAtHeader instantiateMVarsAtLetRecToLift checkLetRecsToLiftTypes withUsed from Lean.Elab.MutualDef + +def Term.elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM Unit := do + let scopeLevelNames ← getLevelNames + let headers ← elabHeaders views + let headers ← levelMVarToParamHeaders views headers + let allUserLevelNames := getAllUserLevelNames headers + withFunLocalDecls headers fun funFVars => do + for view in views, funFVar in funFVars do + addLocalVarInfo view.declId funFVar + let values ← + try + let values ← elabFunValues headers + Term.synthesizeSyntheticMVarsNoPostponing + values.mapM (instantiateMVars ·) + catch ex => + logException ex + headers.mapM fun header => mkSorry header.type (synthetic := true) + let headers ← headers.mapM instantiateMVarsAtHeader + let letRecsToLift ← getLetRecsToLift + let letRecsToLift ← letRecsToLift.mapM instantiateMVarsAtLetRecToLift + checkLetRecsToLiftTypes funFVars letRecsToLift + withUsed vars headers values letRecsToLift fun vars => do + let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift + for preDef in preDefs do + trace[Elab.definition] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" + let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamPreDecls preDefs + let preDefs ← instantiateMVarsAtPreDecls preDefs + let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames + for preDef in preDefs do + trace[Elab.definition] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}" + checkForHiddenUnivLevels allUserLevelNames preDefs + addPreDefinitions preDefs + +open Command in +def Command.elabMutualDef (ds : Array Syntax) : CommandElabM Unit := do + let views ← ds.mapM fun d => do + let `($mods:declModifiers divergent def $id:declId $sig:optDeclSig $val:declVal) := d + | throwUnsupportedSyntax + let modifiers ← elabModifiers mods + let (binders, type) := expandOptDeclSig sig + let deriving? := none + pure { ref := d, kind := DefKind.def, modifiers, + declId := id, binders, type? := type, value := val, deriving? } + runTermElabM fun vars => Term.elabMutualDef vars views + +-- Special command so that we don't fall back to the built-in mutual when we produce an error. +local syntax "_divergent" Parser.Command.mutual : command +elab_rules : command | `(_divergent mutual $decls* end) => Command.elabMutualDef decls + +macro_rules + | `(mutual $decls* end) => do + unless !decls.isEmpty && decls.all (·.1.getKind == ``divergentDef) do + Macro.throwUnsupported + `(command| _divergent mutual $decls* end) + +open private setDeclIdName from Lean.Elab.Declaration +elab_rules : command + | `($mods:declModifiers divergent%$tk def $id:declId $sig:optDeclSig $val:declVal) => do + let (name, _) := expandDeclIdCore id + if (`_root_).isPrefixOf name then throwUnsupportedSyntax + let view := extractMacroScopes name + let .str ns shortName := view.name | throwUnsupportedSyntax + let shortName' := { view with name := shortName }.review + let cmd ← `(mutual $mods:declModifiers divergent%$tk def $(⟨setDeclIdName id shortName'⟩):declId $sig:optDeclSig $val:declVal end) + if ns matches .anonymous then + Command.elabCommand cmd + else + Command.elabCommand <| ← `(namespace $(mkIdentFrom id ns) $cmd end $(mkIdentFrom id ns)) + +mutual + divergent def is_even (i : Int) : Result Bool := + if i = 0 then return true else return (← is_odd (i - 1)) + + divergent def is_odd (i : Int) : Result Bool := + if i = 0 then return false else return (← is_even (i - 1)) +end + +example (i : Int) : is_even i = .ret (i % 2 = 0) ∧ is_odd i = .ret (i % 2 ≠ 0) := by + induction i + unfold is_even + sorry + +divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := + match ls with + | [] => .fail .panic + | x :: ls => + if i = 0 then return x + else return (← list_nth ls (i - 1)) + +mutual + divergent def foo (i : Int) : Result Nat := + if i > 10 then return (← foo (i / 10)) + (← bar i) else bar 10 + + divergent def bar (i : Int) : Result Nat := + if i > 20 then foo (i / 20) else .ret 42 +end + +end Diverge -- cgit v1.2.3 From a6de153f3bfda7feb27d16fcdf2131d37f99c7a3 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 29 Jun 2023 11:22:32 +0200 Subject: Start working on Elab.lean --- backends/lean/Base/Diverge/Elab.lean | 138 ++++++++++++++++++++++++++++++++--- 1 file changed, 127 insertions(+), 11 deletions(-) (limited to 'backends/lean/Base/Diverge/Elab.lean') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 313c5a79..22e0039f 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -15,16 +15,53 @@ syntax (name := divergentDef) open Lean Elab Term Meta Primitives -initialize registerTraceClass `Diverge.divRecursion (inherited := true) - -set_option trace.Diverge.divRecursion true +set_option trace.Diverge.def true /- The following was copied from the `wfRecursion` function. -/ open WF in + + + +-- Replace the recursive calls by a call to the continuation +-- def replace_rec_calls + +#check Lean.Meta.forallTelescope +#check Expr +#check withRef +#check MonadRef.withRef +#check Nat +#check Array +#check Lean.Meta.inferType +#check Nat +#check Int + +#check (0, 1) +#check Prod +#check () +#check Unit +#check Sigma + +-- print_decl is_even_body +#check instOfNatNat +#check OfNat.ofNat -- @OfNat.ofNat ℕ 2 ... +#check OfNat.ofNat -- @OfNat.ofNat (Fin 2) 1 ... +#check Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat + + +-- TODO: is there already such a utility somewhere? +-- TODO: change to mkSigmas +def mkProds (tys : List Expr) : MetaM Expr := + match tys with + | [] => do return (Expr.const ``PUnit.unit []) + | [ty] => do return ty + | ty :: tys => do + let pty ← mkProds tys + mkAppM ``Prod.mk #[ty, pty] + def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) - logInfo ("divRecursion: defs: " ++ msg) + trace[Diverge.def] ("divRecursion: defs: " ++ msg) -- CHANGE HERE This function should add definitions with these names/types/values ^^ -- Temporarily add the predefinitions as axioms @@ -35,6 +72,85 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do for preDef in preDefs do applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation + -- Retrieve the name of the first definition, that we will use as the namespace + -- for the definitions common to the group + let def0 := preDefs[0]! + let grName := def0.declName + trace[Diverge.def] "group name: {grName}" + + /- Compute the type of the continuation. + + We do the following + - we make sure all the definitions have the same universe parameters + (we can make this more general later) + - we group all the type parameters together, make sure all the + definitions have the same type parameters, and enforce + a uniform polymorphism (we can also lift this later). + This would require generalizing a bit our indexed fixed point to + make the output type parametric in the input. + - we group all the non-type parameters: we parameterize the continuation + by those + -/ + let grLvlParams := def0.levelParams + trace[Diverge.def] "def0 type: {def0.type}" + + -- Small utility: compute the list of type parameters + let getTypeParams (ty: Expr) : MetaM (List Expr × List Expr × Expr) := + Lean.Meta.forallTelescope ty fun tys out_ty => do + trace[Diverge.def] "types: {tys}" +/- let (_, params) ← StateT.run (do + for x in tys do + let ty ← Lean.Meta.inferType x + match ty with + | .sort _ => do + let st ← StateT.get + StateT.set (ty :: st) + | _ => do break + ) ([] : List Expr) + let params := params.reverse + trace[Diverge.def] " type parameters {params}" + return params -/ + let rec get_params (ls : List Expr) : MetaM (List Expr × List Expr) := + match ls with + | x :: tl => do + let ty ← Lean.Meta.inferType x + match ty with + | .sort _ => do + let (ty_params, params) ← get_params tl + return (x :: ty_params, params) + | _ => do return ([], ls) + | _ => do return ([], []) + let (ty_params, params) ← get_params tys.toList + trace[Diverge.def] " parameters: {ty_params}; {params}" + return (ty_params, params, out_ty) + let (grTyParams, _, _) ← do + getTypeParams def0.type + + -- Compute the input types and the output types + let all_tys ← preDefs.mapM fun preDef => do + let (tyParams, params, ret_ty) ← getTypeParams preDef.type + -- TODO: this is not complete, there are more checks to perform + if tyParams.length ≠ grTyParams.length then + throwError "Non-uniform polymorphism" + return (params, ret_ty) + + -- TODO: I think there are issues with the free variables + let (input_tys, output_tys) := List.unzip all_tys.toList + let input_tys : List Expr ← liftM (List.mapM mkProds input_tys) + + trace[Diverge.def] " in/out tys: {input_tys}; {output_tys}" + + -- Compute the names set + let names := preDefs.map PreDefinition.declName + let names := HashSet.empty.insertMany names + + -- + for preDef in preDefs do + trace[Diverge.def] "about to explore: {preDef.declName}" + explore_term "" preDef.value + + -- Compute the bodies + -- Process the definitions addAndCompilePartialRec preDefs @@ -47,21 +163,21 @@ open private ensureNoUnassignedMVarsAtPreDef betaReduceLetRecApps partitionPreDe def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLCtx {} {} do for preDef in preDefs do - trace[Elab.definition.body] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" + trace[Diverge.elab] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" let preDefs ← preDefs.mapM ensureNoUnassignedMVarsAtPreDef let preDefs ← betaReduceLetRecApps preDefs let cliques := partitionPreDefs preDefs let mut hasErrors := false for preDefs in cliques do - trace[Elab.definition.scc] "{preDefs.map (·.declName)}" + trace[Diverge.elab] "{preDefs.map (·.declName)}" try - logInfo "calling divRecursion" + trace[Diverge.elab] "calling divRecursion" withRef (preDefs[0]!.ref) do divRecursion preDefs - logInfo "divRecursion succeeded" + trace[Diverge.elab] "divRecursion succeeded" catch ex => -- If it failed, we - logInfo "divRecursion failed" + trace[Diverge.elab] "divRecursion failed" hasErrors := true logException ex let s ← saveState @@ -106,12 +222,12 @@ def Term.elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM U withUsed vars headers values letRecsToLift fun vars => do let preDefs ← MutualClosure.main vars headers funFVars values letRecsToLift for preDef in preDefs do - trace[Elab.definition] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" + trace[Diverge.elab] "{preDef.declName} : {preDef.type} :=\n{preDef.value}" let preDefs ← withLevelNames allUserLevelNames <| levelMVarToParamPreDecls preDefs let preDefs ← instantiateMVarsAtPreDecls preDefs let preDefs ← fixLevelParams preDefs scopeLevelNames allUserLevelNames for preDef in preDefs do - trace[Elab.definition] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}" + trace[Diverge.elab] "after eraseAuxDiscr, {preDef.declName} : {preDef.type} :=\n{preDef.value}" checkForHiddenUnivLevels allUserLevelNames preDefs addPreDefinitions preDefs -- cgit v1.2.3 From fdc8693772ecb1978873018c790061854f00a015 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 29 Jun 2023 23:15:20 +0200 Subject: Write function to compute the input/output types --- backends/lean/Base/Diverge/Elab.lean | 154 ++++++++++++++++++++++++++++------- 1 file changed, 123 insertions(+), 31 deletions(-) (limited to 'backends/lean/Base/Diverge/Elab.lean') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 22e0039f..116c5d8b 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -13,7 +13,7 @@ namespace Diverge syntax (name := divergentDef) declModifiers "divergent" "def" declId ppIndent(optDeclSig) declVal : command -open Lean Elab Term Meta Primitives +open Lean Elab Term Meta Primitives Lean.Meta set_option trace.Diverge.def true @@ -21,27 +21,9 @@ set_option trace.Diverge.def true open WF in - - -- Replace the recursive calls by a call to the continuation -- def replace_rec_calls -#check Lean.Meta.forallTelescope -#check Expr -#check withRef -#check MonadRef.withRef -#check Nat -#check Array -#check Lean.Meta.inferType -#check Nat -#check Int - -#check (0, 1) -#check Prod -#check () -#check Unit -#check Sigma - -- print_decl is_even_body #check instOfNatNat #check OfNat.ofNat -- @OfNat.ofNat ℕ 2 ... @@ -59,6 +41,100 @@ def mkProds (tys : List Expr) : MetaM Expr := let pty ← mkProds tys mkAppM ``Prod.mk #[ty, pty] +/- Generate the input type of a function body, which is a sigma type (i.e., a + dependent tuple) which groups all its inputs. + + Example: + - xl = [(a:Type), (ls:List a), (i:Int)] + + Generates: + `(a:Type) × (ls:List a) × (i:Int)` + + -/ +def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr := + match xl with + | [] => do + trace[Diverge.def.sigmas] "mkSigmasOfTypes: []" + return (Expr.const ``PUnit.unit []) + | [x] => do + trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}]" + let ty ← Lean.Meta.inferType x + return ty + | x :: xl => do + trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]" + let alpha ← Lean.Meta.inferType x + let sty ← mkSigmasTypesOfTypes xl + trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]: alpha={alpha}, sty={sty}" + let beta ← mkLambdaFVars #[x] sty + trace[Diverge.def.sigmas] "mkSigmasOfTypes: ({alpha}) ({beta})" + mkAppOptM ``Sigma #[some alpha, some beta] + +def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index + +/- Generate the out_ty of the body of a function, which from an input (a sigma + type generated by `mkSigmasTypesOfTypes`) gives the output type of the function. + + Example: + - xl = `[a:Type, ls:List a, i:Int]` + - out_ty = `a` + - index = 0 -- For naming purposes: we use it to numerotate the "scrutinee" variables + + Generates: + ``` + match scrut0 with + | Sigma.mk x scrut1 => + match scrut1 with + | Sigma.mk ls i => + a + ``` +-/ +def mkSigmasOutType (xl : List Expr) (out_ty : Expr) (index : Nat := 0) : MetaM Expr := + match xl with + | [] => do + -- This would be unexpected + throwError "mkSigmasOutType: empyt list of input parameters" + | [x] => do + -- In the explanations above: inner match case + trace[Diverge.def.sigmas] "mkSigmasOutType: [{x}]" + mkLambdaFVars #[x] out_ty + | fst :: xl => do + -- In the explanations above: outer match case + -- Remark: for the naming purposes, we use the same convention as for the + -- fields and parameters in `Sigma.casesOn and `Sigma.mk + trace[Diverge.def.sigmas] "mkSigmasOutType: [{fst}::{xl}]" + let alpha ← Lean.Meta.inferType fst + let snd_ty ← mkSigmasTypesOfTypes xl + let beta ← mkLambdaFVars #[fst] snd_ty + let snd ← mkSigmasOutType xl out_ty (index + 1) + let scrut_ty ← mkSigmasTypesOfTypes (fst :: xl) + withLocalDeclD (mk_indexed_name index) scrut_ty fun scrut => do + let mk ← mkLambdaFVars #[fst] snd + trace[Diverge.def.sigmas] "mkSigmasOutType: scrut: ({scrut}) : ({← inferType scrut})" + let motive ← mkLambdaFVars #[scrut] (← inferType out_ty) + trace[Diverge.def.sigmas] "mkSigmasOutType:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})" + let out ← mkAppOptM ``Sigma.casesOn #[some alpha, some beta, some motive, some scrut, some mk] + let out ← mkLambdaFVars #[scrut] out + trace[Diverge.def.sigmas] "mkSigmasOutType: out: {out}" + return out + +/- Small tests for list_nth: give a model of what `mkSigmasOutType` should generate -/ +private def list_nth_out_ty2 (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) := + @Sigma.casesOn (List a) + (fun (_ls : List a) => Int) + (fun (_scrut1:@Sigma (List a) (fun (_ls : List a) => Int)) => Type) + scrut1 + (fun (_ls : List a) (_i : Int) => Diverge.Primitives.Result a) + +private def list_nth_out_ty1 (scrut0 : @Sigma (Type) (fun (a:Type) => + @Sigma (List a) (fun (_ls : List a) => Int))) := + @Sigma.casesOn (Type) + (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int)) + (fun (_scrut0:@Sigma (Type) (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int))) => Type) + scrut0 + (fun (a : Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) => + list_nth_out_ty2 a scrut1) +/- -/ + def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) trace[Diverge.def] ("divRecursion: defs: " ++ msg) @@ -94,7 +170,23 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let grLvlParams := def0.levelParams trace[Diverge.def] "def0 type: {def0.type}" - -- Small utility: compute the list of type parameters + -- Compute the list of pairs: (input type × output type) + let inOutTys : Array (Expr × Expr) ← + preDefs.mapM (fun preDef => do + -- Check the universe parameters - TODO: I'm not sure what the best thing + -- to do is. In practice, all the type parameters should be in Type 0, so + -- we shouldn't have universe issues. + if preDef.levelParams ≠ grLvlParams then + throwError "Non-uniform polymorphism in the universes" + forallTelescope preDef.type (fun in_tys out_ty => do + let in_ty ← liftM (mkSigmasTypesOfTypes in_tys.toList) + let out_ty ← liftM (mkSigmasOutType in_tys.toList out_ty) + return (in_ty, out_ty) + ) + ) + trace[Diverge.def] "inOutTys: {inOutTys}" + +/- -- Small utility: compute the list of type parameters let getTypeParams (ty: Expr) : MetaM (List Expr × List Expr × Expr) := Lean.Meta.forallTelescope ty fun tys out_ty => do trace[Diverge.def] "types: {tys}" @@ -138,16 +230,16 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let (input_tys, output_tys) := List.unzip all_tys.toList let input_tys : List Expr ← liftM (List.mapM mkProds input_tys) - trace[Diverge.def] " in/out tys: {input_tys}; {output_tys}" + trace[Diverge.def] " in/out tys: {input_tys}; {output_tys}" -/ -- Compute the names set let names := preDefs.map PreDefinition.declName let names := HashSet.empty.insertMany names -- - for preDef in preDefs do - trace[Diverge.def] "about to explore: {preDef.declName}" - explore_term "" preDef.value + -- for preDef in preDefs do + -- trace[Diverge.def] "about to explore: {preDef.declName}" + -- explore_term "" preDef.value -- Compute the bodies @@ -267,6 +359,13 @@ elab_rules : command else Command.elabCommand <| ← `(namespace $(mkIdentFrom id ns) $cmd end $(mkIdentFrom id ns)) +divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := + match ls with + | [] => .fail .panic + | x :: ls => + if i = 0 then return x + else return (← list_nth ls (i - 1)) + mutual divergent def is_even (i : Int) : Result Bool := if i = 0 then return true else return (← is_odd (i - 1)) @@ -280,13 +379,6 @@ example (i : Int) : is_even i = .ret (i % 2 = 0) ∧ is_odd i = .ret (i % 2 ≠ unfold is_even sorry -divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := - match ls with - | [] => .fail .panic - | x :: ls => - if i = 0 then return x - else return (← list_nth ls (i - 1)) - mutual divergent def foo (i : Int) : Result Nat := if i > 10 then return (← foo (i / 10)) + (← bar i) else bar 10 -- cgit v1.2.3 From 1c9331ce92b68b9a83c601212149a6c24591708f Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 30 Jun 2023 15:53:39 +0200 Subject: Generate the fixed-point bodies in Elab.lean --- backends/lean/Base/Diverge/Elab.lean | 451 ++++++++++++++++++++++++++--------- 1 file changed, 344 insertions(+), 107 deletions(-) (limited to 'backends/lean/Base/Diverge/Elab.lean') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 116c5d8b..f7de7518 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -16,31 +16,62 @@ syntax (name := divergentDef) open Lean Elab Term Meta Primitives Lean.Meta set_option trace.Diverge.def true +-- set_option trace.Diverge.def.sigmas true /- The following was copied from the `wfRecursion` function. -/ open WF in --- Replace the recursive calls by a call to the continuation --- def replace_rec_calls +def mkList (xl : List Expr) (ty : Expr) : MetaM Expr := + match xl with + | [] => + mkAppOptM ``List.nil #[some ty] + | x :: tl => do + let tl ← mkList tl ty + mkAppOptM ``List.cons #[some ty, some x, some tl] --- print_decl is_even_body -#check instOfNatNat -#check OfNat.ofNat -- @OfNat.ofNat ℕ 2 ... -#check OfNat.ofNat -- @OfNat.ofNat (Fin 2) 1 ... -#check Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat +def mkProd (x y : Expr) : MetaM Expr := + mkAppM ``Prod.mk #[x, y] +def mkInOutTy (x y : Expr) : MetaM Expr := + mkAppM ``FixI.mk_in_out_ty #[x, y] -- TODO: is there already such a utility somewhere? -- TODO: change to mkSigmas def mkProds (tys : List Expr) : MetaM Expr := match tys with - | [] => do return (Expr.const ``PUnit.unit []) - | [ty] => do return ty + | [] => do pure (Expr.const ``PUnit.unit []) + | [ty] => do pure ty | ty :: tys => do let pty ← mkProds tys mkAppM ``Prod.mk #[ty, pty] +-- Return the `a` in `Return a` +def get_result_ty (ty : Expr) : MetaM Expr := + ty.withApp fun f args => do + if ¬ f.isConstOf ``Result ∨ args.size ≠ 1 then + throwError "Invalid argument to get_result_ty: {ty}" + else + pure (args.get! 0) + +-- Group a list of expressions into a dependent tuple +def mkSigmas (xl : List Expr) : MetaM Expr := + match xl with + | [] => do + trace[Diverge.def.sigmas] "mkSigmas: []" + pure (Expr.const ``PUnit.unit []) + | [x] => do + trace[Diverge.def.sigmas] "mkSigmas: [{x}]" + pure x + | fst :: xl => do + trace[Diverge.def.sigmas] "mkSigmas: [{fst}::{xl}]" + let alpha ← Lean.Meta.inferType fst + let snd ← mkSigmas xl + let snd_ty ← inferType snd + let beta ← mkLambdaFVars #[fst] snd_ty + trace[Diverge.def.sigmas] "mkSigmas:\n{alpha}\n{beta}\n{fst}\n{snd}" + mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd] + /- Generate the input type of a function body, which is a sigma type (i.e., a dependent tuple) which groups all its inputs. @@ -55,11 +86,11 @@ def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr := match xl with | [] => do trace[Diverge.def.sigmas] "mkSigmasOfTypes: []" - return (Expr.const ``PUnit.unit []) + pure (Expr.const ``PUnit.unit []) | [x] => do trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}]" let ty ← Lean.Meta.inferType x - return ty + pure ty | x :: xl => do trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]" let alpha ← Lean.Meta.inferType x @@ -71,15 +102,26 @@ def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr := def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index -/- Generate the out_ty of the body of a function, which from an input (a sigma - type generated by `mkSigmasTypesOfTypes`) gives the output type of the function. +/- Given a list of values `[x0:ty0, ..., xn:ty1]` where every `xi` might use the previous + `xj` (j < i) and a value `out` which uses `x0`, ..., `xn`, generate the following + expression: + ``` + fun x:((x0:ty0) × ... × (xn:tyn) => -- **Dependent** tuple + match x with + | (x0, ..., xn) => out + ``` + + The `index` parameter is used for naming purposes: we use it to numerotate the + bound variables that we introduce. Example: + ======== + More precisely: - xl = `[a:Type, ls:List a, i:Int]` - - out_ty = `a` - - index = 0 -- For naming purposes: we use it to numerotate the "scrutinee" variables + - out = `a` + - index = 0 - Generates: + generates: ``` match scrut0 with | Sigma.mk x scrut1 => @@ -88,36 +130,47 @@ def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index a ``` -/ -def mkSigmasOutType (xl : List Expr) (out_ty : Expr) (index : Nat := 0) : MetaM Expr := +partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : MetaM Expr := match xl with | [] => do -- This would be unexpected - throwError "mkSigmasOutType: empyt list of input parameters" + throwError "mkSigmasMatch: empyt list of input parameters" | [x] => do -- In the explanations above: inner match case - trace[Diverge.def.sigmas] "mkSigmasOutType: [{x}]" - mkLambdaFVars #[x] out_ty + trace[Diverge.def.sigmas] "mkSigmasMatch: [{x}]" + mkLambdaFVars #[x] out | fst :: xl => do -- In the explanations above: outer match case -- Remark: for the naming purposes, we use the same convention as for the -- fields and parameters in `Sigma.casesOn and `Sigma.mk - trace[Diverge.def.sigmas] "mkSigmasOutType: [{fst}::{xl}]" + trace[Diverge.def.sigmas] "mkSigmasMatch: [{fst}::{xl}]" let alpha ← Lean.Meta.inferType fst let snd_ty ← mkSigmasTypesOfTypes xl let beta ← mkLambdaFVars #[fst] snd_ty - let snd ← mkSigmasOutType xl out_ty (index + 1) + let snd ← mkSigmasMatch xl out (index + 1) let scrut_ty ← mkSigmasTypesOfTypes (fst :: xl) withLocalDeclD (mk_indexed_name index) scrut_ty fun scrut => do let mk ← mkLambdaFVars #[fst] snd - trace[Diverge.def.sigmas] "mkSigmasOutType: scrut: ({scrut}) : ({← inferType scrut})" - let motive ← mkLambdaFVars #[scrut] (← inferType out_ty) - trace[Diverge.def.sigmas] "mkSigmasOutType:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})" - let out ← mkAppOptM ``Sigma.casesOn #[some alpha, some beta, some motive, some scrut, some mk] - let out ← mkLambdaFVars #[scrut] out - trace[Diverge.def.sigmas] "mkSigmasOutType: out: {out}" - return out - -/- Small tests for list_nth: give a model of what `mkSigmasOutType` should generate -/ + trace[Diverge.def.sigmas] "mkSigmasMatch: scrut: ({scrut}) : ({← inferType scrut})" + -- TODO: make the computation of the motive more efficient + let motive ← do + let out_ty ← inferType out + match out_ty with + | .sort _ | .lit _ | .const .. => + -- The type of the motive doesn't depend on the scrutinee + mkLambdaFVars #[scrut] out_ty + | _ => + -- The type of the motive *may* depend on the scrutinee + -- TODO: make this more efficient (we could change the output type of + -- mkSigmasMatch + mkSigmasMatch (fst :: xl) out_ty + trace[Diverge.def.sigmas] "mkSigmasMatch:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})" + let sm ← mkAppOptM ``Sigma.casesOn #[some alpha, some beta, some motive, some scrut, some mk] + let sm ← mkLambdaFVars #[scrut] sm + trace[Diverge.def.sigmas] "mkSigmasMatch: sm: {sm}" + pure sm + +/- Small tests for list_nth: give a model of what `mkSigmasMatch` should generate -/ private def list_nth_out_ty2 (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) := @Sigma.casesOn (List a) (fun (_ls : List a) => Int) @@ -135,14 +188,199 @@ private def list_nth_out_ty1 (scrut0 : @Sigma (Type) (fun (a:Type) => list_nth_out_ty2 a scrut1) /- -/ +-- TODO: move +-- TODO: we can use Array.mapIdx +@[specialize] def mapiAux (i : Nat) (f : Nat → α → β) : List α → List β + | [] => [] + | a::as => f i a :: mapiAux (i+1) f as + +@[specialize] def mapi (f : Nat → α → β) : List α → List β := mapiAux 0 f + +#check Array.map +-- Return the expression: `Fin n` +-- TODO: use more +def mkFin (n : Nat) : Expr := + mkAppN (.const ``Fin []) #[.lit (.natVal n)] + +-- Return the expression: `i : Fin n` +def mkFinVal (n i : Nat) : MetaM Expr := do + let n_lit : Expr := .lit (.natVal (n - 1)) + let i_lit : Expr := .lit (.natVal i) + -- We could use `trySynthInstance`, but as we know the instance that we are + -- going to use, we can save the lookup + let ofNat ← mkAppOptM ``Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat #[n_lit, i_lit] + mkAppOptM ``OfNat.ofNat #[none, none, ofNat] + +-- TODO: remove? +def mkFinValOld (n i : Nat) : MetaM Expr := do + let finTy := mkFin n + let ofNat ← mkAppM ``OfNat #[finTy, .lit (.natVal i)] + match ← trySynthInstance ofNat with + | LOption.some x => + mkAppOptM ``OfNat.ofNat #[none, none, x] + | _ => throwError "mkFinVal: could not synthesize an instance of {ofNat} " + +/- Generate and declare as individual definitions the bodies for the individual funcions: + - replace the recursive calls with calls to the continutation `k` + - make those bodies take one single dependent tuple as input + + We name the declarations: "[original_name].body". + We return the new declarations. + -/ +def mkDeclareUnaryBodies (grLvlParams : List Name) (k_var : Expr) + (preDefs : Array PreDefinition) : + MetaM (Array Expr) := do + let grSize := preDefs.size + + -- Compute the map from name to index - the continuation has an indexed type: + -- we use the index (a finite number of type `Fin`) to control the function + -- we call at the recursive call + let nameToId : HashMap Name Nat := + let namesIds := mapi (fun i d => (d.declName, i)) preDefs.toList + HashMap.ofList namesIds + + trace[Diverge.def.genBody] "nameToId: {nameToId.toList}" + + -- Auxiliary function to explore the function bodies and replace the + -- recursive calls + let visit_e (e : Expr) : MetaM Expr := do + trace[Diverge.def.genBody] "visiting expression: {e}" + match e with + | .app .. => do + e.withApp fun f args => do + trace[Diverge.def.genBody] "this is an app: {f} {args}" + -- Check if this is a recursive call + if f.isConst then + let name := f.constName! + match nameToId.find? name with + | none => pure e + | some id => + -- This is a recursive call: replace it + -- Compute the index + let i ← mkFinVal grSize id + -- Put the arguments in one big dependent tuple + let args ← mkSigmas args.toList + mkAppM' k_var #[i, args] + else + -- Not a recursive call: do nothing + pure e + | .const name _ => + -- Sanity check: we eliminated all the recursive calls + if (nameToId.find? name).isSome then + throwError "mkUnaryBodies: a recursive call was not eliminated" + else pure e + | _ => pure e + + -- Explore the bodies + preDefs.mapM fun preDef => do + -- Replace the recursive calls + let body ← mapVisit visit_e preDef.value + + -- Change the type + lambdaLetTelescope body fun args body => do + let body ← mkSigmasMatch args.toList body 0 + + -- Add the declaration + let value ← mkLambdaFVars #[k_var] body + let name := preDef.declName.append "body" + let levelParams := grLvlParams + let decl := Declaration.defnDecl { + name := name + levelParams := levelParams + type := ← inferType value -- TODO: change the type + value := value + hints := ReducibilityHints.regular (getMaxHeight (← getEnv) value + 1) + safety := .safe + all := [name] + } + addDecl decl + trace[Diverge.def] "individual body of {preDef.declName}: {body}" + -- Return the constant + let body := Lean.mkConst name (levelParams.map .param) + -- let body ← mkAppM' body #[k_var] + trace[Diverge.def] "individual body (after decl): {body}" + pure body + +-- Generate a unique function body from the bodies of the mutually recursive group, +-- and add it as a declaration in the context +def mkDeclareMutualBody (grName : Name) (grLvlParams : List Name) + (i_var k_var : Expr) + (in_ty out_ty : Expr) (inOutTys : List (Expr × Expr)) + (bodies : Array Expr) : MetaM Expr := do + -- Generate the body + let grSize := bodies.size + let finTypeExpr := mkFin grSize + -- TODO: not very clean + let inOutTyType ← do + let (x, y) := inOutTys.get! 0 + inferType (← mkInOutTy x y) + let rec mkFuns (inOutTys : List (Expr × Expr)) (bl : List Expr) : MetaM Expr := + match inOutTys, bl with + | [], [] => + mkAppOptM ``FixI.Funs.Nil #[finTypeExpr, in_ty, out_ty] + | (ity, oty) :: inOutTys, b :: bl => do + -- Retrieving ity and oty - this is not very clean + let inOutTysExpr ← mkList (← inOutTys.mapM (λ (x, y) => mkInOutTy x y)) inOutTyType + let fl ← mkFuns inOutTys bl + mkAppOptM ``FixI.Funs.Cons #[finTypeExpr, in_ty, out_ty, ity, oty, inOutTysExpr, b, fl] + | _, _ => throwError "mkDeclareMutualBody: `tys` and `bodies` don't have the same length" + let bodyFuns ← mkFuns inOutTys bodies.toList + -- Wrap in `get_fun` + let body ← mkAppM ``FixI.get_fun #[bodyFuns, i_var, k_var] + -- Add the index `i` and the continuation `k` as a variables + let body ← mkLambdaFVars #[k_var, i_var] body + trace[Diverge.def] "mkDeclareMutualBody: body: {body}" + -- Add the declaration + let name := grName.append "mutrec_body" + let levelParams := grLvlParams + let decl := Declaration.defnDecl { + name := name + levelParams := levelParams + type := ← inferType body + value := body + hints := ReducibilityHints.regular (getMaxHeight (← getEnv) body + 1) + safety := .safe + all := [name] + } + addDecl decl + -- Return the constant + pure (Lean.mkConst name (levelParams.map .param)) + +-- Generate the final definions by using the mutual body and the fixed point operator. +def mkDeclareFixDefs (mutBody : Expr) (preDefs : Array PreDefinition) : + TermElabM Unit := do + let grSize := preDefs.size + let _ ← preDefs.mapIdxM fun idx preDef => do + lambdaLetTelescope preDef.value fun xs _ => do + -- Create the index + let idx ← mkFinVal grSize idx.val + -- Group the inputs into a dependent tuple + let input ← mkSigmas xs.toList + -- Apply the fixed point + let fixedBody ← mkAppM ``FixI.fix #[mutBody, idx, input] + let fixedBody ← mkLambdaFVars xs fixedBody + -- Create the declaration + let name := preDef.declName + let decl := Declaration.defnDecl { + name := name + levelParams := preDef.levelParams + type := preDef.type + value := fixedBody + hints := ReducibilityHints.regular (getMaxHeight (← getEnv) fixedBody + 1) + safety := .safe + all := [name] + } + addDecl decl + pure () + def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) trace[Diverge.def] ("divRecursion: defs: " ++ msg) -- CHANGE HERE This function should add definitions with these names/types/values ^^ -- Temporarily add the predefinitions as axioms - for preDef in preDefs do - addAsAxiom preDef + -- for preDef in preDefs do + -- addAsAxiom preDef -- TODO: what is this? for preDef in preDefs do @@ -154,25 +392,14 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let grName := def0.declName trace[Diverge.def] "group name: {grName}" - /- Compute the type of the continuation. - - We do the following - - we make sure all the definitions have the same universe parameters - (we can make this more general later) - - we group all the type parameters together, make sure all the - definitions have the same type parameters, and enforce - a uniform polymorphism (we can also lift this later). - This would require generalizing a bit our indexed fixed point to - make the output type parametric in the input. - - we group all the non-type parameters: we parameterize the continuation - by those - -/ + /- # Compute the input/output types of the continuation `k`. -/ let grLvlParams := def0.levelParams - trace[Diverge.def] "def0 type: {def0.type}" + trace[Diverge.def] "def0 universe levels: {def0.levelParams}" - -- Compute the list of pairs: (input type × output type) + -- We first compute the list of pairs: (input type × output type) let inOutTys : Array (Expr × Expr) ← preDefs.mapM (fun preDef => do + withRef preDef.ref do -- is the withRef useful? -- Check the universe parameters - TODO: I'm not sure what the best thing -- to do is. In practice, all the type parameters should be in Type 0, so -- we shouldn't have universe issues. @@ -180,68 +407,74 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do throwError "Non-uniform polymorphism in the universes" forallTelescope preDef.type (fun in_tys out_ty => do let in_ty ← liftM (mkSigmasTypesOfTypes in_tys.toList) - let out_ty ← liftM (mkSigmasOutType in_tys.toList out_ty) - return (in_ty, out_ty) + -- Retrieve the type in the "Result" + let out_ty ← get_result_ty out_ty + let out_ty ← liftM (mkSigmasMatch in_tys.toList out_ty) + pure (in_ty, out_ty) ) ) trace[Diverge.def] "inOutTys: {inOutTys}" - -/- -- Small utility: compute the list of type parameters - let getTypeParams (ty: Expr) : MetaM (List Expr × List Expr × Expr) := - Lean.Meta.forallTelescope ty fun tys out_ty => do - trace[Diverge.def] "types: {tys}" -/- let (_, params) ← StateT.run (do - for x in tys do - let ty ← Lean.Meta.inferType x - match ty with - | .sort _ => do - let st ← StateT.get - StateT.set (ty :: st) - | _ => do break - ) ([] : List Expr) - let params := params.reverse - trace[Diverge.def] " type parameters {params}" - return params -/ - let rec get_params (ls : List Expr) : MetaM (List Expr × List Expr) := - match ls with - | x :: tl => do - let ty ← Lean.Meta.inferType x - match ty with - | .sort _ => do - let (ty_params, params) ← get_params tl - return (x :: ty_params, params) - | _ => do return ([], ls) - | _ => do return ([], []) - let (ty_params, params) ← get_params tys.toList - trace[Diverge.def] " parameters: {ty_params}; {params}" - return (ty_params, params, out_ty) - let (grTyParams, _, _) ← do - getTypeParams def0.type - - -- Compute the input types and the output types - let all_tys ← preDefs.mapM fun preDef => do - let (tyParams, params, ret_ty) ← getTypeParams preDef.type - -- TODO: this is not complete, there are more checks to perform - if tyParams.length ≠ grTyParams.length then - throwError "Non-uniform polymorphism" - return (params, ret_ty) - - -- TODO: I think there are issues with the free variables - let (input_tys, output_tys) := List.unzip all_tys.toList - let input_tys : List Expr ← liftM (List.mapM mkProds input_tys) - - trace[Diverge.def] " in/out tys: {input_tys}; {output_tys}" -/ - - -- Compute the names set - let names := preDefs.map PreDefinition.declName - let names := HashSet.empty.insertMany names - - -- - -- for preDef in preDefs do - -- trace[Diverge.def] "about to explore: {preDef.declName}" - -- explore_term "" preDef.value - - -- Compute the bodies + -- Turn the list of input/output type pairs into an expresion + let inOutTysExpr ← inOutTys.mapM (λ (x, y) => mkInOutTy x y) + let inOutTysExpr ← mkList inOutTysExpr.toList (← inferType (inOutTysExpr.get! 0)) + + -- From the list of pairs of input/output types, actually compute the + -- type of the continuation `k`. + -- We first introduce the index `i : Fin n` where `n` is the number of + -- functions in the group. + let i_var_ty := mkFin preDefs.size + withLocalDeclD (.num (.str .anonymous "i") 0) i_var_ty fun i_var => do + let in_out_ty ← mkAppM ``List.get #[inOutTysExpr, i_var] + trace[Diverge.def] "in_out_ty := {in_out_ty} : {← inferType in_out_ty}" + -- Add an auxiliary definition for `in_out_ty` + let in_out_ty ← do + let value ← mkLambdaFVars #[i_var] in_out_ty + let name := grName.append "in_out_ty" + let levelParams := grLvlParams + let decl := Declaration.defnDecl { + name := name + levelParams := levelParams + type := ← inferType value + value := value + hints := .abbrev + safety := .safe + all := [name] + } + addDecl decl + -- Return the constant + let in_out_ty := Lean.mkConst name (levelParams.map .param) + mkAppM' in_out_ty #[i_var] + trace[Diverge.def] "in_out_ty (after decl) := {in_out_ty} : {← inferType in_out_ty}" + let in_ty ← mkAppM ``Sigma.fst #[in_out_ty] + trace[Diverge.def] "in_ty: {in_ty}" + withLocalDeclD (.num (.str .anonymous "x") 1) in_ty fun input => do + let out_ty ← mkAppM' (← mkAppM ``Sigma.snd #[in_out_ty]) #[input] + trace[Diverge.def] "out_ty: {out_ty}" + + -- Introduce the continuation `k` + let in_ty ← mkLambdaFVars #[i_var] in_ty + let out_ty ← mkLambdaFVars #[i_var, input] out_ty + let k_var_ty ← mkAppM ``FixI.kk_ty #[i_var_ty, in_ty, out_ty] -- + trace[Diverge.def] "k_var_ty: {k_var_ty}" + withLocalDeclD (.num (.str .anonymous "k") 2) k_var_ty fun k_var => do + trace[Diverge.def] "k_var: {k_var}" + + -- Replace the recursive calls in all the function bodies by calls to the + -- continuation `k` and and generate for those bodies declarations + let bodies ← mkDeclareUnaryBodies grLvlParams k_var preDefs + -- Generate the mutually recursive body + let body ← mkDeclareMutualBody grName grLvlParams i_var k_var in_ty out_ty inOutTys.toList bodies + trace[Diverge.def] "mut rec body (after decl): {body}" + + -- Prove that the mut rec body satisfies the validity criteria required by + -- our fixed-point + -- TODO + + -- Generate the final definitions + let defs ← mkDeclareFixDefs body preDefs + + -- Prove the unfolding equations + -- TODO -- Process the definitions addAndCompilePartialRec preDefs @@ -366,6 +599,10 @@ divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := if i = 0 then return x else return (← list_nth ls (i - 1)) +#print list_nth.in_out_ty +#check list_nth.body +#print list_nth + mutual divergent def is_even (i : Int) : Result Bool := if i = 0 then return true else return (← is_odd (i - 1)) -- cgit v1.2.3 From 37e5d5501e024869037bf0ea1559229a8be62da7 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 3 Jul 2023 16:24:44 +0200 Subject: Generate the proofs of validity in Elab.lean --- backends/lean/Base/Diverge/Elab.lean | 403 ++++++++++++++++++++++++++++++++--- 1 file changed, 371 insertions(+), 32 deletions(-) (limited to 'backends/lean/Base/Diverge/Elab.lean') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index f7de7518..cf40ea8f 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -16,6 +16,7 @@ syntax (name := divergentDef) open Lean Elab Term Meta Primitives Lean.Meta set_option trace.Diverge.def true +set_option trace.Diverge.def.valid true -- set_option trace.Diverge.def.sigmas true /- The following was copied from the `wfRecursion` function. -/ @@ -196,7 +197,6 @@ private def list_nth_out_ty1 (scrut0 : @Sigma (Type) (fun (a:Type) => @[specialize] def mapi (f : Nat → α → β) : List α → List β := mapiAux 0 f -#check Array.map -- Return the expression: `Fin n` -- TODO: use more def mkFin (n : Nat) : Expr := @@ -227,7 +227,7 @@ def mkFinValOld (n i : Nat) : MetaM Expr := do We name the declarations: "[original_name].body". We return the new declarations. -/ -def mkDeclareUnaryBodies (grLvlParams : List Name) (k_var : Expr) +def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) (preDefs : Array PreDefinition) : MetaM (Array Expr) := do let grSize := preDefs.size @@ -260,7 +260,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (k_var : Expr) let i ← mkFinVal grSize id -- Put the arguments in one big dependent tuple let args ← mkSigmas args.toList - mkAppM' k_var #[i, args] + mkAppM' kk_var #[i, args] else -- Not a recursive call: do nothing pure e @@ -281,8 +281,8 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (k_var : Expr) let body ← mkSigmasMatch args.toList body 0 -- Add the declaration - let value ← mkLambdaFVars #[k_var] body - let name := preDef.declName.append "body" + let value ← mkLambdaFVars #[kk_var] body + let name := preDef.declName.append "sbody" let levelParams := grLvlParams let decl := Declaration.defnDecl { name := name @@ -297,16 +297,17 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (k_var : Expr) trace[Diverge.def] "individual body of {preDef.declName}: {body}" -- Return the constant let body := Lean.mkConst name (levelParams.map .param) - -- let body ← mkAppM' body #[k_var] + -- let body ← mkAppM' body #[kk_var] trace[Diverge.def] "individual body (after decl): {body}" pure body -- Generate a unique function body from the bodies of the mutually recursive group, --- and add it as a declaration in the context -def mkDeclareMutualBody (grName : Name) (grLvlParams : List Name) - (i_var k_var : Expr) +-- and add it as a declaration in the context. +-- We return the list of bodies (of type `Funs ...`) and the mutually recursive body. +def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name) + (kk_var i_var : Expr) (in_ty out_ty : Expr) (inOutTys : List (Expr × Expr)) - (bodies : Array Expr) : MetaM Expr := do + (bodies : Array Expr) : MetaM (Expr × Expr) := do -- Generate the body let grSize := bodies.size let finTypeExpr := mkFin grSize @@ -323,15 +324,15 @@ def mkDeclareMutualBody (grName : Name) (grLvlParams : List Name) let inOutTysExpr ← mkList (← inOutTys.mapM (λ (x, y) => mkInOutTy x y)) inOutTyType let fl ← mkFuns inOutTys bl mkAppOptM ``FixI.Funs.Cons #[finTypeExpr, in_ty, out_ty, ity, oty, inOutTysExpr, b, fl] - | _, _ => throwError "mkDeclareMutualBody: `tys` and `bodies` don't have the same length" + | _, _ => throwError "mkDeclareMutRecBody: `tys` and `bodies` don't have the same length" let bodyFuns ← mkFuns inOutTys bodies.toList -- Wrap in `get_fun` - let body ← mkAppM ``FixI.get_fun #[bodyFuns, i_var, k_var] + let body ← mkAppM ``FixI.get_fun #[bodyFuns, i_var, kk_var] -- Add the index `i` and the continuation `k` as a variables - let body ← mkLambdaFVars #[k_var, i_var] body - trace[Diverge.def] "mkDeclareMutualBody: body: {body}" + let body ← mkLambdaFVars #[kk_var, i_var] body + trace[Diverge.def] "mkDeclareMutRecBody: body: {body}" -- Add the declaration - let name := grName.append "mutrec_body" + let name := grName.append "mut_rec_body" let levelParams := grLvlParams let decl := Declaration.defnDecl { name := name @@ -344,10 +345,348 @@ def mkDeclareMutualBody (grName : Name) (grLvlParams : List Name) } addDecl decl -- Return the constant - pure (Lean.mkConst name (levelParams.map .param)) + pure (bodyFuns, Lean.mkConst name (levelParams.map .param)) + +def isCasesExpr (e : Expr) : MetaM Bool := do + let e := e.getAppFn + if e.isConst then + return isCasesOnRecursor (← getEnv) e.constName + else return false + +structure MatchInfo where + matcherName : Name + matcherLevels : Array Level + params : Array Expr + motive : Expr + scruts : Array Expr + branchesNumParams : Array Nat + branches : Array Expr + +instance : ToMessageData MatchInfo where + -- This is not a very clean formatting, but we don't need more + toMessageData := fun me => m!"\n- matcherName: {me.matcherName}\n- params: {me.params}\n- motive: {me.motive}\n- scruts: {me.scruts}\n- branchesNumParams: {me.branchesNumParams}\n- branches: {me.branches}" + +-- An expression which doesn't use the continuation kk is valid +def proveNoKExprIsValid (k_var : Expr) (e : Expr) : MetaM Expr := do + trace[Diverge.def.valid] "proveNoKExprIsValid: {e}" + let eIsValid ← mkAppM ``FixI.is_valid_p_same #[k_var, e] + trace[Diverge.def.valid] "proveNoKExprIsValid: result:\n{eIsValid}:\n{← inferType eIsValid}" + pure eIsValid + +mutual + +partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do + trace[Diverge.def.valid] "proveValid: {e}" + match e with + | .bvar _ + | .fvar _ + | .mvar _ + | .sort _ + | .lit _ + | .const _ _ => throwError "Unimplemented" + | .lam .. => throwError "Unimplemented" + | .forallE .. => throwError "Unreachable" -- Shouldn't get there + | .letE .. => throwError "TODO" + -- lambdaLetTelescope e fun xs b => mapVisitBinders xs do + -- mkLambdaFVars xs (← mapVisit k b) (usedLetOnly := false) + | .mdata _ b => proveExprIsValid k_var kk_var b + | .proj _ _ _ => + -- The projection shouldn't use the continuation + proveNoKExprIsValid k_var e + | .app .. => + e.withApp fun f args => do + -- There are several cases: first, check if this is a match/if + -- The expression is a (dependent) if then else + let isIte := e.isIte + if isIte || e.isDIte then do + e.withApp fun f args => do + trace[Diverge.def.valid] "ite/dite: {f}:\n{args}" + if args.size ≠ 5 then + throwError "Wrong number of parameters for {f}: {args}" + let cond := args.get! 1 + let dec := args.get! 2 + -- Prove that the branches are valid + let br0 := args.get! 3 + let br1 := args.get! 4 + let proveBranchValid (br : Expr) : MetaM Expr := + if isIte then proveExprIsValid k_var kk_var br + else do + -- There is a lambda -- TODO: how do we remove exacly *one* lambda? + lambdaLetTelescope br fun xs br => do + let x := xs.get! 0 + let xs := xs.extract 1 xs.size + let br ← mkLambdaFVars xs br + let brValid ← proveExprIsValid k_var kk_var br + mkLambdaFVars #[x] brValid + let br0Valid ← proveBranchValid br0 + let br1Valid ← proveBranchValid br1 + let const := if isIte then ``FixI.is_valid_p_ite else ``FixI.is_valid_p_dite + let eIsValid ← mkAppOptM const #[none, none, none, none, some k_var, some cond, some dec, none, none, some br0Valid, some br1Valid] + trace[Diverge.def.valid] "ite/dite: result:\n{eIsValid}:\n{← inferType eIsValid}" + pure eIsValid + -- The expression is a match (this case is for when the elaborator + -- introduces auxiliary definitions to hide the match behind syntactic + -- sugar) + else if let some me := ← matchMatcherApp? e then do + trace[Diverge.def.valid] + "matcherApp: + - params: {me.params} + - motive: {me.motive} + - discrs: {me.discrs} + - altNumParams: {me.altNumParams} + - alts: {me.alts} + - remaining: {me.remaining}" + -- matchMatcherApp has already done the work for us + if me.remaining.size ≠ 0 then + throwError "MatcherApp: non empty remaining array: {me.remaining}" + let me : MatchInfo := { + matcherName := me.matcherName + matcherLevels := me.matcherLevels + params := me.params + motive := me.motive + scruts := me.discrs + branchesNumParams := me.altNumParams + branches := me.alts + } + proveMatchIsValid k_var kk_var me + -- The expression is a raw match (this case is for when the expression + -- is a direct call to the primitive `casesOn` function, without + -- syntactic sugar) + else if ← isCasesExpr f then do + trace[Diverge.def.valid] "rawMatch: {e}" + -- The casesOn definition is always of the following shape: + -- input parameters (implicit parameters), then motive (implicit), + -- scrutinee (explicit), branches (explicit). + let matcherName := f.constName! + let matcherLevels := f.constLevels!.toArray + -- Find the first explicit parameter: this is the scrutinee + forallTelescope (← inferType f) fun xs _ => do + let rec findFirstExplicit (i : Nat) : MetaM Nat := do + if i ≥ xs.size then throwError "Unexpected: could not find an explicit parameter" + else + let x := xs.get! i + let xFVarId := x.fvarId! + let localDecl ← xFVarId.getDecl + match localDecl.binderInfo with + | .default => pure i + | _ => findFirstExplicit (i + 1) + let scrutIdx ← findFirstExplicit 0 + -- Split the arguments + let params := args.extract 0 (scrutIdx - 1) + let motive := args.get! (scrutIdx - 1) + let scrut := args.get! scrutIdx + let branches := args.extract (scrutIdx + 1) args.size + -- Compute the number of parameters for the branches: for this we use + -- the type of the uninstantiated casesOn constant + let branchesNumParams : Array Nat ← do + let env ← getEnv + let decl := env.constants.find! matcherName + let ty := decl.type + forallTelescope ty fun xs _ => do + let xs := xs.extract (scrutIdx + 1) xs.size + xs.mapM fun x => do + let xty ← inferType x + forallTelescope xty fun ys _ => do + pure ys.size + let me : MatchInfo := { + matcherName, + matcherLevels, + params, + motive, + scruts := #[scrut], + branchesNumParams, + branches, + } + proveMatchIsValid k_var kk_var me + -- Monadic let-binding + else if f.isConstOf ``Bind.bind then do + trace[Diverge.def.valid] "bind:\n{args}" + let x := args.get! 4 + let y := args.get! 5 + -- Prove that the subexpressions are valid + let xValid ← proveExprIsValid k_var kk_var x + trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}" + let yValid ← do + -- This is a lambda expression -- TODO: how do we remove exacly *one* lambda? + lambdaLetTelescope y fun xs y => do + let x := xs.get! 0 + let xs := xs.extract 1 xs.size + let y ← mkLambdaFVars xs y + trace[Diverge.def.valid] "bind: y: {y}" + let yValid ← proveExprIsValid k_var kk_var y + trace[Diverge.def.valid] "bind: yValid (no forall): {yValid}" + trace[Diverge.def.valid] "bind: yValid: x: {x}" + let yValid ← mkLambdaFVars #[x] yValid + trace[Diverge.def.valid] "bind: yValid (forall): {yValid}: {← inferType yValid}" + pure yValid + -- Put everything together + trace[Diverge.def.valid] "bind:\n- xValid: {xValid}: {← inferType xValid}\n- yValid: {yValid}: {← inferType yValid}" + mkAppM ``FixI.is_valid_p_bind #[xValid, yValid] + -- Recursive call + else if f.isFVarOf kk_var.fvarId! then do + trace[Diverge.def.valid] "rec: args: \n{args}" + if args.size ≠ 2 then throwError "Recursive call with invalid number of parameters: {args}" + let i_arg := args.get! 0 + let x_arg := args.get! 1 + let eIsValid ← mkAppM ``FixI.is_valid_p_rec #[k_var, i_arg, x_arg] + trace[Diverge.def.valid] "rec: result: \n{eIsValid}" + pure eIsValid + else do + -- Remaining case: normal application. + -- It shouldn't use the continuation + proveNoKExprIsValid k_var e + +partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Expr := do + trace[Diverge.def.valid] "proveMatchIsValid: {me}" + -- Prove the validity of the branch expressions + let branchesValid:Array Expr ← me.branches.mapIdxM fun idx br => do + -- Go inside the lambdas - note that we have to be careful: some of the + -- binders might come from the match, and some of the binders might come + -- from the fact that the expression in the match is a lambda expression: + -- we use the branchesNumParams field for this reason + lambdaLetTelescope br fun xs br => do + let numParams := me.branchesNumParams.get! idx + let xs_beg := xs.extract 0 numParams + let xs_end := xs.extract numParams xs.size + let br ← mkLambdaFVars xs_end br + -- Prove that the branch expression is valid + let brValid ← proveExprIsValid k_var kk_var br + -- Reconstruct the lambda expression + mkLambdaFVars xs_beg brValid + trace[Diverge.def.valid] "branchesValid:\n{branchesValid}" + -- Put together: compute the motive. + -- It must be of the shape: + -- ``` + -- λ scrut => is_valid_p k (λ k => match scrut with ...) + -- ``` + let validMotive : Expr ← do + -- The motive is a function of the scrutinees (i.e., a lambda expression): + -- introduce binders for the scrutinees + let declInfos := me.scruts.mapIdx fun idx scrut => + let name : Name := (.num (.str .anonymous "scrut") idx) + let ty := λ (_ : Array Expr) => inferType scrut + (name, ty) + withLocalDeclsD declInfos fun scrutVars => do + -- Create a match expression but where the scrutinees have been replaced + -- by variables + let params : Array (Option Expr) := me.params.map some + let motive : Option Expr := some me.motive + let scruts : Array (Option Expr) := scrutVars.map some + let branches : Array (Option Expr) := me.branches.map some + let args := params ++ [motive] ++ scruts ++ branches + let matchE ← mkAppOptM me.matcherName args + -- let matchE ← mkLambdaFVars scrutVars (← mkAppOptM me.matcherName args) + -- Wrap in the `is_valid_p` predicate + let matchE ← mkLambdaFVars #[kk_var] matchE + let validMotive ← mkAppM ``FixI.is_valid_p #[k_var, matchE] + -- Abstract away the scrutinee variables + mkLambdaFVars scrutVars validMotive + trace[Diverge.def.valid] "valid motive: {validMotive}" + -- Put together + let valid ← do + let params : Array (Option Expr) := me.params.map (λ _ => none) + let motive := some validMotive + let scruts := me.scruts.map some + let branches := branchesValid.map some + let args := params ++ [motive] ++ scruts ++ branches + mkAppOptM me.matcherName args + trace[Diverge.def.valid] "proveMatchIsValid:\n{valid}:\n{← inferType valid}" + pure valid + +end + +-- Prove that a single body (in the mutually recursive group) is valid +partial def proveSingleBodyIsValid (k_var : Expr) (preDef : PreDefinition) (bodyConst : Expr) : + MetaM Expr := do + trace[Diverge.def.valid] "proveSingleBodyIsValid: bodyConst: {bodyConst}" + -- Lookup the definition (`bodyConst` is the definition of the body, we want + -- to retrieve the value itself to dive inside) + let name := bodyConst.constName! + let env ← getEnv + let body := (env.constants.find! name).value! + trace[Diverge.def.valid] "body: {body}" + lambdaLetTelescope body fun xs body => do + assert! xs.size = 2 + let kk_var := xs.get! 0 + let x_var := xs.get! 1 + -- State the type of the theorem to prove + let thmTy ← mkAppM ``FixI.is_valid_p + #[k_var, ← mkLambdaFVars #[kk_var] (← mkAppM' bodyConst #[kk_var, x_var])] + trace[Diverge.def.valid] "thmTy: {thmTy}" + -- Prove that the body is valid + let proof ← proveExprIsValid k_var kk_var body + let proof ← mkLambdaFVars #[k_var, x_var] proof + trace[Diverge.def.valid] "proveSingleBodyIsValid: proof:\n{proof}:\n{← inferType proof}" + -- The target type (we don't have to do this: this is simply a sanity check, + -- and this allows a nicer debugging output) + let thmTy ← do + let body ← mkAppM' bodyConst #[kk_var, x_var] + let body ← mkLambdaFVars #[kk_var] body + let ty ← mkAppM ``FixI.is_valid_p #[k_var, body] + mkForallFVars #[k_var, x_var] ty + trace[Diverge.def.valid] "proveSingleBodyIsValid: thmTy\n{thmTy}:\n{← inferType thmTy}" + -- Save the theorem + let name := preDef.declName ++ "sbody_is_valid" + let decl := Declaration.thmDecl { + name + levelParams := preDef.levelParams + type := thmTy + value := proof + all := [name] + } + addDecl decl + trace[Diverge.def.valid] "proveSingleBodyIsValid: added thm: {name}" + -- Return the theorem + pure (Expr.const name (preDef.levelParams.map .param)) + +partial def proveFunsBodyIsValid (inOutTys: Expr) (bodyFuns : Expr) + (k_var : Expr) (bodiesValid : Array Expr) : MetaM Expr := do + -- Create the big "and" expression, which groups the validity proof of the individual bodies + let rec mkValidConj (i : Nat) : MetaM Expr := do + if i = bodiesValid.size then + -- We reached the end + mkAppM ``FixI.Funs.is_valid_p_Nil #[k_var] + else do + -- We haven't reached the end: introduce a conjunction + let valid := bodiesValid.get! i + let valid ← mkAppM' valid #[k_var] + mkAppM ``And.intro #[valid, ← mkValidConj (i + 1)] + let andExpr ← mkValidConj 0 + -- Wrap in the `is_valid_p_is_valid_p` theorem, and abstract the continuation + let isValid ← mkAppM ``FixI.Funs.is_valid_p_is_valid_p #[inOutTys, k_var, bodyFuns, andExpr] + mkLambdaFVars #[k_var] isValid + +-- Prove that the mut rec body is valid +-- TODO: maybe this function should introduce k_var itself +def proveMutRecIsValid + (grName : Name) (grLvlParams : List Name) + (inOutTys : Expr) (bodyFuns mutRecBodyConst : Expr) + (k_var : Expr) (preDefs : Array PreDefinition) + (bodies : Array Expr) : MetaM Expr := do + -- First prove that the individual bodies are valid + let bodiesValid ← + bodies.mapIdxM fun idx body => do + let preDef := preDefs.get! idx + proveSingleBodyIsValid k_var preDef body + -- Then prove that the mut rec body is valid + let isValid ← proveFunsBodyIsValid inOutTys bodyFuns k_var bodiesValid + -- Save the theorem + let thmTy ← mkAppM ``FixI.is_valid #[mutRecBodyConst] + let name := grName ++ "mut_rec_body_is_valid" + let decl := Declaration.thmDecl { + name + levelParams := grLvlParams + type := thmTy + value := isValid + all := [name] + } + addDecl decl + trace[Diverge.def.valid] "proveFunsBodyIsValid: added thm: {name}:\n{thmTy}" + -- Return the theorem + pure (Expr.const name (grLvlParams.map .param)) -- Generate the final definions by using the mutual body and the fixed point operator. -def mkDeclareFixDefs (mutBody : Expr) (preDefs : Array PreDefinition) : +def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : TermElabM Unit := do let grSize := preDefs.size let _ ← preDefs.mapIdxM fun idx preDef => do @@ -357,7 +696,7 @@ def mkDeclareFixDefs (mutBody : Expr) (preDefs : Array PreDefinition) : -- Group the inputs into a dependent tuple let input ← mkSigmas xs.toList -- Apply the fixed point - let fixedBody ← mkAppM ``FixI.fix #[mutBody, idx, input] + let fixedBody ← mkAppM ``FixI.fix #[mutRecBody, idx, input] let fixedBody ← mkLambdaFVars xs fixedBody -- Create the declaration let name := preDef.declName @@ -454,24 +793,26 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Introduce the continuation `k` let in_ty ← mkLambdaFVars #[i_var] in_ty let out_ty ← mkLambdaFVars #[i_var, input] out_ty - let k_var_ty ← mkAppM ``FixI.kk_ty #[i_var_ty, in_ty, out_ty] -- - trace[Diverge.def] "k_var_ty: {k_var_ty}" - withLocalDeclD (.num (.str .anonymous "k") 2) k_var_ty fun k_var => do - trace[Diverge.def] "k_var: {k_var}" + let kk_var_ty ← mkAppM ``FixI.kk_ty #[i_var_ty, in_ty, out_ty] + trace[Diverge.def] "kk_var_ty: {kk_var_ty}" + withLocalDeclD (.num (.str .anonymous "kk") 2) kk_var_ty fun kk_var => do + trace[Diverge.def] "kk_var: {kk_var}" -- Replace the recursive calls in all the function bodies by calls to the -- continuation `k` and and generate for those bodies declarations - let bodies ← mkDeclareUnaryBodies grLvlParams k_var preDefs + let bodies ← mkDeclareUnaryBodies grLvlParams kk_var preDefs -- Generate the mutually recursive body - let body ← mkDeclareMutualBody grName grLvlParams i_var k_var in_ty out_ty inOutTys.toList bodies - trace[Diverge.def] "mut rec body (after decl): {body}" + let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var in_ty out_ty inOutTys.toList bodies + trace[Diverge.def] "mut rec body (after decl): {mutRecBody}" -- Prove that the mut rec body satisfies the validity criteria required by -- our fixed-point - -- TODO + let k_var_ty ← mkAppM ``FixI.k_ty #[i_var_ty, in_ty, out_ty] + withLocalDeclD (.num (.str .anonymous "k") 3) k_var_ty fun k_var => do + let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies -- Generate the final definitions - let defs ← mkDeclareFixDefs body preDefs + let defs ← mkDeclareFixDefs mutRecBody preDefs -- Prove the unfolding equations -- TODO @@ -496,13 +837,10 @@ def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLC for preDefs in cliques do trace[Diverge.elab] "{preDefs.map (·.declName)}" try - trace[Diverge.elab] "calling divRecursion" withRef (preDefs[0]!.ref) do divRecursion preDefs - trace[Diverge.elab] "divRecursion succeeded" catch ex => - -- If it failed, we - trace[Diverge.elab] "divRecursion failed" + -- If it failed, we add the functions as partial functions hasErrors := true logException ex let s ← saveState @@ -600,7 +938,8 @@ divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := else return (← list_nth ls (i - 1)) #print list_nth.in_out_ty -#check list_nth.body +#check list_nth.sbody +#check list_nth.mut_rec_body #print list_nth mutual -- cgit v1.2.3 From 7ceab6a725e5bd17c05bfd381753e453b15afaf7 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 3 Jul 2023 16:46:59 +0200 Subject: Add a missing case in the validity proofs --- backends/lean/Base/Diverge/Elab.lean | 25 +++++++++++++++++++------ 1 file changed, 19 insertions(+), 6 deletions(-) (limited to 'backends/lean/Base/Diverge/Elab.lean') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index cf40ea8f..063480a2 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -378,17 +378,22 @@ mutual partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do trace[Diverge.def.valid] "proveValid: {e}" match e with + | .const _ _ => throwError "Unimplemented" -- Shouldn't get there? | .bvar _ | .fvar _ - | .mvar _ - | .sort _ | .lit _ - | .const _ _ => throwError "Unimplemented" + | .mvar _ + | .sort _ => throwError "Unreachable" | .lam .. => throwError "Unimplemented" | .forallE .. => throwError "Unreachable" -- Shouldn't get there - | .letE .. => throwError "TODO" - -- lambdaLetTelescope e fun xs b => mapVisitBinders xs do - -- mkLambdaFVars xs (← mapVisit k b) (usedLetOnly := false) + | .letE dName dTy dValue body _nonDep => do + -- Introduce a local declaration for the let-binding + withLetDecl dName dTy dValue fun decl => do + let isValid ← proveExprIsValid k_var kk_var body + -- Add the let-binding around (rem.: the let-binding should be + -- *inside* the `is_valid_p`, not outside, but because it reduces + -- in the end it doesn't matter) + mkLetFVars #[decl] isValid | .mdata _ b => proveExprIsValid k_var kk_var b | .proj _ _ _ => -- The projection shouldn't use the continuation @@ -963,4 +968,12 @@ mutual if i > 20 then foo (i / 20) else .ret 42 end +-- Testing dependent branching and let-bindings +-- TODO: why the linter warning? +divergent def is_non_zero (i : Int) : Result Bool := + if _h:i = 0 then return false + else + let b := true + return b + end Diverge -- cgit v1.2.3 From 9214484c471ad931924865855687f9a2ffe255dd Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 3 Jul 2023 18:02:52 +0200 Subject: Automate the proofs of the unfolding theorems for Diverge --- backends/lean/Base/Diverge/Elab.lean | 107 ++++++++++++++++++++++++++++------- 1 file changed, 88 insertions(+), 19 deletions(-) (limited to 'backends/lean/Base/Diverge/Elab.lean') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 063480a2..91c51a31 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -16,8 +16,9 @@ syntax (name := divergentDef) open Lean Elab Term Meta Primitives Lean.Meta set_option trace.Diverge.def true -set_option trace.Diverge.def.valid true +-- set_option trace.Diverge.def.valid true -- set_option trace.Diverge.def.sigmas true +set_option trace.Diverge.def.unfold true /- The following was copied from the `wfRecursion` function. -/ @@ -390,9 +391,10 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do -- Introduce a local declaration for the let-binding withLetDecl dName dTy dValue fun decl => do let isValid ← proveExprIsValid k_var kk_var body - -- Add the let-binding around (rem.: the let-binding should be - -- *inside* the `is_valid_p`, not outside, but because it reduces - -- in the end it doesn't matter) + -- Add the let-binding around. + -- Rem.: the let-binding should be *inside* the `is_valid_p`, not outside, + -- but because it reduces in the end it doesn't matter. More precisely: + -- `P (let x := v in y)` and `let x := v in P y` reduce to the same expression. mkLetFVars #[decl] isValid | .mdata _ b => proveExprIsValid k_var kk_var b | .proj _ _ _ => @@ -692,9 +694,9 @@ def proveMutRecIsValid -- Generate the final definions by using the mutual body and the fixed point operator. def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : - TermElabM Unit := do + TermElabM (Array Name) := do let grSize := preDefs.size - let _ ← preDefs.mapIdxM fun idx preDef => do + let defs ← preDefs.mapIdxM fun idx preDef => do lambdaLetTelescope preDef.value fun xs _ => do -- Create the index let idx ← mkFinVal grSize idx.val @@ -715,7 +717,58 @@ def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : all := [name] } addDecl decl - pure () + pure name + pure defs + +-- Prove the equations that we will use as unfolding theorems +partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinition) + (decls : Array Name) : MetaM Unit := do + let grSize := preDefs.size + let proveIdx (i : Nat) : MetaM Unit := do + let preDef := preDefs.get! i + let defName := decls.get! i + -- Retrieve the arguments + lambdaLetTelescope preDef.value fun xs body => do + trace[Diverge.def.unfold] "proveUnfoldingThms: xs: {xs}" + trace[Diverge.def.unfold] "proveUnfoldingThms: body: {body}" + -- The theorem statement + let thmTy ← do + -- The equation: the declaration gives the lhs, the pre-def gives the rhs + let lhs ← mkAppOptM defName (xs.map some) + let rhs := body + let eq ← mkAppM ``Eq #[lhs, rhs] + mkForallFVars xs eq + trace[Diverge.def.unfold] "proveUnfoldingThms: thm statement: {thmTy}" + -- The proof + -- Use the fixed-point equation + let proof ← mkAppM ``FixI.is_valid_fix_fixed_eq #[isValidThm] + -- Add the index + let idx ← mkFinVal grSize i + let proof ← mkAppM ``congr_fun #[proof, idx] + -- Add the input argument + let arg ← mkSigmas xs.toList + let proof ← mkAppM ``congr_fun #[proof, arg] + -- Abstract the arguments away + let proof ← mkLambdaFVars xs proof + trace[Diverge.def.unfold] "proveUnfoldingThms: proof: {proof}:\n{← inferType proof}" + -- Declare the theorem + let name := preDef.declName ++ "unfold" + let decl := Declaration.thmDecl { + name + levelParams := preDef.levelParams + type := thmTy + value := proof + all := [name] + } + addDecl decl + trace[Diverge.def.unfold] "proveUnfoldingThms: added thm: {name}:\n{thmTy}" + let rec prove (i : Nat) : MetaM Unit := do + if i = preDefs.size then pure () + else do + proveIdx i + prove (i + 1) + -- + prove 0 def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) @@ -817,12 +870,12 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies -- Generate the final definitions - let defs ← mkDeclareFixDefs mutRecBody preDefs + let decls ← mkDeclareFixDefs mutRecBody preDefs - -- Prove the unfolding equations - -- TODO + -- Prove the unfolding theorems + proveUnfoldingThms isValidThm preDefs decls - -- Process the definitions + -- Process the definitions - TODO addAndCompilePartialRec preDefs -- The following function is copy&pasted from Lean.Elab.PreDefinition.Main @@ -942,10 +995,23 @@ divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := if i = 0 then return x else return (← list_nth ls (i - 1)) -#print list_nth.in_out_ty -#check list_nth.sbody -#check list_nth.mut_rec_body -#print list_nth +example {a: Type} (ls : List a) : + ∀ (i : Int), + 0 ≤ i → i < ls.length → + ∃ x, list_nth ls i = .ret x := by + induction ls + . intro i hpos h; simp at h; linarith + . rename_i hd tl ih + intro i hpos h + rw [list_nth.unfold]; simp + split <;> simp [*] + . tauto + . -- TODO: we shouldn't have to do that + have hneq : 0 < i := by cases i <;> rename_i a _ <;> simp_all; cases a <;> simp_all + simp at h + have ⟨ x, ih ⟩ := ih (i - 1) (by linarith) (by linarith) + simp [ih] + tauto mutual divergent def is_even (i : Int) : Result Bool := @@ -955,10 +1021,8 @@ mutual if i = 0 then return false else return (← is_even (i - 1)) end -example (i : Int) : is_even i = .ret (i % 2 = 0) ∧ is_odd i = .ret (i % 2 ≠ 0) := by - induction i - unfold is_even - sorry +#print is_even.unfold +#print is_odd.unfold mutual divergent def foo (i : Int) : Result Nat := @@ -968,6 +1032,9 @@ mutual if i > 20 then foo (i / 20) else .ret 42 end +#print foo.unfold +#print bar.unfold + -- Testing dependent branching and let-bindings -- TODO: why the linter warning? divergent def is_non_zero (i : Int) : Result Bool := @@ -976,4 +1043,6 @@ divergent def is_non_zero (i : Int) : Result Bool := let b := true return b +#print is_non_zero.unfold + end Diverge -- cgit v1.2.3 From 75fae6384716f24fe137283d4a41836782b9aec7 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 3 Jul 2023 19:26:27 +0200 Subject: Cleanup a bit Diverge/Elab.lean --- backends/lean/Base/Diverge/Elab.lean | 366 +++++++++++++++++++---------------- 1 file changed, 197 insertions(+), 169 deletions(-) (limited to 'backends/lean/Base/Diverge/Elab.lean') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 91c51a31..cc580265 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -15,39 +15,16 @@ syntax (name := divergentDef) open Lean Elab Term Meta Primitives Lean.Meta -set_option trace.Diverge.def true --- set_option trace.Diverge.def.valid true --- set_option trace.Diverge.def.sigmas true -set_option trace.Diverge.def.unfold true - /- The following was copied from the `wfRecursion` function. -/ open WF in -def mkList (xl : List Expr) (ty : Expr) : MetaM Expr := - match xl with - | [] => - mkAppOptM ``List.nil #[some ty] - | x :: tl => do - let tl ← mkList tl ty - mkAppOptM ``List.cons #[some ty, some x, some tl] - def mkProd (x y : Expr) : MetaM Expr := mkAppM ``Prod.mk #[x, y] def mkInOutTy (x y : Expr) : MetaM Expr := mkAppM ``FixI.mk_in_out_ty #[x, y] --- TODO: is there already such a utility somewhere? --- TODO: change to mkSigmas -def mkProds (tys : List Expr) : MetaM Expr := - match tys with - | [] => do pure (Expr.const ``PUnit.unit []) - | [ty] => do pure ty - | ty :: tys => do - let pty ← mkProds tys - mkAppM ``Prod.mk #[ty, pty] - -- Return the `a` in `Return a` def get_result_ty (ty : Expr) : MetaM Expr := ty.withApp fun f args => do @@ -56,26 +33,31 @@ def get_result_ty (ty : Expr) : MetaM Expr := else pure (args.get! 0) --- Group a list of expressions into a dependent tuple -def mkSigmas (xl : List Expr) : MetaM Expr := +/- Group a list of expressions into a dependent tuple. + + Example: + xl = [`a : Type`, `ls : List a`] + returns: + `⟨ (a:Type), (ls: List a) ⟩` + -/ +def mkSigmasVal (xl : List Expr) : MetaM Expr := match xl with | [] => do - trace[Diverge.def.sigmas] "mkSigmas: []" + trace[Diverge.def.sigmas] "mkSigmasVal: []" pure (Expr.const ``PUnit.unit []) | [x] => do - trace[Diverge.def.sigmas] "mkSigmas: [{x}]" + trace[Diverge.def.sigmas] "mkSigmasVal: [{x}]" pure x | fst :: xl => do - trace[Diverge.def.sigmas] "mkSigmas: [{fst}::{xl}]" + trace[Diverge.def.sigmas] "mkSigmasVal: [{fst}::{xl}]" let alpha ← Lean.Meta.inferType fst - let snd ← mkSigmas xl + let snd ← mkSigmasVal xl let snd_ty ← inferType snd let beta ← mkLambdaFVars #[fst] snd_ty - trace[Diverge.def.sigmas] "mkSigmas:\n{alpha}\n{beta}\n{fst}\n{snd}" + trace[Diverge.def.sigmas] "mkSigmasVal:\n{alpha}\n{beta}\n{fst}\n{snd}" mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd] -/- Generate the input type of a function body, which is a sigma type (i.e., a - dependent tuple) which groups all its inputs. +/- Generate a Sigma type from a list of expressions. Example: - xl = [(a:Type), (ls:List a), (i:Int)] @@ -84,7 +66,7 @@ def mkSigmas (xl : List Expr) : MetaM Expr := `(a:Type) × (ls:List a) × (i:Int)` -/ -def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr := +def mkSigmasType (xl : List Expr) : MetaM Expr := match xl with | [] => do trace[Diverge.def.sigmas] "mkSigmasOfTypes: []" @@ -96,15 +78,16 @@ def mkSigmasTypesOfTypes (xl : List Expr) : MetaM Expr := | x :: xl => do trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]" let alpha ← Lean.Meta.inferType x - let sty ← mkSigmasTypesOfTypes xl + let sty ← mkSigmasType xl trace[Diverge.def.sigmas] "mkSigmasOfTypes: [{x}::{xl}]: alpha={alpha}, sty={sty}" let beta ← mkLambdaFVars #[x] sty trace[Diverge.def.sigmas] "mkSigmasOfTypes: ({alpha}) ({beta})" mkAppOptM ``Sigma #[some alpha, some beta] -def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index +def mkAnonymous (s : String) (i : Nat) : Name := + .num (.str .anonymous s) i -/- Given a list of values `[x0:ty0, ..., xn:ty1]` where every `xi` might use the previous +/- Given a list of values `[x0:ty0, ..., xn:ty1]`, where every `xi` might use the previous `xj` (j < i) and a value `out` which uses `x0`, ..., `xn`, generate the following expression: ``` @@ -112,20 +95,22 @@ def mk_indexed_name (index : Nat) : Name := .num (.str .anonymous "_uniq") index match x with | (x0, ..., xn) => out ``` - + The `index` parameter is used for naming purposes: we use it to numerotate the bound variables that we introduce. + We use this function to currify functions (the function bodies given to the + fixed-point operator must be unary functions). + Example: ======== - More precisely: - xl = `[a:Type, ls:List a, i:Int]` - out = `a` - index = 0 - generates: + generates (getting rid of most of the syntactic sugar): ``` - match scrut0 with + λ scrut0 => match scrut0 with | Sigma.mk x scrut1 => match scrut1 with | Sigma.mk ls i => @@ -138,21 +123,30 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met -- This would be unexpected throwError "mkSigmasMatch: empyt list of input parameters" | [x] => do - -- In the explanations above: inner match case + -- In the example given for the explanations: this is the inner match case trace[Diverge.def.sigmas] "mkSigmasMatch: [{x}]" mkLambdaFVars #[x] out | fst :: xl => do - -- In the explanations above: outer match case + -- In the example given for the explanations: this is the outer match case -- Remark: for the naming purposes, we use the same convention as for the - -- fields and parameters in `Sigma.casesOn and `Sigma.mk + -- fields and parameters in `Sigma.casesOn` and `Sigma.mk` (looking at + -- those definitions might help) + -- + -- We want to build the match expression: + -- ``` + -- λ scrut => + -- match scrut with + -- | Sigma.mk x ... -- the hole is given by a recursive call on the tail + -- ``` trace[Diverge.def.sigmas] "mkSigmasMatch: [{fst}::{xl}]" let alpha ← Lean.Meta.inferType fst - let snd_ty ← mkSigmasTypesOfTypes xl + let snd_ty ← mkSigmasType xl let beta ← mkLambdaFVars #[fst] snd_ty let snd ← mkSigmasMatch xl out (index + 1) - let scrut_ty ← mkSigmasTypesOfTypes (fst :: xl) - withLocalDeclD (mk_indexed_name index) scrut_ty fun scrut => do let mk ← mkLambdaFVars #[fst] snd + -- Introduce the "scrut" variable + let scrut_ty ← mkSigmasType (fst :: xl) + withLocalDeclD (mkAnonymous "scrut" index) scrut_ty fun scrut => do trace[Diverge.def.sigmas] "mkSigmasMatch: scrut: ({scrut}) : ({← inferType scrut})" -- TODO: make the computation of the motive more efficient let motive ← do @@ -166,38 +160,32 @@ partial def mkSigmasMatch (xl : List Expr) (out : Expr) (index : Nat := 0) : Met -- TODO: make this more efficient (we could change the output type of -- mkSigmasMatch mkSigmasMatch (fst :: xl) out_ty + -- The final expression: putting everything together trace[Diverge.def.sigmas] "mkSigmasMatch:\n ({alpha})\n ({beta})\n ({motive})\n ({scrut})\n ({mk})" let sm ← mkAppOptM ``Sigma.casesOn #[some alpha, some beta, some motive, some scrut, some mk] + -- Abstracting the "scrut" variable let sm ← mkLambdaFVars #[scrut] sm trace[Diverge.def.sigmas] "mkSigmasMatch: sm: {sm}" pure sm /- Small tests for list_nth: give a model of what `mkSigmasMatch` should generate -/ -private def list_nth_out_ty2 (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) := +private def list_nth_out_ty_inner (a :Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) := @Sigma.casesOn (List a) (fun (_ls : List a) => Int) (fun (_scrut1:@Sigma (List a) (fun (_ls : List a) => Int)) => Type) scrut1 (fun (_ls : List a) (_i : Int) => Diverge.Primitives.Result a) -private def list_nth_out_ty1 (scrut0 : @Sigma (Type) (fun (a:Type) => +private def list_nth_out_ty_outer (scrut0 : @Sigma (Type) (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int))) := @Sigma.casesOn (Type) (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int)) (fun (_scrut0:@Sigma (Type) (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int))) => Type) scrut0 (fun (a : Type) (scrut1: @Sigma (List a) (fun (_ls : List a) => Int)) => - list_nth_out_ty2 a scrut1) + list_nth_out_ty_inner a scrut1) /- -/ --- TODO: move --- TODO: we can use Array.mapIdx -@[specialize] def mapiAux (i : Nat) (f : Nat → α → β) : List α → List β - | [] => [] - | a::as => f i a :: mapiAux (i+1) f as - -@[specialize] def mapi (f : Nat → α → β) : List α → List β := mapiAux 0 f - -- Return the expression: `Fin n` -- TODO: use more def mkFin (n : Nat) : Expr := @@ -212,15 +200,6 @@ def mkFinVal (n i : Nat) : MetaM Expr := do let ofNat ← mkAppOptM ``Fin.instOfNatFinHAddNatInstHAddInstAddNatOfNat #[n_lit, i_lit] mkAppOptM ``OfNat.ofNat #[none, none, ofNat] --- TODO: remove? -def mkFinValOld (n i : Nat) : MetaM Expr := do - let finTy := mkFin n - let ofNat ← mkAppM ``OfNat #[finTy, .lit (.natVal i)] - match ← trySynthInstance ofNat with - | LOption.some x => - mkAppOptM ``OfNat.ofNat #[none, none, x] - | _ => throwError "mkFinVal: could not synthesize an instance of {ofNat} " - /- Generate and declare as individual definitions the bodies for the individual funcions: - replace the recursive calls with calls to the continutation `k` - make those bodies take one single dependent tuple as input @@ -234,11 +213,11 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) let grSize := preDefs.size -- Compute the map from name to index - the continuation has an indexed type: - -- we use the index (a finite number of type `Fin`) to control the function - -- we call at the recursive call + -- we use the index (a finite number of type `Fin`) to control which function + -- we call at the recursive call site. let nameToId : HashMap Name Nat := - let namesIds := mapi (fun i d => (d.declName, i)) preDefs.toList - HashMap.ofList namesIds + let namesIds := preDefs.mapIdx (fun i d => (d.declName, i.val)) + HashMap.ofList namesIds.toList trace[Diverge.def.genBody] "nameToId: {nameToId.toList}" @@ -260,7 +239,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) -- Compute the index let i ← mkFinVal grSize id -- Put the arguments in one big dependent tuple - let args ← mkSigmas args.toList + let args ← mkSigmasVal args.toList mkAppM' kk_var #[i, args] else -- Not a recursive call: do nothing @@ -277,13 +256,14 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) -- Replace the recursive calls let body ← mapVisit visit_e preDef.value - -- Change the type + -- Currify the function by grouping the arguments into a dependent tuple + -- (over which we match to retrieve the individual arguments). lambdaLetTelescope body fun args body => do let body ← mkSigmasMatch args.toList body 0 -- Add the declaration let value ← mkLambdaFVars #[kk_var] body - let name := preDef.declName.append "sbody" + let name := preDef.declName.append "body" let levelParams := grLvlParams let decl := Declaration.defnDecl { name := name @@ -304,7 +284,7 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) -- Generate a unique function body from the bodies of the mutually recursive group, -- and add it as a declaration in the context. --- We return the list of bodies (of type `Funs ...`) and the mutually recursive body. +-- We return the list of bodies (of type `FixI.Funs ...`) and the mutually recursive body. def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name) (kk_var i_var : Expr) (in_ty out_ty : Expr) (inOutTys : List (Expr × Expr)) @@ -322,7 +302,7 @@ def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name) mkAppOptM ``FixI.Funs.Nil #[finTypeExpr, in_ty, out_ty] | (ity, oty) :: inOutTys, b :: bl => do -- Retrieving ity and oty - this is not very clean - let inOutTysExpr ← mkList (← inOutTys.mapM (λ (x, y) => mkInOutTy x y)) inOutTyType + let inOutTysExpr ← mkListLit inOutTyType (← inOutTys.mapM (λ (x, y) => mkInOutTy x y)) let fl ← mkFuns inOutTys bl mkAppOptM ``FixI.Funs.Cons #[finTypeExpr, in_ty, out_ty, ity, oty, inOutTysExpr, b, fl] | _, _ => throwError "mkDeclareMutRecBody: `tys` and `bodies` don't have the same length" @@ -345,7 +325,7 @@ def mkDeclareMutRecBody (grName : Name) (grLvlParams : List Name) all := [name] } addDecl decl - -- Return the constant + -- Return the bodies and the constant pure (bodyFuns, Lean.mkConst name (levelParams.map .param)) def isCasesExpr (e : Expr) : MetaM Bool := do @@ -367,7 +347,8 @@ instance : ToMessageData MatchInfo where -- This is not a very clean formatting, but we don't need more toMessageData := fun me => m!"\n- matcherName: {me.matcherName}\n- params: {me.params}\n- motive: {me.motive}\n- scruts: {me.scruts}\n- branchesNumParams: {me.branchesNumParams}\n- branches: {me.branches}" --- An expression which doesn't use the continuation kk is valid +-- Small helper: prove that an expression which doesn't use the continuation `kk` +-- is valid, and return the proof. def proveNoKExprIsValid (k_var : Expr) (e : Expr) : MetaM Expr := do trace[Diverge.def.valid] "proveNoKExprIsValid: {e}" let eIsValid ← mkAppM ``FixI.is_valid_p_same #[k_var, e] @@ -376,6 +357,14 @@ def proveNoKExprIsValid (k_var : Expr) (e : Expr) : MetaM Expr := do mutual +/- Prove that an expression is valid, and return the proof. + + More precisely, if `e` is an expression which potentially uses the continution + `kk`, return an expression of type: + ``` + is_valid_p k (λ kk => e) + ``` + -/ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do trace[Diverge.def.valid] "proveValid: {e}" match e with @@ -403,7 +392,9 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do | .app .. => e.withApp fun f args => do -- There are several cases: first, check if this is a match/if - -- The expression is a (dependent) if then else + -- Check if the expression is a (dependent) if then else. + -- We treat the if then else expressions differently from the other matches, + -- and have dedicated theorems for them. let isIte := e.isIte if isIte || e.isDIte then do e.withApp fun f args => do @@ -431,9 +422,9 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do let eIsValid ← mkAppOptM const #[none, none, none, none, some k_var, some cond, some dec, none, none, some br0Valid, some br1Valid] trace[Diverge.def.valid] "ite/dite: result:\n{eIsValid}:\n{← inferType eIsValid}" pure eIsValid - -- The expression is a match (this case is for when the elaborator + -- Check if the expression is a match (this case is for when the elaborator -- introduces auxiliary definitions to hide the match behind syntactic - -- sugar) + -- sugar): else if let some me := ← matchMatcherApp? e then do trace[Diverge.def.valid] "matcherApp: @@ -443,7 +434,8 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do - altNumParams: {me.altNumParams} - alts: {me.alts} - remaining: {me.remaining}" - -- matchMatcherApp has already done the work for us + -- matchMatcherApp does all the work for us: we simply need to gather + -- the information and call the auxiliary helper `proveMatchIsValid` if me.remaining.size ≠ 0 then throwError "MatcherApp: non empty remaining array: {me.remaining}" let me : MatchInfo := { @@ -456,14 +448,21 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do branches := me.alts } proveMatchIsValid k_var kk_var me - -- The expression is a raw match (this case is for when the expression - -- is a direct call to the primitive `casesOn` function, without - -- syntactic sugar) + -- Check if the expression is a raw match (this case is for when the expression + -- is a direct call to the primitive `casesOn` function, without syntactic sugar). + -- We have to check this case because functions like `mkSigmasMatch`, which we + -- use to currify function bodies, introduce such raw matches. else if ← isCasesExpr f then do trace[Diverge.def.valid] "rawMatch: {e}" + -- Deconstruct the match, and call the auxiliary helper `proveMatchIsValid`. + -- -- The casesOn definition is always of the following shape: - -- input parameters (implicit parameters), then motive (implicit), - -- scrutinee (explicit), branches (explicit). + -- - input parameters (implicit parameters) + -- - motive (implicit), -- the motive gives the return type of the match + -- - scrutinee (explicit) + -- - branches (explicit). + -- In particular, we notice that the scrutinee is the first *explicit* + -- parameter - this is how we spot it. let matcherName := f.constName! let matcherLevels := f.constLevels!.toArray -- Find the first explicit parameter: this is the scrutinee @@ -484,7 +483,9 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do let scrut := args.get! scrutIdx let branches := args.extract (scrutIdx + 1) args.size -- Compute the number of parameters for the branches: for this we use - -- the type of the uninstantiated casesOn constant + -- the type of the uninstantiated casesOn constant (we can't just + -- destruct the lambdas in the branch expressions because the result + -- of a match might be a lambda expression). let branchesNumParams : Array Nat ← do let env ← getEnv let decl := env.constants.find! matcherName @@ -505,9 +506,11 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do branches, } proveMatchIsValid k_var kk_var me - -- Monadic let-binding + -- Check if this is a monadic let-binding else if f.isConstOf ``Bind.bind then do trace[Diverge.def.valid] "bind:\n{args}" + -- We simply need to prove that the subexpressions are valid, and call + -- the appropriate lemma. let x := args.get! 4 let y := args.get! 5 -- Prove that the subexpressions are valid @@ -529,7 +532,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do -- Put everything together trace[Diverge.def.valid] "bind:\n- xValid: {xValid}: {← inferType xValid}\n- yValid: {yValid}: {← inferType yValid}" mkAppM ``FixI.is_valid_p_bind #[xValid, yValid] - -- Recursive call + -- Check if this is a recursive call, i.e., a call to the continuation `kk` else if f.isFVarOf kk_var.fvarId! then do trace[Diverge.def.valid] "rec: args: \n{args}" if args.size ≠ 2 then throwError "Recursive call with invalid number of parameters: {args}" @@ -540,9 +543,10 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do pure eIsValid else do -- Remaining case: normal application. - -- It shouldn't use the continuation + -- It shouldn't use the continuation. proveNoKExprIsValid k_var e +-- Prove that a match expression is valid. partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Expr := do trace[Diverge.def.valid] "proveMatchIsValid: {me}" -- Prove the validity of the branch expressions @@ -561,16 +565,18 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp -- Reconstruct the lambda expression mkLambdaFVars xs_beg brValid trace[Diverge.def.valid] "branchesValid:\n{branchesValid}" - -- Put together: compute the motive. - -- It must be of the shape: + -- Compute the motive, which has the following shape: -- ``` -- λ scrut => is_valid_p k (λ k => match scrut with ...) + -- ^^^^^^^^^^^^^^^^^^^^ + -- this is the original match expression, with the + -- the difference that the scrutinee(s) is a variable -- ``` let validMotive : Expr ← do -- The motive is a function of the scrutinees (i.e., a lambda expression): -- introduce binders for the scrutinees let declInfos := me.scruts.mapIdx fun idx scrut => - let name : Name := (.num (.str .anonymous "scrut") idx) + let name : Name := mkAnonymous "scrut" idx let ty := λ (_ : Array Expr) => inferType scrut (name, ty) withLocalDeclsD declInfos fun scrutVars => do @@ -582,7 +588,6 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp let branches : Array (Option Expr) := me.branches.map some let args := params ++ [motive] ++ scruts ++ branches let matchE ← mkAppOptM me.matcherName args - -- let matchE ← mkLambdaFVars scrutVars (← mkAppOptM me.matcherName args) -- Wrap in the `is_valid_p` predicate let matchE ← mkLambdaFVars #[kk_var] matchE let validMotive ← mkAppM ``FixI.is_valid_p #[k_var, matchE] @@ -591,6 +596,7 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp trace[Diverge.def.valid] "valid motive: {validMotive}" -- Put together let valid ← do + -- We let Lean infer the parameters let params : Array (Option Expr) := me.params.map (λ _ => none) let motive := some validMotive let scruts := me.scruts.map some @@ -602,12 +608,16 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp end --- Prove that a single body (in the mutually recursive group) is valid -partial def proveSingleBodyIsValid (k_var : Expr) (preDef : PreDefinition) (bodyConst : Expr) : +-- Prove that a single body (in the mutually recursive group) is valid. +-- +-- For instance, if we define the mutually recursive group [`is_even`, `is_odd`], +-- we prove that `is_even.body` and `is_odd.body` are valid. +partial def proveSingleBodyIsValid + (k_var : Expr) (preDef : PreDefinition) (bodyConst : Expr) : MetaM Expr := do trace[Diverge.def.valid] "proveSingleBodyIsValid: bodyConst: {bodyConst}" - -- Lookup the definition (`bodyConst` is the definition of the body, we want - -- to retrieve the value itself to dive inside) + -- Lookup the definition (`bodyConst` is a const, we want to retrieve its + -- definition to dive inside) let name := bodyConst.constName! let env ← getEnv let body := (env.constants.find! name).value! @@ -633,7 +643,7 @@ partial def proveSingleBodyIsValid (k_var : Expr) (preDef : PreDefinition) (body mkForallFVars #[k_var, x_var] ty trace[Diverge.def.valid] "proveSingleBodyIsValid: thmTy\n{thmTy}:\n{← inferType thmTy}" -- Save the theorem - let name := preDef.declName ++ "sbody_is_valid" + let name := preDef.declName ++ "body_is_valid" let decl := Declaration.thmDecl { name levelParams := preDef.levelParams @@ -646,6 +656,11 @@ partial def proveSingleBodyIsValid (k_var : Expr) (preDef : PreDefinition) (body -- Return the theorem pure (Expr.const name (preDef.levelParams.map .param)) +-- Prove that the list of bodies are valid. +-- +-- For instance, if we define the mutually recursive group [`is_even`, `is_odd`], +-- we prove that `Funs.Cons is_even.body (Funs.Cons is_odd.body Funs.Nil)` is +-- valid. partial def proveFunsBodyIsValid (inOutTys: Expr) (bodyFuns : Expr) (k_var : Expr) (bodiesValid : Array Expr) : MetaM Expr := do -- Create the big "and" expression, which groups the validity proof of the individual bodies @@ -663,7 +678,13 @@ partial def proveFunsBodyIsValid (inOutTys: Expr) (bodyFuns : Expr) let isValid ← mkAppM ``FixI.Funs.is_valid_p_is_valid_p #[inOutTys, k_var, bodyFuns, andExpr] mkLambdaFVars #[k_var] isValid --- Prove that the mut rec body is valid +-- Prove that the mut rec body (i.e., the unary body which groups the bodies +-- of all the functions in the mutually recursive group and on which we will +-- apply the fixed-point operator) is valid. +-- +-- We save the proof in the theorem "[GROUP_NAME]."mut_rec_body_is_valid", +-- which we return. +-- -- TODO: maybe this function should introduce k_var itself def proveMutRecIsValid (grName : Name) (grLvlParams : List Name) @@ -693,6 +714,12 @@ def proveMutRecIsValid pure (Expr.const name (grLvlParams.map .param)) -- Generate the final definions by using the mutual body and the fixed point operator. +-- +-- For instance: +-- ``` +-- def is_even (i : Int) : Result Bool := mut_rec_body 0 i +-- def is_odd (i : Int) : Result Bool := mut_rec_body 1 i +-- ``` def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : TermElabM (Array Name) := do let grSize := preDefs.size @@ -701,7 +728,7 @@ def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : -- Create the index let idx ← mkFinVal grSize idx.val -- Group the inputs into a dependent tuple - let input ← mkSigmas xs.toList + let input ← mkSigmasVal xs.toList -- Apply the fixed point let fixedBody ← mkAppM ``FixI.fix #[mutRecBody, idx, input] let fixedBody ← mkLambdaFVars xs fixedBody @@ -746,7 +773,7 @@ partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinitio let idx ← mkFinVal grSize i let proof ← mkAppM ``congr_fun #[proof, idx] -- Add the input argument - let arg ← mkSigmas xs.toList + let arg ← mkSigmasVal xs.toList let proof ← mkAppM ``congr_fun #[proof, arg] -- Abstract the arguments away let proof ← mkLambdaFVars xs proof @@ -774,11 +801,6 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) trace[Diverge.def] ("divRecursion: defs: " ++ msg) - -- CHANGE HERE This function should add definitions with these names/types/values ^^ - -- Temporarily add the predefinitions as axioms - -- for preDef in preDefs do - -- addAsAxiom preDef - -- TODO: what is this? for preDef in preDefs do applyAttributesOf #[preDef] AttributeApplicationTime.afterCompilation @@ -803,7 +825,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do if preDef.levelParams ≠ grLvlParams then throwError "Non-uniform polymorphism in the universes" forallTelescope preDef.type (fun in_tys out_ty => do - let in_ty ← liftM (mkSigmasTypesOfTypes in_tys.toList) + let in_ty ← liftM (mkSigmasType in_tys.toList) -- Retrieve the type in the "Result" let out_ty ← get_result_ty out_ty let out_ty ← liftM (mkSigmasMatch in_tys.toList out_ty) @@ -813,14 +835,14 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do trace[Diverge.def] "inOutTys: {inOutTys}" -- Turn the list of input/output type pairs into an expresion let inOutTysExpr ← inOutTys.mapM (λ (x, y) => mkInOutTy x y) - let inOutTysExpr ← mkList inOutTysExpr.toList (← inferType (inOutTysExpr.get! 0)) + let inOutTysExpr ← mkListLit (← inferType (inOutTysExpr.get! 0)) inOutTysExpr.toList -- From the list of pairs of input/output types, actually compute the -- type of the continuation `k`. -- We first introduce the index `i : Fin n` where `n` is the number of -- functions in the group. let i_var_ty := mkFin preDefs.size - withLocalDeclD (.num (.str .anonymous "i") 0) i_var_ty fun i_var => do + withLocalDeclD (mkAnonymous "i" 0) i_var_ty fun i_var => do let in_out_ty ← mkAppM ``List.get #[inOutTysExpr, i_var] trace[Diverge.def] "in_out_ty := {in_out_ty} : {← inferType in_out_ty}" -- Add an auxiliary definition for `in_out_ty` @@ -844,7 +866,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do trace[Diverge.def] "in_out_ty (after decl) := {in_out_ty} : {← inferType in_out_ty}" let in_ty ← mkAppM ``Sigma.fst #[in_out_ty] trace[Diverge.def] "in_ty: {in_ty}" - withLocalDeclD (.num (.str .anonymous "x") 1) in_ty fun input => do + withLocalDeclD (mkAnonymous "x" 1) in_ty fun input => do let out_ty ← mkAppM' (← mkAppM ``Sigma.snd #[in_out_ty]) #[input] trace[Diverge.def] "out_ty: {out_ty}" @@ -853,7 +875,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let out_ty ← mkLambdaFVars #[i_var, input] out_ty let kk_var_ty ← mkAppM ``FixI.kk_ty #[i_var_ty, in_ty, out_ty] trace[Diverge.def] "kk_var_ty: {kk_var_ty}" - withLocalDeclD (.num (.str .anonymous "kk") 2) kk_var_ty fun kk_var => do + withLocalDeclD (mkAnonymous "kk" 2) kk_var_ty fun kk_var => do trace[Diverge.def] "kk_var: {kk_var}" -- Replace the recursive calls in all the function bodies by calls to the @@ -866,7 +888,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Prove that the mut rec body satisfies the validity criteria required by -- our fixed-point let k_var_ty ← mkAppM ``FixI.k_ty #[i_var_ty, in_ty, out_ty] - withLocalDeclD (.num (.str .anonymous "k") 3) k_var_ty fun k_var => do + withLocalDeclD (mkAnonymous "k" 3) k_var_ty fun k_var => do let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies -- Generate the final definitions @@ -915,7 +937,7 @@ def addPreDefinitions (preDefs : Array PreDefinition) : TermElabM Unit := withLC else return () catch _ => s.restore --- The following two functions are copy&pasted from Lean.Elab.MutualDef +-- The following two functions are copy-pasted from Lean.Elab.MutualDef open private elabHeaders levelMVarToParamHeaders getAllUserLevelNames withFunLocalDecls elabFunValues instantiateMVarsAtHeader instantiateMVarsAtLetRecToLift checkLetRecsToLiftTypes withUsed from Lean.Elab.MutualDef @@ -988,61 +1010,67 @@ elab_rules : command else Command.elabCommand <| ← `(namespace $(mkIdentFrom id ns) $cmd end $(mkIdentFrom id ns)) -divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := - match ls with - | [] => .fail .panic - | x :: ls => - if i = 0 then return x - else return (← list_nth ls (i - 1)) - -example {a: Type} (ls : List a) : - ∀ (i : Int), - 0 ≤ i → i < ls.length → - ∃ x, list_nth ls i = .ret x := by - induction ls - . intro i hpos h; simp at h; linarith - . rename_i hd tl ih - intro i hpos h - rw [list_nth.unfold]; simp - split <;> simp [*] - . tauto - . -- TODO: we shouldn't have to do that - have hneq : 0 < i := by cases i <;> rename_i a _ <;> simp_all; cases a <;> simp_all - simp at h - have ⟨ x, ih ⟩ := ih (i - 1) (by linarith) (by linarith) - simp [ih] - tauto - -mutual - divergent def is_even (i : Int) : Result Bool := - if i = 0 then return true else return (← is_odd (i - 1)) - - divergent def is_odd (i : Int) : Result Bool := - if i = 0 then return false else return (← is_even (i - 1)) -end - -#print is_even.unfold -#print is_odd.unfold - -mutual - divergent def foo (i : Int) : Result Nat := - if i > 10 then return (← foo (i / 10)) + (← bar i) else bar 10 - - divergent def bar (i : Int) : Result Nat := - if i > 20 then foo (i / 20) else .ret 42 -end - -#print foo.unfold -#print bar.unfold - --- Testing dependent branching and let-bindings --- TODO: why the linter warning? -divergent def is_non_zero (i : Int) : Result Bool := - if _h:i = 0 then return false - else - let b := true - return b +namespace Tests + /- Some examples of partial functions -/ + + divergent def list_nth {a: Type} (ls : List a) (i : Int) : Result a := + match ls with + | [] => .fail .panic + | x :: ls => + if i = 0 then return x + else return (← list_nth ls (i - 1)) + + #check list_nth.unfold + + example {a: Type} (ls : List a) : + ∀ (i : Int), + 0 ≤ i → i < ls.length → + ∃ x, list_nth ls i = .ret x := by + induction ls + . intro i hpos h; simp at h; linarith + . rename_i hd tl ih + intro i hpos h + rw [list_nth.unfold]; simp + split <;> simp [*] + . tauto + . -- TODO: we shouldn't have to do that + have hneq : 0 < i := by cases i <;> rename_i a _ <;> simp_all; cases a <;> simp_all + simp at h + have ⟨ x, ih ⟩ := ih (i - 1) (by linarith) (by linarith) + simp [ih] + tauto + + mutual + divergent def is_even (i : Int) : Result Bool := + if i = 0 then return true else return (← is_odd (i - 1)) + + divergent def is_odd (i : Int) : Result Bool := + if i = 0 then return false else return (← is_even (i - 1)) + end + + #check is_even.unfold + #check is_odd.unfold + + mutual + divergent def foo (i : Int) : Result Nat := + if i > 10 then return (← foo (i / 10)) + (← bar i) else bar 10 + + divergent def bar (i : Int) : Result Nat := + if i > 20 then foo (i / 20) else .ret 42 + end + + #check foo.unfold + #check bar.unfold + + -- Testing dependent branching and let-bindings + -- TODO: why the linter warning? + divergent def is_non_zero (i : Int) : Result Bool := + if _h:i = 0 then return false + else + let b := true + return b -#print is_non_zero.unfold + #check is_non_zero.unfold +end Tests end Diverge -- cgit v1.2.3 From 4fd17e4bb91eb46d4704643dfbfbbf0874837b07 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 4 Jul 2023 12:49:37 +0200 Subject: Make Diverge use Primitives --- backends/lean/Base/Diverge/Elab.lean | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'backends/lean/Base/Diverge/Elab.lean') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index cc580265..41209021 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -174,7 +174,7 @@ private def list_nth_out_ty_inner (a :Type) (scrut1: @Sigma (List a) (fun (_ls : (fun (_ls : List a) => Int) (fun (_scrut1:@Sigma (List a) (fun (_ls : List a) => Int)) => Type) scrut1 - (fun (_ls : List a) (_i : Int) => Diverge.Primitives.Result a) + (fun (_ls : List a) (_i : Int) => Primitives.Result a) private def list_nth_out_ty_outer (scrut0 : @Sigma (Type) (fun (a:Type) => @Sigma (List a) (fun (_ls : List a) => Int))) := -- cgit v1.2.3 From bd873499f9a8d517cc948c6336a5c6ce856d846d Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 4 Jul 2023 17:30:35 +0200 Subject: Fix some issues with the extraction to Lean --- backends/lean/Base/Diverge/Elab.lean | 63 +++++++++++++++++++++++++++--------- 1 file changed, 47 insertions(+), 16 deletions(-) (limited to 'backends/lean/Base/Diverge/Elab.lean') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 41209021..4b08fe44 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -255,10 +255,11 @@ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) preDefs.mapM fun preDef => do -- Replace the recursive calls let body ← mapVisit visit_e preDef.value + trace[Diverge.def.genBody] "Body after replacement of the recursive calls: {body}" -- Currify the function by grouping the arguments into a dependent tuple -- (over which we match to retrieve the individual arguments). - lambdaLetTelescope body fun args body => do + lambdaTelescope body fun args body => do let body ← mkSigmasMatch args.toList body 0 -- Add the declaration @@ -376,15 +377,18 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do | .sort _ => throwError "Unreachable" | .lam .. => throwError "Unimplemented" | .forallE .. => throwError "Unreachable" -- Shouldn't get there - | .letE dName dTy dValue body _nonDep => do - -- Introduce a local declaration for the let-binding - withLetDecl dName dTy dValue fun decl => do + | .letE .. => do + -- Telescope all the let-bindings (remark: this also telescopes the lambdas) + lambdaLetTelescope e fun xs body => do + -- Note that we don't visit the bound values: there shouldn't be + -- recursive calls, lambda expressions, etc. inside + -- Prove that the body is valid let isValid ← proveExprIsValid k_var kk_var body - -- Add the let-binding around. + -- Add the let-bindings around. -- Rem.: the let-binding should be *inside* the `is_valid_p`, not outside, -- but because it reduces in the end it doesn't matter. More precisely: -- `P (let x := v in y)` and `let x := v in P y` reduce to the same expression. - mkLetFVars #[decl] isValid + mkLambdaFVars xs isValid (usedLetOnly := false) | .mdata _ b => proveExprIsValid k_var kk_var b | .proj _ _ _ => -- The projection shouldn't use the continuation @@ -410,7 +414,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do if isIte then proveExprIsValid k_var kk_var br else do -- There is a lambda -- TODO: how do we remove exacly *one* lambda? - lambdaLetTelescope br fun xs br => do + lambdaTelescope br fun xs br => do let x := xs.get! 0 let xs := xs.extract 1 xs.size let br ← mkLambdaFVars xs br @@ -518,7 +522,7 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}" let yValid ← do -- This is a lambda expression -- TODO: how do we remove exacly *one* lambda? - lambdaLetTelescope y fun xs y => do + lambdaTelescope y fun xs y => do let x := xs.get! 0 let xs := xs.extract 1 xs.size let y ← mkLambdaFVars xs y @@ -555,7 +559,7 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp -- binders might come from the match, and some of the binders might come -- from the fact that the expression in the match is a lambda expression: -- we use the branchesNumParams field for this reason - lambdaLetTelescope br fun xs br => do + lambdaTelescope br fun xs br => do let numParams := me.branchesNumParams.get! idx let xs_beg := xs.extract 0 numParams let xs_end := xs.extract numParams xs.size @@ -622,7 +626,7 @@ partial def proveSingleBodyIsValid let env ← getEnv let body := (env.constants.find! name).value! trace[Diverge.def.valid] "body: {body}" - lambdaLetTelescope body fun xs body => do + lambdaTelescope body fun xs body => do assert! xs.size = 2 let kk_var := xs.get! 0 let x_var := xs.get! 1 @@ -695,8 +699,10 @@ def proveMutRecIsValid let bodiesValid ← bodies.mapIdxM fun idx body => do let preDef := preDefs.get! idx + trace[Diverge.def.valid] "## Proving that the body {body} is valid" proveSingleBodyIsValid k_var preDef body -- Then prove that the mut rec body is valid + trace[Diverge.def.valid] "## Proving that the 'Funs' body is valid" let isValid ← proveFunsBodyIsValid inOutTys bodyFuns k_var bodiesValid -- Save the theorem let thmTy ← mkAppM ``FixI.is_valid #[mutRecBodyConst] @@ -724,7 +730,7 @@ def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : TermElabM (Array Name) := do let grSize := preDefs.size let defs ← preDefs.mapIdxM fun idx preDef => do - lambdaLetTelescope preDef.value fun xs _ => do + lambdaTelescope preDef.value fun xs _ => do -- Create the index let idx ← mkFinVal grSize idx.val -- Group the inputs into a dependent tuple @@ -755,7 +761,7 @@ partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinitio let preDef := preDefs.get! i let defName := decls.get! i -- Retrieve the arguments - lambdaLetTelescope preDef.value fun xs body => do + lambdaTelescope preDef.value fun xs body => do trace[Diverge.def.unfold] "proveUnfoldingThms: xs: {xs}" trace[Diverge.def.unfold] "proveUnfoldingThms: body: {body}" -- The theorem statement @@ -799,7 +805,7 @@ partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinitio def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do let msg := toMessageData <| preDefs.map fun pd => (pd.declName, pd.levelParams, pd.type, pd.value) - trace[Diverge.def] ("divRecursion: defs: " ++ msg) + trace[Diverge.def] ("divRecursion: defs:\n" ++ msg) -- TODO: what is this? for preDef in preDefs do @@ -880,8 +886,11 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Replace the recursive calls in all the function bodies by calls to the -- continuation `k` and and generate for those bodies declarations + trace[Diverge.def] "# Generating the unary bodies" let bodies ← mkDeclareUnaryBodies grLvlParams kk_var preDefs + trace[Diverge.def] "Unary bodies (after decl): {bodies}" -- Generate the mutually recursive body + trace[Diverge.def] "# Generating the mut rec body" let (bodyFuns, mutRecBody) ← mkDeclareMutRecBody grName grLvlParams kk_var i_var in_ty out_ty inOutTys.toList bodies trace[Diverge.def] "mut rec body (after decl): {mutRecBody}" @@ -889,15 +898,18 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- our fixed-point let k_var_ty ← mkAppM ``FixI.k_ty #[i_var_ty, in_ty, out_ty] withLocalDeclD (mkAnonymous "k" 3) k_var_ty fun k_var => do + trace[Diverge.def] "# Proving that the mut rec body is valid" let isValidThm ← proveMutRecIsValid grName grLvlParams inOutTysExpr bodyFuns mutRecBody k_var preDefs bodies -- Generate the final definitions + trace[Diverge.def] "# Generating the final definitions" let decls ← mkDeclareFixDefs mutRecBody preDefs -- Prove the unfolding theorems + trace[Diverge.def] "# Proving the unfolding theorems" proveUnfoldingThms isValidThm preDefs decls - -- Process the definitions - TODO + -- Generating code -- TODO addAndCompilePartialRec preDefs -- The following function is copy&pasted from Lean.Elab.PreDefinition.Main @@ -1064,13 +1076,32 @@ namespace Tests -- Testing dependent branching and let-bindings -- TODO: why the linter warning? - divergent def is_non_zero (i : Int) : Result Bool := + divergent def isNonZero (i : Int) : Result Bool := if _h:i = 0 then return false else let b := true return b - #check is_non_zero.unfold + #check isNonZero.unfold + + -- Testing let-bindings + divergent def iInBounds {a : Type} (ls : List a) (i : Int) : Result Bool := + let i0 := ls.length + if i < i0 + then Result.ret True + else Result.ret False + + #check iInBounds.unfold + + divergent def isCons + {a : Type} (ls : List a) : Result Bool := + let ls1 := ls + match ls1 with + | [] => Result.ret False + | x :: tl => Result.ret True + + #check isCons.unfold + end Tests end Diverge -- cgit v1.2.3 From 442caaf62e4a217b9a10116c4e529c49f83c4efd Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 4 Jul 2023 22:45:02 +0200 Subject: Fix an issue with mkSigmasVal --- backends/lean/Base/Diverge/Elab.lean | 228 ++++++++++++++++++++++------------- 1 file changed, 143 insertions(+), 85 deletions(-) (limited to 'backends/lean/Base/Diverge/Elab.lean') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 4b08fe44..1af06fea 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -26,38 +26,42 @@ def mkInOutTy (x y : Expr) : MetaM Expr := mkAppM ``FixI.mk_in_out_ty #[x, y] -- Return the `a` in `Return a` -def get_result_ty (ty : Expr) : MetaM Expr := +def getResultTy (ty : Expr) : MetaM Expr := ty.withApp fun f args => do if ¬ f.isConstOf ``Result ∨ args.size ≠ 1 then - throwError "Invalid argument to get_result_ty: {ty}" + throwError "Invalid argument to getResultTy: {ty}" else pure (args.get! 0) -/- Group a list of expressions into a dependent tuple. +/- Deconstruct a sigma type. - Example: - xl = [`a : Type`, `ls : List a`] - returns: - `⟨ (a:Type), (ls: List a) ⟩` + For instance, deconstructs `(a : Type) × List a` into + `Type` and `λ a => List a`. -/ -def mkSigmasVal (xl : List Expr) : MetaM Expr := - match xl with - | [] => do - trace[Diverge.def.sigmas] "mkSigmasVal: []" - pure (Expr.const ``PUnit.unit []) - | [x] => do - trace[Diverge.def.sigmas] "mkSigmasVal: [{x}]" - pure x - | fst :: xl => do - trace[Diverge.def.sigmas] "mkSigmasVal: [{fst}::{xl}]" - let alpha ← Lean.Meta.inferType fst - let snd ← mkSigmasVal xl - let snd_ty ← inferType snd - let beta ← mkLambdaFVars #[fst] snd_ty - trace[Diverge.def.sigmas] "mkSigmasVal:\n{alpha}\n{beta}\n{fst}\n{snd}" - mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd] - -/- Generate a Sigma type from a list of expressions. +def getSigmaTypes (ty : Expr) : MetaM (Expr × Expr) := do + ty.withApp fun f args => do + if ¬ f.isConstOf ``Sigma ∨ args.size ≠ 2 then + throwError "Invalid argument to getSigmaTypes: {ty}" + else + pure (args.get! 0, args.get! 1) + +/- Like `lambdaTelescopeN` but only destructs a fixed number of lambdas -/ +def lambdaTelescopeN (e : Expr) (n : Nat) (k : Array Expr → Expr → MetaM α) : MetaM α := + lambdaTelescope e fun xs body => do + if xs.size < n then throwError "lambdaTelescopeN: not enough lambdas"; + let xs := xs.extract 0 n + let ys := xs.extract n xs.size + let body ← mkLambdaFVars ys body + k xs body + +/- Like `lambdaTelescope`, but only destructs one lambda + TODO: is there an equivalent of this function somewhere in the + standard library? -/ +def lambdaOne (e : Expr) (k : Expr → Expr → MetaM α) : MetaM α := + lambdaTelescopeN e 1 λ xs b => k (xs.get! 0) b + +/- Generate a Sigma type from a list of *variables* (all the expressions + must be variables). Example: - xl = [(a:Type), (ls:List a), (i:Int)] @@ -84,6 +88,53 @@ def mkSigmasType (xl : List Expr) : MetaM Expr := trace[Diverge.def.sigmas] "mkSigmasOfTypes: ({alpha}) ({beta})" mkAppOptM ``Sigma #[some alpha, some beta] +/- Apply a lambda expression to some arguments, simplifying the lambdas -/ +def applyLambdaToArgs (e : Expr) (xs : Array Expr) : MetaM Expr := do + lambdaTelescopeN e xs.size fun vars body => + -- Create the substitution + let s : HashMap FVarId Expr := HashMap.ofList (List.zip (vars.toList.map Expr.fvarId!) xs.toList) + -- Substitute in the body + pure (body.replace fun e => + match e with + | Expr.fvar fvarId => match s.find? fvarId with + | none => e + | some v => v + | _ => none) + +/- Group a list of expressions into a dependent tuple. + + Example: + xl = [`a : Type`, `ls : List a`] + returns: + `⟨ (a:Type), (ls: List a) ⟩` + + We need the type argument because as the elements in the tuple are + "concrete", we can't in all generality figure out the type of the tuple. + + Example: + `⟨ True, 3 ⟩ : (x : Bool) × (if x then Int else Unit)` + -/ +def mkSigmasVal (ty : Expr) (xl : List Expr) : MetaM Expr := + match xl with + | [] => do + trace[Diverge.def.sigmas] "mkSigmasVal: []" + pure (Expr.const ``PUnit.unit []) + | [x] => do + trace[Diverge.def.sigmas] "mkSigmasVal: [{x}]" + pure x + | fst :: xl => do + trace[Diverge.def.sigmas] "mkSigmasVal: [{fst}::{xl}]" + -- Deconstruct the type + let (alpha, beta) ← getSigmaTypes ty + -- Compute the "second" field + -- Specialize beta for fst + let nty ← applyLambdaToArgs beta #[fst] + -- Recursive call + let snd ← mkSigmasVal nty xl + -- Put everything together + trace[Diverge.def.sigmas] "mkSigmasVal:\n{alpha}\n{beta}\n{fst}\n{snd}" + mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd] + def mkAnonymous (s : String) (i : Nat) : Name := .num (.str .anonymous s) i @@ -208,52 +259,57 @@ def mkFinVal (n i : Nat) : MetaM Expr := do We return the new declarations. -/ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) - (preDefs : Array PreDefinition) : + (inOutTys : Array (Expr × Expr)) (preDefs : Array PreDefinition) : MetaM (Array Expr) := do let grSize := preDefs.size - -- Compute the map from name to index - the continuation has an indexed type: - -- we use the index (a finite number of type `Fin`) to control which function - -- we call at the recursive call site. - let nameToId : HashMap Name Nat := - let namesIds := preDefs.mapIdx (fun i d => (d.declName, i.val)) - HashMap.ofList namesIds.toList + -- Compute the map from name to (index × input type). + -- Remark: the continuation has an indexed type; we use the index (a finite number of + -- type `Fin`) to control which function we call at the recursive call site. + let nameToInfo : HashMap Name (Nat × Expr) := + let bl := preDefs.mapIdx fun i d => (d.declName, (i.val, (inOutTys.get! i.val).fst)) + HashMap.ofList bl.toList - trace[Diverge.def.genBody] "nameToId: {nameToId.toList}" + trace[Diverge.def.genBody] "nameToId: {nameToInfo.toList}" -- Auxiliary function to explore the function bodies and replace the -- recursive calls - let visit_e (e : Expr) : MetaM Expr := do - trace[Diverge.def.genBody] "visiting expression: {e}" - match e with - | .app .. => do - e.withApp fun f args => do - trace[Diverge.def.genBody] "this is an app: {f} {args}" - -- Check if this is a recursive call - if f.isConst then - let name := f.constName! - match nameToId.find? name with - | none => pure e - | some id => - -- This is a recursive call: replace it - -- Compute the index - let i ← mkFinVal grSize id - -- Put the arguments in one big dependent tuple - let args ← mkSigmasVal args.toList - mkAppM' kk_var #[i, args] - else - -- Not a recursive call: do nothing - pure e - | .const name _ => - -- Sanity check: we eliminated all the recursive calls - if (nameToId.find? name).isSome then - throwError "mkUnaryBodies: a recursive call was not eliminated" - else pure e - | _ => pure e + let visit_e (i : Nat) (e : Expr) : MetaM Expr := do + trace[Diverge.def.genBody] "visiting expression (dept: {i}): {e}" + let ne ← do + match e with + | .app .. => do + e.withApp fun f args => do + trace[Diverge.def.genBody] "this is an app: {f} {args}" + -- Check if this is a recursive call + if f.isConst then + let name := f.constName! + match nameToInfo.find? name with + | none => pure e + | some (id, in_ty) => + trace[Diverge.def.genBody] "this is a recursive call" + -- This is a recursive call: replace it + -- Compute the index + let i ← mkFinVal grSize id + -- Put the arguments in one big dependent tuple + let args ← mkSigmasVal in_ty args.toList + mkAppM' kk_var #[i, args] + else + -- Not a recursive call: do nothing + pure e + | .const name _ => + -- Sanity check: we eliminated all the recursive calls + if (nameToInfo.find? name).isSome then + throwError "mkUnaryBodies: a recursive call was not eliminated" + else pure e + | _ => pure e + trace[Diverge.def.genBody] "done with expression (depth: {i}): {e}" + pure ne -- Explore the bodies preDefs.mapM fun preDef => do -- Replace the recursive calls + trace[Diverge.def.genBody] "About to replace recursive calls in {preDef.declName}" let body ← mapVisit visit_e preDef.value trace[Diverge.def.genBody] "Body after replacement of the recursive calls: {body}" @@ -413,11 +469,8 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do let proveBranchValid (br : Expr) : MetaM Expr := if isIte then proveExprIsValid k_var kk_var br else do - -- There is a lambda -- TODO: how do we remove exacly *one* lambda? - lambdaTelescope br fun xs br => do - let x := xs.get! 0 - let xs := xs.extract 1 xs.size - let br ← mkLambdaFVars xs br + -- There is a lambda + lambdaOne br fun x br => do let brValid ← proveExprIsValid k_var kk_var br mkLambdaFVars #[x] brValid let br0Valid ← proveBranchValid br0 @@ -521,11 +574,8 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do let xValid ← proveExprIsValid k_var kk_var x trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}" let yValid ← do - -- This is a lambda expression -- TODO: how do we remove exacly *one* lambda? - lambdaTelescope y fun xs y => do - let x := xs.get! 0 - let xs := xs.extract 1 xs.size - let y ← mkLambdaFVars xs y + -- This is a lambda expression + lambdaOne y fun x y => do trace[Diverge.def.valid] "bind: y: {y}" let yValid ← proveExprIsValid k_var kk_var y trace[Diverge.def.valid] "bind: yValid (no forall): {yValid}" @@ -559,15 +609,12 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp -- binders might come from the match, and some of the binders might come -- from the fact that the expression in the match is a lambda expression: -- we use the branchesNumParams field for this reason - lambdaTelescope br fun xs br => do let numParams := me.branchesNumParams.get! idx - let xs_beg := xs.extract 0 numParams - let xs_end := xs.extract numParams xs.size - let br ← mkLambdaFVars xs_end br + lambdaTelescopeN br numParams fun xs br => do -- Prove that the branch expression is valid let brValid ← proveExprIsValid k_var kk_var br -- Reconstruct the lambda expression - mkLambdaFVars xs_beg brValid + mkLambdaFVars xs brValid trace[Diverge.def.valid] "branchesValid:\n{branchesValid}" -- Compute the motive, which has the following shape: -- ``` @@ -726,15 +773,17 @@ def proveMutRecIsValid -- def is_even (i : Int) : Result Bool := mut_rec_body 0 i -- def is_odd (i : Int) : Result Bool := mut_rec_body 1 i -- ``` -def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : +def mkDeclareFixDefs (mutRecBody : Expr) (inOutTys : Array (Expr × Expr)) (preDefs : Array PreDefinition) : TermElabM (Array Name) := do let grSize := preDefs.size let defs ← preDefs.mapIdxM fun idx preDef => do lambdaTelescope preDef.value fun xs _ => do + -- Retrieve the input type + let in_ty := (inOutTys.get! idx.val).fst -- Create the index let idx ← mkFinVal grSize idx.val -- Group the inputs into a dependent tuple - let input ← mkSigmasVal xs.toList + let input ← mkSigmasVal in_ty xs.toList -- Apply the fixed point let fixedBody ← mkAppM ``FixI.fix #[mutRecBody, idx, input] let fixedBody ← mkLambdaFVars xs fixedBody @@ -754,8 +803,8 @@ def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : pure defs -- Prove the equations that we will use as unfolding theorems -partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinition) - (decls : Array Name) : MetaM Unit := do +partial def proveUnfoldingThms (isValidThm : Expr) (inOutTys : Array (Expr × Expr)) + (preDefs : Array PreDefinition) (decls : Array Name) : MetaM Unit := do let grSize := preDefs.size let proveIdx (i : Nat) : MetaM Unit := do let preDef := preDefs.get! i @@ -779,7 +828,7 @@ partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinitio let idx ← mkFinVal grSize i let proof ← mkAppM ``congr_fun #[proof, idx] -- Add the input argument - let arg ← mkSigmasVal xs.toList + let arg ← mkSigmasVal (inOutTys.get! i).fst xs.toList let proof ← mkAppM ``congr_fun #[proof, arg] -- Abstract the arguments away let proof ← mkLambdaFVars xs proof @@ -833,7 +882,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do forallTelescope preDef.type (fun in_tys out_ty => do let in_ty ← liftM (mkSigmasType in_tys.toList) -- Retrieve the type in the "Result" - let out_ty ← get_result_ty out_ty + let out_ty ← getResultTy out_ty let out_ty ← liftM (mkSigmasMatch in_tys.toList out_ty) pure (in_ty, out_ty) ) @@ -886,8 +935,8 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Replace the recursive calls in all the function bodies by calls to the -- continuation `k` and and generate for those bodies declarations - trace[Diverge.def] "# Generating the unary bodies" - let bodies ← mkDeclareUnaryBodies grLvlParams kk_var preDefs + trace[Diverge.def] "# Generating the unary bodies" + let bodies ← mkDeclareUnaryBodies grLvlParams kk_var inOutTys preDefs trace[Diverge.def] "Unary bodies (after decl): {bodies}" -- Generate the mutually recursive body trace[Diverge.def] "# Generating the mut rec body" @@ -903,11 +952,11 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Generate the final definitions trace[Diverge.def] "# Generating the final definitions" - let decls ← mkDeclareFixDefs mutRecBody preDefs + let decls ← mkDeclareFixDefs mutRecBody inOutTys preDefs -- Prove the unfolding theorems trace[Diverge.def] "# Proving the unfolding theorems" - proveUnfoldingThms isValidThm preDefs decls + proveUnfoldingThms isValidThm inOutTys preDefs decls -- Generating code -- TODO addAndCompilePartialRec preDefs @@ -1102,6 +1151,15 @@ namespace Tests #check isCons.unfold + -- Testing what happens when we use concrete arguments in dependent tuples + divergent def test1 + (_ : Option Bool) (_ : Unit) : + Result Unit + := + test1 Option.none () + + #check test1.unfold + end Tests end Diverge -- cgit v1.2.3 From 2496a08691809683e256af7c479588a2fae8e3d7 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 6 Jul 2023 14:23:21 +0200 Subject: Register the unfolding theorems in the Lean equation compilers and solve a "unused variable" warning --- backends/lean/Base/Diverge/Elab.lean | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) (limited to 'backends/lean/Base/Diverge/Elab.lean') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 1af06fea..e5b39440 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -843,6 +843,8 @@ partial def proveUnfoldingThms (isValidThm : Expr) (inOutTys : Array (Expr × Ex all := [name] } addDecl decl + -- Add the unfolding theorem to the equation compiler + eqnsAttribute.add preDef.declName #[name] trace[Diverge.def.unfold] "proveUnfoldingThms: added thm: {name}:\n{thmTy}" let rec prove (i : Nat) : MetaM Unit := do if i = preDefs.size then pure () @@ -1011,6 +1013,13 @@ def Term.elabMutualDef (vars : Array Expr) (views : Array DefView) : TermElabM U withFunLocalDecls headers fun funFVars => do for view in views, funFVar in funFVars do addLocalVarInfo view.declId funFVar + -- Add fake use site to prevent "unused variable" warning (if the + -- function is actually not recursive, Lean would print this warning). + -- Remark: we could detect this case and encode the function without + -- using the fixed-point. In practice it shouldn't happen however: + -- we define non-recursive functions with the `divergent` keyword + -- only for testing purposes. + addTermInfo' view.declId funFVar let values ← try let values ← elabFunValues headers @@ -1091,7 +1100,8 @@ namespace Tests . intro i hpos h; simp at h; linarith . rename_i hd tl ih intro i hpos h - rw [list_nth.unfold]; simp + -- We can directly use `rw [list_nth]`! + rw [list_nth]; simp split <;> simp [*] . tauto . -- TODO: we shouldn't have to do that @@ -1147,7 +1157,7 @@ namespace Tests let ls1 := ls match ls1 with | [] => Result.ret False - | x :: tl => Result.ret True + | _ :: _ => Result.ret True #check isCons.unfold -- cgit v1.2.3 From 9515bbad5b58ed1c51ac6d9fc9d7a4e5884b6273 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 6 Jul 2023 15:23:53 +0200 Subject: Reorganize a bit the lean backend files --- backends/lean/Base/Diverge/Elab.lean | 2 ++ 1 file changed, 2 insertions(+) (limited to 'backends/lean/Base/Diverge/Elab.lean') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index e5b39440..96f7abc0 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -2,6 +2,7 @@ import Lean import Lean.Meta.Tactic.Simp import Init.Data.List.Basic import Mathlib.Tactic.RunCmd +import Base.Utils import Base.Diverge.Base import Base.Diverge.ElabBase @@ -13,6 +14,7 @@ namespace Diverge syntax (name := divergentDef) declModifiers "divergent" "def" declId ppIndent(optDeclSig) declVal : command +open Utils open Lean Elab Term Meta Primitives Lean.Meta /- The following was copied from the `wfRecursion` function. -/ -- cgit v1.2.3 From 7206b48a73d6204baea99f4f4675be2518a8f8c2 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 10 Jul 2023 15:06:12 +0200 Subject: Start working on the progress tactic --- backends/lean/Base/Diverge/Elab.lean | 17 +---------------- 1 file changed, 1 insertion(+), 16 deletions(-) (limited to 'backends/lean/Base/Diverge/Elab.lean') diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 96f7abc0..f109e847 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -14,8 +14,8 @@ namespace Diverge syntax (name := divergentDef) declModifiers "divergent" "def" declId ppIndent(optDeclSig) declVal : command -open Utils open Lean Elab Term Meta Primitives Lean.Meta +open Utils /- The following was copied from the `wfRecursion` function. -/ @@ -47,21 +47,6 @@ def getSigmaTypes (ty : Expr) : MetaM (Expr × Expr) := do else pure (args.get! 0, args.get! 1) -/- Like `lambdaTelescopeN` but only destructs a fixed number of lambdas -/ -def lambdaTelescopeN (e : Expr) (n : Nat) (k : Array Expr → Expr → MetaM α) : MetaM α := - lambdaTelescope e fun xs body => do - if xs.size < n then throwError "lambdaTelescopeN: not enough lambdas"; - let xs := xs.extract 0 n - let ys := xs.extract n xs.size - let body ← mkLambdaFVars ys body - k xs body - -/- Like `lambdaTelescope`, but only destructs one lambda - TODO: is there an equivalent of this function somewhere in the - standard library? -/ -def lambdaOne (e : Expr) (k : Expr → Expr → MetaM α) : MetaM α := - lambdaTelescopeN e 1 λ xs b => k (xs.get! 0) b - /- Generate a Sigma type from a list of *variables* (all the expressions must be variables). -- cgit v1.2.3