summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-08-31 19:10:00 +0200
committerSon Ho2023-08-31 19:10:00 +0200
commitc61b32393508479657b51b777a0b4816815a55a5 (patch)
tree3e5f018d13c237a1858267eebd80cc16149578db
parentf8555e3c1ecfc9667795c19975067b37ba5c617f (diff)
Make progress on Extract and ExtractBase
Diffstat (limited to '')
-rw-r--r--compiler/Config.ml8
-rw-r--r--compiler/Extract.ml143
-rw-r--r--compiler/ExtractBase.ml33
3 files changed, 132 insertions, 52 deletions
diff --git a/compiler/Config.ml b/compiler/Config.ml
index bd80769f..ccbb4c75 100644
--- a/compiler/Config.ml
+++ b/compiler/Config.ml
@@ -323,3 +323,11 @@ let wrap_opaque_in_sig = ref false
information), we use short names (i.e., the original field names).
*)
let record_fields_short_names = ref false
+
+(** Parameterize the traits with their associated types, so as not to use
+ types as first class objects.
+
+ This is useful for some backends with limited expressiveness like HOL4,
+ and to account for type constraints (like [fn f<T : Foo>(...) where T::bar = usize]).
+ *)
+let parameterize_trait_types = ref false
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 7daec16f..4238a152 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -757,6 +757,21 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
let var_basename (_varset : StringSet.t) (basename : string option) (ty : ty)
: string =
+ (* Small helper to derive var names from ADT type names.
+
+ We do the following:
+ - convert the type name to snake case
+ - take the first letter of every "letter group"
+ Ex.: "HashMap" -> "hash_map" -> "hm"
+ *)
+ let name_from_type_ident (name : string) : string =
+ let cl = to_snake_case name in
+ let cl = String.split_on_char '_' cl in
+ let cl = List.filter (fun s -> String.length s > 0) cl in
+ assert (List.length cl > 0);
+ let cl = List.map (fun s -> s.[0]) cl in
+ StringUtils.string_of_chars cl
+ in
(* If there is a basename, we use it *)
match basename with
| Some basename ->
@@ -765,11 +780,11 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
| None -> (
(* No basename: we use the first letter of the type *)
match ty with
- | Adt (type_id, tys, _) -> (
+ | Adt (type_id, generics) -> (
match type_id with
| Tuple ->
(* The "pair" case is frequent enough to have its special treatment *)
- if List.length tys = 2 then "p" else "t"
+ if List.length generics.types = 2 then "p" else "t"
| Assumed Result -> "r"
| Assumed Error -> ConstStrings.error_basename
| Assumed Fuel -> ConstStrings.fuel_basename
@@ -784,21 +799,13 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
let def =
TypeDeclId.Map.find adt_id ctx.type_context.type_decls
in
- (* We do the following:
- * - compute the type name, and retrieve the last ident
- * - convert this to snake case
- * - take the first letter of every "letter group"
+ (* Derive the var name from the last ident of the type name
* Ex.: ["hashmap"; "HashMap"] ~~> "HashMap" -> "hash_map" -> "hm"
*)
- (* Thename shouldn't be empty, and its last element should
+ (* The name shouldn't be empty, and its last element should
* be an ident *)
let cl = List.nth def.name (List.length def.name - 1) in
- let cl = to_snake_case (Names.as_ident cl) in
- let cl = String.split_on_char '_' cl in
- let cl = List.filter (fun s -> String.length s > 0) cl in
- assert (List.length cl > 0);
- let cl = List.map (fun s -> s.[0]) cl in
- StringUtils.string_of_chars cl)
+ name_from_type_ident (Names.as_ident cl))
| TypeVar _ -> (
(* TODO: use "t" also for F* *)
match !backend with
@@ -806,7 +813,8 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
| Coq | Lean | HOL4 -> "t" (* lacking inspiration here... *))
| Literal lty -> (
match lty with Bool -> "b" | Char -> "c" | Integer _ -> "i")
- | Arrow _ -> "f")
+ | Arrow _ -> "f"
+ | TraitType (_, _, name) -> name_from_type_ident name)
in
let type_var_basename (_varset : StringSet.t) (basename : string) : string =
(* Rust type variables are snake-case and start with a capital letter *)
@@ -1131,13 +1139,13 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter)
(no_params_tys : TypeDeclId.Set.t) (inside : bool) (ty : ty) : unit =
let extract_rec = extract_ty ctx fmt no_params_tys in
match ty with
- | Adt (type_id, tys, cgs) -> (
- let has_params = tys <> [] || cgs <> [] in
+ | Adt (type_id, generics) -> (
+ let has_params = generics <> empty_generic_args in
match type_id with
| Tuple ->
(* This is a bit annoying, but in F*/Coq/HOL4 [()] is not the unit type:
* we have to write [unit]... *)
- if tys = [] then F.pp_print_string fmt (unit_name ())
+ if generics.types = [] then F.pp_print_string fmt (unit_name ())
else (
F.pp_print_string fmt "(";
Collections.List.iter_link
@@ -1152,7 +1160,7 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter)
in
F.pp_print_string fmt product;
F.pp_print_space fmt ())
- (extract_rec true) tys;
+ (extract_rec true) generics.types;
F.pp_print_string fmt ")")
| AdtId _ | Assumed _ -> (
(* HOL4 behaves differently. Where in Coq/FStar/Lean we would write:
@@ -1169,36 +1177,34 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter)
(* TODO: for now, only the opaque *functions* are extracted in the
opaque module. The opaque *types* are assumed. *)
F.pp_print_string fmt (ctx_get_type with_opaque_pre type_id ctx);
- if tys <> [] then (
- F.pp_print_space fmt ();
- Collections.List.iter_link (F.pp_print_space fmt)
- (extract_rec true) tys);
- if cgs <> [] then (
- F.pp_print_space fmt ();
- Collections.List.iter_link (F.pp_print_space fmt)
- (extract_const_generic ctx fmt true)
- cgs);
+ extract_generic_args ctx fmt no_params_tys generics;
if print_paren then F.pp_print_string fmt ")"
| HOL4 ->
- (* Const generics are unsupported in HOL4 *)
- assert (cgs = []);
+ let { types; const_generics; trait_refs } = generics in
+ (* Const generics are not supported in HOL4 *)
+ assert (const_generics = []);
let print_tys =
match type_id with
| AdtId id -> not (TypeDeclId.Set.mem id no_params_tys)
| Assumed _ -> true
| _ -> raise (Failure "Unreachable")
in
- if tys <> [] && print_tys then (
- let print_paren = List.length tys > 1 in
+ if const_generics <> [] && print_tys then (
+ let print_paren = List.length types > 1 in
if print_paren then F.pp_print_string fmt "(";
Collections.List.iter_link
(fun () ->
F.pp_print_string fmt ",";
F.pp_print_space fmt ())
- (extract_rec true) tys;
+ (extract_rec true) types;
if print_paren then F.pp_print_string fmt ")";
F.pp_print_space fmt ());
- F.pp_print_string fmt (ctx_get_type with_opaque_pre type_id ctx)))
+ F.pp_print_string fmt (ctx_get_type with_opaque_pre type_id ctx);
+ if trait_refs <> [] then (
+ F.pp_print_space fmt ();
+ Collections.List.iter_link (F.pp_print_space fmt)
+ (extract_trait_ref ctx fmt no_params_tys true)
+ trait_refs)))
| TypeVar vid -> F.pp_print_string fmt (ctx_get_type_var vid ctx)
| Literal lty -> extract_literal_type ctx fmt lty
| Arrow (arg_ty, ret_ty) ->
@@ -1209,6 +1215,64 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_space fmt ();
extract_rec false ret_ty;
if inside then F.pp_print_string fmt ")"
+ | TraitType (trait_ref, generics, type_name) ->
+ if !parameterize_trait_types then raise (Failure "Unimplemented")
+ else (
+ (* HOL4 doesn't have 1st class types *)
+ assert (!backend <> HOL4);
+ if trait_ref.trait_id <> Self then (
+ F.pp_print_string fmt "(";
+ extract_trait_ref ctx fmt no_params_tys false trait_ref;
+ extract_generic_args ctx fmt no_params_tys generics;
+ (* TODO: lookup the type name *)
+ F.pp_print_string fmt (")." ^ type_name))
+ else
+ (* Can only happen when extracting the signature of a trait method
+ *declaration*. If extracting items for a trait method implementation,
+ the type should have been normalized. For trait method declarations
+ we directly reference the item. *)
+ let trait_decl_id = Option.get ctx.trait_decl_id in
+ assert (generics = empty_generic_args);
+ F.pp_print_string fmt type_name)
+
+and extract_trait_ref (ctx : extraction_ctx) (fmt : F.formatter)
+ (no_params_tys : TypeDeclId.Set.t) (inside : bool) (tr : trait_ref) : unit =
+ let use_brackets = tr.generics <> empty_generic_args && inside in
+ if use_brackets then F.pp_print_string fmt "(";
+ extract_trait_instance_id ctx fmt no_params_tys inside tr.trait_id;
+ extract_generic_args ctx fmt no_params_tys tr.generics;
+ if use_brackets then F.pp_print_string fmt ")"
+
+and extract_generic_args (ctx : extraction_ctx) (fmt : F.formatter)
+ (no_params_tys : TypeDeclId.Set.t) (generics : generic_args) : unit =
+ let { types; const_generics; trait_refs } = generics in
+ if types <> [] then (
+ F.pp_print_space fmt ();
+ Collections.List.iter_link (F.pp_print_space fmt)
+ (extract_ty ctx fmt no_params_tys true)
+ types);
+ if const_generics <> [] then (
+ F.pp_print_space fmt ();
+ Collections.List.iter_link (F.pp_print_space fmt)
+ (extract_const_generic ctx fmt true)
+ const_generics);
+ if trait_refs <> [] then (
+ F.pp_print_space fmt ();
+ Collections.List.iter_link (F.pp_print_space fmt)
+ (extract_trait_ref ctx fmt no_params_tys true)
+ trait_refs)
+
+and extract_trait_instance_id (_ctx : extraction_ctx) (_fmt : F.formatter)
+ (_no_params_tys : TypeDeclId.Set.t) (_inside : bool)
+ (id : trait_instance_id) : unit =
+ match id with
+ | Self -> raise (Failure "TODO")
+ | TraitImpl _ -> raise (Failure "TODO")
+ | Clause _ -> raise (Failure "TODO")
+ | ParentClause _ -> raise (Failure "TODO")
+ | ItemClause _ -> raise (Failure "TODO")
+ | TraitRef _ -> raise (Failure "TODO")
+ | UnknownTrait _ -> raise (Failure "TODO")
(** Compute the names for all the top-level identifiers used in a type
definition (type name, variant names, field names, etc. but not type
@@ -1551,19 +1615,16 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
*)
let is_opaque = type_kind = None in
let is_opaque_coq = !backend = Coq && is_opaque in
- let use_forall =
- is_opaque_coq && (def.type_params <> [] || def.const_generic_params <> [])
- in
+ let use_forall = is_opaque_coq && def.generics <> empty_generic_params in
(* Retrieve the definition name *)
let with_opaque_pre = false in
let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in
(* Add the type and const generic params - note that we need those bindings only for the
* body translation (they are not top-level) *)
- let ctx_body, type_params, cg_params =
- ctx_add_type_const_generic_params def.type_params def.const_generic_params
- ctx
+ let ctx_body, type_params, cg_params, trait_clauses =
+ ctx_add_generic_params def.generics ctx
in
- let ty_cg_params = List.append type_params cg_params in
+ let all_params = List.concat [ type_params; cg_params; trait_clauses ] in
(* Add a break before *)
if !backend <> HOL4 || not (decl_is_first_from_group kind) then
F.pp_print_break fmt 0 0;
@@ -1586,7 +1647,7 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
(* HOL4 doesn't support const generics *)
assert (cg_params = [] || !backend <> HOL4);
(* Print the type/const generic parameters *)
- if ty_cg_params <> [] && !backend <> HOL4 then (
+ if all_params <> [] && !backend <> HOL4 then (
if use_forall then (
F.pp_print_space fmt ();
F.pp_print_string fmt ":";
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index d733c763..96ecfd42 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -586,6 +586,8 @@ type extraction_ctx = {
in case a Rust function only has one backward translation
and we filter the forward function because it returns unit.
*)
+ trait_decl_id : trait_decl_id option;
+ (** If we are extracting a trait declaration, identifies it *)
}
(** Debugging function, used when communicating name collisions to the user,
@@ -885,12 +887,24 @@ let ctx_add_const_generic_params (vars : const_generic_var list)
ctx_add_const_generic_var var.name var.index ctx)
ctx vars
-let ctx_add_type_const_generic_params (tvars : type_var list)
- (cgvars : const_generic_var list) (ctx : extraction_ctx) :
- extraction_ctx * string list * string list =
- let ctx, tys = ctx_add_type_params tvars ctx in
- let ctx, cgs = ctx_add_const_generic_params cgvars ctx in
- (ctx, tys, cgs)
+let ctx_add_trait_clauses (clauses : trait_clause list) (ctx : extraction_ctx) :
+ extraction_ctx * string list =
+ List.fold_left_map
+ (fun ctx (c : trait_clause) -> ctx_add_trait_clause c ctx)
+ ctx clauses
+
+(** Returns the lists of names for:
+ - the type variables
+ - the const generic variables
+ - the trait clauses
+ *)
+let ctx_add_generic_params (generics : generic_params) (ctx : extraction_ctx) :
+ extraction_ctx * string list * string list * string list =
+ let { types; const_generics; trait_clauses } = generics in
+ let ctx, tys = ctx_add_type_params types ctx in
+ let ctx, cgs = ctx_add_const_generic_params const_generics ctx in
+ let ctx, tcs = ctx_add_trait_clauses trait_clauses ctx in
+ (ctx, tys, cgs, tcs)
let ctx_add_type_decl_struct (def : type_decl) (ctx : extraction_ctx) :
extraction_ctx * string =
@@ -1003,14 +1017,11 @@ let ctx_add_fun_decl (trans_group : bool * pure_fun_translation)
| None -> None
| Some rg_id ->
let rg = T.RegionGroupId.nth sg.regions_hierarchy rg_id in
- let regions =
+ let region_names =
List.map
- (fun rid -> T.RegionVarId.nth sg.region_params rid)
+ (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name)
rg.regions
in
- let region_names =
- List.map (fun (r : T.region_var) -> r.name) regions
- in
Some { id = rg_id; region_names }
in
let is_opaque = def.body = None in