summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon HO2024-06-13 15:19:11 +0200
committerGitHub2024-06-13 15:19:11 +0200
commit234fa36da87b672397f96098bcf832d869f2cfbb (patch)
treeafb669e46c958fc516b8441278a006582d7f2400
parent40e79f1fd64a6535334b1af19a817b27a9a0296c (diff)
parent87d088fa9e4493f32ae3f8d447ff1ff6d44e6396 (diff)
Merge pull request #242 from AeneasVerif/son/scalars2
Update the scalar notations for the Lean backend
-rw-r--r--backends/lean/Base/Primitives.lean1
-rw-r--r--backends/lean/Base/Primitives/ArraySlice.lean2
-rw-r--r--backends/lean/Base/Primitives/CoreConvertNum.lean1
-rw-r--r--backends/lean/Base/Primitives/Scalar.lean123
-rw-r--r--backends/lean/Base/Primitives/ScalarNotations.lean109
-rw-r--r--backends/lean/Base/Primitives/Vec.lean4
-rw-r--r--compiler/Extract.ml6
-rw-r--r--compiler/ExtractTypes.ml14
-rw-r--r--tests/lean/Matches.lean4
9 files changed, 150 insertions, 114 deletions
diff --git a/backends/lean/Base/Primitives.lean b/backends/lean/Base/Primitives.lean
index f80c2004..93617049 100644
--- a/backends/lean/Base/Primitives.lean
+++ b/backends/lean/Base/Primitives.lean
@@ -1,6 +1,7 @@
import Base.Primitives.Base
import Base.Tuples
import Base.Primitives.Scalar
+import Base.Primitives.ScalarNotations
import Base.Primitives.ArraySlice
import Base.Primitives.Vec
import Base.Primitives.Alloc
diff --git a/backends/lean/Base/Primitives/ArraySlice.lean b/backends/lean/Base/Primitives/ArraySlice.lean
index 157f9df1..17ee626f 100644
--- a/backends/lean/Base/Primitives/ArraySlice.lean
+++ b/backends/lean/Base/Primitives/ArraySlice.lean
@@ -126,7 +126,7 @@ abbrev Slice.v {α : Type u} (v : Slice α) : List α := v.val
example {a: Type u} (v : Slice a) : v.length ≤ Scalar.max ScalarTy.Usize := by
scalar_tac
-def Slice.new (α : Type u): Slice α := ⟨ [], by apply Scalar.cMax_suffices .Usize; simp; decide ⟩
+def Slice.new (α : Type u): Slice α := ⟨ [], by apply Scalar.cMax_suffices .Usize; simp ⟩
-- TODO: very annoying that the α is an explicit parameter
def Slice.len (α : Type u) (v : Slice α) : Usize :=
diff --git a/backends/lean/Base/Primitives/CoreConvertNum.lean b/backends/lean/Base/Primitives/CoreConvertNum.lean
index eb456a96..b53d11db 100644
--- a/backends/lean/Base/Primitives/CoreConvertNum.lean
+++ b/backends/lean/Base/Primitives/CoreConvertNum.lean
@@ -4,6 +4,7 @@ import Init.Data.List.Basic
import Mathlib.Tactic.Linarith
import Base.IList
import Base.Primitives.Scalar
+import Base.Primitives.ScalarNotations
import Base.Primitives.ArraySlice
import Base.Arith
import Base.Progress.Base
diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean
index 157ade2c..f4264b9b 100644
--- a/backends/lean/Base/Primitives/Scalar.lean
+++ b/backends/lean/Base/Primitives/Scalar.lean
@@ -312,38 +312,13 @@ theorem Scalar.bound_suffices (ty : ScalarTy) (x : Int) :
λ h => by
apply And.intro <;> have hmin := Scalar.cMin_bound ty <;> have hmax := Scalar.cMax_bound ty <;> linarith
-/- [match_pattern] attribute: allows to us `Scalar.ofIntCore` inside of patterns.
- This is particularly useful once we introduce notations like `#u32` (which
- desugards to `Scalar.ofIntCore`) as it allows to write expressions like this:
- Example:
- ```
- match x with
- | 0#u32 => ...
- | 1#u32 => ...
- | ...
- ```
- -/
-@[match_pattern] def Scalar.ofIntCore {ty : ScalarTy} (x : Int)
+def Scalar.ofIntCore {ty : ScalarTy} (x : Int)
(h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : Scalar ty :=
{ val := x, hmin := h.left, hmax := h.right }
--- The definitions below are used later to introduce nice syntax for constants,
--- like `1#u32`. We are reusing the technique described here: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Different.20elaboration.20inside.2Foutside.20of.20match.20patterns/near/425455284
-
-class InBounds (ty : ScalarTy) (x : Int) :=
- hInBounds : Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty
-
--- This trick to trigger reduction for decidable propositions comes from
--- here: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/instance.20with.20tactic.20autoparam/near/343495807
-class Decide (p : Prop) [Decidable p] : Prop where
- isTrue : p
-instance : @Decide p (.isTrue h) := @Decide.mk p (_) h
-
-instance [Decide (Scalar.cMin ty ≤ v ∧ v ≤ Scalar.cMax ty)] : InBounds ty v where
- hInBounds := Decide.isTrue
-
-@[reducible, match_pattern] def Scalar.ofInt {ty : ScalarTy} (x : Int) [InBounds ty x] : Scalar ty :=
- Scalar.ofIntCore x (Scalar.bound_suffices ty x InBounds.hInBounds)
+@[reducible] def Scalar.ofInt {ty : ScalarTy} (x : Int)
+ (hInBounds : Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty := by decide) : Scalar ty :=
+ Scalar.ofIntCore x (Scalar.bound_suffices ty x hInBounds)
@[simp] abbrev Scalar.in_bounds (ty : ScalarTy) (x : Int) : Prop :=
Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty
@@ -412,9 +387,8 @@ theorem Scalar.tryMk_eq (ty : ScalarTy) (x : Int) :
simp [tryMk, ofOption, tryMkOpt]
split_ifs <;> simp
-instance (ty: ScalarTy) : InBounds ty 0 where
- hInBounds := by
- induction ty <;> simp [Scalar.cMax, Scalar.cMin, Scalar.max, Scalar.min] <;> decide
+@[simp] theorem zero_in_cbounds {ty : ScalarTy} : Scalar.cMin ty ≤ 0 ∧ 0 ≤ Scalar.cMax ty := by
+ cases ty <;> simp [Scalar.cMax, Scalar.cMin, Scalar.max, Scalar.min] <;> decide
def Scalar.neg {ty : ScalarTy} (x : Scalar ty) : Result (Scalar ty) := Scalar.tryMk ty (- x.val)
@@ -1268,73 +1242,18 @@ def U128.ofIntCore := @Scalar.ofIntCore .U128
-- ofInt
-- TODO: typeclass?
-@[match_pattern] abbrev Isize.ofInt := @Scalar.ofInt .Isize
-@[match_pattern] abbrev I8.ofInt := @Scalar.ofInt .I8
-@[match_pattern] abbrev I16.ofInt := @Scalar.ofInt .I16
-@[match_pattern] abbrev I32.ofInt := @Scalar.ofInt .I32
-@[match_pattern] abbrev I64.ofInt := @Scalar.ofInt .I64
-@[match_pattern] abbrev I128.ofInt := @Scalar.ofInt .I128
-@[match_pattern] abbrev Usize.ofInt := @Scalar.ofInt .Usize
-@[match_pattern] abbrev U8.ofInt := @Scalar.ofInt .U8
-@[match_pattern] abbrev U16.ofInt := @Scalar.ofInt .U16
-@[match_pattern] abbrev U32.ofInt := @Scalar.ofInt .U32
-@[match_pattern] abbrev U64.ofInt := @Scalar.ofInt .U64
-@[match_pattern] abbrev U128.ofInt := @Scalar.ofInt .U128
-
-postfix:max "#isize" => Isize.ofInt
-postfix:max "#i8" => I8.ofInt
-postfix:max "#i16" => I16.ofInt
-postfix:max "#i32" => I32.ofInt
-postfix:max "#i64" => I64.ofInt
-postfix:max "#i128" => I128.ofInt
-postfix:max "#usize" => Usize.ofInt
-postfix:max "#u8" => U8.ofInt
-postfix:max "#u16" => U16.ofInt
-postfix:max "#u32" => U32.ofInt
-postfix:max "#u64" => U64.ofInt
-postfix:max "#u128" => U128.ofInt
-
-/- Testing the notations -/
-example := 0#u32
-example := 1#u32
-example := 1#i32
-example := 0#isize
-example := (-1)#isize
-example (x : U32) : Bool :=
- match x with
- | 0#u32 => true
- | _ => false
-
-example (x : U32) : Bool :=
- match x with
- | 1#u32 => true
- | _ => false
-
-example (x : I32) : Bool :=
- match x with
- | (-1)#i32 => true
- | _ => false
-
--- Notation for pattern matching
--- We make the precedence looser than the negation.
-notation:70 a:70 "#scalar" => Scalar.mk (a) _ _
-
-example {ty} (x : Scalar ty) : ℤ :=
- match x with
- | v#scalar => v
-
-example {ty} (x : Scalar ty) : Bool :=
- match x with
- | 1#scalar => true
- | _ => false
-
-example {ty} (x : Scalar ty) : Bool :=
- match x with
- | -(1 : Int)#scalar => true
- | _ => false
-
--- Testing the notations
-example : Result Usize := 0#usize + 1#usize
+abbrev Isize.ofInt := @Scalar.ofInt .Isize
+abbrev I8.ofInt := @Scalar.ofInt .I8
+abbrev I16.ofInt := @Scalar.ofInt .I16
+abbrev I32.ofInt := @Scalar.ofInt .I32
+abbrev I64.ofInt := @Scalar.ofInt .I64
+abbrev I128.ofInt := @Scalar.ofInt .I128
+abbrev Usize.ofInt := @Scalar.ofInt .Usize
+abbrev U8.ofInt := @Scalar.ofInt .U8
+abbrev U16.ofInt := @Scalar.ofInt .U16
+abbrev U32.ofInt := @Scalar.ofInt .U32
+abbrev U64.ofInt := @Scalar.ofInt .U64
+abbrev U128.ofInt := @Scalar.ofInt .U128
-- TODO: factor those lemmas out
@[simp] theorem Scalar.ofInt_val_eq {ty} (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : (Scalar.ofIntCore x h).val = x := by
@@ -1464,18 +1383,18 @@ theorem coe_max {ty: ScalarTy} (a b: Scalar ty): ↑(Max.max a b) = (Max.max (â†
-- Max theory
-- TODO: do the min theory later on.
-theorem Scalar.zero_le_unsigned {ty} (s: ¬ ty.isSigned) (x: Scalar ty): Scalar.ofInt 0 ≤ x := by
+theorem Scalar.zero_le_unsigned {ty} (s: ¬ ty.isSigned) (x: Scalar ty): Scalar.ofInt 0 (by simp) ≤ x := by
apply (Scalar.le_equiv _ _).2
convert x.hmin
cases ty <;> simp [ScalarTy.isSigned] at s <;> simp [Scalar.min]
@[simp]
theorem Scalar.max_unsigned_left_zero_eq {ty} [s: Fact (¬ ty.isSigned)] (x: Scalar ty):
- Max.max (Scalar.ofInt 0) x = x := max_eq_right (Scalar.zero_le_unsigned s.out x)
+ Max.max (Scalar.ofInt 0 (by simp)) x = x := max_eq_right (Scalar.zero_le_unsigned s.out x)
@[simp]
theorem Scalar.max_unsigned_right_zero_eq {ty} [s: Fact (¬ ty.isSigned)] (x: Scalar ty):
- Max.max x (Scalar.ofInt 0) = x := max_eq_left (Scalar.zero_le_unsigned s.out x)
+ Max.max x (Scalar.ofInt 0 (by simp)) = x := max_eq_left (Scalar.zero_le_unsigned s.out x)
-- Leading zeros
def core.num.Usize.leading_zeros (x : Usize) : U32 := sorry
diff --git a/backends/lean/Base/Primitives/ScalarNotations.lean b/backends/lean/Base/Primitives/ScalarNotations.lean
new file mode 100644
index 00000000..3bc86a9c
--- /dev/null
+++ b/backends/lean/Base/Primitives/ScalarNotations.lean
@@ -0,0 +1,109 @@
+import Lean
+import Lean.Meta.Tactic.Simp
+import Mathlib.Tactic.Linarith
+import Base.Primitives.Scalar
+import Base.Arith
+
+namespace Primitives
+
+open Lean Meta Elab Term
+
+/- Something strange happens here: when we solve the goal with scalar_tac, it
+ sometimes leaves meta-variables in place, which then causes issues when
+ type-checking functions. For instance, it happens when we have const-generics
+ in the translation: the constants contain meta-variables, which are then
+ used in the types, which cause issues later. An example is given below:
+ -/
+macro:max x:term:max noWs "#isize" : term => `(Isize.ofInt $x (by first | decide | scalar_tac))
+macro:max x:term:max noWs "#i8" : term => `(I8.ofInt $x (by first | decide | scalar_tac))
+macro:max x:term:max noWs "#i16" : term => `(I16.ofInt $x (by first | decide | scalar_tac))
+macro:max x:term:max noWs "#i32" : term => `(I32.ofInt $x (by first | decide | scalar_tac))
+macro:max x:term:max noWs "#i64" : term => `(I64.ofInt $x (by first | decide | scalar_tac))
+macro:max x:term:max noWs "#i128" : term => `(I128.ofInt $x (by first | decide | scalar_tac))
+macro:max x:term:max noWs "#usize" : term => `(Usize.ofInt $x (by first | decide | scalar_tac))
+macro:max x:term:max noWs "#u8" : term => `(U8.ofInt $x (by first | decide | scalar_tac))
+macro:max x:term:max noWs "#u16" : term => `(U16.ofInt $x (by first | decide | scalar_tac))
+macro:max x:term:max noWs "#u32" : term => `(U32.ofInt $x (by first | decide | scalar_tac))
+macro:max x:term:max noWs "#u64" : term => `(U64.ofInt $x (by first | decide | scalar_tac))
+macro:max x:term:max noWs "#u128" : term => `(U128.ofInt $x (by first | decide | scalar_tac))
+
+-- Notation for pattern matching
+-- We make the precedence looser than the negation.
+notation:70 a:70 "#scalar" => Scalar.mk (a) _ _
+
+/- Testing the notations -/
+example := 0#u32
+example := 1#u32
+example := 1#i32
+example := 0#isize
+example := (-1)#isize
+
+example := 1#u32
+
+/-
+-- This doesn't work anymore
+example (x : U32) : Bool :=
+ match x with
+ | 0#u32 => true
+ | _ => false
+
+example (x : U32) : Bool :=
+ match x with
+ | 1#u32 => true
+ | _ => false
+
+example (x : I32) : Bool :=
+ match x with
+ | (-1)#i32 => true
+ | _ => false
+-/
+
+example (x : U32) : Bool :=
+ match x with
+ | 0#scalar => true
+ | _ => false
+
+example (x : U32) : Bool :=
+ match x with
+ | 1#scalar => true
+ | _ => false
+
+example (x : I32) : Bool :=
+ match x with
+ | (-1)#scalar => true
+ | _ => false
+
+example {ty} (x : Scalar ty) : ℤ :=
+ match x with
+ | v#scalar => v
+
+example {ty} (x : Scalar ty) : Bool :=
+ match x with
+ | 1#scalar => true
+ | _ => false
+
+example {ty} (x : Scalar ty) : Bool :=
+ match x with
+ | -(1 : Int)#scalar => true
+ | _ => false
+
+-- Testing the notations
+example : Result Usize := 0#usize + 1#usize
+
+-- More complex expressions
+example (x y : Int) (h : 0 ≤ x + y ∧ x + y ≤ 1000) : U32 := (x + y)#u32
+
+namespace Scalar.Examples
+
+ abbrev Array (a : Type) (len : U32) := { l : List a // l.length = len.val }
+
+ -- Checking the syntax
+ example : Array Int 0#u32 := ⟨ [], by simp ⟩
+
+ /- The example below fails if we don't use `decide` in the elaboration
+ of the scalar notation -/
+ example (a : Array (Array Int 32#u32) 32#u32) := a
+
+end Scalar.Examples
+
+end Primitives
diff --git a/backends/lean/Base/Primitives/Vec.lean b/backends/lean/Base/Primitives/Vec.lean
index 5ed7b606..d144fcb8 100644
--- a/backends/lean/Base/Primitives/Vec.lean
+++ b/backends/lean/Base/Primitives/Vec.lean
@@ -34,7 +34,7 @@ abbrev Vec.v {α : Type u} (v : Vec α) : List α := v.val
example {a: Type u} (v : Vec a) : v.length ≤ Scalar.max ScalarTy.Usize := by
scalar_tac
-def Vec.new (α : Type u): Vec α := ⟨ [], by apply Scalar.cMax_suffices .Usize; simp; decide ⟩
+def Vec.new (α : Type u): Vec α := ⟨ [], by apply Scalar.cMax_suffices .Usize; simp ⟩
instance (α : Type u) : Inhabited (Vec α) := by
constructor
@@ -192,7 +192,7 @@ def alloc.slice.Slice.to_vec
def core.slice.Slice.reverse (T : Type) (s : Slice T) : Slice T :=
⟨ s.val.reverse, by sorry ⟩
-def alloc.vec.Vec.with_capacity (T : Type) (s : Usize) : alloc.vec.Vec T := Vec.new T
+def alloc.vec.Vec.with_capacity (T : Type) (_ : Usize) : alloc.vec.Vec T := Vec.new T
/- [alloc::vec::{(core::ops::deref::Deref for alloc::vec::Vec<T, A>)#9}::deref]:
Source: '/rustc/d59363ad0b6391b7fc5bbb02c9ccf9300eef3753/library/alloc/src/vec/mod.rs', lines 2624:4-2624:27
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index eab85054..4acf3f99 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -241,11 +241,12 @@ let rec extract_typed_pattern (span : Meta.span) (ctx : extraction_ctx)
(fmt : F.formatter) (is_let : bool) (inside : bool) ?(with_type = false)
(v : typed_pattern) : extraction_ctx =
if with_type then F.pp_print_string fmt "(";
+ let is_pattern = true in
let inside = inside && not with_type in
let ctx =
match v.value with
| PatConstant cv ->
- extract_literal span fmt inside cv;
+ extract_literal span fmt is_pattern inside cv;
ctx
| PatVar (v, _) ->
let vname = ctx_compute_var_basename span ctx v.basename v.ty in
@@ -307,6 +308,7 @@ let extract_texpression_errors (fmt : F.formatter) =
let rec extract_texpression (span : Meta.span) (ctx : extraction_ctx)
(fmt : F.formatter) (inside : bool) (e : texpression) : unit =
+ let is_pattern = false in
match e.e with
| Var var_id ->
let var_name = ctx_get_var span var_id ctx in
@@ -314,7 +316,7 @@ let rec extract_texpression (span : Meta.span) (ctx : extraction_ctx)
| CVar var_id ->
let var_name = ctx_get_const_generic_var span var_id ctx in
F.pp_print_string fmt var_name
- | Const cv -> extract_literal span fmt inside cv
+ | Const cv -> extract_literal span fmt is_pattern inside cv
| App _ ->
let app, args = destruct_apps e in
extract_App span ctx fmt inside app args
diff --git a/compiler/ExtractTypes.ml b/compiler/ExtractTypes.ml
index 15e75da2..631db13e 100644
--- a/compiler/ExtractTypes.ml
+++ b/compiler/ExtractTypes.ml
@@ -11,12 +11,13 @@ include ExtractBase
Inputs:
- formatter
+ - [is_pattern]: if [true], it means we are generating a (match) pattern
- [inside]: if [true], the value should be wrapped in parentheses
if it is made of an application (ex.: [U32 3])
- the constant value
*)
-let extract_literal (span : Meta.span) (fmt : F.formatter) (inside : bool)
- (cv : literal) : unit =
+let extract_literal (span : Meta.span) (fmt : F.formatter) (is_pattern : bool)
+ (inside : bool) (cv : literal) : unit =
match cv with
| VScalar sv -> (
match backend () with
@@ -39,8 +40,11 @@ let extract_literal (span : Meta.span) (fmt : F.formatter) (inside : bool)
let iname = int_name sv.int_ty in
F.pp_print_string fmt ("%" ^ iname)
| Lean ->
- let iname = String.lowercase_ascii (int_name sv.int_ty) in
- F.pp_print_string fmt ("#" ^ iname)
+ (* We don't use the same notation for patterns and regular literals *)
+ if is_pattern then F.pp_print_string fmt "#scalar"
+ else
+ let iname = String.lowercase_ascii (int_name sv.int_ty) in
+ F.pp_print_string fmt ("#" ^ iname)
| HOL4 -> ()
| _ -> craise __FILE__ __LINE__ span "Unreachable");
if print_brackets then F.pp_print_string fmt ")")
@@ -409,7 +413,7 @@ let extract_const_generic (span : Meta.span) (ctx : extraction_ctx)
| CgGlobal id ->
let s = ctx_get_global span id ctx in
F.pp_print_string fmt s
- | CgValue v -> extract_literal span fmt inside v
+ | CgValue v -> extract_literal span fmt false inside v
| CgVar id ->
let s = ctx_get_const_generic_var span id ctx in
F.pp_print_string fmt s
diff --git a/tests/lean/Matches.lean b/tests/lean/Matches.lean
index 3e3a558b..9233841b 100644
--- a/tests/lean/Matches.lean
+++ b/tests/lean/Matches.lean
@@ -9,8 +9,8 @@ namespace matches
Source: 'tests/src/matches.rs', lines 4:0-4:27 -/
def match_u32 (x : U32) : Result U32 :=
match x with
- | 0#u32 => Result.ok 0#u32
- | 1#u32 => Result.ok 1#u32
+ | 0#scalar => Result.ok 0#u32
+ | 1#scalar => Result.ok 1#u32
| _ => Result.ok 2#u32
end matches