diff options
author | Son Ho | 2022-02-09 10:55:40 +0100 |
---|---|---|
committer | Son Ho | 2022-02-09 10:55:40 +0100 |
commit | b5295c0bf9e7aee437eed8f8fc57e4fba46cb8ef (patch) | |
tree | df7b68348ce4bd784e2c14012652ad3e0fc6f91f /src | |
parent | b85a44d557c7c03e0052b03a824612a99409ef03 (diff) |
Implement filtering of useless forward functions
Diffstat (limited to 'src')
-rw-r--r-- | src/PureMicroPasses.ml | 74 | ||||
-rw-r--r-- | src/SymbolicToPure.ml | 103 | ||||
-rw-r--r-- | src/Translate.ml | 42 | ||||
-rw-r--r-- | src/main.ml | 5 |
4 files changed, 165 insertions, 59 deletions
diff --git a/src/PureMicroPasses.ml b/src/PureMicroPasses.ml index 59871600..7094d885 100644 --- a/src/PureMicroPasses.ml +++ b/src/PureMicroPasses.ml @@ -47,6 +47,20 @@ type config = { See the comments for [expression_contains_child_call_in_all_paths] for additional explanations. + + TODO: rename to [filter_useless_monadic_calls] + *) + filter_useless_functions : bool; + (** If [filter_unused_monadic_calls] is activated, some functions + become useless: if this option is true, we don't extract them. + + The calls to functions which always get filtered are: + - the forward functions with unit return value + - the backward functions which don't output anything (backward + functions coming from rust functions with no mutable borrows + as input values - note that if a function doesn't take mutable + borrows as inputs, it can't return mutable borrows; we actually + dynamically check for that). *) add_unit_args : bool; (** Add unit input arguments to functions with no arguments. *) @@ -612,11 +626,47 @@ let filter_unused (filter_monadic_calls : bool) (ctx : trans_ctx) { def with body; inputs_lvs } (** Return `None` if the function is a backward function with no outputs (so - that we eliminate the definition which is useless) *) -let filter_if_backward_with_no_outputs (def : fun_def) : fun_def option = - if Option.is_some def.back_id && def.signature.outputs = [] then None + that we eliminate the definition which is useless). + + Note that the calls to such functions are filtered when translating from + symbolic to pure. Here, we remove the definitions altogether, because they + are now useless + *) +let filter_if_backward_with_no_outputs (config : config) (def : fun_def) : + fun_def option = + if + config.filter_useless_functions && Option.is_some def.back_id + && def.signature.outputs = [] + then None else Some def +(** Return `false` if the forward function is useless and should be filtered. + + - a forward function with no output (comes from a Rust function with + unit return type) + - the function has mutable borrows as inputs (which is materialized + by the fact we generated backward functions which were not filtered). + + In such situation, every call to the Rust function will be translated to: + - a call to the forward function which returns nothing + - calls to the backward functions + As a failing backward function implies the forward function also fails, + we can filter the calls to the forward function, which thus becomes + useless. + In such situation, we can remove the forward function definition + altogether. + *) +let keep_forward (config : config) (trans : pure_fun_translation) : bool = + let fwd, backs = trans in + (* Note that at this point, the output types are no longer seen as tuples: + * they should be lists of length 1. *) + if + config.filter_useless_functions + && fwd.signature.outputs = [ mk_result_ty unit_ty ] + && backs <> [] + then false + else true + (** Add unit arguments (optionally) to functions with no arguments, and change their output type to use `result` *) @@ -852,7 +902,7 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_def) : * Note that the calls to those functions should already have been removed, * when translating from symbolic to pure. Here, we remove the definitions * altogether, because they are now useless *) - let def = filter_if_backward_with_no_outputs def in + let def = filter_if_backward_with_no_outputs config def in match def with | None -> None @@ -924,9 +974,21 @@ let apply_passes_to_def (config : config) (ctx : trans_ctx) (def : fun_def) : (* We are done *) Some def +(** Return the forward/backward translations on which we applied the micro-passes. + + Also returns a boolean indicating whether the forward function should be kept + or not (because useful/useless - `true` means we need to keep the forward + function). + Note that we don't "filter" the forward function and return a boolean instead, + because this function contains useful information to extract the backward + functions: keeping it is not necessary but more convenient. + *) let apply_passes_to_pure_fun_translation (config : config) (ctx : trans_ctx) - (trans : pure_fun_translation) : pure_fun_translation = + (trans : pure_fun_translation) : bool * pure_fun_translation = + (* Apply the passes to the individual functions *) let forward, backwards = trans in let forward = Option.get (apply_passes_to_def config ctx forward) in let backwards = List.filter_map (apply_passes_to_def config ctx) backwards in - (forward, backwards) + let trans = (forward, backwards) in + (* Compute whether we need to filter the forward function or not *) + (keep_forward config trans, trans) diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index ca214d7c..f2ed1053 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -12,6 +12,34 @@ module PP = PrintPure (** The local logger *) let log = L.symbolic_to_pure_log +type config = { + filter_useless_back_calls : bool; + (** If `true`, filter the useless calls to backward functions. + + The useless calls are calls to backward functions which have no outputs. + This case happens if the original Rust function only takes *shared* borrows + as inputs, and is thus pretty common. + + We are allowed to do this only because in this specific case, + the backward function fails *exactly* when the forward function fails + (they actually do exactly the same thing, the only difference being + that the forward function can potentially return a value), and upon + reaching the place where we should introduce a call to the backward + function, we know we have introduced a call to the forward function. + + Also note that in general, backward functions "do more things" than + forward functions, and have more opportunities to fail (even though + in the generated code, backward functions should fail exactly when + the forward functions fail). + + We might want to move this optimization to the micro-passes subsequent + to the translation from symbolic to pure, but it is really super easy + to do it when going from symbolic to pure. + Note that we later filter the useless *forward* calls in the micro-passes, + where it is more natural to do. + *) +} + type type_context = { cfim_type_defs : T.type_def TypeDefId.Map.t; type_defs : type_def TypeDefId.Map.t; @@ -915,9 +943,10 @@ let fun_is_monadic (fun_id : A.fun_id) : bool = | A.Local _ -> true | A.Assumed aid -> Assumed.assumed_is_monadic aid -let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression = +let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx) + : texpression = match e with - | S.Return opt_v -> translate_return opt_v ctx + | S.Return opt_v -> translate_return config opt_v ctx | Panic -> (* Here we use the function return type - note that it is ok because * we don't match on panics which happen inside the function body - @@ -926,13 +955,13 @@ let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression = let e = Value (v, None) in let ty = v.ty in { e; ty } - | FunCall (call, e) -> translate_function_call call e ctx - | EndAbstraction (abs, e) -> translate_end_abstraction abs e ctx - | Expansion (p, sv, exp) -> translate_expansion p sv exp ctx - | Meta (meta, e) -> translate_meta meta e ctx + | FunCall (call, e) -> translate_function_call config call e ctx + | EndAbstraction (abs, e) -> translate_end_abstraction config abs e ctx + | Expansion (p, sv, exp) -> translate_expansion config p sv exp ctx + | Meta (meta, e) -> translate_meta config meta e ctx -and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression - = +and translate_return (_config : config) (opt_v : V.typed_value option) + (ctx : bs_ctx) : texpression = (* There are two cases: - either we are translating a forward function, in which case the optional value should be `Some` (it is the returned value) @@ -964,8 +993,8 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression let ty = ret_value.ty in { e; ty } -and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : - texpression = +and translate_function_call (config : config) (call : S.call) (e : S.expression) + (ctx : bs_ctx) : texpression = (* Translate the function call *) let type_params = List.map (ctx_translate_fwd_ty ctx) call.type_params in let args = List.map (typed_value_to_rvalue ctx) call.args in @@ -1011,12 +1040,12 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let call_ty = if monadic then mk_result_ty dest_v.ty else dest_v.ty in let call = { e = call; ty = call_ty } in (* Translate the next expression *) - let next_e = translate_expression e ctx in + let next_e = translate_expression config e ctx in (* Put together *) mk_let monadic dest_v call next_e -and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : - texpression = +and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) + (ctx : bs_ctx) : texpression = log#ldebug (lazy ("translate_end_abstraction: abstraction kind: " @@ -1064,7 +1093,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : (fun (var, v) -> assert ((var : var).ty = (v : typed_rvalue).ty)) variables_values; (* Translate the next expression *) - let next_e = translate_expression e ctx in + let next_e = translate_expression config e ctx in (* Generate the assignemnts *) let monadic = false in List.fold_right @@ -1129,7 +1158,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : * if necessary *) let ctx, func = bs_ctx_register_backward_call abs ctx in (* Translate the next expression *) - let next_e = translate_expression e ctx in + let next_e = translate_expression config e ctx in (* Put everything together *) let args_mplaces = List.map (fun _ -> None) inputs in let args = @@ -1144,17 +1173,10 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : (* **Optimization**: * ================= * We do a small optimization here: if the backward function doesn't - * have any output, we don't introduce any function call. This case - * happens if the function only takes *shared* borrows as inputs, - * and is thus pretty common. We might want to move the optimization - * to the micro-passes code, but it is really super easy to do it - * here. Note that we are allowed to do it only because in this case, - * the backward function *fails exactly when the forward function fails* - * (they actually do exactly the same thing, the only difference being - * that the forward function can potentially return a value), and we - * know that we called the forward function before. + * have any output, we don't introduce any function call. + * See the comment in [config]. *) - if outputs = [] then ( + if config.filter_useless_back_calls && outputs = [] then ( (* No outputs - we do a small sanity check: the backward function * should have exactly the same number of inputs as the forward: * this number can be different only if the forward function returned @@ -1218,7 +1240,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : assert (given_back.ty = input.ty)) given_back_inputs; (* Translate the next expression *) - let next_e = translate_expression e ctx in + let next_e = translate_expression config e ctx in (* Generate the assignments *) let monadic = false in List.fold_right @@ -1228,8 +1250,8 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : e) given_back_inputs next_e -and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) - (exp : S.expansion) (ctx : bs_ctx) : texpression = +and translate_expansion (config : config) (p : S.mplace option) + (sv : V.symbolic_value) (exp : S.expansion) (ctx : bs_ctx) : texpression = (* Translate the scrutinee *) let scrutinee_var = lookup_var_for_symbolic_value sv ctx in let scrutinee = mk_typed_rvalue_from_var scrutinee_var in @@ -1246,7 +1268,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) (* The (mut/shared) borrow type is extracted to identity: we thus simply * introduce an reassignment *) let ctx, var = fresh_var_for_symbolic_value nsv ctx in - let next_e = translate_expression e ctx in + let next_e = translate_expression config e ctx in let monadic = false in mk_let monadic (mk_typed_lvalue_from_var var None) @@ -1263,7 +1285,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) (* There is exactly one branch: no branching *) let type_id, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in let ctx, vars = fresh_vars_for_symbolic_values svl ctx in - let branch = translate_expression branch ctx in + let branch = translate_expression config branch ctx in match type_id with | T.AdtId adt_id -> (* Detect if this is an enumeration or not *) @@ -1349,7 +1371,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) in let pat_ty = scrutinee.ty in let pat = mk_adt_lvalue pat_ty variant_id vars in - let branch = translate_expression branch ctx in + let branch = translate_expression config branch ctx in { pat; branch } in let branches = @@ -1367,8 +1389,8 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) | ExpandBool (true_e, false_e) -> (* We don't need to update the context: we don't introduce any * new values/variables *) - let true_e = translate_expression true_e ctx in - let false_e = translate_expression false_e ctx in + let true_e = translate_expression config true_e ctx in + let false_e = translate_expression config false_e ctx in let e = Switch (mk_value_expression scrutinee scrutinee_mplace, If (true_e, false_e)) @@ -1381,12 +1403,12 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) match_branch = (* We don't need to update the context: we don't introduce any * new values/variables *) - let branch = translate_expression branch_e ctx in + let branch = translate_expression config branch_e ctx in let pat = mk_typed_lvalue_from_constant_value (V.Scalar v) in { pat; branch } in let branches = List.map translate_branch branches in - let otherwise = translate_expression otherwise ctx in + let otherwise = translate_expression config otherwise ctx in let pat_ty = Integer int_ty in let otherwise_pat : typed_lvalue = { value = LvVar Dummy; ty = pat_ty } in let otherwise : match_branch = @@ -1402,9 +1424,9 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) List.for_all (fun (br : match_branch) -> br.branch.ty = ty) branches); { e; ty } -and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) : - texpression = - let next_e = translate_expression e ctx in +and translate_meta (config : config) (meta : S.meta) (e : S.expression) + (ctx : bs_ctx) : texpression = + let next_e = translate_expression config e ctx in let meta = match meta with | S.Assignment (p, rv) -> @@ -1416,7 +1438,8 @@ and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) : let ty = next_e.ty in { e; ty } -let translate_fun_def (ctx : bs_ctx) (body : S.expression) : fun_def = +let translate_fun_def (config : config) (ctx : bs_ctx) (body : S.expression) : + fun_def = let def = ctx.fun_def in let bid = ctx.bid in log#ldebug @@ -1431,7 +1454,7 @@ let translate_fun_def (ctx : bs_ctx) (body : S.expression) : fun_def = let def_id = def.A.def_id in let basename = def.name in let signature = bs_ctx_lookup_local_function_sig def_id bid ctx in - let body = translate_expression body ctx in + let body = translate_expression config body ctx in (* Compute the list of (properly ordered) input variables *) let backward_inputs : var list = match bid with diff --git a/src/Translate.ml b/src/Translate.ml index 3781fc33..d51ec826 100644 --- a/src/Translate.ml +++ b/src/Translate.ml @@ -59,7 +59,7 @@ let translate_function_to_symbolics (config : C.partial_config) TODO: maybe we should introduce a record for this. *) let translate_function_to_pure (config : C.partial_config) - (trans_ctx : trans_ctx) + (mp_config : Micro.config) (trans_ctx : trans_ctx) (fun_sigs : SymbolicToPure.fun_sig_named_outputs RegularFunIdMap.t) (pure_type_defs : Pure.type_def Pure.TypeDefId.Map.t) (fdef : A.fun_def) : pure_fun_translation = @@ -134,9 +134,17 @@ let translate_function_to_pure (config : C.partial_config) { ctx with forward_inputs } in + (* The symbolic to pure config *) + let sp_config = + { + SymbolicToPure.filter_useless_back_calls = + mp_config.filter_unused_monadic_calls; + } + in + (* Translate the forward function *) let pure_forward = - SymbolicToPure.translate_fun_def + SymbolicToPure.translate_fun_def sp_config (add_forward_inputs (fst symbolic_forward) ctx) (snd symbolic_forward) in @@ -196,7 +204,7 @@ let translate_function_to_pure (config : C.partial_config) in (* Translate *) - SymbolicToPure.translate_fun_def ctx symbolic + SymbolicToPure.translate_fun_def sp_config ctx symbolic in let pure_backwards = List.map translate_backward fdef.signature.regions_hierarchy @@ -207,7 +215,7 @@ let translate_function_to_pure (config : C.partial_config) let translate_module_to_pure (config : C.partial_config) (mp_config : Micro.config) (m : M.cfim_module) : - trans_ctx * Pure.type_def list * pure_fun_translation list = + trans_ctx * Pure.type_def list * (bool * pure_fun_translation) list = (* Debug *) log#ldebug (lazy "translate_module_to_pure"); @@ -249,7 +257,8 @@ let translate_module_to_pure (config : C.partial_config) (* Translate all the functions *) let pure_translations = List.map - (translate_function_to_pure config trans_ctx fun_sigs type_defs_map) + (translate_function_to_pure config mp_config trans_ctx fun_sigs + type_defs_map) m.functions in @@ -305,7 +314,7 @@ let translate_module (filename : string) (dest_dir : string) let extract_ctx = List.fold_left - (fun extract_ctx def -> + (fun extract_ctx (_, def) -> ExtractToFStar.extract_fun_def_register_names extract_ctx def) extract_ctx trans_funs in @@ -337,7 +346,8 @@ let translate_module (filename : string) (dest_dir : string) let trans_funs = Pure.FunDefId.Map.of_list (List.map - (fun ((fd, bdl) : pure_fun_translation) -> (fd.def_id, (fd, bdl))) + (fun ((keep_fwd, (fd, bdl)) : bool * pure_fun_translation) -> + (fd.def_id, (keep_fwd, (fd, bdl)))) trans_funs) in @@ -368,11 +378,16 @@ let translate_module (filename : string) (dest_dir : string) (* In case of (non-mutually) recursive functions, we use a simple procedure to * check if the forward and backward functions are mutually recursive. *) - let export_functions (is_rec : bool) (pure_ls : pure_fun_translation list) : - unit = - (* Generate the function definitions *) + let export_functions (is_rec : bool) + (pure_ls : (bool * pure_fun_translation) list) : unit = + (* Generate the function definitions, filtering the uselss forward + * functions. *) let fls = - List.concat (List.map (fun (fwd, back_ls) -> fwd :: back_ls) pure_ls) + List.concat + (List.map + (fun (keep_fwd, (fwd, back_ls)) -> + if keep_fwd then fwd :: back_ls else back_ls) + pure_ls) in (* Check if the functions are mutually recursive - this really works * to check if the forward and backward translations of a single @@ -397,8 +412,9 @@ let translate_module (filename : string) (dest_dir : string) (* Insert unit tests if necessary *) if test_unit_functions then List.iter - (fun (fwd, _) -> - ExtractToFStar.extract_unit_test_if_unit_fun extract_ctx fmt fwd) + (fun (keep_fwd, (fwd, _)) -> + if keep_fwd then + ExtractToFStar.extract_unit_test_if_unit_fun extract_ctx fmt fwd) pure_ls in diff --git a/src/main.ml b/src/main.ml index 5e652809..17ab6421 100644 --- a/src/main.ml +++ b/src/main.ml @@ -27,6 +27,7 @@ let () = let decompose_monads = ref false in let unfold_monads = ref true in let filter_unused_calls = ref true in + let filter_useless_functions = ref true in let test_units = ref false in let test_trans_units = ref false in @@ -50,6 +51,9 @@ let () = ( "-filter-unused-calls", Arg.Set filter_unused_calls, " Filter the unused function calls, when possible" ); + ( "-filter-useless-funs", + Arg.Set filter_useless_functions, + " Filter the useless forward/backward functions" ); ( "-test-units", Arg.Set test_units, " Test the unit functions with the concrete interpreter" ); @@ -142,6 +146,7 @@ let () = Micro.decompose_monadic_let_bindings = !decompose_monads; unfold_monadic_let_bindings = !unfold_monads; filter_unused_monadic_calls = !filter_unused_calls; + filter_useless_functions = !filter_useless_functions; add_unit_args = false; } in |