From 1fc263ec0f527698b2f4d734d9757c9f723d0bee Mon Sep 17 00:00:00 2001 From: Son Ho Date: Mon, 18 Sep 2023 12:25:40 +0200 Subject: Improve scalar_tac --- backends/lean/Base/Arith/Int.lean | 8 ++-- backends/lean/Base/Arith/Scalar.lean | 17 +++++++- backends/lean/Base/Primitives/Scalar.lean | 67 ++++++++----------------------- 3 files changed, 36 insertions(+), 56 deletions(-) (limited to 'backends/lean') diff --git a/backends/lean/Base/Arith/Int.lean b/backends/lean/Base/Arith/Int.lean index eb6701c2..3359ecdb 100644 --- a/backends/lean/Base/Arith/Int.lean +++ b/backends/lean/Base/Arith/Int.lean @@ -211,9 +211,11 @@ def intTacPreprocess (extraPreprocess : Tactic.TacticM Unit) : Tactic.TacticM U let _ ← introHasIntPropInstances -- Extra preprocessing, before we split on the disjunctions extraPreprocess - -- Split - let asms ← introInstances ``PropHasImp.concl lookupPropHasImp - splitOnAsms asms.toList + -- Split - note that the extra-preprocessing step might actually have + -- proven the goal (by doing simplifications for instance) + Tactic.allGoals do + let asms ← introInstances ``PropHasImp.concl lookupPropHasImp + splitOnAsms asms.toList elab "int_tac_preprocess" : tactic => intTacPreprocess (do pure ()) diff --git a/backends/lean/Base/Arith/Scalar.lean b/backends/lean/Base/Arith/Scalar.lean index db672489..47751c8a 100644 --- a/backends/lean/Base/Arith/Scalar.lean +++ b/backends/lean/Base/Arith/Scalar.lean @@ -16,14 +16,15 @@ def scalarTacExtraPreprocess : Tactic.TacticM Unit := do add (← mkAppM ``Scalar.cMin_bound #[.const ``ScalarTy.Isize []]) add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Usize []]) add (← mkAppM ``Scalar.cMax_bound #[.const ``ScalarTy.Isize []]) - -- Reveal the concrete bounds + -- Reveal the concrete bounds, simplify calls to [ofInt] Utils.simpAt [``Scalar.min, ``Scalar.max, ``Scalar.cMin, ``Scalar.cMax, ``I8.min, ``I16.min, ``I32.min, ``I64.min, ``I128.min, ``I8.max, ``I16.max, ``I32.max, ``I64.max, ``I128.max, ``U8.min, ``U16.min, ``U32.min, ``U64.min, ``U128.min, ``U8.max, ``U16.max, ``U32.max, ``U64.max, ``U128.max, ``Usize.min - ] [] [] .wildcard + ] [``Scalar.ofInt_val_eq, ``Scalar.neq_to_neq_val] [] .wildcard + elab "scalar_tac_preprocess" : tactic => intTacPreprocess scalarTacExtraPreprocess @@ -50,4 +51,16 @@ example (x y : U32) : x.val ≤ Scalar.max ScalarTy.U32 := by example (x : U32 × U32) : 0 ≤ x.fst.val := by scalar_tac +-- Checking that we properly handle [ofInt] +example : U32.ofInt 1 ≤ U32.max := by + scalar_tac + +example (x : Int) (h0 : 0 ≤ x) (h1 : x ≤ U32.max) : + U32.ofInt x (by constructor <;> scalar_tac) ≤ U32.max := by + scalar_tac + +-- Not equal +example (x : U32) (h0 : ¬ x = U32.ofInt 0) : 0 < x.val := by + scalar_tac + end Arith diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean index ffc969f3..aa1d452d 100644 --- a/backends/lean/Base/Primitives/Scalar.lean +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -709,60 +709,22 @@ def U128.ofIntCore := @Scalar.ofIntCore .U128 -- ofInt -- TODO: typeclass? -def Isize.ofInt := @Scalar.ofInt .Isize -def I8.ofInt := @Scalar.ofInt .I8 -def I16.ofInt := @Scalar.ofInt .I16 -def I32.ofInt := @Scalar.ofInt .I32 -def I64.ofInt := @Scalar.ofInt .I64 -def I128.ofInt := @Scalar.ofInt .I128 -def Usize.ofInt := @Scalar.ofInt .Usize -def U8.ofInt := @Scalar.ofInt .U8 -def U16.ofInt := @Scalar.ofInt .U16 -def U32.ofInt := @Scalar.ofInt .U32 -def U64.ofInt := @Scalar.ofInt .U64 -def U128.ofInt := @Scalar.ofInt .U128 - --- TODO: factor those lemmas out +@[reducible] def Isize.ofInt := @Scalar.ofInt .Isize +@[reducible] def I8.ofInt := @Scalar.ofInt .I8 +@[reducible] def I16.ofInt := @Scalar.ofInt .I16 +@[reducible] def I32.ofInt := @Scalar.ofInt .I32 +@[reducible] def I64.ofInt := @Scalar.ofInt .I64 +@[reducible] def I128.ofInt := @Scalar.ofInt .I128 +@[reducible] def Usize.ofInt := @Scalar.ofInt .Usize +@[reducible] def U8.ofInt := @Scalar.ofInt .U8 +@[reducible] def U16.ofInt := @Scalar.ofInt .U16 +@[reducible] def U32.ofInt := @Scalar.ofInt .U32 +@[reducible] def U64.ofInt := @Scalar.ofInt .U64 +@[reducible] def U128.ofInt := @Scalar.ofInt .U128 + @[simp] theorem Scalar.ofInt_val_eq {ty} (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : (Scalar.ofInt x h).val = x := by simp [Scalar.ofInt, Scalar.ofIntCore] -@[simp] theorem Isize.ofInt_val_eq (h : Scalar.min ScalarTy.Isize ≤ x ∧ x ≤ Scalar.max ScalarTy.Isize) : (Isize.ofInt x h).val = x := by - apply Scalar.ofInt_val_eq h - -@[simp] theorem I8.ofInt_val_eq (h : Scalar.min ScalarTy.I8 ≤ x ∧ x ≤ Scalar.max ScalarTy.I8) : (I8.ofInt x h).val = x := by - apply Scalar.ofInt_val_eq h - -@[simp] theorem I16.ofInt_val_eq (h : Scalar.min ScalarTy.I16 ≤ x ∧ x ≤ Scalar.max ScalarTy.I16) : (I16.ofInt x h).val = x := by - apply Scalar.ofInt_val_eq h - -@[simp] theorem I32.ofInt_val_eq (h : Scalar.min ScalarTy.I32 ≤ x ∧ x ≤ Scalar.max ScalarTy.I32) : (I32.ofInt x h).val = x := by - apply Scalar.ofInt_val_eq h - -@[simp] theorem I64.ofInt_val_eq (h : Scalar.min ScalarTy.I64 ≤ x ∧ x ≤ Scalar.max ScalarTy.I64) : (I64.ofInt x h).val = x := by - apply Scalar.ofInt_val_eq h - -@[simp] theorem I128.ofInt_val_eq (h : Scalar.min ScalarTy.I128 ≤ x ∧ x ≤ Scalar.max ScalarTy.I128) : (I128.ofInt x h).val = x := by - apply Scalar.ofInt_val_eq h - -@[simp] theorem Usize.ofInt_val_eq (h : Scalar.min ScalarTy.Usize ≤ x ∧ x ≤ Scalar.max ScalarTy.Usize) : (Usize.ofInt x h).val = x := by - apply Scalar.ofInt_val_eq h - -@[simp] theorem U8.ofInt_val_eq (h : Scalar.min ScalarTy.U8 ≤ x ∧ x ≤ Scalar.max ScalarTy.U8) : (U8.ofInt x h).val = x := by - apply Scalar.ofInt_val_eq h - -@[simp] theorem U16.ofInt_val_eq (h : Scalar.min ScalarTy.U16 ≤ x ∧ x ≤ Scalar.max ScalarTy.U16) : (U16.ofInt x h).val = x := by - apply Scalar.ofInt_val_eq h - -@[simp] theorem U32.ofInt_val_eq (h : Scalar.min ScalarTy.U32 ≤ x ∧ x ≤ Scalar.max ScalarTy.U32) : (U32.ofInt x h).val = x := by - apply Scalar.ofInt_val_eq h - -@[simp] theorem U64.ofInt_val_eq (h : Scalar.min ScalarTy.U64 ≤ x ∧ x ≤ Scalar.max ScalarTy.U64) : (U64.ofInt x h).val = x := by - apply Scalar.ofInt_val_eq h - -@[simp] theorem U128.ofInt_val_eq (h : Scalar.min ScalarTy.U128 ≤ x ∧ x ≤ Scalar.max ScalarTy.U128) : (U128.ofInt x h).val = x := by - apply Scalar.ofInt_val_eq h - - -- Comparisons instance {ty} : LT (Scalar ty) where lt a b := LT.lt a.val b.val @@ -790,6 +752,9 @@ instance (ty : ScalarTy) : DecidableEq (Scalar ty) := instance (ty : ScalarTy) : CoeOut (Scalar ty) Int where coe := λ v => v.val +@[simp] theorem Scalar.neq_to_neq_val {ty} : ∀ {i j : Scalar ty}, (¬ i = j) ↔ ¬ i.val = j.val := by + intro i j; cases i; cases j; simp + -- -- We now define a type class that subsumes the various machine integer types, so -- -- as to write a concise definition for scalar_cast, rather than exhaustively -- -- enumerating all of the possible pairs. We remark that Rust has sane semantics -- cgit v1.2.3