From 9c230dddebb171ee1b3e0176838441163836b875 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Tue, 24 Oct 2023 18:16:53 +0200 Subject: Handle properly the builtin, non fallible functions --- backends/lean/Base/Primitives/Base.lean | 4 ++-- compiler/ExtractBase.ml | 4 ++-- compiler/ExtractBuiltin.ml | 34 ++++++++++++++++++++++++++++++++- compiler/FunsAnalysis.ml | 34 ++++++++++++++++++--------------- tests/lean/NoNestedBorrows.lean | 12 ++++++------ 5 files changed, 62 insertions(+), 26 deletions(-) diff --git a/backends/lean/Base/Primitives/Base.lean b/backends/lean/Base/Primitives/Base.lean index 2bd081c0..10af8f67 100644 --- a/backends/lean/Base/Primitives/Base.lean +++ b/backends/lean/Base/Primitives/Base.lean @@ -120,8 +120,8 @@ def Result.attach {α: Type} (o : Result α): Result { x : α // o = ret x } := -- MISC -- ---------- -@[simp] def core.mem.replace (a : Type) (x : a) (_ : a) : Result a := ret x -@[simp] def core.mem.replace_back (a : Type) (_ : a) (y : a) : Result a := ret y +@[simp] def core.mem.replace (a : Type) (x : a) (_ : a) : a := x +@[simp] def core.mem.replace_back (a : Type) (_ : a) (y : a) : a := y /-- Aeneas-translated function -- useful to reduce non-recursive definitions. Use with `simp [ aeneas ]` -/ diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index e004aba8..8f32ba44 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -22,8 +22,8 @@ type region_group_info = { *) } -module StringSet = Collections.MakeSet (Collections.OrderedString) -module StringMap = Collections.MakeMap (Collections.OrderedString) +module StringSet = Collections.StringSet +module StringMap = Collections.StringMap type name = Names.name type type_name = Names.type_name diff --git a/compiler/ExtractBuiltin.ml b/compiler/ExtractBuiltin.ml index 9cc7c226..4c6fe014 100644 --- a/compiler/ExtractBuiltin.ml +++ b/compiler/ExtractBuiltin.ml @@ -376,7 +376,39 @@ let mk_builtin_funs_map () = let builtin_funs_map = mk_memoized mk_builtin_funs_map let builtin_non_fallible_funs = - [ "alloc::boxed::Box::deref"; "alloc::boxed::Box::deref_mut" ] + let int_names = + [ + "usize"; + "u8"; + "u16"; + "u32"; + "u64"; + "u128"; + "isize"; + "i8"; + "i16"; + "i32"; + "i64"; + "i128"; + ] + in + let int_ops = + [ "wrapping_add"; "wrapping_sub"; "rotate_left"; "rotate_right" ] + in + let int_funs = + List.map + (fun int_name -> + List.map (fun op -> "core::num::" ^ int_name ^ "::" ^ op) int_ops) + int_names + in + let int_funs = List.concat int_funs in + [ + "alloc::boxed::Box::deref"; + "alloc::boxed::Box::deref_mut"; + "core::mem::replace"; + "core::mem::take"; + ] + @ int_funs let builtin_non_fallible_funs_set = SimpleNameSet.of_list diff --git a/compiler/FunsAnalysis.ml b/compiler/FunsAnalysis.ml index 1273f57d..3ba5d35d 100644 --- a/compiler/FunsAnalysis.ml +++ b/compiler/FunsAnalysis.ml @@ -57,21 +57,16 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t) let stateful = ref false in let can_diverge = ref false in let is_rec = ref false in + let is_builtin_non_fallible_group = ref false in (* We have some specialized knowledge of some library functions; we don't have any more custom treatment than this, and these functions can be modeled suitably in Primitives.fst, rather than special-casing for them all the way. *) - let module M = struct type opaque_info = { fallible: bool; stateful: bool } end in - let open M in - let opaque_info (f: fun_decl) = - match f.name with - | [ Ident "core"; Ident "num"; Ident "u32"; _; Ident "wrapping_add" ] - | [ Ident "core"; Ident "num"; Ident "u32"; _; Ident "rotate_left" ] -> - { fallible = false; stateful = false } - | _ -> - (* Opaque function: we consider they fail by default *) - { fallible = true; stateful = true } + let is_builtin_non_fallible (f : fun_decl) : bool = + let open ExtractBuiltin in + let name = name_to_simple_name f.name in + SimpleNameSet.mem name builtin_non_fallible_funs_set in (* JP: Why not use a reduce visitor here with a tuple of the values to be @@ -124,11 +119,16 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t) in (* Sanity check: global bodies don't contain stateful calls *) assert ((not f.is_global_decl_body) || not !stateful); + let is_builtin_non_fallible = is_builtin_non_fallible f in + is_builtin_non_fallible_group := + !is_builtin_non_fallible_group || is_builtin_non_fallible; match f.body with | None -> - let info = opaque_info f in - obj#may_fail info.fallible; - stateful := (not f.is_global_decl_body) && use_state && info.stateful + obj#may_fail (not is_builtin_non_fallible); + stateful := + (not f.is_global_decl_body) + && use_state + && not is_builtin_non_fallible | Some body -> obj#visit_statement () body.body in List.iter visit_fun d; @@ -136,12 +136,16 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t) * groups containing globals contain exactly one declaration *) let is_global_decl_body = List.exists (fun f -> f.is_global_decl_body) d in assert ((not is_global_decl_body) || List.length d = 1); + assert ((not !is_builtin_non_fallible_group) || List.length d = 1); (* We ignore on purpose functions that cannot fail and consider they *can* * fail: the result of the analysis is not used yet to adjust the translation * so that the functions which syntactically can't fail don't use an error monad. - * However, we do keep the result of the analysis for global bodies. + * However, we do keep the result of the analysis for global bodies and for + * builtin functions which are marked as non-fallible. * *) - can_fail := (not is_global_decl_body) || !can_fail; + can_fail := + ((not is_global_decl_body) && not !is_builtin_non_fallible_group) + || !can_fail; { can_fail = !can_fail; stateful = !stateful; diff --git a/tests/lean/NoNestedBorrows.lean b/tests/lean/NoNestedBorrows.lean index a90d6ea2..d6d603ce 100644 --- a/tests/lean/NoNestedBorrows.lean +++ b/tests/lean/NoNestedBorrows.lean @@ -139,9 +139,9 @@ def test_list1 : Result Unit := /- [no_nested_borrows::test_box1]: forward function -/ def test_box1 : Result Unit := - let b := (I32.ofInt 1) + let b := 1#i32 let x := b - if not (x = (I32.ofInt 1)) + if not (x = 1#i32) then Result.fail Error.panic else Result.ret () @@ -316,7 +316,7 @@ divergent def list_rev_aux /- [no_nested_borrows::list_rev]: merged forward/backward function (there is a single backward function, and the forward function returns ()) -/ def list_rev (T : Type) (l : List T) : Result (List T) := - let li := mem.replace (List T) l List.Nil + let li := core.mem.replace (List T) l List.Nil list_rev_aux T li List.Nil /- [no_nested_borrows::test_list_functions]: forward function -/ @@ -478,10 +478,10 @@ def test_weird_borrows1 : Result Unit := /- [no_nested_borrows::test_mem_replace]: merged forward/backward function (there is a single backward function, and the forward function returns ()) -/ def test_mem_replace (px : U32) : Result U32 := - let y := mem.replace U32 px (U32.ofInt 1) - if not (y = (U32.ofInt 0)) + let y := core.mem.replace U32 px 1#u32 + if not (y = 0#u32) then Result.fail Error.panic - else Result.ret (U32.ofInt 2) + else Result.ret 2#u32 /- [no_nested_borrows::test_shared_borrow_bool1]: forward function -/ def test_shared_borrow_bool1 (b : Bool) : Result U32 := -- cgit v1.2.3