summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--fstar/Primitives.fst49
-rw-r--r--src/Assumed.ml69
-rw-r--r--src/CfimAstUtils.ml12
-rw-r--r--src/PureMicroPasses.ml13
-rw-r--r--src/SymbolicToPure.ml6
-rw-r--r--src/Translate.ml4
-rw-r--r--src/main.ml1
7 files changed, 105 insertions, 49 deletions
diff --git a/fstar/Primitives.fst b/fstar/Primitives.fst
index 6b1ec539..78697259 100644
--- a/fstar/Primitives.fst
+++ b/fstar/Primitives.fst
@@ -1,9 +1,19 @@
/// This file lists primitive and assumed functions and types
module Primitives
open FStar.Mul
+open FStar.List.Tot
#set-options "--z3rlimit 15 --fuel 0 --ifuel 1"
+(*** Utilities *)
+val list_update (#a : Type0) (ls : list a) (i : nat{i < length ls}) (x : a) :
+ ls':list a{length ls' = length ls}
+#push-options "--fuel 1"
+let rec list_update #a ls i x =
+ match ls with
+ | x' :: ls -> if i = 0 then x :: ls else x' :: list_update ls (i-1) x
+#pop-options
+
(*** Result *)
type result (a : Type0) : Type0 =
| Return : a -> result a
@@ -23,6 +33,10 @@ let massert (b:bool) : result unit = if b then Return () else Fail
(*** Misc *)
type char = FStar.Char.char
+type string = string
+
+let mem_replace_fwd (a : Type0) (x : a) (y : a) : a = x
+let mem_replace_back (a : Type0) (x : a) (y : a) : a = y
(*** Scalars *)
/// Rk.: most of the following code was at least partially generated
@@ -219,3 +233,38 @@ let u16_mul = scalar_mul #U16
let u32_mul = scalar_mul #U32
let u64_mul = scalar_mul #U64
let u128_mul = scalar_mul #U128
+
+(*** Vector *)
+type vec (a : Type0) = v:list a{length v <= usize_max}
+
+let vec_new (a : Type0) : vec a = assert_norm(length #a [] == 0); []
+let vec_len (a : Type0) (v : vec a) : usize = length v
+
+// The **forward** function shouldn't be used
+let vec_push_fwd (a : Type0) (v : vec a) (x : a) : unit = ()
+let vec_push_back (a : Type0) (v : vec a) (x : a) : result (vec a) =
+ if length v < usize_max then begin
+ (**) assert_norm(length [x] == 1);
+ (**) append_length v [x];
+ (**) assert(length (append v [x]) = length v + 1);
+ Return (append v [x])
+ end
+ else Fail
+
+// The **forward** function shouldn't be used
+let vec_insert_fwd (a : Type0) (v : vec a) (i : usize) (x : a) : result unit =
+ if i < length v then Return () else Fail
+let vec_insert_back (a : Type0) (v : vec a) (i : usize) (x : a) : result (vec a) =
+ if i < length v then Return (list_update v i x) else Fail
+
+// The **backward** function shouldn't be used
+let vec_index_fwd (a : Type0) (v : vec a) (i : usize) : result a =
+ if i < length v then Return (index v i) else Fail
+let vec_index_back (a : Type0) (v : vec a) (i : usize) (x : a) : result unit =
+ if i < length v then Return () else Fail
+
+let vec_index_mut_fwd (a : Type0) (v : vec a) (i : usize) : result a =
+ if i < length v then Return (index v i) else Fail
+let vec_index_mut_back (a : Type0) (v : vec a) (i : usize) (nx : a) : result (vec a) =
+ if i < length v then Return (list_update v i nx) else Fail
+
diff --git a/src/Assumed.ml b/src/Assumed.ml
index 5a9fb51b..527b2395 100644
--- a/src/Assumed.ml
+++ b/src/Assumed.ml
@@ -227,37 +227,54 @@ module Sig = struct
let vec_index_mut_sig : A.fun_sig = vec_index_gen_sig true
end
-(** The list of assumed functions, and their signatures.
+type assumed_info = A.assumed_fun_id * A.fun_sig * bool * Identifiers.name
+
+(** The list of assumed functions and all their information:
+ - their signature
+ - a boolean indicating whether they are monadic or not (i.e., if they
+ can fail or not)
+ - their name
Rk.: following what is written above, we don't include `Box::free`.
+
+ Remark about the vector functions: for `Vec::len` to be correct and return
+ a `usize`, we have to make sure that vectors are bounded by the max usize.
+ Followingly, `Vec::push` is monadic.
*)
-let assumed_sigs : (A.assumed_fun_id * A.fun_sig) list =
+let assumed_infos : assumed_info list =
+ let deref_pre = [ "core"; "ops"; "deref" ] in
+ let vec_pre = [ "alloc"; "vec"; "Vec" ] in
+ let index_pre = [ "core"; "ops"; "index" ] in
[
- (Replace, Sig.mem_replace_sig);
- (BoxNew, Sig.box_new_sig);
- (BoxDeref, Sig.box_deref_shared_sig);
- (BoxDerefMut, Sig.box_deref_mut_sig);
- (VecNew, Sig.vec_new_sig);
- (VecPush, Sig.vec_push_sig);
- (VecInsert, Sig.vec_insert_sig);
- (VecLen, Sig.vec_len_sig);
- (VecIndex, Sig.vec_index_shared_sig);
- (VecIndexMut, Sig.vec_index_mut_sig);
+ (Replace, Sig.mem_replace_sig, false, [ "core"; "mem"; "replace" ]);
+ (BoxNew, Sig.box_new_sig, false, [ "alloc"; "boxed"; "Box"; "new" ]);
+ (BoxDeref, Sig.box_deref_shared_sig, false, deref_pre @ [ "Deref"; "deref" ]);
+ ( BoxDerefMut,
+ Sig.box_deref_mut_sig,
+ false,
+ deref_pre @ [ "DerefMut"; "deref_mut" ] );
+ (VecNew, Sig.vec_new_sig, false, vec_pre @ [ "new" ]);
+ (VecPush, Sig.vec_push_sig, true, vec_pre @ [ "push" ]);
+ (VecInsert, Sig.vec_insert_sig, true, vec_pre @ [ "insert" ]);
+ (VecLen, Sig.vec_len_sig, false, vec_pre @ [ "len" ]);
+ (VecIndex, Sig.vec_index_shared_sig, true, index_pre @ [ "Index"; "index" ]);
+ ( VecIndexMut,
+ Sig.vec_index_mut_sig,
+ true,
+ index_pre @ [ "IndexMut"; "index_mut" ] );
]
+let get_assumed_info (id : A.assumed_fun_id) : assumed_info =
+ List.find (fun (id', _, _, _) -> id = id') assumed_infos
+
let get_assumed_sig (id : A.assumed_fun_id) : A.fun_sig =
- snd (List.find (fun (id', _) -> id = id') assumed_sigs)
+ let _, sg, _, _ = get_assumed_info id in
+ sg
-let assumed_names : (A.assumed_fun_id * Identifiers.name) list =
- [
- (Replace, [ "core"; "mem"; "replace" ]);
- (BoxNew, [ "alloc"; "boxed"; "Box"; "new" ]);
- (BoxDeref, [ "core"; "ops"; "deref"; "Deref"; "deref" ]);
- (BoxDerefMut, [ "core"; "ops"; "deref"; "DerefMut"; "deref_mut" ]);
- (VecNew, [ "alloc"; "vec"; "Vec"; "new" ]);
- (VecPush, [ "alloc"; "vec"; "Vec"; "push" ]);
- (VecInsert, [ "alloc"; "vec"; "Vec"; "insert" ]);
- (VecLen, [ "alloc"; "vec"; "Vec"; "len" ]);
- (VecIndex, [ "core"; "ops"; "index"; "Index"; "index" ]);
- (VecIndexMut, [ "core"; "ops"; "index"; "IndexMut"; "index_mut" ]);
- ]
+let get_assumed_name (id : A.assumed_fun_id) : Identifiers.name =
+ let _, _, _, name = get_assumed_info id in
+ name
+
+let assumed_is_monadic (id : A.assumed_fun_id) : bool =
+ let _, _, b, _ = get_assumed_info id in
+ b
diff --git a/src/CfimAstUtils.ml b/src/CfimAstUtils.ml
index 902156f2..6a2f680a 100644
--- a/src/CfimAstUtils.ml
+++ b/src/CfimAstUtils.ml
@@ -23,21 +23,13 @@ let lookup_fun_sig (fun_id : fun_id) (fun_defs : fun_def FunDefId.Map.t) :
fun_sig =
match fun_id with
| Local id -> (FunDefId.Map.find id fun_defs).signature
- | Assumed aid ->
- let _, sg =
- List.find (fun (aid', _) -> aid = aid') Assumed.assumed_sigs
- in
- sg
+ | Assumed aid -> Assumed.get_assumed_sig aid
let lookup_fun_name (fun_id : fun_id) (fun_defs : fun_def FunDefId.Map.t) :
Identifiers.name =
match fun_id with
| Local id -> (FunDefId.Map.find id fun_defs).name
- | Assumed aid ->
- let _, sg =
- List.find (fun (aid', _) -> aid = aid') Assumed.assumed_names
- in
- sg
+ | Assumed aid -> Assumed.get_assumed_name aid
(** Small utility: list the transitive parents of a region var group.
We don't do that in an efficient manner, but it doesn't matter.
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml
index 5ac2af4e..99598937 100644
--- a/src/PureMicroPasses.ml
+++ b/src/PureMicroPasses.ml
@@ -48,6 +48,8 @@ type config = {
See the comments for [expression_contains_child_call_in_all_paths]
for additional explanations.
*)
+ add_unit_args : bool;
+ (** Add unit input arguments to functions with no arguments. *)
}
(** A configuration to control the application of the passes *)
@@ -615,16 +617,15 @@ let filter_if_backward_with_no_outputs (def : fun_def) : fun_def option =
if Option.is_some def.back_id && def.signature.outputs = [] then None
else Some def
-(** Add unit arguments for functions with no arguments, and change their return type *)
-let to_monadic (def : fun_def) : fun_def =
+(** Add unit arguments (optionally) for functions with no arguments, and change their return type *)
+let to_monadic (add_unit_args : bool) (def : fun_def) : fun_def =
(* Update the body *)
let obj =
object
inherit [_] map_expression as super
method! visit_call env call =
- (* If no arguments, introduce unit *)
- if call.args = [] then
+ if call.args = [] && add_unit_args then
let args = [ mk_value_expression unit_rvalue None ] in
{ call with args } (* Otherwise: nothing to do *)
else super#visit_call env call
@@ -635,7 +636,7 @@ let to_monadic (def : fun_def) : fun_def =
(* Update the signature: first the input types *)
let def =
- if def.inputs = [] then (
+ if def.inputs = [] && add_unit_args then (
assert (def.signature.inputs = []);
let signature = { def.signature with inputs = [ unit_ty ] } in
let var_cnt = get_expression_min_var_counter def.body.e in
@@ -858,7 +859,7 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_def) :
* **Rk.**: from now onwards, the types in the AST are correct (until now,
* functions had return type `t` where they should have return type `result t`).
* Also, from now onwards, the outputs list has length 1. x*)
- let def = to_monadic def in
+ let def = to_monadic config.add_unit_args def in
log#ldebug (lazy ("to_monadic:\n\n" ^ fun_def_to_string ctx def ^ "\n"));
(* Convert the unit variables to `()` if they are used as right-values or
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index 1967732d..ca214d7c 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -913,11 +913,7 @@ let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) : S.call * V.abs list =
let fun_is_monadic (fun_id : A.fun_id) : bool =
match fun_id with
| A.Local _ -> true
- | A.Assumed
- ( A.Replace | A.BoxNew | BoxDeref | BoxDerefMut | BoxFree | VecNew
- | VecPush | VecLen ) ->
- false
- | A.Assumed (A.VecInsert | VecIndex | VecIndexMut) -> true
+ | A.Assumed aid -> Assumed.assumed_is_monadic aid
let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =
match e with
diff --git a/src/Translate.ml b/src/Translate.ml
index 028114cf..913c5cf8 100644
--- a/src/Translate.ml
+++ b/src/Translate.ml
@@ -227,9 +227,9 @@ let translate_module_to_pure (config : C.partial_config)
(* Translate all the function *signatures* *)
let assumed_sigs =
List.map
- (fun (id, sg) ->
+ (fun (id, sg, _, _) ->
(A.Assumed id, List.map (fun _ -> None) (sg : A.fun_sig).inputs, sg))
- Assumed.assumed_sigs
+ Assumed.assumed_infos
in
let local_sigs =
List.map
diff --git a/src/main.ml b/src/main.ml
index 193b20c2..5e652809 100644
--- a/src/main.ml
+++ b/src/main.ml
@@ -142,6 +142,7 @@ let () =
Micro.decompose_monadic_let_bindings = !decompose_monads;
unfold_monadic_let_bindings = !unfold_monads;
filter_unused_monadic_calls = !filter_unused_calls;
+ add_unit_args = false;
}
in
Translate.translate_module filename dest_dir config micro_passes_config