diff options
-rw-r--r-- | src/Identifiers.ml | 27 | ||||
-rw-r--r-- | src/Substitute.ml | 96 |
2 files changed, 73 insertions, 50 deletions
diff --git a/src/Identifiers.ml b/src/Identifiers.ml index 6f74e062..23c887c4 100644 --- a/src/Identifiers.ml +++ b/src/Identifiers.ml @@ -22,6 +22,8 @@ module type Id = sig val to_int : id -> int + val of_int : id -> int + val empty_vector : 'a vector val vector_to_list : 'a vector -> 'a list @@ -35,15 +37,23 @@ module type Id = sig val nth_opt : 'a vector -> id -> 'a option val update_nth : 'a vector -> id -> 'a -> 'a vector + (** Update the nth element of the vector. + + Raises [Invalid_argument] if the identifier is out of range. + *) val iter : ('a -> unit) -> 'a vector -> unit val map : ('a -> 'b) -> 'a vector -> 'b vector + val mapi : (id -> 'a -> 'b) -> 'a vector -> 'b vector + val for_all : ('a -> bool) -> 'a vector -> bool val exists : ('a -> bool) -> 'a vector -> bool + module Ord : Map.OrderedType with type t = id + module Set : Set.S with type elt = id val set_to_string : Set.t -> string @@ -84,6 +94,8 @@ module IdGen () : Id = struct let to_int x = x + let of_int x = x + let empty_vector = [] let vector_to_list v = v @@ -98,7 +110,7 @@ module IdGen () : Id = struct let rec update_nth vec id v = match (vec, id) with - | [], _ -> failwith "Unreachable" + | [], _ -> raise (Invalid_argument "Out of range") | _ :: vec', 0 -> v :: vec' | x :: vec', _ -> x :: update_nth vec' (id - 1) v @@ -106,21 +118,20 @@ module IdGen () : Id = struct let map = List.map + let mapi = List.mapi + let for_all = List.for_all let exists = List.exists - module Set = Set.Make (struct + module Ord = struct type t = id let compare = compare - end) - - module Map = Map.Make (struct - type t = id + end - let compare = compare - end) + module Set = Set.Make (Ord) + module Map = Map.Make (Ord) let set_to_string ids = let ids = Set.fold (fun id ids -> to_string id :: ids) ids [] in diff --git a/src/Substitute.ml b/src/Substitute.ml index f8afc4b9..1957dc24 100644 --- a/src/Substitute.ml +++ b/src/Substitute.ml @@ -15,58 +15,70 @@ let rec ty_subst (rsubst : 'r1 -> 'r2) (tsubst : T.TypeVarId.id -> 'r2 T.ty) | Adt (def_id, regions, tys) -> Adt (def_id, List.map rsubst regions, List.map subst tys) | Tuple tys -> Tuple (List.map subst tys) - | Bool | Char | Never | Integer _ | Str -> (* no change *) ty | Array aty -> Array (subst aty) | Slice sty -> Slice (subst sty) | Ref (r, ref_ty, ref_kind) -> Ref (rsubst r, subst ref_ty, ref_kind) | Assumed (aty, regions, tys) -> Assumed (aty, List.map rsubst regions, List.map subst tys) + (* Below variants: we technically return the same value, but because + one has type ['r1 ty] and the other has type ['r2 ty] we need to + deconstruct then reconstruct *) + | Bool -> Bool + | Char -> Char + | Never -> Never + | Integer int_ty -> Integer int_ty + | Str -> Str -(*(** Works *) -let ty_subst2 (rsubst : 'r1 -> T.erased_region) - (tsubst : T.TypeVarId.id -> T.erased_region T.ty) (ty : 'r1 T.ty) : - T.erased_region T.ty = +(** Erase the regions in a type and substitute the type variables *) +let erase_regions_substitute_types (tsubst : T.TypeVarId.id -> T.ety) + (ty : T.rty) : T.ety = + let rsubst (r : T.RegionVarId.id T.region) : T.erased_region = T.Erased in ty_subst rsubst tsubst ty -let ty_subst3 (rsubst : int -> T.erased_region) - (tsubst : T.TypeVarId.id -> T.erased_region T.ty) (ty : int T.ty) : - T.erased_region T.ty = - ty_subst rsubst tsubst ty*) - -(*(** Doesn't work *) -let ty_subst3 (rsubst : T.RegionVarId.id T.region -> T.erased_region) - (tsubst : T.TypeVarId.id -> T.erased_region T.ty) - (ty : T.RegionVarId.id T.region T.ty) : T.erased_region T.ty = - ty_subst2 rsubst tsubst ty*) - -(* -(** Erase the regions in a type and substitute the type variables *) -let erase_regions_substitute_types - (rsubst : T.RegionVarId.id T.region -> T.erased_region) - (tsubst : T.TypeVarId.id -> T.erased_region T.ty) - (t : T.RegionVarId.id T.region T.ty) = - ty_subst rsubst tsubst t*) +(** Create a type substitution from a list of type variable ids and a list of + types (with which to substitute the type variable ids *) +let make_type_subst (var_ids : T.TypeVarId.id list) (tys : 'r T.ty list) : + T.TypeVarId.id -> 'r T.ty = + let ls = List.combine var_ids tys in + let mp = + List.fold_left + (fun mp (k, v) -> T.TypeVarId.Map.add k v mp) + T.TypeVarId.Map.empty ls + in + fun id -> T.TypeVarId.Map.find id mp -(*let erase_regions_substitute_types - (rsubst : T.RegionVarId.id T.region -> T.erased_region) - (tsubst : T.TypeVarId.id -> T.erased_region T.ty) - (ty : T.RegionVarId.id T.region T.ty) : T.erased_region T.ty = - ty_subst rsubst tsubst ty*) +(** Retrieve the list of fields for the given variant of a [type_def]. -(*let erase_regions_substitute_types - (tsubst : T.TypeVarId.id -> T.erased_region T.ty) - (ty : T.RegionVarId.id T.region T.ty) : T.erased_region T.ty = - let rsubst (r : T.RegionVarId.id T.region) : T.erased_region = T.Erased in - ty_subst rsubst tsubst ty*) + Raises [Invalid_argument] if the arguments are incorrect. -(*let erase_regions_substitute_types (tsubst : T.TypeVarId.id -> T.ety) - (ty : T.rty) : T.ety = - let rsubst (r : T.RegionVarId.id T.region) : T.erased_region = T.Erased in - ty_subst rsubst tsubst ty*) + TODO: move + *) +let type_def_get_fields (def : T.type_def) + (opt_variant_id : T.VariantId.id option) : T.field T.FieldId.vector = + match (def.kind, opt_variant_id) with + | Enum variants, Some variant_id -> + (T.VariantId.nth variants variant_id).fields + | Struct fields, None -> fields + | _ -> + raise + (Invalid_argument + "The variant id should be [Some] if and only if the definition is \ + an enumeration") -(*(** Instantiate the type variables in an ADT definition, and return the list +(** Instantiate the type variables in an ADT definition, and return the list of types of the fields for the chosen variant *) -let get_adt_instantiated_field_type - (def : T.type_def) - (opt_variant_id : T.VariantId.id option) (types : T.ety list) : - ety list =*) +let type_def_get_instantiated_field_type (def : T.type_def) + (opt_variant_id : T.VariantId.id option) (types : T.ety list) : + T.ety T.FieldId.vector = + (* let indices = List.mapi (fun i _ -> TypeVarId.of_int i) def.type_params in*) + let ty_subst = + make_type_subst + (List.map + (fun x -> x.T.tv_index) + (T.TypeVarId.vector_to_list def.T.type_params)) + types + in + let fields = type_def_get_fields def opt_variant_id in + T.FieldId.map + (fun f -> erase_regions_substitute_types ty_subst f.T.field_ty) + fields |