summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSon Ho2023-08-31 12:47:43 +0200
committerSon Ho2023-08-31 12:47:43 +0200
commit6f22190cba92a44b6c74bfcce8f5ed142a68e195 (patch)
treeed0558281093e4e9dac0983aac22c520434644a4
parent8543092569616ef6a75949a72532f7b73dc696f2 (diff)
Start adding support for traits
-rw-r--r--compiler/AssociatedTypes.ml91
-rw-r--r--compiler/Assumed.ml188
-rw-r--r--compiler/Contexts.ml24
-rw-r--r--compiler/FunsAnalysis.ml15
-rw-r--r--compiler/Interpreter.ml96
-rw-r--r--compiler/InterpreterBorrowsCore.ml16
-rw-r--r--compiler/InterpreterExpansion.ml45
-rw-r--r--compiler/InterpreterExpressions.ml44
-rw-r--r--compiler/InterpreterLoopsJoinCtxs.ml9
-rw-r--r--compiler/InterpreterLoopsMatchCtxs.ml19
-rw-r--r--compiler/InterpreterPaths.ml55
-rw-r--r--compiler/InterpreterPaths.mli7
-rw-r--r--compiler/InterpreterProjectors.ml15
-rw-r--r--compiler/InterpreterStatements.ml305
-rw-r--r--compiler/InterpreterStatements.mli6
-rw-r--r--compiler/InterpreterUtils.ml2
-rw-r--r--compiler/Invariants.ml67
-rw-r--r--compiler/Logging.ml3
-rw-r--r--compiler/PrePasses.ml4
-rw-r--r--compiler/Print.ml57
-rw-r--r--compiler/PrintPure.ml155
-rw-r--r--compiler/Pure.ml100
-rw-r--r--compiler/PureTypeCheck.ml46
-rw-r--r--compiler/PureUtils.ml132
-rw-r--r--compiler/Substitute.ml443
-rw-r--r--compiler/SymbolicAst.ml5
-rw-r--r--compiler/SymbolicToPure.ml23
-rw-r--r--compiler/SynthesizeSymbolic.ml34
-rw-r--r--compiler/Translate.ml4
-rw-r--r--compiler/TypesAnalysis.ml23
-rw-r--r--compiler/dune1
31 files changed, 1258 insertions, 776 deletions
diff --git a/compiler/AssociatedTypes.ml b/compiler/AssociatedTypes.ml
new file mode 100644
index 00000000..4e5625cb
--- /dev/null
+++ b/compiler/AssociatedTypes.ml
@@ -0,0 +1,91 @@
+(** This file implements utilities to handle trait associated types, in
+ particular with normalization helpers.
+
+ When normalizing a type, we simplify the references to the trait associated
+ types, and choose a representative when there are equalities between types
+ enforced by local clauses (i.e., clauses of the shape [where Trait1::T = Trait2::U]).
+ *)
+
+module T = Types
+module TU = TypesUtils
+module V = Values
+module E = Expressions
+module A = LlbcAst
+module C = Contexts
+module Subst = Substitute
+module L = Logging
+
+(** The local logger *)
+let log = L.associated_types_log
+
+(** Normalize a type by simplyfying the references to trait associated types
+ and choosing a representative when there are equalities between types
+ enforced by local clauses (i.e., `where Trait1::T = Trait2::U`. *)
+let ctx_normalize_type (_ctx : C.eval_ctx) (_ty : 'r T.ty) : 'r T.ty =
+ raise (Failure "Unimplemented")
+
+(** Same as [type_decl_get_instantiated_variants_fields_rtypes] but normalizes the types *)
+let type_decl_get_inst_norm_variants_fields_rtypes (ctx : C.eval_ctx)
+ (def : T.type_decl) (generics : T.rgeneric_args) :
+ (T.VariantId.id option * T.rty list) list =
+ let res =
+ Subst.type_decl_get_instantiated_variants_fields_rtypes def generics
+ in
+ List.map
+ (fun (variant_id, types) ->
+ (variant_id, List.map (ctx_normalize_type ctx) types))
+ res
+
+(** Same as [type_decl_get_instantiated_field_rtypes] but normalizes the types *)
+let type_decl_get_inst_norm_field_rtypes (ctx : C.eval_ctx) (def : T.type_decl)
+ (opt_variant_id : T.VariantId.id option) (generics : T.rgeneric_args) :
+ T.rty list =
+ let types =
+ Subst.type_decl_get_instantiated_field_rtypes def opt_variant_id generics
+ in
+ List.map (ctx_normalize_type ctx) types
+
+(** Same as [ctx_adt_value_get_instantiated_field_rtypes] but normalizes the types *)
+let ctx_adt_value_get_inst_norm_field_rtypes (ctx : C.eval_ctx)
+ (adt : V.adt_value) (id : T.type_id) (generics : T.rgeneric_args) :
+ T.rty list =
+ let types =
+ Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id generics
+ in
+ List.map (ctx_normalize_type ctx) types
+
+(** Same as [ctx_adt_value_get_instantiated_field_etypes] but normalizes the types *)
+let type_decl_get_inst_norm_field_etypes (ctx : C.eval_ctx) (def : T.type_decl)
+ (opt_variant_id : T.VariantId.id option) (generics : T.egeneric_args) :
+ T.ety list =
+ let types =
+ Subst.type_decl_get_instantiated_field_etypes def opt_variant_id generics
+ in
+ List.map (ctx_normalize_type ctx) types
+
+(** Same as [ctx_adt_get_instantiated_field_etypes] but normalizes the types *)
+let ctx_adt_get_inst_norm_field_etypes (ctx : C.eval_ctx)
+ (def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option)
+ (generics : T.egeneric_args) : T.ety list =
+ let types =
+ Subst.ctx_adt_get_instantiated_field_etypes ctx def_id opt_variant_id
+ generics
+ in
+ List.map (ctx_normalize_type ctx) types
+
+(** Same as [substitute_signature] but normalizes the types *)
+let ctx_subst_norm_signature (ctx : C.eval_ctx)
+ (asubst : T.RegionGroupId.id -> V.AbstractionId.id)
+ (r_subst : T.RegionVarId.id -> T.RegionId.id)
+ (ty_subst : T.TypeVarId.id -> T.rty)
+ (cg_subst : T.ConstGenericVarId.id -> T.const_generic)
+ (tr_subst : T.TraitClauseId.id -> T.rtrait_instance_id)
+ (tr_self : T.rtrait_instance_id) (sg : A.fun_sig) : A.inst_fun_sig =
+ let sg =
+ Subst.substitute_signature asubst r_subst ty_subst cg_subst tr_subst tr_self
+ sg
+ in
+ let { A.regions_hierarchy; inputs; output } = sg in
+ let inputs = List.map (ctx_normalize_type ctx) inputs in
+ let output = ctx_normalize_type ctx output in
+ { regions_hierarchy; inputs; output }
diff --git a/compiler/Assumed.ml b/compiler/Assumed.ml
index 25462504..e156c335 100644
--- a/compiler/Assumed.ml
+++ b/compiler/Assumed.ml
@@ -63,75 +63,81 @@ module Sig = struct
let empty_const_generic_params : T.const_generic_var list = []
+ let mk_generic_args regions types const_generics : T.sgeneric_args =
+ { regions; types; const_generics; trait_refs = [] }
+
+ let mk_generic_params regions types const_generics : T.generic_params =
+ { regions; types; const_generics; trait_clauses = [] }
+
let mk_ref_ty (r : T.RegionVarId.id T.region) (ty : T.sty) (is_mut : bool) :
T.sty =
let ref_kind = if is_mut then T.Mut else T.Shared in
mk_ref_ty r ty ref_kind
let mk_array_ty (ty : T.sty) (cg : T.const_generic) : T.sty =
- Adt (Assumed Array, [], [ ty ], [ cg ])
+ Adt (Assumed Array, mk_generic_args [] [ ty ] [ cg ])
+
+ let mk_slice_ty (ty : T.sty) : T.sty =
+ Adt (Assumed Slice, mk_generic_args [] [ ty ] [])
- let mk_slice_ty (ty : T.sty) : T.sty = Adt (Assumed Slice, [], [ ty ], [])
- let range_ty : T.sty = Adt (Assumed Range, [], [ usize_ty ], [])
+ let range_ty : T.sty = Adt (Assumed Range, mk_generic_args [] [ usize_ty ] [])
+
+ let mk_sig generics regions_hierarchy inputs output : A.fun_sig =
+ let preds : T.predicates =
+ { regions_outlive = []; types_outlive = []; trait_type_constraints = [] }
+ in
+ {
+ generics;
+ preds;
+ parent_params_info = None;
+ regions_hierarchy;
+ inputs;
+ output;
+ }
(** [fn<T>(&'a mut T, T) -> T] *)
let mem_replace_sig : A.fun_sig =
(* The signature fields *)
- let region_params = [ region_param_0 ] (* <'a> *) in
+ let regions = [ region_param_0 ] (* <'a> *) in
let regions_hierarchy = [ region_group_0 ] (* [{<'a>}] *) in
- let type_params = [ type_param_0 ] (* <T> *) in
+ let types = [ type_param_0 ] (* <T> *) in
+ let generics = mk_generic_params regions types [] in
let inputs =
[ mk_ref_ty rvar_0 tvar_0 true (* &'a mut T *); tvar_0 (* T *) ]
in
let output = tvar_0 (* T *) in
- {
- region_params;
- regions_hierarchy;
- type_params;
- const_generic_params = empty_const_generic_params;
- inputs;
- output;
- }
+ mk_sig generics regions_hierarchy inputs output
(** [fn<T>(T) -> Box<T>] *)
let box_new_sig : A.fun_sig =
- {
- region_params = [];
- regions_hierarchy = [];
- type_params = [ type_param_0 ] (* <T> *);
- const_generic_params = empty_const_generic_params;
- inputs = [ tvar_0 (* T *) ];
- output = mk_box_ty tvar_0 (* Box<T> *);
- }
+ let generics = mk_generic_params [] [ type_param_0 ] [] (* <T> *) in
+ let regions_hierarchy = [] in
+ let inputs = [ tvar_0 (* T *) ] in
+ let output = mk_box_ty tvar_0 (* Box<T> *) in
+ mk_sig generics regions_hierarchy inputs output
(** [fn<T>(Box<T>) -> ()] *)
let box_free_sig : A.fun_sig =
- {
- region_params = [];
- regions_hierarchy = [];
- type_params = [ type_param_0 ] (* <T> *);
- const_generic_params = empty_const_generic_params;
- inputs = [ mk_box_ty tvar_0 (* Box<T> *) ];
- output = mk_unit_ty (* () *);
- }
+ let generics = mk_generic_params [] [ type_param_0 ] [] (* <T> *) in
+ let regions_hierarchy = [] in
+ let inputs = [ mk_box_ty tvar_0 (* Box<T> *) ] in
+ let output = mk_unit_ty (* () *) in
+ mk_sig generics regions_hierarchy inputs output
(** Helper for [Box::deref_shared] and [Box::deref_mut].
Returns:
[fn<'a, T>(&'a (mut) Box<T>) -> &'a (mut) T]
*)
let box_deref_gen_sig (is_mut : bool) : A.fun_sig =
- (* The signature fields *)
- let region_params = [ region_param_0 ] in
+ let generics =
+ mk_generic_params [ region_param_0 ] [ type_param_0 ] [] (* <'a, T> *)
+ in
let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
- {
- region_params;
- regions_hierarchy;
- type_params = [ type_param_0 ] (* <T> *);
- const_generic_params = empty_const_generic_params;
- inputs =
- [ mk_ref_ty rvar_0 (mk_box_ty tvar_0) is_mut (* &'a (mut) Box<T> *) ];
- output = mk_ref_ty rvar_0 tvar_0 is_mut (* &'a (mut) T *);
- }
+ let inputs =
+ [ mk_ref_ty rvar_0 (mk_box_ty tvar_0) is_mut (* &'a (mut) Box<T> *) ]
+ in
+ let output = mk_ref_ty rvar_0 tvar_0 is_mut (* &'a (mut) T *) in
+ mk_sig generics regions_hierarchy inputs output
(** [fn<'a, T>(&'a Box<T>) -> &'a T] *)
let box_deref_shared_sig = box_deref_gen_sig false
@@ -141,26 +147,18 @@ module Sig = struct
(** [fn<T>() -> Vec<T>] *)
let vec_new_sig : A.fun_sig =
- let region_params = [] in
+ let generics = mk_generic_params [] [ type_param_0 ] [] (* <T> *) in
let regions_hierarchy = [] in
- let type_params = [ type_param_0 ] (* <T> *) in
let inputs = [] in
let output = mk_vec_ty tvar_0 (* Vec<T> *) in
- {
- region_params;
- regions_hierarchy;
- type_params;
- const_generic_params = empty_const_generic_params;
- inputs;
- output;
- }
+ mk_sig generics regions_hierarchy inputs output
(** [fn<T>(&'a mut Vec<T>, T)] *)
let vec_push_sig : A.fun_sig =
- (* The signature fields *)
- let region_params = [ region_param_0 ] in
+ let generics =
+ mk_generic_params [ region_param_0 ] [ type_param_0 ] [] (* <'a, T> *)
+ in
let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
- let type_params = [ type_param_0 ] (* <T> *) in
let inputs =
[
mk_ref_ty rvar_0 (mk_vec_ty tvar_0) true (* &'a mut Vec<T> *);
@@ -168,21 +166,14 @@ module Sig = struct
]
in
let output = mk_unit_ty (* () *) in
- {
- region_params;
- regions_hierarchy;
- type_params;
- const_generic_params = empty_const_generic_params;
- inputs;
- output;
- }
+ mk_sig generics regions_hierarchy inputs output
(** [fn<T>(&'a mut Vec<T>, usize, T)] *)
let vec_insert_sig : A.fun_sig =
- (* The signature fields *)
- let region_params = [ region_param_0 ] in
+ let generics =
+ mk_generic_params [ region_param_0 ] [ type_param_0 ] [] (* <'a, T> *)
+ in
let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
- let type_params = [ type_param_0 ] (* <T> *) in
let inputs =
[
mk_ref_ty rvar_0 (mk_vec_ty tvar_0) true (* &'a mut Vec<T> *);
@@ -191,42 +182,28 @@ module Sig = struct
]
in
let output = mk_unit_ty (* () *) in
- {
- region_params;
- regions_hierarchy;
- type_params;
- const_generic_params = empty_const_generic_params;
- inputs;
- output;
- }
+ mk_sig generics regions_hierarchy inputs output
(** [fn<T>(&'a Vec<T>) -> usize] *)
let vec_len_sig : A.fun_sig =
- (* The signature fields *)
- let region_params = [ region_param_0 ] in
+ let generics =
+ mk_generic_params [ region_param_0 ] [ type_param_0 ] [] (* <'a, T> *)
+ in
let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
- let type_params = [ type_param_0 ] (* <T> *) in
let inputs =
[ mk_ref_ty rvar_0 (mk_vec_ty tvar_0) false (* &'a Vec<T> *) ]
in
let output = mk_usize_ty (* usize *) in
- {
- region_params;
- regions_hierarchy;
- type_params;
- const_generic_params = empty_const_generic_params;
- inputs;
- output;
- }
+ mk_sig generics regions_hierarchy inputs output
(** Helper:
[fn<T>(&'a (mut) Vec<T>, usize) -> &'a (mut) T]
*)
let vec_index_gen_sig (is_mut : bool) : A.fun_sig =
- (* The signature fields *)
- let region_params = [ region_param_0 ] in
+ let generics =
+ mk_generic_params [ region_param_0 ] [ type_param_0 ] [] (* <'a, T> *)
+ in
let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
- let type_params = [ type_param_0 ] (* <T> *) in
let inputs =
[
mk_ref_ty rvar_0 (mk_vec_ty tvar_0) is_mut (* &'a (mut) Vec<T> *);
@@ -234,14 +211,7 @@ module Sig = struct
]
in
let output = mk_ref_ty rvar_0 tvar_0 is_mut (* &'a (mut) T *) in
- {
- region_params;
- regions_hierarchy;
- type_params;
- const_generic_params = empty_const_generic_params;
- inputs;
- output;
- }
+ mk_sig generics regions_hierarchy inputs output
(** [fn<T>(&'a Vec<T>, usize) -> &'a T] *)
let vec_index_shared_sig : A.fun_sig = vec_index_gen_sig false
@@ -266,10 +236,10 @@ module Sig = struct
let mk_array_slice_borrow_sig (cgs : T.const_generic_var list)
(input_ty : T.TypeVarId.id -> T.sty) (index_ty : T.sty option)
(output_ty : T.TypeVarId.id -> T.sty) (is_mut : bool) : A.fun_sig =
- (* The signature fields *)
- let region_params = [ region_param_0 ] in
+ let generics =
+ mk_generic_params [ region_param_0 ] [ type_param_0 ] cgs (* <'a, T> *)
+ in
let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
- let type_params = [ type_param_0 ] (* <T> *) in
let inputs =
[
mk_ref_ty rvar_0
@@ -285,14 +255,7 @@ module Sig = struct
(output_ty type_param_0.index)
is_mut (* &'a (mut) output_ty<T> *)
in
- {
- region_params;
- regions_hierarchy;
- type_params;
- const_generic_params = cgs;
- inputs;
- output;
- }
+ mk_sig generics regions_hierarchy inputs output
let mk_array_slice_index_sig (is_array : bool) (is_mut : bool) : A.fun_sig =
(* Array<T, N> *)
@@ -342,22 +305,15 @@ module Sig = struct
[fn<T>(&'a [T]) -> usize]
*)
let slice_len_sig : A.fun_sig =
- (* The signature fields *)
- let region_params = [ region_param_0 ] in
+ let generics =
+ mk_generic_params [ region_param_0 ] [ type_param_0 ] [] (* <'a, T> *)
+ in
let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
- let type_params = [ type_param_0 ] (* <T> *) in
let inputs =
[ mk_ref_ty rvar_0 (mk_slice_ty tvar_0) false (* &'a [T] *) ]
in
let output = mk_usize_ty (* usize *) in
- {
- region_params;
- regions_hierarchy;
- type_params;
- const_generic_params = empty_const_generic_params;
- inputs;
- output;
- }
+ mk_sig generics regions_hierarchy inputs output
end
type assumed_info = A.assumed_fun_id * A.fun_sig * bool * name
diff --git a/compiler/Contexts.ml b/compiler/Contexts.ml
index 14b5d559..2d396924 100644
--- a/compiler/Contexts.ml
+++ b/compiler/Contexts.ml
@@ -255,11 +255,28 @@ type fun_context = { fun_decls : fun_decl FunDeclId.Map.t } [@@deriving show]
type global_context = { global_decls : global_decl GlobalDeclId.Map.t }
[@@deriving show]
+type trait_decls_context = { trait_decls : trait_decl TraitDeclId.Map.t }
+[@@deriving show]
+
+type trait_impls_context = { trait_impls : trait_impl TraitImplId.Map.t }
+[@@deriving show]
+
+type decls_ctx = {
+ type_ctx : type_context;
+ fun_ctx : fun_context;
+ global_ctx : global_context;
+ trait_decls_ctx : trait_decls_context;
+ trait_impls_ctx : trait_impls_context;
+}
+[@@deriving show]
+
(** Evaluation context *)
type eval_ctx = {
type_context : type_context;
fun_context : fun_context;
global_context : global_context;
+ trait_decls_context : trait_decls_context;
+ trait_impls_context : trait_impls_context;
region_groups : RegionGroupId.id list;
type_vars : type_var list;
const_generic_vars : const_generic_var list;
@@ -267,6 +284,7 @@ type eval_ctx = {
(** The map from const generic vars to their values. Those values
can be symbolic values or concrete values (in the latter case:
if we run in interpreter mode) *)
+ trait_clauses : etrait_ref list;
env : env;
ended_regions : RegionId.Set.t;
}
@@ -308,6 +326,12 @@ let ctx_lookup_global_decl (ctx : eval_ctx) (gid : GlobalDeclId.id) :
global_decl =
GlobalDeclId.Map.find gid ctx.global_context.global_decls
+let ctx_lookup_trait_decl (ctx : eval_ctx) (id : TraitDeclId.id) : trait_decl =
+ TraitDeclId.Map.find id ctx.trait_decls_context.trait_decls
+
+let ctx_lookup_trait_impl (ctx : eval_ctx) (id : TraitImplId.id) : trait_impl =
+ TraitImplId.Map.find id ctx.trait_impls_context.trait_impls
+
(** Retrieve a variable's value in the current frame *)
let env_lookup_var_value (env : env) (vid : VarId.id) : typed_value =
snd (env_lookup_var env vid)
diff --git a/compiler/FunsAnalysis.ml b/compiler/FunsAnalysis.ml
index b72fa078..f4406653 100644
--- a/compiler/FunsAnalysis.ml
+++ b/compiler/FunsAnalysis.ml
@@ -70,14 +70,14 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t)
method! visit_rvalue _env rv =
match rv with
- | Use _ | Ref _ | Global _ | Discriminant _ | Aggregate _ -> ()
+ | Use _ | RvRef _ | Global _ | Discriminant _ | Aggregate _ -> ()
| UnaryOp (uop, _) -> can_fail := EU.unop_can_fail uop || !can_fail
| BinaryOp (bop, _, _) ->
can_fail := EU.binop_can_fail bop || !can_fail
method! visit_Call env call =
(match call.func with
- | Regular id ->
+ | FunId (Regular id) ->
if FunDeclId.Set.mem id fun_ids then (
can_diverge := true;
is_rec := true)
@@ -86,9 +86,13 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t)
self#may_fail info.can_fail;
stateful := !stateful || info.stateful;
can_diverge := !can_diverge || info.can_diverge
- | Assumed id ->
+ | FunId (Assumed id) ->
(* None of the assumed functions can diverge nor are considered stateful *)
- can_fail := !can_fail || Assumed.assumed_can_fail id);
+ can_fail := !can_fail || Assumed.assumed_can_fail id
+ | TraitMethod _ ->
+ (* We consider trait functions can fail, diverge, and are not stateful *)
+ can_fail := true;
+ can_diverge := true);
super#visit_Call env call
method! visit_Panic env =
@@ -141,7 +145,8 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t)
let rec analyze_decl_groups (decls : declaration_group list) : unit =
match decls with
| [] -> ()
- | Type _ :: decls' -> analyze_decl_groups decls'
+ | (Type _ | TraitDecl _ | TraitImpl _) :: decls' ->
+ analyze_decl_groups decls'
| Fun decl :: decls' ->
analyze_fun_decl_group decl;
analyze_decl_groups decls'
diff --git a/compiler/Interpreter.ml b/compiler/Interpreter.ml
index 37eeb333..eb66013d 100644
--- a/compiler/Interpreter.ml
+++ b/compiler/Interpreter.ml
@@ -12,27 +12,30 @@ module SA = SymbolicAst
(** The local logger *)
let log = L.interpreter_log
-let compute_type_fun_global_contexts (m : A.crate) :
- C.type_context * C.fun_context * C.global_context =
- let type_decls_list, _, _ = split_declarations m.declarations in
+let compute_contexts (m : A.crate) : C.decls_ctx =
+ let type_decls_list, _, _, _, _ = split_declarations m.declarations in
let type_decls = m.types in
let fun_decls = m.functions in
let global_decls = m.globals in
- let type_decls_groups, _funs_defs_groups, _globals_defs_groups =
+ let trait_decls = m.trait_decls in
+ let trait_impls = m.trait_impls in
+ let type_decls_groups, _, _, _, _ =
split_declarations_to_group_maps m.declarations
in
let type_infos =
TypesAnalysis.analyze_type_declarations type_decls type_decls_list
in
- let type_context = { C.type_decls_groups; type_decls; type_infos } in
- let fun_context = { C.fun_decls } in
- let global_context = { C.global_decls } in
- (type_context, fun_context, global_context)
-
-let initialize_eval_context (type_context : C.type_context)
- (fun_context : C.fun_context) (global_context : C.global_context)
+ let type_ctx = { C.type_decls_groups; type_decls; type_infos } in
+ let fun_ctx = { C.fun_decls } in
+ let global_ctx = { C.global_decls } in
+ let trait_decls_ctx = { C.trait_decls } in
+ let trait_impls_ctx = { C.trait_impls } in
+ { C.type_ctx; fun_ctx; global_ctx; trait_decls_ctx; trait_impls_ctx }
+
+let initialize_eval_context (ctx : C.decls_ctx)
(region_groups : T.RegionGroupId.id list) (type_vars : T.type_var list)
- (const_generic_vars : T.const_generic_var list) : C.eval_ctx =
+ (const_generic_vars : T.const_generic_var list)
+ (trait_clauses : T.etrait_ref list) : C.eval_ctx =
C.reset_global_counters ();
let const_generic_vars_map =
T.ConstGenericVarId.Map.of_list
@@ -44,33 +47,35 @@ let initialize_eval_context (type_context : C.type_context)
const_generic_vars)
in
{
- C.type_context;
- C.fun_context;
- C.global_context;
+ C.type_context = ctx.type_ctx;
+ C.fun_context = ctx.fun_ctx;
+ C.global_context = ctx.global_ctx;
+ C.trait_decls_context = ctx.trait_decls_ctx;
+ C.trait_impls_context = ctx.trait_impls_ctx;
C.region_groups;
C.type_vars;
C.const_generic_vars;
C.const_generic_vars_map;
+ C.trait_clauses;
C.env = [ C.Frame ];
C.ended_regions = T.RegionId.Set.empty;
}
(** Initialize an evaluation context to execute a function.
- Introduces local variables initialized in the following manner:
- - input arguments are initialized as symbolic values
- - the remaining locals are initialized as [⊥]
- Abstractions are introduced for the regions present in the function
- signature.
-
- We return:
- - the initialized evaluation context
- - the list of symbolic values introduced for the input values
- - the instantiated function signature
+ Introduces local variables initialized in the following manner:
+ - input arguments are initialized as symbolic values
+ - the remaining locals are initialized as [⊥]
+ Abstractions are introduced for the regions present in the function
+ signature.
+
+ We return:
+ - the initialized evaluation context
+ - the list of symbolic values introduced for the input values
+ - the instantiated function signature
*)
-let initialize_symbolic_context_for_fun (type_context : C.type_context)
- (fun_context : C.fun_context) (global_context : C.global_context)
- (fdef : A.fun_decl) : C.eval_ctx * V.symbolic_value list * A.inst_fun_sig =
+let initialize_symbolic_context_for_fun (ctx : C.decls_ctx) (fdef : A.fun_decl)
+ : C.eval_ctx * V.symbolic_value list * A.inst_fun_sig =
(* The abstractions are not initialized the same way as for function
* calls: they contain *loan* projectors, because they "provide" us
* with the input values (which behave as if they had been returned
@@ -88,8 +93,8 @@ let initialize_symbolic_context_for_fun (type_context : C.type_context)
List.map (fun (g : T.region_var_group) -> g.id) sg.regions_hierarchy
in
let ctx =
- initialize_eval_context type_context fun_context global_context
- region_groups sg.type_params sg.const_generic_params
+ initialize_eval_context ctx region_groups sg.generics.types
+ sg.generics.const_generics sg.generics.trait_clauses
in
(* Instantiate the signature *)
let type_params =
@@ -508,17 +513,12 @@ module Test = struct
(lazy ("test_unit_function: " ^ Print.fun_name_to_string fdef.A.name));
(* Sanity check - *)
- assert (List.length fdef.A.signature.region_params = 0);
- assert (List.length fdef.A.signature.type_params = 0);
+ assert (fdef.A.signature.generics = TypesUtils.mk_empty_generic_params);
assert (body.A.arg_count = 0);
(* Create the evaluation context *)
- let type_context, fun_context, global_context =
- compute_type_fun_global_contexts crate
- in
- let ctx =
- initialize_eval_context type_context fun_context global_context [] [] []
- in
+ let decls_ctx = compute_contexts crate in
+ let ctx = initialize_eval_context decls_ctx [] [] [] [] in
(* Insert the (uninitialized) local variables *)
let ctx = C.ctx_push_uninitialized_vars ctx body.A.locals in
@@ -546,9 +546,7 @@ module Test = struct
(no parameters, no arguments) - TODO: move *)
let fun_decl_is_transparent_unit (def : A.fun_decl) : bool =
Option.is_some def.body
- && def.A.signature.region_params = []
- && def.A.signature.type_params = []
- && def.A.signature.const_generic_params = []
+ && def.A.signature.generics = TypesUtils.mk_empty_generic_params
&& def.A.signature.inputs = []
(** Test all the unit functions in a list of function definitions *)
@@ -562,20 +560,4 @@ module Test = struct
test_unit_function crate def.A.def_id
in
A.FunDeclId.Map.iter test_unit_fun unit_funs
-
- (** Execute the symbolic interpreter on a function. *)
- let test_function_symbolic (synthesize : bool) (type_context : C.type_context)
- (fun_context : C.fun_context) (global_context : C.global_context)
- (fdef : A.fun_decl) : unit =
- (* Debug *)
- log#ldebug
- (lazy ("test_function_symbolic: " ^ Print.fun_name_to_string fdef.A.name));
-
- (* Evaluate *)
- let _ =
- evaluate_function_symbolic synthesize type_context fun_context
- global_context fdef
- in
-
- ()
end
diff --git a/compiler/InterpreterBorrowsCore.ml b/compiler/InterpreterBorrowsCore.ml
index bf083aa4..e7da045c 100644
--- a/compiler/InterpreterBorrowsCore.ml
+++ b/compiler/InterpreterBorrowsCore.ml
@@ -100,15 +100,18 @@ let rec compare_rtys (default : bool) (combine : bool -> bool -> bool)
(compare_regions : T.RegionId.id T.region -> T.RegionId.id T.region -> bool)
(ty1 : T.rty) (ty2 : T.rty) : bool =
let compare = compare_rtys default combine compare_regions in
+ (* Normalize the associated types *)
match (ty1, ty2) with
| T.Literal lit1, T.Literal lit2 ->
assert (lit1 = lit2);
default
- | T.Adt (id1, regions1, tys1, cgs1), T.Adt (id2, regions2, tys2, cgs2) ->
+ | T.Adt (id1, generics1), T.Adt (id2, generics2) ->
assert (id1 = id2);
(* There are no regions in the const generics, so we ignore them,
but we still check they are the same, for sanity *)
- assert (cgs1 = cgs2);
+ assert (generics1.const_generics = generics2.const_generics);
+
+ (* We also ignore the trait refs *)
(* The check for the ADTs is very crude: we simply compare the arguments
* two by two.
@@ -123,14 +126,14 @@ let rec compare_rtys (default : bool) (combine : bool -> bool -> bool)
* this check would still be a reasonable conservative approximation. *)
(* Check the region parameters *)
- let regions = List.combine regions1 regions2 in
+ let regions = List.combine generics1.regions generics2.regions in
let params_b =
List.fold_left
(fun b (r1, r2) -> combine b (compare_regions r1 r2))
default regions
in
(* Check the type parameters *)
- let tys = List.combine tys1 tys2 in
+ let tys = List.combine generics1.types generics2.types in
let tys_b =
List.fold_left
(fun b (ty1, ty2) -> combine b (compare ty1 ty2))
@@ -150,6 +153,11 @@ let rec compare_rtys (default : bool) (combine : bool -> bool -> bool)
| T.TypeVar id1, T.TypeVar id2 ->
assert (id1 = id2);
default
+ | T.TraitType _, T.TraitType _ ->
+ (* The types should have been normalized. If after normalization we
+ get trait types, we can consider them as variables *)
+ assert (ty1 = ty2);
+ default
| _ ->
log#lerror
(lazy
diff --git a/compiler/InterpreterExpansion.ml b/compiler/InterpreterExpansion.ml
index 81e73e3e..ea692386 100644
--- a/compiler/InterpreterExpansion.ml
+++ b/compiler/InterpreterExpansion.ml
@@ -9,6 +9,7 @@ module V = Values
module E = Expressions
module C = Contexts
module Subst = Substitute
+module Assoc = AssociatedTypes
module L = Logging
open TypesUtils
module Inv = Invariants
@@ -204,7 +205,7 @@ let apply_symbolic_expansion_non_borrow (config : C.config)
apply_symbolic_expansion_to_avalues config allow_reborrows original_sv
expansion ctx
-(** Compute the expansion of a non-assumed (i.e.: not [Option], [Box], etc.)
+(** Compute the expansion of a non-assumed (i.e.: not [Box], etc.)
adt value.
The function might return a list of values if the symbolic value to expand
@@ -214,18 +215,15 @@ let apply_symbolic_expansion_non_borrow (config : C.config)
doesn't allow the expansion of enumerations *containing several variants*.
*)
let compute_expanded_symbolic_non_assumed_adt_value (expand_enumerations : bool)
- (kind : V.sv_kind) (def_id : T.TypeDeclId.id)
- (regions : T.RegionId.id T.region list) (types : T.rty list)
- (cgs : T.const_generic list) (ctx : C.eval_ctx) : V.symbolic_expansion list
- =
+ (kind : V.sv_kind) (def_id : T.TypeDeclId.id) (generics : T.rgeneric_args)
+ (ctx : C.eval_ctx) : V.symbolic_expansion list =
(* Lookup the definition and check if it is an enumeration with several
* variants *)
let def = C.ctx_lookup_type_decl ctx def_id in
- assert (List.length regions = List.length def.T.region_params);
+ assert (List.length generics.regions = List.length def.T.generics.regions);
(* Retrieve, for every variant, the list of its instantiated field types *)
let variants_fields_types =
- Subst.type_decl_get_instantiated_variants_fields_rtypes def regions types
- cgs
+ Assoc.type_decl_get_inst_norm_variants_fields_rtypes ctx def generics
in
(* Check if there is strictly more than one variant *)
if List.length variants_fields_types > 1 && not expand_enumerations then
@@ -280,15 +278,14 @@ let compute_expanded_symbolic_box_value (kind : V.sv_kind) (boxed_ty : T.rty) :
doesn't allow the expansion of enumerations *containing several variants*.
*)
let compute_expanded_symbolic_adt_value (expand_enumerations : bool)
- (kind : V.sv_kind) (adt_id : T.type_id)
- (regions : T.RegionId.id T.region list) (types : T.rty list)
- (cgs : T.const_generic list) (ctx : C.eval_ctx) : V.symbolic_expansion list
- =
- match (adt_id, regions, types) with
+ (kind : V.sv_kind) (adt_id : T.type_id) (generics : T.rgeneric_args)
+ (ctx : C.eval_ctx) : V.symbolic_expansion list =
+ match (adt_id, generics.regions, generics.types) with
| T.AdtId def_id, _, _ ->
compute_expanded_symbolic_non_assumed_adt_value expand_enumerations kind
- def_id regions types cgs ctx
- | T.Tuple, [], _ -> [ compute_expanded_symbolic_tuple_value kind types ]
+ def_id generics ctx
+ | T.Tuple, [], _ ->
+ [ compute_expanded_symbolic_tuple_value kind generics.types ]
| T.Assumed T.Option, [], [ ty ] ->
compute_expanded_symbolic_option_value expand_enumerations kind ty
| T.Assumed T.Box, [], [ boxed_ty ] ->
@@ -543,12 +540,12 @@ let expand_symbolic_value_no_branching (config : C.config)
fun cf ctx ->
match rty with
(* ADTs *)
- | T.Adt (adt_id, regions, types, cgs) ->
+ | T.Adt (adt_id, generics) ->
(* Compute the expanded value *)
let allow_branching = false in
let seel =
compute_expanded_symbolic_adt_value allow_branching sv.sv_kind adt_id
- regions types cgs ctx
+ generics ctx
in
(* There should be exacly one branch *)
let see = Collections.List.to_cons_nil seel in
@@ -600,12 +597,12 @@ let expand_symbolic_adt (config : C.config) (sv : V.symbolic_value)
(* Execute *)
match rty with
(* ADTs *)
- | T.Adt (adt_id, regions, types, cgs) ->
+ | T.Adt (adt_id, generics) ->
let allow_branching = true in
(* Compute the expanded value *)
let seel =
compute_expanded_symbolic_adt_value allow_branching sv.sv_kind adt_id
- regions types cgs ctx
+ generics ctx
in
(* Apply *)
let seel = List.map (fun see -> (Some see, cf_branches)) seel in
@@ -679,7 +676,7 @@ let greedy_expand_symbolics_with_borrows (config : C.config) : cm_fun =
^ symbolic_value_to_string ctx sv));
let cc : cm_fun =
match sv.V.sv_ty with
- | T.Adt (AdtId def_id, _, _, _) ->
+ | T.Adt (AdtId def_id, _) ->
(* {!expand_symbolic_value_no_branching} checks if there are branchings,
* but we prefer to also check it here - this leads to cleaner messages
* and debugging *)
@@ -704,16 +701,16 @@ let greedy_expand_symbolics_with_borrows (config : C.config) : cm_fun =
[config]): "
^ Print.name_to_string def.name))
else expand_symbolic_value_no_branching config sv None
- | T.Adt ((Tuple | Assumed Box), _, _, _) | T.Ref (_, _, _) ->
+ | T.Adt ((Tuple | Assumed Box), _) | T.Ref (_, _, _) ->
(* Ok *)
expand_symbolic_value_no_branching config sv None
- | T.Adt (Assumed (Vec | Option | Array | Slice | Str | Range), _, _, _)
- ->
+ | T.Adt (Assumed (Vec | Option | Array | Slice | Str | Range), _) ->
(* We can't expand those *)
raise
(Failure
"Attempted to greedily expand an ADT which can't be expanded ")
- | T.TypeVar _ | T.Literal _ | Never -> raise (Failure "Unreachable")
+ | T.TypeVar _ | T.Literal _ | Never | T.TraitType _ ->
+ raise (Failure "Unreachable")
in
(* Compose and continue *)
comp cc expand cf ctx
diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml
index 2f6a7b49..51f6ff05 100644
--- a/compiler/InterpreterExpressions.ml
+++ b/compiler/InterpreterExpressions.ml
@@ -7,6 +7,7 @@ module E = Expressions
open Utils
module C = Contexts
module Subst = Substitute
+module Assoc = AssociatedTypes
module L = Logging
open TypesUtils
open ValuesUtils
@@ -141,11 +142,18 @@ let rec copy_value (allow_adt_copy : bool) (config : C.config)
| V.Adt av ->
(* Sanity check *)
(match v.V.ty with
- | T.Adt (T.Assumed (T.Box | Vec), _, _, _) ->
+ | T.Adt (T.Assumed (T.Box | Vec), _) ->
raise (Failure "Can't copy an assumed value other than Option")
- | T.Adt (T.AdtId _, _, _, _) -> assert allow_adt_copy
- | T.Adt ((T.Assumed Option | T.Tuple), _, _, _) -> () (* Ok *)
- | T.Adt (T.Assumed (Slice | T.Array), [], [ ty ], []) ->
+ | T.Adt (T.AdtId _, _) -> assert allow_adt_copy
+ | T.Adt ((T.Assumed Option | T.Tuple), _) -> () (* Ok *)
+ | T.Adt
+ ( T.Assumed (Slice | T.Array),
+ {
+ regions = [];
+ types = [ ty ];
+ const_generics = [];
+ trait_refs = [];
+ } ) ->
assert (ty_is_primitively_copyable ty)
| _ -> raise (Failure "Unreachable"));
let ctx, fields =
@@ -263,6 +271,9 @@ let eval_operand_no_reorganize (config : C.config) (op : E.operand)
match cv.value with
| E.CLiteral lit ->
cf (literal_to_typed_value (TypesUtils.ty_as_literal cv.ty) lit) ctx
+ | E.TraitConst (_trait_ref, _generics, _const_name) ->
+ (* TODO *)
+ raise (Failure "Unimplemented")
| E.CVar vid -> (
let ctx0 = ctx in
(* Lookup the const generic value *)
@@ -681,7 +692,8 @@ let eval_rvalue_aggregate (config : C.config)
| E.AggregatedTuple ->
let tys = List.map (fun (v : V.typed_value) -> v.V.ty) values in
let v = V.Adt { variant_id = None; field_values = values } in
- let ty = T.Adt (T.Tuple, [], tys, []) in
+ let generics = TypesUtils.mk_generic_args [] tys [] [] in
+ let ty = T.Adt (T.Tuple, generics) in
let aggregated : V.typed_value = { V.value = v; ty } in
(* Call the continuation *)
cf aggregated ctx
@@ -692,20 +704,22 @@ let eval_rvalue_aggregate (config : C.config)
assert (List.length values = 1)
else raise (Failure "Unreachable");
(* Construt the value *)
- let aty = T.Adt (T.Assumed T.Option, [], [ ty ], []) in
+ let generics = TypesUtils.mk_generic_args [] [ ty ] [] [] in
+ let aty = T.Adt (T.Assumed T.Option, generics) in
let av : V.adt_value =
{ V.variant_id = Some variant_id; V.field_values = values }
in
let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in
(* Call the continuation *)
cf aggregated ctx
- | E.AggregatedAdt (def_id, opt_variant_id, regions, types, cgs) ->
+ | E.AggregatedAdt (def_id, opt_variant_id, generics) ->
(* Sanity checks *)
let type_decl = C.ctx_lookup_type_decl ctx def_id in
- assert (List.length type_decl.region_params = List.length regions);
+ assert (
+ List.length type_decl.generics.regions = List.length generics.regions);
let expected_field_types =
- Subst.ctx_adt_get_instantiated_field_etypes ctx def_id opt_variant_id
- types cgs
+ Assoc.ctx_adt_get_inst_norm_field_etypes ctx def_id opt_variant_id
+ generics
in
assert (
expected_field_types
@@ -714,7 +728,7 @@ let eval_rvalue_aggregate (config : C.config)
let av : V.adt_value =
{ V.variant_id = opt_variant_id; V.field_values = values }
in
- let aty = T.Adt (T.AdtId def_id, regions, types, cgs) in
+ let aty = T.Adt (T.AdtId def_id, generics) in
let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in
(* Call the continuation *)
cf aggregated ctx
@@ -734,7 +748,8 @@ let eval_rvalue_aggregate (config : C.config)
let av : V.adt_value =
{ V.variant_id = None; V.field_values = values }
in
- let aty = T.Adt (T.Assumed T.Range, [], [ ety ], []) in
+ let generics = TypesUtils.mk_generic_args_from_types [ ety ] in
+ let aty = T.Adt (T.Assumed T.Range, generics) in
let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in
(* Call the continuation *)
cf aggregated ctx
@@ -744,7 +759,8 @@ let eval_rvalue_aggregate (config : C.config)
(* Sanity check: the number of values is consistent with the length *)
let len = (literal_as_scalar (const_generic_as_literal cg)).value in
assert (len = Z.of_int (List.length values));
- let ty = T.Adt (T.Assumed T.Array, [], [ ety ], [ cg ]) in
+ let generics = TypesUtils.mk_generic_args [] [ ety ] [ cg ] [] in
+ let ty = T.Adt (T.Assumed T.Array, generics) in
(* In order to generate a better AST, we introduce a symbolic
value equal to the array. The reason is that otherwise, the
array we introduce here might be duplicated in the generated
@@ -777,7 +793,7 @@ let eval_rvalue_not_global (config : C.config) (rvalue : E.rvalue)
(* Delegate to the proper auxiliary function *)
match rvalue with
| E.Use op -> comp_wrap (eval_operand config op) ctx
- | E.Ref (p, bkind) -> comp_wrap (eval_rvalue_ref config p bkind) ctx
+ | E.RvRef (p, bkind) -> comp_wrap (eval_rvalue_ref config p bkind) ctx
| E.UnaryOp (unop, op) -> eval_unary_op config unop op cf ctx
| E.BinaryOp (binop, op1, op2) -> eval_binary_op config binop op1 op2 cf ctx
| E.Aggregate (aggregate_kind, ops) ->
diff --git a/compiler/InterpreterLoopsJoinCtxs.ml b/compiler/InterpreterLoopsJoinCtxs.ml
index 10205c27..a34a7d06 100644
--- a/compiler/InterpreterLoopsJoinCtxs.ml
+++ b/compiler/InterpreterLoopsJoinCtxs.ml
@@ -554,10 +554,13 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx)
C.type_context;
fun_context;
global_context;
+ trait_decls_context;
+ trait_impls_context;
region_groups;
type_vars;
const_generic_vars;
const_generic_vars_map;
+ trait_clauses;
env = _;
ended_regions = ended_regions0;
} =
@@ -567,10 +570,13 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx)
C.type_context = _;
fun_context = _;
global_context = _;
+ trait_decls_context = _;
+ trait_impls_context = _;
region_groups = _;
type_vars = _;
const_generic_vars = _;
const_generic_vars_map = _;
+ trait_clauses = _;
env = _;
ended_regions = ended_regions1;
} =
@@ -582,10 +588,13 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx)
C.type_context;
fun_context;
global_context;
+ trait_decls_context;
+ trait_impls_context;
region_groups;
type_vars;
const_generic_vars;
const_generic_vars_map;
+ trait_clauses;
env;
ended_regions;
}
diff --git a/compiler/InterpreterLoopsMatchCtxs.ml b/compiler/InterpreterLoopsMatchCtxs.ml
index 9248e513..8cab546e 100644
--- a/compiler/InterpreterLoopsMatchCtxs.ml
+++ b/compiler/InterpreterLoopsMatchCtxs.ml
@@ -149,20 +149,25 @@ let rec match_types (match_distinct_types : 'r T.ty -> 'r T.ty -> 'r T.ty)
(match_regions : 'r -> 'r -> 'r) (ty0 : 'r T.ty) (ty1 : 'r T.ty) : 'r T.ty =
let match_rec = match_types match_distinct_types match_regions in
match (ty0, ty1) with
- | Adt (id0, regions0, tys0, cgs0), Adt (id1, regions1, tys1, cgs1) ->
+ | Adt (id0, generics0), Adt (id1, generics1) ->
assert (id0 = id1);
- assert (cgs0 = cgs1);
+ assert (generics0.const_generics = generics1.const_generics);
+ assert (generics0.trait_refs = generics1.trait_refs);
let id = id0 in
- let cgs = cgs1 in
+ let const_generics = generics1.const_generics in
+ let trait_refs = generics1.trait_refs in
let regions =
List.map
(fun (id0, id1) -> match_regions id0 id1)
- (List.combine regions0 regions1)
+ (List.combine generics0.regions generics1.regions)
in
- let tys =
- List.map (fun (ty0, ty1) -> match_rec ty0 ty1) (List.combine tys0 tys1)
+ let types =
+ List.map
+ (fun (ty0, ty1) -> match_rec ty0 ty1)
+ (List.combine generics0.types generics1.types)
in
- Adt (id, regions, tys, cgs)
+ let generics = { T.regions; types; const_generics; trait_refs } in
+ Adt (id, generics)
| TypeVar vid0, TypeVar vid1 ->
assert (vid0 = vid1);
let vid = vid0 in
diff --git a/compiler/InterpreterPaths.ml b/compiler/InterpreterPaths.ml
index 04dc8892..465d0028 100644
--- a/compiler/InterpreterPaths.ml
+++ b/compiler/InterpreterPaths.ml
@@ -3,6 +3,7 @@ module V = Values
module E = Expressions
module C = Contexts
module Subst = Substitute
+module Assoc = AssociatedTypes
module L = Logging
open Cps
open ValuesUtils
@@ -97,7 +98,7 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx)
match (pe, v.V.value, v.V.ty) with
| ( Field (((ProjAdt (_, _) | ProjOption _) as proj_kind), field_id),
V.Adt adt,
- T.Adt (type_id, _, _, _) ) -> (
+ T.Adt (type_id, _) ) -> (
(* Check consistency *)
(match (proj_kind, type_id) with
| ProjAdt (def_id, opt_variant_id), T.AdtId def_id' ->
@@ -119,8 +120,7 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx)
let updated = { v with value = nadt } in
Ok (ctx, { res with updated }))
(* Tuples *)
- | Field (ProjTuple arity, field_id), V.Adt adt, T.Adt (T.Tuple, _, _, _)
- -> (
+ | Field (ProjTuple arity, field_id), V.Adt adt, T.Adt (T.Tuple, _) -> (
assert (arity = List.length adt.field_values);
let fv = T.FieldId.nth adt.field_values field_id in
(* Project *)
@@ -145,9 +145,9 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx)
(* Box dereferencement *)
| ( DerefBox,
Adt { variant_id = None; field_values = [ bv ] },
- T.Adt (T.Assumed T.Box, _, _, _) ) -> (
- (* We allow moving inside of boxes. In practice, this kind of
- * manipulations should happen only inside unsage code, so
+ T.Adt (T.Assumed T.Box, _) ) -> (
+ (* We allow moving outside of boxes. In practice, this kind of
+ * manipulations should happen only inside unsafe code, so
* it shouldn't happen due to user code, and we leverage it
* when implementing box dereferencement for the concrete
* interpreter *)
@@ -357,24 +357,23 @@ let write_place (access : access_kind) (p : E.place) (nv : V.typed_value)
| Error e -> raise (Failure ("Unreachable: " ^ show_path_fail_kind e))
| Ok ctx -> ctx
-let compute_expanded_bottom_adt_value (tyctx : T.type_decl T.TypeDeclId.Map.t)
+let compute_expanded_bottom_adt_value (ctx : C.eval_ctx)
(def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option)
- (regions : T.erased_region list) (types : T.ety list)
- (cgs : T.const_generic list) : V.typed_value =
+ (generics : T.egeneric_args) : V.typed_value =
(* Lookup the definition and check if it is an enumeration - it
should be an enumeration if and only if the projection element
is a field projection with *some* variant id. Retrieve the list
of fields at the same time. *)
- let def = T.TypeDeclId.Map.find def_id tyctx in
- assert (List.length regions = List.length def.T.region_params);
+ let def = C.ctx_lookup_type_decl ctx def_id in
+ assert (List.length generics.regions = List.length def.T.generics.regions);
(* Compute the field types *)
let field_types =
- Subst.type_decl_get_instantiated_field_etypes def opt_variant_id types cgs
+ Assoc.type_decl_get_inst_norm_field_etypes ctx def opt_variant_id generics
in
(* Initialize the expanded value *)
let fields = List.map mk_bottom field_types in
let av = V.Adt { variant_id = opt_variant_id; field_values = fields } in
- let ty = T.Adt (T.AdtId def_id, regions, types, cgs) in
+ let ty = T.Adt (T.AdtId def_id, generics) in
{ V.value = av; V.ty }
let compute_expanded_bottom_option_value (variant_id : T.VariantId.id)
@@ -387,7 +386,8 @@ let compute_expanded_bottom_option_value (variant_id : T.VariantId.id)
else raise (Failure "Unreachable")
in
let av = V.Adt { variant_id = Some variant_id; field_values } in
- let ty = T.Adt (T.Assumed T.Option, [], [ param_ty ], []) in
+ let generics = TypesUtils.mk_generic_args [] [ param_ty ] [] [] in
+ let ty = T.Adt (T.Assumed T.Option, generics) in
{ V.value = av; ty }
let compute_expanded_bottom_tuple_value (field_types : T.ety list) :
@@ -395,7 +395,8 @@ let compute_expanded_bottom_tuple_value (field_types : T.ety list) :
(* Generate the field values *)
let fields = List.map mk_bottom field_types in
let v = V.Adt { variant_id = None; field_values = fields } in
- let ty = T.Adt (T.Tuple, [], field_types, []) in
+ let generics = TypesUtils.mk_generic_args [] field_types [] [] in
+ let ty = T.Adt (T.Tuple, generics) in
{ V.value = v; V.ty }
(** Auxiliary helper to expand {!V.Bottom} values.
@@ -447,19 +448,29 @@ let expand_bottom_value_from_projection (access : access_kind) (p : E.place)
match (pe, ty) with
(* "Regular" ADTs *)
| ( Field (ProjAdt (def_id, opt_variant_id), _),
- T.Adt (T.AdtId def_id', regions, types, cgs) ) ->
+ T.Adt (T.AdtId def_id', generics) ) ->
assert (def_id = def_id');
- compute_expanded_bottom_adt_value ctx.type_context.type_decls def_id
- opt_variant_id regions types cgs
+ compute_expanded_bottom_adt_value ctx def_id opt_variant_id generics
(* Option *)
| ( Field (ProjOption variant_id, _),
- T.Adt (T.Assumed T.Option, [], [ ty ], []) ) ->
+ T.Adt
+ ( T.Assumed T.Option,
+ {
+ T.regions = [];
+ types = [ ty ];
+ const_generics = [];
+ trait_refs = [];
+ } ) ) ->
compute_expanded_bottom_option_value variant_id ty
(* Tuples *)
- | Field (ProjTuple arity, _), T.Adt (T.Tuple, [], tys, []) ->
- assert (arity = List.length tys);
+ | ( Field (ProjTuple arity, _),
+ T.Adt
+ ( T.Tuple,
+ { T.regions = []; types; const_generics = []; trait_refs = [] } ) )
+ ->
+ assert (arity = List.length types);
(* Generate the field values *)
- compute_expanded_bottom_tuple_value tys
+ compute_expanded_bottom_tuple_value types
| _ ->
raise
(Failure
diff --git a/compiler/InterpreterPaths.mli b/compiler/InterpreterPaths.mli
index 4a9f3b41..041b0a97 100644
--- a/compiler/InterpreterPaths.mli
+++ b/compiler/InterpreterPaths.mli
@@ -3,6 +3,7 @@ module V = Values
module E = Expressions
module C = Contexts
module Subst = Substitute
+module Assoc = AssociatedTypes
module L = Logging
open Cps
open InterpreterExpansion
@@ -56,12 +57,10 @@ val compute_expanded_bottom_tuple_value : T.ety list -> V.typed_value
(** Compute an expanded ADT ⊥ value *)
val compute_expanded_bottom_adt_value :
- T.type_decl T.TypeDeclId.Map.t ->
+ C.eval_ctx ->
T.TypeDeclId.id ->
T.VariantId.id option ->
- T.erased_region list ->
- T.ety list ->
- T.const_generic list ->
+ T.egeneric_args ->
V.typed_value
(** Compute an expanded [Option] ⊥ value *)
diff --git a/compiler/InterpreterProjectors.ml b/compiler/InterpreterProjectors.ml
index faed066b..9e0c2b75 100644
--- a/compiler/InterpreterProjectors.ml
+++ b/compiler/InterpreterProjectors.ml
@@ -3,6 +3,7 @@ module V = Values
module E = Expressions
module C = Contexts
module Subst = Substitute
+module Assoc = AssociatedTypes
module L = Logging
open TypesUtils
open InterpreterUtils
@@ -24,12 +25,12 @@ let rec apply_proj_borrows_on_shared_borrow (ctx : C.eval_ctx)
else
match (v.V.value, ty) with
| V.Literal _, T.Literal _ -> []
- | V.Adt adt, T.Adt (id, region_params, tys, cgs) ->
+ | V.Adt adt, T.Adt (id, generics) ->
(* Retrieve the types of the fields *)
let field_types =
- Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id
- region_params tys cgs
+ Assoc.ctx_adt_value_get_inst_norm_field_rtypes ctx adt id generics
in
+
(* Project over the field values *)
let fields_types = List.combine adt.V.field_values field_types in
let proj_fields =
@@ -103,11 +104,10 @@ let rec apply_proj_borrows (check_symbolic_no_ended : bool) (ctx : C.eval_ctx)
let value : V.avalue =
match (v.V.value, ty) with
| V.Literal _, T.Literal _ -> V.AIgnored
- | V.Adt adt, T.Adt (id, region_params, tys, cgs) ->
+ | V.Adt adt, T.Adt (id, generics) ->
(* Retrieve the types of the fields *)
let field_types =
- Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id
- region_params tys cgs
+ Assoc.ctx_adt_value_get_inst_norm_field_rtypes ctx adt id generics
in
(* Project over the field values *)
let fields_types = List.combine adt.V.field_values field_types in
@@ -268,8 +268,7 @@ let apply_proj_loans_on_symbolic_expansion (regions : T.RegionId.Set.t)
let (value, ty) : V.avalue * T.rty =
match (see, original_sv_ty) with
| SeLiteral _, T.Literal _ -> (V.AIgnored, original_sv_ty)
- | SeAdt (variant_id, field_values), T.Adt (_id, _region_params, _tys, _cgs)
- ->
+ | SeAdt (variant_id, field_values), T.Adt (_id, _generics) ->
(* Project over the field values *)
let field_values =
List.map
diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml
index 6d520059..d38f8b95 100644
--- a/compiler/InterpreterStatements.ml
+++ b/compiler/InterpreterStatements.ml
@@ -17,6 +17,7 @@ open InterpreterProjectors
open InterpreterExpansion
open InterpreterPaths
open InterpreterExpressions
+module PCtx = Print.EvalCtxLlbcAst
(** The local logger *)
let log = L.statements_log
@@ -232,9 +233,8 @@ let set_discriminant (config : C.config) (p : E.place)
let update_value cf (v : V.typed_value) : m_fun =
fun ctx ->
match (v.V.ty, v.V.value) with
- | ( T.Adt
- (((T.AdtId _ | T.Assumed T.Option) as type_id), regions, types, cgs),
- V.Adt av ) -> (
+ | T.Adt (((T.AdtId _ | T.Assumed T.Option) as type_id), generics), V.Adt av
+ -> (
(* There are two situations:
- either the discriminant is already the proper one (in which case we
don't do anything)
@@ -251,28 +251,26 @@ let set_discriminant (config : C.config) (p : E.place)
let bottom_v =
match type_id with
| T.AdtId def_id ->
- compute_expanded_bottom_adt_value
- ctx.type_context.type_decls def_id (Some variant_id)
- regions types cgs
+ compute_expanded_bottom_adt_value ctx def_id
+ (Some variant_id) generics
| T.Assumed T.Option ->
- assert (regions = []);
+ assert (generics.regions = []);
compute_expanded_bottom_option_value variant_id
- (Collections.List.to_cons_nil types)
+ (Collections.List.to_cons_nil generics.types)
| _ -> raise (Failure "Unreachable")
in
assign_to_place config bottom_v p (cf Unit) ctx)
- | ( T.Adt
- (((T.AdtId _ | T.Assumed T.Option) as type_id), regions, types, cgs),
- V.Bottom ) ->
+ | T.Adt (((T.AdtId _ | T.Assumed T.Option) as type_id), generics), V.Bottom
+ ->
let bottom_v =
match type_id with
| T.AdtId def_id ->
- compute_expanded_bottom_adt_value ctx.type_context.type_decls
- def_id (Some variant_id) regions types cgs
+ compute_expanded_bottom_adt_value ctx def_id (Some variant_id)
+ generics
| T.Assumed T.Option ->
- assert (regions = []);
+ assert (generics.regions = []);
compute_expanded_bottom_option_value variant_id
- (Collections.List.to_cons_nil types)
+ (Collections.List.to_cons_nil generics.types)
| _ -> raise (Failure "Unreachable")
in
assign_to_place config bottom_v p (cf Unit) ctx
@@ -301,24 +299,34 @@ let ctx_push_frame (ctx : C.eval_ctx) : C.eval_ctx =
let push_frame : cm_fun = fun cf ctx -> cf (ctx_push_frame ctx)
(** Small helper: compute the type of the return value for a specific
- instantiation of a non-local function.
+ instantiation of an assumed function.
*)
-let get_non_local_function_return_type (fid : A.assumed_fun_id)
- (region_params : T.erased_region list) (type_params : T.ety list)
- (const_generic_params : T.const_generic list) : T.ety =
+let get_assumed_function_return_type (ctx : C.eval_ctx) (fid : A.assumed_fun_id)
+ (generics : T.egeneric_args) : T.ety =
+ assert (generics.trait_refs = []);
(* [Box::free] has a special treatment *)
- match (fid, region_params, type_params, const_generic_params) with
- | A.BoxFree, [], [ _ ], [] -> mk_unit_ty
+ match fid with
+ | A.BoxFree ->
+ assert (generics.regions = []);
+ assert (List.length generics.types = 1);
+ assert (generics.const_generics = []);
+ mk_unit_ty
| _ ->
(* Retrieve the function's signature *)
let sg = Assumed.get_assumed_sig fid in
(* Instantiate the return type *)
- let tsubst = Subst.make_type_subst_from_vars sg.type_params type_params in
- let cgsubst =
- Subst.make_const_generic_subst_from_vars sg.const_generic_params
- const_generic_params
+ (* There shouldn't be any reference to Self *)
+ let tr_self : T.erased_region T.trait_instance_id =
+ T.UnknownTrait __FUNCTION__
+ in
+ let { Subst.r_subst = _; ty_subst; cg_subst; tr_subst; tr_self } =
+ Subst.make_esubst_from_generics sg.generics generics tr_self
in
- Subst.erase_regions_substitute_types tsubst cgsubst sg.output
+ let ty =
+ Subst.erase_regions_substitute_types ty_subst cg_subst tr_subst tr_self
+ sg.output
+ in
+ Assoc.ctx_normalize_type ctx ty
let move_return_value (config : C.config) (pop_return_value : bool)
(cf : V.typed_value option -> m_fun) : m_fun =
@@ -418,19 +426,19 @@ let pop_frame_assign (config : C.config) (dest : E.place) : cm_fun =
in
comp cf_pop cf_assign
-(** Auxiliary function - see {!eval_non_local_function_call} *)
-let eval_replace_concrete (_config : C.config)
- (_region_params : T.erased_region list) (_type_params : T.ety list)
- (_cg_args : T.const_generic list) : cm_fun =
+(** Auxiliary function - see {!eval_assumed_function_call} *)
+let eval_replace_concrete (_config : C.config) (_generics : T.egeneric_args) :
+ cm_fun =
fun _cf _ctx -> raise Unimplemented
-(** Auxiliary function - see {!eval_non_local_function_call} *)
-let eval_box_new_concrete (config : C.config)
- (region_params : T.erased_region list) (type_params : T.ety list)
- (cg_args : T.const_generic list) : cm_fun =
+(** Auxiliary function - see {!eval_assumed_function_call} *)
+let eval_box_new_concrete (config : C.config) (generics : T.egeneric_args) :
+ cm_fun =
fun cf ctx ->
(* Check and retrieve the arguments *)
- match (region_params, type_params, cg_args, ctx.env) with
+ match
+ (generics.regions, generics.types, generics.const_generics, ctx.env)
+ with
| ( [],
[ boxed_ty ],
[],
@@ -448,7 +456,8 @@ let eval_box_new_concrete (config : C.config)
(* Create the new box *)
let cf_create cf (moved_input_value : V.typed_value) : m_fun =
(* Create the box value *)
- let box_ty = T.Adt (T.Assumed T.Box, [], [ boxed_ty ], []) in
+ let generics = TypesUtils.mk_generic_args_from_types [ boxed_ty ] in
+ let box_ty = T.Adt (T.Assumed T.Box, generics) in
let box_v =
V.Adt { variant_id = None; field_values = [ moved_input_value ] }
in
@@ -467,13 +476,14 @@ let eval_box_new_concrete (config : C.config)
| _ -> raise (Failure "Inconsistent state")
(** Auxiliary function which factorizes code to evaluate [std::Deref::deref]
- and [std::DerefMut::deref_mut] - see {!eval_non_local_function_call} *)
+ and [std::DerefMut::deref_mut] - see {!eval_assumed_function_call} *)
let eval_box_deref_mut_or_shared_concrete (config : C.config)
- (region_params : T.erased_region list) (type_params : T.ety list)
- (cg_args : T.const_generic list) (is_mut : bool) : cm_fun =
+ (generics : T.egeneric_args) (is_mut : bool) : cm_fun =
fun cf ctx ->
(* Check the arguments *)
- match (region_params, type_params, cg_args, ctx.env) with
+ match
+ (generics.regions, generics.types, generics.const_generics, ctx.env)
+ with
| ( [],
[ boxed_ty ],
[],
@@ -495,7 +505,7 @@ let eval_box_deref_mut_or_shared_concrete (config : C.config)
{ E.var_id = input_var.C.index; projection = [ E.Deref; E.DerefBox ] }
in
let borrow_kind = if is_mut then E.Mut else E.Shared in
- let rv = E.Ref (p, borrow_kind) in
+ let rv = E.RvRef (p, borrow_kind) in
let cf_borrow = eval_rvalue_not_global config rv in
(* Move the borrow to its destination *)
@@ -514,23 +524,19 @@ let eval_box_deref_mut_or_shared_concrete (config : C.config)
comp cf_borrow cf_move cf ctx
| _ -> raise (Failure "Inconsistent state")
-(** Auxiliary function - see {!eval_non_local_function_call} *)
-let eval_box_deref_concrete (config : C.config)
- (region_params : T.erased_region list) (type_params : T.ety list)
- (cg_args : T.const_generic list) : cm_fun =
+(** Auxiliary function - see {!eval_assumed_function_call} *)
+let eval_box_deref_concrete (config : C.config) (generics : T.egeneric_args) :
+ cm_fun =
let is_mut = false in
- eval_box_deref_mut_or_shared_concrete config region_params type_params cg_args
- is_mut
+ eval_box_deref_mut_or_shared_concrete config generics is_mut
-(** Auxiliary function - see {!eval_non_local_function_call} *)
-let eval_box_deref_mut_concrete (config : C.config)
- (region_params : T.erased_region list) (type_params : T.ety list)
- (cg_args : T.const_generic list) : cm_fun =
+(** Auxiliary function - see {!eval_assumed_function_call} *)
+let eval_box_deref_mut_concrete (config : C.config) (generics : T.egeneric_args)
+ : cm_fun =
let is_mut = true in
- eval_box_deref_mut_or_shared_concrete config region_params type_params cg_args
- is_mut
+ eval_box_deref_mut_or_shared_concrete config generics is_mut
-(** Auxiliary function - see {!eval_non_local_function_call}.
+(** Auxiliary function - see {!eval_assumed_function_call}.
[Box::free] is not handled the same way as the other assumed functions:
- in the regular case, whenever we need to evaluate an assumed function,
@@ -549,11 +555,10 @@ let eval_box_deref_mut_concrete (config : C.config)
It thus updates the box value (by calling {!drop_value}) and updates
the destination (by setting it to [()]).
*)
-let eval_box_free (config : C.config) (region_params : T.erased_region list)
- (type_params : T.ety list) (cg_args : T.const_generic list)
+let eval_box_free (config : C.config) (generics : T.egeneric_args)
(args : E.operand list) (dest : E.place) : cm_fun =
fun cf ctx ->
- match (region_params, type_params, cg_args, args) with
+ match (generics.regions, generics.types, generics.const_generics, args) with
| [], [ boxed_ty ], [], [ E.Move input_box_place ] ->
(* Required type checking *)
let input_box = InterpreterPaths.read_place Write input_box_place ctx in
@@ -570,20 +575,18 @@ let eval_box_free (config : C.config) (region_params : T.erased_region list)
cc cf ctx
| _ -> raise (Failure "Inconsistent state")
-(** Auxiliary function - see {!eval_non_local_function_call} *)
+(** Auxiliary function - see {!eval_assumed_function_call} *)
let eval_vec_function_concrete (_config : C.config) (_fid : A.assumed_fun_id)
- (_region_params : T.erased_region list) (_type_params : T.ety list)
- (_cg_args : T.const_generic list) : cm_fun =
+ (_generics : T.egeneric_args) : cm_fun =
fun _cf _ctx -> raise Unimplemented
(** Evaluate a non-local function call in concrete mode *)
-let eval_non_local_function_call_concrete (config : C.config)
- (fid : A.assumed_fun_id) (region_params : T.erased_region list)
- (type_params : T.ety list) (cg_args : T.const_generic list)
+let eval_assumed_function_call_concrete (config : C.config)
+ (fid : A.assumed_fun_id) (generics : T.egeneric_args)
(args : E.operand list) (dest : E.place) : cm_fun =
(* Sanity check: we don't fully handle the const generic vars environment
in concrete mode yet *)
- assert (cg_args = []);
+ assert (generics.const_generics = []);
(* There are two cases (and this is extremely annoying):
- the function is not box_free
- the function is box_free
@@ -592,7 +595,7 @@ let eval_non_local_function_call_concrete (config : C.config)
match fid with
| A.BoxFree ->
(* Degenerate case: box_free *)
- eval_box_free config region_params type_params cg_args args dest
+ eval_box_free config generics args dest
| _ ->
(* "Normal" case: not box_free *)
(* Evaluate the operands *)
@@ -607,16 +610,14 @@ let eval_non_local_function_call_concrete (config : C.config)
* but it made it less clear where the computed values came from,
* so we reversed the modifications. *)
let cf_eval_call cf (args_vl : V.typed_value list) : m_fun =
+ fun ctx ->
(* Push the stack frame: we initialize the frame with the return variable,
and one variable per input argument *)
let cc = push_frame in
(* Create and push the return variable *)
let ret_vid = E.VarId.zero in
- let ret_ty =
- get_non_local_function_return_type fid region_params type_params
- cg_args
- in
+ let ret_ty = get_assumed_function_return_type ctx fid generics in
let ret_var = mk_var ret_vid (Some "@return") ret_ty in
let cc = comp cc (push_uninitialized_var ret_var) in
@@ -633,20 +634,14 @@ let eval_non_local_function_call_concrete (config : C.config)
* access to a body. *)
let cf_eval_body : cm_fun =
match fid with
- | A.Replace ->
- eval_replace_concrete config region_params type_params cg_args
- | BoxNew ->
- eval_box_new_concrete config region_params type_params cg_args
- | BoxDeref ->
- eval_box_deref_concrete config region_params type_params cg_args
- | BoxDerefMut ->
- eval_box_deref_mut_concrete config region_params type_params
- cg_args
+ | A.Replace -> eval_replace_concrete config generics
+ | BoxNew -> eval_box_new_concrete config generics
+ | BoxDeref -> eval_box_deref_concrete config generics
+ | BoxDerefMut -> eval_box_deref_mut_concrete config generics
| BoxFree ->
(* Should have been treated above *) raise (Failure "Unreachable")
| VecNew | VecPush | VecInsert | VecLen | VecIndex | VecIndexMut ->
- eval_vec_function_concrete config fid region_params type_params
- cg_args
+ eval_vec_function_concrete config fid generics
| ArrayIndexShared | ArrayIndexMut | ArrayToSliceShared
| ArrayToSliceMut | ArraySubsliceShared | ArraySubsliceMut
| SliceIndexShared | SliceIndexMut | SliceSubsliceShared
@@ -660,13 +655,13 @@ let eval_non_local_function_call_concrete (config : C.config)
let cc = comp cc (pop_frame_assign config dest) in
(* Continue *)
- cc cf
+ cc cf ctx
in
(* Compose and apply *)
comp cf_eval_ops cf_eval_call
-let instantiate_fun_sig (type_params : T.ety list)
- (cg_args : T.const_generic list) (sg : A.fun_sig) : A.inst_fun_sig =
+let instantiate_fun_sig (ctx : C.eval_ctx) (generics : T.egeneric_args)
+ (tr_self : T.rtrait_instance_id) (sg : A.fun_sig) : A.inst_fun_sig =
(* Generate fresh abstraction ids and create a substitution from region
* group ids to abstraction ids *)
let rg_abs_ids_bindings =
@@ -685,7 +680,7 @@ let instantiate_fun_sig (type_params : T.ety list)
T.RegionGroupId.Map.find rg_id asubst_map
in
(* Generate fresh regions and their substitutions *)
- let _, rsubst, _ = Subst.fresh_regions_with_substs sg.region_params in
+ let _, rsubst, _ = Subst.fresh_regions_with_substs sg.generics.regions in
(* Generate the type substitution
* Note that we need the substitution to map the type variables to
* {!rty} types (not {!ety}). In order to do that, we convert the
@@ -694,13 +689,28 @@ let instantiate_fun_sig (type_params : T.ety list)
* This is a current limitation of the analysis: there is still some
* work to do to properly handle full type parametrization.
* *)
- let rtype_params = List.map ety_no_regions_to_rty type_params in
- let tsubst = Subst.make_type_subst_from_vars sg.type_params rtype_params in
+ let rtype_params = List.map ety_no_regions_to_rty generics.types in
+ let tsubst = Subst.make_type_subst_from_vars sg.generics.types rtype_params in
let cgsubst =
- Subst.make_const_generic_subst_from_vars sg.const_generic_params cg_args
+ Subst.make_const_generic_subst_from_vars sg.generics.const_generics
+ generics.const_generics
+ in
+ (* TODO: something annoying with the trait ref subst: we need to use region
+ types, but the arguments use erased regions. For now we use the fact
+ that no regions should appear inside. In the future: we should merge
+ ety and rty. *)
+ let trait_refs =
+ List.map TypesUtils.etrait_ref_no_regions_to_gr_trait_ref
+ generics.trait_refs
+ in
+ let tr_subst =
+ Subst.make_trait_subst_from_clauses sg.generics.trait_clauses trait_refs
in
(* Substitute the signature *)
- let inst_sig = Subst.substitute_signature asubst rsubst tsubst cgsubst sg in
+ let inst_sig =
+ Assoc.ctx_subst_norm_signature ctx asubst rsubst tsubst cgsubst tr_subst
+ tr_self sg
+ in
(* Return *)
inst_sig
@@ -839,7 +849,7 @@ let rec eval_statement (config : C.config) (st : A.statement) : st_cm_fun =
match rvalue with
| E.Global _ -> raise (Failure "Unreachable")
| E.Use _
- | E.Ref (_, (E.Shared | E.Mut | E.TwoPhaseMut | E.Shallow))
+ | E.RvRef (_, (E.Shared | E.Mut | E.TwoPhaseMut | E.Shallow))
| E.UnaryOp _ | E.BinaryOp _ | E.Discriminant _
| E.Aggregate _ ->
let rp = rvalue_get_place rvalue in
@@ -896,7 +906,9 @@ and eval_global (config : C.config) (dest : E.place) (gid : LA.GlobalDeclId.id)
match config.mode with
| ConcreteMode ->
(* Treat the evaluation of the global as a call to the global body (without arguments) *)
- (eval_local_function_call_concrete config global.body_id [] [] [] [] dest)
+ let generics = TypesUtils.mk_empty_generic_args in
+ (eval_transparent_function_call_concrete config global.body_id generics []
+ dest)
cf ctx
| SymbolicMode ->
(* Generate a fresh symbolic value. In the translation, this fresh symbolic value will be
@@ -1040,26 +1052,26 @@ and eval_switch (config : C.config) (switch : A.switch) : st_cm_fun =
(** Evaluate a function call (auxiliary helper for [eval_statement]) *)
and eval_function_call (config : C.config) (call : A.call) : st_cm_fun =
- (* There are two cases:
+ (* There are several cases:
- this is a local function, in which case we execute its body
- - this is a non-local function, in which case there is a special treatment
+ - this is an assumed function, in which case there is a special treatment
+ - this is a trait method
*)
match call.func with
- | A.Regular fid ->
- eval_local_function_call config fid call.region_args call.type_args
- call.const_generic_args call.args call.dest
- | A.Assumed fid ->
- eval_non_local_function_call config fid call.region_args call.type_args
- call.const_generic_args call.args call.dest
+ | A.FunId (A.Regular fid) ->
+ eval_transparent_function_call config fid call.generics call.args
+ call.dest
+ | A.FunId (A.Assumed fid) ->
+ eval_assumed_function_call config fid call.generics call.args call.dest
+ | A.TraitMethod _ -> raise (Failure "Unimplemented")
(** Evaluate a local (i.e., non-assumed) function call in concrete mode *)
-and eval_local_function_call_concrete (config : C.config) (fid : A.FunDeclId.id)
- (_region_args : T.erased_region list) (type_args : T.ety list)
- (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) :
- st_cm_fun =
+and eval_transparent_function_call_concrete (config : C.config)
+ (fid : A.FunDeclId.id) (generics : T.egeneric_args) (args : E.operand list)
+ (dest : E.place) : st_cm_fun =
(* Sanity check: we don't fully handle the const generic vars environment
in concrete mode yet *)
- assert (cg_args = []);
+ assert (generics.const_generics = []);
fun cf ctx ->
(* Retrieve the (correctly instantiated) body *)
let def = C.ctx_lookup_fun_decl ctx fid in
@@ -1073,16 +1085,14 @@ and eval_local_function_call_concrete (config : C.config) (fid : A.FunDeclId.id)
^ Print.name_to_string def.name))
| Some body -> body
in
- let tsubst =
- Subst.make_type_subst_from_vars def.A.signature.type_params type_args
- in
- let cgsubst =
- Subst.make_const_generic_subst_from_vars
- def.A.signature.const_generic_params cg_args
- in
- let locals, body_st =
- Subst.fun_body_substitute_in_body tsubst cgsubst body
+ (* TODO: we need to normalize the types if we want to correctly support traits *)
+ assert (ctx.trait_clauses = [] && generics.trait_refs = []);
+ (* There shouldn't be any reference to Self *)
+ let tr_self = T.UnknownTrait __FUNCTION__ in
+ let subst =
+ Subst.make_esubst_from_generics def.A.signature.generics generics tr_self
in
+ let locals, body_st = Subst.fun_body_substitute_in_body subst body in
(* Evaluate the input operands *)
assert (List.length args = body.A.arg_count);
@@ -1139,22 +1149,23 @@ and eval_local_function_call_concrete (config : C.config) (fid : A.FunDeclId.id)
cc cf ctx
(** Evaluate a local (i.e., non-assumed) function call in symbolic mode *)
-and eval_local_function_call_symbolic (config : C.config) (fid : A.FunDeclId.id)
- (region_args : T.erased_region list) (type_args : T.ety list)
- (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) :
- st_cm_fun =
+and eval_transparent_function_call_symbolic (config : C.config)
+ (fid : A.FunDeclId.id) (generics : T.egeneric_args) (args : E.operand list)
+ (dest : E.place) : st_cm_fun =
fun cf ctx ->
(* Retrieve the (correctly instantiated) signature *)
let def = C.ctx_lookup_fun_decl ctx fid in
let sg = def.A.signature in
(* Instantiate the signature and introduce fresh abstraction and region ids
* while doing so *)
- let inst_sg = instantiate_fun_sig type_args cg_args sg in
+ (* There shouldn't be any reference to Self *)
+ let tr_self = T.UnknownTrait __FUNCTION__ in
+ let inst_sg = instantiate_fun_sig ctx generics tr_self sg in
(* Sanity check *)
assert (List.length args = List.length def.A.signature.inputs);
(* Evaluate the function call *)
eval_function_call_symbolic_from_inst_sig config (A.Regular fid) inst_sg
- region_args type_args cg_args args dest cf ctx
+ generics args dest cf ctx
(** Evaluate a function call in symbolic mode by using the function signature.
@@ -1162,10 +1173,8 @@ and eval_local_function_call_symbolic (config : C.config) (fid : A.FunDeclId.id)
calls in symbolic mode: only their signatures matter.
*)
and eval_function_call_symbolic_from_inst_sig (config : C.config)
- (fid : A.fun_id) (inst_sg : A.inst_fun_sig)
- (_region_args : T.erased_region list) (type_args : T.ety list)
- (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) :
- st_cm_fun =
+ (fid : A.fun_id) (inst_sg : A.inst_fun_sig) (generics : T.egeneric_args)
+ (args : E.operand list) (dest : E.place) : st_cm_fun =
fun cf ctx ->
(* Generate a fresh symbolic value for the return value *)
let ret_sv_ty = inst_sg.A.output in
@@ -1232,8 +1241,8 @@ and eval_function_call_symbolic_from_inst_sig (config : C.config)
let expr = cf ctx in
(* Synthesize the symbolic AST *)
- S.synthesize_regular_function_call fid call_id ctx abs_ids type_args cg_args
- args args_places ret_spc dest_place expr
+ S.synthesize_regular_function_call fid call_id ctx abs_ids generics args
+ args_places ret_spc dest_place expr
in
let cc = comp cc cf_call in
@@ -1294,9 +1303,8 @@ and eval_function_call_symbolic_from_inst_sig (config : C.config)
cc (cf Unit) ctx
(** Evaluate a non-local function call in symbolic mode *)
-and eval_non_local_function_call_symbolic (config : C.config)
- (fid : A.assumed_fun_id) (region_args : T.erased_region list)
- (type_args : T.ety list) (cg_args : T.const_generic list)
+and eval_assumed_function_call_symbolic (config : C.config)
+ (fid : A.assumed_fun_id) (generics : T.egeneric_args)
(args : E.operand list) (dest : E.place) : st_cm_fun =
fun cf ctx ->
(* Sanity check: make sure the type parameters don't contain regions -
@@ -1304,7 +1312,7 @@ and eval_non_local_function_call_symbolic (config : C.config)
assert (
List.for_all
(fun ty -> not (ty_has_borrows ctx.type_context.type_infos ty))
- type_args);
+ generics.types);
(* There are two cases (and this is extremely annoying):
- the function is not box_free
@@ -1315,7 +1323,7 @@ and eval_non_local_function_call_symbolic (config : C.config)
| A.BoxFree ->
(* Degenerate case: box_free - note that this is not really a function
* call: no need to call a "synthesize_..." function *)
- eval_box_free config region_args type_args cg_args args dest (cf Unit) ctx
+ eval_box_free config generics args dest (cf Unit) ctx
| _ ->
(* "Normal" case: not box_free *)
(* In symbolic mode, the behaviour of a function call is completely defined
@@ -1327,55 +1335,50 @@ and eval_non_local_function_call_symbolic (config : C.config)
(* should have been treated above *)
raise (Failure "Unreachable")
| _ ->
- instantiate_fun_sig type_args cg_args (Assumed.get_assumed_sig fid)
+ (* There shouldn't be any reference to Self *)
+ let tr_self = T.UnknownTrait __FUNCTION__ in
+ instantiate_fun_sig ctx generics tr_self
+ (Assumed.get_assumed_sig fid)
in
(* Evaluate the function call *)
eval_function_call_symbolic_from_inst_sig config (A.Assumed fid) inst_sig
- region_args type_args cg_args args dest cf ctx
+ generics args dest cf ctx
(** Evaluate a non-local (i.e, assumed) function call such as [Box::deref]
(auxiliary helper for [eval_statement]) *)
-and eval_non_local_function_call (config : C.config) (fid : A.assumed_fun_id)
- (region_args : T.erased_region list) (type_args : T.ety list)
- (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) :
+and eval_assumed_function_call (config : C.config) (fid : A.assumed_fun_id)
+ (generics : T.egeneric_args) (args : E.operand list) (dest : E.place) :
st_cm_fun =
fun cf ctx ->
(* Debug *)
log#ldebug
(lazy
- (let type_args =
- "[" ^ String.concat ", " (List.map (ety_to_string ctx) type_args) ^ "]"
- in
+ (let generics = PCtx.egeneric_args_to_string ctx generics in
let args =
"[" ^ String.concat ", " (List.map (operand_to_string ctx) args) ^ "]"
in
let dest = place_to_string ctx dest in
- "eval_non_local_function_call:\n- fid:" ^ A.show_assumed_fun_id fid
- ^ "\n- type_args: " ^ type_args ^ "\n- args: " ^ args ^ "\n- dest: "
- ^ dest));
+ "eval_assumed_function_call:\n- fid:" ^ A.show_assumed_fun_id fid
+ ^ "\n- generics: " ^ generics ^ "\n- args: " ^ args ^ "\n- dest: " ^ dest));
match config.mode with
| C.ConcreteMode ->
- eval_non_local_function_call_concrete config fid region_args type_args
- cg_args args dest (cf Unit) ctx
+ eval_assumed_function_call_concrete config fid generics args dest
+ (cf Unit) ctx
| C.SymbolicMode ->
- eval_non_local_function_call_symbolic config fid region_args type_args
- cg_args args dest cf ctx
+ eval_assumed_function_call_symbolic config fid generics args dest cf ctx
(** Evaluate a local (i.e, not assumed) function call (auxiliary helper for
[eval_statement]) *)
-and eval_local_function_call (config : C.config) (fid : A.FunDeclId.id)
- (region_args : T.erased_region list) (type_args : T.ety list)
- (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) :
+and eval_transparent_function_call (config : C.config) (fid : A.FunDeclId.id)
+ (generics : T.egeneric_args) (args : E.operand list) (dest : E.place) :
st_cm_fun =
match config.mode with
| ConcreteMode ->
- eval_local_function_call_concrete config fid region_args type_args cg_args
- args dest
+ eval_transparent_function_call_concrete config fid generics args dest
| SymbolicMode ->
- eval_local_function_call_symbolic config fid region_args type_args cg_args
- args dest
+ eval_transparent_function_call_symbolic config fid generics args dest
(** Evaluate a statement seen as a function body *)
and eval_function_body (config : C.config) (body : A.statement) : st_cm_fun =
diff --git a/compiler/InterpreterStatements.mli b/compiler/InterpreterStatements.mli
index 814bc964..0a086fb2 100644
--- a/compiler/InterpreterStatements.mli
+++ b/compiler/InterpreterStatements.mli
@@ -32,7 +32,11 @@ val pop_frame : C.config -> bool -> (V.typed_value option -> m_fun) -> m_fun
Note: there are no region parameters, because they should be erased.
*)
val instantiate_fun_sig :
- T.ety list -> T.const_generic list -> LA.fun_sig -> LA.inst_fun_sig
+ C.eval_ctx ->
+ T.egeneric_args ->
+ T.rtrait_instance_id ->
+ LA.fun_sig ->
+ LA.inst_fun_sig
(** Helper.
diff --git a/compiler/InterpreterUtils.ml b/compiler/InterpreterUtils.ml
index 637f1b1e..1513465c 100644
--- a/compiler/InterpreterUtils.ml
+++ b/compiler/InterpreterUtils.ml
@@ -273,7 +273,7 @@ let rvalue_get_place (rv : E.rvalue) : E.place option =
match rv with
| Use (Copy p | Move p) -> Some p
| Use (Constant _) -> None
- | Ref (p, _) -> Some p
+ | RvRef (p, _) -> Some p
| UnaryOp _ | BinaryOp _ | Global _ | Discriminant _ | Aggregate _ -> None
(** See {!ValuesUtils.symbolic_value_has_borrows} *)
diff --git a/compiler/Invariants.ml b/compiler/Invariants.ml
index f29c7f88..9ac5ce13 100644
--- a/compiler/Invariants.ml
+++ b/compiler/Invariants.ml
@@ -7,6 +7,7 @@ module V = Values
module E = Expressions
module C = Contexts
module Subst = Substitute
+module Assoc = AssociatedTypes
module A = LlbcAst
module L = Logging
open Cps
@@ -406,13 +407,14 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
(match (tv.V.value, tv.V.ty) with
| V.Literal cv, T.Literal ty -> check_literal_type cv ty
(* ADT case *)
- | V.Adt av, T.Adt (T.AdtId def_id, regions, tys, cgs) ->
+ | V.Adt av, T.Adt (T.AdtId def_id, generics) ->
(* Retrieve the definition to check the variant id, the number of
* parameters, etc. *)
let def = C.ctx_lookup_type_decl ctx def_id in
(* Check the number of parameters *)
- assert (List.length regions = List.length def.region_params);
- assert (List.length tys = List.length def.type_params);
+ assert (
+ List.length generics.regions = List.length def.generics.regions);
+ assert (List.length generics.types = List.length def.generics.types);
(* Check that the variant id is consistent *)
(match (av.V.variant_id, def.T.kind) with
| Some variant_id, T.Enum variants ->
@@ -421,8 +423,8 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
| _ -> raise (Failure "Erroneous typing"));
(* Check that the field types are correct *)
let field_types =
- Subst.type_decl_get_instantiated_field_etypes def av.V.variant_id
- tys cgs
+ Assoc.type_decl_get_inst_norm_field_etypes ctx def av.V.variant_id
+ generics
in
let fields_with_types =
List.combine av.V.field_values field_types
@@ -431,20 +433,28 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
(fun ((v, ty) : V.typed_value * T.ety) -> assert (v.V.ty = ty))
fields_with_types
(* Tuple case *)
- | V.Adt av, T.Adt (T.Tuple, regions, tys, cgs) ->
- assert (regions = []);
- assert (cgs = []);
+ | V.Adt av, T.Adt (T.Tuple, generics) ->
+ assert (generics.regions = []);
+ assert (generics.const_generics = []);
assert (av.V.variant_id = None);
(* Check that the fields have the proper values - and check that there
* are as many fields as field types at the same time *)
- let fields_with_types = List.combine av.V.field_values tys in
+ let fields_with_types =
+ List.combine av.V.field_values generics.types
+ in
List.iter
(fun ((v, ty) : V.typed_value * T.ety) -> assert (v.V.ty = ty))
fields_with_types
(* Assumed type case *)
- | V.Adt av, T.Adt (T.Assumed aty_id, regions, tys, cgs) -> (
+ | V.Adt av, T.Adt (T.Assumed aty_id, generics) -> (
assert (av.V.variant_id = None || aty_id = T.Option);
- match (aty_id, av.V.field_values, regions, tys, cgs) with
+ match
+ ( aty_id,
+ av.V.field_values,
+ generics.regions,
+ generics.types,
+ generics.const_generics )
+ with
(* Box *)
| T.Box, [ inner_value ], [], [ inner_ty ], []
| T.Option, [ inner_value ], [], [ inner_ty ], [] ->
@@ -520,14 +530,17 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
(* Check the current pair (value, type) *)
(match (atv.V.value, atv.V.ty) with
(* ADT case *)
- | V.AAdt av, T.Adt (T.AdtId def_id, regions, tys, cgs) ->
+ | V.AAdt av, T.Adt (T.AdtId def_id, generics) ->
(* Retrieve the definition to check the variant id, the number of
* parameters, etc. *)
let def = C.ctx_lookup_type_decl ctx def_id in
(* Check the number of parameters *)
- assert (List.length regions = List.length def.region_params);
- assert (List.length tys = List.length def.type_params);
- assert (List.length cgs = List.length def.const_generic_params);
+ assert (
+ List.length generics.regions = List.length def.generics.regions);
+ assert (List.length generics.types = List.length def.generics.types);
+ assert (
+ List.length generics.const_generics
+ = List.length def.generics.const_generics);
(* Check that the variant id is consistent *)
(match (av.V.variant_id, def.T.kind) with
| Some variant_id, T.Enum variants ->
@@ -536,8 +549,8 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
| _ -> raise (Failure "Erroneous typing"));
(* Check that the field types are correct *)
let field_types =
- Subst.type_decl_get_instantiated_field_rtypes def av.V.variant_id
- regions tys cgs
+ Assoc.type_decl_get_inst_norm_field_rtypes ctx def av.V.variant_id
+ generics
in
let fields_with_types =
List.combine av.V.field_values field_types
@@ -546,20 +559,28 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
(fun ((v, ty) : V.typed_avalue * T.rty) -> assert (v.V.ty = ty))
fields_with_types
(* Tuple case *)
- | V.AAdt av, T.Adt (T.Tuple, regions, tys, cgs) ->
- assert (regions = []);
- assert (cgs = []);
+ | V.AAdt av, T.Adt (T.Tuple, generics) ->
+ assert (generics.regions = []);
+ assert (generics.const_generics = []);
assert (av.V.variant_id = None);
(* Check that the fields have the proper values - and check that there
* are as many fields as field types at the same time *)
- let fields_with_types = List.combine av.V.field_values tys in
+ let fields_with_types =
+ List.combine av.V.field_values generics.types
+ in
List.iter
(fun ((v, ty) : V.typed_avalue * T.rty) -> assert (v.V.ty = ty))
fields_with_types
(* Assumed type case *)
- | V.AAdt av, T.Adt (T.Assumed aty_id, regions, tys, cgs) -> (
+ | V.AAdt av, T.Adt (T.Assumed aty_id, generics) -> (
assert (av.V.variant_id = None);
- match (aty_id, av.V.field_values, regions, tys, cgs) with
+ match
+ ( aty_id,
+ av.V.field_values,
+ generics.regions,
+ generics.types,
+ generics.const_generics )
+ with
(* Box *)
| T.Box, [ boxed_value ], [], [ boxed_ty ], [] ->
assert (boxed_value.V.ty = boxed_ty)
diff --git a/compiler/Logging.ml b/compiler/Logging.ml
index 9dc1f5e3..d0f5b0c5 100644
--- a/compiler/Logging.ml
+++ b/compiler/Logging.ml
@@ -57,6 +57,9 @@ let borrows_log = L.get_logger "MainLogger.Interpreter.Borrows"
(** Logger for Invariants *)
let invariants_log = L.get_logger "MainLogger.Interpreter.Invariants"
+(** Logger for AssociatedTypes *)
+let associated_types_log = L.get_logger "MainLogger.AssociatedTypes"
+
(** Logger for SCC *)
let scc_log = L.get_logger "MainLogger.Graph.SCC"
diff --git a/compiler/PrePasses.ml b/compiler/PrePasses.ml
index b348ba1d..1058fab0 100644
--- a/compiler/PrePasses.ml
+++ b/compiler/PrePasses.ml
@@ -107,7 +107,7 @@ let remove_useless_cf_merges (crate : A.crate) (f : A.fun_decl) : A.fun_decl =
false
| Assign (_, rv) -> (
match rv with
- | Use _ | Ref _ -> not must_end_with_exit
+ | Use _ | RvRef _ -> not must_end_with_exit
| Aggregate (AggregatedTuple, []) -> not must_end_with_exit
| _ -> false)
| FakeRead _ | Drop _ | Nop -> not must_end_with_exit
@@ -376,7 +376,7 @@ let remove_shallow_borrows (crate : A.crate) (f : A.fun_decl) : A.fun_decl =
method! visit_Assign env p rv =
match (p.projection, rv) with
- | [], E.Ref (_, E.Shallow) ->
+ | [], E.RvRef (_, E.Shallow) ->
(* Filter *)
filtered := E.VarId.Set.add p.var_id !filtered;
Nop
diff --git a/compiler/Print.ml b/compiler/Print.ml
index 9aa73d7c..aebfd09c 100644
--- a/compiler/Print.ml
+++ b/compiler/Print.ml
@@ -21,6 +21,9 @@ module Values = struct
type_decl_id_to_string : T.TypeDeclId.id -> string;
const_generic_var_id_to_string : T.ConstGenericVarId.id -> string;
global_decl_id_to_string : T.GlobalDeclId.id -> string;
+ trait_decl_id_to_string : T.TraitDeclId.id -> string;
+ trait_impl_id_to_string : T.TraitImplId.id -> string;
+ trait_clause_id_to_string : T.TraitClauseId.id -> string;
adt_variant_to_string : T.TypeDeclId.id -> T.VariantId.id -> string;
var_id_to_string : E.VarId.id -> string;
adt_field_names :
@@ -34,6 +37,9 @@ module Values = struct
PT.type_decl_id_to_string = fmt.type_decl_id_to_string;
PT.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string;
PT.global_decl_id_to_string = fmt.global_decl_id_to_string;
+ PT.trait_decl_id_to_string = fmt.trait_decl_id_to_string;
+ PT.trait_impl_id_to_string = fmt.trait_impl_id_to_string;
+ PT.trait_clause_id_to_string = fmt.trait_clause_id_to_string;
}
let value_to_rtype_formatter (fmt : value_formatter) : PT.rtype_formatter =
@@ -43,6 +49,9 @@ module Values = struct
PT.type_decl_id_to_string = fmt.type_decl_id_to_string;
PT.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string;
PT.global_decl_id_to_string = fmt.global_decl_id_to_string;
+ PT.trait_decl_id_to_string = fmt.trait_decl_id_to_string;
+ PT.trait_impl_id_to_string = fmt.trait_impl_id_to_string;
+ PT.trait_clause_id_to_string = fmt.trait_clause_id_to_string;
}
let value_to_stype_formatter (fmt : value_formatter) : PT.stype_formatter =
@@ -52,6 +61,9 @@ module Values = struct
PT.type_decl_id_to_string = fmt.type_decl_id_to_string;
PT.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string;
PT.global_decl_id_to_string = fmt.global_decl_id_to_string;
+ PT.trait_decl_id_to_string = fmt.trait_decl_id_to_string;
+ PT.trait_impl_id_to_string = fmt.trait_impl_id_to_string;
+ PT.trait_clause_id_to_string = fmt.trait_clause_id_to_string;
}
let var_id_to_string (id : E.VarId.id) : string =
@@ -86,10 +98,10 @@ module Values = struct
List.map (typed_value_to_string fmt) av.field_values
in
match v.ty with
- | T.Adt (T.Tuple, _, _, _) ->
+ | T.Adt (T.Tuple, _) ->
(* Tuple *)
"(" ^ String.concat ", " field_values ^ ")"
- | T.Adt (T.AdtId def_id, _, _, _) ->
+ | T.Adt (T.AdtId def_id, _) ->
(* "Regular" ADT *)
let adt_ident =
match av.variant_id with
@@ -111,7 +123,7 @@ module Values = struct
let field_values = String.concat " " field_values in
adt_ident ^ " { " ^ field_values ^ " }"
else adt_ident
- | T.Adt (T.Assumed aty, _, _, _) -> (
+ | T.Adt (T.Assumed aty, _) -> (
(* Assumed type *)
match (aty, field_values) with
| Box, [ bv ] -> "@Box(" ^ bv ^ ")"
@@ -201,10 +213,10 @@ module Values = struct
List.map (typed_avalue_to_string fmt) av.field_values
in
match v.ty with
- | T.Adt (T.Tuple, _, _, _) ->
+ | T.Adt (T.Tuple, _) ->
(* Tuple *)
"(" ^ String.concat ", " field_values ^ ")"
- | T.Adt (T.AdtId def_id, _, _, _) ->
+ | T.Adt (T.AdtId def_id, _) ->
(* "Regular" ADT *)
let adt_ident =
match av.variant_id with
@@ -226,7 +238,7 @@ module Values = struct
let field_values = String.concat " " field_values in
adt_ident ^ " { " ^ field_values ^ " }"
else adt_ident
- | T.Adt (T.Assumed aty, _, _, _) -> (
+ | T.Adt (T.Assumed aty, _) -> (
(* Assumed type *)
match (aty, field_values) with
| Box, [ bv ] -> "@Box(" ^ bv ^ ")"
@@ -452,6 +464,9 @@ module Contexts = struct
PV.adt_variant_to_string = fmt.adt_variant_to_string;
PV.var_id_to_string = fmt.var_id_to_string;
PV.adt_field_names = fmt.adt_field_names;
+ PV.trait_decl_id_to_string = fmt.trait_decl_id_to_string;
+ PV.trait_impl_id_to_string = fmt.trait_impl_id_to_string;
+ PV.trait_clause_id_to_string = fmt.trait_clause_id_to_string;
}
let ast_to_value_formatter (fmt : PA.ast_formatter) : PV.value_formatter =
@@ -486,6 +501,15 @@ module Contexts = struct
let def = C.ctx_lookup_global_decl ctx def_id in
name_to_string def.name
in
+ let trait_decl_id_to_string def_id =
+ let def = C.ctx_lookup_trait_decl ctx def_id in
+ name_to_string def.name
+ in
+ let trait_impl_id_to_string def_id =
+ let def = C.ctx_lookup_trait_impl ctx def_id in
+ name_to_string def.name
+ in
+ let trait_clause_id_to_string id = PT.trait_clause_id_to_pretty_string id in
let adt_variant_to_string =
PT.type_ctx_to_adt_variant_to_string_fun ctx.type_context.type_decls
in
@@ -506,6 +530,9 @@ module Contexts = struct
adt_variant_to_string;
var_id_to_string;
adt_field_names;
+ trait_decl_id_to_string;
+ trait_impl_id_to_string;
+ trait_clause_id_to_string;
}
let eval_ctx_to_ast_formatter (ctx : C.eval_ctx) : PA.ast_formatter =
@@ -521,6 +548,15 @@ module Contexts = struct
let def = C.ctx_lookup_global_decl ctx def_id in
global_name_to_string def.name
in
+ let trait_decl_id_to_string def_id =
+ let def = C.ctx_lookup_trait_decl ctx def_id in
+ name_to_string def.name
+ in
+ let trait_impl_id_to_string def_id =
+ let def = C.ctx_lookup_trait_impl ctx def_id in
+ name_to_string def.name
+ in
+ let trait_clause_id_to_string id = PT.trait_clause_id_to_pretty_string id in
{
rvar_to_string = ctx_fmt.PV.rvar_to_string;
r_to_string = ctx_fmt.PV.r_to_string;
@@ -533,6 +569,9 @@ module Contexts = struct
adt_field_to_string;
fun_decl_id_to_string;
global_decl_id_to_string;
+ trait_decl_id_to_string;
+ trait_impl_id_to_string;
+ trait_clause_id_to_string;
}
(** Split an [env] at every occurrence of [Frame], eliminating those elements.
@@ -608,6 +647,12 @@ module EvalCtxLlbcAst = struct
let fmt = PC.ctx_to_rtype_formatter fmt in
PT.rty_to_string fmt t
+ let egeneric_args_to_string (ctx : C.eval_ctx) (x : T.egeneric_args) : string
+ =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_etype_formatter fmt in
+ PT.egeneric_args_to_string fmt x
+
let borrow_content_to_string (ctx : C.eval_ctx) (bc : V.borrow_content) :
string =
let fmt = PC.eval_ctx_to_ctx_formatter ctx in
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index dfb2c9fd..724f1e0a 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -8,6 +8,9 @@ type type_formatter = {
type_decl_id_to_string : TypeDeclId.id -> string;
const_generic_var_id_to_string : ConstGenericVarId.id -> string;
global_decl_id_to_string : GlobalDeclId.id -> string;
+ trait_decl_id_to_string : TraitDeclId.id -> string;
+ trait_impl_id_to_string : TraitImplId.id -> string;
+ trait_clause_id_to_string : TraitClauseId.id -> string;
}
type value_formatter = {
@@ -18,6 +21,9 @@ type value_formatter = {
adt_variant_to_string : TypeDeclId.id -> VariantId.id -> string;
var_id_to_string : VarId.id -> string;
adt_field_names : TypeDeclId.id -> VariantId.id option -> string list option;
+ trait_decl_id_to_string : TraitDeclId.id -> string;
+ trait_impl_id_to_string : TraitImplId.id -> string;
+ trait_clause_id_to_string : TraitClauseId.id -> string;
}
let value_to_type_formatter (fmt : value_formatter) : type_formatter =
@@ -26,6 +32,9 @@ let value_to_type_formatter (fmt : value_formatter) : type_formatter =
type_decl_id_to_string = fmt.type_decl_id_to_string;
const_generic_var_id_to_string = fmt.const_generic_var_id_to_string;
global_decl_id_to_string = fmt.global_decl_id_to_string;
+ trait_decl_id_to_string = fmt.trait_decl_id_to_string;
+ trait_impl_id_to_string = fmt.trait_impl_id_to_string;
+ trait_clause_id_to_string = fmt.trait_clause_id_to_string;
}
(* TODO: we need to store which variables we have encountered so far, and
@@ -42,6 +51,9 @@ type ast_formatter = {
adt_field_names : TypeDeclId.id -> VariantId.id option -> string list option;
fun_decl_id_to_string : FunDeclId.id -> string;
global_decl_id_to_string : GlobalDeclId.id -> string;
+ trait_decl_id_to_string : TraitDeclId.id -> string;
+ trait_impl_id_to_string : TraitImplId.id -> string;
+ trait_clause_id_to_string : TraitClauseId.id -> string;
}
let ast_to_value_formatter (fmt : ast_formatter) : value_formatter =
@@ -53,6 +65,9 @@ let ast_to_value_formatter (fmt : ast_formatter) : value_formatter =
adt_variant_to_string = fmt.adt_variant_to_string;
var_id_to_string = fmt.var_id_to_string;
adt_field_names = fmt.adt_field_names;
+ trait_decl_id_to_string = fmt.trait_decl_id_to_string;
+ trait_impl_id_to_string = fmt.trait_impl_id_to_string;
+ trait_clause_id_to_string = fmt.trait_clause_id_to_string;
}
let ast_to_type_formatter (fmt : ast_formatter) : type_formatter =
@@ -70,31 +85,51 @@ let literal_type_to_string = Print.PrimitiveValues.literal_type_to_string
let scalar_value_to_string = Print.PrimitiveValues.scalar_value_to_string
let literal_to_string = Print.PrimitiveValues.literal_to_string
+(* Remark: not using generic_params on purpose, because we may use parameters
+ which either come from LLBC or from pure, and the [generic_params] type
+ for those ASTs is not the same. Note that it works because we actually don't
+ need to know the trait clauses to print the AST: we can thus ignore them.
+*)
let mk_type_formatter (type_decls : T.type_decl TypeDeclId.Map.t)
(global_decls : A.global_decl GlobalDeclId.Map.t)
- (type_params : type_var list)
+ (trait_decls : A.trait_decl TraitDeclId.Map.t)
+ (trait_impls : A.trait_impl TraitImplId.Map.t) (type_params : type_var list)
(const_generic_params : const_generic_var list) : type_formatter =
let type_var_id_to_string vid =
- let var = T.TypeVarId.nth type_params vid in
+ let var = TypeVarId.nth type_params vid in
type_var_to_string var
in
let const_generic_var_id_to_string vid =
- let var = T.ConstGenericVarId.nth const_generic_params vid in
+ let var = ConstGenericVarId.nth const_generic_params vid in
const_generic_var_to_string var
in
let type_decl_id_to_string def_id =
- let def = T.TypeDeclId.Map.find def_id type_decls in
+ let def = TypeDeclId.Map.find def_id type_decls in
name_to_string def.name
in
let global_decl_id_to_string def_id =
- let def = T.GlobalDeclId.Map.find def_id global_decls in
+ let def = GlobalDeclId.Map.find def_id global_decls in
+ name_to_string def.name
+ in
+ let trait_decl_id_to_string def_id =
+ let def = TraitDeclId.Map.find def_id trait_decls in
name_to_string def.name
in
+ let trait_impl_id_to_string def_id =
+ let def = TraitImplId.Map.find def_id trait_impls in
+ name_to_string def.name
+ in
+ let trait_clause_id_to_string id =
+ Print.PT.trait_clause_id_to_pretty_string id
+ in
{
type_var_id_to_string;
type_decl_id_to_string;
const_generic_var_id_to_string;
global_decl_id_to_string;
+ trait_decl_id_to_string;
+ trait_impl_id_to_string;
+ trait_clause_id_to_string;
}
(* TODO: there is a bit of duplication with Print.fun_decl_to_ast_formatter.
@@ -106,7 +141,8 @@ let mk_type_formatter (type_decls : T.type_decl TypeDeclId.Map.t)
let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t)
(fun_decls : A.fun_decl FunDeclId.Map.t)
(global_decls : A.global_decl GlobalDeclId.Map.t)
- (type_params : type_var list)
+ (trait_decls : A.trait_decl TraitDeclId.Map.t)
+ (trait_impls : A.trait_impl TraitImplId.Map.t) (type_params : type_var list)
(const_generic_params : const_generic_var list) : ast_formatter =
let type_var_id_to_string vid =
let var = T.TypeVarId.nth type_params vid in
@@ -141,6 +177,17 @@ let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t)
let def = GlobalDeclId.Map.find def_id global_decls in
global_name_to_string def.name
in
+ let trait_decl_id_to_string def_id =
+ let def = TraitDeclId.Map.find def_id trait_decls in
+ name_to_string def.name
+ in
+ let trait_impl_id_to_string def_id =
+ let def = TraitImplId.Map.find def_id trait_impls in
+ name_to_string def.name
+ in
+ let trait_clause_id_to_string id =
+ Print.PT.trait_clause_id_to_pretty_string id
+ in
{
type_var_id_to_string;
const_generic_var_id_to_string;
@@ -151,6 +198,9 @@ let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t)
adt_field_to_string;
fun_decl_id_to_string;
global_decl_id_to_string;
+ trait_decl_id_to_string;
+ trait_impl_id_to_string;
+ trait_clause_id_to_string;
}
let assumed_ty_to_string (aty : assumed_ty) : string =
@@ -182,20 +232,18 @@ let const_generic_to_string (fmt : type_formatter) (cg : T.const_generic) :
let rec ty_to_string (fmt : type_formatter) (inside : bool) (ty : ty) : string =
match ty with
- | Adt (id, tys, cgs) -> (
- let tys = List.map (ty_to_string fmt false) tys in
- let cgs = List.map (const_generic_to_string fmt) cgs in
- let params = List.append tys cgs in
+ | Adt (id, generics) -> (
match id with
| Tuple ->
- assert (cgs = []);
- "(" ^ String.concat " * " tys ^ ")"
+ let generics = generic_args_to_strings fmt false generics in
+ "(" ^ String.concat " * " generics ^ ")"
| AdtId _ | Assumed _ ->
- let params_s =
- if params = [] then "" else " " ^ String.concat " " params
+ let generics = generic_args_to_strings fmt true generics in
+ let generics_s =
+ if generics = [] then "" else " " ^ String.concat " " generics
in
- let ty_s = type_id_to_string fmt id ^ params_s in
- if params <> [] && inside then "(" ^ ty_s ^ ")" else ty_s)
+ let ty_s = type_id_to_string fmt id ^ generics_s in
+ if generics <> [] && inside then "(" ^ ty_s ^ ")" else ty_s)
| TypeVar tv -> fmt.type_var_id_to_string tv
| Literal lty -> literal_type_to_string lty
| Arrow (arg_ty, ret_ty) ->
@@ -204,6 +252,62 @@ let rec ty_to_string (fmt : type_formatter) (inside : bool) (ty : ty) : string =
in
if inside then "(" ^ ty ^ ")" else ty
+and generic_args_to_strings (fmt : type_formatter) (inside : bool)
+ (generics : generic_args) : string list =
+ let tys = List.map (ty_to_string fmt inside) generics.types in
+ let cgs = List.map (const_generic_to_string fmt) generics.const_generics in
+ let trait_refs =
+ List.map (trait_ref_to_string fmt inside) generics.trait_refs
+ in
+ List.concat [ tys; cgs; trait_refs ]
+
+and generic_args_to_string (fmt : type_formatter) (generics : generic_args) :
+ string =
+ String.concat " " (generic_args_to_strings fmt true generics)
+
+and trait_ref_to_string (fmt : type_formatter) (inside : bool) (tr : trait_ref)
+ : string =
+ let trait_id = trait_instance_id_to_string fmt false tr.trait_id in
+ let generics = generic_args_to_string fmt tr.generics in
+ let s = trait_id ^ generics in
+ if tr.generics = empty_generic_args || not inside then s else "(" ^ s ^ ")"
+
+and trait_instance_id_to_string (fmt : type_formatter) (inside : bool)
+ (id : trait_instance_id) : string =
+ match id with
+ | Self -> "Self"
+ | TraitImpl id -> fmt.trait_impl_id_to_string id
+ | Clause id -> fmt.trait_clause_id_to_string id
+ | ParentClause (inst_id, clause_id) ->
+ let inst_id = trait_instance_id_to_string fmt false inst_id in
+ let clause_id = fmt.trait_clause_id_to_string clause_id in
+ "parent(" ^ inst_id ^ ")::" ^ clause_id
+ | ItemClause (inst_id, item_name, clause_id) ->
+ let inst_id = trait_instance_id_to_string fmt false inst_id in
+ let clause_id = fmt.trait_clause_id_to_string clause_id in
+ "(" ^ inst_id ^ ")::" ^ item_name ^ "::[" ^ clause_id ^ "]"
+ | TraitRef tr -> trait_ref_to_string fmt inside tr
+ | UnknownTrait msg -> "UNKNOWN(" ^ msg ^ ")"
+
+let trait_clause_to_string (fmt : type_formatter) (clause : trait_clause) :
+ string =
+ let clause_id = fmt.trait_clause_id_to_string clause.clause_id in
+ let trait_id = fmt.trait_decl_id_to_string clause.trait_id in
+ let generics = generic_args_to_strings fmt true clause.generics in
+ let generics =
+ if generics = [] then "" else " " ^ String.concat " " generics
+ in
+ "[" ^ clause_id ^ "]: " ^ trait_id ^ generics
+
+let generic_params_to_strings (fmt : type_formatter) (generics : generic_params)
+ : string list =
+ let tys = List.map type_var_to_string generics.types in
+ let cgs = List.map const_generic_var_to_string generics.const_generics in
+ let trait_clauses =
+ List.map (trait_clause_to_string fmt) generics.trait_clauses
+ in
+ List.concat [ tys; cgs; trait_clauses ]
+
let field_to_string fmt inside (f : field) : string =
match f.field_name with
| None -> ty_to_string fmt inside f.field_ty
@@ -217,11 +321,10 @@ let variant_to_string fmt (v : variant) : string =
^ ")"
let type_decl_to_string (fmt : type_formatter) (def : type_decl) : string =
- let types = def.type_params in
let name = name_to_string def.name in
let params =
- if types = [] then ""
- else " " ^ String.concat " " (List.map type_var_to_string types)
+ if def.generics = empty_generic_params then ""
+ else " " ^ String.concat " " (generic_params_to_strings fmt def.generics)
in
match def.kind with
| Struct fields ->
@@ -353,10 +456,10 @@ let adt_g_value_to_string (fmt : value_formatter)
(field_values : 'v list) (ty : ty) : string =
let field_values = List.map value_to_string field_values in
match ty with
- | Adt (Tuple, _, _) ->
+ | Adt (Tuple, _) ->
(* Tuple *)
"(" ^ String.concat ", " field_values ^ ")"
- | Adt (AdtId def_id, _, _) ->
+ | Adt (AdtId def_id, _) ->
(* "Regular" ADT *)
let adt_ident =
match variant_id with
@@ -378,7 +481,7 @@ let adt_g_value_to_string (fmt : value_formatter)
let field_values = String.concat " " field_values in
adt_ident ^ " { " ^ field_values ^ " }"
else adt_ident
- | Adt (Assumed aty, _, _) -> (
+ | Adt (Assumed aty, _) -> (
(* Assumed type *)
match aty with
| State ->
@@ -631,7 +734,7 @@ and app_to_string (fmt : ast_formatter) (inside : bool) (indent : string)
(* There are two possibilities: either the [app] is an instantiated,
* top-level qualifier (function, ADT constructore...), or it is a "regular"
* expression *)
- let app, tys =
+ let app, generics =
match app.e with
| Qualif qualif ->
(* Qualifier case *)
@@ -656,9 +759,9 @@ and app_to_string (fmt : ast_formatter) (inside : bool) (indent : string)
in
(* Convert the type instantiation *)
let ty_fmt = ast_to_type_formatter fmt in
- let tys = List.map (ty_to_string ty_fmt true) qualif.type_args in
+ let generics = generic_args_to_strings ty_fmt true qualif.generics in
(* *)
- (qualif_s, tys)
+ (qualif_s, generics)
| _ ->
(* "Regular" expression case *)
let inside = args <> [] || (args = [] && inside) in
@@ -673,7 +776,7 @@ and app_to_string (fmt : ast_formatter) (inside : bool) (indent : string)
texpression_to_string fmt inside indent1 indent_incr
in
let args = List.map arg_to_string args in
- let all_args = List.append tys args in
+ let all_args = List.append generics args in
(* Put together *)
let e =
if all_args = [] then app else app ^ " " ^ String.concat " " all_args
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index 55513cc2..147c14b9 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -13,6 +13,9 @@ module FieldId = T.FieldId
module SymbolicValueId = V.SymbolicValueId
module FunDeclId = A.FunDeclId
module GlobalDeclId = A.GlobalDeclId
+module TraitDeclId = T.TraitDeclId
+module TraitImplId = T.TraitImplId
+module TraitClauseId = T.TraitClauseId
(** We redefine identifiers for loop: in {!Values}, the identifiers are global
(they monotonically increase across functions) while in {!module:Pure} we want
@@ -38,6 +41,10 @@ type integer_type = T.integer_type [@@deriving show, ord]
type const_generic_var = T.const_generic_var [@@deriving show, ord]
type const_generic = T.const_generic [@@deriving show, ord]
type const_generic_var_id = T.const_generic_var_id [@@deriving show, ord]
+type trait_decl_id = T.trait_decl_id [@@deriving show, ord]
+type trait_impl_id = T.trait_impl_id [@@deriving show, ord]
+type trait_clause_id = T.trait_clause_id [@@deriving show, ord]
+type trait_item_name = T.trait_item_name [@@deriving show, ord]
(** The assumed types for the pure AST.
@@ -177,6 +184,14 @@ class ['self] iter_ty_base =
inherit! [_] T.iter_const_generic
inherit! [_] PV.iter_literal_type
method visit_type_var_id : 'env -> type_var_id -> unit = fun _ _ -> ()
+ method visit_trait_decl_id : 'env -> trait_decl_id -> unit = fun _ _ -> ()
+ method visit_trait_impl_id : 'env -> trait_impl_id -> unit = fun _ _ -> ()
+
+ method visit_trait_clause_id : 'env -> trait_clause_id -> unit =
+ fun _ _ -> ()
+
+ method visit_trait_item_name : 'env -> trait_item_name -> unit =
+ fun _ _ -> ()
end
(** Ancestor for map visitor for [ty] *)
@@ -186,6 +201,18 @@ class ['self] map_ty_base =
inherit! [_] T.map_const_generic
inherit! [_] PV.map_literal_type
method visit_type_var_id : 'env -> type_var_id -> type_var_id = fun _ x -> x
+
+ method visit_trait_decl_id : 'env -> trait_decl_id -> trait_decl_id =
+ fun _ x -> x
+
+ method visit_trait_impl_id : 'env -> trait_impl_id -> trait_impl_id =
+ fun _ x -> x
+
+ method visit_trait_clause_id : 'env -> trait_clause_id -> trait_clause_id =
+ fun _ x -> x
+
+ method visit_trait_item_name : 'env -> trait_item_name -> trait_item_name =
+ fun _ x -> x
end
(** Ancestor for reduce visitor for [ty] *)
@@ -195,6 +222,18 @@ class virtual ['self] reduce_ty_base =
inherit! [_] T.reduce_const_generic
inherit! [_] PV.reduce_literal_type
method visit_type_var_id : 'env -> type_var_id -> 'a = fun _ _ -> self#zero
+
+ method visit_trait_decl_id : 'env -> trait_decl_id -> 'a =
+ fun _ _ -> self#zero
+
+ method visit_trait_impl_id : 'env -> trait_impl_id -> 'a =
+ fun _ _ -> self#zero
+
+ method visit_trait_clause_id : 'env -> trait_clause_id -> 'a =
+ fun _ _ -> self#zero
+
+ method visit_trait_item_name : 'env -> trait_item_name -> 'a =
+ fun _ _ -> self#zero
end
(** Ancestor for mapreduce visitor for [ty] *)
@@ -206,10 +245,24 @@ class virtual ['self] mapreduce_ty_base =
method visit_type_var_id : 'env -> type_var_id -> type_var_id * 'a =
fun _ x -> (x, self#zero)
+
+ method visit_trait_decl_id : 'env -> trait_decl_id -> trait_decl_id * 'a =
+ fun _ x -> (x, self#zero)
+
+ method visit_trait_impl_id : 'env -> trait_impl_id -> trait_impl_id * 'a =
+ fun _ x -> (x, self#zero)
+
+ method visit_trait_clause_id
+ : 'env -> trait_clause_id -> trait_clause_id * 'a =
+ fun _ x -> (x, self#zero)
+
+ method visit_trait_item_name
+ : 'env -> trait_item_name -> trait_item_name * 'a =
+ fun _ x -> (x, self#zero)
end
type ty =
- | Adt of type_id * ty list * const_generic list
+ | Adt of type_id * generic_args
(** {!Adt} encodes ADTs and tuples and assumed types.
TODO: what about the ended regions? (ADTs may be parameterized
@@ -220,6 +273,23 @@ type ty =
| TypeVar of type_var_id
| Literal of literal_type
| Arrow of ty * ty
+
+and trait_ref = { trait_id : trait_instance_id; generics : generic_args }
+
+and generic_args = {
+ types : ty list;
+ const_generics : const_generic list;
+ trait_refs : trait_ref list;
+}
+
+and trait_instance_id =
+ | Self
+ | TraitImpl of trait_impl_id
+ | Clause of trait_clause_id
+ | ParentClause of trait_instance_id * trait_clause_id
+ | ItemClause of trait_instance_id * trait_item_name * trait_clause_id
+ | TraitRef of trait_ref
+ | UnknownTrait of string
[@@deriving
show,
visitors
@@ -265,11 +335,24 @@ type type_decl_kind = Struct of field list | Enum of variant list | Opaque
type type_var = T.type_var [@@deriving show]
+type trait_clause = {
+ clause_id : trait_clause_id;
+ trait_id : trait_decl_id;
+ generics : generic_args;
+}
+[@@deriving show]
+
+type generic_params = {
+ types : type_var list;
+ const_generics : const_generic_var list;
+ trait_clauses : trait_clause list;
+}
+[@@deriving show]
+
type type_decl = {
def_id : TypeDeclId.id;
name : name;
- type_params : type_var list;
- const_generic_params : const_generic_var list;
+ generics : generic_params;
kind : type_decl_kind;
}
[@@deriving show]
@@ -463,18 +546,13 @@ type qualif_id =
| Proj of projection (** Field projector *)
[@@deriving show]
-(** An instantiated qualified.
+(** An instantiated qualifier.
Note that for now we have a clear separation between types and expressions,
- which explains why we have the [type_params] field: a function or ADT
+ which explains why we have the [generics] field: a function or ADT
constructor is always fully instantiated.
*)
-type qualif = {
- id : qualif_id;
- type_args : ty list;
- const_generic_args : const_generic list;
-}
-[@@deriving show]
+type qualif = { id : qualif_id; generics : generic_args } [@@deriving show]
type field_id = FieldId.id [@@deriving show, ord]
type var_id = VarId.id [@@deriving show, ord]
diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml
index d145ce93..77b12811 100644
--- a/compiler/PureTypeCheck.ml
+++ b/compiler/PureTypeCheck.ml
@@ -9,17 +9,19 @@ open PureUtils
of fields is fixed: it shouldn't be used for arrays, slices, etc.
*)
let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t)
- (type_id : type_id) (variant_id : VariantId.id option) (tys : ty list)
- (cgs : const_generic list) : ty list =
+ (type_id : type_id) (variant_id : VariantId.id option)
+ (generics : generic_args) : ty list =
match type_id with
| Tuple ->
(* Tuple *)
+ assert (generics.const_generics = []);
+ assert (generics.trait_refs = []);
assert (variant_id = None);
- tys
+ generics.types
| AdtId def_id ->
(* "Regular" ADT *)
let def = TypeDeclId.Map.find def_id type_decls in
- type_decl_get_instantiated_fields_types def variant_id tys cgs
+ type_decl_get_instantiated_fields_types def variant_id generics
| Assumed aty -> (
(* Assumed type *)
match aty with
@@ -27,14 +29,14 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t)
(* This type is opaque *)
raise (Failure "Unreachable: opaque type")
| Result ->
- let ty = Collections.List.to_cons_nil tys in
+ let ty = Collections.List.to_cons_nil generics.types in
let variant_id = Option.get variant_id in
if variant_id = result_return_id then [ ty ]
else if variant_id = result_fail_id then [ mk_error_ty ]
else
raise (Failure "Unreachable: improper variant id for result type")
| Error ->
- assert (tys = []);
+ assert (generics = empty_generic_args);
let variant_id = Option.get variant_id in
assert (
variant_id = error_failure_id || variant_id = error_out_of_fuel_id);
@@ -45,14 +47,14 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t)
else if variant_id = fuel_succ_id then [ mk_fuel_ty ]
else raise (Failure "Unreachable: improper variant id for fuel type")
| Option ->
- let ty = Collections.List.to_cons_nil tys in
+ let ty = Collections.List.to_cons_nil generics.types in
let variant_id = Option.get variant_id in
if variant_id = option_some_id then [ ty ]
else if variant_id = option_none_id then []
else
raise (Failure "Unreachable: improper variant id for option type")
| Range ->
- let ty = Collections.List.to_cons_nil tys in
+ let ty = Collections.List.to_cons_nil generics.types in
assert (variant_id = None);
[ ty; ty ]
| Vec | Array | Slice | Str ->
@@ -88,12 +90,13 @@ let rec check_typed_pattern (ctx : tc_ctx) (v : typed_pattern) : tc_ctx =
{ ctx with env }
| PatAdt av ->
(* Compute the field types *)
- let type_id, tys, cgs = ty_as_adt v.ty in
+ let type_id, generics = ty_as_adt v.ty in
let field_tys =
- get_adt_field_types ctx.type_decls type_id av.variant_id tys cgs
+ get_adt_field_types ctx.type_decls type_id av.variant_id generics
in
let check_value (ctx : tc_ctx) (ty : ty) (v : typed_pattern) : tc_ctx =
if ty <> v.ty then (
+ (* TODO: we need to normalize the types *)
log#serror
("check_typed_pattern: not the same types:" ^ "\n- ty: "
^ show_ty ty ^ "\n- v.ty: " ^ show_ty v.ty);
@@ -142,31 +145,29 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
(* Note we can only project fields of structures (not enumerations) *)
(* Deconstruct the projector type *)
let adt_ty, field_ty = destruct_arrow e.ty in
- let adt_id, adt_type_args, adt_cg_args = ty_as_adt adt_ty in
+ let adt_id, adt_generics = ty_as_adt adt_ty in
(* Check the ADT type *)
assert (adt_id = proj_adt_id);
- assert (adt_type_args = qualif.type_args);
- assert (adt_cg_args = qualif.const_generic_args);
+ assert (adt_generics = qualif.generics);
(* Retrieve and check the expected field type *)
let variant_id = None in
let expected_field_tys =
get_adt_field_types ctx.type_decls proj_adt_id variant_id
- qualif.type_args qualif.const_generic_args
+ qualif.generics
in
let expected_field_ty = FieldId.nth expected_field_tys field_id in
assert (expected_field_ty = field_ty)
| AdtCons id -> (
let expected_field_tys =
get_adt_field_types ctx.type_decls id.adt_id id.variant_id
- qualif.type_args qualif.const_generic_args
+ qualif.generics
in
let field_tys, adt_ty = destruct_arrows e.ty in
assert (expected_field_tys = field_tys);
match adt_ty with
- | Adt (type_id, tys, cgs) ->
+ | Adt (type_id, generics) ->
assert (type_id = id.adt_id);
- assert (tys = qualif.type_args);
- assert (cgs = qualif.const_generic_args)
+ assert (generics = qualif.generics)
| _ -> raise (Failure "Unreachable")))
| Let (monadic, pat, re, e_next) ->
let expected_pat_ty = if monadic then destruct_result re.ty else re.ty in
@@ -212,15 +213,14 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
| Some ty -> assert (ty = e.ty));
(* Check the fields *)
(* Retrieve and check the expected field type *)
- let adt_id, adt_type_args, adt_cg_args = ty_as_adt e.ty in
+ let adt_id, adt_generics = ty_as_adt e.ty in
assert (adt_id = supd.struct_id);
(* The id can only be: a custom type decl or an array *)
match adt_id with
| AdtId _ ->
let variant_id = None in
let expected_field_tys =
- get_adt_field_types ctx.type_decls adt_id variant_id adt_type_args
- adt_cg_args
+ get_adt_field_types ctx.type_decls adt_id variant_id adt_generics
in
List.iter
(fun (fid, fe) ->
@@ -229,7 +229,9 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
check_texpression ctx fe)
supd.updates
| Assumed Array ->
- let expected_field_ty = Collections.List.to_cons_nil adt_type_args in
+ let expected_field_ty =
+ Collections.List.to_cons_nil adt_generics.types
+ in
List.iter
(fun (_, fe) ->
assert (expected_field_ty = fe.ty);
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index f099ef9c..1357793b 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -89,14 +89,31 @@ let mk_mplace (var_id : E.VarId.id) (name : string option)
(projection : mprojection) : mplace =
{ var_id; name; projection }
+let empty_generic_params : generic_params =
+ { types = []; const_generics = []; trait_clauses = [] }
+
+let empty_generic_args : generic_args =
+ { types = []; const_generics = []; trait_refs = [] }
+
+let mk_generic_args_from_types (types : ty list) : generic_args =
+ { types; const_generics = []; trait_refs = [] }
+
+type subst = {
+ ty_subst : TypeVarId.id -> ty;
+ cg_subst : ConstGenericVarId.id -> const_generic;
+ tr_subst : TraitClauseId.id -> trait_instance_id;
+ tr_self : trait_instance_id;
+}
+
(** Type substitution *)
-let ty_substitute (tsubst : TypeVarId.id -> ty)
- (cgsubst : ConstGenericVarId.id -> const_generic) (ty : ty) : ty =
+let ty_substitute (subst : subst) (ty : ty) : ty =
let obj =
object
inherit [_] map_ty
- method! visit_TypeVar _ var_id = tsubst var_id
- method! visit_ConstGenericVar _ var_id = cgsubst var_id
+ method! visit_TypeVar _ var_id = subst.ty_subst var_id
+ method! visit_ConstGenericVar _ var_id = subst.cg_subst var_id
+ method! visit_Clause _ id = subst.tr_subst id
+ method! visit_Self _ = subst.tr_self
end
in
obj#visit_ty () ty
@@ -115,6 +132,18 @@ let make_const_generic_subst (vars : const_generic_var list)
(cgs : const_generic list) : ConstGenericVarId.id -> const_generic =
Substitute.make_const_generic_subst_from_vars vars cgs
+let make_trait_subst (clauses : trait_clause list) (refs : trait_ref list) :
+ TraitClauseId.id -> trait_instance_id =
+ let clauses = List.map (fun x -> x.clause_id) clauses in
+ let refs = List.map (fun x -> TraitRef x) refs in
+ let ls = List.combine clauses refs in
+ let mp =
+ List.fold_left
+ (fun mp (k, v) -> TraitClauseId.Map.add k v mp)
+ TraitClauseId.Map.empty ls
+ in
+ fun id -> TraitClauseId.Map.find id mp
+
(** Retrieve the list of fields for the given variant of a {!type:Aeneas.Pure.type_decl}.
Raises [Invalid_argument] if the arguments are incorrect.
@@ -135,20 +164,27 @@ let type_decl_get_fields (def : type_decl)
- def: " ^ show_type_decl def ^ "\n- opt_variant_id: "
^ opt_variant_id))
+let make_subst_from_generics (params : generic_params) (args : generic_args)
+ (tr_self : trait_instance_id) : subst =
+ let ty_subst = make_type_subst params.types args.types in
+ let cg_subst =
+ make_const_generic_subst params.const_generics args.const_generics
+ in
+ let tr_subst = make_trait_subst params.trait_clauses args.trait_refs in
+ { ty_subst; cg_subst; tr_subst; tr_self }
+
(** Instantiate the type variables for the chosen variant in an ADT definition,
and return the list of the types of its fields *)
let type_decl_get_instantiated_fields_types (def : type_decl)
- (opt_variant_id : VariantId.id option) (types : ty list)
- (cgs : const_generic list) : ty list =
- let ty_subst = make_type_subst def.type_params types in
- let cg_subst = make_const_generic_subst def.const_generic_params cgs in
+ (opt_variant_id : VariantId.id option) (generics : generic_args) : ty list =
+ (* There shouldn't be any reference to Self *)
+ let tr_self = UnknownTrait __FUNCTION__ in
+ let subst = make_subst_from_generics def.generics generics tr_self in
let fields = type_decl_get_fields def opt_variant_id in
- List.map (fun f -> ty_substitute ty_subst cg_subst f.field_ty) fields
+ List.map (fun f -> ty_substitute subst f.field_ty) fields
-let fun_sig_substitute (tsubst : TypeVarId.id -> ty)
- (cgsubst : ConstGenericVarId.id -> const_generic) (sg : fun_sig) :
- inst_fun_sig =
- let subst = ty_substitute tsubst cgsubst in
+let fun_sig_substitute (subst : subst) (sg : fun_sig) : inst_fun_sig =
+ let subst = ty_substitute subst in
let inputs = List.map subst sg.inputs in
let output = subst sg.output in
let doutputs = List.map subst sg.doutputs in
@@ -194,9 +230,9 @@ let is_global (e : texpression) : bool =
let is_const (e : texpression) : bool =
match e.e with Const _ -> true | _ -> false
-let ty_as_adt (ty : ty) : type_id * ty list * const_generic list =
+let ty_as_adt (ty : ty) : type_id * generic_args =
match ty with
- | Adt (id, tys, cgs) -> (id, tys, cgs)
+ | Adt (id, generics) -> (id, generics)
| _ -> raise (Failure "Unreachable")
(** Remove the external occurrences of {!Meta} *)
@@ -294,28 +330,30 @@ let destruct_qualif_app (e : texpression) : qualif * texpression list =
(** Destruct an expression into a function call, if possible *)
let opt_destruct_function_call (e : texpression) :
- (fun_or_op_id * ty list * texpression list) option =
+ (fun_or_op_id * generic_args * texpression list) option =
match opt_destruct_qualif_app e with
| None -> None
| Some (qualif, args) -> (
match qualif.id with
- | FunOrOp fun_id -> Some (fun_id, qualif.type_args, args)
+ | FunOrOp fun_id -> Some (fun_id, qualif.generics, args)
| _ -> None)
let opt_destruct_result (ty : ty) : ty option =
match ty with
- | Adt (Assumed Result, tys, cgs) ->
- assert (cgs = []);
- Some (Collections.List.to_cons_nil tys)
+ | Adt (Assumed Result, generics) ->
+ assert (generics.const_generics = []);
+ assert (generics.trait_refs = []);
+ Some (Collections.List.to_cons_nil generics.types)
| _ -> None
let destruct_result (ty : ty) : ty = Option.get (opt_destruct_result ty)
let opt_destruct_tuple (ty : ty) : ty list option =
match ty with
- | Adt (Tuple, tys, cgs) ->
- assert (cgs = []);
- Some tys
+ | Adt (Tuple, generics) ->
+ assert (generics.const_generics = []);
+ assert (generics.trait_refs = []);
+ Some generics.types
| _ -> None
let mk_abs (x : typed_pattern) (e : texpression) : texpression =
@@ -387,14 +425,16 @@ let mk_switch (scrut : texpression) (sb : switch_body) : texpression =
- if there is > one type: wrap them in a tuple
*)
let mk_simpl_tuple_ty (tys : ty list) : ty =
- match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys, [])
+ match tys with
+ | [ ty ] -> ty
+ | _ -> Adt (Tuple, mk_generic_args_from_types tys)
let mk_bool_ty : ty = Literal Bool
-let mk_unit_ty : ty = Adt (Tuple, [], [])
+let mk_unit_ty : ty = Adt (Tuple, empty_generic_args)
let mk_unit_rvalue : texpression =
let id = AdtCons { adt_id = Tuple; variant_id = None } in
- let qualif = { id; type_args = []; const_generic_args = [] } in
+ let qualif = { id; generics = empty_generic_args } in
let e = Qualif qualif in
let ty = mk_unit_ty in
{ e; ty }
@@ -434,7 +474,7 @@ let mk_simpl_tuple_pattern (vl : typed_pattern list) : typed_pattern =
| [ v ] -> v
| _ ->
let tys = List.map (fun (v : typed_pattern) -> v.ty) vl in
- let ty = Adt (Tuple, tys, []) in
+ let ty = Adt (Tuple, mk_generic_args_from_types tys) in
let value = PatAdt { variant_id = None; field_values = vl } in
{ value; ty }
@@ -445,11 +485,11 @@ let mk_simpl_tuple_texpression (vl : texpression list) : texpression =
| _ ->
(* Compute the types of the fields, and the type of the tuple constructor *)
let tys = List.map (fun (v : texpression) -> v.ty) vl in
- let ty = Adt (Tuple, tys, []) in
+ let ty = Adt (Tuple, mk_generic_args_from_types tys) in
let ty = mk_arrows tys ty in
(* Construct the tuple constructor qualifier *)
let id = AdtCons { adt_id = Tuple; variant_id = None } in
- let qualif = { id; type_args = tys; const_generic_args = [] } in
+ let qualif = { id; generics = mk_generic_args_from_types tys } in
(* Put everything together *)
let cons = { e = Qualif qualif; ty } in
mk_apps cons vl
@@ -467,32 +507,36 @@ let ty_as_integer (t : ty) : T.integer_type =
let ty_as_literal (t : ty) : T.literal_type =
match t with Literal ty -> ty | _ -> raise (Failure "Unreachable")
-let mk_state_ty : ty = Adt (Assumed State, [], [])
-let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ], [])
-let mk_error_ty : ty = Adt (Assumed Error, [], [])
-let mk_fuel_ty : ty = Adt (Assumed Fuel, [], [])
+let mk_state_ty : ty = Adt (Assumed State, empty_generic_args)
+
+let mk_result_ty (ty : ty) : ty =
+ Adt (Assumed Result, mk_generic_args_from_types [ ty ])
+
+let mk_error_ty : ty = Adt (Assumed Error, empty_generic_args)
+let mk_fuel_ty : ty = Adt (Assumed Fuel, empty_generic_args)
let mk_error (error : VariantId.id) : texpression =
let ty = mk_error_ty in
let id = AdtCons { adt_id = Assumed Error; variant_id = Some error } in
- let qualif = { id; type_args = []; const_generic_args = [] } in
+ let qualif = { id; generics = empty_generic_args } in
let e = Qualif qualif in
{ e; ty }
let unwrap_result_ty (ty : ty) : ty =
match ty with
- | Adt (Assumed Result, [ ty ], cgs) ->
- assert (cgs = []);
+ | Adt
+ (Assumed Result, { types = [ ty ]; const_generics = []; trait_refs = [] })
+ ->
ty
| _ -> raise (Failure "not a result type")
let mk_result_fail_texpression (error : texpression) (ty : ty) : texpression =
let type_args = [ ty ] in
- let ty = Adt (Assumed Result, type_args, []) in
+ let ty = Adt (Assumed Result, mk_generic_args_from_types type_args) in
let id =
AdtCons { adt_id = Assumed Result; variant_id = Some result_fail_id }
in
- let qualif = { id; type_args; const_generic_args = [] } in
+ let qualif = { id; generics = mk_generic_args_from_types type_args } in
let cons_e = Qualif qualif in
let cons_ty = mk_arrow error.ty ty in
let cons = { e = cons_e; ty = cons_ty } in
@@ -505,11 +549,11 @@ let mk_result_fail_texpression_with_error_id (error : VariantId.id) (ty : ty) :
let mk_result_return_texpression (v : texpression) : texpression =
let type_args = [ v.ty ] in
- let ty = Adt (Assumed Result, type_args, []) in
+ let ty = Adt (Assumed Result, mk_generic_args_from_types type_args) in
let id =
AdtCons { adt_id = Assumed Result; variant_id = Some result_return_id }
in
- let qualif = { id; type_args; const_generic_args = [] } in
+ let qualif = { id; generics = mk_generic_args_from_types type_args } in
let cons_e = Qualif qualif in
let cons_ty = mk_arrow v.ty ty in
let cons = { e = cons_e; ty = cons_ty } in
@@ -518,7 +562,7 @@ let mk_result_return_texpression (v : texpression) : texpression =
(** Create a [Fail err] pattern which captures the error *)
let mk_result_fail_pattern (error_pat : pattern) (ty : ty) : typed_pattern =
let error_pat : typed_pattern = { value = error_pat; ty = mk_error_ty } in
- let ty = Adt (Assumed Result, [ ty ], []) in
+ let ty = Adt (Assumed Result, mk_generic_args_from_types [ ty ]) in
let value =
PatAdt { variant_id = Some result_fail_id; field_values = [ error_pat ] }
in
@@ -530,7 +574,7 @@ let mk_result_fail_pattern_ignore_error (ty : ty) : typed_pattern =
mk_result_fail_pattern error_pat ty
let mk_result_return_pattern (v : typed_pattern) : typed_pattern =
- let ty = Adt (Assumed Result, [ v.ty ], []) in
+ let ty = Adt (Assumed Result, mk_generic_args_from_types [ v.ty ]) in
let value =
PatAdt { variant_id = Some result_return_id; field_values = [ v ] }
in
@@ -565,11 +609,11 @@ let rec typed_pattern_to_texpression (pat : typed_pattern) : texpression option
let fields_values = List.map (fun e -> Option.get e) fields in
(* Retrieve the type id and the type args from the pat type (simpler this way *)
- let adt_id, type_args, const_generic_args = ty_as_adt pat.ty in
+ let adt_id, generics = ty_as_adt pat.ty in
(* Create the constructor *)
let qualif_id = AdtCons { adt_id; variant_id = av.variant_id } in
- let qualif = { id = qualif_id; type_args; const_generic_args } in
+ let qualif = { id = qualif_id; generics } in
let cons_e = Qualif qualif in
let field_tys =
List.map (fun (v : texpression) -> v.ty) fields_values
diff --git a/compiler/Substitute.ml b/compiler/Substitute.ml
index 38850243..64e7716a 100644
--- a/compiler/Substitute.ml
+++ b/compiler/Substitute.ml
@@ -9,51 +9,53 @@ module E = Expressions
module A = LlbcAst
module C = Contexts
-(** Substitute types variables and regions in a type. *)
-let ty_substitute (rsubst : 'r1 -> 'r2) (tsubst : T.TypeVarId.id -> 'r2 T.ty)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : 'r1 T.ty) :
- 'r2 T.ty =
- let open T in
- let visitor =
- object
- inherit [_] map_ty
- method visit_'r _ r = rsubst r
- method! visit_TypeVar _ id = tsubst id
-
- method! visit_type_var_id _ _ =
- (* We should never get here because we reimplemented [visit_TypeVar] *)
- raise (Failure "Unexpected")
+type ('r1, 'r2) subst = {
+ r_subst : 'r1 -> 'r2;
+ ty_subst : T.TypeVarId.id -> 'r2 T.ty;
+ cg_subst : T.ConstGenericVarId.id -> T.const_generic;
+ (** Substitution from *local* trait clause to trait instance *)
+ tr_subst : T.TraitClauseId.id -> 'r2 T.trait_instance_id;
+ (** Substitution for the [Self] trait instance *)
+ tr_self : 'r2 T.trait_instance_id;
+}
+
+let ty_substitute_visitor (subst : ('r1, 'r2) subst) =
+ object
+ inherit [_] T.map_ty
+ method visit_'r _ r = subst.r_subst r
+ method! visit_TypeVar _ id = subst.ty_subst id
- method! visit_ConstGenericVar _ id = cgsubst id
+ method! visit_type_var_id _ _ =
+ (* We should never get here because we reimplemented [visit_TypeVar] *)
+ raise (Failure "Unexpected")
- method! visit_const_generic_var_id _ _ =
- (* We should never get here because we reimplemented [visit_Var] *)
- raise (Failure "Unexpected")
- end
- in
+ method! visit_ConstGenericVar _ id = subst.cg_subst id
- visitor#visit_ty () ty
+ method! visit_const_generic_var_id _ _ =
+ (* We should never get here because we reimplemented [visit_Var] *)
+ raise (Failure "Unexpected")
-let rty_substitute (rsubst : T.RegionId.id -> T.RegionId.id)
- (tsubst : T.TypeVarId.id -> T.rty)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : T.rty) : T.rty =
- let rsubst r =
- match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid)
- in
- ty_substitute rsubst tsubst cgsubst ty
+ method! visit_Clause _ id = subst.tr_subst id
+ method! visit_Self _ = subst.tr_self
+ end
-let ety_substitute (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : T.ety) : T.ety =
- let rsubst r = r in
- ty_substitute rsubst tsubst cgsubst ty
+(** Substitute types variables and regions in a type. *)
+let ty_substitute (subst : ('r1, 'r2) subst) (ty : 'r1 T.ty) : 'r2 T.ty =
+ let visitor = ty_substitute_visitor subst in
+ visitor#visit_ty () ty
(** Convert an {!T.rty} to an {!T.ety} by erasing the region variables *)
let erase_regions (ty : T.rty) : T.ety =
- ty_substitute
- (fun _ -> T.Erased)
- (fun vid -> T.TypeVar vid)
- (fun id -> T.ConstGenericVar id)
- ty
+ let subst =
+ {
+ r_subst = (fun _ -> T.Erased);
+ ty_subst = (fun vid -> T.TypeVar vid);
+ cg_subst = (fun id -> T.ConstGenericVar id);
+ tr_subst = (fun id -> T.Clause id);
+ tr_self = T.Self;
+ }
+ in
+ ty_substitute subst ty
(** Generate fresh regions for region variables.
@@ -78,18 +80,20 @@ let fresh_regions_with_substs (region_vars : T.region_var list) :
(* Generate the substitution from region var id to region *)
let rid_subst id = T.RegionVarId.Map.find id rid_map in
(* Generate the substitution from region to region *)
- let rsubst r =
+ let r_subst r =
match r with T.Static -> T.Static | T.Var id -> T.Var (rid_subst id)
in
(* Return *)
- (fresh_region_ids, rid_subst, rsubst)
+ (fresh_region_ids, rid_subst, r_subst)
-(** Erase the regions in a type and substitute the type variables *)
-let erase_regions_substitute_types (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic)
- (ty : 'r T.region T.ty) : T.ety =
- let rsubst (_ : 'r T.region) : T.erased_region = T.Erased in
- ty_substitute rsubst tsubst cgsubst ty
+(** Erase the regions in a type and perform a substitution *)
+let erase_regions_substitute_types (ty_subst : T.TypeVarId.id -> T.ety)
+ (cg_subst : T.ConstGenericVarId.id -> T.const_generic)
+ (tr_subst : T.TraitClauseId.id -> T.etrait_instance_id)
+ (tr_self : T.etrait_instance_id) (ty : 'r T.ty) : T.ety =
+ let r_subst (_ : 'r) : T.erased_region = T.Erased in
+ let subst = { r_subst; ty_subst; cg_subst; tr_subst; tr_self } in
+ ty_substitute subst ty
(** Create a region substitution from a list of region variable ids and a list of
regions (with which to substitute the region variable ids *)
@@ -146,16 +150,62 @@ let make_const_generic_subst_from_vars (vars : T.const_generic_var list)
(List.map (fun (x : T.const_generic_var) -> x.T.index) vars)
cgs
-(** Instantiate the type variables in an ADT definition, and return, for
- every variant, the list of the types of its fields *)
-let type_decl_get_instantiated_variants_fields_rtypes (def : T.type_decl)
- (regions : T.RegionId.id T.region list) (types : T.rty list)
- (cgs : T.const_generic list) : (T.VariantId.id option * T.rty list) list =
- let r_subst = make_region_subst_from_vars def.T.region_params regions in
- let ty_subst = make_type_subst_from_vars def.T.type_params types in
+(** Create a trait substitution from a list of trait clause ids and a list of
+ trait refs *)
+let make_trait_subst (clause_ids : T.TraitClauseId.id list)
+ (trs : 'r T.trait_ref list) : T.TraitClauseId.id -> 'r T.trait_instance_id =
+ let ls = List.combine clause_ids trs in
+ let mp =
+ List.fold_left
+ (fun mp (k, v) -> T.TraitClauseId.Map.add k (T.TraitRef v) mp)
+ T.TraitClauseId.Map.empty ls
+ in
+ fun id -> T.TraitClauseId.Map.find id mp
+
+let make_trait_subst_from_clauses (clauses : T.trait_clause list)
+ (trs : 'r T.trait_ref list) : T.TraitClauseId.id -> 'r T.trait_instance_id =
+ make_trait_subst
+ (List.map (fun (x : T.trait_clause) -> x.T.clause_id) clauses)
+ trs
+
+let make_subst_from_generics (params : T.generic_params)
+ (args : 'r T.generic_args) (tr_self : 'r T.trait_instance_id) :
+ (T.region_var_id T.region, 'r) subst =
+ let r_subst = make_region_subst_from_vars params.T.regions args.T.regions in
+ let ty_subst = make_type_subst_from_vars params.T.types args.T.types in
+ let cg_subst =
+ make_const_generic_subst_from_vars params.T.const_generics
+ args.T.const_generics
+ in
+ let tr_subst =
+ make_trait_subst_from_clauses params.T.trait_clauses args.T.trait_refs
+ in
+ { r_subst; ty_subst; cg_subst; tr_subst; tr_self }
+
+let make_esubst_from_generics (params : T.generic_params)
+ (generics : T.egeneric_args) (tr_self : T.etrait_instance_id) =
+ let r_subst _ = T.Erased in
+ let ty_subst = make_type_subst_from_vars params.types generics.T.types in
let cg_subst =
- make_const_generic_subst_from_vars def.T.const_generic_params cgs
+ make_const_generic_subst_from_vars params.const_generics
+ generics.T.const_generics
in
+ let tr_subst =
+ make_trait_subst_from_clauses params.trait_clauses generics.T.trait_refs
+ in
+ { r_subst; ty_subst; cg_subst; tr_subst; tr_self }
+
+(** Instantiate the type variables in an ADT definition, and return, for
+ every variant, the list of the types of its fields.
+
+ **IMPORTANT**: this function doesn't normalize the types, you may want to
+ use the [AssociatedTypes] equivalent instead.
+*)
+let type_decl_get_instantiated_variants_fields_rtypes (def : T.type_decl)
+ (generics : T.rgeneric_args) : (T.VariantId.id option * T.rty list) list =
+ (* There shouldn't be any reference to Self *)
+ let tr_self = T.UnknownTrait __FUNCTION__ in
+ let subst = make_subst_from_generics def.T.generics generics tr_self in
let (variants_fields : (T.VariantId.id option * T.field list) list) =
match def.T.kind with
| T.Enum variants ->
@@ -171,191 +221,218 @@ let type_decl_get_instantiated_variants_fields_rtypes (def : T.type_decl)
in
List.map
(fun (id, fields) ->
- ( id,
- List.map
- (fun f -> ty_substitute r_subst ty_subst cg_subst f.T.field_ty)
- fields ))
+ (id, List.map (fun f -> ty_substitute subst f.T.field_ty) fields))
variants_fields
(** Instantiate the type variables in an ADT definition, and return the list
- of types of the fields for the chosen variant *)
+ of types of the fields for the chosen variant.
+
+ **IMPORTANT**: this function doesn't normalize the types, you may want to
+ use the [AssociatedTypes] equivalent instead.
+*)
let type_decl_get_instantiated_field_rtypes (def : T.type_decl)
- (opt_variant_id : T.VariantId.id option)
- (regions : T.RegionId.id T.region list) (types : T.rty list)
- (cgs : T.const_generic list) : T.rty list =
- let r_subst = make_region_subst_from_vars def.T.region_params regions in
- let ty_subst = make_type_subst_from_vars def.T.type_params types in
- let cg_subst =
- make_const_generic_subst_from_vars def.T.const_generic_params cgs
- in
+ (opt_variant_id : T.VariantId.id option) (generics : T.rgeneric_args) :
+ T.rty list =
+ (* For now, check that there are no clauses - otherwise we might need
+ to normalize the types *)
+ assert (def.generics.trait_clauses = []);
+ (* There shouldn't be any reference to Self *)
+ let tr_self = T.UnknownTrait __FUNCTION__ in
+ let subst = make_subst_from_generics def.T.generics generics tr_self in
let fields = TU.type_decl_get_fields def opt_variant_id in
- List.map
- (fun f -> ty_substitute r_subst ty_subst cg_subst f.T.field_ty)
- fields
+ List.map (fun f -> ty_substitute subst f.T.field_ty) fields
(** Return the types of the properly instantiated ADT's variant, provided a
- context *)
+ context.
+
+ **IMPORTANT**: this function doesn't normalize the types, you may want to
+ use the [AssociatedTypes] equivalent instead.
+*)
let ctx_adt_get_instantiated_field_rtypes (ctx : C.eval_ctx)
(def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option)
- (regions : T.RegionId.id T.region list) (types : T.rty list)
- (cgs : T.const_generic list) : T.rty list =
+ (generics : T.rgeneric_args) : T.rty list =
let def = C.ctx_lookup_type_decl ctx def_id in
- type_decl_get_instantiated_field_rtypes def opt_variant_id regions types cgs
+ type_decl_get_instantiated_field_rtypes def opt_variant_id generics
(** Return the types of the properly instantiated ADT value (note that
- here, ADT is understood in its broad meaning: ADT, assumed value or tuple) *)
+ here, ADT is understood in its broad meaning: ADT, assumed value or tuple).
+
+ **IMPORTANT**: this function doesn't normalize the types, you may want to
+ use the [AssociatedTypes] equivalent instead.
+ *)
let ctx_adt_value_get_instantiated_field_rtypes (ctx : C.eval_ctx)
- (adt : V.adt_value) (id : T.type_id)
- (region_params : T.RegionId.id T.region list) (type_params : T.rty list)
- (cg_params : T.const_generic list) : T.rty list =
+ (adt : V.adt_value) (id : T.type_id) (generics : T.rgeneric_args) :
+ T.rty list =
match id with
| T.AdtId id ->
(* Retrieve the types of the fields *)
- ctx_adt_get_instantiated_field_rtypes ctx id adt.V.variant_id
- region_params type_params cg_params
+ ctx_adt_get_instantiated_field_rtypes ctx id adt.V.variant_id generics
| T.Tuple ->
- assert (List.length region_params = 0);
- type_params
+ assert (generics.regions = []);
+ generics.types
| T.Assumed aty -> (
match aty with
| T.Box | T.Vec ->
- assert (List.length region_params = 0);
- assert (List.length type_params = 1);
- assert (List.length cg_params = 0);
- type_params
+ assert (generics.regions = []);
+ assert (List.length generics.types = 1);
+ assert (generics.const_generics = []);
+ generics.types
| T.Option ->
- assert (List.length region_params = 0);
- assert (List.length type_params = 1);
- assert (List.length cg_params = 0);
- if adt.V.variant_id = Some T.option_some_id then type_params
+ assert (generics.regions = []);
+ assert (List.length generics.types = 1);
+ assert (generics.const_generics = []);
+ if adt.V.variant_id = Some T.option_some_id then generics.types
else if adt.V.variant_id = Some T.option_none_id then []
else raise (Failure "Unreachable")
| T.Range ->
- assert (List.length region_params = 0);
- assert (List.length type_params = 1);
- assert (List.length cg_params = 0);
- type_params
+ assert (generics.regions = []);
+ assert (List.length generics.types = 1);
+ assert (generics.const_generics = []);
+ generics.types
| T.Array | T.Slice | T.Str ->
(* Those types don't have fields *)
raise (Failure "Unreachable"))
(** Instantiate the type variables in an ADT definition, and return the list
- of types of the fields for the chosen variant *)
+ of types of the fields for the chosen variant.
+
+ **IMPORTANT**: this function doesn't normalize the types, you may want to
+ use the [AssociatedTypes] equivalent instead.
+*)
let type_decl_get_instantiated_field_etypes (def : T.type_decl)
- (opt_variant_id : T.VariantId.id option) (types : T.ety list)
- (cgs : T.const_generic list) : T.ety list =
- let ty_subst = make_type_subst_from_vars def.T.type_params types in
- let cg_subst =
- make_const_generic_subst_from_vars def.T.const_generic_params cgs
+ (opt_variant_id : T.VariantId.id option) (generics : T.egeneric_args) :
+ T.ety list =
+ (* For now, check that there are no clauses - otherwise we might need
+ to normalize the types *)
+ assert (def.generics.trait_clauses = []);
+ (* There shouldn't be any reference to Self *)
+ let tr_self : T.erased_region T.trait_instance_id =
+ T.UnknownTrait __FUNCTION__
+ in
+ let { r_subst = _; ty_subst; cg_subst; tr_subst; tr_self } =
+ make_esubst_from_generics def.T.generics generics tr_self
in
let fields = TU.type_decl_get_fields def opt_variant_id in
List.map
- (fun f -> erase_regions_substitute_types ty_subst cg_subst f.T.field_ty)
+ (fun (f : T.field) ->
+ erase_regions_substitute_types ty_subst cg_subst tr_subst tr_self
+ f.T.field_ty)
fields
(** Return the types of the properly instantiated ADT's variant, provided a
- context *)
+ context.
+
+ **IMPORTANT**: this function doesn't normalize the types, you may want to
+ use the [AssociatedTypes] equivalent instead.
+ *)
let ctx_adt_get_instantiated_field_etypes (ctx : C.eval_ctx)
(def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option)
- (types : T.ety list) (cgs : T.const_generic list) : T.ety list =
+ (generics : T.egeneric_args) : T.ety list =
let def = C.ctx_lookup_type_decl ctx def_id in
- type_decl_get_instantiated_field_etypes def opt_variant_id types cgs
+ type_decl_get_instantiated_field_etypes def opt_variant_id generics
-let statement_substitute_visitor (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) =
+let statement_substitute_visitor
+ (subst : (T.erased_region, T.erased_region) subst) =
+ (* Keep in synch with [ty_substitute_visitor] *)
object
inherit [_] A.map_statement
- method! visit_ety _ ty = ety_substitute tsubst cgsubst ty
- method! visit_ConstGenericVar _ id = cgsubst id
+ method! visit_'r _ r = subst.r_subst r
+ method! visit_TypeVar _ id = subst.ty_subst id
+
+ method! visit_type_var_id _ _ =
+ (* We should never get here because we reimplemented [visit_TypeVar] *)
+ raise (Failure "Unexpected")
+
+ method! visit_ConstGenericVar _ id = subst.cg_subst id
method! visit_const_generic_var_id _ _ =
(* We should never get here because we reimplemented [visit_Var] *)
raise (Failure "Unexpected")
+
+ method! visit_Clause _ id = subst.tr_subst id
+ method! visit_Self _ = subst.tr_self
end
(** Apply a type substitution to a place *)
-let place_substitute (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (p : E.place) :
- E.place =
+let place_substitute (subst : (T.erased_region, T.erased_region) subst)
+ (p : E.place) : E.place =
(* There is in fact nothing to do *)
- (statement_substitute_visitor tsubst cgsubst)#visit_place () p
+ (statement_substitute_visitor subst)#visit_place () p
(** Apply a type substitution to an operand *)
-let operand_substitute (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (op : E.operand) :
- E.operand =
- (statement_substitute_visitor tsubst cgsubst)#visit_operand () op
+let operand_substitute (subst : (T.erased_region, T.erased_region) subst)
+ (op : E.operand) : E.operand =
+ (statement_substitute_visitor subst)#visit_operand () op
(** Apply a type substitution to an rvalue *)
-let rvalue_substitute (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (rv : E.rvalue) :
- E.rvalue =
- (statement_substitute_visitor tsubst cgsubst)#visit_rvalue () rv
+let rvalue_substitute (subst : (T.erased_region, T.erased_region) subst)
+ (rv : E.rvalue) : E.rvalue =
+ (statement_substitute_visitor subst)#visit_rvalue () rv
(** Apply a type substitution to an assertion *)
-let assertion_substitute (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (a : A.assertion) :
- A.assertion =
- (statement_substitute_visitor tsubst cgsubst)#visit_assertion () a
+let assertion_substitute (subst : (T.erased_region, T.erased_region) subst)
+ (a : A.assertion) : A.assertion =
+ (statement_substitute_visitor subst)#visit_assertion () a
(** Apply a type substitution to a call *)
-let call_substitute (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (call : A.call) :
- A.call =
- (statement_substitute_visitor tsubst cgsubst)#visit_call () call
+let call_substitute (subst : (T.erased_region, T.erased_region) subst)
+ (call : A.call) : A.call =
+ (statement_substitute_visitor subst)#visit_call () call
(** Apply a type substitution to a statement *)
-let statement_substitute (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (st : A.statement) :
- A.statement =
- (statement_substitute_visitor tsubst cgsubst)#visit_statement () st
+let statement_substitute (subst : (T.erased_region, T.erased_region) subst)
+ (st : A.statement) : A.statement =
+ (statement_substitute_visitor subst)#visit_statement () st
(** Apply a type substitution to a function body. Return the local variables
and the body. *)
-let fun_body_substitute_in_body (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (body : A.fun_body) :
+let fun_body_substitute_in_body
+ (subst : (T.erased_region, T.erased_region) subst) (body : A.fun_body) :
A.var list * A.statement =
- let rsubst r = r in
let locals =
List.map
- (fun (v : A.var) ->
- { v with A.var_ty = ty_substitute rsubst tsubst cgsubst v.A.var_ty })
+ (fun (v : A.var) -> { v with A.var_ty = ty_substitute subst v.A.var_ty })
body.A.locals
in
- let body = statement_substitute tsubst cgsubst body.body in
+ let body = statement_substitute subst body.body in
(locals, body)
-(** Substitute a function signature *)
+(** Substitute a function signature.
+
+ **IMPORTANT:** this function doesn't normalize the types.
+ *)
let substitute_signature (asubst : T.RegionGroupId.id -> V.AbstractionId.id)
- (rsubst : T.RegionVarId.id -> T.RegionId.id)
- (tsubst : T.TypeVarId.id -> T.rty)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (sg : A.fun_sig) :
- A.inst_fun_sig =
- let rsubst' (r : T.RegionVarId.id T.region) : T.RegionId.id T.region =
- match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid)
+ (r_subst : T.RegionVarId.id -> T.RegionId.id)
+ (ty_subst : T.TypeVarId.id -> T.rty)
+ (cg_subst : T.ConstGenericVarId.id -> T.const_generic)
+ (tr_subst : T.TraitClauseId.id -> T.rtrait_instance_id)
+ (tr_self : T.rtrait_instance_id) (sg : A.fun_sig) : A.inst_fun_sig =
+ let r_subst' (r : T.RegionVarId.id T.region) : T.RegionId.id T.region =
+ match r with T.Static -> T.Static | T.Var rid -> T.Var (r_subst rid)
in
- let inputs = List.map (ty_substitute rsubst' tsubst cgsubst) sg.A.inputs in
- let output = ty_substitute rsubst' tsubst cgsubst sg.A.output in
+ let subst = { r_subst = r_subst'; ty_subst; cg_subst; tr_subst; tr_self } in
+ let inputs = List.map (ty_substitute subst) sg.A.inputs in
+ let output = ty_substitute subst sg.A.output in
let subst_region_group (rg : T.region_var_group) : A.abs_region_group =
let id = asubst rg.id in
- let regions = List.map rsubst rg.regions in
+ let regions = List.map r_subst rg.regions in
let parents = List.map asubst rg.parents in
{ id; regions; parents }
in
let regions_hierarchy = List.map subst_region_group sg.A.regions_hierarchy in
{ A.regions_hierarchy; inputs; output }
-(** Substitute type variable identifiers in a type *)
-let ty_substitute_ids (tsubst : T.TypeVarId.id -> T.TypeVarId.id)
- (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ty : 'r T.ty)
+(** Substitute variable identifiers in a type *)
+let ty_substitute_ids (ty_subst : T.TypeVarId.id -> T.TypeVarId.id)
+ (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ty : 'r T.ty)
: 'r T.ty =
let open T in
let visitor =
object
inherit [_] map_ty
method visit_'r _ r = r
- method! visit_type_var_id _ id = tsubst id
- method! visit_const_generic_var_id _ id = cgsubst id
+ method! visit_type_var_id _ id = ty_subst id
+ method! visit_const_generic_var_id _ id = cg_subst id
end
in
@@ -371,10 +448,10 @@ let ty_substitute_ids (tsubst : T.TypeVarId.id -> T.TypeVarId.id)
[visit_'r] if we define a class which visits objects of types [ety] and [rty]
while inheriting a class which visit [ty]...
*)
-let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id)
+let subst_ids_visitor (r_subst : T.RegionId.id -> T.RegionId.id)
(rvsubst : T.RegionVarId.id -> T.RegionVarId.id)
- (tsubst : T.TypeVarId.id -> T.TypeVarId.id)
- (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
+ (ty_subst : T.TypeVarId.id -> T.TypeVarId.id)
+ (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
(ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id)
(bsubst : V.BorrowId.id -> V.BorrowId.id)
(asubst : V.AbstractionId.id -> V.AbstractionId.id) =
@@ -383,10 +460,10 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id)
inherit [_] T.map_ty
method visit_'r _ r =
- match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid)
+ match r with T.Static -> T.Static | T.Var rid -> T.Var (r_subst rid)
- method! visit_type_var_id _ id = tsubst id
- method! visit_const_generic_var_id _ id = cgsubst id
+ method! visit_type_var_id _ id = ty_subst id
+ method! visit_const_generic_var_id _ id = cg_subst id
end
in
@@ -395,7 +472,7 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id)
inherit [_] C.map_env
method! visit_borrow_id _ bid = bsubst bid
method! visit_loan_id _ bid = bsubst bid
- method! visit_ety _ ty = ty_substitute_ids tsubst cgsubst ty
+ method! visit_ety _ ty = ty_substitute_ids ty_subst cg_subst ty
method! visit_rty env ty = subst_rty#visit_ty env ty
method! visit_symbolic_value_id _ id = ssubst id
@@ -405,7 +482,7 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id)
(** We *do* visit meta-values *)
method! visit_mvalue env v = self#visit_typed_value env v
- method! visit_region_id _ id = rsubst id
+ method! visit_region_id _ id = r_subst id
method! visit_region_var_id _ id = rvsubst id
method! visit_abstraction_id _ id = asubst id
end
@@ -425,20 +502,20 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id)
method visit_env (env : C.env) : C.env = visitor#visit_env () env
end
-let typed_value_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id)
+let typed_value_subst_ids (r_subst : T.RegionId.id -> T.RegionId.id)
(rvsubst : T.RegionVarId.id -> T.RegionVarId.id)
- (tsubst : T.TypeVarId.id -> T.TypeVarId.id)
- (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
+ (ty_subst : T.TypeVarId.id -> T.TypeVarId.id)
+ (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
(ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id)
(bsubst : V.BorrowId.id -> V.BorrowId.id) (v : V.typed_value) :
V.typed_value =
let asubst _ = raise (Failure "Unreachable") in
- (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst)
+ (subst_ids_visitor r_subst rvsubst ty_subst cg_subst ssubst bsubst asubst)
#visit_typed_value v
-let typed_value_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id)
+let typed_value_subst_rids (r_subst : T.RegionId.id -> T.RegionId.id)
(v : V.typed_value) : V.typed_value =
- typed_value_subst_ids rsubst
+ typed_value_subst_ids r_subst
(fun x -> x)
(fun x -> x)
(fun x -> x)
@@ -446,41 +523,41 @@ let typed_value_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id)
(fun x -> x)
v
-let typed_avalue_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id)
+let typed_avalue_subst_ids (r_subst : T.RegionId.id -> T.RegionId.id)
(rvsubst : T.RegionVarId.id -> T.RegionVarId.id)
- (tsubst : T.TypeVarId.id -> T.TypeVarId.id)
- (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
+ (ty_subst : T.TypeVarId.id -> T.TypeVarId.id)
+ (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
(ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id)
(bsubst : V.BorrowId.id -> V.BorrowId.id) (v : V.typed_avalue) :
V.typed_avalue =
let asubst _ = raise (Failure "Unreachable") in
- (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst)
+ (subst_ids_visitor r_subst rvsubst ty_subst cg_subst ssubst bsubst asubst)
#visit_typed_avalue v
-let abs_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id)
+let abs_subst_ids (r_subst : T.RegionId.id -> T.RegionId.id)
(rvsubst : T.RegionVarId.id -> T.RegionVarId.id)
- (tsubst : T.TypeVarId.id -> T.TypeVarId.id)
- (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
+ (ty_subst : T.TypeVarId.id -> T.TypeVarId.id)
+ (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
(ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id)
(bsubst : V.BorrowId.id -> V.BorrowId.id)
(asubst : V.AbstractionId.id -> V.AbstractionId.id) (x : V.abs) : V.abs =
- (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst)
+ (subst_ids_visitor r_subst rvsubst ty_subst cg_subst ssubst bsubst asubst)
#visit_abs x
-let env_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id)
+let env_subst_ids (r_subst : T.RegionId.id -> T.RegionId.id)
(rvsubst : T.RegionVarId.id -> T.RegionVarId.id)
- (tsubst : T.TypeVarId.id -> T.TypeVarId.id)
- (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
+ (ty_subst : T.TypeVarId.id -> T.TypeVarId.id)
+ (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
(ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id)
(bsubst : V.BorrowId.id -> V.BorrowId.id)
(asubst : V.AbstractionId.id -> V.AbstractionId.id) (x : C.env) : C.env =
- (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst)
+ (subst_ids_visitor r_subst rvsubst ty_subst cg_subst ssubst bsubst asubst)
#visit_env x
-let typed_avalue_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id)
+let typed_avalue_subst_rids (r_subst : T.RegionId.id -> T.RegionId.id)
(x : V.typed_avalue) : V.typed_avalue =
let asubst _ = raise (Failure "Unreachable") in
- (subst_ids_visitor rsubst
+ (subst_ids_visitor r_subst
(fun x -> x)
(fun x -> x)
(fun x -> x)
@@ -490,9 +567,9 @@ let typed_avalue_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id)
#visit_typed_avalue
x
-let env_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) (x : C.env) : C.env
- =
- (subst_ids_visitor rsubst
+let env_subst_rids (r_subst : T.RegionId.id -> T.RegionId.id) (x : C.env) :
+ C.env =
+ (subst_ids_visitor r_subst
(fun x -> x)
(fun x -> x)
(fun x -> x)
diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml
index 17cdcabc..0f107897 100644
--- a/compiler/SymbolicAst.ml
+++ b/compiler/SymbolicAst.ml
@@ -43,10 +43,7 @@ type call = {
borrows (we need to perform lookups).
*)
abstractions : V.AbstractionId.id list;
- (* TODO: rename to "...args" *)
- type_params : T.ety list;
- (* TODO: rename to "...args" *)
- const_generic_params : T.const_generic list;
+ generics : T.egeneric_args;
args : V.typed_value list;
args_places : mplace option list; (** Meta information *)
dest : V.symbolic_value;
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 7dda1f22..6c2c049b 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -52,6 +52,9 @@ type fun_context = {
type global_context = { llbc_global_decls : A.global_decl A.GlobalDeclId.Map.t }
[@@deriving show]
+type trait_decls_context = A.trait_decl A.TraitDeclId.Map.t [@@deriving show]
+type trait_impls_context = A.trait_impl A.TraitImplId.Map.t [@@deriving show]
+
(** Whenever we translate a function call or an ended abstraction, we
store the related information (this is useful when translating ended
children abstractions).
@@ -120,6 +123,8 @@ type bs_ctx = {
type_context : type_context;
fun_context : fun_context;
global_context : global_context;
+ trait_decls_ctx : trait_decls_context;
+ trait_impls_ctx : trait_impls_context;
fun_decl : A.fun_decl;
bid : T.RegionGroupId.id option; (** TODO: rename *)
sg : fun_sig;
@@ -205,7 +210,7 @@ type bs_ctx = {
let bs_ctx_to_ast_formatter (ctx : bs_ctx) : Print.Ast.ast_formatter =
Print.Ast.decls_and_fun_decl_to_ast_formatter ctx.type_context.llbc_type_decls
ctx.fun_context.llbc_fun_decls ctx.global_context.llbc_global_decls
- ctx.fun_decl
+ ctx.trait_decls_ctx ctx.trait_impls_ctx ctx.fun_decl
let bs_ctx_to_ctx_formatter (ctx : bs_ctx) : Print.Contexts.ctx_formatter =
let rvar_to_string = Print.Types.region_var_id_to_string in
@@ -223,16 +228,19 @@ let bs_ctx_to_ctx_formatter (ctx : bs_ctx) : Print.Contexts.ctx_formatter =
adt_variant_to_string = ast_fmt.adt_variant_to_string;
var_id_to_string;
adt_field_names = ast_fmt.adt_field_names;
+ trait_decl_id_to_string = ast_fmt.trait_decl_id_to_string;
+ trait_impl_id_to_string = ast_fmt.trait_impl_id_to_string;
+ trait_clause_id_to_string = ast_fmt.trait_clause_id_to_string;
}
let bs_ctx_to_pp_ast_formatter (ctx : bs_ctx) : PrintPure.ast_formatter =
- let type_params = ctx.fun_decl.signature.type_params in
- let cg_params = ctx.fun_decl.signature.const_generic_params in
+ let generics = ctx.fun_decl.signature.generics in
let type_decls = ctx.type_context.llbc_type_decls in
let fun_decls = ctx.fun_context.llbc_fun_decls in
let global_decls = ctx.global_context.llbc_global_decls in
- PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params
- cg_params
+ PrintPure.mk_ast_formatter type_decls fun_decls global_decls
+ ctx.trait_decls_ctx ctx.trait_impls_ctx generics.types
+ generics.const_generics
let symbolic_value_to_string (ctx : bs_ctx) (sv : V.symbolic_value) : string =
let fmt = bs_ctx_to_ctx_formatter ctx in
@@ -254,12 +262,11 @@ let rty_to_string (ctx : bs_ctx) (ty : T.rty) : string =
Print.PT.rty_to_string fmt ty
let type_decl_to_string (ctx : bs_ctx) (def : type_decl) : string =
- let type_params = def.type_params in
- let cg_params = def.const_generic_params in
let type_decls = ctx.type_context.llbc_type_decls in
let global_decls = ctx.global_context.llbc_global_decls in
let fmt =
- PrintPure.mk_type_formatter type_decls global_decls type_params cg_params
+ PrintPure.mk_type_formatter type_decls global_decls ctx.trait_decls_ctx
+ ctx.trait_impls_ctx def.generics.types def.generics.const_generics
in
PrintPure.type_decl_to_string fmt def
diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml
index 857fea97..cac56487 100644
--- a/compiler/SynthesizeSymbolic.ml
+++ b/compiler/SynthesizeSymbolic.ml
@@ -64,7 +64,7 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value)
assert (otherwise_see = None);
(* Return *)
ExpandInt (int_ty, branches, otherwise)
- | T.Adt (_, _, _, _) ->
+ | T.Adt (_, _) ->
(* Branching: it is necessarily an enumeration expansion *)
let get_variant (see : V.symbolic_expansion option) :
T.VariantId.id option * V.symbolic_value list =
@@ -85,7 +85,7 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value)
match ls with
| [ (Some see, exp) ] -> ExpandNoBranch (see, exp)
| _ -> raise (Failure "Ill-formed borrow expansion"))
- | T.TypeVar _ | T.Literal Char | Never ->
+ | T.TypeVar _ | T.Literal Char | Never | T.TraitType _ ->
raise (Failure "Ill-formed symbolic expansion")
in
Some (Expansion (place, sv, expansion))
@@ -97,10 +97,10 @@ let synthesize_symbolic_expansion_no_branching (sv : V.symbolic_value)
synthesize_symbolic_expansion sv place [ Some see ] el
let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx)
- (abstractions : V.AbstractionId.id list) (type_params : T.ety list)
- (const_generic_params : T.const_generic list) (args : V.typed_value list)
- (args_places : mplace option list) (dest : V.symbolic_value)
- (dest_place : mplace option) (e : expression option) : expression option =
+ (abstractions : V.AbstractionId.id list) (generics : T.egeneric_args)
+ (args : V.typed_value list) (args_places : mplace option list)
+ (dest : V.symbolic_value) (dest_place : mplace option)
+ (e : expression option) : expression option =
Option.map
(fun e ->
let call =
@@ -108,8 +108,7 @@ let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx)
call_id;
ctx;
abstractions;
- type_params;
- const_generic_params;
+ generics;
args;
dest;
args_places;
@@ -125,26 +124,27 @@ let synthesize_global_eval (gid : A.GlobalDeclId.id) (dest : V.symbolic_value)
let synthesize_regular_function_call (fun_id : A.fun_id)
(call_id : V.FunCallId.id) (ctx : Contexts.eval_ctx)
- (abstractions : V.AbstractionId.id list) (type_params : T.ety list)
- (const_generic_params : T.const_generic list) (args : V.typed_value list)
- (args_places : mplace option list) (dest : V.symbolic_value)
- (dest_place : mplace option) (e : expression option) : expression option =
+ (abstractions : V.AbstractionId.id list) (generics : T.egeneric_args)
+ (args : V.typed_value list) (args_places : mplace option list)
+ (dest : V.symbolic_value) (dest_place : mplace option)
+ (e : expression option) : expression option =
synthesize_function_call
(Fun (fun_id, call_id))
- ctx abstractions type_params const_generic_params args args_places dest
- dest_place e
+ ctx abstractions generics args args_places dest dest_place e
let synthesize_unary_op (ctx : Contexts.eval_ctx) (unop : E.unop)
(arg : V.typed_value) (arg_place : mplace option) (dest : V.symbolic_value)
(dest_place : mplace option) (e : expression option) : expression option =
- synthesize_function_call (Unop unop) ctx [] [] [] [ arg ] [ arg_place ] dest
- dest_place e
+ let generics = TypesUtils.mk_empty_generic_args in
+ synthesize_function_call (Unop unop) ctx [] generics [ arg ] [ arg_place ]
+ dest dest_place e
let synthesize_binary_op (ctx : Contexts.eval_ctx) (binop : E.binop)
(arg0 : V.typed_value) (arg0_place : mplace option) (arg1 : V.typed_value)
(arg1_place : mplace option) (dest : V.symbolic_value)
(dest_place : mplace option) (e : expression option) : expression option =
- synthesize_function_call (Binop binop) ctx [] [] [] [ arg0; arg1 ]
+ let generics = TypesUtils.mk_empty_generic_args in
+ synthesize_function_call (Binop binop) ctx [] generics [ arg0; arg1 ]
[ arg0_place; arg1_place ] dest dest_place e
let synthesize_end_abstraction (ctx : Contexts.eval_ctx) (abs : V.abs)
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 70ef5e3d..ca661108 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -280,9 +280,7 @@ let translate_crate_to_pure (crate : A.crate) :
log#ldebug (lazy "translate_crate_to_pure");
(* Compute the type and function contexts *)
- let type_context, fun_context, global_context =
- compute_type_fun_global_contexts crate
- in
+ let type_context, fun_context, global_context = compute_contexts crate in
let fun_infos =
FA.analyze_module crate fun_context.C.fun_decls
global_context.C.global_decls !Config.use_state
diff --git a/compiler/TypesAnalysis.ml b/compiler/TypesAnalysis.ml
index 925f6d39..95c7206a 100644
--- a/compiler/TypesAnalysis.ml
+++ b/compiler/TypesAnalysis.ml
@@ -14,11 +14,10 @@ type expl_info = subtype_info [@@deriving show]
type type_borrows_info = {
contains_static : bool;
- (** Does the type (transitively) contains a static borrow? *)
- contains_borrow : bool;
- (** Does the type (transitively) contains a borrow? *)
+ (** Does the type (transitively) contain a static borrow? *)
+ contains_borrow : bool; (** Does the type (transitively) contain a borrow? *)
contains_nested_borrows : bool;
- (** Does the type (transitively) contains nested borrows? *)
+ (** Does the type (transitively) contain nested borrows? *)
contains_borrow_under_mut : bool;
}
[@@deriving show]
@@ -61,7 +60,7 @@ let initialize_g_type_info (param_infos : 'p) : 'p g_type_info =
let initialize_type_decl_info (def : type_decl) : type_decl_info =
let param_info = { under_borrow = false; under_mut_borrow = false } in
- let param_infos = List.map (fun _ -> param_info) def.type_params in
+ let param_infos = List.map (fun _ -> param_info) def.generics.types in
initialize_g_type_info param_infos
let type_decl_info_to_partial_type_info (info : type_decl_info) :
@@ -122,7 +121,7 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref)
let rec analyze (expl_info : expl_info) (ty_info : partial_type_info)
(ty : 'r ty) : partial_type_info =
match ty with
- | Literal _ | Never -> ty_info
+ | Literal _ | Never | TraitType _ -> ty_info
| TypeVar var_id -> (
(* Update the information for the proper parameter, if necessary *)
match ty_info.param_infos with
@@ -171,20 +170,18 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref)
analyze expl_info ty_info rty
| Adt
( (Tuple | Assumed (Box | Vec | Option | Slice | Array | Str | Range)),
- _,
- tys,
- _ ) ->
+ generics ) ->
(* Nothing to update: just explore the type parameters *)
List.fold_left
(fun ty_info ty -> analyze expl_info ty_info ty)
- ty_info tys
- | Adt (AdtId adt_id, regions, tys, _cgs) ->
+ ty_info generics.types
+ | Adt (AdtId adt_id, generics) ->
(* Lookup the information for this type definition *)
let adt_info = TypeDeclId.Map.find adt_id infos in
(* Update the type info with the information from the adt *)
let ty_info = update_ty_info ty_info adt_info.borrows_info in
(* Check if 'static appears in the region parameters *)
- let found_static = List.exists r_is_static regions in
+ let found_static = List.exists r_is_static generics.regions in
let borrows_info = ty_info.borrows_info in
let borrows_info =
{
@@ -196,7 +193,7 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref)
let ty_info = { ty_info with borrows_info } in
(* For every instantiated type parameter: update the exploration info
* then explore the type *)
- let params_tys = List.combine adt_info.param_infos tys in
+ let params_tys = List.combine adt_info.param_infos generics.types in
let ty_info =
List.fold_left
(fun ty_info (param_info, ty) ->
diff --git a/compiler/dune b/compiler/dune
index 6785cad4..db099c3c 100644
--- a/compiler/dune
+++ b/compiler/dune
@@ -12,6 +12,7 @@
(pps ppx_deriving.show ppx_deriving.ord visitors.ppx))
(libraries charon core_unix unionFind ocamlgraph)
(modules
+ AssociatedTypes
Assumed
Collections
Config