summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJonathan Protzenko2023-01-31 15:28:17 -0800
committerSon HO2023-06-04 21:44:33 +0200
commit7586cf83f59ca784ff4bfd5d11e460fd41acec98 (patch)
tree7a023c2d7fcc8ee76e45568524c4dba1ea95f03d
parentb30bac9e20735ab47327a2ac3122c2cfce162845 (diff)
Fill out more of the primitives file, attempt at type classes for scalar_cast
-rw-r--r--backends/lean/primitives.lean28
-rw-r--r--compiler/Extract.ml25
-rw-r--r--tests/lean/hashmap_on_disk/Base/Primitives.lean28
-rw-r--r--tests/lean/hashmap_on_disk/HashmapMain/Funs.lean4
4 files changed, 69 insertions, 16 deletions
diff --git a/backends/lean/primitives.lean b/backends/lean/primitives.lean
index 3f1d13f1..0a51aacd 100644
--- a/backends/lean/primitives.lean
+++ b/backends/lean/primitives.lean
@@ -1,5 +1,6 @@
import Lean
import Init.Data.List.Basic
+import Mathlib.Tactic.RunCmd
-------------
-- PRELUDE --
@@ -140,7 +141,7 @@ def USize.checked_add (n: USize) (m: USize): result USize :=
def USize.checked_rem (n: USize) (m: USize): result USize :=
if h: m > 0 then
.ret ⟨ n.val % m.val, by
- have h1: m.val < USize.size := m.val.isLt
+ have h1: ↑m.val < USize.size := m.val.isLt
have h2: n.val.val % m.val.val < m.val.val := @Nat.mod_lt n.val m.val h
apply Nat.lt_trans h2 h1
@@ -156,13 +157,36 @@ def USize.checked_mul (n: USize) (m: USize): result USize :=
def USize.checked_div (n: USize) (m: USize): result USize :=
if m > 0 then
.ret ⟨ n.val / m.val, by
- have h1: n.val < USize.size := n.val.isLt
+ have h1: ↑n.val < USize.size := n.val.isLt
have h2: n.val.val / m.val.val <= n.val.val := @Nat.div_le_self n.val m.val
apply Nat.lt_of_le_of_lt h2 h1
else
.fail integerOverflow
+class MachineInteger (t: Type) where
+ size: Nat
+ val: t -> Fin size
+ ofNatCore: (n:Nat) -> LT.lt n size -> t
+
+set_option hygiene false in
+run_cmd
+ for typeName in [`UInt8, `UInt16, `UInt32, `UInt64, `USize].map Lean.mkIdent do
+ Lean.Elab.Command.elabCommand (← `(
+ namespace $typeName
+ instance: MachineInteger $typeName where
+ size := size
+ val := val
+ ofNatCore := ofNatCore
+ end $typeName
+ ))
+
+def scalar_cast { src: Type } (dst: Type) [ MachineInteger src ] [ MachineInteger dst ] (x: src): result dst :=
+ if h: MachineInteger.val x < MachineInteger.size dst then
+ .ret (MachineInteger.ofNatCore (MachineInteger.val x).val h)
+ else
+ .fail integerOverflow
+
-- One needs to perform a little bit of reasoning in order to successfully
-- inject constants into USize, so we provide a general-purpose macro
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 64069cb0..a5ff6796 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -251,13 +251,18 @@ let extract_unop (extract_expr : bool -> texpression -> unit)
if inside then F.pp_print_string fmt "(";
F.pp_print_string fmt "scalar_cast";
F.pp_print_space fmt ();
- F.pp_print_string fmt
- (StringUtils.capitalize_first_letter
- (PrintPure.integer_type_to_string src));
- F.pp_print_space fmt ();
- F.pp_print_string fmt
- (StringUtils.capitalize_first_letter
- (PrintPure.integer_type_to_string tgt));
+ if !backend <> Lean then begin
+ F.pp_print_string fmt
+ (StringUtils.capitalize_first_letter
+ (PrintPure.integer_type_to_string src));
+ F.pp_print_space fmt ()
+ end;
+ if !backend = Lean then
+ F.pp_print_string fmt (int_name tgt)
+ else
+ F.pp_print_string fmt
+ (StringUtils.capitalize_first_letter
+ (PrintPure.integer_type_to_string tgt));
F.pp_print_space fmt ();
extract_expr true arg;
if inside then F.pp_print_string fmt ")"
@@ -2345,11 +2350,11 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter)
(* Termination clause *)
if has_decreases_clause && !backend = Lean then begin
let def_body = Option.get def.body in
- let vars = List.map (fun (v: var) -> v.id) def_body.inputs in
+ let all_vars = List.map (fun (v: var) -> v.id) def_body.inputs in
let num_fwd_inputs =
def.signature.info.num_fwd_inputs_with_fuel_with_state
in
- let vars = Collections.List.prefix num_fwd_inputs vars in
+ let vars = Collections.List.prefix num_fwd_inputs all_vars in
(* terminates_by *)
let terminates_name = ctx_get_terminates_clause def.def_id def.loop_id ctx in
@@ -2363,7 +2368,7 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter)
Collections.List.iter_link
(F.pp_print_space fmt)
(fun v -> F.pp_print_string fmt (ctx_get_var v ctx_body))
- vars;
+ all_vars;
F.pp_print_space fmt ();
F.pp_print_string fmt "=>";
F.pp_close_box fmt ();
diff --git a/tests/lean/hashmap_on_disk/Base/Primitives.lean b/tests/lean/hashmap_on_disk/Base/Primitives.lean
index 3f1d13f1..0a51aacd 100644
--- a/tests/lean/hashmap_on_disk/Base/Primitives.lean
+++ b/tests/lean/hashmap_on_disk/Base/Primitives.lean
@@ -1,5 +1,6 @@
import Lean
import Init.Data.List.Basic
+import Mathlib.Tactic.RunCmd
-------------
-- PRELUDE --
@@ -140,7 +141,7 @@ def USize.checked_add (n: USize) (m: USize): result USize :=
def USize.checked_rem (n: USize) (m: USize): result USize :=
if h: m > 0 then
.ret ⟨ n.val % m.val, by
- have h1: m.val < USize.size := m.val.isLt
+ have h1: ↑m.val < USize.size := m.val.isLt
have h2: n.val.val % m.val.val < m.val.val := @Nat.mod_lt n.val m.val h
apply Nat.lt_trans h2 h1
@@ -156,13 +157,36 @@ def USize.checked_mul (n: USize) (m: USize): result USize :=
def USize.checked_div (n: USize) (m: USize): result USize :=
if m > 0 then
.ret ⟨ n.val / m.val, by
- have h1: n.val < USize.size := n.val.isLt
+ have h1: ↑n.val < USize.size := n.val.isLt
have h2: n.val.val / m.val.val <= n.val.val := @Nat.div_le_self n.val m.val
apply Nat.lt_of_le_of_lt h2 h1
else
.fail integerOverflow
+class MachineInteger (t: Type) where
+ size: Nat
+ val: t -> Fin size
+ ofNatCore: (n:Nat) -> LT.lt n size -> t
+
+set_option hygiene false in
+run_cmd
+ for typeName in [`UInt8, `UInt16, `UInt32, `UInt64, `USize].map Lean.mkIdent do
+ Lean.Elab.Command.elabCommand (← `(
+ namespace $typeName
+ instance: MachineInteger $typeName where
+ size := size
+ val := val
+ ofNatCore := ofNatCore
+ end $typeName
+ ))
+
+def scalar_cast { src: Type } (dst: Type) [ MachineInteger src ] [ MachineInteger dst ] (x: src): result dst :=
+ if h: MachineInteger.val x < MachineInteger.size dst then
+ .ret (MachineInteger.ofNatCore (MachineInteger.val x).val h)
+ else
+ .fail integerOverflow
+
-- One needs to perform a little bit of reasoning in order to successfully
-- inject constants into USize, so we provide a general-purpose macro
diff --git a/tests/lean/hashmap_on_disk/HashmapMain/Funs.lean b/tests/lean/hashmap_on_disk/HashmapMain/Funs.lean
index f9d4a8c5..79d711e9 100644
--- a/tests/lean/hashmap_on_disk/HashmapMain/Funs.lean
+++ b/tests/lean/hashmap_on_disk/HashmapMain/Funs.lean
@@ -264,7 +264,7 @@ def hashmap_hash_map_insert_no_resize_fwd_back
def hashmap_hash_map_try_resize_fwd_back
(T : Type) (self : hashmap_hash_map_t T) : result (hashmap_hash_map_t T) :=
do
- let max_usize <- scalar_cast U32 Usize core_num_u32_max_c
+ let max_usize <- scalar_cast USize core_num_u32_max_c
let capacity := vec_len (hashmap_list_t T) self.hashmap_hash_map_slots
let n1 <- USize.checked_div max_usize (USize.ofNatCore 2 (by intlit))
let (i, i0) := self.hashmap_hash_map_max_load_factor
@@ -401,7 +401,7 @@ def hashmap_hash_map_insert_no_resize_fwd_back
result.ret (hashmap_list_t.HashmapListCons ckey cvalue l)
| hashmap_list_t.HashmapListNil => result.fail error.panic
- termination_by hashmap_hash_map_get_mut_in_list_loop_back ls key =>
+ termination_by hashmap_hash_map_get_mut_in_list_loop_back ls key ret0 =>
hashmap_hash_map_get_mut_in_list_loop_terminates T ls key
decreasing_by hashmap_hash_map_get_mut_in_list_loop_decreases ls key