summaryrefslogtreecommitdiff
path: root/compiler/Substitute.ml
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--compiler/Substitute.ml120
1 files changed, 84 insertions, 36 deletions
diff --git a/compiler/Substitute.ml b/compiler/Substitute.ml
index a05b2c5a..e28f005d 100644
--- a/compiler/Substitute.ml
+++ b/compiler/Substitute.ml
@@ -11,6 +11,9 @@ open Contexts
type subst = {
r_subst : region -> region;
+ (** Remark: this might be called with bound regions with a negative
+ DeBruijn index. A negative DeBruijn index means that the region
+ is locally bound. *)
ty_subst : TypeVarId.id -> ty;
cg_subst : ConstGenericVarId.id -> const_generic;
(** Substitution from *local* trait clause to trait instance *)
@@ -19,11 +22,35 @@ type subst = {
tr_self : trait_instance_id;
}
+let empty_subst : subst =
+ {
+ r_subst = (fun x -> x);
+ ty_subst = (fun id -> TVar id);
+ cg_subst = (fun id -> CgVar id);
+ tr_subst = (fun id -> Clause id);
+ tr_self = Self;
+ }
+
let st_substitute_visitor (subst : subst) =
- object
+ object (self)
inherit [_] map_statement
- method! visit_region _ r = subst.r_subst r
- method! visit_TVar _ id = subst.ty_subst id
+ method! visit_region (subst : subst) r = subst.r_subst r
+
+ (** We need to properly handle the DeBruijn indices *)
+ method! visit_TArrow subst regions inputs output =
+ (* Decrement the DeBruijn indices before calling the substitution *)
+ let r_subst r =
+ match r with
+ | RBVar (db, rid) -> subst.r_subst (RBVar (db - 1, rid))
+ | _ -> subst.r_subst r
+ in
+ let subst = { subst with r_subst } in
+ (* Note that we ignore the bound regions variables *)
+ let inputs = List.map (self#visit_ty subst) inputs in
+ let output = self#visit_ty subst output in
+ TArrow (regions, inputs, output)
+
+ method! visit_TVar (subst : subst) id = subst.ty_subst id
method! visit_type_var_id _ _ =
(* We should never get here because we reimplemented [visit_TypeVar] *)
@@ -35,8 +62,8 @@ let st_substitute_visitor (subst : subst) =
(* We should never get here because we reimplemented [visit_Var] *)
raise (Failure "Unexpected")
- method! visit_Clause _ id = subst.tr_subst id
- method! visit_Self _ = subst.tr_self
+ method! visit_Clause (subst : subst) id = subst.tr_subst id
+ method! visit_Self (subst : subst) = subst.tr_self
end
(** Substitute types variables and regions in a type.
@@ -45,27 +72,27 @@ let st_substitute_visitor (subst : subst) =
*)
let ty_substitute (subst : subst) (ty : ty) : ty =
let visitor = st_substitute_visitor subst in
- visitor#visit_ty () ty
+ visitor#visit_ty subst ty
(** **IMPORTANT**: this doesn't normalize the types. *)
let trait_ref_substitute (subst : subst) (tr : trait_ref) : trait_ref =
let visitor = st_substitute_visitor subst in
- visitor#visit_trait_ref () tr
+ visitor#visit_trait_ref subst tr
(** **IMPORTANT**: this doesn't normalize the types. *)
let trait_instance_id_substitute (subst : subst) (tr : trait_instance_id) :
trait_instance_id =
let visitor = st_substitute_visitor subst in
- visitor#visit_trait_instance_id () tr
+ visitor#visit_trait_instance_id subst tr
(** **IMPORTANT**: this doesn't normalize the types. *)
let generic_args_substitute (subst : subst) (g : generic_args) : generic_args =
let visitor = st_substitute_visitor subst in
- visitor#visit_generic_args () g
+ visitor#visit_generic_args subst g
let predicates_substitute (subst : subst) (p : predicates) : predicates =
let visitor = st_substitute_visitor subst in
- visitor#visit_predicates () p
+ visitor#visit_predicates subst p
let erase_regions_subst : subst =
{
@@ -96,26 +123,40 @@ let generic_args_erase_regions (tr : generic_args) : generic_args =
TODO: simplify? we only need the subst [RegionVarId.id -> RegionId.id]
*)
-let fresh_regions_with_substs (region_vars : region_var list) :
- RegionId.id list * (RegionId.id -> RegionId.id) * (region -> region) =
+let fresh_regions_with_substs ~(fail_if_not_found : bool)
+ (region_vars : RegionVarId.id list) :
+ RegionId.id list
+ * (RegionVarId.id -> RegionId.id option)
+ * (region -> region) =
(* Generate fresh regions *)
let fresh_region_ids = List.map (fun _ -> fresh_region_id ()) region_vars in
(* Generate the map from region var ids to regions *)
let ls = List.combine region_vars fresh_region_ids in
- let rid_map =
- List.fold_left
- (fun mp ((k : region_var), v) -> RegionId.Map.add k.index v mp)
- RegionId.Map.empty ls
- in
+ let rid_map = RegionVarId.Map.of_list ls in
(* Generate the substitution from region var id to region *)
- let rid_subst id = RegionId.Map.find id rid_map in
+ let rid_subst id = RegionVarId.Map.find_opt id rid_map in
(* Generate the substitution from region to region *)
let r_subst (r : region) =
- match r with RStatic | RErased -> r | RVar id -> RVar (rid_subst id)
+ match r with
+ | RStatic | RErased | RFVar _ -> r
+ | RBVar (bdid, id) ->
+ if bdid = 0 then
+ match rid_subst id with
+ | None -> if fail_if_not_found then raise Not_found else r
+ | Some r -> RFVar r
+ else r
in
(* Return *)
(fresh_region_ids, rid_subst, r_subst)
+let fresh_regions_with_substs_from_vars ~(fail_if_not_found : bool)
+ (region_vars : region_var list) :
+ RegionId.id list
+ * (RegionVarId.id -> RegionId.id option)
+ * (region -> region) =
+ fresh_regions_with_substs ~fail_if_not_found
+ (List.map (fun (r : region_var) -> r.index) region_vars)
+
(** Erase the regions in a type and perform a substitution *)
let erase_regions_substitute_types (ty_subst : TypeVarId.id -> ty)
(cg_subst : ConstGenericVarId.id -> const_generic)
@@ -127,16 +168,21 @@ let erase_regions_substitute_types (ty_subst : TypeVarId.id -> ty)
(** Create a region substitution from a list of region variable ids and a list of
regions (with which to substitute the region variable ids *)
-let make_region_subst (var_ids : RegionId.id list) (regions : region list) :
+let make_region_subst (var_ids : RegionVarId.id list) (regions : region list) :
region -> region =
let ls = List.combine var_ids regions in
let mp =
List.fold_left
- (fun mp (k, v) -> RegionId.Map.add k v mp)
- RegionId.Map.empty ls
+ (fun mp (k, v) -> RegionVarId.Map.add k v mp)
+ RegionVarId.Map.empty ls
in
fun r ->
- match r with RStatic | RErased -> r | RVar id -> RegionId.Map.find id mp
+ match r with
+ | RStatic | RErased -> r
+ | RFVar _ -> raise (Failure "Unexpected")
+ | RBVar (bdid, id) ->
+ (* Only substitute the bound regions with DeBruijn index equal to 0 *)
+ if bdid = 0 then RegionVarId.Map.find id mp else r
let make_region_subst_from_vars (vars : region_var list) (regions : region list)
: region -> region =
@@ -298,27 +344,27 @@ let ctx_adt_value_get_instantiated_field_types (ctx : eval_ctx)
(** Apply a type substitution to a place *)
let place_substitute (subst : subst) (p : place) : place =
(* There is in fact nothing to do *)
- (st_substitute_visitor subst)#visit_place () p
+ (st_substitute_visitor subst)#visit_place subst p
(** Apply a type substitution to an operand *)
let operand_substitute (subst : subst) (op : operand) : operand =
- (st_substitute_visitor subst)#visit_operand () op
+ (st_substitute_visitor subst)#visit_operand subst op
(** Apply a type substitution to an rvalue *)
let rvalue_substitute (subst : subst) (rv : rvalue) : rvalue =
- (st_substitute_visitor subst)#visit_rvalue () rv
+ (st_substitute_visitor subst)#visit_rvalue subst rv
(** Apply a type substitution to an assertion *)
let assertion_substitute (subst : subst) (a : assertion) : assertion =
- (st_substitute_visitor subst)#visit_assertion () a
+ (st_substitute_visitor subst)#visit_assertion subst a
(** Apply a type substitution to a call *)
let call_substitute (subst : subst) (call : call) : call =
- (st_substitute_visitor subst)#visit_call () call
+ (st_substitute_visitor subst)#visit_call subst call
(** Apply a type substitution to a statement *)
let statement_substitute (subst : subst) (st : statement) : statement =
- (st_substitute_visitor subst)#visit_statement () st
+ (st_substitute_visitor subst)#visit_statement subst st
(** Apply a type substitution to a function body. Return the local variables
and the body. *)
@@ -336,9 +382,9 @@ let trait_type_constraint_substitute (subst : subst)
(ttc : trait_type_constraint) : trait_type_constraint =
let { trait_ref; generics; type_name; ty } = ttc in
let visitor = st_substitute_visitor subst in
- let trait_ref = visitor#visit_trait_ref () trait_ref in
- let generics = visitor#visit_generic_args () generics in
- let ty = visitor#visit_ty () ty in
+ let trait_ref = visitor#visit_trait_ref subst trait_ref in
+ let generics = visitor#visit_generic_args subst generics in
+ let ty = visitor#visit_ty subst ty in
{ trait_ref; generics; type_name; ty }
(** Substitute a function signature, together with the regions hierarchy
@@ -347,18 +393,20 @@ let trait_type_constraint_substitute (subst : subst)
**IMPORTANT:** this function doesn't normalize the types.
*)
let substitute_signature (asubst : RegionGroupId.id -> AbstractionId.id)
- (r_subst : RegionId.id -> RegionId.id) (ty_subst : TypeVarId.id -> ty)
+ (r_subst : RegionVarId.id -> RegionId.id) (ty_subst : TypeVarId.id -> ty)
(cg_subst : ConstGenericVarId.id -> const_generic)
(tr_subst : TraitClauseId.id -> trait_instance_id)
(tr_self : trait_instance_id) (sg : fun_sig)
- (regions_hierarchy : region_groups) : inst_fun_sig =
+ (regions_hierarchy : region_var_groups) : inst_fun_sig =
let r_subst' (r : region) : region =
- match r with RStatic | RErased -> r | RVar rid -> RVar (r_subst rid)
+ match r with
+ | RStatic | RErased | RFVar _ -> r
+ | RBVar (bdid, rid) -> if bdid = 0 then RFVar (r_subst rid) else r
in
let subst = { r_subst = r_subst'; ty_subst; cg_subst; tr_subst; tr_self } in
let inputs = List.map (ty_substitute subst) sg.inputs in
let output = ty_substitute subst sg.output in
- let subst_region_group (rg : region_group) : abs_region_group =
+ let subst_region_group (rg : region_var_group) : abs_region_group =
let id = asubst rg.id in
let regions = List.map r_subst rg.regions in
let parents = List.map asubst rg.parents in