summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorSon Ho2023-09-03 13:32:43 +0200
committerSon Ho2023-09-03 13:32:43 +0200
commit0cafb31dd42c95f22e0b6680531c27fa0508e376 (patch)
tree1ad715f988b14ca8d6a5755299586c7c77701950 /compiler
parent4cf1217f593b46a17130403df85b5f39f9e3eb85 (diff)
Make progress on the extraction
Diffstat (limited to 'compiler')
-rw-r--r--compiler/AssociatedTypes.ml10
-rw-r--r--compiler/Extract.ml76
-rw-r--r--compiler/ExtractBase.ml57
-rw-r--r--compiler/PrintPure.ml4
-rw-r--r--compiler/Pure.ml16
-rw-r--r--compiler/SymbolicToPure.ml27
6 files changed, 138 insertions, 52 deletions
diff --git a/compiler/AssociatedTypes.ml b/compiler/AssociatedTypes.ml
index 07ab70bd..c4a9538d 100644
--- a/compiler/AssociatedTypes.ml
+++ b/compiler/AssociatedTypes.ml
@@ -100,7 +100,7 @@ let rec trait_instance_id_is_local_clause (id : 'r T.trait_instance_id) : bool =
match id with
| T.Self | Clause _ -> true
| TraitImpl _ | BuiltinOrAuto _ | TraitRef _ | UnknownTrait _ -> false
- | ParentClause (id, _) | ItemClause (id, _, _) ->
+ | ParentClause (id, _, _) | ItemClause (id, _, _, _) ->
trait_instance_id_is_local_clause id
(** About the conversion functions: for now we need them (TODO: merge ety, rty, etc.),
@@ -212,14 +212,14 @@ and ctx_normalize_trait_instance_id :
(id, None)
| Clause _ -> (id, None)
| BuiltinOrAuto _ -> (id, None)
- | ParentClause (inst_id, clause_id) -> (
+ | ParentClause (inst_id, decl_id, clause_id) -> (
let inst_id, impl = ctx_normalize_trait_instance_id ctx inst_id in
(* Check if the inst_id refers to a specific implementation, if yes project *)
match impl with
| None ->
(* This is actually a local clause *)
assert (trait_instance_id_is_local_clause inst_id);
- (ParentClause (inst_id, clause_id), None)
+ (ParentClause (inst_id, decl_id, clause_id), None)
| Some impl ->
(* We figure out the parent clause by doing the following:
{[
@@ -243,14 +243,14 @@ and ctx_normalize_trait_instance_id :
(* Sanity check: the clause necessarily refers to an impl *)
let _ = TypesUtils.trait_instance_id_as_trait_impl clause.trait_id in
(TraitRef clause, Some clause))
- | ItemClause (inst_id, item_name, clause_id) -> (
+ | ItemClause (inst_id, decl_id, item_name, clause_id) -> (
let inst_id, impl = ctx_normalize_trait_instance_id ctx inst_id in
(* Check if the inst_id refers to a specific implementation, if yes project *)
match impl with
| None ->
(* This is actually a local clause *)
assert (trait_instance_id_is_local_clause inst_id);
- (ParentClause (inst_id, clause_id), None)
+ (ParentClause (inst_id, decl_id, clause_id), None)
| Some impl ->
(* We figure out the item clause by doing the following:
{[
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index 3c4feca5..ad89a59e 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -1223,23 +1223,29 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter)
if inside then F.pp_print_string fmt ")"
| TraitType (trait_ref, generics, type_name) ->
if !parameterize_trait_types then raise (Failure "Unimplemented")
- else (
+ else if trait_ref.trait_id <> Self then (
(* HOL4 doesn't have 1st class types *)
assert (!backend <> HOL4);
- if trait_ref.trait_id <> Self then (
- F.pp_print_string fmt "(";
- extract_trait_ref ctx fmt no_params_tys false trait_ref;
- extract_generic_args ctx fmt no_params_tys generics;
- (* TODO: lookup the type name *)
- F.pp_print_string fmt (")." ^ type_name))
- else
- (* Can only happen when extracting the signature of a trait method
- *declaration*. If extracting items for a trait method implementation,
- the type should have been normalized. For trait method declarations
- we directly reference the item. *)
- let trait_decl_id = Option.get ctx.trait_decl_id in
- assert (generics = empty_generic_args);
- F.pp_print_string fmt type_name)
+ let use_brackets = generics <> empty_generic_args in
+ if use_brackets then F.pp_print_string fmt "(";
+ extract_trait_ref ctx fmt no_params_tys false trait_ref;
+ extract_generic_args ctx fmt no_params_tys generics;
+ let name =
+ ctx_get_trait_assoc_type trait_ref.trait_decl_ref.trait_decl_id
+ type_name ctx
+ in
+ if use_brackets then F.pp_print_string fmt ")";
+ F.pp_print_string fmt ("." ^ name))
+ else
+ (* Can only happen when extracting the signature of a trait method
+ *declaration* or a provided trait method (for a declaration).
+ If extracting items for a trait method implementation,
+ the type should have been normalized. For trait method declarations
+ we directly reference the item. *)
+ assert (ctx.trait_decl_id <> None);
+ assert (generics = empty_generic_args);
+ let name = ctx_get_local_trait_assoc_type type_name ctx in
+ F.pp_print_string fmt name
and extract_trait_ref (ctx : extraction_ctx) (fmt : F.formatter)
(no_params_tys : TypeDeclId.Set.t) (inside : bool) (tr : trait_ref) : unit =
@@ -1270,17 +1276,35 @@ and extract_generic_args (ctx : extraction_ctx) (fmt : F.formatter)
(extract_trait_ref ctx fmt no_params_tys true)
trait_refs)
-and extract_trait_instance_id (_ctx : extraction_ctx) (_fmt : F.formatter)
- (_no_params_tys : TypeDeclId.Set.t) (_inside : bool)
- (id : trait_instance_id) : unit =
+and extract_trait_instance_id (ctx : extraction_ctx) (fmt : F.formatter)
+ (no_params_tys : TypeDeclId.Set.t) (inside : bool) (id : trait_instance_id)
+ : unit =
+ let with_opaque_pre = false in
match id with
- | Self -> raise (Failure "TODO")
- | TraitImpl _ -> raise (Failure "TODO")
- | Clause _ -> raise (Failure "TODO")
- | ParentClause _ -> raise (Failure "TODO")
- | ItemClause _ -> raise (Failure "TODO")
- | TraitRef _ -> raise (Failure "TODO")
- | UnknownTrait _ -> raise (Failure "TODO")
+ | Self ->
+ (* This has specific treatment depending on the item we're extracting
+ (associated type, etc.). We should have caught this elsewhere. *)
+ raise (Failure "Unexpected")
+ | TraitImpl id ->
+ let name = ctx_get_trait_impl with_opaque_pre id ctx in
+ F.pp_print_string fmt name
+ | Clause id ->
+ let name = ctx_get_local_trait_clause id ctx in
+ F.pp_print_string fmt name
+ | ParentClause (inst_id, decl_id, clause_id) ->
+ (* Use the trait decl id to lookup the name *)
+ let name = ctx_get_trait_parent_clause decl_id clause_id ctx in
+ extract_trait_instance_id ctx fmt no_params_tys true inst_id;
+ F.pp_print_string fmt ("." ^ name)
+ | ItemClause (inst_id, decl_id, item_name, clause_id) ->
+ (* Use the trait decl id to lookup the name *)
+ let name = ctx_get_trait_item_clause decl_id item_name clause_id ctx in
+ extract_trait_instance_id ctx fmt no_params_tys true inst_id;
+ F.pp_print_string fmt ("." ^ name)
+ | TraitRef trait_ref -> extract_trait_ref ctx fmt no_params_tys true trait_ref
+ | UnknownTrait _ ->
+ (* This is an error case *)
+ raise (Failure "Unexpected")
(** Compute the names for all the top-level identifiers used in a type
definition (type name, variant names, field names, etc. but not type
@@ -1673,7 +1697,7 @@ let extract_generic_params (ctx : extraction_ctx) (fmt : F.formatter)
insert_req_space ();
(* ( *)
left_bracket ();
- let n = ctx_get_trait_clause_var clause.clause_id ctx in
+ let n = ctx_get_local_trait_clause clause.clause_id ctx in
F.pp_print_string fmt n;
F.pp_print_space fmt ();
F.pp_print_string fmt ":";
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index 02ff266e..697b1027 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -408,7 +408,11 @@ type id =
| VarId of VarId.id
| TraitDeclId of TraitDeclId.id
| TraitImplId of TraitImplId.id
- | TraitClauseId of TraitClauseId.id
+ | LocalTraitClauseId of TraitClauseId.id
+ | LocalTraitAssocTypeId of string (** Specifically for: [Self::Ty] *)
+ | TraitAssocTypeId of TraitDeclId.id * string (** A trait associated type *)
+ | TraitParentClauseId of TraitDeclId.id * TraitClauseId.id
+ | TraitItemClauseId of TraitDeclId.id * string * TraitClauseId.id
| UnknownId
(** Used for stored various strings like keywords, definitions which
should always be in context, etc. and which can't be linked to one
@@ -746,7 +750,20 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
| VarId id -> "var_id: " ^ VarId.to_string id
| TraitDeclId id -> "trait_decl_id: " ^ TraitDeclId.to_string id
| TraitImplId id -> "trait_impl_id: " ^ TraitImplId.to_string id
- | TraitClauseId id -> "trait_clause_id: " ^ TraitClauseId.to_string id
+ | LocalTraitClauseId id ->
+ "local_trait_clause_id: " ^ TraitClauseId.to_string id
+ | LocalTraitAssocTypeId type_name -> "local_trait_assoc_type_id: " ^ type_name
+ | TraitParentClauseId (id, clause_id) ->
+ "trait_parent_clause_id: decl_id:" ^ TraitDeclId.to_string id
+ ^ ", clause_id: "
+ ^ TraitClauseId.to_string clause_id
+ | TraitItemClauseId (id, item_name, clause_id) ->
+ "trait_item_clause_id: decl_id:" ^ TraitDeclId.to_string id
+ ^ ", item name: " ^ item_name ^ ", clause_id: "
+ ^ TraitClauseId.to_string clause_id
+ | TraitAssocTypeId (id, type_name) ->
+ "trait_assoc_type_id: decl_id:" ^ TraitDeclId.to_string id
+ ^ ", type name: " ^ type_name
(** We might not check for collisions for some specific ids (ex.: field names) *)
let allow_collisions (id : id) : bool =
@@ -849,6 +866,26 @@ let ctx_get_trait_impl (with_opaque_pre : bool) (id : trait_impl_id)
(ctx : extraction_ctx) : string =
ctx_get with_opaque_pre (TraitImplId id) ctx
+let ctx_get_trait_assoc_type (id : trait_decl_id) (type_name : string)
+ (ctx : extraction_ctx) : string =
+ let is_opaque = false in
+ ctx_get is_opaque (TraitAssocTypeId (id, type_name)) ctx
+
+let ctx_get_local_trait_assoc_type (type_name : string) (ctx : extraction_ctx) :
+ string =
+ let is_opaque = false in
+ ctx_get is_opaque (LocalTraitAssocTypeId type_name) ctx
+
+let ctx_get_trait_parent_clause (id : trait_decl_id) (clause : trait_clause_id)
+ (ctx : extraction_ctx) : string =
+ let with_opaque_pre = false in
+ ctx_get with_opaque_pre (TraitParentClauseId (id, clause)) ctx
+
+let ctx_get_trait_item_clause (id : trait_decl_id) (item : string)
+ (clause : trait_clause_id) (ctx : extraction_ctx) : string =
+ let with_opaque_pre = false in
+ ctx_get with_opaque_pre (TraitItemClauseId (id, item, clause)) ctx
+
let ctx_get_var (id : VarId.id) (ctx : extraction_ctx) : string =
let is_opaque = false in
ctx_get is_opaque (VarId id) ctx
@@ -862,10 +899,10 @@ let ctx_get_const_generic_var (id : ConstGenericVarId.id) (ctx : extraction_ctx)
let is_opaque = false in
ctx_get is_opaque (ConstGenericVarId id) ctx
-let ctx_get_trait_clause_var (id : TraitClauseId.id) (ctx : extraction_ctx) :
+let ctx_get_local_trait_clause (id : TraitClauseId.id) (ctx : extraction_ctx) :
string =
let is_opaque = false in
- ctx_get is_opaque (TraitClauseId id) ctx
+ ctx_get is_opaque (LocalTraitClauseId id) ctx
let ctx_get_field (type_id : type_id) (field_id : FieldId.id)
(ctx : extraction_ctx) : string =
@@ -933,13 +970,13 @@ let ctx_add_var (basename : string) (id : VarId.id) (ctx : extraction_ctx) :
(ctx, name)
(** Generate a unique trait clause name and add it to the context *)
-let ctx_add_trait_clause (basename : string) (id : TraitClauseId.id)
+let ctx_add_local_trait_clause (basename : string) (id : TraitClauseId.id)
(ctx : extraction_ctx) : extraction_ctx * string =
let is_opaque = false in
let name =
basename_to_unique ctx.names_map.names_set ctx.fmt.append_index basename
in
- let ctx = ctx_add is_opaque (TraitClauseId id) name ctx in
+ let ctx = ctx_add is_opaque (LocalTraitClauseId id) name ctx in
(ctx, name)
(** See {!ctx_add_var} *)
@@ -964,12 +1001,12 @@ let ctx_add_const_generic_params (vars : const_generic_var list)
ctx_add_const_generic_var var.name var.index ctx)
ctx vars
-let ctx_add_trait_clauses (clauses : trait_clause list) (ctx : extraction_ctx) :
- extraction_ctx * string list =
+let ctx_add_local_trait_clauses (clauses : trait_clause list)
+ (ctx : extraction_ctx) : extraction_ctx * string list =
List.fold_left_map
(fun ctx (c : trait_clause) ->
let basename = ctx.fmt.trait_clause_basename ctx.names_map.names_set c in
- ctx_add_trait_clause basename c.clause_id ctx)
+ ctx_add_local_trait_clause basename c.clause_id ctx)
ctx clauses
(** Returns the lists of names for:
@@ -982,7 +1019,7 @@ let ctx_add_generic_params (generics : generic_params) (ctx : extraction_ctx) :
let { types; const_generics; trait_clauses } = generics in
let ctx, tys = ctx_add_type_params types ctx in
let ctx, cgs = ctx_add_const_generic_params const_generics ctx in
- let ctx, tcs = ctx_add_trait_clauses trait_clauses ctx in
+ let ctx, tcs = ctx_add_local_trait_clauses trait_clauses ctx in
(ctx, tys, cgs, tcs)
let ctx_add_type_decl_struct (def : type_decl) (ctx : extraction_ctx) :
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index 77d25823..fc39074d 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -273,11 +273,11 @@ and trait_instance_id_to_string (fmt : type_formatter) (inside : bool)
| Self -> "Self"
| TraitImpl id -> fmt.trait_impl_id_to_string id
| Clause id -> fmt.trait_clause_id_to_string id
- | ParentClause (inst_id, clause_id) ->
+ | ParentClause (inst_id, _decl_id, clause_id) ->
let inst_id = trait_instance_id_to_string fmt false inst_id in
let clause_id = fmt.trait_clause_id_to_string clause_id in
"parent(" ^ inst_id ^ ")::" ^ clause_id
- | ItemClause (inst_id, item_name, clause_id) ->
+ | ItemClause (inst_id, _decl_id, item_name, clause_id) ->
let inst_id = trait_instance_id_to_string fmt false inst_id in
let clause_id = fmt.trait_clause_id_to_string clause_id in
"(" ^ inst_id ^ ")::" ^ item_name ^ "::[" ^ clause_id ^ "]"
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index 272ec328..725f71ad 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -276,7 +276,16 @@ type ty =
| TraitType of trait_ref * generic_args * string
(** The string is for the name of the associated type *)
-and trait_ref = { trait_id : trait_instance_id; generics : generic_args }
+and trait_ref = {
+ trait_id : trait_instance_id;
+ generics : generic_args;
+ trait_decl_ref : trait_decl_ref;
+}
+
+and trait_decl_ref = {
+ trait_decl_id : trait_decl_id;
+ decl_generics : generic_args; (* The name: annoying field collisions... *)
+}
and generic_args = {
types : ty list;
@@ -288,8 +297,9 @@ and trait_instance_id =
| Self
| TraitImpl of trait_impl_id
| Clause of trait_clause_id
- | ParentClause of trait_instance_id * trait_clause_id
- | ItemClause of trait_instance_id * trait_item_name * trait_clause_id
+ | ParentClause of trait_instance_id * trait_decl_id * trait_clause_id
+ | ItemClause of
+ trait_instance_id * trait_decl_id * trait_item_name * trait_clause_id
| TraitRef of trait_ref
| UnknownTrait of string
[@@deriving
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index c827475b..166f08a0 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -393,7 +393,15 @@ and translate_trait_ref (translate_ty : 'r T.ty -> ty) (tr : 'r T.trait_ref) :
trait_ref =
let trait_id = translate_trait_instance_id translate_ty tr.trait_id in
let generics = translate_generic_args translate_ty tr.generics in
- { trait_id; generics }
+ let trait_decl_ref =
+ translate_trait_decl_ref translate_ty tr.trait_decl_ref
+ in
+ { trait_id; generics; trait_decl_ref }
+
+and translate_trait_decl_ref (translate_ty : 'r T.ty -> ty)
+ (tr : 'r T.trait_decl_ref) : trait_decl_ref =
+ let decl_generics = translate_generic_args translate_ty tr.decl_generics in
+ { trait_decl_id = tr.trait_decl_id; decl_generics }
and translate_trait_instance_id (translate_ty : 'r T.ty -> ty)
(id : 'r T.trait_instance_id) : trait_instance_id =
@@ -405,12 +413,12 @@ and translate_trait_instance_id (translate_ty : 'r T.ty -> ty)
(* We should have eliminated those in the prepasses *)
raise (Failure "Unreachable")
| Clause id -> Clause id
- | ParentClause (inst_id, clause_id) ->
+ | ParentClause (inst_id, decl_id, clause_id) ->
let inst_id = translate_trait_instance_id inst_id in
- ParentClause (inst_id, clause_id)
- | ItemClause (inst_id, item_name, clause_id) ->
+ ParentClause (inst_id, decl_id, clause_id)
+ | ItemClause (inst_id, decl_id, item_name, clause_id) ->
let inst_id = translate_trait_instance_id inst_id in
- ItemClause (inst_id, item_name, clause_id)
+ ItemClause (inst_id, decl_id, item_name, clause_id)
| TraitRef tr -> TraitRef (translate_trait_ref translate_ty tr)
| UnknownTrait s -> raise (Failure ("Unknown trait found: " ^ s))
@@ -2644,7 +2652,14 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
let trait_refs =
List.map
(fun (c : trait_clause) ->
- { trait_id = Clause c.clause_id; generics = empty_generic_args })
+ let trait_decl_ref =
+ { trait_decl_id = c.trait_id; decl_generics = empty_generic_args }
+ in
+ {
+ trait_id = Clause c.clause_id;
+ generics = empty_generic_args;
+ trait_decl_ref;
+ })
trait_clauses
in
{ types; const_generics; trait_refs }