diff options
author | Son Ho | 2023-12-22 21:03:59 +0100 |
---|---|---|
committer | Son Ho | 2023-12-22 21:03:59 +0100 |
commit | dd7552bec1be1695682801fca6ba6dfcfa990fbb (patch) | |
tree | 2bf42baae680ef57deadce1fbe3824ad405dc055 /compiler | |
parent | 70d506d148e5ae1a3e4115034161f449aff666ed (diff) |
Update the computation of the effect info for the loops
Diffstat (limited to 'compiler')
-rw-r--r-- | compiler/SymbolicToPure.ml | 141 |
1 files changed, 95 insertions, 46 deletions
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml index bf92482a..f0d1ca62 100644 --- a/compiler/SymbolicToPure.ml +++ b/compiler/SymbolicToPure.ml @@ -134,6 +134,8 @@ type loop_info = { Initialized with [None], gets updated to [Some] only if we merge the fwd/back functions. *) + fwd_effect_info : fun_effect_info; + back_effect_infos : fun_effect_info RegionGroupId.Map.t; } [@@deriving show] @@ -922,17 +924,31 @@ let compute_raw_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t) let get_fun_effect_info (ctx : bs_ctx) (fun_id : A.fun_id_or_trait_method_ref) (lid : V.LoopId.id option) (gid : T.RegionGroupId.id option) : fun_effect_info = - match fun_id with - | TraitMethod (_, _, fid) | FunId (FRegular fid) -> - let dsg = A.FunDeclId.Map.find fid ctx.fun_dsigs in - let info = - match gid with - | None -> dsg.fwd_info.effect_info - | Some gid -> (RegionGroupId.Map.find gid dsg.back_sg).effect_info - in - { info with is_rec = info.is_rec || Option.is_some lid } - | FunId (FAssumed _) -> - compute_raw_fun_effect_info ctx.fun_ctx.fun_infos fun_id lid gid + match lid with + | None -> ( + match fun_id with + | TraitMethod (_, _, fid) | FunId (FRegular fid) -> + let dsg = A.FunDeclId.Map.find fid ctx.fun_dsigs in + let info = + match gid with + | None -> dsg.fwd_info.effect_info + | Some gid -> (RegionGroupId.Map.find gid dsg.back_sg).effect_info + in + { info with is_rec = info.is_rec || Option.is_some lid } + | FunId (FAssumed _) -> + compute_raw_fun_effect_info ctx.fun_ctx.fun_infos fun_id lid gid) + | Some lid -> ( + (* This is necessarily for the current function *) + match fun_id with + | FunId (FRegular fid) -> ( + assert (fid = ctx.fun_decl.def_id); + (* Lookup the loop *) + let lid = V.LoopId.Map.find lid ctx.loop_ids_map in + let loop_info = LoopId.Map.find lid ctx.loops in + match gid with + | None -> loop_info.fwd_effect_info + | Some gid -> RegionGroupId.Map.find gid loop_info.back_effect_infos) + | _ -> raise (Failure "Unreachable")) (** Translate a function signature to a decomposed function signature. @@ -1901,7 +1917,7 @@ and translate_panic (ctx : bs_ctx) : texpression = Remark: in case we merge the forward/backward functions, we introduce those in [translate_forward_end]. - *) +*) and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option) (ctx : bs_ctx) : texpression = (* There are two cases: @@ -3381,31 +3397,47 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = let ctx = !ctx in (* The output type of the loop function *) - let output_ty = + let fwd_effect_info = { ctx.sg.fwd_info.effect_info with is_rec = true } in + let back_effect_infos, output_ty = if !Config.return_back_funs then (* The loop backward functions consume the same additional inputs as the parent function, but have custom outputs *) - let back_sgs = RegionGroupId.Map.values ctx.sg.back_sg in + let back_sgs = RegionGroupId.Map.bindings ctx.sg.back_sg in let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in - let back_tys = - List.filter_map - (fun ((back_sg, given_back) : back_sg_info * ty list) -> + let back_info_tys = + List.map + (fun (((id, back_sg), given_back) : (_ * back_sg_info) * ty list) -> + (* Remark: the effect info of the backward function for the loop + is almost the same as for the backward function of the parent function. + Quite importantly, the fact that the function is stateful and/or can fail + mostly depends on whether it has inputs or not, and the backward functions + for the loops have the same inputs as the backward functions for the parent + function. + *) let effect_info = back_sg.effect_info in + let effect_info = { effect_info with is_rec = true } in (* Compute the input/output types *) let inputs = List.map snd back_sg.inputs in let outputs = given_back in (* Filter if necessary *) - if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] - then None - else - let output = mk_simpl_tuple_ty outputs in - let output = - mk_back_output_ty_from_effect_info effect_info inputs output - in - let ty = mk_arrows inputs output in - Some ty) + let ty = + if + !Config.simplify_merged_fwd_backs && inputs = [] && outputs = [] + then None + else + let output = mk_simpl_tuple_ty outputs in + let output = + mk_back_output_ty_from_effect_info effect_info inputs output + in + let ty = mk_arrows inputs output in + Some ty + in + ((id, effect_info), ty)) (List.combine back_sgs given_back_tys) in + let back_info = List.map fst back_info_tys in + let back_info = RegionGroupId.Map.of_list back_info in + let back_tys = List.filter_map snd back_info_tys in let output = if ctx.sg.fwd_info.ignore_output then back_tys else ctx.sg.fwd_output :: back_tys @@ -3416,27 +3448,42 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ] else output in - if effect_info.can_fail && inputs <> [] then mk_result_ty output - else output + let output = + if effect_info.can_fail && inputs <> [] then mk_result_ty output + else output + in + (back_info, output) else - match ctx.bid with - | None -> - (* Forward function: same type as the parent function *) - (translate_fun_sig_from_decomposed ctx.sg None).output - | Some rg_id -> - (* Backward function: custom return type *) - let doutputs = T.RegionGroupId.Map.find rg_id rg_to_given_back_tys in - let output = mk_simpl_tuple_ty doutputs in - let fwd_effect_info = ctx.sg.fwd_info.effect_info in - let output = - if fwd_effect_info.stateful then - mk_simpl_tuple_ty [ mk_state_ty; output ] - else output - in - let output = - if fwd_effect_info.can_fail then mk_result_ty output else output - in - output + let back_info = + RegionGroupId.Map.of_list + (List.map + (fun ((id, back_sg) : _ * back_sg_info) -> + (id, { back_sg.effect_info with is_rec = true })) + (RegionGroupId.Map.bindings ctx.sg.back_sg)) + in + let output = + match ctx.bid with + | None -> + (* Forward function: same type as the parent function *) + (translate_fun_sig_from_decomposed ctx.sg None).output + | Some rg_id -> + (* Backward function: custom return type *) + let doutputs = + T.RegionGroupId.Map.find rg_id rg_to_given_back_tys + in + let output = mk_simpl_tuple_ty doutputs in + let fwd_effect_info = ctx.sg.fwd_info.effect_info in + let output = + if fwd_effect_info.stateful then + mk_simpl_tuple_ty [ mk_state_ty; output ] + else output + in + let output = + if fwd_effect_info.can_fail then mk_result_ty output else output + in + output + in + (back_info, output) in (* Add the loop information in the context *) @@ -3480,6 +3527,8 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression = forward_output_no_state_no_result = None; back_outputs = rg_to_given_back_tys; back_funs = None; + fwd_effect_info; + back_effect_infos; } in let loops = LoopId.Map.add loop_id loop_info ctx.loops in |