From e6f002cfc1dfa41362bbb3a005c4261d09c52c58 Mon Sep 17 00:00:00 2001
From: Son Ho
Date: Wed, 20 Mar 2024 06:17:13 +0100
Subject: Improve the generation of pretty name and the micro passes

---
 compiler/InterpreterBorrows.ml    |  9 +++++-
 compiler/InterpreterStatements.ml |  4 ++-
 compiler/PrintPure.ml             | 25 +++++++++++---
 compiler/Pure.ml                  |  8 ++++-
 compiler/PureMicroPasses.ml       | 51 +++++++++++++++++++++--------
 compiler/SymbolicAst.ml           |  3 ++
 compiler/SymbolicToPure.ml        | 68 +++++++++++++++++++++++++++++++++------
 compiler/SynthesizeSymbolic.ml    |  6 ++++
 8 files changed, 144 insertions(+), 30 deletions(-)

(limited to 'compiler')

diff --git a/compiler/InterpreterBorrows.ml b/compiler/InterpreterBorrows.ml
index 03b2b045..17810705 100644
--- a/compiler/InterpreterBorrows.ml
+++ b/compiler/InterpreterBorrows.ml
@@ -935,7 +935,11 @@ let rec end_borrow_aux (config : config) (chain : borrow_or_abs_ids)
       (* Give back the value *)
       let ctx = give_back config l bc ctx in
       (* Do a sanity check and continue *)
-      cf_check cf ctx
+      let cc = cf_check in
+      (* Save a snapshot of the environment for the name generation *)
+      let cc = comp cc SynthesizeSymbolic.cf_save_snapshot in
+      (* Compose *)
+      cc cf ctx
 
 and end_borrows_aux (config : config) (chain : borrow_or_abs_ids)
     (allowed_abs : AbstractionId.id option) (lset : BorrowId.Set.t) : cm_fun =
@@ -1041,6 +1045,9 @@ and end_abstraction_aux (config : config) (chain : borrow_or_abs_ids)
       (* Sanity check: ending an abstraction must preserve the invariants *)
       let cc = comp cc Invariants.cf_check_invariants in
 
+      (* Save a snapshot of the environment for the name generation *)
+      let cc = comp cc SynthesizeSymbolic.cf_save_snapshot in
+
       (* Apply the continuation *)
       cc cf ctx
 
diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml
index 95a2956b..6b9f47ce 100644
--- a/compiler/InterpreterStatements.ml
+++ b/compiler/InterpreterStatements.ml
@@ -931,9 +931,11 @@ let rec eval_statement (config : config) (st : statement) : st_cm_fun =
       ^ statement_to_string_with_tab ctx st
       ^ "\n]\n\n**Context**:\n" ^ eval_ctx_to_string ctx ^ "\n\n"));
 
+  (* Take a snapshot of the current context for the purpose of generating pretty names *)
+  let cc = S.cf_save_snapshot in
   (* Expand the symbolic values if necessary - we need to do that before
    * checking the invariants *)
-  let cc = greedy_expand_symbolic_values config in
+  let cc = comp cc (greedy_expand_symbolic_values config) in
   (* Sanity check *)
   let cc = comp cc Invariants.cf_check_invariants in
 
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index a401594d..00a431a0 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -585,7 +585,7 @@ let rec texpression_to_string (env : fmt_env) (inside : bool) (indent : string)
       let meta_s = emeta_to_string env meta in
       let e = texpression_to_string env inside indent indent_incr e in
       match meta with
-      | Assignment _ | SymbolicAssignment _ | Tag _ ->
+      | Assignment _ | SymbolicAssignments _ | SymbolicPlaces _ | Tag _ ->
           let e = meta_s ^ "\n" ^ indent ^ e in
           if inside then "(" ^ e ^ ")" else e
       | MPlace _ -> "(" ^ meta_s ^ " " ^ e ^ ")")
@@ -717,10 +717,25 @@ and emeta_to_string (env : fmt_env) (meta : emeta) : string =
         "@assign(" ^ mplace_to_string env lp ^ " := "
         ^ texpression_to_string env false "" "" rv
         ^ rp ^ ")"
-    | SymbolicAssignment (var_id, rv) ->
-        "@symb_assign(" ^ VarId.to_string var_id ^ " := "
-        ^ texpression_to_string env false "" "" rv
-        ^ ")"
+    | SymbolicAssignments info ->
+        let infos =
+          List.map
+            (fun (var_id, rv) ->
+              VarId.to_string var_id ^ " == "
+              ^ texpression_to_string env false "" "" rv)
+            info
+        in
+        let infos = String.concat ", " infos in
+        "@symb_assign(" ^ infos ^ ")"
+    | SymbolicPlaces info ->
+        let infos =
+          List.map
+            (fun (var_id, name) ->
+              VarId.to_string var_id ^ " == \"" ^ name ^ "\"")
+            info
+        in
+        let infos = String.concat ", " infos in
+        "@symb_places(" ^ infos ^ ")"
     | MPlace mp -> "@mplace=" ^ mplace_to_string env mp
     | Tag msg -> "@tag \"" ^ msg ^ "\""
   in
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index cf6710aa..7de7e0f4 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -807,12 +807,18 @@ and emeta =
           The mvalue stores the value which is put in the destination
           The second (optional) mplace stores the origin.
         *)
-  | SymbolicAssignment of (var_id[@opaque]) * mvalue
+  | SymbolicAssignments of ((var_id[@opaque]) * mvalue) list
       (** Informationg linking a variable (from the pure AST) to an
           expression.
 
           We use this to guide the heuristics which derive pretty names.
         *)
+  | SymbolicPlaces of ((var_id[@opaque]) * string) list
+      (** Informationg linking a variable (from the pure AST) to a name.
+
+          We generate this information by exploring the context, and use it
+          to derive pretty names.
+        *)
   | MPlace of mplace  (** Meta-information about the origin of a value *)
   | Tag of string  (** A tag - typically used for debugging *)
 [@@deriving
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index 0ac0851e..a1f6ce33 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -334,6 +334,15 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
     let ctx1 = obj#visit_typed_pattern () lv () in
     merge_ctxs ctx ctx1
   in
+  (* If we have [x = y] where x and y are variables, add a constraint linking
+     the names of x and y *)
+  let add_eq_var_constraint (lv : typed_pattern) (re : texpression)
+      (ctx : pn_ctx) : pn_ctx =
+    match (lv.value, re.e) with
+    | PatVar (lv, _), Var rv when Option.is_some lv.basename ->
+        add_pure_var_constraint rv (Option.get lv.basename) ctx
+    | _ -> ctx
+  in
 
   (* This is used to propagate constraint information about places in case of
    * variable reassignments: we try to propagate the information from the
@@ -428,6 +437,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
     let ctx, re = update_texpression re ctx in
     let ctx, e = update_texpression e ctx in
     let lv = update_typed_pattern ctx lv in
+    let ctx = add_eq_var_constraint lv re ctx in
     (ctx, Let (monadic, lv, re, e))
   (* *)
   and update_switch_body (scrut : texpression) (body : switch_body)
@@ -524,8 +534,15 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
             | _ -> ctx
           in
           ctx
-      | SymbolicAssignment (var_id, rvalue) ->
-          add_pure_var_value_constraint var_id rvalue ctx
+      | SymbolicAssignments infos ->
+          List.fold_left
+            (fun ctx (var_id, rvalue) ->
+              add_pure_var_value_constraint var_id rvalue ctx)
+            ctx infos
+      | SymbolicPlaces infos ->
+          List.fold_left
+            (fun ctx (var_id, name) -> add_pure_var_constraint var_id name ctx)
+            ctx infos
       | MPlace mp -> add_right_constraint mp e ctx
       | Tag _ -> ctx
     in
@@ -1101,14 +1118,6 @@ let simplify_let_then_return _ctx def =
         | _ -> false)
     | _ -> false
   in
-  let match_pattern_and_ret_expr (monadic : bool) (pat : typed_pattern)
-      (e : texpression) : bool =
-    if monadic then
-      match opt_destruct_ret e with
-      | Some e -> match_pattern_and_expr pat e
-      | None -> false
-    else match_pattern_and_expr pat e
-  in
 
   let expr_visitor =
     object (self)
@@ -1124,9 +1133,25 @@ let simplify_let_then_return _ctx def =
         | Switch _ | Loop _ | Let _ ->
             (* Small shortcut to avoid doing the check on every let-binding *)
             not_simpl_e
-        | _ ->
-            if match_pattern_and_ret_expr monadic lv next_e then rv.e
-            else not_simpl_e
+        | _ -> (
+            if (* Do the check *)
+               monadic then
+              (* The first let-binding is monadic *)
+              match opt_destruct_ret next_e with
+              | Some e ->
+                  if match_pattern_and_expr lv e then rv.e else not_simpl_e
+              | None -> not_simpl_e
+            else
+              (* The first let-binding is not monadic *)
+              match opt_destruct_ret next_e with
+              | Some e ->
+                  if match_pattern_and_expr lv e then
+                    (* We need to wrap the right-value in a ret *)
+                    (mk_result_return_texpression rv).e
+                  else not_simpl_e
+              | None ->
+                  if match_pattern_and_expr lv next_e then rv.e else not_simpl_e
+            )
     end
   in
 
diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml
index cc74a16b..e164fd49 100644
--- a/compiler/SymbolicAst.ml
+++ b/compiler/SymbolicAst.ml
@@ -65,6 +65,9 @@ type call = {
 type emeta =
   | Assignment of Contexts.eval_ctx * mplace * typed_value * mplace option
       (** We generated an assignment (destination, assigned value, src) *)
+  | Snapshot of Contexts.eval_ctx
+      (** Remember an environment snapshot - this is useful to check where the
+          symbolic values are, to compute proper names for instance *)
 [@@deriving show]
 
 type variant_id = VariantId.id [@@deriving show]
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 58fb6d04..3fa550cc 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -1873,10 +1873,48 @@ let get_abs_ancestors (ctx : bs_ctx) (abs : V.abs) (call_id : V.FunCallId.id) :
 (** Add meta-information to an expression *)
 let mk_emeta_symbolic_assignments (vars : var list) (values : texpression list)
     (e : texpression) : texpression =
-  let var_values = List.combine vars values in
-  List.fold_right
-    (fun (var, arg) e -> mk_emeta (SymbolicAssignment (var_get_id var, arg)) e)
-    var_values e
+  let var_values = List.combine (List.map var_get_id vars) values in
+  if var_values <> [] then mk_emeta (SymbolicAssignments var_values) e else e
+
+(** Derive naming information from a context.
+
+    We explore the context and look in which bindings the symbolic values appear:
+    we use this information to derive naming information. *)
+let eval_ctx_to_symbolic_assignments_info (ctx : bs_ctx)
+    (ectx : Contexts.eval_ctx) : (VarId.id * string) list =
+  let info : (VarId.id * string) list ref = ref [] in
+  let push_info name sv = info := (name, sv) :: !info in
+  let visitor =
+    object (self)
+      inherit [_] Contexts.iter_eval_ctx
+
+      method! visit_env_elem _ ee =
+        match ee with
+        | EBinding (BVar { index = _; name = Some name }, v) ->
+            self#visit_typed_value name v
+        | _ -> () (* Ignore *)
+
+      method! visit_value name v =
+        match v with
+        | VLiteral _ | VBottom -> ()
+        | VBorrow (VMutBorrow (_, v)) | VLoan (VSharedLoan (_, v)) ->
+            self#visit_typed_value name v
+        | VSymbolic sv ->
+            (* Translate the type *)
+            let ty = ctx_translate_fwd_ty ctx sv.sv_ty in
+            (* If the type is unit, do nothing *)
+            if ty_is_unit ty then ()
+            else
+              (* Otherwise lookup the variable *)
+              let var = lookup_var_for_symbolic_value sv ctx in
+              push_info var.id name
+        | _ -> ()
+    end
+  in
+  (* Visit the context *)
+  visitor#visit_eval_ctx "x" ectx;
+  (* Return the computed information *)
+  !info
 
 let rec translate_expression (e : S.expression) (ctx : bs_ctx) : texpression =
   match e with
@@ -3528,11 +3566,23 @@ and translate_emeta (meta : S.emeta) (e : S.expression) (ctx : bs_ctx) :
         let lp = translate_mplace lp in
         let rv = typed_value_to_texpression ctx ectx rv in
         let rp = translate_opt_mplace rp in
-        Assignment (lp, rv, rp)
-  in
-  let e = Meta (meta, next_e) in
-  let ty = next_e.ty in
-  { e; ty }
+        Some (Assignment (lp, rv, rp))
+    | S.Snapshot ectx ->
+        let infos = eval_ctx_to_symbolic_assignments_info ctx ectx in
+        if infos <> [] then
+          (* If often happens that the next expression contains exactly the
+             same meta information *)
+          match next_e.e with
+          | Meta (SymbolicPlaces infos1, _) when infos1 = infos -> None
+          | _ -> Some (SymbolicPlaces infos)
+        else None
+  in
+  match meta with
+  | Some meta ->
+      let e = Meta (meta, next_e) in
+      let ty = next_e.ty in
+      { e; ty }
+  | None -> next_e
 
 (** Wrap a function body in a match over the fuel to control termination. *)
 let wrap_in_match_fuel (fuel0 : VarId.id) (fuel : VarId.id) (body : texpression)
diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml
index a42c43ac..ad34c48e 100644
--- a/compiler/SynthesizeSymbolic.ml
+++ b/compiler/SynthesizeSymbolic.ml
@@ -189,3 +189,9 @@ let synthesize_loop (loop_id : LoopId.id) (input_svalues : symbolic_value list)
              meta;
            })
   | _ -> raise (Failure "Unreachable")
+
+let save_snapshot (ctx : Contexts.eval_ctx) (e : expression option) :
+    expression option =
+  match e with None -> None | Some e -> Some (Meta (Snapshot ctx, e))
+
+let cf_save_snapshot : Cps.cm_fun = fun cf ctx -> save_snapshot ctx (cf ctx)
-- 
cgit v1.2.3