From 06360698561019d7f480dcb4263e2099d9a03ca5 Mon Sep 17 00:00:00 2001 From: Son Ho Date: Fri, 1 Sep 2023 12:01:03 +0200 Subject: Implement the normalization functions in AssociatedTypes --- compiler/AssociatedTypes.ml | 262 +++++++++++++++++++++++++++++++++-- compiler/Contexts.ml | 43 ++++++ compiler/InterpreterLoopsJoinCtxs.ml | 6 + compiler/InterpreterStatements.ml | 2 +- compiler/Substitute.ml | 36 ++++- 5 files changed, 335 insertions(+), 14 deletions(-) (limited to 'compiler') diff --git a/compiler/AssociatedTypes.ml b/compiler/AssociatedTypes.ml index 4e5625cb..8e08db6e 100644 --- a/compiler/AssociatedTypes.ml +++ b/compiler/AssociatedTypes.ml @@ -18,11 +18,255 @@ module L = Logging (** The local logger *) let log = L.associated_types_log +(** A trait instance id refers to a local clause if it only uses the variants: + [Self], [Clause], [ParentClause], [ItemClause] *) +let rec trait_instance_id_is_local_clause (id : 'r T.trait_instance_id) : bool = + match id with + | T.Self | Clause _ -> true + | TraitImpl _ | BuiltinOrAuto _ | TraitRef _ | UnknownTrait _ -> false + | ParentClause (id, _) | ItemClause (id, _, _) -> + trait_instance_id_is_local_clause id + +(** About the conversion functions: for now we need them (TODO: merge ety, rty, etc.), + but they should be applied to types without regions. + *) +type 'r norm_ctx = { + ctx : C.eval_ctx; + get_ty_repr : 'r C.trait_type_ref -> 'r T.ty option; + convert_ety : T.ety -> 'r T.ty; + convert_etrait_ref : T.etrait_ref -> 'r T.trait_ref; +} + (** Normalize a type by simplyfying the references to trait associated types and choosing a representative when there are equalities between types enforced by local clauses (i.e., `where Trait1::T = Trait2::U`. *) -let ctx_normalize_type (_ctx : C.eval_ctx) (_ty : 'r T.ty) : 'r T.ty = - raise (Failure "Unimplemented") +let rec ctx_normalize_ty : 'r. 'r norm_ctx -> 'r T.ty -> 'r T.ty = + fun ctx ty -> + match ty with + | T.Adt (id, generics) -> Adt (id, ctx_normalize_generic_args ctx generics) + | TypeVar _ | Literal _ | Never -> ty + | Ref (r, ty, rkind) -> + let ty = ctx_normalize_ty ctx ty in + T.Ref (r, ty, rkind) + | TraitType (trait_ref, generics, type_name) -> ( + (* Normalize and attempt to project the type from the trait ref *) + let trait_ref = ctx_normalize_trait_ref ctx trait_ref in + let generics = ctx_normalize_generic_args ctx generics in + let ty : 'r T.ty = + match trait_ref.trait_id with + | T.TraitRef { T.trait_id = T.TraitImpl impl_id; generics; _ } -> + (* Lookup the implementation *) + let trait_impl = C.ctx_lookup_trait_impl ctx.ctx impl_id in + (* Lookup the type *) + let ty = snd (List.assoc type_name trait_impl.types) in + (* Annoying: convert etype to an stype - TODO: hwo to avoid that? *) + let ty : T.sty = TypesUtils.ety_no_regions_to_gr_ty ty in + (* Substitute - annoying: we can't use *) + let tr_self = T.UnknownTrait __FUNCTION__ in + let subst = + Subst.make_subst_from_generics_no_regions trait_impl.generics + generics tr_self + in + let ty = Subst.ty_substitute subst ty in + (* Reconvert *) + let ty : 'r T.ty = ctx.convert_ety (Subst.erase_regions ty) in + (* Normalize *) + ctx_normalize_ty ctx ty + | _ -> + (* We can't project *) + assert (trait_instance_id_is_local_clause trait_ref.trait_id); + T.TraitType (trait_ref, generics, type_name) + in + let tr : 'r C.trait_type_ref = { C.trait_ref; type_name } in + (* Lookup the representative, if there is *) + match ctx.get_ty_repr tr with None -> ty | Some ty -> ty) + +(** This returns the normalized trait instance id together with an optional + reference to a trait **implementation**. + + We need this in particular to simplify the trait instance ids after we + performed a substitution. + + Example: + ======== + {[ + trait Trait { + type S + } + + impl TraitImpl for Foo { + type S = usize + } + + fn f(...) -> T::S; + + ... + let x = f[TraitImpl](...); // T::S ~~> TraitImpl::S ~~> usize + ]} + + Several remarks: + - as we do not allow higher-order types (yet) then local clauses (and + sub-clauses) can't have generic arguments + - the [TraitRef] case only happens because of substitution, the role of + the normalization is in particular to eliminate it. Inside a [TraitRef] + there is necessarily: + - an id referencing a local (sub-)clause, that is an id using the variants + [Self], [Clause], [ItemClause] and [ParentClause] exclusively. We can't + simplify those cases: all we can do is remove the [TraitRef] wrapper + by leveraging the fact that the generic arguments must be empty. + - a [TraitImpl]. Note that the [TraitImpl] is necessarily just a [TraitImpl], + it can't be for instance a [ParentClause(TraitImpl ...)] because the + trait resolution would then directly reference the implementation + designated by [ParentClause(TraitImpl ...)] (and same for the other cases). + In this case we can lookup the trait implementation and recursively project + over it. + *) +and ctx_normalize_trait_instance_id : + 'r. + 'r norm_ctx -> + 'r T.trait_instance_id -> + 'r T.trait_instance_id * 'r T.trait_ref option = + fun ctx id -> + match id with + | Self -> (id, None) + | TraitImpl _ -> + (* The [TraitImpl] shouldn't be inside any projection - we check this + elsewhere by asserting that whenever we return [None] for the impl + trait ref, then the id actually refers to a local clause. *) + (id, None) + | Clause _ -> (id, None) + | BuiltinOrAuto _ -> (id, None) + | ParentClause (inst_id, clause_id) -> ( + let inst_id, impl = ctx_normalize_trait_instance_id ctx inst_id in + (* Check if the inst_id refers to a specific implementation, if yes project *) + match impl with + | None -> + (* This is actually a local clause *) + assert (trait_instance_id_is_local_clause inst_id); + (ParentClause (inst_id, clause_id), None) + | Some impl -> + (* We figure out the parent clause by doing the following: + {[ + // The implementation we are looking at + impl Impl1 : Trait1 { ... } + + // Check the trait it implements + trait Trait1 : ParentTrait1 + ParentTrait2 { ... } + ^^^^^^^^^^^^^^^^^^^^^^^^^^^ + those are the parent clauses + ]} + + We can find the parent clauses in the [trait_decl_ref] field, which + tells us which specific instantiation of [Trait1] is implemented + by [Impl1]. + *) + let clause = + T.TraitClauseId.nth impl.trait_decl_ref.decl_generics.trait_refs + clause_id + in + (* Sanity check: the clause necessarily refers to an impl *) + let _ = TypesUtils.trait_instance_id_as_trait_impl clause.trait_id in + (TraitRef clause, Some clause)) + | ItemClause (inst_id, item_name, clause_id) -> ( + let inst_id, impl = ctx_normalize_trait_instance_id ctx inst_id in + (* Check if the inst_id refers to a specific implementation, if yes project *) + match impl with + | None -> + (* This is actually a local clause *) + assert (trait_instance_id_is_local_clause inst_id); + (ParentClause (inst_id, clause_id), None) + | Some impl -> + (* We figure out the item clause by doing the following: + {[ + // The implementation we are looking at + impl Impl1 : Trait1 { + type S = ... + with Impl2 : Trait2 ... // Instances satisfying the declared bounds + ^^^^^^^^^^^^^^^^^^ + Lookup the clause from here + } + ]} + *) + (* The referenced instance should be an impl *) + let impl_id = + TypesUtils.trait_instance_id_as_trait_impl impl.trait_id + in + let trait_impl = C.ctx_lookup_trait_impl ctx.ctx impl_id in + (* Lookup the clause *) + let item = List.assoc item_name trait_impl.types in + let clause = T.TraitClauseId.nth (fst item) clause_id in + (* Sanity check: the clause necessarily refers to an impl *) + let _ = TypesUtils.trait_instance_id_as_trait_impl clause.trait_id in + (* We need to convert the clause type - + TODO: we have too many problems with those conversions, we should + merge the types. *) + let clause = ctx.convert_etrait_ref clause in + (TraitRef clause, Some clause)) + | TraitRef { T.trait_id = T.TraitImpl trait_id; generics; trait_decl_ref } -> + (* We can't simplify the id *yet* : we will simplify it when projecting. + However, we have an implementation to return *) + (* Normalize the generics *) + let generics = ctx_normalize_generic_args ctx generics in + let trait_decl_ref = ctx_normalize_trait_decl_ref ctx trait_decl_ref in + let trait_ref : 'r T.trait_ref = + { T.trait_id = T.TraitImpl trait_id; generics; trait_decl_ref } + in + (TraitRef trait_ref, Some trait_ref) + | TraitRef trait_ref -> + (* The trait instance id necessarily refers to a local sub-clause. We + can't project over it and can only peel off the [TraitRef] wrapper *) + assert (trait_instance_id_is_local_clause trait_ref.trait_id); + assert (trait_ref.generics = TypesUtils.mk_empty_generic_args); + (trait_ref.trait_id, None) + | UnknownTrait _ -> + (* This is actually an error case *) + (id, None) + +and ctx_normalize_generic_args (ctx : 'r norm_ctx) + (generics : 'r T.generic_args) : 'r T.generic_args = + let { T.regions; types; const_generics; trait_refs } = generics in + let types = List.map (ctx_normalize_ty ctx) types in + let trait_refs = List.map (ctx_normalize_trait_ref ctx) trait_refs in + { T.regions; types; const_generics; trait_refs } + +and ctx_normalize_trait_ref (ctx : 'r norm_ctx) (trait_ref : 'r T.trait_ref) : + 'r T.trait_ref = + let { T.trait_id; generics; trait_decl_ref } = trait_ref in + let trait_id, _ = ctx_normalize_trait_instance_id ctx trait_id in + let generics = ctx_normalize_generic_args ctx generics in + let trait_decl_ref = ctx_normalize_trait_decl_ref ctx trait_decl_ref in + { T.trait_id; generics; trait_decl_ref } + +(* Not sure this one is really necessary *) +and ctx_normalize_trait_decl_ref (ctx : 'r norm_ctx) + (trait_decl_ref : 'r T.trait_decl_ref) : 'r T.trait_decl_ref = + let { T.trait_decl_id; decl_generics } = trait_decl_ref in + let decl_generics = ctx_normalize_generic_args ctx decl_generics in + { T.trait_decl_id; decl_generics } + +let ctx_normalize_rty (ctx : C.eval_ctx) (ty : T.rty) : T.rty = + let get_ty_repr x = C.RTraitTypeRefMap.find_opt x ctx.norm_trait_rtypes in + let ctx : T.RegionId.id T.region norm_ctx = + { + ctx; + get_ty_repr; + convert_ety = TypesUtils.ety_no_regions_to_rty; + convert_etrait_ref = TypesUtils.etrait_ref_no_regions_to_gr_trait_ref; + } + in + ctx_normalize_ty ctx ty + +let ctx_normalize_ety (ctx : C.eval_ctx) (ty : T.ety) : T.ety = + let get_ty_repr x = C.ETraitTypeRefMap.find_opt x ctx.norm_trait_etypes in + let ctx : T.erased_region norm_ctx = + { + ctx; + get_ty_repr; + convert_ety = (fun x -> x); + convert_etrait_ref = (fun x -> x); + } + in + ctx_normalize_ty ctx ty (** Same as [type_decl_get_instantiated_variants_fields_rtypes] but normalizes the types *) let type_decl_get_inst_norm_variants_fields_rtypes (ctx : C.eval_ctx) @@ -33,7 +277,7 @@ let type_decl_get_inst_norm_variants_fields_rtypes (ctx : C.eval_ctx) in List.map (fun (variant_id, types) -> - (variant_id, List.map (ctx_normalize_type ctx) types)) + (variant_id, List.map (ctx_normalize_rty ctx) types)) res (** Same as [type_decl_get_instantiated_field_rtypes] but normalizes the types *) @@ -43,7 +287,7 @@ let type_decl_get_inst_norm_field_rtypes (ctx : C.eval_ctx) (def : T.type_decl) let types = Subst.type_decl_get_instantiated_field_rtypes def opt_variant_id generics in - List.map (ctx_normalize_type ctx) types + List.map (ctx_normalize_rty ctx) types (** Same as [ctx_adt_value_get_instantiated_field_rtypes] but normalizes the types *) let ctx_adt_value_get_inst_norm_field_rtypes (ctx : C.eval_ctx) @@ -52,7 +296,7 @@ let ctx_adt_value_get_inst_norm_field_rtypes (ctx : C.eval_ctx) let types = Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id generics in - List.map (ctx_normalize_type ctx) types + List.map (ctx_normalize_rty ctx) types (** Same as [ctx_adt_value_get_instantiated_field_etypes] but normalizes the types *) let type_decl_get_inst_norm_field_etypes (ctx : C.eval_ctx) (def : T.type_decl) @@ -61,7 +305,7 @@ let type_decl_get_inst_norm_field_etypes (ctx : C.eval_ctx) (def : T.type_decl) let types = Subst.type_decl_get_instantiated_field_etypes def opt_variant_id generics in - List.map (ctx_normalize_type ctx) types + List.map (ctx_normalize_ety ctx) types (** Same as [ctx_adt_get_instantiated_field_etypes] but normalizes the types *) let ctx_adt_get_inst_norm_field_etypes (ctx : C.eval_ctx) @@ -71,7 +315,7 @@ let ctx_adt_get_inst_norm_field_etypes (ctx : C.eval_ctx) Subst.ctx_adt_get_instantiated_field_etypes ctx def_id opt_variant_id generics in - List.map (ctx_normalize_type ctx) types + List.map (ctx_normalize_ety ctx) types (** Same as [substitute_signature] but normalizes the types *) let ctx_subst_norm_signature (ctx : C.eval_ctx) @@ -86,6 +330,6 @@ let ctx_subst_norm_signature (ctx : C.eval_ctx) sg in let { A.regions_hierarchy; inputs; output } = sg in - let inputs = List.map (ctx_normalize_type ctx) inputs in - let output = ctx_normalize_type ctx output in + let inputs = List.map (ctx_normalize_rty ctx) inputs in + let output = ctx_normalize_rty ctx output in { regions_hierarchy; inputs; output } diff --git a/compiler/Contexts.ml b/compiler/Contexts.ml index 2d396924..0719364e 100644 --- a/compiler/Contexts.ml +++ b/compiler/Contexts.ml @@ -270,6 +270,37 @@ type decls_ctx = { } [@@deriving show] +(** A reference to a trait associated type *) +type 'r trait_type_ref = { trait_ref : 'r trait_ref; type_name : string } +[@@deriving show, ord] + +type etrait_type_ref = erased_region trait_type_ref [@@deriving show, ord] + +type rtrait_type_ref = Types.RegionId.id Types.region trait_type_ref +[@@deriving show, ord] + +(* TODO: correctly use the functors so as not to have a duplication below *) +module ETraitTypeRefOrd = struct + type t = etrait_type_ref + + let compare = compare_etrait_type_ref + let to_string = show_etrait_type_ref + let pp_t = pp_etrait_type_ref + let show_t = show_etrait_type_ref +end + +module RTraitTypeRefOrd = struct + type t = rtrait_type_ref + + let compare = compare_rtrait_type_ref + let to_string = show_rtrait_type_ref + let pp_t = pp_rtrait_type_ref + let show_t = show_rtrait_type_ref +end + +module ETraitTypeRefMap = Collections.MakeMap (ETraitTypeRefOrd) +module RTraitTypeRefMap = Collections.MakeMap (RTraitTypeRefOrd) + (** Evaluation context *) type eval_ctx = { type_context : type_context; @@ -285,6 +316,18 @@ type eval_ctx = { can be symbolic values or concrete values (in the latter case: if we run in interpreter mode) *) trait_clauses : etrait_ref list; + norm_trait_etypes : ety ETraitTypeRefMap.t; + (** The normalized trait types (a map from trait types to their representatives). + Note that this doesn't support account higher-order types. *) + norm_trait_rtypes : rty RTraitTypeRefMap.t; + (** We need this because we manipulate two kinds of types. + Note that we actually forbid regions from appearing both in the trait + references and in the constraints given to the associated types, + meaning that we don't have to worry about mismatches due to changes + in region ids. + + TODO: how not to duplicate? + *) env : env; ended_regions : RegionId.Set.t; } diff --git a/compiler/InterpreterLoopsJoinCtxs.ml b/compiler/InterpreterLoopsJoinCtxs.ml index a34a7d06..045ba9d8 100644 --- a/compiler/InterpreterLoopsJoinCtxs.ml +++ b/compiler/InterpreterLoopsJoinCtxs.ml @@ -561,6 +561,8 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx) const_generic_vars; const_generic_vars_map; trait_clauses; + norm_trait_etypes; + norm_trait_rtypes; env = _; ended_regions = ended_regions0; } = @@ -577,6 +579,8 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx) const_generic_vars = _; const_generic_vars_map = _; trait_clauses = _; + norm_trait_etypes = _; + norm_trait_rtypes = _; env = _; ended_regions = ended_regions1; } = @@ -595,6 +599,8 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx) const_generic_vars; const_generic_vars_map; trait_clauses; + norm_trait_etypes; + norm_trait_rtypes; env; ended_regions; } diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml index d38f8b95..3fb07956 100644 --- a/compiler/InterpreterStatements.ml +++ b/compiler/InterpreterStatements.ml @@ -326,7 +326,7 @@ let get_assumed_function_return_type (ctx : C.eval_ctx) (fid : A.assumed_fun_id) Subst.erase_regions_substitute_types ty_subst cg_subst tr_subst tr_self sg.output in - Assoc.ctx_normalize_type ctx ty + Assoc.ctx_normalize_ety ctx ty let move_return_value (config : C.config) (pop_return_value : bool) (cf : V.typed_value option -> m_fun) : m_fun = diff --git a/compiler/Substitute.ml b/compiler/Substitute.ml index 64e7716a..fe88faea 100644 --- a/compiler/Substitute.ml +++ b/compiler/Substitute.ml @@ -39,13 +39,22 @@ let ty_substitute_visitor (subst : ('r1, 'r2) subst) = method! visit_Self _ = subst.tr_self end -(** Substitute types variables and regions in a type. *) +(** Substitute types variables and regions in a type. + + **IMPORTANT**: this doesn't normalize the type. + *) let ty_substitute (subst : ('r1, 'r2) subst) (ty : 'r1 T.ty) : 'r2 T.ty = let visitor = ty_substitute_visitor subst in visitor#visit_ty () ty +(** **IMPORTANT**: this doesn't normalize the trait ref. *) +let trait_ref_substitute (subst : ('r1, 'r2) subst) (tr : 'r1 T.trait_ref) : + 'r2 T.trait_ref = + let visitor = ty_substitute_visitor subst in + visitor#visit_trait_ref () tr + (** Convert an {!T.rty} to an {!T.ety} by erasing the region variables *) -let erase_regions (ty : T.rty) : T.ety = +let erase_regions (ty : 'r T.ty) : T.ety = let subst = { r_subst = (fun _ -> T.Erased); @@ -169,8 +178,9 @@ let make_trait_subst_from_clauses (clauses : T.trait_clause list) trs let make_subst_from_generics (params : T.generic_params) - (args : 'r T.generic_args) (tr_self : 'r T.trait_instance_id) : - (T.region_var_id T.region, 'r) subst = + (args : 'r T.region T.generic_args) + (tr_self : 'r T.region T.trait_instance_id) : + (T.region_var_id T.region, 'r T.region) subst = let r_subst = make_region_subst_from_vars params.T.regions args.T.regions in let ty_subst = make_type_subst_from_vars params.T.types args.T.types in let cg_subst = @@ -182,6 +192,24 @@ let make_subst_from_generics (params : T.generic_params) in { r_subst; ty_subst; cg_subst; tr_subst; tr_self } +let make_subst_from_generics_no_regions : + 'r. + T.generic_params -> + 'r T.generic_args -> + 'r T.trait_instance_id -> + (T.region_var_id T.region, 'r) subst = + fun params args tr_self -> + let r_subst _ = raise (Failure "Unexpected region") in + let ty_subst = make_type_subst_from_vars params.T.types args.T.types in + let cg_subst = + make_const_generic_subst_from_vars params.T.const_generics + args.T.const_generics + in + let tr_subst = + make_trait_subst_from_clauses params.T.trait_clauses args.T.trait_refs + in + { r_subst; ty_subst; cg_subst; tr_subst; tr_self } + let make_esubst_from_generics (params : T.generic_params) (generics : T.egeneric_args) (tr_self : T.etrait_instance_id) = let r_subst _ = T.Erased in -- cgit v1.2.3