diff options
author | Son Ho | 2023-12-21 17:00:52 +0100 |
---|---|---|
committer | Son Ho | 2023-12-21 17:00:52 +0100 |
commit | d4b3d0e6adae5bb9a2f62872dbcedc29aaa9fa30 (patch) | |
tree | f26f591884621ba089c3f606d92c0daf8bcf35c9 | |
parent | cf3eea59ee61f2341daf7248664b8be878f128af (diff) |
Filter the useless backward functions
-rw-r--r-- | compiler/SymbolicToPure.ml | 220 |
1 files changed, 145 insertions, 75 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index d3b0933c..f37ea201 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -67,7 +67,7 @@ type call_info = { Those inputs include the fuel and the state, if pertinent. *) - back_funs : texpression RegionGroupId.Map.t option; + back_funs : texpression option RegionGroupId.Map.t option; (** If we do not split between the forward/backward functions: the variables we introduced for the backward functions. @@ -78,6 +78,10 @@ type call_info = { here ... ]} + + The expression might be [None] in case the backward function + has to be filtered (because it does nothing - the backward + functions for shared borrows for instance). *) } [@@deriving show] @@ -125,7 +129,7 @@ type loop_info = { (** The map from region group ids to the types of the values given back by the corresponding loop abstractions. *) - back_funs : texpression RegionGroupId.Map.t option; + back_funs : texpression option RegionGroupId.Map.t option; (** Same as {!call_info.back_funs}. Initialized with [None], gets updated to [Some] only if we merge the fwd/back functions. @@ -777,8 +781,8 @@ let translate_fun_id_or_trait_method_ref (ctx : bs_ctx) let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) (args : texpression list) - (back_funs : texpression RegionGroupId.Map.t option) (ctx : bs_ctx) : bs_ctx - = + (back_funs : texpression option RegionGroupId.Map.t option) (ctx : bs_ctx) : + bs_ctx = let calls = ctx.calls in assert (not (V.FunCallId.Map.mem call_id calls)); let info = { forward; forward_inputs = args; back_funs } in @@ -790,13 +794,15 @@ let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) [back_args]: the *additional* list of inputs received by the backward function, including the state. - Returns the updated context and the expression corresponding to the function. + Returns the updated context and the expression corresponding to the function + that we need to call. This function may be [None] if it has to be ignored + (because it does nothing). *) let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info) (call_id : V.FunCallId.id) (back_id : T.RegionGroupId.id) (inherited_args : texpression list) (back_args : texpression list) (generics : generic_args) (output_ty : ty) (ctx : bs_ctx) : - bs_ctx * texpression = + bs_ctx * texpression option = (* Insert the abstraction in the call informations *) let info = V.FunCallId.Map.find call_id ctx.calls in let calls = V.FunCallId.Map.add call_id info ctx.calls in @@ -827,7 +833,7 @@ let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info) in let func_ty = mk_arrows input_tys ret_ty in let func = { id = FunOrOp fun_id; generics } in - { e = Qualif func; ty = func_ty } + Some { e = Qualif func; ty = func_ty } in (* Update the context and return *) ({ ctx with calls; abstractions }, func) @@ -1128,23 +1134,36 @@ let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty in if effect_info.can_fail then mk_result_ty output else output -(** Compute the arrow types for all the backward functions. *) +(** Compute the arrow types for all the backward functions. + + If a backward function has no inputs/outputs we filter it. + *) let compute_back_tys (dsg : Pure.decomposed_fun_sig) - (subst : (generic_args * trait_instance_id) option) : ty list = + (subst : (generic_args * trait_instance_id) option) : ty option list = List.map (fun (back_sg : back_sg_info) -> let effect_info = back_sg.effect_info in - (* Compute *) + (* Compute the input/output types *) let inputs = List.map snd back_sg.inputs in - let output = mk_simpl_tuple_ty back_sg.outputs in - let output = mk_output_ty_from_effect_info effect_info output in - let ty = mk_arrows inputs output in - (* Substitute - TODO: normalize *) - match subst with - | None -> ty - | Some (generics, tr_self) -> - let subst = make_subst_from_generics dsg.generics generics tr_self in - ty_substitute subst ty) + let outputs = back_sg.outputs in + (* Filter if necessary *) + if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] then + None + else + let output = mk_simpl_tuple_ty outputs in + let output = mk_output_ty_from_effect_info effect_info output in + let ty = mk_arrows inputs output in + (* Substitute - TODO: normalize *) + let ty = + match subst with + | None -> ty + | Some (generics, tr_self) -> + let subst = + make_subst_from_generics dsg.generics generics tr_self + in + ty_substitute subst ty + in + Some ty) (RegionGroupId.Map.values dsg.back_sg) let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) @@ -1169,7 +1188,7 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) if !Config.return_back_funs then ( assert (gid = None); (* Compute the arrow types for all the backward functions *) - let back_tys = compute_back_tys dsg None in + let back_tys = List.filter_map (fun x -> x) (compute_back_tys dsg None) in (* Group the forward output and the types of the backward functions *) let effect_info = dsg.fwd_info.effect_info in let output = mk_simpl_tuple_ty (dsg.fwd_output :: back_tys) in @@ -1259,8 +1278,19 @@ let fresh_vars (vars : (string option * ty) list) (ctx : bs_ctx) : bs_ctx * var list = List.fold_left_map (fun ctx (name, ty) -> fresh_var name ty ctx) ctx vars +let fresh_opt_vars (vars : (string option * ty) option list) (ctx : bs_ctx) : + bs_ctx * var option list = + List.fold_left_map + (fun ctx var -> + match var with + | None -> (ctx, None) + | Some (name, ty) -> + let ctx, var = fresh_var name ty ctx in + (ctx, Some var)) + ctx vars + (* Introduce variables for the backward functions *) -let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var list = +let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var option list = (* We lookup the LLBC definition in an attempt to derive pretty names for the backward functions. *) let back_var_names = @@ -1291,7 +1321,13 @@ let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var list = (RegionGroupId.Map.bindings ctx.sg.back_sg) in let back_vars = List.combine back_var_names (compute_back_tys ctx.sg None) in - fresh_vars back_vars ctx + let back_vars = + List.map + (fun (name, ty) -> + match ty with None -> None | Some ty -> Some (name, ty)) + back_vars + in + fresh_opt_vars back_vars ctx let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var = match V.SymbolicValueId.Map.find_opt sv.sv_id ctx.sv_to_var with @@ -1748,6 +1784,7 @@ and translate_panic (ctx : bs_ctx) : texpression = | None -> if !Config.return_back_funs then let back_tys = compute_back_tys ctx.sg None in + let back_tys = List.filter_map (fun x -> x) back_tys in let output = mk_simpl_tuple_ty (ctx.sg.fwd_output :: back_tys) in mk_output output else mk_output ctx.sg.fwd_output @@ -1933,21 +1970,33 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : name ^ "_back" in let ctx, back_vars = - fresh_vars - (List.map (fun ty -> (Some back_fun_name, ty)) back_tys) + fresh_opt_vars + (List.map + (fun ty -> + match ty with + | None -> None + | Some ty -> Some (Some back_fun_name, ty)) + back_tys) ctx in let back_funs = - List.map (fun v -> mk_typed_pattern_from_var v None) back_vars + List.filter_map + (fun v -> + match v with + | None -> None + | Some v -> Some (mk_typed_pattern_from_var v None)) + back_vars in let gids = List.map (fun (g : T.region_var_group) -> g.id) call.regions_hierarchy in + let back_vars = + List.map (Option.map mk_texpression_from_var) back_vars + in let back_funs_map = - RegionGroupId.Map.of_list - (List.combine gids (List.map mk_texpression_from_var back_vars)) + RegionGroupId.Map.of_list (List.combine gids back_vars) in (ctx, Some back_funs_map, back_funs) else (ctx, None, []) @@ -2220,15 +2269,6 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) (fun (arg, mp) -> mk_opt_mplace_texpression mp arg) (List.combine inputs args_mplaces) in - log#ldebug - (lazy - (let args = List.map (texpression_to_string ctx) args in - "func: " - ^ texpression_to_string ctx func - ^ "\nfunc type: " - ^ pure_ty_to_string ctx func.ty - ^ "\n\nargs:\n" ^ String.concat "\n" args)); - let call = mk_apps func args in (* **Optimization**: ================= We do a small optimization here if we split the forward/backward functions. @@ -2252,7 +2292,22 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) a value containing mutable borrows, which can't be the case... *) assert (List.length inputs = List.length fwd_inputs); next_e) - else mk_let effect_info.can_fail output call next_e + else + (* The backward function might also have been filtered if we do not + split the forward/backward functions *) + match func with + | None -> next_e + | Some func -> + log#ldebug + (lazy + (let args = List.map (texpression_to_string ctx) args in + "func: " + ^ texpression_to_string ctx func + ^ "\nfunc type: " + ^ pure_ty_to_string ctx func.ty + ^ "\n\nargs:\n" ^ String.concat "\n" args)); + let call = mk_apps func args in + mk_let effect_info.can_fail output call next_e and translate_end_abstraction_identity (ectx : C.eval_ctx) (abs : V.abs) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -2348,7 +2403,7 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) | V.LoopSynthInput -> (* Actually the same case as [SynthInput] *) translate_end_abstraction_synth_input ectx abs e ctx rg_id - | V.LoopCall -> + | V.LoopCall -> ( (* We need to introduce a call to the backward function corresponding to a forward call which happened earlier *) let fun_id = E.FRegular ctx.fun_decl.def_id in @@ -2419,9 +2474,8 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) let func_ty = mk_arrows input_tys ret_ty in let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in let func = { id = FunOrOp func; generics } in - { e = Qualif func; ty = func_ty } + Some { e = Qualif func; ty = func_ty } in - let call = mk_apps func args in (* **Optimization**: ================= We do a small optimization here in case we split the forward/backward @@ -2447,38 +2501,44 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) assert (List.length inputs = List.length fwd_inputs); next_e) else - (* Add meta-information - this is slightly hacky: we look at the - values consumed by the abstraction (note that those come from - *before* we applied the fixed-point context) and use them to - guide the naming of the output vars. - - Also, we need to convert the backward outputs from patterns to - variables. - - Finally, in practice, this works well only for loop bodies: - we do this only in this case. - TODO: improve the heuristics, to give weight to the hints for - instance. - *) - let next_e = - if ctx.inside_loop then - let consumed_values = abs_to_consumed ctx ectx abs in - let var_values = List.combine outputs consumed_values in - let var_values = - List.filter_map - (fun (var, v) -> - match var.Pure.value with - | PatVar (var, _) -> Some (var, v) - | _ -> None) - var_values + (* In case we merge the fwd/back functions we filter the backward + functions elsewhere *) + match func with + | None -> next_e + | Some func -> + let call = mk_apps func args in + (* Add meta-information - this is slightly hacky: we look at the + values consumed by the abstraction (note that those come from + *before* we applied the fixed-point context) and use them to + guide the naming of the output vars. + + Also, we need to convert the backward outputs from patterns to + variables. + + Finally, in practice, this works well only for loop bodies: + we do this only in this case. + TODO: improve the heuristics, to give weight to the hints for + instance. + *) + let next_e = + if ctx.inside_loop then + let consumed_values = abs_to_consumed ctx ectx abs in + let var_values = List.combine outputs consumed_values in + let var_values = + List.filter_map + (fun (var, v) -> + match var.Pure.value with + | PatVar (var, _) -> Some (var, v) + | _ -> None) + var_values + in + let vars, values = List.split var_values in + mk_emeta_symbolic_assignments vars values next_e + else next_e in - let vars, values = List.split var_values in - mk_emeta_symbolic_assignments vars values next_e - else next_e - in - (* Create the let-binding *) - mk_let effect_info.can_fail output call next_e + (* Create the let-binding *) + mk_let effect_info.can_fail output call next_e) and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value) (e : S.expression) (ctx : bs_ctx) : texpression = @@ -2894,7 +2954,7 @@ and translate_forward_end (ectx : C.eval_ctx) let _, back_vars = fresh_back_vars_for_current_fun ctx in (* Create the return expressions *) - let vars = fwd_var :: back_vars in + let vars = fwd_var :: List.filter_map (fun x -> x) back_vars in let vars = List.map mk_texpression_from_var vars in let ret = mk_simpl_tuple_texpression vars in let state_var = List.map mk_texpression_from_var state_var in @@ -2903,12 +2963,16 @@ and translate_forward_end (ectx : C.eval_ctx) (* Bind the expressions for the backward function and the expression for the computation of the forward output *) + let back_vars_els = + List.filter_map + (fun (v, el) -> match v with None -> None | Some v -> Some (v, el)) + (List.combine back_vars back_el) + in let e = List.fold_right (fun (var, back_e) e -> mk_let false (mk_typed_pattern_from_var var None) back_e e) - (List.combine back_vars back_el) - ret + back_vars_els ret in (* Bind the expression for the forward output *) let fwd_var = mk_typed_pattern_from_var fwd_var None in @@ -2976,12 +3040,18 @@ and translate_forward_end (ectx : C.eval_ctx) if !Config.return_back_funs then let ctx, back_vars = fresh_back_vars_for_current_fun ctx in let back_funs = - List.map (fun v -> mk_typed_pattern_from_var v None) back_vars + List.filter_map + (fun v -> + match v with + | None -> None + | Some v -> Some (mk_typed_pattern_from_var v None)) + back_vars in let gids = RegionGroupId.Map.keys ctx.sg.back_sg in let back_funs_map = RegionGroupId.Map.of_list - (List.combine gids (List.map mk_texpression_from_var back_vars)) + (List.combine gids + (List.map (Option.map mk_texpression_from_var) back_vars)) in (ctx, Some back_funs_map, back_funs) else (ctx, None, []) |