summaryrefslogtreecommitdiff
path: root/backends/lean/Base
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--backends/lean/Base/Utils.lean112
1 files changed, 81 insertions, 31 deletions
diff --git a/backends/lean/Base/Utils.lean b/backends/lean/Base/Utils.lean
index 2366e800..b0032281 100644
--- a/backends/lean/Base/Utils.lean
+++ b/backends/lean/Base/Utils.lean
@@ -159,47 +159,96 @@ elab "print_ctx_decls" : tactic => do
let decls ← ctx.getDecls
printDecls decls
--- A map visitor function for expressions (adapted from `AbstractNestedProofs.visit`)
+-- A map-reduce visitor function for expressions (adapted from `AbstractNestedProofs.visit`)
-- The continuation takes as parameters:
-- - the current depth of the expression (useful for printing/debugging)
-- - the expression to explore
-partial def mapVisit (k : Nat → Expr → MetaM Expr) (e : Expr) : MetaM Expr := do
- let mapVisitBinders (xs : Array Expr) (k2 : MetaM Expr) : MetaM Expr := do
+partial def mapreduceVisit {a : Type} (k : Nat → a → Expr → MetaM (a × Expr))
+ (state : a) (e : Expr) : MetaM (a × Expr) := do
+ let mapreduceVisitBinders (state : a) (xs : Array Expr) (k2 : a → MetaM (a × Expr)) :
+ MetaM (a × Expr) := do
let localInstances ← getLocalInstances
- let mut lctx ← getLCtx
- for x in xs do
- let xFVarId := x.fvarId!
- let localDecl ← xFVarId.getDecl
- let type ← mapVisit k localDecl.type
- let localDecl := localDecl.setType type
- let localDecl ← match localDecl.value? with
- | some value => let value ← mapVisit k value; pure <| localDecl.setValue value
- | none => pure localDecl
- lctx :=lctx.modifyLocalDecl xFVarId fun _ => localDecl
- withLCtx lctx localInstances k2
+ -- Update the local declarations for the bindings in context `lctx`
+ let rec visit_xs (lctx : LocalContext) (state : a) (xs : List Expr) : MetaM (LocalContext × a) := do
+ match xs with
+ | [] => pure (lctx, state)
+ | x :: xs => do
+ let xFVarId := x.fvarId!
+ let localDecl ← xFVarId.getDecl
+ let (state, type) ← mapreduceVisit k state localDecl.type
+ let localDecl := localDecl.setType type
+ let (state, localDecl) ← match localDecl.value? with
+ | some value =>
+ let (state, value) ← mapreduceVisit k state value
+ pure (state, localDecl.setValue value)
+ | none => pure (state, localDecl)
+ let lctx := lctx.modifyLocalDecl xFVarId fun _ => localDecl
+ -- Recursive call
+ visit_xs lctx state xs
+ let (lctx, state) ← visit_xs (← getLCtx) state xs.toList
+ -- Call the continuation with the updated context
+ withLCtx lctx localInstances (k2 state)
-- TODO: use a cache? (Lean.checkCache)
- let rec visit (i : Nat) (e : Expr) : MetaM Expr := do
+ let rec visit (i : Nat) (state : a) (e : Expr) : MetaM (a × Expr) := do
-- Explore
- let e ← k i e
+ let (state, e) ← k i state e
match e with
| .bvar _
| .fvar _
| .mvar _
| .sort _
| .lit _
- | .const _ _ => pure e
- | .app .. => do e.withApp fun f args => return mkAppN f (← args.mapM (visit (i + 1)))
+ | .const _ _ => pure (state, e)
+ | .app .. => do e.withApp fun f args => do
+ let (state, args) ← args.foldlM (fun (state, args) arg => do let (state, arg) ← visit (i + 1) state arg; pure (state, arg :: args)) (state, [])
+ let args := args.reverse
+ let (state, f) ← visit (i + 1) state f
+ let e' := mkAppN f (Array.mk args)
+ return (state, e')
| .lam .. =>
lambdaLetTelescope e fun xs b =>
- mapVisitBinders xs do mkLambdaFVars xs (← visit (i + 1) b) (usedLetOnly := false)
+ mapreduceVisitBinders state xs fun state => do
+ let (state, b) ← visit (i + 1) state b
+ let e' ← mkLambdaFVars xs b (usedLetOnly := false)
+ return (state, e')
| .forallE .. => do
- forallTelescope e fun xs b => mapVisitBinders xs do mkForallFVars xs (← visit (i + 1) b)
+ forallTelescope e fun xs b =>
+ mapreduceVisitBinders state xs fun state => do
+ let (state, b) ← visit (i + 1) state b
+ let e' ← mkForallFVars xs b
+ return (state, e')
| .letE .. => do
- lambdaLetTelescope e fun xs b => mapVisitBinders xs do
- mkLambdaFVars xs (← visit (i + 1) b) (usedLetOnly := false)
- | .mdata _ b => return e.updateMData! (← visit (i + 1) b)
- | .proj _ _ b => return e.updateProj! (← visit (i + 1) b)
- visit 0 e
+ lambdaLetTelescope e fun xs b =>
+ mapreduceVisitBinders state xs fun state => do
+ let (state, b) ← visit (i + 1) state b
+ let e' ← mkLambdaFVars xs b (usedLetOnly := false)
+ return (state, e')
+ | .mdata _ b => do
+ let (state, b) ← visit (i + 1) state b
+ return (state, e.updateMData! b)
+ | .proj _ _ b => do
+ let (state, b) ← visit (i + 1) state b
+ return (state, e.updateProj! b)
+ visit 0 state e
+
+-- A map visitor function for expressions (adapted from `AbstractNestedProofs.visit`)
+-- The continuation takes as parameters:
+-- - the current depth of the expression (useful for printing/debugging)
+-- - the expression to explore
+partial def mapVisit (k : Nat → Expr → MetaM Expr) (e : Expr) : MetaM Expr := do
+ let k' i (_ : Unit) e := do
+ let e ← k i e
+ pure ((), e)
+ let (_, e) ← mapreduceVisit k' () e
+ pure e
+
+-- A reduce visitor
+partial def reduceVisit {a : Type} (k : Nat → a → Expr → MetaM a) (s : a) (e : Expr) : MetaM a := do
+ let k' i (s : a) e := do
+ let s ← k i s e
+ pure (s, e)
+ let (s, _) ← mapreduceVisit k' s e
+ pure s
-- Generate a fresh user name for an anonymous proposition to introduce in the
-- assumptions
@@ -376,16 +425,17 @@ def destEq (e : Expr) : MetaM (Expr × Expr) := do
else throwError "Not an equality: {e}"
-- Return the set of FVarIds in the expression
+-- TODO: this collects fvars introduced in the inner bindings
partial def getFVarIds (e : Expr) (hs : HashSet FVarId := HashSet.empty) : MetaM (HashSet FVarId) := do
- e.consumeMData.withApp fun body args => do
- let hs := if body.isFVar then hs.insert body.fvarId! else hs
- args.foldlM (fun hs arg => getFVarIds arg hs) hs
+ reduceVisit (fun _ (hs : HashSet FVarId) e =>
+ if e.isFVar then pure (hs.insert e.fvarId!) else pure hs)
+ hs e
-- Return the set of MVarIds in the expression
partial def getMVarIds (e : Expr) (hs : HashSet MVarId := HashSet.empty) : MetaM (HashSet MVarId) := do
- e.consumeMData.withApp fun body args => do
- let hs := if body.isMVar then hs.insert body.mvarId! else hs
- args.foldlM (fun hs arg => getMVarIds arg hs) hs
+ reduceVisit (fun _ (hs : HashSet MVarId) e =>
+ if e.isMVar then pure (hs.insert e.mvarId!) else pure hs)
+ hs e
-- Tactic to split on a disjunction.
-- The expression `h` should be an fvar.