summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-12-14 17:48:44 +0100
committerSon Ho2023-12-14 17:48:44 +0100
commitf1f41818fb14a6c46442ca42a49a3aab0a5b1aaf (patch)
treee2fb7a4a227ed5699b0535ffe5289344f738ab81
parentf074320eee2203857e669cfb72f7f8f94ab52151 (diff)
Make progress on generated merged fwd/back functions
-rw-r--r--compiler/SymbolicToPure.ml56
-rw-r--r--compiler/Translate.ml4
2 files changed, 32 insertions, 28 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 1fd4896e..86c80f87 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -958,19 +958,7 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id)
(* Compute the additinal inputs for the current function, if it is a backward
* function *)
let back_inputs =
- match gid with
- | None -> []
- | Some gid ->
- (* For now, we don't allow nested borrows, so the additional inputs to the
- backward function can only come from borrows that were returned like
- in (for the backward function we introduce for 'a):
- {[
- fn f<'a>(...) -> &'a mut u32;
- ]}
- Upon ending the abstraction for 'a, we need to get back the borrow
- the function returned.
- *)
- List.filter_map (translate_back_ty_for_gid gid) [ sg.output ]
+ match gid with None -> [] | Some gid -> translate_back_inputs_for_gid gid
in
(* If the function is stateful, the inputs are:
- forward: [fwd_ty0, ..., fwd_tyn, state]
@@ -989,11 +977,12 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id)
See {!effect_info}. *)
if effect_info.stateful_group then [ mk_state_ty ] else []
in
- let back_state_ty =
+ let mk_back_state_ty_for_gid (gid : RegionGroupId.id option) : ty list =
(* For the backward state, we check if the function is a backward function,
and it is stateful *)
if effect_info.stateful && Option.is_some gid then [ mk_state_ty ] else []
in
+ let back_state_ty = mk_back_state_ty_for_gid gid in
(* Concatenate the inputs, in the following order:
* - forward inputs
@@ -1072,22 +1061,35 @@ let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id)
num_fwd_inputs_with_fuel_no_state + List.length fwd_state_ty;
}
in
+ let compute_back_info (back_state_ty : ty list)
+ (num_back_inputs_no_state : int) : inputs_info =
+ let n = num_back_inputs_no_state in
+ (* Note that backward functions never use fuel *)
+ {
+ has_fuel = false;
+ num_inputs_no_fuel_no_state = n;
+ num_inputs_with_fuel_no_state = n;
+ (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *)
+ num_inputs_with_fuel_with_state = n + List.length back_state_ty;
+ }
+ in
let back_info : back_inputs_info =
if !Config.return_back_funs then
+ (* Create the map *)
+ AllBacks
+ (RegionGroupId.Map.of_list
+ (List.map
+ (fun (rg : T.region_var_group) ->
+ ( rg.id,
+ let back_inputs = translate_back_inputs_for_gid rg.id in
+ let num_back_inputs = List.length back_inputs in
+ (* TODO: slightly overkill *)
+ let back_state_ty = mk_back_state_ty_for_gid (Some rg.id) in
+ compute_back_info back_state_ty num_back_inputs ))
+ regions_hierarchy))
+ else
SingleBack
- (Option.map
- (fun n ->
- (* Note that backward functions never use fuel *)
- {
- has_fuel = false;
- num_inputs_no_fuel_no_state = n;
- num_inputs_with_fuel_no_state = n;
- (* Length of [back_state_ty]: similar trick as for [fwd_state_ty] *)
- num_inputs_with_fuel_with_state = n + List.length back_state_ty;
- })
- num_back_inputs_no_state)
- else (* Create the map *)
- failwith "TODO"
+ (Option.map (compute_back_info back_state_ty) num_back_inputs_no_state)
in
let info = { fwd_info; back_info; effect_info } in
assert (fun_sig_info_is_wf info);
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 54e24066..06d4bd6d 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -223,7 +223,9 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
sg.info.fwd_info.num_inputs_with_fuel_with_state
in
let num_back_inputs =
- (Option.get sg.info.back_info).num_inputs_no_fuel_no_state
+ match sg.info.back_info with
+ | SingleBack (Some info) -> info.num_inputs_no_fuel_no_state
+ | _ -> raise (Failure "Unexpected")
in
Collections.List.subslice sg.inputs num_forward_inputs
(num_forward_inputs + num_back_inputs)