diff options
Diffstat (limited to 'compiler/SymbolicToPure.ml')
-rw-r--r-- | compiler/SymbolicToPure.ml | 65 |
1 files changed, 59 insertions, 6 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index cd367d83..bf92482a 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -3368,7 +3368,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = (* Compute the backward outputs *) let ctx = ref ctx in let rg_to_given_back_tys = - T.RegionGroupId.Map.map + RegionGroupId.Map.map (fun (_, tys) -> (* The types shouldn't contain borrows - we can translate them as forward types *) List.map @@ -3380,10 +3380,63 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = in let ctx = !ctx in - let back_output_tys = - match ctx.bid with - | None -> None - | Some rg_id -> Some (T.RegionGroupId.Map.find rg_id rg_to_given_back_tys) + (* The output type of the loop function *) + let output_ty = + if !Config.return_back_funs then + (* The loop backward functions consume the same additional inputs as the parent + function, but have custom outputs *) + let back_sgs = RegionGroupId.Map.values ctx.sg.back_sg in + let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in + let back_tys = + List.filter_map + (fun ((back_sg, given_back) : back_sg_info * ty list) -> + let effect_info = back_sg.effect_info in + (* Compute the input/output types *) + let inputs = List.map snd back_sg.inputs in + let outputs = given_back in + (* Filter if necessary *) + if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] + then None + else + let output = mk_simpl_tuple_ty outputs in + let output = + mk_back_output_ty_from_effect_info effect_info inputs output + in + let ty = mk_arrows inputs output in + Some ty) + (List.combine back_sgs given_back_tys) + in + let output = + if ctx.sg.fwd_info.ignore_output then back_tys + else ctx.sg.fwd_output :: back_tys + in + let output = mk_simpl_tuple_ty output in + let effect_info = ctx.sg.fwd_info.effect_info in + let output = + if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] + else output + in + if effect_info.can_fail && inputs <> [] then mk_result_ty output + else output + else + match ctx.bid with + | None -> + (* Forward function: same type as the parent function *) + (translate_fun_sig_from_decomposed ctx.sg None).output + | Some rg_id -> + (* Backward function: custom return type *) + let doutputs = T.RegionGroupId.Map.find rg_id rg_to_given_back_tys in + let output = mk_simpl_tuple_ty doutputs in + let fwd_effect_info = ctx.sg.fwd_info.effect_info in + let output = + if fwd_effect_info.stateful then + mk_simpl_tuple_ty [ mk_state_ty; output ] + else output + in + let output = + if fwd_effect_info.can_fail then mk_result_ty output else output + in + output in (* Add the loop information in the context *) @@ -3460,7 +3513,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = input_state; inputs; inputs_lvs; - back_output_tys; + output_ty; loop_body; } in |