summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
authorSon Ho2021-11-25 14:35:43 +0100
committerSon Ho2021-11-25 14:35:43 +0100
commitbc910f7aef5dac064f3db47ce601b6ef78c14ff5 (patch)
tree10d23d125613f4e70eae1d5b4962d0675be708f4 /src
parentf8b7ed1a4e75ae80c5cfe3859d3cadc9dc9c5c40 (diff)
Implement various substitution functions
Diffstat (limited to '')
-rw-r--r--src/Identifiers.ml27
-rw-r--r--src/Substitute.ml96
2 files changed, 73 insertions, 50 deletions
diff --git a/src/Identifiers.ml b/src/Identifiers.ml
index 6f74e062..23c887c4 100644
--- a/src/Identifiers.ml
+++ b/src/Identifiers.ml
@@ -22,6 +22,8 @@ module type Id = sig
val to_int : id -> int
+ val of_int : id -> int
+
val empty_vector : 'a vector
val vector_to_list : 'a vector -> 'a list
@@ -35,15 +37,23 @@ module type Id = sig
val nth_opt : 'a vector -> id -> 'a option
val update_nth : 'a vector -> id -> 'a -> 'a vector
+ (** Update the nth element of the vector.
+
+ Raises [Invalid_argument] if the identifier is out of range.
+ *)
val iter : ('a -> unit) -> 'a vector -> unit
val map : ('a -> 'b) -> 'a vector -> 'b vector
+ val mapi : (id -> 'a -> 'b) -> 'a vector -> 'b vector
+
val for_all : ('a -> bool) -> 'a vector -> bool
val exists : ('a -> bool) -> 'a vector -> bool
+ module Ord : Map.OrderedType with type t = id
+
module Set : Set.S with type elt = id
val set_to_string : Set.t -> string
@@ -84,6 +94,8 @@ module IdGen () : Id = struct
let to_int x = x
+ let of_int x = x
+
let empty_vector = []
let vector_to_list v = v
@@ -98,7 +110,7 @@ module IdGen () : Id = struct
let rec update_nth vec id v =
match (vec, id) with
- | [], _ -> failwith "Unreachable"
+ | [], _ -> raise (Invalid_argument "Out of range")
| _ :: vec', 0 -> v :: vec'
| x :: vec', _ -> x :: update_nth vec' (id - 1) v
@@ -106,21 +118,20 @@ module IdGen () : Id = struct
let map = List.map
+ let mapi = List.mapi
+
let for_all = List.for_all
let exists = List.exists
- module Set = Set.Make (struct
+ module Ord = struct
type t = id
let compare = compare
- end)
-
- module Map = Map.Make (struct
- type t = id
+ end
- let compare = compare
- end)
+ module Set = Set.Make (Ord)
+ module Map = Map.Make (Ord)
let set_to_string ids =
let ids = Set.fold (fun id ids -> to_string id :: ids) ids [] in
diff --git a/src/Substitute.ml b/src/Substitute.ml
index f8afc4b9..1957dc24 100644
--- a/src/Substitute.ml
+++ b/src/Substitute.ml
@@ -15,58 +15,70 @@ let rec ty_subst (rsubst : 'r1 -> 'r2) (tsubst : T.TypeVarId.id -> 'r2 T.ty)
| Adt (def_id, regions, tys) ->
Adt (def_id, List.map rsubst regions, List.map subst tys)
| Tuple tys -> Tuple (List.map subst tys)
- | Bool | Char | Never | Integer _ | Str -> (* no change *) ty
| Array aty -> Array (subst aty)
| Slice sty -> Slice (subst sty)
| Ref (r, ref_ty, ref_kind) -> Ref (rsubst r, subst ref_ty, ref_kind)
| Assumed (aty, regions, tys) ->
Assumed (aty, List.map rsubst regions, List.map subst tys)
+ (* Below variants: we technically return the same value, but because
+ one has type ['r1 ty] and the other has type ['r2 ty] we need to
+ deconstruct then reconstruct *)
+ | Bool -> Bool
+ | Char -> Char
+ | Never -> Never
+ | Integer int_ty -> Integer int_ty
+ | Str -> Str
-(*(** Works *)
-let ty_subst2 (rsubst : 'r1 -> T.erased_region)
- (tsubst : T.TypeVarId.id -> T.erased_region T.ty) (ty : 'r1 T.ty) :
- T.erased_region T.ty =
+(** Erase the regions in a type and substitute the type variables *)
+let erase_regions_substitute_types (tsubst : T.TypeVarId.id -> T.ety)
+ (ty : T.rty) : T.ety =
+ let rsubst (r : T.RegionVarId.id T.region) : T.erased_region = T.Erased in
ty_subst rsubst tsubst ty
-let ty_subst3 (rsubst : int -> T.erased_region)
- (tsubst : T.TypeVarId.id -> T.erased_region T.ty) (ty : int T.ty) :
- T.erased_region T.ty =
- ty_subst rsubst tsubst ty*)
-
-(*(** Doesn't work *)
-let ty_subst3 (rsubst : T.RegionVarId.id T.region -> T.erased_region)
- (tsubst : T.TypeVarId.id -> T.erased_region T.ty)
- (ty : T.RegionVarId.id T.region T.ty) : T.erased_region T.ty =
- ty_subst2 rsubst tsubst ty*)
-
-(*
-(** Erase the regions in a type and substitute the type variables *)
-let erase_regions_substitute_types
- (rsubst : T.RegionVarId.id T.region -> T.erased_region)
- (tsubst : T.TypeVarId.id -> T.erased_region T.ty)
- (t : T.RegionVarId.id T.region T.ty) =
- ty_subst rsubst tsubst t*)
+(** Create a type substitution from a list of type variable ids and a list of
+ types (with which to substitute the type variable ids *)
+let make_type_subst (var_ids : T.TypeVarId.id list) (tys : 'r T.ty list) :
+ T.TypeVarId.id -> 'r T.ty =
+ let ls = List.combine var_ids tys in
+ let mp =
+ List.fold_left
+ (fun mp (k, v) -> T.TypeVarId.Map.add k v mp)
+ T.TypeVarId.Map.empty ls
+ in
+ fun id -> T.TypeVarId.Map.find id mp
-(*let erase_regions_substitute_types
- (rsubst : T.RegionVarId.id T.region -> T.erased_region)
- (tsubst : T.TypeVarId.id -> T.erased_region T.ty)
- (ty : T.RegionVarId.id T.region T.ty) : T.erased_region T.ty =
- ty_subst rsubst tsubst ty*)
+(** Retrieve the list of fields for the given variant of a [type_def].
-(*let erase_regions_substitute_types
- (tsubst : T.TypeVarId.id -> T.erased_region T.ty)
- (ty : T.RegionVarId.id T.region T.ty) : T.erased_region T.ty =
- let rsubst (r : T.RegionVarId.id T.region) : T.erased_region = T.Erased in
- ty_subst rsubst tsubst ty*)
+ Raises [Invalid_argument] if the arguments are incorrect.
-(*let erase_regions_substitute_types (tsubst : T.TypeVarId.id -> T.ety)
- (ty : T.rty) : T.ety =
- let rsubst (r : T.RegionVarId.id T.region) : T.erased_region = T.Erased in
- ty_subst rsubst tsubst ty*)
+ TODO: move
+ *)
+let type_def_get_fields (def : T.type_def)
+ (opt_variant_id : T.VariantId.id option) : T.field T.FieldId.vector =
+ match (def.kind, opt_variant_id) with
+ | Enum variants, Some variant_id ->
+ (T.VariantId.nth variants variant_id).fields
+ | Struct fields, None -> fields
+ | _ ->
+ raise
+ (Invalid_argument
+ "The variant id should be [Some] if and only if the definition is \
+ an enumeration")
-(*(** Instantiate the type variables in an ADT definition, and return the list
+(** Instantiate the type variables in an ADT definition, and return the list
of types of the fields for the chosen variant *)
-let get_adt_instantiated_field_type
- (def : T.type_def)
- (opt_variant_id : T.VariantId.id option) (types : T.ety list) :
- ety list =*)
+let type_def_get_instantiated_field_type (def : T.type_def)
+ (opt_variant_id : T.VariantId.id option) (types : T.ety list) :
+ T.ety T.FieldId.vector =
+ (* let indices = List.mapi (fun i _ -> TypeVarId.of_int i) def.type_params in*)
+ let ty_subst =
+ make_type_subst
+ (List.map
+ (fun x -> x.T.tv_index)
+ (T.TypeVarId.vector_to_list def.T.type_params))
+ types
+ in
+ let fields = type_def_get_fields def opt_variant_id in
+ T.FieldId.map
+ (fun f -> erase_regions_substitute_types ty_subst f.T.field_ty)
+ fields