diff options
author | Jonathan Protzenko | 2023-01-31 15:28:17 -0800 |
---|---|---|
committer | Son HO | 2023-06-04 21:44:33 +0200 |
commit | 7586cf83f59ca784ff4bfd5d11e460fd41acec98 (patch) | |
tree | 7a023c2d7fcc8ee76e45568524c4dba1ea95f03d | |
parent | b30bac9e20735ab47327a2ac3122c2cfce162845 (diff) |
Fill out more of the primitives file, attempt at type classes for scalar_cast
-rw-r--r-- | backends/lean/primitives.lean | 28 | ||||
-rw-r--r-- | compiler/Extract.ml | 25 | ||||
-rw-r--r-- | tests/lean/hashmap_on_disk/Base/Primitives.lean | 28 | ||||
-rw-r--r-- | tests/lean/hashmap_on_disk/HashmapMain/Funs.lean | 4 |
4 files changed, 69 insertions, 16 deletions
diff --git a/backends/lean/primitives.lean b/backends/lean/primitives.lean index 3f1d13f1..0a51aacd 100644 --- a/backends/lean/primitives.lean +++ b/backends/lean/primitives.lean @@ -1,5 +1,6 @@ import Lean import Init.Data.List.Basic +import Mathlib.Tactic.RunCmd ------------- -- PRELUDE -- @@ -140,7 +141,7 @@ def USize.checked_add (n: USize) (m: USize): result USize := def USize.checked_rem (n: USize) (m: USize): result USize := if h: m > 0 then .ret ⟨ n.val % m.val, by - have h1: m.val < USize.size := m.val.isLt + have h1: ↑m.val < USize.size := m.val.isLt have h2: n.val.val % m.val.val < m.val.val := @Nat.mod_lt n.val m.val h apply Nat.lt_trans h2 h1 ⟩ @@ -156,13 +157,36 @@ def USize.checked_mul (n: USize) (m: USize): result USize := def USize.checked_div (n: USize) (m: USize): result USize := if m > 0 then .ret ⟨ n.val / m.val, by - have h1: n.val < USize.size := n.val.isLt + have h1: ↑n.val < USize.size := n.val.isLt have h2: n.val.val / m.val.val <= n.val.val := @Nat.div_le_self n.val m.val apply Nat.lt_of_le_of_lt h2 h1 ⟩ else .fail integerOverflow +class MachineInteger (t: Type) where + size: Nat + val: t -> Fin size + ofNatCore: (n:Nat) -> LT.lt n size -> t + +set_option hygiene false in +run_cmd + for typeName in [`UInt8, `UInt16, `UInt32, `UInt64, `USize].map Lean.mkIdent do + Lean.Elab.Command.elabCommand (← `( + namespace $typeName + instance: MachineInteger $typeName where + size := size + val := val + ofNatCore := ofNatCore + end $typeName + )) + +def scalar_cast { src: Type } (dst: Type) [ MachineInteger src ] [ MachineInteger dst ] (x: src): result dst := + if h: MachineInteger.val x < MachineInteger.size dst then + .ret (MachineInteger.ofNatCore (MachineInteger.val x).val h) + else + .fail integerOverflow + -- One needs to perform a little bit of reasoning in order to successfully -- inject constants into USize, so we provide a general-purpose macro diff --git a/compiler/Extract.ml b/compiler/Extract.ml index 64069cb0..a5ff6796 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -251,13 +251,18 @@ let extract_unop (extract_expr : bool -> texpression -> unit) if inside then F.pp_print_string fmt "("; F.pp_print_string fmt "scalar_cast"; F.pp_print_space fmt (); - F.pp_print_string fmt - (StringUtils.capitalize_first_letter - (PrintPure.integer_type_to_string src)); - F.pp_print_space fmt (); - F.pp_print_string fmt - (StringUtils.capitalize_first_letter - (PrintPure.integer_type_to_string tgt)); + if !backend <> Lean then begin + F.pp_print_string fmt + (StringUtils.capitalize_first_letter + (PrintPure.integer_type_to_string src)); + F.pp_print_space fmt () + end; + if !backend = Lean then + F.pp_print_string fmt (int_name tgt) + else + F.pp_print_string fmt + (StringUtils.capitalize_first_letter + (PrintPure.integer_type_to_string tgt)); F.pp_print_space fmt (); extract_expr true arg; if inside then F.pp_print_string fmt ")" @@ -2345,11 +2350,11 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) (* Termination clause *) if has_decreases_clause && !backend = Lean then begin let def_body = Option.get def.body in - let vars = List.map (fun (v: var) -> v.id) def_body.inputs in + let all_vars = List.map (fun (v: var) -> v.id) def_body.inputs in let num_fwd_inputs = def.signature.info.num_fwd_inputs_with_fuel_with_state in - let vars = Collections.List.prefix num_fwd_inputs vars in + let vars = Collections.List.prefix num_fwd_inputs all_vars in (* terminates_by *) let terminates_name = ctx_get_terminates_clause def.def_id def.loop_id ctx in @@ -2363,7 +2368,7 @@ let extract_fun_decl (ctx : extraction_ctx) (fmt : F.formatter) Collections.List.iter_link (F.pp_print_space fmt) (fun v -> F.pp_print_string fmt (ctx_get_var v ctx_body)) - vars; + all_vars; F.pp_print_space fmt (); F.pp_print_string fmt "=>"; F.pp_close_box fmt (); diff --git a/tests/lean/hashmap_on_disk/Base/Primitives.lean b/tests/lean/hashmap_on_disk/Base/Primitives.lean index 3f1d13f1..0a51aacd 100644 --- a/tests/lean/hashmap_on_disk/Base/Primitives.lean +++ b/tests/lean/hashmap_on_disk/Base/Primitives.lean @@ -1,5 +1,6 @@ import Lean import Init.Data.List.Basic +import Mathlib.Tactic.RunCmd ------------- -- PRELUDE -- @@ -140,7 +141,7 @@ def USize.checked_add (n: USize) (m: USize): result USize := def USize.checked_rem (n: USize) (m: USize): result USize := if h: m > 0 then .ret ⟨ n.val % m.val, by - have h1: m.val < USize.size := m.val.isLt + have h1: ↑m.val < USize.size := m.val.isLt have h2: n.val.val % m.val.val < m.val.val := @Nat.mod_lt n.val m.val h apply Nat.lt_trans h2 h1 ⟩ @@ -156,13 +157,36 @@ def USize.checked_mul (n: USize) (m: USize): result USize := def USize.checked_div (n: USize) (m: USize): result USize := if m > 0 then .ret ⟨ n.val / m.val, by - have h1: n.val < USize.size := n.val.isLt + have h1: ↑n.val < USize.size := n.val.isLt have h2: n.val.val / m.val.val <= n.val.val := @Nat.div_le_self n.val m.val apply Nat.lt_of_le_of_lt h2 h1 ⟩ else .fail integerOverflow +class MachineInteger (t: Type) where + size: Nat + val: t -> Fin size + ofNatCore: (n:Nat) -> LT.lt n size -> t + +set_option hygiene false in +run_cmd + for typeName in [`UInt8, `UInt16, `UInt32, `UInt64, `USize].map Lean.mkIdent do + Lean.Elab.Command.elabCommand (← `( + namespace $typeName + instance: MachineInteger $typeName where + size := size + val := val + ofNatCore := ofNatCore + end $typeName + )) + +def scalar_cast { src: Type } (dst: Type) [ MachineInteger src ] [ MachineInteger dst ] (x: src): result dst := + if h: MachineInteger.val x < MachineInteger.size dst then + .ret (MachineInteger.ofNatCore (MachineInteger.val x).val h) + else + .fail integerOverflow + -- One needs to perform a little bit of reasoning in order to successfully -- inject constants into USize, so we provide a general-purpose macro diff --git a/tests/lean/hashmap_on_disk/HashmapMain/Funs.lean b/tests/lean/hashmap_on_disk/HashmapMain/Funs.lean index f9d4a8c5..79d711e9 100644 --- a/tests/lean/hashmap_on_disk/HashmapMain/Funs.lean +++ b/tests/lean/hashmap_on_disk/HashmapMain/Funs.lean @@ -264,7 +264,7 @@ def hashmap_hash_map_insert_no_resize_fwd_back def hashmap_hash_map_try_resize_fwd_back (T : Type) (self : hashmap_hash_map_t T) : result (hashmap_hash_map_t T) := do - let max_usize <- scalar_cast U32 Usize core_num_u32_max_c + let max_usize <- scalar_cast USize core_num_u32_max_c let capacity := vec_len (hashmap_list_t T) self.hashmap_hash_map_slots let n1 <- USize.checked_div max_usize (USize.ofNatCore 2 (by intlit)) let (i, i0) := self.hashmap_hash_map_max_load_factor @@ -401,7 +401,7 @@ def hashmap_hash_map_insert_no_resize_fwd_back result.ret (hashmap_list_t.HashmapListCons ckey cvalue l) | hashmap_list_t.HashmapListNil => result.fail error.panic - termination_by hashmap_hash_map_get_mut_in_list_loop_back ls key => + termination_by hashmap_hash_map_get_mut_in_list_loop_back ls key ret0 => hashmap_hash_map_get_mut_in_list_loop_terminates T ls key decreasing_by hashmap_hash_map_get_mut_in_list_loop_decreases ls key |