summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-12-21 17:00:52 +0100
committerSon Ho2023-12-21 17:00:52 +0100
commitd4b3d0e6adae5bb9a2f62872dbcedc29aaa9fa30 (patch)
treef26f591884621ba089c3f606d92c0daf8bcf35c9
parentcf3eea59ee61f2341daf7248664b8be878f128af (diff)
Filter the useless backward functions
-rw-r--r--compiler/SymbolicToPure.ml220
1 files changed, 145 insertions, 75 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index d3b0933c..f37ea201 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -67,7 +67,7 @@ type call_info = {
Those inputs include the fuel and the state, if pertinent.
*)
- back_funs : texpression RegionGroupId.Map.t option;
+ back_funs : texpression option RegionGroupId.Map.t option;
(** If we do not split between the forward/backward functions: the
variables we introduced for the backward functions.
@@ -78,6 +78,10 @@ type call_info = {
here
...
]}
+
+ The expression might be [None] in case the backward function
+ has to be filtered (because it does nothing - the backward
+ functions for shared borrows for instance).
*)
}
[@@deriving show]
@@ -125,7 +129,7 @@ type loop_info = {
(** The map from region group ids to the types of the values given back
by the corresponding loop abstractions.
*)
- back_funs : texpression RegionGroupId.Map.t option;
+ back_funs : texpression option RegionGroupId.Map.t option;
(** Same as {!call_info.back_funs}.
Initialized with [None], gets updated to [Some] only if we merge
the fwd/back functions.
@@ -777,8 +781,8 @@ let translate_fun_id_or_trait_method_ref (ctx : bs_ctx)
let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call)
(args : texpression list)
- (back_funs : texpression RegionGroupId.Map.t option) (ctx : bs_ctx) : bs_ctx
- =
+ (back_funs : texpression option RegionGroupId.Map.t option) (ctx : bs_ctx) :
+ bs_ctx =
let calls = ctx.calls in
assert (not (V.FunCallId.Map.mem call_id calls));
let info = { forward; forward_inputs = args; back_funs } in
@@ -790,13 +794,15 @@ let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call)
[back_args]: the *additional* list of inputs received by the backward function,
including the state.
- Returns the updated context and the expression corresponding to the function.
+ Returns the updated context and the expression corresponding to the function
+ that we need to call. This function may be [None] if it has to be ignored
+ (because it does nothing).
*)
let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info)
(call_id : V.FunCallId.id) (back_id : T.RegionGroupId.id)
(inherited_args : texpression list) (back_args : texpression list)
(generics : generic_args) (output_ty : ty) (ctx : bs_ctx) :
- bs_ctx * texpression =
+ bs_ctx * texpression option =
(* Insert the abstraction in the call informations *)
let info = V.FunCallId.Map.find call_id ctx.calls in
let calls = V.FunCallId.Map.add call_id info ctx.calls in
@@ -827,7 +833,7 @@ let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info)
in
let func_ty = mk_arrows input_tys ret_ty in
let func = { id = FunOrOp fun_id; generics } in
- { e = Qualif func; ty = func_ty }
+ Some { e = Qualif func; ty = func_ty }
in
(* Update the context and return *)
({ ctx with calls; abstractions }, func)
@@ -1128,23 +1134,36 @@ let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty
in
if effect_info.can_fail then mk_result_ty output else output
-(** Compute the arrow types for all the backward functions. *)
+(** Compute the arrow types for all the backward functions.
+
+ If a backward function has no inputs/outputs we filter it.
+ *)
let compute_back_tys (dsg : Pure.decomposed_fun_sig)
- (subst : (generic_args * trait_instance_id) option) : ty list =
+ (subst : (generic_args * trait_instance_id) option) : ty option list =
List.map
(fun (back_sg : back_sg_info) ->
let effect_info = back_sg.effect_info in
- (* Compute *)
+ (* Compute the input/output types *)
let inputs = List.map snd back_sg.inputs in
- let output = mk_simpl_tuple_ty back_sg.outputs in
- let output = mk_output_ty_from_effect_info effect_info output in
- let ty = mk_arrows inputs output in
- (* Substitute - TODO: normalize *)
- match subst with
- | None -> ty
- | Some (generics, tr_self) ->
- let subst = make_subst_from_generics dsg.generics generics tr_self in
- ty_substitute subst ty)
+ let outputs = back_sg.outputs in
+ (* Filter if necessary *)
+ if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] then
+ None
+ else
+ let output = mk_simpl_tuple_ty outputs in
+ let output = mk_output_ty_from_effect_info effect_info output in
+ let ty = mk_arrows inputs output in
+ (* Substitute - TODO: normalize *)
+ let ty =
+ match subst with
+ | None -> ty
+ | Some (generics, tr_self) ->
+ let subst =
+ make_subst_from_generics dsg.generics generics tr_self
+ in
+ ty_substitute subst ty
+ in
+ Some ty)
(RegionGroupId.Map.values dsg.back_sg)
let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
@@ -1169,7 +1188,7 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig)
if !Config.return_back_funs then (
assert (gid = None);
(* Compute the arrow types for all the backward functions *)
- let back_tys = compute_back_tys dsg None in
+ let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in
(* Group the forward output and the types of the backward functions *)
let effect_info = dsg.fwd_info.effect_info in
let output = mk_simpl_tuple_ty (dsg.fwd_output :: back_tys) in
@@ -1259,8 +1278,19 @@ let fresh_vars (vars : (string option * ty) list) (ctx : bs_ctx) :
bs_ctx * var list =
List.fold_left_map (fun ctx (name, ty) -> fresh_var name ty ctx) ctx vars
+let fresh_opt_vars (vars : (string option * ty) option list) (ctx : bs_ctx) :
+ bs_ctx * var option list =
+ List.fold_left_map
+ (fun ctx var ->
+ match var with
+ | None -> (ctx, None)
+ | Some (name, ty) ->
+ let ctx, var = fresh_var name ty ctx in
+ (ctx, Some var))
+ ctx vars
+
(* Introduce variables for the backward functions *)
-let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var list =
+let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var option list =
(* We lookup the LLBC definition in an attempt to derive pretty names
for the backward functions. *)
let back_var_names =
@@ -1291,7 +1321,13 @@ let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var list =
(RegionGroupId.Map.bindings ctx.sg.back_sg)
in
let back_vars = List.combine back_var_names (compute_back_tys ctx.sg None) in
- fresh_vars back_vars ctx
+ let back_vars =
+ List.map
+ (fun (name, ty) ->
+ match ty with None -> None | Some ty -> Some (name, ty))
+ back_vars
+ in
+ fresh_opt_vars back_vars ctx
let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var =
match V.SymbolicValueId.Map.find_opt sv.sv_id ctx.sv_to_var with
@@ -1748,6 +1784,7 @@ and translate_panic (ctx : bs_ctx) : texpression =
| None ->
if !Config.return_back_funs then
let back_tys = compute_back_tys ctx.sg None in
+ let back_tys = List.filter_map (fun x -> x) back_tys in
let output = mk_simpl_tuple_ty (ctx.sg.fwd_output :: back_tys) in
mk_output output
else mk_output ctx.sg.fwd_output
@@ -1933,21 +1970,33 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
name ^ "_back"
in
let ctx, back_vars =
- fresh_vars
- (List.map (fun ty -> (Some back_fun_name, ty)) back_tys)
+ fresh_opt_vars
+ (List.map
+ (fun ty ->
+ match ty with
+ | None -> None
+ | Some ty -> Some (Some back_fun_name, ty))
+ back_tys)
ctx
in
let back_funs =
- List.map (fun v -> mk_typed_pattern_from_var v None) back_vars
+ List.filter_map
+ (fun v ->
+ match v with
+ | None -> None
+ | Some v -> Some (mk_typed_pattern_from_var v None))
+ back_vars
in
let gids =
List.map
(fun (g : T.region_var_group) -> g.id)
call.regions_hierarchy
in
+ let back_vars =
+ List.map (Option.map mk_texpression_from_var) back_vars
+ in
let back_funs_map =
- RegionGroupId.Map.of_list
- (List.combine gids (List.map mk_texpression_from_var back_vars))
+ RegionGroupId.Map.of_list (List.combine gids back_vars)
in
(ctx, Some back_funs_map, back_funs)
else (ctx, None, [])
@@ -2220,15 +2269,6 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
(fun (arg, mp) -> mk_opt_mplace_texpression mp arg)
(List.combine inputs args_mplaces)
in
- log#ldebug
- (lazy
- (let args = List.map (texpression_to_string ctx) args in
- "func: "
- ^ texpression_to_string ctx func
- ^ "\nfunc type: "
- ^ pure_ty_to_string ctx func.ty
- ^ "\n\nargs:\n" ^ String.concat "\n" args));
- let call = mk_apps func args in
(* **Optimization**:
=================
We do a small optimization here if we split the forward/backward functions.
@@ -2252,7 +2292,22 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
a value containing mutable borrows, which can't be the case... *)
assert (List.length inputs = List.length fwd_inputs);
next_e)
- else mk_let effect_info.can_fail output call next_e
+ else
+ (* The backward function might also have been filtered if we do not
+ split the forward/backward functions *)
+ match func with
+ | None -> next_e
+ | Some func ->
+ log#ldebug
+ (lazy
+ (let args = List.map (texpression_to_string ctx) args in
+ "func: "
+ ^ texpression_to_string ctx func
+ ^ "\nfunc type: "
+ ^ pure_ty_to_string ctx func.ty
+ ^ "\n\nargs:\n" ^ String.concat "\n" args));
+ let call = mk_apps func args in
+ mk_let effect_info.can_fail output call next_e
and translate_end_abstraction_identity (ectx : C.eval_ctx) (abs : V.abs)
(e : S.expression) (ctx : bs_ctx) : texpression =
@@ -2348,7 +2403,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
| V.LoopSynthInput ->
(* Actually the same case as [SynthInput] *)
translate_end_abstraction_synth_input ectx abs e ctx rg_id
- | V.LoopCall ->
+ | V.LoopCall -> (
(* We need to introduce a call to the backward function corresponding
to a forward call which happened earlier *)
let fun_id = E.FRegular ctx.fun_decl.def_id in
@@ -2419,9 +2474,8 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
let func_ty = mk_arrows input_tys ret_ty in
let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in
let func = { id = FunOrOp func; generics } in
- { e = Qualif func; ty = func_ty }
+ Some { e = Qualif func; ty = func_ty }
in
- let call = mk_apps func args in
(* **Optimization**:
=================
We do a small optimization here in case we split the forward/backward
@@ -2447,38 +2501,44 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
assert (List.length inputs = List.length fwd_inputs);
next_e)
else
- (* Add meta-information - this is slightly hacky: we look at the
- values consumed by the abstraction (note that those come from
- *before* we applied the fixed-point context) and use them to
- guide the naming of the output vars.
-
- Also, we need to convert the backward outputs from patterns to
- variables.
-
- Finally, in practice, this works well only for loop bodies:
- we do this only in this case.
- TODO: improve the heuristics, to give weight to the hints for
- instance.
- *)
- let next_e =
- if ctx.inside_loop then
- let consumed_values = abs_to_consumed ctx ectx abs in
- let var_values = List.combine outputs consumed_values in
- let var_values =
- List.filter_map
- (fun (var, v) ->
- match var.Pure.value with
- | PatVar (var, _) -> Some (var, v)
- | _ -> None)
- var_values
+ (* In case we merge the fwd/back functions we filter the backward
+ functions elsewhere *)
+ match func with
+ | None -> next_e
+ | Some func ->
+ let call = mk_apps func args in
+ (* Add meta-information - this is slightly hacky: we look at the
+ values consumed by the abstraction (note that those come from
+ *before* we applied the fixed-point context) and use them to
+ guide the naming of the output vars.
+
+ Also, we need to convert the backward outputs from patterns to
+ variables.
+
+ Finally, in practice, this works well only for loop bodies:
+ we do this only in this case.
+ TODO: improve the heuristics, to give weight to the hints for
+ instance.
+ *)
+ let next_e =
+ if ctx.inside_loop then
+ let consumed_values = abs_to_consumed ctx ectx abs in
+ let var_values = List.combine outputs consumed_values in
+ let var_values =
+ List.filter_map
+ (fun (var, v) ->
+ match var.Pure.value with
+ | PatVar (var, _) -> Some (var, v)
+ | _ -> None)
+ var_values
+ in
+ let vars, values = List.split var_values in
+ mk_emeta_symbolic_assignments vars values next_e
+ else next_e
in
- let vars, values = List.split var_values in
- mk_emeta_symbolic_assignments vars values next_e
- else next_e
- in
- (* Create the let-binding *)
- mk_let effect_info.can_fail output call next_e
+ (* Create the let-binding *)
+ mk_let effect_info.can_fail output call next_e)
and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value)
(e : S.expression) (ctx : bs_ctx) : texpression =
@@ -2894,7 +2954,7 @@ and translate_forward_end (ectx : C.eval_ctx)
let _, back_vars = fresh_back_vars_for_current_fun ctx in
(* Create the return expressions *)
- let vars = fwd_var :: back_vars in
+ let vars = fwd_var :: List.filter_map (fun x -> x) back_vars in
let vars = List.map mk_texpression_from_var vars in
let ret = mk_simpl_tuple_texpression vars in
let state_var = List.map mk_texpression_from_var state_var in
@@ -2903,12 +2963,16 @@ and translate_forward_end (ectx : C.eval_ctx)
(* Bind the expressions for the backward function and the expression
for the computation of the forward output *)
+ let back_vars_els =
+ List.filter_map
+ (fun (v, el) -> match v with None -> None | Some v -> Some (v, el))
+ (List.combine back_vars back_el)
+ in
let e =
List.fold_right
(fun (var, back_e) e ->
mk_let false (mk_typed_pattern_from_var var None) back_e e)
- (List.combine back_vars back_el)
- ret
+ back_vars_els ret
in
(* Bind the expression for the forward output *)
let fwd_var = mk_typed_pattern_from_var fwd_var None in
@@ -2976,12 +3040,18 @@ and translate_forward_end (ectx : C.eval_ctx)
if !Config.return_back_funs then
let ctx, back_vars = fresh_back_vars_for_current_fun ctx in
let back_funs =
- List.map (fun v -> mk_typed_pattern_from_var v None) back_vars
+ List.filter_map
+ (fun v ->
+ match v with
+ | None -> None
+ | Some v -> Some (mk_typed_pattern_from_var v None))
+ back_vars
in
let gids = RegionGroupId.Map.keys ctx.sg.back_sg in
let back_funs_map =
RegionGroupId.Map.of_list
- (List.combine gids (List.map mk_texpression_from_var back_vars))
+ (List.combine gids
+ (List.map (Option.map mk_texpression_from_var) back_vars))
in
(ctx, Some back_funs_map, back_funs)
else (ctx, None, [])