From cf3eea59ee61f2341daf7248664b8be878f128af Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 21 Dec 2023 16:35:27 +0100 Subject: Update SymbolicToPure.ml for the loops --- compiler/SymbolicToPure.ml | 221 +++++++++++++++++++++++++-------------------- 1 file changed, 125 insertions(+), 96 deletions(-) (limited to 'compiler/SymbolicToPure.ml') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index ef0a0bde..d3b0933c 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -125,6 +125,11 @@ type loop_info = { (** The map from region group ids to the types of the values given back by the corresponding loop abstractions. *) + back_funs : texpression RegionGroupId.Map.t option; + (** Same as {!call_info.back_funs}. + Initialized with [None], gets updated to [Some] only if we merge + the fwd/back functions. + *) } [@@deriving show] @@ -1123,45 +1128,25 @@ let mk_output_ty_from_effect_info (effect_info : fun_effect_info) (ty : ty) : ty in if effect_info.can_fail then mk_result_ty output else output -(** Compute the arrow types for all the backward functions. - - TODO: merge with below? - *) -let compute_back_tys (dsg : Pure.decomposed_fun_sig) : ty list = +(** Compute the arrow types for all the backward functions. *) +let compute_back_tys (dsg : Pure.decomposed_fun_sig) + (subst : (generic_args * trait_instance_id) option) : ty list = List.map (fun (back_sg : back_sg_info) -> let effect_info = back_sg.effect_info in + (* Compute *) let inputs = List.map snd back_sg.inputs in let output = mk_simpl_tuple_ty back_sg.outputs in let output = mk_output_ty_from_effect_info effect_info output in - mk_arrows inputs output) + let ty = mk_arrows inputs output in + (* Substitute - TODO: normalize *) + match subst with + | None -> ty + | Some (generics, tr_self) -> + let subst = make_subst_from_generics dsg.generics generics tr_self in + ty_substitute subst ty) (RegionGroupId.Map.values dsg.back_sg) -(** Return the instantiated pure signature of a backward function, in the - case the forward/backward functions are merged (i.e., the forward functions - return the backward functions). - *) -let translate_ret_back_inst_fun_sig_from_decomposed - (dsg : Pure.decomposed_fun_sig) (generics : generic_args) - (gid : RegionGroupId.id) : inst_fun_sig = - assert !Config.return_back_funs; - let mk_output_ty = mk_output_ty_from_effect_info in - (* Lookup the signature information *) - let back_sg = RegionGroupId.Map.find gid dsg.back_sg in - let effect_info = back_sg.effect_info in - (* Do not prepend the forward inputs *) - let inputs = List.map snd back_sg.inputs in - let output = mk_simpl_tuple_ty back_sg.outputs in - let output = mk_output_ty effect_info output in - (* Substitute the types *) - let tr_self = UnknownTrait __FUNCTION__ in - let subst = make_subst_from_generics dsg.generics generics tr_self in - let subst = ty_substitute subst in - let inputs = List.map subst inputs in - let output = subst output in - (* Return *) - { inputs; output } - let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) (gid : RegionGroupId.id option) : fun_sig = let generics = dsg.generics in @@ -1184,7 +1169,7 @@ let translate_fun_sig_from_decomposed (dsg : Pure.decomposed_fun_sig) if !Config.return_back_funs then ( assert (gid = None); (* Compute the arrow types for all the backward functions *) - let back_tys = compute_back_tys dsg in + let back_tys = compute_back_tys dsg None in (* Group the forward output and the types of the backward functions *) let effect_info = dsg.fwd_info.effect_info in let output = mk_simpl_tuple_ty (dsg.fwd_output :: back_tys) in @@ -1274,6 +1259,40 @@ let fresh_vars (vars : (string option * ty) list) (ctx : bs_ctx) : bs_ctx * var list = List.fold_left_map (fun ctx (name, ty) -> fresh_var name ty ctx) ctx vars +(* Introduce variables for the backward functions *) +let fresh_back_vars_for_current_fun (ctx : bs_ctx) : bs_ctx * var list = + (* We lookup the LLBC definition in an attempt to derive pretty names + for the backward functions. *) + let back_var_names = + let def_id = ctx.fun_decl.def_id in + let sg = ctx.fun_decl.signature in + let regions_hierarchy = + LlbcAstUtils.FunIdMap.find (FRegular def_id) + ctx.fun_ctx.regions_hierarchies + in + List.map + (fun (gid, _) -> + let rg = RegionGroupId.nth regions_hierarchy gid in + let region_names = + List.map + (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name) + rg.regions + in + let name = + match region_names with + | [] -> "back" + | [ Some r ] -> "back" ^ r + | _ -> + (* Concatenate all the region names *) + "back" + ^ String.concat "" (List.filter_map (fun x -> x) region_names) + in + Some name) + (RegionGroupId.Map.bindings ctx.sg.back_sg) + in + let back_vars = List.combine back_var_names (compute_back_tys ctx.sg None) in + fresh_vars back_vars ctx + let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var = match V.SymbolicValueId.Map.find_opt sv.sv_id ctx.sv_to_var with | Some v -> v @@ -1728,7 +1747,7 @@ and translate_panic (ctx : bs_ctx) : texpression = match ctx.bid with | None -> if !Config.return_back_funs then - let back_tys = compute_back_tys ctx.sg in + let back_tys = compute_back_tys ctx.sg None in let output = mk_simpl_tuple_ty (ctx.sg.fwd_output :: back_tys) in mk_output output else mk_output ctx.sg.fwd_output @@ -1883,22 +1902,9 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : fid call.regions_hierarchy sg (List.map (fun _ -> None) sg.inputs) in - let gids = - List.map - (fun (g : T.region_var_group) -> g.id) - call.regions_hierarchy - in - let back_sgs = - List.map - (translate_ret_back_inst_fun_sig_from_decomposed dsg generics) - gids - in + let tr_self = UnknownTrait __FUNCTION__ in + let back_tys = compute_back_tys dsg (Some (generics, tr_self)) in (* Introduce variables for the backward functions *) - let back_tys = - List.map - (fun (sg : inst_fun_sig) -> mk_arrows sg.inputs sg.output) - back_sgs - in (* Compute a proper basename for the variables *) let back_fun_name = let name = @@ -1934,6 +1940,11 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) : let back_funs = List.map (fun v -> mk_typed_pattern_from_var v None) back_vars in + let gids = + List.map + (fun (g : T.region_var_group) -> g.id) + call.regions_hierarchy + in let back_funs_map = RegionGroupId.Map.of_list (List.combine gids (List.map mk_texpression_from_var back_vars)) @@ -2338,6 +2349,8 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) (* Actually the same case as [SynthInput] *) translate_end_abstraction_synth_input ectx abs e ctx rg_id | V.LoopCall -> + (* We need to introduce a call to the backward function corresponding + to a forward call which happened earlier *) let fun_id = E.FRegular ctx.fun_decl.def_id in let effect_info = get_fun_effect_info ctx.fun_ctx.fun_infos (FunId fun_id) (Some vloop_id) @@ -2367,7 +2380,10 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) else ([], ctx, None) in (* Concatenate all the inputs *) - let inputs = List.concat [ fwd_inputs; back_inputs; back_state ] in + let inputs = + if !Config.return_back_funs then List.concat [ back_inputs; back_state ] + else List.concat [ fwd_inputs; back_inputs; back_state ] + in (* Retrieve the values given back by this function *) let ctx, outputs = abs_to_given_back None abs ctx in (* Group the output values together: first the updated inputs *) @@ -2391,28 +2407,43 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs) let ret_ty = if effect_info.can_fail then mk_result_ty output.ty else output.ty in - let func_ty = mk_arrows input_tys ret_ty in - let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in - let func = { id = FunOrOp func; generics } in - let func = { e = Qualif func; ty = func_ty } in + (* Create the expression for the function: + - it is either a call to a top-level function, if we split the + forward/backward functions + - or a call to the variable we introduced for the backward function, + if we merge the forward/backward functions *) + let func = + if !Config.return_back_funs then + RegionGroupId.Map.find rg_id (Option.get loop_info.back_funs) + else + let func_ty = mk_arrows input_tys ret_ty in + let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in + let func = { id = FunOrOp func; generics } in + { e = Qualif func; ty = func_ty } + in let call = mk_apps func args in (* **Optimization**: - * ================= - * We do a small optimization here: if the backward function doesn't - * have any output, we don't introduce any function call. - * See the comment in {!Config.filter_useless_monadic_calls}. - * - * TODO: use an option to disallow backward functions from updating the state. - * TODO: a backward function which only gives back shared borrows shouldn't - * update the state (state updates should only be used for mutable borrows, - * with objects like Rc for instance). - *) - if !Config.filter_useless_monadic_calls && outputs = [] && nstate = None + ================= + We do a small optimization here in case we split the forward/backward + functions. + If the backward function doesn't have any output, we don't introduce + any function call. + See the comment in {!Config.filter_useless_monadic_calls}. + + TODO: use an option to disallow backward functions from updating the state. + TODO: a backward function which only gives back shared borrows shouldn't + update the state (state updates should only be used for mutable borrows, + with objects like Rc for instance). + *) + if + (not !Config.return_back_funs) + && !Config.filter_useless_monadic_calls + && outputs = [] && nstate = None then ( (* No outputs - we do a small sanity check: the backward function - * should have exactly the same number of inputs as the forward: - * this number can be different only if the forward function returned - * a value containing mutable borrows, which can't be the case... *) + should have exactly the same number of inputs as the forward: + this number can be different only if the forward function returned + a value containing mutable borrows, which can't be the case... *) assert (List.length inputs = List.length fwd_inputs); next_e) else @@ -2860,35 +2891,7 @@ and translate_forward_end (ectx : C.eval_ctx) (* Introduce variables for the backward functions. We lookup the LLBC definition in an attempt to derive pretty names for those functions. *) - let back_var_names = - let def_id = ctx.fun_decl.def_id in - let sg = ctx.fun_decl.signature in - let regions_hierarchy = - LlbcAstUtils.FunIdMap.find (FRegular def_id) - ctx.fun_ctx.regions_hierarchies - in - List.map - (fun (gid, _) -> - let rg = RegionGroupId.nth regions_hierarchy gid in - let region_names = - List.map - (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name) - rg.regions - in - let name = - match region_names with - | [] -> "back" - | [ Some r ] -> "back" ^ r - | _ -> - (* Concatenate all the region names *) - "back" - ^ String.concat "" (List.filter_map (fun x -> x) region_names) - in - Some name) - (RegionGroupId.Map.bindings ctx.sg.back_sg) - in - let back_vars = List.combine back_var_names (compute_back_tys ctx.sg) in - let _, back_vars = fresh_vars back_vars ctx in + let _, back_vars = fresh_back_vars_for_current_fun ctx in (* Create the return expressions *) let vars = fwd_var :: back_vars in @@ -2964,8 +2967,32 @@ and translate_forward_end (ectx : C.eval_ctx) (* Introduce a fresh output value for the forward function *) let ctx, output_var = fresh_var None ctx.sg.fwd_output ctx in + (* Introduce fresh variables for the backward functions of the loop. + + For now, the backward functions of the loop are the same as the + backward functions of the outer function. + *) + let ctx, back_funs_map, back_funs = + if !Config.return_back_funs then + let ctx, back_vars = fresh_back_vars_for_current_fun ctx in + let back_funs = + List.map (fun v -> mk_typed_pattern_from_var v None) back_vars + in + let gids = RegionGroupId.Map.keys ctx.sg.back_sg in + let back_funs_map = + RegionGroupId.Map.of_list + (List.combine gids (List.map mk_texpression_from_var back_vars)) + in + (ctx, Some back_funs_map, back_funs) + else (ctx, None, []) + in + + (* Introduce patterns *) let args, ctx, out_pats = + (* Create the pattern for the output value *) let output_pat = mk_typed_pattern_from_var output_var None in + (* Add the returned backward functions (they might be empty) *) + let output_pat = mk_simpl_tuple_pattern (output_pat :: back_funs) in (* Depending on the function effects: * - add the fuel @@ -2988,6 +3015,7 @@ and translate_forward_end (ectx : C.eval_ctx) loop_info with forward_inputs = Some args; forward_output_no_state_no_result = Some output_var; + back_funs = back_funs_map; } in let ctx = @@ -3143,6 +3171,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = forward_inputs = None; forward_output_no_state_no_result = None; back_outputs = rg_to_given_back_tys; + back_funs = None; } in let loops = LoopId.Map.add loop_id loop_info ctx.loops in -- cgit v1.2.3