From dd7552bec1be1695682801fca6ba6dfcfa990fbb Mon Sep 17 00:00:00 2001
From: Son Ho
Date: Fri, 22 Dec 2023 21:03:59 +0100
Subject: Update the computation of the effect info for the loops

---
 compiler/SymbolicToPure.ml | 141 ++++++++++++++++++++++++++++++---------------
 1 file changed, 95 insertions(+), 46 deletions(-)

(limited to 'compiler')

diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index bf92482a..f0d1ca62 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -134,6 +134,8 @@ type loop_info = {
           Initialized with [None], gets updated to [Some] only if we merge
           the fwd/back functions.
        *)
+  fwd_effect_info : fun_effect_info;
+  back_effect_infos : fun_effect_info RegionGroupId.Map.t;
 }
 [@@deriving show]
 
@@ -922,17 +924,31 @@ let compute_raw_fun_effect_info (fun_infos : fun_info A.FunDeclId.Map.t)
 let get_fun_effect_info (ctx : bs_ctx) (fun_id : A.fun_id_or_trait_method_ref)
     (lid : V.LoopId.id option) (gid : T.RegionGroupId.id option) :
     fun_effect_info =
-  match fun_id with
-  | TraitMethod (_, _, fid) | FunId (FRegular fid) ->
-      let dsg = A.FunDeclId.Map.find fid ctx.fun_dsigs in
-      let info =
-        match gid with
-        | None -> dsg.fwd_info.effect_info
-        | Some gid -> (RegionGroupId.Map.find gid dsg.back_sg).effect_info
-      in
-      { info with is_rec = info.is_rec || Option.is_some lid }
-  | FunId (FAssumed _) ->
-      compute_raw_fun_effect_info ctx.fun_ctx.fun_infos fun_id lid gid
+  match lid with
+  | None -> (
+      match fun_id with
+      | TraitMethod (_, _, fid) | FunId (FRegular fid) ->
+          let dsg = A.FunDeclId.Map.find fid ctx.fun_dsigs in
+          let info =
+            match gid with
+            | None -> dsg.fwd_info.effect_info
+            | Some gid -> (RegionGroupId.Map.find gid dsg.back_sg).effect_info
+          in
+          { info with is_rec = info.is_rec || Option.is_some lid }
+      | FunId (FAssumed _) ->
+          compute_raw_fun_effect_info ctx.fun_ctx.fun_infos fun_id lid gid)
+  | Some lid -> (
+      (* This is necessarily for the current function *)
+      match fun_id with
+      | FunId (FRegular fid) -> (
+          assert (fid = ctx.fun_decl.def_id);
+          (* Lookup the loop *)
+          let lid = V.LoopId.Map.find lid ctx.loop_ids_map in
+          let loop_info = LoopId.Map.find lid ctx.loops in
+          match gid with
+          | None -> loop_info.fwd_effect_info
+          | Some gid -> RegionGroupId.Map.find gid loop_info.back_effect_infos)
+      | _ -> raise (Failure "Unreachable"))
 
 (** Translate a function signature to a decomposed function signature.
 
@@ -1901,7 +1917,7 @@ and translate_panic (ctx : bs_ctx) : texpression =
 
     Remark: in case we merge the forward/backward functions, we introduce
     those in [translate_forward_end].
- *)
+*)
 and translate_return (ectx : C.eval_ctx) (opt_v : V.typed_value option)
     (ctx : bs_ctx) : texpression =
   (* There are two cases:
@@ -3381,31 +3397,47 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
   let ctx = !ctx in
 
   (* The output type of the loop function *)
-  let output_ty =
+  let fwd_effect_info = { ctx.sg.fwd_info.effect_info with is_rec = true } in
+  let back_effect_infos, output_ty =
     if !Config.return_back_funs then
       (* The loop backward functions consume the same additional inputs as the parent
          function, but have custom outputs *)
-      let back_sgs = RegionGroupId.Map.values ctx.sg.back_sg in
+      let back_sgs = RegionGroupId.Map.bindings ctx.sg.back_sg in
       let given_back_tys = RegionGroupId.Map.values rg_to_given_back_tys in
-      let back_tys =
-        List.filter_map
-          (fun ((back_sg, given_back) : back_sg_info * ty list) ->
+      let back_info_tys =
+        List.map
+          (fun (((id, back_sg), given_back) : (_ * back_sg_info) * ty list) ->
+            (* Remark: the effect info of the backward function for the loop
+               is almost the same as for the backward function of the parent function.
+               Quite importantly, the fact that the function is stateful and/or can fail
+               mostly depends on whether it has inputs or not, and the backward functions
+               for the loops have the same inputs as the backward functions for the parent
+               function.
+            *)
             let effect_info = back_sg.effect_info in
+            let effect_info = { effect_info with is_rec = true } in
             (* Compute the input/output types *)
             let inputs = List.map snd back_sg.inputs in
             let outputs = given_back in
             (* Filter if necessary *)
-            if !Config.simplify_merged_fwd_backs && inputs = [] && outputs = []
-            then None
-            else
-              let output = mk_simpl_tuple_ty outputs in
-              let output =
-                mk_back_output_ty_from_effect_info effect_info inputs output
-              in
-              let ty = mk_arrows inputs output in
-              Some ty)
+            let ty =
+              if
+                !Config.simplify_merged_fwd_backs && inputs = [] && outputs = []
+              then None
+              else
+                let output = mk_simpl_tuple_ty outputs in
+                let output =
+                  mk_back_output_ty_from_effect_info effect_info inputs output
+                in
+                let ty = mk_arrows inputs output in
+                Some ty
+            in
+            ((id, effect_info), ty))
           (List.combine back_sgs given_back_tys)
       in
+      let back_info = List.map fst back_info_tys in
+      let back_info = RegionGroupId.Map.of_list back_info in
+      let back_tys = List.filter_map snd back_info_tys in
       let output =
         if ctx.sg.fwd_info.ignore_output then back_tys
         else ctx.sg.fwd_output :: back_tys
@@ -3416,27 +3448,42 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
         if effect_info.stateful then mk_simpl_tuple_ty [ mk_state_ty; output ]
         else output
       in
-      if effect_info.can_fail && inputs <> [] then mk_result_ty output
-      else output
+      let output =
+        if effect_info.can_fail && inputs <> [] then mk_result_ty output
+        else output
+      in
+      (back_info, output)
     else
-      match ctx.bid with
-      | None ->
-          (* Forward function: same type as the parent function *)
-          (translate_fun_sig_from_decomposed ctx.sg None).output
-      | Some rg_id ->
-          (* Backward function: custom return type *)
-          let doutputs = T.RegionGroupId.Map.find rg_id rg_to_given_back_tys in
-          let output = mk_simpl_tuple_ty doutputs in
-          let fwd_effect_info = ctx.sg.fwd_info.effect_info in
-          let output =
-            if fwd_effect_info.stateful then
-              mk_simpl_tuple_ty [ mk_state_ty; output ]
-            else output
-          in
-          let output =
-            if fwd_effect_info.can_fail then mk_result_ty output else output
-          in
-          output
+      let back_info =
+        RegionGroupId.Map.of_list
+          (List.map
+             (fun ((id, back_sg) : _ * back_sg_info) ->
+               (id, { back_sg.effect_info with is_rec = true }))
+             (RegionGroupId.Map.bindings ctx.sg.back_sg))
+      in
+      let output =
+        match ctx.bid with
+        | None ->
+            (* Forward function: same type as the parent function *)
+            (translate_fun_sig_from_decomposed ctx.sg None).output
+        | Some rg_id ->
+            (* Backward function: custom return type *)
+            let doutputs =
+              T.RegionGroupId.Map.find rg_id rg_to_given_back_tys
+            in
+            let output = mk_simpl_tuple_ty doutputs in
+            let fwd_effect_info = ctx.sg.fwd_info.effect_info in
+            let output =
+              if fwd_effect_info.stateful then
+                mk_simpl_tuple_ty [ mk_state_ty; output ]
+              else output
+            in
+            let output =
+              if fwd_effect_info.can_fail then mk_result_ty output else output
+            in
+            output
+      in
+      (back_info, output)
   in
 
   (* Add the loop information in the context *)
@@ -3480,6 +3527,8 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
         forward_output_no_state_no_result = None;
         back_outputs = rg_to_given_back_tys;
         back_funs = None;
+        fwd_effect_info;
+        back_effect_infos;
       }
     in
     let loops = LoopId.Map.add loop_id loop_info ctx.loops in
-- 
cgit v1.2.3