summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--backends/lean/Base/Primitives/Base.lean4
-rw-r--r--compiler/ExtractBase.ml4
-rw-r--r--compiler/ExtractBuiltin.ml34
-rw-r--r--compiler/FunsAnalysis.ml34
-rw-r--r--tests/lean/NoNestedBorrows.lean12
5 files changed, 62 insertions, 26 deletions
diff --git a/backends/lean/Base/Primitives/Base.lean b/backends/lean/Base/Primitives/Base.lean
index 2bd081c0..10af8f67 100644
--- a/backends/lean/Base/Primitives/Base.lean
+++ b/backends/lean/Base/Primitives/Base.lean
@@ -120,8 +120,8 @@ def Result.attach {α: Type} (o : Result α): Result { x : α // o = ret x } :=
-- MISC --
----------
-@[simp] def core.mem.replace (a : Type) (x : a) (_ : a) : Result a := ret x
-@[simp] def core.mem.replace_back (a : Type) (_ : a) (y : a) : Result a := ret y
+@[simp] def core.mem.replace (a : Type) (x : a) (_ : a) : a := x
+@[simp] def core.mem.replace_back (a : Type) (_ : a) (y : a) : a := y
/-- Aeneas-translated function -- useful to reduce non-recursive definitions.
Use with `simp [ aeneas ]` -/
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index e004aba8..8f32ba44 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -22,8 +22,8 @@ type region_group_info = {
*)
}
-module StringSet = Collections.MakeSet (Collections.OrderedString)
-module StringMap = Collections.MakeMap (Collections.OrderedString)
+module StringSet = Collections.StringSet
+module StringMap = Collections.StringMap
type name = Names.name
type type_name = Names.type_name
diff --git a/compiler/ExtractBuiltin.ml b/compiler/ExtractBuiltin.ml
index 9cc7c226..4c6fe014 100644
--- a/compiler/ExtractBuiltin.ml
+++ b/compiler/ExtractBuiltin.ml
@@ -376,7 +376,39 @@ let mk_builtin_funs_map () =
let builtin_funs_map = mk_memoized mk_builtin_funs_map
let builtin_non_fallible_funs =
- [ "alloc::boxed::Box::deref"; "alloc::boxed::Box::deref_mut" ]
+ let int_names =
+ [
+ "usize";
+ "u8";
+ "u16";
+ "u32";
+ "u64";
+ "u128";
+ "isize";
+ "i8";
+ "i16";
+ "i32";
+ "i64";
+ "i128";
+ ]
+ in
+ let int_ops =
+ [ "wrapping_add"; "wrapping_sub"; "rotate_left"; "rotate_right" ]
+ in
+ let int_funs =
+ List.map
+ (fun int_name ->
+ List.map (fun op -> "core::num::" ^ int_name ^ "::" ^ op) int_ops)
+ int_names
+ in
+ let int_funs = List.concat int_funs in
+ [
+ "alloc::boxed::Box::deref";
+ "alloc::boxed::Box::deref_mut";
+ "core::mem::replace";
+ "core::mem::take";
+ ]
+ @ int_funs
let builtin_non_fallible_funs_set =
SimpleNameSet.of_list
diff --git a/compiler/FunsAnalysis.ml b/compiler/FunsAnalysis.ml
index 1273f57d..3ba5d35d 100644
--- a/compiler/FunsAnalysis.ml
+++ b/compiler/FunsAnalysis.ml
@@ -57,21 +57,16 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t)
let stateful = ref false in
let can_diverge = ref false in
let is_rec = ref false in
+ let is_builtin_non_fallible_group = ref false in
(* We have some specialized knowledge of some library functions; we don't
have any more custom treatment than this, and these functions can be modeled
suitably in Primitives.fst, rather than special-casing for them all the
way. *)
- let module M = struct type opaque_info = { fallible: bool; stateful: bool } end in
- let open M in
- let opaque_info (f: fun_decl) =
- match f.name with
- | [ Ident "core"; Ident "num"; Ident "u32"; _; Ident "wrapping_add" ]
- | [ Ident "core"; Ident "num"; Ident "u32"; _; Ident "rotate_left" ] ->
- { fallible = false; stateful = false }
- | _ ->
- (* Opaque function: we consider they fail by default *)
- { fallible = true; stateful = true }
+ let is_builtin_non_fallible (f : fun_decl) : bool =
+ let open ExtractBuiltin in
+ let name = name_to_simple_name f.name in
+ SimpleNameSet.mem name builtin_non_fallible_funs_set
in
(* JP: Why not use a reduce visitor here with a tuple of the values to be
@@ -124,11 +119,16 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t)
in
(* Sanity check: global bodies don't contain stateful calls *)
assert ((not f.is_global_decl_body) || not !stateful);
+ let is_builtin_non_fallible = is_builtin_non_fallible f in
+ is_builtin_non_fallible_group :=
+ !is_builtin_non_fallible_group || is_builtin_non_fallible;
match f.body with
| None ->
- let info = opaque_info f in
- obj#may_fail info.fallible;
- stateful := (not f.is_global_decl_body) && use_state && info.stateful
+ obj#may_fail (not is_builtin_non_fallible);
+ stateful :=
+ (not f.is_global_decl_body)
+ && use_state
+ && not is_builtin_non_fallible
| Some body -> obj#visit_statement () body.body
in
List.iter visit_fun d;
@@ -136,12 +136,16 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t)
* groups containing globals contain exactly one declaration *)
let is_global_decl_body = List.exists (fun f -> f.is_global_decl_body) d in
assert ((not is_global_decl_body) || List.length d = 1);
+ assert ((not !is_builtin_non_fallible_group) || List.length d = 1);
(* We ignore on purpose functions that cannot fail and consider they *can*
* fail: the result of the analysis is not used yet to adjust the translation
* so that the functions which syntactically can't fail don't use an error monad.
- * However, we do keep the result of the analysis for global bodies.
+ * However, we do keep the result of the analysis for global bodies and for
+ * builtin functions which are marked as non-fallible.
* *)
- can_fail := (not is_global_decl_body) || !can_fail;
+ can_fail :=
+ ((not is_global_decl_body) && not !is_builtin_non_fallible_group)
+ || !can_fail;
{
can_fail = !can_fail;
stateful = !stateful;
diff --git a/tests/lean/NoNestedBorrows.lean b/tests/lean/NoNestedBorrows.lean
index a90d6ea2..d6d603ce 100644
--- a/tests/lean/NoNestedBorrows.lean
+++ b/tests/lean/NoNestedBorrows.lean
@@ -139,9 +139,9 @@ def test_list1 : Result Unit :=
/- [no_nested_borrows::test_box1]: forward function -/
def test_box1 : Result Unit :=
- let b := (I32.ofInt 1)
+ let b := 1#i32
let x := b
- if not (x = (I32.ofInt 1))
+ if not (x = 1#i32)
then Result.fail Error.panic
else Result.ret ()
@@ -316,7 +316,7 @@ divergent def list_rev_aux
/- [no_nested_borrows::list_rev]: merged forward/backward function
(there is a single backward function, and the forward function returns ()) -/
def list_rev (T : Type) (l : List T) : Result (List T) :=
- let li := mem.replace (List T) l List.Nil
+ let li := core.mem.replace (List T) l List.Nil
list_rev_aux T li List.Nil
/- [no_nested_borrows::test_list_functions]: forward function -/
@@ -478,10 +478,10 @@ def test_weird_borrows1 : Result Unit :=
/- [no_nested_borrows::test_mem_replace]: merged forward/backward function
(there is a single backward function, and the forward function returns ()) -/
def test_mem_replace (px : U32) : Result U32 :=
- let y := mem.replace U32 px (U32.ofInt 1)
- if not (y = (U32.ofInt 0))
+ let y := core.mem.replace U32 px 1#u32
+ if not (y = 0#u32)
then Result.fail Error.panic
- else Result.ret (U32.ofInt 2)
+ else Result.ret 2#u32
/- [no_nested_borrows::test_shared_borrow_bool1]: forward function -/
def test_shared_borrow_bool1 (b : Bool) : Result U32 :=