summaryrefslogtreecommitdiff
path: root/backends/lean/Base/Diverge
diff options
context:
space:
mode:
authorSon Ho2023-07-04 22:45:02 +0200
committerSon Ho2023-07-04 22:45:02 +0200
commit442caaf62e4a217b9a10116c4e529c49f83c4efd (patch)
tree2f32cf144004a098efcae541d106d6b94912eb92 /backends/lean/Base/Diverge
parentb643bd00747e75d69b6066c55a1798b61277c4b6 (diff)
Fix an issue with mkSigmasVal
Diffstat (limited to '')
-rw-r--r--backends/lean/Base/Diverge/Elab.lean228
-rw-r--r--backends/lean/Base/Diverge/ElabBase.lean47
2 files changed, 169 insertions, 106 deletions
diff --git a/backends/lean/Base/Diverge/Elab.lean b/backends/lean/Base/Diverge/Elab.lean
index 4b08fe44..1af06fea 100644
--- a/backends/lean/Base/Diverge/Elab.lean
+++ b/backends/lean/Base/Diverge/Elab.lean
@@ -26,38 +26,42 @@ def mkInOutTy (x y : Expr) : MetaM Expr :=
mkAppM ``FixI.mk_in_out_ty #[x, y]
-- Return the `a` in `Return a`
-def get_result_ty (ty : Expr) : MetaM Expr :=
+def getResultTy (ty : Expr) : MetaM Expr :=
ty.withApp fun f args => do
if ¬ f.isConstOf ``Result ∨ args.size ≠ 1 then
- throwError "Invalid argument to get_result_ty: {ty}"
+ throwError "Invalid argument to getResultTy: {ty}"
else
pure (args.get! 0)
-/- Group a list of expressions into a dependent tuple.
+/- Deconstruct a sigma type.
- Example:
- xl = [`a : Type`, `ls : List a`]
- returns:
- `⟨ (a:Type), (ls: List a) ⟩`
+ For instance, deconstructs `(a : Type) × List a` into
+ `Type` and `λ a => List a`.
-/
-def mkSigmasVal (xl : List Expr) : MetaM Expr :=
- match xl with
- | [] => do
- trace[Diverge.def.sigmas] "mkSigmasVal: []"
- pure (Expr.const ``PUnit.unit [])
- | [x] => do
- trace[Diverge.def.sigmas] "mkSigmasVal: [{x}]"
- pure x
- | fst :: xl => do
- trace[Diverge.def.sigmas] "mkSigmasVal: [{fst}::{xl}]"
- let alpha ← Lean.Meta.inferType fst
- let snd ← mkSigmasVal xl
- let snd_ty ← inferType snd
- let beta ← mkLambdaFVars #[fst] snd_ty
- trace[Diverge.def.sigmas] "mkSigmasVal:\n{alpha}\n{beta}\n{fst}\n{snd}"
- mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd]
-
-/- Generate a Sigma type from a list of expressions.
+def getSigmaTypes (ty : Expr) : MetaM (Expr × Expr) := do
+ ty.withApp fun f args => do
+ if ¬ f.isConstOf ``Sigma ∨ args.size ≠ 2 then
+ throwError "Invalid argument to getSigmaTypes: {ty}"
+ else
+ pure (args.get! 0, args.get! 1)
+
+/- Like `lambdaTelescopeN` but only destructs a fixed number of lambdas -/
+def lambdaTelescopeN (e : Expr) (n : Nat) (k : Array Expr → Expr → MetaM α) : MetaM α :=
+ lambdaTelescope e fun xs body => do
+ if xs.size < n then throwError "lambdaTelescopeN: not enough lambdas";
+ let xs := xs.extract 0 n
+ let ys := xs.extract n xs.size
+ let body ← mkLambdaFVars ys body
+ k xs body
+
+/- Like `lambdaTelescope`, but only destructs one lambda
+ TODO: is there an equivalent of this function somewhere in the
+ standard library? -/
+def lambdaOne (e : Expr) (k : Expr → Expr → MetaM α) : MetaM α :=
+ lambdaTelescopeN e 1 λ xs b => k (xs.get! 0) b
+
+/- Generate a Sigma type from a list of *variables* (all the expressions
+ must be variables).
Example:
- xl = [(a:Type), (ls:List a), (i:Int)]
@@ -84,6 +88,53 @@ def mkSigmasType (xl : List Expr) : MetaM Expr :=
trace[Diverge.def.sigmas] "mkSigmasOfTypes: ({alpha}) ({beta})"
mkAppOptM ``Sigma #[some alpha, some beta]
+/- Apply a lambda expression to some arguments, simplifying the lambdas -/
+def applyLambdaToArgs (e : Expr) (xs : Array Expr) : MetaM Expr := do
+ lambdaTelescopeN e xs.size fun vars body =>
+ -- Create the substitution
+ let s : HashMap FVarId Expr := HashMap.ofList (List.zip (vars.toList.map Expr.fvarId!) xs.toList)
+ -- Substitute in the body
+ pure (body.replace fun e =>
+ match e with
+ | Expr.fvar fvarId => match s.find? fvarId with
+ | none => e
+ | some v => v
+ | _ => none)
+
+/- Group a list of expressions into a dependent tuple.
+
+ Example:
+ xl = [`a : Type`, `ls : List a`]
+ returns:
+ `⟨ (a:Type), (ls: List a) ⟩`
+
+ We need the type argument because as the elements in the tuple are
+ "concrete", we can't in all generality figure out the type of the tuple.
+
+ Example:
+ `⟨ True, 3 ⟩ : (x : Bool) × (if x then Int else Unit)`
+ -/
+def mkSigmasVal (ty : Expr) (xl : List Expr) : MetaM Expr :=
+ match xl with
+ | [] => do
+ trace[Diverge.def.sigmas] "mkSigmasVal: []"
+ pure (Expr.const ``PUnit.unit [])
+ | [x] => do
+ trace[Diverge.def.sigmas] "mkSigmasVal: [{x}]"
+ pure x
+ | fst :: xl => do
+ trace[Diverge.def.sigmas] "mkSigmasVal: [{fst}::{xl}]"
+ -- Deconstruct the type
+ let (alpha, beta) ← getSigmaTypes ty
+ -- Compute the "second" field
+ -- Specialize beta for fst
+ let nty ← applyLambdaToArgs beta #[fst]
+ -- Recursive call
+ let snd ← mkSigmasVal nty xl
+ -- Put everything together
+ trace[Diverge.def.sigmas] "mkSigmasVal:\n{alpha}\n{beta}\n{fst}\n{snd}"
+ mkAppOptM ``Sigma.mk #[some alpha, some beta, some fst, some snd]
+
def mkAnonymous (s : String) (i : Nat) : Name :=
.num (.str .anonymous s) i
@@ -208,52 +259,57 @@ def mkFinVal (n i : Nat) : MetaM Expr := do
We return the new declarations.
-/
def mkDeclareUnaryBodies (grLvlParams : List Name) (kk_var : Expr)
- (preDefs : Array PreDefinition) :
+ (inOutTys : Array (Expr × Expr)) (preDefs : Array PreDefinition) :
MetaM (Array Expr) := do
let grSize := preDefs.size
- -- Compute the map from name to index - the continuation has an indexed type:
- -- we use the index (a finite number of type `Fin`) to control which function
- -- we call at the recursive call site.
- let nameToId : HashMap Name Nat :=
- let namesIds := preDefs.mapIdx (fun i d => (d.declName, i.val))
- HashMap.ofList namesIds.toList
+ -- Compute the map from name to (index × input type).
+ -- Remark: the continuation has an indexed type; we use the index (a finite number of
+ -- type `Fin`) to control which function we call at the recursive call site.
+ let nameToInfo : HashMap Name (Nat × Expr) :=
+ let bl := preDefs.mapIdx fun i d => (d.declName, (i.val, (inOutTys.get! i.val).fst))
+ HashMap.ofList bl.toList
- trace[Diverge.def.genBody] "nameToId: {nameToId.toList}"
+ trace[Diverge.def.genBody] "nameToId: {nameToInfo.toList}"
-- Auxiliary function to explore the function bodies and replace the
-- recursive calls
- let visit_e (e : Expr) : MetaM Expr := do
- trace[Diverge.def.genBody] "visiting expression: {e}"
- match e with
- | .app .. => do
- e.withApp fun f args => do
- trace[Diverge.def.genBody] "this is an app: {f} {args}"
- -- Check if this is a recursive call
- if f.isConst then
- let name := f.constName!
- match nameToId.find? name with
- | none => pure e
- | some id =>
- -- This is a recursive call: replace it
- -- Compute the index
- let i ← mkFinVal grSize id
- -- Put the arguments in one big dependent tuple
- let args ← mkSigmasVal args.toList
- mkAppM' kk_var #[i, args]
- else
- -- Not a recursive call: do nothing
- pure e
- | .const name _ =>
- -- Sanity check: we eliminated all the recursive calls
- if (nameToId.find? name).isSome then
- throwError "mkUnaryBodies: a recursive call was not eliminated"
- else pure e
- | _ => pure e
+ let visit_e (i : Nat) (e : Expr) : MetaM Expr := do
+ trace[Diverge.def.genBody] "visiting expression (dept: {i}): {e}"
+ let ne ← do
+ match e with
+ | .app .. => do
+ e.withApp fun f args => do
+ trace[Diverge.def.genBody] "this is an app: {f} {args}"
+ -- Check if this is a recursive call
+ if f.isConst then
+ let name := f.constName!
+ match nameToInfo.find? name with
+ | none => pure e
+ | some (id, in_ty) =>
+ trace[Diverge.def.genBody] "this is a recursive call"
+ -- This is a recursive call: replace it
+ -- Compute the index
+ let i ← mkFinVal grSize id
+ -- Put the arguments in one big dependent tuple
+ let args ← mkSigmasVal in_ty args.toList
+ mkAppM' kk_var #[i, args]
+ else
+ -- Not a recursive call: do nothing
+ pure e
+ | .const name _ =>
+ -- Sanity check: we eliminated all the recursive calls
+ if (nameToInfo.find? name).isSome then
+ throwError "mkUnaryBodies: a recursive call was not eliminated"
+ else pure e
+ | _ => pure e
+ trace[Diverge.def.genBody] "done with expression (depth: {i}): {e}"
+ pure ne
-- Explore the bodies
preDefs.mapM fun preDef => do
-- Replace the recursive calls
+ trace[Diverge.def.genBody] "About to replace recursive calls in {preDef.declName}"
let body ← mapVisit visit_e preDef.value
trace[Diverge.def.genBody] "Body after replacement of the recursive calls: {body}"
@@ -413,11 +469,8 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
let proveBranchValid (br : Expr) : MetaM Expr :=
if isIte then proveExprIsValid k_var kk_var br
else do
- -- There is a lambda -- TODO: how do we remove exacly *one* lambda?
- lambdaTelescope br fun xs br => do
- let x := xs.get! 0
- let xs := xs.extract 1 xs.size
- let br ← mkLambdaFVars xs br
+ -- There is a lambda
+ lambdaOne br fun x br => do
let brValid ← proveExprIsValid k_var kk_var br
mkLambdaFVars #[x] brValid
let br0Valid ← proveBranchValid br0
@@ -521,11 +574,8 @@ partial def proveExprIsValid (k_var kk_var : Expr) (e : Expr) : MetaM Expr := do
let xValid ← proveExprIsValid k_var kk_var x
trace[Diverge.def.valid] "bind: xValid:\n{xValid}:\n{← inferType xValid}"
let yValid ← do
- -- This is a lambda expression -- TODO: how do we remove exacly *one* lambda?
- lambdaTelescope y fun xs y => do
- let x := xs.get! 0
- let xs := xs.extract 1 xs.size
- let y ← mkLambdaFVars xs y
+ -- This is a lambda expression
+ lambdaOne y fun x y => do
trace[Diverge.def.valid] "bind: y: {y}"
let yValid ← proveExprIsValid k_var kk_var y
trace[Diverge.def.valid] "bind: yValid (no forall): {yValid}"
@@ -559,15 +609,12 @@ partial def proveMatchIsValid (k_var kk_var : Expr) (me : MatchInfo) : MetaM Exp
-- binders might come from the match, and some of the binders might come
-- from the fact that the expression in the match is a lambda expression:
-- we use the branchesNumParams field for this reason
- lambdaTelescope br fun xs br => do
let numParams := me.branchesNumParams.get! idx
- let xs_beg := xs.extract 0 numParams
- let xs_end := xs.extract numParams xs.size
- let br ← mkLambdaFVars xs_end br
+ lambdaTelescopeN br numParams fun xs br => do
-- Prove that the branch expression is valid
let brValid ← proveExprIsValid k_var kk_var br
-- Reconstruct the lambda expression
- mkLambdaFVars xs_beg brValid
+ mkLambdaFVars xs brValid
trace[Diverge.def.valid] "branchesValid:\n{branchesValid}"
-- Compute the motive, which has the following shape:
-- ```
@@ -726,15 +773,17 @@ def proveMutRecIsValid
-- def is_even (i : Int) : Result Bool := mut_rec_body 0 i
-- def is_odd (i : Int) : Result Bool := mut_rec_body 1 i
-- ```
-def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) :
+def mkDeclareFixDefs (mutRecBody : Expr) (inOutTys : Array (Expr × Expr)) (preDefs : Array PreDefinition) :
TermElabM (Array Name) := do
let grSize := preDefs.size
let defs ← preDefs.mapIdxM fun idx preDef => do
lambdaTelescope preDef.value fun xs _ => do
+ -- Retrieve the input type
+ let in_ty := (inOutTys.get! idx.val).fst
-- Create the index
let idx ← mkFinVal grSize idx.val
-- Group the inputs into a dependent tuple
- let input ← mkSigmasVal xs.toList
+ let input ← mkSigmasVal in_ty xs.toList
-- Apply the fixed point
let fixedBody ← mkAppM ``FixI.fix #[mutRecBody, idx, input]
let fixedBody ← mkLambdaFVars xs fixedBody
@@ -754,8 +803,8 @@ def mkDeclareFixDefs (mutRecBody : Expr) (preDefs : Array PreDefinition) :
pure defs
-- Prove the equations that we will use as unfolding theorems
-partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinition)
- (decls : Array Name) : MetaM Unit := do
+partial def proveUnfoldingThms (isValidThm : Expr) (inOutTys : Array (Expr × Expr))
+ (preDefs : Array PreDefinition) (decls : Array Name) : MetaM Unit := do
let grSize := preDefs.size
let proveIdx (i : Nat) : MetaM Unit := do
let preDef := preDefs.get! i
@@ -779,7 +828,7 @@ partial def proveUnfoldingThms (isValidThm : Expr) (preDefs : Array PreDefinitio
let idx ← mkFinVal grSize i
let proof ← mkAppM ``congr_fun #[proof, idx]
-- Add the input argument
- let arg ← mkSigmasVal xs.toList
+ let arg ← mkSigmasVal (inOutTys.get! i).fst xs.toList
let proof ← mkAppM ``congr_fun #[proof, arg]
-- Abstract the arguments away
let proof ← mkLambdaFVars xs proof
@@ -833,7 +882,7 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
forallTelescope preDef.type (fun in_tys out_ty => do
let in_ty ← liftM (mkSigmasType in_tys.toList)
-- Retrieve the type in the "Result"
- let out_ty ← get_result_ty out_ty
+ let out_ty ← getResultTy out_ty
let out_ty ← liftM (mkSigmasMatch in_tys.toList out_ty)
pure (in_ty, out_ty)
)
@@ -886,8 +935,8 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
-- Replace the recursive calls in all the function bodies by calls to the
-- continuation `k` and and generate for those bodies declarations
- trace[Diverge.def] "# Generating the unary bodies"
- let bodies ← mkDeclareUnaryBodies grLvlParams kk_var preDefs
+ trace[Diverge.def] "# Generating the unary bodies"
+ let bodies ← mkDeclareUnaryBodies grLvlParams kk_var inOutTys preDefs
trace[Diverge.def] "Unary bodies (after decl): {bodies}"
-- Generate the mutually recursive body
trace[Diverge.def] "# Generating the mut rec body"
@@ -903,11 +952,11 @@ def divRecursion (preDefs : Array PreDefinition) : TermElabM Unit := do
-- Generate the final definitions
trace[Diverge.def] "# Generating the final definitions"
- let decls ← mkDeclareFixDefs mutRecBody preDefs
+ let decls ← mkDeclareFixDefs mutRecBody inOutTys preDefs
-- Prove the unfolding theorems
trace[Diverge.def] "# Proving the unfolding theorems"
- proveUnfoldingThms isValidThm preDefs decls
+ proveUnfoldingThms isValidThm inOutTys preDefs decls
-- Generating code -- TODO
addAndCompilePartialRec preDefs
@@ -1102,6 +1151,15 @@ namespace Tests
#check isCons.unfold
+ -- Testing what happens when we use concrete arguments in dependent tuples
+ divergent def test1
+ (_ : Option Bool) (_ : Unit) :
+ Result Unit
+ :=
+ test1 Option.none ()
+
+ #check test1.unfold
+
end Tests
end Diverge
diff --git a/backends/lean/Base/Diverge/ElabBase.lean b/backends/lean/Base/Diverge/ElabBase.lean
index 1c1062c0..aaaea6f7 100644
--- a/backends/lean/Base/Diverge/ElabBase.lean
+++ b/backends/lean/Base/Diverge/ElabBase.lean
@@ -83,7 +83,10 @@ print_decl test1
print_decl test2
-- A map visitor function for expressions (adapted from `AbstractNestedProofs.visit`)
-partial def mapVisit (k : Expr → MetaM Expr) (e : Expr) : MetaM Expr := do
+-- The continuation takes as parameters:
+-- - the current depth of the expression (useful for printing/debugging)
+-- - the expression to explore
+partial def mapVisit (k : Nat → Expr → MetaM Expr) (e : Expr) : MetaM Expr := do
let mapVisitBinders (xs : Array Expr) (k2 : MetaM Expr) : MetaM Expr := do
let localInstances ← getLocalInstances
let mut lctx ← getLCtx
@@ -98,25 +101,27 @@ partial def mapVisit (k : Expr → MetaM Expr) (e : Expr) : MetaM Expr := do
lctx :=lctx.modifyLocalDecl xFVarId fun _ => localDecl
withLCtx lctx localInstances k2
-- TODO: use a cache? (Lean.checkCache)
- -- Explore
- let e ← k e
- match e with
- | .bvar _
- | .fvar _
- | .mvar _
- | .sort _
- | .lit _
- | .const _ _ => pure e
- | .app .. => do e.withApp fun f args => return mkAppN f (← args.mapM (mapVisit k))
- | .lam .. =>
- lambdaLetTelescope e fun xs b =>
- mapVisitBinders xs do mkLambdaFVars xs (← mapVisit k b) (usedLetOnly := false)
- | .forallE .. => do
- forallTelescope e fun xs b => mapVisitBinders xs do mkForallFVars xs (← mapVisit k b)
- | .letE .. => do
- lambdaLetTelescope e fun xs b => mapVisitBinders xs do
- mkLambdaFVars xs (← mapVisit k b) (usedLetOnly := false)
- | .mdata _ b => return e.updateMData! (← mapVisit k b)
- | .proj _ _ b => return e.updateProj! (← mapVisit k b)
+ let rec visit (i : Nat) (e : Expr) : MetaM Expr := do
+ -- Explore
+ let e ← k i e
+ match e with
+ | .bvar _
+ | .fvar _
+ | .mvar _
+ | .sort _
+ | .lit _
+ | .const _ _ => pure e
+ | .app .. => do e.withApp fun f args => return mkAppN f (← args.mapM (visit (i + 1)))
+ | .lam .. =>
+ lambdaLetTelescope e fun xs b =>
+ mapVisitBinders xs do mkLambdaFVars xs (← visit (i + 1) b) (usedLetOnly := false)
+ | .forallE .. => do
+ forallTelescope e fun xs b => mapVisitBinders xs do mkForallFVars xs (← visit (i + 1) b)
+ | .letE .. => do
+ lambdaLetTelescope e fun xs b => mapVisitBinders xs do
+ mkLambdaFVars xs (← visit (i + 1) b) (usedLetOnly := false)
+ | .mdata _ b => return e.updateMData! (← visit (i + 1) b)
+ | .proj _ _ b => return e.updateProj! (← visit (i + 1) b)
+ visit 0 e
end Diverge