From 124ee77181c4255e2c8f730305b0b1b7802b9a58 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 7 Mar 2024 17:43:55 +0100 Subject: Add a notation for tuple field accesses in Lean --- backends/lean/Base/Primitives/Base.lean | 51 +++++++++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) diff --git a/backends/lean/Base/Primitives/Base.lean b/backends/lean/Base/Primitives/Base.lean index 9dbaf133..adec9a8b 100644 --- a/backends/lean/Base/Primitives/Base.lean +++ b/backends/lean/Base/Primitives/Base.lean @@ -123,6 +123,57 @@ def Result.attach {α: Type} (o : Result α): Result { x : α // o = ret x } := simp [Bind.bind] cases e <;> simp +------------------------------- +-- Tuple field access syntax -- +------------------------------- +-- Declare new syntax `a.#i` for accessing the `i`-th term in a tuple +-- The `noWs` parser is used to ensure there is no whitespace. +syntax term noWs ".#" noWs num : term + +open Lean Meta Elab Term + +-- Auxliary function for computing the number of elements in a tuple (`Prod`) type. +def getArity (type : Expr) : Nat := + match type with + | .app (.app (.const ``Prod _) _) as => getArity as + 1 + | _ => 1 -- It is not product + +-- Given a `tuple` of size `n`, construct a term that for accessing the `i`-th element +def mkGetIdx (tuple : Expr) (n : Nat) (i : Nat) : MetaM Expr := do + match i with + | 0 => mkAppM ``Prod.fst #[tuple] + | i+1 => + if n = 2 then + -- If the tuple has only two elements and `i` is not `0`, + -- we just return the second element. + mkAppM ``Prod.snd #[tuple] + else + -- Otherwise, we continue with the rest of the tuple. + let tuple ← mkAppM ``Prod.snd #[tuple] + mkGetIdx tuple (n-1) i + +-- Now, we define the elaboration function for the new syntax `a#i` +elab_rules : term +| `($a:term.#$i:num) => do + -- Convert `i : Syntax` into a natural number + let i := i.getNat + -- Return error if it is 0. + unless i ≥ 0 do + throwError "tuple index must be greater or equal to 0" + -- Convert `a : Syntax` into an `tuple : Expr` without providing expected type + let tuple ← elabTerm a none + let type ← inferType tuple + -- Instantiate assigned metavariable occurring in `type` + let type ← instantiateMVars type + -- Ensure `tuple`'s type is a `Prod`uct. + unless type.isAppOf ``Prod do + throwError "tuple expected{indentExpr type}" + let n := getArity type + -- Ensure `i` is a valid index + unless i < n do + throwError "invalid tuple access at {i}, tuple has {n} elements" + mkGetIdx tuple n i + ---------- -- MISC -- ---------- -- cgit v1.2.3 From 23ce25c77052c02312f19f17c51fe0b61d6abc93 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 7 Mar 2024 18:17:41 +0100 Subject: Introduce a notation for constant scalars in match patterns --- backends/lean/Base/Primitives/Scalar.lean | 29 +++++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean index 285bc7fb..422cbc6a 100644 --- a/backends/lean/Base/Primitives/Scalar.lean +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -488,6 +488,17 @@ class HNeg (α : Type u) (β : outParam (Type v)) where prefix:75 "-" => HNeg.hNeg +/- We need this, otherwise we break pattern matching like in: + + ``` + def is_minus_one (x : Int) : Bool := + match x with + | -1 => true + | _ => false + ``` +-/ +attribute [match_pattern] HNeg.hNeg + instance : HNeg Isize (Result Isize) where hNeg x := Scalar.neg x instance : HNeg I8 (Result I8) where hNeg x := Scalar.neg x instance : HNeg I16 (Result I16) where hNeg x := Scalar.neg x @@ -1113,4 +1124,22 @@ instance (ty : ScalarTy) : DecidableEq (Scalar ty) := -- else -- .fail integerOverflow +-- Notation for pattern matching +-- We make the precedence looser than the negation. +notation:70 a:70 "#scalar" => Scalar.mk (a) _ _ + +example {ty} (x : Scalar ty) : ℤ := + match x with + | v#scalar => v + +example {ty} (x : Scalar ty) : Bool := + match x with + | 1#scalar => true + | _ => false + +example {ty} (x : Scalar ty) : Bool := + match x with + | -(1 : Int)#scalar => true + | _ => false + end Primitives -- cgit v1.2.3 From bc397dea5c5a67766c9c0381efad222524f68881 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 8 Mar 2024 07:56:44 +0100 Subject: Update the notation for heterogeneous negation --- backends/lean/Base/Primitives/Scalar.lean | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean index 422cbc6a..3afd13d2 100644 --- a/backends/lean/Base/Primitives/Scalar.lean +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -478,15 +478,24 @@ instance (ty : ScalarTy) : Inhabited (Scalar ty) := by Remark: there is no heterogeneous negation in the Lean prelude: we thus introduce one here. -The notation typeclass for heterogeneous addition. -This enables the notation `- a : β` where `a : α`. +The notation typeclass for heterogeneous negation. -/ class HNeg (α : Type u) (β : outParam (Type v)) where /-- `- a` computes the negation of `a`. The meaning of this notation is type-dependent. -/ hNeg : α → β -prefix:75 "-" => HNeg.hNeg +/- Notation for heterogeneous negation. + + We initially used the notation "-" but it conflicted with the homogeneous + negation too much. In particular, it made terms like `-10` ambiguous, + and seemingly caused to backtracking in elaboration, leading to definitions + like arrays of constants to take an unreasonable time to get elaborated + and type-checked. + + TODO: PR to replace Neg with HNeg in Lean? + -/ +prefix:75 "-." => HNeg.hNeg /- We need this, otherwise we break pattern matching like in: -- cgit v1.2.3 From e1e888f23935bfb34830fe160593e09df75a7f20 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 8 Mar 2024 08:03:37 +0100 Subject: Update the code generation --- compiler/ExtractBase.ml | 8 ++++++-- compiler/ExtractTypes.ml | 5 ----- 2 files changed, 6 insertions(+), 7 deletions(-) diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index db887539..5aa8323e 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -782,9 +782,13 @@ let ctx_get_termination_measure (def_id : A.FunDeclId.id) let unop_name (unop : unop) : string = match unop with | Not -> ( - match !backend with FStar | Lean -> "not" | Coq -> "negb" | HOL4 -> "~") + match !backend with + | FStar -> "not" + | Lean -> "¬" + | Coq -> "negb" + | HOL4 -> "~") | Neg (int_ty : integer_type) -> ( - match !backend with Lean -> "-" | _ -> int_name int_ty ^ "_neg") + match !backend with Lean -> "-." | _ -> int_name int_ty ^ "_neg") | Cast _ -> (* We never directly use the unop name in this case *) raise (Failure "Unsupported") diff --git a/compiler/ExtractTypes.ml b/compiler/ExtractTypes.ml index 51e3fd77..a3dbf3cc 100644 --- a/compiler/ExtractTypes.ml +++ b/compiler/ExtractTypes.ml @@ -31,11 +31,6 @@ let extract_literal (fmt : F.formatter) (inside : bool) (cv : literal) : unit = (* We need to add parentheses if the value is negative *) if sv.value >= Z.of_int 0 then F.pp_print_string fmt (Z.to_string sv.value) - else if !backend = Lean then - (* TODO: parsing issues with Lean because there are ambiguous - interpretations between int values and nat values *) - F.pp_print_string fmt - ("(-(" ^ Z.to_string (Z.neg sv.value) ^ ":Int))") else F.pp_print_string fmt ("(" ^ Z.to_string sv.value ^ ")"); (match !backend with | Coq -> -- cgit v1.2.3 From f74647773d7dd21580fd938dd9b4e300719b0234 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 8 Mar 2024 08:03:48 +0100 Subject: Regenerate the test files --- tests/lean/Arrays.lean | 2 +- tests/lean/Hashmap/Funs.lean | 12 ++++----- tests/lean/HashmapMain/Funs.lean | 12 ++++----- tests/lean/NoNestedBorrows.lean | 56 ++++++++++++++++++++-------------------- tests/lean/Paper.lean | 10 +++---- 5 files changed, 46 insertions(+), 46 deletions(-) diff --git a/tests/lean/Arrays.lean b/tests/lean/Arrays.lean index 5158ca28..6f9cd94c 100644 --- a/tests/lean/Arrays.lean +++ b/tests/lean/Arrays.lean @@ -397,7 +397,7 @@ divergent def sum2_loop def sum2 (s : Slice U32) (s2 : Slice U32) : Result U32 := let i := Slice.len U32 s let i1 := Slice.len U32 s2 - if not (i = i1) + if ¬ (i = i1) then Result.fail .panic else sum2_loop s s2 0#u32 0#usize diff --git a/tests/lean/Hashmap/Funs.lean b/tests/lean/Hashmap/Funs.lean index 3978bfc7..f0706725 100644 --- a/tests/lean/Hashmap/Funs.lean +++ b/tests/lean/Hashmap/Funs.lean @@ -397,14 +397,14 @@ def test1 : Result Unit := let hm3 ← HashMap.insert U64 hm2 1024#usize 138#u64 let hm4 ← HashMap.insert U64 hm3 1056#usize 256#u64 let i ← HashMap.get U64 hm4 128#usize - if not (i = 18#u64) + if ¬ (i = 18#u64) then Result.fail .panic else do let (_, get_mut_back) ← HashMap.get_mut U64 hm4 1024#usize let hm5 ← get_mut_back 56#u64 let i1 ← HashMap.get U64 hm5 1024#usize - if not (i1 = 56#u64) + if ¬ (i1 = 56#u64) then Result.fail .panic else do @@ -412,22 +412,22 @@ def test1 : Result Unit := match x with | none => Result.fail .panic | some x1 => - if not (x1 = 56#u64) + if ¬ (x1 = 56#u64) then Result.fail .panic else do let i2 ← HashMap.get U64 hm6 0#usize - if not (i2 = 42#u64) + if ¬ (i2 = 42#u64) then Result.fail .panic else do let i3 ← HashMap.get U64 hm6 128#usize - if not (i3 = 18#u64) + if ¬ (i3 = 18#u64) then Result.fail .panic else do let i4 ← HashMap.get U64 hm6 1056#usize - if not (i4 = 256#u64) + if ¬ (i4 = 256#u64) then Result.fail .panic else Result.ret () diff --git a/tests/lean/HashmapMain/Funs.lean b/tests/lean/HashmapMain/Funs.lean index ebed2570..31441b4a 100644 --- a/tests/lean/HashmapMain/Funs.lean +++ b/tests/lean/HashmapMain/Funs.lean @@ -419,14 +419,14 @@ def hashmap.test1 : Result Unit := let hm3 ← hashmap.HashMap.insert U64 hm2 1024#usize 138#u64 let hm4 ← hashmap.HashMap.insert U64 hm3 1056#usize 256#u64 let i ← hashmap.HashMap.get U64 hm4 128#usize - if not (i = 18#u64) + if ¬ (i = 18#u64) then Result.fail .panic else do let (_, get_mut_back) ← hashmap.HashMap.get_mut U64 hm4 1024#usize let hm5 ← get_mut_back 56#u64 let i1 ← hashmap.HashMap.get U64 hm5 1024#usize - if not (i1 = 56#u64) + if ¬ (i1 = 56#u64) then Result.fail .panic else do @@ -434,22 +434,22 @@ def hashmap.test1 : Result Unit := match x with | none => Result.fail .panic | some x1 => - if not (x1 = 56#u64) + if ¬ (x1 = 56#u64) then Result.fail .panic else do let i2 ← hashmap.HashMap.get U64 hm6 0#usize - if not (i2 = 42#u64) + if ¬ (i2 = 42#u64) then Result.fail .panic else do let i3 ← hashmap.HashMap.get U64 hm6 128#usize - if not (i3 = 18#u64) + if ¬ (i3 = 18#u64) then Result.fail .panic else do let i4 ← hashmap.HashMap.get U64 hm6 1056#usize - if not (i4 = 256#u64) + if ¬ (i4 = 256#u64) then Result.fail .panic else Result.ret () diff --git a/tests/lean/NoNestedBorrows.lean b/tests/lean/NoNestedBorrows.lean index bed71d94..ef81f2e9 100644 --- a/tests/lean/NoNestedBorrows.lean +++ b/tests/lean/NoNestedBorrows.lean @@ -46,7 +46,7 @@ inductive Sum (T1 T2 : Type) := /- [no_nested_borrows::neg_test]: Source: 'src/no_nested_borrows.rs', lines 48:0-48:30 -/ def neg_test (x : I32) : Result I32 := - - x + -. x /- [no_nested_borrows::add_u32]: Source: 'src/no_nested_borrows.rs', lines 54:0-54:37 -/ @@ -185,7 +185,7 @@ def test3 : Result Unit := let x ← get_max 4#u32 3#u32 let y ← get_max 10#u32 11#u32 let z ← x + y - if not (z = 15#u32) + if ¬ (z = 15#u32) then Result.fail .panic else Result.ret () @@ -196,8 +196,8 @@ def test3 : Result Unit := Source: 'src/no_nested_borrows.rs', lines 169:0-169:18 -/ def test_neg1 : Result Unit := do - let y ← - 3#i32 - if not (y = (-(3:Int))#i32) + let y ← -. 3#i32 + if ¬ (y = (-3)#i32) then Result.fail .panic else Result.ret () @@ -207,7 +207,7 @@ def test_neg1 : Result Unit := /- [no_nested_borrows::refs_test1]: Source: 'src/no_nested_borrows.rs', lines 176:0-176:19 -/ def refs_test1 : Result Unit := - if not (1#i32 = 1#i32) + if ¬ (1#i32 = 1#i32) then Result.fail .panic else Result.ret () @@ -217,15 +217,15 @@ def refs_test1 : Result Unit := /- [no_nested_borrows::refs_test2]: Source: 'src/no_nested_borrows.rs', lines 187:0-187:19 -/ def refs_test2 : Result Unit := - if not (2#i32 = 2#i32) + if ¬ (2#i32 = 2#i32) then Result.fail .panic else - if not (0#i32 = 0#i32) + if ¬ (0#i32 = 0#i32) then Result.fail .panic else - if not (2#i32 = 2#i32) + if ¬ (2#i32 = 2#i32) then Result.fail .panic - else if not (2#i32 = 2#i32) + else if ¬ (2#i32 = 2#i32) then Result.fail .panic else Result.ret () @@ -247,7 +247,7 @@ def test_box1 : Result Unit := let (_, deref_mut_back) ← alloc.boxed.Box.deref_mut I32 0#i32 let b ← deref_mut_back 1#i32 let x ← alloc.boxed.Box.deref I32 b - if not (x = 1#i32) + if ¬ (x = 1#i32) then Result.fail .panic else Result.ret () @@ -278,7 +278,7 @@ def test_panic (b : Bool) : Result Unit := def test_copy_int : Result Unit := do let y ← copy_int 0#i32 - if not (0#i32 = y) + if ¬ (0#i32 = y) then Result.fail .panic else Result.ret () @@ -297,7 +297,7 @@ def is_cons (T : Type) (l : List T) : Result Bool := def test_is_cons : Result Unit := do let b ← is_cons I32 (List.Cons 0#i32 List.Nil) - if not b + if ¬ b then Result.fail .panic else Result.ret () @@ -317,7 +317,7 @@ def test_split_list : Result Unit := do let p ← split_list I32 (List.Cons 0#i32 List.Nil) let (hd, _) := p - if not (hd = 0#i32) + if ¬ (hd = 0#i32) then Result.fail .panic else Result.ret () @@ -342,14 +342,14 @@ def choose_test : Result Unit := do let (z, choose_back) ← choose I32 true 0#i32 0#i32 let z1 ← z + 1#i32 - if not (z1 = 1#i32) + if ¬ (z1 = 1#i32) then Result.fail .panic else do let (x, y) ← choose_back z1 - if not (x = 1#i32) + if ¬ (x = 1#i32) then Result.fail .panic - else if not (y = 0#i32) + else if ¬ (y = 0#i32) then Result.fail .panic else Result.ret () @@ -441,22 +441,22 @@ def test_list_functions : Result Unit := let l := List.Cons 2#i32 List.Nil let l1 := List.Cons 1#i32 l let i ← list_length I32 (List.Cons 0#i32 l1) - if not (i = 3#u32) + if ¬ (i = 3#u32) then Result.fail .panic else do let i1 ← list_nth_shared I32 (List.Cons 0#i32 l1) 0#u32 - if not (i1 = 0#i32) + if ¬ (i1 = 0#i32) then Result.fail .panic else do let i2 ← list_nth_shared I32 (List.Cons 0#i32 l1) 1#u32 - if not (i2 = 1#i32) + if ¬ (i2 = 1#i32) then Result.fail .panic else do let i3 ← list_nth_shared I32 (List.Cons 0#i32 l1) 2#u32 - if not (i3 = 2#i32) + if ¬ (i3 = 2#i32) then Result.fail .panic else do @@ -464,17 +464,17 @@ def test_list_functions : Result Unit := list_nth_mut I32 (List.Cons 0#i32 l1) 1#u32 let ls ← list_nth_mut_back 3#i32 let i4 ← list_nth_shared I32 ls 0#u32 - if not (i4 = 0#i32) + if ¬ (i4 = 0#i32) then Result.fail .panic else do let i5 ← list_nth_shared I32 ls 1#u32 - if not (i5 = 3#i32) + if ¬ (i5 = 3#i32) then Result.fail .panic else do let i6 ← list_nth_shared I32 ls 2#u32 - if not (i6 = 2#i32) + if ¬ (i6 = 2#i32) then Result.fail .panic else Result.ret () @@ -555,24 +555,24 @@ def test_constants : Result Unit := do let swt ← new_tuple1 let (i, _) := swt.p - if not (i = 1#u32) + if ¬ (i = 1#u32) then Result.fail .panic else do let swt1 ← new_tuple2 let (i1, _) := swt1.p - if not (i1 = 1#i16) + if ¬ (i1 = 1#i16) then Result.fail .panic else do let swt2 ← new_tuple3 let (i2, _) := swt2.p - if not (i2 = 1#u64) + if ¬ (i2 = 1#u64) then Result.fail .panic else do let swp ← new_pair1 - if not (swp.p.x = 1#u32) + if ¬ (swp.p.x = 1#u32) then Result.fail .panic else Result.ret () @@ -591,7 +591,7 @@ def test_weird_borrows1 : Result Unit := Source: 'src/no_nested_borrows.rs', lines 481:0-481:37 -/ def test_mem_replace (px : U32) : Result U32 := let (y, _) := core.mem.replace U32 px 1#u32 - if not (y = 0#u32) + if ¬ (y = 0#u32) then Result.fail .panic else Result.ret 2#u32 diff --git a/tests/lean/Paper.lean b/tests/lean/Paper.lean index a35c8db0..4930a05c 100644 --- a/tests/lean/Paper.lean +++ b/tests/lean/Paper.lean @@ -15,7 +15,7 @@ def ref_incr (x : I32) : Result I32 := def test_incr : Result Unit := do let i ← ref_incr 0#i32 - if not (i = 1#i32) + if ¬ (i = 1#i32) then Result.fail .panic else Result.ret () @@ -40,14 +40,14 @@ def test_choose : Result Unit := do let (z, choose_back) ← choose I32 true 0#i32 0#i32 let z1 ← z + 1#i32 - if not (z1 = 1#i32) + if ¬ (z1 = 1#i32) then Result.fail .panic else do let (x, y) ← choose_back z1 - if not (x = 1#i32) + if ¬ (x = 1#i32) then Result.fail .panic - else if not (y = 0#i32) + else if ¬ (y = 0#i32) then Result.fail .panic else Result.ret () @@ -101,7 +101,7 @@ def test_nth : Result Unit := let x1 ← x + 1#i32 let l2 ← list_nth_mut_back x1 let i ← sum l2 - if not (i = 7#i32) + if ¬ (i = 7#i32) then Result.fail .panic else Result.ret () -- cgit v1.2.3 From 46b126f4e0e86f14475bc310e150948434726dc7 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 8 Mar 2024 08:50:50 +0100 Subject: Update the handling of notations like #u32 or #isize --- backends/lean/Base/Arith/Scalar.lean | 2 +- backends/lean/Base/Primitives/ArraySlice.lean | 2 +- backends/lean/Base/Primitives/Scalar.lean | 168 ++++++++++++++++---------- backends/lean/Base/Primitives/Vec.lean | 2 +- 4 files changed, 104 insertions(+), 70 deletions(-) diff --git a/backends/lean/Base/Arith/Scalar.lean b/backends/lean/Base/Arith/Scalar.lean index 43fd2766..9441be86 100644 --- a/backends/lean/Base/Arith/Scalar.lean +++ b/backends/lean/Base/Arith/Scalar.lean @@ -74,7 +74,7 @@ example : U32.ofInt 1 ≤ U32.max := by scalar_tac example (x : Int) (h0 : 0 ≤ x) (h1 : x ≤ U32.max) : - U32.ofInt x (by constructor <;> scalar_tac) ≤ U32.max := by + U32.ofIntCore x (by constructor <;> scalar_tac) ≤ U32.max := by scalar_tac -- Not equal diff --git a/backends/lean/Base/Primitives/ArraySlice.lean b/backends/lean/Base/Primitives/ArraySlice.lean index c90a85b8..e1a39d40 100644 --- a/backends/lean/Base/Primitives/ArraySlice.lean +++ b/backends/lean/Base/Primitives/ArraySlice.lean @@ -131,7 +131,7 @@ def Slice.new (α : Type u): Slice α := ⟨ [], by apply Scalar.cMax_suffices . -- TODO: very annoying that the α is an explicit parameter def Slice.len (α : Type u) (v : Slice α) : Usize := - Usize.ofIntCore v.val.len (by scalar_tac) (by scalar_tac) + Usize.ofIntCore v.val.len (by constructor <;> scalar_tac) @[simp] theorem Slice.len_val {α : Type u} (v : Slice α) : (Slice.len α v).val = v.length := diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean index 3afd13d2..bf6b01a6 100644 --- a/backends/lean/Base/Primitives/Scalar.lean +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -281,25 +281,38 @@ theorem Scalar.bound_suffices (ty : ScalarTy) (x : Int) : λ h => by apply And.intro <;> have hmin := Scalar.cMin_bound ty <;> have hmax := Scalar.cMax_bound ty <;> linarith -def Scalar.ofIntCore {ty : ScalarTy} (x : Int) - (hmin : Scalar.min ty ≤ x) (hmax : x ≤ Scalar.max ty) : Scalar ty := - { val := x, hmin := hmin, hmax := hmax } - --- Tactic to prove that integers are in bounds --- TODO: use this: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/instance.20with.20tactic.20autoparam -syntax "intlit" : tactic -macro_rules - | `(tactic| intlit) => `(tactic| apply Scalar.bound_suffices; decide) - -def Scalar.ofInt {ty : ScalarTy} (x : Int) - (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty := by intlit) : Scalar ty := - -- Remark: we initially wrote: - -- let ⟨ hmin, hmax ⟩ := h - -- Scalar.ofIntCore x hmin hmax - -- We updated to the line below because a similar pattern in `Scalar.tryMk` - -- made reduction block. Both versions seem to work for `Scalar.ofInt`, though. - -- TODO: investigate - Scalar.ofIntCore x h.left h.right +/- [match_pattern] attribute: allows to us `Scalar.ofIntCore` inside of patterns. + This is particularly useful once we introduce notations like `#u32` (which + desugards to `Scalar.ofIntCore`) as it allows to write expressions like this: + Example: + ``` + match x with + | 0#u32 => ... + | 1#u32 => ... + | ... + ``` + -/ +@[match_pattern] def Scalar.ofIntCore {ty : ScalarTy} (x : Int) + (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : Scalar ty := + { val := x, hmin := h.left, hmax := h.right } + +-- The definitions below are used later to introduce nice syntax for constants, +-- like `1#u32`. We are reusing the technique described here: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Different.20elaboration.20inside.2Foutside.20of.20match.20patterns/near/425455284 + +class InBounds (ty : ScalarTy) (x : Int) := + hInBounds : Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty + +-- This trick to trigger reduction for decidable propositions comes from +-- here: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/instance.20with.20tactic.20autoparam/near/343495807 +class Decide (p : Prop) [Decidable p] : Prop where + isTrue : p +instance : @Decide p (.isTrue h) := @Decide.mk p (_) h + +instance [Decide (Scalar.cMin ty ≤ v ∧ v ≤ Scalar.cMax ty)] : InBounds ty v where + hInBounds := Decide.isTrue + +@[reducible, match_pattern] def Scalar.ofInt {ty : ScalarTy} (x : Int) [InBounds ty x] : Scalar ty := + Scalar.ofIntCore x (Scalar.bound_suffices ty x InBounds.hInBounds) @[simp] def Scalar.check_bounds (ty : ScalarTy) (x : Int) : Bool := (Scalar.cMin ty ≤ x || Scalar.min ty ≤ x) ∧ (x ≤ Scalar.cMax ty || x ≤ Scalar.max ty) @@ -326,7 +339,7 @@ def Scalar.tryMk (ty : ScalarTy) (x : Int) : Result (Scalar ty) := -- ``` -- then normalization blocks (for instance, some proofs which use reflexivity fail). -- However, the version below doesn't block reduction (TODO: investigate): - return Scalar.ofInt x (Scalar.check_bounds_prop h) + return Scalar.ofIntCore x (Scalar.check_bounds_prop h) else fail integerOverflow def Scalar.neg {ty : ScalarTy} (x : Scalar ty) : Result (Scalar ty) := Scalar.tryMk ty (- x.val) @@ -439,8 +452,8 @@ instance (ty : ScalarTy) : Inhabited (Scalar ty) := by constructor; cases ty <;> apply (Scalar.ofInt 0) -- TODO: reducible? -@[reducible] def core_isize_min : Isize := Scalar.ofInt Isize.min (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Isize)) -@[reducible] def core_isize_max : Isize := Scalar.ofInt Isize.max (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Isize)) +@[reducible] def core_isize_min : Isize := Scalar.ofIntCore Isize.min (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Isize)) +@[reducible] def core_isize_max : Isize := Scalar.ofIntCore Isize.max (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Isize)) @[reducible] def core_i8_min : I8 := Scalar.ofInt I8.min @[reducible] def core_i8_max : I8 := Scalar.ofInt I8.max @[reducible] def core_i16_min : I16 := Scalar.ofInt I16.min @@ -453,8 +466,8 @@ instance (ty : ScalarTy) : Inhabited (Scalar ty) := by @[reducible] def core_i128_max : I128 := Scalar.ofInt I128.max -- TODO: reducible? -@[reducible] def core_usize_min : Usize := Scalar.ofInt Usize.min -@[reducible] def core_usize_max : Usize := Scalar.ofInt Usize.max (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Usize)) +@[reducible] def core_usize_min : Usize := Scalar.ofIntCore Usize.min (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Usize)) +@[reducible] def core_usize_max : Usize := Scalar.ofIntCore Usize.max (by simp [Scalar.min, Scalar.max]; apply (Scalar.min_le_max .Usize)) @[reducible] def core_u8_min : U8 := Scalar.ofInt U8.min @[reducible] def core_u8_max : U8 := Scalar.ofInt U8.max @[reducible] def core_u16_min : U16 := Scalar.ofInt U16.min @@ -985,18 +998,18 @@ def U128.ofIntCore := @Scalar.ofIntCore .U128 -- ofInt -- TODO: typeclass? -abbrev Isize.ofInt := @Scalar.ofInt .Isize -abbrev I8.ofInt := @Scalar.ofInt .I8 -abbrev I16.ofInt := @Scalar.ofInt .I16 -abbrev I32.ofInt := @Scalar.ofInt .I32 -abbrev I64.ofInt := @Scalar.ofInt .I64 -abbrev I128.ofInt := @Scalar.ofInt .I128 -abbrev Usize.ofInt := @Scalar.ofInt .Usize -abbrev U8.ofInt := @Scalar.ofInt .U8 -abbrev U16.ofInt := @Scalar.ofInt .U16 -abbrev U32.ofInt := @Scalar.ofInt .U32 -abbrev U64.ofInt := @Scalar.ofInt .U64 -abbrev U128.ofInt := @Scalar.ofInt .U128 +@[match_pattern] abbrev Isize.ofInt := @Scalar.ofInt .Isize +@[match_pattern] abbrev I8.ofInt := @Scalar.ofInt .I8 +@[match_pattern] abbrev I16.ofInt := @Scalar.ofInt .I16 +@[match_pattern] abbrev I32.ofInt := @Scalar.ofInt .I32 +@[match_pattern] abbrev I64.ofInt := @Scalar.ofInt .I64 +@[match_pattern] abbrev I128.ofInt := @Scalar.ofInt .I128 +@[match_pattern] abbrev Usize.ofInt := @Scalar.ofInt .Usize +@[match_pattern] abbrev U8.ofInt := @Scalar.ofInt .U8 +@[match_pattern] abbrev U16.ofInt := @Scalar.ofInt .U16 +@[match_pattern] abbrev U32.ofInt := @Scalar.ofInt .U32 +@[match_pattern] abbrev U64.ofInt := @Scalar.ofInt .U64 +@[match_pattern] abbrev U128.ofInt := @Scalar.ofInt .U128 postfix:max "#isize" => Isize.ofInt postfix:max "#i8" => I8.ofInt @@ -1011,47 +1024,86 @@ postfix:max "#u32" => U32.ofInt postfix:max "#u64" => U64.ofInt postfix:max "#u128" => U128.ofInt +/- Testing the notations -/ +example := 0#u32 +example := 1#u32 +example := 1#i32 +example := 0#isize +example := (-1)#isize +example (x : U32) : Bool := + match x with + | 0#u32 => true + | _ => false + +example (x : U32) : Bool := + match x with + | 1#u32 => true + | _ => false + +example (x : I32) : Bool := + match x with + | (-1)#i32 => true + | _ => false + +-- Notation for pattern matching +-- We make the precedence looser than the negation. +notation:70 a:70 "#scalar" => Scalar.mk (a) _ _ + +example {ty} (x : Scalar ty) : ℤ := + match x with + | v#scalar => v + +example {ty} (x : Scalar ty) : Bool := + match x with + | 1#scalar => true + | _ => false + +example {ty} (x : Scalar ty) : Bool := + match x with + | -(1 : Int)#scalar => true + | _ => false + -- Testing the notations example : Result Usize := 0#usize + 1#usize -- TODO: factor those lemmas out -@[simp] theorem Scalar.ofInt_val_eq {ty} (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : (Scalar.ofInt x h).val = x := by +@[simp] theorem Scalar.ofInt_val_eq {ty} (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : (Scalar.ofIntCore x h).val = x := by simp [Scalar.ofInt, Scalar.ofIntCore] -@[simp] theorem Isize.ofInt_val_eq (h : Scalar.min ScalarTy.Isize ≤ x ∧ x ≤ Scalar.max ScalarTy.Isize) : (Isize.ofInt x h).val = x := by +@[simp] theorem Isize.ofInt_val_eq (h : Scalar.min ScalarTy.Isize ≤ x ∧ x ≤ Scalar.max ScalarTy.Isize) : (Isize.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -@[simp] theorem I8.ofInt_val_eq (h : Scalar.min ScalarTy.I8 ≤ x ∧ x ≤ Scalar.max ScalarTy.I8) : (I8.ofInt x h).val = x := by +@[simp] theorem I8.ofInt_val_eq (h : Scalar.min ScalarTy.I8 ≤ x ∧ x ≤ Scalar.max ScalarTy.I8) : (I8.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -@[simp] theorem I16.ofInt_val_eq (h : Scalar.min ScalarTy.I16 ≤ x ∧ x ≤ Scalar.max ScalarTy.I16) : (I16.ofInt x h).val = x := by +@[simp] theorem I16.ofInt_val_eq (h : Scalar.min ScalarTy.I16 ≤ x ∧ x ≤ Scalar.max ScalarTy.I16) : (I16.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -@[simp] theorem I32.ofInt_val_eq (h : Scalar.min ScalarTy.I32 ≤ x ∧ x ≤ Scalar.max ScalarTy.I32) : (I32.ofInt x h).val = x := by +@[simp] theorem I32.ofInt_val_eq (h : Scalar.min ScalarTy.I32 ≤ x ∧ x ≤ Scalar.max ScalarTy.I32) : (I32.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -@[simp] theorem I64.ofInt_val_eq (h : Scalar.min ScalarTy.I64 ≤ x ∧ x ≤ Scalar.max ScalarTy.I64) : (I64.ofInt x h).val = x := by +@[simp] theorem I64.ofInt_val_eq (h : Scalar.min ScalarTy.I64 ≤ x ∧ x ≤ Scalar.max ScalarTy.I64) : (I64.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -@[simp] theorem I128.ofInt_val_eq (h : Scalar.min ScalarTy.I128 ≤ x ∧ x ≤ Scalar.max ScalarTy.I128) : (I128.ofInt x h).val = x := by +@[simp] theorem I128.ofInt_val_eq (h : Scalar.min ScalarTy.I128 ≤ x ∧ x ≤ Scalar.max ScalarTy.I128) : (I128.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -@[simp] theorem Usize.ofInt_val_eq (h : Scalar.min ScalarTy.Usize ≤ x ∧ x ≤ Scalar.max ScalarTy.Usize) : (Usize.ofInt x h).val = x := by +@[simp] theorem Usize.ofInt_val_eq (h : Scalar.min ScalarTy.Usize ≤ x ∧ x ≤ Scalar.max ScalarTy.Usize) : (Usize.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -@[simp] theorem U8.ofInt_val_eq (h : Scalar.min ScalarTy.U8 ≤ x ∧ x ≤ Scalar.max ScalarTy.U8) : (U8.ofInt x h).val = x := by +@[simp] theorem U8.ofInt_val_eq (h : Scalar.min ScalarTy.U8 ≤ x ∧ x ≤ Scalar.max ScalarTy.U8) : (U8.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -@[simp] theorem U16.ofInt_val_eq (h : Scalar.min ScalarTy.U16 ≤ x ∧ x ≤ Scalar.max ScalarTy.U16) : (U16.ofInt x h).val = x := by +@[simp] theorem U16.ofInt_val_eq (h : Scalar.min ScalarTy.U16 ≤ x ∧ x ≤ Scalar.max ScalarTy.U16) : (U16.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -@[simp] theorem U32.ofInt_val_eq (h : Scalar.min ScalarTy.U32 ≤ x ∧ x ≤ Scalar.max ScalarTy.U32) : (U32.ofInt x h).val = x := by +@[simp] theorem U32.ofInt_val_eq (h : Scalar.min ScalarTy.U32 ≤ x ∧ x ≤ Scalar.max ScalarTy.U32) : (U32.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -@[simp] theorem U64.ofInt_val_eq (h : Scalar.min ScalarTy.U64 ≤ x ∧ x ≤ Scalar.max ScalarTy.U64) : (U64.ofInt x h).val = x := by +@[simp] theorem U64.ofInt_val_eq (h : Scalar.min ScalarTy.U64 ≤ x ∧ x ≤ Scalar.max ScalarTy.U64) : (U64.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -@[simp] theorem U128.ofInt_val_eq (h : Scalar.min ScalarTy.U128 ≤ x ∧ x ≤ Scalar.max ScalarTy.U128) : (U128.ofInt x h).val = x := by +@[simp] theorem U128.ofInt_val_eq (h : Scalar.min ScalarTy.U128 ≤ x ∧ x ≤ Scalar.max ScalarTy.U128) : (U128.ofIntCore x h).val = x := by apply Scalar.ofInt_val_eq h -- Comparisons @@ -1133,22 +1185,4 @@ instance (ty : ScalarTy) : DecidableEq (Scalar ty) := -- else -- .fail integerOverflow --- Notation for pattern matching --- We make the precedence looser than the negation. -notation:70 a:70 "#scalar" => Scalar.mk (a) _ _ - -example {ty} (x : Scalar ty) : ℤ := - match x with - | v#scalar => v - -example {ty} (x : Scalar ty) : Bool := - match x with - | 1#scalar => true - | _ => false - -example {ty} (x : Scalar ty) : Bool := - match x with - | -(1 : Int)#scalar => true - | _ => false - end Primitives diff --git a/backends/lean/Base/Primitives/Vec.lean b/backends/lean/Base/Primitives/Vec.lean index b03de15b..65249c12 100644 --- a/backends/lean/Base/Primitives/Vec.lean +++ b/backends/lean/Base/Primitives/Vec.lean @@ -43,7 +43,7 @@ instance (α : Type u) : Inhabited (Vec α) := by -- TODO: very annoying that the α is an explicit parameter def Vec.len (α : Type u) (v : Vec α) : Usize := - Usize.ofIntCore v.val.len (by scalar_tac) (by scalar_tac) + Usize.ofIntCore v.val.len (by constructor <;> scalar_tac) @[simp] theorem Vec.len_val {α : Type u} (v : Vec α) : (Vec.len α v).val = v.length := -- cgit v1.2.3 From b6f63f106baef03dd61f1100bd46c9bad7cb79e4 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 8 Mar 2024 08:53:38 +0100 Subject: Remove some comments --- backends/lean/Base/Primitives/Scalar.lean | 31 ------------------------------- 1 file changed, 31 deletions(-) diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean index bf6b01a6..3d90f1a5 100644 --- a/backends/lean/Base/Primitives/Scalar.lean +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -1154,35 +1154,4 @@ instance (ty : ScalarTy) : DecidableEq (Scalar ty) := @[simp] theorem Scalar.neq_to_neq_val {ty} : ∀ {i j : Scalar ty}, (¬ i = j) ↔ ¬ i.val = j.val := by intro i j; cases i; cases j; simp --- -- We now define a type class that subsumes the various machine integer types, so --- -- as to write a concise definition for scalar_cast, rather than exhaustively --- -- enumerating all of the possible pairs. We remark that Rust has sane semantics --- -- and fails if a cast operation would involve a truncation or modulo. - --- class MachineInteger (t: Type) where --- size: Nat --- val: t -> Fin size --- ofNatCore: (n:Nat) -> LT.lt n size -> t - --- set_option hygiene false in --- run_cmd --- for typeName in [`UInt8, `UInt16, `UInt32, `UInt64, `USize].map Lean.mkIdent do --- Lean.Elab.Command.elabCommand (← `( --- namespace $typeName --- instance: MachineInteger $typeName where --- size := size --- val := val --- ofNatCore := ofNatCore --- end $typeName --- )) - --- -- Aeneas only instantiates the destination type (`src` is implicit). We rely on --- -- Lean to infer `src`. - --- def scalar_cast { src: Type } (dst: Type) [ MachineInteger src ] [ MachineInteger dst ] (x: src): Result dst := --- if h: MachineInteger.val x < MachineInteger.size dst then --- .ret (MachineInteger.ofNatCore (MachineInteger.val x).val h) --- else --- .fail integerOverflow - end Primitives -- cgit v1.2.3 From 41d6f78a0ad6bd272164894bead3258b2001ec0c Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 8 Mar 2024 09:22:08 +0100 Subject: Update the tuples notations --- backends/lean/Base/Primitives.lean | 1 + backends/lean/Base/Primitives/Base.lean | 51 --------------------- backends/lean/Base/Tuples.lean | 80 +++++++++++++++++++++++++++++++++ 3 files changed, 81 insertions(+), 51 deletions(-) create mode 100644 backends/lean/Base/Tuples.lean diff --git a/backends/lean/Base/Primitives.lean b/backends/lean/Base/Primitives.lean index 613b6076..7196d2ec 100644 --- a/backends/lean/Base/Primitives.lean +++ b/backends/lean/Base/Primitives.lean @@ -1,4 +1,5 @@ import Base.Primitives.Base +import Base.Tuples import Base.Primitives.Scalar import Base.Primitives.ArraySlice import Base.Primitives.Vec diff --git a/backends/lean/Base/Primitives/Base.lean b/backends/lean/Base/Primitives/Base.lean index adec9a8b..9dbaf133 100644 --- a/backends/lean/Base/Primitives/Base.lean +++ b/backends/lean/Base/Primitives/Base.lean @@ -123,57 +123,6 @@ def Result.attach {α: Type} (o : Result α): Result { x : α // o = ret x } := simp [Bind.bind] cases e <;> simp -------------------------------- --- Tuple field access syntax -- -------------------------------- --- Declare new syntax `a.#i` for accessing the `i`-th term in a tuple --- The `noWs` parser is used to ensure there is no whitespace. -syntax term noWs ".#" noWs num : term - -open Lean Meta Elab Term - --- Auxliary function for computing the number of elements in a tuple (`Prod`) type. -def getArity (type : Expr) : Nat := - match type with - | .app (.app (.const ``Prod _) _) as => getArity as + 1 - | _ => 1 -- It is not product - --- Given a `tuple` of size `n`, construct a term that for accessing the `i`-th element -def mkGetIdx (tuple : Expr) (n : Nat) (i : Nat) : MetaM Expr := do - match i with - | 0 => mkAppM ``Prod.fst #[tuple] - | i+1 => - if n = 2 then - -- If the tuple has only two elements and `i` is not `0`, - -- we just return the second element. - mkAppM ``Prod.snd #[tuple] - else - -- Otherwise, we continue with the rest of the tuple. - let tuple ← mkAppM ``Prod.snd #[tuple] - mkGetIdx tuple (n-1) i - --- Now, we define the elaboration function for the new syntax `a#i` -elab_rules : term -| `($a:term.#$i:num) => do - -- Convert `i : Syntax` into a natural number - let i := i.getNat - -- Return error if it is 0. - unless i ≥ 0 do - throwError "tuple index must be greater or equal to 0" - -- Convert `a : Syntax` into an `tuple : Expr` without providing expected type - let tuple ← elabTerm a none - let type ← inferType tuple - -- Instantiate assigned metavariable occurring in `type` - let type ← instantiateMVars type - -- Ensure `tuple`'s type is a `Prod`uct. - unless type.isAppOf ``Prod do - throwError "tuple expected{indentExpr type}" - let n := getArity type - -- Ensure `i` is a valid index - unless i < n do - throwError "invalid tuple access at {i}, tuple has {n} elements" - mkGetIdx tuple n i - ---------- -- MISC -- ---------- diff --git a/backends/lean/Base/Tuples.lean b/backends/lean/Base/Tuples.lean new file mode 100644 index 00000000..d8e4a843 --- /dev/null +++ b/backends/lean/Base/Tuples.lean @@ -0,0 +1,80 @@ +import Lean +import Base.Utils + +namespace Primitives + +------------------------------- +-- Tuple field access syntax -- +------------------------------- +-- Declare new syntax `a.#i` for accessing the `i`-th term in a tuple +-- The `noWs` parser is used to ensure there is no whitespace. +syntax term noWs ".#" noWs num : term + +open Lean Meta Elab Term + +-- Auxliary function for computing the number of elements in a tuple (`Prod`) type. +def getArity (type : Expr) : Nat := + match type with + | .app (.app (.const ``Prod _) _) as => getArity as + 1 + | _ => 1 -- It is not product + +-- Given a `tuple` of size `n`, construct a term that for accessing the `i`-th element +def mkGetIdx (tuple : Expr) (n : Nat) (i : Nat) : MetaM Expr := do + match i with + | 0 => mkAppM ``Prod.fst #[tuple] + | i+1 => + if n = 2 then + -- If the tuple has only two elements and `i` is not `0`, + -- we just return the second element. + mkAppM ``Prod.snd #[tuple] + else + -- Otherwise, we continue with the rest of the tuple. + let tuple ← mkAppM ``Prod.snd #[tuple] + mkGetIdx tuple (n-1) i + +-- Now, we define the elaboration function for the new syntax `a#i` +elab_rules : term +| `($a:term.#$i:num) => do + -- Convert `i : Syntax` into a natural number + let i := i.getNat + -- Return error if it is 0. + unless i ≥ 0 do + throwError "tuple index must be greater or equal to 0" + -- Convert `a : Syntax` into an `tuple : Expr` without providing expected type + let tuple ← elabTerm a none + let type ← inferType tuple + -- Instantiate assigned metavariable occurring in `type` + let type ← instantiateMVars type + /- In case we are indexing into a type abbreviation, we need to unfold the type. + + TODO: we have to be careful about not unfolding too much, + for instance because of the following code: + ``` + def Pair T U := T × U + def Tuple T U V := T × Pair U V + ``` + We have to make sure that, given `x : Tuple T U V`, `x.1` evaluates + to the pair (an element of type `Pair T U`), not to the first field + of the pair (an element of type `T`). + + We have a similar issue below if we generate code from the following Rust definition: + ``` + struct Tuple(u32, (u32, u32)); + ``` + The issue is that in Rust, field 1 of `Tuple` is a pair `(u32, u32)`, but + in Lean there is no difference between `A × B × C` and `A × (B × C)`. + + In case such situations happen we probably need to resort to chaining + the pair projectors, like in: `x.snd.fst`. + -/ + let type ← whnf type + -- Ensure `tuple`'s type is a `Prod`uct. + unless type.isAppOf ``Prod do + throwError "tuple expected{indentExpr type}" + let n := getArity type + -- Ensure `i` is a valid index + unless i < n do + throwError "invalid tuple access at {i}, tuple has {n} elements" + mkGetIdx tuple n i + +end Primitives -- cgit v1.2.3 From 9d541d1ab6b91e59e4f78f4711af085a33ee4f82 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 8 Mar 2024 09:25:11 +0100 Subject: Update the tuples syntax --- backends/lean/Base/Tuples.lean | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/backends/lean/Base/Tuples.lean b/backends/lean/Base/Tuples.lean index d8e4a843..4c59dac9 100644 --- a/backends/lean/Base/Tuples.lean +++ b/backends/lean/Base/Tuples.lean @@ -8,7 +8,9 @@ namespace Primitives ------------------------------- -- Declare new syntax `a.#i` for accessing the `i`-th term in a tuple -- The `noWs` parser is used to ensure there is no whitespace. -syntax term noWs ".#" noWs num : term +-- We use the maximum precedence to make the syntax work with function calls. +-- Ex.: `f (0, 1).#0` +syntax:max term noWs ".#" noWs num : term open Lean Meta Elab Term -- cgit v1.2.3 From 44248ccfe3bfb8c45e5bb434d8dfb3dfa6e6b69c Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 8 Mar 2024 09:42:29 +0100 Subject: Update the generation of constant bodies for Lean --- backends/lean/Base/Primitives/Base.lean | 4 ++-- compiler/Extract.ml | 3 +-- tests/lean/Arrays.lean | 2 +- tests/lean/Constants.lean | 36 ++++++++++++++++----------------- tests/lean/NoNestedBorrows.lean | 4 ++-- tests/lean/Traits.lean | 5 ++--- 6 files changed, 26 insertions(+), 28 deletions(-) diff --git a/backends/lean/Base/Primitives/Base.lean b/backends/lean/Base/Primitives/Base.lean index 9dbaf133..0b9d9c39 100644 --- a/backends/lean/Base/Primitives/Base.lean +++ b/backends/lean/Base/Primitives/Base.lean @@ -69,7 +69,7 @@ def div? {α: Type u} (r: Result α): Bool := def massert (b:Bool) : Result Unit := if b then ret () else fail assertionFailure -def eval_global {α: Type u} (x: Result α) (_: ret? x): α := +def eval_global {α: Type u} (x: Result α) (_: ret? x := by decide): α := match x with | fail _ | div => by contradiction | ret x => x @@ -78,7 +78,7 @@ def eval_global {α: Type u} (x: Result α) (_: ret? x): α := def bind {α : Type u} {β : Type v} (x: Result α) (f: α → Result β) : Result β := match x with - | ret v => f v + | ret v => f v | fail v => fail v | div => div diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 6c523549..0a21d4ec 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -1863,8 +1863,7 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter) (fun fmt -> let body = match !backend with - | FStar -> "eval_global " ^ body_name - | Lean -> "eval_global " ^ body_name ^ " (by decide)" + | FStar | Lean -> "eval_global " ^ body_name | Coq -> body_name ^ "%global" | HOL4 -> "get_return_value " ^ body_name in diff --git a/tests/lean/Arrays.lean b/tests/lean/Arrays.lean index 6f9cd94c..d2bb7cf2 100644 --- a/tests/lean/Arrays.lean +++ b/tests/lean/Arrays.lean @@ -452,7 +452,7 @@ def f3 : Result U32 := /- [arrays::SZ] Source: 'src/arrays.rs', lines 286:0-286:19 -/ def sz_body : Result Usize := Result.ret 32#usize -def sz_c : Usize := eval_global sz_body (by decide) +def sz_c : Usize := eval_global sz_body /- [arrays::f5]: Source: 'src/arrays.rs', lines 289:0-289:31 -/ diff --git a/tests/lean/Constants.lean b/tests/lean/Constants.lean index 4c626ab3..32e0317b 100644 --- a/tests/lean/Constants.lean +++ b/tests/lean/Constants.lean @@ -8,17 +8,17 @@ namespace constants /- [constants::X0] Source: 'src/constants.rs', lines 5:0-5:17 -/ def x0_body : Result U32 := Result.ret 0#u32 -def x0_c : U32 := eval_global x0_body (by decide) +def x0_c : U32 := eval_global x0_body /- [constants::X1] Source: 'src/constants.rs', lines 7:0-7:17 -/ def x1_body : Result U32 := Result.ret core_u32_max -def x1_c : U32 := eval_global x1_body (by decide) +def x1_c : U32 := eval_global x1_body /- [constants::X2] Source: 'src/constants.rs', lines 10:0-10:17 -/ def x2_body : Result U32 := Result.ret 3#u32 -def x2_c : U32 := eval_global x2_body (by decide) +def x2_c : U32 := eval_global x2_body /- [constants::incr]: Source: 'src/constants.rs', lines 17:0-17:32 -/ @@ -28,7 +28,7 @@ def incr (n : U32) : Result U32 := /- [constants::X3] Source: 'src/constants.rs', lines 15:0-15:17 -/ def x3_body : Result U32 := incr 32#u32 -def x3_c : U32 := eval_global x3_body (by decide) +def x3_c : U32 := eval_global x3_body /- [constants::mk_pair0]: Source: 'src/constants.rs', lines 23:0-23:51 -/ @@ -49,22 +49,22 @@ def mk_pair1 (x : U32) (y : U32) : Result (Pair U32 U32) := /- [constants::P0] Source: 'src/constants.rs', lines 31:0-31:24 -/ def p0_body : Result (U32 × U32) := mk_pair0 0#u32 1#u32 -def p0_c : (U32 × U32) := eval_global p0_body (by decide) +def p0_c : (U32 × U32) := eval_global p0_body /- [constants::P1] Source: 'src/constants.rs', lines 32:0-32:28 -/ def p1_body : Result (Pair U32 U32) := mk_pair1 0#u32 1#u32 -def p1_c : Pair U32 U32 := eval_global p1_body (by decide) +def p1_c : Pair U32 U32 := eval_global p1_body /- [constants::P2] Source: 'src/constants.rs', lines 33:0-33:24 -/ def p2_body : Result (U32 × U32) := Result.ret (0#u32, 1#u32) -def p2_c : (U32 × U32) := eval_global p2_body (by decide) +def p2_c : (U32 × U32) := eval_global p2_body /- [constants::P3] Source: 'src/constants.rs', lines 34:0-34:28 -/ def p3_body : Result (Pair U32 U32) := Result.ret { x := 0#u32, y := 1#u32 } -def p3_c : Pair U32 U32 := eval_global p3_body (by decide) +def p3_c : Pair U32 U32 := eval_global p3_body /- [constants::Wrap] Source: 'src/constants.rs', lines 49:0-49:18 -/ @@ -79,7 +79,7 @@ def Wrap.new (T : Type) (value : T) : Result (Wrap T) := /- [constants::Y] Source: 'src/constants.rs', lines 41:0-41:22 -/ def y_body : Result (Wrap I32) := Wrap.new I32 2#i32 -def y_c : Wrap I32 := eval_global y_body (by decide) +def y_c : Wrap I32 := eval_global y_body /- [constants::unwrap_y]: Source: 'src/constants.rs', lines 43:0-43:30 -/ @@ -89,12 +89,12 @@ def unwrap_y : Result I32 := /- [constants::YVAL] Source: 'src/constants.rs', lines 47:0-47:19 -/ def yval_body : Result I32 := unwrap_y -def yval_c : I32 := eval_global yval_body (by decide) +def yval_c : I32 := eval_global yval_body /- [constants::get_z1::Z1] Source: 'src/constants.rs', lines 62:4-62:17 -/ def get_z1_z1_body : Result I32 := Result.ret 3#i32 -def get_z1_z1_c : I32 := eval_global get_z1_z1_body (by decide) +def get_z1_z1_c : I32 := eval_global get_z1_z1_body /- [constants::get_z1]: Source: 'src/constants.rs', lines 61:0-61:28 -/ @@ -109,17 +109,17 @@ def add (a : I32) (b : I32) : Result I32 := /- [constants::Q1] Source: 'src/constants.rs', lines 74:0-74:17 -/ def q1_body : Result I32 := Result.ret 5#i32 -def q1_c : I32 := eval_global q1_body (by decide) +def q1_c : I32 := eval_global q1_body /- [constants::Q2] Source: 'src/constants.rs', lines 75:0-75:17 -/ def q2_body : Result I32 := Result.ret q1_c -def q2_c : I32 := eval_global q2_body (by decide) +def q2_c : I32 := eval_global q2_body /- [constants::Q3] Source: 'src/constants.rs', lines 76:0-76:17 -/ def q3_body : Result I32 := add q2_c 3#i32 -def q3_c : I32 := eval_global q3_body (by decide) +def q3_c : I32 := eval_global q3_body /- [constants::get_z2]: Source: 'src/constants.rs', lines 70:0-70:28 -/ @@ -132,21 +132,21 @@ def get_z2 : Result I32 := /- [constants::S1] Source: 'src/constants.rs', lines 80:0-80:18 -/ def s1_body : Result U32 := Result.ret 6#u32 -def s1_c : U32 := eval_global s1_body (by decide) +def s1_c : U32 := eval_global s1_body /- [constants::S2] Source: 'src/constants.rs', lines 81:0-81:18 -/ def s2_body : Result U32 := incr s1_c -def s2_c : U32 := eval_global s2_body (by decide) +def s2_c : U32 := eval_global s2_body /- [constants::S3] Source: 'src/constants.rs', lines 82:0-82:29 -/ def s3_body : Result (Pair U32 U32) := Result.ret p3_c -def s3_c : Pair U32 U32 := eval_global s3_body (by decide) +def s3_c : Pair U32 U32 := eval_global s3_body /- [constants::S4] Source: 'src/constants.rs', lines 83:0-83:29 -/ def s4_body : Result (Pair U32 U32) := mk_pair1 7#u32 8#u32 -def s4_c : Pair U32 U32 := eval_global s4_body (by decide) +def s4_c : Pair U32 U32 := eval_global s4_body end constants diff --git a/tests/lean/NoNestedBorrows.lean b/tests/lean/NoNestedBorrows.lean index ef81f2e9..71d064d8 100644 --- a/tests/lean/NoNestedBorrows.lean +++ b/tests/lean/NoNestedBorrows.lean @@ -139,12 +139,12 @@ def mix_arith_i32 (x : I32) (y : I32) (z : I32) : Result I32 := /- [no_nested_borrows::CONST0] Source: 'src/no_nested_borrows.rs', lines 125:0-125:23 -/ def const0_body : Result Usize := 1#usize + 1#usize -def const0_c : Usize := eval_global const0_body (by decide) +def const0_c : Usize := eval_global const0_body /- [no_nested_borrows::CONST1] Source: 'src/no_nested_borrows.rs', lines 126:0-126:23 -/ def const1_body : Result Usize := 2#usize * 2#usize -def const1_c : Usize := eval_global const1_body (by decide) +def const1_c : Usize := eval_global const1_body /- [no_nested_borrows::cast_u32_to_i32]: Source: 'src/no_nested_borrows.rs', lines 128:0-128:37 -/ diff --git a/tests/lean/Traits.lean b/tests/lean/Traits.lean index 3ef4febc..f83fbc2f 100644 --- a/tests/lean/Traits.lean +++ b/tests/lean/Traits.lean @@ -248,8 +248,7 @@ def traits.ToTypetraitsBoolWrapperTInst (T : Type) (ToTypeBoolTInst : ToType /- [traits::WithConstTy::LEN2] Source: 'src/traits.rs', lines 164:4-164:21 -/ def with_const_ty_len2_body : Result Usize := Result.ret 32#usize -def with_const_ty_len2_c : Usize := - eval_global with_const_ty_len2_body (by decide) +def with_const_ty_len2_c : Usize := eval_global with_const_ty_len2_body /- Trait declaration: [traits::WithConstTy] Source: 'src/traits.rs', lines 161:0-161:39 -/ @@ -264,7 +263,7 @@ structure WithConstTy (Self : Type) (LEN : Usize) where /- [traits::{bool#8}::LEN1] Source: 'src/traits.rs', lines 175:4-175:21 -/ def bool_len1_body : Result Usize := Result.ret 12#usize -def bool_len1_c : Usize := eval_global bool_len1_body (by decide) +def bool_len1_c : Usize := eval_global bool_len1_body /- [traits::{bool#8}::f]: Source: 'src/traits.rs', lines 180:4-180:39 -/ -- cgit v1.2.3 From 5427563a8000f281ac614a2501fb9983beb44f21 Mon Sep 17 00:00:00 2001 From: Zyad Hassan Date: Fri, 23 Feb 2024 16:37:58 -0800 Subject: Fix tuple indexing for Lean backend --- backends/lean/Base/IList/IList.lean | 2 +- compiler/Extract.ml | 37 +++++++++++++++++++++++++++++++++---- tests/lean/NoNestedBorrows.lean | 2 +- 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/backends/lean/Base/IList/IList.lean b/backends/lean/Base/IList/IList.lean index 51457c20..ca5ee266 100644 --- a/backends/lean/Base/IList/IList.lean +++ b/backends/lean/Base/IList/IList.lean @@ -33,7 +33,7 @@ def indexOpt (ls : List α) (i : Int) : Option α := @[simp] theorem indexOpt_zero_cons : indexOpt ((x :: tl) : List α) 0 = some x := by simp [indexOpt] @[simp] theorem indexOpt_nzero_cons (hne : i ≠ 0) : indexOpt ((x :: tl) : List α) i = indexOpt tl (i - 1) := by simp [*, indexOpt] --- Remark: if i < 0, then the result is the defaul element +-- Remark: if i < 0, then the result is the default element def index [Inhabited α] (ls : List α) (i : Int) : α := match ls with | [] => Inhabited.default diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 0a21d4ec..d7ef5f34 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -577,12 +577,17 @@ and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter) in (* Check if we extract the type as a tuple, and it only has one field. In this case, there is no projection. *) - let has_one_field = + let num_fields = match proj.adt_id with | TAdtId id -> ( let d = TypeDeclId.Map.find id ctx.trans_types in - match d.kind with Struct [ _ ] -> true | _ -> false) - | _ -> false + match d.kind with + | Struct fields -> Some (List.length fields) + | _ -> None) + | _ -> None + in + let has_one_field = + match num_fields with Some len -> len = 1 | None -> false in if is_tuple_struct && has_one_field then extract_texpression ctx fmt inside arg @@ -590,7 +595,31 @@ and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter) (* Exactly one argument: pretty-print *) let field_name = (* Check if we need to extract the type as a tuple *) - if is_tuple_struct then FieldId.to_string proj.field_id + if is_tuple_struct then + match !backend with + | FStar | HOL4 | Coq -> FieldId.to_string proj.field_id + | Lean -> + (* Tuples in Lean are syntax sugar for nested products/pairs, + so we need to map the field id accordingly. + A field id i maps to: + (.2)^i if i is the last element of the tuple + (.2)^i.1 otherwise + where (.2)^i denotes .2 repeated i times. + For example, 3 maps to .2.2.2 if the tuple has 4 fields and + to .2.2.2.1 if it has more than 4 fields. + Note that the first "." is added below *) + let field_id = FieldId.to_int proj.field_id in + (* Helper: repeat "2.2.2..." *) + let rec repeat_snd n = + match n with + | 0 -> "" + | 1 -> "2" + | _ -> "2." ^ repeat_snd (n - 1) + in + let twos_prefix = repeat_snd field_id in + if field_id + 1 = Option.get num_fields then twos_prefix + else if field_id = 0 then "1" + else twos_prefix ^ ".1" else ctx_get_field proj.adt_id proj.field_id ctx in (* Open a box *) diff --git a/tests/lean/NoNestedBorrows.lean b/tests/lean/NoNestedBorrows.lean index 71d064d8..a326bdf7 100644 --- a/tests/lean/NoNestedBorrows.lean +++ b/tests/lean/NoNestedBorrows.lean @@ -643,7 +643,7 @@ def Tuple (T1 T2 : Type) := T1 × T2 /- [no_nested_borrows::use_tuple_struct]: Source: 'src/no_nested_borrows.rs', lines 556:0-556:48 -/ def use_tuple_struct (x : Tuple U32 U32) : Result (Tuple U32 U32) := - Result.ret (1#u32, x.1) + Result.ret (1#u32, x.2) /- [no_nested_borrows::create_tuple_struct]: Source: 'src/no_nested_borrows.rs', lines 560:0-560:61 -/ -- cgit v1.2.3 From a7452421be018e5d75065e2038f2f50042a80f3c Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 8 Mar 2024 10:42:10 +0100 Subject: Update the code generated for tuple projectors --- compiler/Config.ml | 4 +++ compiler/Extract.ml | 57 ++++++++++++++++++++++++++++------------- compiler/Main.ml | 4 +++ tests/lean/NoNestedBorrows.lean | 2 +- 4 files changed, 48 insertions(+), 19 deletions(-) diff --git a/compiler/Config.ml b/compiler/Config.ml index 2bb1ca34..3b0070c0 100644 --- a/compiler/Config.ml +++ b/compiler/Config.ml @@ -469,3 +469,7 @@ let use_tuple_structs = ref true let backend_has_tuple_projectors () = match !backend with Lean -> true | Coq | FStar | HOL4 -> false + +(** We we use nested projectors for tuple (like: [(0, 1).snd.fst]) or do + we use better projector syntax? *) +let use_nested_tuple_projectors = ref false diff --git a/compiler/Extract.ml b/compiler/Extract.ml index d7ef5f34..dbca4f8f 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -601,25 +601,46 @@ and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter) | Lean -> (* Tuples in Lean are syntax sugar for nested products/pairs, so we need to map the field id accordingly. - A field id i maps to: - (.2)^i if i is the last element of the tuple - (.2)^i.1 otherwise - where (.2)^i denotes .2 repeated i times. - For example, 3 maps to .2.2.2 if the tuple has 4 fields and - to .2.2.2.1 if it has more than 4 fields. - Note that the first "." is added below *) + + We give two possibilities: + - either we use the custom syntax [.#i], like in: [(0, 1).#1] + - or we introduce nested projections which use the field + projectors [.1] and [.2], like in: [(0, 1).2.1] + + This necessary in some situations, for instance if we have + in Rust: + {[ + struct Tuple(u32, (u32, u32)); + ]} + + The issue comes from the fact that in Lean [A * B * C] and [A * (B * + C)] are the same type. As a result, in Rust, field 1 of [Tuple] is + the pair (an element of type [(u32, u32)]), however in Lean it would + be the first element of the pair (an element of type [u32]). If such + situations happen, we allow to force using the nested projectors by + providing the proper command line argument. TODO: we can actually + check the type to determine exactly when we need to use nested + projectors and when we don't. + + When using nested projectors, a field id i maps to: + - (.2)^i if i is the last element of the tuple + - (.2)^i.1 otherwise + where (.2)^i denotes .2 repeated i times. + For example, 3 maps to .2.2.2 if the tuple has 4 fields and + to .2.2.2.1 if it has more than 4 fields. + Note that the first "." is added below. + *) let field_id = FieldId.to_int proj.field_id in - (* Helper: repeat "2.2.2..." *) - let rec repeat_snd n = - match n with - | 0 -> "" - | 1 -> "2" - | _ -> "2." ^ repeat_snd (n - 1) - in - let twos_prefix = repeat_snd field_id in - if field_id + 1 = Option.get num_fields then twos_prefix - else if field_id = 0 then "1" - else twos_prefix ^ ".1" + if !Config.use_nested_tuple_projectors then + (* Nested projection: "2.2.2..." *) + if field_id = 0 then "1" + else + let twos_prefix = + String.concat "." (Collections.List.repeat field_id "2") + in + if field_id + 1 = Option.get num_fields then twos_prefix + else twos_prefix ^ ".1" + else "#" ^ string_of_int field_id else ctx_get_field proj.adt_id proj.field_id ctx in (* Open a box *) diff --git a/compiler/Main.ml b/compiler/Main.ml index 0b8ec439..4a2d01dc 100644 --- a/compiler/Main.ml +++ b/compiler/Main.ml @@ -123,6 +123,10 @@ let () = ( "-split-fwd-back", Arg.Clear return_back_funs, " Split the forward and backward functions." ); + ( "-tuple-nested-proj", + Arg.Set use_nested_tuple_projectors, + " Use nested projectors for tuples (e.g., (0, 1).snd.fst instead of \ + (0, 1).1)." ); ] in diff --git a/tests/lean/NoNestedBorrows.lean b/tests/lean/NoNestedBorrows.lean index a326bdf7..a85209ea 100644 --- a/tests/lean/NoNestedBorrows.lean +++ b/tests/lean/NoNestedBorrows.lean @@ -643,7 +643,7 @@ def Tuple (T1 T2 : Type) := T1 × T2 /- [no_nested_borrows::use_tuple_struct]: Source: 'src/no_nested_borrows.rs', lines 556:0-556:48 -/ def use_tuple_struct (x : Tuple U32 U32) : Result (Tuple U32 U32) := - Result.ret (1#u32, x.2) + Result.ret (1#u32, x.#1) /- [no_nested_borrows::create_tuple_struct]: Source: 'src/no_nested_borrows.rs', lines 560:0-560:61 -/ -- cgit v1.2.3