diff options
author | Son Ho | 2023-07-04 22:45:02 +0200 |
---|---|---|
committer | Son Ho | 2023-07-04 22:45:02 +0200 |
commit | 442caaf62e4a217b9a10116c4e529c49f83c4efd (patch) | |
tree | 2f32cf144004a098efcae541d106d6b94912eb92 /backends/lean/Base | |
parent | b643bd00747e75d69b6066c55a1798b61277c4b6 (diff) |
Fix an issue with mkSigmasVal
Diffstat (limited to '')
-rw-r--r-- | backends/lean/Base/Diverge/Elab.lean | 228 | ||||
-rw-r--r-- | backends/lean/Base/Diverge/ElabBase.lean | 47 |
2 files changed, 169 insertions, 106 deletions
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean index 4b08fe44..1af06fea 100644 --- a/backends/lean/Base/Diverge/Elab.lean +++ b/backends/lean/Base/Diverge/Elab.lean @@ -26,38 +26,42 @@ def mkInOutTy (x y : Expr) : MetaM Expr := mkAppM ``FixI.mk_in_out_ty #[x, y] -- Return the `a` in `Return a` -def get_result_ty (ty : Expr) : MetaM Expr := +def getResultTy (ty : Expr) : MetaM Expr := ty.withApp fun f args => do if ¬ f.isConstOf ``Result ∨ args.size ≠ 1 then - throwError "Invalid argument to get_result_ty: {ty}" + throwError "Invalid argument to getResultTy: {ty}" else pure (args.get! 0) -/- Group a list of expressions into a dependent tuple. +/- Deconstruct a sigma type. - Example: - xl = [`a : Type`, `ls : List a`] - returns: - `⟨ (a:Type), (ls: List a) ⟩` + For instance, deconstructs `(a : Type) × List a` into + `Type` and `λ a => List a`. -/ -def mkSigmasVal (xl : List Expr) : MetaM Expr := - match xl with - | [] => do - trace[Diverge.def.sigmas] "mkSigmasVal: []" - pure (Expr.const ``PUnit.unit []) - | [x] => do - trace[Diverge.def.sigmas] "mkSigmasVal: [{x}]" - pure x - | fst :: xl => do - trace[Diverge.def.sigmas] "mkSigmasVal: [{fst}::{xl}]" - let alpha ← Lean.Meta.inferType fst - let snd ← mkSigmasVal xl - let snd_ty ← inferType snd - let beta ← mkLambdaFVars #[fst] snd_ty - trace[Diverge.def.sigmas] "mkSigmasVal:\n{alpha}\n{beta}\n{fst}\n{snd}" - mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd] - -/- Generate a Sigma type from a list of expressions. +def getSigmaTypes (ty : Expr) : MetaM (Expr × Expr) := do + ty.withApp fun f args => do + if ¬ f.isConstOf ``Sigma ∨ args.size ≠ 2 then + throwError "Invalid argument to getSigmaTypes: {ty}" + else + pure (args.get! 0, args.get! 1) + +/- Like `lambdaTelescopeN` but only destructs a fixed number of lambdas -/ +def lambdaTelescopeN (e : Expr) (n : Nat) (k : Array Expr → Expr → MetaM α) : MetaM α := + lambdaTelescope e fun xs body => do + if xs.size < n then throwError "lambdaTelescopeN: not enough lambdas"; + let xs := xs.extract 0 n + let ys := xs.extract n xs.size + let body ← mkLambdaFVars ys body + k xs body + +/- Like `lambdaTelescope`, but only destructs one lambda + TODO: is there an equivalent of this function somewhere in the + standard library? -/ +def lambdaOne (e : Expr) (k : Expr → Expr → MetaM α) : MetaM α := + lambdaTelescopeN e 1 λ xs b => k (xs.get! 0) b + +/- Generate a Sigma type from a list of *variables* (all the expressions + must be variables). Example: - xl = [(a:Type), (ls:List a), (i:Int)] @@ -84,6 +88,53 @@ def mkSigmasType (xl : List Expr) : MetaM Expr := trace[Diverge.def.sigmas] "mkSigmasOfTypes: ({alpha}) ({beta})" mkAppOptM ``Sigma #[some alpha, some beta] +/- Apply a lambda expression to some arguments, simplifying the lambdas -/ +def applyLambdaToArgs (e : Expr) (xs : Array Expr) : MetaM Expr := do + lambdaTelescopeN e xs.size fun vars body => + -- Create the substitution + let s : HashMap FVarId Expr := HashMap.ofList (List.zip (vars.toList.map Expr.fvarId!) xs.toList) + -- Substitute in the body + pure (body.replace fun e => + match e with + | Expr.fvar fvarId => match s.find? fvarId with + | none => e + | some v => v + | _ => none) + +/- Group a list of expressions into a dependent tuple. + + Example: + xl = [`a : Type`, `ls : List a`] + returns: + `⟨ (a:Type), (ls: List a) ⟩` + + We need the type argument because as the elements in the tuple are + "concrete", we can't in all generality figure out the type of the tuple. + + Example: + `⟨ True, 3 ⟩ : (x : Bool) × (if x then Int else Unit)` + -/ +def mkSigmasVal (ty : Expr) (xl : List Expr) : MetaM Expr := + match xl with + | [] => do + trace[Diverge.def.sigmas] "mkSigmasVal: []" + pure (Expr.const ``PUnit.unit []) + | [x] => do + trace[Diverge.def.sigmas] "mkSigmasVal: [{x}]" + pure x + | fst :: xl => do + trace[Diverge.def.sigmas] "mkSigmasVal: [{fst}::{xl}]" + -- Deconstruct the type + let (alpha, beta) ← getSigmaTypes ty + -- Compute the "second" field + -- Specialize beta for fst + let nty ← applyLambdaToArgs beta #[fst] + -- Recursive call + let snd ← mkSigmasVal nty xl + -- Put everything together + trace[Diverge.def.sigmas] "mkSigmasVal:\n{alpha}\n{beta}\n{fst}\n{snd}" + mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd] + def mkAnonymous (s : String) (i : Nat) : Name := .num (.str .anonymous s) i @@ -208,52 +259,57 @@ def mkFinVal (n i : Nat) : MetaM Expr := do We return the new declarations. -/ def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr) - (preDefs : Array PreDefinition) : + (inOutTys : Array (Expr × Expr)) (preDefs : Array PreDefinition) : MetaM (Array Expr) := do let grSize := preDefs.size - -- Compute the map from name to index - the continuation has an indexed type: - -- we use the index (a finite number of type `Fin`) to control which function - -- we call at the recursive call site. - let nameToId : HashMap Name Nat := - let namesIds := preDefs.mapIdx (fun i d => (d.declName, i.val)) - HashMap.ofList namesIds.toList + -- Compute the map from name to (index × input type). + -- Remark: the continuation has an indexed type; we use the index (a finite number of + -- type `Fin`) to control which function we call at the recursive call site. + let nameToInfo : HashMap Name (Nat × Expr) := + let bl := preDefs.mapIdx fun i d => (d.declName, (i.val, (inOutTys.get! i.val).fst)) + HashMap.ofList bl.toList - trace[Diverge.def.genBody] "nameToId: {nameToId.toList}" + trace[Diverge.def.genBody] "nameToId: {nameToInfo.toList}" -- Auxiliary function to explore the function bodies and replace the -- recursive calls - let visit_e (e : Expr) : MetaM Expr := do - trace[Diverge.def.genBody] "visiting expression: {e}" - match e with - | .app .. => do - e.withApp fun f args => do - trace[Diverge.def.genBody] "this is an app: {f} {args}" - -- Check if this is a recursive call - if f.isConst then - let name := f.constName! - match nameToId.find? name with - | none => pure e - | some id => - -- This is a recursive call: replace it - -- Compute the index - let i ← mkFinVal grSize id - -- Put the arguments in one big dependent tuple - let args ← mkSigmasVal args.toList - mkAppM' kk_var #[i, args] - else - -- Not a recursive call: do nothing - pure e - | .const name _ => - -- Sanity check: we eliminated all the recursive calls - if (nameToId.find? name).isSome then - throwError "mkUnaryBodies: a recursive call was not eliminated" - else pure e - | _ => pure e + let visit_e (i : Nat) (e : Expr) : MetaM Expr := do + trace[Diverge.def.genBody] "visiting expression (dept: {i}): {e}" + let ne ← do + match e with + | .app .. => do + e.withApp fun f args => do + trace[Diverge.def.genBody] "this is an app: {f} {args}" + -- Check if this is a recursive call + if f.isConst then + let name := f.constName! + match nameToInfo.find? name with + | none => pure e + | some (id, in_ty) => + trace[Diverge.def.genBody] "this is a recursive call" + -- This is a recursive call: replace it + -- Compute the index + let i ← mkFinVal grSize id + -- Put the arguments in one big dependent tuple + let args ← mkSigmasVal in_ty args.toList + mkAppM' kk_var #[i, args] + else + -- Not a recursive call: do nothing + pure e + | .const name _ => + -- Sanity check: we eliminated all the recursive calls + if (nameToInfo.find? name).isSome then + throwError "mkUnaryBodies: a recursive call was not eliminated" + else pure e + | _ => pure e + trace[Diverge.def.genBody] "done with expression (depth: {i}): {e}" + pure ne -- Explore the bodies preDefs.mapM fun preDef => do -- Replace the recursive calls + trace[Diverge.def.genBody] "About to replace recursive calls in {preDef.declName}" let body ← mapVisit visit_e preDef.value trace[Diverge.def.genBody] "Body after replacement of the recursive calls: {body}" @@ -413,11 +469,8 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do let proveBranchValid (br : Expr) : MetaM Expr := if isIte then proveExprIsValid k_var kk_var br else do - -- There is a lambda -- TODO: how do we remove exacly *one* lambda? - lambdaTelescope br fun xs br => do - let x := xs.get! 0 - let xs := xs.extract 1 xs.size - let br ← mkLambdaFVars xs br + -- There is a lambda + lambdaOne br fun x br => do let brValid ← proveExprIsValid k_var kk_var br mkLambdaFVars #[x] brValid let br0Valid ← proveBranchValid br0 @@ -521,11 +574,8 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do let xValid ← proveExprIsValid k_var kk_var x trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}" let yValid ← do - -- This is a lambda expression -- TODO: how do we remove exacly *one* lambda? - lambdaTelescope y fun xs y => do - let x := xs.get! 0 - let xs := xs.extract 1 xs.size - let y ← mkLambdaFVars xs y + -- This is a lambda expression + lambdaOne y fun x y => do trace[Diverge.def.valid] "bind: y: {y}" let yValid ← proveExprIsValid k_var kk_var y trace[Diverge.def.valid] "bind: yValid (no forall): {yValid}" @@ -559,15 +609,12 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp -- binders might come from the match, and some of the binders might come -- from the fact that the expression in the match is a lambda expression: -- we use the branchesNumParams field for this reason - lambdaTelescope br fun xs br => do let numParams := me.branchesNumParams.get! idx - let xs_beg := xs.extract 0 numParams - let xs_end := xs.extract numParams xs.size - let br ← mkLambdaFVars xs_end br + lambdaTelescopeN br numParams fun xs br => do -- Prove that the branch expression is valid let brValid ← proveExprIsValid k_var kk_var br -- Reconstruct the lambda expression - mkLambdaFVars xs_beg brValid + mkLambdaFVars xs brValid trace[Diverge.def.valid] "branchesValid:\n{branchesValid}" -- Compute the motive, which has the following shape: -- ``` @@ -726,15 +773,17 @@ def proveMutRecIsValid -- def is_even (i : Int) : Result Bool := mut_rec_body 0 i -- def is_odd (i : Int) : Result Bool := mut_rec_body 1 i -- ``` -def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : +def mkDeclareFixDefs (mutRecBody : Expr) (inOutTys : Array (Expr × Expr)) (preDefs : Array PreDefinition) : TermElabM (Array Name) := do let grSize := preDefs.size let defs ← preDefs.mapIdxM fun idx preDef => do lambdaTelescope preDef.value fun xs _ => do + -- Retrieve the input type + let in_ty := (inOutTys.get! idx.val).fst -- Create the index let idx ← mkFinVal grSize idx.val -- Group the inputs into a dependent tuple - let input ← mkSigmasVal xs.toList + let input ← mkSigmasVal in_ty xs.toList -- Apply the fixed point let fixedBody ← mkAppM ``FixI.fix #[mutRecBody, idx, input] let fixedBody ← mkLambdaFVars xs fixedBody @@ -754,8 +803,8 @@ def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) : pure defs -- Prove the equations that we will use as unfolding theorems -partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinition) - (decls : Array Name) : MetaM Unit := do +partial def proveUnfoldingThms (isValidThm : Expr) (inOutTys : Array (Expr × Expr)) + (preDefs : Array PreDefinition) (decls : Array Name) : MetaM Unit := do let grSize := preDefs.size let proveIdx (i : Nat) : MetaM Unit := do let preDef := preDefs.get! i @@ -779,7 +828,7 @@ partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinitio let idx ← mkFinVal grSize i let proof ← mkAppM ``congr_fun #[proof, idx] -- Add the input argument - let arg ← mkSigmasVal xs.toList + let arg ← mkSigmasVal (inOutTys.get! i).fst xs.toList let proof ← mkAppM ``congr_fun #[proof, arg] -- Abstract the arguments away let proof ← mkLambdaFVars xs proof @@ -833,7 +882,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do forallTelescope preDef.type (fun in_tys out_ty => do let in_ty ← liftM (mkSigmasType in_tys.toList) -- Retrieve the type in the "Result" - let out_ty ← get_result_ty out_ty + let out_ty ← getResultTy out_ty let out_ty ← liftM (mkSigmasMatch in_tys.toList out_ty) pure (in_ty, out_ty) ) @@ -886,8 +935,8 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Replace the recursive calls in all the function bodies by calls to the -- continuation `k` and and generate for those bodies declarations - trace[Diverge.def] "# Generating the unary bodies" - let bodies ← mkDeclareUnaryBodies grLvlParams kk_var preDefs + trace[Diverge.def] "# Generating the unary bodies" + let bodies ← mkDeclareUnaryBodies grLvlParams kk_var inOutTys preDefs trace[Diverge.def] "Unary bodies (after decl): {bodies}" -- Generate the mutually recursive body trace[Diverge.def] "# Generating the mut rec body" @@ -903,11 +952,11 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do -- Generate the final definitions trace[Diverge.def] "# Generating the final definitions" - let decls ← mkDeclareFixDefs mutRecBody preDefs + let decls ← mkDeclareFixDefs mutRecBody inOutTys preDefs -- Prove the unfolding theorems trace[Diverge.def] "# Proving the unfolding theorems" - proveUnfoldingThms isValidThm preDefs decls + proveUnfoldingThms isValidThm inOutTys preDefs decls -- Generating code -- TODO addAndCompilePartialRec preDefs @@ -1102,6 +1151,15 @@ namespace Tests #check isCons.unfold + -- Testing what happens when we use concrete arguments in dependent tuples + divergent def test1 + (_ : Option Bool) (_ : Unit) : + Result Unit + := + test1 Option.none () + + #check test1.unfold + end Tests end Diverge diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean index 1c1062c0..aaaea6f7 100644 --- a/backends/lean/Base/Diverge/ElabBase.lean +++ b/backends/lean/Base/Diverge/ElabBase.lean @@ -83,7 +83,10 @@ print_decl test1 print_decl test2 -- A map visitor function for expressions (adapted from `AbstractNestedProofs.visit`) -partial def mapVisit (k : Expr → MetaM Expr) (e : Expr) : MetaM Expr := do +-- The continuation takes as parameters: +-- - the current depth of the expression (useful for printing/debugging) +-- - the expression to explore +partial def mapVisit (k : Nat → Expr → MetaM Expr) (e : Expr) : MetaM Expr := do let mapVisitBinders (xs : Array Expr) (k2 : MetaM Expr) : MetaM Expr := do let localInstances ← getLocalInstances let mut lctx ← getLCtx @@ -98,25 +101,27 @@ partial def mapVisit (k : Expr → MetaM Expr) (e : Expr) : MetaM Expr := do lctx :=lctx.modifyLocalDecl xFVarId fun _ => localDecl withLCtx lctx localInstances k2 -- TODO: use a cache? (Lean.checkCache) - -- Explore - let e ← k e - match e with - | .bvar _ - | .fvar _ - | .mvar _ - | .sort _ - | .lit _ - | .const _ _ => pure e - | .app .. => do e.withApp fun f args => return mkAppN f (← args.mapM (mapVisit k)) - | .lam .. => - lambdaLetTelescope e fun xs b => - mapVisitBinders xs do mkLambdaFVars xs (← mapVisit k b) (usedLetOnly := false) - | .forallE .. => do - forallTelescope e fun xs b => mapVisitBinders xs do mkForallFVars xs (← mapVisit k b) - | .letE .. => do - lambdaLetTelescope e fun xs b => mapVisitBinders xs do - mkLambdaFVars xs (← mapVisit k b) (usedLetOnly := false) - | .mdata _ b => return e.updateMData! (← mapVisit k b) - | .proj _ _ b => return e.updateProj! (← mapVisit k b) + let rec visit (i : Nat) (e : Expr) : MetaM Expr := do + -- Explore + let e ← k i e + match e with + | .bvar _ + | .fvar _ + | .mvar _ + | .sort _ + | .lit _ + | .const _ _ => pure e + | .app .. => do e.withApp fun f args => return mkAppN f (← args.mapM (visit (i + 1))) + | .lam .. => + lambdaLetTelescope e fun xs b => + mapVisitBinders xs do mkLambdaFVars xs (← visit (i + 1) b) (usedLetOnly := false) + | .forallE .. => do + forallTelescope e fun xs b => mapVisitBinders xs do mkForallFVars xs (← visit (i + 1) b) + | .letE .. => do + lambdaLetTelescope e fun xs b => mapVisitBinders xs do + mkLambdaFVars xs (← visit (i + 1) b) (usedLetOnly := false) + | .mdata _ b => return e.updateMData! (← visit (i + 1) b) + | .proj _ _ b => return e.updateProj! (← visit (i + 1) b) + visit 0 e end Diverge |