From 8835d87df111d09122267fadc9a32f16b52d234a Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 14:37:43 +0100 Subject: Make good progress on merging the fwd/back functions --- compiler/SymbolicToPure.ml | 266 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 209 insertions(+), 57 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index e2787271..1ce6c698 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -67,6 +67,18 @@ type call_info = { Those inputs include the fuel and the state, if pertinent. *) + back_funs : texpression RegionGroupId.Map.t option; + (** If we do not split between the forward/backward functions: the + variables we introduced for the backward functions. + + Example: + {[ + let x, back = Vec.index_mut n v in + ^^^^ + here + ... + ]} + *) } [@@deriving show] @@ -118,6 +130,8 @@ type loop_info = { (** Body synthesis context *) type bs_ctx = { + (* TODO: there are a lot of duplications with the various decls ctx *) + decls_ctx : C.decls_ctx; type_ctx : type_ctx; fun_ctx : fun_ctx; global_ctx : global_ctx; @@ -757,17 +771,27 @@ let translate_fun_id_or_trait_method_ref (ctx : bs_ctx) TraitMethod (trait_ref, method_name, fun_decl_id) let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call) - (args : texpression list) (ctx : bs_ctx) : bs_ctx = + (args : texpression list) + (back_funs : texpression RegionGroupId.Map.t option) (ctx : bs_ctx) : bs_ctx + = let calls = ctx.calls in assert (not (V.FunCallId.Map.mem call_id calls)); - let info = { forward; forward_inputs = args } in + let info = { forward; forward_inputs = args; back_funs } in let calls = V.FunCallId.Map.add call_id info calls in { ctx with calls } -(** [back_args]: the *additional* list of inputs received by the backward function *) -let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id) - (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx) - : bs_ctx * fun_or_op_id = +(** [inherit_args]: the list of inputs inherited from the forward function and + the ancestors backward functions, if pertinent. + [back_args]: the *additional* list of inputs received by the backward function, + including the state. + + Returns the updated context and the expression corresponding to the function. + *) +let bs_ctx_register_backward_call (abs : V.abs) (effect_info : fun_effect_info) + (call_id : V.FunCallId.id) (back_id : T.RegionGroupId.id) + (inherited_args : texpression list) (back_args : texpression list) + (generics : generic_args) (output_ty : ty) (ctx : bs_ctx) : + bs_ctx * texpression = (* Insert the abstraction in the call informations *) let info = V.FunCallId.Map.find call_id ctx.calls in let calls = V.FunCallId.Map.add call_id info ctx.calls in @@ -777,16 +801,31 @@ let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id) let abstractions = V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions in - (* Retrieve the fun_id *) - let fun_id = - match info.forward.call_id with - | S.Fun (fid, _) -> - let fid = translate_fun_id_or_trait_method_ref ctx fid in - Fun (FromLlbc (fid, None, Some back_id)) - | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable") + (* Compute the expression corresponding to the function *) + let func = + if !Config.return_back_funs then + (* Lookup the variable introduced for the backward function *) + RegionGroupId.Map.find back_id (Option.get info.back_funs) + else + (* Retrieve the fun_id *) + let fun_id = + match info.forward.call_id with + | S.Fun (fid, _) -> + let fid = translate_fun_id_or_trait_method_ref ctx fid in + Fun (FromLlbc (fid, None, Some back_id)) + | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable") + in + let args = List.append inherited_args back_args in + let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in + let ret_ty = + if effect_info.can_fail then mk_result_ty output_ty else output_ty + in + let func_ty = mk_arrows input_tys ret_ty in + let func = { id = FunOrOp fun_id; generics } in + { e = Qualif func; ty = func_ty } in (* Update the context and return *) - ({ ctx with calls; abstractions }, fun_id) + ({ ctx with calls; abstractions }, func) (** List the ancestors of an abstraction *) let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs) @@ -878,15 +917,12 @@ let get_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) We use [bid] ("backward function id") only if we split the forward and the backward functions. *) -let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) - (fun_id : A.fun_id) (sg : A.fun_sig) (input_names : string option list) : - decomposed_fun_sig = +let translate_fun_sig_with_regions_hierarchy_to_decomposed + (decls_ctx : C.decls_ctx) (fun_id : A.fun_id_or_trait_method_ref) + (regions_hierarchy : T.region_var_groups) (sg : A.fun_sig) + (input_names : string option list) : decomposed_fun_sig = let fun_infos = decls_ctx.fun_ctx.fun_infos in let type_infos = decls_ctx.type_ctx.type_infos in - (* Retrieve the list of parent backward functions *) - let regions_hierarchy = - FunIdMap.find fun_id decls_ctx.fun_ctx.regions_hierarchies - in (* We need an evaluation context to normalize the types (to normalize the associated types, etc. - for instance it may happen that the types refer to the types associated to a trait ref, but where the trait ref @@ -915,9 +951,7 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) in (* Is the forward function stateful, and can it fail? *) - let fwd_effect_info = - get_fun_effect_info fun_infos (FunId fun_id) None None - in + let fwd_effect_info = get_fun_effect_info fun_infos fun_id None None in (* Compute the forward inputs *) let fwd_fuel = mk_fuel_input_ty_as_list fwd_effect_info in let fwd_inputs_no_fuel_no_state = @@ -1030,7 +1064,7 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) RegionGroupId.id * back_sg_info = let gid = rg.id in let back_effect_info = - get_fun_effect_info fun_infos (FunId fun_id) None (Some gid) + get_fun_effect_info fun_infos fun_id None (Some gid) in let inputs_no_state = translate_back_inputs_for_gid gid in let inputs_no_state = @@ -1072,6 +1106,16 @@ let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) fwd_info; } +let translate_fun_sig_to_decomposed (decls_ctx : C.decls_ctx) + (fun_id : FunDeclId.id) (sg : A.fun_sig) (input_names : string option list) + : decomposed_fun_sig = + (* Retrieve the list of parent backward functions *) + let regions_hierarchy = + FunIdMap.find (FRegular fun_id) decls_ctx.fun_ctx.regions_hierarchies + in + translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx + (FunId (FRegular fun_id)) regions_hierarchy sg input_names + let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty = let output = @@ -1090,6 +1134,40 @@ let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list = mk_arrows inputs output) (RegionGroupId.Map.values dsg.back_sg) +(** Return the pure signature of a backward function, in the case the + forward/backward functions are merged (i.e., the forward functions + return the backward functions). + + TODO: merge with {!translate_fun_sig_from_decomposed} + *) +let translate_ret_back_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) + (gid : RegionGroupId.id) : fun_sig = + assert !Config.return_back_funs; + + let generics = dsg.generics in + let llbc_generics = dsg.llbc_generics in + let preds = dsg.preds in + (* Compute the effects info *) + let fwd_info = dsg.fwd_info in + let back_effect_info = + RegionGroupId.Map.of_list + (List.map + (fun ((gid, info) : RegionGroupId.id * back_sg_info) -> + (gid, info.effect_info)) + (RegionGroupId.Map.bindings dsg.back_sg)) + in + (* Two cases depending on whether we split the forward/backward functions + or not *) + let mk_output_ty = mk_output_ty_from_effect_info in + + let back_sg = RegionGroupId.Map.find gid dsg.back_sg in + let effect_info = back_sg.effect_info in + (* Do not prepend the forward inputs *) + let inputs = List.map snd back_sg.inputs in + let output = mk_simpl_tuple_ty back_sg.outputs in + let output = mk_output_ty effect_info output in + { generics; llbc_generics; preds; inputs; output; fwd_info; back_effect_info } + let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) (gid : RegionGroupId.id option) : fun_sig = let generics = dsg.generics in @@ -1774,7 +1852,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let ctx, dest = fresh_var_for_symbolic_value call.dest ctx in (* Retrieve the function id, and register the function call in the context * if necessary. *) - let ctx, fun_id, effect_info, args, out_state = + let ctx, fun_id, effect_info, args, back_funs, out_state = match call.call_id with | S.Fun (fid, call_id) -> (* Regular function call *) @@ -1798,9 +1876,80 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : (List.concat [ fuel; args; [ state_var ] ], ctx, Some nstate_var) else (List.concat [ fuel; args ], ctx, None) in + (* If we do not split the forward/backward functions: generate the + variables for the backward functions returned by the forward + function. *) + let ctx, back_funs_map, back_funs = + if !Config.return_back_funs then + (* We need to compute the signatures of the backward functions. *) + let sg = Option.get call.sg in + let decls_ctx = ctx.decls_ctx in + let dsg = + translate_fun_sig_with_regions_hierarchy_to_decomposed decls_ctx + fid call.regions_hierarchy sg + (List.map (fun _ -> None) sg.inputs) + in + let gids = + List.map + (fun (g : T.region_var_group) -> g.id) + call.regions_hierarchy + in + let back_sgs = + List.map (translate_ret_back_fun_sig_from_decomposed dsg) gids + in + (* Introduce variables for the backward functions *) + let back_tys = + List.map + (fun (sg : fun_sig) -> mk_arrows sg.inputs sg.output) + back_sgs + in + (* Compute a proper basename for the variables *) + let back_fun_name = + let name = + match fid with + | FunId (FAssumed fid) -> ( + match fid with + | BoxNew -> "box_new" + | BoxFree -> "box_free" + | ArrayRepeat -> "array_repeat" + | ArrayIndexShared -> "index_shared" + | ArrayIndexMut -> "index_mut" + | ArrayToSliceShared -> "to_slice_shared" + | ArrayToSliceMut -> "to_slice_mut" + | SliceIndexShared -> "index_shared" + | SliceIndexMut -> "index_mut") + | FunId (FRegular fid) | TraitMethod (_, _, fid) -> ( + let decl = + FunDeclId.Map.find fid ctx.fun_ctx.llbc_fun_decls + in + match Collections.List.last decl.name with + | PeIdent (s, _) -> s + | PeImpl _ -> + (* We shouldn't get there *) + raise (Failure "Unexpected")) + in + name ^ "_back" + in + let ctx, back_vars = + fresh_vars + (List.map (fun ty -> (Some back_fun_name, ty)) back_tys) + ctx + in + let back_funs = + List.map (fun v -> mk_typed_pattern_from_var v None) back_vars + in + let back_funs_map = + RegionGroupId.Map.of_list + (List.combine gids (List.map mk_texpression_from_var back_vars)) + in + (ctx, Some back_funs_map, back_funs) + else (ctx, None, []) + in (* Register the function call *) - let ctx = bs_ctx_register_forward_call call_id call args ctx in - (ctx, func, effect_info, args, out_state) + let ctx = + bs_ctx_register_forward_call call_id call args back_funs_map ctx + in + (ctx, func, effect_info, args, back_funs, out_state) | S.Unop E.Not -> let effect_info = { @@ -1811,7 +1960,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop Not, effect_info, args, None) + (ctx, Unop Not, effect_info, args, [], None) | S.Unop E.Neg -> ( match args with | [ arg ] -> @@ -1827,7 +1976,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop (Neg int_ty), effect_info, args, None) + (ctx, Unop (Neg int_ty), effect_info, args, [], None) | _ -> raise (Failure "Unreachable")) | S.Unop (E.Cast cast_kind) -> ( match cast_kind with @@ -1842,7 +1991,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, None) + (ctx, Unop (Cast (src_ty, tgt_ty)), effect_info, args, [], None) | CastFnPtr _ -> raise (Failure "TODO: function casts")) | S.Binop binop -> ( match args with @@ -1862,11 +2011,12 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : is_rec = false; } in - (ctx, Binop (binop, int_ty0), effect_info, args, None) + (ctx, Binop (binop, int_ty0), effect_info, args, [], None) | _ -> raise (Failure "Unreachable")) in let dest_v = let dest = mk_typed_pattern_from_var dest dest_mplace in + let dest = mk_simpl_tuple_pattern (dest :: back_funs) in match out_state with | None -> dest | Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ] @@ -2026,9 +2176,11 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) else ([], ctx, None) in (* Concatenate all the inpus *) - let inputs = - List.concat [ fwd_inputs; back_ancestors_inputs; back_inputs; back_state ] + let inherited_inputs = + if !Config.return_back_funs then [] + else List.concat [ fwd_inputs; back_ancestors_inputs ] in + let back_inputs = List.append back_inputs back_state in (* Retrieve the values given back by this function: those are the output * values. We rely on the fact that there are no nested borrows to use the * meta-place information from the input values given to the forward function @@ -2046,43 +2198,43 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs) | Some nstate -> mk_simpl_tuple_pattern [ nstate; output ] in (* Retrieve the function id, and register the function call in the context - * if necessary *) + if necessary.Arith_status *) let ctx, func = - bs_ctx_register_backward_call abs call_id rg_id back_inputs ctx + bs_ctx_register_backward_call abs effect_info call_id rg_id inherited_inputs + back_inputs generics output.ty ctx in (* Translate the next expression *) let next_e = translate_expression e ctx in (* Put everything together *) + let inputs = List.append inherited_inputs back_inputs in let args_mplaces = List.map (fun _ -> None) inputs in let args = List.map (fun (arg, mp) -> mk_opt_mplace_texpression mp arg) (List.combine inputs args_mplaces) in - let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in - let ret_ty = - if effect_info.can_fail then mk_result_ty output.ty else output.ty - in - let func_ty = mk_arrows input_tys ret_ty in - let func = { id = FunOrOp func; generics } in - let func = { e = Qualif func; ty = func_ty } in let call = mk_apps func args in (* **Optimization**: - * ================= - * We do a small optimization here: if the backward function doesn't - * have any output, we don't introduce any function call. - * See the comment in {!Config.filter_useless_monadic_calls}. - * - * TODO: use an option to disallow backward functions from updating the state. - * TODO: a backward function which only gives back shared borrows shouldn't - * update the state (state updates should only be used for mutable borrows, - * with objects like Rc for instance). - *) - if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None then ( + ================= + We do a small optimization here if we split the forward/backward functions. + If the backward function doesn't have any output, we don't introduce any function + call. + See the comment in {!Config.filter_useless_monadic_calls}. + + TODO: use an option to disallow backward functions from updating the state. + TODO: a backward function which only gives back shared borrows shouldn't + update the state (state updates should only be used for mutable borrows, + with objects like Rc for instance). + *) + if + (not !Config.return_back_funs) + && !Config.filter_useless_monadic_calls + && outputs = [] && nstate = None + then ( (* No outputs - we do a small sanity check: the backward function - * should have exactly the same number of inputs as the forward: - * this number can be different only if the forward function returned - * a value containing mutable borrows, which can't be the case... *) + should have exactly the same number of inputs as the forward: + this number can be different only if the forward function returned + a value containing mutable borrows, which can't be the case... *) assert (List.length inputs = List.length fwd_inputs); next_e) else mk_let effect_info.can_fail output call next_e -- cgit v1.2.3