From 9f0e4605e1c8816dbf5ed3e9e893b25e9a2be4a3 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 26 Jan 2024 00:17:59 +0100 Subject: Improve the Lean backend --- backends/lean/Base/Arith/Scalar.lean | 11 +++++++++-- backends/lean/Base/Primitives/Base.lean | 7 +++++++ backends/lean/Base/Primitives/Scalar.lean | 11 +++++++++++ backends/lean/Base/Progress/Base.lean | 20 ++++++++++++-------- backends/lean/Base/Progress/Progress.lean | 13 +++++++++++++ 5 files changed, 52 insertions(+), 10 deletions(-) (limited to 'backends/lean') diff --git a/backends/lean/Base/Arith/Scalar.lean b/backends/lean/Base/Arith/Scalar.lean index 2342cce6..43fd2766 100644 --- a/backends/lean/Base/Arith/Scalar.lean +++ b/backends/lean/Base/Arith/Scalar.lean @@ -17,13 +17,20 @@ def scalarTacExtraPreprocess : Tactic.TacticM Unit := do add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []]) add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []]) -- Reveal the concrete bounds, simplify calls to [ofInt] - Utils.simpAt true [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax, + Utils.simpAt true + -- Unfoldings + [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax, ``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min, ``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max, ``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min, ``U8.max, ``U16.max, ``U32.max, ``U64.max, ``U128.max, ``Usize.min - ] [``Scalar.ofInt_val_eq, ``Scalar.neq_to_neq_val] [] .wildcard + ] + -- Simp lemmas + [``Scalar.ofInt_val_eq, ``Scalar.neq_to_neq_val, + ``Scalar.lt_equiv, ``Scalar.le_equiv, ``Scalar.eq_equiv] + -- Hypotheses + [] .wildcard elab "scalar_tac_preprocess" : tactic => diff --git a/backends/lean/Base/Primitives/Base.lean b/backends/lean/Base/Primitives/Base.lean index 3d70c84a..9dbaf133 100644 --- a/backends/lean/Base/Primitives/Base.lean +++ b/backends/lean/Base/Primitives/Base.lean @@ -116,6 +116,13 @@ def Result.attach {α: Type} (o : Result α): Result { x : α // o = ret x } := @[simp] theorem bind_tc_div (f : α → Result β) : (do let y ← div; f y) = div := by simp [Bind.bind, bind] +@[simp] theorem bind_assoc_eq {a b c : Type u} + (e : Result a) (g : a → Result b) (h : b → Result c) : + (Bind.bind (Bind.bind e g) h) = + (Bind.bind e (λ x => Bind.bind (g x) h)) := by + simp [Bind.bind] + cases e <;> simp + ---------- -- MISC -- ---------- diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean index a8eda6d5..2c34774b 100644 --- a/backends/lean/Base/Primitives/Scalar.lean +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -1038,6 +1038,17 @@ instance {ty} : LT (Scalar ty) where instance {ty} : LE (Scalar ty) where le a b := LE.le a.val b.val +theorem Scalar.lt_equiv {ty : ScalarTy} (x y : Scalar ty) : + x < y ↔ x.val < y.val := by simp [LT.lt] + +theorem Scalar.le_equiv {ty : ScalarTy} (x y : Scalar ty) : + x ≤ y ↔ x.val ≤ y.val := by simp [LE.le] + +-- Not marking this one with @[simp] on purpose +theorem Scalar.eq_equiv {ty : ScalarTy} (x y : Scalar ty) : + x = y ↔ x.val = y.val := by + cases x; cases y; simp_all + instance Scalar.decLt {ty} (a b : Scalar ty) : Decidable (LT.lt a b) := Int.decLt .. instance Scalar.decLe {ty} (a b : Scalar ty) : Decidable (LE.le a b) := Int.decLe .. diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean index 0ad16ab6..935af3f5 100644 --- a/backends/lean/Base/Progress/Base.lean +++ b/backends/lean/Base/Progress/Base.lean @@ -82,17 +82,21 @@ section Methods -- Destruct the equality let (mExpr, ret) ← destEq th.consumeMData trace[Progress] "After splitting the equality:\n- lhs: {th}\n- rhs: {ret}" - -- Destruct the monadic application to dive into the bind, if necessary (this - -- is for when we use `withPSpec` inside of the `progress` tactic), and - -- destruct the application to get the function name - mExpr.consumeMData.withApp fun mf margs => do - trace[Progress] "After stripping the arguments of the monad expression:\n- mf: {mf}\n- margs: {margs}" - let (fArgsExpr, f, args) ← do + -- Recursively destruct the monadic application to dive into the binds, + -- if necessary (this is for when we use `withPSpec` inside of the `progress` tactic), + -- and destruct the application to get the function name + let rec strip_monad mExpr := do + mExpr.consumeMData.withApp fun mf margs => do + trace[Progress] "After stripping the arguments of the monad expression:\n- mf: {mf}\n- margs: {margs}" if mf.isConst ∧ mf.constName = ``Bind.bind then do -- Dive into the bind let fExpr := (margs.get! 4).consumeMData - fExpr.withApp fun f args => pure (fExpr, f, args) - else pure (mExpr, mf, margs) + -- Recursve + strip_monad fExpr + else + -- No bind + pure (mExpr, mf, margs) + let (fArgsExpr, f, args) ← strip_monad mExpr trace[Progress] "After stripping the arguments of the function call:\n- f: {f}\n- args: {args}" if ¬ f.isConst then throwError "Not a constant: {f}" -- *Sanity check* (activated if we are analyzing a theorem to register it in a DB) diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean index a6a4e82a..0fb276aa 100644 --- a/backends/lean/Base/Progress/Progress.lean +++ b/backends/lean/Base/Progress/Progress.lean @@ -421,6 +421,19 @@ namespace Test progress simp [*] + /- Checking that progress can handle nested blocks -/ + example {α : Type} (v: Vec α) (i: Usize) (x : α) + (hbounds : i.val < v.length) : + ∃ nv, + (do + (do + let _ ← v.update_usize i x + .ret ()) + .ret ()) = ret nv + := by + progress + simp [*] + end Test end Progress -- cgit v1.2.3 From c709eadb14e2ecd21c9c4a6a9def39334f27552b Mon Sep 17 00:00:00 2001 From: Son Ho Date: Sat, 27 Jan 2024 21:26:38 +0100 Subject: Fix a minor issue with the progress tactic --- backends/lean/Base/Progress/Base.lean | 16 ++- backends/lean/Base/Progress/Progress.lean | 178 +++++++++++++++++------------- 2 files changed, 112 insertions(+), 82 deletions(-) (limited to 'backends/lean') diff --git a/backends/lean/Base/Progress/Base.lean b/backends/lean/Base/Progress/Base.lean index 935af3f5..a64212a5 100644 --- a/backends/lean/Base/Progress/Base.lean +++ b/backends/lean/Base/Progress/Base.lean @@ -22,8 +22,9 @@ structure PSpecDesc where evars : Array Expr -- The function applied to its arguments fArgsExpr : Expr - -- The function - fName : Name + -- ⊤ if the function is a constant (must be if we are registering a theorem, + -- but is not necessarily the case if we are looking at a goal) + fIsConst : Bool -- The function arguments fLevels : List Level args : Array Expr @@ -98,7 +99,12 @@ section Methods pure (mExpr, mf, margs) let (fArgsExpr, f, args) ← strip_monad mExpr trace[Progress] "After stripping the arguments of the function call:\n- f: {f}\n- args: {args}" - if ¬ f.isConst then throwError "Not a constant: {f}" + let fLevels ← do + -- If we are registering a theorem, then the function must be a constant + if ¬ f.isConst then + if isGoal then pure [] + else throwError "Not a constant: {f}" + else pure f.constLevels! -- *Sanity check* (activated if we are analyzing a theorem to register it in a DB) -- Check if some existentially quantified variables let _ := do @@ -117,8 +123,8 @@ section Methods fvars := fvars evars := evars fArgsExpr - fName := f.constName! - fLevels := f.constLevels! + fIsConst := f.isConst + fLevels args := args ret := ret post := post diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean index 0fb276aa..dc30c441 100644 --- a/backends/lean/Base/Progress/Progress.lean +++ b/backends/lean/Base/Progress/Progress.lean @@ -106,7 +106,7 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal) withMainContext do -- The context changed - TODO: remove once addDeclTac is updated let ngoal ← getMainGoal trace[Progress] "current goal: {ngoal}" - trace[Progress] "current goal: {← ngoal.isAssigned}" + trace[Progress] "current goal is assigned: {← ngoal.isAssigned}" -- The assumption should be of the shape: -- `∃ x1 ... xn, f args = ... ∧ ...` -- We introduce the existentially quantified variables and split the top-most @@ -131,50 +131,59 @@ def progressWith (fExpr : Expr) (th : TheoremOrLocal) -- then continue splitting the post-condition splitEqAndPost fun hEq hPost ids => do trace[Progress] "eq and post:\n{hEq} : {← inferType hEq}\n{hPost}" - tryTac ( - simpAt true [] - [``Primitives.bind_tc_ret, ``Primitives.bind_tc_fail, ``Primitives.bind_tc_div] - [hEq.fvarId!] (.targets #[] true)) - -- TODO: remove this (some types get unfolded too much: we "fold" them back) - tryTac (simpAt true [] scalar_eqs [] .wildcard_dep) - -- Clear the equality, unless the user requests not to do so - let mgoal ← do - if keep.isSome then getMainGoal - else do - let mgoal ← getMainGoal - mgoal.tryClearMany #[hEq.fvarId!] - setGoals (mgoal :: (← getUnsolvedGoals)) - trace[Progress] "Goal after splitting eq and post and simplifying the target: {mgoal}" - -- Continue splitting following the post following the user's instructions - match hPost with - | none => - -- Sanity check - if ¬ ids.isEmpty then - return (.Error m!"Too many ids provided ({ids}): there is no postcondition to split") - else return .Ok - | some hPost => do - let rec splitPostWithIds (prevId : Name) (hPost : Expr) (ids0 : List (Option Name)) : TacticM ProgressError := do - match ids0 with - | [] => - /- We used all the user provided ids. - Split the remaining conjunctions by using fresh ids if the user - instructed to fully split the post-condition, otherwise stop -/ - if splitPost then - splitFullConjTac true hPost (λ _ => pure .Ok) - else pure .Ok - | nid :: ids => do - trace[Progress] "Splitting post: {← inferType hPost}" - -- Split - let nid ← do - match nid with - | none => mkFreshAnonPropUserName - | some nid => pure nid - trace[Progress] "\n- prevId: {prevId}\n- nid: {nid}\n- remaining ids: {ids}" - if ← isConj (← inferType hPost) then - splitConjTac hPost (some (prevId, nid)) (λ _ nhPost => splitPostWithIds nid nhPost ids) - else return (.Error m!"Too many ids provided ({ids0}) not enough conjuncts to split in the postcondition") - let curPostId := (← hPost.fvarId!.getDecl).userName - splitPostWithIds curPostId hPost ids + trace[Progress] "current goal: {← getMainGoal}" + Tactic.focus do + let _ ← + tryTac + (simpAt true [] + [``Primitives.bind_tc_ret, ``Primitives.bind_tc_fail, ``Primitives.bind_tc_div] + [hEq.fvarId!] (.targets #[] true)) + -- It may happen that at this point the goal is already solved (though this is rare) + -- TODO: not sure this is the best way of checking it + if (← getUnsolvedGoals) == [] then pure .Ok + else + trace[Progress] "goal after applying the eq and simplifying the binds: {← getMainGoal}" + -- TODO: remove this (some types get unfolded too much: we "fold" them back) + let _ ← tryTac (simpAt true [] scalar_eqs [] .wildcard_dep) + trace[Progress] "goal after folding back scalar types: {← getMainGoal}" + -- Clear the equality, unless the user requests not to do so + let mgoal ← do + if keep.isSome then getMainGoal + else do + let mgoal ← getMainGoal + mgoal.tryClearMany #[hEq.fvarId!] + setGoals (mgoal :: (← getUnsolvedGoals)) + trace[Progress] "Goal after splitting eq and post and simplifying the target: {mgoal}" + -- Continue splitting following the post following the user's instructions + match hPost with + | none => + -- Sanity check + if ¬ ids.isEmpty then + return (.Error m!"Too many ids provided ({ids}): there is no postcondition to split") + else return .Ok + | some hPost => do + let rec splitPostWithIds (prevId : Name) (hPost : Expr) (ids0 : List (Option Name)) : TacticM ProgressError := do + match ids0 with + | [] => + /- We used all the user provided ids. + Split the remaining conjunctions by using fresh ids if the user + instructed to fully split the post-condition, otherwise stop -/ + if splitPost then + splitFullConjTac true hPost (λ _ => pure .Ok) + else pure .Ok + | nid :: ids => do + trace[Progress] "Splitting post: {← inferType hPost}" + -- Split + let nid ← do + match nid with + | none => mkFreshAnonPropUserName + | some nid => pure nid + trace[Progress] "\n- prevId: {prevId}\n- nid: {nid}\n- remaining ids: {ids}" + if ← isConj (← inferType hPost) then + splitConjTac hPost (some (prevId, nid)) (λ _ nhPost => splitPostWithIds nid nhPost ids) + else return (.Error m!"Too many ids provided ({ids0}) not enough conjuncts to split in the postcondition") + let curPostId := (← hPost.fvarId!.getDecl).userName + splitPostWithIds curPostId hPost ids match res with | .Error _ => return res -- Can we get there? We're using "return" | .Ok => @@ -223,9 +232,9 @@ def tryLookupApply (keep : Option Name) (ids : Array (Option Name)) (splitPost : pure (some res) catch _ => none match res with - | some .Ok => return true + | some .Ok => pure true | some (.Error msg) => throwError msg - | none => return false + | none => pure false -- The array of ids are identifiers to use when introducing fresh variables def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrLocal) @@ -266,36 +275,42 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL match res with | .Ok => return () | .Error msg => throwError msg - -- It failed: lookup the pspec theorems which match the expression - trace[Progress] "No assumption succeeded: trying to lookup a pspec theorem" - let pspecs : Array TheoremOrLocal ← do - let thNames ← pspecAttr.find? fExpr - -- TODO: because of reduction, there may be several valid theorems (for - -- instance for the scalars). We need to sort them from most specific to - -- least specific. For now, we assume the most specific theorems are at - -- the end. - let thNames := thNames.reverse - trace[Progress] "Looked up pspec theorems: {thNames}" - pure (thNames.map fun th => TheoremOrLocal.Theorem th) - -- Try the theorems one by one - for pspec in pspecs do - if ← tryLookupApply keep ids splitPost asmTac fExpr "pspec theorem" pspec then return () - else pure () - -- It failed: try to use the recursive assumptions - trace[Progress] "Failed using a pspec theorem: trying to use a recursive assumption" - -- We try to apply the assumptions of kind "auxDecl" - let ctx ← Lean.MonadLCtx.getLCtx - let decls ← ctx.getAllDecls - let decls := decls.filter (λ decl => match decl.kind with - | .default | .implDetail => false | .auxDecl => true) - for decl in decls.reverse do - trace[Progress] "Trying recursive assumption: {decl.userName} : {decl.type}" - let res ← do try progressWith fExpr (.Local decl) keep ids splitPost asmTac catch _ => continue - match res with - | .Ok => return () - | .Error msg => throwError msg - -- Nothing worked: failed - throwError "Progress failed" + -- It failed: lookup the pspec theorems which match the expression *only + -- if the function is a constant* + let fIsConst ← do + fExpr.consumeMData.withApp fun mf _ => do + pure mf.isConst + if ¬ fIsConst then throwError "Progress failed" + else do + trace[Progress] "No assumption succeeded: trying to lookup a pspec theorem" + let pspecs : Array TheoremOrLocal ← do + let thNames ← pspecAttr.find? fExpr + -- TODO: because of reduction, there may be several valid theorems (for + -- instance for the scalars). We need to sort them from most specific to + -- least specific. For now, we assume the most specific theorems are at + -- the end. + let thNames := thNames.reverse + trace[Progress] "Looked up pspec theorems: {thNames}" + pure (thNames.map fun th => TheoremOrLocal.Theorem th) + -- Try the theorems one by one + for pspec in pspecs do + if ← tryLookupApply keep ids splitPost asmTac fExpr "pspec theorem" pspec then return () + else pure () + -- It failed: try to use the recursive assumptions + trace[Progress] "Failed using a pspec theorem: trying to use a recursive assumption" + -- We try to apply the assumptions of kind "auxDecl" + let ctx ← Lean.MonadLCtx.getLCtx + let decls ← ctx.getAllDecls + let decls := decls.filter (λ decl => match decl.kind with + | .default | .implDetail => false | .auxDecl => true) + for decl in decls.reverse do + trace[Progress] "Trying recursive assumption: {decl.userName} : {decl.type}" + let res ← do try progressWith fExpr (.Local decl) keep ids splitPost asmTac catch _ => continue + match res with + | .Ok => return () + | .Error msg => throwError msg + -- Nothing worked: failed + throwError "Progress failed" syntax progressArgs := ("keep" (ident <|> "_"))? ("with" ident)? ("as" " ⟨ " (ident <|> "_"),* " .."? " ⟩")? @@ -434,6 +449,15 @@ namespace Test progress simp [*] + /- Checking the case where simplifying the goal after instantiating the + pspec theorem the goal actually solves it, and where the function is + not a constant. We also test the case where the function under scrutinee + is not a constant. -/ + example {x : U32} + (f : U32 → Result Unit) (h : ∀ x, f x = .ret ()) : + f x = ret () := by + progress + end Test end Progress -- cgit v1.2.3 From d8247d99520738188bbd160be7de03550f8156ce Mon Sep 17 00:00:00 2001 From: Son Ho Date: Sat, 27 Jan 2024 21:31:52 +0100 Subject: Add some lemmas to the Lean backend --- backends/lean/Base/Primitives/Scalar.lean | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) (limited to 'backends/lean') diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean index 2c34774b..fe8dc8ec 100644 --- a/backends/lean/Base/Primitives/Scalar.lean +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -1038,16 +1038,26 @@ instance {ty} : LT (Scalar ty) where instance {ty} : LE (Scalar ty) where le a b := LE.le a.val b.val +-- Not marking this one with @[simp] on purpose +theorem Scalar.eq_equiv {ty : ScalarTy} (x y : Scalar ty) : + x = y ↔ x.val = y.val := by + cases x; cases y; simp_all + +-- This is sometimes useful when rewriting the goal with the local assumptions +@[simp] theorem Scalar.eq_imp {ty : ScalarTy} (x y : Scalar ty) : + x = y → x.val = y.val := (eq_equiv x y).mp + theorem Scalar.lt_equiv {ty : ScalarTy} (x y : Scalar ty) : x < y ↔ x.val < y.val := by simp [LT.lt] +@[simp] theorem Scalar.lt_imp {ty : ScalarTy} (x y : Scalar ty) : + x < y → x.val < y.val := (lt_equiv x y).mp + theorem Scalar.le_equiv {ty : ScalarTy} (x y : Scalar ty) : x ≤ y ↔ x.val ≤ y.val := by simp [LE.le] --- Not marking this one with @[simp] on purpose -theorem Scalar.eq_equiv {ty : ScalarTy} (x y : Scalar ty) : - x = y ↔ x.val = y.val := by - cases x; cases y; simp_all +@[simp] theorem Scalar.le_imp {ty : ScalarTy} (x y : Scalar ty) : + x ≤ y → x.val ≤ y.val := (le_equiv x y).mp instance Scalar.decLt {ty} (a b : Scalar ty) : Decidable (LT.lt a b) := Int.decLt .. instance Scalar.decLe {ty} (a b : Scalar ty) : Decidable (LE.le a b) := Int.decLe .. -- cgit v1.2.3