summaryrefslogtreecommitdiff
path: root/src/SymbolicToPure.ml
diff options
context:
space:
mode:
Diffstat (limited to 'src/SymbolicToPure.ml')
-rw-r--r--src/SymbolicToPure.ml103
1 files changed, 63 insertions, 40 deletions
diff --git a/src/SymbolicToPure.ml b/src/SymbolicToPure.ml
index ca214d7c..f2ed1053 100644
--- a/src/SymbolicToPure.ml
+++ b/src/SymbolicToPure.ml
@@ -12,6 +12,34 @@ module PP = PrintPure
(** The local logger *)
let log = L.symbolic_to_pure_log
+type config = {
+ filter_useless_back_calls : bool;
+ (** If `true`, filter the useless calls to backward functions.
+
+ The useless calls are calls to backward functions which have no outputs.
+ This case happens if the original Rust function only takes *shared* borrows
+ as inputs, and is thus pretty common.
+
+ We are allowed to do this only because in this specific case,
+ the backward function fails *exactly* when the forward function fails
+ (they actually do exactly the same thing, the only difference being
+ that the forward function can potentially return a value), and upon
+ reaching the place where we should introduce a call to the backward
+ function, we know we have introduced a call to the forward function.
+
+ Also note that in general, backward functions "do more things" than
+ forward functions, and have more opportunities to fail (even though
+ in the generated code, backward functions should fail exactly when
+ the forward functions fail).
+
+ We might want to move this optimization to the micro-passes subsequent
+ to the translation from symbolic to pure, but it is really super easy
+ to do it when going from symbolic to pure.
+ Note that we later filter the useless *forward* calls in the micro-passes,
+ where it is more natural to do.
+ *)
+}
+
type type_context = {
cfim_type_defs : T.type_def TypeDefId.Map.t;
type_defs : type_def TypeDefId.Map.t;
@@ -915,9 +943,10 @@ let fun_is_monadic (fun_id : A.fun_id) : bool =
| A.Local _ -> true
| A.Assumed aid -> Assumed.assumed_is_monadic aid
-let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =
+let rec translate_expression (config : config) (e : S.expression) (ctx : bs_ctx)
+ : texpression =
match e with
- | S.Return opt_v -> translate_return opt_v ctx
+ | S.Return opt_v -> translate_return config opt_v ctx
| Panic ->
(* Here we use the function return type - note that it is ok because
* we don't match on panics which happen inside the function body -
@@ -926,13 +955,13 @@ let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =
let e = Value (v, None) in
let ty = v.ty in
{ e; ty }
- | FunCall (call, e) -> translate_function_call call e ctx
- | EndAbstraction (abs, e) -> translate_end_abstraction abs e ctx
- | Expansion (p, sv, exp) -> translate_expansion p sv exp ctx
- | Meta (meta, e) -> translate_meta meta e ctx
+ | FunCall (call, e) -> translate_function_call config call e ctx
+ | EndAbstraction (abs, e) -> translate_end_abstraction config abs e ctx
+ | Expansion (p, sv, exp) -> translate_expansion config p sv exp ctx
+ | Meta (meta, e) -> translate_meta config meta e ctx
-and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression
- =
+and translate_return (_config : config) (opt_v : V.typed_value option)
+ (ctx : bs_ctx) : texpression =
(* There are two cases:
- either we are translating a forward function, in which case the optional
value should be `Some` (it is the returned value)
@@ -964,8 +993,8 @@ and translate_return (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression
let ty = ret_value.ty in
{ e; ty }
-and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
- texpression =
+and translate_function_call (config : config) (call : S.call) (e : S.expression)
+ (ctx : bs_ctx) : texpression =
(* Translate the function call *)
let type_params = List.map (ctx_translate_fwd_ty ctx) call.type_params in
let args = List.map (typed_value_to_rvalue ctx) call.args in
@@ -1011,12 +1040,12 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
let call_ty = if monadic then mk_result_ty dest_v.ty else dest_v.ty in
let call = { e = call; ty = call_ty } in
(* Translate the next expression *)
- let next_e = translate_expression e ctx in
+ let next_e = translate_expression config e ctx in
(* Put together *)
mk_let monadic dest_v call next_e
-and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
- texpression =
+and translate_end_abstraction (config : config) (abs : V.abs) (e : S.expression)
+ (ctx : bs_ctx) : texpression =
log#ldebug
(lazy
("translate_end_abstraction: abstraction kind: "
@@ -1064,7 +1093,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
(fun (var, v) -> assert ((var : var).ty = (v : typed_rvalue).ty))
variables_values;
(* Translate the next expression *)
- let next_e = translate_expression e ctx in
+ let next_e = translate_expression config e ctx in
(* Generate the assignemnts *)
let monadic = false in
List.fold_right
@@ -1129,7 +1158,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
* if necessary *)
let ctx, func = bs_ctx_register_backward_call abs ctx in
(* Translate the next expression *)
- let next_e = translate_expression e ctx in
+ let next_e = translate_expression config e ctx in
(* Put everything together *)
let args_mplaces = List.map (fun _ -> None) inputs in
let args =
@@ -1144,17 +1173,10 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
(* **Optimization**:
* =================
* We do a small optimization here: if the backward function doesn't
- * have any output, we don't introduce any function call. This case
- * happens if the function only takes *shared* borrows as inputs,
- * and is thus pretty common. We might want to move the optimization
- * to the micro-passes code, but it is really super easy to do it
- * here. Note that we are allowed to do it only because in this case,
- * the backward function *fails exactly when the forward function fails*
- * (they actually do exactly the same thing, the only difference being
- * that the forward function can potentially return a value), and we
- * know that we called the forward function before.
+ * have any output, we don't introduce any function call.
+ * See the comment in [config].
*)
- if outputs = [] then (
+ if config.filter_useless_back_calls && outputs = [] then (
(* No outputs - we do a small sanity check: the backward function
* should have exactly the same number of inputs as the forward:
* this number can be different only if the forward function returned
@@ -1218,7 +1240,7 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
assert (given_back.ty = input.ty))
given_back_inputs;
(* Translate the next expression *)
- let next_e = translate_expression e ctx in
+ let next_e = translate_expression config e ctx in
(* Generate the assignments *)
let monadic = false in
List.fold_right
@@ -1228,8 +1250,8 @@ and translate_end_abstraction (abs : V.abs) (e : S.expression) (ctx : bs_ctx) :
e)
given_back_inputs next_e
-and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
- (exp : S.expansion) (ctx : bs_ctx) : texpression =
+and translate_expansion (config : config) (p : S.mplace option)
+ (sv : V.symbolic_value) (exp : S.expansion) (ctx : bs_ctx) : texpression =
(* Translate the scrutinee *)
let scrutinee_var = lookup_var_for_symbolic_value sv ctx in
let scrutinee = mk_typed_rvalue_from_var scrutinee_var in
@@ -1246,7 +1268,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
(* The (mut/shared) borrow type is extracted to identity: we thus simply
* introduce an reassignment *)
let ctx, var = fresh_var_for_symbolic_value nsv ctx in
- let next_e = translate_expression e ctx in
+ let next_e = translate_expression config e ctx in
let monadic = false in
mk_let monadic
(mk_typed_lvalue_from_var var None)
@@ -1263,7 +1285,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
(* There is exactly one branch: no branching *)
let type_id, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in
let ctx, vars = fresh_vars_for_symbolic_values svl ctx in
- let branch = translate_expression branch ctx in
+ let branch = translate_expression config branch ctx in
match type_id with
| T.AdtId adt_id ->
(* Detect if this is an enumeration or not *)
@@ -1349,7 +1371,7 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
in
let pat_ty = scrutinee.ty in
let pat = mk_adt_lvalue pat_ty variant_id vars in
- let branch = translate_expression branch ctx in
+ let branch = translate_expression config branch ctx in
{ pat; branch }
in
let branches =
@@ -1367,8 +1389,8 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
| ExpandBool (true_e, false_e) ->
(* We don't need to update the context: we don't introduce any
* new values/variables *)
- let true_e = translate_expression true_e ctx in
- let false_e = translate_expression false_e ctx in
+ let true_e = translate_expression config true_e ctx in
+ let false_e = translate_expression config false_e ctx in
let e =
Switch
(mk_value_expression scrutinee scrutinee_mplace, If (true_e, false_e))
@@ -1381,12 +1403,12 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
match_branch =
(* We don't need to update the context: we don't introduce any
* new values/variables *)
- let branch = translate_expression branch_e ctx in
+ let branch = translate_expression config branch_e ctx in
let pat = mk_typed_lvalue_from_constant_value (V.Scalar v) in
{ pat; branch }
in
let branches = List.map translate_branch branches in
- let otherwise = translate_expression otherwise ctx in
+ let otherwise = translate_expression config otherwise ctx in
let pat_ty = Integer int_ty in
let otherwise_pat : typed_lvalue = { value = LvVar Dummy; ty = pat_ty } in
let otherwise : match_branch =
@@ -1402,9 +1424,9 @@ and translate_expansion (p : S.mplace option) (sv : V.symbolic_value)
List.for_all (fun (br : match_branch) -> br.branch.ty = ty) branches);
{ e; ty }
-and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) :
- texpression =
- let next_e = translate_expression e ctx in
+and translate_meta (config : config) (meta : S.meta) (e : S.expression)
+ (ctx : bs_ctx) : texpression =
+ let next_e = translate_expression config e ctx in
let meta =
match meta with
| S.Assignment (p, rv) ->
@@ -1416,7 +1438,8 @@ and translate_meta (meta : S.meta) (e : S.expression) (ctx : bs_ctx) :
let ty = next_e.ty in
{ e; ty }
-let translate_fun_def (ctx : bs_ctx) (body : S.expression) : fun_def =
+let translate_fun_def (config : config) (ctx : bs_ctx) (body : S.expression) :
+ fun_def =
let def = ctx.fun_def in
let bid = ctx.bid in
log#ldebug
@@ -1431,7 +1454,7 @@ let translate_fun_def (ctx : bs_ctx) (body : S.expression) : fun_def =
let def_id = def.A.def_id in
let basename = def.name in
let signature = bs_ctx_lookup_local_function_sig def_id bid ctx in
- let body = translate_expression body ctx in
+ let body = translate_expression config body ctx in
(* Compute the list of (properly ordered) input variables *)
let backward_inputs : var list =
match bid with