diff options
-rw-r--r-- | backends/lean/Base/Primitives.lean | 1 | ||||
-rw-r--r-- | backends/lean/Base/Primitives/Scalar.lean | 123 | ||||
-rw-r--r-- | backends/lean/Base/Primitives/ScalarNotations.lean | 90 | ||||
-rw-r--r-- | compiler/Extract.ml | 6 | ||||
-rw-r--r-- | compiler/ExtractTypes.ml | 14 | ||||
-rw-r--r-- | tests/lean/Matches.lean | 4 |
6 files changed, 127 insertions, 111 deletions
diff --git a/backends/lean/Base/Primitives.lean b/backends/lean/Base/Primitives.lean index f80c2004..93617049 100644 --- a/backends/lean/Base/Primitives.lean +++ b/backends/lean/Base/Primitives.lean @@ -1,6 +1,7 @@ import Base.Primitives.Base import Base.Tuples import Base.Primitives.Scalar +import Base.Primitives.ScalarNotations import Base.Primitives.ArraySlice import Base.Primitives.Vec import Base.Primitives.Alloc diff --git a/backends/lean/Base/Primitives/Scalar.lean b/backends/lean/Base/Primitives/Scalar.lean index 157ade2c..f4264b9b 100644 --- a/backends/lean/Base/Primitives/Scalar.lean +++ b/backends/lean/Base/Primitives/Scalar.lean @@ -312,38 +312,13 @@ theorem Scalar.bound_suffices (ty : ScalarTy) (x : Int) : λ h => by apply And.intro <;> have hmin := Scalar.cMin_bound ty <;> have hmax := Scalar.cMax_bound ty <;> linarith -/- [match_pattern] attribute: allows to us `Scalar.ofIntCore` inside of patterns. - This is particularly useful once we introduce notations like `#u32` (which - desugards to `Scalar.ofIntCore`) as it allows to write expressions like this: - Example: - ``` - match x with - | 0#u32 => ... - | 1#u32 => ... - | ... - ``` - -/ -@[match_pattern] def Scalar.ofIntCore {ty : ScalarTy} (x : Int) +def Scalar.ofIntCore {ty : ScalarTy} (x : Int) (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : Scalar ty := { val := x, hmin := h.left, hmax := h.right } --- The definitions below are used later to introduce nice syntax for constants, --- like `1#u32`. We are reusing the technique described here: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/Different.20elaboration.20inside.2Foutside.20of.20match.20patterns/near/425455284 - -class InBounds (ty : ScalarTy) (x : Int) := - hInBounds : Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty - --- This trick to trigger reduction for decidable propositions comes from --- here: https://leanprover.zulipchat.com/#narrow/stream/270676-lean4/topic/instance.20with.20tactic.20autoparam/near/343495807 -class Decide (p : Prop) [Decidable p] : Prop where - isTrue : p -instance : @Decide p (.isTrue h) := @Decide.mk p (_) h - -instance [Decide (Scalar.cMin ty ≤ v ∧ v ≤ Scalar.cMax ty)] : InBounds ty v where - hInBounds := Decide.isTrue - -@[reducible, match_pattern] def Scalar.ofInt {ty : ScalarTy} (x : Int) [InBounds ty x] : Scalar ty := - Scalar.ofIntCore x (Scalar.bound_suffices ty x InBounds.hInBounds) +@[reducible] def Scalar.ofInt {ty : ScalarTy} (x : Int) + (hInBounds : Scalar.cMin ty ≤ x ∧ x ≤ Scalar.cMax ty := by decide) : Scalar ty := + Scalar.ofIntCore x (Scalar.bound_suffices ty x hInBounds) @[simp] abbrev Scalar.in_bounds (ty : ScalarTy) (x : Int) : Prop := Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty @@ -412,9 +387,8 @@ theorem Scalar.tryMk_eq (ty : ScalarTy) (x : Int) : simp [tryMk, ofOption, tryMkOpt] split_ifs <;> simp -instance (ty: ScalarTy) : InBounds ty 0 where - hInBounds := by - induction ty <;> simp [Scalar.cMax, Scalar.cMin, Scalar.max, Scalar.min] <;> decide +@[simp] theorem zero_in_cbounds {ty : ScalarTy} : Scalar.cMin ty ≤ 0 ∧ 0 ≤ Scalar.cMax ty := by + cases ty <;> simp [Scalar.cMax, Scalar.cMin, Scalar.max, Scalar.min] <;> decide def Scalar.neg {ty : ScalarTy} (x : Scalar ty) : Result (Scalar ty) := Scalar.tryMk ty (- x.val) @@ -1268,73 +1242,18 @@ def U128.ofIntCore := @Scalar.ofIntCore .U128 -- ofInt -- TODO: typeclass? -@[match_pattern] abbrev Isize.ofInt := @Scalar.ofInt .Isize -@[match_pattern] abbrev I8.ofInt := @Scalar.ofInt .I8 -@[match_pattern] abbrev I16.ofInt := @Scalar.ofInt .I16 -@[match_pattern] abbrev I32.ofInt := @Scalar.ofInt .I32 -@[match_pattern] abbrev I64.ofInt := @Scalar.ofInt .I64 -@[match_pattern] abbrev I128.ofInt := @Scalar.ofInt .I128 -@[match_pattern] abbrev Usize.ofInt := @Scalar.ofInt .Usize -@[match_pattern] abbrev U8.ofInt := @Scalar.ofInt .U8 -@[match_pattern] abbrev U16.ofInt := @Scalar.ofInt .U16 -@[match_pattern] abbrev U32.ofInt := @Scalar.ofInt .U32 -@[match_pattern] abbrev U64.ofInt := @Scalar.ofInt .U64 -@[match_pattern] abbrev U128.ofInt := @Scalar.ofInt .U128 - -postfix:max "#isize" => Isize.ofInt -postfix:max "#i8" => I8.ofInt -postfix:max "#i16" => I16.ofInt -postfix:max "#i32" => I32.ofInt -postfix:max "#i64" => I64.ofInt -postfix:max "#i128" => I128.ofInt -postfix:max "#usize" => Usize.ofInt -postfix:max "#u8" => U8.ofInt -postfix:max "#u16" => U16.ofInt -postfix:max "#u32" => U32.ofInt -postfix:max "#u64" => U64.ofInt -postfix:max "#u128" => U128.ofInt - -/- Testing the notations -/ -example := 0#u32 -example := 1#u32 -example := 1#i32 -example := 0#isize -example := (-1)#isize -example (x : U32) : Bool := - match x with - | 0#u32 => true - | _ => false - -example (x : U32) : Bool := - match x with - | 1#u32 => true - | _ => false - -example (x : I32) : Bool := - match x with - | (-1)#i32 => true - | _ => false - --- Notation for pattern matching --- We make the precedence looser than the negation. -notation:70 a:70 "#scalar" => Scalar.mk (a) _ _ - -example {ty} (x : Scalar ty) : ℤ := - match x with - | v#scalar => v - -example {ty} (x : Scalar ty) : Bool := - match x with - | 1#scalar => true - | _ => false - -example {ty} (x : Scalar ty) : Bool := - match x with - | -(1 : Int)#scalar => true - | _ => false - --- Testing the notations -example : Result Usize := 0#usize + 1#usize +abbrev Isize.ofInt := @Scalar.ofInt .Isize +abbrev I8.ofInt := @Scalar.ofInt .I8 +abbrev I16.ofInt := @Scalar.ofInt .I16 +abbrev I32.ofInt := @Scalar.ofInt .I32 +abbrev I64.ofInt := @Scalar.ofInt .I64 +abbrev I128.ofInt := @Scalar.ofInt .I128 +abbrev Usize.ofInt := @Scalar.ofInt .Usize +abbrev U8.ofInt := @Scalar.ofInt .U8 +abbrev U16.ofInt := @Scalar.ofInt .U16 +abbrev U32.ofInt := @Scalar.ofInt .U32 +abbrev U64.ofInt := @Scalar.ofInt .U64 +abbrev U128.ofInt := @Scalar.ofInt .U128 -- TODO: factor those lemmas out @[simp] theorem Scalar.ofInt_val_eq {ty} (h : Scalar.min ty ≤ x ∧ x ≤ Scalar.max ty) : (Scalar.ofIntCore x h).val = x := by @@ -1464,18 +1383,18 @@ theorem coe_max {ty: ScalarTy} (a b: Scalar ty): ↑(Max.max a b) = (Max.max (↠-- Max theory -- TODO: do the min theory later on. -theorem Scalar.zero_le_unsigned {ty} (s: ¬ ty.isSigned) (x: Scalar ty): Scalar.ofInt 0 ≤ x := by +theorem Scalar.zero_le_unsigned {ty} (s: ¬ ty.isSigned) (x: Scalar ty): Scalar.ofInt 0 (by simp) ≤ x := by apply (Scalar.le_equiv _ _).2 convert x.hmin cases ty <;> simp [ScalarTy.isSigned] at s <;> simp [Scalar.min] @[simp] theorem Scalar.max_unsigned_left_zero_eq {ty} [s: Fact (¬ ty.isSigned)] (x: Scalar ty): - Max.max (Scalar.ofInt 0) x = x := max_eq_right (Scalar.zero_le_unsigned s.out x) + Max.max (Scalar.ofInt 0 (by simp)) x = x := max_eq_right (Scalar.zero_le_unsigned s.out x) @[simp] theorem Scalar.max_unsigned_right_zero_eq {ty} [s: Fact (¬ ty.isSigned)] (x: Scalar ty): - Max.max x (Scalar.ofInt 0) = x := max_eq_left (Scalar.zero_le_unsigned s.out x) + Max.max x (Scalar.ofInt 0 (by simp)) = x := max_eq_left (Scalar.zero_le_unsigned s.out x) -- Leading zeros def core.num.Usize.leading_zeros (x : Usize) : U32 := sorry diff --git a/backends/lean/Base/Primitives/ScalarNotations.lean b/backends/lean/Base/Primitives/ScalarNotations.lean new file mode 100644 index 00000000..cc6e6f02 --- /dev/null +++ b/backends/lean/Base/Primitives/ScalarNotations.lean @@ -0,0 +1,90 @@ +import Lean +import Lean.Meta.Tactic.Simp +import Mathlib.Tactic.Linarith +import Base.Primitives.Scalar +import Base.Arith + +namespace Primitives + +open Lean Meta Elab Term + +macro x:term:max "#isize" : term => `(Isize.ofInt $x (by scalar_tac)) +macro x:term:max "#i8" : term => `(I8.ofInt $x (by scalar_tac)) +macro x:term:max "#i16" : term => `(I16.ofInt $x (by scalar_tac)) +macro x:term:max "#i32" : term => `(I32.ofInt $x (by scalar_tac)) +macro x:term:max "#i64" : term => `(I64.ofInt $x (by scalar_tac)) +macro x:term:max "#i128" : term => `(I128.ofInt $x (by scalar_tac)) +macro x:term:max "#usize" : term => `(Usize.ofInt $x (by scalar_tac)) +macro x:term:max "#u8" : term => `(U8.ofInt $x (by scalar_tac)) +macro x:term:max "#u16" : term => `(U16.ofInt $x (by scalar_tac)) +macro x:term:max "#u32" : term => `(U32.ofInt $x (by scalar_tac)) +macro x:term:max "#u64" : term => `(U64.ofInt $x (by scalar_tac)) +macro x:term:max "#u128" : term => `(U128.ofInt $x (by scalar_tac)) + +macro x:term:max noWs "u32" : term => `(U32.ofInt $x (by scalar_tac)) + +-- Notation for pattern matching +-- We make the precedence looser than the negation. +notation:70 a:70 "#scalar" => Scalar.mk (a) _ _ + +/- Testing the notations -/ +example := 0#u32 +example := 1#u32 +example := 1#i32 +example := 0#isize +example := (-1)#isize + +/- +-- This doesn't work anymore +example (x : U32) : Bool := + match x with + | 0#u32 => true + | _ => false + +example (x : U32) : Bool := + match x with + | 1#u32 => true + | _ => false + +example (x : I32) : Bool := + match x with + | (-1)#i32 => true + | _ => false +-/ + +example (x : U32) : Bool := + match x with + | 0#scalar => true + | _ => false + +example (x : U32) : Bool := + match x with + | 1#scalar => true + | _ => false + +example (x : I32) : Bool := + match x with + | (-1)#scalar => true + | _ => false + +example {ty} (x : Scalar ty) : ℤ := + match x with + | v#scalar => v + +example {ty} (x : Scalar ty) : Bool := + match x with + | 1#scalar => true + | _ => false + +example {ty} (x : Scalar ty) : Bool := + match x with + | -(1 : Int)#scalar => true + | _ => false + +-- Testing the notations +example : Result Usize := 0#usize + 1#usize + +-- More complex expressions +example (x y : Int) (h : 0 ≤ x + y ∧ x + y ≤ 1000) : U32 := (x + y)#u32 + +end Primitives diff --git a/compiler/Extract.ml b/compiler/Extract.ml index eab85054..4acf3f99 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -241,11 +241,12 @@ let rec extract_typed_pattern (span : Meta.span) (ctx : extraction_ctx) (fmt : F.formatter) (is_let : bool) (inside : bool) ?(with_type = false) (v : typed_pattern) : extraction_ctx = if with_type then F.pp_print_string fmt "("; + let is_pattern = true in let inside = inside && not with_type in let ctx = match v.value with | PatConstant cv -> - extract_literal span fmt inside cv; + extract_literal span fmt is_pattern inside cv; ctx | PatVar (v, _) -> let vname = ctx_compute_var_basename span ctx v.basename v.ty in @@ -307,6 +308,7 @@ let extract_texpression_errors (fmt : F.formatter) = let rec extract_texpression (span : Meta.span) (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool) (e : texpression) : unit = + let is_pattern = false in match e.e with | Var var_id -> let var_name = ctx_get_var span var_id ctx in @@ -314,7 +316,7 @@ let rec extract_texpression (span : Meta.span) (ctx : extraction_ctx) | CVar var_id -> let var_name = ctx_get_const_generic_var span var_id ctx in F.pp_print_string fmt var_name - | Const cv -> extract_literal span fmt inside cv + | Const cv -> extract_literal span fmt is_pattern inside cv | App _ -> let app, args = destruct_apps e in extract_App span ctx fmt inside app args diff --git a/compiler/ExtractTypes.ml b/compiler/ExtractTypes.ml index 15e75da2..631db13e 100644 --- a/compiler/ExtractTypes.ml +++ b/compiler/ExtractTypes.ml @@ -11,12 +11,13 @@ include ExtractBase Inputs: - formatter + - [is_pattern]: if [true], it means we are generating a (match) pattern - [inside]: if [true], the value should be wrapped in parentheses if it is made of an application (ex.: [U32 3]) - the constant value *) -let extract_literal (span : Meta.span) (fmt : F.formatter) (inside : bool) - (cv : literal) : unit = +let extract_literal (span : Meta.span) (fmt : F.formatter) (is_pattern : bool) + (inside : bool) (cv : literal) : unit = match cv with | VScalar sv -> ( match backend () with @@ -39,8 +40,11 @@ let extract_literal (span : Meta.span) (fmt : F.formatter) (inside : bool) let iname = int_name sv.int_ty in F.pp_print_string fmt ("%" ^ iname) | Lean -> - let iname = String.lowercase_ascii (int_name sv.int_ty) in - F.pp_print_string fmt ("#" ^ iname) + (* We don't use the same notation for patterns and regular literals *) + if is_pattern then F.pp_print_string fmt "#scalar" + else + let iname = String.lowercase_ascii (int_name sv.int_ty) in + F.pp_print_string fmt ("#" ^ iname) | HOL4 -> () | _ -> craise __FILE__ __LINE__ span "Unreachable"); if print_brackets then F.pp_print_string fmt ")") @@ -409,7 +413,7 @@ let extract_const_generic (span : Meta.span) (ctx : extraction_ctx) | CgGlobal id -> let s = ctx_get_global span id ctx in F.pp_print_string fmt s - | CgValue v -> extract_literal span fmt inside v + | CgValue v -> extract_literal span fmt false inside v | CgVar id -> let s = ctx_get_const_generic_var span id ctx in F.pp_print_string fmt s diff --git a/tests/lean/Matches.lean b/tests/lean/Matches.lean index 3e3a558b..9233841b 100644 --- a/tests/lean/Matches.lean +++ b/tests/lean/Matches.lean @@ -9,8 +9,8 @@ namespace matches Source: 'tests/src/matches.rs', lines 4:0-4:27 -/ def match_u32 (x : U32) : Result U32 := match x with - | 0#u32 => Result.ok 0#u32 - | 1#u32 => Result.ok 1#u32 + | 0#scalar => Result.ok 0#u32 + | 1#scalar => Result.ok 1#u32 | _ => Result.ok 2#u32 end matches |