summaryrefslogtreecommitdiff
path: root/compiler
diff options
context:
space:
mode:
Diffstat (limited to 'compiler')
-rw-r--r--compiler/AssociatedTypes.ml577
-rw-r--r--compiler/Assumed.ml214
-rw-r--r--compiler/Config.ml20
-rw-r--r--compiler/Contexts.ml127
-rw-r--r--compiler/Driver.ml22
-rw-r--r--compiler/Extract.ml1386
-rw-r--r--compiler/ExtractAssumed.ml61
-rw-r--r--compiler/ExtractBase.ml521
-rw-r--r--compiler/FunsAnalysis.ml15
-rw-r--r--compiler/Interpreter.ml259
-rw-r--r--compiler/InterpreterBorrows.ml3
-rw-r--r--compiler/InterpreterBorrowsCore.ml16
-rw-r--r--compiler/InterpreterExpansion.ml45
-rw-r--r--compiler/InterpreterExpressions.ml117
-rw-r--r--compiler/InterpreterLoopsJoinCtxs.ml18
-rw-r--r--compiler/InterpreterLoopsMatchCtxs.ml19
-rw-r--r--compiler/InterpreterPaths.ml55
-rw-r--r--compiler/InterpreterPaths.mli7
-rw-r--r--compiler/InterpreterProjectors.ml15
-rw-r--r--compiler/InterpreterStatements.ml713
-rw-r--r--compiler/InterpreterStatements.mli9
-rw-r--r--compiler/InterpreterUtils.ml119
-rw-r--r--compiler/Invariants.ml67
-rw-r--r--compiler/LlbcAst.ml1
-rw-r--r--compiler/LlbcAstUtils.ml29
-rw-r--r--compiler/Logging.ml6
-rw-r--r--compiler/PrePasses.ml4
-rw-r--r--compiler/Print.ml132
-rw-r--r--compiler/PrintPure.ml205
-rw-r--r--compiler/Pure.ml177
-rw-r--r--compiler/PureMicroPasses.ml165
-rw-r--r--compiler/PureTypeCheck.ml53
-rw-r--r--compiler/PureUtils.ml155
-rw-r--r--compiler/ReorderDecls.ml8
-rw-r--r--compiler/Substitute.ml492
-rw-r--r--compiler/SymbolicAst.ml33
-rw-r--r--compiler/SymbolicToPure.ml744
-rw-r--r--compiler/SynthesizeSymbolic.ml36
-rw-r--r--compiler/Translate.ml385
-rw-r--r--compiler/TranslateCore.ml79
-rw-r--r--compiler/TypesAnalysis.ml23
-rw-r--r--compiler/Values.ml4
-rw-r--r--compiler/dune2
43 files changed, 5221 insertions, 1917 deletions
diff --git a/compiler/AssociatedTypes.ml b/compiler/AssociatedTypes.ml
new file mode 100644
index 00000000..992dade9
--- /dev/null
+++ b/compiler/AssociatedTypes.ml
@@ -0,0 +1,577 @@
+(** This file implements utilities to handle trait associated types, in
+ particular with normalization helpers.
+
+ When normalizing a type, we simplify the references to the trait associated
+ types, and choose a representative when there are equalities between types
+ enforced by local clauses (i.e., clauses of the shape [where Trait1::T = Trait2::U]).
+ *)
+
+module T = Types
+module TU = TypesUtils
+module V = Values
+module E = Expressions
+module A = LlbcAst
+module C = Contexts
+module Subst = Substitute
+module L = Logging
+module UF = UnionFind
+module PA = Print.EvalCtxLlbcAst
+
+(** The local logger *)
+let log = L.associated_types_log
+
+let trait_type_ref_substitute (subst : ('r, 'r1) Subst.subst)
+ (r : 'r C.trait_type_ref) : 'r1 C.trait_type_ref =
+ let { C.trait_ref; type_name } = r in
+ let trait_ref = Subst.trait_ref_substitute subst trait_ref in
+ { C.trait_ref; type_name }
+
+(* TODO: how not to duplicate below? *)
+module RTyOrd = struct
+ type t = T.rty
+
+ let compare = T.compare_rty
+ let to_string = T.show_rty
+ let pp_t = T.pp_rty
+ let show_t = T.show_rty
+end
+
+module STyOrd = struct
+ type t = T.sty
+
+ let compare = T.compare_sty
+ let to_string = T.show_sty
+ let pp_t = T.pp_sty
+ let show_t = T.show_sty
+end
+
+module RTyMap = Collections.MakeMap (RTyOrd)
+module STyMap = Collections.MakeMap (STyOrd)
+
+(* TODO: is it possible not to have this? *)
+module type TypeWrapper = sig
+ type t
+end
+
+(* TODO: don't manage to get the syntax right so using a functor *)
+module MakeNormalizer
+ (R : TypeWrapper)
+ (RTyMap : Collections.Map with type key = R.t T.region T.ty)
+ (M : Collections.Map with type key = R.t T.region C.trait_type_ref) =
+struct
+ let compute_norm_trait_types_from_preds
+ (trait_type_constraints : R.t T.region T.trait_type_constraint list) :
+ R.t T.region T.ty M.t =
+ (* Compute a union-find structure by recursively exploring the predicates and clauses *)
+ let norm : R.t T.region T.ty UF.elem RTyMap.t ref = ref RTyMap.empty in
+ let get_ref (ty : R.t T.region T.ty) : R.t T.region T.ty UF.elem =
+ match RTyMap.find_opt ty !norm with
+ | Some r -> r
+ | None ->
+ let r = UF.make ty in
+ norm := RTyMap.add ty r !norm;
+ r
+ in
+ let add_trait_type_constraint (c : R.t T.region T.trait_type_constraint) =
+ let trait_ty = T.TraitType (c.trait_ref, c.generics, c.type_name) in
+ let trait_ty_ref = get_ref trait_ty in
+ let ty_ref = get_ref c.ty in
+ let new_repr = UF.get ty_ref in
+ let merged = UF.union trait_ty_ref ty_ref in
+ (* Not sure the set operation is necessary, but I want to control which
+ representative is chosen *)
+ UF.set merged new_repr
+ in
+ (* Explore the local predicates *)
+ List.iter add_trait_type_constraint trait_type_constraints;
+ (* TODO: explore the local clauses *)
+ (* Compute the norm maps *)
+ let rbindings =
+ List.map (fun (k, v) -> (k, UF.get v)) (RTyMap.bindings !norm)
+ in
+ (* Filter the keys to keep only the trait type aliases *)
+ let rbindings =
+ List.filter_map
+ (fun (k, v) ->
+ match k with
+ | T.TraitType (trait_ref, generics, type_name) ->
+ assert (generics = TypesUtils.mk_empty_generic_args);
+ Some ({ C.trait_ref; type_name }, v)
+ | _ -> None)
+ rbindings
+ in
+ M.of_list rbindings
+end
+
+(** Compute the representative classes of trait associated types, for normalization *)
+let compute_norm_trait_stypes_from_preds
+ (trait_type_constraints : T.strait_type_constraint list) :
+ T.sty C.STraitTypeRefMap.t =
+ (* Compute the normalization map for the types with regions *)
+ let module R = struct
+ type t = T.region_var_id
+ end in
+ let module M = C.STraitTypeRefMap in
+ let module Norm = MakeNormalizer (R) (STyMap) (M) in
+ Norm.compute_norm_trait_types_from_preds trait_type_constraints
+
+(** Compute the representative classes of trait associated types, for normalization *)
+let compute_norm_trait_types_from_preds
+ (trait_type_constraints : T.rtrait_type_constraint list) :
+ T.ety C.ETraitTypeRefMap.t * T.rty C.RTraitTypeRefMap.t =
+ (* Compute the normalization map for the types with regions *)
+ let module R = struct
+ type t = T.region_id
+ end in
+ let module M = C.RTraitTypeRefMap in
+ let module Norm = MakeNormalizer (R) (RTyMap) (M) in
+ let rbindings =
+ Norm.compute_norm_trait_types_from_preds trait_type_constraints
+ in
+ (* Compute the normalization map for the types with erased regions *)
+ let ebindings =
+ List.map
+ (fun (k, v) ->
+ ( trait_type_ref_substitute Subst.erase_regions_subst k,
+ Subst.erase_regions v ))
+ (M.bindings rbindings)
+ in
+ (C.ETraitTypeRefMap.of_list ebindings, rbindings)
+
+let ctx_add_norm_trait_stypes_from_preds (ctx : C.eval_ctx)
+ (trait_type_constraints : T.strait_type_constraint list) : C.eval_ctx =
+ let norm_trait_stypes =
+ compute_norm_trait_stypes_from_preds trait_type_constraints
+ in
+ { ctx with C.norm_trait_stypes }
+
+let ctx_add_norm_trait_types_from_preds (ctx : C.eval_ctx)
+ (trait_type_constraints : T.rtrait_type_constraint list) : C.eval_ctx =
+ let norm_trait_etypes, norm_trait_rtypes =
+ compute_norm_trait_types_from_preds trait_type_constraints
+ in
+ { ctx with C.norm_trait_etypes; norm_trait_rtypes }
+
+(** A trait instance id refers to a local clause if it only uses the variants:
+ [Self], [Clause], [ParentClause], [ItemClause] *)
+let rec trait_instance_id_is_local_clause (id : 'r T.trait_instance_id) : bool =
+ match id with
+ | T.Self | Clause _ -> true
+ | TraitImpl _ | BuiltinOrAuto _ | TraitRef _ | UnknownTrait _ -> false
+ | ParentClause (id, _, _) | ItemClause (id, _, _, _) ->
+ trait_instance_id_is_local_clause id
+
+(** About the conversion functions: for now we need them (TODO: merge ety, rty, etc.),
+ but they should be applied to types without regions.
+ *)
+type 'r norm_ctx = {
+ ctx : C.eval_ctx;
+ get_ty_repr : 'r C.trait_type_ref -> 'r T.ty option;
+ convert_ety : T.ety -> 'r T.ty;
+ convert_etrait_ref : T.etrait_ref -> 'r T.trait_ref;
+ ty_to_string : 'r T.ty -> string;
+ trait_ref_to_string : 'r T.trait_ref -> string;
+ trait_instance_id_to_string : 'r T.trait_instance_id -> string;
+ pp_r : Format.formatter -> 'r -> unit;
+}
+
+(** Normalize a type by simplyfying the references to trait associated types
+ and choosing a representative when there are equalities between types
+ enforced by local clauses (i.e., `where Trait1::T = Trait2::U`. *)
+let rec ctx_normalize_ty : 'r. 'r norm_ctx -> 'r T.ty -> 'r T.ty =
+ fun ctx ty ->
+ log#ldebug (lazy ("ctx_normalize_ty: " ^ ctx.ty_to_string ty));
+ match ty with
+ | T.Adt (id, generics) -> Adt (id, ctx_normalize_generic_args ctx generics)
+ | TypeVar _ | Literal _ | Never -> ty
+ | Ref (r, ty, rkind) ->
+ let ty = ctx_normalize_ty ctx ty in
+ T.Ref (r, ty, rkind)
+ | TraitType (trait_ref, generics, type_name) -> (
+ log#ldebug
+ (lazy
+ ("ctx_normalize_ty: trait type: " ^ ctx.ty_to_string ty
+ ^ "\n- trait_ref: "
+ ^ ctx.trait_ref_to_string trait_ref
+ ^ "\n- raw trait ref:\n"
+ ^ T.show_trait_ref ctx.pp_r trait_ref));
+ (* Normalize and attempt to project the type from the trait ref *)
+ let trait_ref = ctx_normalize_trait_ref ctx trait_ref in
+ let generics = ctx_normalize_generic_args ctx generics in
+ let ty : 'r T.ty =
+ match trait_ref.trait_id with
+ | T.TraitRef
+ { T.trait_id = T.TraitImpl impl_id; generics = ref_generics; _ } ->
+ assert (ref_generics = TypesUtils.mk_empty_generic_args);
+ log#ldebug
+ (lazy
+ ("ctx_normalize_ty: trait type: trait ref: "
+ ^ ctx.ty_to_string ty));
+ (* Lookup the implementation *)
+ let trait_impl = C.ctx_lookup_trait_impl ctx.ctx impl_id in
+ (* Lookup the type *)
+ let ty = snd (List.assoc type_name trait_impl.types) in
+ (* Annoying: convert etype to an stype - TODO: hwo to avoid that? *)
+ let ty : T.sty = TypesUtils.ety_no_regions_to_gr_ty ty in
+ (* Substitute *)
+ let tr_self = T.UnknownTrait __FUNCTION__ in
+ let subst =
+ Subst.make_subst_from_generics_no_regions trait_impl.generics
+ generics tr_self
+ in
+ let ty = Subst.ty_substitute subst ty in
+ (* Reconvert *)
+ let ty : 'r T.ty = ctx.convert_ety (Subst.erase_regions ty) in
+ (* Normalize *)
+ ctx_normalize_ty ctx ty
+ | T.TraitImpl impl_id ->
+ (* This happens. This doesn't come from the substituations
+ performed by Aeneas (the [TraitImpl] would be wrapped in a
+ [TraitRef] but from non-normalized traits translated from
+ the Rustc AST.
+ TODO: factor out with the branch above.
+ *)
+ (* Lookup the implementation *)
+ let trait_impl = C.ctx_lookup_trait_impl ctx.ctx impl_id in
+ (* Lookup the type *)
+ let ty = snd (List.assoc type_name trait_impl.types) in
+ (* Annoying: convert etype to an stype - TODO: hwo to avoid that? *)
+ let ty : T.sty = TypesUtils.ety_no_regions_to_gr_ty ty in
+ (* Substitute *)
+ let tr_self = T.UnknownTrait __FUNCTION__ in
+ let subst =
+ Subst.make_subst_from_generics_no_regions trait_impl.generics
+ generics tr_self
+ in
+ let ty = Subst.ty_substitute subst ty in
+ (* Reconvert *)
+ let ty : 'r T.ty = ctx.convert_ety (Subst.erase_regions ty) in
+ (* Normalize *)
+ ctx_normalize_ty ctx ty
+ | _ ->
+ log#ldebug
+ (lazy
+ ("ctx_normalize_ty: trait type: not a trait ref: "
+ ^ ctx.ty_to_string ty ^ "\n- trait_ref: "
+ ^ ctx.trait_ref_to_string trait_ref
+ ^ "\n- raw trait ref:\n"
+ ^ T.show_trait_ref ctx.pp_r trait_ref));
+ (* We can't project *)
+ assert (trait_instance_id_is_local_clause trait_ref.trait_id);
+ T.TraitType (trait_ref, generics, type_name)
+ in
+ let tr : 'r C.trait_type_ref = { C.trait_ref; type_name } in
+ (* Lookup the representative, if there is *)
+ match ctx.get_ty_repr tr with None -> ty | Some ty -> ty)
+
+(** This returns the normalized trait instance id together with an optional
+ reference to a trait **implementation**.
+
+ We need this in particular to simplify the trait instance ids after we
+ performed a substitution.
+
+ Example:
+ ========
+ {[
+ trait Trait {
+ type S
+ }
+
+ impl TraitImpl for Foo {
+ type S = usize
+ }
+
+ fn f<T : Trait>(...) -> T::S;
+
+ ...
+ let x = f<Foo>[TraitImpl](...); // T::S ~~> TraitImpl::S ~~> usize
+ ]}
+
+ Several remarks:
+ - as we do not allow higher-order types (yet) then local clauses (and
+ sub-clauses) can't have generic arguments
+ - the [TraitRef] case only happens because of substitution, the role of
+ the normalization is in particular to eliminate it. Inside a [TraitRef]
+ there is necessarily:
+ - an id referencing a local (sub-)clause, that is an id using the variants
+ [Self], [Clause], [ItemClause] and [ParentClause] exclusively. We can't
+ simplify those cases: all we can do is remove the [TraitRef] wrapper
+ by leveraging the fact that the generic arguments must be empty.
+ - a [TraitImpl]. Note that the [TraitImpl] is necessarily just a [TraitImpl],
+ it can't be for instance a [ParentClause(TraitImpl ...)] because the
+ trait resolution would then directly reference the implementation
+ designated by [ParentClause(TraitImpl ...)] (and same for the other cases).
+ In this case we can lookup the trait implementation and recursively project
+ over it.
+ *)
+and ctx_normalize_trait_instance_id :
+ 'r.
+ 'r norm_ctx ->
+ 'r T.trait_instance_id ->
+ 'r T.trait_instance_id * 'r T.trait_ref option =
+ fun ctx id ->
+ match id with
+ | Self -> (id, None)
+ | TraitImpl _ ->
+ (* The [TraitImpl] shouldn't be inside any projection - we check this
+ elsewhere by asserting that whenever we return [None] for the impl
+ trait ref, then the id actually refers to a local clause. *)
+ (id, None)
+ | Clause _ -> (id, None)
+ | BuiltinOrAuto _ -> (id, None)
+ | ParentClause (inst_id, decl_id, clause_id) -> (
+ let inst_id, impl = ctx_normalize_trait_instance_id ctx inst_id in
+ (* Check if the inst_id refers to a specific implementation, if yes project *)
+ match impl with
+ | None ->
+ (* This is actually a local clause *)
+ assert (trait_instance_id_is_local_clause inst_id);
+ (ParentClause (inst_id, decl_id, clause_id), None)
+ | Some impl ->
+ (* We figure out the parent clause by doing the following:
+ {[
+ // The implementation we are looking at
+ impl Impl1 : Trait1 { ... }
+
+ // Check the trait it implements
+ trait Trait1 : ParentTrait1 + ParentTrait2 { ... }
+ ^^^^^^^^^^^^^^^^^^^^^^^^^^^
+ those are the parent clauses
+ ]}
+
+ We can find the parent clauses in the [trait_decl_ref] field, which
+ tells us which specific instantiation of [Trait1] is implemented
+ by [Impl1].
+ *)
+ let clause =
+ T.TraitClauseId.nth impl.trait_decl_ref.decl_generics.trait_refs
+ clause_id
+ in
+ (* Sanity check: the clause necessarily refers to an impl *)
+ let _ = TypesUtils.trait_instance_id_as_trait_impl clause.trait_id in
+ (TraitRef clause, Some clause))
+ | ItemClause (inst_id, decl_id, item_name, clause_id) -> (
+ let inst_id, impl = ctx_normalize_trait_instance_id ctx inst_id in
+ (* Check if the inst_id refers to a specific implementation, if yes project *)
+ match impl with
+ | None ->
+ (* This is actually a local clause *)
+ assert (trait_instance_id_is_local_clause inst_id);
+ (ParentClause (inst_id, decl_id, clause_id), None)
+ | Some impl ->
+ (* We figure out the item clause by doing the following:
+ {[
+ // The implementation we are looking at
+ impl Impl1 : Trait1<R> {
+ type S = ...
+ with Impl2 : Trait2 ... // Instances satisfying the declared bounds
+ ^^^^^^^^^^^^^^^^^^
+ Lookup the clause from here
+ }
+ ]}
+ *)
+ (* The referenced instance should be an impl *)
+ let impl_id =
+ TypesUtils.trait_instance_id_as_trait_impl impl.trait_id
+ in
+ let trait_impl = C.ctx_lookup_trait_impl ctx.ctx impl_id in
+ (* Lookup the clause *)
+ let item = List.assoc item_name trait_impl.types in
+ let clause = T.TraitClauseId.nth (fst item) clause_id in
+ (* Sanity check: the clause necessarily refers to an impl *)
+ let _ = TypesUtils.trait_instance_id_as_trait_impl clause.trait_id in
+ (* We need to convert the clause type -
+ TODO: we have too many problems with those conversions, we should
+ merge the types. *)
+ let clause = ctx.convert_etrait_ref clause in
+ (TraitRef clause, Some clause))
+ | TraitRef { T.trait_id = T.TraitImpl trait_id; generics; trait_decl_ref } ->
+ (* We can't simplify the id *yet* : we will simplify it when projecting.
+ However, we have an implementation to return *)
+ (* Normalize the generics *)
+ let generics = ctx_normalize_generic_args ctx generics in
+ let trait_decl_ref = ctx_normalize_trait_decl_ref ctx trait_decl_ref in
+ let trait_ref : 'r T.trait_ref =
+ { T.trait_id = T.TraitImpl trait_id; generics; trait_decl_ref }
+ in
+ (TraitRef trait_ref, Some trait_ref)
+ | TraitRef trait_ref ->
+ (* The trait instance id necessarily refers to a local sub-clause. We
+ can't project over it and can only peel off the [TraitRef] wrapper *)
+ assert (trait_instance_id_is_local_clause trait_ref.trait_id);
+ assert (trait_ref.generics = TypesUtils.mk_empty_generic_args);
+ (trait_ref.trait_id, None)
+ | UnknownTrait _ ->
+ (* This is actually an error case *)
+ (id, None)
+
+and ctx_normalize_generic_args (ctx : 'r norm_ctx)
+ (generics : 'r T.generic_args) : 'r T.generic_args =
+ let { T.regions; types; const_generics; trait_refs } = generics in
+ let types = List.map (ctx_normalize_ty ctx) types in
+ let trait_refs = List.map (ctx_normalize_trait_ref ctx) trait_refs in
+ { T.regions; types; const_generics; trait_refs }
+
+and ctx_normalize_trait_ref (ctx : 'r norm_ctx) (trait_ref : 'r T.trait_ref) :
+ 'r T.trait_ref =
+ log#ldebug
+ (lazy
+ ("ctx_normalize_trait_ref: "
+ ^ ctx.trait_ref_to_string trait_ref
+ ^ "\n- raw trait ref:\n"
+ ^ T.show_trait_ref ctx.pp_r trait_ref));
+ let { T.trait_id; generics; trait_decl_ref } = trait_ref in
+ (* Check if the id is an impl, otherwise normalize it *)
+ let trait_id, norm_trait_ref = ctx_normalize_trait_instance_id ctx trait_id in
+ match norm_trait_ref with
+ | None ->
+ log#ldebug
+ (lazy
+ ("ctx_normalize_trait_ref: no norm: "
+ ^ ctx.trait_instance_id_to_string trait_id));
+ let generics = ctx_normalize_generic_args ctx generics in
+ let trait_decl_ref = ctx_normalize_trait_decl_ref ctx trait_decl_ref in
+ { T.trait_id; generics; trait_decl_ref }
+ | Some trait_ref ->
+ log#ldebug
+ (lazy
+ ("ctx_normalize_trait_ref: normalized to: "
+ ^ ctx.trait_ref_to_string trait_ref));
+ assert (generics = TypesUtils.mk_empty_generic_args);
+ trait_ref
+
+(* Not sure this one is really necessary *)
+and ctx_normalize_trait_decl_ref (ctx : 'r norm_ctx)
+ (trait_decl_ref : 'r T.trait_decl_ref) : 'r T.trait_decl_ref =
+ let { T.trait_decl_id; decl_generics } = trait_decl_ref in
+ let decl_generics = ctx_normalize_generic_args ctx decl_generics in
+ { T.trait_decl_id; decl_generics }
+
+let ctx_normalize_trait_type_constraint (ctx : 'r norm_ctx)
+ (ttc : 'r T.trait_type_constraint) : 'r T.trait_type_constraint =
+ let { T.trait_ref; generics; type_name; ty } = ttc in
+ let trait_ref = ctx_normalize_trait_ref ctx trait_ref in
+ let generics = ctx_normalize_generic_args ctx generics in
+ let ty = ctx_normalize_ty ctx ty in
+ { T.trait_ref; generics; type_name; ty }
+
+let mk_snorm_ctx (ctx : C.eval_ctx) : T.RegionVarId.id T.region norm_ctx =
+ let get_ty_repr x = C.STraitTypeRefMap.find_opt x ctx.norm_trait_stypes in
+ {
+ ctx;
+ get_ty_repr;
+ convert_ety = TypesUtils.ety_no_regions_to_sty;
+ convert_etrait_ref = TypesUtils.etrait_ref_no_regions_to_gr_trait_ref;
+ ty_to_string = PA.sty_to_string ctx;
+ trait_ref_to_string = PA.strait_ref_to_string ctx;
+ trait_instance_id_to_string = PA.strait_instance_id_to_string ctx;
+ pp_r = T.pp_region T.pp_region_var_id;
+ }
+
+let mk_rnorm_ctx (ctx : C.eval_ctx) : T.RegionId.id T.region norm_ctx =
+ let get_ty_repr x = C.RTraitTypeRefMap.find_opt x ctx.norm_trait_rtypes in
+ {
+ ctx;
+ get_ty_repr;
+ convert_ety = TypesUtils.ety_no_regions_to_rty;
+ convert_etrait_ref = TypesUtils.etrait_ref_no_regions_to_gr_trait_ref;
+ ty_to_string = PA.rty_to_string ctx;
+ trait_ref_to_string = PA.rtrait_ref_to_string ctx;
+ trait_instance_id_to_string = PA.rtrait_instance_id_to_string ctx;
+ pp_r = T.pp_region T.pp_region_id;
+ }
+
+let mk_enorm_ctx (ctx : C.eval_ctx) : T.erased_region norm_ctx =
+ let get_ty_repr x = C.ETraitTypeRefMap.find_opt x ctx.norm_trait_etypes in
+ {
+ ctx;
+ get_ty_repr;
+ convert_ety = (fun x -> x);
+ convert_etrait_ref = (fun x -> x);
+ ty_to_string = PA.ety_to_string ctx;
+ trait_ref_to_string = PA.etrait_ref_to_string ctx;
+ trait_instance_id_to_string = PA.etrait_instance_id_to_string ctx;
+ pp_r = T.pp_erased_region;
+ }
+
+let ctx_normalize_sty (ctx : C.eval_ctx) (ty : T.sty) : T.sty =
+ ctx_normalize_ty (mk_snorm_ctx ctx) ty
+
+let ctx_normalize_rty (ctx : C.eval_ctx) (ty : T.rty) : T.rty =
+ ctx_normalize_ty (mk_rnorm_ctx ctx) ty
+
+let ctx_normalize_ety (ctx : C.eval_ctx) (ty : T.ety) : T.ety =
+ ctx_normalize_ty (mk_enorm_ctx ctx) ty
+
+let ctx_normalize_rtrait_type_constraint (ctx : C.eval_ctx)
+ (ttc : T.rtrait_type_constraint) : T.rtrait_type_constraint =
+ ctx_normalize_trait_type_constraint (mk_rnorm_ctx ctx) ttc
+
+(** Same as [type_decl_get_instantiated_variants_fields_rtypes] but normalizes the types *)
+let type_decl_get_inst_norm_variants_fields_rtypes (ctx : C.eval_ctx)
+ (def : T.type_decl) (generics : T.rgeneric_args) :
+ (T.VariantId.id option * T.rty list) list =
+ let res =
+ Subst.type_decl_get_instantiated_variants_fields_rtypes def generics
+ in
+ List.map
+ (fun (variant_id, types) ->
+ (variant_id, List.map (ctx_normalize_rty ctx) types))
+ res
+
+(** Same as [type_decl_get_instantiated_field_rtypes] but normalizes the types *)
+let type_decl_get_inst_norm_field_rtypes (ctx : C.eval_ctx) (def : T.type_decl)
+ (opt_variant_id : T.VariantId.id option) (generics : T.rgeneric_args) :
+ T.rty list =
+ let types =
+ Subst.type_decl_get_instantiated_field_rtypes def opt_variant_id generics
+ in
+ List.map (ctx_normalize_rty ctx) types
+
+(** Same as [ctx_adt_value_get_instantiated_field_rtypes] but normalizes the types *)
+let ctx_adt_value_get_inst_norm_field_rtypes (ctx : C.eval_ctx)
+ (adt : V.adt_value) (id : T.type_id) (generics : T.rgeneric_args) :
+ T.rty list =
+ let types =
+ Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id generics
+ in
+ List.map (ctx_normalize_rty ctx) types
+
+(** Same as [ctx_adt_value_get_instantiated_field_etypes] but normalizes the types *)
+let type_decl_get_inst_norm_field_etypes (ctx : C.eval_ctx) (def : T.type_decl)
+ (opt_variant_id : T.VariantId.id option) (generics : T.egeneric_args) :
+ T.ety list =
+ let types =
+ Subst.type_decl_get_instantiated_field_etypes def opt_variant_id generics
+ in
+ List.map (ctx_normalize_ety ctx) types
+
+(** Same as [ctx_adt_get_instantiated_field_etypes] but normalizes the types *)
+let ctx_adt_get_inst_norm_field_etypes (ctx : C.eval_ctx)
+ (def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option)
+ (generics : T.egeneric_args) : T.ety list =
+ let types =
+ Subst.ctx_adt_get_instantiated_field_etypes ctx def_id opt_variant_id
+ generics
+ in
+ List.map (ctx_normalize_ety ctx) types
+
+(** Same as [substitute_signature] but normalizes the types *)
+let ctx_subst_norm_signature (ctx : C.eval_ctx)
+ (asubst : T.RegionGroupId.id -> V.AbstractionId.id)
+ (r_subst : T.RegionVarId.id -> T.RegionId.id)
+ (ty_subst : T.TypeVarId.id -> T.rty)
+ (cg_subst : T.ConstGenericVarId.id -> T.const_generic)
+ (tr_subst : T.TraitClauseId.id -> T.rtrait_instance_id)
+ (tr_self : T.rtrait_instance_id) (sg : A.fun_sig) : A.inst_fun_sig =
+ let sg =
+ Subst.substitute_signature asubst r_subst ty_subst cg_subst tr_subst tr_self
+ sg
+ in
+ let { A.regions_hierarchy; inputs; output; trait_type_constraints } = sg in
+ let inputs = List.map (ctx_normalize_rty ctx) inputs in
+ let output = ctx_normalize_rty ctx output in
+ let trait_type_constraints =
+ List.map (ctx_normalize_rtrait_type_constraint ctx) trait_type_constraints
+ in
+ { regions_hierarchy; inputs; output; trait_type_constraints }
diff --git a/compiler/Assumed.ml b/compiler/Assumed.ml
index 11cd5666..109175af 100644
--- a/compiler/Assumed.ml
+++ b/compiler/Assumed.ml
@@ -63,79 +63,81 @@ module Sig = struct
let empty_const_generic_params : T.const_generic_var list = []
+ let mk_generic_args regions types const_generics : T.sgeneric_args =
+ { regions; types; const_generics; trait_refs = [] }
+
+ let mk_generic_params regions types const_generics : T.generic_params =
+ { regions; types; const_generics; trait_clauses = [] }
+
let mk_ref_ty (r : T.RegionVarId.id T.region) (ty : T.sty) (is_mut : bool) :
T.sty =
let ref_kind = if is_mut then T.Mut else T.Shared in
mk_ref_ty r ty ref_kind
let mk_array_ty (ty : T.sty) (cg : T.const_generic) : T.sty =
- Adt (Assumed Array, [], [ ty ], [ cg ])
+ Adt (Assumed Array, mk_generic_args [] [ ty ] [ cg ])
+
+ let mk_slice_ty (ty : T.sty) : T.sty =
+ Adt (Assumed Slice, mk_generic_args [] [ ty ] [])
- let mk_slice_ty (ty : T.sty) : T.sty = Adt (Assumed Slice, [], [ ty ], [])
- let range_ty : T.sty = Adt (Assumed Range, [], [ usize_ty ], [])
+ let range_ty : T.sty = Adt (Assumed Range, mk_generic_args [] [ usize_ty ] [])
+
+ let mk_sig generics regions_hierarchy inputs output : A.fun_sig =
+ let preds : T.predicates =
+ { regions_outlive = []; types_outlive = []; trait_type_constraints = [] }
+ in
+ {
+ generics;
+ preds;
+ parent_params_info = None;
+ regions_hierarchy;
+ inputs;
+ output;
+ }
(** [fn<T>(&'a mut T, T) -> T] *)
let mem_replace_sig : A.fun_sig =
(* The signature fields *)
- let region_params = [ region_param_0 ] (* <'a> *) in
+ let regions = [ region_param_0 ] (* <'a> *) in
let regions_hierarchy = [ region_group_0 ] (* [{<'a>}] *) in
- let type_params = [ type_param_0 ] (* <T> *) in
+ let types = [ type_param_0 ] (* <T> *) in
+ let generics = mk_generic_params regions types [] in
let inputs =
[ mk_ref_ty rvar_0 tvar_0 true (* &'a mut T *); tvar_0 (* T *) ]
in
let output = tvar_0 (* T *) in
- {
- region_params;
- num_early_bound_regions = 0;
- regions_hierarchy;
- type_params;
- const_generic_params = empty_const_generic_params;
- inputs;
- output;
- }
+ mk_sig generics regions_hierarchy inputs output
(** [fn<T>(T) -> Box<T>] *)
let box_new_sig : A.fun_sig =
- {
- region_params = [];
- num_early_bound_regions = 0;
- regions_hierarchy = [];
- type_params = [ type_param_0 ] (* <T> *);
- const_generic_params = empty_const_generic_params;
- inputs = [ tvar_0 (* T *) ];
- output = mk_box_ty tvar_0 (* Box<T> *);
- }
+ let generics = mk_generic_params [] [ type_param_0 ] [] (* <T> *) in
+ let regions_hierarchy = [] in
+ let inputs = [ tvar_0 (* T *) ] in
+ let output = mk_box_ty tvar_0 (* Box<T> *) in
+ mk_sig generics regions_hierarchy inputs output
(** [fn<T>(Box<T>) -> ()] *)
let box_free_sig : A.fun_sig =
- {
- region_params = [];
- num_early_bound_regions = 0;
- regions_hierarchy = [];
- type_params = [ type_param_0 ] (* <T> *);
- const_generic_params = empty_const_generic_params;
- inputs = [ mk_box_ty tvar_0 (* Box<T> *) ];
- output = mk_unit_ty (* () *);
- }
+ let generics = mk_generic_params [] [ type_param_0 ] [] (* <T> *) in
+ let regions_hierarchy = [] in
+ let inputs = [ mk_box_ty tvar_0 (* Box<T> *) ] in
+ let output = mk_unit_ty (* () *) in
+ mk_sig generics regions_hierarchy inputs output
(** Helper for [Box::deref_shared] and [Box::deref_mut].
Returns:
[fn<'a, T>(&'a (mut) Box<T>) -> &'a (mut) T]
*)
let box_deref_gen_sig (is_mut : bool) : A.fun_sig =
- (* The signature fields *)
- let region_params = [ region_param_0 ] in
+ let generics =
+ mk_generic_params [ region_param_0 ] [ type_param_0 ] [] (* <'a, T> *)
+ in
let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
- {
- region_params;
- num_early_bound_regions = 0;
- regions_hierarchy;
- type_params = [ type_param_0 ] (* <T> *);
- const_generic_params = empty_const_generic_params;
- inputs =
- [ mk_ref_ty rvar_0 (mk_box_ty tvar_0) is_mut (* &'a (mut) Box<T> *) ];
- output = mk_ref_ty rvar_0 tvar_0 is_mut (* &'a (mut) T *);
- }
+ let inputs =
+ [ mk_ref_ty rvar_0 (mk_box_ty tvar_0) is_mut (* &'a (mut) Box<T> *) ]
+ in
+ let output = mk_ref_ty rvar_0 tvar_0 is_mut (* &'a (mut) T *) in
+ mk_sig generics regions_hierarchy inputs output
(** [fn<'a, T>(&'a Box<T>) -> &'a T] *)
let box_deref_shared_sig = box_deref_gen_sig false
@@ -145,27 +147,18 @@ module Sig = struct
(** [fn<T>() -> Vec<T>] *)
let vec_new_sig : A.fun_sig =
- let region_params = [] in
+ let generics = mk_generic_params [] [ type_param_0 ] [] (* <T> *) in
let regions_hierarchy = [] in
- let type_params = [ type_param_0 ] (* <T> *) in
let inputs = [] in
let output = mk_vec_ty tvar_0 (* Vec<T> *) in
- {
- region_params;
- num_early_bound_regions = 0;
- regions_hierarchy;
- type_params;
- const_generic_params = empty_const_generic_params;
- inputs;
- output;
- }
+ mk_sig generics regions_hierarchy inputs output
(** [fn<T>(&'a mut Vec<T>, T)] *)
let vec_push_sig : A.fun_sig =
- (* The signature fields *)
- let region_params = [ region_param_0 ] in
+ let generics =
+ mk_generic_params [ region_param_0 ] [ type_param_0 ] [] (* <'a, T> *)
+ in
let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
- let type_params = [ type_param_0 ] (* <T> *) in
let inputs =
[
mk_ref_ty rvar_0 (mk_vec_ty tvar_0) true (* &'a mut Vec<T> *);
@@ -173,22 +166,14 @@ module Sig = struct
]
in
let output = mk_unit_ty (* () *) in
- {
- region_params;
- num_early_bound_regions = 0;
- regions_hierarchy;
- type_params;
- const_generic_params = empty_const_generic_params;
- inputs;
- output;
- }
+ mk_sig generics regions_hierarchy inputs output
(** [fn<T>(&'a mut Vec<T>, usize, T)] *)
let vec_insert_sig : A.fun_sig =
- (* The signature fields *)
- let region_params = [ region_param_0 ] in
+ let generics =
+ mk_generic_params [ region_param_0 ] [ type_param_0 ] [] (* <'a, T> *)
+ in
let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
- let type_params = [ type_param_0 ] (* <T> *) in
let inputs =
[
mk_ref_ty rvar_0 (mk_vec_ty tvar_0) true (* &'a mut Vec<T> *);
@@ -197,44 +182,28 @@ module Sig = struct
]
in
let output = mk_unit_ty (* () *) in
- {
- region_params;
- num_early_bound_regions = 0;
- regions_hierarchy;
- type_params;
- const_generic_params = empty_const_generic_params;
- inputs;
- output;
- }
+ mk_sig generics regions_hierarchy inputs output
(** [fn<T>(&'a Vec<T>) -> usize] *)
let vec_len_sig : A.fun_sig =
- (* The signature fields *)
- let region_params = [ region_param_0 ] in
+ let generics =
+ mk_generic_params [ region_param_0 ] [ type_param_0 ] [] (* <'a, T> *)
+ in
let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
- let type_params = [ type_param_0 ] (* <T> *) in
let inputs =
[ mk_ref_ty rvar_0 (mk_vec_ty tvar_0) false (* &'a Vec<T> *) ]
in
let output = mk_usize_ty (* usize *) in
- {
- region_params;
- num_early_bound_regions = 0;
- regions_hierarchy;
- type_params;
- const_generic_params = empty_const_generic_params;
- inputs;
- output;
- }
+ mk_sig generics regions_hierarchy inputs output
(** Helper:
[fn<T>(&'a (mut) Vec<T>, usize) -> &'a (mut) T]
*)
let vec_index_gen_sig (is_mut : bool) : A.fun_sig =
- (* The signature fields *)
- let region_params = [ region_param_0 ] in
+ let generics =
+ mk_generic_params [ region_param_0 ] [ type_param_0 ] [] (* <'a, T> *)
+ in
let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
- let type_params = [ type_param_0 ] (* <T> *) in
let inputs =
[
mk_ref_ty rvar_0 (mk_vec_ty tvar_0) is_mut (* &'a (mut) Vec<T> *);
@@ -242,15 +211,7 @@ module Sig = struct
]
in
let output = mk_ref_ty rvar_0 tvar_0 is_mut (* &'a (mut) T *) in
- {
- region_params;
- num_early_bound_regions = 0;
- regions_hierarchy;
- type_params;
- const_generic_params = empty_const_generic_params;
- inputs;
- output;
- }
+ mk_sig generics regions_hierarchy inputs output
(** [fn<T>(&'a Vec<T>, usize) -> &'a T] *)
let vec_index_shared_sig : A.fun_sig = vec_index_gen_sig false
@@ -275,10 +236,10 @@ module Sig = struct
let mk_array_slice_borrow_sig (cgs : T.const_generic_var list)
(input_ty : T.TypeVarId.id -> T.sty) (index_ty : T.sty option)
(output_ty : T.TypeVarId.id -> T.sty) (is_mut : bool) : A.fun_sig =
- (* The signature fields *)
- let region_params = [ region_param_0 ] in
+ let generics =
+ mk_generic_params [ region_param_0 ] [ type_param_0 ] cgs (* <'a, T> *)
+ in
let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
- let type_params = [ type_param_0 ] (* <T> *) in
let inputs =
[
mk_ref_ty rvar_0
@@ -294,15 +255,7 @@ module Sig = struct
(output_ty type_param_0.index)
is_mut (* &'a (mut) output_ty<T> *)
in
- {
- region_params;
- num_early_bound_regions = 0;
- regions_hierarchy;
- type_params;
- const_generic_params = cgs;
- inputs;
- output;
- }
+ mk_sig generics regions_hierarchy inputs output
let mk_array_slice_index_sig (is_array : bool) (is_mut : bool) : A.fun_sig =
(* Array<T, N> *)
@@ -345,6 +298,19 @@ module Sig = struct
let array_subslice_sig (is_mut : bool) =
mk_array_slice_subslice_sig true is_mut
+ let array_repeat_sig =
+ let generics =
+ (* <T, N> *)
+ mk_generic_params [] [ type_param_0 ] [ cg_param_0 ]
+ in
+ let regions_hierarchy = [] (* <> *) in
+ let inputs = [ tvar_0 (* T *) ] in
+ let output =
+ (* [T; N] *)
+ mk_array_ty tvar_0 cgvar_0
+ in
+ mk_sig generics regions_hierarchy inputs output
+
let slice_subslice_sig (is_mut : bool) =
mk_array_slice_subslice_sig false is_mut
@@ -352,23 +318,15 @@ module Sig = struct
[fn<T>(&'a [T]) -> usize]
*)
let slice_len_sig : A.fun_sig =
- (* The signature fields *)
- let region_params = [ region_param_0 ] in
+ let generics =
+ mk_generic_params [ region_param_0 ] [ type_param_0 ] [] (* <'a, T> *)
+ in
let regions_hierarchy = [ region_group_0 ] (* <'a> *) in
- let type_params = [ type_param_0 ] (* <T> *) in
let inputs =
[ mk_ref_ty rvar_0 (mk_slice_ty tvar_0) false (* &'a [T] *) ]
in
let output = mk_usize_ty (* usize *) in
- {
- region_params;
- num_early_bound_regions = 0;
- regions_hierarchy;
- type_params;
- const_generic_params = empty_const_generic_params;
- inputs;
- output;
- }
+ mk_sig generics regions_hierarchy inputs output
end
type assumed_info = A.assumed_fun_id * A.fun_sig * bool * name
@@ -439,6 +397,8 @@ let assumed_infos : assumed_info list =
Sig.array_subslice_sig true,
true,
to_name [ "@ArraySubsliceMut" ] );
+ (* Array Repeat *)
+ (ArrayRepeat, Sig.array_repeat_sig, false, to_name [ "@ArrayRepeat" ]);
(* Slice Index *)
( SliceIndexShared,
Sig.slice_index_sig false,
diff --git a/compiler/Config.ml b/compiler/Config.ml
index bd80769f..62f6c300 100644
--- a/compiler/Config.ml
+++ b/compiler/Config.ml
@@ -323,3 +323,23 @@ let wrap_opaque_in_sig = ref false
information), we use short names (i.e., the original field names).
*)
let record_fields_short_names = ref false
+
+(** Parameterize the traits with their associated types, so as not to use
+ types as first class objects.
+
+ This is useful for some backends with limited expressiveness like HOL4,
+ and to account for type constraints (like [fn f<T : Foo>(...) where T::bar = usize]).
+ *)
+let parameterize_trait_types = ref false
+
+(** For sanity check: type check the generated pure code (activates checks in
+ several places).
+
+ TODO: deactivated for now because we need to implement the normalization of
+ trait associated types in the pure code.
+ *)
+let type_check_pure_code = ref false
+
+(** Shall we fail hard if there is an issue at code-generation time?
+ We may not want in case outputting a code with holes helps debugging *)
+let extract_fail_hard = ref false
diff --git a/compiler/Contexts.ml b/compiler/Contexts.ml
index 2ca5653d..dac64a9a 100644
--- a/compiler/Contexts.ml
+++ b/compiler/Contexts.ml
@@ -5,6 +5,7 @@ open LlbcAst
module V = Values
open ValuesUtils
open Identifiers
+module L = Logging
(** The [Id] module for dummy variables.
@@ -17,6 +18,9 @@ IdGen ()
type dummy_var_id = DummyVarId.id [@@deriving show, ord]
+(** The local logger *)
+let log = L.contexts_log
+
(** Some global counters.
Note that those counters were initially stored in {!eval_ctx} values,
@@ -40,6 +44,7 @@ type dummy_var_id = DummyVarId.id [@@deriving show, ord]
fn f x : fun_type =
let id = fresh_id () in
...
+ fun () -> ...
let g = f x in // <-- the fresh identifier gets generated here
let x1 = g () in // <-- no fresh generation here
@@ -250,27 +255,127 @@ type type_context = {
}
[@@deriving show]
-type fun_context = { fun_decls : fun_decl FunDeclId.Map.t } [@@deriving show]
+type fun_context = {
+ fun_decls : fun_decl FunDeclId.Map.t;
+ fun_infos : FunsAnalysis.fun_info FunDeclId.Map.t;
+}
+[@@deriving show]
type global_context = { global_decls : global_decl GlobalDeclId.Map.t }
[@@deriving show]
+type trait_decls_context = { trait_decls : trait_decl TraitDeclId.Map.t }
+[@@deriving show]
+
+type trait_impls_context = { trait_impls : trait_impl TraitImplId.Map.t }
+[@@deriving show]
+
+type decls_ctx = {
+ type_ctx : type_context;
+ fun_ctx : fun_context;
+ global_ctx : global_context;
+ trait_decls_ctx : trait_decls_context;
+ trait_impls_ctx : trait_impls_context;
+}
+[@@deriving show]
+
+(** A reference to a trait associated type *)
+type 'r trait_type_ref = { trait_ref : 'r trait_ref; type_name : string }
+[@@deriving show, ord]
+
+type etrait_type_ref = erased_region trait_type_ref [@@deriving show, ord]
+
+type rtrait_type_ref = Types.RegionId.id Types.region trait_type_ref
+[@@deriving show, ord]
+
+type strait_type_ref = Types.RegionVarId.id Types.region trait_type_ref
+[@@deriving show, ord]
+
+(* TODO: correctly use the functors so as not to have a duplication below *)
+module ETraitTypeRefOrd = struct
+ type t = etrait_type_ref
+
+ let compare = compare_etrait_type_ref
+ let to_string = show_etrait_type_ref
+ let pp_t = pp_etrait_type_ref
+ let show_t = show_etrait_type_ref
+end
+
+module RTraitTypeRefOrd = struct
+ type t = rtrait_type_ref
+
+ let compare = compare_rtrait_type_ref
+ let to_string = show_rtrait_type_ref
+ let pp_t = pp_rtrait_type_ref
+ let show_t = show_rtrait_type_ref
+end
+
+module STraitTypeRefOrd = struct
+ type t = strait_type_ref
+
+ let compare = compare_strait_type_ref
+ let to_string = show_strait_type_ref
+ let pp_t = pp_strait_type_ref
+ let show_t = show_strait_type_ref
+end
+
+module ETraitTypeRefMap = Collections.MakeMap (ETraitTypeRefOrd)
+module RTraitTypeRefMap = Collections.MakeMap (RTraitTypeRefOrd)
+module STraitTypeRefMap = Collections.MakeMap (STraitTypeRefOrd)
+
(** Evaluation context *)
type eval_ctx = {
type_context : type_context;
fun_context : fun_context;
global_context : global_context;
+ trait_decls_context : trait_decls_context;
+ trait_impls_context : trait_impls_context;
region_groups : RegionGroupId.id list;
type_vars : type_var list;
const_generic_vars : const_generic_var list;
+ const_generic_vars_map : typed_value Types.ConstGenericVarId.Map.t;
+ (** The map from const generic vars to their values. Those values
+ can be symbolic values or concrete values (in the latter case:
+ if we run in interpreter mode) *)
+ norm_trait_etypes : ety ETraitTypeRefMap.t;
+ (** The normalized trait types (a map from trait types to their representatives).
+ Note that this doesn't support account higher-order types. *)
+ norm_trait_rtypes : rty RTraitTypeRefMap.t;
+ (** We need this because we manipulate two kinds of types.
+ Note that we actually forbid regions from appearing both in the trait
+ references and in the constraints given to the associated types,
+ meaning that we don't have to worry about mismatches due to changes
+ in region ids.
+
+ TODO: how not to duplicate?
+ *)
+ norm_trait_stypes : sty STraitTypeRefMap.t;
+ (** We sometimes need to normalize types in non-instantiated signatures.
+
+ Note that we either need to use the etypes/rtypes maps, or the stypes map.
+ This means that we either compute the maps for etypes and rtypes, or compute
+ the one for stypes (we don't always compute and carry all the maps).
+ *)
env : env;
ended_regions : RegionId.Set.t;
}
[@@deriving show]
+let lookup_type_var_opt (ctx : eval_ctx) (vid : TypeVarId.id) : type_var option
+ =
+ if TypeVarId.to_int vid < List.length ctx.type_vars then
+ Some (TypeVarId.nth ctx.type_vars vid)
+ else None
+
let lookup_type_var (ctx : eval_ctx) (vid : TypeVarId.id) : type_var =
TypeVarId.nth ctx.type_vars vid
+let lookup_const_generic_var_opt (ctx : eval_ctx) (vid : ConstGenericVarId.id) :
+ const_generic_var option =
+ if ConstGenericVarId.to_int vid < List.length ctx.const_generic_vars then
+ Some (ConstGenericVarId.nth ctx.const_generic_vars vid)
+ else None
+
let lookup_const_generic_var (ctx : eval_ctx) (vid : ConstGenericVarId.id) :
const_generic_var =
ConstGenericVarId.nth ctx.const_generic_vars vid
@@ -304,6 +409,12 @@ let ctx_lookup_global_decl (ctx : eval_ctx) (gid : GlobalDeclId.id) :
global_decl =
GlobalDeclId.Map.find gid ctx.global_context.global_decls
+let ctx_lookup_trait_decl (ctx : eval_ctx) (id : TraitDeclId.id) : trait_decl =
+ TraitDeclId.Map.find id ctx.trait_decls_context.trait_decls
+
+let ctx_lookup_trait_impl (ctx : eval_ctx) (id : TraitImplId.id) : trait_impl =
+ TraitImplId.Map.find id ctx.trait_impls_context.trait_impls
+
(** Retrieve a variable's value in the current frame *)
let env_lookup_var_value (env : env) (vid : VarId.id) : typed_value =
snd (env_lookup_var env vid)
@@ -312,6 +423,11 @@ let env_lookup_var_value (env : env) (vid : VarId.id) : typed_value =
let ctx_lookup_var_value (ctx : eval_ctx) (vid : VarId.id) : typed_value =
env_lookup_var_value ctx.env vid
+(** Retrieve a const generic value in an evaluation context *)
+let ctx_lookup_const_generic_value (ctx : eval_ctx) (vid : ConstGenericVarId.id)
+ : typed_value =
+ Types.ConstGenericVarId.Map.find vid ctx.const_generic_vars_map
+
(** Update a variable's value in the current frame.
This is a helper function: it can break invariants and doesn't perform
@@ -361,6 +477,15 @@ let ctx_push_var (ctx : eval_ctx) (var : var) (v : typed_value) : eval_ctx =
*)
let ctx_push_vars (ctx : eval_ctx) (vars : (var * typed_value) list) : eval_ctx
=
+ log#ldebug
+ (lazy
+ ("push_vars:\n"
+ ^ String.concat "\n"
+ (List.map
+ (fun (var, value) ->
+ (* We can unfortunately not use Print because it depends on Contexts... *)
+ show_var var ^ " -> " ^ V.show_typed_value value)
+ vars)));
assert (
List.for_all
(fun (var, (value : typed_value)) -> var.var_ty = value.ty)
diff --git a/compiler/Driver.ml b/compiler/Driver.ml
index b646a53d..414b042d 100644
--- a/compiler/Driver.ml
+++ b/compiler/Driver.ml
@@ -17,11 +17,15 @@ let log = main_log
let _ =
(* Set up the logging - for now we use default values - TODO: use the
* command-line arguments *)
- (* By setting a level for the main_logger_handler, we filter everything *)
+ (* By setting a level for the main_logger_handler, we filter everything.
+ To have a good trace: one should switch between Info and Debug.
+ *)
Easy_logging.Handlers.set_level main_logger_handler EL.Debug;
main_log#set_level EL.Info;
llbc_of_json_logger#set_level EL.Info;
pre_passes_log#set_level EL.Info;
+ associated_types_log#set_level EL.Info;
+ contexts_log#set_level EL.Info;
interpreter_log#set_level EL.Info;
statements_log#set_level EL.Info;
loops_match_ctxs_log#set_level EL.Info;
@@ -164,8 +168,6 @@ let () =
decompose_monadic_let_bindings := true;
decompose_nested_let_patterns := true
| Lean ->
- (* The Lean backend is experimental: print a warning *)
- log#lwarning (lazy "The Lean backend is experimental");
(* We don't support fuel for the Lean backend *)
if !use_fuel then (
log#error "The Lean backend doesn't support the -use-fuel option";
@@ -220,20 +222,6 @@ let () =
in
if has_loops then log#lwarning (lazy "Support for loops is experimental");
- (* If we target Lean, we request the crates to be split into several files
- whenever there are opaque functions *)
- if
- !backend = Lean
- && A.FunDeclId.Map.exists
- (fun _ (d : A.fun_decl) -> d.body = None)
- m.functions
- && not !split_files
- then (
- log#error
- "For Lean, we request the -split-file option whenever using opaque \
- functions";
- fail ());
-
(* We don't support mutually recursive definitions with decreases clauses in Lean *)
if
!backend = Lean && !extract_decreases_clauses
diff --git a/compiler/Extract.ml b/compiler/Extract.ml
index c4238d83..ac81d6f3 100644
--- a/compiler/Extract.ml
+++ b/compiler/Extract.ml
@@ -338,6 +338,7 @@ let assumed_llbc_functions () :
(ArraySubsliceShared, None, "array_subslice_shared");
(ArraySubsliceMut, None, "array_subslice_mut_fwd");
(ArraySubsliceMut, rg0, "array_subslice_mut_back");
+ (ArrayRepeat, None, "array_repeat");
(SliceIndexShared, None, "slice_index_shared");
(SliceIndexMut, None, "slice_index_mut_fwd");
(SliceIndexMut, rg0, "slice_index_mut_back");
@@ -369,6 +370,7 @@ let assumed_llbc_functions () :
(ArraySubsliceShared, None, "Array.subslice_shared");
(ArraySubsliceMut, None, "Array.subslice_mut");
(ArraySubsliceMut, rg0, "Array.subslice_mut_back");
+ (ArrayRepeat, None, "Array.repeat");
(SliceIndexShared, None, "Slice.index_shared");
(SliceIndexMut, None, "Slice.index_mut");
(SliceIndexMut, rg0, "Slice.index_mut_back");
@@ -529,7 +531,15 @@ let type_decl_kind_to_qualif (kind : decl_kind)
*)
Some "with"
| (Assumed | Declared), None -> Some "Axiom"
- | _ -> raise (Failure "Unexpected"))
+ | SingleNonRec, None ->
+ (* This is for traits *)
+ Some "Record"
+ | _ ->
+ raise
+ (Failure
+ ("Unexpected: (" ^ show_decl_kind kind ^ ", "
+ ^ Print.option_to_string show_type_decl_kind type_kind
+ ^ ")")))
| Lean -> (
match kind with
| SingleNonRec ->
@@ -647,6 +657,11 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
| _ ->
raise (Failure ("Unexpected name shape: " ^ Print.name_to_string name))
in
+ let flatten_name (name : string list) : string =
+ match !backend with
+ | FStar | Coq | HOL4 -> String.concat "_" name
+ | Lean -> String.concat "." name
+ in
let get_type_name = get_name in
let type_name_to_camel_case name =
let name = get_type_name name in
@@ -668,15 +683,20 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
in
let field_name (def_name : name) (field_id : FieldId.id)
(field_name : string option) : string =
- let field_name =
+ let field_name_s =
match field_name with
| Some field_name -> field_name
- | None -> FieldId.to_string field_id
+ | None ->
+ (* TODO: extract structs with no field names to tuples *)
+ FieldId.to_string field_id
in
- if !Config.record_fields_short_names then field_name
+ if !Config.record_fields_short_names then
+ if field_name = None then (* TODO: this is a bit ugly *)
+ "_" ^ field_name_s
+ else field_name_s
else
let def_name = type_name_to_snake_case def_name ^ "_" in
- def_name ^ field_name
+ def_name ^ field_name_s
in
let variant_name (def_name : name) (variant : string) : string =
match !backend with
@@ -700,9 +720,7 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
let get_fun_name fname =
let fname = get_name fname in
(* TODO: don't convert to snake case for Coq, HOL4, F* *)
- match !backend with
- | FStar | Coq | HOL4 -> String.concat "_" (List.map to_snake_case fname)
- | Lean -> String.concat "." fname
+ flatten_name fname
in
let global_name (name : global_name) : string =
(* Converting to snake case also lowercases the letters (in Rust, global
@@ -720,6 +738,50 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
fname ^ suffix
in
+ let trait_decl_name (trait_decl : trait_decl) : string =
+ type_name trait_decl.name
+ in
+
+ let trait_impl_name (trait_decl : trait_decl) (trait_impl : trait_impl) :
+ string =
+ (* TODO: provisional: we concatenate the trait impl name (which is its type)
+ with the trait decl name *)
+ let trait_decl =
+ let name = trait_decl.name in
+ match !backend with
+ | FStar | Coq | HOL4 -> type_name_to_snake_case name ^ "_inst"
+ | Lean -> String.concat "" (get_type_name name) ^ "Inst"
+ in
+ flatten_name (get_type_name trait_impl.name @ [ trait_decl ])
+ in
+
+ let trait_parent_clause_name (trait_decl : trait_decl) (clause : trait_clause)
+ : string =
+ (* TODO: improve - it would be better to not use indices *)
+ let clause = "parent_clause_" ^ TraitClauseId.to_string clause.clause_id in
+ if !Config.record_fields_short_names then clause
+ else trait_decl_name trait_decl ^ "_" ^ clause
+ in
+ let trait_type_name (trait_decl : trait_decl) (item : string) : string =
+ if !Config.record_fields_short_names then item
+ else trait_decl_name trait_decl ^ "_" ^ item
+ in
+ let trait_const_name (trait_decl : trait_decl) (item : string) : string =
+ if !Config.record_fields_short_names then item
+ else trait_decl_name trait_decl ^ "_" ^ item
+ in
+ let trait_method_name (trait_decl : trait_decl) (item : string) : string =
+ if !Config.record_fields_short_names then item
+ else trait_decl_name trait_decl ^ "_" ^ item
+ in
+ let trait_type_clause_name (trait_decl : trait_decl) (item : string)
+ (clause : trait_clause) : string =
+ (* TODO: improve - it would be better to not use indices *)
+ trait_type_name trait_decl item
+ ^ "_clause_"
+ ^ TraitClauseId.to_string clause.clause_id
+ in
+
let termination_measure_name (_fid : A.FunDeclId.id) (fname : fun_name)
(num_loops : int) (loop_id : LoopId.id option) : string =
let fname = get_fun_name fname in
@@ -757,6 +819,21 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
let var_basename (_varset : StringSet.t) (basename : string option) (ty : ty)
: string =
+ (* Small helper to derive var names from ADT type names.
+
+ We do the following:
+ - convert the type name to snake case
+ - take the first letter of every "letter group"
+ Ex.: "HashMap" -> "hash_map" -> "hm"
+ *)
+ let name_from_type_ident (name : string) : string =
+ let cl = to_snake_case name in
+ let cl = String.split_on_char '_' cl in
+ let cl = List.filter (fun s -> String.length s > 0) cl in
+ assert (List.length cl > 0);
+ let cl = List.map (fun s -> s.[0]) cl in
+ StringUtils.string_of_chars cl
+ in
(* If there is a basename, we use it *)
match basename with
| Some basename ->
@@ -765,11 +842,11 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
| None -> (
(* No basename: we use the first letter of the type *)
match ty with
- | Adt (type_id, tys, _) -> (
+ | Adt (type_id, generics) -> (
match type_id with
| Tuple ->
(* The "pair" case is frequent enough to have its special treatment *)
- if List.length tys = 2 then "p" else "t"
+ if List.length generics.types = 2 then "p" else "t"
| Assumed Result -> "r"
| Assumed Error -> ConstStrings.error_basename
| Assumed Fuel -> ConstStrings.fuel_basename
@@ -781,24 +858,14 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
| Assumed Range -> "r"
| Assumed State -> ConstStrings.state_basename
| AdtId adt_id ->
- let def =
- TypeDeclId.Map.find adt_id ctx.type_context.type_decls
- in
- (* We do the following:
- * - compute the type name, and retrieve the last ident
- * - convert this to snake case
- * - take the first letter of every "letter group"
+ let def = TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls in
+ (* Derive the var name from the last ident of the type name
* Ex.: ["hashmap"; "HashMap"] ~~> "HashMap" -> "hash_map" -> "hm"
*)
- (* Thename shouldn't be empty, and its last element should
+ (* The name shouldn't be empty, and its last element should
* be an ident *)
let cl = List.nth def.name (List.length def.name - 1) in
- let cl = to_snake_case (Names.as_ident cl) in
- let cl = String.split_on_char '_' cl in
- let cl = List.filter (fun s -> String.length s > 0) cl in
- assert (List.length cl > 0);
- let cl = List.map (fun s -> s.[0]) cl in
- StringUtils.string_of_chars cl)
+ name_from_type_ident (Names.as_ident cl))
| TypeVar _ -> (
(* TODO: use "t" also for F* *)
match !backend with
@@ -806,7 +873,8 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
| Coq | Lean | HOL4 -> "t" (* lacking inspiration here... *))
| Literal lty -> (
match lty with Bool -> "b" | Char -> "c" | Integer _ -> "i")
- | Arrow _ -> "f")
+ | Arrow _ -> "f"
+ | TraitType (_, _, name) -> name_from_type_ident name)
in
let type_var_basename (_varset : StringSet.t) (basename : string) : string =
(* Rust type variables are snake-case and start with a capital letter *)
@@ -828,6 +896,12 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
to_snake_case basename
| Coq | Lean -> basename
in
+ let trait_clause_basename (_varset : StringSet.t) (_clause : trait_clause) :
+ string =
+ (* TODO: actually use the clause to derive the name *)
+ "inst"
+ in
+ let trait_self_clause_basename = "self_clause" in
let append_index (basename : string) (i : int) : string =
basename ^ string_of_int i
in
@@ -838,11 +912,11 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
| Scalar sv -> (
match !backend with
| FStar -> F.pp_print_string fmt (Z.to_string sv.PV.value)
- | Coq | HOL4 ->
+ | Coq | HOL4 | Lean ->
let print_brackets = inside && !backend = HOL4 in
if print_brackets then F.pp_print_string fmt "(";
(match !backend with
- | Coq -> ()
+ | Coq | Lean -> ()
| HOL4 ->
F.pp_print_string fmt ("int_to_" ^ int_name sv.PV.int_ty);
F.pp_print_space fmt ()
@@ -850,30 +924,20 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
(* We need to add parentheses if the value is negative *)
if sv.PV.value >= Z.of_int 0 then
F.pp_print_string fmt (Z.to_string sv.PV.value)
- else F.pp_print_string fmt ("(" ^ Z.to_string sv.PV.value ^ ")");
+ else
+ F.pp_print_string fmt
+ ("(" ^ Z.to_string sv.PV.value
+ ^ if !backend = Lean then ":Int" else "" ^ ")");
(match !backend with
- | Coq -> F.pp_print_string fmt ("%" ^ int_name sv.PV.int_ty)
+ | Coq ->
+ let iname = int_name sv.PV.int_ty in
+ F.pp_print_string fmt ("%" ^ iname)
+ | Lean ->
+ let iname = String.lowercase_ascii (int_name sv.PV.int_ty) in
+ F.pp_print_string fmt ("#" ^ iname)
| HOL4 -> ()
| _ -> raise (Failure "Unreachable"));
- if print_brackets then F.pp_print_string fmt ")"
- | Lean ->
- F.pp_print_string fmt "(";
- F.pp_print_string fmt (int_name sv.int_ty);
- F.pp_print_string fmt ".ofInt ";
- (* Something very annoying: negated values like `-3` are
- ambiguous in Lean because of conversions, so we have to
- be extremely explicit with negative numbers.
- *)
- if Z.lt sv.value Z.zero then (
- F.pp_print_string fmt "(";
- F.pp_print_string fmt "-";
- F.pp_print_string fmt "(";
- Z.pp_print fmt (Z.neg sv.value);
- F.pp_print_string fmt ":Int";
- F.pp_print_string fmt ")";
- F.pp_print_string fmt ")")
- else Z.pp_print fmt sv.value;
- F.pp_print_string fmt ")")
+ if print_brackets then F.pp_print_string fmt ")")
| Bool b ->
let b =
match !backend with
@@ -919,10 +983,19 @@ let mk_formatter (ctx : trans_ctx) (crate_name : string)
fun_name;
termination_measure_name;
decreases_proof_name;
+ trait_decl_name;
+ trait_impl_name;
+ trait_parent_clause_name;
+ trait_const_name;
+ trait_type_name;
+ trait_method_name;
+ trait_type_clause_name;
opaque_pre;
var_basename;
type_var_basename;
const_generic_var_basename;
+ trait_self_clause_basename;
+ trait_clause_basename;
append_index;
extract_literal;
extract_unop;
@@ -1131,13 +1204,13 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter)
(no_params_tys : TypeDeclId.Set.t) (inside : bool) (ty : ty) : unit =
let extract_rec = extract_ty ctx fmt no_params_tys in
match ty with
- | Adt (type_id, tys, cgs) -> (
- let has_params = tys <> [] || cgs <> [] in
+ | Adt (type_id, generics) -> (
+ let has_params = generics <> empty_generic_args in
match type_id with
| Tuple ->
(* This is a bit annoying, but in F*/Coq/HOL4 [()] is not the unit type:
* we have to write [unit]... *)
- if tys = [] then F.pp_print_string fmt (unit_name ())
+ if generics.types = [] then F.pp_print_string fmt (unit_name ())
else (
F.pp_print_string fmt "(";
Collections.List.iter_link
@@ -1152,7 +1225,7 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter)
in
F.pp_print_string fmt product;
F.pp_print_space fmt ())
- (extract_rec true) tys;
+ (extract_rec true) generics.types;
F.pp_print_string fmt ")")
| AdtId _ | Assumed _ -> (
(* HOL4 behaves differently. Where in Coq/FStar/Lean we would write:
@@ -1169,36 +1242,34 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter)
(* TODO: for now, only the opaque *functions* are extracted in the
opaque module. The opaque *types* are assumed. *)
F.pp_print_string fmt (ctx_get_type with_opaque_pre type_id ctx);
- if tys <> [] then (
- F.pp_print_space fmt ();
- Collections.List.iter_link (F.pp_print_space fmt)
- (extract_rec true) tys);
- if cgs <> [] then (
- F.pp_print_space fmt ();
- Collections.List.iter_link (F.pp_print_space fmt)
- (extract_const_generic ctx fmt true)
- cgs);
+ extract_generic_args ctx fmt no_params_tys generics;
if print_paren then F.pp_print_string fmt ")"
| HOL4 ->
- (* Const generics are unsupported in HOL4 *)
- assert (cgs = []);
+ let { types; const_generics; trait_refs } = generics in
+ (* Const generics are not supported in HOL4 *)
+ assert (const_generics = []);
let print_tys =
match type_id with
| AdtId id -> not (TypeDeclId.Set.mem id no_params_tys)
| Assumed _ -> true
| _ -> raise (Failure "Unreachable")
in
- if tys <> [] && print_tys then (
- let print_paren = List.length tys > 1 in
+ if types <> [] && print_tys then (
+ let print_paren = List.length types > 1 in
if print_paren then F.pp_print_string fmt "(";
Collections.List.iter_link
(fun () ->
F.pp_print_string fmt ",";
F.pp_print_space fmt ())
- (extract_rec true) tys;
+ (extract_rec true) types;
if print_paren then F.pp_print_string fmt ")";
F.pp_print_space fmt ());
- F.pp_print_string fmt (ctx_get_type with_opaque_pre type_id ctx)))
+ F.pp_print_string fmt (ctx_get_type with_opaque_pre type_id ctx);
+ if trait_refs <> [] then (
+ F.pp_print_space fmt ();
+ Collections.List.iter_link (F.pp_print_space fmt)
+ (extract_trait_ref ctx fmt no_params_tys true)
+ trait_refs)))
| TypeVar vid -> F.pp_print_string fmt (ctx_get_type_var vid ctx)
| Literal lty -> extract_literal_type ctx fmt lty
| Arrow (arg_ty, ret_ty) ->
@@ -1209,6 +1280,113 @@ let rec extract_ty (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_space fmt ();
extract_rec false ret_ty;
if inside then F.pp_print_string fmt ")"
+ | TraitType (trait_ref, generics, type_name) ->
+ if !parameterize_trait_types then raise (Failure "Unimplemented")
+ else if trait_ref.trait_id <> Self then (
+ (* HOL4 doesn't have 1st class types *)
+ assert (!backend <> HOL4);
+ let use_brackets = generics <> empty_generic_args in
+ if use_brackets then F.pp_print_string fmt "(";
+ extract_trait_ref ctx fmt no_params_tys false trait_ref;
+ extract_generic_args ctx fmt no_params_tys generics;
+ let name =
+ ctx_get_trait_type trait_ref.trait_decl_ref.trait_decl_id type_name
+ ctx
+ in
+ if use_brackets then F.pp_print_string fmt ")";
+ F.pp_print_string fmt ("." ^ name))
+ else
+ (* There are two situations:
+ - we are extracting a declared item (typically a function signature)
+ for a trait declaration. We directly refer to the item (we extract
+ trait declarations as structures, so we can refer to their fields)
+ - we are extracting a provided method for a trait declaration. We
+ refer to the item in the self trait clause (see {!SelfTraitClauseId}).
+
+ Remark: we can't get there for trait *implementations* because then the
+ types should have been normalized.
+ *)
+ let trait_decl_id = Option.get ctx.trait_decl_id in
+ let item_name = ctx_get_trait_type trait_decl_id type_name ctx in
+ assert (generics = empty_generic_args);
+ if ctx.is_provided_method then
+ (* Provided method: use the trait self clause *)
+ let self_clause = ctx_get_trait_self_clause ctx in
+ F.pp_print_string fmt (self_clause ^ "." ^ item_name)
+ else
+ (* Declaration: directly refer to the item *)
+ F.pp_print_string fmt item_name
+
+and extract_trait_ref (ctx : extraction_ctx) (fmt : F.formatter)
+ (no_params_tys : TypeDeclId.Set.t) (inside : bool) (tr : trait_ref) : unit =
+ let use_brackets = tr.generics <> empty_generic_args && inside in
+ if use_brackets then F.pp_print_string fmt "(";
+ extract_trait_instance_id ctx fmt no_params_tys inside tr.trait_id;
+ extract_generic_args ctx fmt no_params_tys tr.generics;
+ if use_brackets then F.pp_print_string fmt ")"
+
+and extract_trait_decl_ref (ctx : extraction_ctx) (fmt : F.formatter)
+ (no_params_tys : TypeDeclId.Set.t) (inside : bool) (tr : trait_decl_ref) :
+ unit =
+ let use_brackets = tr.decl_generics <> empty_generic_args && inside in
+ let is_opaque = false in
+ let name = ctx_get_trait_decl is_opaque tr.trait_decl_id ctx in
+ if use_brackets then F.pp_print_string fmt "(";
+ F.pp_print_string fmt name;
+ extract_generic_args ctx fmt no_params_tys tr.decl_generics;
+ if use_brackets then F.pp_print_string fmt ")"
+
+and extract_generic_args (ctx : extraction_ctx) (fmt : F.formatter)
+ (no_params_tys : TypeDeclId.Set.t) (generics : generic_args) : unit =
+ let { types; const_generics; trait_refs } = generics in
+ if !backend <> HOL4 then (
+ if types <> [] then (
+ F.pp_print_space fmt ();
+ Collections.List.iter_link (F.pp_print_space fmt)
+ (extract_ty ctx fmt no_params_tys true)
+ types);
+ if const_generics <> [] then (
+ assert (!backend <> HOL4);
+ F.pp_print_space fmt ();
+ Collections.List.iter_link (F.pp_print_space fmt)
+ (extract_const_generic ctx fmt true)
+ const_generics));
+ if trait_refs <> [] then (
+ F.pp_print_space fmt ();
+ Collections.List.iter_link (F.pp_print_space fmt)
+ (extract_trait_ref ctx fmt no_params_tys true)
+ trait_refs)
+
+and extract_trait_instance_id (ctx : extraction_ctx) (fmt : F.formatter)
+ (no_params_tys : TypeDeclId.Set.t) (inside : bool) (id : trait_instance_id)
+ : unit =
+ let with_opaque_pre = false in
+ match id with
+ | Self ->
+ (* This has specific treatment depending on the item we're extracting
+ (associated type, etc.). We should have caught this elsewhere. *)
+ raise (Failure "Unexpected")
+ | TraitImpl id ->
+ let name = ctx_get_trait_impl with_opaque_pre id ctx in
+ F.pp_print_string fmt name
+ | Clause id ->
+ let name = ctx_get_local_trait_clause id ctx in
+ F.pp_print_string fmt name
+ | ParentClause (inst_id, decl_id, clause_id) ->
+ (* Use the trait decl id to lookup the name *)
+ let name = ctx_get_trait_parent_clause decl_id clause_id ctx in
+ extract_trait_instance_id ctx fmt no_params_tys true inst_id;
+ F.pp_print_string fmt ("." ^ name)
+ | ItemClause (inst_id, decl_id, item_name, clause_id) ->
+ (* Use the trait decl id to lookup the name *)
+ let name = ctx_get_trait_item_clause decl_id item_name clause_id ctx in
+ extract_trait_instance_id ctx fmt no_params_tys true inst_id;
+ F.pp_print_string fmt ("." ^ name)
+ | TraitRef trait_ref ->
+ extract_trait_ref ctx fmt no_params_tys inside trait_ref
+ | UnknownTrait _ ->
+ (* This is an error case *)
+ raise (Failure "Unexpected")
(** Compute the names for all the top-level identifiers used in a type
definition (type name, variant names, field names, etc. but not type
@@ -1524,6 +1702,172 @@ let extract_comment (fmt : F.formatter) (sl : string list) : unit =
F.pp_print_string fmt rd;
F.pp_close_box fmt ()
+let extract_trait_clause_type (ctx : extraction_ctx) (fmt : F.formatter)
+ (no_params_tys : TypeDeclId.Set.t) (clause : trait_clause) : unit =
+ let with_opaque_pre = false in
+ let trait_name = ctx_get_trait_decl with_opaque_pre clause.trait_id ctx in
+ F.pp_print_string fmt trait_name;
+ extract_generic_args ctx fmt no_params_tys clause.generics
+
+(** Insert a space, if necessary *)
+let insert_req_space (fmt : F.formatter) (space : bool ref) : unit =
+ if !space then space := false else F.pp_print_space fmt ()
+
+(** Extract the trait self clause.
+
+ We add the trait self clause for provided methods (see {!TraitSelfClauseId}).
+ *)
+let extract_trait_self_clause (insert_req_space : unit -> unit)
+ (ctx : extraction_ctx) (fmt : F.formatter) (trait_decl : trait_decl)
+ (params : string list) : unit =
+ insert_req_space ();
+ F.pp_print_string fmt "(";
+ let self_clause = ctx_get_trait_self_clause ctx in
+ F.pp_print_string fmt self_clause;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ let with_opaque_pre = false in
+ let trait_id = ctx_get_trait_decl with_opaque_pre trait_decl.def_id ctx in
+ F.pp_print_string fmt trait_id;
+ List.iter
+ (fun p ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt p)
+ params;
+ F.pp_print_string fmt ")"
+
+(**
+ - [trait_decl]: if [Some], it means we are extracting the generics for a provided
+ method and need to insert a trait self clause (see {!TraitSelfClauseId}).
+ *)
+let extract_generic_params (ctx : extraction_ctx) (fmt : F.formatter)
+ (no_params_tys : TypeDeclId.Set.t) ?(use_forall = false)
+ ?(use_forall_use_sep = true) ?(as_implicits : bool = false)
+ ?(space : bool ref option = None) ?(trait_decl : trait_decl option = None)
+ (generics : generic_params) (type_params : string list)
+ (cg_params : string list) (trait_clauses : string list) : unit =
+ let all_params = List.concat [ type_params; cg_params; trait_clauses ] in
+ (* HOL4 doesn't support const generics *)
+ assert (cg_params = [] || !backend <> HOL4);
+ let left_bracket (implicit : bool) =
+ if implicit then F.pp_print_string fmt "{" else F.pp_print_string fmt "("
+ in
+ let right_bracket (implicit : bool) =
+ if implicit then F.pp_print_string fmt "}" else F.pp_print_string fmt ")"
+ in
+ let insert_req_space () =
+ match space with
+ | None -> F.pp_print_space fmt ()
+ | Some space -> insert_req_space fmt space
+ in
+ (* Print the type/const generic parameters *)
+ if all_params <> [] then (
+ if use_forall then (
+ if use_forall_use_sep then (
+ insert_req_space ();
+ F.pp_print_string fmt ":");
+ insert_req_space ();
+ F.pp_print_string fmt "forall");
+ (* Small helper - we may need to split the parameters *)
+ let print_generics (as_implicits : bool) (type_params : string list)
+ (const_generics : const_generic_var list)
+ (trait_clauses : trait_clause list) : unit =
+ (* Note that in HOL4 we don't print the type parameters. *)
+ if !backend <> HOL4 then (
+ (* Print the type parameters *)
+ if type_params <> [] then (
+ insert_req_space ();
+ (* ( *)
+ left_bracket as_implicits;
+ List.iter
+ (fun s ->
+ F.pp_print_string fmt s;
+ F.pp_print_space fmt ())
+ type_params;
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt (type_keyword ());
+ (* ) *)
+ right_bracket as_implicits);
+ (* Print the const generic parameters *)
+ List.iter
+ (fun (var : const_generic_var) ->
+ insert_req_space ();
+ (* ( *)
+ left_bracket as_implicits;
+ let n = ctx_get_const_generic_var var.index ctx in
+ F.pp_print_string fmt n;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ extract_literal_type ctx fmt var.ty;
+ (* ) *)
+ right_bracket as_implicits)
+ const_generics);
+ (* Print the trait clauses *)
+ List.iter
+ (fun (clause : trait_clause) ->
+ insert_req_space ();
+ (* ( *)
+ left_bracket as_implicits;
+ let n = ctx_get_local_trait_clause clause.clause_id ctx in
+ F.pp_print_string fmt n;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ extract_trait_clause_type ctx fmt no_params_tys clause;
+ (* ) *)
+ right_bracket as_implicits)
+ trait_clauses
+ in
+ (* If we extract the generics for a provided method for a trait declaration
+ (indicated by the trait decl given as input), we need to split the generics:
+ - we print the generics for the trait decl
+ - we print the trait self clause
+ - we print the generics for the trait method
+ *)
+ match trait_decl with
+ | None ->
+ print_generics as_implicits type_params generics.const_generics
+ generics.trait_clauses
+ | Some trait_decl ->
+ (* Split the generics between the generics specific to the trait decl
+ and those specific to the trait method *)
+ let open Collections.List in
+ let dtype_params, mtype_params =
+ split_at type_params (length trait_decl.generics.types)
+ in
+ let dcgs, mcgs =
+ split_at generics.const_generics
+ (length trait_decl.generics.const_generics)
+ in
+ let dtrait_clauses, mtrait_clauses =
+ split_at generics.trait_clauses
+ (length trait_decl.generics.trait_clauses)
+ in
+ (* Extract the trait decl generics - note that we can always deduce
+ those parameters from the trait self clause: for this reason
+ they are always implicit *)
+ print_generics true dtype_params dcgs dtrait_clauses;
+ (* Extract the trait self clause *)
+ let params =
+ concat
+ [
+ dtype_params;
+ map
+ (fun (cg : const_generic_var) ->
+ ctx_get_const_generic_var cg.index ctx)
+ dcgs;
+ map
+ (fun c -> ctx_get_local_trait_clause c.clause_id ctx)
+ dtrait_clauses;
+ ]
+ in
+ extract_trait_self_clause insert_req_space ctx fmt trait_decl params;
+ (* Extract the method generics *)
+ print_generics as_implicits mtype_params mcgs mtrait_clauses)
+
(** Extract a type declaration.
This function is for all type declarations and all backends **at the exception**
@@ -1551,19 +1895,15 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
*)
let is_opaque = type_kind = None in
let is_opaque_coq = !backend = Coq && is_opaque in
- let use_forall =
- is_opaque_coq && (def.type_params <> [] || def.const_generic_params <> [])
- in
+ let use_forall = is_opaque_coq && def.generics <> empty_generic_params in
(* Retrieve the definition name *)
let with_opaque_pre = false in
let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in
(* Add the type and const generic params - note that we need those bindings only for the
* body translation (they are not top-level) *)
- let ctx_body, type_params, cg_params =
- ctx_add_type_const_generic_params def.type_params def.const_generic_params
- ctx
+ let ctx_body, type_params, cg_params, trait_clauses =
+ ctx_add_generic_params def.generics ctx
in
- let ty_cg_params = List.append type_params cg_params in
(* Add a break before *)
if !backend <> HOL4 || not (decl_is_first_from_group kind) then
F.pp_print_break fmt 0 0;
@@ -1583,40 +1923,12 @@ let extract_type_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
(match qualif with
| Some qualif -> F.pp_print_string fmt (qualif ^ " " ^ def_name)
| None -> F.pp_print_string fmt def_name);
- (* HOL4 doesn't support const generics *)
- assert (cg_params = [] || !backend <> HOL4);
- (* Print the type/const generic parameters *)
- if ty_cg_params <> [] && !backend <> HOL4 then (
- if use_forall then (
- F.pp_print_space fmt ();
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- F.pp_print_string fmt "forall");
- (* Print the type parameters *)
- if type_params <> [] then (
- F.pp_print_space fmt ();
- F.pp_print_string fmt "(";
- List.iter
- (fun s ->
- F.pp_print_string fmt s;
- F.pp_print_space fmt ())
- type_params;
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- F.pp_print_string fmt (type_keyword () ^ ")"));
- (* Print the const generic parameters *)
- List.iter
- (fun (var : const_generic_var) ->
- F.pp_print_space fmt ();
- F.pp_print_string fmt "(";
- let n = ctx_get_const_generic_var var.index ctx in
- F.pp_print_string fmt n;
- F.pp_print_space fmt ();
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- extract_literal_type ctx fmt var.ty;
- F.pp_print_string fmt ")")
- def.const_generic_params);
+ (* HOL4 doesn't support const generics, and type definitions in HOL4 don't
+ support trait clauses *)
+ assert ((cg_params = [] && trait_clauses = []) || !backend <> HOL4);
+ (* Print the generic parameters *)
+ extract_generic_params ctx_body fmt type_decl_group ~use_forall def.generics
+ type_params cg_params trait_clauses;
(* Print the "=" if we extract the body*)
if extract_body then (
F.pp_print_space fmt ();
@@ -1676,9 +1988,12 @@ let extract_type_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter)
let with_opaque_pre = false in
let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in
(* Generic parameters are unsupported *)
- assert (def.const_generic_params = []);
+ assert (def.generics.const_generics = []);
+ (* Trait clauses on type definitions are unsupported *)
+ assert (def.generics.trait_clauses = []);
+ (* Types *)
(* Count the number of parameters *)
- let num_params = List.length def.type_params in
+ let num_params = List.length def.generics.types in
(* Generate the declaration *)
F.pp_print_space fmt ();
F.pp_print_string fmt
@@ -1699,8 +2014,7 @@ let extract_type_decl_hol4_empty_record (ctx : extraction_ctx)
let with_opaque_pre = false in
let def_name = ctx_get_local_type with_opaque_pre def.def_id ctx in
(* Sanity check *)
- assert (def.type_params = []);
- assert (def.const_generic_params = []);
+ assert (def.generics = empty_generic_params);
(* Generate the declaration *)
F.pp_print_space fmt ();
F.pp_print_string fmt ("Type " ^ def_name ^ " = “: unit”");
@@ -1740,13 +2054,12 @@ let extract_type_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter)
(kind : decl_kind) (decl : type_decl) : unit =
assert (!backend = Coq);
(* Generating the [Arguments] instructions is useful only if there are type parameters *)
- if decl.type_params = [] && decl.const_generic_params = [] then ()
+ if decl.generics.types = [] && decl.generics.const_generics = [] then ()
else
(* Add the type params - note that we need those bindings only for the
* body translation (they are not top-level) *)
- let _ctx_body, type_params, cg_params =
- ctx_add_type_const_generic_params decl.type_params
- decl.const_generic_params ctx
+ let _ctx_body, type_params, cg_params, trait_clauses =
+ ctx_add_generic_params decl.generics ctx
in
(* Auxiliary function to extract an [Arguments Cons {T} _ _.] instruction *)
let extract_arguments_info (cons_name : string) (fields : 'a list) : unit =
@@ -1754,27 +2067,22 @@ let extract_type_decl_coq_arguments (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_print_break fmt 0 0;
(* Open a box *)
F.pp_open_hovbox fmt ctx.indent_incr;
- (* Small utility *)
- let print_vars () =
- List.iter
- (fun (var : string) ->
- F.pp_print_space fmt ();
- F.pp_print_string fmt ("{" ^ var ^ "}"))
- (List.append type_params cg_params)
- in
- let print_fields () =
- List.iter
- (fun _ ->
- F.pp_print_space fmt ();
- F.pp_print_string fmt "_")
- fields
- in
F.pp_print_break fmt 0 0;
F.pp_print_string fmt "Arguments";
F.pp_print_space fmt ();
F.pp_print_string fmt cons_name;
- print_vars ();
- print_fields ();
+ (* Print the type/const params and the trait clauses (`{T}`) *)
+ List.iter
+ (fun (var : string) ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ("{" ^ var ^ "}"))
+ (List.concat [ type_params; cg_params; trait_clauses ]);
+ (* Print the fields (`_`) *)
+ List.iter
+ (fun _ ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "_")
+ fields;
F.pp_print_string fmt ".";
(* Close the box *)
@@ -1827,9 +2135,8 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
let is_rec = decl_is_from_rec_group kind in
if is_rec then
(* Add the type params *)
- let ctx, type_params, cg_params =
- ctx_add_type_const_generic_params decl.type_params
- decl.const_generic_params ctx
+ let ctx, type_params, cg_params, trait_clauses =
+ ctx_add_generic_params decl.generics ctx
in
let ctx, record_var = ctx_add_var "x" (VarId.of_int 0) ctx in
let ctx, field_var = ctx_add_var "x" (VarId.of_int 1) ctx in
@@ -1850,33 +2157,12 @@ let extract_type_decl_record_field_projectors (ctx : extraction_ctx)
F.pp_print_space fmt ();
let field_name = ctx_get_field (AdtId decl.def_id) field_id ctx in
F.pp_print_string fmt field_name;
- F.pp_print_space fmt ();
- (* Print the type parameters *)
- if type_params <> [] then (
- F.pp_print_string fmt "{";
- List.iter
- (fun p ->
- F.pp_print_string fmt p;
- F.pp_print_space fmt ())
- type_params;
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- F.pp_print_string fmt "Type}";
- F.pp_print_space fmt ());
- (* Print the const generic parameters *)
- if cg_params <> [] then
- List.iter
- (fun (v : const_generic_var) ->
- F.pp_print_string fmt "{";
- let n = ctx_get_const_generic_var v.index ctx in
- F.pp_print_string fmt n;
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- extract_literal_type ctx fmt v.ty;
- F.pp_print_string fmt "}";
- F.pp_print_space fmt ())
- decl.const_generic_params;
+ (* Print the generics *)
+ let as_implicits = true in
+ extract_generic_params ctx fmt TypeDeclId.Set.empty ~as_implicits
+ decl.generics type_params cg_params trait_clauses;
(* Print the record parameter *)
+ F.pp_print_space fmt ();
F.pp_print_string fmt "(";
F.pp_print_string fmt record_var;
F.pp_print_space fmt ();
@@ -2067,10 +2353,11 @@ let extract_state_type (fmt : F.formatter) (ctx : extraction_ctx)
(** Compute the names for all the pure functions generated from a rust function
(forward function and backward functions).
*)
-let extract_fun_decl_register_names (ctx : extraction_ctx) (keep_fwd : bool)
+let extract_fun_decl_register_names (ctx : extraction_ctx)
(has_decreases_clause : fun_decl -> bool) (def : pure_fun_translation) :
extraction_ctx =
- let (fwd, loop_fwds), back_ls = def in
+ let fwd = def.fwd in
+ let backs = def.backs in
(* Register the decrease clauses, if necessary *)
let register_decreases ctx def =
if has_decreases_clause def then
@@ -2083,22 +2370,19 @@ let extract_fun_decl_register_names (ctx : extraction_ctx) (keep_fwd : bool)
| Lean -> ctx_add_decreases_proof def ctx
else ctx
in
- let ctx = List.fold_left register_decreases ctx (fwd :: loop_fwds) in
- let register_fun ctx f = ctx_add_fun_decl (keep_fwd, def) f ctx in
+ let ctx = List.fold_left register_decreases ctx (fwd.f :: fwd.loops) in
+ let register_fun ctx f = ctx_add_fun_decl def f ctx in
let register_funs ctx fl = List.fold_left register_fun ctx fl in
- (* Register the forward functions' names *)
- let ctx = register_funs ctx (fwd :: loop_fwds) in
- (* Register the backward functions' names *)
+ (* Register the names of the forward functions *)
let ctx =
- List.fold_left
- (fun ctx (back, loop_backs) ->
- let ctx = register_fun ctx back in
- register_funs ctx loop_backs)
- ctx back_ls
+ if def.keep_fwd then register_funs ctx (fwd.f :: fwd.loops) else ctx
in
-
- (* Return *)
- ctx
+ (* Register the names of the backward functions *)
+ List.fold_left
+ (fun ctx { f = back; loops = loop_backs } ->
+ let ctx = register_fun ctx back in
+ register_funs ctx loop_backs)
+ ctx backs
(** Simply add the global name to the context. *)
let extract_global_decl_register_names (ctx : extraction_ctx)
@@ -2122,11 +2406,11 @@ let extract_adt_g_value
(inside : bool) (variant_id : VariantId.id option) (field_values : 'v list)
(ty : ty) : extraction_ctx =
match ty with
- | Adt (Tuple, type_args, cg_args) ->
+ | Adt (Tuple, generics) ->
(* Tuple *)
(* For now, we only support fully applied tuple constructors *)
- assert (List.length type_args = List.length field_values);
- assert (cg_args = []);
+ assert (List.length generics.types = List.length field_values);
+ assert (generics.const_generics = [] && generics.trait_refs = []);
(* This is very annoying: in Coq, we can't write [()] for the value of
type [unit], we have to write [tt]. *)
if !backend = Coq && field_values = [] then (
@@ -2144,7 +2428,7 @@ let extract_adt_g_value
in
F.pp_print_string fmt ")";
ctx)
- | Adt (adt_id, _, _) ->
+ | Adt (adt_id, _) ->
(* "Regular" ADT *)
(* If we are generating a pattern for a let-binding and we target Lean,
@@ -2249,6 +2533,9 @@ let rec extract_texpression (ctx : extraction_ctx) (fmt : F.formatter)
| Var var_id ->
let var_name = ctx_get_var var_id ctx in
F.pp_print_string fmt var_name
+ | CVar var_id ->
+ let var_name = ctx_get_const_generic_var var_id ctx in
+ F.pp_print_string fmt var_name
| Const cv -> ctx.fmt.extract_literal fmt inside cv
| App _ ->
let app, args = destruct_apps e in
@@ -2279,14 +2566,23 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(* Top-level qualifier *)
match qualif.id with
| FunOrOp fun_id ->
- extract_function_call ctx fmt inside fun_id qualif.type_args
- qualif.const_generic_args args
+ extract_function_call ctx fmt inside fun_id qualif.generics args
| Global global_id -> extract_global ctx fmt global_id
| AdtCons adt_cons_id ->
- extract_adt_cons ctx fmt inside adt_cons_id qualif.type_args
- qualif.const_generic_args args
+ extract_adt_cons ctx fmt inside adt_cons_id qualif.generics args
| Proj proj ->
- extract_field_projector ctx fmt inside app proj qualif.type_args args)
+ extract_field_projector ctx fmt inside app proj qualif.generics args
+ | TraitConst (trait_ref, generics, const_name) ->
+ let use_brackets = generics <> empty_generic_args in
+ if use_brackets then F.pp_print_string fmt "(";
+ extract_trait_ref ctx fmt TypeDeclId.Set.empty false trait_ref;
+ extract_generic_args ctx fmt TypeDeclId.Set.empty generics;
+ let name =
+ ctx_get_trait_const trait_ref.trait_decl_ref.trait_decl_id
+ const_name ctx
+ in
+ if use_brackets then F.pp_print_string fmt ")";
+ F.pp_print_string fmt ("." ^ name))
| _ ->
(* "Regular" expression *)
(* Open parentheses *)
@@ -2309,8 +2605,8 @@ and extract_App (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(** Subcase of the app case: function call *)
and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
- (inside : bool) (fid : fun_or_op_id) (type_args : ty list)
- (cg_args : const_generic list) (args : texpression list) : unit =
+ (inside : bool) (fid : fun_or_op_id) (generics : generic_args)
+ (args : texpression list) : unit =
match (fid, args) with
| Unop unop, [ arg ] ->
(* A unop can have *at most* one argument (the result can't be a function!).
@@ -2329,22 +2625,97 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_open_hovbox fmt ctx.indent_incr;
(* Print the function name *)
let with_opaque_pre = ctx.use_opaque_pre in
- let fun_name = ctx_get_function with_opaque_pre fun_id ctx in
- F.pp_print_string fmt fun_name;
- (* Sanity check: HOL4 doesn't support const generics *)
- assert (cg_args = [] || !backend <> HOL4);
- (* Print the type parameters, if the backend is not HOL4 *)
- if !backend <> HOL4 then (
- List.iter
- (fun ty ->
- F.pp_print_space fmt ();
- extract_ty ctx fmt TypeDeclId.Set.empty true ty)
- type_args;
- List.iter
- (fun cg ->
+ (* For the function name: the id is not the same depending on whether
+ we call a trait method and a "regular" function (remark: trait
+ method *implementations* are considered as regular functions here;
+ only calls to method of traits which are parameterized in a where
+ clause have a special treatment.
+
+ Remark: the reason why trait method declarations have a special
+ treatment is that, as traits are extracted to records, we may
+ allow collisions between trait item names and some other names,
+ while we do not allow collisions between function names.
+
+ # Impl trait refs:
+ ==================
+ When the trait ref refers to an impl, in
+ [InterpreterStatement.eval_transparent_function_call_symbolic] we
+ replace the call to the trait impl method to a call to the function
+ which implements the trait method (that is, we "forget" that we
+ called a trait method, and treat it as a regular function call).
+
+ # Provided trait methods:
+ =========================
+ Calls to provided trait methods also have a special treatment.
+ For now, we do not allow overriding provided trait methods (methods
+ for which a default implementation is provided in the trait declaration).
+ Whenever we translate a provided trait method, we translate it once as
+ a function which takes a trait ref as input. We have to handle this
+ case below.
+
+ With an example, if in Rust we write:
+ {[
+ fn Foo {
+ fn f(&self) -> u32; // Required
+ fn ret_true(&self) -> bool { true } // Provided
+ }
+ ]}
+
+ We generate:
+ {[
+ structure Foo (Self : Type) = {
+ f : Self -> result u32
+ }
+
+ let ret_true (Self : Type) (self_clause : Foo Self) (self : Self) : result bool =
+ true
+ ]}
+ *)
+ (match fun_id with
+ | FromLlbc
+ (TraitMethod (trait_ref, method_name, _fun_decl_id), lp_id, rg_id) ->
+ (* We have to check whether the trait method is required or provided *)
+ let trait_decl_id = trait_ref.trait_decl_ref.trait_decl_id in
+ let trait_decl =
+ TraitDeclId.Map.find trait_decl_id ctx.trans_trait_decls
+ in
+ let method_id =
+ PureUtils.trait_decl_get_method trait_decl method_name
+ in
+
+ if not method_id.is_provided then (
+ (* Required method *)
+ assert (lp_id = None);
+ extract_trait_ref ctx fmt TypeDeclId.Set.empty true trait_ref;
+ let fun_name =
+ ctx_get_trait_method trait_ref.trait_decl_ref.trait_decl_id
+ method_name rg_id ctx
+ in
+ F.pp_print_string fmt ("." ^ fun_name))
+ else
+ (* Provided method: we see it as a regular function call, and use
+ the function name *)
+ let fun_id =
+ FromLlbc (FunId (A.Regular method_id.id), lp_id, rg_id)
+ in
+ let fun_name = ctx_get_function with_opaque_pre fun_id ctx in
+ F.pp_print_string fmt fun_name;
+
+ (* Note that we do not need to print the generics for the trait
+ declaration: they are always implicit as they can be deduced
+ from the trait self clause.
+
+ Print the trait ref (to instantate the self clause) *)
F.pp_print_space fmt ();
- extract_const_generic ctx fmt true cg)
- cg_args);
+ extract_trait_ref ctx fmt TypeDeclId.Set.empty true trait_ref
+ | _ ->
+ let fun_name = ctx_get_function with_opaque_pre fun_id ctx in
+ F.pp_print_string fmt fun_name);
+
+ (* Sanity check: HOL4 doesn't support const generics *)
+ assert (generics.const_generics = [] || !backend <> HOL4);
+ (* Print the generics *)
+ extract_generic_args ctx fmt TypeDeclId.Set.empty generics;
(* Print the arguments *)
List.iter
(fun ve ->
@@ -2366,9 +2737,9 @@ and extract_function_call (ctx : extraction_ctx) (fmt : F.formatter)
(** Subcase of the app case: ADT constructor *)
and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
- (adt_cons : adt_cons_id) (type_args : ty list)
- (cg_args : const_generic list) (args : texpression list) : unit =
- let e_ty = Adt (adt_cons.adt_id, type_args, cg_args) in
+ (adt_cons : adt_cons_id) (generics : generic_args) (args : texpression list)
+ : unit =
+ let e_ty = Adt (adt_cons.adt_id, generics) in
let is_single_pat = false in
let _ =
extract_adt_g_value
@@ -2382,7 +2753,7 @@ and extract_adt_cons (ctx : extraction_ctx) (fmt : F.formatter) (inside : bool)
(** Subcase of the app case: ADT field projector. *)
and extract_field_projector (ctx : extraction_ctx) (fmt : F.formatter)
(inside : bool) (original_app : texpression) (proj : projection)
- (_proj_type_params : ty list) (args : texpression list) : unit =
+ (_generics : generic_args) (args : texpression list) : unit =
(* We isolate the first argument (if there is), in order to pretty print the
* projection ([x.field] instead of [MkAdt?.field x] *)
match args with
@@ -2734,9 +3105,7 @@ and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter)
let extract_as_unit =
match (!backend, supd.struct_id) with
| HOL4, AdtId adt_id ->
- let d =
- TypeDeclId.Map.find adt_id ctx.trans_ctx.type_context.type_decls
- in
+ let d = TypeDeclId.Map.find adt_id ctx.trans_ctx.type_ctx.type_decls in
d.kind = Struct []
| _ -> false
in
@@ -2835,17 +3204,17 @@ and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_open_hvbox fmt ctx.indent_incr;
let need_paren = inside in
if need_paren then F.pp_print_string fmt "(";
- (* Open the box for `Array.mk T N [` *)
+ (* Open the box for `Array.replicate T N [` *)
F.pp_open_hovbox fmt ctx.indent_incr;
(* Print the array constructor *)
let cs = ctx_get_struct false (Assumed Array) ctx in
F.pp_print_string fmt cs;
(* Print the parameters *)
- let _, tys, cgs = ty_as_adt e_ty in
- let ty = Collections.List.to_cons_nil tys in
+ let _, generics = ty_as_adt e_ty in
+ let ty = Collections.List.to_cons_nil generics.types in
F.pp_print_space fmt ();
extract_ty ctx fmt TypeDeclId.Set.empty true ty;
- let cg = Collections.List.to_cons_nil cgs in
+ let cg = Collections.List.to_cons_nil generics.const_generics in
F.pp_print_space fmt ();
extract_const_generic ctx fmt true cg;
F.pp_print_space fmt ();
@@ -2872,17 +3241,15 @@ and extract_StructUpdate (ctx : extraction_ctx) (fmt : F.formatter)
F.pp_close_box fmt ()
| _ -> raise (Failure "Unreachable")
-(** Insert a space, if necessary *)
-let insert_req_space (fmt : F.formatter) (space : bool ref) : unit =
- if !space then space := false else F.pp_print_space fmt ()
-
(** A small utility to print the parameters of a function signature.
We return two contexts:
- - the context augmented with bindings for the type parameters
- - the context augmented with bindings for the type parameters *and*
+ - the context augmented with bindings for the generics
+ - the context augmented with bindings for the generics *and*
bindings for the input values
+ We also return names for the type parameters, const generics, etc.
+
TODO: do we really need the first one? We should probably always use
the second one.
It comes from the fact that when we print the input values for the
@@ -2890,57 +3257,40 @@ let insert_req_space (fmt : F.formatter) (space : bool ref) : unit =
patterns, not the variables). We should figure a cleaner way.
*)
let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx)
- (fmt : F.formatter) (def : fun_decl) : extraction_ctx * extraction_ctx =
+ (fmt : F.formatter) (def : fun_decl) :
+ extraction_ctx * extraction_ctx * string list =
+ (* First, add the associated types and constants if the function is a method
+ in a trait declaration.
+
+ About the order: we want to make sure the names are reserved for
+ those (variable names might collide with them but it is ok, we will add
+ suffixes to the variables).
+
+ TODO: micro-pass to update what happens when calling trait provided
+ functions.
+ *)
+ let ctx, trait_decl =
+ match def.kind with
+ | TraitMethodProvided (decl_id, _) ->
+ let trait_decl = T.TraitDeclId.Map.find decl_id ctx.trans_trait_decls in
+ let ctx, _ = ctx_add_trait_self_clause ctx in
+ let ctx = { ctx with is_provided_method = true } in
+ (ctx, Some trait_decl)
+ | _ -> (ctx, None)
+ in
(* Add the type parameters - note that we need those bindings only for the
* body translation (they are not top-level) *)
- let ctx, type_params, cg_params =
- ctx_add_type_const_generic_params def.signature.type_params
- def.signature.const_generic_params ctx
+ let ctx, type_params, cg_params, trait_clauses =
+ ctx_add_generic_params def.signature.generics ctx
in
- (* Print the parameters - rem.: we should have filtered the functions
- * with no input parameters *)
- (* The type parameters.
-
- Note that in HOL4 we don't print the type parameters.
- *)
- if (type_params <> [] || cg_params <> []) && !backend <> HOL4 then (
- (* Open a box for the type and const generic parameters *)
- F.pp_open_hovbox fmt 0;
- (* The type parameters *)
- if type_params <> [] then (
- insert_req_space fmt space;
- F.pp_print_string fmt "(";
- List.iter
- (fun (p : type_var) ->
- let pname = ctx_get_type_var p.index ctx in
- F.pp_print_string fmt pname;
- F.pp_print_space fmt ())
- def.signature.type_params;
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- let type_keyword =
- match !backend with
- | FStar -> "Type0"
- | Coq | Lean -> "Type"
- | HOL4 -> raise (Failure "Unreachable")
- in
- F.pp_print_string fmt (type_keyword ^ ")"));
- (* The const generic parameters *)
- if cg_params <> [] then
- List.iter
- (fun (p : const_generic_var) ->
- let pname = ctx_get_const_generic_var p.index ctx in
- insert_req_space fmt space;
- F.pp_print_string fmt "(";
- F.pp_print_string fmt pname;
- F.pp_print_space fmt ();
- F.pp_print_string fmt ":";
- F.pp_print_space fmt ();
- extract_literal_type ctx fmt p.ty;
- F.pp_print_string fmt ")")
- def.signature.const_generic_params;
- (* Close the box for the type parameters *)
- F.pp_close_box fmt ());
+ (* Print the generics *)
+ (* Open a box for the generics *)
+ F.pp_open_hovbox fmt 0;
+ (let space = Some space in
+ extract_generic_params ctx fmt TypeDeclId.Set.empty ~space ~trait_decl
+ def.signature.generics type_params cg_params trait_clauses);
+ (* Close the box for the generics *)
+ F.pp_close_box fmt ();
(* The input parameters - note that doing this adds bindings to the context *)
let ctx_body =
match def.body with
@@ -2963,7 +3313,7 @@ let extract_fun_parameters (space : bool ref) (ctx : extraction_ctx)
ctx)
ctx body.inputs_lvs
in
- (ctx, ctx_body)
+ (ctx, ctx_body, List.concat [ type_params; cg_params; trait_clauses ])
(** A small utility to print the types of the input parameters in the form:
[u32 -> list u32 -> ...]
@@ -2982,6 +3332,11 @@ let extract_fun_input_parameters_types (ctx : extraction_ctx)
in
List.iter extract_param def.signature.inputs
+let extract_fun_inputs_output_parameters_types (ctx : extraction_ctx)
+ (fmt : F.formatter) (def : fun_decl) : unit =
+ extract_fun_input_parameters_types ctx fmt def;
+ extract_ty ctx fmt TypeDeclId.Set.empty false def.signature.output
+
let assert_backend_supports_decreases_clauses () =
match !backend with
| FStar | Lean -> ()
@@ -3032,7 +3387,7 @@ let extract_template_fstar_decreases_clause (ctx : extraction_ctx)
F.pp_print_space fmt ();
(* Extract the parameters *)
let space = ref true in
- let _, _ = extract_fun_parameters space ctx fmt def in
+ let _, _, _ = extract_fun_parameters space ctx fmt def in
insert_req_space fmt space;
F.pp_print_string fmt ":";
(* Print the signature *)
@@ -3094,7 +3449,7 @@ let extract_template_lean_termination_and_decreasing (ctx : extraction_ctx)
F.pp_print_space fmt ();
(* Extract the parameters *)
let space = ref true in
- let _, ctx_body = extract_fun_parameters space ctx fmt def in
+ let _, ctx_body, _ = extract_fun_parameters space ctx fmt def in
(* Print the ":=" *)
F.pp_print_space fmt ();
F.pp_print_string fmt ":=";
@@ -3164,7 +3519,7 @@ let extract_fun_comment (ctx : extraction_ctx) (fmt : F.formatter)
(def : fun_decl) : unit =
let { keep_fwd; num_backs } =
PureUtils.RegularFunIdMap.find
- (A.Regular def.def_id, def.loop_id, def.back_id)
+ (Pure.FunId (A.Regular def.def_id), def.loop_id, def.back_id)
ctx.fun_name_info
in
let comment_pre = "[" ^ Print.fun_name_to_string def.basename ^ "]: " in
@@ -3234,9 +3589,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
*)
let is_opaque_coq = !backend = Coq && is_opaque in
let use_forall =
- is_opaque_coq
- && (def.signature.type_params <> []
- || def.signature.const_generic_params <> [])
+ is_opaque_coq && def.signature.generics <> empty_generic_params
in
(* Print the qualifier ("assume", etc.).
@@ -3262,7 +3615,7 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
(* Open a box for "(PARAMS) :" *)
F.pp_open_hovbox fmt 0;
let space = ref true in
- let ctx, ctx_body = extract_fun_parameters space ctx fmt def in
+ let ctx, ctx_body, all_params = extract_fun_parameters space ctx fmt def in
(* Print the return type - note that we have to be careful when
* printing the input values for the decrease clause, because
* it introduces bindings in the context... We thus "forget"
@@ -3310,20 +3663,13 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
(* The name of the decrease clause *)
let decr_name = ctx_get_termination_measure def.def_id def.loop_id ctx in
F.pp_print_string fmt decr_name;
- (* Print the type/const generic parameters - TODO: we do this many
+ (* Print the generic parameters - TODO: we do this many
times, we should have a helper to factor it out *)
List.iter
- (fun (p : type_var) ->
- let pname = ctx_get_type_var p.index ctx in
+ (fun (name : string) ->
F.pp_print_space fmt ();
- F.pp_print_string fmt pname)
- def.signature.type_params;
- List.iter
- (fun (p : const_generic_var) ->
- let pname = ctx_get_const_generic_var p.index ctx in
- F.pp_print_space fmt ();
- F.pp_print_string fmt pname)
- def.signature.const_generic_params;
+ F.pp_print_string fmt name)
+ all_params;
(* Print the input values: we have to be careful here to print
* only the input values which are in common with the *forward*
* function (the additional input values "given back" to the
@@ -3410,19 +3756,12 @@ let extract_fun_decl_gen (ctx : extraction_ctx) (fmt : F.formatter)
(* Open the box for [DECREASES] *)
F.pp_open_hovbox fmt ctx.indent_incr;
F.pp_print_string fmt terminates_name;
- (* Print the type/const generic params - TODO: factor out *)
+ (* Print the generic params - TODO: factor out *)
List.iter
- (fun (p : type_var) ->
- let pname = ctx_get_type_var p.index ctx in
+ (fun (name : string) ->
F.pp_print_space fmt ();
- F.pp_print_string fmt pname)
- def.signature.type_params;
- List.iter
- (fun (p : const_generic_var) ->
- let pname = ctx_get_const_generic_var p.index ctx in
- F.pp_print_space fmt ();
- F.pp_print_string fmt pname)
- def.signature.const_generic_params;
+ F.pp_print_string fmt name)
+ all_params;
(* Print the variables *)
List.iter
(fun v ->
@@ -3480,13 +3819,10 @@ let extract_fun_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter)
ctx_get_local_function with_opaque_pre def.def_id def.loop_id def.back_id
ctx
in
- assert (def.signature.const_generic_params = []);
+ assert (def.signature.generics.const_generics = []);
(* Add the type/const gen parameters - note that we need those bindings
only for the generation of the type (they are not top-level) *)
- let ctx, _, _ =
- ctx_add_type_const_generic_params def.signature.type_params
- def.signature.const_generic_params ctx
- in
+ let ctx, _, _, _ = ctx_add_generic_params def.signature.generics ctx in
(* Add breaks to insert new lines between definitions *)
F.pp_print_break fmt 0 0;
(* Open a box for the whole definition *)
@@ -3635,8 +3971,13 @@ let extract_global_decl_hol4_opaque (ctx : extraction_ctx) (fmt : F.formatter)
(* Print the type *)
F.pp_open_hovbox fmt 0;
extract_ty ctx fmt TypeDeclId.Set.empty false ty;
+ (* Close the definition *)
+ F.pp_print_string fmt ")";
F.pp_close_box fmt ();
- (* Close the definition boxe *) F.pp_close_box fmt ()
+ (* Close the definition box *)
+ F.pp_close_box fmt ();
+ (* Add a line *)
+ F.pp_print_space fmt ()
(** Extract a global declaration.
@@ -3662,10 +4003,9 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter)
(global : A.global_decl) (body : fun_decl) (interface : bool) : unit =
assert body.is_global_decl_body;
assert (Option.is_none body.back_id);
- assert (List.length body.signature.inputs = 0);
+ assert (body.signature.inputs = []);
assert (List.length body.signature.doutputs = 1);
- assert (List.length body.signature.type_params = 0);
- assert (List.length body.signature.const_generic_params = 0);
+ assert (body.signature.generics = empty_generic_params);
(* Add a break then the name of the corresponding LLBC declaration *)
F.pp_print_break fmt 0 0;
@@ -3676,7 +4016,7 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter)
let decl_name = ctx_get_global with_opaque_pre global.def_id ctx in
let body_name =
ctx_get_function with_opaque_pre
- (FromLlbc (Regular global.body_id, None, None))
+ (FromLlbc (Pure.FunId (Regular global.body_id), None, None))
ctx
in
@@ -3713,6 +4053,433 @@ let extract_global_decl (ctx : extraction_ctx) (fmt : F.formatter)
(* Add a break to insert lines between declarations *)
F.pp_print_break fmt 0 0
+(** Register the names for one trait method item *)
+let extract_trait_decl_method_register_names (ctx : extraction_ctx)
+ (trait_decl : trait_decl) (name : string) (id : fun_decl_id) :
+ extraction_ctx =
+ (* We add one field per required forward/backward function *)
+ let trans = A.FunDeclId.Map.find id ctx.trans_funs in
+
+ let register_fun ctx f = ctx_add_trait_method trait_decl name f.f ctx in
+ (* Register the names *)
+ let funs = trans.fwd :: trans.backs in
+ List.fold_left register_fun ctx funs
+
+(** Similar to {!extract_type_decl_register_names} *)
+let extract_trait_decl_register_names (ctx : extraction_ctx)
+ (trait_decl : trait_decl) : extraction_ctx =
+ let {
+ def_id = _;
+ name = _;
+ generics;
+ preds = _;
+ all_trait_clauses = _;
+ consts;
+ types;
+ required_methods;
+ provided_methods = _;
+ } =
+ trait_decl
+ in
+ let ctx = ctx_add_trait_decl trait_decl ctx in
+ (* Parent clauses *)
+ let ctx =
+ List.fold_left
+ (fun ctx clause -> ctx_add_trait_parent_clause trait_decl clause ctx)
+ ctx generics.trait_clauses
+ in
+ (* Constants *)
+ let ctx =
+ List.fold_left
+ (fun ctx (name, (_, _)) -> ctx_add_trait_const trait_decl name ctx)
+ ctx consts
+ in
+ (* Types *)
+ let ctx =
+ List.fold_left
+ (fun ctx (name, (clauses, _)) ->
+ let ctx = ctx_add_trait_type trait_decl name ctx in
+ List.fold_left
+ (fun ctx clause ->
+ ctx_add_trait_type_clause trait_decl name clause ctx)
+ ctx clauses)
+ ctx types
+ in
+ (* Required methods *)
+ List.fold_left
+ (fun ctx (name, id) ->
+ (* We add one field per required forward/backward function *)
+ extract_trait_decl_method_register_names ctx trait_decl name id)
+ ctx required_methods
+
+(** Similar to {!extract_type_decl_register_names} *)
+let extract_trait_impl_register_names (ctx : extraction_ctx)
+ (trait_impl : trait_impl) : extraction_ctx =
+ (* For now we do not support overriding provided methods *)
+ assert (trait_impl.provided_methods = []);
+ (* Everything is taken care of by {!extract_trait_decl_register_names} *but*
+ the name of the implementation itself *)
+ ctx_add_trait_impl trait_impl ctx
+
+(** Small helper.
+
+ The type `ty` is to be understood in a very general sense.
+ *)
+let extract_trait_item (ctx : extraction_ctx) (fmt : F.formatter)
+ (item_name : string) (separator : string) (ty : unit -> unit) : unit =
+ F.pp_print_space fmt ();
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ F.pp_print_string fmt item_name;
+ F.pp_print_space fmt ();
+ (* ":" or "=" *)
+ F.pp_print_string fmt separator;
+ ty ();
+ (match !Config.backend with Lean -> () | _ -> F.pp_print_string fmt ";");
+ F.pp_close_box fmt ()
+
+let extract_trait_decl_item (ctx : extraction_ctx) (fmt : F.formatter)
+ (item_name : string) (ty : unit -> unit) : unit =
+ extract_trait_item ctx fmt item_name ":" ty
+
+let extract_trait_impl_item (ctx : extraction_ctx) (fmt : F.formatter)
+ (item_name : string) (ty : unit -> unit) : unit =
+ let assign = match !Config.backend with Lean -> ":=" | _ -> "=" in
+ extract_trait_item ctx fmt item_name assign ty
+
+(** Small helper - TODO: move *)
+let generic_params_drop_prefix (g1 : generic_params) (g2 : generic_params) :
+ generic_params =
+ let open Collections.List in
+ let types = drop (length g1.types) g2.types in
+ let const_generics = drop (length g1.const_generics) g2.const_generics in
+ let trait_clauses = drop (length g1.trait_clauses) g2.trait_clauses in
+ { types; const_generics; trait_clauses }
+
+(** Small helper.
+
+ Extract the items for a method in a trait decl.
+ *)
+let extract_trait_decl_method_items (ctx : extraction_ctx) (fmt : F.formatter)
+ (decl : trait_decl) (item_name : string) (id : fun_decl_id) : unit =
+ (* Lookup the definition *)
+ let trans = A.FunDeclId.Map.find id ctx.trans_funs in
+ (* Extract the items *)
+ let funs = if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs in
+ let extract_method (f : fun_and_loops) =
+ let f = f.f in
+ let fun_name = ctx_get_trait_method decl.def_id item_name f.back_id ctx in
+ let ty () =
+ (* Extract the generics *)
+ (* We need to add the generics specific to the method, by removing those
+ which actually apply to the trait decl *)
+ let generics =
+ generic_params_drop_prefix decl.generics f.signature.generics
+ in
+ let ctx, type_params, cg_params, trait_clauses =
+ ctx_add_generic_params generics ctx
+ in
+ let use_forall = generics <> empty_generic_params in
+ let use_forall_use_sep = false in
+ extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall
+ ~use_forall_use_sep generics type_params cg_params trait_clauses;
+ if use_forall then F.pp_print_string fmt ",";
+ (* Extract the inputs and output *)
+ F.pp_print_space fmt ();
+ extract_fun_inputs_output_parameters_types ctx fmt f
+ in
+ extract_trait_decl_item ctx fmt fun_name ty
+ in
+ List.iter extract_method funs
+
+(** Extract a trait declaration *)
+let extract_trait_decl (ctx : extraction_ctx) (fmt : F.formatter)
+ (decl : trait_decl) : unit =
+ (* Retrieve the trait name *)
+ let with_opaque_pre = false in
+ let decl_name = ctx_get_trait_decl with_opaque_pre decl.def_id ctx in
+ (* Add a break before *)
+ F.pp_print_break fmt 0 0;
+ (* Print a comment to link the extracted type to its original rust definition *)
+ extract_comment fmt
+ [ "Trait declaration: [" ^ Print.name_to_string decl.name ^ "]" ];
+ F.pp_print_break fmt 0 0;
+ (* Open two outer boxes for the definition, so that whenever possible it gets printed on
+ one line and indents are correct.
+
+ There is just an exception with Lean: in this backend, line breaks are important
+ for the parsing, so we always open a vertical box.
+ *)
+ if !Config.backend = Lean then F.pp_open_vbox fmt ctx.indent_incr
+ else (
+ F.pp_open_hvbox fmt 0;
+ F.pp_open_hvbox fmt ctx.indent_incr);
+
+ (* `struct Trait (....) =` *)
+ (* Open the box for the name + generics *)
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ let qualif =
+ Option.get (ctx.fmt.type_decl_kind_to_qualif SingleNonRec (Some Struct))
+ in
+ F.pp_print_string fmt qualif;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt decl_name;
+ (* Print the generics *)
+ (* We ignore the trait clauses, which we extract as *fields* *)
+ let generics = { decl.generics with trait_clauses = [] } in
+ (* Add the type and const generic params - note that we need those bindings only for the
+ * body translation (they are not top-level) *)
+ let ctx, type_params, cg_params, trait_clauses =
+ ctx_add_generic_params generics ctx
+ in
+ extract_generic_params ctx fmt TypeDeclId.Set.empty generics type_params
+ cg_params trait_clauses;
+
+ F.pp_print_space fmt ();
+ (match !backend with
+ | Lean -> F.pp_print_string fmt "where"
+ | _ -> F.pp_print_string fmt "{");
+
+ (* Close the box for the name + generics *)
+ F.pp_close_box fmt ();
+
+ (*
+ * Extract the items
+ *)
+
+ (* The parent clauses *)
+ List.iter
+ (fun clause ->
+ let item_name =
+ ctx_get_trait_parent_clause decl.def_id clause.clause_id ctx
+ in
+ let ty () =
+ extract_trait_clause_type ctx fmt TypeDeclId.Set.empty clause
+ in
+ extract_trait_decl_item ctx fmt item_name ty)
+ decl.generics.trait_clauses;
+
+ (* The constants *)
+ List.iter
+ (fun (name, (ty, _)) ->
+ let item_name = ctx_get_trait_const decl.def_id name ctx in
+ let ty () =
+ let inside = false in
+ F.pp_print_space fmt ();
+ extract_ty ctx fmt TypeDeclId.Set.empty inside ty
+ in
+ extract_trait_decl_item ctx fmt item_name ty)
+ decl.consts;
+
+ (* The types *)
+ List.iter
+ (fun (name, (clauses, _)) ->
+ (* Extract the type *)
+ let item_name = ctx_get_trait_type decl.def_id name ctx in
+ let ty () =
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt (type_keyword ())
+ in
+ extract_trait_decl_item ctx fmt item_name ty;
+ (* Extract the clauses *)
+ List.iter
+ (fun clause ->
+ let item_name =
+ ctx_get_trait_item_clause decl.def_id name clause.clause_id ctx
+ in
+ let ty () =
+ F.pp_print_space fmt ();
+ extract_trait_clause_type ctx fmt TypeDeclId.Set.empty clause
+ in
+ extract_trait_decl_item ctx fmt item_name ty)
+ clauses)
+ decl.types;
+
+ (* The required methods *)
+ List.iter
+ (fun (name, id) -> extract_trait_decl_method_items ctx fmt decl name id)
+ decl.required_methods;
+
+ (* Close the brackets *)
+ (match !Config.backend with
+ | Lean -> ()
+ | _ ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "}");
+
+ (* Close the outer boxes for the definition *)
+ if !Config.backend <> Lean then F.pp_close_box fmt ();
+ F.pp_close_box fmt ();
+ (* Add breaks to insert new lines between definitions *)
+ F.pp_print_break fmt 0 0
+
+(** Small helper.
+
+ Extract the items for a method in a trait impl.
+ *)
+let extract_trait_impl_method_items (ctx : extraction_ctx) (fmt : F.formatter)
+ (impl : trait_impl) (item_name : string) (id : fun_decl_id)
+ (impl_generics : string list * string list * string list) : unit =
+ let trait_decl_id = impl.impl_trait.trait_decl_id in
+ (* Lookup the definition *)
+ let trans = A.FunDeclId.Map.find id ctx.trans_funs in
+ (* Extract the items *)
+ let funs = if trans.keep_fwd then trans.fwd :: trans.backs else trans.backs in
+ let extract_method (f : fun_and_loops) =
+ let f = f.f in
+ let fun_name = ctx_get_trait_method trait_decl_id item_name f.back_id ctx in
+ let ty () =
+ (* Extract the generics - we need to quantify over the generics which
+ are specific to the method, and call it will all the generics
+ (trait impl + method generics) *)
+ let f_generics =
+ generic_params_drop_prefix impl.generics f.signature.generics
+ in
+ let ctx, f_tys, f_cgs, f_tcs = ctx_add_generic_params f_generics ctx in
+ let use_forall = f_generics <> empty_generic_params in
+ extract_generic_params ctx fmt TypeDeclId.Set.empty ~use_forall f_generics
+ f_tys f_cgs f_tcs;
+ if use_forall then F.pp_print_string fmt ",";
+ (* Extract the function call *)
+ F.pp_print_space fmt ();
+ let id = ctx_get_local_function false f.def_id None f.back_id ctx in
+ F.pp_print_string fmt id;
+ let all_generics =
+ let i_tys, i_cgs, i_tcs = impl_generics in
+ List.concat [ i_tys; f_tys; i_cgs; f_cgs; i_tcs; f_tcs ]
+ in
+ List.iter
+ (fun p ->
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt p)
+ all_generics
+ in
+ extract_trait_impl_item ctx fmt fun_name ty
+ in
+ List.iter extract_method funs
+
+(** Extract a trait implementation *)
+let extract_trait_impl (ctx : extraction_ctx) (fmt : F.formatter)
+ (impl : trait_impl) : unit =
+ (* Retrieve the impl name *)
+ let with_opaque_pre = false in
+ let impl_name = ctx_get_trait_impl with_opaque_pre impl.def_id ctx in
+ (* Add a break before *)
+ F.pp_print_break fmt 0 0;
+ (* Print a comment to link the extracted type to its original rust definition *)
+ extract_comment fmt
+ [ "Trait implementation: [" ^ Print.name_to_string impl.name ^ "]" ];
+ F.pp_print_break fmt 0 0;
+
+ (* Open two outer boxes for the definition, so that whenever possible it gets printed on
+ one line and indents are correct.
+
+ There is just an exception with Lean: in this backend, line breaks are important
+ for the parsing, so we always open a vertical box.
+ *)
+ if !Config.backend = Lean then (
+ F.pp_open_vbox fmt 0;
+ F.pp_open_vbox fmt ctx.indent_incr)
+ else (
+ F.pp_open_hvbox fmt 0;
+ F.pp_open_hvbox fmt ctx.indent_incr);
+
+ (* `let (....) : Trait ... =` *)
+ (* Open the box for the name + generics *)
+ F.pp_open_hovbox fmt ctx.indent_incr;
+ let qualif = Option.get (ctx.fmt.fun_decl_kind_to_qualif SingleNonRec) in
+ F.pp_print_string fmt qualif;
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt impl_name;
+
+ (* Print the generics *)
+ (* Add the type and const generic params - note that we need those bindings only for the
+ * body translation (they are not top-level) *)
+ let ctx, type_params, cg_params, trait_clauses =
+ ctx_add_generic_params impl.generics ctx
+ in
+ let all_generics = (type_params, cg_params, trait_clauses) in
+ extract_generic_params ctx fmt TypeDeclId.Set.empty impl.generics type_params
+ cg_params trait_clauses;
+
+ (* Print the type *)
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt ":";
+ F.pp_print_space fmt ();
+ extract_trait_decl_ref ctx fmt TypeDeclId.Set.empty false impl.impl_trait;
+
+ F.pp_print_space fmt ();
+ if !Config.backend = Lean then F.pp_print_string fmt ":= {"
+ else F.pp_print_string fmt "= {";
+
+ (* Close the box for the name + generics *)
+ F.pp_close_box fmt ();
+
+ (*
+ * Extract the items
+ *)
+
+ (* The parent clauses - we retrieve those from the impl_ref *)
+ let trait_decl_id = impl.impl_trait.trait_decl_id in
+ TraitClauseId.iteri
+ (fun clause_id trait_ref ->
+ let item_name = ctx_get_trait_parent_clause trait_decl_id clause_id ctx in
+ let ty () =
+ F.pp_print_space fmt ();
+ extract_trait_ref ctx fmt TypeDeclId.Set.empty false trait_ref
+ in
+ extract_trait_impl_item ctx fmt item_name ty)
+ impl.impl_trait.decl_generics.trait_refs;
+
+ (* The constants *)
+ List.iter
+ (fun (name, (_, id)) ->
+ let item_name = ctx_get_trait_const trait_decl_id name ctx in
+ let ty () =
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt (ctx_get_global false id ctx)
+ in
+
+ extract_trait_impl_item ctx fmt item_name ty)
+ impl.consts;
+
+ (* The types *)
+ List.iter
+ (fun (name, (trait_refs, ty)) ->
+ (* Extract the type *)
+ let item_name = ctx_get_trait_type trait_decl_id name ctx in
+ let ty () =
+ F.pp_print_space fmt ();
+ extract_ty ctx fmt TypeDeclId.Set.empty false ty
+ in
+ extract_trait_impl_item ctx fmt item_name ty;
+ (* Extract the clauses *)
+ TraitClauseId.iteri
+ (fun clause_id trait_ref ->
+ let item_name =
+ ctx_get_trait_item_clause trait_decl_id name clause_id ctx
+ in
+ let ty () =
+ F.pp_print_space fmt ();
+ extract_trait_ref ctx fmt TypeDeclId.Set.empty false trait_ref
+ in
+ extract_trait_impl_item ctx fmt item_name ty)
+ trait_refs)
+ impl.types;
+
+ (* The required methods *)
+ List.iter
+ (fun (name, id) ->
+ extract_trait_impl_method_items ctx fmt impl name id all_generics)
+ impl.required_methods;
+
+ (* Close the outer boxes for the definition, as well as the brackets *)
+ F.pp_close_box fmt ();
+ F.pp_print_space fmt ();
+ F.pp_print_string fmt "}";
+ F.pp_close_box fmt ();
+ (* Add breaks to insert new lines between definitions *)
+ F.pp_print_break fmt 0 0
+
(** Extract a unit test, if the function is a unit function (takes no
parameters, returns unit).
@@ -3735,8 +4502,7 @@ let extract_unit_test_if_unit_fun (ctx : extraction_ctx) (fmt : F.formatter)
(* Check if this is a unit function *)
let sg = def.signature in
if
- sg.type_params = []
- && sg.const_generic_params = []
+ sg.generics = empty_generic_params
&& (sg.inputs = [ mk_unit_ty ] || sg.inputs = [])
&& sg.output = mk_result_ty mk_unit_ty
then (
diff --git a/compiler/ExtractAssumed.ml b/compiler/ExtractAssumed.ml
new file mode 100644
index 00000000..7f094b24
--- /dev/null
+++ b/compiler/ExtractAssumed.ml
@@ -0,0 +1,61 @@
+(** This file declares external identifiers that we catch to map them to
+ definitions coming from the standard libraries in our backends. *)
+
+open Names
+
+type simple_name = string list [@@deriving show, ord]
+
+let name_to_simple_name (s : name) : simple_name =
+ (* We simply ignore the disambiguators *)
+ List.filter_map (function Ident id -> Some id | Disambiguator _ -> None) s
+
+(** Small helper which cuts a string at the occurrences of "::" *)
+let string_to_simple_name (s : string) : simple_name =
+ (* No function to split by using string separator?? *)
+ let name = String.split_on_char ':' s in
+ List.filter (fun s -> s <> "") name
+
+module SimpleNameOrd = struct
+ type t = simple_name
+
+ let compare = compare_simple_name
+ let to_string = show_simple_name
+ let pp_t = pp_simple_name
+ let show_t = show_simple_name
+end
+
+module SimpleNameMap = Collections.MakeMap (SimpleNameOrd)
+
+let assumed_globals : (string * string) list =
+ [
+ (* Min *)
+ ("core::num::usize::MIN", "core_usize_min");
+ ("core::num::u8::MIN", "core_u8_min");
+ ("core::num::u16::MIN", "core_u16_min");
+ ("core::num::u32::MIN", "core_u32_min");
+ ("core::num::u64::MIN", "core_u64_min");
+ ("core::num::u128::MIN", "core_u128_min");
+ ("core::num::isize::MIN", "core_isize_min");
+ ("core::num::i8::MIN", "core_i8_min");
+ ("core::num::i16::MIN", "core_i16_min");
+ ("core::num::i32::MIN", "core_i32_min");
+ ("core::num::i64::MIN", "core_i64_min");
+ ("core::num::i128::MIN", "core_i128_min");
+ (* Max *)
+ ("core::num::usize::MAX", "core_usize_max");
+ ("core::num::u8::MAX", "core_u8_max");
+ ("core::num::u16::MAX", "core_u16_max");
+ ("core::num::u32::MAX", "core_u32_max");
+ ("core::num::u64::MAX", "core_u64_max");
+ ("core::num::u128::MAX", "core_u128_max");
+ ("core::num::isize::MAX", "core_isize_max");
+ ("core::num::i8::MAX", "core_i8_max");
+ ("core::num::i16::MAX", "core_i16_max");
+ ("core::num::i32::MAX", "core_i32_max");
+ ("core::num::i64::MAX", "core_i64_max");
+ ("core::num::i128::MAX", "core_i128_max");
+ ]
+
+let assumed_globals_map : string SimpleNameMap.t =
+ SimpleNameMap.of_list
+ (List.map (fun (x, y) -> (string_to_simple_name x, y)) assumed_globals)
diff --git a/compiler/ExtractBase.ml b/compiler/ExtractBase.ml
index d733c763..1586e6ed 100644
--- a/compiler/ExtractBase.ml
+++ b/compiler/ExtractBase.ml
@@ -5,6 +5,7 @@ open TranslateCore
module C = Contexts
module RegionVarId = T.RegionVarId
module F = Format
+open ExtractAssumed
(** The local logger *)
let log = L.pure_to_extract_log
@@ -77,6 +78,7 @@ type decl_kind =
F*: [val x : Type0]
Coq: [Axiom x : Type.]
*)
+[@@deriving show]
(** Return [true] if the declaration is the last from its group of declarations.
@@ -111,9 +113,9 @@ let decl_is_first_from_group (kind : decl_kind) : bool =
let decl_is_not_last_from_group (kind : decl_kind) : bool =
not (decl_is_last_from_group kind)
-(* TODO: this should a module we give to a functor! *)
+type type_decl_kind = Enum | Struct [@@deriving show]
-type type_decl_kind = Enum | Struct
+(* TODO: this should be a module we give to a functor! *)
(** A formatter's role is twofold:
1. Come up with name suggestions.
@@ -125,6 +127,9 @@ type type_decl_kind = Enum | Struct
snake case, adding prefixes/suffixes, etc.
2. Format some specific terms, like constants.
+
+ TODO: unclear that this is useful now that all the backends are so much
+ entangled in Extract.ml
*)
type formatter = {
bool_name : string;
@@ -239,6 +244,13 @@ type formatter = {
the same purpose as in {!field:fun_name}.
- loop identifier, if this is for a loop
*)
+ trait_decl_name : trait_decl -> string;
+ trait_impl_name : trait_decl -> trait_impl -> string;
+ trait_parent_clause_name : trait_decl -> trait_clause -> string;
+ trait_const_name : trait_decl -> string -> string;
+ trait_type_name : trait_decl -> string -> string;
+ trait_method_name : trait_decl -> string -> string;
+ trait_type_clause_name : trait_decl -> string -> trait_clause -> string;
opaque_pre : unit -> string;
(** TODO: obsolete, remove.
@@ -288,6 +300,14 @@ type formatter = {
(** Generates a type variable basename. *)
const_generic_var_basename : StringSet.t -> string -> string;
(** Generates a const generic variable basename. *)
+ trait_self_clause_basename : string;
+ trait_clause_basename : StringSet.t -> trait_clause -> string;
+ (** Return a base name for a trait clause. We might add a suffix to prevent
+ collisions.
+
+ In the traduction we explicitely manipulate the trait clause instances,
+ that is we introduce one input variable for each trait clause.
+ *)
append_index : string -> int -> string;
(** Appends an index to a name - we use this to generate unique
names: when doing so, the role of the formatter is just to concatenate
@@ -396,10 +416,59 @@ type id =
| TypeVarId of TypeVarId.id
| ConstGenericVarId of ConstGenericVarId.id
| VarId of VarId.id
+ | TraitDeclId of TraitDeclId.id
+ | TraitImplId of TraitImplId.id
+ | LocalTraitClauseId of TraitClauseId.id
+ | TraitMethodId of TraitDeclId.id * string * T.RegionGroupId.id option
+ (** Something peculiar with trait methods: because we have to take into
+ account forward/backward functions, we may need to generate fields
+ items per method.
+ *)
+ | TraitItemId of TraitDeclId.id * string
+ (** A trait associated item which is not a method *)
+ | TraitParentClauseId of TraitDeclId.id * TraitClauseId.id
+ | TraitItemClauseId of TraitDeclId.id * string * TraitClauseId.id
+ | TraitSelfClauseId
+ (** Specifically for the clause: [Self : Trait].
+
+ For now, we forbid provided methods (methods in trait declarations
+ with a default implementation) from being overriden in trait implementations.
+ We extract trait provided methods such that they take an instance of
+ the trait as input: this instance is given by the trait self clause.
+
+ For instance:
+ {[
+ //
+ // Rust
+ //
+ trait ToU64 {
+ fn to_u64(&self) -> u64;
+
+ // Provided method
+ fn is_pos(&self) -> bool {
+ self.to_u64() > 0
+ }
+ }
+
+ //
+ // Generated code
+ //
+ struct ToU64 (T : Type) {
+ to_u64 : T -> u64;
+ }
+
+ // The trait self clause
+ // vvvvvvvvvvvvvvvvvvvvvv
+ let is_pos (T : Type) (trait_self : ToU64 T) (self : T) : bool =
+ trait_self.to_u64 self > 0
+ ]}
+ *)
| UnknownId
(** Used for stored various strings like keywords, definitions which
should always be in context, etc. and which can't be linked to one
of the above.
+
+ TODO: rename to "keyword"
*)
[@@deriving show, ord]
@@ -445,23 +514,52 @@ type names_map = {
*)
}
-let names_map_add (id_to_string : id -> string) (is_opaque : bool) (id : id)
- (name : string) (nm : names_map) : names_map =
- (* Check if there is a clash *)
- (match StringMap.find_opt name nm.name_to_id with
+let empty_names_map : names_map =
+ {
+ id_to_name = IdMap.empty;
+ name_to_id = StringMap.empty;
+ names_set = StringSet.empty;
+ opaque_ids = IdSet.empty;
+ }
+
+(** Small helper to report name collision *)
+let report_name_collision (id_to_string : id -> string) (id1 : id) (id2 : id)
+ (name : string) : unit =
+ let id1 = "\n- " ^ id_to_string id1 in
+ let id2 = "\n- " ^ id_to_string id2 in
+ let err =
+ "Name clash detected: the following identifiers are bound to the same name \
+ \"" ^ name ^ "\":" ^ id1 ^ id2
+ ^ "\nYou may want to rename some of your definitions, or report an issue."
+ in
+ log#serror err;
+ (* If we fail hard on errors, raise an exception *)
+ if !Config.extract_fail_hard then raise (Failure err)
+
+let names_map_get_id_from_name (name : string) (nm : names_map) : id option =
+ StringMap.find_opt name nm.name_to_id
+
+let names_map_check_collision (id_to_string : id -> string) (id : id)
+ (name : string) (nm : names_map) : unit =
+ match names_map_get_id_from_name name nm with
| None -> () (* Ok *)
| Some clash ->
(* There is a clash: print a nice debugging message for the user *)
- let id1 = "\n- " ^ id_to_string clash in
- let id2 = "\n- " ^ id_to_string id in
- let err =
- "Name clash detected: the following identifiers are bound to the same \
- name \"" ^ name ^ "\":" ^ id1 ^ id2
- in
- log#serror err;
- raise (Failure err));
+ report_name_collision id_to_string clash id name
+
+let names_map_add (id_to_string : id -> string) (is_opaque : bool) (id : id)
+ (name : string) (nm : names_map) : names_map =
+ (* Check if there is a clash *)
+ names_map_check_collision id_to_string id name nm;
(* Sanity check *)
- assert (not (StringSet.mem name nm.names_set));
+ if StringSet.mem name nm.names_set then (
+ let err =
+ "Error when registering the name for id: " ^ id_to_string id
+ ^ ":\nThe chosen name is already in the names set: " ^ name
+ in
+ log#serror err;
+ (* If we fail hard on errors, raise an exception *)
+ if !Config.extract_fail_hard then raise (Failure err));
(* Insert *)
let id_to_name = IdMap.add id name nm.id_to_name in
let name_to_id = StringMap.add name id nm.name_to_id in
@@ -549,6 +647,7 @@ type fun_name_info = { keep_fwd : bool; num_backs : int }
functions, etc.
*)
type extraction_ctx = {
+ crate : A.crate;
trans_ctx : trans_ctx;
names_map : names_map;
(** The map for id to names, where we forbid name collisions
@@ -556,6 +655,15 @@ type extraction_ctx = {
unsafe_names_map : unsafe_names_map;
(** The map for id to names, where we allow name collisions
(ex.: we might allow record field name collisions). *)
+ strict_names_map : names_map;
+ (** This map is a sub-map of [names_map]. For the ids in this map we also
+ forbid collisions with names in the [unsafe_names_map].
+
+ We do so for keywords for instance, but also for types (in a dependently
+ typed language, we might have an issue if the field of a record has, say,
+ the name "u32", and another field of the same record refers to "u32"
+ (for instance in its type).
+ *)
fmt : formatter;
indent_incr : int;
(** The indent increment we insert whenever we need to indent more *)
@@ -586,6 +694,14 @@ type extraction_ctx = {
in case a Rust function only has one backward translation
and we filter the forward function because it returns unit.
*)
+ trait_decl_id : trait_decl_id option;
+ (** If we are extracting a trait declaration, identifies it *)
+ is_provided_method : bool;
+ trans_types : Pure.type_decl Pure.TypeDeclId.Map.t;
+ trans_funs : pure_fun_translation A.FunDeclId.Map.t;
+ functions_with_decreases_clause : PureUtils.FunLoopIdSet.t;
+ trans_trait_decls : Pure.trait_decl Pure.TraitDeclId.Map.t;
+ trans_trait_impls : Pure.trait_impl Pure.TraitImplId.Map.t;
}
(** Debugging function, used when communicating name collisions to the user,
@@ -593,9 +709,16 @@ type extraction_ctx = {
instance).
*)
let id_to_string (id : id) (ctx : extraction_ctx) : string =
- let global_decls = ctx.trans_ctx.global_context.global_decls in
- let fun_decls = ctx.trans_ctx.fun_context.fun_decls in
- let type_decls = ctx.trans_ctx.type_context.type_decls in
+ let global_decls = ctx.trans_ctx.global_ctx.global_decls in
+ let fun_decls = ctx.trans_ctx.fun_ctx.fun_decls in
+ let type_decls = ctx.trans_ctx.type_ctx.type_decls in
+ let trait_decls = ctx.trans_ctx.trait_decls_ctx.trait_decls in
+ let trait_decl_id_to_string (id : A.TraitDeclId.id) : string =
+ let trait_name =
+ Print.fun_name_to_string (A.TraitDeclId.Map.find id trait_decls).name
+ in
+ "trait_decl: " ^ trait_name ^ " (id: " ^ A.TraitDeclId.to_string id ^ ")"
+ in
(* TODO: factorize the pretty-printing with what is in PrintPure *)
let get_type_name (id : type_id) : string =
match id with
@@ -614,10 +737,17 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
| FromLlbc (fid, lp_id, rg_id) ->
let fun_name =
match fid with
- | Regular fid ->
+ | FunId (Regular fid) ->
Print.fun_name_to_string
(A.FunDeclId.Map.find fid fun_decls).name
- | Assumed aid -> A.show_assumed_fun_id aid
+ | FunId (Assumed aid) -> A.show_assumed_fun_id aid
+ | TraitMethod (trait_ref, method_name, _) ->
+ (* Shouldn't happen *)
+ if !Config.extract_fail_hard then raise (Failure "Unexpected")
+ else
+ "Trait method: decl: "
+ ^ TraitDeclId.to_string trait_ref.trait_decl_ref.trait_decl_id
+ ^ ", method_name: " ^ method_name
in
let lp_kind =
@@ -677,8 +807,16 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
if variant_id = option_some_id then "@option::Some"
else if variant_id = option_none_id then "@option::None"
else raise (Failure "Unreachable")
- | Assumed (State | Vec | Fuel | Array | Slice | Str | Range) ->
- raise (Failure "Unreachable")
+ | Assumed Fuel ->
+ if variant_id = fuel_zero_id then "@fuel::0"
+ else if variant_id = fuel_succ_id then "@fuel::Succ"
+ else raise (Failure "Unreachable")
+ | Assumed (State | Vec | Array | Slice | Str | Range) ->
+ raise
+ (Failure
+ ("Unreachable: variant id ("
+ ^ VariantId.to_string variant_id
+ ^ ") for " ^ show_type_id id))
| AdtId id -> (
let def = TypeDeclId.Map.find id type_decls in
match def.kind with
@@ -716,43 +854,120 @@ let id_to_string (id : id) (ctx : extraction_ctx) : string =
| ConstGenericVarId id ->
"const_generic_var_id: " ^ ConstGenericVarId.to_string id
| VarId id -> "var_id: " ^ VarId.to_string id
+ | TraitDeclId id -> "trait_decl_id: " ^ TraitDeclId.to_string id
+ | TraitImplId id -> "trait_impl_id: " ^ TraitImplId.to_string id
+ | LocalTraitClauseId id ->
+ "local_trait_clause_id: " ^ TraitClauseId.to_string id
+ | TraitParentClauseId (id, clause_id) ->
+ "trait_parent_clause_id: " ^ trait_decl_id_to_string id ^ ", clause_id: "
+ ^ TraitClauseId.to_string clause_id
+ | TraitItemClauseId (id, item_name, clause_id) ->
+ "trait_item_clause_id: " ^ trait_decl_id_to_string id ^ ", item name: "
+ ^ item_name ^ ", clause_id: "
+ ^ TraitClauseId.to_string clause_id
+ | TraitItemId (id, name) ->
+ "trait_item_id: " ^ trait_decl_id_to_string id ^ ", type name: " ^ name
+ | TraitMethodId (trait_decl_id, fun_name, rg_id) ->
+ let fwd_back_kind =
+ match rg_id with
+ | None -> "forward"
+ | Some rg_id -> "backward " ^ RegionGroupId.to_string rg_id
+ in
+ trait_decl_id_to_string trait_decl_id
+ ^ ", method name (" ^ fwd_back_kind ^ "): " ^ fun_name
+ | TraitSelfClauseId -> "trait_self_clause"
+
+(** Return [true] if we are strict on collisions for this id (i.e., we forbid
+ collisions even with the ids in the unsafe names map) *)
+let strict_collisions (id : id) : bool =
+ match id with UnknownId | TypeId _ -> true | _ -> false
(** We might not check for collisions for some specific ids (ex.: field names) *)
let allow_collisions (id : id) : bool =
match id with
- | FieldId (_, _) -> !Config.record_fields_short_names
+ | FieldId _ | TraitItemClauseId _ | TraitParentClauseId _ | TraitItemId _
+ | TraitMethodId _ ->
+ !Config.record_fields_short_names
| _ -> false
let ctx_add (is_opaque : bool) (id : id) (name : string) (ctx : extraction_ctx)
: extraction_ctx =
- (* We do not use the same name map if we allow/disallow collisions *)
+ (* The id_to_string function to print nice debugging messages if there are
+ * collisions *)
+ let id_to_string (id : id) : string = id_to_string id ctx in
+ (* We do not use the same name map if we allow/disallow collisions.
+ We notably use it for field names: some backends like Lean can use the
+ type information to disambiguate field projections.
+
+ Remark: we still need to check that those "unsafe" ids don't collide with
+ the ids that we mark as "strict on collision".
+
+ For instance, we don't allow naming a field "let". We enforce this by
+ not checking collision between ids for which we permit collisions (ex.:
+ between fields), but still checking collisions between those ids and the
+ others (ex.: fields and keywords).
+ *)
if allow_collisions id then (
assert (not is_opaque);
+ (* Check with the ids which are considered to be strict on collisions *)
+ names_map_check_collision id_to_string id name ctx.strict_names_map;
{
ctx with
unsafe_names_map = unsafe_names_map_add id name ctx.unsafe_names_map;
})
else
- (* The id_to_string function to print nice debugging messages if there are
- * collisions *)
- let id_to_string (id : id) : string = id_to_string id ctx in
+ (* Remark: if we are strict on collisions:
+ - we add the id to the strict collisions map
+ - we check that the id doesn't collide with the unsafe map
+ *)
+ let strict_names_map =
+ if strict_collisions id then
+ names_map_add id_to_string is_opaque id name ctx.strict_names_map
+ else ctx.strict_names_map
+ in
let names_map =
names_map_add id_to_string is_opaque id name ctx.names_map
in
- { ctx with names_map }
+ { ctx with strict_names_map; names_map }
(** [with_opaque_pre]: if [true] and the definition is opaque, add the opaque prefix *)
let ctx_get (with_opaque_pre : bool) (id : id) (ctx : extraction_ctx) : string =
(* We do not use the same name map if we allow/disallow collisions *)
- if allow_collisions id then IdMap.find id ctx.unsafe_names_map.id_to_name
+ let map_to_string (m : string IdMap.t) : string =
+ "[\n"
+ ^ String.concat ","
+ (List.map
+ (fun (id, n) -> "\n " ^ id_to_string id ctx ^ " -> " ^ n)
+ (IdMap.bindings m))
+ ^ "\n]"
+ in
+ if allow_collisions id then (
+ let m = ctx.unsafe_names_map.id_to_name in
+ match IdMap.find_opt id m with
+ | Some s -> s
+ | None ->
+ let err =
+ "Could not find: " ^ id_to_string id ctx ^ "\nNames map:\n"
+ ^ map_to_string m
+ in
+ log#serror err;
+ if !Config.extract_fail_hard then raise (Failure err)
+ else
+ "(%%%ERROR: unknown identifier\": " ^ id_to_string id ctx ^ "\"%%%)")
else
- match IdMap.find_opt id ctx.names_map.id_to_name with
+ let m = ctx.names_map.id_to_name in
+ match IdMap.find_opt id m with
| Some s ->
let is_opaque = IdSet.mem id ctx.names_map.opaque_ids in
if with_opaque_pre && is_opaque then ctx.fmt.opaque_pre () ^ s else s
| None ->
- log#serror ("Could not find: " ^ id_to_string id ctx);
- raise Not_found
+ let err =
+ "Could not find: " ^ id_to_string id ctx ^ "\nNames map:\n"
+ ^ map_to_string m
+ in
+ log#serror err;
+ if !Config.extract_fail_hard then raise (Failure err)
+ else "(ERROR: \"" ^ id_to_string id ctx ^ "\")"
let ctx_get_global (with_opaque_pre : bool) (id : A.GlobalDeclId.id)
(ctx : extraction_ctx) : string =
@@ -765,7 +980,7 @@ let ctx_get_function (with_opaque_pre : bool) (id : fun_id)
let ctx_get_local_function (with_opaque_pre : bool) (id : A.FunDeclId.id)
(lp : LoopId.id option) (rg : RegionGroupId.id option)
(ctx : extraction_ctx) : string =
- ctx_get_function with_opaque_pre (FromLlbc (Regular id, lp, rg)) ctx
+ ctx_get_function with_opaque_pre (FromLlbc (FunId (Regular id), lp, rg)) ctx
let ctx_get_type (with_opaque_pre : bool) (id : type_id) (ctx : extraction_ctx)
: string =
@@ -785,6 +1000,46 @@ let ctx_get_assumed_type (id : assumed_ty) (ctx : extraction_ctx) : string =
let is_opaque = false in
ctx_get_type is_opaque (Assumed id) ctx
+let ctx_get_trait_self_clause (ctx : extraction_ctx) : string =
+ let with_opaque_pre = false in
+ ctx_get with_opaque_pre TraitSelfClauseId ctx
+
+let ctx_get_trait_decl (with_opaque_pre : bool) (id : trait_decl_id)
+ (ctx : extraction_ctx) : string =
+ ctx_get with_opaque_pre (TraitDeclId id) ctx
+
+let ctx_get_trait_impl (with_opaque_pre : bool) (id : trait_impl_id)
+ (ctx : extraction_ctx) : string =
+ ctx_get with_opaque_pre (TraitImplId id) ctx
+
+let ctx_get_trait_item (id : trait_decl_id) (item_name : string)
+ (ctx : extraction_ctx) : string =
+ let is_opaque = false in
+ ctx_get is_opaque (TraitItemId (id, item_name)) ctx
+
+let ctx_get_trait_const (id : trait_decl_id) (item_name : string)
+ (ctx : extraction_ctx) : string =
+ ctx_get_trait_item id item_name ctx
+
+let ctx_get_trait_type (id : trait_decl_id) (item_name : string)
+ (ctx : extraction_ctx) : string =
+ ctx_get_trait_item id item_name ctx
+
+let ctx_get_trait_method (id : trait_decl_id) (item_name : string)
+ (rg_id : T.RegionGroupId.id option) (ctx : extraction_ctx) : string =
+ let with_opaque_pre = false in
+ ctx_get with_opaque_pre (TraitMethodId (id, item_name, rg_id)) ctx
+
+let ctx_get_trait_parent_clause (id : trait_decl_id) (clause : trait_clause_id)
+ (ctx : extraction_ctx) : string =
+ let with_opaque_pre = false in
+ ctx_get with_opaque_pre (TraitParentClauseId (id, clause)) ctx
+
+let ctx_get_trait_item_clause (id : trait_decl_id) (item : string)
+ (clause : trait_clause_id) (ctx : extraction_ctx) : string =
+ let with_opaque_pre = false in
+ ctx_get with_opaque_pre (TraitItemClauseId (id, item, clause)) ctx
+
let ctx_get_var (id : VarId.id) (ctx : extraction_ctx) : string =
let is_opaque = false in
ctx_get is_opaque (VarId id) ctx
@@ -798,6 +1053,11 @@ let ctx_get_const_generic_var (id : ConstGenericVarId.id) (ctx : extraction_ctx)
let is_opaque = false in
ctx_get is_opaque (ConstGenericVarId id) ctx
+let ctx_get_local_trait_clause (id : TraitClauseId.id) (ctx : extraction_ctx) :
+ string =
+ let is_opaque = false in
+ ctx_get is_opaque (LocalTraitClauseId id) ctx
+
let ctx_get_field (type_id : type_id) (field_id : FieldId.id)
(ctx : extraction_ctx) : string =
let is_opaque = false in
@@ -863,6 +1123,26 @@ let ctx_add_var (basename : string) (id : VarId.id) (ctx : extraction_ctx) :
let ctx = ctx_add is_opaque (VarId id) name ctx in
(ctx, name)
+(** Generate a unique variable name for the trait self clause and add it to the context *)
+let ctx_add_trait_self_clause (ctx : extraction_ctx) : extraction_ctx * string =
+ let is_opaque = false in
+ let basename = ctx.fmt.trait_self_clause_basename in
+ let name =
+ basename_to_unique ctx.names_map.names_set ctx.fmt.append_index basename
+ in
+ let ctx = ctx_add is_opaque TraitSelfClauseId name ctx in
+ (ctx, name)
+
+(** Generate a unique trait clause name and add it to the context *)
+let ctx_add_local_trait_clause (basename : string) (id : TraitClauseId.id)
+ (ctx : extraction_ctx) : extraction_ctx * string =
+ let is_opaque = false in
+ let name =
+ basename_to_unique ctx.names_map.names_set ctx.fmt.append_index basename
+ in
+ let ctx = ctx_add is_opaque (LocalTraitClauseId id) name ctx in
+ (ctx, name)
+
(** See {!ctx_add_var} *)
let ctx_add_vars (vars : var list) (ctx : extraction_ctx) :
extraction_ctx * string list =
@@ -885,12 +1165,26 @@ let ctx_add_const_generic_params (vars : const_generic_var list)
ctx_add_const_generic_var var.name var.index ctx)
ctx vars
-let ctx_add_type_const_generic_params (tvars : type_var list)
- (cgvars : const_generic_var list) (ctx : extraction_ctx) :
- extraction_ctx * string list * string list =
- let ctx, tys = ctx_add_type_params tvars ctx in
- let ctx, cgs = ctx_add_const_generic_params cgvars ctx in
- (ctx, tys, cgs)
+let ctx_add_local_trait_clauses (clauses : trait_clause list)
+ (ctx : extraction_ctx) : extraction_ctx * string list =
+ List.fold_left_map
+ (fun ctx (c : trait_clause) ->
+ let basename = ctx.fmt.trait_clause_basename ctx.names_map.names_set c in
+ ctx_add_local_trait_clause basename c.clause_id ctx)
+ ctx clauses
+
+(** Returns the lists of names for:
+ - the type variables
+ - the const generic variables
+ - the trait clauses
+ *)
+let ctx_add_generic_params (generics : generic_params) (ctx : extraction_ctx) :
+ extraction_ctx * string list * string list * string list =
+ let { types; const_generics; trait_clauses } = generics in
+ let ctx, tys = ctx_add_type_params types ctx in
+ let ctx, cgs = ctx_add_const_generic_params const_generics ctx in
+ let ctx, tcs = ctx_add_local_trait_clauses trait_clauses ctx in
+ (ctx, tys, cgs, tcs)
let ctx_add_type_decl_struct (def : type_decl) (ctx : extraction_ctx) :
extraction_ctx * string =
@@ -977,49 +1271,62 @@ let ctx_add_global_decl_and_body (def : A.global_decl) (ctx : extraction_ctx) :
extraction_ctx =
(* TODO: update once the body id can be an option *)
let is_opaque = false in
- let name = ctx.fmt.global_name def.name in
let decl = GlobalId def.def_id in
- let body = FunId (FromLlbc (Regular def.body_id, None, None)) in
- let ctx = ctx_add is_opaque decl (name ^ "_c") ctx in
- let ctx = ctx_add is_opaque body (name ^ "_body") ctx in
- ctx
-let ctx_add_fun_decl (trans_group : bool * pure_fun_translation)
- (def : fun_decl) (ctx : extraction_ctx) : extraction_ctx =
- (* Sanity check: the function should not be a global body - those are handled
- * separately *)
- assert (not def.is_global_decl_body);
+ (* Check if the global corresponds to an assumed global that we should map
+ to a custom definition in our standard library (for instance, happens
+ with "core::num::usize::MAX") *)
+ let sname = name_to_simple_name def.name in
+ match SimpleNameMap.find_opt sname assumed_globals_map with
+ | Some name ->
+ (* Yes: register the custom binding *)
+ ctx_add is_opaque decl name ctx
+ | None ->
+ (* Not the case: "standard" registration *)
+ let name = ctx.fmt.global_name def.name in
+ let body = FunId (FromLlbc (FunId (Regular def.body_id), None, None)) in
+ let ctx = ctx_add is_opaque decl (name ^ "_c") ctx in
+ let ctx = ctx_add is_opaque body (name ^ "_body") ctx in
+ ctx
+
+let ctx_compute_fun_name (trans_group : pure_fun_translation) (def : fun_decl)
+ (ctx : extraction_ctx) : string =
(* Lookup the LLBC def to compute the region group information *)
let def_id = def.def_id in
- let llbc_def =
- A.FunDeclId.Map.find def_id ctx.trans_ctx.fun_context.fun_decls
- in
+ let llbc_def = A.FunDeclId.Map.find def_id ctx.trans_ctx.fun_ctx.fun_decls in
let sg = llbc_def.signature in
let num_rgs = List.length sg.regions_hierarchy in
- let keep_fwd, (_, backs) = trans_group in
+ let { keep_fwd; fwd = _; backs } = trans_group in
let num_backs = List.length backs in
let rg_info =
match def.back_id with
| None -> None
| Some rg_id ->
let rg = T.RegionGroupId.nth sg.regions_hierarchy rg_id in
- let regions =
+ let region_names =
List.map
- (fun rid -> T.RegionVarId.nth sg.region_params rid)
+ (fun rid -> (T.RegionVarId.nth sg.generics.regions rid).name)
rg.regions
in
- let region_names =
- List.map (fun (r : T.region_var) -> r.name) regions
- in
Some { id = rg_id; region_names }
in
+ (* Add the function name *)
+ ctx.fmt.fun_name def.basename def.num_loops def.loop_id num_rgs rg_info
+ (keep_fwd, num_backs)
+
+let ctx_add_fun_decl (trans_group : pure_fun_translation) (def : fun_decl)
+ (ctx : extraction_ctx) : extraction_ctx =
+ (* Sanity check: the function should not be a global body - those are handled
+ * separately *)
+ assert (not def.is_global_decl_body);
+ (* Lookup the LLBC def to compute the region group information *)
+ let def_id = def.def_id in
+ let { keep_fwd; fwd = _; backs } = trans_group in
+ let num_backs = List.length backs in
let is_opaque = def.body = None in
(* Add the function name *)
- let def_name =
- ctx.fmt.fun_name def.basename def.num_loops def.loop_id num_rgs rg_info
- (keep_fwd, num_backs)
- in
- let fun_id = (A.Regular def_id, def.loop_id, def.back_id) in
+ let def_name = ctx_compute_fun_name trans_group def ctx in
+ let fun_id = (Pure.FunId (Regular def_id), def.loop_id, def.back_id) in
let ctx = ctx_add is_opaque (FunId (FromLlbc fun_id)) def_name ctx in
(* Add the name info *)
{
@@ -1029,6 +1336,91 @@ let ctx_add_fun_decl (trans_group : bool * pure_fun_translation)
ctx.fun_name_info;
}
+let ctx_add_trait_decl (d : trait_decl) (ctx : extraction_ctx) : extraction_ctx
+ =
+ let is_opaque = false in
+ let name = ctx.fmt.trait_decl_name d in
+ ctx_add is_opaque (TraitDeclId d.def_id) name ctx
+
+let ctx_add_trait_impl (d : trait_impl) (ctx : extraction_ctx) : extraction_ctx
+ =
+ (* We need to lookup the trait decl that is implemented by the trait impl *)
+ let decl =
+ Pure.TraitDeclId.Map.find d.impl_trait.trait_decl_id ctx.trans_trait_decls
+ in
+ (* Compute the name *)
+ let is_opaque = false in
+ let name = ctx.fmt.trait_impl_name decl d in
+ ctx_add is_opaque (TraitImplId d.def_id) name ctx
+
+let ctx_add_trait_const (d : trait_decl) (item : string) (ctx : extraction_ctx)
+ : extraction_ctx =
+ let is_opaque = false in
+ let name = ctx.fmt.trait_const_name d item in
+ (* Add a prefix if necessary *)
+ let name =
+ if !Config.record_fields_short_names then name
+ else ctx.fmt.trait_decl_name d ^ name
+ in
+ ctx_add is_opaque (TraitItemId (d.def_id, item)) name ctx
+
+let ctx_add_trait_type (d : trait_decl) (item : string) (ctx : extraction_ctx) :
+ extraction_ctx =
+ let is_opaque = false in
+ let name = ctx.fmt.trait_type_name d item in
+ (* Add a prefix if necessary *)
+ let name =
+ if !Config.record_fields_short_names then name
+ else ctx.fmt.trait_decl_name d ^ name
+ in
+ ctx_add is_opaque (TraitItemId (d.def_id, item)) name ctx
+
+let ctx_add_trait_method (d : trait_decl) (item_name : string) (f : fun_decl)
+ (ctx : extraction_ctx) : extraction_ctx =
+ (* We do something special: we use the base name but remove everything
+ but the crate (because [get_name] removes it) and the last ident.
+ This allows us to reuse the [ctx_compute_fun_decl] function.
+ *)
+ let basename : name =
+ match (f.basename : name) with
+ | Ident crate :: name -> Ident crate :: [ Collections.List.last name ]
+ | _ -> raise (Failure "Unexpected")
+ in
+ let f = { f with basename } in
+ let trans = A.FunDeclId.Map.find f.def_id ctx.trans_funs in
+ let name = ctx_compute_fun_name trans f ctx in
+ (* Add a prefix if necessary *)
+ let name =
+ if !Config.record_fields_short_names then name
+ else ctx.fmt.trait_decl_name d ^ "_" ^ name
+ in
+ let is_opaque = false in
+ ctx_add is_opaque (TraitMethodId (d.def_id, item_name, f.back_id)) name ctx
+
+let ctx_add_trait_parent_clause (d : trait_decl) (clause : trait_clause)
+ (ctx : extraction_ctx) : extraction_ctx =
+ let is_opaque = false in
+ let name = ctx.fmt.trait_parent_clause_name d clause in
+ (* Add a prefix if necessary *)
+ let name =
+ if !Config.record_fields_short_names then name
+ else ctx.fmt.trait_decl_name d ^ name
+ in
+ ctx_add is_opaque (TraitParentClauseId (d.def_id, clause.clause_id)) name ctx
+
+let ctx_add_trait_type_clause (d : trait_decl) (item : string)
+ (clause : trait_clause) (ctx : extraction_ctx) : extraction_ctx =
+ let is_opaque = false in
+ let name = ctx.fmt.trait_type_clause_name d item clause in
+ (* Add a prefix if necessary *)
+ let name =
+ if !Config.record_fields_short_names then name
+ else ctx.fmt.trait_decl_name d ^ name
+ in
+ ctx_add is_opaque
+ (TraitItemClauseId (d.def_id, item, clause.clause_id))
+ name ctx
+
type names_map_init = {
keywords : string list;
assumed_adts : (assumed_ty * string) list;
@@ -1089,7 +1481,8 @@ let initialize_names_map (fmt : formatter) (init : names_map_init) : names_map =
in
let assumed_functions =
List.map
- (fun (fid, rg, name) -> (FromLlbc (A.Assumed fid, None, rg), name))
+ (fun (fid, rg, name) ->
+ (FromLlbc (Pure.FunId (A.Assumed fid), None, rg), name))
init.assumed_llbc_functions
@ List.map (fun (fid, name) -> (Pure fid, name)) init.assumed_pure_functions
in
diff --git a/compiler/FunsAnalysis.ml b/compiler/FunsAnalysis.ml
index b72fa078..f4406653 100644
--- a/compiler/FunsAnalysis.ml
+++ b/compiler/FunsAnalysis.ml
@@ -70,14 +70,14 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t)
method! visit_rvalue _env rv =
match rv with
- | Use _ | Ref _ | Global _ | Discriminant _ | Aggregate _ -> ()
+ | Use _ | RvRef _ | Global _ | Discriminant _ | Aggregate _ -> ()
| UnaryOp (uop, _) -> can_fail := EU.unop_can_fail uop || !can_fail
| BinaryOp (bop, _, _) ->
can_fail := EU.binop_can_fail bop || !can_fail
method! visit_Call env call =
(match call.func with
- | Regular id ->
+ | FunId (Regular id) ->
if FunDeclId.Set.mem id fun_ids then (
can_diverge := true;
is_rec := true)
@@ -86,9 +86,13 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t)
self#may_fail info.can_fail;
stateful := !stateful || info.stateful;
can_diverge := !can_diverge || info.can_diverge
- | Assumed id ->
+ | FunId (Assumed id) ->
(* None of the assumed functions can diverge nor are considered stateful *)
- can_fail := !can_fail || Assumed.assumed_can_fail id);
+ can_fail := !can_fail || Assumed.assumed_can_fail id
+ | TraitMethod _ ->
+ (* We consider trait functions can fail, diverge, and are not stateful *)
+ can_fail := true;
+ can_diverge := true);
super#visit_Call env call
method! visit_Panic env =
@@ -141,7 +145,8 @@ let analyze_module (m : crate) (funs_map : fun_decl FunDeclId.Map.t)
let rec analyze_decl_groups (decls : declaration_group list) : unit =
match decls with
| [] -> ()
- | Type _ :: decls' -> analyze_decl_groups decls'
+ | (Type _ | TraitDecl _ | TraitImpl _) :: decls' ->
+ analyze_decl_groups decls'
| Fun decl :: decls' ->
analyze_fun_decl_group decl;
analyze_decl_groups decls'
diff --git a/compiler/Interpreter.ml b/compiler/Interpreter.ml
index 154c5a21..24ff4808 100644
--- a/compiler/Interpreter.ml
+++ b/compiler/Interpreter.ml
@@ -12,55 +12,165 @@ module SA = SymbolicAst
(** The local logger *)
let log = L.interpreter_log
-let compute_type_fun_global_contexts (m : A.crate) :
- C.type_context * C.fun_context * C.global_context =
- let type_decls_list, _, _ = split_declarations m.declarations in
+let compute_contexts (m : A.crate) : C.decls_ctx =
+ let type_decls_list, _, _, _, _ = split_declarations m.declarations in
let type_decls = m.types in
let fun_decls = m.functions in
let global_decls = m.globals in
- let type_decls_groups, _funs_defs_groups, _globals_defs_groups =
+ let trait_decls = m.trait_decls in
+ let trait_impls = m.trait_impls in
+ let type_decls_groups, _, _, _, _ =
split_declarations_to_group_maps m.declarations
in
let type_infos =
TypesAnalysis.analyze_type_declarations type_decls type_decls_list
in
- let type_context = { C.type_decls_groups; type_decls; type_infos } in
- let fun_context = { C.fun_decls } in
- let global_context = { C.global_decls } in
- (type_context, fun_context, global_context)
-
-let initialize_eval_context (type_context : C.type_context)
- (fun_context : C.fun_context) (global_context : C.global_context)
- (region_groups : T.RegionGroupId.id list) (type_vars : T.type_var list)
- (const_generic_vars : T.const_generic_var list) : C.eval_ctx =
- C.reset_global_counters ();
- {
- C.type_context;
- C.fun_context;
- C.global_context;
- C.region_groups;
- C.type_vars;
- C.const_generic_vars;
- C.env = [ C.Frame ];
- C.ended_regions = T.RegionId.Set.empty;
- }
+ let type_ctx = { C.type_decls_groups; type_decls; type_infos } in
+ let fun_infos =
+ FunsAnalysis.analyze_module m fun_decls global_decls !Config.use_state
+ in
+ let fun_ctx = { C.fun_decls; fun_infos } in
+ let global_ctx = { C.global_decls } in
+ let trait_decls_ctx = { C.trait_decls } in
+ let trait_impls_ctx = { C.trait_impls } in
+ { C.type_ctx; fun_ctx; global_ctx; trait_decls_ctx; trait_impls_ctx }
+
+(** Small helper.
+
+ Normalize an instantiated function signature provided we used this signature
+ to compute a normalization map (for the associated types) and that we added
+ it in the context.
+ *)
+let normalize_inst_fun_sig (ctx : C.eval_ctx) (sg : A.inst_fun_sig) :
+ A.inst_fun_sig =
+ let { A.regions_hierarchy = _; trait_type_constraints = _; inputs; output } =
+ sg
+ in
+ let norm = AssociatedTypes.ctx_normalize_rty ctx in
+ let inputs = List.map norm inputs in
+ let output = norm output in
+ { sg with A.inputs; output }
+
+(** Instantiate a function signature for a symbolic execution.
+
+ We return a new context because we compute and add the type normalization
+ map in the same step.
+
+ **WARNING**: this doesn't normalize the types. This step has to be done
+ separately. Remark: we need to normalize essentially because of the where
+ clauses (we are not considering a function call, so we don't need to
+ normalize because a trait clause was instantiated with a specific trait ref).
+ *)
+let symbolic_instantiate_fun_sig (ctx : C.eval_ctx) (sg : A.fun_sig)
+ (kind : A.fun_kind) : C.eval_ctx * A.inst_fun_sig =
+ let tr_self =
+ match kind with
+ | RegularKind | TraitMethodImpl _ -> T.UnknownTrait __FUNCTION__
+ | TraitMethodDecl _ | TraitMethodProvided _ -> T.Self
+ in
+ let generics =
+ let { T.regions; types; const_generics; trait_clauses } = sg.generics in
+ let regions = List.map (fun _ -> T.Erased) regions in
+ let types = List.map (fun (v : T.type_var) -> T.TypeVar v.T.index) types in
+ let const_generics =
+ List.map
+ (fun (v : T.const_generic_var) -> T.ConstGenericVar v.T.index)
+ const_generics
+ in
+ (* Annoying that we have to generate this substitution here *)
+ let r_subst _ = raise (Failure "Unexpected region") in
+ let ty_subst = Subst.make_type_subst_from_vars sg.generics.types types in
+ let cg_subst =
+ Subst.make_const_generic_subst_from_vars sg.generics.const_generics
+ const_generics
+ in
+ (* TODO: some clauses may use the types of other clauses, so we may have to
+ reorder them.
+
+ Example:
+ If in Rust we write:
+ {[
+ pub fn use_get<'a, T: Get>(x: &'a mut T) -> u32
+ where
+ T::Item: ToU32,
+ {
+ x.get().to_u32()
+ }
+ ]}
+
+ In LLBC we get:
+ {[
+ fn demo::use_get<'a, T>(@1: &'a mut (T)) -> u32
+ where
+ [@TraitClause0]: demo::Get<T>,
+ [@TraitClause1]: demo::ToU32<@TraitClause0::Item>, // HERE
+ {
+ ... // Omitted
+ }
+ ]}
+ *)
+ (* We will need to update the trait refs map while we perform the instantiations *)
+ let mk_tr_subst
+ (tr_map : T.erased_region T.trait_instance_id T.TraitClauseId.Map.t)
+ clause_id : T.erased_region T.trait_instance_id =
+ match T.TraitClauseId.Map.find_opt clause_id tr_map with
+ | Some tr -> tr
+ | None -> raise (Failure "Local trait clause not found")
+ in
+ let mk_subst tr_map =
+ let tr_subst = mk_tr_subst tr_map in
+ { Subst.r_subst; ty_subst; cg_subst; tr_subst; tr_self }
+ in
+ let _, trait_refs =
+ List.fold_left_map
+ (fun tr_map (c : T.trait_clause) ->
+ let subst = mk_subst tr_map in
+ let { T.trait_id = trait_decl_id; generics; _ } = c in
+ let generics = Subst.generic_args_substitute subst generics in
+ let trait_decl_ref = { T.trait_decl_id; decl_generics = generics } in
+ (* Note that because we directly refer to the clause, we give it
+ empty generics *)
+ let trait_id = T.Clause c.clause_id in
+ let trait_ref =
+ {
+ T.trait_id;
+ generics = TypesUtils.mk_empty_generic_args;
+ trait_decl_ref;
+ }
+ in
+ (* Update the traits map *)
+ let tr_map = T.TraitClauseId.Map.add c.T.clause_id trait_id tr_map in
+ (tr_map, trait_ref))
+ T.TraitClauseId.Map.empty trait_clauses
+ in
+ { T.regions; types; const_generics; trait_refs }
+ in
+ let inst_sg = instantiate_fun_sig ctx generics tr_self sg in
+ (* Compute the normalization maps *)
+ let ctx =
+ AssociatedTypes.ctx_add_norm_trait_types_from_preds ctx
+ inst_sg.trait_type_constraints
+ in
+ (* Normalize the signature *)
+ let inst_sg = normalize_inst_fun_sig ctx inst_sg in
+ (* Return *)
+ (ctx, inst_sg)
(** Initialize an evaluation context to execute a function.
- Introduces local variables initialized in the following manner:
- - input arguments are initialized as symbolic values
- - the remaining locals are initialized as [⊥]
- Abstractions are introduced for the regions present in the function
- signature.
-
- We return:
- - the initialized evaluation context
- - the list of symbolic values introduced for the input values
- - the instantiated function signature
+ Introduces local variables initialized in the following manner:
+ - input arguments are initialized as symbolic values
+ - the remaining locals are initialized as [⊥]
+ Abstractions are introduced for the regions present in the function
+ signature.
+
+ We return:
+ - the initialized evaluation context
+ - the list of symbolic values introduced for the input values
+ - the instantiated function signature
*)
-let initialize_symbolic_context_for_fun (type_context : C.type_context)
- (fun_context : C.fun_context) (global_context : C.global_context)
- (fdef : A.fun_decl) : C.eval_ctx * V.symbolic_value list * A.inst_fun_sig =
+let initialize_symbolic_context_for_fun (ctx : C.decls_ctx) (fdef : A.fun_decl)
+ : C.eval_ctx * V.symbolic_value list * A.inst_fun_sig =
(* The abstractions are not initialized the same way as for function
* calls: they contain *loan* projectors, because they "provide" us
* with the input values (which behave as if they had been returned
@@ -78,19 +188,15 @@ let initialize_symbolic_context_for_fun (type_context : C.type_context)
List.map (fun (g : T.region_var_group) -> g.id) sg.regions_hierarchy
in
let ctx =
- initialize_eval_context type_context fun_context global_context
- region_groups sg.type_params sg.const_generic_params
+ initialize_eval_context ctx region_groups sg.generics.types
+ sg.generics.const_generics
in
- (* Instantiate the signature *)
- let type_params =
- List.map (fun (v : T.type_var) -> T.TypeVar v.T.index) sg.type_params
+ (* Instantiate the signature. This updates the context because we compute
+ at the same time the normalization map for the associated types.
+ *)
+ let ctx, inst_sg =
+ symbolic_instantiate_fun_sig ctx fdef.signature fdef.kind
in
- let cg_params =
- List.map
- (fun (v : T.const_generic_var) -> T.ConstGenericVar v.T.index)
- sg.const_generic_params
- in
- let inst_sg = instantiate_fun_sig type_params cg_params sg in
(* Create fresh symbolic values for the inputs *)
let input_svs =
List.map (fun ty -> mk_fresh_symbolic_value V.SynthInput ty) inst_sg.inputs
@@ -165,15 +271,9 @@ let evaluate_function_symbolic_synthesize_backward_from_return
* an instantiation of the signature, so that we use fresh
* region ids for the return abstractions. *)
let sg = fdef.signature in
- let type_params =
- List.map (fun (v : T.type_var) -> T.TypeVar v.T.index) sg.type_params
- in
- let cg_params =
- List.map
- (fun (v : T.const_generic_var) -> T.ConstGenericVar v.T.index)
- sg.const_generic_params
+ let _, ret_inst_sg =
+ symbolic_instantiate_fun_sig ctx fdef.signature fdef.kind
in
- let ret_inst_sg = instantiate_fun_sig type_params cg_params sg in
let ret_rty = ret_inst_sg.output in
(* Move the return value out of the return variable *)
let pop_return_value = is_regular_return in
@@ -347,19 +447,14 @@ let evaluate_function_symbolic_synthesize_backward_from_return
for the synthesis)
- the symbolic AST generated by the symbolic execution
*)
-let evaluate_function_symbolic (synthesize : bool)
- (type_context : C.type_context) (fun_context : C.fun_context)
- (global_context : C.global_context) (fdef : A.fun_decl) :
- V.symbolic_value list * SA.expression option =
+let evaluate_function_symbolic (synthesize : bool) (ctx : C.decls_ctx)
+ (fdef : A.fun_decl) : V.symbolic_value list * SA.expression option =
(* Debug *)
let name_to_string () = Print.fun_name_to_string fdef.A.name in
log#ldebug (lazy ("evaluate_function_symbolic: " ^ name_to_string ()));
(* Create the evaluation context *)
- let ctx, input_svs, inst_sg =
- initialize_symbolic_context_for_fun type_context fun_context global_context
- fdef
- in
+ let ctx, input_svs, inst_sg = initialize_symbolic_context_for_fun ctx fdef in
(* Create the continuation to finish the evaluation *)
let config = C.mk_config C.SymbolicMode in
@@ -488,7 +583,8 @@ module Test = struct
(** Test a unit function (taking no arguments) by evaluating it in an empty
environment.
*)
- let test_unit_function (crate : A.crate) (fid : A.FunDeclId.id) : unit =
+ let test_unit_function (crate : A.crate) (decls_ctx : C.decls_ctx)
+ (fid : A.FunDeclId.id) : unit =
(* Retrieve the function declaration *)
let fdef = A.FunDeclId.Map.find fid crate.functions in
let body = Option.get fdef.body in
@@ -498,17 +594,11 @@ module Test = struct
(lazy ("test_unit_function: " ^ Print.fun_name_to_string fdef.A.name));
(* Sanity check - *)
- assert (List.length fdef.A.signature.region_params = 0);
- assert (List.length fdef.A.signature.type_params = 0);
+ assert (fdef.A.signature.generics = TypesUtils.mk_empty_generic_params);
assert (body.A.arg_count = 0);
(* Create the evaluation context *)
- let type_context, fun_context, global_context =
- compute_type_fun_global_contexts crate
- in
- let ctx =
- initialize_eval_context type_context fun_context global_context [] [] []
- in
+ let ctx = initialize_eval_context decls_ctx [] [] [] in
(* Insert the (uninitialized) local variables *)
let ctx = C.ctx_push_uninitialized_vars ctx body.A.locals in
@@ -536,9 +626,7 @@ module Test = struct
(no parameters, no arguments) - TODO: move *)
let fun_decl_is_transparent_unit (def : A.fun_decl) : bool =
Option.is_some def.body
- && def.A.signature.region_params = []
- && def.A.signature.type_params = []
- && def.A.signature.const_generic_params = []
+ && def.A.signature.generics = TypesUtils.mk_empty_generic_params
&& def.A.signature.inputs = []
(** Test all the unit functions in a list of function definitions *)
@@ -548,24 +636,9 @@ module Test = struct
(fun _ -> fun_decl_is_transparent_unit)
crate.functions
in
+ let decls_ctx = compute_contexts crate in
let test_unit_fun _ (def : A.fun_decl) : unit =
- test_unit_function crate def.A.def_id
+ test_unit_function crate decls_ctx def.A.def_id
in
A.FunDeclId.Map.iter test_unit_fun unit_funs
-
- (** Execute the symbolic interpreter on a function. *)
- let test_function_symbolic (synthesize : bool) (type_context : C.type_context)
- (fun_context : C.fun_context) (global_context : C.global_context)
- (fdef : A.fun_decl) : unit =
- (* Debug *)
- log#ldebug
- (lazy ("test_function_symbolic: " ^ Print.fun_name_to_string fdef.A.name));
-
- (* Evaluate *)
- let _ =
- evaluate_function_symbolic synthesize type_context fun_context
- global_context fdef
- in
-
- ()
end
diff --git a/compiler/InterpreterBorrows.ml b/compiler/InterpreterBorrows.ml
index 4d67a4e4..e97795a1 100644
--- a/compiler/InterpreterBorrows.ml
+++ b/compiler/InterpreterBorrows.ml
@@ -452,7 +452,8 @@ let give_back_symbolic_value (_config : C.config)
| V.SynthInputGivenBack | SynthRetGivenBack | FunCallGivenBack | LoopGivenBack
->
()
- | FunCallRet | SynthInput | Global | LoopOutput | LoopJoin | Aggregate ->
+ | FunCallRet | SynthInput | Global | LoopOutput | LoopJoin | Aggregate
+ | ConstGeneric | TraitConst ->
raise (Failure "Unreachable"));
(* Store the given-back value as a meta-value for synthesis purposes *)
let mv = nsv in
diff --git a/compiler/InterpreterBorrowsCore.ml b/compiler/InterpreterBorrowsCore.ml
index bf083aa4..e7da045c 100644
--- a/compiler/InterpreterBorrowsCore.ml
+++ b/compiler/InterpreterBorrowsCore.ml
@@ -100,15 +100,18 @@ let rec compare_rtys (default : bool) (combine : bool -> bool -> bool)
(compare_regions : T.RegionId.id T.region -> T.RegionId.id T.region -> bool)
(ty1 : T.rty) (ty2 : T.rty) : bool =
let compare = compare_rtys default combine compare_regions in
+ (* Normalize the associated types *)
match (ty1, ty2) with
| T.Literal lit1, T.Literal lit2 ->
assert (lit1 = lit2);
default
- | T.Adt (id1, regions1, tys1, cgs1), T.Adt (id2, regions2, tys2, cgs2) ->
+ | T.Adt (id1, generics1), T.Adt (id2, generics2) ->
assert (id1 = id2);
(* There are no regions in the const generics, so we ignore them,
but we still check they are the same, for sanity *)
- assert (cgs1 = cgs2);
+ assert (generics1.const_generics = generics2.const_generics);
+
+ (* We also ignore the trait refs *)
(* The check for the ADTs is very crude: we simply compare the arguments
* two by two.
@@ -123,14 +126,14 @@ let rec compare_rtys (default : bool) (combine : bool -> bool -> bool)
* this check would still be a reasonable conservative approximation. *)
(* Check the region parameters *)
- let regions = List.combine regions1 regions2 in
+ let regions = List.combine generics1.regions generics2.regions in
let params_b =
List.fold_left
(fun b (r1, r2) -> combine b (compare_regions r1 r2))
default regions
in
(* Check the type parameters *)
- let tys = List.combine tys1 tys2 in
+ let tys = List.combine generics1.types generics2.types in
let tys_b =
List.fold_left
(fun b (ty1, ty2) -> combine b (compare ty1 ty2))
@@ -150,6 +153,11 @@ let rec compare_rtys (default : bool) (combine : bool -> bool -> bool)
| T.TypeVar id1, T.TypeVar id2 ->
assert (id1 = id2);
default
+ | T.TraitType _, T.TraitType _ ->
+ (* The types should have been normalized. If after normalization we
+ get trait types, we can consider them as variables *)
+ assert (ty1 = ty2);
+ default
| _ ->
log#lerror
(lazy
diff --git a/compiler/InterpreterExpansion.ml b/compiler/InterpreterExpansion.ml
index 81e73e3e..ea692386 100644
--- a/compiler/InterpreterExpansion.ml
+++ b/compiler/InterpreterExpansion.ml
@@ -9,6 +9,7 @@ module V = Values
module E = Expressions
module C = Contexts
module Subst = Substitute
+module Assoc = AssociatedTypes
module L = Logging
open TypesUtils
module Inv = Invariants
@@ -204,7 +205,7 @@ let apply_symbolic_expansion_non_borrow (config : C.config)
apply_symbolic_expansion_to_avalues config allow_reborrows original_sv
expansion ctx
-(** Compute the expansion of a non-assumed (i.e.: not [Option], [Box], etc.)
+(** Compute the expansion of a non-assumed (i.e.: not [Box], etc.)
adt value.
The function might return a list of values if the symbolic value to expand
@@ -214,18 +215,15 @@ let apply_symbolic_expansion_non_borrow (config : C.config)
doesn't allow the expansion of enumerations *containing several variants*.
*)
let compute_expanded_symbolic_non_assumed_adt_value (expand_enumerations : bool)
- (kind : V.sv_kind) (def_id : T.TypeDeclId.id)
- (regions : T.RegionId.id T.region list) (types : T.rty list)
- (cgs : T.const_generic list) (ctx : C.eval_ctx) : V.symbolic_expansion list
- =
+ (kind : V.sv_kind) (def_id : T.TypeDeclId.id) (generics : T.rgeneric_args)
+ (ctx : C.eval_ctx) : V.symbolic_expansion list =
(* Lookup the definition and check if it is an enumeration with several
* variants *)
let def = C.ctx_lookup_type_decl ctx def_id in
- assert (List.length regions = List.length def.T.region_params);
+ assert (List.length generics.regions = List.length def.T.generics.regions);
(* Retrieve, for every variant, the list of its instantiated field types *)
let variants_fields_types =
- Subst.type_decl_get_instantiated_variants_fields_rtypes def regions types
- cgs
+ Assoc.type_decl_get_inst_norm_variants_fields_rtypes ctx def generics
in
(* Check if there is strictly more than one variant *)
if List.length variants_fields_types > 1 && not expand_enumerations then
@@ -280,15 +278,14 @@ let compute_expanded_symbolic_box_value (kind : V.sv_kind) (boxed_ty : T.rty) :
doesn't allow the expansion of enumerations *containing several variants*.
*)
let compute_expanded_symbolic_adt_value (expand_enumerations : bool)
- (kind : V.sv_kind) (adt_id : T.type_id)
- (regions : T.RegionId.id T.region list) (types : T.rty list)
- (cgs : T.const_generic list) (ctx : C.eval_ctx) : V.symbolic_expansion list
- =
- match (adt_id, regions, types) with
+ (kind : V.sv_kind) (adt_id : T.type_id) (generics : T.rgeneric_args)
+ (ctx : C.eval_ctx) : V.symbolic_expansion list =
+ match (adt_id, generics.regions, generics.types) with
| T.AdtId def_id, _, _ ->
compute_expanded_symbolic_non_assumed_adt_value expand_enumerations kind
- def_id regions types cgs ctx
- | T.Tuple, [], _ -> [ compute_expanded_symbolic_tuple_value kind types ]
+ def_id generics ctx
+ | T.Tuple, [], _ ->
+ [ compute_expanded_symbolic_tuple_value kind generics.types ]
| T.Assumed T.Option, [], [ ty ] ->
compute_expanded_symbolic_option_value expand_enumerations kind ty
| T.Assumed T.Box, [], [ boxed_ty ] ->
@@ -543,12 +540,12 @@ let expand_symbolic_value_no_branching (config : C.config)
fun cf ctx ->
match rty with
(* ADTs *)
- | T.Adt (adt_id, regions, types, cgs) ->
+ | T.Adt (adt_id, generics) ->
(* Compute the expanded value *)
let allow_branching = false in
let seel =
compute_expanded_symbolic_adt_value allow_branching sv.sv_kind adt_id
- regions types cgs ctx
+ generics ctx
in
(* There should be exacly one branch *)
let see = Collections.List.to_cons_nil seel in
@@ -600,12 +597,12 @@ let expand_symbolic_adt (config : C.config) (sv : V.symbolic_value)
(* Execute *)
match rty with
(* ADTs *)
- | T.Adt (adt_id, regions, types, cgs) ->
+ | T.Adt (adt_id, generics) ->
let allow_branching = true in
(* Compute the expanded value *)
let seel =
compute_expanded_symbolic_adt_value allow_branching sv.sv_kind adt_id
- regions types cgs ctx
+ generics ctx
in
(* Apply *)
let seel = List.map (fun see -> (Some see, cf_branches)) seel in
@@ -679,7 +676,7 @@ let greedy_expand_symbolics_with_borrows (config : C.config) : cm_fun =
^ symbolic_value_to_string ctx sv));
let cc : cm_fun =
match sv.V.sv_ty with
- | T.Adt (AdtId def_id, _, _, _) ->
+ | T.Adt (AdtId def_id, _) ->
(* {!expand_symbolic_value_no_branching} checks if there are branchings,
* but we prefer to also check it here - this leads to cleaner messages
* and debugging *)
@@ -704,16 +701,16 @@ let greedy_expand_symbolics_with_borrows (config : C.config) : cm_fun =
[config]): "
^ Print.name_to_string def.name))
else expand_symbolic_value_no_branching config sv None
- | T.Adt ((Tuple | Assumed Box), _, _, _) | T.Ref (_, _, _) ->
+ | T.Adt ((Tuple | Assumed Box), _) | T.Ref (_, _, _) ->
(* Ok *)
expand_symbolic_value_no_branching config sv None
- | T.Adt (Assumed (Vec | Option | Array | Slice | Str | Range), _, _, _)
- ->
+ | T.Adt (Assumed (Vec | Option | Array | Slice | Str | Range), _) ->
(* We can't expand those *)
raise
(Failure
"Attempted to greedily expand an ADT which can't be expanded ")
- | T.TypeVar _ | T.Literal _ | Never -> raise (Failure "Unreachable")
+ | T.TypeVar _ | T.Literal _ | Never | T.TraitType _ ->
+ raise (Failure "Unreachable")
in
(* Compose and continue *)
comp cc expand cf ctx
diff --git a/compiler/InterpreterExpressions.ml b/compiler/InterpreterExpressions.ml
index 8b2070c6..29826233 100644
--- a/compiler/InterpreterExpressions.ml
+++ b/compiler/InterpreterExpressions.ml
@@ -7,6 +7,7 @@ module E = Expressions
open Utils
module C = Contexts
module Subst = Substitute
+module Assoc = AssociatedTypes
module L = Logging
open TypesUtils
open ValuesUtils
@@ -141,11 +142,18 @@ let rec copy_value (allow_adt_copy : bool) (config : C.config)
| V.Adt av ->
(* Sanity check *)
(match v.V.ty with
- | T.Adt (T.Assumed (T.Box | Vec), _, _, _) ->
+ | T.Adt (T.Assumed (T.Box | Vec), _) ->
raise (Failure "Can't copy an assumed value other than Option")
- | T.Adt (T.AdtId _, _, _, _) -> assert allow_adt_copy
- | T.Adt ((T.Assumed Option | T.Tuple), _, _, _) -> () (* Ok *)
- | T.Adt (T.Assumed (Slice | T.Array), [], [ ty ], []) ->
+ | T.Adt (T.AdtId _, _) -> assert allow_adt_copy
+ | T.Adt ((T.Assumed Option | T.Tuple), _) -> () (* Ok *)
+ | T.Adt
+ ( T.Assumed (Slice | T.Array),
+ {
+ regions = [];
+ types = [ ty ];
+ const_generics = [];
+ trait_refs = [];
+ } ) ->
assert (ty_is_primitively_copyable ty)
| _ -> raise (Failure "Unreachable"));
let ctx, fields =
@@ -230,17 +238,16 @@ let prepare_eval_operand_reorganize (config : C.config) (op : E.operand) :
let prepare : cm_fun =
fun cf ctx ->
match op with
- | Expressions.Constant (ty, cv) ->
+ | E.Constant _ ->
(* No need to reorganize the context *)
- literal_to_typed_value (TypesUtils.ty_as_literal ty) cv |> ignore;
cf ctx
- | Expressions.Copy p ->
+ | E.Copy p ->
(* Access the value *)
let access = Read in
(* Expand the symbolic values, if necessary *)
let expand_prim_copy = true in
access_rplace_reorganize config expand_prim_copy access p cf ctx
- | Expressions.Move p ->
+ | E.Move p ->
(* Access the value *)
let access = Move in
let expand_prim_copy = false in
@@ -260,9 +267,70 @@ let eval_operand_no_reorganize (config : C.config) (op : E.operand)
^ "\n- ctx:\n" ^ eval_ctx_to_string ctx ^ "\n"));
(* Evaluate *)
match op with
- | Expressions.Constant (ty, cv) ->
- cf (literal_to_typed_value (TypesUtils.ty_as_literal ty) cv) ctx
- | Expressions.Copy p ->
+ | E.Constant cv -> (
+ match cv.value with
+ | E.CLiteral lit ->
+ cf (literal_to_typed_value (TypesUtils.ty_as_literal cv.ty) lit) ctx
+ | E.TraitConst (trait_ref, generics, const_name) -> (
+ assert (generics = TypesUtils.mk_empty_generic_args);
+ match trait_ref.trait_id with
+ | T.TraitImpl _ ->
+ (* This shouldn't happen: if we refer to a concrete implementation, we
+ should directly refer to the top-level constant *)
+ raise (Failure "Unreachable")
+ | _ -> (
+ (* We refer to a constant defined in a local clause: simply
+ introduce a fresh symbolic value *)
+ let ctx0 = ctx in
+ (* Lookup the trait declaration to retrieve the type of the symbolic value *)
+ let trait_decl =
+ C.ctx_lookup_trait_decl ctx
+ trait_ref.trait_decl_ref.trait_decl_id
+ in
+ let _, (ty, _) =
+ List.find (fun (name, _) -> name = const_name) trait_decl.consts
+ in
+ (* Introduce a fresh symbolic value *)
+ let v = mk_fresh_symbolic_typed_value_from_ety V.TraitConst ty in
+ (* Continue the evaluation *)
+ let e = cf v ctx in
+ (* We have to wrap the generated expression *)
+ match e with
+ | None -> None
+ | Some e ->
+ Some
+ (SymbolicAst.IntroSymbolic
+ ( ctx0,
+ None,
+ value_as_symbolic v.value,
+ SymbolicAst.TraitConstValue
+ (trait_ref, generics, const_name),
+ e ))))
+ | E.CVar vid -> (
+ let ctx0 = ctx in
+ (* Lookup the const generic value *)
+ let cv = C.ctx_lookup_const_generic_value ctx vid in
+ (* Copy the value *)
+ let allow_adt_copy = false in
+ let ctx, v = copy_value allow_adt_copy config ctx cv in
+ (* Continue *)
+ let e = cf v ctx in
+ (* We have to wrap the generated expression *)
+ match e with
+ | None -> None
+ | Some e ->
+ (* If we are synthesizing a symbolic AST, it means that we are in symbolic
+ mode: the value of the const generic is necessarily symbolic. *)
+ assert (is_symbolic cv.V.value);
+ (* *)
+ Some
+ (SymbolicAst.IntroSymbolic
+ ( ctx0,
+ None,
+ value_as_symbolic v.value,
+ SymbolicAst.ConstGenericValue vid,
+ e ))))
+ | E.Copy p ->
(* Access the value *)
let access = Read in
let cc = read_place access p in
@@ -283,7 +351,7 @@ let eval_operand_no_reorganize (config : C.config) (op : E.operand)
in
(* Compose and apply *)
comp cc copy cf ctx
- | Expressions.Move p ->
+ | E.Move p ->
(* Access the value *)
let access = Move in
let cc = read_place access p in
@@ -656,7 +724,8 @@ let eval_rvalue_aggregate (config : C.config)
| E.AggregatedTuple ->
let tys = List.map (fun (v : V.typed_value) -> v.V.ty) values in
let v = V.Adt { variant_id = None; field_values = values } in
- let ty = T.Adt (T.Tuple, [], tys, []) in
+ let generics = TypesUtils.mk_generic_args [] tys [] [] in
+ let ty = T.Adt (T.Tuple, generics) in
let aggregated : V.typed_value = { V.value = v; ty } in
(* Call the continuation *)
cf aggregated ctx
@@ -667,20 +736,22 @@ let eval_rvalue_aggregate (config : C.config)
assert (List.length values = 1)
else raise (Failure "Unreachable");
(* Construt the value *)
- let aty = T.Adt (T.Assumed T.Option, [], [ ty ], []) in
+ let generics = TypesUtils.mk_generic_args [] [ ty ] [] [] in
+ let aty = T.Adt (T.Assumed T.Option, generics) in
let av : V.adt_value =
{ V.variant_id = Some variant_id; V.field_values = values }
in
let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in
(* Call the continuation *)
cf aggregated ctx
- | E.AggregatedAdt (def_id, opt_variant_id, regions, types, cgs) ->
+ | E.AggregatedAdt (def_id, opt_variant_id, generics) ->
(* Sanity checks *)
let type_decl = C.ctx_lookup_type_decl ctx def_id in
- assert (List.length type_decl.region_params = List.length regions);
+ assert (
+ List.length type_decl.generics.regions = List.length generics.regions);
let expected_field_types =
- Subst.ctx_adt_get_instantiated_field_etypes ctx def_id opt_variant_id
- types cgs
+ Assoc.ctx_adt_get_inst_norm_field_etypes ctx def_id opt_variant_id
+ generics
in
assert (
expected_field_types
@@ -689,7 +760,7 @@ let eval_rvalue_aggregate (config : C.config)
let av : V.adt_value =
{ V.variant_id = opt_variant_id; V.field_values = values }
in
- let aty = T.Adt (T.AdtId def_id, regions, types, cgs) in
+ let aty = T.Adt (T.AdtId def_id, generics) in
let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in
(* Call the continuation *)
cf aggregated ctx
@@ -709,7 +780,8 @@ let eval_rvalue_aggregate (config : C.config)
let av : V.adt_value =
{ V.variant_id = None; V.field_values = values }
in
- let aty = T.Adt (T.Assumed T.Range, [], [ ety ], []) in
+ let generics = TypesUtils.mk_generic_args_from_types [ ety ] in
+ let aty = T.Adt (T.Assumed T.Range, generics) in
let aggregated : V.typed_value = { V.value = Adt av; ty = aty } in
(* Call the continuation *)
cf aggregated ctx
@@ -719,7 +791,8 @@ let eval_rvalue_aggregate (config : C.config)
(* Sanity check: the number of values is consistent with the length *)
let len = (literal_as_scalar (const_generic_as_literal cg)).value in
assert (len = Z.of_int (List.length values));
- let ty = T.Adt (T.Assumed T.Array, [], [ ety ], [ cg ]) in
+ let generics = TypesUtils.mk_generic_args [] [ ety ] [ cg ] [] in
+ let ty = T.Adt (T.Assumed T.Array, generics) in
(* In order to generate a better AST, we introduce a symbolic
value equal to the array. The reason is that otherwise, the
array we introduce here might be duplicated in the generated
@@ -752,7 +825,7 @@ let eval_rvalue_not_global (config : C.config) (rvalue : E.rvalue)
(* Delegate to the proper auxiliary function *)
match rvalue with
| E.Use op -> comp_wrap (eval_operand config op) ctx
- | E.Ref (p, bkind) -> comp_wrap (eval_rvalue_ref config p bkind) ctx
+ | E.RvRef (p, bkind) -> comp_wrap (eval_rvalue_ref config p bkind) ctx
| E.UnaryOp (unop, op) -> eval_unary_op config unop op cf ctx
| E.BinaryOp (binop, op1, op2) -> eval_binary_op config binop op1 op2 cf ctx
| E.Aggregate (aggregate_kind, ops) ->
diff --git a/compiler/InterpreterLoopsJoinCtxs.ml b/compiler/InterpreterLoopsJoinCtxs.ml
index bf88e055..6d3ecb18 100644
--- a/compiler/InterpreterLoopsJoinCtxs.ml
+++ b/compiler/InterpreterLoopsJoinCtxs.ml
@@ -554,9 +554,15 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx)
C.type_context;
fun_context;
global_context;
+ trait_decls_context;
+ trait_impls_context;
region_groups;
type_vars;
const_generic_vars;
+ const_generic_vars_map;
+ norm_trait_etypes;
+ norm_trait_rtypes;
+ norm_trait_stypes;
env = _;
ended_regions = ended_regions0;
} =
@@ -566,9 +572,15 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx)
C.type_context = _;
fun_context = _;
global_context = _;
+ trait_decls_context = _;
+ trait_impls_context = _;
region_groups = _;
type_vars = _;
const_generic_vars = _;
+ const_generic_vars_map = _;
+ norm_trait_etypes = _;
+ norm_trait_rtypes = _;
+ norm_trait_stypes = _;
env = _;
ended_regions = ended_regions1;
} =
@@ -580,9 +592,15 @@ let join_ctxs (loop_id : V.LoopId.id) (fixed_ids : ids_sets) (ctx0 : C.eval_ctx)
C.type_context;
fun_context;
global_context;
+ trait_decls_context;
+ trait_impls_context;
region_groups;
type_vars;
const_generic_vars;
+ const_generic_vars_map;
+ norm_trait_etypes;
+ norm_trait_rtypes;
+ norm_trait_stypes;
env;
ended_regions;
}
diff --git a/compiler/InterpreterLoopsMatchCtxs.ml b/compiler/InterpreterLoopsMatchCtxs.ml
index 9248e513..8cab546e 100644
--- a/compiler/InterpreterLoopsMatchCtxs.ml
+++ b/compiler/InterpreterLoopsMatchCtxs.ml
@@ -149,20 +149,25 @@ let rec match_types (match_distinct_types : 'r T.ty -> 'r T.ty -> 'r T.ty)
(match_regions : 'r -> 'r -> 'r) (ty0 : 'r T.ty) (ty1 : 'r T.ty) : 'r T.ty =
let match_rec = match_types match_distinct_types match_regions in
match (ty0, ty1) with
- | Adt (id0, regions0, tys0, cgs0), Adt (id1, regions1, tys1, cgs1) ->
+ | Adt (id0, generics0), Adt (id1, generics1) ->
assert (id0 = id1);
- assert (cgs0 = cgs1);
+ assert (generics0.const_generics = generics1.const_generics);
+ assert (generics0.trait_refs = generics1.trait_refs);
let id = id0 in
- let cgs = cgs1 in
+ let const_generics = generics1.const_generics in
+ let trait_refs = generics1.trait_refs in
let regions =
List.map
(fun (id0, id1) -> match_regions id0 id1)
- (List.combine regions0 regions1)
+ (List.combine generics0.regions generics1.regions)
in
- let tys =
- List.map (fun (ty0, ty1) -> match_rec ty0 ty1) (List.combine tys0 tys1)
+ let types =
+ List.map
+ (fun (ty0, ty1) -> match_rec ty0 ty1)
+ (List.combine generics0.types generics1.types)
in
- Adt (id, regions, tys, cgs)
+ let generics = { T.regions; types; const_generics; trait_refs } in
+ Adt (id, generics)
| TypeVar vid0, TypeVar vid1 ->
assert (vid0 = vid1);
let vid = vid0 in
diff --git a/compiler/InterpreterPaths.ml b/compiler/InterpreterPaths.ml
index 04dc8892..465d0028 100644
--- a/compiler/InterpreterPaths.ml
+++ b/compiler/InterpreterPaths.ml
@@ -3,6 +3,7 @@ module V = Values
module E = Expressions
module C = Contexts
module Subst = Substitute
+module Assoc = AssociatedTypes
module L = Logging
open Cps
open ValuesUtils
@@ -97,7 +98,7 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx)
match (pe, v.V.value, v.V.ty) with
| ( Field (((ProjAdt (_, _) | ProjOption _) as proj_kind), field_id),
V.Adt adt,
- T.Adt (type_id, _, _, _) ) -> (
+ T.Adt (type_id, _) ) -> (
(* Check consistency *)
(match (proj_kind, type_id) with
| ProjAdt (def_id, opt_variant_id), T.AdtId def_id' ->
@@ -119,8 +120,7 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx)
let updated = { v with value = nadt } in
Ok (ctx, { res with updated }))
(* Tuples *)
- | Field (ProjTuple arity, field_id), V.Adt adt, T.Adt (T.Tuple, _, _, _)
- -> (
+ | Field (ProjTuple arity, field_id), V.Adt adt, T.Adt (T.Tuple, _) -> (
assert (arity = List.length adt.field_values);
let fv = T.FieldId.nth adt.field_values field_id in
(* Project *)
@@ -145,9 +145,9 @@ let rec access_projection (access : projection_access) (ctx : C.eval_ctx)
(* Box dereferencement *)
| ( DerefBox,
Adt { variant_id = None; field_values = [ bv ] },
- T.Adt (T.Assumed T.Box, _, _, _) ) -> (
- (* We allow moving inside of boxes. In practice, this kind of
- * manipulations should happen only inside unsage code, so
+ T.Adt (T.Assumed T.Box, _) ) -> (
+ (* We allow moving outside of boxes. In practice, this kind of
+ * manipulations should happen only inside unsafe code, so
* it shouldn't happen due to user code, and we leverage it
* when implementing box dereferencement for the concrete
* interpreter *)
@@ -357,24 +357,23 @@ let write_place (access : access_kind) (p : E.place) (nv : V.typed_value)
| Error e -> raise (Failure ("Unreachable: " ^ show_path_fail_kind e))
| Ok ctx -> ctx
-let compute_expanded_bottom_adt_value (tyctx : T.type_decl T.TypeDeclId.Map.t)
+let compute_expanded_bottom_adt_value (ctx : C.eval_ctx)
(def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option)
- (regions : T.erased_region list) (types : T.ety list)
- (cgs : T.const_generic list) : V.typed_value =
+ (generics : T.egeneric_args) : V.typed_value =
(* Lookup the definition and check if it is an enumeration - it
should be an enumeration if and only if the projection element
is a field projection with *some* variant id. Retrieve the list
of fields at the same time. *)
- let def = T.TypeDeclId.Map.find def_id tyctx in
- assert (List.length regions = List.length def.T.region_params);
+ let def = C.ctx_lookup_type_decl ctx def_id in
+ assert (List.length generics.regions = List.length def.T.generics.regions);
(* Compute the field types *)
let field_types =
- Subst.type_decl_get_instantiated_field_etypes def opt_variant_id types cgs
+ Assoc.type_decl_get_inst_norm_field_etypes ctx def opt_variant_id generics
in
(* Initialize the expanded value *)
let fields = List.map mk_bottom field_types in
let av = V.Adt { variant_id = opt_variant_id; field_values = fields } in
- let ty = T.Adt (T.AdtId def_id, regions, types, cgs) in
+ let ty = T.Adt (T.AdtId def_id, generics) in
{ V.value = av; V.ty }
let compute_expanded_bottom_option_value (variant_id : T.VariantId.id)
@@ -387,7 +386,8 @@ let compute_expanded_bottom_option_value (variant_id : T.VariantId.id)
else raise (Failure "Unreachable")
in
let av = V.Adt { variant_id = Some variant_id; field_values } in
- let ty = T.Adt (T.Assumed T.Option, [], [ param_ty ], []) in
+ let generics = TypesUtils.mk_generic_args [] [ param_ty ] [] [] in
+ let ty = T.Adt (T.Assumed T.Option, generics) in
{ V.value = av; ty }
let compute_expanded_bottom_tuple_value (field_types : T.ety list) :
@@ -395,7 +395,8 @@ let compute_expanded_bottom_tuple_value (field_types : T.ety list) :
(* Generate the field values *)
let fields = List.map mk_bottom field_types in
let v = V.Adt { variant_id = None; field_values = fields } in
- let ty = T.Adt (T.Tuple, [], field_types, []) in
+ let generics = TypesUtils.mk_generic_args [] field_types [] [] in
+ let ty = T.Adt (T.Tuple, generics) in
{ V.value = v; V.ty }
(** Auxiliary helper to expand {!V.Bottom} values.
@@ -447,19 +448,29 @@ let expand_bottom_value_from_projection (access : access_kind) (p : E.place)
match (pe, ty) with
(* "Regular" ADTs *)
| ( Field (ProjAdt (def_id, opt_variant_id), _),
- T.Adt (T.AdtId def_id', regions, types, cgs) ) ->
+ T.Adt (T.AdtId def_id', generics) ) ->
assert (def_id = def_id');
- compute_expanded_bottom_adt_value ctx.type_context.type_decls def_id
- opt_variant_id regions types cgs
+ compute_expanded_bottom_adt_value ctx def_id opt_variant_id generics
(* Option *)
| ( Field (ProjOption variant_id, _),
- T.Adt (T.Assumed T.Option, [], [ ty ], []) ) ->
+ T.Adt
+ ( T.Assumed T.Option,
+ {
+ T.regions = [];
+ types = [ ty ];
+ const_generics = [];
+ trait_refs = [];
+ } ) ) ->
compute_expanded_bottom_option_value variant_id ty
(* Tuples *)
- | Field (ProjTuple arity, _), T.Adt (T.Tuple, [], tys, []) ->
- assert (arity = List.length tys);
+ | ( Field (ProjTuple arity, _),
+ T.Adt
+ ( T.Tuple,
+ { T.regions = []; types; const_generics = []; trait_refs = [] } ) )
+ ->
+ assert (arity = List.length types);
(* Generate the field values *)
- compute_expanded_bottom_tuple_value tys
+ compute_expanded_bottom_tuple_value types
| _ ->
raise
(Failure
diff --git a/compiler/InterpreterPaths.mli b/compiler/InterpreterPaths.mli
index 4a9f3b41..041b0a97 100644
--- a/compiler/InterpreterPaths.mli
+++ b/compiler/InterpreterPaths.mli
@@ -3,6 +3,7 @@ module V = Values
module E = Expressions
module C = Contexts
module Subst = Substitute
+module Assoc = AssociatedTypes
module L = Logging
open Cps
open InterpreterExpansion
@@ -56,12 +57,10 @@ val compute_expanded_bottom_tuple_value : T.ety list -> V.typed_value
(** Compute an expanded ADT ⊥ value *)
val compute_expanded_bottom_adt_value :
- T.type_decl T.TypeDeclId.Map.t ->
+ C.eval_ctx ->
T.TypeDeclId.id ->
T.VariantId.id option ->
- T.erased_region list ->
- T.ety list ->
- T.const_generic list ->
+ T.egeneric_args ->
V.typed_value
(** Compute an expanded [Option] ⊥ value *)
diff --git a/compiler/InterpreterProjectors.ml b/compiler/InterpreterProjectors.ml
index faed066b..9e0c2b75 100644
--- a/compiler/InterpreterProjectors.ml
+++ b/compiler/InterpreterProjectors.ml
@@ -3,6 +3,7 @@ module V = Values
module E = Expressions
module C = Contexts
module Subst = Substitute
+module Assoc = AssociatedTypes
module L = Logging
open TypesUtils
open InterpreterUtils
@@ -24,12 +25,12 @@ let rec apply_proj_borrows_on_shared_borrow (ctx : C.eval_ctx)
else
match (v.V.value, ty) with
| V.Literal _, T.Literal _ -> []
- | V.Adt adt, T.Adt (id, region_params, tys, cgs) ->
+ | V.Adt adt, T.Adt (id, generics) ->
(* Retrieve the types of the fields *)
let field_types =
- Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id
- region_params tys cgs
+ Assoc.ctx_adt_value_get_inst_norm_field_rtypes ctx adt id generics
in
+
(* Project over the field values *)
let fields_types = List.combine adt.V.field_values field_types in
let proj_fields =
@@ -103,11 +104,10 @@ let rec apply_proj_borrows (check_symbolic_no_ended : bool) (ctx : C.eval_ctx)
let value : V.avalue =
match (v.V.value, ty) with
| V.Literal _, T.Literal _ -> V.AIgnored
- | V.Adt adt, T.Adt (id, region_params, tys, cgs) ->
+ | V.Adt adt, T.Adt (id, generics) ->
(* Retrieve the types of the fields *)
let field_types =
- Subst.ctx_adt_value_get_instantiated_field_rtypes ctx adt id
- region_params tys cgs
+ Assoc.ctx_adt_value_get_inst_norm_field_rtypes ctx adt id generics
in
(* Project over the field values *)
let fields_types = List.combine adt.V.field_values field_types in
@@ -268,8 +268,7 @@ let apply_proj_loans_on_symbolic_expansion (regions : T.RegionId.Set.t)
let (value, ty) : V.avalue * T.rty =
match (see, original_sv_ty) with
| SeLiteral _, T.Literal _ -> (V.AIgnored, original_sv_ty)
- | SeAdt (variant_id, field_values), T.Adt (_id, _region_params, _tys, _cgs)
- ->
+ | SeAdt (variant_id, field_values), T.Adt (_id, _generics) ->
(* Project over the field values *)
let field_values =
List.map
diff --git a/compiler/InterpreterStatements.ml b/compiler/InterpreterStatements.ml
index 045c4484..36bc3492 100644
--- a/compiler/InterpreterStatements.ml
+++ b/compiler/InterpreterStatements.ml
@@ -17,6 +17,7 @@ open InterpreterProjectors
open InterpreterExpansion
open InterpreterPaths
open InterpreterExpressions
+module PCtx = Print.EvalCtxLlbcAst
(** The local logger *)
let log = L.statements_log
@@ -232,9 +233,8 @@ let set_discriminant (config : C.config) (p : E.place)
let update_value cf (v : V.typed_value) : m_fun =
fun ctx ->
match (v.V.ty, v.V.value) with
- | ( T.Adt
- (((T.AdtId _ | T.Assumed T.Option) as type_id), regions, types, cgs),
- V.Adt av ) -> (
+ | T.Adt (((T.AdtId _ | T.Assumed T.Option) as type_id), generics), V.Adt av
+ -> (
(* There are two situations:
- either the discriminant is already the proper one (in which case we
don't do anything)
@@ -251,28 +251,26 @@ let set_discriminant (config : C.config) (p : E.place)
let bottom_v =
match type_id with
| T.AdtId def_id ->
- compute_expanded_bottom_adt_value
- ctx.type_context.type_decls def_id (Some variant_id)
- regions types cgs
+ compute_expanded_bottom_adt_value ctx def_id
+ (Some variant_id) generics
| T.Assumed T.Option ->
- assert (regions = []);
+ assert (generics.regions = []);
compute_expanded_bottom_option_value variant_id
- (Collections.List.to_cons_nil types)
+ (Collections.List.to_cons_nil generics.types)
| _ -> raise (Failure "Unreachable")
in
assign_to_place config bottom_v p (cf Unit) ctx)
- | ( T.Adt
- (((T.AdtId _ | T.Assumed T.Option) as type_id), regions, types, cgs),
- V.Bottom ) ->
+ | T.Adt (((T.AdtId _ | T.Assumed T.Option) as type_id), generics), V.Bottom
+ ->
let bottom_v =
match type_id with
| T.AdtId def_id ->
- compute_expanded_bottom_adt_value ctx.type_context.type_decls
- def_id (Some variant_id) regions types cgs
+ compute_expanded_bottom_adt_value ctx def_id (Some variant_id)
+ generics
| T.Assumed T.Option ->
- assert (regions = []);
+ assert (generics.regions = []);
compute_expanded_bottom_option_value variant_id
- (Collections.List.to_cons_nil types)
+ (Collections.List.to_cons_nil generics.types)
| _ -> raise (Failure "Unreachable")
in
assign_to_place config bottom_v p (cf Unit) ctx
@@ -301,24 +299,34 @@ let ctx_push_frame (ctx : C.eval_ctx) : C.eval_ctx =
let push_frame : cm_fun = fun cf ctx -> cf (ctx_push_frame ctx)
(** Small helper: compute the type of the return value for a specific
- instantiation of a non-local function.
+ instantiation of an assumed function.
*)
-let get_non_local_function_return_type (fid : A.assumed_fun_id)
- (region_params : T.erased_region list) (type_params : T.ety list)
- (const_generic_params : T.const_generic list) : T.ety =
+let get_assumed_function_return_type (ctx : C.eval_ctx) (fid : A.assumed_fun_id)
+ (generics : T.egeneric_args) : T.ety =
+ assert (generics.trait_refs = []);
(* [Box::free] has a special treatment *)
- match (fid, region_params, type_params, const_generic_params) with
- | A.BoxFree, [], [ _ ], [] -> mk_unit_ty
+ match fid with
+ | A.BoxFree ->
+ assert (generics.regions = []);
+ assert (List.length generics.types = 1);
+ assert (generics.const_generics = []);
+ mk_unit_ty
| _ ->
(* Retrieve the function's signature *)
let sg = Assumed.get_assumed_sig fid in
(* Instantiate the return type *)
- let tsubst = Subst.make_type_subst_from_vars sg.type_params type_params in
- let cgsubst =
- Subst.make_const_generic_subst_from_vars sg.const_generic_params
- const_generic_params
+ (* There shouldn't be any reference to Self *)
+ let tr_self : T.erased_region T.trait_instance_id =
+ T.UnknownTrait __FUNCTION__
+ in
+ let { Subst.r_subst = _; ty_subst; cg_subst; tr_subst; tr_self } =
+ Subst.make_esubst_from_generics sg.generics generics tr_self
in
- Subst.erase_regions_substitute_types tsubst cgsubst sg.output
+ let ty =
+ Subst.erase_regions_substitute_types ty_subst cg_subst tr_subst tr_self
+ sg.output
+ in
+ Assoc.ctx_normalize_ety ctx ty
let move_return_value (config : C.config) (pop_return_value : bool)
(cf : V.typed_value option -> m_fun) : m_fun =
@@ -418,19 +426,19 @@ let pop_frame_assign (config : C.config) (dest : E.place) : cm_fun =
in
comp cf_pop cf_assign
-(** Auxiliary function - see {!eval_non_local_function_call} *)
-let eval_replace_concrete (_config : C.config)
- (_region_params : T.erased_region list) (_type_params : T.ety list)
- (_cg_params : T.const_generic list) : cm_fun =
+(** Auxiliary function - see {!eval_assumed_function_call} *)
+let eval_replace_concrete (_config : C.config) (_generics : T.egeneric_args) :
+ cm_fun =
fun _cf _ctx -> raise Unimplemented
-(** Auxiliary function - see {!eval_non_local_function_call} *)
-let eval_box_new_concrete (config : C.config)
- (region_params : T.erased_region list) (type_params : T.ety list)
- (cg_params : T.const_generic list) : cm_fun =
+(** Auxiliary function - see {!eval_assumed_function_call} *)
+let eval_box_new_concrete (config : C.config) (generics : T.egeneric_args) :
+ cm_fun =
fun cf ctx ->
(* Check and retrieve the arguments *)
- match (region_params, type_params, cg_params, ctx.env) with
+ match
+ (generics.regions, generics.types, generics.const_generics, ctx.env)
+ with
| ( [],
[ boxed_ty ],
[],
@@ -448,7 +456,8 @@ let eval_box_new_concrete (config : C.config)
(* Create the new box *)
let cf_create cf (moved_input_value : V.typed_value) : m_fun =
(* Create the box value *)
- let box_ty = T.Adt (T.Assumed T.Box, [], [ boxed_ty ], []) in
+ let generics = TypesUtils.mk_generic_args_from_types [ boxed_ty ] in
+ let box_ty = T.Adt (T.Assumed T.Box, generics) in
let box_v =
V.Adt { variant_id = None; field_values = [ moved_input_value ] }
in
@@ -467,13 +476,14 @@ let eval_box_new_concrete (config : C.config)
| _ -> raise (Failure "Inconsistent state")
(** Auxiliary function which factorizes code to evaluate [std::Deref::deref]
- and [std::DerefMut::deref_mut] - see {!eval_non_local_function_call} *)
+ and [std::DerefMut::deref_mut] - see {!eval_assumed_function_call} *)
let eval_box_deref_mut_or_shared_concrete (config : C.config)
- (region_params : T.erased_region list) (type_params : T.ety list)
- (cg_params : T.const_generic list) (is_mut : bool) : cm_fun =
+ (generics : T.egeneric_args) (is_mut : bool) : cm_fun =
fun cf ctx ->
(* Check the arguments *)
- match (region_params, type_params, cg_params, ctx.env) with
+ match
+ (generics.regions, generics.types, generics.const_generics, ctx.env)
+ with
| ( [],
[ boxed_ty ],
[],
@@ -495,7 +505,7 @@ let eval_box_deref_mut_or_shared_concrete (config : C.config)
{ E.var_id = input_var.C.index; projection = [ E.Deref; E.DerefBox ] }
in
let borrow_kind = if is_mut then E.Mut else E.Shared in
- let rv = E.Ref (p, borrow_kind) in
+ let rv = E.RvRef (p, borrow_kind) in
let cf_borrow = eval_rvalue_not_global config rv in
(* Move the borrow to its destination *)
@@ -514,23 +524,19 @@ let eval_box_deref_mut_or_shared_concrete (config : C.config)
comp cf_borrow cf_move cf ctx
| _ -> raise (Failure "Inconsistent state")
-(** Auxiliary function - see {!eval_non_local_function_call} *)
-let eval_box_deref_concrete (config : C.config)
- (region_params : T.erased_region list) (type_params : T.ety list)
- (cg_params : T.const_generic list) : cm_fun =
+(** Auxiliary function - see {!eval_assumed_function_call} *)
+let eval_box_deref_concrete (config : C.config) (generics : T.egeneric_args) :
+ cm_fun =
let is_mut = false in
- eval_box_deref_mut_or_shared_concrete config region_params type_params
- cg_params is_mut
+ eval_box_deref_mut_or_shared_concrete config generics is_mut
-(** Auxiliary function - see {!eval_non_local_function_call} *)
-let eval_box_deref_mut_concrete (config : C.config)
- (region_params : T.erased_region list) (type_params : T.ety list)
- (cg_params : T.const_generic list) : cm_fun =
+(** Auxiliary function - see {!eval_assumed_function_call} *)
+let eval_box_deref_mut_concrete (config : C.config) (generics : T.egeneric_args)
+ : cm_fun =
let is_mut = true in
- eval_box_deref_mut_or_shared_concrete config region_params type_params
- cg_params is_mut
+ eval_box_deref_mut_or_shared_concrete config generics is_mut
-(** Auxiliary function - see {!eval_non_local_function_call}.
+(** Auxiliary function - see {!eval_assumed_function_call}.
[Box::free] is not handled the same way as the other assumed functions:
- in the regular case, whenever we need to evaluate an assumed function,
@@ -549,11 +555,10 @@ let eval_box_deref_mut_concrete (config : C.config)
It thus updates the box value (by calling {!drop_value}) and updates
the destination (by setting it to [()]).
*)
-let eval_box_free (config : C.config) (region_params : T.erased_region list)
- (type_params : T.ety list) (cg_params : T.const_generic list)
+let eval_box_free (config : C.config) (generics : T.egeneric_args)
(args : E.operand list) (dest : E.place) : cm_fun =
fun cf ctx ->
- match (region_params, type_params, cg_params, args) with
+ match (generics.regions, generics.types, generics.const_generics, args) with
| [], [ boxed_ty ], [], [ E.Move input_box_place ] ->
(* Required type checking *)
let input_box = InterpreterPaths.read_place Write input_box_place ctx in
@@ -570,17 +575,20 @@ let eval_box_free (config : C.config) (region_params : T.erased_region list)
cc cf ctx
| _ -> raise (Failure "Inconsistent state")
-(** Auxiliary function - see {!eval_non_local_function_call} *)
+(** Auxiliary function - see {!eval_assumed_function_call} *)
let eval_vec_function_concrete (_config : C.config) (_fid : A.assumed_fun_id)
- (_region_params : T.erased_region list) (_type_params : T.ety list)
- (_cg_params : T.const_generic list) : cm_fun =
+ (_generics : T.egeneric_args) : cm_fun =
fun _cf _ctx -> raise Unimplemented
(** Evaluate a non-local function call in concrete mode *)
-let eval_non_local_function_call_concrete (config : C.config)
- (fid : A.assumed_fun_id) (region_params : T.erased_region list)
- (type_params : T.ety list) (cg_params : T.const_generic list)
- (args : E.operand list) (dest : E.place) : cm_fun =
+let eval_assumed_function_call_concrete (config : C.config)
+ (fid : A.assumed_fun_id) (call : A.call) : cm_fun =
+ let generics = call.generics in
+ let args = call.args in
+ let dest = call.dest in
+ (* Sanity check: we don't fully handle the const generic vars environment
+ in concrete mode yet *)
+ assert (generics.const_generics = []);
(* There are two cases (and this is extremely annoying):
- the function is not box_free
- the function is box_free
@@ -589,7 +597,7 @@ let eval_non_local_function_call_concrete (config : C.config)
match fid with
| A.BoxFree ->
(* Degenerate case: box_free *)
- eval_box_free config region_params type_params cg_params args dest
+ eval_box_free config generics args dest
| _ ->
(* "Normal" case: not box_free *)
(* Evaluate the operands *)
@@ -604,16 +612,14 @@ let eval_non_local_function_call_concrete (config : C.config)
* but it made it less clear where the computed values came from,
* so we reversed the modifications. *)
let cf_eval_call cf (args_vl : V.typed_value list) : m_fun =
+ fun ctx ->
(* Push the stack frame: we initialize the frame with the return variable,
and one variable per input argument *)
let cc = push_frame in
(* Create and push the return variable *)
let ret_vid = E.VarId.zero in
- let ret_ty =
- get_non_local_function_return_type fid region_params type_params
- cg_params
- in
+ let ret_ty = get_assumed_function_return_type ctx fid generics in
let ret_var = mk_var ret_vid (Some "@return") ret_ty in
let cc = comp cc (push_uninitialized_var ret_var) in
@@ -630,23 +636,17 @@ let eval_non_local_function_call_concrete (config : C.config)
* access to a body. *)
let cf_eval_body : cm_fun =
match fid with
- | A.Replace ->
- eval_replace_concrete config region_params type_params cg_params
- | BoxNew ->
- eval_box_new_concrete config region_params type_params cg_params
- | BoxDeref ->
- eval_box_deref_concrete config region_params type_params cg_params
- | BoxDerefMut ->
- eval_box_deref_mut_concrete config region_params type_params
- cg_params
+ | A.Replace -> eval_replace_concrete config generics
+ | BoxNew -> eval_box_new_concrete config generics
+ | BoxDeref -> eval_box_deref_concrete config generics
+ | BoxDerefMut -> eval_box_deref_mut_concrete config generics
| BoxFree ->
(* Should have been treated above *) raise (Failure "Unreachable")
| VecNew | VecPush | VecInsert | VecLen | VecIndex | VecIndexMut ->
- eval_vec_function_concrete config fid region_params type_params
- cg_params
+ eval_vec_function_concrete config fid generics
| ArrayIndexShared | ArrayIndexMut | ArrayToSliceShared
| ArrayToSliceMut | ArraySubsliceShared | ArraySubsliceMut
- | SliceIndexShared | SliceIndexMut | SliceSubsliceShared
+ | ArrayRepeat | SliceIndexShared | SliceIndexMut | SliceSubsliceShared
| SliceSubsliceMut | SliceLen ->
raise (Failure "Unimplemented")
in
@@ -657,50 +657,11 @@ let eval_non_local_function_call_concrete (config : C.config)
let cc = comp cc (pop_frame_assign config dest) in
(* Continue *)
- cc cf
+ cc cf ctx
in
(* Compose and apply *)
comp cf_eval_ops cf_eval_call
-let instantiate_fun_sig (type_params : T.ety list)
- (cg_params : T.const_generic list) (sg : A.fun_sig) : A.inst_fun_sig =
- (* Generate fresh abstraction ids and create a substitution from region
- * group ids to abstraction ids *)
- let rg_abs_ids_bindings =
- List.map
- (fun rg ->
- let abs_id = C.fresh_abstraction_id () in
- (rg.T.id, abs_id))
- sg.regions_hierarchy
- in
- let asubst_map : V.AbstractionId.id T.RegionGroupId.Map.t =
- List.fold_left
- (fun mp (rg_id, abs_id) -> T.RegionGroupId.Map.add rg_id abs_id mp)
- T.RegionGroupId.Map.empty rg_abs_ids_bindings
- in
- let asubst (rg_id : T.RegionGroupId.id) : V.AbstractionId.id =
- T.RegionGroupId.Map.find rg_id asubst_map
- in
- (* Generate fresh regions and their substitutions *)
- let _, rsubst, _ = Subst.fresh_regions_with_substs sg.region_params in
- (* Generate the type substitution
- * Note that we need the substitution to map the type variables to
- * {!rty} types (not {!ety}). In order to do that, we convert the
- * type parameters to types with regions. This is possible only
- * if those types don't contain any regions.
- * This is a current limitation of the analysis: there is still some
- * work to do to properly handle full type parametrization.
- * *)
- let rtype_params = List.map ety_no_regions_to_rty type_params in
- let tsubst = Subst.make_type_subst_from_vars sg.type_params rtype_params in
- let cgsubst =
- Subst.make_const_generic_subst_from_vars sg.const_generic_params cg_params
- in
- (* Substitute the signature *)
- let inst_sig = Subst.substitute_signature asubst rsubst tsubst cgsubst sg in
- (* Return *)
- inst_sig
-
(** Helper
Create abstractions (with no avalues, which have to be inserted afterwards)
@@ -836,7 +797,7 @@ let rec eval_statement (config : C.config) (st : A.statement) : st_cm_fun =
match rvalue with
| E.Global _ -> raise (Failure "Unreachable")
| E.Use _
- | E.Ref (_, (E.Shared | E.Mut | E.TwoPhaseMut | E.Shallow))
+ | E.RvRef (_, (E.Shared | E.Mut | E.TwoPhaseMut | E.Shallow))
| E.UnaryOp _ | E.BinaryOp _ | E.Discriminant _
| E.Aggregate _ ->
let rp = rvalue_get_place rvalue in
@@ -893,7 +854,16 @@ and eval_global (config : C.config) (dest : E.place) (gid : LA.GlobalDeclId.id)
match config.mode with
| ConcreteMode ->
(* Treat the evaluation of the global as a call to the global body (without arguments) *)
- (eval_local_function_call_concrete config global.body_id [] [] [] [] dest)
+ let call =
+ {
+ A.func = A.FunId (A.Regular global.body_id);
+ generics = TypesUtils.mk_empty_generic_args;
+ trait_and_method_generic_args = None;
+ args = [];
+ dest;
+ }
+ in
+ (eval_transparent_function_call_concrete config global.body_id call)
cf ctx
| SymbolicMode ->
(* Generate a fresh symbolic value. In the translation, this fresh symbolic value will be
@@ -1037,126 +1007,354 @@ and eval_switch (config : C.config) (switch : A.switch) : st_cm_fun =
(** Evaluate a function call (auxiliary helper for [eval_statement]) *)
and eval_function_call (config : C.config) (call : A.call) : st_cm_fun =
- (* There are two cases:
+ (* There are several cases:
- this is a local function, in which case we execute its body
- - this is a non-local function, in which case there is a special treatment
+ - this is an assumed function, in which case there is a special treatment
+ - this is a trait method
*)
- match call.func with
- | A.Regular fid ->
- eval_local_function_call config fid call.region_args call.type_args
- call.const_generic_args call.args call.dest
- | A.Assumed fid ->
- eval_non_local_function_call config fid call.region_args call.type_args
- call.const_generic_args call.args call.dest
+ match config.mode with
+ | C.ConcreteMode -> eval_function_call_concrete config call
+ | C.SymbolicMode -> eval_function_call_symbolic config call
-(** Evaluate a local (i.e., non-assumed) function call in concrete mode *)
-and eval_local_function_call_concrete (config : C.config) (fid : A.FunDeclId.id)
- (_region_args : T.erased_region list) (type_args : T.ety list)
- (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) :
- st_cm_fun =
+and eval_function_call_concrete (config : C.config) (call : A.call) : st_cm_fun
+ =
fun cf ctx ->
- (* Retrieve the (correctly instantiated) body *)
- let def = C.ctx_lookup_fun_decl ctx fid in
- (* We can evaluate the function call only if it is not opaque *)
- let body =
- match def.body with
- | None ->
- raise
- (Failure
- ("Can't evaluate a call to an opaque function: "
- ^ Print.name_to_string def.name))
- | Some body -> body
- in
- let tsubst =
- Subst.make_type_subst_from_vars def.A.signature.type_params type_args
- in
- let cgsubst =
- Subst.make_const_generic_subst_from_vars
- def.A.signature.const_generic_params cg_args
- in
- let locals, body_st = Subst.fun_body_substitute_in_body tsubst cgsubst body in
-
- (* Evaluate the input operands *)
- assert (List.length args = body.A.arg_count);
- let cc = eval_operands config args in
-
- (* Push a frame delimiter - we use {!comp_transmit} to transmit the result
- * of the operands evaluation from above to the functions afterwards, while
- * ignoring it in this function *)
- let cc = comp_transmit cc push_frame in
-
- (* Compute the initial values for the local variables *)
- (* 1. Push the return value *)
- let ret_var, locals =
- match locals with
- | ret_ty :: locals -> (ret_ty, locals)
- | _ -> raise (Failure "Unreachable")
- in
- let input_locals, locals =
- Collections.List.split_at locals body.A.arg_count
- in
-
- let cc = comp_transmit cc (push_var ret_var (mk_bottom ret_var.var_ty)) in
-
- (* 2. Push the input values *)
- let cf_push_inputs cf args =
- let inputs = List.combine input_locals args in
- (* Note that this function checks that the variables and their values
- * have the same type (this is important) *)
- push_vars inputs cf
- in
- let cc = comp cc cf_push_inputs in
+ match call.func with
+ | A.FunId (A.Regular fid) ->
+ eval_transparent_function_call_concrete config fid call cf ctx
+ | A.FunId (A.Assumed fid) ->
+ (* Continue - note that we do as if the function call has been successful,
+ * by giving {!Unit} to the continuation, because we place us in the case
+ * where we haven't panicked. Of course, the translation needs to take the
+ * panic case into account... *)
+ eval_assumed_function_call_concrete config fid call (cf Unit) ctx
+ | A.TraitMethod _ -> raise (Failure "Unimplemented")
+
+and eval_function_call_symbolic (config : C.config) (call : A.call) : st_cm_fun
+ =
+ match call.func with
+ | A.FunId (A.Regular _) | A.TraitMethod _ ->
+ eval_transparent_function_call_symbolic config call
+ | A.FunId (A.Assumed fid) ->
+ eval_assumed_function_call_symbolic config fid call
- (* 3. Push the remaining local variables (initialized as {!Bottom}) *)
- let cc = comp cc (push_uninitialized_vars locals) in
+(** Evaluate a local (i.e., non-assumed) function call in concrete mode *)
+and eval_transparent_function_call_concrete (config : C.config)
+ (fid : A.FunDeclId.id) (call : A.call) : st_cm_fun =
+ let generics = call.A.generics in
+ let args = call.A.args in
+ let dest = call.A.dest in
+ (* Sanity check: we don't fully handle the const generic vars environment
+ in concrete mode yet *)
+ assert (generics.const_generics = []);
+ fun cf ctx ->
+ (* Retrieve the (correctly instantiated) body *)
+ let def = C.ctx_lookup_fun_decl ctx fid in
+ (* We can evaluate the function call only if it is not opaque *)
+ let body =
+ match def.body with
+ | None ->
+ raise
+ (Failure
+ ("Can't evaluate a call to an opaque function: "
+ ^ Print.name_to_string def.name))
+ | Some body -> body
+ in
+ (* TODO: we need to normalize the types if we want to correctly support traits *)
+ assert (generics.trait_refs = []);
+ (* There shouldn't be any reference to Self *)
+ let tr_self = T.UnknownTrait __FUNCTION__ in
+ let subst =
+ Subst.make_esubst_from_generics def.A.signature.generics generics tr_self
+ in
+ let locals, body_st = Subst.fun_body_substitute_in_body subst body in
+
+ (* Evaluate the input operands *)
+ assert (List.length args = body.A.arg_count);
+ let cc = eval_operands config args in
+
+ (* Push a frame delimiter - we use {!comp_transmit} to transmit the result
+ * of the operands evaluation from above to the functions afterwards, while
+ * ignoring it in this function *)
+ let cc = comp_transmit cc push_frame in
+
+ (* Compute the initial values for the local variables *)
+ (* 1. Push the return value *)
+ let ret_var, locals =
+ match locals with
+ | ret_ty :: locals -> (ret_ty, locals)
+ | _ -> raise (Failure "Unreachable")
+ in
+ let input_locals, locals =
+ Collections.List.split_at locals body.A.arg_count
+ in
- (* Execute the function body *)
- let cc = comp cc (eval_function_body config body_st) in
+ let cc = comp_transmit cc (push_var ret_var (mk_bottom ret_var.var_ty)) in
- (* Pop the stack frame and move the return value to its destination *)
- let cf_finish cf res =
- match res with
- | Panic -> cf Panic
- | Return ->
- (* Pop the stack frame, retrieve the return value, move it to
- * its destination and continue *)
- pop_frame_assign config dest (cf Unit)
- | Break _ | Continue _ | Unit | LoopReturn _ | EndEnterLoop _
- | EndContinue _ ->
- raise (Failure "Unreachable")
- in
- let cc = comp cc cf_finish in
+ (* 2. Push the input values *)
+ let cf_push_inputs cf args =
+ let inputs = List.combine input_locals args in
+ (* Note that this function checks that the variables and their values
+ * have the same type (this is important) *)
+ push_vars inputs cf
+ in
+ let cc = comp cc cf_push_inputs in
+
+ (* 3. Push the remaining local variables (initialized as {!Bottom}) *)
+ let cc = comp cc (push_uninitialized_vars locals) in
+
+ (* Execute the function body *)
+ let cc = comp cc (eval_function_body config body_st) in
+
+ (* Pop the stack frame and move the return value to its destination *)
+ let cf_finish cf res =
+ match res with
+ | Panic -> cf Panic
+ | Return ->
+ (* Pop the stack frame, retrieve the return value, move it to
+ * its destination and continue *)
+ pop_frame_assign config dest (cf Unit)
+ | Break _ | Continue _ | Unit | LoopReturn _ | EndEnterLoop _
+ | EndContinue _ ->
+ raise (Failure "Unreachable")
+ in
+ let cc = comp cc cf_finish in
- (* Continue *)
- cc cf ctx
+ (* Continue *)
+ cc cf ctx
(** Evaluate a local (i.e., non-assumed) function call in symbolic mode *)
-and eval_local_function_call_symbolic (config : C.config) (fid : A.FunDeclId.id)
- (region_args : T.erased_region list) (type_args : T.ety list)
- (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) :
- st_cm_fun =
+and eval_transparent_function_call_symbolic (config : C.config) (call : A.call)
+ : st_cm_fun =
fun cf ctx ->
- (* Retrieve the (correctly instantiated) signature *)
- let def = C.ctx_lookup_fun_decl ctx fid in
- let sg = def.A.signature in
- (* Instantiate the signature and introduce fresh abstraction and region ids
- * while doing so *)
- let inst_sg = instantiate_fun_sig type_args cg_args sg in
+ (* Instantiate the signature and introduce fresh abstractions and region ids while doing so.
+
+ We perform some manipulations when instantiating the signature.
+
+ # Trait impl calls
+ ==================
+ In particular, we have a special treatment of trait method calls when
+ the trait ref is a known impl.
+
+ For instance:
+ {[
+ trait HasValue {
+ fn has_value(&self) -> bool;
+ }
+
+ impl<T> HasValue for Option<T> {
+ fn has_value(&self) {
+ match self {
+ None => false,
+ Some(_) => true,
+ }
+ }
+ }
+
+ fn option_has_value<T>(x: &Option<T>) -> bool {
+ x.has_value()
+ }
+ ]}
+
+ The generated code looks like this:
+ {[
+ structure HasValue (Self : Type) = {
+ has_value : Self -> result bool
+ }
+
+ let OptionHasValueImpl.has_value (Self : Type) (self : Self) : result bool =
+ match self with
+ | None => false
+ | Some _ => true
+
+ let OptionHasValueInstance (T : Type) : HasValue (Option T) = {
+ has_value = OptionHasValueInstance.has_value
+ }
+ ]}
+
+ In [option_has_value], we don't want to refer to the [has_value] method
+ of the instance of [HasValue] for [Option<T>]. We want to refer directly
+ to the function which implements [has_value] for [Option<T>].
+ That is, instead of generating this:
+ {[
+ let option_has_value (T : Type) (x : Option T) : result bool =
+ (OptionHasValueInstance T).has_value x
+ ]}
+
+ We want to generate this:
+ {[
+ let option_has_value (T : Type) (x : Option T) : result bool =
+ OptionHasValueImpl.has_value T x
+ ]}
+
+ # Provided trait methods
+ ========================
+ Calls to provided trait methods also have a special treatment because
+ for now we forbid overriding provided trait methods in the trait implementations,
+ which means that whenever we call a provided trait method, we do not refer
+ to a trait clause but directly to the method provided in the trait declaration.
+ *)
+ let func, generics, def, inst_sg =
+ match call.func with
+ | A.FunId (A.Regular fid) ->
+ let def = C.ctx_lookup_fun_decl ctx fid in
+ let tr_self = T.UnknownTrait __FUNCTION__ in
+ let inst_sg =
+ instantiate_fun_sig ctx call.generics tr_self def.A.signature
+ in
+ (call.func, call.generics, def, inst_sg)
+ | A.FunId (A.Assumed _) ->
+ (* Unreachable: must be a transparent function *)
+ raise (Failure "Unreachable")
+ | A.TraitMethod (trait_ref, method_name, _) -> (
+ log#ldebug
+ (lazy
+ ("trait method call:\n- call: " ^ call_to_string ctx call
+ ^ "\n- method name: " ^ method_name ^ "\n- call.generics:\n"
+ ^ egeneric_args_to_string ctx call.generics
+ ^ "\n- trait and method generics:\n"
+ ^ egeneric_args_to_string ctx
+ (Option.get call.trait_and_method_generic_args)));
+ (* When instantiating, we need to group the generics for the trait ref
+ and the method *)
+ let generics = Option.get call.trait_and_method_generic_args in
+ (* Lookup the trait method signature - there are several possibilities
+ depending on whethere we call a top-level trait method impl or the
+ method from a local clause *)
+ match trait_ref.trait_id with
+ | TraitImpl impl_id -> (
+ (* Lookup the trait impl *)
+ let trait_impl = C.ctx_lookup_trait_impl ctx impl_id in
+ log#ldebug
+ (lazy ("trait impl: " ^ trait_impl_to_string ctx trait_impl));
+ (* First look in the required methods *)
+ let method_id =
+ List.find_opt
+ (fun (s, _) -> s = method_name)
+ trait_impl.required_methods
+ in
+ match method_id with
+ | Some (_, id) ->
+ (* This is a required method *)
+ let method_def = C.ctx_lookup_fun_decl ctx id in
+ (* Instantiate *)
+ let tr_self =
+ T.TraitRef (etrait_ref_no_regions_to_gr_trait_ref trait_ref)
+ in
+ let inst_sg =
+ instantiate_fun_sig ctx generics tr_self
+ method_def.A.signature
+ in
+ (* Also update the function identifier: we want to forget
+ the fact that we called a trait method, and treat it as
+ a regular function call to the top-level function
+ which implements the method. In order to do this properly,
+ we also need to update the generics.
+ *)
+ let func = A.FunId (A.Regular id) in
+ (func, generics, method_def, inst_sg)
+ | None ->
+ (* If not found, lookup the methods provided by the trait *declaration*
+ (remember: for now, we forbid overriding provided methods) *)
+ assert (trait_impl.provided_methods = []);
+ let trait_decl =
+ C.ctx_lookup_trait_decl ctx
+ trait_ref.trait_decl_ref.trait_decl_id
+ in
+ let _, method_id =
+ List.find
+ (fun (s, _) -> s = method_name)
+ trait_decl.provided_methods
+ in
+ let method_id = Option.get method_id in
+ let method_def = C.ctx_lookup_fun_decl ctx method_id in
+ (* For the instantiation we have to do something peculiar
+ because the method was defined for the trait declaration.
+ We have to group:
+ - the parameters given to the trait decl reference
+ - the parameters given to the method itself
+ For instance:
+ {[
+ trait Foo<T> {
+ fn f<U>(...) { ... }
+ }
+
+ fn g<G>(x : G) where Clause0: Foo<G, bool>
+ {
+ x.f::<u32>(...) // The arguments to f are: <G, bool, u32>
+ }
+ ]}
+ *)
+ let all_generics =
+ TypesUtils.merge_generic_args
+ trait_ref.trait_decl_ref.decl_generics call.generics
+ in
+ log#ldebug
+ (lazy
+ ("provided method call:" ^ "\n- method name: " ^ method_name
+ ^ "\n- all_generics:\n"
+ ^ egeneric_args_to_string ctx all_generics
+ ^ "\n- parent params info: "
+ ^ Print.option_to_string A.show_params_info
+ method_def.signature.parent_params_info));
+ let tr_self =
+ T.TraitRef (etrait_ref_no_regions_to_gr_trait_ref trait_ref)
+ in
+ let inst_sg =
+ instantiate_fun_sig ctx all_generics tr_self
+ method_def.A.signature
+ in
+ (call.func, call.generics, method_def, inst_sg))
+ | _ ->
+ (* We are using a local clause - we lookup the trait decl *)
+ let trait_decl =
+ C.ctx_lookup_trait_decl ctx trait_ref.trait_decl_ref.trait_decl_id
+ in
+ (* Lookup the method decl in the required *and* the provided methods *)
+ let _, method_id =
+ let provided =
+ List.filter_map
+ (fun (id, f) ->
+ match f with None -> None | Some f -> Some (id, f))
+ trait_decl.provided_methods
+ in
+ List.find
+ (fun (s, _) -> s = method_name)
+ (List.append trait_decl.required_methods provided)
+ in
+ let method_def = C.ctx_lookup_fun_decl ctx method_id in
+ log#ldebug (lazy ("method:\n" ^ fun_decl_to_string ctx method_def));
+ (* Instantiate *)
+ let tr_self = T.TraitRef trait_ref in
+ let tr_self =
+ TypesUtils.etrait_instance_id_no_regions_to_gr_trait_instance_id
+ tr_self
+ in
+ let inst_sg =
+ instantiate_fun_sig ctx generics tr_self method_def.A.signature
+ in
+ (call.func, call.generics, method_def, inst_sg))
+ in
(* Sanity check *)
- assert (List.length args = List.length def.A.signature.inputs);
+ assert (List.length call.args = List.length def.A.signature.inputs);
(* Evaluate the function call *)
- eval_function_call_symbolic_from_inst_sig config (A.Regular fid) inst_sg
- region_args type_args cg_args args dest cf ctx
+ eval_function_call_symbolic_from_inst_sig config func inst_sg generics
+ call.args call.dest cf ctx
(** Evaluate a function call in symbolic mode by using the function signature.
This allows us to factorize the evaluation of local and non-local function
calls in symbolic mode: only their signatures matter.
+
+ The [self_trait_ref] trait ref refers to [Self]. We use it when calling
+ a provided trait method, because those methods have a special treatment:
+ we dot not group them with the required trait methods, and forbid (for now)
+ overriding them. We treat them as regular method, which take an additional
+ trait ref as input.
*)
and eval_function_call_symbolic_from_inst_sig (config : C.config)
- (fid : A.fun_id) (inst_sg : A.inst_fun_sig)
- (_region_args : T.erased_region list) (type_args : T.ety list)
- (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) :
+ (fid : A.fun_id_or_trait_method_ref) (inst_sg : A.inst_fun_sig)
+ (generics : T.egeneric_args) (args : E.operand list) (dest : E.place) :
st_cm_fun =
fun cf ctx ->
(* Generate a fresh symbolic value for the return value *)
@@ -1224,8 +1422,8 @@ and eval_function_call_symbolic_from_inst_sig (config : C.config)
let expr = cf ctx in
(* Synthesize the symbolic AST *)
- S.synthesize_regular_function_call fid call_id ctx abs_ids type_args cg_args
- args args_places ret_spc dest_place expr
+ S.synthesize_regular_function_call fid call_id ctx abs_ids generics args
+ args_places ret_spc dest_place expr
in
let cc = comp cc cf_call in
@@ -1286,17 +1484,18 @@ and eval_function_call_symbolic_from_inst_sig (config : C.config)
cc (cf Unit) ctx
(** Evaluate a non-local function call in symbolic mode *)
-and eval_non_local_function_call_symbolic (config : C.config)
- (fid : A.assumed_fun_id) (region_args : T.erased_region list)
- (type_args : T.ety list) (cg_args : T.const_generic list)
- (args : E.operand list) (dest : E.place) : st_cm_fun =
+and eval_assumed_function_call_symbolic (config : C.config)
+ (fid : A.assumed_fun_id) (call : A.call) : st_cm_fun =
fun cf ctx ->
+ let generics = call.generics in
+ let args = call.args in
+ let dest = call.dest in
(* Sanity check: make sure the type parameters don't contain regions -
* this is a current limitation of our synthesis *)
assert (
List.for_all
(fun ty -> not (ty_has_borrows ctx.type_context.type_infos ty))
- type_args);
+ generics.types);
(* There are two cases (and this is extremely annoying):
- the function is not box_free
@@ -1307,7 +1506,7 @@ and eval_non_local_function_call_symbolic (config : C.config)
| A.BoxFree ->
(* Degenerate case: box_free - note that this is not really a function
* call: no need to call a "synthesize_..." function *)
- eval_box_free config region_args type_args cg_args args dest (cf Unit) ctx
+ eval_box_free config generics args dest (cf Unit) ctx
| _ ->
(* "Normal" case: not box_free *)
(* In symbolic mode, the behaviour of a function call is completely defined
@@ -1319,55 +1518,15 @@ and eval_non_local_function_call_symbolic (config : C.config)
(* should have been treated above *)
raise (Failure "Unreachable")
| _ ->
- instantiate_fun_sig type_args cg_args (Assumed.get_assumed_sig fid)
+ (* There shouldn't be any reference to Self *)
+ let tr_self = T.UnknownTrait __FUNCTION__ in
+ instantiate_fun_sig ctx generics tr_self
+ (Assumed.get_assumed_sig fid)
in
(* Evaluate the function call *)
- eval_function_call_symbolic_from_inst_sig config (A.Assumed fid) inst_sig
- region_args type_args cg_args args dest cf ctx
-
-(** Evaluate a non-local (i.e, assumed) function call such as [Box::deref]
- (auxiliary helper for [eval_statement]) *)
-and eval_non_local_function_call (config : C.config) (fid : A.assumed_fun_id)
- (region_args : T.erased_region list) (type_args : T.ety list)
- (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) :
- st_cm_fun =
- fun cf ctx ->
- (* Debug *)
- log#ldebug
- (lazy
- (let type_args =
- "[" ^ String.concat ", " (List.map (ety_to_string ctx) type_args) ^ "]"
- in
- let args =
- "[" ^ String.concat ", " (List.map (operand_to_string ctx) args) ^ "]"
- in
- let dest = place_to_string ctx dest in
- "eval_non_local_function_call:\n- fid:" ^ A.show_assumed_fun_id fid
- ^ "\n- type_args: " ^ type_args ^ "\n- args: " ^ args ^ "\n- dest: "
- ^ dest));
-
- match config.mode with
- | C.ConcreteMode ->
- eval_non_local_function_call_concrete config fid region_args type_args
- cg_args args dest (cf Unit) ctx
- | C.SymbolicMode ->
- eval_non_local_function_call_symbolic config fid region_args type_args
- cg_args args dest cf ctx
-
-(** Evaluate a local (i.e, not assumed) function call (auxiliary helper for
- [eval_statement]) *)
-and eval_local_function_call (config : C.config) (fid : A.FunDeclId.id)
- (region_args : T.erased_region list) (type_args : T.ety list)
- (cg_args : T.const_generic list) (args : E.operand list) (dest : E.place) :
- st_cm_fun =
- match config.mode with
- | ConcreteMode ->
- eval_local_function_call_concrete config fid region_args type_args cg_args
- args dest
- | SymbolicMode ->
- eval_local_function_call_symbolic config fid region_args type_args cg_args
- args dest
+ eval_function_call_symbolic_from_inst_sig config (A.FunId (A.Assumed fid))
+ inst_sig generics args dest cf ctx
(** Evaluate a statement seen as a function body *)
and eval_function_body (config : C.config) (body : A.statement) : st_cm_fun =
diff --git a/compiler/InterpreterStatements.mli b/compiler/InterpreterStatements.mli
index 814bc964..e65758ae 100644
--- a/compiler/InterpreterStatements.mli
+++ b/compiler/InterpreterStatements.mli
@@ -25,15 +25,6 @@ open InterpreterExpressions
*)
val pop_frame : C.config -> bool -> (V.typed_value option -> m_fun) -> m_fun
-(** Instantiate a function signature, introducing **fresh** abstraction ids and
- region ids. This is mostly used in preparation of function calls, when
- evaluating in symbolic mode of course.
-
- Note: there are no region parameters, because they should be erased.
- *)
-val instantiate_fun_sig :
- T.ety list -> T.const_generic list -> LA.fun_sig -> LA.inst_fun_sig
-
(** Helper.
Create a list of abstractions from a list of regions groups, and insert
diff --git a/compiler/InterpreterUtils.ml b/compiler/InterpreterUtils.ml
index 7bd37550..7aaee6ff 100644
--- a/compiler/InterpreterUtils.ml
+++ b/compiler/InterpreterUtils.ml
@@ -10,6 +10,11 @@ open TypesUtils
module PA = Print.EvalCtxLlbcAst
open Cps
+(* TODO: we should probably rename the file to ContextsUtils *)
+
+(** The local logger *)
+let log = L.interpreter_log
+
(** Some utilities *)
(** Auxiliary function - call a function which requires a continuation,
@@ -38,6 +43,15 @@ let typed_value_to_string = PA.typed_value_to_string
let typed_avalue_to_string = PA.typed_avalue_to_string
let place_to_string = PA.place_to_string
let operand_to_string = PA.operand_to_string
+let egeneric_args_to_string = PA.egeneric_args_to_string
+let rtrait_instance_id_to_string = PA.rtrait_instance_id_to_string
+let fun_sig_to_string = PA.fun_sig_to_string
+let fun_decl_to_string = PA.fun_decl_to_string
+let call_to_string = PA.call_to_string
+
+let trait_impl_to_string ctx =
+ PA.trait_impl_to_string { ctx with type_vars = []; const_generic_vars = [] }
+
let statement_to_string ctx = PA.statement_to_string ctx "" " "
let statement_to_string_with_tab ctx = PA.statement_to_string ctx " " " "
let env_elem_to_string ctx = PA.env_elem_to_string ctx "" " "
@@ -255,7 +269,8 @@ let value_has_ret_symbolic_value_with_borrow_under_mut (ctx : C.eval_ctx)
raise Found
else ()
| V.SynthInput | V.SynthInputGivenBack | V.FunCallGivenBack
- | V.SynthRetGivenBack | V.Global | V.LoopGivenBack | V.Aggregate ->
+ | V.SynthRetGivenBack | V.Global | V.LoopGivenBack | V.Aggregate
+ | V.ConstGeneric | V.TraitConst ->
()
end
in
@@ -272,7 +287,7 @@ let rvalue_get_place (rv : E.rvalue) : E.place option =
match rv with
| Use (Copy p | Move p) -> Some p
| Use (Constant _) -> None
- | Ref (p, _) -> Some p
+ | RvRef (p, _) -> Some p
| UnaryOp _ | BinaryOp _ | Global _ | Discriminant _ | Aggregate _ -> None
(** See {!ValuesUtils.symbolic_value_has_borrows} *)
@@ -403,3 +418,103 @@ let compute_contexts_ids (ctxl : C.eval_ctx list) : ids_sets * ids_to_values =
(** Compute the sets of ids found in a context. *)
let compute_context_ids (ctx : C.eval_ctx) : ids_sets * ids_to_values =
compute_contexts_ids [ ctx ]
+
+(** **WARNING**: this function doesn't compute the normalized types
+ (for the trait type aliases). This should be computed afterwards.
+ *)
+let initialize_eval_context (ctx : C.decls_ctx)
+ (region_groups : T.RegionGroupId.id list) (type_vars : T.type_var list)
+ (const_generic_vars : T.const_generic_var list) : C.eval_ctx =
+ C.reset_global_counters ();
+ let const_generic_vars_map =
+ T.ConstGenericVarId.Map.of_list
+ (List.map
+ (fun (cg : T.const_generic_var) ->
+ let ty = TypesUtils.ety_no_regions_to_rty (T.Literal cg.ty) in
+ let cv = mk_fresh_symbolic_typed_value V.ConstGeneric ty in
+ (cg.index, cv))
+ const_generic_vars)
+ in
+ {
+ C.type_context = ctx.type_ctx;
+ C.fun_context = ctx.fun_ctx;
+ C.global_context = ctx.global_ctx;
+ C.trait_decls_context = ctx.trait_decls_ctx;
+ C.trait_impls_context = ctx.trait_impls_ctx;
+ C.region_groups;
+ C.type_vars;
+ C.const_generic_vars;
+ C.const_generic_vars_map;
+ C.norm_trait_etypes = C.ETraitTypeRefMap.empty (* Empty for now *);
+ C.norm_trait_rtypes = C.RTraitTypeRefMap.empty (* Empty for now *);
+ C.norm_trait_stypes = C.STraitTypeRefMap.empty (* Empty for now *);
+ C.env = [ C.Frame ];
+ C.ended_regions = T.RegionId.Set.empty;
+ }
+
+(** Instantiate a function signature, introducing **fresh** abstraction ids and
+ region ids. This is mostly used in preparation of function calls (when
+ evaluating in symbolic mode).
+
+ Note: there are no region parameters, because they should be erased.
+ *)
+let instantiate_fun_sig (ctx : C.eval_ctx) (generics : T.egeneric_args)
+ (tr_self : T.rtrait_instance_id) (sg : A.fun_sig) : A.inst_fun_sig =
+ log#ldebug
+ (lazy
+ ("instantiate_fun_sig:" ^ "\n- generics: "
+ ^ egeneric_args_to_string ctx generics
+ ^ "\n- tr_self: "
+ ^ rtrait_instance_id_to_string ctx tr_self
+ ^ "\n- sg: " ^ fun_sig_to_string ctx sg));
+ (* Generate fresh abstraction ids and create a substitution from region
+ * group ids to abstraction ids *)
+ let rg_abs_ids_bindings =
+ List.map
+ (fun rg ->
+ let abs_id = C.fresh_abstraction_id () in
+ (rg.T.id, abs_id))
+ sg.regions_hierarchy
+ in
+ let asubst_map : V.AbstractionId.id T.RegionGroupId.Map.t =
+ List.fold_left
+ (fun mp (rg_id, abs_id) -> T.RegionGroupId.Map.add rg_id abs_id mp)
+ T.RegionGroupId.Map.empty rg_abs_ids_bindings
+ in
+ let asubst (rg_id : T.RegionGroupId.id) : V.AbstractionId.id =
+ T.RegionGroupId.Map.find rg_id asubst_map
+ in
+ (* Generate fresh regions and their substitutions *)
+ let _, rsubst, _ = Subst.fresh_regions_with_substs sg.generics.regions in
+ (* Generate the type substitution
+ * Note that we need the substitution to map the type variables to
+ * {!rty} types (not {!ety}). In order to do that, we convert the
+ * type parameters to types with regions. This is possible only
+ * if those types don't contain any regions.
+ * This is a current limitation of the analysis: there is still some
+ * work to do to properly handle full type parametrization.
+ * *)
+ let rtype_params = List.map ety_no_regions_to_rty generics.types in
+ let tsubst = Subst.make_type_subst_from_vars sg.generics.types rtype_params in
+ let cgsubst =
+ Subst.make_const_generic_subst_from_vars sg.generics.const_generics
+ generics.const_generics
+ in
+ (* TODO: something annoying with the trait ref subst: we need to use region
+ types, but the arguments use erased regions. For now we use the fact
+ that no regions should appear inside. In the future: we should merge
+ ety and rty. *)
+ let trait_refs =
+ List.map TypesUtils.etrait_ref_no_regions_to_gr_trait_ref
+ generics.trait_refs
+ in
+ let tr_subst =
+ Subst.make_trait_subst_from_clauses sg.generics.trait_clauses trait_refs
+ in
+ (* Substitute the signature *)
+ let inst_sig =
+ AssociatedTypes.ctx_subst_norm_signature ctx asubst rsubst tsubst cgsubst
+ tr_subst tr_self sg
+ in
+ (* Return *)
+ inst_sig
diff --git a/compiler/Invariants.ml b/compiler/Invariants.ml
index f29c7f88..9ac5ce13 100644
--- a/compiler/Invariants.ml
+++ b/compiler/Invariants.ml
@@ -7,6 +7,7 @@ module V = Values
module E = Expressions
module C = Contexts
module Subst = Substitute
+module Assoc = AssociatedTypes
module A = LlbcAst
module L = Logging
open Cps
@@ -406,13 +407,14 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
(match (tv.V.value, tv.V.ty) with
| V.Literal cv, T.Literal ty -> check_literal_type cv ty
(* ADT case *)
- | V.Adt av, T.Adt (T.AdtId def_id, regions, tys, cgs) ->
+ | V.Adt av, T.Adt (T.AdtId def_id, generics) ->
(* Retrieve the definition to check the variant id, the number of
* parameters, etc. *)
let def = C.ctx_lookup_type_decl ctx def_id in
(* Check the number of parameters *)
- assert (List.length regions = List.length def.region_params);
- assert (List.length tys = List.length def.type_params);
+ assert (
+ List.length generics.regions = List.length def.generics.regions);
+ assert (List.length generics.types = List.length def.generics.types);
(* Check that the variant id is consistent *)
(match (av.V.variant_id, def.T.kind) with
| Some variant_id, T.Enum variants ->
@@ -421,8 +423,8 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
| _ -> raise (Failure "Erroneous typing"));
(* Check that the field types are correct *)
let field_types =
- Subst.type_decl_get_instantiated_field_etypes def av.V.variant_id
- tys cgs
+ Assoc.type_decl_get_inst_norm_field_etypes ctx def av.V.variant_id
+ generics
in
let fields_with_types =
List.combine av.V.field_values field_types
@@ -431,20 +433,28 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
(fun ((v, ty) : V.typed_value * T.ety) -> assert (v.V.ty = ty))
fields_with_types
(* Tuple case *)
- | V.Adt av, T.Adt (T.Tuple, regions, tys, cgs) ->
- assert (regions = []);
- assert (cgs = []);
+ | V.Adt av, T.Adt (T.Tuple, generics) ->
+ assert (generics.regions = []);
+ assert (generics.const_generics = []);
assert (av.V.variant_id = None);
(* Check that the fields have the proper values - and check that there
* are as many fields as field types at the same time *)
- let fields_with_types = List.combine av.V.field_values tys in
+ let fields_with_types =
+ List.combine av.V.field_values generics.types
+ in
List.iter
(fun ((v, ty) : V.typed_value * T.ety) -> assert (v.V.ty = ty))
fields_with_types
(* Assumed type case *)
- | V.Adt av, T.Adt (T.Assumed aty_id, regions, tys, cgs) -> (
+ | V.Adt av, T.Adt (T.Assumed aty_id, generics) -> (
assert (av.V.variant_id = None || aty_id = T.Option);
- match (aty_id, av.V.field_values, regions, tys, cgs) with
+ match
+ ( aty_id,
+ av.V.field_values,
+ generics.regions,
+ generics.types,
+ generics.const_generics )
+ with
(* Box *)
| T.Box, [ inner_value ], [], [ inner_ty ], []
| T.Option, [ inner_value ], [], [ inner_ty ], [] ->
@@ -520,14 +530,17 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
(* Check the current pair (value, type) *)
(match (atv.V.value, atv.V.ty) with
(* ADT case *)
- | V.AAdt av, T.Adt (T.AdtId def_id, regions, tys, cgs) ->
+ | V.AAdt av, T.Adt (T.AdtId def_id, generics) ->
(* Retrieve the definition to check the variant id, the number of
* parameters, etc. *)
let def = C.ctx_lookup_type_decl ctx def_id in
(* Check the number of parameters *)
- assert (List.length regions = List.length def.region_params);
- assert (List.length tys = List.length def.type_params);
- assert (List.length cgs = List.length def.const_generic_params);
+ assert (
+ List.length generics.regions = List.length def.generics.regions);
+ assert (List.length generics.types = List.length def.generics.types);
+ assert (
+ List.length generics.const_generics
+ = List.length def.generics.const_generics);
(* Check that the variant id is consistent *)
(match (av.V.variant_id, def.T.kind) with
| Some variant_id, T.Enum variants ->
@@ -536,8 +549,8 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
| _ -> raise (Failure "Erroneous typing"));
(* Check that the field types are correct *)
let field_types =
- Subst.type_decl_get_instantiated_field_rtypes def av.V.variant_id
- regions tys cgs
+ Assoc.type_decl_get_inst_norm_field_rtypes ctx def av.V.variant_id
+ generics
in
let fields_with_types =
List.combine av.V.field_values field_types
@@ -546,20 +559,28 @@ let check_typing_invariant (ctx : C.eval_ctx) : unit =
(fun ((v, ty) : V.typed_avalue * T.rty) -> assert (v.V.ty = ty))
fields_with_types
(* Tuple case *)
- | V.AAdt av, T.Adt (T.Tuple, regions, tys, cgs) ->
- assert (regions = []);
- assert (cgs = []);
+ | V.AAdt av, T.Adt (T.Tuple, generics) ->
+ assert (generics.regions = []);
+ assert (generics.const_generics = []);
assert (av.V.variant_id = None);
(* Check that the fields have the proper values - and check that there
* are as many fields as field types at the same time *)
- let fields_with_types = List.combine av.V.field_values tys in
+ let fields_with_types =
+ List.combine av.V.field_values generics.types
+ in
List.iter
(fun ((v, ty) : V.typed_avalue * T.rty) -> assert (v.V.ty = ty))
fields_with_types
(* Assumed type case *)
- | V.AAdt av, T.Adt (T.Assumed aty_id, regions, tys, cgs) -> (
+ | V.AAdt av, T.Adt (T.Assumed aty_id, generics) -> (
assert (av.V.variant_id = None);
- match (aty_id, av.V.field_values, regions, tys, cgs) with
+ match
+ ( aty_id,
+ av.V.field_values,
+ generics.regions,
+ generics.types,
+ generics.const_generics )
+ with
(* Box *)
| T.Box, [ boxed_value ], [], [ boxed_ty ], [] ->
assert (boxed_value.V.ty = boxed_ty)
diff --git a/compiler/LlbcAst.ml b/compiler/LlbcAst.ml
index f4d26e18..2db859b2 100644
--- a/compiler/LlbcAst.ml
+++ b/compiler/LlbcAst.ml
@@ -11,6 +11,7 @@ type abs_region_groups = (AbstractionId.id, RegionId.id) g_region_groups
(** A function signature, after instantiation *)
type inst_fun_sig = {
regions_hierarchy : abs_region_groups;
+ trait_type_constraints : rtrait_type_constraint list;
inputs : rty list;
output : rty;
}
diff --git a/compiler/LlbcAstUtils.ml b/compiler/LlbcAstUtils.ml
index 1111c297..a982af30 100644
--- a/compiler/LlbcAstUtils.ml
+++ b/compiler/LlbcAstUtils.ml
@@ -12,3 +12,32 @@ let lookup_fun_name (fun_id : fun_id) (fun_decls : fun_decl FunDeclId.Map.t) :
match fun_id with
| Regular id -> (FunDeclId.Map.find id fun_decls).name
| Assumed aid -> Assumed.get_assumed_name aid
+
+(** Return the opaque declarations found in the crate.
+
+ [filter_assumed]: if [true], do not consider as opaque the external definitions
+ that we will map to definitions from the standard library.
+
+ Remark: the list of functions also contains the list of opaque global bodies.
+ *)
+let crate_get_opaque_decls (k : crate) (filter_assumed : bool) :
+ T.type_decl list * fun_decl list =
+ let open ExtractAssumed in
+ let is_opaque_fun (d : fun_decl) : bool =
+ let sname = name_to_simple_name d.name in
+ d.body = None
+ (* Something to pay attention to: we must ignore trait method *declarations*
+ (which don't have a body but must not be considered as opaque) *)
+ && (match d.kind with TraitMethodDecl _ -> false | _ -> true)
+ && ((not filter_assumed)
+ || not (SimpleNameMap.mem sname assumed_globals_map))
+ in
+ let is_opaque_type (d : T.type_decl) : bool = d.kind = T.Opaque in
+ (* Note that by checking the function bodies we also the globals *)
+ ( List.filter is_opaque_type (T.TypeDeclId.Map.values k.types),
+ List.filter is_opaque_fun (FunDeclId.Map.values k.functions) )
+
+(** Return true if the crate contains opaque declarations, ignoring the assumed
+ definitions. *)
+let crate_has_opaque_decls (k : crate) (filter_assumed : bool) : bool =
+ crate_get_opaque_decls k filter_assumed <> ([], [])
diff --git a/compiler/Logging.ml b/compiler/Logging.ml
index 9dc1f5e3..59abbfc7 100644
--- a/compiler/Logging.ml
+++ b/compiler/Logging.ml
@@ -9,6 +9,9 @@ let pre_passes_log = L.get_logger "MainLogger.PrePasses"
(** Logger for Translate *)
let translate_log = L.get_logger "MainLogger.Translate"
+(** Logger for Contexts *)
+let contexts_log = L.get_logger "MainLogger.Contexts"
+
(** Logger for PureUtils *)
let pure_utils_log = L.get_logger "MainLogger.PureUtils"
@@ -57,6 +60,9 @@ let borrows_log = L.get_logger "MainLogger.Interpreter.Borrows"
(** Logger for Invariants *)
let invariants_log = L.get_logger "MainLogger.Interpreter.Invariants"
+(** Logger for AssociatedTypes *)
+let associated_types_log = L.get_logger "MainLogger.AssociatedTypes"
+
(** Logger for SCC *)
let scc_log = L.get_logger "MainLogger.Graph.SCC"
diff --git a/compiler/PrePasses.ml b/compiler/PrePasses.ml
index b348ba1d..1058fab0 100644
--- a/compiler/PrePasses.ml
+++ b/compiler/PrePasses.ml
@@ -107,7 +107,7 @@ let remove_useless_cf_merges (crate : A.crate) (f : A.fun_decl) : A.fun_decl =
false
| Assign (_, rv) -> (
match rv with
- | Use _ | Ref _ -> not must_end_with_exit
+ | Use _ | RvRef _ -> not must_end_with_exit
| Aggregate (AggregatedTuple, []) -> not must_end_with_exit
| _ -> false)
| FakeRead _ | Drop _ | Nop -> not must_end_with_exit
@@ -376,7 +376,7 @@ let remove_shallow_borrows (crate : A.crate) (f : A.fun_decl) : A.fun_decl =
method! visit_Assign env p rv =
match (p.projection, rv) with
- | [], E.Ref (_, E.Shallow) ->
+ | [], E.RvRef (_, E.Shallow) ->
(* Filter *)
filtered := E.VarId.Set.add p.var_id !filtered;
Nop
diff --git a/compiler/Print.ml b/compiler/Print.ml
index 9aa73d7c..5d5c16ee 100644
--- a/compiler/Print.ml
+++ b/compiler/Print.ml
@@ -21,6 +21,9 @@ module Values = struct
type_decl_id_to_string : T.TypeDeclId.id -> string;
const_generic_var_id_to_string : T.ConstGenericVarId.id -> string;
global_decl_id_to_string : T.GlobalDeclId.id -> string;
+ trait_decl_id_to_string : T.TraitDeclId.id -> string;
+ trait_impl_id_to_string : T.TraitImplId.id -> string;
+ trait_clause_id_to_string : T.TraitClauseId.id -> string;
adt_variant_to_string : T.TypeDeclId.id -> T.VariantId.id -> string;
var_id_to_string : E.VarId.id -> string;
adt_field_names :
@@ -34,6 +37,9 @@ module Values = struct
PT.type_decl_id_to_string = fmt.type_decl_id_to_string;
PT.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string;
PT.global_decl_id_to_string = fmt.global_decl_id_to_string;
+ PT.trait_decl_id_to_string = fmt.trait_decl_id_to_string;
+ PT.trait_impl_id_to_string = fmt.trait_impl_id_to_string;
+ PT.trait_clause_id_to_string = fmt.trait_clause_id_to_string;
}
let value_to_rtype_formatter (fmt : value_formatter) : PT.rtype_formatter =
@@ -43,6 +49,9 @@ module Values = struct
PT.type_decl_id_to_string = fmt.type_decl_id_to_string;
PT.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string;
PT.global_decl_id_to_string = fmt.global_decl_id_to_string;
+ PT.trait_decl_id_to_string = fmt.trait_decl_id_to_string;
+ PT.trait_impl_id_to_string = fmt.trait_impl_id_to_string;
+ PT.trait_clause_id_to_string = fmt.trait_clause_id_to_string;
}
let value_to_stype_formatter (fmt : value_formatter) : PT.stype_formatter =
@@ -52,6 +61,9 @@ module Values = struct
PT.type_decl_id_to_string = fmt.type_decl_id_to_string;
PT.const_generic_var_id_to_string = fmt.const_generic_var_id_to_string;
PT.global_decl_id_to_string = fmt.global_decl_id_to_string;
+ PT.trait_decl_id_to_string = fmt.trait_decl_id_to_string;
+ PT.trait_impl_id_to_string = fmt.trait_impl_id_to_string;
+ PT.trait_clause_id_to_string = fmt.trait_clause_id_to_string;
}
let var_id_to_string (id : E.VarId.id) : string =
@@ -86,10 +98,10 @@ module Values = struct
List.map (typed_value_to_string fmt) av.field_values
in
match v.ty with
- | T.Adt (T.Tuple, _, _, _) ->
+ | T.Adt (T.Tuple, _) ->
(* Tuple *)
"(" ^ String.concat ", " field_values ^ ")"
- | T.Adt (T.AdtId def_id, _, _, _) ->
+ | T.Adt (T.AdtId def_id, _) ->
(* "Regular" ADT *)
let adt_ident =
match av.variant_id with
@@ -111,7 +123,7 @@ module Values = struct
let field_values = String.concat " " field_values in
adt_ident ^ " { " ^ field_values ^ " }"
else adt_ident
- | T.Adt (T.Assumed aty, _, _, _) -> (
+ | T.Adt (T.Assumed aty, _) -> (
(* Assumed type *)
match (aty, field_values) with
| Box, [ bv ] -> "@Box(" ^ bv ^ ")"
@@ -201,10 +213,10 @@ module Values = struct
List.map (typed_avalue_to_string fmt) av.field_values
in
match v.ty with
- | T.Adt (T.Tuple, _, _, _) ->
+ | T.Adt (T.Tuple, _) ->
(* Tuple *)
"(" ^ String.concat ", " field_values ^ ")"
- | T.Adt (T.AdtId def_id, _, _, _) ->
+ | T.Adt (T.AdtId def_id, _) ->
(* "Regular" ADT *)
let adt_ident =
match av.variant_id with
@@ -226,7 +238,7 @@ module Values = struct
let field_values = String.concat " " field_values in
adt_ident ^ " { " ^ field_values ^ " }"
else adt_ident
- | T.Adt (T.Assumed aty, _, _, _) -> (
+ | T.Adt (T.Assumed aty, _) -> (
(* Assumed type *)
match (aty, field_values) with
| Box, [ bv ] -> "@Box(" ^ bv ^ ")"
@@ -452,6 +464,9 @@ module Contexts = struct
PV.adt_variant_to_string = fmt.adt_variant_to_string;
PV.var_id_to_string = fmt.var_id_to_string;
PV.adt_field_names = fmt.adt_field_names;
+ PV.trait_decl_id_to_string = fmt.trait_decl_id_to_string;
+ PV.trait_impl_id_to_string = fmt.trait_impl_id_to_string;
+ PV.trait_clause_id_to_string = fmt.trait_clause_id_to_string;
}
let ast_to_value_formatter (fmt : PA.ast_formatter) : PV.value_formatter =
@@ -463,20 +478,27 @@ module Contexts = struct
let ctx_to_rtype_formatter (fmt : ctx_formatter) : PT.rtype_formatter =
PV.value_to_rtype_formatter fmt
+ let ctx_to_stype_formatter (fmt : ctx_formatter) : PT.stype_formatter =
+ PV.value_to_stype_formatter fmt
+
let eval_ctx_to_ctx_formatter (ctx : C.eval_ctx) : ctx_formatter =
- (* We shouldn't use rvar_to_string *)
- let rvar_to_string _r =
- raise (Failure "Unexpected use of rvar_to_string")
+ let rvar_to_string r =
+ (* In theory we shouldn't use rvar_to_string, but it can happen
+ when printing definitions for instance... *)
+ T.RegionVarId.to_string r
in
let r_to_string r = PT.region_id_to_string r in
let type_var_id_to_string vid =
- let v = C.lookup_type_var ctx vid in
- v.name
+ (* The context may be invalid *)
+ match C.lookup_type_var_opt ctx vid with
+ | None -> T.TypeVarId.to_string vid
+ | Some v -> v.name
in
let const_generic_var_id_to_string vid =
- let v = C.lookup_const_generic_var ctx vid in
- v.name
+ match C.lookup_const_generic_var_opt ctx vid with
+ | None -> T.ConstGenericVarId.to_string vid
+ | Some v -> v.name
in
let type_decl_id_to_string def_id =
let def = C.ctx_lookup_type_decl ctx def_id in
@@ -486,6 +508,15 @@ module Contexts = struct
let def = C.ctx_lookup_global_decl ctx def_id in
name_to_string def.name
in
+ let trait_decl_id_to_string def_id =
+ let def = C.ctx_lookup_trait_decl ctx def_id in
+ name_to_string def.name
+ in
+ let trait_impl_id_to_string def_id =
+ let def = C.ctx_lookup_trait_impl ctx def_id in
+ name_to_string def.name
+ in
+ let trait_clause_id_to_string id = PT.trait_clause_id_to_pretty_string id in
let adt_variant_to_string =
PT.type_ctx_to_adt_variant_to_string_fun ctx.type_context.type_decls
in
@@ -506,6 +537,9 @@ module Contexts = struct
adt_variant_to_string;
var_id_to_string;
adt_field_names;
+ trait_decl_id_to_string;
+ trait_impl_id_to_string;
+ trait_clause_id_to_string;
}
let eval_ctx_to_ast_formatter (ctx : C.eval_ctx) : PA.ast_formatter =
@@ -521,6 +555,15 @@ module Contexts = struct
let def = C.ctx_lookup_global_decl ctx def_id in
global_name_to_string def.name
in
+ let trait_decl_id_to_string def_id =
+ let def = C.ctx_lookup_trait_decl ctx def_id in
+ name_to_string def.name
+ in
+ let trait_impl_id_to_string def_id =
+ let def = C.ctx_lookup_trait_impl ctx def_id in
+ name_to_string def.name
+ in
+ let trait_clause_id_to_string id = PT.trait_clause_id_to_pretty_string id in
{
rvar_to_string = ctx_fmt.PV.rvar_to_string;
r_to_string = ctx_fmt.PV.r_to_string;
@@ -533,6 +576,9 @@ module Contexts = struct
adt_field_to_string;
fun_decl_id_to_string;
global_decl_id_to_string;
+ trait_decl_id_to_string;
+ trait_impl_id_to_string;
+ trait_clause_id_to_string;
}
(** Split an [env] at every occurrence of [Frame], eliminating those elements.
@@ -608,6 +654,50 @@ module EvalCtxLlbcAst = struct
let fmt = PC.ctx_to_rtype_formatter fmt in
PT.rty_to_string fmt t
+ let sty_to_string (ctx : C.eval_ctx) (t : T.sty) : string =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_stype_formatter fmt in
+ PT.sty_to_string fmt t
+
+ let etrait_ref_to_string (ctx : C.eval_ctx) (x : T.etrait_ref) : string =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_etype_formatter fmt in
+ PT.etrait_ref_to_string fmt x
+
+ let rtrait_ref_to_string (ctx : C.eval_ctx) (x : T.rtrait_ref) : string =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_rtype_formatter fmt in
+ PT.rtrait_ref_to_string fmt x
+
+ let strait_ref_to_string (ctx : C.eval_ctx) (x : T.strait_ref) : string =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_stype_formatter fmt in
+ PT.strait_ref_to_string fmt x
+
+ let etrait_instance_id_to_string (ctx : C.eval_ctx) (x : T.etrait_instance_id)
+ : string =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_etype_formatter fmt in
+ PT.etrait_instance_id_to_string fmt x
+
+ let rtrait_instance_id_to_string (ctx : C.eval_ctx) (x : T.rtrait_instance_id)
+ : string =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_rtype_formatter fmt in
+ PT.rtrait_instance_id_to_string fmt x
+
+ let strait_instance_id_to_string (ctx : C.eval_ctx) (x : T.strait_instance_id)
+ : string =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_stype_formatter fmt in
+ PT.strait_instance_id_to_string fmt x
+
+ let egeneric_args_to_string (ctx : C.eval_ctx) (x : T.egeneric_args) : string
+ =
+ let fmt = PC.eval_ctx_to_ctx_formatter ctx in
+ let fmt = PC.ctx_to_etype_formatter fmt in
+ PT.egeneric_args_to_string fmt x
+
let borrow_content_to_string (ctx : C.eval_ctx) (bc : V.borrow_content) :
string =
let fmt = PC.eval_ctx_to_ctx_formatter ctx in
@@ -653,11 +743,27 @@ module EvalCtxLlbcAst = struct
let fmt = PC.eval_ctx_to_ast_formatter ctx in
PE.operand_to_string fmt op
+ let call_to_string (ctx : C.eval_ctx) (call : A.call) : string =
+ let fmt = PC.eval_ctx_to_ast_formatter ctx in
+ PA.call_to_string fmt "" call
+
+ let fun_decl_to_string (ctx : C.eval_ctx) (f : A.fun_decl) : string =
+ let fmt = PC.eval_ctx_to_ast_formatter ctx in
+ PA.fun_decl_to_string fmt "" " " f
+
+ let fun_sig_to_string (ctx : C.eval_ctx) (x : A.fun_sig) : string =
+ let fmt = PC.eval_ctx_to_ast_formatter ctx in
+ PA.fun_sig_to_string fmt "" " " x
+
let statement_to_string (ctx : C.eval_ctx) (indent : string)
(indent_incr : string) (e : A.statement) : string =
let fmt = PC.eval_ctx_to_ast_formatter ctx in
PA.statement_to_string fmt indent indent_incr e
+ let trait_impl_to_string (ctx : C.eval_ctx) (timpl : A.trait_impl) : string =
+ let fmt = PC.eval_ctx_to_ast_formatter ctx in
+ PA.trait_impl_to_string fmt " " " " timpl
+
let env_elem_to_string (ctx : C.eval_ctx) (indent : string)
(indent_incr : string) (ev : C.env_elem) : string =
let fmt = PC.eval_ctx_to_ctx_formatter ctx in
diff --git a/compiler/PrintPure.ml b/compiler/PrintPure.ml
index cfb63ec2..5fb5978b 100644
--- a/compiler/PrintPure.ml
+++ b/compiler/PrintPure.ml
@@ -8,6 +8,9 @@ type type_formatter = {
type_decl_id_to_string : TypeDeclId.id -> string;
const_generic_var_id_to_string : ConstGenericVarId.id -> string;
global_decl_id_to_string : GlobalDeclId.id -> string;
+ trait_decl_id_to_string : TraitDeclId.id -> string;
+ trait_impl_id_to_string : TraitImplId.id -> string;
+ trait_clause_id_to_string : TraitClauseId.id -> string;
}
type value_formatter = {
@@ -18,6 +21,9 @@ type value_formatter = {
adt_variant_to_string : TypeDeclId.id -> VariantId.id -> string;
var_id_to_string : VarId.id -> string;
adt_field_names : TypeDeclId.id -> VariantId.id option -> string list option;
+ trait_decl_id_to_string : TraitDeclId.id -> string;
+ trait_impl_id_to_string : TraitImplId.id -> string;
+ trait_clause_id_to_string : TraitClauseId.id -> string;
}
let value_to_type_formatter (fmt : value_formatter) : type_formatter =
@@ -26,6 +32,9 @@ let value_to_type_formatter (fmt : value_formatter) : type_formatter =
type_decl_id_to_string = fmt.type_decl_id_to_string;
const_generic_var_id_to_string = fmt.const_generic_var_id_to_string;
global_decl_id_to_string = fmt.global_decl_id_to_string;
+ trait_decl_id_to_string = fmt.trait_decl_id_to_string;
+ trait_impl_id_to_string = fmt.trait_impl_id_to_string;
+ trait_clause_id_to_string = fmt.trait_clause_id_to_string;
}
(* TODO: we need to store which variables we have encountered so far, and
@@ -42,6 +51,9 @@ type ast_formatter = {
adt_field_names : TypeDeclId.id -> VariantId.id option -> string list option;
fun_decl_id_to_string : FunDeclId.id -> string;
global_decl_id_to_string : GlobalDeclId.id -> string;
+ trait_decl_id_to_string : TraitDeclId.id -> string;
+ trait_impl_id_to_string : TraitImplId.id -> string;
+ trait_clause_id_to_string : TraitClauseId.id -> string;
}
let ast_to_value_formatter (fmt : ast_formatter) : value_formatter =
@@ -53,6 +65,9 @@ let ast_to_value_formatter (fmt : ast_formatter) : value_formatter =
adt_variant_to_string = fmt.adt_variant_to_string;
var_id_to_string = fmt.var_id_to_string;
adt_field_names = fmt.adt_field_names;
+ trait_decl_id_to_string = fmt.trait_decl_id_to_string;
+ trait_impl_id_to_string = fmt.trait_impl_id_to_string;
+ trait_clause_id_to_string = fmt.trait_clause_id_to_string;
}
let ast_to_type_formatter (fmt : ast_formatter) : type_formatter =
@@ -70,31 +85,51 @@ let literal_type_to_string = Print.PrimitiveValues.literal_type_to_string
let scalar_value_to_string = Print.PrimitiveValues.scalar_value_to_string
let literal_to_string = Print.PrimitiveValues.literal_to_string
+(* Remark: not using generic_params on purpose, because we may use parameters
+ which either come from LLBC or from pure, and the [generic_params] type
+ for those ASTs is not the same. Note that it works because we actually don't
+ need to know the trait clauses to print the AST: we can thus ignore them.
+*)
let mk_type_formatter (type_decls : T.type_decl TypeDeclId.Map.t)
(global_decls : A.global_decl GlobalDeclId.Map.t)
- (type_params : type_var list)
+ (trait_decls : A.trait_decl TraitDeclId.Map.t)
+ (trait_impls : A.trait_impl TraitImplId.Map.t) (type_params : type_var list)
(const_generic_params : const_generic_var list) : type_formatter =
let type_var_id_to_string vid =
- let var = T.TypeVarId.nth type_params vid in
+ let var = TypeVarId.nth type_params vid in
type_var_to_string var
in
let const_generic_var_id_to_string vid =
- let var = T.ConstGenericVarId.nth const_generic_params vid in
+ let var = ConstGenericVarId.nth const_generic_params vid in
const_generic_var_to_string var
in
let type_decl_id_to_string def_id =
- let def = T.TypeDeclId.Map.find def_id type_decls in
+ let def = TypeDeclId.Map.find def_id type_decls in
name_to_string def.name
in
let global_decl_id_to_string def_id =
- let def = T.GlobalDeclId.Map.find def_id global_decls in
+ let def = GlobalDeclId.Map.find def_id global_decls in
name_to_string def.name
in
+ let trait_decl_id_to_string def_id =
+ let def = TraitDeclId.Map.find def_id trait_decls in
+ name_to_string def.name
+ in
+ let trait_impl_id_to_string def_id =
+ let def = TraitImplId.Map.find def_id trait_impls in
+ name_to_string def.name
+ in
+ let trait_clause_id_to_string id =
+ Print.PT.trait_clause_id_to_pretty_string id
+ in
{
type_var_id_to_string;
type_decl_id_to_string;
const_generic_var_id_to_string;
global_decl_id_to_string;
+ trait_decl_id_to_string;
+ trait_impl_id_to_string;
+ trait_clause_id_to_string;
}
(* TODO: there is a bit of duplication with Print.fun_decl_to_ast_formatter.
@@ -106,19 +141,21 @@ let mk_type_formatter (type_decls : T.type_decl TypeDeclId.Map.t)
let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t)
(fun_decls : A.fun_decl FunDeclId.Map.t)
(global_decls : A.global_decl GlobalDeclId.Map.t)
- (type_params : type_var list)
+ (trait_decls : A.trait_decl TraitDeclId.Map.t)
+ (trait_impls : A.trait_impl TraitImplId.Map.t) (type_params : type_var list)
(const_generic_params : const_generic_var list) : ast_formatter =
- let type_var_id_to_string vid =
- let var = T.TypeVarId.nth type_params vid in
- type_var_to_string var
- in
- let const_generic_var_id_to_string vid =
- let var = T.ConstGenericVarId.nth const_generic_params vid in
- const_generic_var_to_string var
- in
- let type_decl_id_to_string def_id =
- let def = T.TypeDeclId.Map.find def_id type_decls in
- name_to_string def.name
+ let ({
+ type_var_id_to_string;
+ type_decl_id_to_string;
+ const_generic_var_id_to_string;
+ global_decl_id_to_string;
+ trait_decl_id_to_string;
+ trait_impl_id_to_string;
+ trait_clause_id_to_string;
+ }
+ : type_formatter) =
+ mk_type_formatter type_decls global_decls trait_decls trait_impls
+ type_params const_generic_params
in
let adt_variant_to_string =
Print.Types.type_ctx_to_adt_variant_to_string_fun type_decls
@@ -137,10 +174,6 @@ let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t)
let def = FunDeclId.Map.find def_id fun_decls in
fun_name_to_string def.name
in
- let global_decl_id_to_string def_id =
- let def = GlobalDeclId.Map.find def_id global_decls in
- global_name_to_string def.name
- in
{
type_var_id_to_string;
const_generic_var_id_to_string;
@@ -151,6 +184,9 @@ let mk_ast_formatter (type_decls : T.type_decl TypeDeclId.Map.t)
adt_field_to_string;
fun_decl_id_to_string;
global_decl_id_to_string;
+ trait_decl_id_to_string;
+ trait_impl_id_to_string;
+ trait_clause_id_to_string;
}
let assumed_ty_to_string (aty : assumed_ty) : string =
@@ -182,20 +218,18 @@ let const_generic_to_string (fmt : type_formatter) (cg : T.const_generic) :
let rec ty_to_string (fmt : type_formatter) (inside : bool) (ty : ty) : string =
match ty with
- | Adt (id, tys, cgs) -> (
- let tys = List.map (ty_to_string fmt false) tys in
- let cgs = List.map (const_generic_to_string fmt) cgs in
- let params = List.append tys cgs in
+ | Adt (id, generics) -> (
match id with
| Tuple ->
- assert (cgs = []);
- "(" ^ String.concat " * " tys ^ ")"
+ let generics = generic_args_to_strings fmt false generics in
+ "(" ^ String.concat " * " generics ^ ")"
| AdtId _ | Assumed _ ->
- let params_s =
- if params = [] then "" else " " ^ String.concat " " params
+ let generics = generic_args_to_strings fmt true generics in
+ let generics_s =
+ if generics = [] then "" else " " ^ String.concat " " generics
in
- let ty_s = type_id_to_string fmt id ^ params_s in
- if params <> [] && inside then "(" ^ ty_s ^ ")" else ty_s)
+ let ty_s = type_id_to_string fmt id ^ generics_s in
+ if generics <> [] && inside then "(" ^ ty_s ^ ")" else ty_s)
| TypeVar tv -> fmt.type_var_id_to_string tv
| Literal lty -> literal_type_to_string lty
| Arrow (arg_ty, ret_ty) ->
@@ -203,6 +237,71 @@ let rec ty_to_string (fmt : type_formatter) (inside : bool) (ty : ty) : string =
ty_to_string fmt true arg_ty ^ " -> " ^ ty_to_string fmt false ret_ty
in
if inside then "(" ^ ty ^ ")" else ty
+ | TraitType (trait_ref, generics, type_name) ->
+ let trait_ref = trait_ref_to_string fmt false trait_ref in
+ let s =
+ if generics = empty_generic_args then trait_ref ^ "::" ^ type_name
+ else
+ let generics = generic_args_to_string fmt generics in
+ "(" ^ trait_ref ^ " " ^ generics ^ ")::" ^ type_name
+ in
+ if inside then "(" ^ s ^ ")" else s
+
+and generic_args_to_strings (fmt : type_formatter) (inside : bool)
+ (generics : generic_args) : string list =
+ let tys = List.map (ty_to_string fmt inside) generics.types in
+ let cgs = List.map (const_generic_to_string fmt) generics.const_generics in
+ let trait_refs =
+ List.map (trait_ref_to_string fmt inside) generics.trait_refs
+ in
+ List.concat [ tys; cgs; trait_refs ]
+
+and generic_args_to_string (fmt : type_formatter) (generics : generic_args) :
+ string =
+ String.concat " " (generic_args_to_strings fmt true generics)
+
+and trait_ref_to_string (fmt : type_formatter) (inside : bool) (tr : trait_ref)
+ : string =
+ let trait_id = trait_instance_id_to_string fmt false tr.trait_id in
+ let generics = generic_args_to_string fmt tr.generics in
+ let s = trait_id ^ generics in
+ if tr.generics = empty_generic_args || not inside then s else "(" ^ s ^ ")"
+
+and trait_instance_id_to_string (fmt : type_formatter) (inside : bool)
+ (id : trait_instance_id) : string =
+ match id with
+ | Self -> "Self"
+ | TraitImpl id -> fmt.trait_impl_id_to_string id
+ | Clause id -> fmt.trait_clause_id_to_string id
+ | ParentClause (inst_id, _decl_id, clause_id) ->
+ let inst_id = trait_instance_id_to_string fmt false inst_id in
+ let clause_id = fmt.trait_clause_id_to_string clause_id in
+ "parent(" ^ inst_id ^ ")::" ^ clause_id
+ | ItemClause (inst_id, _decl_id, item_name, clause_id) ->
+ let inst_id = trait_instance_id_to_string fmt false inst_id in
+ let clause_id = fmt.trait_clause_id_to_string clause_id in
+ "(" ^ inst_id ^ ")::" ^ item_name ^ "::[" ^ clause_id ^ "]"
+ | TraitRef tr -> trait_ref_to_string fmt inside tr
+ | UnknownTrait msg -> "UNKNOWN(" ^ msg ^ ")"
+
+let trait_clause_to_string (fmt : type_formatter) (clause : trait_clause) :
+ string =
+ let clause_id = fmt.trait_clause_id_to_string clause.clause_id in
+ let trait_id = fmt.trait_decl_id_to_string clause.trait_id in
+ let generics = generic_args_to_strings fmt true clause.generics in
+ let generics =
+ if generics = [] then "" else " " ^ String.concat " " generics
+ in
+ "[" ^ clause_id ^ "]: " ^ trait_id ^ generics
+
+let generic_params_to_strings (fmt : type_formatter) (generics : generic_params)
+ : string list =
+ let tys = List.map type_var_to_string generics.types in
+ let cgs = List.map const_generic_var_to_string generics.const_generics in
+ let trait_clauses =
+ List.map (trait_clause_to_string fmt) generics.trait_clauses
+ in
+ List.concat [ tys; cgs; trait_clauses ]
let field_to_string fmt inside (f : field) : string =
match f.field_name with
@@ -217,11 +316,10 @@ let variant_to_string fmt (v : variant) : string =
^ ")"
let type_decl_to_string (fmt : type_formatter) (def : type_decl) : string =
- let types = def.type_params in
let name = name_to_string def.name in
let params =
- if types = [] then ""
- else " " ^ String.concat " " (List.map type_var_to_string types)
+ if def.generics = empty_generic_params then ""
+ else " " ^ String.concat " " (generic_params_to_strings fmt def.generics)
in
match def.kind with
| Struct fields ->
@@ -353,10 +451,10 @@ let adt_g_value_to_string (fmt : value_formatter)
(field_values : 'v list) (ty : ty) : string =
let field_values = List.map value_to_string field_values in
match ty with
- | Adt (Tuple, _, _) ->
+ | Adt (Tuple, _) ->
(* Tuple *)
"(" ^ String.concat ", " field_values ^ ")"
- | Adt (AdtId def_id, _, _) ->
+ | Adt (AdtId def_id, _) ->
(* "Regular" ADT *)
let adt_ident =
match variant_id with
@@ -378,7 +476,7 @@ let adt_g_value_to_string (fmt : value_formatter)
let field_values = String.concat " " field_values in
adt_ident ^ " { " ^ field_values ^ " }"
else adt_ident
- | Adt (Assumed aty, _, _) -> (
+ | Adt (Assumed aty, _) -> (
(* Assumed type *)
match aty with
| State ->
@@ -464,10 +562,10 @@ let rec typed_pattern_to_string (fmt : ast_formatter) (v : typed_pattern) :
let fun_sig_to_string (fmt : ast_formatter) (sg : fun_sig) : string =
let ty_fmt = ast_to_type_formatter fmt in
- let type_params = List.map type_var_to_string sg.type_params in
+ let generics = generic_params_to_strings ty_fmt sg.generics in
let inputs = List.map (ty_to_string ty_fmt false) sg.inputs in
let output = ty_to_string ty_fmt false sg.output in
- let all_types = List.concat [ type_params; inputs; [ output ] ] in
+ let all_types = List.concat [ generics; inputs; [ output ] ] in
String.concat " -> " all_types
let inst_fun_sig_to_string (fmt : ast_formatter) (sg : inst_fun_sig) : string =
@@ -512,6 +610,7 @@ let llbc_assumed_fun_id_to_string (fid : A.assumed_fun_id) : string =
| ArrayToSliceMut -> "@ArrayToSliceMut"
| ArraySubsliceShared -> "@ArraySubsliceShared"
| ArraySubsliceMut -> "@ArraySubsliceMut"
+ | ArrayRepeat -> "@ArrayRepeat"
| SliceLen -> "@SliceLen"
| SliceIndexShared -> "@SliceIndexShared"
| SliceIndexMut -> "@SliceIndexMut"
@@ -531,8 +630,11 @@ let regular_fun_id_to_string (fmt : ast_formatter) (fun_id : fun_id) : string =
| FromLlbc (fid, lp_id, rg_id) ->
let f =
match fid with
- | Regular fid -> fmt.fun_decl_id_to_string fid
- | Assumed fid -> llbc_assumed_fun_id_to_string fid
+ | FunId (Regular fid) -> fmt.fun_decl_id_to_string fid
+ | FunId (Assumed fid) -> llbc_assumed_fun_id_to_string fid
+ | TraitMethod (trait_ref, method_name, _) ->
+ let fmt = ast_to_type_formatter fmt in
+ trait_ref_to_string fmt true trait_ref ^ "." ^ method_name
in
f ^ fun_suffix lp_id rg_id
| Pure fid -> pure_assumed_fun_id_to_string fid
@@ -559,9 +661,8 @@ let fun_or_op_id_to_string (fmt : ast_formatter) (fun_id : fun_or_op_id) :
let rec texpression_to_string (fmt : ast_formatter) (inside : bool)
(indent : string) (indent_incr : string) (e : texpression) : string =
match e.e with
- | Var var_id ->
- let s = fmt.var_id_to_string var_id in
- if inside then "(" ^ s ^ ")" else s
+ | Var var_id -> fmt.var_id_to_string var_id
+ | CVar cg_id -> fmt.const_generic_var_id_to_string cg_id
| Const cv -> literal_to_string cv
| App _ ->
(* Recursively destruct the app, to have a pair (app, arguments list) *)
@@ -632,10 +733,11 @@ and app_to_string (fmt : ast_formatter) (inside : bool) (indent : string)
(* There are two possibilities: either the [app] is an instantiated,
* top-level qualifier (function, ADT constructore...), or it is a "regular"
* expression *)
- let app, tys =
+ let app, generics =
match app.e with
| Qualif qualif ->
(* Qualifier case *)
+ let ty_fmt = ast_to_type_formatter fmt in
(* Convert the qualifier identifier *)
let qualif_s =
match qualif.id with
@@ -654,12 +756,17 @@ and app_to_string (fmt : ast_formatter) (inside : bool) (indent : string)
let field_s = adt_field_to_string value_fmt adt_id field_id in
(* Adopting an F*-like syntax *)
ConstStrings.constructor_prefix ^ adt_s ^ "?." ^ field_s
+ | TraitConst (trait_ref, generics, const_name) ->
+ let trait_ref = trait_ref_to_string ty_fmt true trait_ref in
+ let generics_s = generic_args_to_string ty_fmt generics in
+ if generics <> empty_generic_args then
+ "(" ^ trait_ref ^ generics_s ^ ")." ^ const_name
+ else trait_ref ^ "." ^ const_name
in
(* Convert the type instantiation *)
- let ty_fmt = ast_to_type_formatter fmt in
- let tys = List.map (ty_to_string ty_fmt true) qualif.type_args in
+ let generics = generic_args_to_strings ty_fmt true qualif.generics in
(* *)
- (qualif_s, tys)
+ (qualif_s, generics)
| _ ->
(* "Regular" expression case *)
let inside = args <> [] || (args = [] && inside) in
@@ -674,7 +781,7 @@ and app_to_string (fmt : ast_formatter) (inside : bool) (indent : string)
texpression_to_string fmt inside indent1 indent_incr
in
let args = List.map arg_to_string args in
- let all_args = List.append tys args in
+ let all_args = List.append generics args in
(* Put together *)
let e =
if all_args = [] then app else app ^ " " ^ String.concat " " all_args
diff --git a/compiler/Pure.ml b/compiler/Pure.ml
index ac4ca081..47c7beb4 100644
--- a/compiler/Pure.ml
+++ b/compiler/Pure.ml
@@ -13,6 +13,9 @@ module FieldId = T.FieldId
module SymbolicValueId = V.SymbolicValueId
module FunDeclId = A.FunDeclId
module GlobalDeclId = A.GlobalDeclId
+module TraitDeclId = T.TraitDeclId
+module TraitImplId = T.TraitImplId
+module TraitClauseId = T.TraitClauseId
(** We redefine identifiers for loop: in {!Values}, the identifiers are global
(they monotonically increase across functions) while in {!module:Pure} we want
@@ -37,6 +40,13 @@ module ConstGenericVarId = T.ConstGenericVarId
type integer_type = T.integer_type [@@deriving show, ord]
type const_generic_var = T.const_generic_var [@@deriving show, ord]
type const_generic = T.const_generic [@@deriving show, ord]
+type const_generic_var_id = T.const_generic_var_id [@@deriving show, ord]
+type trait_decl_id = T.trait_decl_id [@@deriving show, ord]
+type trait_impl_id = T.trait_impl_id [@@deriving show, ord]
+type trait_clause_id = T.trait_clause_id [@@deriving show, ord]
+type trait_item_name = T.trait_item_name [@@deriving show, ord]
+type global_decl_id = T.global_decl_id [@@deriving show, ord]
+type fun_decl_id = A.fun_decl_id [@@deriving show, ord]
(** The assumed types for the pure AST.
@@ -176,6 +186,14 @@ class ['self] iter_ty_base =
inherit! [_] T.iter_const_generic
inherit! [_] PV.iter_literal_type
method visit_type_var_id : 'env -> type_var_id -> unit = fun _ _ -> ()
+ method visit_trait_decl_id : 'env -> trait_decl_id -> unit = fun _ _ -> ()
+ method visit_trait_impl_id : 'env -> trait_impl_id -> unit = fun _ _ -> ()
+
+ method visit_trait_clause_id : 'env -> trait_clause_id -> unit =
+ fun _ _ -> ()
+
+ method visit_trait_item_name : 'env -> trait_item_name -> unit =
+ fun _ _ -> ()
end
(** Ancestor for map visitor for [ty] *)
@@ -185,6 +203,18 @@ class ['self] map_ty_base =
inherit! [_] T.map_const_generic
inherit! [_] PV.map_literal_type
method visit_type_var_id : 'env -> type_var_id -> type_var_id = fun _ x -> x
+
+ method visit_trait_decl_id : 'env -> trait_decl_id -> trait_decl_id =
+ fun _ x -> x
+
+ method visit_trait_impl_id : 'env -> trait_impl_id -> trait_impl_id =
+ fun _ x -> x
+
+ method visit_trait_clause_id : 'env -> trait_clause_id -> trait_clause_id =
+ fun _ x -> x
+
+ method visit_trait_item_name : 'env -> trait_item_name -> trait_item_name =
+ fun _ x -> x
end
(** Ancestor for reduce visitor for [ty] *)
@@ -194,6 +224,18 @@ class virtual ['self] reduce_ty_base =
inherit! [_] T.reduce_const_generic
inherit! [_] PV.reduce_literal_type
method visit_type_var_id : 'env -> type_var_id -> 'a = fun _ _ -> self#zero
+
+ method visit_trait_decl_id : 'env -> trait_decl_id -> 'a =
+ fun _ _ -> self#zero
+
+ method visit_trait_impl_id : 'env -> trait_impl_id -> 'a =
+ fun _ _ -> self#zero
+
+ method visit_trait_clause_id : 'env -> trait_clause_id -> 'a =
+ fun _ _ -> self#zero
+
+ method visit_trait_item_name : 'env -> trait_item_name -> 'a =
+ fun _ _ -> self#zero
end
(** Ancestor for mapreduce visitor for [ty] *)
@@ -205,10 +247,24 @@ class virtual ['self] mapreduce_ty_base =
method visit_type_var_id : 'env -> type_var_id -> type_var_id * 'a =
fun _ x -> (x, self#zero)
+
+ method visit_trait_decl_id : 'env -> trait_decl_id -> trait_decl_id * 'a =
+ fun _ x -> (x, self#zero)
+
+ method visit_trait_impl_id : 'env -> trait_impl_id -> trait_impl_id * 'a =
+ fun _ x -> (x, self#zero)
+
+ method visit_trait_clause_id
+ : 'env -> trait_clause_id -> trait_clause_id * 'a =
+ fun _ x -> (x, self#zero)
+
+ method visit_trait_item_name
+ : 'env -> trait_item_name -> trait_item_name * 'a =
+ fun _ x -> (x, self#zero)
end
type ty =
- | Adt of type_id * ty list * const_generic list
+ | Adt of type_id * generic_args
(** {!Adt} encodes ADTs and tuples and assumed types.
TODO: what about the ended regions? (ADTs may be parameterized
@@ -219,8 +275,38 @@ type ty =
| TypeVar of type_var_id
| Literal of literal_type
| Arrow of ty * ty
+ | TraitType of trait_ref * generic_args * string
+ (** The string is for the name of the associated type *)
+
+and trait_ref = {
+ trait_id : trait_instance_id;
+ generics : generic_args;
+ trait_decl_ref : trait_decl_ref;
+}
+
+and trait_decl_ref = {
+ trait_decl_id : trait_decl_id;
+ decl_generics : generic_args; (* The name: annoying field collisions... *)
+}
+
+and generic_args = {
+ types : ty list;
+ const_generics : const_generic list;
+ trait_refs : trait_ref list;
+}
+
+and trait_instance_id =
+ | Self
+ | TraitImpl of trait_impl_id
+ | Clause of trait_clause_id
+ | ParentClause of trait_instance_id * trait_decl_id * trait_clause_id
+ | ItemClause of
+ trait_instance_id * trait_decl_id * trait_item_name * trait_clause_id
+ | TraitRef of trait_ref
+ | UnknownTrait of string
[@@deriving
show,
+ ord,
visitors
{
name = "iter_ty";
@@ -264,12 +350,37 @@ type type_decl_kind = Struct of field list | Enum of variant list | Opaque
type type_var = T.type_var [@@deriving show]
+type trait_clause = {
+ clause_id : trait_clause_id;
+ trait_id : trait_decl_id;
+ generics : generic_args;
+}
+[@@deriving show]
+
+type generic_params = {
+ types : type_var list;
+ const_generics : const_generic_var list;
+ trait_clauses : trait_clause list;
+}
+[@@deriving show]
+
+type trait_type_constraint = {
+ trait_ref : trait_ref;
+ generics : generic_args;
+ type_name : trait_item_name;
+ ty : ty;
+}
+[@@deriving show, ord]
+
+type predicates = { trait_type_constraints : trait_type_constraint list }
+[@@deriving show]
+
type type_decl = {
def_id : TypeDeclId.id;
name : name;
- type_params : type_var list;
- const_generic_params : const_generic_var list;
+ generics : generic_params;
kind : type_decl_kind;
+ preds : predicates;
}
[@@deriving show]
@@ -420,8 +531,15 @@ type pure_assumed_fun_id =
| FuelEqZero (** Test if some fuel is equal to 0 - TODO: ugly *)
[@@deriving show, ord]
+type fun_id_or_trait_method_ref =
+ | FunId of A.fun_id
+ | TraitMethod of trait_ref * string * fun_decl_id
+ (** The fun decl id is not really needed and here for convenience purposes *)
+[@@deriving show, ord]
+
(** A function id for a non-assumed function *)
-type regular_fun_id = A.fun_id * LoopId.id option * T.RegionGroupId.id option
+type regular_fun_id =
+ fun_id_or_trait_method_ref * LoopId.id option * T.RegionGroupId.id option
[@@deriving show, ord]
(** A function identifier *)
@@ -457,23 +575,20 @@ type projection = { adt_id : type_id; field_id : FieldId.id } [@@deriving show]
type qualif_id =
| FunOrOp of fun_or_op_id (** A function or an operation *)
- | Global of GlobalDeclId.id
+ | Global of global_decl_id
| AdtCons of adt_cons_id (** A function or ADT constructor identifier *)
| Proj of projection (** Field projector *)
+ | TraitConst of trait_ref * generic_args * string
+ (** A trait associated constant *)
[@@deriving show]
-(** An instantiated qualified.
+(** An instantiated qualifier.
Note that for now we have a clear separation between types and expressions,
- which explains why we have the [type_params] field: a function or ADT
+ which explains why we have the [generics] field: a function or ADT
constructor is always fully instantiated.
*)
-type qualif = {
- id : qualif_id;
- type_args : ty list;
- const_generic_args : const_generic list;
-}
-[@@deriving show]
+type qualif = { id : qualif_id; generics : generic_args } [@@deriving show]
type field_id = FieldId.id [@@deriving show, ord]
type var_id = VarId.id [@@deriving show, ord]
@@ -536,6 +651,7 @@ class virtual ['self] mapreduce_expression_base =
*)
type expression =
| Var of var_id (** a variable *)
+ | CVar of const_generic_var_id (** a const generic var *)
| Const of literal
| App of texpression * texpression
(** Application of a function to an argument.
@@ -787,11 +903,11 @@ type fun_sig_info = {
- etc.
*)
type fun_sig = {
- type_params : type_var list;
- const_generic_params : const_generic_var list;
+ generics : generic_params;
(** TODO: we should analyse the signature to make the type parameters implicit whenever possible *)
+ preds : predicates;
inputs : ty list;
- (** The input types.
+ (** The types of the inputs.
Note that those input types take into account the [fuel] parameter,
if the function uses fuel for termination, and the [state] parameter,
@@ -861,8 +977,11 @@ type fun_body = {
}
[@@deriving show]
+type fun_kind = A.fun_kind [@@deriving show]
+
type fun_decl = {
def_id : FunDeclId.id;
+ kind : fun_kind;
num_loops : int;
(** The number of loops in the parent forward function (basically the number
of loops appearing in the original Rust functions, unless some loops are
@@ -882,3 +1001,29 @@ type fun_decl = {
body : fun_body option;
}
[@@deriving show]
+
+type trait_decl = {
+ def_id : trait_decl_id;
+ name : name;
+ generics : generic_params;
+ preds : predicates;
+ all_trait_clauses : trait_clause list;
+ consts : (trait_item_name * (ty * global_decl_id option)) list;
+ types : (trait_item_name * (trait_clause list * ty option)) list;
+ required_methods : (trait_item_name * fun_decl_id) list;
+ provided_methods : (trait_item_name * fun_decl_id option) list;
+}
+[@@deriving show]
+
+type trait_impl = {
+ def_id : trait_impl_id;
+ name : name;
+ impl_trait : trait_decl_ref;
+ generics : generic_params;
+ preds : predicates;
+ consts : (trait_item_name * (ty * global_decl_id)) list;
+ types : (trait_item_name * (trait_ref list * ty)) list;
+ required_methods : (trait_item_name * fun_decl_id) list;
+ provided_methods : (trait_item_name * fun_decl_id) list;
+}
+[@@deriving show]
diff --git a/compiler/PureMicroPasses.ml b/compiler/PureMicroPasses.ml
index b6025df4..cedc3559 100644
--- a/compiler/PureMicroPasses.ml
+++ b/compiler/PureMicroPasses.ml
@@ -376,8 +376,7 @@ let compute_pretty_names (def : fun_decl) : fun_decl =
let ty = e.ty in
let ctx, e =
match e.e with
- | Var _ -> (* Nothing to do *) (ctx, e.e)
- | Const _ -> (* Nothing to do *) (ctx, e.e)
+ | Var _ | CVar _ | Const _ -> (* Nothing to do *) (ctx, e.e)
| App (app, arg) ->
let ctx, app = update_texpression app ctx in
let ctx, arg = update_texpression arg ctx in
@@ -584,13 +583,10 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
| Qualif
{
id = AdtCons { adt_id = AdtId adt_id; variant_id = None };
- type_args = _;
- const_generic_args = _;
+ generics = _;
} ->
(* Lookup the def *)
- let decl =
- TypeDeclId.Map.find adt_id ctx.type_context.type_decls
- in
+ let decl = TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls in
(* Check that there are as many arguments as there are fields - note
that the def should have a body (otherwise we couldn't use the
constructor) *)
@@ -599,8 +595,7 @@ let intro_struct_updates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
(* Check if the definition is recursive *)
let is_rec =
match
- TypeDeclId.Map.find adt_id
- ctx.type_context.type_decls_groups
+ TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls_groups
with
| NonRec _ -> false
| Rec _ -> true
@@ -682,8 +677,8 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool)
| _ -> false
in
(* And either:
- * 2.1 the right-expression is a variable or a global *)
- let var_or_global = is_var re || is_global re in
+ * 2.1 the right-expression is a variable, a global or a const generic var *)
+ let var_or_global = is_var re || is_cvar re || is_global re in
(* Or:
* 2.2 the right-expression is a constant value, an ADT value,
* a projection or a primitive function call *and* the flag
@@ -767,10 +762,10 @@ let inline_useless_var_reassignments (inline_named : bool) (inline_pure : bool)
In this situation, we can remove the call [f@fwd x].
*)
let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
- (id0 : A.fun_id) (lp_id0 : LoopId.id option)
- (rg_id0 : T.RegionGroupId.id option) (tys0 : ty list)
+ (id0 : fun_id_or_trait_method_ref) (lp_id0 : LoopId.id option)
+ (rg_id0 : T.RegionGroupId.id option) (generics0 : generic_args)
(args0 : texpression list) (e : texpression) : bool =
- let check_call (fun_id1 : fun_or_op_id) (tys1 : ty list)
+ let check_call (fun_id1 : fun_or_op_id) (generics1 : generic_args)
(args1 : texpression list) : bool =
(* Check the fun_ids, to see if call1's function is a child of call0's function *)
match fun_id1 with
@@ -793,7 +788,12 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
(* We need to use the regions hierarchy *)
(* First, lookup the signature of the LLBC function *)
let sg =
- LlbcAstUtils.lookup_fun_sig id0 ctx.fun_context.fun_decls
+ let id0 =
+ match id0 with
+ | FunId fun_id -> fun_id
+ | TraitMethod (_, _, fun_decl_id) -> A.Regular fun_decl_id
+ in
+ LlbcAstUtils.lookup_fun_sig id0 ctx.fun_ctx.fun_decls
in
(* Compute the set of ancestors of the function in call1 *)
let call1_ancestors =
@@ -817,8 +817,8 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
let input_eq (v0, v1) =
PureUtils.remove_meta v0 = PureUtils.remove_meta v1
in
- (* Compare the input types and the prefix of the input arguments *)
- tys0 = tys1 && List.for_all input_eq args
+ (* Compare the generics and the prefix of the input arguments *)
+ generics0 = generics1 && List.for_all input_eq args
else (* Not a child *)
false
else (* Not the same function *)
@@ -834,7 +834,7 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
method! visit_texpression env e =
match e.e with
- | Var _ | Const _ -> fun _ -> false
+ | Var _ | CVar _ | Const _ -> fun _ -> false
| StructUpdate _ ->
(* There shouldn't be monadic calls in structure updates - also
note that by returning [false] we are conservative: we might
@@ -844,8 +844,8 @@ let expression_contains_child_call_in_all_paths (ctx : trans_ctx)
| Let (_, _, re, e) -> (
match opt_destruct_function_call re with
| None -> fun () -> self#visit_texpression env e ()
- | Some (func1, tys1, args1) ->
- let call_is_child = check_call func1 tys1 args1 in
+ | Some (func1, generics1, args1) ->
+ let call_is_child = check_call func1 generics1 args1 in
if call_is_child then fun () -> true
else fun () -> self#visit_texpression env e ())
| App _ -> (
@@ -930,7 +930,7 @@ let filter_useless (filter_monadic_calls : bool) (ctx : trans_ctx)
method! visit_expression env e =
match e with
- | Var _ | Const _ | App _ | Qualif _
+ | Var _ | CVar _ | Const _ | App _ | Qualif _
| Switch (_, _)
| Meta (_, _)
| StructUpdate _ | Abs _ ->
@@ -1086,13 +1086,12 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
| Qualif
{
id = AdtCons { adt_id = AdtId adt_id; variant_id = None };
- type_args;
- const_generic_args;
+ generics;
} ->
(* This is a struct *)
(* Retrieve the definiton, to find how many fields there are *)
let adt_decl =
- TypeDeclId.Map.find adt_id ctx.type_context.type_decls
+ TypeDeclId.Map.find adt_id ctx.type_ctx.type_decls
in
let fields =
match adt_decl.kind with
@@ -1108,7 +1107,7 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
* [x.field] for some variable [x], and where the projection
* is for the proper ADT *)
let to_var_proj (i : int) (arg : texpression) :
- (ty list * const_generic list * var_id) option =
+ (generic_args * var_id) option =
match arg.e with
| App (proj, x) -> (
match (proj.e, x.e) with
@@ -1116,16 +1115,14 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
{
id =
Proj { adt_id = AdtId proj_adt_id; field_id };
- type_args = proj_type_args;
- const_generic_args = proj_const_generic_args;
+ generics = proj_generics;
},
Var v ) ->
(* We check that this is the proper ADT, and the proper field *)
if
proj_adt_id = adt_id
&& FieldId.to_int field_id = i
- then
- Some (proj_type_args, proj_const_generic_args, v)
+ then Some (proj_generics, v)
else None
| _ -> None)
| _ -> None
@@ -1136,14 +1133,13 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
if List.length args = num_fields then
(* Check that this is the same variable we project from -
* note that we checked above that there is at least one field *)
- let (_, _, x), end_args = Collections.List.pop args in
- if List.for_all (fun (_, _, y) -> y = x) end_args then (
+ let (_, x), end_args = Collections.List.pop args in
+ if List.for_all (fun (_, y) -> y = x) end_args then (
(* We can substitute *)
(* Sanity check: all types correct *)
assert (
List.for_all
- (fun (tys, cgs, _) ->
- tys = type_args && cgs = const_generic_args)
+ (fun (generics1, _) -> generics1 = generics)
args);
{ e with e = Var x })
else super#visit_texpression env e
@@ -1162,8 +1158,7 @@ let simplify_aggregates (ctx : trans_ctx) (def : fun_decl) : fun_decl =
| ( Qualif
{
id = Proj { adt_id = AdtId proj_adt_id; field_id };
- type_args = _;
- const_generic_args = _;
+ generics = _;
},
Var v ) ->
(* We check that this is the proper ADT, and the proper field *)
@@ -1361,8 +1356,8 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
let loop_sig =
{
- type_params = fun_sig.type_params;
- const_generic_params = fun_sig.const_generic_params;
+ generics = fun_sig.generics;
+ preds = fun_sig.preds;
inputs = inputs_tys;
output;
doutputs;
@@ -1427,6 +1422,7 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
let loop_def =
{
def_id = def.def_id;
+ kind = def.kind;
num_loops;
loop_id = Some loop.loop_id;
back_id = def.back_id;
@@ -1466,13 +1462,12 @@ let decompose_loops (def : fun_decl) : fun_decl * fun_decl list =
In such situation, we can remove the forward function definition
altogether.
*)
-let keep_forward (trans : pure_fun_translation) : bool =
- let (fwd, _), backs = trans in
+let keep_forward (fwd : fun_and_loops) (backs : fun_and_loops list) : bool =
(* Note that at this point, the output types are no longer seen as tuples:
* they should be lists of length 1. *)
if
!Config.filter_useless_functions
- && fwd.signature.output = mk_result_ty mk_unit_ty
+ && fwd.f.signature.output = mk_result_ty mk_unit_ty
&& backs <> []
then false
else true
@@ -1528,7 +1523,7 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
match opt_destruct_function_call e with
| Some (fun_id, _tys, args) -> (
match fun_id with
- | Fun (FromLlbc (A.Assumed aid, _lp_id, rg_id)) -> (
+ | Fun (FromLlbc (FunId (A.Assumed aid), _lp_id, rg_id)) -> (
(* Below, when dealing with the arguments: we consider the very
* general case, where functions could be boxed (meaning we
* could have: [box_new f x])
@@ -1568,7 +1563,7 @@ let eliminate_box_functions (_ctx : trans_ctx) (def : fun_decl) : fun_decl =
| ArraySubsliceMut | SliceIndexShared | SliceIndexMut
| SliceSubsliceShared | SliceSubsliceMut | ArrayIndexShared
| ArrayIndexMut | ArrayToSliceShared | ArrayToSliceMut
- | SliceLen ),
+ | ArrayRepeat | SliceLen ),
_ ) ->
super#visit_texpression env e)
| _ -> super#visit_texpression env e)
@@ -1914,7 +1909,7 @@ let apply_end_passes_to_def (ctx : trans_ctx) (def : fun_decl) : fun_decl =
[ctx]: used only for printing.
*)
let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
- (fun_decl * fun_decl list) option =
+ fun_and_loops option =
(* Debug *)
log#ldebug
(lazy
@@ -1955,9 +1950,9 @@ let apply_passes_to_def (ctx : trans_ctx) (def : fun_decl) :
let def, loops = decompose_loops def in
(* Apply the remaining passes *)
- let def = apply_end_passes_to_def ctx def in
+ let f = apply_end_passes_to_def ctx def in
let loops = List.map (apply_end_passes_to_def ctx) loops in
- Some (def, loops)
+ Some { f; loops }
(** Small utility for {!filter_loop_inputs} *)
let filter_prefix (keep : bool list) (ls : 'a list) : 'a list =
@@ -1983,8 +1978,8 @@ end
module FunLoopIdMap = Collections.MakeMap (FunLoopIdOrderedType)
(** Filter the useless loop input parameters. *)
-let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
- (bool * pure_fun_translation) list =
+let filter_loop_inputs (transl : pure_fun_translation list) :
+ pure_fun_translation list =
(* We need to explore groups of mutually recursive functions. In order
to compute which parameters are useless, we need to explore the
functions by groups of mutually recursive definitions.
@@ -2002,10 +1997,11 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
(List.concat
(List.concat
(List.map
- (fun (_, ((fwd, loops_fwd), backs)) ->
- [ fwd :: loops_fwd ]
+ (fun { fwd; backs; _ } ->
+ [ fwd.f :: fwd.loops ]
:: List.map
- (fun (back, loops_back) -> [ back :: loops_back ])
+ (fun { f = back; loops = loops_back } ->
+ [ back :: loops_back ])
backs)
transl)))
in
@@ -2030,7 +2026,6 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
additional parameters.
*)
let used_map = ref FunLoopIdMap.empty in
- let fun_id_to_fun_loop_id (fid, loop_id, _) = (fid, loop_id) in
(* We start by computing the filtering information, for each function *)
let compute_one_filter_info (decl : fun_decl) =
@@ -2075,8 +2070,8 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
match e_app.e with
| Qualif qualif -> (
match qualif.id with
- | FunOrOp (Fun (FromLlbc fun_id')) ->
- if fun_id_to_fun_loop_id fun_id' = fun_id then (
+ | FunOrOp (Fun (FromLlbc (FunId fun_id', loop_id', _))) ->
+ if (fun_id', loop_id') = fun_id then (
(* For each argument, check if it is exactly the original
input parameter. Note that there shouldn't be partial
applications of loop functions: the number of arguments
@@ -2135,22 +2130,15 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
(* We then apply the filtering to all the function definitions at once *)
let filter_in_one (decl : fun_decl) : fun_decl =
(* Filter the function signature *)
- let fun_id = (A.Regular decl.def_id, decl.loop_id, decl.back_id) in
+ let fun_id = (A.Regular decl.def_id, decl.loop_id) in
let decl =
- match FunLoopIdMap.find_opt (fun_id_to_fun_loop_id fun_id) !used_map with
+ match FunLoopIdMap.find_opt fun_id !used_map with
| None -> (* Nothing to filter *) decl
| Some used_info ->
let num_filtered =
List.length (List.filter (fun b -> not b) used_info)
in
- let {
- type_params;
- const_generic_params;
- inputs;
- output;
- doutputs;
- info;
- } =
+ let { generics; preds; inputs; output; doutputs; info } =
decl.signature
in
let {
@@ -2178,16 +2166,7 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
effect_info;
}
in
- let signature =
- {
- type_params;
- const_generic_params;
- inputs;
- output;
- doutputs;
- info;
- }
- in
+ let signature = { generics; preds; inputs; output; doutputs; info } in
{ decl with signature }
in
@@ -2201,9 +2180,7 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
let { inputs; inputs_lvs; body } = body in
let inputs, inputs_lvs =
- match
- FunLoopIdMap.find_opt (fun_id_to_fun_loop_id fun_id) !used_map
- with
+ match FunLoopIdMap.find_opt fun_id !used_map with
| None -> (* Nothing to filter *) (inputs, inputs_lvs)
| Some used_info ->
let inputs = filter_prefix used_info inputs in
@@ -2223,11 +2200,10 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
match e_app.e with
| Qualif qualif -> (
match qualif.id with
- | FunOrOp (Fun (FromLlbc fun_id)) -> (
+ | FunOrOp (Fun (FromLlbc (FunId fun_id, loop_id, _)))
+ -> (
match
- FunLoopIdMap.find_opt
- (fun_id_to_fun_loop_id fun_id)
- !used_map
+ FunLoopIdMap.find_opt (fun_id, loop_id) !used_map
with
| None -> super#visit_texpression env e
| Some used_info ->
@@ -2267,13 +2243,13 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
in
let transl =
List.map
- (fun (b, (fwd, backs)) ->
- let filter_fun_and_loops (f, fl) =
- (filter_in_one f, List.map filter_in_one fl)
+ (fun trans ->
+ let filter_fun_and_loops f =
+ { f = filter_in_one f.f; loops = List.map filter_in_one f.loops }
in
- let fwd = filter_fun_and_loops fwd in
- let backs = List.map filter_fun_and_loops backs in
- (b, (fwd, backs)))
+ let fwd = filter_fun_and_loops trans.fwd in
+ let backs = List.map filter_fun_and_loops trans.backs in
+ { trans with fwd; backs })
transl
in
@@ -2294,18 +2270,17 @@ let filter_loop_inputs (transl : (bool * pure_fun_translation) list) :
but convenient.
*)
let apply_passes_to_pure_fun_translations (ctx : trans_ctx)
- (transl : (fun_decl * fun_decl list) list) :
- (bool * pure_fun_translation) list =
- let apply_to_one (trans : fun_decl * fun_decl list) :
- bool * pure_fun_translation =
+ (transl : (fun_decl * fun_decl list) list) : pure_fun_translation list =
+ let apply_to_one (trans : fun_decl * fun_decl list) : pure_fun_translation =
(* Apply the passes to the individual functions *)
- let forward, backwards = trans in
- let forward = Option.get (apply_passes_to_def ctx forward) in
- let backwards = List.filter_map (apply_passes_to_def ctx) backwards in
- let trans = (forward, backwards) in
+ let fwd, backs = trans in
+ let fwd = Option.get (apply_passes_to_def ctx fwd) in
+ let backs = List.filter_map (apply_passes_to_def ctx) backs in
(* Compute whether we need to filter the forward function or not *)
- (keep_forward trans, trans)
+ let keep_fwd = keep_forward fwd backs in
+ { keep_fwd; fwd; backs }
in
+
let transl = List.map apply_to_one transl in
(* Filter the useless inputs in the loop functions *)
diff --git a/compiler/PureTypeCheck.ml b/compiler/PureTypeCheck.ml
index 8d28bb8a..b80ff72f 100644
--- a/compiler/PureTypeCheck.ml
+++ b/compiler/PureTypeCheck.ml
@@ -9,17 +9,19 @@ open PureUtils
of fields is fixed: it shouldn't be used for arrays, slices, etc.
*)
let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t)
- (type_id : type_id) (variant_id : VariantId.id option) (tys : ty list)
- (cgs : const_generic list) : ty list =
+ (type_id : type_id) (variant_id : VariantId.id option)
+ (generics : generic_args) : ty list =
match type_id with
| Tuple ->
(* Tuple *)
+ assert (generics.const_generics = []);
+ assert (generics.trait_refs = []);
assert (variant_id = None);
- tys
+ generics.types
| AdtId def_id ->
(* "Regular" ADT *)
let def = TypeDeclId.Map.find def_id type_decls in
- type_decl_get_instantiated_fields_types def variant_id tys cgs
+ type_decl_get_instantiated_fields_types def variant_id generics
| Assumed aty -> (
(* Assumed type *)
match aty with
@@ -27,14 +29,14 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t)
(* This type is opaque *)
raise (Failure "Unreachable: opaque type")
| Result ->
- let ty = Collections.List.to_cons_nil tys in
+ let ty = Collections.List.to_cons_nil generics.types in
let variant_id = Option.get variant_id in
if variant_id = result_return_id then [ ty ]
else if variant_id = result_fail_id then [ mk_error_ty ]
else
raise (Failure "Unreachable: improper variant id for result type")
| Error ->
- assert (tys = []);
+ assert (generics = empty_generic_args);
let variant_id = Option.get variant_id in
assert (
variant_id = error_failure_id || variant_id = error_out_of_fuel_id);
@@ -45,14 +47,14 @@ let get_adt_field_types (type_decls : type_decl TypeDeclId.Map.t)
else if variant_id = fuel_succ_id then [ mk_fuel_ty ]
else raise (Failure "Unreachable: improper variant id for fuel type")
| Option ->
- let ty = Collections.List.to_cons_nil tys in
+ let ty = Collections.List.to_cons_nil generics.types in
let variant_id = Option.get variant_id in
if variant_id = option_some_id then [ ty ]
else if variant_id = option_none_id then []
else
raise (Failure "Unreachable: improper variant id for option type")
| Range ->
- let ty = Collections.List.to_cons_nil tys in
+ let ty = Collections.List.to_cons_nil generics.types in
assert (variant_id = None);
[ ty; ty ]
| Vec | Array | Slice | Str ->
@@ -65,6 +67,9 @@ type tc_ctx = {
global_decls : A.global_decl A.GlobalDeclId.Map.t;
(** The global declarations *)
env : ty VarId.Map.t; (** Environment from variables to types *)
+ const_generics : ty T.ConstGenericVarId.Map.t;
+ (** The types of the const generics *)
+ (* TODO: add trait type constraints *)
}
let check_literal (v : literal) (ty : literal_type) : unit =
@@ -86,12 +91,13 @@ let rec check_typed_pattern (ctx : tc_ctx) (v : typed_pattern) : tc_ctx =
{ ctx with env }
| PatAdt av ->
(* Compute the field types *)
- let type_id, tys, cgs = ty_as_adt v.ty in
+ let type_id, generics = ty_as_adt v.ty in
let field_tys =
- get_adt_field_types ctx.type_decls type_id av.variant_id tys cgs
+ get_adt_field_types ctx.type_decls type_id av.variant_id generics
in
let check_value (ctx : tc_ctx) (ty : ty) (v : typed_pattern) : tc_ctx =
if ty <> v.ty then (
+ (* TODO: we need to normalize the types *)
log#serror
("check_typed_pattern: not the same types:" ^ "\n- ty: "
^ show_ty ty ^ "\n- v.ty: " ^ show_ty v.ty);
@@ -115,6 +121,9 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
match VarId.Map.find_opt var_id ctx.env with
| None -> ()
| Some ty -> assert (ty = e.ty))
+ | CVar cg_id ->
+ let ty = T.ConstGenericVarId.Map.find cg_id ctx.const_generics in
+ assert (ty = e.ty)
| Const cv -> check_literal cv (ty_as_literal e.ty)
| App (app, arg) ->
let input_ty, output_ty = destruct_arrow app.ty in
@@ -133,35 +142,34 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
match qualif.id with
| FunOrOp _ -> () (* TODO *)
| Global _ -> () (* TODO *)
+ | TraitConst _ -> () (* TODO *)
| Proj { adt_id = proj_adt_id; field_id } ->
(* Note we can only project fields of structures (not enumerations) *)
(* Deconstruct the projector type *)
let adt_ty, field_ty = destruct_arrow e.ty in
- let adt_id, adt_type_args, adt_cg_args = ty_as_adt adt_ty in
+ let adt_id, adt_generics = ty_as_adt adt_ty in
(* Check the ADT type *)
assert (adt_id = proj_adt_id);
- assert (adt_type_args = qualif.type_args);
- assert (adt_cg_args = qualif.const_generic_args);
+ assert (adt_generics = qualif.generics);
(* Retrieve and check the expected field type *)
let variant_id = None in
let expected_field_tys =
get_adt_field_types ctx.type_decls proj_adt_id variant_id
- qualif.type_args qualif.const_generic_args
+ qualif.generics
in
let expected_field_ty = FieldId.nth expected_field_tys field_id in
assert (expected_field_ty = field_ty)
| AdtCons id -> (
let expected_field_tys =
get_adt_field_types ctx.type_decls id.adt_id id.variant_id
- qualif.type_args qualif.const_generic_args
+ qualif.generics
in
let field_tys, adt_ty = destruct_arrows e.ty in
assert (expected_field_tys = field_tys);
match adt_ty with
- | Adt (type_id, tys, cgs) ->
+ | Adt (type_id, generics) ->
assert (type_id = id.adt_id);
- assert (tys = qualif.type_args);
- assert (cgs = qualif.const_generic_args)
+ assert (generics = qualif.generics)
| _ -> raise (Failure "Unreachable")))
| Let (monadic, pat, re, e_next) ->
let expected_pat_ty = if monadic then destruct_result re.ty else re.ty in
@@ -207,15 +215,14 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
| Some ty -> assert (ty = e.ty));
(* Check the fields *)
(* Retrieve and check the expected field type *)
- let adt_id, adt_type_args, adt_cg_args = ty_as_adt e.ty in
+ let adt_id, adt_generics = ty_as_adt e.ty in
assert (adt_id = supd.struct_id);
(* The id can only be: a custom type decl or an array *)
match adt_id with
| AdtId _ ->
let variant_id = None in
let expected_field_tys =
- get_adt_field_types ctx.type_decls adt_id variant_id adt_type_args
- adt_cg_args
+ get_adt_field_types ctx.type_decls adt_id variant_id adt_generics
in
List.iter
(fun (fid, fe) ->
@@ -224,7 +231,9 @@ let rec check_texpression (ctx : tc_ctx) (e : texpression) : unit =
check_texpression ctx fe)
supd.updates
| Assumed Array ->
- let expected_field_ty = Collections.List.to_cons_nil adt_type_args in
+ let expected_field_ty =
+ Collections.List.to_cons_nil adt_generics.types
+ in
List.iter
(fun (_, fe) ->
assert (expected_field_ty = fe.ty);
diff --git a/compiler/PureUtils.ml b/compiler/PureUtils.ml
index 1c8d8921..4e44f252 100644
--- a/compiler/PureUtils.ml
+++ b/compiler/PureUtils.ml
@@ -89,14 +89,31 @@ let mk_mplace (var_id : E.VarId.id) (name : string option)
(projection : mprojection) : mplace =
{ var_id; name; projection }
+let empty_generic_params : generic_params =
+ { types = []; const_generics = []; trait_clauses = [] }
+
+let empty_generic_args : generic_args =
+ { types = []; const_generics = []; trait_refs = [] }
+
+let mk_generic_args_from_types (types : ty list) : generic_args =
+ { types; const_generics = []; trait_refs = [] }
+
+type subst = {
+ ty_subst : TypeVarId.id -> ty;
+ cg_subst : ConstGenericVarId.id -> const_generic;
+ tr_subst : TraitClauseId.id -> trait_instance_id;
+ tr_self : trait_instance_id;
+}
+
(** Type substitution *)
-let ty_substitute (tsubst : TypeVarId.id -> ty)
- (cgsubst : ConstGenericVarId.id -> const_generic) (ty : ty) : ty =
+let ty_substitute (subst : subst) (ty : ty) : ty =
let obj =
object
inherit [_] map_ty
- method! visit_TypeVar _ var_id = tsubst var_id
- method! visit_ConstGenericVar _ var_id = cgsubst var_id
+ method! visit_TypeVar _ var_id = subst.ty_subst var_id
+ method! visit_ConstGenericVar _ var_id = subst.cg_subst var_id
+ method! visit_Clause _ id = subst.tr_subst id
+ method! visit_Self _ = subst.tr_self
end
in
obj#visit_ty () ty
@@ -115,6 +132,18 @@ let make_const_generic_subst (vars : const_generic_var list)
(cgs : const_generic list) : ConstGenericVarId.id -> const_generic =
Substitute.make_const_generic_subst_from_vars vars cgs
+let make_trait_subst (clauses : trait_clause list) (refs : trait_ref list) :
+ TraitClauseId.id -> trait_instance_id =
+ let clauses = List.map (fun x -> x.clause_id) clauses in
+ let refs = List.map (fun x -> TraitRef x) refs in
+ let ls = List.combine clauses refs in
+ let mp =
+ List.fold_left
+ (fun mp (k, v) -> TraitClauseId.Map.add k v mp)
+ TraitClauseId.Map.empty ls
+ in
+ fun id -> TraitClauseId.Map.find id mp
+
(** Retrieve the list of fields for the given variant of a {!type:Aeneas.Pure.type_decl}.
Raises [Invalid_argument] if the arguments are incorrect.
@@ -135,20 +164,27 @@ let type_decl_get_fields (def : type_decl)
- def: " ^ show_type_decl def ^ "\n- opt_variant_id: "
^ opt_variant_id))
+let make_subst_from_generics (params : generic_params) (args : generic_args)
+ (tr_self : trait_instance_id) : subst =
+ let ty_subst = make_type_subst params.types args.types in
+ let cg_subst =
+ make_const_generic_subst params.const_generics args.const_generics
+ in
+ let tr_subst = make_trait_subst params.trait_clauses args.trait_refs in
+ { ty_subst; cg_subst; tr_subst; tr_self }
+
(** Instantiate the type variables for the chosen variant in an ADT definition,
and return the list of the types of its fields *)
let type_decl_get_instantiated_fields_types (def : type_decl)
- (opt_variant_id : VariantId.id option) (types : ty list)
- (cgs : const_generic list) : ty list =
- let ty_subst = make_type_subst def.type_params types in
- let cg_subst = make_const_generic_subst def.const_generic_params cgs in
+ (opt_variant_id : VariantId.id option) (generics : generic_args) : ty list =
+ (* There shouldn't be any reference to Self *)
+ let tr_self = UnknownTrait __FUNCTION__ in
+ let subst = make_subst_from_generics def.generics generics tr_self in
let fields = type_decl_get_fields def opt_variant_id in
- List.map (fun f -> ty_substitute ty_subst cg_subst f.field_ty) fields
+ List.map (fun f -> ty_substitute subst f.field_ty) fields
-let fun_sig_substitute (tsubst : TypeVarId.id -> ty)
- (cgsubst : ConstGenericVarId.id -> const_generic) (sg : fun_sig) :
- inst_fun_sig =
- let subst = ty_substitute tsubst cgsubst in
+let fun_sig_substitute (subst : subst) (sg : fun_sig) : inst_fun_sig =
+ let subst = ty_substitute subst in
let inputs = List.map subst sg.inputs in
let output = subst sg.output in
let doutputs = List.map subst sg.doutputs in
@@ -164,7 +200,8 @@ let fun_sig_substitute (tsubst : TypeVarId.id -> ty)
*)
let rec let_group_requires_parentheses (e : texpression) : bool =
match e.e with
- | Var _ | Const _ | App _ | Abs _ | Qualif _ | StructUpdate _ -> false
+ | Var _ | CVar _ | Const _ | App _ | Abs _ | Qualif _ | StructUpdate _ ->
+ false
| Let (monadic, _, _, next_e) ->
if monadic then true else let_group_requires_parentheses next_e
| Switch (_, _) -> false
@@ -184,15 +221,18 @@ let is_var (e : texpression) : bool =
let as_var (e : texpression) : VarId.id =
match e.e with Var v -> v | _ -> raise (Failure "Unreachable")
+let is_cvar (e : texpression) : bool =
+ match e.e with CVar _ -> true | _ -> false
+
let is_global (e : texpression) : bool =
match e.e with Qualif { id = Global _; _ } -> true | _ -> false
let is_const (e : texpression) : bool =
match e.e with Const _ -> true | _ -> false
-let ty_as_adt (ty : ty) : type_id * ty list * const_generic list =
+let ty_as_adt (ty : ty) : type_id * generic_args =
match ty with
- | Adt (id, tys, cgs) -> (id, tys, cgs)
+ | Adt (id, generics) -> (id, generics)
| _ -> raise (Failure "Unreachable")
(** Remove the external occurrences of {!Meta} *)
@@ -290,28 +330,30 @@ let destruct_qualif_app (e : texpression) : qualif * texpression list =
(** Destruct an expression into a function call, if possible *)
let opt_destruct_function_call (e : texpression) :
- (fun_or_op_id * ty list * texpression list) option =
+ (fun_or_op_id * generic_args * texpression list) option =
match opt_destruct_qualif_app e with
| None -> None
| Some (qualif, args) -> (
match qualif.id with
- | FunOrOp fun_id -> Some (fun_id, qualif.type_args, args)
+ | FunOrOp fun_id -> Some (fun_id, qualif.generics, args)
| _ -> None)
let opt_destruct_result (ty : ty) : ty option =
match ty with
- | Adt (Assumed Result, tys, cgs) ->
- assert (cgs = []);
- Some (Collections.List.to_cons_nil tys)
+ | Adt (Assumed Result, generics) ->
+ assert (generics.const_generics = []);
+ assert (generics.trait_refs = []);
+ Some (Collections.List.to_cons_nil generics.types)
| _ -> None
let destruct_result (ty : ty) : ty = Option.get (opt_destruct_result ty)
let opt_destruct_tuple (ty : ty) : ty list option =
match ty with
- | Adt (Tuple, tys, cgs) ->
- assert (cgs = []);
- Some tys
+ | Adt (Tuple, generics) ->
+ assert (generics.const_generics = []);
+ assert (generics.trait_refs = []);
+ Some generics.types
| _ -> None
let mk_abs (x : typed_pattern) (e : texpression) : texpression =
@@ -383,14 +425,16 @@ let mk_switch (scrut : texpression) (sb : switch_body) : texpression =
- if there is > one type: wrap them in a tuple
*)
let mk_simpl_tuple_ty (tys : ty list) : ty =
- match tys with [ ty ] -> ty | _ -> Adt (Tuple, tys, [])
+ match tys with
+ | [ ty ] -> ty
+ | _ -> Adt (Tuple, mk_generic_args_from_types tys)
let mk_bool_ty : ty = Literal Bool
-let mk_unit_ty : ty = Adt (Tuple, [], [])
+let mk_unit_ty : ty = Adt (Tuple, empty_generic_args)
let mk_unit_rvalue : texpression =
let id = AdtCons { adt_id = Tuple; variant_id = None } in
- let qualif = { id; type_args = []; const_generic_args = [] } in
+ let qualif = { id; generics = empty_generic_args } in
let e = Qualif qualif in
let ty = mk_unit_ty in
{ e; ty }
@@ -430,7 +474,7 @@ let mk_simpl_tuple_pattern (vl : typed_pattern list) : typed_pattern =
| [ v ] -> v
| _ ->
let tys = List.map (fun (v : typed_pattern) -> v.ty) vl in
- let ty = Adt (Tuple, tys, []) in
+ let ty = Adt (Tuple, mk_generic_args_from_types tys) in
let value = PatAdt { variant_id = None; field_values = vl } in
{ value; ty }
@@ -441,11 +485,11 @@ let mk_simpl_tuple_texpression (vl : texpression list) : texpression =
| _ ->
(* Compute the types of the fields, and the type of the tuple constructor *)
let tys = List.map (fun (v : texpression) -> v.ty) vl in
- let ty = Adt (Tuple, tys, []) in
+ let ty = Adt (Tuple, mk_generic_args_from_types tys) in
let ty = mk_arrows tys ty in
(* Construct the tuple constructor qualifier *)
let id = AdtCons { adt_id = Tuple; variant_id = None } in
- let qualif = { id; type_args = tys; const_generic_args = [] } in
+ let qualif = { id; generics = mk_generic_args_from_types tys } in
(* Put everything together *)
let cons = { e = Qualif qualif; ty } in
mk_apps cons vl
@@ -463,32 +507,36 @@ let ty_as_integer (t : ty) : T.integer_type =
let ty_as_literal (t : ty) : T.literal_type =
match t with Literal ty -> ty | _ -> raise (Failure "Unreachable")
-let mk_state_ty : ty = Adt (Assumed State, [], [])
-let mk_result_ty (ty : ty) : ty = Adt (Assumed Result, [ ty ], [])
-let mk_error_ty : ty = Adt (Assumed Error, [], [])
-let mk_fuel_ty : ty = Adt (Assumed Fuel, [], [])
+let mk_state_ty : ty = Adt (Assumed State, empty_generic_args)
+
+let mk_result_ty (ty : ty) : ty =
+ Adt (Assumed Result, mk_generic_args_from_types [ ty ])
+
+let mk_error_ty : ty = Adt (Assumed Error, empty_generic_args)
+let mk_fuel_ty : ty = Adt (Assumed Fuel, empty_generic_args)
let mk_error (error : VariantId.id) : texpression =
let ty = mk_error_ty in
let id = AdtCons { adt_id = Assumed Error; variant_id = Some error } in
- let qualif = { id; type_args = []; const_generic_args = [] } in
+ let qualif = { id; generics = empty_generic_args } in
let e = Qualif qualif in
{ e; ty }
let unwrap_result_ty (ty : ty) : ty =
match ty with
- | Adt (Assumed Result, [ ty ], cgs) ->
- assert (cgs = []);
+ | Adt
+ (Assumed Result, { types = [ ty ]; const_generics = []; trait_refs = [] })
+ ->
ty
| _ -> raise (Failure "not a result type")
let mk_result_fail_texpression (error : texpression) (ty : ty) : texpression =
let type_args = [ ty ] in
- let ty = Adt (Assumed Result, type_args, []) in
+ let ty = Adt (Assumed Result, mk_generic_args_from_types type_args) in
let id =
AdtCons { adt_id = Assumed Result; variant_id = Some result_fail_id }
in
- let qualif = { id; type_args; const_generic_args = [] } in
+ let qualif = { id; generics = mk_generic_args_from_types type_args } in
let cons_e = Qualif qualif in
let cons_ty = mk_arrow error.ty ty in
let cons = { e = cons_e; ty = cons_ty } in
@@ -501,11 +549,11 @@ let mk_result_fail_texpression_with_error_id (error : VariantId.id) (ty : ty) :
let mk_result_return_texpression (v : texpression) : texpression =
let type_args = [ v.ty ] in
- let ty = Adt (Assumed Result, type_args, []) in
+ let ty = Adt (Assumed Result, mk_generic_args_from_types type_args) in
let id =
AdtCons { adt_id = Assumed Result; variant_id = Some result_return_id }
in
- let qualif = { id; type_args; const_generic_args = [] } in
+ let qualif = { id; generics = mk_generic_args_from_types type_args } in
let cons_e = Qualif qualif in
let cons_ty = mk_arrow v.ty ty in
let cons = { e = cons_e; ty = cons_ty } in
@@ -514,7 +562,7 @@ let mk_result_return_texpression (v : texpression) : texpression =
(** Create a [Fail err] pattern which captures the error *)
let mk_result_fail_pattern (error_pat : pattern) (ty : ty) : typed_pattern =
let error_pat : typed_pattern = { value = error_pat; ty = mk_error_ty } in
- let ty = Adt (Assumed Result, [ ty ], []) in
+ let ty = Adt (Assumed Result, mk_generic_args_from_types [ ty ]) in
let value =
PatAdt { variant_id = Some result_fail_id; field_values = [ error_pat ] }
in
@@ -526,7 +574,7 @@ let mk_result_fail_pattern_ignore_error (ty : ty) : typed_pattern =
mk_result_fail_pattern error_pat ty
let mk_result_return_pattern (v : typed_pattern) : typed_pattern =
- let ty = Adt (Assumed Result, [ v.ty ], []) in
+ let ty = Adt (Assumed Result, mk_generic_args_from_types [ v.ty ]) in
let value =
PatAdt { variant_id = Some result_return_id; field_values = [ v ] }
in
@@ -561,11 +609,11 @@ let rec typed_pattern_to_texpression (pat : typed_pattern) : texpression option
let fields_values = List.map (fun e -> Option.get e) fields in
(* Retrieve the type id and the type args from the pat type (simpler this way *)
- let adt_id, type_args, const_generic_args = ty_as_adt pat.ty in
+ let adt_id, generics = ty_as_adt pat.ty in
(* Create the constructor *)
let qualif_id = AdtCons { adt_id; variant_id = av.variant_id } in
- let qualif = { id = qualif_id; type_args; const_generic_args } in
+ let qualif = { id = qualif_id; generics } in
let cons_e = Qualif qualif in
let field_tys =
List.map (fun (v : texpression) -> v.ty) fields_values
@@ -577,3 +625,20 @@ let rec typed_pattern_to_texpression (pat : typed_pattern) : texpression option
Some (mk_apps cons fields_values).e
in
match e_opt with None -> None | Some e -> Some { e; ty = pat.ty }
+
+type trait_decl_method_decl_id = { is_provided : bool; id : fun_decl_id }
+
+let trait_decl_get_method (trait_decl : trait_decl) (method_name : string) :
+ trait_decl_method_decl_id =
+ (* First look in the required methods *)
+ let method_id =
+ List.find_opt (fun (s, _) -> s = method_name) trait_decl.required_methods
+ in
+ match method_id with
+ | Some (_, id) -> { is_provided = false; id }
+ | None ->
+ (* Must be a provided method *)
+ let _, id =
+ List.find (fun (s, _) -> s = method_name) trait_decl.provided_methods
+ in
+ { is_provided = true; id = Option.get id }
diff --git a/compiler/ReorderDecls.ml b/compiler/ReorderDecls.ml
index fc4744bc..10b68da3 100644
--- a/compiler/ReorderDecls.ml
+++ b/compiler/ReorderDecls.ml
@@ -38,14 +38,16 @@ let compute_body_fun_deps (e : texpression) : FunIdSet.t =
method! visit_qualif _ id =
match id.id with
- | FunOrOp (Unop _ | Binop _) | Global _ | AdtCons _ | Proj _ -> ()
+ | FunOrOp (Unop _ | Binop _)
+ | Global _ | AdtCons _ | Proj _ | TraitConst _ ->
+ ()
| FunOrOp (Fun fid) -> (
match fid with
| Pure _ -> ()
| FromLlbc (fid, lp_id, rg_id) -> (
match fid with
- | Assumed _ -> ()
- | Regular fid ->
+ | FunId (Assumed _) -> ()
+ | TraitMethod (_, _, fid) | FunId (Regular fid) ->
let id = { def_id = fid; lp_id; rg_id } in
ids := FunIdSet.add id !ids))
end
diff --git a/compiler/Substitute.ml b/compiler/Substitute.ml
index 38850243..b1680282 100644
--- a/compiler/Substitute.ml
+++ b/compiler/Substitute.ml
@@ -9,51 +9,67 @@ module E = Expressions
module A = LlbcAst
module C = Contexts
-(** Substitute types variables and regions in a type. *)
-let ty_substitute (rsubst : 'r1 -> 'r2) (tsubst : T.TypeVarId.id -> 'r2 T.ty)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : 'r1 T.ty) :
- 'r2 T.ty =
- let open T in
- let visitor =
- object
- inherit [_] map_ty
- method visit_'r _ r = rsubst r
- method! visit_TypeVar _ id = tsubst id
+type ('r1, 'r2) subst = {
+ r_subst : 'r1 -> 'r2;
+ ty_subst : T.TypeVarId.id -> 'r2 T.ty;
+ cg_subst : T.ConstGenericVarId.id -> T.const_generic;
+ (** Substitution from *local* trait clause to trait instance *)
+ tr_subst : T.TraitClauseId.id -> 'r2 T.trait_instance_id;
+ (** Substitution for the [Self] trait instance *)
+ tr_self : 'r2 T.trait_instance_id;
+}
+
+let ty_substitute_visitor (subst : ('r1, 'r2) subst) =
+ object
+ inherit [_] T.map_ty
+ method visit_'r _ r = subst.r_subst r
+ method! visit_TypeVar _ id = subst.ty_subst id
- method! visit_type_var_id _ _ =
- (* We should never get here because we reimplemented [visit_TypeVar] *)
- raise (Failure "Unexpected")
+ method! visit_type_var_id _ _ =
+ (* We should never get here because we reimplemented [visit_TypeVar] *)
+ raise (Failure "Unexpected")
- method! visit_ConstGenericVar _ id = cgsubst id
+ method! visit_ConstGenericVar _ id = subst.cg_subst id
- method! visit_const_generic_var_id _ _ =
- (* We should never get here because we reimplemented [visit_Var] *)
- raise (Failure "Unexpected")
- end
- in
+ method! visit_const_generic_var_id _ _ =
+ (* We should never get here because we reimplemented [visit_Var] *)
+ raise (Failure "Unexpected")
- visitor#visit_ty () ty
+ method! visit_Clause _ id = subst.tr_subst id
+ method! visit_Self _ = subst.tr_self
+ end
-let rty_substitute (rsubst : T.RegionId.id -> T.RegionId.id)
- (tsubst : T.TypeVarId.id -> T.rty)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : T.rty) : T.rty =
- let rsubst r =
- match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid)
- in
- ty_substitute rsubst tsubst cgsubst ty
+(** Substitute types variables and regions in a type.
+
+ **IMPORTANT**: this doesn't normalize the types.
+ *)
+let ty_substitute (subst : ('r1, 'r2) subst) (ty : 'r1 T.ty) : 'r2 T.ty =
+ let visitor = ty_substitute_visitor subst in
+ visitor#visit_ty () ty
-let ety_substitute (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (ty : T.ety) : T.ety =
- let rsubst r = r in
- ty_substitute rsubst tsubst cgsubst ty
+(** **IMPORTANT**: this doesn't normalize the types. *)
+let trait_ref_substitute (subst : ('r1, 'r2) subst) (tr : 'r1 T.trait_ref) :
+ 'r2 T.trait_ref =
+ let visitor = ty_substitute_visitor subst in
+ visitor#visit_trait_ref () tr
+
+(** **IMPORTANT**: this doesn't normalize the types. *)
+let generic_args_substitute (subst : ('r1, 'r2) subst) (g : 'r1 T.generic_args)
+ : 'r2 T.generic_args =
+ let visitor = ty_substitute_visitor subst in
+ visitor#visit_generic_args () g
+
+let erase_regions_subst : ('r, T.erased_region) subst =
+ {
+ r_subst = (fun _ -> T.Erased);
+ ty_subst = (fun vid -> T.TypeVar vid);
+ cg_subst = (fun id -> T.ConstGenericVar id);
+ tr_subst = (fun id -> T.Clause id);
+ tr_self = T.Self;
+ }
(** Convert an {!T.rty} to an {!T.ety} by erasing the region variables *)
-let erase_regions (ty : T.rty) : T.ety =
- ty_substitute
- (fun _ -> T.Erased)
- (fun vid -> T.TypeVar vid)
- (fun id -> T.ConstGenericVar id)
- ty
+let erase_regions (ty : 'r T.ty) : T.ety = ty_substitute erase_regions_subst ty
(** Generate fresh regions for region variables.
@@ -78,18 +94,20 @@ let fresh_regions_with_substs (region_vars : T.region_var list) :
(* Generate the substitution from region var id to region *)
let rid_subst id = T.RegionVarId.Map.find id rid_map in
(* Generate the substitution from region to region *)
- let rsubst r =
+ let r_subst r =
match r with T.Static -> T.Static | T.Var id -> T.Var (rid_subst id)
in
(* Return *)
- (fresh_region_ids, rid_subst, rsubst)
+ (fresh_region_ids, rid_subst, r_subst)
-(** Erase the regions in a type and substitute the type variables *)
-let erase_regions_substitute_types (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic)
- (ty : 'r T.region T.ty) : T.ety =
- let rsubst (_ : 'r T.region) : T.erased_region = T.Erased in
- ty_substitute rsubst tsubst cgsubst ty
+(** Erase the regions in a type and perform a substitution *)
+let erase_regions_substitute_types (ty_subst : T.TypeVarId.id -> T.ety)
+ (cg_subst : T.ConstGenericVarId.id -> T.const_generic)
+ (tr_subst : T.TraitClauseId.id -> T.etrait_instance_id)
+ (tr_self : T.etrait_instance_id) (ty : 'r T.ty) : T.ety =
+ let r_subst (_ : 'r) : T.erased_region = T.Erased in
+ let subst = { r_subst; ty_subst; cg_subst; tr_subst; tr_self } in
+ ty_substitute subst ty
(** Create a region substitution from a list of region variable ids and a list of
regions (with which to substitute the region variable ids *)
@@ -146,16 +164,81 @@ let make_const_generic_subst_from_vars (vars : T.const_generic_var list)
(List.map (fun (x : T.const_generic_var) -> x.T.index) vars)
cgs
-(** Instantiate the type variables in an ADT definition, and return, for
- every variant, the list of the types of its fields *)
-let type_decl_get_instantiated_variants_fields_rtypes (def : T.type_decl)
- (regions : T.RegionId.id T.region list) (types : T.rty list)
- (cgs : T.const_generic list) : (T.VariantId.id option * T.rty list) list =
- let r_subst = make_region_subst_from_vars def.T.region_params regions in
- let ty_subst = make_type_subst_from_vars def.T.type_params types in
+(** Create a trait substitution from a list of trait clause ids and a list of
+ trait refs *)
+let make_trait_subst (clause_ids : T.TraitClauseId.id list)
+ (trs : 'r T.trait_ref list) : T.TraitClauseId.id -> 'r T.trait_instance_id =
+ let ls = List.combine clause_ids trs in
+ let mp =
+ List.fold_left
+ (fun mp (k, v) -> T.TraitClauseId.Map.add k (T.TraitRef v) mp)
+ T.TraitClauseId.Map.empty ls
+ in
+ fun id -> T.TraitClauseId.Map.find id mp
+
+let make_trait_subst_from_clauses (clauses : T.trait_clause list)
+ (trs : 'r T.trait_ref list) : T.TraitClauseId.id -> 'r T.trait_instance_id =
+ make_trait_subst
+ (List.map (fun (x : T.trait_clause) -> x.T.clause_id) clauses)
+ trs
+
+let make_subst_from_generics (params : T.generic_params)
+ (args : 'r T.region T.generic_args)
+ (tr_self : 'r T.region T.trait_instance_id) :
+ (T.region_var_id T.region, 'r T.region) subst =
+ let r_subst = make_region_subst_from_vars params.T.regions args.T.regions in
+ let ty_subst = make_type_subst_from_vars params.T.types args.T.types in
let cg_subst =
- make_const_generic_subst_from_vars def.T.const_generic_params cgs
+ make_const_generic_subst_from_vars params.T.const_generics
+ args.T.const_generics
+ in
+ let tr_subst =
+ make_trait_subst_from_clauses params.T.trait_clauses args.T.trait_refs
+ in
+ { r_subst; ty_subst; cg_subst; tr_subst; tr_self }
+
+let make_subst_from_generics_no_regions :
+ 'r.
+ T.generic_params ->
+ 'r T.generic_args ->
+ 'r T.trait_instance_id ->
+ (T.region_var_id T.region, 'r) subst =
+ fun params args tr_self ->
+ let r_subst _ = raise (Failure "Unexpected region") in
+ let ty_subst = make_type_subst_from_vars params.T.types args.T.types in
+ let cg_subst =
+ make_const_generic_subst_from_vars params.T.const_generics
+ args.T.const_generics
+ in
+ let tr_subst =
+ make_trait_subst_from_clauses params.T.trait_clauses args.T.trait_refs
+ in
+ { r_subst; ty_subst; cg_subst; tr_subst; tr_self }
+
+let make_esubst_from_generics (params : T.generic_params)
+ (generics : T.egeneric_args) (tr_self : T.etrait_instance_id) =
+ let r_subst _ = T.Erased in
+ let ty_subst = make_type_subst_from_vars params.types generics.T.types in
+ let cg_subst =
+ make_const_generic_subst_from_vars params.const_generics
+ generics.T.const_generics
+ in
+ let tr_subst =
+ make_trait_subst_from_clauses params.trait_clauses generics.T.trait_refs
in
+ { r_subst; ty_subst; cg_subst; tr_subst; tr_self }
+
+(** Instantiate the type variables in an ADT definition, and return, for
+ every variant, the list of the types of its fields.
+
+ **IMPORTANT**: this function doesn't normalize the types, you may want to
+ use the [AssociatedTypes] equivalent instead.
+*)
+let type_decl_get_instantiated_variants_fields_rtypes (def : T.type_decl)
+ (generics : T.rgeneric_args) : (T.VariantId.id option * T.rty list) list =
+ (* There shouldn't be any reference to Self *)
+ let tr_self = T.UnknownTrait __FUNCTION__ in
+ let subst = make_subst_from_generics def.T.generics generics tr_self in
let (variants_fields : (T.VariantId.id option * T.field list) list) =
match def.T.kind with
| T.Enum variants ->
@@ -171,191 +254,232 @@ let type_decl_get_instantiated_variants_fields_rtypes (def : T.type_decl)
in
List.map
(fun (id, fields) ->
- ( id,
- List.map
- (fun f -> ty_substitute r_subst ty_subst cg_subst f.T.field_ty)
- fields ))
+ (id, List.map (fun f -> ty_substitute subst f.T.field_ty) fields))
variants_fields
(** Instantiate the type variables in an ADT definition, and return the list
- of types of the fields for the chosen variant *)
+ of types of the fields for the chosen variant.
+
+ **IMPORTANT**: this function doesn't normalize the types, you may want to
+ use the [AssociatedTypes] equivalent instead.
+*)
let type_decl_get_instantiated_field_rtypes (def : T.type_decl)
- (opt_variant_id : T.VariantId.id option)
- (regions : T.RegionId.id T.region list) (types : T.rty list)
- (cgs : T.const_generic list) : T.rty list =
- let r_subst = make_region_subst_from_vars def.T.region_params regions in
- let ty_subst = make_type_subst_from_vars def.T.type_params types in
- let cg_subst =
- make_const_generic_subst_from_vars def.T.const_generic_params cgs
- in
+ (opt_variant_id : T.VariantId.id option) (generics : T.rgeneric_args) :
+ T.rty list =
+ (* For now, check that there are no clauses - otherwise we might need
+ to normalize the types *)
+ assert (def.generics.trait_clauses = []);
+ (* There shouldn't be any reference to Self *)
+ let tr_self = T.UnknownTrait __FUNCTION__ in
+ let subst = make_subst_from_generics def.T.generics generics tr_self in
let fields = TU.type_decl_get_fields def opt_variant_id in
- List.map
- (fun f -> ty_substitute r_subst ty_subst cg_subst f.T.field_ty)
- fields
+ List.map (fun f -> ty_substitute subst f.T.field_ty) fields
(** Return the types of the properly instantiated ADT's variant, provided a
- context *)
+ context.
+
+ **IMPORTANT**: this function doesn't normalize the types, you may want to
+ use the [AssociatedTypes] equivalent instead.
+*)
let ctx_adt_get_instantiated_field_rtypes (ctx : C.eval_ctx)
(def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option)
- (regions : T.RegionId.id T.region list) (types : T.rty list)
- (cgs : T.const_generic list) : T.rty list =
+ (generics : T.rgeneric_args) : T.rty list =
let def = C.ctx_lookup_type_decl ctx def_id in
- type_decl_get_instantiated_field_rtypes def opt_variant_id regions types cgs
+ type_decl_get_instantiated_field_rtypes def opt_variant_id generics
(** Return the types of the properly instantiated ADT value (note that
- here, ADT is understood in its broad meaning: ADT, assumed value or tuple) *)
+ here, ADT is understood in its broad meaning: ADT, assumed value or tuple).
+
+ **IMPORTANT**: this function doesn't normalize the types, you may want to
+ use the [AssociatedTypes] equivalent instead.
+ *)
let ctx_adt_value_get_instantiated_field_rtypes (ctx : C.eval_ctx)
- (adt : V.adt_value) (id : T.type_id)
- (region_params : T.RegionId.id T.region list) (type_params : T.rty list)
- (cg_params : T.const_generic list) : T.rty list =
+ (adt : V.adt_value) (id : T.type_id) (generics : T.rgeneric_args) :
+ T.rty list =
match id with
| T.AdtId id ->
(* Retrieve the types of the fields *)
- ctx_adt_get_instantiated_field_rtypes ctx id adt.V.variant_id
- region_params type_params cg_params
+ ctx_adt_get_instantiated_field_rtypes ctx id adt.V.variant_id generics
| T.Tuple ->
- assert (List.length region_params = 0);
- type_params
+ assert (generics.regions = []);
+ generics.types
| T.Assumed aty -> (
match aty with
| T.Box | T.Vec ->
- assert (List.length region_params = 0);
- assert (List.length type_params = 1);
- assert (List.length cg_params = 0);
- type_params
+ assert (generics.regions = []);
+ assert (List.length generics.types = 1);
+ assert (generics.const_generics = []);
+ generics.types
| T.Option ->
- assert (List.length region_params = 0);
- assert (List.length type_params = 1);
- assert (List.length cg_params = 0);
- if adt.V.variant_id = Some T.option_some_id then type_params
+ assert (generics.regions = []);
+ assert (List.length generics.types = 1);
+ assert (generics.const_generics = []);
+ if adt.V.variant_id = Some T.option_some_id then generics.types
else if adt.V.variant_id = Some T.option_none_id then []
else raise (Failure "Unreachable")
| T.Range ->
- assert (List.length region_params = 0);
- assert (List.length type_params = 1);
- assert (List.length cg_params = 0);
- type_params
+ assert (generics.regions = []);
+ assert (List.length generics.types = 1);
+ assert (generics.const_generics = []);
+ generics.types
| T.Array | T.Slice | T.Str ->
(* Those types don't have fields *)
raise (Failure "Unreachable"))
(** Instantiate the type variables in an ADT definition, and return the list
- of types of the fields for the chosen variant *)
+ of types of the fields for the chosen variant.
+
+ **IMPORTANT**: this function doesn't normalize the types, you may want to
+ use the [AssociatedTypes] equivalent instead.
+*)
let type_decl_get_instantiated_field_etypes (def : T.type_decl)
- (opt_variant_id : T.VariantId.id option) (types : T.ety list)
- (cgs : T.const_generic list) : T.ety list =
- let ty_subst = make_type_subst_from_vars def.T.type_params types in
- let cg_subst =
- make_const_generic_subst_from_vars def.T.const_generic_params cgs
+ (opt_variant_id : T.VariantId.id option) (generics : T.egeneric_args) :
+ T.ety list =
+ (* For now, check that there are no clauses - otherwise we might need
+ to normalize the types *)
+ assert (def.generics.trait_clauses = []);
+ (* There shouldn't be any reference to Self *)
+ let tr_self : T.erased_region T.trait_instance_id =
+ T.UnknownTrait __FUNCTION__
+ in
+ let { r_subst = _; ty_subst; cg_subst; tr_subst; tr_self } =
+ make_esubst_from_generics def.T.generics generics tr_self
in
let fields = TU.type_decl_get_fields def opt_variant_id in
List.map
- (fun f -> erase_regions_substitute_types ty_subst cg_subst f.T.field_ty)
+ (fun (f : T.field) ->
+ erase_regions_substitute_types ty_subst cg_subst tr_subst tr_self
+ f.T.field_ty)
fields
(** Return the types of the properly instantiated ADT's variant, provided a
- context *)
+ context.
+
+ **IMPORTANT**: this function doesn't normalize the types, you may want to
+ use the [AssociatedTypes] equivalent instead.
+ *)
let ctx_adt_get_instantiated_field_etypes (ctx : C.eval_ctx)
(def_id : T.TypeDeclId.id) (opt_variant_id : T.VariantId.id option)
- (types : T.ety list) (cgs : T.const_generic list) : T.ety list =
+ (generics : T.egeneric_args) : T.ety list =
let def = C.ctx_lookup_type_decl ctx def_id in
- type_decl_get_instantiated_field_etypes def opt_variant_id types cgs
+ type_decl_get_instantiated_field_etypes def opt_variant_id generics
-let statement_substitute_visitor (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) =
+let statement_substitute_visitor
+ (subst : (T.erased_region, T.erased_region) subst) =
+ (* Keep in synch with [ty_substitute_visitor] *)
object
inherit [_] A.map_statement
- method! visit_ety _ ty = ety_substitute tsubst cgsubst ty
- method! visit_ConstGenericVar _ id = cgsubst id
+ method! visit_'r _ r = subst.r_subst r
+ method! visit_TypeVar _ id = subst.ty_subst id
+
+ method! visit_type_var_id _ _ =
+ (* We should never get here because we reimplemented [visit_TypeVar] *)
+ raise (Failure "Unexpected")
+
+ method! visit_ConstGenericVar _ id = subst.cg_subst id
method! visit_const_generic_var_id _ _ =
(* We should never get here because we reimplemented [visit_Var] *)
raise (Failure "Unexpected")
+
+ method! visit_Clause _ id = subst.tr_subst id
+ method! visit_Self _ = subst.tr_self
end
(** Apply a type substitution to a place *)
-let place_substitute (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (p : E.place) :
- E.place =
+let place_substitute (subst : (T.erased_region, T.erased_region) subst)
+ (p : E.place) : E.place =
(* There is in fact nothing to do *)
- (statement_substitute_visitor tsubst cgsubst)#visit_place () p
+ (statement_substitute_visitor subst)#visit_place () p
(** Apply a type substitution to an operand *)
-let operand_substitute (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (op : E.operand) :
- E.operand =
- (statement_substitute_visitor tsubst cgsubst)#visit_operand () op
+let operand_substitute (subst : (T.erased_region, T.erased_region) subst)
+ (op : E.operand) : E.operand =
+ (statement_substitute_visitor subst)#visit_operand () op
(** Apply a type substitution to an rvalue *)
-let rvalue_substitute (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (rv : E.rvalue) :
- E.rvalue =
- (statement_substitute_visitor tsubst cgsubst)#visit_rvalue () rv
+let rvalue_substitute (subst : (T.erased_region, T.erased_region) subst)
+ (rv : E.rvalue) : E.rvalue =
+ (statement_substitute_visitor subst)#visit_rvalue () rv
(** Apply a type substitution to an assertion *)
-let assertion_substitute (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (a : A.assertion) :
- A.assertion =
- (statement_substitute_visitor tsubst cgsubst)#visit_assertion () a
+let assertion_substitute (subst : (T.erased_region, T.erased_region) subst)
+ (a : A.assertion) : A.assertion =
+ (statement_substitute_visitor subst)#visit_assertion () a
(** Apply a type substitution to a call *)
-let call_substitute (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (call : A.call) :
- A.call =
- (statement_substitute_visitor tsubst cgsubst)#visit_call () call
+let call_substitute (subst : (T.erased_region, T.erased_region) subst)
+ (call : A.call) : A.call =
+ (statement_substitute_visitor subst)#visit_call () call
(** Apply a type substitution to a statement *)
-let statement_substitute (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (st : A.statement) :
- A.statement =
- (statement_substitute_visitor tsubst cgsubst)#visit_statement () st
+let statement_substitute (subst : (T.erased_region, T.erased_region) subst)
+ (st : A.statement) : A.statement =
+ (statement_substitute_visitor subst)#visit_statement () st
(** Apply a type substitution to a function body. Return the local variables
and the body. *)
-let fun_body_substitute_in_body (tsubst : T.TypeVarId.id -> T.ety)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (body : A.fun_body) :
+let fun_body_substitute_in_body
+ (subst : (T.erased_region, T.erased_region) subst) (body : A.fun_body) :
A.var list * A.statement =
- let rsubst r = r in
let locals =
List.map
- (fun (v : A.var) ->
- { v with A.var_ty = ty_substitute rsubst tsubst cgsubst v.A.var_ty })
+ (fun (v : A.var) -> { v with A.var_ty = ty_substitute subst v.A.var_ty })
body.A.locals
in
- let body = statement_substitute tsubst cgsubst body.body in
+ let body = statement_substitute subst body.body in
(locals, body)
-(** Substitute a function signature *)
+let trait_type_constraint_substitute (subst : ('r1, 'r2) subst)
+ (ttc : 'r1 T.trait_type_constraint) : 'r2 T.trait_type_constraint =
+ let { T.trait_ref; generics; type_name; ty } = ttc in
+ let visitor = ty_substitute_visitor subst in
+ let trait_ref = visitor#visit_trait_ref () trait_ref in
+ let generics = visitor#visit_generic_args () generics in
+ let ty = visitor#visit_ty () ty in
+ { T.trait_ref; generics; type_name; ty }
+
+(** Substitute a function signature.
+
+ **IMPORTANT:** this function doesn't normalize the types.
+ *)
let substitute_signature (asubst : T.RegionGroupId.id -> V.AbstractionId.id)
- (rsubst : T.RegionVarId.id -> T.RegionId.id)
- (tsubst : T.TypeVarId.id -> T.rty)
- (cgsubst : T.ConstGenericVarId.id -> T.const_generic) (sg : A.fun_sig) :
- A.inst_fun_sig =
- let rsubst' (r : T.RegionVarId.id T.region) : T.RegionId.id T.region =
- match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid)
+ (r_subst : T.RegionVarId.id -> T.RegionId.id)
+ (ty_subst : T.TypeVarId.id -> T.rty)
+ (cg_subst : T.ConstGenericVarId.id -> T.const_generic)
+ (tr_subst : T.TraitClauseId.id -> T.rtrait_instance_id)
+ (tr_self : T.rtrait_instance_id) (sg : A.fun_sig) : A.inst_fun_sig =
+ let r_subst' (r : T.RegionVarId.id T.region) : T.RegionId.id T.region =
+ match r with T.Static -> T.Static | T.Var rid -> T.Var (r_subst rid)
in
- let inputs = List.map (ty_substitute rsubst' tsubst cgsubst) sg.A.inputs in
- let output = ty_substitute rsubst' tsubst cgsubst sg.A.output in
+ let subst = { r_subst = r_subst'; ty_subst; cg_subst; tr_subst; tr_self } in
+ let inputs = List.map (ty_substitute subst) sg.A.inputs in
+ let output = ty_substitute subst sg.A.output in
let subst_region_group (rg : T.region_var_group) : A.abs_region_group =
let id = asubst rg.id in
- let regions = List.map rsubst rg.regions in
+ let regions = List.map r_subst rg.regions in
let parents = List.map asubst rg.parents in
{ id; regions; parents }
in
let regions_hierarchy = List.map subst_region_group sg.A.regions_hierarchy in
- { A.regions_hierarchy; inputs; output }
+ let trait_type_constraints =
+ List.map
+ (trait_type_constraint_substitute subst)
+ sg.preds.trait_type_constraints
+ in
+ { A.inputs; output; regions_hierarchy; trait_type_constraints }
-(** Substitute type variable identifiers in a type *)
-let ty_substitute_ids (tsubst : T.TypeVarId.id -> T.TypeVarId.id)
- (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ty : 'r T.ty)
+(** Substitute variable identifiers in a type *)
+let ty_substitute_ids (ty_subst : T.TypeVarId.id -> T.TypeVarId.id)
+ (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id) (ty : 'r T.ty)
: 'r T.ty =
let open T in
let visitor =
object
inherit [_] map_ty
method visit_'r _ r = r
- method! visit_type_var_id _ id = tsubst id
- method! visit_const_generic_var_id _ id = cgsubst id
+ method! visit_type_var_id _ id = ty_subst id
+ method! visit_const_generic_var_id _ id = cg_subst id
end
in
@@ -371,10 +495,10 @@ let ty_substitute_ids (tsubst : T.TypeVarId.id -> T.TypeVarId.id)
[visit_'r] if we define a class which visits objects of types [ety] and [rty]
while inheriting a class which visit [ty]...
*)
-let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id)
+let subst_ids_visitor (r_subst : T.RegionId.id -> T.RegionId.id)
(rvsubst : T.RegionVarId.id -> T.RegionVarId.id)
- (tsubst : T.TypeVarId.id -> T.TypeVarId.id)
- (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
+ (ty_subst : T.TypeVarId.id -> T.TypeVarId.id)
+ (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
(ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id)
(bsubst : V.BorrowId.id -> V.BorrowId.id)
(asubst : V.AbstractionId.id -> V.AbstractionId.id) =
@@ -383,10 +507,10 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id)
inherit [_] T.map_ty
method visit_'r _ r =
- match r with T.Static -> T.Static | T.Var rid -> T.Var (rsubst rid)
+ match r with T.Static -> T.Static | T.Var rid -> T.Var (r_subst rid)
- method! visit_type_var_id _ id = tsubst id
- method! visit_const_generic_var_id _ id = cgsubst id
+ method! visit_type_var_id _ id = ty_subst id
+ method! visit_const_generic_var_id _ id = cg_subst id
end
in
@@ -395,7 +519,7 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id)
inherit [_] C.map_env
method! visit_borrow_id _ bid = bsubst bid
method! visit_loan_id _ bid = bsubst bid
- method! visit_ety _ ty = ty_substitute_ids tsubst cgsubst ty
+ method! visit_ety _ ty = ty_substitute_ids ty_subst cg_subst ty
method! visit_rty env ty = subst_rty#visit_ty env ty
method! visit_symbolic_value_id _ id = ssubst id
@@ -405,7 +529,7 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id)
(** We *do* visit meta-values *)
method! visit_mvalue env v = self#visit_typed_value env v
- method! visit_region_id _ id = rsubst id
+ method! visit_region_id _ id = r_subst id
method! visit_region_var_id _ id = rvsubst id
method! visit_abstraction_id _ id = asubst id
end
@@ -425,20 +549,20 @@ let subst_ids_visitor (rsubst : T.RegionId.id -> T.RegionId.id)
method visit_env (env : C.env) : C.env = visitor#visit_env () env
end
-let typed_value_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id)
+let typed_value_subst_ids (r_subst : T.RegionId.id -> T.RegionId.id)
(rvsubst : T.RegionVarId.id -> T.RegionVarId.id)
- (tsubst : T.TypeVarId.id -> T.TypeVarId.id)
- (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
+ (ty_subst : T.TypeVarId.id -> T.TypeVarId.id)
+ (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
(ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id)
(bsubst : V.BorrowId.id -> V.BorrowId.id) (v : V.typed_value) :
V.typed_value =
let asubst _ = raise (Failure "Unreachable") in
- (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst)
+ (subst_ids_visitor r_subst rvsubst ty_subst cg_subst ssubst bsubst asubst)
#visit_typed_value v
-let typed_value_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id)
+let typed_value_subst_rids (r_subst : T.RegionId.id -> T.RegionId.id)
(v : V.typed_value) : V.typed_value =
- typed_value_subst_ids rsubst
+ typed_value_subst_ids r_subst
(fun x -> x)
(fun x -> x)
(fun x -> x)
@@ -446,41 +570,41 @@ let typed_value_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id)
(fun x -> x)
v
-let typed_avalue_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id)
+let typed_avalue_subst_ids (r_subst : T.RegionId.id -> T.RegionId.id)
(rvsubst : T.RegionVarId.id -> T.RegionVarId.id)
- (tsubst : T.TypeVarId.id -> T.TypeVarId.id)
- (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
+ (ty_subst : T.TypeVarId.id -> T.TypeVarId.id)
+ (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
(ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id)
(bsubst : V.BorrowId.id -> V.BorrowId.id) (v : V.typed_avalue) :
V.typed_avalue =
let asubst _ = raise (Failure "Unreachable") in
- (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst)
+ (subst_ids_visitor r_subst rvsubst ty_subst cg_subst ssubst bsubst asubst)
#visit_typed_avalue v
-let abs_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id)
+let abs_subst_ids (r_subst : T.RegionId.id -> T.RegionId.id)
(rvsubst : T.RegionVarId.id -> T.RegionVarId.id)
- (tsubst : T.TypeVarId.id -> T.TypeVarId.id)
- (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
+ (ty_subst : T.TypeVarId.id -> T.TypeVarId.id)
+ (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
(ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id)
(bsubst : V.BorrowId.id -> V.BorrowId.id)
(asubst : V.AbstractionId.id -> V.AbstractionId.id) (x : V.abs) : V.abs =
- (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst)
+ (subst_ids_visitor r_subst rvsubst ty_subst cg_subst ssubst bsubst asubst)
#visit_abs x
-let env_subst_ids (rsubst : T.RegionId.id -> T.RegionId.id)
+let env_subst_ids (r_subst : T.RegionId.id -> T.RegionId.id)
(rvsubst : T.RegionVarId.id -> T.RegionVarId.id)
- (tsubst : T.TypeVarId.id -> T.TypeVarId.id)
- (cgsubst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
+ (ty_subst : T.TypeVarId.id -> T.TypeVarId.id)
+ (cg_subst : T.ConstGenericVarId.id -> T.ConstGenericVarId.id)
(ssubst : V.SymbolicValueId.id -> V.SymbolicValueId.id)
(bsubst : V.BorrowId.id -> V.BorrowId.id)
(asubst : V.AbstractionId.id -> V.AbstractionId.id) (x : C.env) : C.env =
- (subst_ids_visitor rsubst rvsubst tsubst cgsubst ssubst bsubst asubst)
+ (subst_ids_visitor r_subst rvsubst ty_subst cg_subst ssubst bsubst asubst)
#visit_env x
-let typed_avalue_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id)
+let typed_avalue_subst_rids (r_subst : T.RegionId.id -> T.RegionId.id)
(x : V.typed_avalue) : V.typed_avalue =
let asubst _ = raise (Failure "Unreachable") in
- (subst_ids_visitor rsubst
+ (subst_ids_visitor r_subst
(fun x -> x)
(fun x -> x)
(fun x -> x)
@@ -490,9 +614,9 @@ let typed_avalue_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id)
#visit_typed_avalue
x
-let env_subst_rids (rsubst : T.RegionId.id -> T.RegionId.id) (x : C.env) : C.env
- =
- (subst_ids_visitor rsubst
+let env_subst_rids (r_subst : T.RegionId.id -> T.RegionId.id) (x : C.env) :
+ C.env =
+ (subst_ids_visitor r_subst
(fun x -> x)
(fun x -> x)
(fun x -> x)
diff --git a/compiler/SymbolicAst.ml b/compiler/SymbolicAst.ml
index 7dc94dcd..4df8fec7 100644
--- a/compiler/SymbolicAst.ml
+++ b/compiler/SymbolicAst.ml
@@ -29,7 +29,7 @@ type mplace = {
[@@deriving show]
type call_id =
- | Fun of A.fun_id * V.FunCallId.id
+ | Fun of A.fun_id_or_trait_method_ref * V.FunCallId.id
(** A "regular" function (i.e., a function which is not a primitive operation) *)
| Unop of E.unop
| Binop of E.binop
@@ -43,10 +43,7 @@ type call = {
borrows (we need to perform lookups).
*)
abstractions : V.AbstractionId.id list;
- (* TODO: rename to "...args" *)
- type_params : T.ety list;
- (* TODO: rename to "...args" *)
- const_generic_params : T.const_generic list;
+ generics : T.egeneric_args;
args : V.typed_value list;
args_places : mplace option list; (** Meta information *)
dest : V.symbolic_value;
@@ -79,6 +76,9 @@ class ['self] iter_expression_base =
method visit_loop_id : 'env -> V.loop_id -> unit = fun _ _ -> ()
method visit_variant_id : 'env -> variant_id -> unit = fun _ _ -> ()
+ method visit_const_generic_var_id : 'env -> T.const_generic_var_id -> unit =
+ fun _ _ -> ()
+
method visit_symbolic_value_id : 'env -> V.symbolic_value_id -> unit =
fun _ _ -> ()
@@ -120,6 +120,9 @@ class ['self] iter_expression_base =
method visit_symbolic_expansion : 'env -> V.symbolic_expansion -> unit =
fun _ _ -> ()
+
+ method visit_etrait_ref : 'env -> T.etrait_ref -> unit = fun _ _ -> ()
+ method visit_egeneric_args : 'env -> T.egeneric_args -> unit = fun _ _ -> ()
end
(** **Rem.:** here, {!expression} is not at all equivalent to the expressions
@@ -171,14 +174,15 @@ type expression =
* expression
(** We introduce a new symbolic value, equal to some other value.
- This is used for instance when reorganizing the environment to compute
- fixed points: we duplicate some shared symbolic values to destructure
- the shared values, in order to make the environment a bit more general
- (while losing precision of course).
+ This is used for instance when reorganizing the environment to compute
+ fixed points: we duplicate some shared symbolic values to destructure
+ the shared values, in order to make the environment a bit more general
+ (while losing precision of course). We also use it to introduce symbolic
+ values when evaluating constant generics, or trait constants.
- The context is the evaluation context from before introducing the new
- value. It has the same purpose as for the {!Return} case.
- *)
+ The context is the evaluation context from before introducing the new
+ value. It has the same purpose as for the {!Return} case.
+ *)
| ForwardEnd of
Contexts.eval_ctx
* V.typed_value symbolic_value_id_map option
@@ -253,6 +257,11 @@ and value_aggregate =
| SingleValue of V.typed_value (** Regular case *)
| Array of V.typed_value list
(** This is used when introducing array aggregates *)
+ | ConstGenericValue of T.const_generic_var_id
+ (** This is used when evaluating a const generic value: in the interpreter,
+ we introduce a fresh symbolic value. *)
+ | TraitConstValue of T.etrait_ref * T.egeneric_args * string
+ (** A trait constant value *)
[@@deriving
show,
visitors
diff --git a/compiler/SymbolicToPure.ml b/compiler/SymbolicToPure.ml
index 3512270a..429198ad 100644
--- a/compiler/SymbolicToPure.ml
+++ b/compiler/SymbolicToPure.ml
@@ -4,6 +4,7 @@ open Pure
open PureUtils
module Id = Identifiers
module C = Contexts
+module A = LlbcAst
module S = SymbolicAst
module TA = TypesAnalysis
module L = Logging
@@ -52,6 +53,9 @@ type fun_context = {
type global_context = { llbc_global_decls : A.global_decl A.GlobalDeclId.Map.t }
[@@deriving show]
+type trait_decls_context = A.trait_decl A.TraitDeclId.Map.t [@@deriving show]
+type trait_impls_context = A.trait_impl A.TraitImplId.Map.t [@@deriving show]
+
(** Whenever we translate a function call or an ended abstraction, we
store the related information (this is useful when translating ended
children abstractions).
@@ -106,8 +110,7 @@ type loop_info = {
loop_id : LoopId.id;
input_vars : var list;
input_svl : V.symbolic_value list;
- type_args : ty list;
- const_generic_args : const_generic list;
+ generics : generic_args;
forward_inputs : texpression list option;
(** The forward inputs are initialized at [None] *)
forward_output_no_state_no_result : var option;
@@ -120,6 +123,8 @@ type bs_ctx = {
type_context : type_context;
fun_context : fun_context;
global_context : global_context;
+ trait_decls_ctx : trait_decls_context;
+ trait_impls_ctx : trait_impls_context;
fun_decl : A.fun_decl;
bid : T.RegionGroupId.id option; (** TODO: rename *)
sg : fun_sig;
@@ -201,34 +206,11 @@ type bs_ctx = {
}
[@@deriving show]
-let type_check_pattern (ctx : bs_ctx) (v : typed_pattern) : unit =
- let env = VarId.Map.empty in
- let ctx =
- {
- PureTypeCheck.type_decls = ctx.type_context.type_decls;
- global_decls = ctx.global_context.llbc_global_decls;
- env;
- }
- in
- let _ = PureTypeCheck.check_typed_pattern ctx v in
- ()
-
-let type_check_texpression (ctx : bs_ctx) (e : texpression) : unit =
- let env = VarId.Map.empty in
- let ctx =
- {
- PureTypeCheck.type_decls = ctx.type_context.type_decls;
- global_decls = ctx.global_context.llbc_global_decls;
- env;
- }
- in
- PureTypeCheck.check_texpression ctx e
-
(* TODO: move *)
let bs_ctx_to_ast_formatter (ctx : bs_ctx) : Print.Ast.ast_formatter =
Print.Ast.decls_and_fun_decl_to_ast_formatter ctx.type_context.llbc_type_decls
ctx.fun_context.llbc_fun_decls ctx.global_context.llbc_global_decls
- ctx.fun_decl
+ ctx.trait_decls_ctx ctx.trait_impls_ctx ctx.fun_decl
let bs_ctx_to_ctx_formatter (ctx : bs_ctx) : Print.Contexts.ctx_formatter =
let rvar_to_string = Print.Types.region_var_id_to_string in
@@ -246,16 +228,25 @@ let bs_ctx_to_ctx_formatter (ctx : bs_ctx) : Print.Contexts.ctx_formatter =
adt_variant_to_string = ast_fmt.adt_variant_to_string;
var_id_to_string;
adt_field_names = ast_fmt.adt_field_names;
+ trait_decl_id_to_string = ast_fmt.trait_decl_id_to_string;
+ trait_impl_id_to_string = ast_fmt.trait_impl_id_to_string;
+ trait_clause_id_to_string = ast_fmt.trait_clause_id_to_string;
}
let bs_ctx_to_pp_ast_formatter (ctx : bs_ctx) : PrintPure.ast_formatter =
- let type_params = ctx.fun_decl.signature.type_params in
- let cg_params = ctx.fun_decl.signature.const_generic_params in
+ let generics = ctx.fun_decl.signature.generics in
let type_decls = ctx.type_context.llbc_type_decls in
let fun_decls = ctx.fun_context.llbc_fun_decls in
let global_decls = ctx.global_context.llbc_global_decls in
- PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params
- cg_params
+ PrintPure.mk_ast_formatter type_decls fun_decls global_decls
+ ctx.trait_decls_ctx ctx.trait_impls_ctx generics.types
+ generics.const_generics
+
+let ctx_egeneric_args_to_string (ctx : bs_ctx) (args : T.egeneric_args) : string
+ =
+ let fmt = bs_ctx_to_ctx_formatter ctx in
+ let fmt = Print.PC.ctx_to_etype_formatter fmt in
+ Print.PT.egeneric_args_to_string fmt args
let symbolic_value_to_string (ctx : bs_ctx) (sv : V.symbolic_value) : string =
let fmt = bs_ctx_to_ctx_formatter ctx in
@@ -277,12 +268,11 @@ let rty_to_string (ctx : bs_ctx) (ty : T.rty) : string =
Print.PT.rty_to_string fmt ty
let type_decl_to_string (ctx : bs_ctx) (def : type_decl) : string =
- let type_params = def.type_params in
- let cg_params = def.const_generic_params in
let type_decls = ctx.type_context.llbc_type_decls in
let global_decls = ctx.global_context.llbc_global_decls in
let fmt =
- PrintPure.mk_type_formatter type_decls global_decls type_params cg_params
+ PrintPure.mk_type_formatter type_decls global_decls ctx.trait_decls_ctx
+ ctx.trait_impls_ctx def.generics.types def.generics.const_generics
in
PrintPure.type_decl_to_string fmt def
@@ -291,26 +281,27 @@ let texpression_to_string (ctx : bs_ctx) (e : texpression) : string =
PrintPure.texpression_to_string fmt false "" " " e
let fun_sig_to_string (ctx : bs_ctx) (sg : fun_sig) : string =
- let type_params = sg.type_params in
- let cg_params = sg.const_generic_params in
+ let type_params = sg.generics.types in
+ let cg_params = sg.generics.const_generics in
let type_decls = ctx.type_context.llbc_type_decls in
let fun_decls = ctx.fun_context.llbc_fun_decls in
let global_decls = ctx.global_context.llbc_global_decls in
let fmt =
- PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params
- cg_params
+ PrintPure.mk_ast_formatter type_decls fun_decls global_decls
+ ctx.trait_decls_ctx ctx.trait_impls_ctx type_params cg_params
in
PrintPure.fun_sig_to_string fmt sg
let fun_decl_to_string (ctx : bs_ctx) (def : Pure.fun_decl) : string =
- let type_params = def.signature.type_params in
- let cg_params = def.signature.const_generic_params in
+ let generics = def.signature.generics in
+ let type_params = generics.types in
+ let cg_params = generics.const_generics in
let type_decls = ctx.type_context.llbc_type_decls in
let fun_decls = ctx.fun_context.llbc_fun_decls in
let global_decls = ctx.global_context.llbc_global_decls in
let fmt =
- PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params
- cg_params
+ PrintPure.mk_ast_formatter type_decls fun_decls global_decls
+ ctx.trait_decls_ctx ctx.trait_impls_ctx type_params cg_params
in
PrintPure.fun_decl_to_string fmt def
@@ -328,17 +319,18 @@ let abs_to_string (ctx : bs_ctx) (abs : V.abs) : string =
Print.Values.abs_to_string fmt verbose indent indent_incr abs
let get_instantiated_fun_sig (fun_id : A.fun_id)
- (back_id : T.RegionGroupId.id option) (tys : ty list)
- (cgs : const_generic list) (ctx : bs_ctx) : inst_fun_sig =
+ (back_id : T.RegionGroupId.id option) (generics : generic_args)
+ (ctx : bs_ctx) : inst_fun_sig =
(* Lookup the non-instantiated function signature *)
let sg =
(RegularFunIdNotLoopMap.find (fun_id, back_id) ctx.fun_context.fun_sigs).sg
in
(* Create the substitution *)
- let tsubst = make_type_subst sg.type_params tys in
- let cgsubst = make_const_generic_subst sg.const_generic_params cgs in
+ (* There shouldn't be any reference to Self *)
+ let tr_self = UnknownTrait __FUNCTION__ in
+ let subst = make_subst_from_generics sg.generics generics tr_self in
(* Apply *)
- fun_sig_substitute tsubst cgsubst sg
+ fun_sig_substitute subst sg
let bs_ctx_lookup_llbc_type_decl (id : TypeDeclId.id) (ctx : bs_ctx) :
T.type_decl =
@@ -354,74 +346,121 @@ let bs_ctx_lookup_local_function_sig (def_id : A.FunDeclId.id)
let id = (A.Regular def_id, back_id) in
(RegularFunIdNotLoopMap.find id ctx.fun_context.fun_sigs).sg
-let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call)
- (args : texpression list) (ctx : bs_ctx) : bs_ctx =
- let calls = ctx.calls in
- assert (not (V.FunCallId.Map.mem call_id calls));
- let info =
- { forward; forward_inputs = args; backwards = T.RegionGroupId.Map.empty }
- in
- let calls = V.FunCallId.Map.add call_id info calls in
- { ctx with calls }
-
-(** [back_args]: the *additional* list of inputs received by the backward function *)
-let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id)
- (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx)
- : bs_ctx * fun_or_op_id =
- (* Insert the abstraction in the call informations *)
- let info = V.FunCallId.Map.find call_id ctx.calls in
- assert (not (T.RegionGroupId.Map.mem back_id info.backwards));
- let backwards =
- T.RegionGroupId.Map.add back_id (abs, back_args) info.backwards
- in
- let info = { info with backwards } in
- let calls = V.FunCallId.Map.add call_id info ctx.calls in
- (* Insert the abstraction in the abstractions map *)
- let abstractions = ctx.abstractions in
- assert (not (V.AbstractionId.Map.mem abs.abs_id abstractions));
- let abstractions =
- V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions
- in
- (* Retrieve the fun_id *)
- let fun_id =
- match info.forward.call_id with
- | S.Fun (fid, _) -> Fun (FromLlbc (fid, None, Some back_id))
- | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable")
- in
- (* Update the context and return *)
- ({ ctx with calls; abstractions }, fun_id)
+(* Some generic translation functions (we need to translate different "flavours"
+ of types: sty, forward types, backward types, etc.) *)
+let rec translate_generic_args (translate_ty : 'r T.ty -> ty)
+ (generics : 'r T.generic_args) : generic_args =
+ (* We ignore the regions: if they didn't cause trouble for the symbolic execution,
+ then everything's fine *)
+ let types = List.map translate_ty generics.types in
+ let const_generics = generics.const_generics in
+ let trait_refs =
+ List.map (translate_trait_ref translate_ty) generics.trait_refs
+ in
+ { types; const_generics; trait_refs }
+
+and translate_trait_ref (translate_ty : 'r T.ty -> ty) (tr : 'r T.trait_ref) :
+ trait_ref =
+ let trait_id = translate_trait_instance_id translate_ty tr.trait_id in
+ let generics = translate_generic_args translate_ty tr.generics in
+ let trait_decl_ref =
+ translate_trait_decl_ref translate_ty tr.trait_decl_ref
+ in
+ { trait_id; generics; trait_decl_ref }
+
+and translate_trait_decl_ref (translate_ty : 'r T.ty -> ty)
+ (tr : 'r T.trait_decl_ref) : trait_decl_ref =
+ let decl_generics = translate_generic_args translate_ty tr.decl_generics in
+ { trait_decl_id = tr.trait_decl_id; decl_generics }
+
+and translate_trait_instance_id (translate_ty : 'r T.ty -> ty)
+ (id : 'r T.trait_instance_id) : trait_instance_id =
+ let translate_trait_instance_id = translate_trait_instance_id translate_ty in
+ match id with
+ | T.Self -> Self
+ | TraitImpl id -> TraitImpl id
+ | BuiltinOrAuto _ ->
+ (* We should have eliminated those in the prepasses *)
+ raise (Failure "Unreachable")
+ | Clause id -> Clause id
+ | ParentClause (inst_id, decl_id, clause_id) ->
+ let inst_id = translate_trait_instance_id inst_id in
+ ParentClause (inst_id, decl_id, clause_id)
+ | ItemClause (inst_id, decl_id, item_name, clause_id) ->
+ let inst_id = translate_trait_instance_id inst_id in
+ ItemClause (inst_id, decl_id, item_name, clause_id)
+ | TraitRef tr -> TraitRef (translate_trait_ref translate_ty tr)
+ | UnknownTrait s -> raise (Failure ("Unknown trait found: " ^ s))
let rec translate_sty (ty : T.sty) : ty =
let translate = translate_sty in
match ty with
- | T.Adt (type_id, regions, tys, cgs) -> (
- (* Can't translate types with regions for now *)
- assert (regions = []);
- let tys = List.map translate tys in
+ | T.Adt (type_id, generics) -> (
+ let generics = translate_sgeneric_args generics in
match type_id with
- | T.AdtId adt_id -> Adt (AdtId adt_id, tys, cgs)
- | T.Tuple -> mk_simpl_tuple_ty tys
+ | T.AdtId adt_id -> Adt (AdtId adt_id, generics)
+ | T.Tuple ->
+ assert (generics.const_generics = []);
+ mk_simpl_tuple_ty generics.types
| T.Assumed aty -> (
match aty with
- | T.Vec -> Adt (Assumed Vec, tys, cgs)
- | T.Option -> Adt (Assumed Option, tys, cgs)
+ | T.Vec -> Adt (Assumed Vec, generics)
+ | T.Option -> Adt (Assumed Option, generics)
| T.Box -> (
(* Eliminate the boxes *)
- match tys with
+ match generics.types with
| [ ty ] -> ty
| _ ->
raise
(Failure
"Box/vec/option type with incorrect number of arguments")
)
- | T.Array -> Adt (Assumed Array, tys, cgs)
- | T.Slice -> Adt (Assumed Slice, tys, cgs)
- | T.Str -> Adt (Assumed Str, tys, cgs)
- | T.Range -> Adt (Assumed Range, tys, cgs)))
+ | T.Array -> Adt (Assumed Array, generics)
+ | T.Slice -> Adt (Assumed Slice, generics)
+ | T.Str -> Adt (Assumed Str, generics)
+ | T.Range -> Adt (Assumed Range, generics)))
| TypeVar vid -> TypeVar vid
| Literal ty -> Literal ty
| Never -> raise (Failure "Unreachable")
| Ref (_, rty, _) -> translate rty
+ | TraitType (trait_ref, generics, type_name) ->
+ let trait_ref = translate_strait_ref trait_ref in
+ let generics = translate_sgeneric_args generics in
+ TraitType (trait_ref, generics, type_name)
+
+and translate_sgeneric_args (generics : T.sgeneric_args) : generic_args =
+ translate_generic_args translate_sty generics
+
+and translate_strait_ref (tr : T.strait_ref) : trait_ref =
+ translate_trait_ref translate_sty tr
+
+and translate_strait_instance_id (id : T.strait_instance_id) : trait_instance_id
+ =
+ translate_trait_instance_id translate_sty id
+
+let translate_trait_clause (clause : T.trait_clause) : trait_clause =
+ let { T.clause_id; meta = _; trait_id; generics } = clause in
+ let generics = translate_sgeneric_args generics in
+ { clause_id; trait_id; generics }
+
+let translate_strait_type_constraint (ttc : T.strait_type_constraint) :
+ trait_type_constraint =
+ let { T.trait_ref; generics; type_name; ty } = ttc in
+ let trait_ref = translate_strait_ref trait_ref in
+ let generics = translate_sgeneric_args generics in
+ let ty = translate_sty ty in
+ { trait_ref; generics; type_name; ty }
+
+let translate_predicates (preds : T.predicates) : predicates =
+ let trait_type_constraints =
+ List.map translate_strait_type_constraint preds.trait_type_constraints
+ in
+ { trait_type_constraints }
+
+let translate_generic_params (generics : T.generic_params) : generic_params =
+ let { T.regions = _; types; const_generics; trait_clauses } = generics in
+ let trait_clauses = List.map translate_trait_clause trait_clauses in
+ { types; const_generics; trait_clauses }
let translate_field (f : T.field) : field =
let field_name = f.field_name in
@@ -452,15 +491,16 @@ let translate_type_decl_kind (kind : T.type_decl_kind) : type_decl_kind =
point of moving this definition for now.
*)
let translate_type_decl (def : T.type_decl) : type_decl =
- (* Translate *)
let def_id = def.T.def_id in
let name = def.name in
+ let { T.regions; types; const_generics; trait_clauses } = def.generics in
(* Can't translate types with regions for now *)
- assert (def.region_params = []);
- let type_params = def.type_params in
- let const_generic_params = def.const_generic_params in
+ assert (regions = []);
+ let trait_clauses = List.map translate_trait_clause trait_clauses in
+ let generics = { types; const_generics; trait_clauses } in
let kind = translate_type_decl_kind def.T.kind in
- { def_id; name; type_params; const_generic_params; kind }
+ let preds = translate_predicates def.preds in
+ { def_id; name; generics; kind; preds }
let translate_type_id (id : T.type_id) : type_id =
match id with
@@ -488,28 +528,33 @@ let translate_type_id (id : T.type_id) : type_id =
let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty =
let translate = translate_fwd_ty type_infos in
match ty with
- | T.Adt (type_id, regions, tys, cgs) -> (
- (* Can't translate types with regions for now *)
- assert (regions = []);
- (* Translate the type parameters *)
- let t_tys = List.map translate tys in
+ | T.Adt (type_id, generics) -> (
+ let t_generics = translate_fwd_generic_args type_infos generics in
(* Eliminate boxes and simplify tuples *)
match type_id with
| AdtId _
| T.Assumed (T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) ->
(* No general parametricity for now *)
- assert (not (List.exists (TypesUtils.ty_has_borrows type_infos) tys));
+ assert (
+ not
+ (List.exists
+ (TypesUtils.ty_has_borrows type_infos)
+ generics.types));
let type_id = translate_type_id type_id in
- Adt (type_id, t_tys, cgs)
+ Adt (type_id, t_generics)
| Tuple ->
(* Note that if there is exactly one type, [mk_simpl_tuple_ty] is the
identity *)
- mk_simpl_tuple_ty t_tys
+ mk_simpl_tuple_ty t_generics.types
| T.Assumed T.Box -> (
(* We eliminate boxes *)
(* No general parametricity for now *)
- assert (not (List.exists (TypesUtils.ty_has_borrows type_infos) tys));
- match t_tys with
+ assert (
+ not
+ (List.exists
+ (TypesUtils.ty_has_borrows type_infos)
+ generics.types));
+ match t_generics.types with
| [ bty ] -> bty
| _ ->
raise
@@ -520,12 +565,34 @@ let rec translate_fwd_ty (type_infos : TA.type_infos) (ty : 'r T.ty) : ty =
| Never -> raise (Failure "Unreachable")
| Literal lty -> Literal lty
| Ref (_, rty, _) -> translate rty
+ | TraitType (trait_ref, generics, type_name) ->
+ let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
+ let generics = translate_fwd_generic_args type_infos generics in
+ TraitType (trait_ref, generics, type_name)
+
+and translate_fwd_generic_args (type_infos : TA.type_infos)
+ (generics : 'r T.generic_args) : generic_args =
+ translate_generic_args (translate_fwd_ty type_infos) generics
+
+and translate_fwd_trait_ref (type_infos : TA.type_infos) (tr : 'r T.trait_ref) :
+ trait_ref =
+ translate_trait_ref (translate_fwd_ty type_infos) tr
+
+and translate_fwd_trait_instance_id (type_infos : TA.type_infos)
+ (id : 'r T.trait_instance_id) : trait_instance_id =
+ translate_trait_instance_id (translate_fwd_ty type_infos) id
(** Simply calls [translate_fwd_ty] *)
let ctx_translate_fwd_ty (ctx : bs_ctx) (ty : 'r T.ty) : ty =
let type_infos = ctx.type_context.type_infos in
translate_fwd_ty type_infos ty
+(** Simply calls [translate_fwd_generic_args] *)
+let ctx_translate_fwd_generic_args (ctx : bs_ctx) (generics : 'r T.generic_args)
+ : generic_args =
+ let type_infos = ctx.type_context.type_infos in
+ translate_fwd_generic_args type_infos generics
+
(** Translate a type, when some regions may have ended.
We return an option, because the translated type may be empty.
@@ -538,7 +605,7 @@ let rec translate_back_ty (type_infos : TA.type_infos)
(* A small helper for "leave" types *)
let wrap ty = if inside_mut then Some ty else None in
match ty with
- | T.Adt (type_id, _, tys, cgs) -> (
+ | T.Adt (type_id, generics) -> (
match type_id with
| T.AdtId _
| Assumed (T.Vec | T.Option | T.Array | T.Slice | T.Str | T.Range) ->
@@ -546,22 +613,24 @@ let rec translate_back_ty (type_infos : TA.type_infos)
assert (not (TypesUtils.ty_has_borrows type_infos ty));
let type_id = translate_type_id type_id in
if inside_mut then
- let tys_t = List.filter_map translate tys in
- Some (Adt (type_id, tys_t, cgs))
+ (* We do not want to filter anything, so we translate the generics
+ as "forward" types *)
+ let generics = translate_fwd_generic_args type_infos generics in
+ Some (Adt (type_id, generics))
else None
| Assumed T.Box -> (
(* Don't accept ADTs (which are not tuples) with borrows for now *)
assert (not (TypesUtils.ty_has_borrows type_infos ty));
(* Eliminate the box *)
- match tys with
+ match generics.types with
| [ bty ] -> translate bty
| _ ->
raise
(Failure "Unreachable: boxes receive exactly one type parameter")
)
| T.Tuple -> (
- (* Tuples can contain borrows (which we eliminated) *)
- let tys_t = List.filter_map translate tys in
+ (* Tuples can contain borrows (which we eliminate) *)
+ let tys_t = List.filter_map translate generics.types in
match tys_t with
| [] -> None
| _ ->
@@ -582,6 +651,13 @@ let rec translate_back_ty (type_infos : TA.type_infos)
if keep_region r then
translate_back_ty type_infos keep_region inside_mut rty
else None)
+ | TraitType (trait_ref, generics, type_name) ->
+ assert (generics.regions = []);
+ (* Translate the trait ref and the generics as "forward" generics -
+ we do not want to filter any type *)
+ let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
+ let generics = translate_fwd_generic_args type_infos generics in
+ Some (TraitType (trait_ref, generics, type_name))
(** Simply calls [translate_back_ty] *)
let ctx_translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool)
@@ -589,6 +665,80 @@ let ctx_translate_back_ty (ctx : bs_ctx) (keep_region : 'r -> bool)
let type_infos = ctx.type_context.type_infos in
translate_back_ty type_infos keep_region inside_mut ty
+let mk_type_check_ctx (ctx : bs_ctx) : PureTypeCheck.tc_ctx =
+ let const_generics =
+ T.ConstGenericVarId.Map.of_list
+ (List.map
+ (fun (cg : T.const_generic_var) ->
+ (cg.index, ctx_translate_fwd_ty ctx (T.Literal cg.ty)))
+ ctx.sg.generics.const_generics)
+ in
+ let env = VarId.Map.empty in
+ {
+ PureTypeCheck.type_decls = ctx.type_context.type_decls;
+ global_decls = ctx.global_context.llbc_global_decls;
+ env;
+ const_generics;
+ }
+
+let type_check_pattern (ctx : bs_ctx) (v : typed_pattern) : unit =
+ let ctx = mk_type_check_ctx ctx in
+ let _ = PureTypeCheck.check_typed_pattern ctx v in
+ ()
+
+let type_check_texpression (ctx : bs_ctx) (e : texpression) : unit =
+ if !Config.type_check_pure_code then
+ let ctx = mk_type_check_ctx ctx in
+ PureTypeCheck.check_texpression ctx e
+
+let translate_fun_id_or_trait_method_ref (ctx : bs_ctx)
+ (id : A.fun_id_or_trait_method_ref) : fun_id_or_trait_method_ref =
+ match id with
+ | A.FunId fun_id -> FunId fun_id
+ | TraitMethod (trait_ref, method_name, fun_decl_id) ->
+ let type_infos = ctx.type_context.type_infos in
+ let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
+ TraitMethod (trait_ref, method_name, fun_decl_id)
+
+let bs_ctx_register_forward_call (call_id : V.FunCallId.id) (forward : S.call)
+ (args : texpression list) (ctx : bs_ctx) : bs_ctx =
+ let calls = ctx.calls in
+ assert (not (V.FunCallId.Map.mem call_id calls));
+ let info =
+ { forward; forward_inputs = args; backwards = T.RegionGroupId.Map.empty }
+ in
+ let calls = V.FunCallId.Map.add call_id info calls in
+ { ctx with calls }
+
+(** [back_args]: the *additional* list of inputs received by the backward function *)
+let bs_ctx_register_backward_call (abs : V.abs) (call_id : V.FunCallId.id)
+ (back_id : T.RegionGroupId.id) (back_args : texpression list) (ctx : bs_ctx)
+ : bs_ctx * fun_or_op_id =
+ (* Insert the abstraction in the call informations *)
+ let info = V.FunCallId.Map.find call_id ctx.calls in
+ assert (not (T.RegionGroupId.Map.mem back_id info.backwards));
+ let backwards =
+ T.RegionGroupId.Map.add back_id (abs, back_args) info.backwards
+ in
+ let info = { info with backwards } in
+ let calls = V.FunCallId.Map.add call_id info ctx.calls in
+ (* Insert the abstraction in the abstractions map *)
+ let abstractions = ctx.abstractions in
+ assert (not (V.AbstractionId.Map.mem abs.abs_id abstractions));
+ let abstractions =
+ V.AbstractionId.Map.add abs.abs_id (abs, back_args) abstractions
+ in
+ (* Retrieve the fun_id *)
+ let fun_id =
+ match info.forward.call_id with
+ | S.Fun (fid, _) ->
+ let fid = translate_fun_id_or_trait_method_ref ctx fid in
+ Fun (FromLlbc (fid, None, Some back_id))
+ | S.Unop _ | S.Binop _ -> raise (Failure "Unreachable")
+ in
+ (* Update the context and return *)
+ ({ ctx with calls; abstractions }, fun_id)
+
(** List the ancestors of an abstraction *)
let list_ancestor_abstractions_ids (ctx : bs_ctx) (abs : V.abs)
(call_id : V.FunCallId.id) : V.AbstractionId.id list =
@@ -642,10 +792,10 @@ let mk_fuel_input_as_list (ctx : bs_ctx) (info : fun_effect_info) :
(** Small utility. *)
let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
- (fun_id : A.fun_id) (lid : V.LoopId.id option)
+ (fun_id : A.fun_id_or_trait_method_ref) (lid : V.LoopId.id option)
(gid : T.RegionGroupId.id option) : fun_effect_info =
match fun_id with
- | A.Regular fid ->
+ | A.TraitMethod (_, _, fid) | A.FunId (A.Regular fid) ->
let info = A.FunDeclId.Map.find fid fun_infos in
let stateful_group = info.stateful in
let stateful =
@@ -658,7 +808,7 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
can_diverge = info.can_diverge;
is_rec = info.is_rec || Option.is_some lid;
}
- | A.Assumed aid ->
+ | A.FunId (A.Assumed aid) ->
assert (lid = None);
{
can_fail = Assumed.assumed_can_fail aid;
@@ -673,12 +823,14 @@ let get_fun_effect_info (fun_infos : FA.fun_info A.FunDeclId.Map.t)
Note that the function also takes a list of names for the inputs, and
computes, for every output for the backward functions, a corresponding
name (outputs for backward functions come from borrows in the inputs
- of the forward function) which we use as hints to generate pretty names.
+ of the forward function) which we use as hints to generate pretty names
+ in the extracted code.
*)
-let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
- (fun_id : A.fun_id) (type_infos : TA.type_infos) (sg : A.fun_sig)
- (input_names : string option list) (bid : T.RegionGroupId.id option) :
- fun_sig_named_outputs =
+let translate_fun_sig (decls_ctx : C.decls_ctx) (fun_id : A.fun_id)
+ (sg : A.fun_sig) (input_names : string option list)
+ (bid : T.RegionGroupId.id option) : fun_sig_named_outputs =
+ let fun_infos = decls_ctx.fun_ctx.fun_infos in
+ let type_infos = decls_ctx.type_ctx.type_infos in
(* Retrieve the list of parent backward functions *)
let gid, parents =
match bid with
@@ -689,7 +841,34 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
in
(* Is the function stateful, and can it fail? *)
let lid = None in
- let effect_info = get_fun_effect_info fun_infos fun_id lid bid in
+ let effect_info = get_fun_effect_info fun_infos (A.FunId fun_id) lid bid in
+ (* We need an evaluation context to normalize the types (to normalize the
+ associated types, etc. - for instance it may happen that the types
+ refer to the types associated to a trait ref, but where the trait ref
+ is a known impl). *)
+ (* Create the context *)
+ let ctx =
+ let region_groups =
+ List.map (fun (g : T.region_var_group) -> g.id) sg.regions_hierarchy
+ in
+ let ctx =
+ InterpreterUtils.initialize_eval_context decls_ctx region_groups
+ sg.generics.types sg.generics.const_generics
+ in
+ (* Compute the normalization map for the *sty* types and add it to the context *)
+ AssociatedTypes.ctx_add_norm_trait_stypes_from_preds ctx
+ sg.preds.trait_type_constraints
+ in
+
+ (* Normalize the signature *)
+ let sg =
+ let ({ A.inputs; output; _ } : A.fun_sig) = sg in
+ let norm = AssociatedTypes.ctx_normalize_sty ctx in
+ let inputs = List.map norm inputs in
+ let output = norm output in
+ { sg with A.inputs; output }
+ in
+
(* List the inputs for:
* - the fuel
* - the forward function
@@ -806,9 +985,8 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
(* Wrap in a result type *)
if effect_info.can_fail then mk_result_ty output else output
in
- (* Type/const generic parameters *)
- let type_params = sg.type_params in
- let const_generic_params = sg.const_generic_params in
+ (* Generic parameters *)
+ let generics = translate_generic_params sg.generics in
(* Return *)
let has_fuel = fuel <> [] in
let num_fwd_inputs_no_state = List.length fwd_inputs in
@@ -836,9 +1014,8 @@ let translate_fun_sig (fun_infos : FA.fun_info A.FunDeclId.Map.t)
effect_info;
}
in
- let sg =
- { type_params; const_generic_params; inputs; output; doutputs; info }
- in
+ let preds = translate_predicates sg.A.preds in
+ let sg = { generics; preds; inputs; output; doutputs; info } in
{ sg; output_names }
let bs_ctx_fresh_state_var (ctx : bs_ctx) : bs_ctx * typed_pattern =
@@ -917,7 +1094,7 @@ let lookup_var_for_symbolic_value (sv : V.symbolic_value) (ctx : bs_ctx) : var =
(** Peel boxes as long as the value is of the form [Box<T>] *)
let rec unbox_typed_value (v : V.typed_value) : V.typed_value =
match (v.value, v.ty) with
- | V.Adt av, T.Adt (T.Assumed T.Box, _, _, _) -> (
+ | V.Adt av, T.Adt (T.Assumed T.Box, _) -> (
match av.field_values with
| [ bv ] -> unbox_typed_value bv
| _ -> raise (Failure "Unreachable"))
@@ -962,16 +1139,16 @@ let rec typed_value_to_texpression (ctx : bs_ctx) (ectx : C.eval_ctx)
let field_values = List.map translate av.field_values in
(* Eliminate the tuple wrapper if it is a tuple with exactly one field *)
match v.ty with
- | T.Adt (T.Tuple, _, _, _) ->
+ | T.Adt (T.Tuple, _) ->
assert (variant_id = None);
mk_simpl_tuple_texpression field_values
| _ ->
- (* Retrieve the type, the translated type arguments and the
- * const generic arguments from the translated type (simpler this way) *)
- let adt_id, type_args, const_generic_args = ty_as_adt ty in
+ (* Retrieve the type and the translated generics from the translated
+ type (simpler this way) *)
+ let adt_id, generics = ty_as_adt ty in
(* Create the constructor *)
let qualif_id = AdtCons { adt_id; variant_id = av.variant_id } in
- let qualif = { id = qualif_id; type_args; const_generic_args } in
+ let qualif = { id = qualif_id; generics } in
let cons_e = Qualif qualif in
let field_tys =
List.map (fun (v : texpression) -> v.ty) field_values
@@ -1038,7 +1215,7 @@ let rec typed_avalue_to_consumed (ctx : bs_ctx) (ectx : C.eval_ctx)
(* Translate the field values *)
let field_values = List.filter_map translate adt_v.field_values in
(* For now, only tuples can contain borrows *)
- let adt_id, _, _, _ = TypesUtils.ty_as_adt av.ty in
+ let adt_id, _ = TypesUtils.ty_as_adt av.ty in
match adt_id with
| T.AdtId _
| T.Assumed
@@ -1185,7 +1362,7 @@ let rec typed_avalue_to_given_back (mp : mplace option) (av : V.typed_avalue)
(* For now, only tuples can contain borrows - note that if we gave
* something like a [&mut Vec] to a function, we give back the
* vector value upon visiting the "abstraction borrow" node *)
- let adt_id, _, _, _ = TypesUtils.ty_as_adt av.ty in
+ let adt_id, _ = TypesUtils.ty_as_adt av.ty in
match adt_id with
| T.AdtId _
| T.Assumed
@@ -1457,9 +1634,12 @@ and translate_return_with_loop (loop_id : V.LoopId.id) (is_continue : bool)
and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
texpression =
+ log#ldebug
+ (lazy
+ ("translate_function_call:\n"
+ ^ ctx_egeneric_args_to_string ctx call.generics));
(* Translate the function call *)
- let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in
- let const_generic_args = call.const_generic_params in
+ let generics = ctx_translate_fwd_generic_args ctx call.generics in
let args =
let args = List.map (typed_value_to_texpression ctx call.ctx) call.args in
let args_mplaces = List.map translate_opt_mplace call.args_places in
@@ -1475,7 +1655,8 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
match call.call_id with
| S.Fun (fid, call_id) ->
(* Regular function call *)
- let func = Fun (FromLlbc (fid, None, None)) in
+ let fid_t = translate_fun_id_or_trait_method_ref ctx fid in
+ let func = Fun (FromLlbc (fid_t, None, None)) in
(* Retrieve the effect information about this function (can fail,
* takes a state as input, etc.) *)
let effect_info =
@@ -1561,7 +1742,7 @@ and translate_function_call (call : S.call) (e : S.expression) (ctx : bs_ctx) :
| None -> dest
| Some out_state -> mk_simpl_tuple_pattern [ out_state; dest ]
in
- let func = { id = FunOrOp fun_id; type_args; const_generic_args } in
+ let func = { id = FunOrOp fun_id; generics } in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
let ret_ty =
if effect_info.can_fail then mk_result_ty dest_v.ty else dest_v.ty
@@ -1665,9 +1846,11 @@ and translate_end_abstraction_synth_input (ectx : C.eval_ctx) (abs : V.abs)
(* Group the two lists *)
let variables_values = List.combine given_back_variables consumed_values in
(* Sanity check: the two lists match (same types) *)
- List.iter
- (fun (var, v) -> assert ((var : var).ty = (v : texpression).ty))
- variables_values;
+ (* TODO: normalize the types *)
+ if !Config.type_check_pure_code then
+ List.iter
+ (fun (var, v) -> assert ((var : var).ty = (v : texpression).ty))
+ variables_values;
(* Translate the next expression *)
let next_e = translate_expression e ctx in
(* Generate the assignemnts *)
@@ -1692,8 +1875,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
let effect_info =
get_fun_effect_info ctx.fun_context.fun_infos fun_id None (Some rg_id)
in
- let type_args = List.map (ctx_translate_fwd_ty ctx) call.type_params in
- let const_generic_args = call.const_generic_params in
+ let generics = ctx_translate_fwd_generic_args ctx call.generics in
(* Retrieve the original call and the parent abstractions *)
let _forward, backwards = get_abs_ancestors ctx abs call_id in
(* Retrieve the values consumed when we called the forward function and
@@ -1741,34 +1923,35 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
| Some nstate -> mk_simpl_tuple_pattern [ nstate; output ]
in
(* Sanity check: there is the proper number of inputs and outputs, and they have the proper type *)
- let _ =
- let inst_sg =
- get_instantiated_fun_sig fun_id (Some rg_id) type_args const_generic_args
- ctx
- in
- log#ldebug
- (lazy
- ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs ("
- ^ string_of_int (List.length inputs)
- ^ "): "
- ^ String.concat ", " (List.map (texpression_to_string ctx) inputs)
- ^ "\n- inst_sg.inputs ("
- ^ string_of_int (List.length inst_sg.inputs)
- ^ "): "
- ^ String.concat ", " (List.map (ty_to_string ctx) inst_sg.inputs)));
- List.iter
- (fun (x, ty) -> assert ((x : texpression).ty = ty))
- (List.combine inputs inst_sg.inputs);
- log#ldebug
- (lazy
- ("\n- outputs: "
- ^ string_of_int (List.length outputs)
- ^ "\n- expected outputs: "
- ^ string_of_int (List.length inst_sg.doutputs)));
- List.iter
- (fun (x, ty) -> assert ((x : typed_pattern).ty = ty))
- (List.combine outputs inst_sg.doutputs)
- in
+ (if (* TODO: normalize the types *) !Config.type_check_pure_code then
+ match fun_id with
+ | A.FunId fun_id ->
+ let inst_sg =
+ get_instantiated_fun_sig fun_id (Some rg_id) generics ctx
+ in
+ log#ldebug
+ (lazy
+ ("\n- fun_id: " ^ A.show_fun_id fun_id ^ "\n- inputs ("
+ ^ string_of_int (List.length inputs)
+ ^ "): "
+ ^ String.concat ", " (List.map (texpression_to_string ctx) inputs)
+ ^ "\n- inst_sg.inputs ("
+ ^ string_of_int (List.length inst_sg.inputs)
+ ^ "): "
+ ^ String.concat ", " (List.map (ty_to_string ctx) inst_sg.inputs)));
+ List.iter
+ (fun (x, ty) -> assert ((x : texpression).ty = ty))
+ (List.combine inputs inst_sg.inputs);
+ log#ldebug
+ (lazy
+ ("\n- outputs: "
+ ^ string_of_int (List.length outputs)
+ ^ "\n- expected outputs: "
+ ^ string_of_int (List.length inst_sg.doutputs)));
+ List.iter
+ (fun (x, ty) -> assert ((x : typed_pattern).ty = ty))
+ (List.combine outputs inst_sg.doutputs)
+ | _ -> (* TODO: trait methods *) ());
(* Retrieve the function id, and register the function call in the context
* if necessary *)
let ctx, func =
@@ -1788,7 +1971,7 @@ and translate_end_abstraction_fun_call (ectx : C.eval_ctx) (abs : V.abs)
if effect_info.can_fail then mk_result_ty output.ty else output.ty
in
let func_ty = mk_arrows input_tys ret_ty in
- let func = { id = FunOrOp func; type_args; const_generic_args } in
+ let func = { id = FunOrOp func; generics } in
let func = { e = Qualif func; ty = func_ty } in
let call = mk_apps func args in
(* **Optimization**:
@@ -1907,12 +2090,11 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
| V.LoopCall ->
let fun_id = A.Regular ctx.fun_decl.A.def_id in
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos fun_id (Some vloop_id)
- (Some rg_id)
+ get_fun_effect_info ctx.fun_context.fun_infos (A.FunId fun_id)
+ (Some vloop_id) (Some rg_id)
in
let loop_info = LoopId.Map.find loop_id ctx.loops in
- let type_args = loop_info.type_args in
- let const_generic_args = loop_info.const_generic_args in
+ let generics = loop_info.generics in
let fwd_inputs = Option.get loop_info.forward_inputs in
(* Retrieve the additional backward inputs. Note that those are actually
the backward inputs of the function we are synthesizing (and that we
@@ -1960,8 +2142,8 @@ and translate_end_abstraction_loop (ectx : C.eval_ctx) (abs : V.abs)
if effect_info.can_fail then mk_result_ty output.ty else output.ty
in
let func_ty = mk_arrows input_tys ret_ty in
- let func = Fun (FromLlbc (fun_id, Some loop_id, Some rg_id)) in
- let func = { id = FunOrOp func; type_args; const_generic_args } in
+ let func = Fun (FromLlbc (FunId fun_id, Some loop_id, Some rg_id)) in
+ let func = { id = FunOrOp func; generics } in
let func = { e = Qualif func; ty = func_ty } in
let call = mk_apps func args in
(* **Optimization**:
@@ -2021,9 +2203,7 @@ and translate_global_eval (gid : A.GlobalDeclId.id) (sval : V.symbolic_value)
(e : S.expression) (ctx : bs_ctx) : texpression =
let ctx, var = fresh_var_for_symbolic_value sval ctx in
let decl = A.GlobalDeclId.Map.find gid ctx.global_context.llbc_global_decls in
- let global_expr =
- { id = Global gid; type_args = []; const_generic_args = [] }
- in
+ let global_expr = { id = Global gid; generics = empty_generic_args } in
(* We use translate_fwd_ty to translate the global type *)
let ty = ctx_translate_fwd_ty ctx decl.ty in
let gval = { e = Qualif global_expr; ty } in
@@ -2037,11 +2217,7 @@ and translate_assertion (ectx : C.eval_ctx) (v : V.typed_value)
let v = typed_value_to_texpression ctx ectx v in
let args = [ v ] in
let func =
- {
- id = FunOrOp (Fun (Pure Assert));
- type_args = [];
- const_generic_args = [];
- }
+ { id = FunOrOp (Fun (Pure Assert)); generics = empty_generic_args }
in
let func_ty = mk_arrow (Literal Bool) mk_unit_ty in
let func = { e = Qualif func; ty = func_ty } in
@@ -2189,7 +2365,7 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value)
(branch : S.expression) (ctx : bs_ctx) : texpression =
(* TODO: always introduce a match, and use micro-passes to turn the
the match into a let? *)
- let type_id, _, _, _ = TypesUtils.ty_as_adt sv.V.sv_ty in
+ let type_id, _ = TypesUtils.ty_as_adt sv.V.sv_ty in
let ctx, vars = fresh_vars_for_symbolic_values svl ctx in
let branch = translate_expression branch ctx in
match type_id with
@@ -2224,10 +2400,10 @@ and translate_ExpandAdt_one_branch (sv : V.symbolic_value)
* field.
* We use the [dest] variable in order not to have to recompute
* the type of the result of the projection... *)
- let adt_id, type_args, const_generic_args = ty_as_adt scrutinee.ty in
+ let adt_id, generics = ty_as_adt scrutinee.ty in
let gen_field_proj (field_id : FieldId.id) (dest : var) : texpression =
let proj_kind = { adt_id; field_id } in
- let qualif = { id = Proj proj_kind; type_args; const_generic_args } in
+ let qualif = { id = Proj proj_kind; generics } in
let proj_e = Qualif qualif in
let proj_ty = mk_arrow scrutinee.ty dest.ty in
let proj = { e = proj_e; ty = proj_ty } in
@@ -2282,8 +2458,9 @@ and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option)
(* Translate the next expression *)
let next_e = translate_expression e ctx in
- (* Translate the value: there are two cases, depending on whether this
- is a "regular" let-binding or an array aggregate.
+ (* Translate the value: there are several cases, depending on whether this
+ is a "regular" let-binding, an array aggregate, a const generic or
+ a trait associated constant.
*)
let v =
match v with
@@ -2298,6 +2475,14 @@ and translate_intro_symbolic (ectx : C.eval_ctx) (p : S.mplace option)
{ struct_id = Assumed Array; init = None; updates = values }
in
{ e = StructUpdate su; ty = var.ty }
+ | ConstGenericValue cg_id -> { e = CVar cg_id; ty = var.ty }
+ | TraitConstValue (trait_ref, generics, const_name) ->
+ let type_infos = ctx.type_context.type_infos in
+ let trait_ref = translate_fwd_trait_ref type_infos trait_ref in
+ let generics = translate_fwd_generic_args type_infos generics in
+ let qualif_id = TraitConst (trait_ref, generics, const_name) in
+ let qualif = { id = qualif_id; generics = empty_generic_args } in
+ { e = Qualif qualif; ty = var.ty }
in
(* Make the let-binding *)
@@ -2370,7 +2555,7 @@ and translate_forward_end (ectx : C.eval_ctx)
(* Lookup the effect info for the loop function *)
let fid = A.Regular ctx.fun_decl.A.def_id in
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos fid None ctx.bid
+ get_fun_effect_info ctx.fun_context.fun_infos (A.FunId fid) None ctx.bid
in
(* Introduce a fresh output value for the forward function *)
@@ -2415,14 +2600,8 @@ and translate_forward_end (ectx : C.eval_ctx)
let out_pat = mk_simpl_tuple_pattern out_pats in
let loop_call =
- let fun_id = Fun (FromLlbc (fid, Some loop_id, None)) in
- let func =
- {
- id = FunOrOp fun_id;
- type_args = loop_info.type_args;
- const_generic_args = loop_info.const_generic_args;
- }
- in
+ let fun_id = Fun (FromLlbc (FunId fid, Some loop_id, None)) in
+ let func = { id = FunOrOp fun_id; generics = loop_info.generics } in
let input_tys = (List.map (fun (x : texpression) -> x.ty)) args in
let ret_ty =
if effect_info.can_fail then mk_result_ty out_pat.ty else out_pat.ty
@@ -2541,14 +2720,31 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
(* Note that we will retrieve the input values later in the [ForwardEnd]
(and will introduce the outputs at that moment, together with the actual
- call to the loop forward function *)
- let type_args =
- List.map (fun (ty : T.type_var) -> TypeVar ty.T.index) ctx.sg.type_params
- in
- let const_generic_args =
- List.map
- (fun (cg : T.const_generic_var) -> T.ConstGenericVar cg.T.index)
- ctx.sg.const_generic_params
+ call to the loop forward function) *)
+ let generics =
+ let { types; const_generics; trait_clauses } = ctx.sg.generics in
+ let types =
+ List.map (fun (ty : T.type_var) -> TypeVar ty.T.index) types
+ in
+ let const_generics =
+ List.map
+ (fun (cg : T.const_generic_var) -> T.ConstGenericVar cg.T.index)
+ const_generics
+ in
+ let trait_refs =
+ List.map
+ (fun (c : trait_clause) ->
+ let trait_decl_ref =
+ { trait_decl_id = c.trait_id; decl_generics = empty_generic_args }
+ in
+ {
+ trait_id = Clause c.clause_id;
+ generics = empty_generic_args;
+ trait_decl_ref;
+ })
+ trait_clauses
+ in
+ { types; const_generics; trait_refs }
in
let loop_info =
@@ -2556,8 +2752,7 @@ and translate_loop (loop : S.loop) (ctx : bs_ctx) : texpression =
loop_id;
input_vars = inputs;
input_svl = loop.input_svalues;
- type_args;
- const_generic_args;
+ generics;
forward_inputs = None;
forward_output_no_state_no_result = None;
}
@@ -2648,8 +2843,7 @@ let wrap_in_match_fuel (fuel0 : VarId.id) (fuel : VarId.id) (body : texpression)
let func =
{
id = FunOrOp (Fun (Pure FuelEqZero));
- type_args = [];
- const_generic_args = [];
+ generics = empty_generic_args;
}
in
let func_ty = mk_arrow mk_fuel_ty mk_bool_ty in
@@ -2661,8 +2855,7 @@ let wrap_in_match_fuel (fuel0 : VarId.id) (fuel : VarId.id) (body : texpression)
let func =
{
id = FunOrOp (Fun (Pure FuelDecrease));
- type_args = [];
- const_generic_args = [];
+ generics = empty_generic_args;
}
in
let func_ty = mk_arrow mk_fuel_ty mk_fuel_ty in
@@ -2727,8 +2920,8 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
| None -> None
| Some body ->
let effect_info =
- get_fun_effect_info ctx.fun_context.fun_infos (Regular def_id) None
- bid
+ get_fun_effect_info ctx.fun_context.fun_infos (FunId (Regular def_id))
+ None bid
in
let body = translate_expression body ctx in
(* Add a match over the fuel, if necessary *)
@@ -2803,10 +2996,12 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
^ "\n- signature.inputs: "
^ String.concat ", " (List.map (ty_to_string ctx) signature.inputs)
));
- assert (
- List.for_all
- (fun (var, ty) -> (var : var).ty = ty)
- (List.combine inputs signature.inputs));
+ (* TODO: we need to normalize the types *)
+ if !Config.type_check_pure_code then
+ assert (
+ List.for_all
+ (fun (var, ty) -> (var : var).ty = ty)
+ (List.combine inputs signature.inputs));
Some { inputs; inputs_lvs; body }
in
@@ -2821,6 +3016,7 @@ let translate_fun_decl (ctx : bs_ctx) (body : S.expression option) : fun_decl =
let def =
{
def_id;
+ kind = def.kind;
num_loops;
loop_id;
back_id = bid;
@@ -2853,8 +3049,7 @@ let translate_type_decls (type_decls : T.type_decl list) : type_decl list =
- optional names for the outputs values (we derive them for the backward
functions)
*)
-let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t)
- (type_infos : TA.type_infos)
+let translate_fun_signatures (decls_ctx : C.decls_ctx)
(functions : (A.fun_id * string option list * A.fun_sig) list) :
fun_sig_named_outputs RegularFunIdNotLoopMap.t =
(* For every function, translate the signatures of:
@@ -2865,17 +3060,14 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t)
(sg : A.fun_sig) : (regular_fun_id_not_loop * fun_sig_named_outputs) list
=
(* The forward function *)
- let fwd_sg =
- translate_fun_sig fun_infos fun_id type_infos sg input_names None
- in
+ let fwd_sg = translate_fun_sig decls_ctx fun_id sg input_names None in
let fwd_id = (fun_id, None) in
(* The backward functions *)
let back_sgs =
List.map
(fun (rg : T.region_var_group) ->
let tsg =
- translate_fun_sig fun_infos fun_id type_infos sg input_names
- (Some rg.id)
+ translate_fun_sig decls_ctx fun_id sg input_names (Some rg.id)
in
let id = (fun_id, Some rg.id) in
(id, tsg))
@@ -2891,3 +3083,91 @@ let translate_fun_signatures (fun_infos : FA.fun_info A.FunDeclId.Map.t)
List.fold_left
(fun m (id, sg) -> RegularFunIdNotLoopMap.add id sg m)
RegularFunIdNotLoopMap.empty translated
+
+let translate_trait_decl (type_infos : TA.type_infos)
+ (trait_decl : A.trait_decl) : trait_decl =
+ let {
+ A.def_id;
+ name;
+ generics;
+ preds;
+ all_trait_clauses;
+ consts;
+ types;
+ required_methods;
+ provided_methods;
+ } =
+ trait_decl
+ in
+ let generics = translate_generic_params generics in
+ let preds = translate_predicates preds in
+ let all_trait_clauses = List.map translate_trait_clause all_trait_clauses in
+ let consts =
+ List.map
+ (fun (name, (ty, id)) -> (name, (translate_fwd_ty type_infos ty, id)))
+ consts
+ in
+ let types =
+ List.map
+ (fun (name, (trait_clauses, ty)) ->
+ ( name,
+ ( List.map translate_trait_clause trait_clauses,
+ Option.map (translate_fwd_ty type_infos) ty ) ))
+ types
+ in
+ {
+ def_id;
+ name;
+ generics;
+ preds;
+ all_trait_clauses;
+ consts;
+ types;
+ required_methods;
+ provided_methods;
+ }
+
+let translate_trait_impl (type_infos : TA.type_infos)
+ (trait_impl : A.trait_impl) : trait_impl =
+ let {
+ A.def_id;
+ name;
+ impl_trait;
+ generics;
+ preds;
+ consts;
+ types;
+ required_methods;
+ provided_methods;
+ } =
+ trait_impl
+ in
+ let impl_trait =
+ translate_trait_decl_ref (translate_fwd_ty type_infos) impl_trait
+ in
+ let generics = translate_generic_params generics in
+ let preds = translate_predicates preds in
+ let consts =
+ List.map
+ (fun (name, (ty, id)) -> (name, (translate_fwd_ty type_infos ty, id)))
+ consts
+ in
+ let types =
+ List.map
+ (fun (name, (trait_refs, ty)) ->
+ ( name,
+ ( List.map (translate_fwd_trait_ref type_infos) trait_refs,
+ translate_fwd_ty type_infos ty ) ))
+ types
+ in
+ {
+ def_id;
+ name;
+ impl_trait;
+ generics;
+ preds;
+ consts;
+ types;
+ required_methods;
+ provided_methods;
+ }
diff --git a/compiler/SynthesizeSymbolic.ml b/compiler/SynthesizeSymbolic.ml
index 857fea97..aeb6899f 100644
--- a/compiler/SynthesizeSymbolic.ml
+++ b/compiler/SynthesizeSymbolic.ml
@@ -64,7 +64,7 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value)
assert (otherwise_see = None);
(* Return *)
ExpandInt (int_ty, branches, otherwise)
- | T.Adt (_, _, _, _) ->
+ | T.Adt (_, _) ->
(* Branching: it is necessarily an enumeration expansion *)
let get_variant (see : V.symbolic_expansion option) :
T.VariantId.id option * V.symbolic_value list =
@@ -85,7 +85,7 @@ let synthesize_symbolic_expansion (sv : V.symbolic_value)
match ls with
| [ (Some see, exp) ] -> ExpandNoBranch (see, exp)
| _ -> raise (Failure "Ill-formed borrow expansion"))
- | T.TypeVar _ | T.Literal Char | Never ->
+ | T.TypeVar _ | T.Literal Char | Never | T.TraitType _ ->
raise (Failure "Ill-formed symbolic expansion")
in
Some (Expansion (place, sv, expansion))
@@ -97,10 +97,10 @@ let synthesize_symbolic_expansion_no_branching (sv : V.symbolic_value)
synthesize_symbolic_expansion sv place [ Some see ] el
let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx)
- (abstractions : V.AbstractionId.id list) (type_params : T.ety list)
- (const_generic_params : T.const_generic list) (args : V.typed_value list)
- (args_places : mplace option list) (dest : V.symbolic_value)
- (dest_place : mplace option) (e : expression option) : expression option =
+ (abstractions : V.AbstractionId.id list) (generics : T.egeneric_args)
+ (args : V.typed_value list) (args_places : mplace option list)
+ (dest : V.symbolic_value) (dest_place : mplace option)
+ (e : expression option) : expression option =
Option.map
(fun e ->
let call =
@@ -108,8 +108,7 @@ let synthesize_function_call (call_id : call_id) (ctx : Contexts.eval_ctx)
call_id;
ctx;
abstractions;
- type_params;
- const_generic_params;
+ generics;
args;
dest;
args_places;
@@ -123,28 +122,29 @@ let synthesize_global_eval (gid : A.GlobalDeclId.id) (dest : V.symbolic_value)
(e : expression option) : expression option =
Option.map (fun e -> EvalGlobal (gid, dest, e)) e
-let synthesize_regular_function_call (fun_id : A.fun_id)
+let synthesize_regular_function_call (fun_id : A.fun_id_or_trait_method_ref)
(call_id : V.FunCallId.id) (ctx : Contexts.eval_ctx)
- (abstractions : V.AbstractionId.id list) (type_params : T.ety list)
- (const_generic_params : T.const_generic list) (args : V.typed_value list)
- (args_places : mplace option list) (dest : V.symbolic_value)
- (dest_place : mplace option) (e : expression option) : expression option =
+ (abstractions : V.AbstractionId.id list) (generics : T.egeneric_args)
+ (args : V.typed_value list) (args_places : mplace option list)
+ (dest : V.symbolic_value) (dest_place : mplace option)
+ (e : expression option) : expression option =
synthesize_function_call
(Fun (fun_id, call_id))
- ctx abstractions type_params const_generic_params args args_places dest
- dest_place e
+ ctx abstractions generics args args_places dest dest_place e
let synthesize_unary_op (ctx : Contexts.eval_ctx) (unop : E.unop)
(arg : V.typed_value) (arg_place : mplace option) (dest : V.symbolic_value)
(dest_place : mplace option) (e : expression option) : expression option =
- synthesize_function_call (Unop unop) ctx [] [] [] [ arg ] [ arg_place ] dest
- dest_place e
+ let generics = TypesUtils.mk_empty_generic_args in
+ synthesize_function_call (Unop unop) ctx [] generics [ arg ] [ arg_place ]
+ dest dest_place e
let synthesize_binary_op (ctx : Contexts.eval_ctx) (binop : E.binop)
(arg0 : V.typed_value) (arg0_place : mplace option) (arg1 : V.typed_value)
(arg1_place : mplace option) (dest : V.symbolic_value)
(dest_place : mplace option) (e : expression option) : expression option =
- synthesize_function_call (Binop binop) ctx [] [] [] [ arg0; arg1 ]
+ let generics = TypesUtils.mk_empty_generic_args in
+ synthesize_function_call (Binop binop) ctx [] generics [ arg0; arg1 ]
[ arg0_place; arg1_place ] dest dest_place e
let synthesize_end_abstraction (ctx : Contexts.eval_ctx) (abs : V.abs)
diff --git a/compiler/Translate.ml b/compiler/Translate.ml
index 70ef5e3d..e69abee1 100644
--- a/compiler/Translate.ml
+++ b/compiler/Translate.ml
@@ -5,6 +5,7 @@ module T = Types
module A = LlbcAst
module SA = SymbolicAst
module Micro = PureMicroPasses
+module C = Contexts
open PureUtils
open TranslateCore
@@ -28,18 +29,12 @@ let translate_function_to_symbolics (trans_ctx : trans_ctx) (fdef : A.fun_decl)
("translate_function_to_symbolics: "
^ Print.fun_name_to_string fdef.A.name));
- let { type_context; fun_context; global_context } = trans_ctx in
- let fun_context = { C.fun_decls = fun_context.fun_decls } in
-
match fdef.body with
| None -> None
| Some _ ->
(* Evaluate *)
let synthesize = true in
- let inputs, symb =
- evaluate_function_symbolic synthesize type_context fun_context
- global_context fdef
- in
+ let inputs, symb = evaluate_function_symbolic synthesize trans_ctx fdef in
Some (inputs, Option.get symb)
(** Translate a function, by generating its forward and backward translations.
@@ -57,7 +52,6 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
(lazy
("translate_function_to_pure: " ^ Print.fun_name_to_string fdef.A.name));
- let { type_context; fun_context; global_context } = trans_ctx in
let def_id = fdef.def_id in
(* Compute the symbolic ASTs, if the function is transparent *)
@@ -82,25 +76,25 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
(List.filter_map
(fun (tid, g) ->
match g with Charon.GAst.NonRec _ -> None | Rec _ -> Some tid)
- (T.TypeDeclId.Map.bindings trans_ctx.type_context.type_decls_groups))
+ (T.TypeDeclId.Map.bindings trans_ctx.type_ctx.type_decls_groups))
in
let type_context =
{
- SymbolicToPure.type_infos = type_context.type_infos;
- llbc_type_decls = type_context.type_decls;
+ SymbolicToPure.type_infos = trans_ctx.type_ctx.type_infos;
+ llbc_type_decls = trans_ctx.type_ctx.type_decls;
type_decls = pure_type_decls;
recursive_decls = recursive_type_decls;
}
in
let fun_context =
{
- SymbolicToPure.llbc_fun_decls = fun_context.fun_decls;
+ SymbolicToPure.llbc_fun_decls = trans_ctx.fun_ctx.fun_decls;
fun_sigs;
- fun_infos = fun_context.fun_infos;
+ fun_infos = trans_ctx.fun_ctx.fun_infos;
}
in
let global_context =
- { SymbolicToPure.llbc_global_decls = global_context.global_decls }
+ { SymbolicToPure.llbc_global_decls = trans_ctx.global_ctx.global_decls }
in
(* Compute the set of loops, and find better ids for them (starting at 0).
@@ -148,6 +142,8 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
type_context;
fun_context;
global_context;
+ trait_decls_ctx = trans_ctx.trait_decls_ctx.trait_decls;
+ trait_impls_ctx = trans_ctx.trait_impls_ctx.trait_impls;
fun_decl = fdef;
forward_inputs = [];
(* Empty for now *)
@@ -274,21 +270,18 @@ let translate_function_to_pure (trans_ctx : trans_ctx)
(* Return *)
(pure_forward, pure_backwards)
+(* TODO: factor out the return type *)
let translate_crate_to_pure (crate : A.crate) :
- trans_ctx * Pure.type_decl list * (bool * pure_fun_translation) list =
+ trans_ctx
+ * Pure.type_decl list
+ * pure_fun_translation list
+ * Pure.trait_decl list
+ * Pure.trait_impl list =
(* Debug *)
log#ldebug (lazy "translate_crate_to_pure");
- (* Compute the type and function contexts *)
- let type_context, fun_context, global_context =
- compute_type_fun_global_contexts crate
- in
- let fun_infos =
- FA.analyze_module crate fun_context.C.fun_decls
- global_context.C.global_decls !Config.use_state
- in
- let fun_context = { fun_decls = fun_context.fun_decls; fun_infos } in
- let trans_ctx = { type_context; fun_context; global_context } in
+ (* Compute the translation context *)
+ let trans_ctx = compute_contexts crate in
(* Translate all the type definitions *)
let type_decls =
@@ -323,10 +316,7 @@ let translate_crate_to_pure (crate : A.crate) :
(A.FunDeclId.Map.values crate.functions)
in
let sigs = List.append assumed_sigs local_sigs in
- let fun_sigs =
- SymbolicToPure.translate_fun_signatures fun_context.fun_infos
- type_context.type_infos sigs
- in
+ let fun_sigs = SymbolicToPure.translate_fun_signatures trans_ctx sigs in
(* Translate all the *transparent* functions *)
let pure_translations =
@@ -335,28 +325,38 @@ let translate_crate_to_pure (crate : A.crate) :
(A.FunDeclId.Map.values crate.functions)
in
+ (* Translate the trait declarations *)
+ let type_infos = trans_ctx.type_ctx.type_infos in
+ let trait_decls =
+ List.map
+ (SymbolicToPure.translate_trait_decl type_infos)
+ (T.TraitDeclId.Map.values trans_ctx.trait_decls_ctx.trait_decls)
+ in
+
+ (* Translate the trait implementations *)
+ let trait_impls =
+ List.map
+ (SymbolicToPure.translate_trait_impl type_infos)
+ (T.TraitImplId.Map.values trans_ctx.trait_impls_ctx.trait_impls)
+ in
+
(* Apply the micro-passes *)
let pure_translations =
Micro.apply_passes_to_pure_fun_translations trans_ctx pure_translations
in
(* Return *)
- (trans_ctx, type_decls, pure_translations)
-
-(** Extraction context *)
-type gen_ctx = {
- crate : A.crate;
- extract_ctx : ExtractBase.extraction_ctx;
- trans_types : Pure.type_decl Pure.TypeDeclId.Map.t;
- trans_funs : (bool * pure_fun_translation) A.FunDeclId.Map.t;
- functions_with_decreases_clause : PureUtils.FunLoopIdSet.t;
-}
+ (trans_ctx, type_decls, pure_translations, trait_decls, trait_impls)
+
+type gen_ctx = ExtractBase.extraction_ctx
type gen_config = {
extract_types : bool;
extract_decreases_clauses : bool;
extract_template_decreases_clauses : bool;
extract_fun_decls : bool;
+ extract_trait_decls : bool;
+ extract_trait_impls : bool;
extract_transparent : bool;
(** If [true], extract the transparent declarations, otherwise ignore. *)
extract_opaque : bool;
@@ -383,21 +383,23 @@ type gen_config = {
test_trans_unit_functions : bool;
}
-(** Returns the pair: (has opaque type decls, has opaque fun decls) *)
-let module_has_opaque_decls (ctx : gen_ctx) : bool * bool =
- let has_opaque_types =
- Pure.TypeDeclId.Map.exists
- (fun _ (d : Pure.type_decl) ->
- match d.kind with Opaque -> true | _ -> false)
- ctx.trans_types
- in
- let has_opaque_funs =
- A.FunDeclId.Map.exists
- (fun _ ((_, ((t_fwd, _), _)) : bool * pure_fun_translation) ->
- Option.is_none t_fwd.body)
- ctx.trans_funs
+(** Returns the pair: (has opaque type decls, has opaque fun decls).
+
+ [filter_assumed]: if [true], do not consider as opaque the external definitions
+ that we will map to definitions from the standard library.
+ *)
+let crate_has_opaque_decls (ctx : gen_ctx) (filter_assumed : bool) : bool * bool
+ =
+ let types, funs =
+ LlbcAstUtils.crate_get_opaque_decls ctx.crate filter_assumed
in
- (has_opaque_types, has_opaque_funs)
+ log#ldebug
+ (lazy
+ ("Opaque decls:" ^ "\n- types:\n"
+ ^ String.concat ",\n" (List.map T.show_type_decl types)
+ ^ "\n- functions:\n"
+ ^ String.concat ",\n" (List.map A.show_fun_decl funs)));
+ (types <> [], funs <> [])
(** Export a type declaration.
@@ -429,9 +431,9 @@ let export_type (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx)
|| ((not is_opaque) && config.extract_transparent)
then (
if extract_decl then
- Extract.extract_type_decl ctx.extract_ctx fmt type_decl_group kind def;
+ Extract.extract_type_decl ctx fmt type_decl_group kind def;
if extract_extra_info then
- Extract.extract_type_decl_extra_info ctx.extract_ctx fmt kind def)
+ Extract.extract_type_decl_extra_info ctx fmt kind def)
(** Export a group of types.
@@ -483,7 +485,7 @@ let export_types_group (fmt : Format.formatter) (config : gen_config)
End
]}
*)
- Extract.start_type_decl_group ctx.extract_ctx fmt is_rec defs;
+ Extract.start_type_decl_group ctx fmt is_rec defs;
List.iteri
(fun i def ->
let kind = kind_from_index i in
@@ -504,26 +506,34 @@ let export_types_group (fmt : Format.formatter) (config : gen_config)
*)
let export_global (fmt : Format.formatter) (config : gen_config) (ctx : gen_ctx)
(id : A.GlobalDeclId.id) : unit =
- let global_decls = ctx.extract_ctx.trans_ctx.global_context.global_decls in
+ let global_decls = ctx.trans_ctx.global_ctx.global_decls in
let global = A.GlobalDeclId.Map.find id global_decls in
- let _, ((body, loop_fwds), body_backs) =
- A.FunDeclId.Map.find global.body_id ctx.trans_funs
- in
- assert (body_backs = []);
- assert (loop_fwds = []);
+ let trans = A.FunDeclId.Map.find global.body_id ctx.trans_funs in
+ assert (trans.fwd.loops = []);
+ assert (trans.backs = []);
+ let body = trans.fwd.f in
let is_opaque = Option.is_none body.Pure.body in
- if
+ (* Check if we extract the global *)
+ let extract =
config.extract_globals
&& (((not is_opaque) && config.extract_transparent)
|| (is_opaque && config.extract_opaque))
- then
+ in
+ (* Check if it is an assumed global - if yes, we ignore it because we
+ map the definition to one in the standard library *)
+ let open ExtractAssumed in
+ let sname = name_to_simple_name global.name in
+ let extract =
+ extract && SimpleNameMap.find_opt sname assumed_globals_map = None
+ in
+ if extract then
(* We don't wrap global declaration groups between calls to functions
[{start, end}_global_decl_group] (which don't exist): global declaration
groups are always singletons, so the [extract_global_decl] function
takes care of generating the delimiters.
*)
- Extract.extract_global_decl ctx.extract_ctx fmt global body config.interface
+ Extract.extract_global_decl ctx fmt global body config.interface
(** Utility.
@@ -604,14 +614,13 @@ let export_functions_group_scc (fmt : Format.formatter) (config : gen_config)
then
Some
(fun () ->
- Extract.extract_fun_decl ctx.extract_ctx fmt kind has_decr_clause
- def)
+ Extract.extract_fun_decl ctx fmt kind has_decr_clause def)
else None)
decls
in
let extract_defs = List.filter_map (fun x -> x) extract_defs in
if extract_defs <> [] then (
- Extract.start_fun_decl_group ctx.extract_ctx fmt is_rec decls;
+ Extract.start_fun_decl_group ctx fmt is_rec decls;
List.iter (fun f -> f ()) extract_defs;
Extract.end_fun_decl_group fmt is_rec decls)
@@ -621,7 +630,7 @@ let export_functions_group_scc (fmt : Format.formatter) (config : gen_config)
check if the forward and backward functions are mutually recursive.
*)
let export_functions_group (fmt : Format.formatter) (config : gen_config)
- (ctx : gen_ctx) (pure_ls : (bool * pure_fun_translation) list) : unit =
+ (ctx : gen_ctx) (pure_ls : pure_fun_translation list) : unit =
(* Utility to check a function has a decrease clause *)
let has_decreases_clause (def : Pure.fun_decl) : bool =
PureUtils.FunLoopIdSet.mem (def.def_id, def.loop_id)
@@ -631,7 +640,7 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
(* Extract the decrease clauses template bodies *)
if config.extract_template_decreases_clauses then
List.iter
- (fun (_, ((fwd, loop_fwds), _)) ->
+ (fun { fwd; _ } ->
(* We only generate decreases clauses for the forward functions, because
the termination argument should only depend on the forward inputs.
The backward functions thus use the same decreases clauses as the
@@ -647,19 +656,18 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
if has_decr_clause then
match !Config.backend with
| Lean ->
- Extract.extract_template_lean_termination_and_decreasing
- ctx.extract_ctx fmt decl
+ Extract.extract_template_lean_termination_and_decreasing ctx fmt
+ decl
| FStar ->
- Extract.extract_template_fstar_decreases_clause ctx.extract_ctx
- fmt decl
+ Extract.extract_template_fstar_decreases_clause ctx fmt decl
| Coq ->
raise (Failure "Coq doesn't have decreases/termination clauses")
| HOL4 ->
raise
(Failure "HOL4 doesn't have decreases/termination clauses")
in
- extract_decrease fwd;
- List.iter extract_decrease loop_fwds)
+ extract_decrease fwd.f;
+ List.iter extract_decrease fwd.loops)
pure_ls;
(* Concatenate the function definitions, filtering the useless forward
@@ -667,15 +675,13 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
let decls =
List.concat
(List.map
- (fun (keep_fwd, ((fwd, fwd_loops), (back_ls : fun_and_loops list))) ->
- let fwd = if keep_fwd then List.append fwd_loops [ fwd ] else [] in
- let back : Pure.fun_decl list =
+ (fun { keep_fwd; fwd; backs } ->
+ let fwd = if keep_fwd then List.append fwd.loops [ fwd.f ] else [] in
+ let backs : Pure.fun_decl list =
List.concat
- (List.map
- (fun (back, loop_backs) -> List.append loop_backs [ back ])
- back_ls)
+ (List.map (fun back -> List.append back.loops [ back.f ]) backs)
in
- List.append fwd back)
+ List.append fwd backs)
pure_ls)
in
@@ -693,11 +699,24 @@ let export_functions_group (fmt : Format.formatter) (config : gen_config)
(* Insert unit tests if necessary *)
if config.test_trans_unit_functions then
List.iter
- (fun (keep_fwd, ((fwd, _), _)) ->
- if keep_fwd then
- Extract.extract_unit_test_if_unit_fun ctx.extract_ctx fmt fwd)
+ (fun trans ->
+ if trans.keep_fwd then
+ Extract.extract_unit_test_if_unit_fun ctx fmt trans.fwd.f)
pure_ls
+(** Export a trait declaration. *)
+let export_trait_decl (fmt : Format.formatter) (_config : gen_config)
+ (ctx : gen_ctx) (trait_decl_id : Pure.trait_decl_id) : unit =
+ let trait_decl = T.TraitDeclId.Map.find trait_decl_id ctx.trans_trait_decls in
+ let ctx = { ctx with trait_decl_id = Some trait_decl.def_id } in
+ Extract.extract_trait_decl ctx fmt trait_decl
+
+(** Export a trait implementation. *)
+let export_trait_impl (fmt : Format.formatter) (_config : gen_config)
+ (ctx : gen_ctx) (trait_impl_id : Pure.trait_impl_id) : unit =
+ let trait_impl = T.TraitImplId.Map.find trait_impl_id ctx.trans_trait_impls in
+ Extract.extract_trait_impl ctx fmt trait_impl
+
(** A generic utility to generate the extracted definitions: as we may want to
split the definitions between different files (or not), we can control
what is precisely extracted.
@@ -712,12 +731,14 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
let export_functions_group = export_functions_group fmt config ctx in
let export_global = export_global fmt config ctx in
let export_types_group = export_types_group fmt config ctx in
+ let export_trait_decl = export_trait_decl fmt config ctx in
+ let export_trait_impl = export_trait_impl fmt config ctx in
let export_state_type () : unit =
let kind =
if config.interface then ExtractBase.Declared else ExtractBase.Assumed
in
- Extract.extract_state_type fmt ctx.extract_ctx kind
+ Extract.extract_state_type fmt ctx kind
in
let export_decl_group (dg : A.declaration_group) : unit =
@@ -725,11 +746,18 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
| Type (NonRec id) ->
if config.extract_types then export_types_group false [ id ]
| Type (Rec ids) -> if config.extract_types then export_types_group true ids
- | Fun (NonRec id) ->
+ | Fun (NonRec id) -> (
(* Lookup *)
let pure_fun = A.FunDeclId.Map.find id ctx.trans_funs in
- (* Translate *)
- export_functions_group [ pure_fun ]
+ (* Special case: we skip trait method *declarations* (we will
+ extract their type directly in the records we generate for
+ the trait declarations themselves, there is no point in having
+ separate type definitions) *)
+ match pure_fun.fwd.f.Pure.kind with
+ | TraitMethodDecl _ -> ()
+ | _ ->
+ (* Translate *)
+ export_functions_group [ pure_fun ])
| Fun (Rec ids) ->
(* General case of mutually recursive functions *)
(* Lookup *)
@@ -739,11 +767,17 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
(* Translate *)
export_functions_group pure_funs
| Global id -> export_global id
+ | TraitDecl id ->
+ if config.extract_trait_decls && config.extract_transparent then
+ export_trait_decl id
+ | TraitImpl id ->
+ if config.extract_trait_impls && config.extract_transparent then
+ export_trait_impl id
in
(* If we need to export the state type: we try to export it after we defined
* the type definitions, because if the user wants to define a model for the
- * type, he might want to reuse those in the state type.
+ * type, they might want to reuse those in the state type.
* More specifically: if we extract functions in the same file as the type,
* we have no choice but to define the state type before the functions,
* because they may reuse this state type: in this case, we define/declare
@@ -764,7 +798,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
config.extract_opaque && config.extract_fun_decls
&& !Config.wrap_opaque_in_sig
&&
- let _, opaque_funs = module_has_opaque_decls ctx in
+ let _, opaque_funs = crate_has_opaque_decls ctx true in
opaque_funs
in
if wrap_in_sig then (
@@ -774,7 +808,7 @@ let extract_definitions (fmt : Format.formatter) (config : gen_config)
if config.extract_transparent then "Definitions" else "OpaqueDefs"
in
Format.pp_print_break fmt 0 0;
- Format.pp_open_vbox fmt ctx.extract_ctx.indent_incr;
+ Format.pp_open_vbox fmt ctx.indent_incr;
Format.pp_print_string fmt ("structure " ^ struct_name ^ " where");
Format.pp_print_break fmt 0 0);
List.iter export_decl_group ctx.crate.declarations;
@@ -904,7 +938,9 @@ let extract_file (config : gen_config) (ctx : gen_ctx) (fi : extract_file_info)
let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
unit =
(* Translate the module to the pure AST *)
- let trans_ctx, trans_types, trans_funs = translate_crate_to_pure crate in
+ let trans_ctx, trans_types, trans_funs, trans_trait_decls, trans_trait_impls =
+ translate_crate_to_pure crate
+ in
(* Initialize the extraction context - for now we extract only to F*.
* We initialize the names map by registering the keywords used in the
@@ -921,36 +957,36 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
mk_formatter_and_names_map trans_ctx crate.name
variant_concatenate_type_name
in
- (* Put everything in the context *)
- let ctx =
- {
- ExtractBase.trans_ctx;
- names_map;
- unsafe_names_map = { id_to_name = ExtractBase.IdMap.empty };
- fmt;
- indent_incr = 2;
- use_opaque_pre = !Config.split_files;
- use_dep_ite = !Config.backend = Lean && !Config.extract_decreases_clauses;
- fun_name_info = PureUtils.RegularFunIdMap.empty;
- }
+ let strict_names_map =
+ let open ExtractBase in
+ let ids =
+ List.filter
+ (fun (id, _) -> strict_collisions id)
+ (IdMap.bindings names_map.id_to_name)
+ in
+ let is_opaque = false in
+ List.fold_left
+ (* id_to_string: we shouldn't need to use it *)
+ (fun m (id, n) -> names_map_add show_id is_opaque id n m)
+ empty_names_map ids
in
(* We need to compute which functions are recursive, in order to know
* whether we should generate a decrease clause or not. *)
let rec_functions =
List.map
- (fun (_, ((fwd, loop_fwds), _)) ->
- let fwd =
- if fwd.Pure.signature.info.effect_info.is_rec then
- [ (fwd.def_id, None) ]
+ (fun { fwd; _ } ->
+ let fwd_f =
+ if fwd.f.Pure.signature.info.effect_info.is_rec then
+ [ (fwd.f.def_id, None) ]
else []
in
let loop_fwds =
List.map
(fun (def : Pure.fun_decl) -> [ (def.def_id, def.loop_id) ])
- loop_fwds
+ fwd.loops
in
- fwd :: loop_fwds)
+ fwd_f :: loop_fwds)
trans_funs
in
let rec_functions : PureUtils.fun_loop_id list =
@@ -958,22 +994,70 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
in
let rec_functions = PureUtils.FunLoopIdSet.of_list rec_functions in
- (* Register unique names for all the top-level types, globals and functions.
+ (* Put the translated definitions in maps *)
+ let trans_types =
+ Pure.TypeDeclId.Map.of_list
+ (List.map (fun (d : Pure.type_decl) -> (d.def_id, d)) trans_types)
+ in
+ let trans_funs : pure_fun_translation A.FunDeclId.Map.t =
+ A.FunDeclId.Map.of_list
+ (List.map
+ (fun (trans : pure_fun_translation) -> (trans.fwd.f.def_id, trans))
+ trans_funs)
+ in
+
+ (* Put everything in the context *)
+ let ctx =
+ let trans_trait_decls =
+ T.TraitDeclId.Map.of_list
+ (List.map
+ (fun (d : Pure.trait_decl) -> (d.def_id, d))
+ trans_trait_decls)
+ in
+ let trans_trait_impls =
+ T.TraitImplId.Map.of_list
+ (List.map
+ (fun (d : Pure.trait_impl) -> (d.def_id, d))
+ trans_trait_impls)
+ in
+ {
+ ExtractBase.crate;
+ trans_ctx;
+ names_map;
+ unsafe_names_map = { id_to_name = ExtractBase.IdMap.empty };
+ strict_names_map;
+ fmt;
+ indent_incr = 2;
+ use_opaque_pre = !Config.split_files;
+ use_dep_ite = !Config.backend = Lean && !Config.extract_decreases_clauses;
+ fun_name_info = PureUtils.RegularFunIdMap.empty;
+ trait_decl_id = None (* None by default *);
+ is_provided_method = false (* false by default *);
+ trans_trait_decls;
+ trans_trait_impls;
+ trans_types;
+ trans_funs;
+ functions_with_decreases_clause = rec_functions;
+ }
+ in
+
+ (* Register unique names for all the top-level types, globals, functions...
* Note that the order in which we generate the names doesn't matter:
* we just need to generate a mapping from identifier to name, and make
* sure there are no name clashes. *)
let ctx =
List.fold_left
(fun ctx def -> Extract.extract_type_decl_register_names ctx def)
- ctx trans_types
+ ctx
+ (Pure.TypeDeclId.Map.values trans_types)
in
let ctx =
List.fold_left
- (fun ctx (keep_fwd, defs) ->
+ (fun ctx (trans : pure_fun_translation) ->
(* If requested by the user, register termination measures and decreases
proofs for all the recursive functions *)
- let fwd_def = fst (fst defs) in
+ let fwd_def = trans.fwd.f in
let gen_decr_clause (def : Pure.fun_decl) =
!Config.extract_decreases_clauses
&& PureUtils.FunLoopIdSet.mem
@@ -984,10 +1068,9 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
* those are handled later *)
let is_global = fwd_def.Pure.is_global_decl_body in
if is_global then ctx
- else
- Extract.extract_fun_decl_register_names ctx keep_fwd gen_decr_clause
- defs)
- ctx trans_funs
+ else Extract.extract_fun_decl_register_names ctx gen_decr_clause trans)
+ ctx
+ (A.FunDeclId.Map.values trans_funs)
in
let ctx =
@@ -995,6 +1078,16 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
(A.GlobalDeclId.Map.values crate.globals)
in
+ let ctx =
+ List.fold_left Extract.extract_trait_decl_register_names ctx
+ trans_trait_decls
+ in
+
+ let ctx =
+ List.fold_left Extract.extract_trait_impl_register_names ctx
+ trans_trait_impls
+ in
+
(* Open the output file *)
(* First compute the filename by replacing the extension and converting the
* case (rust module names are snake case) *)
@@ -1023,19 +1116,6 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
(namespace, crate_name, Filename.concat dest_dir crate_name)
in
- (* Put the translated definitions in maps *)
- let trans_types =
- Pure.TypeDeclId.Map.of_list
- (List.map (fun (d : Pure.type_decl) -> (d.def_id, d)) trans_types)
- in
- let trans_funs =
- A.FunDeclId.Map.of_list
- (List.map
- (fun ((keep_fwd, (fd, bdl)) : bool * pure_fun_translation) ->
- ((fst fd).def_id, (keep_fwd, (fd, bdl))))
- trans_funs)
- in
-
let mkdir_if dest_dir =
if not (Sys.file_exists dest_dir) then (
log#linfo (lazy ("Creating missing directory: " ^ dest_dir));
@@ -1091,16 +1171,6 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
in
(* Extract the file(s) *)
- let gen_ctx =
- {
- crate;
- extract_ctx = ctx;
- trans_types;
- trans_funs;
- functions_with_decreases_clause = rec_functions;
- }
- in
-
let module_delimiter =
match !Config.backend with
| FStar -> "."
@@ -1136,6 +1206,8 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
extract_decreases_clauses = !Config.extract_decreases_clauses;
extract_template_decreases_clauses = false;
extract_fun_decls = false;
+ extract_trait_decls = false;
+ extract_trait_impls = false;
extract_transparent = true;
extract_opaque = false;
extract_state_type = false;
@@ -1147,7 +1219,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
(* Check if there are opaque types and functions - in which case we need
* to split *)
- let has_opaque_types, has_opaque_funs = module_has_opaque_decls gen_ctx in
+ let has_opaque_types, has_opaque_funs = crate_has_opaque_decls ctx true in
let has_opaque_types = has_opaque_types || !Config.use_state in
(* Extract the types *)
@@ -1168,6 +1240,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
{
base_gen_config with
extract_types = true;
+ extract_trait_decls = true;
extract_opaque = true;
extract_state_type = !Config.use_state;
interface = has_opaque_types;
@@ -1186,7 +1259,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
custom_includes = [];
}
in
- extract_file types_config gen_ctx file_info;
+ extract_file types_config ctx file_info;
(* Extract the template clauses *)
(if needs_clauses_module && !Config.extract_template_decreases_clauses then
@@ -1214,9 +1287,9 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
custom_includes = [];
}
in
- extract_file template_clauses_config gen_ctx file_info);
+ extract_file template_clauses_config ctx file_info);
- (* Extract the opaque functions, if needed *)
+ (* Extract the opaque declarations, if needed *)
let opaque_funs_module =
if has_opaque_funs then (
(* In the case of Lean we generate a template file *)
@@ -1244,17 +1317,14 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
{
base_gen_config with
extract_fun_decls = true;
+ extract_trait_impls = true;
+ extract_globals = true;
extract_transparent = false;
extract_opaque = true;
interface = true;
}
in
- let gen_ctx =
- {
- gen_ctx with
- extract_ctx = { gen_ctx.extract_ctx with use_opaque_pre = false };
- }
- in
+ let ctx = { ctx with use_opaque_pre = false } in
let file_info =
{
filename = opaque_filename;
@@ -1268,7 +1338,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
custom_includes = [ types_module ];
}
in
- extract_file opaque_config gen_ctx file_info;
+ extract_file opaque_config ctx file_info;
(* Return the additional dependencies *)
[ opaque_imported_module ])
else []
@@ -1281,6 +1351,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
{
base_gen_config with
extract_fun_decls = true;
+ extract_trait_impls = true;
extract_globals = true;
test_trans_unit_functions = !Config.test_trans_unit_functions;
}
@@ -1307,7 +1378,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
[ types_module ] @ opaque_funs_module @ clauses_module;
}
in
- extract_file fun_config gen_ctx file_info)
+ extract_file fun_config ctx file_info)
else
let gen_config =
{
@@ -1316,6 +1387,8 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
extract_template_decreases_clauses =
!Config.extract_template_decreases_clauses;
extract_fun_decls = true;
+ extract_trait_decls = true;
+ extract_trait_impls = true;
extract_transparent = true;
extract_opaque = true;
extract_state_type = !Config.use_state;
@@ -1337,7 +1410,7 @@ let translate_crate (filename : string) (dest_dir : string) (crate : A.crate) :
custom_includes = [];
}
in
- extract_file gen_config gen_ctx file_info);
+ extract_file gen_config ctx file_info);
(* Generate the build file *)
match !Config.backend with
diff --git a/compiler/TranslateCore.ml b/compiler/TranslateCore.ml
index ba5e237b..3427fd43 100644
--- a/compiler/TranslateCore.ml
+++ b/compiler/TranslateCore.ml
@@ -10,64 +10,69 @@ module FA = FunsAnalysis
(** The local logger *)
let log = L.translate_log
-type type_context = C.type_context [@@deriving show]
-
-type fun_context = {
- fun_decls : A.fun_decl A.FunDeclId.Map.t;
- fun_infos : FA.fun_info A.FunDeclId.Map.t;
-}
-[@@deriving show]
+type trans_ctx = C.decls_ctx [@@deriving show]
+type fun_and_loops = { f : Pure.fun_decl; loops : Pure.fun_decl list }
+type pure_fun_translation_no_loops = Pure.fun_decl * Pure.fun_decl list
-type global_context = C.global_context [@@deriving show]
+type pure_fun_translation = {
+ keep_fwd : bool;
+ (** Should we extract the forward function?
-type trans_ctx = {
- type_context : type_context;
- fun_context : fun_context;
- global_context : global_context;
+ If the forward function returns `()` and there is exactly one
+ backward function, we may merge the forward into the backward
+ function and thus don't extract the forward function)?
+ *)
+ fwd : fun_and_loops;
+ backs : fun_and_loops list;
}
-type fun_and_loops = Pure.fun_decl * Pure.fun_decl list
-type pure_fun_translation_no_loops = Pure.fun_decl * Pure.fun_decl list
-type pure_fun_translation = fun_and_loops * fun_and_loops list
+let trans_ctx_to_type_formatter (ctx : trans_ctx)
+ (type_params : Pure.type_var list)
+ (const_generic_params : Pure.const_generic_var list) :
+ PrintPure.type_formatter =
+ let type_decls = ctx.type_ctx.type_decls in
+ let global_decls = ctx.global_ctx.global_decls in
+ let trait_decls = ctx.trait_decls_ctx.trait_decls in
+ let trait_impls = ctx.trait_impls_ctx.trait_impls in
+ PrintPure.mk_type_formatter type_decls global_decls trait_decls trait_impls
+ type_params const_generic_params
let type_decl_to_string (ctx : trans_ctx) (def : Pure.type_decl) : string =
- let type_params = def.type_params in
- let cg_params = def.const_generic_params in
- let type_decls = ctx.type_context.type_decls in
- let global_decls = ctx.global_context.global_decls in
+ let generics = def.generics in
let fmt =
- PrintPure.mk_type_formatter type_decls global_decls type_params cg_params
+ trans_ctx_to_type_formatter ctx generics.types generics.const_generics
in
PrintPure.type_decl_to_string fmt def
let type_id_to_string (ctx : trans_ctx) (id : Pure.TypeDeclId.id) : string =
Print.fun_name_to_string
- (Pure.TypeDeclId.Map.find id ctx.type_context.type_decls).name
+ (Pure.TypeDeclId.Map.find id ctx.type_ctx.type_decls).name
+
+let trans_ctx_to_ast_formatter (ctx : trans_ctx)
+ (type_params : Pure.type_var list)
+ (const_generic_params : Pure.const_generic_var list) :
+ PrintPure.ast_formatter =
+ let type_decls = ctx.type_ctx.type_decls in
+ let fun_decls = ctx.fun_ctx.fun_decls in
+ let global_decls = ctx.global_ctx.global_decls in
+ let trait_decls = ctx.trait_decls_ctx.trait_decls in
+ let trait_impls = ctx.trait_impls_ctx.trait_impls in
+ PrintPure.mk_ast_formatter type_decls fun_decls global_decls trait_decls
+ trait_impls type_params const_generic_params
let fun_sig_to_string (ctx : trans_ctx) (sg : Pure.fun_sig) : string =
- let type_params = sg.type_params in
- let cg_params = sg.const_generic_params in
- let type_decls = ctx.type_context.type_decls in
- let fun_decls = ctx.fun_context.fun_decls in
- let global_decls = ctx.global_context.global_decls in
+ let generics = sg.generics in
let fmt =
- PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params
- cg_params
+ trans_ctx_to_ast_formatter ctx generics.types generics.const_generics
in
PrintPure.fun_sig_to_string fmt sg
let fun_decl_to_string (ctx : trans_ctx) (def : Pure.fun_decl) : string =
- let type_params = def.signature.type_params in
- let cg_params = def.signature.const_generic_params in
- let type_decls = ctx.type_context.type_decls in
- let fun_decls = ctx.fun_context.fun_decls in
- let global_decls = ctx.global_context.global_decls in
+ let generics = def.signature.generics in
let fmt =
- PrintPure.mk_ast_formatter type_decls fun_decls global_decls type_params
- cg_params
+ trans_ctx_to_ast_formatter ctx generics.types generics.const_generics
in
PrintPure.fun_decl_to_string fmt def
let fun_decl_id_to_string (ctx : trans_ctx) (id : A.FunDeclId.id) : string =
- Print.fun_name_to_string
- (A.FunDeclId.Map.find id ctx.fun_context.fun_decls).name
+ Print.fun_name_to_string (A.FunDeclId.Map.find id ctx.fun_ctx.fun_decls).name
diff --git a/compiler/TypesAnalysis.ml b/compiler/TypesAnalysis.ml
index 925f6d39..95c7206a 100644
--- a/compiler/TypesAnalysis.ml
+++ b/compiler/TypesAnalysis.ml
@@ -14,11 +14,10 @@ type expl_info = subtype_info [@@deriving show]
type type_borrows_info = {
contains_static : bool;
- (** Does the type (transitively) contains a static borrow? *)
- contains_borrow : bool;
- (** Does the type (transitively) contains a borrow? *)
+ (** Does the type (transitively) contain a static borrow? *)
+ contains_borrow : bool; (** Does the type (transitively) contain a borrow? *)
contains_nested_borrows : bool;
- (** Does the type (transitively) contains nested borrows? *)
+ (** Does the type (transitively) contain nested borrows? *)
contains_borrow_under_mut : bool;
}
[@@deriving show]
@@ -61,7 +60,7 @@ let initialize_g_type_info (param_infos : 'p) : 'p g_type_info =
let initialize_type_decl_info (def : type_decl) : type_decl_info =
let param_info = { under_borrow = false; under_mut_borrow = false } in
- let param_infos = List.map (fun _ -> param_info) def.type_params in
+ let param_infos = List.map (fun _ -> param_info) def.generics.types in
initialize_g_type_info param_infos
let type_decl_info_to_partial_type_info (info : type_decl_info) :
@@ -122,7 +121,7 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref)
let rec analyze (expl_info : expl_info) (ty_info : partial_type_info)
(ty : 'r ty) : partial_type_info =
match ty with
- | Literal _ | Never -> ty_info
+ | Literal _ | Never | TraitType _ -> ty_info
| TypeVar var_id -> (
(* Update the information for the proper parameter, if necessary *)
match ty_info.param_infos with
@@ -171,20 +170,18 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref)
analyze expl_info ty_info rty
| Adt
( (Tuple | Assumed (Box | Vec | Option | Slice | Array | Str | Range)),
- _,
- tys,
- _ ) ->
+ generics ) ->
(* Nothing to update: just explore the type parameters *)
List.fold_left
(fun ty_info ty -> analyze expl_info ty_info ty)
- ty_info tys
- | Adt (AdtId adt_id, regions, tys, _cgs) ->
+ ty_info generics.types
+ | Adt (AdtId adt_id, generics) ->
(* Lookup the information for this type definition *)
let adt_info = TypeDeclId.Map.find adt_id infos in
(* Update the type info with the information from the adt *)
let ty_info = update_ty_info ty_info adt_info.borrows_info in
(* Check if 'static appears in the region parameters *)
- let found_static = List.exists r_is_static regions in
+ let found_static = List.exists r_is_static generics.regions in
let borrows_info = ty_info.borrows_info in
let borrows_info =
{
@@ -196,7 +193,7 @@ let analyze_full_ty (r_is_static : 'r -> bool) (updated : bool ref)
let ty_info = { ty_info with borrows_info } in
(* For every instantiated type parameter: update the exploration info
* then explore the type *)
- let params_tys = List.combine adt_info.param_infos tys in
+ let params_tys = List.combine adt_info.param_infos generics.types in
let ty_info =
List.fold_left
(fun ty_info (param_info, ty) ->
diff --git a/compiler/Values.ml b/compiler/Values.ml
index d884c319..de27e7a9 100644
--- a/compiler/Values.ml
+++ b/compiler/Values.ml
@@ -52,6 +52,10 @@ type sv_kind =
(** The result of a loop join (when computing loop fixed points) *)
| Aggregate
(** A symbolic value we introduce in place of an aggregate value *)
+ | ConstGeneric
+ (** A symbolic value we introduce when using a const generic as a value *)
+ | TraitConst
+ (** A symbolic value we introduce when evaluating a trait associated constant *)
[@@deriving show, ord]
(** Ancestor for {!symbolic_value} iter visitor *)
diff --git a/compiler/dune b/compiler/dune
index 6785cad4..2f5a0a44 100644
--- a/compiler/dune
+++ b/compiler/dune
@@ -12,6 +12,7 @@
(pps ppx_deriving.show ppx_deriving.ord visitors.ppx))
(libraries charon core_unix unionFind ocamlgraph)
(modules
+ AssociatedTypes
Assumed
Collections
Config
@@ -21,6 +22,7 @@
Expressions
ExpressionsUtils
Extract
+ ExtractAssumed
ExtractBase
FunsAnalysis
Identifiers