summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
authorSon Ho2023-09-01 12:01:03 +0200
committerSon Ho2023-09-01 12:01:03 +0200
commit06360698561019d7f480dcb4263e2099d9a03ca5 (patch)
treea38b04f4798efb26f4620449b002055af98d79c2 /compiler
parentc61b32393508479657b51b777a0b4816815a55a5 (diff)
Implement the normalization functions in AssociatedTypes
Diffstat (limited to '')
-rw-r--r--compiler/AssociatedTypes.ml262
-rw-r--r--compiler/Contexts.ml43
-rw-r--r--compiler/InterpreterLoopsJoinCtxs.ml6
-rw-r--r--compiler/InterpreterStatements.ml2
-rw-r--r--compiler/Substitute.ml36
5 files changed, 335 insertions, 14 deletions
diff --git a/compiler/AssociatedTypes.ml b/compiler/AssociatedTypes.ml
index 4e5625cb..8e08db6e 100644
--- a/compiler/AssociatedTypes.ml
+++ b/compiler/AssociatedTypes.ml
@@ -18,11 +18,255 @@ module L = Logging
(** The local logger *)
let log = L.associated_types_log
+(** A trait instance id refers to a local clause if it only uses the variants:
+ [Self], [Clause], [ParentClause], [ItemClause] *)
+let rec trait_instance_id_is_local_clause (id : 'r T.trait_instance_id) : bool =
+ match id with
+ | T.Self | Clause _ -> true
+ | TraitImpl _ | BuiltinOrAuto _ | TraitRef _ | UnknownTrait _ -> false
+ | ParentClause (id, _) | ItemClause (id, _, _) ->
+ trait_instance_id_is_local_clause id
+
+(** About the conversion functions: for now we need them (TODO: merge ety, rty, etc.),
+ but they should be applied to types without regions.
+ *)
+type 'r norm_ctx = {
+ ctx : C.eval_ctx;
+ get_ty_repr : 'r C.trait_type_ref -> 'r T.ty option;
+ convert_ety : T.ety -> 'r T.ty;
+ convert_etrait_ref : T.etrait_ref -> 'r T.trait_ref;
+}
+
(** Normalize a type by simplyfying the references to trait associated types
and choosing a representative when there are equalities between types
enforced by local clauses (i.e., `where Trait1::T = Trait2::U`. *)
-let ctx_normalize_type (_ctx : C.eval_ctx) (_ty : 'r T.ty) : 'r T.ty =
- raise (Failure "Unimplemented")
+let rec ctx_normalize_ty : 'r. 'r norm_ctx -> 'r T.ty -> 'r T.ty =
+ fun ctx ty ->
+ match ty with
+ | T.Adt (id, generics) -> Adt (id, ctx_normalize_generic_args ctx generics)
+ | TypeVar _ | Literal _ | Never -> ty
+ | Ref (r, ty, rkind) ->
+ let ty = ctx_normalize_ty ctx ty in
+ T.Ref (r, ty, rkind)
+ | TraitType (trait_ref, generics, type_name) -> (
+ (* Normalize and attempt to project the type from the trait ref *)
+ let trait_ref = ctx_normalize_trait_ref ctx trait_ref in
+ let generics = ctx_normalize_generic_args ctx generics in
+ let ty : 'r T.ty =
+ match trait_ref.trait_id with
+ | T.TraitRef { T.trait_id = T.TraitImpl impl_id; generics; _ } ->
+ (* Lookup the implementation *)
+ let trait_impl = C.ctx_lookup_trait_impl ctx.ctx impl_id in
+ (* Lookup the type *)
+ let ty = snd (List.assoc type_name trait_impl.types) in
+ (* Annoying: convert etype to an stype - TODO: hwo to avoid that? *)
+ let ty : T.sty = TypesUtils.ety_no_regions_to_gr_ty ty in
+ (* Substitute - annoying: we can't use *)
+ let tr_self = T.UnknownTrait __FUNCTION__ in
+ let subst =
+ Subst.make_subst_from_generics_no_regions trait_impl.generics
+ generics tr_self
+ in
+ let ty = Subst.ty_substitute subst ty in
+ (* Reconvert *)
+ let ty : 'r T.ty = ctx.convert_ety (Subst.erase_regions ty) in
+ (* Normalize *)
+ ctx_normalize_ty ctx ty
+ | _ ->
+ (* We can't project *)
+ assert (trait_instance_id_is_local_clause trait_ref.trait_id);
+ T.TraitType (trait_ref, generics, type_name)
+ in
+ let tr : 'r C.trait_type_ref = { C.trait_ref; type_name } in
+ (* Lookup the representative, if there is *)
+ match ctx.get_ty_repr tr with None -> ty | Some ty -> ty)
+
+(** This returns the normalized trait instance id together with an optional
+ reference to a trait **implementation**.
+
+ We need this in particular to simplify the trait instance ids after we
+ performed a substitution.
+
+ Example:
+ ========
+ {[
+ trait Trait {
+ type S
+ }
+
+ impl TraitImpl for Foo {
+ type S = usize
+ }
+
+ fn f<T : Trait>(...) -> T::S;
+
+ ...
+ let x = f<Foo>[TraitImpl](...); // T::S ~~> TraitImpl::S ~~> usize
+ ]}
+
+ Several remarks:
+ - as we do not allow higher-order types (yet) then local clauses (and
+ sub-clauses) can't have generic arguments
+ - the [TraitRef] case only happens because of substitution, the role of
+ the normalization is in particular to eliminate it. Inside a [TraitRef]
+ there is necessarily:
+ - an id referencing a local (sub-)clause, that is an id using the variants
+ [Self], [Clause], [ItemClause] and [ParentClause] exclusively. We can't
+ simplify those cases: all we can do is remove the [TraitRef] wrapper
+ by leveraging the fact that the generic arguments must be empty.
+ - a [TraitImpl]. Note that the [TraitImpl] is necessarily just a [TraitImpl],
+ it can't be for instance a [ParentClause(TraitImpl ...)] because the
+ trait resolution would then directly reference the implementation
+ designated by [ParentClause(TraitImpl ...)] (and same for the other cases).
+ In this case we can lookup the trait implementation and recursively project
+ over it.
+ *)
+and ctx_normalize_trait_instance_id :
+ 'r.
+ 'r norm_ctx ->
+ 'r T.trait_instance_id ->
+ 'r T.trait_instance_id * 'r T.trait_ref option =
+ fun ctx id ->
+ match id with
+ | Self -> (id, None)
+ | TraitImpl _ ->
+ (* The [TraitImpl] shouldn't be inside any projection - we check this
+ elsewhere by asserting that whenever we return [None] for the impl
+ trait ref, then the id actually refers to a local clause. *)
+ (id, None)
+ | Clause _ -> (id, None)
+ | BuiltinOrAuto _ -> (id, None)
+ | ParentClause (inst_id, clause_id) -> (
+ let inst_id, impl = ctx_normalize_trait_instance_id ctx inst_id in
+ (* Check if the inst_id refers to a specific implementation, if yes project *)
+ match impl with
+ | None ->
+ (* This is actually a local clause *)
+ assert (trait_instance_id_is_local_clause inst_id);
+ (ParentClause (inst_id, clause_id), None)
+ | Some impl ->
+ (* We figure out the parent clause by doing the following:
+ {[
+ // The implementation we are looking at
+ impl Impl1 : Trait1 { ... }
+
+ // Check the trait it implements
+ trait Trait1 : ParentTrait1 + ParentTrait2 { ... }
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ those are the parent clauses
+ ]}
+
+ We can find the parent clauses in the [trait_decl_ref] field, which
+ tells us which specific instantiation of [Trait1] is implemented
+ by [Impl1].
+ *)
+ let clause =
+ T.TraitClauseId.nth impl.trait_decl_ref.decl_generics.trait_refs
+ clause_id
+ in
+ (* Sanity check: the clause necessarily refers to an impl *)
+ let _ = TypesUtils.trait_instance_id_as_trait_impl clause.trait_id in
+ (TraitRef clause, Some clause))
+ | ItemClause (inst_id, item_name, clause_id) -> (
+ let inst_id, impl = ctx_normalize_trait_instance_id ctx inst_id in
+ (* Check if the inst_id refers to a specific implementation, if yes project *)
+ match impl with
+ | None ->
+ (* This is actually a local clause *)
+ assert (trait_instance_id_is_local_clause inst_id);
+ (ParentClause (inst_id, clause_id), None)
+ | Some impl ->
+ (* We figure out the item clause by doing the following:
+ {[
+ // The implementation we are looking at
+ impl Impl1 : Trait1<R> {
+ type S = ...
+ with Impl2 : Trait2 ... // Instances satisfying the declared bounds
+ ^^^^^^^^^^^^^^^^^^
+ Lookup the clause from here
+ }
+ ]}
+ *)
+ (* The referenced instance should be an impl *)
+ let impl_id =
+ TypesUtils.trait_instance_id_as_trait_impl impl.trait_id
+ in
+ let trait_impl = C.ctx_lookup_trait_impl ctx.ctx impl_id in
+ (* Lookup the clause *)
+ let item = List.assoc item_name trait_impl.types in
+ let clause = T.TraitClauseId.nth (fst item) clause_id in
+ (* Sanity check: the clause necessarily refers to an impl *)
+ let _ = TypesUtils.trait_instance_id_as_trait_impl clause.trait_id in
+ (* We need to convert the clause type -
+ TODO: we have too many problems with those conversions, we should
+ merge the types. *)
+ let clause = ctx.convert_etrait_ref clause in
+ (TraitRef clause, Some clause))
+ | TraitRef { T.trait_id = T.TraitImpl trait_id; generics; trait_decl_ref } ->
+ (* We can't simplify the id *yet* : we will simplify it when projecting.
+ However, we have an implementation to return *)
+ (* Normalize the generics *)
+ let generics = ctx_normalize_generic_args ctx generics in
+ let trait_decl_ref = ctx_normalize_trait_decl_ref ctx trait_decl_ref in
+ let trait_ref : 'r T.trait_ref =
+ { T.trait_id = T.TraitImpl trait_id; generics; trait_decl_ref }
+ in
+ (TraitRef trait_ref, Some trait_ref)
+ | TraitRef trait_ref ->
+ (* The trait instance id necessarily refers to a local sub-clause. We
+ can't project over it and can only peel off the [TraitRef] wrapper *)
+ assert (trait_instance_id_is_local_clause trait_ref.trait_id);
+ assert (trait_ref.generics = TypesUtils.mk_empty_generic_args);
+ (trait_ref.trait_id, None)
+ | UnknownTrait _ ->
+ (* This is actually an error case *)
+ (id, None)
+
+and ctx_normalize_generic_args (ctx : 'r norm_ctx)
+ (generics : 'r T.generic_args) : 'r T.generic_args =
+ let { T.regions; types; const_generics; trait_refs } = generics in
+ let types = List.map (ctx_normalize_ty ctx) types in
+ let trait_refs = List.map (ctx_normalize_trait_ref ctx) trait_refs in
+ { T.regions; types; const_generics; trait_refs }
+
+and ctx_normalize_trait_ref (ctx : 'r norm_ctx) (trait_ref : 'r T.trait_ref) :
+ 'r T.trait_ref =
+ let { T.trait_id; generics; trait_decl_ref } = trait_ref in
+ let trait_id, _ = ctx_normalize_trait_instance_id ctx trait_id in
+ let generics = ctx_normalize_generic_args ctx generics in
+ let trait_decl_ref = ctx_normalize_trait_decl_ref ctx trait_decl_ref in
+ { T.trait_id; generics; trait_decl_ref }
+
+(* Not sure this one is really necessary *)
+and ctx_normalize_trait_decl_ref (ctx : 'r norm_ctx)
+ (trait_decl_ref : 'r T.trait_decl_ref) : 'r T.trait_decl_ref =
+ let { T.trait_decl_id; decl_generics } = trait_decl_ref in
+ let decl_generics = ctx_normalize_generic_args ctx decl_generics in
+ { T.trait_decl_id; decl_generics }
+
+let ctx_normalize_rty (ctx : C.eval_ctx) (ty : T.rty) : T.rty =
+ let get_ty_repr x = C.RTraitTypeRefMap.find_opt x ctx.norm_trait_rtypes in
+ let ctx : T.RegionId.id T.region norm_ctx =
+ {
+ ctx;
+ get_ty_repr;
+ convert_ety = TypesUtils.ety_no_regions_to_rty;
+ convert_etrait_ref = TypesUtils.etrait_ref_no_regions_to_gr_trait_ref;
+ }
+ in
+ ctx_normalize_ty ctx ty
+
+let ctx_normalize_ety (ctx : C.eval_ctx) (ty : T.ety) : T.ety =
+ let get_ty_repr x = C.ETraitTypeRefMap.find_opt x ctx.norm_trait_etypes in
+ let ctx : T.erased_region norm_ctx =
+ {
+ ctx;
+ get_ty_repr;
+ convert_ety = (fun x -> x);
+ convert_etrait_ref = (fun x -> x);
+ }
+ in
+ ctx_normalize_ty ctx ty
(** Same as [type_decl_get_instantiated_variants_fields_rtypes] but normalizes the types *)
let type_decl_get_inst_norm_variants_fields_rtypes (ctx : C.eval_ctx)
@@ -33,7 +277,7 @@ let type_decl_get_inst_norm_variants_fields_rtypes (ctx : C.eval_ctx)
in
List.map
(fun (variant_id, types) ->
- (variant_id, List.map (ctx_normalize_type ctx) types))
+ (variant_id, List.map (ctx_normalize_rty ctx) types))
res
(** Same as [type_decl_get_instantiated_field_rtypes] but normalizes the types *)
@@ -43,7 +287,7 @@ let type_decl_get_inst_norm_field_rtypes (ctx : C.eval_ctx) (def : T.type_decl)
let types =
Subst.type_decl_get_instantiated_field_rtypes def opt_variant_id generics
in
- List.map (ctx_normalize_type ctx) types
+ List.map (ctx_normalize_rty ctx) types
(** Same as [ctx_adt_value_get_instantiated_field_rtypes] but normalizes the types *)
let ctx_adt_value_get_inst_norm_field_rtypes (ctx : C.eval_ctx)
@@ -52,7 +296,7 @@ let ctx_adt_value_get_inst_norm_field_rtypes (ctx : C.eval_ctx)
let types =
Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id generics
in
- List.map (ctx_normalize_type ctx) types
+ List.map (ctx_normalize_rty ctx) types
(** Same as [ctx_adt_value_get_instantiated_field_etypes] but normalizes the types *)
let type_decl_get_inst_norm_field_etypes (ctx : C.eval_ctx) (def : T.type_decl)
@@ -61,7 +305,7 @@ let type_decl_get_inst_norm_field_etypes (ctx : C.eval_ctx) (def : T.type_decl)
let types =
Subst.type_decl_get_instantiated_field_etypes def opt_variant_id generics
in
- List.map (ctx_normalize_type ctx) types
+ List.map (ctx_normalize_ety ctx) types
(** Same as [ctx_adt_get_instantiated_field_etypes] but normalizes the types *)
let ctx_adt_get_inst_norm_field_etypes (ctx : C.eval_ctx)
@@ -71,7 +315,7 @@ let ctx_adt_get_inst_norm_field_etypes (ctx : C.eval_ctx)
Subst.ctx_adt_get_instantiated_field_etypes ctx def_id opt_variant_id
generics
in
- List.map (ctx_normalize_type ctx) types
+ List.map (ctx_normalize_ety ctx) types
(** Same as [substitute_signature] but normalizes the types *)
let ctx_subst_norm_signature (ctx : C.eval_ctx)
@@ -86,6 +330,6 @@ let ctx_subst_norm_signature (ctx : C.eval_ctx)
sg
in
let { A.regions_hierarchy; inputs; output } = sg in
- let inputs = List.map (ctx_normalize_type ctx) inputs in
- let output = ctx_normalize_type ctx output in
+ let inputs = List.map (ctx_normalize_rty ctx) inputs in
+ let output = ctx_normalize_rty ctx output in
{ regions_hierarchy; inputs; output }
diff --git a/compiler/Contexts.ml b/compiler/Contexts.ml
index 2d396924..0719364e 100644
--- a/compiler/Contexts.ml
+++ b/compiler/Contexts.ml
@@ -270,6 +270,37 @@ type decls_ctx = {
}
[@@deriving show]
+(** A reference to a trait associated type *)
+type 'r trait_type_ref = { trait_ref : 'r trait_ref; type_name : string }
+[@@deriving show, ord]
+
+type etrait_type_ref = erased_region trait_type_ref [@@deriving show, ord]
+
+type rtrait_type_ref = Types.RegionId.id Types.region trait_type_ref
+[@@deriving show, ord]
+
+(* TODO: correctly use the functors so as not to have a duplication below *)
+module ETraitTypeRefOrd = struct
+ type t = etrait_type_ref
+
+ let compare = compare_etrait_type_ref
+ let to_string = show_etrait_type_ref
+ let pp_t = pp_etrait_type_ref
+ let show_t = show_etrait_type_ref
+end
+
+module RTraitTypeRefOrd = struct
+ type t = rtrait_type_ref
+
+ let compare = compare_rtrait_type_ref
+ let to_string = show_rtrait_type_ref
+ let pp_t = pp_rtrait_type_ref
+ let show_t = show_rtrait_type_ref
+end
+
+module ETraitTypeRefMap = Collections.MakeMap (ETraitTypeRefOrd)
+module RTraitTypeRefMap = Collections.MakeMap (RTraitTypeRefOrd)
+
(** Evaluation context *)
type eval_ctx = {
type_context : type_context;
@@ -285,6 +316,18 @@ type eval_ctx = {
can be symbolic values or concrete values (in the latter case:
if we run in interpreter mode) *)
trait_clauses : etrait_ref list;
+ norm_trait_etypes : ety ETraitTypeRefMap.t;
+ (** The normalized trait types (a map from trait types to their representatives).
+ Note that this doesn't support account higher-order types. *)
+ norm_trait_rtypes : rty RTraitTypeRefMap.t;
+ (** We need this because we manipulate two kinds of types.
+ Note that we actually forbid regions from appearing both in the trait
+ references and in the constraints given to the associated types,
+ meaning that we don't have to worry about mismatches due to changes
+ in region ids.
+
+ TODO: how not to duplicate?
+ *)
env : env;
ended_regions : RegionId.Set.t;
}
diff --git a/compiler/InterpreterLoopsJoinCtxs.ml b/compiler/InterpreterLoopsJoinCtxs.ml
index a34a7d06..045ba9d8 100644
--- a/compiler/InterpreterLoopsJoinCtxs.ml
+++ b/compiler/InterpreterLoopsJoinCtxs.ml
@@ -561,6 +561,8 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx)
const_generic_vars;
const_generic_vars_map;
trait_clauses;
+ norm_trait_etypes;
+ norm_trait_rtypes;
env = _;
ended_regions = ended_regions0;
} =
@@ -577,6 +579,8 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx)
const_generic_vars = _;
const_generic_vars_map = _;
trait_clauses = _;
+ norm_trait_etypes = _;
+ norm_trait_rtypes = _;
env = _;
ended_regions = ended_regions1;
} =
@@ -595,6 +599,8 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx)
const_generic_vars;
const_generic_vars_map;
trait_clauses;
+ norm_trait_etypes;
+ norm_trait_rtypes;
env;
ended_regions;
}
diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml
index d38f8b95..3fb07956 100644
--- a/compiler/InterpreterStatements.ml
+++ b/compiler/InterpreterStatements.ml
@@ -326,7 +326,7 @@ let get_assumed_function_return_type (ctx : C.eval_ctx) (fid : A.assumed_fun_id)
Subst.erase_regions_substitute_types ty_subst cg_subst tr_subst tr_self
sg.output
in
- Assoc.ctx_normalize_type ctx ty
+ Assoc.ctx_normalize_ety ctx ty
let move_return_value (config : C.config) (pop_return_value : bool)
(cf : V.typed_value option -> m_fun) : m_fun =
diff --git a/compiler/Substitute.ml b/compiler/Substitute.ml
index 64e7716a..fe88faea 100644
--- a/compiler/Substitute.ml
+++ b/compiler/Substitute.ml
@@ -39,13 +39,22 @@ let ty_substitute_visitor (subst : ('r1, 'r2) subst) =
method! visit_Self _ = subst.tr_self
end
-(** Substitute types variables and regions in a type. *)
+(** Substitute types variables and regions in a type.
+
+ **IMPORTANT**: this doesn't normalize the type.
+ *)
let ty_substitute (subst : ('r1, 'r2) subst) (ty : 'r1 T.ty) : 'r2 T.ty =
let visitor = ty_substitute_visitor subst in
visitor#visit_ty () ty
+(** **IMPORTANT**: this doesn't normalize the trait ref. *)
+let trait_ref_substitute (subst : ('r1, 'r2) subst) (tr : 'r1 T.trait_ref) :
+ 'r2 T.trait_ref =
+ let visitor = ty_substitute_visitor subst in
+ visitor#visit_trait_ref () tr
+
(** Convert an {!T.rty} to an {!T.ety} by erasing the region variables *)
-let erase_regions (ty : T.rty) : T.ety =
+let erase_regions (ty : 'r T.ty) : T.ety =
let subst =
{
r_subst = (fun _ -> T.Erased);
@@ -169,8 +178,9 @@ let make_trait_subst_from_clauses (clauses : T.trait_clause list)
trs
let make_subst_from_generics (params : T.generic_params)
- (args : 'r T.generic_args) (tr_self : 'r T.trait_instance_id) :
- (T.region_var_id T.region, 'r) subst =
+ (args : 'r T.region T.generic_args)
+ (tr_self : 'r T.region T.trait_instance_id) :
+ (T.region_var_id T.region, 'r T.region) subst =
let r_subst = make_region_subst_from_vars params.T.regions args.T.regions in
let ty_subst = make_type_subst_from_vars params.T.types args.T.types in
let cg_subst =
@@ -182,6 +192,24 @@ let make_subst_from_generics (params : T.generic_params)
in
{ r_subst; ty_subst; cg_subst; tr_subst; tr_self }
+let make_subst_from_generics_no_regions :
+ 'r.
+ T.generic_params ->
+ 'r T.generic_args ->
+ 'r T.trait_instance_id ->
+ (T.region_var_id T.region, 'r) subst =
+ fun params args tr_self ->
+ let r_subst _ = raise (Failure "Unexpected region") in
+ let ty_subst = make_type_subst_from_vars params.T.types args.T.types in
+ let cg_subst =
+ make_const_generic_subst_from_vars params.T.const_generics
+ args.T.const_generics
+ in
+ let tr_subst =
+ make_trait_subst_from_clauses params.T.trait_clauses args.T.trait_refs
+ in
+ { r_subst; ty_subst; cg_subst; tr_subst; tr_self }
+
let make_esubst_from_generics (params : T.generic_params)
(generics : T.egeneric_args) (tr_self : T.etrait_instance_id) =
let r_subst _ = T.Erased in