From 7206b48a73d6204baea99f4f4675be2518a8f8c2 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 10 Jul 2023 15:06:12 +0200 Subject: Start working on the progress tactic --- backends/lean/Base/Arith/Arith.lean | 409 ++++++++++++++++++++++++++++++++++++ backends/lean/Base/Arith/Base.lean | 10 + 2 files changed, 419 insertions(+) create mode 100644 backends/lean/Base/Arith/Arith.lean create mode 100644 backends/lean/Base/Arith/Base.lean (limited to 'backends/lean/Base/Arith') diff --git a/backends/lean/Base/Arith/Arith.lean b/backends/lean/Base/Arith/Arith.lean new file mode 100644 index 00000000..0ba73d18 --- /dev/null +++ b/backends/lean/Base/Arith/Arith.lean @@ -0,0 +1,409 @@ +/- This file contains tactics to solve arithmetic goals -/ + +import Lean +import Lean.Meta.Tactic.Simp +import Init.Data.List.Basic +import Mathlib.Tactic.RunCmd +import Mathlib.Tactic.Linarith +-- TODO: there is no Omega tactic for now - it seems it hasn't been ported yet +--import Mathlib.Tactic.Omega +import Base.Primitives +import Base.Utils +import Base.Arith.Base + +/- +Mathlib tactics: +- rcases: https://leanprover-community.github.io/mathlib_docs/tactics.html#rcases +- split_ifs: https://leanprover-community.github.io/mathlib_docs/tactics.html#split_ifs +- norm_num: https://leanprover-community.github.io/mathlib_docs/tactics.html#norm_num +- should we use linarith or omega? +- hint: https://leanprover-community.github.io/mathlib_docs/tactics.html#hint +- classical: https://leanprover-community.github.io/mathlib_docs/tactics.html#classical +-/ + +namespace List + + -- TODO: I could not find this function?? + @[simp] def flatten {a : Type u} : List (List a) → List a + | [] => [] + | x :: ls => x ++ flatten ls + +end List + +namespace Lean + +namespace LocalContext + + open Lean Lean.Elab Command Term Lean.Meta + + -- Small utility: return the list of declarations in the context, from + -- the last to the first. + def getAllDecls (lctx : Lean.LocalContext) : MetaM (List Lean.LocalDecl) := + lctx.foldrM (fun d ls => do let d ← instantiateLocalDeclMVars d; pure (d :: ls)) [] + + -- Return the list of declarations in the context, but filter the + -- declarations which are considered as implementation details + def getDecls (lctx : Lean.LocalContext) : MetaM (List Lean.LocalDecl) := do + let ls ← lctx.getAllDecls + pure (ls.filter (fun d => not d.isImplementationDetail)) + +end LocalContext + +end Lean + +namespace Arith + +open Primitives + +-- TODO: move? +theorem ne_zero_is_lt_or_gt {x : Int} (hne : x ≠ 0) : x < 0 ∨ x > 0 := by + cases h: x <;> simp_all + . rename_i n; + cases n <;> simp_all + . apply Int.negSucc_lt_zero + +-- TODO: move? +theorem ne_is_lt_or_gt {x y : Int} (hne : x ≠ y) : x < y ∨ x > y := by + have hne : x - y ≠ 0 := by + simp + intro h + have: x = y := by linarith + simp_all + have h := ne_zero_is_lt_or_gt hne + match h with + | .inl _ => left; linarith + | .inr _ => right; linarith + +-- TODO: move +instance Vec.cast (a : Type): Coe (Vec a) (List a) where coe := λ v => v.val + +-- TODO: move +/- Remark: we can't write the following instance because of restrictions about + the type class parameters (`ty` doesn't appear in the return type, which is + forbidden): + + ``` + instance Scalar.cast (ty : ScalarTy) : Coe (Scalar ty) Int where coe := λ v => v.val + ``` + -/ +def Scalar.toInt {ty : ScalarTy} (x : Scalar ty) : Int := x.val + +-- Remark: I tried a version of the shape `HasProp {a : Type} (x : a)` +-- but the lookup didn't work +class HasProp (a : Sort u) where + prop_ty : a → Prop + prop : ∀ x:a, prop_ty x + +instance (ty : ScalarTy) : HasProp (Scalar ty) where + -- prop_ty is inferred + prop := λ x => And.intro x.hmin x.hmax + +instance (a : Type) : HasProp (Vec a) where + prop_ty := λ v => v.val.length ≤ Scalar.max ScalarTy.Usize + prop := λ ⟨ _, l ⟩ => l + +class PropHasImp (x : Prop) where + concl : Prop + prop : x → concl + +-- This also works for `x ≠ y` because this expression reduces to `¬ x = y` +-- and `Ne` is marked as `reducible` +instance (x y : Int) : PropHasImp (¬ x = y) where + concl := x < y ∨ x > y + prop := λ (h:x ≠ y) => ne_is_lt_or_gt h + +open Lean Lean.Elab Command Term Lean.Meta + +-- Small utility: print all the declarations in the context +elab "print_all_decls" : tactic => do + let ctx ← Lean.MonadLCtx.getLCtx + for decl in ← ctx.getDecls do + let ty ← Lean.Meta.inferType decl.toExpr + logInfo m!"{decl.toExpr} : {ty}" + pure () + +-- Explore a term by decomposing the applications (we explore the applied +-- functions and their arguments, but ignore lambdas, forall, etc. - +-- should we go inside?). +partial def foldTermApps (k : α → Expr → MetaM α) (s : α) (e : Expr) : MetaM α := do + -- We do it in a very simpler manner: we deconstruct applications, + -- and recursively explore the sub-expressions. Note that we do + -- not go inside foralls and abstractions (should we?). + e.withApp fun f args => do + let s ← k s f + args.foldlM (foldTermApps k) s + +-- Provided a function `k` which lookups type class instances on an expression, +-- collect all the instances lookuped by applying `k` on the sub-expressions of `e`. +def collectInstances + (k : Expr → MetaM (Option Expr)) (s : HashSet Expr) (e : Expr) : MetaM (HashSet Expr) := do + let k s e := do + match ← k e with + | none => pure s + | some i => pure (s.insert i) + foldTermApps k s e + +-- Similar to `collectInstances`, but explores all the local declarations in the +-- main context. +def collectInstancesFromMainCtx (k : Expr → MetaM (Option Expr)) : Tactic.TacticM (HashSet Expr) := do + Lean.Elab.Tactic.withMainContext do + -- Get the local context + let ctx ← Lean.MonadLCtx.getLCtx + -- Just a matter of precaution + let ctx ← instantiateLCtxMVars ctx + -- Initialize the hashset + let hs := HashSet.empty + -- Explore the declarations + let decls ← ctx.getDecls + decls.foldlM (fun hs d => collectInstances k hs d.toExpr) hs + +-- Return an instance of `HasProp` for `e` if it has some +def lookupHasProp (e : Expr) : MetaM (Option Expr) := do + trace[Arith] "lookupHasProp" + -- TODO: do we need Lean.observing? + -- This actually eliminates the error messages + Lean.observing? do + trace[Arith] "lookupHasProp: observing" + let ty ← Lean.Meta.inferType e + let hasProp ← mkAppM ``HasProp #[ty] + let hasPropInst ← trySynthInstance hasProp + match hasPropInst with + | LOption.some i => + trace[Arith] "Found HasProp instance" + let i_prop ← mkProjection i (Name.mkSimple "prop") + some (← mkAppM' i_prop #[e]) + | _ => none + +-- Collect the instances of `HasProp` for the subexpressions in the context +def collectHasPropInstancesFromMainCtx : Tactic.TacticM (HashSet Expr) := do + collectInstancesFromMainCtx lookupHasProp + +elab "display_has_prop_instances" : tactic => do + trace[Arith] "Displaying the HasProp instances" + let hs ← collectHasPropInstancesFromMainCtx + hs.forM fun e => do + trace[Arith] "+ HasProp instance: {e}" + +example (x : U32) : True := by + let i : HasProp U32 := inferInstance + have p := @HasProp.prop _ i x + simp only [HasProp.prop_ty] at p + display_has_prop_instances + simp + +-- Return an instance of `PropHasImp` for `e` if it has some +def lookupPropHasImp (e : Expr) : MetaM (Option Expr) := do + trace[Arith] "lookupPropHasImp" + -- TODO: do we need Lean.observing? + -- This actually eliminates the error messages + Lean.observing? do + trace[Arith] "lookupPropHasImp: observing" + let ty ← Lean.Meta.inferType e + trace[Arith] "lookupPropHasImp: ty: {ty}" + let cl ← mkAppM ``PropHasImp #[ty] + let inst ← trySynthInstance cl + match inst with + | LOption.some i => + trace[Arith] "Found PropHasImp instance" + let i_prop ← mkProjection i (Name.mkSimple "prop") + some (← mkAppM' i_prop #[e]) + | _ => none + +-- Collect the instances of `PropHasImp` for the subexpressions in the context +def collectPropHasImpInstancesFromMainCtx : Tactic.TacticM (HashSet Expr) := do + collectInstancesFromMainCtx lookupPropHasImp + +elab "display_prop_has_imp_instances" : tactic => do + trace[Arith] "Displaying the PropHasImp instances" + let hs ← collectPropHasImpInstancesFromMainCtx + hs.forM fun e => do + trace[Arith] "+ PropHasImp instance: {e}" + +example (x y : Int) (_ : x ≠ y) (_ : ¬ x = y) : True := by + display_prop_has_imp_instances + simp + +-- Lookup instances in a context and introduce them with additional declarations. +def introInstances (declToUnfold : Name) (lookup : Expr → MetaM (Option Expr)) : Tactic.TacticM (Array Expr) := do + let hs ← collectInstancesFromMainCtx lookup + hs.toArray.mapM fun e => do + let type ← inferType e + let name ← mkFreshUserName `h + -- Add a declaration + let nval ← Utils.addDecl name e type (asLet := false) + -- Simplify to unfold the declaration to unfold (i.e., the projector) + let simpTheorems ← Tactic.simpOnlyBuiltins.foldlM (·.addConst ·) ({} : SimpTheorems) + -- Add the equational theorem for the decl to unfold + let simpTheorems ← simpTheorems.addDeclToUnfold declToUnfold + let congrTheorems ← getSimpCongrTheorems + let ctx : Simp.Context := { simpTheorems := #[simpTheorems], congrTheorems } + -- Where to apply the simplifier + let loc := Tactic.Location.targets #[mkIdent name] false + -- Apply the simplifier + let _ ← Tactic.simpLocation ctx (discharge? := .none) loc + -- Return the new value + pure nval + +-- Lookup the instances of `HasProp for all the sub-expressions in the context, +-- and introduce the corresponding assumptions +elab "intro_has_prop_instances" : tactic => do + trace[Arith] "Introducing the HasProp instances" + let _ ← introInstances ``HasProp.prop_ty lookupHasProp + +example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by + intro_has_prop_instances + simp [*] + +example {a: Type} (v : Vec a) : v.val.length ≤ Scalar.max ScalarTy.Usize := by + intro_has_prop_instances + simp_all [Scalar.max, Scalar.min] + +-- Tactic to split on a disjunction. +-- The expression `h` should be an fvar. +def splitDisj (h : Expr) (kleft kright : Tactic.TacticM Unit) : Tactic.TacticM Unit := do + trace[Arith] "assumption on which to split: {h}" + -- Retrieve the main goal + Lean.Elab.Tactic.withMainContext do + let goalType ← Lean.Elab.Tactic.getMainTarget + let hDecl := (← getLCtx).get! h.fvarId! + let hName := hDecl.userName + -- Case disjunction + let hTy ← inferType h + hTy.withApp fun f xs => do + trace[Arith] "as app: {f} {xs}" + -- Sanity check + if ¬ (f.isConstOf ``Or ∧ xs.size = 2) then throwError "Invalid argument to splitDisj" + let a := xs.get! 0 + let b := xs.get! 1 + -- Introduce the new goals + -- Returns: + -- - the match branch + -- - a fresh new mvar id + let mkGoal (hTy : Expr) (nGoalName : String) : MetaM (Expr × MVarId) := do + -- Introduce a variable for the assumption (`a` or `b`). Note that we reuse + -- the name of the assumption we split. + withLocalDeclD hName hTy fun var => do + -- The new goal + let mgoal ← mkFreshExprSyntheticOpaqueMVar goalType (tag := Name.mkSimple nGoalName) + -- Clear the assumption that we split + let mgoal ← mgoal.mvarId!.tryClearMany #[h.fvarId!] + -- The branch expression + let branch ← mkLambdaFVars #[var] (mkMVar mgoal) + pure (branch, mgoal) + let (inl, mleft) ← mkGoal a "left" + let (inr, mright) ← mkGoal b "right" + trace[Arith] "left: {inl}: {mleft}" + trace[Arith] "right: {inr}: {mright}" + -- Create the match expression + withLocalDeclD (← mkFreshUserName `h) hTy fun hVar => do + let motive ← mkLambdaFVars #[hVar] goalType + let casesExpr ← mkAppOptM ``Or.casesOn #[a, b, motive, h, inl, inr] + let mgoal ← Tactic.getMainGoal + trace[Arith] "goals: {← Tactic.getUnsolvedGoals}" + trace[Arith] "main goal: {mgoal}" + mgoal.assign casesExpr + let goals ← Tactic.getUnsolvedGoals + -- Focus on the left + Tactic.setGoals [mleft] + kleft + let leftGoals ← Tactic.getUnsolvedGoals + -- Focus on the right + Tactic.setGoals [mright] + kright + let rightGoals ← Tactic.getUnsolvedGoals + -- Put all the goals back + Tactic.setGoals (leftGoals ++ rightGoals ++ goals) + trace[Arith] "new goals: {← Tactic.getUnsolvedGoals}" + +elab "split_disj " n:ident : tactic => do + Lean.Elab.Tactic.withMainContext do + let decl ← Lean.Meta.getLocalDeclFromUserName n.getId + let fvar := mkFVar decl.fvarId + splitDisj fvar (fun _ => pure ()) (fun _ => pure ()) + +example (x y : Int) (h0 : x ≤ y ∨ x ≥ y) : x ≤ y ∨ x ≥ y := by + split_disj h0 + . left; assumption + . right; assumption + +-- Lookup the instances of `PropHasImp for all the sub-expressions in the context, +-- and introduce the corresponding assumptions +elab "intro_prop_has_imp_instances" : tactic => do + trace[Arith] "Introducing the PropHasImp instances" + let _ ← introInstances ``PropHasImp.concl lookupPropHasImp + +example (x y : Int) (h0 : x ≤ y) (h1 : x ≠ y) : x < y := by + intro_prop_has_imp_instances + rename_i h + split_disj h + . linarith + . linarith + +/- Boosting a bit the linarith tac. + + We do the following: + - for all the assumptions of the shape `(x : Int) ≠ y` or `¬ (x = y), we + introduce two goals with the assumptions `x < y` and `x > y` + TODO: we could create a PR for mathlib. + -/ +def intTacPreprocess : Tactic.TacticM Unit := do + Lean.Elab.Tactic.withMainContext do + -- Lookup the instances of PropHasImp (this is how we detect assumptions + -- of the proper shape), introduce assumptions in the context and split + -- on those + -- TODO: get rid of the assumptions that we split + let rec splitOnAsms (asms : List Expr) : Tactic.TacticM Unit := + match asms with + | [] => pure () + | asm :: asms => + let k := splitOnAsms asms + splitDisj asm k k + -- Introduce + let asms ← introInstances ``PropHasImp.concl lookupPropHasImp + -- Split + splitOnAsms asms.toList + +elab "int_tac_preprocess" : tactic => + intTacPreprocess + +example (x : Int) (h0: 0 ≤ x) (h1: x ≠ 0) : 0 < x := by + int_tac_preprocess + linarith + linarith + +syntax "int_tac" : tactic +macro_rules + | `(tactic| int_tac) => + `(tactic| + (repeat (apply And.intro)) <;> -- TODO: improve this + int_tac_preprocess <;> + linarith) + +example (x : Int) (h0: 0 ≤ x) (h1: x ≠ 0) : 0 < x := by + int_tac + +-- Checking that things append correctly when there are several disjunctions +example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : 0 < x ∧ 0 < y := by + int_tac + +-- A tactic to solve linear arithmetic goals in the presence of scalars +syntax "scalar_tac" : tactic +macro_rules + | `(tactic| scalar_tac) => + `(tactic| + intro_has_prop_instances; + have := Scalar.cMin_bound ScalarTy.Usize; + have := Scalar.cMin_bound ScalarTy.Isize; + have := Scalar.cMax_bound ScalarTy.Usize; + have := Scalar.cMax_bound ScalarTy.Isize; + -- TODO: not too sure about that + simp only [*, Scalar.max, Scalar.min, Scalar.cMin, Scalar.cMax] at *; + int_tac) + +example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by + scalar_tac + +example {a: Type} (v : Vec a) : v.val.length ≤ Scalar.max ScalarTy.Usize := by + scalar_tac + +end Arith diff --git a/backends/lean/Base/Arith/Base.lean b/backends/lean/Base/Arith/Base.lean new file mode 100644 index 00000000..ddd2dc24 --- /dev/null +++ b/backends/lean/Base/Arith/Base.lean @@ -0,0 +1,10 @@ +import Lean + +namespace Arith + +open Lean Elab Term Meta + +-- We can't define and use trace classes in the same file +initialize registerTraceClass `Arith + +end Arith -- cgit v1.2.3