diff options
Diffstat (limited to '')
-rw-r--r-- | backends/lean/Base/Utils.lean | 112 |
1 files changed, 81 insertions, 31 deletions
diff --git a/backends/lean/Base/Utils.lean b/backends/lean/Base/Utils.lean index 2366e800..b0032281 100644 --- a/backends/lean/Base/Utils.lean +++ b/backends/lean/Base/Utils.lean @@ -159,47 +159,96 @@ elab "print_ctx_decls" : tactic => do let decls ← ctx.getDecls printDecls decls --- A map visitor function for expressions (adapted from `AbstractNestedProofs.visit`) +-- A map-reduce visitor function for expressions (adapted from `AbstractNestedProofs.visit`) -- The continuation takes as parameters: -- - the current depth of the expression (useful for printing/debugging) -- - the expression to explore -partial def mapVisit (k : Nat → Expr → MetaM Expr) (e : Expr) : MetaM Expr := do - let mapVisitBinders (xs : Array Expr) (k2 : MetaM Expr) : MetaM Expr := do +partial def mapreduceVisit {a : Type} (k : Nat → a → Expr → MetaM (a × Expr)) + (state : a) (e : Expr) : MetaM (a × Expr) := do + let mapreduceVisitBinders (state : a) (xs : Array Expr) (k2 : a → MetaM (a × Expr)) : + MetaM (a × Expr) := do let localInstances ← getLocalInstances - let mut lctx ← getLCtx - for x in xs do - let xFVarId := x.fvarId! - let localDecl ← xFVarId.getDecl - let type ← mapVisit k localDecl.type - let localDecl := localDecl.setType type - let localDecl ← match localDecl.value? with - | some value => let value ← mapVisit k value; pure <| localDecl.setValue value - | none => pure localDecl - lctx :=lctx.modifyLocalDecl xFVarId fun _ => localDecl - withLCtx lctx localInstances k2 + -- Update the local declarations for the bindings in context `lctx` + let rec visit_xs (lctx : LocalContext) (state : a) (xs : List Expr) : MetaM (LocalContext × a) := do + match xs with + | [] => pure (lctx, state) + | x :: xs => do + let xFVarId := x.fvarId! + let localDecl ← xFVarId.getDecl + let (state, type) ← mapreduceVisit k state localDecl.type + let localDecl := localDecl.setType type + let (state, localDecl) ← match localDecl.value? with + | some value => + let (state, value) ← mapreduceVisit k state value + pure (state, localDecl.setValue value) + | none => pure (state, localDecl) + let lctx := lctx.modifyLocalDecl xFVarId fun _ => localDecl + -- Recursive call + visit_xs lctx state xs + let (lctx, state) ← visit_xs (← getLCtx) state xs.toList + -- Call the continuation with the updated context + withLCtx lctx localInstances (k2 state) -- TODO: use a cache? (Lean.checkCache) - let rec visit (i : Nat) (e : Expr) : MetaM Expr := do + let rec visit (i : Nat) (state : a) (e : Expr) : MetaM (a × Expr) := do -- Explore - let e ← k i e + let (state, e) ← k i state e match e with | .bvar _ | .fvar _ | .mvar _ | .sort _ | .lit _ - | .const _ _ => pure e - | .app .. => do e.withApp fun f args => return mkAppN f (← args.mapM (visit (i + 1))) + | .const _ _ => pure (state, e) + | .app .. => do e.withApp fun f args => do + let (state, args) ← args.foldlM (fun (state, args) arg => do let (state, arg) ← visit (i + 1) state arg; pure (state, arg :: args)) (state, []) + let args := args.reverse + let (state, f) ← visit (i + 1) state f + let e' := mkAppN f (Array.mk args) + return (state, e') | .lam .. => lambdaLetTelescope e fun xs b => - mapVisitBinders xs do mkLambdaFVars xs (← visit (i + 1) b) (usedLetOnly := false) + mapreduceVisitBinders state xs fun state => do + let (state, b) ← visit (i + 1) state b + let e' ← mkLambdaFVars xs b (usedLetOnly := false) + return (state, e') | .forallE .. => do - forallTelescope e fun xs b => mapVisitBinders xs do mkForallFVars xs (← visit (i + 1) b) + forallTelescope e fun xs b => + mapreduceVisitBinders state xs fun state => do + let (state, b) ← visit (i + 1) state b + let e' ← mkForallFVars xs b + return (state, e') | .letE .. => do - lambdaLetTelescope e fun xs b => mapVisitBinders xs do - mkLambdaFVars xs (← visit (i + 1) b) (usedLetOnly := false) - | .mdata _ b => return e.updateMData! (← visit (i + 1) b) - | .proj _ _ b => return e.updateProj! (← visit (i + 1) b) - visit 0 e + lambdaLetTelescope e fun xs b => + mapreduceVisitBinders state xs fun state => do + let (state, b) ← visit (i + 1) state b + let e' ← mkLambdaFVars xs b (usedLetOnly := false) + return (state, e') + | .mdata _ b => do + let (state, b) ← visit (i + 1) state b + return (state, e.updateMData! b) + | .proj _ _ b => do + let (state, b) ← visit (i + 1) state b + return (state, e.updateProj! b) + visit 0 state e + +-- A map visitor function for expressions (adapted from `AbstractNestedProofs.visit`) +-- The continuation takes as parameters: +-- - the current depth of the expression (useful for printing/debugging) +-- - the expression to explore +partial def mapVisit (k : Nat → Expr → MetaM Expr) (e : Expr) : MetaM Expr := do + let k' i (_ : Unit) e := do + let e ← k i e + pure ((), e) + let (_, e) ← mapreduceVisit k' () e + pure e + +-- A reduce visitor +partial def reduceVisit {a : Type} (k : Nat → a → Expr → MetaM a) (s : a) (e : Expr) : MetaM a := do + let k' i (s : a) e := do + let s ← k i s e + pure (s, e) + let (s, _) ← mapreduceVisit k' s e + pure s -- Generate a fresh user name for an anonymous proposition to introduce in the -- assumptions @@ -376,16 +425,17 @@ def destEq (e : Expr) : MetaM (Expr × Expr) := do else throwError "Not an equality: {e}" -- Return the set of FVarIds in the expression +-- TODO: this collects fvars introduced in the inner bindings partial def getFVarIds (e : Expr) (hs : HashSet FVarId := HashSet.empty) : MetaM (HashSet FVarId) := do - e.consumeMData.withApp fun body args => do - let hs := if body.isFVar then hs.insert body.fvarId! else hs - args.foldlM (fun hs arg => getFVarIds arg hs) hs + reduceVisit (fun _ (hs : HashSet FVarId) e => + if e.isFVar then pure (hs.insert e.fvarId!) else pure hs) + hs e -- Return the set of MVarIds in the expression partial def getMVarIds (e : Expr) (hs : HashSet MVarId := HashSet.empty) : MetaM (HashSet MVarId) := do - e.consumeMData.withApp fun body args => do - let hs := if body.isMVar then hs.insert body.mvarId! else hs - args.foldlM (fun hs arg => getMVarIds arg hs) hs + reduceVisit (fun _ (hs : HashSet MVarId) e => + if e.isMVar then pure (hs.insert e.mvarId!) else pure hs) + hs e -- Tactic to split on a disjunction. -- The expression `h` should be an fvar. |