From 834ea9747fced38f222aec251d2eaaf14a3328e2 Mon Sep 17 00:00:00 2001
From: Son Ho
Date: Thu, 1 Dec 2022 11:26:30 +0100
Subject: Improve some visitors and ctx_merge_regions

---
 compiler/Contexts.ml               | 68 +++++++++++++++++++++++++++++++++-----
 compiler/InterpreterBorrows.ml     | 28 +---------------
 compiler/InterpreterBorrowsCore.ml |  7 +---
 compiler/Substitute.ml             | 15 ++++++++-
 4 files changed, 76 insertions(+), 42 deletions(-)

(limited to 'compiler')

diff --git a/compiler/Contexts.ml b/compiler/Contexts.ml
index cb6a092f..69c4ec3b 100644
--- a/compiler/Contexts.ml
+++ b/compiler/Contexts.ml
@@ -14,6 +14,8 @@ open Identifiers
   *)
 module DummyVarId = IdGen ()
 
+type dummy_var_id = DummyVarId.id [@@deriving show, ord]
+
 (** Some global counters.
 
   Note that those counters were initially stored in {!eval_ctx} values,
@@ -104,29 +106,79 @@ let reset_global_counters () =
   fun_call_id_counter := FunCallId.generator_zero;
   dummy_var_id_counter := DummyVarId.generator_zero
 
+(** Ancestor for {!var_binder} iter visitor *)
+class ['self] iter_var_binder_base =
+  object (_self : 'self)
+    inherit [_] iter_abs
+    method visit_var_id : 'env -> var_id -> unit = fun _ _ -> ()
+    method visit_dummy_var_id : 'env -> dummy_var_id -> unit = fun _ _ -> ()
+  end
+
+(** Ancestor for {!var_binder} map visitor *)
+class ['self] map_var_binder_base =
+  object (_self : 'self)
+    inherit [_] map_abs
+    method visit_var_id : 'env -> var_id -> var_id = fun _ x -> x
+
+    method visit_dummy_var_id : 'env -> dummy_var_id -> dummy_var_id =
+      fun _ x -> x
+  end
+
 (** A binder used in an environment, to map a variable to a value *)
 type var_binder = {
-  index : VarId.id;  (** Unique variable identifier *)
+  index : var_id;  (** Unique variable identifier *)
   name : string option;  (** Possible name *)
 }
-[@@deriving show]
+[@@deriving
+  show,
+    visitors
+      {
+        name = "iter_var_binder";
+        variety = "iter";
+        ancestors = [ "iter_var_binder_base" ];
+        nude = true (* Don't inherit {!VisitorsRuntime.iter} *);
+        concrete = true;
+      },
+    visitors
+      {
+        name = "map_var_binder";
+        variety = "map";
+        ancestors = [ "map_var_binder_base" ];
+        nude = true (* Don't inherit {!VisitorsRuntime.iter} *);
+        concrete = true;
+      }]
 
 (** A binder, for a "real" variable or a dummy variable *)
-type binder = VarBinder of var_binder | DummyBinder of DummyVarId.id
-[@@deriving show]
+type binder = VarBinder of var_binder | DummyBinder of dummy_var_id
+[@@deriving
+  show,
+    visitors
+      {
+        name = "iter_binder";
+        variety = "iter";
+        ancestors = [ "iter_var_binder" ];
+        nude = true (* Don't inherit {!VisitorsRuntime.iter} *);
+        concrete = true;
+      },
+    visitors
+      {
+        name = "map_binder";
+        variety = "map";
+        ancestors = [ "map_var_binder" ];
+        nude = true (* Don't inherit {!VisitorsRuntime.iter} *);
+        concrete = true;
+      }]
 
 (** Ancestor for {!env_elem} iter visitor *)
 class ['self] iter_env_elem_base =
   object (_self : 'self)
-    inherit [_] iter_abs
-    method visit_binder : 'env -> binder -> unit = fun _ _ -> ()
+    inherit [_] iter_binder
   end
 
 (** Ancestor for {!env_elem} map visitor *)
 class ['self] map_env_elem_base =
   object (_self : 'self)
-    inherit [_] map_abs
-    method visit_binder : 'env -> binder -> binder = fun _ x -> x
+    inherit [_] map_binder
   end
 
 (** Environment value: mapping from variable to value, abstraction (only
diff --git a/compiler/InterpreterBorrows.ml b/compiler/InterpreterBorrows.ml
index 9a78e77a..fc97937d 100644
--- a/compiler/InterpreterBorrows.ml
+++ b/compiler/InterpreterBorrows.ml
@@ -2400,29 +2400,7 @@ let merge_abstractions_aux (abs_kind : V.abs_kind) (can_end : bool)
 let ctx_merge_regions (ctx : C.eval_ctx) (rid : T.RegionId.id)
     (rids : T.RegionId.Set.t) : C.eval_ctx =
   let rsubst x = if T.RegionId.Set.mem x rids then rid else x in
-  let merge_in_abs (abs : V.abs) : V.abs =
-    let avalues =
-      List.map (Substitute.typed_avalue_subst_rids rsubst) abs.V.avalues
-    in
-    let regions = T.RegionId.Set.diff abs.V.regions rids in
-    let ancestors_regions =
-      if T.RegionId.Set.disjoint abs.V.ancestors_regions rids then
-        abs.V.ancestors_regions
-      else
-        T.RegionId.Set.add rid
-          (T.RegionId.Set.diff abs.V.ancestors_regions rids)
-    in
-    { abs with V.avalues; regions; ancestors_regions }
-  in
-
-  let env =
-    List.map
-      (fun ee ->
-        match ee with
-        | C.Abs abs -> C.Abs (merge_in_abs abs)
-        | Var _ | Frame -> ee)
-      ctx.env
-  in
+  let env = Substitute.env_subst_rids rsubst ctx.env in
   { ctx with C.env }
 
 let merge_abstractions (abs_kind : V.abs_kind) (can_end : bool)
@@ -2458,9 +2436,5 @@ let merge_abstractions (abs_kind : V.abs_kind) (can_end : bool)
       ctx_merge_regions ctx rid rids
   in
 
-  (* Sanity check *)
-  (* Sanity check *)
-  if !Config.check_invariants then Invariants.check_invariants ctx;
-
   (* Return *)
   (ctx, nabs.abs_id)
diff --git a/compiler/InterpreterBorrowsCore.ml b/compiler/InterpreterBorrowsCore.ml
index 6db23cc4..fced4fbb 100644
--- a/compiler/InterpreterBorrowsCore.ml
+++ b/compiler/InterpreterBorrowsCore.ml
@@ -199,12 +199,7 @@ let compute_contexts_ids (ctxl : C.eval_ctx list) : ctx_ids =
   let obj =
     object
       inherit [_] C.iter_eval_ctx
-
-      method! visit_binder _ bv =
-        match bv with
-        | VarBinder _ -> ()
-        | DummyBinder bid -> dids := C.DummyVarId.Set.add bid !dids
-
+      method! visit_dummy_var_id _ did = dids := C.DummyVarId.Set.add did !dids
       method! visit_borrow_id _ id = bids := V.BorrowId.Set.add id !bids
 
       method! visit_loan_id _ id =
diff --git a/compiler/Substitute.ml b/compiler/Substitute.ml
index 8348424c..9adbbcba 100644
--- a/compiler/Substitute.ml
+++ b/compiler/Substitute.ml
@@ -397,7 +397,7 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id)
 
   let visitor =
     object (self : 'self)
-      inherit [_] V.map_abs
+      inherit [_] C.map_env
       method! visit_borrow_id _ bid = bsubst bid
       method! visit_loan_id _ bid = bsubst bid
       method! visit_ety _ ty = ty_substitute_ids tsubst ty
@@ -427,6 +427,7 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id)
       visitor#visit_typed_avalue () x
 
     method visit_abs (x : V.abs) : V.abs = visitor#visit_abs () x
+    method visit_env (env : C.env) : C.env = visitor#visit_env () env
   end
 
 let typed_value_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id)
@@ -468,3 +469,15 @@ let typed_avalue_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id)
      asubst)
     #visit_typed_avalue
     x
+
+let env_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) (x : C.env) : C.env
+    =
+  let asubst _ = raise (Failure "Unreachable") in
+  (subst_ids_visitor rsubst
+     (fun x -> x)
+     (fun x -> x)
+     (fun x -> x)
+     (fun x -> x)
+     asubst)
+    #visit_env
+    x
-- 
cgit v1.2.3