summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--compiler/Extract.ml218
-rw-r--r--compiler/ExtractBase.ml63
-rw-r--r--compiler/Translate.ml1
3 files changed, 221 insertions, 61 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index e07305f1..e140ea1c 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -841,6 +841,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
(* TODO: actually use the clause to derive the name *)
"cl"
in
+ let trait_self_clause_basename = "self_clause" in
let append_index (basename : string) (i : int) : string =
basename ^ string_of_int i
in
@@ -936,6 +937,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
var_basename;
type_var_basename;
const_generic_var_basename;
+ trait_self_clause_basename;
trait_clause_basename;
append_index;
extract_literal;
@@ -1237,15 +1239,26 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter)
if use_brackets then F.pp_print_string fmt ")";
F.pp_print_string fmt ("." ^ name))
else
- (* Can only happen when extracting the signature of a trait method
- *declaration* or a provided trait method (for a declaration).
- If extracting items for a trait method implementation,
- the type should have been normalized. For trait method declarations
- we directly reference the item. *)
- assert (ctx.trait_decl_id <> None);
- assert (generics = empty_generic_args);
- let name = ctx_get_local_trait_assoc_type type_name ctx in
- F.pp_print_string fmt name
+ (* There are two situations:
+ - we are extracting a declared item (typically a function signature)
+ for a trait declaration. We directly refer to the item (we extract
+ trait declarations as structures, so we can refer to their fields)
+ - we are extracting a provided method for a trait declaration. We
+ refer to the item in the self trait clause (see {!SelfTraitClauseId}).
+
+ Remark: we can't get there for trait *implementations* because then the
+ types should have been normalized.
+ *)
+ let trait_decl_id = Option.get ctx.trait_decl_id in
+ let item_name = ctx_get_trait_assoc_type trait_decl_id type_name ctx in
+ assert (generics = empty_generic_args);
+ if ctx.is_provided_method then
+ (* Provided method: use the trait self clause *)
+ let self_clause = ctx_get_trait_self_clause ctx in
+ F.pp_print_string fmt (self_clause ^ "." ^ item_name)
+ else
+ (* Declaration: directly refer to the item *)
+ F.pp_print_string fmt item_name
and extract_trait_ref (ctx : extraction_ctx) (fmt : F.formatter)
(no_params_tys : TypeDeclId.Set.t) (inside : bool) (tr : trait_ref) : unit =
@@ -1632,11 +1645,37 @@ let extract_trait_clause_type (ctx : extraction_ctx) (fmt : F.formatter)
let insert_req_space (fmt : F.formatter) (space : bool ref) : unit =
if !space then space := false else F.pp_print_space fmt ()
+(** Extract the trait self clause.
+
+ We add the trait self clause for provided methods (see {!TraitSelfClauseId}).
+ *)
+let extract_trait_self_clause (insert_req_space : unit -> unit)
+ (ctx : extraction_ctx) (fmt : F.formatter) (trait_decl : A.trait_decl)
+ (params : string list) : unit =
+ insert_req_space ();
+ F.pp_print_string fmt "(";
+ let self_clause = ctx_get_trait_self_clause ctx in
+ F.pp_print_string fmt self_clause;
+ F.pp_print_string fmt ":";
+ let with_opaque_pre = false in
+ let trait_id = ctx_get_trait_decl with_opaque_pre trait_decl.def_id ctx in
+ F.pp_print_string fmt trait_id;
+ List.iter
+ (fun p ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt p)
+ params;
+ F.pp_print_string fmt ")"
+
+(**
+ - [trait_decl]: if [Some], it means we are extracting the generics for a provided
+ method and need to insert a trait self clause (see {!TraitSelfClauseId}).
+ *)
let extract_generic_params (ctx : extraction_ctx) (fmt : F.formatter)
(no_params_tys : TypeDeclId.Set.t) (use_forall : bool) (as_implicits : bool)
- (space : bool ref option) (generics : generic_params)
- (type_params : string list) (cg_params : string list)
- (trait_clauses : string list) : unit =
+ (space : bool ref option) (trait_decl : A.trait_decl option)
+ (generics : generic_params) (type_params : string list)
+ (cg_params : string list) (trait_clauses : string list) : unit =
let all_params = List.concat [ type_params; cg_params; trait_clauses ] in
(* HOL4 doesn't support const generics *)
assert (cg_params = [] || !backend <> HOL4);
@@ -1660,53 +1699,102 @@ let extract_generic_params (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_string fmt ":";
F.pp_print_space fmt ();
F.pp_print_string fmt "forall");
- (* Note that in HOL4 we don't print the type parameters. *)
- if !backend <> HOL4 then (
- (* Print the type parameters *)
- if type_params <> [] then (
- insert_req_space ();
- (* ( *)
- left_bracket ();
+ (* Small helper - we may need to split the parameters *)
+ let print_generics (type_params : string list)
+ (const_generics : const_generic_var list)
+ (trait_clauses : trait_clause list) : unit =
+ (* Note that in HOL4 we don't print the type parameters. *)
+ if !backend <> HOL4 then (
+ (* Print the type parameters *)
+ if type_params <> [] then (
+ insert_req_space ();
+ (* ( *)
+ left_bracket ();
+ List.iter
+ (fun s ->
+ F.pp_print_string fmt s;
+ F.pp_print_space fmt ())
+ type_params;
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt (type_keyword ());
+ (* ) *)
+ right_bracket ());
+ (* Print the const generic parameters *)
List.iter
- (fun s ->
- F.pp_print_string fmt s;
- F.pp_print_space fmt ())
- type_params;
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- F.pp_print_string fmt (type_keyword ());
- (* ) *)
- right_bracket ());
- (* Print the const generic parameters *)
+ (fun (var : const_generic_var) ->
+ insert_req_space ();
+ (* ( *)
+ left_bracket ();
+ let n = ctx_get_const_generic_var var.index ctx in
+ F.pp_print_string fmt n;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ extract_literal_type ctx fmt var.ty;
+ (* ) *)
+ right_bracket ())
+ const_generics);
+ (* Print the trait clauses *)
List.iter
- (fun (var : const_generic_var) ->
+ (fun (clause : trait_clause) ->
insert_req_space ();
(* ( *)
left_bracket ();
- let n = ctx_get_const_generic_var var.index ctx in
+ let n = ctx_get_local_trait_clause clause.clause_id ctx in
F.pp_print_string fmt n;
F.pp_print_space fmt ();
F.pp_print_string fmt ":";
F.pp_print_space fmt ();
- extract_literal_type ctx fmt var.ty;
+ extract_trait_clause_type ctx fmt no_params_tys clause;
(* ) *)
right_bracket ())
- generics.const_generics);
- (* Print the trait clauses *)
- List.iter
- (fun (clause : trait_clause) ->
- insert_req_space ();
- (* ( *)
- left_bracket ();
- let n = ctx_get_local_trait_clause clause.clause_id ctx in
- F.pp_print_string fmt n;
- F.pp_print_space fmt ();
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- extract_trait_clause_type ctx fmt no_params_tys clause;
- (* ) *)
- right_bracket ())
- generics.trait_clauses)
+ trait_clauses
+ in
+ (* If we extract the generics for a provided method for a trait declaration
+ (indicated by the trait decl given as input), we need to split the generics:
+ - we print the generics for the trait decl
+ - we print the trait self clause
+ - we print the generics for the trait method
+ *)
+ match trait_decl with
+ | None ->
+ print_generics type_params generics.const_generics
+ generics.trait_clauses
+ | Some trait_decl ->
+ (* Split the generics between the generics specific to the trait decl
+ and those specific to the trait method *)
+ let open Collections.List in
+ let dtype_params, mtype_params =
+ split_at type_params (length trait_decl.generics.types)
+ in
+ let dcgs, mcgs =
+ split_at generics.const_generics
+ (length trait_decl.generics.const_generics)
+ in
+ let dtrait_clauses, mtrait_clauses =
+ split_at generics.trait_clauses
+ (length trait_decl.generics.trait_clauses)
+ in
+ (* Extract the trait decl generics *)
+ print_generics dtype_params dcgs dtrait_clauses;
+ (* Extract the trait self clause *)
+ let params =
+ concat
+ [
+ dtype_params;
+ map
+ (fun (cg : const_generic_var) ->
+ ctx_get_const_generic_var cg.index ctx)
+ dcgs;
+ map
+ (fun c -> ctx_get_local_trait_clause c.clause_id ctx)
+ dtrait_clauses;
+ ]
+ in
+ extract_trait_self_clause insert_req_space ctx fmt trait_decl params;
+ (* Extract the method generics *)
+ print_generics mtype_params mcgs mtrait_clauses)
(** Extract a type declaration.
@@ -1769,7 +1857,7 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
(* Print the generic parameters *)
let as_implicits = false in
extract_generic_params ctx_body fmt type_decl_group use_forall as_implicits
- None def.generics type_params cg_params trait_clauses;
+ None None def.generics type_params cg_params trait_clauses;
(* Print the "=" if we extract the body*)
if extract_body then (
F.pp_print_space fmt ();
@@ -2002,7 +2090,8 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
let use_forall = false in
let as_implicits = true in
extract_generic_params ctx fmt TypeDeclId.Set.empty use_forall
- as_implicits None decl.generics type_params cg_params trait_clauses;
+ as_implicits None None decl.generics type_params cg_params
+ trait_clauses;
(* Print the record parameter *)
F.pp_print_space fmt ();
F.pp_print_string fmt "(";
@@ -2994,8 +3083,8 @@ and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter)
(** A small utility to print the parameters of a function signature.
We return two contexts:
- - the context augmented with bindings for the type parameters
- - the context augmented with bindings for the type parameters *and*
+ - the context augmented with bindings for the generics
+ - the context augmented with bindings for the generics *and*
bindings for the input values
We also return names for the type parameters, const generics, etc.
@@ -3009,6 +3098,28 @@ and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter)
let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx)
(fmt : F.formatter) (def : fun_decl) :
extraction_ctx * extraction_ctx * string list =
+ (* First, add the associated types and constants if the function is a method
+ in a trait declaration.
+
+ About the order: we want to make sure the names are reserved for
+ those (variable names might collide with them but it is ok, we will add
+ suffixes to the variables).
+
+ TODO: micro-pass to update what happens when calling trait provided
+ functions.
+ *)
+ let ctx, trait_decl =
+ match def.kind with
+ | TraitMethodProvided (decl_id, _) ->
+ let trait_decl =
+ T.TraitDeclId.Map.find decl_id
+ ctx.trans_ctx.trait_decls_context.trait_decls
+ in
+ let ctx, _ = ctx_add_trait_self_clause ctx in
+ let ctx = { ctx with is_provided_method = true } in
+ (ctx, Some trait_decl)
+ | _ -> (ctx, None)
+ in
(* Add the type parameters - note that we need those bindings only for the
* body translation (they are not top-level) *)
let ctx, type_params, cg_params, trait_clauses =
@@ -3020,7 +3131,8 @@ let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx)
let use_forall = false in
let as_implicits = false in
extract_generic_params ctx fmt TypeDeclId.Set.empty use_forall as_implicits
- (Some space) def.signature.generics type_params cg_params trait_clauses;
+ (Some space) trait_decl def.signature.generics type_params cg_params
+ trait_clauses;
(* Close the box for the generics *)
F.pp_close_box fmt ();
(* The input parameters - note that doing this adds bindings to the context *)
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index 697b1027..251d8b36 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -291,6 +291,7 @@ type formatter = {
(** Generates a type variable basename. *)
const_generic_var_basename : StringSet.t -> string -> string;
(** Generates a const generic variable basename. *)
+ trait_self_clause_basename : string;
trait_clause_basename : StringSet.t -> trait_clause -> string;
(** Return a base name for a trait clause. We might add a suffix to prevent
collisions.
@@ -409,10 +410,44 @@ type id =
| TraitDeclId of TraitDeclId.id
| TraitImplId of TraitImplId.id
| LocalTraitClauseId of TraitClauseId.id
- | LocalTraitAssocTypeId of string (** Specifically for: [Self::Ty] *)
| TraitAssocTypeId of TraitDeclId.id * string (** A trait associated type *)
| TraitParentClauseId of TraitDeclId.id * TraitClauseId.id
| TraitItemClauseId of TraitDeclId.id * string * TraitClauseId.id
+ | TraitSelfClauseId
+ (** Specifically for the clause: [Self : Trait].
+
+ For now, we forbid provided methods (methods in trait declarations
+ with a default implementation) from being overriden in trait implementations.
+ We extract trait provided methods such that they take an instance of
+ the trait as input: this instance is given by the trait self clause.
+
+ For instance:
+ {[
+ //
+ // Rust
+ //
+ trait ToU64 {
+ fn to_u64(&self) -> u64;
+
+ // Provided method
+ fn is_pos(&self) -> bool {
+ self.to_u64() > 0
+ }
+ }
+
+ //
+ // Generated code
+ //
+ struct ToU64 (T : Type) {
+ to_u64 : T -> u64;
+ }
+
+ // The trait self clause
+ // vvvvvvvvvvvvvvvvvvvvvv
+ let is_pos (T : Type) (trait_self : ToU64 T) (self : T) : bool =
+ trait_self.to_u64 self > 0
+ ]}
+ *)
| UnknownId
(** Used for stored various strings like keywords, definitions which
should always be in context, etc. and which can't be linked to one
@@ -618,6 +653,7 @@ type extraction_ctx = {
*)
trait_decl_id : trait_decl_id option;
(** If we are extracting a trait declaration, identifies it *)
+ is_provided_method : bool;
}
(** Debugging function, used when communicating name collisions to the user,
@@ -752,7 +788,6 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
| TraitImplId id -> "trait_impl_id: " ^ TraitImplId.to_string id
| LocalTraitClauseId id ->
"local_trait_clause_id: " ^ TraitClauseId.to_string id
- | LocalTraitAssocTypeId type_name -> "local_trait_assoc_type_id: " ^ type_name
| TraitParentClauseId (id, clause_id) ->
"trait_parent_clause_id: decl_id:" ^ TraitDeclId.to_string id
^ ", clause_id: "
@@ -764,11 +799,14 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
| TraitAssocTypeId (id, type_name) ->
"trait_assoc_type_id: decl_id:" ^ TraitDeclId.to_string id
^ ", type name: " ^ type_name
+ | TraitSelfClauseId -> "trait_self_clause"
(** We might not check for collisions for some specific ids (ex.: field names) *)
let allow_collisions (id : id) : bool =
match id with
- | FieldId (_, _) -> !Config.record_fields_short_names
+ | FieldId _ | TraitItemClauseId _ | TraitParentClauseId _ | TraitAssocTypeId _
+ ->
+ !Config.record_fields_short_names
| _ -> false
let ctx_add (is_opaque : bool) (id : id) (name : string) (ctx : extraction_ctx)
@@ -858,6 +896,10 @@ let ctx_get_assumed_type (id : assumed_ty) (ctx : extraction_ctx) : string =
let is_opaque = false in
ctx_get_type is_opaque (Assumed id) ctx
+let ctx_get_trait_self_clause (ctx : extraction_ctx) : string =
+ let with_opaque_pre = false in
+ ctx_get with_opaque_pre TraitSelfClauseId ctx
+
let ctx_get_trait_decl (with_opaque_pre : bool) (id : trait_decl_id)
(ctx : extraction_ctx) : string =
ctx_get with_opaque_pre (TraitDeclId id) ctx
@@ -871,11 +913,6 @@ let ctx_get_trait_assoc_type (id : trait_decl_id) (type_name : string)
let is_opaque = false in
ctx_get is_opaque (TraitAssocTypeId (id, type_name)) ctx
-let ctx_get_local_trait_assoc_type (type_name : string) (ctx : extraction_ctx) :
- string =
- let is_opaque = false in
- ctx_get is_opaque (LocalTraitAssocTypeId type_name) ctx
-
let ctx_get_trait_parent_clause (id : trait_decl_id) (clause : trait_clause_id)
(ctx : extraction_ctx) : string =
let with_opaque_pre = false in
@@ -969,6 +1006,16 @@ let ctx_add_var (basename : string) (id : VarId.id) (ctx : extraction_ctx) :
let ctx = ctx_add is_opaque (VarId id) name ctx in
(ctx, name)
+(** Generate a unique variable name for the trait self clause and add it to the context *)
+let ctx_add_trait_self_clause (ctx : extraction_ctx) : extraction_ctx * string =
+ let is_opaque = false in
+ let basename = ctx.fmt.trait_self_clause_basename in
+ let name =
+ basename_to_unique ctx.names_map.names_set ctx.fmt.append_index basename
+ in
+ let ctx = ctx_add is_opaque TraitSelfClauseId name ctx in
+ (ctx, name)
+
(** Generate a unique trait clause name and add it to the context *)
let ctx_add_local_trait_clause (basename : string) (id : TraitClauseId.id)
(ctx : extraction_ctx) : extraction_ctx * string =
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index f4f59187..790dbe14 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -1007,6 +1007,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
use_dep_ite = !Config.backend = Lean && !Config.extract_decreases_clauses;
fun_name_info = PureUtils.RegularFunIdMap.empty;
trait_decl_id = None (* None by default *);
+ is_provided_method = false (* false by default *);
}
in