diff options
Diffstat (limited to 'backends/lean/Base/Arith')
-rw-r--r-- | backends/lean/Base/Arith/Arith.lean | 122 |
1 files changed, 29 insertions, 93 deletions
diff --git a/backends/lean/Base/Arith/Arith.lean b/backends/lean/Base/Arith/Arith.lean index ff628cf3..3557d350 100644 --- a/backends/lean/Base/Arith/Arith.lean +++ b/backends/lean/Base/Arith/Arith.lean @@ -230,25 +230,20 @@ def introInstances (declToUnfold : Name) (lookup : Expr → MetaM (Option Expr)) let type ← inferType e let name ← mkFreshUserName `h -- Add a declaration - let nval ← Utils.addDecl name e type (asLet := false) + let nval ← Utils.addDeclTac name e type (asLet := false) -- Simplify to unfold the declaration to unfold (i.e., the projector) - let simpTheorems ← Tactic.simpOnlyBuiltins.foldlM (·.addConst ·) ({} : SimpTheorems) - -- Add the equational theorem for the decl to unfold - let simpTheorems ← simpTheorems.addDeclToUnfold declToUnfold - let congrTheorems ← getSimpCongrTheorems - let ctx : Simp.Context := { simpTheorems := #[simpTheorems], congrTheorems } - -- Where to apply the simplifier - let loc := Tactic.Location.targets #[mkIdent name] false - -- Apply the simplifier - let _ ← Tactic.simpLocation ctx (discharge? := .none) loc + Utils.simpAt [declToUnfold] [] [] (Tactic.Location.targets #[mkIdent name] false) -- Return the new value pure nval +def introHasPropInstances : Tactic.TacticM (Array Expr) := do + trace[Arith] "Introducing the HasProp instances" + introInstances ``HasProp.prop_ty lookupHasProp + -- Lookup the instances of `HasProp for all the sub-expressions in the context, -- and introduce the corresponding assumptions elab "intro_has_prop_instances" : tactic => do - trace[Arith] "Introducing the HasProp instances" - let _ ← introInstances ``HasProp.prop_ty lookupHasProp + let _ ← introHasPropInstances example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by intro_has_prop_instances @@ -258,74 +253,6 @@ example {a: Type} (v : Vec a) : v.val.length ≤ Scalar.max ScalarTy.Usize := by intro_has_prop_instances simp_all [Scalar.max, Scalar.min] --- Tactic to split on a disjunction. --- The expression `h` should be an fvar. -def splitDisj (h : Expr) (kleft kright : Tactic.TacticM Unit) : Tactic.TacticM Unit := do - trace[Arith] "assumption on which to split: {h}" - -- Retrieve the main goal - Tactic.withMainContext do - let goalType ← Tactic.getMainTarget - let hDecl := (← getLCtx).get! h.fvarId! - let hName := hDecl.userName - -- Case disjunction - let hTy ← inferType h - hTy.withApp fun f xs => do - trace[Arith] "as app: {f} {xs}" - -- Sanity check - if ¬ (f.isConstOf ``Or ∧ xs.size = 2) then throwError "Invalid argument to splitDisj" - let a := xs.get! 0 - let b := xs.get! 1 - -- Introduce the new goals - -- Returns: - -- - the match branch - -- - a fresh new mvar id - let mkGoal (hTy : Expr) (nGoalName : String) : MetaM (Expr × MVarId) := do - -- Introduce a variable for the assumption (`a` or `b`). Note that we reuse - -- the name of the assumption we split. - withLocalDeclD hName hTy fun var => do - -- The new goal - let mgoal ← mkFreshExprSyntheticOpaqueMVar goalType (tag := Name.mkSimple nGoalName) - -- Clear the assumption that we split - let mgoal ← mgoal.mvarId!.tryClearMany #[h.fvarId!] - -- The branch expression - let branch ← mkLambdaFVars #[var] (mkMVar mgoal) - pure (branch, mgoal) - let (inl, mleft) ← mkGoal a "left" - let (inr, mright) ← mkGoal b "right" - trace[Arith] "left: {inl}: {mleft}" - trace[Arith] "right: {inr}: {mright}" - -- Create the match expression - withLocalDeclD (← mkFreshUserName `h) hTy fun hVar => do - let motive ← mkLambdaFVars #[hVar] goalType - let casesExpr ← mkAppOptM ``Or.casesOn #[a, b, motive, h, inl, inr] - let mgoal ← Tactic.getMainGoal - trace[Arith] "goals: {← Tactic.getUnsolvedGoals}" - trace[Arith] "main goal: {mgoal}" - mgoal.assign casesExpr - let goals ← Tactic.getUnsolvedGoals - -- Focus on the left - Tactic.setGoals [mleft] - kleft - let leftGoals ← Tactic.getUnsolvedGoals - -- Focus on the right - Tactic.setGoals [mright] - kright - let rightGoals ← Tactic.getUnsolvedGoals - -- Put all the goals back - Tactic.setGoals (leftGoals ++ rightGoals ++ goals) - trace[Arith] "new goals: {← Tactic.getUnsolvedGoals}" - -elab "split_disj " n:ident : tactic => do - Tactic.withMainContext do - let decl ← Lean.Meta.getLocalDeclFromUserName n.getId - let fvar := mkFVar decl.fvarId - splitDisj fvar (fun _ => pure ()) (fun _ => pure ()) - -example (x y : Int) (h0 : x ≤ y ∨ x ≥ y) : x ≤ y ∨ x ≥ y := by - split_disj h0 - . left; assumption - . right; assumption - -- Lookup the instances of `PropHasImp for all the sub-expressions in the context, -- and introduce the corresponding assumptions elab "intro_prop_has_imp_instances" : tactic => do @@ -357,7 +284,7 @@ def intTacPreprocess : Tactic.TacticM Unit := do | [] => pure () | asm :: asms => let k := splitOnAsms asms - splitDisj asm k k + Utils.splitDisjTac asm k k -- Introduce let asms ← introInstances ``PropHasImp.concl lookupPropHasImp -- Split @@ -403,18 +330,27 @@ example (x y : Int) (h0: 0 ≤ x) (h1: x ≠ 0) (h2 : 0 ≤ y) (h3 : y ≠ 0) : int_tac -- A tactic to solve linear arithmetic goals in the presence of scalars -syntax "scalar_tac" : tactic -macro_rules - | `(tactic| scalar_tac) => - `(tactic| - intro_has_prop_instances; - have := Scalar.cMin_bound ScalarTy.Usize; - have := Scalar.cMin_bound ScalarTy.Isize; - have := Scalar.cMax_bound ScalarTy.Usize; - have := Scalar.cMax_bound ScalarTy.Isize; - -- TODO: not too sure about that - simp only [*, Scalar.max, Scalar.min, Scalar.cMin, Scalar.cMax] at *; - int_tac) +def scalarTac : Tactic.TacticM Unit := do + Tactic.withMainContext do + -- Introduce the scalar bounds + let _ ← introHasPropInstances + Tactic.allGoals do + -- Inroduce the bounds for the isize/usize types + let add (e : Expr) : Tactic.TacticM Unit := do + let ty ← inferType e + let _ ← Utils.addDeclTac (← mkFreshUserName `h) e ty (asLet := false) + add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Usize []]) + add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Isize []]) + add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []]) + add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []]) + -- Reveal the concrete bounds - TODO: not too sure about that. + -- Maybe we should reveal the "concrete" bounds (after normalization) + Utils.simpAt [``Scalar.max, ``Scalar.min, ``Scalar.cMin, ``Scalar.cMax] [] [] .wildcard + -- Apply the integer tactic + intTac + +elab "scalar_tac" : tactic => + scalarTac example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by scalar_tac |