From f1f41818fb14a6c46442ca42a49a3aab0a5b1aaf Mon Sep 17 00:00:00 2001 From: Son Ho Date: Thu, 14 Dec 2023 17:48:44 +0100 Subject: Make progress on generated merged fwd/back functions --- compiler/SymbolicToPure.ml | 56 ++++++++++++++++++++++++---------------------- compiler/Translate.ml | 4 +++- 2 files changed, 32 insertions(+), 28 deletions(-) (limited to 'compiler') diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index 1fd4896e..86c80f87 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -958,19 +958,7 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) (* Compute the additinal inputs for the current function, if it is a backward * function *) let back_inputs = - match gid with - | None -> [] - | Some gid -> - (* For now, we don't allow nested borrows, so the additional inputs to the - backward function can only come from borrows that were returned like - in (for the backward function we introduce for 'a): - {[ - fn f<'a>(...) -> &'a mut u32; - ]} - Upon ending the abstraction for 'a, we need to get back the borrow - the function returned. - *) - List.filter_map (translate_back_ty_for_gid gid) [ sg.output ] + match gid with None -> [] | Some gid -> translate_back_inputs_for_gid gid in (* If the function is stateful, the inputs are: - forward: [fwd_ty0, ..., fwd_tyn, state] @@ -989,11 +977,12 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) See {!effect_info}. *) if effect_info.stateful_group then [ mk_state_ty ] else [] in - let back_state_ty = + let mk_back_state_ty_for_gid (gid : RegionGroupId.id option) : ty list = (* For the backward state, we check if the function is a backward function, and it is stateful *) if effect_info.stateful && Option.is_some gid then [ mk_state_ty ] else [] in + let back_state_ty = mk_back_state_ty_for_gid gid in (* Concatenate the inputs, in the following order: * - forward inputs @@ -1072,22 +1061,35 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id) num_fwd_inputs_with_fuel_no_state + List.length fwd_state_ty; } in + let compute_back_info (back_state_ty : ty list) + (num_back_inputs_no_state : int) : inputs_info = + let n = num_back_inputs_no_state in + (* Note that backward functions never use fuel *) + { + has_fuel = false; + num_inputs_no_fuel_no_state = n; + num_inputs_with_fuel_no_state = n; + (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) + num_inputs_with_fuel_with_state = n + List.length back_state_ty; + } + in let back_info : back_inputs_info = if !Config.return_back_funs then + (* Create the map *) + AllBacks + (RegionGroupId.Map.of_list + (List.map + (fun (rg : T.region_var_group) -> + ( rg.id, + let back_inputs = translate_back_inputs_for_gid rg.id in + let num_back_inputs = List.length back_inputs in + (* TODO: slightly overkill *) + let back_state_ty = mk_back_state_ty_for_gid (Some rg.id) in + compute_back_info back_state_ty num_back_inputs )) + regions_hierarchy)) + else SingleBack - (Option.map - (fun n -> - (* Note that backward functions never use fuel *) - { - has_fuel = false; - num_inputs_no_fuel_no_state = n; - num_inputs_with_fuel_no_state = n; - (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *) - num_inputs_with_fuel_with_state = n + List.length back_state_ty; - }) - num_back_inputs_no_state) - else (* Create the map *) - failwith "TODO" + (Option.map (compute_back_info back_state_ty) num_back_inputs_no_state) in let info = { fwd_info; back_info; effect_info } in assert (fun_sig_info_is_wf info); diff --git a/compiler/Translate.ml b/compiler/Translate.ml index 54e24066..06d4bd6d 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -223,7 +223,9 @@ let translate_function_to_pure (trans_ctx : trans_ctx) sg.info.fwd_info.num_inputs_with_fuel_with_state in let num_back_inputs = - (Option.get sg.info.back_info).num_inputs_no_fuel_no_state + match sg.info.back_info with + | SingleBack (Some info) -> info.num_inputs_no_fuel_no_state + | _ -> raise (Failure "Unexpected") in Collections.List.subslice sg.inputs num_forward_inputs (num_forward_inputs + num_back_inputs) -- cgit v1.2.3