diff options
-rw-r--r-- | fstar/Primitives.fst | 49 | ||||
-rw-r--r-- | src/Assumed.ml | 69 | ||||
-rw-r--r-- | src/CfimAstUtils.ml | 12 | ||||
-rw-r--r-- | src/PureMicroPasses.ml | 13 | ||||
-rw-r--r-- | src/SymbolicToPure.ml | 6 | ||||
-rw-r--r-- | src/Translate.ml | 4 | ||||
-rw-r--r-- | src/main.ml | 1 |
7 files changed, 105 insertions, 49 deletions
diff --git a/fstar/Primitives.fst b/fstar/Primitives.fst index 6b1ec539..78697259 100644 --- a/fstar/Primitives.fst +++ b/fstar/Primitives.fst @@ -1,9 +1,19 @@ /// This file lists primitive and assumed functions and types module Primitives open FStar.Mul +open FStar.List.Tot #set-options "--z3rlimit 15 --fuel 0 --ifuel 1" +(*** Utilities *) +val list_update (#a : Type0) (ls : list a) (i : nat{i < length ls}) (x : a) : + ls':list a{length ls' = length ls} +#push-options "--fuel 1" +let rec list_update #a ls i x = + match ls with + | x' :: ls -> if i = 0 then x :: ls else x' :: list_update ls (i-1) x +#pop-options + (*** Result *) type result (a : Type0) : Type0 = | Return : a -> result a @@ -23,6 +33,10 @@ let massert (b:bool) : result unit = if b then Return () else Fail (*** Misc *) type char = FStar.Char.char +type string = string + +let mem_replace_fwd (a : Type0) (x : a) (y : a) : a = x +let mem_replace_back (a : Type0) (x : a) (y : a) : a = y (*** Scalars *) /// Rk.: most of the following code was at least partially generated @@ -219,3 +233,38 @@ let u16_mul = scalar_mul #U16 let u32_mul = scalar_mul #U32 let u64_mul = scalar_mul #U64 let u128_mul = scalar_mul #U128 + +(*** Vector *) +type vec (a : Type0) = v:list a{length v <= usize_max} + +let vec_new (a : Type0) : vec a = assert_norm(length #a [] == 0); [] +let vec_len (a : Type0) (v : vec a) : usize = length v + +// The **forward** function shouldn't be used +let vec_push_fwd (a : Type0) (v : vec a) (x : a) : unit = () +let vec_push_back (a : Type0) (v : vec a) (x : a) : result (vec a) = + if length v < usize_max then begin + (**) assert_norm(length [x] == 1); + (**) append_length v [x]; + (**) assert(length (append v [x]) = length v + 1); + Return (append v [x]) + end + else Fail + +// The **forward** function shouldn't be used +let vec_insert_fwd (a : Type0) (v : vec a) (i : usize) (x : a) : result unit = + if i < length v then Return () else Fail +let vec_insert_back (a : Type0) (v : vec a) (i : usize) (x : a) : result (vec a) = + if i < length v then Return (list_update v i x) else Fail + +// The **backward** function shouldn't be used +let vec_index_fwd (a : Type0) (v : vec a) (i : usize) : result a = + if i < length v then Return (index v i) else Fail +let vec_index_back (a : Type0) (v : vec a) (i : usize) (x : a) : result unit = + if i < length v then Return () else Fail + +let vec_index_mut_fwd (a : Type0) (v : vec a) (i : usize) : result a = + if i < length v then Return (index v i) else Fail +let vec_index_mut_back (a : Type0) (v : vec a) (i : usize) (nx : a) : result (vec a) = + if i < length v then Return (list_update v i nx) else Fail + diff --git a/src/Assumed.ml b/src/Assumed.ml index 5a9fb51b..527b2395 100644 --- a/src/Assumed.ml +++ b/src/Assumed.ml @@ -227,37 +227,54 @@ module Sig = struct let vec_index_mut_sig : A.fun_sig = vec_index_gen_sig true end -(** The list of assumed functions, and their signatures. +type assumed_info = A.assumed_fun_id * A.fun_sig * bool * Identifiers.name + +(** The list of assumed functions and all their information: + - their signature + - a boolean indicating whether they are monadic or not (i.e., if they + can fail or not) + - their name Rk.: following what is written above, we don't include `Box::free`. + + Remark about the vector functions: for `Vec::len` to be correct and return + a `usize`, we have to make sure that vectors are bounded by the max usize. + Followingly, `Vec::push` is monadic. *) -let assumed_sigs : (A.assumed_fun_id * A.fun_sig) list = +let assumed_infos : assumed_info list = + let deref_pre = [ "core"; "ops"; "deref" ] in + let vec_pre = [ "alloc"; "vec"; "Vec" ] in + let index_pre = [ "core"; "ops"; "index" ] in [ - (Replace, Sig.mem_replace_sig); - (BoxNew, Sig.box_new_sig); - (BoxDeref, Sig.box_deref_shared_sig); - (BoxDerefMut, Sig.box_deref_mut_sig); - (VecNew, Sig.vec_new_sig); - (VecPush, Sig.vec_push_sig); - (VecInsert, Sig.vec_insert_sig); - (VecLen, Sig.vec_len_sig); - (VecIndex, Sig.vec_index_shared_sig); - (VecIndexMut, Sig.vec_index_mut_sig); + (Replace, Sig.mem_replace_sig, false, [ "core"; "mem"; "replace" ]); + (BoxNew, Sig.box_new_sig, false, [ "alloc"; "boxed"; "Box"; "new" ]); + (BoxDeref, Sig.box_deref_shared_sig, false, deref_pre @ [ "Deref"; "deref" ]); + ( BoxDerefMut, + Sig.box_deref_mut_sig, + false, + deref_pre @ [ "DerefMut"; "deref_mut" ] ); + (VecNew, Sig.vec_new_sig, false, vec_pre @ [ "new" ]); + (VecPush, Sig.vec_push_sig, true, vec_pre @ [ "push" ]); + (VecInsert, Sig.vec_insert_sig, true, vec_pre @ [ "insert" ]); + (VecLen, Sig.vec_len_sig, false, vec_pre @ [ "len" ]); + (VecIndex, Sig.vec_index_shared_sig, true, index_pre @ [ "Index"; "index" ]); + ( VecIndexMut, + Sig.vec_index_mut_sig, + true, + index_pre @ [ "IndexMut"; "index_mut" ] ); ] +let get_assumed_info (id : A.assumed_fun_id) : assumed_info = + List.find (fun (id', _, _, _) -> id = id') assumed_infos + let get_assumed_sig (id : A.assumed_fun_id) : A.fun_sig = - snd (List.find (fun (id', _) -> id = id') assumed_sigs) + let _, sg, _, _ = get_assumed_info id in + sg -let assumed_names : (A.assumed_fun_id * Identifiers.name) list = - [ - (Replace, [ "core"; "mem"; "replace" ]); - (BoxNew, [ "alloc"; "boxed"; "Box"; "new" ]); - (BoxDeref, [ "core"; "ops"; "deref"; "Deref"; "deref" ]); - (BoxDerefMut, [ "core"; "ops"; "deref"; "DerefMut"; "deref_mut" ]); - (VecNew, [ "alloc"; "vec"; "Vec"; "new" ]); - (VecPush, [ "alloc"; "vec"; "Vec"; "push" ]); - (VecInsert, [ "alloc"; "vec"; "Vec"; "insert" ]); - (VecLen, [ "alloc"; "vec"; "Vec"; "len" ]); - (VecIndex, [ "core"; "ops"; "index"; "Index"; "index" ]); - (VecIndexMut, [ "core"; "ops"; "index"; "IndexMut"; "index_mut" ]); - ] +let get_assumed_name (id : A.assumed_fun_id) : Identifiers.name = + let _, _, _, name = get_assumed_info id in + name + +let assumed_is_monadic (id : A.assumed_fun_id) : bool = + let _, _, b, _ = get_assumed_info id in + b diff --git a/src/CfimAstUtils.ml b/src/CfimAstUtils.ml index 902156f2..6a2f680a 100644 --- a/src/CfimAstUtils.ml +++ b/src/CfimAstUtils.ml @@ -23,21 +23,13 @@ let lookup_fun_sig (fun_id : fun_id) (fun_defs : fun_def FunDefId.Map.t) : fun_sig = match fun_id with | Local id -> (FunDefId.Map.find id fun_defs).signature - | Assumed aid -> - let _, sg = - List.find (fun (aid', _) -> aid = aid') Assumed.assumed_sigs - in - sg + | Assumed aid -> Assumed.get_assumed_sig aid let lookup_fun_name (fun_id : fun_id) (fun_defs : fun_def FunDefId.Map.t) : Identifiers.name = match fun_id with | Local id -> (FunDefId.Map.find id fun_defs).name - | Assumed aid -> - let _, sg = - List.find (fun (aid', _) -> aid = aid') Assumed.assumed_names - in - sg + | Assumed aid -> Assumed.get_assumed_name aid (** Small utility: list the transitive parents of a region var group. We don't do that in an efficient manner, but it doesn't matter. diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index 5ac2af4e..99598937 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -48,6 +48,8 @@ type config = { See the comments for [expression_contains_child_call_in_all_paths] for additional explanations. *) + add_unit_args : bool; + (** Add unit input arguments to functions with no arguments. *) } (** A configuration to control the application of the passes *) @@ -615,16 +617,15 @@ let filter_if_backward_with_no_outputs (def : fun_def) : fun_def option = if Option.is_some def.back_id && def.signature.outputs = [] then None else Some def -(** Add unit arguments for functions with no arguments, and change their return type *) -let to_monadic (def : fun_def) : fun_def = +(** Add unit arguments (optionally) for functions with no arguments, and change their return type *) +let to_monadic (add_unit_args : bool) (def : fun_def) : fun_def = (* Update the body *) let obj = object inherit [_] map_expression as super method! visit_call env call = - (* If no arguments, introduce unit *) - if call.args = [] then + if call.args = [] && add_unit_args then let args = [ mk_value_expression unit_rvalue None ] in { call with args } (* Otherwise: nothing to do *) else super#visit_call env call @@ -635,7 +636,7 @@ let to_monadic (def : fun_def) : fun_def = (* Update the signature: first the input types *) let def = - if def.inputs = [] then ( + if def.inputs = [] && add_unit_args then ( assert (def.signature.inputs = []); let signature = { def.signature with inputs = [ unit_ty ] } in let var_cnt = get_expression_min_var_counter def.body.e in @@ -858,7 +859,7 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_def) : * **Rk.**: from now onwards, the types in the AST are correct (until now, * functions had return type `t` where they should have return type `result t`). * Also, from now onwards, the outputs list has length 1. x*) - let def = to_monadic def in + let def = to_monadic config.add_unit_args def in log#ldebug (lazy ("to_monadic:\n\n" ^ fun_def_to_string ctx def ^ "\n")); (* Convert the unit variables to `()` if they are used as right-values or diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index 1967732d..ca214d7c 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -913,11 +913,7 @@ let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) : S.call * V.abs list = let fun_is_monadic (fun_id : A.fun_id) : bool = match fun_id with | A.Local _ -> true - | A.Assumed - ( A.Replace | A.BoxNew | BoxDeref | BoxDerefMut | BoxFree | VecNew - | VecPush | VecLen ) -> - false - | A.Assumed (A.VecInsert | VecIndex | VecIndexMut) -> true + | A.Assumed aid -> Assumed.assumed_is_monadic aid let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression = match e with diff --git a/src/Translate.ml b/src/Translate.ml index 028114cf..913c5cf8 100644 --- a/src/Translate.ml +++ b/src/Translate.ml @@ -227,9 +227,9 @@ let translate_module_to_pure (config : C.partial_config) (* Translate all the function *signatures* *) let assumed_sigs = List.map - (fun (id, sg) -> + (fun (id, sg, _, _) -> (A.Assumed id, List.map (fun _ -> None) (sg : A.fun_sig).inputs, sg)) - Assumed.assumed_sigs + Assumed.assumed_infos in let local_sigs = List.map diff --git a/src/main.ml b/src/main.ml index 193b20c2..5e652809 100644 --- a/src/main.ml +++ b/src/main.ml @@ -142,6 +142,7 @@ let () = Micro.decompose_monadic_let_bindings = !decompose_monads; unfold_monadic_let_bindings = !unfold_monads; filter_unused_monadic_calls = !filter_unused_calls; + add_unit_args = false; } in Translate.translate_module filename dest_dir config micro_passes_config |