From 3c092169efcbc36a9b435c68c590b36f69204f94 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 8 Dec 2023 12:38:55 +0100 Subject: Update the progress tactic to use discrimination trees --- backends/lean/Base/Progress/Progress.lean | 93 +++++++++++++++++-------------- 1 file changed, 51 insertions(+), 42 deletions(-) (limited to 'backends/lean/Base/Progress/Progress.lean') diff --git a/backends/lean/Base/Progress/Progress.lean b/backends/lean/Base/Progress/Progress.lean index ba63f09d..93b7d7d5 100644 --- a/backends/lean/Base/Progress/Progress.lean +++ b/backends/lean/Base/Progress/Progress.lean @@ -204,11 +204,11 @@ def getFirstArg (args : Array Expr) : Option Expr := do if args.size = 0 then none else some (args.get! 0) -/- Helper: try to lookup a theorem and apply it, or continue with another tactic - if it fails -/ +/- Helper: try to lookup a theorem and apply it. + Return true if it succeeded. -/ def tryLookupApply (keep : Option Name) (ids : Array (Option Name)) (splitPost : Bool) (asmTac : TacticM Unit) (fExpr : Expr) - (kind : String) (th : Option TheoremOrLocal) (x : TacticM Unit) : TacticM Unit := do + (kind : String) (th : Option TheoremOrLocal) : TacticM Bool := do let res ← do match th with | none => @@ -223,9 +223,9 @@ def tryLookupApply (keep : Option Name) (ids : Array (Option Name)) (splitPost : pure (some res) catch _ => none match res with - | some .Ok => return () + | some .Ok => return true | some (.Error msg) => throwError msg - | none => x + | none => return false -- The array of ids are identifiers to use when introducing fresh variables def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrLocal) @@ -236,11 +236,19 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL let goalTy ← mgoal.getType trace[Progress] "goal: {goalTy}" -- Dive into the goal to lookup the theorem - let (fExpr, fName, args) ← do - withPSpec goalTy fun desc => - -- TODO: check that no quantified variables in the arguments - pure (desc.fExpr, desc.fName, desc.args) - trace[Progress] "Function: {fName}" + -- Remark: if we don't isolate the call to `withPSpec` to immediately "close" + -- the terms immediately, we may end up with the error: + -- "(kernel) declaration has free variables" + -- I'm not sure I understand why. + -- TODO: we should also check that no quantified variable appears in fExpr. + -- If such variables appear, we should just fail because the goal doesn't + -- have the proper shape. + let fExpr ← do + let isGoal := true + withPSpec false isGoal goalTy fun desc => do + let fExpr := desc.fArgsExpr + trace[Progress] "Expression to match: {fExpr}" + pure fExpr -- If the user provided a theorem/assumption: use it. -- Otherwise, lookup one. match withTh with @@ -258,36 +266,24 @@ def progressAsmsOrLookupTheorem (keep : Option Name) (withTh : Option TheoremOrL match res with | .Ok => return () | .Error msg => throwError msg - -- It failed: try to lookup a theorem - -- TODO: use a list of theorems, and try them one by one? - trace[Progress] "No assumption succeeded: trying to lookup a theorem" - let pspec ← do - let thName ← pspecAttr.find? fName - pure (thName.map fun th => .Theorem th) - tryLookupApply keep ids splitPost asmTac fExpr "pspec theorem" pspec do - -- It failed: try to lookup a *class* expr spec theorem (those are more - -- specific than class spec theorems) - trace[Progress] "Failed using a pspec theorem: trying to lookup a pspec class expr theorem" - let pspecClassExpr ← do - match getFirstArg args with - | none => pure none - | some arg => do - trace[Progress] "Using: f:{fName}, arg: {arg}" - let thName ← pspecClassExprAttr.find? fName arg - pure (thName.map fun th => .Theorem th) - tryLookupApply keep ids splitPost asmTac fExpr "pspec class expr theorem" pspecClassExpr do - -- It failed: try to lookup a *class* spec theorem - trace[Progress] "Failed using a pspec class expr theorem: trying to lookup a pspec class theorem" - let pspecClass ← do - match ← getFirstArgAppName args with - | none => pure none - | some argName => do - trace[Progress] "Using: f: {fName}, arg: {argName}" - let thName ← pspecClassAttr.find? fName argName - pure (thName.map fun th => .Theorem th) - tryLookupApply keep ids splitPost asmTac fExpr "pspec class theorem" pspecClass do - trace[Progress] "Failed using a pspec class theorem: trying to use a recursive assumption" - -- Try a recursive call - we try the assumptions of kind "auxDecl" + -- It failed: lookup the pspec theorems which match the expression + trace[Progress] "No assumption succeeded: trying to lookup a pspec theorem" + let pspecs : Array TheoremOrLocal ← do + let thNames ← pspecAttr.find? fExpr + -- TODO: because of reduction, there may be several valid theorems (for + -- instance for the scalars). We need to sort them from most specific to + -- least specific. For now, we assume the most specific theorems are at + -- the end. + let thNames := thNames.reverse + trace[Progress] "Looked up pspec theorems: {thNames}" + pure (thNames.map fun th => TheoremOrLocal.Theorem th) + -- Try the theorems one by one + for pspec in pspecs do + if ← tryLookupApply keep ids splitPost asmTac fExpr "pspec theorem" pspec then return () + else pure () + -- It failed: try to use the recursive assumptions + trace[Progress] "Failed using a pspec theorem: trying to use a recursive assumption" + -- We try to apply the assumptions of kind "auxDecl" let ctx ← Lean.MonadLCtx.getLCtx let decls ← ctx.getAllDecls let decls := decls.filter (λ decl => match decl.kind with @@ -381,8 +377,6 @@ namespace Test -- The following commands display the databases of theorems -- #eval showStoredPSpec - -- #eval showStoredPSpecClass - -- #eval showStoredPSpecExprClass open alloc.vec example {ty} {x y : Scalar ty} @@ -392,6 +386,8 @@ namespace Test progress keep _ as ⟨ z, h1 .. ⟩ simp [*, h1] + set_option trace.Progress false + example {ty} {x y : Scalar ty} (hmin : Scalar.min ty ≤ x.val + y.val) (hmax : x.val + y.val ≤ Scalar.max ty) : @@ -399,6 +395,19 @@ namespace Test progress keep h with Scalar.add_spec as ⟨ z ⟩ simp [*, h] + example {x y : U32} + (hmax : x.val + y.val ≤ U32.max) : + ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by + -- This spec theorem is suboptimal, but it is good to check that it works + progress with Scalar.add_spec as ⟨ z, h1 .. ⟩ + simp [*, h1] + + example {x y : U32} + (hmax : x.val + y.val ≤ U32.max) : + ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by + progress with U32.add_spec as ⟨ z, h1 .. ⟩ + simp [*, h1] + example {x y : U32} (hmax : x.val + y.val ≤ U32.max) : ∃ z, x + y = ret z ∧ z.val = x.val + y.val := by -- cgit v1.2.3