summaryrefslogtreecommitdiff
path: root/compiler/SymbolicToPure.ml
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--compiler/SymbolicToPure.ml65
1 files changed, 59 insertions, 6 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index cd367d83..bf92482a 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -3368,7 +3368,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
(* Compute the backward outputs *)
let ctx = ref ctx in
let rg_to_given_back_tys =
- T.RegionGroupId.Map.map
+ RegionGroupId.Map.map
(fun (_, tys) ->
(* The types shouldn't contain borrows - we can translate them as forward types *)
List.map
@@ -3380,10 +3380,63 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
in
let ctx = !ctx in
- let back_output_tys =
- match ctx.bid with
- | None -> None
- | Some rg_id -> Some (T.RegionGroupId.Map.find rg_id rg_to_given_back_tys)
+ (* The output type of the loop function *)
+ let output_ty =
+ if !Config.return_back_funs then
+ (* The loop backward functions consume the same additional inputs as the parent
+ function, but have custom outputs *)
+ let back_sgs = RegionGroupId.Map.values ctx.sg.back_sg in
+ let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in
+ let back_tys =
+ List.filter_map
+ (fun ((back_sg, given_back) : back_sg_info * ty list) ->
+ let effect_info = back_sg.effect_info in
+ (* Compute the input/output types *)
+ let inputs = List.map snd back_sg.inputs in
+ let outputs = given_back in
+ (* Filter if necessary *)
+ if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = []
+ then None
+ else
+ let output = mk_simpl_tuple_ty outputs in
+ let output =
+ mk_back_output_ty_from_effect_info effect_info inputs output
+ in
+ let ty = mk_arrows inputs output in
+ Some ty)
+ (List.combine back_sgs given_back_tys)
+ in
+ let output =
+ if ctx.sg.fwd_info.ignore_output then back_tys
+ else ctx.sg.fwd_output :: back_tys
+ in
+ let output = mk_simpl_tuple_ty output in
+ let effect_info = ctx.sg.fwd_info.effect_info in
+ let output =
+ if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ]
+ else output
+ in
+ if effect_info.can_fail && inputs <> [] then mk_result_ty output
+ else output
+ else
+ match ctx.bid with
+ | None ->
+ (* Forward function: same type as the parent function *)
+ (translate_fun_sig_from_decomposed ctx.sg None).output
+ | Some rg_id ->
+ (* Backward function: custom return type *)
+ let doutputs = T.RegionGroupId.Map.find rg_id rg_to_given_back_tys in
+ let output = mk_simpl_tuple_ty doutputs in
+ let fwd_effect_info = ctx.sg.fwd_info.effect_info in
+ let output =
+ if fwd_effect_info.stateful then
+ mk_simpl_tuple_ty [ mk_state_ty; output ]
+ else output
+ in
+ let output =
+ if fwd_effect_info.can_fail then mk_result_ty output else output
+ in
+ output
in
(* Add the loop information in the context *)
@@ -3460,7 +3513,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
input_state;
inputs;
inputs_lvs;
- back_output_tys;
+ output_ty;
loop_body;
}
in