From a2f19257651df3c8473e17ef73a5389b9cb89bbf Mon Sep 17 00:00:00 2001
From: Son Ho
Date: Sun, 3 Sep 2023 16:35:05 +0200
Subject: Make progress on the extraction

---
 compiler/Extract.ml     | 218 ++++++++++++++++++++++++++++++++++++------------
 compiler/ExtractBase.ml |  63 ++++++++++++--
 compiler/Translate.ml   |   1 +
 3 files changed, 221 insertions(+), 61 deletions(-)

diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index e07305f1..e140ea1c 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -841,6 +841,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
     (* TODO: actually use the clause to derive the name *)
     "cl"
   in
+  let trait_self_clause_basename = "self_clause" in
   let append_index (basename : string) (i : int) : string =
     basename ^ string_of_int i
   in
@@ -936,6 +937,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
     var_basename;
     type_var_basename;
     const_generic_var_basename;
+    trait_self_clause_basename;
     trait_clause_basename;
     append_index;
     extract_literal;
@@ -1237,15 +1239,26 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter)
         if use_brackets then F.pp_print_string fmt ")";
         F.pp_print_string fmt ("." ^ name))
       else
-        (* Can only happen when extracting the signature of a trait method
-           *declaration* or a provided trait method (for a declaration).
-           If extracting items for a trait method implementation,
-           the type should have been normalized. For trait method declarations
-           we directly reference the item. *)
-        assert (ctx.trait_decl_id <> None);
-      assert (generics = empty_generic_args);
-      let name = ctx_get_local_trait_assoc_type type_name ctx in
-      F.pp_print_string fmt name
+        (* There are two situations:
+           - we are extracting a declared item (typically a function signature)
+             for a trait declaration. We directly refer to the item (we extract
+             trait declarations as structures, so we can refer to their fields)
+           - we are extracting a provided method for a trait declaration. We
+             refer to the item in the self trait clause (see {!SelfTraitClauseId}).
+
+           Remark: we can't get there for trait *implementations* because then the
+           types should have been normalized.
+        *)
+        let trait_decl_id = Option.get ctx.trait_decl_id in
+        let item_name = ctx_get_trait_assoc_type trait_decl_id type_name ctx in
+        assert (generics = empty_generic_args);
+        if ctx.is_provided_method then
+          (* Provided method: use the trait self clause *)
+          let self_clause = ctx_get_trait_self_clause ctx in
+          F.pp_print_string fmt (self_clause ^ "." ^ item_name)
+        else
+          (* Declaration: directly refer to the item *)
+          F.pp_print_string fmt item_name
 
 and extract_trait_ref (ctx : extraction_ctx) (fmt : F.formatter)
     (no_params_tys : TypeDeclId.Set.t) (inside : bool) (tr : trait_ref) : unit =
@@ -1632,11 +1645,37 @@ let extract_trait_clause_type (ctx : extraction_ctx) (fmt : F.formatter)
 let insert_req_space (fmt : F.formatter) (space : bool ref) : unit =
   if !space then space := false else F.pp_print_space fmt ()
 
+(** Extract the trait self clause.
+
+    We add the trait self clause for provided methods (see {!TraitSelfClauseId}).
+ *)
+let extract_trait_self_clause (insert_req_space : unit -> unit)
+    (ctx : extraction_ctx) (fmt : F.formatter) (trait_decl : A.trait_decl)
+    (params : string list) : unit =
+  insert_req_space ();
+  F.pp_print_string fmt "(";
+  let self_clause = ctx_get_trait_self_clause ctx in
+  F.pp_print_string fmt self_clause;
+  F.pp_print_string fmt ":";
+  let with_opaque_pre = false in
+  let trait_id = ctx_get_trait_decl with_opaque_pre trait_decl.def_id ctx in
+  F.pp_print_string fmt trait_id;
+  List.iter
+    (fun p ->
+      F.pp_print_space fmt ();
+      F.pp_print_string fmt p)
+    params;
+  F.pp_print_string fmt ")"
+
+(**
+ - [trait_decl]: if [Some], it means we are extracting the generics for a provided
+   method and need to insert a trait self clause (see {!TraitSelfClauseId}).
+ *)
 let extract_generic_params (ctx : extraction_ctx) (fmt : F.formatter)
     (no_params_tys : TypeDeclId.Set.t) (use_forall : bool) (as_implicits : bool)
-    (space : bool ref option) (generics : generic_params)
-    (type_params : string list) (cg_params : string list)
-    (trait_clauses : string list) : unit =
+    (space : bool ref option) (trait_decl : A.trait_decl option)
+    (generics : generic_params) (type_params : string list)
+    (cg_params : string list) (trait_clauses : string list) : unit =
   let all_params = List.concat [ type_params; cg_params; trait_clauses ] in
   (* HOL4 doesn't support const generics *)
   assert (cg_params = [] || !backend <> HOL4);
@@ -1660,53 +1699,102 @@ let extract_generic_params (ctx : extraction_ctx) (fmt : F.formatter)
       F.pp_print_string fmt ":";
       F.pp_print_space fmt ();
       F.pp_print_string fmt "forall");
-    (* Note that in HOL4 we don't print the type parameters. *)
-    if !backend <> HOL4 then (
-      (* Print the type parameters *)
-      if type_params <> [] then (
-        insert_req_space ();
-        (* ( *)
-        left_bracket ();
+    (* Small helper - we may need to split the parameters *)
+    let print_generics (type_params : string list)
+        (const_generics : const_generic_var list)
+        (trait_clauses : trait_clause list) : unit =
+      (* Note that in HOL4 we don't print the type parameters. *)
+      if !backend <> HOL4 then (
+        (* Print the type parameters *)
+        if type_params <> [] then (
+          insert_req_space ();
+          (* ( *)
+          left_bracket ();
+          List.iter
+            (fun s ->
+              F.pp_print_string fmt s;
+              F.pp_print_space fmt ())
+            type_params;
+          F.pp_print_string fmt ":";
+          F.pp_print_space fmt ();
+          F.pp_print_string fmt (type_keyword ());
+          (* ) *)
+          right_bracket ());
+        (* Print the const generic parameters *)
         List.iter
-          (fun s ->
-            F.pp_print_string fmt s;
-            F.pp_print_space fmt ())
-          type_params;
-        F.pp_print_string fmt ":";
-        F.pp_print_space fmt ();
-        F.pp_print_string fmt (type_keyword ());
-        (* ) *)
-        right_bracket ());
-      (* Print the const generic parameters *)
+          (fun (var : const_generic_var) ->
+            insert_req_space ();
+            (* ( *)
+            left_bracket ();
+            let n = ctx_get_const_generic_var var.index ctx in
+            F.pp_print_string fmt n;
+            F.pp_print_space fmt ();
+            F.pp_print_string fmt ":";
+            F.pp_print_space fmt ();
+            extract_literal_type ctx fmt var.ty;
+            (* ) *)
+            right_bracket ())
+          const_generics);
+      (* Print the trait clauses *)
       List.iter
-        (fun (var : const_generic_var) ->
+        (fun (clause : trait_clause) ->
           insert_req_space ();
           (* ( *)
           left_bracket ();
-          let n = ctx_get_const_generic_var var.index ctx in
+          let n = ctx_get_local_trait_clause clause.clause_id ctx in
           F.pp_print_string fmt n;
           F.pp_print_space fmt ();
           F.pp_print_string fmt ":";
           F.pp_print_space fmt ();
-          extract_literal_type ctx fmt var.ty;
+          extract_trait_clause_type ctx fmt no_params_tys clause;
           (* ) *)
           right_bracket ())
-        generics.const_generics);
-    (* Print the trait clauses *)
-    List.iter
-      (fun (clause : trait_clause) ->
-        insert_req_space ();
-        (* ( *)
-        left_bracket ();
-        let n = ctx_get_local_trait_clause clause.clause_id ctx in
-        F.pp_print_string fmt n;
-        F.pp_print_space fmt ();
-        F.pp_print_string fmt ":";
-        F.pp_print_space fmt ();
-        extract_trait_clause_type ctx fmt no_params_tys clause;
-        (* ) *)
-        right_bracket ())
-      generics.trait_clauses)
+        trait_clauses
+    in
+    (* If we extract the generics for a provided method for a trait declaration
+       (indicated by the trait decl given as input), we need to split the generics:
+       - we print the generics for the trait decl
+       - we print the trait self clause
+       - we print the generics for the trait method
+    *)
+    match trait_decl with
+    | None ->
+        print_generics type_params generics.const_generics
+          generics.trait_clauses
+    | Some trait_decl ->
+        (* Split the generics between the generics specific to the trait decl
+           and those specific to the trait method *)
+        let open Collections.List in
+        let dtype_params, mtype_params =
+          split_at type_params (length trait_decl.generics.types)
+        in
+        let dcgs, mcgs =
+          split_at generics.const_generics
+            (length trait_decl.generics.const_generics)
+        in
+        let dtrait_clauses, mtrait_clauses =
+          split_at generics.trait_clauses
+            (length trait_decl.generics.trait_clauses)
+        in
+        (* Extract the trait decl generics *)
+        print_generics dtype_params dcgs dtrait_clauses;
+        (* Extract the trait self clause *)
+        let params =
+          concat
+            [
+              dtype_params;
+              map
+                (fun (cg : const_generic_var) ->
+                  ctx_get_const_generic_var cg.index ctx)
+                dcgs;
+              map
+                (fun c -> ctx_get_local_trait_clause c.clause_id ctx)
+                dtrait_clauses;
+            ]
+        in
+        extract_trait_self_clause insert_req_space ctx fmt trait_decl params;
+        (* Extract the method generics *)
+        print_generics mtype_params mcgs mtrait_clauses)
 
 (** Extract a type declaration.
 
@@ -1769,7 +1857,7 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
   (* Print the generic parameters *)
   let as_implicits = false in
   extract_generic_params ctx_body fmt type_decl_group use_forall as_implicits
-    None def.generics type_params cg_params trait_clauses;
+    None None def.generics type_params cg_params trait_clauses;
   (* Print the "=" if we extract the body*)
   if extract_body then (
     F.pp_print_space fmt ();
@@ -2002,7 +2090,8 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
           let use_forall = false in
           let as_implicits = true in
           extract_generic_params ctx fmt TypeDeclId.Set.empty use_forall
-            as_implicits None decl.generics type_params cg_params trait_clauses;
+            as_implicits None None decl.generics type_params cg_params
+            trait_clauses;
           (* Print the record parameter *)
           F.pp_print_space fmt ();
           F.pp_print_string fmt "(";
@@ -2994,8 +3083,8 @@ and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter)
 (** A small utility to print the parameters of a function signature.
 
     We return two contexts:
-    - the context augmented with bindings for the type parameters
-    - the context augmented with bindings for the type parameters *and*
+    - the context augmented with bindings for the generics
+    - the context augmented with bindings for the generics *and*
       bindings for the input values
 
     We also return names for the type parameters, const generics, etc.
@@ -3009,6 +3098,28 @@ and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter)
 let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx)
     (fmt : F.formatter) (def : fun_decl) :
     extraction_ctx * extraction_ctx * string list =
+  (* First, add the associated types and constants if the function is a method
+     in a trait declaration.
+
+     About the order: we want to make sure the names are reserved for
+     those (variable names might collide with them but it is ok, we will add
+     suffixes to the variables).
+
+     TODO: micro-pass to update what happens when calling trait provided
+     functions.
+  *)
+  let ctx, trait_decl =
+    match def.kind with
+    | TraitMethodProvided (decl_id, _) ->
+        let trait_decl =
+          T.TraitDeclId.Map.find decl_id
+            ctx.trans_ctx.trait_decls_context.trait_decls
+        in
+        let ctx, _ = ctx_add_trait_self_clause ctx in
+        let ctx = { ctx with is_provided_method = true } in
+        (ctx, Some trait_decl)
+    | _ -> (ctx, None)
+  in
   (* Add the type parameters - note that we need those bindings only for the
    * body translation (they are not top-level) *)
   let ctx, type_params, cg_params, trait_clauses =
@@ -3020,7 +3131,8 @@ let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx)
   let use_forall = false in
   let as_implicits = false in
   extract_generic_params ctx fmt TypeDeclId.Set.empty use_forall as_implicits
-    (Some space) def.signature.generics type_params cg_params trait_clauses;
+    (Some space) trait_decl def.signature.generics type_params cg_params
+    trait_clauses;
   (* Close the box for the generics *)
   F.pp_close_box fmt ();
   (* The input parameters - note that doing this adds bindings to the context *)
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index 697b1027..251d8b36 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -291,6 +291,7 @@ type formatter = {
       (** Generates a type variable basename. *)
   const_generic_var_basename : StringSet.t -> string -> string;
       (** Generates a const generic variable basename. *)
+  trait_self_clause_basename : string;
   trait_clause_basename : StringSet.t -> trait_clause -> string;
       (** Return a base name for a trait clause. We might add a suffix to prevent
           collisions.
@@ -409,10 +410,44 @@ type id =
   | TraitDeclId of TraitDeclId.id
   | TraitImplId of TraitImplId.id
   | LocalTraitClauseId of TraitClauseId.id
-  | LocalTraitAssocTypeId of string  (** Specifically for: [Self::Ty] *)
   | TraitAssocTypeId of TraitDeclId.id * string  (** A trait associated type *)
   | TraitParentClauseId of TraitDeclId.id * TraitClauseId.id
   | TraitItemClauseId of TraitDeclId.id * string * TraitClauseId.id
+  | TraitSelfClauseId
+      (** Specifically for the clause: [Self : Trait].
+
+          For now, we forbid provided methods (methods in trait declarations
+          with a default implementation) from being overriden in trait implementations.
+          We extract trait provided methods such that they take an instance of
+          the trait as input: this instance is given by the trait self clause.
+
+          For instance:
+          {[
+            //
+            // Rust
+            //
+            trait ToU64 {
+              fn to_u64(&self) -> u64;
+
+              // Provided method
+              fn is_pos(&self) -> bool {
+                self.to_u64() > 0
+              }
+            }
+
+            //
+            // Generated code
+            //
+            struct ToU64 (T : Type) {
+              to_u64 : T -> u64;
+            }
+
+            //                    The trait self clause
+            //                    vvvvvvvvvvvvvvvvvvvvvv
+            let is_pos (T : Type) (trait_self : ToU64 T) (self : T) : bool =
+              trait_self.to_u64 self > 0
+          ]}
+       *)
   | UnknownId
       (** Used for stored various strings like keywords, definitions which
           should always be in context, etc. and which can't be linked to one
@@ -618,6 +653,7 @@ type extraction_ctx = {
         *)
   trait_decl_id : trait_decl_id option;
       (** If we are extracting a trait declaration, identifies it *)
+  is_provided_method : bool;
 }
 
 (** Debugging function, used when communicating name collisions to the user,
@@ -752,7 +788,6 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
   | TraitImplId id -> "trait_impl_id: " ^ TraitImplId.to_string id
   | LocalTraitClauseId id ->
       "local_trait_clause_id: " ^ TraitClauseId.to_string id
-  | LocalTraitAssocTypeId type_name -> "local_trait_assoc_type_id: " ^ type_name
   | TraitParentClauseId (id, clause_id) ->
       "trait_parent_clause_id: decl_id:" ^ TraitDeclId.to_string id
       ^ ", clause_id: "
@@ -764,11 +799,14 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
   | TraitAssocTypeId (id, type_name) ->
       "trait_assoc_type_id: decl_id:" ^ TraitDeclId.to_string id
       ^ ", type name: " ^ type_name
+  | TraitSelfClauseId -> "trait_self_clause"
 
 (** We might not check for collisions for some specific ids (ex.: field names) *)
 let allow_collisions (id : id) : bool =
   match id with
-  | FieldId (_, _) -> !Config.record_fields_short_names
+  | FieldId _ | TraitItemClauseId _ | TraitParentClauseId _ | TraitAssocTypeId _
+    ->
+      !Config.record_fields_short_names
   | _ -> false
 
 let ctx_add (is_opaque : bool) (id : id) (name : string) (ctx : extraction_ctx)
@@ -858,6 +896,10 @@ let ctx_get_assumed_type (id : assumed_ty) (ctx : extraction_ctx) : string =
   let is_opaque = false in
   ctx_get_type is_opaque (Assumed id) ctx
 
+let ctx_get_trait_self_clause (ctx : extraction_ctx) : string =
+  let with_opaque_pre = false in
+  ctx_get with_opaque_pre TraitSelfClauseId ctx
+
 let ctx_get_trait_decl (with_opaque_pre : bool) (id : trait_decl_id)
     (ctx : extraction_ctx) : string =
   ctx_get with_opaque_pre (TraitDeclId id) ctx
@@ -871,11 +913,6 @@ let ctx_get_trait_assoc_type (id : trait_decl_id) (type_name : string)
   let is_opaque = false in
   ctx_get is_opaque (TraitAssocTypeId (id, type_name)) ctx
 
-let ctx_get_local_trait_assoc_type (type_name : string) (ctx : extraction_ctx) :
-    string =
-  let is_opaque = false in
-  ctx_get is_opaque (LocalTraitAssocTypeId type_name) ctx
-
 let ctx_get_trait_parent_clause (id : trait_decl_id) (clause : trait_clause_id)
     (ctx : extraction_ctx) : string =
   let with_opaque_pre = false in
@@ -969,6 +1006,16 @@ let ctx_add_var (basename : string) (id : VarId.id) (ctx : extraction_ctx) :
   let ctx = ctx_add is_opaque (VarId id) name ctx in
   (ctx, name)
 
+(** Generate a unique variable name for the trait self clause and add it to the context *)
+let ctx_add_trait_self_clause (ctx : extraction_ctx) : extraction_ctx * string =
+  let is_opaque = false in
+  let basename = ctx.fmt.trait_self_clause_basename in
+  let name =
+    basename_to_unique ctx.names_map.names_set ctx.fmt.append_index basename
+  in
+  let ctx = ctx_add is_opaque TraitSelfClauseId name ctx in
+  (ctx, name)
+
 (** Generate a unique trait clause name and add it to the context *)
 let ctx_add_local_trait_clause (basename : string) (id : TraitClauseId.id)
     (ctx : extraction_ctx) : extraction_ctx * string =
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index f4f59187..790dbe14 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -1007,6 +1007,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
       use_dep_ite = !Config.backend = Lean && !Config.extract_decreases_clauses;
       fun_name_info = PureUtils.RegularFunIdMap.empty;
       trait_decl_id = None (* None by default *);
+      is_provided_method = false (* false by default *);
     }
   in
 
-- 
cgit v1.2.3