summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-09-01 14:43:11 +0200
committerSon Ho2023-09-01 14:43:11 +0200
commit1e39985a44646f1c352def6e4b29365a113a5dee (patch)
tree2d64633f1ae8d2bd941f085ad2dbada3ef7896d8
parent06360698561019d7f480dcb4263e2099d9a03ca5 (diff)
Compute the normalized trait types maps and update Interpreter
-rw-r--r--compiler/AssociatedTypes.ml135
-rw-r--r--compiler/Contexts.ml2
-rw-r--r--compiler/Interpreter.ml88
-rw-r--r--compiler/InterpreterLoopsJoinCtxs.ml3
-rw-r--r--compiler/LlbcAst.ml1
-rw-r--r--compiler/Substitute.ml47
6 files changed, 206 insertions, 70 deletions
diff --git a/compiler/AssociatedTypes.ml b/compiler/AssociatedTypes.ml
index 8e08db6e..07ab70bd 100644
--- a/compiler/AssociatedTypes.ml
+++ b/compiler/AssociatedTypes.ml
@@ -14,10 +14,86 @@ module A = LlbcAst
module C = Contexts
module Subst = Substitute
module L = Logging
+module UF = UnionFind
(** The local logger *)
let log = L.associated_types_log
+let trait_type_ref_substitute (subst : ('r, 'r1) Subst.subst)
+ (r : 'r C.trait_type_ref) : 'r1 C.trait_type_ref =
+ let { C.trait_ref; type_name } = r in
+ let trait_ref = Subst.trait_ref_substitute subst trait_ref in
+ { C.trait_ref; type_name }
+
+module RTyOrd = struct
+ type t = T.rty
+
+ let compare = T.compare_rty
+ let to_string = T.show_rty
+ let pp_t = T.pp_rty
+ let show_t = T.show_rty
+end
+
+module RTyMap = Collections.MakeMap (RTyOrd)
+
+(** Compute the representative classes of trait associated types, for normalization *)
+let compute_norm_trait_types_from_preds
+ (trait_type_constraints : T.rtrait_type_constraint list) :
+ T.ety C.ETraitTypeRefMap.t * T.rty C.RTraitTypeRefMap.t =
+ (* Compute a union-find structure by recursively exploring the predicates and clauses *)
+ let norm : T.rty UF.elem RTyMap.t ref = ref RTyMap.empty in
+ let get_ref (ty : T.rty) : T.rty UF.elem =
+ match RTyMap.find_opt ty !norm with
+ | Some r -> r
+ | None ->
+ let r = UF.make ty in
+ norm := RTyMap.add ty r !norm;
+ r
+ in
+ let add_trait_type_constraint (c : T.rtrait_type_constraint) =
+ let trait_ty = T.TraitType (c.trait_ref, c.generics, c.type_name) in
+ let trait_ty_ref = get_ref trait_ty in
+ let ty_ref = get_ref c.ty in
+ let new_repr = UF.get ty_ref in
+ let merged = UF.union trait_ty_ref ty_ref in
+ (* Not sure the set operation is necessary, but I want to control which
+ representative is chosen *)
+ UF.set merged new_repr
+ in
+ (* Explore the local predicates *)
+ List.iter add_trait_type_constraint trait_type_constraints;
+ (* TODO: explore the local clauses *)
+ (* Compute the norm maps *)
+ let rbindings =
+ List.map (fun (k, v) -> (k, UF.get v)) (RTyMap.bindings !norm)
+ in
+ (* Filter the keys to keep only the trait type aliases *)
+ let rbindings =
+ List.filter_map
+ (fun (k, v) ->
+ match k with
+ | T.TraitType (trait_ref, generics, type_name) ->
+ assert (generics = TypesUtils.mk_empty_generic_args);
+ Some ({ C.trait_ref; type_name }, v)
+ | _ -> None)
+ rbindings
+ in
+ let ebindings =
+ List.map
+ (fun (k, v) ->
+ ( trait_type_ref_substitute Subst.erase_regions_subst k,
+ Subst.erase_regions v ))
+ rbindings
+ in
+ (C.ETraitTypeRefMap.of_list ebindings, C.RTraitTypeRefMap.of_list rbindings)
+
+let ctx_add_norm_trait_types_from_preds (ctx : C.eval_ctx)
+ (trait_type_constraints : T.rtrait_type_constraint list) : C.eval_ctx =
+ let norm_trait_etypes, norm_trait_rtypes =
+ compute_norm_trait_types_from_preds trait_type_constraints
+ in
+ { ctx with C.norm_trait_etypes; norm_trait_rtypes }
+
(** A trait instance id refers to a local clause if it only uses the variants:
[Self], [Clause], [ParentClause], [ItemClause] *)
let rec trait_instance_id_is_local_clause (id : 'r T.trait_instance_id) : bool =
@@ -244,29 +320,41 @@ and ctx_normalize_trait_decl_ref (ctx : 'r norm_ctx)
let decl_generics = ctx_normalize_generic_args ctx decl_generics in
{ T.trait_decl_id; decl_generics }
-let ctx_normalize_rty (ctx : C.eval_ctx) (ty : T.rty) : T.rty =
+let ctx_normalize_trait_type_constraint (ctx : 'r norm_ctx)
+ (ttc : 'r T.trait_type_constraint) : 'r T.trait_type_constraint =
+ let { T.trait_ref; generics; type_name; ty } = ttc in
+ let trait_ref = ctx_normalize_trait_ref ctx trait_ref in
+ let generics = ctx_normalize_generic_args ctx generics in
+ let ty = ctx_normalize_ty ctx ty in
+ { T.trait_ref; generics; type_name; ty }
+
+let mk_rnorm_ctx (ctx : C.eval_ctx) : T.RegionId.id T.region norm_ctx =
let get_ty_repr x = C.RTraitTypeRefMap.find_opt x ctx.norm_trait_rtypes in
- let ctx : T.RegionId.id T.region norm_ctx =
- {
- ctx;
- get_ty_repr;
- convert_ety = TypesUtils.ety_no_regions_to_rty;
- convert_etrait_ref = TypesUtils.etrait_ref_no_regions_to_gr_trait_ref;
- }
- in
- ctx_normalize_ty ctx ty
+ {
+ ctx;
+ get_ty_repr;
+ convert_ety = TypesUtils.ety_no_regions_to_rty;
+ convert_etrait_ref = TypesUtils.etrait_ref_no_regions_to_gr_trait_ref;
+ }
-let ctx_normalize_ety (ctx : C.eval_ctx) (ty : T.ety) : T.ety =
+let mk_enorm_ctx (ctx : C.eval_ctx) : T.erased_region norm_ctx =
let get_ty_repr x = C.ETraitTypeRefMap.find_opt x ctx.norm_trait_etypes in
- let ctx : T.erased_region norm_ctx =
- {
- ctx;
- get_ty_repr;
- convert_ety = (fun x -> x);
- convert_etrait_ref = (fun x -> x);
- }
- in
- ctx_normalize_ty ctx ty
+ {
+ ctx;
+ get_ty_repr;
+ convert_ety = (fun x -> x);
+ convert_etrait_ref = (fun x -> x);
+ }
+
+let ctx_normalize_rty (ctx : C.eval_ctx) (ty : T.rty) : T.rty =
+ ctx_normalize_ty (mk_rnorm_ctx ctx) ty
+
+let ctx_normalize_ety (ctx : C.eval_ctx) (ty : T.ety) : T.ety =
+ ctx_normalize_ty (mk_enorm_ctx ctx) ty
+
+let ctx_normalize_rtrait_type_constraint (ctx : C.eval_ctx)
+ (ttc : T.rtrait_type_constraint) : T.rtrait_type_constraint =
+ ctx_normalize_trait_type_constraint (mk_rnorm_ctx ctx) ttc
(** Same as [type_decl_get_instantiated_variants_fields_rtypes] but normalizes the types *)
let type_decl_get_inst_norm_variants_fields_rtypes (ctx : C.eval_ctx)
@@ -329,7 +417,10 @@ let ctx_subst_norm_signature (ctx : C.eval_ctx)
Subst.substitute_signature asubst r_subst ty_subst cg_subst tr_subst tr_self
sg
in
- let { A.regions_hierarchy; inputs; output } = sg in
+ let { A.regions_hierarchy; inputs; output; trait_type_constraints } = sg in
let inputs = List.map (ctx_normalize_rty ctx) inputs in
let output = ctx_normalize_rty ctx output in
- { regions_hierarchy; inputs; output }
+ let trait_type_constraints =
+ List.map (ctx_normalize_rtrait_type_constraint ctx) trait_type_constraints
+ in
+ { regions_hierarchy; inputs; output; trait_type_constraints }
diff --git a/compiler/Contexts.ml b/compiler/Contexts.ml
index 0719364e..9d22a643 100644
--- a/compiler/Contexts.ml
+++ b/compiler/Contexts.ml
@@ -40,6 +40,7 @@ type dummy_var_id = DummyVarId.id [@@deriving show, ord]
fn f x : fun_type =
let id = fresh_id () in
...
+ fun () -> ...
let g = f x in // <-- the fresh identifier gets generated here
let x1 = g () in // <-- no fresh generation here
@@ -315,7 +316,6 @@ type eval_ctx = {
(** The map from const generic vars to their values. Those values
can be symbolic values or concrete values (in the latter case:
if we run in interpreter mode) *)
- trait_clauses : etrait_ref list;
norm_trait_etypes : ety ETraitTypeRefMap.t;
(** The normalized trait types (a map from trait types to their representatives).
Note that this doesn't support account higher-order types. *)
diff --git a/compiler/Interpreter.ml b/compiler/Interpreter.ml
index eb66013d..b5e9fcb9 100644
--- a/compiler/Interpreter.ml
+++ b/compiler/Interpreter.ml
@@ -32,10 +32,12 @@ let compute_contexts (m : A.crate) : C.decls_ctx =
let trait_impls_ctx = { C.trait_impls } in
{ C.type_ctx; fun_ctx; global_ctx; trait_decls_ctx; trait_impls_ctx }
+(** **WARNING**: this function doesn't compute the normalized types
+ (for the trait type aliases). This should be computed afterwards.
+ *)
let initialize_eval_context (ctx : C.decls_ctx)
(region_groups : T.RegionGroupId.id list) (type_vars : T.type_var list)
- (const_generic_vars : T.const_generic_var list)
- (trait_clauses : T.etrait_ref list) : C.eval_ctx =
+ (const_generic_vars : T.const_generic_var list) : C.eval_ctx =
C.reset_global_counters ();
let const_generic_vars_map =
T.ConstGenericVarId.Map.of_list
@@ -56,11 +58,53 @@ let initialize_eval_context (ctx : C.decls_ctx)
C.type_vars;
C.const_generic_vars;
C.const_generic_vars_map;
- C.trait_clauses;
+ C.norm_trait_etypes = C.ETraitTypeRefMap.empty (* Empty for now *);
+ C.norm_trait_rtypes = C.RTraitTypeRefMap.empty (* Empty for now *);
C.env = [ C.Frame ];
C.ended_regions = T.RegionId.Set.empty;
}
+(** Instantiate a function signature for a symbolic execution *)
+let symbolic_instantiate_fun_sig (ctx : C.eval_ctx) (fdef : A.fun_decl) :
+ A.inst_fun_sig =
+ let sg = fdef.signature in
+ let tr_self =
+ match fdef.kind with
+ | RegularKind | TraitMethodImpl _ -> T.UnknownTrait __FUNCTION__
+ | TraitMethodDecl _ | TraitMethodProvided _ ->
+ raise (Failure "Unimplemented")
+ in
+ let generics =
+ let { T.regions; types; const_generics; trait_clauses } = sg.generics in
+ let regions = List.map (fun _ -> T.Erased) regions in
+ let types = List.map (fun (v : T.type_var) -> T.TypeVar v.T.index) types in
+ let const_generics =
+ List.map
+ (fun (v : T.const_generic_var) -> T.ConstGenericVar v.T.index)
+ const_generics
+ in
+ (* Annoying that we have to generate this substitution here *)
+ let r_subst _ = raise (Failure "Unexpected region") in
+ let ty_subst = Subst.make_type_subst_from_vars sg.generics.types types in
+ let cg_subst =
+ Subst.make_const_generic_subst_from_vars sg.generics.const_generics
+ const_generics
+ in
+ let tr_subst _ = raise (Failure "Unexpected local trait clause") in
+ let subst = { Subst.r_subst; ty_subst; cg_subst; tr_subst; tr_self } in
+ let trait_refs =
+ List.map
+ (fun (c : T.trait_clause) ->
+ let { T.trait_id = trait_decl_id; generics; _ } = c in
+ let generics = Subst.generic_args_substitute subst generics in
+ let trait_decl_ref = { T.trait_decl_id; decl_generics = generics } in
+ { T.trait_id = T.Clause c.clause_id; generics; trait_decl_ref })
+ trait_clauses
+ in
+ { T.regions; types; const_generics; trait_refs }
+ in
+ instantiate_fun_sig ctx generics tr_self sg
+
(** Initialize an evaluation context to execute a function.
Introduces local variables initialized in the following manner:
@@ -94,18 +138,15 @@ let initialize_symbolic_context_for_fun (ctx : C.decls_ctx) (fdef : A.fun_decl)
in
let ctx =
initialize_eval_context ctx region_groups sg.generics.types
- sg.generics.const_generics sg.generics.trait_clauses
+ sg.generics.const_generics
in
(* Instantiate the signature *)
- let type_params =
- List.map (fun (v : T.type_var) -> T.TypeVar v.T.index) sg.type_params
- in
- let cg_params =
- List.map
- (fun (v : T.const_generic_var) -> T.ConstGenericVar v.T.index)
- sg.const_generic_params
+ let inst_sg = symbolic_instantiate_fun_sig ctx fdef in
+ (* Compute the normalization maps *)
+ let ctx =
+ AssociatedTypes.ctx_add_norm_trait_types_from_preds ctx
+ inst_sg.trait_type_constraints
in
- let inst_sg = instantiate_fun_sig type_params cg_params sg in
(* Create fresh symbolic values for the inputs *)
let input_svs =
List.map (fun ty -> mk_fresh_symbolic_value V.SynthInput ty) inst_sg.inputs
@@ -180,15 +221,7 @@ let evaluate_function_symbolic_synthesize_backward_from_return
* an instantiation of the signature, so that we use fresh
* region ids for the return abstractions. *)
let sg = fdef.signature in
- let type_params =
- List.map (fun (v : T.type_var) -> T.TypeVar v.T.index) sg.type_params
- in
- let cg_params =
- List.map
- (fun (v : T.const_generic_var) -> T.ConstGenericVar v.T.index)
- sg.const_generic_params
- in
- let ret_inst_sg = instantiate_fun_sig type_params cg_params sg in
+ let ret_inst_sg = symbolic_instantiate_fun_sig ctx fdef in
let ret_rty = ret_inst_sg.output in
(* Move the return value out of the return variable *)
let pop_return_value = is_regular_return in
@@ -362,19 +395,14 @@ let evaluate_function_symbolic_synthesize_backward_from_return
for the synthesis)
- the symbolic AST generated by the symbolic execution
*)
-let evaluate_function_symbolic (synthesize : bool)
- (type_context : C.type_context) (fun_context : C.fun_context)
- (global_context : C.global_context) (fdef : A.fun_decl) :
- V.symbolic_value list * SA.expression option =
+let evaluate_function_symbolic (synthesize : bool) (ctx : C.decls_ctx)
+ (fdef : A.fun_decl) : V.symbolic_value list * SA.expression option =
(* Debug *)
let name_to_string () = Print.fun_name_to_string fdef.A.name in
log#ldebug (lazy ("evaluate_function_symbolic: " ^ name_to_string ()));
(* Create the evaluation context *)
- let ctx, input_svs, inst_sg =
- initialize_symbolic_context_for_fun type_context fun_context global_context
- fdef
- in
+ let ctx, input_svs, inst_sg = initialize_symbolic_context_for_fun ctx fdef in
(* Create the continuation to finish the evaluation *)
let config = C.mk_config C.SymbolicMode in
@@ -518,7 +546,7 @@ module Test = struct
(* Create the evaluation context *)
let decls_ctx = compute_contexts crate in
- let ctx = initialize_eval_context decls_ctx [] [] [] [] in
+ let ctx = initialize_eval_context decls_ctx [] [] [] in
(* Insert the (uninitialized) local variables *)
let ctx = C.ctx_push_uninitialized_vars ctx body.A.locals in
diff --git a/compiler/InterpreterLoopsJoinCtxs.ml b/compiler/InterpreterLoopsJoinCtxs.ml
index 045ba9d8..fa44e20e 100644
--- a/compiler/InterpreterLoopsJoinCtxs.ml
+++ b/compiler/InterpreterLoopsJoinCtxs.ml
@@ -560,7 +560,6 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx)
type_vars;
const_generic_vars;
const_generic_vars_map;
- trait_clauses;
norm_trait_etypes;
norm_trait_rtypes;
env = _;
@@ -578,7 +577,6 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx)
type_vars = _;
const_generic_vars = _;
const_generic_vars_map = _;
- trait_clauses = _;
norm_trait_etypes = _;
norm_trait_rtypes = _;
env = _;
@@ -598,7 +596,6 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx)
type_vars;
const_generic_vars;
const_generic_vars_map;
- trait_clauses;
norm_trait_etypes;
norm_trait_rtypes;
env;
diff --git a/compiler/LlbcAst.ml b/compiler/LlbcAst.ml
index f4d26e18..2db859b2 100644
--- a/compiler/LlbcAst.ml
+++ b/compiler/LlbcAst.ml
@@ -11,6 +11,7 @@ type abs_region_groups = (AbstractionId.id, RegionId.id) g_region_groups
(** A function signature, after instantiation *)
type inst_fun_sig = {
regions_hierarchy : abs_region_groups;
+ trait_type_constraints : rtrait_type_constraint list;
inputs : rty list;
output : rty;
}
diff --git a/compiler/Substitute.ml b/compiler/Substitute.ml
index fe88faea..b1680282 100644
--- a/compiler/Substitute.ml
+++ b/compiler/Substitute.ml
@@ -41,30 +41,35 @@ let ty_substitute_visitor (subst : ('r1, 'r2) subst) =
(** Substitute types variables and regions in a type.
- **IMPORTANT**: this doesn't normalize the type.
+ **IMPORTANT**: this doesn't normalize the types.
*)
let ty_substitute (subst : ('r1, 'r2) subst) (ty : 'r1 T.ty) : 'r2 T.ty =
let visitor = ty_substitute_visitor subst in
visitor#visit_ty () ty
-(** **IMPORTANT**: this doesn't normalize the trait ref. *)
+(** **IMPORTANT**: this doesn't normalize the types. *)
let trait_ref_substitute (subst : ('r1, 'r2) subst) (tr : 'r1 T.trait_ref) :
'r2 T.trait_ref =
let visitor = ty_substitute_visitor subst in
visitor#visit_trait_ref () tr
+(** **IMPORTANT**: this doesn't normalize the types. *)
+let generic_args_substitute (subst : ('r1, 'r2) subst) (g : 'r1 T.generic_args)
+ : 'r2 T.generic_args =
+ let visitor = ty_substitute_visitor subst in
+ visitor#visit_generic_args () g
+
+let erase_regions_subst : ('r, T.erased_region) subst =
+ {
+ r_subst = (fun _ -> T.Erased);
+ ty_subst = (fun vid -> T.TypeVar vid);
+ cg_subst = (fun id -> T.ConstGenericVar id);
+ tr_subst = (fun id -> T.Clause id);
+ tr_self = T.Self;
+ }
+
(** Convert an {!T.rty} to an {!T.ety} by erasing the region variables *)
-let erase_regions (ty : 'r T.ty) : T.ety =
- let subst =
- {
- r_subst = (fun _ -> T.Erased);
- ty_subst = (fun vid -> T.TypeVar vid);
- cg_subst = (fun id -> T.ConstGenericVar id);
- tr_subst = (fun id -> T.Clause id);
- tr_self = T.Self;
- }
- in
- ty_substitute subst ty
+let erase_regions (ty : 'r T.ty) : T.ety = ty_substitute erase_regions_subst ty
(** Generate fresh regions for region variables.
@@ -425,6 +430,15 @@ let fun_body_substitute_in_body
let body = statement_substitute subst body.body in
(locals, body)
+let trait_type_constraint_substitute (subst : ('r1, 'r2) subst)
+ (ttc : 'r1 T.trait_type_constraint) : 'r2 T.trait_type_constraint =
+ let { T.trait_ref; generics; type_name; ty } = ttc in
+ let visitor = ty_substitute_visitor subst in
+ let trait_ref = visitor#visit_trait_ref () trait_ref in
+ let generics = visitor#visit_generic_args () generics in
+ let ty = visitor#visit_ty () ty in
+ { T.trait_ref; generics; type_name; ty }
+
(** Substitute a function signature.
**IMPORTANT:** this function doesn't normalize the types.
@@ -448,7 +462,12 @@ let substitute_signature (asubst : T.RegionGroupId.id -> V.AbstractionId.id)
{ id; regions; parents }
in
let regions_hierarchy = List.map subst_region_group sg.A.regions_hierarchy in
- { A.regions_hierarchy; inputs; output }
+ let trait_type_constraints =
+ List.map
+ (trait_type_constraint_substitute subst)
+ sg.preds.trait_type_constraints
+ in
+ { A.inputs; output; regions_hierarchy; trait_type_constraints }
(** Substitute variable identifiers in a type *)
let ty_substitute_ids (ty_subst : T.TypeVarId.id -> T.TypeVarId.id)