summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSon Ho2022-02-09 10:55:40 +0100
committerSon Ho2022-02-09 10:55:40 +0100
commitb5295c0bf9e7aee437eed8f8fc57e4fba46cb8ef (patch)
treedf7b68348ce4bd784e2c14012652ad3e0fc6f91f /src
parentb85a44d557c7c03e0052b03a824612a99409ef03 (diff)
Implement filtering of useless forward functions
Diffstat (limited to 'src')
-rw-r--r--src/PureMicroPasses.ml74
-rw-r--r--src/SymbolicToPure.ml103
-rw-r--r--src/Translate.ml42
-rw-r--r--src/main.ml5
4 files changed, 165 insertions, 59 deletions
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml
index 59871600..7094d885 100644
--- a/src/PureMicroPasses.ml
+++ b/src/PureMicroPasses.ml
@@ -47,6 +47,20 @@ type config = {
See the comments for [expression_contains_child_call_in_all_paths]
for additional explanations.
+
+ TODO: rename to [filter_useless_monadic_calls]
+ *)
+ filter_useless_functions : bool;
+ (** If [filter_unused_monadic_calls] is activated, some functions
+ become useless: if this option is true, we don't extract them.
+
+ The calls to functions which always get filtered are:
+ - the forward functions with unit return value
+ - the backward functions which don't output anything (backward
+ functions coming from rust functions with no mutable borrows
+ as input values - note that if a function doesn't take mutable
+ borrows as inputs, it can't return mutable borrows; we actually
+ dynamically check for that).
*)
add_unit_args : bool;
(** Add unit input arguments to functions with no arguments. *)
@@ -612,11 +626,47 @@ let filter_unused (filter_monadic_calls : bool) (ctx : trans_ctx)
{ def with body; inputs_lvs }
(** Return `None` if the function is a backward function with no outputs (so
- that we eliminate the definition which is useless) *)
-let filter_if_backward_with_no_outputs (def : fun_def) : fun_def option =
- if Option.is_some def.back_id && def.signature.outputs = [] then None
+ that we eliminate the definition which is useless).
+
+ Note that the calls to such functions are filtered when translating from
+ symbolic to pure. Here, we remove the definitions altogether, because they
+ are now useless
+ *)
+let filter_if_backward_with_no_outputs (config : config) (def : fun_def) :
+ fun_def option =
+ if
+ config.filter_useless_functions && Option.is_some def.back_id
+ && def.signature.outputs = []
+ then None
else Some def
+(** Return `false` if the forward function is useless and should be filtered.
+
+ - a forward function with no output (comes from a Rust function with
+ unit return type)
+ - the function has mutable borrows as inputs (which is materialized
+ by the fact we generated backward functions which were not filtered).
+
+ In such situation, every call to the Rust function will be translated to:
+ - a call to the forward function which returns nothing
+ - calls to the backward functions
+ As a failing backward function implies the forward function also fails,
+ we can filter the calls to the forward function, which thus becomes
+ useless.
+ In such situation, we can remove the forward function definition
+ altogether.
+ *)
+let keep_forward (config : config) (trans : pure_fun_translation) : bool =
+ let fwd, backs = trans in
+ (* Note that at this point, the output types are no longer seen as tuples:
+ * they should be lists of length 1. *)
+ if
+ config.filter_useless_functions
+ && fwd.signature.outputs = [ mk_result_ty unit_ty ]
+ && backs <> []
+ then false
+ else true
+
(** Add unit arguments (optionally) to functions with no arguments, and
change their output type to use `result`
*)
@@ -852,7 +902,7 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_def) :
* Note that the calls to those functions should already have been removed,
* when translating from symbolic to pure. Here, we remove the definitions
* altogether, because they are now useless *)
- let def = filter_if_backward_with_no_outputs def in
+ let def = filter_if_backward_with_no_outputs config def in
match def with
| None -> None
@@ -924,9 +974,21 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_def) :
(* We are done *)
Some def
+(** Return the forward/backward translations on which we applied the micro-passes.
+
+ Also returns a boolean indicating whether the forward function should be kept
+ or not (because useful/useless - `true` means we need to keep the forward
+ function).
+ Note that we don't "filter" the forward function and return a boolean instead,
+ because this function contains useful information to extract the backward
+ functions: keeping it is not necessary but more convenient.
+ *)
let apply_passes_to_pure_fun_translation (config : config) (ctx : trans_ctx)
- (trans : pure_fun_translation) : pure_fun_translation =
+ (trans : pure_fun_translation) : bool * pure_fun_translation =
+ (* Apply the passes to the individual functions *)
let forward, backwards = trans in
let forward = Option.get (apply_passes_to_def config ctx forward) in
let backwards = List.filter_map (apply_passes_to_def config ctx) backwards in
- (forward, backwards)
+ let trans = (forward, backwards) in
+ (* Compute whether we need to filter the forward function or not *)
+ (keep_forward config trans, trans)
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index ca214d7c..f2ed1053 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -12,6 +12,34 @@ module PP = PrintPure
(** The local logger *)
let log = L.symbolic_to_pure_log
+type config = {
+ filter_useless_back_calls : bool;
+ (** If `true`, filter the useless calls to backward functions.
+
+ The useless calls are calls to backward functions which have no outputs.
+ This case happens if the original Rust function only takes *shared* borrows
+ as inputs, and is thus pretty common.
+
+ We are allowed to do this only because in this specific case,
+ the backward function fails *exactly* when the forward function fails
+ (they actually do exactly the same thing, the only difference being
+ that the forward function can potentially return a value), and upon
+ reaching the place where we should introduce a call to the backward
+ function, we know we have introduced a call to the forward function.
+
+ Also note that in general, backward functions "do more things" than
+ forward functions, and have more opportunities to fail (even though
+ in the generated code, backward functions should fail exactly when
+ the forward functions fail).
+
+ We might want to move this optimization to the micro-passes subsequent
+ to the translation from symbolic to pure, but it is really super easy
+ to do it when going from symbolic to pure.
+ Note that we later filter the useless *forward* calls in the micro-passes,
+ where it is more natural to do.
+ *)
+}
+
type type_context = {
cfim_type_defs : T.type_def TypeDefId.Map.t;
type_defs : type_def TypeDefId.Map.t;
@@ -915,9 +943,10 @@ let fun_is_monadic (fun_id : A.fun_id) : bool =
| A.Local _ -> true
| A.Assumed aid -> Assumed.assumed_is_monadic aid
-let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =
+let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx)
+ : texpression =
match e with
- | S.Return opt_v -> translate_return opt_v ctx
+ | S.Return opt_v -> translate_return config opt_v ctx
| Panic ->
(* Here we use the function return type - note that it is ok because
* we don't match on panics which happen inside the function body -
@@ -926,13 +955,13 @@ let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =
let e = Value (v, None) in
let ty = v.ty in
{ e; ty }
- | FunCall (call, e) -> translate_function_call call e ctx
- | EndAbstraction (abs, e) -> translate_end_abstraction abs e ctx
- | Expansion (p, sv, exp) -> translate_expansion p sv exp ctx
- | Meta (meta, e) -> translate_meta meta e ctx
+ | FunCall (call, e) -> translate_function_call config call e ctx
+ | EndAbstraction (abs, e) -> translate_end_abstraction config abs e ctx
+ | Expansion (p, sv, exp) -> translate_expansion config p sv exp ctx
+ | Meta (meta, e) -> translate_meta config meta e ctx
-and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression
- =
+and translate_return (_config : config) (opt_v : V.typed_value option)
+ (ctx : bs_ctx) : texpression =
(* There are two cases:
- either we are translating a forward function, in which case the optional
value should be `Some` (it is the returned value)
@@ -964,8 +993,8 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression
let ty = ret_value.ty in
{ e; ty }
-and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
- texpression =
+and translate_function_call (config : config) (call : S.call) (e : S.expression)
+ (ctx : bs_ctx) : texpression =
(* Translate the function call *)
let type_params = List.map (ctx_translate_fwd_ty ctx) call.type_params in
let args = List.map (typed_value_to_rvalue ctx) call.args in
@@ -1011,12 +1040,12 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
let call_ty = if monadic then mk_result_ty dest_v.ty else dest_v.ty in
let call = { e = call; ty = call_ty } in
(* Translate the next expression *)
- let next_e = translate_expression e ctx in
+ let next_e = translate_expression config e ctx in
(* Put together *)
mk_let monadic dest_v call next_e
-and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
- texpression =
+and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression)
+ (ctx : bs_ctx) : texpression =
log#ldebug
(lazy
("translate_end_abstraction: abstraction kind: "
@@ -1064,7 +1093,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
(fun (var, v) -> assert ((var : var).ty = (v : typed_rvalue).ty))
variables_values;
(* Translate the next expression *)
- let next_e = translate_expression e ctx in
+ let next_e = translate_expression config e ctx in
(* Generate the assignemnts *)
let monadic = false in
List.fold_right
@@ -1129,7 +1158,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
* if necessary *)
let ctx, func = bs_ctx_register_backward_call abs ctx in
(* Translate the next expression *)
- let next_e = translate_expression e ctx in
+ let next_e = translate_expression config e ctx in
(* Put everything together *)
let args_mplaces = List.map (fun _ -> None) inputs in
let args =
@@ -1144,17 +1173,10 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
(* **Optimization**:
* =================
* We do a small optimization here: if the backward function doesn't
- * have any output, we don't introduce any function call. This case
- * happens if the function only takes *shared* borrows as inputs,
- * and is thus pretty common. We might want to move the optimization
- * to the micro-passes code, but it is really super easy to do it
- * here. Note that we are allowed to do it only because in this case,
- * the backward function *fails exactly when the forward function fails*
- * (they actually do exactly the same thing, the only difference being
- * that the forward function can potentially return a value), and we
- * know that we called the forward function before.
+ * have any output, we don't introduce any function call.
+ * See the comment in [config].
*)
- if outputs = [] then (
+ if config.filter_useless_back_calls && outputs = [] then (
(* No outputs - we do a small sanity check: the backward function
* should have exactly the same number of inputs as the forward:
* this number can be different only if the forward function returned
@@ -1218,7 +1240,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
assert (given_back.ty = input.ty))
given_back_inputs;
(* Translate the next expression *)
- let next_e = translate_expression e ctx in
+ let next_e = translate_expression config e ctx in
(* Generate the assignments *)
let monadic = false in
List.fold_right
@@ -1228,8 +1250,8 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
e)
given_back_inputs next_e
-and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
- (exp : S.expansion) (ctx : bs_ctx) : texpression =
+and translate_expansion (config : config) (p : S.mplace option)
+ (sv : V.symbolic_value) (exp : S.expansion) (ctx : bs_ctx) : texpression =
(* Translate the scrutinee *)
let scrutinee_var = lookup_var_for_symbolic_value sv ctx in
let scrutinee = mk_typed_rvalue_from_var scrutinee_var in
@@ -1246,7 +1268,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
(* The (mut/shared) borrow type is extracted to identity: we thus simply
* introduce an reassignment *)
let ctx, var = fresh_var_for_symbolic_value nsv ctx in
- let next_e = translate_expression e ctx in
+ let next_e = translate_expression config e ctx in
let monadic = false in
mk_let monadic
(mk_typed_lvalue_from_var var None)
@@ -1263,7 +1285,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
(* There is exactly one branch: no branching *)
let type_id, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in
let ctx, vars = fresh_vars_for_symbolic_values svl ctx in
- let branch = translate_expression branch ctx in
+ let branch = translate_expression config branch ctx in
match type_id with
| T.AdtId adt_id ->
(* Detect if this is an enumeration or not *)
@@ -1349,7 +1371,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
in
let pat_ty = scrutinee.ty in
let pat = mk_adt_lvalue pat_ty variant_id vars in
- let branch = translate_expression branch ctx in
+ let branch = translate_expression config branch ctx in
{ pat; branch }
in
let branches =
@@ -1367,8 +1389,8 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
| ExpandBool (true_e, false_e) ->
(* We don't need to update the context: we don't introduce any
* new values/variables *)
- let true_e = translate_expression true_e ctx in
- let false_e = translate_expression false_e ctx in
+ let true_e = translate_expression config true_e ctx in
+ let false_e = translate_expression config false_e ctx in
let e =
Switch
(mk_value_expression scrutinee scrutinee_mplace, If (true_e, false_e))
@@ -1381,12 +1403,12 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
match_branch =
(* We don't need to update the context: we don't introduce any
* new values/variables *)
- let branch = translate_expression branch_e ctx in
+ let branch = translate_expression config branch_e ctx in
let pat = mk_typed_lvalue_from_constant_value (V.Scalar v) in
{ pat; branch }
in
let branches = List.map translate_branch branches in
- let otherwise = translate_expression otherwise ctx in
+ let otherwise = translate_expression config otherwise ctx in
let pat_ty = Integer int_ty in
let otherwise_pat : typed_lvalue = { value = LvVar Dummy; ty = pat_ty } in
let otherwise : match_branch =
@@ -1402,9 +1424,9 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
List.for_all (fun (br : match_branch) -> br.branch.ty = ty) branches);
{ e; ty }
-and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) :
- texpression =
- let next_e = translate_expression e ctx in
+and translate_meta (config : config) (meta : S.meta) (e : S.expression)
+ (ctx : bs_ctx) : texpression =
+ let next_e = translate_expression config e ctx in
let meta =
match meta with
| S.Assignment (p, rv) ->
@@ -1416,7 +1438,8 @@ and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) :
let ty = next_e.ty in
{ e; ty }
-let translate_fun_def (ctx : bs_ctx) (body : S.expression) : fun_def =
+let translate_fun_def (config : config) (ctx : bs_ctx) (body : S.expression) :
+ fun_def =
let def = ctx.fun_def in
let bid = ctx.bid in
log#ldebug
@@ -1431,7 +1454,7 @@ let translate_fun_def (ctx : bs_ctx) (body : S.expression) : fun_def =
let def_id = def.A.def_id in
let basename = def.name in
let signature = bs_ctx_lookup_local_function_sig def_id bid ctx in
- let body = translate_expression body ctx in
+ let body = translate_expression config body ctx in
(* Compute the list of (properly ordered) input variables *)
let backward_inputs : var list =
match bid with
diff --git a/src/Translate.ml b/src/Translate.ml
index 3781fc33..d51ec826 100644
--- a/src/Translate.ml
+++ b/src/Translate.ml
@@ -59,7 +59,7 @@ let translate_function_to_symbolics (config : C.partial_config)
TODO: maybe we should introduce a record for this.
*)
let translate_function_to_pure (config : C.partial_config)
- (trans_ctx : trans_ctx)
+ (mp_config : Micro.config) (trans_ctx : trans_ctx)
(fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdMap.t)
(pure_type_defs : Pure.type_def Pure.TypeDefId.Map.t) (fdef : A.fun_def) :
pure_fun_translation =
@@ -134,9 +134,17 @@ let translate_function_to_pure (config : C.partial_config)
{ ctx with forward_inputs }
in
+ (* The symbolic to pure config *)
+ let sp_config =
+ {
+ SymbolicToPure.filter_useless_back_calls =
+ mp_config.filter_unused_monadic_calls;
+ }
+ in
+
(* Translate the forward function *)
let pure_forward =
- SymbolicToPure.translate_fun_def
+ SymbolicToPure.translate_fun_def sp_config
(add_forward_inputs (fst symbolic_forward) ctx)
(snd symbolic_forward)
in
@@ -196,7 +204,7 @@ let translate_function_to_pure (config : C.partial_config)
in
(* Translate *)
- SymbolicToPure.translate_fun_def ctx symbolic
+ SymbolicToPure.translate_fun_def sp_config ctx symbolic
in
let pure_backwards =
List.map translate_backward fdef.signature.regions_hierarchy
@@ -207,7 +215,7 @@ let translate_function_to_pure (config : C.partial_config)
let translate_module_to_pure (config : C.partial_config)
(mp_config : Micro.config) (m : M.cfim_module) :
- trans_ctx * Pure.type_def list * pure_fun_translation list =
+ trans_ctx * Pure.type_def list * (bool * pure_fun_translation) list =
(* Debug *)
log#ldebug (lazy "translate_module_to_pure");
@@ -249,7 +257,8 @@ let translate_module_to_pure (config : C.partial_config)
(* Translate all the functions *)
let pure_translations =
List.map
- (translate_function_to_pure config trans_ctx fun_sigs type_defs_map)
+ (translate_function_to_pure config mp_config trans_ctx fun_sigs
+ type_defs_map)
m.functions
in
@@ -305,7 +314,7 @@ let translate_module (filename : string) (dest_dir : string)
let extract_ctx =
List.fold_left
- (fun extract_ctx def ->
+ (fun extract_ctx (_, def) ->
ExtractToFStar.extract_fun_def_register_names extract_ctx def)
extract_ctx trans_funs
in
@@ -337,7 +346,8 @@ let translate_module (filename : string) (dest_dir : string)
let trans_funs =
Pure.FunDefId.Map.of_list
(List.map
- (fun ((fd, bdl) : pure_fun_translation) -> (fd.def_id, (fd, bdl)))
+ (fun ((keep_fwd, (fd, bdl)) : bool * pure_fun_translation) ->
+ (fd.def_id, (keep_fwd, (fd, bdl))))
trans_funs)
in
@@ -368,11 +378,16 @@ let translate_module (filename : string) (dest_dir : string)
(* In case of (non-mutually) recursive functions, we use a simple procedure to
* check if the forward and backward functions are mutually recursive.
*)
- let export_functions (is_rec : bool) (pure_ls : pure_fun_translation list) :
- unit =
- (* Generate the function definitions *)
+ let export_functions (is_rec : bool)
+ (pure_ls : (bool * pure_fun_translation) list) : unit =
+ (* Generate the function definitions, filtering the uselss forward
+ * functions. *)
let fls =
- List.concat (List.map (fun (fwd, back_ls) -> fwd :: back_ls) pure_ls)
+ List.concat
+ (List.map
+ (fun (keep_fwd, (fwd, back_ls)) ->
+ if keep_fwd then fwd :: back_ls else back_ls)
+ pure_ls)
in
(* Check if the functions are mutually recursive - this really works
* to check if the forward and backward translations of a single
@@ -397,8 +412,9 @@ let translate_module (filename : string) (dest_dir : string)
(* Insert unit tests if necessary *)
if test_unit_functions then
List.iter
- (fun (fwd, _) ->
- ExtractToFStar.extract_unit_test_if_unit_fun extract_ctx fmt fwd)
+ (fun (keep_fwd, (fwd, _)) ->
+ if keep_fwd then
+ ExtractToFStar.extract_unit_test_if_unit_fun extract_ctx fmt fwd)
pure_ls
in
diff --git a/src/main.ml b/src/main.ml
index 5e652809..17ab6421 100644
--- a/src/main.ml
+++ b/src/main.ml
@@ -27,6 +27,7 @@ let () =
let decompose_monads = ref false in
let unfold_monads = ref true in
let filter_unused_calls = ref true in
+ let filter_useless_functions = ref true in
let test_units = ref false in
let test_trans_units = ref false in
@@ -50,6 +51,9 @@ let () =
( "-filter-unused-calls",
Arg.Set filter_unused_calls,
" Filter the unused function calls, when possible" );
+ ( "-filter-useless-funs",
+ Arg.Set filter_useless_functions,
+ " Filter the useless forward/backward functions" );
( "-test-units",
Arg.Set test_units,
" Test the unit functions with the concrete interpreter" );
@@ -142,6 +146,7 @@ let () =
Micro.decompose_monadic_let_bindings = !decompose_monads;
unfold_monadic_let_bindings = !unfold_monads;
filter_unused_monadic_calls = !filter_unused_calls;
+ filter_useless_functions = !filter_useless_functions;
add_unit_args = false;
}
in