diff options
Diffstat (limited to '')
-rw-r--r-- | compiler/Extract.ml | 218 | ||||
-rw-r--r-- | compiler/ExtractBase.ml | 63 | ||||
-rw-r--r-- | compiler/Translate.ml | 1 |
3 files changed, 221 insertions, 61 deletions
diff --git a/compiler/Extract.ml b/compiler/Extract.ml index e07305f1..e140ea1c 100644 --- a/compiler/Extract.ml +++ b/compiler/Extract.ml @@ -841,6 +841,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) (* TODO: actually use the clause to derive the name *) "cl" in + let trait_self_clause_basename = "self_clause" in let append_index (basename : string) (i : int) : string = basename ^ string_of_int i in @@ -936,6 +937,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string) var_basename; type_var_basename; const_generic_var_basename; + trait_self_clause_basename; trait_clause_basename; append_index; extract_literal; @@ -1237,15 +1239,26 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter) if use_brackets then F.pp_print_string fmt ")"; F.pp_print_string fmt ("." ^ name)) else - (* Can only happen when extracting the signature of a trait method - *declaration* or a provided trait method (for a declaration). - If extracting items for a trait method implementation, - the type should have been normalized. For trait method declarations - we directly reference the item. *) - assert (ctx.trait_decl_id <> None); - assert (generics = empty_generic_args); - let name = ctx_get_local_trait_assoc_type type_name ctx in - F.pp_print_string fmt name + (* There are two situations: + - we are extracting a declared item (typically a function signature) + for a trait declaration. We directly refer to the item (we extract + trait declarations as structures, so we can refer to their fields) + - we are extracting a provided method for a trait declaration. We + refer to the item in the self trait clause (see {!SelfTraitClauseId}). + + Remark: we can't get there for trait *implementations* because then the + types should have been normalized. + *) + let trait_decl_id = Option.get ctx.trait_decl_id in + let item_name = ctx_get_trait_assoc_type trait_decl_id type_name ctx in + assert (generics = empty_generic_args); + if ctx.is_provided_method then + (* Provided method: use the trait self clause *) + let self_clause = ctx_get_trait_self_clause ctx in + F.pp_print_string fmt (self_clause ^ "." ^ item_name) + else + (* Declaration: directly refer to the item *) + F.pp_print_string fmt item_name and extract_trait_ref (ctx : extraction_ctx) (fmt : F.formatter) (no_params_tys : TypeDeclId.Set.t) (inside : bool) (tr : trait_ref) : unit = @@ -1632,11 +1645,37 @@ let extract_trait_clause_type (ctx : extraction_ctx) (fmt : F.formatter) let insert_req_space (fmt : F.formatter) (space : bool ref) : unit = if !space then space := false else F.pp_print_space fmt () +(** Extract the trait self clause. + + We add the trait self clause for provided methods (see {!TraitSelfClauseId}). + *) +let extract_trait_self_clause (insert_req_space : unit -> unit) + (ctx : extraction_ctx) (fmt : F.formatter) (trait_decl : A.trait_decl) + (params : string list) : unit = + insert_req_space (); + F.pp_print_string fmt "("; + let self_clause = ctx_get_trait_self_clause ctx in + F.pp_print_string fmt self_clause; + F.pp_print_string fmt ":"; + let with_opaque_pre = false in + let trait_id = ctx_get_trait_decl with_opaque_pre trait_decl.def_id ctx in + F.pp_print_string fmt trait_id; + List.iter + (fun p -> + F.pp_print_space fmt (); + F.pp_print_string fmt p) + params; + F.pp_print_string fmt ")" + +(** + - [trait_decl]: if [Some], it means we are extracting the generics for a provided + method and need to insert a trait self clause (see {!TraitSelfClauseId}). + *) let extract_generic_params (ctx : extraction_ctx) (fmt : F.formatter) (no_params_tys : TypeDeclId.Set.t) (use_forall : bool) (as_implicits : bool) - (space : bool ref option) (generics : generic_params) - (type_params : string list) (cg_params : string list) - (trait_clauses : string list) : unit = + (space : bool ref option) (trait_decl : A.trait_decl option) + (generics : generic_params) (type_params : string list) + (cg_params : string list) (trait_clauses : string list) : unit = let all_params = List.concat [ type_params; cg_params; trait_clauses ] in (* HOL4 doesn't support const generics *) assert (cg_params = [] || !backend <> HOL4); @@ -1660,53 +1699,102 @@ let extract_generic_params (ctx : extraction_ctx) (fmt : F.formatter) F.pp_print_string fmt ":"; F.pp_print_space fmt (); F.pp_print_string fmt "forall"); - (* Note that in HOL4 we don't print the type parameters. *) - if !backend <> HOL4 then ( - (* Print the type parameters *) - if type_params <> [] then ( - insert_req_space (); - (* ( *) - left_bracket (); + (* Small helper - we may need to split the parameters *) + let print_generics (type_params : string list) + (const_generics : const_generic_var list) + (trait_clauses : trait_clause list) : unit = + (* Note that in HOL4 we don't print the type parameters. *) + if !backend <> HOL4 then ( + (* Print the type parameters *) + if type_params <> [] then ( + insert_req_space (); + (* ( *) + left_bracket (); + List.iter + (fun s -> + F.pp_print_string fmt s; + F.pp_print_space fmt ()) + type_params; + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + F.pp_print_string fmt (type_keyword ()); + (* ) *) + right_bracket ()); + (* Print the const generic parameters *) List.iter - (fun s -> - F.pp_print_string fmt s; - F.pp_print_space fmt ()) - type_params; - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - F.pp_print_string fmt (type_keyword ()); - (* ) *) - right_bracket ()); - (* Print the const generic parameters *) + (fun (var : const_generic_var) -> + insert_req_space (); + (* ( *) + left_bracket (); + let n = ctx_get_const_generic_var var.index ctx in + F.pp_print_string fmt n; + F.pp_print_space fmt (); + F.pp_print_string fmt ":"; + F.pp_print_space fmt (); + extract_literal_type ctx fmt var.ty; + (* ) *) + right_bracket ()) + const_generics); + (* Print the trait clauses *) List.iter - (fun (var : const_generic_var) -> + (fun (clause : trait_clause) -> insert_req_space (); (* ( *) left_bracket (); - let n = ctx_get_const_generic_var var.index ctx in + let n = ctx_get_local_trait_clause clause.clause_id ctx in F.pp_print_string fmt n; F.pp_print_space fmt (); F.pp_print_string fmt ":"; F.pp_print_space fmt (); - extract_literal_type ctx fmt var.ty; + extract_trait_clause_type ctx fmt no_params_tys clause; (* ) *) right_bracket ()) - generics.const_generics); - (* Print the trait clauses *) - List.iter - (fun (clause : trait_clause) -> - insert_req_space (); - (* ( *) - left_bracket (); - let n = ctx_get_local_trait_clause clause.clause_id ctx in - F.pp_print_string fmt n; - F.pp_print_space fmt (); - F.pp_print_string fmt ":"; - F.pp_print_space fmt (); - extract_trait_clause_type ctx fmt no_params_tys clause; - (* ) *) - right_bracket ()) - generics.trait_clauses) + trait_clauses + in + (* If we extract the generics for a provided method for a trait declaration + (indicated by the trait decl given as input), we need to split the generics: + - we print the generics for the trait decl + - we print the trait self clause + - we print the generics for the trait method + *) + match trait_decl with + | None -> + print_generics type_params generics.const_generics + generics.trait_clauses + | Some trait_decl -> + (* Split the generics between the generics specific to the trait decl + and those specific to the trait method *) + let open Collections.List in + let dtype_params, mtype_params = + split_at type_params (length trait_decl.generics.types) + in + let dcgs, mcgs = + split_at generics.const_generics + (length trait_decl.generics.const_generics) + in + let dtrait_clauses, mtrait_clauses = + split_at generics.trait_clauses + (length trait_decl.generics.trait_clauses) + in + (* Extract the trait decl generics *) + print_generics dtype_params dcgs dtrait_clauses; + (* Extract the trait self clause *) + let params = + concat + [ + dtype_params; + map + (fun (cg : const_generic_var) -> + ctx_get_const_generic_var cg.index ctx) + dcgs; + map + (fun c -> ctx_get_local_trait_clause c.clause_id ctx) + dtrait_clauses; + ] + in + extract_trait_self_clause insert_req_space ctx fmt trait_decl params; + (* Extract the method generics *) + print_generics mtype_params mcgs mtrait_clauses) (** Extract a type declaration. @@ -1769,7 +1857,7 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter) (* Print the generic parameters *) let as_implicits = false in extract_generic_params ctx_body fmt type_decl_group use_forall as_implicits - None def.generics type_params cg_params trait_clauses; + None None def.generics type_params cg_params trait_clauses; (* Print the "=" if we extract the body*) if extract_body then ( F.pp_print_space fmt (); @@ -2002,7 +2090,8 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx) let use_forall = false in let as_implicits = true in extract_generic_params ctx fmt TypeDeclId.Set.empty use_forall - as_implicits None decl.generics type_params cg_params trait_clauses; + as_implicits None None decl.generics type_params cg_params + trait_clauses; (* Print the record parameter *) F.pp_print_space fmt (); F.pp_print_string fmt "("; @@ -2994,8 +3083,8 @@ and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter) (** A small utility to print the parameters of a function signature. We return two contexts: - - the context augmented with bindings for the type parameters - - the context augmented with bindings for the type parameters *and* + - the context augmented with bindings for the generics + - the context augmented with bindings for the generics *and* bindings for the input values We also return names for the type parameters, const generics, etc. @@ -3009,6 +3098,28 @@ and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter) let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx) (fmt : F.formatter) (def : fun_decl) : extraction_ctx * extraction_ctx * string list = + (* First, add the associated types and constants if the function is a method + in a trait declaration. + + About the order: we want to make sure the names are reserved for + those (variable names might collide with them but it is ok, we will add + suffixes to the variables). + + TODO: micro-pass to update what happens when calling trait provided + functions. + *) + let ctx, trait_decl = + match def.kind with + | TraitMethodProvided (decl_id, _) -> + let trait_decl = + T.TraitDeclId.Map.find decl_id + ctx.trans_ctx.trait_decls_context.trait_decls + in + let ctx, _ = ctx_add_trait_self_clause ctx in + let ctx = { ctx with is_provided_method = true } in + (ctx, Some trait_decl) + | _ -> (ctx, None) + in (* Add the type parameters - note that we need those bindings only for the * body translation (they are not top-level) *) let ctx, type_params, cg_params, trait_clauses = @@ -3020,7 +3131,8 @@ let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx) let use_forall = false in let as_implicits = false in extract_generic_params ctx fmt TypeDeclId.Set.empty use_forall as_implicits - (Some space) def.signature.generics type_params cg_params trait_clauses; + (Some space) trait_decl def.signature.generics type_params cg_params + trait_clauses; (* Close the box for the generics *) F.pp_close_box fmt (); (* The input parameters - note that doing this adds bindings to the context *) diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml index 697b1027..251d8b36 100644 --- a/compiler/ExtractBase.ml +++ b/compiler/ExtractBase.ml @@ -291,6 +291,7 @@ type formatter = { (** Generates a type variable basename. *) const_generic_var_basename : StringSet.t -> string -> string; (** Generates a const generic variable basename. *) + trait_self_clause_basename : string; trait_clause_basename : StringSet.t -> trait_clause -> string; (** Return a base name for a trait clause. We might add a suffix to prevent collisions. @@ -409,10 +410,44 @@ type id = | TraitDeclId of TraitDeclId.id | TraitImplId of TraitImplId.id | LocalTraitClauseId of TraitClauseId.id - | LocalTraitAssocTypeId of string (** Specifically for: [Self::Ty] *) | TraitAssocTypeId of TraitDeclId.id * string (** A trait associated type *) | TraitParentClauseId of TraitDeclId.id * TraitClauseId.id | TraitItemClauseId of TraitDeclId.id * string * TraitClauseId.id + | TraitSelfClauseId + (** Specifically for the clause: [Self : Trait]. + + For now, we forbid provided methods (methods in trait declarations + with a default implementation) from being overriden in trait implementations. + We extract trait provided methods such that they take an instance of + the trait as input: this instance is given by the trait self clause. + + For instance: + {[ + // + // Rust + // + trait ToU64 { + fn to_u64(&self) -> u64; + + // Provided method + fn is_pos(&self) -> bool { + self.to_u64() > 0 + } + } + + // + // Generated code + // + struct ToU64 (T : Type) { + to_u64 : T -> u64; + } + + // The trait self clause + // vvvvvvvvvvvvvvvvvvvvvv + let is_pos (T : Type) (trait_self : ToU64 T) (self : T) : bool = + trait_self.to_u64 self > 0 + ]} + *) | UnknownId (** Used for stored various strings like keywords, definitions which should always be in context, etc. and which can't be linked to one @@ -618,6 +653,7 @@ type extraction_ctx = { *) trait_decl_id : trait_decl_id option; (** If we are extracting a trait declaration, identifies it *) + is_provided_method : bool; } (** Debugging function, used when communicating name collisions to the user, @@ -752,7 +788,6 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = | TraitImplId id -> "trait_impl_id: " ^ TraitImplId.to_string id | LocalTraitClauseId id -> "local_trait_clause_id: " ^ TraitClauseId.to_string id - | LocalTraitAssocTypeId type_name -> "local_trait_assoc_type_id: " ^ type_name | TraitParentClauseId (id, clause_id) -> "trait_parent_clause_id: decl_id:" ^ TraitDeclId.to_string id ^ ", clause_id: " @@ -764,11 +799,14 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string = | TraitAssocTypeId (id, type_name) -> "trait_assoc_type_id: decl_id:" ^ TraitDeclId.to_string id ^ ", type name: " ^ type_name + | TraitSelfClauseId -> "trait_self_clause" (** We might not check for collisions for some specific ids (ex.: field names) *) let allow_collisions (id : id) : bool = match id with - | FieldId (_, _) -> !Config.record_fields_short_names + | FieldId _ | TraitItemClauseId _ | TraitParentClauseId _ | TraitAssocTypeId _ + -> + !Config.record_fields_short_names | _ -> false let ctx_add (is_opaque : bool) (id : id) (name : string) (ctx : extraction_ctx) @@ -858,6 +896,10 @@ let ctx_get_assumed_type (id : assumed_ty) (ctx : extraction_ctx) : string = let is_opaque = false in ctx_get_type is_opaque (Assumed id) ctx +let ctx_get_trait_self_clause (ctx : extraction_ctx) : string = + let with_opaque_pre = false in + ctx_get with_opaque_pre TraitSelfClauseId ctx + let ctx_get_trait_decl (with_opaque_pre : bool) (id : trait_decl_id) (ctx : extraction_ctx) : string = ctx_get with_opaque_pre (TraitDeclId id) ctx @@ -871,11 +913,6 @@ let ctx_get_trait_assoc_type (id : trait_decl_id) (type_name : string) let is_opaque = false in ctx_get is_opaque (TraitAssocTypeId (id, type_name)) ctx -let ctx_get_local_trait_assoc_type (type_name : string) (ctx : extraction_ctx) : - string = - let is_opaque = false in - ctx_get is_opaque (LocalTraitAssocTypeId type_name) ctx - let ctx_get_trait_parent_clause (id : trait_decl_id) (clause : trait_clause_id) (ctx : extraction_ctx) : string = let with_opaque_pre = false in @@ -969,6 +1006,16 @@ let ctx_add_var (basename : string) (id : VarId.id) (ctx : extraction_ctx) : let ctx = ctx_add is_opaque (VarId id) name ctx in (ctx, name) +(** Generate a unique variable name for the trait self clause and add it to the context *) +let ctx_add_trait_self_clause (ctx : extraction_ctx) : extraction_ctx * string = + let is_opaque = false in + let basename = ctx.fmt.trait_self_clause_basename in + let name = + basename_to_unique ctx.names_map.names_set ctx.fmt.append_index basename + in + let ctx = ctx_add is_opaque TraitSelfClauseId name ctx in + (ctx, name) + (** Generate a unique trait clause name and add it to the context *) let ctx_add_local_trait_clause (basename : string) (id : TraitClauseId.id) (ctx : extraction_ctx) : extraction_ctx * string = diff --git a/compiler/Translate.ml b/compiler/Translate.ml index f4f59187..790dbe14 100644 --- a/compiler/Translate.ml +++ b/compiler/Translate.ml @@ -1007,6 +1007,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) : use_dep_ite = !Config.backend = Lean && !Config.extract_decreases_clauses; fun_name_info = PureUtils.RegularFunIdMap.empty; trait_decl_id = None (* None by default *); + is_provided_method = false (* false by default *); } in |