From b5295c0bf9e7aee437eed8f8fc57e4fba46cb8ef Mon Sep 17 00:00:00 2001 From: Son Ho Date: Wed, 9 Feb 2022 10:55:40 +0100 Subject: Implement filtering of useless forward functions --- src/SymbolicToPure.ml | 103 ++++++++++++++++++++++++++++++-------------------- 1 file changed, 63 insertions(+), 40 deletions(-) (limited to 'src/SymbolicToPure.ml') diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml index ca214d7c..f2ed1053 100644 --- a/src/SymbolicToPure.ml +++ b/src/SymbolicToPure.ml @@ -12,6 +12,34 @@ module PP = PrintPure (** The local logger *) let log = L.symbolic_to_pure_log +type config = { + filter_useless_back_calls : bool; + (** If `true`, filter the useless calls to backward functions. + + The useless calls are calls to backward functions which have no outputs. + This case happens if the original Rust function only takes *shared* borrows + as inputs, and is thus pretty common. + + We are allowed to do this only because in this specific case, + the backward function fails *exactly* when the forward function fails + (they actually do exactly the same thing, the only difference being + that the forward function can potentially return a value), and upon + reaching the place where we should introduce a call to the backward + function, we know we have introduced a call to the forward function. + + Also note that in general, backward functions "do more things" than + forward functions, and have more opportunities to fail (even though + in the generated code, backward functions should fail exactly when + the forward functions fail). + + We might want to move this optimization to the micro-passes subsequent + to the translation from symbolic to pure, but it is really super easy + to do it when going from symbolic to pure. + Note that we later filter the useless *forward* calls in the micro-passes, + where it is more natural to do. + *) +} + type type_context = { cfim_type_defs : T.type_def TypeDefId.Map.t; type_defs : type_def TypeDefId.Map.t; @@ -915,9 +943,10 @@ let fun_is_monadic (fun_id : A.fun_id) : bool = | A.Local _ -> true | A.Assumed aid -> Assumed.assumed_is_monadic aid -let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression = +let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx) + : texpression = match e with - | S.Return opt_v -> translate_return opt_v ctx + | S.Return opt_v -> translate_return config opt_v ctx | Panic -> (* Here we use the function return type - note that it is ok because * we don't match on panics which happen inside the function body - @@ -926,13 +955,13 @@ let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression = let e = Value (v, None) in let ty = v.ty in { e; ty } - | FunCall (call, e) -> translate_function_call call e ctx - | EndAbstraction (abs, e) -> translate_end_abstraction abs e ctx - | Expansion (p, sv, exp) -> translate_expansion p sv exp ctx - | Meta (meta, e) -> translate_meta meta e ctx + | FunCall (call, e) -> translate_function_call config call e ctx + | EndAbstraction (abs, e) -> translate_end_abstraction config abs e ctx + | Expansion (p, sv, exp) -> translate_expansion config p sv exp ctx + | Meta (meta, e) -> translate_meta config meta e ctx -and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression - = +and translate_return (_config : config) (opt_v : V.typed_value option) + (ctx : bs_ctx) : texpression = (* There are two cases: - either we are translating a forward function, in which case the optional value should be `Some` (it is the returned value) @@ -964,8 +993,8 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression let ty = ret_value.ty in { e; ty } -and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : - texpression = +and translate_function_call (config : config) (call : S.call) (e : S.expression) + (ctx : bs_ctx) : texpression = (* Translate the function call *) let type_params = List.map (ctx_translate_fwd_ty ctx) call.type_params in let args = List.map (typed_value_to_rvalue ctx) call.args in @@ -1011,12 +1040,12 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let call_ty = if monadic then mk_result_ty dest_v.ty else dest_v.ty in let call = { e = call; ty = call_ty } in (* Translate the next expression *) - let next_e = translate_expression e ctx in + let next_e = translate_expression config e ctx in (* Put together *) mk_let monadic dest_v call next_e -and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : - texpression = +and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression) + (ctx : bs_ctx) : texpression = log#ldebug (lazy ("translate_end_abstraction: abstraction kind: " @@ -1064,7 +1093,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : (fun (var, v) -> assert ((var : var).ty = (v : typed_rvalue).ty)) variables_values; (* Translate the next expression *) - let next_e = translate_expression e ctx in + let next_e = translate_expression config e ctx in (* Generate the assignemnts *) let monadic = false in List.fold_right @@ -1129,7 +1158,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : * if necessary *) let ctx, func = bs_ctx_register_backward_call abs ctx in (* Translate the next expression *) - let next_e = translate_expression e ctx in + let next_e = translate_expression config e ctx in (* Put everything together *) let args_mplaces = List.map (fun _ -> None) inputs in let args = @@ -1144,17 +1173,10 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : (* **Optimization**: * ================= * We do a small optimization here: if the backward function doesn't - * have any output, we don't introduce any function call. This case - * happens if the function only takes *shared* borrows as inputs, - * and is thus pretty common. We might want to move the optimization - * to the micro-passes code, but it is really super easy to do it - * here. Note that we are allowed to do it only because in this case, - * the backward function *fails exactly when the forward function fails* - * (they actually do exactly the same thing, the only difference being - * that the forward function can potentially return a value), and we - * know that we called the forward function before. + * have any output, we don't introduce any function call. + * See the comment in [config]. *) - if outputs = [] then ( + if config.filter_useless_back_calls && outputs = [] then ( (* No outputs - we do a small sanity check: the backward function * should have exactly the same number of inputs as the forward: * this number can be different only if the forward function returned @@ -1218,7 +1240,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : assert (given_back.ty = input.ty)) given_back_inputs; (* Translate the next expression *) - let next_e = translate_expression e ctx in + let next_e = translate_expression config e ctx in (* Generate the assignments *) let monadic = false in List.fold_right @@ -1228,8 +1250,8 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : e) given_back_inputs next_e -and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) - (exp : S.expansion) (ctx : bs_ctx) : texpression = +and translate_expansion (config : config) (p : S.mplace option) + (sv : V.symbolic_value) (exp : S.expansion) (ctx : bs_ctx) : texpression = (* Translate the scrutinee *) let scrutinee_var = lookup_var_for_symbolic_value sv ctx in let scrutinee = mk_typed_rvalue_from_var scrutinee_var in @@ -1246,7 +1268,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) (* The (mut/shared) borrow type is extracted to identity: we thus simply * introduce an reassignment *) let ctx, var = fresh_var_for_symbolic_value nsv ctx in - let next_e = translate_expression e ctx in + let next_e = translate_expression config e ctx in let monadic = false in mk_let monadic (mk_typed_lvalue_from_var var None) @@ -1263,7 +1285,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) (* There is exactly one branch: no branching *) let type_id, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in let ctx, vars = fresh_vars_for_symbolic_values svl ctx in - let branch = translate_expression branch ctx in + let branch = translate_expression config branch ctx in match type_id with | T.AdtId adt_id -> (* Detect if this is an enumeration or not *) @@ -1349,7 +1371,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) in let pat_ty = scrutinee.ty in let pat = mk_adt_lvalue pat_ty variant_id vars in - let branch = translate_expression branch ctx in + let branch = translate_expression config branch ctx in { pat; branch } in let branches = @@ -1367,8 +1389,8 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) | ExpandBool (true_e, false_e) -> (* We don't need to update the context: we don't introduce any * new values/variables *) - let true_e = translate_expression true_e ctx in - let false_e = translate_expression false_e ctx in + let true_e = translate_expression config true_e ctx in + let false_e = translate_expression config false_e ctx in let e = Switch (mk_value_expression scrutinee scrutinee_mplace, If (true_e, false_e)) @@ -1381,12 +1403,12 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) match_branch = (* We don't need to update the context: we don't introduce any * new values/variables *) - let branch = translate_expression branch_e ctx in + let branch = translate_expression config branch_e ctx in let pat = mk_typed_lvalue_from_constant_value (V.Scalar v) in { pat; branch } in let branches = List.map translate_branch branches in - let otherwise = translate_expression otherwise ctx in + let otherwise = translate_expression config otherwise ctx in let pat_ty = Integer int_ty in let otherwise_pat : typed_lvalue = { value = LvVar Dummy; ty = pat_ty } in let otherwise : match_branch = @@ -1402,9 +1424,9 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value) List.for_all (fun (br : match_branch) -> br.branch.ty = ty) branches); { e; ty } -and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) : - texpression = - let next_e = translate_expression e ctx in +and translate_meta (config : config) (meta : S.meta) (e : S.expression) + (ctx : bs_ctx) : texpression = + let next_e = translate_expression config e ctx in let meta = match meta with | S.Assignment (p, rv) -> @@ -1416,7 +1438,8 @@ and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) : let ty = next_e.ty in { e; ty } -let translate_fun_def (ctx : bs_ctx) (body : S.expression) : fun_def = +let translate_fun_def (config : config) (ctx : bs_ctx) (body : S.expression) : + fun_def = let def = ctx.fun_def in let bid = ctx.bid in log#ldebug @@ -1431,7 +1454,7 @@ let translate_fun_def (ctx : bs_ctx) (body : S.expression) : fun_def = let def_id = def.A.def_id in let basename = def.name in let signature = bs_ctx_lookup_local_function_sig def_id bid ctx in - let body = translate_expression body ctx in + let body = translate_expression config body ctx in (* Compute the list of (properly ordered) input variables *) let backward_inputs : var list = match bid with -- cgit v1.2.3