summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorSon Ho2023-12-22 21:03:59 +0100
committerSon Ho2023-12-22 21:03:59 +0100
commitdd7552bec1be1695682801fca6ba6dfcfa990fbb (patch)
tree2bf42baae680ef57deadce1fbe3824ad405dc055 /compiler
parent70d506d148e5ae1a3e4115034161f449aff666ed (diff)
Update the computation of the effect info for the loops
Diffstat (limited to '')
-rw-r--r--compiler/SymbolicToPure.ml141
1 files changed, 95 insertions, 46 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index bf92482a..f0d1ca62 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -134,6 +134,8 @@ type loop_info = {
Initialized with [None], gets updated to [Some] only if we merge
the fwd/back functions.
*)
+ fwd_effect_info : fun_effect_info;
+ back_effect_infos : fun_effect_info RegionGroupId.Map.t;
}
[@@deriving show]
@@ -922,17 +924,31 @@ let compute_raw_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t)
let get_fun_effect_info (ctx : bs_ctx) (fun_id : A.fun_id_or_trait_method_ref)
(lid : V.LoopId.id option) (gid : T.RegionGroupId.id option) :
fun_effect_info =
- match fun_id with
- | TraitMethod (_, _, fid) | FunId (FRegular fid) ->
- let dsg = A.FunDeclId.Map.find fid ctx.fun_dsigs in
- let info =
- match gid with
- | None -> dsg.fwd_info.effect_info
- | Some gid -> (RegionGroupId.Map.find gid dsg.back_sg).effect_info
- in
- { info with is_rec = info.is_rec || Option.is_some lid }
- | FunId (FAssumed _) ->
- compute_raw_fun_effect_info ctx.fun_ctx.fun_infos fun_id lid gid
+ match lid with
+ | None -> (
+ match fun_id with
+ | TraitMethod (_, _, fid) | FunId (FRegular fid) ->
+ let dsg = A.FunDeclId.Map.find fid ctx.fun_dsigs in
+ let info =
+ match gid with
+ | None -> dsg.fwd_info.effect_info
+ | Some gid -> (RegionGroupId.Map.find gid dsg.back_sg).effect_info
+ in
+ { info with is_rec = info.is_rec || Option.is_some lid }
+ | FunId (FAssumed _) ->
+ compute_raw_fun_effect_info ctx.fun_ctx.fun_infos fun_id lid gid)
+ | Some lid -> (
+ (* This is necessarily for the current function *)
+ match fun_id with
+ | FunId (FRegular fid) -> (
+ assert (fid = ctx.fun_decl.def_id);
+ (* Lookup the loop *)
+ let lid = V.LoopId.Map.find lid ctx.loop_ids_map in
+ let loop_info = LoopId.Map.find lid ctx.loops in
+ match gid with
+ | None -> loop_info.fwd_effect_info
+ | Some gid -> RegionGroupId.Map.find gid loop_info.back_effect_infos)
+ | _ -> raise (Failure "Unreachable"))
(** Translate a function signature to a decomposed function signature.
@@ -1901,7 +1917,7 @@ and translate_panic (ctx : bs_ctx) : texpression =
Remark: in case we merge the forward/backward functions, we introduce
those in [translate_forward_end].
- *)
+*)
and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option)
(ctx : bs_ctx) : texpression =
(* There are two cases:
@@ -3381,31 +3397,47 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
let ctx = !ctx in
(* The output type of the loop function *)
- let output_ty =
+ let fwd_effect_info = { ctx.sg.fwd_info.effect_info with is_rec = true } in
+ let back_effect_infos, output_ty =
if !Config.return_back_funs then
(* The loop backward functions consume the same additional inputs as the parent
function, but have custom outputs *)
- let back_sgs = RegionGroupId.Map.values ctx.sg.back_sg in
+ let back_sgs = RegionGroupId.Map.bindings ctx.sg.back_sg in
let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in
- let back_tys =
- List.filter_map
- (fun ((back_sg, given_back) : back_sg_info * ty list) ->
+ let back_info_tys =
+ List.map
+ (fun (((id, back_sg), given_back) : (_ * back_sg_info) * ty list) ->
+ (* Remark: the effect info of the backward function for the loop
+ is almost the same as for the backward function of the parent function.
+ Quite importantly, the fact that the function is stateful and/or can fail
+ mostly depends on whether it has inputs or not, and the backward functions
+ for the loops have the same inputs as the backward functions for the parent
+ function.
+ *)
let effect_info = back_sg.effect_info in
+ let effect_info = { effect_info with is_rec = true } in
(* Compute the input/output types *)
let inputs = List.map snd back_sg.inputs in
let outputs = given_back in
(* Filter if necessary *)
- if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = []
- then None
- else
- let output = mk_simpl_tuple_ty outputs in
- let output =
- mk_back_output_ty_from_effect_info effect_info inputs output
- in
- let ty = mk_arrows inputs output in
- Some ty)
+ let ty =
+ if
+ !Config.simplify_merged_fwd_backs && inputs = [] && outputs = []
+ then None
+ else
+ let output = mk_simpl_tuple_ty outputs in
+ let output =
+ mk_back_output_ty_from_effect_info effect_info inputs output
+ in
+ let ty = mk_arrows inputs output in
+ Some ty
+ in
+ ((id, effect_info), ty))
(List.combine back_sgs given_back_tys)
in
+ let back_info = List.map fst back_info_tys in
+ let back_info = RegionGroupId.Map.of_list back_info in
+ let back_tys = List.filter_map snd back_info_tys in
let output =
if ctx.sg.fwd_info.ignore_output then back_tys
else ctx.sg.fwd_output :: back_tys
@@ -3416,27 +3448,42 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ]
else output
in
- if effect_info.can_fail && inputs <> [] then mk_result_ty output
- else output
+ let output =
+ if effect_info.can_fail && inputs <> [] then mk_result_ty output
+ else output
+ in
+ (back_info, output)
else
- match ctx.bid with
- | None ->
- (* Forward function: same type as the parent function *)
- (translate_fun_sig_from_decomposed ctx.sg None).output
- | Some rg_id ->
- (* Backward function: custom return type *)
- let doutputs = T.RegionGroupId.Map.find rg_id rg_to_given_back_tys in
- let output = mk_simpl_tuple_ty doutputs in
- let fwd_effect_info = ctx.sg.fwd_info.effect_info in
- let output =
- if fwd_effect_info.stateful then
- mk_simpl_tuple_ty [ mk_state_ty; output ]
- else output
- in
- let output =
- if fwd_effect_info.can_fail then mk_result_ty output else output
- in
- output
+ let back_info =
+ RegionGroupId.Map.of_list
+ (List.map
+ (fun ((id, back_sg) : _ * back_sg_info) ->
+ (id, { back_sg.effect_info with is_rec = true }))
+ (RegionGroupId.Map.bindings ctx.sg.back_sg))
+ in
+ let output =
+ match ctx.bid with
+ | None ->
+ (* Forward function: same type as the parent function *)
+ (translate_fun_sig_from_decomposed ctx.sg None).output
+ | Some rg_id ->
+ (* Backward function: custom return type *)
+ let doutputs =
+ T.RegionGroupId.Map.find rg_id rg_to_given_back_tys
+ in
+ let output = mk_simpl_tuple_ty doutputs in
+ let fwd_effect_info = ctx.sg.fwd_info.effect_info in
+ let output =
+ if fwd_effect_info.stateful then
+ mk_simpl_tuple_ty [ mk_state_ty; output ]
+ else output
+ in
+ let output =
+ if fwd_effect_info.can_fail then mk_result_ty output else output
+ in
+ output
+ in
+ (back_info, output)
in
(* Add the loop information in the context *)
@@ -3480,6 +3527,8 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
forward_output_no_state_no_result = None;
back_outputs = rg_to_given_back_tys;
back_funs = None;
+ fwd_effect_info;
+ back_effect_infos;
}
in
let loops = LoopId.Map.add loop_id loop_info ctx.loops in